|
import itertools |
|
import math |
|
import os |
|
import sys |
|
import typing |
|
from dataclasses import dataclass |
|
|
|
import hydra.utils |
|
import lightning as L |
|
import numpy as np |
|
import torch.nn as nn |
|
import torch |
|
|
|
import ema |
|
import time |
|
import gc |
|
import pl_data_loader as dataloader |
|
import torch.nn.functional as F |
|
import torchmetrics |
|
import transformers |
|
from torch import Tensor |
|
from torch.optim.lr_scheduler import _LRScheduler |
|
from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer |
|
|
|
import utils |
|
import noise_schedule |
|
|
|
LOG2 = math.log(2) |
|
|
|
class CosineWarmup(_LRScheduler): |
|
def __init__(self, optimizer, warmup_steps, total_steps, eta_ratio=0.1, last_epoch=-1): |
|
self.warmup_steps = warmup_steps |
|
self.total_steps = total_steps |
|
self.eta_ratio = eta_ratio |
|
super(CosineWarmup, self).__init__(optimizer, last_epoch) |
|
|
|
def get_lr(self): |
|
if self.last_epoch < self.warmup_steps: |
|
return [base_lr * self.last_epoch / self.warmup_steps for base_lr in self.base_lrs] |
|
|
|
progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps) |
|
cosine_decay = 0.5 * (1 + np.cos(np.pi * progress)) |
|
decayed_lr = (1 - self.eta_ratio) * cosine_decay + self.eta_ratio |
|
|
|
return [decayed_lr * base_lr for base_lr in self.base_lrs] |
|
|
|
|
|
def _sample_categorical(categorical_probs): |
|
gumbel_norm = ( |
|
1e-10 |
|
- (torch.rand_like(categorical_probs) + 1e-10).log()) |
|
return (categorical_probs / gumbel_norm).argmax(dim=-1) |
|
|
|
|
|
def _unsqueeze(x, reference): |
|
return x.view( |
|
* x.shape, |
|
* ((1,) * (len(reference.shape) - len(x.shape)))) |
|
|
|
|
|
@dataclass |
|
class Loss: |
|
loss: torch.FloatTensor |
|
nlls: torch.FloatTensor |
|
token_mask: torch.FloatTensor |
|
|
|
|
|
class NLL(torchmetrics.aggregation.MeanMetric): |
|
pass |
|
|
|
|
|
class BPD(NLL): |
|
def compute(self) -> Tensor: |
|
"""Computes the bits per dimension. |
|
|
|
Returns: |
|
bpd |
|
""" |
|
return self.mean_value / self.weight / LOG2 |
|
|
|
|
|
class Perplexity(NLL): |
|
def compute(self) -> Tensor: |
|
"""Computes the Perplexity. |
|
|
|
Returns: |
|
Perplexity |
|
""" |
|
return torch.exp(self.mean_value / self.weight) |
|
|
|
|
|
class WrapVanillaESM(nn.Module): |
|
def __init__(self, bert_model_path): |
|
super(WrapVanillaESM, self).__init__() |
|
|
|
|
|
self.model = AutoModelForMaskedLM.from_pretrained(bert_model_path, device_map='cpu') |
|
self.tokenizer = AutoTokenizer.from_pretrained(bert_model_path) |
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
return self.model(*args, **kwargs) |
|
|
|
def unfreeze_attn_layers(self): |
|
model_layers = len(self.model.esm.encoder.layer) |
|
|
|
for i, layer in enumerate(self.model.esm.encoder.layer): |
|
if i >= model_layers-5: |
|
for module in layer.attention.self.key.modules(): |
|
for param in module.parameters(): |
|
param.requires_grad = True |
|
for module in layer.attention.self.query.modules(): |
|
for param in module.parameters(): |
|
param.requires_grad = True |
|
for module in layer.attention.self.value.modules(): |
|
for param in module.parameters(): |
|
param.requires_grad = True |
|
|
|
def unfreeze_all_layers(self): |
|
for param in self.model.parameters(): |
|
param.requires_grad = True |
|
|
|
def forward(self, inputs, sigma, attention_mask): |
|
logits = self.model(input_ids=inputs, attention_mask=attention_mask).logits |
|
return logits |
|
|
|
def save_model(self, save_dir): |
|
self.model.save_pretrained(save_dir) |
|
self.tokenizer.save_pretrained(save_dir) |
|
|
|
def load_model(self, load_dir): |
|
self.model = AutoModel.from_pretrained(load_dir) |
|
self.tokenizer = AutoTokenizer.from_pretrained(load_dir) |
|
|
|
class WrapMembraneESM(nn.Module): |
|
def __init__(self, bert_model_path): |
|
super(WrapMembraneESM, self).__init__() |
|
|
|
|
|
self.model = AutoModelForMaskedLM.from_pretrained(bert_model_path, device_map='cpu') |
|
self.tokenizer = AutoTokenizer.from_pretrained(bert_model_path) |
|
|
|
def __call__(self, *args, **kwargs): |
|
return self.model(*args, **kwargs) |
|
|
|
def freeze_model(self): |
|
for param in self.model.parameters(): |
|
param.requires_grad = False |
|
|
|
def unfreeze_all_layers(self): |
|
for param in self.model.parameters(): |
|
param.requires_grad = True |
|
|
|
def unfreeze_attn_layers(self): |
|
model_layers = len(self.model.esm.encoder.layer) |
|
|
|
for i, layer in enumerate(self.model.esm.encoder.layer): |
|
if i >= model_layers-11: |
|
for module in layer.attention.self.key.modules(): |
|
for param in module.parameters(): |
|
param.requires_grad = True |
|
for module in layer.attention.self.query.modules(): |
|
for param in module.parameters(): |
|
param.requires_grad = True |
|
for module in layer.attention.self.value.modules(): |
|
for param in module.parameters(): |
|
param.requires_grad = True |
|
|
|
def forward(self, inputs, sigma, attention_mask): |
|
logits = self.model(input_ids=inputs, attention_mask=attention_mask).logits |
|
return logits |
|
|
|
def save_model(self, save_dir): |
|
self.model.save_pretrained(save_dir) |
|
self.tokenizer.save_pretrained(save_dir) |
|
|
|
def load_model(self, load_dir): |
|
self.model = AutoModel.from_pretrained(load_dir) |
|
self.tokenizer = AutoTokenizer.from_pretrained(load_dir) |
|
|
|
class Diffusion(L.LightningModule): |
|
def __init__( |
|
self, |
|
config, |
|
tokenizer: transformers.PreTrainedTokenizer): |
|
super().__init__() |
|
self.save_hyperparameters() |
|
self.config = config |
|
|
|
self.tokenizer = tokenizer |
|
self.vocab_size = self.tokenizer.vocab_size |
|
self.sampler = self.config.sampling.predictor |
|
self.gen_ppl_eval_model_name_or_path = self.config.eval.\ |
|
gen_ppl_eval_model_name_or_path |
|
self.antithetic_sampling = self.config.training.antithetic_sampling |
|
self.importance_sampling = self.config.training.importance_sampling |
|
self.change_of_variables = self.config.training.change_of_variables |
|
if (not hasattr(self.tokenizer, 'mask_token') |
|
or self.tokenizer.mask_token is None): |
|
self.mask_index = self.vocab_size |
|
self.vocab_size += 1 |
|
else: |
|
self.mask_index = self.tokenizer.mask_token_id |
|
self.parameterization = self.config.parameterization |
|
|
|
|
|
|
|
|
|
|
|
if self.config.backbone == "vanilla_esm_pretrain": |
|
self.backbone = WrapVanillaESM(bert_model_path=self.config.training.esm_model_path) |
|
self.backbone.unfreeze_all_layers() |
|
self.backbone = torch.compile(self.backbone) |
|
elif self.config.backbone == 'membrane_esm_finetune': |
|
self.backbone = WrapMembraneESM(bert_model_path=self.config.checkpointing.pretrained_esm_mdlm_automodel_path) |
|
self.backbone.unfreeze_all_layers() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.T = self.config.T |
|
self.subs_masking = self.config.subs_masking |
|
|
|
self.softplus = torch.nn.Softplus() |
|
|
|
metrics = torchmetrics.MetricCollection({ |
|
'nll': NLL(), |
|
'bpd': BPD(), |
|
'ppl': Perplexity(), |
|
}) |
|
metrics.set_dtype(torch.float64) |
|
self.train_metrics = metrics.clone(prefix='train/') |
|
self.valid_metrics = metrics.clone(prefix='val/') |
|
self.test_metrics = metrics.clone(prefix='test/') |
|
|
|
|
|
self.gen_ppl_metric = Perplexity() |
|
self.eval_model_tokenizer = transformers.AutoTokenizer.\ |
|
from_pretrained(self.gen_ppl_eval_model_name_or_path) |
|
if self.eval_model_tokenizer.pad_token is None: |
|
self.eval_model_tokenizer.pad_token =\ |
|
self.eval_model_tokenizer.eos_token |
|
self.eval_model_tokenizer.pad_token_id =\ |
|
self.eval_model_tokenizer.eos_token_id |
|
|
|
self.noise = noise_schedule.get_noise(self.config, |
|
dtype=self.dtype) |
|
if self.config.training.ema > 0: |
|
self.ema = ema.ExponentialMovingAverage( |
|
itertools.chain(self.backbone.parameters(), |
|
self.noise.parameters()), |
|
decay=self.config.training.ema) |
|
else: |
|
self.ema = None |
|
|
|
self.lr = self.config.optim.lr |
|
self.sampling_eps = self.config.training.sampling_eps |
|
self.time_conditioning = self.config.time_conditioning |
|
self.neg_infinity = -1000000.0 |
|
self.fast_forward_epochs = None |
|
self.fast_forward_batches = None |
|
self._validate_configuration() |
|
|
|
def _validate_configuration(self): |
|
assert not (self.change_of_variables |
|
and self.importance_sampling) |
|
if self.parameterization == 'sedd': |
|
assert not self.importance_sampling |
|
assert not self.change_of_variables |
|
if self.parameterization == 'd3pm': |
|
assert self.T > 0 |
|
if self.T > 0: |
|
assert self.parameterization in {'d3pm', 'subs'} |
|
if self.subs_masking: |
|
assert self.parameterization == 'd3pm' |
|
|
|
def on_load_checkpoint(self, checkpoint): |
|
if self.ema: |
|
self.ema.load_state_dict(checkpoint['ema']) |
|
|
|
|
|
self.fast_forward_epochs = checkpoint['loops'][ |
|
'fit_loop']['epoch_progress']['current']['completed'] |
|
self.fast_forward_batches = checkpoint['loops'][ |
|
'fit_loop']['epoch_loop.batch_progress'][ |
|
'current']['completed'] |
|
|
|
def on_save_checkpoint(self, checkpoint): |
|
if self.ema: |
|
checkpoint['ema'] = self.ema.state_dict() |
|
|
|
|
|
|
|
|
|
checkpoint['loops']['fit_loop'][ |
|
'epoch_loop.batch_progress']['total'][ |
|
'completed'] = checkpoint['loops']['fit_loop'][ |
|
'epoch_loop.automatic_optimization.optim_progress'][ |
|
'optimizer']['step']['total'][ |
|
'completed'] * self.trainer.accumulate_grad_batches |
|
checkpoint['loops']['fit_loop'][ |
|
'epoch_loop.batch_progress']['current'][ |
|
'completed'] = checkpoint['loops']['fit_loop'][ |
|
'epoch_loop.automatic_optimization.optim_progress'][ |
|
'optimizer']['step']['current'][ |
|
'completed'] * self.trainer.accumulate_grad_batches |
|
|
|
|
|
checkpoint['loops']['fit_loop'][ |
|
'epoch_loop.state_dict'][ |
|
'_batches_that_stepped'] = checkpoint['loops']['fit_loop'][ |
|
'epoch_loop.automatic_optimization.optim_progress'][ |
|
'optimizer']['step']['total']['completed'] |
|
if 'sampler' not in checkpoint.keys(): |
|
checkpoint['sampler'] = {} |
|
if hasattr(self.trainer.train_dataloader.sampler, |
|
'state_dict'): |
|
sampler_state_dict = self.trainer.\ |
|
train_dataloader.sampler.state_dict() |
|
checkpoint['sampler'][ |
|
'random_state'] = sampler_state_dict.get( |
|
'random_state', None) |
|
else: |
|
checkpoint['sampler']['random_state'] = None |
|
|
|
self.backbone.save_model(self.config.checkpointing.fine_tuned_esm_mdlm_ckpt_path) |
|
|
|
def on_train_start(self): |
|
torch.cuda.empty_cache() |
|
if self.ema: |
|
self.ema.move_shadow_params_to_device(self.device) |
|
|
|
|
|
|
|
distributed = ( |
|
self.trainer._accelerator_connector.use_distributed_sampler |
|
and self.trainer._accelerator_connector.is_distributed) |
|
if distributed: |
|
sampler_cls = dataloader.FaultTolerantDistributedSampler |
|
else: |
|
sampler_cls = dataloader.RandomFaultTolerantSampler |
|
updated_dls = [] |
|
for dl in self.trainer.fit_loop._combined_loader.flattened: |
|
if hasattr(dl.sampler, 'shuffle'): |
|
dl_sampler = sampler_cls( |
|
dl.dataset, shuffle=dl.sampler.shuffle) |
|
else: |
|
dl_sampler = sampler_cls(dl.dataset) |
|
if (distributed |
|
and self.fast_forward_epochs is not None |
|
and self.fast_forward_batches is not None): |
|
dl_sampler.load_state_dict({ |
|
'epoch': self.fast_forward_epochs, |
|
'counter': (self.fast_forward_batches |
|
* self.config.loader.batch_size)}) |
|
|
|
from functools import partial |
|
from pl_data_loader import collate_fn |
|
collate_partial = partial(collate_fn, tokenizer=self.tokenizer) |
|
torch.cuda.empty_cache() |
|
|
|
updated_dls.append( |
|
torch.utils.data.DataLoader( |
|
dl.dataset, |
|
batch_size=self.config.loader.batch_size, |
|
num_workers=self.config.loader.num_workers, |
|
pin_memory=self.config.loader.pin_memory, |
|
sampler=dl_sampler, |
|
shuffle=False, |
|
persistent_workers=False, |
|
collate_fn=collate_partial)) |
|
self.trainer.fit_loop._combined_loader.flattened = updated_dls |
|
|
|
def optimizer_step(self, *args, **kwargs): |
|
super().optimizer_step(*args, **kwargs) |
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
if self.ema: |
|
self.ema.update(itertools.chain( |
|
self.backbone.parameters(), |
|
self.noise.parameters())) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _subs_parameterization(self, logits, xt): |
|
|
|
logits = logits.logits |
|
logits[:, :, self.mask_index] += self.neg_infinity |
|
|
|
|
|
|
|
|
|
|
|
logits = logits - torch.logsumexp(logits, dim=-1, |
|
keepdim=True) |
|
|
|
|
|
|
|
|
|
|
|
unmasked_indices = (xt != self.mask_index) |
|
logits[unmasked_indices] = self.neg_infinity |
|
logits[unmasked_indices, xt[unmasked_indices]] = 0 |
|
return logits |
|
|
|
def _d3pm_parameterization(self, logits): |
|
if self.subs_masking: |
|
logits[:, :, self.mask_index] += self.neg_infinity |
|
logits = logits - torch.logsumexp(logits, dim=-1, |
|
keepdim=True) |
|
return logits |
|
|
|
def _sedd_parameterization(self, logits, xt, sigma): |
|
esigm1_log = torch.where( |
|
sigma < 0.5, |
|
torch.expm1(sigma), |
|
sigma.exp() - 1).log().to(logits.dtype) |
|
|
|
|
|
logits = logits - esigm1_log[:, None, None] - np.log( |
|
logits.shape[-1] - 1) |
|
|
|
|
|
logits = torch.scatter(logits, -1, xt[..., None], |
|
torch.zeros_like(logits[..., :1])) |
|
return logits |
|
|
|
def _process_sigma(self, sigma): |
|
if sigma is None: |
|
assert self.parameterization == 'ar' |
|
return sigma |
|
if sigma.ndim > 1: |
|
sigma = sigma.squeeze(-1) |
|
if not self.time_conditioning: |
|
sigma = torch.zeros_like(sigma) |
|
assert sigma.ndim == 1, sigma.shape |
|
return sigma |
|
|
|
def forward(self, x, sigma, attention_mask, print_logits=False): |
|
"""Returns log score.""" |
|
sigma = self._process_sigma(sigma) |
|
with torch.amp.autocast("cuda", dtype=torch.float32): |
|
logits = self.backbone(x, attention_mask) |
|
|
|
|
|
|
|
|
|
if self.parameterization == 'subs': |
|
return self._subs_parameterization(logits=logits, xt=x) |
|
return logits |
|
|
|
def _d3pm_loss(self, model_output, xt, x0, t, attention_mask): |
|
dt = 1 / self.T |
|
|
|
if torch.is_tensor(t): |
|
t = t[:, None] |
|
assert t.ndim == 2 |
|
t = t.clamp(0., 1. - 1e-4) |
|
alpha_t = 1 - t + torch.zeros_like(xt) |
|
alpha_s = 1 - (t - dt) + torch.zeros_like(xt) |
|
|
|
log_x_theta_at_x0 = torch.gather( |
|
model_output, -1, x0[:, :, None]).squeeze(-1) |
|
log_x_theta_at_m = model_output[:, :, self.mask_index] |
|
x_theta_at_m = log_x_theta_at_m.exp() |
|
|
|
term_1_coef = dt / t |
|
term_1_log_nr = torch.log(alpha_t * x_theta_at_m / t + 1) |
|
term_1_log_dr = log_x_theta_at_x0 |
|
|
|
term_2_coef = 1 - dt / t |
|
term_2_log_nr = term_1_log_nr |
|
term_2_log_dr = torch.log(alpha_s * x_theta_at_m / (t - dt) + 1) |
|
|
|
L_vb_masked = ( |
|
term_1_coef * (term_1_log_nr - term_1_log_dr) |
|
+ term_2_coef * (term_2_log_nr - term_2_log_dr)) |
|
|
|
L_vb = L_vb_masked * (xt == self.mask_index) |
|
|
|
return self.T * L_vb |
|
|
|
def _compute_loss(self, batch, prefix): |
|
if 'attention_mask' in batch: |
|
attention_mask = batch['attention_mask'] |
|
else: |
|
attention_mask = None |
|
if 'mask' in batch: mask = batch['mask'] |
|
else: mask = None |
|
|
|
losses = self._loss(batch['input_ids'], attention_mask, mask) |
|
loss = losses.loss |
|
|
|
if prefix == 'train': |
|
self.train_metrics.update(losses.nlls, losses.token_mask) |
|
metrics = self.train_metrics |
|
elif prefix == 'val': |
|
self.valid_metrics.update(losses.nlls, losses.token_mask) |
|
metrics = self.valid_metrics |
|
elif prefix == 'test': |
|
self.test_metrics.update(losses.nlls, losses.token_mask) |
|
metrics = self.test_metrics |
|
else: |
|
raise ValueError(f'Invalid prefix: {prefix}') |
|
|
|
self.log_dict(metrics, |
|
on_step=False, |
|
on_epoch=True, |
|
sync_dist=True) |
|
return loss |
|
|
|
def on_train_epoch_start(self): |
|
self.backbone.train() |
|
self.noise.train() |
|
|
|
def training_step(self, batch, batch_idx): |
|
|
|
start_time = time.time() |
|
|
|
loss = self._compute_loss(batch, prefix='train') |
|
self.log(name='trainer/loss', |
|
value=loss.item(), |
|
on_step=True, |
|
on_epoch=False, |
|
sync_dist=True) |
|
|
|
|
|
elapsed_time = time.time() - start_time |
|
total_tokens = batch['input_ids'].numel() |
|
throughput = total_tokens / elapsed_time |
|
|
|
self.log(name='trainer/throughput', |
|
value=throughput, |
|
on_step=True, |
|
on_epoch=False, |
|
sync_dist=True) |
|
|
|
return loss |
|
|
|
def on_validation_epoch_start(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
if self.ema: |
|
self.ema.store( |
|
itertools.chain( |
|
self.backbone.parameters(), |
|
self.noise.parameters())) |
|
self.ema.copy_to(itertools.chain( |
|
self.backbone.parameters(), |
|
self.noise.parameters())) |
|
self.backbone.eval() |
|
self.noise.eval() |
|
assert self.valid_metrics.nll.mean_value == 0 |
|
assert self.valid_metrics.nll.weight == 0 |
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
loss = self._compute_loss(batch, prefix='val') |
|
self.log(name='trainer/val_loss', |
|
value=loss.item(), |
|
on_step=True, |
|
on_epoch=False, |
|
prog_bar=True, |
|
sync_dist=True) |
|
return loss |
|
|
|
def on_validation_epoch_end(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
if self.ema: |
|
self.ema.restore( |
|
itertools.chain( |
|
self.backbone.parameters(), |
|
self.noise.parameters())) |
|
|
|
def test_step(self, batch, batch_idx): |
|
loss = self._compute_loss(batch, prefix='test') |
|
self.log('test/loss', |
|
value=loss.item(), |
|
on_step=False, |
|
on_epoch=True, |
|
sync_dist=True) |
|
|
|
if self.config.eval.compute_generative_perplexity: |
|
samples, text_samples = None, None |
|
for _ in range( |
|
self.config.sampling.num_sample_batches): |
|
samples = self._sample() |
|
|
|
text_samples = self.tokenizer.batch_decode(samples) |
|
if self.config.eval.compute_generative_perplexity: |
|
self.compute_generative_perplexity(text_samples) |
|
if self.trainer.global_rank == 0 and hasattr( |
|
self.trainer.logger, 'log_table'): |
|
|
|
text_samples = text_samples[ |
|
: self.config.sampling.num_sample_log] |
|
self.trainer.logger.log_table( |
|
key=f'samples@global_step{self.global_step}', |
|
columns=['Generated Samples'], |
|
data=[[s] for s in text_samples]) |
|
if self.config.eval.compute_generative_perplexity: |
|
self.log('test/gen_ppl', |
|
self.gen_ppl_metric, |
|
on_epoch=False, |
|
on_step=True, |
|
sync_dist=True) |
|
|
|
def on_test_epoch_start(self): |
|
|
|
|
|
|
|
|
|
|
|
if self.ema: |
|
self.ema.store(itertools.chain( |
|
self.backbone.parameters(), |
|
self.noise.parameters())) |
|
self.ema.copy_to(itertools.chain( |
|
self.backbone.parameters(), |
|
self.noise.parameters())) |
|
|
|
self.backbone.eval() |
|
self.noise.eval() |
|
self.test_metrics.reset() |
|
|
|
def on_test_epoch_end(self): |
|
|
|
|
|
|
|
|
|
|
|
if self.ema: |
|
self.ema.restore(itertools.chain( |
|
self.backbone.parameters(), |
|
self.noise.parameters())) |
|
|
|
for metric_name, metric_value in self.test_metrics.compute().items(): |
|
self.log(metric_name, metric_value, sync_dist=True) |
|
|
|
def configure_optimizers(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW( |
|
itertools.chain(self.backbone.parameters(), |
|
self.noise.parameters()), |
|
lr=self.config.optim.lr, |
|
betas=(self.config.optim.beta1, |
|
self.config.optim.beta2), |
|
eps=self.config.optim.eps, |
|
weight_decay=self.config.optim.weight_decay |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.total_steps = self.config.trainer.max_steps |
|
scheduler = CosineWarmup(optimizer, |
|
warmup_steps=self.config.lr_scheduler.num_warmup_steps, |
|
total_steps=self.total_steps) |
|
|
|
scheduler_dict = { |
|
'scheduler': scheduler, |
|
'interval': 'step', |
|
'frequency': 1, |
|
'monitor': 'val/loss', |
|
'name': 'trainer/lr' |
|
} |
|
|
|
return [optimizer], [scheduler_dict] |
|
|
|
@torch.no_grad() |
|
def eval_retokenize(self, text_samples, max_length): |
|
"""Retokenizes samples for the eval model. |
|
|
|
Args: |
|
text_samples: List of sentences generated by the model. |
|
Returns: |
|
samples: Samples re-tokenized for the eval model |
|
attn_mask: Attention mask for the eval model |
|
eval_context_size: Size of the context for the eval model |
|
""" |
|
if 'llama2' in self.gen_ppl_eval_model_name_or_path: |
|
tokenizer_kwargs = { |
|
'text_samples': text_samples, |
|
'return_tensors': 'pt', |
|
'return_token_type_ids': False, |
|
'return_attention_mask': True, |
|
'truncation': True, |
|
'padding': True, |
|
'max_length': max_length, |
|
} |
|
eval_context_size = 4096 |
|
else: |
|
tokenizer_kwargs = { |
|
'return_tensors': 'pt', |
|
'return_token_type_ids': False, |
|
'return_attention_mask': True, |
|
'truncation': True, |
|
'padding': True, |
|
'max_length': max_length, |
|
} |
|
eval_context_size = 1024 |
|
samples = self.eval_model_tokenizer( |
|
text_samples, ** tokenizer_kwargs) |
|
attn_mask = samples['attention_mask'] |
|
samples = samples['input_ids'] |
|
if 'llama2' not in self.gen_ppl_eval_model_name_or_path: |
|
attn_mask = attn_mask.to(self.device) |
|
samples = samples.to(self.device) |
|
return samples, attn_mask, eval_context_size |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
def compute_masked_perplexity(self, sequences, masked): |
|
"""Compute the pseudo-perplexity of the generated protein sequences.""" |
|
total_nll = 0 |
|
total_tokens = 0 |
|
|
|
for sequence in sequences: |
|
|
|
input_ids = self.tokenizer(masked, return_tensors="pt").input_ids.to(self.device) |
|
gt_ids = self.tokenizer(sequence.upper(), return_tensors="pt").input_ids.to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
attention_mask = torch.ones_like(input_ids) |
|
if self.config.mode in ['train', 'ppl_eval']: |
|
outputs = self.backbone.model.forward(input_ids=input_ids, attention_mask=attention_mask) |
|
elif self.config.mode == "sample_eval": |
|
outputs = self.backbone.model.forward(input_ids) |
|
logits = outputs[-1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), |
|
gt_ids.where(input_ids==32, torch.full_like(input_ids, -100)).view(-1), |
|
reduction='sum') |
|
|
|
total_nll += loss.item() |
|
|
|
total_tokens += input_ids.ne(self.tokenizer.pad_token_id).sum().item() |
|
|
|
|
|
pseudo_perplexity = torch.exp(torch.tensor(total_nll / total_tokens)) |
|
self.gen_ppl_metric.update(pseudo_perplexity) |
|
|
|
return pseudo_perplexity.item() |
|
|
|
@torch.no_grad() |
|
def compute_generative_perplexity( |
|
self, |
|
text_samples: typing.List[str], |
|
retokenize: bool = True, |
|
max_length: typing.Optional[int] = None) -> None: |
|
"""Compute the generative perplexity of the model. |
|
|
|
Args: |
|
text_samples: List of sentences generated by the model. |
|
|
|
Returns: |
|
Perplexity of the generated text under a different |
|
pre-trained AR model (e.g., GPT2). |
|
""" |
|
os.environ['TOKENIZERS_PARALLELISM'] = 'false' |
|
eval_model = transformers.AutoModelForCausalLM.from_pretrained( |
|
self.gen_ppl_eval_model_name_or_path).eval() |
|
if max_length is None: |
|
max_length = self.config.model.length |
|
if 'llama2' not in self.gen_ppl_eval_model_name_or_path: |
|
eval_model = eval_model.to(self.device) |
|
|
|
if retokenize: |
|
(samples, attn_mask, |
|
eval_context_size) = self.eval_retokenize( |
|
text_samples, max_length=max_length) |
|
else: |
|
samples = text_samples |
|
attn_mask = torch.ones(samples.shape).to(self.device) |
|
eval_context_size = samples.shape[-1] |
|
batch_size = min( |
|
self.config.eval.perplexity_batch_size, |
|
samples.shape[0]) |
|
num_batches = samples.shape[0] // batch_size |
|
for i in range(num_batches): |
|
_samples = torch.split( |
|
samples[i * batch_size: (i + 1) * batch_size], |
|
eval_context_size, |
|
dim=-1) |
|
_attn_mask = torch.split( |
|
attn_mask[i * batch_size: (i + 1) * batch_size], |
|
eval_context_size, |
|
dim=-1) |
|
for (sample_chunk, attn_mask_chunk) in zip( |
|
_samples, _attn_mask): |
|
logits = eval_model( |
|
sample_chunk, attention_mask=attn_mask_chunk)[0] |
|
logits = logits.transpose(-1, -2) |
|
|
|
nlls = F.cross_entropy(logits[..., :-1], |
|
sample_chunk[..., 1:], |
|
reduction='none') |
|
first_eos = (sample_chunk == self.eval_model_tokenizer\ |
|
.eos_token_id).cumsum(-1) == 1 |
|
token_mask = ( |
|
sample_chunk |
|
!= self.eval_model_tokenizer.eos_token_id) |
|
self.gen_ppl_metric.update( |
|
nlls, first_eos[..., 1:] + token_mask[..., 1:]) |
|
|
|
def q_xt(self, x, move_chance): |
|
"""Computes the noisy sample xt. |
|
|
|
Args: |
|
x: int torch.Tensor with shape (batch_size, |
|
diffusion_model_input_length), input. |
|
move_chance: float torch.Tensor with shape (batch_size, 1). |
|
""" |
|
|
|
actual_seq_length = (x != 1).sum(dim=1, keepdim=True) |
|
|
|
max_mask_length = (actual_seq_length * 0.75).long() |
|
|
|
move_indices = torch.rand(*x.shape, device=x.device) < move_chance |
|
|
|
restricted_move_indices = torch.zeros_like(move_indices, dtype=torch.bool) |
|
|
|
for i in range(x.shape[0]): |
|
true_positions = torch.where(move_indices[i])[0] |
|
if len(true_positions) > max_mask_length[i]: |
|
selected_positions = true_positions[:max_mask_length[i].item()] |
|
restricted_move_indices[i, selected_positions] = True |
|
else: |
|
restricted_move_indices[i] = move_indices[i] |
|
xt = torch.where(restricted_move_indices, self.mask_index, x) |
|
|
|
return xt |
|
|
|
def _sample_prior(self, *batch_dims): |
|
return self.mask_index * torch.ones(* batch_dims, dtype=torch.int64) |
|
|
|
def _ddpm_caching_update(self, x, t, dt, p_x0=None, attention_mask=None): |
|
assert self.config.noise.type == 'loglinear' |
|
sigma_t, _ = self.noise(t) |
|
if t.ndim > 1: |
|
t = t.squeeze(-1) |
|
assert t.ndim == 1 |
|
move_chance_t = t[:, None, None] |
|
move_chance_s = (t - dt)[:, None, None] |
|
assert move_chance_t.ndim == 3, move_chance_t.shape |
|
if p_x0 is None: |
|
p_x0 = self.forward(x, sigma_t, attention_mask).exp() |
|
|
|
assert move_chance_t.ndim == p_x0.ndim |
|
q_xs = p_x0 * (move_chance_t - move_chance_s) |
|
q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0] |
|
_x = _sample_categorical(q_xs) |
|
|
|
copy_flag = (x != self.mask_index).to(x.dtype) |
|
return p_x0, copy_flag * x + (1 - copy_flag) * _x |
|
|
|
def _ddpm_update(self, x, t, dt, attention_mask): |
|
sigma_t, _ = self.noise(t) |
|
sigma_s, _ = self.noise(t - dt) |
|
if sigma_t.ndim > 1: |
|
sigma_t = sigma_t.squeeze(-1) |
|
if sigma_s.ndim > 1: |
|
sigma_s = sigma_s.squeeze(-1) |
|
assert sigma_t.ndim == 1, sigma_t.shape |
|
assert sigma_s.ndim == 1, sigma_s.shape |
|
move_chance_t = 1 - torch.exp(-sigma_t) |
|
move_chance_s = 1 - torch.exp(-sigma_s) |
|
move_chance_t = move_chance_t[:, None, None] |
|
move_chance_s = move_chance_s[:, None, None] |
|
unet_conditioning = sigma_t |
|
log_p_x0 = self.forward(x, unet_conditioning, attention_mask) |
|
assert move_chance_t.ndim == log_p_x0.ndim |
|
|
|
|
|
|
|
q_xs = log_p_x0.exp() * (move_chance_t |
|
- move_chance_s) |
|
q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0] |
|
_x = _sample_categorical(q_xs) |
|
|
|
copy_flag = (x != self.mask_index).to(x.dtype) |
|
return copy_flag * x + (1 - copy_flag) * _x |
|
|
|
def _ar_sampler(self, bsz): |
|
|
|
num_pred_tokens = self.config.model.length - 1 |
|
x = torch.zeros( |
|
(bsz, num_pred_tokens + 1), |
|
dtype=torch.long, |
|
device=self.device) |
|
x[:, 0] = self.tokenizer.bos_token_id |
|
|
|
noise = (torch.distributions.Gumbel(0, 1) |
|
.sample((bsz, num_pred_tokens, self.vocab_size)) |
|
.to(self.device)) |
|
for i in range(num_pred_tokens): |
|
next_logits = self.forward(x[:, :i + 1], None)[:, -1] |
|
y = (next_logits + noise[:, i]).argmax(-1) |
|
x[:, i + 1] = y |
|
return x |
|
|
|
@torch.no_grad() |
|
def _sample(self, num_steps=None, eps=1e-5, x_input = None): |
|
"""Generate samples from the model.""" |
|
batch_size_per_gpu = self.config.eval.perplexity_batch_size |
|
if self.parameterization == 'ar': |
|
return self._ar_sampler(batch_size_per_gpu) |
|
|
|
if num_steps is None: |
|
num_steps = self.config.sampling.steps |
|
if x_input is not None: |
|
x = x_input.input_ids |
|
attention_mask = x_input.attention_mask |
|
else: |
|
x = self._sample_prior(batch_size_per_gpu, self.config.model.length).to(self.device) |
|
attention_mask = torch.ones_like(x) |
|
timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device) |
|
dt = (1 - eps) / num_steps |
|
p_x0_cache = None |
|
|
|
for i in range(num_steps): |
|
t = timesteps[i] * torch.ones(x.shape[0], 1, device=self.device) |
|
if self.sampler == 'ddpm': |
|
x = self._ddpm_update(x, t, dt) |
|
elif self.sampler == 'ddpm_cache': |
|
p_x0_cache, x_next = self._ddpm_caching_update(x, t, dt, p_x0=p_x0_cache, attention_mask=attention_mask) |
|
if (not torch.allclose(x_next, x) or self.time_conditioning): |
|
|
|
p_x0_cache = None |
|
x = x_next |
|
|
|
else: |
|
x = self._analytic_update(x, t, dt, attention_mask) |
|
|
|
if self.config.sampling.noise_removal: |
|
t = timesteps[-1] * torch.ones(x.shape[0], 1, |
|
device=self.device) |
|
if self.sampler == 'analytic': |
|
x = self._denoiser_update(x, t) |
|
else: |
|
unet_conditioning = self.noise(t)[0] |
|
x = self.forward(x, unet_conditioning, attention_mask, print_logits=True).argmax(dim=-1) |
|
|
|
return x |
|
|
|
def restore_model_and_sample(self, num_steps, eps=1e-5): |
|
"""Generate samples from the model.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.ema: |
|
self.ema.store(itertools.chain(self.backbone.parameters(), |
|
self.noise.parameters())) |
|
self.ema.copy_to(itertools.chain(self.backbone.parameters(), |
|
self.noise.parameters())) |
|
self.backbone.eval() |
|
self.noise.eval() |
|
samples = self._sample(num_steps=num_steps, eps=eps) |
|
if self.ema: |
|
self.ema.restore(itertools.chain(self.backbone.parameters(), |
|
self.noise.parameters())) |
|
self.backbone.train() |
|
self.noise.train() |
|
return samples |
|
|
|
def get_score(self, x, sigma, attention_mask=None): |
|
model_output = self.forward(x, sigma, attention_mask) |
|
if self.parameterization == 'subs': |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log_k = - torch.log(torch.expm1(sigma)).squeeze(-1) |
|
assert log_k.ndim == 1 |
|
|
|
masked_score = model_output + log_k[:, None, None] |
|
masked_score[:, :, self.mask_index] = 0 |
|
|
|
unmasked_score = self.neg_infinity * torch.ones_like( |
|
model_output) |
|
unmasked_score = torch.scatter( |
|
unmasked_score, |
|
-1, |
|
x[..., None], |
|
torch.zeros_like(unmasked_score[..., :1])) |
|
unmasked_score[:, :, self.mask_index] = - ( |
|
log_k[:, None] * torch.ones_like(x)) |
|
|
|
masked_indices = (x == self.mask_index).to( |
|
model_output.dtype)[:, :, None] |
|
model_output = ( |
|
masked_score * masked_indices |
|
+ unmasked_score * (1 - masked_indices)) |
|
return model_output.exp() |
|
|
|
def _staggered_score(self, score, dsigma): |
|
score = score.clone() |
|
extra_const = (1 - dsigma.exp()) * score.sum(dim=-1) |
|
score *= dsigma.exp()[:, None] |
|
score[..., self.mask_index] += extra_const |
|
return score |
|
|
|
def _analytic_update(self, x, t, step_size, attention_mask=None): |
|
curr_sigma, _ = self.noise(t) |
|
next_sigma, _ = self.noise(t - step_size) |
|
dsigma = curr_sigma - next_sigma |
|
score = self.get_score(x, curr_sigma, attention_mask) |
|
stag_score = self._staggered_score(score, dsigma) |
|
probs = stag_score * self._transp_transition(x, dsigma) |
|
return _sample_categorical(probs) |
|
|
|
def _denoiser_update(self, x, t): |
|
sigma, _ = self.noise(t) |
|
score = self.get_score(x, sigma) |
|
stag_score = self._staggered_score(score, sigma) |
|
probs = stag_score * self._transp_transition(x, sigma) |
|
probs[..., self.mask_index] = 0 |
|
samples = _sample_categorical(probs) |
|
return samples |
|
|
|
def _transp_transition(self, i, sigma): |
|
sigma = _unsqueeze(sigma, reference=i[..., None]) |
|
edge = torch.exp(-sigma) * F.one_hot( |
|
i, num_classes=self.vocab_size) |
|
edge += torch.where(i == self.mask_index, |
|
1 - torch.exp(-sigma).squeeze(-1), |
|
0)[..., None] |
|
return edge |
|
|
|
def _sample_t(self, n, device): |
|
_eps_t = torch.rand(n, device=device) |
|
if self.antithetic_sampling: |
|
offset = torch.arange(n, device=device) / n |
|
_eps_t = (_eps_t / n + offset) % 1 |
|
t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps |
|
if self.importance_sampling: |
|
return self.noise.importance_sampling_transformation(t) |
|
return t |
|
|
|
def _maybe_sub_sample(self, x0, attention_mask): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_tokens = x0 |
|
output_tokens = None |
|
new_attention_mask = attention_mask |
|
return input_tokens, output_tokens, new_attention_mask |
|
|
|
def _reconstruction_loss(self, x0, attention_mask): |
|
t0 = torch.zeros(x0.shape[0], dtype=self.dtype, |
|
device=self.device) |
|
assert self.config.noise.type == 'loglinear' |
|
|
|
unet_conditioning = self.noise(t0)[0][:, None] |
|
model_output_t0 = self.forward(x0, unet_conditioning, attention_mask) |
|
return - torch.gather(input=model_output_t0, |
|
dim=-1, |
|
index=x0[:, :, None]).squeeze(-1) |
|
|
|
def _forward_pass_diffusion(self, x0, attention_mask, mask=None): |
|
t = self._sample_t(x0.shape[0], x0.device) |
|
if self.T > 0: |
|
t = (t * self.T).to(torch.int) |
|
t = t / self.T |
|
|
|
t += (1 / self.T) |
|
|
|
if self.change_of_variables: |
|
unet_conditioning = t[:, None] |
|
f_T = torch.log1p(- torch.exp(- self.noise.sigma_max)) |
|
f_0 = torch.log1p(- torch.exp(- self.noise.sigma_min)) |
|
move_chance = torch.exp(f_0 + t * (f_T - f_0)) |
|
move_chance = move_chance[:, None] |
|
else: |
|
sigma, dsigma = self.noise(t) |
|
unet_conditioning = sigma[:, None] |
|
move_chance = 1 - torch.exp(-sigma[:, None]) |
|
|
|
if mask is None: xt = self.q_xt(x0, move_chance) |
|
else: xt = x0.where(mask==1, torch.full_like(x0, self.tokenizer.mask_token_id)) |
|
model_output = self.forward(xt, unet_conditioning, attention_mask) |
|
|
|
|
|
utils.print_nans(model_output, 'model_output') |
|
|
|
if self.parameterization == 'sedd': |
|
return dsigma[:, None] * self._score_entropy( |
|
model_output, sigma[:, None], xt, x0) |
|
|
|
if self.T > 0: |
|
diffusion_loss = self._d3pm_loss( |
|
model_output=model_output, xt=xt, x0=x0, t=t) |
|
if self.parameterization == 'd3pm': |
|
reconstruction_loss = self._reconstruction_loss(x0) |
|
elif self.parameterization == 'subs': |
|
reconstruction_loss = 0 |
|
return reconstruction_loss + diffusion_loss |
|
|
|
|
|
log_p_theta = torch.gather( |
|
input=model_output, |
|
dim=-1, |
|
index=x0[:, :, None]).squeeze(-1) |
|
|
|
if self.change_of_variables or self.importance_sampling: |
|
return log_p_theta * torch.log1p( |
|
- torch.exp(- self.noise.sigma_min)) |
|
|
|
return - log_p_theta * ( |
|
dsigma / torch.expm1(sigma))[:, None] |
|
|
|
def _loss(self, x0, attention_mask, mask=None): |
|
(input_tokens, output_tokens, |
|
attention_mask) = self._maybe_sub_sample( |
|
x0, attention_mask) |
|
|
|
if self.parameterization == 'ar': |
|
logprobs = self.backbone(input_tokens, None, attention_mask) |
|
loss = - logprobs.gather( |
|
-1, output_tokens[:, :, None])[:, :, 0] |
|
else: |
|
loss = self._forward_pass_diffusion(input_tokens, attention_mask, mask) |
|
|
|
nlls = loss * attention_mask |
|
count = attention_mask.sum() |
|
|
|
batch_nll = nlls.sum() |
|
token_nll = batch_nll / count |
|
|
|
return Loss(loss=token_nll, |
|
nlls=nlls, |
|
token_mask=attention_mask) |
|
|
|
def _score_entropy(self, log_score, sigma, xt, x0): |
|
"""Computes the SEDD loss. |
|
|
|
Args: |
|
log_score: float torch.Tensor with shape (batch_size, |
|
diffusion_model_input_length, vocab_size), |
|
log score, output of the denoising network. |
|
xt: int torch.Tensor with shape (batch_size, |
|
diffusion_model_input_length), input. |
|
x0: int torch.Tensor with shape (batch_size, |
|
diffusion_model_input_length), input. |
|
sigma: float torch.Tensor with shape (batch_size, 1). |
|
|
|
Returns: |
|
loss with shape (batch_size, diffusion_model_input_length) |
|
""" |
|
masked_indices = xt == self.mask_index |
|
|
|
expsig_minus_1 = torch.expm1(sigma).expand_as(xt) |
|
q_ratio = 1 / expsig_minus_1[masked_indices] |
|
|
|
words_that_were_masked = x0[masked_indices] |
|
|
|
neg_term = q_ratio * torch.gather( |
|
log_score[masked_indices], |
|
-1, |
|
words_that_were_masked[..., None]).squeeze(-1) |
|
score = log_score[masked_indices].exp() |
|
if self.mask_index == self.vocab_size - 1: |
|
pos_term = score[:, :-1].sum(dim=-1) |
|
else: |
|
pos_term = score[:, : self.mask_index].sum( |
|
dim=-1) + score[:, self.mask_index + 1:].sum(dim=-1) |
|
const = q_ratio * (q_ratio.log() - 1) |
|
|
|
entropy = torch.zeros(* xt.shape, device=xt.device) |
|
entropy[masked_indices] += pos_term - neg_term + const |
|
return entropy |
|
|
|
@torch.no_grad |
|
def sample_subs_guidance( |
|
self, n_samples, stride_length, num_strides, dt=0.001): |
|
ones = torch.ones(n_samples, dtype=self.dtype, |
|
device=self.device) |
|
|
|
num_steps = int(1 / dt) |
|
sampling_steps = 0 |
|
intermediate_tokens = [] |
|
target = None |
|
for _ in range(num_strides + 1): |
|
p_x0_cache = None |
|
x = self._sample_prior( |
|
n_samples, |
|
self.config.model.length).to(self.device) |
|
if target is not None: |
|
x[:, : -stride_length] = target |
|
for i in range(num_steps + 1): |
|
p_x0_cache, x_next = self._ddpm_caching_update( |
|
x=x, t=(1 - i * dt) * ones, dt=dt, p_x0=p_x0_cache) |
|
if (not torch.allclose(x_next, x) |
|
or self.time_conditioning): |
|
p_x0_cache = None |
|
sampling_steps += 1 |
|
x = x_next |
|
x = self.forward(x, 0 * ones).argmax(dim=-1) |
|
intermediate_tokens.append( |
|
x[:, :stride_length].cpu().numpy()) |
|
target = x[:, stride_length:] |
|
|
|
intermediate_tokens.append(target.cpu().numpy()) |
|
intermediate_text_samples = [] |
|
sequence_lengths = (( |
|
np.concatenate(intermediate_tokens, axis=1)[:, 1:] |
|
== self.tokenizer.eos_token_id).cumsum(-1) == 0).sum(-1) |
|
for i in range(2, len(intermediate_tokens) + 1): |
|
intermediate_text_samples.append( |
|
self.tokenizer.batch_decode( |
|
np.concatenate(intermediate_tokens[:i], axis=1))) |
|
return (sampling_steps, intermediate_text_samples, |
|
sequence_lengths) |
|
|
|
def restore_model_and_semi_ar_sample( |
|
self, stride_length, num_strides, dt=0.001): |
|
"""Generate samples from the model.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.ema: |
|
self.ema.store(itertools.chain(self.backbone.parameters(), |
|
self.noise.parameters())) |
|
self.ema.copy_to(itertools.chain(self.backbone.parameters(), |
|
self.noise.parameters())) |
|
self.backbone.eval() |
|
self.noise.eval() |
|
(sampling_steps, samples, |
|
sequence_lengths) = self.sample_subs_guidance( |
|
n_samples=self.config.loader.eval_batch_size, |
|
stride_length=stride_length, |
|
num_strides=num_strides, |
|
dt=dt) |
|
if self.ema: |
|
self.ema.restore(itertools.chain(self.backbone.parameters(), |
|
self.noise.parameters())) |
|
self.backbone.train() |
|
self.noise.train() |
|
return sampling_steps, samples, sequence_lengths |