LINC-BIT's picture
Upload 1912 files
b84549f verified
raw
history blame
932 Bytes
import torch
from torch import nn
from abc import ABC, abstractmethod
from utils.dl.common.model import get_model_device, get_model_latency, get_model_size
from utils.common.log import logger
class LoRA(nn.Linear):
pass
class FMLoRA_Util(ABC):
@abstractmethod
def add_lora_ab_to_fm(self, fm: nn.Module, ab_r: int, samples: torch.Tensor):
"""
only applying LoRA to attention weights.
"""
pass
def train_only_lora(self, fm: nn.Module):
res = []
for n, m in fm.named_modules():
if isinstance(m, LoRA):
for p in m.parameters():
p.requires_grad = True
res += [p]
else:
for p in m.parameters():
p.requires_grad = False
return res
@abstractmethod
def absorb_lora_and_recover_net_structure(self, fm: nn.Module):
pass