hamishivi's picture
commit
17ff0d8 verified
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
# from https://github.com/igul222/plaid/blob/main/train_cdcd.py
# CDCD loss cdf. Warping is the same for all tokens in sequence.
class LossCDF(nn.Module):
def __init__(self, n_bins):
super().__init__()
# our buckets! These are basically logits
self.l_t = nn.Parameter(torch.zeros([n_bins]) - float(np.log(n_bins)))
self.l_u = nn.Parameter(torch.zeros([n_bins]) - float(np.log(n_bins)))
def forward(
self, t=None, u=None, normalized=True, t_min=0, t_max=1, l_t=None, l_u=None
):
"""t.shape: [n, l]"""
bsz = t.shape[0] if t is not None else u.shape[0]
seq_len = t.shape[1] if t is not None else u.shape[1]
if l_t is None:
l_t = self.l_t
l_t = l_t.expand(bsz, seq_len, -1)
if l_u is None:
l_u = self.l_u
l_u = l_u.expand(bsz, seq_len, -1)
# apply softmax over logits to get partition of unit interval
w_t = F.softmax(l_t, dim=-1)
# add a small constant to avoid numerical issues / minimum bin size
w_t = w_t + 1e-3
# renormalize to the unit range
w_t = w_t / w_t.sum(-1)[:, :, None]
# instead of softmax, we use exp for output logits (to fit to loss values)
w_u = l_u.exp()
# same as above, if we normalize we are effectively doing a softmax
w_u = w_u + 1e-3
if normalized:
w_u = w_u / w_u.sum(-1)[:, :, None]
# The first bucket edge is zero, then its the cumsum of the edge points
# this means e_t[0:1], e_t[1:2] etc gives us the edges of the buckets
e_t = torch.cat(
[
torch.zeros(list(w_t.shape[:-1]) + [1], device=w_t.device),
w_t.cumsum(dim=-1),
],
dim=-1,
)
e_u = torch.cat(
[
torch.zeros(list(w_t.shape[:-1]) + [1], device=w_u.device),
w_u.cumsum(dim=-1),
],
dim=-1,
)
# if we have t, we want to map to u (= cross-entropy values)
if t is not None:
# flatten out t to 1d
original_shape = t.shape
# renormalize t to the unit range
t_prime = (t - t_min) / (t_max - t_min)
# find the bucket t lies in
t_idx = (e_t <= t_prime[:, :, None]).long().sum(dim=-1) - 1
# clamp to be safe? Does this ever fire?
t_idx = t_idx.clamp(min=0, max=w_t.shape[-1] - 1)
# The actual warping operation: find what % through e_t we are,
# and use that to interpolate between the edges of the e_u bucket.
u = torch.gather(e_u, -1, t_idx[:, :, None]).squeeze(-1) + (
torch.gather(e_u, -1, t_idx[:, :, None] + 1).squeeze(-1)
- torch.gather(e_u, -1, t_idx[:, :, None]).squeeze(-1)
) * (
(t_prime - torch.gather(e_t, -1, t_idx[:, :, None]).squeeze(-1))
/ (
torch.gather(e_t, -1, t_idx[:, :, None] + 1).squeeze(-1)
- torch.gather(e_t, -1, t_idx[:, :, None]).squeeze(-1)
)
)
# return back to the og shape!
return u.view(original_shape)
elif u is not None:
# in this case, we have some timesteps and want to map them to warped timesteps
# that (learnt-ly) correspond to a linear reduction in cross-entropy
original_shape = u.shape
# find bucket edges as above. Clamping still doesnt make sense?
u_idx = (e_u <= u[:, :, None]).long().sum(dim=-1) - 1
u_idx = u_idx.clamp(min=0, max=w_u.shape[-1] - 1)
# again, linearly interpolate
t_prime = torch.gather(e_t, -1, u_idx[:, :, None]).squeeze(-1) + (
torch.gather(e_t, -1, u_idx[:, :, None] + 1).squeeze(-1)
- torch.gather(e_t, -1, u_idx[:, :, None]).squeeze(-1)
) * (
(u - torch.gather(e_u, -1, u_idx[:, :, None]).squeeze(-1))
/ (
torch.gather(e_u, -1, u_idx[:, :, None] + 1).squeeze(-1)
- torch.gather(e_u, -1, u_idx[:, :, None]).squeeze(-1)
)
)
# since e_u may not be normalized, we need(?) to renormalize
t = t_prime * (t_max - t_min) + t_min
return t.view(original_shape)