File size: 2,756 Bytes
2201868
 
 
 
 
342396f
2201868
342396f
163e73a
342396f
2201868
 
 
 
0f694f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342396f
2201868
 
 
 
 
 
 
342396f
 
 
 
 
163e73a
342396f
 
163e73a
342396f
163e73a
 
 
 
342396f
 
 
 
 
 
 
52fd9c2
991ba20
 
 
 
 
 
 
 
 
 
342396f
163e73a
991ba20
342396f
dd36796
991ba20
342396f
 
dd36796
163e73a
dd36796
163e73a
 
342396f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gradio as gr
import json
import torch
from torch import nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image
import requests
import os
from io import BytesIO

# Define the number of classes
num_classes = 2

# Download model from Hugging Face
def download_model():
    model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
    return model_path

# Load the model from Hugging Face
def load_model(model_path):
    model = models.resnet50(pretrained=False)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
    model.eval()
    return model

# Download the model and load it
model_path = download_model()
model = load_model(model_path)

# Define the transformation for the input image
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

# Function to predict from image content
def predict_from_image(image):
    # Ensure the image is a PIL Image
    if not isinstance(image, Image.Image):
        raise ValueError("Invalid image format received. Please provide a valid image.")

    # Apply transformations
    image_tensor = transform(image).unsqueeze(0)

    # Predict
    with torch.no_grad():
        outputs = model(image_tensor)
        predicted_class = torch.argmax(outputs, dim=1).item()

    # Interpret the result
    if predicted_class == 0:
        return {"result": "The photo is of fall army worm with problem ID 126."}
    elif predicted_class == 1:
        return {"result": "The photo is of a healthy maize image."}
    else:
        return {"error": "Unexpected class prediction."}

# Function to predict from URL
def predict_from_url(url):
    try:
        response = requests.get(url)
        response.raise_for_status()  # Ensure the request was successful
        image = Image.open(BytesIO(response.content))
        return predict_from_image(image)
    except Exception as e:
        return {"error": f"Failed to process the URL: {str(e)}"}

# Gradio interface
iface = gr.Interface(
    fn=lambda image, url: predict_from_image(image) if image else predict_from_url(url),
    inputs=[
        gr.Image(type="pil", label="Upload an Image"),
        gr.Textbox(label="Or Enter an Image URL", placeholder="Provide a valid image URL"),
    ],
    outputs=gr.JSON(label="Prediction Result"),
    live=True,
    title="Maize Anomaly Detection",
    description="Upload an image or provide a URL to detect anomalies in maize crops.",
)

# Launch the interface
iface.launch(share=True, show_error=True)