# ------------------------------------------------------------------------------------------ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ import math from typing import List, Optional import torch import torch.nn as nn import torch.nn.functional as F from .layers import LoRALayer class SVDLinear(nn.Linear, LoRALayer): # SVD-based adaptation implemented in a dense layer def __init__( self, in_features: int, out_features: int, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0., fan_in_fan_out: bool = False, merge_weights: bool = True, **kwargs ): nn.Linear.__init__(self, in_features, out_features, **kwargs) LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) self.fan_in_fan_out = fan_in_fan_out # Actual trainable parameters if r > 0: self.lora_A = nn.Parameter( self.weight.new_zeros((r, in_features)) ) self.lora_E = nn.Parameter( self.weight.new_zeros(r, 1) ) self.lora_B = nn.Parameter( self.weight.new_zeros((out_features, r)) ) self.ranknum = nn.Parameter( self.weight.new_zeros(1), requires_grad=False ) self.ranknum.data.fill_(float(self.r)) self.scaling = self.lora_alpha if self.lora_alpha>0 else float(self.r) # Freezing the pre-trained weight matrix self.weight.requires_grad = False self.ranknum.requires_grad = False self.reset_parameters() if fan_in_fan_out: self.weight.data = self.weight.data.T def reset_parameters(self): nn.Linear.reset_parameters(self) if hasattr(self, 'lora_A'): # initialize A,B the same way as the default for nn.Linear # and E (singular values) for zero nn.init.zeros_(self.lora_E) nn.init.normal_(self.lora_A, mean=0.0, std=0.02) nn.init.normal_(self.lora_B, mean=0.0, std=0.02) def train(self, mode: bool = True): def T(w): return w.T if self.fan_in_fan_out else w nn.Linear.train(self, mode) if self.merge_weights and self.merged: # Make sure that the weights are not merged if self.r > 0: self.weight.data -= T( self.lora_B @ (self.lora_A*self.lora_E) ) * self.scaling / (self.ranknum+1e-5) self.merged = False def eval(self): def T(w): return w.T if self.fan_in_fan_out else w nn.Linear.eval(self) if self.merge_weights and not self.merged: # Merge the weights and mark it if self.r > 0: self.weight.data += T( self.lora_B @ (self.lora_A * self.lora_E) ) * self.scaling / (self.ranknum+1e-5) self.merged = True def forward(self, x: torch.Tensor): def T(w): return w.T if self.fan_in_fan_out else w if self.r > 0 and not self.merged: result = F.linear(x, T(self.weight), bias=self.bias) if self.r > 0: result += ( self.lora_dropout(x) @ (self.lora_A * self.lora_E).T @ self.lora_B.T ) * self.scaling / (self.ranknum+1e-5) return result else: return F.linear(x, T(self.weight), bias=self.bias) class RankAllocator(object): """ The RankAllocator for AdaLoRA Model that will be called every training step. Paper: https://openreview.net/pdf?id=lq62uWRJjiY Args: model: the model that we apply AdaLoRA to. lora_r (`int`): The initial rank for each incremental matrix. target_rank (`int`): The target average rank of incremental matrix. init_warmup (`int`): The steps of initial fine-tuning warmup. final_warmup (`int`): The step of final fine-tuning. mask_interval (`int`): The time internval between two budget allocations. beta1 (`float`): The hyperparameter of EMA for sensitivity smoothing. beta2 (`float`): The hyperparameter of EMA for undertainty quantification. total_step (`int`): The total training steps, correctly configured before training. target_total_rank (`Optinal[int]`): The speficified final total rank. tb_writter (`SummaryWriter`): Tensorboard SummaryWriter. tb_writter_loginterval (`int`): The logging interval of SummaryWriter. """ def __init__( self, model, lora_r:int, target_rank:int, init_warmup:int, final_warmup:int, mask_interval:int, beta1:float, beta2:float, total_step:Optional[int]=None, target_total_rank:Optional[int]=None, tb_writter=None, tb_writter_loginterval:int=500, ): self.ave_target_rank = target_rank self.target_rank = target_total_rank self.lora_init_rank = lora_r self.initial_warmup = init_warmup self.final_warmup = final_warmup self.mask_interval = mask_interval self.beta1 = beta1 self.beta2 = beta2 self.total_step = total_step self.model = model self.ipt = {} self.exp_avg_ipt = {} self.exp_avg_unc = {} self.cat_ipt = {} self.rank_pattern = {} self.get_lora_param_name() self.tb_writter = tb_writter self.log_interval = tb_writter_loginterval assert (self.beta1<1 and self.beta1>0) assert (self.beta2<1 and self.beta2>0) def set_total_step(self, total_step:int): # Set total step number self.total_step = total_step assert self.total_step>self.initial_warmup+self.final_warmup def get_rank_pattern(self): # Return rank pattern return self.rank_pattern def get_lora_param_name(self): # Prepare the budget scheduler self.name_set = set() self.total_rank = 0 self.shape_dict = {} for n,p in self.model.named_parameters(): if "lora_A" in n: name_mat = n.replace("lora_A", "%s") self.name_set.add(name_mat) self.total_rank += p.size(0) self.shape_dict[n] = p.shape if "lora_B" in n: self.shape_dict[n] = p.shape self.name_set = list(sorted(self.name_set)) if self.target_rank is None: self.target_rank = self.ave_target_rank * len(self.name_set) def schedule_threshold(self, step:int): # Global budget schedule mask_ind = False target_rank = self.target_rank initial_warmup = self.initial_warmup final_warmup = self.final_warmup total_step = self.total_step self.global_step = step if step <= initial_warmup: # Initial warmup curr_rank = self.total_rank mask_ind = False elif step > total_step - final_warmup: # Final fine-tuning curr_rank = self.target_rank # Fix the rank pattern by # always masking the same unimportant singluar values mask_ind = True else: # Budget decreasing mul_coeff = 1-(step-initial_warmup)/(total_step-final_warmup-initial_warmup) curr_rank = target_rank + (self.total_rank-target_rank)*(mul_coeff**3) curr_rank = int(curr_rank) mask_ind = True if step % self.mask_interval == 0 else False return curr_rank, mask_ind def update_ipt(self, model): for n,p in model.named_parameters(): if "lora_" in n: if n not in self.ipt: self.ipt[n] = torch.zeros_like(p) self.exp_avg_ipt[n] = torch.zeros_like(p) self.exp_avg_unc[n] = torch.zeros_like(p) with torch.no_grad(): # Calculate sensitivity self.ipt[n] = (p * p.grad).abs().detach() # Update sensitivity self.exp_avg_ipt[n] = self.beta1 * self.exp_avg_ipt[n] + \ (1-self.beta1)*self.ipt[n] # Update uncertainty self.exp_avg_unc[n] = self.beta2 * self.exp_avg_unc[n] + \ (1-self.beta2)*(self.ipt[n]-self.exp_avg_ipt[n]).abs() def calculate_score(self, n, p=None, metric="ipt"): if metric == "ipt": # Combine the senstivity and uncertainty ipt_score = self.exp_avg_ipt[n] * self.exp_avg_unc[n] elif metric == "mag": ipt_score = p.abs().detach().clone() else: raise ValueError("Unexcptected Metric: %s"%metric) return ipt_score def _combine_ipt(self, ipt_E, ipt_AB): ipt_AB = ipt_AB.sum(dim=1, keepdim=False) sum_ipt = ipt_E.view(-1) + ipt_AB.view(-1) return sum_ipt def mask_to_target_rank(self, model, curr_rank): is_dict = {} combine_dict = {} singular_dict = {} # Calculate the importance score for each sub matrix for n,p in model.named_parameters(): if "lora_A" in n: rdim, hdim_a = p.shape ipt_score = self.calculate_score(n, metric="ipt") comb_ipt = torch.mean(ipt_score, dim=1, keepdim=True) name_mat = n.replace("lora_A", "%s") if name_mat not in combine_dict: combine_dict[name_mat] = [comb_ipt] else: combine_dict[name_mat].append(comb_ipt) if "lora_B" in n: hdim_b, rdim = p.shape ipt_score = self.calculate_score(n, metric="ipt") comb_ipt = torch.mean(ipt_score, dim=0, keepdim=False).view(-1, 1) name_mat = n.replace("lora_B", "%s") if name_mat not in combine_dict: combine_dict[name_mat] = [comb_ipt] else: combine_dict[name_mat].append(comb_ipt) if "lora_E" in n: ipt_score = self.calculate_score(n, p=p, metric="ipt") name_mat = n.replace("lora_E", "%s") singular_dict[name_mat] = ipt_score # Combine the importance scores all_is = [] for name_mat in combine_dict: ipt_E = singular_dict[name_mat] ipt_AB = torch.cat(combine_dict[name_mat], dim=1) sum_ipt = self._combine_ipt(ipt_E, ipt_AB) name_E = name_mat%"lora_E" is_dict[name_E] = sum_ipt.view(-1, 1) all_is.append(sum_ipt.view(-1)) # Calculate the masking threshold mask_threshold = torch.kthvalue(torch.cat(all_is), (self.total_rank-curr_rank))[0].item() # Mask out unimportant singular values with torch.no_grad(): curr_sum_rank = 0 sum_param = 0 for n,p in model.named_parameters(): if "lora_E" in n: p.data.masked_fill_(is_dict[n]<=mask_threshold, 0.0) ranknum = (is_dict[n]>mask_threshold).sum().item() if self.tb_writter is not None and self.global_step%self.log_interval==0: self.tb_writter.add_scalar("Ranknum/%s"%(n,), ranknum, self.global_step) self.rank_pattern[n] = ranknum curr_sum_rank += ranknum sum_param += ranknum*self.shape_dict[n.replace("lora_E", "lora_A")][1] sum_param += ranknum*self.shape_dict[n.replace("lora_E", "lora_B")][0] if self.tb_writter is not None and self.global_step%self.log_interval==0: self.tb_writter.add_scalar("Budget/total_rank", curr_sum_rank, self.global_step) self.tb_writter.add_scalar("Budget/mask_threshold", mask_threshold, self.global_step) self.tb_writter.add_scalar("Budget/sum_param", sum_param, self.global_step) return mask_threshold def update_and_mask(self, model, global_step): if global_step