File size: 3,016 Bytes
cbc5566
9dfc63c
 
 
 
 
bf44ad8
 
 
cbc5566
38d7439
2255b93
cbc5566
38d7439
9dfc63c
2255b93
 
 
 
 
 
9dfc63c
38d7439
9dfc63c
2255b93
 
 
 
 
 
 
 
 
9dfc63c
38d7439
2255b93
 
9dfc63c
38d7439
9dfc63c
2255b93
 
 
 
9dfc63c
 
aae3560
 
 
2255b93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf44ad8
2255b93
 
 
bf44ad8
2255b93
 
 
 
 
 
 
 
9dfc63c
fb31436
a62d15d
2255b93
 
 
 
bf44ad8
 
a62d15d
 
ce20917
2255b93
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import gradio as gr
import torch
from torch import nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image
import requests
import base64
from io import BytesIO

# Define the number of classes
num_classes = 2  # Update with the actual number of classes in your dataset

# Download model from Hugging Face
def download_model():
    try:
        model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
        return model_path
    except Exception as e:
        print(f"Error downloading model: {e}")
        return None

# Load the model from Hugging Face
def load_model(model_path):
    try:
        model = models.resnet50(pretrained=False)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
        model.eval()
        return model
    except Exception as e:
        print(f"Error loading model: {e}")
        return None

# Download the model and load it
model_path = download_model()
model = load_model(model_path) if model_path else None

# Define the transformation for the input image
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

def predict(image):
    # Check if the input contains a base64-encoded string
    if isinstance(image, dict) and image.get("data"):
        try:
            image_data = base64.b64decode(image["data"])
            image = Image.open(BytesIO(image_data))
        except Exception as e:
            return f"Error decoding base64 image: {e}"

    elif isinstance(image, str):
        try:
            response = requests.get(image)
            image = Image.open(BytesIO(response.content))
        except Exception as e:
            return f"Error fetching image from URL: {e}"

    # Apply transformations
    try:
        image = transform(image).unsqueeze(0)
        image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

        with torch.no_grad():
            outputs = model(image)
            predicted_class = torch.argmax(outputs, dim=1).item()

        if predicted_class == 0:
            return "The photo you've sent is of fall army worm with problem ID 126."
        elif predicted_class == 1:
            return "The photo you've sent is of a healthy maize image."
        else:
            return "Unexpected class prediction."
    except Exception as e:
        return f"Error processing image: {e}"

# Create the Gradio interface
iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=gr.Textbox(),
    live=True,
    title="Maize Anomaly Detection",
    description="Upload an image of maize to detect anomalies like disease or pest infestation. You can provide local paths, URLs, or base64-encoded images."
)

# Launch the Gradio interface
iface.launch(share=True)