# pip install accelerate datasets transformers huggingface_hub wandb gated_state_spaces_pytorch import os import torch import torch.nn as nn from torch.optim import AdamW from torch.utils.data import DataLoader from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts import wandb from tqdm import tqdm from transformers import BloomForCausalLM, BloomTokenizerFast from gated_state_spaces_pytorch import GatedStateSpacesLM from gated_state_spaces_pytorch.autoregressive_wrapper import AutoregressiveWrapper # from c4x import C4X from pile_hf import ThePile, ThePileTokenized from accelerate import Accelerator def main(): accelerator = Accelerator( log_with="wandb", gradient_accumulation_steps=8192, ) accelerator.init_trackers("gated-state-space") emb_fn = "emb.pt" model_name = "bigscience/bloomz-1b7" if not os.path.isfile(emb_fn): bloom = BloomForCausalLM.from_pretrained(model_name) wte = bloom.transformer.word_embeddings.state_dict() torch.save(wte, emb_fn) else: wte = torch.load(emb_fn) f_emb = 2048 n_vocab = 250880 model = AutoregressiveWrapper( GatedStateSpacesLM( num_tokens=n_vocab, dim=f_emb, depth=24, ), ) model.net.token_emb.requires_grad_(False) model.net.token_emb.load_state_dict(wte) to_logits = nn.Linear(f_emb, n_vocab, bias=False) to_logits.requires_grad_(False) to_logits.load_state_dict(wte) model.net.to_logits = nn.Sequential( nn.LayerNorm(f_emb), to_logits, ) model.load_state_dict(torch.load("model3.pt")) model = model.to(accelerator.device) if accelerator.is_main_process: wandb.watch(model) optim = AdamW(model.parameters(), 1e-4) sch = CosineAnnealingWarmRestarts( optim, T_0=1000, T_mult=2, eta_min=1e-7, ) bs = 1 kk = 2048 tok: BloomTokenizerFast = BloomTokenizerFast.from_pretrained(model_name) dsx = ThePileTokenized( ThePile("train"), tokenizer=tok, max_length=kk, repeat_factor=4 / 3, ) dlx = DataLoader( dsx, batch_size=bs, num_workers=12, ) prog = tqdm(dlx, disable=not accelerator.is_main_process) model = accelerator.prepare(model) optim, dlx, sch = accelerator.prepare(optim, dlx, sch) optim.zero_grad() for i, batch in enumerate(prog): batch = batch.to(accelerator.device) with accelerator.accumulate(model): with accelerator.autocast(): los = model(batch) accelerator.backward(los) if accelerator.sync_gradients: accelerator.clip_grad_norm_(model.parameters(), 1.0) optim.step() optim.zero_grad() if not accelerator.optimizer_step_was_skipped: sch.step() if i % 1000 == 0: unwrapped_model = accelerator.unwrap_model(model) b, n = 1, 512 init = torch.tensor([[2]] * b).to(accelerator.device) prd = unwrapped_model.generate(init, n) prd = [tok.decode(p) for p in prd] try: accelerator.log( dict( text=wandb.Html( "
".join(p.replace("\n", "
") for p in prd) ) ), step=i, ) except Exception as ex: accelerator.print("Failed to log to W&B...", ex) sd = unwrapped_model.state_dict() # sd.pop('net.to_logits.weight') accelerator.save(sd, "model4.pt") if i % 10 == 0: accelerator.log( dict( loss=los.item(), lr=optim.param_groups[0]["lr"], ), step=i, ) prog.set_postfix(loss=los.item()) if __name__ == "__main__": main()