Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2025 Ye Liu. Licensed under the BSD-3-Clause License. | |
import warnings | |
import nncore | |
import torch | |
from deepspeed import zero | |
from safetensors.torch import load_model, save_file | |
from torch.utils.data import Sampler | |
from transformers import Trainer, TrainerCallback | |
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS | |
from transformers.trainer_pt_utils import get_parameter_names | |
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR | |
from transformers.utils import CHAT_TEMPLATE_NAME | |
def gather(param): | |
if hasattr(param, 'ds_id'): | |
with zero.GatheredParameters([param]): | |
param = param.data.detach().cpu().clone() | |
else: | |
param = param.detach().cpu().clone() | |
return param | |
def gather_lora_params(model, bias): | |
assert bias in ('lora_only', 'all', 'none') | |
if bias == 'lora_only': | |
state_dict, maybe_lora_bias, lora_bias_names = dict(), dict(), set() | |
for n, p in model.named_parameters(): | |
if 'modules_to_save' in n: | |
state_dict[n] = p | |
elif 'lora_' in n: | |
state_dict[n] = p | |
bias_name = n.split('lora_')[0] + 'bias' | |
lora_bias_names.add(bias_name) | |
elif 'bias' in n: | |
maybe_lora_bias[n] = p | |
for n, p in maybe_lora_bias: | |
if bias_name in lora_bias_names: | |
state_dict[bias_name] = p | |
else: | |
keys = ['lora_', 'modules_to_save', 'bias'] if bias == 'all' else ['lora_', 'modules_to_save'] | |
state_dict = {n: p for n, p in model.named_parameters() if any(k in n for k in keys)} | |
state_dict = {n: gather(p) for n, p in state_dict.items()} | |
return state_dict | |
def gather_key_params(model, keys): | |
state_dict = {n: p for n, p in model.named_parameters() if p.requires_grad and any(k in n for k in keys)} | |
state_dict = {n: gather(p) for n, p in state_dict.items()} | |
return state_dict | |
class GroupSampler(Sampler): | |
def __init__(self, group_size, data_types, seed): | |
self.group_size = group_size | |
self.data_types = data_types | |
self.seed = seed | |
def __len__(self): | |
return len(self.data_types) | |
def __iter__(self): | |
g = torch.Generator() | |
g.manual_seed(self.seed + self.epoch) | |
# avoid using dict or set here as they are not deterministic | |
unique_types, groups = [], [] | |
for i, t in enumerate(self.data_types): | |
if t not in unique_types: | |
unique_types.append(t) | |
groups.append([]) | |
groups[unique_types.index(t)].append(i) | |
group_batches = [] | |
for group in groups: | |
inds = [group[i] for i in torch.randperm(len(group), generator=g)] | |
batches = [inds[i:i + self.group_size] for i in range(0, len(inds), self.group_size)] | |
if len(batches[-1]) < self.group_size: | |
batches = batches[:-1] | |
group_batches += batches | |
perm_group_batches = [group_batches[i] for i in torch.randperm(len(group_batches), generator=g)] | |
inds = [i for batch in perm_group_batches for i in batch] | |
return iter(inds) | |
def set_epoch(self, epoch): | |
self.epoch = epoch | |
class SetEpochCallback(TrainerCallback): | |
# partially fixed in https://github.com/huggingface/accelerate/pull/3246 | |
# but not for the case of batch_sampler.batch_sampler.sampler | |
def on_epoch_begin(self, args, state, control, **kwargs): | |
shard_sampler = kwargs['train_dataloader'].batch_sampler | |
batch_sampler = getattr(shard_sampler, 'batch_sampler', shard_sampler) | |
batch_sampler.sampler.set_epoch(int(state.epoch)) | |
class CustomTrainer(Trainer): | |
def __init__(self, *args, processor=None, head_keys=None, **kwargs): | |
super().__init__(*args, tokenizer=processor, **kwargs) | |
self.add_callback(SetEpochCallback()) | |
self.processor = processor | |
self.head_keys = head_keys | |
def _get_train_sampler(self): | |
if self.args.group_by_data_type: | |
return GroupSampler(self.args.train_batch_size * self.args.world_size, self.train_dataset.data_types, | |
self.args.seed) | |
else: | |
return super()._get_train_sampler() | |
def _load_from_checkpoint(self, resume_from_checkpoint, model=None): | |
if model is None: | |
model = self.model | |
super()._load_from_checkpoint(resume_from_checkpoint, model=model) | |
partial_path = nncore.join(resume_from_checkpoint, 'pytorch_model.safetensors') | |
if nncore.is_file(partial_path): | |
load_model(model, partial_path, strict=False, device=model.device) | |
def create_optimizer(self): | |
if self.optimizer is None: | |
grad_ps = [(n, p) for n, p in self.model.named_parameters() if p.requires_grad] | |
decay_ps = get_parameter_names(self.model, ALL_LAYERNORM_LAYERS) | |
decay_ps = [n for n in decay_ps if 'bias' not in n] | |
if self.args.lora_lr is None: | |
self.args.lora_lr = self.args.learning_rate | |
if self.args.head_lr is None: | |
self.args.head_lr = self.args.learning_rate | |
lora_ps = [n for n, _ in grad_ps if 'lora' in n] | |
head_ps = [n for n, _ in grad_ps if any(k in n for k in self.head_keys)] | |
assert all(n not in lora_ps for n in head_ps) and all(n not in head_ps for n in lora_ps) | |
groups = [{ | |
'params': [p for n, p in grad_ps if (n in decay_ps and n not in lora_ps and n not in head_ps)], | |
'weight_decay': self.args.weight_decay | |
}, { | |
'params': [p for n, p in grad_ps if (n not in decay_ps and n not in lora_ps and n not in head_ps)], | |
'weight_decay': 0.0 | |
}, { | |
'params': [p for n, p in grad_ps if (n in decay_ps and n in lora_ps)], | |
'weight_decay': self.args.weight_decay, | |
'lr': self.args.lora_lr | |
}, { | |
'params': [p for n, p in grad_ps if (n not in decay_ps and n in lora_ps)], | |
'weight_decay': 0.0, | |
'lr': self.args.lora_lr | |
}, { | |
'params': [p for n, p in grad_ps if (n in decay_ps and n in head_ps)], | |
'weight_decay': self.args.weight_decay, | |
'lr': self.args.head_lr | |
}, { | |
'params': [p for n, p in grad_ps if (n not in decay_ps and n in head_ps)], | |
'weight_decay': 0.0, | |
'lr': self.args.head_lr | |
}] | |
optim_cls, kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) | |
self.optimizer = optim_cls(groups, **kwargs) | |
return self.optimizer | |
def gather_and_save_model(self): | |
deepspeed_zero3 = self.accelerator.deepspeed_config['zero_optimization']['stage'] == 3 | |
output_dir = self.args.output_dir | |
if self.args.should_save: | |
print(f'Saving final model to {nncore.abs_path(output_dir)}...') | |
if self.processor is not None and self.args.should_save: | |
self.processor.save_pretrained(output_dir) | |
# https://github.com/huggingface/transformers/pull/33462 | |
if self.processor.chat_template is not None: | |
chat_template = {'chat_template': self.processor.chat_template} | |
nncore.dump(chat_template, nncore.join(output_dir, CHAT_TEMPLATE_NAME), indent=2) | |
if self.args.save_full_model and self.args.lora_enable and deepspeed_zero3: | |
warnings.warn('LoRA models cannot be saved in full mode under zero3, saving adapters instead') | |
self.args.save_full_model = False | |
if self.args.save_full_model: | |
if self.args.lora_enable: | |
self.model = self.model.merge_and_unload() | |
if deepspeed_zero3 and not self.model_wrapped.zero_gather_16bit_weights_on_model_save(): | |
warnings.warn('Saving zero checkpoint, use zero_to_fp32.py to recover weights') | |
self.model_wrapped.save_checkpoint(output_dir) | |
return | |
if deepspeed_zero3: | |
state_dict = self.model_wrapped._zero3_consolidated_16bit_state_dict() | |
else: | |
state_dict = self.model.state_dict() | |
if self.args.should_save: | |
state_dict = {k[17:] if k.startswith('base_model.model.') else k: v for k, v in state_dict.items()} | |
self._save(output_dir, state_dict=state_dict) | |
else: | |
if self.args.lora_enable: | |
state_dict = gather_lora_params(self.model, self.args.lora_bias) | |
if self.args.should_save: | |
self.model.save_pretrained(output_dir, state_dict=state_dict) | |
if self.args.should_save: | |
self.model.config.save_pretrained(output_dir) | |
self.model.generation_config.save_pretrained(output_dir) | |
self.tokenizer.save_pretrained(output_dir) | |
state_dict = gather_key_params(self.model, self.head_keys) | |
if self.args.should_save and state_dict: | |
save_file(state_dict, nncore.join(output_dir, 'pytorch_model.safetensors')) | |
def _save_checkpoint(self, model, trial, **kwargs): | |
output_dir = self._get_output_dir(trial) | |
output_dir = nncore.join(output_dir, f'{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}') | |
if self.args.should_save: | |
print(f'Saving checkpoint to {nncore.abs_path(output_dir)}...') | |
super()._save_checkpoint(model, trial, **kwargs) | |
if self.processor is not None and self.args.should_save: | |
self.processor.save_pretrained(output_dir) | |
# https://github.com/huggingface/transformers/pull/33462 | |
if self.processor.chat_template is not None: | |
chat_template = {'chat_template': self.processor.chat_template} | |
nncore.dump(chat_template, nncore.join(output_dir, CHAT_TEMPLATE_NAME), indent=2) | |
if self.args.lora_enable: | |
state_dict = gather_key_params(self.model, self.head_keys) | |
if self.args.should_save: | |
self.model.config.save_pretrained(output_dir) | |
save_file(state_dict, nncore.join(output_dir, 'pytorch_model.safetensors')) | |