Spaces:
Runtime error
Runtime error
File size: 1,986 Bytes
52eea83 12cdeaf 52eea83 12cdeaf 52eea83 12cdeaf 52eea83 12cdeaf 52eea83 12cdeaf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from flask import Flask, request, jsonify
from PIL import Image
import torch
from torchvision import transforms, models
# Initialize Flask app
app = Flask(__name__)
# Load the trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model():
model = models.resnet152()
model.fc = torch.nn.Linear(model.fc.in_features, 26)
model.load_state_dict(torch.load("trained_model.pth", map_location=device))
model = model.to(device)
model.eval()
return model
# Define preprocessing for the input image
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Class labels
CLASS_LABELS = [
"bluebell", "buttercup", "colts_foot", "corn_poppy", "cowslip",
"crocus", "daffodil", "daisy", "dandelion", "foxglove",
"fritillary", "geranium", "hibiscus", "iris", "lily_valley",
"pansy", "petunia", "rose", "snowdrop", "sunflower",
"tigerlily", "tulip", "wallflower", "water_lily", "wild_tulip",
"windflower"
]
@app.route("/predict", methods=["POST"])
def predict():
model = load_model()
if "file" not in request.files:
return jsonify({"error": "No file uploaded"}), 400
file = request.files["file"]
try:
# Load and preprocess the image
image = Image.open(file.stream).convert("RGB")
input_tensor = preprocess(image).unsqueeze(0).to(device)
# Perform inference
with torch.no_grad():
outputs = model(input_tensor)
_, predicted_class = torch.max(outputs, 1)
predicted_label = CLASS_LABELS[predicted_class.item()]
return jsonify({"predicted_class": predicted_label})
except Exception as e:
return jsonify({"error": f"Error during prediction: {str(e)}"}), 500
# Run the app
if __name__ == "__main__":
from waitress import serve
serve(app, host="0.0.0.0", port=8080)
|