Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2021 microsoft | |
# 2023 Alan ([email protected]) | |
# ----------------------------------------------------------------------------- | |
# Licensed under the MIT License (MIT). See LICENSE in the repo root for | |
# license information. | |
# ----------------------------------------------------------------------------- | |
import logging | |
import torch | |
import torch.nn as nn | |
from typing import Dict, List | |
import wenet.finetune.lora.layers as lora | |
def get_nested_attr(module, attr_path): | |
attrs = attr_path.split('.') | |
for attr in attrs: | |
if hasattr(module, attr): | |
module = getattr(module, attr) | |
else: | |
return None | |
return module | |
def inject_lora(module, lora_config): | |
lora_rank = lora_config["lora_rank"] | |
lora_alpha = lora_config["lora_alpha"] | |
lora_dropout = lora_config["lora_dropout"] | |
for lora_attr in lora_config["lora_list"]: | |
if hasattr(module, lora_attr): | |
submodule = getattr(module, lora_attr) | |
n_feat = submodule.in_features | |
lora_linear = lora.Linear(n_feat, n_feat, r=lora_rank, | |
lora_alpha=lora_alpha, | |
lora_dropout=lora_dropout) | |
setattr(module, lora_attr, lora_linear) | |
def inject_lora_to_model(model, lora_config): | |
lora_modules = [] | |
for module in lora_config["lora_modules"]: | |
submodule = get_nested_attr(model, module) | |
for layer in submodule: | |
lora_modules.append(layer) | |
updated_lora_modules = [] | |
for i in range(len(lora_modules)): | |
for attn_attr in lora_config["lora_attn_attr"]: | |
if hasattr(lora_modules[i], attn_attr): | |
updated_lora_modules.append(getattr(lora_modules[i], attn_attr)) | |
for lora_module in updated_lora_modules: | |
inject_lora(lora_module, lora_config) | |
def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None: | |
logging.info('freezing all params except lora module.') | |
for n, p in model.named_parameters(): | |
if 'lora_' not in n: | |
p.requires_grad = False | |
if bias == 'none': | |
return | |
elif bias == 'all': | |
for n, p in model.named_parameters(): | |
if 'bias' in n: | |
p.requires_grad = True | |
elif bias == 'lora_only': | |
for m in model.modules(): | |
if isinstance(m, lora.LoRALayer) and \ | |
hasattr(m, 'bias') and \ | |
m.bias is not None: | |
m.bias.requires_grad = True | |
else: | |
raise NotImplementedError | |
def lora_state_dict(model: nn.Module, | |
bias: str = 'none') -> Dict[str, torch.Tensor]: | |
my_state_dict = model.state_dict() | |
if bias == 'none': | |
return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k} | |
elif bias == 'all': | |
return { | |
k: my_state_dict[k] | |
for k in my_state_dict if 'lora_' in k or 'bias' in k | |
} | |
elif bias == 'lora_only': | |
to_return = {} | |
for k in my_state_dict: | |
if 'lora_' in k: | |
to_return[k] = my_state_dict[k] | |
bias_name = k.split('lora_')[0] + 'bias' | |
if bias_name in my_state_dict: | |
to_return[bias_name] = my_state_dict[bias_name] | |
return to_return | |
else: | |
raise NotImplementedError | |
def get_record_gradient_hook(model, record_dict): | |
def record_gradient_hook(grad): | |
for n, p in model.named_parameters(): | |
if p.requires_grad and p.grad is not None: | |
if n not in record_dict: | |
record_dict[n] = p.grad.cpu() | |
else: | |
record_dict[n] += p.grad.cpu() | |
p.grad = None | |
return grad | |
return record_gradient_hook | |
def estimate_gradient( | |
model, dataloader, max_iters: int = 8, | |
device: torch.device = torch.device("cpu") | |
) -> Dict[str, List[torch.Tensor]]: | |
r""" | |
Estimate the gradient of the model on the given dataset | |
""" | |
logging.info("Estimating gradient layer by layer, time needed") | |
model.train() | |
named_grads = {} | |
hooks = [] | |
requires_grad_states = {} | |
for name, param in model.named_parameters(): | |
requires_grad_states[name] = param.requires_grad | |
param.requires_grad = True | |
hook = param.register_hook(get_record_gradient_hook(model, named_grads)) | |
hooks.append(hook) | |
num = 0 | |
for _, batch_dict in enumerate(dataloader): | |
num += 1 | |
if max_iters is not None and num >= max_iters: | |
break | |
outputs = model(batch_dict, device) | |
outputs['loss'].backward() | |
get_record_gradient_hook(model, named_grads)(None) # get gradient of last layer | |
# make sure the gradient is cleared | |
for n, p in model.named_parameters(): | |
if p.grad is not None: | |
p.grad = None | |
for n, _ in named_grads.items(): | |
named_grads[n] /= num | |
for hook in hooks: | |
hook.remove() | |
# recover original requires_grad states | |
for name, param in model.named_parameters(): | |
param.requires_grad = requires_grad_states[name] | |
torch.cuda.empty_cache() | |
return named_grads | |
def reinit_lora_modules(name, module, init_config, **kwargs): | |
r"""Refer to https://github.com/Outsider565/LoRA-GA/blob/ | |
c185846309ea9012d0bcd46ebd30347dda1c592c/run_exp.py#L67 | |
Reinitialize the lora model with the given configuration. | |
""" | |
import math | |
lora_r = min(module.lora_A.shape) | |
a_dim = max(module.lora_A.shape) | |
b_dim = max(module.lora_B.shape) | |
if init_config.mode == "simple": | |
match init_config.lora_A: | |
case "gaussian": | |
torch.nn.init.normal_( | |
module.lora_A, mean=0.0, | |
std=init_config.lora_A_std | |
) | |
case "kaiming": | |
# https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L124 | |
torch.nn.init.kaiming_uniform_(module.lora_A, | |
a=math.sqrt(5)) | |
case "fan_out_kaiming": | |
torch.nn.init.kaiming_normal_( | |
module.lora_A, mode="fan_out" | |
) | |
case "xavier": | |
torch.nn.init.xavier_normal_(module.lora_A) | |
case "zeros": | |
torch.nn.init.zeros_(module.lora_A) | |
case "unit": | |
torch.nn.init.normal_( | |
module.lora_A, mean=0.0, | |
std=1.0 / (a_dim**0.5) | |
) | |
case "orthogonal": | |
torch.nn.init.orthogonal_(module.lora_A) | |
case _: | |
raise ValueError( | |
f"Unknown lora_A initialization: {init_config.lora_A}" | |
) | |
match init_config.lora_B: | |
case "gaussian": | |
torch.nn.init.normal_( | |
module.lora_B, mean=0.0, | |
std=init_config.lora_B_std | |
) | |
case "kaiming": | |
torch.nn.init.kaiming_normal_(module.lora_B) | |
case "fan_out_kaiming": | |
torch.nn.init.kaiming_normal_( | |
module.lora_B, mode="fan_out" | |
) | |
case "xavier": | |
torch.nn.init.xavier_normal_(module.lora_B) | |
case "zeros": | |
torch.nn.init.zeros_(module.lora_B) | |
case "unit": | |
torch.nn.init.normal_( | |
module.lora_B, mean=0.0, | |
std=1.0 / (b_dim**0.5) | |
) | |
case "orthogonal": | |
torch.nn.init.orthogonal_(module.lora_B) | |
case _: | |
raise ValueError( | |
f"Unknown lora_B initialization: {init_config.lora_B}" | |
) | |
if getattr(init_config, 'scale', '') == "stable": | |
gamma = init_config.stable_gamma | |
m, n = module.weight.shape | |
module.lora_B.data *= (m**0.25) / gamma**0.5 | |
module.lora_A.data *= (n**0.25) / gamma**0.5 | |
elif init_config.mode == "svd": | |
U, S, V = torch.svd_lowrank(module.weight.float(), q=4 * lora_r, | |
niter=4) | |
V = V.T | |
m, n = module.weight.shape | |
if init_config.scale == "default": | |
S = S / module.scaling | |
module.lora_B = torch.nn.Parameter( | |
(U[:, :lora_r] * torch.sqrt(S[:lora_r])).contiguous() | |
) | |
module.lora_A = torch.nn.Parameter( | |
(V[:lora_r, :].T * torch.sqrt(S[:lora_r])).T.contiguous() | |
) | |
elif init_config.scale == "stable": | |
gamma = init_config.stable_gamma | |
module.lora_B = torch.nn.Parameter( | |
(U[:, :lora_r] * (m**0.25) / gamma**0.5).contiguous() | |
) | |
module.lora_A = torch.nn.Parameter( | |
(V[:lora_r, :] * (n**0.25) / gamma**0.5).contiguous() | |
) | |
elif init_config.scale == "unit": | |
module.lora_B = torch.nn.Parameter((U[:, :lora_r]).contiguous()) | |
module.lora_A = torch.nn.Parameter((V[:lora_r, :]).contiguous()) | |
elif init_config.scale == "normalized": | |
S_sum = S[:lora_r].sum() | |
module.lora_B = torch.nn.Parameter( | |
(U[:, :lora_r] * torch.sqrt(S[:lora_r]) | |
/ torch.sqrt(S_sum) * lora_r**0.5).contiguous() | |
) | |
module.lora_A = torch.nn.Parameter( | |
(V[:lora_r, :].T * torch.sqrt(S[:lora_r]) | |
/ torch.sqrt(S_sum) * lora_r**0.5).T.contiguous() | |
) | |
elif init_config.mode == "gradient": | |
named_grad = kwargs["named_grads"] | |
grad_name = name + ".weight" | |
grads = named_grad[grad_name] | |
U, S, V = torch.svd_lowrank(grads.cuda().float(), q=4 * lora_r, niter=4) | |
V = V.T | |
# set direction | |
if init_config.direction == "ArBr": | |
B = U[:, 0 : 2 * lora_r : 2] | |
A = V[1 : 2 * lora_r : 2, :] | |
elif init_config.direction == "A2rBr": | |
B = U[:, :lora_r] | |
A = V[lora_r : 2 * lora_r, :] | |
elif init_config.direction == "ArB2r": | |
B = U[:, lora_r : 2 * lora_r] | |
A = V[:lora_r, :] | |
scaling_factor = module.scaling | |
if init_config.scale == "gd": | |
A = A / scaling_factor | |
B = B / scaling_factor | |
elif init_config.scale == "unit": | |
# Because A,B is orthogonal, do not need to scale | |
pass | |
elif init_config.scale == "stable": | |
m, n = grads.shape | |
# m: feature_out, n: feature_in | |
# the scale of output is only related to the feature_out | |
gamma = init_config.stable_gamma | |
B = B * m**0.25 / gamma**0.5 | |
A = A * m**0.25 / gamma**0.5 | |
elif init_config.scale == "weightS": | |
_, S, _ = torch.svd_lowrank(module.weight.float(), q=4 * lora_r, | |
niter=4) | |
S = S / module.scaling | |
avg_s = torch.sqrt(S[:lora_r]).mean().to(A.device) | |
B = B * avg_s | |
A = A * avg_s | |
module.lora_B = torch.nn.Parameter(B.contiguous().cuda()) | |
module.lora_A = torch.nn.Parameter(A.contiguous().cuda()) | |
with torch.no_grad(): | |
# consider dtype not in init_config | |
if not hasattr(init_config, "dtype"): | |
pass | |
elif init_config.dtype == "bf16": | |
module.lora_A.data = module.lora_A.data.to(torch.bfloat16) | |
module.lora_B.data = module.lora_B.data.to(torch.bfloat16) | |
elif init_config.dtype == "fp32": | |
module.lora_A.data = module.lora_A.data.to(torch.float32) | |
module.lora_B.data = module.lora_B.data.to(torch.float32) | |
# If lora_A@lora_B is not zero, | |
# then we need to subtract lora_A@lora_B from the original weight matrix | |
offset = ( | |
module.lora_B @ module.lora_A | |
).to(module.weight.data.device) | |
scaling_factor = module.scaling | |
offset *= scaling_factor | |
if hasattr(init_config, "norm_clip") and init_config.norm_clip: | |
# for numerical stability, | |
# offset's largest value must be less then weight's largest value | |
ratio = torch.max(torch.abs(module.weight.data)) / torch.max( | |
torch.abs(offset) | |
) | |
if ratio < 1: | |
offset *= ratio | |
module.lora_A.data *= ratio**0.5 | |
module.lora_B.data *= ratio**0.5 | |
logging.warning(f"Clipping offset by {ratio}") | |
try: | |
module.weight.data -= offset | |
except Exception as e: | |
logging.warning(f"{e}") | |
breakpoint() | |