| from transformers import AutoModelForImageClassification | |
| import torch | |
| from PIL import Image | |
| from torchvision import transforms | |
| model = AutoModelForImageClassification.from_pretrained(".") | |
| model.eval() | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor() | |
| ]) | |
| def predict(img_path): | |
| img = Image.open(img_path).convert("RGB") | |
| img = transform(img).unsqueeze(0) | |
| with torch.no_grad(): | |
| outputs = model(img) | |
| probs = torch.softmax(outputs.logits, dim=1) | |
| label = probs.argmax().item() | |
| return label, float(probs.max()) | |