File size: 4,036 Bytes
e6333f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
# pip install accelerate datasets transformers huggingface_hub wandb gated_state_spaces_pytorch
import os
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import wandb
from tqdm import tqdm
from transformers import BloomForCausalLM, BloomTokenizerFast
from gated_state_spaces_pytorch import GatedStateSpacesLM
from gated_state_spaces_pytorch.autoregressive_wrapper import AutoregressiveWrapper
# from c4x import C4X
from pile_hf import ThePile, ThePileTokenized
from accelerate import Accelerator
def main():
accelerator = Accelerator(
log_with="wandb",
gradient_accumulation_steps=8192,
)
accelerator.init_trackers("gated-state-space")
emb_fn = "emb.pt"
model_name = "bigscience/bloomz-1b7"
if not os.path.isfile(emb_fn):
bloom = BloomForCausalLM.from_pretrained(model_name)
wte = bloom.transformer.word_embeddings.state_dict()
torch.save(wte, emb_fn)
else:
wte = torch.load(emb_fn)
f_emb = 2048
n_vocab = 250880
model = AutoregressiveWrapper(
GatedStateSpacesLM(
num_tokens=n_vocab,
dim=f_emb,
depth=24,
),
)
model.net.token_emb.requires_grad_(False)
model.net.token_emb.load_state_dict(wte)
to_logits = nn.Linear(f_emb, n_vocab, bias=False)
to_logits.requires_grad_(False)
to_logits.load_state_dict(wte)
model.net.to_logits = nn.Sequential(
nn.LayerNorm(f_emb),
to_logits,
)
model.load_state_dict(torch.load("model3.pt"))
model = model.to(accelerator.device)
if accelerator.is_main_process:
wandb.watch(model)
optim = AdamW(model.parameters(), 1e-4)
sch = CosineAnnealingWarmRestarts(
optim,
T_0=1000,
T_mult=2,
eta_min=1e-7,
)
bs = 1
kk = 2048
tok: BloomTokenizerFast = BloomTokenizerFast.from_pretrained(model_name)
dsx = ThePileTokenized(
ThePile("train"),
tokenizer=tok,
max_length=kk,
repeat_factor=4 / 3,
)
dlx = DataLoader(
dsx,
batch_size=bs,
num_workers=12,
)
prog = tqdm(dlx, disable=not accelerator.is_main_process)
model = accelerator.prepare(model)
optim, dlx, sch = accelerator.prepare(optim, dlx, sch)
optim.zero_grad()
for i, batch in enumerate(prog):
batch = batch.to(accelerator.device)
with accelerator.accumulate(model):
with accelerator.autocast():
los = model(batch)
accelerator.backward(los)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optim.step()
optim.zero_grad()
if not accelerator.optimizer_step_was_skipped:
sch.step()
if i % 1000 == 0:
unwrapped_model = accelerator.unwrap_model(model)
b, n = 1, 512
init = torch.tensor([[2]] * b).to(accelerator.device)
prd = unwrapped_model.generate(init, n)
prd = [tok.decode(p) for p in prd]
try:
accelerator.log(
dict(
text=wandb.Html(
"<hr>".join(p.replace("\n", "<br>") for p in prd)
)
),
step=i,
)
except Exception as ex:
accelerator.print("Failed to log to W&B...", ex)
sd = unwrapped_model.state_dict()
# sd.pop('net.to_logits.weight')
accelerator.save(sd, "model4.pt")
if i % 10 == 0:
accelerator.log(
dict(
loss=los.item(),
lr=optim.param_groups[0]["lr"],
),
step=i,
)
prog.set_postfix(loss=los.item())
if __name__ == "__main__":
main()
|