StormblessedKal commited on
Commit
e1412bc
1 Parent(s): edd1b7f

docker for runpod

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. Dockerfile +55 -0
  3. builder/fetch_models.py +13 -0
  4. builder/requirements.txt +280 -0
  5. src/.gitattributes +2 -0
  6. src/Configs/config.yml +116 -0
  7. src/Configs/config_ft.yml +111 -0
  8. src/Configs/config_libritts.yml +113 -0
  9. src/Configs/hg.yml +21 -0
  10. src/Data/OOD_texts.txt +3 -0
  11. src/Data/train_list.txt +3 -0
  12. src/Data/val_list.txt +3 -0
  13. src/Models/epochs_2nd_00020.pth +3 -0
  14. src/Modules/__init__.py +1 -0
  15. src/Modules/__pycache__/__init__.cpython-310.pyc +0 -0
  16. src/Modules/__pycache__/discriminators.cpython-310.pyc +0 -0
  17. src/Modules/__pycache__/hifigan.cpython-310.pyc +0 -0
  18. src/Modules/__pycache__/utils.cpython-310.pyc +0 -0
  19. src/Modules/diffusion/__init__.py +1 -0
  20. src/Modules/diffusion/__pycache__/__init__.cpython-310.pyc +0 -0
  21. src/Modules/diffusion/__pycache__/diffusion.cpython-310.pyc +0 -0
  22. src/Modules/diffusion/__pycache__/modules.cpython-310.pyc +0 -0
  23. src/Modules/diffusion/__pycache__/sampler.cpython-310.pyc +0 -0
  24. src/Modules/diffusion/__pycache__/utils.cpython-310.pyc +0 -0
  25. src/Modules/diffusion/diffusion.py +92 -0
  26. src/Modules/diffusion/modules.py +700 -0
  27. src/Modules/diffusion/sampler.py +685 -0
  28. src/Modules/diffusion/utils.py +83 -0
  29. src/Modules/discriminators.py +267 -0
  30. src/Modules/hifigan.py +643 -0
  31. src/Modules/istftnet.py +720 -0
  32. src/Modules/slmadv.py +256 -0
  33. src/Modules/utils.py +14 -0
  34. src/Utils/ASR/__init__.py +1 -0
  35. src/Utils/ASR/__pycache__/__init__.cpython-310.pyc +0 -0
  36. src/Utils/ASR/__pycache__/layers.cpython-310.pyc +0 -0
  37. src/Utils/ASR/__pycache__/models.cpython-310.pyc +0 -0
  38. src/Utils/ASR/config.yml +29 -0
  39. src/Utils/ASR/epoch_00080.pth +3 -0
  40. src/Utils/ASR/layers.py +455 -0
  41. src/Utils/ASR/models.py +217 -0
  42. src/Utils/JDC/__init__.py +1 -0
  43. src/Utils/JDC/__pycache__/__init__.cpython-310.pyc +0 -0
  44. src/Utils/JDC/__pycache__/model.cpython-310.pyc +0 -0
  45. src/Utils/JDC/bst.t7 +3 -0
  46. src/Utils/JDC/model.py +212 -0
  47. src/Utils/PLBERT/__pycache__/util.cpython-310.pyc +0 -0
  48. src/Utils/PLBERT/config.yml +30 -0
  49. src/Utils/PLBERT/step_1000000.t7 +3 -0
  50. src/Utils/PLBERT/util.py +49 -0
.gitattributes CHANGED
@@ -21,6 +21,8 @@
21
  *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
 
 
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
@@ -32,4 +34,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
21
  *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.t7 filter=lfs diff=lfs merge=lfs -text
25
+ OOD_texts.txt filter=lfs diff=lfs merge=lfs -text
26
  *.rar filter=lfs diff=lfs merge=lfs -text
27
  *.safetensors filter=lfs diff=lfs merge=lfs -text
28
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
 
34
  *.xz filter=lfs diff=lfs merge=lfs -text
35
  *.zip filter=lfs diff=lfs merge=lfs -text
36
  *.zst filter=lfs diff=lfs merge=lfs -text
37
+ *.wav filter=lfs diff=lfs merge=lfs -text
38
  *tfevents* filter=lfs diff=lfs merge=lfs -text
