Update README.md
Browse files
README.md
CHANGED
@@ -52,23 +52,28 @@ TTMs that can cater to many common forecasting settings in practice.
|
|
52 |
Each pre-trained model will be released in a different branch name in this model card. Kindly access the required model using our
|
53 |
getting started [notebook](https://github.com/IBM/tsfm/blob/main/notebooks/hfdemo/ttm_getting_started.ipynb) mentioning the branch name.
|
54 |
|
55 |
-
|
56 |
## Model Releases (along with the branch name where the models are stored):
|
57 |
|
58 |
-
|
|
|
|
|
59 |
in future. (branch name: main)
|
60 |
|
61 |
-
- **1024-96-r2**: Given the last 1024 time-points (i.e. context length), this model can forecast up to next 96 time-points (i.e. forecast length)
|
62 |
in future. (branch name: 1024-96-r2) [[Benchmarks]]
|
63 |
|
64 |
-
- **1536-96-r2**: Given the last 1536 time-points (i.e. context length), this model can forecast up to next 96 time-points (i.e. forecast length)
|
65 |
in future. (branch name: 1536-96-r2)
|
66 |
|
67 |
-
- Likewise, we have models released for forecast lengths up to 720 timepoints. The branch names for these are as follows: 512-192-r2
|
68 |
-
512-336-r2
|
69 |
|
70 |
- Please use the [[get_model]](https://github.com/ibm-granite/granite-tsfm/blob/main/tsfm_public/toolkit/get_model.py) utility to automatically select the required model based on your input context length and forecast length requirement.
|
71 |
|
|
|
|
|
|
|
72 |
|
73 |
|
74 |
|
@@ -143,12 +148,63 @@ In addition, TTM also supports exogenous infusion and categorical data infusion.
|
|
143 |
|
144 |
## Uses
|
145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
```
|
147 |
# Load Model from HF Model Hub mentioning the branch name in revision field
|
148 |
|
|
|
149 |
model = TinyTimeMixerForPrediction.from_pretrained(
|
150 |
-
"https://huggingface.co/ibm/
|
151 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
# Do zeroshot
|
154 |
zeroshot_trainer = Trainer(
|
@@ -166,6 +222,14 @@ zeroshot_output = zeroshot_trainer.evaluate(dset_test)
|
|
166 |
for param in model.backbone.parameters():
|
167 |
param.requires_grad = False
|
168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
finetune_forecast_trainer = Trainer(
|
170 |
model=model,
|
171 |
args=finetune_forecast_args,
|
|
|
52 |
Each pre-trained model will be released in a different branch name in this model card. Kindly access the required model using our
|
53 |
getting started [notebook](https://github.com/IBM/tsfm/blob/main/notebooks/hfdemo/ttm_getting_started.ipynb) mentioning the branch name.
|
54 |
|
55 |
+
|
56 |
## Model Releases (along with the branch name where the models are stored):
|
57 |
|
58 |
+
|
59 |
+
|
60 |
+
- **512-96-r2**: Given the last 512 time-points (i.e. context length), this model can forecast up to the next 96 time-points (i.e. forecast length)
|
61 |
in future. (branch name: main)
|
62 |
|
63 |
+
- **1024-96-r2**: Given the last 1024 time-points (i.e. context length), this model can forecast up to the next 96 time-points (i.e. forecast length)
|
64 |
in future. (branch name: 1024-96-r2) [[Benchmarks]]
|
65 |
|
66 |
+
- **1536-96-r2**: Given the last 1536 time-points (i.e. context length), this model can forecast up to the next 96 time-points (i.e. forecast length)
|
67 |
in future. (branch name: 1536-96-r2)
|
68 |
|
69 |
+
- Likewise, we have models released for forecast lengths up to 720 timepoints. The branch names for these are as follows: `512-192-r2`, `1024-192-r2`, `1536-192-r2`, `512-336-r2`,
|
70 |
+
`512-336-r2`, `1024-336-r2`, `1536-336-r2`, `512-720-r2`, `1024-720-r2`, `1536-720-r2`
|
71 |
|
72 |
- Please use the [[get_model]](https://github.com/ibm-granite/granite-tsfm/blob/main/tsfm_public/toolkit/get_model.py) utility to automatically select the required model based on your input context length and forecast length requirement.
|
73 |
|
74 |
+
- We currently allow 3 context lengths (512, 1024 and 1536) and 4 forecast lengths (96, 192, 336, 720). Users need to provide the exact context length as input.
|
75 |
+
but can provide any forecast lengths up to 720 in get_model().
|
76 |
+
|
77 |
|
78 |
|
79 |
|
|
|
148 |
|
149 |
## Uses
|
150 |
|
151 |
+
|
152 |
+
Automatic Model selection
|
153 |
+
```
|
154 |
+
def get_model(
|
155 |
+
model_path,
|
156 |
+
model_name: str = "ttm",
|
157 |
+
context_length: int = None,
|
158 |
+
prediction_length: int = None,
|
159 |
+
freq_prefix_tuning: bool = None,
|
160 |
+
**kwargs,
|
161 |
+
):
|
162 |
+
|
163 |
+
TTM Model card offers a suite of models with varying context_length and forecast_length combinations.
|
164 |
+
This wrapper automatically selects the right model based on the given input context_length and prediction_length abstracting away the internal
|
165 |
+
complexity.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
model_path (str):
|
169 |
+
HF model card path or local model path (Ex. ibm-granite/granite-timeseries-ttm-r1)
|
170 |
+
model_name (*optional*, str)
|
171 |
+
model name to use. Allowed values: ttm
|
172 |
+
context_length (int):
|
173 |
+
Input Context length. For ibm-granite/granite-timeseries-ttm-r1, we allow 512 and 1024.
|
174 |
+
For ibm-granite/granite-timeseries-ttm-r2 and ibm/ttm-research-r2, we allow 512, 1024 and 1536
|
175 |
+
prediction_length (int):
|
176 |
+
Forecast length to predict. For ibm-granite/granite-timeseries-ttm-r1, we can forecast upto 96.
|
177 |
+
For ibm-granite/granite-timeseries-ttm-r2 and ibm/ttm-research-r2, we can forecast upto 720.
|
178 |
+
Model is trained for fixed forecast lengths (96,192,336,720) and this model add required `prediction_filter_length` to the model instance for required pruning.
|
179 |
+
For Ex. if we need to forecast 150 timepoints given last 512 timepoints using model_path = ibm-granite/granite-timeseries-ttm-r2, then get_model will select the
|
180 |
+
model from 512_192_r2 branch and applies prediction_filter_length = 150 to prune the forecasts from 192 to 150. prediction_filter_length also applies loss
|
181 |
+
only to the pruned forecasts during finetuning.
|
182 |
+
freq_prefix_tuning (*optional*, bool):
|
183 |
+
Future use. Currently do no use this parameter.
|
184 |
+
kwargs:
|
185 |
+
Pass all the extra fine-tuning model parameters intended to be passed in the from_pretrained call to update model configuration.
|
186 |
+
|
187 |
+
|
188 |
+
```
|
189 |
+
|
190 |
```
|
191 |
# Load Model from HF Model Hub mentioning the branch name in revision field
|
192 |
|
193 |
+
|
194 |
model = TinyTimeMixerForPrediction.from_pretrained(
|
195 |
+
"https://huggingface.co/ibm-granite/granite-timeseries-ttm-r2", revision="main"
|
196 |
+
)
|
197 |
+
|
198 |
+
or
|
199 |
+
|
200 |
+
from tsfm_public.toolkit.get_model import get_model
|
201 |
+
model = get_model(
|
202 |
+
model_path="https://huggingface.co/ibm-granite/granite-timeseries-ttm-r2",
|
203 |
+
context_length=512,
|
204 |
+
prediction_length=96
|
205 |
+
)
|
206 |
+
|
207 |
+
|
208 |
|
209 |
# Do zeroshot
|
210 |
zeroshot_trainer = Trainer(
|
|
|
222 |
for param in model.backbone.parameters():
|
223 |
param.requires_grad = False
|
224 |
|
225 |
+
finetune_model = get_model(
|
226 |
+
model_path="https://huggingface.co/ibm-granite/granite-timeseries-ttm-r2",
|
227 |
+
context_length=512,
|
228 |
+
prediction_length=96,
|
229 |
+
# pass other finetune params of decoder or head
|
230 |
+
head_dropout = 0.2
|
231 |
+
)
|
232 |
+
|
233 |
finetune_forecast_trainer = Trainer(
|
234 |
model=model,
|
235 |
args=finetune_forecast_args,
|