manjunathainti commited on
Commit
a32c351
·
1 Parent(s): b37c21f

Update app.py and requirements.txt for SARIMA and LSTM models

Browse files
Files changed (2) hide show
  1. app.py +82 -49
  2. requirements.txt +4 -4
app.py CHANGED
@@ -1,10 +1,10 @@
1
-
2
  import gradio as gr
3
  import matplotlib.pyplot as plt
4
  import pandas as pd
5
  import numpy as np
6
  import tensorflow as tf
7
  import joblib
 
8
 
9
  # Load the dataset
10
  webtraffic_data = pd.read_csv("webtraffic.csv")
@@ -13,83 +13,116 @@ webtraffic_data = pd.read_csv("webtraffic.csv")
13
  webtraffic_data.rename(columns={"Hour Index": "Datetime"}, inplace=True)
14
 
15
  # Create a datetime-like index for visualization purposes
16
- webtraffic_data['Datetime'] = pd.to_datetime(webtraffic_data['Datetime'], unit='h', origin='unix')
 
 
 
 
 
 
17
 
18
  # Load the pre-trained models
19
- sarima_model = joblib.load("sarima_model.pkl") # Load SARIMA model
20
- lstm_model = tf.keras.models.load_model("lstm_model.keras") # Load LSTM model
21
 
22
- # Load the scaler for LSTM (if used during training)
23
- scaler = joblib.load("scaler.pkl")
24
 
25
- # Function to generate predictions and plots
26
- def generate_custom_prediction(model, future_hours):
27
- future_hours = int(future_hours)
28
- future_datetimes = pd.date_range(
29
- start=webtraffic_data['Datetime'].iloc[-1],
30
- periods=future_hours + 1,
31
- freq='H'
32
- )[1:]
33
 
34
- if model == "SARIMA":
35
- # SARIMA Predictions
36
- sarima_predictions = sarima_model.forecast(steps=future_hours)
37
- plt.figure(figsize=(15, 6))
38
- plt.plot(webtraffic_data['Datetime'], webtraffic_data['Sessions'], label="Actual Data", color="blue")
39
- plt.plot(future_datetimes, sarima_predictions, label="SARIMA Predictions", color="green")
40
 
41
- elif model == "LSTM":
42
- # Prepare data for LSTM (reshape and scale as necessary)
43
- lstm_input = webtraffic_data['Sessions'].values[-future_hours:].reshape(-1, 1)
44
- lstm_input_scaled = scaler.transform(lstm_input) # Scale input using the saved scaler
45
- lstm_input_scaled = lstm_input_scaled.reshape(1, future_hours, 1) # Reshape for LSTM model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- # LSTM Predictions
48
- lstm_predictions = lstm_model.predict(lstm_input_scaled)
49
- lstm_predictions = scaler.inverse_transform(lstm_predictions).flatten() # Inverse scale
50
 
51
- plt.figure(figsize=(15, 6))
52
- plt.plot(webtraffic_data['Datetime'], webtraffic_data['Sessions'], label="Actual Data", color="blue")
53
- plt.plot(future_datetimes, lstm_predictions, label="LSTM Predictions", color="green")
54
 
55
- # Customize the plot
56
- plt.title(f"{model} Web Traffic Predictions", fontsize=16)
 
 
 
 
 
 
 
 
 
 
 
57
  plt.xlabel("Datetime", fontsize=12)
58
  plt.ylabel("Sessions", fontsize=12)
59
  plt.legend(loc="upper left")
60
  plt.grid(True)
61
  plt.tight_layout()
62
-
63
- # Save the plot as an image
64
- plot_path = f"{model.lower()}_prediction_plot.png"
65
  plt.savefig(plot_path)
66
  plt.close()
67
  return plot_path
68
 
 
 
 
 
 
 
 
 
 
 
69
  # Gradio interface function
70
- def prediction_dashboard(model, future_hours):
71
- plot_path = generate_custom_prediction(model, future_hours)
72
- return plot_path
 
 
73
 
74
  # Build the Gradio interface
75
  with gr.Blocks() as dashboard:
