inwaves commited on
Commit
d97c361
1 Parent(s): c13ef0b

Added tokenise method for streamed data, fixed issues with einsums

Browse files
Files changed (4) hide show
  1. main.py +33 -14
  2. model.py +20 -11
  3. requirements.txt +1 -0
  4. utils.py +42 -1
main.py CHANGED
@@ -1,4 +1,5 @@
1
  import argparse
 
2
  import torch as t
3
  import torch.nn as nn
4
  import torch.functional as F
@@ -9,7 +10,8 @@ import wandb
9
  from typing import Tuple
10
  from torch.utils.data.dataloader import DataLoader
11
  from datasets import load_dataset
12
- from utils import OsSoluConfig
 
13
  from model import OsSoluModel
14
 
15
  WANDB_PROJECT_NAME = "os_solu"
@@ -32,7 +34,7 @@ def parse_arguments() -> dict:
32
  parser.add_argument("--dropout", type=float, default=0.1, help="Probability of dropout.")
33
  parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate for the optimiser.")
34
  parser.add_argument("--ln_eps", type=float, default=1e-3, help="Layer norm epsilon.")
35
- parser.add_argument("--max_positional_embeddings", type=int, default=1024, help="Maximum number of positional embeddings.")
36
  parser.add_argument("--nonlinearity", type=str, default="solu", help=" Nonlinearity to use inside MLP block: must be relu or solu.")
37
  parser.add_argument("--num_blocks", type=int, default=1, help="Number of transformer blocks.")
38
  parser.add_argument("--num_embeddings", type=int, default=1024, help="Number of embeddings.")
@@ -40,7 +42,7 @@ def parse_arguments() -> dict:
40
  parser.add_argument("--num_heads", type=int, default=4, help="Number of attention heads in each attention layer.")
41
  parser.add_argument("--optimiser_type", type=str, default="adam", help="Optimiser type.")
42
  parser.add_argument("--self_attention_type", type=str, default="unidirectional", help="What type of attention to use: rotary or unidirectional.")
43
- parser.add_argument("--vocab_size", type=int, default=65536, help="Vocabulary size of the input sequence.")
44
  args = vars(parser.parse_args())
45
 
46
  # Parse string arguments.
