File size: 2,651 Bytes
2201868
 
 
 
342396f
2201868
163e73a
4af9c6b
 
 
 
2201868
 
 
 
0f694f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342396f
2201868
 
 
 
 
 
 
4af9c6b
342396f
4af9c6b
 
 
 
163e73a
4af9c6b
 
163e73a
4af9c6b
 
163e73a
4af9c6b
 
 
 
52fd9c2
4af9c6b
 
 
 
 
 
 
991ba20
4af9c6b
 
991ba20
4af9c6b
163e73a
4af9c6b
 
342396f
dd36796
163e73a
4af9c6b
163e73a
 
4dd171e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
import torch
from torch import nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image
import os
import logging

# Setup logging
logging.basicConfig(level=logging.INFO)

# Define the number of classes
num_classes = 2

# Download model from Hugging Face
def download_model():
    model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
    return model_path

# Load the model from Hugging Face
def load_model(model_path):
    model = models.resnet50(pretrained=False)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
    model.eval()
    return model

# Download the model and load it
model_path = download_model()
model = load_model(model_path)

# Define the transformation for the input image
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

# Prediction function for an uploaded image
def predict_from_image(image):
    try:
        # Ensure the input is a valid PIL image
        if not isinstance(image, Image.Image):
            raise ValueError("Invalid image format received. Please provide a valid image.")

        # Log the input for debugging
        logging.info("Received image for prediction")

        # Apply transformations
        image_tensor = transform(image).unsqueeze(0)

        # Predict
        with torch.no_grad():
            outputs = model(image_tensor)
            predicted_class = torch.argmax(outputs, dim=1).item()

        # Interpret the result
        if predicted_class == 0:
            return {"result": "The photo is of fall army worm with problem ID 126."}
        elif predicted_class == 1:
            return {"result": "The photo is of a healthy maize image."}
        else:
            return {"error": "Unexpected class prediction."}
    except Exception as e:
        logging.error(f"Error during prediction: {str(e)}")
        return {"error": f"Failed to process the image: {str(e)}"}

# Gradio interface restricted to image input
iface = gr.Interface(
    fn=predict_from_image,  # Only handle image input
    inputs=gr.Image(type="pil", label="Upload an Image"),  # Restrict input to image upload
    outputs=gr.JSON(label="Prediction Result"),
    live=True,
    title="Maize Anomaly Detection",
    description="Upload an image to detect anomalies in maize crops.",
)

# Launch the interface locally
if __name__ == "__main__":
    iface.launch()