Spaces:
Sleeping
Sleeping
File size: 6,842 Bytes
17ff0d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
from typing import Optional
import torch
from torch import autograd
from transformers.modeling_outputs import MaskedLMOutput
from sdlm.models.cdcd.cdf import LossCDF
from sdlm.models.roberta.modeling_roberta import RobertaForDiffusionLM
class CDCDGARRobertaForDiffusionLM(RobertaForDiffusionLM):
def __init__(self, config):
super().__init__(config)
self.cdf = LossCDF(100)
def apply_gar(
self, timesteps: torch.FloatTensor, token_input=None, t_min=0, t_max=1
):
# Ensure timesteps is a floating point tensor for computations
timesteps = timesteps.float()
# Calculate token masks, excluding specific tokens (masking out padding and special tokens)
token_masks = (token_input != 50264) & (token_input != 1)
# Create a tensor representing each position in the sequence [0, 1, ..., seq_len-1]
seq_len = token_input.size(1)
positions = torch.arange(seq_len, device=token_input.device).float()
# Calculate the difference between positions to create a matrix of relative distances
# Shape of distances: [batch_size, seq_len, seq_len]
distances = positions.unsqueeze(0).unsqueeze(2) - positions.unsqueeze(
0
).unsqueeze(1)
distances = distances.abs() / (
seq_len - 1
) # Normalize distances to range [0, 1]
# Apply token masks to the distances, setting distances for masked tokens to 0
masked_distances = distances * token_masks.unsqueeze(1).float()
# Sum the distances for each position, then normalize by the maximum distance to ensure range [0, 1]
composed = masked_distances.sum(dim=2)
# set padding tokens to 1, since we dont want these to affect the warping
composed = torch.where(
token_input == 1, torch.tensor(1.0, device=token_input.device), composed
)
composed_max, _ = composed.max(dim=1, keepdim=True)
composed_normalized = (
composed / composed_max
) # Now composed_normalized is in range [0, 1]
composed_normalized = (
1 - composed_normalized
) # Invert the composed_normalized values
composed_normalized = (
composed_normalized * 0.5
) # Scale the values to range [0, 0.5]
# Adjust timesteps based on composed_normalized values
# Ensure the operation is broadcastable: [batch_size, 1] * [batch_size, seq_len]
slope = -t_max / torch.clip(t_max * composed_normalized - t_max, max=1e-8)
adjusted_timesteps = slope * (timesteps - t_max) + t_max
adjusted_timesteps = torch.clip(adjusted_timesteps, min=t_min, max=t_max)
return adjusted_timesteps.long()
def warp_timesteps(
self, timesteps: torch.FloatTensor, token_input=None, t_min=0, t_max=1
):
# u has to be in normalized range...
if t_max - t_min > 0:
timesteps = (timesteps - t_min) / (t_max - t_min)
else:
# weird case, only really happens with 1 diffusion steps (tmin=0,tmax=0)
# in this case, we just set timesteps to 0
timesteps = timesteps - t_min
t_max = 1 # just to avoid div by 0
# warp timesteps based on gar
timesteps = self.apply_gar(timesteps, token_input, t_min, t_max)
# then apply CDF
return self.cdf(u=timesteps, normalized=True, t_min=t_min, t_max=t_max).detach()
def forward(
self,
timesteps: torch.FloatTensor,
input_ids: torch.LongTensor,
simplex: torch.FloatTensor,
span_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
previous_pred: Optional[torch.FloatTensor] = None,
classifier_free_guidance: bool = False,
classifier_free_guidance_in_train: bool = False,
max_timestep: int = 5000,
reduce_loss: str = "mean", # passed to 'reduction' in F.cross_entropy
# unconditional_simplex: torch.FloatTensor = None,
return_all_losses: bool = False, # return per-token loss for all items in batch):
previous_hidden: Optional[torch.FloatTensor] = None,
original_timesteps: Optional[torch.FloatTensor] = None,
):
output = super().forward(
timesteps,
input_ids,
simplex,
span_mask,
token_type_ids,
position_ids,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
labels,
output_attentions,
output_hidden_states,
return_dict,
previous_pred,
classifier_free_guidance,
classifier_free_guidance_in_train,
max_timestep,
reduce_loss="none",
return_all_losses=False,
previous_hidden=previous_hidden, # for CDCD predictions...
)
loss = output.loss
if self.training:
# then we learn the cdf from the losses
# only in train mode, since in eval we just apply the warping.
new_timesteps_clone = timesteps.clone()
new_timesteps_clone.requires_grad = True
with torch.enable_grad():
# grab the predictions for the loss values - note at this point timesteps
# are normalised to [0, 1]
xent_pred = self.cdf(t=new_timesteps_clone, normalized=False, t_max=1)
# importance weights -> reciprocal of grad of CDF.
imp_weights = (
1.0 / autograd.grad(xent_pred.sum(), [new_timesteps_clone])[0]
)[:, 0]
imp_weights = imp_weights.detach() * 1e-5
# just one index of timesteps since all are the same. required for compat with tokenwise
cdf_loss = (
imp_weights
* (
self.cdf(t=timesteps, normalized=False, t_max=1)[:, 0]
- loss.detach()
).pow(2)
).mean()
loss = loss.mean() + cdf_loss # upweight cdf loss as its too small :(
else:
loss = loss.mean()
return MaskedLMOutput(
loss=loss,
logits=output.logits,
hidden_states=output.hidden_states,
attentions=output.attentions,
)
|