tess-2-demo / sdlm /pipelines /simplex_ddpm.py
hamishivi's picture
commit
17ff0d8 verified
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.utils import BaseOutput
from sdlm.inference.inference_utils import logits_projection
from sdlm.models.utils import check_tokenizer_equal, is_cdcd_check, load_classifier
from sdlm.utils import scale, self_condition_preds, convert_to_simplex
@dataclass
class SimplexDiffusionPipelineOutput(BaseOutput):
"""
Output class for simplex diffusion pipelines.
Args:
simplex (`np.ndarray`)
numpy array showing the denoised simplex representation.
logits (`np.ndarray`) final generated logits before applying the projection.
"""
simplex: np.ndarray
logits: np.ndarray
loss: np.ndarray
def yield_func(x):
yield x
class SimplexDDPMPipeline(DiffusionPipeline):
r"""
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Parameters:
model: Model architecture to denoise the latents (encoded token ids).
scheduler ([`SchedulerMixin`]): A scheduler to denoise the encoded latent.
"""
def __init__(
self,
model,
scheduler,
simplex_value,
top_p,
sampling_type,
is_conditional_generation,
tokenizer,
classifier_free_uncond_input,
temperature,
guidance_softmax_combination,
):
super().__init__()
self.register_modules(model=model, scheduler=scheduler)
self.simplex_value = simplex_value
self.top_p = top_p
self.sampling_type = sampling_type
self.is_conditional_generation = is_conditional_generation
self.tokenizer = tokenizer
self.classifier_free_uncond_input = classifier_free_uncond_input
self.temperature = temperature
self.guidance_softmax_combination = guidance_softmax_combination
@torch.inference_mode()
def __call__(
self,
seq_length: int = 512,
generator: Optional[torch.Generator] = None,
batch: Optional[torch.FloatTensor] = None,
guidance_scale: float = 1.0,
is_generator: bool = False,
) -> Union[SimplexDiffusionPipelineOutput, Tuple]:
r"""
Args:
batch_size (`int`, *optional*, defaults to 1):
The number of images to generate.
seq_length: (`int`), sequence length for the generated samples.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
batch (`torch.FloatTensor`): batch of input data, mostly used in the conditional generation setting.
Returns:
[`~pipeline_utils.SimplexDiffusionPipelineOutput`]: returns the generated simplex.
"""
# Classifier_free guidance works only in the conditional generation case.
classifier_free_guidance = (
guidance_scale > 1.0 and self.is_conditional_generation
)
"""
if classifier_free_guidance:
# Makes unconditional input for max sequence length, later we truncate it.
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=seq_length, return_tensors="pt"
).to(self.device)
# Converts this to a simplex (batch_size, max_seq, vocab_size)
uncond_simplex = convert_to_simplex(uncond_input["input_ids"], self.simplex_value, self.model.config.vocab_size)
"""
# Sample gaussian noise to begin loop
vocab_size = self.model.config.vocab_size
if batch is not None:
# TODO(rabeeh): is giving the length cheating for this setting?
# Adapts the sequence length to the given `span_mask`'s length.
seq_length = batch["input_ids"].shape[1]
# idk why i have the bsz argument.
batch_size = batch["input_ids"].shape[0]
simplex_shape = (batch_size, seq_length, vocab_size)
simplex = self.simplex_value * torch.randn(
simplex_shape, generator=generator, device=self.device
)
if self.model.config.self_condition is not None:
previous_pred = torch.zeros(
(batch_size, seq_length, vocab_size), device=self.device
)
logits_projection_fct = lambda x: logits_projection( # noqa: E731
x, self.sampling_type, self.top_p, self.simplex_value, self.temperature
)
losses = []
previous_hidden = None
warped_steps = []
prev_t = 0
for t in self.progress_bar(self.scheduler.timesteps):
original_t = torch.tensor([t], device=self.device).expand(
batch_size, seq_length
)
if is_cdcd_check(self.model):
# warp timesteps based on cdf
# we are in inference mode, anything in span_mask is to gen.
token_inputs = torch.where(
batch["span_mask"], self.tokenizer.pad_token_id, batch["input_ids"]
)
t = self.model.warp_timesteps(
original_t,
t_min=0,
t_max=len(self.scheduler) - 1,
token_input=token_inputs,
span_mask=batch["span_mask"],
)
else:
t = original_t
t_scaled = scale(t, len(self.scheduler))
warped_steps.append(t)
"""
if classifier_free_guidance:
if self.classifier_free_uncond_input == "empty_token":
uncond_input = uncond_simplex[:, : batch["input_ids"].shape[1], :]
elif self.classifier_free_uncond_input == "noisy_simplex":
uncond_input = self.simplex_value * torch.randn(simplex.shape, generator=generator, device=self.device)
else:
raise NotImplementedError
"""
# 1. predict noise model_output. Note we need not to pass the input_ids in case of
# unconditional generation since the loss would be computed and it should not.
model_output = self.model(
input_ids=batch["input_ids"]
if self.is_conditional_generation
else None,
span_mask=batch["span_mask"]
if self.is_conditional_generation
else None,
simplex=simplex,
timesteps=t_scaled,
previous_pred=previous_pred
if self.model.config.self_condition
else None,
classifier_free_guidance=classifier_free_guidance,
reduce_loss="none",
max_timestep=len(self.scheduler),
previous_hidden=previous_hidden,
)
model_output_logits = model_output.logits
previous_hidden = model_output.hidden_states
# Performs classifier-free guidance.
if classifier_free_guidance:
logits_uncond, logits_pred = model_output_logits.chunk(2)
if self.guidance_softmax_combination:
model_output_logits = F.softmax(
logits_uncond, dim=-1
) + guidance_scale * (
F.softmax(logits_pred, dim=-1)
- F.softmax(logits_uncond, dim=-1)
)
else:
model_output_logits = logits_uncond + guidance_scale * (
logits_pred - logits_uncond
)
if self.model.config.self_condition is not None:
if classifier_free_guidance:
prev_output_logits = model_output.logits.chunk(2)[1]
else:
prev_output_logits = model_output_logits
previous_pred = self_condition_preds(
self.model.config.self_condition,
prev_output_logits,
logits_projection_fct,
)
# Projection.
projected_logits = logits_projection_fct(model_output_logits)
old_simplex = simplex
# 2. compute previous logits: x_t -> x_t-1
noise = self.simplex_value * torch.randn(
simplex_shape, generator=generator, device=self.device
)
if is_cdcd_check(self.model):
# warp timesteps based on cdf
token_inputs = torch.where(
batch["span_mask"], self.tokenizer.pad_token_id, batch["input_ids"]
)
prev_t = self.model.warp_timesteps(
original_t - 1,
t_min=0,
t_max=len(self.scheduler) - 1,
token_input=token_inputs,
span_mask=batch["span_mask"],
).long()
# since the tokenwise can do some wild stuff.
prev_t = torch.clamp(prev_t, min=0, max=len(self.scheduler) - 1)
else:
prev_t = original_t - 1
simplex = self.scheduler.step(
projected_logits,
t,
prev_t,
noise,
generator=generator,
).prev_sample
# keep loss for logging
losses.append(model_output.loss.detach().cpu())
# yield over it. (prolly not optimal, but whatever)
yield SimplexDiffusionPipelineOutput(
simplex=old_simplex, logits=model_output_logits, loss=losses[-1]
)
# we take the mean loss over all timesteps
loss = torch.stack(losses, dim=0)
# from matplotlib import pyplot as plt
# warped_steps = torch.stack(warped_steps, dim=0)
# for i in range(warped_steps.shape[1]):
# plt.plot(warped_steps[:, i, 256:].cpu())
# plt.savefig(f"warps_prefix_tokenwise/warped_{i}.png")
# plt.clf()
return SimplexDiffusionPipelineOutput(
simplex=simplex, logits=model_output_logits, loss=loss
)
class SimplexDDPMClassifierGuidancePipeline(SimplexDDPMPipeline):
def __init__(
self,
model,
scheduler,
simplex_value,
top_p,
sampling_type,
is_conditional_generation,
tokenizer,
classifier_free_uncond_input,
temperature,
guidance_softmax_combination,
classifier_model_name_or_path,
) -> None:
super().__init__(
model,
scheduler,
simplex_value,
top_p,
sampling_type,
is_conditional_generation,
tokenizer,
classifier_free_uncond_input,
temperature,
guidance_softmax_combination,
)
self.classifier = None
if classifier_model_name_or_path is not None:
classifier_tokenizer, classifier = load_classifier(
classifier_model_name_or_path
)
check_tokenizer_equal(self.tokenizer, classifier_tokenizer)
self.classifier = classifier.to(self.device)
def get_reward(
self,
logits: torch.FloatTensor,
use_gumbel_softmax: bool,
do_hard_sample: bool,
softmax_temperature: float,
one_hot: Optional[torch.Tensor] = None,
span_mask: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
logits = logits.to(torch.bfloat16)
logits.requires_grad = True
if use_gumbel_softmax:
simplex = F.gumbel_softmax(
logits, tau=softmax_temperature, hard=do_hard_sample, dim=-1
)
else:
simplex = torch.softmax(logits / softmax_temperature, dim=-1)
# mask out context
if span_mask is not None:
simplex = torch.where(span_mask.unsqueeze(-1), simplex, one_hot)
# forcibly add eos token to the simplex
# eos_token = torch.nn.functional.one_hot(
# torch.tensor(self.tokenizer.eos_token_id),
# num_classes=self.classifier.config.vocab_size,
# ).to(simplex.device)
# eos_token = eos_token.unsqueeze(0).unsqueeze(0).expand_as(simplex)
# simplex = torch.cat([simplex, eos_token], dim=1)
inputs_embeds = F.linear(
simplex, self.classifier.model.get_input_embeddings().weight.data.T
)
# forward pass through reward model
reward = self.classifier(inputs_embeds=inputs_embeds).logits
return reward
@torch.no_grad()
def __call__(
self,
seq_length: int = 512,
generator: Optional[torch.Generator] = None,
batch: Optional[torch.FloatTensor] = None,
guidance_scale: float = 1.0,
is_generator: bool = False,
use_gumbel_softmax: bool = False,
do_hard_sample: bool = False,
softmax_temperature: float = 1.0,
use_ddim_sampling: bool = False,
num_guidance_steps: int = 5,
) -> Union[SimplexDiffusionPipelineOutput, Tuple]:
# check for classifier guidance
use_classifier_guidance = self.classifier is not None and guidance_scale > 0.0
# NOTE: copied from SimplexDDPMPipeline
# Sample gaussian noise to begin loop
vocab_size = self.model.config.vocab_size
if batch is not None:
# TODO(rabeeh): is giving the length cheating for this setting?
# Adapts the sequence length to the given `span_mask`'s length.
seq_length = batch["input_ids"].shape[1]
# idk why i have the bsz argument.
batch_size = batch["input_ids"].shape[0]
simplex_shape = (batch_size, seq_length, vocab_size)
# simplex ~ N(0, kI)
simplex = self.simplex_value * torch.randn(
simplex_shape, generator=generator, device=self.device
)
if self.model.config.self_condition is not None:
previous_pred = torch.zeros(
(batch_size, seq_length, vocab_size), device=self.device
)
# logits -> hard sampled k / -k
logits_projection_fct = lambda x: logits_projection( # noqa: E731
x, self.sampling_type, self.top_p, self.simplex_value, self.temperature
)
losses = []
previous_hidden = None
warped_steps = []
prev_t = 0
all_rewards = []
for t in self.progress_bar(self.scheduler.timesteps):
original_t = torch.tensor([t], device=self.device).expand(
batch_size, seq_length
)
if is_cdcd_check(self.model):
# warp timesteps based on cdf
# we are in inference mode, anything in span_mask is to gen.
token_inputs = torch.where(
batch["span_mask"], 50264, batch["input_ids"]
)
t = self.model.warp_timesteps(
original_t,
t_min=0,
t_max=len(self.scheduler) - 1,
token_input=token_inputs,
span_mask=batch["span_mask"],
)
else:
t = original_t
t_scaled = scale(t, len(self.scheduler))
warped_steps.append(t)
# 1. predict noise model_output. Note we need not to pass the input_ids in case of
# unconditional generation since the loss would be computed and it should not.
model_output = self.model(
input_ids=batch["input_ids"]
if self.is_conditional_generation
else None,
span_mask=batch["span_mask"]
if self.is_conditional_generation
else None,
simplex=simplex,
timesteps=t_scaled,
previous_pred=previous_pred
if self.model.config.self_condition
else None,
classifier_free_guidance=False,
reduce_loss="none",
max_timestep=len(self.scheduler),
previous_hidden=previous_hidden,
)
model_output_logits = model_output.logits
previous_hidden = model_output.hidden_states
# NOTE: classifier guidance!
# compute one_hot
span_mask = batch["span_mask"]
one_hot = F.one_hot(batch["input_ids"], len(self.tokenizer)).to(
torch.bfloat16
)
model_output_logits = model_output_logits.to(torch.bfloat16)
if use_classifier_guidance:
# use torch.optim api
model_output_logits = torch.nn.Parameter(model_output_logits)
optimizer = torch.optim.SGD([model_output_logits], lr=guidance_scale)
# guidance
with torch.enable_grad():
for _ in range(num_guidance_steps):
reward = self.get_reward(
logits=model_output_logits,
use_gumbel_softmax=use_gumbel_softmax,
do_hard_sample=do_hard_sample,
softmax_temperature=softmax_temperature,
one_hot=one_hot,
span_mask=span_mask,
)
# all_rewards.append(reward.detach().cpu())
reward = reward.sum().neg()
reward.backward()
optimizer.step()
model_output_logits = model_output_logits.data
if self.model.config.self_condition is not None:
prev_output_logits = model_output_logits
previous_pred = self_condition_preds(
self.model.config.self_condition,
prev_output_logits,
logits_projection_fct,
)
old_simplex = simplex
# 2. compute previous logits: x_t -> x_t-1
if is_cdcd_check(self.model):
# warp timesteps based on cdf
token_inputs = torch.where(
batch["span_mask"], 50264, batch["input_ids"]
)
prev_t = self.model.warp_timesteps(
original_t - 1,
t_min=0,
t_max=len(self.scheduler) - 1,
token_input=token_inputs,
span_mask=batch["span_mask"],
).long()
# since the tokenwise can do some wild stuff.
prev_t = torch.clamp(prev_t, min=0, max=len(self.scheduler) - 1)
else:
prev_t = original_t - 1
if not use_ddim_sampling:
# normal tess
noise = self.simplex_value * torch.randn(
simplex_shape, generator=generator, device=self.device
)
# Projection.
projected_logits = logits_projection_fct(model_output_logits)
simplex = self.scheduler.step(
projected_logits,
t,
prev_t,
noise,
generator=generator,
).prev_sample
else:
# input: noisy k / -k
# output: clean k / -k
x_t = old_simplex
x_0_hat = (
2 * self.simplex_value * torch.softmax(model_output_logits, dim=-1)
- self.simplex_value
)
alpha_prod_t = self.scheduler.alphas_cumprod[t[0, 0].item()]
sqrt_alpha_prod_t = torch.sqrt(alpha_prod_t)
sqrt_one_minus_alpha_prod_t = torch.sqrt(1 - alpha_prod_t)
noise = (
x_t - sqrt_alpha_prod_t * x_0_hat
) / sqrt_one_minus_alpha_prod_t
simplex = self.scheduler.step(
x_0_hat,
t,
prev_t,
noise,
generator=generator,
).prev_sample
# keep loss for logging
losses.append(model_output.loss.detach().cpu())
# yield over it. (prolly not optimal, but whatever)
yield SimplexDiffusionPipelineOutput(
simplex=old_simplex, logits=model_output_logits, loss=losses[-1]
)
# we take the mean loss over all timesteps
loss = torch.stack(losses, dim=0)
# from matplotlib import pyplot as plt
# all_rewardst = torch.cat(all_rewards, dim=-1)
# plt.plot(all_rewardst.to(torch.float32).T)
# plt.savefig("tmp.png")
# import pdb; pdb.set_trace()
return SimplexDiffusionPipelineOutput(
simplex=simplex, logits=model_output_logits, loss=loss
)
# A variant of the SimplexDDPMPipeline that is used for evaluation.
# Main difference is that we assume that you pass the ground truth, and
# want to compute the loss.
class SimplexDDPMPipelineForEvaluation(SimplexDDPMPipeline):
@torch.inference_mode()
def __call__(
self,
seq_length: int = 512,
generator: Optional[torch.Generator] = None,
batch: Optional[torch.FloatTensor] = None,
guidance_scale: float = 1.0,
is_generator: bool = False,
) -> Union[SimplexDiffusionPipelineOutput, Tuple]:
# Classifier_free guidance works only in the conditional generation case.
classifier_free_guidance = (
guidance_scale > 1.0 and self.is_conditional_generation
)
# Sample gaussian noise to begin loop
vocab_size = self.model.config.vocab_size
if batch is not None:
seq_length = batch["input_ids"].shape[1]
# idk why i have the bsz argument.
batch_size = batch["input_ids"].shape[0]
simplex_shape = (batch_size, seq_length, vocab_size)
# simplex here is the simplex of the actual input!
simplex = convert_to_simplex(
batch["input_ids"], self.simplex_value, self.model.config.vocab_size
)
noise = self.simplex_value * torch.randn(
simplex_shape, generator=generator, device=self.device
)
if self.model.config.self_condition is not None:
previous_pred = torch.zeros(
(batch_size, seq_length, vocab_size), device=self.device
)
logits_projection_fct = lambda x: logits_projection( # noqa: E731
x, self.sampling_type, self.top_p, self.simplex_value, self.temperature
)
losses = []
previous_hidden = None
warped_steps = []
prev_t = 0
for t in self.progress_bar(self.scheduler.timesteps):
original_t = torch.tensor([t], device=self.device).expand(
batch_size, seq_length
)
if is_cdcd_check(self.model):
# warp timesteps based on cdf
# we are in inference mode, anything in span_mask is to gen.
token_inputs = torch.where(
batch["span_mask"], self.tokenizer.pad_token_id, batch["input_ids"]
)
t = self.model.warp_timesteps(
original_t,
t_min=0,
t_max=len(self.scheduler) - 1,
token_input=token_inputs,
span_mask=batch["span_mask"],
)
else:
t = original_t
t_scaled = scale(t, len(self.scheduler))
warped_steps.append(t)
noisy_simplex = self.scheduler.add_noise(simplex, noise, t)
attention_mask = batch["input_ids"] != self.tokenizer.pad_token_id
# TODO: do we care about self-conditioning...?
model_output = self.model(
input_ids=batch["input_ids"],
span_mask=batch["span_mask"],
attention_mask=attention_mask,
simplex=noisy_simplex,
timesteps=t_scaled,
classifier_free_guidance=classifier_free_guidance,
reduce_loss="none",
previous_pred=previous_pred,
max_timestep=len(self.scheduler),
previous_hidden=previous_hidden,
)
model_output_logits = model_output.logits
previous_hidden = model_output.hidden_states
losses.append(model_output.loss.detach().cpu())
if self.model.config.self_condition is not None:
prev_output_logits = model_output_logits
previous_pred = self_condition_preds(
self.model.config.self_condition,
prev_output_logits,
logits_projection_fct,
)
old_simplex = simplex
# no output stuff here, since all we care about is the loss.
# yield over it. (prolly not optimal, but whatever)
yield SimplexDiffusionPipelineOutput(
simplex=noisy_simplex, logits=model_output_logits, loss=losses[-1]
)
# we take the mean loss over all timesteps
loss = torch.stack(losses, dim=0)
return SimplexDiffusionPipelineOutput(
simplex=simplex, logits=model_output_logits, loss=loss
)