File size: 2,784 Bytes
1f6dbae
2255b93
1f6dbae
 
 
9dfc63c
1f6dbae
66345ab
42edc6c
1f6dbae
 
42edc6c
2251f70
1f6dbae
 
 
2251f70
 
1f6dbae
 
 
 
 
 
42edc6c
 
1f6dbae
 
 
fc29cbf
95250f9
 
5cadf06
fc29cbf
95250f9
5cadf06
95250f9
 
 
5cadf06
95250f9
 
66345ab
95250f9
66345ab
2255b93
66345ab
2255b93
5cadf06
66345ab
9dfc63c
fb31436
a62d15d
5cadf06
2251f70
42edc6c
2255b93
bf44ad8
5cadf06
a62d15d
 
42edc6c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def predict(data):
    try:
        image_input = data.get('image', None)
        if not image_input:
            return json.dumps({"error": "No image provided."})

        print(f"Received image input: {image_input}")

        # Check if the input is a PIL Image type
        if isinstance(image_input, Image.Image):
            print(f"Image is already loaded as PIL Image: {image_input}")
        else:
            # Check if the input contains a base64-encoded string or URL
            if image_input.startswith("http"):  # URL case
                try:
                    response = requests.get(image_input)
                    image = Image.open(BytesIO(response.content))
                    print(f"Fetched image from URL: {image}")
                except Exception as e:
                    print(f"Error fetching image from URL: {e}")
                    return json.dumps({"error": f"Error fetching image from URL: {e}"})
            else:  # Assuming it is base64-encoded image data
                try:
                    image_data = base64.b64decode(image_input)
                    image = Image.open(BytesIO(image_data))
                    print(f"Decoded base64 image: {image}")
                except Exception as e:
                    print(f"Error decoding base64 image: {e}")
                    return json.dumps({"error": f"Error decoding base64 image: {e}"})

        # Apply transformations
        image = transform(image).unsqueeze(0)
        print(f"Transformed image tensor: {image.shape}")

        image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

        with torch.no_grad():
            outputs = model(image)
            predicted_class = torch.argmax(outputs, dim=1).item()
            print(f"Prediction output: {outputs}, Predicted class: {predicted_class}")

        if predicted_class == 0:
            return json.dumps({"result": "The photo you've sent is of fall army worm with problem ID 126."})
        elif predicted_class == 1:
            return json.dumps({"result": "The photo you've sent is of a healthy maize image."})
        else:
            return json.dumps({"error": "Unexpected class prediction."})
    except Exception as e:
        print(f"Error processing image: {e}")
        return json.dumps({"error": f"Error processing image: {e}"})

# Create the Gradio interface
iface = gr.Interface(
    fn=predict,
    inputs=gr.JSON(label="Input JSON"),
    outputs=gr.Textbox(label="Prediction Result"),
    live=True,
    title="Maize Anomaly Detection",
    description="Upload an image of maize to detect anomalies like disease or pest infestation. You can provide local paths, URLs, or base64-encoded images."
)

# Launch the Gradio interface 
iface.launch(share=True, show_error=True)