# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch def exclusive_cumprod(tensor, dim: int, eps: float = 1e-10): """ Implementing exclusive cumprod. There is cumprod in pytorch, however there is no exclusive mode. cumprod(x) = [x1, x1x2, x2x3x4, ..., prod_{i=1}^n x_i] exclusive means cumprod(x) = [1, x1, x1x2, x1x2x3, ..., prod_{i=1}^{n-1} x_i] """ tensor_size = list(tensor.size()) tensor_size[dim] = 1 return_tensor = safe_cumprod( torch.cat([torch.ones(tensor_size).type_as(tensor), tensor], dim=dim), dim=dim, eps=eps, ) if dim == 0: return return_tensor[:-1] elif dim == 1: return return_tensor[:, :-1] elif dim == 2: return return_tensor[:, :, :-1] else: raise RuntimeError("Cumprod on dimension 3 and more is not implemented") def safe_cumprod(tensor, dim: int, eps: float = 1e-10): """ An implementation of cumprod to prevent precision issue. cumprod(x) = [x1, x1x2, x1x2x3, ....] = [exp(log(x1)), exp(log(x1) + log(x2)), exp(log(x1) + log(x2) + log(x3)), ...] = exp(cumsum(log(x))) """ if (tensor + eps < 0).any().item(): raise RuntimeError( "Safe cumprod can only take non-negative tensors as input." "Consider use torch.cumprod if you want to calculate negative values." ) log_tensor = torch.log(tensor + eps) cumsum_log_tensor = torch.cumsum(log_tensor, dim) exp_cumsum_log_tensor = torch.exp(cumsum_log_tensor) return exp_cumsum_log_tensor def lengths_to_mask(lengths, max_len: int, dim: int = 0, negative_mask: bool = False): """ Convert a tensor of lengths to mask For example, lengths = [[2, 3, 4]], max_len = 5 mask = [[1, 1, 1], [1, 1, 1], [0, 1, 1], [0, 0, 1], [0, 0, 0]] """ assert len(lengths.size()) <= 2 if len(lengths) == 2: if dim == 1: lengths = lengths.t() lengths = lengths else: lengths = lengths.unsqueeze(1) # lengths : batch_size, 1 lengths = lengths.view(-1, 1) batch_size = lengths.size(0) # batch_size, max_len mask = torch.arange(max_len).expand(batch_size, max_len).type_as(lengths) < lengths if negative_mask: mask = ~mask if dim == 0: # max_len, batch_size mask = mask.t() return mask def moving_sum(x, start_idx: int, end_idx: int): """ From MONOTONIC CHUNKWISE ATTENTION https://arxiv.org/pdf/1712.05382.pdf Equation (18) x = [x_1, x_2, ..., x_N] MovingSum(x, start_idx, end_idx)_n = Sigma_{m=n−(start_idx−1)}^{n+end_idx-1} x_m for n in {1, 2, 3, ..., N} x : src_len, batch_size start_idx : start idx end_idx : end idx Example src_len = 5 batch_size = 3 x = [[ 0, 5, 10], [ 1, 6, 11], [ 2, 7, 12], [ 3, 8, 13], [ 4, 9, 14]] MovingSum(x, 3, 1) = [[ 0, 5, 10], [ 1, 11, 21], [ 3, 18, 33], [ 6, 21, 36], [ 9, 24, 39]] MovingSum(x, 1, 3) = [[ 3, 18, 33], [ 6, 21, 36], [ 9, 24, 39], [ 7, 17, 27], [ 4, 9, 14]] """ assert start_idx > 0 and end_idx > 0 assert len(x.size()) == 2 src_len, batch_size = x.size() # batch_size, 1, src_len x = x.t().unsqueeze(1) # batch_size, 1, src_len moving_sum_weight = x.new_ones([1, 1, end_idx + start_idx - 1]) moving_sum = ( torch.nn.functional.conv1d( x, moving_sum_weight, padding=start_idx + end_idx - 1 ) .squeeze(1) .t() ) moving_sum = moving_sum[end_idx:-start_idx] assert src_len == moving_sum.size(0) assert batch_size == moving_sum.size(1) return moving_sum