yolac commited on
Commit
2551488
·
verified ·
1 Parent(s): 3d4277c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -19
app.py CHANGED
@@ -1,33 +1,78 @@
 
1
  import torch
2
  import torch.nn as nn
3
- import torch.hub
4
- import requests
5
- from torchvision import models, transforms
6
  from PIL import Image
7
- from fastapi import FastAPI, UploadFile, File
8
- from io import BytesIO
9
 
10
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- # Load the model
 
13
  MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
14
- model = models.mobilenet_v2(weights=None)
15
- model.classifier[1] = nn.Linear(model.last_channel, 3)
16
- model.load_state_dict(torch.hub.load_state_dict_from_url(MODEL_PATH, map_location=torch.device('cpu')))
17
  model.eval()
18
 
19
- # Define image transformation
 
 
 
20
  transform = transforms.Compose([
21
  transforms.Resize((224, 224)),
22
  transforms.ToTensor(),
23
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
24
  ])
25
 
26
- @app.post("/predict/")
27
- async def predict(file: UploadFile = File(...)):
28
- image = Image.open(BytesIO(await file.read())).convert("RGB")
29
- image_tensor = transform(image).unsqueeze(0)
30
- with torch.no_grad():
 
 
 
 
 
 
31
  output = model(image_tensor)
32
- _, predicted = output.max(1)
33
- return {"predicted_class": int(predicted.item())}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
  import torch
3
  import torch.nn as nn
4
+ from torchvision import transforms
 
 
5
  from PIL import Image
6
+ import io
 
7
 
8
+ # Define the model architecture that matches the saved .pth file
9
+ class BacterialMorphologyClassifier(nn.Module):
10
+ def __init__(self):
11
+ super(BacterialMorphologyClassifier, self).__init__()
12
+ self.feature_extractor = nn.Sequential(
13
+ nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
14
+ nn.ReLU(),
15
+ nn.MaxPool2d(kernel_size=2, stride=2),
16
+ nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
17
+ nn.ReLU(),
18
+ nn.MaxPool2d(kernel_size=2, stride=2),
19
+ )
20
+ self.fc = nn.Sequential(
21
+ nn.Flatten(),
22
+ nn.Linear(64 * 56 * 56, 128),
23
+ nn.ReLU(),
24
+ nn.Dropout(0.5),
25
+ nn.Linear(128, 3),
26
+ nn.Softmax(dim=1),
27
+ )
28
+
29
+ def forward(self, x):
30
+ x = self.feature_extractor(x)
31
+ x = self.fc(x)
32
+ return x
33
 
34
+ # Load the model and weights
35
+ model = BacterialMorphologyClassifier()
36
  MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
37
+ state_dict = torch.hub.load_state_dict_from_url(MODEL_PATH, map_location=torch.device('cpu'))
38
+ model.load_state_dict(state_dict)
 
39
  model.eval()
40
 
41
+ # Set up Flask app
42
+ app = Flask(__name__)
43
+
44
+ # Define image preprocessing transformations
45
  transform = transforms.Compose([
46
  transforms.Resize((224, 224)),
47
  transforms.ToTensor(),
48
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
49
  ])
50
 
51
+ @app.route('/predict', methods=['POST'])
52
+ def predict():
53
+ try:
54
+ # Get image from request
55
+ image_file = request.files['image']
56
+ image = Image.open(io.BytesIO(image_file.read())).convert('RGB')
57
+
58
+ # Preprocess the image
59
+ image_tensor = transform(image).unsqueeze(0)
60
+
61
+ # Make prediction
62
  output = model(image_tensor)
63
+ prediction = output.argmax().item()
64
+
65
+ # Class mapping
66
+ class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'}
67
+
68
+ # Return prediction result
69
+ response = {
70
+ 'predicted_class': class_labels[prediction],
71
+ 'confidence': output.max().item()
72
+ }
73
+ return jsonify(response)
74
+ except Exception as e:
75
+ return jsonify({'error': str(e)})
76
+
77
+ if __name__ == '__main__':
78
+ app.run(host='0.0.0.0', port=5000, debug=True)