File size: 1,401 Bytes
14ce5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
"""This file contains the definition of some utility functions for the quantizer."""

from typing import Tuple
import torch


def clamp_log(x: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
    """Clamps the input tensor and computes the log.

    Args:
        x -> torch.Tensor: The input tensor.
        eps -> float: The epsilon value serving as the lower bound.

    Returns:
        torch.Tensor: The log of the clamped input tensor.
    """
    return torch.log(torch.clamp(x, eps))


def entropy_loss_fn(
    affinity: torch.Tensor,
    temperature: float,
    entropy_gamma: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Computes the entropy loss.

    Args:
        affinity -> torch.Tensor: The affinity matrix.
        temperature -> float: The temperature.
        entropy_gamma -> float: The entropy gamma.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: The per-sample and average entropy.
    """
    flat_affinity = affinity.reshape(-1, affinity.shape[-1])
    flat_affinity /= temperature

    probability = flat_affinity.softmax(dim=-1)
    average_probability = torch.mean(probability, dim=0)

    per_sample_entropy = -1 * torch.mean(
        torch.sum(probability * clamp_log(probability), dim=-1)
    )
    avg_entropy = torch.sum(-1 * average_probability * clamp_log(average_probability))

    return (per_sample_entropy, avg_entropy * entropy_gamma)