File size: 13,356 Bytes
0108542
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
import os
import torch
from torch import nn
from tqdm import tqdm
import numpy as np

from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator
from transformers import get_scheduler
from accelerate import Accelerator
from accelerate.utils import set_seed
from collections import defaultdict
from torch.utils.data import DataLoader
import torch.optim as optim

from ..utils.data_utils import load_lm_dataset, extract_new_words_from_dataset, get_group_texts_func, get_tokenize_func


class EmbeddingCalibrator(nn.Module):
    def __init__(self, hidden_size, lora_r=None, lora_alpha=None, dtype=torch.bfloat16):
        super().__init__()
        self.use_lora = lora_r is not None

        if not self.use_lora:
            self.weight = nn.Parameter(torch.zeros(hidden_size, hidden_size, dtype=dtype))
        else:
            self.lora_scaling = lora_alpha / lora_r if lora_alpha is not None else 1.0
            self.lora_A = nn.Parameter(torch.randn(lora_rank, hidden_size, dtype=dtype) * (1/lora_r))
            self.lora_B = nn.Parameter(torch.zeros(hidden_size, lora_rank, dtype=dtype))

    def forward(self, x):
        if not self.use_lora:
            return x + torch.matmul(x, self.weight.t())
        else:
            # Low-rank adaptation
            lora_out = torch.matmul(x, self.lora_A.t())
            lora_out = torch.matmul(lora_out, self.lora_B.t())
            return x + self.lora_scaling * lora_out


class CalibrationModel(nn.Module):
    def __init__(
            self,
            base_model, lm_head, original_vocab_size, num_new_tokens,
            calibrate_embedding=True, calibrate_lm_head=True, empty_init=False,
            lora_alpha=None, lora_r=None,
            target_loss_weight=0.15, subsequent_loss_weight=0.15,
    ):
        super().__init__()
        self.base_model = base_model
        self.lm_head = lm_head
        self.new_tokens_start = original_vocab_size
        self.new_tokens_end = original_vocab_size + num_new_tokens
        
        self.calibrate_lm_head = calibrate_lm_head
        self.calibrate_embedding = calibrate_embedding
        if not empty_init:
            self.lm_head_calibrator = EmbeddingCalibrator(base_model.config.hidden_size, lora_r, lora_alpha)
            self.embedding_calibrator = EmbeddingCalibrator(base_model.config.hidden_size, lora_r, lora_alpha)

        self.loss_fct = nn.CrossEntropyLoss(reduction="none")
        self.subsequent_tokens_loss_alpha = subsequent_loss_weight
        self.new_tokens_loss_alpha = target_loss_weight
        self.original_tokens_loss_alpha = 1 - self.new_tokens_loss_alpha - self.subsequent_tokens_loss_alpha

    def forward(self, input_ids, labels, attention_mask=None):
        # shift labels by 1 for CLM
        labels = labels[:, 1:].contiguous()
        input_ids = input_ids[:, :-1].contiguous()

        if self.calibrate_embedding:
            E_weights = self.base_model.get_input_embeddings().weight.data
            E_weights = torch.cat((E_weights[:self.new_tokens_start], self.embedding_calibrator(E_weights[self.new_tokens_start:])))
            input_embeddings = E_weights[input_ids]
            if attention_mask is None:
                attention_mask = torch.ones_like(input_ids, dtype=torch.long)
            outputs = self.base_model(inputs_embeds=input_embeddings, attention_mask=attention_mask)
        else:
            with torch.no_grad():
                # Forward pass through the base model
                outputs = self.base_model(input_ids, attention_mask=attention_mask)

        if self.calibrate_lm_head:
            with torch.no_grad():
                lm_head_weights = self.lm_head.weight
                normed_weights = lm_head_weights.clone()
            normed_weights[self.new_tokens_start:self.new_tokens_end] = self.lm_head_calibrator(lm_head_weights[self.new_tokens_start:self.new_tokens_end])
            logits = torch.matmul(outputs['last_hidden_state'], normed_weights.T)
        else:
            if self.calibrate_embedding:
                logits = self.lm_head(outputs['last_hidden_state'])
            else:
                with torch.no_grad():
                    logits = self.lm_head(outputs['last_hidden_state'])

        per_example_loss = self.loss_fct(logits.transpose(1,2), labels)
        original_tokens_mask = labels < self.new_tokens_start
        new_tokens_mask = ~original_tokens_mask
        loss = 0.0
        if self.original_tokens_loss_alpha > 0.0:
            loss += self.original_tokens_loss_alpha * per_example_loss[original_tokens_mask].mean()
        if self.new_tokens_loss_alpha > 0.0:
            loss += self.new_tokens_loss_alpha * per_example_loss[new_tokens_mask].mean()
        if self.subsequent_tokens_loss_alpha > 0.0:
            subsequent_tokens_mask = torch.zeros_like(original_tokens_mask, dtype=torch.bool)
            subsequent_tokens_mask[:, 1:][new_tokens_mask[:, :-1]] = True
            loss += self.subsequent_tokens_loss_alpha * per_example_loss[subsequent_tokens_mask].mean()

        return {'loss': loss, 'logits': logits}

    def get_calibrators(self):
        embedding_calibrator = self.embedding_calibrator if self.calibrate_embedding else None
        lm_head_calibrator = self.lm_head_calibrator if self.calibrate_lm_head else None
        return {
            "embedding_calibrator": embedding_calibrator,
            "lm_head_calibrator": lm_head_calibrator,
            "new_tokens_start": self.new_tokens_start,
            "new_tokens_end": self.new_tokens_end,
        }

    def set_calibrators(self, embedding_calibrator=None, lm_head_calibrator=None):
        self.embedding_calibrator = embedding_calibrator
        self.lm_head_calibrator = lm_head_calibrator
        
    def save_calibrators(self, save_dir):
        os.makedirs(save_dir, exist_ok=True)
        if self.calibrate_embedding:
            torch.save(self.embedding_calibrator, os.path.join(save_dir, "embedding_calibrator.pt"))
        if self.calibrate_lm_head:
            torch.save(self.lm_head_calibrator, os.path.join(save_dir, "lm_head_calibrator.pt"))

    def load_calibrators(self, load_dir, fail_ok=False):
        """Loads the model's state dictionary from a file."""
        try:
            if self.calibrate_embedding:
                self.embedding_calibrator = torch.load(os.path.join(load_dir, "embedding_calibrator.pt"))
            if self.calibrate_lm_head:
                self.lm_head_calibrator = torch.load(os.path.join(load_dir, "lm_head_calibrator.pt"))
            return True
        except:
            if fail_ok:
                return False
            raise FileNotFoundError(f"Loading calibrators from '{load_dir}' failed")


