File size: 7,514 Bytes
72268ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
from sentencepiece import SentencePieceProcessor
import os
import torch
class ExLlamaTokenizer:
def __init__(self, tokenizer_model_path):
self.path = tokenizer_model_path
self.tokenizer = SentencePieceProcessor(model_file = self.path)
self.unk_token = "<unk>"
self.bos_token = "<s>"
self.eos_token = "</s>"
self.unk_token_id = self.tokenizer.unk_id() # is the same as pad token id...
self.eos_token_id = self.tokenizer.eos_id()
self.bos_token_id = self.tokenizer.bos_id()
self.pad_token_id = 0 # self.tokenizer.pad_id()
self.newline_token_id = 13
self.special_characters = [(self.bos_token, self.bos_token_id), (self.eos_token, self.eos_token_id), (self.unk_token, self.unk_token_id)] # for tokenzier encoding
# Encode string
def encode(self, text, return_mask = False, max_seq_len = 2048, add_bos = False, add_eos = False, encode_special_characters = False):
if isinstance(text, list):
# text is a list of strings
list_ids = self.tokenizer.EncodeAsIds(text)
# pad bos and eos
if add_bos:
for ids in list_ids: ids.insert(0, self.bos_token_id)
if add_eos:
for ids in list_ids: ids.append(self.eos_token_id)
max_length = max([len(ids) for ids in list_ids])
needs_mask = False
padded_ids = []
for ids in list_ids:
if len(ids) != len(list_ids[0]): needs_mask = True
padding = torch.full((max_length - len(ids),), self.pad_token_id)
sequence = torch.tensor(ids)
padded_ids.append(torch.cat((padding, sequence), dim = 0).long())
stacked_ids = torch.stack(padded_ids, dim = 0)
if return_mask:
if needs_mask:
mask_padding = torch.full((stacked_ids.shape[0], max_seq_len - stacked_ids.shape[1]), True, dtype = torch.bool, device = "cpu")
mask = stacked_ids != 0
mask = torch.cat((mask, mask_padding), dim = 1)
return stacked_ids, mask
else:
return stacked_ids, None
else:
return stacked_ids
else:
# text is a single string
split_text = [text]
# look for special characters
if encode_special_characters:
for special_character, special_token_id in self.special_characters:
temp_text = []
for segment in split_text:
if isinstance(segment, str) and special_character in segment:
# for each special character, append the text before the special character, then append the special character ID, then the rest of the text
parts = segment.split(special_character)
new_parts = []
for i, part in enumerate(parts):
new_parts.append(part)
if i < len(parts) - 1: # add the special token id between parts, but not after the last part
new_parts.append(special_token_id)
temp_text.extend(new_parts)
else:
temp_text.append(segment)
split_text = temp_text
ids = []
for text_chunk in split_text:
if isinstance(text_chunk, str):
ids += self.tokenizer.EncodeAsIds(text_chunk)
else:
ids.append(text_chunk)
# pad bos and eos
if add_bos:
ids = [self.bos_token_id] + ids
if add_eos:
ids = ids + [self.eos_token_id]
stacked_ids = torch.tensor(ids).unsqueeze(0)
if return_mask:
return stacked_ids, None
else:
return stacked_ids
def decode(self, ids, decode_special_characters=False):
special_ids = {id_: char for char, id_ in self.special_characters} # create a lookup dictionary
if ids.dim() > 1:
texts = []
for i in range(ids.shape[0]):
seq = ids[i].tolist()
seq = [t for t in seq if t != self.pad_token_id]
if decode_special_characters:
text_parts = []
normal_ids = [] # list of lists
current_normal_ids = [] # current list of normal IDs
for idx, id_ in enumerate(seq):
if id_ in special_ids:
# Save the current list of normal IDs, then start a new one
normal_ids.append(current_normal_ids)
current_normal_ids = []
# Store special token as a string
text_parts.append(special_ids[id_])
else:
current_normal_ids.append(id_)
normal_ids.append(current_normal_ids) # save the last segment of normal IDs
decoded_segments = [self.tokenizer.Decode(segment) for segment in normal_ids]
for idx, decoded_segment in enumerate(decoded_segments):
text_parts.insert(2*idx, decoded_segment)
texts.append("".join(text_parts))
else:
if self.eos_token_id in seq: # to not mess up special char decoding
seq = seq[:seq.index(self.eos_token_id)]
texts.append(self.tokenizer.Decode(seq))
return texts
else:
ids = ids.tolist()
if decode_special_characters:
text_parts = []
normal_ids = [] # list of lists
current_normal_ids = [] # current list of normal IDs
for idx, id_ in enumerate(ids):
if id_ in special_ids:
# Save the current list of normal IDs, then start a new one
normal_ids.append(current_normal_ids)
current_normal_ids = []
# Store special token as a string
text_parts.append(special_ids[id_])
else:
current_normal_ids.append(id_)
normal_ids.append(current_normal_ids) # save the last segment of normal IDs
decoded_segments = [self.tokenizer.Decode(segment) for segment in normal_ids]
for idx, decoded_segment in enumerate(decoded_segments):
text_parts.insert(2*idx, decoded_segment)
text = "".join(text_parts)
else:
text = self.tokenizer.Decode(ids)
return text
def num_tokens(self, text, encode_special_characters = False):
if encode_special_characters:
ids = self.encode(text, encode_special_characters = True)
return ids.size(1)
else:
ids = self.tokenizer.Encode(text)
return len(ids) |