VILA15-40b-hf-preview / tokenizer_utils.py
Ligeng-Zhu's picture
Upload files with vila-upload.
48c35d1 verified
raw
history blame
6.53 kB
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional, Sequence
import torch
import transformers
from .constants import IGNORE_INDEX, SENTINEL_TOKEN
from .conversation import SeparatorStyle, default_conversation
from .mm_utils import tokenizer_image_token
# __all__ = [
# "tokenize_conversation",
# "preprocess_conversation",
# "infer_stop_tokens",
# ]
DUMMY_CONVERSATION = [
{"from": "human", "value": "question"},
{"from": "gpt", "value": "answer"},
] * 10
def tokenize_conversation_legacy(
messages: Sequence[Dict[str, str]],
tokenizer: transformers.PreTrainedTokenizer,
add_generation_prompt: bool = False,
overrides: Optional[Dict[str, str]] = None,
no_system_prompt: bool = False,
) -> torch.Tensor:
conv = default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
if no_system_prompt:
conv.system = ""
# Skip the first message if it is not from human
if messages[0]["from"] != "human":
messages = messages[1:]
# Add a generation prompt if needed
if add_generation_prompt:
messages.append({"from": "gpt", "value": None})
conv.messages = []
for turn, message in enumerate(messages):
role = roles[message["from"]]
assert role == conv.roles[turn % 2]
if overrides is not None and message["from"] in overrides:
conv.append_message(role, overrides[message["from"]])
else:
conv.append_message(role, message["value"])
return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt")
def tokenize_conversation(
messages: Sequence[Dict[str, str]],
tokenizer: transformers.PreTrainedTokenizer,
add_generation_prompt: bool = False,
overrides: Optional[Dict[str, str]] = None,
no_system_prompt: bool = False,
) -> torch.Tensor:
# Normalize the conversation before tokenization
for message in messages:
message["value"] = message["value"].strip()
if default_conversation.sep_style != SeparatorStyle.AUTO:
return tokenize_conversation_legacy(
messages,
tokenizer,
add_generation_prompt=add_generation_prompt,
overrides=overrides,
no_system_prompt=no_system_prompt,
)
conversation = []
for m in messages:
message = {}
if m["from"] == "human":
message["role"] = "user"
elif m["from"] == "gpt":
message["role"] = "assistant"
else:
raise ValueError(f"Unexpected sender '{m['from']}' in conversation entry.")
message["content"] = m["value"]
if overrides is not None and m["from"] in overrides:
message["content"] = overrides[m["from"]]
conversation.append(message)
if no_system_prompt:
conversation = [{"role": "system", "content": ""}] + conversation
text = tokenizer.apply_chat_template(
conversation,
add_generation_prompt=add_generation_prompt,
tokenize=False,
)
return tokenizer_image_token(text, tokenizer, return_tensors="pt")
def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None:
if not hasattr(tokenizer, "sentinel_token"):
tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True)
tokenizer.sentinel_token = SENTINEL_TOKEN
tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN)
def preprocess_conversation(
conversation: Sequence[Dict[str, str]],
tokenizer: transformers.PreTrainedTokenizer,
no_system_prompt: bool = False,
retried: bool = False,
) -> Dict[str, Any]:
inputs = tokenize_conversation(conversation, tokenizer, no_system_prompt=no_system_prompt)
labels = torch.ones_like(inputs) * IGNORE_INDEX
# Generate the template by replacing the assistant's response with a sentinel.
_maybe_add_sentinel_token(tokenizer)
template = tokenize_conversation(
conversation, tokenizer, overrides={"gpt": SENTINEL_TOKEN}, no_system_prompt=no_system_prompt
)
# Remove sentinel tokens from the template.
mask = torch.ones_like(template, dtype=torch.bool)
for k in range(template.size(0) - 1):
if template[k] == tokenizer.sentinel_token_id:
mask[k : k + 2] = False
# NOTE(zhijianl): This is to handle the corner case where there is an empty token before the sentinel token.
if k > 0 and retried:
mask[k - 1] = False
template = template[mask]
# Match the tokenized conversation with the template (with no assistant's response).
# Every token that is not matched will be included in the label for training.
p = 0
for k in range(inputs.size(0)):
if p < template.size(0) and inputs[k] == template[p]:
p += 1
else:
labels[k] = inputs[k]
# Mask all tokens in the label if the template is not fully matched.
if p < template.size(0):
if not retried:
return preprocess_conversation(
conversation,
tokenizer,
no_system_prompt=no_system_prompt,
retried=True,
)
print(f"Failed to process the conversation: '{conversation}'. All tokens will be masked in the label.")
labels[:] = IGNORE_INDEX
return {"input_ids": inputs, "labels": labels}
def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]:
_maybe_add_sentinel_token(tokenizer)
template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN})
stop_tokens = {tokenizer.eos_token}
for k in range(template.size(0) - 1):
if template[k] == tokenizer.sentinel_token_id:
stop_token = tokenizer.decode(template[k + 1])
stop_tokens.add(stop_token)
return list(stop_tokens)