Spaces:
Sleeping
Sleeping
File size: 6,371 Bytes
c4cf758 2f82680 c4cf758 2f82680 f8d0f44 2f82680 f8d0f44 2f82680 f8d0f44 2f82680 f8d0f44 2f82680 c4cf758 2f82680 c4cf758 2f82680 f69d4dd 2f82680 f8d0f44 2f82680 f8d0f44 c4cf758 2f82680 f8d0f44 2f82680 f8d0f44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import gradio as gr
import matplotlib.pyplot as plt
import pandas as pd
import joblib
from sklearn.metrics import mean_absolute_error, mean_squared_error
from math import sqrt
# Step 1: Load the Dataset
print("Loading Dataset...")
data_file = "webtraffic.csv"
try:
webtraffic_data = pd.read_csv(data_file)
print("Dataset loaded successfully!")
except Exception as e:
print(f"Error loading dataset: {e}")
exit()
# Step 2: Ensure 'Datetime' column exists or create it
if "Datetime" not in webtraffic_data.columns:
print("Datetime column missing. Creating from 'Hour Index'.")
start_date = pd.Timestamp("2024-01-01 00:00:00")
webtraffic_data["Datetime"] = start_date + pd.to_timedelta(webtraffic_data["Hour Index"], unit="h")
else:
webtraffic_data["Datetime"] = pd.to_datetime(webtraffic_data["Datetime"])
webtraffic_data.sort_values("Datetime", inplace=True)
# Step 3: Load SARIMA Model
print("Loading SARIMA Model...")
try:
sarima_model = joblib.load("sarima_model.pkl")
print("SARIMA model loaded successfully!")
except Exception as e:
print(f"Error loading SARIMA model: {e}")
exit()
# Step 4: Define Functions for Gradio Dashboard
future_periods = 48 # Number of hours to predict
def generate_sarima_plot():
"""Generate SARIMA predictions and return a detailed plot with metrics."""
try:
# Generate future dates for predictions
future_dates = pd.date_range(
start=webtraffic_data["Datetime"].iloc[-1],
periods=future_periods + 1,
freq="H"
)[1:]
# Generate SARIMA predictions
sarima_predictions = sarima_model.predict(n_periods=future_periods)
# Extract actual data for the last 'future_periods' hours
actual_sessions = webtraffic_data["Sessions"].iloc[-future_periods:].values
# Calculate metrics
mae_sarima = mean_absolute_error(actual_sessions, sarima_predictions[:len(actual_sessions)])
rmse_sarima = sqrt(mean_squared_error(actual_sessions, sarima_predictions[:len(actual_sessions)]))
# Combine predictions into a DataFrame for plotting
future_predictions = pd.DataFrame({
"Datetime": future_dates,
"SARIMA_Predicted": sarima_predictions
})
# Plot Actual Traffic vs SARIMA Predictions
plt.figure(figsize=(15, 6))
plt.plot(
webtraffic_data["Datetime"],
webtraffic_data["Sessions"],
label="Actual Traffic",
color="black",
linestyle="dotted",
linewidth=2,
)
plt.plot(
future_predictions["Datetime"],
future_predictions["SARIMA_Predicted"],
label="SARIMA Predicted",
color="blue",
linewidth=2,
)
plt.title("SARIMA Predictions vs Actual Traffic", fontsize=16)
plt.xlabel("Datetime", fontsize=12)
plt.ylabel("Sessions", fontsize=12)
plt.legend(loc="upper left")
plt.grid(True)
plt.tight_layout()
# Save the plot
plot_path = "sarima_prediction_plot.png"
plt.savefig(plot_path)
plt.close()
# Return plot path and metrics
metrics = f"""
SARIMA Model Metrics:
- Mean Absolute Error (MAE): {mae_sarima:.2f}
- Root Mean Squared Error (RMSE): {rmse_sarima:.2f}
"""
return plot_path, metrics
except Exception as e:
print(f"Error generating SARIMA plot: {e}")
return None, "Error in generating output. Please check the data and model."
def generate_zoomed_plot():
"""Generate a zoomed-in SARIMA prediction plot."""
try:
# Generate future dates for predictions
future_dates = pd.date_range(
start=webtraffic_data["Datetime"].iloc[-1],
periods=future_periods + 1,
freq="H"
)[1:]
# Generate SARIMA predictions
sarima_predictions = sarima_model.predict(n_periods=future_periods)
# Combine predictions into a DataFrame for plotting
future_predictions = pd.DataFrame({
"Datetime": future_dates,
"SARIMA_Predicted": sarima_predictions
})
# Zoomed-in view of the plot (recent data only)
plt.figure(figsize=(15, 6))
plt.plot(
webtraffic_data["Datetime"].iloc[-future_periods:],
webtraffic_data["Sessions"].iloc[-future_periods:],
label="Actual Traffic (Zoomed)",
color="black",
linestyle="dotted",
linewidth=2,
)
plt.plot(
future_predictions["Datetime"],
future_predictions["SARIMA_Predicted"],
label="SARIMA Predicted (Zoomed)",
color="green",
linewidth=2,
)
plt.title("Zoomed-In SARIMA Predictions vs Actual Traffic", fontsize=16)
plt.xlabel("Datetime", fontsize=12)
plt.ylabel("Sessions", fontsize=12)
plt.legend(loc="upper left")
plt.grid(True)
plt.tight_layout()
# Save the zoomed plot
zoomed_plot_path = "sarima_zoomed_plot.png"
plt.savefig(zoomed_plot_path)
plt.close()
return zoomed_plot_path
except Exception as e:
print(f"Error generating zoomed plot: {e}")
return None
# Step 5: Gradio Dashboard with Two Tiles and Metrics
with gr.Blocks() as dashboard:
gr.Markdown("## Enhanced SARIMA Web Traffic Prediction Dashboard")
gr.Markdown("This dashboard includes SARIMA predictions, performance metrics, and a zoomed-in view of recent data.")
# Outputs: Main Plot and Metrics
plot_output = gr.Image(label="SARIMA Prediction Plot")
metrics_output = gr.Textbox(label="Model Metrics", lines=6)
# Outputs: Zoomed Plot
zoomed_plot_output = gr.Image(label="Zoomed-In Prediction Plot")
# Button to Generate Results
gr.Button("Generate Predictions").click(
fn=generate_sarima_plot,
inputs=[],
outputs=[plot_output, metrics_output],
)
gr.Button("Generate Zoomed-In Plot").click(
fn=generate_zoomed_plot,
inputs=[],
outputs=[zoomed_plot_output],
)
# Launch the Gradio Dashboard
if __name__ == "__main__":
print("\nLaunching Enhanced Gradio Dashboard...")
dashboard.launch()
|