def get_calibration_model(model, original_vocab_size, num_new_tokens, target_loss_weight=0.15, subsequent_loss_weight=0.15):
    calibrated_model = CalibrationModel(model.model, model.lm_head, original_vocab_size, num_new_tokens, target_loss_weight=target_loss_weight, subsequent_loss_weight=subsequent_loss_weight)
    calibrated_model.base_model.eval()
    calibrated_model.lm_head.eval()

    for param in calibrated_model.base_model.parameters():
        param.requires_grad = False
    for param in calibrated_model.lm_head.parameters():
        param.requires_grad = False
    for param in calibrated_model.lm_head_calibrator.parameters():
        param.requires_grad = True
    for param in calibrated_model.embedding_calibrator.parameters():
        param.requires_grad = True

    return calibrated_model


def train_calibration_model(calibrated_model: CalibrationModel, tokenizer, dataset, save_dir=None, max_samples=None, filter_examples_without_new_tokens=True, lr=1e-4, lr_schedule="linear", num_epochs=1, batch_size=8, max_length=256, n_warmup_steps=0, text_col_name="text", clip_grad_norm=1.0, mixed_precision=None):
    accelerator = Accelerator(mixed_precision=mixed_precision)
    # Optimizer
    optimizer = optim.AdamW(calibrated_model.parameters(), lr=lr)

    # Tokenize data
    if tokenizer.bos_token is not None and max_length:
        add_start_token = True
        # leave room for <BOS> token to be added:
        max_tokenized_len = max_length - 1
    else:
        add_start_token = False
        max_tokenized_len = max_length

    def _add_start_token(batch):
        bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * batch["input_ids"].size(dim=0)).to(batch["input_ids"].device)
        batch["input_ids"] = torch.cat([bos_tokens_tensor, batch["input_ids"]], dim=1)
        batch["attention_mask"] = torch.cat(
            [torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(batch["attention_mask"].device), batch["attention_mask"]], dim=1)
        return batch

    tokenize_function = get_tokenize_func(tokenizer, text_col_name)

    column_names = dataset.column_names

    with accelerator.main_process_first():
        tokenized_dataset = dataset.map(
            tokenize_function,
            batched=True,
            remove_columns=column_names,
            load_from_cache_file=False,
            desc="Running tokenizer on dataset",
        )
        group_texts = get_group_texts_func(block_size=max_tokenized_len)
        lm_dataset = tokenized_dataset.map(
            group_texts,
            batched=True,
        )

    if filter_examples_without_new_tokens:
        examples_w_new_token = np.arange(len(lm_dataset))[np.any(np.array(lm_dataset['input_ids']) >= calibrated_model.new_tokens_start, axis=1)]
        lm_dataset = lm_dataset.select(examples_w_new_token)

    if max_samples is not None:
        lm_dataset = lm_dataset.select(np.arange(max_samples))

    data_collator = default_data_collator

    # Create data loaders
    dataloader = DataLoader(
        lm_dataset, collate_fn=data_collator, batch_size=batch_size, drop_last=True, shuffle=True,
    )

    # Learning rate scheduler
    if isinstance(n_warmup_steps, float):
        n_warmup_steps = n_warmup_steps * len(dataloader)
    scheduler = get_scheduler(lr_schedule, optimizer=optimizer, num_warmup_steps=n_warmup_steps, num_training_steps=len(dataloader) * num_epochs)

    calibrated_model, dataloader = accelerator.prepare(calibrated_model, dataloader)

    # Freeze the original lm_head weights
    for param in calibrated_model.lm_head.parameters():
        param.requires_grad = False

    calibrated_model.train()
    for epoch in tqdm(range(num_epochs), unit="epochs", desc="Fitting calibration"):
        total_loss = 0.0
        for step, batch in tqdm(enumerate(dataloader), total=len(dataloader), miniters=10, unit="batches"):
            if add_start_token:
                batch = _add_start_token(batch)
            batch["labels"] = batch["input_ids"]
            optimizer.zero_grad()
            outputs = calibrated_model(**batch)
            loss = outputs['loss']
            loss.backward()
            torch.nn.utils.clip_grad_norm_(calibrated_model.parameters(), max_norm=clip_grad_norm)
            optimizer.step()
            scheduler.step()

            total_loss += loss.item()

            # # Log loss
            # if step % 10 == 0:
            #     print(f"Epoch {epoch + 1}, Step {step}, Loss: {loss.item()}")

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch + 1} completed. Average Loss: {avg_loss}")

    if save_dir is not None:
        calibrated_model.save_calibrators(save_dir)

    return calibrated_model


