Update app.py
Browse files
app.py
CHANGED
@@ -1,58 +1,33 @@
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
|
|
|
|
3 |
from torchvision import models, transforms
|
|
|
4 |
from fastapi import FastAPI, UploadFile, File
|
5 |
-
from pydantic import BaseModel
|
6 |
from io import BytesIO
|
7 |
-
from PIL import Image
|
8 |
|
9 |
-
# Initialize the FastAPI app
|
10 |
app = FastAPI()
|
11 |
|
12 |
-
#
|
13 |
-
MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/
|
14 |
-
|
15 |
-
|
16 |
-
model =
|
17 |
-
model.classifier[1] = nn.Linear(model.last_channel, 3) # Replace with 3 classes for your dataset
|
18 |
-
model.load_state_dict(torch.load(MODEL_PATH)) # Load your trained model weights
|
19 |
model.eval()
|
20 |
|
21 |
-
# Define image
|
22 |
-
|
23 |
transforms.Resize((224, 224)),
|
24 |
transforms.ToTensor(),
|
25 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
26 |
])
|
27 |
|
28 |
-
# Helper function to preprocess the image
|
29 |
-
def preprocess_image(image: Image.Image):
|
30 |
-
image = image.convert("RGB")
|
31 |
-
image = preprocess(image)
|
32 |
-
image = image.unsqueeze(0) # Add a batch dimension
|
33 |
-
return image
|
34 |
-
|
35 |
@app.post("/predict/")
|
36 |
async def predict(file: UploadFile = File(...)):
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
# Perform inference
|
44 |
-
with torch.no_grad():
|
45 |
-
outputs = model(image_tensor)
|
46 |
-
predicted_class_idx = torch.argmax(outputs, dim=1).item()
|
47 |
-
|
48 |
-
# Map class index to class label
|
49 |
-
class_labels = ["Cocci", "Bacilli", "Spirilla"] # Replace with your actual class labels
|
50 |
-
predicted_class = class_labels[predicted_class_idx]
|
51 |
-
|
52 |
-
return {"predicted_class": predicted_class}
|
53 |
-
|
54 |
-
except Exception as e:
|
55 |
-
return {"error": str(e)}
|
56 |
-
|
57 |
-
# To run the app, use the following command in the terminal:
|
58 |
-
# uvicorn app:app --host 0.0.0.0 --port 8000 --reload
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
+
import torch.hub
|
4 |
+
import requests
|
5 |
from torchvision import models, transforms
|
6 |
+
from PIL import Image
|
7 |
from fastapi import FastAPI, UploadFile, File
|
|
|
8 |
from io import BytesIO
|
|
|
9 |
|
|
|
10 |
app = FastAPI()
|
11 |
|
12 |
+
# Load the model
|
13 |
+
MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/BacterialMorphologyClassification_model.pth"
|
14 |
+
model = models.mobilenet_v2(weights=None)
|
15 |
+
model.classifier[1] = nn.Linear(model.last_channel, 3)
|
16 |
+
model.load_state_dict(torch.hub.load_state_dict_from_url(MODEL_PATH, map_location=torch.device('cpu')))
|
|
|
|
|
17 |
model.eval()
|
18 |
|
19 |
+
# Define image transformation
|
20 |
+
transform = transforms.Compose([
|
21 |
transforms.Resize((224, 224)),
|
22 |
transforms.ToTensor(),
|
23 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
24 |
])
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
@app.post("/predict/")
|
27 |
async def predict(file: UploadFile = File(...)):
|
28 |
+
image = Image.open(BytesIO(await file.read())).convert("RGB")
|
29 |
+
image_tensor = transform(image).unsqueeze(0)
|
30 |
+
with torch.no_grad():
|
31 |
+
output = model(image_tensor)
|
32 |
+
_, predicted = output.max(1)
|
33 |
+
return {"predicted_class": int(predicted.item())}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|