File size: 3,682 Bytes
2201868
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4869d07
2255b93
5479ea5
52fd9c2
 
fc29cbf
52fd9c2
95250f9
52fd9c2
 
95250f9
5cadf06
52fd9c2
95250f9
 
 
 
b4d05af
95250f9
5649d80
95250f9
5649d80
2255b93
5649d80
52fd9c2
2255b93
5649d80
9dfc63c
4869d07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c47ae4
fb8a03b
4869d07
0c47ae4
 
fb8a03b
 
 
 
 
 
4869d07
fb8a03b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
import json
import torch
from torch import nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image
import requests
from io import BytesIO

# Define the number of classes
num_classes = 2

# Download model from Hugging Face
def download_model():
    try:
        model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
        return model_path
    except Exception as e:
        print(f"Error downloading model: {e}")
        return None

# Load the model from Hugging Face
def load_model(model_path):
    try:
        model = models.resnet50(pretrained=False)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
        model.eval()
        return model
    except Exception as e:
        print(f"Error loading model: {e}")
        return None

# Download the model and load it
model_path = download_model()
model = load_model(model_path) if model_path else None

# Define the transformation for the input image
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

def predict_from_image(image):
    try:
        # Check if the input is a PIL Image type (Gradio automatically provides a PIL image)
        if not isinstance(image, Image.Image):
            return json.dumps({"error": "Invalid image format received. Please provide a valid image."})

        # Apply transformations to the image
        image = transform(image).unsqueeze(0)

        # Move the image to the correct device
        image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

        # Make predictions
        with torch.no_grad():
            outputs = model(image)
            predicted_class = torch.argmax(outputs, dim=1).item()

        # Return the result based on the predicted class
        if predicted_class == 0:
            return json.dumps({"result": "The photo you've sent is of fall army worm with problem ID 126."})
        elif predicted_class == 1:
            return json.dumps({"result": "The photo you've sent is of a healthy maize image."})
        else:
            return json.dumps({"error": "Unexpected class prediction."})

    except Exception as e:
        return json.dumps({"error": f"Error processing image: {e}"})


def predict_from_url(url):
    try:
        # Check if the URL is valid and try fetching the image
        response = requests.get(url)
        if response.status_code == 200:
            img = Image.open(BytesIO(response.content))
            # Call the predict function for the image
            return predict_from_image(img)
        else:
            return json.dumps({"error": "Unable to fetch image from the URL."})
    except Exception as e:
        return json.dumps({"error": f"Error fetching image from URL: {e}"})


# Create the Gradio interface with both local file upload and URL input
iface = gr.Interface(
    fn=predict_from_image,
    inputs=[gr.Image(type="pil", label="Upload an image or provide a local path"), 
            gr.Textbox(label="Or enter image URL (if available)", placeholder="Enter a URL for the image")],
    outputs=gr.Textbox(label="Prediction Result"),
    live=True,
    title="Maize Anomaly Detection",
    description="Upload an image of maize to detect anomalies like disease or pest infestation. You can provide local paths, URLs, or base64-encoded images."
)

# Launch the Gradio interface
iface.launch(share=True, show_error=True)