sanket09 commited on
Commit
5e1746f
·
verified ·
1 Parent(s): 5bfc41d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -16
app.py CHANGED
@@ -3,16 +3,14 @@ import numpy as np
3
  import rasterio
4
  import cv2
5
  import matplotlib.pyplot as plt
6
- import joblib
7
  from sklearn.ensemble import RandomForestRegressor
8
  from sklearn.model_selection import train_test_split
9
- from PIL import Image
10
- import io
11
 
12
  # Function to process a single TIFF file
13
- def process_tiff(file):
14
- # Read file from BytesIO object
15
- with rasterio.open(io.BytesIO(file.read())) as src:
16
  tiff_data = src.read()
17
  B2_image = tiff_data[1, :, :] # Assuming B2 is the second band
18
  target_size = (50, 50)
@@ -20,20 +18,36 @@ def process_tiff(file):
20
  return B2_resized.reshape(-1, 1) # Reshape for the model input
21
 
22
  # Function to train the RandomForestRegressor model
23
- def train_random_forest_model():
24
  X_list = []
25
  y_list = []
26
 
27
- # Placeholder for actual file paths
28
- # Modify to load and process multiple TIFF files if needed
29
- # For now, using a single uploaded file for training
30
- return
 
 
 
 
31
 
32
- # Function to make predictions using the RandomForestRegressor model
33
- def predict_crop_yield(file, model_name):
 
 
 
 
 
 
 
 
 
 
 
 
34
  if model_name == 'Random Forest':
35
- model = joblib.load('crop_yield_model.joblib') # Load pre-trained model
36
- processed_image = process_tiff(file)
37
  prediction = model.predict(processed_image)
38
  prediction_reshaped = prediction.reshape((50, 50))
39
  plt.imshow(prediction_reshaped, cmap='viridis')
@@ -44,6 +58,8 @@ def predict_crop_yield(file, model_name):
44
  else:
45
  return "Model not found"
46
 
 
 
47
  inputs = [
48
  gr.File(label='Upload TIFF File'),
49
  gr.Dropdown(choices=['Random Forest'], label='Model')
@@ -51,7 +67,7 @@ inputs = [
51
  outputs = gr.Image(type='filepath', label='Predicted Yield Visualization')
52
 
53
  demo = gr.Interface(
54
- fn=predict_crop_yield,
55
  inputs=inputs,
56
  outputs=outputs,
57
  title="Crop Yield Prediction using Random Forest",
 
3
  import rasterio
4
  import cv2
5
  import matplotlib.pyplot as plt
 
6
  from sklearn.ensemble import RandomForestRegressor
7
  from sklearn.model_selection import train_test_split
8
+ from sklearn.metrics import mean_squared_error
9
+ import os
10
 
11
  # Function to process a single TIFF file
12
+ def process_tiff(file_path):
13
+ with rasterio.open(file_path) as src:
 
14
  tiff_data = src.read()
15
  B2_image = tiff_data[1, :, :] # Assuming B2 is the second band
16
  target_size = (50, 50)
 
18
  return B2_resized.reshape(-1, 1) # Reshape for the model input
19
 
20
  # Function to train the RandomForestRegressor model
21
+ def train_random_forest_model(data_dir):
22
  X_list = []
23
  y_list = []
24
 
25
+ # Load all TIFF files and preprocess data
26
+ for root, dirs, files in os.walk(data_dir):
27
+ for file in files:
28
+ if file.endswith('.tiff'):
29
+ file_path = os.path.join(root, file)
30
+ X_list.append(process_tiff(file_path))
31
+ # Generate synthetic target data for demonstration (replace with actual targets)
32
+ y_list.append(np.random.rand(2500)) # Assuming target_size is (50, 50)
33
 
34
+ X = np.vstack(X_list)
35
+ y = np.hstack(y_list)
36
+
37
+ # Split the data into training and testing sets
38
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
39
+
40
+ # Train the model
41
+ model = RandomForestRegressor(n_estimators=100, random_state=42)
42
+ model.fit(X_train, y_train)
43
+
44
+ return model
45
+
46
+ # Function to make predictions using the trained model
47
+ def predict_crop_yield(file, model_name, data_dir):
48
  if model_name == 'Random Forest':
49
+ model = train_random_forest_model(data_dir)
50
+ processed_image = process_tiff(file.name)
51
  prediction = model.predict(processed_image)
52
  prediction_reshaped = prediction.reshape((50, 50))
53
  plt.imshow(prediction_reshaped, cmap='viridis')
 
58
  else:
59
  return "Model not found"
60
 
61
+ data_dir = 'Data' # Path to the folder containing TIFF files
62
+
63
  inputs = [
64
  gr.File(label='Upload TIFF File'),
65
  gr.Dropdown(choices=['Random Forest'], label='Model')
 
67
  outputs = gr.Image(type='filepath', label='Predicted Yield Visualization')
68
 
69
  demo = gr.Interface(
70
+ fn=lambda file, model_name: predict_crop_yield(file, model_name, data_dir),
71
  inputs=inputs,
72
  outputs=outputs,
73
  title="Crop Yield Prediction using Random Forest",