yolac commited on
Commit
9047e05
·
verified ·
1 Parent(s): ca75eaf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -17
app.py CHANGED
@@ -1,49 +1,44 @@
1
  import torch
 
 
2
  from fastapi import FastAPI, UploadFile, File
3
  from pydantic import BaseModel
4
  from io import BytesIO
5
  from PIL import Image
6
- from torchvision import transforms
7
- from transformers import AutoModelForImageClassification, AutoFeatureExtractor
8
 
9
- # Initialize the FastAPI app
10
  app = FastAPI()
11
 
12
- # Load the model and feature extractor
13
- model = AutoModelForImageClassification.from_pretrained("yolac/BacterialMorphologyClassification")
14
- feature_extractor = AutoFeatureExtractor.from_pretrained("yolac/BacterialMorphologyClassification")
 
15
 
16
  # Define the image preprocessing transform
17
  preprocess = transforms.Compose([
18
- transforms.Resize((224, 224)), # Adjust size as needed for your model
19
  transforms.ToTensor(),
20
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
21
  ])
22
 
23
  # Helper function to preprocess the image
24
  def preprocess_image(image: Image.Image):
25
- image = image.convert("RGB") # Ensure the image is in RGB format
26
  image = preprocess(image)
27
- image = image.unsqueeze(0) # Add a batch dimension
28
  return image
29
 
30
- # Define the prediction endpoint
31
  @app.post("/predict/")
32
  async def predict(file: UploadFile = File(...)):
33
  try:
34
- # Read the uploaded image
35
  image_data = await file.read()
36
  image = Image.open(BytesIO(image_data))
37
-
38
- # Preprocess the image
39
  image_tensor = preprocess_image(image)
40
 
41
  # Perform inference
42
- model.eval()
43
  with torch.no_grad():
44
  outputs = model(image_tensor)
45
- logits = outputs.logits # Get the model's raw output
46
- predicted_class_idx = torch.argmax(logits, dim=1).item()
47
 
48
  # Map the predicted class index to the class labels
49
  class_labels = ["Cocci", "Bacilli", "Spirilla"] # Replace with your actual class labels
@@ -54,4 +49,4 @@ async def predict(file: UploadFile = File(...)):
54
  except Exception as e:
55
  return {"error": str(e)}
56
 
57
- # Run the app using the command: uvicorn app:app --reload
 
1
  import torch
2
+ import torch.nn as nn
3
+ from torchvision import models, transforms
4
  from fastapi import FastAPI, UploadFile, File
5
  from pydantic import BaseModel
6
  from io import BytesIO
7
  from PIL import Image
 
 
8
 
 
9
  app = FastAPI()
10
 
11
+ # Load the pretrained MobileNetV2 model from torchvision
12
+ model = models.mobilenet_v2(pretrained=True)
13
+ model.classifier[1] = nn.Linear(model.last_channel, 3) # Replace with the number of classes in your task
14
+ model.eval()
15
 
16
  # Define the image preprocessing transform
17
  preprocess = transforms.Compose([
18
+ transforms.Resize((224, 224)),
19
  transforms.ToTensor(),
20
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
21
  ])
22
 
23
  # Helper function to preprocess the image
24
  def preprocess_image(image: Image.Image):
25
+ image = image.convert("RGB")
26
  image = preprocess(image)
27
+ image = image.unsqueeze(0)
28
  return image
29
 
 
30
  @app.post("/predict/")
31
  async def predict(file: UploadFile = File(...)):
32
  try:
33
+ # Read and preprocess the image
34
  image_data = await file.read()
35
  image = Image.open(BytesIO(image_data))
 
 
36
  image_tensor = preprocess_image(image)
37
 
38
  # Perform inference
 
39
  with torch.no_grad():
40
  outputs = model(image_tensor)
41
+ predicted_class_idx = torch.argmax(outputs, dim=1).item()
 
42
 
43
  # Map the predicted class index to the class labels
44
  class_labels = ["Cocci", "Bacilli", "Spirilla"] # Replace with your actual class labels
 
49
  except Exception as e:
50
  return {"error": str(e)}
51
 
52
+ # Run the app with: uvicorn app:app --reload