|
|
|
|
|
|
|
|
|
""" |
|
This module contains collection of classes which implement |
|
collate functionalities for various tasks. |
|
|
|
Collaters should know what data to expect for each sample |
|
and they should pack / collate them into batches |
|
""" |
|
|
|
|
|
from __future__ import absolute_import, division, print_function, unicode_literals |
|
|
|
import numpy as np |
|
import torch |
|
from fairseq.data import data_utils as fairseq_data_utils |
|
|
|
|
|
class Seq2SeqCollater(object): |
|
""" |
|
Implements collate function mainly for seq2seq tasks |
|
This expects each sample to contain feature (src_tokens) and |
|
targets. |
|
This collator is also used for aligned training task. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
feature_index=0, |
|
label_index=1, |
|
pad_index=1, |
|
eos_index=2, |
|
move_eos_to_beginning=True, |
|
): |
|
self.feature_index = feature_index |
|
self.label_index = label_index |
|
self.pad_index = pad_index |
|
self.eos_index = eos_index |
|
self.move_eos_to_beginning = move_eos_to_beginning |
|
|
|
def _collate_frames(self, frames): |
|
"""Convert a list of 2d frames into a padded 3d tensor |
|
Args: |
|
frames (list): list of 2d frames of size L[i]*f_dim. Where L[i] is |
|
length of i-th frame and f_dim is static dimension of features |
|
Returns: |
|
3d tensor of size len(frames)*len_max*f_dim where len_max is max of L[i] |
|
""" |
|
len_max = max(frame.size(0) for frame in frames) |
|
f_dim = frames[0].size(1) |
|
res = frames[0].new(len(frames), len_max, f_dim).fill_(0.0) |
|
|
|
for i, v in enumerate(frames): |
|
res[i, : v.size(0)] = v |
|
|
|
return res |
|
|
|
def collate(self, samples): |
|
""" |
|
utility function to collate samples into batch for speech recognition. |
|
""" |
|
if len(samples) == 0: |
|
return {} |
|
|
|
|
|
parsed_samples = [] |
|
for s in samples: |
|
|
|
if s["data"][self.feature_index] is None: |
|
continue |
|
source = s["data"][self.feature_index] |
|
if isinstance(source, (np.ndarray, np.generic)): |
|
source = torch.from_numpy(source) |
|
target = s["data"][self.label_index] |
|
if isinstance(target, (np.ndarray, np.generic)): |
|
target = torch.from_numpy(target).long() |
|
elif isinstance(target, list): |
|
target = torch.LongTensor(target) |
|
|
|
parsed_sample = {"id": s["id"], "source": source, "target": target} |
|
parsed_samples.append(parsed_sample) |
|
samples = parsed_samples |
|
|
|
id = torch.LongTensor([s["id"] for s in samples]) |
|
frames = self._collate_frames([s["source"] for s in samples]) |
|
|
|
frames_lengths = torch.LongTensor([s["source"].size(0) for s in samples]) |
|
frames_lengths, sort_order = frames_lengths.sort(descending=True) |
|
id = id.index_select(0, sort_order) |
|
frames = frames.index_select(0, sort_order) |
|
|
|
target = None |
|
target_lengths = None |
|
prev_output_tokens = None |
|
if samples[0].get("target", None) is not None: |
|
ntokens = sum(len(s["target"]) for s in samples) |
|
target = fairseq_data_utils.collate_tokens( |
|
[s["target"] for s in samples], |
|
self.pad_index, |
|
self.eos_index, |
|
left_pad=False, |
|
move_eos_to_beginning=False, |
|
) |
|
target = target.index_select(0, sort_order) |
|
target_lengths = torch.LongTensor( |
|
[s["target"].size(0) for s in samples] |
|
).index_select(0, sort_order) |
|
prev_output_tokens = fairseq_data_utils.collate_tokens( |
|
[s["target"] for s in samples], |
|
self.pad_index, |
|
self.eos_index, |
|
left_pad=False, |
|
move_eos_to_beginning=self.move_eos_to_beginning, |
|
) |
|
prev_output_tokens = prev_output_tokens.index_select(0, sort_order) |
|
else: |
|
ntokens = sum(len(s["source"]) for s in samples) |
|
|
|
batch = { |
|
"id": id, |
|
"ntokens": ntokens, |
|
"net_input": {"src_tokens": frames, "src_lengths": frames_lengths}, |
|
"target": target, |
|
"target_lengths": target_lengths, |
|
"nsentences": len(samples), |
|
} |
|
if prev_output_tokens is not None: |
|
batch["net_input"]["prev_output_tokens"] = prev_output_tokens |
|
return batch |
|
|