@@ -67,7 +69,6 @@ def train(config: OsSoluConfig, model: OsSoluModel, train_dataloader: DataLoader
67
  Returns:
68
  OsSoluModel: The trained model.
69
  """
70
- # TODO: training loop
71
  train_loss_fn = t.nn.CrossEntropyLoss()
72
  wandb.watch(model, criterion=train_loss_fn, log="all", log_freq=10, log_graph=True)
73
 
@@ -77,16 +78,17 @@ def train(config: OsSoluConfig, model: OsSoluModel, train_dataloader: DataLoader
77
 
78
  # Train loop.
79
  examples_seen = 0
 
80
  for epoch in range(config.num_epochs):
81
- for i, (data, target) in enumerate(tqdm(train_dataloader)):
82
- print(data, target)
 
83
  data = data.to(DEVICE)
84
- target = target.to(DEVICE)
85
 
86
  predictions = model(data)
87
  accuracy = (predictions.argmax(dim=-1) == target).sum() / len(data)
88
  optimiser.zero_grad()
89
- loss = train_loss_fn(target, predictions)
90
  loss.backward()
91
  optimiser.step()
92
 
@@ -109,9 +111,10 @@ def eval(model: OsSoluModel, test_dataloader: DataLoader) -> None:
109
  total_loss, num_correct = 0, 0
110
  model.eval()
111
  with t.inference_mode():
112
- for i, (data, target) in enumerate(tqdm(test_dataloader)):
 
 
113
  data = data.to(DEVICE)
114
- target = target.to(DEVICE)
115
 
116
  predictions = model(data)
117
  num_correct += (predictions.argmax(dim=-1) == target).sum().item()
@@ -134,15 +137,31 @@ def setup() -> Tuple[OsSoluConfig, OsSoluModel]:
134
  args = parse_arguments()
135
  wandb.init(project=WANDB_PROJECT_NAME, config=args)
136
  config = OsSoluConfig(args)
137
- model = OsSoluModel(config)
138
 
 
139
  # Load and prep data.
140
  ds = load_dataset("the_pile", streaming=True)
141
- train_dataset = ds["train"].with_format("torch")
142
- train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size)
143
 
144
- test_dataset = ds["test"].with_format("torch")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  test_dataloader = DataLoader(test_dataset, batch_size=config.batch_size)
 
 
146
  return config, model, (train_dataloader, test_dataloader)
147
 
148
  if __name__=="__main__":
 
1
  import argparse
2
+ import time
3
  import torch as t
4
  import torch.nn as nn
5
  import torch.functional as F
 
10
  from typing import Tuple
11
  from torch.utils.data.dataloader import DataLoader
12
  from datasets import load_dataset
13
+ from transformers import AutoTokenizer
14
+ from utils import OsSoluConfig, tokenise
15
  from model import OsSoluModel
16
 
17
  WANDB_PROJECT_NAME = "os_solu"
 
34
  parser.add_argument("--dropout", type=float, default=0.1, help="Probability of dropout.")
35
  parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate for the optimiser.")
36
  parser.add_argument("--ln_eps", type=float, default=1e-3, help="Layer norm epsilon.")
37
+ parser.add_argument("--max_positional_embeddings", type=int, default=1024, help="Maximum number of positional embeddings/sequence length.")
38
  parser.add_argument("--nonlinearity", type=str, default="solu", help=" Nonlinearity to use inside MLP block: must be relu or solu.")
39
  parser.add_argument("--num_blocks", type=int, default=1, help="Number of transformer blocks.")
40
  parser.add_argument("--num_embeddings", type=int, default=1024, help="Number of embeddings.")
 
42
  parser.add_argument("--num_heads", type=int, default=4, help="Number of attention heads in each attention layer.")
43
  parser.add_argument("--optimiser_type", type=str, default="adam", help="Optimiser type.")
44
  parser.add_argument("--self_attention_type", type=str, default="unidirectional", help="What type of attention to use: rotary or unidirectional.")
45
+ parser.add_argument("--vocab_size", type=int, default=50_278, help="Vocabulary size of the input sequence.")
46
  args = vars(parser.parse_args())
47
 
48
  # Parse string arguments.
 
69
  Returns:
70
  OsSoluModel: The trained model.
71
  """
 
72
  train_loss_fn = t.nn.CrossEntropyLoss()
73
  wandb.watch(model, criterion=train_loss_fn, log="all", log_freq=10, log_graph=True)
74
 
 
78
 
79
  # Train loop.
80
  examples_seen = 0
81
+ train_data_iterator = iter(train_dataloader)
82
  for epoch in range(config.num_epochs):
83
+ for i, batch in enumerate(tqdm(train_data_iterator
84
+ )):
85
+ data = batch["text"]
86
  data = data.to(DEVICE)
 
87
 
88
  predictions = model(data)
89
  accuracy = (predictions.argmax(dim=-1) == target).sum() / len(data)
90
  optimiser.zero_grad()
91
+ # loss = train_loss_fn(data, predictions)
92
  loss.backward()
93
  optimiser.step()
94
 
 
111
  total_loss, num_correct = 0, 0
112
  model.eval()
113
  with t.inference_mode():
114
+ test_data_iterator = iter(test_dataloader)
115
+ for i, (data, target) in enumerate(tqdm(test_data_iterator)):
116
+ data = batch["text"]
117
  data = data.to(DEVICE)
 
118
 
119
  predictions = model(data)
120
  num_correct += (predictions.argmax(dim=-1) == target).sum().item()
 
137
  args = parse_arguments()
138
  wandb.init(project=WANDB_PROJECT_NAME, config=args)
139
  config = OsSoluConfig(args)
140
+ model = OsSoluModel(config).to(DEVICE)
141
 
142
+ start_data_time = time.time()
143
  # Load and prep data.
144
  ds = load_dataset("the_pile", streaming=True)
 
 
145
 
146
+ try:
147
+ ds = ds.remove_columns("meta")
148
+ except:
149
+ print("Dataset did not contain 'meta' column.")
150
+
151
+ train_dataset = ds["train"]
152
+ test_dataset = ds["test"]
153
+
154
+ # TODO: tokenise the data before sending it to the model.
155
+ tokeniser = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
156
+ tokeniser.add_special_tokens({"pad_token": "<PAD>"})
157
+
158
+ train_dataset = train_dataset.map(lambda x: tokenise(x, tokeniser), batched=True).with_format("torch")
159
+ test_dataset = test_dataset.map(tokenise, batched=True).with_format("torch")
160
+
161
+ train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size)
162
  test_dataloader = DataLoader(test_dataset, batch_size=config.batch_size)
163
+ print(f"Data loaded in {time.time() - start_data_time:.1f}s.")
164
+
165
  return config, model, (train_dataloader, test_dataloader)
166
 
167
  if __name__=="__main__":
