Spaces:
Running
Running
from tclogger import logger | |
from transformers import AutoTokenizer | |
from constants.models import MODEL_MAP, TOKEN_LIMIT_MAP, TOKEN_RESERVED | |
class TokenChecker: | |
def __init__(self, input_str: str, model: str): | |
self.input_str = input_str | |
if model in MODEL_MAP.keys(): | |
self.model = model | |
else: | |
self.model = "mixtral-8x7b" | |
self.model_fullname = MODEL_MAP[self.model] | |
# As some models are gated, we need to fetch tokenizers from alternatives | |
GATED_MODEL_MAP = { | |
"llama3-70b": "NousResearch/Meta-Llama-3-70B", | |
"gemma-7b": "unsloth/gemma-7b", | |
"mistral-7b": "dfurman/Mistral-7B-Instruct-v0.2", | |
"mixtral-8x7b": "dfurman/Mixtral-8x7B-Instruct-v0.1", | |
} | |
if self.model in GATED_MODEL_MAP.keys(): | |
self.tokenizer = AutoTokenizer.from_pretrained(GATED_MODEL_MAP[self.model]) | |
else: | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_fullname) | |
def count_tokens(self): | |
token_count = len(self.tokenizer.encode(self.input_str)) | |
logger.note(f"Prompt Token Count: {token_count}") | |
return token_count | |
def get_token_limit(self): | |
return TOKEN_LIMIT_MAP[self.model] | |
def get_token_redundancy(self): | |
return int(self.get_token_limit() - TOKEN_RESERVED - self.count_tokens()) | |
def check_token_limit(self): | |
if self.get_token_redundancy() <= 0: | |
raise ValueError(f"Prompt exceeded token limit: {self.get_token_limit()}") | |
return True | |