File size: 3,341 Bytes
cbc5566
9dfc63c
 
 
 
bf44ad8
 
 
871b5a8
cbc5566
38d7439
2255b93
cbc5566
95250f9
 
2255b93
 
 
95250f9
2255b93
 
 
 
 
9dfc63c
95250f9
9dfc63c
38d7439
9dfc63c
2255b93
 
 
 
9dfc63c
 
95250f9
170be68
610d493
ee2271a
 
 
 
170be68
 
ee2271a
 
 
 
 
 
170be68
 
 
ee2271a
 
 
 
fc29cbf
ee2271a
95250f9
 
fc29cbf
95250f9
 
fc29cbf
95250f9
 
 
 
 
 
 
 
 
 
2255b93
95250f9
2255b93
 
9dfc63c
fb31436
a62d15d
fc29cbf
170be68
 
871b5a8
170be68
fc29cbf
2255b93
bf44ad8
170be68
a62d15d
 
ce20917
40efeb4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
import torch
from torch import nn
from torchvision import models, transforms
from PIL import Image
import requests
import base64
from io import BytesIO
import os

# Define the number of classes
num_classes = 2  # Update with the actual number of classes in your dataset

# Load the model (assuming you've already downloaded it)
def load_model():
    try:
        model = models.resnet50(pretrained=False)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        model.load_state_dict(torch.load("path_to_your_model.pth", map_location=torch.device("cpu")))
        model.eval()
        return model
    except Exception as e:
        print(f"Error loading model: {e}")
        return None

model = load_model()

# Define the transformation for the input image
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

# Prediction function
def process_image(image, image_url=None):
    try:
        # Ensure that the image is not None
        if image is None and not image_url:
            return "No image or URL provided."

        # Handle URL-based image loading
        if image_url:
            try:
                response = requests.get(image_url)
                response.raise_for_status()  # Raise an error if the request fails
                image = Image.open(BytesIO(response.content))
            except Exception as e:
                return f"Error fetching image from URL: {e}"

        # Handle local file path image loading (Gradio File input)
        elif isinstance(image, str) and os.path.isfile(image):
            try:
                image = Image.open(image)
            except Exception as e:
                return f"Error loading image from local path: {e}"

        # Validate that the image is loaded correctly
        if not isinstance(image, Image.Image):
            return "Invalid image format received."

        # Apply transformations
        image = transform(image).unsqueeze(0)

        # Prediction
        image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
        with torch.no_grad():
            outputs = model(image)
            predicted_class = torch.argmax(outputs, dim=1).item()

        if predicted_class == 0:
            return "The photo you've sent is of fall army worm with problem ID 126."
        elif predicted_class == 1:
            return "The photo you've sent is of a healthy maize image."
        else:
            return "Unexpected class prediction."
    except Exception as e:
        return f"Error processing image: {e}"

# Create the Gradio interface
iface = gr.Interface(
    fn=process_image,
    inputs=[
        gr.File(label="Upload an image (Local File Path)"),  # Input: Local file
        gr.Textbox(label="Enter Image URL", placeholder="Enter image URL here", lines=1)  # Input: Image URL
    ],
    outputs=gr.Textbox(label="Prediction Result"),  # Output: Prediction result
    live=True,
    title="Maize Anomaly Detection",
    description="Upload an image of maize to detect anomalies like disease or pest infestation. You can upload local images or provide an image URL."
)

# Launch the Gradio interface
iface.launch(share=True, show_error=True)