File size: 1,627 Bytes
61ebb7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
from tclogger import logger
from transformers import AutoTokenizer
from constants.models import MODEL_MAP, TOKEN_LIMIT_MAP, TOKEN_RESERVED
class TokenChecker:
def __init__(self, input_str: str, model: str):
self.input_str = input_str
if model in MODEL_MAP.keys():
self.model = model
else:
self.model = "nous-mixtral-8x7b"
self.model_fullname = MODEL_MAP[self.model]
# As some models are gated, we need to fetch tokenizers from alternatives
GATED_MODEL_MAP = {
"llama3-70b": "NousResearch/Meta-Llama-3-70B",
"gemma-7b": "unsloth/gemma-7b",
"mistral-7b": "dfurman/Mistral-7B-Instruct-v0.2",
"mixtral-8x7b": "dfurman/Mixtral-8x7B-Instruct-v0.1",
}
if self.model in GATED_MODEL_MAP.keys():
self.tokenizer = AutoTokenizer.from_pretrained(GATED_MODEL_MAP[self.model])
else:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_fullname)
def count_tokens(self):
token_count = len(self.tokenizer.encode(self.input_str))
logger.note(f"Prompt Token Count: {token_count}")
return token_count
def get_token_limit(self):
return TOKEN_LIMIT_MAP[self.model]
def get_token_redundancy(self):
return int(self.get_token_limit() - TOKEN_RESERVED - self.count_tokens())
def check_token_limit(self):
if self.get_token_redundancy() <= 0:
raise ValueError(
f"Prompt exceeded token limit: {self.count_tokens()} > {self.get_token_limit()}"
)
return True
|