VILA15-40b-hf-preview / builder.py
Ligeng-Zhu's picture
Upload files with vila-upload.
48c35d1 verified
raw
history blame
8.84 kB
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
import math
import os
import os.path as osp
import warnings
from dataclasses import asdict
from typing import Any, Dict, List, Optional, Sequence, Tuple
import torch
import transformers
from huggingface_hub import file_exists, repo_exists
from huggingface_hub.utils import HFValidationError
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
PretrainedConfig,
PreTrainedModel,
PreTrainedTokenizer,
)
# from .conversation import *
from .conversation import SeparatorStyle, default_conversation
SENTINEL_TOKEN = "<vila/sentinel>"
MEDIA_TOKENS = {
"image": "<image>",
"video": "<vila/video>",
}
# from llava.model.utils import packing
# from llava.utils.logging import logger
# from llava.utils.tokenizer import infer_stop_tokens
DUMMY_CONVERSATION = [
{"from": "human", "value": "question"},
{"from": "gpt", "value": "answer"},
] * 10
def tokenizer_image_token(prompt, tokenizer, return_tensors=None):
return tokenizer(prompt, return_tensors=return_tensors).input_ids[0]
def has_tokenizer(repo_id_or_path: str) -> bool:
# Check if the tokenizer is in a local directory
if osp.exists(osp.join(repo_id_or_path, "tokenizer_config.json")):
return True
# Check if the tokenizer is in a Hugging Face Hub repo
try:
return repo_exists(repo_id_or_path) and file_exists(repo_id_or_path, "tokenizer_config.json")
except HFValidationError:
return False
def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None:
if not hasattr(tokenizer, "sentinel_token"):
tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True)
tokenizer.sentinel_token = SENTINEL_TOKEN
tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN)
def tokenize_conversation_legacy(
messages: Sequence[Dict[str, str]],
tokenizer: transformers.PreTrainedTokenizer,
add_generation_prompt: bool = False,
overrides: Optional[Dict[str, str]] = None,
no_system_prompt: bool = False,
) -> torch.Tensor:
conv = default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
if no_system_prompt:
conv.system = ""
# Skip the first message if it is not from human
if messages[0]["from"] != "human":
messages = messages[1:]
# Add a generation prompt if needed
if add_generation_prompt:
messages.append({"from": "gpt", "value": None})
conv.messages = []
for turn, message in enumerate(messages):
role = roles[message["from"]]
assert role == conv.roles[turn % 2]
if overrides is not None and message["from"] in overrides:
conv.append_message(role, overrides[message["from"]])
else:
conv.append_message(role, message["value"])
return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt")
def tokenize_conversation(
messages: Sequence[Dict[str, str]],
tokenizer: transformers.PreTrainedTokenizer,
add_generation_prompt: bool = False,
overrides: Optional[Dict[str, str]] = None,
no_system_prompt: bool = False,
) -> torch.Tensor:
# Normalize the conversation before tokenization
for message in messages:
message["value"] = message["value"].strip()
if default_conversation.sep_style != SeparatorStyle.AUTO:
return tokenize_conversation_legacy(
messages,
tokenizer,
add_generation_prompt=add_generation_prompt,
overrides=overrides,
no_system_prompt=no_system_prompt,
)
conversation = []
for m in messages:
message = {}
if m["from"] == "human":
message["role"] = "user"
elif m["from"] == "gpt":
message["role"] = "assistant"
else:
raise ValueError(f"Unexpected sender '{m['from']}' in conversation entry.")
message["content"] = m["value"]
if overrides is not None and m["from"] in overrides:
message["content"] = overrides[m["from"]]
conversation.append(message)
if no_system_prompt:
conversation = [{"role": "system", "content": ""}] + conversation
text = tokenizer.apply_chat_template(
conversation,
add_generation_prompt=add_generation_prompt,
tokenize=False,
)
return tokenizer_image_token(text, tokenizer, return_tensors="pt")
def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]:
_maybe_add_sentinel_token(tokenizer)
template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN})
stop_tokens = {tokenizer.eos_token}
for k in range(template.size(0) - 1):
if template[k] == tokenizer.sentinel_token_id:
stop_token = tokenizer.decode(template[k + 1])
stop_tokens.add(stop_token)
return list(stop_tokens)
def context_length_extension(config):
orig_ctx_len = getattr(config, "max_position_embeddings", None)
model_max_length = getattr(config, "model_max_length", None)
if orig_ctx_len and model_max_length > orig_ctx_len:
print(f"Scaling RoPE from {orig_ctx_len} to {model_max_length}")
scaling_factor = float(math.ceil(model_max_length / orig_ctx_len))
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
return config
def build_llm_and_tokenizer(
model_name_or_path: str,
config: PretrainedConfig,
attn_implementation=None,
model_max_length=None,
*args,
**kwargs,
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
# print(model_name_or_path)
llm_cfg = AutoConfig.from_pretrained(model_name_or_path)
llm_cfg._attn_implementation = attn_implementation
llm_cfg.model_max_length = model_max_length
if model_max_length is not None:
context_length_extension(llm_cfg)
# Quantization related
quantization_restore_from_checkpoint = False
if quantization_restore_from_checkpoint:
fp8_model_name_or_path = kwargs.pop("fp8_llm_cfg", None)
llm = AutoModelForCausalLM.from_pretrained(
fp8_model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
)
else:
llm = AutoModelForCausalLM.from_pretrained(
model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
)
# NOTE(ligeng): not sure whether it affects the training
# packing.patch(llm)
# Locate the tokenizer.
llm_path = model_name_or_path
if not has_tokenizer(llm_path):
llm_path = osp.join(llm_path, "llm")
if not has_tokenizer(llm_path):
raise ValueError(f"Cannot find tokenizer in {llm_path}.")
tokenizer = AutoTokenizer.from_pretrained(llm_path, padding_side="right", use_fast=True, legacy=False)
if model_max_length is not None:
tokenizer.model_max_length = model_max_length
# Load chat template if specified.
if getattr(config, "chat_template", None) is not None:
print(f"Using chat template: {config.chat_template}")
fpath = os.path.join(os.path.dirname(__file__), "chat_templates", f"{config.chat_template}.jinja")
if not os.path.exists(fpath):
fpath = os.path.join(os.path.dirname(model_name_or_path), f"{config.chat_template}.jinja")
with open(fpath) as fd:
chat_template = fd.read()
tokenizer.chat_template = chat_template.replace(" ", "").replace("\n", "")
# NOTE(ligeng): disable temporarially, let see will any bugs introduce
# Set stop tokens for the tokenizer
tokenizer.stop_tokens = infer_stop_tokens(tokenizer)
tokenizer.stop_token_ids = tokenizer.convert_tokens_to_ids(tokenizer.stop_tokens)
# Add media tokens to the tokenizer
tokenizer.media_tokens = MEDIA_TOKENS
tokenizer.media_token_ids = {}
for name, token in MEDIA_TOKENS.items():
tokenizer.add_tokens([token], special_tokens=True)
tokenizer.media_token_ids[name] = tokenizer.convert_tokens_to_ids(token)
# TODO(ligeng): is this necessary for llava?
config.hidden_size = llm.config.hidden_size
return llm, tokenizer