model.py CHANGED
@@ -3,7 +3,7 @@ import torch.nn as nn
3
  import torch.functional as F
4
  import torch.optim as optim
5
  import wandb
6
- import fancy_einsum as einsum
7
  from einops import rearrange, repeat, reduce
8
  from utils import OsSoluConfig
9
 
@@ -22,7 +22,7 @@ class OsSoluModel(nn.Module):
22
  self.final_ln = nn.LayerNorm(config.d_model, config.ln_eps)
23
 
24
  def forward(self, x: t.Tensor) -> t.Tensor:
25
- positional_embeddings = self.embed_positions(t.arange(x.size(1)))
26
  token_embeddings = self.embed_tokens(x)
27
  embeddings = positional_embeddings + token_embeddings
28
  out = self.dropout(embeddings)
@@ -69,9 +69,9 @@ class UnidirectionalAttention(nn.Module):
69
  super().__init__()
70
  self.num_heads = config.num_heads
71
  self.d_model = config.d_model
72
- self.project_q = nn.Linear(config.num_embeddings, config.d_model)
73
- self.project_k = nn.Linear(config.num_embeddings, config.d_model)
74
- self.project_v = nn.Linear(config.num_embeddings, config.d_model)
75
  self.project_out = nn.Linear(config.d_model, config.d_model)
76
  self.LARGE_NEGATIVE_VALUE = -1e5
77
 
@@ -84,7 +84,11 @@ class UnidirectionalAttention(nn.Module):
84
 
85
  Q = self.hidden_to_heads(Q)
86
  K = self.hidden_to_heads(K)
87
- attention_pattern = einsum("batch num_heads seqlen_q head_size, batch num_heads seqlen_k head_size -> batch num_heads seqlen_q seqlen_k")
 
 
 
 
88
 
89
  return attention_pattern
90
 
@@ -95,18 +99,23 @@ class UnidirectionalAttention(nn.Module):
95
 
96
  # Masking attention. Since GPT is unidirectional, it should only attend to previous tokens.
97
  if seqlen > 1:
98
- fst_range = t.arange(seqlen, device=self.device).unsqueeze(0).T
99
- snd_range = t.arange(seqlen, device=self.device).unsqueeze(0)
100
  bool_array = fst_range < snd_range
101
- attention_score[..., bool_array] = self.LARGE_NEGATIVE_VALUE
102
 
103
 
