from typing import Optional import torch from torch import autograd from transformers.modeling_outputs import MaskedLMOutput from sdlm.models.cdcd.cdf import LossCDF from sdlm.models.roberta.modeling_roberta import RobertaForDiffusionLM class CDCDGARRobertaForDiffusionLM(RobertaForDiffusionLM): def __init__(self, config): super().__init__(config) self.cdf = LossCDF(100) def apply_gar( self, timesteps: torch.FloatTensor, token_input=None, t_min=0, t_max=1 ): # Ensure timesteps is a floating point tensor for computations timesteps = timesteps.float() # Calculate token masks, excluding specific tokens (masking out padding and special tokens) token_masks = (token_input != 50264) & (token_input != 1) # Create a tensor representing each position in the sequence [0, 1, ..., seq_len-1] seq_len = token_input.size(1) positions = torch.arange(seq_len, device=token_input.device).float() # Calculate the difference between positions to create a matrix of relative distances # Shape of distances: [batch_size, seq_len, seq_len] distances = positions.unsqueeze(0).unsqueeze(2) - positions.unsqueeze( 0 ).unsqueeze(1) distances = distances.abs() / ( seq_len - 1 ) # Normalize distances to range [0, 1] # Apply token masks to the distances, setting distances for masked tokens to 0 masked_distances = distances * token_masks.unsqueeze(1).float() # Sum the distances for each position, then normalize by the maximum distance to ensure range [0, 1] composed = masked_distances.sum(dim=2) # set padding tokens to 1, since we dont want these to affect the warping composed = torch.where( token_input == 1, torch.tensor(1.0, device=token_input.device), composed ) composed_max, _ = composed.max(dim=1, keepdim=True) composed_normalized = ( composed / composed_max ) # Now composed_normalized is in range [0, 1] composed_normalized = ( 1 - composed_normalized ) # Invert the composed_normalized values composed_normalized = ( composed_normalized * 0.5 ) # Scale the values to range [0, 0.5] # Adjust timesteps based on composed_normalized values # Ensure the operation is broadcastable: [batch_size, 1] * [batch_size, seq_len] slope = -t_max / torch.clip(t_max * composed_normalized - t_max, max=1e-8) adjusted_timesteps = slope * (timesteps - t_max) + t_max adjusted_timesteps = torch.clip(adjusted_timesteps, min=t_min, max=t_max) return adjusted_timesteps.long() def warp_timesteps( self, timesteps: torch.FloatTensor, token_input=None, t_min=0, t_max=1 ): # u has to be in normalized range... if t_max - t_min > 0: timesteps = (timesteps - t_min) / (t_max - t_min) else: # weird case, only really happens with 1 diffusion steps (tmin=0,tmax=0) # in this case, we just set timesteps to 0 timesteps = timesteps - t_min t_max = 1 # just to avoid div by 0 # warp timesteps based on gar timesteps = self.apply_gar(timesteps, token_input, t_min, t_max) # then apply CDF return self.cdf(u=timesteps, normalized=True, t_min=t_min, t_max=t_max).detach() def forward( self, timesteps: torch.FloatTensor, input_ids: torch.LongTensor, simplex: torch.FloatTensor, span_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, previous_pred: Optional[torch.FloatTensor] = None, classifier_free_guidance: bool = False, classifier_free_guidance_in_train: bool = False, max_timestep: int = 5000, reduce_loss: str = "mean", # passed to 'reduction' in F.cross_entropy # unconditional_simplex: torch.FloatTensor = None, return_all_losses: bool = False, # return per-token loss for all items in batch): previous_hidden: Optional[torch.FloatTensor] = None, original_timesteps: Optional[torch.FloatTensor] = None, ): output = super().forward( timesteps, input_ids, simplex, span_mask, token_type_ids, position_ids, head_mask, encoder_hidden_states, encoder_attention_mask, labels, output_attentions, output_hidden_states, return_dict, previous_pred, classifier_free_guidance, classifier_free_guidance_in_train, max_timestep, reduce_loss="none", return_all_losses=False, previous_hidden=previous_hidden, # for CDCD predictions... ) loss = output.loss if self.training: # then we learn the cdf from the losses # only in train mode, since in eval we just apply the warping. new_timesteps_clone = timesteps.clone() new_timesteps_clone.requires_grad = True with torch.enable_grad(): # grab the predictions for the loss values - note at this point timesteps # are normalised to [0, 1] xent_pred = self.cdf(t=new_timesteps_clone, normalized=False, t_max=1) # importance weights -> reciprocal of grad of CDF. imp_weights = ( 1.0 / autograd.grad(xent_pred.sum(), [new_timesteps_clone])[0] )[:, 0] imp_weights = imp_weights.detach() * 1e-5 # just one index of timesteps since all are the same. required for compat with tokenwise cdf_loss = ( imp_weights * ( self.cdf(t=timesteps, normalized=False, t_max=1)[:, 0] - loss.detach() ).pow(2) ).mean() loss = loss.mean() + cdf_loss # upweight cdf loss as its too small :( else: loss = loss.mean() return MaskedLMOutput( loss=loss, logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions, )