import itertools import math import os import sys import typing from dataclasses import dataclass import hydra.utils import lightning as L import numpy as np import torch.nn as nn import torch # import dit import ema import time import gc import pl_data_loader as dataloader import torch.nn.functional as F import torchmetrics import transformers from torch import Tensor from torch.optim.lr_scheduler import _LRScheduler from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer import utils import noise_schedule LOG2 = math.log(2) class CosineWarmup(_LRScheduler): def __init__(self, optimizer, warmup_steps, total_steps, eta_ratio=0.1, last_epoch=-1): self.warmup_steps = warmup_steps self.total_steps = total_steps self.eta_ratio = eta_ratio # The ratio of minimum to maximum learning rate super(CosineWarmup, self).__init__(optimizer, last_epoch) def get_lr(self): if self.last_epoch < self.warmup_steps: return [base_lr * self.last_epoch / self.warmup_steps for base_lr in self.base_lrs] progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps) cosine_decay = 0.5 * (1 + np.cos(np.pi * progress)) decayed_lr = (1 - self.eta_ratio) * cosine_decay + self.eta_ratio return [decayed_lr * base_lr for base_lr in self.base_lrs] def _sample_categorical(categorical_probs): gumbel_norm = ( 1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log()) return (categorical_probs / gumbel_norm).argmax(dim=-1) def _unsqueeze(x, reference): return x.view( * x.shape, * ((1,) * (len(reference.shape) - len(x.shape)))) @dataclass class Loss: loss: torch.FloatTensor nlls: torch.FloatTensor token_mask: torch.FloatTensor class NLL(torchmetrics.aggregation.MeanMetric): pass class BPD(NLL): def compute(self) -> Tensor: """Computes the bits per dimension. Returns: bpd """ return self.mean_value / self.weight / LOG2 class Perplexity(NLL): def compute(self) -> Tensor: """Computes the Perplexity. Returns: Perplexity """ return torch.exp(self.mean_value / self.weight) class WrapVanillaESM(nn.Module): def __init__(self, bert_model_path): super(WrapVanillaESM, self).__init__() #self.bert_model_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') #self.model = AutoModelForMaskedLM.from_pretrained(bert_model_path).to(self.bert_model_device) self.model = AutoModelForMaskedLM.from_pretrained(bert_model_path, device_map='cpu') self.tokenizer = AutoTokenizer.from_pretrained(bert_model_path) def __call__(self, *args, **kwargs): return self.model(*args, **kwargs) def unfreeze_attn_layers(self): model_layers = len(self.model.esm.encoder.layer) for i, layer in enumerate(self.model.esm.encoder.layer): if i >= model_layers-5: # fine-tune only last n layers for module in layer.attention.self.key.modules(): for param in module.parameters(): param.requires_grad = True for module in layer.attention.self.query.modules(): for param in module.parameters(): param.requires_grad = True for module in layer.attention.self.value.modules(): for param in module.parameters(): param.requires_grad = True def unfreeze_all_layers(self): for param in self.model.parameters(): param.requires_grad = True def forward(self, inputs, sigma, attention_mask): logits = self.model(input_ids=inputs, attention_mask=attention_mask).logits return logits def save_model(self, save_dir): self.model.save_pretrained(save_dir) self.tokenizer.save_pretrained(save_dir) def load_model(self, load_dir): self.model = AutoModel.from_pretrained(load_dir) self.tokenizer = AutoTokenizer.from_pretrained(load_dir) class WrapMembraneESM(nn.Module): def __init__(self, bert_model_path): super(WrapMembraneESM, self).__init__() #self.bert_model_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') #self.model = AutoModelForMaskedLM.from_pretrained(bert_model_path).to(self.bert_model_device) self.model = AutoModelForMaskedLM.from_pretrained(bert_model_path, device_map='cpu') self.tokenizer = AutoTokenizer.from_pretrained(bert_model_path) def __call__(self, *args, **kwargs): return self.model(*args, **kwargs) def freeze_model(self): for param in self.model.parameters(): param.requires_grad = False def unfreeze_all_layers(self): for param in self.model.parameters(): param.requires_grad = True def unfreeze_attn_layers(self): model_layers = len(self.model.esm.encoder.layer) for i, layer in enumerate(self.model.esm.encoder.layer): if i >= model_layers-11: # fine-tune only last n layers for module in layer.attention.self.key.modules(): for param in module.parameters(): param.requires_grad = True for module in layer.attention.self.query.modules(): for param in module.parameters(): param.requires_grad = True for module in layer.attention.self.value.modules(): for param in module.parameters(): param.requires_grad = True def forward(self, inputs, sigma, attention_mask): logits = self.model(input_ids=inputs, attention_mask=attention_mask).logits return logits def save_model(self, save_dir): self.model.save_pretrained(save_dir) self.tokenizer.save_pretrained(save_dir) def load_model(self, load_dir): self.model = AutoModel.from_pretrained(load_dir) self.tokenizer = AutoTokenizer.from_pretrained(load_dir) class Diffusion(L.LightningModule): def __init__( self, config, tokenizer: transformers.PreTrainedTokenizer): super().__init__() self.save_hyperparameters() self.config = config self.tokenizer = tokenizer self.vocab_size = self.tokenizer.vocab_size self.sampler = self.config.sampling.predictor self.gen_ppl_eval_model_name_or_path = self.config.eval.\ gen_ppl_eval_model_name_or_path self.antithetic_sampling = self.config.training.antithetic_sampling self.importance_sampling = self.config.training.importance_sampling self.change_of_variables = self.config.training.change_of_variables if (not hasattr(self.tokenizer, 'mask_token') or self.tokenizer.mask_token is None): self.mask_index = self.vocab_size self.vocab_size += 1 else: self.mask_index = self.tokenizer.mask_token_id self.parameterization = self.config.parameterization # if self.config.backbone == 'dit': # self.backbone = dit.DIT( # self.config, vocab_size=self.vocab_size, mlm_model_path=config.training.mlm_model_path) if self.config.backbone == "vanilla_esm_pretrain": self.backbone = WrapVanillaESM(bert_model_path=self.config.training.esm_model_path) self.backbone.unfreeze_all_layers() self.backbone = torch.compile(self.backbone) elif self.config.backbone == 'membrane_esm_finetune': self.backbone = WrapMembraneESM(bert_model_path=self.config.checkpointing.pretrained_esm_mdlm_automodel_path) self.backbone.unfreeze_all_layers() # self.backbone = torch.compile(self.backbone) # elif self.config.backbone == 'dimamba': # self.backbone = dimamba.DiMamba( # self.config, # vocab_size=self.vocab_size, # pad_token_id=self.tokenizer.pad_token_id) # elif self.config.backbone == 'ar': # self.backbone = autoregressive.AR( # self.config, # vocab_size=self.vocab_size, # mask_index=self.mask_index) # elif self.config.backbone == 'hf_dit': # self.backbone = transformers.AutoModelForMaskedLM.from_pretrained( # config.eval.checkpoint_path, trust_remote_code=True) # else: # raise ValueError( # f'Unknown backbone: {self.config.backbone}') self.T = self.config.T self.subs_masking = self.config.subs_masking self.softplus = torch.nn.Softplus() # metrics are automatically reset at end of epoch metrics = torchmetrics.MetricCollection({ 'nll': NLL(), 'bpd': BPD(), 'ppl': Perplexity(), }) metrics.set_dtype(torch.float64) self.train_metrics = metrics.clone(prefix='train/') self.valid_metrics = metrics.clone(prefix='val/') self.test_metrics = metrics.clone(prefix='test/') # generative perplexity self.gen_ppl_metric = Perplexity() self.eval_model_tokenizer = transformers.AutoTokenizer.\ from_pretrained(self.gen_ppl_eval_model_name_or_path) if self.eval_model_tokenizer.pad_token is None: self.eval_model_tokenizer.pad_token =\ self.eval_model_tokenizer.eos_token self.eval_model_tokenizer.pad_token_id =\ self.eval_model_tokenizer.eos_token_id self.noise = noise_schedule.get_noise(self.config, dtype=self.dtype) if self.config.training.ema > 0: self.ema = ema.ExponentialMovingAverage( itertools.chain(self.backbone.parameters(), self.noise.parameters()), decay=self.config.training.ema) else: self.ema = None self.lr = self.config.optim.lr self.sampling_eps = self.config.training.sampling_eps self.time_conditioning = self.config.time_conditioning self.neg_infinity = -1000000.0 self.fast_forward_epochs = None self.fast_forward_batches = None self._validate_configuration() def _validate_configuration(self): assert not (self.change_of_variables and self.importance_sampling) if self.parameterization == 'sedd': assert not self.importance_sampling assert not self.change_of_variables if self.parameterization == 'd3pm': assert self.T > 0 if self.T > 0: assert self.parameterization in {'d3pm', 'subs'} if self.subs_masking: assert self.parameterization == 'd3pm' def on_load_checkpoint(self, checkpoint): if self.ema: self.ema.load_state_dict(checkpoint['ema']) # Copied from: # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py#L41 self.fast_forward_epochs = checkpoint['loops'][ 'fit_loop']['epoch_progress']['current']['completed'] self.fast_forward_batches = checkpoint['loops'][ 'fit_loop']['epoch_loop.batch_progress'][ 'current']['completed'] def on_save_checkpoint(self, checkpoint): if self.ema: checkpoint['ema'] = self.ema.state_dict() # Copied from: # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py # ['epoch_loop.batch_progress']['total']['completed'] is 1 iteration # behind, so we're using the optimizer's progress. checkpoint['loops']['fit_loop'][ 'epoch_loop.batch_progress']['total'][ 'completed'] = checkpoint['loops']['fit_loop'][ 'epoch_loop.automatic_optimization.optim_progress'][ 'optimizer']['step']['total'][ 'completed'] * self.trainer.accumulate_grad_batches checkpoint['loops']['fit_loop'][ 'epoch_loop.batch_progress']['current'][ 'completed'] = checkpoint['loops']['fit_loop'][ 'epoch_loop.automatic_optimization.optim_progress'][ 'optimizer']['step']['current'][ 'completed'] * self.trainer.accumulate_grad_batches # _batches_that_stepped tracks the number of global steps, not the number # of local steps, so we don't multiply with self.trainer.accumulate_grad_batches here. checkpoint['loops']['fit_loop'][ 'epoch_loop.state_dict'][ '_batches_that_stepped'] = checkpoint['loops']['fit_loop'][ 'epoch_loop.automatic_optimization.optim_progress'][ 'optimizer']['step']['total']['completed'] if 'sampler' not in checkpoint.keys(): checkpoint['sampler'] = {} if hasattr(self.trainer.train_dataloader.sampler, 'state_dict'): sampler_state_dict = self.trainer.\ train_dataloader.sampler.state_dict() checkpoint['sampler'][ 'random_state'] = sampler_state_dict.get( 'random_state', None) else: checkpoint['sampler']['random_state'] = None self.backbone.save_model(self.config.checkpointing.fine_tuned_esm_mdlm_ckpt_path) def on_train_start(self): torch.cuda.empty_cache() if self.ema: self.ema.move_shadow_params_to_device(self.device) # Adapted from: # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py distributed = ( self.trainer._accelerator_connector.use_distributed_sampler and self.trainer._accelerator_connector.is_distributed) if distributed: sampler_cls = dataloader.FaultTolerantDistributedSampler else: sampler_cls = dataloader.RandomFaultTolerantSampler updated_dls = [] for dl in self.trainer.fit_loop._combined_loader.flattened: if hasattr(dl.sampler, 'shuffle'): dl_sampler = sampler_cls( dl.dataset, shuffle=dl.sampler.shuffle) else: dl_sampler = sampler_cls(dl.dataset) if (distributed and self.fast_forward_epochs is not None and self.fast_forward_batches is not None): dl_sampler.load_state_dict({ 'epoch': self.fast_forward_epochs, 'counter': (self.fast_forward_batches * self.config.loader.batch_size)}) from functools import partial from pl_data_loader import collate_fn collate_partial = partial(collate_fn, tokenizer=self.tokenizer) torch.cuda.empty_cache() updated_dls.append( torch.utils.data.DataLoader( dl.dataset, batch_size=self.config.loader.batch_size, num_workers=self.config.loader.num_workers, pin_memory=self.config.loader.pin_memory, sampler=dl_sampler, shuffle=False, persistent_workers=False, collate_fn=collate_partial)) self.trainer.fit_loop._combined_loader.flattened = updated_dls def optimizer_step(self, *args, **kwargs): super().optimizer_step(*args, **kwargs) gc.collect() torch.cuda.empty_cache() if self.ema: self.ema.update(itertools.chain( self.backbone.parameters(), self.noise.parameters())) # optimizer_closure = kwargs.get('optimizer_closure', None) # params_with_grad = [p for p in itertools.chain( # self.backbone.parameters(), # self.noise.parameters() # ) if p.requires_grad and p.grad_fn is not None] # # if params_with_grad: # # super().optimizer_step(closure=optimizer_closure) # if self.ema: # self.ema.update(params_with_grad) # super().optimizer_step(*args, **kwargs) def _subs_parameterization(self, logits, xt): # log prob at the mask index = - infinity logits = logits.logits logits[:, :, self.mask_index] += self.neg_infinity # logits[:, :, self.tokenizer.eos_token_id] += self.neg_infinity # logits[:, :, self.tokenizer.cls_token_id] += self.neg_infinity # Normalize the logits such that x.exp() is # a probability distribution over vocab_size. logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True) # Apply updates directly in the logits matrix. # For the logits of the unmasked tokens, set all values # to -infinity except for the indices corresponding to # the unmasked tokens. unmasked_indices = (xt != self.mask_index) logits[unmasked_indices] = self.neg_infinity logits[unmasked_indices, xt[unmasked_indices]] = 0 return logits def _d3pm_parameterization(self, logits): if self.subs_masking: logits[:, :, self.mask_index] += self.neg_infinity logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True) return logits def _sedd_parameterization(self, logits, xt, sigma): esigm1_log = torch.where( sigma < 0.5, torch.expm1(sigma), sigma.exp() - 1).log().to(logits.dtype) # logits shape # (batch_size, diffusion_model_input_length, vocab_size) logits = logits - esigm1_log[:, None, None] - np.log( logits.shape[-1] - 1) # The below scatter operation sets the log score # for the input word to 0. logits = torch.scatter(logits, -1, xt[..., None], torch.zeros_like(logits[..., :1])) return logits def _process_sigma(self, sigma): if sigma is None: assert self.parameterization == 'ar' return sigma if sigma.ndim > 1: sigma = sigma.squeeze(-1) if not self.time_conditioning: sigma = torch.zeros_like(sigma) assert sigma.ndim == 1, sigma.shape return sigma def forward(self, x, sigma, attention_mask, print_logits=False): """Returns log score.""" sigma = self._process_sigma(sigma) with torch.amp.autocast("cuda", dtype=torch.float32): logits = self.backbone(x, attention_mask) # if print_logits: # torch.set_printoptions(profile="full") # print(logits) # torch.set_printoptions(profile="default") if self.parameterization == 'subs': return self._subs_parameterization(logits=logits, xt=x) return logits def _d3pm_loss(self, model_output, xt, x0, t, attention_mask): dt = 1 / self.T if torch.is_tensor(t): t = t[:, None] assert t.ndim == 2 t = t.clamp(0., 1. - 1e-4) alpha_t = 1 - t + torch.zeros_like(xt) alpha_s = 1 - (t - dt) + torch.zeros_like(xt) log_x_theta_at_x0 = torch.gather( model_output, -1, x0[:, :, None]).squeeze(-1) log_x_theta_at_m = model_output[:, :, self.mask_index] x_theta_at_m = log_x_theta_at_m.exp() term_1_coef = dt / t term_1_log_nr = torch.log(alpha_t * x_theta_at_m / t + 1) term_1_log_dr = log_x_theta_at_x0 term_2_coef = 1 - dt / t term_2_log_nr = term_1_log_nr term_2_log_dr = torch.log(alpha_s * x_theta_at_m / (t - dt) + 1) L_vb_masked = ( term_1_coef * (term_1_log_nr - term_1_log_dr) + term_2_coef * (term_2_log_nr - term_2_log_dr)) L_vb = L_vb_masked * (xt == self.mask_index) return self.T * L_vb def _compute_loss(self, batch, prefix): if 'attention_mask' in batch: attention_mask = batch['attention_mask'] else: attention_mask = None if 'mask' in batch: mask = batch['mask'] else: mask = None losses = self._loss(batch['input_ids'], attention_mask, mask) loss = losses.loss if prefix == 'train': self.train_metrics.update(losses.nlls, losses.token_mask) metrics = self.train_metrics elif prefix == 'val': self.valid_metrics.update(losses.nlls, losses.token_mask) metrics = self.valid_metrics elif prefix == 'test': self.test_metrics.update(losses.nlls, losses.token_mask) metrics = self.test_metrics else: raise ValueError(f'Invalid prefix: {prefix}') self.log_dict(metrics, on_step=False, on_epoch=True, sync_dist=True) return loss def on_train_epoch_start(self): self.backbone.train() self.noise.train() def training_step(self, batch, batch_idx): # Initialize throughput calculation start_time = time.time() loss = self._compute_loss(batch, prefix='train') self.log(name='trainer/loss', value=loss.item(), on_step=True, on_epoch=False, sync_dist=True) # Calculate throughput elapsed_time = time.time() - start_time total_tokens = batch['input_ids'].numel() throughput = total_tokens / elapsed_time self.log(name='trainer/throughput', value=throughput, on_step=True, on_epoch=False, sync_dist=True) return loss def on_validation_epoch_start(self): # params_with_grad = [p for p in itertools.chain( # self.backbone.parameters(), # self.noise.parameters() # ) if p.requires_grad] # if self.ema: # self.ema.store(params_with_grad) # self.ema.copy_to(params_with_grad) gc.collect() torch.cuda.empty_cache() if self.ema: self.ema.store( itertools.chain( self.backbone.parameters(), self.noise.parameters())) self.ema.copy_to(itertools.chain( self.backbone.parameters(), self.noise.parameters())) self.backbone.eval() self.noise.eval() assert self.valid_metrics.nll.mean_value == 0 assert self.valid_metrics.nll.weight == 0 def validation_step(self, batch, batch_idx): loss = self._compute_loss(batch, prefix='val') self.log(name='trainer/val_loss', value=loss.item(), on_step=True, on_epoch=False, prog_bar=True, sync_dist=True) return loss def on_validation_epoch_end(self): # params_with_grad = [p for p in itertools.chain( # self.backbone.parameters(), # self.noise.parameters() # ) if p.requires_grad] # if ((self.config.eval.compute_perplexity_on_sanity # or not self.trainer.sanity_checking) # and self.config.eval.generate_samples # and not self.parameterization == 'ar'): # # (justin): implement sampling and kv cache for AR # samples, text_samples = None, None # for _ in range( # self.config.sampling.num_sample_batches): # samples = self._sample() # # Decode the samples to be re-tokenized by eval model # text_samples = self.tokenizer.batch_decode(samples) # if self.config.eval.compute_generative_perplexity: # self.compute_generative_perplexity(text_samples) # if self.trainer.global_rank == 0 and hasattr( # self.trainer.logger, 'log_table'): # # Log the last generated samples # text_samples = text_samples[ # : self.config.sampling.num_sample_log] # self.trainer.logger.log_table( # key=f'samples@global_step{self.global_step}', # columns=['Generated Samples'], # data=[[s] for s in text_samples]) # if self.config.eval.compute_generative_perplexity: # self.log('val/gen_ppl', # self.gen_ppl_metric, # on_epoch=True, # on_step=False, # sync_dist=True) gc.collect() torch.cuda.empty_cache() if self.ema: self.ema.restore( itertools.chain( self.backbone.parameters(), self.noise.parameters())) def test_step(self, batch, batch_idx): loss = self._compute_loss(batch, prefix='test') self.log('test/loss', value=loss.item(), on_step=False, on_epoch=True, sync_dist=True) if self.config.eval.compute_generative_perplexity: samples, text_samples = None, None for _ in range( self.config.sampling.num_sample_batches): samples = self._sample() # Decode the samples to be re-tokenized by eval model text_samples = self.tokenizer.batch_decode(samples) if self.config.eval.compute_generative_perplexity: self.compute_generative_perplexity(text_samples) if self.trainer.global_rank == 0 and hasattr( self.trainer.logger, 'log_table'): # Log the last generated samples text_samples = text_samples[ : self.config.sampling.num_sample_log] self.trainer.logger.log_table( key=f'samples@global_step{self.global_step}', columns=['Generated Samples'], data=[[s] for s in text_samples]) if self.config.eval.compute_generative_perplexity: self.log('test/gen_ppl', self.gen_ppl_metric, on_epoch=False, on_step=True, sync_dist=True) def on_test_epoch_start(self): # params_with_grad = [p for p in itertools.chain( # self.backbone.parameters(), # self.noise.parameters() # ) if p.requires_grad] if self.ema: self.ema.store(itertools.chain( self.backbone.parameters(), self.noise.parameters())) self.ema.copy_to(itertools.chain( self.backbone.parameters(), self.noise.parameters())) self.backbone.eval() self.noise.eval() self.test_metrics.reset() def on_test_epoch_end(self): # params_with_grad = [p for p in itertools.chain( # self.backbone.parameters(), # self.noise.parameters() # ) if p.requires_grad] if self.ema: self.ema.restore(itertools.chain( self.backbone.parameters(), self.noise.parameters())) for metric_name, metric_value in self.test_metrics.compute().items(): self.log(metric_name, metric_value, sync_dist=True) def configure_optimizers(self): # (yair): Lightning currently giving this warning when using `fp16`: # "Detected call of `lr_scheduler.step()` before `optimizer.step()`. " # Not clear if this is a problem or not. # See: https://github.com/Lightning-AI/pytorch-lightning/issues/5558 # params_with_grad = [p for p in itertools.chain( # self.backbone.parameters(), # self.noise.parameters() # ) if p.requires_grad] optimizer = torch.optim.AdamW( itertools.chain(self.backbone.parameters(), self.noise.parameters()), lr=self.config.optim.lr, betas=(self.config.optim.beta1, self.config.optim.beta2), eps=self.config.optim.eps, weight_decay=self.config.optim.weight_decay ) # scheduler = hydra.utils.instantiate( # self.config.lr_scheduler, optimizer=optimizer) # scheduler_dict = { # 'scheduler': scheduler, # 'interval': 'step', # 'monitor': 'val/loss', # 'name': 'trainer/lr', # } self.total_steps = self.config.trainer.max_steps scheduler = CosineWarmup(optimizer, warmup_steps=self.config.lr_scheduler.num_warmup_steps, total_steps=self.total_steps) scheduler_dict = { 'scheduler': scheduler, 'interval': 'step', 'frequency': 1, 'monitor': 'val/loss', 'name': 'trainer/lr' } return [optimizer], [scheduler_dict] @torch.no_grad() def eval_retokenize(self, text_samples, max_length): """Retokenizes samples for the eval model. Args: text_samples: List of sentences generated by the model. Returns: samples: Samples re-tokenized for the eval model attn_mask: Attention mask for the eval model eval_context_size: Size of the context for the eval model """ if 'llama2' in self.gen_ppl_eval_model_name_or_path: tokenizer_kwargs = { 'text_samples': text_samples, 'return_tensors': 'pt', 'return_token_type_ids': False, 'return_attention_mask': True, 'truncation': True, 'padding': True, 'max_length': max_length, } eval_context_size = 4096 else: tokenizer_kwargs = { 'return_tensors': 'pt', 'return_token_type_ids': False, 'return_attention_mask': True, 'truncation': True, 'padding': True, 'max_length': max_length, } eval_context_size = 1024 samples = self.eval_model_tokenizer( text_samples, ** tokenizer_kwargs) attn_mask = samples['attention_mask'] samples = samples['input_ids'] if 'llama2' not in self.gen_ppl_eval_model_name_or_path: attn_mask = attn_mask.to(self.device) samples = samples.to(self.device) return samples, attn_mask, eval_context_size # @torch.no_grad() # def compute_generative_perplexity( # self, # text_samples: typing.List[str], # retokenize: bool = True, # max_length: typing.Optional[int] = None) -> None: # """Compute the generative perplexity of the model. # Args: # text_samples: List of sentences generated by the model. # Returns: # Perplexity of the generated text under a different # pre-trained AR model (e.g., GPT2). # """ # os.environ['TOKENIZERS_PARALLELISM'] = 'false' # eval_model = transformers.AutoModelForCausalLM.from_pretrained( # self.gen_ppl_eval_model_name_or_path).eval() # if max_length is None: # max_length = self.config.model.length # if 'llama2' not in self.gen_ppl_eval_model_name_or_path: # eval_model = eval_model.to(self.device) # # Re-tokenize using eval model's tokenizer # if retokenize: # (samples, attn_mask, # eval_context_size) = self.eval_retokenize( # text_samples, max_length=max_length) # else: # samples = text_samples # attn_mask = torch.ones(samples.shape).to(self.device) # eval_context_size = samples.shape[-1] # batch_size = min( # self.config.eval.perplexity_batch_size, # samples.shape[0]) # num_batches = samples.shape[0] // batch_size # for i in range(num_batches): # _samples = torch.split( # samples[i * batch_size: (i + 1) * batch_size], # eval_context_size, # dim=-1) # _attn_mask = torch.split( # attn_mask[i * batch_size: (i + 1) * batch_size], # eval_context_size, # dim=-1) # for (sample_chunk, attn_mask_chunk) in zip( # _samples, _attn_mask): # logits = eval_model( # sample_chunk, attention_mask=attn_mask_chunk)[0] # logits = logits.transpose(-1, -2) # nlls = F.cross_entropy(logits[..., :-1], # sample_chunk[..., 1:], # reduction='none') # first_eos = (sample_chunk == self.eval_model_tokenizer\ # .eos_token_id).cumsum(-1) == 1 # token_mask = ( # sample_chunk # != self.eval_model_tokenizer.eos_token_id) # self.gen_ppl_metric.update( # nlls, first_eos[..., 1:] + token_mask[..., 1:]) @torch.no_grad() def compute_masked_perplexity(self, sequences, masked): """Compute the pseudo-perplexity of the generated protein sequences.""" total_nll = 0 total_tokens = 0 for sequence in sequences: # Tokenize the sequence input_ids = self.tokenizer(masked, return_tensors="pt").input_ids.to(self.device) gt_ids = self.tokenizer(sequence.upper(), return_tensors="pt").input_ids.to(self.device) # print(input_ids.shape) # print(gt_ids.shape) # Forward pass through the ESM model attention_mask = torch.ones_like(input_ids) if self.config.mode in ['train', 'ppl_eval']: outputs = self.backbone.model.forward(input_ids=input_ids, attention_mask=attention_mask) elif self.config.mode == "sample_eval": outputs = self.backbone.model.forward(input_ids) logits = outputs[-1] # B, L, V # Compute loss # shift_logits = logits[:, :-1, :].contiguous() # remove eos # shift_labels = input_ids[:, 1:].contiguous() # print(masked) # print(gt_ids.where(input_ids==32, torch.full_like(input_ids, -100)).view(-1)) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), gt_ids.where(input_ids==32, torch.full_like(input_ids, -100)).view(-1), reduction='sum') total_nll += loss.item() #total_tokens += (input_ids != self.tokenizer.pad_token_id).sum().item() - 1 # -1 for the first token total_tokens += input_ids.ne(self.tokenizer.pad_token_id).sum().item() # count in bos and eos # Compute pseudo-perplexity # print(total_nll, ",;,", total_tokens) pseudo_perplexity = torch.exp(torch.tensor(total_nll / total_tokens)) self.gen_ppl_metric.update(pseudo_perplexity) return pseudo_perplexity.item() @torch.no_grad() def compute_generative_perplexity( self, text_samples: typing.List[str], retokenize: bool = True, max_length: typing.Optional[int] = None) -> None: """Compute the generative perplexity of the model. Args: text_samples: List of sentences generated by the model. Returns: Perplexity of the generated text under a different pre-trained AR model (e.g., GPT2). """ os.environ['TOKENIZERS_PARALLELISM'] = 'false' eval_model = transformers.AutoModelForCausalLM.from_pretrained( self.gen_ppl_eval_model_name_or_path).eval() if max_length is None: max_length = self.config.model.length if 'llama2' not in self.gen_ppl_eval_model_name_or_path: eval_model = eval_model.to(self.device) # Re-tokenize using eval model's tokenizer if retokenize: (samples, attn_mask, eval_context_size) = self.eval_retokenize( text_samples, max_length=max_length) else: samples = text_samples attn_mask = torch.ones(samples.shape).to(self.device) eval_context_size = samples.shape[-1] batch_size = min( self.config.eval.perplexity_batch_size, samples.shape[0]) num_batches = samples.shape[0] // batch_size for i in range(num_batches): _samples = torch.split( samples[i * batch_size: (i + 1) * batch_size], eval_context_size, dim=-1) _attn_mask = torch.split( attn_mask[i * batch_size: (i + 1) * batch_size], eval_context_size, dim=-1) for (sample_chunk, attn_mask_chunk) in zip( _samples, _attn_mask): logits = eval_model( sample_chunk, attention_mask=attn_mask_chunk)[0] logits = logits.transpose(-1, -2) nlls = F.cross_entropy(logits[..., :-1], sample_chunk[..., 1:], reduction='none') first_eos = (sample_chunk == self.eval_model_tokenizer\ .eos_token_id).cumsum(-1) == 1 token_mask = ( sample_chunk != self.eval_model_tokenizer.eos_token_id) self.gen_ppl_metric.update( nlls, first_eos[..., 1:] + token_mask[..., 1:]) def q_xt(self, x, move_chance): """Computes the noisy sample xt. Args: x: int torch.Tensor with shape (batch_size, diffusion_model_input_length), input. move_chance: float torch.Tensor with shape (batch_size, 1). """ actual_seq_length = (x != 1).sum(dim=1, keepdim=True) max_mask_length = (actual_seq_length * 0.75).long() move_indices = torch.rand(*x.shape, device=x.device) < move_chance restricted_move_indices = torch.zeros_like(move_indices, dtype=torch.bool) for i in range(x.shape[0]): true_positions = torch.where(move_indices[i])[0] if len(true_positions) > max_mask_length[i]: selected_positions = true_positions[:max_mask_length[i].item()] restricted_move_indices[i, selected_positions] = True else: restricted_move_indices[i] = move_indices[i] xt = torch.where(restricted_move_indices, self.mask_index, x) return xt def _sample_prior(self, *batch_dims): return self.mask_index * torch.ones(* batch_dims, dtype=torch.int64) def _ddpm_caching_update(self, x, t, dt, p_x0=None, attention_mask=None): assert self.config.noise.type == 'loglinear' sigma_t, _ = self.noise(t) if t.ndim > 1: t = t.squeeze(-1) assert t.ndim == 1 move_chance_t = t[:, None, None] move_chance_s = (t - dt)[:, None, None] assert move_chance_t.ndim == 3, move_chance_t.shape if p_x0 is None: p_x0 = self.forward(x, sigma_t, attention_mask).exp() assert move_chance_t.ndim == p_x0.ndim q_xs = p_x0 * (move_chance_t - move_chance_s) q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0] _x = _sample_categorical(q_xs) copy_flag = (x != self.mask_index).to(x.dtype) return p_x0, copy_flag * x + (1 - copy_flag) * _x def _ddpm_update(self, x, t, dt, attention_mask): sigma_t, _ = self.noise(t) sigma_s, _ = self.noise(t - dt) if sigma_t.ndim > 1: sigma_t = sigma_t.squeeze(-1) if sigma_s.ndim > 1: sigma_s = sigma_s.squeeze(-1) assert sigma_t.ndim == 1, sigma_t.shape assert sigma_s.ndim == 1, sigma_s.shape move_chance_t = 1 - torch.exp(-sigma_t) move_chance_s = 1 - torch.exp(-sigma_s) move_chance_t = move_chance_t[:, None, None] move_chance_s = move_chance_s[:, None, None] unet_conditioning = sigma_t log_p_x0 = self.forward(x, unet_conditioning, attention_mask) assert move_chance_t.ndim == log_p_x0.ndim # Technically, this isn't q_xs since there's a division # term that is missing. This division term doesn't affect # the samples. q_xs = log_p_x0.exp() * (move_chance_t - move_chance_s) q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0] _x = _sample_categorical(q_xs) copy_flag = (x != self.mask_index).to(x.dtype) return copy_flag * x + (1 - copy_flag) * _x def _ar_sampler(self, bsz): # precompute token buffer num_pred_tokens = self.config.model.length - 1 x = torch.zeros( (bsz, num_pred_tokens + 1), dtype=torch.long, device=self.device) x[:, 0] = self.tokenizer.bos_token_id # precompute noise noise = (torch.distributions.Gumbel(0, 1) .sample((bsz, num_pred_tokens, self.vocab_size)) .to(self.device)) for i in range(num_pred_tokens): next_logits = self.forward(x[:, :i + 1], None)[:, -1] y = (next_logits + noise[:, i]).argmax(-1) x[:, i + 1] = y return x @torch.no_grad() def _sample(self, num_steps=None, eps=1e-5, x_input = None): """Generate samples from the model.""" batch_size_per_gpu = self.config.eval.perplexity_batch_size if self.parameterization == 'ar': return self._ar_sampler(batch_size_per_gpu) # Lightning auto-casting is not working in this method for some reason if num_steps is None: num_steps = self.config.sampling.steps if x_input is not None: x = x_input.input_ids attention_mask = x_input.attention_mask else: x = self._sample_prior(batch_size_per_gpu, self.config.model.length).to(self.device) attention_mask = torch.ones_like(x) timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device) dt = (1 - eps) / num_steps p_x0_cache = None for i in range(num_steps): t = timesteps[i] * torch.ones(x.shape[0], 1, device=self.device) if self.sampler == 'ddpm': x = self._ddpm_update(x, t, dt) elif self.sampler == 'ddpm_cache': p_x0_cache, x_next = self._ddpm_caching_update(x, t, dt, p_x0=p_x0_cache, attention_mask=attention_mask) if (not torch.allclose(x_next, x) or self.time_conditioning): # Disable caching p_x0_cache = None x = x_next # print(self.tokenizer.decode(x.squeeze())) else: x = self._analytic_update(x, t, dt, attention_mask) if self.config.sampling.noise_removal: t = timesteps[-1] * torch.ones(x.shape[0], 1, device=self.device) if self.sampler == 'analytic': x = self._denoiser_update(x, t) else: unet_conditioning = self.noise(t)[0] x = self.forward(x, unet_conditioning, attention_mask, print_logits=True).argmax(dim=-1) # print(self.tokenizer.decode(x.squeeze())) return x def restore_model_and_sample(self, num_steps, eps=1e-5): """Generate samples from the model.""" # Lightning auto-casting is not working in this method for some reason # params_with_grad = [p for p in itertools.chain( # self.backbone.parameters(), # self.noise.parameters() # ) if p.requires_grad] if self.ema: self.ema.store(itertools.chain(self.backbone.parameters(), self.noise.parameters())) self.ema.copy_to(itertools.chain(self.backbone.parameters(), self.noise.parameters())) self.backbone.eval() self.noise.eval() samples = self._sample(num_steps=num_steps, eps=eps) if self.ema: self.ema.restore(itertools.chain(self.backbone.parameters(), self.noise.parameters())) self.backbone.train() self.noise.train() return samples def get_score(self, x, sigma, attention_mask=None): model_output = self.forward(x, sigma, attention_mask) if self.parameterization == 'subs': # score(x, t) = p_t(y) / p_t(x) # => log score(x, t) = log p_t(y) - log p_t(x) # case 1: x = masked # (i) y = unmasked # log score(x, t) = log p_\theta(x)|_y + log k # where k = exp(- sigma) / (1 - exp(- sigma)) # (ii) y = masked # log score(x, t) = 0 # case 2: x = unmasked # (i) y != masked, y != x # log score(x_i, t) = - inf # (ii) y = x # log score(x_i, t) = 0 # (iii) y = masked token # log score(x_i, t) = - log k # where k = exp(- sigma) / (1 - exp(- sigma)) log_k = - torch.log(torch.expm1(sigma)).squeeze(-1) assert log_k.ndim == 1 masked_score = model_output + log_k[:, None, None] masked_score[:, :, self.mask_index] = 0 unmasked_score = self.neg_infinity * torch.ones_like( model_output) unmasked_score = torch.scatter( unmasked_score, -1, x[..., None], torch.zeros_like(unmasked_score[..., :1])) unmasked_score[:, :, self.mask_index] = - ( log_k[:, None] * torch.ones_like(x)) masked_indices = (x == self.mask_index).to( model_output.dtype)[:, :, None] model_output = ( masked_score * masked_indices + unmasked_score * (1 - masked_indices)) return model_output.exp() def _staggered_score(self, score, dsigma): score = score.clone() extra_const = (1 - dsigma.exp()) * score.sum(dim=-1) score *= dsigma.exp()[:, None] score[..., self.mask_index] += extra_const return score def _analytic_update(self, x, t, step_size, attention_mask=None): curr_sigma, _ = self.noise(t) next_sigma, _ = self.noise(t - step_size) dsigma = curr_sigma - next_sigma score = self.get_score(x, curr_sigma, attention_mask) stag_score = self._staggered_score(score, dsigma) probs = stag_score * self._transp_transition(x, dsigma) return _sample_categorical(probs) def _denoiser_update(self, x, t): sigma, _ = self.noise(t) score = self.get_score(x, sigma) stag_score = self._staggered_score(score, sigma) probs = stag_score * self._transp_transition(x, sigma) probs[..., self.mask_index] = 0 samples = _sample_categorical(probs) return samples def _transp_transition(self, i, sigma): sigma = _unsqueeze(sigma, reference=i[..., None]) edge = torch.exp(-sigma) * F.one_hot( i, num_classes=self.vocab_size) edge += torch.where(i == self.mask_index, 1 - torch.exp(-sigma).squeeze(-1), 0)[..., None] return edge def _sample_t(self, n, device): _eps_t = torch.rand(n, device=device) if self.antithetic_sampling: offset = torch.arange(n, device=device) / n _eps_t = (_eps_t / n + offset) % 1 t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps if self.importance_sampling: return self.noise.importance_sampling_transformation(t) return t def _maybe_sub_sample(self, x0, attention_mask): # seqlen = x0.shape[1] # if seqlen > self.config.model.length: # assert seqlen == 2 * self.config.model.length # # cropping is needed for text8-crop dataset # # try the same starting point for now # start = np.random.choice(self.config.model.length) # end = start + self.config.model.length # input_tokens = x0[:, start: end] # output_tokens = x0[:, start + 1: end + 1] # new_attention_mask = attention_mask[:, start: end] # # Helps with validation PPL, since the val # # examples will all start and end with BOS/EOS # input_tokens[:, 0] = self.tokenizer.bos_token_id # output_tokens[:, -1] = self.tokenizer.eos_token_id # elif self.parameterization == 'ar': # input_tokens = x0[:, :-1] # output_tokens = x0[:, 1:] # new_attention_mask = attention_mask[:, 1:] # else: input_tokens = x0 output_tokens = None new_attention_mask = attention_mask return input_tokens, output_tokens, new_attention_mask def _reconstruction_loss(self, x0, attention_mask): t0 = torch.zeros(x0.shape[0], dtype=self.dtype, device=self.device) assert self.config.noise.type == 'loglinear' # The above assert is for d3pm parameterization unet_conditioning = self.noise(t0)[0][:, None] model_output_t0 = self.forward(x0, unet_conditioning, attention_mask) return - torch.gather(input=model_output_t0, dim=-1, index=x0[:, :, None]).squeeze(-1) def _forward_pass_diffusion(self, x0, attention_mask, mask=None): t = self._sample_t(x0.shape[0], x0.device) if self.T > 0: t = (t * self.T).to(torch.int) t = t / self.T # t \in {1/T, 2/T, ..., 1} t += (1 / self.T) if self.change_of_variables: unet_conditioning = t[:, None] f_T = torch.log1p(- torch.exp(- self.noise.sigma_max)) f_0 = torch.log1p(- torch.exp(- self.noise.sigma_min)) move_chance = torch.exp(f_0 + t * (f_T - f_0)) move_chance = move_chance[:, None] else: sigma, dsigma = self.noise(t) unet_conditioning = sigma[:, None] move_chance = 1 - torch.exp(-sigma[:, None]) if mask is None: xt = self.q_xt(x0, move_chance) else: xt = x0.where(mask==1, torch.full_like(x0, self.tokenizer.mask_token_id)) model_output = self.forward(xt, unet_conditioning, attention_mask) # print(self.tokenizer.decode(torch.argmax(model_output[0], dim=-1))) utils.print_nans(model_output, 'model_output') if self.parameterization == 'sedd': return dsigma[:, None] * self._score_entropy( model_output, sigma[:, None], xt, x0) if self.T > 0: diffusion_loss = self._d3pm_loss( model_output=model_output, xt=xt, x0=x0, t=t) if self.parameterization == 'd3pm': reconstruction_loss = self._reconstruction_loss(x0) elif self.parameterization == 'subs': reconstruction_loss = 0 return reconstruction_loss + diffusion_loss # SUBS parameterization, continuous time. log_p_theta = torch.gather( input=model_output, dim=-1, index=x0[:, :, None]).squeeze(-1) if self.change_of_variables or self.importance_sampling: return log_p_theta * torch.log1p( - torch.exp(- self.noise.sigma_min)) return - log_p_theta * ( dsigma / torch.expm1(sigma))[:, None] def _loss(self, x0, attention_mask, mask=None): (input_tokens, output_tokens, attention_mask) = self._maybe_sub_sample( x0, attention_mask) if self.parameterization == 'ar': logprobs = self.backbone(input_tokens, None, attention_mask) loss = - logprobs.gather( -1, output_tokens[:, :, None])[:, :, 0] else: loss = self._forward_pass_diffusion(input_tokens, attention_mask, mask) nlls = loss * attention_mask count = attention_mask.sum() batch_nll = nlls.sum() token_nll = batch_nll / count return Loss(loss=token_nll, nlls=nlls, token_mask=attention_mask) def _score_entropy(self, log_score, sigma, xt, x0): """Computes the SEDD loss. Args: log_score: float torch.Tensor with shape (batch_size, diffusion_model_input_length, vocab_size), log score, output of the denoising network. xt: int torch.Tensor with shape (batch_size, diffusion_model_input_length), input. x0: int torch.Tensor with shape (batch_size, diffusion_model_input_length), input. sigma: float torch.Tensor with shape (batch_size, 1). Returns: loss with shape (batch_size, diffusion_model_input_length) """ masked_indices = xt == self.mask_index expsig_minus_1 = torch.expm1(sigma).expand_as(xt) q_ratio = 1 / expsig_minus_1[masked_indices] words_that_were_masked = x0[masked_indices] neg_term = q_ratio * torch.gather( log_score[masked_indices], -1, words_that_were_masked[..., None]).squeeze(-1) score = log_score[masked_indices].exp() if self.mask_index == self.vocab_size - 1: pos_term = score[:, :-1].sum(dim=-1) else: pos_term = score[:, : self.mask_index].sum( dim=-1) + score[:, self.mask_index + 1:].sum(dim=-1) const = q_ratio * (q_ratio.log() - 1) entropy = torch.zeros(* xt.shape, device=xt.device) entropy[masked_indices] += pos_term - neg_term + const return entropy @torch.no_grad def sample_subs_guidance( self, n_samples, stride_length, num_strides, dt=0.001): ones = torch.ones(n_samples, dtype=self.dtype, device=self.device) num_steps = int(1 / dt) sampling_steps = 0 intermediate_tokens = [] target = None for _ in range(num_strides + 1): p_x0_cache = None x = self._sample_prior( n_samples, self.config.model.length).to(self.device) if target is not None: x[:, : -stride_length] = target for i in range(num_steps + 1): p_x0_cache, x_next = self._ddpm_caching_update( x=x, t=(1 - i * dt) * ones, dt=dt, p_x0=p_x0_cache) if (not torch.allclose(x_next, x) or self.time_conditioning): p_x0_cache = None sampling_steps += 1 x = x_next x = self.forward(x, 0 * ones).argmax(dim=-1) intermediate_tokens.append( x[:, :stride_length].cpu().numpy()) target = x[:, stride_length:] intermediate_tokens.append(target.cpu().numpy()) intermediate_text_samples = [] sequence_lengths = (( np.concatenate(intermediate_tokens, axis=1)[:, 1:] == self.tokenizer.eos_token_id).cumsum(-1) == 0).sum(-1) for i in range(2, len(intermediate_tokens) + 1): intermediate_text_samples.append( self.tokenizer.batch_decode( np.concatenate(intermediate_tokens[:i], axis=1))) return (sampling_steps, intermediate_text_samples, sequence_lengths) def restore_model_and_semi_ar_sample( self, stride_length, num_strides, dt=0.001): """Generate samples from the model.""" # Lightning auto-casting is not working in this method for some reason # params_with_grad = [p for p in itertools.chain( # self.backbone.parameters(), # self.noise.parameters() # ) if p] if self.ema: self.ema.store(itertools.chain(self.backbone.parameters(), self.noise.parameters())) self.ema.copy_to(itertools.chain(self.backbone.parameters(), self.noise.parameters())) self.backbone.eval() self.noise.eval() (sampling_steps, samples, sequence_lengths) = self.sample_subs_guidance( n_samples=self.config.loader.eval_batch_size, stride_length=stride_length, num_strides=num_strides, dt=dt) if self.ema: self.ema.restore(itertools.chain(self.backbone.parameters(), self.noise.parameters())) self.backbone.train() self.noise.train() return sampling_steps, samples, sequence_lengths