File size: 13,698 Bytes
2a964fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
import os
import yaml
import torch
from datasets import load_dataset, IterableDataset, Dataset, concatenate_datasets
from trl import SFTTrainer, SFTConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, get_scheduler
from accelerate import Accelerator
from huggingface_hub import HfFolder, create_repo, upload_folder
import wandb
import time
import torch.nn.functional as F
from galore_torch import GaLoreAdamW8bit
import gc
from transformers import TrainerCallback
from itertools import islice
from huggingface_hub import login

def load_config(config_path):
    with open(config_path, 'r') as file:
        return yaml.safe_load(file)

def setup_environment(config):
    # os.environ['WANDB_PROJECT'] = config["wandb"]["wandb_project"]
    # os.environ['WANDB_ENTITY'] = config["wandb"]["wandb_entity"]
    # wandb.init(project=config["wandb"]["wandb_project"], entity=config["wandb"]["wandb_entity"])
    os.environ['WANDB_DISABLED'] = 'true'
    return Accelerator()

def load_and_preprocess_dataset(config, student_tokenizer):
    def tokenize_function(examples):
        return student_tokenizer(examples["text"], truncation=True, max_length=config["tokenizer"]["max_length"], padding="max_length")

    datasets = []
    for subset in config["dataset"]["subsets"]:
        # Load the dataset as an IterableDataset
        dataset = load_dataset(
            config["dataset"]["name"],
            subset['name'],
            split=subset['split'],
            streaming=True
        )
        
        # Keep only the 'text' column for all subsets
        if 'text' in dataset.column_names:
            dataset = dataset.remove_columns([col for col in dataset.column_names if col != 'text'])
        else:
            raise ValueError(f"The 'text' column is missing in the {subset['name']} subset.")
        
        datasets.append(dataset)

    # Concatenate all datasets
    full_dataset = concatenate_datasets(datasets)

    # Create evaluation dataset (first N examples)
    eval_dataset = Dataset.from_list(list(islice(full_dataset, config["dataset"]["eval_samples"])))
    eval_dataset = eval_dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=eval_dataset.column_names
    )

    # Create training dataset (skip first N examples)
    def generate_train_examples():
        for i, example in enumerate(full_dataset):
            if i >= config["dataset"]["eval_samples"]:
                yield example

    train_dataset = IterableDataset.from_generator(generate_train_examples)
    train_dataset = train_dataset.map(
        tokenize_function,
        remove_columns=train_dataset.column_names
    )

    return train_dataset, eval_dataset

def load_models_and_tokenizers(config):
    model_kwargs = {"torch_dtype": torch.bfloat16}
    if config["model_config"]["use_flash_attention"]:
        model_kwargs["attn_implementation"] = "flash_attention_2"

    print(f"model_kwargs: {model_kwargs}")

    teacher_tokenizer = AutoTokenizer.from_pretrained(config["models"]["teacher"], add_eos_token=True)
    student_tokenizer = AutoTokenizer.from_pretrained(config["models"]["student"], add_eos_token=True)

    if student_tokenizer.pad_token is None:
        student_tokenizer.pad_token = student_tokenizer.eos_token
        print(f"Set pad_token to eos_token: {student_tokenizer.pad_token}")

    teacher_model = AutoModelForCausalLM.from_pretrained(config["models"]["teacher"], **model_kwargs)
    student_model = AutoModelForCausalLM.from_pretrained(config["models"]["student"], **model_kwargs)

    teacher_model.eval() # set teacher model to evaluation mode

    return teacher_model, student_model, teacher_tokenizer, student_tokenizer

def pad_logits(student_logits, teacher_logits):
    student_size, teacher_size = student_logits.size(-1), teacher_logits.size(-1)
    if student_size != teacher_size:
        pad_size = abs(student_size - teacher_size)
        pad_tensor = torch.zeros((*teacher_logits.shape[:-1], pad_size), dtype=teacher_logits.dtype, device=teacher_logits.device)
        return (torch.cat([student_logits, pad_tensor], dim=-1), teacher_logits) if student_size < teacher_size else (student_logits, torch.cat([teacher_logits, pad_tensor], dim=-1))
    return student_logits, teacher_logits