104
  attention_pattern = attention_pattern / t.sqrt(t.tensor(self.d_model // self.num_heads))
105
  attention_score = attention_pattern.softmax(dim=-1)
106
 
107
  V = self.hidden_to_heads(V)
108
- out = einsum("batch num_heads seqlen_q seqlen_k, batch num_heads seqlen_k head_size -> batch num_heads seqlen_q head_size", attention_score, V)
109
- out = rearrange("b nh s hs -> b s (nh hs)")
 
 
 
 
 
110
  out = self.project_out(out)
111
 
112
 
 
3
  import torch.functional as F
4
  import torch.optim as optim
5
  import wandb
6
+ from fancy_einsum import einsum
7
  from einops import rearrange, repeat, reduce
8
  from utils import OsSoluConfig
9
 
 
22
  self.final_ln = nn.LayerNorm(config.d_model, config.ln_eps)
23
 
24
  def forward(self, x: t.Tensor) -> t.Tensor:
25
+ positional_embeddings = self.embed_positions(t.arange(x.size(1), device=x.device))
26
  token_embeddings = self.embed_tokens(x)
27
  embeddings = positional_embeddings + token_embeddings
28
  out = self.dropout(embeddings)
 
69
  super().__init__()
70
  self.num_heads = config.num_heads
71
  self.d_model = config.d_model
72
+ self.project_q = nn.Linear(config.d_model, config.d_model)
73
+ self.project_k = nn.Linear(config.d_model, config.d_model)
74
+ self.project_v = nn.Linear(config.d_model, config.d_model)
75
  self.project_out = nn.Linear(config.d_model, config.d_model)
76
  self.LARGE_NEGATIVE_VALUE = -1e5
77
 
 
84
 
85
  Q = self.hidden_to_heads(Q)
86
  K = self.hidden_to_heads(K)
87
+ attention_pattern = einsum(
88
+ "batch num_heads seqlen_q head_size, "
89
+ "batch num_heads seqlen_k head_size ->"
90
+ "batch num_heads seqlen_q seqlen_k",
91
+ Q, K)
92
 
93
  return attention_pattern
94
 
 
99
 
100
  # Masking attention. Since GPT is unidirectional, it should only attend to previous tokens.
101
  if seqlen > 1:
102
+ fst_range = t.arange(seqlen, device=x.device).unsqueeze(0).T
103
+ snd_range = t.arange(seqlen, device=x.device).unsqueeze(0)
104
  bool_array = fst_range < snd_range
105
+ attention_pattern[..., bool_array] = self.LARGE_NEGATIVE_VALUE
106
 
107
 
108
  attention_pattern = attention_pattern / t.sqrt(t.tensor(self.d_model // self.num_heads))
109
  attention_score = attention_pattern.softmax(dim=-1)
110
 
111
  V = self.hidden_to_heads(V)
112
+ out = einsum(
113
+ "batch num_heads seqlen_q seqlen_k,"
114
+ "batch num_heads seqlen_k head_size ->"
115
+ "batch num_heads seqlen_q head_size",
116
+ attention_score, V)
117
+
118
+ out = rearrange(out, "b nh s hs -> b s (nh hs)")
119
  out = self.project_out(out)
120
 
121
 
requirements.txt CHANGED
@@ -9,6 +9,7 @@ notebook
9
  numpy-stl
10
  plotly
11
  torch
 
12
  tqdm
13
  wandb
14
  zstandard
 
9
  numpy-stl
10
  plotly
11
  torch
12
+ transformers
13
  tqdm
14
  wandb
15
  zstandard
utils.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  class OsSoluConfig:
2
  """A class to hold hyperparameters for the model itself and for the training process."""
3
 
@@ -32,4 +35,42 @@ class OsSoluConfig:
32
  self.num_heads = args["num_heads"]
33
  self.optimiser_type = args["optimiser_type"]
34
  self.self_attention_type = args["self_attention_type"]
35
- self.vocab_size = args["vocab_size"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from einops import rearrange
3
+
4
  class OsSoluConfig:
5
  """A class to hold hyperparameters for the model itself and for the training process."""
6
 
 
35
  self.num_heads = args["num_heads"]
36
  self.optimiser_type = args["optimiser_type"]
37
  self.self_attention_type = args["self_attention_type"]
38
+ self.vocab_size = args["vocab_size"]
39
+
40
+ def tokenise(batch, tokeniser, num_gpus: int = 1, context_length: int = 1024):
41
+ """Tokenise a batch of text data. This implementation is idiosyncratic to the Pile dataset, but can be easily modified to work with e.g. C4.
42
+
43
+ Args:
44
+ batch (dict): The batch of text, as a dict with a 'text' field.
45
+ tokeniser (-): A huggingface-API tokeniser, of type returned by AutoTokenizer.from_pretrained (depends on model chosen).
46
+ num_gpus (int, optional): The number of GPUs available for data parallel training. Defaults to 1.
47
+ context_length (int, optional): The context length of the model that will be trained on this data. Defaults to 1024.
48
+
49
+ Returns:
50
+ dict: A single field dictionary, 'text', whose value is a tensor of shape (batch_size, sequence_length) containing tokenised sequences.
51
+ """
52
+ batch = batch["text"]
53
+ full_text = tokeniser.eos_token.join(batch)
54
+
55
+ # Divide entire batch among all GPUs available.
56
+ seq_len = len(full_text)//num_gpus
57
+ sequence_list = [full_text[i*seq_len:(i+1)*seq_len] for i in range(num_gpus)]
58
+
59
+ # Tokenise sequences, removing padding tokens.
60
+ all_tokens = tokeniser(sequence_list, return_tensors="pt", padding=True)["input_ids"].flatten()
61
+ all_tokens = all_tokens[all_tokens != tokeniser.pad_token_id]
62
+
63
+ # Reshape all_tokens to be (batch_size x sequence_length) where each sequence has
64
+ # a "beginning of sequence" token prepended to it.
65
+ num_tokens = len(all_tokens)
66
+ current_batch_size = num_tokens // (context_length-1)
67
+ all_tokens = all_tokens[:(context_length-1)*current_batch_size]
68
+ all_tokens = rearrange(all_tokens, "(batch_size seq_len) -> batch_size seq_len", batch_size=current_batch_size, seq_len=context_length-1)
69
+ prefix = np.full((current_batch_size, 1), tokeniser.bos_token_id, dtype=np.int64)
70
+
71
+ tokenised_text = np.concatenate([prefix, all_tokens], axis=1)
72
+ assert tokenised_text.shape == (current_batch_size, context_length)
73
+ print(f"{current_batch_size=}, {context_length=}")
74
+ return {"text": tokenised_text}
75
+
76
+