File size: 3,159 Bytes
a84a65c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#

"""Command-line for audio compression."""

import os
import torch

from omegaconf import OmegaConf
import logging
from ..abs_tokenizer import AbsTokenizer
from .models.soundstream import SoundStream
import sys
class AudioTokenizer(AbsTokenizer):
    def __init__(self, 
                 ckpt_path,
                 device=torch.device('cpu'), 
                 ):
        """ soundstream with fixed bandwidth of 4kbps 
            It encodes audio with 50 fps and 8-dim vector for each frame
            The value of each entry is in [0, 1023]
        """
        super(AudioTokenizer, self).__init__()
        # GPU is only for offline tokenization
        # So, when distributed training is launched, this should still be on CPU

        self.device = device
        config_path = os.path.join(os.path.dirname(ckpt_path), 'config.yaml')
        if not os.path.isfile(config_path):
            raise ValueError(f"{config_path} file does not exist.")
        config = OmegaConf.load(config_path)
        
        self.ckpt_path = ckpt_path
        logging.info(f"using config {config_path} and model {self.ckpt_path}")
        
        self.soundstream = self.build_codec_model(config)
        # properties
        self.sr = 16000
        self.dim_codebook = 1024
        self.n_codebook = 3
        self.bw = 1.5 # bw=1.5 ---> 3 codebooks
        self.freq = self.n_codebook * 50
        self.mask_id = self.dim_codebook * self.n_codebook
        

    def build_codec_model(self, config):

        model = eval(config.generator.name)(**config.generator.config)
        parameter_dict = torch.load(self.ckpt_path, map_location='cpu')
        model.load_state_dict(parameter_dict['codec_model']) # load model
        model = model.to(self.device)
        return model
    
    
    @torch.no_grad()
    def encode(self, wav):
        wav = wav.unsqueeze(1).to(self.device) # (B,1,len)
        compressed = self.soundstream.encode(wav, target_bw=self.bw) # [n_codebook, 1, n_frames]
        compressed = compressed.detach().cpu().numpy() # [B, n_codebook, n_frames]

        return compressed


    @torch.no_grad()
    def decode(self, audio, rescale):
        compressed = audio.unsqueeze(1)
        print('compressed ', compressed.shape)
        out = self.soundstream.decode(compressed)
        out = out.detach().cpu().squeeze(0)
        check_clipping(out, rescale)
        return out
    
def check_clipping(wav, rescale):
    if rescale:
        return
    mx = wav.abs().max()
    limit = 0.99
    if mx > limit:
        print(
            f"Clipping!! max scale {mx}, limit is {limit}. "
            "To avoid clipping, use the `-r` option to rescale the output.",
            file=sys.stderr)



if __name__ == '__main__':
    tokenizer = AudioTokenizer(device=torch.device('cuda:0')).cuda()
    wav = '/home/v-dongyang/data/FSD/mnt/fast/nobackup/scratch4weeks/xm00178/WavCaps/data/waveforms/FreeSound_flac/537271.flac'
    codec = tokenizer.tokenize(wav)
    print(codec)
    # wav = tokenizer.detokenize(codec)
    # import torchaudio
    # torchaudio.save('desing.wav', wav, 16000, bits_per_sample=16, encoding='PCM_S')
    # print(wav.shape)