|
from sentencepiece import SentencePieceProcessor |
|
import os |
|
import torch |
|
|
|
class ExLlamaTokenizer: |
|
|
|
def __init__(self, tokenizer_model_path): |
|
|
|
self.path = tokenizer_model_path |
|
self.tokenizer = SentencePieceProcessor(model_file = self.path) |
|
|
|
self.unk_token = "<unk>" |
|
self.bos_token = "<s>" |
|
self.eos_token = "</s>" |
|
self.unk_token_id = self.tokenizer.unk_id() |
|
self.eos_token_id = self.tokenizer.eos_id() |
|
self.bos_token_id = self.tokenizer.bos_id() |
|
self.pad_token_id = 0 |
|
self.newline_token_id = 13 |
|
|
|
self.special_characters = [(self.bos_token, self.bos_token_id), (self.eos_token, self.eos_token_id), (self.unk_token, self.unk_token_id)] |
|
|
|
|
|
|
|
def encode(self, text, return_mask = False, max_seq_len = 2048, add_bos = False, add_eos = False, encode_special_characters = False): |
|
|
|
if isinstance(text, list): |
|
|
|
|
|
|
|
list_ids = self.tokenizer.EncodeAsIds(text) |
|
|
|
|
|
|
|
if add_bos: |
|
for ids in list_ids: ids.insert(0, self.bos_token_id) |
|
if add_eos: |
|
for ids in list_ids: ids.append(self.eos_token_id) |
|
|
|
max_length = max([len(ids) for ids in list_ids]) |
|
|
|
needs_mask = False |
|
padded_ids = [] |
|
for ids in list_ids: |
|
if len(ids) != len(list_ids[0]): needs_mask = True |
|
padding = torch.full((max_length - len(ids),), self.pad_token_id) |
|
sequence = torch.tensor(ids) |
|
padded_ids.append(torch.cat((padding, sequence), dim = 0).long()) |
|
|
|
stacked_ids = torch.stack(padded_ids, dim = 0) |
|
|
|
if return_mask: |
|
if needs_mask: |
|
mask_padding = torch.full((stacked_ids.shape[0], max_seq_len - stacked_ids.shape[1]), True, dtype = torch.bool, device = "cpu") |
|
mask = stacked_ids != 0 |
|
mask = torch.cat((mask, mask_padding), dim = 1) |
|
return stacked_ids, mask |
|
else: |
|
return stacked_ids, None |
|
else: |
|
return stacked_ids |
|
|
|
else: |
|
|
|
|
|
split_text = [text] |
|
|
|
|
|
if encode_special_characters: |
|
for special_character, special_token_id in self.special_characters: |
|
temp_text = [] |
|
for segment in split_text: |
|
if isinstance(segment, str) and special_character in segment: |
|
|
|
parts = segment.split(special_character) |
|
new_parts = [] |
|
for i, part in enumerate(parts): |
|
new_parts.append(part) |
|
if i < len(parts) - 1: |
|
new_parts.append(special_token_id) |
|
temp_text.extend(new_parts) |
|
else: |
|
temp_text.append(segment) |
|
split_text = temp_text |
|
|
|
ids = [] |
|
|
|
for text_chunk in split_text: |
|
if isinstance(text_chunk, str): |
|
ids += self.tokenizer.EncodeAsIds(text_chunk) |
|
else: |
|
ids.append(text_chunk) |
|
|
|
|
|
|
|
if add_bos: |
|
ids = [self.bos_token_id] + ids |
|
if add_eos: |
|
ids = ids + [self.eos_token_id] |
|
|
|
stacked_ids = torch.tensor(ids).unsqueeze(0) |
|
|
|
if return_mask: |
|
return stacked_ids, None |
|
else: |
|
return stacked_ids |
|
|
|
def decode(self, ids, decode_special_characters=False): |
|
|
|
special_ids = {id_: char for char, id_ in self.special_characters} |
|
|
|
if ids.dim() > 1: |
|
|
|
texts = [] |
|
for i in range(ids.shape[0]): |
|
seq = ids[i].tolist() |
|
seq = [t for t in seq if t != self.pad_token_id] |
|
|
|
if decode_special_characters: |
|
text_parts = [] |
|
normal_ids = [] |
|
current_normal_ids = [] |
|
for idx, id_ in enumerate(seq): |
|
if id_ in special_ids: |
|
|
|
normal_ids.append(current_normal_ids) |
|
current_normal_ids = [] |
|
|
|
text_parts.append(special_ids[id_]) |
|
else: |
|
current_normal_ids.append(id_) |
|
normal_ids.append(current_normal_ids) |
|
|
|
decoded_segments = [self.tokenizer.Decode(segment) for segment in normal_ids] |
|
for idx, decoded_segment in enumerate(decoded_segments): |
|
text_parts.insert(2*idx, decoded_segment) |
|
|
|
texts.append("".join(text_parts)) |
|
else: |
|
if self.eos_token_id in seq: |
|
seq = seq[:seq.index(self.eos_token_id)] |
|
texts.append(self.tokenizer.Decode(seq)) |
|
|
|
return texts |
|
|
|
else: |
|
|
|
ids = ids.tolist() |
|
|
|
if decode_special_characters: |
|
|
|
text_parts = [] |
|
normal_ids = [] |
|
current_normal_ids = [] |
|
for idx, id_ in enumerate(ids): |
|
if id_ in special_ids: |
|
|
|
normal_ids.append(current_normal_ids) |
|
current_normal_ids = [] |
|
|
|
text_parts.append(special_ids[id_]) |
|
else: |
|
current_normal_ids.append(id_) |
|
normal_ids.append(current_normal_ids) |
|
|
|
decoded_segments = [self.tokenizer.Decode(segment) for segment in normal_ids] |
|
for idx, decoded_segment in enumerate(decoded_segments): |
|
text_parts.insert(2*idx, decoded_segment) |
|
|
|
text = "".join(text_parts) |
|
|
|
else: |
|
|
|
text = self.tokenizer.Decode(ids) |
|
|
|
return text |
|
|
|
|
|
def num_tokens(self, text, encode_special_characters = False): |
|
|
|
if encode_special_characters: |
|
|
|
ids = self.encode(text, encode_special_characters = True) |
|
return ids.size(1) |
|
|
|
else: |
|
|
|
ids = self.tokenizer.Encode(text) |
|
return len(ids) |