ZiyueJiang commited on
Commit
593f3bc
·
0 Parent(s):

first commit for huggingface space

Browse files
Dockerfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.3.0-cuda12.1-cudnn8-runtime
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y \
6
+ curl \
7
+ python3 \
8
+ python3-pip \
9
+ ffmpeg \
10
+ && apt-get clean
11
+
12
+ COPY requirements.txt /app/
13
+
14
+ RUN pip install --no-cache-dir -r requirements.txt
15
+
16
+ COPY . /app/
17
+
18
+ CMD ["python", "-m", "tts.gradio_api"]
LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright [2025] ByteDance
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import multiprocessing as mp
16
+ import torch
17
+ import os
18
+ from functools import partial
19
+ import gradio as gr
20
+ import traceback
21
+ from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav
22
+
23
+
24
+ def model_worker(input_queue, output_queue, device_id):
25
+ device = None
26
+ if device_id is not None:
27
+ device = torch.device(f'cuda:{device_id}')
28
+ infer_pipe = MegaTTS3DiTInfer(device=device)
29
+ os.system(f'pkill -f "voidgpu{device_id}"')
30
+
31
+ while True:
32
+ task = input_queue.get()
33
+ inp_audio_path, inp_npy_path, inp_text, infer_timestep, p_w, t_w = task
34
+ try:
35
+ convert_to_wav(inp_audio_path)
36
+ wav_path = os.path.splitext(inp_audio_path)[0] + '.wav'
37
+ cut_wav(wav_path, max_len=28)
38
+ with open(wav_path, 'rb') as file:
39
+ file_content = file.read()
40
+ resource_context = infer_pipe.preprocess(file_content, latent_file=inp_npy_path)
41
+ wav_bytes = infer_pipe.forward(resource_context, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
42
+ output_queue.put(wav_bytes)
43
+ except Exception as e:
44
+ traceback.print_exc()
45
+ print(task, str(e))
46
+ output_queue.put(None)
47
+
48
+
49
+ def main(inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w, processes, input_queue, output_queue):
50
+ print("Push task to the inp queue |", inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w)
51
+ input_queue.put((inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w))
52
+ res = output_queue.get()
53
+ if res is not None:
54
+ return res
55
+ else:
56
+ print("")
57
+ return None
58
+
59
+
60
+ if __name__ == '__main__':
61
+ mp.set_start_method('spawn', force=True)
62
+ devices = os.environ.get('CUDA_VISIBLE_DEVICES', '')
63
+ if devices != '':
64
+ devices = os.environ.get('CUDA_VISIBLE_DEVICES', '').split(",")
65
+ for d in devices:
66
+ os.system(f'pkill -f "voidgpu{d}"')
67
+ else:
68
+ devices = None
69
+
70
+ num_workers = 1
71
+ input_queue = mp.Queue()
72
+ output_queue = mp.Queue()
73
+ processes = []
74
+
75
+ print("Start open workers")
76
+ for i in range(num_workers):
77
+ p = mp.Process(target=model_worker, args=(input_queue, output_queue, i % len(devices) if devices is not None else None))
78
+ p.start()
79
+ processes.append(p)
80
+
81
+ api_interface = gr.Interface(fn=
82
+ partial(main, processes=processes, input_queue=input_queue,
83
+ output_queue=output_queue),
84
+ inputs=[gr.Audio(type="filepath", label="Upload .wav"), gr.File(type="filepath", label="Upload .npy"), "text",
85
+ gr.Number(label="infer timestep", value=32),
86
+ gr.Number(label="Intelligibility Weight", value=1.4),
87
+ gr.Number(label="Similarity Weight", value=3.0)], outputs=[gr.Audio(label="Synthesized Audio")],
88
+ title="MegaTTS3",
89
+ description="Upload a speech clip as a reference for timbre, " +
90
+ "upload the pre-extracted latent file, "+
91
+ "input the target text, and receive the cloned voice.", concurrency_limit=1)
92
+ api_interface.launch(server_name='0.0.0.0', server_port=7929, debug=True)
93
+ for p in processes:
94
+ p.join()
readme.md ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <h1>
3
+ MegaTTS 3 <img src="./assets/fig/Hi.gif" width="40px">
4
+ </h1>
5
+ <p>
6
+ Official PyTorch Implementation<br>
7
+ </p>
8
+ <p></p>
9
+ <img src="https://img.shields.io/badge/Bytedance-%230077B5.svg?&style=flat-square&logo=bytedance&logoColor=white" />
10
+ <img src="https://img.shields.io/badge/Zhejiang University-%230077B5.svg?&style=flat-square&logo=data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHZpZXdCb3g9IjAgMCA1MTIgNTEyIj48IS0tIUZvbnQgQXdlc29tZSBGcmVlIDYuNy4yIGJ5IEBmb250YXdlc29tZSAtIGh0dHBzOi8vZm9udGF3ZXNvbWUuY29tIExpY2Vuc2UgLSBodHRwczovL2ZvbnRhd2Vzb21lLmNvbS9saWNlbnNlL2ZyZWUgQ29weXJpZ2h0IDIwMjUgRm9udGljb25zLCBJbmMuLS0+PHBhdGggZmlsbD0iI2ZmZmZmZiIgZD0iTTI0My40IDIuNmwtMjI0IDk2Yy0xNCA2LTIxLjggMjEtMTguNyAzNS44UzE2LjggMTYwIDMyIDE2MGwwIDhjMCAxMy4zIDEwLjcgMjQgMjQgMjRsNDAwIDBjMTMuMyAwIDI0LTEwLjcgMjQtMjRsMC04YzE1LjIgMCAyOC4zLTEwLjcgMzEuMy0yNS42cy00LjgtMjkuOS0xOC43LTM1LjhsLTIyNC05NmMtOC0zLjQtMTcuMi0zLjQtMjUuMiAwek0xMjggMjI0bC02NCAwIDAgMTk2LjNjLS42IC4zLTEuMiAuNy0xLjggMS4xbC00OCAzMmMtMTEuNyA3LjgtMTcgMjIuNC0xMi45IDM1LjlTMTcuOSA1MTIgMzIgNTEybDQ0OCAwYzE0LjEgMCAyNi41LTkuMiAzMC42LTIyLjdzLTEuMS0yOC4xLTEyLjktMzUuOWwtNDgtMzJjLS42LS40LTEuMi0uNy0xLjgtMS4xTDQ0OCAyMjRsLTY0IDAgMCAxOTItNDAgMCAwLTE5Mi02NCAwIDAgMTkyLTQ4IDAgMC0xOTItNjQgMCAwIDE5Mi00MCAwIDAtMTkyek0yNTYgNjRhMzIgMzIgMCAxIDEgMCA2NCAzMiAzMiAwIDEgMSAwLTY0eiIvPjwvc3ZnPg==&logoColor=white" />
11
+ </div>
12
+
13
+ ## Key features
14
+ - 🚀**Lightweight and Efficient:** The backbone of the TTS Diffusion Transformer has only 0.45B parameters.
15
+ - 🎧**Ultra High-Quality Voice Cloning:** See the demo video below! We also report results of recent TTS models on the Seed test sets in the following table. 🎉Submit a sample on [link2](https://drive.google.com/drive/folders/1gCWL1y_2xu9nIFhUX_OW5MbcFuB7J5Cl?usp=sharing) to receive voice latents you can use locally.
16
+ - 🌍**Bilingual Support:** Supports both Chinese and English, and code-switching.
17
+ - ✍️**Controllable:** Supports accent intensity control ✅ and fine-grained pronunciation/duration adjustment (coming soon).
18
+
19
+ [MegaTTS 3 Demo Video](https://github.com/user-attachments/assets/0174c111-f392-4376-a34b-0b5b8164aacc)
20
+
21
+ <div style='width:100%;text-align:center'>
22
+ <img src="./assets/fig/table_tts.png" width="550px">
23
+ </div>
24
+
25
+ ## 🎯Roadmap
26
+
27
+ - **[2025-03-22]** Our project has been released!
28
+
29
+
30
+ ## Installation
31
+ ``` sh
32
+ # Clone the repository
33
+ git clone https://github.com/bytedance/MegaTTS3
34
+ cd MegaTTS3
35
+ ```
36
+ **Requirements (for Linux)**
37
+ ``` sh
38
+
39
+ # Create a python 3.10 conda env (you could also use virtualenv)
40
+ conda create -n megatts3-env python=3.10
41
+ conda activate megatts3-env
42
+ pip install -r requirements.txt
43
+
44
+ # Set the root directory
45
+ export PYTHONPATH="/path/to/MegaTTS3:$PYTHONPATH"
46
+
47
+ # [Optional] Set GPU
48
+ export CUDA_VISIBLE_DEVICES=0
49
+
50
+ # If you encounter bugs with pydantic in inference, you should check if the versions of pydantic and gradio are matched.
51
+ # [Note] if you encounter bugs related with httpx, please check that whether your environmental variable "no_proxy" has patterns like "::"
52
+ ```
53
+
54
+ **Requirements (for Windows)**
55
+ ``` sh
56
+ # [The Windows version is currently under testing]
57
+ # Comment below dependence in requirements.txt:
58
+ # # WeTextProcessing==1.0.4.1
59
+
60
+ # Create a python 3.10 conda env (you could also use virtualenv)
61
+ conda create -n megatts3-env python=3.10
62
+ conda activate megatts3-env
63
+ pip install -r requirements.txt
64
+ conda install -y -c conda-forge pynini==2.1.5
65
+ pip install WeTextProcessing==1.0.3
66
+
67
+ # [Optional] If you want GPU inference, you may need to install specific version of PyTorch for your GPU from https://pytorch.org/.
68
+ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
69
+
70
+ # [Note] if you encounter bugs related with `ffprobe` or `ffmpeg`, you can install it through `conda install -c conda-forge ffmpeg`
71
+
72
+ # Set environment variable for root directory
73
+ set PYTHONPATH="C:\path\to\MegaTTS3;%PYTHONPATH%" # Windows
74
+ $env:PYTHONPATH="C:\path\to\MegaTTS3;%PYTHONPATH%" # Powershell on Windows
75
+ conda env config vars set PYTHONPATH="C:\path\to\MegaTTS3;%PYTHONPATH%" # For conda users
76
+
77
+ # [Optional] Set GPU
78
+ set CUDA_VISIBLE_DEVICES=0 # Windows
79
+ $env:CUDA_VISIBLE_DEVICES=0 # Powershell on Windows
80
+
81
+ ```
82
+
83
+ **Requirements (for Docker)**
84
+ ``` sh
85
+ # [The Docker version is currently under testing]
86
+ # ! You should download the pretrained checkpoint before running the following command
87
+ docker build . -t megatts3:latest
88
+
89
+ # For GPU inference
90
+ docker run -it -p 7929:7929 --gpus all -e CUDA_VISIBLE_DEVICES=0 megatts3:latest
91
+ # For CPU inference
92
+ docker run -it -p 7929:7929 megatts3:latest
93
+
94
+ # Visit http://0.0.0.0:7929/ for gradio.
95
+ ```
96
+
97
+
98
+ **Model Download**
99
+
100
+ The pretrained checkpoint can be found at [Google Drive](https://drive.google.com/drive/folders/1CidiSqtHgJTBDAHQ746_on_YR0boHDYB?usp=sharing) or [Huggingface](https://huggingface.co/ByteDance/MegaTTS3). Please download them and put them to ``./checkpoints/xxx``.
101
+
102
+ > [!IMPORTANT]
103
+ > For security issues, we do not upload the parameters of WaveVAE encoder to the above links. You can only use the pre-extracted latents from [link1](https://drive.google.com/drive/folders/1QhcHWcy20JfqWjgqZX1YM3I6i9u4oNlr?usp=sharing) for inference. If you want to synthesize speech for speaker A, you need "A.wav" and "A.npy" in the same directory. If you have any questions or suggestions for our model, please email us.
104
+ >
105
+ > This project is primarily intended for academic purposes. For academic datasets requiring evaluation, you may upload them to the voice request queue in [link2](https://drive.google.com/drive/folders/1gCWL1y_2xu9nIFhUX_OW5MbcFuB7J5Cl?usp=sharing) (within 24s for each clip). After verifying that your uploaded voices are free from safety issues, we will upload their latent files to [link1](https://drive.google.com/drive/folders/1QhcHWcy20JfqWjgqZX1YM3I6i9u4oNlr?usp=sharing) as soon as possible.
106
+ >
107
+ > In the coming days, we will also prepare and release the latent representations for some common TTS benchmarks.
108
+
109
+ ## Inference
110
+
111
+ **Command-Line Usage (Standard)**
112
+ ``` bash
113
+ # p_w (intelligibility weight), t_w (similarity weight). Typically, prompt with more noises requires higher p_w and t_w
114
+ python tts/infer_cli.py --input_wav 'assets/Chinese_prompt.wav' --input_text "另一边的桌上,一位读书人嗤之以鼻道,'佛子三藏,神子燕小鱼是什么样的人物,李家的那个李子夜如何与他们相提并论?'" --output_dir ./gen
115
+
116
+ # As long as audio volume and pronunciation are appropriate, increasing --t_w within reasonable ranges (2.0~5.0)
117
+ # will increase the generated speech's expressiveness and similarity (especially for some emotional cases).
118
+ python tts/infer_cli.py --input_wav 'assets/English_prompt.wav' --input_text 'As his long promised tariff threat turned into reality this week, top human advisers began fielding a wave of calls from business leaders, particularly in the automotive sector, along with lawmakers who were sounding the alarm.' --output_dir ./gen --p_w 2.0 --t_w 3.0
119
+ ```
120
+ **Command-Line Usage (for TTS with Accents)**
121
+ ``` bash
122
+ # When p_w (intelligibility weight) ≈ 1.0, the generated audio closely retains the speaker’s original accent. As p_w increases, it shifts toward standard pronunciation.
123
+ # t_w (similarity weight) is typically set 0–3 points higher than p_w for optimal results.
124
+ # Useful for accented TTS or solving the accent problems in cross-lingual TTS.
125
+ python tts/infer_cli.py --input_wav 'assets/English_prompt.wav' --input_text '这是一条有口音的音频。' --output_dir ./gen --p_w 1.0 --t_w 3.0
126
+
127
+ python tts/infer_cli.py --input_wav 'assets/English_prompt.wav' --input_text '这条音频的发音标准一些了吗?' --output_dir ./gen --p_w 2.5 --t_w 2.5
128
+ ```
129
+
130
+ **Web UI Usage**
131
+ ``` bash
132
+ # We also support cpu inference, but it may take about 30 seconds (for 10 inference steps).
133
+ python tts/gradio_api.py
134
+ ```
135
+
136
+ ## Submodules
137
+ > [!TIP]
138
+ > In addition to TTS, some submodules in this project may also have additional usages.
139
+ > See ``./tts/frontend_fuction.py`` and ``./tts/infer_cli.py`` for example code.
140
+
141
+ ### Aligner
142
+ **Description:** a robust speech-text aligner model trained using pseudo-labels generated by a large number of MFA expert models.
143
+
144
+ **Usage**: 1) Prepare the finetuning dataset for our model; 2) Filter the large-scale speech dataset (if the aligner fails to align a certain speech clip, it is likely to be noisy); 3) Phoneme recognition; 4) Speech segmentation.
145
+
146
+ ### Graphme-to-Phoneme Model
147
+ **Description:** a Qwen2.5-0.5B model finetuned for robust graphme-to-phoneme conversion.
148
+
149
+ **Usage**: Graphme-to-phoneme conversion.
150
+
151
+ ### WaveVAE
152
+ **Description:** a strong waveform VAE that can compress 24 kHz speeche into 25 Hz acoustic latent and reconstruct the original wave almost losslessly.
153
+
154
+ **Usage:** 1) Acoustic latents can provide a more compact and discriminative training target for speech synthesis models compared to mel-spectrograms, accelerating convergence; 2) Used as acoustic latents for voice conversion; 3) High-quality vocoder.
155
+
156
+ <div style='width:100%;text-align:center'>
157
+ <img src="./assets/fig/table_wavvae.png" width="650px">
158
+ </div>
159
+
160
+
161
+ ## Security
162
+ If you discover a potential security issue in this project, or think you may
163
+ have discovered a security issue, we ask that you notify Bytedance Security via our [security center](https://security.bytedance.com/src) or [[email protected]]([email protected]).
164
+
165
+ Please do **not** create a public GitHub issue.
166
+
167
+ ## License
168
+ This project is licensed under the [Apache-2.0 License](LICENSE).
169
+
170
+ ## Citation
171
+ This repo contains forced-align version of `Sparse Alignment Enhanced Latent Diffusion Transformer for Zero-Shot Speech Synthesis` and the WavVAE is mainly based on `Wavtokenizer: an efficient acoustic discrete codec tokenizer for audio language modeling`. Compared to the model described in paper, the repository includes additional models. These models not only enhance the stability and cloning capabilities of the algorithm but can also be independently utilized to serve a wider range of scenarios.
172
+ ```
173
+ @article{jiang2025sparse,
174
+ title={Sparse Alignment Enhanced Latent Diffusion Transformer for Zero-Shot Speech Synthesis},
175
+ author={Jiang, Ziyue and Ren, Yi and Li, Ruiqi and Ji, Shengpeng and Ye, Zhenhui and Zhang, Chen and Jionghao, Bai and Yang, Xiaoda and Zuo, Jialong and Zhang, Yu and others},
176
+ journal={arXiv preprint arXiv:2502.18924},
177
+ year={2025}
178
+ }
179
+
180
+ @article{ji2024wavtokenizer,
181
+ title={Wavtokenizer: an efficient acoustic discrete codec tokenizer for audio language modeling},
182
+ author={Ji, Shengpeng and Jiang, Ziyue and Wang, Wen and Chen, Yifu and Fang, Minghui and Zuo, Jialong and Yang, Qian and Cheng, Xize and Wang, Zehan and Li, Ruiqi and others},
183
+ journal={arXiv preprint arXiv:2408.16532},
184
+ year={2024}
185
+ }
186
+ ```
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.3.0
2
+ torchaudio==2.3.0
3
+ setproctitle==1.3.3
4
+ attrdict==2.0.1
5
+ librosa==0.10.2.post1
6
+ langdetect==1.0.9
7
+ pydub==0.25.1
8
+ pyloudnorm==0.1.1
9
+ modelscope==1.22.2
10
+ WeTextProcessing==1.0.4.1
11
+ transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.*;python_version<'3.10'
12
+ transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.0;python_version>='3.10'
13
+ x-transformers==1.44.4
14
+ torchdiffeq==0.2.5
15
+ openai-whisper==20240930
16
+ httpx==0.28.1
17
+ gradio==5.23.1
tts/frontend_function.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ import whisper
18
+ import librosa
19
+ from copy import deepcopy
20
+ from tts.utils.text_utils.ph_tone_convert import split_ph_timestamp, split_ph
21
+ from tts.utils.audio_utils.align import mel2token_to_dur
22
+
23
+ ''' Graphme to phoneme function '''
24
+ def g2p(self, text_inp):
25
+ # prepare inputs
26
+ txt_token = self.g2p_tokenizer('<BOT>' + text_inp + '<BOS>')['input_ids']
27
+ input_ids = torch.LongTensor([txt_token+[145+self.speech_start_idx]]).to(self.device)
28
+
29
+ # model forward
30
+ with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
31
+ outputs = self.g2p_model.generate(input_ids, max_new_tokens=256, do_sample=True, top_k=1, eos_token_id=800+1+self.speech_start_idx)
32
+
33
+ # process outputs
34
+ ph_tokens = outputs[:, len(txt_token):-1]-self.speech_start_idx
35
+ ph_pred, tone_pred = split_ph(ph_tokens[0])
36
+ ph_pred, tone_pred = ph_pred[None, :].to(self.device), tone_pred[None, :].to(self.device)
37
+ return ph_pred, tone_pred
38
+
39
+ ''' Get phoneme2mel align of prompt speech '''
40
+ def align(self, wav):
41
+ with torch.inference_mode():
42
+ whisper_wav = librosa.resample(wav, orig_sr=self.sr, target_sr=16000)
43
+ mel = torch.FloatTensor(whisper.log_mel_spectrogram(whisper_wav).T).to(self.device)[None].transpose(1,2)
44
+ prompt_max_frame = mel.size(2) // self.fm * self.fm
45
+ mel = mel[:, :, :prompt_max_frame]
46
+ token = torch.LongTensor([[798]]).to(self.device)
47
+ audio_features = self.aligner_lm.embed_audio(mel)
48
+ for i in range(768):
49
+ with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
50
+ logits = self.aligner_lm.logits(token, audio_features, None)
51
+ token_pred = torch.argmax(F.softmax(logits[:, -1], dim=-1), 1)[None]
52
+ token = torch.cat([token, token_pred], dim=1)
53
+ if token_pred[0] == 799:
54
+ break
55
+ alignment_tokens = token
56
+
57
+ ph_ref, tone_ref, dur_ref, _ = split_ph_timestamp(deepcopy(alignment_tokens)[0, 1:-1])
58
+ ph_ref = torch.Tensor(ph_ref)[None].to(self.device)
59
+ tone_ref = torch.Tensor(tone_ref)[None].to(self.device)
60
+ if dur_ref.sum() < prompt_max_frame:
61
+ dur_ref[-1] += prompt_max_frame - dur_ref.sum()
62
+ elif dur_ref.sum() > prompt_max_frame:
63
+ len_diff = dur_ref.sum() - prompt_max_frame
64
+ while True:
65
+ for i in range(len(dur_ref)):
66
+ dur_ref[i] -= 1
67
+ len_diff -= 1
68
+ if len_diff == 0:
69
+ break
70
+ if len_diff == 0:
71
+ break
72
+ mel2ph_ref = self.length_regulator(dur_ref[None]).to(self.device)
73
+ mel2ph_ref = mel2ph_ref[:, :mel2ph_ref.size(1)//self.fm*self.fm]
74
+ return ph_ref, tone_ref, mel2ph_ref
75
+
76
+ ''' Duration Prompting '''
77
+ def make_dur_prompt(self, mel2ph_ref, ph_ref, tone_ref):
78
+ dur_tokens_2d_ = mel2token_to_dur(mel2ph_ref, ph_ref.shape[1]).clamp(
79
+ max=self.hp_dur_model['dur_code_size'] - 1) + 1
80
+
81
+ ctx_dur_tokens = dur_tokens_2d_.clone().flatten(0, 1).to(self.device)
82
+ txt_tokens_flat_ = ph_ref.flatten(0, 1)
83
+ ctx_dur_tokens = ctx_dur_tokens[txt_tokens_flat_ > 0][None]
84
+
85
+ last_dur_pos_prompt = ctx_dur_tokens.shape[1]
86
+ dur_spk_pos_ids_flat = range(0, last_dur_pos_prompt)
87
+ dur_spk_pos_ids_flat = torch.LongTensor([dur_spk_pos_ids_flat]).to(self.device)
88
+ with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
89
+ _, incremental_state_dur_prompt = self.dur_model.infer(
90
+ ph_ref, {'tone': tone_ref}, None, None, None,
91
+ ctx_vqcodes=ctx_dur_tokens, spk_pos_ids_flat=dur_spk_pos_ids_flat, return_state=True)
92
+ return incremental_state_dur_prompt, ctx_dur_tokens
93
+
94
+ ''' Duration Prediction '''
95
+ def dur_pred(self, ctx_dur_tokens, incremental_state_dur_prompt, ph_pred, tone_pred, seg_i, dur_disturb, dur_alpha, is_first, is_final):
96
+ last_dur_token = ctx_dur_tokens[:, -1:]
97
+ last_dur_pos_prompt = ctx_dur_tokens.shape[1]
98
+ incremental_state_dur = deepcopy(incremental_state_dur_prompt)
99
+ txt_len = ph_pred.shape[1]
100
+ dur_spk_pos_ids_flat = range(last_dur_pos_prompt, last_dur_pos_prompt + txt_len)
101
+ dur_spk_pos_ids_flat = torch.LongTensor([dur_spk_pos_ids_flat]).to(self.device)
102
+ last_dur_pos_prompt = last_dur_pos_prompt + txt_len
103
+
104
+ with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
105
+ dur_pred = self.dur_model.infer(
106
+ ph_pred, {'tone': tone_pred}, None, None, None,
107
+ incremental_state=incremental_state_dur,
108
+ first_decoder_inp=last_dur_token,
109
+ spk_pos_ids_flat=dur_spk_pos_ids_flat,
110
+ )
111
+
112
+ dur_pred = dur_pred - 1
113
+ dur_pred = dur_pred.clamp(0, self.hp_dur_model['dur_code_size'] - 1)
114
+ # if is_final:
115
+ # dur_pred[:, -1] = dur_pred[:, -1].clamp(64, 128)
116
+ # else:
117
+ # dur_pred[:, -1] = dur_pred[:, -1].clamp(48, 128)
118
+ # if seg_i > 0:
119
+ # dur_pred[:, 0] = 0
120
+ # ['。', '!', '?', 'sil']
121
+ for sil_token in [148, 153, 166, 145]:
122
+ dur_pred[ph_pred==sil_token].clamp_min(32)
123
+ # [',', ';']
124
+ for sil_token in [163, 165]:
125
+ dur_pred[ph_pred==sil_token].clamp_min(16)
126
+ if not is_final:
127
+ # add 0.32ms for crossfade
128
+ dur_pred[:, -1] = dur_pred[:, -1] + 32
129
+ else:
130
+ dur_pred[:, -1] = dur_pred[:, -1].clamp(64, 128)
131
+
132
+ ''' DiT target speech generation '''
133
+ dur_disturb_choice = (torch.rand_like(dur_pred.float()) > 0.5).float()
134
+ dur_disturb_r = 1 + torch.rand_like(dur_pred.float()) * dur_disturb
135
+ dur_pred = dur_pred * dur_disturb_r * dur_disturb_choice + \
136
+ dur_pred / dur_disturb_r * (1 - dur_disturb_choice)
137
+ dur_pred = torch.round(dur_pred * dur_alpha).clamp(0, 127)
138
+ if is_first:
139
+ dur_pred[:, 0] = 8
140
+
141
+ dur_sum = dur_pred.sum()
142
+ npad = self.fm - dur_sum % self.fm
143
+ if npad < self.fm:
144
+ dur_pred[:, -1] += npad
145
+ mel2ph_pred = self.length_regulator(dur_pred).to(self.device)
146
+ return mel2ph_pred
147
+
148
+ def prepare_inputs_for_dit(self, mel2ph_ref, mel2ph_pred, ph_ref, tone_ref, ph_pred, tone_pred, vae_latent):
149
+ # Prepare duration token
150
+ mel2ph_pred = torch.cat((mel2ph_ref, mel2ph_pred+ph_ref.size(1)), dim=1)
151
+ mel2ph_pred = mel2ph_pred[:, :mel2ph_pred.size(1)//self.fm*self.fm].repeat(3, 1)
152
+ # Prepare phone and tone token
153
+ ph_pred = torch.cat((ph_ref, ph_pred), dim=1)
154
+ tone_pred = torch.cat((tone_ref, tone_pred), dim=1)
155
+ # Disable the English tone (set them to 3)"""
156
+ en_tone_idx = ~((tone_pred == 4) | ( (11 <= tone_pred) & (tone_pred <= 15)) | (tone_pred == 0))
157
+ tone_pred[en_tone_idx] = 3
158
+
159
+ # Prepare cfg inputs
160
+ ph_seq = torch.cat([ph_pred, ph_pred, torch.full(ph_pred.size(), self.cfg_mask_token_phone, device=self.device)], 0)
161
+ tone_seq = torch.cat([tone_pred, tone_pred, torch.full(tone_pred.size(), self.cfg_mask_token_tone, device=self.device)], 0)
162
+ target_size = mel2ph_pred.size(1)//self.vae_stride
163
+ vae_latent_ = vae_latent.repeat(3, 1, 1)
164
+ ctx_mask = torch.ones_like(vae_latent_[:, :, 0:1])
165
+ vae_latent_ = F.pad(vae_latent_, (0, 0, 0, target_size - vae_latent.size(1)), mode='constant', value=0)
166
+ vae_latent_[1:] = 0.0
167
+ ctx_mask = F.pad(ctx_mask, (0, 0, 0, target_size - vae_latent.size(1)), mode='constant', value=0)
168
+
169
+ return {
170
+ 'phone': ph_seq,
171
+ 'tone': tone_seq,
172
+ "lat_ctx": vae_latent_ * ctx_mask,
173
+ "ctx_mask": ctx_mask,
174
+ "dur": mel2ph_pred,
175
+ }
tts/gradio_api.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import multiprocessing as mp
16
+ import torch
17
+ import os
18
+ from functools import partial
19
+ import gradio as gr
20
+ import traceback
21
+ from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav
22
+
23
+
24
+ def model_worker(input_queue, output_queue, device_id):
25
+ device = None
26
+ if device_id is not None:
27
+ device = torch.device(f'cuda:{device_id}')
28
+ infer_pipe = MegaTTS3DiTInfer(device=device)
29
+ os.system(f'pkill -f "voidgpu{device_id}"')
30
+
31
+ while True:
32
+ task = input_queue.get()
33
+ inp_audio_path, inp_npy_path, inp_text, infer_timestep, p_w, t_w = task
34
+ try:
35
+ convert_to_wav(inp_audio_path)
36
+ wav_path = os.path.splitext(inp_audio_path)[0] + '.wav'
37
+ cut_wav(wav_path, max_len=28)
38
+ with open(wav_path, 'rb') as file:
39
+ file_content = file.read()
40
+ resource_context = infer_pipe.preprocess(file_content, latent_file=inp_npy_path)
41
+ wav_bytes = infer_pipe.forward(resource_context, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
42
+ output_queue.put(wav_bytes)
43
+ except Exception as e:
44
+ traceback.print_exc()
45
+ print(task, str(e))
46
+ output_queue.put(None)
47
+
48
+
49
+ def main(inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w, processes, input_queue, output_queue):
50
+ print("Push task to the inp queue |", inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w)
51
+ input_queue.put((inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w))
52
+ res = output_queue.get()
53
+ if res is not None:
54
+ return res
55
+ else:
56
+ print("")
57
+ return None
58
+
59
+
60
+ if __name__ == '__main__':
61
+ mp.set_start_method('spawn', force=True)
62
+ devices = os.environ.get('CUDA_VISIBLE_DEVICES', '')
63
+ if devices != '':
64
+ devices = os.environ.get('CUDA_VISIBLE_DEVICES', '').split(",")
65
+ for d in devices:
66
+ os.system(f'pkill -f "voidgpu{d}"')
67
+ else:
68
+ devices = None
69
+
70
+ num_workers = 1
71
+ input_queue = mp.Queue()
72
+ output_queue = mp.Queue()
73
+ processes = []
74
+
75
+ print("Start open workers")
76
+ for i in range(num_workers):
77
+ p = mp.Process(target=model_worker, args=(input_queue, output_queue, i % len(devices) if devices is not None else None))
78
+ p.start()
79
+ processes.append(p)
80
+
81
+ api_interface = gr.Interface(fn=
82
+ partial(main, processes=processes, input_queue=input_queue,
83
+ output_queue=output_queue),
84
+ inputs=[gr.Audio(type="filepath", label="Upload .wav"), gr.File(type="filepath", label="Upload .npy"), "text",
85
+ gr.Number(label="infer timestep", value=32),
86
+ gr.Number(label="Intelligibility Weight", value=1.4),
87
+ gr.Number(label="Similarity Weight", value=3.0)], outputs=[gr.Audio(label="Synthesized Audio")],
88
+ title="MegaTTS3",
89
+ description="Upload a speech clip as a reference for timbre, " +
90
+ "upload the pre-extracted latent file, "+
91
+ "input the target text, and receive the cloned voice.", concurrency_limit=1)
92
+ api_interface.launch(server_name='0.0.0.0', server_port=7929, debug=True)
93
+ for p in processes:
94
+ p.join()
tts/infer_cli.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import os
17
+ import argparse
18
+ import librosa
19
+ import numpy as np
20
+ import torch
21
+
22
+ from tn.chinese.normalizer import Normalizer as ZhNormalizer
23
+ from tn.english.normalizer import Normalizer as EnNormalizer
24
+ from langdetect import detect as classify_language
25
+ from pydub import AudioSegment
26
+ import pyloudnorm as pyln
27
+
28
+ from tts.modules.ar_dur.commons.nar_tts_modules import LengthRegulator
29
+ from tts.frontend_function import g2p, align, make_dur_prompt, dur_pred, prepare_inputs_for_dit
30
+ from tts.utils.audio_utils.io import save_wav, to_wav_bytes, convert_to_wav_bytes, combine_audio_segments
31
+ from tts.utils.commons.ckpt_utils import load_ckpt
32
+ from tts.utils.commons.hparams import set_hparams, hparams
33
+ from tts.utils.text_utils.text_encoder import TokenTextEncoder
34
+ from tts.utils.text_utils.split_text import chunk_text_chinese, chunk_text_english
35
+ from tts.utils.commons.hparams import hparams, set_hparams
36
+
37
+
38
+ if "TOKENIZERS_PARALLELISM" not in os.environ:
39
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
40
+
41
+ def convert_to_wav(wav_path):
42
+ # Check if the file exists
43
+ if not os.path.exists(wav_path):
44
+ print(f"The file '{wav_path}' does not exist.")
45
+ return
46
+
47
+ # Check if the file already has a .wav extension
48
+ if not wav_path.endswith(".wav"):
49
+ # Define the output path with a .wav extension
50
+ out_path = os.path.splitext(wav_path)[0] + ".wav"
51
+
52
+ # Load the audio file using pydub and convert it to WAV
53
+ audio = AudioSegment.from_file(wav_path)
54
+ audio.export(out_path, format="wav")
55
+
56
+ print(f"Converted '{wav_path}' to '{out_path}'")
57
+
58
+
59
+ def cut_wav(wav_path, max_len=28):
60
+ audio = AudioSegment.from_file(wav_path)
61
+ audio = audio[:int(max_len * 1000)]
62
+ audio.export(wav_path, format="wav")
63
+
64
+ class MegaTTS3DiTInfer():
65
+ def __init__(
66
+ self,
67
+ device=None,
68
+ ckpt_root='./checkpoints',
69
+ dit_exp_name='diffusion_transformer',
70
+ frontend_exp_name='aligner_lm',
71
+ wavvae_exp_name='wavvae',
72
+ dur_ckpt_path='duration_lm',
73
+ g2p_exp_name='g2p',
74
+ precision=torch.float16,
75
+ **kwargs
76
+ ):
77
+ self.sr = 24000
78
+ self.fm = 8
79
+ if device is None:
80
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
81
+ self.device = device
82
+ self.precision = precision
83
+
84
+ # build models
85
+ self.dit_exp_name = os.path.join(ckpt_root, dit_exp_name)
86
+ self.frontend_exp_name = os.path.join(ckpt_root, frontend_exp_name)
87
+ self.wavvae_exp_name = os.path.join(ckpt_root, wavvae_exp_name)
88
+ self.dur_exp_name = os.path.join(ckpt_root, dur_ckpt_path)
89
+ self.g2p_exp_name = os.path.join(ckpt_root, g2p_exp_name)
90
+ self.build_model(self.device)
91
+
92
+ # init text normalizer
93
+ self.zh_normalizer = ZhNormalizer(overwrite_cache=False, remove_erhua=False, remove_interjections=False)
94
+ self.en_normalizer = EnNormalizer(overwrite_cache=False)
95
+ # loudness meter
96
+ self.loudness_meter = pyln.Meter(self.sr)
97
+
98
+ def build_model(self, device):
99
+ set_hparams(exp_name=self.dit_exp_name, print_hparams=False)
100
+
101
+ ''' Load Dict '''
102
+ current_dir = os.path.dirname(os.path.abspath(__file__))
103
+ ling_dict = json.load(open(f"{current_dir}/utils/text_utils/dict.json", encoding='utf-8-sig'))
104
+ self.ling_dict = {k: TokenTextEncoder(None, vocab_list=ling_dict[k], replace_oov='<UNK>') for k in ['phone', 'tone']}
105
+ self.token_encoder = token_encoder = self.ling_dict['phone']
106
+ ph_dict_size = len(token_encoder)
107
+
108
+ ''' Load Duration LM '''
109
+ from tts.modules.ar_dur.ar_dur_predictor import ARDurPredictor
110
+ hp_dur_model = self.hp_dur_model = set_hparams(f'{self.dur_exp_name}/config.yaml', global_hparams=False)
111
+ hp_dur_model['frames_multiple'] = hparams['frames_multiple']
112
+ self.dur_model = ARDurPredictor(
113
+ hp_dur_model, hp_dur_model['dur_txt_hs'], hp_dur_model['dur_model_hidden_size'],
114
+ hp_dur_model['dur_model_layers'], ph_dict_size,
115
+ hp_dur_model['dur_code_size'],
116
+ use_rot_embed=hp_dur_model.get('use_rot_embed', False))
117
+ self.length_regulator = LengthRegulator()
118
+ load_ckpt(self.dur_model, f'{self.dur_exp_name}', 'dur_model')
119
+ self.dur_model.eval()
120
+ self.dur_model.to(device)
121
+
122
+ ''' Load Diffusion Transformer '''
123
+ from tts.modules.llm_dit.dit import Diffusion
124
+ self.dit = Diffusion()
125
+ load_ckpt(self.dit, f'{self.dit_exp_name}', 'dit', strict=False)
126
+ self.dit.eval()
127
+ self.dit.to(device)
128
+ self.cfg_mask_token_phone = 302 - 1
129
+ self.cfg_mask_token_tone = 32 - 1
130
+
131
+ ''' Load Frontend LM '''
132
+ from tts.modules.aligner.whisper_small import Whisper
133
+ self.aligner_lm = Whisper()
134
+ load_ckpt(self.aligner_lm, f'{self.frontend_exp_name}', 'model')
135
+ self.aligner_lm.eval()
136
+ self.aligner_lm.to(device)
137
+ self.kv_cache = None
138
+ self.hooks = None
139
+
140
+ ''' Load G2P LM'''
141
+ from transformers import AutoTokenizer, AutoModelForCausalLM
142
+ g2p_tokenizer = AutoTokenizer.from_pretrained(self.g2p_exp_name, padding_side="right")
143
+ g2p_tokenizer.padding_side = "right"
144
+ self.g2p_model = AutoModelForCausalLM.from_pretrained(self.g2p_exp_name).eval().to(device)
145
+ self.g2p_tokenizer = g2p_tokenizer
146
+ self.speech_start_idx = g2p_tokenizer.encode('<Reserved_TTS_0>')[0]
147
+
148
+ ''' Wav VAE '''
149
+ self.hp_wavvae = hp_wavvae = set_hparams(f'{self.wavvae_exp_name}/config.yaml', global_hparams=False)
150
+ from tts.modules.wavvae.decoder.wavvae_v3 import WavVAE_V3
151
+ self.wavvae = WavVAE_V3(hparams=hp_wavvae)
152
+ if os.path.exists(f'{self.wavvae_exp_name}/model_only_last.ckpt'):
153
+ load_ckpt(self.wavvae, f'{self.wavvae_exp_name}/model_only_last.ckpt', 'model_gen', strict=True)
154
+ self.has_vae_encoder = True
155
+ else:
156
+ load_ckpt(self.wavvae, f'{self.wavvae_exp_name}/decoder.ckpt', 'model_gen', strict=False)
157
+ self.has_vae_encoder = False
158
+ self.wavvae.eval()
159
+ self.wavvae.to(device)
160
+ self.vae_stride = hp_wavvae.get('vae_stride', 4)
161
+ self.hop_size = hp_wavvae.get('hop_size', 4)
162
+
163
+ def preprocess(self, audio_bytes, latent_file=None, topk_dur=1, **kwargs):
164
+ wav_bytes = convert_to_wav_bytes(audio_bytes)
165
+
166
+ ''' Load wav '''
167
+ wav, _ = librosa.core.load(wav_bytes, sr=self.sr)
168
+ # Pad wav if necessary
169
+ ws = hparams['win_size']
170
+ if len(wav) % ws < ws - 1:
171
+ wav = np.pad(wav, (0, ws - 1 - (len(wav) % ws)), mode='constant', constant_values=0.0).astype(np.float32)
172
+ wav = np.pad(wav, (0, 12000), mode='constant', constant_values=0.0).astype(np.float32)
173
+ self.loudness_prompt = self.loudness_meter.integrated_loudness(wav.astype(float))
174
+
175
+ ''' obtain alignments with aligner_lm '''
176
+ ph_ref, tone_ref, mel2ph_ref = align(self, wav)
177
+
178
+ with torch.inference_mode():
179
+ ''' Forward WaveVAE to obtain: prompt latent '''
180
+ if self.has_vae_encoder:
181
+ wav = torch.FloatTensor(wav)[None].to(self.device)
182
+ vae_latent = self.wavvae.encode_latent(wav)
183
+ vae_latent = vae_latent[:, :mel2ph_ref.size(1)//4]
184
+ else:
185
+ assert latent_file is not None, "Please provide latent_file in WaveVAE decoder-only mode"
186
+ vae_latent = torch.from_numpy(np.load(latent_file)).to(self.device)
187
+ vae_latent = vae_latent[:, :mel2ph_ref.size(1)//4]
188
+
189
+ ''' Duration Prompting '''
190
+ self.dur_model.hparams["infer_top_k"] = topk_dur if topk_dur > 1 else None
191
+ incremental_state_dur_prompt, ctx_dur_tokens = make_dur_prompt(self, mel2ph_ref, ph_ref, tone_ref)
192
+
193
+ return {
194
+ 'ph_ref': ph_ref,
195
+ 'tone_ref': tone_ref,
196
+ 'mel2ph_ref': mel2ph_ref,
197
+ 'vae_latent': vae_latent,
198
+ 'incremental_state_dur_prompt': incremental_state_dur_prompt,
199
+ 'ctx_dur_tokens': ctx_dur_tokens,
200
+ }
201
+
202
+ def forward(self, resource_context, input_text, time_step, p_w, t_w, dur_disturb=0.1, dur_alpha=1.0, **kwargs):
203
+ device = self.device
204
+
205
+ ph_ref = resource_context['ph_ref'].to(device)
206
+ tone_ref = resource_context['tone_ref'].to(device)
207
+ mel2ph_ref = resource_context['mel2ph_ref'].to(device)
208
+ vae_latent = resource_context['vae_latent'].to(device)
209
+ ctx_dur_tokens = resource_context['ctx_dur_tokens'].to(device)
210
+ incremental_state_dur_prompt = resource_context['incremental_state_dur_prompt']
211
+
212
+ with torch.inference_mode():
213
+ ''' Generating '''
214
+ wav_pred_ = []
215
+ language_type = classify_language(input_text)
216
+ if language_type == 'en':
217
+ input_text = self.en_normalizer.normalize(input_text)
218
+ text_segs = chunk_text_english(input_text, max_chars=130)
219
+ else:
220
+ input_text = self.zh_normalizer.normalize(input_text)
221
+ text_segs = chunk_text_chinese(input_text, limit=60)
222
+
223
+ for seg_i, text in enumerate(text_segs):
224
+ ''' G2P '''
225
+ ph_pred, tone_pred = g2p(self, text)
226
+
227
+ ''' Duration Prediction '''
228
+ mel2ph_pred = dur_pred(self, ctx_dur_tokens, incremental_state_dur_prompt, ph_pred, tone_pred, seg_i, dur_disturb, dur_alpha, is_first=seg_i==0, is_final=seg_i==len(text_segs)-1)
229
+
230
+ inputs = prepare_inputs_for_dit(self, mel2ph_ref, mel2ph_pred, ph_ref, tone_ref, ph_pred, tone_pred, vae_latent)
231
+ # Speech dit inference
232
+ with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
233
+ x = self.dit.inference(inputs, timesteps=time_step, seq_cfg_w=[p_w, t_w]).float()
234
+
235
+ # WavVAE decode
236
+ x[:, :vae_latent.size(1)] = vae_latent
237
+ wav_pred = self.wavvae.decode(x)[0,0].to(torch.float32)
238
+
239
+ ''' Post-processing '''
240
+ # Trim prompt wav
241
+ wav_pred = wav_pred[vae_latent.size(1)*self.vae_stride*self.hop_size:].cpu().numpy()
242
+ # Norm generated wav to prompt wav's level
243
+ meter = pyln.Meter(self.sr) # create BS.1770 meter
244
+ loudness_pred = self.loudness_meter.integrated_loudness(wav_pred.astype(float))
245
+ wav_pred = pyln.normalize.loudness(wav_pred, loudness_pred, self.loudness_prompt)
246
+ if np.abs(wav_pred).max() >= 1:
247
+ wav_pred = wav_pred / np.abs(wav_pred).max() * 0.95
248
+
249
+ # Apply hamming window
250
+ wav_pred_.append(wav_pred)
251
+
252
+ wav_pred = combine_audio_segments(wav_pred_, sr=self.sr).astype(float)
253
+ return to_wav_bytes(wav_pred, self.sr)
254
+
255
+
256
+ if __name__ == '__main__':
257
+ parser = argparse.ArgumentParser()
258
+ parser.add_argument('--input_wav', type=str)
259
+ parser.add_argument('--input_text', type=str)
260
+ parser.add_argument('--output_dir', type=str)
261
+ parser.add_argument('--time_step', type=int, default=32, help='Inference steps of Diffusion Transformer')
262
+ parser.add_argument('--p_w', type=float, default=1.6, help='Intelligibility Weight')
263
+ parser.add_argument('--t_w', type=float, default=2.5, help='Similarity Weight')
264
+ args = parser.parse_args()
265
+ wav_path, input_text, out_path, time_step, p_w, t_w = args.input_wav, args.input_text, args.output_dir, args.time_step, args.p_w, args.t_w
266
+
267
+ infer_ins = MegaTTS3DiTInfer()
268
+
269
+ with open(wav_path, 'rb') as file:
270
+ file_content = file.read()
271
+
272
+ print(f"| Start processing {wav_path}+{input_text}")
273
+ resource_context = infer_ins.preprocess(file_content, latent_file=wav_path.replace('.wav', '.npy'))
274
+ wav_bytes = infer_ins.forward(resource_context, input_text, time_step=time_step, p_w=p_w, t_w=t_w)
275
+
276
+ print(f"| Saving results to {out_path}/[P]{input_text[:20]}.wav")
277
+ os.makedirs(out_path, exist_ok=True)
278
+ save_wav(wav_bytes, f'{out_path}/[P]{input_text[:20]}.wav')
tts/modules/aligner/whisper_small.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 OpenAI
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # Copyright (c) [2022] [OpenAI]
24
+ # Copyright (c) [2025] [Ziyue Jiang]
25
+ # SPDX-License-Identifier: MIT
26
+ # This file has been modified by Ziyue Jiang on 2025/03/19
27
+ # Original file was released under MIT, with the full license text # available at https://github.com/openai/whisper/blob/v20240930/LICENSE.
28
+ # This modified file is released under the same license.
29
+
30
+ from contextlib import contextmanager
31
+ from typing import Dict, Iterable, Optional, Tuple
32
+
33
+ import numpy as np
34
+ import torch
35
+ import torch.nn.functional as F
36
+ from torch import Tensor, nn
37
+
38
+ from torch.nn.functional import scaled_dot_product_attention
39
+ SDPA_AVAILABLE = True
40
+
41
+
42
+ class LayerNorm(nn.LayerNorm):
43
+ def forward(self, x: Tensor) -> Tensor:
44
+ return super().forward(x.float()).type(x.dtype)
45
+
46
+
47
+ class Linear(nn.Linear):
48
+ def forward(self, x: Tensor) -> Tensor:
49
+ return F.linear(
50
+ x,
51
+ self.weight.to(x.dtype),
52
+ None if self.bias is None else self.bias.to(x.dtype),
53
+ )
54
+
55
+
56
+ class Conv1d(nn.Conv1d):
57
+ def _conv_forward(
58
+ self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
59
+ ) -> Tensor:
60
+ return super()._conv_forward(
61
+ x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
62
+ )
63
+
64
+
65
+ def sinusoids(length, channels, max_timescale=10000):
66
+ """Returns sinusoids for positional embedding"""
67
+ assert channels % 2 == 0
68
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
69
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
70
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
71
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
72
+
73
+
74
+ @contextmanager
75
+ def disable_sdpa():
76
+ prev_state = MultiHeadAttention.use_sdpa
77
+ try:
78
+ MultiHeadAttention.use_sdpa = False
79
+ yield
80
+ finally:
81
+ MultiHeadAttention.use_sdpa = prev_state
82
+
83
+
84
+ class MultiHeadAttention(nn.Module):
85
+ use_sdpa = True
86
+
87
+ def __init__(self, n_state: int, n_head: int):
88
+ super().__init__()
89
+ self.n_head = n_head
90
+ self.query = Linear(n_state, n_state)
91
+ self.key = Linear(n_state, n_state, bias=False)
92
+ self.value = Linear(n_state, n_state)
93
+ self.out = Linear(n_state, n_state)
94
+
95
+ def forward(
96
+ self,
97
+ x: Tensor,
98
+ xa: Optional[Tensor] = None,
99
+ mask: Optional[Tensor] = None,
100
+ kv_cache: Optional[dict] = None,
101
+ casual: Optional[bool] = None
102
+ ):
103
+ q = self.query(x)
104
+
105
+ if kv_cache is None or xa is None or self.key not in kv_cache:
106
+ # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
107
+ # otherwise, perform key/value projections for self- or cross-attention as usual.
108
+ k = self.key(x if xa is None else xa)
109
+ v = self.value(x if xa is None else xa)
110
+ else:
111
+ # for cross-attention, calculate keys and values once and reuse in subsequent calls.
112
+ k = kv_cache[self.key]
113
+ v = kv_cache[self.value]
114
+
115
+ wv = self.qkv_attention(q, k, v, mask, casual)
116
+ return self.out(wv)
117
+
118
+ def qkv_attention(
119
+ self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, casual: Optional[bool] = None
120
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
121
+ n_batch, n_ctx, n_state = q.shape
122
+ scale = (n_state // self.n_head) ** -0.25
123
+ q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
124
+ k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
125
+ v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
126
+
127
+ a = scaled_dot_product_attention(
128
+ q, k, v, is_causal=casual and n_ctx > 1, attn_mask=mask[:, None, None, :] if mask is not None else None
129
+ )
130
+ out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
131
+ return out
132
+
133
+
134
+ class ResidualAttentionBlock(nn.Module):
135
+ def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
136
+ super().__init__()
137
+
138
+ self.attn = MultiHeadAttention(n_state, n_head)
139
+ self.attn_ln = LayerNorm(n_state)
140
+
141
+ self.cross_attn = (
142
+ MultiHeadAttention(n_state, n_head) if cross_attention else None
143
+ )
144
+ self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
145
+
146
+ n_mlp = n_state * 4
147
+ self.mlp = nn.Sequential(
148
+ Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
149
+ )
150
+ self.mlp_ln = LayerNorm(n_state)
151
+
152
+ def forward(
153
+ self,
154
+ x: Tensor,
155
+ xa: Optional[Tensor] = None,
156
+ mask: Optional[Tensor] = None,
157
+ kv_cache: Optional[dict] = None,
158
+ casual: Optional[bool] = None,
159
+ ):
160
+ x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache, casual=casual)
161
+ if self.cross_attn:
162
+ # TODO: Cross attention mask
163
+ x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache, casual=False)
164
+ x = x + self.mlp(self.mlp_ln(x))
165
+ return x
166
+
167
+
168
+ class AudioEncoder(nn.Module):
169
+ def __init__(
170
+ self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
171
+ ):
172
+ super().__init__()
173
+ self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
174
+ self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
175
+ self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
176
+
177
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
178
+ [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
179
+ )
180
+ self.ln_post = LayerNorm(n_state)
181
+
182
+ def forward(self, x: Tensor, attn_mask: Tensor):
183
+ """
184
+ x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
185
+ the mel spectrogram of the audio
186
+ """
187
+ x = F.gelu(self.conv1(x))
188
+ x = F.gelu(self.conv2(x))
189
+ x = x.permute(0, 2, 1)
190
+
191
+ # assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
192
+ x = (x + self.positional_embedding[:x.size(1)]).to(x.dtype)
193
+
194
+ for block in self.blocks:
195
+ x = block(x, mask=attn_mask, casual=False)
196
+
197
+ x = self.ln_post(x)
198
+ return x
199
+
200
+
201
+ class TextDecoder(nn.Module):
202
+ def __init__(
203
+ self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
204
+ ):
205
+ super().__init__()
206
+
207
+ self.token_embedding = nn.Embedding(n_vocab, n_state)
208
+ self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
209
+
210
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
211
+ [
212
+ ResidualAttentionBlock(n_state, n_head, cross_attention=True)
213
+ for _ in range(n_layer)
214
+ ]
215
+ )
216
+ self.ln = LayerNorm(n_state)
217
+
218
+ self.out_proj = nn.Linear(n_state, n_vocab)
219
+
220
+ def forward(self, x: Tensor, attn_mask: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
221
+ """
222
+ x : torch.LongTensor, shape = (batch_size, <= n_ctx)
223
+ the text tokens
224
+ xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
225
+ the encoded audio features to be attended on
226
+ """
227
+ offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
228
+ x = (
229
+ self.token_embedding(x)
230
+ + self.positional_embedding[offset : offset + x.shape[-1]]
231
+ )
232
+ x = x.to(xa.dtype)
233
+
234
+ for block in self.blocks:
235
+ x = block(x, xa, mask=attn_mask, kv_cache=kv_cache, casual=True)
236
+
237
+ x = self.ln(x)
238
+ # logits = (
239
+ # x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
240
+ # ).float()
241
+ logits = self.out_proj(x)
242
+
243
+ return logits
244
+
245
+
246
+ class Whisper(nn.Module):
247
+ def __init__(self):
248
+ super().__init__()
249
+ self.n_vocab = 6800
250
+ self.n_text_layer = 6
251
+ self.n_text_head = 8
252
+ self.n_text_ctx = 2048
253
+
254
+ self.encoder = AudioEncoder(
255
+ n_mels=80, n_ctx=3000, n_state=512, n_head=8, n_layer=6,
256
+ )
257
+ self.decoder = TextDecoder(
258
+ n_vocab=6800, n_ctx=2048, n_state=512, n_head=8, n_layer=6,
259
+ )
260
+
261
+ def embed_audio(self, mel: torch.Tensor):
262
+ return self.encoder(mel, None)
263
+
264
+ def logits(self, tokens, audio_features, kv_cache=None):
265
+ return self.decoder(tokens, None, audio_features, kv_cache=kv_cache)
266
+
267
+ def forward(
268
+ self, mel, mel_len, token, token_len
269
+ ) -> Dict[str, torch.Tensor]:
270
+ attn_mask_enc = self.sequence_mask(mel_len//2, device=mel.device) > 0
271
+ attn_mask_dec = self.sequence_mask(token_len, device=mel.device) > 0
272
+ return self.decoder(token, attn_mask_dec, self.encoder(mel, attn_mask_enc))
273
+
274
+ @property
275
+ def device(self):
276
+ return next(self.parameters()).device
277
+
278
+ def install_kv_cache_hooks(self, cache: Optional[dict] = None):
279
+ """
280
+ The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
281
+ tensors calculated for the previous positions. This method returns a dictionary that stores
282
+ all caches, and the necessary hooks for the key and value projection modules that save the
283
+ intermediate tensors to be reused during later calculations.
284
+
285
+ Returns
286
+ -------
287
+ cache : Dict[nn.Module, torch.Tensor]
288
+ A dictionary object mapping the key/value projection modules to its cache
289
+ hooks : List[RemovableHandle]
290
+ List of PyTorch RemovableHandle objects to stop the hooks to be called
291
+ """
292
+ cache = {**cache} if cache is not None else {}
293
+ hooks = []
294
+
295
+ def save_to_cache(module, _, output):
296
+ if module not in cache or output.shape[1] > self.n_text_ctx:
297
+ # save as-is, for the first token or cross attention
298
+ cache[module] = output
299
+ else:
300
+ cache[module] = torch.cat([cache[module], output], dim=1).detach()
301
+ return cache[module]
302
+
303
+ def install_hooks(layer: nn.Module):
304
+ if isinstance(layer, MultiHeadAttention):
305
+ hooks.append(layer.key.register_forward_hook(save_to_cache))
306
+ hooks.append(layer.value.register_forward_hook(save_to_cache))
307
+
308
+ self.decoder.apply(install_hooks)
309
+ return cache, hooks
310
+
311
+ def sequence_mask(self, seq_lens, max_len=None, device='cpu'):
312
+ b = seq_lens.shape[0]
313
+ if max_len is None:
314
+ max_len = seq_lens.max()
315
+ mask = torch.arange(max_len).unsqueeze(0).to(device) # [1, t]
316
+ mask = mask < (seq_lens.unsqueeze(1)) # [1, t] + [b, 1] = [b, t]
317
+ mask = mask.float()
318
+ return mask
tts/modules/ar_dur/ar_dur_predictor.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import random
16
+ from copy import deepcopy
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import nn
21
+ from torch.nn import Linear
22
+ from tqdm import tqdm
23
+
24
+ from tts.modules.ar_dur.commons.layers import Embedding, LayerNorm
25
+ from tts.modules.ar_dur.commons.nar_tts_modules import PosEmb
26
+ from tts.modules.ar_dur.commons.rot_transformer import RotTransformerDecoderLayer
27
+ from tts.modules.ar_dur.commons.transformer import SinusoidalPositionalEmbedding
28
+ from tts.modules.ar_dur.commons.rel_transformer import RelTransformerEncoder
29
+
30
+ FS_ENCODERS = {
31
+ 'rel_fft': lambda hp, dict_size: RelTransformerEncoder(
32
+ dict_size, hp['hidden_size'], hp['hidden_size'],
33
+ hp['ffn_hidden_size'], hp['num_heads'], hp['enc_layers'],
34
+ hp['enc_ffn_kernel_size'], hp['dropout'], prenet=hp['enc_prenet'], pre_ln=hp['enc_pre_ln']),
35
+ }
36
+
37
+ def fill_with_neg_inf2(t):
38
+ """FP16-compatible function that fills a tensor with -inf."""
39
+ return t.float().fill_(-1e8).type_as(t)
40
+
41
+ def expand_states(h, mel2token):
42
+ h = F.pad(h, [0, 0, 1, 0])
43
+ mel2token_ = mel2token[..., None].repeat([1, 1, h.shape[-1]])
44
+ h = torch.gather(h, 1, mel2token_) # [B, T, H]
45
+ return h
46
+
47
+
48
+ class CodePredictor(nn.Module):
49
+ def __init__(self, hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size):
50
+ super().__init__()
51
+ self.hparams = deepcopy(hparams)
52
+ self.hparams['hidden_size'] = hidden_size
53
+ self.hidden_size = hidden_size
54
+ char_dict_size = hparams.get('char_dict_size', 4000)
55
+ if not hparams.get('lm_use_enc'):
56
+ self.encoder = nn.Embedding(dict_size, self.hidden_size, padding_idx=0)
57
+ if hparams.get('mega_use_char', True):
58
+ self.char_encoder = nn.Embedding(char_dict_size,
59
+ self.hidden_size, padding_idx=0)
60
+ else:
61
+ self.encoder = FS_ENCODERS[self.hparams['encoder_type']](self.hparams, dict_size)
62
+ if hparams.get('mega_use_char', True):
63
+ self.char_encoder = FS_ENCODERS[self.hparams['encoder_type']](self.hparams, char_dict_size)
64
+ if hparams['use_ph_pos_embed']:
65
+ self.ph_pos_embed = PosEmb(self.hidden_size)
66
+
67
+ self.char_empty_embed = nn.Embedding(1, self.hidden_size)
68
+ if hparams.get('use_bert_input'):
69
+ self.bert_input_proj = nn.Linear(768, self.hidden_size)
70
+ self.ling_label_embed_layers = nn.ModuleDict()
71
+ for k, s in zip(hparams['ling_labels'], hparams['ling_label_dict_size']):
72
+ self.ling_label_embed_layers[k] = Embedding(s + 3, self.hidden_size, padding_idx=0)
73
+
74
+ self.dec_hidden_size = dec_hidden_size
75
+ self.enc_proj = nn.Linear(self.hidden_size, dec_hidden_size)
76
+ self.code_emb = Embedding(code_size + 2, dec_hidden_size, 0)
77
+ self.use_pos_embed = hparams.get('use_pos_embed', False)
78
+ if self.use_pos_embed:
79
+ self.embed_positions = SinusoidalPositionalEmbedding(dec_hidden_size, 0, init_size=1024)
80
+ self.use_post_ln = hparams.get('use_post_ln', False)
81
+ self.layers = None
82
+ if not self.use_post_ln:
83
+ self.layer_norm = LayerNorm(dec_hidden_size)
84
+ self.code_size = code_size
85
+ self.project_out_dim = Linear(dec_hidden_size, code_size + 1, bias=True)
86
+
87
+ def forward_ling_encoder(
88
+ self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, spk_id, spk_embed, mels_timbre):
89
+ ph_tokens = txt_tokens
90
+ hparams = self.hparams
91
+ ph_nonpadding = (ph_tokens > 0).float()[:, :, None] # [B, T_phone, 1]
92
+ x_spk = self.forward_style_embed(spk_embed, spk_id, mels_timbre)
93
+
94
+ # enc_ph
95
+ if not hparams.get('lm_use_enc'):
96
+ x_ph = self.encoder(ph_tokens)
97
+ x_ph = x_ph + sum(
98
+ [self.ling_label_embed_layers[k](ling_feas[k]) for k in hparams['ling_labels']]) \
99
+ if len(hparams['ling_labels']) > 0 else 0
100
+ x_ph = x_ph + x_spk
101
+ else:
102
+ # enc_ph
103
+ ph_enc_oembed = sum(
104
+ [self.ling_label_embed_layers[k](ling_feas[k]) for k in hparams['ling_labels']]) \
105
+ if len(hparams['ling_labels']) > 0 else 0
106
+ ph_enc_oembed = ph_enc_oembed + self.ph_pos_embed(
107
+ torch.arange(0, ph_tokens.shape[1])[None,].to(ph_tokens.device))
108
+ ph_enc_oembed = ph_enc_oembed + x_spk
109
+ ph_enc_oembed = ph_enc_oembed * ph_nonpadding
110
+ x_ph = self.encoder(ph_tokens, other_embeds=ph_enc_oembed)
111
+
112
+ # enc_char
113
+ if char_tokens is not None and ph2char is not None:
114
+ char_nonpadding = (char_tokens > 0).float()[:, :, None]
115
+ x_char = self.char_encoder(char_tokens)
116
+ empty_char = (ph2char > 100000).long()
117
+ ph2char = ph2char * (1 - empty_char)
118
+ x_char_phlevel = \
119
+ expand_states(x_char * char_nonpadding, ph2char) \
120
+ * (1 - empty_char)[..., None] + \
121
+ self.char_empty_embed(torch.zeros_like(ph_tokens)) * empty_char[..., None]
122
+ else:
123
+ x_char_phlevel = 0
124
+ # x_ling
125
+ x_ling = x_ph + x_char_phlevel
126
+ x_ling = x_ling * ph_nonpadding
127
+ x_ling = self.enc_proj(x_ling)
128
+ return x_ling
129
+
130
+ def sample_one_step(self, vq_pred):
131
+ hparams = self.hparams
132
+ if hparams.get('infer_top_k'):
133
+ top_k = hparams.get('infer_top_k')
134
+ temperature = hparams.get('infer_temperature', 1)
135
+ vq_pred = vq_pred[:, -1] / temperature
136
+ # optionally crop the logits to only the top k options
137
+ if top_k is not None:
138
+ v, _ = torch.topk(vq_pred, min(top_k, vq_pred.size(-1)))
139
+ vq_pred[vq_pred < v[:, [-1]]] = -float('Inf')
140
+ # apply softmax to convert logits to (normalized) probabilities
141
+ probs = F.softmax(vq_pred, dim=-1)
142
+ # sample from the distribution
143
+ vq_pred = torch.multinomial(probs, num_samples=1)
144
+ else:
145
+ vq_pred = torch.argmax(F.softmax(vq_pred[:, -1], dim=-1), 1)
146
+ return vq_pred
147
+
148
+ def forward_style_embed(self, spk_embed=None, spk_id=None, mel_ref=None):
149
+ # add spk embed
150
+ style_embed = 0
151
+ if self.hparams['use_spk_embed']:
152
+ style_embed = style_embed + self.spk_embed_proj(spk_embed)[:, None, :]
153
+ if self.hparams['use_spk_id']:
154
+ style_embed = style_embed + self.spk_id_proj(spk_id)[:, None, :]
155
+ if self.hparams['use_spk_enc']:
156
+ style_embed = style_embed + self.spk_enc(mel_ref)[:, None, :]
157
+ return style_embed
158
+
159
+ def buffered_future_mask(self, tensor):
160
+ dim = tensor.size(0)
161
+ if (
162
+ not hasattr(self, '_future_mask')
163
+ or self._future_mask is None
164
+ or self._future_mask.device != tensor.device
165
+ or self._future_mask.size(0) < dim
166
+ ):
167
+ self._future_mask = torch.triu(fill_with_neg_inf2(tensor.new(dim, dim)), 1)
168
+ return self._future_mask[:dim, :dim]
169
+
170
+
171
+ class ARDurPredictor(CodePredictor):
172
+ def __init__(self, hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size, use_rot_embed=True,
173
+ op_version=1):
174
+ super().__init__(hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size)
175
+ self.use_rot_embed = use_rot_embed
176
+ bias = hparams.get('lm_bias', True)
177
+ if self.use_rot_embed:
178
+ self.layers = nn.ModuleList([])
179
+ self.layers.extend([
180
+ RotTransformerDecoderLayer(
181
+ dec_hidden_size, 0.0, kernel_size=1, ffn_hidden_size=dec_hidden_size * 4,
182
+ post_ln=self.use_post_ln, op_version=op_version, bias=bias)
183
+ for _ in range(lm_num_layers)
184
+ ])
185
+ if hparams['dur_model_type'] == 'ar_mse':
186
+ self.project_out_dim = nn.Sequential(torch.nn.Linear(dec_hidden_size, 1), nn.Softplus())
187
+ else:
188
+ self.project_out_dim = torch.nn.Linear(dec_hidden_size, code_size + 1)
189
+
190
+ def forward(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
191
+ prev_code, spk_id=None, spk_embed=None, mels_timbre=None, mel2ph=None,
192
+ incremental_state=None, x_ling=None, attn_mask=None, spk_pos_ids_flat=None,
193
+ prompt_length=None, cache_size=20, streaming=False):
194
+ x = self.code_emb(prev_code)
195
+ if x_ling is None:
196
+ x_ling = self.forward_ling_encoder(
197
+ txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, spk_id, spk_embed, mels_timbre)
198
+ x_ling = x_ling.flatten(0, 1)
199
+ txt_tokens = txt_tokens.flatten(0, 1)
200
+ x_ling = x_ling[txt_tokens > 0][None]
201
+
202
+ # run decoder
203
+ self_attn_padding_mask = None
204
+ if self.use_pos_embed:
205
+ positions = self.embed_positions(
206
+ prev_code,
207
+ incremental_state=incremental_state
208
+ )
209
+ if incremental_state is not None:
210
+ x_ling = x_ling[:, x.shape[1] - 1:x.shape[1]]
211
+ if spk_pos_ids_flat is not None:
212
+ spk_pos_ids_flat = spk_pos_ids_flat[:, x.shape[1] - 1:x.shape[1]]
213
+ x = x[:, -1:]
214
+ if self.use_pos_embed:
215
+ positions = positions[:, -1:]
216
+ if streaming:
217
+ # Shift Pos: query pos is min(cache_size, idx)
218
+ spk_pos_ids_flat = torch.min(torch.LongTensor([prompt_length + cache_size]).to(x.device),
219
+ spk_pos_ids_flat)
220
+
221
+ # # B x T x C -> T x B x C
222
+ if self.use_pos_embed:
223
+ x = x + positions
224
+ x_ling = x_ling[:, :self.hparams['max_tokens']].contiguous()
225
+ T = min(self.hparams.get('max_tokens_per_item', 1e9), x_ling.shape[1])
226
+ x_ling = x_ling.reshape(-1, T, x_ling.shape[-1])
227
+ x = x + x_ling
228
+ x = x.transpose(0, 1)
229
+
230
+ for idx, layer in enumerate(self.layers):
231
+ if incremental_state is None:
232
+ self_attn_mask = self.buffered_future_mask(x)
233
+ if attn_mask is not None:
234
+ self_attn_mask = self_attn_mask + (1 - attn_mask.float()) * -1e8
235
+ self_attn_mask = self_attn_mask.clamp_min(-1e8)
236
+ else:
237
+ self_attn_mask = None
238
+
239
+ x, attn_weights = layer(
240
+ x,
241
+ incremental_state=incremental_state,
242
+ self_attn_mask=self_attn_mask,
243
+ self_attn_padding_mask=self_attn_padding_mask,
244
+ spk_pos_ids_flat=spk_pos_ids_flat
245
+ )
246
+
247
+ if streaming and incremental_state != {}:
248
+ for k, v in incremental_state.items():
249
+ if 'attn_state' in k:
250
+ prev_key, prev_value = incremental_state[k]['prev_key'], incremental_state[k]['prev_value']
251
+ cur_length = prev_key.shape[2]
252
+ if cur_length - prompt_length > cache_size:
253
+ prev_key = torch.cat((prev_key[:, :, :prompt_length], prev_key[:, :, -cache_size:]), dim=2)
254
+ prev_value = torch.cat((prev_value[:, :, :prompt_length], prev_value[:, :, -cache_size:]),
255
+ dim=2)
256
+ incremental_state[k]['prev_key'], incremental_state[k]['prev_value'] = prev_key, prev_value
257
+
258
+ if not self.use_post_ln:
259
+ x = self.layer_norm(x)
260
+ # T x B x C -> B x T x C
261
+ x = x.transpose(0, 1)
262
+ x = self.project_out_dim(x)
263
+ return x
264
+
265
+ def infer(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
266
+ spk_id=None, spk_embed=None, mels_timbre=None,
267
+ incremental_state=None, ctx_vqcodes=None, spk_pos_ids_flat=None, return_state=False,
268
+ first_step_min=0, return_probs=False, first_decoder_inp=None, dur_disturb=0.0, **kwargs):
269
+ if incremental_state is None:
270
+ incremental_state = {}
271
+ x_ling = self.forward_ling_encoder(
272
+ txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
273
+ spk_id, spk_embed, mels_timbre)
274
+ x_ling = x_ling.flatten(0, 1)
275
+ txt_tokens_ori = txt_tokens
276
+ txt_tokens_withpad = txt_tokens = txt_tokens.flatten(0, 1)
277
+ x_ling = x_ling[txt_tokens > 0][None]
278
+ txt_tokens = txt_tokens[txt_tokens > 0][None]
279
+
280
+ decoded = torch.zeros_like(txt_tokens)
281
+ decoded = F.pad(decoded, [1, 0], value=self.code_size + 1)
282
+ if incremental_state != {}:
283
+ if first_decoder_inp is None:
284
+ assert ctx_vqcodes is not None
285
+ decoded[:, :ctx_vqcodes.shape[1]] = ctx_vqcodes
286
+ ctx_vqcodes = None
287
+ else:
288
+ decoded[:, :1] = first_decoder_inp
289
+ probs = []
290
+ for step in range(decoded.shape[1] - 1):
291
+ vq_pred = self(txt_tokens, None, None, None, None,
292
+ decoded[:, :step + 1], None, None, None,
293
+ incremental_state=incremental_state, x_ling=x_ling,
294
+ spk_pos_ids_flat=spk_pos_ids_flat, **kwargs)
295
+ probs.append(vq_pred.cpu())
296
+ if ctx_vqcodes is None or step >= ctx_vqcodes.shape[1]:
297
+ if self.hparams['dur_model_type'] == 'ar_mse':
298
+ d = vq_pred[:, -1, 0]
299
+ if dur_disturb > 0 and step >= 1:
300
+ if random.random() > 0.5:
301
+ d = d * (1 + random.random() * dur_disturb)
302
+ else:
303
+ d = d / (1 + random.random() * dur_disturb)
304
+ d = torch.clamp_max(d, self.code_size - 1)
305
+ vq_pred = torch.round(d).long()
306
+ else:
307
+ vq_pred = self.sample_one_step(vq_pred)
308
+ decoded[:, step + 1] = torch.clamp_min(vq_pred, 1)
309
+ if step == 0:
310
+ decoded[:, step + 1] = torch.clamp_min(vq_pred, first_step_min)
311
+ else:
312
+ decoded[:, step + 1] = ctx_vqcodes[:, step]
313
+ decoded = decoded[:, 1:]
314
+ decoded_2d = torch.zeros_like(txt_tokens_ori)
315
+ decoded_2d.flatten(0, 1)[txt_tokens_withpad > 0] = decoded
316
+ if return_state:
317
+ return decoded_2d, incremental_state
318
+ if return_probs:
319
+ return decoded_2d, torch.cat(probs, 1)
320
+ return decoded_2d
321
+
322
+ def streaming_infer(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
323
+ spk_id=None, spk_embed=None, mels_timbre=None,
324
+ incremental_state=None, ctx_vqcodes=None, spk_pos_ids_flat=None, return_state=False,
325
+ **kwargs):
326
+ if incremental_state is None:
327
+ incremental_state = {}
328
+ x_ling = self.forward_ling_encoder(
329
+ txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
330
+ spk_id, spk_embed, mels_timbre)
331
+ x_ling = x_ling.flatten(0, 1)
332
+ txt_tokens_ori = txt_tokens
333
+ txt_tokens_withpad = txt_tokens = txt_tokens.flatten(0, 1)
334
+ x_ling = x_ling[txt_tokens > 0][None]
335
+ txt_tokens = txt_tokens[txt_tokens > 0][None]
336
+
337
+ vq_decoded = torch.zeros_like(txt_tokens)
338
+ vq_decoded = F.pad(vq_decoded, [1, 0], value=self.code_size + 1)
339
+ if incremental_state != {}:
340
+ assert ctx_vqcodes is not None
341
+ vq_decoded[:, :ctx_vqcodes.shape[1]] = ctx_vqcodes
342
+ ctx_vqcodes = None
343
+ prompt_length = list(incremental_state.items())[0][1]['prev_key'].shape[2]
344
+ for step in tqdm(range(vq_decoded.shape[1] - 1), desc='AR Duration Predictor inference...'):
345
+ vq_pred = self(txt_tokens, None, None, None, None,
346
+ vq_decoded[:, :step + 1], None, None, None,
347
+ incremental_state=incremental_state, x_ling=x_ling,
348
+ spk_pos_ids_flat=spk_pos_ids_flat, prompt_length=prompt_length, streaming=True, **kwargs)
349
+ if ctx_vqcodes is None or step >= ctx_vqcodes.shape[1]:
350
+ if self.hparams['dur_model_type'] == 'ar_mse':
351
+ vq_pred = torch.round(vq_pred[:, -1, 0]).long()
352
+ else:
353
+ vq_pred = self.sample_one_step(vq_pred)
354
+ vq_decoded[:, step + 1] = vq_pred
355
+ else:
356
+ vq_decoded[:, step + 1] = ctx_vqcodes[:, step]
357
+ vq_decoded = vq_decoded[:, 1:]
358
+ vq_decoded_2d = torch.zeros_like(txt_tokens_ori)
359
+ vq_decoded_2d.flatten(0, 1)[txt_tokens_withpad > 0] = vq_decoded
360
+ if return_state:
361
+ return vq_decoded_2d, incremental_state
362
+ return vq_decoded_2d
tts/modules/ar_dur/commons/layers.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from torch import nn
17
+
18
+
19
+ class LayerNorm(torch.nn.LayerNorm):
20
+ """Layer normalization module.
21
+ :param int nout: output dim size
22
+ :param int dim: dimension to be normalized
23
+ """
24
+
25
+ def __init__(self, nout, dim=-1, eps=1e-5):
26
+ """Construct an LayerNorm object."""
27
+ super(LayerNorm, self).__init__(nout, eps=eps)
28
+ self.dim = dim
29
+
30
+ def forward(self, x):
31
+ """Apply layer normalization.
32
+ :param torch.Tensor x: input tensor
33
+ :return: layer normalized tensor
34
+ :rtype torch.Tensor
35
+ """
36
+ if self.dim == -1:
37
+ return super(LayerNorm, self).forward(x)
38
+ return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
39
+
40
+
41
+ class Reshape(nn.Module):
42
+ def __init__(self, *args):
43
+ super(Reshape, self).__init__()
44
+ self.shape = args
45
+
46
+ def forward(self, x):
47
+ return x.view(self.shape)
48
+
49
+
50
+ class Permute(nn.Module):
51
+ def __init__(self, *args):
52
+ super(Permute, self).__init__()
53
+ self.args = args
54
+
55
+ def forward(self, x):
56
+ return x.permute(self.args)
57
+
58
+
59
+ def Embedding(num_embeddings, embedding_dim, padding_idx=None):
60
+ m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
61
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
62
+ if padding_idx is not None:
63
+ nn.init.constant_(m.weight[padding_idx], 0)
64
+ return m
tts/modules/ar_dur/commons/nar_tts_modules.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+ import torch.nn.functional as F
21
+
22
+
23
+ class LengthRegulator(torch.nn.Module):
24
+ def __init__(self, pad_value=0.0):
25
+ super(LengthRegulator, self).__init__()
26
+ self.pad_value = pad_value
27
+
28
+ def forward(self, dur, dur_padding=None, alpha=1.0):
29
+ """
30
+ Example (no batch dim version):
31
+ 1. dur = [2,2,3]
32
+ 2. token_idx = [[1],[2],[3]], dur_cumsum = [2,4,7], dur_cumsum_prev = [0,2,4]
33
+ 3. token_mask = [[1,1,0,0,0,0,0],
34
+ [0,0,1,1,0,0,0],
35
+ [0,0,0,0,1,1,1]]
36
+ 4. token_idx * token_mask = [[1,1,0,0,0,0,0],
37
+ [0,0,2,2,0,0,0],
38
+ [0,0,0,0,3,3,3]]
39
+ 5. (token_idx * token_mask).sum(0) = [1,1,2,2,3,3,3]
40
+
41
+ :param dur: Batch of durations of each frame (B, T_txt)
42
+ :param dur_padding: Batch of padding of each frame (B, T_txt)
43
+ :param alpha: duration rescale coefficient
44
+ :return:
45
+ mel2ph (B, T_speech)
46
+ assert alpha > 0
47
+ """
48
+ dur = torch.round(dur.float() * alpha).long()
49
+ if dur_padding is not None:
50
+ dur = dur * (1 - dur_padding.long())
51
+ token_idx = torch.arange(1, dur.shape[1] + 1)[None, :, None].to(dur.device)
52
+ dur_cumsum = torch.cumsum(dur, 1)
53
+ dur_cumsum_prev = F.pad(dur_cumsum, [1, -1], mode='constant', value=0)
54
+
55
+ pos_idx = torch.arange(dur.sum(-1).max())[None, None].to(dur.device)
56
+ token_mask = (pos_idx >= dur_cumsum_prev[:, :, None]) & (pos_idx < dur_cumsum[:, :, None])
57
+ mel2token = (token_idx * token_mask.long()).sum(1)
58
+ return mel2token
59
+
60
+
61
+ class PosEmb(nn.Module):
62
+ def __init__(self, dim):
63
+ super().__init__()
64
+ self.dim = dim
65
+ half_dim = self.dim // 2
66
+ emb = math.log(10000) / (half_dim - 1)
67
+ emb = torch.exp(torch.arange(half_dim) * -emb)
68
+ self.emb = emb # TODO
69
+
70
+ def forward(self, x):
71
+ emb = x[:, :, None] * self.emb[None, None, :].to(x.device)
72
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
73
+ return emb
tts/modules/ar_dur/commons/rel_transformer.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import torch
17
+ from torch import nn
18
+ from torch.nn import functional as F
19
+
20
+ from tts.modules.ar_dur.commons.layers import Embedding
21
+
22
+
23
+ def convert_pad_shape(pad_shape):
24
+ l = pad_shape[::-1]
25
+ pad_shape = [item for sublist in l for item in sublist]
26
+ return pad_shape
27
+
28
+
29
+ def shift_1d(x):
30
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
31
+ return x
32
+
33
+
34
+ def sequence_mask(length, max_length=None):
35
+ if max_length is None:
36
+ max_length = length.max()
37
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
38
+ return x.unsqueeze(0) < length.unsqueeze(1)
39
+
40
+
41
+ class Encoder(nn.Module):
42
+ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.,
43
+ window_size=None, block_length=None, pre_ln=False, **kwargs):
44
+ super().__init__()
45
+ self.hidden_channels = hidden_channels
46
+ self.filter_channels = filter_channels
47
+ self.n_heads = n_heads
48
+ self.n_layers = n_layers
49
+ self.kernel_size = kernel_size
50
+ self.p_dropout = p_dropout
51
+ self.window_size = window_size
52
+ self.block_length = block_length
53
+ self.pre_ln = pre_ln
54
+
55
+ self.drop = nn.Dropout(p_dropout)
56
+ self.attn_layers = nn.ModuleList()
57
+ self.norm_layers_1 = nn.ModuleList()
58
+ self.ffn_layers = nn.ModuleList()
59
+ self.norm_layers_2 = nn.ModuleList()
60
+ for i in range(self.n_layers):
61
+ self.attn_layers.append(
62
+ MultiHeadAttention(hidden_channels, hidden_channels, n_heads, window_size=window_size,
63
+ p_dropout=p_dropout, block_length=block_length))
64
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
65
+ self.ffn_layers.append(
66
+ FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
67
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
68
+ if pre_ln:
69
+ self.last_ln = LayerNorm(hidden_channels)
70
+
71
+ def forward(self, x, x_mask, attn_mask=1):
72
+ if isinstance(attn_mask, torch.Tensor):
73
+ attn_mask = attn_mask[:, None]
74
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) * attn_mask
75
+ for i in range(self.n_layers):
76
+ x = x * x_mask
77
+ x_ = x
78
+ if self.pre_ln:
79
+ x = self.norm_layers_1[i](x)
80
+ y = self.attn_layers[i](x, x, attn_mask)
81
+ y = self.drop(y)
82
+ x = x_ + y
83
+ if not self.pre_ln:
84
+ x = self.norm_layers_1[i](x)
85
+
86
+ x_ = x
87
+ if self.pre_ln:
88
+ x = self.norm_layers_2[i](x)
89
+ y = self.ffn_layers[i](x, x_mask)
90
+ y = self.drop(y)
91
+ x = x_ + y
92
+ if not self.pre_ln:
93
+ x = self.norm_layers_2[i](x)
94
+ if self.pre_ln:
95
+ x = self.last_ln(x)
96
+ x = x * x_mask
97
+ return x
98
+
99
+
100
+ class MultiHeadAttention(nn.Module):
101
+ def __init__(self, channels, out_channels, n_heads, window_size=None, heads_share=True, p_dropout=0.,
102
+ block_length=None, proximal_bias=False, proximal_init=False):
103
+ super().__init__()
104
+ assert channels % n_heads == 0
105
+
106
+ self.channels = channels
107
+ self.out_channels = out_channels
108
+ self.n_heads = n_heads
109
+ self.window_size = window_size
110
+ self.heads_share = heads_share
111
+ self.block_length = block_length
112
+ self.proximal_bias = proximal_bias
113
+ self.p_dropout = p_dropout
114
+ self.attn = None
115
+
116
+ self.k_channels = channels // n_heads
117
+ self.conv_q = nn.Conv1d(channels, channels, 1)
118
+ self.conv_k = nn.Conv1d(channels, channels, 1)
119
+ self.conv_v = nn.Conv1d(channels, channels, 1)
120
+ if window_size is not None:
121
+ n_heads_rel = 1 if heads_share else n_heads
122
+ rel_stddev = self.k_channels ** -0.5
123
+ self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
124
+ self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
125
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
126
+ self.drop = nn.Dropout(p_dropout)
127
+
128
+ nn.init.xavier_uniform_(self.conv_q.weight)
129
+ nn.init.xavier_uniform_(self.conv_k.weight)
130
+ if proximal_init:
131
+ self.conv_k.weight.data.copy_(self.conv_q.weight.data)
132
+ self.conv_k.bias.data.copy_(self.conv_q.bias.data)
133
+ nn.init.xavier_uniform_(self.conv_v.weight)
134
+
135
+ def forward(self, x, c, attn_mask=None):
136
+ q = self.conv_q(x)
137
+ k = self.conv_k(c)
138
+ v = self.conv_v(c)
139
+
140
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
141
+
142
+ x = self.conv_o(x)
143
+ return x
144
+
145
+ def attention(self, query, key, value, mask=None):
146
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
147
+ b, d, t_s, t_t = (*key.size(), query.size(2))
148
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
149
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
150
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
151
+
152
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
153
+ if self.window_size is not None:
154
+ assert t_s == t_t, "Relative attention is only available for self-attention."
155
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
156
+ rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
157
+ rel_logits = self._relative_position_to_absolute_position(rel_logits)
158
+ scores_local = rel_logits / math.sqrt(self.k_channels)
159
+ scores = scores + scores_local
160
+ if self.proximal_bias:
161
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
162
+ scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
163
+ if mask is not None:
164
+ scores = scores.masked_fill(mask == 0, -1e4)
165
+ if self.block_length is not None:
166
+ block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
167
+ scores = scores * block_mask + -1e4 * (1 - block_mask)
168
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
169
+ p_attn = self.drop(p_attn)
170
+ output = torch.matmul(p_attn, value)
171
+ if self.window_size is not None:
172
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
173
+ value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
174
+ output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
175
+ output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
176
+ return output, p_attn
177
+
178
+ def _matmul_with_relative_values(self, x, y):
179
+ """
180
+ x: [b, h, l, m]
181
+ y: [h or 1, m, d]
182
+ ret: [b, h, l, d]
183
+ """
184
+ ret = torch.matmul(x, y.unsqueeze(0))
185
+ return ret
186
+
187
+ def _matmul_with_relative_keys(self, x, y):
188
+ """
189
+ x: [b, h, l, d]
190
+ y: [h or 1, m, d]
191
+ ret: [b, h, l, m]
192
+ """
193
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
194
+ return ret
195
+
196
+ def _get_relative_embeddings(self, relative_embeddings, length):
197
+ max_relative_position = 2 * self.window_size + 1
198
+ # Pad first before slice to avoid using cond ops.
199
+ pad_length = max(length - (self.window_size + 1), 0)
200
+ slice_start_position = max((self.window_size + 1) - length, 0)
201
+ slice_end_position = slice_start_position + 2 * length - 1
202
+ if pad_length > 0:
203
+ padded_relative_embeddings = F.pad(
204
+ relative_embeddings,
205
+ convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
206
+ else:
207
+ padded_relative_embeddings = relative_embeddings
208
+ used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
209
+ return used_relative_embeddings
210
+
211
+ def _relative_position_to_absolute_position(self, x):
212
+ """
213
+ x: [b, h, l, 2*l-1]
214
+ ret: [b, h, l, l]
215
+ """
216
+ batch, heads, length, _ = x.size()
217
+ # Concat columns of pad to shift from relative to absolute indexing.
218
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
219
+
220
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
221
+ x_flat = x.view([batch, heads, length * 2 * length])
222
+ x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
223
+
224
+ # Reshape and slice out the padded elements.
225
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:]
226
+ return x_final
227
+
228
+ def _absolute_position_to_relative_position(self, x):
229
+ """
230
+ x: [b, h, l, l]
231
+ ret: [b, h, l, 2*l-1]
232
+ """
233
+ batch, heads, length, _ = x.size()
234
+ # padd along column
235
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
236
+ x_flat = x.view([batch, heads, -1])
237
+ # add 0's in the beginning that will skew the elements after reshape
238
+ x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
239
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
240
+ return x_final
241
+
242
+ def _attention_bias_proximal(self, length):
243
+ """Bias for self-attention to encourage attention to close positions.
244
+ Args:
245
+ length: an integer scalar.
246
+ Returns:
247
+ a Tensor with shape [1, 1, length, length]
248
+ """
249
+ r = torch.arange(length, dtype=torch.float32)
250
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
251
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
252
+
253
+
254
+ class FFN(nn.Module):
255
+ def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None):
256
+ super().__init__()
257
+ self.in_channels = in_channels
258
+ self.out_channels = out_channels
259
+ self.filter_channels = filter_channels
260
+ self.kernel_size = kernel_size
261
+ self.p_dropout = p_dropout
262
+ self.activation = activation
263
+
264
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
265
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, 1)
266
+ self.drop = nn.Dropout(p_dropout)
267
+
268
+ def forward(self, x, x_mask):
269
+ x = self.conv_1(x * x_mask)
270
+ if self.activation == "gelu":
271
+ x = x * torch.sigmoid(1.702 * x)
272
+ else:
273
+ x = torch.relu(x)
274
+ x = self.drop(x)
275
+ x = self.conv_2(x * x_mask)
276
+ return x * x_mask
277
+
278
+
279
+ class LayerNorm(nn.Module):
280
+ def __init__(self, channels, eps=1e-4):
281
+ super().__init__()
282
+ self.channels = channels
283
+ self.eps = eps
284
+
285
+ self.gamma = nn.Parameter(torch.ones(channels))
286
+ self.beta = nn.Parameter(torch.zeros(channels))
287
+
288
+ def forward(self, x):
289
+ n_dims = len(x.shape)
290
+ mean = torch.mean(x, 1, keepdim=True)
291
+ variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
292
+
293
+ x = (x - mean) * torch.rsqrt(variance + self.eps)
294
+
295
+ shape = [1, -1] + [1] * (n_dims - 2)
296
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
297
+ return x
298
+
299
+
300
+ class ConvReluNorm(nn.Module):
301
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
302
+ super().__init__()
303
+ self.in_channels = in_channels
304
+ self.hidden_channels = hidden_channels
305
+ self.out_channels = out_channels
306
+ self.kernel_size = kernel_size
307
+ self.n_layers = n_layers
308
+ self.p_dropout = p_dropout
309
+ assert n_layers > 1, "Number of layers should be larger than 0."
310
+
311
+ self.conv_layers = nn.ModuleList()
312
+ self.norm_layers = nn.ModuleList()
313
+ self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
314
+ self.norm_layers.append(LayerNorm(hidden_channels))
315
+ self.relu_drop = nn.Sequential(
316
+ nn.ReLU(),
317
+ nn.Dropout(p_dropout))
318
+ for _ in range(n_layers - 1):
319
+ self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
320
+ self.norm_layers.append(LayerNorm(hidden_channels))
321
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
322
+ self.proj.weight.data.zero_()
323
+ self.proj.bias.data.zero_()
324
+
325
+ def forward(self, x, x_mask):
326
+ x_org = x
327
+ for i in range(self.n_layers):
328
+ x = self.conv_layers[i](x * x_mask)
329
+ x = self.norm_layers[i](x)
330
+ x = self.relu_drop(x)
331
+ x = x_org + self.proj(x)
332
+ return x * x_mask
333
+
334
+
335
+ class RelTransformerEncoder(nn.Module):
336
+ def __init__(self,
337
+ n_vocab,
338
+ out_channels,
339
+ hidden_channels,
340
+ filter_channels,
341
+ n_heads,
342
+ n_layers,
343
+ kernel_size,
344
+ p_dropout=0.0,
345
+ window_size=4,
346
+ block_length=None,
347
+ in_channels=None,
348
+ prenet=True,
349
+ pre_ln=True,
350
+ ):
351
+
352
+ super().__init__()
353
+
354
+ self.n_vocab = n_vocab
355
+ self.out_channels = out_channels
356
+ self.hidden_channels = hidden_channels
357
+ self.filter_channels = filter_channels
358
+ self.n_heads = n_heads
359
+ self.n_layers = n_layers
360
+ self.kernel_size = kernel_size
361
+ self.p_dropout = p_dropout
362
+ self.window_size = window_size
363
+ self.block_length = block_length
364
+ self.prenet = prenet
365
+ if n_vocab > 0:
366
+ self.emb = Embedding(n_vocab, hidden_channels, padding_idx=0)
367
+
368
+ if prenet:
369
+ if in_channels is None:
370
+ in_channels = hidden_channels
371
+ self.pre = ConvReluNorm(in_channels, in_channels, in_channels,
372
+ kernel_size=5, n_layers=3, p_dropout=0)
373
+ if in_channels is not None and in_channels != hidden_channels:
374
+ self.encoder_inp_proj = nn.Conv1d(in_channels, hidden_channels, 1)
375
+ self.encoder = Encoder(
376
+ hidden_channels,
377
+ filter_channels,
378
+ n_heads,
379
+ n_layers,
380
+ kernel_size,
381
+ p_dropout,
382
+ window_size=window_size,
383
+ block_length=block_length,
384
+ pre_ln=pre_ln,
385
+ )
386
+
387
+ def forward(self, x, x_mask=None, other_embeds=0, attn_mask=1):
388
+ if self.n_vocab > 0:
389
+ x_lengths = (x > 0).long().sum(-1)
390
+ x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
391
+ else:
392
+ x_lengths = (x.abs().sum(-1) > 0).long().sum(-1)
393
+ x = x + other_embeds
394
+ x = torch.transpose(x, 1, -1) # [b, h, t]
395
+ x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
396
+
397
+ if self.prenet:
398
+ x = self.pre(x, x_mask)
399
+ self.prenet_out = x.transpose(1, 2)
400
+ if hasattr(self, 'encoder_inp_proj'):
401
+ x = self.encoder_inp_proj(x) * x_mask
402
+ x = self.encoder(x, x_mask, attn_mask)
403
+ return x.transpose(1, 2)
tts/modules/ar_dur/commons/rot_transformer.py ADDED
@@ -0,0 +1,649 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import torch
17
+ from typing import Optional, Tuple
18
+ from torch import nn
19
+ from torch.nn import Parameter, Linear
20
+ from tts.modules.ar_dur.commons.layers import LayerNorm, Embedding
21
+ from tts.modules.ar_dur.commons.transformer import TransformerFFNLayer, MultiheadAttention
22
+ from tts.modules.ar_dur.commons.seq_utils import get_incremental_state, set_incremental_state, softmax, make_positions
23
+ import torch.nn.functional as F
24
+
25
+ DEFAULT_MAX_SOURCE_POSITIONS = 3000
26
+ DEFAULT_MAX_TARGET_POSITIONS = 3000
27
+
28
+
29
+ class SinusoidalPositionalEmbedding(nn.Module):
30
+ """This module produces sinusoidal positional embeddings of any length.
31
+
32
+ Padding symbols are ignored.
33
+ """
34
+
35
+ def __init__(self, embedding_dim, padding_idx, init_size=1024):
36
+ super().__init__()
37
+ self.embedding_dim = embedding_dim
38
+ self.padding_idx = padding_idx
39
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
40
+ init_size,
41
+ embedding_dim,
42
+ padding_idx,
43
+ )
44
+ self.register_buffer('_float_tensor', torch.FloatTensor(1))
45
+
46
+ @staticmethod
47
+ def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
48
+ """Build sinusoidal embeddings.
49
+
50
+ This matches the implementation in tensor2tensor, but differs slightly
51
+ from the description in Section 3.5 of "Attention Is All You Need".
52
+ """
53
+ half_dim = embedding_dim // 2
54
+ emb = math.log(10000) / (half_dim - 1)
55
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
56
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
57
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
58
+ if embedding_dim % 2 == 1:
59
+ # zero pad
60
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
61
+ if padding_idx is not None:
62
+ emb[padding_idx, :] = 0
63
+ return emb
64
+
65
+ def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
66
+ """Input is expected to be of size [bsz x seqlen]."""
67
+ bsz, seq_len = input.shape[:2]
68
+ max_pos = self.padding_idx + 1 + seq_len
69
+ if self.weights is None or max_pos > self.weights.size(0):
70
+ # recompute/expand embeddings if needed
71
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
72
+ max_pos,
73
+ self.embedding_dim,
74
+ self.padding_idx,
75
+ )
76
+ self.weights = self.weights.to(self._float_tensor)
77
+
78
+ if incremental_state is not None:
79
+ # positions is the same for every token when decoding a single step
80
+ pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
81
+ return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
82
+
83
+ positions = make_positions(input, self.padding_idx) if positions is None else positions
84
+ return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
85
+
86
+ def max_positions(self):
87
+ """Maximum number of supported positions."""
88
+ return int(1e5) # an arbitrary large number
89
+
90
+
91
+ class RotaryEmbeddings(nn.Module):
92
+ cos: torch.Tensor
93
+ sin: torch.Tensor
94
+ theta: torch.Tensor
95
+
96
+ def __init__(
97
+ self,
98
+ width: int,
99
+ *,
100
+ seq_len: int = 40000,
101
+ base: int = 10000,
102
+ device: Optional[torch.device] = None,
103
+ ):
104
+ """Rotary embeddings (Su et al., 2021) layer. The rotary embedding
105
+ will be precomputed for up to 'seq _len' positions. The embedding
106
+ will be recomputed when a longer sequence is found in the input.
107
+
108
+ :param width:
109
+ Rotary embedding dimensionality, must be even.
110
+ :param seq_len:
111
+ Number of positons to initially precompute.
112
+ :param base:
113
+ The base used for Θ_i, determines the cycle length of the
114
+ embeddings.
115
+ :param device: Device on which the module is to be initialized.
116
+ """
117
+ super().__init__()
118
+
119
+ if width % 2:
120
+ raise ValueError(f"Width of rotary embeddings must be even, was: {width}")
121
+
122
+ # Ignore allocations on the meta device as we don't persist our buffer,
123
+ # i.e., we don't expect the backing tensor to be replaced with pretrained weights.
124
+ if device is not None and device.type == "meta":
125
+ device = None
126
+ # Θ_i = 10000^(-2(i-1)/d)
127
+ theta = torch.pow(
128
+ base, -torch.arange(0, width, 2, dtype=torch.float, device=device) / width
129
+ )
130
+ self.register_buffer("theta", theta, persistent=False)
131
+
132
+ self._create_rotary_embed(width=width, length=seq_len)
133
+
134
+ def _create_rotary_embed(self, *, width: int, length: int):
135
+ # mΘ
136
+ position = torch.arange(length, device=self.theta.device).unsqueeze(1)
137
+ m_theta = position * self.theta.unsqueeze(0)
138
+
139
+ # We apply both sin and cos twice (see Eq 15, 34), but the ordering
140
+ # is changed for compatibility with most common implementations.
141
+ m_theta = torch.cat([m_theta, m_theta], dim=-1)
142
+
143
+ re_cos = m_theta.cos().view([length, width])
144
+ re_sin = m_theta.sin().view([length, width])
145
+
146
+ self.register_buffer("cos", re_cos, persistent=False)
147
+ self.register_buffer("sin", re_sin, persistent=False)
148
+
149
+ def _rotate(self, input: torch.Tensor):
150
+ """Rotate the input tensor by half of its innermost width.
151
+
152
+ input (Tensor): array to rotate.
153
+ RETURNS (Tensor): rotated array.
154
+
155
+ Shapes:
156
+ input - (..., width)
157
+ output - (..., width)
158
+ """
159
+ half_idx = input.shape[-1] // 2
160
+ input_1 = -input[..., half_idx:]
161
+ input_2 = input[..., :half_idx]
162
+ return torch.cat([input_1, input_2], dim=-1)
163
+
164
+ def forward(self, input: torch.Tensor, *, positions: Optional[torch.Tensor] = None):
165
+ """
166
+ Apply rotary embeddings to an array.
167
+
168
+ :param input: Array to apply the rotary embeddings to.
169
+ :param positions: positions of the inputs. If no positions are
170
+ provided, they are assumed to be [0, seq_len).
171
+ :return: Array with the rotary embeddings applied.
172
+
173
+ Shapes:
174
+ input - (batch_size, num_heads, seq_len, width_per_head)
175
+ positions - (batch_size, seq_len)
176
+ output - (batch_size, num_heads, seq_len, width_per_head)
177
+ """
178
+ batch_size, _, seq_len, width = input.shape
179
+
180
+ if positions is None:
181
+ # Fastpath: positions from [0..seq_len), avoid indexing.
182
+ if self.cos.size(-2) < seq_len:
183
+ self._create_rotary_embed(width=width, length=seq_len)
184
+ rot_cos = self.cos[:seq_len, :].view(1, 1, seq_len, width)
185
+ rot_sin = self.sin[:seq_len, :].view(1, 1, seq_len, width)
186
+ else:
187
+ max_len = int(positions.max()) + 1
188
+ if self.cos.size(-2) < max_len:
189
+ self._create_rotary_embed(width=width, length=max_len)
190
+
191
+ # Flatten positions to index cos/sin arrays, then unflatten.
192
+ #
193
+ # Example shapes:
194
+ #
195
+ # positions_flat - (batch_size * seq_len)
196
+ # self.cos - (max_len, width)
197
+ # rot_cos - (batch_size, seq_len, width)
198
+ positions_flat = positions.view(-1)
199
+ rot_cos = self.cos[positions_flat].view(batch_size, 1, seq_len, width)
200
+ rot_sin = self.sin[positions_flat].view(batch_size, 1, seq_len, width)
201
+
202
+ # Eq 34 with ordering changed for compatibility.
203
+ return rot_cos * input + rot_sin * self._rotate(input)
204
+
205
+
206
+ class RotMultiheadAttention(MultiheadAttention):
207
+ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
208
+ add_bias_kv=False, add_zero_attn=False, self_attention=False,
209
+ encoder_decoder_attention=False):
210
+ super().__init__(embed_dim, num_heads, kdim=kdim, vdim=vdim, dropout=dropout, bias=bias,
211
+ add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=self_attention,
212
+ encoder_decoder_attention=encoder_decoder_attention)
213
+ self.rotary_embeds = RotaryEmbeddings(width=embed_dim // num_heads)
214
+
215
+ def forward(
216
+ self,
217
+ query, key, value,
218
+ spk_pos_ids_flat=None,
219
+ key_padding_mask=None,
220
+ incremental_state=None,
221
+ need_weights=True,
222
+ static_kv=False,
223
+ attn_mask=None,
224
+ before_softmax=False,
225
+ need_head_weights=False,
226
+ enc_dec_attn_constraint_mask=None,
227
+ reset_attn_weight=None
228
+ ):
229
+ """Input shape: Time x Batch x Channel
230
+
231
+ Args:
232
+ key_padding_mask (ByteTensor, optional): mask to exclude
233
+ keys that are pads, of shape `(batch, src_len)`, where
234
+ padding elements are indicated by 1s.
235
+ need_weights (bool, optional): return the attention weights,
236
+ averaged over heads (default: False).
237
+ attn_mask (ByteTensor, optional): typically used to
238
+ implement causal attention, where the mask prevents the
239
+ attention from looking forward in time (default: None).
240
+ before_softmax (bool, optional): return the raw attention
241
+ weights and values before the attention softmax.
242
+ need_head_weights (bool, optional): return the attention
243
+ weights for each head. Implies *need_weights*. Default:
244
+ return the average attention weights over all heads.
245
+ """
246
+ if need_head_weights:
247
+ need_weights = True
248
+
249
+ tgt_len, bsz, embed_dim = query.size()
250
+ assert embed_dim == self.embed_dim
251
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
252
+
253
+ if incremental_state is not None:
254
+ saved_state = self._get_input_buffer(incremental_state)
255
+ if 'prev_key' in saved_state:
256
+ # previous time steps are cached - no need to recompute
257
+ # key and value if they are static
258
+ if static_kv:
259
+ assert self.encoder_decoder_attention and not self.self_attention
260
+ key = value = None
261
+ else:
262
+ saved_state = None
263
+
264
+ if self.self_attention:
265
+ # self-attention
266
+ q, k, v = self.in_proj_qkv(query)
267
+ elif self.encoder_decoder_attention:
268
+ # encoder-decoder attention
269
+ q = self.in_proj_q(query)
270
+ if key is None:
271
+ assert value is None
272
+ k = v = None
273
+ else:
274
+ k = self.in_proj_k(key)
275
+ v = self.in_proj_v(key)
276
+ else:
277
+ q = self.in_proj_q(query)
278
+ k = self.in_proj_k(key)
279
+ v = self.in_proj_v(value)
280
+ q = q * self.scaling
281
+
282
+ if self.bias_k is not None:
283
+ assert self.bias_v is not None
284
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
285
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
286
+ if attn_mask is not None:
287
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
288
+ if key_padding_mask is not None:
289
+ key_padding_mask = torch.cat(
290
+ [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
291
+
292
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
293
+ if k is not None:
294
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
295
+ if v is not None:
296
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
297
+
298
+ # Apply rot embedding and store incremental_state
299
+ q = self.rotary_embeds(q[None, :], positions=spk_pos_ids_flat)[0]
300
+ if saved_state is not None:
301
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
302
+ if 'prev_key' in saved_state:
303
+ prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
304
+ if static_kv:
305
+ k = prev_key
306
+ else:
307
+ k = torch.cat((prev_key, k), dim=1)
308
+ if 'prev_value' in saved_state:
309
+ prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
310
+ if static_kv:
311
+ v = prev_value
312
+ else:
313
+ v = torch.cat((prev_value, v), dim=1)
314
+ saved_state['prev_key'], saved_state['prev_value'] = k.view(bsz, self.num_heads, -1, self.head_dim), v.view(
315
+ bsz, self.num_heads, -1, self.head_dim)
316
+ self._set_input_buffer(incremental_state, saved_state)
317
+ if incremental_state is not None:
318
+ key_pos = torch.arange(k.shape[-2], device=q.device).unsqueeze(0)
319
+ else:
320
+ key_pos = spk_pos_ids_flat
321
+ k = self.rotary_embeds(k[None, :], positions=key_pos)[0]
322
+
323
+ src_len = k.size(1)
324
+
325
+ # This is part of a workaround to get around fork/join parallelism
326
+ # not supporting Optional types.
327
+ if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
328
+ key_padding_mask = None
329
+
330
+ if key_padding_mask is not None:
331
+ assert key_padding_mask.size(0) == bsz
332
+ assert key_padding_mask.size(1) == src_len
333
+
334
+ if self.add_zero_attn:
335
+ src_len += 1
336
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
337
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
338
+ if attn_mask is not None:
339
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
340
+ if key_padding_mask is not None:
341
+ key_padding_mask = torch.cat(
342
+ [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
343
+
344
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
345
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
346
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
347
+
348
+ if attn_mask is not None:
349
+ if len(attn_mask.shape) == 2:
350
+ attn_mask = attn_mask.unsqueeze(0)
351
+ elif len(attn_mask.shape) == 3:
352
+ attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
353
+ bsz * self.num_heads, tgt_len, src_len)
354
+ attn_weights = attn_weights + attn_mask
355
+
356
+ if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
357
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
358
+ attn_weights = attn_weights.masked_fill(
359
+ enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
360
+ -1e8,
361
+ )
362
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
363
+
364
+ if key_padding_mask is not None:
365
+ # don't attend to padding symbols
366
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
367
+ attn_weights = attn_weights.masked_fill(
368
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
369
+ -1e8,
370
+ )
371
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
372
+
373
+ attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
374
+
375
+ if before_softmax:
376
+ return attn_weights, v
377
+
378
+ attn_weights_float = softmax(attn_weights, dim=-1)
379
+ attn_weights = attn_weights_float.type_as(attn_weights)
380
+ attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
381
+
382
+ if reset_attn_weight is not None:
383
+ if reset_attn_weight:
384
+ self.last_attn_probs = attn_probs.detach()
385
+ else:
386
+ assert self.last_attn_probs is not None
387
+ attn_probs = self.last_attn_probs
388
+ attn = torch.bmm(attn_probs, v)
389
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
390
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
391
+ attn = self.out_proj(attn)
392
+
393
+ if need_weights:
394
+ attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
395
+ if not need_head_weights:
396
+ # average attention weights over heads
397
+ attn_weights = attn_weights.mean(dim=0)
398
+ else:
399
+ attn_weights = None
400
+
401
+ return attn, (attn_weights, attn_logits)
402
+
403
+
404
+ class RotMultiheadAttention2(MultiheadAttention):
405
+ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
406
+ add_bias_kv=False, add_zero_attn=False, self_attention=False,
407
+ encoder_decoder_attention=False):
408
+ super().__init__(embed_dim, num_heads, kdim=kdim, vdim=vdim, dropout=dropout, bias=bias,
409
+ add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=self_attention,
410
+ encoder_decoder_attention=encoder_decoder_attention)
411
+ self.rotary_embeds = RotaryEmbeddings(width=embed_dim // num_heads)
412
+
413
+ def forward(
414
+ self,
415
+ query, key, value,
416
+ spk_pos_ids_flat=None,
417
+ key_padding_mask=None,
418
+ incremental_state=None,
419
+ need_weights=True,
420
+ static_kv=False,
421
+ attn_mask=None,
422
+ before_softmax=False,
423
+ need_head_weights=False,
424
+ enc_dec_attn_constraint_mask=None,
425
+ reset_attn_weight=None
426
+ ):
427
+ """Input shape: Time x Batch x Channel
428
+
429
+ Args:
430
+ key_padding_mask (ByteTensor, optional): mask to exclude
431
+ keys that are pads, of shape `(batch, src_len)`, where
432
+ padding elements are indicated by 1s.
433
+ need_weights (bool, optional): return the attention weights,
434
+ averaged over heads (default: False).
435
+ attn_mask (ByteTensor, optional): typically used to
436
+ implement causal attention, where the mask prevents the
437
+ attention from looking forward in time (default: None).
438
+ before_softmax (bool, optional): return the raw attention
439
+ weights and values before the attention softmax.
440
+ need_head_weights (bool, optional): return the attention
441
+ weights for each head. Implies *need_weights*. Default:
442
+ return the average attention weights over all heads.
443
+ """
444
+ if need_head_weights:
445
+ need_weights = True
446
+
447
+ tgt_len, bsz, embed_dim = query.size()
448
+ assert embed_dim == self.embed_dim
449
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
450
+
451
+ if incremental_state is not None:
452
+ saved_state = self._get_input_buffer(incremental_state)
453
+ if 'prev_key' in saved_state:
454
+ # previous time steps are cached - no need to recompute
455
+ # key and value if they are static
456
+ if static_kv:
457
+ assert self.encoder_decoder_attention and not self.self_attention
458
+ key = value = None
459
+ else:
460
+ saved_state = None
461
+
462
+ if self.self_attention:
463
+ # self-attention
464
+ q, k, v = self.in_proj_qkv(query)
465
+ elif self.encoder_decoder_attention:
466
+ # encoder-decoder attention
467
+ q = self.in_proj_q(query)
468
+ if key is None:
469
+ assert value is None
470
+ k = v = None
471
+ else:
472
+ k = self.in_proj_k(key)
473
+ v = self.in_proj_v(key)
474
+ else:
475
+ q = self.in_proj_q(query)
476
+ k = self.in_proj_k(key)
477
+ v = self.in_proj_v(value)
478
+
479
+ if self.bias_k is not None:
480
+ assert self.bias_v is not None
481
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
482
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
483
+ if attn_mask is not None:
484
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
485
+ if key_padding_mask is not None:
486
+ key_padding_mask = torch.cat(
487
+ [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
488
+
489
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
490
+ if k is not None:
491
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
492
+ if v is not None:
493
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
494
+
495
+ # Apply rot embedding and store incremental_state
496
+ q = self.rotary_embeds(q[None, :], positions=spk_pos_ids_flat)[0]
497
+ if saved_state is not None:
498
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
499
+ if 'prev_key' in saved_state:
500
+ prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
501
+ if static_kv:
502
+ k = prev_key
503
+ else:
504
+ k = torch.cat((prev_key, k), dim=1)
505
+ if 'prev_value' in saved_state:
506
+ prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
507
+ if static_kv:
508
+ v = prev_value
509
+ else:
510
+ v = torch.cat((prev_value, v), dim=1)
511
+ saved_state['prev_key'], saved_state['prev_value'] = k.view(bsz, self.num_heads, -1, self.head_dim), v.view(
512
+ bsz, self.num_heads, -1, self.head_dim)
513
+ self._set_input_buffer(incremental_state, saved_state)
514
+ key_pos = torch.arange(k.shape[-2], device=q.device).unsqueeze(0)
515
+ k = self.rotary_embeds(k[None, :], positions=key_pos)[0]
516
+
517
+ src_len = k.size(1)
518
+
519
+ # This is part of a workaround to get around fork/join parallelism
520
+ # not supporting Optional types.
521
+ if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
522
+ key_padding_mask = None
523
+
524
+ if key_padding_mask is not None:
525
+ assert key_padding_mask.size(0) == bsz
526
+ assert key_padding_mask.size(1) == src_len
527
+
528
+ if attn_mask is not None:
529
+ if len(attn_mask.shape) == 2:
530
+ attn_mask = attn_mask.unsqueeze(0)
531
+ elif len(attn_mask.shape) == 3:
532
+ attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
533
+ bsz * self.num_heads, tgt_len, src_len)
534
+ attn = torch.nn.functional.scaled_dot_product_attention(
535
+ q, k, v, attn_mask=attn_mask, dropout_p=0, is_causal=False)
536
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
537
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
538
+ attn_logits = None
539
+ attn_weights = None
540
+ return attn, (attn_weights, attn_logits)
541
+
542
+
543
+ class RotDecSALayer(nn.Module):
544
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1,
545
+ kernel_size=9, ffn_hidden_size=1024, act='gelu', post_ln=False, bias=True):
546
+ super().__init__()
547
+ self.c = c
548
+ self.dropout = dropout
549
+ self.layer_norm1 = LayerNorm(c)
550
+ self.self_attn = RotMultiheadAttention(
551
+ c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
552
+ )
553
+ self.layer_norm2 = LayerNorm(c)
554
+ self.ffn = TransformerFFNLayer(
555
+ c, ffn_hidden_size, padding='LEFT', kernel_size=kernel_size,
556
+ dropout=relu_dropout, act=act, bias=bias)
557
+ self.post_ln = post_ln
558
+
559
+ def forward(
560
+ self,
561
+ x,
562
+ encoder_out=None,
563
+ encoder_padding_mask=None,
564
+ incremental_state=None,
565
+ self_attn_mask=None,
566
+ self_attn_padding_mask=None,
567
+ attn_out=None,
568
+ reset_attn_weight=None,
569
+ spk_pos_ids_flat=None,
570
+ **kwargs,
571
+ ):
572
+ layer_norm_training = kwargs.get('layer_norm_training', None)
573
+ if layer_norm_training is not None:
574
+ self.layer_norm1.training = layer_norm_training
575
+ self.layer_norm2.training = layer_norm_training
576
+ residual = x
577
+ if not self.post_ln:
578
+ x = self.layer_norm1(x)
579
+
580
+ x, (attn_weights, _) = self.self_attn(
581
+ query=x,
582
+ key=x,
583
+ value=x,
584
+ key_padding_mask=self_attn_padding_mask,
585
+ incremental_state=incremental_state,
586
+ attn_mask=self_attn_mask,
587
+ spk_pos_ids_flat=spk_pos_ids_flat
588
+ )
589
+ x = F.dropout(x, self.dropout, training=self.training)
590
+ x = residual + x
591
+ if self.post_ln:
592
+ x = self.layer_norm1(x)
593
+
594
+ residual = x
595
+ if not self.post_ln:
596
+ x = self.layer_norm2(x)
597
+ x = self.ffn(x, incremental_state=incremental_state)
598
+ x = F.dropout(x, self.dropout, training=self.training)
599
+ x = residual + x
600
+ if self.post_ln:
601
+ x = self.layer_norm2(x)
602
+ return x, attn_weights
603
+
604
+ def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
605
+ self.encoder_attn.clear_buffer(incremental_state)
606
+ self.ffn.clear_buffer(incremental_state)
607
+
608
+ def set_buffer(self, name, tensor, incremental_state):
609
+ return set_incremental_state(self, incremental_state, name, tensor)
610
+
611
+
612
+ class RotDecSALayer2(RotDecSALayer):
613
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1, kernel_size=9,
614
+ ffn_hidden_size=1024, act='gelu', post_ln=False):
615
+ super().__init__(c, num_heads, dropout, attention_dropout, relu_dropout, kernel_size, ffn_hidden_size, act,
616
+ post_ln)
617
+ self.self_attn = RotMultiheadAttention2(
618
+ c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
619
+ )
620
+
621
+
622
+ class RotTransformerDecoderLayer(nn.Module):
623
+ def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=8, ffn_hidden_size=1024, post_ln=False,
624
+ op_version=1, bias=True):
625
+ super().__init__()
626
+ self.hidden_size = hidden_size
627
+ self.dropout = dropout
628
+ self.num_heads = num_heads
629
+ if op_version == 1:
630
+ self.op = RotDecSALayer(
631
+ hidden_size, num_heads, dropout=dropout,
632
+ attention_dropout=0.0, relu_dropout=dropout,
633
+ kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size,
634
+ post_ln=post_ln, bias=bias)
635
+ else:
636
+ self.op = RotDecSALayer2(
637
+ hidden_size, num_heads, dropout=dropout,
638
+ attention_dropout=0.0, relu_dropout=dropout,
639
+ kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size,
640
+ post_ln=post_ln)
641
+
642
+ def forward(self, x, **kwargs):
643
+ return self.op(x, **kwargs)
644
+
645
+ def clear_buffer(self, *args):
646
+ return self.op.clear_buffer(*args)
647
+
648
+ def set_buffer(self, *args):
649
+ return self.op.set_buffer(*args)
tts/modules/ar_dur/commons/seq_utils.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from collections import defaultdict
16
+ import torch
17
+ import torch.nn.functional as F
18
+
19
+
20
+ def make_positions(tensor, padding_idx):
21
+ """Replace non-padding symbols with their position numbers.
22
+
23
+ Position numbers begin at padding_idx+1. Padding symbols are ignored.
24
+ """
25
+ # The series of casts and type-conversions here are carefully
26
+ # balanced to both work with ONNX export and XLA. In particular XLA
27
+ # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
28
+ # how to handle the dtype kwarg in cumsum.
29
+ mask = tensor.ne(padding_idx).int()
30
+ return (
31
+ torch.cumsum(mask, dim=1).type_as(mask) * mask
32
+ ).long() + padding_idx
33
+
34
+
35
+ def softmax(x, dim):
36
+ return F.softmax(x, dim=dim, dtype=torch.float32)
37
+
38
+
39
+ def sequence_mask(lengths, maxlen=None, dtype=torch.bool):
40
+ if maxlen is None:
41
+ maxlen = lengths.max()
42
+ mask = ~(torch.ones((len(lengths), maxlen)).to(lengths.device).cumsum(dim=1).t() > lengths).t()
43
+ mask.type(dtype)
44
+ return mask
45
+
46
+
47
+ def weights_nonzero_speech(target):
48
+ # target : B x T x mel
49
+ # Assign weight 1.0 to all labels except for padding (id=0).
50
+ dim = target.size(-1)
51
+ return target.abs().sum(-1, keepdim=True).ne(0).float().repeat(1, 1, dim)
52
+
53
+
54
+ INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0)
55
+
56
+
57
+ def _get_full_incremental_state_key(module_instance, key):
58
+ module_name = module_instance.__class__.__name__
59
+
60
+ # assign a unique ID to each module instance, so that incremental state is
61
+ # not shared across module instances
62
+ if not hasattr(module_instance, '_instance_id'):
63
+ INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1
64
+ module_instance._instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name]
65
+
66
+ return '{}.{}.{}'.format(module_name, module_instance._instance_id, key)
67
+
68
+
69
+ def get_incremental_state(module, incremental_state, key):
70
+ """Helper for getting incremental state for an nn.Module."""
71
+ full_key = _get_full_incremental_state_key(module, key)
72
+ if incremental_state is None or full_key not in incremental_state:
73
+ return None
74
+ return incremental_state[full_key]
75
+
76
+
77
+ def set_incremental_state(module, incremental_state, key, value):
78
+ """Helper for setting incremental state for an nn.Module."""
79
+ if incremental_state is not None:
80
+ full_key = _get_full_incremental_state_key(module, key)
81
+ incremental_state[full_key] = value
82
+
83
+
84
+ def fill_with_neg_inf(t):
85
+ """FP16-compatible function that fills a tensor with -inf."""
86
+ return t.float().fill_(float('-inf')).type_as(t)
87
+
88
+
89
+ def fill_with_neg_inf2(t):
90
+ """FP16-compatible function that fills a tensor with -inf."""
91
+ return t.float().fill_(-1e8).type_as(t)
92
+
93
+
94
+ def select_attn(attn_logits, type='best'):
95
+ """
96
+
97
+ :param attn_logits: [n_layers, B, n_head, T_sp, T_txt]
98
+ :return:
99
+ """
100
+ encdec_attn = torch.stack(attn_logits, 0).transpose(1, 2)
101
+ # [n_layers * n_head, B, T_sp, T_txt]
102
+ encdec_attn = (encdec_attn.reshape([-1, *encdec_attn.shape[2:]])).softmax(-1)
103
+ if type == 'best':
104
+ indices = encdec_attn.max(-1).values.sum(-1).argmax(0)
105
+ encdec_attn = encdec_attn.gather(
106
+ 0, indices[None, :, None, None].repeat(1, 1, encdec_attn.size(-2), encdec_attn.size(-1)))[0]
107
+ return encdec_attn
108
+ elif type == 'mean':
109
+ return encdec_attn.mean(0)
110
+
111
+
112
+ def make_pad_mask(lengths, xs=None, length_dim=-1):
113
+ """Make mask tensor containing indices of padded part.
114
+ Args:
115
+ lengths (LongTensor or List): Batch of lengths (B,).
116
+ xs (Tensor, optional): The reference tensor.
117
+ If set, masks will be the same shape as this tensor.
118
+ length_dim (int, optional): Dimension indicator of the above tensor.
119
+ See the example.
120
+ Returns:
121
+ Tensor: Mask tensor containing indices of padded part.
122
+ dtype=torch.uint8 in PyTorch 1.2-
123
+ dtype=torch.bool in PyTorch 1.2+ (including 1.2)
124
+ Examples:
125
+ With only lengths.
126
+ >>> lengths = [5, 3, 2]
127
+ >>> make_non_pad_mask(lengths)
128
+ masks = [[0, 0, 0, 0 ,0],
129
+ [0, 0, 0, 1, 1],
130
+ [0, 0, 1, 1, 1]]
131
+ With the reference tensor.
132
+ >>> xs = torch.zeros((3, 2, 4))
133
+ >>> make_pad_mask(lengths, xs)
134
+ tensor([[[0, 0, 0, 0],
135
+ [0, 0, 0, 0]],
136
+ [[0, 0, 0, 1],
137
+ [0, 0, 0, 1]],
138
+ [[0, 0, 1, 1],
139
+ [0, 0, 1, 1]]], dtype=torch.uint8)
140
+ >>> xs = torch.zeros((3, 2, 6))
141
+ >>> make_pad_mask(lengths, xs)
142
+ tensor([[[0, 0, 0, 0, 0, 1],
143
+ [0, 0, 0, 0, 0, 1]],
144
+ [[0, 0, 0, 1, 1, 1],
145
+ [0, 0, 0, 1, 1, 1]],
146
+ [[0, 0, 1, 1, 1, 1],
147
+ [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
148
+ With the reference tensor and dimension indicator.
149
+ >>> xs = torch.zeros((3, 6, 6))
150
+ >>> make_pad_mask(lengths, xs, 1)
151
+ tensor([[[0, 0, 0, 0, 0, 0],
152
+ [0, 0, 0, 0, 0, 0],
153
+ [0, 0, 0, 0, 0, 0],
154
+ [0, 0, 0, 0, 0, 0],
155
+ [0, 0, 0, 0, 0, 0],
156
+ [1, 1, 1, 1, 1, 1]],
157
+ [[0, 0, 0, 0, 0, 0],
158
+ [0, 0, 0, 0, 0, 0],
159
+ [0, 0, 0, 0, 0, 0],
160
+ [1, 1, 1, 1, 1, 1],
161
+ [1, 1, 1, 1, 1, 1],
162
+ [1, 1, 1, 1, 1, 1]],
163
+ [[0, 0, 0, 0, 0, 0],
164
+ [0, 0, 0, 0, 0, 0],
165
+ [1, 1, 1, 1, 1, 1],
166
+ [1, 1, 1, 1, 1, 1],
167
+ [1, 1, 1, 1, 1, 1],
168
+ [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
169
+ >>> make_pad_mask(lengths, xs, 2)
170
+ tensor([[[0, 0, 0, 0, 0, 1],
171
+ [0, 0, 0, 0, 0, 1],
172
+ [0, 0, 0, 0, 0, 1],
173
+ [0, 0, 0, 0, 0, 1],
174
+ [0, 0, 0, 0, 0, 1],
175
+ [0, 0, 0, 0, 0, 1]],
176
+ [[0, 0, 0, 1, 1, 1],
177
+ [0, 0, 0, 1, 1, 1],
178
+ [0, 0, 0, 1, 1, 1],
179
+ [0, 0, 0, 1, 1, 1],
180
+ [0, 0, 0, 1, 1, 1],
181
+ [0, 0, 0, 1, 1, 1]],
182
+ [[0, 0, 1, 1, 1, 1],
183
+ [0, 0, 1, 1, 1, 1],
184
+ [0, 0, 1, 1, 1, 1],
185
+ [0, 0, 1, 1, 1, 1],
186
+ [0, 0, 1, 1, 1, 1],
187
+ [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
188
+ """
189
+ if length_dim == 0:
190
+ raise ValueError("length_dim cannot be 0: {}".format(length_dim))
191
+
192
+ if not isinstance(lengths, list):
193
+ lengths = lengths.tolist()
194
+ bs = int(len(lengths))
195
+ if xs is None:
196
+ maxlen = int(max(lengths))
197
+ else:
198
+ maxlen = xs.size(length_dim)
199
+
200
+ seq_range = torch.arange(0, maxlen, dtype=torch.int64)
201
+ seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
202
+ seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
203
+ mask = seq_range_expand >= seq_length_expand
204
+
205
+ if xs is not None:
206
+ assert xs.size(0) == bs, (xs.size(0), bs)
207
+
208
+ if length_dim < 0:
209
+ length_dim = xs.dim() + length_dim
210
+ # ind = (:, None, ..., None, :, , None, ..., None)
211
+ ind = tuple(
212
+ slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
213
+ )
214
+ mask = mask[ind].expand_as(xs).to(xs.device)
215
+ return mask
216
+
217
+
218
+ def make_non_pad_mask(lengths, xs=None, length_dim=-1):
219
+ """Make mask tensor containing indices of non-padded part.
220
+ Args:
221
+ lengths (LongTensor or List): Batch of lengths (B,).
222
+ xs (Tensor, optional): The reference tensor.
223
+ If set, masks will be the same shape as this tensor.
224
+ length_dim (int, optional): Dimension indicator of the above tensor.
225
+ See the example.
226
+ Returns:
227
+ ByteTensor: mask tensor containing indices of padded part.
228
+ dtype=torch.uint8 in PyTorch 1.2-
229
+ dtype=torch.bool in PyTorch 1.2+ (including 1.2)
230
+ Examples:
231
+ With only lengths.
232
+ >>> lengths = [5, 3, 2]
233
+ >>> make_non_pad_mask(lengths)
234
+ masks = [[1, 1, 1, 1 ,1],
235
+ [1, 1, 1, 0, 0],
236
+ [1, 1, 0, 0, 0]]
237
+ With the reference tensor.
238
+ >>> xs = torch.zeros((3, 2, 4))
239
+ >>> make_non_pad_mask(lengths, xs)
240
+ tensor([[[1, 1, 1, 1],
241
+ [1, 1, 1, 1]],
242
+ [[1, 1, 1, 0],
243
+ [1, 1, 1, 0]],
244
+ [[1, 1, 0, 0],
245
+ [1, 1, 0, 0]]], dtype=torch.uint8)
246
+ >>> xs = torch.zeros((3, 2, 6))
247
+ >>> make_non_pad_mask(lengths, xs)
248
+ tensor([[[1, 1, 1, 1, 1, 0],
249
+ [1, 1, 1, 1, 1, 0]],
250
+ [[1, 1, 1, 0, 0, 0],
251
+ [1, 1, 1, 0, 0, 0]],
252
+ [[1, 1, 0, 0, 0, 0],
253
+ [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
254
+ With the reference tensor and dimension indicator.
255
+ >>> xs = torch.zeros((3, 6, 6))
256
+ >>> make_non_pad_mask(lengths, xs, 1)
257
+ tensor([[[1, 1, 1, 1, 1, 1],
258
+ [1, 1, 1, 1, 1, 1],
259
+ [1, 1, 1, 1, 1, 1],
260
+ [1, 1, 1, 1, 1, 1],
261
+ [1, 1, 1, 1, 1, 1],
262
+ [0, 0, 0, 0, 0, 0]],
263
+ [[1, 1, 1, 1, 1, 1],
264
+ [1, 1, 1, 1, 1, 1],
265
+ [1, 1, 1, 1, 1, 1],
266
+ [0, 0, 0, 0, 0, 0],
267
+ [0, 0, 0, 0, 0, 0],
268
+ [0, 0, 0, 0, 0, 0]],
269
+ [[1, 1, 1, 1, 1, 1],
270
+ [1, 1, 1, 1, 1, 1],
271
+ [0, 0, 0, 0, 0, 0],
272
+ [0, 0, 0, 0, 0, 0],
273
+ [0, 0, 0, 0, 0, 0],
274
+ [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
275
+ >>> make_non_pad_mask(lengths, xs, 2)
276
+ tensor([[[1, 1, 1, 1, 1, 0],
277
+ [1, 1, 1, 1, 1, 0],
278
+ [1, 1, 1, 1, 1, 0],
279
+ [1, 1, 1, 1, 1, 0],
280
+ [1, 1, 1, 1, 1, 0],
281
+ [1, 1, 1, 1, 1, 0]],
282
+ [[1, 1, 1, 0, 0, 0],
283
+ [1, 1, 1, 0, 0, 0],
284
+ [1, 1, 1, 0, 0, 0],
285
+ [1, 1, 1, 0, 0, 0],
286
+ [1, 1, 1, 0, 0, 0],
287
+ [1, 1, 1, 0, 0, 0]],
288
+ [[1, 1, 0, 0, 0, 0],
289
+ [1, 1, 0, 0, 0, 0],
290
+ [1, 1, 0, 0, 0, 0],
291
+ [1, 1, 0, 0, 0, 0],
292
+ [1, 1, 0, 0, 0, 0],
293
+ [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
294
+ """
295
+ return ~make_pad_mask(lengths, xs, length_dim)
296
+
297
+
298
+ def get_mask_from_lengths(lengths):
299
+ max_len = torch.max(lengths).item()
300
+ ids = torch.arange(0, max_len).to(lengths.device)
301
+ mask = (ids < lengths.unsqueeze(1)).bool()
302
+ return mask
303
+
304
+
305
+ def group_hidden_by_segs(h, seg_ids, max_len):
306
+ """
307
+
308
+ :param h: [B, T, H]
309
+ :param seg_ids: [B, T]
310
+ :return: h_ph: [B, T_ph, H]
311
+ """
312
+ B, T, H = h.shape
313
+ h_gby_segs = h.new_zeros([B, max_len + 1, H]).scatter_add_(1, seg_ids[:, :, None].repeat([1, 1, H]), h)
314
+ all_ones = h.new_ones(h.shape[:2])
315
+ cnt_gby_segs = h.new_zeros([B, max_len + 1]).scatter_add_(1, seg_ids, all_ones).contiguous()
316
+ h_gby_segs = h_gby_segs[:, 1:]
317
+ cnt_gby_segs = cnt_gby_segs[:, 1:]
318
+ h_gby_segs = h_gby_segs / torch.clamp(cnt_gby_segs[:, :, None], min=1)
319
+ return h_gby_segs, cnt_gby_segs
320
+
321
+ def expand_by_repeat_times(source_encoding, lengths):
322
+ """
323
+ source_encoding: [T, C]
324
+ lengths, list of int, [T,], how many times each token should repeat
325
+ return:
326
+ expanded_encoding: [T_expand, C]
327
+ """
328
+ hid_dim = source_encoding.shape[1]
329
+ out2source = []
330
+ for i, length in enumerate(lengths):
331
+ out2source += [i for _ in range(length)]
332
+ out2source = torch.LongTensor(out2source).to(source_encoding.device)
333
+ out2source_ = out2source[:, None].repeat([1, hid_dim])
334
+ expanded_encoding = torch.gather(source_encoding, 0, out2source_) # [B, T, H]
335
+ return expanded_encoding
336
+
337
+
338
+ def expand_word2ph(word_encoding, ph2word):
339
+ word_encoding = F.pad(word_encoding,[0,0,1,0])
340
+ ph2word_ = ph2word[:, :, None].repeat([1, 1, word_encoding.shape[-1]])
341
+ out = torch.gather(word_encoding, 1, ph2word_) # [B, T, H]
342
+ return out
tts/modules/ar_dur/commons/transformer.py ADDED
@@ -0,0 +1,767 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import torch
17
+ from torch import nn
18
+ from torch.nn import Parameter, Linear
19
+ from tts.modules.ar_dur.commons.layers import LayerNorm, Embedding
20
+ from tts.modules.ar_dur.commons.seq_utils import get_incremental_state, set_incremental_state, softmax, make_positions
21
+ import torch.nn.functional as F
22
+
23
+ DEFAULT_MAX_SOURCE_POSITIONS = 3000
24
+ DEFAULT_MAX_TARGET_POSITIONS = 3000
25
+
26
+
27
+ class SinusoidalPositionalEmbedding(nn.Module):
28
+ """This module produces sinusoidal positional embeddings of any length.
29
+
30
+ Padding symbols are ignored.
31
+ """
32
+
33
+ def __init__(self, embedding_dim, padding_idx, init_size=1024):
34
+ super().__init__()
35
+ self.embedding_dim = embedding_dim
36
+ self.padding_idx = padding_idx
37
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
38
+ init_size,
39
+ embedding_dim,
40
+ padding_idx,
41
+ )
42
+ self.register_buffer('_float_tensor', torch.FloatTensor(1))
43
+
44
+ @staticmethod
45
+ def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
46
+ """Build sinusoidal embeddings.
47
+
48
+ This matches the implementation in tensor2tensor, but differs slightly
49
+ from the description in Section 3.5 of "Attention Is All You Need".
50
+ """
51
+ half_dim = embedding_dim // 2
52
+ emb = math.log(10000) / (half_dim - 1)
53
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
54
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
55
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
56
+ if embedding_dim % 2 == 1:
57
+ # zero pad
58
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
59
+ if padding_idx is not None:
60
+ emb[padding_idx, :] = 0
61
+ return emb
62
+
63
+ def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
64
+ """Input is expected to be of size [bsz x seqlen]."""
65
+ bsz, seq_len = input.shape[:2]
66
+ max_pos = self.padding_idx + 1 + seq_len
67
+ if self.weights is None or max_pos > self.weights.size(0):
68
+ # recompute/expand embeddings if needed
69
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
70
+ max_pos,
71
+ self.embedding_dim,
72
+ self.padding_idx,
73
+ )
74
+ self.weights = self.weights.to(self._float_tensor)
75
+
76
+ if incremental_state is not None:
77
+ # positions is the same for every token when decoding a single step
78
+ pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
79
+ return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
80
+
81
+ positions = make_positions(input, self.padding_idx) if positions is None else positions
82
+ return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
83
+
84
+ def max_positions(self):
85
+ """Maximum number of supported positions."""
86
+ return int(1e5) # an arbitrary large number
87
+
88
+
89
+ class TransformerFFNLayer(nn.Module):
90
+ def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu', bias=True):
91
+ super().__init__()
92
+ self.kernel_size = kernel_size
93
+ self.dropout = dropout
94
+ self.act = act
95
+ if padding == 'SAME':
96
+ self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size,
97
+ padding=kernel_size // 2, bias=bias)
98
+ elif padding == 'LEFT':
99
+ self.ffn_1 = nn.Sequential(
100
+ nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
101
+ nn.Conv1d(hidden_size, filter_size, kernel_size, bias=bias)
102
+ )
103
+ self.ffn_2 = Linear(filter_size, hidden_size, bias=bias)
104
+
105
+ def forward(self, x, incremental_state=None):
106
+ # x: T x B x C
107
+ if incremental_state is not None:
108
+ saved_state = self._get_input_buffer(incremental_state)
109
+ if 'prev_input' in saved_state:
110
+ prev_input = saved_state['prev_input']
111
+ x = torch.cat((prev_input, x), dim=0)
112
+ x = x[-self.kernel_size:]
113
+ saved_state['prev_input'] = x
114
+ self._set_input_buffer(incremental_state, saved_state)
115
+
116
+ x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
117
+ x = x * self.kernel_size ** -0.5
118
+
119
+ if incremental_state is not None:
120
+ x = x[-1:]
121
+ if self.act == 'gelu':
122
+ x = F.gelu(x)
123
+ if self.act == 'relu':
124
+ x = F.relu(x)
125
+ x = F.dropout(x, self.dropout, training=self.training)
126
+ x = self.ffn_2(x)
127
+ return x
128
+
129
+ def _get_input_buffer(self, incremental_state):
130
+ return get_incremental_state(
131
+ self,
132
+ incremental_state,
133
+ 'f',
134
+ ) or {}
135
+
136
+ def _set_input_buffer(self, incremental_state, buffer):
137
+ set_incremental_state(
138
+ self,
139
+ incremental_state,
140
+ 'f',
141
+ buffer,
142
+ )
143
+
144
+ def clear_buffer(self, incremental_state):
145
+ if incremental_state is not None:
146
+ saved_state = self._get_input_buffer(incremental_state)
147
+ if 'prev_input' in saved_state:
148
+ del saved_state['prev_input']
149
+ self._set_input_buffer(incremental_state, saved_state)
150
+
151
+
152
+ class MultiheadAttention(nn.Module):
153
+ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
154
+ add_bias_kv=False, add_zero_attn=False, self_attention=False,
155
+ encoder_decoder_attention=False):
156
+ super().__init__()
157
+ self.embed_dim = embed_dim
158
+ self.kdim = kdim if kdim is not None else embed_dim
159
+ self.vdim = vdim if vdim is not None else embed_dim
160
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
161
+
162
+ self.num_heads = num_heads
163
+ self.dropout = dropout
164
+ self.head_dim = embed_dim // num_heads
165
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
166
+ self.scaling = self.head_dim ** -0.5
167
+
168
+ self.self_attention = self_attention
169
+ self.encoder_decoder_attention = encoder_decoder_attention
170
+
171
+ assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
172
+ 'value to be of the same size'
173
+
174
+ if self.qkv_same_dim:
175
+ self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
176
+ else:
177
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
178
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
179
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
180
+
181
+ if bias:
182
+ self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
183
+ else:
184
+ self.register_parameter('in_proj_bias', None)
185
+
186
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
187
+
188
+ if add_bias_kv:
189
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
190
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
191
+ else:
192
+ self.bias_k = self.bias_v = None
193
+
194
+ self.add_zero_attn = add_zero_attn
195
+
196
+ self.reset_parameters()
197
+
198
+ self.enable_torch_version = False
199
+ self.last_attn_probs = None
200
+
201
+ def reset_parameters(self):
202
+ if self.qkv_same_dim:
203
+ nn.init.xavier_uniform_(self.in_proj_weight)
204
+ else:
205
+ nn.init.xavier_uniform_(self.k_proj_weight)
206
+ nn.init.xavier_uniform_(self.v_proj_weight)
207
+ nn.init.xavier_uniform_(self.q_proj_weight)
208
+
209
+ nn.init.xavier_uniform_(self.out_proj.weight)
210
+ if self.in_proj_bias is not None:
211
+ nn.init.constant_(self.in_proj_bias, 0.)
212
+ nn.init.constant_(self.out_proj.bias, 0.)
213
+ if self.bias_k is not None:
214
+ nn.init.xavier_normal_(self.bias_k)
215
+ if self.bias_v is not None:
216
+ nn.init.xavier_normal_(self.bias_v)
217
+
218
+ def forward(
219
+ self,
220
+ query, key, value,
221
+ key_padding_mask=None,
222
+ incremental_state=None,
223
+ need_weights=True,
224
+ static_kv=False,
225
+ attn_mask=None,
226
+ before_softmax=False,
227
+ need_head_weights=False,
228
+ enc_dec_attn_constraint_mask=None,
229
+ reset_attn_weight=None
230
+ ):
231
+ """Input shape: Time x Batch x Channel
232
+
233
+ Args:
234
+ key_padding_mask (ByteTensor, optional): mask to exclude
235
+ keys that are pads, of shape `(batch, src_len)`, where
236
+ padding elements are indicated by 1s.
237
+ need_weights (bool, optional): return the attention weights,
238
+ averaged over heads (default: False).
239
+ attn_mask (ByteTensor, optional): typically used to
240
+ implement causal attention, where the mask prevents the
241
+ attention from looking forward in time (default: None).
242
+ before_softmax (bool, optional): return the raw attention
243
+ weights and values before the attention softmax.
244
+ need_head_weights (bool, optional): return the attention
245
+ weights for each head. Implies *need_weights*. Default:
246
+ return the average attention weights over all heads.
247
+ """
248
+ if need_head_weights:
249
+ need_weights = True
250
+
251
+ tgt_len, bsz, embed_dim = query.size()
252
+ assert embed_dim == self.embed_dim
253
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
254
+
255
+ if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
256
+ if self.qkv_same_dim:
257
+ return F.multi_head_attention_forward(query, key, value,
258
+ self.embed_dim, self.num_heads,
259
+ self.in_proj_weight,
260
+ self.in_proj_bias, self.bias_k, self.bias_v,
261
+ self.add_zero_attn, self.dropout,
262
+ self.out_proj.weight, self.out_proj.bias,
263
+ self.training, key_padding_mask, need_weights,
264
+ attn_mask)
265
+ else:
266
+ return F.multi_head_attention_forward(query, key, value,
267
+ self.embed_dim, self.num_heads,
268
+ torch.empty([0]),
269
+ self.in_proj_bias, self.bias_k, self.bias_v,
270
+ self.add_zero_attn, self.dropout,
271
+ self.out_proj.weight, self.out_proj.bias,
272
+ self.training, key_padding_mask, need_weights,
273
+ attn_mask, use_separate_proj_weight=True,
274
+ q_proj_weight=self.q_proj_weight,
275
+ k_proj_weight=self.k_proj_weight,
276
+ v_proj_weight=self.v_proj_weight)
277
+
278
+ if incremental_state is not None:
279
+ saved_state = self._get_input_buffer(incremental_state)
280
+ if 'prev_key' in saved_state:
281
+ # previous time steps are cached - no need to recompute
282
+ # key and value if they are static
283
+ if static_kv:
284
+ assert self.encoder_decoder_attention and not self.self_attention
285
+ key = value = None
286
+ else:
287
+ saved_state = None
288
+
289
+ if self.self_attention:
290
+ # self-attention
291
+ q, k, v = self.in_proj_qkv(query)
292
+ elif self.encoder_decoder_attention:
293
+ # encoder-decoder attention
294
+ q = self.in_proj_q(query)
295
+ if key is None:
296
+ assert value is None
297
+ k = v = None
298
+ else:
299
+ k = self.in_proj_k(key)
300
+ v = self.in_proj_v(key)
301
+
302
+ else:
303
+ q = self.in_proj_q(query)
304
+ k = self.in_proj_k(key)
305
+ v = self.in_proj_v(value)
306
+ q = q * self.scaling
307
+
308
+ if self.bias_k is not None:
309
+ assert self.bias_v is not None
310
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
311
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
312
+ if attn_mask is not None:
313
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
314
+ if key_padding_mask is not None:
315
+ key_padding_mask = torch.cat(
316
+ [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
317
+
318
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
319
+ if k is not None:
320
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
321
+ if v is not None:
322
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
323
+
324
+ if saved_state is not None:
325
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
326
+ if 'prev_key' in saved_state:
327
+ prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
328
+ if static_kv:
329
+ k = prev_key
330
+ else:
331
+ k = torch.cat((prev_key, k), dim=1)
332
+ if 'prev_value' in saved_state:
333
+ prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
334
+ if static_kv:
335
+ v = prev_value
336
+ else:
337
+ v = torch.cat((prev_value, v), dim=1)
338
+ if 'prev_key_padding_mask' in saved_state and saved_state['prev_key_padding_mask'] is not None:
339
+ prev_key_padding_mask = saved_state['prev_key_padding_mask']
340
+ if static_kv:
341
+ key_padding_mask = prev_key_padding_mask
342
+ else:
343
+ key_padding_mask = torch.cat((prev_key_padding_mask, key_padding_mask), dim=1)
344
+
345
+ saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
346
+ saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
347
+ saved_state['prev_key_padding_mask'] = key_padding_mask
348
+
349
+ self._set_input_buffer(incremental_state, saved_state)
350
+
351
+ src_len = k.size(1)
352
+
353
+ # This is part of a workaround to get around fork/join parallelism
354
+ # not supporting Optional types.
355
+ if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
356
+ key_padding_mask = None
357
+
358
+ if key_padding_mask is not None:
359
+ assert key_padding_mask.size(0) == bsz
360
+ assert key_padding_mask.size(1) == src_len
361
+
362
+ if self.add_zero_attn:
363
+ src_len += 1
364
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
365
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
366
+ if attn_mask is not None:
367
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
368
+ if key_padding_mask is not None:
369
+ key_padding_mask = torch.cat(
370
+ [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
371
+
372
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
373
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
374
+
375
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
376
+
377
+ if attn_mask is not None:
378
+ if len(attn_mask.shape) == 2:
379
+ attn_mask = attn_mask.unsqueeze(0)
380
+ elif len(attn_mask.shape) == 3:
381
+ attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
382
+ bsz * self.num_heads, tgt_len, src_len)
383
+ attn_weights = attn_weights + attn_mask
384
+
385
+ if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
386
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
387
+ attn_weights = attn_weights.masked_fill(
388
+ enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
389
+ -1e8,
390
+ )
391
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
392
+
393
+ if key_padding_mask is not None:
394
+ # don't attend to padding symbols
395
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
396
+ attn_weights = attn_weights.masked_fill(
397
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
398
+ -1e8,
399
+ )
400
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
401
+
402
+ attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
403
+
404
+ if before_softmax:
405
+ return attn_weights, v
406
+
407
+ attn_weights_float = softmax(attn_weights, dim=-1)
408
+ attn_weights = attn_weights_float.type_as(attn_weights)
409
+ attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
410
+
411
+ if reset_attn_weight is not None:
412
+ if reset_attn_weight:
413
+ self.last_attn_probs = attn_probs.detach()
414
+ else:
415
+ assert self.last_attn_probs is not None
416
+ attn_probs = self.last_attn_probs
417
+ attn = torch.bmm(attn_probs, v)
418
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
419
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
420
+ attn = self.out_proj(attn)
421
+
422
+ if need_weights:
423
+ attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
424
+ if not need_head_weights:
425
+ # average attention weights over heads
426
+ attn_weights = attn_weights.mean(dim=0)
427
+ else:
428
+ attn_weights = None
429
+
430
+ return attn, (attn_weights, attn_logits)
431
+
432
+ def in_proj_qkv(self, query):
433
+ return self._in_proj(query).chunk(3, dim=-1)
434
+
435
+ def in_proj_q(self, query):
436
+ if self.qkv_same_dim:
437
+ return self._in_proj(query, end=self.embed_dim)
438
+ else:
439
+ bias = self.in_proj_bias
440
+ if bias is not None:
441
+ bias = bias[:self.embed_dim]
442
+ return F.linear(query, self.q_proj_weight, bias)
443
+
444
+ def in_proj_k(self, key):
445
+ if self.qkv_same_dim:
446
+ return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
447
+ else:
448
+ weight = self.k_proj_weight
449
+ bias = self.in_proj_bias
450
+ if bias is not None:
451
+ bias = bias[self.embed_dim:2 * self.embed_dim]
452
+ return F.linear(key, weight, bias)
453
+
454
+ def in_proj_v(self, value):
455
+ if self.qkv_same_dim:
456
+ return self._in_proj(value, start=2 * self.embed_dim)
457
+ else:
458
+ weight = self.v_proj_weight
459
+ bias = self.in_proj_bias
460
+ if bias is not None:
461
+ bias = bias[2 * self.embed_dim:]
462
+ return F.linear(value, weight, bias)
463
+
464
+ def _in_proj(self, input, start=0, end=None):
465
+ weight = self.in_proj_weight
466
+ bias = self.in_proj_bias
467
+ weight = weight[start:end, :]
468
+ if bias is not None:
469
+ bias = bias[start:end]
470
+ return F.linear(input, weight, bias)
471
+
472
+ def _get_input_buffer(self, incremental_state):
473
+ return get_incremental_state(
474
+ self,
475
+ incremental_state,
476
+ 'attn_state',
477
+ ) or {}
478
+
479
+ def _set_input_buffer(self, incremental_state, buffer):
480
+ set_incremental_state(
481
+ self,
482
+ incremental_state,
483
+ 'attn_state',
484
+ buffer,
485
+ )
486
+
487
+ def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
488
+ return attn_weights
489
+
490
+ def clear_buffer(self, incremental_state=None):
491
+ if incremental_state is not None:
492
+ saved_state = self._get_input_buffer(incremental_state)
493
+ if 'prev_key' in saved_state:
494
+ del saved_state['prev_key']
495
+ if 'prev_value' in saved_state:
496
+ del saved_state['prev_value']
497
+ self._set_input_buffer(incremental_state, saved_state)
498
+
499
+
500
+ class EncSALayer(nn.Module):
501
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
502
+ relu_dropout=0.1, kernel_size=9, padding='SAME', act='gelu',
503
+ ffn_hidden_size=1024):
504
+ super().__init__()
505
+ self.c = c
506
+ self.dropout = dropout
507
+ self.num_heads = num_heads
508
+ if num_heads > 0:
509
+ self.layer_norm1 = LayerNorm(c)
510
+ self.self_attn = MultiheadAttention(
511
+ self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False)
512
+ self.layer_norm2 = LayerNorm(c)
513
+ self.ffn = TransformerFFNLayer(
514
+ c, ffn_hidden_size, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act)
515
+
516
+ def forward(self, x, encoder_padding_mask=None, **kwargs):
517
+ layer_norm_training = kwargs.get('layer_norm_training', None)
518
+ if layer_norm_training is not None:
519
+ self.layer_norm1.training = layer_norm_training
520
+ self.layer_norm2.training = layer_norm_training
521
+ if self.num_heads > 0:
522
+ residual = x
523
+ x = self.layer_norm1(x)
524
+ x, _, = self.self_attn(
525
+ query=x,
526
+ key=x,
527
+ value=x,
528
+ key_padding_mask=encoder_padding_mask
529
+ )
530
+ x = F.dropout(x, self.dropout, training=self.training)
531
+ x = residual + x
532
+ x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
533
+
534
+ residual = x
535
+ x = self.layer_norm2(x)
536
+ x = self.ffn(x)
537
+ x = F.dropout(x, self.dropout, training=self.training)
538
+ x = residual + x
539
+ x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
540
+ return x
541
+
542
+
543
+ class DecSALayer(nn.Module):
544
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1,
545
+ kernel_size=9, ffn_hidden_size=1024, act='gelu', post_ln=False):
546
+ super().__init__()
547
+ self.c = c
548
+ self.dropout = dropout
549
+ self.layer_norm1 = LayerNorm(c)
550
+ self.self_attn = MultiheadAttention(
551
+ c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
552
+ )
553
+ self.layer_norm2 = LayerNorm(c)
554
+ self.encoder_attn = MultiheadAttention(
555
+ c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False,
556
+ )
557
+ self.layer_norm3 = LayerNorm(c)
558
+ self.ffn = TransformerFFNLayer(
559
+ c, ffn_hidden_size, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
560
+ self.post_ln = post_ln
561
+
562
+ def forward(
563
+ self,
564
+ x,
565
+ encoder_out=None,
566
+ encoder_padding_mask=None,
567
+ incremental_state=None,
568
+ self_attn_mask=None,
569
+ self_attn_padding_mask=None,
570
+ attn_out=None,
571
+ reset_attn_weight=None,
572
+ **kwargs,
573
+ ):
574
+ layer_norm_training = kwargs.get('layer_norm_training', None)
575
+ if layer_norm_training is not None:
576
+ self.layer_norm1.training = layer_norm_training
577
+ self.layer_norm2.training = layer_norm_training
578
+ self.layer_norm3.training = layer_norm_training
579
+ residual = x
580
+ if not self.post_ln:
581
+ x = self.layer_norm1(x)
582
+ x, _ = self.self_attn(
583
+ query=x,
584
+ key=x,
585
+ value=x,
586
+ key_padding_mask=self_attn_padding_mask,
587
+ incremental_state=incremental_state,
588
+ attn_mask=self_attn_mask
589
+ )
590
+ x = F.dropout(x, self.dropout, training=self.training)
591
+ x = residual + x
592
+ if self.post_ln:
593
+ x = self.layer_norm1(x)
594
+
595
+ attn_logits = None
596
+ if encoder_out is not None or attn_out is not None:
597
+ residual = x
598
+ if not self.post_ln:
599
+ x = self.layer_norm2(x)
600
+ if encoder_out is not None:
601
+ x, attn = self.encoder_attn(
602
+ query=x,
603
+ key=encoder_out,
604
+ value=encoder_out,
605
+ key_padding_mask=encoder_padding_mask,
606
+ incremental_state=incremental_state,
607
+ static_kv=True,
608
+ enc_dec_attn_constraint_mask=get_incremental_state(self, incremental_state,
609
+ 'enc_dec_attn_constraint_mask'),
610
+ reset_attn_weight=reset_attn_weight
611
+ )
612
+ attn_logits = attn[1]
613
+ elif attn_out is not None:
614
+ x = self.encoder_attn.in_proj_v(attn_out)
615
+ if encoder_out is not None or attn_out is not None:
616
+ x = F.dropout(x, self.dropout, training=self.training)
617
+ x = residual + x
618
+ if self.post_ln:
619
+ x = self.layer_norm2(x)
620
+
621
+ residual = x
622
+ if not self.post_ln:
623
+ x = self.layer_norm3(x)
624
+ x = self.ffn(x, incremental_state=incremental_state)
625
+ x = F.dropout(x, self.dropout, training=self.training)
626
+ x = residual + x
627
+ if self.post_ln:
628
+ x = self.layer_norm3(x)
629
+ return x, attn_logits
630
+
631
+ def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
632
+ self.encoder_attn.clear_buffer(incremental_state)
633
+ self.ffn.clear_buffer(incremental_state)
634
+
635
+ def set_buffer(self, name, tensor, incremental_state):
636
+ return set_incremental_state(self, incremental_state, name, tensor)
637
+
638
+
639
+ class TransformerEncoderLayer(nn.Module):
640
+ def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2, ffn_hidden_size=1024):
641
+ super().__init__()
642
+ self.hidden_size = hidden_size
643
+ self.dropout = dropout
644
+ self.num_heads = num_heads
645
+ self.op = EncSALayer(
646
+ hidden_size, num_heads, dropout=dropout,
647
+ attention_dropout=0.0, relu_dropout=dropout,
648
+ kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size)
649
+
650
+ def forward(self, x, **kwargs):
651
+ return self.op(x, **kwargs)
652
+
653
+
654
+ class TransformerDecoderLayer(nn.Module):
655
+ def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2, ffn_hidden_size=1024, post_ln=False):
656
+ super().__init__()
657
+ self.hidden_size = hidden_size
658
+ self.dropout = dropout
659
+ self.num_heads = num_heads
660
+ self.op = DecSALayer(
661
+ hidden_size, num_heads, dropout=dropout,
662
+ attention_dropout=0.0, relu_dropout=dropout,
663
+ kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size,
664
+ post_ln=post_ln)
665
+
666
+ def forward(self, x, **kwargs):
667
+ return self.op(x, **kwargs)
668
+
669
+ def clear_buffer(self, *args):
670
+ return self.op.clear_buffer(*args)
671
+
672
+ def set_buffer(self, *args):
673
+ return self.op.set_buffer(*args)
674
+
675
+
676
+ class FFTBlocks(nn.Module):
677
+ def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=0.0,
678
+ num_heads=2, use_pos_embed=True, use_last_norm=True,
679
+ use_pos_embed_alpha=True, ffn_hidden_size=1024):
680
+ super().__init__()
681
+ self.num_layers = num_layers
682
+ embed_dim = self.hidden_size = hidden_size
683
+ self.dropout = dropout
684
+ self.use_pos_embed = use_pos_embed
685
+ self.use_last_norm = use_last_norm
686
+ if use_pos_embed:
687
+ self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS
688
+ self.padding_idx = 0
689
+ self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1
690
+ self.embed_positions = SinusoidalPositionalEmbedding(
691
+ embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
692
+ )
693
+
694
+ self.layers = nn.ModuleList([])
695
+ self.layers.extend([
696
+ TransformerEncoderLayer(self.hidden_size, self.dropout,
697
+ kernel_size=ffn_kernel_size, num_heads=num_heads,
698
+ ffn_hidden_size=ffn_hidden_size)
699
+ for _ in range(self.num_layers)
700
+ ])
701
+ if self.use_last_norm:
702
+ self.layer_norm = nn.LayerNorm(embed_dim)
703
+ else:
704
+ self.layer_norm = None
705
+
706
+ def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False):
707
+ """
708
+ :param x: [B, T, C]
709
+ :param padding_mask: [B, T]
710
+ :return: [B, T, C] or [L, B, T, C]
711
+ """
712
+ padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
713
+ nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1]
714
+ if self.use_pos_embed:
715
+ positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
716
+ x = x + positions
717
+ x = F.dropout(x, p=self.dropout, training=self.training)
718
+ # B x T x C -> T x B x C
719
+ x = x.transpose(0, 1) * nonpadding_mask_TB
720
+ hiddens = []
721
+ for layer in self.layers:
722
+ x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
723
+ hiddens.append(x)
724
+ if self.use_last_norm:
725
+ x = self.layer_norm(x) * nonpadding_mask_TB
726
+ if return_hiddens:
727
+ x = torch.stack(hiddens, 0) # [L, T, B, C]
728
+ x = x.transpose(1, 2) # [L, B, T, C]
729
+ else:
730
+ x = x.transpose(0, 1) # [B, T, C]
731
+ return x
732
+
733
+
734
+ class FastSpeechEncoder(FFTBlocks):
735
+ def __init__(self, dict_size, hidden_size=256, num_layers=4, kernel_size=9,
736
+ dropout=0.0, num_heads=2, ffn_hidden_size=1024):
737
+ super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads,
738
+ use_pos_embed=False, dropout=dropout, ffn_hidden_size=ffn_hidden_size)
739
+ self.embed_tokens = Embedding(dict_size, hidden_size, 0)
740
+ self.embed_scale = math.sqrt(hidden_size)
741
+ self.padding_idx = 0
742
+ self.embed_positions = SinusoidalPositionalEmbedding(
743
+ hidden_size, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
744
+ )
745
+
746
+ def forward(self, txt_tokens, attn_mask=None, other_embeds=0):
747
+ """
748
+
749
+ :param txt_tokens: [B, T]
750
+ :return: {
751
+ 'encoder_out': [B x T x C]
752
+ }
753
+ """
754
+ encoder_padding_mask = txt_tokens.eq(self.padding_idx).data
755
+ x = self.forward_embedding(txt_tokens) + other_embeds # [B, T, H]
756
+ if self.num_layers > 0:
757
+ x = super(FastSpeechEncoder, self).forward(x, encoder_padding_mask, attn_mask=attn_mask)
758
+ return x
759
+
760
+ def forward_embedding(self, txt_tokens):
761
+ # embed tokens and positions
762
+ x = self.embed_scale * self.embed_tokens(txt_tokens)
763
+ if self.use_pos_embed:
764
+ positions = self.embed_positions(txt_tokens)
765
+ x = x + positions
766
+ x = F.dropout(x, p=self.dropout, training=self.training)
767
+ return x
tts/modules/llm_dit/cfm.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2023 Alexander Tong
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # Copyright (c) [2023] [Alexander Tong]
24
+ # Copyright (c) [2025] [Ziyue Jiang]
25
+ # SPDX-License-Identifier: MIT
26
+ # This file has been modified by Ziyue Jiang on 2025/03/19
27
+ # Original file was released under MIT, with the full license text # available at https://github.com/atong01/conditional-flow-matching/blob/1.0.7/LICENSE.
28
+ # This modified file is released under the same license.
29
+
30
+ import math
31
+ import torch
32
+ from typing import Union
33
+ from torch.distributions import LogisticNormal
34
+
35
+
36
+ class LogitNormalTrainingTimesteps:
37
+ def __init__(self, T=1000.0, loc=0.0, scale=1.0):
38
+ assert T > 0
39
+ self.T = T
40
+ self.dist = LogisticNormal(loc, scale)
41
+
42
+ def sample(self, size, device):
43
+ t = self.dist.sample(size)[..., 0].to(device)
44
+ return t
45
+
46
+
47
+ def pad_t_like_x(t, x):
48
+ """Function to reshape the time vector t by the number of dimensions of x.
49
+
50
+ Parameters
51
+ ----------
52
+ x : Tensor, shape (bs, *dim)
53
+ represents the source minibatch
54
+ t : FloatTensor, shape (bs)
55
+
56
+ Returns
57
+ -------
58
+ t : Tensor, shape (bs, number of x dimensions)
59
+
60
+ Example
61
+ -------
62
+ x: Tensor (bs, C, W, H)
63
+ t: Vector (bs)
64
+ pad_t_like_x(t, x): Tensor (bs, 1, 1, 1)
65
+ """
66
+ if isinstance(t, (float, int)):
67
+ return t
68
+ return t.reshape(-1, *([1] * (x.dim() - 1)))
69
+
70
+
71
+ class ConditionalFlowMatcher:
72
+ """Base class for conditional flow matching methods. This class implements the independent
73
+ conditional flow matching methods from [1] and serves as a parent class for all other flow
74
+ matching methods.
75
+
76
+ It implements:
77
+ - Drawing data from gaussian probability path N(t * x1 + (1 - t) * x0, sigma) function
78
+ - conditional flow matching ut(x1|x0) = x1 - x0
79
+ - score function $\nabla log p_t(x|x0, x1)$
80
+ """
81
+
82
+ def __init__(self, sigma: Union[float, int] = 0.0):
83
+ r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$.
84
+
85
+ Parameters
86
+ ----------
87
+ sigma : Union[float, int]
88
+ """
89
+ self.sigma = sigma
90
+ self.time_sampler = LogitNormalTrainingTimesteps()
91
+
92
+ def compute_mu_t(self, x0, x1, t):
93
+ """
94
+ Compute the mean of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].
95
+
96
+ Parameters
97
+ ----------
98
+ x0 : Tensor, shape (bs, *dim)
99
+ represents the source minibatch
100
+ x1 : Tensor, shape (bs, *dim)
101
+ represents the target minibatch
102
+ t : FloatTensor, shape (bs)
103
+
104
+ Returns
105
+ -------
106
+ mean mu_t: t * x1 + (1 - t) * x0
107
+
108
+ References
109
+ ----------
110
+ [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
111
+ """
112
+ t = pad_t_like_x(t, x0)
113
+ return t * x1 + (1 - t) * x0
114
+
115
+ def compute_sigma_t(self, t):
116
+ """
117
+ Compute the standard deviation of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].
118
+
119
+ Parameters
120
+ ----------
121
+ t : FloatTensor, shape (bs)
122
+
123
+ Returns
124
+ -------
125
+ standard deviation sigma
126
+
127
+ References
128
+ ----------
129
+ [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
130
+ """
131
+ del t
132
+ return self.sigma
133
+
134
+ def sample_xt(self, x0, x1, t, epsilon):
135
+ """
136
+ Draw a sample from the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].
137
+
138
+ Parameters
139
+ ----------
140
+ x0 : Tensor, shape (bs, *dim)
141
+ represents the source minibatch
142
+ x1 : Tensor, shape (bs, *dim)
143
+ represents the target minibatch
144
+ t : FloatTensor, shape (bs)
145
+ epsilon : Tensor, shape (bs, *dim)
146
+ noise sample from N(0, 1)
147
+
148
+ Returns
149
+ -------
150
+ xt : Tensor, shape (bs, *dim)
151
+
152
+ References
153
+ ----------
154
+ [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
155
+ """
156
+ mu_t = self.compute_mu_t(x0, x1, t)
157
+ sigma_t = self.compute_sigma_t(t)
158
+ sigma_t = pad_t_like_x(sigma_t, x0)
159
+ return mu_t + sigma_t * epsilon
160
+
161
+ def compute_conditional_flow(self, x0, x1, t, xt):
162
+ """
163
+ Compute the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].
164
+
165
+ Parameters
166
+ ----------
167
+ x0 : Tensor, shape (bs, *dim)
168
+ represents the source minibatch
169
+ x1 : Tensor, shape (bs, *dim)
170
+ represents the target minibatch
171
+ t : FloatTensor, shape (bs)
172
+ xt : Tensor, shape (bs, *dim)
173
+ represents the samples drawn from probability path pt
174
+
175
+ Returns
176
+ -------
177
+ ut : conditional vector field ut(x1|x0) = x1 - x0
178
+
179
+ References
180
+ ----------
181
+ [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
182
+ """
183
+ del t, xt
184
+ return x1 - x0
185
+
186
+ def sample_noise_like(self, x):
187
+ return torch.randn_like(x)
188
+
189
+ def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False):
190
+ """
191
+ Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma))
192
+ and the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].
193
+
194
+ Parameters
195
+ ----------
196
+ x0 : Tensor, shape (bs, *dim)
197
+ represents the source minibatch
198
+ x1 : Tensor, shape (bs, *dim)
199
+ represents the target minibatch
200
+ (optionally) t : Tensor, shape (bs)
201
+ represents the time levels
202
+ if None, drawn from uniform [0,1]
203
+ return_noise : bool
204
+ return the noise sample epsilon
205
+
206
+
207
+ Returns
208
+ -------
209
+ t : FloatTensor, shape (bs)
210
+ xt : Tensor, shape (bs, *dim)
211
+ represents the samples drawn from probability path pt
212
+ ut : conditional vector field ut(x1|x0) = x1 - x0
213
+ (optionally) eps: Tensor, shape (bs, *dim) such that xt = mu_t + sigma_t * epsilon
214
+
215
+ References
216
+ ----------
217
+ [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
218
+ """
219
+ if t is None:
220
+ # t = torch.rand(x0.shape[0]).type_as(x0)
221
+ t = self.time_sampler.sample([x0.shape[0]], x0.device).type_as(x0)
222
+
223
+ assert len(t) == x0.shape[0], "t has to have batch size dimension"
224
+
225
+ eps = self.sample_noise_like(x0)
226
+ xt = self.sample_xt(x0, x1, t, eps)
227
+ ut = self.compute_conditional_flow(x0, x1, t, xt)
228
+ if return_noise:
229
+ return t, xt, ut, eps
230
+ else:
231
+ return t, xt, ut
232
+
233
+ def compute_lambda(self, t):
234
+ """Compute the lambda function, see Eq.(23) [3].
235
+
236
+ Parameters
237
+ ----------
238
+ t : FloatTensor, shape (bs)
239
+
240
+ Returns
241
+ -------
242
+ lambda : score weighting function
243
+
244
+ References
245
+ ----------
246
+ [4] Simulation-free Schrodinger bridges via score and flow matching, Preprint, Tong et al.
247
+ """
248
+ sigma_t = self.compute_sigma_t(t)
249
+ return 2 * sigma_t / (self.sigma**2 + 1e-8)
250
+
251
+
252
+ class VariancePreservingConditionalFlowMatcher(ConditionalFlowMatcher):
253
+ """Albergo et al. 2023 trigonometric interpolants class. This class inherits the
254
+ ConditionalFlowMatcher and override the compute_mu_t and compute_conditional_flow functions in
255
+ order to compute [3]'s trigonometric interpolants.
256
+
257
+ [3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
258
+ """
259
+
260
+ def compute_mu_t(self, x0, x1, t):
261
+ r"""Compute the mean of the probability path (Eq.5) from [3].
262
+
263
+ Parameters
264
+ ----------
265
+ x0 : Tensor, shape (bs, *dim)
266
+ represents the source minibatch
267
+ x1 : Tensor, shape (bs, *dim)
268
+ represents the target minibatch
269
+ t : FloatTensor, shape (bs)
270
+
271
+ Returns
272
+ -------
273
+ mean mu_t: cos(pi t/2)x0 + sin(pi t/2)x1
274
+
275
+ References
276
+ ----------
277
+ [3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
278
+ """
279
+ t = pad_t_like_x(t, x0)
280
+ return torch.cos(math.pi / 2 * t) * x0 + torch.sin(math.pi / 2 * t) * x1
281
+
282
+ def compute_conditional_flow(self, x0, x1, t, xt):
283
+ r"""Compute the conditional vector field similar to [3].
284
+
285
+ ut(x1|x0) = pi/2 (cos(pi*t/2) x1 - sin(pi*t/2) x0),
286
+ see Eq.(21) [3].
287
+
288
+ Parameters
289
+ ----------
290
+ x0 : Tensor, shape (bs, *dim)
291
+ represents the source minibatch
292
+ x1 : Tensor, shape (bs, *dim)
293
+ represents the target minibatch
294
+ t : FloatTensor, shape (bs)
295
+ xt : Tensor, shape (bs, *dim)
296
+ represents the samples drawn from probability path pt
297
+
298
+ Returns
299
+ -------
300
+ ut : conditional vector field
301
+ ut(x1|x0) = pi/2 (cos(pi*t/2) x1 - sin(\pi*t/2) x0)
302
+
303
+ References
304
+ ----------
305
+ [3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
306
+ """
307
+ del xt
308
+ t = pad_t_like_x(t, x0)
309
+ return math.pi / 2 * (torch.cos(math.pi / 2 * t) * x1 - torch.sin(math.pi / 2 * t) * x0)
tts/modules/llm_dit/dit.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from torch import nn
17
+
18
+ from tts.modules.llm_dit.cfm import ConditionalFlowMatcher
19
+ from tts.modules.ar_dur.commons.layers import Embedding
20
+ from tts.modules.ar_dur.commons.nar_tts_modules import PosEmb
21
+ from tts.modules.ar_dur.commons.rel_transformer import RelTransformerEncoder
22
+ from tts.modules.ar_dur.ar_dur_predictor import expand_states
23
+ from tts.modules.llm_dit.transformer import Transformer
24
+ from tts.modules.llm_dit.time_embedding import TimestepEmbedding
25
+
26
+
27
+ class Diffusion(nn.Module):
28
+ def __init__(self):
29
+ super().__init__()
30
+ # Hparams
31
+ # cond dim
32
+ self.local_cond_dim = 512
33
+ self.ctx_mask_dim = 16
34
+ self.in_channels = 32
35
+ self.out_channels = 32
36
+ # LLM
37
+ self.encoder_dim = 1024
38
+ self.encoder_n_layers = 24
39
+ self.encoder_n_heads = 16
40
+ self.max_seq_len = 16384
41
+ self.multiple_of = 256
42
+
43
+ self.ctx_mask_proj = nn.Linear(1, self.ctx_mask_dim)
44
+ self.local_cond_project = nn.Linear(
45
+ self.out_channels + self.ctx_mask_dim, self.local_cond_dim)
46
+
47
+ self.encoder = Transformer(self.encoder_n_layers, self.encoder_dim, self.encoder_n_heads, self.max_seq_len)
48
+
49
+ self.x_prenet = nn.Linear(self.in_channels, self.encoder_dim)
50
+ self.prenet = nn.Linear(self.local_cond_dim, self.encoder_dim)
51
+ self.postnet = nn.Linear(self.encoder_dim, self.out_channels)
52
+
53
+ self.flow_matcher = ConditionalFlowMatcher(sigma=0.0)
54
+ # The implementation of TimestepEmbedding is a modified version from F5-TTS (https://github.com/SWivid/F5-TTS),
55
+ # which is licensed under the MIT License.
56
+ self.f5_time_embed = TimestepEmbedding(self.encoder_dim)
57
+
58
+ # text encoder
59
+ self.ph_encoder = RelTransformerEncoder(
60
+ 302, self.encoder_dim, self.encoder_dim,
61
+ self.encoder_dim * 2, 4, 6,
62
+ 3, 0.0, prenet=True, pre_ln=True)
63
+ self.tone_embed = Embedding(32, self.encoder_dim, padding_idx=0)
64
+ self.ph_pos_embed = PosEmb(self.encoder_dim)
65
+ self.ling_pre_net = torch.nn.Sequential(*[
66
+ torch.nn.Conv1d(self.encoder_dim, self.encoder_dim, kernel_size=s * 2, stride=s, padding=s // 2)
67
+ for i, s in enumerate([2, 2])
68
+ ])
69
+
70
+ def forward(self, inputs, sigmas=None, x_noisy=None):
71
+ ctx_mask = inputs['ctx_mask']
72
+ ctx_feature = inputs['lat_ctx'] * ctx_mask
73
+
74
+ """ local conditioning (prompt_latent + spk_embed) """
75
+ ctx_mask_emb = self.ctx_mask_proj(ctx_mask)
76
+ # ctx_feature = ctx_feature * (1 - inputs["spk_cfg_mask"][:, :, None])
77
+ local_cond = torch.cat([ctx_feature, ctx_mask_emb], dim=-1)
78
+ local_cond = self.local_cond_project(local_cond)
79
+
80
+ """ diffusion target latent """
81
+ x = inputs['lat']
82
+
83
+ # Here, x is x1 in CFM
84
+ x0 = torch.randn_like(x)
85
+ t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow(x0, x)
86
+
87
+ # define noisy_input and target
88
+ t = t.bfloat16()
89
+ x_noisy = (xt * (1 - ctx_mask)).bfloat16()
90
+ target = ut
91
+
92
+ # concat condition.
93
+ x_ling = self.forward_ling_encoder(inputs["phone"], inputs["tone"])
94
+ x_ling = self.ling_pre_net(expand_states(x_ling, inputs['mel2ph']).transpose(1, 2)).transpose(1, 2)
95
+ x_noisy = self.x_prenet(x_noisy) + self.prenet(local_cond) + x_ling
96
+ encoder_out = self.encoder(x_noisy, self.f5_time_embed(t), attn_mask=inputs["text_mel_mask"], do_checkpoint=False)
97
+ pred = self.postnet(encoder_out)
98
+
99
+ return pred, target
100
+
101
+ def forward_ling_encoder(self, txt_tokens, tone_tokens):
102
+ ph_tokens = txt_tokens
103
+ ph_nonpadding = (ph_tokens > 0).float()[:, :, None] # [B, T_phone, 1]
104
+
105
+ # enc_ph
106
+ ph_enc_oembed = self.tone_embed(tone_tokens)
107
+ ph_enc_oembed = ph_enc_oembed + self.ph_pos_embed(
108
+ torch.arange(0, ph_tokens.shape[1])[None,].to(ph_tokens.device))
109
+ ph_enc_oembed = ph_enc_oembed
110
+ ph_enc_oembed = ph_enc_oembed * ph_nonpadding
111
+ x_ling = self.ph_encoder(ph_tokens, other_embeds=ph_enc_oembed) * ph_nonpadding
112
+ return x_ling
113
+
114
+ def _forward(self, x, local_cond, x_ling, timesteps, ctx_mask, dur=None, seq_cfg_w=[1.0,1.0]):
115
+ """ When we use torchdiffeq, we need to include the CFG process inside _forward() """
116
+ x = x * (1 - ctx_mask)
117
+ x = self.x_prenet(x) + self.prenet(local_cond) + x_ling
118
+ pred_v = self.encoder(x, self.f5_time_embed(timesteps), attn_mask=torch.ones((x.size(0), x.size(1)), device=x.device))
119
+ pred = self.postnet(pred_v)
120
+
121
+ """ Perform multi-cond CFG """
122
+ cond_spk_txt, cond_txt, uncond = pred.chunk(3)
123
+ pred = uncond + seq_cfg_w[0] * (cond_txt - uncond) + seq_cfg_w[1] * (cond_spk_txt - cond_txt)
124
+ return pred
125
+
126
+ @torch.no_grad()
127
+ def inference(self, inputs, timesteps=20, seq_cfg_w=[1.0, 1.0], **kwargs):
128
+ # txt embedding
129
+ x_ling = self.forward_ling_encoder(inputs["phone"], inputs["tone"])
130
+ x_ling = self.ling_pre_net(expand_states(x_ling, inputs['dur']).transpose(1, 2)).transpose(1, 2)
131
+
132
+ # speaker embedding
133
+ ctx_feature = inputs['lat_ctx']
134
+ ctx_feature[1:, :, :] = 0 # prefix spk cfg
135
+ ctx_mask_emb = self.ctx_mask_proj(inputs['ctx_mask'])
136
+
137
+ # local conditioning.
138
+ local_cond = torch.cat([ctx_feature, ctx_mask_emb], dim=-1)
139
+ local_cond = self.local_cond_project(local_cond)
140
+
141
+ ''' Euler ODE solver '''
142
+ bsz, device, frm_len = (local_cond.size(0), local_cond.device, local_cond.size(1))
143
+ # Sway sampling from F5-TTS (https://github.com/SWivid/F5-TTS),
144
+ # which is licensed under the MIT License.
145
+ sway_sampling_coef = -1.0
146
+ t_schedule = torch.linspace(0, 1, timesteps + 1, device=device, dtype=x_ling.dtype)
147
+ if sway_sampling_coef is not None:
148
+ t_schedule = t_schedule + sway_sampling_coef * (torch.cos(torch.pi / 2 * t_schedule) - 1 + t_schedule)
149
+
150
+ # AMO sampling implementation for "AMO Sampler: Enhancing Text Rendering with Overshooting" (https://arxiv.org/pdf/2411.19415)
151
+ def amo_sampling(z_t, t, t_next, v):
152
+ # Upcast to avoid precision issues when computing prev_sample
153
+ z_t = z_t.to(torch.float32)
154
+
155
+ # Constant definition in Algorithm 1
156
+ s = t_next
157
+ c = 3
158
+
159
+ # Line 7 in Algorithm 1
160
+ o = min(t_next + c * (t_next - t), 1)
161
+ pred_z_o = z_t + (o - t) * v
162
+
163
+ # Line 11 in Algorithm 1
164
+ a = s / o
165
+ b = ((1 - s) ** 2 - (a * (1 - o)) ** 2) ** 0.5
166
+ noise_i = torch.randn(size=z_t.shape, device=z_t.device)
167
+ z_t_next = a * pred_z_o + b * noise_i
168
+ return z_t_next.to(v.dtype)
169
+
170
+ x = torch.randn([1, frm_len, self.out_channels], device=device)
171
+ for step_index in range(timesteps):
172
+ x = x.to(torch.float32)
173
+ sigma = t_schedule[step_index].to(x_ling.dtype)
174
+ sigma_next = t_schedule[step_index + 1]
175
+ model_out = self._forward(torch.cat([x] * bsz), local_cond, x_ling, timesteps=sigma.unsqueeze(0), ctx_mask=inputs['ctx_mask'], dur=inputs['dur'], seq_cfg_w=seq_cfg_w)
176
+ x = amo_sampling(x, sigma, sigma_next, model_out)
177
+ # Cast sample back to model compatible dtype
178
+ x = x.to(model_out.dtype)
179
+
180
+ return x
tts/modules/llm_dit/time_embedding.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import torch
17
+ from torch import nn
18
+
19
+
20
+ class SinusPositionEmbedding(nn.Module):
21
+ def __init__(self, dim):
22
+ super().__init__()
23
+ self.dim = dim
24
+
25
+ def forward(self, x, scale=1000):
26
+ device = x.device
27
+ half_dim = self.dim // 2
28
+ emb = math.log(10000) / (half_dim - 1)
29
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
30
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
31
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
32
+ return emb
33
+
34
+ class TimestepEmbedding(nn.Module):
35
+ def __init__(self, dim, freq_embed_dim=256):
36
+ super().__init__()
37
+ self.time_embed = SinusPositionEmbedding(freq_embed_dim)
38
+ self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
39
+
40
+ def forward(self, timestep): # noqa: F821
41
+ time_hidden = self.time_embed(timestep)
42
+ time_hidden = time_hidden.to(timestep.dtype)
43
+ time = self.time_mlp(time_hidden) # b d
44
+ return time
tts/modules/llm_dit/transformer.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Optional, Tuple
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import nn
21
+
22
+
23
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
24
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
25
+ t = torch.arange(end, device=freqs.device) # type: ignore
26
+ freqs = torch.outer(t, freqs).float() # type: ignore
27
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
28
+ return freqs_cis
29
+
30
+
31
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
32
+ ndim = x.ndim
33
+ assert 0 <= 1 < ndim
34
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
35
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
36
+ return freqs_cis.view(*shape)
37
+
38
+
39
+ def apply_rotary_emb(
40
+ xq: torch.Tensor,
41
+ xk: torch.Tensor,
42
+ freqs_cis: torch.Tensor,
43
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
44
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
45
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
46
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
47
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
48
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
49
+ return xq_out.type_as(xq), xk_out.type_as(xk)
50
+
51
+
52
+ class AdaLNZero(nn.Module):
53
+ def __init__(self, dim):
54
+ super().__init__()
55
+ self.silu = nn.SiLU()
56
+ self.linear = nn.Linear(dim, dim * 6)
57
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
58
+
59
+ def forward(self, x, emb=None):
60
+ emb = self.linear(self.silu(emb))
61
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
62
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
63
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
64
+
65
+
66
+ class AdaLNZero_Out(nn.Module):
67
+ def __init__(self, dim):
68
+ super().__init__()
69
+ self.silu = nn.SiLU()
70
+ self.linear = nn.Linear(dim, dim * 2)
71
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
72
+
73
+ def forward(self, x, emb):
74
+ emb = self.linear(self.silu(emb))
75
+ scale, shift = torch.chunk(emb, 2, dim=1)
76
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
77
+ return x
78
+
79
+
80
+ class Attention(nn.Module):
81
+ def __init__(self, encoder_dim, encoder_n_heads, max_seq_len):
82
+ super().__init__()
83
+ self.encoder_n_kv_heads = encoder_n_heads
84
+ model_parallel_size = 1
85
+ self.n_local_heads = encoder_n_heads // model_parallel_size
86
+ self.n_local_kv_heads = self.encoder_n_kv_heads // model_parallel_size
87
+ self.n_rep = self.n_local_heads // self.n_local_kv_heads
88
+ self.head_dim = encoder_dim // encoder_n_heads
89
+
90
+ self.wq = nn.Linear(
91
+ encoder_dim,
92
+ encoder_n_heads * self.head_dim,
93
+ )
94
+ self.wk = nn.Linear(
95
+ encoder_dim,
96
+ self.encoder_n_kv_heads * self.head_dim,
97
+ )
98
+ self.wv = nn.Linear(
99
+ encoder_dim,
100
+ self.encoder_n_kv_heads * self.head_dim,
101
+ )
102
+ self.wo = nn.Linear(
103
+ encoder_n_heads * self.head_dim,
104
+ encoder_dim,
105
+ )
106
+
107
+ def forward(
108
+ self,
109
+ x: torch.Tensor,
110
+ start_pos: int,
111
+ freqs_cis: torch.Tensor,
112
+ mask: Optional[torch.Tensor],
113
+ ):
114
+ bsz, seqlen, _ = x.shape
115
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
116
+ xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
117
+ xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
118
+ xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
119
+
120
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
121
+ xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
122
+ keys = xk.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
123
+ values = xv.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
124
+
125
+ output = F.scaled_dot_product_attention(xq, keys, values, mask[:, None, None, :], is_causal=False)
126
+ output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
127
+ return self.wo(output)
128
+
129
+
130
+ class FeedForward(nn.Module):
131
+ def __init__(
132
+ self,
133
+ dim: int,
134
+ hidden_dim: int,
135
+ multiple_of: int,
136
+ ffn_dim_multiplier: Optional[float],
137
+ ):
138
+ super().__init__()
139
+ if ffn_dim_multiplier is not None:
140
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
141
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
142
+
143
+ self.w1 = nn.Linear(
144
+ dim, hidden_dim
145
+ )
146
+ self.w2 = nn.Linear(
147
+ hidden_dim, dim
148
+ )
149
+
150
+ def forward(self, x):
151
+ return self.w2(F.silu(self.w1(x)))
152
+
153
+
154
+ class TransformerBlock(nn.Module):
155
+ def __init__(self, encoder_dim, encoder_n_heads, max_seq_len):
156
+ super().__init__()
157
+ self.encoder_n_heads = encoder_n_heads
158
+ self.encoder_dim = encoder_dim
159
+ self.head_dim = encoder_dim // encoder_n_heads
160
+ self.attention = Attention(encoder_dim, encoder_n_heads, max_seq_len)
161
+ self.feed_forward = FeedForward(
162
+ dim=encoder_dim,
163
+ hidden_dim=2 * encoder_dim,
164
+ multiple_of=256,
165
+ ffn_dim_multiplier=None,
166
+ )
167
+ self.attention_norm = AdaLNZero(encoder_dim)
168
+ self.ffn_norm = nn.LayerNorm(encoder_dim, elementwise_affine=False, eps=1e-6)
169
+
170
+ def forward(
171
+ self,
172
+ x: torch.Tensor,
173
+ t: torch.Tensor,
174
+ start_pos: int,
175
+ freqs_cis: torch.Tensor,
176
+ mask: Optional[torch.Tensor],
177
+ ):
178
+ """
179
+ Perform a forward pass through the TransformerBlock.
180
+
181
+ Args:
182
+ x (torch.Tensor): Input tensor.
183
+ start_pos (int): Starting position for attention caching.
184
+ freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
185
+ mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.
186
+
187
+ Returns:
188
+ torch.Tensor: Output tensor after applying attention and feedforward layers.
189
+
190
+ """
191
+ # pre-norm & modulation for attention input
192
+ norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attention_norm(x, emb=t)
193
+
194
+ # attention
195
+ attn_output = self.attention(norm, start_pos, freqs_cis, mask=mask)
196
+
197
+ # process attention output for input x
198
+ h = x + gate_msa.unsqueeze(1) * attn_output
199
+
200
+ norm = self.ffn_norm(h) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
201
+ ff_output = self.feed_forward(norm)
202
+ out = h + gate_mlp.unsqueeze(1) * ff_output
203
+
204
+ return out
205
+
206
+
207
+ class Transformer(nn.Module):
208
+ def __init__(self, encoder_n_layers, encoder_dim, encoder_n_heads, max_seq_len):
209
+ super().__init__()
210
+ # Decoder
211
+ self.layers = torch.nn.ModuleList()
212
+ for _ in range(encoder_n_layers):
213
+ self.layers.append(TransformerBlock(encoder_dim, encoder_n_heads, max_seq_len))
214
+
215
+ self.norm = AdaLNZero_Out(encoder_dim)
216
+ self.out_proj = nn.Linear(encoder_dim, encoder_dim)
217
+
218
+ # Rope embedding
219
+ freqs_cis = precompute_freqs_cis(
220
+ encoder_dim // encoder_n_heads, max_seq_len
221
+ )
222
+ self.register_buffer("freqs_cis", torch.view_as_real(freqs_cis), persistent=False)
223
+
224
+ def forward(self, x, t, attn_mask, start_pos=0):
225
+ freqs_cis = torch.view_as_complex(self.freqs_cis.float())[start_pos: start_pos + x.size(1)]
226
+ for i, layer in enumerate(self.layers):
227
+ x = layer(x, t, start_pos, freqs_cis, attn_mask)
228
+ x = self.norm(x, t)
229
+ x = self.out_proj(x)
230
+ return x
tts/modules/wavvae/decoder/diag_gaussian.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import numpy as np
17
+
18
+ class DiagonalGaussianDistribution(object):
19
+ def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
20
+ self.parameters = parameters
21
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
22
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
23
+ self.deterministic = deterministic
24
+ self.std = torch.exp(0.5 * self.logvar)
25
+ self.var = torch.exp(self.logvar)
26
+ if self.deterministic:
27
+ self.var = self.std = torch.zeros_like(
28
+ self.mean, device=self.parameters.device, dtype=self.parameters.dtype
29
+ )
30
+
31
+ def sample(self, generator=None) -> torch.Tensor:
32
+ # make sure sample is on the same device as the parameters and has same dtype
33
+ sample = torch.randn(
34
+ self.mean.shape,
35
+ generator=generator,
36
+ device=self.parameters.device,
37
+ dtype=self.parameters.dtype,
38
+ )
39
+ x = self.mean + self.std * sample
40
+ return x
41
+
42
+ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
43
+ if self.deterministic:
44
+ return torch.Tensor([0.0])
45
+ else:
46
+ if other is None:
47
+ return 0.5 * torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar
48
+ else:
49
+ return 0.5 * (
50
+ torch.pow(self.mean - other.mean, 2) / other.var
51
+ + self.var / other.var
52
+ - 1.0
53
+ - self.logvar
54
+ + other.logvar
55
+ )
56
+
57
+ def nll(self, sample, dims) -> torch.Tensor:
58
+ if self.deterministic:
59
+ return torch.Tensor([0.0])
60
+ logtwopi = np.log(2.0 * np.pi)
61
+ return 0.5 * torch.sum(
62
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
63
+ dim=dims,
64
+ )
65
+
66
+ def mode(self) -> torch.Tensor:
67
+ return self.mean
tts/modules/wavvae/decoder/hifigan_modules.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import torch
18
+ import torch.utils.data
19
+ from librosa.filters import mel as librosa_mel_fn
20
+ from torch.nn.utils import weight_norm, remove_weight_norm
21
+ from torch.nn import Conv1d
22
+ import numpy as np
23
+
24
+
25
+ def init_weights(m, mean=0.0, std=0.01):
26
+ classname = m.__class__.__name__
27
+ if classname.find("Conv") != -1:
28
+ m.weight.data.normal_(mean, std)
29
+
30
+
31
+ def get_padding(kernel_size, dilation=1):
32
+ return int((kernel_size*dilation - dilation)/2)
33
+
34
+
35
+ class Upsample(nn.Module):
36
+ def __init__(self, mult, r):
37
+ super(Upsample, self).__init__()
38
+ self.r = r
39
+ self.upsample = nn.Sequential(nn.Upsample(mode="nearest", scale_factor=r),
40
+ nn.LeakyReLU(0.2),
41
+ nn.ReflectionPad1d(3),
42
+ nn.utils.weight_norm(nn.Conv1d(mult, mult // 2, kernel_size=7, stride=1))
43
+ )
44
+ r_kernel = r if r >= 5 else 5
45
+ self.trans_upsample = nn.Sequential(nn.LeakyReLU(0.2),
46
+ nn.utils.weight_norm(nn.ConvTranspose1d(mult, mult // 2,
47
+ kernel_size=r_kernel * 2, stride=r,
48
+ padding=r_kernel - r // 2,
49
+ output_padding=r % 2)
50
+ ))
51
+
52
+ def forward(self, x):
53
+ x = torch.sin(x) + x
54
+ out1 = self.upsample(x)
55
+ out2 = self.trans_upsample(x)
56
+ return out1 + out2
57
+
58
+
59
+ class Downsample(nn.Module):
60
+ def __init__(self, mult, r):
61
+ super(Downsample, self).__init__()
62
+ self.r = r
63
+ r_kernel = r if r >= 5 else 5
64
+ self.trans_downsample = nn.Sequential(nn.LeakyReLU(0.2),
65
+ nn.utils.weight_norm(nn.Conv1d(mult, mult * 2,
66
+ kernel_size=r_kernel * 2, stride=r,
67
+ padding=r_kernel - r // 2)
68
+ ))
69
+
70
+ def forward(self, x):
71
+ out = self.trans_downsample(x)
72
+ return out
73
+
74
+
75
+ def weights_init(m):
76
+ classname = m.__class__.__name__
77
+ if classname.find("Conv") != -1:
78
+ m.weight.data.normal_(0.0, 0.02)
79
+ elif classname.find("BatchNorm2d") != -1:
80
+ m.weight.data.normal_(1.0, 0.02)
81
+ m.bias.data.fill_(0)
82
+
83
+
84
+ def weights_zero_init(m):
85
+ classname = m.__class__.__name__
86
+ if classname.find("Conv") != -1:
87
+ m.weight.data.fill_(0.0)
88
+ m.bias.data.fill_(0.0)
89
+
90
+
91
+ def WNConv1d(*args, **kwargs):
92
+ return weight_norm(nn.Conv1d(*args, **kwargs))
93
+
94
+
95
+ def WNConvTranspose1d(*args, **kwargs):
96
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
97
+
98
+
99
+ class Audio2Mel(nn.Module):
100
+ def __init__(
101
+ self,
102
+ hop_length=300,
103
+ sampling_rate=24000,
104
+ n_mel_channels=80,
105
+ mel_fmin=0.,
106
+ mel_fmax=None,
107
+ frame_size=0.05,
108
+ device='cpu'
109
+ ):
110
+ super().__init__()
111
+ ##############################################
112
+ # FFT Parameters #
113
+ ##############################################
114
+
115
+ self.n_fft = int(np.power(2., np.ceil(np.log(sampling_rate * frame_size) / np.log(2))))
116
+ window = torch.hann_window(int(sampling_rate * frame_size)).float()
117
+ mel_basis = librosa_mel_fn(
118
+ sampling_rate, self.n_fft, n_mel_channels, mel_fmin, mel_fmax
119
+ ) # Mel filter (by librosa)
120
+ mel_basis = torch.from_numpy(mel_basis).float()
121
+ self.register_buffer("mel_basis", mel_basis)
122
+ self.register_buffer("window", window)
123
+
124
+ self.hop_length = hop_length
125
+ self.win_length = int(sampling_rate * frame_size)
126
+ self.sampling_rate = sampling_rate
127
+ self.n_mel_channels = n_mel_channels
128
+
129
+ def forward(self, audio):
130
+ fft = torch.stft(
131
+ audio.squeeze(1),
132
+ n_fft=self.n_fft,
133
+ hop_length=self.hop_length,
134
+ win_length=self.win_length,
135
+ window=self.window,
136
+ center=True,
137
+ )
138
+ real_part, imag_part = fft.unbind(-1)
139
+ magnitude = torch.sqrt(torch.clamp(real_part ** 2 + imag_part ** 2, min=1e-5))
140
+ mel_output = torch.matmul(self.mel_basis, magnitude)
141
+
142
+ log_mel_spec = 20 * torch.log10(torch.clamp(mel_output, min=1e-5)) - 20
143
+ norm_mel = (log_mel_spec + 115.) / 115.
144
+ mel_comp = torch.clamp(norm_mel * 8. - 4., -4., 4.)
145
+
146
+ return mel_comp
147
+
148
+
149
+ class ResnetBlock(nn.Module):
150
+ def __init__(self, dim, dilation=1, dim_in=None):
151
+ super().__init__()
152
+ if dim_in is None:
153
+ dim_in = dim
154
+
155
+ self.block = nn.Sequential(
156
+ nn.LeakyReLU(0.2),
157
+ nn.ReflectionPad1d(dilation),
158
+ WNConv1d(dim_in, dim, kernel_size=3, dilation=dilation),
159
+ nn.LeakyReLU(0.2),
160
+ WNConv1d(dim, dim, kernel_size=1),
161
+ )
162
+ self.shortcut = WNConv1d(dim_in, dim, kernel_size=1)
163
+
164
+ def forward(self, x):
165
+ return self.shortcut(x) + self.block(x)
166
+
167
+
168
+ '''
169
+ 参照hifigan(https://arxiv.org/pdf/2010.05646.pdf)v2结构
170
+ 多尺度主要是kernel_size不同,3组并行卷积模块,每个卷积模块内部采用不同的串行dilation size,且中间交叉正常无dilation卷积层
171
+ '''
172
+
173
+
174
+ class ResBlockMRFV2(torch.nn.Module):
175
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
176
+ super(ResBlockMRFV2, self).__init__()
177
+ self.convs1 = nn.ModuleList([
178
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
179
+ padding=get_padding(kernel_size, dilation[0]))),
180
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
181
+ padding=get_padding(kernel_size, dilation[1]))),
182
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
183
+ padding=get_padding(kernel_size, dilation[2])))
184
+ ])
185
+ self.convs1.apply(init_weights)
186
+
187
+ self.convs2 = nn.ModuleList([
188
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
189
+ padding=get_padding(kernel_size, 1))),
190
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
191
+ padding=get_padding(kernel_size, 1))),
192
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
193
+ padding=get_padding(kernel_size, 1)))
194
+ ])
195
+ self.convs2.apply(init_weights)
196
+
197
+ def forward(self, x):
198
+ for c1, c2 in zip(self.convs1, self.convs2):
199
+ xt = F.leaky_relu(x, 0.2)
200
+ xt = c1(xt)
201
+ xt = F.leaky_relu(xt, 0.2)
202
+ xt = c2(xt)
203
+ x = xt + x
204
+ return x
205
+
206
+ def remove_weight_norm(self):
207
+ for l in self.convs1:
208
+ remove_weight_norm(l)
209
+ for l in self.convs2:
210
+ remove_weight_norm(l)
211
+
212
+
213
+ class ResBlockMRFV2Inter(torch.nn.Module):
214
+ def __init__(self, channels, kernel_size=3):
215
+ super(ResBlockMRFV2Inter, self).__init__()
216
+ self.block1 = ResBlockMRFV2(channels)
217
+ self.block2 = ResBlockMRFV2(channels, 7)
218
+ self.block3 = ResBlockMRFV2(channels, 11)
219
+
220
+ def forward(self, x):
221
+ xs = self.block1(x)
222
+ xs += self.block2(x)
223
+ xs += self.block3(x)
224
+ x = xs / 3
225
+ return x
226
+
227
+
228
+ class Generator(nn.Module):
229
+ def __init__(self, input_size_, ngf, n_residual_layers, num_band, args, ratios=[5, 5, 4, 3], onnx_export=False,
230
+ device='cpu'):
231
+ super().__init__()
232
+ self.hop_length = args.frame_shift
233
+ self.args = args
234
+ self.onnx_export = onnx_export
235
+
236
+ # ------------- Define upsample layers ----------------
237
+ mult = int(2 ** len(ratios))
238
+ model_up = []
239
+ input_size = input_size_
240
+ model_up += [
241
+ nn.ReflectionPad1d(3),
242
+ WNConv1d(input_size, mult * ngf, kernel_size=7, padding=0),
243
+ ]
244
+
245
+ # Upsample to raw audio scale
246
+ for i, r in enumerate(ratios):
247
+ model_up += [Upsample(mult * ngf, r)]
248
+ model_up += [ResBlockMRFV2Inter(mult * ngf // 2)]
249
+ mult //= 2
250
+
251
+ model_up += [
252
+ nn.LeakyReLU(0.2),
253
+ nn.ReflectionPad1d(3),
254
+ WNConv1d(ngf, num_band, kernel_size=7, padding=0),
255
+ nn.Tanh(),
256
+ ]
257
+ if not args.use_tanh:
258
+ model_up[-1] = nn.Conv1d(num_band, num_band, 1)
259
+ model_up[-2].apply(weights_zero_init)
260
+
261
+ self.model_up = nn.Sequential(*model_up)
262
+
263
+ self.apply(weights_init)
264
+
265
+ def forward(self, mel, step=None):
266
+ # mel input: (batch_size, seq_num, 80)
267
+ if self.onnx_export:
268
+ mel = mel.transpose(1, 2)
269
+ # on onnx, for engineering, mel input: (batch_size, 80, seq_num)
270
+
271
+ # Between Down and up
272
+ x = mel
273
+
274
+ # Upsample pipline
275
+ cnt_after_upsample = 0
276
+
277
+ for i, m in enumerate(self.model_up):
278
+ x = m(x)
279
+
280
+ if type(m) == Upsample:
281
+ cnt_after_upsample += 1
282
+
283
+ return x
tts/modules/wavvae/decoder/seanet_encoder.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List
16
+
17
+ import torch
18
+ from torch import nn
19
+ from tts.modules.wavvae.encoder.common_modules.seanet import SEANetEncoder
20
+
21
+ class Encoder(nn.Module):
22
+ def __init__(
23
+ self,
24
+ dowmsamples: List[int] = [6, 5, 5, 4, 2],
25
+ ):
26
+ super().__init__()
27
+
28
+ # breakpoint()
29
+ self.frame_rate = 25 # not use
30
+ self.encoder = SEANetEncoder(causal=False, n_residual_layers=1, norm='weight_norm', pad_mode='reflect', lstm=2,
31
+ dimension=512, channels=1, n_filters=32, ratios=dowmsamples, activation='ELU',
32
+ kernel_size=7, residual_kernel_size=3, last_kernel_size=7, dilation_base=2,
33
+ true_skip=False, compress=2)
34
+
35
+ def forward(self, audio: torch.Tensor):
36
+ audio = audio.unsqueeze(1) # audio(16,24000)
37
+ emb = self.encoder(audio)
38
+ return emb
tts/modules/wavvae/decoder/wavvae_v3.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import torch
17
+ from torch import nn
18
+ import torch.nn.functional as F
19
+
20
+ from tts.modules.wavvae.decoder.seanet_encoder import Encoder
21
+ from tts.modules.wavvae.decoder.diag_gaussian import DiagonalGaussianDistribution
22
+ from tts.modules.wavvae.decoder.hifigan_modules import Generator, Upsample
23
+
24
+
25
+ class WavVAE_V3(nn.Module):
26
+ def __init__(self, hparams=None):
27
+ super().__init__()
28
+ self.encoder = Encoder(dowmsamples=[6, 5, 4, 4, 2])
29
+ self.proj_to_z = nn.Linear(512, 64)
30
+ self.proj_to_decoder = nn.Linear(32, 320)
31
+
32
+ config_path = hparams['melgan_config']
33
+ args = argparse.Namespace()
34
+ args.__dict__.update(config_path)
35
+ self.latent_upsampler = Upsample(320, 4)
36
+ self.decoder = Generator(
37
+ input_size_=160, ngf=128, n_residual_layers=4,
38
+ num_band=1, args=args, ratios=[5,4,4,3])
39
+
40
+ ''' encode waveform into 25 hz latent representation '''
41
+ def encode_latent(self, audio):
42
+ posterior = self.encode(audio)
43
+ latent = posterior.sample().permute(0, 2, 1) # (b,t,latent_channel)
44
+ return latent
45
+
46
+ def encode(self, audio):
47
+ x = self.encoder(audio).permute(0, 2, 1)
48
+ x = self.proj_to_z(x).permute(0, 2, 1)
49
+ poseterior = DiagonalGaussianDistribution(x)
50
+ return poseterior
51
+
52
+ def decode(self, latent):
53
+ latent = self.proj_to_decoder(latent).permute(0, 2, 1)
54
+ return self.decoder(self.latent_upsampler(latent))
55
+
56
+ def forward(self, audio):
57
+ posterior = self.encode(audio)
58
+ latent = posterior.sample().permute(0, 2, 1) # (b, t, latent_channel)
59
+ recon_wav = self.decode(latent)
60
+ return recon_wav, posterior
tts/modules/wavvae/encoder/common_modules/conv.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # Copyright (c) [2023] [Meta Platforms, Inc. and affiliates.]
24
+ # Copyright (c) [2025] [Ziyue Jiang]
25
+ # SPDX-License-Identifier: MIT
26
+ # This file has been modified by Ziyue Jiang on 2025/03/19
27
+ # Original file was released under MIT, with the full license text # available at https://github.com/facebookresearch/encodec/blob/gh-pages/LICENSE.
28
+ # This modified file is released under the same license.
29
+
30
+ """Convolutional layers wrappers and utilities."""
31
+
32
+ import math
33
+ import typing as tp
34
+ import warnings
35
+ import einops
36
+
37
+ import torch
38
+ from torch import nn
39
+ from torch.nn import functional as F
40
+ from torch.nn.utils import spectral_norm, weight_norm
41
+
42
+
43
+ CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
44
+ 'time_layer_norm', 'layer_norm', 'time_group_norm'])
45
+
46
+
47
+ def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
48
+ assert norm in CONV_NORMALIZATIONS
49
+ if norm == 'weight_norm':
50
+ return weight_norm(module)
51
+ elif norm == 'spectral_norm':
52
+ return spectral_norm(module)
53
+ else:
54
+ return module
55
+
56
+
57
+ def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
58
+ assert norm in CONV_NORMALIZATIONS
59
+ if norm == 'layer_norm':
60
+ assert isinstance(module, nn.modules.conv._ConvNd)
61
+ return ConvLayerNorm(module.out_channels, **norm_kwargs)
62
+ elif norm == 'time_group_norm':
63
+ if causal:
64
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
65
+ assert isinstance(module, nn.modules.conv._ConvNd)
66
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
67
+ else:
68
+ return nn.Identity()
69
+
70
+
71
+ def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
72
+ padding_total: int = 0) -> int:
73
+ length = x.shape[-1]
74
+ n_frames = (length - kernel_size + padding_total) / stride + 1
75
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
76
+ return ideal_length - length
77
+
78
+
79
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
80
+ length = x.shape[-1]
81
+ padding_left, padding_right = paddings
82
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
83
+ if mode == 'reflect':
84
+ max_pad = max(padding_left, padding_right)
85
+ extra_pad = 0
86
+ if length <= max_pad:
87
+ extra_pad = max_pad - length + 1
88
+ x = F.pad(x, (0, extra_pad))
89
+ padded = F.pad(x, paddings, mode, value)
90
+ end = padded.shape[-1] - extra_pad
91
+ return padded[..., :end]
92
+ else:
93
+ return F.pad(x, paddings, mode, value)
94
+
95
+
96
+ class ConvLayerNorm(nn.LayerNorm):
97
+ def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
98
+ super().__init__(normalized_shape, **kwargs)
99
+
100
+ def forward(self, x):
101
+ x = einops.rearrange(x, 'b ... t -> b t ...')
102
+ x = super().forward(x)
103
+ x = einops.rearrange(x, 'b t ... -> b ... t')
104
+ return
105
+
106
+
107
+ class NormConv1d(nn.Module):
108
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
109
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
110
+ super().__init__()
111
+ self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
112
+ self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
113
+ self.norm_type = norm
114
+
115
+ def forward(self, x):
116
+ x = self.conv(x)
117
+ x = self.norm(x)
118
+ return x
119
+
120
+
121
+ class SConv1d(nn.Module):
122
+ def __init__(self, in_channels: int, out_channels: int,
123
+ kernel_size: int, stride: int = 1, dilation: int = 1,
124
+ groups: int = 1, bias: bool = True, causal: bool = False,
125
+ norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
126
+ pad_mode: str = 'reflect'):
127
+ super().__init__()
128
+ # warn user on unusual setup between dilation and stride
129
+ if stride > 1 and dilation > 1:
130
+ warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1'
131
+ f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
132
+ self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
133
+ dilation=dilation, groups=groups, bias=bias, causal=causal,
134
+ norm=norm, norm_kwargs=norm_kwargs)
135
+ self.causal = causal
136
+ self.pad_mode = pad_mode
137
+
138
+ def forward(self, x):
139
+ B, C, T = x.shape
140
+ kernel_size = self.conv.conv.kernel_size[0]
141
+ stride = self.conv.conv.stride[0]
142
+ dilation = self.conv.conv.dilation[0]
143
+ kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
144
+ padding_total = kernel_size - stride
145
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
146
+ if self.causal:
147
+ # Left padding for causal
148
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
149
+ else:
150
+ # Asymmetric padding required for odd strides
151
+ padding_right = padding_total // 2
152
+ padding_left = padding_total - padding_right
153
+ x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
154
+ return self.conv(x)
tts/modules/wavvae/encoder/common_modules/lstm.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # Copyright (c) [2023] [Meta Platforms, Inc. and affiliates.]
24
+ # Copyright (c) [2025] [Ziyue Jiang]
25
+ # SPDX-License-Identifier: MIT
26
+ # This file has been modified by Ziyue Jiang on 2025/03/19
27
+ # Original file was released under MIT, with the full license text # available at https://github.com/facebookresearch/encodec/blob/gh-pages/LICENSE.
28
+ # This modified file is released under the same license.
29
+
30
+ """LSTM layers module."""
31
+ from torch import nn
32
+
33
+
34
+ class SLSTM(nn.Module):
35
+ """
36
+ LSTM without worrying about the hidden state, nor the layout of the data.
37
+ Expects input as convolutional layout.
38
+ """
39
+ def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
40
+ super().__init__()
41
+ self.skip = skip
42
+ self.lstm = nn.LSTM(dimension, dimension, num_layers)
43
+
44
+ # 修改transpose顺序
45
+ def forward(self, x):
46
+ x1 = x.permute(2, 0, 1)
47
+ y, _ = self.lstm(x1)
48
+ y = y.permute(1, 2, 0)
49
+ if self.skip:
50
+ y = y + x
51
+ return y
tts/modules/wavvae/encoder/common_modules/seanet.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # Copyright (c) [2023] [Meta Platforms, Inc. and affiliates.]
24
+ # Copyright (c) [2025] [Ziyue Jiang]
25
+ # SPDX-License-Identifier: MIT
26
+ # This file has been modified by Ziyue Jiang on 2025/03/19
27
+ # Original file was released under MIT, with the full license text # available at https://github.com/facebookresearch/encodec/blob/gh-pages/LICENSE.
28
+ # This modified file is released under the same license.
29
+
30
+ """Encodec SEANet-based encoder and decoder implementation."""
31
+
32
+ import typing as tp
33
+
34
+ import numpy as np
35
+ import torch.nn as nn
36
+
37
+ from .conv import SConv1d
38
+ from .lstm import SLSTM
39
+
40
+
41
+ class SEANetResnetBlock(nn.Module):
42
+ def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
43
+ activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
44
+ norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False,
45
+ pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True):
46
+ super().__init__()
47
+ assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations'
48
+ act = getattr(nn, activation)
49
+ hidden = dim // compress
50
+ block = []
51
+ for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
52
+ in_chs = dim if i == 0 else hidden
53
+ out_chs = dim if i == len(kernel_sizes) - 1 else hidden
54
+ block += [
55
+ act(**activation_params),
56
+ SConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
57
+ norm=norm, norm_kwargs=norm_params,
58
+ causal=causal, pad_mode=pad_mode),
59
+ ]
60
+ self.block = nn.Sequential(*block)
61
+ self.shortcut: nn.Module
62
+ if true_skip:
63
+ self.shortcut = nn.Identity()
64
+ else:
65
+ self.shortcut = SConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params,
66
+ causal=causal, pad_mode=pad_mode)
67
+
68
+ def forward(self, x):
69
+ return self.shortcut(x) + self.block(x)
70
+
71
+
72
+ class SEANetEncoder(nn.Module):
73
+ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 1,
74
+ ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
75
+ norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
76
+ last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
77
+ pad_mode: str = 'reflect', true_skip: bool = False, compress: int = 2, lstm: int = 2):
78
+ super().__init__()
79
+ self.channels = channels
80
+ self.dimension = dimension
81
+ self.n_filters = n_filters
82
+ self.ratios = list(reversed(ratios))
83
+ del ratios
84
+ self.n_residual_layers = n_residual_layers
85
+ self.hop_length = np.prod(self.ratios)
86
+
87
+ act = getattr(nn, activation)
88
+ mult = 1
89
+ model: tp.List[nn.Module] = [
90
+ SConv1d(channels, mult * n_filters, kernel_size, norm=norm, norm_kwargs=norm_params,
91
+ causal=causal, pad_mode=pad_mode)
92
+ ]
93
+ # Downsample to raw audio scale
94
+ for i, ratio in enumerate(self.ratios):
95
+ # Add residual layers
96
+ for j in range(n_residual_layers):
97
+ model += [
98
+ SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
99
+ dilations=[dilation_base ** j, 1],
100
+ norm=norm, norm_params=norm_params,
101
+ activation=activation, activation_params=activation_params,
102
+ causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
103
+
104
+ # Add downsampling layers
105
+ model += [
106
+ act(**activation_params),
107
+ SConv1d(mult * n_filters, mult * n_filters * 2,
108
+ kernel_size=ratio * 2, stride=ratio,
109
+ norm=norm, norm_kwargs=norm_params,
110
+ causal=causal, pad_mode=pad_mode),
111
+ ]
112
+ mult *= 2
113
+
114
+ if lstm:
115
+ model += [SLSTM(mult * n_filters, num_layers=lstm)]
116
+
117
+ model += [
118
+ act(**activation_params),
119
+ SConv1d(mult * n_filters, dimension, last_kernel_size, norm=norm, norm_kwargs=norm_params,
120
+ causal=causal, pad_mode=pad_mode)
121
+ ]
122
+
123
+ self.model = nn.Sequential(*model)
124
+
125
+ def forward(self, x):
126
+ return self.model(x)
tts/utils/audio_utils/align.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+
17
+ def mel2token_to_dur(mel2token, T_txt=None, max_dur=None):
18
+ is_torch = isinstance(mel2token, torch.Tensor)
19
+ has_batch_dim = True
20
+ if not is_torch:
21
+ mel2token = torch.LongTensor(mel2token)
22
+ if T_txt is None:
23
+ T_txt = mel2token.max()
24
+ if len(mel2token.shape) == 1:
25
+ mel2token = mel2token[None, ...]
26
+ has_batch_dim = False
27
+ B, _ = mel2token.shape
28
+ dur = mel2token.new_zeros(B, T_txt + 1).scatter_add(1, mel2token, torch.ones_like(mel2token))
29
+ dur = dur[:, 1:]
30
+ if max_dur is not None:
31
+ dur = dur.clamp(max=max_dur)
32
+ if not is_torch:
33
+ dur = dur.numpy()
34
+ if not has_batch_dim:
35
+ dur = dur[0]
36
+ return dur
tts/utils/audio_utils/io.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import io
16
+ import os
17
+ import subprocess
18
+
19
+ import numpy as np
20
+ from scipy.io import wavfile
21
+ import pyloudnorm as pyln
22
+ from pydub import AudioSegment
23
+
24
+
25
+ def to_wav_bytes(wav, sr, norm=False):
26
+ wav = wav.astype(float)
27
+ if norm:
28
+ meter = pyln.Meter(sr) # create BS.1770 meter
29
+ loudness = meter.integrated_loudness(wav)
30
+ wav = pyln.normalize.loudness(wav, loudness, -18.0)
31
+ if np.abs(wav).max() >= 1:
32
+ wav = wav / np.abs(wav).max() * 0.95
33
+ wav = wav * 32767
34
+ bytes_io = io.BytesIO()
35
+ wavfile.write(bytes_io, sr, wav.astype(np.int16))
36
+ return bytes_io.getvalue()
37
+
38
+
39
+ def save_wav(wav_bytes, path):
40
+ with open(path[:-4] + '.wav', 'wb') as file:
41
+ file.write(wav_bytes)
42
+ if path[-4:] == '.mp3':
43
+ to_mp3(path[:-4])
44
+
45
+
46
+ def to_mp3(out_path):
47
+ if out_path[-4:] == '.wav':
48
+ out_path = out_path[:-4]
49
+ subprocess.check_call(
50
+ f'ffmpeg -threads 1 -loglevel error -i "{out_path}.wav" -vn -b:a 192k -y -hide_banner -async 1 "{out_path}.mp3"',
51
+ shell=True, stdin=subprocess.PIPE)
52
+ subprocess.check_call(f'rm -f "{out_path}.wav"', shell=True)
53
+
54
+
55
+ def convert_to_wav(wav_path):
56
+ # Check if the file exists
57
+ if not os.path.exists(wav_path):
58
+ print(f"The file '{wav_path}' does not exist.")
59
+ return
60
+
61
+ # Check if the file already has a .wav extension
62
+ if not wav_path.endswith(".wav"):
63
+ # Define the output path with a .wav extension
64
+ out_path = os.path.splitext(wav_path)[0] + ".wav"
65
+
66
+ # Load the audio file using pydub and convert it to WAV
67
+ audio = AudioSegment.from_file(wav_path)
68
+ audio.export(out_path, format="wav")
69
+
70
+ print(f"Converted '{wav_path}' to '{out_path}'")
71
+
72
+
73
+ def convert_to_wav_bytes(audio_binary):
74
+ # Load the audio binary using pydub and convert it to WAV
75
+ audio = AudioSegment.from_file(io.BytesIO(audio_binary))
76
+ wav_bytes = io.BytesIO()
77
+ audio.export(wav_bytes, format="wav")
78
+ wav_bytes.seek(0)
79
+ return wav_bytes
80
+
81
+
82
+ ''' Smoothly combine audio segments using crossfade transitions." '''
83
+ def combine_audio_segments(segments, crossfade_duration=0.16, sr=24000):
84
+ window_length = int(sr * crossfade_duration)
85
+ hanning_window = np.hanning(2 * window_length)
86
+ # Combine
87
+ for i, segment in enumerate(segments):
88
+ if i == 0:
89
+ combined_audio = segment
90
+ else:
91
+ overlap = combined_audio[-window_length:] * hanning_window[window_length:] + segment[:window_length] * hanning_window[:window_length]
92
+ combined_audio = np.concatenate(
93
+ [combined_audio[:-window_length], overlap, segment[window_length:]]
94
+ )
95
+ return combined_audio
tts/utils/audio_utils/plot.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import matplotlib
16
+
17
+ matplotlib.use('Agg')
18
+ import matplotlib.pyplot as plt
19
+ import numpy as np
20
+ import torch
21
+
22
+ LINE_COLORS = ['w', 'r', 'orange', 'k', 'cyan', 'm', 'b', 'lime', 'g', 'brown', 'navy']
23
+
24
+
25
+ def spec_to_figure(spec, vmin=None, vmax=None, title='', f0s=None, dur_info=None, figsize=(12, 6)):
26
+ if isinstance(spec, torch.Tensor):
27
+ spec = spec.cpu().numpy()
28
+ H = spec.shape[1] // 2
29
+ fig = plt.figure(figsize=figsize)
30
+ plt.title(title)
31
+ plt.pcolor(spec.T, vmin=vmin, vmax=vmax)
32
+
33
+ if dur_info is not None:
34
+ assert isinstance(dur_info, dict)
35
+ txt = dur_info['txt']
36
+ dur_gt = dur_info['dur_gt']
37
+ if isinstance(dur_gt, torch.Tensor):
38
+ dur_gt = dur_gt.cpu().numpy()
39
+ dur_gt = np.cumsum(dur_gt).astype(int)
40
+ for i in range(len(dur_gt)):
41
+ shift = (i % 8) + 1
42
+ plt.text(dur_gt[i], shift * 4, txt[i])
43
+ plt.vlines(dur_gt[i], 0, H // 2, colors='b') # blue is gt
44
+ plt.xlim(0, dur_gt[-1])
45
+ if 'dur_pred' in dur_info:
46
+ dur_pred = dur_info['dur_pred']
47
+ if isinstance(dur_pred, torch.Tensor):
48
+ dur_pred = dur_pred.cpu().numpy()
49
+ dur_pred = np.cumsum(dur_pred).astype(int)
50
+ for i in range(len(dur_pred)):
51
+ shift = (i % 8) + 1
52
+ plt.text(dur_pred[i], H + shift * 4, txt[i])
53
+ plt.vlines(dur_pred[i], H, H * 1.5, colors='r') # red is pred
54
+ plt.xlim(0, max(dur_gt[-1], dur_pred[-1]))
55
+ if f0s is not None:
56
+ ax = plt.gca()
57
+ ax2 = ax.twinx()
58
+ # ax.set_xticks()
59
+
60
+ if not isinstance(f0s, dict):
61
+ f0s = {'f0': f0s}
62
+ for i, (k, f0) in enumerate(f0s.items()):
63
+ if f0 is not None:
64
+ if isinstance(f0, torch.Tensor):
65
+ f0 = f0.cpu().numpy()
66
+ ax2.plot(
67
+ np.arange(len(f0)) + 0.5, f0, label=k, c=LINE_COLORS[i], linewidth=1, alpha=0.5)
68
+ ax2.set_ylim(0, 1000)
69
+ ax2.legend()
70
+ return fig
71
+
72
+
73
+ def align_to_figure(align, dur_info):
74
+ if isinstance(align, torch.Tensor):
75
+ align = align.cpu().numpy()
76
+ H = align.shape[1]
77
+ fig = plt.figure(figsize=(12, 6))
78
+ plt.pcolor(align.T, vmin=0, vmax=1)
79
+ if dur_info is not None:
80
+ assert isinstance(dur_info, dict)
81
+ txt = dur_info['txt']
82
+ dur_gt = dur_info['dur_gt']
83
+ if isinstance(dur_gt, torch.Tensor):
84
+ dur_gt = dur_gt.cpu().numpy()
85
+ dur_gt = np.cumsum(dur_gt).astype(int) // 2
86
+ for i in range(len(dur_gt)):
87
+ plt.text(dur_gt[i], i, txt[i], color='red')
88
+ plt.vlines(dur_gt[i], 0, H, colors='b') # blue is gt
89
+ # plt.xlim(0, dur_gt[-1])
90
+ return fig
tts/utils/commons/ckpt_utils.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import contextlib
16
+ import glob
17
+ import os
18
+ import re
19
+ import subprocess
20
+ import traceback
21
+
22
+ import torch
23
+ from torch.nn.parallel import DistributedDataParallel
24
+ import torch.distributed as dist
25
+
26
+
27
+ @contextlib.contextmanager
28
+ def dist_load(path):
29
+ if not dist.is_initialized() or dist.get_world_size() == 1 or os.path.realpath(path).startswith('/dev/shm'):
30
+ yield path
31
+ else:
32
+ from tts.utils.commons.hparams import hparams
33
+ from tts.utils.commons.trainer import LOCAL_RANK
34
+ tmpdir = '/dev/shm'
35
+ assert len(os.path.basename(path)) > 0
36
+ shm_ckpt_path = f'{tmpdir}/{hparams["exp_name"]}/{os.path.basename(path)}'
37
+ if LOCAL_RANK == 0:
38
+ subprocess.check_call(
39
+ f'mkdir -p {os.path.dirname(shm_ckpt_path)}; '
40
+ f'cp -Lr {path} {shm_ckpt_path}', shell=True)
41
+ dist.barrier()
42
+ yield shm_ckpt_path
43
+ dist.barrier()
44
+ if LOCAL_RANK == 0:
45
+ subprocess.check_call(f'rm -rf {shm_ckpt_path}', shell=True)
46
+
47
+
48
+ def torch_load_dist(path, map_location='cpu'):
49
+ with dist_load(path) as tmp_path:
50
+ checkpoint = torch.load(tmp_path, map_location=map_location)
51
+ return checkpoint
52
+
53
+
54
+ def get_last_checkpoint(work_dir, steps=None):
55
+ checkpoint = None
56
+ last_ckpt_path = None
57
+ ckpt_paths = get_all_ckpts(work_dir, steps)
58
+ if len(ckpt_paths) > 0:
59
+ last_ckpt_path = ckpt_paths[0]
60
+ checkpoint = torch_load_dist(last_ckpt_path, map_location='cpu')
61
+ return checkpoint, last_ckpt_path
62
+
63
+
64
+ def get_all_ckpts(work_dir, steps=None):
65
+ if steps is None or steps == 0:
66
+ ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_*.ckpt'
67
+ else:
68
+ ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_{steps}.ckpt'
69
+ return sorted(glob.glob(ckpt_path_pattern),
70
+ key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0]))
71
+
72
+
73
+ def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True,
74
+ silent=False, load_opt=False, opts=None, steps=None, checkpoint=None, ckpt_path='', delete_unmatch=True):
75
+ if checkpoint is None:
76
+ if os.path.isfile(ckpt_base_dir):
77
+ base_dir = os.path.dirname(ckpt_base_dir)
78
+ ckpt_path = ckpt_base_dir
79
+ checkpoint = torch_load_dist(ckpt_base_dir, map_location='cpu')
80
+ else:
81
+ base_dir = ckpt_base_dir
82
+ if load_opt:
83
+ checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir, steps)
84
+ else:
85
+ ckpt_path = f'{ckpt_base_dir}/model_only_last.ckpt'
86
+ if os.path.exists(ckpt_path):
87
+ checkpoint = torch_load_dist(ckpt_path, map_location='cpu')
88
+ else:
89
+ checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir, steps)
90
+ if checkpoint is not None:
91
+ state_dict_all = {
92
+ k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in checkpoint["state_dict"].items()}
93
+ if not isinstance(cur_model, list):
94
+ cur_models = [cur_model]
95
+ model_names = [model_name]
96
+ else:
97
+ cur_models = cur_model
98
+ model_names = model_name
99
+ for model_name, cur_model in zip(model_names, cur_models):
100
+ if isinstance(cur_model, DistributedDataParallel):
101
+ cur_model = cur_model.module
102
+ device = next(cur_model.parameters()).device
103
+ if '.' not in model_name:
104
+ state_dict = state_dict_all[model_name]
105
+ else:
106
+ base_model_name = model_name.split('.')[0]
107
+ rest_model_name = model_name[len(base_model_name) + 1:]
108
+ state_dict = {
109
+ k[len(rest_model_name) + 1:]: v for k, v in state_dict_all[base_model_name].items()
110
+ if k.startswith(f'{rest_model_name}.')}
111
+ state_dict = {k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in state_dict.items()}
112
+ if not strict and delete_unmatch:
113
+ try:
114
+ cur_model.load_state_dict(state_dict, strict=True)
115
+ if not silent:
116
+ print(f"| loaded '{model_name}' from '{ckpt_path}' with strict=True.")
117
+ except:
118
+ cur_model_state_dict = cur_model.state_dict()
119
+ cur_model_state_dict = {k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in
120
+ cur_model_state_dict.items()}
121
+ unmatched_keys = []
122
+ for key, param in state_dict.items():
123
+ if key in cur_model_state_dict:
124
+ new_param = cur_model_state_dict[key]
125
+ if new_param.shape != param.shape:
126
+ unmatched_keys.append(key)
127
+ print("| Unmatched keys: ", key, "cur model: ", new_param.shape,
128
+ "ckpt model: ", param.shape)
129
+ for key in unmatched_keys:
130
+ del state_dict[key]
131
+ load_results = cur_model.load_state_dict(state_dict, strict=strict)
132
+ cur_model.to(device)
133
+ if not silent:
134
+ print(f"| loaded '{model_name}' from '{ckpt_path}'.")
135
+ missing_keys, unexpected_keys = load_results.missing_keys, load_results.unexpected_keys
136
+ print(f"| Missing keys: {len(missing_keys)}, Unexpected keys: {len(unexpected_keys)}")
137
+ if load_opt:
138
+ optimizer_states = checkpoint['optimizer_states']
139
+ assert len(opts) == len(optimizer_states)
140
+ for optimizer, opt_state in zip(opts, optimizer_states):
141
+ opt_state = {k.replace('_orig_mod.', ''): v for k, v in opt_state.items()}
142
+ if optimizer is None:
143
+ return
144
+ try:
145
+ optimizer.load_state_dict(opt_state)
146
+ for i, state in enumerate(optimizer.state.values()):
147
+ for k, v in state.items():
148
+ if isinstance(v, torch.Tensor):
149
+ state[k] = v.to(device)
150
+ except ValueError:
151
+ print(f"| WARMING: optimizer {optimizer} parameters not match !!!")
152
+ return checkpoint.get('global_step', 0)
153
+ else:
154
+ e_msg = f"| ckpt not found in {base_dir}."
155
+ if force:
156
+ assert False, e_msg
157
+ else:
158
+ print(e_msg)
159
+
160
+
161
+ def load_with_size_mismatch(model, state_dict, prefix=""):
162
+ current_model_dict = model.state_dict()
163
+ cm_keys = current_model_dict.keys()
164
+ mismatch_keys = {k.replace(prefix, "") for k, v in state_dict.items() if k.replace(prefix, "") in cm_keys and v.size() != current_model_dict[k.replace(prefix, "")].size()}
165
+ new_state_dict = {k.replace(prefix, ""): v for k, v in state_dict.items() if k.replace(prefix, "") in cm_keys and v.size() == current_model_dict[k.replace(prefix, "")].size()}
166
+ missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
167
+ print(f"| mismatch keys: ", mismatch_keys)
168
+ if len(missing_keys) > 0:
169
+ print(f"| missing_keys in dit: {missing_keys}")
170
+ if len(unexpected_keys) > 0:
171
+ print(f"| unexpected_keys in dit: {unexpected_keys}")
tts/utils/commons/hparams.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import json
17
+ import os
18
+ import re
19
+
20
+ import yaml
21
+
22
+ global_print_hparams = True
23
+ hparams = {}
24
+
25
+
26
+ class Args:
27
+ def __init__(self, **kwargs):
28
+ for k, v in kwargs.items():
29
+ self.__setattr__(k, v)
30
+
31
+
32
+ def override_config(old_config: dict, new_config: dict):
33
+ if new_config.get('__replace', False):
34
+ old_config.clear()
35
+ for k, v in new_config.items():
36
+ if isinstance(v, dict) and k in old_config:
37
+ override_config(old_config[k], new_config[k])
38
+ else:
39
+ old_config[k] = v
40
+
41
+
42
+ def traverse_dict(d, func, ctx):
43
+ for k in list(d.keys()):
44
+ v = d[k]
45
+ if isinstance(v, dict):
46
+ traverse_dict(v, func, ctx)
47
+ else:
48
+ d[k] = func(v, ctx)
49
+
50
+
51
+ def parse_config(v, context=None):
52
+ if context is None:
53
+ context = {}
54
+
55
+ if isinstance(v, str):
56
+ if v.startswith('^'):
57
+ return load_config(v[1:], [], set())
58
+
59
+ match = re.match(r"\${(.*)}", v)
60
+ if match:
61
+ expression = match.group(1)
62
+ return eval(expression, {}, context)
63
+ return v
64
+
65
+
66
+ def remove_meta_key(d):
67
+ for k in list(d.keys()):
68
+ v = d[k]
69
+ if isinstance(v, dict):
70
+ remove_meta_key(v)
71
+ else:
72
+ if k[:2] == '__':
73
+ del d[k]
74
+
75
+
76
+ def load_config(config_fn, config_chains, loaded_configs):
77
+ # deep first inheritance and avoid the second visit of one node
78
+ if not os.path.exists(config_fn):
79
+ print(f"| WARN: {config_fn} not exist.", )
80
+ return {}
81
+ with open(config_fn) as f:
82
+ hparams_ = yaml.safe_load(f)
83
+ loaded_configs.add(config_fn)
84
+
85
+ if 'base_config' in hparams_:
86
+ ret_hparams = {}
87
+ if not isinstance(hparams_['base_config'], list):
88
+ hparams_['base_config'] = [hparams_['base_config']]
89
+ for c in hparams_['base_config']:
90
+ if c.startswith('.'):
91
+ c = f'{os.path.dirname(config_fn)}/{c}'
92
+ c = os.path.normpath(c)
93
+ if c not in loaded_configs:
94
+ override_config(ret_hparams, load_config(c, config_chains, loaded_configs))
95
+ override_config(ret_hparams, hparams_)
96
+ else:
97
+ ret_hparams = hparams_
98
+
99
+ config_chains.append(config_fn)
100
+ return ret_hparams
101
+
102
+
103
+ def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True):
104
+ if config == '' and exp_name == '':
105
+ parser = argparse.ArgumentParser(description='')
106
+ parser.add_argument('--config', type=str, default='',
107
+ help='location of the data corpus')
108
+ parser.add_argument('--exp_name', type=str, default='', help='exp_name')
109
+ parser.add_argument('-hp', '--hparams', type=str, default='',
110
+ help='location of the data corpus')
111
+ parser.add_argument('--infer', action='store_true', help='infer')
112
+ parser.add_argument('--validate', action='store_true', help='validate')
113
+ parser.add_argument('--reset', action='store_true', help='reset hparams')
114
+ parser.add_argument('--remove', action='store_true', help='remove old ckpt')
115
+ parser.add_argument('--debug', action='store_true', help='debug')
116
+ parser.add_argument('--start_rank', type=int, default=-1,
117
+ help='the start rank id for DDP, keep 0 when single-machine multi-GPU')
118
+ parser.add_argument('--world_size', type=int, default=-1,
119
+ help='the total number of GPU used across all machines, keep -1 for single-machine multi-GPU')
120
+ parser.add_argument('--init_method', type=str, default='tcp', help='method to init ddp, use tcp or file')
121
+ parser.add_argument('--master_addr', type=str, default='', help='')
122
+ parser.add_argument('--ddp_dir', type=str, default='', help='')
123
+
124
+ args, unknown = parser.parse_known_args()
125
+ if print_hparams:
126
+ print("| set_hparams Unknow hparams: ", unknown)
127
+ else:
128
+ args = Args(config=config, exp_name=exp_name, hparams=hparams_str,
129
+ infer=False, validate=False, reset=False, debug=False, remove=False,
130
+ start_rank=-1, world_size=-1, init_method='tcp', ddp_dir='', master_addr='')
131
+ global hparams
132
+ assert args.config != '' or args.exp_name != ''
133
+ if args.config != '':
134
+ assert os.path.exists(args.config), f"{args.config} not exists"
135
+
136
+ saved_hparams = {}
137
+ args_work_dir = ''
138
+ if args.exp_name != '':
139
+ args_work_dir = f'{args.exp_name}'
140
+ ckpt_config_path = f'{args_work_dir}/config.yaml'
141
+ if os.path.exists(ckpt_config_path):
142
+ with open(ckpt_config_path) as f:
143
+ saved_hparams_ = yaml.safe_load(f)
144
+ if saved_hparams_ is not None:
145
+ saved_hparams.update(saved_hparams_)
146
+ hparams_ = {}
147
+ config_chains = []
148
+ if args.config != '':
149
+ hparams_.update(load_config(args.config, config_chains, set()))
150
+ if len(config_chains) > 1 and print_hparams:
151
+ print('| Hparams chains: ', config_chains)
152
+ if not args.reset:
153
+ hparams_.update(saved_hparams)
154
+ traverse_dict(hparams_, parse_config, hparams_)
155
+ hparams_['work_dir'] = args_work_dir
156
+
157
+ # Support config overriding in command line. Support list type config overriding.
158
+ # Examples: --hparams="a=1,b.c=2,d=[1 1 1]"
159
+ if args.hparams != "":
160
+ for new_hparam in args.hparams.split(","):
161
+ k, v = new_hparam.split("=")
162
+ v = v.strip("\'\" ")
163
+ config_node = hparams_
164
+ for k_ in k.split(".")[:-1]:
165
+ config_node = config_node[k_]
166
+ k = k.split(".")[-1]
167
+ if k in config_node:
168
+ if v in ['True', 'False'] or type(config_node[k]) in [bool, list, dict]:
169
+ if type(config_node[k]) == list:
170
+ v = v.replace(" ", ",").replace('^', "\"")
171
+ if '|' in v:
172
+ tp = type(config_node[k][0]) if len(config_node[k]) else str
173
+ config_node[k] = [tp(x) for x in v.split("|") if x != '']
174
+ continue
175
+ config_node[k] = eval(v)
176
+ else:
177
+ config_node[k] = type(config_node[k])(v)
178
+ else:
179
+ config_node[k] = v
180
+ try:
181
+ config_node[k] = float(v)
182
+ except:
183
+ pass
184
+ try:
185
+ config_node[k] = int(v)
186
+ except:
187
+ pass
188
+ if v.lower() in ['false', 'true']:
189
+ config_node[k] = v.lower() == 'true'
190
+
191
+ if args_work_dir != '' and not args.infer:
192
+ os.makedirs(hparams_['work_dir'], exist_ok=True)
193
+
194
+ hparams_['infer'] = args.infer
195
+ hparams_['debug'] = args.debug
196
+ hparams_['validate'] = args.validate
197
+ hparams_['exp_name'] = args.exp_name
198
+
199
+ hparams_['start_rank'] = args.start_rank # useful for multi-machine training
200
+ hparams_['world_size'] = args.world_size
201
+ hparams_['init_method'] = args.init_method
202
+ hparams_['ddp_dir'] = args.ddp_dir
203
+ hparams_['master_addr'] = args.master_addr
204
+
205
+ remove_meta_key(hparams_)
206
+ global global_print_hparams
207
+ if global_hparams:
208
+ hparams.clear()
209
+ hparams.update(hparams_)
210
+ if print_hparams and global_print_hparams and global_hparams:
211
+ print('| Hparams: ', json.dumps(hparams_, indent=2, sort_keys=True))
212
+ # for i, (k, v) in enumerate(sorted(hparams_.items())):
213
+ # print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "")
214
+ global_print_hparams = False
215
+ return hparams_
tts/utils/text_utils/dict.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"phone": ["C0a", "C0ai", "C0air", "C0an", "C0ang", "C0angr", "C0anr", "C0ao", "C0aor", "C0ar", "C0b", "C0c", "C0ch", "C0d", "C0e", "C0ei", "C0eir", "C0en", "C0eng", "C0engr", "C0enr", "C0er", "C0f", "C0g", "C0h", "C0i", "C0ia", "C0ian", "C0iang", "C0iangr", "C0ianr", "C0iao", "C0iaor", "C0iar", "C0ie", "C0ier", "C0ii", "C0iii", "C0iiir", "C0iir", "C0in", "C0ing", "C0ingr", "C0inr", "C0io", "C0iong", "C0iongr", "C0iou", "C0iour", "C0ir", "C0j", "C0k", "C0l", "C0m", "C0n", "C0ng", "C0o", "C0ong", "C0ongr", "C0or", "C0ou", "C0our", "C0p", "C0q", "C0r", "C0s", "C0sh", "C0t", "C0u", "C0ua", "C0uai", "C0uair", "C0uan", "C0uang", "C0uangr", "C0uanr", "C0uar", "C0uei", "C0ueir", "C0uen", "C0ueng", "C0uengr", "C0uenr", "C0uo", "C0uor", "C0ur", "C0v", "C0van", "C0vanr", "C0ve", "C0ver", "C0vn", "C0vnr", "C0vr", "C0x", "C0z", "C0zh", "C0_", "E0aa", "E0ae", "E0ah", "E0ao", "E0aw", "E0ax", "E0ay", "E0b", "E0ch", "E0d", "E0dh", "E0eh", "E0ehr", "E0er", "E0ey", "E0f", "E0g", "E0hh", "E0ih", "E0iy", "E0iyr", "E0jh", "E0k", "E0l", "E0m", "E0n", "E0ng", "E0oh", "E0ow", "E0oy", "E0p", "E0r", "E0s", "E0sh", "E0t", "E0th", "E0uh", "E0uw", "E0uwr", "E0v", "E0w", "E0y", "E0z", "E0zh", "sil", "…", "、", "。", "《", "》", "【", "】", "!", """, "#", "$", "%", "'", "''", "(", ")", "*", ",", ":", ";", "?", "\", "^", "_", "`", "{", "}", "~"], "tone": ["0", "1", "10", "11", "12", "13", "15", "17", "2", "3", "4", "5", "6", "7", "8", "9"], "wordCategory": ["0", "B", "E", "M", "S"], "prosody": ["0", "1", "2", "3", "4"], "focus": ["0", "1"], "intonation": ["0", "1", "2"], "phraseAccent": ["0", "H-", "L-"], "boundaryTone": ["0", "H%", "L%"], "accentType": ["!H*", "0", "H*", "L*", "L*+H", "L+H*"]}
tts/utils/text_utils/ph_tone_convert.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+
18
+ def map_phone_to_tokendict(item, pad_bos_eos=True):
19
+ # Merge Chinese phone and tone (Original dict ends at 173, i.e., ph_dict_size=173). 146~173 is punctuations.
20
+ phone = item['txt_token'].clone()
21
+ merged_phone = item['txt_token'].clone()
22
+ tone_tmp = item['tone'].clone()
23
+ # In tone_dict, tone_1 is 4, tone_2 is 11, tone_3 is 12, tone_4 is 13, tone_5 is 14, tone_6 is 15
24
+ tone_tmp[tone_tmp==4] = 1
25
+ tone_tmp[tone_tmp==11] = 2
26
+ tone_tmp[tone_tmp==12] = 3
27
+ tone_tmp[tone_tmp==13] = 4
28
+ tone_tmp[tone_tmp==14] = 5
29
+ tone_tmp[tone_tmp==15] = 6
30
+ # Chinese phones lie in 3~100 in the phone_dict, we map them to 200~788
31
+ ch_phone_idx = (phone >= 3) & (phone <= 100)
32
+ merged_phone[ch_phone_idx] = (merged_phone[ch_phone_idx] - 3) * 6 + 200 + tone_tmp[ch_phone_idx]
33
+
34
+ if pad_bos_eos:
35
+ merged_phone = F.pad(merged_phone, (1, 0), mode='constant', value=798)
36
+ merged_phone = F.pad(merged_phone, (0, 1), mode='constant', value=799)
37
+ return merged_phone
38
+
39
+ def split_ph_timestamp(ph_timestamp):
40
+ ''' Input: ph_timestamp, shape [T] '''
41
+
42
+ # Map the timestamp of each phone back to its original frame-level lengths
43
+ ph_timestamp[ph_timestamp >= 800] -= 800
44
+
45
+ ph_list = []
46
+ tone_list = []
47
+ dur_list = []
48
+ cur_timestamp = 0
49
+ for idx, item in enumerate(ph_timestamp):
50
+ if idx % 2 == 0:
51
+ # Map Chinese phones back to its original phone_dict
52
+ if (200 <= item <= 788):
53
+ ph = (item - 200 - 1) // 6 + 3
54
+ tone = (item - 200 - 1) % 6 + 1
55
+ if tone == 1:
56
+ tone = 4
57
+ else:
58
+ tone = tone + 9
59
+ # Set English tone to '3'
60
+ else:
61
+ ph = item
62
+ tone = 3
63
+ ph_list.append(ph)
64
+ tone_list.append(tone)
65
+ else:
66
+ dur_list.append((item - cur_timestamp))
67
+ cur_timestamp = item
68
+ assert len(ph_list) == len(dur_list), f"{len(ph_list)}, {len(dur_list)}"
69
+ ph_seq, tone_seq, dur_seq = torch.LongTensor(ph_list), torch.LongTensor(tone_list), torch.LongTensor(dur_list)
70
+ return ph_seq, tone_seq, dur_seq, ph_timestamp[-1]
71
+
72
+ def split_ph(ph_seq):
73
+ ''' Input: ph_timestamp, shape [T] '''
74
+ ph_list = []
75
+ tone_list = []
76
+ for idx, item in enumerate(ph_seq):
77
+ # Map Chinese phones back to its original phone_dict
78
+ if (200 <= item <= 788):
79
+ ph = (item - 200 - 1) // 6 + 3
80
+ tone = (item - 200 - 1) % 6 + 1
81
+ if tone == 1:
82
+ tone = 4
83
+ else:
84
+ tone = tone + 9
85
+ # Set English tone to '3'
86
+ else:
87
+ ph = item
88
+ tone = 3
89
+ ph_list.append(ph)
90
+ tone_list.append(tone)
91
+
92
+ assert len(ph_list) == len(tone_list)
93
+ ph_seq, tone_seq = torch.LongTensor(ph_list), torch.LongTensor(tone_list)
94
+ return ph_seq, tone_seq
tts/utils/text_utils/split_text.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+
17
+ def chunk_text_chinese(text, limit=60):
18
+ # 中文字符匹配
19
+ chinese_pattern = re.compile(r'[\u4e00-\u9fff]')
20
+ # 标点符号匹配
21
+ punctuation = ",。!?;:,\.!?;"
22
+
23
+ result = [] # 存储断句结果
24
+ current_chunk = [] # 当前片段
25
+ chinese_count = 0 # 中文字符计数
26
+
27
+ i = 0
28
+ while i < len(text):
29
+ char = text[i]
30
+ current_chunk.append(char)
31
+ if chinese_pattern.match(char):
32
+ chinese_count += 1
33
+
34
+ if chinese_count >= limit: # 达到限制字符数
35
+ # 从当前位置往前找最近的标点符号
36
+ for j in range(len(current_chunk) - 1, -1, -1):
37
+ if current_chunk[j] in punctuation:
38
+ result.append(''.join(current_chunk[:j + 1]))
39
+ current_chunk = current_chunk[j + 1:]
40
+ chinese_count = sum(1 for c in current_chunk if chinese_pattern.match(c))
41
+ break
42
+ else:
43
+ # 如果前面没有标点符号,则继续找后面的标点符号
44
+ for k in range(i + 1, len(text)):
45
+ if text[k] in punctuation:
46
+ result.append(''.join(current_chunk)+text[i+1:k+1])
47
+ current_chunk = []
48
+ chinese_count = 0
49
+ i = k
50
+ break
51
+ i+=1
52
+
53
+ # 添加最后剩余的部分
54
+ if current_chunk:
55
+ result.append(''.join(current_chunk))
56
+
57
+ return result
58
+
59
+ def chunk_text_english(text, max_chars=130):
60
+ """
61
+ Splits the input text into chunks, each with a maximum number of characters.
62
+
63
+ Args:
64
+ text (str): The text to be split.
65
+ max_chars (int): The maximum number of characters per chunk.
66
+
67
+ Returns:
68
+ List[str]: A list of text chunks.
69
+ """
70
+ chunks = []
71
+ current_chunk = ""
72
+ # Split the text into sentences based on punctuation followed by whitespace
73
+ sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
74
+
75
+ for sentence in sentences:
76
+ if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
77
+ current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
78
+ else:
79
+ if current_chunk:
80
+ chunks.append(current_chunk.strip())
81
+ current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
82
+
83
+ if current_chunk:
84
+ chunks.append(current_chunk.strip())
85
+
86
+ return chunks
87
+
88
+ if __name__ == '__main__':
89
+ print(chunk_text_chinese("哇塞!家人们,你们太好运了。我居然发现了一个宝藏零食大礼包,简直适合所有人的口味!有香辣的,让你舌尖跳舞;有盐焗的,咸香可口;还有五香的,香气四溢。就连怀孕的姐妹都吃得津津有味!整整三十包啊!什么手撕蟹柳、辣子鸡、嫩豆干、手撕素肉、鹌鹑蛋、小肉枣肠、猪肉腐、魔芋、魔芋丝等等,应有尽有。香辣土豆爽辣过瘾,各种素肉嚼劲十足,鹌鹑蛋营养美味,真的太多太多啦,...家人们,现在价格太划算了,赶紧下单。"))
90
+ print(chunk_text_english("Washington CNN When President Donald Trump declared in the House Chamber this week that executives at the nation’s top automakers were “so excited” about their prospects amid his new tariff regime, it did not entirely reflect the conversation he’d held with them earlier that day."))
tts/utils/text_utils/text_encoder.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import re
17
+ import six
18
+ from six.moves import range # pylint: disable=redefined-builtin
19
+
20
+ PAD = "<pad>"
21
+ EOS = "<EOS>"
22
+ UNK = "<UNK>"
23
+ SEG = "|"
24
+ PUNCS = '!,.?;:'
25
+ RESERVED_TOKENS = [PAD, EOS, UNK]
26
+ NUM_RESERVED_TOKENS = len(RESERVED_TOKENS)
27
+ PAD_ID = RESERVED_TOKENS.index(PAD) # Normally 0
28
+ EOS_ID = RESERVED_TOKENS.index(EOS) # Normally 1
29
+ UNK_ID = RESERVED_TOKENS.index(UNK) # Normally 2
30
+
31
+ if six.PY2:
32
+ RESERVED_TOKENS_BYTES = RESERVED_TOKENS
33
+ else:
34
+ RESERVED_TOKENS_BYTES = [bytes(PAD, "ascii"), bytes(EOS, "ascii")]
35
+
36
+ # Regular expression for unescaping token strings.
37
+ # '\u' is converted to '_'
38
+ # '\\' is converted to '\'
39
+ # '\213;' is converted to unichr(213)
40
+ _UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);")
41
+ _ESCAPE_CHARS = set(u"\\_u;0123456789")
42
+
43
+
44
+ def strip_ids(ids, ids_to_strip):
45
+ """Strip ids_to_strip from the end ids."""
46
+ ids = list(ids)
47
+ while ids and ids[-1] in ids_to_strip:
48
+ ids.pop()
49
+ return ids
50
+
51
+
52
+ class TextEncoder(object):
53
+ """Base class for converting from ints to/from human readable strings."""
54
+
55
+ def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS):
56
+ self._num_reserved_ids = num_reserved_ids
57
+
58
+ @property
59
+ def num_reserved_ids(self):
60
+ return self._num_reserved_ids
61
+
62
+ def encode(self, s):
63
+ """Transform a human-readable string into a sequence of int ids.
64
+
65
+ The ids should be in the range [num_reserved_ids, vocab_size). Ids [0,
66
+ num_reserved_ids) are reserved.
67
+
68
+ EOS is not appended.
69
+
70
+ Args:
71
+ s: human-readable string to be converted.
72
+
73
+ Returns:
74
+ ids: list of integers
75
+ """
76
+ return [int(w) + self._num_reserved_ids for w in s.split()]
77
+
78
+ def decode(self, ids, strip_extraneous=False):
79
+ """Transform a sequence of int ids into a human-readable string.
80
+
81
+ EOS is not expected in ids.
82
+
83
+ Args:
84
+ ids: list of integers to be converted.
85
+ strip_extraneous: bool, whether to strip off extraneous tokens
86
+ (EOS and PAD).
87
+
88
+ Returns:
89
+ s: human-readable string.
90
+ """
91
+ if strip_extraneous:
92
+ ids = strip_ids(ids, list(range(self._num_reserved_ids or 0)))
93
+ return " ".join(self.decode_list(ids))
94
+
95
+ def decode_list(self, ids):
96
+ """Transform a sequence of int ids into a their string versions.
97
+
98
+ This method supports transforming individual input/output ids to their
99
+ string versions so that sequence to/from text conversions can be visualized
100
+ in a human readable format.
101
+
102
+ Args:
103
+ ids: list of integers to be converted.
104
+
105
+ Returns:
106
+ strs: list of human-readable string.
107
+ """
108
+ decoded_ids = []
109
+ for id_ in ids:
110
+ if 0 <= id_ < self._num_reserved_ids:
111
+ decoded_ids.append(RESERVED_TOKENS[int(id_)])
112
+ else:
113
+ decoded_ids.append(id_ - self._num_reserved_ids)
114
+ return [str(d) for d in decoded_ids]
115
+
116
+ @property
117
+ def vocab_size(self):
118
+ raise NotImplementedError()
119
+
120
+
121
+ class TokenTextEncoder(TextEncoder):
122
+ """Encoder based on a user-supplied vocabulary (file or list)."""
123
+
124
+ def __init__(self,
125
+ vocab_filename,
126
+ reverse=False,
127
+ vocab_list=None,
128
+ replace_oov=None,
129
+ num_reserved_ids=NUM_RESERVED_TOKENS):
130
+ """Initialize from a file or list, one token per line.
131
+
132
+ Handling of reserved tokens works as follows:
133
+ - When initializing from a list, we add reserved tokens to the vocab.
134
+ - When initializing from a file, we do not add reserved tokens to the vocab.
135
+ - When saving vocab files, we save reserved tokens to the file.
136
+
137
+ Args:
138
+ vocab_filename: If not None, the full filename to read vocab from. If this
139
+ is not None, then vocab_list should be None.
140
+ reverse: Boolean indicating if tokens should be reversed during encoding
141
+ and decoding.
142
+ vocab_list: If not None, a list of elements of the vocabulary. If this is
143
+ not None, then vocab_filename should be None.
144
+ replace_oov: If not None, every out-of-vocabulary token seen when
145
+ encoding will be replaced by this string (which must be in vocab).
146
+ num_reserved_ids: Number of IDs to save for reserved tokens like <EOS>.
147
+ """
148
+ super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids)
149
+ self._reverse = reverse
150
+ self._replace_oov = replace_oov
151
+ if vocab_filename:
152
+ self._init_vocab_from_file(vocab_filename)
153
+ else:
154
+ assert vocab_list is not None
155
+ self._init_vocab_from_list(vocab_list)
156
+ self.pad_index = self.token_to_id[PAD]
157
+ self.eos_index = self.token_to_id[EOS]
158
+ self.unk_index = self.token_to_id[UNK]
159
+ self.seg_index = self.token_to_id[SEG] if SEG in self.token_to_id else self.eos_index
160
+
161
+ def encode(self, s):
162
+ """Converts a space-separated string of tokens to a list of ids."""
163
+ if isinstance(s, str):
164
+ sentence = s
165
+ tokens = sentence.strip().split()
166
+ else:
167
+ tokens = s
168
+ if self._replace_oov is not None:
169
+ tokens = [t if t in self.token_to_id else self._replace_oov
170
+ for t in tokens]
171
+ ret = [self.token_to_id[tok] for tok in tokens]
172
+ return ret[::-1] if self._reverse else ret
173
+
174
+ def decode(self, ids, strip_eos=False, strip_padding=False):
175
+ if strip_padding and self.pad() in list(ids):
176
+ pad_pos = list(ids).index(self.pad())
177
+ ids = ids[:pad_pos]
178
+ if strip_eos and self.eos() in list(ids):
179
+ eos_pos = list(ids).index(self.eos())
180
+ ids = ids[:eos_pos]
181
+ return " ".join(self.decode_list(ids))
182
+
183
+ def decode_list(self, ids):
184
+ seq = reversed(ids) if self._reverse else ids
185
+ return [self._safe_id_to_token(i) for i in seq]
186
+
187
+ @property
188
+ def vocab_size(self):
189
+ return len(self.id_to_token)
190
+
191
+ def __len__(self):
192
+ return self.vocab_size
193
+
194
+ def _safe_id_to_token(self, idx):
195
+ return self.id_to_token.get(idx, "ID_%d" % idx)
196
+
197
+ def _init_vocab_from_file(self, filename):
198
+ """Load vocab from a file.
199
+
200
+ Args:
201
+ filename: The file to load vocabulary from.
202
+ """
203
+ with open(filename) as f:
204
+ tokens = [token.strip() for token in f.readlines()]
205
+
206
+ def token_gen():
207
+ for token in tokens:
208
+ yield token
209
+
210
+ self._init_vocab(token_gen(), add_reserved_tokens=False)
211
+
212
+ def _init_vocab_from_list(self, vocab_list):
213
+ """Initialize tokens from a list of tokens.
214
+
215
+ It is ok if reserved tokens appear in the vocab list. They will be
216
+ removed. The set of tokens in vocab_list should be unique.
217
+
218
+ Args:
219
+ vocab_list: A list of tokens.
220
+ """
221
+
222
+ def token_gen():
223
+ for token in vocab_list:
224
+ if token not in RESERVED_TOKENS:
225
+ yield token
226
+
227
+ self._init_vocab(token_gen())
228
+
229
+ def _init_vocab(self, token_generator, add_reserved_tokens=True):
230
+ """Initialize vocabulary with tokens from token_generator."""
231
+
232
+ self.id_to_token = {}
233
+ non_reserved_start_index = 0
234
+
235
+ if add_reserved_tokens:
236
+ self.id_to_token.update(enumerate(RESERVED_TOKENS))
237
+ non_reserved_start_index = len(RESERVED_TOKENS)
238
+
239
+ self.id_to_token.update(
240
+ enumerate(token_generator, start=non_reserved_start_index))
241
+
242
+ # _token_to_id is the reverse of _id_to_token
243
+ self.token_to_id = dict((v, k) for k, v in six.iteritems(self.id_to_token))
244
+
245
+ def pad(self):
246
+ return self.pad_index
247
+
248
+ def eos(self):
249
+ return self.eos_index
250
+
251
+ def unk(self):
252
+ return self.unk_index
253
+
254
+ def seg(self):
255
+ return self.seg_index
256
+
257
+ def store_to_file(self, filename):
258
+ """Write vocab file to disk.
259
+
260
+ Vocab files have one token per line. The file ends in a newline. Reserved
261
+ tokens are written to the vocab file as well.
262
+
263
+ Args:
264
+ filename: Full path of the file to store the vocab to.
265
+ """
266
+ with open(filename, "w") as f:
267
+ for i in range(len(self.id_to_token)):
268
+ f.write(self.id_to_token[i] + "\n")
269
+
270
+ def sil_phonemes(self):
271
+ return [p for p in self.id_to_token.values() if is_sil_phoneme(p)]
272
+
273
+
274
+ def build_token_encoder(token_list_file):
275
+ token_list = json.load(open(token_list_file))
276
+ return TokenTextEncoder(None, vocab_list=token_list, replace_oov='<UNK>')
277
+
278
+
279
+ def is_sil_phoneme(p):
280
+ return p == '' or not p[0].isalpha() or p == 'sil' or p == 'sp' or p == 'XX'