from flask import Flask, request, jsonify from PIL import Image import torch from torchvision import transforms, models # Initialize Flask app app = Flask(__name__) # Load the trained model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def load_model(): model = models.resnet152() model.fc = torch.nn.Linear(model.fc.in_features, 26) model.load_state_dict(torch.load("trained_model.pth", map_location=device)) model = model.to(device) model.eval() return model # Define preprocessing for the input image preprocess = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Class labels CLASS_LABELS = [ "bluebell", "buttercup", "colts_foot", "corn_poppy", "cowslip", "crocus", "daffodil", "daisy", "dandelion", "foxglove", "fritillary", "geranium", "hibiscus", "iris", "lily_valley", "pansy", "petunia", "rose", "snowdrop", "sunflower", "tigerlily", "tulip", "wallflower", "water_lily", "wild_tulip", "windflower" ] @app.route("/predict", methods=["POST"]) def predict(): model = load_model() if "file" not in request.files: return jsonify({"error": "No file uploaded"}), 400 file = request.files["file"] try: # Load and preprocess the image image = Image.open(file.stream).convert("RGB") input_tensor = preprocess(image).unsqueeze(0).to(device) # Perform inference with torch.no_grad(): outputs = model(input_tensor) _, predicted_class = torch.max(outputs, 1) predicted_label = CLASS_LABELS[predicted_class.item()] return jsonify({"predicted_class": predicted_label}) except Exception as e: return jsonify({"error": f"Error during prediction: {str(e)}"}), 500 # Run the app if __name__ == "__main__": from waitress import serve serve(app, host="0.0.0.0", port=8080)