Text Generation
English
instruction-following
reasoning
File size: 3,467 Bytes
d18eb09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
import wandb
from transformers import get_linear_schedule_with_warmup
from utils.data_preprocessing import get_dataloader, load_tokenizer
from models.gem_model import GEM
from configs.config import MODEL_CONFIG, TRAINING_CONFIG

def train():
    wandb.init(project="GEM_Project", config=MODEL_CONFIG, mode="offline")
    print("WandB initialized in offline mode.")

    tokenizer = load_tokenizer()
    print("Tokenizer loaded.")

    dataloader = get_dataloader('wikitext', 'wikitext-2-raw-v1', tokenizer, MODEL_CONFIG['MAX_SEQ_LEN'], MODEL_CONFIG['BATCH_SIZE'])
    print("Dataloader created.")

    model = GEM(
        vocab_size=len(tokenizer),
        d_model=MODEL_CONFIG['D_MODEL'],
        n_heads=MODEL_CONFIG['N_HEADS'],
        d_ff=MODEL_CONFIG['D_FF'],
        n_layers=MODEL_CONFIG['N_LAYERS'],
        dropout=MODEL_CONFIG['DROPOUT']
    ).to(MODEL_CONFIG['DEVICE'])
    print("Model initialized.")

    optimizer = optim.AdamW(model.parameters(), lr=MODEL_CONFIG['LEARNING_RATE'], eps=MODEL_CONFIG['ADAM_EPSILON'])
    total_steps = len(dataloader) * MODEL_CONFIG['NUM_EPOCHS'] // MODEL_CONFIG['GRADIENT_ACCUMULATION_STEPS']
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=MODEL_CONFIG['WARMUP_STEPS'],
        num_training_steps=total_steps
    )
    print("Optimizer and scheduler set up.")

    # Mixed precision setup
    scaler = torch.cuda.amp.GradScaler()

    model.train()
    print("Starting training loop.")
    for epoch in range(MODEL_CONFIG['NUM_EPOCHS']):
        print(f"Epoch {epoch + 1}/{MODEL_CONFIG['NUM_EPOCHS']} started.")
        for step, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch + 1}")):
            batch = batch.to(MODEL_CONFIG['DEVICE'])

            # Mixed precision training
            with torch.cuda.amp.autocast():
                outputs = model(batch)
                loss = F.cross_entropy(outputs.view(-1, outputs.size(-1)), batch.view(-1))

            # Gradient accumulation
            loss = loss / MODEL_CONFIG['GRADIENT_ACCUMULATION_STEPS']
            scaler.scale(loss).backward()

            if (step + 1) % MODEL_CONFIG['GRADIENT_ACCUMULATION_STEPS'] == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), MODEL_CONFIG['MAX_GRAD_NORM'])
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                optimizer.zero_grad()

                if step % TRAINING_CONFIG['LOGGING_STEPS'] == 0:
                    wandb.log({"loss": loss.item() * MODEL_CONFIG['GRADIENT_ACCUMULATION_STEPS']})

                if step % TRAINING_CONFIG['EVAL_STEPS'] == 0:
                    model.eval()
                    with torch.no_grad():
                        val_loss = sum(F.cross_entropy(model(batch).view(-1, outputs.size(-1)), batch.view(-1)).item() for batch in dataloader)
                    wandb.log({"val_loss": val_loss / len(dataloader)})
                    model.train()

                if step % TRAINING_CONFIG['CHECKPOINT_SAVE_STEPS'] == 0:
                    torch.save(model.state_dict(), f"checkpoint_{epoch}_{step}.pt")

    torch.save(model.state_dict(), "GEM_1o_Aug_15.pt")
    print("Training complete. Final model saved.")

if __name__ == "__main__":
    train()