yolac commited on
Commit
5590937
·
verified ·
1 Parent(s): fd02441

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -8
app.py CHANGED
@@ -31,21 +31,20 @@ class BacterialMorphologyClassifier(nn.Module):
31
  x = self.fc(x)
32
  return x
33
 
34
- # Load the model and weights
35
  model = BacterialMorphologyClassifier()
36
  MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
37
  state_dict = torch.hub.load_state_dict_from_url(MODEL_PATH, map_location=torch.device('cpu'))
38
- model.load_state_dict(state_dict)
39
  model.eval()
40
 
 
 
 
 
41
  # Set up Flask app
42
  app = Flask(__name__)
43
 
44
- # Add a basic route
45
- @app.route('/')
46
- def home():
47
- return "Flask app is running!"
48
-
49
  # Define image preprocessing transformations
50
  transform = transforms.Compose([
51
  transforms.Resize((224, 224)),
@@ -61,7 +60,7 @@ def predict():
61
  image = Image.open(io.BytesIO(image_file.read())).convert('RGB')
62
 
63
  # Preprocess the image
64
- image_tensor = transform(image).unsqueeze(0)
65
 
66
  # Make prediction
67
  output = model(image_tensor)
 
31
  x = self.fc(x)
32
  return x
33
 
34
+ # Load the model and weights at app startup
35
  model = BacterialMorphologyClassifier()
36
  MODEL_PATH = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth"
37
  state_dict = torch.hub.load_state_dict_from_url(MODEL_PATH, map_location=torch.device('cpu'))
38
+ model.load_state_dict(state_dict, strict=False)
39
  model.eval()
40
 
41
+ # Move model to GPU if available
42
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
43
+ model.to(device)
44
+
45
  # Set up Flask app
46
  app = Flask(__name__)
47
 
 
 
 
 
 
48
  # Define image preprocessing transformations
49
  transform = transforms.Compose([
50
  transforms.Resize((224, 224)),
 
60
  image = Image.open(io.BytesIO(image_file.read())).convert('RGB')
61
 
62
  # Preprocess the image
63
+ image_tensor = transform(image).unsqueeze(0).to(device)
64
 
65
  # Make prediction
66
  output = model(image_tensor)