Spaces:
Sleeping
Sleeping
File size: 3,398 Bytes
2201868 342396f 2201868 342396f 163e73a 342396f a87604d 2201868 0f694f1 342396f 2201868 29a1140 342396f 163e73a 342396f 163e73a 342396f 163e73a 342396f 52fd9c2 29a1140 a87604d 29a1140 a87604d 29a1140 a87604d 29a1140 a87604d 29a1140 a87604d 29a1140 a87604d 29a1140 5b86dff 342396f 163e73a 29a1140 342396f 29a1140 342396f 29a1140 163e73a 29a1140 163e73a 342396f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import gradio as gr
import json
import torch
from torch import nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image
import requests
import os
from io import BytesIO
import logging
# Set up basic logging
logging.basicConfig(level=logging.INFO)
# Define the number of classes
num_classes = 2
# Download model from Hugging Face
def download_model():
model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
return model_path
# Load the model from Hugging Face
def load_model(model_path):
model = models.resnet50(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
model.eval()
return model
# Download the model and load it
model_path = download_model()
model = load_model(model_path)
# Define the transformation for the input image
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
# Global variable to store the file path
file_path = None
# Function to predict from image content
def predict_from_image(image):
# Ensure the image is a PIL Image
if not isinstance(image, Image.Image):
raise ValueError("Invalid image format received. Please provide a valid image.")
# Apply transformations
image_tensor = transform(image).unsqueeze(0)
# Predict
with torch.no_grad():
outputs = model(image_tensor)
predicted_class = torch.argmax(outputs, dim=1).item()
# Interpret the result
if predicted_class == 0:
return {"result": "The photo is of fall army worm with problem ID 126."}
elif predicted_class == 1:
return {"result": "The photo is of a healthy maize image."}
else:
return {"error": "Unexpected class prediction."}
# Function to handle the file path sent via POST request
def process_file_path(file_path_input):
global file_path
file_path = file_path_input # Store the file path
logging.info(f"Received file path: {file_path}")
if not os.path.exists(file_path):
logging.error(f"File not found at {file_path}")
return {"error": f"File not found at {file_path}"}
image = Image.open(file_path)
logging.info(f"Processing image from path: {file_path}")
return predict_from_image(image)
# Function to fetch the result (for the GET request)
def fetch_result():
if file_path:
image = Image.open(file_path)
logging.info(f"Making prediction for image at path: {file_path}")
return predict_from_image(image)
else:
logging.warning("No file path available. Please send a POST request with a file path first.")
return {"error": "No file path available. Please send a POST request with a file path first."}
# Gradio interface
iface = gr.Interface(
fn=process_file_path,
inputs=[
gr.Textbox(label="Enter Local Image Path", placeholder="Provide the local image path"),
],
outputs=gr.JSON(label="Prediction Result"),
live=False,
title="Maize Anomaly Detection",
description="Provide a local file path via POST request to process an image.",
)
# Launch the interface
iface.launch(share=True, show_error=True)
|