Spaces:
Sleeping
Sleeping
File size: 5,076 Bytes
2201868 dd0dfd1 2201868 dd0dfd1 2201868 f7b71ff b77b937 2201868 f7b71ff b77b937 2201868 b77b937 2201868 b77b937 5b86dff a53a305 5b86dff fc29cbf 5b86dff 5cadf06 5b86dff a53a305 5b86dff a53a305 5b86dff a53a305 5b86dff a53a305 52fd9c2 b77b937 07f42c0 b77b937 07f42c0 b77b937 2255b93 a53a305 5b86dff f7b71ff a53a305 f7b71ff 334fb08 a53a305 334fb08 a53a305 334fb08 f7b71ff 334fb08 f7b71ff 334fb08 a53a305 f7b71ff 07f42c0 f7b71ff a53a305 f7b71ff a53a305 9dfc63c 334fb08 dd0dfd1 07f42c0 dd0dfd1 a53a305 dd0dfd1 f7b71ff dd0dfd1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import gradio as gr
import json
import torch
from torch import nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image
import requests
from io import BytesIO
from fastapi import FastAPI
# Define the number of classes
num_classes = 2
# In-memory storage for results
results_cache = {}
# Download model from Hugging Face
def download_model():
print("Downloading model...")
model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
return model_path
# Load the model from Hugging Face
def load_model(model_path):
print("Loading model...")
model = models.resnet50(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
model.eval()
return model
# Download the model and load it
model_path = download_model()
model = load_model(model_path)
# Define the transformation for the input image
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
# Function to predict from image content
def predict_from_image(image):
try:
print(f"Processing image: {image}")
if not isinstance(image, Image.Image):
raise ValueError("Invalid image format received. Please provide a valid image.")
# Apply transformations
image_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(image_tensor)
predicted_class = torch.argmax(outputs, dim=1).item()
if predicted_class == 0:
return {"result": "The photo is of fall army worm with problem ID 126."}
elif predicted_class == 1:
return {"result": "The photo is of a healthy maize image."}
else:
return {"error": "Unexpected class prediction."}
except Exception as e:
print(f"Error during image processing: {e}")
return {"error": str(e)}
# Function to predict from URL
def predict_from_url(url):
try:
print(f"Fetching image from URL: {url}")
response = requests.get(url)
response.raise_for_status()
image = Image.open(BytesIO(response.content))
return predict_from_image(image)
except Exception as e:
print(f"Error during URL processing: {e}")
return {"error": f"Failed to process the URL: {str(e)}"}
# Main prediction function with caching
def predict(image, url):
try:
print("Starting prediction...")
# Check if image is provided (direct upload)
if image:
print("Processing uploaded image...")
result = predict_from_image(image)
# Check if URL is provided
elif url:
print("Processing image from URL...")
result = predict_from_url(url)
else:
result = {"error": "No input provided. Please upload an image or provide a URL."}
event_id = id(result) # Use Python's id() function to generate a unique identifier
results_cache[event_id] = result
print(f"Event ID: {event_id}, Result: {result}")
return {"event_id": event_id, "result": result}
except Exception as e:
print(f"Error in prediction function: {e}")
return {"error": str(e)}
# Function to retrieve result by event_id
def get_result(event_id):
try:
print(f"Retrieving result for event ID: {event_id}")
event_id = int(event_id)
result = results_cache.get(event_id)
if result:
return result
else:
return {"error": "No result found for the provided event ID."}
except Exception as e:
return {"error": f"Invalid event ID: {str(e)}"}
# Create a FastAPI app for handling the GET request
app = FastAPI()
@app.get("/result/{event_id}")
def get_result_api(event_id: int):
return get_result(event_id)
# Gradio interface setup
iface = gr.Blocks()
with iface:
gr.Markdown("# Maize Anomaly Detection")
with gr.Row():
image_input = gr.Image(type="pil", label="Upload an Image")
url_input = gr.Textbox(label="Or Enter an Image URL", placeholder="Provide a valid image URL")
output = gr.JSON(label="Prediction Result")
submit_button = gr.Button("Submit")
submit_button.click(
fn=predict,
inputs=[image_input, url_input],
outputs=output
)
# Event ID retrieval section
with gr.Row():
event_id_input = gr.Textbox(label="Event ID", placeholder="Enter Event ID")
event_output = gr.JSON(label="Retrieved Result")
retrieve_button = gr.Button("Get Result")
retrieve_button.click(
fn=get_result,
inputs=[event_id_input],
outputs=event_output
)
# Launch the Gradio interface
iface.launch(share=True, show_error=True, server_name="0.0.0.0", server_port=7860)
|