|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import torch |
|
from itertools import chain |
|
from libc.math cimport ceil |
|
|
|
cimport cython |
|
cimport numpy as np |
|
|
|
from libc.stdint cimport int32_t, int64_t |
|
|
|
DTYPE = np.int64 |
|
ctypedef int64_t DTYPE_t |
|
|
|
|
|
@cython.boundscheck(False) |
|
@cython.wraparound(False) |
|
@cython.nonecheck(False) |
|
cdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_none_mode(np.ndarray[DTYPE_t, ndim=1] sizes, int block_size): |
|
cdef DTYPE_t total_size = sizes.sum() |
|
cdef DTYPE_t length = <DTYPE_t> ceil(total_size / <double> block_size) |
|
cdef np.ndarray[DTYPE_t, ndim=2] slice_indices = np.zeros([length, 2], dtype=DTYPE) |
|
cdef DTYPE_t[:, :] slice_indices_view = slice_indices |
|
cdef DTYPE_t i |
|
cdef DTYPE_t start |
|
cdef DTYPE_t end |
|
for i in range(length): |
|
start = i * block_size |
|
end = min(start + block_size, total_size) |
|
slice_indices_view[i][0] = start |
|
slice_indices_view[i][1] = end |
|
return slice_indices |
|
|
|
|
|
cdef np.ndarray[DTYPE_t, ndim=2] _fast_convert_to_np_array(list list_of_list): |
|
""" |
|
Faster function to convert DTYPE_t list of list. |
|
Only fast when there are huge number of rows and low number of columns. |
|
""" |
|
cdef np.ndarray[DTYPE_t, ndim=1] flat = np.fromiter(chain.from_iterable(list_of_list), DTYPE, -1) |
|
return flat.reshape((len(list_of_list), -1)) |
|
|
|
|
|
@cython.boundscheck(False) |
|
@cython.wraparound(False) |
|
@cython.nonecheck(False) |
|
cpdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_fast(np.ndarray[DTYPE_t, ndim=1] sizes, str break_mode, int block_size, int document_sep_len): |
|
cdef DTYPE_t tok_idx = 0 |
|
cdef DTYPE_t sz_idx = 0 |
|
cdef DTYPE_t curr_size = 0 |
|
cdef DTYPE_t i = 0 |
|
cdef DTYPE_t length |
|
cdef DTYPE_t total_size |
|
cdef DTYPE_t[:] sizes_view = sizes |
|
cdef np.ndarray[DTYPE_t, ndim=2] slice_indices |
|
cdef list slice_indices_list = [] |
|
|
|
if break_mode is None or break_mode == 'none': |
|
slice_indices = _get_slice_indices_none_mode(sizes, block_size) |
|
elif break_mode == 'complete': |
|
while sz_idx < len(sizes_view): |
|
if curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0: |
|
curr_size += sizes_view[sz_idx] |
|
sz_idx += 1 |
|
else: |
|
slice_indices_list.append((tok_idx, tok_idx + curr_size)) |
|
tok_idx += curr_size |
|
curr_size = 0 |
|
if curr_size > 0: |
|
slice_indices_list.append((tok_idx, tok_idx + curr_size)) |
|
slice_indices = _fast_convert_to_np_array(slice_indices_list) |
|
elif break_mode == 'complete_doc': |
|
while sz_idx < len(sizes_view): |
|
if ( |
|
(curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0) |
|
|
|
and sizes_view[sz_idx] != document_sep_len |
|
): |
|
curr_size += sizes_view[sz_idx] |
|
sz_idx += 1 |
|
else: |
|
|
|
if curr_size > 1: |
|
slice_indices_list.append((tok_idx, tok_idx + curr_size)) |
|
tok_idx += curr_size |
|
curr_size = 0 |
|
if sizes_view[sz_idx] == document_sep_len: |
|
tok_idx += sizes_view[sz_idx] |
|
sz_idx += 1 |
|
if curr_size > 1: |
|
slice_indices_list.append((tok_idx, tok_idx + curr_size)) |
|
slice_indices = _fast_convert_to_np_array(slice_indices_list) |
|
elif break_mode == 'eos': |
|
slice_indices = np.zeros((len(sizes), 2), dtype=DTYPE) |
|
cumsum = sizes.cumsum(axis=0) |
|
slice_indices[1:, 0] = cumsum[:cumsum.shape[0] - 1] |
|
slice_indices[:, 1] = cumsum |
|
else: |
|
raise ValueError('Invalid break_mode: ' + break_mode) |
|
return slice_indices |
|
|
|
|
|
@cython.boundscheck(False) |
|
@cython.wraparound(False) |
|
@cython.nonecheck(False) |
|
cpdef np.ndarray[DTYPE_t, ndim=2] _get_block_to_dataset_index_fast(np.ndarray[DTYPE_t, ndim=1] sizes, np.ndarray[DTYPE_t, ndim=2] slice_indices): |
|
cdef DTYPE_t start_ds_idx |
|
cdef DTYPE_t start_offset |
|
cdef DTYPE_t end_ds_idx |
|
cdef DTYPE_t i |
|
cdef DTYPE_t s |
|
cdef DTYPE_t e |
|
cdef DatasetSearcher ds = DatasetSearcher(sizes) |
|
cdef np.ndarray[DTYPE_t, ndim=2] block_to_dataset_index = np.zeros([len(slice_indices), 3], dtype=DTYPE) |
|
cdef DTYPE_t[:, :] block_to_dataset_index_view = block_to_dataset_index |
|
cdef DTYPE_t[:, :] slice_indices_view = slice_indices |
|
cdef Py_ssize_t x_max = slice_indices.shape[0] |
|
|
|
for i in range(x_max): |
|
s = slice_indices_view[i][0] |
|
e = slice_indices_view[i][1] |
|
ds.seek(s) |
|
start_ds_idx = ds.current_index |
|
start_offset = ds.current_offset |
|
if e <= s: |
|
end_ds_idx = start_ds_idx |
|
else: |
|
ds.seek(e - 1) |
|
end_ds_idx = ds.current_index |
|
block_to_dataset_index_view[i][0] = start_ds_idx |
|
block_to_dataset_index_view[i][1] = start_offset |
|
block_to_dataset_index_view[i][2] = end_ds_idx |
|
return block_to_dataset_index |
|
|
|
|
|
cdef class DatasetSearcher(object): |
|
"""Helper for mapping "flat" indices to indices and offsets in an |
|
underlying dataset.""" |
|
cdef DTYPE_t current_i |
|
cdef DTYPE_t current_offset |
|
cdef DTYPE_t current_index |
|
cdef DTYPE_t[:] sizes |
|
|
|
def __init__(self, DTYPE_t[:] sizes): |
|
self.sizes = sizes |
|
self.reset() |
|
|
|
cdef reset(self): |
|
self.current_offset = 0 |
|
self.current_i = 0 |
|
self.current_index = 0 |
|
|
|
@cython.boundscheck(False) |
|
@cython.wraparound(False) |
|
@cython.nonecheck(False) |
|
cdef int step(self, DTYPE_t i): |
|
cdef DTYPE_t to_consume |
|
cdef DTYPE_t remaining |
|
if i < self.current_i: |
|
self.reset() |
|
if i > self.current_i: |
|
to_consume = i - self.current_i |
|
remaining = self.sizes[self.current_index] - self.current_offset |
|
if remaining > to_consume: |
|
self.current_offset += to_consume |
|
self.current_i += to_consume |
|
else: |
|
assert remaining >= 0 |
|
self.current_i += remaining |
|
self.current_index += 1 |
|
self.current_offset = 0 |
|
return 1 |
|
return 0 |
|
|
|
@cython.boundscheck(False) |
|
@cython.wraparound(False) |
|
@cython.nonecheck(False) |
|
cdef seek(self, DTYPE_t i): |
|
cdef int not_done = 1 |
|
while not_done == 1: |
|
not_done = self.step(i) |
|
assert self.current_i == i |
|
|