inwaves commited on
Commit
c13ef0b
1 Parent(s): 0cab14c

WIP getting the Pile dataset up and running

Browse files
Files changed (4) hide show
  1. main.py +127 -20
  2. model.py +16 -8
  3. requirements.txt +6 -5
  4. utils.py +27 -19
main.py CHANGED
@@ -1,44 +1,151 @@
 
1
  import torch as t
2
  import torch.nn as nn
3
  import torch.functional as F
4
  import torch.optim as optim
5
- import argparse
 
 
 
 
 
6
  from utils import OsSoluConfig
7
  from model import OsSoluModel
8
- from typing import Tuple
9
 
10
- def parse_arguments() -> argparse.Namespace:
11
- # TODO: command-line args for hparams
 
 
 
 
 
 
 
 
 
 
 
 
12
  parser = argparse.ArgumentParser(description="Parse command-line arguments for this model.")
 
13
  parser.add_argument("--d_model", type=int, default=512, help="Hidden size of the model.")
14
- parser.add_argument("--vocab_size", type=int, default=65536, help="Vocabulary size of the input sequence.")
15
- parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate for the optimiser.")
16
- parser.add_argument("--num_embeddings", type=int, default=1024, help="Number of embeddings.")
17
- parser.add_argument("--num_blocks", type=int, default=1, help="Number of transformer blocks.")
18
  parser.add_argument("--dropout", type=float, default=0.1, help="Probability of dropout.")
 
19
  parser.add_argument("--ln_eps", type=float, default=1e-3, help="Layer norm epsilon.")
20
- parser.add_argument("--num_heads", type=int, default=4, help="Number of attention heads in each attention layer.")
21
- parser.add_argument("--self_attention_type", type=str, default="unidirectional", help="What type of attention to use: rotary or unidirectional. ")
22
  parser.add_argument("--max_positional_embeddings", type=int, default=1024, help="Maximum number of positional embeddings.")
23
- args = parser.parse_args()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  return args
25
 
26
- def train(config: OsSoluConfig, model: OsSoluModel) -> OsSoluModel:
 
 
 
 
 
 
 
 
 
 
27
  # TODO: training loop
28
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  return model
30
 
31
- def eval():
32
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def setup() -> Tuple[OsSoluConfig, OsSoluModel]:
35
- # TODO: wandb logging
 
 
 
 
36
  args = parse_arguments()
 
37
  config = OsSoluConfig(args)
38
  model = OsSoluModel(config)
39
- return config, model
 
 
 
 
 
 
 
 
40
 
41
  if __name__=="__main__":
42
- config, model = setup()
43
- trained_model = train(config, model)
44
- eval()
 
1
+ import argparse
2
  import torch as t
3
  import torch.nn as nn
4
  import torch.functional as F
5
  import torch.optim as optim
6
+ from tqdm import tqdm
7
+ import wandb
8
+
9
+ from typing import Tuple
10
+ from torch.utils.data.dataloader import DataLoader
11
+ from datasets import load_dataset
12
  from utils import OsSoluConfig
13
  from model import OsSoluModel
 
14
 
15
+ WANDB_PROJECT_NAME = "os_solu"
16
+ DEVICE = "cuda" if t.cuda.is_available() else "cpu"
17
+
18
+ def parse_arguments() -> dict:
19
+ """Parses command-line arguments for this model run. Arguments of type string have allowed values,
20
+ which are enforced. Default parameter values are provided such that fields in the config are never None.
21
+
22
+ Raises:
23
+ ValueError: optimiser type must be adam or sgd.
24
+ ValueError: attention type must be rotary or unidirectional.
25
+
26
+ Returns:
27
+ dict: a dictionary containing the command-line arguments parsed by this function.
28
+ """
29
  parser = argparse.ArgumentParser(description="Parse command-line arguments for this model.")
30
+ parser.add_argument("--batch_size", type=int, default=256, help="Batch size used in training.")
31
  parser.add_argument("--d_model", type=int, default=512, help="Hidden size of the model.")
 
 
 
 
32
  parser.add_argument("--dropout", type=float, default=0.1, help="Probability of dropout.")
33
+ parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate for the optimiser.")
34
  parser.add_argument("--ln_eps", type=float, default=1e-3, help="Layer norm epsilon.")
 
 
35
  parser.add_argument("--max_positional_embeddings", type=int, default=1024, help="Maximum number of positional embeddings.")
