sanket09 commited on
Commit
21c3d04
·
verified ·
1 Parent(s): 88ab16a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -5
app.py CHANGED
@@ -57,8 +57,38 @@ def predict_traditional(model_name, year, state, crop, yield_):
57
  else:
58
  return "Model not found"
59
 
60
- # Load the pre-trained RandomForestRegressor model
61
- rf_model = joblib.load('crop_yield_model.joblib')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  def predict_random_forest(file):
64
  if file is not None:
@@ -146,13 +176,12 @@ def predict_deep_learning(model_name, file):
146
  plt.imshow(img_data, cmap='gray', alpha=0.5)
147
  plt.imshow(overlay, cmap='jet', alpha=0.5)
148
  plt.title('Crop Yield Prediction Overlay')
 
149
  plt.savefig('/tmp/dl_prediction_overlay.png')
150
 
151
  return '/tmp/dl_prediction_overlay.png'
152
  else:
153
  return "No file uploaded"
154
- elif model_name == 'Random Forest':
155
- return predict_random_forest(file)
156
  else:
157
  return "Model not found"
158
 
@@ -182,7 +211,7 @@ with gr.Blocks() as demo:
182
 
183
  with gr.Tab("Deep Learning Models"):
184
  gr.Interface(
185
- fn=predict_deep_learning,
186
  inputs=inputs_deep_learning,
187
  outputs=outputs_deep_learning,
188
  title="Crop Yield Prediction using Deep Learning Models and Random Forest"
 
57
  else:
58
  return "Model not found"
59
 
60
+ # Train RandomForestRegressor model for deep learning model
61
+ def train_random_forest_model():
62
+ def process_tiff(file_path):
63
+ with rasterio.open(file_path) as src:
64
+ tiff_data = src.read()
65
+ B2_image = tiff_data[1, :, :] # Assuming B2 is the second band
66
+ target_size = (50, 50)
67
+ B2_resized = cv2.resize(B2_image, target_size, interpolation=cv2.INTER_NEAREST)
68
+ return B2_resized.reshape(-1, 1)
69
+
70
+ data_dir = '/Data'
71
+ X_list = []
72
+ y_list = []
73
+
74
+ for root, dirs, files in os.walk(data_dir):
75
+ for file in files:
76
+ if file.endswith('.tiff'):
77
+ file_path = os.path.join(root, file)
78
+ X_list.append(process_tiff(file_path))
79
+ y_list.append(np.random.rand(2500)) # Replace with actual target data
80
+
81
+ X = np.vstack(X_list)
82
+ y = np.hstack(y_list)
83
+
84
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
85
+
86
+ model = RandomForestRegressor(n_estimators=100, random_state=42)
87
+ model.fit(X_train, y_train)
88
+
89
+ return model
90
+
91
+ rf_model = train_random_forest_model()
92
 
93
  def predict_random_forest(file):
94
  if file is not None:
 
176
  plt.imshow(img_data, cmap='gray', alpha=0.5)
177
  plt.imshow(overlay, cmap='jet', alpha=0.5)
178
  plt.title('Crop Yield Prediction Overlay')
179
+ plt.colorbar()
180
  plt.savefig('/tmp/dl_prediction_overlay.png')
181
 
182
  return '/tmp/dl_prediction_overlay.png'
183
  else:
184
  return "No file uploaded"
 
 
185
  else:
186
  return "Model not found"
187
 
 
211
 
212
  with gr.Tab("Deep Learning Models"):
213
  gr.Interface(
214
+ fn=lambda model_name, file: predict_deep_learning(model_name, file) if model_name != 'Random Forest' else predict_random_forest(file),
215
  inputs=inputs_deep_learning,
216
  outputs=outputs_deep_learning,
217
  title="Crop Yield Prediction using Deep Learning Models and Random Forest"