File size: 6,842 Bytes
17ff0d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
from typing import Optional

import torch
from torch import autograd
from transformers.modeling_outputs import MaskedLMOutput

from sdlm.models.cdcd.cdf import LossCDF
from sdlm.models.roberta.modeling_roberta import RobertaForDiffusionLM


class CDCDGARRobertaForDiffusionLM(RobertaForDiffusionLM):
    def __init__(self, config):
        super().__init__(config)
        self.cdf = LossCDF(100)

    def apply_gar(
        self, timesteps: torch.FloatTensor, token_input=None, t_min=0, t_max=1
    ):
        # Ensure timesteps is a floating point tensor for computations
        timesteps = timesteps.float()

        # Calculate token masks, excluding specific tokens (masking out padding and special tokens)
        token_masks = (token_input != 50264) & (token_input != 1)

        # Create a tensor representing each position in the sequence [0, 1, ..., seq_len-1]
        seq_len = token_input.size(1)
        positions = torch.arange(seq_len, device=token_input.device).float()

        # Calculate the difference between positions to create a matrix of relative distances
        # Shape of distances: [batch_size, seq_len, seq_len]
        distances = positions.unsqueeze(0).unsqueeze(2) - positions.unsqueeze(
            0
        ).unsqueeze(1)
        distances = distances.abs() / (
            seq_len - 1
        )  # Normalize distances to range [0, 1]

        # Apply token masks to the distances, setting distances for masked tokens to 0
        masked_distances = distances * token_masks.unsqueeze(1).float()

        # Sum the distances for each position, then normalize by the maximum distance to ensure range [0, 1]
        composed = masked_distances.sum(dim=2)
        # set padding tokens to 1, since we dont want these to affect the warping
        composed = torch.where(
            token_input == 1, torch.tensor(1.0, device=token_input.device), composed
        )
        composed_max, _ = composed.max(dim=1, keepdim=True)
        composed_normalized = (
            composed / composed_max
        )  # Now composed_normalized is in range [0, 1]
        composed_normalized = (
            1 - composed_normalized
        )  # Invert the composed_normalized values
        composed_normalized = (
            composed_normalized * 0.5
        )  # Scale the values to range [0, 0.5]

        # Adjust timesteps based on composed_normalized values
        # Ensure the operation is broadcastable: [batch_size, 1] * [batch_size, seq_len]
        slope = -t_max / torch.clip(t_max * composed_normalized - t_max, max=1e-8)
        adjusted_timesteps = slope * (timesteps - t_max) + t_max
        adjusted_timesteps = torch.clip(adjusted_timesteps, min=t_min, max=t_max)
        return adjusted_timesteps.long()

    def warp_timesteps(
        self, timesteps: torch.FloatTensor, token_input=None, t_min=0, t_max=1
    ):
        # u has to be in normalized range...
        if t_max - t_min > 0:
            timesteps = (timesteps - t_min) / (t_max - t_min)
        else:
            # weird case, only really happens with 1 diffusion steps (tmin=0,tmax=0)
            # in this case, we just set timesteps to 0
            timesteps = timesteps - t_min
            t_max = 1  # just to avoid div by 0
        # warp timesteps based on gar
        timesteps = self.apply_gar(timesteps, token_input, t_min, t_max)
        # then apply CDF
        return self.cdf(u=timesteps, normalized=True, t_min=t_min, t_max=t_max).detach()

    def forward(
        self,
        timesteps: torch.FloatTensor,
        input_ids: torch.LongTensor,
        simplex: torch.FloatTensor,
        span_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        previous_pred: Optional[torch.FloatTensor] = None,
        classifier_free_guidance: bool = False,
        classifier_free_guidance_in_train: bool = False,
        max_timestep: int = 5000,
        reduce_loss: str = "mean",  # passed to 'reduction' in F.cross_entropy
        # unconditional_simplex: torch.FloatTensor = None,
        return_all_losses: bool = False,  # return per-token loss for all items in batch):
        previous_hidden: Optional[torch.FloatTensor] = None,
        original_timesteps: Optional[torch.FloatTensor] = None,
    ):
        output = super().forward(
            timesteps,
            input_ids,
            simplex,
            span_mask,
            token_type_ids,
            position_ids,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            labels,
            output_attentions,
            output_hidden_states,
            return_dict,
            previous_pred,
            classifier_free_guidance,
            classifier_free_guidance_in_train,
            max_timestep,
            reduce_loss="none",
            return_all_losses=False,
            previous_hidden=previous_hidden,  # for CDCD predictions...
        )
        loss = output.loss
        if self.training:
            # then we learn the cdf from the losses
            # only in train mode, since in eval we just apply the warping.
            new_timesteps_clone = timesteps.clone()
            new_timesteps_clone.requires_grad = True
            with torch.enable_grad():
                # grab the predictions for the loss values - note at this point timesteps
                # are normalised to [0, 1]
                xent_pred = self.cdf(t=new_timesteps_clone, normalized=False, t_max=1)
                # importance weights -> reciprocal of grad of CDF.
                imp_weights = (
                    1.0 / autograd.grad(xent_pred.sum(), [new_timesteps_clone])[0]
                )[:, 0]
            imp_weights = imp_weights.detach() * 1e-5
            # just one index of timesteps since all are the same. required for compat with tokenwise
            cdf_loss = (
                imp_weights
                * (
                    self.cdf(t=timesteps, normalized=False, t_max=1)[:, 0]
                    - loss.detach()
                ).pow(2)
            ).mean()
            loss = loss.mean() + cdf_loss  # upweight cdf loss as its too small :(
        else:
            loss = loss.mean()
        return MaskedLMOutput(
            loss=loss,
            logits=output.logits,
            hidden_states=output.hidden_states,
            attentions=output.attentions,
        )