Spaces:
Sleeping
Sleeping
from typing import List, Optional, Tuple, Union | |
import torch | |
from torch import autograd | |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | |
from torch.nn import functional as F | |
from transformers.cache_utils import Cache | |
from transformers.modeling_outputs import ( | |
MaskedLMOutput, | |
SequenceClassifierOutputWithPast, | |
) | |
from sdlm.data.data_utils import pad_sequence | |
from sdlm.models.cdcd.cdf import LossCDF | |
from sdlm.utils import mix_values_based_on_self_condition | |
class DiffusionModelMixin: | |
def forward( | |
self, | |
timesteps: torch.FloatTensor, | |
input_ids: torch.LongTensor, | |
simplex: torch.FloatTensor, | |
span_mask: Optional[torch.FloatTensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
previous_pred: Optional[torch.FloatTensor] = None, | |
reduce_loss: str = "mean", | |
attention_mask: Optional[torch.LongTensor] =None, | |
**kwargs, | |
): | |
# simplex -> weighted avg embedding | |
inputs_probs = F.softmax(simplex, dim=-1) | |
inputs_embeds = self.vocab_to_hidden_dim_embed(inputs_probs) | |
if self.config.self_condition is not None: | |
if previous_pred is None: | |
previous_pred = torch.zeros_like(simplex, device=simplex.device) | |
previous_pred_probs = F.softmax(previous_pred, dim=-1) | |
if not self.config.self_condition_mix_logits_before_weights: | |
previous_pred = self.vocab_to_hidden_dim_embed(previous_pred_probs) | |
# In this setting, we mix the probabilities then apply the weight. | |
if self.config.self_condition_mix_before_weights: | |
mixed_probs = mix_values_based_on_self_condition( | |
self.config.self_condition, inputs_probs, previous_pred_probs | |
) | |
inputs_embeds = self.vocab_to_hidden_dim_embed(mixed_probs) | |
# Original word embeddings without noise. | |
inputs_word_embeds = self.get_input_embeddings()(input_ids) | |
if not self.config.disable_timestep_embed: | |
timesteps = torch.where(span_mask, timesteps, torch.zeros_like(timesteps)) | |
timesteps_embed = self.timestep_embed(timesteps.unsqueeze(-1).float()) | |
inputs_embeds = inputs_embeds + timesteps_embed | |
# For the unmasked tokens, we only compute their original word embeddings. | |
# Note that this also sets the self-conditioned inputs which we are conditioning on | |
# to their original word embeddings values. | |
inputs_embeds = torch.where( | |
span_mask.unsqueeze(-1), inputs_embeds, inputs_word_embeds | |
) | |
outputs = self.model( | |
input_ids=None, # TODO(rabeeh): we can remove this hack when we moved loss to outside. | |
attention_mask=attention_mask, # only used for dealing with padding during evals | |
position_ids=position_ids, | |
past_key_values=None, | |
inputs_embeds=inputs_embeds, | |
use_cache=False, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=True, | |
) | |
sequence_output = outputs[0] | |
prediction_scores = self.lm_head(sequence_output) | |
masked_lm_loss = None | |
if input_ids is not None: | |
prediction_scores_for_loss = prediction_scores | |
loss_fct = CrossEntropyLoss(reduction=reduce_loss) | |
labels = ( | |
torch.where(span_mask, input_ids, -100) | |
if span_mask is not None | |
else input_ids | |
) | |
if self.config.mask_padding_in_loss: | |
# also mask padding token loss.... | |
labels = torch.where(labels == self.config.pad_token_id, -100, labels) | |
# important: shift labels to the right by one, mimicking the causal pretraining | |
labels = labels[:, 1:] | |
prediction_scores_for_loss = prediction_scores_for_loss[:, :-1] | |
masked_lm_loss = loss_fct( | |
prediction_scores_for_loss.reshape(-1, self.config.vocab_size), | |
labels.reshape(-1), | |
) | |
if reduce_loss == "none": | |
# take the average loss over tokens, not counting the masked tokens. | |
masked_lm_loss = masked_lm_loss.view(input_ids.shape[0], -1) | |
masked_lm_loss = masked_lm_loss.sum(dim=-1) / span_mask.sum(dim=-1) | |
# shift our logits forward by one, so that input->output match | |
prediction_scores = prediction_scores[:, :-1] | |
# add back in our start tok. | |
padding_pred = torch.zeros_like(prediction_scores[:, 0])[:, None] | |
prediction_scores = torch.cat([padding_pred, prediction_scores], dim=1) | |
return MaskedLMOutput( | |
loss=masked_lm_loss, | |
logits=prediction_scores, | |
hidden_states=outputs.last_hidden_state, | |
attentions=outputs.attentions, | |
) | |
class CDCDDiffusionModelMixin(DiffusionModelMixin): | |
def __init__(self, config): | |
super().__init__(config) | |
self.cdf = LossCDF(config.n_bins) | |
def warp_timesteps( | |
self, | |
timesteps: torch.FloatTensor, | |
token_input=None, | |
t_min=0, | |
t_max=1, | |
**kwargs, | |
): | |
# u has to be in normalized range... | |
if t_max - t_min > 0: | |
timesteps = (timesteps - t_min) / (t_max - t_min) | |
else: | |
# weird case, only really happens with 1 diffusion steps (tmin=0,tmax=0) | |
# in this case, we just set timesteps to 0 | |
timesteps = timesteps - t_min | |
t_max = 1 # just to avoid div by 0 | |
# warp timesteps. sep. call so we can pass to scheduler | |
# detach so we don't backprop through this | |
return self.cdf(u=timesteps, normalized=True, t_min=t_min, t_max=t_max).detach() | |
def forward( | |
self, | |
timesteps: torch.FloatTensor, | |
input_ids: torch.LongTensor, | |
simplex: torch.FloatTensor, | |
span_mask: Optional[torch.FloatTensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
previous_pred: Optional[torch.FloatTensor] = None, | |
reduce_loss: str = "mean", | |
**kwargs, | |
): | |
output = super().forward( | |
timesteps=timesteps, | |
input_ids=input_ids, | |
simplex=simplex, | |
span_mask=span_mask, | |
position_ids=position_ids, | |
labels=labels, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
previous_pred=previous_pred, | |
reduce_loss=reduce_loss, | |
**kwargs, | |
) | |
loss = output.loss | |
# NOTE: need inference mode check to prevent cdf loss computation | |
# for prev generation in self-conditioning | |
if self.training and not torch.is_inference_mode_enabled(): | |
# then we learn the cdf from the losses | |
# only in train mode, since in eval we just apply the warping. | |
new_timesteps_clone = timesteps.clone() | |
new_timesteps_clone.requires_grad = True | |
with torch.enable_grad(): | |
# grab the predictions for the loss values - note at this point timesteps | |
# are normalised to [0, 1] | |
xent_pred = self.cdf(t=new_timesteps_clone, normalized=False, t_max=1) | |
# importance weights -> reciprocal of grad of CDF. | |
imp_weights = ( | |
1.0 / autograd.grad(xent_pred.sum(), [new_timesteps_clone])[0] | |
)[:, 0] | |
imp_weights = imp_weights.detach() * 1e-5 | |
# just one index of timesteps since all are the same. required for compat with tokenwise | |
cdf_loss = ( | |
imp_weights | |
* ( | |
self.cdf(t=timesteps, normalized=False, t_max=1)[:, 0] | |
- loss.detach() | |
).pow(2) | |
).mean() | |
loss = loss.mean() + cdf_loss # upweight cdf loss as its too small :( | |
else: | |
loss = loss.mean() | |
return MaskedLMOutput( | |
loss=loss, | |
logits=output.logits, | |
hidden_states=output.hidden_states, | |
attentions=output.attentions, | |
) | |
class CausalLMForSeq2SeqMixin: | |
def forward( | |
self, | |
input_ids, | |
attention_mask=None, | |
position_ids=None, | |
past_key_values=None, | |
inputs_embeds=None, | |
labels=None, | |
use_cache=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
pad_lengths=None, | |
context_lengths=None, | |
): | |
""" | |
HACK: added input lengths to forward args for generate(), | |
otherwise `Trainer`'s `remove_unused_columns` will remove all | |
keys from kwargs. | |
""" | |
return super().forward( | |
input_ids, | |
attention_mask, | |
position_ids, | |
past_key_values, | |
inputs_embeds, | |
labels, | |
use_cache, | |
output_attentions, | |
output_hidden_states, | |
return_dict, | |
) | |
def generate(self, *args, **kwargs): | |
context_tokens = [] | |
# labels not needed for generation | |
del kwargs["labels"] | |
input_ids = kwargs.pop("input_ids") | |
if "pad_lengths" in kwargs: | |
pad_lengths = kwargs.pop("pad_lengths") | |
context_lengths = kwargs.pop("context_lengths") | |
for input_id, pad_length, context_length in zip( | |
input_ids, pad_lengths, context_lengths | |
): | |
# grab non-padding context, without labels | |
context_tokens.append( | |
input_id[pad_length : pad_length + context_length] | |
) | |
else: | |
context_tokens = input_ids | |
input_ids = pad_sequence( | |
context_tokens, | |
padding_value=self.config.pad_token_id, | |
batch_first=True, | |
padding_side=self.config.padding_side, | |
) | |
kwargs["input_ids"] = input_ids.to(self.device) | |
kwargs["attention_mask"] = ~(kwargs["input_ids"] == self.config.pad_token_id) | |
# need to set to false due to flash attention | |
kwargs["use_cache"] = False | |
kwargs["max_new_tokens"] = kwargs.get("max_length", 512) | |
kwargs.pop("max_length", None) | |
outputs = super().generate(*args, **kwargs) | |
seq_len = input_ids.size(1) | |
output_ids = outputs[:, seq_len:] | |
return output_ids.to(self.device) | |
class PaddingIncludedSequenceClassificationMixin: | |
def forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, SequenceClassifierOutputWithPast]: | |
r""" | |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | |
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., | |
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If | |
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). | |
""" | |
return_dict = ( | |
return_dict if return_dict is not None else self.config.use_return_dict | |
) | |
transformer_outputs = self.model( | |
input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_states = transformer_outputs[0] | |
logits = self.score(hidden_states) | |
if input_ids is not None: | |
batch_size = input_ids.shape[0] | |
else: | |
batch_size = inputs_embeds.shape[0] | |
# we always use the last hidden state for classification | |
# this is the only change from the original implementation | |
sequence_lengths = -1 | |
pooled_logits = logits[ | |
torch.arange(batch_size, device=logits.device), sequence_lengths | |
] | |
loss = None | |
if labels is not None: | |
labels = labels.to(logits.device) | |
if self.config.problem_type is None: | |
if self.num_labels == 1: | |
self.config.problem_type = "regression" | |
elif self.num_labels > 1 and ( | |
labels.dtype == torch.long or labels.dtype == torch.int | |
): | |
self.config.problem_type = "single_label_classification" | |
else: | |
self.config.problem_type = "multi_label_classification" | |
if self.config.problem_type == "regression": | |
loss_fct = MSELoss() | |
if self.num_labels == 1: | |
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) | |
else: | |
loss = loss_fct(pooled_logits, labels) | |
elif self.config.problem_type == "single_label_classification": | |
loss_fct = CrossEntropyLoss() | |
loss = loss_fct( | |
pooled_logits.view(-1, self.num_labels), labels.view(-1) | |
) | |
elif self.config.problem_type == "multi_label_classification": | |
loss_fct = BCEWithLogitsLoss() | |
loss = loss_fct(pooled_logits, labels) | |
if not return_dict: | |
output = (pooled_logits,) + transformer_outputs[1:] | |
return ((loss,) + output) if loss is not None else output | |
return SequenceClassifierOutputWithPast( | |
loss=loss, | |
logits=pooled_logits, | |
past_key_values=transformer_outputs.past_key_values, | |
hidden_states=transformer_outputs.hidden_states, | |
attentions=transformer_outputs.attentions, | |
) | |