tess-2-demo / sdlm /models /confidence_tracker /confidence_tracker_model.py
hamishivi's picture
commit
17ff0d8 verified
from typing import Optional
import torch
from transformers.modeling_outputs import MaskedLMOutput
from sdlm.models.roberta.modeling_roberta import RobertaForDiffusionLM
# Roberta with the confidence tracker. empirically the same,
# but alters timesteps based on last confidence.
# operates on a token level.
class ConfidenceTrackerRobertaDiffusionLM(RobertaForDiffusionLM):
def __init__(self, config):
super().__init__(config)
def forward(
self,
timesteps: torch.FloatTensor,
input_ids: torch.LongTensor,
simplex: torch.FloatTensor,
span_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
previous_pred: Optional[torch.FloatTensor] = None,
classifier_free_guidance: bool = False,
classifier_free_guidance_in_train: bool = False,
max_timestep: int = 5000,
reduce_loss: str = "mean", # passed to 'reduction' in F.cross_entropy
# unconditional_simplex: torch.FloatTensor = None,
return_all_losses: bool = False, # return per-token loss for all items in batch):
previous_hidden: Optional[torch.FloatTensor] = None,
original_timesteps: Optional[torch.FloatTensor] = None,
last_confidence_scores: Optional[torch.FloatTensor] = None,
):
# main difference: timesteps are the min(1-confidence, timesteps)
# 1 - since 1 is full noise.
# if last_confidence_scores is not None:
# timesteps = torch.min(
# torch.where(last_confidence_scores > 0.99, 1 - last_confidence_scores, timesteps), timesteps
# )
output = super().forward(
timesteps,
input_ids,
simplex,
span_mask,
token_type_ids,
position_ids,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
labels,
output_attentions,
output_hidden_states,
return_dict,
previous_pred,
classifier_free_guidance,
classifier_free_guidance_in_train,
max_timestep,
reduce_loss=reduce_loss,
return_all_losses=False,
)
loss = output.loss.mean()
# confidence = how much did we put on the right token?
# todo: calibrate this to the right scale.
confidence_scores = torch.softmax(output.logits, dim=-1).max(dim=-1).values
if not self.training:
return (
MaskedLMOutput(
loss=loss,
logits=output.logits,
hidden_states=output.hidden_states,
attentions=output.attentions,
),
confidence_scores,
)
else:
return MaskedLMOutput(
loss=loss,
logits=output.logits,
hidden_states=output.hidden_states,
attentions=output.attentions,
)