File size: 7,552 Bytes
9031f04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import json
from dataclasses import asdict, dataclass, field
from typing import Literal, Optional


@dataclass
class FreezeArguments:
    r"""
    Arguments pertaining to the freeze (partial-parameter) training.
    """
    name_module_trainable: Optional[str] = field(
        default="mlp",
        metadata={
            "help": 'Name of trainable modules for partial-parameter (freeze) fine-tuning. \
                  Use commas to separate multiple modules. \
                  LLaMA choices: ["mlp", "self_attn"], \
                  BLOOM & Falcon & ChatGLM choices: ["mlp", "self_attention"], \
                  Qwen choices: ["mlp", "attn"], \
                  Phi choices: ["mlp", "mixer"], \
                  Others choices: the same as LLaMA.'
        },
    )
    num_layer_trainable: Optional[int] = field(
        default=3, metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."}
    )


@dataclass
class LoraArguments:
    r"""
    Arguments pertaining to the LoRA training.
    """
    additional_target: Optional[str] = field(
        default=None,
        metadata={
            "help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."
        },
    )
    lora_alpha: Optional[int] = field(
        default=None, metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."}
    )
    lora_dropout: Optional[float] = field(default=0.0, metadata={"help": "Dropout rate for the LoRA fine-tuning."})
    lora_rank: Optional[int] = field(default=8, metadata={"help": "The intrinsic dimension for LoRA fine-tuning."})
    lora_target: Optional[str] = field(
        default=None,
        metadata={
            "help": 'Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
                  LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \
                  BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \
                  Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \
                  Qwen choices: ["c_attn", "attn.c_proj", "w1", "w2", "mlp.c_proj"], \
                  Phi choices: ["Wqkv", "out_proj", "fc1", "fc2"], \
                  Others choices: the same as LLaMA.'
        },
    )
    lora_bf16_mode: Optional[bool] = field(
        default=False, metadata={"help": "Whether or not to train lora adapters in bf16 precision."}
    )
    create_new_adapter: Optional[bool] = field(
        default=False, metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."}
    )


@dataclass
class RLHFArguments:
    r"""
    Arguments pertaining to the PPO and DPO training.
    """
    dpo_beta: Optional[float] = field(default=0.1, metadata={"help": "The beta parameter for the DPO loss."})
    dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto"]] = field(
        default="sigmoid", metadata={"help": "The type of DPO loss to use."}
    )
    dpo_ftx: Optional[float] = field(
        default=0, metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."}
    )
    ppo_buffer_size: Optional[int] = field(
        default=1,
        metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
    )
    ppo_epochs: Optional[int] = field(
        default=4, metadata={"help": "The number of epochs to perform in a PPO optimization step."}
    )
    ppo_logger: Optional[str] = field(
        default=None, metadata={"help": 'Log with either "wandb" or "tensorboard" in PPO training.'}
    )
    ppo_score_norm: Optional[bool] = field(
        default=False, metadata={"help": "Use score normalization in PPO training."}
    )
    ppo_target: Optional[float] = field(
        default=6.0, metadata={"help": "Target KL value for adaptive KL control in PPO training."}
    )
    ppo_whiten_rewards: Optional[bool] = field(
        default=False, metadata={"help": "Whiten the rewards before compute advantages in PPO training."}
    )
    ref_model: Optional[str] = field(
        default=None, metadata={"help": "Path to the reference model used for the PPO or DPO training."}
    )
    ref_model_adapters: Optional[str] = field(
        default=None, metadata={"help": "Path to the adapters of the reference model."}
    )
    ref_model_quantization_bit: Optional[int] = field(
        default=None, metadata={"help": "The number of bits to quantize the reference model."}
    )
    reward_model: Optional[str] = field(
        default=None, metadata={"help": "Path to the reward model used for the PPO training."}
    )
    reward_model_adapters: Optional[str] = field(
        default=None, metadata={"help": "Path to the adapters of the reward model."}
    )
    reward_model_quantization_bit: Optional[int] = field(
        default=None, metadata={"help": "The number of bits to quantize the reward model."}
    )
    reward_model_type: Optional[Literal["lora", "full", "api"]] = field(
        default="lora",
        metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
    )


@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
    r"""
    Arguments pertaining to which techniques we are going to fine-tuning with.
    """
    stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
        default="sft", metadata={"help": "Which stage will be performed in training."}
    )
    finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
        default="lora", metadata={"help": "Which fine-tuning method to use."}
    )
    plot_loss: Optional[bool] = field(
        default=False, metadata={"help": "Whether or not to save the training loss curves."}
    )

    def __post_init__(self):
        def split_arg(arg):
            if isinstance(arg, str):
                return [item.strip() for item in arg.split(",")]
            return arg

        self.name_module_trainable = split_arg(self.name_module_trainable)
        self.lora_alpha = self.lora_alpha or self.lora_rank * 2
        self.lora_target = split_arg(self.lora_target)
        self.additional_target = split_arg(self.additional_target)

        assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
        assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
        assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."

        if self.stage == "ppo" and self.reward_model is None:
            raise ValueError("Reward model is necessary for PPO training.")

        if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
            raise ValueError("Freeze/Full PPO training needs `reward_model_type=full`.")

    def save_to_json(self, json_path: str):
        r"""Saves the content of this instance in JSON format inside `json_path`."""
        json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
        with open(json_path, "w", encoding="utf-8") as f:
            f.write(json_string)

    @classmethod
    def load_from_json(cls, json_path: str):
        r"""Creates an instance from the content of `json_path`."""
        with open(json_path, "r", encoding="utf-8") as f:
            text = f.read()

        return cls(**json.loads(text))