Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,312 Bytes
568e264 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
from typing import Tuple
import torch
def gumbel(shape: torch.Size, dtype: torch.dtype, device: torch.device):
"""Sample Gumbel random values with given shape and float dtype.
The values are distributed according to the probability density function:
.. math::
f(x) = e^{-(x + e^{-x})}
Args:
shape (torch.Size): pdf shape
dtype (torch.dtype): pdf value dtype
Returns:
A random array with the specified shape and dtype.
"""
# see https://www.cnblogs.com/initial-h/p/9468974.html for more details
return -torch.log(-torch.log(
torch.empty(shape, device=device).uniform_(
torch.finfo(dtype).tiny, 1.)))
class Wav2vecGumbelVectorQuantizer(torch.nn.Module):
def __init__(self,
features_dim: int = 256,
num_codebooks: int = 2,
num_embeddings: int = 8192,
embedding_dim: int = 16,
hard: bool = False) -> None:
super().__init__()
self.num_groups = num_codebooks
self.num_codevectors_per_group = num_embeddings
# codebooks
# means [C, G, D] see quantize_vector in bestrq_model.py
assert embedding_dim % num_codebooks == 0.0
self.embeddings = torch.nn.parameter.Parameter(
torch.empty(1, num_codebooks * num_embeddings,
embedding_dim // num_codebooks),
requires_grad=True,
)
torch.nn.init.uniform_(self.embeddings)
self.weight_proj = torch.nn.Linear(features_dim,
num_codebooks * num_embeddings)
# use gumbel softmax or argmax(non-differentiable)
self.hard = hard
@staticmethod
def _compute_perplexity(probs, mask=None):
if mask is not None:
mask_extended = torch.broadcast_to(mask.flatten()[:, None, None],
probs.shape)
probs = torch.where(mask_extended.to(torch.bool), probs,
torch.zeros_like(probs))
marginal_probs = probs.sum(dim=0) / mask.sum()
else:
marginal_probs = probs.mean(dim=0)
perplexity = torch.exp(-torch.sum(
marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
return perplexity
def forward(
self,
input: torch.Tensor,
input_mask: torch.Tensor,
temperature: float = 1.
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
b, t, _ = input.size()
hidden = self.weight_proj(input)
hidden = hidden.reshape(b * t * self.num_groups, -1)
if not self.hard:
# sample code vector probs via gumbel in differentiateable way
gumbels = gumbel(hidden.size(), hidden.dtype, hidden.device)
codevector_probs = torch.nn.functional.softmax(
(hidden + gumbels) / temperature, dim=-1)
# compute perplexity
codevector_soft_dist = torch.nn.functional.softmax(
hidden.reshape(b * t, self.num_groups, -1),
dim=-1,
) # [B*T, num_codebooks, num_embeddings]
perplexity = self._compute_perplexity(codevector_soft_dist,
input_mask)
else:
# take argmax in non-differentiable way
# comptute hard codevector distribution (one hot)
codevector_idx = hidden.argmax(axis=-1)
codevector_probs = torch.nn.functional.one_hot(
codevector_idx, hidden.shape[-1]) * 1.0
codevector_probs = codevector_probs.reshape(
b * t, self.num_groups, -1)
perplexity = self._compute_perplexity(codevector_probs, input_mask)
targets_idx = codevector_probs.argmax(-1).reshape(b, t, -1)
codevector_probs = codevector_probs.reshape(b * t, -1)
# use probs to retrieve codevectors
codevectors_per_group = codevector_probs.unsqueeze(
-1) * self.embeddings
codevectors = codevectors_per_group.reshape(
b * t, self.num_groups, self.num_codevectors_per_group, -1)
codevectors = codevectors.sum(-2).reshape(b, t, -1)
return codevectors, perplexity, targets_idx
|