aiben / src /prompter_utils.py
abugaber's picture
Upload folder using huggingface_hub
3943768 verified
raw
history blame
2.01 kB
import base64
from enums import unknown_prompt_type, template_prompt_type
def get_use_chat_template(tokenizer, prompt_type=None):
if tokenizer is None:
return False
use_chat_template = prompt_type in [None, '', unknown_prompt_type, template_prompt_type] and \
has_chat_template(tokenizer)
return use_chat_template
def has_chat_template(tokenizer):
return (hasattr(tokenizer, 'chat_template') and
tokenizer.chat_template not in [None, ''] or
hasattr(tokenizer, 'default_chat_template') and
tokenizer.default_chat_template not in [None, '']
)
def get_chat_template(tokenizer):
if tokenizer is None:
return None
if hasattr(tokenizer, 'chat_template') and tokenizer.chat_template not in [None, '']:
return tokenizer.chat_template
if hasattr(tokenizer, 'default_chat_template') and tokenizer.default_chat_template not in [None, '']:
return tokenizer.default_chat_template
return None
def base64_encode_jinja_template(template_str):
encoded_bytes = base64.b64encode(template_str.encode('utf-8'))
encoded_str = encoded_bytes.decode('utf-8')
return encoded_str
def base64_decode_jinja_template(encoded_str):
if is_base64(encoded_str):
decoded_bytes = base64.b64decode(encoded_str.encode('utf-8'))
decoded_str = decoded_bytes.decode('utf-8')
return decoded_str
else:
# just normal string, pass along
return encoded_str
def is_base64(s):
# Check if the length is a multiple of 4
if len(s) % 4 != 0:
return False
# Check if the string contains only valid base64 characters
try:
# Try to decode the base64 string
decoded = base64.b64decode(s, validate=True)
# Check if the decoded bytes can be converted to a UTF-8 string
decoded.decode('utf-8')
except Exception:
return False
return True