Spaces:
Sleeping
Sleeping
File size: 3,341 Bytes
cbc5566 9dfc63c bf44ad8 871b5a8 cbc5566 38d7439 2255b93 cbc5566 95250f9 2255b93 95250f9 2255b93 9dfc63c 95250f9 9dfc63c 38d7439 9dfc63c 2255b93 9dfc63c 95250f9 170be68 610d493 ee2271a 170be68 ee2271a 170be68 ee2271a fc29cbf ee2271a 95250f9 fc29cbf 95250f9 fc29cbf 95250f9 2255b93 95250f9 2255b93 9dfc63c fb31436 a62d15d fc29cbf 170be68 871b5a8 170be68 fc29cbf 2255b93 bf44ad8 170be68 a62d15d ce20917 40efeb4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import gradio as gr
import torch
from torch import nn
from torchvision import models, transforms
from PIL import Image
import requests
import base64
from io import BytesIO
import os
# Define the number of classes
num_classes = 2 # Update with the actual number of classes in your dataset
# Load the model (assuming you've already downloaded it)
def load_model():
try:
model = models.resnet50(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.load_state_dict(torch.load("path_to_your_model.pth", map_location=torch.device("cpu")))
model.eval()
return model
except Exception as e:
print(f"Error loading model: {e}")
return None
model = load_model()
# Define the transformation for the input image
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
# Prediction function
def process_image(image, image_url=None):
try:
# Ensure that the image is not None
if image is None and not image_url:
return "No image or URL provided."
# Handle URL-based image loading
if image_url:
try:
response = requests.get(image_url)
response.raise_for_status() # Raise an error if the request fails
image = Image.open(BytesIO(response.content))
except Exception as e:
return f"Error fetching image from URL: {e}"
# Handle local file path image loading (Gradio File input)
elif isinstance(image, str) and os.path.isfile(image):
try:
image = Image.open(image)
except Exception as e:
return f"Error loading image from local path: {e}"
# Validate that the image is loaded correctly
if not isinstance(image, Image.Image):
return "Invalid image format received."
# Apply transformations
image = transform(image).unsqueeze(0)
# Prediction
image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
with torch.no_grad():
outputs = model(image)
predicted_class = torch.argmax(outputs, dim=1).item()
if predicted_class == 0:
return "The photo you've sent is of fall army worm with problem ID 126."
elif predicted_class == 1:
return "The photo you've sent is of a healthy maize image."
else:
return "Unexpected class prediction."
except Exception as e:
return f"Error processing image: {e}"
# Create the Gradio interface
iface = gr.Interface(
fn=process_image,
inputs=[
gr.File(label="Upload an image (Local File Path)"), # Input: Local file
gr.Textbox(label="Enter Image URL", placeholder="Enter image URL here", lines=1) # Input: Image URL
],
outputs=gr.Textbox(label="Prediction Result"), # Output: Prediction result
live=True,
title="Maize Anomaly Detection",
description="Upload an image of maize to detect anomalies like disease or pest infestation. You can upload local images or provide an image URL."
)
# Launch the Gradio interface
iface.launch(share=True, show_error=True)
|