Update app.py
Browse files
app.py
CHANGED
@@ -4,8 +4,9 @@ import torch.nn as nn
|
|
4 |
from torchvision import transforms
|
5 |
from PIL import Image
|
6 |
import io
|
|
|
7 |
|
8 |
-
# Define the model architecture
|
9 |
class BacterialMorphologyClassifier(nn.Module):
|
10 |
def __init__(self):
|
11 |
super(BacterialMorphologyClassifier, self).__init__()
|
@@ -31,16 +32,13 @@ class BacterialMorphologyClassifier(nn.Module):
|
|
31 |
x = self.fc(x)
|
32 |
return x
|
33 |
|
34 |
-
# Load the model and weights
|
35 |
model = BacterialMorphologyClassifier()
|
36 |
-
|
37 |
-
|
|
|
38 |
model.eval()
|
39 |
|
40 |
-
# Move model to GPU if available
|
41 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
42 |
-
model.to(device)
|
43 |
-
|
44 |
# Set up Flask app
|
45 |
app = Flask(__name__)
|
46 |
|
@@ -59,7 +57,7 @@ def predict():
|
|
59 |
image = Image.open(io.BytesIO(image_file.read())).convert('RGB')
|
60 |
|
61 |
# Preprocess the image
|
62 |
-
image_tensor = transform(image).unsqueeze(0)
|
63 |
|
64 |
# Make prediction
|
65 |
output = model(image_tensor)
|
@@ -78,4 +76,4 @@ def predict():
|
|
78 |
return jsonify({'error': str(e)})
|
79 |
|
80 |
if __name__ == '__main__':
|
81 |
-
app.run(host='0.0.0.0', port=5000, debug=False)
|
|
|
4 |
from torchvision import transforms
|
5 |
from PIL import Image
|
6 |
import io
|
7 |
+
from torch.hub import load_state_dict_from_url
|
8 |
|
9 |
+
# Define the model architecture
|
10 |
class BacterialMorphologyClassifier(nn.Module):
|
11 |
def __init__(self):
|
12 |
super(BacterialMorphologyClassifier, self).__init__()
|
|
|
32 |
x = self.fc(x)
|
33 |
return x
|
34 |
|
35 |
+
# Load the model and weights
|
36 |
model = BacterialMorphologyClassifier()
|
37 |
+
MODEL_URL = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
|
38 |
+
state_dict = load_state_dict_from_url(MODEL_URL, map_location=torch.device('cpu'))
|
39 |
+
model.load_state_dict(state_dict, strict=False)
|
40 |
model.eval()
|
41 |
|
|
|
|
|
|
|
|
|
42 |
# Set up Flask app
|
43 |
app = Flask(__name__)
|
44 |
|
|
|
57 |
image = Image.open(io.BytesIO(image_file.read())).convert('RGB')
|
58 |
|
59 |
# Preprocess the image
|
60 |
+
image_tensor = transform(image).unsqueeze(0)
|
61 |
|
62 |
# Make prediction
|
63 |
output = model(image_tensor)
|
|
|
76 |
return jsonify({'error': str(e)})
|
77 |
|
78 |
if __name__ == '__main__':
|
79 |
+
app.run(host='0.0.0.0', port=5000, debug=False)
|