jays009's picture
Update app.py
95250f9 verified
raw
history blame
3.06 kB
import gradio as gr
import torch
from torch import nn
from torchvision import models, transforms
from PIL import Image
import requests
import base64
from io import BytesIO
import os
# Define the number of classes
num_classes = 2 # Update with the actual number of classes in your dataset
# Load the model (assuming you've already downloaded it)
def load_model():
try:
model = models.resnet50(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.load_state_dict(torch.load("path_to_your_model.pth", map_location=torch.device("cpu")))
model.eval()
return model
except Exception as e:
print(f"Error loading model: {e}")
return None
model = load_model()
# Define the transformation for the input image
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
# Prediction function
def process_image(data):
try:
# Check if the input contains a base64-encoded string
if isinstance(data, dict):
if "data" in data:
# Base64 decoding
image_data = base64.b64decode(data["data"])
image = Image.open(BytesIO(image_data))
elif "url" in data:
# URL-based image loading
response = requests.get(data["url"])
image = Image.open(BytesIO(response.content))
elif "path" in data:
# Local path image loading
image = Image.open(data["path"])
else:
return "Invalid input data structure."
# Validate image
if not isinstance(image, Image.Image):
return "Invalid image format received."
# Apply transformations
image = transform(image).unsqueeze(0)
# Prediction
image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
with torch.no_grad():
outputs = model(image)
predicted_class = torch.argmax(outputs, dim=1).item()
if predicted_class == 0:
return "The photo you've sent is of fall army worm with problem ID 126."
elif predicted_class == 1:
return "The photo you've sent is of a healthy maize image."
else:
return "Unexpected class prediction."
except Exception as e:
return f"Error processing image: {e}"
# Create the Gradio interface
iface = gr.Interface(
fn=process_image,
inputs=gr.JSON(label="Upload an image (URL or Local Path)"), # Input: JSON to handle URL or path
outputs=gr.Textbox(label="Prediction Result"), # Output: Prediction result
live=True,
title="Maize Anomaly Detection",
description="Upload an image of maize to detect anomalies like disease or pest infestation. You can provide local paths, URLs, or base64-encoded images."
)
# Launch the Gradio interface
iface.launch(share=True, show_error=True)