Spaces:
Running
Running
File size: 510 Bytes
1fd4e9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
import torch
def distribute_into_histogram(loss: torch.Tensor,
t: torch.Tensor,
num_bins: int = 25) -> tuple[torch.Tensor, torch.Tensor]:
loss = loss.detach().flatten()
t = t.detach().flatten()
t = (t * num_bins).long()
hist = torch.zeros(num_bins, device=loss.device)
count = torch.zeros(num_bins, device=loss.device)
hist.scatter_add_(0, t, loss)
count.scatter_add_(0, t, torch.ones_like(loss))
return hist, count
|