wavlm-large / s3prl_s3prl_main /s3prl /dataset /autoregressive_prediction_pipes.py
lmzjms's picture
Upload 1162 files
0b32ad6 verified
import copy
from dataclasses import dataclass
import torch
from .base import AugmentedDynamicItemDataset, DataPipe
@dataclass
class AutoregressivePrediction(DataPipe):
n_future: int = 5
source_feat_name: str = (
"source_feat" # tensors in the shape of: (seq_len, feat_dim)
)
target_feat_name: str = (
"target_feat" # tensors in the shape of: (seq_len, feat_dim)
)
source_feat_len_name: str = "feat_len"
def generate_shifted_data(self, source_feat):
with torch.no_grad():
feat_len = int(source_feat.size(0)) - self.n_future
target_feat = copy.deepcopy(source_feat[self.n_future :, :])
source_feat = source_feat[: -self.n_future, :]
target_feat = target_feat.to(dtype=torch.float32)
source_feat = source_feat.to(dtype=torch.float32)
return source_feat, target_feat, feat_len
def __call__(self, dataset: AugmentedDynamicItemDataset):
dataset.add_dynamic_item(
self.generate_shifted_data,
takes=self.source_feat_name,
provides=[
self.source_feat_name,
self.target_feat_name,
self.source_feat_len_name,
],
)
return dataset