File size: 18,599 Bytes
b2108ae
e5e340b
 
d511c44
e5e340b
 
 
c0809e5
e5e340b
64e36b4
7cecffc
971fffe
7cecffc
4d6d97a
 
e5e340b
 
 
 
 
340176f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5e340b
fbf37ad
8d3eb2c
fbf37ad
 
 
e5e340b
 
2b898e9
 
6355b2b
2b898e9
53cd5ff
34960fe
 
53cd5ff
34960fe
 
53cd5ff
 
34960fe
 
 
 
 
 
 
e5e340b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2108ae
e5e340b
 
 
 
 
 
 
 
 
 
a9b1c26
e5e340b
 
 
 
 
 
 
 
 
 
7cecffc
e5e340b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b898e9
a4e353d
2b898e9
a4e353d
8d3eb2c
c0f7916
8ba5467
e5e340b
2b898e9
0a98de7
08750ad
1d5eb19
65275be
2b898e9
08750ad
e5e340b
fbf37ad
e5e340b
 
3127dc9
e5e340b
 
 
3127dc9
e5e340b
d511c44
 
90869c2
d511c44
90869c2
d511c44
932e273
d13e84f
 
 
d511c44
a4e353d
 
1258ec5
932e273
 
 
 
 
 
 
d3eff15
932e273
1258ec5
d3eff15
932e273
 
 
 
 
 
542509f
932e273
b895dd3
932e273
 
0b84d55
a709b08
1258ec5
a709b08
27e1b78
 
084abe0
ba185a4
 
a709b08
1258ec5
 
 
d511c44
 
c0f7916
 
65275be
d511c44
1258ec5
106863f
3127dc9
1258ec5
 
d511c44
a752441
c0f7916
d511c44
 
5616e81
 
 
b0365ad
78d4c43
5616e81
 
 
 
 
 
 
 
 
 
0e2e1e6
 
5616e81
81f859b
5616e81
 
 
 
 
8993619
 
21479cb
752c404
8993619
 
 
 
 
 
 
 
 
 
 
 
 
0e2e1e6
 
 
 
 
 
 
21479cb
d511c44
90869c2
d511c44
d13e84f
 
 
 
 
90869c2
1258ec5
 
 
 
 
 
 
 
 
0b84d55
a709b08
1258ec5
a709b08
27e1b78
 
084abe0
a709b08
1258ec5
3127dc9
2b898e9
106863f
3127dc9
 
 
1258ec5
 
 
d511c44
c0f7916
 
65275be
 
c2b17c7
78d4c43
 
c0f7916
f932524
d837990
6cb4dbe
 
 
 
d837990
7cecffc
d13e84f
 
 
7cecffc
 
 
 
 
 
 
 
 
6cb4dbe
 
7cecffc
6cb4dbe
 
 
084abe0
6cb4dbe
4d6d97a
6cb4dbe
f932524
 
6cb4dbe
f932524
6cb4dbe
 
 
 
 
 
 
 
c7e43d5
6cb4dbe
1210377
 
6355b2b
6cb4dbe
 
 
4d6d97a
6cb4dbe
 
 
4d6d97a
6cb4dbe
4d6d97a
 
 
6cb4dbe
65275be
f932524
 
 
 
c755f1c
f932524
 
 
d13e84f
7181a20
d13e84f
 
f932524
 
 
 
 
 
 
 
 
 
a709b08
f932524
a709b08
27e1b78
 
084abe0
a709b08
f932524
 
 
 
 
da23fc7
 
 
 
 
65275be
7c0f996
1210377
 
65275be
 
 
 
 
 
4d6d97a
7cecffc
4d6d97a
 
 
7cecffc
 
 
 
 
d511c44
932e273
d511c44
932e273
90869c2
7cecffc
 
 
f932524
 
7cecffc
d511c44
 
d066a0e
0b84d55
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from neuralforecast.core import NeuralForecast
from neuralforecast.models import NHITS, TimesNet, LSTM, TFT
from neuralforecast.losses.pytorch import HuberMQLoss
from neuralforecast.utils import AirPassengersDF
import time
from st_aggrid import AgGrid
from nixtla import NixtlaClient
import os

st.set_page_config(layout='wide')
    
