Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
593f3bc
0
Parent(s):
first commit for huggingface space
Browse files- Dockerfile +18 -0
- LICENSE +202 -0
- app.py +94 -0
- readme.md +186 -0
- requirements.txt +17 -0
- tts/frontend_function.py +175 -0
- tts/gradio_api.py +94 -0
- tts/infer_cli.py +278 -0
- tts/modules/aligner/whisper_small.py +318 -0
- tts/modules/ar_dur/ar_dur_predictor.py +362 -0
- tts/modules/ar_dur/commons/layers.py +64 -0
- tts/modules/ar_dur/commons/nar_tts_modules.py +73 -0
- tts/modules/ar_dur/commons/rel_transformer.py +403 -0
- tts/modules/ar_dur/commons/rot_transformer.py +649 -0
- tts/modules/ar_dur/commons/seq_utils.py +342 -0
- tts/modules/ar_dur/commons/transformer.py +767 -0
- tts/modules/llm_dit/cfm.py +309 -0
- tts/modules/llm_dit/dit.py +180 -0
- tts/modules/llm_dit/time_embedding.py +44 -0
- tts/modules/llm_dit/transformer.py +230 -0
- tts/modules/wavvae/decoder/diag_gaussian.py +67 -0
- tts/modules/wavvae/decoder/hifigan_modules.py +283 -0
- tts/modules/wavvae/decoder/seanet_encoder.py +38 -0
- tts/modules/wavvae/decoder/wavvae_v3.py +60 -0
- tts/modules/wavvae/encoder/common_modules/conv.py +154 -0
- tts/modules/wavvae/encoder/common_modules/lstm.py +51 -0
- tts/modules/wavvae/encoder/common_modules/seanet.py +126 -0
- tts/utils/audio_utils/align.py +36 -0
- tts/utils/audio_utils/io.py +95 -0
- tts/utils/audio_utils/plot.py +90 -0
- tts/utils/commons/ckpt_utils.py +171 -0
- tts/utils/commons/hparams.py +215 -0
- tts/utils/text_utils/dict.json +1 -0
- tts/utils/text_utils/ph_tone_convert.py +94 -0
- tts/utils/text_utils/split_text.py +90 -0
- tts/utils/text_utils/text_encoder.py +280 -0
Dockerfile
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM pytorch/pytorch:2.3.0-cuda12.1-cudnn8-runtime
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
RUN apt-get update && apt-get install -y \
|
6 |
+
curl \
|
7 |
+
python3 \
|
8 |
+
python3-pip \
|
9 |
+
ffmpeg \
|
10 |
+
&& apt-get clean
|
11 |
+
|
12 |
+
COPY requirements.txt /app/
|
13 |
+
|
14 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
15 |
+
|
16 |
+
COPY . /app/
|
17 |
+
|
18 |
+
CMD ["python", "-m", "tts.gradio_api"]
|
LICENSE
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Apache License
|
3 |
+
Version 2.0, January 2004
|
4 |
+
http://www.apache.org/licenses/
|
5 |
+
|
6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
7 |
+
|
8 |
+
1. Definitions.
|
9 |
+
|
10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
12 |
+
|
13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
14 |
+
the copyright owner that is granting the License.
|
15 |
+
|
16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
17 |
+
other entities that control, are controlled by, or are under common
|
18 |
+
control with that entity. For the purposes of this definition,
|
19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
20 |
+
direction or management of such entity, whether by contract or
|
21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
23 |
+
|
24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
25 |
+
exercising permissions granted by this License.
|
26 |
+
|
27 |
+
"Source" form shall mean the preferred form for making modifications,
|
28 |
+
including but not limited to software source code, documentation
|
29 |
+
source, and configuration files.
|
30 |
+
|
31 |
+
"Object" form shall mean any form resulting from mechanical
|
32 |
+
transformation or translation of a Source form, including but
|
33 |
+
not limited to compiled object code, generated documentation,
|
34 |
+
and conversions to other media types.
|
35 |
+
|
36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
37 |
+
Object form, made available under the License, as indicated by a
|
38 |
+
copyright notice that is included in or attached to the work
|
39 |
+
(an example is provided in the Appendix below).
|
40 |
+
|
41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
42 |
+
form, that is based on (or derived from) the Work and for which the
|
43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
45 |
+
of this License, Derivative Works shall not include works that remain
|
46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
47 |
+
the Work and Derivative Works thereof.
|
48 |
+
|
49 |
+
"Contribution" shall mean any work of authorship, including
|
50 |
+
the original version of the Work and any modifications or additions
|
51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
55 |
+
means any form of electronic, verbal, or written communication sent
|
56 |
+
to the Licensor or its representatives, including but not limited to
|
57 |
+
communication on electronic mailing lists, source code control systems,
|
58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
60 |
+
excluding communication that is conspicuously marked or otherwise
|
61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
62 |
+
|
63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
65 |
+
subsequently incorporated within the Work.
|
66 |
+
|
67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
72 |
+
Work and such Derivative Works in Source or Object form.
|
73 |
+
|
74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
77 |
+
(except as stated in this section) patent license to make, have made,
|
78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
79 |
+
where such license applies only to those patent claims licensable
|
80 |
+
by such Contributor that are necessarily infringed by their
|
81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
83 |
+
institute patent litigation against any entity (including a
|
84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
85 |
+
or a Contribution incorporated within the Work constitutes direct
|
86 |
+
or contributory patent infringement, then any patent licenses
|
87 |
+
granted to You under this License for that Work shall terminate
|
88 |
+
as of the date such litigation is filed.
|
89 |
+
|
90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
91 |
+
Work or Derivative Works thereof in any medium, with or without
|
92 |
+
modifications, and in Source or Object form, provided that You
|
93 |
+
meet the following conditions:
|
94 |
+
|
95 |
+
(a) You must give any other recipients of the Work or
|
96 |
+
Derivative Works a copy of this License; and
|
97 |
+
|
98 |
+
(b) You must cause any modified files to carry prominent notices
|
99 |
+
stating that You changed the files; and
|
100 |
+
|
101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
102 |
+
that You distribute, all copyright, patent, trademark, and
|
103 |
+
attribution notices from the Source form of the Work,
|
104 |
+
excluding those notices that do not pertain to any part of
|
105 |
+
the Derivative Works; and
|
106 |
+
|
107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
108 |
+
distribution, then any Derivative Works that You distribute must
|
109 |
+
include a readable copy of the attribution notices contained
|
110 |
+
within such NOTICE file, excluding those notices that do not
|
111 |
+
pertain to any part of the Derivative Works, in at least one
|
112 |
+
of the following places: within a NOTICE text file distributed
|
113 |
+
as part of the Derivative Works; within the Source form or
|
114 |
+
documentation, if provided along with the Derivative Works; or,
|
115 |
+
within a display generated by the Derivative Works, if and
|
116 |
+
wherever such third-party notices normally appear. The contents
|
117 |
+
of the NOTICE file are for informational purposes only and
|
118 |
+
do not modify the License. You may add Your own attribution
|
119 |
+
notices within Derivative Works that You distribute, alongside
|
120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
121 |
+
that such additional attribution notices cannot be construed
|
122 |
+
as modifying the License.
|
123 |
+
|
124 |
+
You may add Your own copyright statement to Your modifications and
|
125 |
+
may provide additional or different license terms and conditions
|
126 |
+
for use, reproduction, or distribution of Your modifications, or
|
127 |
+
for any such Derivative Works as a whole, provided Your use,
|
128 |
+
reproduction, and distribution of the Work otherwise complies with
|
129 |
+
the conditions stated in this License.
|
130 |
+
|
131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
133 |
+
by You to the Licensor shall be under the terms and conditions of
|
134 |
+
this License, without any additional terms or conditions.
|
135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
136 |
+
the terms of any separate license agreement you may have executed
|
137 |
+
with Licensor regarding such Contributions.
|
138 |
+
|
139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
141 |
+
except as required for reasonable and customary use in describing the
|
142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
143 |
+
|
144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
145 |
+
agreed to in writing, Licensor provides the Work (and each
|
146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
148 |
+
implied, including, without limitation, any warranties or conditions
|
149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
151 |
+
appropriateness of using or redistributing the Work and assume any
|
152 |
+
risks associated with Your exercise of permissions under this License.
|
153 |
+
|
154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
155 |
+
whether in tort (including negligence), contract, or otherwise,
|
156 |
+
unless required by applicable law (such as deliberate and grossly
|
157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
158 |
+
liable to You for damages, including any direct, indirect, special,
|
159 |
+
incidental, or consequential damages of any character arising as a
|
160 |
+
result of this License or out of the use or inability to use the
|
161 |
+
Work (including but not limited to damages for loss of goodwill,
|
162 |
+
work stoppage, computer failure or malfunction, or any and all
|
163 |
+
other commercial damages or losses), even if such Contributor
|
164 |
+
has been advised of the possibility of such damages.
|
165 |
+
|
166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
169 |
+
or other liability obligations and/or rights consistent with this
|
170 |
+
License. However, in accepting such obligations, You may act only
|
171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
172 |
+
of any other Contributor, and only if You agree to indemnify,
|
173 |
+
defend, and hold each Contributor harmless for any liability
|
174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
175 |
+
of your accepting any such warranty or additional liability.
|
176 |
+
|
177 |
+
END OF TERMS AND CONDITIONS
|
178 |
+
|
179 |
+
APPENDIX: How to apply the Apache License to your work.
|
180 |
+
|
181 |
+
To apply the Apache License to your work, attach the following
|
182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
183 |
+
replaced with your own identifying information. (Don't include
|
184 |
+
the brackets!) The text should be enclosed in the appropriate
|
185 |
+
comment syntax for the file format. We also recommend that a
|
186 |
+
file or class name and description of purpose be included on the
|
187 |
+
same "printed page" as the copyright notice for easier
|
188 |
+
identification within third-party archives.
|
189 |
+
|
190 |
+
Copyright [2025] ByteDance
|
191 |
+
|
192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
193 |
+
you may not use this file except in compliance with the License.
|
194 |
+
You may obtain a copy of the License at
|
195 |
+
|
196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
197 |
+
|
198 |
+
Unless required by applicable law or agreed to in writing, software
|
199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
201 |
+
See the License for the specific language governing permissions and
|
202 |
+
limitations under the License.
|
app.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 ByteDance and/or its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import multiprocessing as mp
|
16 |
+
import torch
|
17 |
+
import os
|
18 |
+
from functools import partial
|
19 |
+
import gradio as gr
|
20 |
+
import traceback
|
21 |
+
from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav
|
22 |
+
|
23 |
+
|
24 |
+
def model_worker(input_queue, output_queue, device_id):
|
25 |
+
device = None
|
26 |
+
if device_id is not None:
|
27 |
+
device = torch.device(f'cuda:{device_id}')
|
28 |
+
infer_pipe = MegaTTS3DiTInfer(device=device)
|
29 |
+
os.system(f'pkill -f "voidgpu{device_id}"')
|
30 |
+
|
31 |
+
while True:
|
32 |
+
task = input_queue.get()
|
33 |
+
inp_audio_path, inp_npy_path, inp_text, infer_timestep, p_w, t_w = task
|
34 |
+
try:
|
35 |
+
convert_to_wav(inp_audio_path)
|
36 |
+
wav_path = os.path.splitext(inp_audio_path)[0] + '.wav'
|
37 |
+
cut_wav(wav_path, max_len=28)
|
38 |
+
with open(wav_path, 'rb') as file:
|
39 |
+
file_content = file.read()
|
40 |
+
resource_context = infer_pipe.preprocess(file_content, latent_file=inp_npy_path)
|
41 |
+
wav_bytes = infer_pipe.forward(resource_context, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
|
42 |
+
output_queue.put(wav_bytes)
|
43 |
+
except Exception as e:
|
44 |
+
traceback.print_exc()
|
45 |
+
print(task, str(e))
|
46 |
+
output_queue.put(None)
|
47 |
+
|
48 |
+
|
49 |
+
def main(inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w, processes, input_queue, output_queue):
|
50 |
+
print("Push task to the inp queue |", inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w)
|
51 |
+
input_queue.put((inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w))
|
52 |
+
res = output_queue.get()
|
53 |
+
if res is not None:
|
54 |
+
return res
|
55 |
+
else:
|
56 |
+
print("")
|
57 |
+
return None
|
58 |
+
|
59 |
+
|
60 |
+
if __name__ == '__main__':
|
61 |
+
mp.set_start_method('spawn', force=True)
|
62 |
+
devices = os.environ.get('CUDA_VISIBLE_DEVICES', '')
|
63 |
+
if devices != '':
|
64 |
+
devices = os.environ.get('CUDA_VISIBLE_DEVICES', '').split(",")
|
65 |
+
for d in devices:
|
66 |
+
os.system(f'pkill -f "voidgpu{d}"')
|
67 |
+
else:
|
68 |
+
devices = None
|
69 |
+
|
70 |
+
num_workers = 1
|
71 |
+
input_queue = mp.Queue()
|
72 |
+
output_queue = mp.Queue()
|
73 |
+
processes = []
|
74 |
+
|
75 |
+
print("Start open workers")
|
76 |
+
for i in range(num_workers):
|
77 |
+
p = mp.Process(target=model_worker, args=(input_queue, output_queue, i % len(devices) if devices is not None else None))
|
78 |
+
p.start()
|
79 |
+
processes.append(p)
|
80 |
+
|
81 |
+
api_interface = gr.Interface(fn=
|
82 |
+
partial(main, processes=processes, input_queue=input_queue,
|
83 |
+
output_queue=output_queue),
|
84 |
+
inputs=[gr.Audio(type="filepath", label="Upload .wav"), gr.File(type="filepath", label="Upload .npy"), "text",
|
85 |
+
gr.Number(label="infer timestep", value=32),
|
86 |
+
gr.Number(label="Intelligibility Weight", value=1.4),
|
87 |
+
gr.Number(label="Similarity Weight", value=3.0)], outputs=[gr.Audio(label="Synthesized Audio")],
|
88 |
+
title="MegaTTS3",
|
89 |
+
description="Upload a speech clip as a reference for timbre, " +
|
90 |
+
"upload the pre-extracted latent file, "+
|
91 |
+
"input the target text, and receive the cloned voice.", concurrency_limit=1)
|
92 |
+
api_interface.launch(server_name='0.0.0.0', server_port=7929, debug=True)
|
93 |
+
for p in processes:
|
94 |
+
p.join()
|
readme.md
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
<h1>
|
3 |
+
MegaTTS 3 <img src="./assets/fig/Hi.gif" width="40px">
|
4 |
+
</h1>
|
5 |
+
<p>
|
6 |
+
Official PyTorch Implementation<br>
|
7 |
+
</p>
|
8 |
+
<p></p>
|
9 |
+
<img src="https://img.shields.io/badge/Bytedance-%230077B5.svg?&style=flat-square&logo=bytedance&logoColor=white" />
|
10 |
+
<img src="https://img.shields.io/badge/Zhejiang University-%230077B5.svg?&style=flat-square&logo=&logoColor=white" />
|
11 |
+
</div>
|
12 |
+
|
13 |
+
## Key features
|
14 |
+
- 🚀**Lightweight and Efficient:** The backbone of the TTS Diffusion Transformer has only 0.45B parameters.
|
15 |
+
- 🎧**Ultra High-Quality Voice Cloning:** See the demo video below! We also report results of recent TTS models on the Seed test sets in the following table. 🎉Submit a sample on [link2](https://drive.google.com/drive/folders/1gCWL1y_2xu9nIFhUX_OW5MbcFuB7J5Cl?usp=sharing) to receive voice latents you can use locally.
|
16 |
+
- 🌍**Bilingual Support:** Supports both Chinese and English, and code-switching.
|
17 |
+
- ✍️**Controllable:** Supports accent intensity control ✅ and fine-grained pronunciation/duration adjustment (coming soon).
|
18 |
+
|
19 |
+
[MegaTTS 3 Demo Video](https://github.com/user-attachments/assets/0174c111-f392-4376-a34b-0b5b8164aacc)
|
20 |
+
|
21 |
+
<div style='width:100%;text-align:center'>
|
22 |
+
<img src="./assets/fig/table_tts.png" width="550px">
|
23 |
+
</div>
|
24 |
+
|
25 |
+
## 🎯Roadmap
|
26 |
+
|
27 |
+
- **[2025-03-22]** Our project has been released!
|
28 |
+
|
29 |
+
|
30 |
+
## Installation
|
31 |
+
``` sh
|
32 |
+
# Clone the repository
|
33 |
+
git clone https://github.com/bytedance/MegaTTS3
|
34 |
+
cd MegaTTS3
|
35 |
+
```
|
36 |
+
**Requirements (for Linux)**
|
37 |
+
``` sh
|
38 |
+
|
39 |
+
# Create a python 3.10 conda env (you could also use virtualenv)
|
40 |
+
conda create -n megatts3-env python=3.10
|
41 |
+
conda activate megatts3-env
|
42 |
+
pip install -r requirements.txt
|
43 |
+
|
44 |
+
# Set the root directory
|
45 |
+
export PYTHONPATH="/path/to/MegaTTS3:$PYTHONPATH"
|
46 |
+
|
47 |
+
# [Optional] Set GPU
|
48 |
+
export CUDA_VISIBLE_DEVICES=0
|
49 |
+
|
50 |
+
# If you encounter bugs with pydantic in inference, you should check if the versions of pydantic and gradio are matched.
|
51 |
+
# [Note] if you encounter bugs related with httpx, please check that whether your environmental variable "no_proxy" has patterns like "::"
|
52 |
+
```
|
53 |
+
|
54 |
+
**Requirements (for Windows)**
|
55 |
+
``` sh
|
56 |
+
# [The Windows version is currently under testing]
|
57 |
+
# Comment below dependence in requirements.txt:
|
58 |
+
# # WeTextProcessing==1.0.4.1
|
59 |
+
|
60 |
+
# Create a python 3.10 conda env (you could also use virtualenv)
|
61 |
+
conda create -n megatts3-env python=3.10
|
62 |
+
conda activate megatts3-env
|
63 |
+
pip install -r requirements.txt
|
64 |
+
conda install -y -c conda-forge pynini==2.1.5
|
65 |
+
pip install WeTextProcessing==1.0.3
|
66 |
+
|
67 |
+
# [Optional] If you want GPU inference, you may need to install specific version of PyTorch for your GPU from https://pytorch.org/.
|
68 |
+
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
|
69 |
+
|
70 |
+
# [Note] if you encounter bugs related with `ffprobe` or `ffmpeg`, you can install it through `conda install -c conda-forge ffmpeg`
|
71 |
+
|
72 |
+
# Set environment variable for root directory
|
73 |
+
set PYTHONPATH="C:\path\to\MegaTTS3;%PYTHONPATH%" # Windows
|
74 |
+
$env:PYTHONPATH="C:\path\to\MegaTTS3;%PYTHONPATH%" # Powershell on Windows
|
75 |
+
conda env config vars set PYTHONPATH="C:\path\to\MegaTTS3;%PYTHONPATH%" # For conda users
|
76 |
+
|
77 |
+
# [Optional] Set GPU
|
78 |
+
set CUDA_VISIBLE_DEVICES=0 # Windows
|
79 |
+
$env:CUDA_VISIBLE_DEVICES=0 # Powershell on Windows
|
80 |
+
|
81 |
+
```
|
82 |
+
|
83 |
+
**Requirements (for Docker)**
|
84 |
+
``` sh
|
85 |
+
# [The Docker version is currently under testing]
|
86 |
+
# ! You should download the pretrained checkpoint before running the following command
|
87 |
+
docker build . -t megatts3:latest
|
88 |
+
|
89 |
+
# For GPU inference
|
90 |
+
docker run -it -p 7929:7929 --gpus all -e CUDA_VISIBLE_DEVICES=0 megatts3:latest
|
91 |
+
# For CPU inference
|
92 |
+
docker run -it -p 7929:7929 megatts3:latest
|
93 |
+
|
94 |
+
# Visit http://0.0.0.0:7929/ for gradio.
|
95 |
+
```
|
96 |
+
|
97 |
+
|
98 |
+
**Model Download**
|
99 |
+
|
100 |
+
The pretrained checkpoint can be found at [Google Drive](https://drive.google.com/drive/folders/1CidiSqtHgJTBDAHQ746_on_YR0boHDYB?usp=sharing) or [Huggingface](https://huggingface.co/ByteDance/MegaTTS3). Please download them and put them to ``./checkpoints/xxx``.
|
101 |
+
|
102 |
+
> [!IMPORTANT]
|
103 |
+
> For security issues, we do not upload the parameters of WaveVAE encoder to the above links. You can only use the pre-extracted latents from [link1](https://drive.google.com/drive/folders/1QhcHWcy20JfqWjgqZX1YM3I6i9u4oNlr?usp=sharing) for inference. If you want to synthesize speech for speaker A, you need "A.wav" and "A.npy" in the same directory. If you have any questions or suggestions for our model, please email us.
|
104 |
+
>
|
105 |
+
> This project is primarily intended for academic purposes. For academic datasets requiring evaluation, you may upload them to the voice request queue in [link2](https://drive.google.com/drive/folders/1gCWL1y_2xu9nIFhUX_OW5MbcFuB7J5Cl?usp=sharing) (within 24s for each clip). After verifying that your uploaded voices are free from safety issues, we will upload their latent files to [link1](https://drive.google.com/drive/folders/1QhcHWcy20JfqWjgqZX1YM3I6i9u4oNlr?usp=sharing) as soon as possible.
|
106 |
+
>
|
107 |
+
> In the coming days, we will also prepare and release the latent representations for some common TTS benchmarks.
|
108 |
+
|
109 |
+
## Inference
|
110 |
+
|
111 |
+
**Command-Line Usage (Standard)**
|
112 |
+
``` bash
|
113 |
+
# p_w (intelligibility weight), t_w (similarity weight). Typically, prompt with more noises requires higher p_w and t_w
|
114 |
+
python tts/infer_cli.py --input_wav 'assets/Chinese_prompt.wav' --input_text "另一边的桌上,一位读书人嗤之以鼻道,'佛子三藏,神子燕小鱼是什么样的人物,李家的那个李子夜如何与他们相提并论?'" --output_dir ./gen
|
115 |
+
|
116 |
+
# As long as audio volume and pronunciation are appropriate, increasing --t_w within reasonable ranges (2.0~5.0)
|
117 |
+
# will increase the generated speech's expressiveness and similarity (especially for some emotional cases).
|
118 |
+
python tts/infer_cli.py --input_wav 'assets/English_prompt.wav' --input_text 'As his long promised tariff threat turned into reality this week, top human advisers began fielding a wave of calls from business leaders, particularly in the automotive sector, along with lawmakers who were sounding the alarm.' --output_dir ./gen --p_w 2.0 --t_w 3.0
|
119 |
+
```
|
120 |
+
**Command-Line Usage (for TTS with Accents)**
|
121 |
+
``` bash
|
122 |
+
# When p_w (intelligibility weight) ≈ 1.0, the generated audio closely retains the speaker’s original accent. As p_w increases, it shifts toward standard pronunciation.
|
123 |
+
# t_w (similarity weight) is typically set 0–3 points higher than p_w for optimal results.
|
124 |
+
# Useful for accented TTS or solving the accent problems in cross-lingual TTS.
|
125 |
+
python tts/infer_cli.py --input_wav 'assets/English_prompt.wav' --input_text '这是一条有口音的音频。' --output_dir ./gen --p_w 1.0 --t_w 3.0
|
126 |
+
|
127 |
+
python tts/infer_cli.py --input_wav 'assets/English_prompt.wav' --input_text '这条音频的发音标准一些了吗?' --output_dir ./gen --p_w 2.5 --t_w 2.5
|
128 |
+
```
|
129 |
+
|
130 |
+
**Web UI Usage**
|
131 |
+
``` bash
|
132 |
+
# We also support cpu inference, but it may take about 30 seconds (for 10 inference steps).
|
133 |
+
python tts/gradio_api.py
|
134 |
+
```
|
135 |
+
|
136 |
+
## Submodules
|
137 |
+
> [!TIP]
|
138 |
+
> In addition to TTS, some submodules in this project may also have additional usages.
|
139 |
+
> See ``./tts/frontend_fuction.py`` and ``./tts/infer_cli.py`` for example code.
|
140 |
+
|
141 |
+
### Aligner
|
142 |
+
**Description:** a robust speech-text aligner model trained using pseudo-labels generated by a large number of MFA expert models.
|
143 |
+
|
144 |
+
**Usage**: 1) Prepare the finetuning dataset for our model; 2) Filter the large-scale speech dataset (if the aligner fails to align a certain speech clip, it is likely to be noisy); 3) Phoneme recognition; 4) Speech segmentation.
|
145 |
+
|
146 |
+
### Graphme-to-Phoneme Model
|
147 |
+
**Description:** a Qwen2.5-0.5B model finetuned for robust graphme-to-phoneme conversion.
|
148 |
+
|
149 |
+
**Usage**: Graphme-to-phoneme conversion.
|
150 |
+
|
151 |
+
### WaveVAE
|
152 |
+
**Description:** a strong waveform VAE that can compress 24 kHz speeche into 25 Hz acoustic latent and reconstruct the original wave almost losslessly.
|
153 |
+
|
154 |
+
**Usage:** 1) Acoustic latents can provide a more compact and discriminative training target for speech synthesis models compared to mel-spectrograms, accelerating convergence; 2) Used as acoustic latents for voice conversion; 3) High-quality vocoder.
|
155 |
+
|
156 |
+
<div style='width:100%;text-align:center'>
|
157 |
+
<img src="./assets/fig/table_wavvae.png" width="650px">
|
158 |
+
</div>
|
159 |
+
|
160 |
+
|
161 |
+
## Security
|
162 |
+
If you discover a potential security issue in this project, or think you may
|
163 |
+
have discovered a security issue, we ask that you notify Bytedance Security via our [security center](https://security.bytedance.com/src) or [[email protected]]([email protected]).
|
164 |
+
|
165 |
+
Please do **not** create a public GitHub issue.
|
166 |
+
|
167 |
+
## License
|
168 |
+
This project is licensed under the [Apache-2.0 License](LICENSE).
|
169 |
+
|
170 |
+
## Citation
|
171 |
+
This repo contains forced-align version of `Sparse Alignment Enhanced Latent Diffusion Transformer for Zero-Shot Speech Synthesis` and the WavVAE is mainly based on `Wavtokenizer: an efficient acoustic discrete codec tokenizer for audio language modeling`. Compared to the model described in paper, the repository includes additional models. These models not only enhance the stability and cloning capabilities of the algorithm but can also be independently utilized to serve a wider range of scenarios.
|
172 |
+
```
|
173 |
+
@article{jiang2025sparse,
|
174 |
+
title={Sparse Alignment Enhanced Latent Diffusion Transformer for Zero-Shot Speech Synthesis},
|
175 |
+
author={Jiang, Ziyue and Ren, Yi and Li, Ruiqi and Ji, Shengpeng and Ye, Zhenhui and Zhang, Chen and Jionghao, Bai and Yang, Xiaoda and Zuo, Jialong and Zhang, Yu and others},
|
176 |
+
journal={arXiv preprint arXiv:2502.18924},
|
177 |
+
year={2025}
|
178 |
+
}
|
179 |
+
|
180 |
+
@article{ji2024wavtokenizer,
|
181 |
+
title={Wavtokenizer: an efficient acoustic discrete codec tokenizer for audio language modeling},
|
182 |
+
author={Ji, Shengpeng and Jiang, Ziyue and Wang, Wen and Chen, Yifu and Fang, Minghui and Zuo, Jialong and Yang, Qian and Cheng, Xize and Wang, Zehan and Li, Ruiqi and others},
|
183 |
+
journal={arXiv preprint arXiv:2408.16532},
|
184 |
+
year={2024}
|
185 |
+
}
|
186 |
+
```
|
requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.3.0
|
2 |
+
torchaudio==2.3.0
|
3 |
+
setproctitle==1.3.3
|
4 |
+
attrdict==2.0.1
|
5 |
+
librosa==0.10.2.post1
|
6 |
+
langdetect==1.0.9
|
7 |
+
pydub==0.25.1
|
8 |
+
pyloudnorm==0.1.1
|
9 |
+
modelscope==1.22.2
|
10 |
+
WeTextProcessing==1.0.4.1
|
11 |
+
transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.*;python_version<'3.10'
|
12 |
+
transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.0;python_version>='3.10'
|
13 |
+
x-transformers==1.44.4
|
14 |
+
torchdiffeq==0.2.5
|
15 |
+
openai-whisper==20240930
|
16 |
+
httpx==0.28.1
|
17 |
+
gradio==5.23.1
|
tts/frontend_function.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 ByteDance and/or its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn.functional as F
|
17 |
+
import whisper
|
18 |
+
import librosa
|
19 |
+
from copy import deepcopy
|
20 |
+
from tts.utils.text_utils.ph_tone_convert import split_ph_timestamp, split_ph
|
21 |
+
from tts.utils.audio_utils.align import mel2token_to_dur
|
22 |
+
|
23 |
+
''' Graphme to phoneme function '''
|
24 |
+
def g2p(self, text_inp):
|
25 |
+
# prepare inputs
|
26 |
+
txt_token = self.g2p_tokenizer('<BOT>' + text_inp + '<BOS>')['input_ids']
|
27 |
+
input_ids = torch.LongTensor([txt_token+[145+self.speech_start_idx]]).to(self.device)
|
28 |
+
|
29 |
+
# model forward
|
30 |
+
with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
|
31 |
+
outputs = self.g2p_model.generate(input_ids, max_new_tokens=256, do_sample=True, top_k=1, eos_token_id=800+1+self.speech_start_idx)
|
32 |
+
|
33 |
+
# process outputs
|
34 |
+
ph_tokens = outputs[:, len(txt_token):-1]-self.speech_start_idx
|
35 |
+
ph_pred, tone_pred = split_ph(ph_tokens[0])
|
36 |
+
ph_pred, tone_pred = ph_pred[None, :].to(self.device), tone_pred[None, :].to(self.device)
|
37 |
+
return ph_pred, tone_pred
|
38 |
+
|
39 |
+
''' Get phoneme2mel align of prompt speech '''
|
40 |
+
def align(self, wav):
|
41 |
+
with torch.inference_mode():
|
42 |
+
whisper_wav = librosa.resample(wav, orig_sr=self.sr, target_sr=16000)
|
43 |
+
mel = torch.FloatTensor(whisper.log_mel_spectrogram(whisper_wav).T).to(self.device)[None].transpose(1,2)
|
44 |
+
prompt_max_frame = mel.size(2) // self.fm * self.fm
|
45 |
+
mel = mel[:, :, :prompt_max_frame]
|
46 |
+
token = torch.LongTensor([[798]]).to(self.device)
|
47 |
+
audio_features = self.aligner_lm.embed_audio(mel)
|
48 |
+
for i in range(768):
|
49 |
+
with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
|
50 |
+
logits = self.aligner_lm.logits(token, audio_features, None)
|
51 |
+
token_pred = torch.argmax(F.softmax(logits[:, -1], dim=-1), 1)[None]
|
52 |
+
token = torch.cat([token, token_pred], dim=1)
|
53 |
+
if token_pred[0] == 799:
|
54 |
+
break
|
55 |
+
alignment_tokens = token
|
56 |
+
|
57 |
+
ph_ref, tone_ref, dur_ref, _ = split_ph_timestamp(deepcopy(alignment_tokens)[0, 1:-1])
|
58 |
+
ph_ref = torch.Tensor(ph_ref)[None].to(self.device)
|
59 |
+
tone_ref = torch.Tensor(tone_ref)[None].to(self.device)
|
60 |
+
if dur_ref.sum() < prompt_max_frame:
|
61 |
+
dur_ref[-1] += prompt_max_frame - dur_ref.sum()
|
62 |
+
elif dur_ref.sum() > prompt_max_frame:
|
63 |
+
len_diff = dur_ref.sum() - prompt_max_frame
|
64 |
+
while True:
|
65 |
+
for i in range(len(dur_ref)):
|
66 |
+
dur_ref[i] -= 1
|
67 |
+
len_diff -= 1
|
68 |
+
if len_diff == 0:
|
69 |
+
break
|
70 |
+
if len_diff == 0:
|
71 |
+
break
|
72 |
+
mel2ph_ref = self.length_regulator(dur_ref[None]).to(self.device)
|
73 |
+
mel2ph_ref = mel2ph_ref[:, :mel2ph_ref.size(1)//self.fm*self.fm]
|
74 |
+
return ph_ref, tone_ref, mel2ph_ref
|
75 |
+
|
76 |
+
''' Duration Prompting '''
|
77 |
+
def make_dur_prompt(self, mel2ph_ref, ph_ref, tone_ref):
|
78 |
+
dur_tokens_2d_ = mel2token_to_dur(mel2ph_ref, ph_ref.shape[1]).clamp(
|
79 |
+
max=self.hp_dur_model['dur_code_size'] - 1) + 1
|
80 |
+
|
81 |
+
ctx_dur_tokens = dur_tokens_2d_.clone().flatten(0, 1).to(self.device)
|
82 |
+
txt_tokens_flat_ = ph_ref.flatten(0, 1)
|
83 |
+
ctx_dur_tokens = ctx_dur_tokens[txt_tokens_flat_ > 0][None]
|
84 |
+
|
85 |
+
last_dur_pos_prompt = ctx_dur_tokens.shape[1]
|
86 |
+
dur_spk_pos_ids_flat = range(0, last_dur_pos_prompt)
|
87 |
+
dur_spk_pos_ids_flat = torch.LongTensor([dur_spk_pos_ids_flat]).to(self.device)
|
88 |
+
with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
|
89 |
+
_, incremental_state_dur_prompt = self.dur_model.infer(
|
90 |
+
ph_ref, {'tone': tone_ref}, None, None, None,
|
91 |
+
ctx_vqcodes=ctx_dur_tokens, spk_pos_ids_flat=dur_spk_pos_ids_flat, return_state=True)
|
92 |
+
return incremental_state_dur_prompt, ctx_dur_tokens
|
93 |
+
|
94 |
+
''' Duration Prediction '''
|
95 |
+
def dur_pred(self, ctx_dur_tokens, incremental_state_dur_prompt, ph_pred, tone_pred, seg_i, dur_disturb, dur_alpha, is_first, is_final):
|
96 |
+
last_dur_token = ctx_dur_tokens[:, -1:]
|
97 |
+
last_dur_pos_prompt = ctx_dur_tokens.shape[1]
|
98 |
+
incremental_state_dur = deepcopy(incremental_state_dur_prompt)
|
99 |
+
txt_len = ph_pred.shape[1]
|
100 |
+
dur_spk_pos_ids_flat = range(last_dur_pos_prompt, last_dur_pos_prompt + txt_len)
|
101 |
+
dur_spk_pos_ids_flat = torch.LongTensor([dur_spk_pos_ids_flat]).to(self.device)
|
102 |
+
last_dur_pos_prompt = last_dur_pos_prompt + txt_len
|
103 |
+
|
104 |
+
with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
|
105 |
+
dur_pred = self.dur_model.infer(
|
106 |
+
ph_pred, {'tone': tone_pred}, None, None, None,
|
107 |
+
incremental_state=incremental_state_dur,
|
108 |
+
first_decoder_inp=last_dur_token,
|
109 |
+
spk_pos_ids_flat=dur_spk_pos_ids_flat,
|
110 |
+
)
|
111 |
+
|
112 |
+
dur_pred = dur_pred - 1
|
113 |
+
dur_pred = dur_pred.clamp(0, self.hp_dur_model['dur_code_size'] - 1)
|
114 |
+
# if is_final:
|
115 |
+
# dur_pred[:, -1] = dur_pred[:, -1].clamp(64, 128)
|
116 |
+
# else:
|
117 |
+
# dur_pred[:, -1] = dur_pred[:, -1].clamp(48, 128)
|
118 |
+
# if seg_i > 0:
|
119 |
+
# dur_pred[:, 0] = 0
|
120 |
+
# ['。', '!', '?', 'sil']
|
121 |
+
for sil_token in [148, 153, 166, 145]:
|
122 |
+
dur_pred[ph_pred==sil_token].clamp_min(32)
|
123 |
+
# [',', ';']
|
124 |
+
for sil_token in [163, 165]:
|
125 |
+
dur_pred[ph_pred==sil_token].clamp_min(16)
|
126 |
+
if not is_final:
|
127 |
+
# add 0.32ms for crossfade
|
128 |
+
dur_pred[:, -1] = dur_pred[:, -1] + 32
|
129 |
+
else:
|
130 |
+
dur_pred[:, -1] = dur_pred[:, -1].clamp(64, 128)
|
131 |
+
|
132 |
+
''' DiT target speech generation '''
|
133 |
+
dur_disturb_choice = (torch.rand_like(dur_pred.float()) > 0.5).float()
|
134 |
+
dur_disturb_r = 1 + torch.rand_like(dur_pred.float()) * dur_disturb
|
135 |
+
dur_pred = dur_pred * dur_disturb_r * dur_disturb_choice + \
|
136 |
+
dur_pred / dur_disturb_r * (1 - dur_disturb_choice)
|
137 |
+
dur_pred = torch.round(dur_pred * dur_alpha).clamp(0, 127)
|
138 |
+
if is_first:
|
139 |
+
dur_pred[:, 0] = 8
|
140 |
+
|
141 |
+
dur_sum = dur_pred.sum()
|
142 |
+
npad = self.fm - dur_sum % self.fm
|
143 |
+
if npad < self.fm:
|
144 |
+
dur_pred[:, -1] += npad
|
145 |
+
mel2ph_pred = self.length_regulator(dur_pred).to(self.device)
|
146 |
+
return mel2ph_pred
|
147 |
+
|
148 |
+
def prepare_inputs_for_dit(self, mel2ph_ref, mel2ph_pred, ph_ref, tone_ref, ph_pred, tone_pred, vae_latent):
|
149 |
+
# Prepare duration token
|
150 |
+
mel2ph_pred = torch.cat((mel2ph_ref, mel2ph_pred+ph_ref.size(1)), dim=1)
|
151 |
+
mel2ph_pred = mel2ph_pred[:, :mel2ph_pred.size(1)//self.fm*self.fm].repeat(3, 1)
|
152 |
+
# Prepare phone and tone token
|
153 |
+
ph_pred = torch.cat((ph_ref, ph_pred), dim=1)
|
154 |
+
tone_pred = torch.cat((tone_ref, tone_pred), dim=1)
|
155 |
+
# Disable the English tone (set them to 3)"""
|
156 |
+
en_tone_idx = ~((tone_pred == 4) | ( (11 <= tone_pred) & (tone_pred <= 15)) | (tone_pred == 0))
|
157 |
+
tone_pred[en_tone_idx] = 3
|
158 |
+
|
159 |
+
# Prepare cfg inputs
|
160 |
+
ph_seq = torch.cat([ph_pred, ph_pred, torch.full(ph_pred.size(), self.cfg_mask_token_phone, device=self.device)], 0)
|
161 |
+
tone_seq = torch.cat([tone_pred, tone_pred, torch.full(tone_pred.size(), self.cfg_mask_token_tone, device=self.device)], 0)
|
162 |
+
target_size = mel2ph_pred.size(1)//self.vae_stride
|
163 |
+
vae_latent_ = vae_latent.repeat(3, 1, 1)
|
164 |
+
ctx_mask = torch.ones_like(vae_latent_[:, :, 0:1])
|
165 |
+
vae_latent_ = F.pad(vae_latent_, (0, 0, 0, target_size - vae_latent.size(1)), mode='constant', value=0)
|
166 |
+
vae_latent_[1:] = 0.0
|
167 |
+
ctx_mask = F.pad(ctx_mask, (0, 0, 0, target_size - vae_latent.size(1)), mode='constant', value=0)
|
168 |
+
|
169 |
+
return {
|
170 |
+
'phone': ph_seq,
|
171 |
+
'tone': tone_seq,
|
172 |
+
"lat_ctx": vae_latent_ * ctx_mask,
|
173 |
+
"ctx_mask": ctx_mask,
|
174 |
+
"dur": mel2ph_pred,
|
175 |
+
}
|
tts/gradio_api.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 ByteDance and/or its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import multiprocessing as mp
|
16 |
+
import torch
|
17 |
+
import os
|
18 |
+
from functools import partial
|
19 |
+
import gradio as gr
|
20 |
+
import traceback
|
21 |
+
from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav
|
22 |
+
|
23 |
+
|
24 |
+
def model_worker(input_queue, output_queue, device_id):
|
25 |
+
device = None
|
26 |
+
if device_id is not None:
|
27 |
+
device = torch.device(f'cuda:{device_id}')
|
28 |
+
infer_pipe = MegaTTS3DiTInfer(device=device)
|
29 |
+
os.system(f'pkill -f "voidgpu{device_id}"')
|
30 |
+
|
31 |
+
while True:
|
32 |
+
task = input_queue.get()
|
33 |
+
inp_audio_path, inp_npy_path, inp_text, infer_timestep, p_w, t_w = task
|
34 |
+
try:
|
35 |
+
convert_to_wav(inp_audio_path)
|
36 |
+
wav_path = os.path.splitext(inp_audio_path)[0] + '.wav'
|
37 |
+
cut_wav(wav_path, max_len=28)
|
38 |
+
with open(wav_path, 'rb') as file:
|
39 |
+
file_content = file.read()
|
40 |
+
resource_context = infer_pipe.preprocess(file_content, latent_file=inp_npy_path)
|
41 |
+
wav_bytes = infer_pipe.forward(resource_context, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
|
42 |
+
output_queue.put(wav_bytes)
|
43 |
+
except Exception as e:
|
44 |
+
traceback.print_exc()
|
45 |
+
print(task, str(e))
|
46 |
+
output_queue.put(None)
|
47 |
+
|
48 |
+
|
49 |
+
def main(inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w, processes, input_queue, output_queue):
|
50 |
+
print("Push task to the inp queue |", inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w)
|
51 |
+
input_queue.put((inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w))
|
52 |
+
res = output_queue.get()
|
53 |
+
if res is not None:
|
54 |
+
return res
|
55 |
+
else:
|
56 |
+
print("")
|
57 |
+
return None
|
58 |
+
|
59 |
+
|
60 |
+
if __name__ == '__main__':
|
61 |
+
mp.set_start_method('spawn', force=True)
|
62 |
+
devices = os.environ.get('CUDA_VISIBLE_DEVICES', '')
|
63 |
+
if devices != '':
|
64 |
+
devices = os.environ.get('CUDA_VISIBLE_DEVICES', '').split(",")
|
65 |
+
for d in devices:
|
66 |
+
os.system(f'pkill -f "voidgpu{d}"')
|
67 |
+
else:
|
68 |
+
devices = None
|
69 |
+
|
70 |
+
num_workers = 1
|
71 |
+
input_queue = mp.Queue()
|
72 |
+
output_queue = mp.Queue()
|
73 |
+
processes = []
|
74 |
+
|
75 |
+
print("Start open workers")
|
76 |
+
for i in range(num_workers):
|
77 |
+
p = mp.Process(target=model_worker, args=(input_queue, output_queue, i % len(devices) if devices is not None else None))
|
78 |
+
p.start()
|
79 |
+
processes.append(p)
|
80 |
+
|
81 |
+
api_interface = gr.Interface(fn=
|
82 |
+
partial(main, processes=processes, input_queue=input_queue,
|
83 |
+
output_queue=output_queue),
|
84 |
+
inputs=[gr.Audio(type="filepath", label="Upload .wav"), gr.File(type="filepath", label="Upload .npy"), "text",
|
85 |
+
gr.Number(label="infer timestep", value=32),
|
86 |
+
gr.Number(label="Intelligibility Weight", value=1.4),
|
87 |
+
gr.Number(label="Similarity Weight", value=3.0)], outputs=[gr.Audio(label="Synthesized Audio")],
|
88 |
+
title="MegaTTS3",
|
89 |
+
description="Upload a speech clip as a reference for timbre, " +
|
90 |
+
"upload the pre-extracted latent file, "+
|
91 |
+
"input the target text, and receive the cloned voice.", concurrency_limit=1)
|
92 |
+
api_interface.launch(server_name='0.0.0.0', server_port=7929, debug=True)
|
93 |
+
for p in processes:
|
94 |
+
p.join()
|
tts/infer_cli.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 ByteDance and/or its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import json
|
16 |
+
import os
|
17 |
+
import argparse
|
18 |
+
import librosa
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
|
22 |
+
from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
23 |
+
from tn.english.normalizer import Normalizer as EnNormalizer
|
24 |
+
from langdetect import detect as classify_language
|
25 |
+
from pydub import AudioSegment
|
26 |
+
import pyloudnorm as pyln
|
27 |
+
|
28 |
+
from tts.modules.ar_dur.commons.nar_tts_modules import LengthRegulator
|
29 |
+
from tts.frontend_function import g2p, align, make_dur_prompt, dur_pred, prepare_inputs_for_dit
|
30 |
+
from tts.utils.audio_utils.io import save_wav, to_wav_bytes, convert_to_wav_bytes, combine_audio_segments
|
31 |
+
from tts.utils.commons.ckpt_utils import load_ckpt
|
32 |
+
from tts.utils.commons.hparams import set_hparams, hparams
|
33 |
+
from tts.utils.text_utils.text_encoder import TokenTextEncoder
|
34 |
+
from tts.utils.text_utils.split_text import chunk_text_chinese, chunk_text_english
|
35 |
+
from tts.utils.commons.hparams import hparams, set_hparams
|
36 |
+
|
37 |
+
|
38 |
+
if "TOKENIZERS_PARALLELISM" not in os.environ:
|
39 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
40 |
+
|
41 |
+
def convert_to_wav(wav_path):
|
42 |
+
# Check if the file exists
|
43 |
+
if not os.path.exists(wav_path):
|
44 |
+
print(f"The file '{wav_path}' does not exist.")
|
45 |
+
return
|
46 |
+
|
47 |
+
# Check if the file already has a .wav extension
|
48 |
+
if not wav_path.endswith(".wav"):
|
49 |
+
# Define the output path with a .wav extension
|
50 |
+
out_path = os.path.splitext(wav_path)[0] + ".wav"
|
51 |
+
|
52 |
+
# Load the audio file using pydub and convert it to WAV
|
53 |
+
audio = AudioSegment.from_file(wav_path)
|
54 |
+
audio.export(out_path, format="wav")
|
55 |
+
|
56 |
+
print(f"Converted '{wav_path}' to '{out_path}'")
|
57 |
+
|
58 |
+
|
59 |
+
def cut_wav(wav_path, max_len=28):
|
60 |
+
audio = AudioSegment.from_file(wav_path)
|
61 |
+
audio = audio[:int(max_len * 1000)]
|
62 |
+
audio.export(wav_path, format="wav")
|
63 |
+
|
64 |
+
class MegaTTS3DiTInfer():
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
device=None,
|
68 |
+
ckpt_root='./checkpoints',
|
69 |
+
dit_exp_name='diffusion_transformer',
|
70 |
+
frontend_exp_name='aligner_lm',
|
71 |
+
wavvae_exp_name='wavvae',
|
72 |
+
dur_ckpt_path='duration_lm',
|
73 |
+
g2p_exp_name='g2p',
|
74 |
+
precision=torch.float16,
|
75 |
+
**kwargs
|
76 |
+
):
|
77 |
+
self.sr = 24000
|
78 |
+
self.fm = 8
|
79 |
+
if device is None:
|
80 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
81 |
+
self.device = device
|
82 |
+
self.precision = precision
|
83 |
+
|
84 |
+
# build models
|
85 |
+
self.dit_exp_name = os.path.join(ckpt_root, dit_exp_name)
|
86 |
+
self.frontend_exp_name = os.path.join(ckpt_root, frontend_exp_name)
|
87 |
+
self.wavvae_exp_name = os.path.join(ckpt_root, wavvae_exp_name)
|
88 |
+
self.dur_exp_name = os.path.join(ckpt_root, dur_ckpt_path)
|
89 |
+
self.g2p_exp_name = os.path.join(ckpt_root, g2p_exp_name)
|
90 |
+
self.build_model(self.device)
|
91 |
+
|
92 |
+
# init text normalizer
|
93 |
+
self.zh_normalizer = ZhNormalizer(overwrite_cache=False, remove_erhua=False, remove_interjections=False)
|
94 |
+
self.en_normalizer = EnNormalizer(overwrite_cache=False)
|
95 |
+
# loudness meter
|
96 |
+
self.loudness_meter = pyln.Meter(self.sr)
|
97 |
+
|
98 |
+
def build_model(self, device):
|
99 |
+
set_hparams(exp_name=self.dit_exp_name, print_hparams=False)
|
100 |
+
|
101 |
+
''' Load Dict '''
|
102 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
103 |
+
ling_dict = json.load(open(f"{current_dir}/utils/text_utils/dict.json", encoding='utf-8-sig'))
|
104 |
+
self.ling_dict = {k: TokenTextEncoder(None, vocab_list=ling_dict[k], replace_oov='<UNK>') for k in ['phone', 'tone']}
|
105 |
+
self.token_encoder = token_encoder = self.ling_dict['phone']
|
106 |
+
ph_dict_size = len(token_encoder)
|
107 |
+
|
108 |
+
''' Load Duration LM '''
|
109 |
+
from tts.modules.ar_dur.ar_dur_predictor import ARDurPredictor
|
110 |
+
hp_dur_model = self.hp_dur_model = set_hparams(f'{self.dur_exp_name}/config.yaml', global_hparams=False)
|
111 |
+
hp_dur_model['frames_multiple'] = hparams['frames_multiple']
|
112 |
+
self.dur_model = ARDurPredictor(
|
113 |
+
hp_dur_model, hp_dur_model['dur_txt_hs'], hp_dur_model['dur_model_hidden_size'],
|
114 |
+
hp_dur_model['dur_model_layers'], ph_dict_size,
|
115 |
+
hp_dur_model['dur_code_size'],
|
116 |
+
use_rot_embed=hp_dur_model.get('use_rot_embed', False))
|
117 |
+
self.length_regulator = LengthRegulator()
|
118 |
+
load_ckpt(self.dur_model, f'{self.dur_exp_name}', 'dur_model')
|
119 |
+
self.dur_model.eval()
|
120 |
+
self.dur_model.to(device)
|
121 |
+
|
122 |
+
''' Load Diffusion Transformer '''
|
123 |
+
from tts.modules.llm_dit.dit import Diffusion
|
124 |
+
self.dit = Diffusion()
|
125 |
+
load_ckpt(self.dit, f'{self.dit_exp_name}', 'dit', strict=False)
|
126 |
+
self.dit.eval()
|
127 |
+
self.dit.to(device)
|
128 |
+
self.cfg_mask_token_phone = 302 - 1
|
129 |
+
self.cfg_mask_token_tone = 32 - 1
|
130 |
+
|
131 |
+
''' Load Frontend LM '''
|
132 |
+
from tts.modules.aligner.whisper_small import Whisper
|
133 |
+
self.aligner_lm = Whisper()
|
134 |
+
load_ckpt(self.aligner_lm, f'{self.frontend_exp_name}', 'model')
|
135 |
+
self.aligner_lm.eval()
|
136 |
+
self.aligner_lm.to(device)
|
137 |
+
self.kv_cache = None
|
138 |
+
self.hooks = None
|
139 |
+
|
140 |
+
''' Load G2P LM'''
|
141 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
142 |
+
g2p_tokenizer = AutoTokenizer.from_pretrained(self.g2p_exp_name, padding_side="right")
|
143 |
+
g2p_tokenizer.padding_side = "right"
|
144 |
+
self.g2p_model = AutoModelForCausalLM.from_pretrained(self.g2p_exp_name).eval().to(device)
|
145 |
+
self.g2p_tokenizer = g2p_tokenizer
|
146 |
+
self.speech_start_idx = g2p_tokenizer.encode('<Reserved_TTS_0>')[0]
|
147 |
+
|
148 |
+
''' Wav VAE '''
|
149 |
+
self.hp_wavvae = hp_wavvae = set_hparams(f'{self.wavvae_exp_name}/config.yaml', global_hparams=False)
|
150 |
+
from tts.modules.wavvae.decoder.wavvae_v3 import WavVAE_V3
|
151 |
+
self.wavvae = WavVAE_V3(hparams=hp_wavvae)
|
152 |
+
if os.path.exists(f'{self.wavvae_exp_name}/model_only_last.ckpt'):
|
153 |
+
load_ckpt(self.wavvae, f'{self.wavvae_exp_name}/model_only_last.ckpt', 'model_gen', strict=True)
|
154 |
+
self.has_vae_encoder = True
|
155 |
+
else:
|
156 |
+
load_ckpt(self.wavvae, f'{self.wavvae_exp_name}/decoder.ckpt', 'model_gen', strict=False)
|
157 |
+
self.has_vae_encoder = False
|
158 |
+
self.wavvae.eval()
|
159 |
+
self.wavvae.to(device)
|
160 |
+
self.vae_stride = hp_wavvae.get('vae_stride', 4)
|
161 |
+
self.hop_size = hp_wavvae.get('hop_size', 4)
|
162 |
+
|
163 |
+
def preprocess(self, audio_bytes, latent_file=None, topk_dur=1, **kwargs):
|
164 |
+
wav_bytes = convert_to_wav_bytes(audio_bytes)
|
165 |
+
|
166 |
+
''' Load wav '''
|
167 |
+
wav, _ = librosa.core.load(wav_bytes, sr=self.sr)
|
168 |
+
# Pad wav if necessary
|
169 |
+
ws = hparams['win_size']
|
170 |
+
if len(wav) % ws < ws - 1:
|
171 |
+
wav = np.pad(wav, (0, ws - 1 - (len(wav) % ws)), mode='constant', constant_values=0.0).astype(np.float32)
|
172 |
+
wav = np.pad(wav, (0, 12000), mode='constant', constant_values=0.0).astype(np.float32)
|
173 |
+
self.loudness_prompt = self.loudness_meter.integrated_loudness(wav.astype(float))
|
174 |
+
|
175 |
+
''' obtain alignments with aligner_lm '''
|
176 |
+
ph_ref, tone_ref, mel2ph_ref = align(self, wav)
|
177 |
+
|
178 |
+
with torch.inference_mode():
|
179 |
+
''' Forward WaveVAE to obtain: prompt latent '''
|
180 |
+
if self.has_vae_encoder:
|
181 |
+
wav = torch.FloatTensor(wav)[None].to(self.device)
|
182 |
+
vae_latent = self.wavvae.encode_latent(wav)
|
183 |
+
vae_latent = vae_latent[:, :mel2ph_ref.size(1)//4]
|
184 |
+
else:
|
185 |
+
assert latent_file is not None, "Please provide latent_file in WaveVAE decoder-only mode"
|
186 |
+
vae_latent = torch.from_numpy(np.load(latent_file)).to(self.device)
|
187 |
+
vae_latent = vae_latent[:, :mel2ph_ref.size(1)//4]
|
188 |
+
|
189 |
+
''' Duration Prompting '''
|
190 |
+
self.dur_model.hparams["infer_top_k"] = topk_dur if topk_dur > 1 else None
|
191 |
+
incremental_state_dur_prompt, ctx_dur_tokens = make_dur_prompt(self, mel2ph_ref, ph_ref, tone_ref)
|
192 |
+
|
193 |
+
return {
|
194 |
+
'ph_ref': ph_ref,
|
195 |
+
'tone_ref': tone_ref,
|
196 |
+
'mel2ph_ref': mel2ph_ref,
|
197 |
+
'vae_latent': vae_latent,
|
198 |
+
'incremental_state_dur_prompt': incremental_state_dur_prompt,
|
199 |
+
'ctx_dur_tokens': ctx_dur_tokens,
|
200 |
+
}
|
201 |
+
|
202 |
+
def forward(self, resource_context, input_text, time_step, p_w, t_w, dur_disturb=0.1, dur_alpha=1.0, **kwargs):
|
203 |
+
device = self.device
|
204 |
+
|
205 |
+
ph_ref = resource_context['ph_ref'].to(device)
|
206 |
+
tone_ref = resource_context['tone_ref'].to(device)
|
207 |
+
mel2ph_ref = resource_context['mel2ph_ref'].to(device)
|
208 |
+
vae_latent = resource_context['vae_latent'].to(device)
|
209 |
+
ctx_dur_tokens = resource_context['ctx_dur_tokens'].to(device)
|
210 |
+
incremental_state_dur_prompt = resource_context['incremental_state_dur_prompt']
|
211 |
+
|
212 |
+
with torch.inference_mode():
|
213 |
+
''' Generating '''
|
214 |
+
wav_pred_ = []
|
215 |
+
language_type = classify_language(input_text)
|
216 |
+
if language_type == 'en':
|
217 |
+
input_text = self.en_normalizer.normalize(input_text)
|
218 |
+
text_segs = chunk_text_english(input_text, max_chars=130)
|
219 |
+
else:
|
220 |
+
input_text = self.zh_normalizer.normalize(input_text)
|
221 |
+
text_segs = chunk_text_chinese(input_text, limit=60)
|
222 |
+
|
223 |
+
for seg_i, text in enumerate(text_segs):
|
224 |
+
''' G2P '''
|
225 |
+
ph_pred, tone_pred = g2p(self, text)
|
226 |
+
|
227 |
+
''' Duration Prediction '''
|
228 |
+
mel2ph_pred = dur_pred(self, ctx_dur_tokens, incremental_state_dur_prompt, ph_pred, tone_pred, seg_i, dur_disturb, dur_alpha, is_first=seg_i==0, is_final=seg_i==len(text_segs)-1)
|
229 |
+
|
230 |
+
inputs = prepare_inputs_for_dit(self, mel2ph_ref, mel2ph_pred, ph_ref, tone_ref, ph_pred, tone_pred, vae_latent)
|
231 |
+
# Speech dit inference
|
232 |
+
with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
|
233 |
+
x = self.dit.inference(inputs, timesteps=time_step, seq_cfg_w=[p_w, t_w]).float()
|
234 |
+
|
235 |
+
# WavVAE decode
|
236 |
+
x[:, :vae_latent.size(1)] = vae_latent
|
237 |
+
wav_pred = self.wavvae.decode(x)[0,0].to(torch.float32)
|
238 |
+
|
239 |
+
''' Post-processing '''
|
240 |
+
# Trim prompt wav
|
241 |
+
wav_pred = wav_pred[vae_latent.size(1)*self.vae_stride*self.hop_size:].cpu().numpy()
|
242 |
+
# Norm generated wav to prompt wav's level
|
243 |
+
meter = pyln.Meter(self.sr) # create BS.1770 meter
|
244 |
+
loudness_pred = self.loudness_meter.integrated_loudness(wav_pred.astype(float))
|
245 |
+
wav_pred = pyln.normalize.loudness(wav_pred, loudness_pred, self.loudness_prompt)
|
246 |
+
if np.abs(wav_pred).max() >= 1:
|
247 |
+
wav_pred = wav_pred / np.abs(wav_pred).max() * 0.95
|
248 |
+
|
249 |
+
# Apply hamming window
|
250 |
+
wav_pred_.append(wav_pred)
|
251 |
+
|
252 |
+
wav_pred = combine_audio_segments(wav_pred_, sr=self.sr).astype(float)
|
253 |
+
return to_wav_bytes(wav_pred, self.sr)
|
254 |
+
|
255 |
+
|
256 |
+
if __name__ == '__main__':
|
257 |
+
parser = argparse.ArgumentParser()
|
258 |
+
parser.add_argument('--input_wav', type=str)
|
259 |
+
parser.add_argument('--input_text', type=str)
|
260 |
+
parser.add_argument('--output_dir', type=str)
|
261 |
+
parser.add_argument('--time_step', type=int, default=32, help='Inference steps of Diffusion Transformer')
|
262 |
+
parser.add_argument('--p_w', type=float, default=1.6, help='Intelligibility Weight')
|
263 |
+
parser.add_argument('--t_w', type=float, default=2.5, help='Similarity Weight')
|
264 |
+
args = parser.parse_args()
|
265 |
+
wav_path, input_text, out_path, time_step, p_w, t_w = args.input_wav, args.input_text, args.output_dir, args.time_step, args.p_w, args.t_w
|
266 |
+
|
267 |
+
infer_ins = MegaTTS3DiTInfer()
|
268 |
+
|
269 |
+
with open(wav_path, 'rb') as file:
|
270 |
+
file_content = file.read()
|
271 |
+
|
272 |
+
print(f"| Start processing {wav_path}+{input_text}")
|
273 |
+
resource_context = infer_ins.preprocess(file_content, latent_file=wav_path.replace('.wav', '.npy'))
|
274 |
+
wav_bytes = infer_ins.forward(resource_context, input_text, time_step=time_step, p_w=p_w, t_w=t_w)
|
275 |
+
|
276 |
+
print(f"| Saving results to {out_path}/[P]{input_text[:20]}.wav")
|
277 |
+
os.makedirs(out_path, exist_ok=True)
|
278 |
+
save_wav(wav_bytes, f'{out_path}/[P]{input_text[:20]}.wav')
|
tts/modules/aligner/whisper_small.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2022 OpenAI
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
# Copyright (c) [2022] [OpenAI]
|
24 |
+
# Copyright (c) [2025] [Ziyue Jiang]
|
25 |
+
# SPDX-License-Identifier: MIT
|
26 |
+
# This file has been modified by Ziyue Jiang on 2025/03/19
|
27 |
+
# Original file was released under MIT, with the full license text # available at https://github.com/openai/whisper/blob/v20240930/LICENSE.
|
28 |
+
# This modified file is released under the same license.
|
29 |
+
|
30 |
+
from contextlib import contextmanager
|
31 |
+
from typing import Dict, Iterable, Optional, Tuple
|
32 |
+
|
33 |
+
import numpy as np
|
34 |
+
import torch
|
35 |
+
import torch.nn.functional as F
|
36 |
+
from torch import Tensor, nn
|
37 |
+
|
38 |
+
from torch.nn.functional import scaled_dot_product_attention
|
39 |
+
SDPA_AVAILABLE = True
|
40 |
+
|
41 |
+
|
42 |
+
class LayerNorm(nn.LayerNorm):
|
43 |
+
def forward(self, x: Tensor) -> Tensor:
|
44 |
+
return super().forward(x.float()).type(x.dtype)
|
45 |
+
|
46 |
+
|
47 |
+
class Linear(nn.Linear):
|
48 |
+
def forward(self, x: Tensor) -> Tensor:
|
49 |
+
return F.linear(
|
50 |
+
x,
|
51 |
+
self.weight.to(x.dtype),
|
52 |
+
None if self.bias is None else self.bias.to(x.dtype),
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
+
class Conv1d(nn.Conv1d):
|
57 |
+
def _conv_forward(
|
58 |
+
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
|
59 |
+
) -> Tensor:
|
60 |
+
return super()._conv_forward(
|
61 |
+
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
+
def sinusoids(length, channels, max_timescale=10000):
|
66 |
+
"""Returns sinusoids for positional embedding"""
|
67 |
+
assert channels % 2 == 0
|
68 |
+
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
69 |
+
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
70 |
+
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
71 |
+
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
72 |
+
|
73 |
+
|
74 |
+
@contextmanager
|
75 |
+
def disable_sdpa():
|
76 |
+
prev_state = MultiHeadAttention.use_sdpa
|
77 |
+
try:
|
78 |
+
MultiHeadAttention.use_sdpa = False
|
79 |
+
yield
|
80 |
+
finally:
|
81 |
+
MultiHeadAttention.use_sdpa = prev_state
|
82 |
+
|
83 |
+
|
84 |
+
class MultiHeadAttention(nn.Module):
|
85 |
+
use_sdpa = True
|
86 |
+
|
87 |
+
def __init__(self, n_state: int, n_head: int):
|
88 |
+
super().__init__()
|
89 |
+
self.n_head = n_head
|
90 |
+
self.query = Linear(n_state, n_state)
|
91 |
+
self.key = Linear(n_state, n_state, bias=False)
|
92 |
+
self.value = Linear(n_state, n_state)
|
93 |
+
self.out = Linear(n_state, n_state)
|
94 |
+
|
95 |
+
def forward(
|
96 |
+
self,
|
97 |
+
x: Tensor,
|
98 |
+
xa: Optional[Tensor] = None,
|
99 |
+
mask: Optional[Tensor] = None,
|
100 |
+
kv_cache: Optional[dict] = None,
|
101 |
+
casual: Optional[bool] = None
|
102 |
+
):
|
103 |
+
q = self.query(x)
|
104 |
+
|
105 |
+
if kv_cache is None or xa is None or self.key not in kv_cache:
|
106 |
+
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
107 |
+
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
108 |
+
k = self.key(x if xa is None else xa)
|
109 |
+
v = self.value(x if xa is None else xa)
|
110 |
+
else:
|
111 |
+
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
112 |
+
k = kv_cache[self.key]
|
113 |
+
v = kv_cache[self.value]
|
114 |
+
|
115 |
+
wv = self.qkv_attention(q, k, v, mask, casual)
|
116 |
+
return self.out(wv)
|
117 |
+
|
118 |
+
def qkv_attention(
|
119 |
+
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, casual: Optional[bool] = None
|
120 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
121 |
+
n_batch, n_ctx, n_state = q.shape
|
122 |
+
scale = (n_state // self.n_head) ** -0.25
|
123 |
+
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
124 |
+
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
125 |
+
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
126 |
+
|
127 |
+
a = scaled_dot_product_attention(
|
128 |
+
q, k, v, is_causal=casual and n_ctx > 1, attn_mask=mask[:, None, None, :] if mask is not None else None
|
129 |
+
)
|
130 |
+
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
|
131 |
+
return out
|
132 |
+
|
133 |
+
|
134 |
+
class ResidualAttentionBlock(nn.Module):
|
135 |
+
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
136 |
+
super().__init__()
|
137 |
+
|
138 |
+
self.attn = MultiHeadAttention(n_state, n_head)
|
139 |
+
self.attn_ln = LayerNorm(n_state)
|
140 |
+
|
141 |
+
self.cross_attn = (
|
142 |
+
MultiHeadAttention(n_state, n_head) if cross_attention else None
|
143 |
+
)
|
144 |
+
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
145 |
+
|
146 |
+
n_mlp = n_state * 4
|
147 |
+
self.mlp = nn.Sequential(
|
148 |
+
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
|
149 |
+
)
|
150 |
+
self.mlp_ln = LayerNorm(n_state)
|
151 |
+
|
152 |
+
def forward(
|
153 |
+
self,
|
154 |
+
x: Tensor,
|
155 |
+
xa: Optional[Tensor] = None,
|
156 |
+
mask: Optional[Tensor] = None,
|
157 |
+
kv_cache: Optional[dict] = None,
|
158 |
+
casual: Optional[bool] = None,
|
159 |
+
):
|
160 |
+
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache, casual=casual)
|
161 |
+
if self.cross_attn:
|
162 |
+
# TODO: Cross attention mask
|
163 |
+
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache, casual=False)
|
164 |
+
x = x + self.mlp(self.mlp_ln(x))
|
165 |
+
return x
|
166 |
+
|
167 |
+
|
168 |
+
class AudioEncoder(nn.Module):
|
169 |
+
def __init__(
|
170 |
+
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
171 |
+
):
|
172 |
+
super().__init__()
|
173 |
+
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
174 |
+
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
175 |
+
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
176 |
+
|
177 |
+
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
178 |
+
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
|
179 |
+
)
|
180 |
+
self.ln_post = LayerNorm(n_state)
|
181 |
+
|
182 |
+
def forward(self, x: Tensor, attn_mask: Tensor):
|
183 |
+
"""
|
184 |
+
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
185 |
+
the mel spectrogram of the audio
|
186 |
+
"""
|
187 |
+
x = F.gelu(self.conv1(x))
|
188 |
+
x = F.gelu(self.conv2(x))
|
189 |
+
x = x.permute(0, 2, 1)
|
190 |
+
|
191 |
+
# assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
192 |
+
x = (x + self.positional_embedding[:x.size(1)]).to(x.dtype)
|
193 |
+
|
194 |
+
for block in self.blocks:
|
195 |
+
x = block(x, mask=attn_mask, casual=False)
|
196 |
+
|
197 |
+
x = self.ln_post(x)
|
198 |
+
return x
|
199 |
+
|
200 |
+
|
201 |
+
class TextDecoder(nn.Module):
|
202 |
+
def __init__(
|
203 |
+
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
204 |
+
):
|
205 |
+
super().__init__()
|
206 |
+
|
207 |
+
self.token_embedding = nn.Embedding(n_vocab, n_state)
|
208 |
+
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
|
209 |
+
|
210 |
+
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
211 |
+
[
|
212 |
+
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
|
213 |
+
for _ in range(n_layer)
|
214 |
+
]
|
215 |
+
)
|
216 |
+
self.ln = LayerNorm(n_state)
|
217 |
+
|
218 |
+
self.out_proj = nn.Linear(n_state, n_vocab)
|
219 |
+
|
220 |
+
def forward(self, x: Tensor, attn_mask: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
221 |
+
"""
|
222 |
+
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
223 |
+
the text tokens
|
224 |
+
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
|
225 |
+
the encoded audio features to be attended on
|
226 |
+
"""
|
227 |
+
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
228 |
+
x = (
|
229 |
+
self.token_embedding(x)
|
230 |
+
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
231 |
+
)
|
232 |
+
x = x.to(xa.dtype)
|
233 |
+
|
234 |
+
for block in self.blocks:
|
235 |
+
x = block(x, xa, mask=attn_mask, kv_cache=kv_cache, casual=True)
|
236 |
+
|
237 |
+
x = self.ln(x)
|
238 |
+
# logits = (
|
239 |
+
# x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
|
240 |
+
# ).float()
|
241 |
+
logits = self.out_proj(x)
|
242 |
+
|
243 |
+
return logits
|
244 |
+
|
245 |
+
|
246 |
+
class Whisper(nn.Module):
|
247 |
+
def __init__(self):
|
248 |
+
super().__init__()
|
249 |
+
self.n_vocab = 6800
|
250 |
+
self.n_text_layer = 6
|
251 |
+
self.n_text_head = 8
|
252 |
+
self.n_text_ctx = 2048
|
253 |
+
|
254 |
+
self.encoder = AudioEncoder(
|
255 |
+
n_mels=80, n_ctx=3000, n_state=512, n_head=8, n_layer=6,
|
256 |
+
)
|
257 |
+
self.decoder = TextDecoder(
|
258 |
+
n_vocab=6800, n_ctx=2048, n_state=512, n_head=8, n_layer=6,
|
259 |
+
)
|
260 |
+
|
261 |
+
def embed_audio(self, mel: torch.Tensor):
|
262 |
+
return self.encoder(mel, None)
|
263 |
+
|
264 |
+
def logits(self, tokens, audio_features, kv_cache=None):
|
265 |
+
return self.decoder(tokens, None, audio_features, kv_cache=kv_cache)
|
266 |
+
|
267 |
+
def forward(
|
268 |
+
self, mel, mel_len, token, token_len
|
269 |
+
) -> Dict[str, torch.Tensor]:
|
270 |
+
attn_mask_enc = self.sequence_mask(mel_len//2, device=mel.device) > 0
|
271 |
+
attn_mask_dec = self.sequence_mask(token_len, device=mel.device) > 0
|
272 |
+
return self.decoder(token, attn_mask_dec, self.encoder(mel, attn_mask_enc))
|
273 |
+
|
274 |
+
@property
|
275 |
+
def device(self):
|
276 |
+
return next(self.parameters()).device
|
277 |
+
|
278 |
+
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
279 |
+
"""
|
280 |
+
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
|
281 |
+
tensors calculated for the previous positions. This method returns a dictionary that stores
|
282 |
+
all caches, and the necessary hooks for the key and value projection modules that save the
|
283 |
+
intermediate tensors to be reused during later calculations.
|
284 |
+
|
285 |
+
Returns
|
286 |
+
-------
|
287 |
+
cache : Dict[nn.Module, torch.Tensor]
|
288 |
+
A dictionary object mapping the key/value projection modules to its cache
|
289 |
+
hooks : List[RemovableHandle]
|
290 |
+
List of PyTorch RemovableHandle objects to stop the hooks to be called
|
291 |
+
"""
|
292 |
+
cache = {**cache} if cache is not None else {}
|
293 |
+
hooks = []
|
294 |
+
|
295 |
+
def save_to_cache(module, _, output):
|
296 |
+
if module not in cache or output.shape[1] > self.n_text_ctx:
|
297 |
+
# save as-is, for the first token or cross attention
|
298 |
+
cache[module] = output
|
299 |
+
else:
|
300 |
+
cache[module] = torch.cat([cache[module], output], dim=1).detach()
|
301 |
+
return cache[module]
|
302 |
+
|
303 |
+
def install_hooks(layer: nn.Module):
|
304 |
+
if isinstance(layer, MultiHeadAttention):
|
305 |
+
hooks.append(layer.key.register_forward_hook(save_to_cache))
|
306 |
+
hooks.append(layer.value.register_forward_hook(save_to_cache))
|
307 |
+
|
308 |
+
self.decoder.apply(install_hooks)
|
309 |
+
return cache, hooks
|
310 |
+
|
311 |
+
def sequence_mask(self, seq_lens, max_len=None, device='cpu'):
|
312 |
+
b = seq_lens.shape[0]
|
313 |
+
if max_len is None:
|
314 |
+
max_len = seq_lens.max()
|
315 |
+
mask = torch.arange(max_len).unsqueeze(0).to(device) # [1, t]
|
316 |
+
mask = mask < (seq_lens.unsqueeze(1)) # [1, t] + [b, 1] = [b, t]
|
317 |
+
mask = mask.float()
|
318 |
+
return mask
|
tts/modules/ar_dur/ar_dur_predictor.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 ByteDance and/or its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import random
|
16 |
+
from copy import deepcopy
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn.functional as F
|
20 |
+
from torch import nn
|
21 |
+
from torch.nn import Linear
|
22 |
+
from tqdm import tqdm
|
23 |
+
|
24 |
+
from tts.modules.ar_dur.commons.layers import Embedding, LayerNorm
|
25 |
+
from tts.modules.ar_dur.commons.nar_tts_modules import PosEmb
|
26 |
+
from tts.modules.ar_dur.commons.rot_transformer import RotTransformerDecoderLayer
|
27 |
+
from tts.modules.ar_dur.commons.transformer import SinusoidalPositionalEmbedding
|
28 |
+
from tts.modules.ar_dur.commons.rel_transformer import RelTransformerEncoder
|
29 |
+
|
30 |
+
FS_ENCODERS = {
|
31 |
+
'rel_fft': lambda hp, dict_size: RelTransformerEncoder(
|
32 |
+
dict_size, hp['hidden_size'], hp['hidden_size'],
|
33 |
+
hp['ffn_hidden_size'], hp['num_heads'], hp['enc_layers'],
|
34 |
+
hp['enc_ffn_kernel_size'], hp['dropout'], prenet=hp['enc_prenet'], pre_ln=hp['enc_pre_ln']),
|
35 |
+
}
|
36 |
+
|
37 |
+
def fill_with_neg_inf2(t):
|
38 |
+
"""FP16-compatible function that fills a tensor with -inf."""
|
39 |
+
return t.float().fill_(-1e8).type_as(t)
|
40 |
+
|
41 |
+
def expand_states(h, mel2token):
|
42 |
+
h = F.pad(h, [0, 0, 1, 0])
|
43 |
+
mel2token_ = mel2token[..., None].repeat([1, 1, h.shape[-1]])
|
44 |
+
h = torch.gather(h, 1, mel2token_) # [B, T, H]
|
45 |
+
return h
|
46 |
+
|
47 |
+
|
48 |
+
class CodePredictor(nn.Module):
|
49 |
+
def __init__(self, hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size):
|
50 |
+
super().__init__()
|
51 |
+
self.hparams = deepcopy(hparams)
|
52 |
+
self.hparams['hidden_size'] = hidden_size
|
53 |
+
self.hidden_size = hidden_size
|
54 |
+
char_dict_size = hparams.get('char_dict_size', 4000)
|
55 |
+
if not hparams.get('lm_use_enc'):
|
56 |
+
self.encoder = nn.Embedding(dict_size, self.hidden_size, padding_idx=0)
|
57 |
+
if hparams.get('mega_use_char', True):
|
58 |
+
self.char_encoder = nn.Embedding(char_dict_size,
|
59 |
+
self.hidden_size, padding_idx=0)
|
60 |
+
else:
|
61 |
+
self.encoder = FS_ENCODERS[self.hparams['encoder_type']](self.hparams, dict_size)
|
62 |
+
if hparams.get('mega_use_char', True):
|
63 |
+
self.char_encoder = FS_ENCODERS[self.hparams['encoder_type']](self.hparams, char_dict_size)
|
64 |
+
if hparams['use_ph_pos_embed']:
|
65 |
+
self.ph_pos_embed = PosEmb(self.hidden_size)
|
66 |
+
|
67 |
+
self.char_empty_embed = nn.Embedding(1, self.hidden_size)
|
68 |
+
if hparams.get('use_bert_input'):
|
69 |
+
self.bert_input_proj = nn.Linear(768, self.hidden_size)
|
70 |
+
self.ling_label_embed_layers = nn.ModuleDict()
|
71 |
+
for k, s in zip(hparams['ling_labels'], hparams['ling_label_dict_size']):
|
72 |
+
self.ling_label_embed_layers[k] = Embedding(s + 3, self.hidden_size, padding_idx=0)
|
73 |
+
|
74 |
+
self.dec_hidden_size = dec_hidden_size
|
75 |
+
self.enc_proj = nn.Linear(self.hidden_size, dec_hidden_size)
|
76 |
+
self.code_emb = Embedding(code_size + 2, dec_hidden_size, 0)
|
77 |
+
self.use_pos_embed = hparams.get('use_pos_embed', False)
|
78 |
+
if self.use_pos_embed:
|
79 |
+
self.embed_positions = SinusoidalPositionalEmbedding(dec_hidden_size, 0, init_size=1024)
|
80 |
+
self.use_post_ln = hparams.get('use_post_ln', False)
|
81 |
+
self.layers = None
|
82 |
+
if not self.use_post_ln:
|
83 |
+
self.layer_norm = LayerNorm(dec_hidden_size)
|
84 |
+
self.code_size = code_size
|
85 |
+
self.project_out_dim = Linear(dec_hidden_size, code_size + 1, bias=True)
|
86 |
+
|
87 |
+
def forward_ling_encoder(
|
88 |
+
self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, spk_id, spk_embed, mels_timbre):
|
89 |
+
ph_tokens = txt_tokens
|
90 |
+
hparams = self.hparams
|
91 |
+
ph_nonpadding = (ph_tokens > 0).float()[:, :, None] # [B, T_phone, 1]
|
92 |
+
x_spk = self.forward_style_embed(spk_embed, spk_id, mels_timbre)
|
93 |
+
|
94 |
+
# enc_ph
|
95 |
+
if not hparams.get('lm_use_enc'):
|
96 |
+
x_ph = self.encoder(ph_tokens)
|
97 |
+
x_ph = x_ph + sum(
|
98 |
+
[self.ling_label_embed_layers[k](ling_feas[k]) for k in hparams['ling_labels']]) \
|
99 |
+
if len(hparams['ling_labels']) > 0 else 0
|
100 |
+
x_ph = x_ph + x_spk
|
101 |
+
else:
|
102 |
+
# enc_ph
|
103 |
+
ph_enc_oembed = sum(
|
104 |
+
[self.ling_label_embed_layers[k](ling_feas[k]) for k in hparams['ling_labels']]) \
|
105 |
+
if len(hparams['ling_labels']) > 0 else 0
|
106 |
+
ph_enc_oembed = ph_enc_oembed + self.ph_pos_embed(
|
107 |
+
torch.arange(0, ph_tokens.shape[1])[None,].to(ph_tokens.device))
|
108 |
+
ph_enc_oembed = ph_enc_oembed + x_spk
|
109 |
+
ph_enc_oembed = ph_enc_oembed * ph_nonpadding
|
110 |
+
x_ph = self.encoder(ph_tokens, other_embeds=ph_enc_oembed)
|
111 |
+
|
112 |
+
# enc_char
|
113 |
+
if char_tokens is not None and ph2char is not None:
|
114 |
+
char_nonpadding = (char_tokens > 0).float()[:, :, None]
|
115 |
+
x_char = self.char_encoder(char_tokens)
|
116 |
+
empty_char = (ph2char > 100000).long()
|
117 |
+
ph2char = ph2char * (1 - empty_char)
|
118 |
+
x_char_phlevel = \
|
119 |
+
expand_states(x_char * char_nonpadding, ph2char) \
|
120 |
+
* (1 - empty_char)[..., None] + \
|
121 |
+
self.char_empty_embed(torch.zeros_like(ph_tokens)) * empty_char[..., None]
|
122 |
+
else:
|
123 |
+
x_char_phlevel = 0
|
124 |
+
# x_ling
|
125 |
+
x_ling = x_ph + x_char_phlevel
|
126 |
+
x_ling = x_ling * ph_nonpadding
|
127 |
+
x_ling = self.enc_proj(x_ling)
|
128 |
+
return x_ling
|
129 |
+
|
130 |
+
def sample_one_step(self, vq_pred):
|
131 |
+
hparams = self.hparams
|
132 |
+
if hparams.get('infer_top_k'):
|
133 |
+
top_k = hparams.get('infer_top_k')
|
134 |
+
temperature = hparams.get('infer_temperature', 1)
|
135 |
+
vq_pred = vq_pred[:, -1] / temperature
|
136 |
+
# optionally crop the logits to only the top k options
|
137 |
+
if top_k is not None:
|
138 |
+
v, _ = torch.topk(vq_pred, min(top_k, vq_pred.size(-1)))
|
139 |
+
vq_pred[vq_pred < v[:, [-1]]] = -float('Inf')
|
140 |
+
# apply softmax to convert logits to (normalized) probabilities
|
141 |
+
probs = F.softmax(vq_pred, dim=-1)
|
142 |
+
# sample from the distribution
|
143 |
+
vq_pred = torch.multinomial(probs, num_samples=1)
|
144 |
+
else:
|
145 |
+
vq_pred = torch.argmax(F.softmax(vq_pred[:, -1], dim=-1), 1)
|
146 |
+
return vq_pred
|
147 |
+
|
148 |
+
def forward_style_embed(self, spk_embed=None, spk_id=None, mel_ref=None):
|
149 |
+
# add spk embed
|
150 |
+
style_embed = 0
|
151 |
+
if self.hparams['use_spk_embed']:
|
152 |
+
style_embed = style_embed + self.spk_embed_proj(spk_embed)[:, None, :]
|
153 |
+
if self.hparams['use_spk_id']:
|
154 |
+
style_embed = style_embed + self.spk_id_proj(spk_id)[:, None, :]
|
155 |
+
if self.hparams['use_spk_enc']:
|
156 |
+
style_embed = style_embed + self.spk_enc(mel_ref)[:, None, :]
|
157 |
+
return style_embed
|
158 |
+
|
159 |
+
def buffered_future_mask(self, tensor):
|
160 |
+
dim = tensor.size(0)
|
161 |
+
if (
|
162 |
+
not hasattr(self, '_future_mask')
|
163 |
+
or self._future_mask is None
|
164 |
+
or self._future_mask.device != tensor.device
|
165 |
+
or self._future_mask.size(0) < dim
|
166 |
+
):
|
167 |
+
self._future_mask = torch.triu(fill_with_neg_inf2(tensor.new(dim, dim)), 1)
|
168 |
+
return self._future_mask[:dim, :dim]
|
169 |
+
|
170 |
+
|
171 |
+
class ARDurPredictor(CodePredictor):
|
172 |
+
def __init__(self, hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size, use_rot_embed=True,
|
173 |
+
op_version=1):
|
174 |
+
super().__init__(hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size)
|
175 |
+
self.use_rot_embed = use_rot_embed
|
176 |
+
bias = hparams.get('lm_bias', True)
|
177 |
+
if self.use_rot_embed:
|
178 |
+
self.layers = nn.ModuleList([])
|
179 |
+
self.layers.extend([
|
180 |
+
RotTransformerDecoderLayer(
|
181 |
+
dec_hidden_size, 0.0, kernel_size=1, ffn_hidden_size=dec_hidden_size * 4,
|
182 |
+
post_ln=self.use_post_ln, op_version=op_version, bias=bias)
|
183 |
+
for _ in range(lm_num_layers)
|
184 |
+
])
|
185 |
+
if hparams['dur_model_type'] == 'ar_mse':
|
186 |
+
self.project_out_dim = nn.Sequential(torch.nn.Linear(dec_hidden_size, 1), nn.Softplus())
|
187 |
+
else:
|
188 |
+
self.project_out_dim = torch.nn.Linear(dec_hidden_size, code_size + 1)
|
189 |
+
|
190 |
+
def forward(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
|
191 |
+
prev_code, spk_id=None, spk_embed=None, mels_timbre=None, mel2ph=None,
|
192 |
+
incremental_state=None, x_ling=None, attn_mask=None, spk_pos_ids_flat=None,
|
193 |
+
prompt_length=None, cache_size=20, streaming=False):
|
194 |
+
x = self.code_emb(prev_code)
|
195 |
+
if x_ling is None:
|
196 |
+
x_ling = self.forward_ling_encoder(
|
197 |
+
txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, spk_id, spk_embed, mels_timbre)
|
198 |
+
x_ling = x_ling.flatten(0, 1)
|
199 |
+
txt_tokens = txt_tokens.flatten(0, 1)
|
200 |
+
x_ling = x_ling[txt_tokens > 0][None]
|
201 |
+
|
202 |
+
# run decoder
|
203 |
+
self_attn_padding_mask = None
|
204 |
+
if self.use_pos_embed:
|
205 |
+
positions = self.embed_positions(
|
206 |
+
prev_code,
|
207 |
+
incremental_state=incremental_state
|
208 |
+
)
|
209 |
+
if incremental_state is not None:
|
210 |
+
x_ling = x_ling[:, x.shape[1] - 1:x.shape[1]]
|
211 |
+
if spk_pos_ids_flat is not None:
|
212 |
+
spk_pos_ids_flat = spk_pos_ids_flat[:, x.shape[1] - 1:x.shape[1]]
|
213 |
+
x = x[:, -1:]
|
214 |
+
if self.use_pos_embed:
|
215 |
+
positions = positions[:, -1:]
|
216 |
+
if streaming:
|
217 |
+
# Shift Pos: query pos is min(cache_size, idx)
|
218 |
+
spk_pos_ids_flat = torch.min(torch.LongTensor([prompt_length + cache_size]).to(x.device),
|
219 |
+
spk_pos_ids_flat)
|
220 |
+
|
221 |
+
# # B x T x C -> T x B x C
|
222 |
+
if self.use_pos_embed:
|
223 |
+
x = x + positions
|
224 |
+
x_ling = x_ling[:, :self.hparams['max_tokens']].contiguous()
|
225 |
+
T = min(self.hparams.get('max_tokens_per_item', 1e9), x_ling.shape[1])
|
226 |
+
x_ling = x_ling.reshape(-1, T, x_ling.shape[-1])
|
227 |
+
x = x + x_ling
|
228 |
+
x = x.transpose(0, 1)
|
229 |
+
|
230 |
+
for idx, layer in enumerate(self.layers):
|
231 |
+
if incremental_state is None:
|
232 |
+
self_attn_mask = self.buffered_future_mask(x)
|
233 |
+
if attn_mask is not None:
|
234 |
+
self_attn_mask = self_attn_mask + (1 - attn_mask.float()) * -1e8
|
235 |
+
self_attn_mask = self_attn_mask.clamp_min(-1e8)
|
236 |
+
else:
|
237 |
+
self_attn_mask = None
|
238 |
+
|
239 |
+
x, attn_weights = layer(
|
240 |
+
x,
|
241 |
+
incremental_state=incremental_state,
|
242 |
+
self_attn_mask=self_attn_mask,
|
243 |
+
self_attn_padding_mask=self_attn_padding_mask,
|
244 |
+
spk_pos_ids_flat=spk_pos_ids_flat
|
245 |
+
)
|
246 |
+
|
247 |
+
if streaming and incremental_state != {}:
|
248 |
+
for k, v in incremental_state.items():
|
249 |
+
if 'attn_state' in k:
|
250 |
+
prev_key, prev_value = incremental_state[k]['prev_key'], incremental_state[k]['prev_value']
|
251 |
+
cur_length = prev_key.shape[2]
|
252 |
+
if cur_length - prompt_length > cache_size:
|
253 |
+
prev_key = torch.cat((prev_key[:, :, :prompt_length], prev_key[:, :, -cache_size:]), dim=2)
|
254 |
+
prev_value = torch.cat((prev_value[:, :, :prompt_length], prev_value[:, :, -cache_size:]),
|
255 |
+
dim=2)
|
256 |
+
incremental_state[k]['prev_key'], incremental_state[k]['prev_value'] = prev_key, prev_value
|
257 |
+
|
258 |
+
if not self.use_post_ln:
|
259 |
+
x = self.layer_norm(x)
|
260 |
+
# T x B x C -> B x T x C
|
261 |
+
x = x.transpose(0, 1)
|
262 |
+
x = self.project_out_dim(x)
|
263 |
+
return x
|
264 |
+
|
265 |
+
def infer(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
|
266 |
+
spk_id=None, spk_embed=None, mels_timbre=None,
|
267 |
+
incremental_state=None, ctx_vqcodes=None, spk_pos_ids_flat=None, return_state=False,
|
268 |
+
first_step_min=0, return_probs=False, first_decoder_inp=None, dur_disturb=0.0, **kwargs):
|
269 |
+
if incremental_state is None:
|
270 |
+
incremental_state = {}
|
271 |
+
x_ling = self.forward_ling_encoder(
|
272 |
+
txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
|
273 |
+
spk_id, spk_embed, mels_timbre)
|
274 |
+
x_ling = x_ling.flatten(0, 1)
|
275 |
+
txt_tokens_ori = txt_tokens
|
276 |
+
txt_tokens_withpad = txt_tokens = txt_tokens.flatten(0, 1)
|
277 |
+
x_ling = x_ling[txt_tokens > 0][None]
|
278 |
+
txt_tokens = txt_tokens[txt_tokens > 0][None]
|
279 |
+
|
280 |
+
decoded = torch.zeros_like(txt_tokens)
|
281 |
+
decoded = F.pad(decoded, [1, 0], value=self.code_size + 1)
|
282 |
+
if incremental_state != {}:
|
283 |
+
if first_decoder_inp is None:
|
284 |
+
assert ctx_vqcodes is not None
|
285 |
+
decoded[:, :ctx_vqcodes.shape[1]] = ctx_vqcodes
|
286 |
+
ctx_vqcodes = None
|
287 |
+
else:
|
288 |
+
decoded[:, :1] = first_decoder_inp
|
289 |
+
probs = []
|
290 |
+
for step in range(decoded.shape[1] - 1):
|
291 |
+
vq_pred = self(txt_tokens, None, None, None, None,
|
292 |
+
decoded[:, :step + 1], None, None, None,
|
293 |
+
incremental_state=incremental_state, x_ling=x_ling,
|
294 |
+
spk_pos_ids_flat=spk_pos_ids_flat, **kwargs)
|
295 |
+
probs.append(vq_pred.cpu())
|
296 |
+
if ctx_vqcodes is None or step >= ctx_vqcodes.shape[1]:
|
297 |
+
if self.hparams['dur_model_type'] == 'ar_mse':
|
298 |
+
d = vq_pred[:, -1, 0]
|
299 |
+
if dur_disturb > 0 and step >= 1:
|
300 |
+
if random.random() > 0.5:
|
301 |
+
d = d * (1 + random.random() * dur_disturb)
|
302 |
+
else:
|
303 |
+
d = d / (1 + random.random() * dur_disturb)
|
304 |
+
d = torch.clamp_max(d, self.code_size - 1)
|
305 |
+
vq_pred = torch.round(d).long()
|
306 |
+
else:
|
307 |
+
vq_pred = self.sample_one_step(vq_pred)
|
308 |
+
decoded[:, step + 1] = torch.clamp_min(vq_pred, 1)
|
309 |
+
if step == 0:
|
310 |
+
decoded[:, step + 1] = torch.clamp_min(vq_pred, first_step_min)
|
311 |
+
else:
|
312 |
+
decoded[:, step + 1] = ctx_vqcodes[:, step]
|
313 |
+
decoded = decoded[:, 1:]
|
314 |
+
decoded_2d = torch.zeros_like(txt_tokens_ori)
|
315 |
+
decoded_2d.flatten(0, 1)[txt_tokens_withpad > 0] = decoded
|
316 |
+
if return_state:
|
317 |
+
return decoded_2d, incremental_state
|
318 |
+
if return_probs:
|
319 |
+
return decoded_2d, torch.cat(probs, 1)
|
320 |
+
return decoded_2d
|
321 |
+
|
322 |
+
def streaming_infer(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
|
323 |
+
spk_id=None, spk_embed=None, mels_timbre=None,
|
324 |
+
incremental_state=None, ctx_vqcodes=None, spk_pos_ids_flat=None, return_state=False,
|
325 |
+
**kwargs):
|
326 |
+
if incremental_state is None:
|
327 |
+
incremental_state = {}
|
328 |
+
x_ling = self.forward_ling_encoder(
|
329 |
+
txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
|
330 |
+
spk_id, spk_embed, mels_timbre)
|
331 |
+
x_ling = x_ling.flatten(0, 1)
|
332 |
+
txt_tokens_ori = txt_tokens
|
333 |
+
txt_tokens_withpad = txt_tokens = txt_tokens.flatten(0, 1)
|
334 |
+
x_ling = x_ling[txt_tokens > 0][None]
|
335 |
+
txt_tokens = txt_tokens[txt_tokens > 0][None]
|
336 |
+
|
337 |
+
vq_decoded = torch.zeros_like(txt_tokens)
|
338 |
+
vq_decoded = F.pad(vq_decoded, [1, 0], value=self.code_size + 1)
|
339 |
+
if incremental_state != {}:
|
340 |
+
assert ctx_vqcodes is not None
|
341 |
+
vq_decoded[:, :ctx_vqcodes.shape[1]] = ctx_vqcodes
|
342 |
+
ctx_vqcodes = None
|
343 |
+
prompt_length = list(incremental_state.items())[0][1]['prev_key'].shape[2]
|
344 |
+
for step in tqdm(range(vq_decoded.shape[1] - 1), desc='AR Duration Predictor inference...'):
|
345 |
+
vq_pred = self(txt_tokens, None, None, None, None,
|
346 |
+
vq_decoded[:, :step + 1], None, None, None,
|
347 |
+
incremental_state=incremental_state, x_ling=x_ling,
|
348 |
+
spk_pos_ids_flat=spk_pos_ids_flat, prompt_length=prompt_length, streaming=True, **kwargs)
|
349 |
+
if ctx_vqcodes is None or step >= ctx_vqcodes.shape[1]:
|
350 |
+
if self.hparams['dur_model_type'] == 'ar_mse':
|
351 |
+
vq_pred = torch.round(vq_pred[:, -1, 0]).long()
|
352 |
+
else:
|
353 |
+
vq_pred = self.sample_one_step(vq_pred)
|
354 |
+
vq_decoded[:, step + 1] = vq_pred
|
355 |
+
else:
|
356 |
+
vq_decoded[:, step + 1] = ctx_vqcodes[:, step]
|
357 |
+
vq_decoded = vq_decoded[:, 1:]
|
358 |
+
vq_decoded_2d = torch.zeros_like(txt_tokens_ori)
|
359 |
+
vq_decoded_2d.flatten(0, 1)[txt_tokens_withpad > 0] = vq_decoded
|
360 |
+
if return_state:
|
361 |
+
return vq_decoded_2d, incremental_state
|
362 |
+
return vq_decoded_2d
|
tts/modules/ar_dur/commons/layers.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 ByteDance and/or its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from torch import nn
|
17 |
+
|
18 |
+
|
19 |
+
class LayerNorm(torch.nn.LayerNorm):
|
20 |
+
"""Layer normalization module.
|
21 |
+
:param int nout: output dim size
|
22 |
+
:param int dim: dimension to be normalized
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, nout, dim=-1, eps=1e-5):
|
26 |
+
"""Construct an LayerNorm object."""
|
27 |
+
super(LayerNorm, self).__init__(nout, eps=eps)
|
28 |
+
self.dim = dim
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
"""Apply layer normalization.
|
32 |
+
:param torch.Tensor x: input tensor
|
33 |
+
:return: layer normalized tensor
|
34 |
+
:rtype torch.Tensor
|
35 |
+
"""
|
36 |
+
if self.dim == -1:
|
37 |
+
return super(LayerNorm, self).forward(x)
|
38 |
+
return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
|
39 |
+
|
40 |
+
|
41 |
+
class Reshape(nn.Module):
|
42 |
+
def __init__(self, *args):
|
43 |
+
super(Reshape, self).__init__()
|
44 |
+
self.shape = args
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
return x.view(self.shape)
|
48 |
+
|
49 |
+
|
50 |
+
class Permute(nn.Module):
|
51 |
+
def __init__(self, *args):
|
52 |
+
super(Permute, self).__init__()
|
53 |
+
self.args = args
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
return x.permute(self.args)
|
57 |
+
|
58 |
+
|
59 |
+
def Embedding(num_embeddings, embedding_dim, padding_idx=None):
|
60 |
+
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
61 |
+
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
|
62 |
+
if padding_idx is not None:
|
63 |
+
nn.init.constant_(m.weight[padding_idx], 0)
|
64 |
+
return m
|
tts/modules/ar_dur/commons/nar_tts_modules.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 ByteDance and/or its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
import torch.nn.functional as F
|
21 |
+
|
22 |
+
|
23 |
+
class LengthRegulator(torch.nn.Module):
|
24 |
+
def __init__(self, pad_value=0.0):
|
25 |
+
super(LengthRegulator, self).__init__()
|
26 |
+
self.pad_value = pad_value
|
27 |
+
|
28 |
+
def forward(self, dur, dur_padding=None, alpha=1.0):
|
29 |
+
"""
|
30 |
+
Example (no batch dim version):
|
31 |
+
1. dur = [2,2,3]
|
32 |
+
2. token_idx = [[1],[2],[3]], dur_cumsum = [2,4,7], dur_cumsum_prev = [0,2,4]
|
33 |
+
3. token_mask = [[1,1,0,0,0,0,0],
|
34 |
+
[0,0,1,1,0,0,0],
|
35 |
+
[0,0,0,0,1,1,1]]
|
36 |
+
4. token_idx * token_mask = [[1,1,0,0,0,0,0],
|
37 |
+
[0,0,2,2,0,0,0],
|
38 |
+
[0,0,0,0,3,3,3]]
|
39 |
+
5. (token_idx * token_mask).sum(0) = [1,1,2,2,3,3,3]
|
40 |
+
|
41 |
+
:param dur: Batch of durations of each frame (B, T_txt)
|
42 |
+
:param dur_padding: Batch of padding of each frame (B, T_txt)
|
43 |
+
:param alpha: duration rescale coefficient
|
44 |
+
:return:
|
45 |
+
mel2ph (B, T_speech)
|
46 |
+
assert alpha > 0
|
47 |
+
"""
|
48 |
+
dur = torch.round(dur.float() * alpha).long()
|
49 |
+
if dur_padding is not None:
|
50 |
+
dur = dur * (1 - dur_padding.long())
|
51 |
+
token_idx = torch.arange(1, dur.shape[1] + 1)[None, :, None].to(dur.device)
|
52 |
+
dur_cumsum = torch.cumsum(dur, 1)
|
53 |
+
dur_cumsum_prev = F.pad(dur_cumsum, [1, -1], mode='constant', value=0)
|
54 |
+
|
55 |
+
pos_idx = torch.arange(dur.sum(-1).max())[None, None].to(dur.device)
|
56 |
+
token_mask = (pos_idx >= dur_cumsum_prev[:, :, None]) & (pos_idx < dur_cumsum[:, :, None])
|
57 |
+
mel2token = (token_idx * token_mask.long()).sum(1)
|
58 |
+
return mel2token
|
59 |
+
|
60 |
+
|
61 |
+
class PosEmb(nn.Module):
|
62 |
+
def __init__(self, dim):
|
63 |
+
super().__init__()
|
64 |
+
self.dim = dim
|
65 |
+
half_dim = self.dim // 2
|
66 |
+
emb = math.log(10000) / (half_dim - 1)
|
67 |
+
emb = torch.exp(torch.arange(half_dim) * -emb)
|
68 |
+
self.emb = emb # TODO
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
emb = x[:, :, None] * self.emb[None, None, :].to(x.device)
|
72 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
73 |
+
return emb
|
tts/modules/ar_dur/commons/rel_transformer.py
ADDED
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 ByteDance and/or its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
import torch
|
17 |
+
from torch import nn
|
18 |
+
from torch.nn import functional as F
|
19 |
+
|
20 |
+
from tts.modules.ar_dur.commons.layers import Embedding
|
21 |
+
|
22 |
+
|
23 |
+
def convert_pad_shape(pad_shape):
|
24 |
+
l = pad_shape[::-1]
|
25 |
+
pad_shape = [item for sublist in l for item in sublist]
|
26 |
+
return pad_shape
|
27 |
+
|
28 |
+
|
29 |
+
def shift_1d(x):
|
30 |
+
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
31 |
+
return x
|
32 |
+
|
33 |
+
|
34 |
+
def sequence_mask(length, max_length=None):
|
35 |
+
if max_length is None:
|
36 |
+
max_length = length.max()
|
37 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
38 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
39 |
+
|
40 |
+
|
41 |
+
class Encoder(nn.Module):
|
42 |
+
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.,
|
43 |
+
window_size=None, block_length=None, pre_ln=False, **kwargs):
|
44 |
+
super().__init__()
|
45 |
+
self.hidden_channels = hidden_channels
|
46 |
+
self.filter_channels = filter_channels
|
47 |
+
self.n_heads = n_heads
|
48 |
+
self.n_layers = n_layers
|
49 |
+
self.kernel_size = kernel_size
|
50 |
+
self.p_dropout = p_dropout
|
51 |
+
self.window_size = window_size
|
52 |
+
self.block_length = block_length
|
53 |
+
self.pre_ln = pre_ln
|
54 |
+
|
55 |
+
self.drop = nn.Dropout(p_dropout)
|
56 |
+
self.attn_layers = nn.ModuleList()
|
57 |
+
self.norm_layers_1 = nn.ModuleList()
|
58 |
+
self.ffn_layers = nn.ModuleList()
|
59 |
+
self.norm_layers_2 = nn.ModuleList()
|
60 |
+
for i in range(self.n_layers):
|
61 |
+
self.attn_layers.append(
|
62 |
+
MultiHeadAttention(hidden_channels, hidden_channels, n_heads, window_size=window_size,
|
63 |
+
p_dropout=p_dropout, block_length=block_length))
|
64 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
65 |
+
self.ffn_layers.append(
|
66 |
+
FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
|
67 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
68 |
+
if pre_ln:
|
69 |
+
self.last_ln = LayerNorm(hidden_channels)
|
70 |
+
|
71 |
+
def forward(self, x, x_mask, attn_mask=1):
|
72 |
+
if isinstance(attn_mask, torch.Tensor):
|
73 |
+
attn_mask = attn_mask[:, None]
|
74 |
+
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) * attn_mask
|
75 |
+
for i in range(self.n_layers):
|
76 |
+
x = x * x_mask
|
77 |
+
x_ = x
|
78 |
+
if self.pre_ln:
|
79 |
+
x = self.norm_layers_1[i](x)
|
80 |
+
y = self.attn_layers[i](x, x, attn_mask)
|
81 |
+
y = self.drop(y)
|
82 |
+
x = x_ + y
|
83 |
+
if not self.pre_ln:
|
84 |
+
x = self.norm_layers_1[i](x)
|
85 |
+
|
86 |
+
x_ = x
|
87 |
+
if self.pre_ln:
|
88 |
+
x = self.norm_layers_2[i](x)
|
89 |
+
y = self.ffn_layers[i](x, x_mask)
|
90 |
+
y = self.drop(y)
|
91 |
+
x = x_ + y
|
92 |
+
if not self.pre_ln:
|
93 |
+
x = self.norm_layers_2[i](x)
|
94 |
+
if self.pre_ln:
|
95 |
+
x = self.last_ln(x)
|
96 |
+
x = x * x_mask
|
97 |
+
return x
|
98 |
+
|
99 |
+
|
100 |
+
class MultiHeadAttention(nn.Module):
|
101 |
+
def __init__(self, channels, out_channels, n_heads, window_size=None, heads_share=True, p_dropout=0.,
|
102 |
+
block_length=None, proximal_bias=False, proximal_init=False):
|
103 |
+
super().__init__()
|
104 |
+
assert channels % n_heads == 0
|
105 |
+
|
106 |
+
self.channels = channels
|
107 |
+
self.out_channels = out_channels
|
108 |
+
self.n_heads = n_heads
|
109 |
+
self.window_size = window_size
|
110 |
+
self.heads_share = heads_share
|
111 |
+
self.block_length = block_length
|
112 |
+
self.proximal_bias = proximal_bias
|
113 |
+
self.p_dropout = p_dropout
|
114 |
+
self.attn = None
|
115 |
+
|
116 |
+
self.k_channels = channels // n_heads
|
117 |
+
self.conv_q = nn.Conv1d(channels, channels, 1)
|
118 |
+
self.conv_k = nn.Conv1d(channels, channels, 1)
|
119 |
+
self.conv_v = nn.Conv1d(channels, channels, 1)
|
120 |
+
if window_size is not None:
|
121 |
+
n_heads_rel = 1 if heads_share else n_heads
|
122 |
+
rel_stddev = self.k_channels ** -0.5
|
123 |
+
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
124 |
+
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
125 |
+
self.conv_o = nn.Conv1d(channels, out_channels, 1)
|
126 |
+
self.drop = nn.Dropout(p_dropout)
|
127 |
+
|
128 |
+
nn.init.xavier_uniform_(self.conv_q.weight)
|
129 |
+
nn.init.xavier_uniform_(self.conv_k.weight)
|
130 |
+
if proximal_init:
|
131 |
+
self.conv_k.weight.data.copy_(self.conv_q.weight.data)
|
132 |
+
self.conv_k.bias.data.copy_(self.conv_q.bias.data)
|
133 |
+
nn.init.xavier_uniform_(self.conv_v.weight)
|
134 |
+
|
135 |
+
def forward(self, x, c, attn_mask=None):
|
136 |
+
q = self.conv_q(x)
|
137 |
+
k = self.conv_k(c)
|
138 |
+
v = self.conv_v(c)
|
139 |
+
|
140 |
+
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
141 |
+
|
142 |
+
x = self.conv_o(x)
|
143 |
+
return x
|
144 |
+
|
145 |
+
def attention(self, query, key, value, mask=None):
|
146 |
+
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
147 |
+
b, d, t_s, t_t = (*key.size(), query.size(2))
|
148 |
+
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
149 |
+
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
150 |
+
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
151 |
+
|
152 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
|
153 |
+
if self.window_size is not None:
|
154 |
+
assert t_s == t_t, "Relative attention is only available for self-attention."
|
155 |
+
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
156 |
+
rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
|
157 |
+
rel_logits = self._relative_position_to_absolute_position(rel_logits)
|
158 |
+
scores_local = rel_logits / math.sqrt(self.k_channels)
|
159 |
+
scores = scores + scores_local
|
160 |
+
if self.proximal_bias:
|
161 |
+
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
162 |
+
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
|
163 |
+
if mask is not None:
|
164 |
+
scores = scores.masked_fill(mask == 0, -1e4)
|
165 |
+
if self.block_length is not None:
|
166 |
+
block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
|
167 |
+
scores = scores * block_mask + -1e4 * (1 - block_mask)
|
168 |
+
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
169 |
+
p_attn = self.drop(p_attn)
|
170 |
+
output = torch.matmul(p_attn, value)
|
171 |
+
if self.window_size is not None:
|
172 |
+
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
173 |
+
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
|
174 |
+
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
|
175 |
+
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
176 |
+
return output, p_attn
|
177 |
+
|
178 |
+
def _matmul_with_relative_values(self, x, y):
|
179 |
+
"""
|
180 |
+
x: [b, h, l, m]
|
181 |
+
y: [h or 1, m, d]
|
182 |
+
ret: [b, h, l, d]
|
183 |
+
"""
|
184 |
+
ret = torch.matmul(x, y.unsqueeze(0))
|
185 |
+
return ret
|
186 |
+
|
187 |
+
def _matmul_with_relative_keys(self, x, y):
|
188 |
+
"""
|
189 |
+
x: [b, h, l, d]
|
190 |
+
y: [h or 1, m, d]
|
191 |
+
ret: [b, h, l, m]
|
192 |
+
"""
|
193 |
+
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
194 |
+
return ret
|
195 |
+
|
196 |
+
def _get_relative_embeddings(self, relative_embeddings, length):
|
197 |
+
max_relative_position = 2 * self.window_size + 1
|
198 |
+
# Pad first before slice to avoid using cond ops.
|
199 |
+
pad_length = max(length - (self.window_size + 1), 0)
|
200 |
+
slice_start_position = max((self.window_size + 1) - length, 0)
|
201 |
+
slice_end_position = slice_start_position + 2 * length - 1
|
202 |
+
if pad_length > 0:
|
203 |
+
padded_relative_embeddings = F.pad(
|
204 |
+
relative_embeddings,
|
205 |
+
convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
|
206 |
+
else:
|
207 |
+
padded_relative_embeddings = relative_embeddings
|
208 |
+
used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
|
209 |
+
return used_relative_embeddings
|
210 |
+
|
211 |
+
def _relative_position_to_absolute_position(self, x):
|
212 |
+
"""
|
213 |
+
x: [b, h, l, 2*l-1]
|
214 |
+
ret: [b, h, l, l]
|
215 |
+
"""
|
216 |
+
batch, heads, length, _ = x.size()
|
217 |
+
# Concat columns of pad to shift from relative to absolute indexing.
|
218 |
+
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
|
219 |
+
|
220 |
+
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
221 |
+
x_flat = x.view([batch, heads, length * 2 * length])
|
222 |
+
x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
|
223 |
+
|
224 |
+
# Reshape and slice out the padded elements.
|
225 |
+
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:]
|
226 |
+
return x_final
|
227 |
+
|
228 |
+
def _absolute_position_to_relative_position(self, x):
|
229 |
+
"""
|
230 |
+
x: [b, h, l, l]
|
231 |
+
ret: [b, h, l, 2*l-1]
|
232 |
+
"""
|
233 |
+
batch, heads, length, _ = x.size()
|
234 |
+
# padd along column
|
235 |
+
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
|
236 |
+
x_flat = x.view([batch, heads, -1])
|
237 |
+
# add 0's in the beginning that will skew the elements after reshape
|
238 |
+
x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
239 |
+
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
240 |
+
return x_final
|
241 |
+
|
242 |
+
def _attention_bias_proximal(self, length):
|
243 |
+
"""Bias for self-attention to encourage attention to close positions.
|
244 |
+
Args:
|
245 |
+
length: an integer scalar.
|
246 |
+
Returns:
|
247 |
+
a Tensor with shape [1, 1, length, length]
|
248 |
+
"""
|
249 |
+
r = torch.arange(length, dtype=torch.float32)
|
250 |
+
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
251 |
+
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
252 |
+
|
253 |
+
|
254 |
+
class FFN(nn.Module):
|
255 |
+
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None):
|
256 |
+
super().__init__()
|
257 |
+
self.in_channels = in_channels
|
258 |
+
self.out_channels = out_channels
|
259 |
+
self.filter_channels = filter_channels
|
260 |
+
self.kernel_size = kernel_size
|
261 |
+
self.p_dropout = p_dropout
|
262 |
+
self.activation = activation
|
263 |
+
|
264 |
+
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
265 |
+
self.conv_2 = nn.Conv1d(filter_channels, out_channels, 1)
|
266 |
+
self.drop = nn.Dropout(p_dropout)
|
267 |
+
|
268 |
+
def forward(self, x, x_mask):
|
269 |
+
x = self.conv_1(x * x_mask)
|
270 |
+
if self.activation == "gelu":
|
271 |
+
x = x * torch.sigmoid(1.702 * x)
|
272 |
+
else:
|
273 |
+
x = torch.relu(x)
|
274 |
+
x = self.drop(x)
|
275 |
+
x = self.conv_2(x * x_mask)
|
276 |
+
return x * x_mask
|
277 |
+
|
278 |
+
|
279 |
+
class LayerNorm(nn.Module):
|
280 |
+
def __init__(self, channels, eps=1e-4):
|
281 |
+
super().__init__()
|
282 |
+
self.channels = channels
|
283 |
+
self.eps = eps
|
284 |
+
|
285 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
286 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
287 |
+
|
288 |
+
def forward(self, x):
|
289 |
+
n_dims = len(x.shape)
|
290 |
+
mean = torch.mean(x, 1, keepdim=True)
|
291 |
+
variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
|
292 |
+
|
293 |
+
x = (x - mean) * torch.rsqrt(variance + self.eps)
|
294 |
+
|
295 |
+
shape = [1, -1] + [1] * (n_dims - 2)
|
296 |
+
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
|
297 |
+
return x
|
298 |
+
|
299 |
+
|
300 |
+
class ConvReluNorm(nn.Module):
|
301 |
+
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
|
302 |
+
super().__init__()
|
303 |
+
self.in_channels = in_channels
|
304 |
+
self.hidden_channels = hidden_channels
|
305 |
+
self.out_channels = out_channels
|
306 |
+
self.kernel_size = kernel_size
|
307 |
+
self.n_layers = n_layers
|
308 |
+
self.p_dropout = p_dropout
|
309 |
+
assert n_layers > 1, "Number of layers should be larger than 0."
|
310 |
+
|
311 |
+
self.conv_layers = nn.ModuleList()
|
312 |
+
self.norm_layers = nn.ModuleList()
|
313 |
+
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
314 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
315 |
+
self.relu_drop = nn.Sequential(
|
316 |
+
nn.ReLU(),
|
317 |
+
nn.Dropout(p_dropout))
|
318 |
+
for _ in range(n_layers - 1):
|
319 |
+
self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
320 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
321 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
322 |
+
self.proj.weight.data.zero_()
|
323 |
+
self.proj.bias.data.zero_()
|
324 |
+
|
325 |
+
def forward(self, x, x_mask):
|
326 |
+
x_org = x
|
327 |
+
for i in range(self.n_layers):
|
328 |
+
x = self.conv_layers[i](x * x_mask)
|
329 |
+
x = self.norm_layers[i](x)
|
330 |
+
x = self.relu_drop(x)
|
331 |
+
x = x_org + self.proj(x)
|
332 |
+
return x * x_mask
|
333 |
+
|
334 |
+
|
335 |
+
class RelTransformerEncoder(nn.Module):
|
336 |
+
def __init__(self,
|
337 |
+
n_vocab,
|
338 |
+
out_channels,
|
339 |
+
hidden_channels,
|
340 |
+
filter_channels,
|
341 |
+
n_heads,
|
342 |
+
n_layers,
|
343 |
+
kernel_size,
|
344 |
+
p_dropout=0.0,
|
345 |
+
window_size=4,
|
346 |
+
block_length=None,
|
347 |
+
in_channels=None,
|
348 |
+
prenet=True,
|
349 |
+
pre_ln=True,
|
350 |
+
):
|
351 |
+
|
352 |
+
super().__init__()
|
353 |
+
|
354 |
+
self.n_vocab = n_vocab
|
355 |
+
self.out_channels = out_channels
|
356 |
+
self.hidden_channels = hidden_channels
|
357 |
+
self.filter_channels = filter_channels
|
358 |
+
self.n_heads = n_heads
|
359 |
+
self.n_layers = n_layers
|
360 |
+
self.kernel_size = kernel_size
|
361 |
+
self.p_dropout = p_dropout
|
362 |
+
self.window_size = window_size
|
363 |
+
self.block_length = block_length
|
364 |
+
self.prenet = prenet
|
365 |
+
if n_vocab > 0:
|
366 |
+
self.emb = Embedding(n_vocab, hidden_channels, padding_idx=0)
|
367 |
+
|
368 |
+
if prenet:
|
369 |
+
if in_channels is None:
|
370 |
+
in_channels = hidden_channels
|
371 |
+
self.pre = ConvReluNorm(in_channels, in_channels, in_channels,
|
372 |
+
kernel_size=5, n_layers=3, p_dropout=0)
|
373 |
+
if in_channels is not None and in_channels != hidden_channels:
|
374 |
+
self.encoder_inp_proj = nn.Conv1d(in_channels, hidden_channels, 1)
|
375 |
+
self.encoder = Encoder(
|
376 |
+
hidden_channels,
|
377 |
+
filter_channels,
|
378 |
+
n_heads,
|
379 |
+
n_layers,
|
380 |
+
kernel_size,
|
381 |
+
p_dropout,
|
382 |
+
window_size=window_size,
|
383 |
+
block_length=block_length,
|
384 |
+
pre_ln=pre_ln,
|
385 |
+
)
|
386 |
+
|
387 |
+
def forward(self, x, x_mask=None, other_embeds=0, attn_mask=1):
|
388 |
+
if self.n_vocab > 0:
|
389 |
+
x_lengths = (x > 0).long().sum(-1)
|
390 |
+
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
|
391 |
+
else:
|
392 |
+
x_lengths = (x.abs().sum(-1) > 0).long().sum(-1)
|
393 |
+
x = x + other_embeds
|
394 |
+
x = torch.transpose(x, 1, -1) # [b, h, t]
|
395 |
+
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
396 |
+
|
397 |
+
if self.prenet:
|
398 |
+
x = self.pre(x, x_mask)
|
399 |
+
self.prenet_out = x.transpose(1, 2)
|
400 |
+
if hasattr(self, 'encoder_inp_proj'):
|
401 |
+
x = self.encoder_inp_proj(x) * x_mask
|
402 |
+
x = self.encoder(x, x_mask, attn_mask)
|
403 |
+
return x.transpose(1, 2)
|
tts/modules/ar_dur/commons/rot_transformer.py
ADDED
@@ -0,0 +1,649 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 ByteDance and/or its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
import torch
|
17 |
+
from typing import Optional, Tuple
|
18 |
+
from torch import nn
|
19 |
+
from torch.nn import Parameter, Linear
|
20 |
+
from tts.modules.ar_dur.commons.layers import LayerNorm, Embedding
|
21 |
+
from tts.modules.ar_dur.commons.transformer import TransformerFFNLayer, MultiheadAttention
|
22 |
+
from tts.modules.ar_dur.commons.seq_utils import get_incremental_state, set_incremental_state, softmax, make_positions
|
23 |
+
import torch.nn.functional as F
|
24 |
+
|
25 |
+
DEFAULT_MAX_SOURCE_POSITIONS = 3000
|
26 |
+
DEFAULT_MAX_TARGET_POSITIONS = 3000
|
27 |
+
|
28 |
+
|
29 |
+
class SinusoidalPositionalEmbedding(nn.Module):
|
30 |
+
"""This module produces sinusoidal positional embeddings of any length.
|
31 |
+
|
32 |
+
Padding symbols are ignored.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self, embedding_dim, padding_idx, init_size=1024):
|
36 |
+
super().__init__()
|
37 |
+
self.embedding_dim = embedding_dim
|
38 |
+
self.padding_idx = padding_idx
|
39 |
+
self.weights = SinusoidalPositionalEmbedding.get_embedding(
|
40 |
+
init_size,
|
41 |
+
embedding_dim,
|
42 |
+
padding_idx,
|
43 |
+
)
|
44 |
+
self.register_buffer('_float_tensor', torch.FloatTensor(1))
|
45 |
+
|
46 |
+
@staticmethod
|
47 |
+
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
|
48 |
+
"""Build sinusoidal embeddings.
|
49 |
+
|
50 |
+
This matches the implementation in tensor2tensor, but differs slightly
|
51 |
+
from the description in Section 3.5 of "Attention Is All You Need".
|
52 |
+
"""
|
53 |
+
half_dim = embedding_dim // 2
|
54 |
+
emb = math.log(10000) / (half_dim - 1)
|
55 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
|
56 |
+
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
|
57 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
|
58 |
+
if embedding_dim % 2 == 1:
|
59 |
+
# zero pad
|
60 |
+
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
|
61 |
+
if padding_idx is not None:
|
62 |
+
emb[padding_idx, :] = 0
|
63 |
+
return emb
|
64 |
+
|
65 |
+
def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
|
66 |
+
"""Input is expected to be of size [bsz x seqlen]."""
|
67 |
+
bsz, seq_len = input.shape[:2]
|
68 |
+
max_pos = self.padding_idx + 1 + seq_len
|
69 |
+
if self.weights is None or max_pos > self.weights.size(0):
|
70 |
+
# recompute/expand embeddings if needed
|
71 |
+
self.weights = SinusoidalPositionalEmbedding.get_embedding(
|
72 |
+
max_pos,
|
73 |
+
self.embedding_dim,
|
74 |
+
self.padding_idx,
|
75 |
+
)
|
76 |
+
self.weights = self.weights.to(self._float_tensor)
|
77 |
+
|
78 |
+
if incremental_state is not None:
|
79 |
+
# positions is the same for every token when decoding a single step
|
80 |
+
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
|
81 |
+
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
|
82 |
+
|
83 |
+
positions = make_positions(input, self.padding_idx) if positions is None else positions
|
84 |
+
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
|
85 |
+
|
86 |
+
def max_positions(self):
|
87 |
+
"""Maximum number of supported positions."""
|
88 |
+
return int(1e5) # an arbitrary large number
|
89 |
+
|
90 |
+
|
91 |
+
class RotaryEmbeddings(nn.Module):
|
92 |
+
cos: torch.Tensor
|
93 |
+
sin: torch.Tensor
|
94 |
+
theta: torch.Tensor
|
95 |
+
|
96 |
+
def __init__(
|
97 |
+
self,
|
98 |
+
width: int,
|
99 |
+
*,
|
100 |
+
seq_len: int = 40000,
|
101 |
+
base: int = 10000,
|
102 |
+
device: Optional[torch.device] = None,
|
103 |
+
):
|
104 |
+
"""Rotary embeddings (Su et al., 2021) layer. The rotary embedding
|
105 |
+
will be precomputed for up to 'seq _len' positions. The embedding
|
106 |
+
will be recomputed when a longer sequence is found in the input.
|
107 |
+
|
108 |
+
:param width:
|
109 |
+
Rotary embedding dimensionality, must be even.
|
110 |
+
:param seq_len:
|
111 |
+
Number of positons to initially precompute.
|
112 |
+
:param base:
|
113 |
+
The base used for Θ_i, determines the cycle length of the
|
114 |
+
embeddings.
|
115 |
+
:param device: Device on which the module is to be initialized.
|
116 |
+
"""
|
117 |
+
super().__init__()
|
118 |
+
|
119 |
+
if width % 2:
|
120 |
+
raise ValueError(f"Width of rotary embeddings must be even, was: {width}")
|
121 |
+
|
122 |
+
# Ignore allocations on the meta device as we don't persist our buffer,
|
123 |
+
# i.e., we don't expect the backing tensor to be replaced with pretrained weights.
|
124 |
+
if device is not None and device.type == "meta":
|
125 |
+
device = None
|
126 |
+
# Θ_i = 10000^(-2(i-1)/d)
|
127 |
+
theta = torch.pow(
|
128 |
+
base, -torch.arange(0, width, 2, dtype=torch.float, device=device) / width
|
129 |
+
)
|
130 |
+
self.register_buffer("theta", theta, persistent=False)
|
131 |
+
|
132 |
+
self._create_rotary_embed(width=width, length=seq_len)
|
133 |
+
|
134 |
+
def _create_rotary_embed(self, *, width: int, length: int):
|
135 |
+
# mΘ
|
136 |
+
position = torch.arange(length, device=self.theta.device).unsqueeze(1)
|
137 |
+
m_theta = position * self.theta.unsqueeze(0)
|
138 |
+
|
139 |
+
# We apply both sin and cos twice (see Eq 15, 34), but the ordering
|
140 |
+
# is changed for compatibility with most common implementations.
|
141 |
+
m_theta = torch.cat([m_theta, m_theta], dim=-1)
|
142 |
+
|
143 |
+
re_cos = m_theta.cos().view([length, width])
|
144 |
+
re_sin = m_theta.sin().view([length, width])
|
145 |
+
|
146 |
+
self.register_buffer("cos", re_cos, persistent=False)
|
147 |
+
self.register_buffer("sin", re_sin, persistent=False)
|
148 |
+
|
149 |
+
def _rotate(self, input: torch.Tensor):
|
150 |
+
"""Rotate the input tensor by half of its innermost width.
|
151 |
+
|
152 |
+
input (Tensor): array to rotate.
|
153 |
+
RETURNS (Tensor): rotated array.
|
154 |
+
|
155 |
+
Shapes:
|
156 |
+
input - (..., width)
|
157 |
+
output - (..., width)
|
158 |
+
"""
|
159 |
+
half_idx = input.shape[-1] // 2
|
160 |
+
input_1 = -input[..., half_idx:]
|
161 |
+
input_2 = input[..., :half_idx]
|
162 |
+
return torch.cat([input_1, input_2], dim=-1)
|
163 |
+
|
164 |
+
def forward(self, input: torch.Tensor, *, positions: Optional[torch.Tensor] = None):
|
165 |
+
"""
|
166 |
+
Apply rotary embeddings to an array.
|
167 |
+
|
168 |
+
:param input: Array to apply the rotary embeddings to.
|
169 |
+
:param positions: positions of the inputs. If no positions are
|
170 |
+
provided, they are assumed to be [0, seq_len).
|
171 |
+
:return: Array with the rotary embeddings applied.
|
172 |
+
|
173 |
+
Shapes:
|
174 |
+
input - (batch_size, num_heads, seq_len, width_per_head)
|
175 |
+
positions - (batch_size, seq_len)
|
176 |
+
output - (batch_size, num_heads, seq_len, width_per_head)
|
177 |
+
"""
|
178 |
+
batch_size, _, seq_len, width = input.shape
|
179 |
+
|
180 |
+
if positions is None:
|
181 |
+
# Fastpath: positions from [0..seq_len), avoid indexing.
|
182 |
+
if self.cos.size(-2) < seq_len:
|
183 |
+
self._create_rotary_embed(width=width, length=seq_len)
|
184 |
+
rot_cos = self.cos[:seq_len, :].view(1, 1, seq_len, width)
|
185 |
+
rot_sin = self.sin[:seq_len, :].view(1, 1, seq_len, width)
|
186 |
+
else:
|
187 |
+
max_len = int(positions.max()) + 1
|
188 |
+
if self.cos.size(-2) < max_len:
|
189 |
+
self._create_rotary_embed(width=width, length=max_len)
|
190 |
+
|
191 |
+
# Flatten positions to index cos/sin arrays, then unflatten.
|
192 |
+
#
|
193 |
+
# Example shapes:
|
194 |
+
#
|
195 |
+
# positions_flat - (batch_size * seq_len)
|
196 |
+
# self.cos - (max_len, width)
|
197 |
+
# rot_cos - (batch_size, seq_len, width)
|
198 |
+
positions_flat = positions.view(-1)
|
199 |
+
rot_cos = self.cos[positions_flat].view(batch_size, 1, seq_len, width)
|
200 |
+
rot_sin = self.sin[positions_flat].view(batch_size, 1, seq_len, width)
|
201 |
+
|
202 |
+
# Eq 34 with ordering changed for compatibility.
|
203 |
+
return rot_cos * input + rot_sin * self._rotate(input)
|
204 |
+
|
205 |
+
|
206 |
+
class RotMultiheadAttention(MultiheadAttention):
|
207 |
+
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
|
208 |
+
add_bias_kv=False, add_zero_attn=False, self_attention=False,
|
209 |
+
encoder_decoder_attention=False):
|
210 |
+
super().__init__(embed_dim, num_heads, kdim=kdim, vdim=vdim, dropout=dropout, bias=bias,
|
211 |
+
add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=self_attention,
|
212 |
+
encoder_decoder_attention=encoder_decoder_attention)
|
213 |
+
self.rotary_embeds = RotaryEmbeddings(width=embed_dim // num_heads)
|
214 |
+
|
215 |
+
def forward(
|
216 |
+
self,
|
217 |
+
query, key, value,
|
218 |
+
spk_pos_ids_flat=None,
|
219 |
+
key_padding_mask=None,
|
220 |
+
incremental_state=None,
|
221 |
+
need_weights=True,
|
222 |
+
static_kv=False,
|
223 |
+
attn_mask=None,
|
224 |
+
before_softmax=False,
|
225 |
+
need_head_weights=False,
|
226 |
+
enc_dec_attn_constraint_mask=None,
|
227 |
+
reset_attn_weight=None
|
228 |
+
):
|
229 |
+
"""Input shape: Time x Batch x Channel
|
230 |
+
|
231 |
+
Args:
|
232 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
233 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
234 |
+
padding elements are indicated by 1s.
|
235 |
+
need_weights (bool, optional): return the attention weights,
|
236 |
+
averaged over heads (default: False).
|
237 |
+
attn_mask (ByteTensor, optional): typically used to
|
238 |
+
implement causal attention, where the mask prevents the
|
239 |
+
attention from looking forward in time (default: None).
|
240 |
+
before_softmax (bool, optional): return the raw attention
|
241 |
+
weights and values before the attention softmax.
|
242 |
+
need_head_weights (bool, optional): return the attention
|
243 |
+
weights for each head. Implies *need_weights*. Default:
|
244 |
+
return the average attention weights over all heads.
|
245 |
+
"""
|
246 |
+
if need_head_weights:
|
247 |
+
need_weights = True
|
248 |
+
|
249 |
+
tgt_len, bsz, embed_dim = query.size()
|
250 |
+
assert embed_dim == self.embed_dim
|
251 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
252 |
+
|
253 |
+
if incremental_state is not None:
|
254 |
+
saved_state = self._get_input_buffer(incremental_state)
|
255 |
+
if 'prev_key' in saved_state:
|
256 |
+
# previous time steps are cached - no need to recompute
|
257 |
+
# key and value if they are static
|
258 |
+
if static_kv:
|
259 |
+
assert self.encoder_decoder_attention and not self.self_attention
|
260 |
+
key = value = None
|
261 |
+
else:
|
262 |
+
saved_state = None
|
263 |
+
|
264 |
+
if self.self_attention:
|
265 |
+
# self-attention
|
266 |
+
q, k, v = self.in_proj_qkv(query)
|
267 |
+
elif self.encoder_decoder_attention:
|
268 |
+
# encoder-decoder attention
|
269 |
+
q = self.in_proj_q(query)
|
270 |
+
if key is None:
|
271 |
+
assert value is None
|
272 |
+
k = v = None
|
273 |
+
else:
|
274 |
+
k = self.in_proj_k(key)
|
275 |
+
v = self.in_proj_v(key)
|
276 |
+
else:
|
277 |
+
q = self.in_proj_q(query)
|
278 |
+
k = self.in_proj_k(key)
|
279 |
+
v = self.in_proj_v(value)
|
280 |
+
q = q * self.scaling
|
281 |
+
|
282 |
+
if self.bias_k is not None:
|
283 |
+
assert self.bias_v is not None
|
284 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
285 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
286 |
+
if attn_mask is not None:
|
287 |
+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
288 |
+
if key_padding_mask is not None:
|
289 |
+
key_padding_mask = torch.cat(
|
290 |
+
[key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
|
291 |
+
|
292 |
+
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
293 |
+
if k is not None:
|
294 |
+
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
295 |
+
if v is not None:
|
296 |
+
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
297 |
+
|
298 |
+
# Apply rot embedding and store incremental_state
|
299 |
+
q = self.rotary_embeds(q[None, :], positions=spk_pos_ids_flat)[0]
|
300 |
+
if saved_state is not None:
|
301 |
+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
302 |
+
if 'prev_key' in saved_state:
|
303 |
+
prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
|
304 |
+
if static_kv:
|
305 |
+
k = prev_key
|
306 |
+
else:
|
307 |
+
k = torch.cat((prev_key, k), dim=1)
|
308 |
+
if 'prev_value' in saved_state:
|
309 |
+
prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
|
310 |
+
if static_kv:
|
311 |
+
v = prev_value
|
312 |
+
else:
|
313 |
+
v = torch.cat((prev_value, v), dim=1)
|
314 |
+
saved_state['prev_key'], saved_state['prev_value'] = k.view(bsz, self.num_heads, -1, self.head_dim), v.view(
|
315 |
+
bsz, self.num_heads, -1, self.head_dim)
|
316 |
+
self._set_input_buffer(incremental_state, saved_state)
|
317 |
+
if incremental_state is not None:
|
318 |
+
key_pos = torch.arange(k.shape[-2], device=q.device).unsqueeze(0)
|
319 |
+
else:
|
320 |
+
key_pos = spk_pos_ids_flat
|
321 |
+
k = self.rotary_embeds(k[None, :], positions=key_pos)[0]
|
322 |
+
|
323 |
+
src_len = k.size(1)
|
324 |
+
|
325 |
+
# This is part of a workaround to get around fork/join parallelism
|
326 |
+
# not supporting Optional types.
|
327 |
+
if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
|
328 |
+
key_padding_mask = None
|
329 |
+
|
330 |
+
if key_padding_mask is not None:
|
331 |
+
assert key_padding_mask.size(0) == bsz
|
332 |
+
assert key_padding_mask.size(1) == src_len
|
333 |
+
|
334 |
+
if self.add_zero_attn:
|
335 |
+
src_len += 1
|
336 |
+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
337 |
+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
338 |
+
if attn_mask is not None:
|
339 |
+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
340 |
+
if key_padding_mask is not None:
|
341 |
+
key_padding_mask = torch.cat(
|
342 |
+
[key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
|
343 |
+
|
344 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
345 |
+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
346 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
347 |
+
|
348 |
+
if attn_mask is not None:
|
349 |
+
if len(attn_mask.shape) == 2:
|
350 |
+
attn_mask = attn_mask.unsqueeze(0)
|
351 |
+
elif len(attn_mask.shape) == 3:
|
352 |
+
attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
|
353 |
+
bsz * self.num_heads, tgt_len, src_len)
|
354 |
+
attn_weights = attn_weights + attn_mask
|
355 |
+
|
356 |
+
if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
|
357 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
358 |
+
attn_weights = attn_weights.masked_fill(
|
359 |
+
enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
|
360 |
+
-1e8,
|
361 |
+
)
|
362 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
363 |
+
|
364 |
+
if key_padding_mask is not None:
|
365 |
+
# don't attend to padding symbols
|
366 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
367 |
+
attn_weights = attn_weights.masked_fill(
|
368 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
369 |
+
-1e8,
|
370 |
+
)
|
371 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
372 |
+
|
373 |
+
attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
374 |
+
|
375 |
+
if before_softmax:
|
376 |
+
return attn_weights, v
|
377 |
+
|
378 |
+
attn_weights_float = softmax(attn_weights, dim=-1)
|
379 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
380 |
+
attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
|
381 |
+
|
382 |
+
if reset_attn_weight is not None:
|
383 |
+
if reset_attn_weight:
|
384 |
+
self.last_attn_probs = attn_probs.detach()
|
385 |
+
else:
|
386 |
+
assert self.last_attn_probs is not None
|
387 |
+
attn_probs = self.last_attn_probs
|
388 |
+
attn = torch.bmm(attn_probs, v)
|
389 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
390 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
391 |
+
attn = self.out_proj(attn)
|
392 |
+
|
393 |
+
if need_weights:
|
394 |
+
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
|
395 |
+
if not need_head_weights:
|
396 |
+
# average attention weights over heads
|
397 |
+
attn_weights = attn_weights.mean(dim=0)
|
398 |
+
else:
|
399 |
+
attn_weights = None
|
400 |
+
|
401 |
+
return attn, (attn_weights, attn_logits)
|
402 |
+
|
403 |
+
|
404 |
+
class RotMultiheadAttention2(MultiheadAttention):
|
405 |
+
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
|
406 |
+
add_bias_kv=False, add_zero_attn=False, self_attention=False,
|
407 |
+
encoder_decoder_attention=False):
|
408 |
+
super().__init__(embed_dim, num_heads, kdim=kdim, vdim=vdim, dropout=dropout, bias=bias,
|
409 |
+
add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=self_attention,
|
410 |
+
encoder_decoder_attention=encoder_decoder_attention)
|
411 |
+
self.rotary_embeds = RotaryEmbeddings(width=embed_dim // num_heads)
|
412 |
+
|
413 |
+
def forward(
|
414 |
+
self,
|
415 |
+
query, key, value,
|
416 |
+
spk_pos_ids_flat=None,
|
417 |
+
key_padding_mask=None,
|
418 |
+
incremental_state=None,
|
419 |
+
need_weights=True,
|
420 |
+
static_kv=False,
|
421 |
+
attn_mask=None,
|
422 |
+
before_softmax=False,
|
423 |
+
need_head_weights=False,
|
424 |
+
enc_dec_attn_constraint_mask=None,
|
425 |
+
reset_attn_weight=None
|
426 |
+
):
|
427 |
+
"""Input shape: Time x Batch x Channel
|
428 |
+
|
429 |
+
Args:
|
430 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
431 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
432 |
+
padding elements are indicated by 1s.
|
433 |
+
need_weights (bool, optional): return the attention weights,
|
434 |
+
averaged over heads (default: False).
|
435 |
+
attn_mask (ByteTensor, optional): typically used to
|
436 |
+
implement causal attention, where the mask prevents the
|
437 |
+
attention from looking forward in time (default: None).
|
438 |
+
before_softmax (bool, optional): return the raw attention
|
439 |
+
weights and values before the attention softmax.
|
440 |
+
need_head_weights (bool, optional): return the attention
|
441 |
+
weights for each head. Implies *need_weights*. Default:
|
442 |
+
return the average attention weights over all heads.
|
443 |
+
"""
|
444 |
+
if need_head_weights:
|
445 |
+
need_weights = True
|
446 |
+
|
447 |
+
tgt_len, bsz, embed_dim = query.size()
|
448 |
+
assert embed_dim == self.embed_dim
|
449 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
450 |
+
|
451 |
+
if incremental_state is not None:
|
452 |
+
saved_state = self._get_input_buffer(incremental_state)
|
453 |
+
if 'prev_key' in saved_state:
|
454 |
+
# previous time steps are cached - no need to recompute
|
455 |
+
# key and value if they are static
|
456 |
+
if static_kv:
|
457 |
+
assert self.encoder_decoder_attention and not self.self_attention
|
458 |
+
key = value = None
|
459 |
+
else:
|
460 |
+
saved_state = None
|
461 |
+
|
462 |
+
if self.self_attention:
|
463 |
+
# self-attention
|
464 |
+
q, k, v = self.in_proj_qkv(query)
|
465 |
+
elif self.encoder_decoder_attention:
|
466 |
+
# encoder-decoder attention
|
467 |
+
q = self.in_proj_q(query)
|
468 |
+
if key is None:
|
469 |
+
assert value is None
|
470 |
+
k = v = None
|
471 |
+
else:
|
472 |
+
k = self.in_proj_k(key)
|
473 |
+
v = self.in_proj_v(key)
|
474 |
+
else:
|
475 |
+
q = self.in_proj_q(query)
|
476 |
+
k = self.in_proj_k(key)
|
477 |
+
v = self.in_proj_v(value)
|
478 |
+
|
479 |
+
if self.bias_k is not None:
|
480 |
+
assert self.bias_v is not None
|
481 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
482 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
483 |
+
if attn_mask is not None:
|
484 |
+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
485 |
+
if key_padding_mask is not None:
|
486 |
+
key_padding_mask = torch.cat(
|
487 |
+
[key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
|
488 |
+
|
489 |
+
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
490 |
+
if k is not None:
|
491 |
+
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
492 |
+
if v is not None:
|
493 |
+
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
494 |
+
|
495 |
+
# Apply rot embedding and store incremental_state
|
496 |
+
q = self.rotary_embeds(q[None, :], positions=spk_pos_ids_flat)[0]
|
497 |
+
if saved_state is not None:
|
498 |
+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
499 |
+
if 'prev_key' in saved_state:
|
500 |
+
prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
|
501 |
+
if static_kv:
|
502 |
+
k = prev_key
|
503 |
+
else:
|
504 |
+
k = torch.cat((prev_key, k), dim=1)
|
505 |
+
if 'prev_value' in saved_state:
|
506 |
+
prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
|
507 |
+
if static_kv:
|
508 |
+
v = prev_value
|
509 |
+
else:
|
510 |
+
v = torch.cat((prev_value, v), dim=1)
|
511 |
+
saved_state['prev_key'], saved_state['prev_value'] = k.view(bsz, self.num_heads, -1, self.head_dim), v.view(
|
512 |
+
bsz, self.num_heads, -1, self.head_dim)
|
513 |
+
self._set_input_buffer(incremental_state, saved_state)
|
514 |
+
key_pos = torch.arange(k.shape[-2], device=q.device).unsqueeze(0)
|
515 |
+
k = self.rotary_embeds(k[None, :], positions=key_pos)[0]
|
516 |
+
|
517 |
+
src_len = k.size(1)
|
518 |
+
|
519 |
+
# This is part of a workaround to get around fork/join parallelism
|
520 |
+
# not supporting Optional types.
|
521 |
+
if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
|
522 |
+
key_padding_mask = None
|
523 |
+
|
524 |
+
if key_padding_mask is not None:
|
525 |
+
assert key_padding_mask.size(0) == bsz
|
526 |
+
assert key_padding_mask.size(1) == src_len
|
527 |
+
|
528 |
+
if attn_mask is not None:
|
529 |
+
if len(attn_mask.shape) == 2:
|
530 |
+
attn_mask = attn_mask.unsqueeze(0)
|
531 |
+
elif len(attn_mask.shape) == 3:
|
532 |
+
attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
|
533 |
+
bsz * self.num_heads, tgt_len, src_len)
|
534 |
+
attn = torch.nn.functional.scaled_dot_product_attention(
|
535 |
+
q, k, v, attn_mask=attn_mask, dropout_p=0, is_causal=False)
|
536 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
537 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
538 |
+
attn_logits = None
|
539 |
+
attn_weights = None
|
540 |
+
return attn, (attn_weights, attn_logits)
|
541 |
+
|
542 |
+
|
543 |
+
class RotDecSALayer(nn.Module):
|
544 |
+
def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1,
|
545 |
+
kernel_size=9, ffn_hidden_size=1024, act='gelu', post_ln=False, bias=True):
|
546 |
+
super().__init__()
|
547 |
+
self.c = c
|
548 |
+
self.dropout = dropout
|
549 |
+
self.layer_norm1 = LayerNorm(c)
|
550 |
+
self.self_attn = RotMultiheadAttention(
|
551 |
+
c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
|
552 |
+
)
|
553 |
+
self.layer_norm2 = LayerNorm(c)
|
554 |
+
self.ffn = TransformerFFNLayer(
|
555 |
+
c, ffn_hidden_size, padding='LEFT', kernel_size=kernel_size,
|
556 |
+
dropout=relu_dropout, act=act, bias=bias)
|
557 |
+
self.post_ln = post_ln
|
558 |
+
|
559 |
+
def forward(
|
560 |
+
self,
|
561 |
+
x,
|
562 |
+
encoder_out=None,
|
563 |
+
encoder_padding_mask=None,
|
564 |
+
incremental_state=None,
|
565 |
+
self_attn_mask=None,
|
566 |
+
self_attn_padding_mask=None,
|
567 |
+
attn_out=None,
|
568 |
+
reset_attn_weight=None,
|
569 |
+
spk_pos_ids_flat=None,
|
570 |
+
**kwargs,
|
571 |
+
):
|
572 |
+
layer_norm_training = kwargs.get('layer_norm_training', None)
|
573 |
+
if layer_norm_training is not None:
|
574 |
+
self.layer_norm1.training = layer_norm_training
|
575 |
+
self.layer_norm2.training = layer_norm_training
|
576 |
+
residual = x
|
577 |
+
if not self.post_ln:
|
578 |
+
x = self.layer_norm1(x)
|
579 |
+
|
580 |
+
x, (attn_weights, _) = self.self_attn(
|
581 |
+
query=x,
|
582 |
+
key=x,
|
583 |
+
value=x,
|
584 |
+
key_padding_mask=self_attn_padding_mask,
|
585 |
+
incremental_state=incremental_state,
|
586 |
+
attn_mask=self_attn_mask,
|
587 |
+
spk_pos_ids_flat=spk_pos_ids_flat
|
588 |
+
)
|
589 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
590 |
+
x = residual + x
|
591 |
+
if self.post_ln:
|
592 |
+
x = self.layer_norm1(x)
|
593 |
+
|
594 |
+
residual = x
|
595 |
+
if not self.post_ln:
|
596 |
+
x = self.layer_norm2(x)
|
597 |
+
x = self.ffn(x, incremental_state=incremental_state)
|
598 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
599 |
+
x = residual + x
|
600 |
+
if self.post_ln:
|
601 |
+
x = self.layer_norm2(x)
|
602 |
+
return x, attn_weights
|
603 |
+
|
604 |
+
def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
|
605 |
+
self.encoder_attn.clear_buffer(incremental_state)
|
606 |
+
self.ffn.clear_buffer(incremental_state)
|
607 |
+
|
608 |
+
def set_buffer(self, name, tensor, incremental_state):
|
609 |
+
return set_incremental_state(self, incremental_state, name, tensor)
|
610 |
+
|
611 |
+
|
612 |
+
class RotDecSALayer2(RotDecSALayer):
|
613 |
+
def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1, kernel_size=9,
|
614 |
+
ffn_hidden_size=1024, act='gelu', post_ln=False):
|
615 |
+
super().__init__(c, num_heads, dropout, attention_dropout, relu_dropout, kernel_size, ffn_hidden_size, act,
|
616 |
+
post_ln)
|
617 |
+
self.self_attn = RotMultiheadAttention2(
|
618 |
+
c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
|
619 |
+
)
|
620 |
+
|
621 |
+
|
622 |
+
class RotTransformerDecoderLayer(nn.Module):
|
623 |
+
def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=8, ffn_hidden_size=1024, post_ln=False,
|
624 |
+
op_version=1, bias=True):
|
625 |
+
super().__init__()
|
626 |
+
self.hidden_size = hidden_size
|
627 |
+
self.dropout = dropout
|
628 |
+
self.num_heads = num_heads
|
629 |
+
if op_version == 1:
|
630 |
+
self.op = RotDecSALayer(
|
631 |
+
hidden_size, num_heads, dropout=dropout,
|
632 |
+
attention_dropout=0.0, relu_dropout=dropout,
|
633 |
+
kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size,
|
634 |
+
post_ln=post_ln, bias=bias)
|
635 |
+
else:
|
636 |
+
self.op = RotDecSALayer2(
|
637 |
+
hidden_size, num_heads, dropout=dropout,
|
638 |
+
attention_dropout=0.0, relu_dropout=dropout,
|
639 |
+
kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size,
|
640 |
+
post_ln=post_ln)
|
641 |
+
|
642 |
+
def forward(self, x, **kwargs):
|
643 |
+
return self.op(x, **kwargs)
|
644 |
+
|
645 |
+
def clear_buffer(self, *args):
|
646 |
+
return self.op.clear_buffer(*args)
|
647 |
+
|
648 |
+
def set_buffer(self, *args):
|
649 |
+
return self.op.set_buffer(*args)
|
tts/modules/ar_dur/commons/seq_utils.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 ByteDance and/or its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from collections import defaultdict
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
|
19 |
+
|
20 |
+
def make_positions(tensor, padding_idx):
|
21 |
+
"""Replace non-padding symbols with their position numbers.
|
22 |
+
|
23 |
+
Position numbers begin at padding_idx+1. Padding symbols are ignored.
|
24 |
+
"""
|
25 |
+
# The series of casts and type-conversions here are carefully
|
26 |
+
# balanced to both work with ONNX export and XLA. In particular XLA
|
27 |
+
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know
|
28 |
+
# how to handle the dtype kwarg in cumsum.
|
29 |
+
mask = tensor.ne(padding_idx).int()
|
30 |
+
return (
|
31 |
+
torch.cumsum(mask, dim=1).type_as(mask) * mask
|
32 |
+
).long() + padding_idx
|
33 |
+
|
34 |
+
|
35 |
+
def softmax(x, dim):
|
36 |
+
return F.softmax(x, dim=dim, dtype=torch.float32)
|
37 |
+
|
38 |
+
|
39 |
+
def sequence_mask(lengths, maxlen=None, dtype=torch.bool):
|
40 |
+
if maxlen is None:
|
41 |
+
maxlen = lengths.max()
|
42 |
+
mask = ~(torch.ones((len(lengths), maxlen)).to(lengths.device).cumsum(dim=1).t() > lengths).t()
|
43 |
+
mask.type(dtype)
|
44 |
+
return mask
|
45 |
+
|
46 |
+
|
47 |
+
def weights_nonzero_speech(target):
|
48 |
+
# target : B x T x mel
|
49 |
+
# Assign weight 1.0 to all labels except for padding (id=0).
|
50 |
+
dim = target.size(-1)
|
51 |
+
return target.abs().sum(-1, keepdim=True).ne(0).float().repeat(1, 1, dim)
|
52 |
+
|
53 |
+
|
54 |
+
INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0)
|
55 |
+
|
56 |
+
|
57 |
+
def _get_full_incremental_state_key(module_instance, key):
|
58 |
+
module_name = module_instance.__class__.__name__
|
59 |
+
|
60 |
+
# assign a unique ID to each module instance, so that incremental state is
|
61 |
+
# not shared across module instances
|
62 |
+
if not hasattr(module_instance, '_instance_id'):
|
63 |
+
INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1
|
64 |
+
module_instance._instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name]
|
65 |
+
|
66 |
+
return '{}.{}.{}'.format(module_name, module_instance._instance_id, key)
|
67 |
+
|
68 |
+
|
69 |
+
def get_incremental_state(module, incremental_state, key):
|
70 |
+
"""Helper for getting incremental state for an nn.Module."""
|
71 |
+
full_key = _get_full_incremental_state_key(module, key)
|
72 |
+
if incremental_state is None or full_key not in incremental_state:
|
73 |
+
return None
|
74 |
+
return incremental_state[full_key]
|
75 |
+
|
76 |
+
|
77 |
+
def set_incremental_state(module, incremental_state, key, value):
|
78 |
+
"""Helper for setting incremental state for an nn.Module."""
|
79 |
+
if incremental_state is not None:
|
80 |
+
full_key = _get_full_incremental_state_key(module, key)
|
81 |
+
incremental_state[full_key] = value
|
82 |
+
|
83 |
+
|
84 |
+
def fill_with_neg_inf(t):
|
85 |
+
"""FP16-compatible function that fills a tensor with -inf."""
|
86 |
+
return t.float().fill_(float('-inf')).type_as(t)
|
87 |
+
|
88 |
+
|
89 |
+
def fill_with_neg_inf2(t):
|
90 |
+
"""FP16-compatible function that fills a tensor with -inf."""
|
91 |
+
return t.float().fill_(-1e8).type_as(t)
|
92 |
+
|
93 |
+
|
94 |
+
def select_attn(attn_logits, type='best'):
|
95 |
+
"""
|
96 |
+
|
97 |
+
:param attn_logits: [n_layers, B, n_head, T_sp, T_txt]
|
98 |
+
:return:
|
99 |
+
"""
|
100 |
+
encdec_attn = torch.stack(attn_logits, 0).transpose(1, 2)
|
101 |
+
# [n_layers * n_head, B, T_sp, T_txt]
|
102 |
+
encdec_attn = (encdec_attn.reshape([-1, *encdec_attn.shape[2:]])).softmax(-1)
|
103 |
+
if type == 'best':
|
104 |
+
indices = encdec_attn.max(-1).values.sum(-1).argmax(0)
|
105 |
+
encdec_attn = encdec_attn.gather(
|
106 |
+
0, indices[None, :, None, None].repeat(1, 1, encdec_attn.size(-2), encdec_attn.size(-1)))[0]
|
107 |
+
return encdec_attn
|
108 |
+
elif type == 'mean':
|
109 |
+
return encdec_attn.mean(0)
|
110 |
+
|
111 |
+
|
112 |
+
def make_pad_mask(lengths, xs=None, length_dim=-1):
|
113 |
+
"""Make mask tensor containing indices of padded part.
|
114 |
+
Args:
|
115 |
+
lengths (LongTensor or List): Batch of lengths (B,).
|
116 |
+
xs (Tensor, optional): The reference tensor.
|
117 |
+
If set, masks will be the same shape as this tensor.
|
118 |
+
length_dim (int, optional): Dimension indicator of the above tensor.
|
119 |
+
See the example.
|
120 |
+
Returns:
|
121 |
+
Tensor: Mask tensor containing indices of padded part.
|
122 |
+
dtype=torch.uint8 in PyTorch 1.2-
|
123 |
+
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
124 |
+
Examples:
|
125 |
+
With only lengths.
|
126 |
+
>>> lengths = [5, 3, 2]
|
127 |
+
>>> make_non_pad_mask(lengths)
|
128 |
+
masks = [[0, 0, 0, 0 ,0],
|
129 |
+
[0, 0, 0, 1, 1],
|
130 |
+
[0, 0, 1, 1, 1]]
|
131 |
+
With the reference tensor.
|
132 |
+
>>> xs = torch.zeros((3, 2, 4))
|
133 |
+
>>> make_pad_mask(lengths, xs)
|
134 |
+
tensor([[[0, 0, 0, 0],
|
135 |
+
[0, 0, 0, 0]],
|
136 |
+
[[0, 0, 0, 1],
|
137 |
+
[0, 0, 0, 1]],
|
138 |
+
[[0, 0, 1, 1],
|
139 |
+
[0, 0, 1, 1]]], dtype=torch.uint8)
|
140 |
+
>>> xs = torch.zeros((3, 2, 6))
|
141 |
+
>>> make_pad_mask(lengths, xs)
|
142 |
+
tensor([[[0, 0, 0, 0, 0, 1],
|
143 |
+
[0, 0, 0, 0, 0, 1]],
|
144 |
+
[[0, 0, 0, 1, 1, 1],
|
145 |
+
[0, 0, 0, 1, 1, 1]],
|
146 |
+
[[0, 0, 1, 1, 1, 1],
|
147 |
+
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
148 |
+
With the reference tensor and dimension indicator.
|
149 |
+
>>> xs = torch.zeros((3, 6, 6))
|
150 |
+
>>> make_pad_mask(lengths, xs, 1)
|
151 |
+
tensor([[[0, 0, 0, 0, 0, 0],
|
152 |
+
[0, 0, 0, 0, 0, 0],
|
153 |
+
[0, 0, 0, 0, 0, 0],
|
154 |
+
[0, 0, 0, 0, 0, 0],
|
155 |
+
[0, 0, 0, 0, 0, 0],
|
156 |
+
[1, 1, 1, 1, 1, 1]],
|
157 |
+
[[0, 0, 0, 0, 0, 0],
|
158 |
+
[0, 0, 0, 0, 0, 0],
|
159 |
+
[0, 0, 0, 0, 0, 0],
|
160 |
+
[1, 1, 1, 1, 1, 1],
|
161 |
+
[1, 1, 1, 1, 1, 1],
|
162 |
+
[1, 1, 1, 1, 1, 1]],
|
163 |
+
[[0, 0, 0, 0, 0, 0],
|
164 |
+
[0, 0, 0, 0, 0, 0],
|
165 |
+
[1, 1, 1, 1, 1, 1],
|
166 |
+
[1, 1, 1, 1, 1, 1],
|
167 |
+
[1, 1, 1, 1, 1, 1],
|
168 |
+
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
|
169 |
+
>>> make_pad_mask(lengths, xs, 2)
|
170 |
+
tensor([[[0, 0, 0, 0, 0, 1],
|
171 |
+
[0, 0, 0, 0, 0, 1],
|
172 |
+
[0, 0, 0, 0, 0, 1],
|
173 |
+
[0, 0, 0, 0, 0, 1],
|
174 |
+
[0, 0, 0, 0, 0, 1],
|
175 |
+
[0, 0, 0, 0, 0, 1]],
|
176 |
+
[[0, 0, 0, 1, 1, 1],
|
177 |
+
[0, 0, 0, 1, 1, 1],
|
178 |
+
[0, 0, 0, 1, 1, 1],
|
179 |
+
[0, 0, 0, 1, 1, 1],
|
180 |
+
[0, 0, 0, 1, 1, 1],
|
181 |
+
[0, 0, 0, 1, 1, 1]],
|
182 |
+
[[0, 0, 1, 1, 1, 1],
|
183 |
+
[0, 0, 1, 1, 1, 1],
|
184 |
+
[0, 0, 1, 1, 1, 1],
|
185 |
+
[0, 0, 1, 1, 1, 1],
|
186 |
+
[0, 0, 1, 1, 1, 1],
|
187 |
+
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
188 |
+
"""
|
189 |
+
if length_dim == 0:
|
190 |
+
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
|
191 |
+
|
192 |
+
if not isinstance(lengths, list):
|
193 |
+
lengths = lengths.tolist()
|
194 |
+
bs = int(len(lengths))
|
195 |
+
if xs is None:
|
196 |
+
maxlen = int(max(lengths))
|
197 |
+
else:
|
198 |
+
maxlen = xs.size(length_dim)
|
199 |
+
|
200 |
+
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
|
201 |
+
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
|
202 |
+
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
|
203 |
+
mask = seq_range_expand >= seq_length_expand
|
204 |
+
|
205 |
+
if xs is not None:
|
206 |
+
assert xs.size(0) == bs, (xs.size(0), bs)
|
207 |
+
|
208 |
+
if length_dim < 0:
|
209 |
+
length_dim = xs.dim() + length_dim
|
210 |
+
# ind = (:, None, ..., None, :, , None, ..., None)
|
211 |
+
ind = tuple(
|
212 |
+
slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
|
213 |
+
)
|
214 |
+
mask = mask[ind].expand_as(xs).to(xs.device)
|
215 |
+
return mask
|
216 |
+
|
217 |
+
|
218 |
+
def make_non_pad_mask(lengths, xs=None, length_dim=-1):
|
219 |
+
"""Make mask tensor containing indices of non-padded part.
|
220 |
+
Args:
|
221 |
+
lengths (LongTensor or List): Batch of lengths (B,).
|
222 |
+
xs (Tensor, optional): The reference tensor.
|
223 |
+
If set, masks will be the same shape as this tensor.
|
224 |
+
length_dim (int, optional): Dimension indicator of the above tensor.
|
225 |
+
See the example.
|
226 |
+
Returns:
|
227 |
+
ByteTensor: mask tensor containing indices of padded part.
|
228 |
+
dtype=torch.uint8 in PyTorch 1.2-
|
229 |
+
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
230 |
+
Examples:
|
231 |
+
With only lengths.
|
232 |
+
>>> lengths = [5, 3, 2]
|
233 |
+
>>> make_non_pad_mask(lengths)
|
234 |
+
masks = [[1, 1, 1, 1 ,1],
|
235 |
+
[1, 1, 1, 0, 0],
|
236 |
+
[1, 1, 0, 0, 0]]
|
237 |
+
With the reference tensor.
|
238 |
+
>>> xs = torch.zeros((3, 2, 4))
|
239 |
+
>>> make_non_pad_mask(lengths, xs)
|
240 |
+
tensor([[[1, 1, 1, 1],
|
241 |
+
[1, 1, 1, 1]],
|
242 |
+
[[1, 1, 1, 0],
|
243 |
+
[1, 1, 1, 0]],
|
244 |
+
[[1, 1, 0, 0],
|
245 |
+
[1, 1, 0, 0]]], dtype=torch.uint8)
|
246 |
+
>>> xs = torch.zeros((3, 2, 6))
|
247 |
+
>>> make_non_pad_mask(lengths, xs)
|
248 |
+
tensor([[[1, 1, 1, 1, 1, 0],
|
249 |
+
[1, 1, 1, 1, 1, 0]],
|
250 |
+
[[1, 1, 1, 0, 0, 0],
|
251 |
+
[1, 1, 1, 0, 0, 0]],
|
252 |
+
[[1, 1, 0, 0, 0, 0],
|
253 |
+
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
254 |
+
With the reference tensor and dimension indicator.
|
255 |
+
>>> xs = torch.zeros((3, 6, 6))
|
256 |
+
>>> make_non_pad_mask(lengths, xs, 1)
|
257 |
+
tensor([[[1, 1, 1, 1, 1, 1],
|
258 |
+
[1, 1, 1, 1, 1, 1],
|
259 |
+
[1, 1, 1, 1, 1, 1],
|
260 |
+
[1, 1, 1, 1, 1, 1],
|
261 |
+
[1, 1, 1, 1, 1, 1],
|
262 |
+
[0, 0, 0, 0, 0, 0]],
|
263 |
+
[[1, 1, 1, 1, 1, 1],
|
264 |
+
[1, 1, 1, 1, 1, 1],
|
265 |
+
[1, 1, 1, 1, 1, 1],
|
266 |
+
[0, 0, 0, 0, 0, 0],
|
267 |
+
[0, 0, 0, 0, 0, 0],
|
268 |
+
[0, 0, 0, 0, 0, 0]],
|
269 |
+
[[1, 1, 1, 1, 1, 1],
|
270 |
+
[1, 1, 1, 1, 1, 1],
|
271 |
+
[0, 0, 0, 0, 0, 0],
|
272 |
+
[0, 0, 0, 0, 0, 0],
|
273 |
+
[0, 0, 0, 0, 0, 0],
|
274 |
+
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
|
275 |
+
>>> make_non_pad_mask(lengths, xs, 2)
|
276 |
+
tensor([[[1, 1, 1, 1, 1, 0],
|
277 |
+
[1, 1, 1, 1, 1, 0],
|
278 |
+
[1, 1, 1, 1, 1, 0],
|
279 |
+
[1, 1, 1, 1, 1, 0],
|
280 |
+
[1, 1, 1, 1, 1, 0],
|
281 |
+
[1, 1, 1, 1, 1, 0]],
|
282 |
+
[[1, 1, 1, 0, 0, 0],
|
283 |
+
[1, 1, 1, 0, 0, 0],
|
284 |
+
[1, 1, 1, 0, 0, 0],
|
285 |
+
[1, 1, 1, 0, 0, 0],
|
286 |
+
[1, 1, 1, 0, 0, 0],
|
287 |
+
[1, 1, 1, 0, 0, 0]],
|
288 |
+
[[1, 1, 0, 0, 0, 0],
|
289 |
+
[1, 1, 0, 0, 0, 0],
|
290 |
+
[1, 1, 0, 0, 0, 0],
|
291 |
+
[1, 1, 0, 0, 0, 0],
|
292 |
+
[1, 1, 0, 0, 0, 0],
|
293 |
+
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
294 |
+
"""
|
295 |
+
return ~make_pad_mask(lengths, xs, length_dim)
|
296 |
+
|
297 |
+
|
298 |
+
def get_mask_from_lengths(lengths):
|
299 |
+
max_len = torch.max(lengths).item()
|
300 |
+
ids = torch.arange(0, max_len).to(lengths.device)
|
301 |
+
mask = (ids < lengths.unsqueeze(1)).bool()
|
302 |
+
return mask
|
303 |
+
|
304 |
+
|
305 |
+
def group_hidden_by_segs(h, seg_ids, max_len):
|
306 |
+
"""
|
307 |
+
|
308 |
+
:param h: [B, T, H]
|
309 |
+
:param seg_ids: [B, T]
|
310 |
+
:return: h_ph: [B, T_ph, H]
|
311 |
+
"""
|
312 |
+
B, T, H = h.shape
|
313 |
+
h_gby_segs = h.new_zeros([B, max_len + 1, H]).scatter_add_(1, seg_ids[:, :, None].repeat([1, 1, H]), h)
|
314 |
+
all_ones = h.new_ones(h.shape[:2])
|
315 |
+
cnt_gby_segs = h.new_zeros([B, max_len + 1]).scatter_add_(1, seg_ids, all_ones).contiguous()
|
316 |
+
h_gby_segs = h_gby_segs[:, 1:]
|
317 |
+
cnt_gby_segs = cnt_gby_segs[:, 1:]
|
318 |
+
h_gby_segs = h_gby_segs / torch.clamp(cnt_gby_segs[:, :, None], min=1)
|
319 |
+
return h_gby_segs, cnt_gby_segs
|
320 |
+
|
321 |
+
def expand_by_repeat_times(source_encoding, lengths):
|
322 |
+
"""
|
323 |
+
source_encoding: [T, C]
|
324 |
+
lengths, list of int, [T,], how many times each token should repeat
|
325 |
+
return:
|
326 |
+
expanded_encoding: [T_expand, C]
|
327 |
+
"""
|
328 |
+
hid_dim = source_encoding.shape[1]
|
329 |
+
out2source = []
|
330 |
+
for i, length in enumerate(lengths):
|
331 |
+
out2source += [i for _ in range(length)]
|
332 |
+
out2source = torch.LongTensor(out2source).to(source_encoding.device)
|
333 |
+
out2source_ = out2source[:, None].repeat([1, hid_dim])
|
334 |
+
expanded_encoding = torch.gather(source_encoding, 0, out2source_) # [B, T, H]
|
335 |
+
return expanded_encoding
|
336 |
+
|
337 |
+
|
338 |
+
def expand_word2ph(word_encoding, ph2word):
|
339 |
+
word_encoding = F.pad(word_encoding,[0,0,1,0])
|
340 |
+
ph2word_ = ph2word[:, :, None].repeat([1, 1, word_encoding.shape[-1]])
|
341 |
+
out = torch.gather(word_encoding, 1, ph2word_) # [B, T, H]
|
342 |
+
return out
|
tts/modules/ar_dur/commons/transformer.py
ADDED
@@ -0,0 +1,767 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 ByteDance and/or its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
import torch
|
17 |
+
from torch import nn
|
18 |
+
from torch.nn import Parameter, Linear
|
19 |
+
from tts.modules.ar_dur.commons.layers import LayerNorm, Embedding
|
20 |
+
from tts.modules.ar_dur.commons.seq_utils import get_incremental_state, set_incremental_state, softmax, make_positions
|
21 |
+
import torch.nn.functional as F
|
22 |
+
|
23 |
+
DEFAULT_MAX_SOURCE_POSITIONS = 3000
|
24 |
+
DEFAULT_MAX_TARGET_POSITIONS = 3000
|
25 |
+
|
26 |
+
|
27 |
+
class SinusoidalPositionalEmbedding(nn.Module):
|
28 |
+
"""This module produces sinusoidal positional embeddings of any length.
|
29 |
+
|
30 |
+
Padding symbols are ignored.
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self, embedding_dim, padding_idx, init_size=1024):
|
34 |
+
super().__init__()
|
35 |
+
self.embedding_dim = embedding_dim
|
36 |
+
self.padding_idx = padding_idx
|
37 |
+
self.weights = SinusoidalPositionalEmbedding.get_embedding(
|
38 |
+
init_size,
|
39 |
+
embedding_dim,
|
40 |
+
padding_idx,
|
41 |
+
)
|
42 |
+
self.register_buffer('_float_tensor', torch.FloatTensor(1))
|
43 |
+
|
44 |
+
@staticmethod
|
45 |
+
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
|
46 |
+
"""Build sinusoidal embeddings.
|
47 |
+
|
48 |
+
This matches the implementation in tensor2tensor, but differs slightly
|
49 |
+
from the description in Section 3.5 of "Attention Is All You Need".
|
50 |
+
"""
|
51 |
+
half_dim = embedding_dim // 2
|
52 |
+
emb = math.log(10000) / (half_dim - 1)
|
53 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
|
54 |
+
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
|
55 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
|
56 |
+
if embedding_dim % 2 == 1:
|
57 |
+
# zero pad
|
58 |
+
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
|
59 |
+
if padding_idx is not None:
|
60 |
+
emb[padding_idx, :] = 0
|
61 |
+
return emb
|
62 |
+
|
63 |
+
def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
|
64 |
+
"""Input is expected to be of size [bsz x seqlen]."""
|
65 |
+
bsz, seq_len = input.shape[:2]
|
66 |
+
max_pos = self.padding_idx + 1 + seq_len
|
67 |
+
if self.weights is None or max_pos > self.weights.size(0):
|
68 |
+
# recompute/expand embeddings if needed
|
69 |
+
self.weights = SinusoidalPositionalEmbedding.get_embedding(
|
70 |
+
max_pos,
|
71 |
+
self.embedding_dim,
|
72 |
+
self.padding_idx,
|
73 |
+
)
|
74 |
+
self.weights = self.weights.to(self._float_tensor)
|
75 |
+
|
76 |
+
if incremental_state is not None:
|
77 |
+
# positions is the same for every token when decoding a single step
|
78 |
+
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
|
79 |
+
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
|
80 |
+
|
81 |
+
positions = make_positions(input, self.padding_idx) if positions is None else positions
|
82 |
+
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
|
83 |
+
|
84 |
+
def max_positions(self):
|
85 |
+
"""Maximum number of supported positions."""
|
86 |
+
return int(1e5) # an arbitrary large number
|
87 |
+
|
88 |
+
|
89 |
+
class TransformerFFNLayer(nn.Module):
|
90 |
+
def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu', bias=True):
|
91 |
+
super().__init__()
|
92 |
+
self.kernel_size = kernel_size
|
93 |
+
self.dropout = dropout
|
94 |
+
self.act = act
|
95 |
+
if padding == 'SAME':
|
96 |
+
self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size,
|
97 |
+
padding=kernel_size // 2, bias=bias)
|
98 |
+
elif padding == 'LEFT':
|
99 |
+
self.ffn_1 = nn.Sequential(
|
100 |
+
nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
|
101 |
+
nn.Conv1d(hidden_size, filter_size, kernel_size, bias=bias)
|
102 |
+
)
|
103 |
+
self.ffn_2 = Linear(filter_size, hidden_size, bias=bias)
|
104 |
+
|
105 |
+
def forward(self, x, incremental_state=None):
|
106 |
+
# x: T x B x C
|
107 |
+
if incremental_state is not None:
|
108 |
+
saved_state = self._get_input_buffer(incremental_state)
|
109 |
+
if 'prev_input' in saved_state:
|
110 |
+
prev_input = saved_state['prev_input']
|
111 |
+
x = torch.cat((prev_input, x), dim=0)
|
112 |
+
x = x[-self.kernel_size:]
|
113 |
+
saved_state['prev_input'] = x
|
114 |
+
self._set_input_buffer(incremental_state, saved_state)
|
115 |
+
|
116 |
+
x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
|
117 |
+
x = x * self.kernel_size ** -0.5
|
118 |
+
|
119 |
+
if incremental_state is not None:
|
120 |
+
x = x[-1:]
|
121 |
+
if self.act == 'gelu':
|
122 |
+
x = F.gelu(x)
|
123 |
+
if self.act == 'relu':
|
124 |
+
x = F.relu(x)
|
125 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
126 |
+
x = self.ffn_2(x)
|
127 |
+
return x
|
128 |
+
|
129 |
+
def _get_input_buffer(self, incremental_state):
|
130 |
+
return get_incremental_state(
|
131 |
+
self,
|
132 |
+
incremental_state,
|
133 |
+
'f',
|
134 |
+
) or {}
|
135 |
+
|
136 |
+
def _set_input_buffer(self, incremental_state, buffer):
|
137 |
+
set_incremental_state(
|
138 |
+
self,
|
139 |
+
incremental_state,
|
140 |
+
'f',
|
141 |
+
buffer,
|
142 |
+
)
|
143 |
+
|
144 |
+
def clear_buffer(self, incremental_state):
|
145 |
+
if incremental_state is not None:
|
146 |
+
saved_state = self._get_input_buffer(incremental_state)
|
147 |
+
if 'prev_input' in saved_state:
|
148 |
+
del saved_state['prev_input']
|
149 |
+
self._set_input_buffer(incremental_state, saved_state)
|
150 |
+
|
151 |
+
|
152 |
+
class MultiheadAttention(nn.Module):
|
153 |
+
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
|
154 |
+
add_bias_kv=False, add_zero_attn=False, self_attention=False,
|
155 |
+
encoder_decoder_attention=False):
|
156 |
+
super().__init__()
|
157 |
+
self.embed_dim = embed_dim
|
158 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
159 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
160 |
+
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
161 |
+
|
162 |
+
self.num_heads = num_heads
|
163 |
+
self.dropout = dropout
|
164 |
+
self.head_dim = embed_dim // num_heads
|
165 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
166 |
+
self.scaling = self.head_dim ** -0.5
|
167 |
+
|
168 |
+
self.self_attention = self_attention
|
169 |
+
self.encoder_decoder_attention = encoder_decoder_attention
|
170 |
+
|
171 |
+
assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
|
172 |
+
'value to be of the same size'
|
173 |
+
|
174 |
+
if self.qkv_same_dim:
|
175 |
+
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
|
176 |
+
else:
|
177 |
+
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
|
178 |
+
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
|
179 |
+
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
|
180 |
+
|
181 |
+
if bias:
|
182 |
+
self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
|
183 |
+
else:
|
184 |
+
self.register_parameter('in_proj_bias', None)
|
185 |
+
|
186 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
187 |
+
|
188 |
+
if add_bias_kv:
|
189 |
+
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
190 |
+
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
191 |
+
else:
|
192 |
+
self.bias_k = self.bias_v = None
|
193 |
+
|
194 |
+
self.add_zero_attn = add_zero_attn
|
195 |
+
|
196 |
+
self.reset_parameters()
|
197 |
+
|
198 |
+
self.enable_torch_version = False
|
199 |
+
self.last_attn_probs = None
|
200 |
+
|
201 |
+
def reset_parameters(self):
|
202 |
+
if self.qkv_same_dim:
|
203 |
+
nn.init.xavier_uniform_(self.in_proj_weight)
|
204 |
+
else:
|
205 |
+
nn.init.xavier_uniform_(self.k_proj_weight)
|
206 |
+
nn.init.xavier_uniform_(self.v_proj_weight)
|
207 |
+
nn.init.xavier_uniform_(self.q_proj_weight)
|
208 |
+
|
209 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
210 |
+
if self.in_proj_bias is not None:
|
211 |
+
nn.init.constant_(self.in_proj_bias, 0.)
|
212 |
+
nn.init.constant_(self.out_proj.bias, 0.)
|
213 |
+
if self.bias_k is not None:
|
214 |
+
nn.init.xavier_normal_(self.bias_k)
|
215 |
+
if self.bias_v is not None:
|
216 |
+
nn.init.xavier_normal_(self.bias_v)
|
217 |
+
|
218 |
+
def forward(
|
219 |
+
self,
|
220 |
+
query, key, value,
|
221 |
+
key_padding_mask=None,
|
222 |
+
incremental_state=None,
|
223 |
+
need_weights=True,
|
224 |
+
static_kv=False,
|
225 |
+
attn_mask=None,
|
226 |
+
before_softmax=False,
|
227 |
+
need_head_weights=False,
|
228 |
+
enc_dec_attn_constraint_mask=None,
|
229 |
+
reset_attn_weight=None
|
230 |
+
):
|
231 |
+
"""Input shape: Time x Batch x Channel
|
232 |
+
|
233 |
+
Args:
|
234 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
235 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
236 |
+
padding elements are indicated by 1s.
|
237 |
+
need_weights (bool, optional): return the attention weights,
|
238 |
+
averaged over heads (default: False).
|
239 |
+
attn_mask (ByteTensor, optional): typically used to
|
240 |
+
implement causal attention, where the mask prevents the
|
241 |
+
attention from looking forward in time (default: None).
|
242 |
+
before_softmax (bool, optional): return the raw attention
|
243 |
+
weights and values before the attention softmax.
|
244 |
+
need_head_weights (bool, optional): return the attention
|
245 |
+
weights for each head. Implies *need_weights*. Default:
|
246 |
+
return the average attention weights over all heads.
|
247 |
+
"""
|
248 |
+
if need_head_weights:
|
249 |
+
need_weights = True
|
250 |
+
|
251 |
+
tgt_len, bsz, embed_dim = query.size()
|
252 |
+
assert embed_dim == self.embed_dim
|
253 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
254 |
+
|
255 |
+
if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
|
256 |
+
if self.qkv_same_dim:
|
257 |
+
return F.multi_head_attention_forward(query, key, value,
|
258 |
+
self.embed_dim, self.num_heads,
|
259 |
+
self.in_proj_weight,
|
260 |
+
self.in_proj_bias, self.bias_k, self.bias_v,
|
261 |
+
self.add_zero_attn, self.dropout,
|
262 |
+
self.out_proj.weight, self.out_proj.bias,
|
263 |
+
self.training, key_padding_mask, need_weights,
|
264 |
+
attn_mask)
|
265 |
+
else:
|
266 |
+
return F.multi_head_attention_forward(query, key, value,
|
267 |
+
self.embed_dim, self.num_heads,
|
268 |
+
torch.empty([0]),
|
269 |
+
self.in_proj_bias, self.bias_k, self.bias_v,
|
270 |
+
self.add_zero_attn, self.dropout,
|
271 |
+
self.out_proj.weight, self.out_proj.bias,
|
272 |
+
self.training, key_padding_mask, need_weights,
|
273 |
+
attn_mask, use_separate_proj_weight=True,
|
274 |
+
q_proj_weight=self.q_proj_weight,
|
275 |
+
k_proj_weight=self.k_proj_weight,
|
276 |
+
v_proj_weight=self.v_proj_weight)
|
277 |
+
|
278 |
+
if incremental_state is not None:
|
279 |
+
saved_state = self._get_input_buffer(incremental_state)
|
280 |
+
if 'prev_key' in saved_state:
|
281 |
+
# previous time steps are cached - no need to recompute
|
282 |
+
# key and value if they are static
|
283 |
+
if static_kv:
|
284 |
+
assert self.encoder_decoder_attention and not self.self_attention
|
285 |
+
key = value = None
|
286 |
+
else:
|
287 |
+
saved_state = None
|
288 |
+
|
289 |
+
if self.self_attention:
|
290 |
+
# self-attention
|
291 |
+
q, k, v = self.in_proj_qkv(query)
|
292 |
+
elif self.encoder_decoder_attention:
|
293 |
+
# encoder-decoder attention
|
294 |
+
q = self.in_proj_q(query)
|
295 |
+
if key is None:
|
296 |
+
assert value is None
|
297 |
+
k = v = None
|
298 |
+
else:
|
299 |
+
k = self.in_proj_k(key)
|
300 |
+
v = self.in_proj_v(key)
|
301 |
+
|
302 |
+
else:
|
303 |
+
q = self.in_proj_q(query)
|
304 |
+
k = self.in_proj_k(key)
|
305 |
+
v = self.in_proj_v(value)
|
306 |
+
q = q * self.scaling
|
307 |
+
|
308 |
+
if self.bias_k is not None:
|
309 |
+
assert self.bias_v is not None
|
310 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
311 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
312 |
+
if attn_mask is not None:
|
313 |
+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
314 |
+
if key_padding_mask is not None:
|
315 |
+
key_padding_mask = torch.cat(
|
316 |
+
[key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
|
317 |
+
|
318 |
+
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
319 |
+
if k is not None:
|
320 |
+
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
321 |
+
if v is not None:
|
322 |
+
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
323 |
+
|
324 |
+
if saved_state is not None:
|
325 |
+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
326 |
+
if 'prev_key' in saved_state:
|
327 |
+
prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
|
328 |
+
if static_kv:
|
329 |
+
k = prev_key
|
330 |
+
else:
|
331 |
+
k = torch.cat((prev_key, k), dim=1)
|
332 |
+
if 'prev_value' in saved_state:
|
333 |
+
prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
|
334 |
+
if static_kv:
|
335 |
+
v = prev_value
|
336 |
+
else:
|
337 |
+
v = torch.cat((prev_value, v), dim=1)
|
338 |
+
if 'prev_key_padding_mask' in saved_state and saved_state['prev_key_padding_mask'] is not None:
|
339 |
+
prev_key_padding_mask = saved_state['prev_key_padding_mask']
|
340 |
+
if static_kv:
|
341 |
+
key_padding_mask = prev_key_padding_mask
|
342 |
+
else:
|
343 |
+
key_padding_mask = torch.cat((prev_key_padding_mask, key_padding_mask), dim=1)
|
344 |
+
|
345 |
+
saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
346 |
+
saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
347 |
+
saved_state['prev_key_padding_mask'] = key_padding_mask
|
348 |
+
|
349 |
+
self._set_input_buffer(incremental_state, saved_state)
|
350 |
+
|
351 |
+
src_len = k.size(1)
|
352 |
+
|
353 |
+
# This is part of a workaround to get around fork/join parallelism
|
354 |
+
# not supporting Optional types.
|
355 |
+
if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
|
356 |
+
key_padding_mask = None
|
357 |
+
|
358 |
+
if key_padding_mask is not None:
|
359 |
+
assert key_padding_mask.size(0) == bsz
|
360 |
+
assert key_padding_mask.size(1) == src_len
|
361 |
+
|
362 |
+
if self.add_zero_attn:
|
363 |
+
src_len += 1
|
364 |
+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
365 |
+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
366 |
+
if attn_mask is not None:
|
367 |
+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
368 |
+
if key_padding_mask is not None:
|
369 |
+
key_padding_mask = torch.cat(
|
370 |
+
[key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
|
371 |
+
|
372 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
373 |
+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
374 |
+
|
375 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
376 |
+
|
377 |
+
if attn_mask is not None:
|
378 |
+
if len(attn_mask.shape) == 2:
|
379 |
+
attn_mask = attn_mask.unsqueeze(0)
|
380 |
+
elif len(attn_mask.shape) == 3:
|
381 |
+
attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
|
382 |
+
bsz * self.num_heads, tgt_len, src_len)
|
383 |
+
attn_weights = attn_weights + attn_mask
|
384 |
+
|
385 |
+
if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
|
386 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
387 |
+
attn_weights = attn_weights.masked_fill(
|
388 |
+
enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
|
389 |
+
-1e8,
|
390 |
+
)
|
391 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
392 |
+
|
393 |
+
if key_padding_mask is not None:
|
394 |
+
# don't attend to padding symbols
|
395 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
396 |
+
attn_weights = attn_weights.masked_fill(
|
397 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
398 |
+
-1e8,
|
399 |
+
)
|
400 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
401 |
+
|
402 |
+
attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
403 |
+
|
404 |
+
if before_softmax:
|
405 |
+
return attn_weights, v
|
406 |
+
|
407 |
+
attn_weights_float = softmax(attn_weights, dim=-1)
|
408 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
409 |
+
attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
|
410 |
+
|
411 |
+
if reset_attn_weight is not None:
|
412 |
+
if reset_attn_weight:
|
413 |
+
self.last_attn_probs = attn_probs.detach()
|
414 |
+
else:
|
415 |
+
assert self.last_attn_probs is not None
|
416 |
+
attn_probs = self.last_attn_probs
|
417 |
+
attn = torch.bmm(attn_probs, v)
|
418 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
419 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
420 |
+
attn = self.out_proj(attn)
|
421 |
+
|
422 |
+
if need_weights:
|
423 |
+
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
|
424 |
+
if not need_head_weights:
|
425 |
+
# average attention weights over heads
|
426 |
+
attn_weights = attn_weights.mean(dim=0)
|
427 |
+
else:
|
428 |
+
attn_weights = None
|
429 |
+
|
430 |
+
return attn, (attn_weights, attn_logits)
|
431 |
+
|
432 |
+
def in_proj_qkv(self, query):
|
433 |
+
return self._in_proj(query).chunk(3, dim=-1)
|
434 |
+
|
435 |
+
def in_proj_q(self, query):
|
436 |
+
if self.qkv_same_dim:
|
437 |
+
return self._in_proj(query, end=self.embed_dim)
|
438 |
+
else:
|
439 |
+
bias = self.in_proj_bias
|
440 |
+
if bias is not None:
|
441 |
+
bias = bias[:self.embed_dim]
|
442 |
+
return F.linear(query, self.q_proj_weight, bias)
|
443 |
+
|
444 |
+
def in_proj_k(self, key):
|
445 |
+
if self.qkv_same_dim:
|
446 |
+
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
|
447 |
+
else:
|
448 |
+
weight = self.k_proj_weight
|
449 |
+
bias = self.in_proj_bias
|
450 |
+
if bias is not None:
|
451 |
+
bias = bias[self.embed_dim:2 * self.embed_dim]
|
452 |
+
return F.linear(key, weight, bias)
|
453 |
+
|
454 |
+
def in_proj_v(self, value):
|
455 |
+
if self.qkv_same_dim:
|
456 |
+
return self._in_proj(value, start=2 * self.embed_dim)
|
457 |
+
else:
|
458 |
+
weight = self.v_proj_weight
|
459 |
+
bias = self.in_proj_bias
|
460 |
+
if bias is not None:
|
461 |
+
bias = bias[2 * self.embed_dim:]
|
462 |
+
return F.linear(value, weight, bias)
|
463 |
+
|
464 |
+
def _in_proj(self, input, start=0, end=None):
|
465 |
+
weight = self.in_proj_weight
|
466 |
+
bias = self.in_proj_bias
|
467 |
+
weight = weight[start:end, :]
|
468 |
+
if bias is not None:
|
469 |
+
bias = bias[start:end]
|
470 |
+
return F.linear(input, weight, bias)
|
471 |
+
|
472 |
+
def _get_input_buffer(self, incremental_state):
|
473 |
+
return get_incremental_state(
|
474 |
+
self,
|
475 |
+
incremental_state,
|
476 |
+
'attn_state',
|
477 |
+
) or {}
|
478 |
+
|
479 |
+
def _set_input_buffer(self, incremental_state, buffer):
|
480 |
+
set_incremental_state(
|
481 |
+
self,
|
482 |
+
incremental_state,
|
483 |
+
'attn_state',
|
484 |
+
buffer,
|
485 |
+
)
|
486 |
+
|
487 |
+
def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
|
488 |
+
return attn_weights
|
489 |
+
|
490 |
+
def clear_buffer(self, incremental_state=None):
|
491 |
+
if incremental_state is not None:
|
492 |
+
saved_state = self._get_input_buffer(incremental_state)
|
493 |
+
if 'prev_key' in saved_state:
|
494 |
+
del saved_state['prev_key']
|
495 |
+
if 'prev_value' in saved_state:
|
496 |
+
del saved_state['prev_value']
|
497 |
+
self._set_input_buffer(incremental_state, saved_state)
|
498 |
+
|
499 |
+
|
500 |
+
class EncSALayer(nn.Module):
|
501 |
+
def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
|
502 |
+
relu_dropout=0.1, kernel_size=9, padding='SAME', act='gelu',
|
503 |
+
ffn_hidden_size=1024):
|
504 |
+
super().__init__()
|
505 |
+
self.c = c
|
506 |
+
self.dropout = dropout
|
507 |
+
self.num_heads = num_heads
|
508 |
+
if num_heads > 0:
|
509 |
+
self.layer_norm1 = LayerNorm(c)
|
510 |
+
self.self_attn = MultiheadAttention(
|
511 |
+
self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False)
|
512 |
+
self.layer_norm2 = LayerNorm(c)
|
513 |
+
self.ffn = TransformerFFNLayer(
|
514 |
+
c, ffn_hidden_size, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act)
|
515 |
+
|
516 |
+
def forward(self, x, encoder_padding_mask=None, **kwargs):
|
517 |
+
layer_norm_training = kwargs.get('layer_norm_training', None)
|
518 |
+
if layer_norm_training is not None:
|
519 |
+
self.layer_norm1.training = layer_norm_training
|
520 |
+
self.layer_norm2.training = layer_norm_training
|
521 |
+
if self.num_heads > 0:
|
522 |
+
residual = x
|
523 |
+
x = self.layer_norm1(x)
|
524 |
+
x, _, = self.self_attn(
|
525 |
+
query=x,
|
526 |
+
key=x,
|
527 |
+
value=x,
|
528 |
+
key_padding_mask=encoder_padding_mask
|
529 |
+
)
|
530 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
531 |
+
x = residual + x
|
532 |
+
x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
|
533 |
+
|
534 |
+
residual = x
|
535 |
+
x = self.layer_norm2(x)
|
536 |
+
x = self.ffn(x)
|
537 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
538 |
+
x = residual + x
|
539 |
+
x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
|
540 |
+
return x
|
541 |
+
|
542 |
+
|
543 |
+
class DecSALayer(nn.Module):
|
544 |
+
def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1,
|
545 |
+
kernel_size=9, ffn_hidden_size=1024, act='gelu', post_ln=False):
|
546 |
+
super().__init__()
|
547 |
+
self.c = c
|
548 |
+
self.dropout = dropout
|
549 |
+
self.layer_norm1 = LayerNorm(c)
|
550 |
+
self.self_attn = MultiheadAttention(
|
551 |
+
c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
|
552 |
+
)
|
553 |
+
self.layer_norm2 = LayerNorm(c)
|
554 |
+
self.encoder_attn = MultiheadAttention(
|
555 |
+
c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False,
|
556 |
+
)
|
557 |
+
self.layer_norm3 = LayerNorm(c)
|
558 |
+
self.ffn = TransformerFFNLayer(
|
559 |
+
c, ffn_hidden_size, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
|
560 |
+
self.post_ln = post_ln
|
561 |
+
|
562 |
+
def forward(
|
563 |
+
self,
|
564 |
+
x,
|
565 |
+
encoder_out=None,
|
566 |
+
encoder_padding_mask=None,
|
567 |
+
incremental_state=None,
|
568 |
+
self_attn_mask=None,
|
569 |
+
self_attn_padding_mask=None,
|
570 |
+
attn_out=None,
|
571 |
+
reset_attn_weight=None,
|
572 |
+
**kwargs,
|
573 |
+
):
|
574 |
+
layer_norm_training = kwargs.get('layer_norm_training', None)
|
575 |
+
if layer_norm_training is not None:
|
576 |
+
self.layer_norm1.training = layer_norm_training
|
577 |
+
self.layer_norm2.training = layer_norm_training
|
578 |
+
self.layer_norm3.training = layer_norm_training
|
579 |
+
residual = x
|
580 |
+
if not self.post_ln:
|
581 |
+
x = self.layer_norm1(x)
|
582 |
+
x, _ = self.self_attn(
|
583 |
+
query=x,
|
584 |
+
key=x,
|
585 |
+
value=x,
|
586 |
+
key_padding_mask=self_attn_padding_mask,
|
587 |
+
incremental_state=incremental_state,
|
588 |
+
attn_mask=self_attn_mask
|
589 |
+
)
|
590 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
591 |
+
x = residual + x
|
592 |
+
if self.post_ln:
|
593 |
+
x = self.layer_norm1(x)
|
594 |
+
|
595 |
+
attn_logits = None
|
596 |
+
if encoder_out is not None or attn_out is not None:
|
597 |
+
residual = x
|
598 |
+
if not self.post_ln:
|
599 |
+
x = self.layer_norm2(x)
|
600 |
+
if encoder_out is not None:
|
601 |
+
x, attn = self.encoder_attn(
|
602 |
+
query=x,
|
603 |
+
key=encoder_out,
|
604 |
+
value=encoder_out,
|
605 |
+
key_padding_mask=encoder_padding_mask,
|
606 |
+
incremental_state=incremental_state,
|
607 |
+
static_kv=True,
|
608 |
+
enc_dec_attn_constraint_mask=get_incremental_state(self, incremental_state,
|
609 |
+
'enc_dec_attn_constraint_mask'),
|
610 |
+
reset_attn_weight=reset_attn_weight
|
611 |
+
)
|
612 |
+
attn_logits = attn[1]
|
613 |
+
elif attn_out is not None:
|
614 |
+
x = self.encoder_attn.in_proj_v(attn_out)
|
615 |
+
if encoder_out is not None or attn_out is not None:
|
616 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
617 |
+
x = residual + x
|
618 |
+
if self.post_ln:
|
619 |
+
x = self.layer_norm2(x)
|
620 |
+
|
621 |
+
residual = x
|
622 |
+
if not self.post_ln:
|
623 |
+
x = self.layer_norm3(x)
|
624 |
+
x = self.ffn(x, incremental_state=incremental_state)
|
625 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
626 |
+
x = residual + x
|
627 |
+
if self.post_ln:
|
628 |
+
x = self.layer_norm3(x)
|
629 |
+
return x, attn_logits
|
630 |
+
|
631 |
+
def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
|
632 |
+
self.encoder_attn.clear_buffer(incremental_state)
|
633 |
+
self.ffn.clear_buffer(incremental_state)
|
634 |
+
|
635 |
+
def set_buffer(self, name, tensor, incremental_state):
|
636 |
+
return set_incremental_state(self, incremental_state, name, tensor)
|
637 |
+
|
638 |
+
|
639 |
+
class TransformerEncoderLayer(nn.Module):
|
640 |
+
def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2, ffn_hidden_size=1024):
|
641 |
+
super().__init__()
|
642 |
+
self.hidden_size = hidden_size
|
643 |
+
self.dropout = dropout
|
644 |
+
self.num_heads = num_heads
|
645 |
+
self.op = EncSALayer(
|
646 |
+
hidden_size, num_heads, dropout=dropout,
|
647 |
+
attention_dropout=0.0, relu_dropout=dropout,
|
648 |
+
kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size)
|
649 |
+
|
650 |
+
def forward(self, x, **kwargs):
|
651 |
+
return self.op(x, **kwargs)
|
652 |
+
|
653 |
+
|
654 |
+
class TransformerDecoderLayer(nn.Module):
|
655 |
+
def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2, ffn_hidden_size=1024, post_ln=False):
|
656 |
+
super().__init__()
|
657 |
+
self.hidden_size = hidden_size
|
658 |
+
self.dropout = dropout
|
659 |
+
self.num_heads = num_heads
|
660 |
+
self.op = DecSALayer(
|
661 |
+
hidden_size, num_heads, dropout=dropout,
|
662 |
+
attention_dropout=0.0, relu_dropout=dropout,
|
663 |
+
kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size,
|
664 |
+
post_ln=post_ln)
|
665 |
+
|
666 |
+
def forward(self, x, **kwargs):
|
667 |
+
return self.op(x, **kwargs)
|
668 |
+
|
669 |
+
def clear_buffer(self, *args):
|
670 |
+
return self.op.clear_buffer(*args)
|
671 |
+
|
672 |
+
def set_buffer(self, *args):
|
673 |
+
return self.op.set_buffer(*args)
|
674 |
+
|
675 |
+
|
676 |
+
class FFTBlocks(nn.Module):
|
677 |
+
def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=0.0,
|
678 |
+
num_heads=2, use_pos_embed=True, use_last_norm=True,
|
679 |
+
use_pos_embed_alpha=True, ffn_hidden_size=1024):
|
680 |
+
super().__init__()
|
681 |
+
self.num_layers = num_layers
|
682 |
+
embed_dim = self.hidden_size = hidden_size
|
683 |
+
self.dropout = dropout
|
684 |
+
self.use_pos_embed = use_pos_embed
|
685 |
+
self.use_last_norm = use_last_norm
|
686 |
+
if use_pos_embed:
|
687 |
+
self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS
|
688 |
+
self.padding_idx = 0
|
689 |
+
self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1
|
690 |
+
self.embed_positions = SinusoidalPositionalEmbedding(
|
691 |
+
embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
|
692 |
+
)
|
693 |
+
|
694 |
+
self.layers = nn.ModuleList([])
|
695 |
+
self.layers.extend([
|
696 |
+
TransformerEncoderLayer(self.hidden_size, self.dropout,
|
697 |
+
kernel_size=ffn_kernel_size, num_heads=num_heads,
|
698 |
+
ffn_hidden_size=ffn_hidden_size)
|
699 |
+
for _ in range(self.num_layers)
|
700 |
+
])
|
701 |
+
if self.use_last_norm:
|
702 |
+
self.layer_norm = nn.LayerNorm(embed_dim)
|
703 |
+
else:
|
704 |
+
self.layer_norm = None
|
705 |
+
|
706 |
+
def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False):
|
707 |
+
"""
|
708 |
+
:param x: [B, T, C]
|
709 |
+
:param padding_mask: [B, T]
|
710 |
+
:return: [B, T, C] or [L, B, T, C]
|
711 |
+
"""
|
712 |
+
padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
|
713 |
+
nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1]
|
714 |
+
if self.use_pos_embed:
|
715 |
+
positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
|
716 |
+
x = x + positions
|
717 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
718 |
+
# B x T x C -> T x B x C
|
719 |
+
x = x.transpose(0, 1) * nonpadding_mask_TB
|
720 |
+
hiddens = []
|
721 |
+
for layer in self.layers:
|
722 |
+
x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
|
723 |
+
hiddens.append(x)
|
724 |
+
if self.use_last_norm:
|
725 |
+
x = self.layer_norm(x) * nonpadding_mask_TB
|
726 |
+
if return_hiddens:
|
727 |
+
x = torch.stack(hiddens, 0) # [L, T, B, C]
|
728 |
+
x = x.transpose(1, 2) # [L, B, T, C]
|
729 |
+
else:
|
730 |
+
x = x.transpose(0, 1) # [B, T, C]
|
731 |
+
return x
|
732 |
+
|
733 |
+
|
734 |
+
class FastSpeechEncoder(FFTBlocks):
|
735 |
+
def __init__(self, dict_size, hidden_size=256, num_layers=4, kernel_size=9,
|
736 |
+
dropout=0.0, num_heads=2, ffn_hidden_size=1024):
|
737 |
+
super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads,
|
738 |
+
use_pos_embed=False, dropout=dropout, ffn_hidden_size=ffn_hidden_size)
|
739 |
+
self.embed_tokens = Embedding(dict_size, hidden_size, 0)
|
740 |
+
self.embed_scale = math.sqrt(hidden_size)
|
741 |
+
self.padding_idx = 0
|
742 |
+
self.embed_positions = SinusoidalPositionalEmbedding(
|
743 |
+
hidden_size, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
|
744 |
+
)
|
745 |
+
|
746 |
+
def forward(self, txt_tokens, attn_mask=None, other_embeds=0):
|
747 |
+
"""
|
748 |
+
|
749 |
+
:param txt_tokens: [B, T]
|
750 |
+
:return: {
|
751 |
+
'encoder_out': [B x T x C]
|
752 |
+
}
|
753 |
+
"""
|
754 |
+
encoder_padding_mask = txt_tokens.eq(self.padding_idx).data
|
755 |
+
x = self.forward_embedding(txt_tokens) + other_embeds # [B, T, H]
|
756 |
+
if self.num_layers > 0:
|
757 |
+
x = super(FastSpeechEncoder, self).forward(x, encoder_padding_mask, attn_mask=attn_mask)
|
758 |
+
return x
|
759 |
+
|
760 |
+
def forward_embedding(self, txt_tokens):
|
761 |
+
# embed tokens and positions
|
762 |
+
x = self.embed_scale * self.embed_tokens(txt_tokens)
|
763 |
+
if self.use_pos_embed:
|
764 |
+
positions = self.embed_positions(txt_tokens)
|
765 |
+
x = x + positions
|
766 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
767 |
+
return x
|
tts/modules/llm_dit/cfm.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2023 Alexander Tong
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
# Copyright (c) [2023] [Alexander Tong]
|
24 |
+
# Copyright (c) [2025] [Ziyue Jiang]
|
25 |
+
# SPDX-License-Identifier: MIT
|
26 |
+
# This file has been modified by Ziyue Jiang on 2025/03/19
|
27 |
+
# Original file was released under MIT, with the full license text # available at https://github.com/atong01/conditional-flow-matching/blob/1.0.7/LICENSE.
|
28 |
+
# This modified file is released under the same license.
|
29 |
+
|
30 |
+
import math
|
31 |
+
import torch
|
32 |
+
from typing import Union
|
33 |
+
from torch.distributions import LogisticNormal
|
34 |
+
|
35 |
+
|
36 |
+
class LogitNormalTrainingTimesteps:
|
37 |
+
def __init__(self, T=1000.0, loc=0.0, scale=1.0):
|
38 |
+
assert T > 0
|
39 |
+
self.T = T
|
40 |
+
self.dist = LogisticNormal(loc, scale)
|
41 |
+
|
42 |
+
def sample(self, size, device):
|
43 |
+
t = self.dist.sample(size)[..., 0].to(device)
|
44 |
+
return t
|
45 |
+
|
46 |
+
|
47 |
+
def pad_t_like_x(t, x):
|
48 |
+
"""Function to reshape the time vector t by the number of dimensions of x.
|
49 |
+
|
50 |
+
Parameters
|
51 |
+
----------
|
52 |
+
x : Tensor, shape (bs, *dim)
|
53 |
+
represents the source minibatch
|
54 |
+
t : FloatTensor, shape (bs)
|
55 |
+
|
56 |
+
Returns
|
57 |
+
-------
|
58 |
+
t : Tensor, shape (bs, number of x dimensions)
|
59 |
+
|
60 |
+
Example
|
61 |
+
-------
|
62 |
+
x: Tensor (bs, C, W, H)
|
63 |
+
t: Vector (bs)
|
64 |
+
pad_t_like_x(t, x): Tensor (bs, 1, 1, 1)
|
65 |
+
"""
|
66 |
+
if isinstance(t, (float, int)):
|
67 |
+
return t
|
68 |
+
return t.reshape(-1, *([1] * (x.dim() - 1)))
|
69 |
+
|
70 |
+
|
71 |
+
class ConditionalFlowMatcher:
|
72 |
+
"""Base class for conditional flow matching methods. This class implements the independent
|
73 |
+
conditional flow matching methods from [1] and serves as a parent class for all other flow
|
74 |
+
matching methods.
|
75 |
+
|
76 |
+
It implements:
|
77 |
+
- Drawing data from gaussian probability path N(t * x1 + (1 - t) * x0, sigma) function
|
78 |
+
- conditional flow matching ut(x1|x0) = x1 - x0
|
79 |
+
- score function $\nabla log p_t(x|x0, x1)$
|
80 |
+
"""
|
81 |
+
|
82 |
+
def __init__(self, sigma: Union[float, int] = 0.0):
|
83 |
+
r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$.
|
84 |
+
|
85 |
+
Parameters
|
86 |
+
----------
|
87 |
+
sigma : Union[float, int]
|
88 |
+
"""
|
89 |
+
self.sigma = sigma
|
90 |
+
self.time_sampler = LogitNormalTrainingTimesteps()
|
91 |
+
|
92 |
+
def compute_mu_t(self, x0, x1, t):
|
93 |
+
"""
|
94 |
+
Compute the mean of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].
|
95 |
+
|
96 |
+
Parameters
|
97 |
+
----------
|
98 |
+
x0 : Tensor, shape (bs, *dim)
|
99 |
+
represents the source minibatch
|
100 |
+
x1 : Tensor, shape (bs, *dim)
|
101 |
+
represents the target minibatch
|
102 |
+
t : FloatTensor, shape (bs)
|
103 |
+
|
104 |
+
Returns
|
105 |
+
-------
|
106 |
+
mean mu_t: t * x1 + (1 - t) * x0
|
107 |
+
|
108 |
+
References
|
109 |
+
----------
|
110 |
+
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
|
111 |
+
"""
|
112 |
+
t = pad_t_like_x(t, x0)
|
113 |
+
return t * x1 + (1 - t) * x0
|
114 |
+
|
115 |
+
def compute_sigma_t(self, t):
|
116 |
+
"""
|
117 |
+
Compute the standard deviation of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].
|
118 |
+
|
119 |
+
Parameters
|
120 |
+
----------
|
121 |
+
t : FloatTensor, shape (bs)
|
122 |
+
|
123 |
+
Returns
|
124 |
+
-------
|
125 |
+
standard deviation sigma
|
126 |
+
|
127 |
+
References
|
128 |
+
----------
|
129 |
+
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
|
130 |
+
"""
|
131 |
+
del t
|
132 |
+
return self.sigma
|
133 |
+
|
134 |
+
def sample_xt(self, x0, x1, t, epsilon):
|
135 |
+
"""
|
136 |
+
Draw a sample from the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].
|
137 |
+
|
138 |
+
Parameters
|
139 |
+
----------
|
140 |
+
x0 : Tensor, shape (bs, *dim)
|
141 |
+
represents the source minibatch
|
142 |
+
x1 : Tensor, shape (bs, *dim)
|
143 |
+
represents the target minibatch
|
144 |
+
t : FloatTensor, shape (bs)
|
145 |
+
epsilon : Tensor, shape (bs, *dim)
|
146 |
+
noise sample from N(0, 1)
|
147 |
+
|
148 |
+
Returns
|
149 |
+
-------
|
150 |
+
xt : Tensor, shape (bs, *dim)
|
151 |
+
|
152 |
+
References
|
153 |
+
----------
|
154 |
+
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
|
155 |
+
"""
|
156 |
+
mu_t = self.compute_mu_t(x0, x1, t)
|
157 |
+
sigma_t = self.compute_sigma_t(t)
|
158 |
+
sigma_t = pad_t_like_x(sigma_t, x0)
|
159 |
+
return mu_t + sigma_t * epsilon
|
160 |
+
|
161 |
+
def compute_conditional_flow(self, x0, x1, t, xt):
|
162 |
+
"""
|
163 |
+
Compute the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].
|
164 |
+
|
165 |
+
Parameters
|
166 |
+
----------
|
167 |
+
x0 : Tensor, shape (bs, *dim)
|
168 |
+
represents the source minibatch
|
169 |
+
x1 : Tensor, shape (bs, *dim)
|
170 |
+
represents the target minibatch
|
171 |
+
t : FloatTensor, shape (bs)
|
172 |
+
xt : Tensor, shape (bs, *dim)
|
173 |
+
represents the samples drawn from probability path pt
|
174 |
+
|
175 |
+
Returns
|
176 |
+
-------
|
177 |
+
ut : conditional vector field ut(x1|x0) = x1 - x0
|
178 |
+
|
179 |
+
References
|
180 |
+
----------
|
181 |
+
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
|
182 |
+
"""
|
183 |
+
del t, xt
|
184 |
+
return x1 - x0
|
185 |
+
|
186 |
+
def sample_noise_like(self, x):
|
187 |
+
return torch.randn_like(x)
|
188 |
+
|
189 |
+
def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False):
|
190 |
+
"""
|
191 |
+
Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma))
|
192 |
+
and the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].
|
193 |
+
|
194 |
+
Parameters
|
195 |
+
----------
|
196 |
+
x0 : Tensor, shape (bs, *dim)
|
197 |
+
represents the source minibatch
|
198 |
+
x1 : Tensor, shape (bs, *dim)
|
199 |
+
represents the target minibatch
|
200 |
+
(optionally) t : Tensor, shape (bs)
|
201 |
+
represents the time levels
|
202 |
+
if None, drawn from uniform [0,1]
|
203 |
+
return_noise : bool
|
204 |
+
return the noise sample epsilon
|
205 |
+
|
206 |
+
|
207 |
+
Returns
|
208 |
+
-------
|
209 |
+
t : FloatTensor, shape (bs)
|
210 |
+
xt : Tensor, shape (bs, *dim)
|
211 |
+
represents the samples drawn from probability path pt
|
212 |
+
ut : conditional vector field ut(x1|x0) = x1 - x0
|
213 |
+
(optionally) eps: Tensor, shape (bs, *dim) such that xt = mu_t + sigma_t * epsilon
|
214 |
+
|
215 |
+
References
|
216 |
+
----------
|
217 |
+
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
|
218 |
+
"""
|
219 |
+
if t is None:
|
220 |
+
# t = torch.rand(x0.shape[0]).type_as(x0)
|
221 |
+
t = self.time_sampler.sample([x0.shape[0]], x0.device).type_as(x0)
|
222 |
+
|
223 |
+
assert len(t) == x0.shape[0], "t has to have batch size dimension"
|
224 |
+
|
225 |
+
eps = self.sample_noise_like(x0)
|
226 |
+
xt = self.sample_xt(x0, x1, t, eps)
|
227 |
+
ut = self.compute_conditional_flow(x0, x1, t, xt)
|
228 |
+
if return_noise:
|
229 |
+
return t, xt, ut, eps
|
230 |
+
else:
|
231 |
+
return t, xt, ut
|
232 |
+
|
233 |
+
def compute_lambda(self, t):
|
234 |
+
"""Compute the lambda function, see Eq.(23) [3].
|
235 |
+
|
236 |
+
Parameters
|
237 |
+
----------
|
238 |
+
t : FloatTensor, shape (bs)
|
239 |
+
|
240 |
+
Returns
|
241 |
+
-------
|
242 |
+
lambda : score weighting function
|
243 |
+
|
244 |
+
References
|
245 |
+
----------
|
246 |
+
[4] Simulation-free Schrodinger bridges via score and flow matching, Preprint, Tong et al.
|
247 |
+
"""
|
248 |
+
sigma_t = self.compute_sigma_t(t)
|
249 |
+
return 2 * sigma_t / (self.sigma**2 + 1e-8)
|
250 |
+
|
251 |
+
|
252 |
+
class VariancePreservingConditionalFlowMatcher(ConditionalFlowMatcher):
|
253 |
+
"""Albergo et al. 2023 trigonometric interpolants class. This class inherits the
|
254 |
+
ConditionalFlowMatcher and override the compute_mu_t and compute_conditional_flow functions in
|
255 |
+
order to compute [3]'s trigonometric interpolants.
|
256 |
+
|
257 |
+
[3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
|
258 |
+
"""
|
259 |
+
|
260 |
+
def compute_mu_t(self, x0, x1, t):
|
261 |
+
r"""Compute the mean of the probability path (Eq.5) from [3].
|
262 |
+
|
263 |
+
Parameters
|
264 |
+
----------
|
265 |
+
x0 : Tensor, shape (bs, *dim)
|
266 |
+
represents the source minibatch
|
267 |
+
x1 : Tensor, shape (bs, *dim)
|
268 |
+
represents the target minibatch
|
269 |
+
t : FloatTensor, shape (bs)
|
270 |
+
|
271 |
+
Returns
|
272 |
+
-------
|
273 |
+
mean mu_t: cos(pi t/2)x0 + sin(pi t/2)x1
|
274 |
+
|
275 |
+
References
|
276 |
+
----------
|
277 |
+
[3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
|
278 |
+
"""
|
279 |
+
t = pad_t_like_x(t, x0)
|
280 |
+
return torch.cos(math.pi / 2 * t) * x0 + torch.sin(math.pi / 2 * t) * x1
|
281 |
+
|
282 |
+
def compute_conditional_flow(self, x0, x1, t, xt):
|
283 |
+
r"""Compute the conditional vector field similar to [3].
|
284 |
+
|
285 |
+
ut(x1|x0) = pi/2 (cos(pi*t/2) x1 - sin(pi*t/2) x0),
|
286 |
+
see Eq.(21) [3].
|
287 |
+
|
288 |
+
Parameters
|
289 |
+
----------
|
290 |
+
x0 : Tensor, shape (bs, *dim)
|
291 |
+
represents the source minibatch
|
292 |
+
x1 : Tensor, shape (bs, *dim)
|
293 |
+
represents the target minibatch
|
294 |
+
t : FloatTensor, shape (bs)
|
295 |
+
xt : Tensor, shape (bs, *dim)
|
296 |
+
represents the samples drawn from probability path pt
|
297 |
+
|
298 |
+
Returns
|
299 |
+
-------
|
300 |
+
ut : conditional vector field
|
301 |
+
ut(x1|x0) = pi/2 (cos(pi*t/2) x1 - sin(\pi*t/2) x0)
|
302 |
+
|
303 |
+
References
|
304 |
+
----------
|
305 |
+
[3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
|
306 |
+
"""
|
307 |
+
del xt
|
308 |
+
t = pad_t_like_x(t, x0)
|
309 |
+
return math.pi / 2 * (torch.cos(math.pi / 2 * t) * x1 - torch.sin(math.pi / 2 * t) * x0)
|
tts/modules/llm_dit/dit.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 ByteDance and/or its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from torch import nn
|
17 |
+
|
18 |
+
from tts.modules.llm_dit.cfm import ConditionalFlowMatcher
|
19 |
+
from tts.modules.ar_dur.commons.layers import Embedding
|
20 |
+
from tts.modules.ar_dur.commons.nar_tts_modules import PosEmb
|
21 |
+
from tts.modules.ar_dur.commons.rel_transformer import RelTransformerEncoder
|
22 |
+
from tts.modules.ar_dur.ar_dur_predictor import expand_states
|
23 |
+
from tts.modules.llm_dit.transformer import Transformer
|
24 |
+
from tts.modules.llm_dit.time_embedding import TimestepEmbedding
|
25 |
+
|
26 |
+
|
27 |
+
class Diffusion(nn.Module):
|
28 |
+
def __init__(self):
|
29 |
+
super().__init__()
|
30 |
+
# Hparams
|
31 |
+
# cond dim
|
32 |
+
self.local_cond_dim = 512
|
33 |
+
self.ctx_mask_dim = 16
|
34 |
+
self.in_channels = 32
|
35 |
+
self.out_channels = 32
|
36 |
+
# LLM
|
37 |
+
self.encoder_dim = 1024
|
38 |
+
self.encoder_n_layers = 24
|
39 |
+
self.encoder_n_heads = 16
|
40 |
+
self.max_seq_len = 16384
|
41 |
+
self.multiple_of = 256
|
42 |
+
|
43 |
+
self.ctx_mask_proj = nn.Linear(1, self.ctx_mask_dim)
|
44 |
+
self.local_cond_project = nn.Linear(
|
45 |
+
self.out_channels + self.ctx_mask_dim, self.local_cond_dim)
|
46 |
+
|
47 |
+
self.encoder = Transformer(self.encoder_n_layers, self.encoder_dim, self.encoder_n_heads, self.max_seq_len)
|
48 |
+
|
49 |
+
self.x_prenet = nn.Linear(self.in_channels, self.encoder_dim)
|
50 |
+
self.prenet = nn.Linear(self.local_cond_dim, self.encoder_dim)
|
51 |
+
self.postnet = nn.Linear(self.encoder_dim, self.out_channels)
|
52 |
+
|
53 |
+
self.flow_matcher = ConditionalFlowMatcher(sigma=0.0)
|
54 |
+
# The implementation of TimestepEmbedding is a modified version from F5-TTS (https://github.com/SWivid/F5-TTS),
|
55 |
+
# which is licensed under the MIT License.
|
56 |
+
self.f5_time_embed = TimestepEmbedding(self.encoder_dim)
|
57 |
+
|
58 |
+
# text encoder
|
59 |
+
self.ph_encoder = RelTransformerEncoder(
|
60 |
+
302, self.encoder_dim, self.encoder_dim,
|
61 |
+
self.encoder_dim * 2, 4, 6,
|
62 |
+
3, 0.0, prenet=True, pre_ln=True)
|
63 |
+
self.tone_embed = Embedding(32, self.encoder_dim, padding_idx=0)
|
64 |
+
self.ph_pos_embed = PosEmb(self.encoder_dim)
|
65 |
+
self.ling_pre_net = torch.nn.Sequential(*[
|
66 |
+
torch.nn.Conv1d(self.encoder_dim, self.encoder_dim, kernel_size=s * 2, stride=s, padding=s // 2)
|
67 |
+
for i, s in enumerate([2, 2])
|
68 |
+
])
|
69 |
+
|
70 |
+
def forward(self, inputs, sigmas=None, x_noisy=None):
|
71 |
+
ctx_mask = inputs['ctx_mask']
|
72 |
+
ctx_feature = inputs['lat_ctx'] * ctx_mask
|
73 |
+
|
74 |
+
""" local conditioning (prompt_latent + spk_embed) """
|
75 |
+
ctx_mask_emb = self.ctx_mask_proj(ctx_mask)
|
76 |
+
# ctx_feature = ctx_feature * (1 - inputs["spk_cfg_mask"][:, :, None])
|
77 |
+
local_cond = torch.cat([ctx_feature, ctx_mask_emb], dim=-1)
|
78 |
+
local_cond = self.local_cond_project(local_cond)
|
79 |
+
|
80 |
+
""" diffusion target latent """
|
81 |
+
x = inputs['lat']
|
82 |
+
|
83 |
+
# Here, x is x1 in CFM
|
84 |
+
x0 = torch.randn_like(x)
|
85 |
+
t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow(x0, x)
|
86 |
+
|
87 |
+
# define noisy_input and target
|
88 |
+
t = t.bfloat16()
|
89 |
+
x_noisy = (xt * (1 - ctx_mask)).bfloat16()
|
90 |
+
target = ut
|
91 |
+
|
92 |
+
# concat condition.
|
93 |
+
x_ling = self.forward_ling_encoder(inputs["phone"], inputs["tone"])
|
94 |
+
x_ling = self.ling_pre_net(expand_states(x_ling, inputs['mel2ph']).transpose(1, 2)).transpose(1, 2)
|
95 |
+
x_noisy = self.x_prenet(x_noisy) + self.prenet(local_cond) + x_ling
|
96 |
+
encoder_out = self.encoder(x_noisy, self.f5_time_embed(t), attn_mask=inputs["text_mel_mask"], do_checkpoint=False)
|
97 |
+
pred = self.postnet(encoder_out)
|
98 |
+
|
99 |
+
return pred, target
|
100 |
+
|
101 |
+
def forward_ling_encoder(self, txt_tokens, tone_tokens):
|
102 |
+
ph_tokens = txt_tokens
|
103 |
+
ph_nonpadding = (ph_tokens > 0).float()[:, :, None] # [B, T_phone, 1]
|
104 |
+
|
105 |
+
# enc_ph
|
106 |
+
ph_enc_oembed = self.tone_embed(tone_tokens)
|
107 |
+
ph_enc_oembed = ph_enc_oembed + self.ph_pos_embed(
|
108 |
+
torch.arange(0, ph_tokens.shape[1])[None,].to(ph_tokens.device))
|
109 |
+
ph_enc_oembed = ph_enc_oembed
|
110 |
+
ph_enc_oembed = ph_enc_oembed * ph_nonpadding
|
111 |
+
x_ling = self.ph_encoder(ph_tokens, other_embeds=ph_enc_oembed) * ph_nonpadding
|
112 |
+
return x_ling
|
113 |
+
|
114 |
+
def _forward(self, x, local_cond, x_ling, timesteps, ctx_mask, dur=None, seq_cfg_w=[1.0,1.0]):
|
115 |
+
""" When we use torchdiffeq, we need to include the CFG process inside _forward() """
|
116 |
+
x = x * (1 - ctx_mask)
|
117 |
+
x = self.x_prenet(x) + self.prenet(local_cond) + x_ling
|
118 |
+
pred_v = self.encoder(x, self.f5_time_embed(timesteps), attn_mask=torch.ones((x.size(0), x.size(1)), device=x.device))
|
119 |
+
pred = self.postnet(pred_v)
|
120 |
+
|
121 |
+
""" Perform multi-cond CFG """
|
122 |
+
cond_spk_txt, cond_txt, uncond = pred.chunk(3)
|
123 |
+
pred = uncond + seq_cfg_w[0] * (cond_txt - uncond) + seq_cfg_w[1] * (cond_spk_txt - cond_txt)
|
124 |
+
return pred
|
125 |
+
|
126 |
+
@torch.no_grad()
|
127 |
+
def inference(self, inputs, timesteps=20, seq_cfg_w=[1.0, 1.0], **kwargs):
|
128 |
+
# txt embedding
|
129 |
+
x_ling = self.forward_ling_encoder(inputs["phone"], inputs["tone"])
|
130 |
+
x_ling = self.ling_pre_net(expand_states(x_ling, inputs['dur']).transpose(1, 2)).transpose(1, 2)
|
131 |
+
|
132 |
+
# speaker embedding
|
133 |
+
ctx_feature = inputs['lat_ctx']
|
134 |
+
ctx_feature[1:, :, :] = 0 # prefix spk cfg
|
135 |
+
ctx_mask_emb = self.ctx_mask_proj(inputs['ctx_mask'])
|
136 |
+
|
137 |
+
# local conditioning.
|
138 |
+
local_cond = torch.cat([ctx_feature, ctx_mask_emb], dim=-1)
|
139 |
+
local_cond = self.local_cond_project(local_cond)
|
140 |
+
|
141 |
+
''' Euler ODE solver '''
|
142 |
+
bsz, device, frm_len = (local_cond.size(0), local_cond.device, local_cond.size(1))
|
143 |
+
# Sway sampling from F5-TTS (https://github.com/SWivid/F5-TTS),
|
144 |
+
# which is licensed under the MIT License.
|
145 |
+
sway_sampling_coef = -1.0
|
146 |
+
t_schedule = torch.linspace(0, 1, timesteps + 1, device=device, dtype=x_ling.dtype)
|
147 |
+
if sway_sampling_coef is not None:
|
148 |
+
t_schedule = t_schedule + sway_sampling_coef * (torch.cos(torch.pi / 2 * t_schedule) - 1 + t_schedule)
|
149 |
+
|
150 |
+
# AMO sampling implementation for "AMO Sampler: Enhancing Text Rendering with Overshooting" (https://arxiv.org/pdf/2411.19415)
|
151 |
+
def amo_sampling(z_t, t, t_next, v):
|
152 |
+
# Upcast to avoid precision issues when computing prev_sample
|
153 |
+
z_t = z_t.to(torch.float32)
|
154 |
+
|
155 |
+
# Constant definition in Algorithm 1
|
156 |
+
s = t_next
|
157 |
+
c = 3
|
158 |
+
|
159 |
+
# Line 7 in Algorithm 1
|
160 |
+
o = min(t_next + c * (t_next - t), 1)
|
161 |
+
pred_z_o = z_t + (o - t) * v
|
162 |
+
|
163 |
+
# Line 11 in Algorithm 1
|
164 |
+
a = s / o
|
165 |
+
b = ((1 - s) ** 2 - (a * (1 - o)) ** 2) ** 0.5
|
166 |
+
noise_i = torch.randn(size=z_t.shape, device=z_t.device)
|
167 |
+
z_t_next = a * pred_z_o + b * noise_i
|
168 |
+
return z_t_next.to(v.dtype)
|
169 |
+
|
170 |
+
x = torch.randn([1, frm_len, self.out_channels], device=device)
|
171 |
+
for step_index in range(timesteps):
|
172 |
+
x = x.to(torch.float32)
|
173 |
+
sigma = t_schedule[step_index].to(x_ling.dtype)
|
174 |
+
sigma_next = t_schedule[step_index + 1]
|
175 |
+
model_out = self._forward(torch.cat([x] * bsz), local_cond, x_ling, timesteps=sigma.unsqueeze(0), ctx_mask=inputs['ctx_mask'], dur=inputs['dur'], seq_cfg_w=seq_cfg_w)
|
176 |
+
x = amo_sampling(x, sigma, sigma_next, model_out)
|
177 |
+
# Cast sample back to model compatible dtype
|
178 |
+
x = x.to(model_out.dtype)
|
179 |
+
|
180 |
+
return x
|
tts/modules/llm_dit/time_embedding.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 ByteDance and/or its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
import torch
|
17 |
+
from torch import nn
|
18 |
+
|
19 |
+
|
20 |
+
class SinusPositionEmbedding(nn.Module):
|
21 |
+
def __init__(self, dim):
|
22 |
+
super().__init__()
|
23 |
+
self.dim = dim
|
24 |
+
|
25 |
+
def forward(self, x, scale=1000):
|
26 |
+
device = x.device
|
27 |
+
half_dim = self.dim // 2
|
28 |
+
emb = math.log(10000) / (half_dim - 1)
|
29 |
+
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
30 |
+
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
31 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
32 |
+
return emb
|
33 |
+
|
34 |
+
class TimestepEmbedding(nn.Module):
|
35 |
+
def __init__(self, dim, freq_embed_dim=256):
|
36 |
+
super().__init__()
|
37 |
+
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
38 |
+
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
39 |
+
|
40 |
+
def forward(self, timestep): # noqa: F821
|
41 |
+
time_hidden = self.time_embed(timestep)
|
42 |
+
time_hidden = time_hidden.to(timestep.dtype)
|
43 |
+
time = self.time_mlp(time_hidden) # b d
|
44 |
+
return time
|
tts/modules/llm_dit/transformer.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 ByteDance and/or its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import Any, Optional, Tuple
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn.functional as F
|
20 |
+
from torch import nn
|
21 |
+
|
22 |
+
|
23 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
24 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
25 |
+
t = torch.arange(end, device=freqs.device) # type: ignore
|
26 |
+
freqs = torch.outer(t, freqs).float() # type: ignore
|
27 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
28 |
+
return freqs_cis
|
29 |
+
|
30 |
+
|
31 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
32 |
+
ndim = x.ndim
|
33 |
+
assert 0 <= 1 < ndim
|
34 |
+
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
35 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
36 |
+
return freqs_cis.view(*shape)
|
37 |
+
|
38 |
+
|
39 |
+
def apply_rotary_emb(
|
40 |
+
xq: torch.Tensor,
|
41 |
+
xk: torch.Tensor,
|
42 |
+
freqs_cis: torch.Tensor,
|
43 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
44 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
45 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
46 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
47 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
48 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
49 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
50 |
+
|
51 |
+
|
52 |
+
class AdaLNZero(nn.Module):
|
53 |
+
def __init__(self, dim):
|
54 |
+
super().__init__()
|
55 |
+
self.silu = nn.SiLU()
|
56 |
+
self.linear = nn.Linear(dim, dim * 6)
|
57 |
+
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
58 |
+
|
59 |
+
def forward(self, x, emb=None):
|
60 |
+
emb = self.linear(self.silu(emb))
|
61 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
|
62 |
+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
63 |
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
64 |
+
|
65 |
+
|
66 |
+
class AdaLNZero_Out(nn.Module):
|
67 |
+
def __init__(self, dim):
|
68 |
+
super().__init__()
|
69 |
+
self.silu = nn.SiLU()
|
70 |
+
self.linear = nn.Linear(dim, dim * 2)
|
71 |
+
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
72 |
+
|
73 |
+
def forward(self, x, emb):
|
74 |
+
emb = self.linear(self.silu(emb))
|
75 |
+
scale, shift = torch.chunk(emb, 2, dim=1)
|
76 |
+
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
77 |
+
return x
|
78 |
+
|
79 |
+
|
80 |
+
class Attention(nn.Module):
|
81 |
+
def __init__(self, encoder_dim, encoder_n_heads, max_seq_len):
|
82 |
+
super().__init__()
|
83 |
+
self.encoder_n_kv_heads = encoder_n_heads
|
84 |
+
model_parallel_size = 1
|
85 |
+
self.n_local_heads = encoder_n_heads // model_parallel_size
|
86 |
+
self.n_local_kv_heads = self.encoder_n_kv_heads // model_parallel_size
|
87 |
+
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
88 |
+
self.head_dim = encoder_dim // encoder_n_heads
|
89 |
+
|
90 |
+
self.wq = nn.Linear(
|
91 |
+
encoder_dim,
|
92 |
+
encoder_n_heads * self.head_dim,
|
93 |
+
)
|
94 |
+
self.wk = nn.Linear(
|
95 |
+
encoder_dim,
|
96 |
+
self.encoder_n_kv_heads * self.head_dim,
|
97 |
+
)
|
98 |
+
self.wv = nn.Linear(
|
99 |
+
encoder_dim,
|
100 |
+
self.encoder_n_kv_heads * self.head_dim,
|
101 |
+
)
|
102 |
+
self.wo = nn.Linear(
|
103 |
+
encoder_n_heads * self.head_dim,
|
104 |
+
encoder_dim,
|
105 |
+
)
|
106 |
+
|
107 |
+
def forward(
|
108 |
+
self,
|
109 |
+
x: torch.Tensor,
|
110 |
+
start_pos: int,
|
111 |
+
freqs_cis: torch.Tensor,
|
112 |
+
mask: Optional[torch.Tensor],
|
113 |
+
):
|
114 |
+
bsz, seqlen, _ = x.shape
|
115 |
+
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
116 |
+
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
117 |
+
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
118 |
+
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
119 |
+
|
120 |
+
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
121 |
+
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
|
122 |
+
keys = xk.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
|
123 |
+
values = xv.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
|
124 |
+
|
125 |
+
output = F.scaled_dot_product_attention(xq, keys, values, mask[:, None, None, :], is_causal=False)
|
126 |
+
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
127 |
+
return self.wo(output)
|
128 |
+
|
129 |
+
|
130 |
+
class FeedForward(nn.Module):
|
131 |
+
def __init__(
|
132 |
+
self,
|
133 |
+
dim: int,
|
134 |
+
hidden_dim: int,
|
135 |
+
multiple_of: int,
|
136 |
+
ffn_dim_multiplier: Optional[float],
|
137 |
+
):
|
138 |
+
super().__init__()
|
139 |
+
if ffn_dim_multiplier is not None:
|
140 |
+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
141 |
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
142 |
+
|
143 |
+
self.w1 = nn.Linear(
|
144 |
+
dim, hidden_dim
|
145 |
+
)
|
146 |
+
self.w2 = nn.Linear(
|
147 |
+
hidden_dim, dim
|
148 |
+
)
|
149 |
+
|
150 |
+
def forward(self, x):
|
151 |
+
return self.w2(F.silu(self.w1(x)))
|
152 |
+
|
153 |
+
|
154 |
+
class TransformerBlock(nn.Module):
|
155 |
+
def __init__(self, encoder_dim, encoder_n_heads, max_seq_len):
|
156 |
+
super().__init__()
|
157 |
+
self.encoder_n_heads = encoder_n_heads
|
158 |
+
self.encoder_dim = encoder_dim
|
159 |
+
self.head_dim = encoder_dim // encoder_n_heads
|
160 |
+
self.attention = Attention(encoder_dim, encoder_n_heads, max_seq_len)
|
161 |
+
self.feed_forward = FeedForward(
|
162 |
+
dim=encoder_dim,
|
163 |
+
hidden_dim=2 * encoder_dim,
|
164 |
+
multiple_of=256,
|
165 |
+
ffn_dim_multiplier=None,
|
166 |
+
)
|
167 |
+
self.attention_norm = AdaLNZero(encoder_dim)
|
168 |
+
self.ffn_norm = nn.LayerNorm(encoder_dim, elementwise_affine=False, eps=1e-6)
|
169 |
+
|
170 |
+
def forward(
|
171 |
+
self,
|
172 |
+
x: torch.Tensor,
|
173 |
+
t: torch.Tensor,
|
174 |
+
start_pos: int,
|
175 |
+
freqs_cis: torch.Tensor,
|
176 |
+
mask: Optional[torch.Tensor],
|
177 |
+
):
|
178 |
+
"""
|
179 |
+
Perform a forward pass through the TransformerBlock.
|
180 |
+
|
181 |
+
Args:
|
182 |
+
x (torch.Tensor): Input tensor.
|
183 |
+
start_pos (int): Starting position for attention caching.
|
184 |
+
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
|
185 |
+
mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.
|
186 |
+
|
187 |
+
Returns:
|
188 |
+
torch.Tensor: Output tensor after applying attention and feedforward layers.
|
189 |
+
|
190 |
+
"""
|
191 |
+
# pre-norm & modulation for attention input
|
192 |
+
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attention_norm(x, emb=t)
|
193 |
+
|
194 |
+
# attention
|
195 |
+
attn_output = self.attention(norm, start_pos, freqs_cis, mask=mask)
|
196 |
+
|
197 |
+
# process attention output for input x
|
198 |
+
h = x + gate_msa.unsqueeze(1) * attn_output
|
199 |
+
|
200 |
+
norm = self.ffn_norm(h) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
201 |
+
ff_output = self.feed_forward(norm)
|
202 |
+
out = h + gate_mlp.unsqueeze(1) * ff_output
|
203 |
+
|
204 |
+
return out
|
205 |
+
|
206 |
+
|
207 |
+
class Transformer(nn.Module):
|
208 |
+
def __init__(self, encoder_n_layers, encoder_dim, encoder_n_heads, max_seq_len):
|
209 |
+
super().__init__()
|
210 |
+
# Decoder
|
211 |
+
self.layers = torch.nn.ModuleList()
|
212 |
+
for _ in range(encoder_n_layers):
|
213 |
+
self.layers.append(TransformerBlock(encoder_dim, encoder_n_heads, max_seq_len))
|
214 |
+
|
215 |
+
self.norm = AdaLNZero_Out(encoder_dim)
|
216 |
+
self.out_proj = nn.Linear(encoder_dim, encoder_dim)
|
217 |
+
|
218 |
+
# Rope embedding
|
219 |
+
freqs_cis = precompute_freqs_cis(
|
220 |
+
encoder_dim // encoder_n_heads, max_seq_len
|
221 |
+
)
|
222 |
+
self.register_buffer("freqs_cis", torch.view_as_real(freqs_cis), persistent=False)
|
223 |
+
|
224 |
+
def forward(self, x, t, attn_mask, start_pos=0):
|
225 |
+
freqs_cis = torch.view_as_complex(self.freqs_cis.float())[start_pos: start_pos + x.size(1)]
|
226 |
+
for i, layer in enumerate(self.layers):
|
227 |
+
x = layer(x, t, start_pos, freqs_cis, attn_mask)
|
228 |
+
x = self.norm(x, t)
|
229 |
+
x = self.out_proj(x)
|
230 |
+
return x
|
tts/modules/wavvae/decoder/diag_gaussian.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 ByteDance and/or its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
class DiagonalGaussianDistribution(object):
|
19 |
+
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
|
20 |
+
self.parameters = parameters
|
21 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
22 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
23 |
+
self.deterministic = deterministic
|
24 |
+
self.std = torch.exp(0.5 * self.logvar)
|
25 |
+
self.var = torch.exp(self.logvar)
|
26 |
+
if self.deterministic:
|
27 |
+
self.var = self.std = torch.zeros_like(
|
28 |
+
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
|
29 |
+
)
|
30 |
+
|
31 |
+
def sample(self, generator=None) -> torch.Tensor:
|
32 |
+
# make sure sample is on the same device as the parameters and has same dtype
|
33 |
+
sample = torch.randn(
|
34 |
+
self.mean.shape,
|
35 |
+
generator=generator,
|
36 |
+
device=self.parameters.device,
|
37 |
+
dtype=self.parameters.dtype,
|
38 |
+
)
|
39 |
+
x = self.mean + self.std * sample
|
40 |
+
return x
|
41 |
+
|
42 |
+
def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
|
43 |
+
if self.deterministic:
|
44 |
+
return torch.Tensor([0.0])
|
45 |
+
else:
|
46 |
+
if other is None:
|
47 |
+
return 0.5 * torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar
|
48 |
+
else:
|
49 |
+
return 0.5 * (
|
50 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
51 |
+
+ self.var / other.var
|
52 |
+
- 1.0
|
53 |
+
- self.logvar
|
54 |
+
+ other.logvar
|
55 |
+
)
|
56 |
+
|
57 |
+
def nll(self, sample, dims) -> torch.Tensor:
|
58 |
+
if self.deterministic:
|
59 |
+
return torch.Tensor([0.0])
|
60 |
+
logtwopi = np.log(2.0 * np.pi)
|
61 |
+
return 0.5 * torch.sum(
|
62 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
63 |
+
dim=dims,
|
64 |
+
)
|
65 |
+
|
66 |
+
def mode(self) -> torch.Tensor:
|
67 |
+
return self.mean
|
tts/modules/wavvae/decoder/hifigan_modules.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 ByteDance and/or its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
import torch
|
18 |
+
import torch.utils.data
|
19 |
+
from librosa.filters import mel as librosa_mel_fn
|
20 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
21 |
+
from torch.nn import Conv1d
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
|
25 |
+
def init_weights(m, mean=0.0, std=0.01):
|
26 |
+
classname = m.__class__.__name__
|
27 |
+
if classname.find("Conv") != -1:
|
28 |
+
m.weight.data.normal_(mean, std)
|
29 |
+
|
30 |
+
|
31 |
+
def get_padding(kernel_size, dilation=1):
|
32 |
+
return int((kernel_size*dilation - dilation)/2)
|
33 |
+
|
34 |
+
|
35 |
+
class Upsample(nn.Module):
|
36 |
+
def __init__(self, mult, r):
|
37 |
+
super(Upsample, self).__init__()
|
38 |
+
self.r = r
|
39 |
+
self.upsample = nn.Sequential(nn.Upsample(mode="nearest", scale_factor=r),
|
40 |
+
nn.LeakyReLU(0.2),
|
41 |
+
nn.ReflectionPad1d(3),
|
42 |
+
nn.utils.weight_norm(nn.Conv1d(mult, mult // 2, kernel_size=7, stride=1))
|
43 |
+
)
|
44 |
+
r_kernel = r if r >= 5 else 5
|
45 |
+
self.trans_upsample = nn.Sequential(nn.LeakyReLU(0.2),
|
46 |
+
nn.utils.weight_norm(nn.ConvTranspose1d(mult, mult // 2,
|
47 |
+
kernel_size=r_kernel * 2, stride=r,
|
48 |
+
padding=r_kernel - r // 2,
|
49 |
+
output_padding=r % 2)
|
50 |
+
))
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
x = torch.sin(x) + x
|
54 |
+
out1 = self.upsample(x)
|
55 |
+
out2 = self.trans_upsample(x)
|
56 |
+
return out1 + out2
|
57 |
+
|
58 |
+
|
59 |
+
class Downsample(nn.Module):
|
60 |
+
def __init__(self, mult, r):
|
61 |
+
super(Downsample, self).__init__()
|
62 |
+
self.r = r
|
63 |
+
r_kernel = r if r >= 5 else 5
|
64 |
+
self.trans_downsample = nn.Sequential(nn.LeakyReLU(0.2),
|
65 |
+
nn.utils.weight_norm(nn.Conv1d(mult, mult * 2,
|
66 |
+
kernel_size=r_kernel * 2, stride=r,
|
67 |
+
padding=r_kernel - r // 2)
|
68 |
+
))
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
out = self.trans_downsample(x)
|
72 |
+
return out
|
73 |
+
|
74 |
+
|
75 |
+
def weights_init(m):
|
76 |
+
classname = m.__class__.__name__
|
77 |
+
if classname.find("Conv") != -1:
|
78 |
+
m.weight.data.normal_(0.0, 0.02)
|
79 |
+
elif classname.find("BatchNorm2d") != -1:
|
80 |
+
m.weight.data.normal_(1.0, 0.02)
|
81 |
+
m.bias.data.fill_(0)
|
82 |
+
|
83 |
+
|
84 |
+
def weights_zero_init(m):
|
85 |
+
classname = m.__class__.__name__
|
86 |
+
if classname.find("Conv") != -1:
|
87 |
+
m.weight.data.fill_(0.0)
|
88 |
+
m.bias.data.fill_(0.0)
|
89 |
+
|
90 |
+
|
91 |
+
def WNConv1d(*args, **kwargs):
|
92 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
93 |
+
|
94 |
+
|
95 |
+
def WNConvTranspose1d(*args, **kwargs):
|
96 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
97 |
+
|
98 |
+
|
99 |
+
class Audio2Mel(nn.Module):
|
100 |
+
def __init__(
|
101 |
+
self,
|
102 |
+
hop_length=300,
|
103 |
+
sampling_rate=24000,
|
104 |
+
n_mel_channels=80,
|
105 |
+
mel_fmin=0.,
|
106 |
+
mel_fmax=None,
|
107 |
+
frame_size=0.05,
|
108 |
+
device='cpu'
|
109 |
+
):
|
110 |
+
super().__init__()
|
111 |
+
##############################################
|
112 |
+
# FFT Parameters #
|
113 |
+
##############################################
|
114 |
+
|
115 |
+
self.n_fft = int(np.power(2., np.ceil(np.log(sampling_rate * frame_size) / np.log(2))))
|
116 |
+
window = torch.hann_window(int(sampling_rate * frame_size)).float()
|
117 |
+
mel_basis = librosa_mel_fn(
|
118 |
+
sampling_rate, self.n_fft, n_mel_channels, mel_fmin, mel_fmax
|
119 |
+
) # Mel filter (by librosa)
|
120 |
+
mel_basis = torch.from_numpy(mel_basis).float()
|
121 |
+
self.register_buffer("mel_basis", mel_basis)
|
122 |
+
self.register_buffer("window", window)
|
123 |
+
|
124 |
+
self.hop_length = hop_length
|
125 |
+
self.win_length = int(sampling_rate * frame_size)
|
126 |
+
self.sampling_rate = sampling_rate
|
127 |
+
self.n_mel_channels = n_mel_channels
|
128 |
+
|
129 |
+
def forward(self, audio):
|
130 |
+
fft = torch.stft(
|
131 |
+
audio.squeeze(1),
|
132 |
+
n_fft=self.n_fft,
|
133 |
+
hop_length=self.hop_length,
|
134 |
+
win_length=self.win_length,
|
135 |
+
window=self.window,
|
136 |
+
center=True,
|
137 |
+
)
|
138 |
+
real_part, imag_part = fft.unbind(-1)
|
139 |
+
magnitude = torch.sqrt(torch.clamp(real_part ** 2 + imag_part ** 2, min=1e-5))
|
140 |
+
mel_output = torch.matmul(self.mel_basis, magnitude)
|
141 |
+
|
142 |
+
log_mel_spec = 20 * torch.log10(torch.clamp(mel_output, min=1e-5)) - 20
|
143 |
+
norm_mel = (log_mel_spec + 115.) / 115.
|
144 |
+
mel_comp = torch.clamp(norm_mel * 8. - 4., -4., 4.)
|
145 |
+
|
146 |
+
return mel_comp
|
147 |
+
|
148 |
+
|
149 |
+
class ResnetBlock(nn.Module):
|
150 |
+
def __init__(self, dim, dilation=1, dim_in=None):
|
151 |
+
super().__init__()
|
152 |
+
if dim_in is None:
|
153 |
+
dim_in = dim
|
154 |
+
|
155 |
+
self.block = nn.Sequential(
|
156 |
+
nn.LeakyReLU(0.2),
|
157 |
+
nn.ReflectionPad1d(dilation),
|
158 |
+
WNConv1d(dim_in, dim, kernel_size=3, dilation=dilation),
|
159 |
+
nn.LeakyReLU(0.2),
|
160 |
+
WNConv1d(dim, dim, kernel_size=1),
|
161 |
+
)
|
162 |
+
self.shortcut = WNConv1d(dim_in, dim, kernel_size=1)
|
163 |
+
|
164 |
+
def forward(self, x):
|
165 |
+
return self.shortcut(x) + self.block(x)
|
166 |
+
|
167 |
+
|
168 |
+
'''
|
169 |
+
参照hifigan(https://arxiv.org/pdf/2010.05646.pdf)v2结构
|
170 |
+
多尺度主要是kernel_size不同,3组并行卷积模块,每个卷积模块内部采用不同的串行dilation size,且中间交叉正常无dilation卷积层
|
171 |
+
'''
|
172 |
+
|
173 |
+
|
174 |
+
class ResBlockMRFV2(torch.nn.Module):
|
175 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
176 |
+
super(ResBlockMRFV2, self).__init__()
|
177 |
+
self.convs1 = nn.ModuleList([
|
178 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
179 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
180 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
181 |
+
padding=get_padding(kernel_size, dilation[1]))),
|
182 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
183 |
+
padding=get_padding(kernel_size, dilation[2])))
|
184 |
+
])
|
185 |
+
self.convs1.apply(init_weights)
|
186 |
+
|
187 |
+
self.convs2 = nn.ModuleList([
|
188 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
189 |
+
padding=get_padding(kernel_size, 1))),
|
190 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
191 |
+
padding=get_padding(kernel_size, 1))),
|
192 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
193 |
+
padding=get_padding(kernel_size, 1)))
|
194 |
+
])
|
195 |
+
self.convs2.apply(init_weights)
|
196 |
+
|
197 |
+
def forward(self, x):
|
198 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
199 |
+
xt = F.leaky_relu(x, 0.2)
|
200 |
+
xt = c1(xt)
|
201 |
+
xt = F.leaky_relu(xt, 0.2)
|
202 |
+
xt = c2(xt)
|
203 |
+
x = xt + x
|
204 |
+
return x
|
205 |
+
|
206 |
+
def remove_weight_norm(self):
|
207 |
+
for l in self.convs1:
|
208 |
+
remove_weight_norm(l)
|
209 |
+
for l in self.convs2:
|
210 |
+
remove_weight_norm(l)
|
211 |
+
|
212 |
+
|
213 |
+
class ResBlockMRFV2Inter(torch.nn.Module):
|
214 |
+
def __init__(self, channels, kernel_size=3):
|
215 |
+
super(ResBlockMRFV2Inter, self).__init__()
|
216 |
+
self.block1 = ResBlockMRFV2(channels)
|
217 |
+
self.block2 = ResBlockMRFV2(channels, 7)
|
218 |
+
self.block3 = ResBlockMRFV2(channels, 11)
|
219 |
+
|
220 |
+
def forward(self, x):
|
221 |
+
xs = self.block1(x)
|
222 |
+
xs += self.block2(x)
|
223 |
+
xs += self.block3(x)
|
224 |
+
x = xs / 3
|
225 |
+
return x
|
226 |
+
|
227 |
+
|
228 |
+
class Generator(nn.Module):
|
229 |
+
def __init__(self, input_size_, ngf, n_residual_layers, num_band, args, ratios=[5, 5, 4, 3], onnx_export=False,
|
230 |
+
device='cpu'):
|
231 |
+
super().__init__()
|
232 |
+
self.hop_length = args.frame_shift
|
233 |
+
self.args = args
|
234 |
+
self.onnx_export = onnx_export
|
235 |
+
|
236 |
+
# ------------- Define upsample layers ----------------
|
237 |
+
mult = int(2 ** len(ratios))
|
238 |
+
model_up = []
|
239 |
+
input_size = input_size_
|
240 |
+
model_up += [
|
241 |
+
nn.ReflectionPad1d(3),
|
242 |
+
WNConv1d(input_size, mult * ngf, kernel_size=7, padding=0),
|
243 |
+
]
|
244 |
+
|
245 |
+
# Upsample to raw audio scale
|
246 |
+
for i, r in enumerate(ratios):
|
247 |
+
model_up += [Upsample(mult * ngf, r)]
|
248 |
+
model_up += [ResBlockMRFV2Inter(mult * ngf // 2)]
|
249 |
+
mult //= 2
|
250 |
+
|
251 |
+
model_up += [
|
252 |
+
nn.LeakyReLU(0.2),
|
253 |
+
nn.ReflectionPad1d(3),
|
254 |
+
WNConv1d(ngf, num_band, kernel_size=7, padding=0),
|
255 |
+
nn.Tanh(),
|
256 |
+
]
|
257 |
+
if not args.use_tanh:
|
258 |
+
model_up[-1] = nn.Conv1d(num_band, num_band, 1)
|
259 |
+
model_up[-2].apply(weights_zero_init)
|
260 |
+
|
261 |
+
self.model_up = nn.Sequential(*model_up)
|
262 |
+
|
263 |
+
self.apply(weights_init)
|
264 |
+
|
265 |
+
def forward(self, mel, step=None):
|
266 |
+
# mel input: (batch_size, seq_num, 80)
|
267 |
+
if self.onnx_export:
|
268 |
+
mel = mel.transpose(1, 2)
|
269 |
+
# on onnx, for engineering, mel input: (batch_size, 80, seq_num)
|
270 |
+
|
271 |
+
# Between Down and up
|
272 |
+
x = mel
|
273 |
+
|
274 |
+
# Upsample pipline
|
275 |
+
cnt_after_upsample = 0
|
276 |
+
|
277 |
+
for i, m in enumerate(self.model_up):
|
278 |
+
x = m(x)
|
279 |
+
|
280 |
+
if type(m) == Upsample:
|
281 |
+
cnt_after_upsample += 1
|
282 |
+
|
283 |
+
return x
|
tts/modules/wavvae/decoder/seanet_encoder.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 ByteDance and/or its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import List
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import nn
|
19 |
+
from tts.modules.wavvae.encoder.common_modules.seanet import SEANetEncoder
|
20 |
+
|
21 |
+
class Encoder(nn.Module):
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
dowmsamples: List[int] = [6, 5, 5, 4, 2],
|
25 |
+
):
|
26 |
+
super().__init__()
|
27 |
+
|
28 |
+
# breakpoint()
|
29 |
+
self.frame_rate = 25 # not use
|
30 |
+
self.encoder = SEANetEncoder(causal=False, n_residual_layers=1, norm='weight_norm', pad_mode='reflect', lstm=2,
|
31 |
+
dimension=512, channels=1, n_filters=32, ratios=dowmsamples, activation='ELU',
|
32 |
+
kernel_size=7, residual_kernel_size=3, last_kernel_size=7, dilation_base=2,
|
33 |
+
true_skip=False, compress=2)
|
34 |
+
|
35 |
+
def forward(self, audio: torch.Tensor):
|
36 |
+
audio = audio.unsqueeze(1) # audio(16,24000)
|
37 |
+
emb = self.encoder(audio)
|
38 |
+
return emb
|
tts/modules/wavvae/decoder/wavvae_v3.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 ByteDance and/or its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import argparse
|
16 |
+
import torch
|
17 |
+
from torch import nn
|
18 |
+
import torch.nn.functional as F
|
19 |
+
|
20 |
+
from tts.modules.wavvae.decoder.seanet_encoder import Encoder
|
21 |
+
from tts.modules.wavvae.decoder.diag_gaussian import DiagonalGaussianDistribution
|
22 |
+
from tts.modules.wavvae.decoder.hifigan_modules import Generator, Upsample
|
23 |
+
|
24 |
+
|
25 |
+
class WavVAE_V3(nn.Module):
|
26 |
+
def __init__(self, hparams=None):
|
27 |
+
super().__init__()
|
28 |
+
self.encoder = Encoder(dowmsamples=[6, 5, 4, 4, 2])
|
29 |
+
self.proj_to_z = nn.Linear(512, 64)
|
30 |
+
self.proj_to_decoder = nn.Linear(32, 320)
|
31 |
+
|
32 |
+
config_path = hparams['melgan_config']
|
33 |
+
args = argparse.Namespace()
|
34 |
+
args.__dict__.update(config_path)
|
35 |
+
self.latent_upsampler = Upsample(320, 4)
|
36 |
+
self.decoder = Generator(
|
37 |
+
input_size_=160, ngf=128, n_residual_layers=4,
|
38 |
+
num_band=1, args=args, ratios=[5,4,4,3])
|
39 |
+
|
40 |
+
''' encode waveform into 25 hz latent representation '''
|
41 |
+
def encode_latent(self, audio):
|
42 |
+
posterior = self.encode(audio)
|
43 |
+
latent = posterior.sample().permute(0, 2, 1) # (b,t,latent_channel)
|
44 |
+
return latent
|
45 |
+
|
46 |
+
def encode(self, audio):
|
47 |
+
x = self.encoder(audio).permute(0, 2, 1)
|
48 |
+
x = self.proj_to_z(x).permute(0, 2, 1)
|
49 |
+
poseterior = DiagonalGaussianDistribution(x)
|
50 |
+
return poseterior
|
51 |
+
|
52 |
+
def decode(self, latent):
|
53 |
+
latent = self.proj_to_decoder(latent).permute(0, 2, 1)
|
54 |
+
return self.decoder(self.latent_upsampler(latent))
|
55 |
+
|
56 |
+
def forward(self, audio):
|
57 |
+
posterior = self.encode(audio)
|
58 |
+
latent = posterior.sample().permute(0, 2, 1) # (b, t, latent_channel)
|
59 |
+
recon_wav = self.decode(latent)
|
60 |
+
return recon_wav, posterior
|
tts/modules/wavvae/encoder/common_modules/conv.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
# Copyright (c) [2023] [Meta Platforms, Inc. and affiliates.]
|
24 |
+
# Copyright (c) [2025] [Ziyue Jiang]
|
25 |
+
# SPDX-License-Identifier: MIT
|
26 |
+
# This file has been modified by Ziyue Jiang on 2025/03/19
|
27 |
+
# Original file was released under MIT, with the full license text # available at https://github.com/facebookresearch/encodec/blob/gh-pages/LICENSE.
|
28 |
+
# This modified file is released under the same license.
|
29 |
+
|
30 |
+
"""Convolutional layers wrappers and utilities."""
|
31 |
+
|
32 |
+
import math
|
33 |
+
import typing as tp
|
34 |
+
import warnings
|
35 |
+
import einops
|
36 |
+
|
37 |
+
import torch
|
38 |
+
from torch import nn
|
39 |
+
from torch.nn import functional as F
|
40 |
+
from torch.nn.utils import spectral_norm, weight_norm
|
41 |
+
|
42 |
+
|
43 |
+
CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
|
44 |
+
'time_layer_norm', 'layer_norm', 'time_group_norm'])
|
45 |
+
|
46 |
+
|
47 |
+
def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
|
48 |
+
assert norm in CONV_NORMALIZATIONS
|
49 |
+
if norm == 'weight_norm':
|
50 |
+
return weight_norm(module)
|
51 |
+
elif norm == 'spectral_norm':
|
52 |
+
return spectral_norm(module)
|
53 |
+
else:
|
54 |
+
return module
|
55 |
+
|
56 |
+
|
57 |
+
def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
|
58 |
+
assert norm in CONV_NORMALIZATIONS
|
59 |
+
if norm == 'layer_norm':
|
60 |
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
61 |
+
return ConvLayerNorm(module.out_channels, **norm_kwargs)
|
62 |
+
elif norm == 'time_group_norm':
|
63 |
+
if causal:
|
64 |
+
raise ValueError("GroupNorm doesn't support causal evaluation.")
|
65 |
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
66 |
+
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
|
67 |
+
else:
|
68 |
+
return nn.Identity()
|
69 |
+
|
70 |
+
|
71 |
+
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
|
72 |
+
padding_total: int = 0) -> int:
|
73 |
+
length = x.shape[-1]
|
74 |
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
75 |
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
76 |
+
return ideal_length - length
|
77 |
+
|
78 |
+
|
79 |
+
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
|
80 |
+
length = x.shape[-1]
|
81 |
+
padding_left, padding_right = paddings
|
82 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
83 |
+
if mode == 'reflect':
|
84 |
+
max_pad = max(padding_left, padding_right)
|
85 |
+
extra_pad = 0
|
86 |
+
if length <= max_pad:
|
87 |
+
extra_pad = max_pad - length + 1
|
88 |
+
x = F.pad(x, (0, extra_pad))
|
89 |
+
padded = F.pad(x, paddings, mode, value)
|
90 |
+
end = padded.shape[-1] - extra_pad
|
91 |
+
return padded[..., :end]
|
92 |
+
else:
|
93 |
+
return F.pad(x, paddings, mode, value)
|
94 |
+
|
95 |
+
|
96 |
+
class ConvLayerNorm(nn.LayerNorm):
|
97 |
+
def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
|
98 |
+
super().__init__(normalized_shape, **kwargs)
|
99 |
+
|
100 |
+
def forward(self, x):
|
101 |
+
x = einops.rearrange(x, 'b ... t -> b t ...')
|
102 |
+
x = super().forward(x)
|
103 |
+
x = einops.rearrange(x, 'b t ... -> b ... t')
|
104 |
+
return
|
105 |
+
|
106 |
+
|
107 |
+
class NormConv1d(nn.Module):
|
108 |
+
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
109 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
110 |
+
super().__init__()
|
111 |
+
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
|
112 |
+
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
|
113 |
+
self.norm_type = norm
|
114 |
+
|
115 |
+
def forward(self, x):
|
116 |
+
x = self.conv(x)
|
117 |
+
x = self.norm(x)
|
118 |
+
return x
|
119 |
+
|
120 |
+
|
121 |
+
class SConv1d(nn.Module):
|
122 |
+
def __init__(self, in_channels: int, out_channels: int,
|
123 |
+
kernel_size: int, stride: int = 1, dilation: int = 1,
|
124 |
+
groups: int = 1, bias: bool = True, causal: bool = False,
|
125 |
+
norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
|
126 |
+
pad_mode: str = 'reflect'):
|
127 |
+
super().__init__()
|
128 |
+
# warn user on unusual setup between dilation and stride
|
129 |
+
if stride > 1 and dilation > 1:
|
130 |
+
warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1'
|
131 |
+
f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
|
132 |
+
self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
|
133 |
+
dilation=dilation, groups=groups, bias=bias, causal=causal,
|
134 |
+
norm=norm, norm_kwargs=norm_kwargs)
|
135 |
+
self.causal = causal
|
136 |
+
self.pad_mode = pad_mode
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
B, C, T = x.shape
|
140 |
+
kernel_size = self.conv.conv.kernel_size[0]
|
141 |
+
stride = self.conv.conv.stride[0]
|
142 |
+
dilation = self.conv.conv.dilation[0]
|
143 |
+
kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
|
144 |
+
padding_total = kernel_size - stride
|
145 |
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
146 |
+
if self.causal:
|
147 |
+
# Left padding for causal
|
148 |
+
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
|
149 |
+
else:
|
150 |
+
# Asymmetric padding required for odd strides
|
151 |
+
padding_right = padding_total // 2
|
152 |
+
padding_left = padding_total - padding_right
|
153 |
+
x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
|
154 |
+
return self.conv(x)
|
tts/modules/wavvae/encoder/common_modules/lstm.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
# Copyright (c) [2023] [Meta Platforms, Inc. and affiliates.]
|
24 |
+
# Copyright (c) [2025] [Ziyue Jiang]
|
25 |
+
# SPDX-License-Identifier: MIT
|
26 |
+
# This file has been modified by Ziyue Jiang on 2025/03/19
|
27 |
+
# Original file was released under MIT, with the full license text # available at https://github.com/facebookresearch/encodec/blob/gh-pages/LICENSE.
|
28 |
+
# This modified file is released under the same license.
|
29 |
+
|
30 |
+
"""LSTM layers module."""
|
31 |
+
from torch import nn
|
32 |
+
|
33 |
+
|
34 |
+
class SLSTM(nn.Module):
|
35 |
+
"""
|
36 |
+
LSTM without worrying about the hidden state, nor the layout of the data.
|
37 |
+
Expects input as convolutional layout.
|
38 |
+
"""
|
39 |
+
def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
|
40 |
+
super().__init__()
|
41 |
+
self.skip = skip
|
42 |
+
self.lstm = nn.LSTM(dimension, dimension, num_layers)
|
43 |
+
|
44 |
+
# 修改transpose顺序
|
45 |
+
def forward(self, x):
|
46 |
+
x1 = x.permute(2, 0, 1)
|
47 |
+
y, _ = self.lstm(x1)
|
48 |
+
y = y.permute(1, 2, 0)
|
49 |
+
if self.skip:
|
50 |
+
y = y + x
|
51 |
+
return y
|
tts/modules/wavvae/encoder/common_modules/seanet.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
# Copyright (c) [2023] [Meta Platforms, Inc. and affiliates.]
|
24 |
+
# Copyright (c) [2025] [Ziyue Jiang]
|
25 |
+
# SPDX-License-Identifier: MIT
|
26 |
+
# This file has been modified by Ziyue Jiang on 2025/03/19
|
27 |
+
# Original file was released under MIT, with the full license text # available at https://github.com/facebookresearch/encodec/blob/gh-pages/LICENSE.
|
28 |
+
# This modified file is released under the same license.
|
29 |
+
|
30 |
+
"""Encodec SEANet-based encoder and decoder implementation."""
|
31 |
+
|
32 |
+
import typing as tp
|
33 |
+
|
34 |
+
import numpy as np
|
35 |
+
import torch.nn as nn
|
36 |
+
|
37 |
+
from .conv import SConv1d
|
38 |
+
from .lstm import SLSTM
|
39 |
+
|
40 |
+
|
41 |
+
class SEANetResnetBlock(nn.Module):
|
42 |
+
def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
|
43 |
+
activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
|
44 |
+
norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False,
|
45 |
+
pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True):
|
46 |
+
super().__init__()
|
47 |
+
assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations'
|
48 |
+
act = getattr(nn, activation)
|
49 |
+
hidden = dim // compress
|
50 |
+
block = []
|
51 |
+
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
|
52 |
+
in_chs = dim if i == 0 else hidden
|
53 |
+
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
|
54 |
+
block += [
|
55 |
+
act(**activation_params),
|
56 |
+
SConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
|
57 |
+
norm=norm, norm_kwargs=norm_params,
|
58 |
+
causal=causal, pad_mode=pad_mode),
|
59 |
+
]
|
60 |
+
self.block = nn.Sequential(*block)
|
61 |
+
self.shortcut: nn.Module
|
62 |
+
if true_skip:
|
63 |
+
self.shortcut = nn.Identity()
|
64 |
+
else:
|
65 |
+
self.shortcut = SConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params,
|
66 |
+
causal=causal, pad_mode=pad_mode)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
return self.shortcut(x) + self.block(x)
|
70 |
+
|
71 |
+
|
72 |
+
class SEANetEncoder(nn.Module):
|
73 |
+
def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 1,
|
74 |
+
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
|
75 |
+
norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
|
76 |
+
last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
|
77 |
+
pad_mode: str = 'reflect', true_skip: bool = False, compress: int = 2, lstm: int = 2):
|
78 |
+
super().__init__()
|
79 |
+
self.channels = channels
|
80 |
+
self.dimension = dimension
|
81 |
+
self.n_filters = n_filters
|
82 |
+
self.ratios = list(reversed(ratios))
|
83 |
+
del ratios
|
84 |
+
self.n_residual_layers = n_residual_layers
|
85 |
+
self.hop_length = np.prod(self.ratios)
|
86 |
+
|
87 |
+
act = getattr(nn, activation)
|
88 |
+
mult = 1
|
89 |
+
model: tp.List[nn.Module] = [
|
90 |
+
SConv1d(channels, mult * n_filters, kernel_size, norm=norm, norm_kwargs=norm_params,
|
91 |
+
causal=causal, pad_mode=pad_mode)
|
92 |
+
]
|
93 |
+
# Downsample to raw audio scale
|
94 |
+
for i, ratio in enumerate(self.ratios):
|
95 |
+
# Add residual layers
|
96 |
+
for j in range(n_residual_layers):
|
97 |
+
model += [
|
98 |
+
SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
|
99 |
+
dilations=[dilation_base ** j, 1],
|
100 |
+
norm=norm, norm_params=norm_params,
|
101 |
+
activation=activation, activation_params=activation_params,
|
102 |
+
causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
|
103 |
+
|
104 |
+
# Add downsampling layers
|
105 |
+
model += [
|
106 |
+
act(**activation_params),
|
107 |
+
SConv1d(mult * n_filters, mult * n_filters * 2,
|
108 |
+
kernel_size=ratio * 2, stride=ratio,
|
109 |
+
norm=norm, norm_kwargs=norm_params,
|
110 |
+
causal=causal, pad_mode=pad_mode),
|
111 |
+
]
|
112 |
+
mult *= 2
|
113 |
+
|
114 |
+
if lstm:
|
115 |
+
model += [SLSTM(mult * n_filters, num_layers=lstm)]
|
116 |
+
|
117 |
+
model += [
|
118 |
+
act(**activation_params),
|
119 |
+
SConv1d(mult * n_filters, dimension, last_kernel_size, norm=norm, norm_kwargs=norm_params,
|
120 |
+
causal=causal, pad_mode=pad_mode)
|
121 |
+
]
|
122 |
+
|
123 |
+
self.model = nn.Sequential(*model)
|
124 |
+
|
125 |
+
def forward(self, x):
|
126 |
+
return self.model(x)
|
tts/utils/audio_utils/align.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 ByteDance and/or its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
|
17 |
+
def mel2token_to_dur(mel2token, T_txt=None, max_dur=None):
|
18 |
+
is_torch = isinstance(mel2token, torch.Tensor)
|
19 |
+
has_batch_dim = True
|
20 |
+
if not is_torch:
|
21 |
+
mel2token = torch.LongTensor(mel2token)
|
22 |
+
if T_txt is None:
|
23 |
+
T_txt = mel2token.max()
|
24 |
+
if len(mel2token.shape) == 1:
|
25 |
+
mel2token = mel2token[None, ...]
|
26 |
+
has_batch_dim = False
|
27 |
+
B, _ = mel2token.shape
|
28 |
+
dur = mel2token.new_zeros(B, T_txt + 1).scatter_add(1, mel2token, torch.ones_like(mel2token))
|
29 |
+
dur = dur[:, 1:]
|
30 |
+
if max_dur is not None:
|
31 |
+
dur = dur.clamp(max=max_dur)
|
32 |
+
if not is_torch:
|
33 |
+
dur = dur.numpy()
|
34 |
+
if not has_batch_dim:
|
35 |
+
dur = dur[0]
|
36 |
+
return dur
|
tts/utils/audio_utils/io.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 ByteDance and/or its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import io
|
16 |
+
import os
|
17 |
+
import subprocess
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
from scipy.io import wavfile
|
21 |
+
import pyloudnorm as pyln
|
22 |
+
from pydub import AudioSegment
|
23 |
+
|
24 |
+
|
25 |
+
def to_wav_bytes(wav, sr, norm=False):
|
26 |
+
wav = wav.astype(float)
|
27 |
+
if norm:
|
28 |
+
meter = pyln.Meter(sr) # create BS.1770 meter
|
29 |
+
loudness = meter.integrated_loudness(wav)
|
30 |
+
wav = pyln.normalize.loudness(wav, loudness, -18.0)
|
31 |
+
if np.abs(wav).max() >= 1:
|
32 |
+
wav = wav / np.abs(wav).max() * 0.95
|
33 |
+
wav = wav * 32767
|
34 |
+
bytes_io = io.BytesIO()
|
35 |
+
wavfile.write(bytes_io, sr, wav.astype(np.int16))
|
36 |
+
return bytes_io.getvalue()
|
37 |
+
|
38 |
+
|
39 |
+
def save_wav(wav_bytes, path):
|
40 |
+
with open(path[:-4] + '.wav', 'wb') as file:
|
41 |
+
file.write(wav_bytes)
|
42 |
+
if path[-4:] == '.mp3':
|
43 |
+
to_mp3(path[:-4])
|
44 |
+
|
45 |
+
|
46 |
+
def to_mp3(out_path):
|
47 |
+
if out_path[-4:] == '.wav':
|
48 |
+
out_path = out_path[:-4]
|
49 |
+
subprocess.check_call(
|
50 |
+
f'ffmpeg -threads 1 -loglevel error -i "{out_path}.wav" -vn -b:a 192k -y -hide_banner -async 1 "{out_path}.mp3"',
|
51 |
+
shell=True, stdin=subprocess.PIPE)
|
52 |
+
subprocess.check_call(f'rm -f "{out_path}.wav"', shell=True)
|
53 |
+
|
54 |
+
|
55 |
+
def convert_to_wav(wav_path):
|
56 |
+
# Check if the file exists
|
57 |
+
if not os.path.exists(wav_path):
|
58 |
+
print(f"The file '{wav_path}' does not exist.")
|
59 |
+
return
|
60 |
+
|
61 |
+
# Check if the file already has a .wav extension
|
62 |
+
if not wav_path.endswith(".wav"):
|
63 |
+
# Define the output path with a .wav extension
|
64 |
+
out_path = os.path.splitext(wav_path)[0] + ".wav"
|
65 |
+
|
66 |
+
# Load the audio file using pydub and convert it to WAV
|
67 |
+
audio = AudioSegment.from_file(wav_path)
|
68 |
+
audio.export(out_path, format="wav")
|
69 |
+
|
70 |
+
print(f"Converted '{wav_path}' to '{out_path}'")
|
71 |
+
|
72 |
+
|
73 |
+
def convert_to_wav_bytes(audio_binary):
|
74 |
+
# Load the audio binary using pydub and convert it to WAV
|
75 |
+
audio = AudioSegment.from_file(io.BytesIO(audio_binary))
|
76 |
+
wav_bytes = io.BytesIO()
|
77 |
+
audio.export(wav_bytes, format="wav")
|
78 |
+
wav_bytes.seek(0)
|
79 |
+
return wav_bytes
|
80 |
+
|
81 |
+
|
82 |
+
''' Smoothly combine audio segments using crossfade transitions." '''
|
83 |
+
def combine_audio_segments(segments, crossfade_duration=0.16, sr=24000):
|
84 |
+
window_length = int(sr * crossfade_duration)
|
85 |
+
hanning_window = np.hanning(2 * window_length)
|
86 |
+
# Combine
|
87 |
+
for i, segment in enumerate(segments):
|
88 |
+
if i == 0:
|
89 |
+
combined_audio = segment
|
90 |
+
else:
|
91 |
+
overlap = combined_audio[-window_length:] * hanning_window[window_length:] + segment[:window_length] * hanning_window[:window_length]
|
92 |
+
combined_audio = np.concatenate(
|
93 |
+
[combined_audio[:-window_length], overlap, segment[window_length:]]
|
94 |
+
)
|
95 |
+
return combined_audio
|
tts/utils/audio_utils/plot.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 ByteDance and/or its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import matplotlib
|
16 |
+
|
17 |
+
matplotlib.use('Agg')
|
18 |
+
import matplotlib.pyplot as plt
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
|
22 |
+
LINE_COLORS = ['w', 'r', 'orange', 'k', 'cyan', 'm', 'b', 'lime', 'g', 'brown', 'navy']
|
23 |
+
|
24 |
+
|
25 |
+
def spec_to_figure(spec, vmin=None, vmax=None, title='', f0s=None, dur_info=None, figsize=(12, 6)):
|
26 |
+
if isinstance(spec, torch.Tensor):
|
27 |
+
spec = spec.cpu().numpy()
|
28 |
+
H = spec.shape[1] // 2
|
29 |
+
fig = plt.figure(figsize=figsize)
|
30 |
+
plt.title(title)
|
31 |
+
plt.pcolor(spec.T, vmin=vmin, vmax=vmax)
|
32 |
+
|
33 |
+
if dur_info is not None:
|
34 |
+
assert isinstance(dur_info, dict)
|
35 |
+
txt = dur_info['txt']
|
36 |
+
dur_gt = dur_info['dur_gt']
|
37 |
+
if isinstance(dur_gt, torch.Tensor):
|
38 |
+
dur_gt = dur_gt.cpu().numpy()
|
39 |
+
dur_gt = np.cumsum(dur_gt).astype(int)
|
40 |
+
for i in range(len(dur_gt)):
|
41 |
+
shift = (i % 8) + 1
|
42 |
+
plt.text(dur_gt[i], shift * 4, txt[i])
|
43 |
+
plt.vlines(dur_gt[i], 0, H // 2, colors='b') # blue is gt
|
44 |
+
plt.xlim(0, dur_gt[-1])
|
45 |
+
if 'dur_pred' in dur_info:
|
46 |
+
dur_pred = dur_info['dur_pred']
|
47 |
+
if isinstance(dur_pred, torch.Tensor):
|
48 |
+
dur_pred = dur_pred.cpu().numpy()
|
49 |
+
dur_pred = np.cumsum(dur_pred).astype(int)
|
50 |
+
for i in range(len(dur_pred)):
|
51 |
+
shift = (i % 8) + 1
|
52 |
+
plt.text(dur_pred[i], H + shift * 4, txt[i])
|
53 |
+
plt.vlines(dur_pred[i], H, H * 1.5, colors='r') # red is pred
|
54 |
+
plt.xlim(0, max(dur_gt[-1], dur_pred[-1]))
|
55 |
+
if f0s is not None:
|
56 |
+
ax = plt.gca()
|
57 |
+
ax2 = ax.twinx()
|
58 |
+
# ax.set_xticks()
|
59 |
+
|
60 |
+
if not isinstance(f0s, dict):
|
61 |
+
f0s = {'f0': f0s}
|
62 |
+
for i, (k, f0) in enumerate(f0s.items()):
|
63 |
+
if f0 is not None:
|
64 |
+
if isinstance(f0, torch.Tensor):
|
65 |
+
f0 = f0.cpu().numpy()
|
66 |
+
ax2.plot(
|
67 |
+
np.arange(len(f0)) + 0.5, f0, label=k, c=LINE_COLORS[i], linewidth=1, alpha=0.5)
|
68 |
+
ax2.set_ylim(0, 1000)
|
69 |
+
ax2.legend()
|
70 |
+
return fig
|
71 |
+
|
72 |
+
|
73 |
+
def align_to_figure(align, dur_info):
|
74 |
+
if isinstance(align, torch.Tensor):
|
75 |
+
align = align.cpu().numpy()
|
76 |
+
H = align.shape[1]
|
77 |
+
fig = plt.figure(figsize=(12, 6))
|
78 |
+
plt.pcolor(align.T, vmin=0, vmax=1)
|
79 |
+
if dur_info is not None:
|
80 |
+
assert isinstance(dur_info, dict)
|
81 |
+
txt = dur_info['txt']
|
82 |
+
dur_gt = dur_info['dur_gt']
|
83 |
+
if isinstance(dur_gt, torch.Tensor):
|
84 |
+
dur_gt = dur_gt.cpu().numpy()
|
85 |
+
dur_gt = np.cumsum(dur_gt).astype(int) // 2
|
86 |
+
for i in range(len(dur_gt)):
|
87 |
+
plt.text(dur_gt[i], i, txt[i], color='red')
|
88 |
+
plt.vlines(dur_gt[i], 0, H, colors='b') # blue is gt
|
89 |
+
# plt.xlim(0, dur_gt[-1])
|
90 |
+
return fig
|
tts/utils/commons/ckpt_utils.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 ByteDance and/or its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import contextlib
|
16 |
+
import glob
|
17 |
+
import os
|
18 |
+
import re
|
19 |
+
import subprocess
|
20 |
+
import traceback
|
21 |
+
|
22 |
+
import torch
|
23 |
+
from torch.nn.parallel import DistributedDataParallel
|
24 |
+
import torch.distributed as dist
|
25 |
+
|
26 |
+
|
27 |
+
@contextlib.contextmanager
|
28 |
+
def dist_load(path):
|
29 |
+
if not dist.is_initialized() or dist.get_world_size() == 1 or os.path.realpath(path).startswith('/dev/shm'):
|
30 |
+
yield path
|
31 |
+
else:
|
32 |
+
from tts.utils.commons.hparams import hparams
|
33 |
+
from tts.utils.commons.trainer import LOCAL_RANK
|
34 |
+
tmpdir = '/dev/shm'
|
35 |
+
assert len(os.path.basename(path)) > 0
|
36 |
+
shm_ckpt_path = f'{tmpdir}/{hparams["exp_name"]}/{os.path.basename(path)}'
|
37 |
+
if LOCAL_RANK == 0:
|
38 |
+
subprocess.check_call(
|
39 |
+
f'mkdir -p {os.path.dirname(shm_ckpt_path)}; '
|
40 |
+
f'cp -Lr {path} {shm_ckpt_path}', shell=True)
|
41 |
+
dist.barrier()
|
42 |
+
yield shm_ckpt_path
|
43 |
+
dist.barrier()
|
44 |
+
if LOCAL_RANK == 0:
|
45 |
+
subprocess.check_call(f'rm -rf {shm_ckpt_path}', shell=True)
|
46 |
+
|
47 |
+
|
48 |
+
def torch_load_dist(path, map_location='cpu'):
|
49 |
+
with dist_load(path) as tmp_path:
|
50 |
+
checkpoint = torch.load(tmp_path, map_location=map_location)
|
51 |
+
return checkpoint
|
52 |
+
|
53 |
+
|
54 |
+
def get_last_checkpoint(work_dir, steps=None):
|
55 |
+
checkpoint = None
|
56 |
+
last_ckpt_path = None
|
57 |
+
ckpt_paths = get_all_ckpts(work_dir, steps)
|
58 |
+
if len(ckpt_paths) > 0:
|
59 |
+
last_ckpt_path = ckpt_paths[0]
|
60 |
+
checkpoint = torch_load_dist(last_ckpt_path, map_location='cpu')
|
61 |
+
return checkpoint, last_ckpt_path
|
62 |
+
|
63 |
+
|
64 |
+
def get_all_ckpts(work_dir, steps=None):
|
65 |
+
if steps is None or steps == 0:
|
66 |
+
ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_*.ckpt'
|
67 |
+
else:
|
68 |
+
ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_{steps}.ckpt'
|
69 |
+
return sorted(glob.glob(ckpt_path_pattern),
|
70 |
+
key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0]))
|
71 |
+
|
72 |
+
|
73 |
+
def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True,
|
74 |
+
silent=False, load_opt=False, opts=None, steps=None, checkpoint=None, ckpt_path='', delete_unmatch=True):
|
75 |
+
if checkpoint is None:
|
76 |
+
if os.path.isfile(ckpt_base_dir):
|
77 |
+
base_dir = os.path.dirname(ckpt_base_dir)
|
78 |
+
ckpt_path = ckpt_base_dir
|
79 |
+
checkpoint = torch_load_dist(ckpt_base_dir, map_location='cpu')
|
80 |
+
else:
|
81 |
+
base_dir = ckpt_base_dir
|
82 |
+
if load_opt:
|
83 |
+
checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir, steps)
|
84 |
+
else:
|
85 |
+
ckpt_path = f'{ckpt_base_dir}/model_only_last.ckpt'
|
86 |
+
if os.path.exists(ckpt_path):
|
87 |
+
checkpoint = torch_load_dist(ckpt_path, map_location='cpu')
|
88 |
+
else:
|
89 |
+
checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir, steps)
|
90 |
+
if checkpoint is not None:
|
91 |
+
state_dict_all = {
|
92 |
+
k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in checkpoint["state_dict"].items()}
|
93 |
+
if not isinstance(cur_model, list):
|
94 |
+
cur_models = [cur_model]
|
95 |
+
model_names = [model_name]
|
96 |
+
else:
|
97 |
+
cur_models = cur_model
|
98 |
+
model_names = model_name
|
99 |
+
for model_name, cur_model in zip(model_names, cur_models):
|
100 |
+
if isinstance(cur_model, DistributedDataParallel):
|
101 |
+
cur_model = cur_model.module
|
102 |
+
device = next(cur_model.parameters()).device
|
103 |
+
if '.' not in model_name:
|
104 |
+
state_dict = state_dict_all[model_name]
|
105 |
+
else:
|
106 |
+
base_model_name = model_name.split('.')[0]
|
107 |
+
rest_model_name = model_name[len(base_model_name) + 1:]
|
108 |
+
state_dict = {
|
109 |
+
k[len(rest_model_name) + 1:]: v for k, v in state_dict_all[base_model_name].items()
|
110 |
+
if k.startswith(f'{rest_model_name}.')}
|
111 |
+
state_dict = {k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in state_dict.items()}
|
112 |
+
if not strict and delete_unmatch:
|
113 |
+
try:
|
114 |
+
cur_model.load_state_dict(state_dict, strict=True)
|
115 |
+
if not silent:
|
116 |
+
print(f"| loaded '{model_name}' from '{ckpt_path}' with strict=True.")
|
117 |
+
except:
|
118 |
+
cur_model_state_dict = cur_model.state_dict()
|
119 |
+
cur_model_state_dict = {k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in
|
120 |
+
cur_model_state_dict.items()}
|
121 |
+
unmatched_keys = []
|
122 |
+
for key, param in state_dict.items():
|
123 |
+
if key in cur_model_state_dict:
|
124 |
+
new_param = cur_model_state_dict[key]
|
125 |
+
if new_param.shape != param.shape:
|
126 |
+
unmatched_keys.append(key)
|
127 |
+
print("| Unmatched keys: ", key, "cur model: ", new_param.shape,
|
128 |
+
"ckpt model: ", param.shape)
|
129 |
+
for key in unmatched_keys:
|
130 |
+
del state_dict[key]
|
131 |
+
load_results = cur_model.load_state_dict(state_dict, strict=strict)
|
132 |
+
cur_model.to(device)
|
133 |
+
if not silent:
|
134 |
+
print(f"| loaded '{model_name}' from '{ckpt_path}'.")
|
135 |
+
missing_keys, unexpected_keys = load_results.missing_keys, load_results.unexpected_keys
|
136 |
+
print(f"| Missing keys: {len(missing_keys)}, Unexpected keys: {len(unexpected_keys)}")
|
137 |
+
if load_opt:
|
138 |
+
optimizer_states = checkpoint['optimizer_states']
|
139 |
+
assert len(opts) == len(optimizer_states)
|
140 |
+
for optimizer, opt_state in zip(opts, optimizer_states):
|
141 |
+
opt_state = {k.replace('_orig_mod.', ''): v for k, v in opt_state.items()}
|
142 |
+
if optimizer is None:
|
143 |
+
return
|
144 |
+
try:
|
145 |
+
optimizer.load_state_dict(opt_state)
|
146 |
+
for i, state in enumerate(optimizer.state.values()):
|
147 |
+
for k, v in state.items():
|
148 |
+
if isinstance(v, torch.Tensor):
|
149 |
+
state[k] = v.to(device)
|
150 |
+
except ValueError:
|
151 |
+
print(f"| WARMING: optimizer {optimizer} parameters not match !!!")
|
152 |
+
return checkpoint.get('global_step', 0)
|
153 |
+
else:
|
154 |
+
e_msg = f"| ckpt not found in {base_dir}."
|
155 |
+
if force:
|
156 |
+
assert False, e_msg
|
157 |
+
else:
|
158 |
+
print(e_msg)
|
159 |
+
|
160 |
+
|
161 |
+
def load_with_size_mismatch(model, state_dict, prefix=""):
|
162 |
+
current_model_dict = model.state_dict()
|
163 |
+
cm_keys = current_model_dict.keys()
|
164 |
+
mismatch_keys = {k.replace(prefix, "") for k, v in state_dict.items() if k.replace(prefix, "") in cm_keys and v.size() != current_model_dict[k.replace(prefix, "")].size()}
|
165 |
+
new_state_dict = {k.replace(prefix, ""): v for k, v in state_dict.items() if k.replace(prefix, "") in cm_keys and v.size() == current_model_dict[k.replace(prefix, "")].size()}
|
166 |
+
missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
|
167 |
+
print(f"| mismatch keys: ", mismatch_keys)
|
168 |
+
if len(missing_keys) > 0:
|
169 |
+
print(f"| missing_keys in dit: {missing_keys}")
|
170 |
+
if len(unexpected_keys) > 0:
|
171 |
+
print(f"| unexpected_keys in dit: {unexpected_keys}")
|
tts/utils/commons/hparams.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 ByteDance and/or its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import argparse
|
16 |
+
import json
|
17 |
+
import os
|
18 |
+
import re
|
19 |
+
|
20 |
+
import yaml
|
21 |
+
|
22 |
+
global_print_hparams = True
|
23 |
+
hparams = {}
|
24 |
+
|
25 |
+
|
26 |
+
class Args:
|
27 |
+
def __init__(self, **kwargs):
|
28 |
+
for k, v in kwargs.items():
|
29 |
+
self.__setattr__(k, v)
|
30 |
+
|
31 |
+
|
32 |
+
def override_config(old_config: dict, new_config: dict):
|
33 |
+
if new_config.get('__replace', False):
|
34 |
+
old_config.clear()
|
35 |
+
for k, v in new_config.items():
|
36 |
+
if isinstance(v, dict) and k in old_config:
|
37 |
+
override_config(old_config[k], new_config[k])
|
38 |
+
else:
|
39 |
+
old_config[k] = v
|
40 |
+
|
41 |
+
|
42 |
+
def traverse_dict(d, func, ctx):
|
43 |
+
for k in list(d.keys()):
|
44 |
+
v = d[k]
|
45 |
+
if isinstance(v, dict):
|
46 |
+
traverse_dict(v, func, ctx)
|
47 |
+
else:
|
48 |
+
d[k] = func(v, ctx)
|
49 |
+
|
50 |
+
|
51 |
+
def parse_config(v, context=None):
|
52 |
+
if context is None:
|
53 |
+
context = {}
|
54 |
+
|
55 |
+
if isinstance(v, str):
|
56 |
+
if v.startswith('^'):
|
57 |
+
return load_config(v[1:], [], set())
|
58 |
+
|
59 |
+
match = re.match(r"\${(.*)}", v)
|
60 |
+
if match:
|
61 |
+
expression = match.group(1)
|
62 |
+
return eval(expression, {}, context)
|
63 |
+
return v
|
64 |
+
|
65 |
+
|
66 |
+
def remove_meta_key(d):
|
67 |
+
for k in list(d.keys()):
|
68 |
+
v = d[k]
|
69 |
+
if isinstance(v, dict):
|
70 |
+
remove_meta_key(v)
|
71 |
+
else:
|
72 |
+
if k[:2] == '__':
|
73 |
+
del d[k]
|
74 |
+
|
75 |
+
|
76 |
+
def load_config(config_fn, config_chains, loaded_configs):
|
77 |
+
# deep first inheritance and avoid the second visit of one node
|
78 |
+
if not os.path.exists(config_fn):
|
79 |
+
print(f"| WARN: {config_fn} not exist.", )
|
80 |
+
return {}
|
81 |
+
with open(config_fn) as f:
|
82 |
+
hparams_ = yaml.safe_load(f)
|
83 |
+
loaded_configs.add(config_fn)
|
84 |
+
|
85 |
+
if 'base_config' in hparams_:
|
86 |
+
ret_hparams = {}
|
87 |
+
if not isinstance(hparams_['base_config'], list):
|
88 |
+
hparams_['base_config'] = [hparams_['base_config']]
|
89 |
+
for c in hparams_['base_config']:
|
90 |
+
if c.startswith('.'):
|
91 |
+
c = f'{os.path.dirname(config_fn)}/{c}'
|
92 |
+
c = os.path.normpath(c)
|
93 |
+
if c not in loaded_configs:
|
94 |
+
override_config(ret_hparams, load_config(c, config_chains, loaded_configs))
|
95 |
+
override_config(ret_hparams, hparams_)
|
96 |
+
else:
|
97 |
+
ret_hparams = hparams_
|
98 |
+
|
99 |
+
config_chains.append(config_fn)
|
100 |
+
return ret_hparams
|
101 |
+
|
102 |
+
|
103 |
+
def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True):
|
104 |
+
if config == '' and exp_name == '':
|
105 |
+
parser = argparse.ArgumentParser(description='')
|
106 |
+
parser.add_argument('--config', type=str, default='',
|
107 |
+
help='location of the data corpus')
|
108 |
+
parser.add_argument('--exp_name', type=str, default='', help='exp_name')
|
109 |
+
parser.add_argument('-hp', '--hparams', type=str, default='',
|
110 |
+
help='location of the data corpus')
|
111 |
+
parser.add_argument('--infer', action='store_true', help='infer')
|
112 |
+
parser.add_argument('--validate', action='store_true', help='validate')
|
113 |
+
parser.add_argument('--reset', action='store_true', help='reset hparams')
|
114 |
+
parser.add_argument('--remove', action='store_true', help='remove old ckpt')
|
115 |
+
parser.add_argument('--debug', action='store_true', help='debug')
|
116 |
+
parser.add_argument('--start_rank', type=int, default=-1,
|
117 |
+
help='the start rank id for DDP, keep 0 when single-machine multi-GPU')
|
118 |
+
parser.add_argument('--world_size', type=int, default=-1,
|
119 |
+
help='the total number of GPU used across all machines, keep -1 for single-machine multi-GPU')
|
120 |
+
parser.add_argument('--init_method', type=str, default='tcp', help='method to init ddp, use tcp or file')
|
121 |
+
parser.add_argument('--master_addr', type=str, default='', help='')
|
122 |
+
parser.add_argument('--ddp_dir', type=str, default='', help='')
|
123 |
+
|
124 |
+
args, unknown = parser.parse_known_args()
|
125 |
+
if print_hparams:
|
126 |
+
print("| set_hparams Unknow hparams: ", unknown)
|
127 |
+
else:
|
128 |
+
args = Args(config=config, exp_name=exp_name, hparams=hparams_str,
|
129 |
+
infer=False, validate=False, reset=False, debug=False, remove=False,
|
130 |
+
start_rank=-1, world_size=-1, init_method='tcp', ddp_dir='', master_addr='')
|
131 |
+
global hparams
|
132 |
+
assert args.config != '' or args.exp_name != ''
|
133 |
+
if args.config != '':
|
134 |
+
assert os.path.exists(args.config), f"{args.config} not exists"
|
135 |
+
|
136 |
+
saved_hparams = {}
|
137 |
+
args_work_dir = ''
|
138 |
+
if args.exp_name != '':
|
139 |
+
args_work_dir = f'{args.exp_name}'
|
140 |
+
ckpt_config_path = f'{args_work_dir}/config.yaml'
|
141 |
+
if os.path.exists(ckpt_config_path):
|
142 |
+
with open(ckpt_config_path) as f:
|
143 |
+
saved_hparams_ = yaml.safe_load(f)
|
144 |
+
if saved_hparams_ is not None:
|
145 |
+
saved_hparams.update(saved_hparams_)
|
146 |
+
hparams_ = {}
|
147 |
+
config_chains = []
|
148 |
+
if args.config != '':
|
149 |
+
hparams_.update(load_config(args.config, config_chains, set()))
|
150 |
+
if len(config_chains) > 1 and print_hparams:
|
151 |
+
print('| Hparams chains: ', config_chains)
|
152 |
+
if not args.reset:
|
153 |
+
hparams_.update(saved_hparams)
|
154 |
+
traverse_dict(hparams_, parse_config, hparams_)
|
155 |
+
hparams_['work_dir'] = args_work_dir
|
156 |
+
|
157 |
+
# Support config overriding in command line. Support list type config overriding.
|
158 |
+
# Examples: --hparams="a=1,b.c=2,d=[1 1 1]"
|
159 |
+
if args.hparams != "":
|
160 |
+
for new_hparam in args.hparams.split(","):
|
161 |
+
k, v = new_hparam.split("=")
|
162 |
+
v = v.strip("\'\" ")
|
163 |
+
config_node = hparams_
|
164 |
+
for k_ in k.split(".")[:-1]:
|
165 |
+
config_node = config_node[k_]
|
166 |
+
k = k.split(".")[-1]
|
167 |
+
if k in config_node:
|
168 |
+
if v in ['True', 'False'] or type(config_node[k]) in [bool, list, dict]:
|
169 |
+
if type(config_node[k]) == list:
|
170 |
+
v = v.replace(" ", ",").replace('^', "\"")
|
171 |
+
if '|' in v:
|
172 |
+
tp = type(config_node[k][0]) if len(config_node[k]) else str
|
173 |
+
config_node[k] = [tp(x) for x in v.split("|") if x != '']
|
174 |
+
continue
|
175 |
+
config_node[k] = eval(v)
|
176 |
+
else:
|
177 |
+
config_node[k] = type(config_node[k])(v)
|
178 |
+
else:
|
179 |
+
config_node[k] = v
|
180 |
+
try:
|
181 |
+
config_node[k] = float(v)
|
182 |
+
except:
|
183 |
+
pass
|
184 |
+
try:
|
185 |
+
config_node[k] = int(v)
|
186 |
+
except:
|
187 |
+
pass
|
188 |
+
if v.lower() in ['false', 'true']:
|
189 |
+
config_node[k] = v.lower() == 'true'
|
190 |
+
|
191 |
+
if args_work_dir != '' and not args.infer:
|
192 |
+
os.makedirs(hparams_['work_dir'], exist_ok=True)
|
193 |
+
|
194 |
+
hparams_['infer'] = args.infer
|
195 |
+
hparams_['debug'] = args.debug
|
196 |
+
hparams_['validate'] = args.validate
|
197 |
+
hparams_['exp_name'] = args.exp_name
|
198 |
+
|
199 |
+
hparams_['start_rank'] = args.start_rank # useful for multi-machine training
|
200 |
+
hparams_['world_size'] = args.world_size
|
201 |
+
hparams_['init_method'] = args.init_method
|
202 |
+
hparams_['ddp_dir'] = args.ddp_dir
|
203 |
+
hparams_['master_addr'] = args.master_addr
|
204 |
+
|
205 |
+
remove_meta_key(hparams_)
|
206 |
+
global global_print_hparams
|
207 |
+
if global_hparams:
|
208 |
+
hparams.clear()
|
209 |
+
hparams.update(hparams_)
|
210 |
+
if print_hparams and global_print_hparams and global_hparams:
|
211 |
+
print('| Hparams: ', json.dumps(hparams_, indent=2, sort_keys=True))
|
212 |
+
# for i, (k, v) in enumerate(sorted(hparams_.items())):
|
213 |
+
# print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "")
|
214 |
+
global_print_hparams = False
|
215 |
+
return hparams_
|
tts/utils/text_utils/dict.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"phone": ["C0a", "C0ai", "C0air", "C0an", "C0ang", "C0angr", "C0anr", "C0ao", "C0aor", "C0ar", "C0b", "C0c", "C0ch", "C0d", "C0e", "C0ei", "C0eir", "C0en", "C0eng", "C0engr", "C0enr", "C0er", "C0f", "C0g", "C0h", "C0i", "C0ia", "C0ian", "C0iang", "C0iangr", "C0ianr", "C0iao", "C0iaor", "C0iar", "C0ie", "C0ier", "C0ii", "C0iii", "C0iiir", "C0iir", "C0in", "C0ing", "C0ingr", "C0inr", "C0io", "C0iong", "C0iongr", "C0iou", "C0iour", "C0ir", "C0j", "C0k", "C0l", "C0m", "C0n", "C0ng", "C0o", "C0ong", "C0ongr", "C0or", "C0ou", "C0our", "C0p", "C0q", "C0r", "C0s", "C0sh", "C0t", "C0u", "C0ua", "C0uai", "C0uair", "C0uan", "C0uang", "C0uangr", "C0uanr", "C0uar", "C0uei", "C0ueir", "C0uen", "C0ueng", "C0uengr", "C0uenr", "C0uo", "C0uor", "C0ur", "C0v", "C0van", "C0vanr", "C0ve", "C0ver", "C0vn", "C0vnr", "C0vr", "C0x", "C0z", "C0zh", "C0_", "E0aa", "E0ae", "E0ah", "E0ao", "E0aw", "E0ax", "E0ay", "E0b", "E0ch", "E0d", "E0dh", "E0eh", "E0ehr", "E0er", "E0ey", "E0f", "E0g", "E0hh", "E0ih", "E0iy", "E0iyr", "E0jh", "E0k", "E0l", "E0m", "E0n", "E0ng", "E0oh", "E0ow", "E0oy", "E0p", "E0r", "E0s", "E0sh", "E0t", "E0th", "E0uh", "E0uw", "E0uwr", "E0v", "E0w", "E0y", "E0z", "E0zh", "sil", "…", "、", "。", "《", "》", "【", "】", "!", """, "#", "$", "%", "'", "''", "(", ")", "*", ",", ":", ";", "?", "\", "^", "_", "`", "{", "}", "~"], "tone": ["0", "1", "10", "11", "12", "13", "15", "17", "2", "3", "4", "5", "6", "7", "8", "9"], "wordCategory": ["0", "B", "E", "M", "S"], "prosody": ["0", "1", "2", "3", "4"], "focus": ["0", "1"], "intonation": ["0", "1", "2"], "phraseAccent": ["0", "H-", "L-"], "boundaryTone": ["0", "H%", "L%"], "accentType": ["!H*", "0", "H*", "L*", "L*+H", "L+H*"]}
|
tts/utils/text_utils/ph_tone_convert.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 ByteDance and/or its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
def map_phone_to_tokendict(item, pad_bos_eos=True):
|
19 |
+
# Merge Chinese phone and tone (Original dict ends at 173, i.e., ph_dict_size=173). 146~173 is punctuations.
|
20 |
+
phone = item['txt_token'].clone()
|
21 |
+
merged_phone = item['txt_token'].clone()
|
22 |
+
tone_tmp = item['tone'].clone()
|
23 |
+
# In tone_dict, tone_1 is 4, tone_2 is 11, tone_3 is 12, tone_4 is 13, tone_5 is 14, tone_6 is 15
|
24 |
+
tone_tmp[tone_tmp==4] = 1
|
25 |
+
tone_tmp[tone_tmp==11] = 2
|
26 |
+
tone_tmp[tone_tmp==12] = 3
|
27 |
+
tone_tmp[tone_tmp==13] = 4
|
28 |
+
tone_tmp[tone_tmp==14] = 5
|
29 |
+
tone_tmp[tone_tmp==15] = 6
|
30 |
+
# Chinese phones lie in 3~100 in the phone_dict, we map them to 200~788
|
31 |
+
ch_phone_idx = (phone >= 3) & (phone <= 100)
|
32 |
+
merged_phone[ch_phone_idx] = (merged_phone[ch_phone_idx] - 3) * 6 + 200 + tone_tmp[ch_phone_idx]
|
33 |
+
|
34 |
+
if pad_bos_eos:
|
35 |
+
merged_phone = F.pad(merged_phone, (1, 0), mode='constant', value=798)
|
36 |
+
merged_phone = F.pad(merged_phone, (0, 1), mode='constant', value=799)
|
37 |
+
return merged_phone
|
38 |
+
|
39 |
+
def split_ph_timestamp(ph_timestamp):
|
40 |
+
''' Input: ph_timestamp, shape [T] '''
|
41 |
+
|
42 |
+
# Map the timestamp of each phone back to its original frame-level lengths
|
43 |
+
ph_timestamp[ph_timestamp >= 800] -= 800
|
44 |
+
|
45 |
+
ph_list = []
|
46 |
+
tone_list = []
|
47 |
+
dur_list = []
|
48 |
+
cur_timestamp = 0
|
49 |
+
for idx, item in enumerate(ph_timestamp):
|
50 |
+
if idx % 2 == 0:
|
51 |
+
# Map Chinese phones back to its original phone_dict
|
52 |
+
if (200 <= item <= 788):
|
53 |
+
ph = (item - 200 - 1) // 6 + 3
|
54 |
+
tone = (item - 200 - 1) % 6 + 1
|
55 |
+
if tone == 1:
|
56 |
+
tone = 4
|
57 |
+
else:
|
58 |
+
tone = tone + 9
|
59 |
+
# Set English tone to '3'
|
60 |
+
else:
|
61 |
+
ph = item
|
62 |
+
tone = 3
|
63 |
+
ph_list.append(ph)
|
64 |
+
tone_list.append(tone)
|
65 |
+
else:
|
66 |
+
dur_list.append((item - cur_timestamp))
|
67 |
+
cur_timestamp = item
|
68 |
+
assert len(ph_list) == len(dur_list), f"{len(ph_list)}, {len(dur_list)}"
|
69 |
+
ph_seq, tone_seq, dur_seq = torch.LongTensor(ph_list), torch.LongTensor(tone_list), torch.LongTensor(dur_list)
|
70 |
+
return ph_seq, tone_seq, dur_seq, ph_timestamp[-1]
|
71 |
+
|
72 |
+
def split_ph(ph_seq):
|
73 |
+
''' Input: ph_timestamp, shape [T] '''
|
74 |
+
ph_list = []
|
75 |
+
tone_list = []
|
76 |
+
for idx, item in enumerate(ph_seq):
|
77 |
+
# Map Chinese phones back to its original phone_dict
|
78 |
+
if (200 <= item <= 788):
|
79 |
+
ph = (item - 200 - 1) // 6 + 3
|
80 |
+
tone = (item - 200 - 1) % 6 + 1
|
81 |
+
if tone == 1:
|
82 |
+
tone = 4
|
83 |
+
else:
|
84 |
+
tone = tone + 9
|
85 |
+
# Set English tone to '3'
|
86 |
+
else:
|
87 |
+
ph = item
|
88 |
+
tone = 3
|
89 |
+
ph_list.append(ph)
|
90 |
+
tone_list.append(tone)
|
91 |
+
|
92 |
+
assert len(ph_list) == len(tone_list)
|
93 |
+
ph_seq, tone_seq = torch.LongTensor(ph_list), torch.LongTensor(tone_list)
|
94 |
+
return ph_seq, tone_seq
|
tts/utils/text_utils/split_text.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 ByteDance and/or its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import re
|
16 |
+
|
17 |
+
def chunk_text_chinese(text, limit=60):
|
18 |
+
# 中文字符匹配
|
19 |
+
chinese_pattern = re.compile(r'[\u4e00-\u9fff]')
|
20 |
+
# 标点符号匹配
|
21 |
+
punctuation = ",。!?;:,\.!?;"
|
22 |
+
|
23 |
+
result = [] # 存储断句结果
|
24 |
+
current_chunk = [] # 当前片段
|
25 |
+
chinese_count = 0 # 中文字符计数
|
26 |
+
|
27 |
+
i = 0
|
28 |
+
while i < len(text):
|
29 |
+
char = text[i]
|
30 |
+
current_chunk.append(char)
|
31 |
+
if chinese_pattern.match(char):
|
32 |
+
chinese_count += 1
|
33 |
+
|
34 |
+
if chinese_count >= limit: # 达到限制字符数
|
35 |
+
# 从当前位置往前找最近的标点符号
|
36 |
+
for j in range(len(current_chunk) - 1, -1, -1):
|
37 |
+
if current_chunk[j] in punctuation:
|
38 |
+
result.append(''.join(current_chunk[:j + 1]))
|
39 |
+
current_chunk = current_chunk[j + 1:]
|
40 |
+
chinese_count = sum(1 for c in current_chunk if chinese_pattern.match(c))
|
41 |
+
break
|
42 |
+
else:
|
43 |
+
# 如果前面没有标点符号,则继续找后面的标点符号
|
44 |
+
for k in range(i + 1, len(text)):
|
45 |
+
if text[k] in punctuation:
|
46 |
+
result.append(''.join(current_chunk)+text[i+1:k+1])
|
47 |
+
current_chunk = []
|
48 |
+
chinese_count = 0
|
49 |
+
i = k
|
50 |
+
break
|
51 |
+
i+=1
|
52 |
+
|
53 |
+
# 添加最后剩余的部分
|
54 |
+
if current_chunk:
|
55 |
+
result.append(''.join(current_chunk))
|
56 |
+
|
57 |
+
return result
|
58 |
+
|
59 |
+
def chunk_text_english(text, max_chars=130):
|
60 |
+
"""
|
61 |
+
Splits the input text into chunks, each with a maximum number of characters.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
text (str): The text to be split.
|
65 |
+
max_chars (int): The maximum number of characters per chunk.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
List[str]: A list of text chunks.
|
69 |
+
"""
|
70 |
+
chunks = []
|
71 |
+
current_chunk = ""
|
72 |
+
# Split the text into sentences based on punctuation followed by whitespace
|
73 |
+
sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
|
74 |
+
|
75 |
+
for sentence in sentences:
|
76 |
+
if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
|
77 |
+
current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
|
78 |
+
else:
|
79 |
+
if current_chunk:
|
80 |
+
chunks.append(current_chunk.strip())
|
81 |
+
current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
|
82 |
+
|
83 |
+
if current_chunk:
|
84 |
+
chunks.append(current_chunk.strip())
|
85 |
+
|
86 |
+
return chunks
|
87 |
+
|
88 |
+
if __name__ == '__main__':
|
89 |
+
print(chunk_text_chinese("哇塞!家人们,你们太好运了。我居然发现了一个宝藏零食大礼包,简直适合所有人的口味!有香辣的,让你舌尖跳舞;有盐焗的,咸香可口;还有五香的,香气四溢。就连怀孕的姐妹都吃得津津有味!整整三十包啊!什么手撕蟹柳、辣子鸡、嫩豆干、手撕素肉、鹌鹑蛋、小肉枣肠、猪肉腐、魔芋、魔芋丝等等,应有尽有。香辣土豆爽辣过瘾,各种素肉嚼劲十足,鹌鹑蛋营养美味,真的太多太多啦,...家人们,现在价格太划算了,赶紧下单。"))
|
90 |
+
print(chunk_text_english("Washington CNN When President Donald Trump declared in the House Chamber this week that executives at the nation’s top automakers were “so excited” about their prospects amid his new tariff regime, it did not entirely reflect the conversation he’d held with them earlier that day."))
|
tts/utils/text_utils/text_encoder.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 ByteDance and/or its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import json
|
16 |
+
import re
|
17 |
+
import six
|
18 |
+
from six.moves import range # pylint: disable=redefined-builtin
|
19 |
+
|
20 |
+
PAD = "<pad>"
|
21 |
+
EOS = "<EOS>"
|
22 |
+
UNK = "<UNK>"
|
23 |
+
SEG = "|"
|
24 |
+
PUNCS = '!,.?;:'
|
25 |
+
RESERVED_TOKENS = [PAD, EOS, UNK]
|
26 |
+
NUM_RESERVED_TOKENS = len(RESERVED_TOKENS)
|
27 |
+
PAD_ID = RESERVED_TOKENS.index(PAD) # Normally 0
|
28 |
+
EOS_ID = RESERVED_TOKENS.index(EOS) # Normally 1
|
29 |
+
UNK_ID = RESERVED_TOKENS.index(UNK) # Normally 2
|
30 |
+
|
31 |
+
if six.PY2:
|
32 |
+
RESERVED_TOKENS_BYTES = RESERVED_TOKENS
|
33 |
+
else:
|
34 |
+
RESERVED_TOKENS_BYTES = [bytes(PAD, "ascii"), bytes(EOS, "ascii")]
|
35 |
+
|
36 |
+
# Regular expression for unescaping token strings.
|
37 |
+
# '\u' is converted to '_'
|
38 |
+
# '\\' is converted to '\'
|
39 |
+
# '\213;' is converted to unichr(213)
|
40 |
+
_UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);")
|
41 |
+
_ESCAPE_CHARS = set(u"\\_u;0123456789")
|
42 |
+
|
43 |
+
|
44 |
+
def strip_ids(ids, ids_to_strip):
|
45 |
+
"""Strip ids_to_strip from the end ids."""
|
46 |
+
ids = list(ids)
|
47 |
+
while ids and ids[-1] in ids_to_strip:
|
48 |
+
ids.pop()
|
49 |
+
return ids
|
50 |
+
|
51 |
+
|
52 |
+
class TextEncoder(object):
|
53 |
+
"""Base class for converting from ints to/from human readable strings."""
|
54 |
+
|
55 |
+
def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS):
|
56 |
+
self._num_reserved_ids = num_reserved_ids
|
57 |
+
|
58 |
+
@property
|
59 |
+
def num_reserved_ids(self):
|
60 |
+
return self._num_reserved_ids
|
61 |
+
|
62 |
+
def encode(self, s):
|
63 |
+
"""Transform a human-readable string into a sequence of int ids.
|
64 |
+
|
65 |
+
The ids should be in the range [num_reserved_ids, vocab_size). Ids [0,
|
66 |
+
num_reserved_ids) are reserved.
|
67 |
+
|
68 |
+
EOS is not appended.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
s: human-readable string to be converted.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
ids: list of integers
|
75 |
+
"""
|
76 |
+
return [int(w) + self._num_reserved_ids for w in s.split()]
|
77 |
+
|
78 |
+
def decode(self, ids, strip_extraneous=False):
|
79 |
+
"""Transform a sequence of int ids into a human-readable string.
|
80 |
+
|
81 |
+
EOS is not expected in ids.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
ids: list of integers to be converted.
|
85 |
+
strip_extraneous: bool, whether to strip off extraneous tokens
|
86 |
+
(EOS and PAD).
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
s: human-readable string.
|
90 |
+
"""
|
91 |
+
if strip_extraneous:
|
92 |
+
ids = strip_ids(ids, list(range(self._num_reserved_ids or 0)))
|
93 |
+
return " ".join(self.decode_list(ids))
|
94 |
+
|
95 |
+
def decode_list(self, ids):
|
96 |
+
"""Transform a sequence of int ids into a their string versions.
|
97 |
+
|
98 |
+
This method supports transforming individual input/output ids to their
|
99 |
+
string versions so that sequence to/from text conversions can be visualized
|
100 |
+
in a human readable format.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
ids: list of integers to be converted.
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
strs: list of human-readable string.
|
107 |
+
"""
|
108 |
+
decoded_ids = []
|
109 |
+
for id_ in ids:
|
110 |
+
if 0 <= id_ < self._num_reserved_ids:
|
111 |
+
decoded_ids.append(RESERVED_TOKENS[int(id_)])
|
112 |
+
else:
|
113 |
+
decoded_ids.append(id_ - self._num_reserved_ids)
|
114 |
+
return [str(d) for d in decoded_ids]
|
115 |
+
|
116 |
+
@property
|
117 |
+
def vocab_size(self):
|
118 |
+
raise NotImplementedError()
|
119 |
+
|
120 |
+
|
121 |
+
class TokenTextEncoder(TextEncoder):
|
122 |
+
"""Encoder based on a user-supplied vocabulary (file or list)."""
|
123 |
+
|
124 |
+
def __init__(self,
|
125 |
+
vocab_filename,
|
126 |
+
reverse=False,
|
127 |
+
vocab_list=None,
|
128 |
+
replace_oov=None,
|
129 |
+
num_reserved_ids=NUM_RESERVED_TOKENS):
|
130 |
+
"""Initialize from a file or list, one token per line.
|
131 |
+
|
132 |
+
Handling of reserved tokens works as follows:
|
133 |
+
- When initializing from a list, we add reserved tokens to the vocab.
|
134 |
+
- When initializing from a file, we do not add reserved tokens to the vocab.
|
135 |
+
- When saving vocab files, we save reserved tokens to the file.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
vocab_filename: If not None, the full filename to read vocab from. If this
|
139 |
+
is not None, then vocab_list should be None.
|
140 |
+
reverse: Boolean indicating if tokens should be reversed during encoding
|
141 |
+
and decoding.
|
142 |
+
vocab_list: If not None, a list of elements of the vocabulary. If this is
|
143 |
+
not None, then vocab_filename should be None.
|
144 |
+
replace_oov: If not None, every out-of-vocabulary token seen when
|
145 |
+
encoding will be replaced by this string (which must be in vocab).
|
146 |
+
num_reserved_ids: Number of IDs to save for reserved tokens like <EOS>.
|
147 |
+
"""
|
148 |
+
super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids)
|
149 |
+
self._reverse = reverse
|
150 |
+
self._replace_oov = replace_oov
|
151 |
+
if vocab_filename:
|
152 |
+
self._init_vocab_from_file(vocab_filename)
|
153 |
+
else:
|
154 |
+
assert vocab_list is not None
|
155 |
+
self._init_vocab_from_list(vocab_list)
|
156 |
+
self.pad_index = self.token_to_id[PAD]
|
157 |
+
self.eos_index = self.token_to_id[EOS]
|
158 |
+
self.unk_index = self.token_to_id[UNK]
|
159 |
+
self.seg_index = self.token_to_id[SEG] if SEG in self.token_to_id else self.eos_index
|
160 |
+
|
161 |
+
def encode(self, s):
|
162 |
+
"""Converts a space-separated string of tokens to a list of ids."""
|
163 |
+
if isinstance(s, str):
|
164 |
+
sentence = s
|
165 |
+
tokens = sentence.strip().split()
|
166 |
+
else:
|
167 |
+
tokens = s
|
168 |
+
if self._replace_oov is not None:
|
169 |
+
tokens = [t if t in self.token_to_id else self._replace_oov
|
170 |
+
for t in tokens]
|
171 |
+
ret = [self.token_to_id[tok] for tok in tokens]
|
172 |
+
return ret[::-1] if self._reverse else ret
|
173 |
+
|
174 |
+
def decode(self, ids, strip_eos=False, strip_padding=False):
|
175 |
+
if strip_padding and self.pad() in list(ids):
|
176 |
+
pad_pos = list(ids).index(self.pad())
|
177 |
+
ids = ids[:pad_pos]
|
178 |
+
if strip_eos and self.eos() in list(ids):
|
179 |
+
eos_pos = list(ids).index(self.eos())
|
180 |
+
ids = ids[:eos_pos]
|
181 |
+
return " ".join(self.decode_list(ids))
|
182 |
+
|
183 |
+
def decode_list(self, ids):
|
184 |
+
seq = reversed(ids) if self._reverse else ids
|
185 |
+
return [self._safe_id_to_token(i) for i in seq]
|
186 |
+
|
187 |
+
@property
|
188 |
+
def vocab_size(self):
|
189 |
+
return len(self.id_to_token)
|
190 |
+
|
191 |
+
def __len__(self):
|
192 |
+
return self.vocab_size
|
193 |
+
|
194 |
+
def _safe_id_to_token(self, idx):
|
195 |
+
return self.id_to_token.get(idx, "ID_%d" % idx)
|
196 |
+
|
197 |
+
def _init_vocab_from_file(self, filename):
|
198 |
+
"""Load vocab from a file.
|
199 |
+
|
200 |
+
Args:
|
201 |
+
filename: The file to load vocabulary from.
|
202 |
+
"""
|
203 |
+
with open(filename) as f:
|
204 |
+
tokens = [token.strip() for token in f.readlines()]
|
205 |
+
|
206 |
+
def token_gen():
|
207 |
+
for token in tokens:
|
208 |
+
yield token
|
209 |
+
|
210 |
+
self._init_vocab(token_gen(), add_reserved_tokens=False)
|
211 |
+
|
212 |
+
def _init_vocab_from_list(self, vocab_list):
|
213 |
+
"""Initialize tokens from a list of tokens.
|
214 |
+
|
215 |
+
It is ok if reserved tokens appear in the vocab list. They will be
|
216 |
+
removed. The set of tokens in vocab_list should be unique.
|
217 |
+
|
218 |
+
Args:
|
219 |
+
vocab_list: A list of tokens.
|
220 |
+
"""
|
221 |
+
|
222 |
+
def token_gen():
|
223 |
+
for token in vocab_list:
|
224 |
+
if token not in RESERVED_TOKENS:
|
225 |
+
yield token
|
226 |
+
|
227 |
+
self._init_vocab(token_gen())
|
228 |
+
|
229 |
+
def _init_vocab(self, token_generator, add_reserved_tokens=True):
|
230 |
+
"""Initialize vocabulary with tokens from token_generator."""
|
231 |
+
|
232 |
+
self.id_to_token = {}
|
233 |
+
non_reserved_start_index = 0
|
234 |
+
|
235 |
+
if add_reserved_tokens:
|
236 |
+
self.id_to_token.update(enumerate(RESERVED_TOKENS))
|
237 |
+
non_reserved_start_index = len(RESERVED_TOKENS)
|
238 |
+
|
239 |
+
self.id_to_token.update(
|
240 |
+
enumerate(token_generator, start=non_reserved_start_index))
|
241 |
+
|
242 |
+
# _token_to_id is the reverse of _id_to_token
|
243 |
+
self.token_to_id = dict((v, k) for k, v in six.iteritems(self.id_to_token))
|
244 |
+
|
245 |
+
def pad(self):
|
246 |
+
return self.pad_index
|
247 |
+
|
248 |
+
def eos(self):
|
249 |
+
return self.eos_index
|
250 |
+
|
251 |
+
def unk(self):
|
252 |
+
return self.unk_index
|
253 |
+
|
254 |
+
def seg(self):
|
255 |
+
return self.seg_index
|
256 |
+
|
257 |
+
def store_to_file(self, filename):
|
258 |
+
"""Write vocab file to disk.
|
259 |
+
|
260 |
+
Vocab files have one token per line. The file ends in a newline. Reserved
|
261 |
+
tokens are written to the vocab file as well.
|
262 |
+
|
263 |
+
Args:
|
264 |
+
filename: Full path of the file to store the vocab to.
|
265 |
+
"""
|
266 |
+
with open(filename, "w") as f:
|
267 |
+
for i in range(len(self.id_to_token)):
|
268 |
+
f.write(self.id_to_token[i] + "\n")
|
269 |
+
|
270 |
+
def sil_phonemes(self):
|
271 |
+
return [p for p in self.id_to_token.values() if is_sil_phoneme(p)]
|
272 |
+
|
273 |
+
|
274 |
+
def build_token_encoder(token_list_file):
|
275 |
+
token_list = json.load(open(token_list_file))
|
276 |
+
return TokenTextEncoder(None, vocab_list=token_list, replace_oov='<UNK>')
|
277 |
+
|
278 |
+
|
279 |
+
def is_sil_phoneme(p):
|
280 |
+
return p == '' or not p[0].isalpha() or p == 'sil' or p == 'sp' or p == 'XX'
|