File size: 510 Bytes
1fd4e9c
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch


def distribute_into_histogram(loss: torch.Tensor,
                              t: torch.Tensor,
                              num_bins: int = 25) -> tuple[torch.Tensor, torch.Tensor]:
    loss = loss.detach().flatten()
    t = t.detach().flatten()
    t = (t * num_bins).long()
    hist = torch.zeros(num_bins, device=loss.device)
    count = torch.zeros(num_bins, device=loss.device)
    hist.scatter_add_(0, t, loss)
    count.scatter_add_(0, t, torch.ones_like(loss))
    return hist, count