# Copyright (c) 2025 Ye Liu. Licensed under the BSD-3-Clause License. import warnings import nncore import torch from deepspeed import zero from safetensors.torch import load_model, save_file from torch.utils.data import Sampler from transformers import Trainer, TrainerCallback from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from transformers.trainer_pt_utils import get_parameter_names from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.utils import CHAT_TEMPLATE_NAME def gather(param): if hasattr(param, 'ds_id'): with zero.GatheredParameters([param]): param = param.data.detach().cpu().clone() else: param = param.detach().cpu().clone() return param def gather_lora_params(model, bias): assert bias in ('lora_only', 'all', 'none') if bias == 'lora_only': state_dict, maybe_lora_bias, lora_bias_names = dict(), dict(), set() for n, p in model.named_parameters(): if 'modules_to_save' in n: state_dict[n] = p elif 'lora_' in n: state_dict[n] = p bias_name = n.split('lora_')[0] + 'bias' lora_bias_names.add(bias_name) elif 'bias' in n: maybe_lora_bias[n] = p for n, p in maybe_lora_bias: if bias_name in lora_bias_names: state_dict[bias_name] = p else: keys = ['lora_', 'modules_to_save', 'bias'] if bias == 'all' else ['lora_', 'modules_to_save'] state_dict = {n: p for n, p in model.named_parameters() if any(k in n for k in keys)} state_dict = {n: gather(p) for n, p in state_dict.items()} return state_dict def gather_key_params(model, keys): state_dict = {n: p for n, p in model.named_parameters() if p.requires_grad and any(k in n for k in keys)} state_dict = {n: gather(p) for n, p in state_dict.items()} return state_dict class GroupSampler(Sampler): def __init__(self, group_size, data_types, seed): self.group_size = group_size self.data_types = data_types self.seed = seed def __len__(self): return len(self.data_types) def __iter__(self): g = torch.Generator() g.manual_seed(self.seed + self.epoch) # avoid using dict or set here as they are not deterministic unique_types, groups = [], [] for i, t in enumerate(self.data_types): if t not in unique_types: unique_types.append(t) groups.append([]) groups[unique_types.index(t)].append(i) group_batches = [] for group in groups: inds = [group[i] for i in torch.randperm(len(group), generator=g)] batches = [inds[i:i + self.group_size] for i in range(0, len(inds), self.group_size)] if len(batches[-1]) < self.group_size: batches = batches[:-1] group_batches += batches perm_group_batches = [group_batches[i] for i in torch.randperm(len(group_batches), generator=g)] inds = [i for batch in perm_group_batches for i in batch] return iter(inds) def set_epoch(self, epoch): self.epoch = epoch class SetEpochCallback(TrainerCallback): # partially fixed in https://github.com/huggingface/accelerate/pull/3246 # but not for the case of batch_sampler.batch_sampler.sampler def on_epoch_begin(self, args, state, control, **kwargs): shard_sampler = kwargs['train_dataloader'].batch_sampler batch_sampler = getattr(shard_sampler, 'batch_sampler', shard_sampler) batch_sampler.sampler.set_epoch(int(state.epoch)) class CustomTrainer(Trainer): def __init__(self, *args, processor=None, head_keys=None, **kwargs): super().__init__(*args, tokenizer=processor, **kwargs) self.add_callback(SetEpochCallback()) self.processor = processor self.head_keys = head_keys def _get_train_sampler(self): if self.args.group_by_data_type: return GroupSampler(self.args.train_batch_size * self.args.world_size, self.train_dataset.data_types, self.args.seed) else: return super()._get_train_sampler() def _load_from_checkpoint(self, resume_from_checkpoint, model=None): if model is None: model = self.model super()._load_from_checkpoint(resume_from_checkpoint, model=model) partial_path = nncore.join(resume_from_checkpoint, 'pytorch_model.safetensors') if nncore.is_file(partial_path): load_model(model, partial_path, strict=False, device=model.device) def create_optimizer(self): if self.optimizer is None: grad_ps = [(n, p) for n, p in self.model.named_parameters() if p.requires_grad] decay_ps = get_parameter_names(self.model, ALL_LAYERNORM_LAYERS) decay_ps = [n for n in decay_ps if 'bias' not in n] if self.args.lora_lr is None: self.args.lora_lr = self.args.learning_rate if self.args.head_lr is None: self.args.head_lr = self.args.learning_rate lora_ps = [n for n, _ in grad_ps if 'lora' in n] head_ps = [n for n, _ in grad_ps if any(k in n for k in self.head_keys)] assert all(n not in lora_ps for n in head_ps) and all(n not in head_ps for n in lora_ps) groups = [{ 'params': [p for n, p in grad_ps if (n in decay_ps and n not in lora_ps and n not in head_ps)], 'weight_decay': self.args.weight_decay }, { 'params': [p for n, p in grad_ps if (n not in decay_ps and n not in lora_ps and n not in head_ps)], 'weight_decay': 0.0 }, { 'params': [p for n, p in grad_ps if (n in decay_ps and n in lora_ps)], 'weight_decay': self.args.weight_decay, 'lr': self.args.lora_lr }, { 'params': [p for n, p in grad_ps if (n not in decay_ps and n in lora_ps)], 'weight_decay': 0.0, 'lr': self.args.lora_lr }, { 'params': [p for n, p in grad_ps if (n in decay_ps and n in head_ps)], 'weight_decay': self.args.weight_decay, 'lr': self.args.head_lr }, { 'params': [p for n, p in grad_ps if (n not in decay_ps and n in head_ps)], 'weight_decay': 0.0, 'lr': self.args.head_lr }] optim_cls, kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) self.optimizer = optim_cls(groups, **kwargs) return self.optimizer def gather_and_save_model(self): deepspeed_zero3 = self.accelerator.deepspeed_config['zero_optimization']['stage'] == 3 output_dir = self.args.output_dir if self.args.should_save: print(f'Saving final model to {nncore.abs_path(output_dir)}...') if self.processor is not None and self.args.should_save: self.processor.save_pretrained(output_dir) # https://github.com/huggingface/transformers/pull/33462 if self.processor.chat_template is not None: chat_template = {'chat_template': self.processor.chat_template} nncore.dump(chat_template, nncore.join(output_dir, CHAT_TEMPLATE_NAME), indent=2) if self.args.save_full_model and self.args.lora_enable and deepspeed_zero3: warnings.warn('LoRA models cannot be saved in full mode under zero3, saving adapters instead') self.args.save_full_model = False if self.args.save_full_model: if self.args.lora_enable: self.model = self.model.merge_and_unload() if deepspeed_zero3 and not self.model_wrapped.zero_gather_16bit_weights_on_model_save(): warnings.warn('Saving zero checkpoint, use zero_to_fp32.py to recover weights') self.model_wrapped.save_checkpoint(output_dir) return if deepspeed_zero3: state_dict = self.model_wrapped._zero3_consolidated_16bit_state_dict() else: state_dict = self.model.state_dict() if self.args.should_save: state_dict = {k[17:] if k.startswith('base_model.model.') else k: v for k, v in state_dict.items()} self._save(output_dir, state_dict=state_dict) else: if self.args.lora_enable: state_dict = gather_lora_params(self.model, self.args.lora_bias) if self.args.should_save: self.model.save_pretrained(output_dir, state_dict=state_dict) if self.args.should_save: self.model.config.save_pretrained(output_dir) self.model.generation_config.save_pretrained(output_dir) self.tokenizer.save_pretrained(output_dir) state_dict = gather_key_params(self.model, self.head_keys) if self.args.should_save and state_dict: save_file(state_dict, nncore.join(output_dir, 'pytorch_model.safetensors')) def _save_checkpoint(self, model, trial, **kwargs): output_dir = self._get_output_dir(trial) output_dir = nncore.join(output_dir, f'{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}') if self.args.should_save: print(f'Saving checkpoint to {nncore.abs_path(output_dir)}...') super()._save_checkpoint(model, trial, **kwargs) if self.processor is not None and self.args.should_save: self.processor.save_pretrained(output_dir) # https://github.com/huggingface/transformers/pull/33462 if self.processor.chat_template is not None: chat_template = {'chat_template': self.processor.chat_template} nncore.dump(chat_template, nncore.join(output_dir, CHAT_TEMPLATE_NAME), indent=2) if self.args.lora_enable: state_dict = gather_key_params(self.model, self.head_keys) if self.args.should_save: self.model.config.save_pretrained(output_dir) save_file(state_dict, nncore.join(output_dir, 'pytorch_model.safetensors'))