yolac commited on
Commit
f4aea2c
·
verified ·
1 Parent(s): 9047e05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -7
app.py CHANGED
@@ -6,14 +6,19 @@ from pydantic import BaseModel
6
  from io import BytesIO
7
  from PIL import Image
8
 
 
9
  app = FastAPI()
10
 
11
- # Load the pretrained MobileNetV2 model from torchvision
12
- model = models.mobilenet_v2(pretrained=True)
13
- model.classifier[1] = nn.Linear(model.last_channel, 3) # Replace with the number of classes in your task
 
 
 
 
14
  model.eval()
15
 
16
- # Define the image preprocessing transform
17
  preprocess = transforms.Compose([
18
  transforms.Resize((224, 224)),
19
  transforms.ToTensor(),
@@ -24,7 +29,7 @@ preprocess = transforms.Compose([
24
  def preprocess_image(image: Image.Image):
25
  image = image.convert("RGB")
26
  image = preprocess(image)
27
- image = image.unsqueeze(0)
28
  return image
29
 
30
  @app.post("/predict/")
@@ -40,7 +45,7 @@ async def predict(file: UploadFile = File(...)):
40
  outputs = model(image_tensor)
41
  predicted_class_idx = torch.argmax(outputs, dim=1).item()
42
 
43
- # Map the predicted class index to the class labels
44
  class_labels = ["Cocci", "Bacilli", "Spirilla"] # Replace with your actual class labels
45
  predicted_class = class_labels[predicted_class_idx]
46
 
@@ -49,4 +54,5 @@ async def predict(file: UploadFile = File(...)):
49
  except Exception as e:
50
  return {"error": str(e)}
51
 
52
- # Run the app with: uvicorn app:app --reload
 
 
6
  from io import BytesIO
7
  from PIL import Image
8
 
9
+ # Initialize the FastAPI app
10
  app = FastAPI()
11
 
12
+ # Path to the trained model checkpoint
13
+ MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/BacterialMorphologyClassification_model.ipynb"
14
+
15
+ # Load the pretrained MobileNetV2 model and modify the classifier
16
+ model = models.mobilenet_v2(weights=None) # No initial weights
17
+ model.classifier[1] = nn.Linear(model.last_channel, 3) # Replace with 3 classes for your dataset
18
+ model.load_state_dict(torch.load(MODEL_PATH)) # Load your trained model weights
19
  model.eval()
20
 
21
+ # Define image preprocessing transformations
22
  preprocess = transforms.Compose([
23
  transforms.Resize((224, 224)),
24
  transforms.ToTensor(),
 
29
  def preprocess_image(image: Image.Image):
30
  image = image.convert("RGB")
31
  image = preprocess(image)
32
+ image = image.unsqueeze(0) # Add a batch dimension
33
  return image
34
 
35
  @app.post("/predict/")
 
45
  outputs = model(image_tensor)
46
  predicted_class_idx = torch.argmax(outputs, dim=1).item()
47
 
48
+ # Map class index to class label
49
  class_labels = ["Cocci", "Bacilli", "Spirilla"] # Replace with your actual class labels
50
  predicted_class = class_labels[predicted_class_idx]
51
 
 
54
  except Exception as e:
55
  return {"error": str(e)}
56
 
57
+ # To run the app, use the following command in the terminal:
58
+ # uvicorn app:app --host 0.0.0.0 --port 8000 --reload