|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import torch |
|
from fairseq.data import FairseqDataset, plasma_utils |
|
from fairseq.data.indexed_dataset import best_fitting_int_dtype |
|
from typing import Tuple |
|
|
|
|
|
class TokenBlockDataset(FairseqDataset): |
|
"""Break a Dataset of tokens into blocks. |
|
|
|
Args: |
|
dataset (~torch.utils.data.Dataset): dataset to break into blocks |
|
sizes (List[int]): sentence lengths (required for 'complete' and 'eos') |
|
block_size (int): maximum block size (ignored in 'eos' break mode) |
|
break_mode (str, optional): Mode used for breaking tokens. Values can |
|
be one of: |
|
- 'none': break tokens into equally sized blocks (up to block_size) |
|
- 'complete': break tokens into blocks (up to block_size) such that |
|
blocks contains complete sentences, although block_size may be |
|
exceeded if some sentences exceed block_size |
|
- 'complete_doc': similar to 'complete' mode, but do not |
|
cross document boundaries |
|
- 'eos': each block contains one sentence (block_size is ignored) |
|
include_targets (bool, optional): return next tokens as targets |
|
(default: False). |
|
document_sep_len (int, optional): document separator size (required for |
|
'complete_doc' break mode). Typically 1 if the sentences have eos |
|
and 0 otherwise. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dataset, |
|
sizes, |
|
block_size, |
|
pad, |
|
eos, |
|
break_mode=None, |
|
include_targets=False, |
|
document_sep_len=1, |
|
use_plasma_view=False, |
|
split_path=None, |
|
plasma_path=None, |
|
): |
|
|
|
super().__init__() |
|
self.dataset = dataset |
|
self.pad = pad |
|
self.eos = eos |
|
self.include_targets = include_targets |
|
|
|
assert len(dataset) > 0 |
|
|
|
assert len(dataset) == len(sizes) |
|
_sizes, block_to_dataset_index, slice_indices = self._build_slice_indices( |
|
sizes, break_mode, document_sep_len, block_size |
|
) |
|
if use_plasma_view: |
|
plasma_id = (block_size, document_sep_len, str(break_mode), len(dataset)) |
|
self._slice_indices = plasma_utils.PlasmaView( |
|
slice_indices, split_path, (plasma_id, 0), plasma_path=plasma_path |
|
) |
|
self._sizes = plasma_utils.PlasmaView( |
|
_sizes, split_path, (plasma_id, 1), plasma_path=plasma_path |
|
) |
|
self._block_to_dataset_index = plasma_utils.PlasmaView( |
|
block_to_dataset_index, split_path, (plasma_id, 2), plasma_path=plasma_path, |
|
) |
|
else: |
|
self._slice_indices = plasma_utils.PlasmaArray(slice_indices) |
|
self._sizes = plasma_utils.PlasmaArray(_sizes) |
|
self._block_to_dataset_index = plasma_utils.PlasmaArray( |
|
block_to_dataset_index |
|
) |
|
|
|
@staticmethod |
|
def _build_slice_indices( |
|
sizes, break_mode, document_sep_len, block_size |
|
) -> Tuple[np.ndarray]: |
|
"""Use token_block_utils_fast to build arrays for indexing into self.dataset""" |
|
try: |
|
from fairseq.data.token_block_utils_fast import ( |
|
_get_slice_indices_fast, |
|
_get_block_to_dataset_index_fast, |
|
) |
|
except ImportError: |
|
raise ImportError( |
|
"Please build Cython components with: `pip install --editable .` " |
|
"or `python setup.py build_ext --inplace`" |
|
) |
|
|
|
if isinstance(sizes, list): |
|
sizes = np.array(sizes, dtype=np.int64) |
|
else: |
|
if torch.is_tensor(sizes): |
|
sizes = sizes.numpy() |
|
sizes = sizes.astype(np.int64) |
|
|
|
break_mode = break_mode if break_mode is not None else "none" |
|
|
|
|
|
if break_mode == "eos" and block_size is None: |
|
block_size = 0 |
|
|
|
slice_indices = _get_slice_indices_fast( |
|
sizes, str(break_mode), block_size, document_sep_len |
|
) |
|
_sizes = slice_indices[:, 1] - slice_indices[:, 0] |
|
|
|
|
|
if break_mode == "eos": |
|
|
|
block_to_dataset_index = np.stack( |
|
[ |
|
np.arange(len(sizes)), |
|
np.zeros( |
|
len(sizes), dtype=np.compat.long |
|
), |
|
np.arange(len(sizes)), |
|
], |
|
1, |
|
) |
|
else: |
|
block_to_dataset_index = _get_block_to_dataset_index_fast( |
|
sizes, slice_indices, |
|
) |
|
size_dtype = np.uint16 if block_size < 65535 else np.uint32 |
|
num_tokens = slice_indices[-1].max() |
|
slice_indices_dtype = best_fitting_int_dtype(num_tokens) |
|
slice_indices = slice_indices.astype(slice_indices_dtype) |
|
_sizes = _sizes.astype(size_dtype) |
|
block_to_dataset_index = block_to_dataset_index.astype(slice_indices_dtype) |
|
return _sizes, block_to_dataset_index, slice_indices |
|
|
|
@property |
|
def slice_indices(self): |
|
return self._slice_indices.array |
|
|
|
@property |
|
def sizes(self): |
|
return self._sizes.array |
|
|
|
@property |
|
def block_to_dataset_index(self): |
|
return self._block_to_dataset_index.array |
|
|
|
def attr(self, attr: str, index: int): |
|
start_ds_idx, _, _ = self.block_to_dataset_index[index] |
|
return self.dataset.attr(attr, start_ds_idx) |
|
|
|
def __getitem__(self, index): |
|
start_ds_idx, start_offset, end_ds_idx = self.block_to_dataset_index[index] |
|
|
|
buffer = torch.cat( |
|
[self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)] |
|
) |
|
slice_s, slice_e = self.slice_indices[index] |
|
length = slice_e - slice_s |
|
s, e = start_offset, start_offset + length |
|
item = buffer[s:e] |
|
|
|
if self.include_targets: |
|
|
|
|
|
|
|
if s == 0: |
|
source = torch.cat([item.new([self.eos]), buffer[0 : e - 1]]) |
|
past_target = torch.cat( |
|
[item.new([self.pad, self.eos]), buffer[0 : e - 2]] |
|
) |
|
else: |
|
source = buffer[s - 1 : e - 1] |
|
if s == 1: |
|
past_target = torch.cat([item.new([self.eos]), buffer[0 : e - 2]]) |
|
else: |
|
past_target = buffer[s - 2 : e - 2] |
|
|
|
return source, item, past_target |
|
|
|
return item |
|
|
|
def __len__(self): |
|
return len(self.slice_indices) |
|
|
|
@property |
|
def supports_prefetch(self): |
|
return getattr(self.dataset, "supports_prefetch", False) |
|
|
|
def prefetch(self, indices): |
|
self.dataset.prefetch( |
|
{ |
|
ds_idx |
|
for index in indices |
|
for start_ds_idx, _, end_ds_idx in [self.block_to_dataset_index[index]] |
|
for ds_idx in range(start_ds_idx, end_ds_idx + 1) |
|
} |
|
) |
|
|