# cython: language_level=3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import numpy as np import torch from itertools import chain from libc.math cimport ceil cimport cython cimport numpy as np from libc.stdint cimport int32_t, int64_t DTYPE = np.int64 ctypedef int64_t DTYPE_t @cython.boundscheck(False) @cython.wraparound(False) @cython.nonecheck(False) cdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_none_mode(np.ndarray[DTYPE_t, ndim=1] sizes, int block_size): cdef DTYPE_t total_size = sizes.sum() cdef DTYPE_t length = ceil(total_size / block_size) cdef np.ndarray[DTYPE_t, ndim=2] slice_indices = np.zeros([length, 2], dtype=DTYPE) cdef DTYPE_t[:, :] slice_indices_view = slice_indices cdef DTYPE_t i cdef DTYPE_t start cdef DTYPE_t end for i in range(length): start = i * block_size end = min(start + block_size, total_size) slice_indices_view[i][0] = start slice_indices_view[i][1] = end return slice_indices cdef np.ndarray[DTYPE_t, ndim=2] _fast_convert_to_np_array(list list_of_list): """ Faster function to convert DTYPE_t list of list. Only fast when there are huge number of rows and low number of columns. """ cdef np.ndarray[DTYPE_t, ndim=1] flat = np.fromiter(chain.from_iterable(list_of_list), DTYPE, -1) return flat.reshape((len(list_of_list), -1)) @cython.boundscheck(False) @cython.wraparound(False) @cython.nonecheck(False) cpdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_fast(np.ndarray[DTYPE_t, ndim=1] sizes, str break_mode, int block_size, int document_sep_len): cdef DTYPE_t tok_idx = 0 cdef DTYPE_t sz_idx = 0 cdef DTYPE_t curr_size = 0 cdef DTYPE_t i = 0 cdef DTYPE_t length cdef DTYPE_t total_size cdef DTYPE_t[:] sizes_view = sizes cdef np.ndarray[DTYPE_t, ndim=2] slice_indices cdef list slice_indices_list = [] if break_mode is None or break_mode == 'none': slice_indices = _get_slice_indices_none_mode(sizes, block_size) elif break_mode == 'complete': while sz_idx < len(sizes_view): if curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0: curr_size += sizes_view[sz_idx] sz_idx += 1 else: slice_indices_list.append((tok_idx, tok_idx + curr_size)) tok_idx += curr_size curr_size = 0 if curr_size > 0: slice_indices_list.append((tok_idx, tok_idx + curr_size)) slice_indices = _fast_convert_to_np_array(slice_indices_list) elif break_mode == 'complete_doc': while sz_idx < len(sizes_view): if ( (curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0) # an empty sentence indicates end-of-document: and sizes_view[sz_idx] != document_sep_len ): curr_size += sizes_view[sz_idx] sz_idx += 1 else: # Only keep non-empty documents. if curr_size > 1: slice_indices_list.append((tok_idx, tok_idx + curr_size)) tok_idx += curr_size curr_size = 0 if sizes_view[sz_idx] == document_sep_len: tok_idx += sizes_view[sz_idx] sz_idx += 1 if curr_size > 1: slice_indices_list.append((tok_idx, tok_idx + curr_size)) slice_indices = _fast_convert_to_np_array(slice_indices_list) elif break_mode == 'eos': slice_indices = np.zeros((len(sizes), 2), dtype=DTYPE) cumsum = sizes.cumsum(axis=0) slice_indices[1:, 0] = cumsum[:cumsum.shape[0] - 1] slice_indices[:, 1] = cumsum else: raise ValueError('Invalid break_mode: ' + break_mode) return slice_indices @cython.boundscheck(False) @cython.wraparound(False) @cython.nonecheck(False) cpdef np.ndarray[DTYPE_t, ndim=2] _get_block_to_dataset_index_fast(np.ndarray[DTYPE_t, ndim=1] sizes, np.ndarray[DTYPE_t, ndim=2] slice_indices): cdef DTYPE_t start_ds_idx cdef DTYPE_t start_offset cdef DTYPE_t end_ds_idx cdef DTYPE_t i cdef DTYPE_t s cdef DTYPE_t e cdef DatasetSearcher ds = DatasetSearcher(sizes) cdef np.ndarray[DTYPE_t, ndim=2] block_to_dataset_index = np.zeros([len(slice_indices), 3], dtype=DTYPE) cdef DTYPE_t[:, :] block_to_dataset_index_view = block_to_dataset_index cdef DTYPE_t[:, :] slice_indices_view = slice_indices cdef Py_ssize_t x_max = slice_indices.shape[0] for i in range(x_max): s = slice_indices_view[i][0] e = slice_indices_view[i][1] ds.seek(s) start_ds_idx = ds.current_index start_offset = ds.current_offset if e <= s: end_ds_idx = start_ds_idx else: ds.seek(e - 1) end_ds_idx = ds.current_index block_to_dataset_index_view[i][0] = start_ds_idx # starting index in dataset block_to_dataset_index_view[i][1] = start_offset # starting offset within starting index block_to_dataset_index_view[i][2] = end_ds_idx # ending index in dataset return block_to_dataset_index cdef class DatasetSearcher(object): """Helper for mapping "flat" indices to indices and offsets in an underlying dataset.""" cdef DTYPE_t current_i cdef DTYPE_t current_offset cdef DTYPE_t current_index cdef DTYPE_t[:] sizes def __init__(self, DTYPE_t[:] sizes): self.sizes = sizes self.reset() cdef reset(self): self.current_offset = 0 # offset within current index in underlying dataset self.current_i = 0 # "flat" index self.current_index = 0 # index in underlying dataset @cython.boundscheck(False) @cython.wraparound(False) @cython.nonecheck(False) cdef int step(self, DTYPE_t i): cdef DTYPE_t to_consume cdef DTYPE_t remaining if i < self.current_i: self.reset() if i > self.current_i: to_consume = i - self.current_i remaining = self.sizes[self.current_index] - self.current_offset if remaining > to_consume: self.current_offset += to_consume self.current_i += to_consume else: assert remaining >= 0 self.current_i += remaining self.current_index += 1 self.current_offset = 0 return 1 return 0 @cython.boundscheck(False) @cython.wraparound(False) @cython.nonecheck(False) cdef seek(self, DTYPE_t i): cdef int not_done = 1 while not_done == 1: not_done = self.step(i) assert self.current_i == i