Spaces:
Sleeping
Sleeping
from typing import Optional | |
import torch | |
from transformers.modeling_outputs import MaskedLMOutput | |
from sdlm.models.roberta.modeling_roberta import RobertaForDiffusionLM | |
# Roberta with the confidence tracker. empirically the same, | |
# but alters timesteps based on last confidence. | |
# operates on a token level. | |
class ConfidenceTrackerRobertaDiffusionLM(RobertaForDiffusionLM): | |
def __init__(self, config): | |
super().__init__(config) | |
def forward( | |
self, | |
timesteps: torch.FloatTensor, | |
input_ids: torch.LongTensor, | |
simplex: torch.FloatTensor, | |
span_mask: Optional[torch.FloatTensor] = None, | |
token_type_ids: Optional[torch.LongTensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
head_mask: Optional[torch.FloatTensor] = None, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
previous_pred: Optional[torch.FloatTensor] = None, | |
classifier_free_guidance: bool = False, | |
classifier_free_guidance_in_train: bool = False, | |
max_timestep: int = 5000, | |
reduce_loss: str = "mean", # passed to 'reduction' in F.cross_entropy | |
# unconditional_simplex: torch.FloatTensor = None, | |
return_all_losses: bool = False, # return per-token loss for all items in batch): | |
previous_hidden: Optional[torch.FloatTensor] = None, | |
original_timesteps: Optional[torch.FloatTensor] = None, | |
last_confidence_scores: Optional[torch.FloatTensor] = None, | |
): | |
# main difference: timesteps are the min(1-confidence, timesteps) | |
# 1 - since 1 is full noise. | |
# if last_confidence_scores is not None: | |
# timesteps = torch.min( | |
# torch.where(last_confidence_scores > 0.99, 1 - last_confidence_scores, timesteps), timesteps | |
# ) | |
output = super().forward( | |
timesteps, | |
input_ids, | |
simplex, | |
span_mask, | |
token_type_ids, | |
position_ids, | |
head_mask, | |
encoder_hidden_states, | |
encoder_attention_mask, | |
labels, | |
output_attentions, | |
output_hidden_states, | |
return_dict, | |
previous_pred, | |
classifier_free_guidance, | |
classifier_free_guidance_in_train, | |
max_timestep, | |
reduce_loss=reduce_loss, | |
return_all_losses=False, | |
) | |
loss = output.loss.mean() | |
# confidence = how much did we put on the right token? | |
# todo: calibrate this to the right scale. | |
confidence_scores = torch.softmax(output.logits, dim=-1).max(dim=-1).values | |
if not self.training: | |
return ( | |
MaskedLMOutput( | |
loss=loss, | |
logits=output.logits, | |
hidden_states=output.hidden_states, | |
attentions=output.attentions, | |
), | |
confidence_scores, | |
) | |
else: | |
return MaskedLMOutput( | |
loss=loss, | |
logits=output.logits, | |
hidden_states=output.hidden_states, | |
attentions=output.attentions, | |
) | |