File size: 1,986 Bytes
52eea83
 
 
 
 
 
 
 
 
 
 
12cdeaf
 
 
 
 
 
 
 
52eea83
 
 
 
 
 
 
12cdeaf
52eea83
 
 
 
 
 
 
 
 
 
 
12cdeaf
52eea83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12cdeaf
52eea83
12cdeaf
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from flask import Flask, request, jsonify
from PIL import Image
import torch
from torchvision import transforms, models

# Initialize Flask app
app = Flask(__name__)

# Load the trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_model():
    model = models.resnet152()
    model.fc = torch.nn.Linear(model.fc.in_features, 26)  
    model.load_state_dict(torch.load("trained_model.pth", map_location=device))
    model = model.to(device)
    model.eval()
    return model
    
# Define preprocessing for the input image
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Class labels 
CLASS_LABELS = [
"bluebell", "buttercup", "colts_foot", "corn_poppy", "cowslip", 
"crocus", "daffodil", "daisy", "dandelion", "foxglove", 
"fritillary", "geranium", "hibiscus", "iris", "lily_valley", 
"pansy", "petunia", "rose", "snowdrop", "sunflower", 
"tigerlily", "tulip", "wallflower", "water_lily", "wild_tulip", 
"windflower"
]

@app.route("/predict", methods=["POST"])
def predict():
    model = load_model()
    if "file" not in request.files:
        return jsonify({"error": "No file uploaded"}), 400

    file = request.files["file"]

    try:
        # Load and preprocess the image
        image = Image.open(file.stream).convert("RGB")
        input_tensor = preprocess(image).unsqueeze(0).to(device)

        # Perform inference
        with torch.no_grad():
            outputs = model(input_tensor)
            _, predicted_class = torch.max(outputs, 1)
        
        predicted_label = CLASS_LABELS[predicted_class.item()]
        return jsonify({"predicted_class": predicted_label})

    except Exception as e:
        return jsonify({"error": f"Error during prediction: {str(e)}"}), 500

# Run the app 
if __name__ == "__main__":
    from waitress import serve  
    serve(app, host="0.0.0.0", port=8080)