Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -9,13 +9,14 @@ app = Flask(__name__)
|
|
9 |
# Load the trained model
|
10 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
|
12 |
-
|
13 |
-
model = models.resnet152()
|
14 |
-
model.fc = torch.nn.Linear(model.fc.in_features, 26)
|
15 |
-
model.load_state_dict(torch.load("trained_model.pth", map_location=device))
|
16 |
-
model = model.to(device)
|
17 |
-
model.eval()
|
18 |
-
|
|
|
19 |
# Define preprocessing for the input image
|
20 |
preprocess = transforms.Compose([
|
21 |
transforms.Resize((224, 224)),
|
@@ -23,7 +24,7 @@ preprocess = transforms.Compose([
|
|
23 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
24 |
])
|
25 |
|
26 |
-
# Class labels
|
27 |
CLASS_LABELS = [
|
28 |
"bluebell", "buttercup", "colts_foot", "corn_poppy", "cowslip",
|
29 |
"crocus", "daffodil", "daisy", "dandelion", "foxglove",
|
@@ -35,6 +36,7 @@ CLASS_LABELS = [
|
|
35 |
|
36 |
@app.route("/predict", methods=["POST"])
|
37 |
def predict():
|
|
|
38 |
if "file" not in request.files:
|
39 |
return jsonify({"error": "No file uploaded"}), 400
|
40 |
|
@@ -56,6 +58,7 @@ def predict():
|
|
56 |
except Exception as e:
|
57 |
return jsonify({"error": f"Error during prediction: {str(e)}"}), 500
|
58 |
|
59 |
-
# Run the app
|
60 |
if __name__ == "__main__":
|
61 |
-
|
|
|
|
9 |
# Load the trained model
|
10 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
|
12 |
+
def load_model():
|
13 |
+
model = models.resnet152()
|
14 |
+
model.fc = torch.nn.Linear(model.fc.in_features, 26)
|
15 |
+
model.load_state_dict(torch.load("trained_model.pth", map_location=device))
|
16 |
+
model = model.to(device)
|
17 |
+
model.eval()
|
18 |
+
return model
|
19 |
+
|
20 |
# Define preprocessing for the input image
|
21 |
preprocess = transforms.Compose([
|
22 |
transforms.Resize((224, 224)),
|
|
|
24 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
25 |
])
|
26 |
|
27 |
+
# Class labels
|
28 |
CLASS_LABELS = [
|
29 |
"bluebell", "buttercup", "colts_foot", "corn_poppy", "cowslip",
|
30 |
"crocus", "daffodil", "daisy", "dandelion", "foxglove",
|
|
|
36 |
|
37 |
@app.route("/predict", methods=["POST"])
|
38 |
def predict():
|
39 |
+
model = load_model()
|
40 |
if "file" not in request.files:
|
41 |
return jsonify({"error": "No file uploaded"}), 400
|
42 |
|
|
|
58 |
except Exception as e:
|
59 |
return jsonify({"error": f"Error during prediction: {str(e)}"}), 500
|
60 |
|
61 |
+
# Run the app
|
62 |
if __name__ == "__main__":
|
63 |
+
from waitress import serve
|
64 |
+
serve(app, host="0.0.0.0", port=8080)
|