File size: 15,573 Bytes
b2108ae
e5e340b
 
d511c44
e5e340b
 
 
c0809e5
e5e340b
64e36b4
7cecffc
971fffe
7cecffc
e5e340b
 
 
 
 
 
340176f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5e340b
fbf37ad
8d3eb2c
fbf37ad
 
 
e5e340b
 
2b898e9
 
 
e5e340b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2108ae
e5e340b
 
 
 
 
 
 
 
 
 
a9b1c26
e5e340b
 
 
 
 
 
 
 
 
 
7cecffc
e5e340b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b898e9
a4e353d
2b898e9
a4e353d
8d3eb2c
c0f7916
8ba5467
e5e340b
2b898e9
0a98de7
08750ad
1d5eb19
2b898e9
08750ad
e5e340b
c0f7916
fbf37ad
e5e340b
 
3127dc9
e5e340b
 
 
3127dc9
e5e340b
d511c44
 
90869c2
d511c44
90869c2
d511c44
 
 
a4e353d
 
1258ec5
 
 
 
 
 
 
 
 
 
0b84d55
1258ec5
971fffe
106863f
1258ec5
 
 
d511c44
 
c0f7916
 
 
d511c44
1258ec5
106863f
3127dc9
1258ec5
 
d511c44
a752441
c0f7916
d511c44
b0365ad
 
c875dd7
 
64e36b4
b0365ad
 
1eb07a6
 
 
 
 
d511c44
c875dd7
 
1eb07a6
 
 
 
c875dd7
b0365ad
 
 
 
 
a3dc2e1
78d4c43
 
a3dc2e1
 
 
 
 
 
 
 
 
 
78d4c43
 
21479cb
 
 
 
 
d511c44
90869c2
d511c44
70a8bde
90869c2
1258ec5
 
 
 
 
 
 
 
 
 
0b84d55
1258ec5
971fffe
106863f
1258ec5
3127dc9
2b898e9
106863f
3127dc9
 
 
1258ec5
 
 
d511c44
c0f7916
 
 
7cecffc
c2b17c7
78d4c43
 
c0f7916
f932524
d837990
 
 
 
 
 
7cecffc
 
 
 
 
 
 
 
 
 
 
 
 
971fffe
7cecffc
f932524
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
971fffe
f932524
 
 
 
 
 
 
 
 
 
 
 
 
7cecffc
d837990
f932524
 
d837990
 
7cecffc
 
 
 
 
d511c44
7cecffc
d511c44
90869c2
 
7cecffc
 
 
f932524
 
7cecffc
d511c44
 
 
 
 
c0f7916
0b84d55
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from neuralforecast.core import NeuralForecast
from neuralforecast.models import NHITS, TimesNet, LSTM, TFT
from neuralforecast.losses.pytorch import HuberMQLoss
from neuralforecast.utils import AirPassengersDF
import time
from st_aggrid import AgGrid
from nixtla import NixtlaClient
import os


@st.cache_resource
def load_model(path, freq):
    nf = NeuralForecast.load(path=path)
    return nf

@st.cache_resource
def load_all_models():
    nhits_paths = {
        'D': './M4/NHITS/daily',
        'M': './M4/NHITS/monthly',
        'H': './M4/NHITS/hourly',
        'W': './M4/NHITS/weekly',
        'Y': './M4/NHITS/yearly'
    }
    
    timesnet_paths = {
        'D': './M4/TimesNet/daily',
        'M': './M4/TimesNet/monthly',
        'H': './M4/TimesNet/hourly',
        'W': './M4/TimesNet/weekly',
        'Y': './M4/TimesNet/yearly'
    }
    
    lstm_paths = {
        'D': './M4/LSTM/daily',
        'M': './M4/LSTM/monthly',
        'H': './M4/LSTM/hourly',
        'W': './M4/LSTM/weekly',
        'Y': './M4/LSTM/yearly'
    }
    
    tft_paths = {
        'D': './M4/TFT/daily',
        'M': './M4/TFT/monthly',
        'H': './M4/TFT/hourly',
        'W': './M4/TFT/weekly',
        'Y': './M4/TFT/yearly'
    }
    nhits_models = {freq: load_model(path, freq) for freq, path in nhits_paths.items()}
    timesnet_models = {freq: load_model(path, freq) for freq, path in timesnet_paths.items()}
    lstm_models = {freq: load_model(path, freq) for freq, path in lstm_paths.items()}
    tft_models = {freq: load_model(path, freq) for freq, path in tft_paths.items()}

    return nhits_models, timesnet_models, lstm_models, tft_models

