File size: 6,025 Bytes
17ff0d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from typing import Optional

import numpy as np
import torch
from torch import autograd
from transformers.modeling_outputs import MaskedLMOutput

from sdlm.models.cdcd.cdf import LossCDF
from sdlm.models.roberta.configuration_roberta import RobertaDiffusionConfig
from sdlm.models.roberta.modeling_roberta import RobertaForDiffusionLM


# only difference is that we add n_bins to the config
class PositionwiseCDCDRobertaConfig(RobertaDiffusionConfig):
    def __init__(self, *args, n_bins=100, **kwargs):
        super().__init__(*args, **kwargs)
        self.n_bins = n_bins


# Roberta with the CDF timestep warper.
class PositionwiseCDCDRobertaForDiffusionLM(RobertaForDiffusionLM):
    def __init__(self, config):
        super().__init__(config)
        self.position_lus = torch.nn.Parameter(
            torch.zeros([config.max_position_embeddings, 100]) - float(np.log(100))
        )
        self.position_lts = torch.nn.Parameter(
            torch.zeros([config.max_position_embeddings, 100]) - float(np.log(100))
        )
        self.cdf = LossCDF(100)

    def warp_timesteps(
        self,
        timesteps: torch.FloatTensor,
        previous_hidden: Optional[torch.FloatTensor] = None,
        t_min=0,
        t_max=1,
    ):
        # u has to be in normalized range...
        if t_max - t_min > 0:
            timesteps = (timesteps - t_min) / (t_max - t_min)
        else:
            # weird case, only really happens with 1 diffusion steps (tmin=0,tmax=0)
            # in this case, we just set timesteps to 0
            timesteps = timesteps - t_min
            t_max = 1  # just to avoid div by 0

        # warp timesteps. sep. call so we can pass to scheduler
        # detach so we don't backprop through this
        # not all batches will have max seq length, so cut to suze
        pos_lus = self.position_lus[None, : timesteps.shape[1]].expand(
            timesteps.shape[0], -1, -1
        )
        pos_lts = self.position_lts[None, : timesteps.shape[1]].expand(
            timesteps.shape[0], -1, -1
        )
        return self.cdf(
            u=timesteps,
            normalized=True,
            t_min=t_min,
            t_max=t_max,
            l_u=pos_lus,
            l_t=pos_lts,
        ).detach()

    def forward(
        self,
        timesteps: torch.FloatTensor,
        input_ids: torch.LongTensor,
        simplex: torch.FloatTensor,
        span_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        previous_pred: Optional[torch.FloatTensor] = None,
        classifier_free_guidance: bool = False,
        classifier_free_guidance_in_train: bool = False,
        max_timestep: int = 5000,
        reduce_loss: str = "mean",  # passed to 'reduction' in F.cross_entropy
        # unconditional_simplex: torch.FloatTensor = None,
        return_all_losses: bool = False,  # return per-token loss for all items in batch):
        previous_hidden: Optional[torch.FloatTensor] = None,
        original_timesteps: Optional[torch.FloatTensor] = None,
    ):
        output = super().forward(
            timesteps,
            input_ids,
            simplex,
            span_mask,
            token_type_ids,
            position_ids,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            labels,
            output_attentions,
            output_hidden_states,
            return_dict,
            previous_pred,
            classifier_free_guidance,
            classifier_free_guidance_in_train,
            max_timestep,
            reduce_loss="none",
            return_all_losses=True,
            previous_hidden=previous_hidden,  # for CDCD predictions...
        )
        loss = output.loss
        if self.training:
            # then we learn the cdf from the losses
            # only in train mode, since in eval we just apply the warping.
            new_timesteps_clone = timesteps.clone()
            new_timesteps_clone.requires_grad = True
            with torch.enable_grad():
                # grab the predictions for the loss values - note at this point timesteps
                # are normalised to [0, 1]
                pos_lus = self.position_lus[None, : timesteps.shape[1]].expand(
                    timesteps.shape[0], -1, -1
                )
                pos_lts = self.position_lts[None, : timesteps.shape[1]].expand(
                    timesteps.shape[0], -1, -1
                )
                xent_pred = self.cdf(
                    t=new_timesteps_clone,
                    normalized=False,
                    t_max=1,
                    l_u=pos_lus,
                    l_t=pos_lts,
                )
                # importance weights -> reciprocal of grad of CDF.
                imp_weights = (
                    1.0 / autograd.grad(xent_pred.sum(), [new_timesteps_clone])[0]
                )
            imp_weights = imp_weights.detach() * 1e-5
            cdf_loss = (
                imp_weights
                * (
                    self.cdf(
                        t=timesteps, normalized=False, t_max=1, l_u=pos_lus, l_t=pos_lts
                    )
                    - loss.detach()
                ).pow(2)
            ).mean()
            loss = loss.mean() + cdf_loss  # upweight cdf loss as its too small :(
        else:
            loss = loss.mean()
        return MaskedLMOutput(
            loss=loss,
            logits=output.logits,
            hidden_states=output.hidden_states,
            attentions=output.attentions,
        )