Plachta commited on
Commit
f32cd36
1 Parent(s): d65c610

Upload 42 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/reference/dingzhen_0.wav filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,140 +1,195 @@
1
- import spaces
2
- import gradio as gr
3
- import torch
4
- import torchaudio
5
- import librosa
6
- from modules.commons import build_model, load_checkpoint, recursive_munch
7
- import yaml
8
- from hf_utils import load_custom_model_from_hf
9
-
10
- # Load model and configuration
11
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
-
13
- dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
14
- "DiT_step_315000_seed_v2_online_pruned.pth",
15
- "config_dit_mel_seed.yml")
16
-
17
- config = yaml.safe_load(open(dit_config_path, 'r'))
18
- model_params = recursive_munch(config['model_params'])
19
- model = build_model(model_params, stage='DiT')
20
- hop_length = config['preprocess_params']['spect_params']['hop_length']
21
- sr = config['preprocess_params']['sr']
22
-
23
- # Load checkpoints
24
- model, _, _, _ = load_checkpoint(model, None, dit_checkpoint_path,
25
- load_only_params=True, ignore_modules=[], is_distributed=False)
26
- for key in model:
27
- model[key].eval()
28
- model[key].to(device)
29
- model.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
30
-
31
- # Load additional modules
32
- from modules.campplus.DTDNN import CAMPPlus
33
-
34
- campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
35
- campplus_model.load_state_dict(torch.load(config['model_params']['style_encoder']['campplus_path'], map_location='cpu'))
36
- campplus_model.eval()
37
- campplus_model.to(device)
38
-
39
- from modules.hifigan.generator import HiFTGenerator
40
- from modules.hifigan.f0_predictor import ConvRNNF0Predictor
41
-
42
- hift_checkpoint_path, hift_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
43
- "hift.pt",
44
- "hifigan.yml")
45
- hift_config = yaml.safe_load(open(hift_config_path, 'r'))
46
- hift_gen = HiFTGenerator(**hift_config['hift'], f0_predictor=ConvRNNF0Predictor(**hift_config['f0_predictor']))
47
- hift_gen.load_state_dict(torch.load(hift_checkpoint_path, map_location='cpu'))
48
- hift_gen.eval()
49
- hift_gen.to(device)
50
-
51
- from modules.cosyvoice_tokenizer.frontend import CosyVoiceFrontEnd
52
-
53
- speech_tokenizer_path = load_custom_model_from_hf("Plachta/Seed-VC", "speech_tokenizer_v1.onnx", None)
54
-
55
- cosyvoice_frontend = CosyVoiceFrontEnd(speech_tokenizer_model=speech_tokenizer_path,
56
- device='cuda', device_id=0)
57
- # Generate mel spectrograms
58
- mel_fn_args = {
59
- "n_fft": config['preprocess_params']['spect_params']['n_fft'],
60
- "win_size": config['preprocess_params']['spect_params']['win_length'],
61
- "hop_size": config['preprocess_params']['spect_params']['hop_length'],
62
- "num_mels": config['preprocess_params']['spect_params']['n_mels'],
63
- "sampling_rate": sr,
64
- "fmin": 0,
65
- "fmax": 8000,
66
- "center": False
67
- }
68
- from modules.audio import mel_spectrogram
69
-
70
- to_mel = lambda x: mel_spectrogram(x, **mel_fn_args)
71
-
72
- @spaces.GPU
73
- @torch.no_grad()
74
- def voice_conversion(source, target, diffusion_steps, length_adjust, inference_cfg_rate):
75
- # Load audio
76
- source_audio = librosa.load(source, sr=sr)[0]
77
- ref_audio = librosa.load(target, sr=sr)[0]
78
-
79
- # Process audio
80
- source_audio = torch.tensor(source_audio[:sr * 30]).unsqueeze(0).float().to(device)
81
- ref_audio = torch.tensor(ref_audio[:sr * 30]).unsqueeze(0).float().to(device)
82
-
83
- # Resample
84
- source_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000)
85
- ref_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
86
-
87
- # Extract features
88
- S_alt = cosyvoice_frontend.extract_speech_token(source_waves_16k)[0]
89
- S_ori = cosyvoice_frontend.extract_speech_token(ref_waves_16k)[0]
90
-
91
- mel = to_mel(source_audio.to(device).float())
92
- mel2 = to_mel(ref_audio.to(device).float())
93
-
94
- target_lengths = torch.LongTensor([int(mel.size(2) * length_adjust)]).to(mel.device)
95
- target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device)
96
-
97
- # Style encoding
98
- feat = torchaudio.compliance.kaldi.fbank(source_waves_16k,
99
- num_mel_bins=80,
100
- dither=0,
101
- sample_frequency=16000)
102
- feat = feat - feat.mean(dim=0, keepdim=True)
103
- style1 = campplus_model(feat.unsqueeze(0))
104
-
105
- feat2 = torchaudio.compliance.kaldi.fbank(ref_waves_16k,
106
- num_mel_bins=80,
107
- dither=0,
108
- sample_frequency=16000)
109
- feat2 = feat2 - feat2.mean(dim=0, keepdim=True)
110
- style2 = campplus_model(feat2.unsqueeze(0))
111
-
112
- # Length regulation
113
- cond = model.length_regulator(S_alt, ylens=target_lengths)[0]
114
- prompt_condition = model.length_regulator(S_ori, ylens=target2_lengths)[0]
115
- cat_condition = torch.cat([prompt_condition, cond], dim=1)
116
-
117
- # Voice Conversion
118
- vc_target = model.cfm.inference(cat_condition, torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
119
- mel2, style2, None, diffusion_steps, inference_cfg_rate=inference_cfg_rate)
120
- vc_target = vc_target[:, :, mel2.size(-1):]
121
-
122
- # Convert to waveform
123
- vc_wave = hift_gen.inference(vc_target)
124
-
125
- return (sr, vc_wave.squeeze(0).cpu().numpy())
126
-
127
-
128
- if __name__ == "__main__":
129
- description = "Zero-shot voice conversion with in-context learning. Check out our [GitHub repository](https://github.com/Plachtaa/seed-vc) for details and updates."
130
- inputs = [
131
- gr.Audio(type="filepath", label="Source Audio"),
132
- gr.Audio(type="filepath", label="Reference Audio"),
133
- gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Diffusion Steps"),
134
- gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Adjust"),
135
- gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="Inference CFG Rate"),
136
- ]
137
-
138
- outputs = gr.Audio(label="Output Audio")
139
-
140
- gr.Interface(fn=voice_conversion, description=description, inputs=inputs, outputs=outputs, title="Seed Voice Conversion").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchaudio
4
+ import librosa
5
+ from modules.commons import build_model, load_checkpoint, recursive_munch
6
+ import yaml
7
+ from hf_utils import load_custom_model_from_hf
8
+
9
+ # Load model and configuration
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
13
+ "DiT_step_298000_seed_uvit_facodec_small_wavenet_pruned.pth",
14
+ "config_dit_mel_seed_facodec_small_wavenet.yml")
15
+
16
+ config = yaml.safe_load(open(dit_config_path, 'r'))
17
+ model_params = recursive_munch(config['model_params'])
18
+ model = build_model(model_params, stage='DiT')
19
+ hop_length = config['preprocess_params']['spect_params']['hop_length']
20
+ sr = config['preprocess_params']['sr']
21
+
22
+ # Load checkpoints
23
+ model, _, _, _ = load_checkpoint(model, None, dit_checkpoint_path,
24
+ load_only_params=True, ignore_modules=[], is_distributed=False)
25
+ for key in model:
26
+ model[key].eval()
27
+ model[key].to(device)
28
+ model.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
29
+
30
+ # Load additional modules
31
+ from modules.campplus.DTDNN import CAMPPlus
32
+
33
+ campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
34
+ campplus_model.load_state_dict(torch.load(config['model_params']['style_encoder']['campplus_path']))
35
+ campplus_model.eval()
36
+ campplus_model.to(device)
37
+
38
+ from modules.hifigan.generator import HiFTGenerator
39
+ from modules.hifigan.f0_predictor import ConvRNNF0Predictor
40
+
41
+ hift_checkpoint_path, hift_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
42
+ "hift.pt",
43
+ "hifigan.yml")
44
+ hift_config = yaml.safe_load(open(hift_config_path, 'r'))
45
+ hift_gen = HiFTGenerator(**hift_config['hift'], f0_predictor=ConvRNNF0Predictor(**hift_config['f0_predictor']))
46
+ hift_gen.load_state_dict(torch.load(hift_checkpoint_path, map_location='cpu'))
47
+ hift_gen.eval()
48
+ hift_gen.to(device)
49
+
50
+ speech_tokenizer_type = config['model_params']['speech_tokenizer'].get('type', 'cosyvoice')
51
+ if speech_tokenizer_type == 'cosyvoice':
52
+ from modules.cosyvoice_tokenizer.frontend import CosyVoiceFrontEnd
53
+ speech_tokenizer_path = load_custom_model_from_hf("Plachta/Seed-VC", "speech_tokenizer_v1.onnx", None)
54
+ cosyvoice_frontend = CosyVoiceFrontEnd(speech_tokenizer_model=speech_tokenizer_path,
55
+ device='cuda', device_id=0)
56
+ elif speech_tokenizer_type == 'facodec':
57
+ ckpt_path, config_path = load_custom_model_from_hf("Plachta/FAcodec", 'pytorch_model.bin', 'config.yml')
58
+
59
+ codec_config = yaml.safe_load(open(config_path))
60
+ codec_model_params = recursive_munch(codec_config['model_params'])
61
+ codec_encoder = build_model(codec_model_params, stage="codec")
62
+
63
+ ckpt_params = torch.load(ckpt_path, map_location="cpu")
64
+
65
+ for key in codec_encoder:
66
+ codec_encoder[key].load_state_dict(ckpt_params[key], strict=False)
67
+ _ = [codec_encoder[key].eval() for key in codec_encoder]
68
+ _ = [codec_encoder[key].to(device) for key in codec_encoder]
69
+ # Generate mel spectrograms
70
+ mel_fn_args = {
71
+ "n_fft": config['preprocess_params']['spect_params']['n_fft'],
72
+ "win_size": config['preprocess_params']['spect_params']['win_length'],
73
+ "hop_size": config['preprocess_params']['spect_params']['hop_length'],
74
+ "num_mels": config['preprocess_params']['spect_params']['n_mels'],
75
+ "sampling_rate": sr,
76
+ "fmin": 0,
77
+ "fmax": 8000,
78
+ "center": False
79
+ }
80
+ from modules.audio import mel_spectrogram
81
+
82
+ to_mel = lambda x: mel_spectrogram(x, **mel_fn_args)
83
+
84
+ @torch.no_grad()
85
+ @torch.inference_mode()
86
+ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_cfg_rate, n_quantizers):
87
+ # Load audio
88
+ source_audio = librosa.load(source, sr=sr)[0]
89
+ ref_audio = librosa.load(target, sr=sr)[0]
90
+ # source_sr, source_audio = source
91
+ # ref_sr, ref_audio = target
92
+ # # if any of the inputs has 2 channels, take the first only
93
+ # if source_audio.ndim == 2:
94
+ # source_audio = source_audio[:, 0]
95
+ # if ref_audio.ndim == 2:
96
+ # ref_audio = ref_audio[:, 0]
97
+ #
98
+ # source_audio, ref_audio = source_audio / 32768.0, ref_audio / 32768.0
99
+ #
100
+ # # if source or audio sr not equal to default sr, resample
101
+ # if source_sr != sr:
102
+ # source_audio = librosa.resample(source_audio, source_sr, sr)
103
+ # if ref_sr != sr:
104
+ # ref_audio = librosa.resample(ref_audio, ref_sr, sr)
105
+
106
+ # Process audio
107
+ source_audio = torch.tensor(source_audio[:sr * 30]).unsqueeze(0).float().to(device)
108
+ ref_audio = torch.tensor(ref_audio[:sr * 30]).unsqueeze(0).float().to(device)
109
+
110
+ # Resample
111
+ source_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000)
112
+ ref_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
113
+
114
+ # Extract features
115
+ if speech_tokenizer_type == 'cosyvoice':
116
+ S_alt = cosyvoice_frontend.extract_speech_token(source_waves_16k)[0]
117
+ S_ori = cosyvoice_frontend.extract_speech_token(ref_waves_16k)[0]
118
+ elif speech_tokenizer_type == 'facodec':
119
+ converted_waves_24k = torchaudio.functional.resample(source_audio, sr, 24000)
120
+ wave_lengths_24k = torch.LongTensor([converted_waves_24k.size(1)]).to(converted_waves_24k.device)
121
+ waves_input = converted_waves_24k.unsqueeze(1)
122
+ z = codec_encoder.encoder(waves_input)
123
+ (
124
+ quantized,
125
+ codes
126
+ ) = codec_encoder.quantizer(
127
+ z,
128
+ waves_input,
129
+ )
130
+ S_alt = torch.cat([codes[1], codes[0]], dim=1)
131
+
132
+ # S_ori should be extracted in the same way
133
+ waves_24k = torchaudio.functional.resample(ref_audio, sr, 24000)
134
+ waves_input = waves_24k.unsqueeze(1)
135
+ z = codec_encoder.encoder(waves_input)
136
+ (
137
+ quantized,
138
+ codes
139
+ ) = codec_encoder.quantizer(
140
+ z,
141
+ waves_input,
142
+ )
143
+ S_ori = torch.cat([codes[1], codes[0]], dim=1)
144
+
145
+ mel = to_mel(source_audio.to(device).float())
146
+ mel2 = to_mel(ref_audio.to(device).float())
147
+
148
+ target_lengths = torch.LongTensor([int(mel.size(2) * length_adjust)]).to(mel.device)
149
+ target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device)
150
+
151
+ feat2 = torchaudio.compliance.kaldi.fbank(ref_waves_16k,
152
+ num_mel_bins=80,
153
+ dither=0,
154
+ sample_frequency=16000)
155
+ feat2 = feat2 - feat2.mean(dim=0, keepdim=True)
156
+ style2 = campplus_model(feat2.unsqueeze(0))
157
+
158
+ # Length regulation
159
+ cond = model.length_regulator(S_alt, ylens=target_lengths, n_quantizers=int(n_quantizers))[0]
160
+ prompt_condition = model.length_regulator(S_ori, ylens=target2_lengths, n_quantizers=int(n_quantizers))[0]
161
+ cat_condition = torch.cat([prompt_condition, cond], dim=1)
162
+
163
+ # Voice Conversion
164
+ vc_target = model.cfm.inference(cat_condition, torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
165
+ mel2, style2, None, diffusion_steps, inference_cfg_rate=inference_cfg_rate)
166
+ vc_target = vc_target[:, :, mel2.size(-1):]
167
+
168
+ # Convert to waveform
169
+ vc_wave = hift_gen.inference(vc_target)
170
+
171
+ return sr, vc_wave.squeeze(0).cpu().numpy()
172
+
173
+
174
+ if __name__ == "__main__":
175
+ description = "Zero-shot voice conversion with in-context learning. Check out our [GitHub repository](https://github.com/Plachtaa/seed-vc) for details and updates."
176
+ inputs = [
177
+ gr.Audio(type="filepath", label="Source Audio"),
178
+ gr.Audio(type="filepath", label="Reference Audio"),
179
+ gr.Slider(minimum=1, maximum=200, value=10, step=1, label="Diffusion Steps", info="10 by default, 50~100 for best quality"),
180
+ gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Adjust", info="<1.0 for speed-up speech, >1.0 for slow-down speech"),
181
+ gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="Inference CFG Rate", info="has subtle influence"),
182
+ gr.Slider(minimum=1, maximum=3, step=1, value=3, label="N Quantizers", info="the less quantizer used, the less prosody of source audio is preserved"),
183
+ ]
184
+
185
+ examples = [["examples/source/yae_0.wav", "examples/reference/dingzhen_0.wav", 50, 1.0, 0.7, 1]]
186
+
187
+ outputs = gr.Audio(label="Output Audio")
188
+
189
+ gr.Interface(fn=voice_conversion,
190
+ description=description,
191
+ inputs=inputs,
192
+ outputs=outputs,
193
+ title="Seed Voice Conversion",
194
+ examples=examples,
195
+ ).launch()
configs/config_dit_mel_seed_facodec_small.yml ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "./runs/run_dit_mel_seed_facodec_small"
2
+ save_freq: 1
3
+ log_interval: 10
4
+ save_interval: 1000
5
+ device: "cuda"
6
+ epochs: 1000 # number of epochs for first stage training (pre-training)
7
+ batch_size: 2
8
+ batch_length: 100 # maximum duration of audio in a batch (in seconds)
9
+ max_len: 80 # maximum number of frames
10
+ pretrained_model: ""
11
+ pretrained_encoder: ""
12
+ load_only_params: False # set to true if do not want to load epoch numbers and optimizer parameters
13
+
14
+ F0_path: "modules/JDC/bst.t7"
15
+
16
+ data_params:
17
+ train_data: "./data/train.txt"
18
+ val_data: "./data/val.txt"
19
+ root_path: "./data/"
20
+
21
+ preprocess_params:
22
+ sr: 22050
23
+ spect_params:
24
+ n_fft: 1024
25
+ win_length: 1024
26
+ hop_length: 256
27
+ n_mels: 80
28
+
29
+ model_params:
30
+ dit_type: "DiT" # uDiT or DiT
31
+ reg_loss_type: "l1" # l1 or l2
32
+
33
+ speech_tokenizer:
34
+ type: 'facodec' # facodec or cosyvoice
35
+ path: "checkpoints/speech_tokenizer_v1.onnx"
36
+
37
+ style_encoder:
38
+ dim: 192
39
+ campplus_path: "checkpoints/campplus_cn_common.bin"
40
+
41
+ DAC:
42
+ encoder_dim: 64
43
+ encoder_rates: [2, 5, 5, 6]
44
+ decoder_dim: 1536
45
+ decoder_rates: [ 6, 5, 5, 2 ]
46
+ sr: 24000
47
+
48
+ length_regulator:
49
+ channels: 512
50
+ is_discrete: true
51
+ content_codebook_size: 1024
52
+ in_frame_rate: 80
53
+ out_frame_rate: 80
54
+ sampling_ratios: [1, 1, 1, 1]
55
+ token_dropout_prob: 0.3 # probability of performing token dropout
56
+ token_dropout_range: 1.0 # maximum percentage of tokens to drop out
57
+ n_codebooks: 3
58
+ quantizer_dropout: 0.5
59
+ f0_condition: false
60
+ n_f0_bins: 512
61
+
62
+ DiT:
63
+ hidden_dim: 512
64
+ num_heads: 8
65
+ depth: 13
66
+ class_dropout_prob: 0.1
67
+ block_size: 8192
68
+ in_channels: 80
69
+ style_condition: true
70
+ final_layer_type: 'wavenet'
71
+ target: 'mel' # mel or codec
72
+ content_dim: 512
73
+ content_codebook_size: 1024
74
+ content_type: 'discrete'
75
+ f0_condition: true
76
+ n_f0_bins: 512
77
+ content_codebooks: 1
78
+ is_causal: false
79
+ long_skip_connection: true
80
+ zero_prompt_speech_token: false # for prompt component, do not input corresponding speech token
81
+ time_as_token: false
82
+ style_as_token: false
83
+ uvit_skip_connection: true
84
+ add_resblock_in_transformer: false
85
+
86
+ wavenet:
87
+ hidden_dim: 512
88
+ num_layers: 8
89
+ kernel_size: 5
90
+ dilation_rate: 1
91
+ p_dropout: 0.2
92
+ style_condition: true
93
+
94
+ loss_params:
95
+ base_lr: 0.0001
96
+ lambda_mel: 45
97
+ lambda_kl: 1.0
configs/config_dit_mel_seed_wavenet.yml ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "./runs/run_dit_mel_seed"
2
+ save_freq: 1
3
+ log_interval: 10
4
+ save_interval: 1000
5
+ device: "cuda"
6
+ epochs: 1000 # number of epochs for first stage training (pre-training)
7
+ batch_size: 4
8
+ batch_length: 100 # maximum duration of audio in a batch (in seconds)
9
+ max_len: 80 # maximum number of frames
10
+ pretrained_model: ""
11
+ pretrained_encoder: ""
12
+ load_only_params: False # set to true if do not want to load epoch numbers and optimizer parameters
13
+
14
+ F0_path: "modules/JDC/bst.t7"
15
+
16
+ preprocess_params:
17
+ sr: 22050
18
+ spect_params:
19
+ n_fft: 1024
20
+ win_length: 1024
21
+ hop_length: 256
22
+ n_mels: 80
23
+
24
+ model_params:
25
+ dit_type: "DiT" # uDiT or DiT
26
+ reg_loss_type: "l2" # l1 or l2
27
+
28
+ speech_tokenizer:
29
+ path: "checkpoints/speech_tokenizer_v1.onnx"
30
+
31
+ style_encoder:
32
+ dim: 192
33
+ campplus_path: "campplus_cn_common.bin"
34
+
35
+ DAC:
36
+ encoder_dim: 64
37
+ encoder_rates: [2, 5, 5, 6]
38
+ decoder_dim: 1536
39
+ decoder_rates: [ 6, 5, 5, 2 ]
40
+ sr: 24000
41
+
42
+ length_regulator:
43
+ channels: 768
44
+ is_discrete: true
45
+ content_codebook_size: 4096
46
+ in_frame_rate: 50
47
+ out_frame_rate: 80
48
+ sampling_ratios: [1, 1, 1, 1]
49
+
50
+ DiT:
51
+ hidden_dim: 768
52
+ num_heads: 12
53
+ depth: 12
54
+ class_dropout_prob: 0.1
55
+ block_size: 8192
56
+ in_channels: 80
57
+ style_condition: true
58
+ final_layer_type: 'wavenet'
59
+ target: 'mel' # mel or codec
60
+ content_dim: 768
61
+ content_codebook_size: 1024
62
+ content_type: 'discrete'
63
+ f0_condition: false
64
+ n_f0_bins: 512
65
+ content_codebooks: 1
66
+ is_causal: false
67
+ long_skip_connection: true
68
+ zero_prompt_speech_token: false # for prompt component, do not input corresponding speech token
69
+
70
+ wavenet:
71
+ hidden_dim: 768
72
+ num_layers: 8
73
+ kernel_size: 5
74
+ dilation_rate: 1
75
+ p_dropout: 0.2
76
+ style_condition: true
77
+
78
+ loss_params:
79
+ base_lr: 0.0001
configs/hifigan.yml CHANGED
@@ -22,4 +22,4 @@ f0_predictor:
22
  in_channels: 80
