MeMDLM / diffusion.py
sgoel30's picture
Upload 12 files
d061944 verified
raw
history blame
52 kB
import itertools
import math
import os
import sys
import typing
from dataclasses import dataclass
import hydra.utils
import lightning as L
import numpy as np
import torch.nn as nn
import torch
# import dit
import ema
import time
import gc
import pl_data_loader as dataloader
import torch.nn.functional as F
import torchmetrics
import transformers
from torch import Tensor
from torch.optim.lr_scheduler import _LRScheduler
from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer
import utils
import noise_schedule
LOG2 = math.log(2)
class CosineWarmup(_LRScheduler):
def __init__(self, optimizer, warmup_steps, total_steps, eta_ratio=0.1, last_epoch=-1):
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.eta_ratio = eta_ratio # The ratio of minimum to maximum learning rate
super(CosineWarmup, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch < self.warmup_steps:
return [base_lr * self.last_epoch / self.warmup_steps for base_lr in self.base_lrs]
progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
decayed_lr = (1 - self.eta_ratio) * cosine_decay + self.eta_ratio
return [decayed_lr * base_lr for base_lr in self.base_lrs]
def _sample_categorical(categorical_probs):
gumbel_norm = (
1e-10
- (torch.rand_like(categorical_probs) + 1e-10).log())
return (categorical_probs / gumbel_norm).argmax(dim=-1)
def _unsqueeze(x, reference):
return x.view(
* x.shape,
* ((1,) * (len(reference.shape) - len(x.shape))))
@dataclass
class Loss:
loss: torch.FloatTensor
nlls: torch.FloatTensor
token_mask: torch.FloatTensor
class NLL(torchmetrics.aggregation.MeanMetric):
pass
class BPD(NLL):
def compute(self) -> Tensor:
"""Computes the bits per dimension.
Returns:
bpd
"""
return self.mean_value / self.weight / LOG2
class Perplexity(NLL):
def compute(self) -> Tensor:
"""Computes the Perplexity.
Returns:
Perplexity
"""
return torch.exp(self.mean_value / self.weight)
class WrapVanillaESM(nn.Module):
def __init__(self, bert_model_path):
super(WrapVanillaESM, self).__init__()
#self.bert_model_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#self.model = AutoModelForMaskedLM.from_pretrained(bert_model_path).to(self.bert_model_device)
self.model = AutoModelForMaskedLM.from_pretrained(bert_model_path, device_map='cpu')
self.tokenizer = AutoTokenizer.from_pretrained(bert_model_path)
def __call__(self, *args, **kwargs):
return self.model(*args, **kwargs)
def unfreeze_attn_layers(self):
model_layers = len(self.model.esm.encoder.layer)
for i, layer in enumerate(self.model.esm.encoder.layer):
if i >= model_layers-5: # fine-tune only last n layers
for module in layer.attention.self.key.modules():
for param in module.parameters():
param.requires_grad = True
for module in layer.attention.self.query.modules():
for param in module.parameters():
param.requires_grad = True
for module in layer.attention.self.value.modules():
for param in module.parameters():
param.requires_grad = True
def unfreeze_all_layers(self):
for param in self.model.parameters():
param.requires_grad = True
def forward(self, inputs, sigma, attention_mask):
logits = self.model(input_ids=inputs, attention_mask=attention_mask).logits
return logits
def save_model(self, save_dir):
self.model.save_pretrained(save_dir)
self.tokenizer.save_pretrained(save_dir)
def load_model(self, load_dir):
self.model = AutoModel.from_pretrained(load_dir)
self.tokenizer = AutoTokenizer.from_pretrained(load_dir)
class WrapMembraneESM(nn.Module):
def __init__(self, bert_model_path):
super(WrapMembraneESM, self).__init__()
#self.bert_model_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#self.model = AutoModelForMaskedLM.from_pretrained(bert_model_path).to(self.bert_model_device)
self.model = AutoModelForMaskedLM.from_pretrained(bert_model_path, device_map='cpu')
self.tokenizer = AutoTokenizer.from_pretrained(bert_model_path)
def __call__(self, *args, **kwargs):
return self.model(*args, **kwargs)
def freeze_model(self):
for param in self.model.parameters():
param.requires_grad = False
def unfreeze_all_layers(self):
for param in self.model.parameters():
param.requires_grad = True
def unfreeze_attn_layers(self):
model_layers = len(self.model.esm.encoder.layer)
for i, layer in enumerate(self.model.esm.encoder.layer):
if i >= model_layers-11: # fine-tune only last n layers
for module in layer.attention.self.key.modules():
for param in module.parameters():
param.requires_grad = True
for module in layer.attention.self.query.modules():
for param in module.parameters():
param.requires_grad = True
for module in layer.attention.self.value.modules():
for param in module.parameters():
param.requires_grad = True
def forward(self, inputs, sigma, attention_mask):
logits = self.model(input_ids=inputs, attention_mask=attention_mask).logits
return logits
def save_model(self, save_dir):
self.model.save_pretrained(save_dir)
self.tokenizer.save_pretrained(save_dir)
def load_model(self, load_dir):
self.model = AutoModel.from_pretrained(load_dir)
self.tokenizer = AutoTokenizer.from_pretrained(load_dir)
class Diffusion(L.LightningModule):
def __init__(
self,
config,
tokenizer: transformers.PreTrainedTokenizer):
super().__init__()
self.save_hyperparameters()
self.config = config
self.tokenizer = tokenizer
self.vocab_size = self.tokenizer.vocab_size
self.sampler = self.config.sampling.predictor
self.gen_ppl_eval_model_name_or_path = self.config.eval.\
gen_ppl_eval_model_name_or_path
self.antithetic_sampling = self.config.training.antithetic_sampling
self.importance_sampling = self.config.training.importance_sampling
self.change_of_variables = self.config.training.change_of_variables
if (not hasattr(self.tokenizer, 'mask_token')
or self.tokenizer.mask_token is None):
self.mask_index = self.vocab_size
self.vocab_size += 1
else:
self.mask_index = self.tokenizer.mask_token_id
self.parameterization = self.config.parameterization
# if self.config.backbone == 'dit':
# self.backbone = dit.DIT(
# self.config, vocab_size=self.vocab_size, mlm_model_path=config.training.mlm_model_path)
if self.config.backbone == "vanilla_esm_pretrain":
self.backbone = WrapVanillaESM(bert_model_path=self.config.training.esm_model_path)
self.backbone.unfreeze_all_layers()
self.backbone = torch.compile(self.backbone)
elif self.config.backbone == 'membrane_esm_finetune':
self.backbone = WrapMembraneESM(bert_model_path=self.config.checkpointing.pretrained_esm_mdlm_automodel_path)
self.backbone.unfreeze_all_layers()
# self.backbone = torch.compile(self.backbone)
# elif self.config.backbone == 'dimamba':
# self.backbone = dimamba.DiMamba(
# self.config,
# vocab_size=self.vocab_size,
# pad_token_id=self.tokenizer.pad_token_id)
# elif self.config.backbone == 'ar':
# self.backbone = autoregressive.AR(
# self.config,
# vocab_size=self.vocab_size,
# mask_index=self.mask_index)
# elif self.config.backbone == 'hf_dit':
# self.backbone = transformers.AutoModelForMaskedLM.from_pretrained(
# config.eval.checkpoint_path, trust_remote_code=True)
# else:
# raise ValueError(
# f'Unknown backbone: {self.config.backbone}')
self.T = self.config.T
self.subs_masking = self.config.subs_masking
self.softplus = torch.nn.Softplus()
# metrics are automatically reset at end of epoch
metrics = torchmetrics.MetricCollection({
'nll': NLL(),
'bpd': BPD(),
'ppl': Perplexity(),
})
metrics.set_dtype(torch.float64)
self.train_metrics = metrics.clone(prefix='train/')
self.valid_metrics = metrics.clone(prefix='val/')
self.test_metrics = metrics.clone(prefix='test/')
# generative perplexity
self.gen_ppl_metric = Perplexity()
self.eval_model_tokenizer = transformers.AutoTokenizer.\
from_pretrained(self.gen_ppl_eval_model_name_or_path)
if self.eval_model_tokenizer.pad_token is None:
self.eval_model_tokenizer.pad_token =\
self.eval_model_tokenizer.eos_token
self.eval_model_tokenizer.pad_token_id =\
self.eval_model_tokenizer.eos_token_id
self.noise = noise_schedule.get_noise(self.config,
dtype=self.dtype)
if self.config.training.ema > 0:
self.ema = ema.ExponentialMovingAverage(
itertools.chain(self.backbone.parameters(),
self.noise.parameters()),
decay=self.config.training.ema)
else:
self.ema = None
self.lr = self.config.optim.lr
self.sampling_eps = self.config.training.sampling_eps
self.time_conditioning = self.config.time_conditioning
self.neg_infinity = -1000000.0
self.fast_forward_epochs = None
self.fast_forward_batches = None
self._validate_configuration()
def _validate_configuration(self):
assert not (self.change_of_variables
and self.importance_sampling)
if self.parameterization == 'sedd':
assert not self.importance_sampling
assert not self.change_of_variables
if self.parameterization == 'd3pm':
assert self.T > 0
if self.T > 0:
assert self.parameterization in {'d3pm', 'subs'}
if self.subs_masking:
assert self.parameterization == 'd3pm'
def on_load_checkpoint(self, checkpoint):
if self.ema:
self.ema.load_state_dict(checkpoint['ema'])
# Copied from:
# https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py#L41
self.fast_forward_epochs = checkpoint['loops'][
'fit_loop']['epoch_progress']['current']['completed']
self.fast_forward_batches = checkpoint['loops'][
'fit_loop']['epoch_loop.batch_progress'][
'current']['completed']
def on_save_checkpoint(self, checkpoint):
if self.ema:
checkpoint['ema'] = self.ema.state_dict()
# Copied from:
# https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py
# ['epoch_loop.batch_progress']['total']['completed'] is 1 iteration
# behind, so we're using the optimizer's progress.
checkpoint['loops']['fit_loop'][
'epoch_loop.batch_progress']['total'][
'completed'] = checkpoint['loops']['fit_loop'][
'epoch_loop.automatic_optimization.optim_progress'][
'optimizer']['step']['total'][
'completed'] * self.trainer.accumulate_grad_batches
checkpoint['loops']['fit_loop'][
'epoch_loop.batch_progress']['current'][
'completed'] = checkpoint['loops']['fit_loop'][
'epoch_loop.automatic_optimization.optim_progress'][
'optimizer']['step']['current'][
'completed'] * self.trainer.accumulate_grad_batches
# _batches_that_stepped tracks the number of global steps, not the number
# of local steps, so we don't multiply with self.trainer.accumulate_grad_batches here.
checkpoint['loops']['fit_loop'][
'epoch_loop.state_dict'][
'_batches_that_stepped'] = checkpoint['loops']['fit_loop'][
'epoch_loop.automatic_optimization.optim_progress'][
'optimizer']['step']['total']['completed']
if 'sampler' not in checkpoint.keys():
checkpoint['sampler'] = {}
if hasattr(self.trainer.train_dataloader.sampler,
'state_dict'):
sampler_state_dict = self.trainer.\
train_dataloader.sampler.state_dict()
checkpoint['sampler'][
'random_state'] = sampler_state_dict.get(
'random_state', None)
else:
checkpoint['sampler']['random_state'] = None
self.backbone.save_model(self.config.checkpointing.fine_tuned_esm_mdlm_ckpt_path)
def on_train_start(self):
torch.cuda.empty_cache()
if self.ema:
self.ema.move_shadow_params_to_device(self.device)
# Adapted from:
# https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py
distributed = (
self.trainer._accelerator_connector.use_distributed_sampler
and self.trainer._accelerator_connector.is_distributed)
if distributed:
sampler_cls = dataloader.FaultTolerantDistributedSampler
else:
sampler_cls = dataloader.RandomFaultTolerantSampler
updated_dls = []
for dl in self.trainer.fit_loop._combined_loader.flattened:
if hasattr(dl.sampler, 'shuffle'):
dl_sampler = sampler_cls(
dl.dataset, shuffle=dl.sampler.shuffle)
else:
dl_sampler = sampler_cls(dl.dataset)
if (distributed
and self.fast_forward_epochs is not None
and self.fast_forward_batches is not None):
dl_sampler.load_state_dict({
'epoch': self.fast_forward_epochs,
'counter': (self.fast_forward_batches
* self.config.loader.batch_size)})
from functools import partial
from pl_data_loader import collate_fn
collate_partial = partial(collate_fn, tokenizer=self.tokenizer)
torch.cuda.empty_cache()
updated_dls.append(
torch.utils.data.DataLoader(
dl.dataset,
batch_size=self.config.loader.batch_size,
num_workers=self.config.loader.num_workers,
pin_memory=self.config.loader.pin_memory,
sampler=dl_sampler,
shuffle=False,
persistent_workers=False,
collate_fn=collate_partial))
self.trainer.fit_loop._combined_loader.flattened = updated_dls
def optimizer_step(self, *args, **kwargs):
super().optimizer_step(*args, **kwargs)
gc.collect()
torch.cuda.empty_cache()
if self.ema:
self.ema.update(itertools.chain(
self.backbone.parameters(),
self.noise.parameters()))
# optimizer_closure = kwargs.get('optimizer_closure', None)
# params_with_grad = [p for p in itertools.chain(
# self.backbone.parameters(),
# self.noise.parameters()
# ) if p.requires_grad and p.grad_fn is not None]
# # if params_with_grad:
# # super().optimizer_step(closure=optimizer_closure)
# if self.ema:
# self.ema.update(params_with_grad)
# super().optimizer_step(*args, **kwargs)
def _subs_parameterization(self, logits, xt):
# log prob at the mask index = - infinity
logits = logits.logits
logits[:, :, self.mask_index] += self.neg_infinity
# logits[:, :, self.tokenizer.eos_token_id] += self.neg_infinity
# logits[:, :, self.tokenizer.cls_token_id] += self.neg_infinity
# Normalize the logits such that x.exp() is
# a probability distribution over vocab_size.
logits = logits - torch.logsumexp(logits, dim=-1,
keepdim=True)
# Apply updates directly in the logits matrix.
# For the logits of the unmasked tokens, set all values
# to -infinity except for the indices corresponding to
# the unmasked tokens.
unmasked_indices = (xt != self.mask_index)
logits[unmasked_indices] = self.neg_infinity
logits[unmasked_indices, xt[unmasked_indices]] = 0
return logits
def _d3pm_parameterization(self, logits):
if self.subs_masking:
logits[:, :, self.mask_index] += self.neg_infinity
logits = logits - torch.logsumexp(logits, dim=-1,
keepdim=True)
return logits
def _sedd_parameterization(self, logits, xt, sigma):
esigm1_log = torch.where(
sigma < 0.5,
torch.expm1(sigma),
sigma.exp() - 1).log().to(logits.dtype)
# logits shape
# (batch_size, diffusion_model_input_length, vocab_size)
logits = logits - esigm1_log[:, None, None] - np.log(
logits.shape[-1] - 1)
# The below scatter operation sets the log score
# for the input word to 0.
logits = torch.scatter(logits, -1, xt[..., None],
torch.zeros_like(logits[..., :1]))
return logits
def _process_sigma(self, sigma):
if sigma is None:
assert self.parameterization == 'ar'
return sigma
if sigma.ndim > 1:
sigma = sigma.squeeze(-1)
if not self.time_conditioning:
sigma = torch.zeros_like(sigma)
assert sigma.ndim == 1, sigma.shape
return sigma
def forward(self, x, sigma, attention_mask, print_logits=False):
"""Returns log score."""
sigma = self._process_sigma(sigma)
with torch.amp.autocast("cuda", dtype=torch.float32):
logits = self.backbone(x, attention_mask)
# if print_logits:
# torch.set_printoptions(profile="full")
# print(logits)
# torch.set_printoptions(profile="default")
if self.parameterization == 'subs':
return self._subs_parameterization(logits=logits, xt=x)
return logits
def _d3pm_loss(self, model_output, xt, x0, t, attention_mask):
dt = 1 / self.T
if torch.is_tensor(t):
t = t[:, None]
assert t.ndim == 2
t = t.clamp(0., 1. - 1e-4)
alpha_t = 1 - t + torch.zeros_like(xt)
alpha_s = 1 - (t - dt) + torch.zeros_like(xt)
log_x_theta_at_x0 = torch.gather(
model_output, -1, x0[:, :, None]).squeeze(-1)
log_x_theta_at_m = model_output[:, :, self.mask_index]
x_theta_at_m = log_x_theta_at_m.exp()
term_1_coef = dt / t
term_1_log_nr = torch.log(alpha_t * x_theta_at_m / t + 1)
term_1_log_dr = log_x_theta_at_x0
term_2_coef = 1 - dt / t
term_2_log_nr = term_1_log_nr
term_2_log_dr = torch.log(alpha_s * x_theta_at_m / (t - dt) + 1)
L_vb_masked = (
term_1_coef * (term_1_log_nr - term_1_log_dr)
+ term_2_coef * (term_2_log_nr - term_2_log_dr))
L_vb = L_vb_masked * (xt == self.mask_index)
return self.T * L_vb
def _compute_loss(self, batch, prefix):
if 'attention_mask' in batch:
attention_mask = batch['attention_mask']
else:
attention_mask = None
if 'mask' in batch: mask = batch['mask']
else: mask = None
losses = self._loss(batch['input_ids'], attention_mask, mask)
loss = losses.loss
if prefix == 'train':
self.train_metrics.update(losses.nlls, losses.token_mask)
metrics = self.train_metrics
elif prefix == 'val':
self.valid_metrics.update(losses.nlls, losses.token_mask)
metrics = self.valid_metrics
elif prefix == 'test':
self.test_metrics.update(losses.nlls, losses.token_mask)
metrics = self.test_metrics
else:
raise ValueError(f'Invalid prefix: {prefix}')
self.log_dict(metrics,
on_step=False,
on_epoch=True,
sync_dist=True)
return loss
def on_train_epoch_start(self):
self.backbone.train()
self.noise.train()
def training_step(self, batch, batch_idx):
# Initialize throughput calculation
start_time = time.time()
loss = self._compute_loss(batch, prefix='train')
self.log(name='trainer/loss',
value=loss.item(),
on_step=True,
on_epoch=False,
sync_dist=True)
# Calculate throughput
elapsed_time = time.time() - start_time
total_tokens = batch['input_ids'].numel()
throughput = total_tokens / elapsed_time
self.log(name='trainer/throughput',
value=throughput,
on_step=True,
on_epoch=False,
sync_dist=True)
return loss
def on_validation_epoch_start(self):
# params_with_grad = [p for p in itertools.chain(
# self.backbone.parameters(),
# self.noise.parameters()
# ) if p.requires_grad]
# if self.ema:
# self.ema.store(params_with_grad)
# self.ema.copy_to(params_with_grad)
gc.collect()
torch.cuda.empty_cache()
if self.ema:
self.ema.store(
itertools.chain(
self.backbone.parameters(),
self.noise.parameters()))
self.ema.copy_to(itertools.chain(
self.backbone.parameters(),
self.noise.parameters()))
self.backbone.eval()
self.noise.eval()
assert self.valid_metrics.nll.mean_value == 0
assert self.valid_metrics.nll.weight == 0
def validation_step(self, batch, batch_idx):
loss = self._compute_loss(batch, prefix='val')
self.log(name='trainer/val_loss',
value=loss.item(),
on_step=True,
on_epoch=False,
prog_bar=True,
sync_dist=True)
return loss
def on_validation_epoch_end(self):
# params_with_grad = [p for p in itertools.chain(
# self.backbone.parameters(),
# self.noise.parameters()
# ) if p.requires_grad]
# if ((self.config.eval.compute_perplexity_on_sanity
# or not self.trainer.sanity_checking)
# and self.config.eval.generate_samples
# and not self.parameterization == 'ar'):
# # (justin): implement sampling and kv cache for AR
# samples, text_samples = None, None
# for _ in range(
# self.config.sampling.num_sample_batches):
# samples = self._sample()
# # Decode the samples to be re-tokenized by eval model
# text_samples = self.tokenizer.batch_decode(samples)
# if self.config.eval.compute_generative_perplexity:
# self.compute_generative_perplexity(text_samples)
# if self.trainer.global_rank == 0 and hasattr(
# self.trainer.logger, 'log_table'):
# # Log the last generated samples
# text_samples = text_samples[
# : self.config.sampling.num_sample_log]
# self.trainer.logger.log_table(
# key=f'samples@global_step{self.global_step}',
# columns=['Generated Samples'],
# data=[[s] for s in text_samples])
# if self.config.eval.compute_generative_perplexity:
# self.log('val/gen_ppl',
# self.gen_ppl_metric,
# on_epoch=True,
# on_step=False,
# sync_dist=True)
gc.collect()
torch.cuda.empty_cache()
if self.ema:
self.ema.restore(
itertools.chain(
self.backbone.parameters(),
self.noise.parameters()))
def test_step(self, batch, batch_idx):
loss = self._compute_loss(batch, prefix='test')
self.log('test/loss',
value=loss.item(),
on_step=False,
on_epoch=True,
sync_dist=True)
if self.config.eval.compute_generative_perplexity:
samples, text_samples = None, None
for _ in range(
self.config.sampling.num_sample_batches):
samples = self._sample()
# Decode the samples to be re-tokenized by eval model
text_samples = self.tokenizer.batch_decode(samples)
if self.config.eval.compute_generative_perplexity:
self.compute_generative_perplexity(text_samples)
if self.trainer.global_rank == 0 and hasattr(
self.trainer.logger, 'log_table'):
# Log the last generated samples
text_samples = text_samples[
: self.config.sampling.num_sample_log]
self.trainer.logger.log_table(
key=f'samples@global_step{self.global_step}',
columns=['Generated Samples'],
data=[[s] for s in text_samples])
if self.config.eval.compute_generative_perplexity:
self.log('test/gen_ppl',
self.gen_ppl_metric,
on_epoch=False,
on_step=True,
sync_dist=True)
def on_test_epoch_start(self):
# params_with_grad = [p for p in itertools.chain(
# self.backbone.parameters(),
# self.noise.parameters()
# ) if p.requires_grad]
if self.ema:
self.ema.store(itertools.chain(
self.backbone.parameters(),
self.noise.parameters()))
self.ema.copy_to(itertools.chain(
self.backbone.parameters(),
self.noise.parameters()))
self.backbone.eval()
self.noise.eval()
self.test_metrics.reset()
def on_test_epoch_end(self):
# params_with_grad = [p for p in itertools.chain(
# self.backbone.parameters(),
# self.noise.parameters()
# ) if p.requires_grad]
if self.ema:
self.ema.restore(itertools.chain(
self.backbone.parameters(),
self.noise.parameters()))
for metric_name, metric_value in self.test_metrics.compute().items():
self.log(metric_name, metric_value, sync_dist=True)
def configure_optimizers(self):
# (yair): Lightning currently giving this warning when using `fp16`:
# "Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
# Not clear if this is a problem or not.
# See: https://github.com/Lightning-AI/pytorch-lightning/issues/5558
# params_with_grad = [p for p in itertools.chain(
# self.backbone.parameters(),
# self.noise.parameters()
# ) if p.requires_grad]
optimizer = torch.optim.AdamW(
itertools.chain(self.backbone.parameters(),
self.noise.parameters()),
lr=self.config.optim.lr,
betas=(self.config.optim.beta1,
self.config.optim.beta2),
eps=self.config.optim.eps,
weight_decay=self.config.optim.weight_decay
)
# scheduler = hydra.utils.instantiate(
# self.config.lr_scheduler, optimizer=optimizer)
# scheduler_dict = {
# 'scheduler': scheduler,
# 'interval': 'step',
# 'monitor': 'val/loss',
# 'name': 'trainer/lr',
# }
self.total_steps = self.config.trainer.max_steps
scheduler = CosineWarmup(optimizer,
warmup_steps=self.config.lr_scheduler.num_warmup_steps,
total_steps=self.total_steps)
scheduler_dict = {
'scheduler': scheduler,
'interval': 'step',
'frequency': 1,
'monitor': 'val/loss',
'name': 'trainer/lr'
}
return [optimizer], [scheduler_dict]
@torch.no_grad()
def eval_retokenize(self, text_samples, max_length):
"""Retokenizes samples for the eval model.
Args:
text_samples: List of sentences generated by the model.
Returns:
samples: Samples re-tokenized for the eval model
attn_mask: Attention mask for the eval model
eval_context_size: Size of the context for the eval model
"""
if 'llama2' in self.gen_ppl_eval_model_name_or_path:
tokenizer_kwargs = {
'text_samples': text_samples,
'return_tensors': 'pt',
'return_token_type_ids': False,
'return_attention_mask': True,
'truncation': True,
'padding': True,
'max_length': max_length,
}
eval_context_size = 4096
else:
tokenizer_kwargs = {
'return_tensors': 'pt',
'return_token_type_ids': False,
'return_attention_mask': True,
'truncation': True,
'padding': True,
'max_length': max_length,
}
eval_context_size = 1024
samples = self.eval_model_tokenizer(
text_samples, ** tokenizer_kwargs)
attn_mask = samples['attention_mask']
samples = samples['input_ids']
if 'llama2' not in self.gen_ppl_eval_model_name_or_path:
attn_mask = attn_mask.to(self.device)
samples = samples.to(self.device)
return samples, attn_mask, eval_context_size
# @torch.no_grad()
# def compute_generative_perplexity(
# self,
# text_samples: typing.List[str],
# retokenize: bool = True,
# max_length: typing.Optional[int] = None) -> None:
# """Compute the generative perplexity of the model.
# Args:
# text_samples: List of sentences generated by the model.
# Returns:
# Perplexity of the generated text under a different
# pre-trained AR model (e.g., GPT2).
# """
# os.environ['TOKENIZERS_PARALLELISM'] = 'false'
# eval_model = transformers.AutoModelForCausalLM.from_pretrained(
# self.gen_ppl_eval_model_name_or_path).eval()
# if max_length is None:
# max_length = self.config.model.length
# if 'llama2' not in self.gen_ppl_eval_model_name_or_path:
# eval_model = eval_model.to(self.device)
# # Re-tokenize using eval model's tokenizer
# if retokenize:
# (samples, attn_mask,
# eval_context_size) = self.eval_retokenize(
# text_samples, max_length=max_length)
# else:
# samples = text_samples
# attn_mask = torch.ones(samples.shape).to(self.device)
# eval_context_size = samples.shape[-1]
# batch_size = min(
# self.config.eval.perplexity_batch_size,
# samples.shape[0])
# num_batches = samples.shape[0] // batch_size
# for i in range(num_batches):
# _samples = torch.split(
# samples[i * batch_size: (i + 1) * batch_size],
# eval_context_size,
# dim=-1)
# _attn_mask = torch.split(
# attn_mask[i * batch_size: (i + 1) * batch_size],
# eval_context_size,
# dim=-1)
# for (sample_chunk, attn_mask_chunk) in zip(
# _samples, _attn_mask):
# logits = eval_model(
# sample_chunk, attention_mask=attn_mask_chunk)[0]
# logits = logits.transpose(-1, -2)
# nlls = F.cross_entropy(logits[..., :-1],
# sample_chunk[..., 1:],
# reduction='none')
# first_eos = (sample_chunk == self.eval_model_tokenizer\
# .eos_token_id).cumsum(-1) == 1
# token_mask = (
# sample_chunk
# != self.eval_model_tokenizer.eos_token_id)
# self.gen_ppl_metric.update(
# nlls, first_eos[..., 1:] + token_mask[..., 1:])
@torch.no_grad()
def compute_masked_perplexity(self, sequences, masked):
"""Compute the pseudo-perplexity of the generated protein sequences."""
total_nll = 0
total_tokens = 0
for sequence in sequences:
# Tokenize the sequence
input_ids = self.tokenizer(masked, return_tensors="pt").input_ids.to(self.device)
gt_ids = self.tokenizer(sequence.upper(), return_tensors="pt").input_ids.to(self.device)
# print(input_ids.shape)
# print(gt_ids.shape)
# Forward pass through the ESM model
attention_mask = torch.ones_like(input_ids)
if self.config.mode in ['train', 'ppl_eval']:
outputs = self.backbone.model.forward(input_ids=input_ids, attention_mask=attention_mask)
elif self.config.mode == "sample_eval":
outputs = self.backbone.model.forward(input_ids)
logits = outputs[-1] # B, L, V
# Compute loss
# shift_logits = logits[:, :-1, :].contiguous() # remove eos
# shift_labels = input_ids[:, 1:].contiguous()
# print(masked)
# print(gt_ids.where(input_ids==32, torch.full_like(input_ids, -100)).view(-1))
loss = F.cross_entropy(logits.view(-1, logits.size(-1)),
gt_ids.where(input_ids==32, torch.full_like(input_ids, -100)).view(-1),
reduction='sum')
total_nll += loss.item()
#total_tokens += (input_ids != self.tokenizer.pad_token_id).sum().item() - 1 # -1 for the first token
total_tokens += input_ids.ne(self.tokenizer.pad_token_id).sum().item() # count in bos and eos
# Compute pseudo-perplexity
# print(total_nll, ",;,", total_tokens)
pseudo_perplexity = torch.exp(torch.tensor(total_nll / total_tokens))
self.gen_ppl_metric.update(pseudo_perplexity)
return pseudo_perplexity.item()
@torch.no_grad()
def compute_generative_perplexity(
self,
text_samples: typing.List[str],
retokenize: bool = True,
max_length: typing.Optional[int] = None) -> None:
"""Compute the generative perplexity of the model.
Args:
text_samples: List of sentences generated by the model.
Returns:
Perplexity of the generated text under a different
pre-trained AR model (e.g., GPT2).
"""
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
eval_model = transformers.AutoModelForCausalLM.from_pretrained(
self.gen_ppl_eval_model_name_or_path).eval()
if max_length is None:
max_length = self.config.model.length
if 'llama2' not in self.gen_ppl_eval_model_name_or_path:
eval_model = eval_model.to(self.device)
# Re-tokenize using eval model's tokenizer
if retokenize:
(samples, attn_mask,
eval_context_size) = self.eval_retokenize(
text_samples, max_length=max_length)
else:
samples = text_samples
attn_mask = torch.ones(samples.shape).to(self.device)
eval_context_size = samples.shape[-1]
batch_size = min(
self.config.eval.perplexity_batch_size,
samples.shape[0])
num_batches = samples.shape[0] // batch_size
for i in range(num_batches):
_samples = torch.split(
samples[i * batch_size: (i + 1) * batch_size],
eval_context_size,
dim=-1)
_attn_mask = torch.split(
attn_mask[i * batch_size: (i + 1) * batch_size],
eval_context_size,
dim=-1)
for (sample_chunk, attn_mask_chunk) in zip(
_samples, _attn_mask):
logits = eval_model(
sample_chunk, attention_mask=attn_mask_chunk)[0]
logits = logits.transpose(-1, -2)
nlls = F.cross_entropy(logits[..., :-1],
sample_chunk[..., 1:],
reduction='none')
first_eos = (sample_chunk == self.eval_model_tokenizer\
.eos_token_id).cumsum(-1) == 1
token_mask = (
sample_chunk
!= self.eval_model_tokenizer.eos_token_id)
self.gen_ppl_metric.update(
nlls, first_eos[..., 1:] + token_mask[..., 1:])
def q_xt(self, x, move_chance):
"""Computes the noisy sample xt.
Args:
x: int torch.Tensor with shape (batch_size,
diffusion_model_input_length), input.
move_chance: float torch.Tensor with shape (batch_size, 1).
"""
actual_seq_length = (x != 1).sum(dim=1, keepdim=True)
max_mask_length = (actual_seq_length * 0.75).long()
move_indices = torch.rand(*x.shape, device=x.device) < move_chance
restricted_move_indices = torch.zeros_like(move_indices, dtype=torch.bool)
for i in range(x.shape[0]):
true_positions = torch.where(move_indices[i])[0]
if len(true_positions) > max_mask_length[i]:
selected_positions = true_positions[:max_mask_length[i].item()]
restricted_move_indices[i, selected_positions] = True
else:
restricted_move_indices[i] = move_indices[i]
xt = torch.where(restricted_move_indices, self.mask_index, x)
return xt
def _sample_prior(self, *batch_dims):
return self.mask_index * torch.ones(* batch_dims, dtype=torch.int64)
def _ddpm_caching_update(self, x, t, dt, p_x0=None, attention_mask=None):
assert self.config.noise.type == 'loglinear'
sigma_t, _ = self.noise(t)
if t.ndim > 1:
t = t.squeeze(-1)
assert t.ndim == 1
move_chance_t = t[:, None, None]
move_chance_s = (t - dt)[:, None, None]
assert move_chance_t.ndim == 3, move_chance_t.shape
if p_x0 is None:
p_x0 = self.forward(x, sigma_t, attention_mask).exp()
assert move_chance_t.ndim == p_x0.ndim
q_xs = p_x0 * (move_chance_t - move_chance_s)
q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
_x = _sample_categorical(q_xs)
copy_flag = (x != self.mask_index).to(x.dtype)
return p_x0, copy_flag * x + (1 - copy_flag) * _x
def _ddpm_update(self, x, t, dt, attention_mask):
sigma_t, _ = self.noise(t)
sigma_s, _ = self.noise(t - dt)
if sigma_t.ndim > 1:
sigma_t = sigma_t.squeeze(-1)
if sigma_s.ndim > 1:
sigma_s = sigma_s.squeeze(-1)
assert sigma_t.ndim == 1, sigma_t.shape
assert sigma_s.ndim == 1, sigma_s.shape
move_chance_t = 1 - torch.exp(-sigma_t)
move_chance_s = 1 - torch.exp(-sigma_s)
move_chance_t = move_chance_t[:, None, None]
move_chance_s = move_chance_s[:, None, None]
unet_conditioning = sigma_t
log_p_x0 = self.forward(x, unet_conditioning, attention_mask)
assert move_chance_t.ndim == log_p_x0.ndim
# Technically, this isn't q_xs since there's a division
# term that is missing. This division term doesn't affect
# the samples.
q_xs = log_p_x0.exp() * (move_chance_t
- move_chance_s)
q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
_x = _sample_categorical(q_xs)
copy_flag = (x != self.mask_index).to(x.dtype)
return copy_flag * x + (1 - copy_flag) * _x
def _ar_sampler(self, bsz):
# precompute token buffer
num_pred_tokens = self.config.model.length - 1
x = torch.zeros(
(bsz, num_pred_tokens + 1),
dtype=torch.long,
device=self.device)
x[:, 0] = self.tokenizer.bos_token_id
# precompute noise
noise = (torch.distributions.Gumbel(0, 1)
.sample((bsz, num_pred_tokens, self.vocab_size))
.to(self.device))
for i in range(num_pred_tokens):
next_logits = self.forward(x[:, :i + 1], None)[:, -1]
y = (next_logits + noise[:, i]).argmax(-1)
x[:, i + 1] = y
return x
@torch.no_grad()
def _sample(self, num_steps=None, eps=1e-5, x_input = None):
"""Generate samples from the model."""
batch_size_per_gpu = self.config.eval.perplexity_batch_size
if self.parameterization == 'ar':
return self._ar_sampler(batch_size_per_gpu)
# Lightning auto-casting is not working in this method for some reason
if num_steps is None:
num_steps = self.config.sampling.steps
if x_input is not None:
x = x_input.input_ids
attention_mask = x_input.attention_mask
else:
x = self._sample_prior(batch_size_per_gpu, self.config.model.length).to(self.device)
attention_mask = torch.ones_like(x)
timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
dt = (1 - eps) / num_steps
p_x0_cache = None
for i in range(num_steps):
t = timesteps[i] * torch.ones(x.shape[0], 1, device=self.device)
if self.sampler == 'ddpm':
x = self._ddpm_update(x, t, dt)
elif self.sampler == 'ddpm_cache':
p_x0_cache, x_next = self._ddpm_caching_update(x, t, dt, p_x0=p_x0_cache, attention_mask=attention_mask)
if (not torch.allclose(x_next, x) or self.time_conditioning):
# Disable caching
p_x0_cache = None
x = x_next
# print(self.tokenizer.decode(x.squeeze()))
else:
x = self._analytic_update(x, t, dt, attention_mask)
if self.config.sampling.noise_removal:
t = timesteps[-1] * torch.ones(x.shape[0], 1,
device=self.device)
if self.sampler == 'analytic':
x = self._denoiser_update(x, t)
else:
unet_conditioning = self.noise(t)[0]
x = self.forward(x, unet_conditioning, attention_mask, print_logits=True).argmax(dim=-1)
# print(self.tokenizer.decode(x.squeeze()))
return x
def restore_model_and_sample(self, num_steps, eps=1e-5):
"""Generate samples from the model."""
# Lightning auto-casting is not working in this method for some reason
# params_with_grad = [p for p in itertools.chain(
# self.backbone.parameters(),
# self.noise.parameters()
# ) if p.requires_grad]
if self.ema:
self.ema.store(itertools.chain(self.backbone.parameters(),
self.noise.parameters()))
self.ema.copy_to(itertools.chain(self.backbone.parameters(),
self.noise.parameters()))
self.backbone.eval()
self.noise.eval()
samples = self._sample(num_steps=num_steps, eps=eps)
if self.ema:
self.ema.restore(itertools.chain(self.backbone.parameters(),
self.noise.parameters()))
self.backbone.train()
self.noise.train()
return samples
def get_score(self, x, sigma, attention_mask=None):
model_output = self.forward(x, sigma, attention_mask)
if self.parameterization == 'subs':
# score(x, t) = p_t(y) / p_t(x)
# => log score(x, t) = log p_t(y) - log p_t(x)
# case 1: x = masked
# (i) y = unmasked
# log score(x, t) = log p_\theta(x)|_y + log k
# where k = exp(- sigma) / (1 - exp(- sigma))
# (ii) y = masked
# log score(x, t) = 0
# case 2: x = unmasked
# (i) y != masked, y != x
# log score(x_i, t) = - inf
# (ii) y = x
# log score(x_i, t) = 0
# (iii) y = masked token
# log score(x_i, t) = - log k
# where k = exp(- sigma) / (1 - exp(- sigma))
log_k = - torch.log(torch.expm1(sigma)).squeeze(-1)
assert log_k.ndim == 1
masked_score = model_output + log_k[:, None, None]
masked_score[:, :, self.mask_index] = 0
unmasked_score = self.neg_infinity * torch.ones_like(
model_output)
unmasked_score = torch.scatter(
unmasked_score,
-1,
x[..., None],
torch.zeros_like(unmasked_score[..., :1]))
unmasked_score[:, :, self.mask_index] = - (
log_k[:, None] * torch.ones_like(x))
masked_indices = (x == self.mask_index).to(
model_output.dtype)[:, :, None]
model_output = (
masked_score * masked_indices
+ unmasked_score * (1 - masked_indices))
return model_output.exp()
def _staggered_score(self, score, dsigma):
score = score.clone()
extra_const = (1 - dsigma.exp()) * score.sum(dim=-1)
score *= dsigma.exp()[:, None]
score[..., self.mask_index] += extra_const
return score
def _analytic_update(self, x, t, step_size, attention_mask=None):
curr_sigma, _ = self.noise(t)
next_sigma, _ = self.noise(t - step_size)
dsigma = curr_sigma - next_sigma
score = self.get_score(x, curr_sigma, attention_mask)
stag_score = self._staggered_score(score, dsigma)
probs = stag_score * self._transp_transition(x, dsigma)
return _sample_categorical(probs)
def _denoiser_update(self, x, t):
sigma, _ = self.noise(t)
score = self.get_score(x, sigma)
stag_score = self._staggered_score(score, sigma)
probs = stag_score * self._transp_transition(x, sigma)
probs[..., self.mask_index] = 0
samples = _sample_categorical(probs)
return samples
def _transp_transition(self, i, sigma):
sigma = _unsqueeze(sigma, reference=i[..., None])
edge = torch.exp(-sigma) * F.one_hot(
i, num_classes=self.vocab_size)
edge += torch.where(i == self.mask_index,
1 - torch.exp(-sigma).squeeze(-1),
0)[..., None]
return edge
def _sample_t(self, n, device):
_eps_t = torch.rand(n, device=device)
if self.antithetic_sampling:
offset = torch.arange(n, device=device) / n
_eps_t = (_eps_t / n + offset) % 1
t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps
if self.importance_sampling:
return self.noise.importance_sampling_transformation(t)
return t
def _maybe_sub_sample(self, x0, attention_mask):
# seqlen = x0.shape[1]
# if seqlen > self.config.model.length:
# assert seqlen == 2 * self.config.model.length
# # cropping is needed for text8-crop dataset
# # try the same starting point for now
# start = np.random.choice(self.config.model.length)
# end = start + self.config.model.length
# input_tokens = x0[:, start: end]
# output_tokens = x0[:, start + 1: end + 1]
# new_attention_mask = attention_mask[:, start: end]
# # Helps with validation PPL, since the val
# # examples will all start and end with BOS/EOS
# input_tokens[:, 0] = self.tokenizer.bos_token_id
# output_tokens[:, -1] = self.tokenizer.eos_token_id
# elif self.parameterization == 'ar':
# input_tokens = x0[:, :-1]
# output_tokens = x0[:, 1:]
# new_attention_mask = attention_mask[:, 1:]
# else:
input_tokens = x0
output_tokens = None
new_attention_mask = attention_mask
return input_tokens, output_tokens, new_attention_mask
def _reconstruction_loss(self, x0, attention_mask):
t0 = torch.zeros(x0.shape[0], dtype=self.dtype,
device=self.device)
assert self.config.noise.type == 'loglinear'
# The above assert is for d3pm parameterization
unet_conditioning = self.noise(t0)[0][:, None]
model_output_t0 = self.forward(x0, unet_conditioning, attention_mask)
return - torch.gather(input=model_output_t0,
dim=-1,
index=x0[:, :, None]).squeeze(-1)
def _forward_pass_diffusion(self, x0, attention_mask, mask=None):
t = self._sample_t(x0.shape[0], x0.device)
if self.T > 0:
t = (t * self.T).to(torch.int)
t = t / self.T
# t \in {1/T, 2/T, ..., 1}
t += (1 / self.T)
if self.change_of_variables:
unet_conditioning = t[:, None]
f_T = torch.log1p(- torch.exp(- self.noise.sigma_max))
f_0 = torch.log1p(- torch.exp(- self.noise.sigma_min))
move_chance = torch.exp(f_0 + t * (f_T - f_0))
move_chance = move_chance[:, None]
else:
sigma, dsigma = self.noise(t)
unet_conditioning = sigma[:, None]
move_chance = 1 - torch.exp(-sigma[:, None])
if mask is None: xt = self.q_xt(x0, move_chance)
else: xt = x0.where(mask==1, torch.full_like(x0, self.tokenizer.mask_token_id))
model_output = self.forward(xt, unet_conditioning, attention_mask)
# print(self.tokenizer.decode(torch.argmax(model_output[0], dim=-1)))
utils.print_nans(model_output, 'model_output')
if self.parameterization == 'sedd':
return dsigma[:, None] * self._score_entropy(
model_output, sigma[:, None], xt, x0)
if self.T > 0:
diffusion_loss = self._d3pm_loss(
model_output=model_output, xt=xt, x0=x0, t=t)
if self.parameterization == 'd3pm':
reconstruction_loss = self._reconstruction_loss(x0)
elif self.parameterization == 'subs':
reconstruction_loss = 0
return reconstruction_loss + diffusion_loss
# SUBS parameterization, continuous time.
log_p_theta = torch.gather(
input=model_output,
dim=-1,
index=x0[:, :, None]).squeeze(-1)
if self.change_of_variables or self.importance_sampling:
return log_p_theta * torch.log1p(
- torch.exp(- self.noise.sigma_min))
return - log_p_theta * (
dsigma / torch.expm1(sigma))[:, None]
def _loss(self, x0, attention_mask, mask=None):
(input_tokens, output_tokens,
attention_mask) = self._maybe_sub_sample(
x0, attention_mask)
if self.parameterization == 'ar':
logprobs = self.backbone(input_tokens, None, attention_mask)
loss = - logprobs.gather(
-1, output_tokens[:, :, None])[:, :, 0]
else:
loss = self._forward_pass_diffusion(input_tokens, attention_mask, mask)
nlls = loss * attention_mask
count = attention_mask.sum()
batch_nll = nlls.sum()
token_nll = batch_nll / count
return Loss(loss=token_nll,
nlls=nlls,
token_mask=attention_mask)
def _score_entropy(self, log_score, sigma, xt, x0):
"""Computes the SEDD loss.
Args:
log_score: float torch.Tensor with shape (batch_size,
diffusion_model_input_length, vocab_size),
log score, output of the denoising network.
xt: int torch.Tensor with shape (batch_size,
diffusion_model_input_length), input.
x0: int torch.Tensor with shape (batch_size,
diffusion_model_input_length), input.
sigma: float torch.Tensor with shape (batch_size, 1).
Returns:
loss with shape (batch_size, diffusion_model_input_length)
"""
masked_indices = xt == self.mask_index
expsig_minus_1 = torch.expm1(sigma).expand_as(xt)
q_ratio = 1 / expsig_minus_1[masked_indices]
words_that_were_masked = x0[masked_indices]
neg_term = q_ratio * torch.gather(
log_score[masked_indices],
-1,
words_that_were_masked[..., None]).squeeze(-1)
score = log_score[masked_indices].exp()
if self.mask_index == self.vocab_size - 1:
pos_term = score[:, :-1].sum(dim=-1)
else:
pos_term = score[:, : self.mask_index].sum(
dim=-1) + score[:, self.mask_index + 1:].sum(dim=-1)
const = q_ratio * (q_ratio.log() - 1)
entropy = torch.zeros(* xt.shape, device=xt.device)
entropy[masked_indices] += pos_term - neg_term + const
return entropy
@torch.no_grad
def sample_subs_guidance(
self, n_samples, stride_length, num_strides, dt=0.001):
ones = torch.ones(n_samples, dtype=self.dtype,
device=self.device)
num_steps = int(1 / dt)
sampling_steps = 0
intermediate_tokens = []
target = None
for _ in range(num_strides + 1):
p_x0_cache = None
x = self._sample_prior(
n_samples,
self.config.model.length).to(self.device)
if target is not None:
x[:, : -stride_length] = target
for i in range(num_steps + 1):
p_x0_cache, x_next = self._ddpm_caching_update(
x=x, t=(1 - i * dt) * ones, dt=dt, p_x0=p_x0_cache)
if (not torch.allclose(x_next, x)
or self.time_conditioning):
p_x0_cache = None
sampling_steps += 1
x = x_next
x = self.forward(x, 0 * ones).argmax(dim=-1)
intermediate_tokens.append(
x[:, :stride_length].cpu().numpy())
target = x[:, stride_length:]
intermediate_tokens.append(target.cpu().numpy())
intermediate_text_samples = []
sequence_lengths = ((
np.concatenate(intermediate_tokens, axis=1)[:, 1:]
== self.tokenizer.eos_token_id).cumsum(-1) == 0).sum(-1)
for i in range(2, len(intermediate_tokens) + 1):
intermediate_text_samples.append(
self.tokenizer.batch_decode(
np.concatenate(intermediate_tokens[:i], axis=1)))
return (sampling_steps, intermediate_text_samples,
sequence_lengths)
def restore_model_and_semi_ar_sample(
self, stride_length, num_strides, dt=0.001):
"""Generate samples from the model."""
# Lightning auto-casting is not working in this method for some reason
# params_with_grad = [p for p in itertools.chain(
# self.backbone.parameters(),
# self.noise.parameters()
# ) if p]
if self.ema:
self.ema.store(itertools.chain(self.backbone.parameters(),
self.noise.parameters()))
self.ema.copy_to(itertools.chain(self.backbone.parameters(),
self.noise.parameters()))
self.backbone.eval()
self.noise.eval()
(sampling_steps, samples,
sequence_lengths) = self.sample_subs_guidance(
n_samples=self.config.loader.eval_batch_size,
stride_length=stride_length,
num_strides=num_strides,
dt=dt)
if self.ema:
self.ema.restore(itertools.chain(self.backbone.parameters(),
self.noise.parameters()))
self.backbone.train()
self.noise.train()
return sampling_steps, samples, sequence_lengths