File size: 2,523 Bytes
cbc5566
9dfc63c
 
 
 
 
cbc5566
2b3983d
cbc5566
2b3983d
9dfc63c
9975291
9dfc63c
 
2b3983d
9dfc63c
2b3983d
 
 
 
9dfc63c
 
2b3983d
 
9dfc63c
 
2b3983d
9dfc63c
 
 
 
2b3983d
9dfc63c
 
2b3983d
9dfc63c
2b3983d
 
 
9dfc63c
2b3983d
 
9975291
2b3983d
9dfc63c
 
 
 
 
 
 
fb31436
a62d15d
2b3983d
 
 
 
9975291
fb31436
a62d15d
 
fb31436
0253941
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import gradio as gr
import torch
from torch import nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image

num_classes = 2  # Number of classes for your dataset

# Download model weights from Hugging Face
def download_model():
    model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
    return model_path

# Load the model from the downloaded weights
def load_model(model_path):
    model = models.resnet50(pretrained=False)  # Set pretrained=False for custom weights
    model.fc = nn.Linear(model.fc.in_features, num_classes)  # Adjust final layer for your number of classes
    model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))  # Load model weights
    model.eval()  # Set model to evaluation mode
    return model

# Download and load the model
model_path = download_model()
model = load_model(model_path)

# Image transformation pipeline
transform = transforms.Compose([
    transforms.Resize(256),  # Resize the image to 256x256
    transforms.CenterCrop(224),  # Crop the image to 224x224
    transforms.ToTensor(),  # Convert the image to a Tensor
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),  # Normalize for ImageNet
])

# Prediction function
def predict(image):
    image = transform(image).unsqueeze(0)  # Add batch dimension
    image = image.to(torch.device("cpu"))  # Move the image to CPU (adjust if you want to use GPU)

    with torch.no_grad():
        outputs = model(image)  # Perform forward pass
        predicted_class = torch.argmax(outputs, dim=1).item()  # Get the predicted class ID

    # Return appropriate response based on predicted class
    if predicted_class == 0:
        return "The photo you've sent is of fall army worm with problem ID 126."
    elif predicted_class == 1:
        return "The photo you've sent is of a healthy wheat image."
    else:
        return "Unexpected class prediction."

# Create the Gradio interface
iface = gr.Interface(
    fn=predict,  # Prediction function
    inputs=gr.Image(type="pil"),  # Image input (PIL format)
    outputs=gr.Textbox(),  # Text output (Predicted class description)
    live=True,  # Update predictions as the user uploads an image
    title="Maize Anomaly Detection",
    description="Upload an image of maize to detect anomalies like disease or pest infestation."
)

# Expose Gradio interface as API endpoint
iface.launch(share=False, server_name="0.0.0.0", server_port=7860)