Spaces:
Sleeping
Sleeping
File size: 2,651 Bytes
2201868 342396f 2201868 163e73a 4af9c6b 2201868 0f694f1 342396f 2201868 4af9c6b 342396f 4af9c6b 163e73a 4af9c6b 163e73a 4af9c6b 163e73a 4af9c6b 52fd9c2 4af9c6b 991ba20 4af9c6b 991ba20 4af9c6b 163e73a 4af9c6b 342396f dd36796 163e73a 4af9c6b 163e73a 4dd171e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import gradio as gr
import torch
from torch import nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image
import os
import logging
# Setup logging
logging.basicConfig(level=logging.INFO)
# Define the number of classes
num_classes = 2
# Download model from Hugging Face
def download_model():
model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
return model_path
# Load the model from Hugging Face
def load_model(model_path):
model = models.resnet50(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
model.eval()
return model
# Download the model and load it
model_path = download_model()
model = load_model(model_path)
# Define the transformation for the input image
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
# Prediction function for an uploaded image
def predict_from_image(image):
try:
# Ensure the input is a valid PIL image
if not isinstance(image, Image.Image):
raise ValueError("Invalid image format received. Please provide a valid image.")
# Log the input for debugging
logging.info("Received image for prediction")
# Apply transformations
image_tensor = transform(image).unsqueeze(0)
# Predict
with torch.no_grad():
outputs = model(image_tensor)
predicted_class = torch.argmax(outputs, dim=1).item()
# Interpret the result
if predicted_class == 0:
return {"result": "The photo is of fall army worm with problem ID 126."}
elif predicted_class == 1:
return {"result": "The photo is of a healthy maize image."}
else:
return {"error": "Unexpected class prediction."}
except Exception as e:
logging.error(f"Error during prediction: {str(e)}")
return {"error": f"Failed to process the image: {str(e)}"}
# Gradio interface restricted to image input
iface = gr.Interface(
fn=predict_from_image, # Only handle image input
inputs=gr.Image(type="pil", label="Upload an Image"), # Restrict input to image upload
outputs=gr.JSON(label="Prediction Result"),
live=True,
title="Maize Anomaly Detection",
description="Upload an image to detect anomalies in maize crops.",
)
# Launch the interface locally
if __name__ == "__main__":
iface.launch() |