|
import json |
|
import unittest |
|
import os |
|
from collections import Counter |
|
from typing import Dict, List, Optional, Sized, Tuple, Union, Any |
|
|
|
import torch |
|
import numpy as np |
|
from tokenizers import AddedToken |
|
from transformers import PreTrainedTokenizer |
|
from transformers.tokenization_utils_base import ( |
|
BatchEncoding, |
|
EncodedInput, |
|
TruncationStrategy, |
|
) |
|
from transformers.utils import logging |
|
from transformers.utils.generic import PaddingStrategy, TensorType, to_py_obj |
|
|
|
from .ngme import ngrams as ngram_tokenizer |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
def load_vocab(vocab_file): |
|
"""Loads a vocabulary file into a dictionary.""" |
|
with open(vocab_file, "r", encoding="utf-8") as f: |
|
vocab = json.load(f) |
|
return vocab |
|
|
|
|
|
def all_same(items): |
|
return all(x == items[0] for x in items) |
|
|
|
|
|
class NGMETokenizer(PreTrainedTokenizer): |
|
model_input_names = ["input_ids", "attention_mask"] |
|
vocab_file = "vocab.json" |
|
vocab_files_names = {"vocab_file": vocab_file} |
|
|
|
def __init__( |
|
self, |
|
vocab_file, |
|
eos_token="\n", |
|
pad_token="\n", |
|
unk_token="<unk>", |
|
eod_token="<eod>", |
|
**kwargs, |
|
): |
|
super().__init__( |
|
eos_token=eos_token, pad_token=pad_token, unk_token=unk_token, **kwargs |
|
) |
|
|
|
eos_token = ( |
|
AddedToken( |
|
eos_token, |
|
lstrip=False, |
|
rstrip=False, |
|
) |
|
if isinstance(eos_token, str) |
|
else eos_token |
|
) |
|
pad_token = ( |
|
AddedToken( |
|
pad_token, |
|
lstrip=False, |
|
rstrip=False, |
|
) |
|
if isinstance(pad_token, str) |
|
else pad_token |
|
) |
|
unk_token = ( |
|
AddedToken( |
|
unk_token, |
|
lstrip=False, |
|
rstrip=False, |
|
) |
|
if isinstance(unk_token, str) |
|
else unk_token |
|
) |
|
|
|
self._ngram2word2idx = {} |
|
self._ngram2idx2word = {} |
|
self._current_max_idx = 0 |
|
self._frequencies: Counter = Counter() |
|
|
|
self._load_from_file(vocab_file) |
|
|
|
for n in range(2, self.ngram + 1): |
|
self.model_input_names.append(f"ngram_{n}_sequence") |
|
|
|
|
|
self._special_token = "Ġ" |
|
assert self._special_token not in self._ngram2word2idx[1] |
|
|
|
def __call__(self, *args, **kwargs) -> BatchEncoding: |
|
if "return_ngram_sequences" in kwargs: |
|
return_ngram_sequences = kwargs["return_ngram_sequences"] |
|
del kwargs["return_ngram_sequences"] |
|
else: |
|
return_ngram_sequences = False |
|
|
|
|
|
|
|
batch_encoding = super().__call__(*args, **kwargs) |
|
|
|
if return_ngram_sequences: |
|
ngram_sequences = self.create_ngram_sequences(args[0]) |
|
|
|
|
|
if "padding" in kwargs: |
|
if kwargs["padding"] == "max_length": |
|
padded_sequences = {} |
|
for n_key, sequence in ngram_sequences.items(): |
|
padded_sequences[n_key] = self.pad_sequence_right( |
|
sequence, |
|
len(batch_encoding["input_ids"][0]), |
|
self.pad_token_id, |
|
) |
|
|
|
ngram_sequences = padded_sequences |
|
elif kwargs["padding"] == "longest": |
|
padded_sequences = {} |
|
for n_key, sequence in ngram_sequences.items(): |
|
padded_sequences[n_key] = self.pad_sequence_right( |
|
sequence, |
|
max([len(seq) for seq in sequence]), |
|
self.pad_token_id, |
|
) |
|
ngram_sequences = padded_sequences |
|
|
|
else: |
|
raise ValueError( |
|
f"Padding {kwargs['padding']} not supported for ngram sequences" |
|
) |
|
|
|
if "truncation" in kwargs and kwargs["truncation"]: |
|
truncated_sequences = {} |
|
for n_key, sequence in ngram_sequences.items(): |
|
truncated_sequences[n_key] = self.truncate_sequence_right( |
|
sequence, len(batch_encoding["input_ids"][0]) |
|
) |
|
ngram_sequences = truncated_sequences |
|
|
|
batch_encoding.update(ngram_sequences) |
|
|
|
if "return_tensors" in kwargs: |
|
batch_encoding.convert_to_tensors(kwargs["return_tensors"]) |
|
|
|
return batch_encoding |
|
|
|
def pad_sequence_right( |
|
self, batched_sequence: List[List[int]], padding_length: int, padding_value: int |
|
) -> List[List[int]]: |
|
padded_sequence = [] |
|
for sequence in batched_sequence: |
|
padded_sequence.append( |
|
sequence + [padding_value] * (padding_length - len(sequence)) |
|
) |
|
return padded_sequence |
|
|
|
def truncate_sequence_right( |
|
self, batched_sequence: List[List[int]], max_length: int |
|
) -> List[List[int]]: |
|
truncated_sequence = [] |
|
for sequence in batched_sequence: |
|
truncated_sequence.append(sequence[:max_length]) |
|
return truncated_sequence |
|
|
|
def create_ngram_sequences(self, char_sequences: List[str]) -> Dict[str, Any]: |
|
ngram_sequences_output = {} |
|
|
|
if isinstance(char_sequences, str): |
|
char_sequences = [char_sequences] |
|
|
|
for n in range(2, self.ngram + 1): |
|
ngram_sequences = [] |
|
for char_sequence in char_sequences: |
|
ngrams = ["".join(ngram) for ngram in ngram_tokenizer(char_sequence, n)] |
|
|
|
|
|
ngrams = list(char_sequence[: n - 1]) + ngrams |
|
encoded_ngrams = self.encode(ngrams) if len(ngrams) > 0 else [] |
|
ngram_sequences.append(encoded_ngrams) |
|
|
|
ngram_sequences_output[f"label_gram_{n}_sequence"] = ngram_sequences |
|
|
|
return ngram_sequences_output |
|
|
|
def _seq_size(self, encoded) -> Union[int, List[int]]: |
|
if isinstance(encoded, torch.Tensor): |
|
encoded = encoded.tolist() |
|
|
|
if isinstance(encoded[0], list): |
|
return [len(enc) for enc in encoded] |
|
|
|
return len(encoded) |
|
|
|
def _load_from_file(self, filename: str): |
|
"""Loads a dictionary from a file.""" |
|
vocab_file = load_vocab(filename) |
|
self.ngram = vocab_file["ngram"] |
|
|
|
if "\n" not in vocab_file["vocab"]: |
|
self._add_ngram("\n", 1) |
|
|
|
for token in vocab_file["vocab"]: |
|
self._add_ngram(token["token"], token["ngram"]) |
|
self._frequencies.update({token["token"]: token["frequency"]}) |
|
|
|
def _add_ngram(self, word, ngram: int) -> int: |
|
"""Add a new n-gram token to the dictionary.""" |
|
self._frequencies.update({word: 1}) |
|
|
|
if ngram not in self._ngram2idx2word: |
|
self._ngram2idx2word[ngram] = {self._current_max_idx: word} |
|
self._ngram2word2idx[ngram] = {word: self._current_max_idx} |
|
self._current_max_idx += 1 |
|
else: |
|
if word not in self._ngram2word2idx[ngram]: |
|
self._ngram2idx2word[ngram][self._current_max_idx] = word |
|
self._ngram2word2idx[ngram][word] = self._current_max_idx |
|
self._current_max_idx += 1 |
|
|
|
return self._ngram2word2idx[ngram][word] |
|
|
|
def _is_contiguous(self): |
|
vocab_size = len(self) |
|
return list(range(vocab_size)) == [idx for idx, token in self._get_all_tokens()] |
|
|
|
def _get_all_tokens(self): |
|
"""Returns all tokens in the dictionary.""" |
|
for ngram in range(1, self.ngram + 1): |
|
for idx, token in self._ngram2idx2word[ngram].items(): |
|
yield idx, token |
|
|
|
def save_vocabulary( |
|
self, save_directory: str, filename_prefix: Optional[str] = None |
|
) -> Tuple[str]: |
|
filename = os.path.join( |
|
save_directory, |
|
(filename_prefix + "-" if filename_prefix else ""), |
|
self.vocab_file, |
|
) |
|
|
|
index = 0 |
|
vocab = {"ngram": self.ngram, "vocab": []} |
|
|
|
for ngram in range(1, self.ngram + 1): |
|
for idx, token in self._ngram2idx2word[ngram].items(): |
|
if index != idx: |
|
index = idx |
|
|
|
try: |
|
frequency = self._frequencies[token] |
|
except KeyError: |
|
frequency = -1 |
|
|
|
index += 1 |
|
vocab["vocab"].append( |
|
{ |
|
"token": token, |
|
"index": idx, |
|
"frequency": frequency, |
|
"ngram": ngram, |
|
} |
|
) |
|
|
|
with open(filename, "w", encoding="utf-8") as writer: |
|
json.dump(vocab, writer, indent=4, ensure_ascii=False) |
|
|
|
return (filename,) |
|
|
|
@property |
|
def vocab_size(self) -> int: |
|
return self._current_max_idx |
|
|
|
def _tokenize(self, text: str) -> List[str]: |
|
return list(text) |
|
|
|
def get_idx(self, token: str, ngram: Optional[int] = None) -> int: |
|
if ngram: |
|
if token in self._ngram2word2idx[ngram]: |
|
return self._ngram2word2idx[ngram][token] |
|
else: |
|
return self._ngram2word2idx[1]["<unk>"] |
|
|
|
for ngram in range(1, self.ngram + 1): |
|
if token in self._ngram2word2idx[ngram]: |
|
return self._ngram2word2idx[ngram][token] |
|
|
|
return self._ngram2word2idx[1]["<unk>"] |
|
|
|
def _convert_ngram_tokens_to_ids(self, ngram_tokens: List[str]) -> List[int]: |
|
return [self.get_idx(token) for token in ngram_tokens] |
|
|
|
def convert_tokens_to_ids(self, tokens: List[str]): |
|
if not tokens: |
|
return [] |
|
|
|
if isinstance(tokens, str): |
|
return self.get_idx(tokens) |
|
|
|
return self._convert_ngram_tokens_to_ids(tokens) |
|
|
|
def _convert_id_to_token(self, index: int) -> str: |
|
return self.get_item_for_index(index) |
|
|
|
def get_item_for_index(self, idx) -> str: |
|
"""Return the token for a given index.""" |
|
for idxs in self._ngram2idx2word.values(): |
|
if idx in idxs: |
|
return idxs[idx] |
|
|
|
return self.unk_token |
|
|
|
def convert_tokens_to_string(self, tokens): |
|
return "".join(tokens) |
|
|
|
def create_weight_tensor(self) -> torch.Tensor: |
|
unked_freqs = self._frequencies.most_common() |
|
|
|
t = torch.ones(len(self)) |
|
|
|
for token, freq in unked_freqs: |
|
t[self._ngram2word2idx[self._token_to_n_order(token)][token]] = freq |
|
|
|
|
|
t[self._ngram2word2idx[1][" "]] = 1.0 |
|
|
|
max_t = max(t) |
|
|
|
normed_weights = torch.tensor([(1 - (x / (max_t + 1))).item() for x in t]) |
|
|
|
marker_tokens = [self.get_idx("<unk>", n) for n in range(1, self.ngram + 1)] |
|
marker_tokens.extend( |
|
[self.get_idx("<start>", n) for n in range(1, self.ngram + 1)] |
|
) |
|
|
|
for marker in marker_tokens: |
|
normed_weights[marker] = 0 |
|
|
|
return normed_weights |
|
|
|
def _token_to_n_order(self, token: str) -> int: |
|
"""Get N-gram order for a token""" |
|
for n_gram, word2idx in self._ngram2word2idx.items(): |
|
if token in word2idx: |
|
return n_gram |
|
|
|
return 0 |
|
|
|
|
|
class GPTNGMETokenizer(PreTrainedTokenizer): |
|
model_input_names = ["input_ids", "attention_mask"] |
|
vocab_file = "vocab.json" |
|
vocab_files_names = {"vocab_file": vocab_file} |
|
|
|
def __init__( |
|
self, vocab_file, eos_token="\n", pad_token="\n", unk_token="<unk>", **kwargs |
|
): |
|
eos_token = ( |
|
AddedToken( |
|
eos_token, |
|
lstrip=False, |
|
rstrip=False, |
|
) |
|
if isinstance(eos_token, str) |
|
else eos_token |
|
) |
|
pad_token = ( |
|
AddedToken( |
|
pad_token, |
|
lstrip=False, |
|
rstrip=False, |
|
) |
|
if isinstance(pad_token, str) |
|
else pad_token |
|
) |
|
unk_token = ( |
|
AddedToken( |
|
unk_token, |
|
lstrip=False, |
|
rstrip=False, |
|
) |
|
if isinstance(unk_token, str) |
|
else unk_token |
|
) |
|
|
|
super().__init__( |
|
eos_token=eos_token, pad_token=pad_token, unk_token=unk_token, **kwargs |
|
) |
|
|
|
self._ngram2word2idx = {} |
|
self._ngram2idx2word = {} |
|
self._current_max_idx = 0 |
|
self._frequencies: Counter = Counter() |
|
|
|
self._load_from_file(vocab_file) |
|
|
|
def _load_from_file(self, filename: str): |
|
"""Loads a dictionary from a file.""" |
|
vocab_file = load_vocab(filename) |
|
self.ngram = vocab_file["ngram"] |
|
|
|
if "\n" not in vocab_file["vocab"]: |
|
self._add_ngram("\n", 1) |
|
|
|
for token in vocab_file["vocab"]: |
|
self._add_ngram(token["token"], token["ngram"]) |
|
self._frequencies.update({token["token"]: token["frequency"]}) |
|
|
|
def _add_ngram(self, word, ngram: int) -> int: |
|
"""Add a new n-gram token to the dictionary.""" |
|
self._frequencies.update({word: 1}) |
|
|
|
if ngram not in self._ngram2idx2word: |
|
self._ngram2idx2word[ngram] = {self._current_max_idx: word} |
|
self._ngram2word2idx[ngram] = {word: self._current_max_idx} |
|
self._current_max_idx += 1 |
|
else: |
|
if word not in self._ngram2word2idx[ngram]: |
|
self._ngram2idx2word[ngram][self._current_max_idx] = word |
|
self._ngram2word2idx[ngram][word] = self._current_max_idx |
|
self._current_max_idx += 1 |
|
|
|
return self._ngram2word2idx[ngram][word] |
|
|
|
def _is_contiguous(self): |
|
vocab_size = len(self) |
|
return list(range(vocab_size)) == [idx for idx, token in self._get_all_tokens()] |
|
|
|
def _get_all_tokens(self): |
|
"""Returns all tokens in the dictionary.""" |
|
for ngram in range(1, self.ngram + 1): |
|
for idx, token in self._ngram2idx2word[ngram].items(): |
|
yield idx, token |
|
|
|
def save_vocabulary( |
|
self, save_directory: str, filename_prefix: Optional[str] = None |
|
) -> Tuple[str]: |
|
filename = os.path.join( |
|
save_directory, |
|
(filename_prefix + "-" if filename_prefix else ""), |
|
self.vocab_file, |
|
) |
|
|
|
index = 0 |
|
vocab = {"ngram": self.ngram, "vocab": []} |
|
|
|
for ngram in range(1, self.ngram + 1): |
|
for idx, token in self._ngram2idx2word[ngram].items(): |
|
if index != idx: |
|
index = idx |
|
|
|
try: |
|
frequency = self._frequencies[token] |
|
except KeyError: |
|
frequency = -1 |
|
|
|
index += 1 |
|
vocab["vocab"].append( |
|
{ |
|
"token": token, |
|
"index": idx, |
|
"frequency": frequency, |
|
"ngram": ngram, |
|
} |
|
) |
|
|
|
with open(filename, "w", encoding="utf-8") as writer: |
|
json.dump(vocab, writer, indent=4, ensure_ascii=False) |
|
|
|
return (filename,) |
|
|
|
@property |
|
def vocab_size(self) -> int: |
|
return self._current_max_idx |
|
|
|
def retokenize(self, input_ids, *args, **kwargs): |
|
decoded = self.convert_ids_to_tokens(input_ids) |
|
sequence = "".join(decoded) |
|
new_decoded = self(sequence, *args, **kwargs).input_ids |
|
return new_decoded |
|
|
|
def _tokenize(self, text): |
|
ngram_sequences = [] |
|
for n in range(1, self.ngram + 1): |
|
words = ["<start>" for _ in range(1, n)] |
|
words.extend(list(text)) |
|
|
|
tokens = [] |
|
for i, word in enumerate(ngram_tokenizer(words, n)): |
|
if "<start>" in word: |
|
word = [w for w in list(word) if w != "<start>"] |
|
tokens.append("".join(word)) |
|
|
|
ngram_sequences.append(tokens) |
|
|
|
return ngram_sequences |
|
|
|
def get_idx(self, token: str, ngram: Optional[int] = None) -> int: |
|
if ngram: |
|
if token in self._ngram2word2idx[ngram]: |
|
return self._ngram2word2idx[ngram][token] |
|
else: |
|
return self._ngram2word2idx[1]["<unk>"] |
|
|
|
for ngram in range(1, self.ngram + 1): |
|
if token in self._ngram2word2idx[ngram]: |
|
return self._ngram2word2idx[ngram][token] |
|
|
|
return self._ngram2word2idx[1]["<unk>"] |
|
|
|
def _convert_ngram_tokens_to_ids(self, ngram_tokens: List[str]) -> List[int]: |
|
return [self.get_idx(token) for token in ngram_tokens] |
|
|
|
def convert_tokens_to_ids(self, tokens: List[List[str]]): |
|
if not tokens: |
|
return [] |
|
|
|
if isinstance(tokens, str): |
|
return self.get_idx(tokens) |
|
|
|
return [ |
|
self._convert_ngram_tokens_to_ids(ngram_tokens) for ngram_tokens in tokens |
|
] |
|
|
|
def _convert_id_to_token(self, index: int) -> str: |
|
return self.get_item_for_index(index) |
|
|
|
def get_item_for_index(self, idx) -> str: |
|
"""Return the token for a given index.""" |
|
for idxs in self._ngram2idx2word.values(): |
|
if idx in idxs: |
|
return idxs[idx] |
|
|
|
return self.unk_token |
|
|
|
def _decode( |
|
self, token_ids: List[List[int]], skip_special_tokens: bool = False, **kwargs |
|
) -> str: |
|
return "".join(self.convert_ids_to_tokens(token_ids[0])) |
|
|
|
def debug_decode(self, token_ids: List[List[int]]): |
|
for n in range(1, self.ngram + 1): |
|
print(f"{n}-gram: {self.convert_ids_to_tokens(token_ids[n-1])}") |
|
|
|
def _pad( |
|
self, |
|
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], |
|
max_length: Optional[int] = None, |
|
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, |
|
pad_to_multiple_of: Optional[int] = None, |
|
return_attention_mask: Optional[bool] = None, |
|
) -> dict: |
|
""" |
|
Pad encoded inputs (on left/right and up to predefined length or max length in the batch) |
|
|
|
Args: |
|
encoded_inputs: |
|
Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). |
|
max_length: maximum length of the returned list and optionally padding length (see below). |
|
Will truncate by taking into account the special tokens. |
|
padding_strategy: PaddingStrategy to use for padding. |
|
|
|
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch |
|
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default) |
|
- PaddingStrategy.DO_NOT_PAD: Do not pad |
|
The tokenizer padding sides are defined in self.padding_side: |
|
|
|
- 'left': pads on the left of the sequences |
|
- 'right': pads on the right of the sequences |
|
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. |
|
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability |
|
`>= 7.5` (Volta). |
|
return_attention_mask: |
|
(optional) Set to False to avoid returning attention mask (default: set to model specifics) |
|
""" |
|
|
|
|
|
|
|
if return_attention_mask is None: |
|
return_attention_mask = "attention_mask" in self.model_input_names |
|
|
|
required_input = encoded_inputs[self.model_input_names[0]] |
|
|
|
if ( |
|
len(required_input) != 0 |
|
and isinstance(required_input[0], list) |
|
and isinstance(required_input[0][0], list) |
|
): |
|
required_input = required_input[0] |
|
|
|
if padding_strategy == PaddingStrategy.LONGEST: |
|
max_length = len(required_input) |
|
|
|
if ( |
|
max_length is not None |
|
and pad_to_multiple_of is not None |
|
and (max_length % pad_to_multiple_of != 0) |
|
): |
|
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of |
|
|
|
needs_to_be_padded = ( |
|
padding_strategy != PaddingStrategy.DO_NOT_PAD |
|
and len(required_input[0]) != max_length |
|
) |
|
|
|
|
|
if return_attention_mask and "attention_mask" not in encoded_inputs: |
|
if len(required_input) == 0: |
|
encoded_inputs["attention_mask"] = [] |
|
else: |
|
encoded_inputs["attention_mask"] = [1] * len(required_input[0]) |
|
|
|
if needs_to_be_padded: |
|
difference = max_length - len(required_input[0]) |
|
|
|
if self.padding_side == "right": |
|
if return_attention_mask: |
|
encoded_inputs["attention_mask"] = ( |
|
encoded_inputs["attention_mask"] + [0] * difference |
|
) |
|
if "token_type_ids" in encoded_inputs: |
|
encoded_inputs["token_type_ids"] = ( |
|
encoded_inputs["token_type_ids"] |
|
+ [self.pad_token_type_id] * difference |
|
) |
|
if "special_tokens_mask" in encoded_inputs: |
|
encoded_inputs["special_tokens_mask"] = ( |
|
encoded_inputs["special_tokens_mask"] + [1] * difference |
|
) |
|
for i in range(len(encoded_inputs[self.model_input_names[0]])): |
|
encoded_inputs[self.model_input_names[0]][i] = ( |
|
required_input[i] + [self.pad_token_id] * difference |
|
) |
|
elif self.padding_side == "left": |
|
if return_attention_mask: |
|
encoded_inputs["attention_mask"] = [ |
|
0 |
|
] * difference + encoded_inputs["attention_mask"] |
|
if "token_type_ids" in encoded_inputs: |
|
encoded_inputs["token_type_ids"] = [ |
|
self.pad_token_type_id |
|
] * difference + encoded_inputs["token_type_ids"] |
|
if "special_tokens_mask" in encoded_inputs: |
|
encoded_inputs["special_tokens_mask"] = [ |
|
1 |
|
] * difference + encoded_inputs["special_tokens_mask"] |
|
|
|
for i in range(len(encoded_inputs[self.model_input_names[0]])): |
|
encoded_inputs[self.model_input_names[0]][i] = [ |
|
self.pad_token_id |
|
] * difference + required_input[i] |
|
else: |
|
raise ValueError("Invalid padding strategy:" + str(self.padding_side)) |
|
|
|
return encoded_inputs |
|
|
|
def pad( |
|
self, |
|
encoded_inputs: Union[ |
|
BatchEncoding, |
|
List[BatchEncoding], |
|
Dict[str, EncodedInput], |
|
Dict[str, List[EncodedInput]], |
|
List[Dict[str, EncodedInput]], |
|
], |
|
padding: Union[bool, str, PaddingStrategy] = True, |
|
max_length: Optional[int] = None, |
|
pad_to_multiple_of: Optional[int] = None, |
|
return_attention_mask: Optional[bool] = None, |
|
return_tensors: Optional[Union[str, TensorType]] = None, |
|
verbose: bool = True, |
|
) -> BatchEncoding: |
|
""" |
|
Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length |
|
in the batch. |
|
|
|
Padding side (left/right) padding token ids are defined at the tokenizer level (with `self.padding_side`, |
|
|
|
`self.pad_token_id` and `self.pad_token_type_id`). |
|
|
|
Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the |
|
text followed by a call to the `pad` method to get a padded encoding. |
|
|
|
<Tip> |
|
|
|
If the `encoded_inputs` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the |
|
result will use the same type unless you provide a different tensor type with `return_tensors`. In the case of |
|
PyTorch tensors, you will lose the specific device of your tensors however. |
|
|
|
</Tip> |
|
|
|
Args: |
|
encoded_inputs ([`BatchEncoding`], list of [`BatchEncoding`], `Dict[str, List[int]]`, `Dict[str, List[List[int]]` or `List[Dict[str, List[int]]]`): |
|
Tokenized inputs. Can represent one input ([`BatchEncoding`] or `Dict[str, List[int]]`) or a batch of |
|
tokenized inputs (list of [`BatchEncoding`], *Dict[str, List[List[int]]]* or *List[Dict[str, |
|
List[int]]]*) so you can use this method during preprocessing as well as in a PyTorch Dataloader |
|
collate function. |
|
|
|
Instead of `List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors), see |
|
the note above for the return type. |
|
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): |
|
Select a strategy to pad the returned sequences (according to the model's padding side and padding |
|
index) among: |
|
|
|
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single |
|
sequence if provided). |
|
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum |
|
acceptable input length for the model if that argument is not provided. |
|
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different |
|
lengths). |
|
max_length (`int`, *optional*): |
|
Maximum length of the returned list and optionally padding length (see above). |
|
pad_to_multiple_of (`int`, *optional*): |
|
If set will pad the sequence to a multiple of the provided value. |
|
|
|
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability |
|
`>= 7.5` (Volta). |
|
return_attention_mask (`bool`, *optional*): |
|
Whether to return the attention mask. If left to the default, will return the attention mask according |
|
to the specific tokenizer's default, defined by the `return_outputs` attribute. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
return_tensors (`str` or [`~utils.TensorType`], *optional*): |
|
If set, will return tensors instead of list of python integers. Acceptable values are: |
|
|
|
- `'tf'`: Return TensorFlow `tf.constant` objects. |
|
- `'pt'`: Return PyTorch `torch.Tensor` objects. |
|
- `'np'`: Return Numpy `np.ndarray` objects. |
|
verbose (`bool`, *optional*, defaults to `True`): |
|
Whether or not to print more information and warnings. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(encoded_inputs, (list, tuple)) and isinstance( |
|
encoded_inputs[0], Mapping |
|
): |
|
encoded_inputs = { |
|
key: [example[key] for example in encoded_inputs] |
|
for key in encoded_inputs[0].keys() |
|
} |
|
|
|
|
|
if self.model_input_names[0] not in encoded_inputs: |
|
raise ValueError( |
|
"You should supply an encoding or a list of encodings to this method " |
|
f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}" |
|
) |
|
|
|
required_input = encoded_inputs[self.model_input_names[0]] |
|
|
|
if required_input is None or ( |
|
isinstance(required_input, Sized) and len(required_input) == 0 |
|
): |
|
if return_attention_mask: |
|
encoded_inputs["attention_mask"] = [] |
|
return encoded_inputs |
|
|
|
|
|
|
|
|
|
|
|
first_element = required_input[0] |
|
|
|
if isinstance(first_element, (list, tuple)): |
|
|
|
for item in required_input: |
|
if len(item) != 0: |
|
first_element = item[0] |
|
break |
|
|
|
if not isinstance(first_element, (int, list, tuple)): |
|
if is_tf_tensor(first_element): |
|
return_tensors = "tf" if return_tensors is None else return_tensors |
|
elif is_torch_tensor(first_element): |
|
return_tensors = "pt" if return_tensors is None else return_tensors |
|
elif isinstance(first_element, np.ndarray): |
|
return_tensors = "np" if return_tensors is None else return_tensors |
|
else: |
|
raise ValueError( |
|
f"type of {first_element} unknown: {type(first_element)}. " |
|
"Should be one of a python, numpy, pytorch or tensorflow object." |
|
) |
|
|
|
for key, value in encoded_inputs.items(): |
|
encoded_inputs[key] = to_py_obj(value) |
|
|
|
|
|
padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies( |
|
padding=padding, max_length=max_length, verbose=verbose |
|
) |
|
|
|
required_input = encoded_inputs[self.model_input_names[0]] |
|
|
|
if required_input: |
|
if isinstance(required_input[0], (list, tuple)): |
|
if len(required_input[0]) > 0 and not isinstance( |
|
required_input[0][0], (list, tuple) |
|
): |
|
encoded_inputs = self._pad( |
|
encoded_inputs, |
|
max_length=max_length, |
|
padding_strategy=padding_strategy, |
|
pad_to_multiple_of=pad_to_multiple_of, |
|
return_attention_mask=return_attention_mask, |
|
) |
|
return BatchEncoding(encoded_inputs, tensor_type=return_tensors) |
|
|
|
batch_size = len(required_input) |
|
assert all( |
|
len(v) == batch_size for v in encoded_inputs.values() |
|
), "Some items in the output dictionary have a different batch size than others." |
|
|
|
if padding_strategy == PaddingStrategy.LONGEST: |
|
max_length = max(len(inputs[0]) for inputs in required_input) |
|
padding_strategy = PaddingStrategy.MAX_LENGTH |
|
|
|
batch_outputs = {} |
|
for i in range(batch_size): |
|
inputs = dict((k, v[i]) for k, v in encoded_inputs.items()) |
|
outputs = self._pad( |
|
inputs, |
|
max_length=max_length, |
|
padding_strategy=padding_strategy, |
|
pad_to_multiple_of=pad_to_multiple_of, |
|
return_attention_mask=return_attention_mask, |
|
) |
|
|
|
for key, value in outputs.items(): |
|
if key not in batch_outputs: |
|
batch_outputs[key] = [] |
|
batch_outputs[key].append(value) |
|
|
|
return BatchEncoding(batch_outputs, tensor_type=return_tensors) |
|
|
|
def prepare_for_model( |
|
self, |
|
ids: List[int], |
|
pair_ids: Optional[List[int]] = None, |
|
add_special_tokens: bool = True, |
|
padding: Union[bool, str, PaddingStrategy] = False, |
|
truncation: Union[bool, str, TruncationStrategy] = None, |
|
max_length: Optional[int] = None, |
|
stride: int = 0, |
|
pad_to_multiple_of: Optional[int] = None, |
|
return_tensors: Optional[Union[str, TensorType]] = None, |
|
return_token_type_ids: Optional[bool] = None, |
|
return_attention_mask: Optional[bool] = None, |
|
return_overflowing_tokens: bool = False, |
|
return_special_tokens_mask: bool = False, |
|
return_offsets_mapping: bool = False, |
|
return_length: bool = False, |
|
verbose: bool = True, |
|
prepend_batch_axis: bool = False, |
|
**kwargs, |
|
) -> BatchEncoding: |
|
""" |
|
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It |
|
adds special tokens, truncates sequences if overflowing while taking into account the special tokens and |
|
manages a moving window (with user defined stride) for overflowing tokens. Please Note, for *pair_ids* |
|
different than `None` and *truncation_strategy = longest_first* or `True`, it is not possible to return |
|
overflowing tokens. Such a combination of arguments will raise an error. |
|
Args: |
|
ids (`List[int]`): |
|
Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and |
|
`convert_tokens_to_ids` methods. |
|
pair_ids (`List[int]`, *optional*): |
|
Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize` |
|
and `convert_tokens_to_ids` methods. |
|
""" |
|
|
|
|
|
( |
|
padding_strategy, |
|
truncation_strategy, |
|
max_length, |
|
kwargs, |
|
) = self._get_padding_truncation_strategies( |
|
padding=padding, |
|
truncation=truncation, |
|
max_length=max_length, |
|
pad_to_multiple_of=pad_to_multiple_of, |
|
verbose=verbose, |
|
**kwargs, |
|
) |
|
|
|
pair = bool(pair_ids is not None) |
|
|
|
if len(ids) == 0: |
|
len_ids = 0 |
|
else: |
|
len_ids = len(ids[0]) |
|
|
|
if pair and len(pair_ids) == 0: |
|
len_pair_ids = 0 |
|
elif pair and len(pair_ids) > 0: |
|
len_pair_ids = len(pair_ids[0]) |
|
else: |
|
len_pair_ids = 0 |
|
|
|
if return_token_type_ids and not add_special_tokens: |
|
raise ValueError( |
|
"Asking to return token_type_ids while setting add_special_tokens to False " |
|
"results in an undefined behavior. Please set add_special_tokens to True or " |
|
"set return_token_type_ids to None." |
|
) |
|
|
|
if ( |
|
return_overflowing_tokens |
|
and truncation_strategy == TruncationStrategy.LONGEST_FIRST |
|
and pair_ids is not None |
|
): |
|
raise ValueError( |
|
"Not possible to return overflowing tokens for pair of sequences with the " |
|
"`longest_first`. Please select another truncation strategy than `longest_first`, " |
|
"for instance `only_second` or `only_first`." |
|
) |
|
|
|
|
|
if return_token_type_ids is None: |
|
return_token_type_ids = "token_type_ids" in self.model_input_names |
|
if return_attention_mask is None: |
|
return_attention_mask = "attention_mask" in self.model_input_names |
|
|
|
encoded_inputs = {} |
|
|
|
|
|
total_len = ( |
|
len_ids |
|
+ len_pair_ids |
|
+ (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0) |
|
) |
|
|
|
|
|
overflowing_tokens = [] |
|
if ( |
|
truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE |
|
and max_length |
|
and total_len > max_length |
|
): |
|
ids, pair_ids, overflowing_tokens = self.truncate_sequences( |
|
ids, |
|
pair_ids=pair_ids, |
|
num_tokens_to_remove=total_len - max_length, |
|
truncation_strategy=truncation_strategy, |
|
stride=stride, |
|
) |
|
|
|
if return_overflowing_tokens: |
|
encoded_inputs["overflowing_tokens"] = overflowing_tokens |
|
encoded_inputs["num_truncated_tokens"] = total_len - max_length |
|
|
|
|
|
if add_special_tokens: |
|
sequence = self.build_inputs_with_special_tokens(ids, pair_ids) |
|
token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) |
|
else: |
|
sequence = self.build_inputs_with_special_tokens(ids, pair_ids) |
|
token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else []) |
|
|
|
|
|
encoded_inputs["input_ids"] = sequence |
|
if return_token_type_ids: |
|
encoded_inputs["token_type_ids"] = token_type_ids |
|
if return_special_tokens_mask: |
|
if add_special_tokens: |
|
encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask( |
|
ids, pair_ids |
|
) |
|
else: |
|
encoded_inputs["special_tokens_mask"] = [0] * len(sequence) |
|
|
|
|
|
self._eventual_warn_about_too_long_sequence( |
|
encoded_inputs["input_ids"], max_length, verbose |
|
) |
|
|
|
|
|
if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask: |
|
encoded_inputs = self.pad( |
|
encoded_inputs, |
|
max_length=max_length, |
|
padding=padding_strategy.value, |
|
pad_to_multiple_of=pad_to_multiple_of, |
|
return_attention_mask=return_attention_mask, |
|
) |
|
|
|
if return_length: |
|
encoded_inputs["length"] = len(encoded_inputs["input_ids"]) |
|
|
|
batch_outputs = BatchEncoding( |
|
encoded_inputs, |
|
tensor_type=return_tensors, |
|
prepend_batch_axis=prepend_batch_axis, |
|
) |
|
|
|
return batch_outputs |
|
|
|
def build_inputs_with_special_tokens( |
|
self, |
|
token_ids_0: List[List[int]], |
|
token_ids_1: Optional[List[List[int]]] = None, |
|
) -> List[List[int]]: |
|
""" |
|
Concatenate nested ngram sequences. |
|
|
|
Args: |
|
token_ids_0 (`List[List[int]]`): The first tokenized sequence. |
|
token_ids_1 (`List[List[int]]`, *optional*): The second tokenized sequence. |
|
|
|
Returns: |
|
`List[List[int]]`: The model input with special tokens. |
|
""" |
|
if token_ids_1 is None or len(token_ids_1) == 0: |
|
return token_ids_0 |
|
|
|
if len(token_ids_0) == 0: |
|
return token_ids_1 |
|
|
|
return np.concatenate( |
|
(np.array(token_ids_0), np.array(token_ids_1)), axis=1 |
|
).tolist() |
|
|
|
def truncate_sequences( |
|
self, |
|
ids: List[int], |
|
pair_ids: Optional[List[int]] = None, |
|
num_tokens_to_remove: int = 0, |
|
truncation_strategy: Union[str, TruncationStrategy] = "longest_first", |
|
stride: int = 0, |
|
) -> Tuple[List[int], List[int], List[int]]: |
|
""" |
|
Truncates a sequence pair in-place following the strategy. |
|
Args: |
|
ids (`List[int]`): |
|
Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and |
|
`convert_tokens_to_ids` methods. |
|
pair_ids (`List[int]`, *optional*): |
|
Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize` |
|
and `convert_tokens_to_ids` methods. |
|
num_tokens_to_remove (`int`, *optional*, defaults to 0): |
|
Number of tokens to remove using the truncation strategy. |
|
truncation_strategy (`str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`): |
|
The strategy to follow for truncation. Can be: |
|
- `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or to the |
|
maximum acceptable input length for the model if that argument is not provided. This will truncate |
|
token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a |
|
batch of pairs) is provided. |
|
- `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the |
|
maximum acceptable input length for the model if that argument is not provided. This will only |
|
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. |
|
- `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the |
|
maximum acceptable input length for the model if that argument is not provided. This will only |
|
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. |
|
- `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths greater |
|
than the model maximum admissible input size). |
|
stride (`int`, *optional*, defaults to 0): |
|
If set to a positive number, the overflowing tokens returned will contain some tokens from the main |
|
sequence returned. The value of this argument defines the number of additional tokens. |
|
Returns: |
|
`Tuple[List[int], List[int], List[int]]`: The truncated `ids`, the truncated `pair_ids` and the list of |
|
overflowing tokens. Note: The *longest_first* strategy returns empty list of overflowing tokens if a pair |
|
of sequences (or a batch of pairs) is provided. |
|
""" |
|
if num_tokens_to_remove <= 0: |
|
return ids, pair_ids, [] |
|
|
|
if not isinstance(truncation_strategy, TruncationStrategy): |
|
truncation_strategy = TruncationStrategy(truncation_strategy) |
|
|
|
overflowing_tokens = [] |
|
if truncation_strategy == TruncationStrategy.ONLY_FIRST or ( |
|
truncation_strategy == TruncationStrategy.LONGEST_FIRST and pair_ids is None |
|
): |
|
ids = np.array(ids) |
|
|
|
|
|
if ids.shape[1] > num_tokens_to_remove: |
|
window_len = min(ids.shape[1], stride + num_tokens_to_remove) |
|
if self.truncation_side == "left": |
|
overflowing_tokens = ids[:, :window_len] |
|
ids = ids[:, num_tokens_to_remove:] |
|
elif self.truncation_side == "right": |
|
overflowing_tokens = ids[-window_len:] |
|
ids = ids[:, :-num_tokens_to_remove] |
|
else: |
|
raise ValueError( |
|
f"invalid truncation strategy: {self.truncation_side}, use 'left' or 'right'." |
|
) |
|
|
|
ids = ids.tolist() |
|
|
|
else: |
|
error_msg = ( |
|
f"We need to remove {num_tokens_to_remove} to truncate the input " |
|
f"but the first sequence has a length {len(ids)}. " |
|
) |
|
if truncation_strategy == TruncationStrategy.ONLY_FIRST: |
|
error_msg = ( |
|
error_msg + "Please select another truncation strategy than " |
|
f"{truncation_strategy}, for instance 'longest_first' or 'only_second'." |
|
) |
|
logger.error(error_msg) |
|
elif truncation_strategy == TruncationStrategy.LONGEST_FIRST: |
|
logger.warning( |
|
"Be aware, overflowing tokens are not returned for the setting you have chosen," |
|
f" i.e. sequence pairs with the '{TruncationStrategy.LONGEST_FIRST.value}' " |
|
"truncation strategy. So the returned list will always be empty even if some " |
|
"tokens have been removed." |
|
) |
|
ids = np.array(ids) |
|
pair_ids = np.array(pair_ids) |
|
|
|
for _ in range(num_tokens_to_remove): |
|
if pair_ids is None or ids.shape[1] > pair_ids.shape[1]: |
|
if self.truncation_side == "right": |
|
ids = ids[:, :-1] |
|
elif self.truncation_side == "left": |
|
ids = ids[:, 1:] |
|
else: |
|
raise ValueError( |
|
"invalid truncation strategy:" + str(self.truncation_side) |
|
) |
|
else: |
|
if self.truncation_side == "right": |
|
pair_ids = pair_ids[:, :-1] |
|
elif self.truncation_side == "left": |
|
pair_ids = pair_ids[:, 1:] |
|
else: |
|
raise ValueError( |
|
"invalid truncation strategy:" + str(self.truncation_side) |
|
) |
|
|
|
ids = ids.tolist() |
|
pair_ids = pair_ids.tolist() |
|
|
|
elif ( |
|
truncation_strategy == TruncationStrategy.ONLY_SECOND |
|
and pair_ids is not None |
|
): |
|
raise NotImplementedError( |
|
"PHA: I think we only truncate with longest first" |
|
) |
|
if len(pair_ids) > num_tokens_to_remove: |
|
window_len = min(len(pair_ids), stride + num_tokens_to_remove) |
|
if self.truncation_side == "right": |
|
overflowing_tokens = pair_ids[-window_len:] |
|
pair_ids = pair_ids[:-num_tokens_to_remove] |
|
elif self.truncation_side == "left": |
|
overflowing_tokens = pair_ids[:window_len] |
|
pair_ids = pair_ids[num_tokens_to_remove:] |
|
else: |
|
raise ValueError( |
|
"invalid truncation strategy:" + str(self.truncation_side) |
|
) |
|
else: |
|
logger.error( |
|
f"We need to remove {num_tokens_to_remove} to truncate the input " |
|
f"but the second sequence has a length {len(pair_ids)}. " |
|
f"Please select another truncation strategy than {truncation_strategy}, " |
|
"for instance 'longest_first' or 'only_first'." |
|
) |
|
|
|
return (ids, pair_ids, overflowing_tokens) |
|
|
|
def _token_to_n_order(self, token: str) -> int: |
|
"""Get N-gram order for a token""" |
|
for n_gram, word2idx in self._ngram2word2idx.items(): |
|
if token in word2idx: |
|
return n_gram |
|
|
|
return 0 |
|
|
|
def create_weight_tensor(self) -> torch.Tensor: |
|
unked_freqs = self._frequencies.most_common() |
|
|
|
t = torch.ones(len(self)) |
|
|
|
for token, freq in unked_freqs: |
|
t[self._ngram2word2idx[self._token_to_n_order(token)][token]] = freq |
|
|
|
|
|
t[self._ngram2word2idx[1][" "]] = 1.0 |
|
|
|
normed_weights = torch.tensor([(1 - (x / (max(t) + 1))).item() for x in t]) |
|
|
|
marker_tokens = [self.get_idx("<unk>", n) for n in range(1, self.ngram + 1)] |
|
marker_tokens.extend( |
|
[self.get_idx("<start>", n) for n in range(1, self.ngram + 1)] |
|
) |
|
|
|
for marker in marker_tokens: |
|
normed_weights[marker] = 0 |
|
|
|
return normed_weights |
|
|
|
|
|
class TestTokenizer(unittest.TestCase): |
|
def test_one(self): |
|
vocab_file = "/home/phmaker/Projects/ngme/vocabs/1-gram-babylm.json" |
|
|
|
t = NGMETokenizer(vocab_file) |
|
self.assertEqual(t.get_idx("<unk>", 1), 1) |
|
|
|
result = t("hello world") |
|
self.assertEqual(result.input_ids, [16, 3, 11, 11, 8, 2, 21, 8, 9, 11, 12]) |
|
|
|
result = t("<unk>") |
|
self.assertEqual(result.input_ids, [1, 13, 5, 24, 1]) |
|
|
|
result = t(["hello world", "<unk>"]) |
|
self.assertEqual( |
|
result.input_ids, |
|
[[16, 3, 11, 11, 8, 2, 21, 8, 9, 11, 12], [1, 13, 5, 24, 1]], |
|
) |
|
|
|
def test_three(self): |
|
vocab_file = "/home/phmaker/Projects/ngme/vocabs/3-gram-babylm.json" |
|
|
|
t = NGMETokenizer(vocab_file) |
|
|
|
result = t("hello world") |
|
self.assertEqual(result.input_ids, [16, 3, 11, 11, 8, 2, 21, 8, 9, 11, 12]) |
|
|
|
result = t("hello", return_ngram_sequences=True) |
|
|
|
result = t(["hello world"], return_ngram_sequences=True) |
|
two_gram_expected = [[16, 208, 229, 230, 231, 1, 1, 312, 257, 499, 306]] |
|
|
|
self.assertEqual(result["gram_2_sequence"], two_gram_expected) |
|
self.assertEqual(t._ngram2idx2word[1][16], "h") |
|
self.assertEqual(t._ngram2idx2word[2][208], "he") |
|
self.assertEqual(t._ngram2idx2word[2][229], "el") |
|
|
|
def test_unks(self): |
|
vocab_file = "/home/phmaker/Projects/ngme/vocabs/2-gram-wiki-en.json" |
|
t = NGMETokenizer(vocab_file) |
|
result = t("OciVDjöShG", return_ngram_sequences=True, return_tensors="pt") |
|
|
|
def test_decode(self): |
|
vocab_file = "/home/phmaker/Projects/ngme/vocabs/3-gram-babylm.json" |
|
t = NGMETokenizer(vocab_file) |
|
decoded = t.decode(208) |
|
assert decoded == "he" |
|
|
|
def test_padding(self): |
|
vocab_file = "/home/phmaker/Projects/ngme/vocabs/3-gram-babylm.json" |
|
t = NGMETokenizer(vocab_file) |
|
result = t( |
|
"hello world", |
|
return_tensors="pt", |
|
padding="max_length", |
|
max_length=20, |
|
return_ngram_sequences=True, |
|
) |
|
|
|
self.assertEqual(result.input_ids.shape, (1, 20)) |
|
self.assertEqual(result.gram_2_sequence.shape, (1, 20)) |
|
self.assertEqual(result.gram_3_sequence.shape, (1, 20)) |
|
|
|
def test_truncation(self): |
|
vocab_file = "/home/phmaker/Projects/ngme/vocabs/3-gram-babylm.json" |
|
t = NGMETokenizer(vocab_file) |
|
|
|
result = t( |
|
"hello world", |
|
return_tensors="pt", |
|
truncation=True, |
|
max_length=5, |
|
return_ngram_sequences=True, |
|
) |
|
self.assertEqual(result.input_ids.shape, (1, 5)) |
|
self.assertEqual(result.gram_2_sequence.shape, (1, 5)) |
|
|
|
def test_padding_and_truncation(self): |
|
vocab_file = "/home/phmaker/Projects/ngme/vocabs/3-gram-babylm.json" |
|
t = NGMETokenizer(vocab_file) |
|
|
|
result = t( |
|
["four", "something longer"], |
|
return_tensors="pt", |
|
padding="max_length", |
|
truncation=True, |
|
max_length=5, |
|
return_ngram_sequences=True, |
|
) |
|
self.assertEqual(result.input_ids.shape, (2, 5)) |
|
self.assertEqual(result.gram_2_sequence.shape, (2, 5)) |
|
|
|
|
|
if __name__ == "__main__": |
|
unittest.main() |
|
|