Spaces:
Runtime error
Runtime error
Delete app.py
Browse files
app.py
DELETED
@@ -1,64 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
import torch
|
3 |
-
from torchvision import transforms as tt
|
4 |
-
from PIL import Image
|
5 |
-
|
6 |
-
# Load models
|
7 |
-
potato_model = torch.load("models/potato_model_statedict__f.pth", map_location=torch.device('cpu'))
|
8 |
-
tomato_model = torch.load("models/tomato_model_statedict__f.pth", map_location=torch.device('cpu'))
|
9 |
-
|
10 |
-
# Set both models to evaluation mode
|
11 |
-
potato_model.eval()
|
12 |
-
tomato_model.eval()
|
13 |
-
|
14 |
-
# Prediction function for Potato and Tomato models
|
15 |
-
def predict(image, model_type):
|
16 |
-
# Define transformations
|
17 |
-
transforms = tt.Compose([
|
18 |
-
tt.Resize((224, 224)),
|
19 |
-
tt.ToTensor()
|
20 |
-
])
|
21 |
-
|
22 |
-
# Open image and apply transforms
|
23 |
-
image = Image.open(image).convert("RGB")
|
24 |
-
image_tensor = transforms(image).unsqueeze(0)
|
25 |
-
|
26 |
-
# Choose the correct model
|
27 |
-
model = potato_model if model_type == 'Potato' else tomato_model
|
28 |
-
|
29 |
-
# Perform prediction
|
30 |
-
with torch.no_grad():
|
31 |
-
output = model(image_tensor)
|
32 |
-
|
33 |
-
probabilities = torch.nn.functional.softmax(output[0], dim=0)
|
34 |
-
predicted_class = torch.argmax(probabilities).item()
|
35 |
-
predicted_probability = probabilities[predicted_class].item()
|
36 |
-
|
37 |
-
# Define class labels
|
38 |
-
if model_type == 'Potato':
|
39 |
-
class_labels = ['Potato Early Blight', 'Potato Late Blight', 'Potato Healthy']
|
40 |
-
else:
|
41 |
-
class_labels = ['Tomato Early Blight', 'Tomato Late Blight', 'Tomato Healthy']
|
42 |
-
|
43 |
-
return {
|
44 |
-
"Class": class_labels[predicted_class],
|
45 |
-
"Probability (%)": round(predicted_probability * 100, 2)
|
46 |
-
}
|
47 |
-
|
48 |
-
# Gradio Interface
|
49 |
-
def inference(image, model_type):
|
50 |
-
prediction = predict(image, model_type)
|
51 |
-
return prediction
|
52 |
-
|
53 |
-
gr.Interface(
|
54 |
-
fn=inference,
|
55 |
-
inputs=[
|
56 |
-
gr.Image(type="file", label="Upload an Image"),
|
57 |
-
gr.Radio(["Potato", "Tomato"], label="Model Type", value="Potato")
|
58 |
-
],
|
59 |
-
outputs=[
|
60 |
-
gr.Label(label="Prediction"),
|
61 |
-
],
|
62 |
-
title="Plant Disease Classification",
|
63 |
-
description="Upload an image to predict whether the plant is healthy or has a disease."
|
64 |
-
).launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|