saritha commited on
Commit
59542a9
·
verified ·
1 Parent(s): 9159a27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -40
app.py CHANGED
@@ -8,11 +8,10 @@ import os
8
  import contextlib
9
  from transformers import ViTForImageClassification, pipeline
10
 
11
- # Suppress warnings related to the model weights initialization
12
  warnings.filterwarnings("ignore", category=UserWarning, message=".*weights.*")
13
  warnings.filterwarnings("ignore", category=FutureWarning, module="torch")
14
 
15
- # Suppress output for copying files and verbose model initialization messages
16
  @contextlib.contextmanager
17
  def suppress_stdout():
18
  with open(os.devnull, 'w') as devnull:
@@ -23,16 +22,12 @@ def suppress_stdout():
23
  finally:
24
  sys.stdout = old_stdout
25
 
26
- # Load the sugarcane disease model
27
  with suppress_stdout():
28
  model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=6)
29
  model.load_state_dict(torch.load('vit_sugarcane_disease_detection.pth', map_location=torch.device('cpu')))
30
  model.eval()
31
 
32
- # Load a general-purpose classifier (e.g., MobileNetV2)
33
- general_model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
34
- general_model.eval()
35
-
36
  # Define the same transformation used during training
37
  transform = transforms.Compose([
38
  transforms.Resize((224, 224)),
@@ -40,9 +35,12 @@ transform = transforms.Compose([
40
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
41
  ])
42
 
43
- # Load the class names (disease types)
44
  class_names = ['BacterialBlights', 'Healthy', 'Mosaic', 'RedRot', 'Rust', 'Yellow']
45
 
 
 
 
46
  # Knowledge base for sugarcane diseases
47
  knowledge_base = {
48
  'BacterialBlights': "Bacterial blights cause water-soaked lesions on leaves, leading to yellowing and withering. To manage, apply copper-based fungicides and ensure proper drainage.",
@@ -53,31 +51,10 @@ knowledge_base = {
53
  'Healthy': "The sugarcane crop is healthy. Continue regular monitoring and good agronomic practices."
54
  }
55
 
56
- # Function to check if the image is plant-related
57
- def is_plant_image(image):
58
- general_transform = transforms.Compose([
59
- transforms.Resize((224, 224)),
60
- transforms.ToTensor(),
61
- ])
62
- img_tensor = general_transform(image).unsqueeze(0)
63
- with torch.no_grad():
64
- outputs = general_model(img_tensor)
65
- _, predicted_class = torch.max(outputs, 1)
66
- # Check if the predicted class corresponds to plant-like images
67
- plant_related_classes = range(20, 25) # Replace with specific classes for plants
68
- return predicted_class.item() in plant_related_classes
69
-
70
- # Predict disease or detect non-sugarcane images
71
  def predict_disease(image):
72
- if not is_plant_image(image):
73
- return """
74
- <div style='font-size: 18px; color: #FF5722; font-weight: bold;'>
75
- The uploaded image is not related to sugarcane. Please upload a sugarcane image.
76
- </div>
77
- """
78
-
79
  # Apply transformations to the image
80
- img_tensor = transform(image).unsqueeze(0)
81
 
82
  # Make prediction
83
  with torch.no_grad():
@@ -85,35 +62,64 @@ def predict_disease(image):
85
  probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
86
  max_prob, predicted_class = torch.max(probabilities, 1)
87
 
 
 
 
 
 
 
 
 
 
 
 
88
  # Get the predicted label
89
  predicted_label = class_names[predicted_class.item()]
90
 
91
  # Retrieve response from knowledge base
92
- detailed_response = knowledge_base.get(predicted_label, "No additional information available.")
93
-
94
- # Create styled HTML output
 
 
 
 
 
95
  output_message = f"""
96
  <div style='font-size: 18px; color: #4CAF50; font-weight: bold;'>
97
  Detected Disease: <span style='color: #FF5722;'>{predicted_label}</span>
98
  </div>
99
  """
100
- output_message += f"""
101
- <p style='font-size: 16px; color: #757575;'>
102
- {detailed_response}
103
- </p>
104
- """
 
 
 
 
 
 
 
 
 
105
  return output_message
106
 
107
  # Create Gradio interface
108
  inputs = gr.Image(type="pil")
109
  outputs = gr.HTML()
110
 
 
 
111
  demo_app = gr.Interface(
112
  fn=predict_disease,
113
  inputs=inputs,
114
  outputs=outputs,
115
  title="Sugarcane Disease Detection",
116
- live=True
 
 
117
  )
118
 
119
  demo_app.launch(debug=True)
 
8
  import contextlib
9
  from transformers import ViTForImageClassification, pipeline
10
 
11
+ # Suppress warnings
12
  warnings.filterwarnings("ignore", category=UserWarning, message=".*weights.*")
13
  warnings.filterwarnings("ignore", category=FutureWarning, module="torch")
14
 
 
15
  @contextlib.contextmanager
16
  def suppress_stdout():
17
  with open(os.devnull, 'w') as devnull:
 
22
  finally:
23
  sys.stdout = old_stdout
24
 
25
+ # Load the saved model and suppress the warnings
26
  with suppress_stdout():
27
  model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=6)
28
  model.load_state_dict(torch.load('vit_sugarcane_disease_detection.pth', map_location=torch.device('cpu')))
29
  model.eval()
30
 
 
 
 
 
31
  # Define the same transformation used during training
32
  transform = transforms.Compose([
33
  transforms.Resize((224, 224)),
 
35
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
36
  ])
37
 
38
+ # Load the class names
39
  class_names = ['BacterialBlights', 'Healthy', 'Mosaic', 'RedRot', 'Rust', 'Yellow']
40
 
41
+ # Load AI response generator
42
+ ai_pipeline = pipeline("text-generation", model="gpt2", tokenizer="gpt2")
43
+
44
  # Knowledge base for sugarcane diseases
45
  knowledge_base = {
46
  'BacterialBlights': "Bacterial blights cause water-soaked lesions on leaves, leading to yellowing and withering. To manage, apply copper-based fungicides and ensure proper drainage.",
 
51
  'Healthy': "The sugarcane crop is healthy. Continue regular monitoring and good agronomic practices."
52
  }
53
 
54
+ # Update the predict_disease function to handle non-sugarcane images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def predict_disease(image):
 
 
 
 
 
 
 
56
  # Apply transformations to the image
57
+ img_tensor = transform(image).unsqueeze(0) # Add batch dimension
58
 
59
  # Make prediction
60
  with torch.no_grad():
 
62
  probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
63
  max_prob, predicted_class = torch.max(probabilities, 1)
64
 
65
+ # Confidence threshold for non-sugarcane detection
66
+ confidence_threshold = 0.6 # Adjust based on experimentation
67
+
68
+ # Check if the confidence is below the threshold
69
+ if max_prob.item() < confidence_threshold:
70
+ return """
71
+ <div style='font-size: 18px; color: #FF5722; font-weight: bold;'>
72
+ The uploaded image does not belong to the sugarcane dataset.
73
+ </div>
74
+ """
75
+
76
  # Get the predicted label
77
  predicted_label = class_names[predicted_class.item()]
78
 
79
  # Retrieve response from knowledge base
80
+ if predicted_label in knowledge_base:
81
+ detailed_response = knowledge_base[predicted_label]
82
+ else:
83
+ # Fallback to AI-generated response
84
+ prompt = f"The detected sugarcane disease is '{predicted_label}'. Provide detailed advice for managing this condition."
85
+ detailed_response = ai_pipeline(prompt, max_length=100, num_return_sequences=1, truncation=True)[0]['generated_text']
86
+
87
+ # Create a styled HTML output
88
  output_message = f"""
89
  <div style='font-size: 18px; color: #4CAF50; font-weight: bold;'>
90
  Detected Disease: <span style='color: #FF5722;'>{predicted_label}</span>
91
  </div>
92
  """
93
+
94
+ if predicted_label != "Healthy":
95
+ output_message += f"""
96
+ <p style='font-size: 16px; color: #757575;'>
97
+ {detailed_response}
98
+ </p>
99
+ """
100
+ else:
101
+ output_message += f"""
102
+ <p style='font-size: 16px; color: #757575;'>
103
+ {detailed_response}
104
+ </p>
105
+ """
106
+
107
  return output_message
108
 
109
  # Create Gradio interface
110
  inputs = gr.Image(type="pil")
111
  outputs = gr.HTML()
112
 
113
+ EXAMPLES = ["img1.jpeg", "redrot2.jpg", "rust1.jpg", "healthy2.jpeg"]
114
+
115
  demo_app = gr.Interface(
116
  fn=predict_disease,
117
  inputs=inputs,
118
  outputs=outputs,
119
  title="Sugarcane Disease Detection",
120
+ examples=EXAMPLES,
121
+ live=True,
122
+ theme="huggingface"
123
  )
124
 
125
  demo_app.launch(debug=True)