webtraffic / app.py
manjunathainti's picture
update app.py
2f82680 verified
import gradio as gr
import matplotlib.pyplot as plt
import pandas as pd
import joblib
from sklearn.metrics import mean_absolute_error, mean_squared_error
from math import sqrt
# Step 1: Load the Dataset
print("Loading Dataset...")
data_file = "webtraffic.csv"
try:
webtraffic_data = pd.read_csv(data_file)
print("Dataset loaded successfully!")
except Exception as e:
print(f"Error loading dataset: {e}")
exit()
# Step 2: Ensure 'Datetime' column exists or create it
if "Datetime" not in webtraffic_data.columns:
print("Datetime column missing. Creating from 'Hour Index'.")
start_date = pd.Timestamp("2024-01-01 00:00:00")
webtraffic_data["Datetime"] = start_date + pd.to_timedelta(webtraffic_data["Hour Index"], unit="h")
else:
webtraffic_data["Datetime"] = pd.to_datetime(webtraffic_data["Datetime"])
webtraffic_data.sort_values("Datetime", inplace=True)
# Step 3: Load SARIMA Model
print("Loading SARIMA Model...")
try:
sarima_model = joblib.load("sarima_model.pkl")
print("SARIMA model loaded successfully!")
except Exception as e:
print(f"Error loading SARIMA model: {e}")
exit()
# Step 4: Define Functions for Gradio Dashboard
future_periods = 48 # Number of hours to predict
def generate_sarima_plot():
"""Generate SARIMA predictions and return a detailed plot with metrics."""
try:
# Generate future dates for predictions
future_dates = pd.date_range(
start=webtraffic_data["Datetime"].iloc[-1],
periods=future_periods + 1,
freq="H"
)[1:]
# Generate SARIMA predictions
sarima_predictions = sarima_model.predict(n_periods=future_periods)
# Extract actual data for the last 'future_periods' hours
actual_sessions = webtraffic_data["Sessions"].iloc[-future_periods:].values
# Calculate metrics
mae_sarima = mean_absolute_error(actual_sessions, sarima_predictions[:len(actual_sessions)])
rmse_sarima = sqrt(mean_squared_error(actual_sessions, sarima_predictions[:len(actual_sessions)]))
# Combine predictions into a DataFrame for plotting
future_predictions = pd.DataFrame({
"Datetime": future_dates,
"SARIMA_Predicted": sarima_predictions
})
# Plot Actual Traffic vs SARIMA Predictions
plt.figure(figsize=(15, 6))
plt.plot(
webtraffic_data["Datetime"],
webtraffic_data["Sessions"],
label="Actual Traffic",
color="black",
linestyle="dotted",
linewidth=2,
)
plt.plot(
future_predictions["Datetime"],
future_predictions["SARIMA_Predicted"],
label="SARIMA Predicted",
color="blue",
linewidth=2,
)
plt.title("SARIMA Predictions vs Actual Traffic", fontsize=16)
plt.xlabel("Datetime", fontsize=12)
plt.ylabel("Sessions", fontsize=12)
plt.legend(loc="upper left")
plt.grid(True)
plt.tight_layout()
# Save the plot
plot_path = "sarima_prediction_plot.png"
plt.savefig(plot_path)
plt.close()
# Return plot path and metrics
metrics = f"""
SARIMA Model Metrics:
- Mean Absolute Error (MAE): {mae_sarima:.2f}
- Root Mean Squared Error (RMSE): {rmse_sarima:.2f}
"""
return plot_path, metrics
except Exception as e:
print(f"Error generating SARIMA plot: {e}")
return None, "Error in generating output. Please check the data and model."
def generate_zoomed_plot():
"""Generate a zoomed-in SARIMA prediction plot."""
try:
# Generate future dates for predictions
future_dates = pd.date_range(
start=webtraffic_data["Datetime"].iloc[-1],
periods=future_periods + 1,
freq="H"
)[1:]
# Generate SARIMA predictions
sarima_predictions = sarima_model.predict(n_periods=future_periods)
# Combine predictions into a DataFrame for plotting
future_predictions = pd.DataFrame({
"Datetime": future_dates,
"SARIMA_Predicted": sarima_predictions
})
# Zoomed-in view of the plot (recent data only)
plt.figure(figsize=(15, 6))
plt.plot(
webtraffic_data["Datetime"].iloc[-future_periods:],
webtraffic_data["Sessions"].iloc[-future_periods:],
label="Actual Traffic (Zoomed)",
color="black",
linestyle="dotted",
linewidth=2,
)
plt.plot(
future_predictions["Datetime"],
future_predictions["SARIMA_Predicted"],
label="SARIMA Predicted (Zoomed)",
color="green",
linewidth=2,
)
plt.title("Zoomed-In SARIMA Predictions vs Actual Traffic", fontsize=16)
plt.xlabel("Datetime", fontsize=12)
plt.ylabel("Sessions", fontsize=12)
plt.legend(loc="upper left")
plt.grid(True)
plt.tight_layout()
# Save the zoomed plot
zoomed_plot_path = "sarima_zoomed_plot.png"
plt.savefig(zoomed_plot_path)
plt.close()
return zoomed_plot_path
except Exception as e:
print(f"Error generating zoomed plot: {e}")
return None
# Step 5: Gradio Dashboard with Two Tiles and Metrics
with gr.Blocks() as dashboard:
gr.Markdown("## Enhanced SARIMA Web Traffic Prediction Dashboard")
gr.Markdown("This dashboard includes SARIMA predictions, performance metrics, and a zoomed-in view of recent data.")
# Outputs: Main Plot and Metrics
plot_output = gr.Image(label="SARIMA Prediction Plot")
metrics_output = gr.Textbox(label="Model Metrics", lines=6)
# Outputs: Zoomed Plot
zoomed_plot_output = gr.Image(label="Zoomed-In Prediction Plot")
# Button to Generate Results
gr.Button("Generate Predictions").click(
fn=generate_sarima_plot,
inputs=[],
outputs=[plot_output, metrics_output],
)
gr.Button("Generate Zoomed-In Plot").click(
fn=generate_zoomed_plot,
inputs=[],
outputs=[zoomed_plot_output],
)
# Launch the Gradio Dashboard
if __name__ == "__main__":
print("\nLaunching Enhanced Gradio Dashboard...")
dashboard.launch()