HuBERT / fairseq /data /token_block_utils_fast.pyx
aliabd
full working demo
d5175d3
# cython: language_level=3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from itertools import chain
from libc.math cimport ceil
cimport cython
cimport numpy as np
from libc.stdint cimport int32_t, int64_t
DTYPE = np.int64
ctypedef int64_t DTYPE_t
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
cdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_none_mode(np.ndarray[DTYPE_t, ndim=1] sizes, int block_size):
cdef DTYPE_t total_size = sizes.sum()
cdef DTYPE_t length = <DTYPE_t> ceil(total_size / <double> block_size)
cdef np.ndarray[DTYPE_t, ndim=2] slice_indices = np.zeros([length, 2], dtype=DTYPE)
cdef DTYPE_t[:, :] slice_indices_view = slice_indices
cdef DTYPE_t i
cdef DTYPE_t start
cdef DTYPE_t end
for i in range(length):
start = i * block_size
end = min(start + block_size, total_size)
slice_indices_view[i][0] = start
slice_indices_view[i][1] = end
return slice_indices
cdef np.ndarray[DTYPE_t, ndim=2] _fast_convert_to_np_array(list list_of_list):
"""
Faster function to convert DTYPE_t list of list.
Only fast when there are huge number of rows and low number of columns.
"""
cdef np.ndarray[DTYPE_t, ndim=1] flat = np.fromiter(chain.from_iterable(list_of_list), DTYPE, -1)
return flat.reshape((len(list_of_list), -1))
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
cpdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_fast(np.ndarray[DTYPE_t, ndim=1] sizes, str break_mode, int block_size, int document_sep_len):
cdef DTYPE_t tok_idx = 0
cdef DTYPE_t sz_idx = 0
cdef DTYPE_t curr_size = 0
cdef DTYPE_t i = 0
cdef DTYPE_t length
cdef DTYPE_t total_size
cdef DTYPE_t[:] sizes_view = sizes
cdef np.ndarray[DTYPE_t, ndim=2] slice_indices
cdef list slice_indices_list = []
if break_mode is None or break_mode == 'none':
slice_indices = _get_slice_indices_none_mode(sizes, block_size)
elif break_mode == 'complete':
while sz_idx < len(sizes_view):
if curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0:
curr_size += sizes_view[sz_idx]
sz_idx += 1
else:
slice_indices_list.append((tok_idx, tok_idx + curr_size))
tok_idx += curr_size
curr_size = 0
if curr_size > 0:
slice_indices_list.append((tok_idx, tok_idx + curr_size))
slice_indices = _fast_convert_to_np_array(slice_indices_list)
elif break_mode == 'complete_doc':
while sz_idx < len(sizes_view):
if (
(curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0)
# an empty sentence indicates end-of-document:
and sizes_view[sz_idx] != document_sep_len
):
curr_size += sizes_view[sz_idx]
sz_idx += 1
else:
# Only keep non-empty documents.
if curr_size > 1:
slice_indices_list.append((tok_idx, tok_idx + curr_size))
tok_idx += curr_size
curr_size = 0
if sizes_view[sz_idx] == document_sep_len:
tok_idx += sizes_view[sz_idx]
sz_idx += 1
if curr_size > 1:
slice_indices_list.append((tok_idx, tok_idx + curr_size))
slice_indices = _fast_convert_to_np_array(slice_indices_list)
elif break_mode == 'eos':
slice_indices = np.zeros((len(sizes), 2), dtype=DTYPE)
cumsum = sizes.cumsum(axis=0)
slice_indices[1:, 0] = cumsum[:cumsum.shape[0] - 1]
slice_indices[:, 1] = cumsum
else:
raise ValueError('Invalid break_mode: ' + break_mode)
return slice_indices
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
cpdef np.ndarray[DTYPE_t, ndim=2] _get_block_to_dataset_index_fast(np.ndarray[DTYPE_t, ndim=1] sizes, np.ndarray[DTYPE_t, ndim=2] slice_indices):
cdef DTYPE_t start_ds_idx
cdef DTYPE_t start_offset
cdef DTYPE_t end_ds_idx
cdef DTYPE_t i
cdef DTYPE_t s
cdef DTYPE_t e
cdef DatasetSearcher ds = DatasetSearcher(sizes)
cdef np.ndarray[DTYPE_t, ndim=2] block_to_dataset_index = np.zeros([len(slice_indices), 3], dtype=DTYPE)
cdef DTYPE_t[:, :] block_to_dataset_index_view = block_to_dataset_index
cdef DTYPE_t[:, :] slice_indices_view = slice_indices
cdef Py_ssize_t x_max = slice_indices.shape[0]
for i in range(x_max):
s = slice_indices_view[i][0]
e = slice_indices_view[i][1]
ds.seek(s)
start_ds_idx = ds.current_index
start_offset = ds.current_offset
if e <= s:
end_ds_idx = start_ds_idx
else:
ds.seek(e - 1)
end_ds_idx = ds.current_index
block_to_dataset_index_view[i][0] = start_ds_idx # starting index in dataset
block_to_dataset_index_view[i][1] = start_offset # starting offset within starting index
block_to_dataset_index_view[i][2] = end_ds_idx # ending index in dataset
return block_to_dataset_index
cdef class DatasetSearcher(object):
"""Helper for mapping "flat" indices to indices and offsets in an
underlying dataset."""
cdef DTYPE_t current_i
cdef DTYPE_t current_offset
cdef DTYPE_t current_index
cdef DTYPE_t[:] sizes
def __init__(self, DTYPE_t[:] sizes):
self.sizes = sizes
self.reset()
cdef reset(self):
self.current_offset = 0 # offset within current index in underlying dataset
self.current_i = 0 # "flat" index
self.current_index = 0 # index in underlying dataset
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
cdef int step(self, DTYPE_t i):
cdef DTYPE_t to_consume
cdef DTYPE_t remaining
if i < self.current_i:
self.reset()
if i > self.current_i:
to_consume = i - self.current_i
remaining = self.sizes[self.current_index] - self.current_offset
if remaining > to_consume:
self.current_offset += to_consume
self.current_i += to_consume
else:
assert remaining >= 0
self.current_i += remaining
self.current_index += 1
self.current_offset = 0
return 1
return 0
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
cdef seek(self, DTYPE_t i):
cdef int not_done = 1
while not_done == 1:
not_done = self.step(i)
assert self.current_i == i