def generate_forecast(model, df,tag=False):
    if tag == 'retrain':
        forecast_df = model.predict()
    else:
        forecast_df = model.predict(df=df)
    return forecast_df

def determine_frequency(df):
    df['ds'] = pd.to_datetime(df['ds'])
    df = df.set_index('ds')
    freq = pd.infer_freq(df.index)
    return freq

def plot_forecasts(forecast_df, train_df, title):
    fig, ax = plt.subplots(1, 1, figsize=(20, 7))
    plot_df = pd.concat([train_df, forecast_df]).set_index('ds')
    historical_col = 'y'
    forecast_col = next((col for col in plot_df.columns if 'median' in col), None)
    lo_col = next((col for col in plot_df.columns if 'lo-90' in col), None)
    hi_col = next((col for col in plot_df.columns if 'hi-90' in col), None)
    if forecast_col is None:
        raise KeyError("No forecast column found in the data.")
    plot_df[[historical_col, forecast_col]].plot(ax=ax, linewidth=2, label=['Historical', 'Forecast'])
    if lo_col and hi_col:
        ax.fill_between(
            plot_df.index,
            plot_df[lo_col],
            plot_df[hi_col],
            color='blue',
            alpha=0.3,
            label='90% Confidence Interval'
        )
    ax.set_title(title, fontsize=22)
    ax.set_ylabel('Value', fontsize=20)
    ax.set_xlabel('Timestamp [t]', fontsize=20)
    ax.legend(prop={'size': 15})
    ax.grid()
    st.pyplot(fig)

def select_model_based_on_frequency(freq, nhits_models, timesnet_models, lstm_models, tft_models):
    if freq == 'D':
        return nhits_models['D'], timesnet_models['D'], lstm_models['D'], tft_models['D']
    elif freq == 'ME':
        return nhits_models['M'], timesnet_models['M'], lstm_models['M'], tft_models['M']
    elif freq == 'H':
        return nhits_models['H'], timesnet_models['H'], lstm_models['H'], tft_models['H']
    elif freq in ['W', 'W-SUN']:
        return nhits_models['W'], timesnet_models['W'], lstm_models['W'], tft_models['W']
    elif freq in ['Y', 'Y-DEC']:
        return nhits_models['Y'], timesnet_models['Y'], lstm_models['Y'], tft_models['Y']
    else:
        raise ValueError(f"Unsupported frequency: {freq}")

def select_model(horizon, model_type, max_steps=50):
    if model_type == 'NHITS':
        return NHITS(input_size=5 * horizon,
                     h=horizon,
                     max_steps=max_steps,
                     stack_types=3*['identity'],
                     n_blocks=3*[1],
                     mlp_units=[[256, 256] for _ in range(3)],
                     n_pool_kernel_size=3*[1],
                     batch_size=32,
                     scaler_type='standard',
                     n_freq_downsample=[12, 4, 1],
                     loss=HuberMQLoss(level=[90]))
    elif model_type == 'TimesNet':
        return TimesNet(h=horizon,
                        input_size=horizon * 5,
                        hidden_size=16,
                        conv_hidden_size=32,
                        loss=HuberMQLoss(level=[90]),
                        scaler_type='standard',
                        learning_rate=1e-3,
                        max_steps=max_steps,
                        val_check_steps=200,
                        valid_batch_size=64,
                        windows_batch_size=128,
                        inference_windows_batch_size=512)
    elif model_type == 'LSTM':
        return LSTM(h=horizon,
                    input_size=horizon * 5,
                    loss=HuberMQLoss(level=[90]),
                    scaler_type='standard',
                    encoder_n_layers=2,
                    encoder_hidden_size=64,
                    context_size=10,
                    decoder_hidden_size=64,
                    decoder_layers=2,
                    max_steps=max_steps)
    elif model_type == 'TFT':
        return TFT(h=horizon,
                   input_size=horizon,
                   hidden_size=16,
                   loss=HuberMQLoss(level=[90]),
                   learning_rate=0.005,
                   scaler_type='standard',
                   windows_batch_size=128,
                   max_steps=max_steps,
                   val_check_steps=200,
                   valid_batch_size=64,
                   enable_progress_bar=True)
    else:
        raise ValueError(f"Unsupported model type: {model_type}")