@st.cache_resource
def load_model(path, freq):
    nf = NeuralForecast.load(path=path)
    return nf

@st.cache_resource
def load_all_models():
    nhits_paths = {
        'D': './M4/NHITS/daily',
        'M': './M4/NHITS/monthly',
        'H': './M4/NHITS/hourly',
        'W': './M4/NHITS/weekly',
        'Y': './M4/NHITS/yearly'
    }
    
    timesnet_paths = {
        'D': './M4/TimesNet/daily',
        'M': './M4/TimesNet/monthly',
        'H': './M4/TimesNet/hourly',
        'W': './M4/TimesNet/weekly',
        'Y': './M4/TimesNet/yearly'
    }
    
    lstm_paths = {
        'D': './M4/LSTM/daily',
        'M': './M4/LSTM/monthly',
        'H': './M4/LSTM/hourly',
        'W': './M4/LSTM/weekly',
        'Y': './M4/LSTM/yearly'
    }
    
    tft_paths = {
        'D': './M4/TFT/daily',
        'M': './M4/TFT/monthly',
        'H': './M4/TFT/hourly',
        'W': './M4/TFT/weekly',
        'Y': './M4/TFT/yearly'
    }
    nhits_models = {freq: load_model(path, freq) for freq, path in nhits_paths.items()}
    timesnet_models = {freq: load_model(path, freq) for freq, path in timesnet_paths.items()}
    lstm_models = {freq: load_model(path, freq) for freq, path in lstm_paths.items()}
    tft_models = {freq: load_model(path, freq) for freq, path in tft_paths.items()}

    return nhits_models, timesnet_models, lstm_models, tft_models

def generate_forecast(model, df,tag=False):
    if tag == 'retrain':
        forecast_df = model.predict()
    else:
        forecast_df = model.predict(df=df)
    return forecast_df

def determine_frequency(df):
    df['ds'] = pd.to_datetime(df['ds'])
    df = df.drop_duplicates(subset='ds')
    df = df.set_index('ds')
    
    # # Create a complete date range
    # full_range = pd.date_range(start=df.index.min(), end=df.index.max(),freq=freq)
    
    # # Reindex the DataFrame to this full date range
    # df_full = df.reindex(full_range)
    
    # Infer the frequency
    # freq = pd.infer_freq(df_full.index)

    freq = pd.infer_freq(df.index)
    if not freq:
        st.warning('The forecast will use default Daily forecast due to date inconsistency. Please check your data.',icon="⚠️")
        freq = 'D'
        
    return freq

def plot_forecasts(forecast_df, train_df, title):
    fig, ax = plt.subplots(1, 1, figsize=(20, 7))
    plot_df = pd.concat([train_df, forecast_df]).set_index('ds')
    historical_col = 'y'
    forecast_col = next((col for col in plot_df.columns if 'median' in col), None)
    lo_col = next((col for col in plot_df.columns if 'lo-90' in col), None)
    hi_col = next((col for col in plot_df.columns if 'hi-90' in col), None)
    if forecast_col is None:
        raise KeyError("No forecast column found in the data.")
    plot_df[[historical_col, forecast_col]].plot(ax=ax, linewidth=2, label=['Historical', 'Forecast'])
    if lo_col and hi_col:
        ax.fill_between(
            plot_df.index,
            plot_df[lo_col],
            plot_df[hi_col],
            color='blue',
            alpha=0.3,
            label='90% Confidence Interval'
        )
    ax.set_title(title, fontsize=22)
    ax.set_ylabel('Value', fontsize=20)
    ax.set_xlabel('Timestamp [t]', fontsize=20)
    ax.legend(prop={'size': 15})
    ax.grid()
    st.pyplot(fig)

def select_model_based_on_frequency(freq, nhits_models, timesnet_models, lstm_models, tft_models):
    if freq == 'D':
        return nhits_models['D'], timesnet_models['D'], lstm_models['D'], tft_models['D']
    elif freq == 'ME':
        return nhits_models['M'], timesnet_models['M'], lstm_models['M'], tft_models['M']
    elif freq == 'H':
        return nhits_models['H'], timesnet_models['H'], lstm_models['H'], tft_models['H']
    elif freq in ['W', 'W-SUN']:
        return nhits_models['W'], timesnet_models['W'], lstm_models['W'], tft_models['W']
    elif freq in ['Y', 'Y-DEC']:
        return nhits_models['Y'], timesnet_models['Y'], lstm_models['Y'], tft_models['Y']
    else:
        raise ValueError(f"Unsupported frequency: {freq}")