class DistillationTrainer(SFTTrainer):
    def __init__(self, *args, **kwargs):
        self.config = kwargs.pop('config', None)
        self.teacher_model = kwargs.pop('teacher_model', None)
        super().__init__(*args, **kwargs)
        
        # Ensure teacher model is on the same device as the student model
        if self.teacher_model.device != self.model.device:
            self.teacher_model = self.teacher_model.to(self.model.device)
        
        # Ensure teacher model is in eval mode
        self.teacher_model.eval()

    def compute_loss(self, model, inputs, return_outputs=False):
        if hasattr(model, 'module'):
            device = model.module.device
        else:
            device = next(model.parameters()).device
        inputs = {k: v.to(device) if hasattr(v, 'to') else v for k, v in inputs.items()}
        
        student_outputs = model(**inputs)
        with torch.no_grad():
            teacher_outputs = self.teacher_model(**inputs)

        # Check if 'labels' are in the inputs, if not, use 'input_ids' as labels
        labels = inputs.get('labels', inputs.get('input_ids'))
        
        if labels is None:
            raise ValueError("Neither 'labels' nor 'input_ids' found in inputs. Cannot compute loss.")

        custom_loss = self.distillation_loss(student_outputs.logits, teacher_outputs.logits, labels)
        return (custom_loss, student_outputs) if return_outputs else custom_loss

    def distillation_loss(self, student_logits, teacher_logits, labels):
        student_logits, teacher_logits = pad_logits(student_logits, teacher_logits)
        
        kl_loss = self.forward_kl_divergence(student_logits, teacher_logits)
        
        if self.config["distillation"]["alpha"] != 1:
            # Calculate the original loss (cross-entropy loss)
            original_loss = F.cross_entropy(student_logits.view(-1, student_logits.size(-1)), labels.view(-1), ignore_index=-100)
        else:
            original_loss = 0
        
        combined_loss = self.config["distillation"]["alpha"] * kl_loss + (1 - self.config["distillation"]["alpha"]) * original_loss
        return combined_loss

    def forward_kl_divergence(self, student_logits, teacher_logits):
        temperature = self.config["distillation"]["temperature"]
        student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
        teacher_log_probs = F.log_softmax(teacher_logits / temperature, dim=-1)
        
        kl_div = F.kl_div(
            student_log_probs,
            teacher_log_probs.exp(),
            reduction='batchmean',
            log_target=False
        )
        return kl_div * (temperature ** 2) / self.config["tokenizer"]["max_length"]

    
    def evaluation_loop(self, dataloader, description, prediction_loss_only=None, ignore_keys=None, metric_key_prefix="eval"):
        output = super().evaluation_loop(dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
        
        eval_loss = 0.0
        num_examples = 0
        chunk_size = 4  # Adjust this value based on your GPU memory

        for step, inputs in enumerate(dataloader):
            for i in range(0, inputs["input_ids"].size(0), chunk_size):
                chunk_inputs = {k: v[i:i+chunk_size] for k, v in inputs.items() if isinstance(v, torch.Tensor)}
                loss = self.compute_loss(self.model, chunk_inputs)
                eval_loss += loss.detach().float() * len(chunk_inputs["input_ids"])
                num_examples += len(chunk_inputs["input_ids"])
        
        eval_loss /= num_examples
        output.metrics[f"{metric_key_prefix}_loss"] = eval_loss.item()
        return output

def print_memory_stats():
    print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f}GB")
    print(f"Max Allocated: {torch.cuda.max_memory_allocated() / 1e9:.2f}GB")
    print(f"Reserved: {torch.cuda.memory_reserved() / 1e9:.2f}GB")
    print(f"Max Reserved: {torch.cuda.max_memory_reserved() / 1e9:.2f}GB")

def clear_memory():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

class MemoryTracker(TrainerCallback):
    def __init__(self, print_every=100):
        self.print_every = print_every

    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % self.print_every == 0:
            print(f"Step {state.global_step}:")
            print_memory_stats()
            clear_memory()

def get_custom_scheduler(optimizer, num_warmup_steps, num_training_steps):
    return get_scheduler(
        "constant",
        optimizer=optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps
    )

