File size: 3,488 Bytes
17ff0d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from typing import Optional

import torch
from transformers.modeling_outputs import MaskedLMOutput

from sdlm.models.roberta.modeling_roberta import RobertaForDiffusionLM


# Roberta with the confidence tracker. empirically the same,
# but alters timesteps based on last confidence.
# operates on a token level.
class ConfidenceTrackerRobertaDiffusionLM(RobertaForDiffusionLM):
    def __init__(self, config):
        super().__init__(config)

    def forward(
        self,
        timesteps: torch.FloatTensor,
        input_ids: torch.LongTensor,
        simplex: torch.FloatTensor,
        span_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        previous_pred: Optional[torch.FloatTensor] = None,
        classifier_free_guidance: bool = False,
        classifier_free_guidance_in_train: bool = False,
        max_timestep: int = 5000,
        reduce_loss: str = "mean",  # passed to 'reduction' in F.cross_entropy
        # unconditional_simplex: torch.FloatTensor = None,
        return_all_losses: bool = False,  # return per-token loss for all items in batch):
        previous_hidden: Optional[torch.FloatTensor] = None,
        original_timesteps: Optional[torch.FloatTensor] = None,
        last_confidence_scores: Optional[torch.FloatTensor] = None,
    ):
        # main difference: timesteps are the min(1-confidence, timesteps)
        # 1 - since 1 is full noise.
        # if last_confidence_scores is not None:
        #     timesteps = torch.min(
        #         torch.where(last_confidence_scores > 0.99, 1 - last_confidence_scores, timesteps), timesteps
        #     )
        output = super().forward(
            timesteps,
            input_ids,
            simplex,
            span_mask,
            token_type_ids,
            position_ids,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            labels,
            output_attentions,
            output_hidden_states,
            return_dict,
            previous_pred,
            classifier_free_guidance,
            classifier_free_guidance_in_train,
            max_timestep,
            reduce_loss=reduce_loss,
            return_all_losses=False,
        )
        loss = output.loss.mean()
        # confidence = how much did we put on the right token?
        # todo: calibrate this to the right scale.
        confidence_scores = torch.softmax(output.logits, dim=-1).max(dim=-1).values
        if not self.training:
            return (
                MaskedLMOutput(
                    loss=loss,
                    logits=output.logits,
                    hidden_states=output.hidden_states,
                    attentions=output.attentions,
                ),
                confidence_scores,
            )
        else:
            return MaskedLMOutput(
                loss=loss,
                logits=output.logits,
                hidden_states=output.hidden_states,
                attentions=output.attentions,
            )