Spaces:
Sleeping
Sleeping
File size: 4,396 Bytes
cbc5566 9dfc63c 5cadf06 9dfc63c bf44ad8 871b5a8 cbc5566 38d7439 2255b93 cbc5566 5cadf06 2255b93 5cadf06 2255b93 9dfc63c 5cadf06 9dfc63c 38d7439 9dfc63c 2255b93 9dfc63c 5cadf06 610d493 5cadf06 ee2271a 5cadf06 ee2271a 5cadf06 ee2271a 5cadf06 ee2271a 5cadf06 ee2271a 170be68 5cadf06 170be68 ee2271a 5cadf06 ee2271a 5cadf06 ee2271a fc29cbf 5cadf06 95250f9 5cadf06 95250f9 fc29cbf 95250f9 5cadf06 fc29cbf 95250f9 5cadf06 95250f9 5cadf06 95250f9 2255b93 95250f9 2255b93 5cadf06 2255b93 9dfc63c fb31436 a62d15d 5cadf06 2255b93 bf44ad8 5cadf06 a62d15d ce20917 5cadf06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import gradio as gr
import torch
from torch import nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image
import requests
import base64
from io import BytesIO
import os
# Define the number of classes
num_classes = 2 # Update with the actual number of classes in your dataset
# Download model from Hugging Face
def download_model():
try:
model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
return model_path
except Exception as e:
print(f"Error downloading model: {e}")
return None
# Load the model from Hugging Face
def load_model(model_path):
try:
model = models.resnet50(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
model.eval()
return model
except Exception as e:
print(f"Error loading model: {e}")
return None
# Download the model and load it
model_path = download_model()
model = load_model(model_path) if model_path else None
# Define the transformation for the input image
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def predict(image):
try:
print(f"Received image input: {image}")
# Check if the input contains a base64-encoded string
if isinstance(image, dict) and image.get("data"):
try:
image_data = base64.b64decode(image["data"])
image = Image.open(BytesIO(image_data))
print(f"Decoded base64 image: {image}")
except Exception as e:
print(f"Error decoding base64 image: {e}")
return f"Error decoding base64 image: {e}"
# Check if the input is a URL
elif isinstance(image, str) and image.startswith("http"):
try:
response = requests.get(image)
image = Image.open(BytesIO(response.content))
print(f"Fetched image from URL: {image}")
except Exception as e:
print(f"Error fetching image from URL: {e}")
return f"Error fetching image from URL: {e}"
# Check if the input is a local file path
elif isinstance(image, str) and os.path.isfile(image):
try:
image = Image.open(image)
print(f"Loaded image from local path: {image}")
except Exception as e:
print(f"Error loading image from local path: {e}")
return f"Error loading image from local path: {e}"
# Validate that the image is correctly loaded
if not isinstance(image, Image.Image):
print("Invalid image format received.")
return "Invalid image format received."
# Apply transformations
image = transform(image).unsqueeze(0)
print(f"Transformed image tensor: {image.shape}")
image = image.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
with torch.no_grad():
outputs = model(image)
predicted_class = torch.argmax(outputs, dim=1).item()
print(f"Prediction output: {outputs}, Predicted class: {predicted_class}")
if predicted_class == 0:
return "The photo you've sent is of fall army worm with problem ID 126."
elif predicted_class == 1:
return "The photo you've sent is of a healthy maize image."
else:
return "Unexpected class prediction."
except Exception as e:
print(f"Error processing image: {e}")
return f"Error processing image: {e}"
# Create the Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload an image or provide a URL or local path"), # Input: Image, URL, or Local Path
outputs=gr.Textbox(label="Prediction Result"), # Output: Predicted class
live=True,
title="Maize Anomaly Detection",
description="Upload an image of maize to detect anomalies like disease or pest infestation. You can provide local paths, URLs, or base64-encoded images."
)
# Launch the Gradio interface
iface.launch(share=True, show_error=True) |