brain-tumor-classifier / inference.py
Codewithsalty's picture
Rename Inference.py to inference.py
ea3cc1e verified
# inference.py
import torch
from io import BytesIO
from PIL import Image
from torchvision import transforms
from TumorModel import TumorClassification
# 1) Preprocessing pipeline
_transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
# 2) Load model once
_model = TumorClassification()
_model.load_state_dict(torch.load("BTD_model.pth", map_location="cpu"))
_model.eval()
def inference(image_bytes):
"""
Hugging Face will pass the raw image bytes here.
Return {"label": <one of glioma, meningioma, notumor, pituitary>}.
"""
img = Image.open(BytesIO(image_bytes)).convert("RGB")
x = _transform(img).unsqueeze(0) # batch dimension
with torch.no_grad():
idx = torch.argmax(_model(x), dim=1).item()
labels = ["glioma", "meningioma", "notumor", "pituitary"]
return {"label": labels[idx]}