Spaces:
Sleeping
Sleeping
File size: 1,246 Bytes
bfbce39 881a8e8 9f7b5d5 bfbce39 94e7aac 881a8e8 6eb3411 9f7b5d5 94e7aac 881a8e8 bfbce39 94e7aac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
from torchvision.models import efficientnet_b0
import torchvision.transforms.functional as tf
from fastapi import FastAPI, UploadFile
app=FastAPI(title="Alzheimer Detection API")
def classify_img(model,img):
img=tf.to_tensor(img)
img=img.unsqueeze(0)
with torch.no_grad():
predict=model(img)
predict=nn.functional.softmax(predict,1)
label=torch.argmax(predict)
probability=torch.max(predict)
return label,probability
def get_alzheimer_model():
model=efficientnet_b0(weights=None)
in_features=model.classifier[1].in_features
model.classifier[1]=nn.Linear(in_features=in_features,out_features=4)
weights=torch.load("alzheimer_weight.pth",map_location="cpu")
model.load_state_dict(weights)
model.eval()
return model
@app.get("/")
def display():
return "Welcome to Alzheimer Detection Api"
@app.post("/predict")
def predict(file: UploadFile):
img=Image.open(file.file).convert("RGB")
img=img.resize(480,480)
img=np.array(img)
model= get_alzheimer_model()
label,probability=classify_img(model,img)
return {"label":label.item(),"probability":probability.item()} |