Spaces:
Sleeping
Sleeping
File size: 6,025 Bytes
17ff0d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
from typing import Optional
import numpy as np
import torch
from torch import autograd
from transformers.modeling_outputs import MaskedLMOutput
from sdlm.models.cdcd.cdf import LossCDF
from sdlm.models.roberta.configuration_roberta import RobertaDiffusionConfig
from sdlm.models.roberta.modeling_roberta import RobertaForDiffusionLM
# only difference is that we add n_bins to the config
class PositionwiseCDCDRobertaConfig(RobertaDiffusionConfig):
def __init__(self, *args, n_bins=100, **kwargs):
super().__init__(*args, **kwargs)
self.n_bins = n_bins
# Roberta with the CDF timestep warper.
class PositionwiseCDCDRobertaForDiffusionLM(RobertaForDiffusionLM):
def __init__(self, config):
super().__init__(config)
self.position_lus = torch.nn.Parameter(
torch.zeros([config.max_position_embeddings, 100]) - float(np.log(100))
)
self.position_lts = torch.nn.Parameter(
torch.zeros([config.max_position_embeddings, 100]) - float(np.log(100))
)
self.cdf = LossCDF(100)
def warp_timesteps(
self,
timesteps: torch.FloatTensor,
previous_hidden: Optional[torch.FloatTensor] = None,
t_min=0,
t_max=1,
):
# u has to be in normalized range...
if t_max - t_min > 0:
timesteps = (timesteps - t_min) / (t_max - t_min)
else:
# weird case, only really happens with 1 diffusion steps (tmin=0,tmax=0)
# in this case, we just set timesteps to 0
timesteps = timesteps - t_min
t_max = 1 # just to avoid div by 0
# warp timesteps. sep. call so we can pass to scheduler
# detach so we don't backprop through this
# not all batches will have max seq length, so cut to suze
pos_lus = self.position_lus[None, : timesteps.shape[1]].expand(
timesteps.shape[0], -1, -1
)
pos_lts = self.position_lts[None, : timesteps.shape[1]].expand(
timesteps.shape[0], -1, -1
)
return self.cdf(
u=timesteps,
normalized=True,
t_min=t_min,
t_max=t_max,
l_u=pos_lus,
l_t=pos_lts,
).detach()
def forward(
self,
timesteps: torch.FloatTensor,
input_ids: torch.LongTensor,
simplex: torch.FloatTensor,
span_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
previous_pred: Optional[torch.FloatTensor] = None,
classifier_free_guidance: bool = False,
classifier_free_guidance_in_train: bool = False,
max_timestep: int = 5000,
reduce_loss: str = "mean", # passed to 'reduction' in F.cross_entropy
# unconditional_simplex: torch.FloatTensor = None,
return_all_losses: bool = False, # return per-token loss for all items in batch):
previous_hidden: Optional[torch.FloatTensor] = None,
original_timesteps: Optional[torch.FloatTensor] = None,
):
output = super().forward(
timesteps,
input_ids,
simplex,
span_mask,
token_type_ids,
position_ids,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
labels,
output_attentions,
output_hidden_states,
return_dict,
previous_pred,
classifier_free_guidance,
classifier_free_guidance_in_train,
max_timestep,
reduce_loss="none",
return_all_losses=True,
previous_hidden=previous_hidden, # for CDCD predictions...
)
loss = output.loss
if self.training:
# then we learn the cdf from the losses
# only in train mode, since in eval we just apply the warping.
new_timesteps_clone = timesteps.clone()
new_timesteps_clone.requires_grad = True
with torch.enable_grad():
# grab the predictions for the loss values - note at this point timesteps
# are normalised to [0, 1]
pos_lus = self.position_lus[None, : timesteps.shape[1]].expand(
timesteps.shape[0], -1, -1
)
pos_lts = self.position_lts[None, : timesteps.shape[1]].expand(
timesteps.shape[0], -1, -1
)
xent_pred = self.cdf(
t=new_timesteps_clone,
normalized=False,
t_max=1,
l_u=pos_lus,
l_t=pos_lts,
)
# importance weights -> reciprocal of grad of CDF.
imp_weights = (
1.0 / autograd.grad(xent_pred.sum(), [new_timesteps_clone])[0]
)
imp_weights = imp_weights.detach() * 1e-5
cdf_loss = (
imp_weights
* (
self.cdf(
t=timesteps, normalized=False, t_max=1, l_u=pos_lus, l_t=pos_lts
)
- loss.detach()
).pow(2)
).mean()
loss = loss.mean() + cdf_loss # upweight cdf loss as its too small :(
else:
loss = loss.mean()
return MaskedLMOutput(
loss=loss,
logits=output.logits,
hidden_states=output.hidden_states,
attentions=output.attentions,
)
|