jays009's picture
Update app.py
a87604d verified
raw
history blame
3.4 kB
import gradio as gr
import json
import torch
from torch import nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image
import requests
import os
from io import BytesIO
import logging
# Set up basic logging
logging.basicConfig(level=logging.INFO)
# Define the number of classes
num_classes = 2
# Download model from Hugging Face
def download_model():
model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
return model_path
# Load the model from Hugging Face
def load_model(model_path):
model = models.resnet50(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
model.eval()
return model
# Download the model and load it
model_path = download_model()
model = load_model(model_path)
# Define the transformation for the input image
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
# Global variable to store the file path
file_path = None
# Function to predict from image content
def predict_from_image(image):
# Ensure the image is a PIL Image
if not isinstance(image, Image.Image):
raise ValueError("Invalid image format received. Please provide a valid image.")
# Apply transformations
image_tensor = transform(image).unsqueeze(0)
# Predict
with torch.no_grad():
outputs = model(image_tensor)
predicted_class = torch.argmax(outputs, dim=1).item()
# Interpret the result
if predicted_class == 0:
return {"result": "The photo is of fall army worm with problem ID 126."}
elif predicted_class == 1:
return {"result": "The photo is of a healthy maize image."}
else:
return {"error": "Unexpected class prediction."}
# Function to handle the file path sent via POST request
def process_file_path(file_path_input):
global file_path
file_path = file_path_input # Store the file path
logging.info(f"Received file path: {file_path}")
if not os.path.exists(file_path):
logging.error(f"File not found at {file_path}")
return {"error": f"File not found at {file_path}"}
image = Image.open(file_path)
logging.info(f"Processing image from path: {file_path}")
return predict_from_image(image)
# Function to fetch the result (for the GET request)
def fetch_result():
if file_path:
image = Image.open(file_path)
logging.info(f"Making prediction for image at path: {file_path}")
return predict_from_image(image)
else:
logging.warning("No file path available. Please send a POST request with a file path first.")
return {"error": "No file path available. Please send a POST request with a file path first."}
# Gradio interface
iface = gr.Interface(
fn=process_file_path,
inputs=[
gr.Textbox(label="Enter Local Image Path", placeholder="Provide the local image path"),
],
outputs=gr.JSON(label="Prediction Result"),
live=False,
title="Maize Anomaly Detection",
description="Provide a local file path via POST request to process an image.",
)
# Launch the interface
iface.launch(share=True, show_error=True)