76
  gr.Markdown("## Interactive Web Traffic Prediction Dashboard")
77
- gr.Markdown("Input the number of hours to predict and select a model for future web traffic forecasting.")
78
 
79
  # Dropdown for model selection
80
  model_selection = gr.Dropdown(["SARIMA", "LSTM"], label="Select Model", value="SARIMA")
81
 
82
- # Input for future hours
83
- future_hours_input = gr.Number(label="Future Hours to Predict", value=24)
84
-
85
- # Output: Plot
86
  plot_output = gr.Image(label="Prediction Plot")
 
87
 
88
- # Button to generate predictions
89
- gr.Button("Generate Prediction").click(
90
- fn=prediction_dashboard,
91
- inputs=[model_selection, future_hours_input],
92
- outputs=[plot_output]
93
  )
94
 
95
  # Launch the Gradio dashboard
 
 
1
  import gradio as gr
2
  import matplotlib.pyplot as plt
3
  import pandas as pd
4
  import numpy as np
5
  import tensorflow as tf
6
  import joblib
7
+ from sklearn.metrics import mean_absolute_error, mean_squared_error
8
 
9
  # Load the dataset
10
  webtraffic_data = pd.read_csv("webtraffic.csv")
 
13
  webtraffic_data.rename(columns={"Hour Index": "Datetime"}, inplace=True)
14
 
15
  # Create a datetime-like index for visualization purposes
16
+ webtraffic_data['Datetime'] = pd.date_range(start='2023-01-01', periods=len(webtraffic_data), freq='H')
17
+
18
+ # Split the data into train/test for evaluation
19
+ train_size = int(len(webtraffic_data) * 0.8)
20
+ test_size = len(webtraffic_data) - train_size
21
+ train_data = webtraffic_data.iloc[:train_size]
22
+ test_data = webtraffic_data.iloc[train_size:]
23
 
24
  # Load the pre-trained models
25
+ sarima_model = joblib.load("sarima_model.pkl") # SARIMA model
26
+ lstm_model = tf.keras.models.load_model("lstm_model.keras") # LSTM model
27
 
28
+ # Initialize future periods for prediction
29
+ future_periods = len(test_data)
30
 
31
+ # Generate predictions for SARIMA
32
+ sarima_predictions = sarima_model.forecast(steps=future_periods)
 
 
 
 
 
 
33
 
34
+ # Prepare data for LSTM predictions
35
+ from sklearn.preprocessing import MinMaxScaler
 
 
 
 
36
 
37
+ scaler_X = MinMaxScaler(feature_range=(0, 1))
38
+ scaler_y = MinMaxScaler(feature_range=(0, 1))
39
+
40
+ # Fit the scaler to the training data
41
+ X_train_scaled = scaler_X.fit_transform(train_data['Sessions'].values.reshape(-1, 1))
42
+ y_train_scaled = scaler_y.fit_transform(train_data['Sessions'].values.reshape(-1, 1))
43
+
44
+ # Scale test data
45
+ X_test_scaled = scaler_X.transform(test_data['Sessions'].values.reshape(-1, 1))
46
+ y_test_scaled = scaler_y.transform(test_data['Sessions'].values.reshape(-1, 1))
47
+
48
+ # Reshape data for LSTM input
49
+ X_test_lstm = X_test_scaled.reshape((X_test_scaled.shape[0], 1, X_test_scaled.shape[1]))
50
+
51
+ # Predict with LSTM
52
+ lstm_predictions_scaled = lstm_model.predict(X_test_lstm)
53
+ lstm_predictions = scaler_y.inverse_transform(lstm_predictions_scaled).flatten()
54
+
55
+ # Combine predictions into a DataFrame for visualization
56
+ future_predictions = pd.DataFrame({
57
+ "Datetime": test_data['Datetime'],
58
+ "SARIMA_Predicted": sarima_predictions,
59
+ "LSTM_Predicted": lstm_predictions
60
+ })
61
 
