Spaces:
Sleeping
Sleeping
import gradio as gr | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
import joblib | |
from sklearn.metrics import mean_absolute_error, mean_squared_error | |
from math import sqrt | |
# Step 1: Load the Dataset | |
print("Loading Dataset...") | |
data_file = "webtraffic.csv" | |
try: | |
webtraffic_data = pd.read_csv(data_file) | |
print("Dataset loaded successfully!") | |
except Exception as e: | |
print(f"Error loading dataset: {e}") | |
exit() | |
# Step 2: Ensure 'Datetime' column exists or create it | |
if "Datetime" not in webtraffic_data.columns: | |
print("Datetime column missing. Creating from 'Hour Index'.") | |
start_date = pd.Timestamp("2024-01-01 00:00:00") | |
webtraffic_data["Datetime"] = start_date + pd.to_timedelta(webtraffic_data["Hour Index"], unit="h") | |
else: | |
webtraffic_data["Datetime"] = pd.to_datetime(webtraffic_data["Datetime"]) | |
webtraffic_data.sort_values("Datetime", inplace=True) | |
# Step 3: Load SARIMA Model | |
print("Loading SARIMA Model...") | |
try: | |
sarima_model = joblib.load("sarima_model.pkl") | |
print("SARIMA model loaded successfully!") | |
except Exception as e: | |
print(f"Error loading SARIMA model: {e}") | |
exit() | |
# Step 4: Define Functions for Gradio Dashboard | |
future_periods = 48 # Number of hours to predict | |
def generate_sarima_plot(): | |
"""Generate SARIMA predictions and return a detailed plot with metrics.""" | |
try: | |
# Generate future dates for predictions | |
future_dates = pd.date_range( | |
start=webtraffic_data["Datetime"].iloc[-1], | |
periods=future_periods + 1, | |
freq="H" | |
)[1:] | |
# Generate SARIMA predictions | |
sarima_predictions = sarima_model.predict(n_periods=future_periods) | |
# Extract actual data for the last 'future_periods' hours | |
actual_sessions = webtraffic_data["Sessions"].iloc[-future_periods:].values | |
# Calculate metrics | |
mae_sarima = mean_absolute_error(actual_sessions, sarima_predictions[:len(actual_sessions)]) | |
rmse_sarima = sqrt(mean_squared_error(actual_sessions, sarima_predictions[:len(actual_sessions)])) | |
# Combine predictions into a DataFrame for plotting | |
future_predictions = pd.DataFrame({ | |
"Datetime": future_dates, | |
"SARIMA_Predicted": sarima_predictions | |
}) | |
# Plot Actual Traffic vs SARIMA Predictions | |
plt.figure(figsize=(15, 6)) | |
plt.plot( | |
webtraffic_data["Datetime"], | |
webtraffic_data["Sessions"], | |
label="Actual Traffic", | |
color="black", | |
linestyle="dotted", | |
linewidth=2, | |
) | |
plt.plot( | |
future_predictions["Datetime"], | |
future_predictions["SARIMA_Predicted"], | |
label="SARIMA Predicted", | |
color="blue", | |
linewidth=2, | |
) | |
plt.title("SARIMA Predictions vs Actual Traffic", fontsize=16) | |
plt.xlabel("Datetime", fontsize=12) | |
plt.ylabel("Sessions", fontsize=12) | |
plt.legend(loc="upper left") | |
plt.grid(True) | |
plt.tight_layout() | |
# Save the plot | |
plot_path = "sarima_prediction_plot.png" | |
plt.savefig(plot_path) | |
plt.close() | |
# Return plot path and metrics | |
metrics = f""" | |
SARIMA Model Metrics: | |
- Mean Absolute Error (MAE): {mae_sarima:.2f} | |
- Root Mean Squared Error (RMSE): {rmse_sarima:.2f} | |
""" | |
return plot_path, metrics | |
except Exception as e: | |
print(f"Error generating SARIMA plot: {e}") | |
return None, "Error in generating output. Please check the data and model." | |
def generate_zoomed_plot(): | |
"""Generate a zoomed-in SARIMA prediction plot.""" | |
try: | |
# Generate future dates for predictions | |
future_dates = pd.date_range( | |
start=webtraffic_data["Datetime"].iloc[-1], | |
periods=future_periods + 1, | |
freq="H" | |
)[1:] | |
# Generate SARIMA predictions | |
sarima_predictions = sarima_model.predict(n_periods=future_periods) | |
# Combine predictions into a DataFrame for plotting | |
future_predictions = pd.DataFrame({ | |
"Datetime": future_dates, | |
"SARIMA_Predicted": sarima_predictions | |
}) | |
# Zoomed-in view of the plot (recent data only) | |
plt.figure(figsize=(15, 6)) | |
plt.plot( | |
webtraffic_data["Datetime"].iloc[-future_periods:], | |
webtraffic_data["Sessions"].iloc[-future_periods:], | |
label="Actual Traffic (Zoomed)", | |
color="black", | |
linestyle="dotted", | |
linewidth=2, | |
) | |
plt.plot( | |
future_predictions["Datetime"], | |
future_predictions["SARIMA_Predicted"], | |
label="SARIMA Predicted (Zoomed)", | |
color="green", | |
linewidth=2, | |
) | |
plt.title("Zoomed-In SARIMA Predictions vs Actual Traffic", fontsize=16) | |
plt.xlabel("Datetime", fontsize=12) | |
plt.ylabel("Sessions", fontsize=12) | |
plt.legend(loc="upper left") | |
plt.grid(True) | |
plt.tight_layout() | |
# Save the zoomed plot | |
zoomed_plot_path = "sarima_zoomed_plot.png" | |
plt.savefig(zoomed_plot_path) | |
plt.close() | |
return zoomed_plot_path | |
except Exception as e: | |
print(f"Error generating zoomed plot: {e}") | |
return None | |
# Step 5: Gradio Dashboard with Two Tiles and Metrics | |
with gr.Blocks() as dashboard: | |
gr.Markdown("## Enhanced SARIMA Web Traffic Prediction Dashboard") | |
gr.Markdown("This dashboard includes SARIMA predictions, performance metrics, and a zoomed-in view of recent data.") | |
# Outputs: Main Plot and Metrics | |
plot_output = gr.Image(label="SARIMA Prediction Plot") | |
metrics_output = gr.Textbox(label="Model Metrics", lines=6) | |
# Outputs: Zoomed Plot | |
zoomed_plot_output = gr.Image(label="Zoomed-In Prediction Plot") | |
# Button to Generate Results | |
gr.Button("Generate Predictions").click( | |
fn=generate_sarima_plot, | |
inputs=[], | |
outputs=[plot_output, metrics_output], | |
) | |
gr.Button("Generate Zoomed-In Plot").click( | |
fn=generate_zoomed_plot, | |
inputs=[], | |
outputs=[zoomed_plot_output], | |
) | |
# Launch the Gradio Dashboard | |
if __name__ == "__main__": | |
print("\nLaunching Enhanced Gradio Dashboard...") | |
dashboard.launch() | |