36
+ parser.add_argument("--nonlinearity", type=str, default="solu", help=" Nonlinearity to use inside MLP block: must be relu or solu.")
37
+ parser.add_argument("--num_blocks", type=int, default=1, help="Number of transformer blocks.")
38
+ parser.add_argument("--num_embeddings", type=int, default=1024, help="Number of embeddings.")
39
+ parser.add_argument("--num_epochs", type=int, default=5, help="Number of epochs to run for.")
40
+ parser.add_argument("--num_heads", type=int, default=4, help="Number of attention heads in each attention layer.")
41
+ parser.add_argument("--optimiser_type", type=str, default="adam", help="Optimiser type.")
42
+ parser.add_argument("--self_attention_type", type=str, default="unidirectional", help="What type of attention to use: rotary or unidirectional.")
43
+ parser.add_argument("--vocab_size", type=int, default=65536, help="Vocabulary size of the input sequence.")
44
+ args = vars(parser.parse_args())
45
+
46
+ # Parse string arguments.
47
+ allowed_values = {
48
+ "optimiser_type": ["adam", "sgd"],
49
+ "self_attention_type": ["unidirectional", "rotary"],
50
+ "nonlinearity": ["relu", "solu"],
51
+ }
52
+
53
+ for key, values in allowed_values.items():
54
+ if args[key] not in values:
55
+ raise ValueError(f"{key} should be one of {values}.")
56
+
57
  return args
58
 
59
+ def train(config: OsSoluConfig, model: OsSoluModel, train_dataloader: DataLoader) -> OsSoluModel:
60
+ """Trains a model using the config and training dataset provided.
61
+
62
+ Args:
63
+ config (OsSoluConfig): The config object.
64
+ model (OsSoluModel): The model to train.
65
+ train_dataloader (t.utils.data.DataLoader): The training dataset provided as a torch DataLoader object.
66
+
67
+ Returns:
68
+ OsSoluModel: The trained model.
69
+ """
70
  # TODO: training loop
71
+ train_loss_fn = t.nn.CrossEntropyLoss()
72
+ wandb.watch(model, criterion=train_loss_fn, log="all", log_freq=10, log_graph=True)
73
+
74
+ # Initialise optimiser.
75
+ opt = optim.Adam if config.optimiser_type.lower() == "adam" else optim.SGD
76
+ optimiser = opt(model.parameters(), lr=config.learning_rate)
77
+
78
+ # Train loop.
79
+ examples_seen = 0
80
+ for epoch in range(config.num_epochs):
81
+ for i, (data, target) in enumerate(tqdm(train_dataloader)):
82
+ print(data, target)
83
+ data = data.to(DEVICE)
84
+ target = target.to(DEVICE)
85
+
86
+ predictions = model(data)
87
+ accuracy = (predictions.argmax(dim=-1) == target).sum() / len(data)
88
+ optimiser.zero_grad()
89
+ loss = train_loss_fn(target, predictions)
90
+ loss.backward()
91
+ optimiser.step()
92
+
93
+ wandb.log(dict(train_loss=loss, train_accuracy=accuracy, elapsed=time.time() - start_time), step=examples_seen)
94
+ examples_seen += len(data)
95
+
96
  return model
97
 
98
+ def eval(model: OsSoluModel, test_dataloader: DataLoader) -> None:
99
+ """Evaluates a trained model on the test dataset provided.
100
+
101
+ Args:
102
+ model (OsSoluModel): The trained model.
103
+ test_dataset (t.utils.data.Dataset): The dataset on which to evaluate the model.
104
+ """
105
+ test_loss_fn = t.nn.CrossEntropyLoss()
106
+
107
+ # Eval loop.
108
+ examples_seen = 0
109
+ total_loss, num_correct = 0, 0
110
+ model.eval()
111
+ with t.inference_mode():
112
+ for i, (data, target) in enumerate(tqdm(test_dataloader)):
113
+ data = data.to(DEVICE)
114
+ target = target.to(DEVICE)
115
+
116
+ predictions = model(data)
117
+ num_correct += (predictions.argmax(dim=-1) == target).sum().item()
118
+ total_loss += test_loss_fn(target, predictions).item()
119
+ examples_seen += len(data)
120
+ wandb.log(dict(test_loss=total_loss, test_accuracy=num_correct / examples_seen, elapsed=time.time() - start_time), step=examples_seen)
121
+
122
+ # Save the model's state on disk, then upload to wandb.
123
+ filename = f"{wandb.run.dir}/model_state_dict.pt"
124
+ t.save(model.state_dict(), filename)
125
+ wandb.save(filename)
126
+
127
 
128
  def setup() -> Tuple[OsSoluConfig, OsSoluModel]:
129
+ """This function delegates the setup to various helper functions.
130
+
131
+ Returns:
132
+ Tuple[OsSoluConfig, OsSoluModel, datasets.iterable_dataset.IterableDataset, datasets.iterable_dataset.IterableDataset]: A tuple containing a config, a model, a training dataset and a test dataset.
133
+ """
134
  args = parse_arguments()
135
+ wandb.init(project=WANDB_PROJECT_NAME, config=args)
136
  config = OsSoluConfig(args)
