Srastog commited on
Commit
dd2548a
·
1 Parent(s): 94e7aac

Final Application

Browse files
Files changed (4) hide show
  1. .gitignore +2 -0
  2. app/main.py +6 -27
  3. app/utils.py +29 -0
  4. requirements.txt +1 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .venv/
2
+ app/__pycache__/
app/main.py CHANGED
@@ -1,43 +1,22 @@
1
  from PIL import Image
 
2
  import numpy as np
3
- import torch
4
- import torch.nn as nn
5
- from torchvision.models import efficientnet_b0
6
- import torchvision.transforms.functional as tf
7
  from fastapi import FastAPI, UploadFile
 
8
 
9
  app=FastAPI(title="Alzheimer Detection API")
10
 
11
- def classify_img(model,img):
12
- img=tf.to_tensor(img)
13
- img=img.unsqueeze(0)
14
- with torch.no_grad():
15
- predict=model(img)
16
- predict=nn.functional.softmax(predict,1)
17
- label=torch.argmax(predict)
18
- probability=torch.max(predict)
19
- return label,probability
20
-
21
- def get_alzheimer_model():
22
- model=efficientnet_b0(weights=None)
23
- in_features=model.classifier[1].in_features
24
- model.classifier[1]=nn.Linear(in_features=in_features,out_features=4)
25
- weights=torch.load("alzheimer_weight.pth",map_location="cpu")
26
- model.load_state_dict(weights)
27
- model.eval()
28
- return model
29
-
30
-
31
-
32
  @app.get("/")
33
  def display():
34
  return "Welcome to Alzheimer Detection Api"
35
 
36
  @app.post("/predict")
37
  def predict(file: UploadFile):
38
- img=Image.open(file.file).convert("RGB")
39
- img=img.resize(480,480)
40
  img=np.array(img)
 
 
 
41
  model= get_alzheimer_model()
42
  label,probability=classify_img(model,img)
43
  return {"label":label.item(),"probability":probability.item()}
 
1
  from PIL import Image
2
+ import cv2
3
  import numpy as np
 
 
 
 
4
  from fastapi import FastAPI, UploadFile
5
+ from .utils import classify_img,get_alzheimer_model
6
 
7
  app=FastAPI(title="Alzheimer Detection API")
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  @app.get("/")
10
  def display():
11
  return "Welcome to Alzheimer Detection Api"
12
 
13
  @app.post("/predict")
14
  def predict(file: UploadFile):
15
+ img=Image.open(file.file)
 
16
  img=np.array(img)
17
+ if len(img.shape)==2:
18
+ img=cv2.cvtColor(img,cv2.COLOR_GRAY2RGB)
19
+ img=cv2.resize(img,(480,480))
20
  model= get_alzheimer_model()
21
  label,probability=classify_img(model,img)
22
  return {"label":label.item(),"probability":probability.item()}
app/utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision
5
+ import os
6
+ from torchvision.models import efficientnet_b0
7
+ import torchvision.transforms.functional as tf
8
+
9
+ root_dir=os.path.dirname(os.path.abspath(__file__))
10
+ weights_pth=os.path.join(root_dir,"alzheimer_weight.pth")
11
+
12
+ def classify_img(model,img):
13
+ img=tf.to_tensor(img)
14
+ img=img.unsqueeze(0)
15
+ with torch.no_grad():
16
+ predict=model(img)
17
+ predict=nn.functional.softmax(predict,1)
18
+ label=torch.argmax(predict)
19
+ probability=torch.max(predict)
20
+ return label,probability
21
+
22
+ def get_alzheimer_model():
23
+ model=efficientnet_b0(weights=None)
24
+ in_features=model.classifier[1].in_features
25
+ model.classifier[1]=nn.Linear(in_features=in_features,out_features=4)
26
+ weights=torch.load(weights_pth,map_location="cpu")
27
+ model.load_state_dict(weights)
28
+ model.eval()
29
+ return model
requirements.txt CHANGED
@@ -3,4 +3,5 @@ torch==2.0.1
3
  fastapi==0.104.1
4
  uvicorn==0.24.0
5
  python-multipart==0.0.6
 
6
  torchvision==0.15.2
 
3
  fastapi==0.104.1
4
  uvicorn==0.24.0
5
  python-multipart==0.0.6
6
+ opencv-python-headless==4.5.4.60
7
  torchvision==0.15.2