File size: 6,371 Bytes
c4cf758
 
 
 
2f82680
 
c4cf758
2f82680
 
f8d0f44
 
2f82680
 
 
 
 
 
 
 
f8d0f44
2f82680
f8d0f44
2f82680
f8d0f44
 
 
 
 
2f82680
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4cf758
2f82680
 
 
 
 
 
c4cf758
2f82680
 
f69d4dd
2f82680
f8d0f44
2f82680
f8d0f44
 
 
c4cf758
2f82680
 
 
 
 
 
 
f8d0f44
2f82680
f8d0f44
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import gradio as gr
import matplotlib.pyplot as plt
import pandas as pd
import joblib
from sklearn.metrics import mean_absolute_error, mean_squared_error
from math import sqrt

# Step 1: Load the Dataset
print("Loading Dataset...")
data_file = "webtraffic.csv"

try:
    webtraffic_data = pd.read_csv(data_file)
    print("Dataset loaded successfully!")
except Exception as e:
    print(f"Error loading dataset: {e}")
    exit()

# Step 2: Ensure 'Datetime' column exists or create it
if "Datetime" not in webtraffic_data.columns:
    print("Datetime column missing. Creating from 'Hour Index'.")
    start_date = pd.Timestamp("2024-01-01 00:00:00")
    webtraffic_data["Datetime"] = start_date + pd.to_timedelta(webtraffic_data["Hour Index"], unit="h")
else:
    webtraffic_data["Datetime"] = pd.to_datetime(webtraffic_data["Datetime"])

webtraffic_data.sort_values("Datetime", inplace=True)

# Step 3: Load SARIMA Model
print("Loading SARIMA Model...")
try:
    sarima_model = joblib.load("sarima_model.pkl")
    print("SARIMA model loaded successfully!")
except Exception as e:
    print(f"Error loading SARIMA model: {e}")
    exit()

# Step 4: Define Functions for Gradio Dashboard
future_periods = 48  # Number of hours to predict

def generate_sarima_plot():
    """Generate SARIMA predictions and return a detailed plot with metrics."""
    try:
        # Generate future dates for predictions
        future_dates = pd.date_range(
            start=webtraffic_data["Datetime"].iloc[-1],
            periods=future_periods + 1,
            freq="H"
        )[1:]

        # Generate SARIMA predictions
        sarima_predictions = sarima_model.predict(n_periods=future_periods)
        
        # Extract actual data for the last 'future_periods' hours
        actual_sessions = webtraffic_data["Sessions"].iloc[-future_periods:].values

        # Calculate metrics
        mae_sarima = mean_absolute_error(actual_sessions, sarima_predictions[:len(actual_sessions)])
        rmse_sarima = sqrt(mean_squared_error(actual_sessions, sarima_predictions[:len(actual_sessions)]))

        # Combine predictions into a DataFrame for plotting
        future_predictions = pd.DataFrame({
            "Datetime": future_dates,
            "SARIMA_Predicted": sarima_predictions
        })

        # Plot Actual Traffic vs SARIMA Predictions
        plt.figure(figsize=(15, 6))
        plt.plot(
            webtraffic_data["Datetime"],
            webtraffic_data["Sessions"],
            label="Actual Traffic",
            color="black",
            linestyle="dotted",
            linewidth=2,
        )
        plt.plot(
            future_predictions["Datetime"],
            future_predictions["SARIMA_Predicted"],
            label="SARIMA Predicted",
            color="blue",
            linewidth=2,
        )

        plt.title("SARIMA Predictions vs Actual Traffic", fontsize=16)
        plt.xlabel("Datetime", fontsize=12)
        plt.ylabel("Sessions", fontsize=12)
        plt.legend(loc="upper left")
        plt.grid(True)
        plt.tight_layout()

        # Save the plot
        plot_path = "sarima_prediction_plot.png"
        plt.savefig(plot_path)
        plt.close()

        # Return plot path and metrics
        metrics = f"""
        SARIMA Model Metrics:
        - Mean Absolute Error (MAE): {mae_sarima:.2f}
        - Root Mean Squared Error (RMSE): {rmse_sarima:.2f}
        """
        return plot_path, metrics

    except Exception as e:
        print(f"Error generating SARIMA plot: {e}")
        return None, "Error in generating output. Please check the data and model."

def generate_zoomed_plot():
    """Generate a zoomed-in SARIMA prediction plot."""
    try:
        # Generate future dates for predictions
        future_dates = pd.date_range(
            start=webtraffic_data["Datetime"].iloc[-1],
            periods=future_periods + 1,
            freq="H"
        )[1:]

        # Generate SARIMA predictions
        sarima_predictions = sarima_model.predict(n_periods=future_periods)

        # Combine predictions into a DataFrame for plotting
        future_predictions = pd.DataFrame({
            "Datetime": future_dates,
            "SARIMA_Predicted": sarima_predictions
        })

        # Zoomed-in view of the plot (recent data only)
        plt.figure(figsize=(15, 6))
        plt.plot(
            webtraffic_data["Datetime"].iloc[-future_periods:],
            webtraffic_data["Sessions"].iloc[-future_periods:],
            label="Actual Traffic (Zoomed)",
            color="black",
            linestyle="dotted",
            linewidth=2,
        )
        plt.plot(
            future_predictions["Datetime"],
            future_predictions["SARIMA_Predicted"],
            label="SARIMA Predicted (Zoomed)",
            color="green",
            linewidth=2,
        )

        plt.title("Zoomed-In SARIMA Predictions vs Actual Traffic", fontsize=16)
        plt.xlabel("Datetime", fontsize=12)
        plt.ylabel("Sessions", fontsize=12)
        plt.legend(loc="upper left")
        plt.grid(True)
        plt.tight_layout()

        # Save the zoomed plot
        zoomed_plot_path = "sarima_zoomed_plot.png"
        plt.savefig(zoomed_plot_path)
        plt.close()

        return zoomed_plot_path

    except Exception as e:
        print(f"Error generating zoomed plot: {e}")
        return None

# Step 5: Gradio Dashboard with Two Tiles and Metrics
with gr.Blocks() as dashboard:
    gr.Markdown("## Enhanced SARIMA Web Traffic Prediction Dashboard")
    gr.Markdown("This dashboard includes SARIMA predictions, performance metrics, and a zoomed-in view of recent data.")

    # Outputs: Main Plot and Metrics
    plot_output = gr.Image(label="SARIMA Prediction Plot")
    metrics_output = gr.Textbox(label="Model Metrics", lines=6)

    # Outputs: Zoomed Plot
    zoomed_plot_output = gr.Image(label="Zoomed-In Prediction Plot")

    # Button to Generate Results
    gr.Button("Generate Predictions").click(
        fn=generate_sarima_plot,
        inputs=[],
        outputs=[plot_output, metrics_output],
    )

    gr.Button("Generate Zoomed-In Plot").click(
        fn=generate_zoomed_plot,
        inputs=[],
        outputs=[zoomed_plot_output],
    )

# Launch the Gradio Dashboard
if __name__ == "__main__":
    print("\nLaunching Enhanced Gradio Dashboard...")
    dashboard.launch()