OMG-LLaVA / xtuner /engine /hooks /evaluate_chat_hook.py
zhangtao-whu's picture
Upload folder using huggingface_hub
476ac07 verified
# Copyright (c) OpenMMLab. All rights reserved.
import os
import warnings
import torch
from mmengine.dist import master_only
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmengine.utils import mkdir_or_exist
from mmengine.utils.misc import get_object_from_string
from transformers import GenerationConfig, StoppingCriteriaList
from xtuner.dataset.utils import expand2square, load_image
from xtuner.model.utils import prepare_inputs_labels_for_multimodal
from xtuner.registry import BUILDER
from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
StopWordStoppingCriteria)
class EvaluateChatHook(Hook):
priority = 'LOW'
def __init__(self,
tokenizer,
evaluation_inputs,
evaluation_images=None,
image_processor=None,
system='',
prompt_template=None,
every_n_iters=None,
max_new_tokens=600,
stop_word=None,
stop_words=[],
generation_kwargs={}):
self.evaluation_inputs = evaluation_inputs
if isinstance(self.evaluation_inputs, str):
self.evaluation_inputs = [self.evaluation_inputs]
self.evaluation_images = evaluation_images
if isinstance(self.evaluation_images, str):
self.evaluation_images = [self.evaluation_images]
if self.evaluation_images is not None:
assert len(
self.evaluation_images) in [1, len(self.evaluation_inputs)]
if len(self.evaluation_images) == 1:
self.evaluation_images = [self.evaluation_images[0]] * len(
self.evaluation_inputs)
self.evaluation_images = [
load_image(img) for img in self.evaluation_images
]
if prompt_template is None:
instruction = '{input}'
else:
if isinstance(prompt_template, str): # for resume
prompt_template = get_object_from_string(prompt_template)
instruction = prompt_template.get('INSTRUCTION', '{input}')
if system != '':
system = prompt_template.get(
'SYSTEM', '{system}\n').format(system=system)
stop_words += prompt_template.get('STOP_WORDS', [])
if stop_word is not None:
# TODO: deprecation, v0.3.0
warnings.warn(
('The `stop_word` argument is deprecated and will be removed '
'in v0.3.0, use `stop_words` instead.'), DeprecationWarning)
stop_words.append(stop_word)
self.instruction = instruction
self.system = system
self.every_n_iters = every_n_iters
self.max_new_tokens = max_new_tokens
self.tokenizer = BUILDER.build(tokenizer)
if image_processor is not None:
self.image_processor = BUILDER.build(image_processor)
self.stop_criteria = StoppingCriteriaList()
# default generation config
default_generation_kwargs = dict(
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.1,
top_p=0.75,
top_k=40,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id
if self.tokenizer.pad_token_id is not None else
self.tokenizer.eos_token_id)
default_generation_kwargs.update(generation_kwargs)
self.gen_config = GenerationConfig(**default_generation_kwargs)
self.stop_criteria = StoppingCriteriaList()
for word in stop_words:
self.stop_criteria.append(
StopWordStoppingCriteria(self.tokenizer, word))
self.is_first_run = True
@master_only
def _save_eval_output(self, runner, eval_outputs):
save_path = os.path.join(runner.log_dir, 'vis_data',
f'eval_outputs_iter_{runner.iter}.txt')
mkdir_or_exist(os.path.dirname(save_path))
with open(save_path, 'w', encoding='utf-8') as f:
for i, output in enumerate(eval_outputs):
f.write(f'Eval output {i + 1}:\n{output}\n\n')
def _eval_images(self,
runner,
model,
device,
max_new_tokens=None,
save_eval_output=False):
if save_eval_output:
eval_outputs = []
for sample_image, sample_input in zip(self.evaluation_images,
self.evaluation_inputs):
image = expand2square(
sample_image,
tuple(int(x * 255) for x in self.image_processor.image_mean))
image = self.image_processor.preprocess(
image, return_tensors='pt')['pixel_values'][0]
image = image.to(device)
sample_input = DEFAULT_IMAGE_TOKEN + '\n' + sample_input
inputs = (self.system + self.instruction).format(
input=sample_input, round=1, **runner.cfg)
chunk_encode = []
for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)):
if idx == 0:
cur_encode = self.tokenizer.encode(chunk)
else:
cur_encode = self.tokenizer.encode(
chunk, add_special_tokens=False)
chunk_encode.append(cur_encode)
assert len(chunk_encode) == 2
input_ids = []
for idx, cur_chunk_encode in enumerate(chunk_encode):
input_ids.extend(cur_chunk_encode)
if idx != len(chunk_encode) - 1:
input_ids.append(IMAGE_TOKEN_INDEX)
input_ids = torch.tensor(input_ids).to(device)
visual_outputs = model.visual_encoder(
image.unsqueeze(0).to(model.visual_encoder.dtype),
output_hidden_states=True)
pixel_values = model.projector(
visual_outputs.hidden_states[model.visual_select_layer][:, 1:])
mm_inputs = prepare_inputs_labels_for_multimodal(
llm=model.llm,
input_ids=input_ids.unsqueeze(0),
pixel_values=pixel_values)
generation_output = model.generate(
**mm_inputs,
max_new_tokens=max_new_tokens,
generation_config=self.gen_config,
bos_token_id=self.tokenizer.bos_token_id,
stopping_criteria=self.stop_criteria)
generation_output = self.tokenizer.decode(generation_output[0])
runner.logger.info(f'Sample output:\n'
f'{inputs + generation_output}\n')
if save_eval_output:
eval_outputs.append(f'{inputs + generation_output}\n')
if save_eval_output:
self._save_eval_output(runner, eval_outputs)
def _eval_language(self,
runner,
model,
device,
max_new_tokens=None,
save_eval_output=False):
if save_eval_output:
eval_outputs = []
for sample_input in self.evaluation_inputs:
inputs = (self.system + self.instruction).format(
input=sample_input, round=1, **runner.cfg)
input_ids = self.tokenizer.encode(inputs, return_tensors='pt')
input_ids = input_ids.to(device)
generation_output = model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
generation_config=self.gen_config,
stopping_criteria=self.stop_criteria)
generation_output = self.tokenizer.decode(generation_output[0])
runner.logger.info(f'Sample output:\n{generation_output}\n')
if save_eval_output:
eval_outputs.append(f'{generation_output}\n')
if save_eval_output:
self._save_eval_output(runner, eval_outputs)
def _generate_samples(self,
runner,
max_new_tokens=None,
save_eval_output=False):
if max_new_tokens is None:
max_new_tokens = self.max_new_tokens
model = runner.model
if is_model_wrapper(model):
model = model.module
device = next(iter(model.parameters())).device
if self.is_first_run:
# hardcode for qlora DeepSpeed ZeRO3, put buffers and QuantState to
# device
model.to(device)
self.is_first_run = False
is_checkpointing = model.llm.is_gradient_checkpointing
use_cache = model.llm.config.use_cache
# Cast to inference mode
model.activation_checkpointing_disable()
model.llm.config.use_cache = True
model.eval()
if self.evaluation_images is not None:
self._eval_images(runner, model, device, max_new_tokens,
save_eval_output)
else:
self._eval_language(runner, model, device, max_new_tokens,
save_eval_output)
# Cast to training mode
if is_checkpointing:
model.activation_checkpointing_enable()
model.llm.config.use_cache = use_cache
model.train()
def before_train(self, runner):
runner.logger.info('before_train in EvaluateChatHook.')
self._generate_samples(runner, max_new_tokens=50)
def _is_save_checkpoint(self, runner):
hooks = runner.hooks
checkpoint_hook = None
for hook in hooks:
if type(hook).__name__ == 'CheckpointHook':
checkpoint_hook = hook
break
if checkpoint_hook is None or checkpoint_hook.by_epoch:
return False
if checkpoint_hook.every_n_train_iters(
runner, checkpoint_hook.interval, checkpoint_hook.save_begin) or \
(checkpoint_hook.save_last and
checkpoint_hook.is_last_train_iter(runner)):
return True
return False
def after_train_iter(self,
runner,
batch_idx: int,
data_batch=None,
outputs=None) -> None:
if self.every_n_iters is None:
return
save_eval_output = self._is_save_checkpoint(runner)
do_chat = (
save_eval_output
or self.every_n_train_iters(runner, self.every_n_iters))
if not do_chat:
return
runner.logger.info('after_train_iter in EvaluateChatHook.')
self._generate_samples(runner, save_eval_output=save_eval_output)
def after_train(self, runner):
runner.logger.info('after_train in EvaluateChatHook.')
self._generate_samples(runner)
def after_val(self, runner) -> None:
if self.every_n_iters is not None:
return
runner.logger.info('after_val in EvaluateChatHook.')
self._generate_samples(runner)