Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -6,24 +6,19 @@ from torchvision import models, transforms
|
|
6 |
from huggingface_hub import hf_hub_download
|
7 |
from PIL import Image
|
8 |
import requests
|
|
|
9 |
from io import BytesIO
|
10 |
-
from fastapi import FastAPI
|
11 |
|
12 |
# Define the number of classes
|
13 |
num_classes = 2
|
14 |
|
15 |
-
# In-memory storage for results
|
16 |
-
results_cache = {}
|
17 |
-
|
18 |
# Download model from Hugging Face
|
19 |
def download_model():
|
20 |
-
print("Downloading model...")
|
21 |
model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
|
22 |
return model_path
|
23 |
|
24 |
# Load the model from Hugging Face
|
25 |
def load_model(model_path):
|
26 |
-
print("Loading model...")
|
27 |
model = models.resnet50(pretrained=False)
|
28 |
model.fc = nn.Linear(model.fc.in_features, num_classes)
|
29 |
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
|
@@ -44,116 +39,48 @@ transform = transforms.Compose([
|
|
44 |
|
45 |
# Function to predict from image content
|
46 |
def predict_from_image(image):
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
print(f"Error during image processing: {e}")
|
67 |
-
return {"error": str(e)}
|
68 |
|
69 |
# Function to predict from URL
|
70 |
def predict_from_url(url):
|
71 |
try:
|
72 |
-
print(f"Fetching image from URL: {url}")
|
73 |
response = requests.get(url)
|
74 |
-
response.raise_for_status()
|
75 |
image = Image.open(BytesIO(response.content))
|
76 |
return predict_from_image(image)
|
77 |
except Exception as e:
|
78 |
-
print(f"Error during URL processing: {e}")
|
79 |
return {"error": f"Failed to process the URL: {str(e)}"}
|
80 |
|
81 |
-
#
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
else:
|
97 |
-
result = {"error": "No input provided. Please upload an image or provide a URL."}
|
98 |
-
|
99 |
-
event_id = id(result) # Use Python's id() function to generate a unique identifier
|
100 |
-
results_cache[event_id] = result
|
101 |
-
|
102 |
-
print(f"Event ID: {event_id}, Result: {result}")
|
103 |
-
return {"event_id": event_id, "result": result}
|
104 |
-
except Exception as e:
|
105 |
-
print(f"Error in prediction function: {e}")
|
106 |
-
return {"error": str(e)}
|
107 |
-
|
108 |
-
|
109 |
-
# Function to retrieve result by event_id
|
110 |
-
def get_result(event_id):
|
111 |
-
try:
|
112 |
-
print(f"Retrieving result for event ID: {event_id}")
|
113 |
-
event_id = int(event_id)
|
114 |
-
result = results_cache.get(event_id)
|
115 |
-
if result:
|
116 |
-
return result
|
117 |
-
else:
|
118 |
-
return {"error": "No result found for the provided event ID."}
|
119 |
-
except Exception as e:
|
120 |
-
return {"error": f"Invalid event ID: {str(e)}"}
|
121 |
-
|
122 |
-
# Create a FastAPI app for handling the GET request
|
123 |
-
app = FastAPI()
|
124 |
-
|
125 |
-
@app.get("/result/{event_id}")
|
126 |
-
def get_result_api(event_id: int):
|
127 |
-
return get_result(event_id)
|
128 |
-
|
129 |
-
# Gradio interface setup
|
130 |
-
iface = gr.Blocks()
|
131 |
-
|
132 |
-
with iface:
|
133 |
-
gr.Markdown("# Maize Anomaly Detection")
|
134 |
-
with gr.Row():
|
135 |
-
image_input = gr.Image(type="pil", label="Upload an Image")
|
136 |
-
url_input = gr.Textbox(label="Or Enter an Image URL", placeholder="Provide a valid image URL")
|
137 |
-
output = gr.JSON(label="Prediction Result")
|
138 |
-
submit_button = gr.Button("Submit")
|
139 |
-
|
140 |
-
submit_button.click(
|
141 |
-
fn=predict,
|
142 |
-
inputs=[image_input, url_input],
|
143 |
-
outputs=output
|
144 |
-
)
|
145 |
-
|
146 |
-
# Event ID retrieval section
|
147 |
-
with gr.Row():
|
148 |
-
event_id_input = gr.Textbox(label="Event ID", placeholder="Enter Event ID")
|
149 |
-
event_output = gr.JSON(label="Retrieved Result")
|
150 |
-
retrieve_button = gr.Button("Get Result")
|
151 |
-
|
152 |
-
retrieve_button.click(
|
153 |
-
fn=get_result,
|
154 |
-
inputs=[event_id_input],
|
155 |
-
outputs=event_output
|
156 |
-
)
|
157 |
-
|
158 |
-
# Launch the Gradio interface
|
159 |
-
iface.launch(share=True, show_error=True, server_name="0.0.0.0", server_port=7860)
|
|
|
6 |
from huggingface_hub import hf_hub_download
|
7 |
from PIL import Image
|
8 |
import requests
|
9 |
+
import os
|
10 |
from io import BytesIO
|
|
|
11 |
|
12 |
# Define the number of classes
|
13 |
num_classes = 2
|
14 |
|
|
|
|
|
|
|
15 |
# Download model from Hugging Face
|
16 |
def download_model():
|
|
|
17 |
model_path = hf_hub_download(repo_id="jays009/Restnet50", filename="pytorch_model.bin")
|
18 |
return model_path
|
19 |
|
20 |
# Load the model from Hugging Face
|
21 |
def load_model(model_path):
|
|
|
22 |
model = models.resnet50(pretrained=False)
|
23 |
model.fc = nn.Linear(model.fc.in_features, num_classes)
|
24 |
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
|
|
|
39 |
|
40 |
# Function to predict from image content
|
41 |
def predict_from_image(image):
|
42 |
+
# Ensure the image is a PIL Image
|
43 |
+
if not isinstance(image, Image.Image):
|
44 |
+
raise ValueError("Invalid image format received. Please provide a valid image.")
|
45 |
+
|
46 |
+
# Apply transformations
|
47 |
+
image_tensor = transform(image).unsqueeze(0)
|
48 |
+
|
49 |
+
# Predict
|
50 |
+
with torch.no_grad():
|
51 |
+
outputs = model(image_tensor)
|
52 |
+
predicted_class = torch.argmax(outputs, dim=1).item()
|
53 |
+
|
54 |
+
# Interpret the result
|
55 |
+
if predicted_class == 0:
|
56 |
+
return {"result": "The photo is of fall army worm with problem ID 126."}
|
57 |
+
elif predicted_class == 1:
|
58 |
+
return {"result": "The photo is of a healthy maize image."}
|
59 |
+
else:
|
60 |
+
return {"error": "Unexpected class prediction."}
|
|
|
|
|
61 |
|
62 |
# Function to predict from URL
|
63 |
def predict_from_url(url):
|
64 |
try:
|
|
|
65 |
response = requests.get(url)
|
66 |
+
response.raise_for_status() # Ensure the request was successful
|
67 |
image = Image.open(BytesIO(response.content))
|
68 |
return predict_from_image(image)
|
69 |
except Exception as e:
|
|
|
70 |
return {"error": f"Failed to process the URL: {str(e)}"}
|
71 |
|
72 |
+
# Gradio interface
|
73 |
+
iface = gr.Interface(
|
74 |
+
fn=lambda image, url: predict_from_image(image) if image else predict_from_url(url),
|
75 |
+
inputs=[
|
76 |
+
gr.Image(type="pil", label="Upload an Image"),
|
77 |
+
gr.Textbox(label="Or Enter an Image URL", placeholder="Provide a valid image URL"),
|
78 |
+
],
|
79 |
+
outputs=gr.JSON(label="Prediction Result"),
|
80 |
+
live=True,
|
81 |
+
title="Maize Anomaly Detection",
|
82 |
+
description="Upload an image or provide a URL to detect anomalies in maize crops.",
|
83 |
+
)
|
84 |
+
|
85 |
+
# Launch the interface
|
86 |
+
iface.launch(share=True, show_error=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|