File size: 3,467 Bytes
d18eb09 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
import wandb
from transformers import get_linear_schedule_with_warmup
from utils.data_preprocessing import get_dataloader, load_tokenizer
from models.gem_model import GEM
from configs.config import MODEL_CONFIG, TRAINING_CONFIG
def train():
wandb.init(project="GEM_Project", config=MODEL_CONFIG, mode="offline")
print("WandB initialized in offline mode.")
tokenizer = load_tokenizer()
print("Tokenizer loaded.")
dataloader = get_dataloader('wikitext', 'wikitext-2-raw-v1', tokenizer, MODEL_CONFIG['MAX_SEQ_LEN'], MODEL_CONFIG['BATCH_SIZE'])
print("Dataloader created.")
model = GEM(
vocab_size=len(tokenizer),
d_model=MODEL_CONFIG['D_MODEL'],
n_heads=MODEL_CONFIG['N_HEADS'],
d_ff=MODEL_CONFIG['D_FF'],
n_layers=MODEL_CONFIG['N_LAYERS'],
dropout=MODEL_CONFIG['DROPOUT']
).to(MODEL_CONFIG['DEVICE'])
print("Model initialized.")
optimizer = optim.AdamW(model.parameters(), lr=MODEL_CONFIG['LEARNING_RATE'], eps=MODEL_CONFIG['ADAM_EPSILON'])
total_steps = len(dataloader) * MODEL_CONFIG['NUM_EPOCHS'] // MODEL_CONFIG['GRADIENT_ACCUMULATION_STEPS']
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=MODEL_CONFIG['WARMUP_STEPS'],
num_training_steps=total_steps
)
print("Optimizer and scheduler set up.")
# Mixed precision setup
scaler = torch.cuda.amp.GradScaler()
model.train()
print("Starting training loop.")
for epoch in range(MODEL_CONFIG['NUM_EPOCHS']):
print(f"Epoch {epoch + 1}/{MODEL_CONFIG['NUM_EPOCHS']} started.")
for step, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch + 1}")):
batch = batch.to(MODEL_CONFIG['DEVICE'])
# Mixed precision training
with torch.cuda.amp.autocast():
outputs = model(batch)
loss = F.cross_entropy(outputs.view(-1, outputs.size(-1)), batch.view(-1))
# Gradient accumulation
loss = loss / MODEL_CONFIG['GRADIENT_ACCUMULATION_STEPS']
scaler.scale(loss).backward()
if (step + 1) % MODEL_CONFIG['GRADIENT_ACCUMULATION_STEPS'] == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), MODEL_CONFIG['MAX_GRAD_NORM'])
scaler.step(optimizer)
scaler.update()
scheduler.step()
optimizer.zero_grad()
if step % TRAINING_CONFIG['LOGGING_STEPS'] == 0:
wandb.log({"loss": loss.item() * MODEL_CONFIG['GRADIENT_ACCUMULATION_STEPS']})
if step % TRAINING_CONFIG['EVAL_STEPS'] == 0:
model.eval()
with torch.no_grad():
val_loss = sum(F.cross_entropy(model(batch).view(-1, outputs.size(-1)), batch.view(-1)).item() for batch in dataloader)
wandb.log({"val_loss": val_loss / len(dataloader)})
model.train()
if step % TRAINING_CONFIG['CHECKPOINT_SAVE_STEPS'] == 0:
torch.save(model.state_dict(), f"checkpoint_{epoch}_{step}.pt")
torch.save(model.state_dict(), "GEM_1o_Aug_15.pt")
print("Training complete. Final model saved.")
if __name__ == "__main__":
train()
|