File size: 4,318 Bytes
2201868
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ef3e33
2255b93
e0c19c3
 
 
7ef3e33
e0c19c3
 
 
7ef3e33
e0c19c3
 
 
 
7ef3e33
 
 
e0c19c3
7ef3e33
e0c19c3
52fd9c2
e0c19c3
 
fc29cbf
52fd9c2
7ef3e33
e0c19c3
 
 
 
 
52fd9c2
 
7ef3e33
5cadf06
52fd9c2
95250f9
7ef3e33
95250f9
e0c19c3
95250f9
b4d05af
95250f9
5649d80
95250f9
5649d80
2255b93
5649d80
52fd9c2
2255b93
e0c19c3
5649d80
9dfc63c
4869d07
0c47ae4
fb8a03b
7ef3e33
0c47ae4
 
fb8a03b
 
 
 
 
 
4869d07
fb8a03b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import gradio as gr
import json
import torch
from torch import nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image
import requests
from io import BytesIO

# Define the number of classes
num_classes = 2

# Download model from Hugging Face
def download_model():
    try:
        model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
        return model_path
    except Exception as e:
        print(f"Error downloading model: {e}")
        return None

# Load the model from Hugging Face
def load_model(model_path):
    try:
        model = models.resnet50(pretrained=False)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
        model.eval()
        return model
    except Exception as e:
        print(f"Error loading model: {e}")
        return None

# Download the model and load it
model_path = download_model()
model = load_model(model_path) if model_path else None

# Define the transformation for the input image
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

def predict(input_data):
    try:
        print(f"Input data received: {input_data}, Type: {type(input_data)}")
        
        # Check if the input is a URL or image
        if isinstance(input_data, str):  # If it's a string, assume it's a URL
            try:
                response = requests.get(input_data)
                response.raise_for_status()  # Raise error for HTTP issues
                img = Image.open(BytesIO(response.content))
                print("Image fetched successfully from URL.")
            except Exception as e:
                print(f"Error fetching image from URL: {e}")
                return json.dumps({"error": f"Failed to fetch image from URL: {e}"})
        else:  # If it's not a string, assume it's an image file
            img = input_data

        # Validate the image
        if not isinstance(img, Image.Image):
            print("Invalid image format received.")
            return json.dumps({"error": "Invalid image format received. Please provide a valid image."})
        else:
            print(f"Image successfully loaded: {img}")

        # Apply transformations to the image
        img = transform(img).unsqueeze(0)
        print(f"Transformed image tensor shape: {img.shape}")

        # Ensure model is loaded
        if model is None:
            return json.dumps({"error": "Model not loaded. Ensure the model file is available and correctly loaded."})

        # Move the image to the correct device
        img = img.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

        # Make predictions
        with torch.no_grad():
            outputs = model(img)
            predicted_class = torch.argmax(outputs, dim=1).item()
            print(f"Model prediction outputs: {outputs}, Predicted class: {predicted_class}")

        # Return the result based on the predicted class
        if predicted_class == 0:
            return json.dumps({"result": "The photo you've sent is of fall army worm with problem ID 126."})
        elif predicted_class == 1:
            return json.dumps({"result": "The photo you've sent is of a healthy maize image."})
        else:
            return json.dumps({"error": "Unexpected class prediction."})

    except Exception as e:
        print(f"Error processing image: {e}")
        return json.dumps({"error": f"Error processing image: {e}"})


# Create the Gradio interface with both local file upload and URL input
iface = gr.Interface(
    fn=predict,
    inputs=[gr.Image(type="pil", label="Upload an image or provide a local path"), 
            gr.Textbox(label="Or enter image URL (if available)", placeholder="Enter a URL for the image")],
    outputs=gr.Textbox(label="Prediction Result"),
    live=True,
    title="Maize Anomaly Detection",
    description="Upload an image of maize to detect anomalies like disease or pest infestation. You can provide local paths, URLs, or base64-encoded images."
)

# Launch the Gradio interface
iface.launch(share=True, show_error=True)