39
+ *.mp3 filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use specific version of nvidia cuda image
2
+ FROM nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04
3
+
4
+ # Remove any third-party apt sources to avoid issues with expiring keys.
5
+ RUN rm -f /etc/apt/sources.list.d/*.list
6
+
7
+ # Set shell and noninteractive environment variables
8
+ SHELL ["/bin/bash", "-c"]
9
+ ENV DEBIAN_FRONTEND=noninteractive
10
+ ENV SHELL=/bin/bash
11
+
12
+ # Set working directory
13
+ WORKDIR /
14
+
15
+ # Update and upgrade the system packages (Worker Template)
16
+ RUN apt-get update -y && \
17
+ apt-get upgrade -y && \
18
+ apt-get install --yes --no-install-recommends sudo ca-certificates git wget curl bash libgl1 libx11-6 software-properties-common ffmpeg build-essential -y &&\
19
+ apt-get autoremove -y && \
20
+ apt-get clean -y && \
21
+ rm -rf /var/lib/apt/lists/*
22
+
23
+ # Add the deadsnakes PPA and install Python 3.10
24
+ RUN add-apt-repository ppa:deadsnakes/ppa -y && \
25
+ apt-get install python3.10-dev python3.10-venv python3-pip -y --no-install-recommends && \
26
+ ln -s /usr/bin/python3.10 /usr/bin/python && \
27
+ rm /usr/bin/python3 && \
28
+ ln -s /usr/bin/python3.10 /usr/bin/python3 && \
29
+ apt-get autoremove -y && \
30
+ apt-get clean -y && \
31
+ rm -rf /var/lib/apt/lists/*
32
+
33
+ # Download and install pip
34
+ RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && \
35
+ python get-pip.py && \
36
+ rm get-pip.py
37
+
38
+ # Install Python dependencies (Worker Template)
39
+ COPY builder/requirements.txt /requirements.txt
40
+ RUN --mount=type=cache,target=/root/.cache/pip \
41
+ pip install --upgrade pip && \
42
+ pip install -r /requirements.txt --no-cache-dir && \
43
+ rm /requirements.txt
44
+ # Copy source code into image
45
+ COPY src .
46
+
47
+ # Copy and run script to fetch models
48
+ COPY builder/fetch_models.py /fetch_models.py
49
+ RUN python /fetch_models.py && \
50
+ rm /fetch_models.py
51
+
52
+
53
+
54
+ # Set default command
55
+ CMD python -u /rp_handler.py
builder/fetch_models.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import se_extractor as se
2
+
3
+ _ = se.generate_voice_segments('openai_source_output.mp3',vad=False)
4
+ _ = se.generate_voice_segments('openai_source_output.mp3',vad=True)
5
+
6
+ from resemble_enhance.enhancer.inference import denoise, enhance
7
+ import torchaudio
8
+
9
+
10
+ dwav, sr = torchaudio.load('openai_source_output.mp3')
11
+ dwav = dwav.mean(dim=0)
12
+
13
+ wav1, new_sr = enhance(dwav, sr, 'cuda:0', nfe=32, solver='midpoint', lambd=0.9, tau=0.5)
builder/requirements.txt ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.25.0
2
+ aiofiles==23.2.1
3
+ altair==5.2.0
4
+ amqp==5.2.0
5
+ annotated-types==0.6.0
6
+ antlr4-python3-runtime==4.9.3
7
+ anyio==4.2.0
8
+ asttokens==2.0.5
9
+ astunparse==1.6.3
10
+ async-timeout==4.0.3
11
+ attrs==23.1.0
12
+ audioread==3.0.1
13
+ av==10.0.0
14
+ Babel==2.14.0
15
+ backcall==0.2.0
16
+ beartype==0.16.4
17
+ beautifulsoup4==4.12.2
18
+ bibtexparser==2.0.0b4
19
+ billiard==4.2.0
20
+ boltons==23.0.0
21
+ boto3==1.34.11
22
+ botocore==1.34.11
23
+ brotlipy==0.7.0
24
+ cached-path==1.5.1
25
+ cachetools==5.3.2
26
+ celery==5.3.6
27
+ celluloid==0.2.0
28
+ certifi==2023.7.22
29
+ cffi==1.15.1
30
+ chardet==4.0.0
31
+ charset-normalizer==2.0.4
32
+ click==8.1.7
33
+ click-didyoumean==0.3.0
34
+ click-plugins==1.1.1
35
+ click-repl==0.3.0
36
+ clldutils==3.22.1
37
+ cloudpickle==3.0.0
38
+ cn2an==0.5.22
39
+ colorama==0.4.6
40
+ coloredlogs==15.0.1
41
+ colorlog==6.8.0
42
+ conda==23.9.0
43
+ conda-build==3.27.0
44
+ conda-content-trust==0.2.0
45
+ conda_index==0.3.0
46
+ conda-libmamba-solver==23.7.0
47
+ conda-package-handling==2.2.0
48
+ conda_package_streaming==0.9.0
49
+ contourpy==1.2.0
50
+ cryptography==41.0.3
51
+ csvw==3.2.1
52
+ ctranslate2==3.23.0
53
+ cycler==0.12.1
54
+ Cython==3.0.7
55
+ dateparser==1.1.8
56
+ decorator==5.1.1
57
+ deepspeed==0.12.4
58
+ distro==1.9.0
59
+ dlinfo==1.2.1
60
+ dnspython==2.4.2
61
+ docopt==0.6.2
62
+ dtw-python==1.3.1
63
+ einops==0.7.0
64
+ einops-exts==0.0.4
65
+ email-validator==2.1.0.post1
66
+ eng-to-ipa==0.0.2
67
+ eventlet==0.34.2
68
+ exceptiongroup==1.0.4
69
+ executing==0.8.3
70
+ expecttest==0.1.6
71
+ fastapi==0.108.0
72
+ faster-whisper==0.10.0
73
+ ffmpy==0.3.1
74
+ filelock==3.9.0
75
+ flatbuffers==23.5.26
76
+ fonttools==4.47.0
77
+ fsspec==2023.9.2
78
+ gmpy2==2.1.2
79
+ google-api-core==2.15.0
80
+ google-auth==2.25.2
81
+ google-cloud-core==2.4.1
82
+ google-cloud-storage==2.14.0
83
+ google-crc32c==1.5.0
84
+ google-resumable-media==2.7.0
85
+ googleapis-common-protos==1.62.0
86
+ gradio==4.8.0
87
+ gradio_client==0.7.1
88
+ greenlet==3.0.3
89
+ gruut==2.3.4
90
+ gruut-ipa==0.13.0
91
+ gruut-lang-en==2.0.0
92
+ h11==0.14.0
93
+ hjson==3.1.0
94
+ httpcore==1.0.2
95
+ httptools==0.6.1
96
+ httpx==0.26.0
97
+ huggingface-hub==0.19.4
98
+ humanfriendly==10.0
99
+ hypothesis==6.87.1
100
+ icontract==2.6.6
101
+ idna==3.4
102
+ importlib-resources==6.1.1
103
+ inflect==7.0.0
104
+ interegular==0.3.2
105
+ ipython==8.15.0
106
+ isodate==0.6.1
107
+ itsdangerous==2.1.2
108
+ jedi==0.18.1
109
+ jieba==0.42.1
110
+ Jinja2==3.1.2
111
+ jmespath==1.0.1
112
+ joblib==1.3.2
113
+ jsonlines==1.2.0
114
+ jsonpatch==1.32
115
+ jsonpointer==2.1
116
+ jsonschema==4.20.0
117
+ jsonschema-specifications==2023.12.1
118
+ kiwisolver==1.4.5
119
+ kombu==5.3.4
120
+ language-tags==1.2.0
121
+ lark==1.1.8
122
+ lazy_loader==0.3
123
+ libarchive-c==2.9
124
+ libmambapy==1.4.1
125
+ librosa==0.10.1
126
+ llvmlite==0.41.1
127
+ lxml==5.0.0
128
+ Markdown==3.5.1
129
+ markdown-it-py==3.0.0
130
+ MarkupSafe==2.1.1
131
+ matplotlib==3.8.1
132
+ matplotlib-inline==0.1.6
133
+ mdurl==0.1.2
134
+ mkl-fft==1.3.8
135
+ mkl-random==1.2.4
136
+ mkl-service==2.4.0
137
+ monotonic_align==1.2
138
+ more-itertools==8.12.0
139
+ mpmath==1.3.0
140
+ msgpack==1.0.7
141
+ munch==4.0.0
142
+ nest-asyncio==1.5.8
143
+ networkx==2.8.8
144
+ ninja==1.11.1.1
145
+ nltk==3.8.1
146
+ num2words==0.5.13
147
+ numba==0.58.1
148
+ numpy==1.26.2
149
+ nvidia-cublas-cu12==12.1.3.1
150
+ nvidia-cuda-cupti-cu12==12.1.105
151
+ nvidia-cuda-nvrtc-cu12==12.1.105
152
+ nvidia-cuda-runtime-cu12==12.1.105
153
+ nvidia-cudnn-cu12==8.9.2.26
154
+ nvidia-cufft-cu12==11.0.2.54
155
+ nvidia-curand-cu12==10.3.2.106
156
+ nvidia-cusolver-cu12==11.4.5.107
157
+ nvidia-cusparse-cu12==12.1.0.106
158
+ nvidia-nccl-cu12==2.18.1
159
+ nvidia-nvjitlink-cu12==12.3.101
160
+ nvidia-nvtx-cu12==12.1.105
161
+ omegaconf==2.3.0
162
+ onnxruntime==1.16.3
163
+ openai==1.6.1
164
+ openai-whisper==20231117
165
+ orjson==3.9.10
166
+ outlines==0.0.21
167
+ packaging==23.1
168
+ pandas==2.1.3
169
+ parso==0.8.3
170
+ perscache==0.6.1
171
+ pexpect==4.8.0
172
+ phonemizer==3.2.1
173
+ pickleshare==0.7.5
174
+ Pillow==9.4.0
175
+ pip==23.2.1
176
+ pkginfo==1.9.6
177
+ platformdirs==4.1.0
178
+ pluggy==1.0.0
179
+ pooch==1.8.0
180
+ proces==0.1.7
181
+ progressbar==2.5
182
+ prompt-toolkit==3.0.36
183
+ protobuf==4.25.1
184
+ psutil==5.9.0
185
+ ptflops==0.7.1.2
186
+ ptyprocess==0.7.0
187
+ pure-eval==0.2.2
188
+ py-cpuinfo==9.0.0
189
+ pyasn1==0.5.1
190
+ pyasn1-modules==0.3.0
191
+ pycosat==0.6.4
192
+ pycparser==2.21
193
+ pydantic==2.5.3
194
+ pydantic_core==2.14.6
195
+ pydantic-extra-types==2.3.0
196
+ pydantic-settings==2.1.0
197
+ pydub==0.25.1
198
+ Pygments==2.15.1
199
+ pylatexenc==2.10
200
+ pynvml==11.5.0
201
+ pyOpenSSL==23.2.0
202
+ pyparsing==3.1.1
203
+ pypinyin==0.50.0
204
+ PySocks==1.7.1
205
+ python-crfsuite==0.9.10
206
+ python-dateutil==2.8.2
207
+ python-dotenv==1.0.0
208
+ python-etcd==0.4.5
209
+ python-multipart==0.0.6
210
+ pytz==2023.3.post1
211
+ PyYAML==6.0
212
+ rdflib==7.0.0
213
+ redis==5.0.1
214
+ referencing==0.32.0
215
+ regex==2023.12.25
216
+ requests==2.31.0
217
+ resampy==0.4.2
218
+ resemble-enhance==0.0.1
219
+ rfc3986==1.5.0
220
+ rich==13.7.0
221
+ rotary-embedding-torch==0.5.3
222
+ rpds-py==0.16.2
223
+ rsa==4.9
224
+ ruamel.yaml==0.17.21
225
+ ruamel.yaml.clib==0.2.6
226
+ s3transfer==0.10.0
227
+ safetensors==0.4.1
228
+ scikit-learn==1.3.2
229
+ scipy==1.11.4
230
+ segments==2.2.1
231
+ semantic-version==2.10.0
232
+ setuptools==68.0.0
233
+ shellingham==1.5.4
234
+ six==1.16.0
235
+ sniffio==1.3.0
236
+ sortedcontainers==2.4.0
237
+ soundfile==0.12.1
238
+ soupsieve==2.5
239
+ sox==1.4.1
240
+ soxr==0.3.7
241
+ stack-data==0.2.0
242
+ starlette==0.32.0.post1
243
+ sympy==1.11.1
244
+ tabulate==0.8.10
245
+ threadpoolctl==3.2.0
246
+ tiktoken==0.5.2
247
+ tokenizers==0.13.3
248
+ tomli==2.0.1
249
+ tomlkit==0.12.0
250
+ toolz==0.12.0
251
+ torch==2.1.1
252
+ torchaudio==2.1.1
253
+ torchelastic==0.2.2
254
+ torchvision==0.16.1
255
+ tortoise-tts==3.0.0
256
+ tqdm==4.66.1
257
+ traitlets==5.7.1
258
+ transformers==4.31.0
259
+ triton==2.1.0
260
+ truststore==0.8.0
261
+ typer==0.9.0
262
+ types-dataclasses==0.6.6
263
+ typing==3.7.4.3
264
+ typing_extensions==4.8.0
265
+ tzdata==2023.4
266
+ tzlocal==5.2
267
+ ujson==5.9.0
268
+ Unidecode==1.3.7
269
+ uritemplate==4.1.1
270
+ urllib3==1.26.16
271
+ uuid==1.30
272
+ uvicorn==0.25.0
273
+ uvloop==0.19.0
274
+ vine==5.1.0
275
+ watchfiles==0.21.0
276
+ wcwidth==0.2.5
277
+ websockets==11.0.3
278
+ wheel==0.41.2
279
+ whisper-timestamped==1.14.2
280
+ zstandard==0.19.0
src/.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.txt filter=lfs diff=lfs merge=lfs -text
2
+ *.t7 filter=lfs diff=lfs merge=lfs -text
src/Configs/config.yml ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "Models/LJSpeech"
2
+ first_stage_path: "first_stage.pth"
3
+ save_freq: 2
4
+ log_interval: 10
5
+ device: "cuda"
6
+ epochs_1st: 200 # number of epochs for first stage training (pre-training)
7
+ epochs_2nd: 100 # number of peochs for second stage training (joint training)
8
+ batch_size: 16
9
+ max_len: 400 # maximum number of frames
10
+ pretrained_model: ""
11
+ second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
12
+ load_only_params: false # set to true if do not want to load epoch numbers and optimizer parameters
13
+
14
+ F0_path: "Utils/JDC/bst.t7"
15
+ ASR_config: "Utils/ASR/config.yml"
16
+ ASR_path: "Utils/ASR/epoch_00080.pth"
17
+ PLBERT_dir: 'Utils/PLBERT/'
18
+
19
+ data_params:
20
+ train_data: "Data/train_list.txt"
21
+ val_data: "Data/val_list.txt"
22
+ root_path: "/local/LJSpeech-1.1/wavs"
23
+ OOD_data: "Data/OOD_texts.txt"
24
+ min_length: 50 # sample until texts with this size are obtained for OOD texts
25
+
26
+ preprocess_params:
27
+ sr: 24000
28
+ spect_params:
29
+ n_fft: 2048
30
+ win_length: 1200
31
+ hop_length: 300
32
+
33
+ model_params:
34
+ multispeaker: false
35
+
36
+ dim_in: 64
37
+ hidden_dim: 512
38
+ max_conv_dim: 512
39
+ n_layer: 3
40
+ n_mels: 80
41
+
42
+ n_token: 178 # number of phoneme tokens
43
+ max_dur: 50 # maximum duration of a single phoneme
44
+ style_dim: 128 # style vector size
45
+
46
+ dropout: 0.2
47
+
48
+ # config for decoder
49
+ decoder:
50
+ type: 'istftnet' # either hifigan or istftnet
51
+ resblock_kernel_sizes: [3,7,11]
52
+ upsample_rates : [10, 6]
53
+ upsample_initial_channel: 512
54
+ resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
55
+ upsample_kernel_sizes: [20, 12]
56
+ gen_istft_n_fft: 20
57
+ gen_istft_hop_size: 5
58
+
59
+ # speech language model config
60
+ slm:
61
+ model: 'microsoft/wavlm-base-plus'
62
+ sr: 16000 # sampling rate of SLM
63
+ hidden: 768 # hidden size of SLM
64
+ nlayers: 13 # number of layers of SLM
65
+ initial_channel: 64 # initial channels of SLM discriminator head
66
+
67
+ # style diffusion model config
68
+ diffusion:
69
+ embedding_mask_proba: 0.1
70
+ # transformer config
71
+ transformer:
72
+ num_layers: 3
73
+ num_heads: 8
74
+ head_features: 64
75
+ multiplier: 2
76
+
77
+ # diffusion distribution config
78
+ dist:
79
+ sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
80
+ estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
81
+ mean: -3.0
82
+ std: 1.0
83
+
84
+ loss_params:
85
+ lambda_mel: 5. # mel reconstruction loss
86
+ lambda_gen: 1. # generator loss
87
+ lambda_slm: 1. # slm feature matching loss
88
+
89
+ lambda_mono: 1. # monotonic alignment loss (1st stage, TMA)
90
+ lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA)
91
+ TMA_epoch: 50 # TMA starting epoch (1st stage)
92
+
93
+ lambda_F0: 1. # F0 reconstruction loss (2nd stage)
94
+ lambda_norm: 1. # norm reconstruction loss (2nd stage)
95
+ lambda_dur: 1. # duration loss (2nd stage)
96
+ lambda_ce: 20. # duration predictor probability output CE loss (2nd stage)
97
+ lambda_sty: 1. # style reconstruction loss (2nd stage)
98
+ lambda_diff: 1. # score matching loss (2nd stage)
99
+
100
+ diff_epoch: 20 # style diffusion starting epoch (2nd stage)
101
+ joint_epoch: 50 # joint training starting epoch (2nd stage)
102
+
103
+ optimizer_params:
104
+ lr: 0.0001 # general learning rate
105
+ bert_lr: 0.00001 # learning rate for PLBERT
106
+ ft_lr: 0.00001 # learning rate for acoustic modules
107
+
108
+ slmadv_params:
109
+ min_len: 400 # minimum length of samples
110
+ max_len: 500 # maximum length of samples
111
+ batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size
112
+ iter: 10 # update the discriminator every this iterations of generator update
113
+ thresh: 5 # gradient norm above which the gradient is scaled
114
+ scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
115
+ sig: 1.5 # sigma for differentiable duration modeling
116
+
src/Configs/config_ft.yml ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "Models/LJSpeech"
2
+ save_freq: 5
3
+ log_interval: 10
4
+ device: "cuda"
5
+ epochs: 50 # number of finetuning epoch (1 hour of data)
6
+ batch_size: 8
7
+ max_len: 400 # maximum number of frames
8
+ pretrained_model: "Models/LibriTTS/epochs_2nd_00020.pth"
9
+ second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
10
+ load_only_params: true # set to true if do not want to load epoch numbers and optimizer parameters
11
+
12
+ F0_path: "Utils/JDC/bst.t7"
13
+ ASR_config: "Utils/ASR/config.yml"
14
+ ASR_path: "Utils/ASR/epoch_00080.pth"
15
+ PLBERT_dir: 'Utils/PLBERT/'
16
+
17
+ data_params:
18
+ train_data: "Data/train_list.txt"
19
+ val_data: "Data/val_list.txt"
20
+ root_path: "/local/LJSpeech-1.1/wavs"
21
+ OOD_data: "Data/OOD_texts.txt"
22
+ min_length: 50 # sample until texts with this size are obtained for OOD texts
23
+
24
+ preprocess_params:
25
+ sr: 24000
26
+ spect_params:
27
+ n_fft: 2048
28
+ win_length: 1200
29
+ hop_length: 300
30
+
31
+ model_params:
32
+ multispeaker: true
33
+
34
+ dim_in: 64
35
+ hidden_dim: 512
36
+ max_conv_dim: 512
37
+ n_layer: 3
38
+ n_mels: 80
39
+
40
+ n_token: 178 # number of phoneme tokens
41
+ max_dur: 50 # maximum duration of a single phoneme
42
+ style_dim: 128 # style vector size
43
+
44
+ dropout: 0.2
45
+
46
+ # config for decoder
47
+ decoder:
48
+ type: 'hifigan' # either hifigan or istftnet
49
+ resblock_kernel_sizes: [3,7,11]
50
+ upsample_rates : [10,5,3,2]
51
+ upsample_initial_channel: 512
52
+ resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
53
+ upsample_kernel_sizes: [20,10,6,4]
54
+
55
+ # speech language model config
56
+ slm:
57
+ model: 'microsoft/wavlm-base-plus'
58
+ sr: 16000 # sampling rate of SLM
59
+ hidden: 768 # hidden size of SLM
60
+ nlayers: 13 # number of layers of SLM
61
+ initial_channel: 64 # initial channels of SLM discriminator head
62
+
63
+ # style diffusion model config
64
+ diffusion:
65
+ embedding_mask_proba: 0.1
66
+ # transformer config
67
+ transformer:
68
+ num_layers: 3
69
+ num_heads: 8
70
+ head_features: 64
71
+ multiplier: 2
72
+
73
+ # diffusion distribution config
74
+ dist:
75
+ sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
76
+ estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
77
+ mean: -3.0
78
+ std: 1.0
79
+
80
+ loss_params:
81
+ lambda_mel: 5. # mel reconstruction loss
82
+ lambda_gen: 1. # generator loss
83
+ lambda_slm: 1. # slm feature matching loss
84
+
85
+ lambda_mono: 1. # monotonic alignment loss (TMA)
86
+ lambda_s2s: 1. # sequence-to-sequence loss (TMA)
87
+
88
+ lambda_F0: 1. # F0 reconstruction loss
89
+ lambda_norm: 1. # norm reconstruction loss
90
+ lambda_dur: 1. # duration loss
91
+ lambda_ce: 20. # duration predictor probability output CE loss
92
+ lambda_sty: 1. # style reconstruction loss
93
+ lambda_diff: 1. # score matching loss
94
+
95
+ diff_epoch: 10 # style diffusion starting epoch
96
+ joint_epoch: 30 # joint training starting epoch
97
+
98
+ optimizer_params:
99
+ lr: 0.0001 # general learning rate
100
+ bert_lr: 0.00001 # learning rate for PLBERT
101
+ ft_lr: 0.0001 # learning rate for acoustic modules
102
+
103
+ slmadv_params:
104
+ min_len: 400 # minimum length of samples
105
+ max_len: 500 # maximum length of samples
106
+ batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size
107
+ iter: 10 # update the discriminator every this iterations of generator update
108
+ thresh: 5 # gradient norm above which the gradient is scaled
109
+ scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
110
+ sig: 1.5 # sigma for differentiable duration modeling
111
+
src/Configs/config_libritts.yml ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "Models/LibriTTS"
2
+ first_stage_path: "first_stage.pth"
3
+ save_freq: 1
4
+ log_interval: 10
5
+ device: "cuda"
6
+ epochs_1st: 50 # number of epochs for first stage training (pre-training)
7
+ epochs_2nd: 30 # number of peochs for second stage training (joint training)
8
+ batch_size: 16
9
+ max_len: 300 # maximum number of frames
10
+ pretrained_model: ""
11
+ second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
12
+ load_only_params: false # set to true if do not want to load epoch numbers and optimizer parameters
13
+
14
+ F0_path: "Utils/JDC/bst.t7"
15
+ ASR_config: "Utils/ASR/config.yml"
16
+ ASR_path: "Utils/ASR/epoch_00080.pth"
17
+ PLBERT_dir: 'Utils/PLBERT/'
18
+
19
+ data_params:
20
+ train_data: "Data/train_list.txt"
21
+ val_data: "Data/val_list.txt"
22
+ root_path: ""
23
+ OOD_data: "Data/OOD_texts.txt"
24
+ min_length: 50 # sample until texts with this size are obtained for OOD texts
25
+
26
+ preprocess_params:
27
+ sr: 24000
28
+ spect_params:
29
+ n_fft: 2048
30
+ win_length: 1200
31
+ hop_length: 300
32
+
33
+ model_params:
34
+ multispeaker: true
35
+
36
+ dim_in: 64
37
+ hidden_dim: 512
38
+ max_conv_dim: 512
39
+ n_layer: 3
40
+ n_mels: 80
41
+
42
+ n_token: 178 # number of phoneme tokens
43
+ max_dur: 50 # maximum duration of a single phoneme
44
+ style_dim: 128 # style vector size
45
+
46
+ dropout: 0.2
47
+
48
+ # config for decoder
49
+ decoder:
50
+ type: 'hifigan' # either hifigan or istftnet
51
+ resblock_kernel_sizes: [3,7,11]
52
+ upsample_rates : [10,5,3,2]
53
+ upsample_initial_channel: 512
54
+ resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
55
+ upsample_kernel_sizes: [20,10,6,4]
56
+
57
+ # speech language model config
58
+ slm:
59
+ model: 'microsoft/wavlm-base-plus'
60
+ sr: 16000 # sampling rate of SLM
61
+ hidden: 768 # hidden size of SLM
62
+ nlayers: 13 # number of layers of SLM
63
+ initial_channel: 64 # initial channels of SLM discriminator head
64
+
65
+ # style diffusion model config
66
+ diffusion:
67
+ embedding_mask_proba: 0.1
68
+ # transformer config
69
+ transformer:
70
+ num_layers: 3
71
+ num_heads: 8
72
+ head_features: 64
73
+ multiplier: 2
74
+
75
+ # diffusion distribution config
76
+ dist:
77
+ sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
78
+ estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
79
+ mean: -3.0
80
+ std: 1.0
81
+
82
+ loss_params:
83
+ lambda_mel: 5. # mel reconstruction loss
84
+ lambda_gen: 1. # generator loss
85
+ lambda_slm: 1. # slm feature matching loss
86
+
87
+ lambda_mono: 1. # monotonic alignment loss (1st stage, TMA)
88
+ lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA)
89
+ TMA_epoch: 5 # TMA starting epoch (1st stage)
90
+
91
+ lambda_F0: 1. # F0 reconstruction loss (2nd stage)
92
+ lambda_norm: 1. # norm reconstruction loss (2nd stage)
93
+ lambda_dur: 1. # duration loss (2nd stage)
94
+ lambda_ce: 20. # duration predictor probability output CE loss (2nd stage)
95
+ lambda_sty: 1. # style reconstruction loss (2nd stage)
96
+ lambda_diff: 1. # score matching loss (2nd stage)
97
+
98
+ diff_epoch: 10 # style diffusion starting epoch (2nd stage)
99
+ joint_epoch: 15 # joint training starting epoch (2nd stage)
100
+
101
+ optimizer_params:
102
+ lr: 0.0001 # general learning rate
103
+ bert_lr: 0.00001 # learning rate for PLBERT
104
+ ft_lr: 0.00001 # learning rate for acoustic modules
105
+
106
+ slmadv_params:
107
+ min_len: 400 # minimum length of samples
108
+ max_len: 500 # maximum length of samples
109
+ batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size
110
+ iter: 20 # update the discriminator every this iterations of generator update
111
+ thresh: 5 # gradient norm above which the gradient is scaled
112
+ scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
113
+ sig: 1.5 # sigma for differentiable duration modeling
src/Configs/hg.yml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {ASR_config: Utils/ASR/config.yml, ASR_path: Utils/ASR/epoch_00080.pth, F0_path: Utils/JDC/bst.t7,
2
+ PLBERT_dir: Utils/PLBERT/, batch_size: 8, data_params: {OOD_data: Data/OOD_texts.txt,
3
+ min_length: 50, root_path: '', train_data: Data/train_list.txt, val_data: Data/val_list.txt},
4
+ device: cuda, epochs_1st: 40, epochs_2nd: 25, first_stage_path: first_stage.pth,
5
+ load_only_params: false, log_dir: Models/LibriTTS, log_interval: 10, loss_params: {
6
+ TMA_epoch: 4, diff_epoch: 0, joint_epoch: 0, lambda_F0: 1.0, lambda_ce: 20.0,
7
+ lambda_diff: 1.0, lambda_dur: 1.0, lambda_gen: 1.0, lambda_mel: 5.0, lambda_mono: 1.0,
8
+ lambda_norm: 1.0, lambda_s2s: 1.0, lambda_slm: 1.0, lambda_sty: 1.0}, max_len: 300,
9
+ model_params: {decoder: {resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3,
10
+ 5]], resblock_kernel_sizes: [3, 7, 11], type: hifigan, upsample_initial_channel: 512,
11
+ upsample_kernel_sizes: [20, 10, 6, 4], upsample_rates: [10, 5, 3, 2]}, diffusion: {
12
+ dist: {estimate_sigma_data: true, mean: -3.0, sigma_data: 0.19926648961191362,
13
+ std: 1.0}, embedding_mask_proba: 0.1, transformer: {head_features: 64, multiplier: 2,
14
+ num_heads: 8, num_layers: 3}}, dim_in: 64, dropout: 0.2, hidden_dim: 512,
15
+ max_conv_dim: 512, max_dur: 50, multispeaker: true, n_layer: 3, n_mels: 80, n_token: 178,
16
+ slm: {hidden: 768, initial_channel: 64, model: microsoft/wavlm-base-plus, nlayers: 13,
17
+ sr: 16000}, style_dim: 128}, optimizer_params: {bert_lr: 1.0e-05, ft_lr: 1.0e-05,
18
+ lr: 0.0001}, preprocess_params: {spect_params: {hop_length: 300, n_fft: 2048,
19
+ win_length: 1200}, sr: 24000}, pretrained_model: Models/LibriTTS/epoch_2nd_00002.pth,
20
+ save_freq: 1, second_stage_load_pretrained: true, slmadv_params: {batch_percentage: 0.5,
21
+ iter: 20, max_len: 500, min_len: 400, scale: 0.01, sig: 1.5, thresh: 5}}
src/Data/OOD_texts.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0989ef6a9873b711befefcbe60660ced7a65532359277f766f4db504c558a72
3
+ size 31758898
src/Data/train_list.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a02392d09b88cb0dd5794d5aef056068b9741cde680c37fb34c607de83d77da0
3
+ size 2195448
src/Data/val_list.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e2a6f76b7698ce50ba199dfba60c784b758b3ef4981c05dffd0768db2934208
3
+ size 17203
src/Models/epochs_2nd_00020.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1164ffe19a17449d2c722234cecaf2836b35a698fb8ffd42562d2663657dca0a
3
+ size 771390526
src/Modules/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
src/Modules/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (149 Bytes). View file
 
src/Modules/__pycache__/discriminators.cpython-310.pyc ADDED
Binary file (6.18 kB). View file
 
src/Modules/__pycache__/hifigan.cpython-310.pyc ADDED
Binary file (14.7 kB). View file
 
src/Modules/__pycache__/utils.cpython-310.pyc ADDED
Binary file (757 Bytes). View file
 
src/Modules/diffusion/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
src/Modules/diffusion/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (159 Bytes). View file
 
src/Modules/diffusion/__pycache__/diffusion.cpython-310.pyc ADDED
Binary file (3.65 kB). View file
 
src/Modules/diffusion/__pycache__/modules.cpython-310.pyc ADDED
Binary file (16.2 kB). View file
 
src/Modules/diffusion/__pycache__/sampler.cpython-310.pyc ADDED
Binary file (22.1 kB). View file
 
src/Modules/diffusion/__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.49 kB). View file
 
src/Modules/diffusion/diffusion.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import pi
2
+ from random import randint
3
+ from typing import Any, Optional, Sequence, Tuple, Union
4
+
5
+ import torch
6
+ from einops import rearrange
7
+ from torch import Tensor, nn
8
+ from tqdm import tqdm
9
+
10
+ from .utils import *
11
+ from .sampler import *
12
+
13
+ """
14
+ Diffusion Classes (generic for 1d data)
15
+ """
16
+
17
+
18
+ class Model1d(nn.Module):
19
+ def __init__(self, unet_type: str = "base", **kwargs):
20
+ super().__init__()
21
+ diffusion_kwargs, kwargs = groupby("diffusion_", kwargs)
22
+ self.unet = None
23
+ self.diffusion = None
24
+
25
+ def forward(self, x: Tensor, **kwargs) -> Tensor:
26
+ return self.diffusion(x, **kwargs)
27
+
28
+ def sample(self, *args, **kwargs) -> Tensor:
29
+ return self.diffusion.sample(*args, **kwargs)
30
+
31
+
32
+ """
33
+ Audio Diffusion Classes (specific for 1d audio data)
34
+ """
35
+
36
+
37
+ def get_default_model_kwargs():
38
+ return dict(
39
+ channels=128,
40
+ patch_size=16,
41
+ multipliers=[1, 2, 4, 4, 4, 4, 4],
42
+ factors=[4, 4, 4, 2, 2, 2],
43
+ num_blocks=[2, 2, 2, 2, 2, 2],
44
+ attentions=[0, 0, 0, 1, 1, 1, 1],
45
+ attention_heads=8,
46
+ attention_features=64,
47
+ attention_multiplier=2,
48
+ attention_use_rel_pos=False,
49
+ diffusion_type="v",
50
+ diffusion_sigma_distribution=UniformDistribution(),
51
+ )
52
+
53
+
54
+ def get_default_sampling_kwargs():
55
+ return dict(sigma_schedule=LinearSchedule(), sampler=VSampler(), clamp=True)
56
+
57
+
58
+ class AudioDiffusionModel(Model1d):
59
+ def __init__(self, **kwargs):
60
+ super().__init__(**{**get_default_model_kwargs(), **kwargs})
61
+
62
+ def sample(self, *args, **kwargs):
63
+ return super().sample(*args, **{**get_default_sampling_kwargs(), **kwargs})
64
+
65
+
66
+ class AudioDiffusionConditional(Model1d):
67
+ def __init__(
68
+ self,
69
+ embedding_features: int,
70
+ embedding_max_length: int,
71
+ embedding_mask_proba: float = 0.1,
72
+ **kwargs,
73
+ ):
74
+ self.embedding_mask_proba = embedding_mask_proba
75
+ default_kwargs = dict(
76
+ **get_default_model_kwargs(),
77
+ unet_type="cfg",
78
+ context_embedding_features=embedding_features,
79
+ context_embedding_max_length=embedding_max_length,
80
+ )
81
+ super().__init__(**{**default_kwargs, **kwargs})
82
+
83
+ def forward(self, *args, **kwargs):
84
+ default_kwargs = dict(embedding_mask_proba=self.embedding_mask_proba)
85
+ return super().forward(*args, **{**default_kwargs, **kwargs})
86
+
87
+ def sample(self, *args, **kwargs):
88
+ default_kwargs = dict(
89
+ **get_default_sampling_kwargs(),
90
+ embedding_scale=5.0,
91
+ )
92
+ return super().sample(*args, **{**default_kwargs, **kwargs})
src/Modules/diffusion/modules.py ADDED
@@ -0,0 +1,700 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import floor, log, pi
2
+ from typing import Any, List, Optional, Sequence, Tuple, Union
3
+
4
+ from .utils import *
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange, reduce, repeat
9
+ from einops.layers.torch import Rearrange
10
+ from einops_exts import rearrange_many
11
+ from torch import Tensor, einsum
12
+
13
+
14
+ """
15
+ Utils
16
+ """
17
+
18
+
19
+ class AdaLayerNorm(nn.Module):
20
+ def __init__(self, style_dim, channels, eps=1e-5):
21
+ super().__init__()
22
+ self.channels = channels
23
+ self.eps = eps
24
+
25
+ self.fc = nn.Linear(style_dim, channels * 2)
26
+
27
+ def forward(self, x, s):
28
+ x = x.transpose(-1, -2)
29
+ x = x.transpose(1, -1)
30
+
31
+ h = self.fc(s)
32
+ h = h.view(h.size(0), h.size(1), 1)
33
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
34
+ gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
35
+
36
+ x = F.layer_norm(x, (self.channels,), eps=self.eps)
37
+ x = (1 + gamma) * x + beta
38
+ return x.transpose(1, -1).transpose(-1, -2)
39
+
40
+
41
+ class StyleTransformer1d(nn.Module):
42
+ def __init__(
43
+ self,
44
+ num_layers: int,
45
+ channels: int,
46
+ num_heads: int,
47
+ head_features: int,
48
+ multiplier: int,
49
+ use_context_time: bool = True,
50
+ use_rel_pos: bool = False,
51
+ context_features_multiplier: int = 1,
52
+ rel_pos_num_buckets: Optional[int] = None,
53
+ rel_pos_max_distance: Optional[int] = None,
54
+ context_features: Optional[int] = None,
55
+ context_embedding_features: Optional[int] = None,
56
+ embedding_max_length: int = 512,
57
+ ):
58
+ super().__init__()
59
+
60
+ self.blocks = nn.ModuleList(
61
+ [
62
+ StyleTransformerBlock(
63
+ features=channels + context_embedding_features,
64
+ head_features=head_features,
65
+ num_heads=num_heads,
66
+ multiplier=multiplier,
67
+ style_dim=context_features,
68
+ use_rel_pos=use_rel_pos,
69
+ rel_pos_num_buckets=rel_pos_num_buckets,
70
+ rel_pos_max_distance=rel_pos_max_distance,
71
+ )
72
+ for i in range(num_layers)
73
+ ]
74
+ )
75
+
76
+ self.to_out = nn.Sequential(
77
+ Rearrange("b t c -> b c t"),
78
+ nn.Conv1d(
79
+ in_channels=channels + context_embedding_features,
80
+ out_channels=channels,
81
+ kernel_size=1,
82
+ ),
83
+ )
84
+
85
+ use_context_features = exists(context_features)
86
+ self.use_context_features = use_context_features
87
+ self.use_context_time = use_context_time
88
+
89
+ if use_context_time or use_context_features:
90
+ context_mapping_features = channels + context_embedding_features
91
+
92
+ self.to_mapping = nn.Sequential(
93
+ nn.Linear(context_mapping_features, context_mapping_features),
94
+ nn.GELU(),
95
+ nn.Linear(context_mapping_features, context_mapping_features),
96
+ nn.GELU(),
97
+ )
98
+
99
+ if use_context_time:
100
+ assert exists(context_mapping_features)
101
+ self.to_time = nn.Sequential(
102
+ TimePositionalEmbedding(
103
+ dim=channels, out_features=context_mapping_features
104
+ ),
105
+ nn.GELU(),
106
+ )
107
+
108
+ if use_context_features:
109
+ assert exists(context_features) and exists(context_mapping_features)
110
+ self.to_features = nn.Sequential(
111
+ nn.Linear(
112
+ in_features=context_features, out_features=context_mapping_features
113
+ ),
114
+ nn.GELU(),
115
+ )
116
+
117
+ self.fixed_embedding = FixedEmbedding(
118
+ max_length=embedding_max_length, features=context_embedding_features
119
+ )
120
+
121
+ def get_mapping(
122
+ self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
123
+ ) -> Optional[Tensor]:
124
+ """Combines context time features and features into mapping"""
125
+ items, mapping = [], None
126
+ # Compute time features
127
+ if self.use_context_time:
128
+ assert_message = "use_context_time=True but no time features provided"
129
+ assert exists(time), assert_message
130
+ items += [self.to_time(time)]
131
+ # Compute features
132
+ if self.use_context_features:
133
+ assert_message = "context_features exists but no features provided"
134
+ assert exists(features), assert_message
135
+ items += [self.to_features(features)]
136
+
137
+ # Compute joint mapping
138
+ if self.use_context_time or self.use_context_features:
139
+ mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
140
+ mapping = self.to_mapping(mapping)
141
+
142
+ return mapping
143
+
144
+ def run(self, x, time, embedding, features):
145
+ mapping = self.get_mapping(time, features)
146
+ x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
147
+ mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
148
+
149
+ for block in self.blocks:
150
+ x = x + mapping
151
+ x = block(x, features)
152
+
153
+ x = x.mean(axis=1).unsqueeze(1)
154
+ x = self.to_out(x)
155
+ x = x.transpose(-1, -2)
156
+
157
+ return x
158
+
159
+ def forward(
160
+ self,
161
+ x: Tensor,
162
+ time: Tensor,
163
+ embedding_mask_proba: float = 0.0,
164
+ embedding: Optional[Tensor] = None,
165
+ features: Optional[Tensor] = None,
166
+ embedding_scale: float = 1.0,
167
+ ) -> Tensor:
168
+ b, device = embedding.shape[0], embedding.device
169
+ fixed_embedding = self.fixed_embedding(embedding)
170
+ if embedding_mask_proba > 0.0:
171
+ # Randomly mask embedding
172
+ batch_mask = rand_bool(
173
+ shape=(b, 1, 1), proba=embedding_mask_proba, device=device
174
+ )
175
+ embedding = torch.where(batch_mask, fixed_embedding, embedding)
176
+
177
+ if embedding_scale != 1.0:
178
+ # Compute both normal and fixed embedding outputs
179
+ out = self.run(x, time, embedding=embedding, features=features)
180
+ out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
181
+ # Scale conditional output using classifier-free guidance
182
+ return out_masked + (out - out_masked) * embedding_scale
183
+ else:
184
+ return self.run(x, time, embedding=embedding, features=features)
185
+
186
+ return x
187
+
188
+
189
+ class StyleTransformerBlock(nn.Module):
190
+ def __init__(
191
+ self,
192
+ features: int,
193
+ num_heads: int,
194
+ head_features: int,
195
+ style_dim: int,
196
+ multiplier: int,
197
+ use_rel_pos: bool,
198
+ rel_pos_num_buckets: Optional[int] = None,
199
+ rel_pos_max_distance: Optional[int] = None,
200
+ context_features: Optional[int] = None,
201
+ ):
202
+ super().__init__()
203
+
204
+ self.use_cross_attention = exists(context_features) and context_features > 0
205
+
206
+ self.attention = StyleAttention(
207
+ features=features,
208
+ style_dim=style_dim,
209
+ num_heads=num_heads,
210
+ head_features=head_features,
211
+ use_rel_pos=use_rel_pos,
212
+ rel_pos_num_buckets=rel_pos_num_buckets,
213
+ rel_pos_max_distance=rel_pos_max_distance,
214
+ )
215
+
216
+ if self.use_cross_attention:
217
+ self.cross_attention = StyleAttention(
218
+ features=features,
219
+ style_dim=style_dim,
220
+ num_heads=num_heads,
221
+ head_features=head_features,
222
+ context_features=context_features,
223
+ use_rel_pos=use_rel_pos,
224
+ rel_pos_num_buckets=rel_pos_num_buckets,
225
+ rel_pos_max_distance=rel_pos_max_distance,
226
+ )
227
+
228
+ self.feed_forward = FeedForward(features=features, multiplier=multiplier)
229
+
230
+ def forward(
231
+ self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None
232
+ ) -> Tensor:
233
+ x = self.attention(x, s) + x
234
+ if self.use_cross_attention:
235
+ x = self.cross_attention(x, s, context=context) + x
236
+ x = self.feed_forward(x) + x
237
+ return x
238
+
239
+
240
+ class StyleAttention(nn.Module):
241
+ def __init__(
242
+ self,
243
+ features: int,
244
+ *,
245
+ style_dim: int,
246
+ head_features: int,
247
+ num_heads: int,
248
+ context_features: Optional[int] = None,
249
+ use_rel_pos: bool,
250
+ rel_pos_num_buckets: Optional[int] = None,
251
+ rel_pos_max_distance: Optional[int] = None,
252
+ ):
253
+ super().__init__()
254
+ self.context_features = context_features
255
+ mid_features = head_features * num_heads
256
+ context_features = default(context_features, features)
257
+
258
+ self.norm = AdaLayerNorm(style_dim, features)
259
+ self.norm_context = AdaLayerNorm(style_dim, context_features)
260
+ self.to_q = nn.Linear(
261
+ in_features=features, out_features=mid_features, bias=False
262
+ )
263
+ self.to_kv = nn.Linear(
264
+ in_features=context_features, out_features=mid_features * 2, bias=False
265
+ )
266
+ self.attention = AttentionBase(
267
+ features,
268
+ num_heads=num_heads,
269
+ head_features=head_features,
270
+ use_rel_pos=use_rel_pos,
271
+ rel_pos_num_buckets=rel_pos_num_buckets,
272
+ rel_pos_max_distance=rel_pos_max_distance,
273
+ )
274
+
275
+ def forward(
276
+ self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None
277
+ ) -> Tensor:
278
+ assert_message = "You must provide a context when using context_features"
279
+ assert not self.context_features or exists(context), assert_message
280
+ # Use context if provided
281
+ context = default(context, x)
282
+ # Normalize then compute q from input and k,v from context
283
+ x, context = self.norm(x, s), self.norm_context(context, s)
284
+
285
+ q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
286
+ # Compute and return attention
287
+ return self.attention(q, k, v)
288
+
289
+
290
+ class Transformer1d(nn.Module):
291
+ def __init__(
292
+ self,
293
+ num_layers: int,
294
+ channels: int,
295
+ num_heads: int,
296
+ head_features: int,
297
+ multiplier: int,
298
+ use_context_time: bool = True,
299
+ use_rel_pos: bool = False,
300
+ context_features_multiplier: int = 1,
301
+ rel_pos_num_buckets: Optional[int] = None,
302
+ rel_pos_max_distance: Optional[int] = None,
303
+ context_features: Optional[int] = None,
304
+ context_embedding_features: Optional[int] = None,
305
+ embedding_max_length: int = 512,
306
+ ):
307
+ super().__init__()
308
+
309
+ self.blocks = nn.ModuleList(
310
+ [
311
+ TransformerBlock(
312
+ features=channels + context_embedding_features,
313
+ head_features=head_features,
314
+ num_heads=num_heads,
315
+ multiplier=multiplier,
316
+ use_rel_pos=use_rel_pos,
317
+ rel_pos_num_buckets=rel_pos_num_buckets,
318
+ rel_pos_max_distance=rel_pos_max_distance,
319
+ )
320
+ for i in range(num_layers)
321
+ ]
322
+ )
323
+
324
+ self.to_out = nn.Sequential(
325
+ Rearrange("b t c -> b c t"),
326
+ nn.Conv1d(
327
+ in_channels=channels + context_embedding_features,
328
+ out_channels=channels,
329
+ kernel_size=1,
330
+ ),
331
+ )
332
+
333
+ use_context_features = exists(context_features)
334
+ self.use_context_features = use_context_features
335
+ self.use_context_time = use_context_time
336
+
337
+ if use_context_time or use_context_features:
338
+ context_mapping_features = channels + context_embedding_features
339
+
340
+ self.to_mapping = nn.Sequential(
341
+ nn.Linear(context_mapping_features, context_mapping_features),
342
+ nn.GELU(),
343
+ nn.Linear(context_mapping_features, context_mapping_features),
344
+ nn.GELU(),
345
+ )
346
+
347
+ if use_context_time:
348
+ assert exists(context_mapping_features)
349
+ self.to_time = nn.Sequential(
350
+ TimePositionalEmbedding(
351
+ dim=channels, out_features=context_mapping_features
352
+ ),
353
+ nn.GELU(),
354
+ )
355
+
356
+ if use_context_features:
357
+ assert exists(context_features) and exists(context_mapping_features)
358
+ self.to_features = nn.Sequential(
359
+ nn.Linear(
360
+ in_features=context_features, out_features=context_mapping_features
361
+ ),
362
+ nn.GELU(),
363
+ )
364
+
365
+ self.fixed_embedding = FixedEmbedding(
366
+ max_length=embedding_max_length, features=context_embedding_features
367
+ )
368
+
369
+ def get_mapping(
370
+ self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
371
+ ) -> Optional[Tensor]:
372
+ """Combines context time features and features into mapping"""
373
+ items, mapping = [], None
374
+ # Compute time features
375
+ if self.use_context_time:
376
+ assert_message = "use_context_time=True but no time features provided"
377
+ assert exists(time), assert_message
378
+ items += [self.to_time(time)]
379
+ # Compute features
380
+ if self.use_context_features:
381
+ assert_message = "context_features exists but no features provided"
382
+ assert exists(features), assert_message
383
+ items += [self.to_features(features)]
384
+
385
+ # Compute joint mapping
386
+ if self.use_context_time or self.use_context_features:
387
+ mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
388
+ mapping = self.to_mapping(mapping)
389
+
390
+ return mapping
391
+
392
+ def run(self, x, time, embedding, features):
393
+ mapping = self.get_mapping(time, features)
394
+ x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
395
+ mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
396
+
397
+ for block in self.blocks:
398
+ x = x + mapping
399
+ x = block(x)
400
+
401
+ x = x.mean(axis=1).unsqueeze(1)
402
+ x = self.to_out(x)
403
+ x = x.transpose(-1, -2)
404
+
405
+ return x
406
+
407
+ def forward(
408
+ self,
409
+ x: Tensor,
410
+ time: Tensor,
411
+ embedding_mask_proba: float = 0.0,
412
+ embedding: Optional[Tensor] = None,
413
+ features: Optional[Tensor] = None,
414
+ embedding_scale: float = 1.0,
415
+ ) -> Tensor:
416
+ b, device = embedding.shape[0], embedding.device
417
+ fixed_embedding = self.fixed_embedding(embedding)
418
+ if embedding_mask_proba > 0.0:
419
+ # Randomly mask embedding
420
+ batch_mask = rand_bool(
421
+ shape=(b, 1, 1), proba=embedding_mask_proba, device=device
422
+ )
423
+ embedding = torch.where(batch_mask, fixed_embedding, embedding)
424
+
425
+ if embedding_scale != 1.0:
426
+ # Compute both normal and fixed embedding outputs
427
+ out = self.run(x, time, embedding=embedding, features=features)
428
+ out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
429
+ # Scale conditional output using classifier-free guidance
430
+ return out_masked + (out - out_masked) * embedding_scale
431
+ else:
432
+ return self.run(x, time, embedding=embedding, features=features)
433
+
434
+ return x
435
+
436
+
437
+ """
438
+ Attention Components
439
+ """
440
+
441
+
442
+ class RelativePositionBias(nn.Module):
443
+ def __init__(self, num_buckets: int, max_distance: int, num_heads: int):
444
+ super().__init__()
445
+ self.num_buckets = num_buckets
446
+ self.max_distance = max_distance
447
+ self.num_heads = num_heads
448
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
449
+
450
+ @staticmethod
451
+ def _relative_position_bucket(
452
+ relative_position: Tensor, num_buckets: int, max_distance: int
453
+ ):
454
+ num_buckets //= 2
455
+ ret = (relative_position >= 0).to(torch.long) * num_buckets
456
+ n = torch.abs(relative_position)
457
+
458
+ max_exact = num_buckets // 2
459
+ is_small = n < max_exact
460
+
461
+ val_if_large = (
462
+ max_exact
463
+ + (
464
+ torch.log(n.float() / max_exact)
465
+ / log(max_distance / max_exact)
466
+ * (num_buckets - max_exact)
467
+ ).long()
468
+ )
469
+ val_if_large = torch.min(
470
+ val_if_large, torch.full_like(val_if_large, num_buckets - 1)
471
+ )
472
+
473
+ ret += torch.where(is_small, n, val_if_large)
474
+ return ret
475
+
476
+ def forward(self, num_queries: int, num_keys: int) -> Tensor:
477
+ i, j, device = num_queries, num_keys, self.relative_attention_bias.weight.device
478
+ q_pos = torch.arange(j - i, j, dtype=torch.long, device=device)
479
+ k_pos = torch.arange(j, dtype=torch.long, device=device)
480
+ rel_pos = rearrange(k_pos, "j -> 1 j") - rearrange(q_pos, "i -> i 1")
481
+
482
+ relative_position_bucket = self._relative_position_bucket(
483
+ rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance
484
+ )
485
+
486
+ bias = self.relative_attention_bias(relative_position_bucket)
487
+ bias = rearrange(bias, "m n h -> 1 h m n")
488
+ return bias
489
+
490
+
491
+ def FeedForward(features: int, multiplier: int) -> nn.Module:
492
+ mid_features = features * multiplier
493
+ return nn.Sequential(
494
+ nn.Linear(in_features=features, out_features=mid_features),
495
+ nn.GELU(),
496
+ nn.Linear(in_features=mid_features, out_features=features),
497
+ )
498
+
499
+
500
+ class AttentionBase(nn.Module):
501
+ def __init__(
502
+ self,
503
+ features: int,
504
+ *,
505
+ head_features: int,
506
+ num_heads: int,
507
+ use_rel_pos: bool,
508
+ out_features: Optional[int] = None,
509
+ rel_pos_num_buckets: Optional[int] = None,
510
+ rel_pos_max_distance: Optional[int] = None,
511
+ ):
512
+ super().__init__()
513
+ self.scale = head_features**-0.5
514
+ self.num_heads = num_heads
515
+ self.use_rel_pos = use_rel_pos
516
+ mid_features = head_features * num_heads
517
+
518
+ if use_rel_pos:
519
+ assert exists(rel_pos_num_buckets) and exists(rel_pos_max_distance)
520
+ self.rel_pos = RelativePositionBias(
521
+ num_buckets=rel_pos_num_buckets,
522
+ max_distance=rel_pos_max_distance,
523
+ num_heads=num_heads,
524
+ )
525
+ if out_features is None:
526
+ out_features = features
527
+
528
+ self.to_out = nn.Linear(in_features=mid_features, out_features=out_features)
529
+
530
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
531
+ # Split heads
532
+ q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
533
+ # Compute similarity matrix
534
+ sim = einsum("... n d, ... m d -> ... n m", q, k)
535
+ sim = (sim + self.rel_pos(*sim.shape[-2:])) if self.use_rel_pos else sim
536
+ sim = sim * self.scale
537
+ # Get attention matrix with softmax
538
+ attn = sim.softmax(dim=-1)
539
+ # Compute values
540
+ out = einsum("... n m, ... m d -> ... n d", attn, v)
541
+ out = rearrange(out, "b h n d -> b n (h d)")
542
+ return self.to_out(out)
543
+
544
+
545
+ class Attention(nn.Module):
546
+ def __init__(
547
+ self,
548
+ features: int,
549
+ *,
550
+ head_features: int,
551
+ num_heads: int,
552
+ out_features: Optional[int] = None,
553
+ context_features: Optional[int] = None,
554
+ use_rel_pos: bool,
555
+ rel_pos_num_buckets: Optional[int] = None,
556
+ rel_pos_max_distance: Optional[int] = None,
557
+ ):
558
+ super().__init__()
559
+ self.context_features = context_features
560
+ mid_features = head_features * num_heads
561
+ context_features = default(context_features, features)
562
+
563
+ self.norm = nn.LayerNorm(features)
564
+ self.norm_context = nn.LayerNorm(context_features)
565
+ self.to_q = nn.Linear(
566
+ in_features=features, out_features=mid_features, bias=False
567
+ )
568
+ self.to_kv = nn.Linear(
569
+ in_features=context_features, out_features=mid_features * 2, bias=False
570
+ )
571
+
572
+ self.attention = AttentionBase(
573
+ features,
574
+ out_features=out_features,
575
+ num_heads=num_heads,
576
+ head_features=head_features,
577
+ use_rel_pos=use_rel_pos,
578
+ rel_pos_num_buckets=rel_pos_num_buckets,
579
+ rel_pos_max_distance=rel_pos_max_distance,
580
+ )
581
+
582
+ def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
583
+ assert_message = "You must provide a context when using context_features"
584
+ assert not self.context_features or exists(context), assert_message
585
+ # Use context if provided
586
+ context = default(context, x)
587
+ # Normalize then compute q from input and k,v from context
588
+ x, context = self.norm(x), self.norm_context(context)
589
+ q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
590
+ # Compute and return attention
591
+ return self.attention(q, k, v)
592
+
593
+
594
+ """
595
+ Transformer Blocks
596
+ """
597
+
598
+
599
+ class TransformerBlock(nn.Module):
600
+ def __init__(
601
+ self,
602
+ features: int,
603
+ num_heads: int,
604
+ head_features: int,
605
+ multiplier: int,
606
+ use_rel_pos: bool,
607
+ rel_pos_num_buckets: Optional[int] = None,
608
+ rel_pos_max_distance: Optional[int] = None,
609
+ context_features: Optional[int] = None,
610
+ ):
611
+ super().__init__()
612
+
613
+ self.use_cross_attention = exists(context_features) and context_features > 0
614
+
615
+ self.attention = Attention(
616
+ features=features,
617
+ num_heads=num_heads,
618
+ head_features=head_features,
619
+ use_rel_pos=use_rel_pos,
620
+ rel_pos_num_buckets=rel_pos_num_buckets,
621
+ rel_pos_max_distance=rel_pos_max_distance,
622
+ )
623
+
624
+ if self.use_cross_attention:
625
+ self.cross_attention = Attention(
626
+ features=features,
627
+ num_heads=num_heads,
628
+ head_features=head_features,
629
+ context_features=context_features,
630
+ use_rel_pos=use_rel_pos,
631
+ rel_pos_num_buckets=rel_pos_num_buckets,
632
+ rel_pos_max_distance=rel_pos_max_distance,
633
+ )
634
+
635
+ self.feed_forward = FeedForward(features=features, multiplier=multiplier)
636
+
637
+ def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
638
+ x = self.attention(x) + x
639
+ if self.use_cross_attention:
640
+ x = self.cross_attention(x, context=context) + x
641
+ x = self.feed_forward(x) + x
642
+ return x
643
+
644
+
645
+ """
646
+ Time Embeddings
647
+ """
648
+
649
+
650
+ class SinusoidalEmbedding(nn.Module):
651
+ def __init__(self, dim: int):
652
+ super().__init__()
653
+ self.dim = dim
654
+
655
+ def forward(self, x: Tensor) -> Tensor:
656
+ device, half_dim = x.device, self.dim // 2
657
+ emb = torch.tensor(log(10000) / (half_dim - 1), device=device)
658
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
659
+ emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
660
+ return torch.cat((emb.sin(), emb.cos()), dim=-1)
661
+
662
+
663
+ class LearnedPositionalEmbedding(nn.Module):
664
+ """Used for continuous time"""
665
+
666
+ def __init__(self, dim: int):
667
+ super().__init__()
668
+ assert (dim % 2) == 0
669
+ half_dim = dim // 2
670
+ self.weights = nn.Parameter(torch.randn(half_dim))
671
+
672
+ def forward(self, x: Tensor) -> Tensor:
673
+ x = rearrange(x, "b -> b 1")
674
+ freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
675
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
676
+ fouriered = torch.cat((x, fouriered), dim=-1)
677
+ return fouriered
678
+
679
+
680
+ def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
681
+ return nn.Sequential(
682
+ LearnedPositionalEmbedding(dim),
683
+ nn.Linear(in_features=dim + 1, out_features=out_features),
684
+ )
685
+
686
+
687
+ class FixedEmbedding(nn.Module):
688
+ def __init__(self, max_length: int, features: int):
689
+ super().__init__()
690
+ self.max_length = max_length
691
+ self.embedding = nn.Embedding(max_length, features)
692
+
693
+ def forward(self, x: Tensor) -> Tensor:
694
+ batch_size, length, device = *x.shape[0:2], x.device
695
+ assert_message = "Input sequence length must be <= max_length"
696
+ assert length <= self.max_length, assert_message
697
+ position = torch.arange(length, device=device)
698
+ fixed_embedding = self.embedding(position)
699
+ fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size)
700
+ return fixed_embedding
src/Modules/diffusion/sampler.py ADDED
@@ -0,0 +1,685 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import atan, cos, pi, sin, sqrt
2
+ from typing import Any, Callable, List, Optional, Tuple, Type
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange, reduce
8
+ from torch import Tensor
9
+
10
+ from .utils import *
11
+
12
+ """
13
+ Diffusion Training
14
+ """
15
+
16
+ """ Distributions """
17
+
18
+
19
+ class Distribution:
20
+ def __call__(self, num_samples: int, device: torch.device):
21
+ raise NotImplementedError()
22
+
23
+
24
+ class LogNormalDistribution(Distribution):
25
+ def __init__(self, mean: float, std: float):
26
+ self.mean = mean
27
+ self.std = std
28
+
29
+ def __call__(
30
+ self, num_samples: int, device: torch.device = torch.device("cpu")
31
+ ) -> Tensor:
32
+ normal = self.mean + self.std * torch.randn((num_samples,), device=device)
33
+ return normal.exp()
34
+
35
+
36
+ class UniformDistribution(Distribution):
37
+ def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
38
+ return torch.rand(num_samples, device=device)
39
+
40
+
41
+ class VKDistribution(Distribution):
42
+ def __init__(
43
+ self,
44
+ min_value: float = 0.0,
45
+ max_value: float = float("inf"),
46
+ sigma_data: float = 1.0,
47
+ ):
48
+ self.min_value = min_value
49
+ self.max_value = max_value
50
+ self.sigma_data = sigma_data
51
+
52
+ def __call__(
53
+ self, num_samples: int, device: torch.device = torch.device("cpu")
54
+ ) -> Tensor:
55
+ sigma_data = self.sigma_data
56
+ min_cdf = atan(self.min_value / sigma_data) * 2 / pi
57
+ max_cdf = atan(self.max_value / sigma_data) * 2 / pi
58
+ u = (max_cdf - min_cdf) * torch.randn((num_samples,), device=device) + min_cdf
59
+ return torch.tan(u * pi / 2) * sigma_data
60
+
61
+
62
+ """ Diffusion Classes """
63
+
64
+
65
+ def pad_dims(x: Tensor, ndim: int) -> Tensor:
66
+ # Pads additional ndims to the right of the tensor
67
+ return x.view(*x.shape, *((1,) * ndim))
68
+
69
+
70
+ def clip(x: Tensor, dynamic_threshold: float = 0.0):
71
+ if dynamic_threshold == 0.0:
72
+ return x.clamp(-1.0, 1.0)
73
+ else:
74
+ # Dynamic thresholding
75
+ # Find dynamic threshold quantile for each batch
76
+ x_flat = rearrange(x, "b ... -> b (...)")
77
+ scale = torch.quantile(x_flat.abs(), dynamic_threshold, dim=-1)
78
+ # Clamp to a min of 1.0
79
+ scale.clamp_(min=1.0)
80
+ # Clamp all values and scale
81
+ scale = pad_dims(scale, ndim=x.ndim - scale.ndim)
82
+ x = x.clamp(-scale, scale) / scale
83
+ return x
84
+
85
+
86
+ def to_batch(
87
+ batch_size: int,
88
+ device: torch.device,
89
+ x: Optional[float] = None,
90
+ xs: Optional[Tensor] = None,
91
+ ) -> Tensor:
92
+ assert exists(x) ^ exists(xs), "Either x or xs must be provided"
93
+ # If x provided use the same for all batch items
94
+ if exists(x):
95
+ xs = torch.full(size=(batch_size,), fill_value=x).to(device)
96
+ assert exists(xs)
97
+ return xs
98
+
99
+
100
+ class Diffusion(nn.Module):
101
+ alias: str = ""
102
+
103
+ """Base diffusion class"""
104
+
105
+ def denoise_fn(
106
+ self,
107
+ x_noisy: Tensor,
108
+ sigmas: Optional[Tensor] = None,
109
+ sigma: Optional[float] = None,
110
+ **kwargs,
111
+ ) -> Tensor:
112
+ raise NotImplementedError("Diffusion class missing denoise_fn")
113
+
114
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
115
+ raise NotImplementedError("Diffusion class missing forward function")
116
+
117
+
118
+ class VDiffusion(Diffusion):
119
+ alias = "v"
120
+
121
+ def __init__(self, net: nn.Module, *, sigma_distribution: Distribution):
122
+ super().__init__()
123
+ self.net = net
124
+ self.sigma_distribution = sigma_distribution
125
+
126
+ def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
127
+ angle = sigmas * pi / 2
128
+ alpha = torch.cos(angle)
129
+ beta = torch.sin(angle)
130
+ return alpha, beta
131
+
132
+ def denoise_fn(
133
+ self,
134
+ x_noisy: Tensor,
135
+ sigmas: Optional[Tensor] = None,
136
+ sigma: Optional[float] = None,
137
+ **kwargs,
138
+ ) -> Tensor:
139
+ batch_size, device = x_noisy.shape[0], x_noisy.device
140
+ sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
141
+ return self.net(x_noisy, sigmas, **kwargs)
142
+
143
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
144
+ batch_size, device = x.shape[0], x.device
145
+
146
+ # Sample amount of noise to add for each batch element
147
+ sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
148
+ sigmas_padded = rearrange(sigmas, "b -> b 1 1")
149
+
150
+ # Get noise
151
+ noise = default(noise, lambda: torch.randn_like(x))
152
+
153
+ # Combine input and noise weighted by half-circle
154
+ alpha, beta = self.get_alpha_beta(sigmas_padded)
155
+ x_noisy = x * alpha + noise * beta
156
+ x_target = noise * alpha - x * beta
157
+
158
+ # Denoise and return loss
159
+ x_denoised = self.denoise_fn(x_noisy, sigmas, **kwargs)
160
+ return F.mse_loss(x_denoised, x_target)
161
+
162
+
163
+ class KDiffusion(Diffusion):
164
+ """Elucidated Diffusion (Karras et al. 2022): https://arxiv.org/abs/2206.00364"""
165
+
166
+ alias = "k"
167
+
168
+ def __init__(
169
+ self,
170
+ net: nn.Module,
171
+ *,
172
+ sigma_distribution: Distribution,
173
+ sigma_data: float, # data distribution standard deviation
174
+ dynamic_threshold: float = 0.0,
175
+ ):
176
+ super().__init__()
177
+ self.net = net
178
+ self.sigma_data = sigma_data
179
+ self.sigma_distribution = sigma_distribution
180
+ self.dynamic_threshold = dynamic_threshold
181
+
182
+ def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
183
+ sigma_data = self.sigma_data
184
+ c_noise = torch.log(sigmas) * 0.25
185
+ sigmas = rearrange(sigmas, "b -> b 1 1")
186
+ c_skip = (sigma_data**2) / (sigmas**2 + sigma_data**2)
187
+ c_out = sigmas * sigma_data * (sigma_data**2 + sigmas**2) ** -0.5
188
+ c_in = (sigmas**2 + sigma_data**2) ** -0.5
189
+ return c_skip, c_out, c_in, c_noise
190
+
191
+ def denoise_fn(
192
+ self,
193
+ x_noisy: Tensor,
194
+ sigmas: Optional[Tensor] = None,
195
+ sigma: Optional[float] = None,
196
+ **kwargs,
197
+ ) -> Tensor:
198
+ batch_size, device = x_noisy.shape[0], x_noisy.device
199
+ sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
200
+
201
+ # Predict network output and add skip connection
202
+ c_skip, c_out, c_in, c_noise = self.get_scale_weights(sigmas)
203
+ x_pred = self.net(c_in * x_noisy, c_noise, **kwargs)
204
+ x_denoised = c_skip * x_noisy + c_out * x_pred
205
+
206
+ return x_denoised
207
+
208
+ def loss_weight(self, sigmas: Tensor) -> Tensor:
209
+ # Computes weight depending on data distribution
210
+ return (sigmas**2 + self.sigma_data**2) * (sigmas * self.sigma_data) ** -2
211
+
212
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
213
+ batch_size, device = x.shape[0], x.device
214
+ from einops import rearrange, reduce
215
+
216
+ # Sample amount of noise to add for each batch element
217
+ sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
218
+ sigmas_padded = rearrange(sigmas, "b -> b 1 1")
219
+
220
+ # Add noise to input
221
+ noise = default(noise, lambda: torch.randn_like(x))
222
+ x_noisy = x + sigmas_padded * noise
223
+
224
+ # Compute denoised values
225
+ x_denoised = self.denoise_fn(x_noisy, sigmas=sigmas, **kwargs)
226
+
227
+ # Compute weighted loss
228
+ losses = F.mse_loss(x_denoised, x, reduction="none")
229
+ losses = reduce(losses, "b ... -> b", "mean")
230
+ losses = losses * self.loss_weight(sigmas)
231
+ loss = losses.mean()
232
+ return loss
233
+
234
+
235
+ class VKDiffusion(Diffusion):
236
+ alias = "vk"
237
+
238
+ def __init__(self, net: nn.Module, *, sigma_distribution: Distribution):
239
+ super().__init__()
240
+ self.net = net
241
+ self.sigma_distribution = sigma_distribution
242
+
243
+ def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
244
+ sigma_data = 1.0
245
+ sigmas = rearrange(sigmas, "b -> b 1 1")
246
+ c_skip = (sigma_data**2) / (sigmas**2 + sigma_data**2)
247
+ c_out = -sigmas * sigma_data * (sigma_data**2 + sigmas**2) ** -0.5
248
+ c_in = (sigmas**2 + sigma_data**2) ** -0.5
249
+ return c_skip, c_out, c_in
250
+
251
+ def sigma_to_t(self, sigmas: Tensor) -> Tensor:
252
+ return sigmas.atan() / pi * 2
253
+
254
+ def t_to_sigma(self, t: Tensor) -> Tensor:
255
+ return (t * pi / 2).tan()
256
+
257
+ def denoise_fn(
258
+ self,
259
+ x_noisy: Tensor,
260
+ sigmas: Optional[Tensor] = None,
261
+ sigma: Optional[float] = None,
262
+ **kwargs,
263
+ ) -> Tensor:
264
+ batch_size, device = x_noisy.shape[0], x_noisy.device
265
+ sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
266
+
267
+ # Predict network output and add skip connection
268
+ c_skip, c_out, c_in = self.get_scale_weights(sigmas)
269
+ x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)
270
+ x_denoised = c_skip * x_noisy + c_out * x_pred
271
+ return x_denoised
272
+
273
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
274
+ batch_size, device = x.shape[0], x.device
275
+
276
+ # Sample amount of noise to add for each batch element
277
+ sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
278
+ sigmas_padded = rearrange(sigmas, "b -> b 1 1")
279
+
280
+ # Add noise to input
281
+ noise = default(noise, lambda: torch.randn_like(x))
282
+ x_noisy = x + sigmas_padded * noise
283
+
284
+ # Compute model output
285
+ c_skip, c_out, c_in = self.get_scale_weights(sigmas)
286
+ x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)
287
+
288
+ # Compute v-objective target
289
+ v_target = (x - c_skip * x_noisy) / (c_out + 1e-7)
290
+
291
+ # Compute loss
292
+ loss = F.mse_loss(x_pred, v_target)
293
+ return loss
294
+
295
+
296
+ """
297
+ Diffusion Sampling
298
+ """
299
+
300
+ """ Schedules """
301
+
302
+
303
+ class Schedule(nn.Module):
304
+ """Interface used by different sampling schedules"""
305
+
306
+ def forward(self, num_steps: int, device: torch.device) -> Tensor:
307
+ raise NotImplementedError()
308
+
309
+
310
+ class LinearSchedule(Schedule):
311
+ def forward(self, num_steps: int, device: Any) -> Tensor:
312
+ sigmas = torch.linspace(1, 0, num_steps + 1)[:-1]
313
+ return sigmas
314
+
315
+
316
+ class KarrasSchedule(Schedule):
317
+ """https://arxiv.org/abs/2206.00364 equation 5"""
318
+
319
+ def __init__(self, sigma_min: float, sigma_max: float, rho: float = 7.0):
320
+ super().__init__()
321
+ self.sigma_min = sigma_min
322
+ self.sigma_max = sigma_max
323
+ self.rho = rho
324
+
325
+ def forward(self, num_steps: int, device: Any) -> Tensor:
326
+ rho_inv = 1.0 / self.rho
327
+ steps = torch.arange(num_steps, device=device, dtype=torch.float32)
328
+ sigmas = (
329
+ self.sigma_max**rho_inv
330
+ + (steps / (num_steps - 1))
331
+ * (self.sigma_min**rho_inv - self.sigma_max**rho_inv)
332
+ ) ** self.rho
333
+ sigmas = F.pad(sigmas, pad=(0, 1), value=0.0)
334
+ return sigmas
335
+
336
+
337
+ """ Samplers """
338
+
339
+
340
+ class Sampler(nn.Module):
341
+ diffusion_types: List[Type[Diffusion]] = []
342
+
343
+ def forward(
344
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
345
+ ) -> Tensor:
346
+ raise NotImplementedError()
347
+
348
+ def inpaint(
349
+ self,
350
+ source: Tensor,
351
+ mask: Tensor,
352
+ fn: Callable,
353
+ sigmas: Tensor,
354
+ num_steps: int,
355
+ num_resamples: int,
356
+ ) -> Tensor:
357
+ raise NotImplementedError("Inpainting not available with current sampler")
358
+
359
+
360
+ class VSampler(Sampler):
361
+ diffusion_types = [VDiffusion]
362
+
363
+ def get_alpha_beta(self, sigma: float) -> Tuple[float, float]:
364
+ angle = sigma * pi / 2
365
+ alpha = cos(angle)
366
+ beta = sin(angle)
367
+ return alpha, beta
368
+
369
+ def forward(
370
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
371
+ ) -> Tensor:
372
+ x = sigmas[0] * noise
373
+ alpha, beta = self.get_alpha_beta(sigmas[0].item())
374
+
375
+ for i in range(num_steps - 1):
376
+ is_last = i == num_steps - 1
377
+
378
+ x_denoised = fn(x, sigma=sigmas[i])
379
+ x_pred = x * alpha - x_denoised * beta
380
+ x_eps = x * beta + x_denoised * alpha
381
+
382
+ if not is_last:
383
+ alpha, beta = self.get_alpha_beta(sigmas[i + 1].item())
384
+ x = x_pred * alpha + x_eps * beta
385
+
386
+ return x_pred
387
+
388
+
389
+ class KarrasSampler(Sampler):
390
+ """https://arxiv.org/abs/2206.00364 algorithm 1"""
391
+
392
+ diffusion_types = [KDiffusion, VKDiffusion]
393
+
394
+ def __init__(
395
+ self,
396
+ s_tmin: float = 0,
397
+ s_tmax: float = float("inf"),
398
+ s_churn: float = 0.0,
399
+ s_noise: float = 1.0,
400
+ ):
401
+ super().__init__()
402
+ self.s_tmin = s_tmin
403
+ self.s_tmax = s_tmax
404
+ self.s_noise = s_noise
405
+ self.s_churn = s_churn
406
+
407
+ def step(
408
+ self, x: Tensor, fn: Callable, sigma: float, sigma_next: float, gamma: float
409
+ ) -> Tensor:
410
+ """Algorithm 2 (step)"""
411
+ # Select temporarily increased noise level
412
+ sigma_hat = sigma + gamma * sigma
413
+ # Add noise to move from sigma to sigma_hat
414
+ epsilon = self.s_noise * torch.randn_like(x)
415
+ x_hat = x + sqrt(sigma_hat**2 - sigma**2) * epsilon
416
+ # Evaluate ∂x/∂sigma at sigma_hat
417
+ d = (x_hat - fn(x_hat, sigma=sigma_hat)) / sigma_hat
418
+ # Take euler step from sigma_hat to sigma_next
419
+ x_next = x_hat + (sigma_next - sigma_hat) * d
420
+ # Second order correction
421
+ if sigma_next != 0:
422
+ model_out_next = fn(x_next, sigma=sigma_next)
423
+ d_prime = (x_next - model_out_next) / sigma_next
424
+ x_next = x_hat + 0.5 * (sigma - sigma_hat) * (d + d_prime)
425
+ return x_next
426
+
427
+ def forward(
428
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
429
+ ) -> Tensor:
430
+ x = sigmas[0] * noise
431
+ # Compute gammas
432
+ gammas = torch.where(
433
+ (sigmas >= self.s_tmin) & (sigmas <= self.s_tmax),
434
+ min(self.s_churn / num_steps, sqrt(2) - 1),
435
+ 0.0,
436
+ )
437
+ # Denoise to sample
438
+ for i in range(num_steps - 1):
439
+ x = self.step(
440
+ x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1], gamma=gammas[i] # type: ignore # noqa
441
+ )
442
+
443
+ return x
444
+
445
+
446
+ class AEulerSampler(Sampler):
447
+ diffusion_types = [KDiffusion, VKDiffusion]
448
+
449
+ def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float]:
450
+ sigma_up = sqrt(sigma_next**2 * (sigma**2 - sigma_next**2) / sigma**2)
451
+ sigma_down = sqrt(sigma_next**2 - sigma_up**2)
452
+ return sigma_up, sigma_down
453
+
454
+ def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
455
+ # Sigma steps
456
+ sigma_up, sigma_down = self.get_sigmas(sigma, sigma_next)
457
+ # Derivative at sigma (∂x/∂sigma)
458
+ d = (x - fn(x, sigma=sigma)) / sigma
459
+ # Euler method
460
+ x_next = x + d * (sigma_down - sigma)
461
+ # Add randomness
462
+ x_next = x_next + torch.randn_like(x) * sigma_up
463
+ return x_next
464
+
465
+ def forward(
466
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
467
+ ) -> Tensor:
468
+ x = sigmas[0] * noise
469
+ # Denoise to sample
470
+ for i in range(num_steps - 1):
471
+ x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
472
+ return x
473
+
474
+
475
+ class ADPM2Sampler(Sampler):
476
+ """https://www.desmos.com/calculator/jbxjlqd9mb"""
477
+
478
+ diffusion_types = [KDiffusion, VKDiffusion]
479
+
480
+ def __init__(self, rho: float = 1.0):
481
+ super().__init__()
482
+ self.rho = rho
483
+
484
+ def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float, float]:
485
+ r = self.rho
486
+ sigma_up = sqrt(sigma_next**2 * (sigma**2 - sigma_next**2) / sigma**2)
487
+ sigma_down = sqrt(sigma_next**2 - sigma_up**2)
488
+ sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r
489
+ return sigma_up, sigma_down, sigma_mid
490
+
491
+ def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
492
+ # Sigma steps
493
+ sigma_up, sigma_down, sigma_mid = self.get_sigmas(sigma, sigma_next)
494
+ # Derivative at sigma (∂x/∂sigma)
495
+ d = (x - fn(x, sigma=sigma)) / sigma
496
+ # Denoise to midpoint
497
+ x_mid = x + d * (sigma_mid - sigma)
498
+ # Derivative at sigma_mid (∂x_mid/∂sigma_mid)
499
+ d_mid = (x_mid - fn(x_mid, sigma=sigma_mid)) / sigma_mid
500
+ # Denoise to next
501
+ x = x + d_mid * (sigma_down - sigma)
502
+ # Add randomness
503
+ x_next = x + torch.randn_like(x) * sigma_up
504
+ return x_next
505
+
506
+ def forward(
507
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
508
+ ) -> Tensor:
509
+ x = sigmas[0] * noise
510
+ # Denoise to sample
511
+ for i in range(num_steps - 1):
512
+ x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
513
+ return x
514
+
515
+ def inpaint(
516
+ self,
517
+ source: Tensor,
518
+ mask: Tensor,
519
+ fn: Callable,
520
+ sigmas: Tensor,
521
+ num_steps: int,
522
+ num_resamples: int,
523
+ ) -> Tensor:
524
+ x = sigmas[0] * torch.randn_like(source)
525
+
526
+ for i in range(num_steps - 1):
527
+ # Noise source to current noise level
528
+ source_noisy = source + sigmas[i] * torch.randn_like(source)
529
+ for r in range(num_resamples):
530
+ # Merge noisy source and current then denoise
531
+ x = source_noisy * mask + x * ~mask
532
+ x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
533
+ # Renoise if not last resample step
534
+ if r < num_resamples - 1:
535
+ sigma = sqrt(sigmas[i] ** 2 - sigmas[i + 1] ** 2)
536
+ x = x + sigma * torch.randn_like(x)
537
+
538
+ return source * mask + x * ~mask
539
+
540
+
541
+ """ Main Classes """
542
+
543
+
544
+ class DiffusionSampler(nn.Module):
545
+ def __init__(
546
+ self,
547
+ diffusion: Diffusion,
548
+ *,
549
+ sampler: Sampler,
550
+ sigma_schedule: Schedule,
551
+ num_steps: Optional[int] = None,
552
+ clamp: bool = True,
553
+ ):
554
+ super().__init__()
555
+ self.denoise_fn = diffusion.denoise_fn
556
+ self.sampler = sampler
557
+ self.sigma_schedule = sigma_schedule
558
+ self.num_steps = num_steps
559
+ self.clamp = clamp
560
+
561
+ # Check sampler is compatible with diffusion type
562
+ sampler_class = sampler.__class__.__name__
563
+ diffusion_class = diffusion.__class__.__name__
564
+ message = f"{sampler_class} incompatible with {diffusion_class}"
565
+ assert diffusion.alias in [t.alias for t in sampler.diffusion_types], message
566
+
567
+ def forward(
568
+ self, noise: Tensor, num_steps: Optional[int] = None, **kwargs
569
+ ) -> Tensor:
570
+ device = noise.device
571
+ num_steps = default(num_steps, self.num_steps) # type: ignore
572
+ assert exists(num_steps), "Parameter `num_steps` must be provided"
573
+ # Compute sigmas using schedule
574
+ sigmas = self.sigma_schedule(num_steps, device)
575
+ # Append additional kwargs to denoise function (used e.g. for conditional unet)
576
+ fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs}) # noqa
577
+ # Sample using sampler
578
+ x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps)
579
+ x = x.clamp(-1.0, 1.0) if self.clamp else x
580
+ return x
581
+
582
+
583
+ class DiffusionInpainter(nn.Module):
584
+ def __init__(
585
+ self,
586
+ diffusion: Diffusion,
587
+ *,
588
+ num_steps: int,
589
+ num_resamples: int,
590
+ sampler: Sampler,
591
+ sigma_schedule: Schedule,
592
+ ):
593
+ super().__init__()
594
+ self.denoise_fn = diffusion.denoise_fn
595
+ self.num_steps = num_steps
596
+ self.num_resamples = num_resamples
597
+ self.inpaint_fn = sampler.inpaint
598
+ self.sigma_schedule = sigma_schedule
599
+
600
+ @torch.no_grad()
601
+ def forward(self, inpaint: Tensor, inpaint_mask: Tensor) -> Tensor:
602
+ x = self.inpaint_fn(
603
+ source=inpaint,
604
+ mask=inpaint_mask,
605
+ fn=self.denoise_fn,
606
+ sigmas=self.sigma_schedule(self.num_steps, inpaint.device),
607
+ num_steps=self.num_steps,
608
+ num_resamples=self.num_resamples,
609
+ )
610
+ return x
611
+
612
+
613
+ def sequential_mask(like: Tensor, start: int) -> Tensor:
614
+ length, device = like.shape[2], like.device
615
+ mask = torch.ones_like(like, dtype=torch.bool)
616
+ mask[:, :, start:] = torch.zeros((length - start,), device=device)
617
+ return mask
618
+
619
+
620
+ class SpanBySpanComposer(nn.Module):
621
+ def __init__(
622
+ self,
623
+ inpainter: DiffusionInpainter,
624
+ *,
625
+ num_spans: int,
626
+ ):
627
+ super().__init__()
628
+ self.inpainter = inpainter
629
+ self.num_spans = num_spans
630
+
631
+ def forward(self, start: Tensor, keep_start: bool = False) -> Tensor:
632
+ half_length = start.shape[2] // 2
633
+
634
+ spans = list(start.chunk(chunks=2, dim=-1)) if keep_start else []
635
+ # Inpaint second half from first half
636
+ inpaint = torch.zeros_like(start)
637
+ inpaint[:, :, :half_length] = start[:, :, half_length:]
638
+ inpaint_mask = sequential_mask(like=start, start=half_length)
639
+
640
+ for i in range(self.num_spans):
641
+ # Inpaint second half
642
+ span = self.inpainter(inpaint=inpaint, inpaint_mask=inpaint_mask)
643
+ # Replace first half with generated second half
644
+ second_half = span[:, :, half_length:]
645
+ inpaint[:, :, :half_length] = second_half
646
+ # Save generated span
647
+ spans.append(second_half)
648
+
649
+ return torch.cat(spans, dim=2)
650
+
651
+
652
+ class XDiffusion(nn.Module):
653
+ def __init__(self, type: str, net: nn.Module, **kwargs):
654
+ super().__init__()
655
+
656
+ diffusion_classes = [VDiffusion, KDiffusion, VKDiffusion]
657
+ aliases = [t.alias for t in diffusion_classes] # type: ignore
658
+ message = f"type='{type}' must be one of {*aliases,}"
659
+ assert type in aliases, message
660
+ self.net = net
661
+
662
+ for XDiffusion in diffusion_classes:
663
+ if XDiffusion.alias == type: # type: ignore
664
+ self.diffusion = XDiffusion(net=net, **kwargs)
665
+
666
+ def forward(self, *args, **kwargs) -> Tensor:
667
+ return self.diffusion(*args, **kwargs)
668
+
669
+ def sample(
670
+ self,
671
+ noise: Tensor,
672
+ num_steps: int,
673
+ sigma_schedule: Schedule,
674
+ sampler: Sampler,
675
+ clamp: bool,
676
+ **kwargs,
677
+ ) -> Tensor:
678
+ diffusion_sampler = DiffusionSampler(
679
+ diffusion=self.diffusion,
680
+ sampler=sampler,
681
+ sigma_schedule=sigma_schedule,
682
+ num_steps=num_steps,
683
+ clamp=clamp,
684
+ )
685
+ return diffusion_sampler(noise, **kwargs)
src/Modules/diffusion/utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import reduce
2
+ from inspect import isfunction
3
+ from math import ceil, floor, log2, pi
4
+ from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from torch import Generator, Tensor
10
+ from typing_extensions import TypeGuard
11
+
12
+ T = TypeVar("T")
13
+
14
+
15
+ def exists(val: Optional[T]) -> TypeGuard[T]:
16
+ return val is not None
17
+
18
+
19
+ def iff(condition: bool, value: T) -> Optional[T]:
20
+ return value if condition else None
21
+
22
+
23
+ def is_sequence(obj: T) -> TypeGuard[Union[list, tuple]]:
24
+ return isinstance(obj, list) or isinstance(obj, tuple)
25
+
26
+
27
+ def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
28
+ if exists(val):
29
+ return val
30
+ return d() if isfunction(d) else d
31
+
32
+
33
+ def to_list(val: Union[T, Sequence[T]]) -> List[T]:
34
+ if isinstance(val, tuple):
35
+ return list(val)
36
+ if isinstance(val, list):
37
+ return val
38
+ return [val] # type: ignore
39
+
40
+
41
+ def prod(vals: Sequence[int]) -> int:
42
+ return reduce(lambda x, y: x * y, vals)
43
+
44
+
45
+ def closest_power_2(x: float) -> int:
46
+ exponent = log2(x)
47
+ distance_fn = lambda z: abs(x - 2**z) # noqa
48
+ exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
49
+ return 2 ** int(exponent_closest)
50
+
51
+
52
+ def rand_bool(shape, proba, device=None):
53
+ if proba == 1:
54
+ return torch.ones(shape, device=device, dtype=torch.bool)
55
+ elif proba == 0:
56
+ return torch.zeros(shape, device=device, dtype=torch.bool)
57
+ else:
58
+ return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
59
+
60
+
61
+ """
62
+ Kwargs Utils
63
+ """
64
+
65
+
66
+ def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
67
+ return_dicts: Tuple[Dict, Dict] = ({}, {})
68
+ for key in d.keys():
69
+ no_prefix = int(not key.startswith(prefix))
70
+ return_dicts[no_prefix][key] = d[key]
71
+ return return_dicts
72
+
73
+
74
+ def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
75
+ kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
76
+ if keep_prefix:
77
+ return kwargs_with_prefix, kwargs
78
+ kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
79
+ return kwargs_no_prefix, kwargs
80
+
81
+
82
+ def prefix_dict(prefix: str, d: Dict) -> Dict:
83
+ return {prefix + str(k): v for k, v in d.items()}
src/Modules/discriminators.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, AvgPool1d, Conv2d
5
+ from torch.nn.utils import weight_norm, spectral_norm
6
+
7
+ from .utils import get_padding
8
+
9
+ LRELU_SLOPE = 0.1
10
+
11
+
12
+ def stft(x, fft_size, hop_size, win_length, window):
13
+ """Perform STFT and convert to magnitude spectrogram.
14
+ Args:
15
+ x (Tensor): Input signal tensor (B, T).
16
+ fft_size (int): FFT size.
17
+ hop_size (int): Hop size.
18
+ win_length (int): Window length.
19
+ window (str): Window function type.
20
+ Returns:
21
+ Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
22
+ """
23
+ x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
24
+ real = x_stft[..., 0]
25
+ imag = x_stft[..., 1]
26
+
27
+ return torch.abs(x_stft).transpose(2, 1)
28
+
29
+
30
+ class SpecDiscriminator(nn.Module):
31
+ """docstring for Discriminator."""
32
+
33
+ def __init__(
34
+ self,
35
+ fft_size=1024,
36
+ shift_size=120,
37
+ win_length=600,
38
+ window="hann_window",
39
+ use_spectral_norm=False,
40
+ ):
41
+ super(SpecDiscriminator, self).__init__()
42
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
43
+ self.fft_size = fft_size
44
+ self.shift_size = shift_size
45
+ self.win_length = win_length
46
+ self.window = getattr(torch, window)(win_length)
47
+ self.discriminators = nn.ModuleList(
48
+ [
49
+ norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
50
+ norm_f(
51
+ nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))
52
+ ),
53
+ norm_f(
54
+ nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))
55
+ ),
56
+ norm_f(
57
+ nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))
58
+ ),
59
+ norm_f(
60
+ nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
61
+ ),
62
+ ]
63
+ )
64
+
65
+ self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
66
+
67
+ def forward(self, y):
68
+ fmap = []
69
+ y = y.squeeze(1)
70
+ y = stft(
71
+ y,
72
+ self.fft_size,
73
+ self.shift_size,
74
+ self.win_length,
75
+ self.window.to(y.get_device()),
76
+ )
77
+ y = y.unsqueeze(1)
78
+ for i, d in enumerate(self.discriminators):
79
+ y = d(y)
80
+ y = F.leaky_relu(y, LRELU_SLOPE)
81
+ fmap.append(y)
82
+
83
+ y = self.out(y)
84
+ fmap.append(y)
85
+
86
+ return torch.flatten(y, 1, -1), fmap
87
+
88
+
89
+ class MultiResSpecDiscriminator(torch.nn.Module):
90
+ def __init__(
91
+ self,
92
+ fft_sizes=[1024, 2048, 512],
93
+ hop_sizes=[120, 240, 50],
94
+ win_lengths=[600, 1200, 240],
95
+ window="hann_window",
96
+ ):
97
+ super(MultiResSpecDiscriminator, self).__init__()
98
+ self.discriminators = nn.ModuleList(
99
+ [
100
+ SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
101
+ SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
102
+ SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window),
103
+ ]
104
+ )
105
+
106
+ def forward(self, y, y_hat):
107
+ y_d_rs = []
108
+ y_d_gs = []
109
+ fmap_rs = []
110
+ fmap_gs = []
111
+ for i, d in enumerate(self.discriminators):
112
+ y_d_r, fmap_r = d(y)
113
+ y_d_g, fmap_g = d(y_hat)
114
+ y_d_rs.append(y_d_r)
115
+ fmap_rs.append(fmap_r)
116
+ y_d_gs.append(y_d_g)
117
+ fmap_gs.append(fmap_g)
118
+
119
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
120
+
121
+
122
+ class DiscriminatorP(torch.nn.Module):
123
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
124
+ super(DiscriminatorP, self).__init__()
125
+ self.period = period
126
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
127
+ self.convs = nn.ModuleList(
128
+ [
129
+ norm_f(
130
+ Conv2d(
131
+ 1,
132
+ 32,
133
+ (kernel_size, 1),
134
+ (stride, 1),
135
+ padding=(get_padding(5, 1), 0),
136
+ )
137
+ ),
138
+ norm_f(
139
+ Conv2d(
140
+ 32,
141
+ 128,
142
+ (kernel_size, 1),
143
+ (stride, 1),
144
+ padding=(get_padding(5, 1), 0),
145
+ )
146
+ ),
147
+ norm_f(
148
+ Conv2d(
149
+ 128,
150
+ 512,
151
+ (kernel_size, 1),
152
+ (stride, 1),
153
+ padding=(get_padding(5, 1), 0),
154
+ )
155
+ ),
156
+ norm_f(
157
+ Conv2d(
158
+ 512,
159
+ 1024,
160
+ (kernel_size, 1),
161
+ (stride, 1),
162
+ padding=(get_padding(5, 1), 0),
163
+ )
164
+ ),
165
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
166
+ ]
167
+ )
168
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
169
+
170
+ def forward(self, x):
171
+ fmap = []
172
+
173
+ # 1d to 2d
174
+ b, c, t = x.shape
175
+ if t % self.period != 0: # pad first
176
+ n_pad = self.period - (t % self.period)
177
+ x = F.pad(x, (0, n_pad), "reflect")
178
+ t = t + n_pad
179
+ x = x.view(b, c, t // self.period, self.period)
180
+
181
+ for l in self.convs:
182
+ x = l(x)
183
+ x = F.leaky_relu(x, LRELU_SLOPE)
184
+ fmap.append(x)
185
+ x = self.conv_post(x)
186
+ fmap.append(x)
187
+ x = torch.flatten(x, 1, -1)
188
+
189
+ return x, fmap
190
+
191
+
192
+ class MultiPeriodDiscriminator(torch.nn.Module):
193
+ def __init__(self):
194
+ super(MultiPeriodDiscriminator, self).__init__()
195
+ self.discriminators = nn.ModuleList(
196
+ [
197
+ DiscriminatorP(2),
198
+ DiscriminatorP(3),
199
+ DiscriminatorP(5),
200
+ DiscriminatorP(7),
201
+ DiscriminatorP(11),
202
+ ]
203
+ )
204
+
205
+ def forward(self, y, y_hat):
206
+ y_d_rs = []
207
+ y_d_gs = []
208
+ fmap_rs = []
209
+ fmap_gs = []
210
+ for i, d in enumerate(self.discriminators):
211
+ y_d_r, fmap_r = d(y)
212
+ y_d_g, fmap_g = d(y_hat)
213
+ y_d_rs.append(y_d_r)
214
+ fmap_rs.append(fmap_r)
215
+ y_d_gs.append(y_d_g)
216
+ fmap_gs.append(fmap_g)
217
+
218
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
219
+
220
+
221
+ class WavLMDiscriminator(nn.Module):
222
+ """docstring for Discriminator."""
223
+
224
+ def __init__(
225
+ self, slm_hidden=768, slm_layers=13, initial_channel=64, use_spectral_norm=False
226
+ ):
227
+ super(WavLMDiscriminator, self).__init__()
228
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
229
+ self.pre = norm_f(
230
+ Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0)
231
+ )
232
+
233
+ self.convs = nn.ModuleList(
234
+ [
235
+ norm_f(
236
+ nn.Conv1d(
237
+ initial_channel, initial_channel * 2, kernel_size=5, padding=2
238
+ )
239
+ ),
240
+ norm_f(
241
+ nn.Conv1d(
242
+ initial_channel * 2,
243
+ initial_channel * 4,
244
+ kernel_size=5,
245
+ padding=2,
246
+ )
247
+ ),
248
+ norm_f(
249
+ nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)
250
+ ),
251
+ ]
252
+ )
253
+
254
+ self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1))
255
+
256
+ def forward(self, x):
257
+ x = self.pre(x)
258
+
259
+ fmap = []
260
+ for l in self.convs:
261
+ x = l(x)
262
+ x = F.leaky_relu(x, LRELU_SLOPE)
263
+ fmap.append(x)
264
+ x = self.conv_post(x)
265
+ x = torch.flatten(x, 1, -1)
266
+
267
+ return x
src/Modules/hifigan.py ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
6
+ from .utils import init_weights, get_padding
7
+
8
+ import math
9
+ import random
10
+ import numpy as np
11
+
12
+ LRELU_SLOPE = 0.1
13
+
14
+
15
+ class AdaIN1d(nn.Module):
16
+ def __init__(self, style_dim, num_features):
17
+ super().__init__()
18
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
19
+ self.fc = nn.Linear(style_dim, num_features * 2)
20
+
21
+ def forward(self, x, s):
22
+ h = self.fc(s)
23
+ h = h.view(h.size(0), h.size(1), 1)
24
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
25
+ return (1 + gamma) * self.norm(x) + beta
26
+
27
+
28
+ class AdaINResBlock1(torch.nn.Module):
29
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
30
+ super(AdaINResBlock1, self).__init__()
31
+ self.convs1 = nn.ModuleList(
32
+ [
33
+ weight_norm(
34
+ Conv1d(
35
+ channels,
36
+ channels,
37
+ kernel_size,
38
+ 1,
39
+ dilation=dilation[0],
40
+ padding=get_padding(kernel_size, dilation[0]),
41
+ )
42
+ ),
43
+ weight_norm(
44
+ Conv1d(
45
+ channels,
46
+ channels,
47
+ kernel_size,
48
+ 1,
49
+ dilation=dilation[1],
50
+ padding=get_padding(kernel_size, dilation[1]),
51
+ )
52
+ ),
53
+ weight_norm(
54
+ Conv1d(
55
+ channels,
56
+ channels,
57
+ kernel_size,
58
+ 1,
59
+ dilation=dilation[2],
60
+ padding=get_padding(kernel_size, dilation[2]),
61
+ )
62
+ ),
63
+ ]
64
+ )
65
+ self.convs1.apply(init_weights)
66
+
67
+ self.convs2 = nn.ModuleList(
68
+ [
69
+ weight_norm(
70
+ Conv1d(
71
+ channels,
72
+ channels,
73
+ kernel_size,
74
+ 1,
75
+ dilation=1,
76
+ padding=get_padding(kernel_size, 1),
77
+ )
78
+ ),
79
+ weight_norm(
80
+ Conv1d(
81
+ channels,
82
+ channels,
83
+ kernel_size,
84
+ 1,
85
+ dilation=1,
86
+ padding=get_padding(kernel_size, 1),
87
+ )
88
+ ),
89
+ weight_norm(
90
+ Conv1d(
91
+ channels,
92
+ channels,
93
+ kernel_size,
94
+ 1,
95
+ dilation=1,
96
+ padding=get_padding(kernel_size, 1),
97
+ )
98
+ ),
99
+ ]
100
+ )
101
+ self.convs2.apply(init_weights)
102
+
103
+ self.adain1 = nn.ModuleList(
104
+ [
105
+ AdaIN1d(style_dim, channels),
106
+ AdaIN1d(style_dim, channels),
107
+ AdaIN1d(style_dim, channels),
108
+ ]
109
+ )
110
+
111
+ self.adain2 = nn.ModuleList(
112
+ [
113
+ AdaIN1d(style_dim, channels),
114
+ AdaIN1d(style_dim, channels),
115
+ AdaIN1d(style_dim, channels),
116
+ ]
117
+ )
118
+
119
+ self.alpha1 = nn.ParameterList(
120
+ [nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))]
121
+ )
122
+ self.alpha2 = nn.ParameterList(
123
+ [nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))]
124
+ )
125
+
126
+ def forward(self, x, s):
127
+ for c1, c2, n1, n2, a1, a2 in zip(
128
+ self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2
129
+ ):
130
+ xt = n1(x, s)
131
+ xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
132
+ xt = c1(xt)
133
+ xt = n2(xt, s)
134
+ xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
135
+ xt = c2(xt)
136
+ x = xt + x
137
+ return x
138
+
139
+ def remove_weight_norm(self):
140
+ for l in self.convs1:
141
+ remove_weight_norm(l)
142
+ for l in self.convs2:
143
+ remove_weight_norm(l)
144
+
145
+
146
+ class SineGen(torch.nn.Module):
147
+ """Definition of sine generator
148
+ SineGen(samp_rate, harmonic_num = 0,
149
+ sine_amp = 0.1, noise_std = 0.003,
150
+ voiced_threshold = 0,
151
+ flag_for_pulse=False)
152
+ samp_rate: sampling rate in Hz
153
+ harmonic_num: number of harmonic overtones (default 0)
154
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
155
+ noise_std: std of Gaussian noise (default 0.003)
156
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
157
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
158
+ Note: when flag_for_pulse is True, the first time step of a voiced
159
+ segment is always sin(np.pi) or cos(0)
160
+ """
161
+
162
+ def __init__(
163
+ self,
164
+ samp_rate,
165
+ upsample_scale,
166
+ harmonic_num=0,
167
+ sine_amp=0.1,
168
+ noise_std=0.003,
169
+ voiced_threshold=0,
170
+ flag_for_pulse=False,
171
+ ):
172
+ super(SineGen, self).__init__()
173
+ self.sine_amp = sine_amp
174
+ self.noise_std = noise_std
175
+ self.harmonic_num = harmonic_num
176
+ self.dim = self.harmonic_num + 1
177
+ self.sampling_rate = samp_rate
178
+ self.voiced_threshold = voiced_threshold
179
+ self.flag_for_pulse = flag_for_pulse
180
+ self.upsample_scale = upsample_scale
181
+
182
+ def _f02uv(self, f0):
183
+ # generate uv signal
184
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
185
+ return uv
186
+
187
+ def _f02sine(self, f0_values):
188
+ """f0_values: (batchsize, length, dim)
189
+ where dim indicates fundamental tone and overtones
190
+ """
191
+ # convert to F0 in rad. The interger part n can be ignored
192
+ # because 2 * np.pi * n doesn't affect phase
193
+ rad_values = (f0_values / self.sampling_rate) % 1
194
+
195
+ # initial phase noise (no noise for fundamental component)
196
+ rand_ini = torch.rand(
197
+ f0_values.shape[0], f0_values.shape[2], device=f0_values.device
198
+ )
199
+ rand_ini[:, 0] = 0
200
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
201
+
202
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
203
+ if not self.flag_for_pulse:
204
+ # # for normal case
205
+
206
+ # # To prevent torch.cumsum numerical overflow,
207
+ # # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
208
+ # # Buffer tmp_over_one_idx indicates the time step to add -1.
209
+ # # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
210
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
211
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
212
+ # cumsum_shift = torch.zeros_like(rad_values)
213
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
214
+
215
+ # phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
216
+ rad_values = torch.nn.functional.interpolate(
217
+ rad_values.transpose(1, 2),
218
+ scale_factor=1 / self.upsample_scale,
219
+ mode="linear",
220
+ ).transpose(1, 2)
221
+
222
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
223
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
224
+ # cumsum_shift = torch.zeros_like(rad_values)
225
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
226
+
227
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
228
+ phase = torch.nn.functional.interpolate(
229
+ phase.transpose(1, 2) * self.upsample_scale,
230
+ scale_factor=self.upsample_scale,
231
+ mode="linear",
232
+ ).transpose(1, 2)
233
+ sines = torch.sin(phase)
234
+
235
+ else:
236
+ # If necessary, make sure that the first time step of every
237
+ # voiced segments is sin(pi) or cos(0)
238
+ # This is used for pulse-train generation
239
+
240
+ # identify the last time step in unvoiced segments
241
+ uv = self._f02uv(f0_values)
242
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
243
+ uv_1[:, -1, :] = 1
244
+ u_loc = (uv < 1) * (uv_1 > 0)
245
+
246
+ # get the instantanouse phase
247
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
248
+ # different batch needs to be processed differently
249
+ for idx in range(f0_values.shape[0]):
250
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
251
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
252
+ # stores the accumulation of i.phase within
253
+ # each voiced segments
254
+ tmp_cumsum[idx, :, :] = 0
255
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
256
+
257
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
258
+ # within the previous voiced segment.
259
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
260
+
261
+ # get the sines
262
+ sines = torch.cos(i_phase * 2 * np.pi)
263
+ return sines
264
+
265
+ def forward(self, f0):
266
+ """sine_tensor, uv = forward(f0)
267
+ input F0: tensor(batchsize=1, length, dim=1)
268
+ f0 for unvoiced steps should be 0
269
+ output sine_tensor: tensor(batchsize=1, length, dim)
270
+ output uv: tensor(batchsize=1, length, 1)
271
+ """
272
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
273
+ # fundamental component
274
+ fn = torch.multiply(
275
+ f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)
276
+ )
277
+
278
+ # generate sine waveforms
279
+ sine_waves = self._f02sine(fn) * self.sine_amp
280
+
281
+ # generate uv signal
282
+ # uv = torch.ones(f0.shape)
283
+ # uv = uv * (f0 > self.voiced_threshold)
284
+ uv = self._f02uv(f0)
285
+
286
+ # noise: for unvoiced should be similar to sine_amp
287
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
288
+ # . for voiced regions is self.noise_std
289
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
290
+ noise = noise_amp * torch.randn_like(sine_waves)
291
+
292
+ # first: set the unvoiced part to 0 by uv
293
+ # then: additive noise
294
+ sine_waves = sine_waves * uv + noise
295
+ return sine_waves, uv, noise
296
+
297
+
298
+ class SourceModuleHnNSF(torch.nn.Module):
299
+ """SourceModule for hn-nsf
300
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
301
+ add_noise_std=0.003, voiced_threshod=0)
302
+ sampling_rate: sampling_rate in Hz
303
+ harmonic_num: number of harmonic above F0 (default: 0)
304
+ sine_amp: amplitude of sine source signal (default: 0.1)
305
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
306
+ note that amplitude of noise in unvoiced is decided
307
+ by sine_amp
308
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
309
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
310
+ F0_sampled (batchsize, length, 1)
311
+ Sine_source (batchsize, length, 1)
312
+ noise_source (batchsize, length 1)
313
+ uv (batchsize, length, 1)
314
+ """
315
+
316
+ def __init__(
317
+ self,
318
+ sampling_rate,
319
+ upsample_scale,
320
+ harmonic_num=0,
321
+ sine_amp=0.1,
322
+ add_noise_std=0.003,
323
+ voiced_threshod=0,
324
+ ):
325
+ super(SourceModuleHnNSF, self).__init__()
326
+
327
+ self.sine_amp = sine_amp
328
+ self.noise_std = add_noise_std
329
+
330
+ # to produce sine waveforms
331
+ self.l_sin_gen = SineGen(
332
+ sampling_rate,
333
+ upsample_scale,
334
+ harmonic_num,
335
+ sine_amp,
336
+ add_noise_std,
337
+ voiced_threshod,
338
+ )
339
+
340
+ # to merge source harmonics into a single excitation
341
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
342
+ self.l_tanh = torch.nn.Tanh()
343
+
344
+ def forward(self, x):
345
+ """
346
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
347
+ F0_sampled (batchsize, length, 1)
348
+ Sine_source (batchsize, length, 1)
349
+ noise_source (batchsize, length 1)
350
+ """
351
+ # source for harmonic branch
352
+ with torch.no_grad():
353
+ sine_wavs, uv, _ = self.l_sin_gen(x)
354
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
355
+
356
+ # source for noise branch, in the same shape as uv
357
+ noise = torch.randn_like(uv) * self.sine_amp / 3
358
+ return sine_merge, noise, uv
359
+
360
+
361
+ def padDiff(x):
362
+ return F.pad(
363
+ F.pad(x, (0, 0, -1, 1), "constant", 0) - x, (0, 0, 0, -1), "constant", 0
364
+ )
365
+
366
+
367
+ class Generator(torch.nn.Module):
368
+ def __init__(
369
+ self,
370
+ style_dim,
371
+ resblock_kernel_sizes,
372
+ upsample_rates,
373
+ upsample_initial_channel,
374
+ resblock_dilation_sizes,
375
+ upsample_kernel_sizes,
376
+ ):
377
+ super(Generator, self).__init__()
378
+ self.num_kernels = len(resblock_kernel_sizes)
379
+ self.num_upsamples = len(upsample_rates)
380
+ resblock = AdaINResBlock1
381
+
382
+ self.m_source = SourceModuleHnNSF(
383
+ sampling_rate=24000,
384
+ upsample_scale=np.prod(upsample_rates),
385
+ harmonic_num=8,
386
+ voiced_threshod=10,
387
+ )
388
+
389
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
390
+ self.noise_convs = nn.ModuleList()
391
+ self.ups = nn.ModuleList()
392
+ self.noise_res = nn.ModuleList()
393
+
394
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
395
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
396
+
397
+ self.ups.append(
398
+ weight_norm(
399
+ ConvTranspose1d(
400
+ upsample_initial_channel // (2**i),
401
+ upsample_initial_channel // (2 ** (i + 1)),
402
+ k,
403
+ u,
404
+ padding=(u // 2 + u % 2),
405
+ output_padding=u % 2,
406
+ )
407
+ )
408
+ )
409
+
410
+ if i + 1 < len(upsample_rates): #
411
+ stride_f0 = np.prod(upsample_rates[i + 1 :])
412
+ self.noise_convs.append(
413
+ Conv1d(
414
+ 1,
415
+ c_cur,
416
+ kernel_size=stride_f0 * 2,
417
+ stride=stride_f0,
418
+ padding=(stride_f0 + 1) // 2,
419
+ )
420
+ )
421
+ self.noise_res.append(resblock(c_cur, 7, [1, 3, 5], style_dim))
422
+ else:
423
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
424
+ self.noise_res.append(resblock(c_cur, 11, [1, 3, 5], style_dim))
425
+
426
+ self.resblocks = nn.ModuleList()
427
+
428
+ self.alphas = nn.ParameterList()
429
+ self.alphas.append(nn.Parameter(torch.ones(1, upsample_initial_channel, 1)))
430
+
431
+ for i in range(len(self.ups)):
432
+ ch = upsample_initial_channel // (2 ** (i + 1))
433
+ self.alphas.append(nn.Parameter(torch.ones(1, ch, 1)))
434
+
435
+ for j, (k, d) in enumerate(
436
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
437
+ ):
438
+ self.resblocks.append(resblock(ch, k, d, style_dim))
439
+
440
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
441
+ self.ups.apply(init_weights)
442
+ self.conv_post.apply(init_weights)
443
+
444
+ def forward(self, x, s, f0):
445
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
446
+
447
+ har_source, noi_source, uv = self.m_source(f0)
448
+ har_source = har_source.transpose(1, 2)
449
+
450
+ for i in range(self.num_upsamples):
451
+ x = x + (1 / self.alphas[i]) * (torch.sin(self.alphas[i] * x) ** 2)
452
+ x_source = self.noise_convs[i](har_source)
453
+ x_source = self.noise_res[i](x_source, s)
454
+
455
+ x = self.ups[i](x)
456
+ x = x + x_source
457
+
458
+ xs = None
459
+ for j in range(self.num_kernels):
460
+ if xs is None:
461
+ xs = self.resblocks[i * self.num_kernels + j](x, s)
462
+ else:
463
+ xs += self.resblocks[i * self.num_kernels + j](x, s)
464
+ x = xs / self.num_kernels
465
+ x = x + (1 / self.alphas[i + 1]) * (torch.sin(self.alphas[i + 1] * x) ** 2)
466
+ x = self.conv_post(x)
467
+ x = torch.tanh(x)
468
+
469
+ return x
470
+
471
+ def remove_weight_norm(self):
472
+ print("Removing weight norm...")
473
+ for l in self.ups:
474
+ remove_weight_norm(l)
475
+ for l in self.resblocks:
476
+ l.remove_weight_norm()
477
+ remove_weight_norm(self.conv_pre)
478
+ remove_weight_norm(self.conv_post)
479
+
480
+
481
+ class AdainResBlk1d(nn.Module):
482
+ def __init__(
483
+ self,
484
+ dim_in,
485
+ dim_out,
486
+ style_dim=64,
487
+ actv=nn.LeakyReLU(0.2),
488
+ upsample="none",
489
+ dropout_p=0.0,
490
+ ):
491
+ super().__init__()
492
+ self.actv = actv
493
+ self.upsample_type = upsample
494
+ self.upsample = UpSample1d(upsample)
495
+ self.learned_sc = dim_in != dim_out
496
+ self._build_weights(dim_in, dim_out, style_dim)
497
+ self.dropout = nn.Dropout(dropout_p)
498
+
499
+ if upsample == "none":
500
+ self.pool = nn.Identity()
501
+ else:
502
+ self.pool = weight_norm(
503
+ nn.ConvTranspose1d(
504
+ dim_in,
505
+ dim_in,
506
+ kernel_size=3,
507
+ stride=2,
508
+ groups=dim_in,
509
+ padding=1,
510
+ output_padding=1,
511
+ )
512
+ )
513
+
514
+ def _build_weights(self, dim_in, dim_out, style_dim):
515
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
516
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
517
+ self.norm1 = AdaIN1d(style_dim, dim_in)
518
+ self.norm2 = AdaIN1d(style_dim, dim_out)
519
+ if self.learned_sc:
520
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
521
+
522
+ def _shortcut(self, x):
523
+ x = self.upsample(x)
524
+ if self.learned_sc:
525
+ x = self.conv1x1(x)
526
+ return x
527
+
528
+ def _residual(self, x, s):
529
+ x = self.norm1(x, s)
530
+ x = self.actv(x)
531
+ x = self.pool(x)
532
+ x = self.conv1(self.dropout(x))
533
+ x = self.norm2(x, s)
534
+ x = self.actv(x)
535
+ x = self.conv2(self.dropout(x))
536
+ return x
537
+
538
+ def forward(self, x, s):
539
+ out = self._residual(x, s)
540
+ out = (out + self._shortcut(x)) / math.sqrt(2)
541
+ return out
542
+
543
+
544
+ class UpSample1d(nn.Module):
545
+ def __init__(self, layer_type):
546
+ super().__init__()
547
+ self.layer_type = layer_type
548
+
549
+ def forward(self, x):
550
+ if self.layer_type == "none":
551
+ return x
552
+ else:
553
+ return F.interpolate(x, scale_factor=2, mode="nearest")
554
+
555
+
556
+ class Decoder(nn.Module):
557
+ def __init__(
558
+ self,
559
+ dim_in=512,
560
+ F0_channel=512,
561
+ style_dim=64,
562
+ dim_out=80,
563
+ resblock_kernel_sizes=[3, 7, 11],
564
+ upsample_rates=[10, 5, 3, 2],
565
+ upsample_initial_channel=512,
566
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
567
+ upsample_kernel_sizes=[20, 10, 6, 4],
568
+ ):
569
+ super().__init__()
570
+
571
+ self.decode = nn.ModuleList()
572
+
573
+ self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
574
+
575
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
576
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
577
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
578
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
579
+
580
+ self.F0_conv = weight_norm(
581
+ nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)
582
+ )
583
+
584
+ self.N_conv = weight_norm(
585
+ nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)
586
+ )
587
+
588
+ self.asr_res = nn.Sequential(
589
+ weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
590
+ )
591
+
592
+ self.generator = Generator(
593
+ style_dim,
594
+ resblock_kernel_sizes,
595
+ upsample_rates,
596
+ upsample_initial_channel,
597
+ resblock_dilation_sizes,
598
+ upsample_kernel_sizes,
599
+ )
600
+
601
+ def forward(self, asr, F0_curve, N, s):
602
+ if self.training:
603
+ downlist = [0, 3, 7]
604
+ F0_down = downlist[random.randint(0, 2)]
605
+ downlist = [0, 3, 7, 15]
606
+ N_down = downlist[random.randint(0, 3)]
607
+ if F0_down:
608
+ F0_curve = (
609
+ nn.functional.conv1d(
610
+ F0_curve.unsqueeze(1),
611
+ torch.ones(1, 1, F0_down).to("cuda"),
612
+ padding=F0_down // 2,
613
+ ).squeeze(1)
614
+ / F0_down
615
+ )
616
+ if N_down:
617
+ N = (
618
+ nn.functional.conv1d(
619
+ N.unsqueeze(1),
620
+ torch.ones(1, 1, N_down).to("cuda"),
621
+ padding=N_down // 2,
622
+ ).squeeze(1)
623
+ / N_down
624
+ )
625
+
626
+ F0 = self.F0_conv(F0_curve.unsqueeze(1))
627
+ N = self.N_conv(N.unsqueeze(1))
628
+
629
+ x = torch.cat([asr, F0, N], axis=1)
630
+ x = self.encode(x, s)
631
+
632
+ asr_res = self.asr_res(asr)
633
+
634
+ res = True
635
+ for block in self.decode:
636
+ if res:
637
+ x = torch.cat([x, asr_res, F0, N], axis=1)
638
+ x = block(x, s)
639
+ if block.upsample_type != "none":
640
+ res = False
641
+
642
+ x = self.generator(x, s, F0_curve)
643
+ return x
src/Modules/istftnet.py ADDED
@@ -0,0 +1,720 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
6
+ from .utils import init_weights, get_padding
7
+
8
+ import math
9
+ import random
10
+ import numpy as np
11
+ from scipy.signal import get_window
12
+
13
+ LRELU_SLOPE = 0.1
14
+
15
+
16
+ class AdaIN1d(nn.Module):
17
+ def __init__(self, style_dim, num_features):
18
+ super().__init__()
19
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
20
+ self.fc = nn.Linear(style_dim, num_features * 2)
21
+
22
+ def forward(self, x, s):
23
+ h = self.fc(s)
24
+ h = h.view(h.size(0), h.size(1), 1)
25
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
26
+ return (1 + gamma) * self.norm(x) + beta
27
+
28
+
29
+ class AdaINResBlock1(torch.nn.Module):
30
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
31
+ super(AdaINResBlock1, self).__init__()
32
+ self.convs1 = nn.ModuleList(
33
+ [
34
+ weight_norm(
35
+ Conv1d(
36
+ channels,
37
+ channels,
38
+ kernel_size,
39
+ 1,
40
+ dilation=dilation[0],
41
+ padding=get_padding(kernel_size, dilation[0]),
42
+ )
43
+ ),
44
+ weight_norm(
45
+ Conv1d(
46
+ channels,
47
+ channels,
48
+ kernel_size,
49
+ 1,
50
+ dilation=dilation[1],
51
+ padding=get_padding(kernel_size, dilation[1]),
52
+ )
53
+ ),
54
+ weight_norm(
55
+ Conv1d(
56
+ channels,
57
+ channels,
58
+ kernel_size,
59
+ 1,
60
+ dilation=dilation[2],
61
+ padding=get_padding(kernel_size, dilation[2]),
62
+ )
63
+ ),
64
+ ]
65
+ )
66
+ self.convs1.apply(init_weights)
67
+
68
+ self.convs2 = nn.ModuleList(
69
+ [
70
+ weight_norm(
71
+ Conv1d(
72
+ channels,
73
+ channels,
74
+ kernel_size,
75
+ 1,
76
+ dilation=1,
77
+ padding=get_padding(kernel_size, 1),
78
+ )
79
+ ),
80
+ weight_norm(
81
+ Conv1d(
82
+ channels,
83
+ channels,
84
+ kernel_size,
85
+ 1,
86
+ dilation=1,
87
+ padding=get_padding(kernel_size, 1),
88
+ )
89
+ ),
90
+ weight_norm(
91
+ Conv1d(
92
+ channels,
93
+ channels,
94
+ kernel_size,
95
+ 1,
96
+ dilation=1,
97
+ padding=get_padding(kernel_size, 1),
98
+ )
99
+ ),
100
+ ]
101
+ )
102
+ self.convs2.apply(init_weights)
103
+
104
+ self.adain1 = nn.ModuleList(
105
+ [
106
+ AdaIN1d(style_dim, channels),
107
+ AdaIN1d(style_dim, channels),
108
+ AdaIN1d(style_dim, channels),
109
+ ]
110
+ )
111
+
112
+ self.adain2 = nn.ModuleList(
113
+ [
114
+ AdaIN1d(style_dim, channels),
115
+ AdaIN1d(style_dim, channels),
116
+ AdaIN1d(style_dim, channels),
117
+ ]
118
+ )
119
+
120
+ self.alpha1 = nn.ParameterList(
121
+ [nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))]
122
+ )
123
+ self.alpha2 = nn.ParameterList(
124
+ [nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))]
125
+ )
126
+
127
+ def forward(self, x, s):
128
+ for c1, c2, n1, n2, a1, a2 in zip(
129
+ self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2
130
+ ):
131
+ xt = n1(x, s)
132
+ xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
133
+ xt = c1(xt)
134
+ xt = n2(xt, s)
135
+ xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
136
+ xt = c2(xt)
137
+ x = xt + x
138
+ return x
139
+
140
+ def remove_weight_norm(self):
141
+ for l in self.convs1:
142
+ remove_weight_norm(l)
143
+ for l in self.convs2:
144
+ remove_weight_norm(l)
145
+
146
+
147
+ class TorchSTFT(torch.nn.Module):
148
+ def __init__(
149
+ self, filter_length=800, hop_length=200, win_length=800, window="hann"
150
+ ):
151
+ super().__init__()
152
+ self.filter_length = filter_length
153
+ self.hop_length = hop_length
154
+ self.win_length = win_length
155
+ self.window = torch.from_numpy(
156
+ get_window(window, win_length, fftbins=True).astype(np.float32)
157
+ )
158
+
159
+ def transform(self, input_data):
160
+ forward_transform = torch.stft(
161
+ input_data,
162
+ self.filter_length,
163
+ self.hop_length,
164
+ self.win_length,
165
+ window=self.window.to(input_data.device),
166
+ return_complex=True,
167
+ )
168
+
169
+ return torch.abs(forward_transform), torch.angle(forward_transform)
170
+
171
+ def inverse(self, magnitude, phase):
172
+ inverse_transform = torch.istft(
173
+ magnitude * torch.exp(phase * 1j),
174
+ self.filter_length,
175
+ self.hop_length,
176
+ self.win_length,
177
+ window=self.window.to(magnitude.device),
178
+ )
179
+
180
+ return inverse_transform.unsqueeze(
181
+ -2
182
+ ) # unsqueeze to stay consistent with conv_transpose1d implementation
183
+
184
+ def forward(self, input_data):
185
+ self.magnitude, self.phase = self.transform(input_data)
186
+ reconstruction = self.inverse(self.magnitude, self.phase)
187
+ return reconstruction
188
+
189
+
190
+ class SineGen(torch.nn.Module):
191
+ """Definition of sine generator
192
+ SineGen(samp_rate, harmonic_num = 0,
193
+ sine_amp = 0.1, noise_std = 0.003,
194
+ voiced_threshold = 0,
195
+ flag_for_pulse=False)
196
+ samp_rate: sampling rate in Hz
197
+ harmonic_num: number of harmonic overtones (default 0)
198
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
199
+ noise_std: std of Gaussian noise (default 0.003)
200
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
201
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
202
+ Note: when flag_for_pulse is True, the first time step of a voiced
203
+ segment is always sin(np.pi) or cos(0)
204
+ """
205
+
206
+ def __init__(
207
+ self,
208
+ samp_rate,
209
+ upsample_scale,
210
+ harmonic_num=0,
211
+ sine_amp=0.1,
212
+ noise_std=0.003,
213
+ voiced_threshold=0,
214
+ flag_for_pulse=False,
215
+ ):
216
+ super(SineGen, self).__init__()
217
+ self.sine_amp = sine_amp
218
+ self.noise_std = noise_std
219
+ self.harmonic_num = harmonic_num
220
+ self.dim = self.harmonic_num + 1
221
+ self.sampling_rate = samp_rate
222
+ self.voiced_threshold = voiced_threshold
223
+ self.flag_for_pulse = flag_for_pulse
224
+ self.upsample_scale = upsample_scale
225
+
226
+ def _f02uv(self, f0):
227
+ # generate uv signal
228
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
229
+ return uv
230
+
231
+ def _f02sine(self, f0_values):
232
+ """f0_values: (batchsize, length, dim)
233
+ where dim indicates fundamental tone and overtones
234
+ """
235
+ # convert to F0 in rad. The interger part n can be ignored
236
+ # because 2 * np.pi * n doesn't affect phase
237
+ rad_values = (f0_values / self.sampling_rate) % 1
238
+
239
+ # initial phase noise (no noise for fundamental component)
240
+ rand_ini = torch.rand(
241
+ f0_values.shape[0], f0_values.shape[2], device=f0_values.device
242
+ )
243
+ rand_ini[:, 0] = 0
244
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
245
+
246
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
247
+ if not self.flag_for_pulse:
248
+ # # for normal case
249
+
250
+ # # To prevent torch.cumsum numerical overflow,
251
+ # # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
252
+ # # Buffer tmp_over_one_idx indicates the time step to add -1.
253
+ # # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
254
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
255
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
256
+ # cumsum_shift = torch.zeros_like(rad_values)
257
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
258
+
259
+ # phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
260
+ rad_values = torch.nn.functional.interpolate(
261
+ rad_values.transpose(1, 2),
262
+ scale_factor=1 / self.upsample_scale,
263
+ mode="linear",
264
+ ).transpose(1, 2)
265
+
266
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
267
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
268
+ # cumsum_shift = torch.zeros_like(rad_values)
269
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
270
+
271
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
272
+ phase = torch.nn.functional.interpolate(
273
+ phase.transpose(1, 2) * self.upsample_scale,
274
+ scale_factor=self.upsample_scale,
275
+ mode="linear",
276
+ ).transpose(1, 2)
277
+ sines = torch.sin(phase)
278
+
279
+ else:
280
+ # If necessary, make sure that the first time step of every
281
+ # voiced segments is sin(pi) or cos(0)
282
+ # This is used for pulse-train generation
283
+
284
+ # identify the last time step in unvoiced segments
285
+ uv = self._f02uv(f0_values)
286
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
287
+ uv_1[:, -1, :] = 1
288
+ u_loc = (uv < 1) * (uv_1 > 0)
289
+
290
+ # get the instantanouse phase
291
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
292
+ # different batch needs to be processed differently
293
+ for idx in range(f0_values.shape[0]):
294
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
295
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
296
+ # stores the accumulation of i.phase within
297
+ # each voiced segments
298
+ tmp_cumsum[idx, :, :] = 0
299
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
300
+
301
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
302
+ # within the previous voiced segment.
303
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
304
+
305
+ # get the sines
306
+ sines = torch.cos(i_phase * 2 * np.pi)
307
+ return sines
308
+
309
+ def forward(self, f0):
310
+ """sine_tensor, uv = forward(f0)
311
+ input F0: tensor(batchsize=1, length, dim=1)
312
+ f0 for unvoiced steps should be 0
313
+ output sine_tensor: tensor(batchsize=1, length, dim)
314
+ output uv: tensor(batchsize=1, length, 1)
315
+ """
316
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
317
+ # fundamental component
318
+ fn = torch.multiply(
319
+ f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)
320
+ )
321
+
322
+ # generate sine waveforms
323
+ sine_waves = self._f02sine(fn) * self.sine_amp
324
+
325
+ # generate uv signal
326
+ # uv = torch.ones(f0.shape)
327
+ # uv = uv * (f0 > self.voiced_threshold)
328
+ uv = self._f02uv(f0)
329
+
330
+ # noise: for unvoiced should be similar to sine_amp
331
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
332
+ # . for voiced regions is self.noise_std
333
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
334
+ noise = noise_amp * torch.randn_like(sine_waves)
335
+
336
+ # first: set the unvoiced part to 0 by uv
337
+ # then: additive noise
338
+ sine_waves = sine_waves * uv + noise
339
+ return sine_waves, uv, noise
340
+
341
+
342
+ class SourceModuleHnNSF(torch.nn.Module):
343
+ """SourceModule for hn-nsf
344
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
345
+ add_noise_std=0.003, voiced_threshod=0)
346
+ sampling_rate: sampling_rate in Hz
347
+ harmonic_num: number of harmonic above F0 (default: 0)
348
+ sine_amp: amplitude of sine source signal (default: 0.1)
349
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
350
+ note that amplitude of noise in unvoiced is decided
351
+ by sine_amp
352
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
353
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
354
+ F0_sampled (batchsize, length, 1)
355
+ Sine_source (batchsize, length, 1)
356
+ noise_source (batchsize, length 1)
357
+ uv (batchsize, length, 1)
358
+ """
359
+
360
+ def __init__(
361
+ self,
362
+ sampling_rate,
363
+ upsample_scale,
364
+ harmonic_num=0,
365
+ sine_amp=0.1,
366
+ add_noise_std=0.003,
367
+ voiced_threshod=0,
368
+ ):
369
+ super(SourceModuleHnNSF, self).__init__()
370
+
371
+ self.sine_amp = sine_amp
372
+ self.noise_std = add_noise_std
373
+
374
+ # to produce sine waveforms
375
+ self.l_sin_gen = SineGen(
376
+ sampling_rate,
377
+ upsample_scale,
378
+ harmonic_num,
379
+ sine_amp,
380
+ add_noise_std,
381
+ voiced_threshod,
382
+ )
383
+
384
+ # to merge source harmonics into a single excitation
385
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
386
+ self.l_tanh = torch.nn.Tanh()
387
+
388
+ def forward(self, x):
389
+ """
390
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
391
+ F0_sampled (batchsize, length, 1)
392
+ Sine_source (batchsize, length, 1)
393
+ noise_source (batchsize, length 1)
394
+ """
395
+ # source for harmonic branch
396
+ with torch.no_grad():
397
+ sine_wavs, uv, _ = self.l_sin_gen(x)
398
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
399
+
400
+ # source for noise branch, in the same shape as uv
401
+ noise = torch.randn_like(uv) * self.sine_amp / 3
402
+ return sine_merge, noise, uv
403
+
404
+
405
+ def padDiff(x):
406
+ return F.pad(
407
+ F.pad(x, (0, 0, -1, 1), "constant", 0) - x, (0, 0, 0, -1), "constant", 0
408
+ )
409
+
410
+
411
+ class Generator(torch.nn.Module):
412
+ def __init__(
413
+ self,
414
+ style_dim,
415
+ resblock_kernel_sizes,
416
+ upsample_rates,
417
+ upsample_initial_channel,
418
+ resblock_dilation_sizes,
419
+ upsample_kernel_sizes,
420
+ gen_istft_n_fft,
421
+ gen_istft_hop_size,
422
+ ):
423
+ super(Generator, self).__init__()
424
+
425
+ self.num_kernels = len(resblock_kernel_sizes)
426
+ self.num_upsamples = len(upsample_rates)
427
+ resblock = AdaINResBlock1
428
+
429
+ self.m_source = SourceModuleHnNSF(
430
+ sampling_rate=24000,
431
+ upsample_scale=np.prod(upsample_rates) * gen_istft_hop_size,
432
+ harmonic_num=8,
433
+ voiced_threshod=10,
434
+ )
435
+ self.f0_upsamp = torch.nn.Upsample(
436
+ scale_factor=np.prod(upsample_rates) * gen_istft_hop_size
437
+ )
438
+ self.noise_convs = nn.ModuleList()
439
+ self.noise_res = nn.ModuleList()
440
+
441
+ self.ups = nn.ModuleList()
442
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
443
+ self.ups.append(
444
+ weight_norm(
445
+ ConvTranspose1d(
446
+ upsample_initial_channel // (2**i),
447
+ upsample_initial_channel // (2 ** (i + 1)),
448
+ k,
449
+ u,
450
+ padding=(k - u) // 2,
451
+ )
452
+ )
453
+ )
454
+
455
+ self.resblocks = nn.ModuleList()
456
+ for i in range(len(self.ups)):
457
+ ch = upsample_initial_channel // (2 ** (i + 1))
458
+ for j, (k, d) in enumerate(
459
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
460
+ ):
461
+ self.resblocks.append(resblock(ch, k, d, style_dim))
462
+
463
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
464
+
465
+ if i + 1 < len(upsample_rates): #
466
+ stride_f0 = np.prod(upsample_rates[i + 1 :])
467
+ self.noise_convs.append(
468
+ Conv1d(
469
+ gen_istft_n_fft + 2,
470
+ c_cur,
471
+ kernel_size=stride_f0 * 2,
472
+ stride=stride_f0,
473
+ padding=(stride_f0 + 1) // 2,
474
+ )
475
+ )
476
+ self.noise_res.append(resblock(c_cur, 7, [1, 3, 5], style_dim))
477
+ else:
478
+ self.noise_convs.append(
479
+ Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1)
480
+ )
481
+ self.noise_res.append(resblock(c_cur, 11, [1, 3, 5], style_dim))
482
+
483
+ self.post_n_fft = gen_istft_n_fft
484
+ self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
485
+ self.ups.apply(init_weights)
486
+ self.conv_post.apply(init_weights)
487
+ self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
488
+ self.stft = TorchSTFT(
489
+ filter_length=gen_istft_n_fft,
490
+ hop_length=gen_istft_hop_size,
491
+ win_length=gen_istft_n_fft,
492
+ )
493
+
494
+ def forward(self, x, s, f0):
495
+ with torch.no_grad():
496
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
497
+
498
+ har_source, noi_source, uv = self.m_source(f0)
499
+ har_source = har_source.transpose(1, 2).squeeze(1)
500
+ har_spec, har_phase = self.stft.transform(har_source)
501
+ har = torch.cat([har_spec, har_phase], dim=1)
502
+
503
+ for i in range(self.num_upsamples):
504
+ x = F.leaky_relu(x, LRELU_SLOPE)
505
+ x_source = self.noise_convs[i](har)
506
+ x_source = self.noise_res[i](x_source, s)
507
+
508
+ x = self.ups[i](x)
509
+ if i == self.num_upsamples - 1:
510
+ x = self.reflection_pad(x)
511
+
512
+ x = x + x_source
513
+ xs = None
514
+ for j in range(self.num_kernels):
515
+ if xs is None:
516
+ xs = self.resblocks[i * self.num_kernels + j](x, s)
517
+ else:
518
+ xs += self.resblocks[i * self.num_kernels + j](x, s)
519
+ x = xs / self.num_kernels
520
+ x = F.leaky_relu(x)
521
+ x = self.conv_post(x)
522
+ spec = torch.exp(x[:, : self.post_n_fft // 2 + 1, :])
523
+ phase = torch.sin(x[:, self.post_n_fft // 2 + 1 :, :])
524
+ return self.stft.inverse(spec, phase)
525
+
526
+ def fw_phase(self, x, s):
527
+ for i in range(self.num_upsamples):
528
+ x = F.leaky_relu(x, LRELU_SLOPE)
529
+ x = self.ups[i](x)
530
+ xs = None
531
+ for j in range(self.num_kernels):
532
+ if xs is None:
533
+ xs = self.resblocks[i * self.num_kernels + j](x, s)
534
+ else:
535
+ xs += self.resblocks[i * self.num_kernels + j](x, s)
536
+ x = xs / self.num_kernels
537
+ x = F.leaky_relu(x)
538
+ x = self.reflection_pad(x)
539
+ x = self.conv_post(x)
540
+ spec = torch.exp(x[:, : self.post_n_fft // 2 + 1, :])
541
+ phase = torch.sin(x[:, self.post_n_fft // 2 + 1 :, :])
542
+ return spec, phase
543
+
544
+ def remove_weight_norm(self):
545
+ print("Removing weight norm...")
546
+ for l in self.ups:
547
+ remove_weight_norm(l)
548
+ for l in self.resblocks:
549
+ l.remove_weight_norm()
550
+ remove_weight_norm(self.conv_pre)
551
+ remove_weight_norm(self.conv_post)
552
+
553
+
554
+ class AdainResBlk1d(nn.Module):
555
+ def __init__(
556
+ self,
557
+ dim_in,
558
+ dim_out,
559
+ style_dim=64,
560
+ actv=nn.LeakyReLU(0.2),
561
+ upsample="none",
562
+ dropout_p=0.0,
563
+ ):
564
+ super().__init__()
565
+ self.actv = actv
566
+ self.upsample_type = upsample
567
+ self.upsample = UpSample1d(upsample)
568
+ self.learned_sc = dim_in != dim_out
569
+ self._build_weights(dim_in, dim_out, style_dim)
570
+ self.dropout = nn.Dropout(dropout_p)
571
+
572
+ if upsample == "none":
573
+ self.pool = nn.Identity()
574
+ else:
575
+ self.pool = weight_norm(
576
+ nn.ConvTranspose1d(
577
+ dim_in,
578
+ dim_in,
579
+ kernel_size=3,
580
+ stride=2,
581
+ groups=dim_in,
582
+ padding=1,
583
+ output_padding=1,
584
+ )
585
+ )
586
+
587
+ def _build_weights(self, dim_in, dim_out, style_dim):
588
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
589
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
590
+ self.norm1 = AdaIN1d(style_dim, dim_in)
591
+ self.norm2 = AdaIN1d(style_dim, dim_out)
592
+ if self.learned_sc:
593
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
594
+
595
+ def _shortcut(self, x):
596
+ x = self.upsample(x)
597
+ if self.learned_sc:
598
+ x = self.conv1x1(x)
599
+ return x
600
+
601
+ def _residual(self, x, s):
602
+ x = self.norm1(x, s)
603
+ x = self.actv(x)
604
+ x = self.pool(x)
605
+ x = self.conv1(self.dropout(x))
606
+ x = self.norm2(x, s)
607
+ x = self.actv(x)
608
+ x = self.conv2(self.dropout(x))
609
+ return x
610
+
611
+ def forward(self, x, s):
612
+ out = self._residual(x, s)
613
+ out = (out + self._shortcut(x)) / math.sqrt(2)
614
+ return out
615
+
616
+
617
+ class UpSample1d(nn.Module):
618
+ def __init__(self, layer_type):
619
+ super().__init__()
620
+ self.layer_type = layer_type
621
+
622
+ def forward(self, x):
623
+ if self.layer_type == "none":
624
+ return x
625
+ else:
626
+ return F.interpolate(x, scale_factor=2, mode="nearest")
627
+
628
+
629
+ class Decoder(nn.Module):
630
+ def __init__(
631
+ self,
632
+ dim_in=512,
633
+ F0_channel=512,
634
+ style_dim=64,
635
+ dim_out=80,
636
+ resblock_kernel_sizes=[3, 7, 11],
637
+ upsample_rates=[10, 6],
638
+ upsample_initial_channel=512,
639
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
640
+ upsample_kernel_sizes=[20, 12],
641
+ gen_istft_n_fft=20,
642
+ gen_istft_hop_size=5,
643
+ ):
644
+ super().__init__()
645
+
646
+ self.decode = nn.ModuleList()
647
+
648
+ self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
649
+
650
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
651
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
652
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
653
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
654
+
655
+ self.F0_conv = weight_norm(
656
+ nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)
657
+ )
658
+
659
+ self.N_conv = weight_norm(
660
+ nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)
661
+ )
662
+
663
+ self.asr_res = nn.Sequential(
664
+ weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
665
+ )
666
+
667
+ self.generator = Generator(
668
+ style_dim,
669
+ resblock_kernel_sizes,
670
+ upsample_rates,
671
+ upsample_initial_channel,
672
+ resblock_dilation_sizes,
673
+ upsample_kernel_sizes,
674
+ gen_istft_n_fft,
675
+ gen_istft_hop_size,
676
+ )
677
+
678
+ def forward(self, asr, F0_curve, N, s):
679
+ if self.training:
680
+ downlist = [0, 3, 7]
681
+ F0_down = downlist[random.randint(0, 2)]
682
+ downlist = [0, 3, 7, 15]
683
+ N_down = downlist[random.randint(0, 3)]
684
+ if F0_down:
685
+ F0_curve = (
686
+ nn.functional.conv1d(
687
+ F0_curve.unsqueeze(1),
688
+ torch.ones(1, 1, F0_down).to("cuda"),
689
+ padding=F0_down // 2,
690
+ ).squeeze(1)
691
+ / F0_down
692
+ )
693
+ if N_down:
694
+ N = (
695
+ nn.functional.conv1d(
696
+ N.unsqueeze(1),
697
+ torch.ones(1, 1, N_down).to("cuda"),
698
+ padding=N_down // 2,
699
+ ).squeeze(1)
700
+ / N_down
701
+ )
702
+
703
+ F0 = self.F0_conv(F0_curve.unsqueeze(1))
704
+ N = self.N_conv(N.unsqueeze(1))
705
+
706
+ x = torch.cat([asr, F0, N], axis=1)
707
+ x = self.encode(x, s)
708
+
709
+ asr_res = self.asr_res(asr)
710
+
711
+ res = True
712
+ for block in self.decode:
713
+ if res:
714
+ x = torch.cat([x, asr_res, F0, N], axis=1)
715
+ x = block(x, s)
716
+ if block.upsample_type != "none":
717
+ res = False
718
+
719
+ x = self.generator(x, s, F0_curve)
720
+ return x
src/Modules/slmadv.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class SLMAdversarialLoss(torch.nn.Module):
7
+ def __init__(
8
+ self,
9
+ model,
10
+ wl,
11
+ sampler,
12
+ min_len,
13
+ max_len,
14
+ batch_percentage=0.5,
15
+ skip_update=10,
16
+ sig=1.5,
17
+ ):
18
+ super(SLMAdversarialLoss, self).__init__()
19
+ self.model = model
20
+ self.wl = wl
21
+ self.sampler = sampler
22
+
23
+ self.min_len = min_len
24
+ self.max_len = max_len
25
+ self.batch_percentage = batch_percentage
26
+
27
+ self.sig = sig
28
+ self.skip_update = skip_update
29
+
30
+ def forward(
31
+ self,
32
+ iters,
33
+ y_rec_gt,
34
+ y_rec_gt_pred,
35
+ waves,
36
+ mel_input_length,
37
+ ref_text,
38
+ ref_lengths,
39
+ use_ind,
40
+ s_trg,
41
+ ref_s=None,
42
+ ):
43
+ text_mask = length_to_mask(ref_lengths).to(ref_text.device)
44
+ bert_dur = self.model.bert(ref_text, attention_mask=(~text_mask).int())
45
+ d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2)
46
+
47
+ if use_ind and np.random.rand() < 0.5:
48
+ s_preds = s_trg
49
+ else:
50
+ num_steps = np.random.randint(3, 5)
51
+ if ref_s is not None:
52
+ s_preds = self.sampler(
53
+ noise=torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device),
54
+ embedding=bert_dur,
55
+ embedding_scale=1,
56
+ features=ref_s, # reference from the same speaker as the embedding
57
+ embedding_mask_proba=0.1,
58
+ num_steps=num_steps,
59
+ ).squeeze(1)
60
+ else:
61
+ s_preds = self.sampler(
62
+ noise=torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device),
63
+ embedding=bert_dur,
64
+ embedding_scale=1,
65
+ embedding_mask_proba=0.1,
66
+ num_steps=num_steps,
67
+ ).squeeze(1)
68
+
69
+ s_dur = s_preds[:, 128:]
70
+ s = s_preds[:, :128]
71
+
72
+ d, _ = self.model.predictor(
73
+ d_en,
74
+ s_dur,
75
+ ref_lengths,
76
+ torch.randn(ref_lengths.shape[0], ref_lengths.max(), 2).to(ref_text.device),
77
+ text_mask,
78
+ )
79
+
80
+ bib = 0
81
+
82
+ output_lengths = []
83
+ attn_preds = []
84
+
85
+ # differentiable duration modeling
86
+ for _s2s_pred, _text_length in zip(d, ref_lengths):
87
+ _s2s_pred_org = _s2s_pred[:_text_length, :]
88
+
89
+ _s2s_pred = torch.sigmoid(_s2s_pred_org)
90
+ _dur_pred = _s2s_pred.sum(axis=-1)
91
+
92
+ l = int(torch.round(_s2s_pred.sum()).item())
93
+ t = torch.arange(0, l).expand(l)
94
+
95
+ t = (
96
+ torch.arange(0, l)
97
+ .unsqueeze(0)
98
+ .expand((len(_s2s_pred), l))
99
+ .to(ref_text.device)
100
+ )
101
+ loc = torch.cumsum(_dur_pred, dim=0) - _dur_pred / 2
102
+
103
+ h = torch.exp(
104
+ -0.5 * torch.square(t - (l - loc.unsqueeze(-1))) / (self.sig) ** 2
105
+ )
106
+
107
+ out = torch.nn.functional.conv1d(
108
+ _s2s_pred_org.unsqueeze(0),
109
+ h.unsqueeze(1),
110
+ padding=h.shape[-1] - 1,
111
+ groups=int(_text_length),
112
+ )[..., :l]
113
+ attn_preds.append(F.softmax(out.squeeze(), dim=0))
114
+
115
+ output_lengths.append(l)
116
+
117
+ max_len = max(output_lengths)
118
+
119
+ with torch.no_grad():
120
+ t_en = self.model.text_encoder(ref_text, ref_lengths, text_mask)
121
+
122
+ s2s_attn = torch.zeros(len(ref_lengths), int(ref_lengths.max()), max_len).to(
123
+ ref_text.device
124
+ )
125
+ for bib in range(len(output_lengths)):
126
+ s2s_attn[bib, : ref_lengths[bib], : output_lengths[bib]] = attn_preds[bib]
127
+
128
+ asr_pred = t_en @ s2s_attn
129
+
130
+ _, p_pred = self.model.predictor(d_en, s_dur, ref_lengths, s2s_attn, text_mask)
131
+
132
+ mel_len = max(int(min(output_lengths) / 2 - 1), self.min_len // 2)
133
+ mel_len = min(mel_len, self.max_len // 2)
134
+
135
+ # get clips
136
+
137
+ en = []
138
+ p_en = []
139
+ sp = []
140
+
141
+ F0_fakes = []
142
+ N_fakes = []
143
+
144
+ wav = []
145
+
146
+ for bib in range(len(output_lengths)):
147
+ mel_length_pred = output_lengths[bib]
148
+ mel_length_gt = int(mel_input_length[bib].item() / 2)
149
+ if mel_length_gt <= mel_len or mel_length_pred <= mel_len:
150
+ continue
151
+
152
+ sp.append(s_preds[bib])
153
+
154
+ random_start = np.random.randint(0, mel_length_pred - mel_len)
155
+ en.append(asr_pred[bib, :, random_start : random_start + mel_len])
156
+ p_en.append(p_pred[bib, :, random_start : random_start + mel_len])
157
+
158
+ # get ground truth clips
159
+ random_start = np.random.randint(0, mel_length_gt - mel_len)
160
+ y = waves[bib][
161
+ (random_start * 2) * 300 : ((random_start + mel_len) * 2) * 300
162
+ ]
163
+ wav.append(torch.from_numpy(y).to(ref_text.device))
164
+
165
+ if len(wav) >= self.batch_percentage * len(
166
+ waves
167
+ ): # prevent OOM due to longer lengths
168
+ break
169
+
170
+ if len(sp) <= 1:
171
+ return None
172
+
173
+ sp = torch.stack(sp)
174
+ wav = torch.stack(wav).float()
175
+ en = torch.stack(en)
176
+ p_en = torch.stack(p_en)
177
+
178
+ F0_fake, N_fake = self.model.predictor.F0Ntrain(p_en, sp[:, 128:])
179
+ y_pred = self.model.decoder(en, F0_fake, N_fake, sp[:, :128])
180
+
181
+ # discriminator loss
182
+ if (iters + 1) % self.skip_update == 0:
183
+ if np.random.randint(0, 2) == 0:
184
+ wav = y_rec_gt_pred
185
+ use_rec = True
186
+ else:
187
+ use_rec = False
188
+
189
+ crop_size = min(wav.size(-1), y_pred.size(-1))
190
+ if (
191
+ use_rec
192
+ ): # use reconstructed (shorter lengths), do length invariant regularization
193
+ if wav.size(-1) > y_pred.size(-1):
194
+ real_GP = wav[:, :, :crop_size]
195
+ out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze())
196
+ out_org = self.wl.discriminator_forward(wav.detach().squeeze())
197
+ loss_reg = F.l1_loss(out_crop, out_org[..., : out_crop.size(-1)])
198
+
199
+ if np.random.randint(0, 2) == 0:
200
+ d_loss = self.wl.discriminator(
201
+ real_GP.detach().squeeze(), y_pred.detach().squeeze()
202
+ ).mean()
203
+ else:
204
+ d_loss = self.wl.discriminator(
205
+ wav.detach().squeeze(), y_pred.detach().squeeze()
206
+ ).mean()
207
+ else:
208
+ real_GP = y_pred[:, :, :crop_size]
209
+ out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze())
210
+ out_org = self.wl.discriminator_forward(y_pred.detach().squeeze())
211
+ loss_reg = F.l1_loss(out_crop, out_org[..., : out_crop.size(-1)])
212
+
213
+ if np.random.randint(0, 2) == 0:
214
+ d_loss = self.wl.discriminator(
215
+ wav.detach().squeeze(), real_GP.detach().squeeze()
216
+ ).mean()
217
+ else:
218
+ d_loss = self.wl.discriminator(
219
+ wav.detach().squeeze(), y_pred.detach().squeeze()
220
+ ).mean()
221
+
222
+ # regularization (ignore length variation)
223
+ d_loss += loss_reg
224
+
225
+ out_gt = self.wl.discriminator_forward(y_rec_gt.detach().squeeze())
226
+ out_rec = self.wl.discriminator_forward(
227
+ y_rec_gt_pred.detach().squeeze()
228
+ )
229
+
230
+ # regularization (ignore reconstruction artifacts)
231
+ d_loss += F.l1_loss(out_gt, out_rec)
232
+
233
+ else:
234
+ d_loss = self.wl.discriminator(
235
+ wav.detach().squeeze(), y_pred.detach().squeeze()
236
+ ).mean()
237
+ else:
238
+ d_loss = 0
239
+
240
+ # generator loss
241
+ gen_loss = self.wl.generator(y_pred.squeeze())
242
+
243
+ gen_loss = gen_loss.mean()
244
+
245
+ return d_loss, gen_loss, y_pred.detach().cpu().numpy()
246
+
247
+
248
+ def length_to_mask(lengths):
249
+ mask = (
250
+ torch.arange(lengths.max())
251
+ .unsqueeze(0)
252
+ .expand(lengths.shape[0], -1)
253
+ .type_as(lengths)
254
+ )
255
+ mask = torch.gt(mask + 1, lengths.unsqueeze(1))
256
+ return mask
src/Modules/utils.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def init_weights(m, mean=0.0, std=0.01):
2
+ classname = m.__class__.__name__
3
+ if classname.find("Conv") != -1:
4
+ m.weight.data.normal_(mean, std)
5
+
6
+
7
+ def apply_weight_norm(m):
8
+ classname = m.__class__.__name__
9
+ if classname.find("Conv") != -1:
10
+ weight_norm(m)
11
+
12
+
13
+ def get_padding(kernel_size, dilation=1):
14
+ return int((kernel_size * dilation - dilation) / 2)
src/Utils/ASR/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
src/Utils/ASR/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (151 Bytes). View file
 
src/Utils/ASR/__pycache__/layers.cpython-310.pyc ADDED
Binary file (11.2 kB). View file
 
src/Utils/ASR/__pycache__/models.cpython-310.pyc ADDED
Binary file (6.15 kB). View file
 
src/Utils/ASR/config.yml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "logs/20201006"
2
+ save_freq: 5
3
+ device: "cuda"
4
+ epochs: 180
5
+ batch_size: 64
6
+ pretrained_model: ""
7
+ train_data: "ASRDataset/train_list.txt"
8
+ val_data: "ASRDataset/val_list.txt"
9
+
10
+ dataset_params:
11
+ data_augmentation: false
12
+
13
+ preprocess_parasm:
14
+ sr: 24000
15
+ spect_params:
16
+ n_fft: 2048
17
+ win_length: 1200
18
+ hop_length: 300
19
+ mel_params:
20
+ n_mels: 80
21
+
22
+ model_params:
23
+ input_dim: 80
24
+ hidden_dim: 256
25
+ n_token: 178
26
+ token_embedding_dim: 512
27
+
28
+ optimizer_params:
29
+ lr: 0.0005
src/Utils/ASR/epoch_00080.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fedd55a1234b0c56e1e8b509c74edf3a5e2f27106a66038a4a946047a775bd6c
3
+ size 94552811
src/Utils/ASR/layers.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from typing import Optional, Any
5
+ from torch import Tensor
6
+ import torch.nn.functional as F
7
+ import torchaudio
8
+ import torchaudio.functional as audio_F
9
+
10
+ import random
11
+
12
+ random.seed(0)
13
+
14
+
15
+ def _get_activation_fn(activ):
16
+ if activ == "relu":
17
+ return nn.ReLU()
18
+ elif activ == "lrelu":
19
+ return nn.LeakyReLU(0.2)
20
+ elif activ == "swish":
21
+ return lambda x: x * torch.sigmoid(x)
22
+ else:
23
+ raise RuntimeError(
24
+ "Unexpected activ type %s, expected [relu, lrelu, swish]" % activ
25
+ )
26
+
27
+
28
+ class LinearNorm(torch.nn.Module):
29
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"):
30
+ super(LinearNorm, self).__init__()
31
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
32
+
33
+ torch.nn.init.xavier_uniform_(
34
+ self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
35
+ )
36
+
37
+ def forward(self, x):
38
+ return self.linear_layer(x)
39
+
40
+
41
+ class ConvNorm(torch.nn.Module):
42
+ def __init__(
43
+ self,
44
+ in_channels,
45
+ out_channels,
46
+ kernel_size=1,
47
+ stride=1,
48
+ padding=None,
49
+ dilation=1,
50
+ bias=True,
51
+ w_init_gain="linear",
52
+ param=None,
53
+ ):
54
+ super(ConvNorm, self).__init__()
55
+ if padding is None:
56
+ assert kernel_size % 2 == 1
57
+ padding = int(dilation * (kernel_size - 1) / 2)
58
+
59
+ self.conv = torch.nn.Conv1d(
60
+ in_channels,
61
+ out_channels,
62
+ kernel_size=kernel_size,
63
+ stride=stride,
64
+ padding=padding,
65
+ dilation=dilation,
66
+ bias=bias,
67
+ )
68
+
69
+ torch.nn.init.xavier_uniform_(
70
+ self.conv.weight,
71
+ gain=torch.nn.init.calculate_gain(w_init_gain, param=param),
72
+ )
73
+
74
+ def forward(self, signal):
75
+ conv_signal = self.conv(signal)
76
+ return conv_signal
77
+
78
+
79
+ class CausualConv(nn.Module):
80
+ def __init__(
81
+ self,
82
+ in_channels,
83
+ out_channels,
84
+ kernel_size=1,
85
+ stride=1,
86
+ padding=1,
87
+ dilation=1,
88
+ bias=True,
89
+ w_init_gain="linear",
90
+ param=None,
91
+ ):
92
+ super(CausualConv, self).__init__()
93
+ if padding is None:
94
+ assert kernel_size % 2 == 1
95
+ padding = int(dilation * (kernel_size - 1) / 2) * 2
96
+ else:
97
+ self.padding = padding * 2
98
+ self.conv = nn.Conv1d(
99
+ in_channels,
100
+ out_channels,
101
+ kernel_size=kernel_size,
102
+ stride=stride,
103
+ padding=self.padding,
104
+ dilation=dilation,
105
+ bias=bias,
106
+ )
107
+
108
+ torch.nn.init.xavier_uniform_(
109
+ self.conv.weight,
110
+ gain=torch.nn.init.calculate_gain(w_init_gain, param=param),
111
+ )
112
+
113
+ def forward(self, x):
114
+ x = self.conv(x)
115
+ x = x[:, :, : -self.padding]
116
+ return x
117
+
118
+
119
+ class CausualBlock(nn.Module):
120
+ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ="lrelu"):
121
+ super(CausualBlock, self).__init__()
122
+ self.blocks = nn.ModuleList(
123
+ [
124
+ self._get_conv(
125
+ hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p
126
+ )
127
+ for i in range(n_conv)
128
+ ]
129
+ )
130
+
131
+ def forward(self, x):
132
+ for block in self.blocks:
133
+ res = x
134
+ x = block(x)
135
+ x += res
136
+ return x
137
+
138
+ def _get_conv(self, hidden_dim, dilation, activ="lrelu", dropout_p=0.2):
139
+ layers = [
140
+ CausualConv(
141
+ hidden_dim,
142
+ hidden_dim,
143
+ kernel_size=3,
144
+ padding=dilation,
145
+ dilation=dilation,
146
+ ),
147
+ _get_activation_fn(activ),
148
+ nn.BatchNorm1d(hidden_dim),
149
+ nn.Dropout(p=dropout_p),
150
+ CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
151
+ _get_activation_fn(activ),
152
+ nn.Dropout(p=dropout_p),
153
+ ]
154
+ return nn.Sequential(*layers)
155
+
156
+
157
+ class ConvBlock(nn.Module):
158
+ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ="relu"):
159
+ super().__init__()
160
+ self._n_groups = 8
161
+ self.blocks = nn.ModuleList(
162
+ [
163
+ self._get_conv(
164
+ hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p
165
+ )
166
+ for i in range(n_conv)
167
+ ]
168
+ )
169
+
170
+ def forward(self, x):
171
+ for block in self.blocks:
172
+ res = x
173
+ x = block(x)
174
+ x += res
175
+ return x
176
+
177
+ def _get_conv(self, hidden_dim, dilation, activ="relu", dropout_p=0.2):
178
+ layers = [
179
+ ConvNorm(
180
+ hidden_dim,
181
+ hidden_dim,
182
+ kernel_size=3,
183
+ padding=dilation,
184
+ dilation=dilation,
185
+ ),
186
+ _get_activation_fn(activ),
187
+ nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
188
+ nn.Dropout(p=dropout_p),
189
+ ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
190
+ _get_activation_fn(activ),
191
+ nn.Dropout(p=dropout_p),
192
+ ]
193
+ return nn.Sequential(*layers)
194
+
195
+
196
+ class LocationLayer(nn.Module):
197
+ def __init__(self, attention_n_filters, attention_kernel_size, attention_dim):
198
+ super(LocationLayer, self).__init__()
199
+ padding = int((attention_kernel_size - 1) / 2)
200
+ self.location_conv = ConvNorm(
201
+ 2,
202
+ attention_n_filters,
203
+ kernel_size=attention_kernel_size,
204
+ padding=padding,
205
+ bias=False,
206
+ stride=1,
207
+ dilation=1,
208
+ )
209
+ self.location_dense = LinearNorm(
210
+ attention_n_filters, attention_dim, bias=False, w_init_gain="tanh"
211
+ )
212
+
213
+ def forward(self, attention_weights_cat):
214
+ processed_attention = self.location_conv(attention_weights_cat)
215
+ processed_attention = processed_attention.transpose(1, 2)
216
+ processed_attention = self.location_dense(processed_attention)
217
+ return processed_attention
218
+
219
+
220
+ class Attention(nn.Module):
221
+ def __init__(
222
+ self,
223
+ attention_rnn_dim,
224
+ embedding_dim,
225
+ attention_dim,
226
+ attention_location_n_filters,
227
+ attention_location_kernel_size,
228
+ ):
229
+ super(Attention, self).__init__()
230
+ self.query_layer = LinearNorm(
231
+ attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh"
232
+ )
233
+ self.memory_layer = LinearNorm(
234
+ embedding_dim, attention_dim, bias=False, w_init_gain="tanh"
235
+ )
236
+ self.v = LinearNorm(attention_dim, 1, bias=False)
237
+ self.location_layer = LocationLayer(
238
+ attention_location_n_filters, attention_location_kernel_size, attention_dim
239
+ )
240
+ self.score_mask_value = -float("inf")
241
+
242
+ def get_alignment_energies(self, query, processed_memory, attention_weights_cat):
243
+ """
244
+ PARAMS
245
+ ------
246
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
247
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
248
+ attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
249
+ RETURNS
250
+ -------
251
+ alignment (batch, max_time)
252
+ """
253
+
254
+ processed_query = self.query_layer(query.unsqueeze(1))
255
+ processed_attention_weights = self.location_layer(attention_weights_cat)
256
+ energies = self.v(
257
+ torch.tanh(processed_query + processed_attention_weights + processed_memory)
258
+ )
259
+
260
+ energies = energies.squeeze(-1)
261
+ return energies
262
+
263
+ def forward(
264
+ self,
265
+ attention_hidden_state,
266
+ memory,
267
+ processed_memory,
268
+ attention_weights_cat,
269
+ mask,
270
+ ):
271
+ """
272
+ PARAMS
273
+ ------
274
+ attention_hidden_state: attention rnn last output
275
+ memory: encoder outputs
276
+ processed_memory: processed encoder outputs
277
+ attention_weights_cat: previous and cummulative attention weights
278
+ mask: binary mask for padded data
279
+ """
280
+ alignment = self.get_alignment_energies(
281
+ attention_hidden_state, processed_memory, attention_weights_cat
282
+ )
283
+
284
+ if mask is not None:
285
+ alignment.data.masked_fill_(mask, self.score_mask_value)
286
+
287
+ attention_weights = F.softmax(alignment, dim=1)
288
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
289
+ attention_context = attention_context.squeeze(1)
290
+
291
+ return attention_context, attention_weights
292
+
293
+
294
+ class ForwardAttentionV2(nn.Module):
295
+ def __init__(
296
+ self,
297
+ attention_rnn_dim,
298
+ embedding_dim,
299
+ attention_dim,
300
+ attention_location_n_filters,
301
+ attention_location_kernel_size,
302
+ ):
303
+ super(ForwardAttentionV2, self).__init__()
304
+ self.query_layer = LinearNorm(
305
+ attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh"
306
+ )
307
+ self.memory_layer = LinearNorm(
308
+ embedding_dim, attention_dim, bias=False, w_init_gain="tanh"
309
+ )
310
+ self.v = LinearNorm(attention_dim, 1, bias=False)
311
+ self.location_layer = LocationLayer(
312
+ attention_location_n_filters, attention_location_kernel_size, attention_dim
313
+ )
314
+ self.score_mask_value = -float(1e20)
315
+
316
+ def get_alignment_energies(self, query, processed_memory, attention_weights_cat):
317
+ """
318
+ PARAMS
319
+ ------
320
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
321
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
322
+ attention_weights_cat: prev. and cumulative att weights (B, 2, max_time)
323
+ RETURNS
324
+ -------
325
+ alignment (batch, max_time)
326
+ """
327
+
328
+ processed_query = self.query_layer(query.unsqueeze(1))
329
+ processed_attention_weights = self.location_layer(attention_weights_cat)
330
+ energies = self.v(
331
+ torch.tanh(processed_query + processed_attention_weights + processed_memory)
332
+ )
333
+
334
+ energies = energies.squeeze(-1)
335
+ return energies
336
+
337
+ def forward(
338
+ self,
339
+ attention_hidden_state,
340
+ memory,
341
+ processed_memory,
342
+ attention_weights_cat,
343
+ mask,
344
+ log_alpha,
345
+ ):
346
+ """
347
+ PARAMS
348
+ ------
349
+ attention_hidden_state: attention rnn last output
350
+ memory: encoder outputs
351
+ processed_memory: processed encoder outputs
352
+ attention_weights_cat: previous and cummulative attention weights
353
+ mask: binary mask for padded data
354
+ """
355
+ log_energy = self.get_alignment_energies(
356
+ attention_hidden_state, processed_memory, attention_weights_cat
357
+ )
358
+
359
+ # log_energy =
360
+
361
+ if mask is not None:
362
+ log_energy.data.masked_fill_(mask, self.score_mask_value)
363
+
364
+ # attention_weights = F.softmax(alignment, dim=1)
365
+
366
+ # content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME]
367
+ # log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1]
368
+
369
+ # log_total_score = log_alpha + content_score
370
+
371
+ # previous_attention_weights = attention_weights_cat[:,0,:]
372
+
373
+ log_alpha_shift_padded = []
374
+ max_time = log_energy.size(1)
375
+ for sft in range(2):
376
+ shifted = log_alpha[:, : max_time - sft]
377
+ shift_padded = F.pad(shifted, (sft, 0), "constant", self.score_mask_value)
378
+ log_alpha_shift_padded.append(shift_padded.unsqueeze(2))
379
+
380
+ biased = torch.logsumexp(torch.cat(log_alpha_shift_padded, 2), 2)
381
+
382
+ log_alpha_new = biased + log_energy
383
+
384
+ attention_weights = F.softmax(log_alpha_new, dim=1)
385
+
386
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
387
+ attention_context = attention_context.squeeze(1)
388
+
389
+ return attention_context, attention_weights, log_alpha_new
390
+
391
+
392
+ class PhaseShuffle2d(nn.Module):
393
+ def __init__(self, n=2):
394
+ super(PhaseShuffle2d, self).__init__()
395
+ self.n = n
396
+ self.random = random.Random(1)
397
+
398
+ def forward(self, x, move=None):
399
+ # x.size = (B, C, M, L)
400
+ if move is None:
401
+ move = self.random.randint(-self.n, self.n)
402
+
403
+ if move == 0:
404
+ return x
405
+ else:
406
+ left = x[:, :, :, :move]
407
+ right = x[:, :, :, move:]
408
+ shuffled = torch.cat([right, left], dim=3)
409
+ return shuffled
410
+
411
+
412
+ class PhaseShuffle1d(nn.Module):
413
+ def __init__(self, n=2):
414
+ super(PhaseShuffle1d, self).__init__()
415
+ self.n = n
416
+ self.random = random.Random(1)
417
+
418
+ def forward(self, x, move=None):
419
+ # x.size = (B, C, M, L)
420
+ if move is None:
421
+ move = self.random.randint(-self.n, self.n)
422
+
423
+ if move == 0:
424
+ return x
425
+ else:
426
+ left = x[:, :, :move]
427
+ right = x[:, :, move:]
428
+ shuffled = torch.cat([right, left], dim=2)
429
+
430
+ return shuffled
431
+
432
+
433
+ class MFCC(nn.Module):
434
+ def __init__(self, n_mfcc=40, n_mels=80):
435
+ super(MFCC, self).__init__()
436
+ self.n_mfcc = n_mfcc
437
+ self.n_mels = n_mels
438
+ self.norm = "ortho"
439
+ dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
440
+ self.register_buffer("dct_mat", dct_mat)
441
+
442
+ def forward(self, mel_specgram):
443
+ if len(mel_specgram.shape) == 2:
444
+ mel_specgram = mel_specgram.unsqueeze(0)
445
+ unsqueezed = True
446
+ else:
447
+ unsqueezed = False
448
+ # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
449
+ # -> (channel, time, n_mfcc).tranpose(...)
450
+ mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
451
+
452
+ # unpack batch
453
+ if unsqueezed:
454
+ mfcc = mfcc.squeeze(0)
455
+ return mfcc
src/Utils/ASR/models.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import TransformerEncoder
5
+ import torch.nn.functional as F
6
+ from .layers import MFCC, Attention, LinearNorm, ConvNorm, ConvBlock
7
+
8
+
9
+ class ASRCNN(nn.Module):
10
+ def __init__(
11
+ self,
12
+ input_dim=80,
13
+ hidden_dim=256,
14
+ n_token=35,
15
+ n_layers=6,
16
+ token_embedding_dim=256,
17
+ ):
18
+ super().__init__()
19
+ self.n_token = n_token
20
+ self.n_down = 1
21
+ self.to_mfcc = MFCC()
22
+ self.init_cnn = ConvNorm(
23
+ input_dim // 2, hidden_dim, kernel_size=7, padding=3, stride=2
24
+ )
25
+ self.cnns = nn.Sequential(
26
+ *[
27
+ nn.Sequential(
28
+ ConvBlock(hidden_dim),
29
+ nn.GroupNorm(num_groups=1, num_channels=hidden_dim),
30
+ )
31
+ for n in range(n_layers)
32
+ ]
33
+ )
34
+ self.projection = ConvNorm(hidden_dim, hidden_dim // 2)
35
+ self.ctc_linear = nn.Sequential(
36
+ LinearNorm(hidden_dim // 2, hidden_dim),
37
+ nn.ReLU(),
38
+ LinearNorm(hidden_dim, n_token),
39
+ )
40
+ self.asr_s2s = ASRS2S(
41
+ embedding_dim=token_embedding_dim,
42
+ hidden_dim=hidden_dim // 2,
43
+ n_token=n_token,
44
+ )
45
+
46
+ def forward(self, x, src_key_padding_mask=None, text_input=None):
47
+ x = self.to_mfcc(x)
48
+ x = self.init_cnn(x)
49
+ x = self.cnns(x)
50
+ x = self.projection(x)
51
+ x = x.transpose(1, 2)
52
+ ctc_logit = self.ctc_linear(x)
53
+ if text_input is not None:
54
+ _, s2s_logit, s2s_attn = self.asr_s2s(x, src_key_padding_mask, text_input)
55
+ return ctc_logit, s2s_logit, s2s_attn
56
+ else:
57
+ return ctc_logit
58
+
59
+ def get_feature(self, x):
60
+ x = self.to_mfcc(x.squeeze(1))
61
+ x = self.init_cnn(x)
62
+ x = self.cnns(x)
63
+ x = self.projection(x)
64
+ return x
65
+
66
+ def length_to_mask(self, lengths):
67
+ mask = (
68
+ torch.arange(lengths.max())
69
+ .unsqueeze(0)
70
+ .expand(lengths.shape[0], -1)
71
+ .type_as(lengths)
72
+ )
73
+ mask = torch.gt(mask + 1, lengths.unsqueeze(1)).to(lengths.device)
74
+ return mask
75
+
76
+ def get_future_mask(self, out_length, unmask_future_steps=0):
77
+ """
78
+ Args:
79
+ out_length (int): returned mask shape is (out_length, out_length).
80
+ unmask_futre_steps (int): unmasking future step size.
81
+ Return:
82
+ mask (torch.BoolTensor): mask future timesteps mask[i, j] = True if i > j + unmask_future_steps else False
83
+ """
84
+ index_tensor = torch.arange(out_length).unsqueeze(0).expand(out_length, -1)
85
+ mask = torch.gt(index_tensor, index_tensor.T + unmask_future_steps)
86
+ return mask
87
+
88
+
89
+ class ASRS2S(nn.Module):
90
+ def __init__(
91
+ self,
92
+ embedding_dim=256,
93
+ hidden_dim=512,
94
+ n_location_filters=32,
95
+ location_kernel_size=63,
96
+ n_token=40,
97
+ ):
98
+ super(ASRS2S, self).__init__()
99
+ self.embedding = nn.Embedding(n_token, embedding_dim)
100
+ val_range = math.sqrt(6 / hidden_dim)
101
+ self.embedding.weight.data.uniform_(-val_range, val_range)
102
+
103
+ self.decoder_rnn_dim = hidden_dim
104
+ self.project_to_n_symbols = nn.Linear(self.decoder_rnn_dim, n_token)
105
+ self.attention_layer = Attention(
106
+ self.decoder_rnn_dim,
107
+ hidden_dim,
108
+ hidden_dim,
109
+ n_location_filters,
110
+ location_kernel_size,
111
+ )
112
+ self.decoder_rnn = nn.LSTMCell(
113
+ self.decoder_rnn_dim + embedding_dim, self.decoder_rnn_dim
114
+ )
115
+ self.project_to_hidden = nn.Sequential(
116
+ LinearNorm(self.decoder_rnn_dim * 2, hidden_dim), nn.Tanh()
117
+ )
118
+ self.sos = 1
119
+ self.eos = 2
120
+
121
+ def initialize_decoder_states(self, memory, mask):
122
+ """
123
+ moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim)
124
+ """
125
+ B, L, H = memory.shape
126
+ self.decoder_hidden = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory)
127
+ self.decoder_cell = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory)
128
+ self.attention_weights = torch.zeros((B, L)).type_as(memory)
129
+ self.attention_weights_cum = torch.zeros((B, L)).type_as(memory)
130
+ self.attention_context = torch.zeros((B, H)).type_as(memory)
131
+ self.memory = memory
132
+ self.processed_memory = self.attention_layer.memory_layer(memory)
133
+ self.mask = mask
134
+ self.unk_index = 3
135
+ self.random_mask = 0.1
136
+
137
+ def forward(self, memory, memory_mask, text_input):
138
+ """
139
+ moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim)
140
+ moemory_mask.shape = (B, L, )
141
+ texts_input.shape = (B, T)
142
+ """
143
+ self.initialize_decoder_states(memory, memory_mask)
144
+ # text random mask
145
+ random_mask = (torch.rand(text_input.shape) < self.random_mask).to(
146
+ text_input.device
147
+ )
148
+ _text_input = text_input.clone()
149
+ _text_input.masked_fill_(random_mask, self.unk_index)
150
+ decoder_inputs = self.embedding(_text_input).transpose(
151
+ 0, 1
152
+ ) # -> [T, B, channel]
153
+ start_embedding = self.embedding(
154
+ torch.LongTensor([self.sos] * decoder_inputs.size(1)).to(
155
+ decoder_inputs.device
156
+ )
157
+ )
158
+ decoder_inputs = torch.cat(
159
+ (start_embedding.unsqueeze(0), decoder_inputs), dim=0
160
+ )
161
+
162
+ hidden_outputs, logit_outputs, alignments = [], [], []
163
+ while len(hidden_outputs) < decoder_inputs.size(0):
164
+ decoder_input = decoder_inputs[len(hidden_outputs)]
165
+ hidden, logit, attention_weights = self.decode(decoder_input)
166
+ hidden_outputs += [hidden]
167
+ logit_outputs += [logit]
168
+ alignments += [attention_weights]
169
+
170
+ hidden_outputs, logit_outputs, alignments = self.parse_decoder_outputs(
171
+ hidden_outputs, logit_outputs, alignments
172
+ )
173
+
174
+ return hidden_outputs, logit_outputs, alignments
175
+
176
+ def decode(self, decoder_input):
177
+ cell_input = torch.cat((decoder_input, self.attention_context), -1)
178
+ self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
179
+ cell_input, (self.decoder_hidden, self.decoder_cell)
180
+ )
181
+
182
+ attention_weights_cat = torch.cat(
183
+ (
184
+ self.attention_weights.unsqueeze(1),
185
+ self.attention_weights_cum.unsqueeze(1),
186
+ ),
187
+ dim=1,
188
+ )
189
+
190
+ self.attention_context, self.attention_weights = self.attention_layer(
191
+ self.decoder_hidden,
192
+ self.memory,
193
+ self.processed_memory,
194
+ attention_weights_cat,
195
+ self.mask,
196
+ )
197
+
198
+ self.attention_weights_cum += self.attention_weights
199
+
200
+ hidden_and_context = torch.cat(
201
+ (self.decoder_hidden, self.attention_context), -1
202
+ )
203
+ hidden = self.project_to_hidden(hidden_and_context)
204
+
205
+ # dropout to increasing g
206
+ logit = self.project_to_n_symbols(F.dropout(hidden, 0.5, self.training))
207
+
208
+ return hidden, logit, self.attention_weights
209
+
210
+ def parse_decoder_outputs(self, hidden, logit, alignments):
211
+ # -> [B, T_out + 1, max_time]
212
+ alignments = torch.stack(alignments).transpose(0, 1)
213
+ # [T_out + 1, B, n_symbols] -> [B, T_out + 1, n_symbols]
214
+ logit = torch.stack(logit).transpose(0, 1).contiguous()
215
+ hidden = torch.stack(hidden).transpose(0, 1).contiguous()
216
+
217
+ return hidden, logit, alignments
src/Utils/JDC/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
src/Utils/JDC/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (151 Bytes). View file
 
src/Utils/JDC/__pycache__/model.cpython-310.pyc ADDED
Binary file (4.81 kB). View file
 
src/Utils/JDC/bst.t7 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54dc94364b97e18ac1dfa6287714ed121248cfaac4cfd39d061c6e0a089ef169
3
+ size 21029926
src/Utils/JDC/model.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of model from:
3
+ Kum et al. - "Joint Detection and Classification of Singing Voice Melody Using
4
+ Convolutional Recurrent Neural Networks" (2019)
5
+ Link: https://www.semanticscholar.org/paper/Joint-Detection-and-Classification-of-Singing-Voice-Kum-Nam/60a2ad4c7db43bace75805054603747fcd062c0d
6
+ """
7
+ import torch
8
+ from torch import nn
9
+
10
+
11
+ class JDCNet(nn.Module):
12
+ """
13
+ Joint Detection and Classification Network model for singing voice melody.
14
+ """
15
+
16
+ def __init__(self, num_class=722, seq_len=31, leaky_relu_slope=0.01):
17
+ super().__init__()
18
+ self.num_class = num_class
19
+
20
+ # input = (b, 1, 31, 513), b = batch size
21
+ self.conv_block = nn.Sequential(
22
+ nn.Conv2d(
23
+ in_channels=1, out_channels=64, kernel_size=3, padding=1, bias=False
24
+ ), # out: (b, 64, 31, 513)
25
+ nn.BatchNorm2d(num_features=64),
26
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
27
+ nn.Conv2d(64, 64, 3, padding=1, bias=False), # (b, 64, 31, 513)
28
+ )
29
+
30
+ # res blocks
31
+ self.res_block1 = ResBlock(
32
+ in_channels=64, out_channels=128
33
+ ) # (b, 128, 31, 128)
34
+ self.res_block2 = ResBlock(
35
+ in_channels=128, out_channels=192
36
+ ) # (b, 192, 31, 32)
37
+ self.res_block3 = ResBlock(in_channels=192, out_channels=256) # (b, 256, 31, 8)
38
+
39
+ # pool block
40
+ self.pool_block = nn.Sequential(
41
+ nn.BatchNorm2d(num_features=256),
42
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
43
+ nn.MaxPool2d(kernel_size=(1, 4)), # (b, 256, 31, 2)
44
+ nn.Dropout(p=0.2),
45
+ )
46
+
47
+ # maxpool layers (for auxiliary network inputs)
48
+ # in = (b, 128, 31, 513) from conv_block, out = (b, 128, 31, 2)
49
+ self.maxpool1 = nn.MaxPool2d(kernel_size=(1, 40))
50
+ # in = (b, 128, 31, 128) from res_block1, out = (b, 128, 31, 2)
51
+ self.maxpool2 = nn.MaxPool2d(kernel_size=(1, 20))
52
+ # in = (b, 128, 31, 32) from res_block2, out = (b, 128, 31, 2)
53
+ self.maxpool3 = nn.MaxPool2d(kernel_size=(1, 10))
54
+
55
+ # in = (b, 640, 31, 2), out = (b, 256, 31, 2)
56
+ self.detector_conv = nn.Sequential(
57
+ nn.Conv2d(640, 256, 1, bias=False),
58
+ nn.BatchNorm2d(256),
59
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
60
+ nn.Dropout(p=0.2),
61
+ )
62
+
63
+ # input: (b, 31, 512) - resized from (b, 256, 31, 2)
64
+ self.bilstm_classifier = nn.LSTM(
65
+ input_size=512, hidden_size=256, batch_first=True, bidirectional=True
66
+ ) # (b, 31, 512)
67
+
68
+ # input: (b, 31, 512) - resized from (b, 256, 31, 2)
69
+ self.bilstm_detector = nn.LSTM(
70
+ input_size=512, hidden_size=256, batch_first=True, bidirectional=True
71
+ ) # (b, 31, 512)
72
+
73
+ # input: (b * 31, 512)
74
+ self.classifier = nn.Linear(
75
+ in_features=512, out_features=self.num_class
76
+ ) # (b * 31, num_class)
77
+
78
+ # input: (b * 31, 512)
79
+ self.detector = nn.Linear(
80
+ in_features=512, out_features=2
81
+ ) # (b * 31, 2) - binary classifier
82
+
83
+ # initialize weights
84
+ self.apply(self.init_weights)
85
+
86
+ def get_feature_GAN(self, x):
87
+ seq_len = x.shape[-2]
88
+ x = x.float().transpose(-1, -2)
89
+
90
+ convblock_out = self.conv_block(x)
91
+
92
+ resblock1_out = self.res_block1(convblock_out)
93
+ resblock2_out = self.res_block2(resblock1_out)
94
+ resblock3_out = self.res_block3(resblock2_out)
95
+ poolblock_out = self.pool_block[0](resblock3_out)
96
+ poolblock_out = self.pool_block[1](poolblock_out)
97
+
98
+ return poolblock_out.transpose(-1, -2)
99
+
100
+ def get_feature(self, x):
101
+ seq_len = x.shape[-2]
102
+ x = x.float().transpose(-1, -2)
103
+
104
+ convblock_out = self.conv_block(x)
105
+
106
+ resblock1_out = self.res_block1(convblock_out)
107
+ resblock2_out = self.res_block2(resblock1_out)
108
+ resblock3_out = self.res_block3(resblock2_out)
109
+ poolblock_out = self.pool_block[0](resblock3_out)
110
+ poolblock_out = self.pool_block[1](poolblock_out)
111
+
112
+ return self.pool_block[2](poolblock_out)
113
+
114
+ def forward(self, x):
115
+ """
116
+ Returns:
117
+ classification_prediction, detection_prediction
118
+ sizes: (b, 31, 722), (b, 31, 2)
119
+ """
120
+ ###############################
121
+ # forward pass for classifier #
122
+ ###############################
123
+ seq_len = x.shape[-1]
124
+ x = x.float().transpose(-1, -2)
125
+
126
+ convblock_out = self.conv_block(x)
127
+
128
+ resblock1_out = self.res_block1(convblock_out)
129
+ resblock2_out = self.res_block2(resblock1_out)
130
+ resblock3_out = self.res_block3(resblock2_out)
131
+
132
+ poolblock_out = self.pool_block[0](resblock3_out)
133
+ poolblock_out = self.pool_block[1](poolblock_out)
134
+ GAN_feature = poolblock_out.transpose(-1, -2)
135
+ poolblock_out = self.pool_block[2](poolblock_out)
136
+
137
+ # (b, 256, 31, 2) => (b, 31, 256, 2) => (b, 31, 512)
138
+ classifier_out = (
139
+ poolblock_out.permute(0, 2, 1, 3).contiguous().view((-1, seq_len, 512))
140
+ )
141
+ classifier_out, _ = self.bilstm_classifier(
142
+ classifier_out
143
+ ) # ignore the hidden states
144
+
145
+ classifier_out = classifier_out.contiguous().view((-1, 512)) # (b * 31, 512)
146
+ classifier_out = self.classifier(classifier_out)
147
+ classifier_out = classifier_out.view(
148
+ (-1, seq_len, self.num_class)
149
+ ) # (b, 31, num_class)
150
+
151
+ # sizes: (b, 31, 722), (b, 31, 2)
152
+ # classifier output consists of predicted pitch classes per frame
153
+ # detector output consists of: (isvoice, notvoice) estimates per frame
154
+ return torch.abs(classifier_out.squeeze()), GAN_feature, poolblock_out
155
+
156
+ @staticmethod
157
+ def init_weights(m):
158
+ if isinstance(m, nn.Linear):
159
+ nn.init.kaiming_uniform_(m.weight)
160
+ if m.bias is not None:
161
+ nn.init.constant_(m.bias, 0)
162
+ elif isinstance(m, nn.Conv2d):
163
+ nn.init.xavier_normal_(m.weight)
164
+ elif isinstance(m, nn.LSTM) or isinstance(m, nn.LSTMCell):
165
+ for p in m.parameters():
166
+ if p.data is None:
167
+ continue
168
+
169
+ if len(p.shape) >= 2:
170
+ nn.init.orthogonal_(p.data)
171
+ else:
172
+ nn.init.normal_(p.data)
173
+
174
+
175
+ class ResBlock(nn.Module):
176
+ def __init__(self, in_channels: int, out_channels: int, leaky_relu_slope=0.01):
177
+ super().__init__()
178
+ self.downsample = in_channels != out_channels
179
+
180
+ # BN / LReLU / MaxPool layer before the conv layer - see Figure 1b in the paper
181
+ self.pre_conv = nn.Sequential(
182
+ nn.BatchNorm2d(num_features=in_channels),
183
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
184
+ nn.MaxPool2d(kernel_size=(1, 2)), # apply downsampling on the y axis only
185
+ )
186
+
187
+ # conv layers
188
+ self.conv = nn.Sequential(
189
+ nn.Conv2d(
190
+ in_channels=in_channels,
191
+ out_channels=out_channels,
192
+ kernel_size=3,
193
+ padding=1,
194
+ bias=False,
195
+ ),
196
+ nn.BatchNorm2d(out_channels),
197
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
198
+ nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
199
+ )
200
+
201
+ # 1 x 1 convolution layer to match the feature dimensions
202
+ self.conv1by1 = None
203
+ if self.downsample:
204
+ self.conv1by1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
205
+
206
+ def forward(self, x):
207
+ x = self.pre_conv(x)
208
+ if self.downsample:
209
+ x = self.conv(x) + self.conv1by1(x)
210
+ else:
211
+ x = self.conv(x) + x
212
+ return x
src/Utils/PLBERT/__pycache__/util.cpython-310.pyc ADDED
Binary file (1.77 kB). View file
 
src/Utils/PLBERT/config.yml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "Checkpoint"
2
+ mixed_precision: "fp16"
3
+ data_folder: "wikipedia_20220301.en.processed"
4
+ batch_size: 192
5
+ save_interval: 5000
6
+ log_interval: 10
7
+ num_process: 1 # number of GPUs
8
+ num_steps: 1000000
9
+
10
+ dataset_params:
11
+ tokenizer: "transfo-xl-wt103"
12
+ token_separator: " " # token used for phoneme separator (space)
13
+ token_mask: "M" # token used for phoneme mask (M)
14
+ word_separator: 3039 # token used for word separator (<formula>)
15
+ token_maps: "token_maps.pkl" # token map path
16
+
17
+ max_mel_length: 512 # max phoneme length
18
+
19
+ word_mask_prob: 0.15 # probability to mask the entire word
20
+ phoneme_mask_prob: 0.1 # probability to mask each phoneme
21
+ replace_prob: 0.2 # probablity to replace phonemes
22
+
23
+ model_params:
24
+ vocab_size: 178
25
+ hidden_size: 768
26
+ num_attention_heads: 12
27
+ intermediate_size: 2048
28
+ max_position_embeddings: 512
29
+ num_hidden_layers: 12
30
+ dropout: 0.1
src/Utils/PLBERT/step_1000000.t7 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0714ff85804db43e06b3b0ac5749bf90cf206257c6c5916e8a98c5933b4c21e0
3
+ size 25185187
src/Utils/PLBERT/util.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import torch
4
+ from transformers import AlbertConfig, AlbertModel
5
+
6
+
7
+ class CustomAlbert(AlbertModel):
8
+ def forward(self, *args, **kwargs):
9
+ # Call the original forward method
10
+ outputs = super().forward(*args, **kwargs)
11
+
12
+ # Only return the last_hidden_state
13
+ return outputs.last_hidden_state
14
+
15
+
16
+ def load_plbert(log_dir):
17
+ config_path = os.path.join(log_dir, "config.yml")
18
+ plbert_config = yaml.safe_load(open(config_path))
19
+
20
+ albert_base_configuration = AlbertConfig(**plbert_config["model_params"])
21
+ bert = CustomAlbert(albert_base_configuration)
22
+
23
+ files = os.listdir(log_dir)
24
+ ckpts = []
25
+ for f in os.listdir(log_dir):
26
+ if f.startswith("step_"):
27
+ ckpts.append(f)
28
+
29
+ iters = [
30
+ int(f.split("_")[-1].split(".")[0])
31
+ for f in ckpts
32
+ if os.path.isfile(os.path.join(log_dir, f))
33
+ ]
34
+ iters = sorted(iters)[-1]
35
+
36
+ checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".t7", map_location="cpu")
37
+ state_dict = checkpoint["net"]
38
+ from collections import OrderedDict
39
+
40
+ new_state_dict = OrderedDict()
41
+ for k, v in state_dict.items():
42
+ name = k[7:] # remove `module.`
43
+ if name.startswith("encoder."):
44
+ name = name[8:] # remove `encoder.`
45
+ new_state_dict[name] = v
46
+ del new_state_dict["embeddings.position_ids"]
47
+ bert.load_state_dict(new_state_dict, strict=False)
48
+
49
+ return bert