Spaces:
Sleeping
Sleeping
File size: 4,318 Bytes
2201868 7ef3e33 2255b93 e0c19c3 7ef3e33 e0c19c3 7ef3e33 e0c19c3 7ef3e33 e0c19c3 7ef3e33 e0c19c3 52fd9c2 e0c19c3 fc29cbf 52fd9c2 7ef3e33 e0c19c3 52fd9c2 7ef3e33 5cadf06 52fd9c2 95250f9 7ef3e33 95250f9 e0c19c3 95250f9 b4d05af 95250f9 5649d80 95250f9 5649d80 2255b93 5649d80 52fd9c2 2255b93 e0c19c3 5649d80 9dfc63c 4869d07 0c47ae4 fb8a03b 7ef3e33 0c47ae4 fb8a03b 4869d07 fb8a03b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import gradio as gr
import json
import torch
from torch import nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image
import requests
from io import BytesIO
# Define the number of classes
num_classes = 2
# Download model from Hugging Face
def download_model():
try:
model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
return model_path
except Exception as e:
print(f"Error downloading model: {e}")
return None
# Load the model from Hugging Face
def load_model(model_path):
try:
model = models.resnet50(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
model.eval()
return model
except Exception as e:
print(f"Error loading model: {e}")
return None
# Download the model and load it
model_path = download_model()
model = load_model(model_path) if model_path else None
# Define the transformation for the input image
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def predict(input_data):
try:
print(f"Input data received: {input_data}, Type: {type(input_data)}")
# Check if the input is a URL or image
if isinstance(input_data, str): # If it's a string, assume it's a URL
try:
response = requests.get(input_data)
response.raise_for_status() # Raise error for HTTP issues
img = Image.open(BytesIO(response.content))
print("Image fetched successfully from URL.")
except Exception as e:
print(f"Error fetching image from URL: {e}")
return json.dumps({"error": f"Failed to fetch image from URL: {e}"})
else: # If it's not a string, assume it's an image file
img = input_data
# Validate the image
if not isinstance(img, Image.Image):
print("Invalid image format received.")
return json.dumps({"error": "Invalid image format received. Please provide a valid image."})
else:
print(f"Image successfully loaded: {img}")
# Apply transformations to the image
img = transform(img).unsqueeze(0)
print(f"Transformed image tensor shape: {img.shape}")
# Ensure model is loaded
if model is None:
return json.dumps({"error": "Model not loaded. Ensure the model file is available and correctly loaded."})
# Move the image to the correct device
img = img.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
# Make predictions
with torch.no_grad():
outputs = model(img)
predicted_class = torch.argmax(outputs, dim=1).item()
print(f"Model prediction outputs: {outputs}, Predicted class: {predicted_class}")
# Return the result based on the predicted class
if predicted_class == 0:
return json.dumps({"result": "The photo you've sent is of fall army worm with problem ID 126."})
elif predicted_class == 1:
return json.dumps({"result": "The photo you've sent is of a healthy maize image."})
else:
return json.dumps({"error": "Unexpected class prediction."})
except Exception as e:
print(f"Error processing image: {e}")
return json.dumps({"error": f"Error processing image: {e}"})
# Create the Gradio interface with both local file upload and URL input
iface = gr.Interface(
fn=predict,
inputs=[gr.Image(type="pil", label="Upload an image or provide a local path"),
gr.Textbox(label="Or enter image URL (if available)", placeholder="Enter a URL for the image")],
outputs=gr.Textbox(label="Prediction Result"),
live=True,
title="Maize Anomaly Detection",
description="Upload an image of maize to detect anomalies like disease or pest infestation. You can provide local paths, URLs, or base64-encoded images."
)
# Launch the Gradio interface
iface.launch(share=True, show_error=True)
|