Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2025 Ye Liu. Licensed under the BSD-3-Clause License. | |
from dataclasses import dataclass, field | |
from typing import Optional | |
import nncore | |
import torch | |
import torch.nn as nn | |
from peft import LoraConfig, PeftModel, get_peft_model | |
from transformers import AutoProcessor, HfArgumentParser, TrainingArguments | |
from videomind.constants import REG_TOKEN, SEG_E_TOKEN, SEG_S_TOKEN | |
from videomind.dataset import HybridDataCollator, HybridDataset | |
from videomind.model import MODELS | |
from videomind.model.builder import build_model | |
from videomind.train.custom_trainer import CustomTrainer | |
class ModelArguments: | |
model_name_or_path: Optional[str] = field(default=None) | |
base_model: Optional[str] = field(default=None) | |
conv_type: Optional[str] = field(default=None) | |
role: Optional[str] = field(default=None) | |
class DataArguments: | |
datasets: Optional[str] = field(default=None) | |
min_video_len: Optional[int] = field(default=-1) | |
max_video_len: Optional[int] = field(default=-1) | |
min_num_words: Optional[int] = field(default=-1) | |
max_num_words: Optional[int] = field(default=-1) | |
max_retries: Optional[int] = field(default=10) | |
class CustomArguments: | |
optim: Optional[str] = field(default='adamw_torch') | |
group_by_data_type: Optional[bool] = field(default=True) | |
merge_adapter: Optional[bool] = field(default=False) | |
lora_enable: Optional[bool] = field(default=False) | |
lora_type: Optional[str] = field(default='qkvo') | |
lora_r: Optional[int] = field(default=64) | |
lora_alpha: Optional[int] = field(default=64) | |
lora_dropout: Optional[float] = field(default=0.1) | |
lora_bias: Optional[str] = field(default='none') | |
lora_lr: Optional[float] = field(default=None) | |
head_lr: Optional[float] = field(default=None) | |
tuning_modules: Optional[str] = field(default=None) | |
save_full_model: Optional[bool] = field(default=False) | |
remove_unused_columns: Optional[bool] = field(default=False) | |
class TrainingArguments(CustomArguments, TrainingArguments): | |
pass | |
def get_target_modules(model, lora_type, base_model): | |
lora_type = lora_type.split('_') | |
assert all(t in ('qkvo', 'linear', 'all') for t in lora_type) | |
if base_model == 'qwen2_vl': | |
# all qkvo layers in the visual encoder and the llm | |
qkvo_keys = ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'attn.qkv', 'attn.proj'] | |
target_modules = set() | |
for n, m in model.named_modules(): | |
if not isinstance(m, nn.Linear): | |
continue | |
if 'all' not in lora_type and 'visual' in n: | |
continue | |
if 'qkvo' in lora_type and not any(n.endswith(k) for k in qkvo_keys): | |
continue | |
target_modules.add(n) | |
else: | |
raise ValueError(f'unknown base model: {base_model}') | |
return target_modules | |
def train(TrainingArguments, Trainer): | |
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) | |
model_args, data_args, training_args = parser.parse_args_into_dataclasses() | |
assert model_args.role in ('all_in_one', 'planner', 'grounder', 'verifier', 'answerer') | |
config_cls, model_cls = MODELS[model_args.base_model] | |
dtype = torch.bfloat16 if training_args.bf16 else torch.float32 | |
config = config_cls.from_pretrained(model_args.model_name_or_path, torch_dtype=dtype) | |
config.update(model_args.__dict__) | |
if config.model_type == 'agent_qwen2_vl': | |
model, processor = build_model( | |
model_args.model_name_or_path, | |
config=config, | |
is_trainable=True, | |
merge_adapter=training_args.merge_adapter, | |
dtype=dtype) | |
else: | |
# set do_resize to false to avoid duplicated resizing | |
# https://github.com/huggingface/transformers/tree/main/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py | |
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, do_resize=False) | |
# eager attention has known & unknown bugs | |
# [4.46.2] broken causality fp16: https://github.com/huggingface/transformers/issues/35151 | |
# [4.48.1] broken sliding window: https://github.com/huggingface/transformers/issues/35924 | |
model = model_cls.from_pretrained(model_args.model_name_or_path, config=config, attn_implementation='sdpa') | |
# save base model path for inference | |
model.config.base_model_path = model_args.model_name_or_path | |
# conv parameters may become inf after casting to fp16 | |
model.reset_conv_parameters() | |
model.requires_grad_(False) | |
if training_args.lora_enable and not isinstance(model, PeftModel): | |
target_modules = get_target_modules(model, training_args.lora_type, model.config.base_model) | |
tune_lm_head = model.config.role in ('all_in_one', 'grounder', 'verifier') | |
print(f'LoRA target modules: {target_modules}') | |
lora_config = LoraConfig( | |
task_type='CAUSAL_LM', | |
r=training_args.lora_r, | |
lora_alpha=training_args.lora_alpha, | |
lora_dropout=training_args.lora_dropout, | |
bias=training_args.lora_bias, | |
target_modules=target_modules, | |
modules_to_save=['embed_tokens', 'lm_head'] if tune_lm_head else None) | |
# transformers integration does not support merge_and_unload, use peft instead | |
model = get_peft_model(model, lora_config, adapter_name=model_args.role) | |
new_tokens = processor.tokenizer.add_special_tokens( | |
dict(additional_special_tokens=[REG_TOKEN, SEG_S_TOKEN, SEG_E_TOKEN])) | |
print(f'Added {new_tokens} new token(s)') | |
model.config.reg_token_id = processor.tokenizer.convert_tokens_to_ids(REG_TOKEN) | |
model.config.seg_s_token_id = processor.tokenizer.convert_tokens_to_ids(SEG_S_TOKEN) | |
model.config.seg_e_token_id = processor.tokenizer.convert_tokens_to_ids(SEG_E_TOKEN) | |
if new_tokens > 0 and len(processor.tokenizer) > model.config.vocab_size: | |
print(f'Expanding vocab size: {model.config.vocab_size} -> {len(processor.tokenizer)}') | |
model.resize_token_embeddings(len(processor.tokenizer)) | |
i_emb = model.get_input_embeddings().weight.data | |
o_emb = model.get_output_embeddings().weight.data | |
i_emb[-new_tokens:] = i_emb[:-new_tokens].mean(0, keepdim=True) | |
o_emb[-new_tokens:] = o_emb[:-new_tokens].mean(0, keepdim=True) | |
tuning_modules = [] if training_args.tuning_modules is None else training_args.tuning_modules.split(',') | |
head_keys = [ | |
'vis_proj', 'reg_proj', 'vis_fuse', 'vis_norm', 'vis_pos', 'vis_emb', 'reg_emb', 'pyramid', 'class_head', | |
'coord_head', 'coef', 'bundle_loss' | |
] | |
for n, p in model.named_parameters(): | |
# embed_tokens and lm_head might be handled by lora | |
if not training_args.lora_enable and new_tokens > 0 and any(k in n for k in ('embed_tokens', 'lm_head')): | |
p.requires_grad = True | |
if 'projector' in tuning_modules and 'visual.merger' in n: | |
p.requires_grad = True | |
if model_args.role in ('all_in_one', 'grounder') and any(k in n for k in head_keys): | |
p.requires_grad = True | |
if training_args.local_rank in (0, -1): | |
for n, p in model.named_parameters(): | |
print(p.requires_grad, p.dtype, p.shape, n) | |
total_params = sum(p.numel() for p in model.parameters()) | |
learnable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
ratio = round(learnable_params / total_params * 100, 2) if total_params > 0 else 0 | |
print(f'Total params: {total_params} Learnable params: {learnable_params} ({ratio}%)') | |
i_size = model.get_input_embeddings().num_embeddings | |
o_size = model.get_output_embeddings().out_features | |
assert i_size == o_size, (i_size, o_size) | |
print(f'Tokenizer size: {len(processor.tokenizer)} Vocab size: {model.config.vocab_size} Embed size: {i_size}') | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
data_collator=HybridDataCollator(processor.tokenizer), | |
train_dataset=HybridDataset(processor, model.config, model_args, data_args, training_args), | |
processor=processor, | |
head_keys=head_keys) | |
has_ckpt = bool(nncore.find(training_args.output_dir, 'checkpoint-*')) | |
trainer.train(resume_from_checkpoint=has_ckpt) | |
trainer.save_state() | |
trainer.gather_and_save_model() | |
if __name__ == '__main__': | |
train(TrainingArguments, CustomTrainer) | |