File size: 3,615 Bytes
2201868
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b77b937
 
2201868
 
 
b77b937
 
 
 
 
2201868
 
 
b77b937
2201868
 
 
 
 
 
 
 
 
b77b937
 
5b86dff
 
 
 
fc29cbf
5b86dff
 
52fd9c2
5b86dff
 
 
 
5cadf06
5b86dff
 
 
 
 
 
 
 
 
52fd9c2
b77b937
 
 
5b86dff
 
 
b77b937
 
 
 
2255b93
5b86dff
 
 
 
 
 
 
 
 
 
 
 
9dfc63c
b77b937
fb8a03b
5b86dff
b77b937
 
 
 
 
fb8a03b
 
b77b937
5b86dff
 
 
fb8a03b
 
b77b937
fb8a03b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import gradio as gr
import json
import torch
from torch import nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image
import requests
from io import BytesIO

# Define the number of classes
num_classes = 2

# Download model from Hugging Face
def download_model():
    model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
    return model_path

# Load the model from Hugging Face
def load_model(model_path):
    model = models.resnet50(pretrained=False)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
    model.eval()
    return model

# Download the model and load it
model_path = download_model()
model = load_model(model_path)

# Define the transformation for the input image
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

# Function to predict from image content
def predict_from_image(image):
    try:
        # Ensure the image is a PIL Image
        if not isinstance(image, Image.Image):
            raise ValueError("Invalid image format received. Please provide a valid image.")

        # Apply transformations
        image_tensor = transform(image).unsqueeze(0)

        # Predict
        with torch.no_grad():
            outputs = model(image_tensor)
            predicted_class = torch.argmax(outputs, dim=1).item()

        # Interpret the result
        if predicted_class == 0:
            return {"status": "success", "result": "Fall army worm detected (Problem ID: 126)."}
        elif predicted_class == 1:
            return {"status": "success", "result": "Healthy maize image detected."}
        else:
            return {"status": "error", "message": "Unexpected class prediction."}
    except Exception as e:
        return {"status": "error", "message": f"Error during prediction: {str(e)}"}

# Function to predict from URL
def predict_from_url(url):
    try:
        if not url.startswith(("http://", "https://")):
            raise ValueError("Invalid URL format. Please provide a valid image URL.")
        
        response = requests.get(url)
        response.raise_for_status()  # Ensure the request was successful
        image = Image.open(BytesIO(response.content))
        return predict_from_image(image)
    except Exception as e:
        return {"status": "error", "message": f"Failed to process the URL: {str(e)}"}

# Combined prediction function for Gradio
def combined_predict(image, url):
    if image and url:
        return {"status": "error", "message": "Provide either an image or a URL, not both."}
    elif image:
        return predict_from_image(image)
    elif url:
        return predict_from_url(url)
    else:
        return {"status": "error", "message": "No input provided. Please upload an image or provide a URL."}

# Gradio interface
iface = gr.Interface(
    fn=combined_predict,
    inputs=[
        gr.Image(type="pil", label="Upload an Image"),
        gr.Textbox(label="Or Enter an Image URL", placeholder="Provide a valid image URL"),
    ],
    outputs=gr.JSON(label="Prediction Result"),
    live=True,
    title="Maize Anomaly Detection",
    description="Upload an image or provide a URL to detect anomalies in maize crops.",
    examples=[
        [None, "https://example.com/sample-image.jpg"],  # Replace with a valid example URL
    ]
)

# Launch the interface
iface.launch(share=True, show_error=True)