yolac commited on
Commit
d9d91ec
·
verified ·
1 Parent(s): 24aaab1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -10
app.py CHANGED
@@ -4,8 +4,9 @@ import torch.nn as nn
4
  from torchvision import transforms
5
  from PIL import Image
6
  import io
 
7
 
8
- # Define the model architecture that matches the saved .pth file
9
  class BacterialMorphologyClassifier(nn.Module):
10
  def __init__(self):
11
  super(BacterialMorphologyClassifier, self).__init__()
@@ -31,16 +32,13 @@ class BacterialMorphologyClassifier(nn.Module):
31
  x = self.fc(x)
32
  return x
33
 
34
- # Load the model and weights at the start of the app
35
  model = BacterialMorphologyClassifier()
36
- MODEL_PATH = 'https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth' # Replace this with the local path if needed
37
- model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')), strict=False)
 
38
  model.eval()
39
 
40
- # Move model to GPU if available
41
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
42
- model.to(device)
43
-
44
  # Set up Flask app
45
  app = Flask(__name__)
46
 
@@ -59,7 +57,7 @@ def predict():
59
  image = Image.open(io.BytesIO(image_file.read())).convert('RGB')
60
 
61
  # Preprocess the image
62
- image_tensor = transform(image).unsqueeze(0).to(device)
63
 
64
  # Make prediction
65
  output = model(image_tensor)
@@ -78,4 +76,4 @@ def predict():
78
  return jsonify({'error': str(e)})
79
 
80
  if __name__ == '__main__':
81
- app.run(host='0.0.0.0', port=5000, debug=False) # Set debug=False for production
 
4
  from torchvision import transforms
5
  from PIL import Image
6
  import io
7
+ from torch.hub import load_state_dict_from_url
8
 
9
+ # Define the model architecture
10
  class BacterialMorphologyClassifier(nn.Module):
11
  def __init__(self):
12
  super(BacterialMorphologyClassifier, self).__init__()
 
32
  x = self.fc(x)
33
  return x
34
 
35
+ # Load the model and weights
36
  model = BacterialMorphologyClassifier()
37
+ MODEL_URL = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
38
+ state_dict = load_state_dict_from_url(MODEL_URL, map_location=torch.device('cpu'))
39
+ model.load_state_dict(state_dict, strict=False)
40
  model.eval()
41
 
 
 
 
 
42
  # Set up Flask app
43
  app = Flask(__name__)
44
 
 
57
  image = Image.open(io.BytesIO(image_file.read())).convert('RGB')
58
 
59
  # Preprocess the image
60
+ image_tensor = transform(image).unsqueeze(0)
61
 
62
  # Make prediction
63
  output = model(image_tensor)
 
76
  return jsonify({'error': str(e)})
77
 
78
  if __name__ == '__main__':
79
+ app.run(host='0.0.0.0', port=5000, debug=False)