megiddo commited on
Commit
12cdeaf
·
verified ·
1 Parent(s): 07d5288

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -10
app.py CHANGED
@@ -9,13 +9,14 @@ app = Flask(__name__)
9
  # Load the trained model
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
- # Define the model architecture
13
- model = models.resnet152()
14
- model.fc = torch.nn.Linear(model.fc.in_features, 26) # Adjust for the number of classes
15
- model.load_state_dict(torch.load("trained_model.pth", map_location=device))
16
- model = model.to(device)
17
- model.eval()
18
-
 
19
  # Define preprocessing for the input image
20
  preprocess = transforms.Compose([
21
  transforms.Resize((224, 224)),
@@ -23,7 +24,7 @@ preprocess = transforms.Compose([
23
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
24
  ])
25
 
26
- # Class labels (replace with your dataset's classes)
27
  CLASS_LABELS = [
28
  "bluebell", "buttercup", "colts_foot", "corn_poppy", "cowslip",
29
  "crocus", "daffodil", "daisy", "dandelion", "foxglove",
@@ -35,6 +36,7 @@ CLASS_LABELS = [
35
 
36
  @app.route("/predict", methods=["POST"])
37
  def predict():
 
38
  if "file" not in request.files:
39
  return jsonify({"error": "No file uploaded"}), 400
40
 
@@ -56,6 +58,7 @@ def predict():
56
  except Exception as e:
57
  return jsonify({"error": f"Error during prediction: {str(e)}"}), 500
58
 
59
- # Run the app (Hugging Face Spaces requires `app.run()` here)
60
  if __name__ == "__main__":
61
- app.run(host="0.0.0.0", port=8080)
 
 
9
  # Load the trained model
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
+ def load_model():
13
+ model = models.resnet152()
14
+ model.fc = torch.nn.Linear(model.fc.in_features, 26)
15
+ model.load_state_dict(torch.load("trained_model.pth", map_location=device))
16
+ model = model.to(device)
17
+ model.eval()
18
+ return model
19
+
20
  # Define preprocessing for the input image
21
  preprocess = transforms.Compose([
22
  transforms.Resize((224, 224)),
 
24
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
25
  ])
26
 
27
+ # Class labels
28
  CLASS_LABELS = [
29
  "bluebell", "buttercup", "colts_foot", "corn_poppy", "cowslip",
30
  "crocus", "daffodil", "daisy", "dandelion", "foxglove",
 
36
 
37
  @app.route("/predict", methods=["POST"])
38
  def predict():
39
+ model = load_model()
40
  if "file" not in request.files:
41
  return jsonify({"error": "No file uploaded"}), 400
42
 
 
58
  except Exception as e:
59
  return jsonify({"error": f"Error during prediction: {str(e)}"}), 500
60
 
61
+ # Run the app
62
  if __name__ == "__main__":
63
+ from waitress import serve
64
+ serve(app, host="0.0.0.0", port=8080)