File size: 13,698 Bytes
2a964fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 |
import os
import yaml
import torch
from datasets import load_dataset, IterableDataset, Dataset, concatenate_datasets
from trl import SFTTrainer, SFTConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, get_scheduler
from accelerate import Accelerator
from huggingface_hub import HfFolder, create_repo, upload_folder
import wandb
import time
import torch.nn.functional as F
from galore_torch import GaLoreAdamW8bit
import gc
from transformers import TrainerCallback
from itertools import islice
from huggingface_hub import login
def load_config(config_path):
with open(config_path, 'r') as file:
return yaml.safe_load(file)
def setup_environment(config):
# os.environ['WANDB_PROJECT'] = config["wandb"]["wandb_project"]
# os.environ['WANDB_ENTITY'] = config["wandb"]["wandb_entity"]
# wandb.init(project=config["wandb"]["wandb_project"], entity=config["wandb"]["wandb_entity"])
os.environ['WANDB_DISABLED'] = 'true'
return Accelerator()
def load_and_preprocess_dataset(config, student_tokenizer):
def tokenize_function(examples):
return student_tokenizer(examples["text"], truncation=True, max_length=config["tokenizer"]["max_length"], padding="max_length")
datasets = []
for subset in config["dataset"]["subsets"]:
# Load the dataset as an IterableDataset
dataset = load_dataset(
config["dataset"]["name"],
subset['name'],
split=subset['split'],
streaming=True
)
# Keep only the 'text' column for all subsets
if 'text' in dataset.column_names:
dataset = dataset.remove_columns([col for col in dataset.column_names if col != 'text'])
else:
raise ValueError(f"The 'text' column is missing in the {subset['name']} subset.")
datasets.append(dataset)
# Concatenate all datasets
full_dataset = concatenate_datasets(datasets)
# Create evaluation dataset (first N examples)
eval_dataset = Dataset.from_list(list(islice(full_dataset, config["dataset"]["eval_samples"])))
eval_dataset = eval_dataset.map(
tokenize_function,
batched=True,
remove_columns=eval_dataset.column_names
)
# Create training dataset (skip first N examples)
def generate_train_examples():
for i, example in enumerate(full_dataset):
if i >= config["dataset"]["eval_samples"]:
yield example
train_dataset = IterableDataset.from_generator(generate_train_examples)
train_dataset = train_dataset.map(
tokenize_function,
remove_columns=train_dataset.column_names
)
return train_dataset, eval_dataset
def load_models_and_tokenizers(config):
model_kwargs = {"torch_dtype": torch.bfloat16}
if config["model_config"]["use_flash_attention"]:
model_kwargs["attn_implementation"] = "flash_attention_2"
print(f"model_kwargs: {model_kwargs}")
teacher_tokenizer = AutoTokenizer.from_pretrained(config["models"]["teacher"], add_eos_token=True)
student_tokenizer = AutoTokenizer.from_pretrained(config["models"]["student"], add_eos_token=True)
if student_tokenizer.pad_token is None:
student_tokenizer.pad_token = student_tokenizer.eos_token
print(f"Set pad_token to eos_token: {student_tokenizer.pad_token}")
teacher_model = AutoModelForCausalLM.from_pretrained(config["models"]["teacher"], **model_kwargs)
student_model = AutoModelForCausalLM.from_pretrained(config["models"]["student"], **model_kwargs)
teacher_model.eval() # set teacher model to evaluation mode
return teacher_model, student_model, teacher_tokenizer, student_tokenizer
def pad_logits(student_logits, teacher_logits):
student_size, teacher_size = student_logits.size(-1), teacher_logits.size(-1)
if student_size != teacher_size:
pad_size = abs(student_size - teacher_size)
pad_tensor = torch.zeros((*teacher_logits.shape[:-1], pad_size), dtype=teacher_logits.dtype, device=teacher_logits.device)
return (torch.cat([student_logits, pad_tensor], dim=-1), teacher_logits) if student_size < teacher_size else (student_logits, torch.cat([teacher_logits, pad_tensor], dim=-1))
return student_logits, teacher_logits
class DistillationTrainer(SFTTrainer):
def __init__(self, *args, **kwargs):
self.config = kwargs.pop('config', None)
self.teacher_model = kwargs.pop('teacher_model', None)
super().__init__(*args, **kwargs)
# Ensure teacher model is on the same device as the student model
if self.teacher_model.device != self.model.device:
self.teacher_model = self.teacher_model.to(self.model.device)
# Ensure teacher model is in eval mode
self.teacher_model.eval()
def compute_loss(self, model, inputs, return_outputs=False):
if hasattr(model, 'module'):
device = model.module.device
else:
device = next(model.parameters()).device
inputs = {k: v.to(device) if hasattr(v, 'to') else v for k, v in inputs.items()}
student_outputs = model(**inputs)
with torch.no_grad():
teacher_outputs = self.teacher_model(**inputs)
# Check if 'labels' are in the inputs, if not, use 'input_ids' as labels
labels = inputs.get('labels', inputs.get('input_ids'))
if labels is None:
raise ValueError("Neither 'labels' nor 'input_ids' found in inputs. Cannot compute loss.")
custom_loss = self.distillation_loss(student_outputs.logits, teacher_outputs.logits, labels)
return (custom_loss, student_outputs) if return_outputs else custom_loss
def distillation_loss(self, student_logits, teacher_logits, labels):
student_logits, teacher_logits = pad_logits(student_logits, teacher_logits)
kl_loss = self.forward_kl_divergence(student_logits, teacher_logits)
if self.config["distillation"]["alpha"] != 1:
# Calculate the original loss (cross-entropy loss)
original_loss = F.cross_entropy(student_logits.view(-1, student_logits.size(-1)), labels.view(-1), ignore_index=-100)
else:
original_loss = 0
combined_loss = self.config["distillation"]["alpha"] * kl_loss + (1 - self.config["distillation"]["alpha"]) * original_loss
return combined_loss
def forward_kl_divergence(self, student_logits, teacher_logits):
temperature = self.config["distillation"]["temperature"]
student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
teacher_log_probs = F.log_softmax(teacher_logits / temperature, dim=-1)
kl_div = F.kl_div(
student_log_probs,
teacher_log_probs.exp(),
reduction='batchmean',
log_target=False
)
return kl_div * (temperature ** 2) / self.config["tokenizer"]["max_length"]
def evaluation_loop(self, dataloader, description, prediction_loss_only=None, ignore_keys=None, metric_key_prefix="eval"):
output = super().evaluation_loop(dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
eval_loss = 0.0
num_examples = 0
chunk_size = 4 # Adjust this value based on your GPU memory
for step, inputs in enumerate(dataloader):
for i in range(0, inputs["input_ids"].size(0), chunk_size):
chunk_inputs = {k: v[i:i+chunk_size] for k, v in inputs.items() if isinstance(v, torch.Tensor)}
loss = self.compute_loss(self.model, chunk_inputs)
eval_loss += loss.detach().float() * len(chunk_inputs["input_ids"])
num_examples += len(chunk_inputs["input_ids"])
eval_loss /= num_examples
output.metrics[f"{metric_key_prefix}_loss"] = eval_loss.item()
return output
def print_memory_stats():
print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f}GB")
print(f"Max Allocated: {torch.cuda.max_memory_allocated() / 1e9:.2f}GB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1e9:.2f}GB")
print(f"Max Reserved: {torch.cuda.max_memory_reserved() / 1e9:.2f}GB")
def clear_memory():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
class MemoryTracker(TrainerCallback):
def __init__(self, print_every=100):
self.print_every = print_every
def on_step_end(self, args, state, control, **kwargs):
if state.global_step % self.print_every == 0:
print(f"Step {state.global_step}:")
print_memory_stats()
clear_memory()
def get_custom_scheduler(optimizer, num_warmup_steps, num_training_steps):
return get_scheduler(
"constant",
optimizer=optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps
)
def main(config_path):
config = load_config(config_path)
accelerator = setup_environment(config)
teacher_model, student_model, teacher_tokenizer, student_tokenizer = load_models_and_tokenizers(config)
print(f"Student model: {student_model}")
print("Memory after loading models:")
print_memory_stats()
clear_memory()
train_dataset, eval_dataset = load_and_preprocess_dataset(config, student_tokenizer)
# Ensure train_dataset is iterable and eval_dataset is a regular dataset
# assert isinstance(train_dataset, IterableDataset)
# assert isinstance(eval_dataset, Dataset)
# Calculate max_steps
total_samples = config["dataset"]["total_train_samples"] - config["dataset"]["eval_samples"]
batch_size = config["training"]["per_device_train_batch_size"]
grad_accum_steps = config["training"]["gradient_accumulation_steps"]
num_gpus = torch.cuda.device_count()
num_epochs = config["training_aux"]["num_train_epochs"]
max_steps = int((total_samples / (batch_size * grad_accum_steps * num_gpus)) * num_epochs)
# Ensure max_steps is a positive integer
max_steps = max(1, max_steps)
# Calculate save_steps, logging_steps, and eval_steps
save_steps = max(1, int(max_steps * config["training_aux"]["save_steps_fraction"]))
logging_steps = max(1, int(max_steps * config["training_aux"]["logging_steps_fraction"]))
eval_steps = max(1, int(max_steps * config["training_aux"]["eval_steps_fraction"]))
# Calculate warmup_steps if using warmup
warmup_steps = int(max_steps * config["training"]["warmup_ratio"]) if config["training"]["warmup_ratio"] > 0 else 0
run_name = f"distillation_v6_lr_{config['training']['learning_rate']}_rows_{total_samples}"
training_args = TrainingArguments(
**config["training"],
max_steps=max_steps, # Explicitly set max_steps
num_train_epochs=config["training_aux"]["num_train_epochs"], # Set to None when using max_steps
run_name=run_name,
logging_dir=f"./logs/{run_name}",
save_steps=save_steps,
logging_steps=logging_steps,
eval_steps=eval_steps,
warmup_steps=warmup_steps,
# Default optimizer
optim="adamw_torch",
# # Galore optimizer, uses 80%+ less memory than adamw_torch
# optim="galore_adamw_8bit",
# optim_target_modules=["mlp.down_proj","mlp.up_proj","mlp.gate_proj","self_attn.q_proj","self_attn.k_proj","self_attn.v_proj","self_attn.o_proj"],
ddp_find_unused_parameters=False,
)
# Print out the values to verify
print(f"max_steps: {max_steps}")
print(f"num_train_epochs: {training_args.num_train_epochs}")
trainer = DistillationTrainer(
model=student_model,
teacher_model=teacher_model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset, # This is now a regular Dataset, not IterableDataset
tokenizer=student_tokenizer,
config=config, # This is your custom config, not SFTConfig
dataset_text_field="text",
max_seq_length=config["tokenizer"]["max_length"],
packing=True,
)
if config.get("gradient_checkpointing", False)==True:
# Disable caching for gradient checkpointing compatibility
trainer.model.config.use_cache = False
# Prepare the trainer, models, and datasets
trainer, teacher_model, train_dataset, eval_dataset = accelerator.prepare(
trainer, teacher_model, train_dataset, eval_dataset
)
# Update the teacher model and datasets in the trainer
trainer.teacher_model = teacher_model
trainer.train_dataset = train_dataset
trainer.eval_dataset = eval_dataset
# Add custom scheduler
optimizer = trainer.create_optimizer()
scheduler = get_custom_scheduler(optimizer, warmup_steps, max_steps)
trainer.lr_scheduler = scheduler
trainer.add_callback(MemoryTracker())
print("Starting knowledge distillation with evaluation...")
try:
trainer.train(resume_from_checkpoint=config["training"]["resume_from_checkpoint"])
except RuntimeError as e:
print(f"An error occurred during training: {e}")
print("Please check that your GPU has enough memory and that all tensors are on the same device.")
raise
finally:
print("Final memory stats:")
print_memory_stats()
print(f"Distillation completed. Saving model to {config['training']['output_dir']}")
trainer.save_model(config['training']['output_dir'])
trainer.push_to_hub()
if __name__ == "__main__":
main("config_v9.yaml")
|