Srastog commited on
Commit
881a8e8
·
1 Parent(s): ccde474
Files changed (2) hide show
  1. app/main.py +26 -4
  2. app/utils.py +0 -24
app/main.py CHANGED
@@ -1,19 +1,41 @@
1
- import numpy as np
2
  from PIL import Image
 
 
 
 
 
3
  from fastapi import FastAPI, UploadFile
4
- from utils import classify_img,get_alzheimer_model
5
 
6
  app=FastAPI(title="Alzheimer Detection API")
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  @app.get("/generate")
9
  def display(text: str):
10
  return {"yoy":text}
11
 
12
- """@app.post("/predict")
13
  def predict(file: UploadFile):
14
  img=Image.open(file.file).convert("RGB")
15
  img=img.resize(480,480)
16
  img=np.array(img)
17
  model= get_alzheimer_model()
18
  label,probability=classify_img(model,img)
19
- return {"label":label.item(),"probability":probability.item()}"""
 
 
1
  from PIL import Image
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from torchvision.models import efficientnet_b0
6
+ import torchvision.transforms.functional as tf
7
  from fastapi import FastAPI, UploadFile
 
8
 
9
  app=FastAPI(title="Alzheimer Detection API")
10
 
11
+ def classify_img(model,img):
12
+ img=tf.to_tensor(img)
13
+ img=img.unsqueeze(0)
14
+ with torch.no_grad():
15
+ predict=model(img)
16
+ predict=nn.functional.softmax(predict,1)
17
+ label=torch.argmax(predict)
18
+ probability=torch.max(predict)
19
+ return label,probability
20
+
21
+ def get_alzheimer_model():
22
+ model=efficientnet_b0(weights=None)
23
+ in_features=model.classifier[1].in_features
24
+ model.classifier[1]=nn.Linear(in_features=in_features,out_features=4)
25
+ weights=torch.load("alzheimer_weight.pth",map_location="cpu")
26
+ model.load_state_dict(weights)
27
+ model.eval()
28
+ return model
29
+
30
  @app.get("/generate")
31
  def display(text: str):
32
  return {"yoy":text}
33
 
34
+ @app.post("/predict")
35
  def predict(file: UploadFile):
36
  img=Image.open(file.file).convert("RGB")
37
  img=img.resize(480,480)
38
  img=np.array(img)
39
  model= get_alzheimer_model()
40
  label,probability=classify_img(model,img)
41
+ return {"label":label.item(),"probability":probability.item()}
app/utils.py DELETED
@@ -1,24 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn as nn
4
- from torchvision.models import efficientnet_b0
5
- import torchvision.transforms.functional as tf
6
-
7
- def classify_img(model,img):
8
- img=tf.to_tensor(img)
9
- img=img.unsqueeze(0)
10
- with torch.no_grad():
11
- predict=model(img)
12
- predict=nn.functional.softmax(predict,1)
13
- label=torch.argmax(predict)
14
- probability=torch.max(predict)
15
- return label,probability
16
-
17
- def get_alzheimer_model():
18
- model=efficientnet_b0(weights=None)
19
- in_features=model.classifier[1].in_features
20
- model.classifier[1]=nn.Linear(in_features=in_features,out_features=4)
21
- weights=torch.load("alzheimer_weight.pth",map_location="cpu")
22
- model.load_state_dict(weights)
23
- model.eval()
24
- return model