import torch from torch import nn from abc import ABC, abstractmethod from utils.dl.common.model import get_model_device, get_model_latency, get_model_size from utils.common.log import logger class LoRA(nn.Linear): pass class FMLoRA_Util(ABC): @abstractmethod def add_lora_ab_to_fm(self, fm: nn.Module, ab_r: int, samples: torch.Tensor): """ only applying LoRA to attention weights. """ pass def train_only_lora(self, fm: nn.Module): res = [] for n, m in fm.named_modules(): if isinstance(m, LoRA): for p in m.parameters(): p.requires_grad = True res += [p] else: for p in m.parameters(): p.requires_grad = False return res @abstractmethod def absorb_lora_and_recover_net_structure(self, fm: nn.Module): pass