File size: 4,607 Bytes
cbc5566
9dfc63c
 
 
5cadf06
9dfc63c
bf44ad8
 
 
871b5a8
cbc5566
38d7439
42edc6c
cbc5566
5cadf06
 
 
 
 
 
 
 
 
 
 
2255b93
 
 
5cadf06
2255b93
 
 
 
 
9dfc63c
5cadf06
 
 
9dfc63c
38d7439
9dfc63c
2255b93
 
 
 
9dfc63c
 
5cadf06
610d493
5cadf06
42edc6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc29cbf
95250f9
 
5cadf06
fc29cbf
95250f9
5cadf06
95250f9
 
 
5cadf06
95250f9
 
 
 
 
2255b93
95250f9
2255b93
5cadf06
2255b93
9dfc63c
fb31436
a62d15d
5cadf06
 
42edc6c
2255b93
bf44ad8
5cadf06
a62d15d
 
42edc6c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import gradio as gr
import torch
from torch import nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image
import requests
import base64
from io import BytesIO
import os

# Define the number of classes
num_classes = 2

# Download model from Hugging Face
def download_model():
    try:
        model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
        return model_path
    except Exception as e:
        print(f"Error downloading model: {e}")
        return None

# Load the model from Hugging Face
def load_model(model_path):
    try:
        model = models.resnet50(pretrained=False)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
        model.eval()
        return model
    except Exception as e:
        print(f"Error loading model: {e}")
        return None

# Download the model and load it
model_path = download_model()
model = load_model(model_path) if model_path else None

# Define the transformation for the input image
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

def predict(image):
    try:
        print(f"Received image input: {image}")
        
        # Check if the input is a PIL Image type
        if isinstance(image, Image.Image):
            print(f"Image is already loaded as PIL Image: {image}")
        else:
            # Try to handle base64-encoded image
            if isinstance(image, dict) and image.get("data"):
                try:
                    image_data = base64.b64decode(image["data"])
                    image = Image.open(BytesIO(image_data))
                    print(f"Decoded base64 image: {image}")
                except Exception as e:
                    print(f"Error decoding base64 image: {e}")
                    return f"Error decoding base64 image: {e}"

            # Try to fetch the image from a URL
            elif isinstance(image, str) and image.startswith("http"):
                try:
                    response = requests.get(image)
                    image = Image.open(BytesIO(response.content))
                    print(f"Fetched image from URL: {image}")
                except Exception as e:
                    print(f"Error fetching image from URL: {e}")
                    return f"Error fetching image from URL: {e}"

            # Try to load the image from a local file path
            elif isinstance(image, str) and os.path.isfile(image):
                try:
                    image = Image.open(image)
                    print(f"Loaded image from local path: {image}")
                except Exception as e:
                    print(f"Error loading image from local path: {e}")
                    return f"Error loading image from local path: {e}"

            # Validate that the image is correctly loaded
            if not isinstance(image, Image.Image):
                print("Invalid image format received.")
                return "Invalid image format received."

        # Apply transformations
        image = transform(image).unsqueeze(0)
        print(f"Transformed image tensor: {image.shape}")

        image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

        with torch.no_grad():
            outputs = model(image)
            predicted_class = torch.argmax(outputs, dim=1).item()
            print(f"Prediction output: {outputs}, Predicted class: {predicted_class}")

        if predicted_class == 0:
            return "The photo you've sent is of fall army worm with problem ID 126."
        elif predicted_class == 1:
            return "The photo you've sent is of a healthy maize image."
        else:
            return "Unexpected class prediction."
    except Exception as e:
        print(f"Error processing image: {e}")
        return f"Error processing image: {e}"

# Create the Gradio interface
iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Upload an image or provide a URL or local path"),  # Input: Image, URL, or Local Path
    outputs=gr.Textbox(label="Prediction Result"),
    live=True,
    title="Maize Anomaly Detection",
    description="Upload an image of maize to detect anomalies like disease or pest infestation. You can provide local paths, URLs, or base64-encoded images."
)

# Launch the Gradio interface 
iface.launch(share=True, show_error=True)