File size: 34,448 Bytes
a8090dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 |
import argparse
import math
import os
import sys
import json
import jsonlines
import copy
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torch.cuda.amp import autocast, GradScaler
from torch.utils.tensorboard import SummaryWriter
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from tqdm import tqdm
# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ======================================
# Import Custom Components from lightbulb_custom
# ======================================
from lightbulb_custom import (
RotaryPositionalEncoding,
MultiHeadAttention,
MoE,
TransformerBlock,
Transformer,
InfoNCE_Loss,
CovarianceRegularization,
DynamicsPerformanceLoss,
ThoughtConsistencyLoss,
PolicyValueJointLoss,
ActionDiversityReward,
ExpectedThoughtValueLoss,
ExplorationRegularization,
KL_DivergenceLoss,
ActionEncoder,
RepresentationNetwork,
DynamicsNetwork,
PredictionNetwork,
ThoughtNode,
MCTS,
State
)
# ==========================
# Custom Dataset Definition
# ==========================
class CustomDataset(Dataset):
def __init__(self, inputs, labels):
self.inputs = inputs
self.labels = labels
def __len__(self):
return len(self.inputs)
def __getitem__(self, idx):
return {'input_ids': self.inputs[idx], 'labels': self.labels[idx]}
# ================================
# Utility Functions for Data Loading
# ================================
def load_filtered_dataset(dataset_name: str, config: str, queries: Optional[List[str]] = None):
dataset = load_dataset(dataset_name, config)
if queries:
def filter_func(examples):
return [any(query.lower() in text.lower() for query in queries) for text in examples["text"]]
dataset = dataset.filter(filter_func, batched=True)
return dataset
def load_custom_data_from_files(file_paths):
custom_data = []
for file_path in file_paths:
if file_path.endswith('.json'):
with open(file_path, 'r') as f:
data = json.load(f)
if isinstance(data, list):
custom_data.extend(data)
else:
custom_data.append(data)
elif file_path.endswith('.jsonl'):
with jsonlines.open(file_path) as reader:
custom_data.extend(reader)
return custom_data
def preprocess_custom_data(data_list):
processed_data = []
for item in data_list:
# Check if the item is a string (JSON)
if isinstance(item, str):
try:
item = json.loads(item)
except json.JSONDecodeError:
print(f"Failed to parse JSON: {item[:100]}...") # Print first 100 chars for debugging
continue # Skip this item if it's not valid JSON
# Process query and content
query = item.get('query', '')
content = item.get('content', '')
if content == "RAG response generation failed.":
content = ""
# Combine query and content
combined_text = f"Query: {query} Content: {content}"
# Process numerical data (assuming these are available in the item dict)
episode_reward = item.get('episode_reward', 0)
loss = item.get('loss', 0)
cosine_similarity = item.get('cosine_similarity', 0)
rag_performance = item.get('rag_performance', 0)
ranking_model_performance = item.get('ranking_model_performance', 0)
# Create a dictionary with processed data
processed_item = {
'text': combined_text,
'episode_reward': episode_reward,
'loss': loss,
'cosine_similarity': cosine_similarity,
'rag_performance': rag_performance,
'ranking_model_performance': ranking_model_performance
}
processed_data.append(processed_item)
return processed_data
def load_custom_data(args, tokenizer, custom_data):
# Preprocess the custom data
processed_data = preprocess_custom_data(custom_data)
# Create a custom dataset
class CustomDatasetProcessed(torch.utils.data.Dataset):
def __init__(self, data, tokenizer, max_length):
self.data = data
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
encoded = self.tokenizer.encode_plus(
item['text'],
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
return {
'input_ids': encoded['input_ids'].squeeze(),
'attention_mask': encoded['attention_mask'].squeeze(),
'episode_reward': torch.tensor(item['episode_reward'], dtype=torch.float),
'loss': torch.tensor(item['loss'], dtype=torch.float),
'cosine_similarity': torch.tensor(item['cosine_similarity'], dtype=torch.float),
'rag_performance': torch.tensor(item['rag_performance'], dtype=torch.float),
'ranking_model_performance': torch.tensor(item['ranking_model_performance'], dtype=torch.float)
}
# Create dataset and dataloader
dataset = CustomDatasetProcessed(processed_data, tokenizer, args.max_length)
# Split the dataset into train and eval
train_size = int(0.8 * len(dataset))
eval_size = len(dataset) - train_size
train_dataset, eval_dataset = random_split(dataset, [train_size, eval_size])
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=4
)
eval_loader = DataLoader(
eval_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=4
)
return train_loader, eval_loader
def prepare_data(tokenizer, dataset, max_length, batch_size):
# Tokenize the inputs and labels
tokenized_inputs = tokenizer(dataset["train"]["text"], return_tensors="pt", padding=True, truncation=True, max_length=max_length)
tokenized_labels = tokenizer(dataset["train"]["text"], return_tensors="pt", padding=True, truncation=True, max_length=max_length)
# Create custom dataset
custom_dataset = CustomDataset(tokenized_inputs["input_ids"], tokenized_labels["input_ids"])
# Split into training and validation sets
train_size = int(0.9 * len(custom_dataset))
val_size = len(custom_dataset) - train_size
train_dataset, val_dataset = random_split(custom_dataset, [train_size, val_size])
# Create DataLoaders
train_loader = DataLoader(
train_dataset,
shuffle=True,
batch_size=batch_size,
num_workers=4,
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
shuffle=False,
batch_size=batch_size,
num_workers=4,
pin_memory=True
)
return train_loader, val_loader
# ==========================
# Training and Validation Functions
# ==========================
def save_all_models(transformer_model, representation_network, dynamics_network, prediction_network, action_encoder, save_dir, epoch):
"""
Save all models to the specified directory.
Args:
transformer_model (nn.Module): Transformer model.
representation_network (nn.Module): Representation network.
dynamics_network (nn.Module): Dynamics network.
prediction_network (nn.Module): Prediction network.
action_encoder (nn.Module): Action encoder.
save_dir (str): Directory to save the models.
epoch (int): Current epoch number.
"""
os.makedirs(save_dir, exist_ok=True)
torch.save(transformer_model.state_dict(), os.path.join(save_dir, f'transformer_model_epoch_{epoch}.pt'))
torch.save(representation_network.state_dict(), os.path.join(save_dir, f'representation_network_epoch_{epoch}.pt'))
torch.save(dynamics_network.state_dict(), os.path.join(save_dir, f'dynamics_network_epoch_{epoch}.pt'))
torch.save(prediction_network.state_dict(), os.path.join(save_dir, f'prediction_network_epoch_{epoch}.pt'))
torch.save(action_encoder.state_dict(), os.path.join(save_dir, f'action_encoder_epoch_{epoch}.pt'))
print(f"All models saved for epoch {epoch}.")
def train_epoch_world_model(world_model_components, train_loader, optimizer, scheduler, scaler, args, model_transformer, state_dim, embed_dim, input_dim):
representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent, model_transformer = world_model_components
representation_network.train()
dynamics_network.train()
prediction_network.train()
action_encoder.train()
ppo_agent.policy_network.train()
total_loss = 0.0
optimizer.zero_grad()
print(f"Starting World Model training epoch with {len(train_loader)} batches...")
for i, batch in enumerate(train_loader):
print(f"Processing batch {i+1}/{len(train_loader)}...")
# Move batches to the device
src_batch = batch['input_ids'].to(device)
tgt_batch = batch['labels'].to(device)
with torch.cuda.amp.autocast():
print("Forward pass through Transformer (frozen)...")
with torch.no_grad():
transformer_output = model_transformer(src_batch, tgt_batch[:, :-1])
# World Model - Representation
state_representation = representation_network(transformer_output)
# For simplicity, let's assume true actions are provided (e.g., next tokens)
true_actions = tgt_batch[:, :-1]
print(f"True actions shape: {true_actions.shape}")
action_sequences = true_actions
# Get action embeddings
action_embeddings = action_encoder(action_sequences)
print(f"Action embeddings shape: {action_embeddings.shape}")
# Apply dynamics network
predicted_next_state_batch = dynamics_network(state_representation, action_embeddings)
print(f"Predicted next state batch shape: {predicted_next_state_batch.shape}")
# Prediction Network - Policy logits and value
policy_logits, value_estimates = prediction_network(predicted_next_state_batch)
# Define true_policy and true_value as placeholders on the GPU
true_policy = F.one_hot(true_actions, num_classes=input_dim).float()
true_value = torch.zeros_like(value_estimates).to(device)
# Compute individual losses
ppo_loss = ppo_agent.compute_loss(
state_representation,
torch.zeros_like(true_actions, dtype=torch.float32).to(device),
true_actions,
torch.zeros_like(value_estimates, dtype=torch.float32).to(device),
torch.zeros_like(value_estimates, dtype=torch.float32).to(device)
)
info_nce = InfoNCE_Loss()(state_representation.reshape(-1, state_dim),
F.dropout(state_representation.reshape(-1, state_dim), p=0.1, training=True))
covariance = CovarianceRegularization()(predicted_next_state_batch.view(-1, predicted_next_state_batch.size(-1)))
dynamics_loss = DynamicsPerformanceLoss()(state_representation, predicted_next_state_batch)
perturbed_next_state = predicted_next_state_batch + torch.randn_like(predicted_next_state_batch) * 0.01
thought_loss = ThoughtConsistencyLoss()(predicted_next_state_batch, perturbed_next_state)
pv_loss = PolicyValueJointLoss()(policy_logits, true_policy, value_estimates.squeeze(-1), true_value.squeeze(-1))
action_diversity = ActionDiversityReward()(action_embeddings.view(-1, embed_dim))
mcts_best_values = torch.zeros(true_actions.size(0)).to(device)
etv = ExpectedThoughtValueLoss()(mcts_best_values)
visit_counts = torch.ones(true_actions.size(0), policy_logits.size(-1)).to(device)
exploration = ExplorationRegularization()(visit_counts)
old_policy = F.softmax(policy_logits.detach(), dim=-1)
new_policy = F.softmax(policy_logits, dim=-1)
kl_loss = KL_DivergenceLoss()(old_policy, new_policy)
# Total Loss
loss = (
ppo_loss +
info_nce +
covariance +
dynamics_loss +
thought_loss +
pv_loss +
action_diversity +
etv +
exploration +
kl_loss
)
loss = loss / args.accumulation_steps
print("Backward pass...")
scaler.scale(loss).backward()
if (i + 1) % args.accumulation_steps == 0 or (i + 1) == len(train_loader):
print("Gradient clipping...")
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(
[param for group in optimizer.param_groups for param in group['params']],
args.max_grad_norm
)
print("Optimizer step...")
scaler.step(optimizer)
scaler.update()
print("Zeroing gradients...")
optimizer.zero_grad()
print("Updating learning rate...")
scheduler.step()
total_loss += loss.item() * args.accumulation_steps
# Print individual losses and total loss for this batch
print(f"Batch {i+1} completed. Losses:")
print(f" PPO Loss: {ppo_loss.item():.4f}")
print(f" InfoNCE Loss: {info_nce.item():.4f}")
print(f" Covariance Loss: {covariance.item():.4f}")
print(f" Dynamics Loss: {dynamics_loss.item():.4f}")
print(f" Thought Consistency Loss: {thought_loss.item():.4f}")
print(f" Policy-Value Loss: {pv_loss.item():.4f}")
print(f" Action Diversity Loss: {action_diversity.item():.4f}")
print(f" Expected Thought Value Loss: {etv.item():.4f}")
print(f" Exploration Loss: {exploration.item():.4f}")
print(f" KL Divergence Loss: {kl_loss.item():.4f}")
print(f" Total Loss: {loss.item():.4f}")
avg_loss = total_loss / len(train_loader)
print(f"World Model training epoch completed. Average loss: {avg_loss:.4f}")
return avg_loss
def train_step(teacher, student, data_loader, optimizer, criterion, scaler, temperature=2.0):
teacher.eval()
student.train()
total_loss = 0
for batch in tqdm(data_loader, desc="Training"):
inputs = batch["input_ids"].to(device)
labels = batch["labels"].to(device)
with autocast():
with torch.no_grad():
teacher_outputs = teacher(inputs).logits
teacher_logits = teacher_outputs / temperature
student_outputs = student(inputs).logits
student_logits = student_outputs / temperature
# Compute KL Divergence Loss
loss = criterion(nn.functional.log_softmax(student_logits, dim=-1), nn.functional.softmax(teacher_logits, dim=-1))
loss = loss * (temperature ** 2) # Scale loss by temperature squared
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
total_loss += loss.item()
avg_loss = total_loss / len(data_loader)
return avg_loss
def validate(teacher, student, data_loader, criterion, temperature=2.0):
teacher.eval()
student.eval()
total_loss = 0
with torch.no_grad():
for batch in tqdm(data_loader, desc="Validation"):
inputs = batch["input_ids"].to(device)
labels = batch["labels"].to(device)
teacher_outputs = teacher(inputs).logits
teacher_logits = teacher_outputs / temperature
student_outputs = student(inputs).logits
student_logits = student_outputs / temperature
loss = criterion(nn.functional.log_softmax(student_logits, dim=-1), nn.functional.softmax(teacher_logits, dim=-1))
loss = loss * (temperature ** 2)
total_loss += loss.item()
avg_loss = total_loss / len(data_loader)
return avg_loss
def save_checkpoint(state, save_dir, epoch):
os.makedirs(save_dir, exist_ok=True)
checkpoint_path = os.path.join(save_dir, f'checkpoint_epoch_{epoch}.pt')
torch.save(state, checkpoint_path)
print(f"Checkpoint saved at {checkpoint_path}")
# ==========================
# Inference Functions
# ==========================
def infer(query, world_model_components, root_thought_node, tokenizer, max_length=2000, inference_mode='world_model', beam_size=5, n_tokens_predict=3, mcts_iterations=10, exploration_constant=1.414):
"""
Perform inference given a query, utilizing the Tree of Thought and MCTS with multi-token beam search.
Args:
query (str): The input query or prompt.
world_model_components (tuple): Tuple containing the model components.
root_thought_node (ThoughtNode): The root node of the Tree of Thought.
tokenizer (transformers.PreTrainedTokenizer): The tokenizer used.
max_length (int): Maximum length for the generated sequence.
inference_mode (str): Inference mode ('world_model', 'without_world_model', 'world_model_tree_of_thought')
beam_size (int): Size of the beam for beam search
n_tokens_predict (int): Number of tokens to predict at each step
mcts_iterations (int): Number of MCTS iterations
exploration_constant (float): Exploration constant for MCTS
Returns:
List[str] or str: The sequence of actions (thoughts) selected or generated text.
"""
if inference_mode != 'world_model':
print("Inference mode other than 'world_model' not implemented yet.")
return ""
representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent, model_transformer = world_model_components
# Tokenize and encode the query
input_ids = tokenizer.encode(query, return_tensors='pt').to(device)
attention_mask = (input_ids != tokenizer.pad_token_id).long()
# Use the world model components
with torch.no_grad():
transformer_output = model_transformer(input_ids, input_ids)
# Get the initial state representation
initial_representation = representation_network(transformer_output) # Shape: (batch_size=1, seq_len, state_dim)
initial_representation = initial_representation[:, -1, :].unsqueeze(1) # Shape: (batch_size=1, 1, state_dim)
initial_state = State(
representation=initial_representation,
dynamics_network=dynamics_network,
action_encoder=action_encoder,
thought_node=root_thought_node
)
# Use MCTS with Tree of Thought and multi-token beam search
mcts = MCTS(prediction_network, dynamics_network, action_encoder, num_iterations=mcts_iterations, exploration_constant=exploration_constant)
current_state = initial_state
thought_sequence = []
for _ in range(max_length // n_tokens_predict):
best_actions = mcts.search_with_beam(current_state)
thought_sequence.extend(best_actions)
# Apply the best actions to get the next state
for action in best_actions:
current_state = current_state.apply_action(action)
# Check if we've reached a leaf node (no further actions)
if len(current_state.thought_node.children) == 0:
break
return thought_sequence
# ==========================
# Main Training Function
# ==========================
def distill_model(
teacher_model_name: str,
student_model_name: str,
dataset_name: str,
config: str,
distill_full_model: bool = True,
query_terms: Optional[List[str]] = None,
num_epochs: int = 3,
batch_size: int = 4,
max_length: int = 128,
learning_rate: float = 5e-5,
temperature: float = 2.0,
save_path: str = "./distilled_model",
log_dir: str = "./logs",
checkpoint_dir: str = "./checkpoints",
early_stopping_patience: int = 3,
accumulation_steps: int = 1,
max_grad_norm: float = 1.0,
weight_decay: float = 0.01
):
# Initialize TensorBoard writer
writer = SummaryWriter(log_dir=log_dir)
# Load tokenizer
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("Tokenizer loaded successfully.")
# Load teacher model
print("Loading teacher model...")
teacher = AutoModelForCausalLM.from_pretrained(teacher_model_name).to(device)
print("Teacher model loaded successfully.")
if distill_full_model:
# Full World Model Distillation
print(f"Starting Full World Model Distillation into '{student_model_name}'.")
# Load or instantiate student model
print(f"Attempting to load student model '{student_model_name}'...")
try:
student = AutoModelForCausalLM.from_pretrained(student_model_name).to(device)
print(f"Student model '{student_model_name}' loaded successfully.")
except (OSError, ValueError) as e:
print(f"Student model '{student_model_name}' not found. Instantiating a new student model.")
# Instantiate a smaller pre-trained model as the student, e.g., distilgpt2
try:
student = AutoModelForCausalLM.from_pretrained('distilgpt2').to(device)
# Save the instantiated student model with the desired name
student.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
print(f"New student model '{student_model_name}' instantiated and saved to '{save_path}'.")
except Exception as inst_e:
print(f"Failed to instantiate and save student model: {inst_e}")
sys.exit(1)
# Optionally freeze teacher model parameters
for param in teacher.parameters():
param.requires_grad = False
# Load and prepare dataset
print(f"Loading full dataset '{dataset_name}' with config '{config}'...")
dataset = load_dataset(dataset_name, config)
train_loader, val_loader = prepare_data(tokenizer, dataset, max_length, batch_size)
print("Data loaded and preprocessed successfully.")
# Define optimizer, scheduler, and scaler for mixed precision
optimizer = optim.AdamW(student.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
scaler = GradScaler()
# Define loss criterion
criterion = nn.KLDivLoss(reduction="batchmean")
best_val_loss = float('inf')
epochs_no_improve = 0
# Training loop
for epoch in range(1, num_epochs + 1):
print(f"\nEpoch {epoch}/{num_epochs}")
print("-" * 20)
# Training
train_loss = train_step(teacher, student, train_loader, optimizer, criterion, scaler, temperature)
print(f"Training Loss: {train_loss:.4f}")
writer.add_scalar("Loss/Train", train_loss, epoch)
# Validation
val_loss = validate(teacher, student, val_loader, criterion, temperature)
print(f"Validation Loss: {val_loss:.4f}")
writer.add_scalar("Loss/Validation", val_loss, epoch)
# Check for improvement
if val_loss < best_val_loss:
best_val_loss = val_loss
epochs_no_improve = 0
# Save the best model
save_checkpoint({
'epoch': epoch,
'model_state_dict': student.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'scaler_state_dict': scaler.state_dict(),
'best_val_loss': best_val_loss
}, checkpoint_dir, epoch)
# Save the model as the best one
student.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
print(f"Best model saved at epoch {epoch}")
else:
epochs_no_improve += 1
print(f"No improvement in validation loss for {epochs_no_improve} epoch(s)")
if epochs_no_improve >= early_stopping_patience:
print("Early stopping triggered")
break
# Step the scheduler
scheduler.step()
writer.close()
print("\nFull World Model Distillation completed.")
else:
# Standard Language Model Distillation
print(f"Starting Standard Language Model Distillation into '{student_model_name}'.")
if not query_terms:
print("Error: --query_terms must be provided for standard language model distillation.")
sys.exit(1)
# Load or instantiate student model
print(f"Attempting to load student model '{student_model_name}'...")
try:
student = AutoModelForCausalLM.from_pretrained(student_model_name).to(device)
print(f"Student model '{student_model_name}' loaded successfully.")
except (OSError, ValueError) as e:
print(f"Student model '{student_model_name}' not found. Instantiating a new student model.")
# Instantiate a smaller pre-trained model as the student, e.g., distilgpt2
try:
student = AutoModelForCausalLM.from_pretrained('distilgpt2').to(device)
# Save the instantiated student model with the desired name
student.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
print(f"New student model '{student_model_name}' instantiated and saved to '{save_path}'.")
except Exception as inst_e:
print(f"Failed to instantiate and save student model: {inst_e}")
sys.exit(1)
# Optionally freeze teacher model parameters
for param in teacher.parameters():
param.requires_grad = False
# Load and prepare custom dataset
print(f"Loading custom data files: {query_terms}")
custom_data = load_custom_data_from_files(query_terms)
train_loader, val_loader = load_custom_data(
args=argparse.Namespace(max_length=max_length),
tokenizer=tokenizer,
custom_data=custom_data
)
print("Custom data loaded and preprocessed successfully.")
# Define optimizer, scheduler, and scaler for mixed precision
optimizer = optim.AdamW(student.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
scaler = GradScaler()
# Define loss criterion
criterion = nn.KLDivLoss(reduction="batchmean")
best_val_loss = float('inf')
epochs_no_improve = 0
# Training loop
for epoch in range(1, num_epochs + 1):
print(f"\nEpoch {epoch}/{num_epochs}")
print("-" * 20)
# Training
train_loss = train_step(teacher, student, train_loader, optimizer, criterion, scaler, temperature)
print(f"Training Loss: {train_loss:.4f}")
writer.add_scalar("Loss/Train", train_loss, epoch)
# Validation
val_loss = validate(teacher, student, val_loader, criterion, temperature)
print(f"Validation Loss: {val_loss:.4f}")
writer.add_scalar("Loss/Validation", val_loss, epoch)
# Check for improvement
if val_loss < best_val_loss:
best_val_loss = val_loss
epochs_no_improve = 0
# Save the best model
save_checkpoint({
'epoch': epoch,
'model_state_dict': student.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'scaler_state_dict': scaler.state_dict(),
'best_val_loss': best_val_loss
}, checkpoint_dir, epoch)
# Save the model as the best one
student.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
print(f"Best model saved at epoch {epoch}")
else:
epochs_no_improve += 1
print(f"No improvement in validation loss for {epochs_no_improve} epoch(s)")
if epochs_no_improve >= early_stopping_patience:
print("Early stopping triggered")
break
# Step the scheduler
scheduler.step()
writer.close()
print("\nStandard Language Model Distillation completed.")
# ==========================
# Argument Parsing
# ==========================
def parse_args():
parser = argparse.ArgumentParser(description="Distill a large LLM into a smaller one or a full language world model.")
# Required arguments
parser.add_argument("--teacher_model_name", type=str, required=True, help="Name of the teacher model")
parser.add_argument("--student_model_name", type=str, required=True, help="Name of the student model")
# Dataset arguments
parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset")
parser.add_argument("--config", type=str, default=None, help="Dataset configuration (e.g., 'wikitext-2-raw-v1')")
# Mode selection
parser.add_argument("--distill_full_model", action="store_true", help="Whether to distill into the full language world model")
# For standard distillation
parser.add_argument("--query_terms", type=str, nargs="+", help="Paths to custom data files for standard language model distillation")
# Training hyperparameters
parser.add_argument("--num_epochs", type=int, default=3, help="Number of epochs")
parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
parser.add_argument("--max_length", type=int, default=128, help="Maximum sequence length")
parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate")
parser.add_argument("--temperature", type=float, default=2.0, help="Distillation temperature")
# Saving and logging
parser.add_argument("--save_path", type=str, default="./distilled_model", help="Path to save the distilled model")
parser.add_argument("--log_dir", type=str, default="./logs", help="Directory for TensorBoard logs")
parser.add_argument("--checkpoint_dir", type=str, default="./checkpoints", help="Directory to save checkpoints")
# Early stopping
parser.add_argument("--early_stopping_patience", type=int, default=3, help="Early stopping patience")
# Gradient accumulation and optimization
parser.add_argument("--accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Maximum gradient norm for clipping")
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay for optimizer")
return parser.parse_args()
# ==========================
# Main Function
# ==========================
def main():
args = parse_args()
print("Arguments parsed successfully.")
# Create save directories
os.makedirs(args.save_path, exist_ok=True)
os.makedirs(args.log_dir, exist_ok=True)
os.makedirs(args.checkpoint_dir, exist_ok=True)
print(f"Save directory created: {args.save_path}")
print(f"Log directory created: {args.log_dir}")
print(f"Checkpoint directory created: {args.checkpoint_dir}")
# Handle dataset loading based on distillation mode
if args.distill_full_model:
# Full World Model Distillation
distill_model(
teacher_model_name=args.teacher_model_name,
student_model_name=args.student_model_name,
dataset_name=args.dataset_name,
config=args.config,
distill_full_model=args.distill_full_model,
query_terms=args.query_terms, # Not used in this mode
num_epochs=args.num_epochs,
batch_size=args.batch_size,
max_length=args.max_length,
learning_rate=args.learning_rate,
temperature=args.temperature,
save_path=args.save_path,
log_dir=args.log_dir,
checkpoint_dir=args.checkpoint_dir,
early_stopping_patience=args.early_stopping_patience,
accumulation_steps=args.accumulation_steps,
max_grad_norm=args.max_grad_norm,
weight_decay=args.weight_decay
)
else:
# Standard Language Model Distillation
distill_model(
teacher_model_name=args.teacher_model_name,
student_model_name=args.student_model_name,
dataset_name=args.dataset_name,
config=args.config,
distill_full_model=args.distill_full_model,
query_terms=args.query_terms,
num_epochs=args.num_epochs,
batch_size=args.batch_size,
max_length=args.max_length,
learning_rate=args.learning_rate,
temperature=args.temperature,
save_path=args.save_path,
log_dir=args.log_dir,
checkpoint_dir=args.checkpoint_dir,
early_stopping_patience=args.early_stopping_patience,
accumulation_steps=args.accumulation_steps,
max_grad_norm=args.max_grad_norm,
weight_decay=args.weight_decay
)
if __name__ == "__main__":
main()
|