yolac commited on
Commit
d3b5926
·
verified ·
1 Parent(s): 6153dd4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -41
app.py CHANGED
@@ -1,58 +1,33 @@
1
  import torch
2
  import torch.nn as nn
 
 
3
  from torchvision import models, transforms
 
4
  from fastapi import FastAPI, UploadFile, File
5
- from pydantic import BaseModel
6
  from io import BytesIO
7
- from PIL import Image
8
 
9
- # Initialize the FastAPI app
10
  app = FastAPI()
11
 
12
- # Path to the trained model checkpoint
13
- MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/blob/main/BacterialMorphologyClassification_model.ipynb"
14
-
15
- # Load the pretrained MobileNetV2 model and modify the classifier
16
- model = models.mobilenet_v2(weights=None) # No initial weights
17
- model.classifier[1] = nn.Linear(model.last_channel, 3) # Replace with 3 classes for your dataset
18
- model.load_state_dict(torch.load(MODEL_PATH)) # Load your trained model weights
19
  model.eval()
20
 
21
- # Define image preprocessing transformations
22
- preprocess = transforms.Compose([
23
  transforms.Resize((224, 224)),
24
  transforms.ToTensor(),
25
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
26
  ])
27
 
28
- # Helper function to preprocess the image
29
- def preprocess_image(image: Image.Image):
30
- image = image.convert("RGB")
31
- image = preprocess(image)
32
- image = image.unsqueeze(0) # Add a batch dimension
33
- return image
34
-
35
  @app.post("/predict/")
36
  async def predict(file: UploadFile = File(...)):
37
- try:
38
- # Read and preprocess the image
39
- image_data = await file.read()
40
- image = Image.open(BytesIO(image_data))
41
- image_tensor = preprocess_image(image)
42
-
43
- # Perform inference
44
- with torch.no_grad():
45
- outputs = model(image_tensor)
46
- predicted_class_idx = torch.argmax(outputs, dim=1).item()
47
-
48
- # Map class index to class label
49
- class_labels = ["Cocci", "Bacilli", "Spirilla"] # Replace with your actual class labels
50
- predicted_class = class_labels[predicted_class_idx]
51
-
52
- return {"predicted_class": predicted_class}
53
-
54
- except Exception as e:
55
- return {"error": str(e)}
56
-
57
- # To run the app, use the following command in the terminal:
58
- # uvicorn app:app --host 0.0.0.0 --port 8000 --reload
 
1
  import torch
2
  import torch.nn as nn
3
+ import torch.hub
4
+ import requests
5
  from torchvision import models, transforms
6
+ from PIL import Image
7
  from fastapi import FastAPI, UploadFile, File
 
8
  from io import BytesIO
 
9
 
 
10
  app = FastAPI()
11
 
12
+ # Load the model
13
+ MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/BacterialMorphologyClassification_model.pth"
14
+ model = models.mobilenet_v2(weights=None)
15
+ model.classifier[1] = nn.Linear(model.last_channel, 3)
16
+ model.load_state_dict(torch.hub.load_state_dict_from_url(MODEL_PATH, map_location=torch.device('cpu')))
 
 
17
  model.eval()
18
 
19
+ # Define image transformation
20
+ transform = transforms.Compose([
21
  transforms.Resize((224, 224)),
22
  transforms.ToTensor(),
23
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
24
  ])
25
 
 
 
 
 
 
 
 
26
  @app.post("/predict/")
27
  async def predict(file: UploadFile = File(...)):
28
+ image = Image.open(BytesIO(await file.read())).convert("RGB")
29
+ image_tensor = transform(image).unsqueeze(0)
30
+ with torch.no_grad():
31
+ output = model(image_tensor)
32
+ _, predicted = output.max(1)
33
+ return {"predicted_class": int(predicted.item())}