Spaces:
Sleeping
Sleeping
File size: 2,523 Bytes
cbc5566 9dfc63c cbc5566 2b3983d cbc5566 2b3983d 9dfc63c 9975291 9dfc63c 2b3983d 9dfc63c 2b3983d 9dfc63c 2b3983d 9dfc63c 2b3983d 9dfc63c 2b3983d 9dfc63c 2b3983d 9dfc63c 2b3983d 9dfc63c 2b3983d 9975291 2b3983d 9dfc63c fb31436 a62d15d 2b3983d 9975291 fb31436 a62d15d fb31436 0253941 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import gradio as gr
import torch
from torch import nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image
num_classes = 2 # Number of classes for your dataset
# Download model weights from Hugging Face
def download_model():
model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
return model_path
# Load the model from the downloaded weights
def load_model(model_path):
model = models.resnet50(pretrained=False) # Set pretrained=False for custom weights
model.fc = nn.Linear(model.fc.in_features, num_classes) # Adjust final layer for your number of classes
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) # Load model weights
model.eval() # Set model to evaluation mode
return model
# Download and load the model
model_path = download_model()
model = load_model(model_path)
# Image transformation pipeline
transform = transforms.Compose([
transforms.Resize(256), # Resize the image to 256x256
transforms.CenterCrop(224), # Crop the image to 224x224
transforms.ToTensor(), # Convert the image to a Tensor
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), # Normalize for ImageNet
])
# Prediction function
def predict(image):
image = transform(image).unsqueeze(0) # Add batch dimension
image = image.to(torch.device("cpu")) # Move the image to CPU (adjust if you want to use GPU)
with torch.no_grad():
outputs = model(image) # Perform forward pass
predicted_class = torch.argmax(outputs, dim=1).item() # Get the predicted class ID
# Return appropriate response based on predicted class
if predicted_class == 0:
return "The photo you've sent is of fall army worm with problem ID 126."
elif predicted_class == 1:
return "The photo you've sent is of a healthy wheat image."
else:
return "Unexpected class prediction."
# Create the Gradio interface
iface = gr.Interface(
fn=predict, # Prediction function
inputs=gr.Image(type="pil"), # Image input (PIL format)
outputs=gr.Textbox(), # Text output (Predicted class description)
live=True, # Update predictions as the user uploads an image
title="Maize Anomaly Detection",
description="Upload an image of maize to detect anomalies like disease or pest infestation."
)
# Expose Gradio interface as API endpoint
iface.launch(share=False, server_name="0.0.0.0", server_port=7860)
|