Spaces:
Sleeping
Sleeping
from typing import Optional | |
import numpy as np | |
import torch | |
from torch import autograd | |
from transformers import RobertaForMaskedLM | |
from transformers.modeling_outputs import MaskedLMOutput | |
from sdlm.models.cdcd.cdf import LossCDF | |
from sdlm.models.roberta.configuration_roberta import RobertaDiffusionConfig | |
from sdlm.models.roberta.modeling_roberta import RobertaForDiffusionLM | |
# only difference is that we add n_bins to the config | |
class TokenwiseCDCDRobertaConfig(RobertaDiffusionConfig): | |
def __init__(self, *args, n_bins=100, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.n_bins = n_bins | |
# Roberta with the CDF timestep warper. | |
class TokenwiseCDCDRobertaForDiffusionLM(RobertaForDiffusionLM): | |
def __init__(self, config): | |
super().__init__(config) | |
self.cdf = LossCDF(100) | |
# keep the hidden dim larger? | |
self.base_lm = RobertaForMaskedLM.from_pretrained("roberta-base") | |
self.linear_lu = torch.nn.Sequential( | |
torch.nn.Linear(self.config.hidden_size, self.config.hidden_size), | |
torch.nn.ReLU(), | |
torch.nn.Linear(self.config.hidden_size, 100), | |
) | |
self.linear_lt = torch.nn.Sequential( | |
torch.nn.Linear(self.config.hidden_size, self.config.hidden_size), | |
torch.nn.ReLU(), | |
torch.nn.Linear(self.config.hidden_size, 100), | |
) | |
self.start_lt = torch.zeros([100]) - float(np.log(100)) | |
self.start_lu = torch.zeros([100]) - float(np.log(100)) | |
# small starting a | |
self.linear_lu_start_a = torch.nn.Parameter(torch.zeros([1]) + 1) | |
self.linear_lt_start_a = torch.nn.Parameter(torch.zeros([1]) + 1) | |
def warp_timesteps( | |
self, | |
timesteps: torch.FloatTensor, | |
token_input: Optional[torch.LongTensor] = None, | |
t_min=0, | |
t_max=1, | |
): | |
# u has to be in normalized range... | |
if t_max - t_min > 0: | |
timesteps = (timesteps - t_min) / (t_max - t_min) | |
else: | |
# weird case, only really happens with 1 diffusion steps (tmin=0,tmax=0) | |
# in this case, we just set timesteps to 0 | |
timesteps = timesteps - t_min | |
t_max = 1 # just to avoid div by 0 | |
if token_input is None: | |
lu, lt = None, None | |
else: | |
# replace padding tokens with <mask> token | |
# to avoid model ignoring those tokens | |
token_input = torch.where(token_input == 1, 50264, token_input) | |
hidden_states = self.base_lm.roberta( | |
input_ids=token_input, output_hidden_states=True | |
).hidden_states[-1] | |
# predict out the new timesteps | |
lu = self.start_lu.to( | |
self.linear_lu_start_a.device | |
) + self.linear_lu_start_a * self.linear_lu( | |
torch.cat([hidden_states], dim=-1) | |
) | |
lt = self.start_lt.to( | |
self.linear_lu_start_a.device | |
) + self.linear_lt_start_a * self.linear_lt( | |
torch.cat([hidden_states], dim=-1) | |
) | |
# lu = self.linear_lu(previous_hidden) | |
# lt = self.linear_lt(previous_hidden) | |
# warp timesteps. sep. call so we can pass to scheduler | |
# detach so we don't backprop through this | |
return self.cdf( | |
u=timesteps, normalized=True, t_min=t_min, t_max=t_max, l_u=lu, l_t=lt | |
).detach() | |
def forward( | |
self, | |
timesteps: torch.FloatTensor, | |
input_ids: torch.LongTensor, | |
simplex: torch.FloatTensor, | |
span_mask: Optional[torch.FloatTensor] = None, | |
token_type_ids: Optional[torch.LongTensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
head_mask: Optional[torch.FloatTensor] = None, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
previous_pred: Optional[torch.FloatTensor] = None, | |
classifier_free_guidance: bool = False, | |
classifier_free_guidance_in_train: bool = False, | |
max_timestep: int = 5000, | |
reduce_loss: str = "mean", # passed to 'reduction' in F.cross_entropy | |
# unconditional_simplex: torch.FloatTensor = None, | |
return_all_losses: bool = False, # return per-token loss for all items in batch): | |
previous_hidden: Optional[torch.FloatTensor] = None, | |
original_timesteps: Optional[torch.FloatTensor] = None, | |
): | |
output = super().forward( | |
timesteps, | |
input_ids, | |
simplex, | |
span_mask, | |
token_type_ids, | |
position_ids, | |
head_mask, | |
encoder_hidden_states, | |
encoder_attention_mask, | |
labels, | |
output_attentions, | |
output_hidden_states, | |
return_dict, | |
previous_pred, | |
classifier_free_guidance, | |
classifier_free_guidance_in_train, | |
max_timestep, | |
reduce_loss="none", | |
return_all_losses=True, | |
previous_hidden=previous_hidden, # for CDCD predictions... | |
) | |
loss = output.loss | |
if self.training: | |
# then we learn the cdf from the losses | |
# only in train mode, since in eval we just apply the warping. | |
new_timesteps_clone = timesteps.clone() | |
new_timesteps_clone.requires_grad = True | |
with torch.enable_grad(): | |
# grab the predictions for the loss values - note at this point timesteps | |
# are normalised to [0, 1] | |
# at train time: we want to predict the tokens not in the span mask, | |
# replace with <mask> | |
token_input = torch.where((input_ids * span_mask) > 1, 50264, input_ids) | |
previous_hidden = self.base_lm.roberta( | |
input_ids=token_input, output_hidden_states=True | |
).hidden_states[-1] | |
if previous_hidden is None: | |
lu, lt = None, None | |
else: | |
lu = self.start_lu.to( | |
self.linear_lu_start_a.device | |
) + self.linear_lu_start_a * self.linear_lu( | |
torch.cat([previous_hidden], dim=-1) | |
) | |
lt = self.start_lt.to( | |
self.linear_lt_start_a.device | |
) + self.linear_lt_start_a * self.linear_lt( | |
torch.cat([previous_hidden], dim=-1) | |
) | |
# lu = self.linear_lu(previous_hidden) | |
# lt = self.linear_lt(previous_hidden) | |
xent_pred = self.cdf( | |
t=new_timesteps_clone, normalized=False, t_max=1, l_u=lu, l_t=lt | |
) | |
# importance weights -> reciprocal of grad of CDF. | |
imp_weights = ( | |
1.0 / autograd.grad(xent_pred.sum(), [new_timesteps_clone])[0] | |
) | |
imp_weights = imp_weights.detach() * 1e-5 | |
cdf_loss = imp_weights * ( | |
self.cdf(t=timesteps, normalized=False, t_max=1, l_u=lu, l_t=lt) | |
- loss.detach() | |
).pow(2) | |
# mask regular input part of loss, since we don't warp this anyway. | |
# also mask out padding at the end. | |
cdf_loss = cdf_loss * span_mask * (input_ids != 1) | |
import pdb | |
pdb.set_trace() | |
loss = loss.mean() + cdf_loss.mean() | |
else: | |
loss = loss.mean() | |
return MaskedLMOutput( | |
loss=loss, | |
logits=output.logits, | |
hidden_states=output.hidden_states, | |
attentions=output.attentions, | |
) | |