Spaces:
Sleeping
Sleeping
import gradio as gr | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
import numpy as np | |
import tensorflow as tf | |
import joblib | |
# Load the dataset | |
webtraffic_data = pd.read_csv("webtraffic.csv") | |
# Ensure the 'Datetime' column is in datetime format | |
webtraffic_data['Datetime'] = pd.to_datetime(webtraffic_data['Datetime']) | |
# Load the pre-trained models | |
sarima_model = joblib.load("sarima_model.pkl") # Load SARIMA model | |
lstm_model = tf.keras.models.load_model("lstm_model.keras") # Load LSTM model | |
# Load the scaler for LSTM if used during training (optional) | |
scaler = joblib.load("scaler.pkl") | |
# Function to generate predictions and plots | |
def generate_custom_prediction(model, future_hours): | |
future_hours = int(future_hours) | |
future_datetimes = pd.date_range( | |
start=webtraffic_data['Datetime'].iloc[-1], | |
periods=future_hours + 1, | |
freq='H' | |
)[1:] | |
if model == "SARIMA": | |
# SARIMA Predictions | |
sarima_predictions = sarima_model.forecast(steps=future_hours) | |
plt.figure(figsize=(15, 6)) | |
plt.plot(webtraffic_data['Datetime'], webtraffic_data['Sessions'], label="Actual Data", color="blue") | |
plt.plot(future_datetimes, sarima_predictions, label="SARIMA Predictions", color="green") | |
elif model == "LSTM": | |
# Prepare data for LSTM (reshape and scale as necessary) | |
lstm_input = webtraffic_data['Sessions'].values[-future_hours:].reshape(-1, 1) | |
lstm_input_scaled = scaler.transform(lstm_input) # Scale input using the saved scaler | |
lstm_input_scaled = lstm_input_scaled.reshape(1, future_hours, 1) # Reshape for LSTM model | |
# LSTM Predictions | |
lstm_predictions = lstm_model.predict(lstm_input_scaled) | |
lstm_predictions = scaler.inverse_transform(lstm_predictions).flatten() # Inverse scale | |
plt.figure(figsize=(15, 6)) | |
plt.plot(webtraffic_data['Datetime'], webtraffic_data['Sessions'], label="Actual Data", color="blue") | |
plt.plot(future_datetimes, lstm_predictions, label="LSTM Predictions", color="green") | |
# Customize the plot | |
plt.title(f"{model} Web Traffic Predictions", fontsize=16) | |
plt.xlabel("Datetime", fontsize=12) | |
plt.ylabel("Sessions", fontsize=12) | |
plt.legend(loc="upper left") | |
plt.grid(True) | |
plt.tight_layout() | |
# Save the plot as an image | |
plot_path = f"{model.lower()}_prediction_plot.png" | |
plt.savefig(plot_path) | |
plt.close() | |
return plot_path | |
# Gradio interface function | |
def prediction_dashboard(model, future_hours): | |
plot_path = generate_custom_prediction(model, future_hours) | |
return plot_path | |
# Build the Gradio interface | |
with gr.Blocks() as dashboard: | |
gr.Markdown("## Interactive Web Traffic Prediction Dashboard") | |
gr.Markdown("Input the number of hours to predict and select a model for future web traffic forecasting.") | |
# Dropdown for model selection | |
model_selection = gr.Dropdown(["SARIMA", "LSTM"], label="Select Model", value="SARIMA") | |
# Input for future hours | |
future_hours_input = gr.Number(label="Future Hours to Predict", value=24) | |
# Output: Plot | |
plot_output = gr.Image(label="Prediction Plot") | |
# Button to generate predictions | |
gr.Button("Generate Prediction").click( | |
fn=prediction_dashboard, | |
inputs=[model_selection, future_hours_input], | |
outputs=[plot_output] | |
) | |
# Launch the Gradio dashboard | |
dashboard.launch() | |