137
  model = OsSoluModel(config)
138
+
139
+ # Load and prep data.
140
+ ds = load_dataset("the_pile", streaming=True)
141
+ train_dataset = ds["train"].with_format("torch")
142
+ train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size)
143
+
144
+ test_dataset = ds["test"].with_format("torch")
145
+ test_dataloader = DataLoader(test_dataset, batch_size=config.batch_size)
146
+ return config, model, (train_dataloader, test_dataloader)
147
 
148
  if __name__=="__main__":
149
+ config, model, (train_dataloader, test_dataloader) = setup()
150
+ trained_model = train(config, model, train_dataloader)
151
+ eval(trained_model, test_dataloader)
model.py CHANGED
@@ -8,28 +8,35 @@ from einops import rearrange, repeat, reduce
8
  from utils import OsSoluConfig
9
 
10
 
 
11
  class OsSoluModel(nn.Module):
 
 
12
  def __init__(self, config: OsSoluConfig) -> None:
13
  super().__init__()
14
- normalised_shape = None # TODO: normalised_shape should be defined properly
15
  self.config = config
16
  self.embed_positions = nn.Embedding(config.max_positional_embeddings, config.d_model)
17
  self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
18
  self.dropout = nn.Dropout(config.dropout)
19
  self.transformer_blocks = nn.ModuleList([GPT2Block(config) for _ in range(config.num_blocks)])
20
- self.final_ln = nn.LayerNorm(normalized_shape, config.ln_eps)
21
- self.unembed = nn
22
 
23
  def forward(self, x: t.Tensor) -> t.Tensor:
24
  positional_embeddings = self.embed_positions(t.arange(x.size(1)))
25
  token_embeddings = self.embed_tokens(x)
26
  embeddings = positional_embeddings + token_embeddings
27
  out = self.dropout(embeddings)
28
- out = self.transformer_blocks(out)
 
 
 
 
 
29
 
30
  class SoLU(nn.Module):
 
31
  def __init__(self):
32
- pass
33
 
34
  def forward(self, x: t.Tensor) -> t.Tensor:
35
  return x * x.softmax(dim=-1)
@@ -39,12 +46,13 @@ class GPT2Block(nn.Module):
39
  super().__init__()
40
  self.config = config
41
 
42
- self.layer_norm1 = nn.LayerNorm(normalized_shape, config.ln_eps)
43
  self.attention = UnidirectionalAttention(config) if config.self_attention_type == "unidirectional" else RotaryAttention(config)
 
44
  self.MLP = nn.Sequential(
45
- nn.LayerNorm(normalized_shape, config.ln_eps),
46
  nn.Linear(config.d_model, 4*config.d_model),
47
- SoLU(),
48
  nn.Linear(4*config.d_model, config.d_model),
49
  nn.Dropout(config.dropout)
50
  )
 
8
  from utils import OsSoluConfig
9
 
10
 
11
+
12
  class OsSoluModel(nn.Module):
13
+ """An open-source implementation of a SoLU-based transformer. This is a GPT-style architecture model
14
+ where the nonlinearity in the MLP block is replaced with SoLU(x) = x * softmax(x)."""
15
  def __init__(self, config: OsSoluConfig) -> None:
16
  super().__init__()
 
17
  self.config = config
18
  self.embed_positions = nn.Embedding(config.max_positional_embeddings, config.d_model)
19
  self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
20
  self.dropout = nn.Dropout(config.dropout)
21
  self.transformer_blocks = nn.ModuleList([GPT2Block(config) for _ in range(config.num_blocks)])
22
+ self.final_ln = nn.LayerNorm(config.d_model, config.ln_eps)
 
23
 
24
  def forward(self, x: t.Tensor) -> t.Tensor:
25
  positional_embeddings = self.embed_positions(t.arange(x.size(1)))
26
  token_embeddings = self.embed_tokens(x)
27
  embeddings = positional_embeddings + token_embeddings
28
  out = self.dropout(embeddings)
29
+ for block in self.transformer_blocks:
30
+ out = block(out)
31
+
32
+ # Unembedding is not separate, so we just einsum with token embedding weights.
33
+ out = einsum("vocab hidden, batch seq hidden -> batch seq vocab", self.embed_tokens.weight, out)
34
+ return out
35
 
36
  class SoLU(nn.Module):
37
+ """A simple wrapper around the SoLU function such that it can be used as a layer in a model."""
38
  def __init__(self):
39
+ super().__init__()
40
 
41
  def forward(self, x: t.Tensor) -> t.Tensor:
42
  return x * x.softmax(dim=-1)
 
46
  super().__init__()
47
  self.config = config
48
 
49
+ self.layer_norm1 = nn.LayerNorm(config.d_model, config.ln_eps)
50
  self.attention = UnidirectionalAttention(config) if config.self_attention_type == "unidirectional" else RotaryAttention(config)
