init
Browse files- .gitignore +3 -0
- app.py +139 -0
- c4x.py +61 -0
- model2.pt +3 -0
- model3.pt +3 -0
- model4.pt +3 -0
- pile.py +107 -0
- pile_hf.py +50 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
wandb
|
2 |
+
__pycache__
|
3 |
+
.ipynb_checkpoints
|
app.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pip install accelerate datasets transformers huggingface_hub wandb gated_state_spaces_pytorch
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.optim import AdamW
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
9 |
+
|
10 |
+
import wandb
|
11 |
+
from tqdm import tqdm
|
12 |
+
from transformers import BloomForCausalLM, BloomTokenizerFast
|
13 |
+
from gated_state_spaces_pytorch import GatedStateSpacesLM
|
14 |
+
from gated_state_spaces_pytorch.autoregressive_wrapper import AutoregressiveWrapper
|
15 |
+
|
16 |
+
# from c4x import C4X
|
17 |
+
from pile_hf import ThePile, ThePileTokenized
|
18 |
+
from accelerate import Accelerator
|
19 |
+
|
20 |
+
|
21 |
+
def main():
|
22 |
+
accelerator = Accelerator(
|
23 |
+
log_with="wandb",
|
24 |
+
gradient_accumulation_steps=8192,
|
25 |
+
)
|
26 |
+
accelerator.init_trackers("gated-state-space")
|
27 |
+
|
28 |
+
emb_fn = "emb.pt"
|
29 |
+
model_name = "bigscience/bloomz-1b7"
|
30 |
+
if not os.path.isfile(emb_fn):
|
31 |
+
bloom = BloomForCausalLM.from_pretrained(model_name)
|
32 |
+
wte = bloom.transformer.word_embeddings.state_dict()
|
33 |
+
torch.save(wte, emb_fn)
|
34 |
+
else:
|
35 |
+
wte = torch.load(emb_fn)
|
36 |
+
|
37 |
+
f_emb = 2048
|
38 |
+
n_vocab = 250880
|
39 |
+
model = AutoregressiveWrapper(
|
40 |
+
GatedStateSpacesLM(
|
41 |
+
num_tokens=n_vocab,
|
42 |
+
dim=f_emb,
|
43 |
+
depth=24,
|
44 |
+
),
|
45 |
+
)
|
46 |
+
|
47 |
+
model.net.token_emb.requires_grad_(False)
|
48 |
+
model.net.token_emb.load_state_dict(wte)
|
49 |
+
|
50 |
+
to_logits = nn.Linear(f_emb, n_vocab, bias=False)
|
51 |
+
to_logits.requires_grad_(False)
|
52 |
+
to_logits.load_state_dict(wte)
|
53 |
+
|
54 |
+
model.net.to_logits = nn.Sequential(
|
55 |
+
nn.LayerNorm(f_emb),
|
56 |
+
to_logits,
|
57 |
+
)
|
58 |
+
model.load_state_dict(torch.load("model3.pt"))
|
59 |
+
model = model.to(accelerator.device)
|
60 |
+
|
61 |
+
if accelerator.is_main_process:
|
62 |
+
wandb.watch(model)
|
63 |
+
|
64 |
+
optim = AdamW(model.parameters(), 1e-4)
|
65 |
+
sch = CosineAnnealingWarmRestarts(
|
66 |
+
optim,
|
67 |
+
T_0=1000,
|
68 |
+
T_mult=2,
|
69 |
+
eta_min=1e-7,
|
70 |
+
)
|
71 |
+
|
72 |
+
bs = 1
|
73 |
+
kk = 2048
|
74 |
+
tok: BloomTokenizerFast = BloomTokenizerFast.from_pretrained(model_name)
|
75 |
+
dsx = ThePileTokenized(
|
76 |
+
ThePile("train"),
|
77 |
+
tokenizer=tok,
|
78 |
+
max_length=kk,
|
79 |
+
repeat_factor=4 / 3,
|
80 |
+
)
|
81 |
+
dlx = DataLoader(
|
82 |
+
dsx,
|
83 |
+
batch_size=bs,
|
84 |
+
num_workers=12,
|
85 |
+
)
|
86 |
+
|
87 |
+
prog = tqdm(dlx, disable=not accelerator.is_main_process)
|
88 |
+
|
89 |
+
model = accelerator.prepare(model)
|
90 |
+
optim, dlx, sch = accelerator.prepare(optim, dlx, sch)
|
91 |
+
|
92 |
+
optim.zero_grad()
|
93 |
+
for i, batch in enumerate(prog):
|
94 |
+
batch = batch.to(accelerator.device)
|
95 |
+
with accelerator.accumulate(model):
|
96 |
+
with accelerator.autocast():
|
97 |
+
los = model(batch)
|
98 |
+
accelerator.backward(los)
|
99 |
+
if accelerator.sync_gradients:
|
100 |
+
accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
101 |
+
optim.step()
|
102 |
+
optim.zero_grad()
|
103 |
+
if not accelerator.optimizer_step_was_skipped:
|
104 |
+
sch.step()
|
105 |
+
|
106 |
+
if i % 1000 == 0:
|
107 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
108 |
+
b, n = 1, 512
|
109 |
+
init = torch.tensor([[2]] * b).to(accelerator.device)
|
110 |
+
prd = unwrapped_model.generate(init, n)
|
111 |
+
prd = [tok.decode(p) for p in prd]
|
112 |
+
try:
|
113 |
+
accelerator.log(
|
114 |
+
dict(
|
115 |
+
text=wandb.Html(
|
116 |
+
"<hr>".join(p.replace("\n", "<br>") for p in prd)
|
117 |
+
)
|
118 |
+
),
|
119 |
+
step=i,
|
120 |
+
)
|
121 |
+
except Exception as ex:
|
122 |
+
accelerator.print("Failed to log to W&B...", ex)
|
123 |
+
sd = unwrapped_model.state_dict()
|
124 |
+
# sd.pop('net.to_logits.weight')
|
125 |
+
accelerator.save(sd, "model4.pt")
|
126 |
+
|
127 |
+
if i % 10 == 0:
|
128 |
+
accelerator.log(
|
129 |
+
dict(
|
130 |
+
loss=los.item(),
|
131 |
+
lr=optim.param_groups[0]["lr"],
|
132 |
+
),
|
133 |
+
step=i,
|
134 |
+
)
|
135 |
+
prog.set_postfix(loss=los.item())
|
136 |
+
|
137 |
+
|
138 |
+
if __name__ == "__main__":
|
139 |
+
main()
|
c4x.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# stream C4 dataset from Huggingface with GPT-2 Tokenizer for PyTorch Language Model Training
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
import random
|
5 |
+
from datasets import load_dataset
|
6 |
+
from transformers import BloomTokenizerFast
|
7 |
+
from torch.utils.data import Dataset, get_worker_info
|
8 |
+
|
9 |
+
|
10 |
+
def cycled(itr):
|
11 |
+
while True:
|
12 |
+
for itm in itr:
|
13 |
+
yield itm
|
14 |
+
|
15 |
+
class C4X(Dataset):
|
16 |
+
|
17 |
+
def __init__(self, seq_len=512, split='train'):
|
18 |
+
self.seq = seq_len
|
19 |
+
self.ds = load_dataset(
|
20 |
+
'c4',
|
21 |
+
name='en',
|
22 |
+
split=split,
|
23 |
+
streaming=True,
|
24 |
+
)
|
25 |
+
self.tok = BloomTokenizerFast.from_pretrained('bigscience/bloomz-1b7')
|
26 |
+
self.init = False
|
27 |
+
|
28 |
+
def __len__(self):
|
29 |
+
return 1_000_000_000
|
30 |
+
|
31 |
+
def _init(self):
|
32 |
+
if self.init:
|
33 |
+
return
|
34 |
+
wi = get_worker_info()
|
35 |
+
self.ds = cycled(
|
36 |
+
self.ds.shuffle(
|
37 |
+
seed=wi.seed,
|
38 |
+
buffer_size=10_000,
|
39 |
+
)
|
40 |
+
)
|
41 |
+
self.init = True
|
42 |
+
|
43 |
+
def _get_next(self):
|
44 |
+
self._init()
|
45 |
+
obj = next(self.ds)['text']
|
46 |
+
tkn = self.tok.encode(obj)
|
47 |
+
return tkn
|
48 |
+
|
49 |
+
def _get_full(self):
|
50 |
+
obj = []
|
51 |
+
while len(obj) < self.seq:
|
52 |
+
obj += self._get_next()
|
53 |
+
obj.append(self.tok.eos_token_id)
|
54 |
+
s = random.randint(0, len(obj)-self.seq)
|
55 |
+
return obj[s:s+self.seq]
|
56 |
+
|
57 |
+
def __getitem__(self, _):
|
58 |
+
return torch.tensor(self._get_full())
|
59 |
+
|
60 |
+
def decode(self, tkns):
|
61 |
+
return self.tok.decode(tkns)
|
model2.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:176c772feff0cf8504a46f872f6a32ae4269632b3e805e9437438f29268b795b
|
3 |
+
size 7609367025
|
model3.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c89f900da2bae9f79193ba785df8be4118d99135ffe66848e60f1ee6627b4bac
|
3 |
+
size 7609367025
|
model4.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c89f900da2bae9f79193ba785df8be4118d99135ffe66848e60f1ee6627b4bac
|
3 |
+
size 7609367025
|
pile.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import time
|
3 |
+
import random
|
4 |
+
from typing import Literal
|
5 |
+
|
6 |
+
import requests
|
7 |
+
import zstandard as zstd
|
8 |
+
from torch.utils.data import IterableDataset, get_worker_info
|
9 |
+
|
10 |
+
|
11 |
+
Subset = Literal["train", "val", "test"]
|
12 |
+
URLs = {
|
13 |
+
"val": [
|
14 |
+
"https://the-eye.eu/public/AI/pile/val.jsonl.zst",
|
15 |
+
],
|
16 |
+
"test": [
|
17 |
+
"https://the-eye.eu/public/AI/pile/test.jsonl.zst",
|
18 |
+
],
|
19 |
+
"train": [
|
20 |
+
"https://the-eye.eu/public/AI/pile/train/00.jsonl.zst",
|
21 |
+
"https://the-eye.eu/public/AI/pile/train/01.jsonl.zst",
|
22 |
+
"https://the-eye.eu/public/AI/pile/train/02.jsonl.zst",
|
23 |
+
"https://the-eye.eu/public/AI/pile/train/03.jsonl.zst",
|
24 |
+
"https://the-eye.eu/public/AI/pile/train/04.jsonl.zst",
|
25 |
+
"https://the-eye.eu/public/AI/pile/train/05.jsonl.zst",
|
26 |
+
"https://the-eye.eu/public/AI/pile/train/06.jsonl.zst",
|
27 |
+
"https://the-eye.eu/public/AI/pile/train/07.jsonl.zst",
|
28 |
+
"https://the-eye.eu/public/AI/pile/train/08.jsonl.zst",
|
29 |
+
"https://the-eye.eu/public/AI/pile/train/09.jsonl.zst",
|
30 |
+
"https://the-eye.eu/public/AI/pile/train/10.jsonl.zst",
|
31 |
+
"https://the-eye.eu/public/AI/pile/train/11.jsonl.zst",
|
32 |
+
"https://the-eye.eu/public/AI/pile/train/12.jsonl.zst",
|
33 |
+
"https://the-eye.eu/public/AI/pile/train/13.jsonl.zst",
|
34 |
+
"https://the-eye.eu/public/AI/pile/train/14.jsonl.zst",
|
35 |
+
"https://the-eye.eu/public/AI/pile/train/15.jsonl.zst",
|
36 |
+
"https://the-eye.eu/public/AI/pile/train/16.jsonl.zst",
|
37 |
+
"https://the-eye.eu/public/AI/pile/train/17.jsonl.zst",
|
38 |
+
"https://the-eye.eu/public/AI/pile/train/18.jsonl.zst",
|
39 |
+
"https://the-eye.eu/public/AI/pile/train/19.jsonl.zst",
|
40 |
+
"https://the-eye.eu/public/AI/pile/train/20.jsonl.zst",
|
41 |
+
"https://the-eye.eu/public/AI/pile/train/21.jsonl.zst",
|
42 |
+
"https://the-eye.eu/public/AI/pile/train/22.jsonl.zst",
|
43 |
+
"https://the-eye.eu/public/AI/pile/train/23.jsonl.zst",
|
44 |
+
"https://the-eye.eu/public/AI/pile/train/24.jsonl.zst",
|
45 |
+
"https://the-eye.eu/public/AI/pile/train/25.jsonl.zst",
|
46 |
+
"https://the-eye.eu/public/AI/pile/train/26.jsonl.zst",
|
47 |
+
"https://the-eye.eu/public/AI/pile/train/27.jsonl.zst",
|
48 |
+
"https://the-eye.eu/public/AI/pile/train/28.jsonl.zst",
|
49 |
+
"https://the-eye.eu/public/AI/pile/train/29.jsonl.zst",
|
50 |
+
],
|
51 |
+
}
|
52 |
+
|
53 |
+
|
54 |
+
def _read_line_from_stream(reader, initial_line="", buffer_size=4096):
|
55 |
+
line = initial_line
|
56 |
+
while True:
|
57 |
+
c = reader.read(buffer_size)
|
58 |
+
if not c:
|
59 |
+
raise StopIteration
|
60 |
+
line += c.decode("utf-8")
|
61 |
+
if "\n" in line:
|
62 |
+
break
|
63 |
+
return line.split("\n", 1)
|
64 |
+
|
65 |
+
|
66 |
+
def _line_streamer(reader, buffer_size=4096):
|
67 |
+
rest = ""
|
68 |
+
while True:
|
69 |
+
try:
|
70 |
+
line, rest = _read_line_from_stream(
|
71 |
+
reader,
|
72 |
+
rest,
|
73 |
+
buffer_size,
|
74 |
+
)
|
75 |
+
yield line
|
76 |
+
except StopIteration:
|
77 |
+
break
|
78 |
+
|
79 |
+
|
80 |
+
class ThePile(IterableDataset):
|
81 |
+
TEXT_BUFFER_SIZE = 4096
|
82 |
+
|
83 |
+
def __init__(self, subset: Subset):
|
84 |
+
self.subset = subset
|
85 |
+
|
86 |
+
def __iter__(self):
|
87 |
+
urls = URLs[self.subset].copy()
|
88 |
+
while True:
|
89 |
+
wi = get_worker_info()
|
90 |
+
seed = wi.id if wi is not None else None
|
91 |
+
rnd = random.Random(seed)
|
92 |
+
rnd.shuffle(urls)
|
93 |
+
for url in urls:
|
94 |
+
r = requests.get(url, stream=True)
|
95 |
+
with zstd.ZstdDecompressor().stream_reader(r.raw) as reader:
|
96 |
+
for line in _line_streamer(reader, self.TEXT_BUFFER_SIZE):
|
97 |
+
data = json.loads(line)
|
98 |
+
yield data
|
99 |
+
|
100 |
+
|
101 |
+
if __name__ == "__main__":
|
102 |
+
from tqdm import tqdm
|
103 |
+
|
104 |
+
dataset = ThePile("train")
|
105 |
+
for data in tqdm(dataset, smoothing=0.01):
|
106 |
+
pass
|
107 |
+
# Average: ~2000 samples/sec/worker
|
pile_hf.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import IterableDataset
|
3 |
+
|
4 |
+
from transformers import PreTrainedTokenizerBase
|
5 |
+
|
6 |
+
from pile import ThePile
|
7 |
+
|
8 |
+
|
9 |
+
class ThePileTokenized(IterableDataset):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
base_dataset: ThePile,
|
13 |
+
tokenizer: PreTrainedTokenizerBase,
|
14 |
+
max_length: int = 1024,
|
15 |
+
repeat_factor: float = 1.0,
|
16 |
+
):
|
17 |
+
self.pile = base_dataset
|
18 |
+
self.tokenizer = tokenizer
|
19 |
+
self.max_length = max_length
|
20 |
+
self.repeat_factor = repeat_factor
|
21 |
+
|
22 |
+
def __iter__(self):
|
23 |
+
ds = iter(self.pile)
|
24 |
+
buffer = []
|
25 |
+
while True:
|
26 |
+
tokens = self.tokenizer.encode(next(ds)["text"])
|
27 |
+
buffer += [self.tokenizer.eos_token_id] + tokens
|
28 |
+
while len(buffer) > self.max_length:
|
29 |
+
yield torch.tensor(buffer[: self.max_length])
|
30 |
+
buffer = buffer[int(self.max_length / self.repeat_factor) :]
|
31 |
+
|
32 |
+
|
33 |
+
if __name__ == "__main__":
|
34 |
+
from tqdm import tqdm
|
35 |
+
from torch.utils.data import DataLoader
|
36 |
+
from transformers import GPT2Tokenizer
|
37 |
+
|
38 |
+
dataset = ThePileTokenized(
|
39 |
+
ThePile("train"),
|
40 |
+
GPT2Tokenizer.from_pretrained("gpt2"),
|
41 |
+
max_length=2048,
|
42 |
+
repeat_factor=4 / 3,
|
43 |
+
)
|
44 |
+
dataloader = DataLoader(
|
45 |
+
dataset,
|
46 |
+
batch_size=1,
|
47 |
+
)
|
48 |
+
for batch in tqdm(dataloader, smoothing=0.01):
|
49 |
+
x = 0
|
50 |
+
# ~6 iters/s for 1 worker
|