File size: 4,266 Bytes
ae81e0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""
Helpers for parameter-efficient finetuning via low-rank adapters (LoRA)
-> Mainly follow PEFT / llama recipes

Right now quantization not super tested
"""
import torch
from torch.nn import Module


# Modified from https://github.com/facebookresearch/llama-recipes/blob/main/examples/quickstart.ipynb
def create_peft_config(model: Module, 
                       peft_config: dict, 
                       target_dtype: str = 'bfloat16',
                       preserve_requires_grad: bool = False,
                       use_gradient_checkpointing: bool = None,
                       add_self_attn_prefix: bool = True):
    """
    Create a parameter-efficient finetuning model (e.g., attaching LoRAs)
    -> Assumes that all non-trainable weights have been frozen already.
       If not, freeze them before calling this function.
    """
    if peft_config['method'] == 'lora':
        from peft import (
            get_peft_model,
            LoraConfig,
            TaskType,
            prepare_model_for_kbit_training,
        )
        try:
            target_modules = []  # hack to only do self_attn terms
            for module_name in peft_config['kwargs']['target_modules']:
                if ('_proj' in module_name and 'self_attn' not in module_name 
                    and add_self_attn_prefix):
                    target_modules.append(f'self_attn.{module_name}')
                elif '_proj' in module_name:
                    target_modules.append(module_name)
            peft_config['kwargs']['target_modules'] = target_modules
        except Exception as e:
            print(e)
            target_modules = []

        if 'layers_to_ignore' in peft_config:
            peft_config['kwargs']['layers_to_transform'] = [
                i for i in range(len(model.model.layers))
                if i not in peft_config['layers_to_ignore']
            ]
            
        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            inference_mode=False,
            **peft_config['kwargs'],
        )
        # Save parameters that did not have frozen weights before to unfreeze later
        trainable_weights = [
            n for n, p in model.named_parameters() if p.requires_grad 
        ]
        # Prepare int-8 or int-4 model for training
        loaded_in_kbit = (getattr(model, "is_loaded_in_8bit", False) or 
                          getattr(model, "is_loaded_in_4bit", False))
        if loaded_in_kbit:  # From https://huggingface.co/docs/peft/en/package_reference/peft_model:
            # This method wraps the entire protocol for preparing a model before running a training. 
            # 1- Cast the layernorm in fp32 
            # 2- making output embedding layer require grads 
            # 3- Add the upcasting of the lm head to fp32
            model.enable_input_require_grads()
            ugc = (use_gradient_checkpointing 
                   if use_gradient_checkpointing is not None else True)
            print('-> use_gradient_checkpointing:', ugc)
            # model.gradient_checkpointing_enable()
            model = prepare_model_for_kbit_training(
                model, use_gradient_checkpointing=ugc,
                gradient_checkpointing_kwargs={'use_reentrant': False},
            )
        
        model = get_peft_model(model, peft_config)
        model.print_trainable_parameters()

        for n, p in model.named_parameters():
            # Unfreeze weights frozen by get_peft_model()
            if preserve_requires_grad:
                if n[len('base_model.model.'):] in trainable_weights:
                    p.requires_grad = True

            # prepare_model_for_kbit_training will cast all non INT8 parameters to fp32
            # -> https://github.com/huggingface/peft/blob/7e84dec20b3106bdd0a90ba8e80187f0aec835b7/src/peft/utils/other.py#L103
            # So we'll cast these back to their prior dtype
            if p.requires_grad and loaded_in_kbit:
                p.data = p.data.to(getattr(torch, target_dtype))

        if not loaded_in_kbit:
            model.to(dtype=getattr(torch, target_dtype))

        return model, peft_config
    else:
        raise NotImplementedError(f"Sorry PEFT method {peft_config['method']} not implemented yet.")