Spaces:
Sleeping
Sleeping
File size: 3,424 Bytes
c4cf758 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import gradio as gr
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import tensorflow as tf
import joblib
# Load the dataset
webtraffic_data = pd.read_csv("webtraffic.csv")
# Ensure the 'Datetime' column is in datetime format
webtraffic_data['Datetime'] = pd.to_datetime(webtraffic_data['Datetime'])
# Load the pre-trained models
sarima_model = joblib.load("sarima_model.pkl") # Load SARIMA model
lstm_model = tf.keras.models.load_model("lstm_model.keras") # Load LSTM model
# Load the scaler for LSTM if used during training (optional)
scaler = joblib.load("scaler.pkl")
# Function to generate predictions and plots
def generate_custom_prediction(model, future_hours):
future_hours = int(future_hours)
future_datetimes = pd.date_range(
start=webtraffic_data['Datetime'].iloc[-1],
periods=future_hours + 1,
freq='H'
)[1:]
if model == "SARIMA":
# SARIMA Predictions
sarima_predictions = sarima_model.forecast(steps=future_hours)
plt.figure(figsize=(15, 6))
plt.plot(webtraffic_data['Datetime'], webtraffic_data['Sessions'], label="Actual Data", color="blue")
plt.plot(future_datetimes, sarima_predictions, label="SARIMA Predictions", color="green")
elif model == "LSTM":
# Prepare data for LSTM (reshape and scale as necessary)
lstm_input = webtraffic_data['Sessions'].values[-future_hours:].reshape(-1, 1)
lstm_input_scaled = scaler.transform(lstm_input) # Scale input using the saved scaler
lstm_input_scaled = lstm_input_scaled.reshape(1, future_hours, 1) # Reshape for LSTM model
# LSTM Predictions
lstm_predictions = lstm_model.predict(lstm_input_scaled)
lstm_predictions = scaler.inverse_transform(lstm_predictions).flatten() # Inverse scale
plt.figure(figsize=(15, 6))
plt.plot(webtraffic_data['Datetime'], webtraffic_data['Sessions'], label="Actual Data", color="blue")
plt.plot(future_datetimes, lstm_predictions, label="LSTM Predictions", color="green")
# Customize the plot
plt.title(f"{model} Web Traffic Predictions", fontsize=16)
plt.xlabel("Datetime", fontsize=12)
plt.ylabel("Sessions", fontsize=12)
plt.legend(loc="upper left")
plt.grid(True)
plt.tight_layout()
# Save the plot as an image
plot_path = f"{model.lower()}_prediction_plot.png"
plt.savefig(plot_path)
plt.close()
return plot_path
# Gradio interface function
def prediction_dashboard(model, future_hours):
plot_path = generate_custom_prediction(model, future_hours)
return plot_path
# Build the Gradio interface
with gr.Blocks() as dashboard:
gr.Markdown("## Interactive Web Traffic Prediction Dashboard")
gr.Markdown("Input the number of hours to predict and select a model for future web traffic forecasting.")
# Dropdown for model selection
model_selection = gr.Dropdown(["SARIMA", "LSTM"], label="Select Model", value="SARIMA")
# Input for future hours
future_hours_input = gr.Number(label="Future Hours to Predict", value=24)
# Output: Plot
plot_output = gr.Image(label="Prediction Plot")
# Button to generate predictions
gr.Button("Generate Prediction").click(
fn=prediction_dashboard,
inputs=[model_selection, future_hours_input],
outputs=[plot_output]
)
# Launch the Gradio dashboard
dashboard.launch()
|