|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
import os |
|
import os.path as osp |
|
import warnings |
|
from dataclasses import asdict |
|
from typing import Any, Dict, List, Optional, Sequence, Tuple |
|
|
|
import torch |
|
import transformers |
|
from huggingface_hub import file_exists, repo_exists |
|
from huggingface_hub.utils import HFValidationError |
|
from transformers import ( |
|
AutoConfig, |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
PretrainedConfig, |
|
PreTrainedModel, |
|
PreTrainedTokenizer, |
|
) |
|
|
|
|
|
from .conversation import SeparatorStyle, default_conversation |
|
|
|
SENTINEL_TOKEN = "<vila/sentinel>" |
|
MEDIA_TOKENS = { |
|
"image": "<image>", |
|
"video": "<vila/video>", |
|
} |
|
|
|
|
|
|
|
|
|
|
|
DUMMY_CONVERSATION = [ |
|
{"from": "human", "value": "question"}, |
|
{"from": "gpt", "value": "answer"}, |
|
] * 10 |
|
|
|
|
|
def tokenizer_image_token(prompt, tokenizer, return_tensors=None): |
|
return tokenizer(prompt, return_tensors=return_tensors).input_ids[0] |
|
|
|
|
|
def has_tokenizer(repo_id_or_path: str) -> bool: |
|
|
|
if osp.exists(osp.join(repo_id_or_path, "tokenizer_config.json")): |
|
return True |
|
|
|
|
|
try: |
|
return repo_exists(repo_id_or_path) and file_exists(repo_id_or_path, "tokenizer_config.json") |
|
except HFValidationError: |
|
return False |
|
|
|
|
|
def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None: |
|
if not hasattr(tokenizer, "sentinel_token"): |
|
tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True) |
|
tokenizer.sentinel_token = SENTINEL_TOKEN |
|
tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN) |
|
|
|
|
|
def tokenize_conversation_legacy( |
|
messages: Sequence[Dict[str, str]], |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
add_generation_prompt: bool = False, |
|
overrides: Optional[Dict[str, str]] = None, |
|
no_system_prompt: bool = False, |
|
) -> torch.Tensor: |
|
conv = default_conversation.copy() |
|
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} |
|
|
|
if no_system_prompt: |
|
conv.system = "" |
|
|
|
|
|
if messages[0]["from"] != "human": |
|
messages = messages[1:] |
|
|
|
|
|
if add_generation_prompt: |
|
messages.append({"from": "gpt", "value": None}) |
|
|
|
conv.messages = [] |
|
for turn, message in enumerate(messages): |
|
role = roles[message["from"]] |
|
assert role == conv.roles[turn % 2] |
|
if overrides is not None and message["from"] in overrides: |
|
conv.append_message(role, overrides[message["from"]]) |
|
else: |
|
conv.append_message(role, message["value"]) |
|
|
|
return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt") |
|
|
|
|
|
def tokenize_conversation( |
|
messages: Sequence[Dict[str, str]], |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
add_generation_prompt: bool = False, |
|
overrides: Optional[Dict[str, str]] = None, |
|
no_system_prompt: bool = False, |
|
) -> torch.Tensor: |
|
|
|
for message in messages: |
|
message["value"] = message["value"].strip() |
|
|
|
if default_conversation.sep_style != SeparatorStyle.AUTO: |
|
return tokenize_conversation_legacy( |
|
messages, |
|
tokenizer, |
|
add_generation_prompt=add_generation_prompt, |
|
overrides=overrides, |
|
no_system_prompt=no_system_prompt, |
|
) |
|
|
|
conversation = [] |
|
for m in messages: |
|
message = {} |
|
if m["from"] == "human": |
|
message["role"] = "user" |
|
elif m["from"] == "gpt": |
|
message["role"] = "assistant" |
|
else: |
|
raise ValueError(f"Unexpected sender '{m['from']}' in conversation entry.") |
|
|
|
message["content"] = m["value"] |
|
if overrides is not None and m["from"] in overrides: |
|
message["content"] = overrides[m["from"]] |
|
conversation.append(message) |
|
|
|
if no_system_prompt: |
|
conversation = [{"role": "system", "content": ""}] + conversation |
|
|
|
text = tokenizer.apply_chat_template( |
|
conversation, |
|
add_generation_prompt=add_generation_prompt, |
|
tokenize=False, |
|
) |
|
return tokenizer_image_token(text, tokenizer, return_tensors="pt") |
|
|
|
|
|
def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]: |
|
_maybe_add_sentinel_token(tokenizer) |
|
template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN}) |
|
|
|
stop_tokens = {tokenizer.eos_token} |
|
for k in range(template.size(0) - 1): |
|
if template[k] == tokenizer.sentinel_token_id: |
|
stop_token = tokenizer.decode(template[k + 1]) |
|
stop_tokens.add(stop_token) |
|
return list(stop_tokens) |
|
|
|
|
|
def context_length_extension(config): |
|
orig_ctx_len = getattr(config, "max_position_embeddings", None) |
|
model_max_length = getattr(config, "model_max_length", None) |
|
if orig_ctx_len and model_max_length > orig_ctx_len: |
|
print(f"Scaling RoPE from {orig_ctx_len} to {model_max_length}") |
|
scaling_factor = float(math.ceil(model_max_length / orig_ctx_len)) |
|
config.rope_scaling = {"type": "linear", "factor": scaling_factor} |
|
return config |
|
|
|
|
|
def build_llm_and_tokenizer( |
|
model_name_or_path: str, |
|
config: PretrainedConfig, |
|
attn_implementation=None, |
|
model_max_length=None, |
|
*args, |
|
**kwargs, |
|
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: |
|
|
|
llm_cfg = AutoConfig.from_pretrained(model_name_or_path) |
|
llm_cfg._attn_implementation = attn_implementation |
|
llm_cfg.model_max_length = model_max_length |
|
if model_max_length is not None: |
|
context_length_extension(llm_cfg) |
|
|
|
|
|
quantization_restore_from_checkpoint = False |
|
|
|
if quantization_restore_from_checkpoint: |
|
fp8_model_name_or_path = kwargs.pop("fp8_llm_cfg", None) |
|
|
|
llm = AutoModelForCausalLM.from_pretrained( |
|
fp8_model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs |
|
) |
|
else: |
|
llm = AutoModelForCausalLM.from_pretrained( |
|
model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs |
|
) |
|
|
|
|
|
|
|
|
|
llm_path = model_name_or_path |
|
if not has_tokenizer(llm_path): |
|
llm_path = osp.join(llm_path, "llm") |
|
if not has_tokenizer(llm_path): |
|
raise ValueError(f"Cannot find tokenizer in {llm_path}.") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(llm_path, padding_side="right", use_fast=True, legacy=False) |
|
if model_max_length is not None: |
|
tokenizer.model_max_length = model_max_length |
|
|
|
|
|
if getattr(config, "chat_template", None) is not None: |
|
print(f"Using chat template: {config.chat_template}") |
|
fpath = os.path.join(os.path.dirname(__file__), "chat_templates", f"{config.chat_template}.jinja") |
|
if not os.path.exists(fpath): |
|
fpath = os.path.join(os.path.dirname(model_name_or_path), f"{config.chat_template}.jinja") |
|
with open(fpath) as fd: |
|
chat_template = fd.read() |
|
tokenizer.chat_template = chat_template.replace(" ", "").replace("\n", "") |
|
|
|
|
|
|
|
tokenizer.stop_tokens = infer_stop_tokens(tokenizer) |
|
tokenizer.stop_token_ids = tokenizer.convert_tokens_to_ids(tokenizer.stop_tokens) |
|
|
|
|
|
tokenizer.media_tokens = MEDIA_TOKENS |
|
tokenizer.media_token_ids = {} |
|
for name, token in MEDIA_TOKENS.items(): |
|
tokenizer.add_tokens([token], special_tokens=True) |
|
tokenizer.media_token_ids[name] = tokenizer.convert_tokens_to_ids(token) |
|
|
|
|
|
config.hidden_size = llm.config.hidden_size |
|
return llm, tokenizer |
|
|