Spaces:
Sleeping
Sleeping
File size: 3,488 Bytes
17ff0d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
from typing import Optional
import torch
from transformers.modeling_outputs import MaskedLMOutput
from sdlm.models.roberta.modeling_roberta import RobertaForDiffusionLM
# Roberta with the confidence tracker. empirically the same,
# but alters timesteps based on last confidence.
# operates on a token level.
class ConfidenceTrackerRobertaDiffusionLM(RobertaForDiffusionLM):
def __init__(self, config):
super().__init__(config)
def forward(
self,
timesteps: torch.FloatTensor,
input_ids: torch.LongTensor,
simplex: torch.FloatTensor,
span_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
previous_pred: Optional[torch.FloatTensor] = None,
classifier_free_guidance: bool = False,
classifier_free_guidance_in_train: bool = False,
max_timestep: int = 5000,
reduce_loss: str = "mean", # passed to 'reduction' in F.cross_entropy
# unconditional_simplex: torch.FloatTensor = None,
return_all_losses: bool = False, # return per-token loss for all items in batch):
previous_hidden: Optional[torch.FloatTensor] = None,
original_timesteps: Optional[torch.FloatTensor] = None,
last_confidence_scores: Optional[torch.FloatTensor] = None,
):
# main difference: timesteps are the min(1-confidence, timesteps)
# 1 - since 1 is full noise.
# if last_confidence_scores is not None:
# timesteps = torch.min(
# torch.where(last_confidence_scores > 0.99, 1 - last_confidence_scores, timesteps), timesteps
# )
output = super().forward(
timesteps,
input_ids,
simplex,
span_mask,
token_type_ids,
position_ids,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
labels,
output_attentions,
output_hidden_states,
return_dict,
previous_pred,
classifier_free_guidance,
classifier_free_guidance_in_train,
max_timestep,
reduce_loss=reduce_loss,
return_all_losses=False,
)
loss = output.loss.mean()
# confidence = how much did we put on the right token?
# todo: calibrate this to the right scale.
confidence_scores = torch.softmax(output.logits, dim=-1).max(dim=-1).values
if not self.training:
return (
MaskedLMOutput(
loss=loss,
logits=output.logits,
hidden_states=output.hidden_states,
attentions=output.attentions,
),
confidence_scores,
)
else:
return MaskedLMOutput(
loss=loss,
logits=output.logits,
hidden_states=output.hidden_states,
attentions=output.attentions,
)
|