Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,39 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from torchvision import transforms as tt
|
4 |
+
from PIL import Image
|
5 |
+
from predict import predict_potato, predict_tomato # Ensure these functions are defined to use your models
|
6 |
+
from model import model_potato, model_tomato # Ensure these models are correctly defined in your script
|
7 |
|
8 |
+
# Load models (assuming they are saved in the "models" directory within the Hugging Face Space)
|
9 |
+
model_potato.load_state_dict(torch.load("models/potato_model_statedict__f.pth", map_location=torch.device('cpu')))
|
10 |
+
model_tomato.load_state_dict(torch.load("models/tomato_model_statedict__f.pth", map_location=torch.device('cpu')))
|
11 |
+
|
12 |
+
# Define a function to make predictions based on the selected model
|
13 |
+
def classify_image(image, plant_type):
|
14 |
+
if plant_type == "Potato":
|
15 |
+
class_name, probability = predict_potato(image, model_potato)
|
16 |
+
else:
|
17 |
+
class_name, probability = predict_tomato(image, model_tomato)
|
18 |
+
|
19 |
+
return {
|
20 |
+
"Predicted Class": class_name,
|
21 |
+
"Confidence (%)": f"{probability * 100:.2f}%"
|
22 |
+
}
|
23 |
+
|
24 |
+
# Set up Gradio interface
|
25 |
+
iface = gr.Interface(
|
26 |
+
fn=classify_image,
|
27 |
+
inputs=[
|
28 |
+
gr.inputs.Image(type="file", label="Upload an Image"),
|
29 |
+
gr.inputs.Radio(["Potato", "Tomato"], label="Select Plant Type")
|
30 |
+
],
|
31 |
+
outputs=[
|
32 |
+
gr.outputs.Label(num_top_classes=1, label="Prediction Results")
|
33 |
+
],
|
34 |
+
title="Potato and Tomato Disease Classification",
|
35 |
+
description="Upload an image of a potato or tomato leaf to classify the disease."
|
36 |
+
)
|
37 |
+
|
38 |
+
# Launch Gradio interface
|
39 |
+
iface.launch()
|