jays009's picture
Update app.py
20e6ace verified
raw
history blame
5.49 kB
import gradio as gr
import json
import torch
from torch import nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image
import requests
from io import BytesIO
from fastapi import FastAPI
# Define the number of classes
num_classes = 2
# In-memory storage for results
results_cache = {}
# Download model from Hugging Face
def download_model():
model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
return model_path
# Load the model from Hugging Face
def load_model(model_path):
model = models.resnet50(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
model.eval()
return model
# Download the model and load it
model_path = download_model()
model = load_model(model_path)
# Define the transformation for the input image
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
# Function to predict from image content
def predict_from_image(image):
try:
print(f"Processing image: {image}")
if not isinstance(image, Image.Image):
raise ValueError("Invalid image format received. Please provide a valid image.")
# Apply transformations
image_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(image_tensor)
predicted_class = torch.argmax(outputs, dim=1).item()
if predicted_class == 0:
return {"result": "The photo is of fall army worm with problem ID 126."}
elif predicted_class == 1:
return {"result": "The photo is of a healthy maize image."}
else:
return {"error": "Unexpected class prediction."}
except Exception as e:
print(f"Error during image processing: {e}")
return {"error": str(e)}
# Function to predict from URL
def predict_from_url(url):
try:
print(f"Fetching image from URL: {url}")
response = requests.get(url)
response.raise_for_status()
image = Image.open(BytesIO(response.content))
return predict_from_image(image)
except Exception as e:
print(f"Error during URL processing: {e}")
return {"error": f"Failed to process the URL: {str(e)}"}
# Main prediction function without event ID for direct uploads
def predict_direct_upload(image):
try:
if image:
result = predict_from_image(image)
return result
else:
return {"error": "No image provided. Please upload an image."}
except Exception as e:
print(f"Error in direct upload prediction function: {e}")
return {"error": str(e)}
# Main prediction function with caching for paths via Postman
def predict_with_event_id(data):
try:
path = data[0].get('path', None)
if path:
if path.startswith("http://") or path.startswith("https://"):
result = predict_from_url(path)
elif os.path.isfile(path):
image = Image.open(path)
result = predict_from_image(image)
else:
result = {"error": "Invalid path format. Please provide a valid URL or local file path."}
event_id = id(result)
results_cache[event_id] = result
print(f"Event ID: {event_id}, Result: {result}")
return {"event_id": event_id, "result": result}
else:
return {"error": "No path provided. Please provide a valid path."}
except Exception as e:
print(f"Error in prediction function with event ID: {e}")
return {"error": str(e)}
# Function to retrieve result by event_id
def get_result(event_id):
try:
print(f"Retrieving result for event ID: {event_id}")
event_id = int(event_id)
result = results_cache.get(event_id)
if result:
return result
else:
return {"error": "No result found for the provided event ID."}
except Exception as e:
return {"error": f"Invalid event ID: {str(e)}"}
# Create a FastAPI app for handling the GET request
app = FastAPI()
@app.get("/result/{event_id}")
def get_result_api(event_id: int):
return get_result(event_id)
# Gradio interface setup
iface = gr.Blocks()
with iface:
gr.Markdown("# Maize Anomaly Detection")
with gr.Row():
image_input = gr.Image(type="pil", label="Upload an Image")
url_input = gr.Textbox(label="Or Enter an Image URL", placeholder="Provide a valid image URL")
output = gr.JSON(label="Prediction Result")
submit_button = gr.Button("Submit")
submit_button.click(
fn=predict_direct_upload,
inputs=[image_input],
outputs=output
)
# Event ID retrieval section
with gr.Row():
event_id_input = gr.Textbox(label="Event ID", placeholder="Enter Event ID")
event_output = gr.JSON(label="Retrieved Result")
retrieve_button = gr.Button("Get Result")
retrieve_button.click(
fn=get_result,
inputs=[event_id_input],
outputs=event_output
)
# Launch the Gradio interface
iface.launch(share=True, show_error=True, server_name="0.0.0.0", server_port=7860)