Styner1 commited on
Commit
20a5020
·
verified ·
1 Parent(s): 0d2d6b2

Upload 27 files

Browse files

Podcaster Dynamic site using Zonos-v0.1, a leading open-weight text-to-speech model trained on more than 200k hours of varied multilingual speech, delivering expressiveness and quality, for TTS podcast creation.

.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Zonos-main/assets/ArchitectureDiagram.png filter=lfs diff=lfs merge=lfs -text
37
+ Zonos-main/assets/exampleaudio.mp3 filter=lfs diff=lfs merge=lfs -text
Zonos-main/.DS_Store ADDED
Binary file (6.15 kB). View file
 
Zonos-main/.gitignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+
12
+ # Misc.
13
+ .ipynb_checkpoints/
Zonos-main/.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
Zonos-main/CONDITIONING_README.md ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Conditioning explanations
2
+ Here we will list out all the conditionings the model accepts as well as a short description and some tips for optimal use. For conditionings with a learned unconditional, they can be set to that to allow the model to infer an appropriate setting.
3
+ ### espeak
4
+ - **Type:** `EspeakPhonemeConditioner`
5
+ - **Description:**
6
+ Responsible for cleaning, phonemicizing, tokenizing, and embedding the text provided to the model. This is the text pre-processing pipeline. If you would like to change how a word is pronounced or enter raw phonemes you can do that here.
7
+
8
+ Supported by transformer and hybrid models.
9
+ ---
10
+ ### speaker
11
+ - **Type:** `PassthroughConditioner`
12
+ - **Attributes:**
13
+ - **cond_dim:** `128`
14
+ - **uncond_type:** `learned`
15
+ - **projection:** `linear`
16
+ - **Description:**
17
+ An embedded representation of the speakers voice. We use [these](https://huggingface.co/Zyphra/Zonos-v0.1-speaker-embedding) speaker embedding models. It can capture a surprising amount of detail from the reference clip and supports arbitrary length input. Try to input clean reference clips containing only speech. It can be valid to concatenate multiple clean samples from the same speaker into one long sample and may lead to better cloning. If the speaker clip is very long, it is advisable to cut out long speech-free background music segments if they exist. If the reference clip is yielding noisy outputs with denoising enabled we recommend doing source separation before cloning.
18
+
19
+ Supported by transformer and hybrid models.
20
+ ---
21
+ ### emotion
22
+ - **Type:** `FourierConditioner`
23
+ - **Attributes:**
24
+ - **input_dim:** `8`
25
+ - **uncond_type:** `learned`
26
+ - **Description:**
27
+ Encodes emotion in an 8D vector. Included emotions are Happiness, Sadness, Disgust, Fear, Surprise, Anger, Other, Neutral in that order. This vector tends to be entangled with various other conditioning inputs. More notably, it's entangled with text based on the text sentiment (eg. Angry texts will be more effectively conditioned to be angry, but if you try to make it sound sad it will be a lot less effective). It's also entangled with pitch standard deviation since larger values there tend to correlate to more emotional utterances. It's also heavily correlated with VQScore and DNSMOS as these conditionings favor neutral speech. It's also possible to do a form of "negative prompting" by doing CFG where the unconditional branch is set to a highly neutral emotion vector instead of the true unconditional value, doing this will exaggerate the emotions as it pushes the model away from being neutral.
28
+
29
+ Supported by transformer and hybrid models.
30
+ ---
31
+ ### fmax
32
+ - **Type:** `FourierConditioner`
33
+ - **Attributes:**
34
+ - **min_val:** `0`
35
+ - **max_val:** `24000`
36
+ - **uncond_type:** `learned`
37
+ - **Description:**
38
+ Specifies the max frequency of the audio. For best results select 22050 or 24000 as these correspond to 44.1 and 48KHz audio respectively. They should not be any different in terms of actual max frequency since the model's sampling rate is 44.1KHz but they represent different slices of data which lead to slightly different voicing. Selecting a lower value generally produces lower-quality results both in terms of acoustics and voicing.
39
+
40
+ For voice cloning it is recommended to use 22050.
41
+
42
+ Supported by transformer and hybrid models.
43
+ ---
44
+ ### pitch_std
45
+ - **Type:** `FourierConditioner`
46
+ - **Attributes:**
47
+ - **min_val:** `0`
48
+ - **max_val:** `400`
49
+ - **uncond_type:** `learned`
50
+ - **Description:**
51
+ Specifies the standard deviation of the pitch of the output audio. Wider variations of pitch tend to be more correlated with expressive speech. Good values are from 20-45 for normal speech and 60-150 for expressive speech. Higher than that generally tend to be crazier samples.
52
+
53
+ Supported by transformer and hybrid models.
54
+ ---
55
+ ### speaking_rate
56
+ - **Type:** `FourierConditioner`
57
+ - **Attributes:**
58
+ - **min_val:** `0`
59
+ - **max_val:** `40`
60
+ - **uncond_type:** `learned`
61
+ - **Description:**
62
+ Specifies the number of phonemes to be read per second. When entering a long text, it is advisable to adjust the speaking rate such that the number of phonemes is readable within the generation length. For example, if your generation length is 10 seconds, and your input is 300 phonemes, you would want either 30 phonemes per second (which is very very fast) or to generate a longer sample. The model's maximum is 30 seconds. Please note that unrealistic speaking rates can be OOD for the model and create undesirable effects, so at the 30-second limit, it can be better to cut the text short and do multiple generations than to feed the model the entire prompt and have an unrealistically low speaking rate.
63
+
64
+ Supported by transformer and hybrid models.
65
+ ---
66
+ ### language_id
67
+ - **Type:** `IntegerConditioner`
68
+ - **Attributes:**
69
+ - **min_val:** `-1`
70
+ - **max_val:** `126`
71
+ - **uncond_type:** `learned`
72
+ - **Description:**
73
+ Indicates which language the output should be in. A mapping for these values can be found in the [conditioning section](https://github.com/Zyphra/Zonos/blob/3807c8e04bd4beaadb9502b3df1ffa4b0350e3f7/zonos/conditioning.py#L308C1-L376C21) of Zonos.
74
+
75
+ Supported by transformer and hybrid models.
76
+ ---
77
+ ### vqscore_8
78
+ - **Type:** `FourierConditioner`
79
+ - **Attributes:**
80
+ - **input_dim:** `8`
81
+ - **min_val:** `0.5`
82
+ - **max_val:** `0.8`
83
+ - **uncond_type:** `learned`
84
+ - **Description:**
85
+ Encodes the desired [VQScore](https://github.com/JasonSWFu/VQscore) value for the output audio. VQScore is an unsupervised speech quality (cleanliness) estimation method that we found has superior generalization and reduced biases compared to supervised methods like DNSMOS. A good value for our model is 0.78 for high-quality speech. The eight dimensions correspond to consecutive 1/8th chunks of the audio. (eg. for an 8-second output, the first dimension represents the quality of the first second only). For inference, we generally set all 8 dimensions to the same value. This has an unfortunately strong correlation with expressiveness, so for expressive speech, we recommend setting it to unconditional.
86
+
87
+ Only applicable for the hybrid model.
88
+ ---
89
+ ### ctc_loss
90
+ - **Type:** `FourierConditioner`
91
+ - **Attributes:**
92
+ - **min_val:** `-1.0`
93
+ - **max_val:** `1000`
94
+ - **uncond_type:** `learned`
95
+ - **Description:**
96
+ Encodes loss values from a [CTC](https://en.wikipedia.org/wiki/Connectionist_temporal_classification) (Connectionist Temporal Classification) setup, this indicates how well the training-time transcription matched with the audio according to a CTC model. For inference always use low values (eg. 0.0 or 1.0)
97
+
98
+ Only applicable for the hybrid model.
99
+ ---
100
+ ### dnsmos_ovrl
101
+ - **Type:** `FourierConditioner`
102
+ - **Attributes:**
103
+ - **min_val:** `1`
104
+ - **max_val:** `5`
105
+ - **uncond_type:** `learned`
106
+ - **Description:**
107
+ A [MOS](https://arxiv.org/abs/2110.01763) score for the output audio. This is similar to VQScore and tends to have a stronger entanglement with emotions. It additionally has a strong entanglement with languages. Set to 4.0 for very clean and neutral English speech, else we recommend setting it to unconditional.
108
+
109
+ Only applicable for the hybrid model.
110
+ ---
111
+ ### speaker_noised
112
+ - **Type:** `IntegerConditioner`
113
+ - **Attributes:**
114
+ - **min_val:** `0`
115
+ - **max_val:** `1`
116
+ - **uncond_type:** `learned`
117
+ - **Description:**
118
+ Indicates if the speaker embedding is noisy or not. If checked this lets the model clean (denoise) the input speaker embedding. When this is set to True, VQScore and DNSMOS will have a lot more power to clean the speaker embedding, so for very noisy input samples we recommend setting this to True and specifying a high VQScore value. If your speaker cloning outputs sound echo-y or do weird things, setting this to True will help.
119
+
120
+ Only applicable for the hybrid model.
Zonos-main/Dockerfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel
2
+ RUN pip install uv
3
+
4
+ RUN apt update && \
5
+ apt install -y espeak-ng && \
6
+ rm -rf /var/lib/apt/lists/*
7
+
8
+ WORKDIR /app
9
+ COPY . ./
10
+
11
+ RUN uv pip install --system -e . && uv pip install --system -e .[compile]
Zonos-main/LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright [yyyy] [name of copyright owner]
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
Zonos-main/README.md ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Zonos-v0.1
2
+
3
+ <div align="center">
4
+ <img src="assets/ZonosHeader.png"
5
+ alt="Alt text"
6
+ style="width: 500px;
7
+ height: auto;
8
+ object-position: center top;">
9
+ </div>
10
+
11
+ <div align="center">
12
+ <a href="https://discord.gg/gTW9JwST8q" target="_blank">
13
+ <img src="https://img.shields.io/badge/Join%20Our%20Discord-7289DA?style=for-the-badge&logo=discord&logoColor=white" alt="Discord">
14
+ </a>
15
+ </div>
16
+
17
+ ---
18
+
19
+ Zonos-v0.1 is a leading open-weight text-to-speech model trained on more than 200k hours of varied multilingual speech, delivering expressiveness and quality on par with—or even surpassing—top TTS providers.
20
+
21
+ Our model enables highly natural speech generation from text prompts when given a speaker embedding or audio prefix, and can accurately perform speech cloning when given a reference clip spanning just a few seconds. The conditioning setup also allows for fine control over speaking rate, pitch variation, audio quality, and emotions such as happiness, fear, sadness, and anger. The model outputs speech natively at 44kHz.
22
+
23
+ ##### For more details and speech samples, check out our blog [here](https://www.zyphra.com/post/beta-release-of-zonos-v0-1)
24
+
25
+ ##### We also have a hosted version available at [playground.zyphra.com/audio](https://playground.zyphra.com/audio)
26
+
27
+ ---
28
+
29
+ Zonos follows a straightforward architecture: text normalization and phonemization via eSpeak, followed by DAC token prediction through a transformer or hybrid backbone. An overview of the architecture can be seen below.
30
+
31
+ <div align="center">
32
+ <img src="assets/ArchitectureDiagram.png"
33
+ alt="Alt text"
34
+ style="width: 1000px;
35
+ height: auto;
36
+ object-position: center top;">
37
+ </div>
38
+
39
+ ---
40
+
41
+ ## Usage
42
+
43
+ ### Python
44
+
45
+ ```python
46
+ import torch
47
+ import torchaudio
48
+ from zonos.model import Zonos
49
+ from zonos.conditioning import make_cond_dict
50
+ from zonos.utils import DEFAULT_DEVICE as device
51
+
52
+ # model = Zonos.from_pretrained("Zyphra/Zonos-v0.1-hybrid", device=device)
53
+ model = Zonos.from_pretrained("Zyphra/Zonos-v0.1-transformer", device=device)
54
+
55
+ wav, sampling_rate = torchaudio.load("assets/exampleaudio.mp3")
56
+ speaker = model.make_speaker_embedding(wav, sampling_rate)
57
+
58
+ cond_dict = make_cond_dict(text="Hello, world!", speaker=speaker, language="en-us")
59
+ conditioning = model.prepare_conditioning(cond_dict)
60
+
61
+ codes = model.generate(conditioning)
62
+
63
+ wavs = model.autoencoder.decode(codes).cpu()
64
+ torchaudio.save("sample.wav", wavs[0], model.autoencoder.sampling_rate)
65
+ ```
66
+
67
+ ### Gradio interface (recommended)
68
+
69
+ ```bash
70
+ uv run gradio_interface.py
71
+ # python gradio_interface.py
72
+ ```
73
+
74
+ This should produce a `sample.wav` file in your project root directory.
75
+
76
+ _For repeated sampling we highly recommend using the gradio interface instead, as the minimal example needs to load the model every time it is run._
77
+
78
+ ## Features
79
+
80
+ - Zero-shot TTS with voice cloning: Input desired text and a 10-30s speaker sample to generate high quality TTS output
81
+ - Audio prefix inputs: Add text plus an audio prefix for even richer speaker matching. Audio prefixes can be used to elicit behaviours such as whispering which can otherwise be challenging to replicate when cloning from speaker embeddings
82
+ - Multilingual support: Zonos-v0.1 supports English, Japanese, Chinese, French, and German
83
+ - Audio quality and emotion control: Zonos offers fine-grained control of many aspects of the generated audio. These include speaking rate, pitch, maximum frequency, audio quality, and various emotions such as happiness, anger, sadness, and fear.
84
+ - Fast: our model runs with a real-time factor of ~2x on an RTX 4090 (i.e. generates 2 seconds of audio per 1 second of compute time)
85
+ - Gradio WebUI: Zonos comes packaged with an easy to use gradio interface to generate speech
86
+ - Simple installation and deployment: Zonos can be installed and deployed simply using the docker file packaged with our repository.
87
+
88
+ ## Installation
89
+
90
+ #### System requirements
91
+
92
+ - **Operating System:** Linux (preferably Ubuntu 22.04/24.04), macOS
93
+ - **GPU:** 6GB+ VRAM, Hybrid additionally requires a 3000-series or newer Nvidia GPU
94
+
95
+ Note: Zonos can also run on CPU provided there is enough free RAM. However, this will be a lot slower than running on a dedicated GPU, and likely won't be sufficient for interactive use.
96
+
97
+ For experimental windows support check out [this fork](https://github.com/sdbds/Zonos-for-windows).
98
+
99
+ See also [Docker Installation](#docker-installation)
100
+
101
+ #### System dependencies
102
+
103
+ Zonos depends on the eSpeak library phonemization. You can install it on Ubuntu with the following command:
104
+
105
+ ```bash
106
+ apt install -y espeak-ng # For Ubuntu
107
+ # brew install espeak-ng # For MacOS
108
+ ```
109
+
110
+ #### Python dependencies
111
+
112
+ We highly recommend using a recent version of [uv](https://docs.astral.sh/uv/#installation) for installation. If you don't have uv installed, you can install it via pip: `pip install -U uv`.
113
+
114
+ ##### Installing into a new uv virtual environment (recommended)
115
+
116
+ ```bash
117
+ uv sync
118
+ uv sync --extra compile # optional but needed to run the hybrid
119
+ uv pip install -e .
120
+ ```
121
+
122
+ ##### Installing into the system/actived environment using uv
123
+
124
+ ```bash
125
+ uv pip install -e .
126
+ uv pip install -e .[compile] # optional but needed to run the hybrid
127
+ ```
128
+
129
+ ##### Installing into the system/actived environment using pip
130
+
131
+ ```bash
132
+ pip install -e .
133
+ pip install --no-build-isolation -e .[compile] # optional but needed to run the hybrid
134
+ ```
135
+
136
+ ##### Confirm that it's working
137
+
138
+ For convenience we provide a minimal example to check that the installation works:
139
+
140
+ ```bash
141
+ uv run sample.py
142
+ # python sample.py
143
+ ```
144
+
145
+ ## Docker installation
146
+
147
+ ```bash
148
+ git clone https://github.com/Zyphra/Zonos.git
149
+ cd Zonos
150
+
151
+ # For gradio
152
+ docker compose up
153
+
154
+ # Or for development you can do
155
+ docker build -t zonos .
156
+ docker run -it --gpus=all --net=host -v /path/to/Zonos:/Zonos -t zonos
157
+ cd /Zonos
158
+ python sample.py # this will generate a sample.wav in /Zonos
159
+ ```
Zonos-main/assets/ArchitectureDiagram.png ADDED

Git LFS Details

  • SHA256: 8946541130ba1eae6cc6d22e53c48ccedd39191c6ae72e65affaca9dec47b41f
  • Pointer size: 131 Bytes
  • Size of remote file: 751 kB
Zonos-main/assets/ZonosHeader.png ADDED
Zonos-main/assets/exampleaudio.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c0931115d43cd4e672dc137d2099c0b4e103d7207a2a42e957e41b2af30e4ae
3
+ size 819820
Zonos-main/assets/silence_100ms.wav ADDED
Binary file (9.43 kB). View file
 
Zonos-main/docker-compose.yml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: '3.8'
2
+
3
+ services:
4
+ zonos:
5
+ build:
6
+ context: .
7
+ dockerfile: Dockerfile
8
+ container_name: zonos_container
9
+ runtime: nvidia
10
+ network_mode: "host"
11
+ stdin_open: true
12
+ tty: true
13
+ command: ["python3", "gradio_interface.py"]
14
+ environment:
15
+ - NVIDIA_VISIBLE_DEVICES=0
16
+ - GRADIO_SHARE=False
Zonos-main/gradio_interface.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import gradio as gr
4
+ from os import getenv
5
+
6
+ from zonos.model import Zonos, DEFAULT_BACKBONE_CLS as ZonosBackbone
7
+ from zonos.conditioning import make_cond_dict, supported_language_codes
8
+ from zonos.utils import DEFAULT_DEVICE as device
9
+
10
+ CURRENT_MODEL_TYPE = None
11
+ CURRENT_MODEL = None
12
+
13
+ SPEAKER_EMBEDDING = None
14
+ SPEAKER_AUDIO_PATH = None
15
+
16
+
17
+ def load_model_if_needed(model_choice: str):
18
+ global CURRENT_MODEL_TYPE, CURRENT_MODEL
19
+ if CURRENT_MODEL_TYPE != model_choice:
20
+ if CURRENT_MODEL is not None:
21
+ del CURRENT_MODEL
22
+ torch.cuda.empty_cache()
23
+ print(f"Loading {model_choice} model...")
24
+ CURRENT_MODEL = Zonos.from_pretrained(model_choice, device=device)
25
+ CURRENT_MODEL.requires_grad_(False).eval()
26
+ CURRENT_MODEL_TYPE = model_choice
27
+ print(f"{model_choice} model loaded successfully!")
28
+ return CURRENT_MODEL
29
+
30
+
31
+ def update_ui(model_choice):
32
+ """
33
+ Dynamically show/hide UI elements based on the model's conditioners.
34
+ We do NOT display 'language_id' or 'ctc_loss' even if they exist in the model.
35
+ """
36
+ model = load_model_if_needed(model_choice)
37
+ cond_names = [c.name for c in model.prefix_conditioner.conditioners]
38
+ print("Conditioners in this model:", cond_names)
39
+
40
+ text_update = gr.update(visible=("espeak" in cond_names))
41
+ language_update = gr.update(visible=("espeak" in cond_names))
42
+ speaker_audio_update = gr.update(visible=("speaker" in cond_names))
43
+ prefix_audio_update = gr.update(visible=True)
44
+ emotion1_update = gr.update(visible=("emotion" in cond_names))
45
+ emotion2_update = gr.update(visible=("emotion" in cond_names))
46
+ emotion3_update = gr.update(visible=("emotion" in cond_names))
47
+ emotion4_update = gr.update(visible=("emotion" in cond_names))
48
+ emotion5_update = gr.update(visible=("emotion" in cond_names))
49
+ emotion6_update = gr.update(visible=("emotion" in cond_names))
50
+ emotion7_update = gr.update(visible=("emotion" in cond_names))
51
+ emotion8_update = gr.update(visible=("emotion" in cond_names))
52
+ vq_single_slider_update = gr.update(visible=("vqscore_8" in cond_names))
53
+ fmax_slider_update = gr.update(visible=("fmax" in cond_names))
54
+ pitch_std_slider_update = gr.update(visible=("pitch_std" in cond_names))
55
+ speaking_rate_slider_update = gr.update(visible=("speaking_rate" in cond_names))
56
+ dnsmos_slider_update = gr.update(visible=("dnsmos_ovrl" in cond_names))
57
+ speaker_noised_checkbox_update = gr.update(visible=("speaker_noised" in cond_names))
58
+ unconditional_keys_update = gr.update(
59
+ choices=[name for name in cond_names if name not in ("espeak", "language_id")]
60
+ )
61
+
62
+ return (
63
+ text_update,
64
+ language_update,
65
+ speaker_audio_update,
66
+ prefix_audio_update,
67
+ emotion1_update,
68
+ emotion2_update,
69
+ emotion3_update,
70
+ emotion4_update,
71
+ emotion5_update,
72
+ emotion6_update,
73
+ emotion7_update,
74
+ emotion8_update,
75
+ vq_single_slider_update,
76
+ fmax_slider_update,
77
+ pitch_std_slider_update,
78
+ speaking_rate_slider_update,
79
+ dnsmos_slider_update,
80
+ speaker_noised_checkbox_update,
81
+ unconditional_keys_update,
82
+ )
83
+
84
+
85
+ def generate_audio(
86
+ model_choice,
87
+ text,
88
+ language,
89
+ speaker_audio,
90
+ prefix_audio,
91
+ e1,
92
+ e2,
93
+ e3,
94
+ e4,
95
+ e5,
96
+ e6,
97
+ e7,
98
+ e8,
99
+ vq_single,
100
+ fmax,
101
+ pitch_std,
102
+ speaking_rate,
103
+ dnsmos_ovrl,
104
+ speaker_noised,
105
+ cfg_scale,
106
+ top_p,
107
+ top_k,
108
+ min_p,
109
+ linear,
110
+ confidence,
111
+ quadratic,
112
+ seed,
113
+ randomize_seed,
114
+ unconditional_keys,
115
+ progress=gr.Progress(),
116
+ ):
117
+ """
118
+ Generates audio based on the provided UI parameters.
119
+ We do NOT use language_id or ctc_loss even if the model has them.
120
+ """
121
+ selected_model = load_model_if_needed(model_choice)
122
+
123
+ speaker_noised_bool = bool(speaker_noised)
124
+ fmax = float(fmax)
125
+ pitch_std = float(pitch_std)
126
+ speaking_rate = float(speaking_rate)
127
+ dnsmos_ovrl = float(dnsmos_ovrl)
128
+ cfg_scale = float(cfg_scale)
129
+ top_p = float(top_p)
130
+ top_k = int(top_k)
131
+ min_p = float(min_p)
132
+ linear = float(linear)
133
+ confidence = float(confidence)
134
+ quadratic = float(quadratic)
135
+ seed = int(seed)
136
+ max_new_tokens = 86 * 30
137
+
138
+ # This is a bit ew, but works for now.
139
+ global SPEAKER_AUDIO_PATH, SPEAKER_EMBEDDING
140
+
141
+ if randomize_seed:
142
+ seed = torch.randint(0, 2**32 - 1, (1,)).item()
143
+ torch.manual_seed(seed)
144
+
145
+ if speaker_audio is not None and "speaker" not in unconditional_keys:
146
+ if speaker_audio != SPEAKER_AUDIO_PATH:
147
+ print("Recomputed speaker embedding")
148
+ wav, sr = torchaudio.load(speaker_audio)
149
+ SPEAKER_EMBEDDING = selected_model.make_speaker_embedding(wav, sr)
150
+ SPEAKER_EMBEDDING = SPEAKER_EMBEDDING.to(device, dtype=torch.bfloat16)
151
+ SPEAKER_AUDIO_PATH = speaker_audio
152
+
153
+ audio_prefix_codes = None
154
+ if prefix_audio is not None:
155
+ wav_prefix, sr_prefix = torchaudio.load(prefix_audio)
156
+ wav_prefix = wav_prefix.mean(0, keepdim=True)
157
+ wav_prefix = selected_model.autoencoder.preprocess(wav_prefix, sr_prefix)
158
+ wav_prefix = wav_prefix.to(device, dtype=torch.float32)
159
+ audio_prefix_codes = selected_model.autoencoder.encode(wav_prefix.unsqueeze(0))
160
+
161
+ emotion_tensor = torch.tensor(list(map(float, [e1, e2, e3, e4, e5, e6, e7, e8])), device=device)
162
+
163
+ vq_val = float(vq_single)
164
+ vq_tensor = torch.tensor([vq_val] * 8, device=device).unsqueeze(0)
165
+
166
+ cond_dict = make_cond_dict(
167
+ text=text,
168
+ language=language,
169
+ speaker=SPEAKER_EMBEDDING,
170
+ emotion=emotion_tensor,
171
+ vqscore_8=vq_tensor,
172
+ fmax=fmax,
173
+ pitch_std=pitch_std,
174
+ speaking_rate=speaking_rate,
175
+ dnsmos_ovrl=dnsmos_ovrl,
176
+ speaker_noised=speaker_noised_bool,
177
+ device=device,
178
+ unconditional_keys=unconditional_keys,
179
+ )
180
+ conditioning = selected_model.prepare_conditioning(cond_dict)
181
+
182
+ estimated_generation_duration = 30 * len(text) / 400
183
+ estimated_total_steps = int(estimated_generation_duration * 86)
184
+
185
+ def update_progress(_frame: torch.Tensor, step: int, _total_steps: int) -> bool:
186
+ progress((step, estimated_total_steps))
187
+ return True
188
+
189
+ codes = selected_model.generate(
190
+ prefix_conditioning=conditioning,
191
+ audio_prefix_codes=audio_prefix_codes,
192
+ max_new_tokens=max_new_tokens,
193
+ cfg_scale=cfg_scale,
194
+ batch_size=1,
195
+ sampling_params=dict(top_p=top_p, top_k=top_k, min_p=min_p, linear=linear, conf=confidence, quad=quadratic),
196
+ callback=update_progress,
197
+ )
198
+
199
+ wav_out = selected_model.autoencoder.decode(codes).cpu().detach()
200
+ sr_out = selected_model.autoencoder.sampling_rate
201
+ if wav_out.dim() == 2 and wav_out.size(0) > 1:
202
+ wav_out = wav_out[0:1, :]
203
+ return (sr_out, wav_out.squeeze().numpy()), seed
204
+
205
+
206
+ def build_interface():
207
+ supported_models = []
208
+ if "transformer" in ZonosBackbone.supported_architectures:
209
+ supported_models.append("Zyphra/Zonos-v0.1-transformer")
210
+
211
+ if "hybrid" in ZonosBackbone.supported_architectures:
212
+ supported_models.append("Zyphra/Zonos-v0.1-hybrid")
213
+ else:
214
+ print(
215
+ "| The current ZonosBackbone does not support the hybrid architecture, meaning only the transformer model will be available in the model selector.\n"
216
+ "| This probably means the mamba-ssm library has not been installed."
217
+ )
218
+
219
+ with gr.Blocks() as demo:
220
+ with gr.Row():
221
+ with gr.Column():
222
+ model_choice = gr.Dropdown(
223
+ choices=supported_models,
224
+ value=supported_models[0],
225
+ label="Zonos Model Type",
226
+ info="Select the model variant to use.",
227
+ )
228
+ text = gr.Textbox(
229
+ label="Text to Synthesize",
230
+ value="Zonos uses eSpeak for text to phoneme conversion!",
231
+ lines=4,
232
+ max_length=500, # approximately
233
+ )
234
+ language = gr.Dropdown(
235
+ choices=supported_language_codes,
236
+ value="en-us",
237
+ label="Language Code",
238
+ info="Select a language code.",
239
+ )
240
+ prefix_audio = gr.Audio(
241
+ value="assets/silence_100ms.wav",
242
+ label="Optional Prefix Audio (continue from this audio)",
243
+ type="filepath",
244
+ )
245
+ with gr.Column():
246
+ speaker_audio = gr.Audio(
247
+ label="Optional Speaker Audio (for cloning)",
248
+ type="filepath",
249
+ )
250
+ speaker_noised_checkbox = gr.Checkbox(label="Denoise Speaker?", value=False)
251
+
252
+ with gr.Row():
253
+ with gr.Column():
254
+ gr.Markdown("## Conditioning Parameters")
255
+ dnsmos_slider = gr.Slider(1.0, 5.0, value=4.0, step=0.1, label="DNSMOS Overall")
256
+ fmax_slider = gr.Slider(0, 24000, value=24000, step=1, label="Fmax (Hz)")
257
+ vq_single_slider = gr.Slider(0.5, 0.8, 0.78, 0.01, label="VQ Score")
258
+ pitch_std_slider = gr.Slider(0.0, 300.0, value=45.0, step=1, label="Pitch Std")
259
+ speaking_rate_slider = gr.Slider(5.0, 30.0, value=15.0, step=0.5, label="Speaking Rate")
260
+
261
+ with gr.Column():
262
+ gr.Markdown("## Generation Parameters")
263
+ cfg_scale_slider = gr.Slider(1.0, 5.0, 2.0, 0.1, label="CFG Scale")
264
+ seed_number = gr.Number(label="Seed", value=420, precision=0)
265
+ randomize_seed_toggle = gr.Checkbox(label="Randomize Seed (before generation)", value=True)
266
+
267
+ with gr.Accordion("Sampling", open=False):
268
+ with gr.Row():
269
+ with gr.Column():
270
+ gr.Markdown("### NovelAi's unified sampler")
271
+ linear_slider = gr.Slider(-2.0, 2.0, 0.5, 0.01, label="Linear (set to 0 to disable unified sampling)", info="High values make the output less random.")
272
+ #Conf's theoretical range is between -2 * Quad and 0.
273
+ confidence_slider = gr.Slider(-2.0, 2.0, 0.40, 0.01, label="Confidence", info="Low values make random outputs more random.")
274
+ quadratic_slider = gr.Slider(-2.0, 2.0, 0.00, 0.01, label="Quadratic", info="High values make low probablities much lower.")
275
+ with gr.Column():
276
+ gr.Markdown("### Legacy sampling")
277
+ top_p_slider = gr.Slider(0.0, 1.0, 0, 0.01, label="Top P")
278
+ min_k_slider = gr.Slider(0.0, 1024, 0, 1, label="Min K")
279
+ min_p_slider = gr.Slider(0.0, 1.0, 0, 0.01, label="Min P")
280
+
281
+ with gr.Accordion("Advanced Parameters", open=False):
282
+ gr.Markdown(
283
+ "### Unconditional Toggles\n"
284
+ "Checking a box will make the model ignore the corresponding conditioning value and make it unconditional.\n"
285
+ 'Practically this means the given conditioning feature will be unconstrained and "filled in automatically".'
286
+ )
287
+ with gr.Row():
288
+ unconditional_keys = gr.CheckboxGroup(
289
+ [
290
+ "speaker",
291
+ "emotion",
292
+ "vqscore_8",
293
+ "fmax",
294
+ "pitch_std",
295
+ "speaking_rate",
296
+ "dnsmos_ovrl",
297
+ "speaker_noised",
298
+ ],
299
+ value=["emotion"],
300
+ label="Unconditional Keys",
301
+ )
302
+
303
+ gr.Markdown(
304
+ "### Emotion Sliders\n"
305
+ "Warning: The way these sliders work is not intuitive and may require some trial and error to get the desired effect.\n"
306
+ "Certain configurations can cause the model to become unstable. Setting emotion to unconditional may help."
307
+ )
308
+ with gr.Row():
309
+ emotion1 = gr.Slider(0.0, 1.0, 1.0, 0.05, label="Happiness")
310
+ emotion2 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Sadness")
311
+ emotion3 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Disgust")
312
+ emotion4 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Fear")
313
+ with gr.Row():
314
+ emotion5 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Surprise")
315
+ emotion6 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Anger")
316
+ emotion7 = gr.Slider(0.0, 1.0, 0.1, 0.05, label="Other")
317
+ emotion8 = gr.Slider(0.0, 1.0, 0.2, 0.05, label="Neutral")
318
+
319
+ with gr.Column():
320
+ generate_button = gr.Button("Generate Audio")
321
+ output_audio = gr.Audio(label="Generated Audio", type="numpy", autoplay=True)
322
+
323
+ model_choice.change(
324
+ fn=update_ui,
325
+ inputs=[model_choice],
326
+ outputs=[
327
+ text,
328
+ language,
329
+ speaker_audio,
330
+ prefix_audio,
331
+ emotion1,
332
+ emotion2,
333
+ emotion3,
334
+ emotion4,
335
+ emotion5,
336
+ emotion6,
337
+ emotion7,
338
+ emotion8,
339
+ vq_single_slider,
340
+ fmax_slider,
341
+ pitch_std_slider,
342
+ speaking_rate_slider,
343
+ dnsmos_slider,
344
+ speaker_noised_checkbox,
345
+ unconditional_keys,
346
+ ],
347
+ )
348
+
349
+ # On page load, trigger the same UI refresh
350
+ demo.load(
351
+ fn=update_ui,
352
+ inputs=[model_choice],
353
+ outputs=[
354
+ text,
355
+ language,
356
+ speaker_audio,
357
+ prefix_audio,
358
+ emotion1,
359
+ emotion2,
360
+ emotion3,
361
+ emotion4,
362
+ emotion5,
363
+ emotion6,
364
+ emotion7,
365
+ emotion8,
366
+ vq_single_slider,
367
+ fmax_slider,
368
+ pitch_std_slider,
369
+ speaking_rate_slider,
370
+ dnsmos_slider,
371
+ speaker_noised_checkbox,
372
+ unconditional_keys,
373
+ ],
374
+ )
375
+
376
+ # Generate audio on button click
377
+ generate_button.click(
378
+ fn=generate_audio,
379
+ inputs=[
380
+ model_choice,
381
+ text,
382
+ language,
383
+ speaker_audio,
384
+ prefix_audio,
385
+ emotion1,
386
+ emotion2,
387
+ emotion3,
388
+ emotion4,
389
+ emotion5,
390
+ emotion6,
391
+ emotion7,
392
+ emotion8,
393
+ vq_single_slider,
394
+ fmax_slider,
395
+ pitch_std_slider,
396
+ speaking_rate_slider,
397
+ dnsmos_slider,
398
+ speaker_noised_checkbox,
399
+ cfg_scale_slider,
400
+ top_p_slider,
401
+ min_k_slider,
402
+ min_p_slider,
403
+ linear_slider,
404
+ confidence_slider,
405
+ quadratic_slider,
406
+ seed_number,
407
+ randomize_seed_toggle,
408
+ unconditional_keys,
409
+ ],
410
+ outputs=[output_audio, seed_number],
411
+ )
412
+
413
+ return demo
414
+
415
+
416
+ if __name__ == "__main__":
417
+ demo = build_interface()
418
+ share = getenv("GRADIO_SHARE", "False").lower() in ("true", "1", "t")
419
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=share)
Zonos-main/pyproject.toml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "zonos"
3
+ version = "0.1.0"
4
+ description = "Text-to-speech by Zyphra"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "torch>=2.5.1",
9
+ "setuptools",
10
+ "packaging",
11
+ "inflect>=7.5.0",
12
+ "kanjize>=1.5.0",
13
+ "numpy>=2.2.2",
14
+ "phonemizer>=3.3.0",
15
+ "sudachidict-full>=20241021",
16
+ "sudachipy>=0.6.10",
17
+ "torchaudio>=2.5.1",
18
+ "transformers>=4.48.1",
19
+ "soundfile>=0.13.1",
20
+ "huggingface-hub>=0.28.1",
21
+ "gradio>=5.15.0",
22
+ ]
23
+
24
+ # These are technically optional, but mamba-ssm is required to run hybrid models.
25
+ [project.optional-dependencies]
26
+ compile = [
27
+ "flash-attn>=2.7.3",
28
+ "mamba-ssm>=2.2.4",
29
+ "causal-conv1d>=1.5.0.post8",
30
+ ]
31
+
32
+ [tool.setuptools.packages.find]
33
+ include = ["zonos"]
34
+
35
+ [tool.uv]
36
+ no-build-isolation-package = ["flash-attn", "mamba-ssm", "causal-conv1d"]
37
+
38
+ [tool.ruff]
39
+ line-length = 120
Zonos-main/sample.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ from zonos.model import Zonos
4
+ from zonos.conditioning import make_cond_dict
5
+ from zonos.utils import DEFAULT_DEVICE as device
6
+
7
+ # model = Zonos.from_pretrained("Zyphra/Zonos-v0.1-hybrid", device=device)
8
+ model = Zonos.from_pretrained("Zyphra/Zonos-v0.1-transformer", device=device)
9
+
10
+ wav, sampling_rate = torchaudio.load("assets/exampleaudio.mp3")
11
+ speaker = model.make_speaker_embedding(wav, sampling_rate)
12
+
13
+ torch.manual_seed(421)
14
+
15
+ cond_dict = make_cond_dict(text="Hello, world!", speaker=speaker, language="en-us")
16
+ conditioning = model.prepare_conditioning(cond_dict)
17
+
18
+ codes = model.generate(conditioning)
19
+
20
+ wavs = model.autoencoder.decode(codes).cpu()
21
+ torchaudio.save("sample.wav", wavs[0], model.autoencoder.sampling_rate)
Zonos-main/uv.lock ADDED
The diff for this file is too large to render. See raw diff
 
Zonos-main/zonos/autoencoder.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torchaudio
5
+ from transformers.models.dac import DacModel
6
+
7
+
8
+ class DACAutoencoder:
9
+ def __init__(self):
10
+ super().__init__()
11
+ self.dac = DacModel.from_pretrained("descript/dac_44khz")
12
+ self.dac.eval().requires_grad_(False)
13
+ self.codebook_size = self.dac.config.codebook_size
14
+ self.num_codebooks = self.dac.quantizer.n_codebooks
15
+ self.sampling_rate = self.dac.config.sampling_rate
16
+
17
+ def preprocess(self, wav: torch.Tensor, sr: int) -> torch.Tensor:
18
+ wav = torchaudio.functional.resample(wav, sr, 44_100)
19
+ right_pad = math.ceil(wav.shape[-1] / 512) * 512 - wav.shape[-1]
20
+ return torch.nn.functional.pad(wav, (0, right_pad))
21
+
22
+ def encode(self, wav: torch.Tensor) -> torch.Tensor:
23
+ return self.dac.encode(wav).audio_codes
24
+
25
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
26
+ with torch.autocast(self.dac.device.type, torch.float16, enabled=self.dac.device.type != "cpu"):
27
+ return self.dac.decode(audio_codes=codes).audio_values.unsqueeze(1).float()
Zonos-main/zonos/backbone/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BACKBONES = {}
2
+
3
+ try:
4
+ from ._mamba_ssm import MambaSSMZonosBackbone
5
+
6
+ BACKBONES["mamba_ssm"] = MambaSSMZonosBackbone
7
+ except ImportError:
8
+ pass
9
+
10
+ from ._torch import TorchZonosBackbone
11
+
12
+ BACKBONES["torch"] = TorchZonosBackbone
Zonos-main/zonos/backbone/_mamba_ssm.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from mamba_ssm.models.mixer_seq_simple import create_block
4
+ from mamba_ssm.ops.triton.layer_norm import layer_norm_fn
5
+
6
+ from zonos.config import BackboneConfig, InferenceParams
7
+
8
+
9
+ class MambaSSMZonosBackbone(nn.Module):
10
+ supported_architectures = ["transformer", "hybrid"]
11
+
12
+ def __init__(self, config: BackboneConfig):
13
+ super().__init__()
14
+ self.config = config
15
+
16
+ self.layers = nn.ModuleList(
17
+ [
18
+ create_block(
19
+ d_model=config.d_model,
20
+ d_intermediate=config.d_intermediate
21
+ if (i not in config.attn_layer_idx)
22
+ else config.attn_mlp_d_intermediate,
23
+ ssm_cfg=config.ssm_cfg,
24
+ layer_idx=i,
25
+ attn_layer_idx=config.attn_layer_idx,
26
+ attn_cfg=config.attn_cfg,
27
+ norm_epsilon=config.norm_epsilon,
28
+ residual_in_fp32=config.residual_in_fp32,
29
+ fused_add_norm=True,
30
+ rms_norm=config.rms_norm,
31
+ )
32
+ for i in range(config.n_layer)
33
+ ]
34
+ )
35
+
36
+ self.norm_f = nn.LayerNorm(config.d_model, eps=config.norm_epsilon)
37
+
38
+ def allocate_inference_cache(self, batch_size: int, max_seqlen: int, dtype: torch.dtype = torch.bfloat16):
39
+ return {
40
+ i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype)
41
+ for i, layer in enumerate(self.layers)
42
+ }
43
+
44
+ def forward(self, hidden_states: torch.Tensor, inference_params: InferenceParams | None = None):
45
+ residual = None
46
+ for layer in self.layers:
47
+ hidden_states, residual = layer(hidden_states, residual, inference_params)
48
+
49
+ return layer_norm_fn(
50
+ hidden_states,
51
+ self.norm_f.weight,
52
+ self.norm_f.bias,
53
+ residual,
54
+ eps=self.norm_f.eps,
55
+ residual_in_fp32=self.config.residual_in_fp32,
56
+ is_rms_norm=self.config.rms_norm,
57
+ )
Zonos-main/zonos/backbone/_torch.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/095b2229ee3a40e379c11f05b94bd6923db63b4b/model.py
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+
6
+ from zonos.config import BackboneConfig, InferenceParams
7
+
8
+
9
+ def precompute_freqs_cis(seq_len: int, n_elem: int, base: float = 10000) -> torch.Tensor:
10
+ freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
11
+ t = torch.arange(seq_len, device=freqs.device)
12
+ freqs = torch.outer(t, freqs)
13
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
14
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
15
+ return cache
16
+
17
+
18
+ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
19
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
20
+ freqs_cis = freqs_cis.view(-1, xshaped.size(1), 1, xshaped.size(3), 2)
21
+ x_out2 = torch.stack(
22
+ [
23
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
24
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
25
+ ],
26
+ -1,
27
+ )
28
+
29
+ x_out2 = x_out2.flatten(3)
30
+ return x_out2.type_as(x)
31
+
32
+
33
+ def _update_kv_cache(
34
+ k: torch.Tensor, v: torch.Tensor, inference_params: InferenceParams, layer_idx: int
35
+ ) -> torch.Tensor:
36
+ """k/v: (batch_size, seqlen, nheads, head_dim) or (batch_size, 1, nheads, head_dim)"""
37
+ assert layer_idx in inference_params.key_value_memory_dict
38
+ kv_cache, _ = inference_params.key_value_memory_dict[layer_idx]
39
+ # Adjust key and value for inference
40
+ batch_start = inference_params.batch_size_offset
41
+ batch_end = batch_start + k.shape[0]
42
+ sequence_start = inference_params.seqlen_offset
43
+ sequence_end = sequence_start + k.shape[1]
44
+ assert batch_end <= kv_cache.shape[0]
45
+ assert sequence_end <= kv_cache.shape[1]
46
+ assert kv_cache is not None
47
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, 0, ...] = k
48
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, 1, ...] = v
49
+ return kv_cache[batch_start:batch_end, :sequence_end, ...]
50
+
51
+
52
+ class TorchZonosBackbone(nn.Module):
53
+ supported_architectures = ["transformer"]
54
+ freqs_cis: torch.Tensor
55
+
56
+ def __init__(self, config: BackboneConfig):
57
+ assert not config.ssm_cfg, "This backbone implementation only supports the Transformer model."
58
+ super().__init__()
59
+ self.config = config
60
+
61
+ self.layers = nn.ModuleList(TransformerBlock(config, i) for i in range(config.n_layer))
62
+ self.norm_f = nn.LayerNorm(config.d_model, eps=config.norm_epsilon)
63
+
64
+ def allocate_inference_cache(self, batch_size: int, max_seqlen: int, dtype: torch.dtype = torch.bfloat16):
65
+ # TODO: This function should be pure
66
+ head_dim = self.config.d_model // self.config.attn_cfg["num_heads"]
67
+ self.freqs_cis = precompute_freqs_cis(16384, head_dim)
68
+ return {
69
+ i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype)
70
+ for i, layer in enumerate(self.layers)
71
+ }
72
+
73
+ def forward(self, hidden_states: torch.Tensor, inference_params: InferenceParams) -> torch.Tensor:
74
+ input_pos = torch.arange(0, hidden_states.shape[1], device=hidden_states.device)
75
+ input_pos = input_pos + inference_params.lengths_per_sample.unsqueeze(-1)
76
+
77
+ freqs_cis = self.freqs_cis[input_pos].expand(hidden_states.shape[0], -1, -1, -1)
78
+ for i, layer in enumerate(self.layers):
79
+ hidden_states = layer(hidden_states, inference_params, freqs_cis)
80
+ return self.norm_f(hidden_states)
81
+
82
+
83
+ class TransformerBlock(nn.Module):
84
+ def __init__(self, config: BackboneConfig, layer_idx: int) -> None:
85
+ super().__init__()
86
+ self.config = config
87
+
88
+ self.norm = nn.LayerNorm(config.d_model, eps=config.norm_epsilon)
89
+ self.mixer = Attention(config, layer_idx)
90
+ self.norm2 = nn.LayerNorm(config.d_model, eps=config.norm_epsilon)
91
+ self.mlp = FeedForward(config)
92
+
93
+ self.num_heads_kv = config.attn_cfg["num_heads_kv"]
94
+ self.head_dim = config.d_model // config.attn_cfg["num_heads"]
95
+
96
+ def allocate_inference_cache(self, batch_size: int, max_seqlen: int, dtype: torch.dtype = torch.bfloat16):
97
+ return torch.empty(batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim, dtype=dtype), None
98
+
99
+ def forward(self, x: torch.Tensor, inference_params: InferenceParams, freqs_cis: torch.Tensor) -> torch.Tensor:
100
+ x = x + self.mixer(self.norm(x), inference_params, freqs_cis)
101
+ x = x + self.mlp(self.norm2(x))
102
+ return x
103
+
104
+
105
+ class Attention(nn.Module):
106
+ def __init__(self, config: BackboneConfig, layer_idx: int):
107
+ super().__init__()
108
+ self.num_heads = config.attn_cfg["num_heads"]
109
+ self.num_heads_kv = config.attn_cfg["num_heads_kv"]
110
+ self.head_dim = config.d_model // self.num_heads
111
+ self.layer_idx = layer_idx
112
+
113
+ total_head_dim = (self.num_heads + 2 * self.num_heads_kv) * self.head_dim
114
+ self.in_proj = nn.Linear(config.d_model, total_head_dim, bias=False)
115
+ self.out_proj = nn.Linear(self.num_heads * self.head_dim, config.d_model, bias=False)
116
+
117
+ def forward(self, x: torch.Tensor, inference_params: InferenceParams, freqs_cis: torch.Tensor) -> torch.Tensor:
118
+ batch_size, seqlen, _ = x.shape
119
+
120
+ q_size = self.num_heads * self.head_dim
121
+ kv_size = self.num_heads_kv * self.head_dim
122
+ q, k, v = self.in_proj(x).split([q_size, kv_size, kv_size], dim=-1)
123
+
124
+ q = q.view(batch_size, seqlen, self.num_heads, self.head_dim)
125
+ k = k.view(batch_size, seqlen, self.num_heads_kv, self.head_dim)
126
+ v = v.view(batch_size, seqlen, self.num_heads_kv, self.head_dim)
127
+
128
+ q = apply_rotary_emb(q, freqs_cis)
129
+ k = apply_rotary_emb(k, freqs_cis)
130
+
131
+ kv = _update_kv_cache(k, v, inference_params, self.layer_idx)
132
+ k, v = kv.unbind(dim=-3)
133
+
134
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
135
+
136
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=seqlen > 1, enable_gqa=True)
137
+
138
+ y = y.transpose(1, 2).contiguous().view(batch_size, seqlen, q_size)
139
+
140
+ y = self.out_proj(y)
141
+ return y
142
+
143
+
144
+ class FeedForward(nn.Module):
145
+ def __init__(self, config: BackboneConfig) -> None:
146
+ super().__init__()
147
+ self.fc1 = nn.Linear(config.d_model, 2 * config.attn_mlp_d_intermediate, bias=False)
148
+ self.fc2 = nn.Linear(config.attn_mlp_d_intermediate, config.d_model, bias=False)
149
+
150
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
151
+ y, gate = self.fc1(x).chunk(2, dim=-1)
152
+ return self.fc2(y * F.silu(gate))
Zonos-main/zonos/codebook_pattern.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def apply_delay_pattern(codes: torch.Tensor, mask_token: int):
6
+ codes = F.pad(codes, (0, codes.shape[1]), value=mask_token)
7
+ return torch.stack([codes[:, k].roll(k + 1) for k in range(codes.shape[1])], dim=1)
8
+
9
+
10
+ def revert_delay_pattern(codes: torch.Tensor):
11
+ _, n_q, seq_len = codes.shape
12
+ return torch.stack([codes[:, k, k + 1 : seq_len - n_q + k + 1] for k in range(n_q)], dim=1)
Zonos-main/zonos/conditioning.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import cache
2
+ from typing import Any, Literal, Iterable
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from zonos.config import PrefixConditionerConfig
8
+ from zonos.utils import DEFAULT_DEVICE
9
+
10
+
11
+ class Conditioner(nn.Module):
12
+ def __init__(
13
+ self,
14
+ output_dim: int,
15
+ name: str,
16
+ cond_dim: int | None = None,
17
+ projection: Literal["none", "linear", "mlp"] = "none",
18
+ uncond_type: Literal["learned", "none"] = "none",
19
+ **kwargs,
20
+ ):
21
+ super().__init__()
22
+ self.name = name
23
+ self.output_dim = output_dim
24
+ self.cond_dim = cond_dim = cond_dim or output_dim
25
+
26
+ if projection == "linear":
27
+ self.project = nn.Linear(cond_dim, output_dim)
28
+ elif projection == "mlp":
29
+ self.project = nn.Sequential(
30
+ nn.Linear(cond_dim, output_dim),
31
+ nn.SiLU(),
32
+ nn.Linear(output_dim, output_dim),
33
+ )
34
+ else:
35
+ self.project = nn.Identity()
36
+
37
+ self.uncond_vector = None
38
+ if uncond_type == "learned":
39
+ self.uncond_vector = nn.Parameter(torch.zeros(output_dim))
40
+
41
+ def apply_cond(self, *inputs: Any) -> torch.Tensor:
42
+ raise NotImplementedError()
43
+
44
+ def forward(self, inputs: tuple[Any, ...] | None) -> torch.Tensor:
45
+ if inputs is None:
46
+ assert self.uncond_vector is not None
47
+ return self.uncond_vector.data.view(1, 1, -1)
48
+
49
+ cond = self.apply_cond(*inputs)
50
+ cond = self.project(cond)
51
+ return cond
52
+
53
+
54
+ # ------- ESPEAK CONTAINMENT ZONE ------------------------------------------------------------------------------------------------------------------------------------------------
55
+ import os
56
+ import sys
57
+ import re
58
+ import unicodedata
59
+
60
+ import inflect
61
+ import torch
62
+ import torch.nn as nn
63
+ from kanjize import number2kanji
64
+ from phonemizer.backend import EspeakBackend
65
+ from sudachipy import Dictionary, SplitMode
66
+
67
+ if sys.platform == "darwin":
68
+ os.environ["PHONEMIZER_ESPEAK_LIBRARY"] = "/opt/homebrew/lib/libespeak-ng.dylib"
69
+
70
+ # --- Number normalization code from https://github.com/daniilrobnikov/vits2/blob/main/text/normalize_numbers.py ---
71
+
72
+ _inflect = inflect.engine()
73
+ _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
74
+ _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
75
+ _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
76
+ _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
77
+ _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
78
+ _number_re = re.compile(r"[0-9]+")
79
+
80
+
81
+ def _remove_commas(m: re.Match) -> str:
82
+ return m.group(1).replace(",", "")
83
+
84
+
85
+ def _expand_decimal_point(m: re.Match) -> str:
86
+ return m.group(1).replace(".", " point ")
87
+
88
+
89
+ def _expand_dollars(m: re.Match) -> str:
90
+ match = m.group(1)
91
+ parts = match.split(".")
92
+ if len(parts) > 2:
93
+ return match + " dollars" # Unexpected format
94
+ dollars = int(parts[0]) if parts[0] else 0
95
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
96
+ if dollars and cents:
97
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
98
+ cent_unit = "cent" if cents == 1 else "cents"
99
+ return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
100
+ elif dollars:
101
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
102
+ return "%s %s" % (dollars, dollar_unit)
103
+ elif cents:
104
+ cent_unit = "cent" if cents == 1 else "cents"
105
+ return "%s %s" % (cents, cent_unit)
106
+ else:
107
+ return "zero dollars"
108
+
109
+
110
+ def _expand_ordinal(m: re.Match) -> str:
111
+ return _inflect.number_to_words(m.group(0))
112
+
113
+
114
+ def _expand_number(m: re.Match) -> str:
115
+ num = int(m.group(0))
116
+ if num > 1000 and num < 3000:
117
+ if num == 2000:
118
+ return "two thousand"
119
+ elif num > 2000 and num < 2010:
120
+ return "two thousand " + _inflect.number_to_words(num % 100)
121
+ elif num % 100 == 0:
122
+ return _inflect.number_to_words(num // 100) + " hundred"
123
+ else:
124
+ return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
125
+ else:
126
+ return _inflect.number_to_words(num, andword="")
127
+
128
+
129
+ def normalize_numbers(text: str) -> str:
130
+ text = re.sub(_comma_number_re, _remove_commas, text)
131
+ text = re.sub(_pounds_re, r"\1 pounds", text)
132
+ text = re.sub(_dollars_re, _expand_dollars, text)
133
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
134
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
135
+ text = re.sub(_number_re, _expand_number, text)
136
+ return text
137
+
138
+
139
+ # --- Number normalization code end ---
140
+
141
+
142
+ PAD_ID, UNK_ID, BOS_ID, EOS_ID = 0, 1, 2, 3
143
+ SPECIAL_TOKEN_IDS = [PAD_ID, UNK_ID, BOS_ID, EOS_ID]
144
+
145
+ _punctuation = ';:,.!?¡¿—…"«»“”() *~-/\\&'
146
+ _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
147
+ _letters_ipa = (
148
+ "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌː��ʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
149
+ )
150
+
151
+ symbols = [*_punctuation, *_letters, *_letters_ipa]
152
+ _symbol_to_id = {s: i for i, s in enumerate(symbols, start=len(SPECIAL_TOKEN_IDS))}
153
+
154
+
155
+ def _get_symbol_id(s: str) -> int:
156
+ return _symbol_to_id.get(s, 1)
157
+
158
+
159
+ def get_symbol_ids(text: str) -> list[int]:
160
+ return list(map(_get_symbol_id, text))
161
+
162
+
163
+ def tokenize_phonemes(phonemes: list[str]) -> tuple[torch.Tensor, list[int]]:
164
+ phoneme_ids = [[BOS_ID, *get_symbol_ids(phonemes), EOS_ID] for phonemes in phonemes]
165
+ lengths = list(map(len, phoneme_ids))
166
+ longest = max(lengths)
167
+ phoneme_ids = [[PAD_ID] * (longest - len(ids)) + ids for ids in phoneme_ids]
168
+ return torch.tensor(phoneme_ids), lengths
169
+
170
+
171
+ def normalize_jp_text(text: str, tokenizer=Dictionary(dict="full").create()) -> str:
172
+ text = unicodedata.normalize("NFKC", text)
173
+ text = re.sub(r"\d+", lambda m: number2kanji(int(m[0])), text)
174
+ final_text = " ".join([x.reading_form() for x in tokenizer.tokenize(text, SplitMode.A)])
175
+ return final_text
176
+
177
+
178
+ def clean(texts: list[str], languages: list[str]) -> list[str]:
179
+ texts_out = []
180
+ for text, language in zip(texts, languages):
181
+ if "ja" in language:
182
+ text = normalize_jp_text(text)
183
+ else:
184
+ text = normalize_numbers(text)
185
+ texts_out.append(text)
186
+ return texts_out
187
+
188
+
189
+ @cache
190
+ def get_backend(language: str) -> "EspeakBackend":
191
+ import logging
192
+
193
+ from phonemizer.backend import EspeakBackend
194
+
195
+ logger = logging.getLogger("phonemizer")
196
+ backend = EspeakBackend(
197
+ language,
198
+ preserve_punctuation=True,
199
+ with_stress=True,
200
+ punctuation_marks=_punctuation,
201
+ logger=logger,
202
+ )
203
+ logger.setLevel(logging.ERROR)
204
+ return backend
205
+
206
+
207
+ def phonemize(texts: list[str], languages: list[str]) -> list[str]:
208
+ texts = clean(texts, languages)
209
+
210
+ batch_phonemes = []
211
+ for text, language in zip(texts, languages):
212
+ backend = get_backend(language)
213
+ phonemes = backend.phonemize([text], strip=True)
214
+ batch_phonemes.append(phonemes[0])
215
+
216
+ return batch_phonemes
217
+
218
+
219
+ class EspeakPhonemeConditioner(Conditioner):
220
+ def __init__(self, output_dim: int, **kwargs):
221
+ super().__init__(output_dim, **kwargs)
222
+ self.phoneme_embedder = nn.Embedding(len(SPECIAL_TOKEN_IDS) + len(symbols), output_dim)
223
+
224
+ def apply_cond(self, texts: list[str], languages: list[str]) -> torch.Tensor:
225
+ """
226
+ Args:
227
+ texts: list of texts to convert to phonemes
228
+ languages: ISO 639-1 -or otherwise eSpeak compatible- language code
229
+ """
230
+ device = self.phoneme_embedder.weight.device
231
+
232
+ phonemes = phonemize(texts, languages)
233
+ phoneme_ids, _ = tokenize_phonemes(phonemes)
234
+ phoneme_embeds = self.phoneme_embedder(phoneme_ids.to(device))
235
+
236
+ return phoneme_embeds
237
+
238
+
239
+ # ------- ESPEAK CONTAINMENT ZONE ------------------------------------------------------------------------------------------------------------------------------------------------
240
+
241
+
242
+ class FourierConditioner(Conditioner):
243
+ def __init__(
244
+ self,
245
+ output_dim: int,
246
+ input_dim: int = 1,
247
+ std: float = 1.0,
248
+ min_val: float = 0.0,
249
+ max_val: float = 1.0,
250
+ **kwargs,
251
+ ):
252
+ assert output_dim % 2 == 0
253
+ super().__init__(output_dim, **kwargs)
254
+ self.register_buffer("weight", torch.randn([output_dim // 2, input_dim]) * std)
255
+ self.input_dim, self.min_val, self.max_val = input_dim, min_val, max_val
256
+
257
+ def apply_cond(self, x: torch.Tensor) -> torch.Tensor:
258
+ assert x.shape[-1] == self.input_dim
259
+ x = (x - self.min_val) / (self.max_val - self.min_val) # [batch_size, seq_len, input_dim]
260
+ f = 2 * torch.pi * x.to(self.weight.dtype) @ self.weight.T # [batch_size, seq_len, output_dim // 2]
261
+ return torch.cat([f.cos(), f.sin()], dim=-1) # [batch_size, seq_len, output_dim]
262
+
263
+
264
+ class IntegerConditioner(Conditioner):
265
+ def __init__(self, output_dim: int, min_val: int = 0, max_val: int = 512, **kwargs):
266
+ super().__init__(output_dim, **kwargs)
267
+ self.min_val = min_val
268
+ self.max_val = max_val
269
+ self.int_embedder = nn.Embedding(max_val - min_val + 1, output_dim)
270
+
271
+ def apply_cond(self, x: torch.Tensor) -> torch.Tensor:
272
+ assert x.shape[-1] == 1
273
+ return self.int_embedder(x.squeeze(-1) - self.min_val) # [batch_size, seq_len, output_dim]
274
+
275
+
276
+ class PassthroughConditioner(Conditioner):
277
+ def __init__(self, output_dim: int, **kwargs):
278
+ super().__init__(output_dim, **kwargs)
279
+
280
+ def apply_cond(self, x: torch.Tensor) -> torch.Tensor:
281
+ assert x.shape[-1] == self.cond_dim
282
+ return x
283
+
284
+
285
+ _cond_cls_map = {
286
+ "PassthroughConditioner": PassthroughConditioner,
287
+ "EspeakPhonemeConditioner": EspeakPhonemeConditioner,
288
+ "FourierConditioner": FourierConditioner,
289
+ "IntegerConditioner": IntegerConditioner,
290
+ }
291
+
292
+
293
+ def build_conditioners(conditioners: list[dict], output_dim: int) -> list[Conditioner]:
294
+ return [_cond_cls_map[config["type"]](output_dim, **config) for config in conditioners]
295
+
296
+
297
+ class PrefixConditioner(Conditioner):
298
+ def __init__(self, config: PrefixConditionerConfig, output_dim: int):
299
+ super().__init__(output_dim, "prefix", projection=config.projection)
300
+ self.conditioners = nn.ModuleList(build_conditioners(config.conditioners, output_dim))
301
+ self.norm = nn.LayerNorm(output_dim)
302
+ self.required_keys = {c.name for c in self.conditioners if c.uncond_vector is None}
303
+
304
+ def forward(self, cond_dict: dict) -> torch.Tensor:
305
+ if not set(cond_dict).issuperset(self.required_keys):
306
+ raise ValueError(f"Missing required keys: {self.required_keys - set(cond_dict)}")
307
+ conds = []
308
+ for conditioner in self.conditioners:
309
+ conds.append(conditioner(cond_dict.get(conditioner.name)))
310
+ max_bsz = max(map(len, conds))
311
+ assert all(c.shape[0] in (max_bsz, 1) for c in conds)
312
+ conds = [c.expand(max_bsz, -1, -1) for c in conds]
313
+ return self.norm(self.project(torch.cat(conds, dim=-2)))
314
+
315
+
316
+ supported_language_codes = [
317
+ 'af', 'am', 'an', 'ar', 'as', 'az', 'ba', 'bg', 'bn', 'bpy', 'bs', 'ca', 'cmn',
318
+ 'cs', 'cy', 'da', 'de', 'el', 'en-029', 'en-gb', 'en-gb-scotland', 'en-gb-x-gbclan',
319
+ 'en-gb-x-gbcwmd', 'en-gb-x-rp', 'en-us', 'eo', 'es', 'es-419', 'et', 'eu', 'fa',
320
+ 'fa-latn', 'fi', 'fr-be', 'fr-ch', 'fr-fr', 'ga', 'gd', 'gn', 'grc', 'gu', 'hak',
321
+ 'hi', 'hr', 'ht', 'hu', 'hy', 'hyw', 'ia', 'id', 'is', 'it', 'ja', 'jbo', 'ka',
322
+ 'kk', 'kl', 'kn', 'ko', 'kok', 'ku', 'ky', 'la', 'lfn', 'lt', 'lv', 'mi', 'mk',
323
+ 'ml', 'mr', 'ms', 'mt', 'my', 'nb', 'nci', 'ne', 'nl', 'om', 'or', 'pa', 'pap',
324
+ 'pl', 'pt', 'pt-br', 'py', 'quc', 'ro', 'ru', 'ru-lv', 'sd', 'shn', 'si', 'sk',
325
+ 'sl', 'sq', 'sr', 'sv', 'sw', 'ta', 'te', 'tn', 'tr', 'tt', 'ur', 'uz', 'vi',
326
+ 'vi-vn-x-central', 'vi-vn-x-south', 'yue'
327
+ ] # fmt: off
328
+
329
+
330
+ def make_cond_dict(
331
+ text: str = "It would be nice to have time for testing, indeed.",
332
+ language: str = "en-us",
333
+ speaker: torch.Tensor | None = None,
334
+
335
+ # Emotion vector from 0.0 to 1.0
336
+ # Is entangled with pitch_std because more emotion => more pitch variation
337
+ # VQScore and DNSMOS because they favor neutral speech
338
+ #
339
+ # Happiness, Sadness, Disgust, Fear, Surprise, Anger, Other, Neutral
340
+ emotion: list[float] = [0.3077, 0.0256, 0.0256, 0.0256, 0.0256, 0.0256, 0.2564, 0.3077],
341
+
342
+ # Maximum frequency (0 to 24000), should be 22050 or 24000 for 44.1 or 48 kHz audio
343
+ # For voice cloning use 22050
344
+ fmax: float = 22050.0,
345
+
346
+ # Standard deviation for pitch (0 to 400), should be
347
+ # 20-45 for normal speech,
348
+ # 60-150 for expressive speech,
349
+ # higher values => crazier samples
350
+ pitch_std: float = 20.0,
351
+
352
+ # Speaking rate in phonemes per minute (0 to 40). 30 is very fast, 10 is slow.
353
+ speaking_rate: float = 15.0,
354
+
355
+ # Target VoiceQualityScore for the generated speech (0.5 to 0.8).
356
+ # A list of values must be provided which represent each 1/8th of the audio.
357
+ # You should unset for expressive speech.
358
+ # According to discord Chat this is only used for the hybrid model
359
+ vqscore_8: list[float] = [0.78] * 8,
360
+
361
+ # CTC target loss
362
+ # Only used for the hybrid model
363
+ ctc_loss: float = 0.0,
364
+ # Only used for the hybrid model
365
+ dnsmos_ovrl: float = 4.0,
366
+ # Only used for the hybrid model
367
+ speaker_noised: bool = False,
368
+ unconditional_keys: Iterable[str] = {"vqscore_8", "dnsmos_ovrl"},
369
+ device: torch.device | str = DEFAULT_DEVICE,
370
+ ) -> dict:
371
+ """
372
+ A helper to build the 'cond_dict' that the model expects.
373
+ By default, it will generate a random speaker embedding
374
+ """
375
+ assert language.lower() in supported_language_codes, "Please pick a supported language"
376
+
377
+ language_code_to_id = {lang: i for i, lang in enumerate(supported_language_codes)}
378
+
379
+ cond_dict = {
380
+ "espeak": ([text], [language]),
381
+ "speaker": speaker,
382
+ "emotion": emotion,
383
+ "fmax": fmax,
384
+ "pitch_std": pitch_std,
385
+ "speaking_rate": speaking_rate,
386
+ "language_id": language_code_to_id[language],
387
+ "vqscore_8": vqscore_8,
388
+ "ctc_loss": ctc_loss,
389
+ "dnsmos_ovrl": dnsmos_ovrl,
390
+ "speaker_noised": int(speaker_noised),
391
+ }
392
+
393
+ for k in unconditional_keys:
394
+ cond_dict.pop(k, None)
395
+
396
+ for k, v in cond_dict.items():
397
+ if isinstance(v, (float, int, list)):
398
+ v = torch.tensor(v)
399
+ if isinstance(v, torch.Tensor):
400
+ cond_dict[k] = v.view(1, 1, -1).to(device)
401
+
402
+ if k == "emotion":
403
+ cond_dict[k] /= cond_dict[k].sum(dim=-1)
404
+
405
+ return cond_dict
Zonos-main/zonos/config.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Literal
3
+
4
+ import torch
5
+
6
+
7
+ # https://github.com/state-spaces/mamba/blob//mamba_ssm/utils/generation.py#L18
8
+ @dataclass
9
+ class InferenceParams:
10
+ """Inference parameters that are passed to the main model in order
11
+ to efficienly calculate and store the context during inference."""
12
+
13
+ max_seqlen: int
14
+ max_batch_size: int
15
+ seqlen_offset: int = 0
16
+ batch_size_offset: int = 0
17
+ key_value_memory_dict: dict = field(default_factory=dict)
18
+ lengths_per_sample: torch.Tensor | None = None
19
+
20
+ def reset(self, max_seqlen, max_batch_size):
21
+ self.max_seqlen = max_seqlen
22
+ self.max_batch_size = max_batch_size
23
+ self.seqlen_offset = 0
24
+ if self.lengths_per_sample is not None:
25
+ self.lengths_per_sample.zero_()
26
+
27
+
28
+ @dataclass
29
+ class BackboneConfig:
30
+ d_model: int = 1024
31
+ d_intermediate: int = 0
32
+ attn_mlp_d_intermediate: int = 0
33
+ n_layer: int = 16
34
+ ssm_cfg: dict = field(default_factory=dict)
35
+ attn_layer_idx: list = field(default_factory=list)
36
+ attn_cfg: dict = field(default_factory=dict)
37
+ rms_norm: bool = False
38
+ residual_in_fp32: bool = False
39
+ norm_epsilon: float = 1e-5
40
+
41
+
42
+ @dataclass
43
+ class PrefixConditionerConfig:
44
+ conditioners: list[dict]
45
+ projection: Literal["none", "linear", "mlp"]
46
+
47
+
48
+ @dataclass
49
+ class ZonosConfig:
50
+ backbone: BackboneConfig
51
+ prefix_conditioner: PrefixConditionerConfig
52
+ eos_token_id: int = 1024
53
+ masked_token_id: int = 1025
54
+ pad_vocab_to_multiple_of: int = 8
55
+
56
+ @classmethod
57
+ def from_dict(cls, d: dict) -> "ZonosConfig":
58
+ d = d.copy()
59
+ backbone_config = BackboneConfig(**d.pop("backbone"))
60
+ prefix_conditioner_config = PrefixConditionerConfig(**d.pop("prefix_conditioner"))
61
+ config = cls(backbone_config, prefix_conditioner_config, **d)
62
+ return config
Zonos-main/zonos/model.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Callable
3
+
4
+ import safetensors
5
+ import torch
6
+ import torch.nn as nn
7
+ from huggingface_hub import hf_hub_download
8
+ from tqdm import tqdm
9
+
10
+ from zonos.autoencoder import DACAutoencoder
11
+ from zonos.backbone import BACKBONES
12
+ from zonos.codebook_pattern import apply_delay_pattern, revert_delay_pattern
13
+ from zonos.conditioning import PrefixConditioner
14
+ from zonos.config import InferenceParams, ZonosConfig
15
+ from zonos.sampling import sample_from_logits
16
+ from zonos.speaker_cloning import SpeakerEmbeddingLDA
17
+ from zonos.utils import DEFAULT_DEVICE, find_multiple, pad_weight_
18
+
19
+ DEFAULT_BACKBONE_CLS = next(iter(BACKBONES.values()))
20
+
21
+
22
+ class Zonos(nn.Module):
23
+ def __init__(self, config: ZonosConfig, backbone_cls=DEFAULT_BACKBONE_CLS):
24
+ super().__init__()
25
+ self.config = config
26
+ dim = config.backbone.d_model
27
+ self.eos_token_id = config.eos_token_id
28
+ self.masked_token_id = config.masked_token_id
29
+
30
+ self.autoencoder = DACAutoencoder()
31
+ self.backbone = backbone_cls(config.backbone)
32
+ self.prefix_conditioner = PrefixConditioner(config.prefix_conditioner, dim)
33
+ self.spk_clone_model = None
34
+
35
+ # TODO: pad to multiple of at least 8
36
+ self.embeddings = nn.ModuleList([nn.Embedding(1026, dim) for _ in range(self.autoencoder.num_codebooks)])
37
+ self.heads = nn.ModuleList([nn.Linear(dim, 1025, bias=False) for _ in range(self.autoencoder.num_codebooks)])
38
+
39
+ self._cg_graph = None
40
+ self._cg_batch_size = None
41
+ self._cg_input_ids = None
42
+ self._cg_logits = None
43
+ self._cg_inference_params = None
44
+ self._cg_scale = None
45
+
46
+ if config.pad_vocab_to_multiple_of:
47
+ self.register_load_state_dict_post_hook(self._pad_embeddings_and_heads)
48
+
49
+ def _pad_embeddings_and_heads(self, *args, **kwargs):
50
+ for w in [*self.embeddings, *self.heads]:
51
+ pad_weight_(w, self.config.pad_vocab_to_multiple_of)
52
+
53
+ @property
54
+ def device(self) -> torch.device:
55
+ return next(self.parameters()).device
56
+
57
+ @classmethod
58
+ def from_pretrained(
59
+ cls, repo_id: str, revision: str | None = None, device: str = DEFAULT_DEVICE, **kwargs
60
+ ) -> "Zonos":
61
+ config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision)
62
+ model_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
63
+ return cls.from_local(config_path, model_path, device, **kwargs)
64
+
65
+ @classmethod
66
+ def from_local(
67
+ cls, config_path: str, model_path: str, device: str = DEFAULT_DEVICE, backbone: str | None = None
68
+ ) -> "Zonos":
69
+ config = ZonosConfig.from_dict(json.load(open(config_path)))
70
+ if backbone:
71
+ backbone_cls = BACKBONES[backbone]
72
+ else:
73
+ is_transformer = not bool(config.backbone.ssm_cfg)
74
+ backbone_cls = DEFAULT_BACKBONE_CLS
75
+ # Preferentially route to pure torch backbone for increased performance and lower latency.
76
+ if is_transformer and "torch" in BACKBONES:
77
+ backbone_cls = BACKBONES["torch"]
78
+
79
+ model = cls(config, backbone_cls).to(device, torch.bfloat16)
80
+ model.autoencoder.dac.to(device)
81
+
82
+ sd = model.state_dict()
83
+ with safetensors.safe_open(model_path, framework="pt") as f:
84
+ for k in f.keys():
85
+ sd[k] = f.get_tensor(k)
86
+ model.load_state_dict(sd)
87
+
88
+ return model
89
+
90
+ def make_speaker_embedding(self, wav: torch.Tensor, sr: int) -> torch.Tensor:
91
+ """Generate a speaker embedding from an audio clip."""
92
+ if self.spk_clone_model is None:
93
+ self.spk_clone_model = SpeakerEmbeddingLDA()
94
+ _, spk_embedding = self.spk_clone_model(wav.to(self.spk_clone_model.device), sr)
95
+ return spk_embedding.unsqueeze(0).bfloat16()
96
+
97
+ def embed_codes(self, codes: torch.Tensor) -> torch.Tensor:
98
+ return sum(emb(codes[:, i]) for i, emb in enumerate(self.embeddings))
99
+
100
+ def apply_heads(self, hidden_states: torch.Tensor) -> torch.Tensor:
101
+ return torch.stack([head(hidden_states) for head in self.heads], dim=1)
102
+
103
+ def _compute_logits(
104
+ self, hidden_states: torch.Tensor, inference_params: InferenceParams, cfg_scale: float
105
+ ) -> torch.Tensor:
106
+ """
107
+ Pass `hidden_states` into `backbone` and `multi_head`, applying
108
+ classifier-free guidance if `cfg_scale != 1.0`.
109
+ """
110
+ last_hidden_states = self.backbone(hidden_states, inference_params)[:, -1, :].unsqueeze(1)
111
+ logits = self.apply_heads(last_hidden_states).squeeze(2).float()
112
+ if cfg_scale != 1.0:
113
+ cond_logits, uncond_logits = logits.chunk(2)
114
+ logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
115
+ logits[..., 1025:].fill_(-torch.inf) # ensures padding is ignored
116
+ return logits
117
+
118
+ def _decode_one_token(
119
+ self,
120
+ input_ids: torch.Tensor,
121
+ inference_params: InferenceParams,
122
+ cfg_scale: float,
123
+ allow_cudagraphs: bool = True,
124
+ ) -> torch.Tensor:
125
+ """
126
+ Single-step decode. Prepares the hidden states, possibly replicates them
127
+ for CFG, and then delegates to `_compute_logits`.
128
+
129
+ Below we wrap this function with a simple CUDA Graph capturing mechanism,
130
+ doing 3 warmup steps if needed and then capturing or replaying the graph.
131
+ We only recapture if the batch size changes.
132
+ """
133
+ # TODO: support cfg_scale==1
134
+ if cfg_scale == 1.0:
135
+ hidden_states = self.embed_codes(input_ids)
136
+ return self._compute_logits(hidden_states, inference_params, cfg_scale)
137
+
138
+ bsz = input_ids.size(0)
139
+
140
+ if not allow_cudagraphs or input_ids.device.type != "cuda":
141
+ hidden_states_local = self.embed_codes(input_ids)
142
+ hidden_states_local = hidden_states_local.repeat(2, 1, 1)
143
+ return self._compute_logits(hidden_states_local, inference_params, cfg_scale)
144
+
145
+ need_capture = (self._cg_graph is None) or (self._cg_batch_size != bsz)
146
+
147
+ if need_capture:
148
+ self._cg_graph = None
149
+
150
+ self._cg_batch_size = bsz
151
+ self._cg_inference_params = inference_params
152
+ self._cg_scale = cfg_scale
153
+
154
+ for _ in range(3):
155
+ hidden_states = self.embed_codes(input_ids)
156
+ hidden_states = hidden_states.repeat(2, 1, 1) # because cfg != 1.0
157
+ logits = self._compute_logits(hidden_states, inference_params, cfg_scale)
158
+
159
+ self._cg_input_ids = input_ids.clone()
160
+ self._cg_logits = torch.empty_like(logits)
161
+
162
+ g = torch.cuda.CUDAGraph()
163
+
164
+ def capture_region():
165
+ hidden_states_local = self.embed_codes(self._cg_input_ids)
166
+ hidden_states_local = hidden_states_local.repeat(2, 1, 1)
167
+ self._cg_logits = self._compute_logits(hidden_states_local, self._cg_inference_params, self._cg_scale)
168
+
169
+ with torch.cuda.graph(g):
170
+ capture_region()
171
+
172
+ self._cg_graph = g
173
+
174
+ else:
175
+ self._cg_input_ids.copy_(input_ids)
176
+
177
+ self._cg_graph.replay()
178
+
179
+ return self._cg_logits
180
+
181
+ def _prefill(
182
+ self,
183
+ prefix_hidden_states: torch.Tensor,
184
+ input_ids: torch.Tensor,
185
+ inference_params: InferenceParams,
186
+ cfg_scale: float,
187
+ ) -> torch.Tensor:
188
+ """
189
+ "Prefill" mode: we already have `prefix_hidden_states`, and we want
190
+ to append new embeddings, then compute the logits.
191
+ """
192
+ # Replicate input_ids if CFG is enabled
193
+ if cfg_scale != 1.0:
194
+ input_ids = input_ids.expand(prefix_hidden_states.shape[0], -1, -1)
195
+ hidden_states = torch.cat([prefix_hidden_states, self.embed_codes(input_ids)], dim=1)
196
+ return self._compute_logits(hidden_states, inference_params, cfg_scale)
197
+
198
+ def setup_cache(self, batch_size: int, max_seqlen: int, dtype: torch.dtype = torch.bfloat16) -> InferenceParams:
199
+ max_seqlen = find_multiple(max_seqlen, 8)
200
+ key_value_memory_dict = self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype)
201
+ lengths_per_sample = torch.full((batch_size,), 0, dtype=torch.int32)
202
+ return InferenceParams(max_seqlen, batch_size, 0, 0, key_value_memory_dict, lengths_per_sample)
203
+
204
+ def prepare_conditioning(self, cond_dict: dict, uncond_dict: dict | None = None) -> torch.Tensor:
205
+ if uncond_dict is None:
206
+ uncond_dict = {k: cond_dict[k] for k in self.prefix_conditioner.required_keys}
207
+ return torch.cat(
208
+ [
209
+ self.prefix_conditioner(cond_dict),
210
+ self.prefix_conditioner(uncond_dict),
211
+ ]
212
+ )
213
+
214
+ def can_use_cudagraphs(self) -> bool:
215
+ # Only the mamba-ssm backbone supports CUDA Graphs at the moment
216
+ return self.device.type == "cuda" and "_mamba_ssm" in str(self.backbone.__class__)
217
+
218
+ @torch.inference_mode()
219
+ def generate(
220
+ self,
221
+ prefix_conditioning: torch.Tensor, # [bsz, cond_seq_len, d_model]
222
+ audio_prefix_codes: torch.Tensor | None = None, # [bsz, 9, prefix_audio_seq_len]
223
+ max_new_tokens: int = 86 * 30,
224
+ cfg_scale: float = 2.0,
225
+ batch_size: int = 1,
226
+ sampling_params: dict = dict(min_p=0.1),
227
+ progress_bar: bool = True,
228
+ disable_torch_compile: bool = False,
229
+ callback: Callable[[torch.Tensor, int, int], bool] | None = None,
230
+ ):
231
+ assert cfg_scale != 1, "TODO: add support for cfg_scale=1"
232
+ prefix_audio_len = 0 if audio_prefix_codes is None else audio_prefix_codes.shape[2]
233
+ device = self.device
234
+
235
+ # Use CUDA Graphs if supported, and torch.compile otherwise.
236
+ cg = self.can_use_cudagraphs()
237
+ decode_one_token = self._decode_one_token
238
+ decode_one_token = torch.compile(decode_one_token, dynamic=True, disable=cg or disable_torch_compile)
239
+
240
+ unknown_token = -1
241
+ audio_seq_len = prefix_audio_len + max_new_tokens
242
+ seq_len = prefix_conditioning.shape[1] + audio_seq_len + 9
243
+
244
+ with torch.device(device):
245
+ inference_params = self.setup_cache(batch_size=batch_size * 2, max_seqlen=seq_len)
246
+ codes = torch.full((batch_size, 9, audio_seq_len), unknown_token)
247
+
248
+ if audio_prefix_codes is not None:
249
+ codes[..., :prefix_audio_len] = audio_prefix_codes
250
+
251
+ delayed_codes = apply_delay_pattern(codes, self.masked_token_id)
252
+
253
+ delayed_prefix_audio_codes = delayed_codes[..., : prefix_audio_len + 1]
254
+
255
+ logits = self._prefill(prefix_conditioning, delayed_prefix_audio_codes, inference_params, cfg_scale)
256
+ next_token = sample_from_logits(logits, **sampling_params)
257
+
258
+ offset = delayed_prefix_audio_codes.shape[2]
259
+ frame = delayed_codes[..., offset : offset + 1]
260
+ frame.masked_scatter_(frame == unknown_token, next_token)
261
+
262
+ prefix_length = prefix_conditioning.shape[1] + prefix_audio_len + 1
263
+ inference_params.seqlen_offset += prefix_length
264
+ inference_params.lengths_per_sample[:] += prefix_length
265
+
266
+ logit_bias = torch.zeros_like(logits)
267
+ logit_bias[:, 1:, self.eos_token_id] = -torch.inf # only allow codebook 0 to predict EOS
268
+
269
+ stopping = torch.zeros(batch_size, dtype=torch.bool, device=device)
270
+ max_steps = delayed_codes.shape[2] - offset
271
+ remaining_steps = torch.full((batch_size,), max_steps, device=device)
272
+ progress = tqdm(total=max_steps, desc="Generating", disable=not progress_bar)
273
+ cfg_scale = torch.tensor(cfg_scale)
274
+
275
+ step = 0
276
+ while torch.max(remaining_steps) > 0:
277
+ offset += 1
278
+ input_ids = delayed_codes[..., offset - 1 : offset]
279
+ logits = decode_one_token(input_ids, inference_params, cfg_scale, allow_cudagraphs=cg)
280
+ logits += logit_bias
281
+
282
+ next_token = sample_from_logits(logits, generated_tokens=delayed_codes[..., :offset], **sampling_params)
283
+ eos_in_cb0 = next_token[:, 0] == self.eos_token_id
284
+
285
+ remaining_steps[eos_in_cb0[:, 0]] = torch.minimum(remaining_steps[eos_in_cb0[:, 0]], torch.tensor(9))
286
+ stopping |= eos_in_cb0[:, 0]
287
+
288
+ eos_codebook_idx = 9 - remaining_steps
289
+ eos_codebook_idx = torch.clamp(eos_codebook_idx, max=9 - 1)
290
+ for i in range(next_token.shape[0]):
291
+ if stopping[i]:
292
+ idx = eos_codebook_idx[i].item()
293
+ next_token[i, :idx] = self.masked_token_id
294
+ next_token[i, idx] = self.eos_token_id
295
+
296
+ frame = delayed_codes[..., offset : offset + 1]
297
+ frame.masked_scatter_(frame == unknown_token, next_token)
298
+ inference_params.seqlen_offset += 1
299
+ inference_params.lengths_per_sample[:] += 1
300
+
301
+ remaining_steps -= 1
302
+
303
+ progress.update()
304
+ step += 1
305
+
306
+ if callback is not None and not callback(frame, step, max_steps):
307
+ break
308
+
309
+ out_codes = revert_delay_pattern(delayed_codes)
310
+ out_codes.masked_fill_(out_codes >= 1024, 0)
311
+ out_codes = out_codes[..., : offset - 9]
312
+
313
+ self._cg_graph = None # reset cuda graph to avoid cache changes
314
+
315
+ return out_codes
Zonos-main/zonos/sampling.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
5
+ """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
6
+
7
+ Args:
8
+ input (torch.Tensor): The input tensor containing probabilities.
9
+ num_samples (int): Number of samples to draw.
10
+ replacement (bool): Whether to draw with replacement or not.
11
+ Keywords args:
12
+ generator (torch.Generator): A pseudorandom number generator for sampling.
13
+ Returns:
14
+ torch.Tensor: Last dimension contains num_samples indices
15
+ sampled from the multinomial probability distribution
16
+ located in the last dimension of tensor input.
17
+ """
18
+
19
+ if num_samples == 1:
20
+ q = torch.empty_like(input).exponential_(1, generator=generator)
21
+ return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64)
22
+
23
+ input_ = input.reshape(-1, input.shape[-1])
24
+ output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
25
+ output = output_.reshape(*list(input.shape[:-1]), -1)
26
+ return output
27
+
28
+
29
+ def apply_unified(probs: torch.Tensor, linear: float, conf: float, quad: float) -> torch.Tensor:
30
+ """Sample next token using unified sampling approach that combines linear scaling, confidence, and quadratic terms.
31
+
32
+ Args:
33
+ probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
34
+ linear (float): Linear scaling factor applied to log probabilities.
35
+ conf (float): Confidence factor that scales the entropy term.
36
+ quad (float): Quadratic penalty factor applied to squared log probabilities.
37
+ Returns:
38
+ torch.Tensor: Modified probability distribution after applying unified sampling.
39
+ """
40
+ logprobs = torch.log(probs.clamp_min(1e-20))
41
+ entropy = -torch.sum(probs * logprobs, dim=-1, keepdim=True)
42
+ raw = logprobs * (linear + entropy * conf) - logprobs**2 * quad
43
+ return raw.softmax(dim=-1)
44
+
45
+ def apply_top_k(
46
+ probs: torch.Tensor,
47
+ k: int,
48
+ ) -> torch.Tensor:
49
+ """Sample next token from top K values along the last dimension of the input probs tensor.
50
+
51
+ Args:
52
+ probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
53
+ k (int): The k in “top-k”.
54
+ Returns:
55
+ torch.Tensor: Sampled tokens.
56
+ """
57
+ v, _ = torch.topk(probs, min(k, probs.size(-1)))
58
+ pivot = v.select(-1, -1).unsqueeze(-1)
59
+ probs = torch.where(probs < pivot, 0.0, probs)
60
+ probs.div_(probs.sum(dim=-1, keepdim=True))
61
+ return probs
62
+
63
+
64
+ def apply_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
65
+ """Sample next token from top P probabilities along the last dimension of the input probs tensor.
66
+
67
+ Args:
68
+ probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
69
+ p (int): The p in “top-p”.
70
+ Returns:
71
+ torch.Tensor: Sampled tokens.
72
+ """
73
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
74
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
75
+ mask = probs_sum - probs_sort > p
76
+ probs_sort *= (~mask).float()
77
+ probs = probs.scatter(-1, probs_idx, probs_sort)
78
+ probs.div_(probs.sum(dim=-1, keepdim=True))
79
+ return probs
80
+
81
+
82
+ def apply_min_p(probs: torch.Tensor, min_p: float) -> torch.Tensor:
83
+ """Sample next token using min-p sampling.
84
+
85
+ Args:
86
+ scores (torch.FloatTensor): Input logits with token candidates on the last dimension.
87
+ min_p (float): Minimum token probability, scaled by the probability of the most likely token.
88
+ Must be between 0 and 1. Typical values are in the 0.01-0.2 range.
89
+ Returns:
90
+ torch.Tensor: Sampled tokens.
91
+ """
92
+ top_probs, _ = probs.max(dim=-1, keepdim=True)
93
+ tokens_to_remove = probs < (min_p * top_probs)
94
+ probs = probs.masked_fill(tokens_to_remove, 0.0)
95
+ probs.div_(probs.sum(dim=-1, keepdim=True))
96
+ return probs
97
+
98
+
99
+ def modify_logit_for_repetition_penalty(
100
+ logits: torch.Tensor,
101
+ generated_tokens: torch.Tensor,
102
+ repetition_penalty: float,
103
+ repetition_penalty_window: int,
104
+ ):
105
+ """See https://arxiv.org/abs/1909.05858
106
+ Apply repetition penalty over a sliding window of the last `repetition_penalty_window` tokens.
107
+ logits: (batch_size, n_codebooks, vocab_size)
108
+ generated_tokens: (batch_size, n_codebooks, seq_len)
109
+ """
110
+ generated_tokens = generated_tokens[..., -repetition_penalty_window:]
111
+ generated_tokens = generated_tokens.clamp_max(logits.shape[-1] - 1).to(torch.int64)
112
+ rp = torch.full_like(logits, repetition_penalty)
113
+ factors = torch.ones_like(logits).scatter_reduce(2, generated_tokens, rp, reduce="prod")
114
+ return torch.where(logits <= 0, logits * factors, logits / factors)
115
+
116
+
117
+ def sample_from_logits(
118
+ logits: torch.Tensor,
119
+ temperature: float = 1.0,
120
+ top_p: float = 0.0,
121
+ top_k: int = 0,
122
+ min_p: float = 0.0,
123
+ linear: float = 0.0,
124
+ conf: float = 0.0,
125
+ quad: float = 0.0,
126
+ generated_tokens: torch.Tensor | None = None,
127
+ repetition_penalty: float = 3.0,
128
+ repetition_penalty_window: int = 2,
129
+ ) -> torch.Tensor:
130
+ """Sample next token from logits using either top_k/p/min_p OR using NovelAI's Unified Sampler.
131
+
132
+ Args:
133
+ logits (torch.Tensor): Input logits with token candidates on the last dimension.
134
+
135
+ temperature (float): Randomness of the sampling. Lower temperature results in more deterministic samples.
136
+ To disable sampling entirely, set it to 0. For NovelAI's Unified Sampler, set it to 1.0
137
+
138
+ top_p (float): Only sample from the most probable tokens whose cumulative probability is less than p.
139
+ This is called nucleus sampling. Must be between 0 and 1. Typical values are in the 0.1-0.9 range.
140
+
141
+ Set to 0 to disable.
142
+
143
+ top_k (int): Only sample from the top k most probable tokens. Set to 0 to disable.
144
+
145
+ min_p (float): Minimum token probability, scaled by the probability of the most likely token.
146
+ Must be between 0 and 1. Typical values are in the 0.01-0.2 range.
147
+ If too high, no token might be sampled leading to silence (?)
148
+
149
+ linear (float): NovelAI's Unified Sampler -> 0.0 to 1.0, default from gradio 0.5
150
+
151
+ Set Linear between 0 and 1 according to how unusual you want tokens to be.
152
+ Lower numbers will produce more unusual/creative outputs,
153
+ but you will have to reroll or edit more.
154
+
155
+ conf (float): Confidence - Low values make random outputs more random. -> -2.0 * Quad to 2.0, default from gradio 0.4
156
+
157
+ As a starting point, set Quad = 1/3 - Linear * 4 / 15, and Conf = -Quad / 2.
158
+
159
+ quad (float): Quadratic - High values make low probablities much lower. -> -2.0 to 2.0, default from gradio 0.0
160
+
161
+ Returns:
162
+ torch.Tensor: Sampled tokens.
163
+ """
164
+ if repetition_penalty != 1.0 and generated_tokens is not None:
165
+ logits = modify_logit_for_repetition_penalty(logits, generated_tokens, repetition_penalty, repetition_penalty_window)
166
+
167
+ if temperature > 0:
168
+ probs = torch.softmax(logits / temperature, dim=-1)
169
+ if linear > 0.0:
170
+ probs = apply_unified(probs, linear, conf, quad)
171
+ if top_p > 0:
172
+ probs = apply_top_p(probs, top_p)
173
+ if top_k > 0:
174
+ probs = apply_top_k(probs, top_k)
175
+ if min_p > 0:
176
+ probs = apply_min_p(probs, min_p)
177
+
178
+ next_token = multinomial(probs, num_samples=1)
179
+ else:
180
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
181
+
182
+ return next_token # [batch_size, num_codebooks, 1]
Zonos-main/zonos/speaker_cloning.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import cache
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torchaudio
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ from zonos.utils import DEFAULT_DEVICE
11
+
12
+
13
+ class logFbankCal(nn.Module):
14
+ def __init__(
15
+ self,
16
+ sample_rate: int = 16_000,
17
+ n_fft: int = 512,
18
+ win_length: float = 0.025,
19
+ hop_length: float = 0.01,
20
+ n_mels: int = 80,
21
+ ):
22
+ super().__init__()
23
+ self.fbankCal = torchaudio.transforms.MelSpectrogram(
24
+ sample_rate=sample_rate,
25
+ n_fft=n_fft,
26
+ win_length=int(win_length * sample_rate),
27
+ hop_length=int(hop_length * sample_rate),
28
+ n_mels=n_mels,
29
+ )
30
+
31
+ def forward(self, x):
32
+ out = self.fbankCal(x)
33
+ out = torch.log(out + 1e-6)
34
+ out = out - out.mean(axis=2).unsqueeze(dim=2)
35
+ return out
36
+
37
+
38
+ class ASP(nn.Module):
39
+ # Attentive statistics pooling
40
+ def __init__(self, in_planes, acoustic_dim):
41
+ super(ASP, self).__init__()
42
+ outmap_size = int(acoustic_dim / 8)
43
+ self.out_dim = in_planes * 8 * outmap_size * 2
44
+
45
+ self.attention = nn.Sequential(
46
+ nn.Conv1d(in_planes * 8 * outmap_size, 128, kernel_size=1),
47
+ nn.ReLU(),
48
+ nn.BatchNorm1d(128),
49
+ nn.Conv1d(128, in_planes * 8 * outmap_size, kernel_size=1),
50
+ nn.Softmax(dim=2),
51
+ )
52
+
53
+ def forward(self, x):
54
+ x = x.reshape(x.size()[0], -1, x.size()[-1])
55
+ w = self.attention(x)
56
+ mu = torch.sum(x * w, dim=2)
57
+ sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5))
58
+ x = torch.cat((mu, sg), 1)
59
+
60
+ x = x.view(x.size()[0], -1)
61
+ return x
62
+
63
+
64
+ class SimAMBasicBlock(nn.Module):
65
+ expansion = 1
66
+
67
+ def __init__(self, ConvLayer, NormLayer, in_planes, planes, stride=1, block_id=1):
68
+ super(SimAMBasicBlock, self).__init__()
69
+ self.conv1 = ConvLayer(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
70
+ self.bn1 = NormLayer(planes)
71
+ self.conv2 = ConvLayer(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
72
+ self.bn2 = NormLayer(planes)
73
+ self.relu = nn.ReLU(inplace=True)
74
+ self.sigmoid = nn.Sigmoid()
75
+
76
+ self.downsample = nn.Sequential()
77
+ if stride != 1 or in_planes != self.expansion * planes:
78
+ self.downsample = nn.Sequential(
79
+ ConvLayer(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
80
+ NormLayer(self.expansion * planes),
81
+ )
82
+
83
+ def forward(self, x):
84
+ out = self.relu(self.bn1(self.conv1(x)))
85
+ out = self.bn2(self.conv2(out))
86
+ out = self.SimAM(out)
87
+ out += self.downsample(x)
88
+ out = self.relu(out)
89
+ return out
90
+
91
+ def SimAM(self, X, lambda_p=1e-4):
92
+ n = X.shape[2] * X.shape[3] - 1
93
+ d = (X - X.mean(dim=[2, 3], keepdim=True)).pow(2)
94
+ v = d.sum(dim=[2, 3], keepdim=True) / n
95
+ E_inv = d / (4 * (v + lambda_p)) + 0.5
96
+ return X * self.sigmoid(E_inv)
97
+
98
+
99
+ class BasicBlock(nn.Module):
100
+ expansion = 1
101
+
102
+ def __init__(self, ConvLayer, NormLayer, in_planes, planes, stride=1, block_id=1):
103
+ super(BasicBlock, self).__init__()
104
+ self.conv1 = ConvLayer(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
105
+ self.bn1 = NormLayer(planes)
106
+ self.conv2 = ConvLayer(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
107
+ self.bn2 = NormLayer(planes)
108
+ self.relu = nn.ReLU(inplace=True)
109
+
110
+ self.downsample = nn.Sequential()
111
+ if stride != 1 or in_planes != self.expansion * planes:
112
+ self.downsample = nn.Sequential(
113
+ ConvLayer(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
114
+ NormLayer(self.expansion * planes),
115
+ )
116
+
117
+ def forward(self, x):
118
+ out = self.relu(self.bn1(self.conv1(x)))
119
+ out = self.bn2(self.conv2(out))
120
+ out += self.downsample(x)
121
+ out = self.relu(out)
122
+ return out
123
+
124
+
125
+ class Bottleneck(nn.Module):
126
+ expansion = 4
127
+
128
+ def __init__(self, ConvLayer, NormLayer, in_planes, planes, stride=1, block_id=1):
129
+ super(Bottleneck, self).__init__()
130
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
131
+ self.bn1 = nn.BatchNorm2d(planes)
132
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
133
+ self.bn2 = nn.BatchNorm2d(planes)
134
+ self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
135
+ self.bn3 = nn.BatchNorm2d(self.expansion * planes)
136
+
137
+ self.shortcut = nn.Sequential()
138
+ if stride != 1 or in_planes != self.expansion * planes:
139
+ self.shortcut = nn.Sequential(
140
+ nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
141
+ nn.BatchNorm2d(self.expansion * planes),
142
+ )
143
+
144
+ def forward(self, x):
145
+ out = F.relu(self.bn1(self.conv1(x)))
146
+ out = F.relu(self.bn2(self.conv2(out)))
147
+ out = self.bn3(self.conv3(out))
148
+ out += self.shortcut(x)
149
+ out = F.relu(out)
150
+ return out
151
+
152
+
153
+ class ResNet(nn.Module):
154
+ def __init__(self, in_planes, block, num_blocks, in_ch=1, feat_dim="2d", **kwargs):
155
+ super(ResNet, self).__init__()
156
+ if feat_dim == "1d":
157
+ self.NormLayer = nn.BatchNorm1d
158
+ self.ConvLayer = nn.Conv1d
159
+ elif feat_dim == "2d":
160
+ self.NormLayer = nn.BatchNorm2d
161
+ self.ConvLayer = nn.Conv2d
162
+ elif feat_dim == "3d":
163
+ self.NormLayer = nn.BatchNorm3d
164
+ self.ConvLayer = nn.Conv3d
165
+ else:
166
+ print("error")
167
+
168
+ self.in_planes = in_planes
169
+
170
+ self.conv1 = self.ConvLayer(in_ch, in_planes, kernel_size=3, stride=1, padding=1, bias=False)
171
+ self.bn1 = self.NormLayer(in_planes)
172
+ self.relu = nn.ReLU(inplace=True)
173
+ self.layer1 = self._make_layer(block, in_planes, num_blocks[0], stride=1, block_id=1)
174
+ self.layer2 = self._make_layer(block, in_planes * 2, num_blocks[1], stride=2, block_id=2)
175
+ self.layer3 = self._make_layer(block, in_planes * 4, num_blocks[2], stride=2, block_id=3)
176
+ self.layer4 = self._make_layer(block, in_planes * 8, num_blocks[3], stride=2, block_id=4)
177
+
178
+ def _make_layer(self, block, planes, num_blocks, stride, block_id=1):
179
+ strides = [stride] + [1] * (num_blocks - 1)
180
+ layers = []
181
+ for stride in strides:
182
+ layers.append(block(self.ConvLayer, self.NormLayer, self.in_planes, planes, stride, block_id))
183
+ self.in_planes = planes * block.expansion
184
+ return nn.Sequential(*layers)
185
+
186
+ def forward(self, x):
187
+ x = self.relu(self.bn1(self.conv1(x)))
188
+ x = self.layer1(x)
189
+ x = self.layer2(x)
190
+ x = self.layer3(x)
191
+ x = self.layer4(x)
192
+ return x
193
+
194
+
195
+ def ResNet293(in_planes: int, **kwargs):
196
+ return ResNet(in_planes, SimAMBasicBlock, [10, 20, 64, 3], **kwargs)
197
+
198
+
199
+ class ResNet293_based(nn.Module):
200
+ def __init__(
201
+ self,
202
+ in_planes: int = 64,
203
+ embd_dim: int = 256,
204
+ acoustic_dim: int = 80,
205
+ featCal=None,
206
+ dropout: float = 0,
207
+ **kwargs,
208
+ ):
209
+ super(ResNet293_based, self).__init__()
210
+ self.featCal = featCal
211
+ self.front = ResNet293(in_planes)
212
+ block_expansion = SimAMBasicBlock.expansion
213
+ self.pooling = ASP(in_planes * block_expansion, acoustic_dim)
214
+ self.bottleneck = nn.Linear(self.pooling.out_dim, embd_dim)
215
+ self.drop = nn.Dropout(dropout) if dropout else None
216
+
217
+ def forward(self, x):
218
+ x = self.featCal(x)
219
+ x = self.front(x.unsqueeze(dim=1))
220
+ x = self.pooling(x)
221
+ if self.drop:
222
+ x = self.drop(x)
223
+ x = self.bottleneck(x)
224
+ return x
225
+
226
+
227
+ class SEModule(nn.Module):
228
+ def __init__(self, channels, bottleneck=128):
229
+ super(SEModule, self).__init__()
230
+ self.se = nn.Sequential(
231
+ nn.AdaptiveAvgPool1d(1),
232
+ nn.Conv1d(channels, bottleneck, kernel_size=1, padding=0),
233
+ nn.ReLU(),
234
+ # nn.BatchNorm1d(bottleneck), # Removed
235
+ nn.Conv1d(bottleneck, channels, kernel_size=1, padding=0),
236
+ nn.Sigmoid(),
237
+ )
238
+
239
+ def forward(self, input):
240
+ x = self.se(input)
241
+ return input * x
242
+
243
+
244
+ class Bottle2neck(nn.Module):
245
+ def __init__(self, inplanes, planes, kernel_size=None, dilation=None, scale=8):
246
+ super(Bottle2neck, self).__init__()
247
+ width = int(math.floor(planes / scale))
248
+ self.conv1 = nn.Conv1d(inplanes, width * scale, kernel_size=1)
249
+ self.bn1 = nn.BatchNorm1d(width * scale)
250
+ self.nums = scale - 1
251
+ convs = []
252
+ bns = []
253
+ num_pad = math.floor(kernel_size / 2) * dilation
254
+ for i in range(self.nums):
255
+ convs.append(nn.Conv1d(width, width, kernel_size=kernel_size, dilation=dilation, padding=num_pad))
256
+ bns.append(nn.BatchNorm1d(width))
257
+ self.convs = nn.ModuleList(convs)
258
+ self.bns = nn.ModuleList(bns)
259
+ self.conv3 = nn.Conv1d(width * scale, planes, kernel_size=1)
260
+ self.bn3 = nn.BatchNorm1d(planes)
261
+ self.relu = nn.ReLU()
262
+ self.width = width
263
+ self.se = SEModule(planes)
264
+
265
+ def forward(self, x):
266
+ residual = x
267
+ out = self.conv1(x)
268
+ out = self.relu(out)
269
+ out = self.bn1(out)
270
+
271
+ spx = torch.split(out, self.width, 1)
272
+ for i in range(self.nums):
273
+ if i == 0:
274
+ sp = spx[i]
275
+ else:
276
+ sp = sp + spx[i]
277
+ sp = self.convs[i](sp)
278
+ sp = self.relu(sp)
279
+ sp = self.bns[i](sp)
280
+ if i == 0:
281
+ out = sp
282
+ else:
283
+ out = torch.cat((out, sp), 1)
284
+ out = torch.cat((out, spx[self.nums]), 1)
285
+
286
+ out = self.conv3(out)
287
+ out = self.relu(out)
288
+ out = self.bn3(out)
289
+
290
+ out = self.se(out)
291
+ out += residual
292
+ return out
293
+
294
+
295
+ class ECAPA_TDNN(nn.Module):
296
+ def __init__(self, C, featCal):
297
+ super(ECAPA_TDNN, self).__init__()
298
+ self.featCal = featCal
299
+ self.conv1 = nn.Conv1d(80, C, kernel_size=5, stride=1, padding=2)
300
+ self.relu = nn.ReLU()
301
+ self.bn1 = nn.BatchNorm1d(C)
302
+ self.layer1 = Bottle2neck(C, C, kernel_size=3, dilation=2, scale=8)
303
+ self.layer2 = Bottle2neck(C, C, kernel_size=3, dilation=3, scale=8)
304
+ self.layer3 = Bottle2neck(C, C, kernel_size=3, dilation=4, scale=8)
305
+ # I fixed the shape of the output from MFA layer, that is close to the setting from ECAPA paper.
306
+ self.layer4 = nn.Conv1d(3 * C, 1536, kernel_size=1)
307
+ self.attention = nn.Sequential(
308
+ nn.Conv1d(4608, 256, kernel_size=1),
309
+ nn.ReLU(),
310
+ nn.BatchNorm1d(256),
311
+ nn.Tanh(), # Added
312
+ nn.Conv1d(256, 1536, kernel_size=1),
313
+ nn.Softmax(dim=2),
314
+ )
315
+ self.bn5 = nn.BatchNorm1d(3072)
316
+ self.fc6 = nn.Linear(3072, 192)
317
+ self.bn6 = nn.BatchNorm1d(192)
318
+
319
+ def forward(self, x):
320
+ x = self.featCal(x)
321
+ x = self.conv1(x)
322
+ x = self.relu(x)
323
+ x = self.bn1(x)
324
+
325
+ x1 = self.layer1(x)
326
+ x2 = self.layer2(x + x1)
327
+ x3 = self.layer3(x + x1 + x2)
328
+
329
+ x = self.layer4(torch.cat((x1, x2, x3), dim=1))
330
+ x = self.relu(x)
331
+
332
+ t = x.size()[-1]
333
+
334
+ global_x = torch.cat(
335
+ (
336
+ x,
337
+ torch.mean(x, dim=2, keepdim=True).repeat(1, 1, t),
338
+ torch.sqrt(torch.var(x, dim=2, keepdim=True).clamp(min=1e-4)).repeat(1, 1, t),
339
+ ),
340
+ dim=1,
341
+ )
342
+
343
+ w = self.attention(global_x)
344
+
345
+ mu = torch.sum(x * w, dim=2)
346
+ sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-4))
347
+
348
+ x = torch.cat((mu, sg), 1)
349
+ x = self.bn5(x)
350
+ x = self.fc6(x)
351
+ x = self.bn6(x)
352
+
353
+ return x
354
+
355
+
356
+ class SpeakerEmbedding(nn.Module):
357
+ def __init__(self, ckpt_path: str = "ResNet293_SimAM_ASP_base.pt", device: str = DEFAULT_DEVICE):
358
+ super().__init__()
359
+ self.device = device
360
+ with torch.device(device):
361
+ self.model = ResNet293_based()
362
+ state_dict = torch.load(ckpt_path, weights_only=True, mmap=True, map_location="cpu")
363
+ self.model.load_state_dict(state_dict)
364
+ self.model.featCal = logFbankCal()
365
+
366
+ self.requires_grad_(False).eval()
367
+
368
+ @property
369
+ def dtype(self):
370
+ return next(self.parameters()).dtype
371
+
372
+ @cache
373
+ def _get_resampler(self, orig_sample_rate: int):
374
+ return torchaudio.transforms.Resample(orig_sample_rate, 16_000).to(self.device)
375
+
376
+ def prepare_input(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
377
+ assert wav.ndim < 3
378
+ if wav.ndim == 2:
379
+ wav = wav.mean(0, keepdim=True)
380
+ wav = self._get_resampler(sample_rate)(wav)
381
+ return wav
382
+
383
+ def forward(self, wav: torch.Tensor, sample_rate: int):
384
+ wav = self.prepare_input(wav, sample_rate).to(self.device, self.dtype)
385
+ return self.model(wav).to(wav.device)
386
+
387
+
388
+ class SpeakerEmbeddingLDA(nn.Module):
389
+ def __init__(self, device: str = DEFAULT_DEVICE):
390
+ super().__init__()
391
+ spk_model_path = hf_hub_download(
392
+ repo_id="Zyphra/Zonos-v0.1-speaker-embedding",
393
+ filename="ResNet293_SimAM_ASP_base.pt",
394
+ )
395
+ lda_spk_model_path = hf_hub_download(
396
+ repo_id="Zyphra/Zonos-v0.1-speaker-embedding",
397
+ filename="ResNet293_SimAM_ASP_base_LDA-128.pt",
398
+ )
399
+
400
+ self.device = device
401
+ with torch.device(device):
402
+ self.model = SpeakerEmbedding(spk_model_path, device)
403
+ lda_sd = torch.load(lda_spk_model_path, weights_only=True)
404
+ out_features, in_features = lda_sd["weight"].shape
405
+ self.lda = nn.Linear(in_features, out_features, bias=True, dtype=torch.float32)
406
+ self.lda.load_state_dict(lda_sd)
407
+
408
+ self.requires_grad_(False).eval()
409
+
410
+ def forward(self, wav: torch.Tensor, sample_rate: int):
411
+ emb = self.model(wav, sample_rate).to(torch.float32)
412
+ return emb, self.lda(emb)
Zonos-main/zonos/utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def find_multiple(n: int, k: int) -> int:
7
+ if k == 0 or n % k == 0:
8
+ return n
9
+ return n + k - (n % k)
10
+
11
+
12
+ def pad_weight_(w: nn.Embedding | nn.Linear, multiple: int):
13
+ """Pad the weight of an embedding or linear layer to a multiple of `multiple`."""
14
+ if isinstance(w, nn.Embedding):
15
+ # Pad input dim
16
+ if w.weight.shape[1] % multiple == 0:
17
+ return
18
+ w.weight.data = F.pad(w.weight.data, (0, 0, 0, w.weight.shape[1] % multiple))
19
+ w.num_embeddings, w.embedding_dim = w.weight.shape
20
+ elif isinstance(w, nn.Linear):
21
+ # Pad output dim
22
+ if w.weight.shape[0] % multiple == 0:
23
+ return
24
+ w.weight.data = F.pad(w.weight.data, (0, 0, 0, w.weight.shape[0] % multiple))
25
+ w.out_features, w.in_features = w.weight.shape
26
+ else:
27
+ raise ValueError(f"Unsupported weight type: {type(w)}")
28
+
29
+
30
+ def get_device() -> torch.device:
31
+ if torch.cuda.is_available():
32
+ return torch.device(torch.cuda.current_device())
33
+ # MPS breaks for whatever reason. Uncomment when it's working.
34
+ # if torch.mps.is_available():
35
+ # return torch.device("mps")
36
+ return torch.device("cpu")
37
+
38
+
39
+ DEFAULT_DEVICE = get_device()