|
import copy |
|
from dataclasses import dataclass |
|
|
|
import torch |
|
|
|
from .base import AugmentedDynamicItemDataset, DataPipe |
|
|
|
|
|
@dataclass |
|
class AutoregressivePrediction(DataPipe): |
|
n_future: int = 5 |
|
source_feat_name: str = ( |
|
"source_feat" |
|
) |
|
target_feat_name: str = ( |
|
"target_feat" |
|
) |
|
source_feat_len_name: str = "feat_len" |
|
|
|
def generate_shifted_data(self, source_feat): |
|
with torch.no_grad(): |
|
feat_len = int(source_feat.size(0)) - self.n_future |
|
target_feat = copy.deepcopy(source_feat[self.n_future :, :]) |
|
source_feat = source_feat[: -self.n_future, :] |
|
|
|
target_feat = target_feat.to(dtype=torch.float32) |
|
source_feat = source_feat.to(dtype=torch.float32) |
|
|
|
return source_feat, target_feat, feat_len |
|
|
|
def __call__(self, dataset: AugmentedDynamicItemDataset): |
|
dataset.add_dynamic_item( |
|
self.generate_shifted_data, |
|
takes=self.source_feat_name, |
|
provides=[ |
|
self.source_feat_name, |
|
self.target_feat_name, |
|
self.source_feat_len_name, |
|
], |
|
) |
|
return dataset |
|
|