xRunda commited on
Commit
5dcfa45
·
1 Parent(s): 2f2f78b

首次提交

Browse files
assets/gpt2_124M_loss.png ADDED
assets/nanogpt.jpg ADDED
bench.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A much shorter version of train.py for benchmarking
3
+ """
4
+ import os
5
+ from contextlib import nullcontext
6
+ import numpy as np
7
+ import time
8
+ import torch
9
+ from model import GPTConfig, GPT
10
+
11
+ # -----------------------------------------------------------------------------
12
+ batch_size = 12
13
+ block_size = 1024
14
+ bias = False
15
+ real_data = True
16
+ seed = 1337
17
+ device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
18
+ dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
19
+ compile = True # use PyTorch 2.0 to compile the model to be faster
20
+ profile = False # use pytorch profiler, or just simple benchmarking?
21
+ exec(open('configurator.py').read()) # overrides from command line or config file
22
+ # -----------------------------------------------------------------------------
23
+
24
+ torch.manual_seed(seed)
25
+ torch.cuda.manual_seed(seed)
26
+ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
27
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
28
+ device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
29
+ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
30
+ ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
31
+
32
+ # data loading init
33
+ if real_data:
34
+ dataset = 'openwebtext'
35
+ data_dir = os.path.join('data', dataset)
36
+ train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
37
+ def get_batch(split):
38
+ data = train_data # note ignore split in benchmarking script
39
+ ix = torch.randint(len(data) - block_size, (batch_size,))
40
+ x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
41
+ y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
42
+ x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
43
+ return x, y
44
+ else:
45
+ # alternatively, if fixed data is desired to not care about data loading
46
+ x = torch.randint(50304, (batch_size, block_size), device=device)
47
+ y = torch.randint(50304, (batch_size, block_size), device=device)
48
+ get_batch = lambda split: (x, y)
49
+
50
+ # model init
51
+ gptconf = GPTConfig(
52
+ block_size = block_size, # how far back does the model look? i.e. context size
53
+ n_layer = 12, n_head = 12, n_embd = 768, # size of the model
54
+ dropout = 0, # for determinism
55
+ bias = bias,
56
+ )
57
+ model = GPT(gptconf)
58
+ model.to(device)
59
+
60
+ optimizer = model.configure_optimizers(weight_decay=1e-2, learning_rate=1e-4, betas=(0.9, 0.95), device_type=device_type)
61
+
62
+ if compile:
63
+ print("Compiling model...")
64
+ model = torch.compile(model) # pytorch 2.0
65
+
66
+ if profile:
67
+ # useful docs on pytorch profiler:
68
+ # - tutorial https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html
69
+ # - api https://pytorch.org/docs/stable/profiler.html#torch.profiler.profile
70
+ wait, warmup, active = 5, 5, 5
71
+ num_steps = wait + warmup + active
72
+ with torch.profiler.profile(
73
+ activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
74
+ schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1),
75
+ on_trace_ready=torch.profiler.tensorboard_trace_handler('./bench_log'),
76
+ record_shapes=False,
77
+ profile_memory=False,
78
+ with_stack=False, # incurs an additional overhead, disable if not needed
79
+ with_flops=True,
80
+ with_modules=False, # only for torchscript models atm
81
+ ) as prof:
82
+
83
+ X, Y = get_batch('train')
84
+ for k in range(num_steps):
85
+ with ctx:
86
+ logits, loss = model(X, Y)
87
+ X, Y = get_batch('train')
88
+ optimizer.zero_grad(set_to_none=True)
89
+ loss.backward()
90
+ optimizer.step()
91
+ lossf = loss.item()
92
+ print(f"{k}/{num_steps} loss: {lossf:.4f}")
93
+
94
+ prof.step() # notify the profiler at end of each step
95
+
96
+ else:
97
+
98
+ # simple benchmarking
99
+ torch.cuda.synchronize()
100
+ for stage, num_steps in enumerate([10, 20]): # burnin, then benchmark
101
+ t0 = time.time()
102
+ X, Y = get_batch('train')
103
+ for k in range(num_steps):
104
+ with ctx:
105
+ logits, loss = model(X, Y)
106
+ X, Y = get_batch('train')
107
+ optimizer.zero_grad(set_to_none=True)
108
+ loss.backward()
109
+ optimizer.step()
110
+ lossf = loss.item()
111
+ print(f"{k}/{num_steps} loss: {lossf:.4f}")
112
+ torch.cuda.synchronize()
113
+ t1 = time.time()
114
+ dt = t1-t0
115
+ mfu = model.estimate_mfu(batch_size * 1 * num_steps, dt)
116
+ if stage == 1:
117
+ print(f"time per iteration: {dt/num_steps*1000:.4f}ms, MFU: {mfu*100:.2f}%")
config/eval_gpt2.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # evaluate the base gpt2
2
+ # n_layer=12, n_head=12, n_embd=768
3
+ # 124M parameters
4
+ batch_size = 8
5
+ eval_iters = 500 # use more iterations to get good estimate
6
+ eval_only = True
7
+ wandb_log = False
8
+ init_from = 'gpt2'
config/eval_gpt2_large.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # evaluate the base gpt2
2
+ # n_layer=36, n_head=20, n_embd=1280
3
+ # 774M parameters
4
+ batch_size = 8
5
+ eval_iters = 500 # use more iterations to get good estimate
6
+ eval_only = True
7
+ wandb_log = False
8
+ init_from = 'gpt2-large'
config/eval_gpt2_medium.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # evaluate the base gpt2
2
+ # n_layer=24, n_head=16, n_embd=1024
3
+ # 350M parameters
4
+ batch_size = 8
5
+ eval_iters = 500 # use more iterations to get good estimate
6
+ eval_only = True
7
+ wandb_log = False
8
+ init_from = 'gpt2-medium'
config/eval_gpt2_xl.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # evaluate the base gpt2
2
+ # n_layer=48, n_head=25, n_embd=1600
3
+ # 1558M parameters
4
+ batch_size = 8
5
+ eval_iters = 500 # use more iterations to get good estimate
6
+ eval_only = True
7
+ wandb_log = False
8
+ init_from = 'gpt2-xl'
config/finetune_shakespeare.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ out_dir = 'out-shakespeare'
4
+ eval_interval = 5
5
+ eval_iters = 40
6
+ wandb_log = False # feel free to turn on
7
+ wandb_project = 'shakespeare'
8
+ wandb_run_name = 'ft-' + str(time.time())
9
+
10
+ dataset = 'shakespeare'
11
+ init_from = 'gpt2-xl' # this is the largest GPT-2 model
12
+
13
+ # only save checkpoints if the validation loss improves
14
+ always_save_checkpoint = False
15
+
16
+ # the number of examples per iter:
17
+ # 1 batch_size * 32 grad_accum * 1024 tokens = 32,768 tokens/iter
18
+ # shakespeare has 301,966 tokens, so 1 epoch ~= 9.2 iters
19
+ batch_size = 1
20
+ gradient_accumulation_steps = 32
21
+ max_iters = 20
22
+
23
+ # finetune at constant LR
24
+ learning_rate = 3e-5
25
+ decay_lr = False
config/train_gpt2.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config for training GPT-2 (124M) down to very nice loss of ~2.85 on 1 node of 8X A100 40GB
2
+ # launch as the following (e.g. in a screen session) and wait ~5 days:
3
+ # $ torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py
4
+
5
+ wandb_log = True
6
+ wandb_project = 'owt'
7
+ wandb_run_name='gpt2-124M'
8
+
9
+ # these make the total batch size be ~0.5M
10
+ # 12 batch size * 1024 block size * 5 gradaccum * 8 GPUs = 491,520
11
+ batch_size = 12
12
+ block_size = 1024
13
+ gradient_accumulation_steps = 5 * 8
14
+
15
+ # this makes total number of tokens be 300B
16
+ max_iters = 600000
17
+ lr_decay_iters = 600000
18
+
19
+ # eval stuff
20
+ eval_interval = 1000
21
+ eval_iters = 200
22
+ log_interval = 10
23
+
24
+ # weight decay
25
+ weight_decay = 1e-1
config/train_shakespeare_char.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train a miniature character-level shakespeare model
2
+ # good for debugging and playing on macbooks and such
3
+
4
+ out_dir = 'out-shakespeare-char'
5
+ eval_interval = 250 # keep frequent because we'll overfit
6
+ eval_iters = 200
7
+ log_interval = 10 # don't print too too often
8
+
9
+ # we expect to overfit on this small dataset, so only save when val improves
10
+ always_save_checkpoint = False
11
+
12
+ wandb_log = False # override via command line if you like
13
+ wandb_project = 'shakespeare-char'
14
+ wandb_run_name = 'mini-gpt'
15
+
16
+ dataset = 'shakespeare_char'
17
+ gradient_accumulation_steps = 1
18
+ batch_size = 64
19
+ block_size = 256 # context of up to 256 previous characters
20
+
21
+ # baby GPT model :)
22
+ n_layer = 6
23
+ n_head = 6
24
+ n_embd = 384
25
+ dropout = 0.2
26
+
27
+ learning_rate = 1e-3 # with baby networks can afford to go a bit higher
28
+ max_iters = 5000
29
+ lr_decay_iters = 5000 # make equal to max_iters usually
30
+ min_lr = 1e-4 # learning_rate / 10 usually
31
+ beta2 = 0.99 # make a bit bigger because number of tokens per iter is small
32
+
33
+ warmup_iters = 100 # not super necessary potentially
34
+
35
+ # on macbook also add
36
+ # device = 'cpu' # run on cpu only
37
+ # compile = False # do not torch compile the model
config/train_tangshi.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train a miniature character-level shakespeare model
2
+ # good for debugging and playing on macbooks and such
3
+
4
+ out_dir = 'out-gushi-char'
5
+ eval_interval = 250 # keep frequent because we'll overfit
6
+ eval_iters = 200
7
+ log_interval = 10 # don't print too too often
8
+
9
+ # we expect to overfit on this small dataset, so only save when val improves
10
+ always_save_checkpoint = False
11
+
12
+ wandb_log = False # override via command line if you like
13
+ wandb_project = 'gushi-char'
14
+ wandb_run_name = 'mini-gpt'
15
+
16
+ dataset = 'gushi'
17
+ gradient_accumulation_steps = 1
18
+ batch_size = 64
19
+ block_size = 256 # context of up to 256 previous characters
20
+
21
+ # baby GPT model :)
22
+ n_layer = 6
23
+ n_head = 6
24
+ n_embd = 384
25
+ dropout = 0.2
26
+
27
+ learning_rate = 1e-3 # with baby networks can afford to go a bit higher
28
+ max_iters = 5000
29
+ lr_decay_iters = 5000 # make equal to max_iters usually
30
+ min_lr = 1e-4 # learning_rate / 10 usually
31
+ beta2 = 0.99 # make a bit bigger because number of tokens per iter is small
32
+
33
+ warmup_iters = 100 # not super necessary potentially
34
+
35
+ # on macbook also add
36
+ # device = 'cpu' # run on cpu only
37
+ # compile = False # do not torch compile the model
configurator.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Poor Man's Configurator. Probably a terrible idea. Example usage:
3
+ $ python train.py config/override_file.py --batch_size=32
4
+ this will first run config/override_file.py, then override batch_size to 32
5
+
6
+ The code in this file will be run as follows from e.g. train.py:
7
+ >>> exec(open('configurator.py').read())
8
+
9
+ So it's not a Python module, it's just shuttling this code away from train.py
10
+ The code in this script then overrides the globals()
11
+
12
+ I know people are not going to love this, I just really dislike configuration
13
+ complexity and having to prepend config. to every single variable. If someone
14
+ comes up with a better simple Python solution I am all ears.
15
+ """
16
+
17
+ import sys
18
+ from ast import literal_eval
19
+
20
+ for arg in sys.argv[1:]:
21
+ if '=' not in arg:
22
+ # assume it's the name of a config file
23
+ assert not arg.startswith('--')
24
+ config_file = arg
25
+ print(f"Overriding config with {config_file}:")
26
+ with open(config_file) as f:
27
+ print(f.read())
28
+ exec(open(config_file).read())
29
+ else:
30
+ # assume it's a --key=value argument
31
+ assert arg.startswith('--')
32
+ key, val = arg.split('=')
33
+ key = key[2:]
34
+ if key in globals():
35
+ try:
36
+ # attempt to eval it it (e.g. if bool, number, or etc)
37
+ attempt = literal_eval(val)
38
+ except (SyntaxError, ValueError):
39
+ # if that goes wrong, just use the string
40
+ attempt = val
41
+ # ensure the types match ok
42
+ assert type(attempt) == type(globals()[key])
43
+ # cross fingers
44
+ print(f"Overriding: {key} = {attempt}")
45
+ globals()[key] = attempt
46
+ else:
47
+ raise ValueError(f"Unknown config key: {key}")
data/gushi/getdata.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # _*_ coding:utf-8 _*_
2
+ import glob
3
+ import json
4
+ datas=glob.glob("chinese-poetry/全唐诗/poet.*json")
5
+
6
+ for data in datas:
7
+ with open(data,"r",encoding='utf-8') as fp:
8
+ tangshi=json.load(fp)
9
+ for each_shi in tangshi:
10
+ if len(each_shi["paragraphs"])==2 and len(each_shi["paragraphs"][0])==16 and len(each_shi["paragraphs"][1])==16:
11
+ with open('data/gushi/gushi.txt',"a",encoding='utf-8') as f:
12
+ f.write("".join(each_shi["paragraphs"]))
13
+ f.write("\n")
data/gushi/gushi.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/gushi/input.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/gushi/meta.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b4def92f4c91321cf56a2c9a107f5f5d9cd975f9057f971f80ee004dc624812
3
+ size 139966
data/gushi/prepare.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prepare the Shakespeare dataset for character-level language modeling.
3
+ So instead of encoding with GPT-2 BPE tokens, we just map characters to ints.
4
+ Will save train.bin, val.bin containing the ids, and meta.pkl containing the
5
+ encoder and decoder and some other related info.
6
+ """
7
+ import os
8
+ import pickle
9
+ import requests
10
+ import numpy as np
11
+
12
+ # download the tiny shakespeare dataset
13
+ input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt')
14
+ if not os.path.exists(input_file_path):
15
+ data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
16
+ with open(input_file_path, 'w') as f:
17
+ f.write(requests.get(data_url).text)
18
+
19
+ with open("data/gushi/gushi.txt", 'r',encoding='utf-8') as f:
20
+ data = f.read()
21
+ print(f"length of dataset in characters: {len(data):,}")
22
+
23
+ # get all the unique characters that occur in this text
24
+ chars = sorted(list(set(data)))
25
+ vocab_size = len(chars)
26
+ print("all the unique characters:", ''.join(chars))
27
+ print(f"vocab size: {vocab_size:,}")
28
+
29
+ # create a mapping from characters to integers
30
+ stoi = { ch:i for i,ch in enumerate(chars) }
31
+ itos = { i:ch for i,ch in enumerate(chars) }
32
+ def encode(s):
33
+ return [stoi[c] for c in s] # encoder: take a string, output a list of integers
34
+ def decode(l):
35
+ return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
36
+
37
+ # create the train and test splits
38
+ n = len(data)
39
+ train_data = data[:int(n*0.9)]
40
+ val_data = data[int(n*0.9):]
41
+
42
+ # encode both to integers
43
+ train_ids = encode(train_data)
44
+ val_ids = encode(val_data)
45
+ print(f"train has {len(train_ids):,} tokens")
46
+ print(f"val has {len(val_ids):,} tokens")
47
+
48
+ # export to bin files
49
+ train_ids = np.array(train_ids, dtype=np.uint16)
50
+ val_ids = np.array(val_ids, dtype=np.uint16)
51
+ train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin'))
52
+ val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin'))
53
+
54
+ # save the meta information as well, to help us encode/decode later
55
+ meta = {
56
+ 'vocab_size': vocab_size,
57
+ 'itos': itos,
58
+ 'stoi': stoi,
59
+ }
60
+ with open(os.path.join(os.path.dirname(__file__), 'meta.pkl'), 'wb') as f:
61
+ pickle.dump(meta, f)
62
+
63
+ # length of dataset in characters: 1115394
64
+ # all the unique characters:
65
+ # !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
66
+ # vocab size: 65
67
+ # train has 1003854 tokens
68
+ # val has 111540 tokens
data/gushi/train.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da11845b8d5ccafe712a278cf019a5e1b8895f28cedaa62b80cf5d56f5c3c3cf
3
+ size 4990846
data/gushi/val.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91f7ee08afb970103d6f1d07d209d903c3b44882ed00c2907d3264289a5ed28b
3
+ size 554540
model.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Full definition of a GPT Language Model, all of it in this single file.
3
+ References:
4
+ 1) the official GPT-2 TensorFlow implementation released by OpenAI:
5
+ https://github.com/openai/gpt-2/blob/master/src/model.py
6
+ 2) huggingface/transformers PyTorch implementation:
7
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
8
+ """
9
+
10
+ import math
11
+ import inspect
12
+ from dataclasses import dataclass
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn import functional as F
17
+
18
+ class LayerNorm(nn.Module):
19
+ """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
20
+
21
+ def __init__(self, ndim, bias):
22
+ super().__init__()
23
+ self.weight = nn.Parameter(torch.ones(ndim))
24
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
25
+
26
+ def forward(self, input):
27
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
28
+
29
+ class CausalSelfAttention(nn.Module):
30
+
31
+ def __init__(self, config):
32
+ super().__init__()
33
+ assert config.n_embd % config.n_head == 0
34
+ # key, query, value projections for all heads, but in a batch
35
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
36
+ # output projection
37
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
38
+ # regularization
39
+ self.attn_dropout = nn.Dropout(config.dropout)
40
+ self.resid_dropout = nn.Dropout(config.dropout)
41
+ self.n_head = config.n_head
42
+ self.n_embd = config.n_embd
43
+ self.dropout = config.dropout
44
+ # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
45
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
46
+ if not self.flash:
47
+ print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
48
+ # causal mask to ensure that attention is only applied to the left in the input sequence
49
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
50
+ .view(1, 1, config.block_size, config.block_size))
51
+
52
+ def forward(self, x):
53
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
54
+
55
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
56
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
57
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
58
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
59
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
60
+
61
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
62
+ if self.flash:
63
+ # efficient attention using Flash Attention CUDA kernels
64
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
65
+ else:
66
+ # manual implementation of attention
67
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
68
+ att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
69
+ att = F.softmax(att, dim=-1)
70
+ att = self.attn_dropout(att)
71
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
72
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
73
+
74
+ # output projection
75
+ y = self.resid_dropout(self.c_proj(y))
76
+ return y
77
+
78
+ class MLP(nn.Module):
79
+
80
+ def __init__(self, config):
81
+ super().__init__()
82
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
83
+ self.gelu = nn.GELU()
84
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
85
+ self.dropout = nn.Dropout(config.dropout)
86
+
87
+ def forward(self, x):
88
+ x = self.c_fc(x)
89
+ x = self.gelu(x)
90
+ x = self.c_proj(x)
91
+ x = self.dropout(x)
92
+ return x
93
+
94
+ class Block(nn.Module):
95
+
96
+ def __init__(self, config):
97
+ super().__init__()
98
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
99
+ self.attn = CausalSelfAttention(config)
100
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
101
+ self.mlp = MLP(config)
102
+
103
+ def forward(self, x):
104
+ x = x + self.attn(self.ln_1(x))
105
+ x = x + self.mlp(self.ln_2(x))
106
+ return x
107
+
108
+ @dataclass
109
+ class GPTConfig:
110
+ block_size: int = 1024
111
+ vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
112
+ n_layer: int = 12
113
+ n_head: int = 12
114
+ n_embd: int = 768
115
+ dropout: float = 0.0
116
+ bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
117
+
118
+ class GPT(nn.Module):
119
+
120
+ def __init__(self, config):
121
+ super().__init__()
122
+ assert config.vocab_size is not None
123
+ assert config.block_size is not None
124
+ self.config = config
125
+
126
+ self.transformer = nn.ModuleDict(dict(
127
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
128
+ wpe = nn.Embedding(config.block_size, config.n_embd),
129
+ drop = nn.Dropout(config.dropout),
130
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
131
+ ln_f = LayerNorm(config.n_embd, bias=config.bias),
132
+ ))
133
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
134
+ # with weight tying when using torch.compile() some warnings get generated:
135
+ # "UserWarning: functional_call was passed multiple values for tied weights.
136
+ # This behavior is deprecated and will be an error in future versions"
137
+ # not 100% sure what this is, so far seems to be harmless. TODO investigate
138
+ self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
139
+
140
+ # init all weights
141
+ self.apply(self._init_weights)
142
+ # apply special scaled init to the residual projections, per GPT-2 paper
143
+ for pn, p in self.named_parameters():
144
+ if pn.endswith('c_proj.weight'):
145
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
146
+
147
+ # report number of parameters
148
+ print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
149
+
150
+ def get_num_params(self, non_embedding=True):
151
+ """
152
+ Return the number of parameters in the model.
153
+ For non-embedding count (default), the position embeddings get subtracted.
154
+ The token embeddings would too, except due to the parameter sharing these
155
+ params are actually used as weights in the final layer, so we include them.
156
+ """
157
+ n_params = sum(p.numel() for p in self.parameters())
158
+ if non_embedding:
159
+ n_params -= self.transformer.wpe.weight.numel()
160
+ return n_params
161
+
162
+ def _init_weights(self, module):
163
+ if isinstance(module, nn.Linear):
164
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
165
+ if module.bias is not None:
166
+ torch.nn.init.zeros_(module.bias)
167
+ elif isinstance(module, nn.Embedding):
168
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
169
+
170
+ def forward(self, idx, targets=None):
171
+ device = idx.device
172
+ b, t = idx.size()
173
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
174
+ pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
175
+
176
+ # forward the GPT model itself
177
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
178
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
179
+ x = self.transformer.drop(tok_emb + pos_emb)
180
+ for block in self.transformer.h:
181
+ x = block(x)
182
+ x = self.transformer.ln_f(x)
183
+
184
+ if targets is not None:
185
+ # if we are given some desired targets also calculate the loss
186
+ logits = self.lm_head(x)
187
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
188
+ else:
189
+ # inference-time mini-optimization: only forward the lm_head on the very last position
190
+ logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
191
+ loss = None
192
+
193
+ return logits, loss
194
+
195
+ def crop_block_size(self, block_size):
196
+ # model surgery to decrease the block size if necessary
197
+ # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
198
+ # but want to use a smaller block size for some smaller, simpler model
199
+ assert block_size <= self.config.block_size
200
+ self.config.block_size = block_size
201
+ self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
202
+ for block in self.transformer.h:
203
+ if hasattr(block.attn, 'bias'):
204
+ block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
205
+
206
+ @classmethod
207
+ def from_pretrained(cls, model_type, override_args=None):
208
+ assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
209
+ override_args = override_args or {} # default to empty dict
210
+ # only dropout can be overridden see more notes below
211
+ assert all(k == 'dropout' for k in override_args)
212
+ from transformers import GPT2LMHeadModel
213
+ print("loading weights from pretrained gpt: %s" % model_type)
214
+
215
+ # n_layer, n_head and n_embd are determined from model_type
216
+ config_args = {
217
+ 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
218
+ 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
219
+ 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
220
+ 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
221
+ }[model_type]
222
+ print("forcing vocab_size=50257, block_size=1024, bias=True")
223
+ config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
224
+ config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
225
+ config_args['bias'] = True # always True for GPT model checkpoints
226
+ # we can override the dropout rate, if desired
227
+ if 'dropout' in override_args:
228
+ print(f"overriding dropout rate to {override_args['dropout']}")
229
+ config_args['dropout'] = override_args['dropout']
230
+ # create a from-scratch initialized minGPT model
231
+ config = GPTConfig(**config_args)
232
+ model = GPT(config)
233
+ sd = model.state_dict()
234
+ sd_keys = sd.keys()
235
+ sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
236
+
237
+ # init a huggingface/transformers model
238
+ model_hf = GPT2LMHeadModel.from_pretrained(model_type)
239
+ sd_hf = model_hf.state_dict()
240
+
241
+ # copy while ensuring all of the parameters are aligned and match in names and shapes
242
+ sd_keys_hf = sd_hf.keys()
243
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
244
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
245
+ transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
246
+ # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
247
+ # this means that we have to transpose these weights when we import them
248
+ assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
249
+ for k in sd_keys_hf:
250
+ if any(k.endswith(w) for w in transposed):
251
+ # special treatment for the Conv1D weights we need to transpose
252
+ assert sd_hf[k].shape[::-1] == sd[k].shape
253
+ with torch.no_grad():
254
+ sd[k].copy_(sd_hf[k].t())
255
+ else:
256
+ # vanilla copy over the other parameters
257
+ assert sd_hf[k].shape == sd[k].shape
258
+ with torch.no_grad():
259
+ sd[k].copy_(sd_hf[k])
260
+
261
+ return model
262
+
263
+ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
264
+ # start with all of the candidate parameters
265
+ param_dict = {pn: p for pn, p in self.named_parameters()}
266
+ # filter out those that do not require grad
267
+ param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
268
+ # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
269
+ # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
270
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
271
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
272
+ optim_groups = [
273
+ {'params': decay_params, 'weight_decay': weight_decay},
274
+ {'params': nodecay_params, 'weight_decay': 0.0}
275
+ ]
276
+ num_decay_params = sum(p.numel() for p in decay_params)
277
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
278
+ print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
279
+ print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
280
+ # Create AdamW optimizer and use the fused version if it is available
281
+ fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
282
+ use_fused = fused_available and device_type == 'cuda'
283
+ extra_args = dict(fused=True) if use_fused else dict()
284
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
285
+ print(f"using fused AdamW: {use_fused}")
286
+
287
+ return optimizer
288
+
289
+ def estimate_mfu(self, fwdbwd_per_iter, dt):
290
+ """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
291
+ # first estimate the number of flops we do per iteration.
292
+ # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
293
+ N = self.get_num_params()
294
+ cfg = self.config
295
+ L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
296
+ flops_per_token = 6*N + 12*L*H*Q*T
297
+ flops_per_fwdbwd = flops_per_token * T
298
+ flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
299
+ # express our flops throughput as ratio of A100 bfloat16 peak flops
300
+ flops_achieved = flops_per_iter * (1.0/dt) # per second
301
+ flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
302
+ mfu = flops_achieved / flops_promised
303
+ return mfu
304
+
305
+ @torch.no_grad()
306
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
307
+ """
308
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
309
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
310
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
311
+ """
312
+ for _ in range(max_new_tokens):
313
+ # if the sequence context is growing too long we must crop it at block_size
314
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
315
+ # forward the model to get the logits for the index in the sequence
316
+ logits, _ = self(idx_cond)
317
+ # pluck the logits at the final step and scale by desired temperature
318
+ logits = logits[:, -1, :] / temperature
319
+ # optionally crop the logits to only the top k options
320
+ if top_k is not None:
321
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
322
+ logits[logits < v[:, [-1]]] = -float('Inf')
323
+ # apply softmax to convert logits to (normalized) probabilities
324
+ probs = F.softmax(logits, dim=-1)
325
+ # sample from the distribution
326
+ idx_next = torch.multinomial(probs, num_samples=1)
327
+ # append sampled index to the running sequence and continue
328
+ idx = torch.cat((idx, idx_next), dim=1)
329
+
330
+ return idx
out-gushi-char/ckpt.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ffe9d8b81fd9e6c2b6a9bed73f959d6b0f7ad7f96ac6e1640c634070ea564a9b
3
+ size 166910314
requirements.txt ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohttp==3.8.5
2
+ aiosignal==1.3.1
3
+ appdirs==1.4.4
4
+ async-timeout==4.0.3
5
+ attrs==21.2.0
6
+ Automat==20.2.0
7
+ bcrypt==3.2.0
8
+ blinker==1.4
9
+ certifi==2023.7.22
10
+ chardet==4.0.0
11
+ charset-normalizer==3.2.0
12
+ click==8.0.3
13
+ cmake==3.27.2
14
+ colorama==0.4.4
15
+ command-not-found==0.3
16
+ configobj==5.0.6
17
+ constantly==15.1.0
18
+ cryptography==3.4.8
19
+ datasets==2.14.4
20
+ dbus-python==1.2.18
21
+ dill==0.3.7
22
+ distro==1.7.0
23
+ distro-info===1.1build1
24
+ docker-pycreds==0.4.0
25
+ filelock==3.12.2
26
+ frozenlist==1.4.0
27
+ fsspec==2023.6.0
28
+ gitdb==4.0.10
29
+ GitPython==3.1.32
30
+ httplib2==0.20.2
31
+ huggingface-hub==0.16.4
32
+ hyperlink==21.0.0
33
+ idna==3.3
34
+ importlib-metadata==4.6.4
35
+ incremental==21.3.0
36
+ jeepney==0.7.1
37
+ Jinja2==3.1.2
38
+ keyring==23.5.0
39
+ launchpadlib==1.10.16
40
+ lazr.restfulclient==0.14.4
41
+ lazr.uri==1.0.6
42
+ lit==16.0.6
43
+ MarkupSafe==2.1.3
44
+ more-itertools==8.10.0
45
+ mpmath==1.3.0
46
+ multidict==6.0.4
47
+ multiprocess==0.70.15
48
+ netifaces==0.11.0
49
+ networkx==3.1
50
+ numpy==1.25.2
51
+ nvidia-cublas-cu11==11.10.3.66
52
+ nvidia-cuda-cupti-cu11==11.7.101
53
+ nvidia-cuda-nvrtc-cu11==11.7.99
54
+ nvidia-cuda-runtime-cu11==11.7.99
55
+ nvidia-cudnn-cu11==8.5.0.96
56
+ nvidia-cufft-cu11==10.9.0.58
57
+ nvidia-curand-cu11==10.2.10.91
58
+ nvidia-cusolver-cu11==11.4.0.1
59
+ nvidia-cusparse-cu11==11.7.4.91
60
+ nvidia-nccl-cu11==2.14.3
61
+ nvidia-nvtx-cu11==11.7.91
62
+ oauthlib==3.2.0
63
+ packaging==23.1
64
+ pandas==2.0.3
65
+ pathtools==0.1.2
66
+ pexpect==4.8.0
67
+ protobuf==4.24.0
68
+ psutil==5.9.5
69
+ ptyprocess==0.7.0
70
+ pyarrow==12.0.1
71
+ pyasn1==0.4.8
72
+ pyasn1-modules==0.2.1
73
+ PyGObject==3.42.1
74
+ PyHamcrest==2.0.2
75
+ PyJWT==2.3.0
76
+ pyOpenSSL==21.0.0
77
+ pyparsing==2.4.7
78
+ pyserial==3.5
79
+ python-apt==2.4.0+ubuntu1
80
+ python-dateutil==2.8.2
81
+ python-debian===0.1.43ubuntu1
82
+ python-magic==0.4.24
83
+ pytz==2023.3
84
+ PyYAML==5.4.1
85
+ regex==2023.8.8
86
+ requests==2.31.0
87
+ safetensors==0.3.2
88
+ SecretStorage==3.3.1
89
+ sentry-sdk==1.29.2
90
+ service-identity==18.1.0
91
+ setproctitle==1.3.2
92
+ six==1.16.0
93
+ smmap==5.0.0
94
+ sos==4.4
95
+ ssh-import-id==5.11
96
+ sympy==1.12
97
+ systemd-python==234
98
+ tiktoken==0.4.0
99
+ tokenizers==0.13.3
100
+ torch==2.0.1
101
+ tqdm==4.66.1
102
+ transformers==4.31.0
103
+ triton==2.0.0
104
+ Twisted==22.1.0
105
+ typing_extensions==4.7.1
106
+ tzdata==2023.3
107
+ ubuntu-advantage-tools==8001
108
+ ubuntu-drivers-common==0.0.0
109
+ ufw==0.36.1
110
+ unattended-upgrades==0.1
111
+ urllib3==2.0.4
112
+ wadllib==1.3.6
113
+ wandb==0.15.8
114
+ xkit==0.0.0
115
+ xxhash==3.3.0
116
+ yarl==1.9.2
117
+ zipp==1.0.0
118
+ zope.interface==5.4.0
119
+ streamlit
sample.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sample from a trained model
3
+ """
4
+ import os
5
+ import pickle
6
+ from contextlib import nullcontext
7
+ import torch
8
+ import tiktoken
9
+ from model import GPTConfig, GPT
10
+
11
+ # -----------------------------------------------------------------------------
12
+ init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
13
+ out_dir = 'out-gushi-char' # ignored if init_from is not 'resume'
14
+ start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
15
+ num_samples = 1 # number of samples to draw
16
+ max_new_tokens = 500 # number of tokens generated in each sample
17
+ temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
18
+ top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
19
+ seed = 1337
20
+ device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
21
+ dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
22
+ compile = False # use PyTorch 2.0 to compile the model to be faster
23
+ exec(open('configurator.py').read()) # overrides from command line or config file
24
+ # -----------------------------------------------------------------------------
25
+
26
+ torch.manual_seed(seed)
27
+ torch.cuda.manual_seed(seed)
28
+ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
29
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
30
+ device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
31
+ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
32
+ ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
33
+
34
+ # model
35
+ if init_from == 'resume':
36
+ # init from a model saved in a specific directory
37
+ ckpt_path = os.path.join(out_dir, 'ckpt.pt')
38
+ checkpoint = torch.load(ckpt_path, map_location=device)
39
+ gptconf = GPTConfig(**checkpoint['model_args'])
40
+ model = GPT(gptconf)
41
+ state_dict = checkpoint['model']
42
+ unwanted_prefix = '_orig_mod.'
43
+ for k,v in list(state_dict.items()):
44
+ if k.startswith(unwanted_prefix):
45
+ state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
46
+ model.load_state_dict(state_dict)
47
+ elif init_from.startswith('gpt2'):
48
+ # init from a given GPT-2 model
49
+ model = GPT.from_pretrained(init_from, dict(dropout=0.0))
50
+
51
+ model.eval()
52
+ model.to(device)
53
+ if compile:
54
+ model = torch.compile(model) # requires PyTorch 2.0 (optional)
55
+
56
+ # look for the meta pickle in case it is available in the dataset folder
57
+ load_meta = False
58
+ if init_from == 'resume' and 'config' in checkpoint and 'dataset' in checkpoint['config']: # older checkpoints might not have these...
59
+ meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl')
60
+ load_meta = os.path.exists(meta_path)
61
+ if load_meta:
62
+ print(f"Loading meta from {meta_path}...")
63
+ with open(meta_path, 'rb') as f:
64
+ meta = pickle.load(f)
65
+ # TODO want to make this more general to arbitrary encoder/decoder schemes
66
+ stoi, itos = meta['stoi'], meta['itos']
67
+ encode = lambda s: [stoi[c] for c in s]
68
+ decode = lambda l: ''.join([itos[i] for i in l])
69
+ else:
70
+ # ok let's assume gpt-2 encodings by default
71
+ print("No meta.pkl found, assuming GPT-2 encodings...")
72
+ enc = tiktoken.get_encoding("gpt2")
73
+ encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
74
+ decode = lambda l: enc.decode(l)
75
+
76
+ # encode the beginning of the prompt
77
+ if start.startswith('FILE:'):
78
+ with open(start[5:], 'r', encoding='utf-8') as f:
79
+ start = f.read()
80
+ start_ids = encode(start)
81
+ x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
82
+
83
+ # run generation
84
+ with torch.no_grad():
85
+ with ctx:
86
+ for k in range(num_samples):
87
+ y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
88
+ print(decode(y[0].tolist()))
89
+ print('---------------')
scaling_laws.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
train.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This training script can be run both on a single gpu in debug mode,
3
+ and also in a larger training run with distributed data parallel (ddp).
4
+
5
+ To run on a single GPU, example:
6
+ $ python train.py --batch_size=32 --compile=False
7
+
8
+ To run with DDP on 4 gpus on 1 node, example:
9
+ $ torchrun --standalone --nproc_per_node=4 train.py
10
+
11
+ To run with DDP on 4 gpus across 2 nodes, example:
12
+ - Run on the first (master) node with example IP 123.456.123.456:
13
+ $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py
14
+ - Run on the worker node:
15
+ $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py
16
+ (If your cluster does not have Infiniband interconnect prepend NCCL_IB_DISABLE=1)
17
+ """
18
+
19
+ import os
20
+ import time
21
+ import math
22
+ import pickle
23
+ from contextlib import nullcontext
24
+
25
+ import numpy as np
26
+ import torch
27
+ from torch.nn.parallel import DistributedDataParallel as DDP
28
+ from torch.distributed import init_process_group, destroy_process_group
29
+
30
+ from model import GPTConfig, GPT
31
+
32
+ # -----------------------------------------------------------------------------
33
+ # default config values designed to train a gpt2 (124M) on OpenWebText
34
+ # I/O
35
+ out_dir = 'out'
36
+ eval_interval = 2000
37
+ log_interval = 1
38
+ eval_iters = 200
39
+ eval_only = False # if True, script exits right after the first eval
40
+ always_save_checkpoint = True # if True, always save a checkpoint after each eval
41
+ init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*'
42
+ # wandb logging
43
+ wandb_log = False # disabled by default
44
+ wandb_project = 'owt'
45
+ wandb_run_name = 'gpt2' # 'run' + str(time.time())
46
+ # data
47
+ dataset = 'openwebtext'
48
+ gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes
49
+ batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
50
+ block_size = 1024
51
+ # model
52
+ n_layer = 12
53
+ n_head = 12
54
+ n_embd = 768
55
+ dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
56
+ bias = False # do we use bias inside LayerNorm and Linear layers?
57
+ # adamw optimizer
58
+ learning_rate = 6e-4 # max learning rate
59
+ max_iters = 600000 # total number of training iterations
60
+ weight_decay = 1e-1
61
+ beta1 = 0.9
62
+ beta2 = 0.95
63
+ grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
64
+ # learning rate decay settings
65
+ decay_lr = True # whether to decay the learning rate
66
+ warmup_iters = 2000 # how many steps to warm up for
67
+ lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla
68
+ min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
69
+ # DDP settings
70
+ backend = 'nccl' # 'nccl', 'gloo', etc.
71
+ # system
72
+ device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
73
+ dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
74
+ compile = True # use PyTorch 2.0 to compile the model to be faster
75
+ # -----------------------------------------------------------------------------
76
+ config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
77
+ exec(open('configurator.py').read()) # overrides from command line or config file
78
+ config = {k: globals()[k] for k in config_keys} # will be useful for logging
79
+ # -----------------------------------------------------------------------------
80
+
81
+ # various inits, derived attributes, I/O setup
82
+ ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
83
+ if ddp:
84
+ init_process_group(backend=backend)
85
+ ddp_rank = int(os.environ['RANK'])
86
+ ddp_local_rank = int(os.environ['LOCAL_RANK'])
87
+ ddp_world_size = int(os.environ['WORLD_SIZE'])
88
+ device = f'cuda:{ddp_local_rank}'
89
+ torch.cuda.set_device(device)
90
+ master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
91
+ seed_offset = ddp_rank # each process gets a different seed
92
+ # world_size number of processes will be training simultaneously, so we can scale
93
+ # down the desired gradient accumulation iterations per process proportionally
94
+ assert gradient_accumulation_steps % ddp_world_size == 0
95
+ gradient_accumulation_steps //= ddp_world_size
96
+ else:
97
+ # if not ddp, we are running on a single gpu, and one process
98
+ master_process = True
99
+ seed_offset = 0
100
+ ddp_world_size = 1
101
+ tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size
102
+ print(f"tokens per iteration will be: {tokens_per_iter:,}")
103
+
104
+ if master_process:
105
+ os.makedirs(out_dir, exist_ok=True)
106
+ torch.manual_seed(1337 + seed_offset)
107
+ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
108
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
109
+ device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
110
+ # note: float16 data type will automatically use a GradScaler
111
+ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
112
+ ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
113
+
114
+ # poor man's data loader
115
+ data_dir = os.path.join('data', dataset)
116
+ train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
117
+ val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
118
+ def get_batch(split):
119
+ data = train_data if split == 'train' else val_data
120
+ ix = torch.randint(len(data) - block_size, (batch_size,))
121
+ x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
122
+ y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
123
+ if device_type == 'cuda':
124
+ # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
125
+ x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
126
+ else:
127
+ x, y = x.to(device), y.to(device)
128
+ return x, y
129
+
130
+ # init these up here, can override if init_from='resume' (i.e. from a checkpoint)
131
+ iter_num = 0
132
+ best_val_loss = 1e9
133
+
134
+ # attempt to derive vocab_size from the dataset
135
+ meta_path = os.path.join(data_dir, 'meta.pkl')
136
+ meta_vocab_size = None
137
+ if os.path.exists(meta_path):
138
+ with open(meta_path, 'rb') as f:
139
+ meta = pickle.load(f)
140
+ meta_vocab_size = meta['vocab_size']
141
+ print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")
142
+
143
+ # model init
144
+ model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
145
+ bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line
146
+ if init_from == 'scratch':
147
+ # init a new model from scratch
148
+ print("Initializing a new model from scratch")
149
+ # determine the vocab size we'll use for from-scratch training
150
+ if meta_vocab_size is None:
151
+ print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
152
+ model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304
153
+ gptconf = GPTConfig(**model_args)
154
+ model = GPT(gptconf)
155
+ elif init_from == 'resume':
156
+ print(f"Resuming training from {out_dir}")
157
+ # resume training from a checkpoint.
158
+ ckpt_path = os.path.join(out_dir, 'ckpt.pt')
159
+ checkpoint = torch.load(ckpt_path, map_location=device)
160
+ checkpoint_model_args = checkpoint['model_args']
161
+ # force these config attributes to be equal otherwise we can't even resume training
162
+ # the rest of the attributes (e.g. dropout) can stay as desired from command line
163
+ for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
164
+ model_args[k] = checkpoint_model_args[k]
165
+ # create the model
166
+ gptconf = GPTConfig(**model_args)
167
+ model = GPT(gptconf)
168
+ state_dict = checkpoint['model']
169
+ # fix the keys of the state dictionary :(
170
+ # honestly no idea how checkpoints sometimes get this prefix, have to debug more
171
+ unwanted_prefix = '_orig_mod.'
172
+ for k,v in list(state_dict.items()):
173
+ if k.startswith(unwanted_prefix):
174
+ state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
175
+ model.load_state_dict(state_dict)
176
+ iter_num = checkpoint['iter_num']
177
+ best_val_loss = checkpoint['best_val_loss']
178
+ elif init_from.startswith('gpt2'):
179
+ print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
180
+ # initialize from OpenAI GPT-2 weights
181
+ override_args = dict(dropout=dropout)
182
+ model = GPT.from_pretrained(init_from, override_args)
183
+ # read off the created config params, so we can store them into checkpoint correctly
184
+ for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
185
+ model_args[k] = getattr(model.config, k)
186
+ # crop down the model block size if desired, using model surgery
187
+ if block_size < model.config.block_size:
188
+ model.crop_block_size(block_size)
189
+ model_args['block_size'] = block_size # so that the checkpoint will have the right value
190
+ model.to(device)
191
+
192
+ # initialize a GradScaler. If enabled=False scaler is a no-op
193
+ scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
194
+
195
+ # optimizer
196
+ optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
197
+ if init_from == 'resume':
198
+ optimizer.load_state_dict(checkpoint['optimizer'])
199
+ checkpoint = None # free up memory
200
+
201
+ # compile the model
202
+ if compile:
203
+ print("compiling the model... (takes a ~minute)")
204
+ unoptimized_model = model
205
+ model = torch.compile(model) # requires PyTorch 2.0
206
+
207
+ # wrap model into DDP container
208
+ if ddp:
209
+ model = DDP(model, device_ids=[ddp_local_rank])
210
+
211
+ # helps estimate an arbitrarily accurate loss over either split using many batches
212
+ @torch.no_grad()
213
+ def estimate_loss():
214
+ out = {}
215
+ model.eval()
216
+ for split in ['train', 'val']:
217
+ losses = torch.zeros(eval_iters)
218
+ for k in range(eval_iters):
219
+ X, Y = get_batch(split)
220
+ with ctx:
221
+ logits, loss = model(X, Y)
222
+ losses[k] = loss.item()
223
+ out[split] = losses.mean()
224
+ model.train()
225
+ return out
226
+
227
+ # learning rate decay scheduler (cosine with warmup)
228
+ def get_lr(it):
229
+ # 1) linear warmup for warmup_iters steps
230
+ if it < warmup_iters:
231
+ return learning_rate * it / warmup_iters
232
+ # 2) if it > lr_decay_iters, return min learning rate
233
+ if it > lr_decay_iters:
234
+ return min_lr
235
+ # 3) in between, use cosine decay down to min learning rate
236
+ decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
237
+ assert 0 <= decay_ratio <= 1
238
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
239
+ return min_lr + coeff * (learning_rate - min_lr)
240
+
241
+ # logging
242
+ if wandb_log and master_process:
243
+ import wandb
244
+ wandb.init(project=wandb_project, name=wandb_run_name, config=config)
245
+
246
+ # training loop
247
+ X, Y = get_batch('train') # fetch the very first batch
248
+ t0 = time.time()
249
+ local_iter_num = 0 # number of iterations in the lifetime of this process
250
+ raw_model = model.module if ddp else model # unwrap DDP container if needed
251
+ running_mfu = -1.0
252
+ while True:
253
+
254
+ # determine and set the learning rate for this iteration
255
+ lr = get_lr(iter_num) if decay_lr else learning_rate
256
+ for param_group in optimizer.param_groups:
257
+ param_group['lr'] = lr
258
+
259
+ # evaluate the loss on train/val sets and write checkpoints
260
+ if iter_num % eval_interval == 0 and master_process:
261
+ losses = estimate_loss()
262
+ print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
263
+ if wandb_log:
264
+ wandb.log({
265
+ "iter": iter_num,
266
+ "train/loss": losses['train'],
267
+ "val/loss": losses['val'],
268
+ "lr": lr,
269
+ "mfu": running_mfu*100, # convert to percentage
270
+ })
271
+ if losses['val'] < best_val_loss or always_save_checkpoint:
272
+ best_val_loss = losses['val']
273
+ if iter_num > 0:
274
+ checkpoint = {
275
+ 'model': raw_model.state_dict(),
276
+ 'optimizer': optimizer.state_dict(),
277
+ 'model_args': model_args,
278
+ 'iter_num': iter_num,
279
+ 'best_val_loss': best_val_loss,
280
+ 'config': config,
281
+ }
282
+ print(f"saving checkpoint to {out_dir}")
283
+ torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
284
+ if iter_num == 0 and eval_only:
285
+ break
286
+
287
+ # forward backward update, with optional gradient accumulation to simulate larger batch size
288
+ # and using the GradScaler if data type is float16
289
+ for micro_step in range(gradient_accumulation_steps):
290
+ if ddp:
291
+ # in DDP training we only need to sync gradients at the last micro step.
292
+ # the official way to do this is with model.no_sync() context manager, but
293
+ # I really dislike that this bloats the code and forces us to repeat code
294
+ # looking at the source of that context manager, it just toggles this variable
295
+ model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
296
+ with ctx:
297
+ logits, loss = model(X, Y)
298
+ loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
299
+ # immediately async prefetch next batch while model is doing the forward pass on the GPU
300
+ X, Y = get_batch('train')
301
+ # backward pass, with gradient scaling if training in fp16
302
+ scaler.scale(loss).backward()
303
+ # clip the gradient
304
+ if grad_clip != 0.0:
305
+ scaler.unscale_(optimizer)
306
+ torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
307
+ # step the optimizer and scaler if training in fp16
308
+ scaler.step(optimizer)
309
+ scaler.update()
310
+ # flush the gradients as soon as we can, no need for this memory anymore
311
+ optimizer.zero_grad(set_to_none=True)
312
+
313
+ # timing and logging
314
+ t1 = time.time()
315
+ dt = t1 - t0
316
+ t0 = t1
317
+ if iter_num % log_interval == 0 and master_process:
318
+ # get loss as float. note: this is a CPU-GPU sync point
319
+ # scale up to undo the division above, approximating the true total loss (exact would have been a sum)
320
+ lossf = loss.item() * gradient_accumulation_steps
321
+ if local_iter_num >= 5: # let the training loop settle a bit
322
+ mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
323
+ running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
324
+ print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
325
+ iter_num += 1
326
+ local_iter_num += 1
327
+
328
+ # termination conditions
329
+ if iter_num > max_iters:
330
+ break
331
+
332
+ if ddp:
333
+ destroy_process_group()
transformer_sizing.ipynb ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "attachments": {},
5
+ "cell_type": "markdown",
6
+ "metadata": {},
7
+ "source": [
8
+ "### Transformer Theoretical Model\n",
9
+ "\n",
10
+ "This notebook stores a bunch of analysis about a Transformer, e.g. estimates the number of FLOPs, parameters, peak memory footprint, checkpoint size, etc."
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 1,
16
+ "metadata": {},
17
+ "outputs": [],
18
+ "source": [
19
+ "from collections import OrderedDict"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": 2,
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "# config_args = {\n",
29
+ "# 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params\n",
30
+ "# 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params\n",
31
+ "# 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params\n",
32
+ "# 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params\n",
33
+ "# }[model_type]\n",
34
+ "\n",
35
+ "block_size = 1024\n",
36
+ "vocab_size = 50257\n",
37
+ "n_layer = 12\n",
38
+ "n_head = 12\n",
39
+ "n_embd = 768\n",
40
+ "bias = False\n",
41
+ "assert not bias, \"this notebook assumes bias=False just for simplicity\""
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": 3,
47
+ "metadata": {},
48
+ "outputs": [
49
+ {
50
+ "name": "stdout",
51
+ "output_type": "stream",
52
+ "text": [
53
+ "we see: 124337664, expected: 124337664, match: True\n",
54
+ "name params ratio (%) \n",
55
+ "emebedding/position 786432 0.6325\n",
56
+ "embedding/token 38597376 31.0424\n",
57
+ "embedding 39383808 31.6749\n",
58
+ "attention/ln 768 0.0006\n",
59
+ "attention/kqv 1769472 1.4231\n",
60
+ "attention/proj 589824 0.4744\n",
61
+ "attention 2360064 1.8981\n",
62
+ "mlp/ln 768 0.0006\n",
63
+ "mlp/ffw 2359296 1.8975\n",
64
+ "mlp/proj 2359296 1.8975\n",
65
+ "mlp 4719360 3.7956\n",
66
+ "block 7079424 5.6937\n",
67
+ "transformer 84953088 68.3245\n",
68
+ "ln_f 768 0.0006\n",
69
+ "dense 0 0.0000\n",
70
+ "total 124337664 100.0000\n"
71
+ ]
72
+ }
73
+ ],
74
+ "source": [
75
+ "def params():\n",
76
+ " \"\"\" estimates the number of parameters in the model\"\"\"\n",
77
+ " out = OrderedDict()\n",
78
+ "\n",
79
+ " # token and position embeddings\n",
80
+ " out['emebedding/position'] = n_embd * block_size\n",
81
+ " out['embedding/token'] = n_embd * vocab_size\n",
82
+ " out['embedding'] = out['emebedding/position'] + out['embedding/token']\n",
83
+ "\n",
84
+ " # attention blocks\n",
85
+ " out['attention/ln'] = n_embd # note, bias=False in our LN\n",
86
+ " out['attention/kqv'] = n_embd * 3*n_embd\n",
87
+ " out['attention/proj'] = n_embd**2\n",
88
+ " out['attention'] = out['attention/ln'] + out['attention/kqv'] + out['attention/proj']\n",
89
+ "\n",
90
+ " # MLP blocks\n",
91
+ " ffw_size = 4*n_embd # feed forward size\n",
92
+ " out['mlp/ln'] = n_embd\n",
93
+ " out['mlp/ffw'] = n_embd * ffw_size\n",
94
+ " out['mlp/proj'] = ffw_size * n_embd\n",
95
+ " out['mlp'] = out['mlp/ln'] + out['mlp/ffw'] + out['mlp/proj']\n",
96
+ " \n",
97
+ " # the transformer and the rest of it\n",
98
+ " out['block'] = out['attention'] + out['mlp']\n",
99
+ " out['transformer'] = n_layer * out['block']\n",
100
+ " out['ln_f'] = n_embd # final layernorm\n",
101
+ " out['dense'] = 0 # 0 because of parameter sharing. This layer uses the weights from the embedding layer\n",
102
+ "\n",
103
+ " # total\n",
104
+ " out['total'] = out['embedding'] + out['transformer'] + out['ln_f'] + out['dense']\n",
105
+ "\n",
106
+ " return out\n",
107
+ "\n",
108
+ "# compare our param count to that reported by PyTorch\n",
109
+ "p = params()\n",
110
+ "params_total = p['total']\n",
111
+ "print(f\"we see: {params_total}, expected: {124337664}, match: {params_total == 124337664}\")\n",
112
+ "# create a header\n",
113
+ "print(f\"{'name':20s} {'params':10s} {'ratio (%)':10s}\")\n",
114
+ "for k,v in p.items():\n",
115
+ " print(f\"{k:20s} {v:10d} {v/params_total*100:10.4f}\")\n",
116
+ " "
117
+ ]
118
+ },
119
+ {
120
+ "cell_type": "code",
121
+ "execution_count": 4,
122
+ "metadata": {},
123
+ "outputs": [
124
+ {
125
+ "name": "stdout",
126
+ "output_type": "stream",
127
+ "text": [
128
+ "est checkpoint size: 1.49 GB\n",
129
+ "measured with wc -c ckpt.pt: 1542470366\n",
130
+ "fluff ratio: 103.38%\n"
131
+ ]
132
+ }
133
+ ],
134
+ "source": [
135
+ "# we can now calculate the size of each checkpoint\n",
136
+ "# params are stored in fp32, and the AdamW optimizer has 2 additional buffers per param for statistics\n",
137
+ "params_bytes = params_total*4\n",
138
+ "params_and_buffers_bytes = params_bytes + 2*params_bytes\n",
139
+ "print(f\"est checkpoint size: {params_and_buffers_bytes/1e9:.2f} GB\")\n",
140
+ "measured_bytes = 1542470366 # from wc -c ckpt.pt\n",
141
+ "print(f\"measured with wc -c ckpt.pt: {measured_bytes}\")\n",
142
+ "print(f\"fluff ratio: {measured_bytes/params_and_buffers_bytes*100:.2f}%\")"
143
+ ]
144
+ },
145
+ {
146
+ "attachments": {},
147
+ "cell_type": "markdown",
148
+ "metadata": {},
149
+ "source": [
150
+ "We can also estimate the ratio of our GPU memory that will be taken up just by the weights and the buffers inside the AdamW optimizer"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": 5,
156
+ "metadata": {},
157
+ "outputs": [
158
+ {
159
+ "name": "stdout",
160
+ "output_type": "stream",
161
+ "text": [
162
+ "memory ratio taken up just for parameters: 3.73%\n"
163
+ ]
164
+ }
165
+ ],
166
+ "source": [
167
+ "gpu_memory = 40e9 # 40 GB A100 GPU, roughly\n",
168
+ "print(f\"memory ratio taken up just for parameters: {params_and_buffers_bytes / gpu_memory * 100:.2f}%\")"
169
+ ]
170
+ },
171
+ {
172
+ "attachments": {},
173
+ "cell_type": "markdown",
174
+ "metadata": {},
175
+ "source": [
176
+ "i.e. not that much of the memory for this tiny model, most of the memory is activations (forward and backward). This of course changes dramatically for larger and larger models."
177
+ ]
178
+ },
179
+ {
180
+ "attachments": {},
181
+ "cell_type": "markdown",
182
+ "metadata": {},
183
+ "source": [
184
+ "Let's estimate FLOPs for a single forward pass."
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "code",
189
+ "execution_count": 6,
190
+ "metadata": {},
191
+ "outputs": [
192
+ {
193
+ "name": "stdout",
194
+ "output_type": "stream",
195
+ "text": [
196
+ "name flops ratio (%) \n",
197
+ "attention/kqv 3623878656 1.2426\n",
198
+ "attention/scores 1610612736 0.5522\n",
199
+ "attention/reduce 1610612736 0.5522\n",
200
+ "attention/proj 1207959552 0.4142\n",
201
+ "attention 8053063680 2.7612\n",
202
+ "mlp/ffw1 4831838208 1.6567\n",
203
+ "mlp/ffw2 4831838208 1.6567\n",
204
+ "mlp 9663676416 3.3135\n",
205
+ "block 17716740096 6.0747\n",
206
+ "transformer 212600881152 72.8963\n",
207
+ "dense 79047426048 27.1037\n",
208
+ "forward_total 291648307200 100.0000\n",
209
+ "backward_total 583296614400 200.0000\n",
210
+ "total 874944921600 300.0000\n"
211
+ ]
212
+ }
213
+ ],
214
+ "source": [
215
+ "def flops():\n",
216
+ " # we only count Weight FLOPs, all other layers (LayerNorm, Softmax, etc) are effectively irrelevant\n",
217
+ " # we count actual FLOPs, not MACs. Hence 2* all over the place\n",
218
+ " # basically for any matrix multiply A (BxC) @ B (CxD) -> (BxD) flops are 2*B*C*D\n",
219
+ "\n",
220
+ " out = OrderedDict()\n",
221
+ " head_size = n_embd // n_head\n",
222
+ "\n",
223
+ " # attention blocks\n",
224
+ " # 1) the projection to key, query, values\n",
225
+ " out['attention/kqv'] = 2 * block_size * (n_embd * 3*n_embd)\n",
226
+ " # 2) calculating the attention scores\n",
227
+ " out['attention/scores'] = 2 * block_size * block_size * n_embd\n",
228
+ " # 3) the reduction of the values (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)\n",
229
+ " out['attention/reduce'] = 2 * n_head * (block_size * block_size * head_size)\n",
230
+ " # 4) the final linear projection\n",
231
+ " out['attention/proj'] = 2 * block_size * (n_embd * n_embd)\n",
232
+ " out['attention'] = sum(out['attention/'+k] for k in ['kqv', 'scores', 'reduce', 'proj'])\n",
233
+ "\n",
234
+ " # MLP blocks\n",
235
+ " ffw_size = 4*n_embd # feed forward size\n",
236
+ " out['mlp/ffw1'] = 2 * block_size * (n_embd * ffw_size)\n",
237
+ " out['mlp/ffw2'] = 2 * block_size * (ffw_size * n_embd)\n",
238
+ " out['mlp'] = out['mlp/ffw1'] + out['mlp/ffw2']\n",
239
+ "\n",
240
+ " # the transformer and the rest of it\n",
241
+ " out['block'] = out['attention'] + out['mlp']\n",
242
+ " out['transformer'] = n_layer * out['block']\n",
243
+ " out['dense'] = 2 * block_size * (n_embd * vocab_size)\n",
244
+ "\n",
245
+ " # forward,backward,total\n",
246
+ " out['forward_total'] = out['transformer'] + out['dense']\n",
247
+ " out['backward_total'] = 2 * out['forward_total'] # use common estimate of bwd = 2*fwd\n",
248
+ " out['total'] = out['forward_total'] + out['backward_total']\n",
249
+ "\n",
250
+ " return out\n",
251
+ " \n",
252
+ "# compare our param count to that reported by PyTorch\n",
253
+ "f = flops()\n",
254
+ "flops_total = f['forward_total']\n",
255
+ "print(f\"{'name':20s} {'flops':14s} {'ratio (%)':10s}\")\n",
256
+ "for k,v in f.items():\n",
257
+ " print(f\"{k:20s} {v:14d} {v/flops_total*100:10.4f}\")\n",
258
+ " "
259
+ ]
260
+ },
261
+ {
262
+ "cell_type": "code",
263
+ "execution_count": 7,
264
+ "metadata": {},
265
+ "outputs": [
266
+ {
267
+ "name": "stdout",
268
+ "output_type": "stream",
269
+ "text": [
270
+ "palm_flops: 875062886400, flops: 874944921600, ratio: 1.0001\n"
271
+ ]
272
+ }
273
+ ],
274
+ "source": [
275
+ "# now here is an estimate copy pasted from the PaLM paper\n",
276
+ "# this formula is often used to calculate MFU (model flops utilization)\n",
277
+ "def palm_flops():\n",
278
+ " \"\"\"estimate of the model flops following PaLM paper formula\"\"\"\n",
279
+ " # non-embedding model parameters. note that we do not subtract the\n",
280
+ " # embedding/token params because those are tied and get used in the last layer.\n",
281
+ " N = params()['total'] - params()['emebedding/position']\n",
282
+ " L, H, Q, T = n_layer, n_head, n_embd//n_head, block_size\n",
283
+ " mf_per_token = 6*N + 12*L*H*Q*T\n",
284
+ " mf = mf_per_token * block_size\n",
285
+ " return mf\n",
286
+ "\n",
287
+ "print(f\"palm_flops: {palm_flops():d}, flops: {flops()['total']:d}, ratio: {palm_flops()/flops()['total']:.4f}\")"
288
+ ]
289
+ },
290
+ {
291
+ "attachments": {},
292
+ "cell_type": "markdown",
293
+ "metadata": {},
294
+ "source": [
295
+ "Ok they are quite similar, giving some confidence that my math in flops() function was ~ok. Now, A100 is cited at 312TFLOPS bfloat16 on tensor cores. So what is our model flops utilization (MFU)? I trained the model above with a batch_size of 20 and grad_accum of 5, which runs in about 755ms on a single A100 GPU. We get:"
296
+ ]
297
+ },
298
+ {
299
+ "cell_type": "code",
300
+ "execution_count": 8,
301
+ "metadata": {},
302
+ "outputs": [
303
+ {
304
+ "name": "stdout",
305
+ "output_type": "stream",
306
+ "text": [
307
+ "fraction of A100 used: 37.14%\n"
308
+ ]
309
+ }
310
+ ],
311
+ "source": [
312
+ "# here is what we currently roughly measure\n",
313
+ "batch_size = 20 * 5 # 5 is grad_accum, so total batch size is 100\n",
314
+ "measured_time = 0.755 # in seconds per iteration\n",
315
+ "measured_throughput = batch_size / measured_time\n",
316
+ "flops_achieved = f['total'] * measured_throughput\n",
317
+ "\n",
318
+ "# A100 is cited to be 312 TFLOPS of bloat16 running on tensor cores\n",
319
+ "a100_flops_promised = 312e12\n",
320
+ "\n",
321
+ "# the fraction of the A100 that we are using:\n",
322
+ "print(f\"fraction of A100 used: {flops_achieved / a100_flops_promised * 100:.2f}%\")"
323
+ ]
324
+ },
325
+ {
326
+ "attachments": {},
327
+ "cell_type": "markdown",
328
+ "metadata": {},
329
+ "source": [
330
+ "For reference, we'd prefer to be somewhere around 50%+, and not just for a single GPU but for an entire DDP run. So we still have some work to do, but at least we're within a factor of ~2X of what is achievable with this GPU."
331
+ ]
332
+ },
333
+ {
334
+ "cell_type": "code",
335
+ "execution_count": 9,
336
+ "metadata": {},
337
+ "outputs": [
338
+ {
339
+ "name": "stdout",
340
+ "output_type": "stream",
341
+ "text": [
342
+ "time needed to train the model: 3.46 days\n"
343
+ ]
344
+ }
345
+ ],
346
+ "source": [
347
+ "# Finally let's check out the 6ND approximation as total cost of training in FLOPs\n",
348
+ "model_size = params()['total'] # this is number of parameters, N\n",
349
+ "tokens_num = 300e9 # 300B tokens, this is dataset size in tokens, D\n",
350
+ "a100_flops = 312e12 # 312 TFLOPS\n",
351
+ "assumed_mfu = 0.3 # assume this model flops utilization (take the current 37% from above and add some DDP overhead)\n",
352
+ "flops_throughput = a100_flops * 8 * assumed_mfu # assume an 8XA100 node at 30% utilization\n",
353
+ "flops_needed = 6 * model_size * tokens_num # 6ND\n",
354
+ "time_needed_s = flops_needed / flops_throughput # in seconds\n",
355
+ "print(f\"time needed to train the model: {time_needed_s/3600/24:.2f} days\")"
356
+ ]
357
+ },
358
+ {
359
+ "attachments": {},
360
+ "cell_type": "markdown",
361
+ "metadata": {},
362
+ "source": [
363
+ "This is not a bad estimate at all. I trained this model and it converged in roughly 4 days. Btw as a good reference for where 6ND comes from and some intuition around it I recommend [Dzmitry's post](https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4)."
364
+ ]
365
+ },
366
+ {
367
+ "attachments": {},
368
+ "cell_type": "markdown",
369
+ "metadata": {},
370
+ "source": [
371
+ "Now, FLOPs are just one constraint, the other that we have to keep a close track of is the memory bandwidth. TODO estimate LOAD/STORE costs of our model later."
372
+ ]
373
+ }
374
+ ],
375
+ "metadata": {
376
+ "kernelspec": {
377
+ "display_name": "pytorch2",
378
+ "language": "python",
379
+ "name": "python3"
380
+ },
381
+ "language_info": {
382
+ "codemirror_mode": {
383
+ "name": "ipython",
384
+ "version": 3
385
+ },
386
+ "file_extension": ".py",
387
+ "mimetype": "text/x-python",
388
+ "name": "python",
389
+ "nbconvert_exporter": "python",
390
+ "pygments_lexer": "ipython3",
391
+ "version": "3.10.8"
392
+ },
393
+ "orig_nbformat": 4,
394
+ "vscode": {
395
+ "interpreter": {
396
+ "hash": "7f5833218766b48e6e35e4452ee875aac0e2188d05bbe5298f2c62b79f08b222"
397
+ }
398
+ }
399
+ },
400
+ "nbformat": 4,
401
+ "nbformat_minor": 2
402
+ }