File size: 932 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
from torch import nn
from abc import ABC, abstractmethod

from utils.dl.common.model import get_model_device, get_model_latency, get_model_size
from utils.common.log import logger


class LoRA(nn.Linear):
    pass


class FMLoRA_Util(ABC):
    @abstractmethod
    def add_lora_ab_to_fm(self, fm: nn.Module, ab_r: int, samples: torch.Tensor):
        """
        only applying LoRA to attention weights.
        """
        pass
    
    def train_only_lora(self, fm: nn.Module):
        res = []
        for n, m in fm.named_modules():
            if isinstance(m, LoRA):
                for p in m.parameters():
                    p.requires_grad = True
                    res += [p]
            else:
                for p in m.parameters():
                    p.requires_grad = False
        return res
    
    @abstractmethod
    def absorb_lora_and_recover_net_structure(self, fm: nn.Module):
        pass