|
from pathlib import Path |
|
from typing import List, Tuple |
|
|
|
import numpy as np |
|
import torch |
|
|
|
|
|
class TextTokenCollater: |
|
"""Collate list of text tokens |
|
|
|
Map sentences to integers. Sentences are padded to equal length. |
|
Beginning and end-of-sequence symbols can be added. |
|
|
|
Example: |
|
>>> token_collater = TextTokenCollater(text_tokens) |
|
>>> tokens_batch, tokens_lens = token_collater(text) |
|
|
|
Returns: |
|
tokens_batch: IntTensor of shape (B, L) |
|
B: batch dimension, number of input sentences |
|
L: length of the longest sentence |
|
tokens_lens: IntTensor of shape (B,) |
|
Length of each sentence after adding <eos> and <bos> |
|
but before padding. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
text_tokens: List[str], |
|
add_eos: bool = True, |
|
add_bos: bool = True, |
|
pad_symbol: str = "<pad>", |
|
bos_symbol: str = "<bos>", |
|
eos_symbol: str = "<eos>", |
|
): |
|
self.pad_symbol = pad_symbol |
|
|
|
self.add_eos = add_eos |
|
self.add_bos = add_bos |
|
|
|
self.bos_symbol = bos_symbol |
|
self.eos_symbol = eos_symbol |
|
|
|
unique_tokens = ( |
|
[pad_symbol] |
|
+ ([bos_symbol] if add_bos else []) |
|
+ ([eos_symbol] if add_eos else []) |
|
+ sorted(text_tokens) |
|
) |
|
|
|
self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)} |
|
self.idx2token = [token for token in unique_tokens] |
|
|
|
def index( |
|
self, tokens_list: List[str] |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
seqs, seq_lens = [], [] |
|
for tokens in tokens_list: |
|
assert ( |
|
all([True if s in self.token2idx else False for s in tokens]) |
|
is True |
|
) |
|
seq = ( |
|
([self.bos_symbol] if self.add_bos else []) |
|
+ list(tokens) |
|
+ ([self.eos_symbol] if self.add_eos else []) |
|
) |
|
seqs.append(seq) |
|
seq_lens.append(len(seq)) |
|
|
|
max_len = max(seq_lens) |
|
for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)): |
|
seq.extend([self.pad_symbol] * (max_len - seq_len)) |
|
|
|
tokens = torch.from_numpy( |
|
np.array( |
|
[[self.token2idx[token] for token in seq] for seq in seqs], |
|
dtype=np.int64, |
|
) |
|
) |
|
tokens_lens = torch.IntTensor(seq_lens) |
|
|
|
return tokens, tokens_lens |
|
|
|
def __call__(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: |
|
tokens_seqs = [[p for p in text] for text in texts] |
|
max_len = len(max(tokens_seqs, key=len)) |
|
|
|
seqs = [ |
|
([self.bos_symbol] if self.add_bos else []) |
|
+ list(seq) |
|
+ ([self.eos_symbol] if self.add_eos else []) |
|
+ [self.pad_symbol] * (max_len - len(seq)) |
|
for seq in tokens_seqs |
|
] |
|
|
|
tokens_batch = torch.from_numpy( |
|
np.array( |
|
[seq for seq in seqs], |
|
dtype=np.int64, |
|
) |
|
) |
|
|
|
tokens_lens = torch.IntTensor( |
|
[ |
|
len(seq) + int(self.add_eos) + int(self.add_bos) |
|
for seq in tokens_seqs |
|
] |
|
) |
|
|
|
return tokens_batch, tokens_lens |
|
|
|
|
|
def get_text_token_collater() -> TextTokenCollater: |
|
collater = TextTokenCollater( |
|
['0'], add_bos=False, add_eos=False |
|
) |
|
return collater |
|
|