Spaces:
Runtime error
Runtime error
from flask import Flask, request, jsonify | |
from PIL import Image | |
import torch | |
from torchvision import transforms, models | |
# Initialize Flask app | |
app = Flask(__name__) | |
# Load the trained model | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def load_model(): | |
model = models.resnet152() | |
model.fc = torch.nn.Linear(model.fc.in_features, 26) | |
model.load_state_dict(torch.load("trained_model.pth", map_location=device)) | |
model = model.to(device) | |
model.eval() | |
return model | |
# Define preprocessing for the input image | |
preprocess = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
# Class labels | |
CLASS_LABELS = [ | |
"bluebell", "buttercup", "colts_foot", "corn_poppy", "cowslip", | |
"crocus", "daffodil", "daisy", "dandelion", "foxglove", | |
"fritillary", "geranium", "hibiscus", "iris", "lily_valley", | |
"pansy", "petunia", "rose", "snowdrop", "sunflower", | |
"tigerlily", "tulip", "wallflower", "water_lily", "wild_tulip", | |
"windflower" | |
] | |
def predict(): | |
model = load_model() | |
if "file" not in request.files: | |
return jsonify({"error": "No file uploaded"}), 400 | |
file = request.files["file"] | |
try: | |
# Load and preprocess the image | |
image = Image.open(file.stream).convert("RGB") | |
input_tensor = preprocess(image).unsqueeze(0).to(device) | |
# Perform inference | |
with torch.no_grad(): | |
outputs = model(input_tensor) | |
_, predicted_class = torch.max(outputs, 1) | |
predicted_label = CLASS_LABELS[predicted_class.item()] | |
return jsonify({"predicted_class": predicted_label}) | |
except Exception as e: | |
return jsonify({"error": f"Error during prediction: {str(e)}"}), 500 | |
# Run the app | |
if __name__ == "__main__": | |
from waitress import serve | |
serve(app, host="0.0.0.0", port=8080) | |