File size: 222,369 Bytes
f4355cb |
|
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "86ad0e30",
"metadata": {},
"outputs": [],
"source": [
"# tutorial url\n",
"# https://huggingface.co/blog/time-series-transformers"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "4357ee0e",
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a7009beb",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Found cached dataset monash_tsf (C:/Users/yozhan/.cache/huggingface/datasets/monash_tsf/tourism_monthly/1.0.0/fc869f3ae1577c9def2a919ab1dd0c3d4a7a44826b8e0e8fa423bb0161b629e2)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "553772b851a041aeb196651882da5fbe",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"dataset = load_dataset(\"monash_tsf\", \"tourism_monthly\")"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "27005a68",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"dict_keys(['start', 'target', 'feat_static_cat', 'feat_dynamic_real', 'item_id'])\n"
]
}
],
"source": [
"train_example = dataset[\"train\"][1]\n",
"print(train_example.keys())"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "9415bae0",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1979-01\n",
"[65072.37109375, 48612.19921875, 58452.58984375, 57033.96875, 71498.953125, 79187.109375, 101896.1015625, 115971.796875, 94962.1484375, 80648.328125, 64196.078125, 50364.859375, 57624.05859375, 47163.87109375, 48874.0703125, 62737.609375, 69621.1328125, 71454.21875, 107916.796875, 120461.5, 99441.1796875, 84936.5390625, 62809.51953125, 54028.48046875, 58605.91015625, 50516.33984375, 55711.5390625, 55798.41015625, 65033.1796875, 89421.140625, 119027.8984375, 133411.296875, 112890.703125, 96718.140625, 76462.796875, 57951.6796875, 62094.69140625, 55118.23046875, 66128.3515625, 71334.2578125, 75644.71875, 98380.4296875, 127255.0, 146442.703125, 121934.796875, 88537.546875, 71126.1484375, 80209.5078125, 72614.40625, 64114.46875, 70382.5, 77124.5390625, 85675.6796875, 110282.1015625, 135361.296875, 165238.90625, 139575.0, 113179.796875, 92898.84375, 69113.3984375, 74665.15625, 70439.328125, 72906.4609375, 79968.9609375, 114617.0, 123295.703125, 158611.703125, 187202.40625, 145933.703125, 142136.0, 104437.6015625, 77719.5390625, 87537.671875, 74318.6484375, 78690.7421875, 93133.796875, 107493.0, 115749.5, 170180.0, 182643.703125, 150198.40625, 144580.09375, 88946.34375, 79351.4609375, 93303.1171875, 83155.078125, 84729.3984375, 118132.703125, 109726.8984375, 136382.296875, 184945.203125, 183298.09375, 182025.203125, 154070.09375, 107512.6015625, 95252.90625, 100485.796875, 85406.546875, 99865.203125, 104420.0, 131736.796875, 151818.703125, 204872.0, 200877.90625, 172959.40625, 143267.203125, 100896.703125, 91154.3671875, 100889.296875, 105025.5, 103881.0, 117237.1015625, 125309.8984375, 145201.703125, 175963.296875, 220512.40625, 179558.0, 145964.90625, 109480.0, 88714.75, 104875.703125, 92488.0, 103057.296875, 105123.703125, 130371.296875, 139400.0, 162441.09375, 198709.203125, 154935.703125, 151893.296875, 76786.90625, 79548.8984375, 109728.203125, 88898.578125, 82404.8984375, 101708.703125, 112285.3984375, 106262.3984375, 155264.0]\n"
]
}
],
"source": [
"print(train_example[\"start\"])\n",
"print(train_example[\"target\"])"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "7c5d086c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['start', 'target', 'feat_static_cat', 'feat_dynamic_real', 'item_id'])"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"validation_example = dataset['validation'][1]\n",
"validation_example.keys()"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "dad75154",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1979-01-01 00:00:00\n",
"[65072.37109375, 48612.19921875, 58452.58984375, 57033.96875, 71498.953125, 79187.109375, 101896.1015625, 115971.796875, 94962.1484375, 80648.328125, 64196.078125, 50364.859375, 57624.05859375, 47163.87109375, 48874.0703125, 62737.609375, 69621.1328125, 71454.21875, 107916.796875, 120461.5, 99441.1796875, 84936.5390625, 62809.51953125, 54028.48046875, 58605.91015625, 50516.33984375, 55711.5390625, 55798.41015625, 65033.1796875, 89421.140625, 119027.8984375, 133411.296875, 112890.703125, 96718.140625, 76462.796875, 57951.6796875, 62094.69140625, 55118.23046875, 66128.3515625, 71334.2578125, 75644.71875, 98380.4296875, 127255.0, 146442.703125, 121934.796875, 88537.546875, 71126.1484375, 80209.5078125, 72614.40625, 64114.46875, 70382.5, 77124.5390625, 85675.6796875, 110282.1015625, 135361.296875, 165238.90625, 139575.0, 113179.796875, 92898.84375, 69113.3984375, 74665.15625, 70439.328125, 72906.4609375, 79968.9609375, 114617.0, 123295.703125, 158611.703125, 187202.40625, 145933.703125, 142136.0, 104437.6015625, 77719.5390625, 87537.671875, 74318.6484375, 78690.7421875, 93133.796875, 107493.0, 115749.5, 170180.0, 182643.703125, 150198.40625, 144580.09375, 88946.34375, 79351.4609375, 93303.1171875, 83155.078125, 84729.3984375, 118132.703125, 109726.8984375, 136382.296875, 184945.203125, 183298.09375, 182025.203125, 154070.09375, 107512.6015625, 95252.90625, 100485.796875, 85406.546875, 99865.203125, 104420.0, 131736.796875, 151818.703125, 204872.0, 200877.90625, 172959.40625, 143267.203125, 100896.703125, 91154.3671875, 100889.296875, 105025.5, 103881.0, 117237.1015625, 125309.8984375, 145201.703125, 175963.296875, 220512.40625, 179558.0, 145964.90625, 109480.0, 88714.75, 104875.703125, 92488.0, 103057.296875, 105123.703125, 130371.296875, 139400.0, 162441.09375, 198709.203125, 154935.703125, 151893.296875, 76786.90625, 79548.8984375, 109728.203125, 88898.578125, 82404.8984375, 101708.703125, 112285.3984375, 106262.3984375, 155264.0, 170488.703125, 132573.703125, 132368.203125, 102293.8984375, 77060.4609375, 105801.703125, 70909.859375, 71229.46875, 99842.4375, 106268.8984375, 102494.3984375, 163841.796875, 177779.296875, 146222.296875, 143864.203125, 95231.2890625, 87472.6875, 98428.953125, 80790.8984375, 90682.2109375, 97428.0078125, 115606.6015625, 130511.796875, 175490.203125]\n"
]
}
],
"source": [
"print(validation_example['start'])\n",
"print(validation_example['target'])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "df44a30f",
"metadata": {},
"outputs": [],
"source": [
"freq = \"1M\"\n",
"prediction_length = 24\n",
"\n",
"assert len(train_example[\"target\"]) + prediction_length == len(\n",
" validation_example[\"target\"]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "ef6e21e5",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"figure, axes = plt.subplots()\n",
"axes.plot(train_example[\"target\"], color=\"blue\")\n",
"axes.plot(validation_example[\"target\"], color=\"red\", alpha=0.5)\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "2e02e6de",
"metadata": {},
"outputs": [],
"source": [
"train_dataset = dataset[\"train\"]\n",
"test_dataset = dataset[\"test\"]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "030b3d23",
"metadata": {},
"outputs": [],
"source": [
"#lru_cache is a decorator for some recursive calculation\n",
"# the values will stored in cache so that cached values will be used in the future if repeating computation happens\n",
"\n",
"from functools import lru_cache\n",
"\n",
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"\n",
"@lru_cache(10_000)\n",
"def convert_to_pandas_period(date, freq):\n",
" return pd.Period(date, freq)\n",
"\n",
"def transform_start_field(batch, freq):\n",
" batch[\"start\"] = [convert_to_pandas_period(date, freq) for date in batch[\"start\"]]\n",
" return batch"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "965bdb44",
"metadata": {},
"outputs": [],
"source": [
"from functools import partial\n",
"\n",
"train_dataset.set_transform(partial(transform_start_field, freq=freq))\n",
"test_dataset.set_transform(partial(transform_start_field, freq=freq))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "f162ef59",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'We specify a couple of additional parameters to the model:\\n\\nprediction_length (in our case, 24 months): this is the horizon that the decoder of the Transformer will learn to predict for;\\ncontext_length: the model will set the context_length (input of the encoder) \\n equal to the prediction_length, if no context_length is specified;\\n \\nlags for a given frequency: these specify how much we \"look back\", to be added as \\n additional features. e.g. for a Daily frequency we might consider \\n a look back of [1, 2, 7, 30, ...] or in other words look back 1, 2, ... \\n days while for Minute data we might consider [1, 30, 60, 60*24, ...] etc.;\\n \\nthe number of time features: in our case, this will be 2 as we\\'ll add MonthOfYear and Age features;\\n\\nthe number of static categorical features: in our case, this will be just 1 as we\\'ll add a single \"time series ID\" feature;\\n\\nthe cardinality: the number of values of each static categorical feature, \\n as a list which for our case will be [366] as we have 366 different time series\\n \\nthe embedding dimension: the embedding dimension for each static categorical feature, \\n as a list, for example [3] means the model will learn an embedding \\n vector of size 3 for each of the 366 time series (regions). '"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\"\"\"We specify a couple of additional parameters to the model:\n",
"\n",
"prediction_length (in our case, 24 months): this is the horizon that the decoder of the Transformer will learn to predict for;\n",
"context_length: the model will set the context_length (input of the encoder) \n",
" equal to the prediction_length, if no context_length is specified;\n",
" \n",
"lags for a given frequency: these specify how much we \"look back\", to be added as \n",
" additional features. e.g. for a Daily frequency we might consider \n",
" a look back of [1, 2, 7, 30, ...] or in other words look back 1, 2, ... \n",
" days while for Minute data we might consider [1, 30, 60, 60*24, ...] etc.;\n",
" \n",
"the number of time features: in our case, this will be 2 as we'll add MonthOfYear and Age features;\n",
"\n",
"the number of static categorical features: in our case, this will be just 1 as we'll add a single \"time series ID\" feature;\n",
"\n",
"the cardinality: the number of values of each static categorical feature, \n",
" as a list which for our case will be [366] as we have 366 different time series\n",
" \n",
"the embedding dimension: the embedding dimension for each static categorical feature, \n",
" as a list, for example [3] means the model will learn an embedding \n",
" vector of size 3 for each of the 366 time series (regions). \"\"\""
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "09f1f194",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1, 2, 3, 4, 5, 6, 7, 11, 12, 13, 23, 24, 25, 35, 36, 37]\n"
]
}
],
"source": [
"from gluonts.time_feature import get_lags_for_frequency\n",
"\n",
"lags_sequence = get_lags_for_frequency(freq)\n",
"print(lags_sequence)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "ccd2aa3f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[<function month_of_year at 0x000002DFA39B55A0>]\n"
]
}
],
"source": [
"from gluonts.time_feature import time_features_from_frequency_str\n",
"\n",
"time_features = time_features_from_frequency_str(freq)\n",
"print(time_features)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "c84c2203",
"metadata": {},
"outputs": [],
"source": [
"from transformers import TimeSeriesTransformerConfig, TimeSeriesTransformerForPrediction\n",
"\n",
"config = TimeSeriesTransformerConfig(\n",
" prediction_length=prediction_length,\n",
" # context length:\n",
" context_length=prediction_length * 2,\n",
" # lags coming from helper given the freq:\n",
" lags_sequence=lags_sequence,\n",
" # we'll add 2 time features (\"month of year\" and \"age\", see further):\n",
" num_time_features=len(time_features) + 1,\n",
" # we have a single static categorical feature, namely time series ID:\n",
" num_static_categorical_features=1,\n",
" # it has 366 possible values:\n",
" cardinality=[len(train_dataset)],\n",
" # the model will learn an embedding of size 2 for each of the 366 possible values:\n",
" embedding_dimension=[2],\n",
" \n",
" # transformer params:\n",
" encoder_layers=4,\n",
" decoder_layers=4,\n",
" d_model=32,\n",
" \n",
")\n",
"\n",
"model = TimeSeriesTransformerForPrediction(config)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "76c9c724",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'\\note that, similar to other models in the 🤗 Transformers library, \\nTimeSeriesTransformerModel corresponds to the encoder-decoder Transformer without any head on top, \\nand TimeSeriesTransformerForPrediction corresponds to TimeSeriesTransformerModel with a distribution head on top. \\nBy default, the model uses a Student-t distribution (but this is configurable):\\n'"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\"\"\"\n",
"Note that, similar to other models in the 🤗 Transformers library, \n",
"TimeSeriesTransformerModel corresponds to the encoder-decoder Transformer without any head on top, \n",
"and TimeSeriesTransformerForPrediction corresponds to TimeSeriesTransformerModel with a distribution head on top. \n",
"By default, the model uses a Student-t distribution (but this is configurable):\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "e97434e0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'student_t'"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.config.distribution_output"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "83e665f0",
"metadata": {},
"outputs": [],
"source": [
"from gluonts.time_feature import (\n",
" time_features_from_frequency_str,\n",
" TimeFeature,\n",
" get_lags_for_frequency,\n",
")\n",
"from gluonts.dataset.field_names import FieldName\n",
"from gluonts.transform import (\n",
" AddAgeFeature,\n",
" AddObservedValuesIndicator,\n",
" AddTimeFeatures,\n",
" AsNumpyArray,\n",
" Chain,\n",
" ExpectedNumInstanceSampler,\n",
" InstanceSplitter,\n",
" RemoveFields,\n",
" SelectFields,\n",
" SetField,\n",
" TestSplitSampler,\n",
" Transformation,\n",
" ValidationSplitSampler,\n",
" VstackFeatures,\n",
" RenameFields,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "e3d6ce89",
"metadata": {},
"outputs": [],
"source": [
"from transformers import PretrainedConfig\n",
"\n",
"def create_transformation(freq: str, config: PretrainedConfig) -> Transformation:\n",
" remove_field_names = []\n",
" if config.num_static_real_features == 0:\n",
" remove_field_names.append(FieldName.FEAT_STATIC_REAL)\n",
" if config.num_dynamic_real_features == 0:\n",
" remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL)\n",
" if config.num_static_categorical_features == 0:\n",
" remove_field_names.append(FieldName.FEAT_STATIC_CAT)\n",
" print(remove_field_names)\n",
"\n",
" # a bit like torchvision.transforms.Compose\n",
" return Chain(\n",
" # step 1: remove static/dynamic fields if not specified\n",
" [RemoveFields(field_names=remove_field_names)]\n",
" # step 2: convert the data to NumPy (potentially not needed)\n",
" + (\n",
" [\n",
" AsNumpyArray(\n",
" field=FieldName.FEAT_STATIC_CAT,\n",
" expected_ndim=1,\n",
" dtype=int,\n",
" )\n",
" ]\n",
" if config.num_static_categorical_features > 0\n",
" else []\n",
" )\n",
" + (\n",
" [\n",
" AsNumpyArray(\n",
" field=FieldName.FEAT_STATIC_REAL,\n",
" expected_ndim=1,\n",
" )\n",
" ]\n",
" if config.num_static_real_features > 0\n",
" else []\n",
" )\n",
" + [\n",
" AsNumpyArray(\n",
" field=FieldName.TARGET,\n",
" # we expect an extra dim for the multivariate case:\n",
" expected_ndim=1 if config.input_size == 1 else 2,\n",
" ),\n",
" # step 3: handle the NaN's by filling in the target with zero\n",
" # and return the mask (which is in the observed values)\n",
" # true for observed values, false for nan's\n",
" # the decoder uses this mask (no loss is incurred for unobserved values)\n",
" # see loss_weights inside the xxxForPrediction model\n",
" AddObservedValuesIndicator(\n",
" target_field=FieldName.TARGET,\n",
" output_field=FieldName.OBSERVED_VALUES,\n",
" ),\n",
" # step 4: add temporal features based on freq of the dataset\n",
" # month of year in the case when freq=\"M\"\n",
" # these serve as positional encodings\n",
" AddTimeFeatures(\n",
" start_field=FieldName.START,\n",
" target_field=FieldName.TARGET,\n",
" output_field=FieldName.FEAT_TIME,\n",
" time_features=time_features_from_frequency_str(freq),\n",
" pred_length=config.prediction_length,\n",
" ),\n",
" # step 5: add another temporal feature (just a single number)\n",
" # tells the model where in its life the value of the time series is,\n",
" # sort of a running counter\n",
" AddAgeFeature(\n",
" target_field=FieldName.TARGET,\n",
" output_field=FieldName.FEAT_AGE,\n",
" pred_length=config.prediction_length,\n",
" log_scale=True,\n",
" ),\n",
" # step 6: vertically stack all the temporal features into the key FEAT_TIME\n",
" VstackFeatures(\n",
" output_field=FieldName.FEAT_TIME,\n",
" input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE]\n",
" + (\n",
" [FieldName.FEAT_DYNAMIC_REAL]\n",
" if config.num_dynamic_real_features > 0\n",
" else []\n",
" ),\n",
" ),\n",
" # step 7: rename to match HuggingFace names\n",
" RenameFields(\n",
" mapping={\n",
" FieldName.FEAT_STATIC_CAT: \"static_categorical_features\",\n",
" FieldName.FEAT_STATIC_REAL: \"static_real_features\",\n",
" FieldName.FEAT_TIME: \"time_features\",\n",
" FieldName.TARGET: \"values\",\n",
" FieldName.OBSERVED_VALUES: \"observed_mask\",\n",
" }\n",
" ),\n",
" ]\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "73c8d1fc",
"metadata": {},
"outputs": [],
"source": [
"from gluonts.transform.sampler import InstanceSampler\n",
"from typing import Optional\n",
"\n",
"def create_instance_splitter(\n",
" config: PretrainedConfig,\n",
" mode: str,\n",
" train_sampler: Optional[InstanceSampler] = None,\n",
" validation_sampler: Optional[InstanceSampler] = None,\n",
") -> Transformation:\n",
" assert mode in [\"train\", \"validation\", \"test\"]\n",
"\n",
" instance_sampler = {\n",
" \"train\": train_sampler\n",
" or ExpectedNumInstanceSampler(\n",
" num_instances=1.0, min_future=config.prediction_length\n",
" ),\n",
" \"validation\": validation_sampler\n",
" or ValidationSplitSampler(min_future=config.prediction_length),\n",
" \"test\": TestSplitSampler(),\n",
" }[mode]\n",
"\n",
" return InstanceSplitter(\n",
" target_field=\"values\",\n",
" is_pad_field=FieldName.IS_PAD,\n",
" start_field=FieldName.START,\n",
" forecast_start_field=FieldName.FORECAST_START,\n",
" instance_sampler=instance_sampler,\n",
" past_length=config.context_length + max(config.lags_sequence),\n",
" future_length=config.prediction_length,\n",
" time_series_fields=[\"time_features\", \"observed_mask\"],\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "82d5a50c",
"metadata": {},
"outputs": [],
"source": [
"from typing import Iterable\n",
"\n",
"import torch\n",
"from gluonts.itertools import Cached, Cyclic\n",
"from gluonts.dataset.loader import as_stacked_batches\n",
"\n",
"\n",
"def create_train_dataloader(\n",
" config: PretrainedConfig,\n",
" freq,\n",
" data,\n",
" batch_size: int,\n",
" num_batches_per_epoch: int,\n",
" shuffle_buffer_length: Optional[int] = None,\n",
" cache_data: bool = True,\n",
" **kwargs,\n",
") -> Iterable:\n",
" PREDICTION_INPUT_NAMES = [\n",
" \"past_time_features\",\n",
" \"past_values\",\n",
" \"past_observed_mask\",\n",
" \"future_time_features\",\n",
" ]\n",
" if config.num_static_categorical_features > 0:\n",
" PREDICTION_INPUT_NAMES.append(\"static_categorical_features\")\n",
"\n",
" if config.num_static_real_features > 0:\n",
" PREDICTION_INPUT_NAMES.append(\"static_real_features\")\n",
"\n",
" TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [\n",
" \"future_values\",\n",
" \"future_observed_mask\",\n",
" ]\n",
"\n",
" transformation = create_transformation(freq, config)\n",
" transformed_data = transformation.apply(data, is_train=True)\n",
" if cache_data:\n",
" transformed_data = Cached(transformed_data)\n",
"\n",
" # we initialize a Training instance\n",
" instance_splitter = create_instance_splitter(config, \"train\")\n",
"\n",
" # the instance splitter will sample a window of\n",
" # context length + lags + prediction length (from the 366 possible transformed time series)\n",
" # randomly from within the target time series and return an iterator.\n",
" stream = Cyclic(transformed_data).stream()\n",
" training_instances = instance_splitter.apply(\n",
" stream, is_train=True\n",
" )\n",
" \n",
" return as_stacked_batches(\n",
" training_instances,\n",
" batch_size=batch_size,\n",
" shuffle_buffer_length=shuffle_buffer_length,\n",
" field_names=TRAINING_INPUT_NAMES,\n",
" output_type=torch.tensor,\n",
" num_batches_per_epoch=num_batches_per_epoch,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "4b3f4fb2",
"metadata": {},
"outputs": [],
"source": [
"def create_test_dataloader(\n",
" config: PretrainedConfig,\n",
" freq,\n",
" data,\n",
" batch_size: int,\n",
" **kwargs,\n",
"):\n",
" PREDICTION_INPUT_NAMES = [\n",
" \"past_time_features\",\n",
" \"past_values\",\n",
" \"past_observed_mask\",\n",
" \"future_time_features\",\n",
" ]\n",
" if config.num_static_categorical_features > 0:\n",
" PREDICTION_INPUT_NAMES.append(\"static_categorical_features\")\n",
"\n",
" if config.num_static_real_features > 0:\n",
" PREDICTION_INPUT_NAMES.append(\"static_real_features\")\n",
"\n",
" transformation = create_transformation(freq, config)\n",
" transformed_data = transformation.apply(data, is_train=False)\n",
"\n",
" # we create a Test Instance splitter which will sample the very last\n",
" # context window seen during training only for the encoder.\n",
" instance_sampler = create_instance_splitter(config, \"test\")\n",
"\n",
" # we apply the transformations in test mode\n",
" testing_instances = instance_sampler.apply(transformed_data, is_train=False)\n",
" \n",
" return as_stacked_batches(\n",
" testing_instances,\n",
" batch_size=batch_size,\n",
" output_type=torch.tensor,\n",
" field_names=PREDICTION_INPUT_NAMES,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "2cbf8ec7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['feat_static_real', 'feat_dynamic_real']\n",
"['feat_static_real', 'feat_dynamic_real']\n"
]
}
],
"source": [
"train_dataloader = create_train_dataloader(\n",
" config=config,\n",
" freq=freq,\n",
" data=train_dataset,\n",
" batch_size=256,\n",
" num_batches_per_epoch=100,\n",
")\n",
"\n",
"test_dataloader = create_test_dataloader(\n",
" config=config,\n",
" freq=freq,\n",
" data=test_dataset,\n",
" batch_size=64,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "064793ed",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"past_time_features torch.Size([256, 85, 2]) torch.FloatTensor\n",
"past_values torch.Size([256, 85]) torch.FloatTensor\n",
"past_observed_mask torch.Size([256, 85]) torch.FloatTensor\n",
"future_time_features torch.Size([256, 24, 2]) torch.FloatTensor\n",
"static_categorical_features torch.Size([256, 1]) torch.IntTensor\n",
"future_values torch.Size([256, 24]) torch.FloatTensor\n",
"future_observed_mask torch.Size([256, 24]) torch.FloatTensor\n"
]
}
],
"source": [
"batch = next(iter(train_dataloader))\n",
"for k, v in batch.items():\n",
" print(k, v.shape, v.type())"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "9a1e2036",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# perform forward pass\n",
"outputs = model(\n",
" past_values=batch[\"past_values\"],\n",
" past_time_features=batch[\"past_time_features\"],\n",
" past_observed_mask=batch[\"past_observed_mask\"],\n",
" static_categorical_features=batch[\"static_categorical_features\"]\n",
" if config.num_static_categorical_features > 0\n",
" else None,\n",
" static_real_features=batch[\"static_real_features\"]\n",
" if config.num_static_real_features > 0\n",
" else None,\n",
" future_values=batch[\"future_values\"],\n",
" future_time_features=batch[\"future_time_features\"],\n",
" future_observed_mask=batch[\"future_observed_mask\"],\n",
" output_hidden_states=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "29b4d896",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss: 9.22168254852295\n",
"CPU times: total: 0 ns\n",
"Wall time: 1 ms\n"
]
}
],
"source": [
"%%time\n",
"\n",
"print(\"Loss:\", outputs.loss.item())"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "b00eda51",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"9.025437355041504\n",
"CPU times: total: 32min 9s\n",
"Wall time: 4min 36s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"from accelerate import Accelerator\n",
"from torch.optim import AdamW\n",
"\n",
"accelerator = Accelerator()\n",
"device = accelerator.device\n",
"\n",
"model.to(device)\n",
"optimizer = AdamW(model.parameters(), lr=6e-4, betas=(0.9, 0.95), weight_decay=1e-1)\n",
"\n",
"model, optimizer, train_dataloader = accelerator.prepare(\n",
" model,\n",
" optimizer,\n",
" train_dataloader,\n",
")\n",
"\n",
"model.train()\n",
"for epoch in range(1):\n",
" for idx, batch in enumerate(train_dataloader):\n",
" optimizer.zero_grad()\n",
" outputs = model(\n",
" static_categorical_features=batch[\"static_categorical_features\"].to(device)\n",
" if config.num_static_categorical_features > 0\n",
" else None,\n",
" static_real_features=batch[\"static_real_features\"].to(device)\n",
" if config.num_static_real_features > 0\n",
" else None,\n",
" past_time_features=batch[\"past_time_features\"].to(device),\n",
" past_values=batch[\"past_values\"].to(device),\n",
" future_time_features=batch[\"future_time_features\"].to(device),\n",
" future_values=batch[\"future_values\"].to(device),\n",
" past_observed_mask=batch[\"past_observed_mask\"].to(device),\n",
" future_observed_mask=batch[\"future_observed_mask\"].to(device),\n",
" )\n",
" loss = outputs.loss\n",
"\n",
" # Backpropagation\n",
" accelerator.backward(loss)\n",
" optimizer.step()\n",
"\n",
" if idx % 100 == 0:\n",
" print(loss.item())"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "2d4001c6",
"metadata": {},
"outputs": [],
"source": [
"model.eval()\n",
"\n",
"forecasts = []\n",
"\n",
"for batch in test_dataloader:\n",
" outputs = model.generate(\n",
" static_categorical_features=batch[\"static_categorical_features\"].to(device)\n",
" if config.num_static_categorical_features > 0\n",
" else None,\n",
" static_real_features=batch[\"static_real_features\"].to(device)\n",
" if config.num_static_real_features > 0\n",
" else None,\n",
" past_time_features=batch[\"past_time_features\"].to(device),\n",
" past_values=batch[\"past_values\"].to(device),\n",
" future_time_features=batch[\"future_time_features\"].to(device),\n",
" past_observed_mask=batch[\"past_observed_mask\"].to(device),\n",
" )\n",
" forecasts.append(outputs.sequences.cpu().numpy())"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "f924996a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(64, 100, 24)\n"
]
}
],
"source": [
"print(forecasts[0].shape)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "d3a9c5db",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(366, 100, 24)\n"
]
}
],
"source": [
"forecasts = np.vstack(forecasts)\n",
"print(forecasts.shape)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "7dd17819",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7b3f56e729e3434ba3a89ce3681b6443",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading builder script: 0%| | 0.00/5.50k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cb80639aba454f75b9620c0c720eefaf",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading builder script: 0%| | 0.00/6.65k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from evaluate import load\n",
"from gluonts.time_feature import get_seasonality\n",
"\n",
"mase_metric = load(\"mase\")\n",
"smape_metric = load(\"smape\")\n",
"\n",
"forecast_median = np.median(forecasts, 1)\n",
"\n",
"mase_metrics = []\n",
"smape_metrics = []\n",
"for item_id, ts in enumerate(test_dataset):\n",
" training_data = ts[\"target\"][:-prediction_length]\n",
" ground_truth = ts[\"target\"][-prediction_length:]\n",
" mase = mase_metric.compute(\n",
" predictions=forecast_median[item_id], \n",
" references=np.array(ground_truth), \n",
" training=np.array(training_data), \n",
" periodicity=get_seasonality(freq))\n",
" mase_metrics.append(mase[\"mase\"])\n",
" \n",
" smape = smape_metric.compute(\n",
" predictions=forecast_median[item_id], \n",
" references=np.array(ground_truth), \n",
" )\n",
" smape_metrics.append(smape[\"smape\"])"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "b0d85234",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MASE: 3.1912513800374036\n",
"sMAPE: 0.34196129648343504\n"
]
}
],
"source": [
"print(f\"MASE: {np.mean(mase_metrics)}\")\n",
"print(f\"sMAPE: {np.mean(smape_metrics)}\")"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "29d1ffc2",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.scatter(mase_metrics, smape_metrics, alpha=0.3)\n",
"plt.xlabel(\"MASE\")\n",
"plt.ylabel(\"sMAPE\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "d0be3a68",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.dates as mdates\n",
"\n",
"def plot(ts_index):\n",
" fig, ax = plt.subplots()\n",
"\n",
" index = pd.period_range(\n",
" start=test_dataset[ts_index][FieldName.START],\n",
" periods=len(test_dataset[ts_index][FieldName.TARGET]),\n",
" freq=freq,\n",
" ).to_timestamp()\n",
"\n",
" # Major ticks every half year, minor ticks every month,\n",
" ax.xaxis.set_major_locator(mdates.MonthLocator(bymonth=(1, 7)))\n",
" ax.xaxis.set_minor_locator(mdates.MonthLocator())\n",
"\n",
" ax.plot(\n",
" index[-2*prediction_length:], \n",
" test_dataset[ts_index][\"target\"][-2*prediction_length:],\n",
" label=\"actual\",\n",
" )\n",
"\n",
" plt.plot(\n",
" index[-prediction_length:], \n",
" np.median(forecasts[ts_index], axis=0),\n",
" label=\"median\",\n",
" )\n",
" \n",
" plt.fill_between(\n",
" index[-prediction_length:],\n",
" forecasts[ts_index].mean(0) - forecasts[ts_index].std(axis=0), \n",
" forecasts[ts_index].mean(0) + forecasts[ts_index].std(axis=0), \n",
" alpha=0.3, \n",
" interpolate=True,\n",
" label=\"+/- 1-std\",\n",
" )\n",
" plt.legend()\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "015533fd",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot(334)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bae5b44c",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|