def model_train(df,model, freq):
    nf = NeuralForecast(models=[model], freq=freq)
    df['ds'] = pd.to_datetime(df['ds'])
    nf.fit(df)
    return nf

def forecast_time_series(df, model_type, horizon, max_steps,y_col):
    start_time = time.time()  # Start timing
    freq = determine_frequency(df)
    st.sidebar.write(f"Data frequency: {freq}")
    
    selected_model = select_model(horizon, model_type, max_steps)
    model = model_train(df, selected_model,freq)
    
    forecast_results = {}
    st.sidebar.write(f"Generating forecast using {model_type} model...")
    forecast_results[model_type] = generate_forecast(model, df, tag='retrain')

    for model_name, forecast_df in forecast_results.items():
        plot_forecasts(forecast_df, df, f'{model_name} Forecast for {y_col}')
        
    end_time = time.time()  # End timing
    time_taken = end_time - start_time
    st.success(f"Time taken for {model_type} forecast: {time_taken:.2f} seconds")

@st.cache_data
def load_default():
    df = AirPassengersDF.copy()
    return df

def transfer_learning_forecasting():
    st.title("Transfer Learning Forecasting")

    nhits_models, timesnet_models, lstm_models, tft_models = load_all_models()
    
    with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
        uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
        if uploaded_file:
            df = pd.read_csv(uploaded_file)
            st.session_state.df = df
        else:
            df = load_default()  
            st.session_state.df = df
        
        # Column selection
        columns = df.columns.tolist()  # Convert Index to list
        ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
        y_col = st.selectbox("Select Target column", options=columns,, index=columns.index('ds') if 'ds' in columns else 0)
        # unique_id_col = st.text_input("Unique ID column (default: '1')", value="1")

    st.session_state.ds_col = ds_col
    st.session_state.y_col = y_col

    # Model selection and forecasting
    st.sidebar.subheader("Model Selection and Forecasting")
    model_choice = st.sidebar.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"])
    horizon = st.sidebar.number_input("Forecast horizon", value=18)

    df = df.rename(columns={ds_col: 'ds', y_col: 'y'})
    df['unique_id']=1
    df = df[['unique_id','ds','y']]
    st.session_state.df = df

    # Determine frequency of data
    frequency = determine_frequency(df)
    st.sidebar.write(f"Detected frequency: {frequency}")

    col1, col2 = st.columns([2,4])
    with col1:
        tab_insample, tab_forecast  = st.tabs(
                    ["Input data", "Forecast"]
                )
        with tab_insample:
            df_grid = df.drop(columns="unique_id")
            st.write(df_grid)
            # grid_table = AgGrid(
            #                 df_grid,
            #                 theme="alpine",
            #             )

        with tab_forecast:
            df_grid = df.drop(columns="unique_id")
            # grid_table = AgGrid(
            #                 df_grid,
            #                 theme="alpine",
            #             )
            
    with col2:
    # Load pre-trained models
        nhits_model, timesnet_model, lstm_model, tft_model = select_model_based_on_frequency(frequency, nhits_models, timesnet_models, lstm_models, tft_models)
        forecast_results = {}
    
        

        if st.sidebar.button("Submit"):
            start_time = time.time()  # Start timing
            if model_choice == "NHITS":
                forecast_results['NHITS'] = generate_forecast(nhits_model, df)
            elif model_choice == "TimesNet":
                forecast_results['TimesNet'] = generate_forecast(timesnet_model, df)
            elif model_choice == "LSTM":
                forecast_results['LSTM'] = generate_forecast(lstm_model, df)
            elif model_choice == "TFT":
                forecast_results['TFT'] = generate_forecast(tft_model, df)
                
            for model_name, forecast_df in forecast_results.items():
                plot_forecasts(forecast_df, df, f'{model_name} Forecast for {y_col}')

            end_time = time.time()  # End timing
            time_taken = end_time - start_time
            st.success(f"Time taken for {model_choice} forecast: {time_taken:.2f} seconds")


