megiddo's picture
Update app.py
12cdeaf verified
raw
history blame
1.99 kB
from flask import Flask, request, jsonify
from PIL import Image
import torch
from torchvision import transforms, models
# Initialize Flask app
app = Flask(__name__)
# Load the trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model():
model = models.resnet152()
model.fc = torch.nn.Linear(model.fc.in_features, 26)
model.load_state_dict(torch.load("trained_model.pth", map_location=device))
model = model.to(device)
model.eval()
return model
# Define preprocessing for the input image
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Class labels
CLASS_LABELS = [
"bluebell", "buttercup", "colts_foot", "corn_poppy", "cowslip",
"crocus", "daffodil", "daisy", "dandelion", "foxglove",
"fritillary", "geranium", "hibiscus", "iris", "lily_valley",
"pansy", "petunia", "rose", "snowdrop", "sunflower",
"tigerlily", "tulip", "wallflower", "water_lily", "wild_tulip",
"windflower"
]
@app.route("/predict", methods=["POST"])
def predict():
model = load_model()
if "file" not in request.files:
return jsonify({"error": "No file uploaded"}), 400
file = request.files["file"]
try:
# Load and preprocess the image
image = Image.open(file.stream).convert("RGB")
input_tensor = preprocess(image).unsqueeze(0).to(device)
# Perform inference
with torch.no_grad():
outputs = model(input_tensor)
_, predicted_class = torch.max(outputs, 1)
predicted_label = CLASS_LABELS[predicted_class.item()]
return jsonify({"predicted_class": predicted_label})
except Exception as e:
return jsonify({"error": f"Error during prediction: {str(e)}"}), 500
# Run the app
if __name__ == "__main__":
from waitress import serve
serve(app, host="0.0.0.0", port=8080)