def main(config_path):
    config = load_config(config_path)
    accelerator = setup_environment(config)
    
    teacher_model, student_model, teacher_tokenizer, student_tokenizer = load_models_and_tokenizers(config)
    
    print(f"Student model: {student_model}")
    
    print("Memory after loading models:")
    print_memory_stats()
    clear_memory()

    train_dataset, eval_dataset = load_and_preprocess_dataset(config, student_tokenizer)
    
    # Ensure train_dataset is iterable and eval_dataset is a regular dataset
    # assert isinstance(train_dataset, IterableDataset)
    # assert isinstance(eval_dataset, Dataset)
    
    # Calculate max_steps
    total_samples = config["dataset"]["total_train_samples"] - config["dataset"]["eval_samples"]
    batch_size = config["training"]["per_device_train_batch_size"]
    grad_accum_steps = config["training"]["gradient_accumulation_steps"]
    num_gpus = torch.cuda.device_count()
    num_epochs = config["training_aux"]["num_train_epochs"]
    
    max_steps = int((total_samples / (batch_size * grad_accum_steps * num_gpus)) * num_epochs)
    
    # Ensure max_steps is a positive integer
    max_steps = max(1, max_steps)

    # Calculate save_steps, logging_steps, and eval_steps
    save_steps = max(1, int(max_steps * config["training_aux"]["save_steps_fraction"]))
    logging_steps = max(1, int(max_steps * config["training_aux"]["logging_steps_fraction"]))
    eval_steps = max(1, int(max_steps * config["training_aux"]["eval_steps_fraction"]))
    
    # Calculate warmup_steps if using warmup
    warmup_steps = int(max_steps * config["training"]["warmup_ratio"]) if config["training"]["warmup_ratio"] > 0 else 0

    run_name = f"distillation_v6_lr_{config['training']['learning_rate']}_rows_{total_samples}"

    training_args = TrainingArguments(
        **config["training"],
        max_steps=max_steps,  # Explicitly set max_steps
        num_train_epochs=config["training_aux"]["num_train_epochs"],  # Set to None when using max_steps
        run_name=run_name,
        logging_dir=f"./logs/{run_name}",
        save_steps=save_steps,
        logging_steps=logging_steps,
        eval_steps=eval_steps,
        warmup_steps=warmup_steps,
        # Default optimizer
        optim="adamw_torch",
        
        # # Galore optimizer, uses 80%+ less memory than adamw_torch
        # optim="galore_adamw_8bit",
        # optim_target_modules=["mlp.down_proj","mlp.up_proj","mlp.gate_proj","self_attn.q_proj","self_attn.k_proj","self_attn.v_proj","self_attn.o_proj"],
        
        ddp_find_unused_parameters=False,
    )

    # Print out the values to verify
    print(f"max_steps: {max_steps}")
    print(f"num_train_epochs: {training_args.num_train_epochs}")

    trainer = DistillationTrainer(
        model=student_model,
        teacher_model=teacher_model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,  # This is now a regular Dataset, not IterableDataset
        tokenizer=student_tokenizer,
        config=config,  # This is your custom config, not SFTConfig
        dataset_text_field="text",
        max_seq_length=config["tokenizer"]["max_length"],
        packing=True,
    )
    
    if config.get("gradient_checkpointing", False)==True:
        # Disable caching for gradient checkpointing compatibility
        trainer.model.config.use_cache = False
    
    # Prepare the trainer, models, and datasets
    trainer, teacher_model, train_dataset, eval_dataset = accelerator.prepare(
        trainer, teacher_model, train_dataset, eval_dataset
    )
    
    # Update the teacher model and datasets in the trainer
    trainer.teacher_model = teacher_model
    trainer.train_dataset = train_dataset
    trainer.eval_dataset = eval_dataset

    # Add custom scheduler
    optimizer = trainer.create_optimizer()
    scheduler = get_custom_scheduler(optimizer, warmup_steps, max_steps)
    trainer.lr_scheduler = scheduler

    trainer.add_callback(MemoryTracker())
    
    print("Starting knowledge distillation with evaluation...")
    try:
        trainer.train(resume_from_checkpoint=config["training"]["resume_from_checkpoint"])
    except RuntimeError as e:
        print(f"An error occurred during training: {e}")
        print("Please check that your GPU has enough memory and that all tensors are on the same device.")
        raise
    finally:
        print("Final memory stats:")
        print_memory_stats()
    
    print(f"Distillation completed. Saving model to {config['training']['output_dir']}")
    trainer.save_model(config['training']['output_dir'])
    
    trainer.push_to_hub()

if __name__ == "__main__":
    main("config_v9.yaml")