File size: 1,246 Bytes
bfbce39
881a8e8
 
 
 
 
9f7b5d5
bfbce39
 
94e7aac
881a8e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6eb3411
9f7b5d5
94e7aac
 
 
 
881a8e8
bfbce39
 
 
 
 
 
94e7aac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
from torchvision.models import efficientnet_b0
import torchvision.transforms.functional as tf
from fastapi import FastAPI, UploadFile

app=FastAPI(title="Alzheimer Detection API")

def classify_img(model,img):
    img=tf.to_tensor(img)
    img=img.unsqueeze(0)
    with torch.no_grad():
        predict=model(img)
        predict=nn.functional.softmax(predict,1)
        label=torch.argmax(predict)
        probability=torch.max(predict)
        return label,probability

def get_alzheimer_model():
    model=efficientnet_b0(weights=None)
    in_features=model.classifier[1].in_features
    model.classifier[1]=nn.Linear(in_features=in_features,out_features=4)
    weights=torch.load("alzheimer_weight.pth",map_location="cpu")
    model.load_state_dict(weights)
    model.eval()
    return model



@app.get("/")
def display():
    return "Welcome to Alzheimer Detection Api"

@app.post("/predict")
def predict(file: UploadFile):
    img=Image.open(file.file).convert("RGB")
    img=img.resize(480,480)
    img=np.array(img)
    model= get_alzheimer_model()
    label,probability=classify_img(model,img)
    return {"label":label.item(),"probability":probability.item()}