File size: 8,096 Bytes
c13ef0b
d97c361
4e1467d
 
 
 
c13ef0b
 
 
 
 
 
d97c361
1bcfe48
405f5b1
4e1467d
c13ef0b
 
 
1bcfe48
 
 
c13ef0b
 
 
 
 
 
 
 
 
 
 
405f5b1
1bcfe48
2c547b1
405f5b1
 
c13ef0b
405f5b1
d97c361
c13ef0b
 
 
1bcfe48
c13ef0b
 
 
d97c361
c13ef0b
 
 
 
 
 
 
 
 
 
 
 
 
405f5b1
4e1467d
c13ef0b
 
 
 
 
 
 
 
 
 
 
1bcfe48
c13ef0b
 
 
 
 
 
 
d97c361
c13ef0b
d97c361
 
1bcfe48
 
 
c13ef0b
1bcfe48
c13ef0b
1bcfe48
c13ef0b
 
 
1bcfe48
2c547b1
 
 
1bcfe48
 
 
 
 
 
 
 
 
 
 
 
 
 
c13ef0b
405f5b1
4e1467d
c13ef0b
 
 
 
 
 
 
 
 
 
 
 
 
 
d97c361
1bcfe48
 
 
 
 
 
 
 
c13ef0b
 
 
 
 
 
4e1467d
405f5b1
c13ef0b
 
 
 
 
405f5b1
 
d97c361
1bcfe48
 
c13ef0b
d97c361
c13ef0b
 
 
d97c361
 
 
 
 
 
 
 
 
 
 
2c547b1
d97c361
 
 
c13ef0b
d97c361
 
c13ef0b
4e1467d
 
c13ef0b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import argparse
import time
import torch as t
import torch.nn as nn
import torch.functional as F
import torch.optim as optim
from tqdm import tqdm
import wandb

from typing import Tuple
from torch.utils.data.dataloader import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer
from utils import OsSoluConfig, tokenise, loss_fn, count_parameters
from model import OsSoluModel

WANDB_PROJECT_NAME = "os_solu"
DEVICE = "cuda" if t.cuda.is_available() else "cpu"

# TODO: Add support for distributed training.
# TODO: Use only book data from dataset.

def parse_arguments() -> dict:
    """Parses command-line arguments for this model run. Arguments of type string have allowed values, 
       which are enforced. Default parameter values are provided such that fields in the config are never None.

    Raises:
        ValueError: optimiser type must be adam or sgd.
        ValueError: attention type must be rotary or unidirectional.

    Returns:
        dict: a dictionary containing the command-line arguments parsed by this function.
    """
    parser = argparse.ArgumentParser(description="Parse command-line arguments for this model.")
    parser.add_argument("--batch_size", type=int, default=40, help="Batch size used in training.")
    parser.add_argument("--checkpoint_every_n_tokens", type=int, default=500_000_000, help="Save a checkpoint of the model every n tokens processed.")
    parser.add_argument("--d_model", type=int, default=512, help="Hidden size of the model.")
    parser.add_argument("--dropout", type=float, default=0.1, help="Probability of dropout.")
    parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate for the optimiser.")
    parser.add_argument("--ln_eps", type=float, default=1e-3, help="Layer norm epsilon.")
    parser.add_argument("--max_positional_embeddings", type=int, default=1024, help="Maximum number of positional embeddings/sequence length.")
    parser.add_argument("--nonlinearity", type=str, default="solu", help=" Nonlinearity to use inside MLP block: must be relu or solu.")
    parser.add_argument("--num_blocks", type=int, default=1, help="Number of transformer blocks.")
    parser.add_argument("--num_embeddings", type=int, default=1024, help="Number of embeddings.")
    parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs to run for.")
    parser.add_argument("--num_heads", type=int, default=4, help="Number of attention heads in each attention layer.")
    parser.add_argument("--optimiser_type", type=str, default="adam", help="Optimiser type.")
    parser.add_argument("--self_attention_type", type=str, default="unidirectional", help="What type of attention to use: rotary or unidirectional.")
    parser.add_argument("--vocab_size", type=int, default=50_278, help="Vocabulary size of the input sequence.")
    args = vars(parser.parse_args())

    # Parse string arguments.
    allowed_values = {
        "optimiser_type": ["adam", "sgd"], 
        "self_attention_type": ["unidirectional", "rotary"], 
        "nonlinearity": ["relu", "solu"],    
    }

    for key, values in allowed_values.items():
        if args[key] not in values:
            raise ValueError(f"{key} should be one of {values}.")

    return args