def select_model(horizon, model_type, max_steps=50):
    if model_type == 'NHITS':
        return NHITS(input_size=5 * horizon,
                     h=horizon,
                     max_steps=max_steps,
                     stack_types=3*['identity'],
                     n_blocks=3*[1],
                     mlp_units=[[256, 256] for _ in range(3)],
                     n_pool_kernel_size=3*[1],
                     batch_size=32,
                     scaler_type='standard',
                     n_freq_downsample=[12, 4, 1],
                     loss=HuberMQLoss(level=[90]))
    elif model_type == 'TimesNet':
        return TimesNet(h=horizon,
                        input_size=horizon * 5,
                        hidden_size=16,
                        conv_hidden_size=32,
                        loss=HuberMQLoss(level=[90]),
                        scaler_type='standard',
                        learning_rate=1e-3,
                        max_steps=max_steps,
                        val_check_steps=200,
                        valid_batch_size=64,
                        windows_batch_size=128,
                        inference_windows_batch_size=512)
    elif model_type == 'LSTM':
        return LSTM(h=horizon,
                    input_size=horizon * 5,
                    loss=HuberMQLoss(level=[90]),
                    scaler_type='standard',
                    encoder_n_layers=2,
                    encoder_hidden_size=64,
                    context_size=10,
                    decoder_hidden_size=64,
                    decoder_layers=2,
                    max_steps=max_steps)
    elif model_type == 'TFT':
        return TFT(h=horizon,
                   input_size=horizon,
                   hidden_size=16,
                   loss=HuberMQLoss(level=[90]),
                   learning_rate=0.005,
                   scaler_type='standard',
                   windows_batch_size=128,
                   max_steps=max_steps,
                   val_check_steps=200,
                   valid_batch_size=64,
                   enable_progress_bar=True)
    else:
        raise ValueError(f"Unsupported model type: {model_type}")

def model_train(df,model, freq):
    nf = NeuralForecast(models=[model], freq=freq)
    df['ds'] = pd.to_datetime(df['ds'])
    nf.fit(df)
    return nf

def forecast_time_series(df, model_type, horizon, max_steps,y_col):
    start_time = time.time()  # Start timing
    freq = determine_frequency(df)
    st.sidebar.write(f"Data frequency: {freq}")
    
    selected_model = select_model(horizon, model_type, max_steps)
    st.spinner(f"Training {model_type} model...")
    model = model_train(df, selected_model,freq)
    
    forecast_results = {}
    forecast_results[model_type] = generate_forecast(model, df, tag='retrain')

    for model_name, forecast_df in forecast_results.items():
        plot_forecasts(forecast_df, df, f'{model_name} Forecast for {y_col}')
        
    end_time = time.time()  # End timing
    time_taken = end_time - start_time
    st.success(f"Time taken for {model_type} forecast: {time_taken:.2f} seconds")

@st.cache_data
def load_default():
    df = AirPassengersDF.copy()
    return df

