Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,530 Bytes
6073e55 23fdbc0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
# Copyright (c) 2025 Ye Liu. Licensed under the BSD-3-Clause License.
from dataclasses import dataclass, field
from typing import Optional
import nncore
import torch
import torch.nn as nn
from peft import LoraConfig, PeftModel, get_peft_model
from transformers import AutoProcessor, HfArgumentParser, TrainingArguments
from videomind.constants import REG_TOKEN, SEG_E_TOKEN, SEG_S_TOKEN
from videomind.dataset import HybridDataCollator, HybridDataset
from videomind.model import MODELS
from videomind.model.builder import build_model
from videomind.train.custom_trainer import CustomTrainer
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default=None)
base_model: Optional[str] = field(default=None)
conv_type: Optional[str] = field(default=None)
role: Optional[str] = field(default=None)
@dataclass
class DataArguments:
datasets: Optional[str] = field(default=None)
min_video_len: Optional[int] = field(default=-1)
max_video_len: Optional[int] = field(default=-1)
min_num_words: Optional[int] = field(default=-1)
max_num_words: Optional[int] = field(default=-1)
max_retries: Optional[int] = field(default=10)
@dataclass
class CustomArguments:
optim: Optional[str] = field(default='adamw_torch')
group_by_data_type: Optional[bool] = field(default=True)
merge_adapter: Optional[bool] = field(default=False)
lora_enable: Optional[bool] = field(default=False)
lora_type: Optional[str] = field(default='qkvo')
lora_r: Optional[int] = field(default=64)
lora_alpha: Optional[int] = field(default=64)
lora_dropout: Optional[float] = field(default=0.1)
lora_bias: Optional[str] = field(default='none')
lora_lr: Optional[float] = field(default=None)
head_lr: Optional[float] = field(default=None)
tuning_modules: Optional[str] = field(default=None)
save_full_model: Optional[bool] = field(default=False)
remove_unused_columns: Optional[bool] = field(default=False)
@dataclass
class TrainingArguments(CustomArguments, TrainingArguments):
pass
def get_target_modules(model, lora_type, base_model):
lora_type = lora_type.split('_')
assert all(t in ('qkvo', 'linear', 'all') for t in lora_type)
if base_model == 'qwen2_vl':
# all qkvo layers in the visual encoder and the llm
qkvo_keys = ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'attn.qkv', 'attn.proj']
target_modules = set()
for n, m in model.named_modules():
if not isinstance(m, nn.Linear):
continue
if 'all' not in lora_type and 'visual' in n:
continue
if 'qkvo' in lora_type and not any(n.endswith(k) for k in qkvo_keys):
continue
target_modules.add(n)
else:
raise ValueError(f'unknown base model: {base_model}')
return target_modules
def train(TrainingArguments, Trainer):
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
assert model_args.role in ('all_in_one', 'planner', 'grounder', 'verifier', 'answerer')
config_cls, model_cls = MODELS[model_args.base_model]
dtype = torch.bfloat16 if training_args.bf16 else torch.float32
config = config_cls.from_pretrained(model_args.model_name_or_path, torch_dtype=dtype)
config.update(model_args.__dict__)
if config.model_type == 'agent_qwen2_vl':
model, processor = build_model(
model_args.model_name_or_path,
config=config,
is_trainable=True,
merge_adapter=training_args.merge_adapter,
dtype=dtype)
else:
# set do_resize to false to avoid duplicated resizing
# https://github.com/huggingface/transformers/tree/main/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, do_resize=False)
# eager attention has known & unknown bugs
# [4.46.2] broken causality fp16: https://github.com/huggingface/transformers/issues/35151
# [4.48.1] broken sliding window: https://github.com/huggingface/transformers/issues/35924
model = model_cls.from_pretrained(model_args.model_name_or_path, config=config, attn_implementation='sdpa')
# save base model path for inference
model.config.base_model_path = model_args.model_name_or_path
# conv parameters may become inf after casting to fp16
model.reset_conv_parameters()
model.requires_grad_(False)
if training_args.lora_enable and not isinstance(model, PeftModel):
target_modules = get_target_modules(model, training_args.lora_type, model.config.base_model)
tune_lm_head = model.config.role in ('all_in_one', 'grounder', 'verifier')
print(f'LoRA target modules: {target_modules}')
lora_config = LoraConfig(
task_type='CAUSAL_LM',
r=training_args.lora_r,
lora_alpha=training_args.lora_alpha,
lora_dropout=training_args.lora_dropout,
bias=training_args.lora_bias,
target_modules=target_modules,
modules_to_save=['embed_tokens', 'lm_head'] if tune_lm_head else None)
# transformers integration does not support merge_and_unload, use peft instead
model = get_peft_model(model, lora_config, adapter_name=model_args.role)
new_tokens = processor.tokenizer.add_special_tokens(
dict(additional_special_tokens=[REG_TOKEN, SEG_S_TOKEN, SEG_E_TOKEN]))
print(f'Added {new_tokens} new token(s)')
model.config.reg_token_id = processor.tokenizer.convert_tokens_to_ids(REG_TOKEN)
model.config.seg_s_token_id = processor.tokenizer.convert_tokens_to_ids(SEG_S_TOKEN)
model.config.seg_e_token_id = processor.tokenizer.convert_tokens_to_ids(SEG_E_TOKEN)
if new_tokens > 0 and len(processor.tokenizer) > model.config.vocab_size:
print(f'Expanding vocab size: {model.config.vocab_size} -> {len(processor.tokenizer)}')
model.resize_token_embeddings(len(processor.tokenizer))
i_emb = model.get_input_embeddings().weight.data
o_emb = model.get_output_embeddings().weight.data
i_emb[-new_tokens:] = i_emb[:-new_tokens].mean(0, keepdim=True)
o_emb[-new_tokens:] = o_emb[:-new_tokens].mean(0, keepdim=True)
tuning_modules = [] if training_args.tuning_modules is None else training_args.tuning_modules.split(',')
head_keys = [
'vis_proj', 'reg_proj', 'vis_fuse', 'vis_norm', 'vis_pos', 'vis_emb', 'reg_emb', 'pyramid', 'class_head',
'coord_head', 'coef', 'bundle_loss'
]
for n, p in model.named_parameters():
# embed_tokens and lm_head might be handled by lora
if not training_args.lora_enable and new_tokens > 0 and any(k in n for k in ('embed_tokens', 'lm_head')):
p.requires_grad = True
if 'projector' in tuning_modules and 'visual.merger' in n:
p.requires_grad = True
if model_args.role in ('all_in_one', 'grounder') and any(k in n for k in head_keys):
p.requires_grad = True
if training_args.local_rank in (0, -1):
for n, p in model.named_parameters():
print(p.requires_grad, p.dtype, p.shape, n)
total_params = sum(p.numel() for p in model.parameters())
learnable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
ratio = round(learnable_params / total_params * 100, 2) if total_params > 0 else 0
print(f'Total params: {total_params} Learnable params: {learnable_params} ({ratio}%)')
i_size = model.get_input_embeddings().num_embeddings
o_size = model.get_output_embeddings().out_features
assert i_size == o_size, (i_size, o_size)
print(f'Tokenizer size: {len(processor.tokenizer)} Vocab size: {model.config.vocab_size} Embed size: {i_size}')
trainer = Trainer(
model=model,
args=training_args,
data_collator=HybridDataCollator(processor.tokenizer),
train_dataset=HybridDataset(processor, model.config, model_args, data_args, training_args),
processor=processor,
head_keys=head_keys)
has_ckpt = bool(nncore.find(training_args.output_dir, 'checkpoint-*'))
trainer.train(resume_from_checkpoint=has_ckpt)
trainer.save_state()
trainer.gather_and_save_model()
if __name__ == '__main__':
train(TrainingArguments, CustomTrainer)
|