62
+ # Calculate metrics for both models
63
+ mae_sarima_future = mean_absolute_error(test_data['Sessions'], sarima_predictions)
64
+ rmse_sarima_future = mean_squared_error(test_data['Sessions'], sarima_predictions, squared=False)
65
 
66
+ mae_lstm_future = mean_absolute_error(test_data['Sessions'], lstm_predictions)
67
+ rmse_lstm_future = mean_squared_error(test_data['Sessions'], lstm_predictions, squared=False)
 
68
 
69
+ # Function to generate plot based on the selected model
70
+ def generate_plot(model):
71
+ """Generate plot based on the selected model."""
72
+ plt.figure(figsize=(15, 6))
73
+ actual_dates = test_data['Datetime']
74
+ plt.plot(actual_dates, test_data['Sessions'], label='Actual Traffic', color='black', linestyle='dotted', linewidth=2)
75
+
76
+ if model == "SARIMA":
77
+ plt.plot(future_predictions['Datetime'], future_predictions['SARIMA_Predicted'], label='SARIMA Predicted', color='blue', linewidth=2)
78
+ elif model == "LSTM":
79
+ plt.plot(future_predictions['Datetime'], future_predictions['LSTM_Predicted'], label='LSTM Predicted', color='green', linewidth=2)
80
+
81
+ plt.title(f"{model} Predictions vs Actual Traffic", fontsize=16)
82
  plt.xlabel("Datetime", fontsize=12)
83
  plt.ylabel("Sessions", fontsize=12)
84
  plt.legend(loc="upper left")
85
  plt.grid(True)
86
  plt.tight_layout()
87
+ plot_path = f"{model.lower()}_plot.png"
 
 
88
  plt.savefig(plot_path)
89
  plt.close()
90
  return plot_path
91
 
92
+ # Function to display metrics for both models
93
+ def display_metrics():
94
+ """Generate a DataFrame with metrics for SARIMA and LSTM."""
95
+ metrics = {
96
+ "Model": ["SARIMA", "LSTM"],
97
+ "Mean Absolute Error (MAE)": [mae_sarima_future, mae_lstm_future],
98
+ "Root Mean Squared Error (RMSE)": [rmse_sarima_future, rmse_lstm_future]
99
+ }
100
+ return pd.DataFrame(metrics)
101
+
102
  # Gradio interface function
103
+ def dashboard_interface(model="SARIMA"):
104
+ """Generate plot and metrics for the selected model."""
105
+ plot_path = generate_plot(model) # Generate plot for the selected model
106
+ metrics_df = display_metrics() # Get metrics
107
+ return plot_path, metrics_df.to_string()
108
 
109
  # Build the Gradio interface
110
  with gr.Blocks() as dashboard:
111
  gr.Markdown("## Interactive Web Traffic Prediction Dashboard")
112
+ gr.Markdown("Use the dropdown menu to select a model and view its predictions vs actual traffic along with performance metrics.")
113
 
114
  # Dropdown for model selection
115
  model_selection = gr.Dropdown(["SARIMA", "LSTM"], label="Select Model", value="SARIMA")
116
 
117
+ # Outputs: Plot and Metrics
 
 
 
118
  plot_output = gr.Image(label="Prediction Plot")
119
+ metrics_output = gr.Textbox(label="Metrics", lines=15)
120
 
121
+ # Button to update dashboard
122
+ gr.Button("Update Dashboard").click(
123
+ fn=dashboard_interface, # Function to call
124
+ inputs=[model_selection], # Inputs to the function
125
+ outputs=[plot_output, metrics_output] # Outputs from the function
126
  )
127
 
128
  # Launch the Gradio dashboard
requirements.txt CHANGED
@@ -1,8 +1,8 @@
1
- pandas
2
  numpy==1.23.5
3
- pmdarima==1.8.5
 
 
4
  matplotlib
5
  gradio
6
  scikit-learn
7
- statsmodels
8
- tensorflow
 
1
+ pmdarima==2.0.4
2
  numpy==1.23.5
3
+ pandas
4
+ tensorflow
5
+ joblib
6
  matplotlib
7
  gradio
8
  scikit-learn