Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
import torch | |
from polos.tokenizers_ import TextEncoderBase | |
def average_pooling( | |
tokens: torch.Tensor, | |
embeddings: torch.Tensor, | |
mask: torch.Tensor, | |
padding_index: int, | |
) -> torch.Tensor: | |
"""Average pooling function. | |
:param tokens: Word ids [batch_size x seq_length] | |
:param embeddings: Word embeddings [batch_size x seq_length x hidden_size] | |
:param mask: Padding mask [batch_size x seq_length] | |
:param padding_index: Padding value. | |
""" | |
wordemb = mask_fill(0.0, tokens, embeddings, padding_index) | |
sentemb = torch.sum(wordemb, 1) | |
sum_mask = mask.unsqueeze(-1).expand(embeddings.size()).float().sum(1) | |
return sentemb / sum_mask | |
def max_pooling( | |
tokens: torch.Tensor, embeddings: torch.Tensor, padding_index: int | |
) -> torch.Tensor: | |
"""Max pooling function. | |
:param tokens: Word ids [batch_size x seq_length] | |
:param embeddings: Word embeddings [batch_size x seq_length x hidden_size] | |
:param padding_index: Padding value. | |
""" | |
return mask_fill(float("-inf"), tokens, embeddings, padding_index).max(dim=1)[0] | |
def mask_fill( | |
fill_value: float, | |
tokens: torch.Tensor, | |
embeddings: torch.Tensor, | |
padding_index: int, | |
) -> torch.Tensor: | |
""" | |
Function that masks embeddings representing padded elements. | |
:param fill_value: the value to fill the embeddings belonging to padded tokens. | |
:param tokens: The input sequences [bsz x seq_len]. | |
:param embeddings: word embeddings [bsz x seq_len x hiddens]. | |
:param padding_index: Index of the padding token. | |
""" | |
padding_mask = tokens.eq(padding_index).unsqueeze(-1) | |
return embeddings.float().masked_fill_(padding_mask, fill_value).type_as(embeddings) | |
def sort_sequences(inputs: torch.Tensor, input_lengths: torch.Tensor): | |
""" | |
Sort sequences according to lengths of the input sequence (descendingly). | |
:param inputs (Tensor): input sequences, size [B, T, D] | |
:param input_lengths (Tensor): length of each sequence, size [B] | |
""" | |
lengths_sorted, sorted_idx = input_lengths.sort(descending=True) | |
_, unsorted_idx = sorted_idx.sort() | |
return inputs[sorted_idx], lengths_sorted, unsorted_idx | |
def apply_to_sample(f, sample): | |
if hasattr(sample, "__len__") and len(sample) == 0: | |
return {} | |
def _apply(x): | |
if torch.is_tensor(x): | |
return f(x) | |
elif isinstance(x, dict): | |
return {key: _apply(value) for key, value in x.items()} | |
elif isinstance(x, list): | |
return [_apply(x) for x in x] | |
else: | |
return x | |
return _apply(sample) | |
def move_to_cuda(sample): | |
""" Moves a sample to cuda. Works with dictionaries, tensors and lists. """ | |
def _move_to_cuda(tensor): | |
return tensor.cuda() | |
return apply_to_sample(_move_to_cuda, sample) | |
def move_to_cpu(sample): | |
""" Moves a sample to cuda. Works with dictionaries, tensors and lists. """ | |
def _move_to_cpu(tensor): | |
return tensor.cpu() | |
return apply_to_sample(_move_to_cpu, sample) | |
# --------------- LASER auxiliar functions from facebook research ------------------------------ | |
def buffered_arange(max): | |
if not hasattr(buffered_arange, "buf"): | |
buffered_arange.buf = torch.LongTensor() | |
if max > buffered_arange.buf.numel(): | |
torch.arange(max, out=buffered_arange.buf) | |
return buffered_arange.buf[:max] | |
def convert_padding_direction( | |
src_tokens, padding_idx, right_to_left=False, left_to_right=False | |
): | |
assert right_to_left ^ left_to_right | |
pad_mask = src_tokens.eq(padding_idx) | |
if not pad_mask.any(): | |
# no padding, return early | |
return src_tokens | |
if left_to_right and not pad_mask[:, 0].any(): | |
# already right padded | |
return src_tokens | |
if right_to_left and not pad_mask[:, -1].any(): | |
# already left padded | |
return src_tokens | |
max_len = src_tokens.size(1) | |
range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens) | |
num_pads = pad_mask.long().sum(dim=1, keepdim=True) | |
if right_to_left: | |
index = torch.remainder(range - num_pads, max_len) | |
else: | |
index = torch.remainder(range + num_pads, max_len) | |
return src_tokens.gather(1, index) | |