saritha commited on
Commit
d521450
·
verified ·
1 Parent(s): 59542a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -20
app.py CHANGED
@@ -8,10 +8,11 @@ import os
8
  import contextlib
9
  from transformers import ViTForImageClassification, pipeline
10
 
11
- # Suppress warnings
12
  warnings.filterwarnings("ignore", category=UserWarning, message=".*weights.*")
13
  warnings.filterwarnings("ignore", category=FutureWarning, module="torch")
14
 
 
15
  @contextlib.contextmanager
16
  def suppress_stdout():
17
  with open(os.devnull, 'w') as devnull:
@@ -35,13 +36,13 @@ transform = transforms.Compose([
35
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
36
  ])
37
 
38
- # Load the class names
39
  class_names = ['BacterialBlights', 'Healthy', 'Mosaic', 'RedRot', 'Rust', 'Yellow']
40
 
41
- # Load AI response generator
42
  ai_pipeline = pipeline("text-generation", model="gpt2", tokenizer="gpt2")
43
 
44
- # Knowledge base for sugarcane diseases
45
  knowledge_base = {
46
  'BacterialBlights': "Bacterial blights cause water-soaked lesions on leaves, leading to yellowing and withering. To manage, apply copper-based fungicides and ensure proper drainage.",
47
  'Mosaic': "Mosaic disease results in streaked and mottled leaves, reducing photosynthesis. Use disease-resistant varieties and control aphids to prevent spread.",
@@ -51,7 +52,7 @@ knowledge_base = {
51
  'Healthy': "The sugarcane crop is healthy. Continue regular monitoring and good agronomic practices."
52
  }
53
 
54
- # Update the predict_disease function to handle non-sugarcane images
55
  def predict_disease(image):
56
  # Apply transformations to the image
57
  img_tensor = transform(image).unsqueeze(0) # Add batch dimension
@@ -59,20 +60,8 @@ def predict_disease(image):
59
  # Make prediction
60
  with torch.no_grad():
61
  outputs = model(img_tensor)
62
- probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
63
- max_prob, predicted_class = torch.max(probabilities, 1)
64
 
65
- # Confidence threshold for non-sugarcane detection
66
- confidence_threshold = 0.6 # Adjust based on experimentation
67
-
68
- # Check if the confidence is below the threshold
69
- if max_prob.item() < confidence_threshold:
70
- return """
71
- <div style='font-size: 18px; color: #FF5722; font-weight: bold;'>
72
- The uploaded image does not belong to the sugarcane dataset.
73
- </div>
74
- """
75
-
76
  # Get the predicted label
77
  predicted_label = class_names[predicted_class.item()]
78
 
@@ -108,7 +97,7 @@ def predict_disease(image):
108
 
109
  # Create Gradio interface
110
  inputs = gr.Image(type="pil")
111
- outputs = gr.HTML()
112
 
113
  EXAMPLES = ["img1.jpeg", "redrot2.jpg", "rust1.jpg", "healthy2.jpeg"]
114
 
@@ -124,4 +113,3 @@ demo_app = gr.Interface(
124
 
125
  demo_app.launch(debug=True)
126
 
127
-
 
8
  import contextlib
9
  from transformers import ViTForImageClassification, pipeline
10
 
11
+ # Suppress warnings related to the model weights initialization
12
  warnings.filterwarnings("ignore", category=UserWarning, message=".*weights.*")
13
  warnings.filterwarnings("ignore", category=FutureWarning, module="torch")
14
 
15
+ # Suppress output for copying files and verbose model initialization messages
16
  @contextlib.contextmanager
17
  def suppress_stdout():
18
  with open(os.devnull, 'w') as devnull:
 
36
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
37
  ])
38
 
39
+ # Load the class names (disease types)
40
  class_names = ['BacterialBlights', 'Healthy', 'Mosaic', 'RedRot', 'Rust', 'Yellow']
41
 
42
+ # Load AI response generator (using a local GPT pipeline or OpenAI's GPT-3/4 API)
43
  ai_pipeline = pipeline("text-generation", model="gpt2", tokenizer="gpt2")
44
 
45
+ # Knowledge base for sugarcane diseases (example data from the website)
46
  knowledge_base = {
47
  'BacterialBlights': "Bacterial blights cause water-soaked lesions on leaves, leading to yellowing and withering. To manage, apply copper-based fungicides and ensure proper drainage.",
48
  'Mosaic': "Mosaic disease results in streaked and mottled leaves, reducing photosynthesis. Use disease-resistant varieties and control aphids to prevent spread.",
 
52
  'Healthy': "The sugarcane crop is healthy. Continue regular monitoring and good agronomic practices."
53
  }
54
 
55
+ # Update the predict_disease function
56
  def predict_disease(image):
57
  # Apply transformations to the image
58
  img_tensor = transform(image).unsqueeze(0) # Add batch dimension
 
60
  # Make prediction
61
  with torch.no_grad():
62
  outputs = model(img_tensor)
63
+ _, predicted_class = torch.max(outputs.logits, 1)
 
64
 
 
 
 
 
 
 
 
 
 
 
 
65
  # Get the predicted label
66
  predicted_label = class_names[predicted_class.item()]
67
 
 
97
 
98
  # Create Gradio interface
99
  inputs = gr.Image(type="pil")
100
+ outputs = gr.HTML() # Use HTML output for styled text
101
 
102
  EXAMPLES = ["img1.jpeg", "redrot2.jpg", "rust1.jpg", "healthy2.jpeg"]
103
 
 
113
 
114
  demo_app.launch(debug=True)
115