saritha commited on
Commit
7794a17
·
verified ·
1 Parent(s): a31e72e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -3
app.py CHANGED
@@ -6,7 +6,7 @@ import warnings
6
  import sys
7
  import os
8
  import contextlib
9
- from transformers import ViTForImageClassification
10
 
11
  # Suppress warnings related to the model weights initialization
12
  warnings.filterwarnings("ignore", category=UserWarning, message=".*weights.*")
@@ -39,6 +39,9 @@ transform = transforms.Compose([
39
  # Load the class names (disease types)
40
  class_names = ['BacterialBlights', 'Healthy', 'Mosaic', 'RedRot', 'Rust', 'Yellow']
41
 
 
 
 
42
  # Function to predict disease type from an image
43
  def predict_disease(image):
44
  # Apply transformations to the image
@@ -51,7 +54,11 @@ def predict_disease(image):
51
 
52
  # Get the predicted label
53
  predicted_label = class_names[predicted_class.item()]
54
-
 
 
 
 
55
  # Create a styled HTML output
56
  output_message = f"""
57
  <div style='font-size: 18px; color: #4CAF50; font-weight: bold;'>
@@ -62,7 +69,7 @@ def predict_disease(image):
62
  if predicted_label != "Healthy":
63
  output_message += f"""
64
  <p style='font-size: 16px; color: #757575;'>
65
- This indicates the presence of <strong>{predicted_label}</strong>. Please take immediate action to prevent further spread.
66
  </p>
67
  """
68
  else:
@@ -90,4 +97,6 @@ demo_app = gr.Interface(
90
  theme="huggingface"
91
  )
92
 
 
 
93
  demo_app.launch(debug=True)
 
6
  import sys
7
  import os
8
  import contextlib
9
+ from transformers import ViTForImageClassification, pipeline
10
 
11
  # Suppress warnings related to the model weights initialization
12
  warnings.filterwarnings("ignore", category=UserWarning, message=".*weights.*")
 
39
  # Load the class names (disease types)
40
  class_names = ['BacterialBlights', 'Healthy', 'Mosaic', 'RedRot', 'Rust', 'Yellow']
41
 
42
+ # Load AI response generator (using a local GPT pipeline or OpenAI's GPT-3/4 API)
43
+ ai_pipeline = pipeline("text-generation", model="gpt2", tokenizer="gpt2")
44
+
45
  # Function to predict disease type from an image
46
  def predict_disease(image):
47
  # Apply transformations to the image
 
54
 
55
  # Get the predicted label
56
  predicted_label = class_names[predicted_class.item()]
57
+
58
+ # Generate AI response based on the detected disease
59
+ prompt = f"You are an expert in sugarcane farming. The detected disease is '{predicted_label}'. Provide advice for the farmer."
60
+ ai_response = ai_pipeline(prompt, max_length=50, num_return_sequences=1)[0]['generated_text']
61
+
62
  # Create a styled HTML output
63
  output_message = f"""
64
  <div style='font-size: 18px; color: #4CAF50; font-weight: bold;'>
 
69
  if predicted_label != "Healthy":
70
  output_message += f"""
71
  <p style='font-size: 16px; color: #757575;'>
72
+ <strong>AI Response:</strong> {ai_response}
73
  </p>
74
  """
75
  else:
 
97
  theme="huggingface"
98
  )
99
 
100
+ demo_app.launch(debug=True)
101
+
102
  demo_app.launch(debug=True)