|
|
|
import os |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.optim import AdamW |
|
from torch.utils.data import DataLoader |
|
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts |
|
|
|
import wandb |
|
from tqdm import tqdm |
|
from transformers import BloomForCausalLM, BloomTokenizerFast |
|
from gated_state_spaces_pytorch import GatedStateSpacesLM |
|
from gated_state_spaces_pytorch.autoregressive_wrapper import AutoregressiveWrapper |
|
|
|
|
|
from pile_hf import ThePile, ThePileTokenized |
|
from accelerate import Accelerator |
|
|
|
|
|
def main(): |
|
accelerator = Accelerator( |
|
log_with="wandb", |
|
gradient_accumulation_steps=8192, |
|
) |
|
accelerator.init_trackers("gated-state-space") |
|
|
|
emb_fn = "emb.pt" |
|
model_name = "bigscience/bloomz-1b7" |
|
if not os.path.isfile(emb_fn): |
|
bloom = BloomForCausalLM.from_pretrained(model_name) |
|
wte = bloom.transformer.word_embeddings.state_dict() |
|
torch.save(wte, emb_fn) |
|
else: |
|
wte = torch.load(emb_fn) |
|
|
|
f_emb = 2048 |
|
n_vocab = 250880 |
|
model = AutoregressiveWrapper( |
|
GatedStateSpacesLM( |
|
num_tokens=n_vocab, |
|
dim=f_emb, |
|
depth=24, |
|
), |
|
) |
|
|
|
model.net.token_emb.requires_grad_(False) |
|
model.net.token_emb.load_state_dict(wte) |
|
|
|
to_logits = nn.Linear(f_emb, n_vocab, bias=False) |
|
to_logits.requires_grad_(False) |
|
to_logits.load_state_dict(wte) |
|
|
|
model.net.to_logits = nn.Sequential( |
|
nn.LayerNorm(f_emb), |
|
to_logits, |
|
) |
|
model.load_state_dict(torch.load("model3.pt")) |
|
model = model.to(accelerator.device) |
|
|
|
if accelerator.is_main_process: |
|
wandb.watch(model) |
|
|
|
optim = AdamW(model.parameters(), 1e-4) |
|
sch = CosineAnnealingWarmRestarts( |
|
optim, |
|
T_0=1000, |
|
T_mult=2, |
|
eta_min=1e-7, |
|
) |
|
|
|
bs = 1 |
|
kk = 2048 |
|
tok: BloomTokenizerFast = BloomTokenizerFast.from_pretrained(model_name) |
|
dsx = ThePileTokenized( |
|
ThePile("train"), |
|
tokenizer=tok, |
|
max_length=kk, |
|
repeat_factor=4 / 3, |
|
) |
|
dlx = DataLoader( |
|
dsx, |
|
batch_size=bs, |
|
num_workers=12, |
|
) |
|
|
|
prog = tqdm(dlx, disable=not accelerator.is_main_process) |
|
|
|
model = accelerator.prepare(model) |
|
optim, dlx, sch = accelerator.prepare(optim, dlx, sch) |
|
|
|
optim.zero_grad() |
|
for i, batch in enumerate(prog): |
|
batch = batch.to(accelerator.device) |
|
with accelerator.accumulate(model): |
|
with accelerator.autocast(): |
|
los = model(batch) |
|
accelerator.backward(los) |
|
if accelerator.sync_gradients: |
|
accelerator.clip_grad_norm_(model.parameters(), 1.0) |
|
optim.step() |
|
optim.zero_grad() |
|
if not accelerator.optimizer_step_was_skipped: |
|
sch.step() |
|
|
|
if i % 1000 == 0: |
|
unwrapped_model = accelerator.unwrap_model(model) |
|
b, n = 1, 512 |
|
init = torch.tensor([[2]] * b).to(accelerator.device) |
|
prd = unwrapped_model.generate(init, n) |
|
prd = [tok.decode(p) for p in prd] |
|
try: |
|
accelerator.log( |
|
dict( |
|
text=wandb.Html( |
|
"<hr>".join(p.replace("\n", "<br>") for p in prd) |
|
) |
|
), |
|
step=i, |
|
) |
|
except Exception as ex: |
|
accelerator.print("Failed to log to W&B...", ex) |
|
sd = unwrapped_model.state_dict() |
|
|
|
accelerator.save(sd, "model4.pt") |
|
|
|
if i % 10 == 0: |
|
accelerator.log( |
|
dict( |
|
loss=los.item(), |
|
lr=optim.param_groups[0]["lr"], |
|
), |
|
step=i, |
|
) |
|
prog.set_postfix(loss=los.item()) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|