Update app.py
Browse files
app.py
CHANGED
@@ -31,11 +31,10 @@ class BacterialMorphologyClassifier(nn.Module):
|
|
31 |
x = self.fc(x)
|
32 |
return x
|
33 |
|
34 |
-
# Load the model and weights at app
|
35 |
model = BacterialMorphologyClassifier()
|
36 |
-
MODEL_PATH =
|
37 |
-
|
38 |
-
model.load_state_dict(state_dict, strict=False)
|
39 |
model.eval()
|
40 |
|
41 |
# Move model to GPU if available
|
@@ -79,4 +78,4 @@ def predict():
|
|
79 |
return jsonify({'error': str(e)})
|
80 |
|
81 |
if __name__ == '__main__':
|
82 |
-
app.run(host='0.0.0.0', port=5000, debug=
|
|
|
31 |
x = self.fc(x)
|
32 |
return x
|
33 |
|
34 |
+
# Load the model and weights at the start of the app
|
35 |
model = BacterialMorphologyClassifier()
|
36 |
+
MODEL_PATH = 'https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth' # Replace this with the local path if needed
|
37 |
+
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')), strict=False)
|
|
|
38 |
model.eval()
|
39 |
|
40 |
# Move model to GPU if available
|
|
|
78 |
return jsonify({'error': str(e)})
|
79 |
|
80 |
if __name__ == '__main__':
|
81 |
+
app.run(host='0.0.0.0', port=5000, debug=False) # Set debug=False for production
|