File size: 1,401 Bytes
14ce5a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
"""This file contains the definition of some utility functions for the quantizer."""
from typing import Tuple
import torch
def clamp_log(x: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
"""Clamps the input tensor and computes the log.
Args:
x -> torch.Tensor: The input tensor.
eps -> float: The epsilon value serving as the lower bound.
Returns:
torch.Tensor: The log of the clamped input tensor.
"""
return torch.log(torch.clamp(x, eps))
def entropy_loss_fn(
affinity: torch.Tensor,
temperature: float,
entropy_gamma: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Computes the entropy loss.
Args:
affinity -> torch.Tensor: The affinity matrix.
temperature -> float: The temperature.
entropy_gamma -> float: The entropy gamma.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The per-sample and average entropy.
"""
flat_affinity = affinity.reshape(-1, affinity.shape[-1])
flat_affinity /= temperature
probability = flat_affinity.softmax(dim=-1)
average_probability = torch.mean(probability, dim=0)
per_sample_entropy = -1 * torch.mean(
torch.sum(probability * clamp_log(probability), dim=-1)
)
avg_entropy = torch.sum(-1 * average_probability * clamp_log(average_probability))
return (per_sample_entropy, avg_entropy * entropy_gamma)
|