51
+ nonlinearity = SoLU() if config.nonlinearity == "solu" else nn.ReLU()
52
  self.MLP = nn.Sequential(
53
+ nn.LayerNorm(config.d_model, config.ln_eps),
54
  nn.Linear(config.d_model, 4*config.d_model),
55
+ nonlinearity,
56
  nn.Linear(4*config.d_model, config.d_model),
57
  nn.Dropout(config.dropout)
58
  )
requirements.txt CHANGED
@@ -1,13 +1,14 @@
1
- torch
2
- wandb
3
  einops
4
  fancy_einsum
5
- tqdm
6
  ipykernel
7
- notebook
8
  ipywidgets
9
  jupyter
10
  matplotlib
 
11
  numpy-stl
 
 
 
12
  wandb
13
- plotly
 
1
+ datasets
 
2
  einops
3
  fancy_einsum
 
4
  ipykernel
 
5
  ipywidgets
6
  jupyter
7
  matplotlib
8
+ notebook
9
  numpy-stl
10
+ plotly
11
+ torch
12
+ tqdm
13
  wandb
14
+ zstandard
utils.py CHANGED
@@ -1,27 +1,35 @@
1
- import argparse
2
-
3
  class OsSoluConfig:
 
 
 
4
  d_model: int # Hidden size of the model.
5
- vocab_size: int # Vocabulary size of the input sequence. Unsure about this.
6
- learning_rate: float # Learning rate for the optimiser.
7
- num_embeddings: int # Number of embeddings. Unsure about this.
8
- num_blocks: int # Number of transformer blocks.
9
  dropout: float # Probability of dropout.
 
10
  ln_eps: float # Layer norm epsilon.
 
 
 
 
 
11
  num_heads: int # Number of attention heads in each attention layer.
12
  self_attention_type: str # What type of attention to use: rotary or unidirectional.
13
- max_positional_embeddings: int # Maximum number of positional embeddings.
14
-
15
- def __init__(self, args: argparse.Namespace) -> None:
 
16
  """Initialise this config class with values provided by a command-line argument parser.
17
  Values are never None here, as we provide suitable defaults in the parser call."""
18
- self.d_model = args.d_model
19
- self.vocab_size = args.vocab_size
20
- self.learning_rate = args.learning_rate
21
- self.num_embeddings = args.num_embeddings
22
- self.num_blocks = args.num_blocks
23
- self.dropout = args.dropout
24
- self.ln_eps = args.ln_eps
25
- self.num_heads = args.num_heads
26
- self.self_attention_type = args.self_attention_type
27
- self.max_positional_embeddings = args.max_positional_embeddings
 
 
 
 
 
 
 
1
  class OsSoluConfig:
2
+ """A class to hold hyperparameters for the model itself and for the training process."""
3
+
4
+ batch_size: int # Training data batch size.
5
  d_model: int # Hidden size of the model.
 
 
 
 
6
  dropout: float # Probability of dropout.
7
+ learning_rate: float # Learning rate for the optimiser.
8
  ln_eps: float # Layer norm epsilon.
9
+ max_positional_embeddings: int # Maximum number of positional embeddings.
10
+ nonlinearity: str # Nonlinearity to use inside MLP block: must be ReLU or SoLU.
11
+ num_blocks: int # Number of transformer blocks.
12
+ num_embeddings: int # Number of embeddings. Unsure about this.
13
+ num_epochs: int # Number of epochs for this run.
14
  num_heads: int # Number of attention heads in each attention layer.
15
  self_attention_type: str # What type of attention to use: rotary or unidirectional.
16
+ optimiser_type: str # Optimiser type: SGD, Adam.
17
+ vocab_size: int # Vocabulary size of the input sequence. Unsure about this.
18
+
19
+ def __init__(self, args: dict) -> None:
20
  """Initialise this config class with values provided by a command-line argument parser.
21
  Values are never None here, as we provide suitable defaults in the parser call."""
22
+ self.batch_size = args["batch_size"]
23
+ self.d_model = args["d_model"]
24
+ self.dropout = args["dropout"]
25
+ self.learning_rate = args["learning_rate"]
26
+ self.ln_eps = args["ln_eps"]
27
+ self.max_positional_embeddings = args["max_positional_embeddings"]
28
+ self.nonlinearity = args["nonlinearity"]
29
+ self.num_blocks = args["num_blocks"]
30
+ self.num_embeddings = args["num_embeddings"]
31
+ self.num_epochs = args["num_epochs"]
32
+ self.num_heads = args["num_heads"]
33
+ self.optimiser_type = args["optimiser_type"]
34
+ self.self_attention_type = args["self_attention_type"]
35
+ self.vocab_size = args["vocab_size"]