Spaces:
Running
Running
import torch | |
def distribute_into_histogram(loss: torch.Tensor, | |
t: torch.Tensor, | |
num_bins: int = 25) -> tuple[torch.Tensor, torch.Tensor]: | |
loss = loss.detach().flatten() | |
t = t.detach().flatten() | |
t = (t * num_bins).long() | |
hist = torch.zeros(num_bins, device=loss.device) | |
count = torch.zeros(num_bins, device=loss.device) | |
hist.scatter_add_(0, t, loss) | |
count.scatter_add_(0, t, torch.ones_like(loss)) | |
return hist, count | |