File size: 7,453 Bytes
96fe5d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
from typing import List

import torch
import torchaudio
from torch import nn
import math
# from inspiremusic.wavtokenizer.decoder.modules import safe_log
from inspiremusic.wavtokenizer.encoder.modules import SEANetEncoder, SEANetDecoder
from inspiremusic.wavtokenizer.encoder import EncodecModel
from inspiremusic.wavtokenizer.encoder.quantization import ResidualVectorQuantizer


def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
    """
    Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.

    Args:
        x (Tensor): Input tensor.
        clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.

    Returns:
        Tensor: Element-wise logarithm of the input tensor with clipping applied.
    """
    return torch.log(torch.clip(x, min=clip_val))


def symlog(x: torch.Tensor) -> torch.Tensor:
    return torch.sign(x) * torch.log1p(x.abs())


def symexp(x: torch.Tensor) -> torch.Tensor:
    return torch.sign(x) * (torch.exp(x.abs()) - 1)


class FeatureExtractor(nn.Module):
    """Base class for feature extractors."""

    def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Extract features from the given audio.

        Args:
            audio (Tensor): Input audio waveform.

        Returns:
            Tensor: Extracted features of shape (B, C, L), where B is the batch size,
                    C denotes output features, and L is the sequence length.
        """
        raise NotImplementedError("Subclasses must implement the forward method.")


class MelSpectrogramFeatures(FeatureExtractor):
    def __init__(self, sample_rate=24000, n_fft=1024, hop_length=256, n_mels=100, padding="center"):
        super().__init__()
        if padding not in ["center", "same"]:
            raise ValueError("Padding must be 'center' or 'same'.")
        self.padding = padding
        self.mel_spec = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=n_fft,
            hop_length=hop_length,
            n_mels=n_mels,
            center=padding == "center",
            power=1,
        )

    def forward(self, audio, **kwargs):
        if self.padding == "same":
            pad = self.mel_spec.win_length - self.mel_spec.hop_length
            audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect")
        mel = self.mel_spec(audio)
        features = safe_log(mel)
        return features


class EncodecFeatures(FeatureExtractor):
    def __init__(
        self,
        encodec_model: str = "encodec_24khz",
        bandwidths: List[float] = [1.5, 3.0, 6.0, 12.0],
        train_codebooks: bool = False,
        num_quantizers: int = 1, 
        dowmsamples: List[int] = [6, 5, 5, 4],
        vq_bins: int = 16384,
        vq_kmeans: int = 800,
    ):
        super().__init__()

        # breakpoint()
        self.frame_rate = 25  # not use
        # n_q = int(bandwidths[-1]*1000/(math.log2(2048) * self.frame_rate))
        n_q = num_quantizers   # important
        encoder = SEANetEncoder(causal=False, n_residual_layers=1, norm='weight_norm', pad_mode='reflect', lstm=2,
                                dimension=512, channels=1, n_filters=32, ratios=dowmsamples, activation='ELU',
                                kernel_size=7, residual_kernel_size=3, last_kernel_size=7, dilation_base=2,
                                true_skip=False, compress=2)
        decoder = SEANetDecoder(causal=False, n_residual_layers=1, norm='weight_norm', pad_mode='reflect', lstm=2,
                                dimension=512, channels=1, n_filters=32, ratios=[8, 5, 4, 2], activation='ELU',
                                kernel_size=7, residual_kernel_size=3, last_kernel_size=7, dilation_base=2,
                                true_skip=False, compress=2)
        quantizer = ResidualVectorQuantizer(dimension=512, n_q=n_q, bins=vq_bins, kmeans_iters=vq_kmeans,
                                            decay=0.99, kmeans_init=True)

        # breakpoint()
        if encodec_model == "encodec_24khz":
            self.encodec = EncodecModel(encoder=encoder, decoder=decoder, quantizer=quantizer,
                                        target_bandwidths=bandwidths, sample_rate=24000, channels=1)
        else:
            raise ValueError(
                f"Unsupported encodec_model: {encodec_model}. Supported options are 'encodec_24khz'."
            )
        for param in self.encodec.parameters():
            param.requires_grad = True
        # self.num_q = n_q
        # codebook_weights = torch.cat([vq.codebook for vq in self.encodec.quantizer.vq.layers[: self.num_q]], dim=0)
        # self.codebook_weights = torch.nn.Parameter(codebook_weights, requires_grad=train_codebooks)
        self.bandwidths = bandwidths

    # @torch.no_grad()
    # def get_encodec_codes(self, audio):
    #     audio = audio.unsqueeze(1)
    #     emb = self.encodec.encoder(audio)
    #     codes = self.encodec.quantizer.encode(emb, self.encodec.frame_rate, self.encodec.bandwidth)
    #     return codes

    def forward(self, audio: torch.Tensor, bandwidth_id: torch.Tensor = torch.tensor(0)):
        if self.training:
            self.encodec.train()

        audio = audio.unsqueeze(1)                  # audio(16,24000)

        # breakpoint()

        emb = self.encodec.encoder(audio)
        q_res = self.encodec.quantizer(emb, self.frame_rate, bandwidth=self.bandwidths[bandwidth_id])
        quantized = q_res.quantized
        codes = q_res.codes
        commit_loss = q_res.penalty                 # codes(8,16,75),features(16,128,75)

        return quantized, codes, commit_loss

        # codes = self.get_encodec_codes(audio)
        # # Instead of summing in the loop, it stores subsequent VQ dictionaries in a single `self.codebook_weights`
        # # with offsets given by the number of bins, and finally summed in a vectorized operation.
        # offsets = torch.arange(
        #     0, self.encodec.quantizer.bins * len(codes), self.encodec.quantizer.bins, device=audio.device
        # )
        # embeddings_idxs = codes + offsets.view(-1, 1, 1)
        # features = torch.nn.functional.embedding(embeddings_idxs, self.codebook_weights).sum(dim=0)
        # return features.transpose(1, 2)

    def infer(self, audio: torch.Tensor, bandwidth_id: torch.Tensor):
        if self.training:
            self.encodec.train()

        audio = audio.unsqueeze(1)                  # audio(16,24000)
        emb = self.encodec.encoder(audio)
        q_res = self.encodec.quantizer.infer(emb, self.frame_rate, bandwidth=self.bandwidths[bandwidth_id])
        quantized = q_res.quantized
        codes = q_res.codes
        commit_loss = q_res.penalty                 # codes(8,16,75),features(16,128,75)

        return quantized, codes, commit_loss

    def _infer(self, audio: torch.Tensor, bandwidth_id: torch.Tensor = torch.tensor(0)):
        if self.training:
            self.encodec.train()

        audio = audio.unsqueeze(1)                  # audio(16,24000)
        emb = self.encodec.encoder(audio)
        q_res = self.encodec.quantizer.infer(emb, self.frame_rate, bandwidth=self.bandwidths[bandwidth_id])
        quantized = q_res.quantized
        codes = q_res.codes
        commit_loss = q_res.penalty                 # codes(8,16,75),features(16,128,75)

        return quantized, codes, commit_loss