Spaces:
Running
Running
File size: 3,933 Bytes
2201868 aab569a 2255b93 aab569a 66345ab aab569a 5649d80 87e4f7b 9eebef2 aab569a 9eebef2 aab569a fc29cbf 95250f9 5cadf06 fc29cbf 95250f9 5cadf06 95250f9 5cadf06 95250f9 5649d80 95250f9 5649d80 2255b93 5649d80 2255b93 5cadf06 5649d80 9dfc63c fb31436 a62d15d 5cadf06 aab569a 5649d80 2255b93 bf44ad8 5cadf06 a62d15d 42edc6c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import gradio as gr
import json
import torch
from torch import nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image
import requests
import base64
from io import BytesIO
import os
# Define the number of classes
num_classes = 2
# Download model from Hugging Face
def download_model():
try:
model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
return model_path
except Exception as e:
print(f"Error downloading model: {e}")
return None
# Load the model from Hugging Face
def load_model(model_path):
try:
model = models.resnet50(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
model.eval()
return model
except Exception as e:
print(f"Error loading model: {e}")
return None
# Download the model and load it
model_path = download_model()
model = load_model(model_path) if model_path else None
# Define the transformation for the input image
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def predict(data):
try:
# Expecting data to be a list
if not isinstance(data, list) or len(data) == 0:
return json.dumps({"error": "Input data should be a non-empty list."})
image_input = data[0].get('path', None)
if not image_input:
return json.dumps({"error": "No image provided."})
print(f"Received image input: {image_input}")
# Check if the input is a URL
if isinstance(image_input, str):
if image_input.startswith("http://") or image_input.startswith("https://"):
try:
response = requests.get(image_input)
response.raise_for_status() # Check for HTTP errors
image = Image.open(BytesIO(response.content))
print(f"Fetched image from URL: {image}")
except Exception as e:
print(f"Error fetching image from URL: {e}")
return json.dumps({"error": f"Error fetching image from URL: {e}"})
else:
return json.dumps({"error": "Invalid URL format. Please provide a valid URL starting with 'http://' or 'https://'."})
# Apply transformations
image = transform(image).unsqueeze(0)
print(f"Transformed image tensor: {image.shape}")
image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
with torch.no_grad():
outputs = model(image)
predicted_class = torch.argmax(outputs, dim=1).item()
print(f"Prediction output: {outputs}, Predicted class: {predicted_class}")
if predicted_class == 0:
return json.dumps({"result": "The photo you've sent is of fall army worm with problem ID 126."})
elif predicted_class == 1:
return json.dumps({"result": "The photo you've sent is of a healthy maize image."})
else:
return json.dumps({"error": "Unexpected class prediction."})
except Exception as e:
print(f"Error processing image: {e}")
return json.dumps({"error": f"Error processing image: {e}"})
# Create the Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.JSON(label="Input JSON"),
outputs=gr.Textbox(label="Prediction Result"),
live=True,
title="Maize Anomaly Detection",
description="Upload an image of maize to detect anomalies like disease or pest infestation. You can provide local paths, URLs, or base64-encoded images."
)
# Launch the Gradio interface
iface.launch(share=True, show_error=True)
|