webtraffic / app.py
manjunathainti's picture
Initial commit of web traffic prediction app
c4cf758
raw
history blame
3.42 kB
import gradio as gr
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import tensorflow as tf
import joblib
# Load the dataset
webtraffic_data = pd.read_csv("webtraffic.csv")
# Ensure the 'Datetime' column is in datetime format
webtraffic_data['Datetime'] = pd.to_datetime(webtraffic_data['Datetime'])
# Load the pre-trained models
sarima_model = joblib.load("sarima_model.pkl") # Load SARIMA model
lstm_model = tf.keras.models.load_model("lstm_model.keras") # Load LSTM model
# Load the scaler for LSTM if used during training (optional)
scaler = joblib.load("scaler.pkl")
# Function to generate predictions and plots
def generate_custom_prediction(model, future_hours):
future_hours = int(future_hours)
future_datetimes = pd.date_range(
start=webtraffic_data['Datetime'].iloc[-1],
periods=future_hours + 1,
freq='H'
)[1:]
if model == "SARIMA":
# SARIMA Predictions
sarima_predictions = sarima_model.forecast(steps=future_hours)
plt.figure(figsize=(15, 6))
plt.plot(webtraffic_data['Datetime'], webtraffic_data['Sessions'], label="Actual Data", color="blue")
plt.plot(future_datetimes, sarima_predictions, label="SARIMA Predictions", color="green")
elif model == "LSTM":
# Prepare data for LSTM (reshape and scale as necessary)
lstm_input = webtraffic_data['Sessions'].values[-future_hours:].reshape(-1, 1)
lstm_input_scaled = scaler.transform(lstm_input) # Scale input using the saved scaler
lstm_input_scaled = lstm_input_scaled.reshape(1, future_hours, 1) # Reshape for LSTM model
# LSTM Predictions
lstm_predictions = lstm_model.predict(lstm_input_scaled)
lstm_predictions = scaler.inverse_transform(lstm_predictions).flatten() # Inverse scale
plt.figure(figsize=(15, 6))
plt.plot(webtraffic_data['Datetime'], webtraffic_data['Sessions'], label="Actual Data", color="blue")
plt.plot(future_datetimes, lstm_predictions, label="LSTM Predictions", color="green")
# Customize the plot
plt.title(f"{model} Web Traffic Predictions", fontsize=16)
plt.xlabel("Datetime", fontsize=12)
plt.ylabel("Sessions", fontsize=12)
plt.legend(loc="upper left")
plt.grid(True)
plt.tight_layout()
# Save the plot as an image
plot_path = f"{model.lower()}_prediction_plot.png"
plt.savefig(plot_path)
plt.close()
return plot_path
# Gradio interface function
def prediction_dashboard(model, future_hours):
plot_path = generate_custom_prediction(model, future_hours)
return plot_path
# Build the Gradio interface
with gr.Blocks() as dashboard:
gr.Markdown("## Interactive Web Traffic Prediction Dashboard")
gr.Markdown("Input the number of hours to predict and select a model for future web traffic forecasting.")
# Dropdown for model selection
model_selection = gr.Dropdown(["SARIMA", "LSTM"], label="Select Model", value="SARIMA")
# Input for future hours
future_hours_input = gr.Number(label="Future Hours to Predict", value=24)
# Output: Plot
plot_output = gr.Image(label="Prediction Plot")
# Button to generate predictions
gr.Button("Generate Prediction").click(
fn=prediction_dashboard,
inputs=[model_selection, future_hours_input],
outputs=[plot_output]
)
# Launch the Gradio dashboard
dashboard.launch()