Spaces:
Sleeping
Sleeping
from PIL import Image | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torchvision.models import efficientnet_b0 | |
import torchvision.transforms.functional as tf | |
from fastapi import FastAPI, UploadFile | |
app=FastAPI(title="Alzheimer Detection API") | |
def classify_img(model,img): | |
img=tf.to_tensor(img) | |
img=img.unsqueeze(0) | |
with torch.no_grad(): | |
predict=model(img) | |
predict=nn.functional.softmax(predict,1) | |
label=torch.argmax(predict) | |
probability=torch.max(predict) | |
return label,probability | |
def get_alzheimer_model(): | |
model=efficientnet_b0(weights=None) | |
in_features=model.classifier[1].in_features | |
model.classifier[1]=nn.Linear(in_features=in_features,out_features=4) | |
weights=torch.load("alzheimer_weight.pth",map_location="cpu") | |
model.load_state_dict(weights) | |
model.eval() | |
return model | |
def display(): | |
return "Welcome to Alzheimer Detection Api" | |
def predict(file: UploadFile): | |
img=Image.open(file.file).convert("RGB") | |
img=img.resize(480,480) | |
img=np.array(img) | |
model= get_alzheimer_model() | |
label,probability=classify_img(model,img) | |
return {"label":label.item(),"probability":probability.item()} |