Spaces:
Sleeping
Sleeping
Commit
·
a32c351
1
Parent(s):
b37c21f
Update app.py and requirements.txt for SARIMA and LSTM models
Browse files- app.py +82 -49
- requirements.txt +4 -4
app.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
-
|
2 |
import gradio as gr
|
3 |
import matplotlib.pyplot as plt
|
4 |
import pandas as pd
|
5 |
import numpy as np
|
6 |
import tensorflow as tf
|
7 |
import joblib
|
|
|
8 |
|
9 |
# Load the dataset
|
10 |
webtraffic_data = pd.read_csv("webtraffic.csv")
|
@@ -13,83 +13,116 @@ webtraffic_data = pd.read_csv("webtraffic.csv")
|
|
13 |
webtraffic_data.rename(columns={"Hour Index": "Datetime"}, inplace=True)
|
14 |
|
15 |
# Create a datetime-like index for visualization purposes
|
16 |
-
webtraffic_data['Datetime'] = pd.
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
# Load the pre-trained models
|
19 |
-
sarima_model = joblib.load("sarima_model.pkl") #
|
20 |
-
lstm_model = tf.keras.models.load_model("lstm_model.keras") #
|
21 |
|
22 |
-
#
|
23 |
-
|
24 |
|
25 |
-
#
|
26 |
-
|
27 |
-
future_hours = int(future_hours)
|
28 |
-
future_datetimes = pd.date_range(
|
29 |
-
start=webtraffic_data['Datetime'].iloc[-1],
|
30 |
-
periods=future_hours + 1,
|
31 |
-
freq='H'
|
32 |
-
)[1:]
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
sarima_predictions = sarima_model.forecast(steps=future_hours)
|
37 |
-
plt.figure(figsize=(15, 6))
|
38 |
-
plt.plot(webtraffic_data['Datetime'], webtraffic_data['Sessions'], label="Actual Data", color="blue")
|
39 |
-
plt.plot(future_datetimes, sarima_predictions, label="SARIMA Predictions", color="green")
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
plt.plot(future_datetimes, lstm_predictions, label="LSTM Predictions", color="green")
|
54 |
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
plt.xlabel("Datetime", fontsize=12)
|
58 |
plt.ylabel("Sessions", fontsize=12)
|
59 |
plt.legend(loc="upper left")
|
60 |
plt.grid(True)
|
61 |
plt.tight_layout()
|
62 |
-
|
63 |
-
# Save the plot as an image
|
64 |
-
plot_path = f"{model.lower()}_prediction_plot.png"
|
65 |
plt.savefig(plot_path)
|
66 |
plt.close()
|
67 |
return plot_path
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
# Gradio interface function
|
70 |
-
def
|
71 |
-
|
72 |
-
|
|
|
|
|
73 |
|
74 |
# Build the Gradio interface
|
75 |
with gr.Blocks() as dashboard:
|
76 |
gr.Markdown("## Interactive Web Traffic Prediction Dashboard")
|
77 |
-
gr.Markdown("
|
78 |
|
79 |
# Dropdown for model selection
|
80 |
model_selection = gr.Dropdown(["SARIMA", "LSTM"], label="Select Model", value="SARIMA")
|
81 |
|
82 |
-
#
|
83 |
-
future_hours_input = gr.Number(label="Future Hours to Predict", value=24)
|
84 |
-
|
85 |
-
# Output: Plot
|
86 |
plot_output = gr.Image(label="Prediction Plot")
|
|
|
87 |
|
88 |
-
# Button to
|
89 |
-
gr.Button("
|
90 |
-
fn=
|
91 |
-
inputs=[model_selection,
|
92 |
-
outputs=[plot_output]
|
93 |
)
|
94 |
|
95 |
# Launch the Gradio dashboard
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import matplotlib.pyplot as plt
|
3 |
import pandas as pd
|
4 |
import numpy as np
|
5 |
import tensorflow as tf
|
6 |
import joblib
|
7 |
+
from sklearn.metrics import mean_absolute_error, mean_squared_error
|
8 |
|
9 |
# Load the dataset
|
10 |
webtraffic_data = pd.read_csv("webtraffic.csv")
|
|
|
13 |
webtraffic_data.rename(columns={"Hour Index": "Datetime"}, inplace=True)
|
14 |
|
15 |
# Create a datetime-like index for visualization purposes
|
16 |
+
webtraffic_data['Datetime'] = pd.date_range(start='2023-01-01', periods=len(webtraffic_data), freq='H')
|
17 |
+
|
18 |
+
# Split the data into train/test for evaluation
|
19 |
+
train_size = int(len(webtraffic_data) * 0.8)
|
20 |
+
test_size = len(webtraffic_data) - train_size
|
21 |
+
train_data = webtraffic_data.iloc[:train_size]
|
22 |
+
test_data = webtraffic_data.iloc[train_size:]
|
23 |
|
24 |
# Load the pre-trained models
|
25 |
+
sarima_model = joblib.load("sarima_model.pkl") # SARIMA model
|
26 |
+
lstm_model = tf.keras.models.load_model("lstm_model.keras") # LSTM model
|
27 |
|
28 |
+
# Initialize future periods for prediction
|
29 |
+
future_periods = len(test_data)
|
30 |
|
31 |
+
# Generate predictions for SARIMA
|
32 |
+
sarima_predictions = sarima_model.forecast(steps=future_periods)
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
+
# Prepare data for LSTM predictions
|
35 |
+
from sklearn.preprocessing import MinMaxScaler
|
|
|
|
|
|
|
|
|
36 |
|
37 |
+
scaler_X = MinMaxScaler(feature_range=(0, 1))
|
38 |
+
scaler_y = MinMaxScaler(feature_range=(0, 1))
|
39 |
+
|
40 |
+
# Fit the scaler to the training data
|
41 |
+
X_train_scaled = scaler_X.fit_transform(train_data['Sessions'].values.reshape(-1, 1))
|
42 |
+
y_train_scaled = scaler_y.fit_transform(train_data['Sessions'].values.reshape(-1, 1))
|
43 |
+
|
44 |
+
# Scale test data
|
45 |
+
X_test_scaled = scaler_X.transform(test_data['Sessions'].values.reshape(-1, 1))
|
46 |
+
y_test_scaled = scaler_y.transform(test_data['Sessions'].values.reshape(-1, 1))
|
47 |
+
|
48 |
+
# Reshape data for LSTM input
|
49 |
+
X_test_lstm = X_test_scaled.reshape((X_test_scaled.shape[0], 1, X_test_scaled.shape[1]))
|
50 |
+
|
51 |
+
# Predict with LSTM
|
52 |
+
lstm_predictions_scaled = lstm_model.predict(X_test_lstm)
|
53 |
+
lstm_predictions = scaler_y.inverse_transform(lstm_predictions_scaled).flatten()
|
54 |
+
|
55 |
+
# Combine predictions into a DataFrame for visualization
|
56 |
+
future_predictions = pd.DataFrame({
|
57 |
+
"Datetime": test_data['Datetime'],
|
58 |
+
"SARIMA_Predicted": sarima_predictions,
|
59 |
+
"LSTM_Predicted": lstm_predictions
|
60 |
+
})
|
61 |
|
62 |
+
# Calculate metrics for both models
|
63 |
+
mae_sarima_future = mean_absolute_error(test_data['Sessions'], sarima_predictions)
|
64 |
+
rmse_sarima_future = mean_squared_error(test_data['Sessions'], sarima_predictions, squared=False)
|
65 |
|
66 |
+
mae_lstm_future = mean_absolute_error(test_data['Sessions'], lstm_predictions)
|
67 |
+
rmse_lstm_future = mean_squared_error(test_data['Sessions'], lstm_predictions, squared=False)
|
|
|
68 |
|
69 |
+
# Function to generate plot based on the selected model
|
70 |
+
def generate_plot(model):
|
71 |
+
"""Generate plot based on the selected model."""
|
72 |
+
plt.figure(figsize=(15, 6))
|
73 |
+
actual_dates = test_data['Datetime']
|
74 |
+
plt.plot(actual_dates, test_data['Sessions'], label='Actual Traffic', color='black', linestyle='dotted', linewidth=2)
|
75 |
+
|
76 |
+
if model == "SARIMA":
|
77 |
+
plt.plot(future_predictions['Datetime'], future_predictions['SARIMA_Predicted'], label='SARIMA Predicted', color='blue', linewidth=2)
|
78 |
+
elif model == "LSTM":
|
79 |
+
plt.plot(future_predictions['Datetime'], future_predictions['LSTM_Predicted'], label='LSTM Predicted', color='green', linewidth=2)
|
80 |
+
|
81 |
+
plt.title(f"{model} Predictions vs Actual Traffic", fontsize=16)
|
82 |
plt.xlabel("Datetime", fontsize=12)
|
83 |
plt.ylabel("Sessions", fontsize=12)
|
84 |
plt.legend(loc="upper left")
|
85 |
plt.grid(True)
|
86 |
plt.tight_layout()
|
87 |
+
plot_path = f"{model.lower()}_plot.png"
|
|
|
|
|
88 |
plt.savefig(plot_path)
|
89 |
plt.close()
|
90 |
return plot_path
|
91 |
|
92 |
+
# Function to display metrics for both models
|
93 |
+
def display_metrics():
|
94 |
+
"""Generate a DataFrame with metrics for SARIMA and LSTM."""
|
95 |
+
metrics = {
|
96 |
+
"Model": ["SARIMA", "LSTM"],
|
97 |
+
"Mean Absolute Error (MAE)": [mae_sarima_future, mae_lstm_future],
|
98 |
+
"Root Mean Squared Error (RMSE)": [rmse_sarima_future, rmse_lstm_future]
|
99 |
+
}
|
100 |
+
return pd.DataFrame(metrics)
|
101 |
+
|
102 |
# Gradio interface function
|
103 |
+
def dashboard_interface(model="SARIMA"):
|
104 |
+
"""Generate plot and metrics for the selected model."""
|
105 |
+
plot_path = generate_plot(model) # Generate plot for the selected model
|
106 |
+
metrics_df = display_metrics() # Get metrics
|
107 |
+
return plot_path, metrics_df.to_string()
|
108 |
|
109 |
# Build the Gradio interface
|
110 |
with gr.Blocks() as dashboard:
|
111 |
gr.Markdown("## Interactive Web Traffic Prediction Dashboard")
|
112 |
+
gr.Markdown("Use the dropdown menu to select a model and view its predictions vs actual traffic along with performance metrics.")
|
113 |
|
114 |
# Dropdown for model selection
|
115 |
model_selection = gr.Dropdown(["SARIMA", "LSTM"], label="Select Model", value="SARIMA")
|
116 |
|
117 |
+
# Outputs: Plot and Metrics
|
|
|
|
|
|
|
118 |
plot_output = gr.Image(label="Prediction Plot")
|
119 |
+
metrics_output = gr.Textbox(label="Metrics", lines=15)
|
120 |
|
121 |
+
# Button to update dashboard
|
122 |
+
gr.Button("Update Dashboard").click(
|
123 |
+
fn=dashboard_interface, # Function to call
|
124 |
+
inputs=[model_selection], # Inputs to the function
|
125 |
+
outputs=[plot_output, metrics_output] # Outputs from the function
|
126 |
)
|
127 |
|
128 |
# Launch the Gradio dashboard
|
requirements.txt
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
-
|
2 |
numpy==1.23.5
|
3 |
-
|
|
|
|
|
4 |
matplotlib
|
5 |
gradio
|
6 |
scikit-learn
|
7 |
-
statsmodels
|
8 |
-
tensorflow
|
|
|
1 |
+
pmdarima==2.0.4
|
2 |
numpy==1.23.5
|
3 |
+
pandas
|
4 |
+
tensorflow
|
5 |
+
joblib
|
6 |
matplotlib
|
7 |
gradio
|
8 |
scikit-learn
|
|
|
|