def transfer_learning_forecasting():
    st.title("Zero-shot Forecasting")
    st.markdown("""
    Instant time series forecasting and visualization by using various pre-trained deep neural network-based model trained on M4 data.
    """)

    nhits_models, timesnet_models, lstm_models, tft_models = load_all_models()
    
    with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
        if 'uploaded_file' not in st.session_state:
            uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
            if uploaded_file:
                df = pd.read_csv(uploaded_file)
                st.session_state.df = df
                st.session_state.uploaded_file = uploaded_file
            else:
                df = load_default()
                st.session_state.df = df
        else:
            if st.checkbox("Upload a new file (CSV)"):
                uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
                if uploaded_file:
                    df = pd.read_csv(uploaded_file)
                    st.session_state.df = df
                    st.session_state.uploaded_file = uploaded_file
                else:
                    df = st.session_state.df
            else:
                df = st.session_state.df
            
            
        columns = df.columns.tolist()  # Convert Index to list
        opt = []
        ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
        if 'ds' in columns and 'unique_id' in columns:
            columns.pop(columns.index('ds'))
            columns.pop(columns.index('unique_id'))
        opt = columns
        if 'ds' in opt:
            opt.remove('ds')
        y_col = st.selectbox("Select Target column", options=opt, index=0)

    st.session_state.ds_col = ds_col
    st.session_state.y_col = y_col

    # Model selection and forecasting
    st.sidebar.subheader("Model Selection and Forecasting")
    model_choice = st.sidebar.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"])
    horizon = st.sidebar.number_input("Forecast horizon", value=12)

    df = df.rename(columns={ds_col: 'ds', y_col: 'y'})
    df['unique_id']=1
    df = df[['unique_id','ds','y']]
    st.session_state.df = df

    # Determine frequency of data
    frequency = determine_frequency(df)
    st.sidebar.write(f"Detected frequency: {frequency}")


    nhits_model, timesnet_model, lstm_model, tft_model = select_model_based_on_frequency(frequency, nhits_models, timesnet_models, lstm_models, tft_models)
    forecast_results = {}

    

    if st.sidebar.button("Submit"):
        start_time = time.time()  # Start timing
        if model_choice == "NHITS":
            forecast_results['NHITS'] = generate_forecast(nhits_model, df)
        elif model_choice == "TimesNet":
            forecast_results['TimesNet'] = generate_forecast(timesnet_model, df)
        elif model_choice == "LSTM":
            forecast_results['LSTM'] = generate_forecast(lstm_model, df)
        elif model_choice == "TFT":
            forecast_results['TFT'] = generate_forecast(tft_model, df)

        st.session_state.forecast_results = forecast_results
        for model_name, forecast_df in forecast_results.items():
            plot_forecasts(forecast_df.iloc[:horizon,:], df, f'{model_name} Forecast for {y_col}')

        end_time = time.time()  # End timing
        time_taken = end_time - start_time
        st.success(f"Time taken for {model_choice} forecast: {time_taken:.2f} seconds")

    if 'forecast_results' in st.session_state:
        forecast_results = st.session_state.forecast_results

        st.markdown('You can download Input and Forecast Data below')
        tab_insample, tab_forecast  = st.tabs(
                        ["Input data", "Forecast"]
                    )
            
        with tab_insample:
            df_grid = df.drop(columns="unique_id")
            st.write(df_grid)
            # grid_table = AgGrid(
            #                 df_grid,
            #                 theme="alpine",
            #             )
    
        with tab_forecast:
            if model_choice in forecast_results:
                df_grid = forecast_results[model_choice]
                st.write(df_grid)
                # grid_table = AgGrid(
                #                 df_grid,
                #                 theme="alpine",
                #             )


def dynamic_forecasting():
    st.title("Dynamic Forecasting")
    st.markdown("""
    Train time series forecasting model from scratch and provide forecasts/visualization by using various deep neural network-based model trained on user data.
    
    Forecasting speed depends on CPU/GPU availabilty.
    """)
    
    with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
        uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
        if uploaded_file:
            df = pd.read_csv(uploaded_file)
            st.session_state.df = df
        else:
            df = load_default()
            st.session_state.df = df
        
        columns = df.columns.tolist()  # Convert Index to list
        opt = []
        ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
        if 'ds' in columns and 'unique_id' in columns:
            columns.pop(columns.index('ds'))
            columns.pop(columns.index('unique_id'))
        opt = columns
        y_col = st.selectbox("Select Target column", options=opt, index=0)

    df = df.rename(columns={ds_col: 'ds', y_col: 'y'})
    
    df['unique_id']=1
    df = df[['unique_id','ds','y']]
    st.session_state.df = df
    
    st.session_state.ds_col = ds_col
    st.session_state.y_col = y_col

    # Dynamic forecasting
    st.sidebar.subheader("Dynamic Model Selection and Forecasting")
    dynamic_model_choice = st.sidebar.selectbox("Select model for dynamic forecasting", ["NHITS", "TimesNet", "LSTM", "TFT"], key="dynamic_model_choice")
    dynamic_horizon = st.sidebar.number_input("Forecast horizon", value=12)
    dynamic_max_steps = st.sidebar.number_input('Max steps', value=20)

    if st.sidebar.button("Submit"):
        forecast_time_series(df, dynamic_model_choice, dynamic_horizon, dynamic_max_steps,y_col)

