Upload 27 files
Browse filesPodcaster Dynamic site using Zonos-v0.1, a leading open-weight text-to-speech model trained on more than 200k hours of varied multilingual speech, delivering expressiveness and quality, for TTS podcast creation.
- .gitattributes +2 -0
- Zonos-main/.DS_Store +0 -0
- Zonos-main/.gitignore +13 -0
- Zonos-main/.python-version +1 -0
- Zonos-main/CONDITIONING_README.md +120 -0
- Zonos-main/Dockerfile +11 -0
- Zonos-main/LICENSE +202 -0
- Zonos-main/README.md +159 -0
- Zonos-main/assets/ArchitectureDiagram.png +3 -0
- Zonos-main/assets/ZonosHeader.png +0 -0
- Zonos-main/assets/exampleaudio.mp3 +3 -0
- Zonos-main/assets/silence_100ms.wav +0 -0
- Zonos-main/docker-compose.yml +16 -0
- Zonos-main/gradio_interface.py +419 -0
- Zonos-main/pyproject.toml +39 -0
- Zonos-main/sample.py +21 -0
- Zonos-main/uv.lock +0 -0
- Zonos-main/zonos/autoencoder.py +27 -0
- Zonos-main/zonos/backbone/__init__.py +12 -0
- Zonos-main/zonos/backbone/_mamba_ssm.py +57 -0
- Zonos-main/zonos/backbone/_torch.py +152 -0
- Zonos-main/zonos/codebook_pattern.py +12 -0
- Zonos-main/zonos/conditioning.py +405 -0
- Zonos-main/zonos/config.py +62 -0
- Zonos-main/zonos/model.py +315 -0
- Zonos-main/zonos/sampling.py +182 -0
- Zonos-main/zonos/speaker_cloning.py +412 -0
- Zonos-main/zonos/utils.py +39 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
Zonos-main/assets/ArchitectureDiagram.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
Zonos-main/assets/exampleaudio.mp3 filter=lfs diff=lfs merge=lfs -text
|
Zonos-main/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
Zonos-main/.gitignore
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python-generated files
|
2 |
+
__pycache__/
|
3 |
+
*.py[oc]
|
4 |
+
build/
|
5 |
+
dist/
|
6 |
+
wheels/
|
7 |
+
*.egg-info
|
8 |
+
|
9 |
+
# Virtual environments
|
10 |
+
.venv
|
11 |
+
|
12 |
+
# Misc.
|
13 |
+
.ipynb_checkpoints/
|
Zonos-main/.python-version
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
3.12
|
Zonos-main/CONDITIONING_README.md
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Conditioning explanations
|
2 |
+
Here we will list out all the conditionings the model accepts as well as a short description and some tips for optimal use. For conditionings with a learned unconditional, they can be set to that to allow the model to infer an appropriate setting.
|
3 |
+
### espeak
|
4 |
+
- **Type:** `EspeakPhonemeConditioner`
|
5 |
+
- **Description:**
|
6 |
+
Responsible for cleaning, phonemicizing, tokenizing, and embedding the text provided to the model. This is the text pre-processing pipeline. If you would like to change how a word is pronounced or enter raw phonemes you can do that here.
|
7 |
+
|
8 |
+
Supported by transformer and hybrid models.
|
9 |
+
---
|
10 |
+
### speaker
|
11 |
+
- **Type:** `PassthroughConditioner`
|
12 |
+
- **Attributes:**
|
13 |
+
- **cond_dim:** `128`
|
14 |
+
- **uncond_type:** `learned`
|
15 |
+
- **projection:** `linear`
|
16 |
+
- **Description:**
|
17 |
+
An embedded representation of the speakers voice. We use [these](https://huggingface.co/Zyphra/Zonos-v0.1-speaker-embedding) speaker embedding models. It can capture a surprising amount of detail from the reference clip and supports arbitrary length input. Try to input clean reference clips containing only speech. It can be valid to concatenate multiple clean samples from the same speaker into one long sample and may lead to better cloning. If the speaker clip is very long, it is advisable to cut out long speech-free background music segments if they exist. If the reference clip is yielding noisy outputs with denoising enabled we recommend doing source separation before cloning.
|
18 |
+
|
19 |
+
Supported by transformer and hybrid models.
|
20 |
+
---
|
21 |
+
### emotion
|
22 |
+
- **Type:** `FourierConditioner`
|
23 |
+
- **Attributes:**
|
24 |
+
- **input_dim:** `8`
|
25 |
+
- **uncond_type:** `learned`
|
26 |
+
- **Description:**
|
27 |
+
Encodes emotion in an 8D vector. Included emotions are Happiness, Sadness, Disgust, Fear, Surprise, Anger, Other, Neutral in that order. This vector tends to be entangled with various other conditioning inputs. More notably, it's entangled with text based on the text sentiment (eg. Angry texts will be more effectively conditioned to be angry, but if you try to make it sound sad it will be a lot less effective). It's also entangled with pitch standard deviation since larger values there tend to correlate to more emotional utterances. It's also heavily correlated with VQScore and DNSMOS as these conditionings favor neutral speech. It's also possible to do a form of "negative prompting" by doing CFG where the unconditional branch is set to a highly neutral emotion vector instead of the true unconditional value, doing this will exaggerate the emotions as it pushes the model away from being neutral.
|
28 |
+
|
29 |
+
Supported by transformer and hybrid models.
|
30 |
+
---
|
31 |
+
### fmax
|
32 |
+
- **Type:** `FourierConditioner`
|
33 |
+
- **Attributes:**
|
34 |
+
- **min_val:** `0`
|
35 |
+
- **max_val:** `24000`
|
36 |
+
- **uncond_type:** `learned`
|
37 |
+
- **Description:**
|
38 |
+
Specifies the max frequency of the audio. For best results select 22050 or 24000 as these correspond to 44.1 and 48KHz audio respectively. They should not be any different in terms of actual max frequency since the model's sampling rate is 44.1KHz but they represent different slices of data which lead to slightly different voicing. Selecting a lower value generally produces lower-quality results both in terms of acoustics and voicing.
|
39 |
+
|
40 |
+
For voice cloning it is recommended to use 22050.
|
41 |
+
|
42 |
+
Supported by transformer and hybrid models.
|
43 |
+
---
|
44 |
+
### pitch_std
|
45 |
+
- **Type:** `FourierConditioner`
|
46 |
+
- **Attributes:**
|
47 |
+
- **min_val:** `0`
|
48 |
+
- **max_val:** `400`
|
49 |
+
- **uncond_type:** `learned`
|
50 |
+
- **Description:**
|
51 |
+
Specifies the standard deviation of the pitch of the output audio. Wider variations of pitch tend to be more correlated with expressive speech. Good values are from 20-45 for normal speech and 60-150 for expressive speech. Higher than that generally tend to be crazier samples.
|
52 |
+
|
53 |
+
Supported by transformer and hybrid models.
|
54 |
+
---
|
55 |
+
### speaking_rate
|
56 |
+
- **Type:** `FourierConditioner`
|
57 |
+
- **Attributes:**
|
58 |
+
- **min_val:** `0`
|
59 |
+
- **max_val:** `40`
|
60 |
+
- **uncond_type:** `learned`
|
61 |
+
- **Description:**
|
62 |
+
Specifies the number of phonemes to be read per second. When entering a long text, it is advisable to adjust the speaking rate such that the number of phonemes is readable within the generation length. For example, if your generation length is 10 seconds, and your input is 300 phonemes, you would want either 30 phonemes per second (which is very very fast) or to generate a longer sample. The model's maximum is 30 seconds. Please note that unrealistic speaking rates can be OOD for the model and create undesirable effects, so at the 30-second limit, it can be better to cut the text short and do multiple generations than to feed the model the entire prompt and have an unrealistically low speaking rate.
|
63 |
+
|
64 |
+
Supported by transformer and hybrid models.
|
65 |
+
---
|
66 |
+
### language_id
|
67 |
+
- **Type:** `IntegerConditioner`
|
68 |
+
- **Attributes:**
|
69 |
+
- **min_val:** `-1`
|
70 |
+
- **max_val:** `126`
|
71 |
+
- **uncond_type:** `learned`
|
72 |
+
- **Description:**
|
73 |
+
Indicates which language the output should be in. A mapping for these values can be found in the [conditioning section](https://github.com/Zyphra/Zonos/blob/3807c8e04bd4beaadb9502b3df1ffa4b0350e3f7/zonos/conditioning.py#L308C1-L376C21) of Zonos.
|
74 |
+
|
75 |
+
Supported by transformer and hybrid models.
|
76 |
+
---
|
77 |
+
### vqscore_8
|
78 |
+
- **Type:** `FourierConditioner`
|
79 |
+
- **Attributes:**
|
80 |
+
- **input_dim:** `8`
|
81 |
+
- **min_val:** `0.5`
|
82 |
+
- **max_val:** `0.8`
|
83 |
+
- **uncond_type:** `learned`
|
84 |
+
- **Description:**
|
85 |
+
Encodes the desired [VQScore](https://github.com/JasonSWFu/VQscore) value for the output audio. VQScore is an unsupervised speech quality (cleanliness) estimation method that we found has superior generalization and reduced biases compared to supervised methods like DNSMOS. A good value for our model is 0.78 for high-quality speech. The eight dimensions correspond to consecutive 1/8th chunks of the audio. (eg. for an 8-second output, the first dimension represents the quality of the first second only). For inference, we generally set all 8 dimensions to the same value. This has an unfortunately strong correlation with expressiveness, so for expressive speech, we recommend setting it to unconditional.
|
86 |
+
|
87 |
+
Only applicable for the hybrid model.
|
88 |
+
---
|
89 |
+
### ctc_loss
|
90 |
+
- **Type:** `FourierConditioner`
|
91 |
+
- **Attributes:**
|
92 |
+
- **min_val:** `-1.0`
|
93 |
+
- **max_val:** `1000`
|
94 |
+
- **uncond_type:** `learned`
|
95 |
+
- **Description:**
|
96 |
+
Encodes loss values from a [CTC](https://en.wikipedia.org/wiki/Connectionist_temporal_classification) (Connectionist Temporal Classification) setup, this indicates how well the training-time transcription matched with the audio according to a CTC model. For inference always use low values (eg. 0.0 or 1.0)
|
97 |
+
|
98 |
+
Only applicable for the hybrid model.
|
99 |
+
---
|
100 |
+
### dnsmos_ovrl
|
101 |
+
- **Type:** `FourierConditioner`
|
102 |
+
- **Attributes:**
|
103 |
+
- **min_val:** `1`
|
104 |
+
- **max_val:** `5`
|
105 |
+
- **uncond_type:** `learned`
|
106 |
+
- **Description:**
|
107 |
+
A [MOS](https://arxiv.org/abs/2110.01763) score for the output audio. This is similar to VQScore and tends to have a stronger entanglement with emotions. It additionally has a strong entanglement with languages. Set to 4.0 for very clean and neutral English speech, else we recommend setting it to unconditional.
|
108 |
+
|
109 |
+
Only applicable for the hybrid model.
|
110 |
+
---
|
111 |
+
### speaker_noised
|
112 |
+
- **Type:** `IntegerConditioner`
|
113 |
+
- **Attributes:**
|
114 |
+
- **min_val:** `0`
|
115 |
+
- **max_val:** `1`
|
116 |
+
- **uncond_type:** `learned`
|
117 |
+
- **Description:**
|
118 |
+
Indicates if the speaker embedding is noisy or not. If checked this lets the model clean (denoise) the input speaker embedding. When this is set to True, VQScore and DNSMOS will have a lot more power to clean the speaker embedding, so for very noisy input samples we recommend setting this to True and specifying a high VQScore value. If your speaker cloning outputs sound echo-y or do weird things, setting this to True will help.
|
119 |
+
|
120 |
+
Only applicable for the hybrid model.
|
Zonos-main/Dockerfile
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel
|
2 |
+
RUN pip install uv
|
3 |
+
|
4 |
+
RUN apt update && \
|
5 |
+
apt install -y espeak-ng && \
|
6 |
+
rm -rf /var/lib/apt/lists/*
|
7 |
+
|
8 |
+
WORKDIR /app
|
9 |
+
COPY . ./
|
10 |
+
|
11 |
+
RUN uv pip install --system -e . && uv pip install --system -e .[compile]
|
Zonos-main/LICENSE
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Apache License
|
3 |
+
Version 2.0, January 2004
|
4 |
+
http://www.apache.org/licenses/
|
5 |
+
|
6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
7 |
+
|
8 |
+
1. Definitions.
|
9 |
+
|
10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
12 |
+
|
13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
14 |
+
the copyright owner that is granting the License.
|
15 |
+
|
16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
17 |
+
other entities that control, are controlled by, or are under common
|
18 |
+
control with that entity. For the purposes of this definition,
|
19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
20 |
+
direction or management of such entity, whether by contract or
|
21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
23 |
+
|
24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
25 |
+
exercising permissions granted by this License.
|
26 |
+
|
27 |
+
"Source" form shall mean the preferred form for making modifications,
|
28 |
+
including but not limited to software source code, documentation
|
29 |
+
source, and configuration files.
|
30 |
+
|
31 |
+
"Object" form shall mean any form resulting from mechanical
|
32 |
+
transformation or translation of a Source form, including but
|
33 |
+
not limited to compiled object code, generated documentation,
|
34 |
+
and conversions to other media types.
|
35 |
+
|
36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
37 |
+
Object form, made available under the License, as indicated by a
|
38 |
+
copyright notice that is included in or attached to the work
|
39 |
+
(an example is provided in the Appendix below).
|
40 |
+
|
41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
42 |
+
form, that is based on (or derived from) the Work and for which the
|
43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
45 |
+
of this License, Derivative Works shall not include works that remain
|
46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
47 |
+
the Work and Derivative Works thereof.
|
48 |
+
|
49 |
+
"Contribution" shall mean any work of authorship, including
|
50 |
+
the original version of the Work and any modifications or additions
|
51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
55 |
+
means any form of electronic, verbal, or written communication sent
|
56 |
+
to the Licensor or its representatives, including but not limited to
|
57 |
+
communication on electronic mailing lists, source code control systems,
|
58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
60 |
+
excluding communication that is conspicuously marked or otherwise
|
61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
62 |
+
|
63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
65 |
+
subsequently incorporated within the Work.
|
66 |
+
|
67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
72 |
+
Work and such Derivative Works in Source or Object form.
|
73 |
+
|
74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
77 |
+
(except as stated in this section) patent license to make, have made,
|
78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
79 |
+
where such license applies only to those patent claims licensable
|
80 |
+
by such Contributor that are necessarily infringed by their
|
81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
83 |
+
institute patent litigation against any entity (including a
|
84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
85 |
+
or a Contribution incorporated within the Work constitutes direct
|
86 |
+
or contributory patent infringement, then any patent licenses
|
87 |
+
granted to You under this License for that Work shall terminate
|
88 |
+
as of the date such litigation is filed.
|
89 |
+
|
90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
91 |
+
Work or Derivative Works thereof in any medium, with or without
|
92 |
+
modifications, and in Source or Object form, provided that You
|
93 |
+
meet the following conditions:
|
94 |
+
|
95 |
+
(a) You must give any other recipients of the Work or
|
96 |
+
Derivative Works a copy of this License; and
|
97 |
+
|
98 |
+
(b) You must cause any modified files to carry prominent notices
|
99 |
+
stating that You changed the files; and
|
100 |
+
|
101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
102 |
+
that You distribute, all copyright, patent, trademark, and
|
103 |
+
attribution notices from the Source form of the Work,
|
104 |
+
excluding those notices that do not pertain to any part of
|
105 |
+
the Derivative Works; and
|
106 |
+
|
107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
108 |
+
distribution, then any Derivative Works that You distribute must
|
109 |
+
include a readable copy of the attribution notices contained
|
110 |
+
within such NOTICE file, excluding those notices that do not
|
111 |
+
pertain to any part of the Derivative Works, in at least one
|
112 |
+
of the following places: within a NOTICE text file distributed
|
113 |
+
as part of the Derivative Works; within the Source form or
|
114 |
+
documentation, if provided along with the Derivative Works; or,
|
115 |
+
within a display generated by the Derivative Works, if and
|
116 |
+
wherever such third-party notices normally appear. The contents
|
117 |
+
of the NOTICE file are for informational purposes only and
|
118 |
+
do not modify the License. You may add Your own attribution
|
119 |
+
notices within Derivative Works that You distribute, alongside
|
120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
121 |
+
that such additional attribution notices cannot be construed
|
122 |
+
as modifying the License.
|
123 |
+
|
124 |
+
You may add Your own copyright statement to Your modifications and
|
125 |
+
may provide additional or different license terms and conditions
|
126 |
+
for use, reproduction, or distribution of Your modifications, or
|
127 |
+
for any such Derivative Works as a whole, provided Your use,
|
128 |
+
reproduction, and distribution of the Work otherwise complies with
|
129 |
+
the conditions stated in this License.
|
130 |
+
|
131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
133 |
+
by You to the Licensor shall be under the terms and conditions of
|
134 |
+
this License, without any additional terms or conditions.
|
135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
136 |
+
the terms of any separate license agreement you may have executed
|
137 |
+
with Licensor regarding such Contributions.
|
138 |
+
|
139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
141 |
+
except as required for reasonable and customary use in describing the
|
142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
143 |
+
|
144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
145 |
+
agreed to in writing, Licensor provides the Work (and each
|
146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
148 |
+
implied, including, without limitation, any warranties or conditions
|
149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
151 |
+
appropriateness of using or redistributing the Work and assume any
|
152 |
+
risks associated with Your exercise of permissions under this License.
|
153 |
+
|
154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
155 |
+
whether in tort (including negligence), contract, or otherwise,
|
156 |
+
unless required by applicable law (such as deliberate and grossly
|
157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
158 |
+
liable to You for damages, including any direct, indirect, special,
|
159 |
+
incidental, or consequential damages of any character arising as a
|
160 |
+
result of this License or out of the use or inability to use the
|
161 |
+
Work (including but not limited to damages for loss of goodwill,
|
162 |
+
work stoppage, computer failure or malfunction, or any and all
|
163 |
+
other commercial damages or losses), even if such Contributor
|
164 |
+
has been advised of the possibility of such damages.
|
165 |
+
|
166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
169 |
+
or other liability obligations and/or rights consistent with this
|
170 |
+
License. However, in accepting such obligations, You may act only
|
171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
172 |
+
of any other Contributor, and only if You agree to indemnify,
|
173 |
+
defend, and hold each Contributor harmless for any liability
|
174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
175 |
+
of your accepting any such warranty or additional liability.
|
176 |
+
|
177 |
+
END OF TERMS AND CONDITIONS
|
178 |
+
|
179 |
+
APPENDIX: How to apply the Apache License to your work.
|
180 |
+
|
181 |
+
To apply the Apache License to your work, attach the following
|
182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
183 |
+
replaced with your own identifying information. (Don't include
|
184 |
+
the brackets!) The text should be enclosed in the appropriate
|
185 |
+
comment syntax for the file format. We also recommend that a
|
186 |
+
file or class name and description of purpose be included on the
|
187 |
+
same "printed page" as the copyright notice for easier
|
188 |
+
identification within third-party archives.
|
189 |
+
|
190 |
+
Copyright [yyyy] [name of copyright owner]
|
191 |
+
|
192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
193 |
+
you may not use this file except in compliance with the License.
|
194 |
+
You may obtain a copy of the License at
|
195 |
+
|
196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
197 |
+
|
198 |
+
Unless required by applicable law or agreed to in writing, software
|
199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
201 |
+
See the License for the specific language governing permissions and
|
202 |
+
limitations under the License.
|
Zonos-main/README.md
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Zonos-v0.1
|
2 |
+
|
3 |
+
<div align="center">
|
4 |
+
<img src="assets/ZonosHeader.png"
|
5 |
+
alt="Alt text"
|
6 |
+
style="width: 500px;
|
7 |
+
height: auto;
|
8 |
+
object-position: center top;">
|
9 |
+
</div>
|
10 |
+
|
11 |
+
<div align="center">
|
12 |
+
<a href="https://discord.gg/gTW9JwST8q" target="_blank">
|
13 |
+
<img src="https://img.shields.io/badge/Join%20Our%20Discord-7289DA?style=for-the-badge&logo=discord&logoColor=white" alt="Discord">
|
14 |
+
</a>
|
15 |
+
</div>
|
16 |
+
|
17 |
+
---
|
18 |
+
|
19 |
+
Zonos-v0.1 is a leading open-weight text-to-speech model trained on more than 200k hours of varied multilingual speech, delivering expressiveness and quality on par with—or even surpassing—top TTS providers.
|
20 |
+
|
21 |
+
Our model enables highly natural speech generation from text prompts when given a speaker embedding or audio prefix, and can accurately perform speech cloning when given a reference clip spanning just a few seconds. The conditioning setup also allows for fine control over speaking rate, pitch variation, audio quality, and emotions such as happiness, fear, sadness, and anger. The model outputs speech natively at 44kHz.
|
22 |
+
|
23 |
+
##### For more details and speech samples, check out our blog [here](https://www.zyphra.com/post/beta-release-of-zonos-v0-1)
|
24 |
+
|
25 |
+
##### We also have a hosted version available at [playground.zyphra.com/audio](https://playground.zyphra.com/audio)
|
26 |
+
|
27 |
+
---
|
28 |
+
|
29 |
+
Zonos follows a straightforward architecture: text normalization and phonemization via eSpeak, followed by DAC token prediction through a transformer or hybrid backbone. An overview of the architecture can be seen below.
|
30 |
+
|
31 |
+
<div align="center">
|
32 |
+
<img src="assets/ArchitectureDiagram.png"
|
33 |
+
alt="Alt text"
|
34 |
+
style="width: 1000px;
|
35 |
+
height: auto;
|
36 |
+
object-position: center top;">
|
37 |
+
</div>
|
38 |
+
|
39 |
+
---
|
40 |
+
|
41 |
+
## Usage
|
42 |
+
|
43 |
+
### Python
|
44 |
+
|
45 |
+
```python
|
46 |
+
import torch
|
47 |
+
import torchaudio
|
48 |
+
from zonos.model import Zonos
|
49 |
+
from zonos.conditioning import make_cond_dict
|
50 |
+
from zonos.utils import DEFAULT_DEVICE as device
|
51 |
+
|
52 |
+
# model = Zonos.from_pretrained("Zyphra/Zonos-v0.1-hybrid", device=device)
|
53 |
+
model = Zonos.from_pretrained("Zyphra/Zonos-v0.1-transformer", device=device)
|
54 |
+
|
55 |
+
wav, sampling_rate = torchaudio.load("assets/exampleaudio.mp3")
|
56 |
+
speaker = model.make_speaker_embedding(wav, sampling_rate)
|
57 |
+
|
58 |
+
cond_dict = make_cond_dict(text="Hello, world!", speaker=speaker, language="en-us")
|
59 |
+
conditioning = model.prepare_conditioning(cond_dict)
|
60 |
+
|
61 |
+
codes = model.generate(conditioning)
|
62 |
+
|
63 |
+
wavs = model.autoencoder.decode(codes).cpu()
|
64 |
+
torchaudio.save("sample.wav", wavs[0], model.autoencoder.sampling_rate)
|
65 |
+
```
|
66 |
+
|
67 |
+
### Gradio interface (recommended)
|
68 |
+
|
69 |
+
```bash
|
70 |
+
uv run gradio_interface.py
|
71 |
+
# python gradio_interface.py
|
72 |
+
```
|
73 |
+
|
74 |
+
This should produce a `sample.wav` file in your project root directory.
|
75 |
+
|
76 |
+
_For repeated sampling we highly recommend using the gradio interface instead, as the minimal example needs to load the model every time it is run._
|
77 |
+
|
78 |
+
## Features
|
79 |
+
|
80 |
+
- Zero-shot TTS with voice cloning: Input desired text and a 10-30s speaker sample to generate high quality TTS output
|
81 |
+
- Audio prefix inputs: Add text plus an audio prefix for even richer speaker matching. Audio prefixes can be used to elicit behaviours such as whispering which can otherwise be challenging to replicate when cloning from speaker embeddings
|
82 |
+
- Multilingual support: Zonos-v0.1 supports English, Japanese, Chinese, French, and German
|
83 |
+
- Audio quality and emotion control: Zonos offers fine-grained control of many aspects of the generated audio. These include speaking rate, pitch, maximum frequency, audio quality, and various emotions such as happiness, anger, sadness, and fear.
|
84 |
+
- Fast: our model runs with a real-time factor of ~2x on an RTX 4090 (i.e. generates 2 seconds of audio per 1 second of compute time)
|
85 |
+
- Gradio WebUI: Zonos comes packaged with an easy to use gradio interface to generate speech
|
86 |
+
- Simple installation and deployment: Zonos can be installed and deployed simply using the docker file packaged with our repository.
|
87 |
+
|
88 |
+
## Installation
|
89 |
+
|
90 |
+
#### System requirements
|
91 |
+
|
92 |
+
- **Operating System:** Linux (preferably Ubuntu 22.04/24.04), macOS
|
93 |
+
- **GPU:** 6GB+ VRAM, Hybrid additionally requires a 3000-series or newer Nvidia GPU
|
94 |
+
|
95 |
+
Note: Zonos can also run on CPU provided there is enough free RAM. However, this will be a lot slower than running on a dedicated GPU, and likely won't be sufficient for interactive use.
|
96 |
+
|
97 |
+
For experimental windows support check out [this fork](https://github.com/sdbds/Zonos-for-windows).
|
98 |
+
|
99 |
+
See also [Docker Installation](#docker-installation)
|
100 |
+
|
101 |
+
#### System dependencies
|
102 |
+
|
103 |
+
Zonos depends on the eSpeak library phonemization. You can install it on Ubuntu with the following command:
|
104 |
+
|
105 |
+
```bash
|
106 |
+
apt install -y espeak-ng # For Ubuntu
|
107 |
+
# brew install espeak-ng # For MacOS
|
108 |
+
```
|
109 |
+
|
110 |
+
#### Python dependencies
|
111 |
+
|
112 |
+
We highly recommend using a recent version of [uv](https://docs.astral.sh/uv/#installation) for installation. If you don't have uv installed, you can install it via pip: `pip install -U uv`.
|
113 |
+
|
114 |
+
##### Installing into a new uv virtual environment (recommended)
|
115 |
+
|
116 |
+
```bash
|
117 |
+
uv sync
|
118 |
+
uv sync --extra compile # optional but needed to run the hybrid
|
119 |
+
uv pip install -e .
|
120 |
+
```
|
121 |
+
|
122 |
+
##### Installing into the system/actived environment using uv
|
123 |
+
|
124 |
+
```bash
|
125 |
+
uv pip install -e .
|
126 |
+
uv pip install -e .[compile] # optional but needed to run the hybrid
|
127 |
+
```
|
128 |
+
|
129 |
+
##### Installing into the system/actived environment using pip
|
130 |
+
|
131 |
+
```bash
|
132 |
+
pip install -e .
|
133 |
+
pip install --no-build-isolation -e .[compile] # optional but needed to run the hybrid
|
134 |
+
```
|
135 |
+
|
136 |
+
##### Confirm that it's working
|
137 |
+
|
138 |
+
For convenience we provide a minimal example to check that the installation works:
|
139 |
+
|
140 |
+
```bash
|
141 |
+
uv run sample.py
|
142 |
+
# python sample.py
|
143 |
+
```
|
144 |
+
|
145 |
+
## Docker installation
|
146 |
+
|
147 |
+
```bash
|
148 |
+
git clone https://github.com/Zyphra/Zonos.git
|
149 |
+
cd Zonos
|
150 |
+
|
151 |
+
# For gradio
|
152 |
+
docker compose up
|
153 |
+
|
154 |
+
# Or for development you can do
|
155 |
+
docker build -t zonos .
|
156 |
+
docker run -it --gpus=all --net=host -v /path/to/Zonos:/Zonos -t zonos
|
157 |
+
cd /Zonos
|
158 |
+
python sample.py # this will generate a sample.wav in /Zonos
|
159 |
+
```
|
Zonos-main/assets/ArchitectureDiagram.png
ADDED
![]() |
Git LFS Details
|
Zonos-main/assets/ZonosHeader.png
ADDED
![]() |
Zonos-main/assets/exampleaudio.mp3
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9c0931115d43cd4e672dc137d2099c0b4e103d7207a2a42e957e41b2af30e4ae
|
3 |
+
size 819820
|
Zonos-main/assets/silence_100ms.wav
ADDED
Binary file (9.43 kB). View file
|
|
Zonos-main/docker-compose.yml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: '3.8'
|
2 |
+
|
3 |
+
services:
|
4 |
+
zonos:
|
5 |
+
build:
|
6 |
+
context: .
|
7 |
+
dockerfile: Dockerfile
|
8 |
+
container_name: zonos_container
|
9 |
+
runtime: nvidia
|
10 |
+
network_mode: "host"
|
11 |
+
stdin_open: true
|
12 |
+
tty: true
|
13 |
+
command: ["python3", "gradio_interface.py"]
|
14 |
+
environment:
|
15 |
+
- NVIDIA_VISIBLE_DEVICES=0
|
16 |
+
- GRADIO_SHARE=False
|
Zonos-main/gradio_interface.py
ADDED
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchaudio
|
3 |
+
import gradio as gr
|
4 |
+
from os import getenv
|
5 |
+
|
6 |
+
from zonos.model import Zonos, DEFAULT_BACKBONE_CLS as ZonosBackbone
|
7 |
+
from zonos.conditioning import make_cond_dict, supported_language_codes
|
8 |
+
from zonos.utils import DEFAULT_DEVICE as device
|
9 |
+
|
10 |
+
CURRENT_MODEL_TYPE = None
|
11 |
+
CURRENT_MODEL = None
|
12 |
+
|
13 |
+
SPEAKER_EMBEDDING = None
|
14 |
+
SPEAKER_AUDIO_PATH = None
|
15 |
+
|
16 |
+
|
17 |
+
def load_model_if_needed(model_choice: str):
|
18 |
+
global CURRENT_MODEL_TYPE, CURRENT_MODEL
|
19 |
+
if CURRENT_MODEL_TYPE != model_choice:
|
20 |
+
if CURRENT_MODEL is not None:
|
21 |
+
del CURRENT_MODEL
|
22 |
+
torch.cuda.empty_cache()
|
23 |
+
print(f"Loading {model_choice} model...")
|
24 |
+
CURRENT_MODEL = Zonos.from_pretrained(model_choice, device=device)
|
25 |
+
CURRENT_MODEL.requires_grad_(False).eval()
|
26 |
+
CURRENT_MODEL_TYPE = model_choice
|
27 |
+
print(f"{model_choice} model loaded successfully!")
|
28 |
+
return CURRENT_MODEL
|
29 |
+
|
30 |
+
|
31 |
+
def update_ui(model_choice):
|
32 |
+
"""
|
33 |
+
Dynamically show/hide UI elements based on the model's conditioners.
|
34 |
+
We do NOT display 'language_id' or 'ctc_loss' even if they exist in the model.
|
35 |
+
"""
|
36 |
+
model = load_model_if_needed(model_choice)
|
37 |
+
cond_names = [c.name for c in model.prefix_conditioner.conditioners]
|
38 |
+
print("Conditioners in this model:", cond_names)
|
39 |
+
|
40 |
+
text_update = gr.update(visible=("espeak" in cond_names))
|
41 |
+
language_update = gr.update(visible=("espeak" in cond_names))
|
42 |
+
speaker_audio_update = gr.update(visible=("speaker" in cond_names))
|
43 |
+
prefix_audio_update = gr.update(visible=True)
|
44 |
+
emotion1_update = gr.update(visible=("emotion" in cond_names))
|
45 |
+
emotion2_update = gr.update(visible=("emotion" in cond_names))
|
46 |
+
emotion3_update = gr.update(visible=("emotion" in cond_names))
|
47 |
+
emotion4_update = gr.update(visible=("emotion" in cond_names))
|
48 |
+
emotion5_update = gr.update(visible=("emotion" in cond_names))
|
49 |
+
emotion6_update = gr.update(visible=("emotion" in cond_names))
|
50 |
+
emotion7_update = gr.update(visible=("emotion" in cond_names))
|
51 |
+
emotion8_update = gr.update(visible=("emotion" in cond_names))
|
52 |
+
vq_single_slider_update = gr.update(visible=("vqscore_8" in cond_names))
|
53 |
+
fmax_slider_update = gr.update(visible=("fmax" in cond_names))
|
54 |
+
pitch_std_slider_update = gr.update(visible=("pitch_std" in cond_names))
|
55 |
+
speaking_rate_slider_update = gr.update(visible=("speaking_rate" in cond_names))
|
56 |
+
dnsmos_slider_update = gr.update(visible=("dnsmos_ovrl" in cond_names))
|
57 |
+
speaker_noised_checkbox_update = gr.update(visible=("speaker_noised" in cond_names))
|
58 |
+
unconditional_keys_update = gr.update(
|
59 |
+
choices=[name for name in cond_names if name not in ("espeak", "language_id")]
|
60 |
+
)
|
61 |
+
|
62 |
+
return (
|
63 |
+
text_update,
|
64 |
+
language_update,
|
65 |
+
speaker_audio_update,
|
66 |
+
prefix_audio_update,
|
67 |
+
emotion1_update,
|
68 |
+
emotion2_update,
|
69 |
+
emotion3_update,
|
70 |
+
emotion4_update,
|
71 |
+
emotion5_update,
|
72 |
+
emotion6_update,
|
73 |
+
emotion7_update,
|
74 |
+
emotion8_update,
|
75 |
+
vq_single_slider_update,
|
76 |
+
fmax_slider_update,
|
77 |
+
pitch_std_slider_update,
|
78 |
+
speaking_rate_slider_update,
|
79 |
+
dnsmos_slider_update,
|
80 |
+
speaker_noised_checkbox_update,
|
81 |
+
unconditional_keys_update,
|
82 |
+
)
|
83 |
+
|
84 |
+
|
85 |
+
def generate_audio(
|
86 |
+
model_choice,
|
87 |
+
text,
|
88 |
+
language,
|
89 |
+
speaker_audio,
|
90 |
+
prefix_audio,
|
91 |
+
e1,
|
92 |
+
e2,
|
93 |
+
e3,
|
94 |
+
e4,
|
95 |
+
e5,
|
96 |
+
e6,
|
97 |
+
e7,
|
98 |
+
e8,
|
99 |
+
vq_single,
|
100 |
+
fmax,
|
101 |
+
pitch_std,
|
102 |
+
speaking_rate,
|
103 |
+
dnsmos_ovrl,
|
104 |
+
speaker_noised,
|
105 |
+
cfg_scale,
|
106 |
+
top_p,
|
107 |
+
top_k,
|
108 |
+
min_p,
|
109 |
+
linear,
|
110 |
+
confidence,
|
111 |
+
quadratic,
|
112 |
+
seed,
|
113 |
+
randomize_seed,
|
114 |
+
unconditional_keys,
|
115 |
+
progress=gr.Progress(),
|
116 |
+
):
|
117 |
+
"""
|
118 |
+
Generates audio based on the provided UI parameters.
|
119 |
+
We do NOT use language_id or ctc_loss even if the model has them.
|
120 |
+
"""
|
121 |
+
selected_model = load_model_if_needed(model_choice)
|
122 |
+
|
123 |
+
speaker_noised_bool = bool(speaker_noised)
|
124 |
+
fmax = float(fmax)
|
125 |
+
pitch_std = float(pitch_std)
|
126 |
+
speaking_rate = float(speaking_rate)
|
127 |
+
dnsmos_ovrl = float(dnsmos_ovrl)
|
128 |
+
cfg_scale = float(cfg_scale)
|
129 |
+
top_p = float(top_p)
|
130 |
+
top_k = int(top_k)
|
131 |
+
min_p = float(min_p)
|
132 |
+
linear = float(linear)
|
133 |
+
confidence = float(confidence)
|
134 |
+
quadratic = float(quadratic)
|
135 |
+
seed = int(seed)
|
136 |
+
max_new_tokens = 86 * 30
|
137 |
+
|
138 |
+
# This is a bit ew, but works for now.
|
139 |
+
global SPEAKER_AUDIO_PATH, SPEAKER_EMBEDDING
|
140 |
+
|
141 |
+
if randomize_seed:
|
142 |
+
seed = torch.randint(0, 2**32 - 1, (1,)).item()
|
143 |
+
torch.manual_seed(seed)
|
144 |
+
|
145 |
+
if speaker_audio is not None and "speaker" not in unconditional_keys:
|
146 |
+
if speaker_audio != SPEAKER_AUDIO_PATH:
|
147 |
+
print("Recomputed speaker embedding")
|
148 |
+
wav, sr = torchaudio.load(speaker_audio)
|
149 |
+
SPEAKER_EMBEDDING = selected_model.make_speaker_embedding(wav, sr)
|
150 |
+
SPEAKER_EMBEDDING = SPEAKER_EMBEDDING.to(device, dtype=torch.bfloat16)
|
151 |
+
SPEAKER_AUDIO_PATH = speaker_audio
|
152 |
+
|
153 |
+
audio_prefix_codes = None
|
154 |
+
if prefix_audio is not None:
|
155 |
+
wav_prefix, sr_prefix = torchaudio.load(prefix_audio)
|
156 |
+
wav_prefix = wav_prefix.mean(0, keepdim=True)
|
157 |
+
wav_prefix = selected_model.autoencoder.preprocess(wav_prefix, sr_prefix)
|
158 |
+
wav_prefix = wav_prefix.to(device, dtype=torch.float32)
|
159 |
+
audio_prefix_codes = selected_model.autoencoder.encode(wav_prefix.unsqueeze(0))
|
160 |
+
|
161 |
+
emotion_tensor = torch.tensor(list(map(float, [e1, e2, e3, e4, e5, e6, e7, e8])), device=device)
|
162 |
+
|
163 |
+
vq_val = float(vq_single)
|
164 |
+
vq_tensor = torch.tensor([vq_val] * 8, device=device).unsqueeze(0)
|
165 |
+
|
166 |
+
cond_dict = make_cond_dict(
|
167 |
+
text=text,
|
168 |
+
language=language,
|
169 |
+
speaker=SPEAKER_EMBEDDING,
|
170 |
+
emotion=emotion_tensor,
|
171 |
+
vqscore_8=vq_tensor,
|
172 |
+
fmax=fmax,
|
173 |
+
pitch_std=pitch_std,
|
174 |
+
speaking_rate=speaking_rate,
|
175 |
+
dnsmos_ovrl=dnsmos_ovrl,
|
176 |
+
speaker_noised=speaker_noised_bool,
|
177 |
+
device=device,
|
178 |
+
unconditional_keys=unconditional_keys,
|
179 |
+
)
|
180 |
+
conditioning = selected_model.prepare_conditioning(cond_dict)
|
181 |
+
|
182 |
+
estimated_generation_duration = 30 * len(text) / 400
|
183 |
+
estimated_total_steps = int(estimated_generation_duration * 86)
|
184 |
+
|
185 |
+
def update_progress(_frame: torch.Tensor, step: int, _total_steps: int) -> bool:
|
186 |
+
progress((step, estimated_total_steps))
|
187 |
+
return True
|
188 |
+
|
189 |
+
codes = selected_model.generate(
|
190 |
+
prefix_conditioning=conditioning,
|
191 |
+
audio_prefix_codes=audio_prefix_codes,
|
192 |
+
max_new_tokens=max_new_tokens,
|
193 |
+
cfg_scale=cfg_scale,
|
194 |
+
batch_size=1,
|
195 |
+
sampling_params=dict(top_p=top_p, top_k=top_k, min_p=min_p, linear=linear, conf=confidence, quad=quadratic),
|
196 |
+
callback=update_progress,
|
197 |
+
)
|
198 |
+
|
199 |
+
wav_out = selected_model.autoencoder.decode(codes).cpu().detach()
|
200 |
+
sr_out = selected_model.autoencoder.sampling_rate
|
201 |
+
if wav_out.dim() == 2 and wav_out.size(0) > 1:
|
202 |
+
wav_out = wav_out[0:1, :]
|
203 |
+
return (sr_out, wav_out.squeeze().numpy()), seed
|
204 |
+
|
205 |
+
|
206 |
+
def build_interface():
|
207 |
+
supported_models = []
|
208 |
+
if "transformer" in ZonosBackbone.supported_architectures:
|
209 |
+
supported_models.append("Zyphra/Zonos-v0.1-transformer")
|
210 |
+
|
211 |
+
if "hybrid" in ZonosBackbone.supported_architectures:
|
212 |
+
supported_models.append("Zyphra/Zonos-v0.1-hybrid")
|
213 |
+
else:
|
214 |
+
print(
|
215 |
+
"| The current ZonosBackbone does not support the hybrid architecture, meaning only the transformer model will be available in the model selector.\n"
|
216 |
+
"| This probably means the mamba-ssm library has not been installed."
|
217 |
+
)
|
218 |
+
|
219 |
+
with gr.Blocks() as demo:
|
220 |
+
with gr.Row():
|
221 |
+
with gr.Column():
|
222 |
+
model_choice = gr.Dropdown(
|
223 |
+
choices=supported_models,
|
224 |
+
value=supported_models[0],
|
225 |
+
label="Zonos Model Type",
|
226 |
+
info="Select the model variant to use.",
|
227 |
+
)
|
228 |
+
text = gr.Textbox(
|
229 |
+
label="Text to Synthesize",
|
230 |
+
value="Zonos uses eSpeak for text to phoneme conversion!",
|
231 |
+
lines=4,
|
232 |
+
max_length=500, # approximately
|
233 |
+
)
|
234 |
+
language = gr.Dropdown(
|
235 |
+
choices=supported_language_codes,
|
236 |
+
value="en-us",
|
237 |
+
label="Language Code",
|
238 |
+
info="Select a language code.",
|
239 |
+
)
|
240 |
+
prefix_audio = gr.Audio(
|
241 |
+
value="assets/silence_100ms.wav",
|
242 |
+
label="Optional Prefix Audio (continue from this audio)",
|
243 |
+
type="filepath",
|
244 |
+
)
|
245 |
+
with gr.Column():
|
246 |
+
speaker_audio = gr.Audio(
|
247 |
+
label="Optional Speaker Audio (for cloning)",
|
248 |
+
type="filepath",
|
249 |
+
)
|
250 |
+
speaker_noised_checkbox = gr.Checkbox(label="Denoise Speaker?", value=False)
|
251 |
+
|
252 |
+
with gr.Row():
|
253 |
+
with gr.Column():
|
254 |
+
gr.Markdown("## Conditioning Parameters")
|
255 |
+
dnsmos_slider = gr.Slider(1.0, 5.0, value=4.0, step=0.1, label="DNSMOS Overall")
|
256 |
+
fmax_slider = gr.Slider(0, 24000, value=24000, step=1, label="Fmax (Hz)")
|
257 |
+
vq_single_slider = gr.Slider(0.5, 0.8, 0.78, 0.01, label="VQ Score")
|
258 |
+
pitch_std_slider = gr.Slider(0.0, 300.0, value=45.0, step=1, label="Pitch Std")
|
259 |
+
speaking_rate_slider = gr.Slider(5.0, 30.0, value=15.0, step=0.5, label="Speaking Rate")
|
260 |
+
|
261 |
+
with gr.Column():
|
262 |
+
gr.Markdown("## Generation Parameters")
|
263 |
+
cfg_scale_slider = gr.Slider(1.0, 5.0, 2.0, 0.1, label="CFG Scale")
|
264 |
+
seed_number = gr.Number(label="Seed", value=420, precision=0)
|
265 |
+
randomize_seed_toggle = gr.Checkbox(label="Randomize Seed (before generation)", value=True)
|
266 |
+
|
267 |
+
with gr.Accordion("Sampling", open=False):
|
268 |
+
with gr.Row():
|
269 |
+
with gr.Column():
|
270 |
+
gr.Markdown("### NovelAi's unified sampler")
|
271 |
+
linear_slider = gr.Slider(-2.0, 2.0, 0.5, 0.01, label="Linear (set to 0 to disable unified sampling)", info="High values make the output less random.")
|
272 |
+
#Conf's theoretical range is between -2 * Quad and 0.
|
273 |
+
confidence_slider = gr.Slider(-2.0, 2.0, 0.40, 0.01, label="Confidence", info="Low values make random outputs more random.")
|
274 |
+
quadratic_slider = gr.Slider(-2.0, 2.0, 0.00, 0.01, label="Quadratic", info="High values make low probablities much lower.")
|
275 |
+
with gr.Column():
|
276 |
+
gr.Markdown("### Legacy sampling")
|
277 |
+
top_p_slider = gr.Slider(0.0, 1.0, 0, 0.01, label="Top P")
|
278 |
+
min_k_slider = gr.Slider(0.0, 1024, 0, 1, label="Min K")
|
279 |
+
min_p_slider = gr.Slider(0.0, 1.0, 0, 0.01, label="Min P")
|
280 |
+
|
281 |
+
with gr.Accordion("Advanced Parameters", open=False):
|
282 |
+
gr.Markdown(
|
283 |
+
"### Unconditional Toggles\n"
|
284 |
+
"Checking a box will make the model ignore the corresponding conditioning value and make it unconditional.\n"
|
285 |
+
'Practically this means the given conditioning feature will be unconstrained and "filled in automatically".'
|
286 |
+
)
|
287 |
+
with gr.Row():
|
288 |
+
unconditional_keys = gr.CheckboxGroup(
|
289 |
+
[
|
290 |
+
"speaker",
|
291 |
+
"emotion",
|
292 |
+
"vqscore_8",
|
293 |
+
"fmax",
|
294 |
+
"pitch_std",
|
295 |
+
"speaking_rate",
|
296 |
+
"dnsmos_ovrl",
|
297 |
+
"speaker_noised",
|
298 |
+
],
|
299 |
+
value=["emotion"],
|
300 |
+
label="Unconditional Keys",
|
301 |
+
)
|
302 |
+
|
303 |
+
gr.Markdown(
|
304 |
+
"### Emotion Sliders\n"
|
305 |
+
"Warning: The way these sliders work is not intuitive and may require some trial and error to get the desired effect.\n"
|
306 |
+
"Certain configurations can cause the model to become unstable. Setting emotion to unconditional may help."
|
307 |
+
)
|
308 |
+
with gr.Row():
|
309 |
+
emotion1 = gr.Slider(0.0, 1.0, 1.0, 0.05, label="Happiness")
|
310 |
+
emotion2 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Sadness")
|
311 |
+
emotion3 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Disgust")
|
312 |
+
emotion4 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Fear")
|
313 |
+
with gr.Row():
|
314 |
+
emotion5 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Surprise")
|
315 |
+
emotion6 = gr.Slider(0.0, 1.0, 0.05, 0.05, label="Anger")
|
316 |
+
emotion7 = gr.Slider(0.0, 1.0, 0.1, 0.05, label="Other")
|
317 |
+
emotion8 = gr.Slider(0.0, 1.0, 0.2, 0.05, label="Neutral")
|
318 |
+
|
319 |
+
with gr.Column():
|
320 |
+
generate_button = gr.Button("Generate Audio")
|
321 |
+
output_audio = gr.Audio(label="Generated Audio", type="numpy", autoplay=True)
|
322 |
+
|
323 |
+
model_choice.change(
|
324 |
+
fn=update_ui,
|
325 |
+
inputs=[model_choice],
|
326 |
+
outputs=[
|
327 |
+
text,
|
328 |
+
language,
|
329 |
+
speaker_audio,
|
330 |
+
prefix_audio,
|
331 |
+
emotion1,
|
332 |
+
emotion2,
|
333 |
+
emotion3,
|
334 |
+
emotion4,
|
335 |
+
emotion5,
|
336 |
+
emotion6,
|
337 |
+
emotion7,
|
338 |
+
emotion8,
|
339 |
+
vq_single_slider,
|
340 |
+
fmax_slider,
|
341 |
+
pitch_std_slider,
|
342 |
+
speaking_rate_slider,
|
343 |
+
dnsmos_slider,
|
344 |
+
speaker_noised_checkbox,
|
345 |
+
unconditional_keys,
|
346 |
+
],
|
347 |
+
)
|
348 |
+
|
349 |
+
# On page load, trigger the same UI refresh
|
350 |
+
demo.load(
|
351 |
+
fn=update_ui,
|
352 |
+
inputs=[model_choice],
|
353 |
+
outputs=[
|
354 |
+
text,
|
355 |
+
language,
|
356 |
+
speaker_audio,
|
357 |
+
prefix_audio,
|
358 |
+
emotion1,
|
359 |
+
emotion2,
|
360 |
+
emotion3,
|
361 |
+
emotion4,
|
362 |
+
emotion5,
|
363 |
+
emotion6,
|
364 |
+
emotion7,
|
365 |
+
emotion8,
|
366 |
+
vq_single_slider,
|
367 |
+
fmax_slider,
|
368 |
+
pitch_std_slider,
|
369 |
+
speaking_rate_slider,
|
370 |
+
dnsmos_slider,
|
371 |
+
speaker_noised_checkbox,
|
372 |
+
unconditional_keys,
|
373 |
+
],
|
374 |
+
)
|
375 |
+
|
376 |
+
# Generate audio on button click
|
377 |
+
generate_button.click(
|
378 |
+
fn=generate_audio,
|
379 |
+
inputs=[
|
380 |
+
model_choice,
|
381 |
+
text,
|
382 |
+
language,
|
383 |
+
speaker_audio,
|
384 |
+
prefix_audio,
|
385 |
+
emotion1,
|
386 |
+
emotion2,
|
387 |
+
emotion3,
|
388 |
+
emotion4,
|
389 |
+
emotion5,
|
390 |
+
emotion6,
|
391 |
+
emotion7,
|
392 |
+
emotion8,
|
393 |
+
vq_single_slider,
|
394 |
+
fmax_slider,
|
395 |
+
pitch_std_slider,
|
396 |
+
speaking_rate_slider,
|
397 |
+
dnsmos_slider,
|
398 |
+
speaker_noised_checkbox,
|
399 |
+
cfg_scale_slider,
|
400 |
+
top_p_slider,
|
401 |
+
min_k_slider,
|
402 |
+
min_p_slider,
|
403 |
+
linear_slider,
|
404 |
+
confidence_slider,
|
405 |
+
quadratic_slider,
|
406 |
+
seed_number,
|
407 |
+
randomize_seed_toggle,
|
408 |
+
unconditional_keys,
|
409 |
+
],
|
410 |
+
outputs=[output_audio, seed_number],
|
411 |
+
)
|
412 |
+
|
413 |
+
return demo
|
414 |
+
|
415 |
+
|
416 |
+
if __name__ == "__main__":
|
417 |
+
demo = build_interface()
|
418 |
+
share = getenv("GRADIO_SHARE", "False").lower() in ("true", "1", "t")
|
419 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, share=share)
|
Zonos-main/pyproject.toml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
name = "zonos"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = "Text-to-speech by Zyphra"
|
5 |
+
readme = "README.md"
|
6 |
+
requires-python = ">=3.10"
|
7 |
+
dependencies = [
|
8 |
+
"torch>=2.5.1",
|
9 |
+
"setuptools",
|
10 |
+
"packaging",
|
11 |
+
"inflect>=7.5.0",
|
12 |
+
"kanjize>=1.5.0",
|
13 |
+
"numpy>=2.2.2",
|
14 |
+
"phonemizer>=3.3.0",
|
15 |
+
"sudachidict-full>=20241021",
|
16 |
+
"sudachipy>=0.6.10",
|
17 |
+
"torchaudio>=2.5.1",
|
18 |
+
"transformers>=4.48.1",
|
19 |
+
"soundfile>=0.13.1",
|
20 |
+
"huggingface-hub>=0.28.1",
|
21 |
+
"gradio>=5.15.0",
|
22 |
+
]
|
23 |
+
|
24 |
+
# These are technically optional, but mamba-ssm is required to run hybrid models.
|
25 |
+
[project.optional-dependencies]
|
26 |
+
compile = [
|
27 |
+
"flash-attn>=2.7.3",
|
28 |
+
"mamba-ssm>=2.2.4",
|
29 |
+
"causal-conv1d>=1.5.0.post8",
|
30 |
+
]
|
31 |
+
|
32 |
+
[tool.setuptools.packages.find]
|
33 |
+
include = ["zonos"]
|
34 |
+
|
35 |
+
[tool.uv]
|
36 |
+
no-build-isolation-package = ["flash-attn", "mamba-ssm", "causal-conv1d"]
|
37 |
+
|
38 |
+
[tool.ruff]
|
39 |
+
line-length = 120
|
Zonos-main/sample.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchaudio
|
3 |
+
from zonos.model import Zonos
|
4 |
+
from zonos.conditioning import make_cond_dict
|
5 |
+
from zonos.utils import DEFAULT_DEVICE as device
|
6 |
+
|
7 |
+
# model = Zonos.from_pretrained("Zyphra/Zonos-v0.1-hybrid", device=device)
|
8 |
+
model = Zonos.from_pretrained("Zyphra/Zonos-v0.1-transformer", device=device)
|
9 |
+
|
10 |
+
wav, sampling_rate = torchaudio.load("assets/exampleaudio.mp3")
|
11 |
+
speaker = model.make_speaker_embedding(wav, sampling_rate)
|
12 |
+
|
13 |
+
torch.manual_seed(421)
|
14 |
+
|
15 |
+
cond_dict = make_cond_dict(text="Hello, world!", speaker=speaker, language="en-us")
|
16 |
+
conditioning = model.prepare_conditioning(cond_dict)
|
17 |
+
|
18 |
+
codes = model.generate(conditioning)
|
19 |
+
|
20 |
+
wavs = model.autoencoder.decode(codes).cpu()
|
21 |
+
torchaudio.save("sample.wav", wavs[0], model.autoencoder.sampling_rate)
|
Zonos-main/uv.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Zonos-main/zonos/autoencoder.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torchaudio
|
5 |
+
from transformers.models.dac import DacModel
|
6 |
+
|
7 |
+
|
8 |
+
class DACAutoencoder:
|
9 |
+
def __init__(self):
|
10 |
+
super().__init__()
|
11 |
+
self.dac = DacModel.from_pretrained("descript/dac_44khz")
|
12 |
+
self.dac.eval().requires_grad_(False)
|
13 |
+
self.codebook_size = self.dac.config.codebook_size
|
14 |
+
self.num_codebooks = self.dac.quantizer.n_codebooks
|
15 |
+
self.sampling_rate = self.dac.config.sampling_rate
|
16 |
+
|
17 |
+
def preprocess(self, wav: torch.Tensor, sr: int) -> torch.Tensor:
|
18 |
+
wav = torchaudio.functional.resample(wav, sr, 44_100)
|
19 |
+
right_pad = math.ceil(wav.shape[-1] / 512) * 512 - wav.shape[-1]
|
20 |
+
return torch.nn.functional.pad(wav, (0, right_pad))
|
21 |
+
|
22 |
+
def encode(self, wav: torch.Tensor) -> torch.Tensor:
|
23 |
+
return self.dac.encode(wav).audio_codes
|
24 |
+
|
25 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
26 |
+
with torch.autocast(self.dac.device.type, torch.float16, enabled=self.dac.device.type != "cpu"):
|
27 |
+
return self.dac.decode(audio_codes=codes).audio_values.unsqueeze(1).float()
|
Zonos-main/zonos/backbone/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
BACKBONES = {}
|
2 |
+
|
3 |
+
try:
|
4 |
+
from ._mamba_ssm import MambaSSMZonosBackbone
|
5 |
+
|
6 |
+
BACKBONES["mamba_ssm"] = MambaSSMZonosBackbone
|
7 |
+
except ImportError:
|
8 |
+
pass
|
9 |
+
|
10 |
+
from ._torch import TorchZonosBackbone
|
11 |
+
|
12 |
+
BACKBONES["torch"] = TorchZonosBackbone
|
Zonos-main/zonos/backbone/_mamba_ssm.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from mamba_ssm.models.mixer_seq_simple import create_block
|
4 |
+
from mamba_ssm.ops.triton.layer_norm import layer_norm_fn
|
5 |
+
|
6 |
+
from zonos.config import BackboneConfig, InferenceParams
|
7 |
+
|
8 |
+
|
9 |
+
class MambaSSMZonosBackbone(nn.Module):
|
10 |
+
supported_architectures = ["transformer", "hybrid"]
|
11 |
+
|
12 |
+
def __init__(self, config: BackboneConfig):
|
13 |
+
super().__init__()
|
14 |
+
self.config = config
|
15 |
+
|
16 |
+
self.layers = nn.ModuleList(
|
17 |
+
[
|
18 |
+
create_block(
|
19 |
+
d_model=config.d_model,
|
20 |
+
d_intermediate=config.d_intermediate
|
21 |
+
if (i not in config.attn_layer_idx)
|
22 |
+
else config.attn_mlp_d_intermediate,
|
23 |
+
ssm_cfg=config.ssm_cfg,
|
24 |
+
layer_idx=i,
|
25 |
+
attn_layer_idx=config.attn_layer_idx,
|
26 |
+
attn_cfg=config.attn_cfg,
|
27 |
+
norm_epsilon=config.norm_epsilon,
|
28 |
+
residual_in_fp32=config.residual_in_fp32,
|
29 |
+
fused_add_norm=True,
|
30 |
+
rms_norm=config.rms_norm,
|
31 |
+
)
|
32 |
+
for i in range(config.n_layer)
|
33 |
+
]
|
34 |
+
)
|
35 |
+
|
36 |
+
self.norm_f = nn.LayerNorm(config.d_model, eps=config.norm_epsilon)
|
37 |
+
|
38 |
+
def allocate_inference_cache(self, batch_size: int, max_seqlen: int, dtype: torch.dtype = torch.bfloat16):
|
39 |
+
return {
|
40 |
+
i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype)
|
41 |
+
for i, layer in enumerate(self.layers)
|
42 |
+
}
|
43 |
+
|
44 |
+
def forward(self, hidden_states: torch.Tensor, inference_params: InferenceParams | None = None):
|
45 |
+
residual = None
|
46 |
+
for layer in self.layers:
|
47 |
+
hidden_states, residual = layer(hidden_states, residual, inference_params)
|
48 |
+
|
49 |
+
return layer_norm_fn(
|
50 |
+
hidden_states,
|
51 |
+
self.norm_f.weight,
|
52 |
+
self.norm_f.bias,
|
53 |
+
residual,
|
54 |
+
eps=self.norm_f.eps,
|
55 |
+
residual_in_fp32=self.config.residual_in_fp32,
|
56 |
+
is_rms_norm=self.config.rms_norm,
|
57 |
+
)
|
Zonos-main/zonos/backbone/_torch.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Based on gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/095b2229ee3a40e379c11f05b94bd6923db63b4b/model.py
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
from zonos.config import BackboneConfig, InferenceParams
|
7 |
+
|
8 |
+
|
9 |
+
def precompute_freqs_cis(seq_len: int, n_elem: int, base: float = 10000) -> torch.Tensor:
|
10 |
+
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
|
11 |
+
t = torch.arange(seq_len, device=freqs.device)
|
12 |
+
freqs = torch.outer(t, freqs)
|
13 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
14 |
+
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
|
15 |
+
return cache
|
16 |
+
|
17 |
+
|
18 |
+
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
19 |
+
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
|
20 |
+
freqs_cis = freqs_cis.view(-1, xshaped.size(1), 1, xshaped.size(3), 2)
|
21 |
+
x_out2 = torch.stack(
|
22 |
+
[
|
23 |
+
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
|
24 |
+
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
|
25 |
+
],
|
26 |
+
-1,
|
27 |
+
)
|
28 |
+
|
29 |
+
x_out2 = x_out2.flatten(3)
|
30 |
+
return x_out2.type_as(x)
|
31 |
+
|
32 |
+
|
33 |
+
def _update_kv_cache(
|
34 |
+
k: torch.Tensor, v: torch.Tensor, inference_params: InferenceParams, layer_idx: int
|
35 |
+
) -> torch.Tensor:
|
36 |
+
"""k/v: (batch_size, seqlen, nheads, head_dim) or (batch_size, 1, nheads, head_dim)"""
|
37 |
+
assert layer_idx in inference_params.key_value_memory_dict
|
38 |
+
kv_cache, _ = inference_params.key_value_memory_dict[layer_idx]
|
39 |
+
# Adjust key and value for inference
|
40 |
+
batch_start = inference_params.batch_size_offset
|
41 |
+
batch_end = batch_start + k.shape[0]
|
42 |
+
sequence_start = inference_params.seqlen_offset
|
43 |
+
sequence_end = sequence_start + k.shape[1]
|
44 |
+
assert batch_end <= kv_cache.shape[0]
|
45 |
+
assert sequence_end <= kv_cache.shape[1]
|
46 |
+
assert kv_cache is not None
|
47 |
+
kv_cache[batch_start:batch_end, sequence_start:sequence_end, 0, ...] = k
|
48 |
+
kv_cache[batch_start:batch_end, sequence_start:sequence_end, 1, ...] = v
|
49 |
+
return kv_cache[batch_start:batch_end, :sequence_end, ...]
|
50 |
+
|
51 |
+
|
52 |
+
class TorchZonosBackbone(nn.Module):
|
53 |
+
supported_architectures = ["transformer"]
|
54 |
+
freqs_cis: torch.Tensor
|
55 |
+
|
56 |
+
def __init__(self, config: BackboneConfig):
|
57 |
+
assert not config.ssm_cfg, "This backbone implementation only supports the Transformer model."
|
58 |
+
super().__init__()
|
59 |
+
self.config = config
|
60 |
+
|
61 |
+
self.layers = nn.ModuleList(TransformerBlock(config, i) for i in range(config.n_layer))
|
62 |
+
self.norm_f = nn.LayerNorm(config.d_model, eps=config.norm_epsilon)
|
63 |
+
|
64 |
+
def allocate_inference_cache(self, batch_size: int, max_seqlen: int, dtype: torch.dtype = torch.bfloat16):
|
65 |
+
# TODO: This function should be pure
|
66 |
+
head_dim = self.config.d_model // self.config.attn_cfg["num_heads"]
|
67 |
+
self.freqs_cis = precompute_freqs_cis(16384, head_dim)
|
68 |
+
return {
|
69 |
+
i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype)
|
70 |
+
for i, layer in enumerate(self.layers)
|
71 |
+
}
|
72 |
+
|
73 |
+
def forward(self, hidden_states: torch.Tensor, inference_params: InferenceParams) -> torch.Tensor:
|
74 |
+
input_pos = torch.arange(0, hidden_states.shape[1], device=hidden_states.device)
|
75 |
+
input_pos = input_pos + inference_params.lengths_per_sample.unsqueeze(-1)
|
76 |
+
|
77 |
+
freqs_cis = self.freqs_cis[input_pos].expand(hidden_states.shape[0], -1, -1, -1)
|
78 |
+
for i, layer in enumerate(self.layers):
|
79 |
+
hidden_states = layer(hidden_states, inference_params, freqs_cis)
|
80 |
+
return self.norm_f(hidden_states)
|
81 |
+
|
82 |
+
|
83 |
+
class TransformerBlock(nn.Module):
|
84 |
+
def __init__(self, config: BackboneConfig, layer_idx: int) -> None:
|
85 |
+
super().__init__()
|
86 |
+
self.config = config
|
87 |
+
|
88 |
+
self.norm = nn.LayerNorm(config.d_model, eps=config.norm_epsilon)
|
89 |
+
self.mixer = Attention(config, layer_idx)
|
90 |
+
self.norm2 = nn.LayerNorm(config.d_model, eps=config.norm_epsilon)
|
91 |
+
self.mlp = FeedForward(config)
|
92 |
+
|
93 |
+
self.num_heads_kv = config.attn_cfg["num_heads_kv"]
|
94 |
+
self.head_dim = config.d_model // config.attn_cfg["num_heads"]
|
95 |
+
|
96 |
+
def allocate_inference_cache(self, batch_size: int, max_seqlen: int, dtype: torch.dtype = torch.bfloat16):
|
97 |
+
return torch.empty(batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim, dtype=dtype), None
|
98 |
+
|
99 |
+
def forward(self, x: torch.Tensor, inference_params: InferenceParams, freqs_cis: torch.Tensor) -> torch.Tensor:
|
100 |
+
x = x + self.mixer(self.norm(x), inference_params, freqs_cis)
|
101 |
+
x = x + self.mlp(self.norm2(x))
|
102 |
+
return x
|
103 |
+
|
104 |
+
|
105 |
+
class Attention(nn.Module):
|
106 |
+
def __init__(self, config: BackboneConfig, layer_idx: int):
|
107 |
+
super().__init__()
|
108 |
+
self.num_heads = config.attn_cfg["num_heads"]
|
109 |
+
self.num_heads_kv = config.attn_cfg["num_heads_kv"]
|
110 |
+
self.head_dim = config.d_model // self.num_heads
|
111 |
+
self.layer_idx = layer_idx
|
112 |
+
|
113 |
+
total_head_dim = (self.num_heads + 2 * self.num_heads_kv) * self.head_dim
|
114 |
+
self.in_proj = nn.Linear(config.d_model, total_head_dim, bias=False)
|
115 |
+
self.out_proj = nn.Linear(self.num_heads * self.head_dim, config.d_model, bias=False)
|
116 |
+
|
117 |
+
def forward(self, x: torch.Tensor, inference_params: InferenceParams, freqs_cis: torch.Tensor) -> torch.Tensor:
|
118 |
+
batch_size, seqlen, _ = x.shape
|
119 |
+
|
120 |
+
q_size = self.num_heads * self.head_dim
|
121 |
+
kv_size = self.num_heads_kv * self.head_dim
|
122 |
+
q, k, v = self.in_proj(x).split([q_size, kv_size, kv_size], dim=-1)
|
123 |
+
|
124 |
+
q = q.view(batch_size, seqlen, self.num_heads, self.head_dim)
|
125 |
+
k = k.view(batch_size, seqlen, self.num_heads_kv, self.head_dim)
|
126 |
+
v = v.view(batch_size, seqlen, self.num_heads_kv, self.head_dim)
|
127 |
+
|
128 |
+
q = apply_rotary_emb(q, freqs_cis)
|
129 |
+
k = apply_rotary_emb(k, freqs_cis)
|
130 |
+
|
131 |
+
kv = _update_kv_cache(k, v, inference_params, self.layer_idx)
|
132 |
+
k, v = kv.unbind(dim=-3)
|
133 |
+
|
134 |
+
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
135 |
+
|
136 |
+
y = F.scaled_dot_product_attention(q, k, v, is_causal=seqlen > 1, enable_gqa=True)
|
137 |
+
|
138 |
+
y = y.transpose(1, 2).contiguous().view(batch_size, seqlen, q_size)
|
139 |
+
|
140 |
+
y = self.out_proj(y)
|
141 |
+
return y
|
142 |
+
|
143 |
+
|
144 |
+
class FeedForward(nn.Module):
|
145 |
+
def __init__(self, config: BackboneConfig) -> None:
|
146 |
+
super().__init__()
|
147 |
+
self.fc1 = nn.Linear(config.d_model, 2 * config.attn_mlp_d_intermediate, bias=False)
|
148 |
+
self.fc2 = nn.Linear(config.attn_mlp_d_intermediate, config.d_model, bias=False)
|
149 |
+
|
150 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
151 |
+
y, gate = self.fc1(x).chunk(2, dim=-1)
|
152 |
+
return self.fc2(y * F.silu(gate))
|
Zonos-main/zonos/codebook_pattern.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
def apply_delay_pattern(codes: torch.Tensor, mask_token: int):
|
6 |
+
codes = F.pad(codes, (0, codes.shape[1]), value=mask_token)
|
7 |
+
return torch.stack([codes[:, k].roll(k + 1) for k in range(codes.shape[1])], dim=1)
|
8 |
+
|
9 |
+
|
10 |
+
def revert_delay_pattern(codes: torch.Tensor):
|
11 |
+
_, n_q, seq_len = codes.shape
|
12 |
+
return torch.stack([codes[:, k, k + 1 : seq_len - n_q + k + 1] for k in range(n_q)], dim=1)
|
Zonos-main/zonos/conditioning.py
ADDED
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import cache
|
2 |
+
from typing import Any, Literal, Iterable
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from zonos.config import PrefixConditionerConfig
|
8 |
+
from zonos.utils import DEFAULT_DEVICE
|
9 |
+
|
10 |
+
|
11 |
+
class Conditioner(nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
output_dim: int,
|
15 |
+
name: str,
|
16 |
+
cond_dim: int | None = None,
|
17 |
+
projection: Literal["none", "linear", "mlp"] = "none",
|
18 |
+
uncond_type: Literal["learned", "none"] = "none",
|
19 |
+
**kwargs,
|
20 |
+
):
|
21 |
+
super().__init__()
|
22 |
+
self.name = name
|
23 |
+
self.output_dim = output_dim
|
24 |
+
self.cond_dim = cond_dim = cond_dim or output_dim
|
25 |
+
|
26 |
+
if projection == "linear":
|
27 |
+
self.project = nn.Linear(cond_dim, output_dim)
|
28 |
+
elif projection == "mlp":
|
29 |
+
self.project = nn.Sequential(
|
30 |
+
nn.Linear(cond_dim, output_dim),
|
31 |
+
nn.SiLU(),
|
32 |
+
nn.Linear(output_dim, output_dim),
|
33 |
+
)
|
34 |
+
else:
|
35 |
+
self.project = nn.Identity()
|
36 |
+
|
37 |
+
self.uncond_vector = None
|
38 |
+
if uncond_type == "learned":
|
39 |
+
self.uncond_vector = nn.Parameter(torch.zeros(output_dim))
|
40 |
+
|
41 |
+
def apply_cond(self, *inputs: Any) -> torch.Tensor:
|
42 |
+
raise NotImplementedError()
|
43 |
+
|
44 |
+
def forward(self, inputs: tuple[Any, ...] | None) -> torch.Tensor:
|
45 |
+
if inputs is None:
|
46 |
+
assert self.uncond_vector is not None
|
47 |
+
return self.uncond_vector.data.view(1, 1, -1)
|
48 |
+
|
49 |
+
cond = self.apply_cond(*inputs)
|
50 |
+
cond = self.project(cond)
|
51 |
+
return cond
|
52 |
+
|
53 |
+
|
54 |
+
# ------- ESPEAK CONTAINMENT ZONE ------------------------------------------------------------------------------------------------------------------------------------------------
|
55 |
+
import os
|
56 |
+
import sys
|
57 |
+
import re
|
58 |
+
import unicodedata
|
59 |
+
|
60 |
+
import inflect
|
61 |
+
import torch
|
62 |
+
import torch.nn as nn
|
63 |
+
from kanjize import number2kanji
|
64 |
+
from phonemizer.backend import EspeakBackend
|
65 |
+
from sudachipy import Dictionary, SplitMode
|
66 |
+
|
67 |
+
if sys.platform == "darwin":
|
68 |
+
os.environ["PHONEMIZER_ESPEAK_LIBRARY"] = "/opt/homebrew/lib/libespeak-ng.dylib"
|
69 |
+
|
70 |
+
# --- Number normalization code from https://github.com/daniilrobnikov/vits2/blob/main/text/normalize_numbers.py ---
|
71 |
+
|
72 |
+
_inflect = inflect.engine()
|
73 |
+
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
|
74 |
+
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
|
75 |
+
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
|
76 |
+
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
|
77 |
+
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
|
78 |
+
_number_re = re.compile(r"[0-9]+")
|
79 |
+
|
80 |
+
|
81 |
+
def _remove_commas(m: re.Match) -> str:
|
82 |
+
return m.group(1).replace(",", "")
|
83 |
+
|
84 |
+
|
85 |
+
def _expand_decimal_point(m: re.Match) -> str:
|
86 |
+
return m.group(1).replace(".", " point ")
|
87 |
+
|
88 |
+
|
89 |
+
def _expand_dollars(m: re.Match) -> str:
|
90 |
+
match = m.group(1)
|
91 |
+
parts = match.split(".")
|
92 |
+
if len(parts) > 2:
|
93 |
+
return match + " dollars" # Unexpected format
|
94 |
+
dollars = int(parts[0]) if parts[0] else 0
|
95 |
+
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
96 |
+
if dollars and cents:
|
97 |
+
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
98 |
+
cent_unit = "cent" if cents == 1 else "cents"
|
99 |
+
return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
|
100 |
+
elif dollars:
|
101 |
+
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
102 |
+
return "%s %s" % (dollars, dollar_unit)
|
103 |
+
elif cents:
|
104 |
+
cent_unit = "cent" if cents == 1 else "cents"
|
105 |
+
return "%s %s" % (cents, cent_unit)
|
106 |
+
else:
|
107 |
+
return "zero dollars"
|
108 |
+
|
109 |
+
|
110 |
+
def _expand_ordinal(m: re.Match) -> str:
|
111 |
+
return _inflect.number_to_words(m.group(0))
|
112 |
+
|
113 |
+
|
114 |
+
def _expand_number(m: re.Match) -> str:
|
115 |
+
num = int(m.group(0))
|
116 |
+
if num > 1000 and num < 3000:
|
117 |
+
if num == 2000:
|
118 |
+
return "two thousand"
|
119 |
+
elif num > 2000 and num < 2010:
|
120 |
+
return "two thousand " + _inflect.number_to_words(num % 100)
|
121 |
+
elif num % 100 == 0:
|
122 |
+
return _inflect.number_to_words(num // 100) + " hundred"
|
123 |
+
else:
|
124 |
+
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
|
125 |
+
else:
|
126 |
+
return _inflect.number_to_words(num, andword="")
|
127 |
+
|
128 |
+
|
129 |
+
def normalize_numbers(text: str) -> str:
|
130 |
+
text = re.sub(_comma_number_re, _remove_commas, text)
|
131 |
+
text = re.sub(_pounds_re, r"\1 pounds", text)
|
132 |
+
text = re.sub(_dollars_re, _expand_dollars, text)
|
133 |
+
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
134 |
+
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
135 |
+
text = re.sub(_number_re, _expand_number, text)
|
136 |
+
return text
|
137 |
+
|
138 |
+
|
139 |
+
# --- Number normalization code end ---
|
140 |
+
|
141 |
+
|
142 |
+
PAD_ID, UNK_ID, BOS_ID, EOS_ID = 0, 1, 2, 3
|
143 |
+
SPECIAL_TOKEN_IDS = [PAD_ID, UNK_ID, BOS_ID, EOS_ID]
|
144 |
+
|
145 |
+
_punctuation = ';:,.!?¡¿—…"«»“”() *~-/\\&'
|
146 |
+
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
147 |
+
_letters_ipa = (
|
148 |
+
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌː��ʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
149 |
+
)
|
150 |
+
|
151 |
+
symbols = [*_punctuation, *_letters, *_letters_ipa]
|
152 |
+
_symbol_to_id = {s: i for i, s in enumerate(symbols, start=len(SPECIAL_TOKEN_IDS))}
|
153 |
+
|
154 |
+
|
155 |
+
def _get_symbol_id(s: str) -> int:
|
156 |
+
return _symbol_to_id.get(s, 1)
|
157 |
+
|
158 |
+
|
159 |
+
def get_symbol_ids(text: str) -> list[int]:
|
160 |
+
return list(map(_get_symbol_id, text))
|
161 |
+
|
162 |
+
|
163 |
+
def tokenize_phonemes(phonemes: list[str]) -> tuple[torch.Tensor, list[int]]:
|
164 |
+
phoneme_ids = [[BOS_ID, *get_symbol_ids(phonemes), EOS_ID] for phonemes in phonemes]
|
165 |
+
lengths = list(map(len, phoneme_ids))
|
166 |
+
longest = max(lengths)
|
167 |
+
phoneme_ids = [[PAD_ID] * (longest - len(ids)) + ids for ids in phoneme_ids]
|
168 |
+
return torch.tensor(phoneme_ids), lengths
|
169 |
+
|
170 |
+
|
171 |
+
def normalize_jp_text(text: str, tokenizer=Dictionary(dict="full").create()) -> str:
|
172 |
+
text = unicodedata.normalize("NFKC", text)
|
173 |
+
text = re.sub(r"\d+", lambda m: number2kanji(int(m[0])), text)
|
174 |
+
final_text = " ".join([x.reading_form() for x in tokenizer.tokenize(text, SplitMode.A)])
|
175 |
+
return final_text
|
176 |
+
|
177 |
+
|
178 |
+
def clean(texts: list[str], languages: list[str]) -> list[str]:
|
179 |
+
texts_out = []
|
180 |
+
for text, language in zip(texts, languages):
|
181 |
+
if "ja" in language:
|
182 |
+
text = normalize_jp_text(text)
|
183 |
+
else:
|
184 |
+
text = normalize_numbers(text)
|
185 |
+
texts_out.append(text)
|
186 |
+
return texts_out
|
187 |
+
|
188 |
+
|
189 |
+
@cache
|
190 |
+
def get_backend(language: str) -> "EspeakBackend":
|
191 |
+
import logging
|
192 |
+
|
193 |
+
from phonemizer.backend import EspeakBackend
|
194 |
+
|
195 |
+
logger = logging.getLogger("phonemizer")
|
196 |
+
backend = EspeakBackend(
|
197 |
+
language,
|
198 |
+
preserve_punctuation=True,
|
199 |
+
with_stress=True,
|
200 |
+
punctuation_marks=_punctuation,
|
201 |
+
logger=logger,
|
202 |
+
)
|
203 |
+
logger.setLevel(logging.ERROR)
|
204 |
+
return backend
|
205 |
+
|
206 |
+
|
207 |
+
def phonemize(texts: list[str], languages: list[str]) -> list[str]:
|
208 |
+
texts = clean(texts, languages)
|
209 |
+
|
210 |
+
batch_phonemes = []
|
211 |
+
for text, language in zip(texts, languages):
|
212 |
+
backend = get_backend(language)
|
213 |
+
phonemes = backend.phonemize([text], strip=True)
|
214 |
+
batch_phonemes.append(phonemes[0])
|
215 |
+
|
216 |
+
return batch_phonemes
|
217 |
+
|
218 |
+
|
219 |
+
class EspeakPhonemeConditioner(Conditioner):
|
220 |
+
def __init__(self, output_dim: int, **kwargs):
|
221 |
+
super().__init__(output_dim, **kwargs)
|
222 |
+
self.phoneme_embedder = nn.Embedding(len(SPECIAL_TOKEN_IDS) + len(symbols), output_dim)
|
223 |
+
|
224 |
+
def apply_cond(self, texts: list[str], languages: list[str]) -> torch.Tensor:
|
225 |
+
"""
|
226 |
+
Args:
|
227 |
+
texts: list of texts to convert to phonemes
|
228 |
+
languages: ISO 639-1 -or otherwise eSpeak compatible- language code
|
229 |
+
"""
|
230 |
+
device = self.phoneme_embedder.weight.device
|
231 |
+
|
232 |
+
phonemes = phonemize(texts, languages)
|
233 |
+
phoneme_ids, _ = tokenize_phonemes(phonemes)
|
234 |
+
phoneme_embeds = self.phoneme_embedder(phoneme_ids.to(device))
|
235 |
+
|
236 |
+
return phoneme_embeds
|
237 |
+
|
238 |
+
|
239 |
+
# ------- ESPEAK CONTAINMENT ZONE ------------------------------------------------------------------------------------------------------------------------------------------------
|
240 |
+
|
241 |
+
|
242 |
+
class FourierConditioner(Conditioner):
|
243 |
+
def __init__(
|
244 |
+
self,
|
245 |
+
output_dim: int,
|
246 |
+
input_dim: int = 1,
|
247 |
+
std: float = 1.0,
|
248 |
+
min_val: float = 0.0,
|
249 |
+
max_val: float = 1.0,
|
250 |
+
**kwargs,
|
251 |
+
):
|
252 |
+
assert output_dim % 2 == 0
|
253 |
+
super().__init__(output_dim, **kwargs)
|
254 |
+
self.register_buffer("weight", torch.randn([output_dim // 2, input_dim]) * std)
|
255 |
+
self.input_dim, self.min_val, self.max_val = input_dim, min_val, max_val
|
256 |
+
|
257 |
+
def apply_cond(self, x: torch.Tensor) -> torch.Tensor:
|
258 |
+
assert x.shape[-1] == self.input_dim
|
259 |
+
x = (x - self.min_val) / (self.max_val - self.min_val) # [batch_size, seq_len, input_dim]
|
260 |
+
f = 2 * torch.pi * x.to(self.weight.dtype) @ self.weight.T # [batch_size, seq_len, output_dim // 2]
|
261 |
+
return torch.cat([f.cos(), f.sin()], dim=-1) # [batch_size, seq_len, output_dim]
|
262 |
+
|
263 |
+
|
264 |
+
class IntegerConditioner(Conditioner):
|
265 |
+
def __init__(self, output_dim: int, min_val: int = 0, max_val: int = 512, **kwargs):
|
266 |
+
super().__init__(output_dim, **kwargs)
|
267 |
+
self.min_val = min_val
|
268 |
+
self.max_val = max_val
|
269 |
+
self.int_embedder = nn.Embedding(max_val - min_val + 1, output_dim)
|
270 |
+
|
271 |
+
def apply_cond(self, x: torch.Tensor) -> torch.Tensor:
|
272 |
+
assert x.shape[-1] == 1
|
273 |
+
return self.int_embedder(x.squeeze(-1) - self.min_val) # [batch_size, seq_len, output_dim]
|
274 |
+
|
275 |
+
|
276 |
+
class PassthroughConditioner(Conditioner):
|
277 |
+
def __init__(self, output_dim: int, **kwargs):
|
278 |
+
super().__init__(output_dim, **kwargs)
|
279 |
+
|
280 |
+
def apply_cond(self, x: torch.Tensor) -> torch.Tensor:
|
281 |
+
assert x.shape[-1] == self.cond_dim
|
282 |
+
return x
|
283 |
+
|
284 |
+
|
285 |
+
_cond_cls_map = {
|
286 |
+
"PassthroughConditioner": PassthroughConditioner,
|
287 |
+
"EspeakPhonemeConditioner": EspeakPhonemeConditioner,
|
288 |
+
"FourierConditioner": FourierConditioner,
|
289 |
+
"IntegerConditioner": IntegerConditioner,
|
290 |
+
}
|
291 |
+
|
292 |
+
|
293 |
+
def build_conditioners(conditioners: list[dict], output_dim: int) -> list[Conditioner]:
|
294 |
+
return [_cond_cls_map[config["type"]](output_dim, **config) for config in conditioners]
|
295 |
+
|
296 |
+
|
297 |
+
class PrefixConditioner(Conditioner):
|
298 |
+
def __init__(self, config: PrefixConditionerConfig, output_dim: int):
|
299 |
+
super().__init__(output_dim, "prefix", projection=config.projection)
|
300 |
+
self.conditioners = nn.ModuleList(build_conditioners(config.conditioners, output_dim))
|
301 |
+
self.norm = nn.LayerNorm(output_dim)
|
302 |
+
self.required_keys = {c.name for c in self.conditioners if c.uncond_vector is None}
|
303 |
+
|
304 |
+
def forward(self, cond_dict: dict) -> torch.Tensor:
|
305 |
+
if not set(cond_dict).issuperset(self.required_keys):
|
306 |
+
raise ValueError(f"Missing required keys: {self.required_keys - set(cond_dict)}")
|
307 |
+
conds = []
|
308 |
+
for conditioner in self.conditioners:
|
309 |
+
conds.append(conditioner(cond_dict.get(conditioner.name)))
|
310 |
+
max_bsz = max(map(len, conds))
|
311 |
+
assert all(c.shape[0] in (max_bsz, 1) for c in conds)
|
312 |
+
conds = [c.expand(max_bsz, -1, -1) for c in conds]
|
313 |
+
return self.norm(self.project(torch.cat(conds, dim=-2)))
|
314 |
+
|
315 |
+
|
316 |
+
supported_language_codes = [
|
317 |
+
'af', 'am', 'an', 'ar', 'as', 'az', 'ba', 'bg', 'bn', 'bpy', 'bs', 'ca', 'cmn',
|
318 |
+
'cs', 'cy', 'da', 'de', 'el', 'en-029', 'en-gb', 'en-gb-scotland', 'en-gb-x-gbclan',
|
319 |
+
'en-gb-x-gbcwmd', 'en-gb-x-rp', 'en-us', 'eo', 'es', 'es-419', 'et', 'eu', 'fa',
|
320 |
+
'fa-latn', 'fi', 'fr-be', 'fr-ch', 'fr-fr', 'ga', 'gd', 'gn', 'grc', 'gu', 'hak',
|
321 |
+
'hi', 'hr', 'ht', 'hu', 'hy', 'hyw', 'ia', 'id', 'is', 'it', 'ja', 'jbo', 'ka',
|
322 |
+
'kk', 'kl', 'kn', 'ko', 'kok', 'ku', 'ky', 'la', 'lfn', 'lt', 'lv', 'mi', 'mk',
|
323 |
+
'ml', 'mr', 'ms', 'mt', 'my', 'nb', 'nci', 'ne', 'nl', 'om', 'or', 'pa', 'pap',
|
324 |
+
'pl', 'pt', 'pt-br', 'py', 'quc', 'ro', 'ru', 'ru-lv', 'sd', 'shn', 'si', 'sk',
|
325 |
+
'sl', 'sq', 'sr', 'sv', 'sw', 'ta', 'te', 'tn', 'tr', 'tt', 'ur', 'uz', 'vi',
|
326 |
+
'vi-vn-x-central', 'vi-vn-x-south', 'yue'
|
327 |
+
] # fmt: off
|
328 |
+
|
329 |
+
|
330 |
+
def make_cond_dict(
|
331 |
+
text: str = "It would be nice to have time for testing, indeed.",
|
332 |
+
language: str = "en-us",
|
333 |
+
speaker: torch.Tensor | None = None,
|
334 |
+
|
335 |
+
# Emotion vector from 0.0 to 1.0
|
336 |
+
# Is entangled with pitch_std because more emotion => more pitch variation
|
337 |
+
# VQScore and DNSMOS because they favor neutral speech
|
338 |
+
#
|
339 |
+
# Happiness, Sadness, Disgust, Fear, Surprise, Anger, Other, Neutral
|
340 |
+
emotion: list[float] = [0.3077, 0.0256, 0.0256, 0.0256, 0.0256, 0.0256, 0.2564, 0.3077],
|
341 |
+
|
342 |
+
# Maximum frequency (0 to 24000), should be 22050 or 24000 for 44.1 or 48 kHz audio
|
343 |
+
# For voice cloning use 22050
|
344 |
+
fmax: float = 22050.0,
|
345 |
+
|
346 |
+
# Standard deviation for pitch (0 to 400), should be
|
347 |
+
# 20-45 for normal speech,
|
348 |
+
# 60-150 for expressive speech,
|
349 |
+
# higher values => crazier samples
|
350 |
+
pitch_std: float = 20.0,
|
351 |
+
|
352 |
+
# Speaking rate in phonemes per minute (0 to 40). 30 is very fast, 10 is slow.
|
353 |
+
speaking_rate: float = 15.0,
|
354 |
+
|
355 |
+
# Target VoiceQualityScore for the generated speech (0.5 to 0.8).
|
356 |
+
# A list of values must be provided which represent each 1/8th of the audio.
|
357 |
+
# You should unset for expressive speech.
|
358 |
+
# According to discord Chat this is only used for the hybrid model
|
359 |
+
vqscore_8: list[float] = [0.78] * 8,
|
360 |
+
|
361 |
+
# CTC target loss
|
362 |
+
# Only used for the hybrid model
|
363 |
+
ctc_loss: float = 0.0,
|
364 |
+
# Only used for the hybrid model
|
365 |
+
dnsmos_ovrl: float = 4.0,
|
366 |
+
# Only used for the hybrid model
|
367 |
+
speaker_noised: bool = False,
|
368 |
+
unconditional_keys: Iterable[str] = {"vqscore_8", "dnsmos_ovrl"},
|
369 |
+
device: torch.device | str = DEFAULT_DEVICE,
|
370 |
+
) -> dict:
|
371 |
+
"""
|
372 |
+
A helper to build the 'cond_dict' that the model expects.
|
373 |
+
By default, it will generate a random speaker embedding
|
374 |
+
"""
|
375 |
+
assert language.lower() in supported_language_codes, "Please pick a supported language"
|
376 |
+
|
377 |
+
language_code_to_id = {lang: i for i, lang in enumerate(supported_language_codes)}
|
378 |
+
|
379 |
+
cond_dict = {
|
380 |
+
"espeak": ([text], [language]),
|
381 |
+
"speaker": speaker,
|
382 |
+
"emotion": emotion,
|
383 |
+
"fmax": fmax,
|
384 |
+
"pitch_std": pitch_std,
|
385 |
+
"speaking_rate": speaking_rate,
|
386 |
+
"language_id": language_code_to_id[language],
|
387 |
+
"vqscore_8": vqscore_8,
|
388 |
+
"ctc_loss": ctc_loss,
|
389 |
+
"dnsmos_ovrl": dnsmos_ovrl,
|
390 |
+
"speaker_noised": int(speaker_noised),
|
391 |
+
}
|
392 |
+
|
393 |
+
for k in unconditional_keys:
|
394 |
+
cond_dict.pop(k, None)
|
395 |
+
|
396 |
+
for k, v in cond_dict.items():
|
397 |
+
if isinstance(v, (float, int, list)):
|
398 |
+
v = torch.tensor(v)
|
399 |
+
if isinstance(v, torch.Tensor):
|
400 |
+
cond_dict[k] = v.view(1, 1, -1).to(device)
|
401 |
+
|
402 |
+
if k == "emotion":
|
403 |
+
cond_dict[k] /= cond_dict[k].sum(dim=-1)
|
404 |
+
|
405 |
+
return cond_dict
|
Zonos-main/zonos/config.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Literal
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
# https://github.com/state-spaces/mamba/blob//mamba_ssm/utils/generation.py#L18
|
8 |
+
@dataclass
|
9 |
+
class InferenceParams:
|
10 |
+
"""Inference parameters that are passed to the main model in order
|
11 |
+
to efficienly calculate and store the context during inference."""
|
12 |
+
|
13 |
+
max_seqlen: int
|
14 |
+
max_batch_size: int
|
15 |
+
seqlen_offset: int = 0
|
16 |
+
batch_size_offset: int = 0
|
17 |
+
key_value_memory_dict: dict = field(default_factory=dict)
|
18 |
+
lengths_per_sample: torch.Tensor | None = None
|
19 |
+
|
20 |
+
def reset(self, max_seqlen, max_batch_size):
|
21 |
+
self.max_seqlen = max_seqlen
|
22 |
+
self.max_batch_size = max_batch_size
|
23 |
+
self.seqlen_offset = 0
|
24 |
+
if self.lengths_per_sample is not None:
|
25 |
+
self.lengths_per_sample.zero_()
|
26 |
+
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
class BackboneConfig:
|
30 |
+
d_model: int = 1024
|
31 |
+
d_intermediate: int = 0
|
32 |
+
attn_mlp_d_intermediate: int = 0
|
33 |
+
n_layer: int = 16
|
34 |
+
ssm_cfg: dict = field(default_factory=dict)
|
35 |
+
attn_layer_idx: list = field(default_factory=list)
|
36 |
+
attn_cfg: dict = field(default_factory=dict)
|
37 |
+
rms_norm: bool = False
|
38 |
+
residual_in_fp32: bool = False
|
39 |
+
norm_epsilon: float = 1e-5
|
40 |
+
|
41 |
+
|
42 |
+
@dataclass
|
43 |
+
class PrefixConditionerConfig:
|
44 |
+
conditioners: list[dict]
|
45 |
+
projection: Literal["none", "linear", "mlp"]
|
46 |
+
|
47 |
+
|
48 |
+
@dataclass
|
49 |
+
class ZonosConfig:
|
50 |
+
backbone: BackboneConfig
|
51 |
+
prefix_conditioner: PrefixConditionerConfig
|
52 |
+
eos_token_id: int = 1024
|
53 |
+
masked_token_id: int = 1025
|
54 |
+
pad_vocab_to_multiple_of: int = 8
|
55 |
+
|
56 |
+
@classmethod
|
57 |
+
def from_dict(cls, d: dict) -> "ZonosConfig":
|
58 |
+
d = d.copy()
|
59 |
+
backbone_config = BackboneConfig(**d.pop("backbone"))
|
60 |
+
prefix_conditioner_config = PrefixConditionerConfig(**d.pop("prefix_conditioner"))
|
61 |
+
config = cls(backbone_config, prefix_conditioner_config, **d)
|
62 |
+
return config
|
Zonos-main/zonos/model.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import Callable
|
3 |
+
|
4 |
+
import safetensors
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from huggingface_hub import hf_hub_download
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from zonos.autoencoder import DACAutoencoder
|
11 |
+
from zonos.backbone import BACKBONES
|
12 |
+
from zonos.codebook_pattern import apply_delay_pattern, revert_delay_pattern
|
13 |
+
from zonos.conditioning import PrefixConditioner
|
14 |
+
from zonos.config import InferenceParams, ZonosConfig
|
15 |
+
from zonos.sampling import sample_from_logits
|
16 |
+
from zonos.speaker_cloning import SpeakerEmbeddingLDA
|
17 |
+
from zonos.utils import DEFAULT_DEVICE, find_multiple, pad_weight_
|
18 |
+
|
19 |
+
DEFAULT_BACKBONE_CLS = next(iter(BACKBONES.values()))
|
20 |
+
|
21 |
+
|
22 |
+
class Zonos(nn.Module):
|
23 |
+
def __init__(self, config: ZonosConfig, backbone_cls=DEFAULT_BACKBONE_CLS):
|
24 |
+
super().__init__()
|
25 |
+
self.config = config
|
26 |
+
dim = config.backbone.d_model
|
27 |
+
self.eos_token_id = config.eos_token_id
|
28 |
+
self.masked_token_id = config.masked_token_id
|
29 |
+
|
30 |
+
self.autoencoder = DACAutoencoder()
|
31 |
+
self.backbone = backbone_cls(config.backbone)
|
32 |
+
self.prefix_conditioner = PrefixConditioner(config.prefix_conditioner, dim)
|
33 |
+
self.spk_clone_model = None
|
34 |
+
|
35 |
+
# TODO: pad to multiple of at least 8
|
36 |
+
self.embeddings = nn.ModuleList([nn.Embedding(1026, dim) for _ in range(self.autoencoder.num_codebooks)])
|
37 |
+
self.heads = nn.ModuleList([nn.Linear(dim, 1025, bias=False) for _ in range(self.autoencoder.num_codebooks)])
|
38 |
+
|
39 |
+
self._cg_graph = None
|
40 |
+
self._cg_batch_size = None
|
41 |
+
self._cg_input_ids = None
|
42 |
+
self._cg_logits = None
|
43 |
+
self._cg_inference_params = None
|
44 |
+
self._cg_scale = None
|
45 |
+
|
46 |
+
if config.pad_vocab_to_multiple_of:
|
47 |
+
self.register_load_state_dict_post_hook(self._pad_embeddings_and_heads)
|
48 |
+
|
49 |
+
def _pad_embeddings_and_heads(self, *args, **kwargs):
|
50 |
+
for w in [*self.embeddings, *self.heads]:
|
51 |
+
pad_weight_(w, self.config.pad_vocab_to_multiple_of)
|
52 |
+
|
53 |
+
@property
|
54 |
+
def device(self) -> torch.device:
|
55 |
+
return next(self.parameters()).device
|
56 |
+
|
57 |
+
@classmethod
|
58 |
+
def from_pretrained(
|
59 |
+
cls, repo_id: str, revision: str | None = None, device: str = DEFAULT_DEVICE, **kwargs
|
60 |
+
) -> "Zonos":
|
61 |
+
config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision)
|
62 |
+
model_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
|
63 |
+
return cls.from_local(config_path, model_path, device, **kwargs)
|
64 |
+
|
65 |
+
@classmethod
|
66 |
+
def from_local(
|
67 |
+
cls, config_path: str, model_path: str, device: str = DEFAULT_DEVICE, backbone: str | None = None
|
68 |
+
) -> "Zonos":
|
69 |
+
config = ZonosConfig.from_dict(json.load(open(config_path)))
|
70 |
+
if backbone:
|
71 |
+
backbone_cls = BACKBONES[backbone]
|
72 |
+
else:
|
73 |
+
is_transformer = not bool(config.backbone.ssm_cfg)
|
74 |
+
backbone_cls = DEFAULT_BACKBONE_CLS
|
75 |
+
# Preferentially route to pure torch backbone for increased performance and lower latency.
|
76 |
+
if is_transformer and "torch" in BACKBONES:
|
77 |
+
backbone_cls = BACKBONES["torch"]
|
78 |
+
|
79 |
+
model = cls(config, backbone_cls).to(device, torch.bfloat16)
|
80 |
+
model.autoencoder.dac.to(device)
|
81 |
+
|
82 |
+
sd = model.state_dict()
|
83 |
+
with safetensors.safe_open(model_path, framework="pt") as f:
|
84 |
+
for k in f.keys():
|
85 |
+
sd[k] = f.get_tensor(k)
|
86 |
+
model.load_state_dict(sd)
|
87 |
+
|
88 |
+
return model
|
89 |
+
|
90 |
+
def make_speaker_embedding(self, wav: torch.Tensor, sr: int) -> torch.Tensor:
|
91 |
+
"""Generate a speaker embedding from an audio clip."""
|
92 |
+
if self.spk_clone_model is None:
|
93 |
+
self.spk_clone_model = SpeakerEmbeddingLDA()
|
94 |
+
_, spk_embedding = self.spk_clone_model(wav.to(self.spk_clone_model.device), sr)
|
95 |
+
return spk_embedding.unsqueeze(0).bfloat16()
|
96 |
+
|
97 |
+
def embed_codes(self, codes: torch.Tensor) -> torch.Tensor:
|
98 |
+
return sum(emb(codes[:, i]) for i, emb in enumerate(self.embeddings))
|
99 |
+
|
100 |
+
def apply_heads(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
101 |
+
return torch.stack([head(hidden_states) for head in self.heads], dim=1)
|
102 |
+
|
103 |
+
def _compute_logits(
|
104 |
+
self, hidden_states: torch.Tensor, inference_params: InferenceParams, cfg_scale: float
|
105 |
+
) -> torch.Tensor:
|
106 |
+
"""
|
107 |
+
Pass `hidden_states` into `backbone` and `multi_head`, applying
|
108 |
+
classifier-free guidance if `cfg_scale != 1.0`.
|
109 |
+
"""
|
110 |
+
last_hidden_states = self.backbone(hidden_states, inference_params)[:, -1, :].unsqueeze(1)
|
111 |
+
logits = self.apply_heads(last_hidden_states).squeeze(2).float()
|
112 |
+
if cfg_scale != 1.0:
|
113 |
+
cond_logits, uncond_logits = logits.chunk(2)
|
114 |
+
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
|
115 |
+
logits[..., 1025:].fill_(-torch.inf) # ensures padding is ignored
|
116 |
+
return logits
|
117 |
+
|
118 |
+
def _decode_one_token(
|
119 |
+
self,
|
120 |
+
input_ids: torch.Tensor,
|
121 |
+
inference_params: InferenceParams,
|
122 |
+
cfg_scale: float,
|
123 |
+
allow_cudagraphs: bool = True,
|
124 |
+
) -> torch.Tensor:
|
125 |
+
"""
|
126 |
+
Single-step decode. Prepares the hidden states, possibly replicates them
|
127 |
+
for CFG, and then delegates to `_compute_logits`.
|
128 |
+
|
129 |
+
Below we wrap this function with a simple CUDA Graph capturing mechanism,
|
130 |
+
doing 3 warmup steps if needed and then capturing or replaying the graph.
|
131 |
+
We only recapture if the batch size changes.
|
132 |
+
"""
|
133 |
+
# TODO: support cfg_scale==1
|
134 |
+
if cfg_scale == 1.0:
|
135 |
+
hidden_states = self.embed_codes(input_ids)
|
136 |
+
return self._compute_logits(hidden_states, inference_params, cfg_scale)
|
137 |
+
|
138 |
+
bsz = input_ids.size(0)
|
139 |
+
|
140 |
+
if not allow_cudagraphs or input_ids.device.type != "cuda":
|
141 |
+
hidden_states_local = self.embed_codes(input_ids)
|
142 |
+
hidden_states_local = hidden_states_local.repeat(2, 1, 1)
|
143 |
+
return self._compute_logits(hidden_states_local, inference_params, cfg_scale)
|
144 |
+
|
145 |
+
need_capture = (self._cg_graph is None) or (self._cg_batch_size != bsz)
|
146 |
+
|
147 |
+
if need_capture:
|
148 |
+
self._cg_graph = None
|
149 |
+
|
150 |
+
self._cg_batch_size = bsz
|
151 |
+
self._cg_inference_params = inference_params
|
152 |
+
self._cg_scale = cfg_scale
|
153 |
+
|
154 |
+
for _ in range(3):
|
155 |
+
hidden_states = self.embed_codes(input_ids)
|
156 |
+
hidden_states = hidden_states.repeat(2, 1, 1) # because cfg != 1.0
|
157 |
+
logits = self._compute_logits(hidden_states, inference_params, cfg_scale)
|
158 |
+
|
159 |
+
self._cg_input_ids = input_ids.clone()
|
160 |
+
self._cg_logits = torch.empty_like(logits)
|
161 |
+
|
162 |
+
g = torch.cuda.CUDAGraph()
|
163 |
+
|
164 |
+
def capture_region():
|
165 |
+
hidden_states_local = self.embed_codes(self._cg_input_ids)
|
166 |
+
hidden_states_local = hidden_states_local.repeat(2, 1, 1)
|
167 |
+
self._cg_logits = self._compute_logits(hidden_states_local, self._cg_inference_params, self._cg_scale)
|
168 |
+
|
169 |
+
with torch.cuda.graph(g):
|
170 |
+
capture_region()
|
171 |
+
|
172 |
+
self._cg_graph = g
|
173 |
+
|
174 |
+
else:
|
175 |
+
self._cg_input_ids.copy_(input_ids)
|
176 |
+
|
177 |
+
self._cg_graph.replay()
|
178 |
+
|
179 |
+
return self._cg_logits
|
180 |
+
|
181 |
+
def _prefill(
|
182 |
+
self,
|
183 |
+
prefix_hidden_states: torch.Tensor,
|
184 |
+
input_ids: torch.Tensor,
|
185 |
+
inference_params: InferenceParams,
|
186 |
+
cfg_scale: float,
|
187 |
+
) -> torch.Tensor:
|
188 |
+
"""
|
189 |
+
"Prefill" mode: we already have `prefix_hidden_states`, and we want
|
190 |
+
to append new embeddings, then compute the logits.
|
191 |
+
"""
|
192 |
+
# Replicate input_ids if CFG is enabled
|
193 |
+
if cfg_scale != 1.0:
|
194 |
+
input_ids = input_ids.expand(prefix_hidden_states.shape[0], -1, -1)
|
195 |
+
hidden_states = torch.cat([prefix_hidden_states, self.embed_codes(input_ids)], dim=1)
|
196 |
+
return self._compute_logits(hidden_states, inference_params, cfg_scale)
|
197 |
+
|
198 |
+
def setup_cache(self, batch_size: int, max_seqlen: int, dtype: torch.dtype = torch.bfloat16) -> InferenceParams:
|
199 |
+
max_seqlen = find_multiple(max_seqlen, 8)
|
200 |
+
key_value_memory_dict = self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype)
|
201 |
+
lengths_per_sample = torch.full((batch_size,), 0, dtype=torch.int32)
|
202 |
+
return InferenceParams(max_seqlen, batch_size, 0, 0, key_value_memory_dict, lengths_per_sample)
|
203 |
+
|
204 |
+
def prepare_conditioning(self, cond_dict: dict, uncond_dict: dict | None = None) -> torch.Tensor:
|
205 |
+
if uncond_dict is None:
|
206 |
+
uncond_dict = {k: cond_dict[k] for k in self.prefix_conditioner.required_keys}
|
207 |
+
return torch.cat(
|
208 |
+
[
|
209 |
+
self.prefix_conditioner(cond_dict),
|
210 |
+
self.prefix_conditioner(uncond_dict),
|
211 |
+
]
|
212 |
+
)
|
213 |
+
|
214 |
+
def can_use_cudagraphs(self) -> bool:
|
215 |
+
# Only the mamba-ssm backbone supports CUDA Graphs at the moment
|
216 |
+
return self.device.type == "cuda" and "_mamba_ssm" in str(self.backbone.__class__)
|
217 |
+
|
218 |
+
@torch.inference_mode()
|
219 |
+
def generate(
|
220 |
+
self,
|
221 |
+
prefix_conditioning: torch.Tensor, # [bsz, cond_seq_len, d_model]
|
222 |
+
audio_prefix_codes: torch.Tensor | None = None, # [bsz, 9, prefix_audio_seq_len]
|
223 |
+
max_new_tokens: int = 86 * 30,
|
224 |
+
cfg_scale: float = 2.0,
|
225 |
+
batch_size: int = 1,
|
226 |
+
sampling_params: dict = dict(min_p=0.1),
|
227 |
+
progress_bar: bool = True,
|
228 |
+
disable_torch_compile: bool = False,
|
229 |
+
callback: Callable[[torch.Tensor, int, int], bool] | None = None,
|
230 |
+
):
|
231 |
+
assert cfg_scale != 1, "TODO: add support for cfg_scale=1"
|
232 |
+
prefix_audio_len = 0 if audio_prefix_codes is None else audio_prefix_codes.shape[2]
|
233 |
+
device = self.device
|
234 |
+
|
235 |
+
# Use CUDA Graphs if supported, and torch.compile otherwise.
|
236 |
+
cg = self.can_use_cudagraphs()
|
237 |
+
decode_one_token = self._decode_one_token
|
238 |
+
decode_one_token = torch.compile(decode_one_token, dynamic=True, disable=cg or disable_torch_compile)
|
239 |
+
|
240 |
+
unknown_token = -1
|
241 |
+
audio_seq_len = prefix_audio_len + max_new_tokens
|
242 |
+
seq_len = prefix_conditioning.shape[1] + audio_seq_len + 9
|
243 |
+
|
244 |
+
with torch.device(device):
|
245 |
+
inference_params = self.setup_cache(batch_size=batch_size * 2, max_seqlen=seq_len)
|
246 |
+
codes = torch.full((batch_size, 9, audio_seq_len), unknown_token)
|
247 |
+
|
248 |
+
if audio_prefix_codes is not None:
|
249 |
+
codes[..., :prefix_audio_len] = audio_prefix_codes
|
250 |
+
|
251 |
+
delayed_codes = apply_delay_pattern(codes, self.masked_token_id)
|
252 |
+
|
253 |
+
delayed_prefix_audio_codes = delayed_codes[..., : prefix_audio_len + 1]
|
254 |
+
|
255 |
+
logits = self._prefill(prefix_conditioning, delayed_prefix_audio_codes, inference_params, cfg_scale)
|
256 |
+
next_token = sample_from_logits(logits, **sampling_params)
|
257 |
+
|
258 |
+
offset = delayed_prefix_audio_codes.shape[2]
|
259 |
+
frame = delayed_codes[..., offset : offset + 1]
|
260 |
+
frame.masked_scatter_(frame == unknown_token, next_token)
|
261 |
+
|
262 |
+
prefix_length = prefix_conditioning.shape[1] + prefix_audio_len + 1
|
263 |
+
inference_params.seqlen_offset += prefix_length
|
264 |
+
inference_params.lengths_per_sample[:] += prefix_length
|
265 |
+
|
266 |
+
logit_bias = torch.zeros_like(logits)
|
267 |
+
logit_bias[:, 1:, self.eos_token_id] = -torch.inf # only allow codebook 0 to predict EOS
|
268 |
+
|
269 |
+
stopping = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
270 |
+
max_steps = delayed_codes.shape[2] - offset
|
271 |
+
remaining_steps = torch.full((batch_size,), max_steps, device=device)
|
272 |
+
progress = tqdm(total=max_steps, desc="Generating", disable=not progress_bar)
|
273 |
+
cfg_scale = torch.tensor(cfg_scale)
|
274 |
+
|
275 |
+
step = 0
|
276 |
+
while torch.max(remaining_steps) > 0:
|
277 |
+
offset += 1
|
278 |
+
input_ids = delayed_codes[..., offset - 1 : offset]
|
279 |
+
logits = decode_one_token(input_ids, inference_params, cfg_scale, allow_cudagraphs=cg)
|
280 |
+
logits += logit_bias
|
281 |
+
|
282 |
+
next_token = sample_from_logits(logits, generated_tokens=delayed_codes[..., :offset], **sampling_params)
|
283 |
+
eos_in_cb0 = next_token[:, 0] == self.eos_token_id
|
284 |
+
|
285 |
+
remaining_steps[eos_in_cb0[:, 0]] = torch.minimum(remaining_steps[eos_in_cb0[:, 0]], torch.tensor(9))
|
286 |
+
stopping |= eos_in_cb0[:, 0]
|
287 |
+
|
288 |
+
eos_codebook_idx = 9 - remaining_steps
|
289 |
+
eos_codebook_idx = torch.clamp(eos_codebook_idx, max=9 - 1)
|
290 |
+
for i in range(next_token.shape[0]):
|
291 |
+
if stopping[i]:
|
292 |
+
idx = eos_codebook_idx[i].item()
|
293 |
+
next_token[i, :idx] = self.masked_token_id
|
294 |
+
next_token[i, idx] = self.eos_token_id
|
295 |
+
|
296 |
+
frame = delayed_codes[..., offset : offset + 1]
|
297 |
+
frame.masked_scatter_(frame == unknown_token, next_token)
|
298 |
+
inference_params.seqlen_offset += 1
|
299 |
+
inference_params.lengths_per_sample[:] += 1
|
300 |
+
|
301 |
+
remaining_steps -= 1
|
302 |
+
|
303 |
+
progress.update()
|
304 |
+
step += 1
|
305 |
+
|
306 |
+
if callback is not None and not callback(frame, step, max_steps):
|
307 |
+
break
|
308 |
+
|
309 |
+
out_codes = revert_delay_pattern(delayed_codes)
|
310 |
+
out_codes.masked_fill_(out_codes >= 1024, 0)
|
311 |
+
out_codes = out_codes[..., : offset - 9]
|
312 |
+
|
313 |
+
self._cg_graph = None # reset cuda graph to avoid cache changes
|
314 |
+
|
315 |
+
return out_codes
|
Zonos-main/zonos/sampling.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
|
5 |
+
"""torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
|
6 |
+
|
7 |
+
Args:
|
8 |
+
input (torch.Tensor): The input tensor containing probabilities.
|
9 |
+
num_samples (int): Number of samples to draw.
|
10 |
+
replacement (bool): Whether to draw with replacement or not.
|
11 |
+
Keywords args:
|
12 |
+
generator (torch.Generator): A pseudorandom number generator for sampling.
|
13 |
+
Returns:
|
14 |
+
torch.Tensor: Last dimension contains num_samples indices
|
15 |
+
sampled from the multinomial probability distribution
|
16 |
+
located in the last dimension of tensor input.
|
17 |
+
"""
|
18 |
+
|
19 |
+
if num_samples == 1:
|
20 |
+
q = torch.empty_like(input).exponential_(1, generator=generator)
|
21 |
+
return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64)
|
22 |
+
|
23 |
+
input_ = input.reshape(-1, input.shape[-1])
|
24 |
+
output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
|
25 |
+
output = output_.reshape(*list(input.shape[:-1]), -1)
|
26 |
+
return output
|
27 |
+
|
28 |
+
|
29 |
+
def apply_unified(probs: torch.Tensor, linear: float, conf: float, quad: float) -> torch.Tensor:
|
30 |
+
"""Sample next token using unified sampling approach that combines linear scaling, confidence, and quadratic terms.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
34 |
+
linear (float): Linear scaling factor applied to log probabilities.
|
35 |
+
conf (float): Confidence factor that scales the entropy term.
|
36 |
+
quad (float): Quadratic penalty factor applied to squared log probabilities.
|
37 |
+
Returns:
|
38 |
+
torch.Tensor: Modified probability distribution after applying unified sampling.
|
39 |
+
"""
|
40 |
+
logprobs = torch.log(probs.clamp_min(1e-20))
|
41 |
+
entropy = -torch.sum(probs * logprobs, dim=-1, keepdim=True)
|
42 |
+
raw = logprobs * (linear + entropy * conf) - logprobs**2 * quad
|
43 |
+
return raw.softmax(dim=-1)
|
44 |
+
|
45 |
+
def apply_top_k(
|
46 |
+
probs: torch.Tensor,
|
47 |
+
k: int,
|
48 |
+
) -> torch.Tensor:
|
49 |
+
"""Sample next token from top K values along the last dimension of the input probs tensor.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
53 |
+
k (int): The k in “top-k”.
|
54 |
+
Returns:
|
55 |
+
torch.Tensor: Sampled tokens.
|
56 |
+
"""
|
57 |
+
v, _ = torch.topk(probs, min(k, probs.size(-1)))
|
58 |
+
pivot = v.select(-1, -1).unsqueeze(-1)
|
59 |
+
probs = torch.where(probs < pivot, 0.0, probs)
|
60 |
+
probs.div_(probs.sum(dim=-1, keepdim=True))
|
61 |
+
return probs
|
62 |
+
|
63 |
+
|
64 |
+
def apply_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
|
65 |
+
"""Sample next token from top P probabilities along the last dimension of the input probs tensor.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
69 |
+
p (int): The p in “top-p”.
|
70 |
+
Returns:
|
71 |
+
torch.Tensor: Sampled tokens.
|
72 |
+
"""
|
73 |
+
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
74 |
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
75 |
+
mask = probs_sum - probs_sort > p
|
76 |
+
probs_sort *= (~mask).float()
|
77 |
+
probs = probs.scatter(-1, probs_idx, probs_sort)
|
78 |
+
probs.div_(probs.sum(dim=-1, keepdim=True))
|
79 |
+
return probs
|
80 |
+
|
81 |
+
|
82 |
+
def apply_min_p(probs: torch.Tensor, min_p: float) -> torch.Tensor:
|
83 |
+
"""Sample next token using min-p sampling.
|
84 |
+
|
85 |
+
Args:
|
86 |
+
scores (torch.FloatTensor): Input logits with token candidates on the last dimension.
|
87 |
+
min_p (float): Minimum token probability, scaled by the probability of the most likely token.
|
88 |
+
Must be between 0 and 1. Typical values are in the 0.01-0.2 range.
|
89 |
+
Returns:
|
90 |
+
torch.Tensor: Sampled tokens.
|
91 |
+
"""
|
92 |
+
top_probs, _ = probs.max(dim=-1, keepdim=True)
|
93 |
+
tokens_to_remove = probs < (min_p * top_probs)
|
94 |
+
probs = probs.masked_fill(tokens_to_remove, 0.0)
|
95 |
+
probs.div_(probs.sum(dim=-1, keepdim=True))
|
96 |
+
return probs
|
97 |
+
|
98 |
+
|
99 |
+
def modify_logit_for_repetition_penalty(
|
100 |
+
logits: torch.Tensor,
|
101 |
+
generated_tokens: torch.Tensor,
|
102 |
+
repetition_penalty: float,
|
103 |
+
repetition_penalty_window: int,
|
104 |
+
):
|
105 |
+
"""See https://arxiv.org/abs/1909.05858
|
106 |
+
Apply repetition penalty over a sliding window of the last `repetition_penalty_window` tokens.
|
107 |
+
logits: (batch_size, n_codebooks, vocab_size)
|
108 |
+
generated_tokens: (batch_size, n_codebooks, seq_len)
|
109 |
+
"""
|
110 |
+
generated_tokens = generated_tokens[..., -repetition_penalty_window:]
|
111 |
+
generated_tokens = generated_tokens.clamp_max(logits.shape[-1] - 1).to(torch.int64)
|
112 |
+
rp = torch.full_like(logits, repetition_penalty)
|
113 |
+
factors = torch.ones_like(logits).scatter_reduce(2, generated_tokens, rp, reduce="prod")
|
114 |
+
return torch.where(logits <= 0, logits * factors, logits / factors)
|
115 |
+
|
116 |
+
|
117 |
+
def sample_from_logits(
|
118 |
+
logits: torch.Tensor,
|
119 |
+
temperature: float = 1.0,
|
120 |
+
top_p: float = 0.0,
|
121 |
+
top_k: int = 0,
|
122 |
+
min_p: float = 0.0,
|
123 |
+
linear: float = 0.0,
|
124 |
+
conf: float = 0.0,
|
125 |
+
quad: float = 0.0,
|
126 |
+
generated_tokens: torch.Tensor | None = None,
|
127 |
+
repetition_penalty: float = 3.0,
|
128 |
+
repetition_penalty_window: int = 2,
|
129 |
+
) -> torch.Tensor:
|
130 |
+
"""Sample next token from logits using either top_k/p/min_p OR using NovelAI's Unified Sampler.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
logits (torch.Tensor): Input logits with token candidates on the last dimension.
|
134 |
+
|
135 |
+
temperature (float): Randomness of the sampling. Lower temperature results in more deterministic samples.
|
136 |
+
To disable sampling entirely, set it to 0. For NovelAI's Unified Sampler, set it to 1.0
|
137 |
+
|
138 |
+
top_p (float): Only sample from the most probable tokens whose cumulative probability is less than p.
|
139 |
+
This is called nucleus sampling. Must be between 0 and 1. Typical values are in the 0.1-0.9 range.
|
140 |
+
|
141 |
+
Set to 0 to disable.
|
142 |
+
|
143 |
+
top_k (int): Only sample from the top k most probable tokens. Set to 0 to disable.
|
144 |
+
|
145 |
+
min_p (float): Minimum token probability, scaled by the probability of the most likely token.
|
146 |
+
Must be between 0 and 1. Typical values are in the 0.01-0.2 range.
|
147 |
+
If too high, no token might be sampled leading to silence (?)
|
148 |
+
|
149 |
+
linear (float): NovelAI's Unified Sampler -> 0.0 to 1.0, default from gradio 0.5
|
150 |
+
|
151 |
+
Set Linear between 0 and 1 according to how unusual you want tokens to be.
|
152 |
+
Lower numbers will produce more unusual/creative outputs,
|
153 |
+
but you will have to reroll or edit more.
|
154 |
+
|
155 |
+
conf (float): Confidence - Low values make random outputs more random. -> -2.0 * Quad to 2.0, default from gradio 0.4
|
156 |
+
|
157 |
+
As a starting point, set Quad = 1/3 - Linear * 4 / 15, and Conf = -Quad / 2.
|
158 |
+
|
159 |
+
quad (float): Quadratic - High values make low probablities much lower. -> -2.0 to 2.0, default from gradio 0.0
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
torch.Tensor: Sampled tokens.
|
163 |
+
"""
|
164 |
+
if repetition_penalty != 1.0 and generated_tokens is not None:
|
165 |
+
logits = modify_logit_for_repetition_penalty(logits, generated_tokens, repetition_penalty, repetition_penalty_window)
|
166 |
+
|
167 |
+
if temperature > 0:
|
168 |
+
probs = torch.softmax(logits / temperature, dim=-1)
|
169 |
+
if linear > 0.0:
|
170 |
+
probs = apply_unified(probs, linear, conf, quad)
|
171 |
+
if top_p > 0:
|
172 |
+
probs = apply_top_p(probs, top_p)
|
173 |
+
if top_k > 0:
|
174 |
+
probs = apply_top_k(probs, top_k)
|
175 |
+
if min_p > 0:
|
176 |
+
probs = apply_min_p(probs, min_p)
|
177 |
+
|
178 |
+
next_token = multinomial(probs, num_samples=1)
|
179 |
+
else:
|
180 |
+
next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
181 |
+
|
182 |
+
return next_token # [batch_size, num_codebooks, 1]
|
Zonos-main/zonos/speaker_cloning.py
ADDED
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from functools import cache
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torchaudio
|
8 |
+
from huggingface_hub import hf_hub_download
|
9 |
+
|
10 |
+
from zonos.utils import DEFAULT_DEVICE
|
11 |
+
|
12 |
+
|
13 |
+
class logFbankCal(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
sample_rate: int = 16_000,
|
17 |
+
n_fft: int = 512,
|
18 |
+
win_length: float = 0.025,
|
19 |
+
hop_length: float = 0.01,
|
20 |
+
n_mels: int = 80,
|
21 |
+
):
|
22 |
+
super().__init__()
|
23 |
+
self.fbankCal = torchaudio.transforms.MelSpectrogram(
|
24 |
+
sample_rate=sample_rate,
|
25 |
+
n_fft=n_fft,
|
26 |
+
win_length=int(win_length * sample_rate),
|
27 |
+
hop_length=int(hop_length * sample_rate),
|
28 |
+
n_mels=n_mels,
|
29 |
+
)
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
out = self.fbankCal(x)
|
33 |
+
out = torch.log(out + 1e-6)
|
34 |
+
out = out - out.mean(axis=2).unsqueeze(dim=2)
|
35 |
+
return out
|
36 |
+
|
37 |
+
|
38 |
+
class ASP(nn.Module):
|
39 |
+
# Attentive statistics pooling
|
40 |
+
def __init__(self, in_planes, acoustic_dim):
|
41 |
+
super(ASP, self).__init__()
|
42 |
+
outmap_size = int(acoustic_dim / 8)
|
43 |
+
self.out_dim = in_planes * 8 * outmap_size * 2
|
44 |
+
|
45 |
+
self.attention = nn.Sequential(
|
46 |
+
nn.Conv1d(in_planes * 8 * outmap_size, 128, kernel_size=1),
|
47 |
+
nn.ReLU(),
|
48 |
+
nn.BatchNorm1d(128),
|
49 |
+
nn.Conv1d(128, in_planes * 8 * outmap_size, kernel_size=1),
|
50 |
+
nn.Softmax(dim=2),
|
51 |
+
)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
x = x.reshape(x.size()[0], -1, x.size()[-1])
|
55 |
+
w = self.attention(x)
|
56 |
+
mu = torch.sum(x * w, dim=2)
|
57 |
+
sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5))
|
58 |
+
x = torch.cat((mu, sg), 1)
|
59 |
+
|
60 |
+
x = x.view(x.size()[0], -1)
|
61 |
+
return x
|
62 |
+
|
63 |
+
|
64 |
+
class SimAMBasicBlock(nn.Module):
|
65 |
+
expansion = 1
|
66 |
+
|
67 |
+
def __init__(self, ConvLayer, NormLayer, in_planes, planes, stride=1, block_id=1):
|
68 |
+
super(SimAMBasicBlock, self).__init__()
|
69 |
+
self.conv1 = ConvLayer(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
70 |
+
self.bn1 = NormLayer(planes)
|
71 |
+
self.conv2 = ConvLayer(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
72 |
+
self.bn2 = NormLayer(planes)
|
73 |
+
self.relu = nn.ReLU(inplace=True)
|
74 |
+
self.sigmoid = nn.Sigmoid()
|
75 |
+
|
76 |
+
self.downsample = nn.Sequential()
|
77 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
78 |
+
self.downsample = nn.Sequential(
|
79 |
+
ConvLayer(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
80 |
+
NormLayer(self.expansion * planes),
|
81 |
+
)
|
82 |
+
|
83 |
+
def forward(self, x):
|
84 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
85 |
+
out = self.bn2(self.conv2(out))
|
86 |
+
out = self.SimAM(out)
|
87 |
+
out += self.downsample(x)
|
88 |
+
out = self.relu(out)
|
89 |
+
return out
|
90 |
+
|
91 |
+
def SimAM(self, X, lambda_p=1e-4):
|
92 |
+
n = X.shape[2] * X.shape[3] - 1
|
93 |
+
d = (X - X.mean(dim=[2, 3], keepdim=True)).pow(2)
|
94 |
+
v = d.sum(dim=[2, 3], keepdim=True) / n
|
95 |
+
E_inv = d / (4 * (v + lambda_p)) + 0.5
|
96 |
+
return X * self.sigmoid(E_inv)
|
97 |
+
|
98 |
+
|
99 |
+
class BasicBlock(nn.Module):
|
100 |
+
expansion = 1
|
101 |
+
|
102 |
+
def __init__(self, ConvLayer, NormLayer, in_planes, planes, stride=1, block_id=1):
|
103 |
+
super(BasicBlock, self).__init__()
|
104 |
+
self.conv1 = ConvLayer(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
105 |
+
self.bn1 = NormLayer(planes)
|
106 |
+
self.conv2 = ConvLayer(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
107 |
+
self.bn2 = NormLayer(planes)
|
108 |
+
self.relu = nn.ReLU(inplace=True)
|
109 |
+
|
110 |
+
self.downsample = nn.Sequential()
|
111 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
112 |
+
self.downsample = nn.Sequential(
|
113 |
+
ConvLayer(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
114 |
+
NormLayer(self.expansion * planes),
|
115 |
+
)
|
116 |
+
|
117 |
+
def forward(self, x):
|
118 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
119 |
+
out = self.bn2(self.conv2(out))
|
120 |
+
out += self.downsample(x)
|
121 |
+
out = self.relu(out)
|
122 |
+
return out
|
123 |
+
|
124 |
+
|
125 |
+
class Bottleneck(nn.Module):
|
126 |
+
expansion = 4
|
127 |
+
|
128 |
+
def __init__(self, ConvLayer, NormLayer, in_planes, planes, stride=1, block_id=1):
|
129 |
+
super(Bottleneck, self).__init__()
|
130 |
+
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
|
131 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
132 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
133 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
134 |
+
self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
|
135 |
+
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
|
136 |
+
|
137 |
+
self.shortcut = nn.Sequential()
|
138 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
139 |
+
self.shortcut = nn.Sequential(
|
140 |
+
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
141 |
+
nn.BatchNorm2d(self.expansion * planes),
|
142 |
+
)
|
143 |
+
|
144 |
+
def forward(self, x):
|
145 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
146 |
+
out = F.relu(self.bn2(self.conv2(out)))
|
147 |
+
out = self.bn3(self.conv3(out))
|
148 |
+
out += self.shortcut(x)
|
149 |
+
out = F.relu(out)
|
150 |
+
return out
|
151 |
+
|
152 |
+
|
153 |
+
class ResNet(nn.Module):
|
154 |
+
def __init__(self, in_planes, block, num_blocks, in_ch=1, feat_dim="2d", **kwargs):
|
155 |
+
super(ResNet, self).__init__()
|
156 |
+
if feat_dim == "1d":
|
157 |
+
self.NormLayer = nn.BatchNorm1d
|
158 |
+
self.ConvLayer = nn.Conv1d
|
159 |
+
elif feat_dim == "2d":
|
160 |
+
self.NormLayer = nn.BatchNorm2d
|
161 |
+
self.ConvLayer = nn.Conv2d
|
162 |
+
elif feat_dim == "3d":
|
163 |
+
self.NormLayer = nn.BatchNorm3d
|
164 |
+
self.ConvLayer = nn.Conv3d
|
165 |
+
else:
|
166 |
+
print("error")
|
167 |
+
|
168 |
+
self.in_planes = in_planes
|
169 |
+
|
170 |
+
self.conv1 = self.ConvLayer(in_ch, in_planes, kernel_size=3, stride=1, padding=1, bias=False)
|
171 |
+
self.bn1 = self.NormLayer(in_planes)
|
172 |
+
self.relu = nn.ReLU(inplace=True)
|
173 |
+
self.layer1 = self._make_layer(block, in_planes, num_blocks[0], stride=1, block_id=1)
|
174 |
+
self.layer2 = self._make_layer(block, in_planes * 2, num_blocks[1], stride=2, block_id=2)
|
175 |
+
self.layer3 = self._make_layer(block, in_planes * 4, num_blocks[2], stride=2, block_id=3)
|
176 |
+
self.layer4 = self._make_layer(block, in_planes * 8, num_blocks[3], stride=2, block_id=4)
|
177 |
+
|
178 |
+
def _make_layer(self, block, planes, num_blocks, stride, block_id=1):
|
179 |
+
strides = [stride] + [1] * (num_blocks - 1)
|
180 |
+
layers = []
|
181 |
+
for stride in strides:
|
182 |
+
layers.append(block(self.ConvLayer, self.NormLayer, self.in_planes, planes, stride, block_id))
|
183 |
+
self.in_planes = planes * block.expansion
|
184 |
+
return nn.Sequential(*layers)
|
185 |
+
|
186 |
+
def forward(self, x):
|
187 |
+
x = self.relu(self.bn1(self.conv1(x)))
|
188 |
+
x = self.layer1(x)
|
189 |
+
x = self.layer2(x)
|
190 |
+
x = self.layer3(x)
|
191 |
+
x = self.layer4(x)
|
192 |
+
return x
|
193 |
+
|
194 |
+
|
195 |
+
def ResNet293(in_planes: int, **kwargs):
|
196 |
+
return ResNet(in_planes, SimAMBasicBlock, [10, 20, 64, 3], **kwargs)
|
197 |
+
|
198 |
+
|
199 |
+
class ResNet293_based(nn.Module):
|
200 |
+
def __init__(
|
201 |
+
self,
|
202 |
+
in_planes: int = 64,
|
203 |
+
embd_dim: int = 256,
|
204 |
+
acoustic_dim: int = 80,
|
205 |
+
featCal=None,
|
206 |
+
dropout: float = 0,
|
207 |
+
**kwargs,
|
208 |
+
):
|
209 |
+
super(ResNet293_based, self).__init__()
|
210 |
+
self.featCal = featCal
|
211 |
+
self.front = ResNet293(in_planes)
|
212 |
+
block_expansion = SimAMBasicBlock.expansion
|
213 |
+
self.pooling = ASP(in_planes * block_expansion, acoustic_dim)
|
214 |
+
self.bottleneck = nn.Linear(self.pooling.out_dim, embd_dim)
|
215 |
+
self.drop = nn.Dropout(dropout) if dropout else None
|
216 |
+
|
217 |
+
def forward(self, x):
|
218 |
+
x = self.featCal(x)
|
219 |
+
x = self.front(x.unsqueeze(dim=1))
|
220 |
+
x = self.pooling(x)
|
221 |
+
if self.drop:
|
222 |
+
x = self.drop(x)
|
223 |
+
x = self.bottleneck(x)
|
224 |
+
return x
|
225 |
+
|
226 |
+
|
227 |
+
class SEModule(nn.Module):
|
228 |
+
def __init__(self, channels, bottleneck=128):
|
229 |
+
super(SEModule, self).__init__()
|
230 |
+
self.se = nn.Sequential(
|
231 |
+
nn.AdaptiveAvgPool1d(1),
|
232 |
+
nn.Conv1d(channels, bottleneck, kernel_size=1, padding=0),
|
233 |
+
nn.ReLU(),
|
234 |
+
# nn.BatchNorm1d(bottleneck), # Removed
|
235 |
+
nn.Conv1d(bottleneck, channels, kernel_size=1, padding=0),
|
236 |
+
nn.Sigmoid(),
|
237 |
+
)
|
238 |
+
|
239 |
+
def forward(self, input):
|
240 |
+
x = self.se(input)
|
241 |
+
return input * x
|
242 |
+
|
243 |
+
|
244 |
+
class Bottle2neck(nn.Module):
|
245 |
+
def __init__(self, inplanes, planes, kernel_size=None, dilation=None, scale=8):
|
246 |
+
super(Bottle2neck, self).__init__()
|
247 |
+
width = int(math.floor(planes / scale))
|
248 |
+
self.conv1 = nn.Conv1d(inplanes, width * scale, kernel_size=1)
|
249 |
+
self.bn1 = nn.BatchNorm1d(width * scale)
|
250 |
+
self.nums = scale - 1
|
251 |
+
convs = []
|
252 |
+
bns = []
|
253 |
+
num_pad = math.floor(kernel_size / 2) * dilation
|
254 |
+
for i in range(self.nums):
|
255 |
+
convs.append(nn.Conv1d(width, width, kernel_size=kernel_size, dilation=dilation, padding=num_pad))
|
256 |
+
bns.append(nn.BatchNorm1d(width))
|
257 |
+
self.convs = nn.ModuleList(convs)
|
258 |
+
self.bns = nn.ModuleList(bns)
|
259 |
+
self.conv3 = nn.Conv1d(width * scale, planes, kernel_size=1)
|
260 |
+
self.bn3 = nn.BatchNorm1d(planes)
|
261 |
+
self.relu = nn.ReLU()
|
262 |
+
self.width = width
|
263 |
+
self.se = SEModule(planes)
|
264 |
+
|
265 |
+
def forward(self, x):
|
266 |
+
residual = x
|
267 |
+
out = self.conv1(x)
|
268 |
+
out = self.relu(out)
|
269 |
+
out = self.bn1(out)
|
270 |
+
|
271 |
+
spx = torch.split(out, self.width, 1)
|
272 |
+
for i in range(self.nums):
|
273 |
+
if i == 0:
|
274 |
+
sp = spx[i]
|
275 |
+
else:
|
276 |
+
sp = sp + spx[i]
|
277 |
+
sp = self.convs[i](sp)
|
278 |
+
sp = self.relu(sp)
|
279 |
+
sp = self.bns[i](sp)
|
280 |
+
if i == 0:
|
281 |
+
out = sp
|
282 |
+
else:
|
283 |
+
out = torch.cat((out, sp), 1)
|
284 |
+
out = torch.cat((out, spx[self.nums]), 1)
|
285 |
+
|
286 |
+
out = self.conv3(out)
|
287 |
+
out = self.relu(out)
|
288 |
+
out = self.bn3(out)
|
289 |
+
|
290 |
+
out = self.se(out)
|
291 |
+
out += residual
|
292 |
+
return out
|
293 |
+
|
294 |
+
|
295 |
+
class ECAPA_TDNN(nn.Module):
|
296 |
+
def __init__(self, C, featCal):
|
297 |
+
super(ECAPA_TDNN, self).__init__()
|
298 |
+
self.featCal = featCal
|
299 |
+
self.conv1 = nn.Conv1d(80, C, kernel_size=5, stride=1, padding=2)
|
300 |
+
self.relu = nn.ReLU()
|
301 |
+
self.bn1 = nn.BatchNorm1d(C)
|
302 |
+
self.layer1 = Bottle2neck(C, C, kernel_size=3, dilation=2, scale=8)
|
303 |
+
self.layer2 = Bottle2neck(C, C, kernel_size=3, dilation=3, scale=8)
|
304 |
+
self.layer3 = Bottle2neck(C, C, kernel_size=3, dilation=4, scale=8)
|
305 |
+
# I fixed the shape of the output from MFA layer, that is close to the setting from ECAPA paper.
|
306 |
+
self.layer4 = nn.Conv1d(3 * C, 1536, kernel_size=1)
|
307 |
+
self.attention = nn.Sequential(
|
308 |
+
nn.Conv1d(4608, 256, kernel_size=1),
|
309 |
+
nn.ReLU(),
|
310 |
+
nn.BatchNorm1d(256),
|
311 |
+
nn.Tanh(), # Added
|
312 |
+
nn.Conv1d(256, 1536, kernel_size=1),
|
313 |
+
nn.Softmax(dim=2),
|
314 |
+
)
|
315 |
+
self.bn5 = nn.BatchNorm1d(3072)
|
316 |
+
self.fc6 = nn.Linear(3072, 192)
|
317 |
+
self.bn6 = nn.BatchNorm1d(192)
|
318 |
+
|
319 |
+
def forward(self, x):
|
320 |
+
x = self.featCal(x)
|
321 |
+
x = self.conv1(x)
|
322 |
+
x = self.relu(x)
|
323 |
+
x = self.bn1(x)
|
324 |
+
|
325 |
+
x1 = self.layer1(x)
|
326 |
+
x2 = self.layer2(x + x1)
|
327 |
+
x3 = self.layer3(x + x1 + x2)
|
328 |
+
|
329 |
+
x = self.layer4(torch.cat((x1, x2, x3), dim=1))
|
330 |
+
x = self.relu(x)
|
331 |
+
|
332 |
+
t = x.size()[-1]
|
333 |
+
|
334 |
+
global_x = torch.cat(
|
335 |
+
(
|
336 |
+
x,
|
337 |
+
torch.mean(x, dim=2, keepdim=True).repeat(1, 1, t),
|
338 |
+
torch.sqrt(torch.var(x, dim=2, keepdim=True).clamp(min=1e-4)).repeat(1, 1, t),
|
339 |
+
),
|
340 |
+
dim=1,
|
341 |
+
)
|
342 |
+
|
343 |
+
w = self.attention(global_x)
|
344 |
+
|
345 |
+
mu = torch.sum(x * w, dim=2)
|
346 |
+
sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-4))
|
347 |
+
|
348 |
+
x = torch.cat((mu, sg), 1)
|
349 |
+
x = self.bn5(x)
|
350 |
+
x = self.fc6(x)
|
351 |
+
x = self.bn6(x)
|
352 |
+
|
353 |
+
return x
|
354 |
+
|
355 |
+
|
356 |
+
class SpeakerEmbedding(nn.Module):
|
357 |
+
def __init__(self, ckpt_path: str = "ResNet293_SimAM_ASP_base.pt", device: str = DEFAULT_DEVICE):
|
358 |
+
super().__init__()
|
359 |
+
self.device = device
|
360 |
+
with torch.device(device):
|
361 |
+
self.model = ResNet293_based()
|
362 |
+
state_dict = torch.load(ckpt_path, weights_only=True, mmap=True, map_location="cpu")
|
363 |
+
self.model.load_state_dict(state_dict)
|
364 |
+
self.model.featCal = logFbankCal()
|
365 |
+
|
366 |
+
self.requires_grad_(False).eval()
|
367 |
+
|
368 |
+
@property
|
369 |
+
def dtype(self):
|
370 |
+
return next(self.parameters()).dtype
|
371 |
+
|
372 |
+
@cache
|
373 |
+
def _get_resampler(self, orig_sample_rate: int):
|
374 |
+
return torchaudio.transforms.Resample(orig_sample_rate, 16_000).to(self.device)
|
375 |
+
|
376 |
+
def prepare_input(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
|
377 |
+
assert wav.ndim < 3
|
378 |
+
if wav.ndim == 2:
|
379 |
+
wav = wav.mean(0, keepdim=True)
|
380 |
+
wav = self._get_resampler(sample_rate)(wav)
|
381 |
+
return wav
|
382 |
+
|
383 |
+
def forward(self, wav: torch.Tensor, sample_rate: int):
|
384 |
+
wav = self.prepare_input(wav, sample_rate).to(self.device, self.dtype)
|
385 |
+
return self.model(wav).to(wav.device)
|
386 |
+
|
387 |
+
|
388 |
+
class SpeakerEmbeddingLDA(nn.Module):
|
389 |
+
def __init__(self, device: str = DEFAULT_DEVICE):
|
390 |
+
super().__init__()
|
391 |
+
spk_model_path = hf_hub_download(
|
392 |
+
repo_id="Zyphra/Zonos-v0.1-speaker-embedding",
|
393 |
+
filename="ResNet293_SimAM_ASP_base.pt",
|
394 |
+
)
|
395 |
+
lda_spk_model_path = hf_hub_download(
|
396 |
+
repo_id="Zyphra/Zonos-v0.1-speaker-embedding",
|
397 |
+
filename="ResNet293_SimAM_ASP_base_LDA-128.pt",
|
398 |
+
)
|
399 |
+
|
400 |
+
self.device = device
|
401 |
+
with torch.device(device):
|
402 |
+
self.model = SpeakerEmbedding(spk_model_path, device)
|
403 |
+
lda_sd = torch.load(lda_spk_model_path, weights_only=True)
|
404 |
+
out_features, in_features = lda_sd["weight"].shape
|
405 |
+
self.lda = nn.Linear(in_features, out_features, bias=True, dtype=torch.float32)
|
406 |
+
self.lda.load_state_dict(lda_sd)
|
407 |
+
|
408 |
+
self.requires_grad_(False).eval()
|
409 |
+
|
410 |
+
def forward(self, wav: torch.Tensor, sample_rate: int):
|
411 |
+
emb = self.model(wav, sample_rate).to(torch.float32)
|
412 |
+
return emb, self.lda(emb)
|
Zonos-main/zonos/utils.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
def find_multiple(n: int, k: int) -> int:
|
7 |
+
if k == 0 or n % k == 0:
|
8 |
+
return n
|
9 |
+
return n + k - (n % k)
|
10 |
+
|
11 |
+
|
12 |
+
def pad_weight_(w: nn.Embedding | nn.Linear, multiple: int):
|
13 |
+
"""Pad the weight of an embedding or linear layer to a multiple of `multiple`."""
|
14 |
+
if isinstance(w, nn.Embedding):
|
15 |
+
# Pad input dim
|
16 |
+
if w.weight.shape[1] % multiple == 0:
|
17 |
+
return
|
18 |
+
w.weight.data = F.pad(w.weight.data, (0, 0, 0, w.weight.shape[1] % multiple))
|
19 |
+
w.num_embeddings, w.embedding_dim = w.weight.shape
|
20 |
+
elif isinstance(w, nn.Linear):
|
21 |
+
# Pad output dim
|
22 |
+
if w.weight.shape[0] % multiple == 0:
|
23 |
+
return
|
24 |
+
w.weight.data = F.pad(w.weight.data, (0, 0, 0, w.weight.shape[0] % multiple))
|
25 |
+
w.out_features, w.in_features = w.weight.shape
|
26 |
+
else:
|
27 |
+
raise ValueError(f"Unsupported weight type: {type(w)}")
|
28 |
+
|
29 |
+
|
30 |
+
def get_device() -> torch.device:
|
31 |
+
if torch.cuda.is_available():
|
32 |
+
return torch.device(torch.cuda.current_device())
|
33 |
+
# MPS breaks for whatever reason. Uncomment when it's working.
|
34 |
+
# if torch.mps.is_available():
|
35 |
+
# return torch.device("mps")
|
36 |
+
return torch.device("cpu")
|
37 |
+
|
38 |
+
|
39 |
+
DEFAULT_DEVICE = get_device()
|