|
|
|
|
|
|
|
|
|
|
|
import os |
|
from pathlib import Path |
|
|
|
import tiktoken |
|
from tiktoken.load import load_tiktoken_bpe |
|
|
|
|
|
class Llama3Tokenizer: |
|
def __init__(self, model_path): |
|
assert os.path.isfile(model_path), f"Model file {model_path} not found" |
|
mergeable_ranks = load_tiktoken_bpe(model_path) |
|
|
|
self.special_tokens = { |
|
"<|begin_of_text|>": 128000, |
|
"<|end_of_text|>": 128001, |
|
"<|start_header_id|>": 128006, |
|
"<|end_header_id|>": 128007, |
|
"<|eot_id|>": 128009, |
|
} |
|
self.special_tokens.update({ |
|
f"<|reserved_{i}|>": 128002 + i for i in range(256) if (128002 + i) not in self.special_tokens.values() |
|
}) |
|
|
|
self.model = tiktoken.Encoding( |
|
name=Path(model_path).name, |
|
pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+", |
|
mergeable_ranks=mergeable_ranks, |
|
special_tokens=self.special_tokens |
|
) |
|
|
|
def encode(self, text, bos=False, eos=False, allowed_special=set(), disallowed_special=()): |
|
if bos: |
|
tokens = [self.special_tokens["<|begin_of_text|>"]] |
|
else: |
|
tokens = [] |
|
|
|
tokens += self.model.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special) |
|
|
|
if eos: |
|
tokens.append(self.special_tokens["<|end_of_text|>"]) |
|
return tokens |
|
|
|
def decode(self, tokens): |
|
return self.model.decode(tokens) |
|
|
|
|
|
class ChatFormat: |
|
def __init__(self, tokenizer): |
|
self.tokenizer = tokenizer |
|
|
|
def encode_header(self, message): |
|
tokens = [] |
|
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"]) |
|
tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False)) |
|
tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"]) |
|
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False)) |
|
return tokens |
|
|
|
def encode(self, text): |
|
message = { |
|
"role": "user", |
|
"content": text |
|
} |
|
|
|
tokens = self.encode_header(message) |
|
tokens.extend( |
|
self.tokenizer.encode(message["content"].strip(), bos=False, eos=False) |
|
) |
|
tokens.append(self.tokenizer.special_tokens["<|eot_id|>"]) |
|
return tokens |
|
|
|
def decode(self, token_ids): |
|
return self.tokenizer.decode(token_ids) |
|
|
|
|
|
def clean_text(text, header_end="assistant<|end_header_id|>\n\n"): |
|
|
|
index = text.find(header_end) |
|
|
|
if index != -1: |
|
|
|
return text[index + len(header_end):].strip() |
|
else: |
|
|
|
return text |
|
|