def timegpt_fcst():
    nixtla_token = os.environ.get("NIXTLA_API_KEY")
    nixtla_client = NixtlaClient(
    api_key = nixtla_token
    )

    
    st.title("TimeGPT Forecasting")
    st.markdown("""
    Instant time series forecasting and visualization by using the TimeGPT API provided by Nixtla.
    """)
    with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
        uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
        if uploaded_file:
            df = pd.read_csv(uploaded_file)
            st.session_state.df = df
        else:
            df = load_default()
            st.session_state.df = df
        
        columns = df.columns.tolist()  # Convert Index to list
        opt = []
        ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
        if 'ds' in columns and 'unique_id' in columns:
            columns.pop(columns.index('ds'))
            columns.pop(columns.index('unique_id'))
        opt = columns
        y_col = st.selectbox("Select Target column", options=opt, index=0)
        h = st.number_input("Forecast horizon", value=14)

        df = df.rename(columns={ds_col: 'ds', y_col: 'y'})

        
        id_col = 'ts_test'
        df['unique_id']=id_col
        df = df[['unique_id','ds','y']]
        st.session_state.df = df
        
        st.session_state.ds_col = ds_col
        st.session_state.y_col = y_col
        

        freq = determine_frequency(df)

        df = df.drop_duplicates(subset=['ds']).reset_index(drop=True)
        
        # st.write(df)
        if st.sidebar.button("Submit"):
            forecast_df = nixtla_client.forecast(
                df=df,
                h=h,
                freq=freq,
                level=[90]
            )
            st.session_state.forecast_df = forecast_df

    if 'forecast_df' in st.session_state:
        forecast_df = st.session_state.forecast_df
        st.pyplot(nixtla_client.plot(df, forecast_df, level=[90]))



def timegpt_anom():
    nixtla_token = os.environ.get("NIXTLA_API_KEY")
    nixtla_client = NixtlaClient(
    api_key = nixtla_token
    )

    
    st.title("TimeGPT Anomaly Detection")
    st.markdown("""
    Instant time series anomaly detection and visualization by using the TimeGPT API provided by Nixtla.
    """)
    with st.sidebar.expander("Upload and Configure Dataset", expanded=True):
        uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"])
        if uploaded_file:
            df = pd.read_csv(uploaded_file)
            st.session_state.df = df
        else:
            df = load_default()
            st.session_state.df = df
        
        columns = df.columns.tolist()  # Convert Index to list
        opt = []
        ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0)
        if 'ds' in columns and 'unique_id' in columns:
            columns.pop(columns.index('ds'))
            columns.pop(columns.index('unique_id'))
        opt = columns
        y_col = st.selectbox("Select Target column", options=opt, index=0)

        df = df.rename(columns={ds_col: 'ds', y_col: 'y'})

        id_col = 'ts_test'
        df['unique_id']=id_col
        df = df[['unique_id','ds','y']]
        st.session_state.df = df
        
        st.session_state.ds_col = ds_col
        st.session_state.y_col = y_col

        freq = determine_frequency(df)

        df = df.drop_duplicates(subset=['ds']).reset_index(drop=True)
        if st.sidebar.button("Submit"):
            anom_df = nixtla_client.detect_anomalies(
                df=df,
                freq=freq,
                level=90
            )
            st.session_state.anom_df = anom_df

    if 'anom_df' in st.session_state:
        anom_df = st.session_state.anom_df
        st.pyplot(nixtla_client.plot(df, anom_df))
        
    
    
    

pg = st.navigation({
    "Neuralforecast": [
        # Load pages from functions
        st.Page(transfer_learning_forecasting, title="Zero-shot Forecasting", default=True, icon=":material/query_stats:"),
        st.Page(dynamic_forecasting, title="Dynamic Forecasting", icon=":material/monitoring:"),
    ],
        "TimeGPT": [
        # Load pages from functions
        st.Page(timegpt_fcst, title="TimeGPT Forecast", icon=":material/smart_toy:"),
        st.Page(timegpt_anom, title="TimeGPT Anomalies Detection", icon=":material/detector_offline:")
        ]
})

pg.run()