vijaye12 commited on
Commit
07c3734
·
verified ·
1 Parent(s): a2efe1c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +72 -8
README.md CHANGED
@@ -52,23 +52,28 @@ TTMs that can cater to many common forecasting settings in practice.
52
  Each pre-trained model will be released in a different branch name in this model card. Kindly access the required model using our
53
  getting started [notebook](https://github.com/IBM/tsfm/blob/main/notebooks/hfdemo/ttm_getting_started.ipynb) mentioning the branch name.
54
 
55
-
56
  ## Model Releases (along with the branch name where the models are stored):
57
 
58
- - **512-96-r2**: Given the last 512 time-points (i.e. context length), this model can forecast up to next 96 time-points (i.e. forecast length)
 
 
59
  in future. (branch name: main)
60
 
61
- - **1024-96-r2**: Given the last 1024 time-points (i.e. context length), this model can forecast up to next 96 time-points (i.e. forecast length)
62
  in future. (branch name: 1024-96-r2) [[Benchmarks]]
63
 
64
- - **1536-96-r2**: Given the last 1536 time-points (i.e. context length), this model can forecast up to next 96 time-points (i.e. forecast length)
65
  in future. (branch name: 1536-96-r2)
66
 
67
- - Likewise, we have models released for forecast lengths up to 720 timepoints. The branch names for these are as follows: 512-192-r2, 1024-192-r2, 1536-192-r2, 512-336-r2,
68
- 512-336-r2, 1024-336-r2, 1536-336-r2, 512-720-r2, 1024-720-r2, 1536-720-r2
69
 
70
  - Please use the [[get_model]](https://github.com/ibm-granite/granite-tsfm/blob/main/tsfm_public/toolkit/get_model.py) utility to automatically select the required model based on your input context length and forecast length requirement.
71
 
 
 
 
72
 
73
 
74
 
@@ -143,12 +148,63 @@ In addition, TTM also supports exogenous infusion and categorical data infusion.
143
 
144
  ## Uses
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  ```
147
  # Load Model from HF Model Hub mentioning the branch name in revision field
148
 
 
149
  model = TinyTimeMixerForPrediction.from_pretrained(
150
- "https://huggingface.co/ibm/TTM", revision="main"
151
- )
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  # Do zeroshot
154
  zeroshot_trainer = Trainer(
@@ -166,6 +222,14 @@ zeroshot_output = zeroshot_trainer.evaluate(dset_test)
166
  for param in model.backbone.parameters():
167
  param.requires_grad = False
168
 
 
 
 
 
 
 
 
 
169
  finetune_forecast_trainer = Trainer(
170
  model=model,
171
  args=finetune_forecast_args,
 
52
  Each pre-trained model will be released in a different branch name in this model card. Kindly access the required model using our
53
  getting started [notebook](https://github.com/IBM/tsfm/blob/main/notebooks/hfdemo/ttm_getting_started.ipynb) mentioning the branch name.
54
 
55
+
56
  ## Model Releases (along with the branch name where the models are stored):
57
 
58
+
59
+
60
+ - **512-96-r2**: Given the last 512 time-points (i.e. context length), this model can forecast up to the next 96 time-points (i.e. forecast length)
61
  in future. (branch name: main)
62
 
63
+ - **1024-96-r2**: Given the last 1024 time-points (i.e. context length), this model can forecast up to the next 96 time-points (i.e. forecast length)
64
  in future. (branch name: 1024-96-r2) [[Benchmarks]]
65
 
66
+ - **1536-96-r2**: Given the last 1536 time-points (i.e. context length), this model can forecast up to the next 96 time-points (i.e. forecast length)
67
  in future. (branch name: 1536-96-r2)
68
 
69
+ - Likewise, we have models released for forecast lengths up to 720 timepoints. The branch names for these are as follows: `512-192-r2`, `1024-192-r2`, `1536-192-r2`, `512-336-r2`,
70
+ `512-336-r2`, `1024-336-r2`, `1536-336-r2`, `512-720-r2`, `1024-720-r2`, `1536-720-r2`
71
 
72
  - Please use the [[get_model]](https://github.com/ibm-granite/granite-tsfm/blob/main/tsfm_public/toolkit/get_model.py) utility to automatically select the required model based on your input context length and forecast length requirement.
73
 
74
+ - We currently allow 3 context lengths (512, 1024 and 1536) and 4 forecast lengths (96, 192, 336, 720). Users need to provide the exact context length as input.
75
+ but can provide any forecast lengths up to 720 in get_model().
76
+
77
 
78
 
79
 
 
148
 
149
  ## Uses
150
 
151
+
152
+ Automatic Model selection
153
+ ```
154
+ def get_model(
155
+ model_path,
156
+ model_name: str = "ttm",
157
+ context_length: int = None,
158
+ prediction_length: int = None,
159
+ freq_prefix_tuning: bool = None,
160
+ **kwargs,
161
+ ):
162
+
163
+ TTM Model card offers a suite of models with varying context_length and forecast_length combinations.
164
+ This wrapper automatically selects the right model based on the given input context_length and prediction_length abstracting away the internal
165
+ complexity.
166
+
167
+ Args:
168
+ model_path (str):
169
+ HF model card path or local model path (Ex. ibm-granite/granite-timeseries-ttm-r1)
170
+ model_name (*optional*, str)
171
+ model name to use. Allowed values: ttm
172
+ context_length (int):
173
+ Input Context length. For ibm-granite/granite-timeseries-ttm-r1, we allow 512 and 1024.
174
+ For ibm-granite/granite-timeseries-ttm-r2 and ibm/ttm-research-r2, we allow 512, 1024 and 1536
175
+ prediction_length (int):
176
+ Forecast length to predict. For ibm-granite/granite-timeseries-ttm-r1, we can forecast upto 96.
177
+ For ibm-granite/granite-timeseries-ttm-r2 and ibm/ttm-research-r2, we can forecast upto 720.
178
+ Model is trained for fixed forecast lengths (96,192,336,720) and this model add required `prediction_filter_length` to the model instance for required pruning.
179
+ For Ex. if we need to forecast 150 timepoints given last 512 timepoints using model_path = ibm-granite/granite-timeseries-ttm-r2, then get_model will select the
180
+ model from 512_192_r2 branch and applies prediction_filter_length = 150 to prune the forecasts from 192 to 150. prediction_filter_length also applies loss
181
+ only to the pruned forecasts during finetuning.
182
+ freq_prefix_tuning (*optional*, bool):
183
+ Future use. Currently do no use this parameter.
184
+ kwargs:
185
+ Pass all the extra fine-tuning model parameters intended to be passed in the from_pretrained call to update model configuration.
186
+
187
+
188
+ ```
189
+
190
  ```
191
  # Load Model from HF Model Hub mentioning the branch name in revision field
192
 
193
+
194
  model = TinyTimeMixerForPrediction.from_pretrained(
195
+ "https://huggingface.co/ibm-granite/granite-timeseries-ttm-r2", revision="main"
196
+ )
197
+
198
+ or
199
+
200
+ from tsfm_public.toolkit.get_model import get_model
201
+ model = get_model(
202
+ model_path="https://huggingface.co/ibm-granite/granite-timeseries-ttm-r2",
203
+ context_length=512,
204
+ prediction_length=96
205
+ )
206
+
207
+
208
 
209
  # Do zeroshot
210
  zeroshot_trainer = Trainer(
 
222
  for param in model.backbone.parameters():
223
  param.requires_grad = False
224
 
225
+ finetune_model = get_model(
226
+ model_path="https://huggingface.co/ibm-granite/granite-timeseries-ttm-r2",
227
+ context_length=512,
228
+ prediction_length=96,
229
+ # pass other finetune params of decoder or head
230
+ head_dropout = 0.2
231
+ )
232
+
233
  finetune_forecast_trainer = Trainer(
234
  model=model,
235
  args=finetune_forecast_args,