dslee2601 commited on
Commit
6752381
·
verified ·
1 Parent(s): 9642163

Upload DAC

Browse files
Files changed (3) hide show
  1. config.json +16 -0
  2. model.py +212 -212
  3. model.safetensors +3 -0
config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DAC"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "model.DACConfig",
7
+ "AutoModel": "model.DAC"
8
+ },
9
+ "decoding_chunk_rate": 0.1,
10
+ "decoding_overlap_rate": 0.1,
11
+ "encoding_chunk_size_in_sec": 1,
12
+ "model_type": "dac",
13
+ "model_type_by_sampling_freq": "16khz",
14
+ "torch_dtype": "float32",
15
+ "transformers_version": "4.44.0"
16
+ }
model.py CHANGED
@@ -1,212 +1,212 @@
1
- from typing import Union
2
-
3
- import numpy as np
4
- import torch
5
- import torchaudio
6
- import torch.nn as nn
7
- import torchaudio.transforms as transforms
8
- from transformers import PretrainedConfig, PreTrainedModel
9
-
10
- import dac
11
- from audiotools import AudioSignal
12
-
13
- from utils import freeze
14
-
15
-
16
- class DACConfig(PretrainedConfig):
17
- model_type = 'dac'
18
-
19
- def __init__(self,
20
- model_type_by_sampling_freq:str='44khz',
21
- encoding_chunk_size_in_sec:int=1,
22
- decoding_chunk_rate:float=0.1,
23
- decoding_overlap_rate:float=0.1,
24
- **kwargs):
25
- super().__init__(**kwargs)
26
- """
27
- Initializes the model object.
28
- Args:
29
- model_type_by_sampling_freq (str, optional): The model type based on the sampling frequency. Defaults to '44khz'. Choose among ['44khz', '24khz', '16khz']
30
- encoding_chunk_size_in_sec (int, optional): The size of the encoding chunk in seconds. Defaults to 1.
31
- decoding_chunk_rate (float, optional): The decoding chunk rate. Must be between 0 and 1. Defaults to 0.1.
32
- decoding_overlap_rate (float, optional): The decoding overlap rate. Must be between 0 and 1. Defaults to 0.1.
33
- **kwargs: Additional keyword arguments.
34
- Raises:
35
- AssertionError: If the model_type_by_sampling_freq is not one of ['44khz', '24khz', '16khz'].
36
- AssertionError: If the decoding_chunk_rate is not between 0 and 1.
37
- AssertionError: If the decoding_overlap_rate is not between 0 and 1.
38
- """
39
- self.model_type_by_sampling_freq = model_type_by_sampling_freq
40
- self.encoding_chunk_size_in_sec = encoding_chunk_size_in_sec
41
- self.decoding_chunk_rate = decoding_chunk_rate
42
- self.decoding_overlap_rate = decoding_overlap_rate
43
-
44
- assert model_type_by_sampling_freq.lower() in ['44khz', '24khz', '16khz']
45
- assert decoding_chunk_rate > 0 and decoding_chunk_rate <= 1.0, '`decoding_chunk_rate` must be bewteen 0 and 1.'
46
- assert decoding_overlap_rate >= 0 and decoding_overlap_rate < 1.0, '`decoding_overlap_rate` must be bewteen 0 and 1.'
47
-
48
-
49
-
50
- class DAC(PreTrainedModel):
51
- config_class = DACConfig
52
-
53
- def __init__(self, config):
54
- super().__init__(config)
55
-
56
- self.model_type_by_sampling_freq = config.model_type_by_sampling_freq.lower()
57
- self.model_type_by_sampling_freq_int = {'44khz':44100, '24khz':24000, '16khz':16000}[self.model_type_by_sampling_freq]
58
- self.encoding_chunk_size_in_sec = config.encoding_chunk_size_in_sec
59
- self.decoding_chunk_rate = config.decoding_chunk_rate
60
- self.decoding_overlap_rate = config.decoding_overlap_rate
61
-
62
-
63
- dac_path = dac.utils.download(model_type=self.model_type_by_sampling_freq)
64
- self.dac = dac.DAC.load(dac_path)
65
- self.dac.eval()
66
- freeze(self.dac)
67
-
68
- self.downsampling_rate = int(np.prod(self.dac.encoder_rates)) # 512
69
-
70
- def load_audio(self, filename:str):
71
- waveform, sample_rate = torchaudio.load(filename) # waveform: (n_channels, length); sample_rate: const.
72
- return waveform, sample_rate
73
-
74
- def resample_audio(self, waveform:torch.FloatTensor, orig_sr:int, target_sr:int):
75
- """
76
- - sr: sampling rate
77
- - waveform: (n_channels, length)
78
- """
79
- if orig_sr == target_sr:
80
- return waveform
81
-
82
- converter = transforms.Resample(orig_freq=orig_sr, new_freq=target_sr)
83
- waveform = converter(waveform) # (n_channels, new_length)
84
- return waveform # (n_channels, new_length)
85
-
86
- def to_mono_channel(self, waveform:torch.FloatTensor):
87
- """
88
- - waveform: (n_channels, length)
89
- """
90
- n_channels = waveform.shape[0]
91
- if n_channels > 1:
92
- waveform = torch.mean(waveform, dim=0, keepdim=True) # (1, length)
93
- return waveform # (1, length)
94
-
95
- @torch.no_grad()
96
- def encode(self, audio_fname:str):
97
- self.eval()
98
-
99
- waveform, sr = self.load_audio(audio_fname)
100
- waveform = self.resample_audio(waveform, orig_sr=sr, target_sr=self.model_type_by_sampling_freq_int)
101
- sr = self.model_type_by_sampling_freq_int
102
- waveform = self.to_mono_channel(waveform) # DAC accepts a mono channel only.
103
-
104
- zq, s = self._chunk_encoding(waveform, sr)
105
- return zq, s
106
-
107
- def _chunk_encoding(self, waveform:torch.FloatTensor, sr:int):
108
- # TODO: can I make it parallel?
109
- """
110
- waveform: (c l)
111
- """
112
- x = waveform # brief varname
113
- x = x.unsqueeze(1) # (b 1 l); add a null batch dim
114
- chunk_size = int(self.encoding_chunk_size_in_sec * sr)
115
-
116
- # adjust `chunk_size` to prevent any padding in `dac.preprocess`, which causes a gap between the mini-batches in the resulting music.
117
- remainer = chunk_size % self.dac.hop_length
118
- chunk_size = chunk_size-remainer
119
-
120
- # process
121
- zq_list, s_list = [], []
122
- audio_length = x.shape[-1]
123
- for start in range(0, audio_length, chunk_size):
124
- end = start + chunk_size
125
- chunk = x[:, :, start:end]
126
- chunk = self.dac.preprocess(chunk, sr)
127
- zq, s, _, _, _ = self.dac.encode(chunk.to(self.device))
128
- zq = zq.cpu()
129
- s = s.cpu()
130
- """
131
- "zq" : Tensor[B x D x T]
132
- Quantized continuous representation of input
133
- = summation of all the residual quantized vectors across every rvq level
134
- = E(x) = z = \sum_n^N{zq_n} where N is the number of codebooks
135
- "s" : Tensor[B x N x T]
136
- Codebook indices for each codebook
137
- (quantized discrete representation of input)
138
- *first element in the N dimension = first RVQ level
139
- """
140
- zq_list.append(zq)
141
- s_list.append(s)
142
- torch.cuda.empty_cache()
143
-
144
- zq = torch.cat(zq_list, dim=2).float() # (1, d, length)
145
- s = torch.cat(s_list, dim=2).long() # (1, n_rvq, length)
146
-
147
- return zq, s
148
-
149
- @torch.no_grad()
150
- def decode(self, *, zq:Union[torch.FloatTensor,None]=None, s:Union[torch.IntTensor,None]=None):
151
- """
152
- zq: (b, d, length)
153
- """
154
- if isinstance(zq,type(None)) and isinstance(s,type(None)):
155
- assert False, 'one of them must be valid.'
156
- self.eval()
157
-
158
- if not isinstance(zq,type(None)):
159
- waveform = self._chunk_decoding(zq) # (b, 1, length); output always has a mono-channel.
160
- if not isinstance(s,type(None)):
161
- zq = self.code_to_zq(s)
162
- waveform = self._chunk_decoding(zq) # (b, 1, length); output always has a mono-channel.
163
-
164
- return waveform
165
-
166
- def _chunk_decoding(self, zq:torch.FloatTensor):
167
- """
168
- zq: (b, d, length)
169
- """
170
- length = zq.shape[-1]
171
- chunk_size = round(int(self.decoding_chunk_rate * length))
172
- overlap_size = round(self.decoding_overlap_rate * chunk_size) # overlap size in terms of token length
173
- overlap_size_in_data_space = round(overlap_size * self.downsampling_rate)
174
- waveform_concat = None
175
- for start in range(0, length, chunk_size-overlap_size):
176
- end = start + chunk_size
177
- chunk = zq[:,:, start:end] # (b, d, chunk_size)
178
- waveform = self.dac.decode(chunk.to(self.device)) # (b, 1, chunk_size*self.downsampling_rate)
179
- waveform = waveform.cpu()
180
-
181
- if isinstance(waveform_concat, type(None)):
182
- waveform_concat = waveform.clone()
183
- else:
184
- if self.decoding_overlap_rate != 0.:
185
- prev_x = waveform_concat[:,:,:-overlap_size_in_data_space]
186
- rest_of_new_x = waveform[:,:,overlap_size_in_data_space:]
187
- overlap_x_from_prev_x = waveform_concat[:,:,-overlap_size_in_data_space:] # (b, 1, overlap_size_in_data_space)
188
- overlap_x_from_new_x = waveform[:,:,:overlap_size_in_data_space] # (b, 1, overlap_size_in_data_space)
189
- overlap = (overlap_x_from_prev_x + overlap_x_from_new_x) / 2 # take mean; maybe there's a better strategy but it seems to work fine.
190
- waveform_concat = torch.cat((prev_x, overlap, rest_of_new_x), dim=-1) # (b, 1, ..)
191
- else:
192
- prev_x = waveform_concat
193
- rest_of_new_x = waveform
194
- waveform_concat = torch.cat((prev_x, rest_of_new_x), dim=-1) # (b, 1, ..)
195
- return waveform_concat # (b, 1, length)
196
-
197
- def code_to_zq(self, s:torch.IntTensor):
198
- """
199
- s: (b, n_rvq, length)
200
- """
201
- zq, _, _ = self.dac.quantizer.from_codes(s.to(self.device)) # zq: (b, d, length)
202
- zq = zq.cpu()
203
- return zq
204
-
205
- def save_tensor(self, tensor:torch.Tensor, fname:str) -> None:
206
- torch.save(tensor.cpu(), fname)
207
-
208
- def load_tensor(self, fname:str):
209
- return torch.load(fname)
210
-
211
- def waveform_to_audiofile(self, waveform:torch.FloatTensor, fname:str) -> None:
212
- AudioSignal(waveform, sample_rate=self.model_type_by_sampling_freq_int).write(fname)
 
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torchaudio
6
+ import torch.nn as nn
7
+ import torchaudio.transforms as transforms
8
+ from transformers import PretrainedConfig, PreTrainedModel
9
+
10
+ import dac
11
+ from audiotools import AudioSignal
12
+
13
+ from utils import freeze
14
+
15
+
16
+ class DACConfig(PretrainedConfig):
17
+ model_type = 'dac'
18
+
19
+ def __init__(self,
20
+ model_type_by_sampling_freq:str='44khz',
21
+ encoding_chunk_size_in_sec:int=1,
22
+ decoding_chunk_rate:float=0.1,
23
+ decoding_overlap_rate:float=0.1,
24
+ **kwargs):
25
+ super().__init__(**kwargs)
26
+ """
27
+ Initializes the model object.
28
+ Args:
29
+ model_type_by_sampling_freq (str, optional): The model type based on the sampling frequency. Defaults to '44khz'. Choose among ['44khz', '24khz', '16khz']
30
+ encoding_chunk_size_in_sec (int, optional): The size of the encoding chunk in seconds. Defaults to 1.
31
+ decoding_chunk_rate (float, optional): The decoding chunk rate. Must be between 0 and 1. Defaults to 0.1.
32
+ decoding_overlap_rate (float, optional): The decoding overlap rate. Must be between 0 and 1. Defaults to 0.1.
33
+ **kwargs: Additional keyword arguments.
34
+ Raises:
35
+ AssertionError: If the model_type_by_sampling_freq is not one of ['44khz', '24khz', '16khz'].
36
+ AssertionError: If the decoding_chunk_rate is not between 0 and 1.
37
+ AssertionError: If the decoding_overlap_rate is not between 0 and 1.
38
+ """
39
+ self.model_type_by_sampling_freq = model_type_by_sampling_freq
40
+ self.encoding_chunk_size_in_sec = encoding_chunk_size_in_sec
41
+ self.decoding_chunk_rate = decoding_chunk_rate
42
+ self.decoding_overlap_rate = decoding_overlap_rate
43
+
44
+ assert model_type_by_sampling_freq.lower() in ['44khz', '24khz', '16khz']
45
+ assert decoding_chunk_rate > 0 and decoding_chunk_rate <= 1.0, '`decoding_chunk_rate` must be bewteen 0 and 1.'
46
+ assert decoding_overlap_rate >= 0 and decoding_overlap_rate < 1.0, '`decoding_overlap_rate` must be bewteen 0 and 1.'
47
+
48
+
49
+
50
+ class DAC(PreTrainedModel):
51
+ config_class = DACConfig
52
+
53
+ def __init__(self, config):
54
+ super().__init__(config)
55
+
56
+ self.model_type_by_sampling_freq = config.model_type_by_sampling_freq.lower()
57
+ self.model_type_by_sampling_freq_int = {'44khz':44100, '24khz':24000, '16khz':16000}[self.model_type_by_sampling_freq]
58
+ self.encoding_chunk_size_in_sec = config.encoding_chunk_size_in_sec
59
+ self.decoding_chunk_rate = config.decoding_chunk_rate
60
+ self.decoding_overlap_rate = config.decoding_overlap_rate
61
+
62
+
63
+ dac_path = dac.utils.download(model_type=self.model_type_by_sampling_freq)
64
+ self.dac = dac.DAC.load(dac_path)
65
+ self.dac.eval()
66
+ freeze(self.dac)
67
+
68
+ self.downsampling_rate = int(np.prod(self.dac.encoder_rates)) # 512
69
+
70
+ def load_audio(self, filename:str):
71
+ waveform, sample_rate = torchaudio.load(filename) # waveform: (n_channels, length); sample_rate: const.
72
+ return waveform, sample_rate
73
+
74
+ def resample_audio(self, waveform:torch.FloatTensor, orig_sr:int, target_sr:int):
75
+ """
76
+ - sr: sampling rate
77
+ - waveform: (n_channels, length)
78
+ """
79
+ if orig_sr == target_sr:
80
+ return waveform
81
+
82
+ converter = transforms.Resample(orig_freq=orig_sr, new_freq=target_sr)
83
+ waveform = converter(waveform) # (n_channels, new_length)
84
+ return waveform # (n_channels, new_length)
85
+
86
+ def to_mono_channel(self, waveform:torch.FloatTensor):
87
+ """
88
+ - waveform: (n_channels, length)
89
+ """
90
+ n_channels = waveform.shape[0]
91
+ if n_channels > 1:
92
+ waveform = torch.mean(waveform, dim=0, keepdim=True) # (1, length)
93
+ return waveform # (1, length)
94
+
95
+ @torch.no_grad()
96
+ def encode(self, audio_fname:str):
97
+ self.eval()
98
+
99
+ waveform, sr = self.load_audio(audio_fname)
100
+ waveform = self.resample_audio(waveform, orig_sr=sr, target_sr=self.model_type_by_sampling_freq_int)
101
+ sr = self.model_type_by_sampling_freq_int
102
+ waveform = self.to_mono_channel(waveform) # DAC accepts a mono channel only.
103
+
104
+ zq, s = self._chunk_encoding(waveform, sr)
105
+ return zq, s
106
+
107
+ def _chunk_encoding(self, waveform:torch.FloatTensor, sr:int):
108
+ # TODO: can I make it parallel?
109
+ """
110
+ waveform: (c l)
111
+ """
112
+ x = waveform # brief varname
113
+ x = x.unsqueeze(1) # (b 1 l); add a null batch dim
114
+ chunk_size = int(self.encoding_chunk_size_in_sec * sr)
115
+
116
+ # adjust `chunk_size` to prevent any padding in `dac.preprocess`, which causes a gap between the mini-batches in the resulting music.
117
+ remainer = chunk_size % self.dac.hop_length
118
+ chunk_size = chunk_size-remainer
119
+
120
+ # process
121
+ zq_list, s_list = [], []
122
+ audio_length = x.shape[-1]
123
+ for start in range(0, audio_length, chunk_size):
124
+ end = start + chunk_size
125
+ chunk = x[:, :, start:end]
126
+ chunk = self.dac.preprocess(chunk, sr)
127
+ zq, s, _, _, _ = self.dac.encode(chunk.to(self.device))
128
+ zq = zq.cpu()
129
+ s = s.cpu()
130
+ """
131
+ "zq" : Tensor[B x D x T]
132
+ Quantized continuous representation of input
133
+ = summation of all the residual quantized vectors across every rvq level
134
+ = E(x) = z = \sum_n^N{zq_n} where N is the number of codebooks
135
+ "s" : Tensor[B x N x T]
136
+ Codebook indices for each codebook
137
+ (quantized discrete representation of input)
138
+ *first element in the N dimension = first RVQ level
139
+ """
140
+ zq_list.append(zq)
141
+ s_list.append(s)
142
+ torch.cuda.empty_cache()
143
+
144
+ zq = torch.cat(zq_list, dim=2).float() # (1, d, length)
145
+ s = torch.cat(s_list, dim=2).long() # (1, n_rvq, length)
146
+
147
+ return zq, s
148
+
149
+ @torch.no_grad()
150
+ def decode(self, *, zq:Union[torch.FloatTensor,None]=None, s:Union[torch.IntTensor,None]=None):
151
+ """
152
+ zq: (b, d, length)
153
+ """
154
+ if isinstance(zq,type(None)) and isinstance(s,type(None)):
155
+ assert False, 'one of them must be valid.'
156
+ self.eval()
157
+
158
+ if not isinstance(zq,type(None)):
159
+ waveform = self._chunk_decoding(zq) # (b, 1, length); output always has a mono-channel.
160
+ if not isinstance(s,type(None)):
161
+ zq = self.code_to_zq(s)
162
+ waveform = self._chunk_decoding(zq) # (b, 1, length); output always has a mono-channel.
163
+
164
+ return waveform
165
+
166
+ def _chunk_decoding(self, zq:torch.FloatTensor):
167
+ """
168
+ zq: (b, d, length)
169
+ """
170
+ length = zq.shape[-1]
171
+ chunk_size = round(int(self.decoding_chunk_rate * length))
172
+ overlap_size = round(self.decoding_overlap_rate * chunk_size) # overlap size in terms of token length
173
+ overlap_size_in_data_space = round(overlap_size * self.downsampling_rate)
174
+ waveform_concat = None
175
+ for start in range(0, length, chunk_size-overlap_size):
176
+ end = start + chunk_size
177
+ chunk = zq[:,:, start:end] # (b, d, chunk_size)
178
+ waveform = self.dac.decode(chunk.to(self.device)) # (b, 1, chunk_size*self.downsampling_rate)
179
+ waveform = waveform.cpu()
180
+
181
+ if isinstance(waveform_concat, type(None)):
182
+ waveform_concat = waveform.clone()
183
+ else:
184
+ if self.decoding_overlap_rate != 0.:
185
+ prev_x = waveform_concat[:,:,:-overlap_size_in_data_space]
186
+ rest_of_new_x = waveform[:,:,overlap_size_in_data_space:]
187
+ overlap_x_from_prev_x = waveform_concat[:,:,-overlap_size_in_data_space:] # (b, 1, overlap_size_in_data_space)
188
+ overlap_x_from_new_x = waveform[:,:,:overlap_size_in_data_space] # (b, 1, overlap_size_in_data_space)
189
+ overlap = (overlap_x_from_prev_x + overlap_x_from_new_x) / 2 # take mean; maybe there's a better strategy but it seems to work fine.
190
+ waveform_concat = torch.cat((prev_x, overlap, rest_of_new_x), dim=-1) # (b, 1, ..)
191
+ else:
192
+ prev_x = waveform_concat
193
+ rest_of_new_x = waveform
194
+ waveform_concat = torch.cat((prev_x, rest_of_new_x), dim=-1) # (b, 1, ..)
195
+ return waveform_concat # (b, 1, length)
196
+
197
+ def code_to_zq(self, s:torch.IntTensor):
198
+ """
199
+ s: (b, n_rvq, length)
200
+ """
201
+ zq, _, _ = self.dac.quantizer.from_codes(s.to(self.device)) # zq: (b, d, length)
202
+ zq = zq.cpu()
203
+ return zq
204
+
205
+ def save_tensor(self, tensor:torch.Tensor, fname:str) -> None:
206
+ torch.save(tensor.cpu(), fname)
207
+
208
+ def load_tensor(self, fname:str):
209
+ return torch.load(fname)
210
+
211
+ def waveform_to_audiofile(self, waveform:torch.FloatTensor, fname:str) -> None:
212
+ AudioSignal(waveform, sample_rate=self.model_type_by_sampling_freq_int).write(fname)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4eedd71256d763a5e9806e32e96bb33d7daff6dc10acbaab5403e4057a45771
3
+ size 296740304