sanket09 commited on
Commit
aa651ec
·
verified ·
1 Parent(s): 08018bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -13
app.py CHANGED
@@ -15,6 +15,8 @@ import rasterio
15
  import matplotlib.pyplot as plt
16
  from tensorflow.keras.applications import ResNet50
17
  from tensorflow.keras.models import Model
 
 
18
 
19
  # Load crop data
20
  def load_data():
@@ -55,7 +57,34 @@ def predict_traditional(model_name, year, state, crop, yield_):
55
  else:
56
  return "Model not found"
57
 
58
- # Load pre-trained deep learning models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def load_deep_learning_model(model_name):
60
  base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(128, 128, 3))
61
  base_model.trainable = False
@@ -105,27 +134,22 @@ def predict_deep_learning(model_name, file):
105
  predictions = model.predict(preprocessed_patches)
106
  predictions = predictions.reshape((n_patches_y, n_patches_x))
107
 
108
- # Set a threshold to highlight areas with higher predicted yields
109
- threshold = np.percentile(predictions, 90) # Adjust the percentile as needed
110
 
111
- # Create an overlay image to visualize predictions
112
  overlay = np.zeros_like(img_data, dtype=np.float32)
113
  for i in range(n_patches_y):
114
  for j in range(n_patches_x):
115
  if predictions[i, j] > threshold:
116
  overlay[i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size] = predictions[i, j]
117
 
118
- # Plot the overlay on the original image
119
  plt.figure(figsize=(10, 10))
120
  plt.imshow(img_data, cmap='gray', alpha=0.5)
121
  plt.imshow(overlay, cmap='jet', alpha=0.5)
122
  plt.title('Crop Yield Prediction Overlay')
123
- plt.colorbar()
124
-
125
- # Save the plot to a file
126
- plt.savefig('/tmp/prediction_overlay.png')
127
 
128
- return '/tmp/prediction_overlay.png'
129
  else:
130
  return "No file uploaded"
131
  else:
@@ -141,7 +165,7 @@ inputs_traditional = [
141
  outputs_traditional = gr.Textbox(label='Predicted Profit')
142
 
143
  inputs_deep_learning = [
144
- gr.Dropdown(choices=list(deep_learning_models.keys()), label='Model'),
145
  gr.File(label='Upload TIFF File')
146
  ]
147
  outputs_deep_learning = gr.Image(label='Prediction Overlay')
@@ -157,10 +181,10 @@ with gr.Blocks() as demo:
157
 
158
  with gr.Tab("Deep Learning Models"):
159
  gr.Interface(
160
- fn=predict_deep_learning,
161
  inputs=inputs_deep_learning,
162
  outputs=outputs_deep_learning,
163
- title="Crop Yield Prediction using Deep Learning Models"
164
  )
165
 
166
  demo.launch()
 
15
  import matplotlib.pyplot as plt
16
  from tensorflow.keras.applications import ResNet50
17
  from tensorflow.keras.models import Model
18
+ import cv2
19
+ import joblib
20
 
21
  # Load crop data
22
  def load_data():
 
57
  else:
58
  return "Model not found"
59
 
60
+ # Load the pre-trained RandomForestRegressor model
61
+ rf_model = joblib.load('crop_yield_model.joblib')
62
+
63
+ def predict_random_forest(file):
64
+ if file is not None:
65
+ def process_tiff(file_path):
66
+ with rasterio.open(file_path) as src:
67
+ tiff_data = src.read()
68
+ B2_image = tiff_data[1, :, :]
69
+ target_size = (50, 50)
70
+ B2_resized = cv2.resize(B2_image, target_size, interpolation=cv2.INTER_NEAREST)
71
+ return B2_resized.reshape(-1, 1)
72
+
73
+ tiff_processed = process_tiff(file.name)
74
+ prediction = rf_model.predict(tiff_processed)
75
+ prediction_reshaped = prediction.reshape((50, 50))
76
+
77
+ plt.figure(figsize=(10, 10))
78
+ plt.imshow(prediction_reshaped, cmap='viridis')
79
+ plt.colorbar()
80
+ plt.title('Yield Prediction for Single TIFF File')
81
+ plt.savefig('/tmp/rf_prediction_overlay.png')
82
+
83
+ return '/tmp/rf_prediction_overlay.png'
84
+ else:
85
+ return "No file uploaded"
86
+
87
+ # Load deep learning models
88
  def load_deep_learning_model(model_name):
89
  base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(128, 128, 3))
90
  base_model.trainable = False
 
134
  predictions = model.predict(preprocessed_patches)
135
  predictions = predictions.reshape((n_patches_y, n_patches_x))
136
 
137
+ threshold = np.percentile(predictions, 90)
 
138
 
 
139
  overlay = np.zeros_like(img_data, dtype=np.float32)
140
  for i in range(n_patches_y):
141
  for j in range(n_patches_x):
142
  if predictions[i, j] > threshold:
143
  overlay[i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size] = predictions[i, j]
144
 
 
145
  plt.figure(figsize=(10, 10))
146
  plt.imshow(img_data, cmap='gray', alpha=0.5)
147
  plt.imshow(overlay, cmap='jet', alpha=0.5)
148
  plt.title('Crop Yield Prediction Overlay')
149
+ # plt.colorbar()
150
+ plt.savefig('/tmp/dl_prediction_overlay.png')
 
 
151
 
152
+ return '/tmp/dl_prediction_overlay.png'
153
  else:
154
  return "No file uploaded"
155
  else:
 
165
  outputs_traditional = gr.Textbox(label='Predicted Profit')
166
 
167
  inputs_deep_learning = [
168
+ gr.Dropdown(choices=list(deep_learning_models.keys()) + ['Random Forest'], label='Model'),
169
  gr.File(label='Upload TIFF File')
170
  ]
171
  outputs_deep_learning = gr.Image(label='Prediction Overlay')
 
181
 
182
  with gr.Tab("Deep Learning Models"):
183
  gr.Interface(
184
+ fn=lambda model_name, file: predict_deep_learning(model_name, file) if model_name != 'Random Forest' else predict_random_forest(file),
185
  inputs=inputs_deep_learning,
186
  outputs=outputs_deep_learning,
187
+ title="Crop Yield Prediction using Deep Learning Models and Random Forest"
188
  )
189
 
190
  demo.launch()