def merge_calibrators_to_hf_model(hf_model, new_tokens_start, new_tokens_end=None, embedding_calibrator=None, lm_head_calibrator=None):
    embedding_calibrator.to(hf_model.device)
    lm_head_calibrator.to(hf_model.device)
    if embedding_calibrator is not None:
        embedding_weights = hf_model.get_input_embeddings().weight
        with torch.no_grad():
            calibrated_weights = embedding_calibrator(embedding_weights[new_tokens_start:new_tokens_end])
            hf_model.model.embed_tokens.weight.data[
            new_tokens_start:new_tokens_end] = calibrated_weights

    if lm_head_calibrator is not None:
        lm_head_weights = hf_model.get_output_embeddings().weight
        with torch.no_grad():
            calibrated_weights = lm_head_calibrator(lm_head_weights[new_tokens_start:new_tokens_end])
            hf_model.lm_head.weight.data[new_tokens_start:new_tokens_end] = calibrated_weights

    return hf_model


def merge_calibration_model_to_hf_model(hf_model, calibrated_model):
    calibrated_model.to(hf_model.device)
    if calibrated_model.calibrate_lm_head:
        lm_head_weights = calibrated_model.lm_head.weight
        normed_weights = calibrated_model.lm_head_calibrator(lm_head_weights[calibrated_model.new_tokens_start:calibrated_model.new_tokens_end])
        with torch.no_grad():
            hf_model.lm_head.weight.data[calibrated_model.new_tokens_start:calibrated_model.new_tokens_end] = normed_weights
    if calibrated_model.calibrate_embedding:
        embedding_weights = calibrated_model.base_model.get_input_embeddings().weight
        normed_weights = calibrated_model.embedding_calibrator(embedding_weights[calibrated_model.new_tokens_start:calibrated_model.new_tokens_end])
        with torch.no_grad():
            hf_model.model.embed_tokens.weight.data[calibrated_model.new_tokens_start:calibrated_model.new_tokens_end] = normed_weights
    return hf_model