naxalpha's picture
working
c9ebb32
raw
history blame
3.07 kB
# pip install accelerate datasets transformers huggingface_hub wandb gated_state_spaces_pytorch
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
import wandb
from tqdm import tqdm
from transformers import GPT2LMHeadModel
from gated_state_spaces_pytorch import GatedStateSpacesLM
from gated_state_spaces_pytorch.autoregressive_wrapper import AutoregressiveWrapper
from c4x import C4X
from accelerate import Accelerator
def main():
accelerator = Accelerator(
log_with="wandb",
gradient_accumulation_steps=4,
)
accelerator.init_trackers("gated-state-space")
f_emb = 1600
model = AutoregressiveWrapper(
GatedStateSpacesLM(
num_tokens=50257,
dim=f_emb,
depth=24,
),
)
# model.net.token_emb.weight.requires_grad_(False)
# model.net.to_logits.weight.requires_grad_(False)
model.net.to_logits = nn.Sequential(
nn.LayerNorm(f_emb),
model.net.to_logits,
)
model = model.to(accelerator.device)
if accelerator.is_main_process:
wandb.watch(model)
model.load_state_dict(torch.load('model.pt'))
optim = AdamW(model.parameters(), 5e-6)
bs = 1
kk = 2048
dsx = C4X(kk+1)
dlx = DataLoader(
dsx,
batch_size=bs,
num_workers=4,
)
prog = tqdm(dlx, disable=not accelerator.is_main_process)
model = accelerator.prepare(model)
optim, dlx = accelerator.prepare(optim, dlx)
optim.zero_grad()
for i, batch in enumerate(prog):
batch = batch.to(accelerator.device)
with accelerator.accumulate(model):
with accelerator.autocast():
los = model(batch)
accelerator.backward(los)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(
model.parameters(),
1.0,
)
optim.step()
optim.zero_grad()
if i % 1000 == 0:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
b, n = 1, 2048
init = torch.tensor([[50256]]*b).to(accelerator.device)
prd = unwrapped_model.generate(init, n)
prd = [dsx.decode(p) for p in prd]
try:
accelerator.log(dict(
text=wandb.Html(
'<hr>'.join(
p.replace('\n', '<br>') for p in prd
)
)), step=i)
except Exception as ex:
accelerator.print('Failed to log to W&B...', ex)
accelerator.save(unwrapped_model.state_dict(), 'model2.pt')
if i % 10 == 0:
accelerator.log(dict(
loss=los.item(),
), step=i)
prog.set_postfix(loss=los.item())
if __name__ == '__main__':
main()