Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
# from https://github.com/igul222/plaid/blob/main/train_cdcd.py | |
# CDCD loss cdf. Warping is the same for all tokens in sequence. | |
class LossCDF(nn.Module): | |
def __init__(self, n_bins): | |
super().__init__() | |
# our buckets! These are basically logits | |
self.l_t = nn.Parameter(torch.zeros([n_bins]) - float(np.log(n_bins))) | |
self.l_u = nn.Parameter(torch.zeros([n_bins]) - float(np.log(n_bins))) | |
def forward( | |
self, t=None, u=None, normalized=True, t_min=0, t_max=1, l_t=None, l_u=None | |
): | |
"""t.shape: [n, l]""" | |
bsz = t.shape[0] if t is not None else u.shape[0] | |
seq_len = t.shape[1] if t is not None else u.shape[1] | |
if l_t is None: | |
l_t = self.l_t | |
l_t = l_t.expand(bsz, seq_len, -1) | |
if l_u is None: | |
l_u = self.l_u | |
l_u = l_u.expand(bsz, seq_len, -1) | |
# apply softmax over logits to get partition of unit interval | |
w_t = F.softmax(l_t, dim=-1) | |
# add a small constant to avoid numerical issues / minimum bin size | |
w_t = w_t + 1e-3 | |
# renormalize to the unit range | |
w_t = w_t / w_t.sum(-1)[:, :, None] | |
# instead of softmax, we use exp for output logits (to fit to loss values) | |
w_u = l_u.exp() | |
# same as above, if we normalize we are effectively doing a softmax | |
w_u = w_u + 1e-3 | |
if normalized: | |
w_u = w_u / w_u.sum(-1)[:, :, None] | |
# The first bucket edge is zero, then its the cumsum of the edge points | |
# this means e_t[0:1], e_t[1:2] etc gives us the edges of the buckets | |
e_t = torch.cat( | |
[ | |
torch.zeros(list(w_t.shape[:-1]) + [1], device=w_t.device), | |
w_t.cumsum(dim=-1), | |
], | |
dim=-1, | |
) | |
e_u = torch.cat( | |
[ | |
torch.zeros(list(w_t.shape[:-1]) + [1], device=w_u.device), | |
w_u.cumsum(dim=-1), | |
], | |
dim=-1, | |
) | |
# if we have t, we want to map to u (= cross-entropy values) | |
if t is not None: | |
# flatten out t to 1d | |
original_shape = t.shape | |
# renormalize t to the unit range | |
t_prime = (t - t_min) / (t_max - t_min) | |
# find the bucket t lies in | |
t_idx = (e_t <= t_prime[:, :, None]).long().sum(dim=-1) - 1 | |
# clamp to be safe? Does this ever fire? | |
t_idx = t_idx.clamp(min=0, max=w_t.shape[-1] - 1) | |
# The actual warping operation: find what % through e_t we are, | |
# and use that to interpolate between the edges of the e_u bucket. | |
u = torch.gather(e_u, -1, t_idx[:, :, None]).squeeze(-1) + ( | |
torch.gather(e_u, -1, t_idx[:, :, None] + 1).squeeze(-1) | |
- torch.gather(e_u, -1, t_idx[:, :, None]).squeeze(-1) | |
) * ( | |
(t_prime - torch.gather(e_t, -1, t_idx[:, :, None]).squeeze(-1)) | |
/ ( | |
torch.gather(e_t, -1, t_idx[:, :, None] + 1).squeeze(-1) | |
- torch.gather(e_t, -1, t_idx[:, :, None]).squeeze(-1) | |
) | |
) | |
# return back to the og shape! | |
return u.view(original_shape) | |
elif u is not None: | |
# in this case, we have some timesteps and want to map them to warped timesteps | |
# that (learnt-ly) correspond to a linear reduction in cross-entropy | |
original_shape = u.shape | |
# find bucket edges as above. Clamping still doesnt make sense? | |
u_idx = (e_u <= u[:, :, None]).long().sum(dim=-1) - 1 | |
u_idx = u_idx.clamp(min=0, max=w_u.shape[-1] - 1) | |
# again, linearly interpolate | |
t_prime = torch.gather(e_t, -1, u_idx[:, :, None]).squeeze(-1) + ( | |
torch.gather(e_t, -1, u_idx[:, :, None] + 1).squeeze(-1) | |
- torch.gather(e_t, -1, u_idx[:, :, None]).squeeze(-1) | |
) * ( | |
(u - torch.gather(e_u, -1, u_idx[:, :, None]).squeeze(-1)) | |
/ ( | |
torch.gather(e_u, -1, u_idx[:, :, None] + 1).squeeze(-1) | |
- torch.gather(e_u, -1, u_idx[:, :, None]).squeeze(-1) | |
) | |
) | |
# since e_u may not be normalized, we need(?) to renormalize | |
t = t_prime * (t_max - t_min) + t_min | |
return t.view(original_shape) | |