Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,159 Bytes
a84a65c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
#
"""Command-line for audio compression."""
import os
import torch
from omegaconf import OmegaConf
import logging
from ..abs_tokenizer import AbsTokenizer
from .models.soundstream import SoundStream
import sys
class AudioTokenizer(AbsTokenizer):
def __init__(self,
ckpt_path,
device=torch.device('cpu'),
):
""" soundstream with fixed bandwidth of 4kbps
It encodes audio with 50 fps and 8-dim vector for each frame
The value of each entry is in [0, 1023]
"""
super(AudioTokenizer, self).__init__()
# GPU is only for offline tokenization
# So, when distributed training is launched, this should still be on CPU
self.device = device
config_path = os.path.join(os.path.dirname(ckpt_path), 'config.yaml')
if not os.path.isfile(config_path):
raise ValueError(f"{config_path} file does not exist.")
config = OmegaConf.load(config_path)
self.ckpt_path = ckpt_path
logging.info(f"using config {config_path} and model {self.ckpt_path}")
self.soundstream = self.build_codec_model(config)
# properties
self.sr = 16000
self.dim_codebook = 1024
self.n_codebook = 3
self.bw = 1.5 # bw=1.5 ---> 3 codebooks
self.freq = self.n_codebook * 50
self.mask_id = self.dim_codebook * self.n_codebook
def build_codec_model(self, config):
model = eval(config.generator.name)(**config.generator.config)
parameter_dict = torch.load(self.ckpt_path, map_location='cpu')
model.load_state_dict(parameter_dict['codec_model']) # load model
model = model.to(self.device)
return model
@torch.no_grad()
def encode(self, wav):
wav = wav.unsqueeze(1).to(self.device) # (B,1,len)
compressed = self.soundstream.encode(wav, target_bw=self.bw) # [n_codebook, 1, n_frames]
compressed = compressed.detach().cpu().numpy() # [B, n_codebook, n_frames]
return compressed
@torch.no_grad()
def decode(self, audio, rescale):
compressed = audio.unsqueeze(1)
print('compressed ', compressed.shape)
out = self.soundstream.decode(compressed)
out = out.detach().cpu().squeeze(0)
check_clipping(out, rescale)
return out
def check_clipping(wav, rescale):
if rescale:
return
mx = wav.abs().max()
limit = 0.99
if mx > limit:
print(
f"Clipping!! max scale {mx}, limit is {limit}. "
"To avoid clipping, use the `-r` option to rescale the output.",
file=sys.stderr)
if __name__ == '__main__':
tokenizer = AudioTokenizer(device=torch.device('cuda:0')).cuda()
wav = '/home/v-dongyang/data/FSD/mnt/fast/nobackup/scratch4weeks/xm00178/WavCaps/data/waveforms/FreeSound_flac/537271.flac'
codec = tokenizer.tokenize(wav)
print(codec)
# wav = tokenizer.detokenize(codec)
# import torchaudio
# torchaudio.save('desing.wav', wav, 16000, bits_per_sample=16, encoding='PCM_S')
# print(wav.shape)
|