23
  cond_channels: 512
24
 
25
- pretrained_model_path: "hift.pt"
 
22
  in_channels: 80
23
  cond_channels: 512
24
 
25
+ pretrained_model_path: "checkpoints/hift.pt"
dac/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "1.0.0"
2
+
3
+ # preserved here for legacy reasons
4
+ __model_version__ = "latest"
5
+
6
+ import audiotools
7
+
8
+ audiotools.ml.BaseModel.INTERN += ["dac.**"]
9
+ audiotools.ml.BaseModel.EXTERN += ["einops"]
10
+
11
+
12
+ from . import nn
13
+ from . import model
14
+ from . import utils
15
+ from .model import DAC
16
+ from .model import DACFile
dac/__main__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import argbind
4
+
5
+ from dac.utils import download
6
+ from dac.utils.decode import decode
7
+ from dac.utils.encode import encode
8
+
9
+ STAGES = ["encode", "decode", "download"]
10
+
11
+
12
+ def run(stage: str):
13
+ """Run stages.
14
+
15
+ Parameters
16
+ ----------
17
+ stage : str
18
+ Stage to run
19
+ """
20
+ if stage not in STAGES:
21
+ raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}")
22
+ stage_fn = globals()[stage]
23
+
24
+ if stage == "download":
25
+ stage_fn()
26
+ return
27
+
28
+ stage_fn()
29
+
30
+
31
+ if __name__ == "__main__":
32
+ group = sys.argv.pop(1)
33
+ args = argbind.parse_args(group=group)
34
+
35
+ with argbind.scope(args):
36
+ run(group)
dac/model/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .base import CodecMixin
2
+ from .base import DACFile
3
+ from .dac import DAC
4
+ from .discriminator import Discriminator
dac/model/base.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import tqdm
9
+ from audiotools import AudioSignal
10
+ from torch import nn
11
+
12
+ SUPPORTED_VERSIONS = ["1.0.0"]
13
+
14
+
15
+ @dataclass
16
+ class DACFile:
17
+ codes: torch.Tensor
18
+
19
+ # Metadata
20
+ chunk_length: int
21
+ original_length: int
22
+ input_db: float
23
+ channels: int
24
+ sample_rate: int
25
+ padding: bool
26
+ dac_version: str
27
+
28
+ def save(self, path):
29
+ artifacts = {
30
+ "codes": self.codes.numpy().astype(np.uint16),
31
+ "metadata": {
32
+ "input_db": self.input_db.numpy().astype(np.float32),
33
+ "original_length": self.original_length,
34
+ "sample_rate": self.sample_rate,
35
+ "chunk_length": self.chunk_length,
36
+ "channels": self.channels,
37
+ "padding": self.padding,
38
+ "dac_version": SUPPORTED_VERSIONS[-1],
39
+ },
40
+ }
41
+ path = Path(path).with_suffix(".dac")
42
+ with open(path, "wb") as f:
43
+ np.save(f, artifacts)
44
+ return path
45
+
46
+ @classmethod
47
+ def load(cls, path):
48
+ artifacts = np.load(path, allow_pickle=True)[()]
49
+ codes = torch.from_numpy(artifacts["codes"].astype(int))
50
+ if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
51
+ raise RuntimeError(
52
+ f"Given file {path} can't be loaded with this version of descript-audio-codec."
53
+ )
54
+ return cls(codes=codes, **artifacts["metadata"])
55
+
56
+
57
+ class CodecMixin:
58
+ @property
59
+ def padding(self):
60
+ if not hasattr(self, "_padding"):
61
+ self._padding = True
62
+ return self._padding
63
+
64
+ @padding.setter
65
+ def padding(self, value):
66
+ assert isinstance(value, bool)
67
+
68
+ layers = [
69
+ l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))
70
+ ]
71
+
72
+ for layer in layers:
73
+ if value:
74
+ if hasattr(layer, "original_padding"):
75
+ layer.padding = layer.original_padding
76
+ else:
77
+ layer.original_padding = layer.padding
78
+ layer.padding = tuple(0 for _ in range(len(layer.padding)))
79
+
80
+ self._padding = value
81
+
82
+ def get_delay(self):
83
+ # Any number works here, delay is invariant to input length
84
+ l_out = self.get_output_length(0)
85
+ L = l_out
86
+
87
+ layers = []
88
+ for layer in self.modules():
89
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
90
+ layers.append(layer)
91
+
92
+ for layer in reversed(layers):
93
+ d = layer.dilation[0]
94
+ k = layer.kernel_size[0]
95
+ s = layer.stride[0]
96
+
97
+ if isinstance(layer, nn.ConvTranspose1d):
98
+ L = ((L - d * (k - 1) - 1) / s) + 1
99
+ elif isinstance(layer, nn.Conv1d):
100
+ L = (L - 1) * s + d * (k - 1) + 1
101
+
102
+ L = math.ceil(L)
103
+
104
+ l_in = L
105
+
106
+ return (l_in - l_out) // 2
107
+
108
+ def get_output_length(self, input_length):
109
+ L = input_length
110
+ # Calculate output length
111
+ for layer in self.modules():
112
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
113
+ d = layer.dilation[0]
114
+ k = layer.kernel_size[0]
115
+ s = layer.stride[0]
116
+
117
+ if isinstance(layer, nn.Conv1d):
118
+ L = ((L - d * (k - 1) - 1) / s) + 1
119
+ elif isinstance(layer, nn.ConvTranspose1d):
120
+ L = (L - 1) * s + d * (k - 1) + 1
121
+
122
+ L = math.floor(L)
123
+ return L
124
+
125
+ @torch.no_grad()
126
+ def compress(
127
+ self,
128
+ audio_path_or_signal: Union[str, Path, AudioSignal],
129
+ win_duration: float = 1.0,
130
+ verbose: bool = False,
131
+ normalize_db: float = -16,
132
+ n_quantizers: int = None,
133
+ ) -> DACFile:
134
+ """Processes an audio signal from a file or AudioSignal object into
135
+ discrete codes. This function processes the signal in short windows,
136
+ using constant GPU memory.
137
+
138
+ Parameters
139
+ ----------
140
+ audio_path_or_signal : Union[str, Path, AudioSignal]
141
+ audio signal to reconstruct
142
+ win_duration : float, optional
143
+ window duration in seconds, by default 5.0
144
+ verbose : bool, optional
145
+ by default False
146
+ normalize_db : float, optional
147
+ normalize db, by default -16
148
+
149
+ Returns
150
+ -------
151
+ DACFile
152
+ Object containing compressed codes and metadata
153
+ required for decompression
154
+ """
155
+ audio_signal = audio_path_or_signal
156
+ if isinstance(audio_signal, (str, Path)):
157
+ audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
158
+
159
+ self.eval()
160
+ original_padding = self.padding
161
+ original_device = audio_signal.device
162
+
163
+ audio_signal = audio_signal.clone()
164
+ original_sr = audio_signal.sample_rate
165
+
166
+ resample_fn = audio_signal.resample
167
+ loudness_fn = audio_signal.loudness
168
+
169
+ # If audio is > 10 minutes long, use the ffmpeg versions
170
+ if audio_signal.signal_duration >= 10 * 60 * 60:
171
+ resample_fn = audio_signal.ffmpeg_resample
172
+ loudness_fn = audio_signal.ffmpeg_loudness
173
+
174
+ original_length = audio_signal.signal_length
175
+ resample_fn(self.sample_rate)
176
+ input_db = loudness_fn()
177
+
178
+ if normalize_db is not None:
179
+ audio_signal.normalize(normalize_db)
180
+ audio_signal.ensure_max_of_audio()
181
+
182
+ nb, nac, nt = audio_signal.audio_data.shape
183
+ audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
184
+ win_duration = (
185
+ audio_signal.signal_duration if win_duration is None else win_duration
186
+ )
187
+
188
+ if audio_signal.signal_duration <= win_duration:
189
+ # Unchunked compression (used if signal length < win duration)
190
+ self.padding = True
191
+ n_samples = nt
192
+ hop = nt
193
+ else:
194
+ # Chunked inference
195
+ self.padding = False
196
+ # Zero-pad signal on either side by the delay
197
+ audio_signal.zero_pad(self.delay, self.delay)
198
+ n_samples = int(win_duration * self.sample_rate)
199
+ # Round n_samples to nearest hop length multiple
200
+ n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
201
+ hop = self.get_output_length(n_samples)
202
+
203
+ codes = []
204
+ range_fn = range if not verbose else tqdm.trange
205
+
206
+ for i in range_fn(0, nt, hop):
207
+ x = audio_signal[..., i : i + n_samples]
208
+ x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
209
+
210
+ audio_data = x.audio_data.to(self.device)
211
+ audio_data = self.preprocess(audio_data, self.sample_rate)
212
+ _, c, _, _, _ = self.encode(audio_data, n_quantizers)
213
+ codes.append(c.to(original_device))
214
+ chunk_length = c.shape[-1]
215
+
216
+ codes = torch.cat(codes, dim=-1)
217
+
218
+ dac_file = DACFile(
219
+ codes=codes,
220
+ chunk_length=chunk_length,
221
+ original_length=original_length,
222
+ input_db=input_db,
223
+ channels=nac,
224
+ sample_rate=original_sr,
225
+ padding=self.padding,
226
+ dac_version=SUPPORTED_VERSIONS[-1],
227
+ )
228
+
229
+ if n_quantizers is not None:
230
+ codes = codes[:, :n_quantizers, :]
231
+
232
+ self.padding = original_padding
233
+ return dac_file
234
+
235
+ @torch.no_grad()
236
+ def decompress(
237
+ self,
238
+ obj: Union[str, Path, DACFile],
239
+ verbose: bool = False,
240
+ ) -> AudioSignal:
241
+ """Reconstruct audio from a given .dac file
242
+
243
+ Parameters
244
+ ----------
245
+ obj : Union[str, Path, DACFile]
246
+ .dac file location or corresponding DACFile object.
247
+ verbose : bool, optional
248
+ Prints progress if True, by default False
249
+
250
+ Returns
251
+ -------
252
+ AudioSignal
253
+ Object with the reconstructed audio
254
+ """
255
+ self.eval()
256
+ if isinstance(obj, (str, Path)):
257
+ obj = DACFile.load(obj)
258
+
259
+ original_padding = self.padding
260
+ self.padding = obj.padding
261
+
262
+ range_fn = range if not verbose else tqdm.trange
263
+ codes = obj.codes
264
+ original_device = codes.device
265
+ chunk_length = obj.chunk_length
266
+ recons = []
267
+
268
+ for i in range_fn(0, codes.shape[-1], chunk_length):
269
+ c = codes[..., i : i + chunk_length].to(self.device)
270
+ z = self.quantizer.from_codes(c)[0]
271
+ r = self.decode(z)
272
+ recons.append(r.to(original_device))
273
+
274
+ recons = torch.cat(recons, dim=-1)
275
+ recons = AudioSignal(recons, self.sample_rate)
276
+
277
+ resample_fn = recons.resample
278
+ loudness_fn = recons.loudness
279
+
280
+ # If audio is > 10 minutes long, use the ffmpeg versions
281
+ if recons.signal_duration >= 10 * 60 * 60:
282
+ resample_fn = recons.ffmpeg_resample
283
+ loudness_fn = recons.ffmpeg_loudness
284
+
285
+ recons.normalize(obj.input_db)
286
+ resample_fn(obj.sample_rate)
287
+ recons = recons[..., : obj.original_length]
288
+ loudness_fn()
289
+ recons.audio_data = recons.audio_data.reshape(
290
+ -1, obj.channels, obj.original_length
291
+ )
292
+
293
+ self.padding = original_padding
294
+ return recons
dac/model/dac.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+ from typing import Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from audiotools import AudioSignal
8
+ from audiotools.ml import BaseModel
9
+ from torch import nn
10
+
11
+ from .base import CodecMixin
12
+ from dac.nn.layers import Snake1d
13
+ from dac.nn.layers import WNConv1d
14
+ from dac.nn.layers import WNConvTranspose1d
15
+ from dac.nn.quantize import ResidualVectorQuantize
16
+ from .encodec import SConv1d, SConvTranspose1d, SLSTM
17
+
18
+
19
+ def init_weights(m):
20
+ if isinstance(m, nn.Conv1d):
21
+ nn.init.trunc_normal_(m.weight, std=0.02)
22
+ nn.init.constant_(m.bias, 0)
23
+
24
+
25
+ class ResidualUnit(nn.Module):
26
+ def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False):
27
+ super().__init__()
28
+ conv1d_type = SConv1d# if causal else WNConv1d
29
+ pad = ((7 - 1) * dilation) // 2
30
+ self.block = nn.Sequential(
31
+ Snake1d(dim),
32
+ conv1d_type(dim, dim, kernel_size=7, dilation=dilation, padding=pad, causal=causal, norm='weight_norm'),
33
+ Snake1d(dim),
34
+ conv1d_type(dim, dim, kernel_size=1, causal=causal, norm='weight_norm'),
35
+ )
36
+
37
+ def forward(self, x):
38
+ y = self.block(x)
39
+ pad = (x.shape[-1] - y.shape[-1]) // 2
40
+ if pad > 0:
41
+ x = x[..., pad:-pad]
42
+ return x + y
43
+
44
+
45
+ class EncoderBlock(nn.Module):
46
+ def __init__(self, dim: int = 16, stride: int = 1, causal: bool = False):
47
+ super().__init__()
48
+ conv1d_type = SConv1d# if causal else WNConv1d
49
+ self.block = nn.Sequential(
50
+ ResidualUnit(dim // 2, dilation=1, causal=causal),
51
+ ResidualUnit(dim // 2, dilation=3, causal=causal),
52
+ ResidualUnit(dim // 2, dilation=9, causal=causal),
53
+ Snake1d(dim // 2),
54
+ conv1d_type(
55
+ dim // 2,
56
+ dim,
57
+ kernel_size=2 * stride,
58
+ stride=stride,
59
+ padding=math.ceil(stride / 2),
60
+ causal=causal,
61
+ norm='weight_norm',
62
+ ),
63
+ )
64
+
65
+ def forward(self, x):
66
+ return self.block(x)
67
+
68
+
69
+ class Encoder(nn.Module):
70
+ def __init__(
71
+ self,
72
+ d_model: int = 64,
73
+ strides: list = [2, 4, 8, 8],
74
+ d_latent: int = 64,
75
+ causal: bool = False,
76
+ lstm: int = 2,
77
+ ):
78
+ super().__init__()
79
+ conv1d_type = SConv1d# if causal else WNConv1d
80
+ # Create first convolution
81
+ self.block = [conv1d_type(1, d_model, kernel_size=7, padding=3, causal=causal, norm='weight_norm')]
82
+
83
+ # Create EncoderBlocks that double channels as they downsample by `stride`
84
+ for stride in strides:
85
+ d_model *= 2
86
+ self.block += [EncoderBlock(d_model, stride=stride, causal=causal)]
87
+
88
+ # Add LSTM if needed
89
+ self.use_lstm = lstm
90
+ if lstm:
91
+ self.block += [SLSTM(d_model, lstm)]
92
+
93
+ # Create last convolution
94
+ self.block += [
95
+ Snake1d(d_model),
96
+ conv1d_type(d_model, d_latent, kernel_size=3, padding=1, causal=causal, norm='weight_norm'),
97
+ ]
98
+
99
+ # Wrap black into nn.Sequential
100
+ self.block = nn.Sequential(*self.block)
101
+ self.enc_dim = d_model
102
+
103
+ def forward(self, x):
104
+ return self.block(x)
105
+
106
+ def reset_cache(self):
107
+ # recursively find all submodules named SConv1d in self.block and use their reset_cache method
108
+ def reset_cache(m):
109
+ if isinstance(m, SConv1d) or isinstance(m, SLSTM):
110
+ m.reset_cache()
111
+ return
112
+ for child in m.children():
113
+ reset_cache(child)
114
+
115
+ reset_cache(self.block)
116
+
117
+
118
+ class DecoderBlock(nn.Module):
119
+ def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, causal: bool = False):
120
+ super().__init__()
121
+ conv1d_type = SConvTranspose1d #if causal else WNConvTranspose1d
122
+ self.block = nn.Sequential(
123
+ Snake1d(input_dim),
124
+ conv1d_type(
125
+ input_dim,
126
+ output_dim,
127
+ kernel_size=2 * stride,
128
+ stride=stride,
129
+ padding=math.ceil(stride / 2),
130
+ causal=causal,
131
+ norm='weight_norm'
132
+ ),
133
+ ResidualUnit(output_dim, dilation=1, causal=causal),
134
+ ResidualUnit(output_dim, dilation=3, causal=causal),
135
+ ResidualUnit(output_dim, dilation=9, causal=causal),
136
+ )
137
+
138
+ def forward(self, x):
139
+ return self.block(x)
140
+
141
+
142
+ class Decoder(nn.Module):
143
+ def __init__(
144
+ self,
145
+ input_channel,
146
+ channels,
147
+ rates,
148
+ d_out: int = 1,
149
+ causal: bool = False,
150
+ lstm: int = 2,
151
+ ):
152
+ super().__init__()
153
+ conv1d_type = SConv1d# if causal else WNConv1d
154
+ # Add first conv layer
155
+ layers = [conv1d_type(input_channel, channels, kernel_size=7, padding=3, causal=causal, norm='weight_norm')]
156
+
157
+ if lstm:
158
+ layers += [SLSTM(channels, num_layers=lstm)]
159
+
160
+ # Add upsampling + MRF blocks
161
+ for i, stride in enumerate(rates):
162
+ input_dim = channels // 2**i
163
+ output_dim = channels // 2 ** (i + 1)
164
+ layers += [DecoderBlock(input_dim, output_dim, stride, causal=causal)]
165
+
166
+ # Add final conv layer
167
+ layers += [
168
+ Snake1d(output_dim),
169
+ conv1d_type(output_dim, d_out, kernel_size=7, padding=3, causal=causal, norm='weight_norm'),
170
+ nn.Tanh(),
171
+ ]
172
+
173
+ self.model = nn.Sequential(*layers)
174
+
175
+ def forward(self, x):
176
+ return self.model(x)
177
+
178
+
179
+ class DAC(BaseModel, CodecMixin):
180
+ def __init__(
181
+ self,
182
+ encoder_dim: int = 64,
183
+ encoder_rates: List[int] = [2, 4, 8, 8],
184
+ latent_dim: int = None,
185
+ decoder_dim: int = 1536,
186
+ decoder_rates: List[int] = [8, 8, 4, 2],
187
+ n_codebooks: int = 9,
188
+ codebook_size: int = 1024,
189
+ codebook_dim: Union[int, list] = 8,
190
+ quantizer_dropout: bool = False,
191
+ sample_rate: int = 44100,
192
+ lstm: int = 2,
193
+ causal: bool = False,
194
+ ):
195
+ super().__init__()
196
+
197
+ self.encoder_dim = encoder_dim
198
+ self.encoder_rates = encoder_rates
199
+ self.decoder_dim = decoder_dim
200
+ self.decoder_rates = decoder_rates
201
+ self.sample_rate = sample_rate
202
+
203
+ if latent_dim is None:
204
+ latent_dim = encoder_dim * (2 ** len(encoder_rates))
205
+
206
+ self.latent_dim = latent_dim
207
+
208
+ self.hop_length = np.prod(encoder_rates)
209
+ self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim, causal=causal, lstm=lstm)
210
+
211
+ self.n_codebooks = n_codebooks
212
+ self.codebook_size = codebook_size
213
+ self.codebook_dim = codebook_dim
214
+ self.quantizer = ResidualVectorQuantize(
215
+ input_dim=latent_dim,
216
+ n_codebooks=n_codebooks,
217
+ codebook_size=codebook_size,
218
+ codebook_dim=codebook_dim,
219
+ quantizer_dropout=quantizer_dropout,
220
+ )
221
+
222
+ self.decoder = Decoder(
223
+ latent_dim,
224
+ decoder_dim,
225
+ decoder_rates,
226
+ lstm=lstm,
227
+ causal=causal,
228
+ )
229
+ self.sample_rate = sample_rate
230
+ self.apply(init_weights)
231
+
232
+ self.delay = self.get_delay()
233
+
234
+ def preprocess(self, audio_data, sample_rate):
235
+ if sample_rate is None:
236
+ sample_rate = self.sample_rate
237
+ assert sample_rate == self.sample_rate
238
+
239
+ length = audio_data.shape[-1]
240
+ right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
241
+ audio_data = nn.functional.pad(audio_data, (0, right_pad))
242
+
243
+ return audio_data
244
+
245
+ def encode(
246
+ self,
247
+ audio_data: torch.Tensor,
248
+ n_quantizers: int = None,
249
+ ):
250
+ """Encode given audio data and return quantized latent codes
251
+
252
+ Parameters
253
+ ----------
254
+ audio_data : Tensor[B x 1 x T]
255
+ Audio data to encode
256
+ n_quantizers : int, optional
257
+ Number of quantizers to use, by default None
258
+ If None, all quantizers are used.
259
+
260
+ Returns
261
+ -------
262
+ dict
263
+ A dictionary with the following keys:
264
+ "z" : Tensor[B x D x T]
265
+ Quantized continuous representation of input
266
+ "codes" : Tensor[B x N x T]
267
+ Codebook indices for each codebook
268
+ (quantized discrete representation of input)
269
+ "latents" : Tensor[B x N*D x T]
270
+ Projected latents (continuous representation of input before quantization)
271
+ "vq/commitment_loss" : Tensor[1]
272
+ Commitment loss to train encoder to predict vectors closer to codebook
273
+ entries
274
+ "vq/codebook_loss" : Tensor[1]
275
+ Codebook loss to update the codebook
276
+ "length" : int
277
+ Number of samples in input audio
278
+ """
279
+ z = self.encoder(audio_data)
280
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(
281
+ z, n_quantizers
282
+ )
283
+ return z, codes, latents, commitment_loss, codebook_loss
284
+
285
+ def decode(self, z: torch.Tensor):
286
+ """Decode given latent codes and return audio data
287
+
288
+ Parameters
289
+ ----------
290
+ z : Tensor[B x D x T]
291
+ Quantized continuous representation of input
292
+ length : int, optional
293
+ Number of samples in output audio, by default None
294
+
295
+ Returns
296
+ -------
297
+ dict
298
+ A dictionary with the following keys:
299
+ "audio" : Tensor[B x 1 x length]
300
+ Decoded audio data.
301
+ """
302
+ return self.decoder(z)
303
+
304
+ def forward(
305
+ self,
306
+ audio_data: torch.Tensor,
307
+ sample_rate: int = None,
308
+ n_quantizers: int = None,
309
+ ):
310
+ """Model forward pass
311
+
312
+ Parameters
313
+ ----------
314
+ audio_data : Tensor[B x 1 x T]
315
+ Audio data to encode
316
+ sample_rate : int, optional
317
+ Sample rate of audio data in Hz, by default None
318
+ If None, defaults to `self.sample_rate`
319
+ n_quantizers : int, optional
320
+ Number of quantizers to use, by default None.
321
+ If None, all quantizers are used.
322
+
323
+ Returns
324
+ -------
325
+ dict
326
+ A dictionary with the following keys:
327
+ "z" : Tensor[B x D x T]
328
+ Quantized continuous representation of input
329
+ "codes" : Tensor[B x N x T]
330
+ Codebook indices for each codebook
331
+ (quantized discrete representation of input)
332
+ "latents" : Tensor[B x N*D x T]
333
+ Projected latents (continuous representation of input before quantization)
334
+ "vq/commitment_loss" : Tensor[1]
335
+ Commitment loss to train encoder to predict vectors closer to codebook
336
+ entries
337
+ "vq/codebook_loss" : Tensor[1]
338
+ Codebook loss to update the codebook
339
+ "length" : int
340
+ Number of samples in input audio
341
+ "audio" : Tensor[B x 1 x length]
342
+ Decoded audio data.
343
+ """
344
+ length = audio_data.shape[-1]
345
+ audio_data = self.preprocess(audio_data, sample_rate)
346
+ z, codes, latents, commitment_loss, codebook_loss = self.encode(
347
+ audio_data, n_quantizers
348
+ )
349
+
350
+ x = self.decode(z)
351
+ return {
352
+ "audio": x[..., :length],
353
+ "z": z,
354
+ "codes": codes,
355
+ "latents": latents,
356
+ "vq/commitment_loss": commitment_loss,
357
+ "vq/codebook_loss": codebook_loss,
358
+ }
359
+
360
+
361
+ if __name__ == "__main__":
362
+ import numpy as np
363
+ from functools import partial
364
+
365
+ model = DAC().to("cpu")
366
+
367
+ for n, m in model.named_modules():
368
+ o = m.extra_repr()
369
+ p = sum([np.prod(p.size()) for p in m.parameters()])
370
+ fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
371
+ setattr(m, "extra_repr", partial(fn, o=o, p=p))
372
+ print(model)
373
+ print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
374
+
375
+ length = 88200 * 2
376
+ x = torch.randn(1, 1, length).to(model.device)
377
+ x.requires_grad_(True)
378
+ x.retain_grad()
379
+
380
+ # Make a forward pass
381
+ out = model(x)["audio"]
382
+ print("Input shape:", x.shape)
383
+ print("Output shape:", out.shape)
384
+
385
+ # Create gradient variable
386
+ grad = torch.zeros_like(out)
387
+ grad[:, :, grad.shape[-1] // 2] = 1
388
+
389
+ # Make a backward pass
390
+ out.backward(grad)
391
+
392
+ # Check non-zero values
393
+ gradmap = x.grad.squeeze(0)
394
+ gradmap = (gradmap != 0).sum(0) # sum across features
395
+ rf = (gradmap != 0).sum()
396
+
397
+ print(f"Receptive field: {rf.item()}")
398
+
399
+ x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
400
+ model.decompress(model.compress(x, verbose=True), verbose=True)
dac/model/discriminator.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from audiotools import AudioSignal
5
+ from audiotools import ml
6
+ from audiotools import STFTParams
7
+ from einops import rearrange
8
+ from torch.nn.utils import weight_norm
9
+
10
+
11
+ def WNConv1d(*args, **kwargs):
12
+ act = kwargs.pop("act", True)
13
+ conv = weight_norm(nn.Conv1d(*args, **kwargs))
14
+ if not act:
15
+ return conv
16
+ return nn.Sequential(conv, nn.LeakyReLU(0.1))
17
+
18
+
19
+ def WNConv2d(*args, **kwargs):
20
+ act = kwargs.pop("act", True)
21
+ conv = weight_norm(nn.Conv2d(*args, **kwargs))
22
+ if not act:
23
+ return conv
24
+ return nn.Sequential(conv, nn.LeakyReLU(0.1))
25
+
26
+
27
+ class MPD(nn.Module):
28
+ def __init__(self, period):
29
+ super().__init__()
30
+ self.period = period
31
+ self.convs = nn.ModuleList(
32
+ [
33
+ WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
34
+ WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
35
+ WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
36
+ WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
37
+ WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
38
+ ]
39
+ )
40
+ self.conv_post = WNConv2d(
41
+ 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
42
+ )
43
+
44
+ def pad_to_period(self, x):
45
+ t = x.shape[-1]
46
+ x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
47
+ return x
48
+
49
+ def forward(self, x):
50
+ fmap = []
51
+
52
+ x = self.pad_to_period(x)
53
+ x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
54
+
55
+ for layer in self.convs:
56
+ x = layer(x)
57
+ fmap.append(x)
58
+
59
+ x = self.conv_post(x)
60
+ fmap.append(x)
61
+
62
+ return fmap
63
+
64
+
65
+ class MSD(nn.Module):
66
+ def __init__(self, rate: int = 1, sample_rate: int = 44100):
67
+ super().__init__()
68
+ self.convs = nn.ModuleList(
69
+ [
70
+ WNConv1d(1, 16, 15, 1, padding=7),
71
+ WNConv1d(16, 64, 41, 4, groups=4, padding=20),
72
+ WNConv1d(64, 256, 41, 4, groups=16, padding=20),
73
+ WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
74
+ WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
75
+ WNConv1d(1024, 1024, 5, 1, padding=2),
76
+ ]
77
+ )
78
+ self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
79
+ self.sample_rate = sample_rate
80
+ self.rate = rate
81
+
82
+ def forward(self, x):
83
+ x = AudioSignal(x, self.sample_rate)
84
+ x.resample(self.sample_rate // self.rate)
85
+ x = x.audio_data
86
+
87
+ fmap = []
88
+
89
+ for l in self.convs:
90
+ x = l(x)
91
+ fmap.append(x)
92
+ x = self.conv_post(x)
93
+ fmap.append(x)
94
+
95
+ return fmap
96
+
97
+
98
+ BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
99
+
100
+
101
+ class MRD(nn.Module):
102
+ def __init__(
103
+ self,
104
+ window_length: int,
105
+ hop_factor: float = 0.25,
106
+ sample_rate: int = 44100,
107
+ bands: list = BANDS,
108
+ ):
109
+ """Complex multi-band spectrogram discriminator.
110
+ Parameters
111
+ ----------
112
+ window_length : int
113
+ Window length of STFT.
114
+ hop_factor : float, optional
115
+ Hop factor of the STFT, defaults to ``0.25 * window_length``.
116
+ sample_rate : int, optional
117
+ Sampling rate of audio in Hz, by default 44100
118
+ bands : list, optional
119
+ Bands to run discriminator over.
120
+ """
121
+ super().__init__()
122
+
123
+ self.window_length = window_length
124
+ self.hop_factor = hop_factor
125
+ self.sample_rate = sample_rate
126
+ self.stft_params = STFTParams(
127
+ window_length=window_length,
128
+ hop_length=int(window_length * hop_factor),
129
+ match_stride=True,
130
+ )
131
+
132
+ n_fft = window_length // 2 + 1
133
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
134
+ self.bands = bands
135
+
136
+ ch = 32
137
+ convs = lambda: nn.ModuleList(
138
+ [
139
+ WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
140
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
141
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
142
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
143
+ WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
144
+ ]
145
+ )
146
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
147
+ self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
148
+
149
+ def spectrogram(self, x):
150
+ x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
151
+ x = torch.view_as_real(x.stft())
152
+ x = rearrange(x, "b 1 f t c -> (b 1) c t f")
153
+ # Split into bands
154
+ x_bands = [x[..., b[0] : b[1]] for b in self.bands]
155
+ return x_bands
156
+
157
+ def forward(self, x):
158
+ x_bands = self.spectrogram(x)
159
+ fmap = []
160
+
161
+ x = []
162
+ for band, stack in zip(x_bands, self.band_convs):
163
+ for layer in stack:
164
+ band = layer(band)
165
+ fmap.append(band)
166
+ x.append(band)
167
+
168
+ x = torch.cat(x, dim=-1)
169
+ x = self.conv_post(x)
170
+ fmap.append(x)
171
+
172
+ return fmap
173
+
174
+
175
+ class Discriminator(nn.Module):
176
+ def __init__(
177
+ self,
178
+ rates: list = [],
179
+ periods: list = [2, 3, 5, 7, 11],
180
+ fft_sizes: list = [2048, 1024, 512],
181
+ sample_rate: int = 44100,
182
+ bands: list = BANDS,
183
+ ):
184
+ """Discriminator that combines multiple discriminators.
185
+
186
+ Parameters
187
+ ----------
188
+ rates : list, optional
189
+ sampling rates (in Hz) to run MSD at, by default []
190
+ If empty, MSD is not used.
191
+ periods : list, optional
192
+ periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
193
+ fft_sizes : list, optional
194
+ Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
195
+ sample_rate : int, optional
196
+ Sampling rate of audio in Hz, by default 44100
197
+ bands : list, optional
198
+ Bands to run MRD at, by default `BANDS`
199
+ """
200
+ super().__init__()
201
+ discs = []
202
+ discs += [MPD(p) for p in periods]
203
+ discs += [MSD(r, sample_rate=sample_rate) for r in rates]
204
+ discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes]
205
+ self.discriminators = nn.ModuleList(discs)
206
+
207
+ def preprocess(self, y):
208
+ # Remove DC offset
209
+ y = y - y.mean(dim=-1, keepdims=True)
210
+ # Peak normalize the volume of input audio
211
+ y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
212
+ return y
213
+
214
+ def forward(self, x):
215
+ x = self.preprocess(x)
216
+ fmaps = [d(x) for d in self.discriminators]
217
+ return fmaps
218
+
219
+
220
+ if __name__ == "__main__":
221
+ disc = Discriminator()
222
+ x = torch.zeros(1, 1, 44100)
223
+ results = disc(x)
224
+ for i, result in enumerate(results):
225
+ print(f"disc{i}")
226
+ for i, r in enumerate(result):
227
+ print(r.shape, r.mean(), r.min(), r.max())
228
+ print()
dac/model/encodec.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Convolutional layers wrappers and utilities."""
8
+
9
+ import math
10
+ import typing as tp
11
+ import warnings
12
+
13
+ import torch
14
+ from torch import nn
15
+ from torch.nn import functional as F
16
+ from torch.nn.utils import spectral_norm, weight_norm
17
+
18
+ import typing as tp
19
+
20
+ import einops
21
+
22
+
23
+ class ConvLayerNorm(nn.LayerNorm):
24
+ """
25
+ Convolution-friendly LayerNorm that moves channels to last dimensions
26
+ before running the normalization and moves them back to original position right after.
27
+ """
28
+ def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
29
+ super().__init__(normalized_shape, **kwargs)
30
+
31
+ def forward(self, x):
32
+ x = einops.rearrange(x, 'b ... t -> b t ...')
33
+ x = super().forward(x)
34
+ x = einops.rearrange(x, 'b t ... -> b ... t')
35
+ return
36
+
37
+
38
+ CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
39
+ 'time_layer_norm', 'layer_norm', 'time_group_norm'])
40
+
41
+
42
+ def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
43
+ assert norm in CONV_NORMALIZATIONS
44
+ if norm == 'weight_norm':
45
+ return weight_norm(module)
46
+ elif norm == 'spectral_norm':
47
+ return spectral_norm(module)
48
+ else:
49
+ # We already check was in CONV_NORMALIZATION, so any other choice
50
+ # doesn't need reparametrization.
51
+ return module
52
+
53
+
54
+ def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
55
+ """Return the proper normalization module. If causal is True, this will ensure the returned
56
+ module is causal, or return an error if the normalization doesn't support causal evaluation.
57
+ """
58
+ assert norm in CONV_NORMALIZATIONS
59
+ if norm == 'layer_norm':
60
+ assert isinstance(module, nn.modules.conv._ConvNd)
61
+ return ConvLayerNorm(module.out_channels, **norm_kwargs)
62
+ elif norm == 'time_group_norm':
63
+ if causal:
64
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
65
+ assert isinstance(module, nn.modules.conv._ConvNd)
66
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
67
+ else:
68
+ return nn.Identity()
69
+
70
+
71
+ def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
72
+ padding_total: int = 0) -> int:
73
+ """See `pad_for_conv1d`.
74
+ """
75
+ length = x.shape[-1]
76
+ n_frames = (length - kernel_size + padding_total) / stride + 1
77
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
78
+ return ideal_length - length
79
+
80
+
81
+ def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
82
+ """Pad for a convolution to make sure that the last window is full.
83
+ Extra padding is added at the end. This is required to ensure that we can rebuild
84
+ an output of the same length, as otherwise, even with padding, some time steps
85
+ might get removed.
86
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
87
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
88
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
89
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
90
+ 1 2 3 4 # once you removed padding, we are missing one time step !
91
+ """
92
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
93
+ return F.pad(x, (0, extra_padding))
94
+
95
+
96
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
97
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
98
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
99
+ """
100
+ length = x.shape[-1]
101
+ padding_left, padding_right = paddings
102
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
103
+ if mode == 'reflect':
104
+ max_pad = max(padding_left, padding_right)
105
+ extra_pad = 0
106
+ if length <= max_pad:
107
+ extra_pad = max_pad - length + 1
108
+ x = F.pad(x, (0, extra_pad))
109
+ padded = F.pad(x, paddings, mode, value)
110
+ end = padded.shape[-1] - extra_pad
111
+ return padded[..., :end]
112
+ else:
113
+ return F.pad(x, paddings, mode, value)
114
+
115
+
116
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
117
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
118
+ padding_left, padding_right = paddings
119
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
120
+ assert (padding_left + padding_right) <= x.shape[-1]
121
+ end = x.shape[-1] - padding_right
122
+ return x[..., padding_left: end]
123
+
124
+
125
+ class NormConv1d(nn.Module):
126
+ """Wrapper around Conv1d and normalization applied to this conv
127
+ to provide a uniform interface across normalization approaches.
128
+ """
129
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
130
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
131
+ super().__init__()
132
+ self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
133
+ self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
134
+ self.norm_type = norm
135
+
136
+ def forward(self, x):
137
+ x = self.conv(x)
138
+ x = self.norm(x)
139
+ return x
140
+
141
+
142
+ class NormConv2d(nn.Module):
143
+ """Wrapper around Conv2d and normalization applied to this conv
144
+ to provide a uniform interface across normalization approaches.
145
+ """
146
+ def __init__(self, *args, norm: str = 'none',
147
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
148
+ super().__init__()
149
+ self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
150
+ self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
151
+ self.norm_type = norm
152
+
153
+ def forward(self, x):
154
+ x = self.conv(x)
155
+ x = self.norm(x)
156
+ return x
157
+
158
+
159
+ class NormConvTranspose1d(nn.Module):
160
+ """Wrapper around ConvTranspose1d and normalization applied to this conv
161
+ to provide a uniform interface across normalization approaches.
162
+ """
163
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
164
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
165
+ super().__init__()
166
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
167
+ self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
168
+ self.norm_type = norm
169
+
170
+ def forward(self, x):
171
+ x = self.convtr(x)
172
+ x = self.norm(x)
173
+ return x
174
+
175
+
176
+ class NormConvTranspose2d(nn.Module):
177
+ """Wrapper around ConvTranspose2d and normalization applied to this conv
178
+ to provide a uniform interface across normalization approaches.
179
+ """
180
+ def __init__(self, *args, norm: str = 'none',
181
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
182
+ super().__init__()
183
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
184
+ self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
185
+
186
+ def forward(self, x):
187
+ x = self.convtr(x)
188
+ x = self.norm(x)
189
+ return x
190
+
191
+
192
+ class SConv1d(nn.Module):
193
+ """Conv1d with some builtin handling of asymmetric or causal padding
194
+ and normalization.
195
+ """
196
+ def __init__(self, in_channels: int, out_channels: int,
197
+ kernel_size: int, stride: int = 1, dilation: int = 1,
198
+ groups: int = 1, bias: bool = True, causal: bool = False,
199
+ norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
200
+ pad_mode: str = 'reflect', **kwargs):
201
+ super().__init__()
202
+ # warn user on unusual setup between dilation and stride
203
+ if stride > 1 and dilation > 1:
204
+ warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1'
205
+ f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
206
+ self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
207
+ dilation=dilation, groups=groups, bias=bias, causal=causal,
208
+ norm=norm, norm_kwargs=norm_kwargs)
209
+ self.causal = causal
210
+ self.pad_mode = pad_mode
211
+
212
+ self.cache_enabled = False
213
+
214
+ def reset_cache(self):
215
+ """Reset the cache when starting a new stream."""
216
+ self.cache = None
217
+ self.cache_enabled = True
218
+
219
+ def forward(self, x):
220
+ B, C, T = x.shape
221
+ kernel_size = self.conv.conv.kernel_size[0]
222
+ stride = self.conv.conv.stride[0]
223
+ dilation = self.conv.conv.dilation[0]
224
+ kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
225
+ padding_total = kernel_size - stride
226
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
227
+
228
+ if self.causal:
229
+ # Left padding for causal
230
+ if self.cache_enabled and self.cache is not None:
231
+ # Concatenate the cache (previous inputs) with the new input for streaming
232
+ x = torch.cat([self.cache, x], dim=2)
233
+ else:
234
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
235
+ else:
236
+ # Asymmetric padding required for odd strides
237
+ padding_right = padding_total // 2
238
+ padding_left = padding_total - padding_right
239
+ x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
240
+
241
+ # Store the most recent input frames for future cache use
242
+ if self.cache_enabled:
243
+ if self.cache is None:
244
+ # Initialize cache with zeros (at the start of streaming)
245
+ self.cache = torch.zeros(B, C, kernel_size - 1, device=x.device)
246
+ # Update the cache by storing the latest input frames
247
+ if kernel_size > 1:
248
+ self.cache = x[:, :, -kernel_size + 1:].detach() # Only store the necessary frames
249
+
250
+ return self.conv(x)
251
+
252
+
253
+
254
+ class SConvTranspose1d(nn.Module):
255
+ """ConvTranspose1d with some builtin handling of asymmetric or causal padding
256
+ and normalization.
257
+ """
258
+ def __init__(self, in_channels: int, out_channels: int,
259
+ kernel_size: int, stride: int = 1, causal: bool = False,
260
+ norm: str = 'none', trim_right_ratio: float = 1.,
261
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
262
+ super().__init__()
263
+ self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
264
+ causal=causal, norm=norm, norm_kwargs=norm_kwargs)
265
+ self.causal = causal
266
+ self.trim_right_ratio = trim_right_ratio
267
+ assert self.causal or self.trim_right_ratio == 1., \
268
+ "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
269
+ assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
270
+
271
+ def forward(self, x):
272
+ kernel_size = self.convtr.convtr.kernel_size[0]
273
+ stride = self.convtr.convtr.stride[0]
274
+ padding_total = kernel_size - stride
275
+
276
+ y = self.convtr(x)
277
+
278
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
279
+ # removed at the very end, when keeping only the right length for the output,
280
+ # as removing it here would require also passing the length at the matching layer
281
+ # in the encoder.
282
+ if self.causal:
283
+ # Trim the padding on the right according to the specified ratio
284
+ # if trim_right_ratio = 1.0, trim everything from right
285
+ padding_right = math.ceil(padding_total * self.trim_right_ratio)
286
+ padding_left = padding_total - padding_right
287
+ y = unpad1d(y, (padding_left, padding_right))
288
+ else:
289
+ # Asymmetric padding required for odd strides
290
+ padding_right = padding_total // 2
291
+ padding_left = padding_total - padding_right
292
+ y = unpad1d(y, (padding_left, padding_right))
293
+ return y
294
+
295
+ class SLSTM(nn.Module):
296
+ """
297
+ LSTM without worrying about the hidden state, nor the layout of the data.
298
+ Expects input as convolutional layout.
299
+ """
300
+ def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
301
+ super().__init__()
302
+ self.skip = skip
303
+ self.lstm = nn.LSTM(dimension, dimension, num_layers)
304
+ self.hidden = None
305
+ self.cache_enabled = False
306
+
307
+ def forward(self, x):
308
+ x = x.permute(2, 0, 1)
309
+ if self.training or not self.cache_enabled:
310
+ y, _ = self.lstm(x)
311
+ else:
312
+ y, self.hidden = self.lstm(x, self.hidden)
313
+ if self.skip:
314
+ y = y + x
315
+ y = y.permute(1, 2, 0)
316
+ return y
317
+
318
+ def reset_cache(self):
319
+ self.hidden = None
320
+ self.cache_enabled = True
dac/nn/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import layers
2
+ from . import loss
3
+ from . import quantize
dac/nn/layers.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from torch.nn.utils import weight_norm
7
+
8
+
9
+ def WNConv1d(*args, **kwargs):
10
+ return weight_norm(nn.Conv1d(*args, **kwargs))
11
+
12
+
13
+ def WNConvTranspose1d(*args, **kwargs):
14
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
15
+
16
+
17
+ # Scripting this brings model speed up 1.4x
18
+ @torch.jit.script
19
+ def snake(x, alpha):
20
+ shape = x.shape
21
+ x = x.reshape(shape[0], shape[1], -1)
22
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
23
+ x = x.reshape(shape)
24
+ return x
25
+
26
+
27
+ class Snake1d(nn.Module):
28
+ def __init__(self, channels):
29
+ super().__init__()
30
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
31
+
32
+ def forward(self, x):
33
+ return snake(x, self.alpha)
dac/nn/loss.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing
2
+ from typing import List
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from audiotools import AudioSignal
7
+ from audiotools import STFTParams
8
+ from torch import nn
9
+
10
+
11
+ class L1Loss(nn.L1Loss):
12
+ """L1 Loss between AudioSignals. Defaults
13
+ to comparing ``audio_data``, but any
14
+ attribute of an AudioSignal can be used.
15
+
16
+ Parameters
17
+ ----------
18
+ attribute : str, optional
19
+ Attribute of signal to compare, defaults to ``audio_data``.
20
+ weight : float, optional
21
+ Weight of this loss, defaults to 1.0.
22
+
23
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
24
+ """
25
+
26
+ def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
27
+ self.attribute = attribute
28
+ self.weight = weight
29
+ super().__init__(**kwargs)
30
+
31
+ def forward(self, x: AudioSignal, y: AudioSignal):
32
+ """
33
+ Parameters
34
+ ----------
35
+ x : AudioSignal
36
+ Estimate AudioSignal
37
+ y : AudioSignal
38
+ Reference AudioSignal
39
+
40
+ Returns
41
+ -------
42
+ torch.Tensor
43
+ L1 loss between AudioSignal attributes.
44
+ """
45
+ if isinstance(x, AudioSignal):
46
+ x = getattr(x, self.attribute)
47
+ y = getattr(y, self.attribute)
48
+ return super().forward(x, y)
49
+
50
+
51
+ class SISDRLoss(nn.Module):
52
+ """
53
+ Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
54
+ of estimated and reference audio signals or aligned features.
55
+
56
+ Parameters
57
+ ----------
58
+ scaling : int, optional
59
+ Whether to use scale-invariant (True) or
60
+ signal-to-noise ratio (False), by default True
61
+ reduction : str, optional
62
+ How to reduce across the batch (either 'mean',
63
+ 'sum', or none).], by default ' mean'
64
+ zero_mean : int, optional
65
+ Zero mean the references and estimates before
66
+ computing the loss, by default True
67
+ clip_min : int, optional
68
+ The minimum possible loss value. Helps network
69
+ to not focus on making already good examples better, by default None
70
+ weight : float, optional
71
+ Weight of this loss, defaults to 1.0.
72
+
73
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ scaling: int = True,
79
+ reduction: str = "mean",
80
+ zero_mean: int = True,
81
+ clip_min: int = None,
82
+ weight: float = 1.0,
83
+ ):
84
+ self.scaling = scaling
85
+ self.reduction = reduction
86
+ self.zero_mean = zero_mean
87
+ self.clip_min = clip_min
88
+ self.weight = weight
89
+ super().__init__()
90
+
91
+ def forward(self, x: AudioSignal, y: AudioSignal):
92
+ eps = 1e-8
93
+ # nb, nc, nt
94
+ if isinstance(x, AudioSignal):
95
+ references = x.audio_data
96
+ estimates = y.audio_data
97
+ else:
98
+ references = x
99
+ estimates = y
100
+
101
+ nb = references.shape[0]
102
+ references = references.reshape(nb, 1, -1).permute(0, 2, 1)
103
+ estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
104
+
105
+ # samples now on axis 1
106
+ if self.zero_mean:
107
+ mean_reference = references.mean(dim=1, keepdim=True)
108
+ mean_estimate = estimates.mean(dim=1, keepdim=True)
109
+ else:
110
+ mean_reference = 0
111
+ mean_estimate = 0
112
+
113
+ _references = references - mean_reference
114
+ _estimates = estimates - mean_estimate
115
+
116
+ references_projection = (_references**2).sum(dim=-2) + eps
117
+ references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
118
+
119
+ scale = (
120
+ (references_on_estimates / references_projection).unsqueeze(1)
121
+ if self.scaling
122
+ else 1
123
+ )
124
+
125
+ e_true = scale * _references
126
+ e_res = _estimates - e_true
127
+
128
+ signal = (e_true**2).sum(dim=1)
129
+ noise = (e_res**2).sum(dim=1)
130
+ sdr = -10 * torch.log10(signal / noise + eps)
131
+
132
+ if self.clip_min is not None:
133
+ sdr = torch.clamp(sdr, min=self.clip_min)
134
+
135
+ if self.reduction == "mean":
136
+ sdr = sdr.mean()
137
+ elif self.reduction == "sum":
138
+ sdr = sdr.sum()
139
+ return sdr
140
+
141
+
142
+ class MultiScaleSTFTLoss(nn.Module):
143
+ """Computes the multi-scale STFT loss from [1].
144
+
145
+ Parameters
146
+ ----------
147
+ window_lengths : List[int], optional
148
+ Length of each window of each STFT, by default [2048, 512]
149
+ loss_fn : typing.Callable, optional
150
+ How to compare each loss, by default nn.L1Loss()
151
+ clamp_eps : float, optional
152
+ Clamp on the log magnitude, below, by default 1e-5
153
+ mag_weight : float, optional
154
+ Weight of raw magnitude portion of loss, by default 1.0
155
+ log_weight : float, optional
156
+ Weight of log magnitude portion of loss, by default 1.0
157
+ pow : float, optional
158
+ Power to raise magnitude to before taking log, by default 2.0
159
+ weight : float, optional
160
+ Weight of this loss, by default 1.0
161
+ match_stride : bool, optional
162
+ Whether to match the stride of convolutional layers, by default False
163
+
164
+ References
165
+ ----------
166
+
167
+ 1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
168
+ "DDSP: Differentiable Digital Signal Processing."
169
+ International Conference on Learning Representations. 2019.
170
+
171
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ window_lengths: List[int] = [2048, 512],
177
+ loss_fn: typing.Callable = nn.L1Loss(),
178
+ clamp_eps: float = 1e-5,
179
+ mag_weight: float = 1.0,
180
+ log_weight: float = 1.0,
181
+ pow: float = 2.0,
182
+ weight: float = 1.0,
183
+ match_stride: bool = False,
184
+ window_type: str = None,
185
+ ):
186
+ super().__init__()
187
+ self.stft_params = [
188
+ STFTParams(
189
+ window_length=w,
190
+ hop_length=w // 4,
191
+ match_stride=match_stride,
192
+ window_type=window_type,
193
+ )
194
+ for w in window_lengths
195
+ ]
196
+ self.loss_fn = loss_fn
197
+ self.log_weight = log_weight
198
+ self.mag_weight = mag_weight
199
+ self.clamp_eps = clamp_eps
200
+ self.weight = weight
201
+ self.pow = pow
202
+
203
+ def forward(self, x: AudioSignal, y: AudioSignal):
204
+ """Computes multi-scale STFT between an estimate and a reference
205
+ signal.
206
+
207
+ Parameters
208
+ ----------
209
+ x : AudioSignal
210
+ Estimate signal
211
+ y : AudioSignal
212
+ Reference signal
213
+
214
+ Returns
215
+ -------
216
+ torch.Tensor
217
+ Multi-scale STFT loss.
218
+ """
219
+ loss = 0.0
220
+ for s in self.stft_params:
221
+ x.stft(s.window_length, s.hop_length, s.window_type)
222
+ y.stft(s.window_length, s.hop_length, s.window_type)
223
+ loss += self.log_weight * self.loss_fn(
224
+ x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
225
+ y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
226
+ )
227
+ loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
228
+ return loss
229
+
230
+
231
+ class MelSpectrogramLoss(nn.Module):
232
+ """Compute distance between mel spectrograms. Can be used
233
+ in a multi-scale way.
234
+
235
+ Parameters
236
+ ----------
237
+ n_mels : List[int]
238
+ Number of mels per STFT, by default [150, 80],
239
+ window_lengths : List[int], optional
240
+ Length of each window of each STFT, by default [2048, 512]
241
+ loss_fn : typing.Callable, optional
242
+ How to compare each loss, by default nn.L1Loss()
243
+ clamp_eps : float, optional
244
+ Clamp on the log magnitude, below, by default 1e-5
245
+ mag_weight : float, optional
246
+ Weight of raw magnitude portion of loss, by default 1.0
247
+ log_weight : float, optional
248
+ Weight of log magnitude portion of loss, by default 1.0
249
+ pow : float, optional
250
+ Power to raise magnitude to before taking log, by default 2.0
251
+ weight : float, optional
252
+ Weight of this loss, by default 1.0
253
+ match_stride : bool, optional
254
+ Whether to match the stride of convolutional layers, by default False
255
+
256
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
257
+ """
258
+
259
+ def __init__(
260
+ self,
261
+ n_mels: List[int] = [150, 80],
262
+ window_lengths: List[int] = [2048, 512],
263
+ loss_fn: typing.Callable = nn.L1Loss(),
264
+ clamp_eps: float = 1e-5,
265
+ mag_weight: float = 1.0,
266
+ log_weight: float = 1.0,
267
+ pow: float = 2.0,
268
+ weight: float = 1.0,
269
+ match_stride: bool = False,
270
+ mel_fmin: List[float] = [0.0, 0.0],
271
+ mel_fmax: List[float] = [None, None],
272
+ window_type: str = None,
273
+ ):
274
+ super().__init__()
275
+ self.stft_params = [
276
+ STFTParams(
277
+ window_length=w,
278
+ hop_length=w // 4,
279
+ match_stride=match_stride,
280
+ window_type=window_type,
281
+ )
282
+ for w in window_lengths
283
+ ]
284
+ self.n_mels = n_mels
285
+ self.loss_fn = loss_fn
286
+ self.clamp_eps = clamp_eps
287
+ self.log_weight = log_weight
288
+ self.mag_weight = mag_weight
289
+ self.weight = weight
290
+ self.mel_fmin = mel_fmin
291
+ self.mel_fmax = mel_fmax
292
+ self.pow = pow
293
+
294
+ def forward(self, x: AudioSignal, y: AudioSignal):
295
+ """Computes mel loss between an estimate and a reference
296
+ signal.
297
+
298
+ Parameters
299
+ ----------
300
+ x : AudioSignal
301
+ Estimate signal
302
+ y : AudioSignal
303
+ Reference signal
304
+
305
+ Returns
306
+ -------
307
+ torch.Tensor
308
+ Mel loss.
309
+ """
310
+ loss = 0.0
311
+ for n_mels, fmin, fmax, s in zip(
312
+ self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
313
+ ):
314
+ kwargs = {
315
+ "window_length": s.window_length,
316
+ "hop_length": s.hop_length,
317
+ "window_type": s.window_type,
318
+ }
319
+ x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
320
+ y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
321
+
322
+ loss += self.log_weight * self.loss_fn(
323
+ x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
324
+ y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
325
+ )
326
+ loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
327
+ return loss
328
+
329
+
330
+ class GANLoss(nn.Module):
331
+ """
332
+ Computes a discriminator loss, given a discriminator on
333
+ generated waveforms/spectrograms compared to ground truth
334
+ waveforms/spectrograms. Computes the loss for both the
335
+ discriminator and the generator in separate functions.
336
+ """
337
+
338
+ def __init__(self, discriminator):
339
+ super().__init__()
340
+ self.discriminator = discriminator
341
+
342
+ def forward(self, fake, real):
343
+ d_fake = self.discriminator(fake.audio_data)
344
+ d_real = self.discriminator(real.audio_data)
345
+ return d_fake, d_real
346
+
347
+ def discriminator_loss(self, fake, real):
348
+ d_fake, d_real = self.forward(fake.clone().detach(), real)
349
+
350
+ loss_d = 0
351
+ for x_fake, x_real in zip(d_fake, d_real):
352
+ loss_d += torch.mean(x_fake[-1] ** 2)
353
+ loss_d += torch.mean((1 - x_real[-1]) ** 2)
354
+ return loss_d
355
+
356
+ def generator_loss(self, fake, real):
357
+ d_fake, d_real = self.forward(fake, real)
358
+
359
+ loss_g = 0
360
+ for x_fake in d_fake:
361
+ loss_g += torch.mean((1 - x_fake[-1]) ** 2)
362
+
363
+ loss_feature = 0
364
+
365
+ for i in range(len(d_fake)):
366
+ for j in range(len(d_fake[i]) - 1):
367
+ loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
368
+ return loss_g, loss_feature
dac/nn/quantize.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from torch.nn.utils import weight_norm
9
+
10
+ from dac.nn.layers import WNConv1d
11
+
12
+ class VectorQuantizeLegacy(nn.Module):
13
+ """
14
+ Implementation of VQ similar to Karpathy's repo:
15
+ https://github.com/karpathy/deep-vector-quantization
16
+ removed in-out projection
17
+ """
18
+
19
+ def __init__(self, input_dim: int, codebook_size: int):
20
+ super().__init__()
21
+ self.codebook_size = codebook_size
22
+ self.codebook = nn.Embedding(codebook_size, input_dim)
23
+
24
+ def forward(self, z, z_mask=None):
25
+ """Quantized the input tensor using a fixed codebook and returns
26
+ the corresponding codebook vectors
27
+
28
+ Parameters
29
+ ----------
30
+ z : Tensor[B x D x T]
31
+
32
+ Returns
33
+ -------
34
+ Tensor[B x D x T]
35
+ Quantized continuous representation of input
36
+ Tensor[1]
37
+ Commitment loss to train encoder to predict vectors closer to codebook
38
+ entries
39
+ Tensor[1]
40
+ Codebook loss to update the codebook
41
+ Tensor[B x T]
42
+ Codebook indices (quantized discrete representation of input)
43
+ Tensor[B x D x T]
44
+ Projected latents (continuous representation of input before quantization)
45
+ """
46
+
47
+ z_e = z
48
+ z_q, indices = self.decode_latents(z)
49
+
50
+ if z_mask is not None:
51
+ commitment_loss = (F.mse_loss(z_e, z_q.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum()
52
+ codebook_loss = (F.mse_loss(z_q, z_e.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum()
53
+ else:
54
+ commitment_loss = F.mse_loss(z_e, z_q.detach())
55
+ codebook_loss = F.mse_loss(z_q, z_e.detach())
56
+ z_q = (
57
+ z_e + (z_q - z_e).detach()
58
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
59
+
60
+ return z_q, indices, z_e, commitment_loss, codebook_loss
61
+
62
+ def embed_code(self, embed_id):
63
+ return F.embedding(embed_id, self.codebook.weight)
64
+
65
+ def decode_code(self, embed_id):
66
+ return self.embed_code(embed_id).transpose(1, 2)
67
+
68
+ def decode_latents(self, latents):
69
+ encodings = rearrange(latents, "b d t -> (b t) d")
70
+ codebook = self.codebook.weight # codebook: (N x D)
71
+
72
+ # L2 normalize encodings and codebook (ViT-VQGAN)
73
+ encodings = F.normalize(encodings)
74
+ codebook = F.normalize(codebook)
75
+
76
+ # Compute euclidean distance with codebook
77
+ dist = (
78
+ encodings.pow(2).sum(1, keepdim=True)
79
+ - 2 * encodings @ codebook.t()
80
+ + codebook.pow(2).sum(1, keepdim=True).t()
81
+ )
82
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
83
+ z_q = self.decode_code(indices)
84
+ return z_q, indices
85
+
86
+ class VectorQuantize(nn.Module):
87
+ """
88
+ Implementation of VQ similar to Karpathy's repo:
89
+ https://github.com/karpathy/deep-vector-quantization
90
+ Additionally uses following tricks from Improved VQGAN
91
+ (https://arxiv.org/pdf/2110.04627.pdf):
92
+ 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
93
+ for improved codebook usage
94
+ 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
95
+ improves training stability
96
+ """
97
+
98
+ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
99
+ super().__init__()
100
+ self.codebook_size = codebook_size
101
+ self.codebook_dim = codebook_dim
102
+
103
+ self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
104
+ self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
105
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
106
+
107
+ def forward(self, z, z_mask=None):
108
+ """Quantized the input tensor using a fixed codebook and returns
109
+ the corresponding codebook vectors
110
+
111
+ Parameters
112
+ ----------
113
+ z : Tensor[B x D x T]
114
+
115
+ Returns
116
+ -------
117
+ Tensor[B x D x T]
118
+ Quantized continuous representation of input
119
+ Tensor[1]
120
+ Commitment loss to train encoder to predict vectors closer to codebook
121
+ entries
122
+ Tensor[1]
123
+ Codebook loss to update the codebook
124
+ Tensor[B x T]
125
+ Codebook indices (quantized discrete representation of input)
126
+ Tensor[B x D x T]
127
+ Projected latents (continuous representation of input before quantization)
128
+ """
129
+
130
+ # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
131
+ z_e = self.in_proj(z) # z_e : (B x D x T)
132
+ z_q, indices = self.decode_latents(z_e)
133
+
134
+ if z_mask is not None:
135
+ commitment_loss = (F.mse_loss(z_e, z_q.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum()
136
+ codebook_loss = (F.mse_loss(z_q, z_e.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum()
137
+ else:
138
+ commitment_loss = F.mse_loss(z_e, z_q.detach())
139
+ codebook_loss = F.mse_loss(z_q, z_e.detach())
140
+
141
+ z_q = (
142
+ z_e + (z_q - z_e).detach()
143
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
144
+
145
+ z_q = self.out_proj(z_q)
146
+
147
+ return z_q, commitment_loss, codebook_loss, indices, z_e
148
+
149
+ def embed_code(self, embed_id):
150
+ return F.embedding(embed_id, self.codebook.weight)
151
+
152
+ def decode_code(self, embed_id):
153
+ return self.embed_code(embed_id).transpose(1, 2)
154
+
155
+ def decode_latents(self, latents):
156
+ encodings = rearrange(latents, "b d t -> (b t) d")
157
+ codebook = self.codebook.weight # codebook: (N x D)
158
+
159
+ # L2 normalize encodings and codebook (ViT-VQGAN)
160
+ encodings = F.normalize(encodings)
161
+ codebook = F.normalize(codebook)
162
+
163
+ # Compute euclidean distance with codebook
164
+ dist = (
165
+ encodings.pow(2).sum(1, keepdim=True)
166
+ - 2 * encodings @ codebook.t()
167
+ + codebook.pow(2).sum(1, keepdim=True).t()
168
+ )
169
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
170
+ z_q = self.decode_code(indices)
171
+ return z_q, indices
172
+
173
+
174
+ class ResidualVectorQuantize(nn.Module):
175
+ """
176
+ Introduced in SoundStream: An end2end neural audio codec
177
+ https://arxiv.org/abs/2107.03312
178
+ """
179
+
180
+ def __init__(
181
+ self,
182
+ input_dim: int = 512,
183
+ n_codebooks: int = 9,
184
+ codebook_size: int = 1024,
185
+ codebook_dim: Union[int, list] = 8,
186
+ quantizer_dropout: float = 0.0,
187
+ ):
188
+ super().__init__()
189
+ if isinstance(codebook_dim, int):
190
+ codebook_dim = [codebook_dim for _ in range(n_codebooks)]
191
+
192
+ self.n_codebooks = n_codebooks
193
+ self.codebook_dim = codebook_dim
194
+ self.codebook_size = codebook_size
195
+
196
+ self.quantizers = nn.ModuleList(
197
+ [
198
+ VectorQuantize(input_dim, codebook_size, codebook_dim[i])
199
+ for i in range(n_codebooks)
200
+ ]
201
+ )
202
+ self.quantizer_dropout = quantizer_dropout
203
+
204
+ def forward(self, z, n_quantizers: int = None):
205
+ """Quantized the input tensor using a fixed set of `n` codebooks and returns
206
+ the corresponding codebook vectors
207
+ Parameters
208
+ ----------
209
+ z : Tensor[B x D x T]
210
+ n_quantizers : int, optional
211
+ No. of quantizers to use
212
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
213
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
214
+ when in training mode, and a random number of quantizers is used.
215
+ Returns
216
+ -------
217
+ dict
218
+ A dictionary with the following keys:
219
+
220
+ "z" : Tensor[B x D x T]
221
+ Quantized continuous representation of input
222
+ "codes" : Tensor[B x N x T]
223
+ Codebook indices for each codebook
224
+ (quantized discrete representation of input)
225
+ "latents" : Tensor[B x N*D x T]
226
+ Projected latents (continuous representation of input before quantization)
227
+ "vq/commitment_loss" : Tensor[1]
228
+ Commitment loss to train encoder to predict vectors closer to codebook
229
+ entries
230
+ "vq/codebook_loss" : Tensor[1]
231
+ Codebook loss to update the codebook
232
+ """
233
+ z_q = 0
234
+ residual = z
235
+ commitment_loss = 0
236
+ codebook_loss = 0
237
+
238
+ codebook_indices = []
239
+ latents = []
240
+
241
+ if n_quantizers is None:
242
+ n_quantizers = self.n_codebooks
243
+ if self.training:
244
+ n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
245
+ dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
246
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
247
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
248
+ n_quantizers = n_quantizers.to(z.device)
249
+
250
+ for i, quantizer in enumerate(self.quantizers):
251
+ if self.training is False and i >= n_quantizers:
252
+ break
253
+
254
+ z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
255
+ residual
256
+ )
257
+
258
+ # Create mask to apply quantizer dropout
259
+ mask = (
260
+ torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
261
+ )
262
+ z_q = z_q + z_q_i * mask[:, None, None]
263
+ residual = residual - z_q_i
264
+
265
+ # Sum losses
266
+ commitment_loss += (commitment_loss_i * mask).mean()
267
+ codebook_loss += (codebook_loss_i * mask).mean()
268
+
269
+ codebook_indices.append(indices_i)
270
+ latents.append(z_e_i)
271
+
272
+ codes = torch.stack(codebook_indices, dim=1)
273
+ latents = torch.cat(latents, dim=1)
274
+
275
+ return z_q, codes, latents, commitment_loss, codebook_loss
276
+
277
+ def from_codes(self, codes: torch.Tensor):
278
+ """Given the quantized codes, reconstruct the continuous representation
279
+ Parameters
280
+ ----------
281
+ codes : Tensor[B x N x T]
282
+ Quantized discrete representation of input
283
+ Returns
284
+ -------
285
+ Tensor[B x D x T]
286
+ Quantized continuous representation of input
287
+ """
288
+ z_q = 0.0
289
+ z_p = []
290
+ n_codebooks = codes.shape[1]
291
+ for i in range(n_codebooks):
292
+ z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
293
+ z_p.append(z_p_i)
294
+
295
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
296
+ z_q = z_q + z_q_i
297
+ return z_q, torch.cat(z_p, dim=1), codes
298
+
299
+ def from_latents(self, latents: torch.Tensor):
300
+ """Given the unquantized latents, reconstruct the
301
+ continuous representation after quantization.
302
+
303
+ Parameters
304
+ ----------
305
+ latents : Tensor[B x N x T]
306
+ Continuous representation of input after projection
307
+
308
+ Returns
309
+ -------
310
+ Tensor[B x D x T]
311
+ Quantized representation of full-projected space
312
+ Tensor[B x D x T]
313
+ Quantized representation of latent space
314
+ """
315
+ z_q = 0
316
+ z_p = []
317
+ codes = []
318
+ dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
319
+
320
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
321
+ 0
322
+ ]
323
+ for i in range(n_codebooks):
324
+ j, k = dims[i], dims[i + 1]
325
+ z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
326
+ z_p.append(z_p_i)
327
+ codes.append(codes_i)
328
+
329
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
330
+ z_q = z_q + z_q_i
331
+
332
+ return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
333
+
334
+
335
+ if __name__ == "__main__":
336
+ rvq = ResidualVectorQuantize(quantizer_dropout=True)
337
+ x = torch.randn(16, 512, 80)
338
+ y = rvq(x)
339
+ print(y["latents"].shape)
dac/utils/__init__.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import argbind
4
+ from audiotools import ml
5
+
6
+ import dac
7
+
8
+ DAC = dac.model.DAC
9
+ Accelerator = ml.Accelerator
10
+
11
+ __MODEL_LATEST_TAGS__ = {
12
+ ("44khz", "8kbps"): "0.0.1",
13
+ ("24khz", "8kbps"): "0.0.4",
14
+ ("16khz", "8kbps"): "0.0.5",
15
+ ("44khz", "16kbps"): "1.0.0",
16
+ }
17
+
18
+ __MODEL_URLS__ = {
19
+ (
20
+ "44khz",
21
+ "0.0.1",
22
+ "8kbps",
23
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth",
24
+ (
25
+ "24khz",
26
+ "0.0.4",
27
+ "8kbps",
28
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth",
29
+ (
30
+ "16khz",
31
+ "0.0.5",
32
+ "8kbps",
33
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth",
34
+ (
35
+ "44khz",
36
+ "1.0.0",
37
+ "16kbps",
38
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth",
39
+ }
40
+
41
+
42
+ @argbind.bind(group="download", positional=True, without_prefix=True)
43
+ def download(
44
+ model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest"
45
+ ):
46
+ """
47
+ Function that downloads the weights file from URL if a local cache is not found.
48
+
49
+ Parameters
50
+ ----------
51
+ model_type : str
52
+ The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz".
53
+ model_bitrate: str
54
+ Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
55
+ Only 44khz model supports 16kbps.
56
+ tag : str
57
+ The tag of the model to download. Defaults to "latest".
58
+
59
+ Returns
60
+ -------
61
+ Path
62
+ Directory path required to load model via audiotools.
63
+ """
64
+ model_type = model_type.lower()
65
+ tag = tag.lower()
66
+
67
+ assert model_type in [
68
+ "44khz",
69
+ "24khz",
70
+ "16khz",
71
+ ], "model_type must be one of '44khz', '24khz', or '16khz'"
72
+
73
+ assert model_bitrate in [
74
+ "8kbps",
75
+ "16kbps",
76
+ ], "model_bitrate must be one of '8kbps', or '16kbps'"
77
+
78
+ if tag == "latest":
79
+ tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)]
80
+
81
+ download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None)
82
+
83
+ if download_link is None:
84
+ raise ValueError(
85
+ f"Could not find model with tag {tag} and model type {model_type}"
86
+ )
87
+
88
+ local_path = (
89
+ Path.home()
90
+ / ".cache"
91
+ / "descript"
92
+ / "dac"
93
+ / f"weights_{model_type}_{model_bitrate}_{tag}.pth"
94
+ )
95
+ if not local_path.exists():
96
+ local_path.parent.mkdir(parents=True, exist_ok=True)
97
+
98
+ # Download the model
99
+ import requests
100
+
101
+ response = requests.get(download_link)
102
+
103
+ if response.status_code != 200:
104
+ raise ValueError(
105
+ f"Could not download model. Received response code {response.status_code}"
106
+ )
107
+ local_path.write_bytes(response.content)
108
+
109
+ return local_path
110
+
111
+
112
+ def load_model(
113
+ model_type: str = "44khz",
114
+ model_bitrate: str = "8kbps",
115
+ tag: str = "latest",
116
+ load_path: str = None,
117
+ ):
118
+ if not load_path:
119
+ load_path = download(
120
+ model_type=model_type, model_bitrate=model_bitrate, tag=tag
121
+ )
122
+ generator = DAC.load(load_path)
123
+ return generator
dac/utils/decode.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from pathlib import Path
3
+
4
+ import argbind
5
+ import numpy as np
6
+ import torch
7
+ from audiotools import AudioSignal
8
+ from tqdm import tqdm
9
+
10
+ from dac import DACFile
11
+ from dac.utils import load_model
12
+
13
+ warnings.filterwarnings("ignore", category=UserWarning)
14
+
15
+
16
+ @argbind.bind(group="decode", positional=True, without_prefix=True)
17
+ @torch.inference_mode()
18
+ @torch.no_grad()
19
+ def decode(
20
+ input: str,
21
+ output: str = "",
22
+ weights_path: str = "",
23
+ model_tag: str = "latest",
24
+ model_bitrate: str = "8kbps",
25
+ device: str = "cuda",
26
+ model_type: str = "44khz",
27
+ verbose: bool = False,
28
+ ):
29
+ """Decode audio from codes.
30
+
31
+ Parameters
32
+ ----------
33
+ input : str
34
+ Path to input directory or file
35
+ output : str, optional
36
+ Path to output directory, by default "".
37
+ If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
38
+ weights_path : str, optional
39
+ Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
40
+ model_tag and model_type.
41
+ model_tag : str, optional
42
+ Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
43
+ model_bitrate: str
44
+ Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
45
+ device : str, optional
46
+ Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU.
47
+ model_type : str, optional
48
+ The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
49
+ """
50
+ generator = load_model(
51
+ model_type=model_type,
52
+ model_bitrate=model_bitrate,
53
+ tag=model_tag,
54
+ load_path=weights_path,
55
+ )
56
+ generator.to(device)
57
+ generator.eval()
58
+
59
+ # Find all .dac files in input directory
60
+ _input = Path(input)
61
+ input_files = list(_input.glob("**/*.dac"))
62
+
63
+ # If input is a .dac file, add it to the list
64
+ if _input.suffix == ".dac":
65
+ input_files.append(_input)
66
+
67
+ # Create output directory
68
+ output = Path(output)
69
+ output.mkdir(parents=True, exist_ok=True)
70
+
71
+ for i in tqdm(range(len(input_files)), desc=f"Decoding files"):
72
+ # Load file
73
+ artifact = DACFile.load(input_files[i])
74
+
75
+ # Reconstruct audio from codes
76
+ recons = generator.decompress(artifact, verbose=verbose)
77
+
78
+ # Compute output path
79
+ relative_path = input_files[i].relative_to(input)
80
+ output_dir = output / relative_path.parent
81
+ if not relative_path.name:
82
+ output_dir = output
83
+ relative_path = input_files[i]
84
+ output_name = relative_path.with_suffix(".wav").name
85
+ output_path = output_dir / output_name
86
+ output_path.parent.mkdir(parents=True, exist_ok=True)
87
+
88
+ # Write to file
89
+ recons.write(output_path)
90
+
91
+
92
+ if __name__ == "__main__":
93
+ args = argbind.parse_args()
94
+ with argbind.scope(args):
95
+ decode()
dac/utils/encode.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from pathlib import Path
4
+
5
+ import argbind
6
+ import numpy as np
7
+ import torch
8
+ from audiotools import AudioSignal
9
+ from audiotools.core import util
10
+ from tqdm import tqdm
11
+
12
+ from dac.utils import load_model
13
+
14
+ warnings.filterwarnings("ignore", category=UserWarning)
15
+
16
+
17
+ @argbind.bind(group="encode", positional=True, without_prefix=True)
18
+ @torch.inference_mode()
19
+ @torch.no_grad()
20
+ def encode(
21
+ input: str,
22
+ output: str = "",
23
+ weights_path: str = "",
24
+ model_tag: str = "latest",
25
+ model_bitrate: str = "8kbps",
26
+ n_quantizers: int = None,
27
+ device: str = "cuda",
28
+ model_type: str = "44khz",
29
+ win_duration: float = 5.0,
30
+ verbose: bool = False,
31
+ ):
32
+ """Encode audio files in input path to .dac format.
33
+
34
+ Parameters
35
+ ----------
36
+ input : str
37
+ Path to input audio file or directory
38
+ output : str, optional
39
+ Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
40
+ weights_path : str, optional
41
+ Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
42
+ model_tag and model_type.
43
+ model_tag : str, optional
44
+ Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
45
+ model_bitrate: str
46
+ Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
47
+ n_quantizers : int, optional
48
+ Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate.
49
+ device : str, optional
50
+ Device to use, by default "cuda"
51
+ model_type : str, optional
52
+ The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
53
+ """
54
+ generator = load_model(
55
+ model_type=model_type,
56
+ model_bitrate=model_bitrate,
57
+ tag=model_tag,
58
+ load_path=weights_path,
59
+ )
60
+ generator.to(device)
61
+ generator.eval()
62
+ kwargs = {"n_quantizers": n_quantizers}
63
+
64
+ # Find all audio files in input path
65
+ input = Path(input)
66
+ audio_files = util.find_audio(input)
67
+
68
+ output = Path(output)
69
+ output.mkdir(parents=True, exist_ok=True)
70
+
71
+ for i in tqdm(range(len(audio_files)), desc="Encoding files"):
72
+ # Load file
73
+ signal = AudioSignal(audio_files[i])
74
+
75
+ # Encode audio to .dac format
76
+ artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs)
77
+
78
+ # Compute output path
79
+ relative_path = audio_files[i].relative_to(input)
80
+ output_dir = output / relative_path.parent
81
+ if not relative_path.name:
82
+ output_dir = output
83
+ relative_path = audio_files[i]
84
+ output_name = relative_path.with_suffix(".dac").name
85
+ output_path = output_dir / output_name
86
+ output_path.parent.mkdir(parents=True, exist_ok=True)
87
+
88
+ artifact.save(output_path)
89
+
90
+
91
+ if __name__ == "__main__":
92
+ args = argbind.parse_args()
93
+ with argbind.scope(args):
94
+ encode()
examples/reference/dingzhen_0.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3db260824d11f56cdf2fccf2b84ad83c95a732ddfa2f8cb8a20b68ca06ea9ff8
3
+ size 1088420
examples/source/yae_0.wav ADDED
Binary file (528 kB). View file
 
modules/alias_free_torch/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+
3
+ from .filter import *
4
+ from .resample import *
5
+ from .act import *
modules/alias_free_torch/act.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+
3
+ import torch.nn as nn
4
+ from .resample import UpSample1d, DownSample1d
5
+
6
+
7
+ class Activation1d(nn.Module):
8
+ def __init__(
9
+ self,
10
+ activation,
11
+ up_ratio: int = 2,
12
+ down_ratio: int = 2,
13
+ up_kernel_size: int = 12,
14
+ down_kernel_size: int = 12,
15
+ ):
16
+ super().__init__()
17
+ self.up_ratio = up_ratio
18
+ self.down_ratio = down_ratio
19
+ self.act = activation
20
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
21
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
22
+
23
+ # x: [B,C,T]
24
+ def forward(self, x):
25
+ x = self.upsample(x)
26
+ x = self.act(x)
27
+ x = self.downsample(x)
28
+
29
+ return x
modules/alias_free_torch/filter.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import math
7
+
8
+ if "sinc" in dir(torch):
9
+ sinc = torch.sinc
10
+ else:
11
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
12
+ # https://adefossez.github.io/julius/julius/core.html
13
+ def sinc(x: torch.Tensor):
14
+ """
15
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
16
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
17
+ """
18
+ return torch.where(
19
+ x == 0,
20
+ torch.tensor(1.0, device=x.device, dtype=x.dtype),
21
+ torch.sin(math.pi * x) / math.pi / x,
22
+ )
23
+
24
+
25
+ # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
26
+ # https://adefossez.github.io/julius/julius/lowpass.html
27
+ def kaiser_sinc_filter1d(
28
+ cutoff, half_width, kernel_size
29
+ ): # return filter [1,1,kernel_size]
30
+ even = kernel_size % 2 == 0
31
+ half_size = kernel_size // 2
32
+
33
+ # For kaiser window
34
+ delta_f = 4 * half_width
35
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
36
+ if A > 50.0:
37
+ beta = 0.1102 * (A - 8.7)
38
+ elif A >= 21.0:
39
+ beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
40
+ else:
41
+ beta = 0.0
42
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
43
+
44
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
45
+ if even:
46
+ time = torch.arange(-half_size, half_size) + 0.5
47
+ else:
48
+ time = torch.arange(kernel_size) - half_size
49
+ if cutoff == 0:
50
+ filter_ = torch.zeros_like(time)
51
+ else:
52
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
53
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
54
+ # of the constant component in the input signal.
55
+ filter_ /= filter_.sum()
56
+ filter = filter_.view(1, 1, kernel_size)
57
+
58
+ return filter
59
+
60
+
61
+ class LowPassFilter1d(nn.Module):
62
+ def __init__(
63
+ self,
64
+ cutoff=0.5,
65
+ half_width=0.6,
66
+ stride: int = 1,
67
+ padding: bool = True,
68
+ padding_mode: str = "replicate",
69
+ kernel_size: int = 12,
70
+ ):
71
+ # kernel_size should be even number for stylegan3 setup,
72
+ # in this implementation, odd number is also possible.
73
+ super().__init__()
74
+ if cutoff < -0.0:
75
+ raise ValueError("Minimum cutoff must be larger than zero.")
76
+ if cutoff > 0.5:
77
+ raise ValueError("A cutoff above 0.5 does not make sense.")
78
+ self.kernel_size = kernel_size
79
+ self.even = kernel_size % 2 == 0
80
+ self.pad_left = kernel_size // 2 - int(self.even)
81
+ self.pad_right = kernel_size // 2
82
+ self.stride = stride
83
+ self.padding = padding
84
+ self.padding_mode = padding_mode
85
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
86
+ self.register_buffer("filter", filter)
87
+
88
+ # input [B, C, T]
89
+ def forward(self, x):
90
+ _, C, _ = x.shape
91
+
92
+ if self.padding:
93
+ x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
94
+ out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
95
+
96
+ return out
modules/alias_free_torch/resample.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ from .filter import LowPassFilter1d
6
+ from .filter import kaiser_sinc_filter1d
7
+
8
+
9
+ class UpSample1d(nn.Module):
10
+ def __init__(self, ratio=2, kernel_size=None):
11
+ super().__init__()
12
+ self.ratio = ratio
13
+ self.kernel_size = (
14
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
15
+ )
16
+ self.stride = ratio
17
+ self.pad = self.kernel_size // ratio - 1
18
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
19
+ self.pad_right = (
20
+ self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
21
+ )
22
+ filter = kaiser_sinc_filter1d(
23
+ cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
24
+ )
25
+ self.register_buffer("filter", filter)
26
+
27
+ # x: [B, C, T]
28
+ def forward(self, x):
29
+ _, C, _ = x.shape
30
+
31
+ x = F.pad(x, (self.pad, self.pad), mode="replicate")
32
+ x = self.ratio * F.conv_transpose1d(
33
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
34
+ )
35
+ x = x[..., self.pad_left : -self.pad_right]
36
+
37
+ return x
38
+
39
+
40
+ class DownSample1d(nn.Module):
41
+ def __init__(self, ratio=2, kernel_size=None):
42
+ super().__init__()
43
+ self.ratio = ratio
44
+ self.kernel_size = (
45
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
46
+ )
47
+ self.lowpass = LowPassFilter1d(
48
+ cutoff=0.5 / ratio,
49
+ half_width=0.6 / ratio,
50
+ stride=ratio,
51
+ kernel_size=self.kernel_size,
52
+ )
53
+
54
+ def forward(self, x):
55
+ xx = self.lowpass(x)
56
+
57
+ return xx
modules/commons.py CHANGED
@@ -384,12 +384,50 @@ def build_model(args, stage="DiT"):
384
  sampling_ratios=args.length_regulator.sampling_ratios,
385
  is_discrete=args.length_regulator.is_discrete,
386
  codebook_size=args.length_regulator.content_codebook_size,
 
 
 
 
 
 
387
  )
388
  cfm = CFM(args)
389
  nets = Munch(
390
  cfm=cfm,
391
  length_regulator=length_regulator,
392
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  else:
394
  raise ValueError(f"Unknown stage: {stage}")
395
 
 
384
  sampling_ratios=args.length_regulator.sampling_ratios,
385
  is_discrete=args.length_regulator.is_discrete,
386
  codebook_size=args.length_regulator.content_codebook_size,
387
+ token_dropout_prob=args.length_regulator.token_dropout_prob if hasattr(args.length_regulator, "token_dropout_prob") else 0.0,
388
+ token_dropout_range=args.length_regulator.token_dropout_range if hasattr(args.length_regulator, "token_dropout_range") else 0.0,
389
+ n_codebooks=args.length_regulator.n_codebooks if hasattr(args.length_regulator, "n_codebooks") else 1,
390
+ quantizer_dropout=args.length_regulator.quantizer_dropout if hasattr(args.length_regulator, "quantizer_dropout") else 0.0,
391
+ f0_condition=args.length_regulator.f0_condition if hasattr(args.length_regulator, "f0_condition") else False,
392
+ n_f0_bins=args.length_regulator.n_f0_bins if hasattr(args.length_regulator, "n_f0_bins") else 512,
393
  )
394
  cfm = CFM(args)
395
  nets = Munch(
396
  cfm=cfm,
397
  length_regulator=length_regulator,
398
  )
399
+ elif stage == 'codec':
400
+ from dac.model.dac import Encoder
401
+ from modules.quantize import (
402
+ FAquantizer,
403
+ )
404
+
405
+ encoder = Encoder(
406
+ d_model=args.DAC.encoder_dim,
407
+ strides=args.DAC.encoder_rates,
408
+ d_latent=1024,
409
+ causal=args.causal,
410
+ lstm=args.lstm,
411
+ )
412
+
413
+ quantizer = FAquantizer(
414
+ in_dim=1024,
415
+ n_p_codebooks=1,
416
+ n_c_codebooks=args.n_c_codebooks,
417
+ n_t_codebooks=2,
418
+ n_r_codebooks=3,
419
+ codebook_size=1024,
420
+ codebook_dim=8,
421
+ quantizer_dropout=0.5,
422
+ causal=args.causal,
423
+ separate_prosody_encoder=args.separate_prosody_encoder,
424
+ timbre_norm=args.timbre_norm,
425
+ )
426
+
427
+ nets = Munch(
428
+ encoder=encoder,
429
+ quantizer=quantizer,
430
+ )
431
  else:
432
  raise ValueError(f"Unknown stage: {stage}")
433
 
modules/cosyvoice_tokenizer/frontend.py CHANGED
@@ -1,52 +1,54 @@
1
- # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from functools import partial
15
- import onnxruntime
16
- import torch
17
- import numpy as np
18
- import whisper
19
- import torchaudio.compliance.kaldi as kaldi
20
-
21
- class CosyVoiceFrontEnd:
22
-
23
- def __init__(self, speech_tokenizer_model: str, device: str = 'cuda', device_id: int = 0):
24
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
- option = onnxruntime.SessionOptions()
26
- option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
27
- option.intra_op_num_threads = 1
28
- self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option, providers=["CPUExecutionProvider"])
29
-
30
- def extract_speech_token(self, speech):
31
- feat = whisper.log_mel_spectrogram(speech, n_mels=128)
32
- speech_token = self.speech_tokenizer_session.run(None, {self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
33
- self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
34
- speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
35
- speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
36
- return speech_token, speech_token_len
37
-
38
- def _extract_spk_embedding(self, speech):
39
- feat = kaldi.fbank(speech,
40
- num_mel_bins=80,
41
- dither=0,
42
- sample_frequency=16000)
43
- feat = feat - feat.mean(dim=0, keepdim=True)
44
- embedding = self.campplus_session.run(None, {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
45
- embedding = torch.tensor([embedding]).to(self.device)
46
- return embedding
47
-
48
- def _extract_speech_feat(self, speech):
49
- speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
50
- speech_feat = speech_feat.unsqueeze(dim=0)
51
- speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
 
 
52
  return speech_feat, speech_feat_len
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from functools import partial
15
+ import onnxruntime
16
+ import torch
17
+ import numpy as np
18
+ import whisper
19
+ import torchaudio.compliance.kaldi as kaldi
20
+
21
+ class CosyVoiceFrontEnd:
22
+
23
+ def __init__(self, speech_tokenizer_model: str, device: str = 'cuda', device_id: int = 0):
24
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
+ option = onnxruntime.SessionOptions()
26
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
27
+ option.intra_op_num_threads = 1
28
+ self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option, providers=["CUDAExecutionProvider"if device == "cuda" else "CPUExecutionProvider"])
29
+ if device == 'cuda':
30
+ self.speech_tokenizer_session.set_providers(['CUDAExecutionProvider'], [ {'device_id': device_id}])
31
+
32
+ def extract_speech_token(self, speech):
33
+ feat = whisper.log_mel_spectrogram(speech, n_mels=128)
34
+ speech_token = self.speech_tokenizer_session.run(None, {self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
35
+ self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
36
+ speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
37
+ speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
38
+ return speech_token, speech_token_len
39
+
40
+ def _extract_spk_embedding(self, speech):
41
+ feat = kaldi.fbank(speech,
42
+ num_mel_bins=80,
43
+ dither=0,
44
+ sample_frequency=16000)
45
+ feat = feat - feat.mean(dim=0, keepdim=True)
46
+ embedding = self.campplus_session.run(None, {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
47
+ embedding = torch.tensor([embedding]).to(self.device)
48
+ return embedding
49
+
50
+ def _extract_speech_feat(self, speech):
51
+ speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
52
+ speech_feat = speech_feat.unsqueeze(dim=0)
53
+ speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
54
  return speech_feat, speech_feat_len
modules/diffusion_transformer.py CHANGED
@@ -106,7 +106,7 @@ class DiT(torch.nn.Module):
106
  self.style_as_token = args.DiT.style_as_token if hasattr(args.DiT, 'style_as_token') else False
107
  self.uvit_skip_connection = args.DiT.uvit_skip_connection if hasattr(args.DiT, 'uvit_skip_connection') else False
108
  model_args = ModelArgs(
109
- block_size=args.DiT.block_size,
110
  n_layer=args.DiT.depth,
111
  n_head=args.DiT.num_heads,
112
  dim=args.DiT.hidden_dim,
@@ -139,7 +139,7 @@ class DiT(torch.nn.Module):
139
  # self.style_embedder1 = weight_norm(nn.Linear(1024, args.DiT.hidden_dim, bias=True))
140
  # self.style_embedder2 = weight_norm(nn.Linear(1024, args.style_encoder.dim, bias=True))
141
 
142
- input_pos = torch.arange(args.DiT.block_size)
143
  self.register_buffer("input_pos", input_pos)
144
 
145
  self.conv1 = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim)
 
106
  self.style_as_token = args.DiT.style_as_token if hasattr(args.DiT, 'style_as_token') else False
107
  self.uvit_skip_connection = args.DiT.uvit_skip_connection if hasattr(args.DiT, 'uvit_skip_connection') else False
108
  model_args = ModelArgs(
109
+ block_size=8192,#args.DiT.block_size,
110
  n_layer=args.DiT.depth,
111
  n_head=args.DiT.num_heads,
112
  dim=args.DiT.hidden_dim,
 
139
  # self.style_embedder1 = weight_norm(nn.Linear(1024, args.DiT.hidden_dim, bias=True))
140
  # self.style_embedder2 = weight_norm(nn.Linear(1024, args.style_encoder.dim, bias=True))
141
 
142
+ input_pos = torch.arange(8192)
143
  self.register_buffer("input_pos", input_pos)
144
 
145
  self.conv1 = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim)
modules/length_regulator.py CHANGED
@@ -1,4 +1,5 @@
1
  from typing import Tuple
 
2
  import torch.nn as nn
3
  from torch.nn import functional as F
4
  from modules.commons import sequence_mask
@@ -13,6 +14,12 @@ class InterpolateRegulator(nn.Module):
13
  codebook_size: int = 1024, # for discrete only
14
  out_channels: int = None,
15
  groups: int = 1,
 
 
 
 
 
 
16
  ):
17
  super().__init__()
18
  self.sampling_ratios = sampling_ratios
@@ -31,12 +38,59 @@ class InterpolateRegulator(nn.Module):
31
  self.embedding = nn.Embedding(codebook_size, channels)
32
  self.is_discrete = is_discrete
33
 
34
- def forward(self, x, ylens=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  if self.is_discrete:
36
- x = self.embedding(x)
 
 
 
 
 
 
 
 
 
 
37
  # x in (B, T, D)
38
  mask = sequence_mask(ylens).unsqueeze(-1)
39
  x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
 
 
 
 
 
 
 
40
  out = self.model(x).transpose(1, 2).contiguous()
41
  olens = ylens
42
  return out * mask, olens
 
1
  from typing import Tuple
2
+ import torch
3
  import torch.nn as nn
4
  from torch.nn import functional as F
5
  from modules.commons import sequence_mask
 
14
  codebook_size: int = 1024, # for discrete only
15
  out_channels: int = None,
16
  groups: int = 1,
17
+ token_dropout_prob: float = 0.5, # randomly drop out input tokens
18
+ token_dropout_range: float = 0.5, # randomly drop out input tokens
19
+ n_codebooks: int = 1, # number of codebooks
20
+ quantizer_dropout: float = 0.0, # dropout for quantizer
21
+ f0_condition: bool = False,
22
+ n_f0_bins: int = 512,
23
  ):
24
  super().__init__()
25
  self.sampling_ratios = sampling_ratios
 
38
  self.embedding = nn.Embedding(codebook_size, channels)
39
  self.is_discrete = is_discrete
40
 
41
+ self.mask_token = nn.Parameter(torch.zeros(1, channels))
42
+
43
+ self.n_codebooks = n_codebooks
44
+ if n_codebooks > 1:
45
+ self.extra_codebooks = nn.ModuleList([
46
+ nn.Embedding(codebook_size, channels) for _ in range(n_codebooks - 1)
47
+ ])
48
+ self.token_dropout_prob = token_dropout_prob
49
+ self.token_dropout_range = token_dropout_range
50
+ self.quantizer_dropout = quantizer_dropout
51
+
52
+ if f0_condition:
53
+ self.f0_embedding = nn.Embedding(n_f0_bins, channels)
54
+ self.f0_condition = f0_condition
55
+ self.n_f0_bins = n_f0_bins
56
+ self.f0_bins = torch.arange(2, 1024, 1024 // n_f0_bins)
57
+ self.f0_mask = nn.Parameter(torch.zeros(1, channels))
58
+ else:
59
+ self.f0_condition = False
60
+
61
+ def forward(self, x, ylens=None, n_quantizers=None, f0=None):
62
+ # apply token drop
63
+ if self.training:
64
+ n_quantizers = torch.ones((x.shape[0],)) * self.n_codebooks
65
+ dropout = torch.randint(1, self.n_codebooks + 1, (x.shape[0],))
66
+ n_dropout = int(x.shape[0] * self.quantizer_dropout)
67
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
68
+ n_quantizers = n_quantizers.to(x.device)
69
+ # decide whether to drop for each sample in batch
70
+ else:
71
+ n_quantizers = torch.ones((x.shape[0],), device=x.device) * (self.n_codebooks if n_quantizers is None else n_quantizers)
72
  if self.is_discrete:
73
+ if self.n_codebooks > 1:
74
+ assert len(x.size()) == 3
75
+ x_emb = self.embedding(x[:, 0])
76
+ for i, emb in enumerate(self.extra_codebooks):
77
+ x_emb = x_emb + (n_quantizers > i+1)[..., None, None] * emb(x[:, i+1])
78
+ x = x_emb
79
+ elif self.n_codebooks == 1:
80
+ if len(x.size()) == 2:
81
+ x = self.embedding(x)
82
+ else:
83
+ x = self.embedding(x[:, 0])
84
  # x in (B, T, D)
85
  mask = sequence_mask(ylens).unsqueeze(-1)
86
  x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
87
+ if self.f0_condition:
88
+ quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T)
89
+ drop_f0 = torch.rand(quantized_f0.size(0)).to(f0.device) < self.quantizer_dropout
90
+ f0_emb = self.f0_embedding(quantized_f0)
91
+ f0_emb[drop_f0] = self.f0_mask
92
+ f0_emb = F.interpolate(f0_emb.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
93
+ x = x + f0_emb
94
  out = self.model(x).transpose(1, 2).contiguous()
95
  olens = ylens
96
  return out * mask, olens
modules/quantize.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dac.nn.quantize import ResidualVectorQuantize
2
+ from torch import nn
3
+ from modules.wavenet import WN
4
+ import torch
5
+ import torchaudio
6
+ import torchaudio.functional as audio_F
7
+ import numpy as np
8
+ from .alias_free_torch import *
9
+ from torch.nn.utils import weight_norm
10
+ from torch import nn, sin, pow
11
+ from einops.layers.torch import Rearrange
12
+ from dac.model.encodec import SConv1d
13
+
14
+ def init_weights(m):
15
+ if isinstance(m, nn.Conv1d):
16
+ nn.init.trunc_normal_(m.weight, std=0.02)
17
+ nn.init.constant_(m.bias, 0)
18
+
19
+
20
+ def WNConv1d(*args, **kwargs):
21
+ return weight_norm(nn.Conv1d(*args, **kwargs))
22
+
23
+
24
+ def WNConvTranspose1d(*args, **kwargs):
25
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
26
+
27
+ class SnakeBeta(nn.Module):
28
+ """
29
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
30
+ Shape:
31
+ - Input: (B, C, T)
32
+ - Output: (B, C, T), same shape as the input
33
+ Parameters:
34
+ - alpha - trainable parameter that controls frequency
35
+ - beta - trainable parameter that controls magnitude
36
+ References:
37
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
38
+ https://arxiv.org/abs/2006.08195
39
+ Examples:
40
+ >>> a1 = snakebeta(256)
41
+ >>> x = torch.randn(256)
42
+ >>> x = a1(x)
43
+ """
44
+
45
+ def __init__(
46
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
47
+ ):
48
+ """
49
+ Initialization.
50
+ INPUT:
51
+ - in_features: shape of the input
52
+ - alpha - trainable parameter that controls frequency
53
+ - beta - trainable parameter that controls magnitude
54
+ alpha is initialized to 1 by default, higher values = higher-frequency.
55
+ beta is initialized to 1 by default, higher values = higher-magnitude.
56
+ alpha will be trained along with the rest of your model.
57
+ """
58
+ super(SnakeBeta, self).__init__()
59
+ self.in_features = in_features
60
+
61
+ # initialize alpha
62
+ self.alpha_logscale = alpha_logscale
63
+ if self.alpha_logscale: # log scale alphas initialized to zeros
64
+ self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
65
+ self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
66
+ else: # linear scale alphas initialized to ones
67
+ self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
68
+ self.beta = nn.Parameter(torch.ones(in_features) * alpha)
69
+
70
+ self.alpha.requires_grad = alpha_trainable
71
+ self.beta.requires_grad = alpha_trainable
72
+
73
+ self.no_div_by_zero = 0.000000001
74
+
75
+ def forward(self, x):
76
+ """
77
+ Forward pass of the function.
78
+ Applies the function to the input elementwise.
79
+ SnakeBeta := x + 1/b * sin^2 (xa)
80
+ """
81
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
82
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
83
+ if self.alpha_logscale:
84
+ alpha = torch.exp(alpha)
85
+ beta = torch.exp(beta)
86
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
87
+
88
+ return x
89
+
90
+ class ResidualUnit(nn.Module):
91
+ def __init__(self, dim: int = 16, dilation: int = 1):
92
+ super().__init__()
93
+ pad = ((7 - 1) * dilation) // 2
94
+ self.block = nn.Sequential(
95
+ Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
96
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
97
+ Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
98
+ WNConv1d(dim, dim, kernel_size=1),
99
+ )
100
+
101
+ def forward(self, x):
102
+ return x + self.block(x)
103
+
104
+ class CNNLSTM(nn.Module):
105
+ def __init__(self, indim, outdim, head, global_pred=False):
106
+ super().__init__()
107
+ self.global_pred = global_pred
108
+ self.model = nn.Sequential(
109
+ ResidualUnit(indim, dilation=1),
110
+ ResidualUnit(indim, dilation=2),
111
+ ResidualUnit(indim, dilation=3),
112
+ Activation1d(activation=SnakeBeta(indim, alpha_logscale=True)),
113
+ Rearrange("b c t -> b t c"),
114
+ )
115
+ self.heads = nn.ModuleList([nn.Linear(indim, outdim) for i in range(head)])
116
+
117
+ def forward(self, x):
118
+ # x: [B, C, T]
119
+ x = self.model(x)
120
+ if self.global_pred:
121
+ x = torch.mean(x, dim=1, keepdim=False)
122
+ outs = [head(x) for head in self.heads]
123
+ return outs
124
+
125
+ def sequence_mask(length, max_length=None):
126
+ if max_length is None:
127
+ max_length = length.max()
128
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
129
+ return x.unsqueeze(0) < length.unsqueeze(1)
130
+ class FAquantizer(nn.Module):
131
+ def __init__(self, in_dim=1024,
132
+ n_p_codebooks=1,
133
+ n_c_codebooks=2,
134
+ n_t_codebooks=2,
135
+ n_r_codebooks=3,
136
+ codebook_size=1024,
137
+ codebook_dim=8,
138
+ quantizer_dropout=0.5,
139
+ causal=False,
140
+ separate_prosody_encoder=False,
141
+ timbre_norm=False,):
142
+ super(FAquantizer, self).__init__()
143
+ conv1d_type = SConv1d# if causal else nn.Conv1d
144
+ self.prosody_quantizer = ResidualVectorQuantize(
145
+ input_dim=in_dim,
146
+ n_codebooks=n_p_codebooks,
147
+ codebook_size=codebook_size,
148
+ codebook_dim=codebook_dim,
149
+ quantizer_dropout=quantizer_dropout,
150
+ )
151
+
152
+ self.content_quantizer = ResidualVectorQuantize(
153
+ input_dim=in_dim,
154
+ n_codebooks=n_c_codebooks,
155
+ codebook_size=codebook_size,
156
+ codebook_dim=codebook_dim,
157
+ quantizer_dropout=quantizer_dropout,
158
+ )
159
+
160
+ self.residual_quantizer = ResidualVectorQuantize(
161
+ input_dim=in_dim,
162
+ n_codebooks=n_r_codebooks,
163
+ codebook_size=codebook_size,
164
+ codebook_dim=codebook_dim,
165
+ quantizer_dropout=quantizer_dropout,
166
+ )
167
+
168
+ self.melspec_linear = conv1d_type(in_channels=20, out_channels=256, kernel_size=1, causal=causal)
169
+ self.melspec_encoder = WN(hidden_channels=256, kernel_size=5, dilation_rate=1, n_layers=8, gin_channels=0, p_dropout=0.2, causal=causal)
170
+ self.melspec_linear2 = conv1d_type(in_channels=256, out_channels=1024, kernel_size=1, causal=causal)
171
+
172
+ self.prob_random_mask_residual = 0.75
173
+
174
+ SPECT_PARAMS = {
175
+ "n_fft": 2048,
176
+ "win_length": 1200,
177
+ "hop_length": 300,
178
+ }
179
+ MEL_PARAMS = {
180
+ "n_mels": 80,
181
+ }
182
+
183
+ self.to_mel = torchaudio.transforms.MelSpectrogram(
184
+ n_mels=MEL_PARAMS["n_mels"], sample_rate=24000, **SPECT_PARAMS
185
+ )
186
+ self.mel_mean, self.mel_std = -4, 4
187
+ self.frame_rate = 24000 / 300
188
+ self.hop_length = 300
189
+
190
+ def preprocess(self, wave_tensor, n_bins=20):
191
+ mel_tensor = self.to_mel(wave_tensor.squeeze(1))
192
+ mel_tensor = (torch.log(1e-5 + mel_tensor) - self.mel_mean) / self.mel_std
193
+ return mel_tensor[:, :n_bins, :int(wave_tensor.size(-1) / self.hop_length)]
194
+
195
+ def forward(self, x, wave_segments):
196
+ outs = 0
197
+ prosody_feature = self.preprocess(wave_segments)
198
+
199
+ f0_input = prosody_feature # (B, T, 20)
200
+ f0_input = self.melspec_linear(f0_input)
201
+ f0_input = self.melspec_encoder(f0_input, torch.ones(f0_input.shape[0], 1, f0_input.shape[2]).to(
202
+ f0_input.device).bool())
203
+ f0_input = self.melspec_linear2(f0_input)
204
+
205
+ common_min_size = min(f0_input.size(2), x.size(2))
206
+ f0_input = f0_input[:, :, :common_min_size]
207
+
208
+ x = x[:, :, :common_min_size]
209
+
210
+ z_p, codes_p, latents_p, commitment_loss_p, codebook_loss_p = self.prosody_quantizer(
211
+ f0_input, 1
212
+ )
213
+ outs += z_p.detach()
214
+
215
+ z_c, codes_c, latents_c, commitment_loss_c, codebook_loss_c = self.content_quantizer(
216
+ x, 2
217
+ )
218
+ outs += z_c.detach()
219
+
220
+ residual_feature = x - z_p.detach() - z_c.detach()
221
+
222
+ z_r, codes_r, latents_r, commitment_loss_r, codebook_loss_r = self.residual_quantizer(
223
+ residual_feature, 3
224
+ )
225
+
226
+ quantized = [z_p, z_c, z_r]
227
+ codes = [codes_p, codes_c, codes_r]
228
+
229
+ return quantized, codes