def train(config: OsSoluConfig, model: OsSoluModel, train_dataloader: DataLoader) -> OsSoluModel:
    """Trains a model using the config and training dataset provided.

    Args:
        config (OsSoluConfig): The config object.
        model (OsSoluModel): The model to train.
        train_dataloader (t.utils.data.DataLoader): The training dataset provided as a torch DataLoader object.

    Returns:
        OsSoluModel: The trained model.
    """
    wandb.watch(model, criterion=loss_fn, log="all", log_freq=10, log_graph=True)

    # Initialise optimiser.
    opt = optim.Adam if config.optimiser_type.lower() == "adam" else optim.SGD
    optimiser = opt(model.parameters(), lr=config.learning_rate)

    # Train loop.
    examples_seen = 0
    train_data_iterator = iter(train_dataloader)
    for epoch in range(config.num_epochs):
        for i, batch in enumerate(tqdm(train_data_iterator
    )):
            start_time = time.time()
            batch = batch["text"]
            batch = batch.to(DEVICE)

            logits = model(batch)
            optimiser.zero_grad()
            loss = loss_fn(logits, batch)
            loss.backward()
            optimiser.step()

            wandb.log(dict(train_loss=loss, elapsed=time.time() - start_time), step=examples_seen)

            # Number of tokens processed is batch_size * sequence_length.
            examples_seen += batch.numel()

            # Save a checkpoint of the model.
            if examples_seen % config.checkpoint_every_n_tokens == 0:
                # Save the model's state on disk, then upload to wandb.
                filename = f"{wandb.run.dir}/os_solu_model_ckpt_step_{examples_seen}.pt"
                t.save({
                    "step": examples_seen,
                    "model_state_dict": model.state_dict(),
                    "optimiser_state_dict": optimiser.state_dict(),
                    "loss": loss.item()
                }, filename)
                wandb.save(filename)
                print(f"Checkpointing model at {examples_seen} tokens seen.")


    return model

def eval(model: OsSoluModel, test_dataloader: DataLoader) -> None:
    """Evaluates a trained model on the test dataset provided.

    Args:
        model (OsSoluModel): The trained model.
        test_dataset (t.utils.data.Dataset): The dataset on which to evaluate the model.
    """
    test_loss_fn = t.nn.CrossEntropyLoss()

    # Eval loop.
    examples_seen = 0
    total_loss, num_correct = 0, 0
    model.eval()
    with t.inference_mode():
        test_data_iterator = iter(test_dataloader)
        for i, batch in enumerate(tqdm(test_data_iterator)):
            batch = batch["text"]
            batch = batch.to(DEVICE)

            logits = model(batch)
            total_loss += loss_fn(logits, batch).item()
            examples_seen += len(batch)
        wandb.log(dict(test_loss=total_loss, elapsed=time.time() - start_time), step=examples_seen)
    
    # Save the model's state on disk, then upload to wandb.
    filename = f"{wandb.run.dir}/model_state_dict.pt"
    t.save(model.state_dict(), filename)
    wandb.save(filename)


def setup() -> Tuple[OsSoluConfig, OsSoluModel]:
    """This function delegates the setup to various helper functions.

    Returns:
        Tuple[OsSoluConfig, OsSoluModel, datasets.iterable_dataset.IterableDataset, datasets.iterable_dataset.IterableDataset]: A tuple containing a config, a model, a training dataset and a test dataset.
    """
    args = parse_arguments()
    config = OsSoluConfig(args)
    model = OsSoluModel(config).to(DEVICE)
    args["num_parameters"] = count_parameters(model)
    wandb.init(project=WANDB_PROJECT_NAME, config=args)

    start_data_time = time.time()
    # Load and prep data.
    ds = load_dataset("the_pile", streaming=True)

    try:
        ds = ds.remove_columns("meta")
    except:
        print("Dataset did not contain 'meta' column.")

    train_dataset = ds["train"]
    test_dataset = ds["test"]

    tokeniser = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
    tokeniser.add_special_tokens({"pad_token": "<PAD>"})

    train_dataset = train_dataset.map(lambda x: tokenise(x, tokeniser, 1, config.max_positional_embeddings), batched=True).with_format("torch")
    test_dataset = test_dataset.map(tokenise, batched=True).with_format("torch")

    train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size)
    test_dataloader = DataLoader(test_dataset, batch_size=config.batch_size)
    print(f"Data loaded in {time.time() - start_data_time:.1f}s.")

    return config, model, (train_dataloader, test_dataloader)

if __name__=="__main__":
    config, model, (train_dataloader, test_dataloader) = setup()
    trained_model = train(config, model, train_dataloader)
    eval(trained_model, test_dataloader)