Spaces:
Sleeping
Sleeping
File size: 3,753 Bytes
cbc5566 9dfc63c bf44ad8 01a4ed7 cbc5566 38d7439 2255b93 cbc5566 38d7439 9dfc63c 2255b93 9dfc63c 38d7439 9dfc63c 2255b93 9dfc63c 38d7439 2255b93 9dfc63c 38d7439 9dfc63c 2255b93 9dfc63c fc29cbf 610d493 fc29cbf 610d493 fc29cbf 2255b93 fc29cbf 2255b93 40efeb4 2255b93 9dfc63c fb31436 a62d15d fc29cbf 2255b93 bf44ad8 a62d15d ce20917 40efeb4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import gradio as gr
import torch
from torch import nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image
import requests
import base64
from io import BytesIO
import os
# Define the number of classes
num_classes = 2 # Update with the actual number of classes in your dataset
# Download model from Hugging Face
def download_model():
try:
model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
return model_path
except Exception as e:
print(f"Error downloading model: {e}")
return None
# Load the model from Hugging Face
def load_model(model_path):
try:
model = models.resnet50(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
model.eval()
return model
except Exception as e:
print(f"Error loading model: {e}")
return None
# Download the model and load it
model_path = download_model()
model = load_model(model_path) if model_path else None
# Define the transformation for the input image
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def process_image(image_input):
try:
# Process the image input (URL, local file, or base64)
if isinstance(image_input, dict):
# Check if the input contains a URL
if image_input.get("url"):
image_url = image_input["url"]
response = requests.get(image_url)
image = Image.open(BytesIO(response.content))
# Check if the input contains a file path
elif image_input.get("path"):
image_path = image_input["path"]
image = Image.open(image_path)
# Handle base64 if it's included
elif image_input.get("data"):
image_data = base64.b64decode(image_input["data"])
image = Image.open(BytesIO(image_data))
else:
return "Invalid input data format. Please provide a URL or path."
# Apply transformations
image = transform(image).unsqueeze(0)
image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
# Make the prediction
with torch.no_grad():
outputs = model(image)
predicted_class = torch.argmax(outputs, dim=1).item()
# Return prediction result
if predicted_class == 0:
return "The photo you've sent is of fall army worm with problem ID 126."
elif predicted_class == 1:
return "The photo you've sent is of a healthy maize image."
else:
return "Unexpected class prediction."
else:
return "Invalid input. Please provide a dictionary with 'url' or 'path'."
except Exception as e:
print(f"Error processing image: {e}")
return f"Error processing image: {e}"
# Create the Gradio interface
iface = gr.Interface(
fn=process_image,
inputs=gr.JSON(label="Upload an image (URL or Local Path)"), # Input: JSON to handle URL or path
outputs=gr.Textbox(label="Prediction Result"), # Output: Prediction result
live=True,
title="Maize Anomaly Detection",
description="Upload an image of maize to detect anomalies like disease or pest infestation. You can provide local paths, URLs, or base64-encoded images."
)
# Launch the Gradio interface
iface.launch(share=True, show_error=True)
|