def dynamic_forecasting():
    st.title("Dynamic Forecasting")
    st.subheader("Speed depends on CPU/GPU availability", divider="gray")
    
    with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
        uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
        if uploaded_file:
            df = pd.read_csv(uploaded_file)
            st.session_state.df = df
        else:
            df = load_default()
            st.session_state.df = df
        
        # Column selection
        columns = df.columns.tolist()  # Convert Index to list
        ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
        y_col = st.selectbox("Select Target column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
        # unique_id_col = st.text_input("Unique ID column (default: '1')", value="1")

    df = df.rename(columns={ds_col: 'ds', y_col: 'y'})
    
    df['unique_id']=1
    df = df[['unique_id','ds','y']]
    st.session_state.df = df
    
    st.session_state.ds_col = ds_col
    st.session_state.y_col = y_col

    # Dynamic forecasting
    st.sidebar.subheader("Dynamic Model Selection and Forecasting")
    dynamic_model_choice = st.sidebar.selectbox("Select model for dynamic forecasting", ["NHITS", "TimesNet", "LSTM", "TFT"], key="dynamic_model_choice")
    dynamic_horizon = st.sidebar.number_input("Forecast horizon", value=18)
    dynamic_max_steps = st.sidebar.number_input('Max steps', value=10)

    if st.sidebar.button("Submit"):
        forecast_time_series(df, dynamic_model_choice, dynamic_horizon, dynamic_max_steps,y_col)

def timegpt_fcst():
    nixtla_token = os.environ.get("NIXTLA_API_KEY")
    nixtla_client = NixtlaClient(
    api_key = api_key
    )

    
    st.title("TimeGPT Forecasting")
    with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
        uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
        if uploaded_file:
            df = pd.read_csv(uploaded_file)
            st.session_state.df = df
        else:
            df = load_default()
            st.session_state.df = df
        
        # Column selection
        columns = df.columns.tolist()  # Convert Index to list
        ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
        y_col = st.selectbox("Select Target column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)

        df = df.rename(columns={ds_col: 'ds', y_col: 'y'})

        id_col = 'ts_test'
        df['unique_id']=id_col
        freq = determine_frequency(df)
        
        forecast_df = nixtla_client.forecast(
            df=df,
            h=7,
            freq=freq,
            level=[90]
        )

        nixtla_client.plot(
            forecast_df,
            level=[90],
            max_insample_length=365
        )

def timegpt_anom():
    nixtla_token = os.environ.get("NIXTLA_API_KEY")
    nixtla_client = NixtlaClient(
    api_key = api_key
    )

    
    st.title("TimeGPT Forecasting")
    with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
        uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
        if uploaded_file:
            df = pd.read_csv(uploaded_file)
            st.session_state.df = df
        else:
            df = load_default()
            st.session_state.df = df
        
        # Column selection
        columns = df.columns.tolist()  # Convert Index to list
        ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
        y_col = st.selectbox("Select Target column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)

        df = df.rename(columns={ds_col: 'ds', y_col: 'y'})

        id_col = 'ts_test'
        df['unique_id']=id_col
        freq = determine_frequency(df)
        
        forecast_df = nixtla_client.forecast(
            df=df,
            h=7,
            freq=freq,
            level=[90]
        )

        nixtla_client.plot(
            forecast_df,
            level=[90],
            max_insample_length=365
        )
        
    
    
    

pg = st.navigation({
    "NeuralForecast": [
        # Load pages from functions
        st.Page(transfer_learning_forecasting, title="Transfer Learning Forecasting", default=True, icon=":material/query_stats:"),
        st.Page(dynamic_forecasting, title="Dynamic Forecasting", icon=":material/monitoring:"),
    ],
        "TimeGPT": [
        # Load pages from functions
        st.Page(timegpt_fcst, title="TimeGPT Forecast", icon=":material/smart_toy:"),
        st.Page(timegpt_anom, title="TimeGPT Anomalies Detection", icon=":material/detector_offline:")
        ]
})

try:
    pg.run()
except Exception as e:
    st.sidebar.error(f"Something went wrong: {e}", icon=":material/error:")