|
|
|
|
|
import torch
|
|
from io import BytesIO
|
|
from PIL import Image
|
|
from torchvision import transforms
|
|
from TumorModel import TumorClassification
|
|
|
|
|
|
_transform = transforms.Compose([
|
|
transforms.Grayscale(),
|
|
transforms.Resize((224, 224)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.5], [0.5]),
|
|
])
|
|
|
|
|
|
_model = TumorClassification()
|
|
_model.load_state_dict(torch.load("BTD_model.pth", map_location="cpu"))
|
|
_model.eval()
|
|
|
|
def inference(image_bytes):
|
|
"""
|
|
Hugging Face will pass the raw image bytes here.
|
|
Return {"label": <one of glioma, meningioma, notumor, pituitary>}.
|
|
"""
|
|
img = Image.open(BytesIO(image_bytes)).convert("RGB")
|
|
x = _transform(img).unsqueeze(0)
|
|
with torch.no_grad():
|
|
idx = torch.argmax(_model(x), dim=1).item()
|
|
labels = ["glioma", "meningioma", "notumor", "pituitary"]
|
|
return {"label": labels[idx]}
|
|
|