Brain_Tumor_MRI / model.py
Amr2272's picture
Upload model.py
97c8d50 verified
from transformers import AutoModelForImageClassification
import torch
from PIL import Image
from torchvision import transforms
model = AutoModelForImageClassification.from_pretrained(".")
model.eval()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
def predict(img_path):
img = Image.open(img_path).convert("RGB")
img = transform(img).unsqueeze(0)
with torch.no_grad():
outputs = model(img)
probs = torch.softmax(outputs.logits, dim=1)
label = probs.argmax().item()
return label, float(probs.max())