Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -8,11 +8,10 @@ import os
|
|
8 |
import contextlib
|
9 |
from transformers import ViTForImageClassification, pipeline
|
10 |
|
11 |
-
# Suppress warnings
|
12 |
warnings.filterwarnings("ignore", category=UserWarning, message=".*weights.*")
|
13 |
warnings.filterwarnings("ignore", category=FutureWarning, module="torch")
|
14 |
|
15 |
-
# Suppress output for copying files and verbose model initialization messages
|
16 |
@contextlib.contextmanager
|
17 |
def suppress_stdout():
|
18 |
with open(os.devnull, 'w') as devnull:
|
@@ -23,16 +22,12 @@ def suppress_stdout():
|
|
23 |
finally:
|
24 |
sys.stdout = old_stdout
|
25 |
|
26 |
-
# Load the
|
27 |
with suppress_stdout():
|
28 |
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=6)
|
29 |
model.load_state_dict(torch.load('vit_sugarcane_disease_detection.pth', map_location=torch.device('cpu')))
|
30 |
model.eval()
|
31 |
|
32 |
-
# Load a general-purpose classifier (e.g., MobileNetV2)
|
33 |
-
general_model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
|
34 |
-
general_model.eval()
|
35 |
-
|
36 |
# Define the same transformation used during training
|
37 |
transform = transforms.Compose([
|
38 |
transforms.Resize((224, 224)),
|
@@ -40,9 +35,12 @@ transform = transforms.Compose([
|
|
40 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
41 |
])
|
42 |
|
43 |
-
# Load the class names
|
44 |
class_names = ['BacterialBlights', 'Healthy', 'Mosaic', 'RedRot', 'Rust', 'Yellow']
|
45 |
|
|
|
|
|
|
|
46 |
# Knowledge base for sugarcane diseases
|
47 |
knowledge_base = {
|
48 |
'BacterialBlights': "Bacterial blights cause water-soaked lesions on leaves, leading to yellowing and withering. To manage, apply copper-based fungicides and ensure proper drainage.",
|
@@ -53,31 +51,10 @@ knowledge_base = {
|
|
53 |
'Healthy': "The sugarcane crop is healthy. Continue regular monitoring and good agronomic practices."
|
54 |
}
|
55 |
|
56 |
-
#
|
57 |
-
def is_plant_image(image):
|
58 |
-
general_transform = transforms.Compose([
|
59 |
-
transforms.Resize((224, 224)),
|
60 |
-
transforms.ToTensor(),
|
61 |
-
])
|
62 |
-
img_tensor = general_transform(image).unsqueeze(0)
|
63 |
-
with torch.no_grad():
|
64 |
-
outputs = general_model(img_tensor)
|
65 |
-
_, predicted_class = torch.max(outputs, 1)
|
66 |
-
# Check if the predicted class corresponds to plant-like images
|
67 |
-
plant_related_classes = range(20, 25) # Replace with specific classes for plants
|
68 |
-
return predicted_class.item() in plant_related_classes
|
69 |
-
|
70 |
-
# Predict disease or detect non-sugarcane images
|
71 |
def predict_disease(image):
|
72 |
-
if not is_plant_image(image):
|
73 |
-
return """
|
74 |
-
<div style='font-size: 18px; color: #FF5722; font-weight: bold;'>
|
75 |
-
The uploaded image is not related to sugarcane. Please upload a sugarcane image.
|
76 |
-
</div>
|
77 |
-
"""
|
78 |
-
|
79 |
# Apply transformations to the image
|
80 |
-
img_tensor = transform(image).unsqueeze(0)
|
81 |
|
82 |
# Make prediction
|
83 |
with torch.no_grad():
|
@@ -85,35 +62,64 @@ def predict_disease(image):
|
|
85 |
probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
|
86 |
max_prob, predicted_class = torch.max(probabilities, 1)
|
87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
# Get the predicted label
|
89 |
predicted_label = class_names[predicted_class.item()]
|
90 |
|
91 |
# Retrieve response from knowledge base
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
95 |
output_message = f"""
|
96 |
<div style='font-size: 18px; color: #4CAF50; font-weight: bold;'>
|
97 |
Detected Disease: <span style='color: #FF5722;'>{predicted_label}</span>
|
98 |
</div>
|
99 |
"""
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
return output_message
|
106 |
|
107 |
# Create Gradio interface
|
108 |
inputs = gr.Image(type="pil")
|
109 |
outputs = gr.HTML()
|
110 |
|
|
|
|
|
111 |
demo_app = gr.Interface(
|
112 |
fn=predict_disease,
|
113 |
inputs=inputs,
|
114 |
outputs=outputs,
|
115 |
title="Sugarcane Disease Detection",
|
116 |
-
|
|
|
|
|
117 |
)
|
118 |
|
119 |
demo_app.launch(debug=True)
|
|
|
8 |
import contextlib
|
9 |
from transformers import ViTForImageClassification, pipeline
|
10 |
|
11 |
+
# Suppress warnings
|
12 |
warnings.filterwarnings("ignore", category=UserWarning, message=".*weights.*")
|
13 |
warnings.filterwarnings("ignore", category=FutureWarning, module="torch")
|
14 |
|
|
|
15 |
@contextlib.contextmanager
|
16 |
def suppress_stdout():
|
17 |
with open(os.devnull, 'w') as devnull:
|
|
|
22 |
finally:
|
23 |
sys.stdout = old_stdout
|
24 |
|
25 |
+
# Load the saved model and suppress the warnings
|
26 |
with suppress_stdout():
|
27 |
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=6)
|
28 |
model.load_state_dict(torch.load('vit_sugarcane_disease_detection.pth', map_location=torch.device('cpu')))
|
29 |
model.eval()
|
30 |
|
|
|
|
|
|
|
|
|
31 |
# Define the same transformation used during training
|
32 |
transform = transforms.Compose([
|
33 |
transforms.Resize((224, 224)),
|
|
|
35 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
36 |
])
|
37 |
|
38 |
+
# Load the class names
|
39 |
class_names = ['BacterialBlights', 'Healthy', 'Mosaic', 'RedRot', 'Rust', 'Yellow']
|
40 |
|
41 |
+
# Load AI response generator
|
42 |
+
ai_pipeline = pipeline("text-generation", model="gpt2", tokenizer="gpt2")
|
43 |
+
|
44 |
# Knowledge base for sugarcane diseases
|
45 |
knowledge_base = {
|
46 |
'BacterialBlights': "Bacterial blights cause water-soaked lesions on leaves, leading to yellowing and withering. To manage, apply copper-based fungicides and ensure proper drainage.",
|
|
|
51 |
'Healthy': "The sugarcane crop is healthy. Continue regular monitoring and good agronomic practices."
|
52 |
}
|
53 |
|
54 |
+
# Update the predict_disease function to handle non-sugarcane images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
def predict_disease(image):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
# Apply transformations to the image
|
57 |
+
img_tensor = transform(image).unsqueeze(0) # Add batch dimension
|
58 |
|
59 |
# Make prediction
|
60 |
with torch.no_grad():
|
|
|
62 |
probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
|
63 |
max_prob, predicted_class = torch.max(probabilities, 1)
|
64 |
|
65 |
+
# Confidence threshold for non-sugarcane detection
|
66 |
+
confidence_threshold = 0.6 # Adjust based on experimentation
|
67 |
+
|
68 |
+
# Check if the confidence is below the threshold
|
69 |
+
if max_prob.item() < confidence_threshold:
|
70 |
+
return """
|
71 |
+
<div style='font-size: 18px; color: #FF5722; font-weight: bold;'>
|
72 |
+
The uploaded image does not belong to the sugarcane dataset.
|
73 |
+
</div>
|
74 |
+
"""
|
75 |
+
|
76 |
# Get the predicted label
|
77 |
predicted_label = class_names[predicted_class.item()]
|
78 |
|
79 |
# Retrieve response from knowledge base
|
80 |
+
if predicted_label in knowledge_base:
|
81 |
+
detailed_response = knowledge_base[predicted_label]
|
82 |
+
else:
|
83 |
+
# Fallback to AI-generated response
|
84 |
+
prompt = f"The detected sugarcane disease is '{predicted_label}'. Provide detailed advice for managing this condition."
|
85 |
+
detailed_response = ai_pipeline(prompt, max_length=100, num_return_sequences=1, truncation=True)[0]['generated_text']
|
86 |
+
|
87 |
+
# Create a styled HTML output
|
88 |
output_message = f"""
|
89 |
<div style='font-size: 18px; color: #4CAF50; font-weight: bold;'>
|
90 |
Detected Disease: <span style='color: #FF5722;'>{predicted_label}</span>
|
91 |
</div>
|
92 |
"""
|
93 |
+
|
94 |
+
if predicted_label != "Healthy":
|
95 |
+
output_message += f"""
|
96 |
+
<p style='font-size: 16px; color: #757575;'>
|
97 |
+
{detailed_response}
|
98 |
+
</p>
|
99 |
+
"""
|
100 |
+
else:
|
101 |
+
output_message += f"""
|
102 |
+
<p style='font-size: 16px; color: #757575;'>
|
103 |
+
{detailed_response}
|
104 |
+
</p>
|
105 |
+
"""
|
106 |
+
|
107 |
return output_message
|
108 |
|
109 |
# Create Gradio interface
|
110 |
inputs = gr.Image(type="pil")
|
111 |
outputs = gr.HTML()
|
112 |
|
113 |
+
EXAMPLES = ["img1.jpeg", "redrot2.jpg", "rust1.jpg", "healthy2.jpeg"]
|
114 |
+
|
115 |
demo_app = gr.Interface(
|
116 |
fn=predict_disease,
|
117 |
inputs=inputs,
|
118 |
outputs=outputs,
|
119 |
title="Sugarcane Disease Detection",
|
120 |
+
examples=EXAMPLES,
|
121 |
+
live=True,
|
122 |
+
theme="huggingface"
|
123 |
)
|
124 |
|
125 |
demo_app.launch(debug=True)
|