hassaanik commited on
Commit
99d3103
·
verified ·
1 Parent(s): 1009e17

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -64
app.py DELETED
@@ -1,64 +0,0 @@
1
- import gradio as gr
2
- import torch
3
- from torchvision import transforms as tt
4
- from PIL import Image
5
-
6
- # Load models
7
- potato_model = torch.load("models/potato_model_statedict__f.pth", map_location=torch.device('cpu'))
8
- tomato_model = torch.load("models/tomato_model_statedict__f.pth", map_location=torch.device('cpu'))
9
-
10
- # Set both models to evaluation mode
11
- potato_model.eval()
12
- tomato_model.eval()
13
-
14
- # Prediction function for Potato and Tomato models
15
- def predict(image, model_type):
16
- # Define transformations
17
- transforms = tt.Compose([
18
- tt.Resize((224, 224)),
19
- tt.ToTensor()
20
- ])
21
-
22
- # Open image and apply transforms
23
- image = Image.open(image).convert("RGB")
24
- image_tensor = transforms(image).unsqueeze(0)
25
-
26
- # Choose the correct model
27
- model = potato_model if model_type == 'Potato' else tomato_model
28
-
29
- # Perform prediction
30
- with torch.no_grad():
31
- output = model(image_tensor)
32
-
33
- probabilities = torch.nn.functional.softmax(output[0], dim=0)
34
- predicted_class = torch.argmax(probabilities).item()
35
- predicted_probability = probabilities[predicted_class].item()
36
-
37
- # Define class labels
38
- if model_type == 'Potato':
39
- class_labels = ['Potato Early Blight', 'Potato Late Blight', 'Potato Healthy']
40
- else:
41
- class_labels = ['Tomato Early Blight', 'Tomato Late Blight', 'Tomato Healthy']
42
-
43
- return {
44
- "Class": class_labels[predicted_class],
45
- "Probability (%)": round(predicted_probability * 100, 2)
46
- }
47
-
48
- # Gradio Interface
49
- def inference(image, model_type):
50
- prediction = predict(image, model_type)
51
- return prediction
52
-
53
- gr.Interface(
54
- fn=inference,
55
- inputs=[
56
- gr.Image(type="file", label="Upload an Image"),
57
- gr.Radio(["Potato", "Tomato"], label="Model Type", value="Potato")
58
- ],
59
- outputs=[
60
- gr.Label(label="Prediction"),
61
- ],
62
- title="Plant Disease Classification",
63
- description="Upload an image to predict whether the plant is healthy or has a disease."
64
- ).launch()