File size: 4,312 Bytes
568e264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from typing import Tuple
import torch


def gumbel(shape: torch.Size, dtype: torch.dtype, device: torch.device):
    """Sample Gumbel random values with given shape and float dtype.

    The values are distributed according to the probability density function:

    .. math::
     f(x) = e^{-(x + e^{-x})}

    Args:
      shape (torch.Size): pdf shape
      dtype (torch.dtype): pdf value dtype

    Returns:
       A random array with the specified shape and dtype.
    """
    # see https://www.cnblogs.com/initial-h/p/9468974.html for more details
    return -torch.log(-torch.log(
        torch.empty(shape, device=device).uniform_(
            torch.finfo(dtype).tiny, 1.)))


class Wav2vecGumbelVectorQuantizer(torch.nn.Module):

    def __init__(self,
                 features_dim: int = 256,
                 num_codebooks: int = 2,
                 num_embeddings: int = 8192,
                 embedding_dim: int = 16,
                 hard: bool = False) -> None:

        super().__init__()

        self.num_groups = num_codebooks
        self.num_codevectors_per_group = num_embeddings
        # codebooks
        # means [C, G, D] see quantize_vector in bestrq_model.py
        assert embedding_dim % num_codebooks == 0.0
        self.embeddings = torch.nn.parameter.Parameter(
            torch.empty(1, num_codebooks * num_embeddings,
                        embedding_dim // num_codebooks),
            requires_grad=True,
        )
        torch.nn.init.uniform_(self.embeddings)

        self.weight_proj = torch.nn.Linear(features_dim,
                                           num_codebooks * num_embeddings)
        # use gumbel softmax or argmax(non-differentiable)
        self.hard = hard

    @staticmethod
    def _compute_perplexity(probs, mask=None):
        if mask is not None:

            mask_extended = torch.broadcast_to(mask.flatten()[:, None, None],
                                               probs.shape)
            probs = torch.where(mask_extended.to(torch.bool), probs,
                                torch.zeros_like(probs))
            marginal_probs = probs.sum(dim=0) / mask.sum()
        else:
            marginal_probs = probs.mean(dim=0)

        perplexity = torch.exp(-torch.sum(
            marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
        return perplexity

    def forward(
        self,
        input: torch.Tensor,
        input_mask: torch.Tensor,
        temperature: float = 1.
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

        b, t, _ = input.size()

        hidden = self.weight_proj(input)
        hidden = hidden.reshape(b * t * self.num_groups, -1)
        if not self.hard:
            # sample code vector probs via gumbel in differentiateable way
            gumbels = gumbel(hidden.size(), hidden.dtype, hidden.device)
            codevector_probs = torch.nn.functional.softmax(
                (hidden + gumbels) / temperature, dim=-1)

            # compute perplexity
            codevector_soft_dist = torch.nn.functional.softmax(
                hidden.reshape(b * t, self.num_groups, -1),
                dim=-1,
            )  # [B*T, num_codebooks, num_embeddings]
            perplexity = self._compute_perplexity(codevector_soft_dist,
                                                  input_mask)
        else:
            # take argmax in non-differentiable way
            # comptute hard codevector distribution (one hot)
            codevector_idx = hidden.argmax(axis=-1)
            codevector_probs = torch.nn.functional.one_hot(
                codevector_idx, hidden.shape[-1]) * 1.0
            codevector_probs = codevector_probs.reshape(
                b * t, self.num_groups, -1)
            perplexity = self._compute_perplexity(codevector_probs, input_mask)

        targets_idx = codevector_probs.argmax(-1).reshape(b, t, -1)
        codevector_probs = codevector_probs.reshape(b * t, -1)
        # use probs to retrieve codevectors
        codevectors_per_group = codevector_probs.unsqueeze(
            -1) * self.embeddings
        codevectors = codevectors_per_group.reshape(
            b * t, self.num_groups, self.num_codevectors_per_group, -1)

        codevectors = codevectors.sum(-2).reshape(b, t, -1)
        return codevectors, perplexity, targets_idx