File size: 7,588 Bytes
d5175d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from fairseq.data import FairseqDataset, plasma_utils
from fairseq.data.indexed_dataset import best_fitting_int_dtype
from typing import Tuple
class TokenBlockDataset(FairseqDataset):
"""Break a Dataset of tokens into blocks.
Args:
dataset (~torch.utils.data.Dataset): dataset to break into blocks
sizes (List[int]): sentence lengths (required for 'complete' and 'eos')
block_size (int): maximum block size (ignored in 'eos' break mode)
break_mode (str, optional): Mode used for breaking tokens. Values can
be one of:
- 'none': break tokens into equally sized blocks (up to block_size)
- 'complete': break tokens into blocks (up to block_size) such that
blocks contains complete sentences, although block_size may be
exceeded if some sentences exceed block_size
- 'complete_doc': similar to 'complete' mode, but do not
cross document boundaries
- 'eos': each block contains one sentence (block_size is ignored)
include_targets (bool, optional): return next tokens as targets
(default: False).
document_sep_len (int, optional): document separator size (required for
'complete_doc' break mode). Typically 1 if the sentences have eos
and 0 otherwise.
"""
def __init__(
self,
dataset,
sizes,
block_size,
pad,
eos,
break_mode=None,
include_targets=False,
document_sep_len=1,
use_plasma_view=False,
split_path=None,
plasma_path=None,
):
super().__init__()
self.dataset = dataset
self.pad = pad
self.eos = eos
self.include_targets = include_targets
assert len(dataset) > 0
assert len(dataset) == len(sizes)
_sizes, block_to_dataset_index, slice_indices = self._build_slice_indices(
sizes, break_mode, document_sep_len, block_size
)
if use_plasma_view:
plasma_id = (block_size, document_sep_len, str(break_mode), len(dataset))
self._slice_indices = plasma_utils.PlasmaView(
slice_indices, split_path, (plasma_id, 0), plasma_path=plasma_path
)
self._sizes = plasma_utils.PlasmaView(
_sizes, split_path, (plasma_id, 1), plasma_path=plasma_path
)
self._block_to_dataset_index = plasma_utils.PlasmaView(
block_to_dataset_index, split_path, (plasma_id, 2), plasma_path=plasma_path,
)
else:
self._slice_indices = plasma_utils.PlasmaArray(slice_indices)
self._sizes = plasma_utils.PlasmaArray(_sizes)
self._block_to_dataset_index = plasma_utils.PlasmaArray(
block_to_dataset_index
)
@staticmethod
def _build_slice_indices(
sizes, break_mode, document_sep_len, block_size
) -> Tuple[np.ndarray]:
"""Use token_block_utils_fast to build arrays for indexing into self.dataset"""
try:
from fairseq.data.token_block_utils_fast import (
_get_slice_indices_fast,
_get_block_to_dataset_index_fast,
)
except ImportError:
raise ImportError(
"Please build Cython components with: `pip install --editable .` "
"or `python setup.py build_ext --inplace`"
)
if isinstance(sizes, list):
sizes = np.array(sizes, dtype=np.int64)
else:
if torch.is_tensor(sizes):
sizes = sizes.numpy()
sizes = sizes.astype(np.int64)
break_mode = break_mode if break_mode is not None else "none"
# For "eos" break-mode, block_size is not required parameters.
if break_mode == "eos" and block_size is None:
block_size = 0
slice_indices = _get_slice_indices_fast(
sizes, str(break_mode), block_size, document_sep_len
)
_sizes = slice_indices[:, 1] - slice_indices[:, 0]
# build index mapping block indices to the underlying dataset indices
if break_mode == "eos":
# much faster version for eos break mode
block_to_dataset_index = np.stack(
[
np.arange(len(sizes)), # starting index in dataset
np.zeros(
len(sizes), dtype=np.compat.long
), # starting offset within starting index
np.arange(len(sizes)), # ending index in dataset
],
1,
)
else:
block_to_dataset_index = _get_block_to_dataset_index_fast(
sizes, slice_indices,
)
size_dtype = np.uint16 if block_size < 65535 else np.uint32
num_tokens = slice_indices[-1].max()
slice_indices_dtype = best_fitting_int_dtype(num_tokens)
slice_indices = slice_indices.astype(slice_indices_dtype)
_sizes = _sizes.astype(size_dtype)
block_to_dataset_index = block_to_dataset_index.astype(slice_indices_dtype)
return _sizes, block_to_dataset_index, slice_indices
@property
def slice_indices(self):
return self._slice_indices.array
@property
def sizes(self):
return self._sizes.array
@property
def block_to_dataset_index(self):
return self._block_to_dataset_index.array
def attr(self, attr: str, index: int):
start_ds_idx, _, _ = self.block_to_dataset_index[index]
return self.dataset.attr(attr, start_ds_idx)
def __getitem__(self, index):
start_ds_idx, start_offset, end_ds_idx = self.block_to_dataset_index[index]
buffer = torch.cat(
[self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)]
)
slice_s, slice_e = self.slice_indices[index]
length = slice_e - slice_s
s, e = start_offset, start_offset + length
item = buffer[s:e]
if self.include_targets:
# *target* is the original sentence (=item)
# *source* is shifted right by 1 (maybe left-padded with eos)
# *past_target* is shifted right by 2 (left-padded as needed)
if s == 0:
source = torch.cat([item.new([self.eos]), buffer[0 : e - 1]])
past_target = torch.cat(
[item.new([self.pad, self.eos]), buffer[0 : e - 2]]
)
else:
source = buffer[s - 1 : e - 1]
if s == 1:
past_target = torch.cat([item.new([self.eos]), buffer[0 : e - 2]])
else:
past_target = buffer[s - 2 : e - 2]
return source, item, past_target
return item
def __len__(self):
return len(self.slice_indices)
@property
def supports_prefetch(self):
return getattr(self.dataset, "supports_prefetch", False)
def prefetch(self, indices):
self.dataset.prefetch(
{
ds_idx
for index in indices
for start_ds_idx, _, end_ds_idx in [self.block_to_dataset_index[index]]
for ds_idx in range(start_ds_idx, end_ds_idx + 1)
}
)
|