diff --git a/.dockerignore b/.dockerignore
new file mode 100644
index 0000000000000000000000000000000000000000..e464d502ddae5dc6ed7e1c13176a3f90fa8c4738
--- /dev/null
+++ b/.dockerignore
@@ -0,0 +1,7 @@
+.git
+.github
+results
+data
+*.filelist
+/data_server/target
+checkpoints
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..e5305b2f6446feab7f399ff0f0f3778fb7fb335f
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,32 @@
+.DS_Store
+.pgx.*
+.pdm-python
+/fish_speech.egg-info
+__pycache__
+/results
+/data
+/*.test.sh
+*.filelist
+filelists
+/fish_speech/text/cmudict_cache.pickle
+/checkpoints
+/.vscode
+/data_server/target
+/*.npy
+/*.wav
+/*.mp3
+/*.lab
+/results
+/data
+/.idea
+ffmpeg.exe
+ffprobe.exe
+asr-label*
+/.cache
+/fishenv
+/.locale
+/demo-audios
+/references
+/example
+/faster_whisper
+/.gradio
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..18dedd27cd83cdd1e4fe918b09f559d0005b6fb5
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,25 @@
+ci:
+ autoupdate_schedule: monthly
+
+repos:
+ - repo: https://github.com/pycqa/isort
+ rev: 5.13.2
+ hooks:
+ - id: isort
+ args: [--profile=black]
+
+ - repo: https://github.com/psf/black
+ rev: 24.10.0
+ hooks:
+ - id: black
+
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v5.0.0
+ hooks:
+ - id: end-of-file-fixer
+ - id: check-yaml
+ - id: check-json
+ - id: mixed-line-ending
+ args: ['--fix=lf']
+ - id: check-added-large-files
+ args: ['--maxkb=5000']
diff --git a/.project-root b/.project-root
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.readthedocs.yaml b/.readthedocs.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..015eb5de8569951255b2d66c251ee20fe9153ace
--- /dev/null
+++ b/.readthedocs.yaml
@@ -0,0 +1,19 @@
+# Read the Docs configuration file for MkDocs projects
+# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
+
+# Required
+version: 2
+
+# Set the version of Python and other tools you might need
+build:
+ os: ubuntu-22.04
+ tools:
+ python: "3.12"
+
+mkdocs:
+ configuration: mkdocs.yml
+
+# Optionally declare the Python requirements required to build your docs
+python:
+ install:
+ - requirements: docs/requirements.txt
diff --git a/API_FLAGS.txt b/API_FLAGS.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4a49f141b9fb8d7d3bbc3d009dc204efd6fef324
--- /dev/null
+++ b/API_FLAGS.txt
@@ -0,0 +1,6 @@
+# --infer
+--api
+--listen 0.0.0.0:8080 \
+--llama-checkpoint-path "checkpoints/fish-speech-1.4" \
+--decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+--decoder-config-name firefly_gan_vq
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..cbe5ad1670406e4402217edfb82d2c56af7e8631
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,437 @@
+Attribution-NonCommercial-ShareAlike 4.0 International
+
+=======================================================================
+
+Creative Commons Corporation ("Creative Commons") is not a law firm and
+does not provide legal services or legal advice. Distribution of
+Creative Commons public licenses does not create a lawyer-client or
+other relationship. Creative Commons makes its licenses and related
+information available on an "as-is" basis. Creative Commons gives no
+warranties regarding its licenses, any material licensed under their
+terms and conditions, or any related information. Creative Commons
+disclaims all liability for damages resulting from their use to the
+fullest extent possible.
+
+Using Creative Commons Public Licenses
+
+Creative Commons public licenses provide a standard set of terms and
+conditions that creators and other rights holders may use to share
+original works of authorship and other material subject to copyright
+and certain other rights specified in the public license below. The
+following considerations are for informational purposes only, are not
+exhaustive, and do not form part of our licenses.
+
+ Considerations for licensors: Our public licenses are
+ intended for use by those authorized to give the public
+ permission to use material in ways otherwise restricted by
+ copyright and certain other rights. Our licenses are
+ irrevocable. Licensors should read and understand the terms
+ and conditions of the license they choose before applying it.
+ Licensors should also secure all rights necessary before
+ applying our licenses so that the public can reuse the
+ material as expected. Licensors should clearly mark any
+ material not subject to the license. This includes other CC-
+ licensed material, or material used under an exception or
+ limitation to copyright. More considerations for licensors:
+ wiki.creativecommons.org/Considerations_for_licensors
+
+ Considerations for the public: By using one of our public
+ licenses, a licensor grants the public permission to use the
+ licensed material under specified terms and conditions. If
+ the licensor's permission is not necessary for any reason--for
+ example, because of any applicable exception or limitation to
+ copyright--then that use is not regulated by the license. Our
+ licenses grant only permissions under copyright and certain
+ other rights that a licensor has authority to grant. Use of
+ the licensed material may still be restricted for other
+ reasons, including because others have copyright or other
+ rights in the material. A licensor may make special requests,
+ such as asking that all changes be marked or described.
+ Although not required by our licenses, you are encouraged to
+ respect those requests where reasonable. More considerations
+ for the public:
+ wiki.creativecommons.org/Considerations_for_licensees
+
+=======================================================================
+
+Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
+Public License
+
+By exercising the Licensed Rights (defined below), You accept and agree
+to be bound by the terms and conditions of this Creative Commons
+Attribution-NonCommercial-ShareAlike 4.0 International Public License
+("Public License"). To the extent this Public License may be
+interpreted as a contract, You are granted the Licensed Rights in
+consideration of Your acceptance of these terms and conditions, and the
+Licensor grants You such rights in consideration of benefits the
+Licensor receives from making the Licensed Material available under
+these terms and conditions.
+
+
+Section 1 -- Definitions.
+
+ a. Adapted Material means material subject to Copyright and Similar
+ Rights that is derived from or based upon the Licensed Material
+ and in which the Licensed Material is translated, altered,
+ arranged, transformed, or otherwise modified in a manner requiring
+ permission under the Copyright and Similar Rights held by the
+ Licensor. For purposes of this Public License, where the Licensed
+ Material is a musical work, performance, or sound recording,
+ Adapted Material is always produced where the Licensed Material is
+ synched in timed relation with a moving image.
+
+ b. Adapter's License means the license You apply to Your Copyright
+ and Similar Rights in Your contributions to Adapted Material in
+ accordance with the terms and conditions of this Public License.
+
+ c. BY-NC-SA Compatible License means a license listed at
+ creativecommons.org/compatiblelicenses, approved by Creative
+ Commons as essentially the equivalent of this Public License.
+
+ d. Copyright and Similar Rights means copyright and/or similar rights
+ closely related to copyright including, without limitation,
+ performance, broadcast, sound recording, and Sui Generis Database
+ Rights, without regard to how the rights are labeled or
+ categorized. For purposes of this Public License, the rights
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
+ Rights.
+
+ e. Effective Technological Measures means those measures that, in the
+ absence of proper authority, may not be circumvented under laws
+ fulfilling obligations under Article 11 of the WIPO Copyright
+ Treaty adopted on December 20, 1996, and/or similar international
+ agreements.
+
+ f. Exceptions and Limitations means fair use, fair dealing, and/or
+ any other exception or limitation to Copyright and Similar Rights
+ that applies to Your use of the Licensed Material.
+
+ g. License Elements means the license attributes listed in the name
+ of a Creative Commons Public License. The License Elements of this
+ Public License are Attribution, NonCommercial, and ShareAlike.
+
+ h. Licensed Material means the artistic or literary work, database,
+ or other material to which the Licensor applied this Public
+ License.
+
+ i. Licensed Rights means the rights granted to You subject to the
+ terms and conditions of this Public License, which are limited to
+ all Copyright and Similar Rights that apply to Your use of the
+ Licensed Material and that the Licensor has authority to license.
+
+ j. Licensor means the individual(s) or entity(ies) granting rights
+ under this Public License.
+
+ k. NonCommercial means not primarily intended for or directed towards
+ commercial advantage or monetary compensation. For purposes of
+ this Public License, the exchange of the Licensed Material for
+ other material subject to Copyright and Similar Rights by digital
+ file-sharing or similar means is NonCommercial provided there is
+ no payment of monetary compensation in connection with the
+ exchange.
+
+ l. Share means to provide material to the public by any means or
+ process that requires permission under the Licensed Rights, such
+ as reproduction, public display, public performance, distribution,
+ dissemination, communication, or importation, and to make material
+ available to the public including in ways that members of the
+ public may access the material from a place and at a time
+ individually chosen by them.
+
+ m. Sui Generis Database Rights means rights other than copyright
+ resulting from Directive 96/9/EC of the European Parliament and of
+ the Council of 11 March 1996 on the legal protection of databases,
+ as amended and/or succeeded, as well as other essentially
+ equivalent rights anywhere in the world.
+
+ n. You means the individual or entity exercising the Licensed Rights
+ under this Public License. Your has a corresponding meaning.
+
+
+Section 2 -- Scope.
+
+ a. License grant.
+
+ 1. Subject to the terms and conditions of this Public License,
+ the Licensor hereby grants You a worldwide, royalty-free,
+ non-sublicensable, non-exclusive, irrevocable license to
+ exercise the Licensed Rights in the Licensed Material to:
+
+ a. reproduce and Share the Licensed Material, in whole or
+ in part, for NonCommercial purposes only; and
+
+ b. produce, reproduce, and Share Adapted Material for
+ NonCommercial purposes only.
+
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
+ Exceptions and Limitations apply to Your use, this Public
+ License does not apply, and You do not need to comply with
+ its terms and conditions.
+
+ 3. Term. The term of this Public License is specified in Section
+ 6(a).
+
+ 4. Media and formats; technical modifications allowed. The
+ Licensor authorizes You to exercise the Licensed Rights in
+ all media and formats whether now known or hereafter created,
+ and to make technical modifications necessary to do so. The
+ Licensor waives and/or agrees not to assert any right or
+ authority to forbid You from making technical modifications
+ necessary to exercise the Licensed Rights, including
+ technical modifications necessary to circumvent Effective
+ Technological Measures. For purposes of this Public License,
+ simply making modifications authorized by this Section 2(a)
+ (4) never produces Adapted Material.
+
+ 5. Downstream recipients.
+
+ a. Offer from the Licensor -- Licensed Material. Every
+ recipient of the Licensed Material automatically
+ receives an offer from the Licensor to exercise the
+ Licensed Rights under the terms and conditions of this
+ Public License.
+
+ b. Additional offer from the Licensor -- Adapted Material.
+ Every recipient of Adapted Material from You
+ automatically receives an offer from the Licensor to
+ exercise the Licensed Rights in the Adapted Material
+ under the conditions of the Adapter's License You apply.
+
+ c. No downstream restrictions. You may not offer or impose
+ any additional or different terms or conditions on, or
+ apply any Effective Technological Measures to, the
+ Licensed Material if doing so restricts exercise of the
+ Licensed Rights by any recipient of the Licensed
+ Material.
+
+ 6. No endorsement. Nothing in this Public License constitutes or
+ may be construed as permission to assert or imply that You
+ are, or that Your use of the Licensed Material is, connected
+ with, or sponsored, endorsed, or granted official status by,
+ the Licensor or others designated to receive attribution as
+ provided in Section 3(a)(1)(A)(i).
+
+ b. Other rights.
+
+ 1. Moral rights, such as the right of integrity, are not
+ licensed under this Public License, nor are publicity,
+ privacy, and/or other similar personality rights; however, to
+ the extent possible, the Licensor waives and/or agrees not to
+ assert any such rights held by the Licensor to the limited
+ extent necessary to allow You to exercise the Licensed
+ Rights, but not otherwise.
+
+ 2. Patent and trademark rights are not licensed under this
+ Public License.
+
+ 3. To the extent possible, the Licensor waives any right to
+ collect royalties from You for the exercise of the Licensed
+ Rights, whether directly or through a collecting society
+ under any voluntary or waivable statutory or compulsory
+ licensing scheme. In all other cases the Licensor expressly
+ reserves any right to collect such royalties, including when
+ the Licensed Material is used other than for NonCommercial
+ purposes.
+
+
+Section 3 -- License Conditions.
+
+Your exercise of the Licensed Rights is expressly made subject to the
+following conditions.
+
+ a. Attribution.
+
+ 1. If You Share the Licensed Material (including in modified
+ form), You must:
+
+ a. retain the following if it is supplied by the Licensor
+ with the Licensed Material:
+
+ i. identification of the creator(s) of the Licensed
+ Material and any others designated to receive
+ attribution, in any reasonable manner requested by
+ the Licensor (including by pseudonym if
+ designated);
+
+ ii. a copyright notice;
+
+ iii. a notice that refers to this Public License;
+
+ iv. a notice that refers to the disclaimer of
+ warranties;
+
+ v. a URI or hyperlink to the Licensed Material to the
+ extent reasonably practicable;
+
+ b. indicate if You modified the Licensed Material and
+ retain an indication of any previous modifications; and
+
+ c. indicate the Licensed Material is licensed under this
+ Public License, and include the text of, or the URI or
+ hyperlink to, this Public License.
+
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
+ reasonable manner based on the medium, means, and context in
+ which You Share the Licensed Material. For example, it may be
+ reasonable to satisfy the conditions by providing a URI or
+ hyperlink to a resource that includes the required
+ information.
+ 3. If requested by the Licensor, You must remove any of the
+ information required by Section 3(a)(1)(A) to the extent
+ reasonably practicable.
+
+ b. ShareAlike.
+
+ In addition to the conditions in Section 3(a), if You Share
+ Adapted Material You produce, the following conditions also apply.
+
+ 1. The Adapter's License You apply must be a Creative Commons
+ license with the same License Elements, this version or
+ later, or a BY-NC-SA Compatible License.
+
+ 2. You must include the text of, or the URI or hyperlink to, the
+ Adapter's License You apply. You may satisfy this condition
+ in any reasonable manner based on the medium, means, and
+ context in which You Share Adapted Material.
+
+ 3. You may not offer or impose any additional or different terms
+ or conditions on, or apply any Effective Technological
+ Measures to, Adapted Material that restrict exercise of the
+ rights granted under the Adapter's License You apply.
+
+
+Section 4 -- Sui Generis Database Rights.
+
+Where the Licensed Rights include Sui Generis Database Rights that
+apply to Your use of the Licensed Material:
+
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
+ to extract, reuse, reproduce, and Share all or a substantial
+ portion of the contents of the database for NonCommercial purposes
+ only;
+
+ b. if You include all or a substantial portion of the database
+ contents in a database in which You have Sui Generis Database
+ Rights, then the database in which You have Sui Generis Database
+ Rights (but not its individual contents) is Adapted Material,
+ including for purposes of Section 3(b); and
+
+ c. You must comply with the conditions in Section 3(a) if You Share
+ all or a substantial portion of the contents of the database.
+
+For the avoidance of doubt, this Section 4 supplements and does not
+replace Your obligations under this Public License where the Licensed
+Rights include other Copyright and Similar Rights.
+
+
+Section 5 -- Disclaimer of Warranties and Limitation of Liability.
+
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
+
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
+
+ c. The disclaimer of warranties and limitation of liability provided
+ above shall be interpreted in a manner that, to the extent
+ possible, most closely approximates an absolute disclaimer and
+ waiver of all liability.
+
+
+Section 6 -- Term and Termination.
+
+ a. This Public License applies for the term of the Copyright and
+ Similar Rights licensed here. However, if You fail to comply with
+ this Public License, then Your rights under this Public License
+ terminate automatically.
+
+ b. Where Your right to use the Licensed Material has terminated under
+ Section 6(a), it reinstates:
+
+ 1. automatically as of the date the violation is cured, provided
+ it is cured within 30 days of Your discovery of the
+ violation; or
+
+ 2. upon express reinstatement by the Licensor.
+
+ For the avoidance of doubt, this Section 6(b) does not affect any
+ right the Licensor may have to seek remedies for Your violations
+ of this Public License.
+
+ c. For the avoidance of doubt, the Licensor may also offer the
+ Licensed Material under separate terms or conditions or stop
+ distributing the Licensed Material at any time; however, doing so
+ will not terminate this Public License.
+
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
+ License.
+
+
+Section 7 -- Other Terms and Conditions.
+
+ a. The Licensor shall not be bound by any additional or different
+ terms or conditions communicated by You unless expressly agreed.
+
+ b. Any arrangements, understandings, or agreements regarding the
+ Licensed Material not stated herein are separate from and
+ independent of the terms and conditions of this Public License.
+
+
+Section 8 -- Interpretation.
+
+ a. For the avoidance of doubt, this Public License does not, and
+ shall not be interpreted to, reduce, limit, restrict, or impose
+ conditions on any use of the Licensed Material that could lawfully
+ be made without permission under this Public License.
+
+ b. To the extent possible, if any provision of this Public License is
+ deemed unenforceable, it shall be automatically reformed to the
+ minimum extent necessary to make it enforceable. If the provision
+ cannot be reformed, it shall be severed from this Public License
+ without affecting the enforceability of the remaining terms and
+ conditions.
+
+ c. No term or condition of this Public License will be waived and no
+ failure to comply consented to unless expressly agreed to by the
+ Licensor.
+
+ d. Nothing in this Public License constitutes or may be interpreted
+ as a limitation upon, or waiver of, any privileges and immunities
+ that apply to the Licensor or You, including from the legal
+ processes of any jurisdiction or authority.
+
+=======================================================================
+
+Creative Commons is not a party to its public
+licenses. Notwithstanding, Creative Commons may elect to apply one of
+its public licenses to material it publishes and in those instances
+will be considered the “Licensor.” The text of the Creative Commons
+public licenses is dedicated to the public domain under the CC0 Public
+Domain Dedication. Except for the limited purpose of indicating that
+material is shared under a Creative Commons public license or as
+otherwise permitted by the Creative Commons policies published at
+creativecommons.org/policies, Creative Commons does not authorize the
+use of the trademark "Creative Commons" or any other trademark or logo
+of Creative Commons without its prior written consent including,
+without limitation, in connection with any unauthorized modifications
+to any of its public licenses or any other arrangements,
+understandings, or agreements concerning use of licensed material. For
+the avoidance of doubt, this paragraph does not form part of the
+public licenses.
+
+Creative Commons may be contacted at creativecommons.org.
diff --git a/README.md b/README.md
index 2ab947944a6b4814e0a573a4f0d527f0ec7ac768..0dd76de8edb1712302e0e265cc7be9304e0e0270 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,124 @@
+
+
Fish Speech
+
+**English** | [简体中文](docs/README.zh.md) | [Portuguese](docs/README.pt-BR.md) | [日本語](docs/README.ja.md) | [한국어](docs/README.ko.md)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+This codebase and all models are released under CC-BY-NC-SA-4.0 License. Please refer to [LICENSE](LICENSE) for more details.
+
---
-title: Fishspeech2
-emoji: 🔥
-colorFrom: yellow
-colorTo: blue
-sdk: gradio
-sdk_version: 5.5.0
-app_file: app.py
-pinned: false
-license: cc-by-nc-sa-4.0
----
+## Fish Agent
+We are very excited to annoce that we have made our self-research agent demo open source, you can now try our agent demo online at [demo](https://fish.audio/demo/live) for instant English chat and English and Chinese chat locally by following the [docs](https://speech.fish.audio/start_agent/).
+
+You should mention that the content is released under a **CC BY-NC-SA 4.0 licence**. And the demo is an early alpha test version, the inference speed needs to be optimised, and there are a lot of bugs waiting to be fixed. If you've found a bug or want to fix it, we'd be very happy to receive an issue or a pull request.
+
+## Features
+### Fish Speech
+
+1. **Zero-shot & Few-shot TTS:** Input a 10 to 30-second vocal sample to generate high-quality TTS output. **For detailed guidelines, see [Voice Cloning Best Practices](https://docs.fish.audio/text-to-speech/voice-clone-best-practices).**
+
+2. **Multilingual & Cross-lingual Support:** Simply copy and paste multilingual text into the input box—no need to worry about the language. Currently supports English, Japanese, Korean, Chinese, French, German, Arabic, and Spanish.
+
+3. **No Phoneme Dependency:** The model has strong generalization capabilities and does not rely on phonemes for TTS. It can handle text in any language script.
+
+4. **Highly Accurate:** Achieves a low CER (Character Error Rate) and WER (Word Error Rate) of around 2% for 5-minute English texts.
+
+5. **Fast:** With fish-tech acceleration, the real-time factor is approximately 1:5 on an Nvidia RTX 4060 laptop and 1:15 on an Nvidia RTX 4090.
+
+6. **WebUI Inference:** Features an easy-to-use, Gradio-based web UI compatible with Chrome, Firefox, Edge, and other browsers.
+
+7. **GUI Inference:** Offers a PyQt6 graphical interface that works seamlessly with the API server. Supports Linux, Windows, and macOS. [See GUI](https://github.com/AnyaCoder/fish-speech-gui).
+
+8. **Deploy-Friendly:** Easily set up an inference server with native support for Linux, Windows and MacOS, minimizing speed loss.
+
+### Fish Agent
+1. **Completely End to End:** Automatically integrates ASR and TTS parts, no need to plug-in other models, i.e., true end-to-end, not three-stage (ASR+LLM+TTS).
+
+2. **Timbre Control:** Can use reference audio to control the speech timbre.
+
+3. **Emotional:** The model can generate speech with strong emotion.
+
+## Disclaimer
+
+We do not hold any responsibility for any illegal usage of the codebase. Please refer to your local laws about DMCA and other related laws.
+
+## Online Demo
+
+[Fish Audio](https://fish.audio)
+
+[Fish Agent](https://fish.audio/demo/live)
+
+## Quick Start for Local Inference
+
+[inference.ipynb](/inference.ipynb)
+
+## Videos
+
+#### V1.4 Demo Video: [Youtube](https://www.youtube.com/watch?v=Ghc8cJdQyKQ)
+
+## Documents
+
+- [English](https://speech.fish.audio/)
+- [中文](https://speech.fish.audio/zh/)
+- [日本語](https://speech.fish.audio/ja/)
+- [Portuguese (Brazil)](https://speech.fish.audio/pt/)
+
+## Samples (2024/10/02 V1.4)
+
+- [English](https://speech.fish.audio/samples/)
+- [中文](https://speech.fish.audio/zh/samples/)
+- [日本語](https://speech.fish.audio/ja/samples/)
+- [Portuguese (Brazil)](https://speech.fish.audio/pt/samples/)
+
+## Credits
+
+- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
+- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
+- [GPT VITS](https://github.com/innnky/gpt-vits)
+- [MQTTS](https://github.com/b04901014/MQTTS)
+- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
+- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
+
+## Sponsor
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+
+
diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml
new file mode 100644
index 0000000000000000000000000000000000000000..cbd2ba51ca6c62d967333a86981641e17c3b67ef
--- /dev/null
+++ b/docker-compose.dev.yml
@@ -0,0 +1,18 @@
+version: '3.8'
+
+services:
+ fish-speech:
+ build:
+ context: .
+ dockerfile: dockerfile.dev
+ container_name: fish-speech
+ volumes:
+ - ./:/exp
+ deploy:
+ resources:
+ reservations:
+ devices:
+ - driver: nvidia
+ count: all
+ capabilities: [gpu]
+ command: tail -f /dev/null
diff --git a/dockerfile b/dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..8cd7834ea15b6d255e26bfa9d4df56f589bea486
--- /dev/null
+++ b/dockerfile
@@ -0,0 +1,50 @@
+FROM python:3.12-slim-bookworm AS stage-1
+ARG TARGETARCH
+
+ARG HUGGINGFACE_MODEL=fish-speech-1.4
+ARG HF_ENDPOINT=https://huggingface.co
+
+WORKDIR /opt/fish-speech
+
+RUN set -ex \
+ && pip install huggingface_hub \
+ && HF_ENDPOINT=${HF_ENDPOINT} huggingface-cli download --resume-download fishaudio/${HUGGINGFACE_MODEL} --local-dir checkpoints/${HUGGINGFACE_MODEL}
+
+FROM python:3.12-slim-bookworm
+ARG TARGETARCH
+
+ARG DEPENDENCIES=" \
+ ca-certificates \
+ libsox-dev \
+ build-essential \
+ cmake \
+ libasound-dev \
+ portaudio19-dev \
+ libportaudio2 \
+ libportaudiocpp0 \
+ ffmpeg"
+
+RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
+ --mount=type=cache,target=/var/lib/apt,sharing=locked \
+ set -ex \
+ && rm -f /etc/apt/apt.conf.d/docker-clean \
+ && echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' >/etc/apt/apt.conf.d/keep-cache \
+ && apt-get update \
+ && apt-get -y install --no-install-recommends ${DEPENDENCIES} \
+ && echo "no" | dpkg-reconfigure dash
+
+WORKDIR /opt/fish-speech
+
+COPY . .
+
+RUN --mount=type=cache,target=/root/.cache,sharing=locked \
+ set -ex \
+ && pip install -e .[stable]
+
+COPY --from=stage-1 /opt/fish-speech/checkpoints /opt/fish-speech/checkpoints
+
+ENV GRADIO_SERVER_NAME="0.0.0.0"
+
+EXPOSE 7860
+
+CMD ["./entrypoint.sh"]
diff --git a/dockerfile.dev b/dockerfile.dev
new file mode 100644
index 0000000000000000000000000000000000000000..ac5d18f6a6053ba758dcbc557a4b8d5d6eacf09b
--- /dev/null
+++ b/dockerfile.dev
@@ -0,0 +1,37 @@
+ARG VERSION=dev
+ARG BASE_IMAGE=ghcr.io/fishaudio/fish-speech:${VERSION}
+
+FROM ${BASE_IMAGE}
+
+ARG TOOLS=" \
+ git \
+ curl \
+ build-essential \
+ ffmpeg \
+ libsm6 \
+ libxext6 \
+ libjpeg-dev \
+ zlib1g-dev \
+ aria2 \
+ zsh \
+ openssh-server \
+ sudo \
+ protobuf-compiler \
+ libasound-dev \
+ portaudio19-dev \
+ libportaudio2 \
+ libportaudiocpp0 \
+ cmake"
+
+RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
+ --mount=type=cache,target=/var/lib/apt,sharing=locked \
+ set -ex \
+ && apt-get update \
+ && apt-get -y install --no-install-recommends ${TOOLS}
+
+# Install oh-my-zsh so your terminal looks nice
+RUN sh -c "$(curl https://raw.githubusercontent.com/robbyrussell/oh-my-zsh/master/tools/install.sh)" "" --unattended
+
+# Set zsh as default shell
+RUN chsh -s /usr/bin/zsh
+ENV SHELL=/usr/bin/zsh
diff --git a/docs/CNAME b/docs/CNAME
new file mode 100644
index 0000000000000000000000000000000000000000..d506fb8b394fa80f3d329ab8450dfc102e839bd1
--- /dev/null
+++ b/docs/CNAME
@@ -0,0 +1 @@
+speech.fish.audio
diff --git a/docs/README.ja.md b/docs/README.ja.md
new file mode 100644
index 0000000000000000000000000000000000000000..e0872e988edec972a874d7c52327d3fa9380c578
--- /dev/null
+++ b/docs/README.ja.md
@@ -0,0 +1,106 @@
+
+
Fish Speech
+
+[English](../README.md) | [简体中文](README.zh.md) | [Portuguese](README.pt-BR.md) | **日本語** | [한국어](README.ko.md)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+このコードベースとすべてのモデルは、CC-BY-NC-SA-4.0 ライセンスの下でリリースされています。詳細については、[LICENSE](LICENSE)を参照してください。
+
+---
+
+## 機能
+
+1. **ゼロショット & フューショット TTS**:10〜30 秒の音声サンプルを入力して、高品質の TTS 出力を生成します。**詳細は [音声クローンのベストプラクティス](https://docs.fish.audio/text-to-speech/voice-clone-best-practices) を参照してください。**
+2. **多言語 & クロスリンガル対応**:多言語テキストを入力ボックスにコピーペーストするだけで、言語を気にする必要はありません。現在、英語、日本語、韓国語、中国語、フランス語、ドイツ語、アラビア語、スペイン語に対応しています。
+3. **音素依存なし**:このモデルは強力な汎化能力を持ち、TTS に音素を必要としません。あらゆる言語スクリプトに対応可能です。
+4. **高精度**:5 分間の英語テキストに対し、CER(文字誤り率)と WER(単語誤り率)は約 2%の精度を達成します。
+5. **高速**:fish-tech アクセラレーションにより、Nvidia RTX 4060 ラップトップではリアルタイムファクターが約 1:5、Nvidia RTX 4090 では約 1:15 です。
+6. **WebUI 推論**:使いやすい Gradio ベースの Web ユーザーインターフェースを搭載し、Chrome、Firefox、Edge などのブラウザに対応しています。
+7. **GUI 推論**:PyQt6 のグラフィカルインターフェースを提供し、API サーバーとシームレスに連携します。Linux、Windows、macOS に対応しています。[GUI を見る](https://github.com/AnyaCoder/fish-speech-gui)。
+8. **デプロイしやすい**:Linux、Windows、macOS にネイティブ対応した推論サーバーを簡単にセットアップでき、速度の低下を最小限に抑えます。
+
+## 免責事項
+
+コードベースの違法な使用については一切責任を負いません。DMCA(デジタルミレニアム著作権法)およびその他の関連法については、地域の法律を参照してください。
+
+## オンラインデモ
+
+[Fish Audio](https://fish.audio)
+
+## ローカル推論のクイックスタート
+
+[inference.ipynb](/inference.ipynb)
+
+## ビデオ
+
+#### V1.4 デモビデオ: https://www.bilibili.com/video/BV1pu46eVEk7
+
+#### V1.2 デモビデオ: https://www.bilibili.com/video/BV1wz421B71D
+
+#### V1.1 デモビデオ: https://www.bilibili.com/video/BV1zJ4m1K7cj
+
+## ドキュメント
+
+- [英語](https://speech.fish.audio/)
+- [中文](https://speech.fish.audio/zh/)
+- [日本語](https://speech.fish.audio/ja/)
+- [ポルトガル語 (ブラジル)](https://speech.fish.audio/pt/)
+
+## サンプル (2024/10/02 V1.4)
+
+- [英語](https://speech.fish.audio/samples/)
+- [中文](https://speech.fish.audio/zh/samples/)
+- [日本語](https://speech.fish.audio/ja/samples/)
+- [ポルトガル語 (ブラジル)](https://speech.fish.audio/pt/samples/)
+
+## クレジット
+
+- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
+- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
+- [GPT VITS](https://github.com/innnky/gpt-vits)
+- [MQTTS](https://github.com/b04901014/MQTTS)
+- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
+- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
+
+## スポンサー
+
+
+
diff --git a/docs/README.ko.md b/docs/README.ko.md
new file mode 100644
index 0000000000000000000000000000000000000000..952ae5f2f1946b33d1f57ea42c3b4a645e039840
--- /dev/null
+++ b/docs/README.ko.md
@@ -0,0 +1,111 @@
+
+
Fish Speech
+
+[English](../README.md) | [简体中文](README.zh.md) | [Portuguese](README.pt-BR.md) | [日本語](README.ja.md) | **한국어**
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+이 코드베이스와 모든 모델은 CC-BY-NC-SA-4.0 라이선스에 따라 배포됩니다. 자세한 내용은 [LICENSE](LICENSE)를 참조하시길 바랍니다.
+
+---
+
+## 기능
+
+1. **Zero-shot & Few-shot TTS:** 10초에서 30초의 음성 샘플을 입력하여 고품질의 TTS 출력을 생성합니다. **자세한 가이드는 [모범 사례](https://docs.fish.audio/text-to-speech/voice-clone-best-practices)를 참조하시길 바랍니다.**
+
+2. **다국어 및 교차 언어 지원:** 다국어 걱정 없이, 텍스트를 입력창에 복사하여 붙여넣기만 하면 됩니다. 현재 영어, 일본어, 한국어, 중국어, 프랑스어, 독일어, 아랍어, 스페인어를 지원합니다.
+
+3. **음소 의존성 제거:** 이 모델은 강력한 일반화 능력을 가지고 있으며, TTS가 음소에 의존하지 않습니다. 모든 언어 스크립트 텍스트를 손쉽게 처리할 수 있습니다.
+
+4. **높은 정확도:** 영어 텍스트 기준 5분 기준에서 단, 2%의 문자 오류율(CER)과 단어 오류율(WER)을 달성합니다.
+
+5. **빠른 속도:** fish-tech 가속을 통해 실시간 인자(RTF)는 Nvidia RTX 4060 노트북에서는 약 1:5, Nvidia RTX 4090에서는 1:15입니다.
+
+6. **웹 UI 추론:** Chrome, Firefox, Edge 등 다양한 브라우저에서 호환되는 Gradio 기반의 사용하기 쉬운 웹 UI를 제공합니다.
+
+7. **GUI 추론:** PyQt6 그래픽 인터페이스를 제공하여 API 서버와 원활하게 작동합니다. Linux, Windows 및 macOS를 지원합니다. [GUI 참조](https://github.com/AnyaCoder/fish-speech-gui).
+
+8. **배포 친화적:** Linux, Windows, macOS에서 네이티브로 지원되는 추론 서버를 쉽게 설정할 수 있어 속도 손실을 최소화합니다.
+
+## 면책 조항
+
+이 코드베이스의 불법적 사용에 대해 어떠한 책임도 지지 않습니다. DMCA 및 관련 법률에 대한 로컬 법률을 참조하십시오.
+
+## 온라인 데모
+
+[Fish Audio](https://fish.audio)
+
+## 로컬 추론을 위한 빠른 시작
+
+[inference.ipynb](/inference.ipynb)
+
+## 영상
+
+#### V1.4 데모 영상: [Youtube](https://www.youtube.com/watch?v=Ghc8cJdQyKQ)
+
+## 문서
+
+- [English](https://speech.fish.audio/)
+- [中文](https://speech.fish.audio/zh/)
+- [日本語](https://speech.fish.audio/ja/)
+- [Portuguese (Brazil)](https://speech.fish.audio/pt/)
+- [한국어](https://speech.fish.audio/ko/)
+
+## Samples (2024/10/02 V1.4)
+
+- [English](https://speech.fish.audio/samples/)
+- [中文](https://speech.fish.audio/zh/samples/)
+- [日本語](https://speech.fish.audio/ja/samples/)
+- [Portuguese (Brazil)](https://speech.fish.audio/pt/samples/)
+- [한국어](https://speech.fish.audio/ko/samples/)
+
+## Credits
+
+- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
+- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
+- [GPT VITS](https://github.com/innnky/gpt-vits)
+- [MQTTS](https://github.com/b04901014/MQTTS)
+- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
+- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
+
+## Sponsor
+
+
+
diff --git a/docs/README.pt-BR.md b/docs/README.pt-BR.md
new file mode 100644
index 0000000000000000000000000000000000000000..443617ce3025cc0a93857f4c064c987c478b3092
--- /dev/null
+++ b/docs/README.pt-BR.md
@@ -0,0 +1,114 @@
+
+
Fish Speech
+
+[English](../README.md) | [简体中文](README.zh.md) | **Portuguese** | [日本語](README.ja.md) | [한국어](README.ko.md)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+Este código-fonte e os modelos são publicados sob a licença CC-BY-NC-SA-4.0. Consulte [LICENSE](LICENSE) para mais detalhes.
+
+---
+
+## Funcionalidades
+
+1. **TTS Zero-shot & Few-shot**: Insira uma amostra vocal de 10 a 30 segundos para gerar saída de TTS de alta qualidade. **Para diretrizes detalhadas, veja [Melhores Práticas para Clonagem de Voz](https://docs.fish.audio/text-to-speech/voice-clone-best-practices).**
+
+2. **Suporte Multilíngue e Interlingual**: Basta copiar e colar o texto multilíngue na caixa de entrada—não se preocupe com o idioma. Atualmente suporta inglês, japonês, coreano, chinês, francês, alemão, árabe e espanhol.
+
+3. **Sem Dependência de Fonemas**: O modelo tem forte capacidade de generalização e não depende de fonemas para TTS. Ele pode lidar com textos em qualquer script de idioma.
+
+4. **Alta Precisão**: Alcança uma CER (Taxa de Erro de Caracteres) e WER (Taxa de Erro de Palavras) de cerca de 2% para textos de 5 minutos em inglês.
+
+5. **Rápido**: Com a aceleração fish-tech, o fator de tempo real é de aproximadamente 1:5 em um laptop Nvidia RTX 4060 e 1:15 em uma Nvidia RTX 4090.
+
+6. **Inferência WebUI**: Apresenta uma interface de usuário web baseada em Gradio, fácil de usar e compatível com navegadores como Chrome, Firefox e Edge.
+
+7. **Inferência GUI**: Oferece uma interface gráfica PyQt6 que funciona perfeitamente com o servidor API. Suporta Linux, Windows e macOS. [Veja o GUI](https://github.com/AnyaCoder/fish-speech-gui).
+
+8. **Fácil de Implantar**: Configura facilmente um servidor de inferência com suporte nativo para Linux, Windows e macOS, minimizando a perda de velocidade.
+
+## Isenção de Responsabilidade
+
+Não nos responsabilizamos por qualquer uso ilegal do código-fonte. Consulte as leis locais sobre DMCA (Digital Millennium Copyright Act) e outras leis relevantes em sua região.
+
+## Demonstração Online
+
+[Fish Audio](https://fish.audio)
+
+## Início Rápido de Inferência Local
+
+[inference.ipynb](/inference.ipynb)
+
+## Vídeos
+
+#### 1.4 Introdução: https://www.bilibili.com/video/BV1pu46eVEk7
+
+#### 1.2 Introdução: https://www.bilibili.com/video/BV1wz421B71D
+
+#### 1.1 Apresentação Técnica: https://www.bilibili.com/video/BV1zJ4m1K7cj
+
+## Documentação
+
+- [Inglês](https://speech.fish.audio/)
+- [Chinês](https://speech.fish.audio/zh/)
+- [Japonês](https://speech.fish.audio/ja/)
+- [Português (Brasil)](https://speech.fish.audio/pt/)
+
+## Exemplos
+
+- [Inglês](https://speech.fish.audio/samples/)
+- [Chinês](https://speech.fish.audio/zh/samples/)
+- [Japonês](https://speech.fish.audio/ja/samples/)
+- [Português (Brasil)](https://speech.fish.audio/pt/samples/)
+
+## Agradecimentos
+
+- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
+- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
+- [GPT VITS](https://github.com/innnky/gpt-vits)
+- [MQTTS](https://github.com/b04901014/MQTTS)
+- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
+- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
+
+## Patrocinadores
+
+
+
diff --git a/docs/README.zh.md b/docs/README.zh.md
new file mode 100644
index 0000000000000000000000000000000000000000..ee09a64262cd95df2c8b985ae5d5582dd84e9cac
--- /dev/null
+++ b/docs/README.zh.md
@@ -0,0 +1,109 @@
+
+
Fish Speech
+
+[English](../README.md) | **简体中文** | [Portuguese](README.pt-BR.md) | [日本語](README.ja.md) | [한국어](README.ko.md)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+此代码库及模型根据 CC-BY-NC-SA-4.0 许可证发布。请参阅 [LICENSE](LICENSE) 了解更多细节.
+
+---
+
+## 特性
+
+1. **零样本 & 小样本 TTS**:输入 10 到 30 秒的声音样本即可生成高质量的 TTS 输出。**详见 [语音克隆最佳实践指南](https://docs.fish.audio/text-to-speech/voice-clone-best-practices)。**
+2. **多语言 & 跨语言支持**:只需复制并粘贴多语言文本到输入框中,无需担心语言问题。目前支持英语、日语、韩语、中文、法语、德语、阿拉伯语和西班牙语。
+3. **无音素依赖**:模型具备强大的泛化能力,不依赖音素进行 TTS,能够处理任何文字表示的语言。
+4. **高准确率**:在 5 分钟的英文文本上,达到了约 2% 的 CER(字符错误率)和 WER(词错误率)。
+5. **快速**:通过 fish-tech 加速,在 Nvidia RTX 4060 笔记本上的实时因子约为 1:5,在 Nvidia RTX 4090 上约为 1:15。
+6. **WebUI 推理**:提供易于使用的基于 Gradio 的网页用户界面,兼容 Chrome、Firefox、Edge 等浏览器。
+7. **GUI 推理**:提供 PyQt6 图形界面,与 API 服务器无缝协作。支持 Linux、Windows 和 macOS。[查看 GUI](https://github.com/AnyaCoder/fish-speech-gui)。
+8. **易于部署**:轻松设置推理服务器,原生支持 Linux、Windows 和 macOS,最大程度减少速度损失。
+
+## 免责声明
+
+我们不对代码库的任何非法使用承担任何责任. 请参阅您当地关于 DMCA (数字千年法案) 和其他相关法律法规.
+
+## 在线 DEMO
+
+[Fish Audio](https://fish.audio)
+
+## 快速开始本地推理
+
+[inference.ipynb](/inference.ipynb)
+
+## 视频
+
+#### 1.4 介绍: https://www.bilibili.com/video/BV1pu46eVEk7
+
+#### 1.2 介绍: https://www.bilibili.com/video/BV1wz421B71D
+
+#### 1.1 介绍: https://www.bilibili.com/video/BV1zJ4m1K7cj
+
+## 文档
+
+- [English](https://speech.fish.audio/)
+- [中文](https://speech.fish.audio/zh/)
+- [日本語](https://speech.fish.audio/ja/)
+- [Portuguese (Brazil)](https://speech.fish.audio/pt/)
+
+## 例子 (2024/10/02 V1.4)
+
+- [English](https://speech.fish.audio/samples/)
+- [中文](https://speech.fish.audio/zh/samples/)
+- [日本語](https://speech.fish.audio/ja/samples/)
+- [Portuguese (Brazil)](https://speech.fish.audio/pt/samples/)
+
+## 鸣谢
+
+- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
+- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
+- [GPT VITS](https://github.com/innnky/gpt-vits)
+- [MQTTS](https://github.com/b04901014/MQTTS)
+- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
+- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
+
+## 赞助
+
+
+
diff --git a/docs/assets/figs/VS_1.jpg b/docs/assets/figs/VS_1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..41a3f69992edcbbaa85a21695bdc33ff81dc10d6
Binary files /dev/null and b/docs/assets/figs/VS_1.jpg differ
diff --git a/docs/assets/figs/VS_1_pt-BR.png b/docs/assets/figs/VS_1_pt-BR.png
new file mode 100644
index 0000000000000000000000000000000000000000..d7cf5c85cb1cf98d9c716d03575eb0c74d53d572
Binary files /dev/null and b/docs/assets/figs/VS_1_pt-BR.png differ
diff --git a/docs/assets/figs/agent_gradio.png b/docs/assets/figs/agent_gradio.png
new file mode 100644
index 0000000000000000000000000000000000000000..02041bf6caa02f8c598b16bd8b495ef030dc3134
Binary files /dev/null and b/docs/assets/figs/agent_gradio.png differ
diff --git a/docs/assets/figs/diagram.png b/docs/assets/figs/diagram.png
new file mode 100644
index 0000000000000000000000000000000000000000..761b012f0a38ca6effc99eeb3bacfbfe11ffece0
Binary files /dev/null and b/docs/assets/figs/diagram.png differ
diff --git a/docs/assets/figs/diagrama.png b/docs/assets/figs/diagrama.png
new file mode 100644
index 0000000000000000000000000000000000000000..140f926ad9dc3e3e494872f1ca7b7e3f24994c3b
Binary files /dev/null and b/docs/assets/figs/diagrama.png differ
diff --git a/docs/assets/figs/logo-circle.png b/docs/assets/figs/logo-circle.png
new file mode 100644
index 0000000000000000000000000000000000000000..acfa4e3703e74909e4793020c5f3494f03dcb212
Binary files /dev/null and b/docs/assets/figs/logo-circle.png differ
diff --git a/docs/en/finetune.md b/docs/en/finetune.md
new file mode 100644
index 0000000000000000000000000000000000000000..8b19a8df0ab1ef6d71fb2663beabc46c173e1b33
--- /dev/null
+++ b/docs/en/finetune.md
@@ -0,0 +1,128 @@
+# Fine-tuning
+
+Obviously, when you opened this page, you were not satisfied with the performance of the few-shot pre-trained model. You want to fine-tune a model to improve its performance on your dataset.
+
+In current version, you only need to finetune the 'LLAMA' part.
+
+## Fine-tuning LLAMA
+### 1. Prepare the dataset
+
+```
+.
+├── SPK1
+│ ├── 21.15-26.44.lab
+│ ├── 21.15-26.44.mp3
+│ ├── 27.51-29.98.lab
+│ ├── 27.51-29.98.mp3
+│ ├── 30.1-32.71.lab
+│ └── 30.1-32.71.mp3
+└── SPK2
+ ├── 38.79-40.85.lab
+ └── 38.79-40.85.mp3
+```
+
+You need to convert your dataset into the above format and place it under `data`. The audio file can have the extensions `.mp3`, `.wav`, or `.flac`, and the annotation file should have the extensions `.lab`.
+
+!!! info "Dataset Format"
+ The `.lab` annotation file only needs to contain the transcription of the audio, with no special formatting required. For example, if `hi.mp3` says "Hello, goodbye," then the `hi.lab` file would contain a single line of text: "Hello, goodbye."
+
+!!! warning
+ It's recommended to apply loudness normalization to the dataset. You can use [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) to do this.
+
+ ```bash
+ fap loudness-norm data-raw data --clean
+ ```
+
+
+### 2. Batch extraction of semantic tokens
+
+Make sure you have downloaded the VQGAN weights. If not, run the following command:
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+```
+
+You can then run the following command to extract semantic tokens:
+
+```bash
+python tools/vqgan/extract_vq.py data \
+ --num-workers 1 --batch-size 16 \
+ --config-name "firefly_gan_vq" \
+ --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+```
+
+!!! note
+ You can adjust `--num-workers` and `--batch-size` to increase extraction speed, but please make sure not to exceed your GPU memory limit.
+ For the VITS format, you can specify a file list using `--filelist xxx.list`.
+
+This command will create `.npy` files in the `data` directory, as shown below:
+
+```
+.
+├── SPK1
+│ ├── 21.15-26.44.lab
+│ ├── 21.15-26.44.mp3
+│ ├── 21.15-26.44.npy
+│ ├── 27.51-29.98.lab
+│ ├── 27.51-29.98.mp3
+│ ├── 27.51-29.98.npy
+│ ├── 30.1-32.71.lab
+│ ├── 30.1-32.71.mp3
+│ └── 30.1-32.71.npy
+└── SPK2
+ ├── 38.79-40.85.lab
+ ├── 38.79-40.85.mp3
+ └── 38.79-40.85.npy
+```
+
+### 3. Pack the dataset into protobuf
+
+```bash
+python tools/llama/build_dataset.py \
+ --input "data" \
+ --output "data/protos" \
+ --text-extension .lab \
+ --num-workers 16
+```
+
+After the command finishes executing, you should see the `quantized-dataset-ft.protos` file in the `data` directory.
+
+### 4. Finally, fine-tuning with LoRA
+
+Similarly, make sure you have downloaded the `LLAMA` weights. If not, run the following command:
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+```
+
+Finally, you can start the fine-tuning by running the following command:
+
+```bash
+python fish_speech/train.py --config-name text2semantic_finetune \
+ project=$project \
+ +lora@model.model.lora_config=r_8_alpha_16
+```
+
+!!! note
+ You can modify the training parameters such as `batch_size`, `gradient_accumulation_steps`, etc. to fit your GPU memory by modifying `fish_speech/configs/text2semantic_finetune.yaml`.
+
+!!! note
+ For Windows users, you can use `trainer.strategy.process_group_backend=gloo` to avoid `nccl` issues.
+
+After training is complete, you can refer to the [inference](inference.md) section to generate speech.
+
+!!! info
+ By default, the model will only learn the speaker's speech patterns and not the timbre. You still need to use prompts to ensure timbre stability.
+ If you want to learn the timbre, you can increase the number of training steps, but this may lead to overfitting.
+
+After training, you need to convert the LoRA weights to regular weights before performing inference.
+
+```bash
+python tools/llama/merge_lora.py \
+ --lora-config r_8_alpha_16 \
+ --base-weight checkpoints/fish-speech-1.4 \
+ --lora-weight results/$project/checkpoints/step_000000010.ckpt \
+ --output checkpoints/fish-speech-1.4-yth-lora/
+```
+!!! note
+ You may also try other checkpoints. We suggest using the earliest checkpoint that meets your requirements, as they often perform better on out-of-distribution (OOD) data.
diff --git a/docs/en/index.md b/docs/en/index.md
new file mode 100644
index 0000000000000000000000000000000000000000..cb57b7757f9dfef5e617dadba0f0cbc2900dfbd2
--- /dev/null
+++ b/docs/en/index.md
@@ -0,0 +1,215 @@
+# Introduction
+
+
+
+!!! warning
+ We assume no responsibility for any illegal use of the codebase. Please refer to the local laws regarding DMCA (Digital Millennium Copyright Act) and other relevant laws in your area.
+ This codebase and all models are released under the CC-BY-NC-SA-4.0 license.
+
+
+
+
+
+## Requirements
+
+- GPU Memory: 4GB (for inference), 8GB (for fine-tuning)
+- System: Linux, Windows
+
+## Windows Setup
+
+Professional Windows users may consider using WSL2 or Docker to run the codebase.
+
+```bash
+# Create a python 3.10 virtual environment, you can also use virtualenv
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+
+# Install pytorch
+pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121
+
+# Install fish-speech
+pip3 install -e .
+
+# (Enable acceleration) Install triton-windows
+pip install https://github.com/AnyaCoder/fish-speech/releases/download/v0.1.0/triton_windows-0.1.0-py3-none-any.whl
+```
+
+Non-professional Windows users can consider the following basic methods to run the project without a Linux environment (with model compilation capabilities, i.e., `torch.compile`):
+
+1. Extract the project package.
+2. Click `install_env.bat` to install the environment.
+3. If you want to enable compilation acceleration, follow this step:
+ 1. Download the LLVM compiler from the following links:
+ - [LLVM-17.0.6 (Official Site Download)](https://huggingface.co/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
+ - [LLVM-17.0.6 (Mirror Site Download)](https://hf-mirror.com/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
+ - After downloading `LLVM-17.0.6-win64.exe`, double-click to install, select an appropriate installation location, and most importantly, check the `Add Path to Current User` option to add the environment variable.
+ - Confirm that the installation is complete.
+ 2. Download and install the Microsoft Visual C++ Redistributable to solve potential .dll missing issues:
+ - [MSVC++ 14.40.33810.0 Download](https://aka.ms/vs/17/release/vc_redist.x64.exe)
+ 3. Download and install Visual Studio Community Edition to get MSVC++ build tools and resolve LLVM's header file dependencies:
+ - [Visual Studio Download](https://visualstudio.microsoft.com/zh-hans/downloads/)
+ - After installing Visual Studio Installer, download Visual Studio Community 2022.
+ - As shown below, click the `Modify` button and find the `Desktop development with C++` option to select and download.
+ 4. Download and install [CUDA Toolkit 12.x](https://developer.nvidia.com/cuda-12-1-0-download-archive?target_os=Windows&target_arch=x86_64)
+4. Double-click `start.bat` to open the training inference WebUI management interface. If needed, you can modify the `API_FLAGS` as prompted below.
+
+!!! info "Optional"
+
+ Want to start the inference WebUI?
+
+ Edit the `API_FLAGS.txt` file in the project root directory and modify the first three lines as follows:
+ ```
+ --infer
+ # --api
+ # --listen ...
+ ...
+ ```
+
+!!! info "Optional"
+
+ Want to start the API server?
+
+ Edit the `API_FLAGS.txt` file in the project root directory and modify the first three lines as follows:
+
+ ```
+ # --infer
+ --api
+ --listen ...
+ ...
+ ```
+
+!!! info "Optional"
+
+ Double-click `run_cmd.bat` to enter the conda/python command line environment of this project.
+
+## Linux Setup
+
+See [pyproject.toml](../../pyproject.toml) for details.
+```bash
+# Create a python 3.10 virtual environment, you can also use virtualenv
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+
+# Install pytorch
+pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1
+
+# (Ubuntu / Debian User) Install sox + ffmpeg
+apt install libsox-dev ffmpeg
+
+# (Ubuntu / Debian User) Install pyaudio
+apt install build-essential \
+ cmake \
+ libasound-dev \
+ portaudio19-dev \
+ libportaudio2 \
+ libportaudiocpp0
+
+# Install fish-speech
+pip3 install -e .[stable]
+```
+
+## macos setup
+
+If you want to perform inference on MPS, please add the `--device mps` flag.
+Please refer to [this PR](https://github.com/fishaudio/fish-speech/pull/461#issuecomment-2284277772) for a comparison of inference speeds.
+
+!!! warning
+ The `compile` option is not officially supported on Apple Silicon devices, so there is no guarantee that inference speed will improve.
+
+```bash
+# create a python 3.10 virtual environment, you can also use virtualenv
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+# install pytorch
+pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1
+# install fish-speech
+pip install -e .[stable]
+```
+
+## Docker Setup
+
+1. Install NVIDIA Container Toolkit:
+
+ To use GPU for model training and inference in Docker, you need to install NVIDIA Container Toolkit:
+
+ For Ubuntu users:
+
+ ```bash
+ # Add repository
+ curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg \
+ && curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \
+ sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \
+ sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
+ # Install nvidia-container-toolkit
+ sudo apt-get update
+ sudo apt-get install -y nvidia-container-toolkit
+ # Restart Docker service
+ sudo systemctl restart docker
+ ```
+
+ For users of other Linux distributions, please refer to: [NVIDIA Container Toolkit Install-guide](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html).
+
+2. Pull and run the fish-speech image
+
+ ```shell
+ # Pull the image
+ docker pull fishaudio/fish-speech:latest-dev
+ # Run the image
+ docker run -it \
+ --name fish-speech \
+ --gpus all \
+ -p 7860:7860 \
+ fishaudio/fish-speech:latest-dev \
+ zsh
+ # If you need to use a different port, please modify the -p parameter to YourPort:7860
+ ```
+
+3. Download model dependencies
+
+ Make sure you are in the terminal inside the docker container, then download the required `vqgan` and `llama` models from our huggingface repository.
+
+ ```bash
+ huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+ ```
+
+4. Configure environment variables and access WebUI
+
+ In the terminal inside the docker container, enter `export GRADIO_SERVER_NAME="0.0.0.0"` to allow external access to the gradio service inside docker.
+ Then in the terminal inside the docker container, enter `python tools/webui.py` to start the WebUI service.
+
+ If you're using WSL or MacOS, visit [http://localhost:7860](http://localhost:7860) to open the WebUI interface.
+
+ If it's deployed on a server, replace localhost with your server's IP.
+
+## Changelog
+
+- 2024/09/10: Updated Fish-Speech to 1.4 version, with an increase in dataset size and a change in the quantizer's n_groups from 4 to 8.
+- 2024/07/02: Updated Fish-Speech to 1.2 version, remove VITS Decoder, and greatly enhanced zero-shot ability.
+- 2024/05/10: Updated Fish-Speech to 1.1 version, implement VITS decoder to reduce WER and improve timbre similarity.
+- 2024/04/22: Finished Fish-Speech 1.0 version, significantly modified VQGAN and LLAMA models.
+- 2023/12/28: Added `lora` fine-tuning support.
+- 2023/12/27: Add `gradient checkpointing`, `causual sampling`, and `flash-attn` support.
+- 2023/12/19: Updated webui and HTTP API.
+- 2023/12/18: Updated fine-tuning documentation and related examples.
+- 2023/12/17: Updated `text2semantic` model, supporting phoneme-free mode.
+- 2023/12/13: Beta version released, includes VQGAN model and a language model based on LLAMA (phoneme support only).
+
+## Acknowledgements
+
+- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
+- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
+- [GPT VITS](https://github.com/innnky/gpt-vits)
+- [MQTTS](https://github.com/b04901014/MQTTS)
+- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
+- [Transformers](https://github.com/huggingface/transformers)
+- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
diff --git a/docs/en/inference.md b/docs/en/inference.md
new file mode 100644
index 0000000000000000000000000000000000000000..316c8a9c1e29353db42f3595387ea8ec088897f0
--- /dev/null
+++ b/docs/en/inference.md
@@ -0,0 +1,135 @@
+# Inference
+
+Inference support command line, HTTP API and web UI.
+
+!!! note
+ Overall, reasoning consists of several parts:
+
+ 1. Encode a given ~10 seconds of voice using VQGAN.
+ 2. Input the encoded semantic tokens and the corresponding text into the language model as an example.
+ 3. Given a new piece of text, let the model generate the corresponding semantic tokens.
+ 4. Input the generated semantic tokens into VITS / VQGAN to decode and generate the corresponding voice.
+
+## Command Line Inference
+
+Download the required `vqgan` and `llama` models from our Hugging Face repository.
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+```
+
+### 1. Generate prompt from voice:
+
+!!! note
+ If you plan to let the model randomly choose a voice timbre, you can skip this step.
+
+```bash
+python tools/vqgan/inference.py \
+ -i "paimon.wav" \
+ --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+```
+
+You should get a `fake.npy` file.
+
+### 2. Generate semantic tokens from text:
+
+```bash
+python tools/llama/generate.py \
+ --text "The text you want to convert" \
+ --prompt-text "Your reference text" \
+ --prompt-tokens "fake.npy" \
+ --checkpoint-path "checkpoints/fish-speech-1.4" \
+ --num-samples 2 \
+ --compile
+```
+
+This command will create a `codes_N` file in the working directory, where N is an integer starting from 0.
+
+!!! note
+ You may want to use `--compile` to fuse CUDA kernels for faster inference (~30 tokens/second -> ~500 tokens/second).
+ Correspondingly, if you do not plan to use acceleration, you can comment out the `--compile` parameter.
+
+!!! info
+ For GPUs that do not support bf16, you may need to use the `--half` parameter.
+
+### 3. Generate vocals from semantic tokens:
+
+#### VQGAN Decoder
+
+```bash
+python tools/vqgan/inference.py \
+ -i "codes_0.npy" \
+ --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+```
+
+## HTTP API Inference
+
+We provide a HTTP API for inference. You can use the following command to start the server:
+
+```bash
+python -m tools.api \
+ --listen 0.0.0.0:8080 \
+ --llama-checkpoint-path "checkpoints/fish-speech-1.4" \
+ --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+ --decoder-config-name firefly_gan_vq
+```
+
+> If you want to speed up inference, you can add the `--compile` parameter.
+
+After that, you can view and test the API at http://127.0.0.1:8080/.
+
+Below is an example of sending a request using `tools/post_api.py`.
+
+```bash
+python -m tools.post_api \
+ --text "Text to be input" \
+ --reference_audio "Path to reference audio" \
+ --reference_text "Text content of the reference audio" \
+ --streaming True
+```
+
+The above command indicates synthesizing the desired audio according to the reference audio information and returning it in a streaming manner.
+
+The following example demonstrates that you can use **multiple** reference audio paths and reference audio texts at once. Separate them with spaces in the command.
+
+```bash
+python -m tools.post_api \
+ --text "Text to input" \
+ --reference_audio "reference audio path1" "reference audio path2" \
+ --reference_text "reference audio text1" "reference audio text2"\
+ --streaming False \
+ --output "generated" \
+ --format "mp3"
+```
+
+The above command synthesizes the desired `MP3` format audio based on the information from multiple reference audios and saves it as `generated.mp3` in the current directory.
+
+You can also use `--reference_id` (only one can be used) instead of `--reference-audio` and `--reference_text`, provided that you create a `references/` folder in the project root directory, which contains any audio and annotation text.
+The currently supported reference audio has a maximum total duration of 90 seconds.
+
+
+!!! info
+ To learn more about available parameters, you can use the command `python -m tools.post_api -h`
+
+## GUI Inference
+[Download client](https://github.com/AnyaCoder/fish-speech-gui/releases)
+
+## WebUI Inference
+
+You can start the WebUI using the following command:
+
+```bash
+python -m tools.webui \
+ --llama-checkpoint-path "checkpoints/fish-speech-1.4" \
+ --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+ --decoder-config-name firefly_gan_vq
+```
+> If you want to speed up inference, you can add the `--compile` parameter.
+
+!!! note
+ You can save the label file and reference audio file in advance to the `references` folder in the main directory (which you need to create yourself), so that you can directly call them in the WebUI.
+
+!!! note
+ You can use Gradio environment variables, such as `GRADIO_SHARE`, `GRADIO_SERVER_PORT`, `GRADIO_SERVER_NAME` to configure WebUI.
+
+Enjoy!
diff --git a/docs/en/samples.md b/docs/en/samples.md
new file mode 100644
index 0000000000000000000000000000000000000000..93f378407a4ef1b69e88baa33cc170f3571217bf
--- /dev/null
+++ b/docs/en/samples.md
@@ -0,0 +1,137 @@
+# Samples
+
+ver 1.4
+
+## Credits
+Special thanks to [Seed-TTS (2024)](https://bytedancespeech.github.io/seedtts_tech_report/) for providing the evaluation data for demonstration.
+
+All prompt audio is from the Seed-TTS effect demo page, and all generated audio is from the first generation of fish-speech version 1.4.
+
+## Zero-shot In-context Learning
+
+
+
+ Language
+ Prompt
+ Same Language Generation
+ Cross-linugal Generation
+
+
+
+
+ EN
+ Your browser does not support the audio element.
+ Your browser does not support the audio element. I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.
+ Your browser does not support the audio element. 顿时,气氛变得沉郁起来。乍看之下,一切的困扰仿佛都围绕在我身边。我皱着眉头,感受着那份压力,但我知道我不能放弃,不能认输。于是,我深吸一口气,心底的声音告诉我:“无论如何,都要冷静下来,重新开始。”
+
+
+ Your browser does not support the audio element.
+ Your browser does not support the audio element. Dealing with family secrets is never easy. Yet, sometimes, omission is a form of protection, intending to safeguard some from the harsh truths. One day, I hope you understand the reasons behind my actions. Until then, Anna, please, bear with me.
+ Your browser does not support the audio element. 处理家庭秘密从来都不是一件容易的事。然而,有时候,隐瞒是一种保护形式,旨在保护一些人免受残酷的真相伤害。有一天,我希望你能理解我行为背后的原因。在那之前,安娜,请容忍我。
+
+
+ Your browser does not support the audio element.
+ Your browser does not support the audio element. The combinations of different textures and flavors create a perfect harmony. The succulence of the steak, the tartness of the cranberries, the crunch of pine nuts, and creaminess of blue cheese make it a truly delectable delight. Enjoy your culinary adventure!
+ Your browser does not support the audio element. 听着你的话,我心里五味杂陈。虽然我愿意一直在你身边,承担一切不幸,但我知道只有让你自己面对,才能真正让你变得更强大。所以,你要记得,无论面对何种困难,都请你坚强,我会在心里一直支持你的。
+
+
+ ZH
+ Your browser does not support the audio element.
+ Your browser does not support the audio element. 突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:"我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?"
+ Your browser does not support the audio element. Suddenly, there was a burst of laughter beside me. I looked at them, stood up straight with high spirit, shook the slightly fleshy arms, and smiled lightly, saying, "The flesh on my body is to hide my bursting charm. Otherwise, wouldn't it scare you?"
+
+
+ Your browser does not support the audio element.
+ Your browser does not support the audio element. 他闭上眼睛,期望这一切都能过去。然而,当他再次睁开眼睛,眼前的景象让他不禁倒吸一口气。雾气中出现的禁闭岛,陌生又熟悉,充满未知的危险。他握紧拳头,心知他的生活即将发生翻天覆地的改变。
+ Your browser does not support the audio element. He closed his eyes, expecting that all of this could pass. However, when he opened his eyes again, the sight in front of him made him couldn't help but take a deep breath. The closed island that appeared in the fog, strange and familiar, was full of unknown dangers. He tightened his fist, knowing that his life was about to undergo earth-shaking changes.
+
+
+ Your browser does not support the audio element.
+ Your browser does not support the audio element. 顿时,气氛变得沉郁起来。乍看之下,一切的困扰仿佛都围绕在我身边。我皱着眉头,感受着那份压力,但我知道我不能放弃,不能认输。于是,我深吸一口气,心底的声音告诉我:“无论如何,都要冷静下来,重新开始。”
+ Your browser does not support the audio element. Suddenly, the atmosphere became gloomy. At first glance, all the troubles seemed to surround me. I frowned, feeling that pressure, but I know I can't give up, can't admit defeat. So, I took a deep breath, and the voice in my heart told me, "Anyway, must calm down and start again."
+
+
+
+
+
+
+## Speaker Fine-tune
+
+
+
+
+
+ Text
+ Generated
+
+
+
+
+ Speaker1
+ 好呀,哈哈哈哈哈,喜欢笑的人运气都不会差哦,希望你每天笑口常开~
+ Your browser does not support the audio element.
+
+
+ 哇!恭喜你中了大乐透,八百万可真不少呢!有什么特别的计划或想法吗?
+ Your browser does not support the audio element.
+
+
+ 哼,你这么问是想请本小姐吃饭吗?如果对象是你的话,那也不是不可以。
+ Your browser does not support the audio element.
+
+
+ Speaker2
+ 是呀,他还想换个地球仪哈哈哈,看来给你积累了一些快乐值了,你还想不想再听一个其他的笑话呀?
+ Your browser does not support the audio element.
+
+
+ 嘿嘿,你是不是也想拥有甜甜的恋爱呢?《微微一笑很倾城》是你的不二选择,男女主是校花校草类型,他们通过游戏结识,再到两人见面,全程没有一点误会,真的齁甜,想想都忍不住“姨妈笑”~
+ Your browser does not support the audio element.
+
+
+ 小傻瓜,嗯……算是个很可爱很亲切的名字,有点“独特”哦,不过我有些好奇,你为什么会给我选这个昵称呢?
+ Your browser does not support the audio element.
+
+
+
+
+
+## Content Editing
+
+
+
+ Language
+ Original Text
+ Original Audio
+ Target Text
+ Edited Audio
+
+
+
+ EN
+ They can't order me to stop dreaming. If you dream a thing more than once, it's sure to come true. Have faith in your dreams, and someday your rainbow will come shining through.
+ Your browser does not support the audio element.
+ They can't require me to stop imagining. If you envision a thing more than once, it's bound to come about . Have trust in your visions , and someday your radiance will come beaming through.
+ Your browser does not support the audio element.
+
+
+ Are you familiar with it? Slice the steak and place the strips on top, then garnish with the dried cranberries, pine nuts, and blue cheese. I wonder how people rationalise the decision?
+ Your browser does not support the audio element.
+ Are you acquainted with it? Cut the pork and place the strips on top, then garnish with the dried cherries, almonds, and feta cheese. I query how people justify the choice?
+ Your browser does not support the audio element.
+
+
+ ZH
+ 自古以来,庸君最怕党政了,可圣君他就不怕,不但不怕,反能利用。要我说,你就让明珠索额图互相争宠,只要你心里明白,左右逢源,你就能立于不败之地。
+ Your browser does not support the audio element.
+ 从古至今 ,庸君最怕朝纲了 ,可明 君他就不怕,不但不怕,反能借助 。要我说,你就让李四张三 互相争宠,只要你心里清楚 ,左右周旋 ,你就能处 于不败之境 。
+ Your browser does not support the audio element.
+
+
+ 对,这就是我,万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。
+ Your browser does not support the audio element.
+ 对,这就是我,众人尊崇 的太白金星 ,虽然有点娃娃脸 ,但也遮 不住我迷人 的魅力。
+ Your browser does not support the audio element.
+
+
+
diff --git a/docs/en/start_agent.md b/docs/en/start_agent.md
new file mode 100644
index 0000000000000000000000000000000000000000..d252490463979c27ce8d376ccc125971555317e7
--- /dev/null
+++ b/docs/en/start_agent.md
@@ -0,0 +1,77 @@
+# Start Agent
+
+## Requirements
+
+- GPU memory: At least 8GB(under quanization), 16GB or more is recommanded.
+- Disk usage: 10GB
+
+## Download Model
+
+You can get the model by:
+
+```bash
+huggingface-cli download fishaudio/fish-agent-v0.1-3b --local-dir checkpoints/fish-agent-v0.1-3b
+```
+
+Put them in the 'checkpoints' folder.
+
+You also need the fish-speech model which you can download instructed by [inference](inference.md).
+
+So there will be 2 folder in the checkpoints.
+
+The `checkpoints/fish-speech-1.4` and `checkpoints/fish-agent-v0.1-3b`
+
+## Environment Prepare
+
+If you already have Fish-speech, you can directly use by adding the follow instruction:
+```bash
+pip install cachetools
+```
+
+!!! note
+ Please use the Python version below 3.12 for compile.
+
+If you don't have, please use the below commands to build yout environment:
+
+```bash
+sudo apt-get install portaudio19-dev
+
+pip install -e .[stable]
+```
+
+## Launch The Agent Demo.
+
+To build fish-agent, please use the command below under the main folder:
+
+```bash
+python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
+```
+
+The `--compile` args only support Python < 3.12 , which will greatly speed up the token generation.
+
+It won't compile at once (remember).
+
+Then open another terminal and use the command:
+
+```bash
+python -m tools.e2e_webui
+```
+
+This will create a Gradio WebUI on the device.
+
+When you first use the model, it will come to compile (if the `--compile` is True) for a short time, so please wait with patience.
+
+## Gradio Webui
+
+
+
+
+Have a good time!
+
+## Performance
+
+Under our test, a 4060 laptop just barely runs, but is very stretched, which is only about 8 tokens/s. The 4090 is around 95 tokens/s under compile, which is what we recommend.
+
+# About Agent
+
+The demo is an early alpha test version, the inference speed needs to be optimised, and there are a lot of bugs waiting to be fixed. If you've found a bug or want to fix it, we'd be very happy to receive an issue or a pull request.
diff --git a/docs/ja/finetune.md b/docs/ja/finetune.md
new file mode 100644
index 0000000000000000000000000000000000000000..68db8cbd464064f33d789bcc2b25534fc99785ae
--- /dev/null
+++ b/docs/ja/finetune.md
@@ -0,0 +1,128 @@
+# 微調整
+
+明らかに、このページを開いたとき、few-shot 事前トレーニングモデルのパフォーマンスに満足していなかったことでしょう。データセット上でのパフォーマンスを向上させるためにモデルを微調整したいと考えています。
+
+現在のバージョンでは、「LLAMA」部分のみを微調整する必要があります。
+
+## LLAMAの微調整
+### 1. データセットの準備
+
+```
+.
+├── SPK1
+│ ├── 21.15-26.44.lab
+│ ├── 21.15-26.44.mp3
+│ ├── 27.51-29.98.lab
+│ ├── 27.51-29.98.mp3
+│ ├── 30.1-32.71.lab
+│ └── 30.1-32.71.mp3
+└── SPK2
+ ├── 38.79-40.85.lab
+ └── 38.79-40.85.mp3
+```
+
+データセットを上記の形式に変換し、「data」ディレクトリに配置する必要があります。音声ファイルの拡張子は「.mp3」、「.wav」、または「.flac」にすることができ、注釈ファイルの拡張子は「.lab」にする必要があります。
+
+!!! info
+ 標準ファイル `.lab` には、音声の転写テキストのみを含め、特別なフォーマットは必要ありません。例えば、`hi.mp3` で「こんにちは、さようなら」と言っている場合、`hi.lab` ファイルには「こんにちは、さようなら」という一行のテキストを含めるだけです。
+
+!!! warning
+ データセットにラウドネス正規化を適用することをお勧めします。これを行うには、[fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) を使用できます。
+
+ ```bash
+ fap loudness-norm data-raw data --clean
+ ```
+
+
+### 2. セマンティックトークンのバッチ抽出
+
+VQGANの重みをダウンロードしたことを確認してください。まだダウンロードしていない場合は、次のコマンドを実行してください。
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+```
+
+次に、次のコマンドを実行してセマンティックトークンを抽出できます。
+
+```bash
+python tools/vqgan/extract_vq.py data \
+ --num-workers 1 --batch-size 16 \
+ --config-name "firefly_gan_vq" \
+ --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+```
+
+!!! note
+ `--num-workers` と `--batch-size` を調整して抽出速度を上げることができますが、GPUメモリの制限を超えないようにしてください。
+ VITS形式の場合、`--filelist xxx.list` を使用してファイルリストを指定できます。
+
+このコマンドは、`data`ディレクトリに`.npy`ファイルを作成します。以下のように表示されます。
+
+```
+.
+├── SPK1
+│ ├── 21.15-26.44.lab
+│ ├── 21.15-26.44.mp3
+│ ├── 21.15-26.44.npy
+│ ├── 27.51-29.98.lab
+│ ├── 27.51-29.98.mp3
+│ ├── 27.51-29.98.npy
+│ ├── 30.1-32.71.lab
+│ ├── 30.1-32.71.mp3
+│ └── 30.1-32.71.npy
+└── SPK2
+ ├── 38.79-40.85.lab
+ ├── 38.79-40.85.mp3
+ └── 38.79-40.85.npy
+```
+
+### 3. データセットをprotobufにパックする
+
+```bash
+python tools/llama/build_dataset.py \
+ --input "data" \
+ --output "data/protos" \
+ --text-extension .lab \
+ --num-workers 16
+```
+
+コマンドの実行が完了すると、`data`ディレクトリに`quantized-dataset-ft.protos`ファイルが表示されます。
+
+### 4. 最後に、LoRAを使用して微調整する
+
+同様に、`LLAMA`の重みをダウンロードしたことを確認してください。まだダウンロードしていない場合は、次のコマンドを実行してください。
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+```
+
+最後に、次のコマンドを実行して微調整を開始できます。
+
+```bash
+python fish_speech/train.py --config-name text2semantic_finetune \
+ project=$project \
+ +lora@model.model.lora_config=r_8_alpha_16
+```
+
+!!! note
+ `fish_speech/configs/text2semantic_finetune.yaml` を変更して、`batch_size`、`gradient_accumulation_steps` などのトレーニングパラメータを変更し、GPUメモリに適合させることができます。
+
+!!! note
+ Windowsユーザーの場合、`trainer.strategy.process_group_backend=gloo` を使用して `nccl` の問題を回避できます。
+
+トレーニングが完了したら、[推論](inference.md)セクションを参照し、音声を生成します。
+
+!!! info
+ デフォルトでは、モデルは話者の発話パターンのみを学習し、音色は学習しません。音色の安定性を確保するためにプロンプトを使用する必要があります。
+ 音色を学習したい場合は、トレーニングステップ数を増やすことができますが、これにより過学習が発生する可能性があります。
+
+トレーニングが完了したら、推論を行う前にLoRAの重みを通常の重みに変換する必要があります。
+
+```bash
+python tools/llama/merge_lora.py \
+ --lora-config r_8_alpha_16 \
+ --base-weight checkpoints/fish-speech-1.4 \
+ --lora-weight results/$project/checkpoints/step_000000010.ckpt \
+ --output checkpoints/fish-speech-1.4-yth-lora/
+```
+!!! note
+ 他のチェックポイントを試すこともできます。要件を満たす最も早いチェックポイントを使用することをお勧めします。これらは通常、分布外(OOD)データでより良いパフォーマンスを発揮します。
diff --git a/docs/ja/index.md b/docs/ja/index.md
new file mode 100644
index 0000000000000000000000000000000000000000..7c5ad2ec59f1ea2d5b54cf98cf6135099ee375fe
--- /dev/null
+++ b/docs/ja/index.md
@@ -0,0 +1,214 @@
+# Fish Speech の紹介
+
+
+
+!!! warning
+ 私たちは、コードベースの違法な使用について一切の責任を負いません。お住まいの地域の DMCA(デジタルミレニアム著作権法)およびその他の関連法を参照してください。
+ このコードベースとモデルは、CC-BY-NC-SA-4.0 ライセンス下でリリースされています。
+
+
+
+
+
+## 要件
+
+- GPU メモリ: 4GB(推論用)、8GB(ファインチューニング用)
+- システム: Linux、Windows
+
+## Windowsセットアップ
+
+プロフェッショナルなWindowsユーザーは、WSL2またはDockerを使用してコードベースを実行することを検討してください。
+
+```bash
+# Python 3.10の仮想環境を作成(virtualenvも使用可能)
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+
+# PyTorchをインストール
+pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121
+
+# fish-speechをインストール
+pip3 install -e .
+
+# (アクセラレーションを有効にする) triton-windowsをインストール
+pip install https://github.com/AnyaCoder/fish-speech/releases/download/v0.1.0/triton_windows-0.1.0-py3-none-any.whl
+```
+
+非プロフェッショナルなWindowsユーザーは、Linux環境なしでプロジェクトを実行するための以下の基本的な方法を検討できます(モデルコンパイル機能、つまり`torch.compile`を使用可能):
+
+1. プロジェクトパッケージを解凍する。
+2. `install_env.bat`をクリックして環境をインストールする。
+3. コンパイルアクセラレーションを有効にしたい場合は、次のステップに従ってください:
+ 1. 以下のリンクからLLVMコンパイラをダウンロード:
+ - [LLVM-17.0.6(公式サイトのダウンロード)](https://huggingface.co/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
+ - [LLVM-17.0.6(ミラーサイトのダウンロード)](https://hf-mirror.com/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
+ - `LLVM-17.0.6-win64.exe`をダウンロードした後、ダブルクリックしてインストールし、適切なインストール場所を選択し、最も重要なのは`Add Path to Current User`オプションを選択して環境変数を追加することです。
+ - インストールが完了したことを確認する。
+ 2. 欠落している .dll の問題を解決するため、Microsoft Visual C++ Redistributable をダウンロードしてインストールする:
+ - [MSVC++ 14.40.33810.0 ダウンロード](https://aka.ms/vs/17/release/vc_redist.x64.exe)
+ 3. Visual Studio Community Editionをダウンロードして、MSVC++ビルドツールを取得し、LLVMのヘッダーファイルの依存関係を解決する:
+ - [Visual Studio ダウンロード](https://visualstudio.microsoft.com/ja/downloads/)
+ - Visual Studio Installerをインストールした後、Visual Studio Community 2022をダウンロード。
+ - 下記のように、`Modify`ボタンをクリックし、`C++によるデスクトップ開発`オプションを選択してダウンロード。
+ -
+ 4. [CUDA Toolkit 12.x](https://developer.nvidia.com/cuda-12-1-0-download-archive?target_os=Windows&target_arch=x86_64)をダウンロードしてインストールする。
+4. `start.bat`をダブルクリックして、トレーニング推論WebUI管理インターフェースを開きます。必要に応じて、以下に示すように`API_FLAGS`を修正できます。
+
+
+!!! info "オプション"
+ 推論WebUIを起動しますか?
+ プロジェクトのルートディレクトリにある `API_FLAGS.txt` ファイルを編集し、最初の3行を次のように変更します:
+ ```
+ --infer
+ # --api
+ # --listen ...
+ ...
+ ```
+
+!!! info "オプション"
+ APIサーバーを起動しますか?
+ プロジェクトのルートディレクトリにある `API_FLAGS.txt` ファイルを編集し、最初の3行を次のように変更します:
+ ```
+ # --infer
+ --api
+ --listen ...
+ ...
+ ```
+
+!!! info "オプション"
+ `run_cmd.bat` をダブルクリックして、このプロジェクトの conda/python コマンドライン環境に入ります。
+
+
+
+## Linux セットアップ
+
+詳細については、[pyproject.toml](../../pyproject.toml) を参照してください。
+```bash
+# python 3.10の仮想環境を作成します。virtualenvも使用できます。
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+
+# pytorchをインストールします。
+pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1
+
+# (Ubuntu / Debianユーザー) sox + ffmpegをインストールします。
+apt install libsox-dev ffmpeg
+
+# (Ubuntu / Debianユーザー) pyaudio をインストールします。
+apt install build-essential \
+ cmake \
+ libasound-dev \
+ portaudio19-dev \
+ libportaudio2 \
+ libportaudiocpp0
+
+# fish-speechをインストールします。
+pip3 install -e .[stable]
+
+```
+
+## macos setup
+
+推論をMPS上で行う場合は、`--device mps`フラグを追加してください。
+推論速度の比較は[こちらのPR](https://github.com/fishaudio/fish-speech/pull/461#issuecomment-2284277772)を参考にしてください。
+
+!!! warning
+ AppleSiliconのデバイスでは、compileオプションに正式に対応していませんので、推論速度が向上する保証はありません。
+
+```bash
+# create a python 3.10 virtual environment, you can also use virtualenv
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+# install pytorch
+pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1
+# install fish-speech
+pip install -e .[stable]
+```
+
+## Docker セットアップ
+
+1. NVIDIA Container Toolkit のインストール:
+
+ Docker で GPU を使用してモデルのトレーニングと推論を行うには、NVIDIA Container Toolkit をインストールする必要があります:
+
+ Ubuntu ユーザーの場合:
+
+ ```bash
+ # リポジトリの追加
+ curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg \
+ && curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \
+ sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \
+ sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
+ # nvidia-container-toolkit のインストール
+ sudo apt-get update
+ sudo apt-get install -y nvidia-container-toolkit
+ # Docker サービスの再起動
+ sudo systemctl restart docker
+ ```
+
+ 他の Linux ディストリビューションを使用している場合は、以下のインストールガイドを参照してください:[NVIDIA Container Toolkit Install-guide](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)。
+
+2. fish-speech イメージのプルと実行
+
+ ```shell
+ # イメージのプル
+ docker pull fishaudio/fish-speech:latest-dev
+ # イメージの実行
+ docker run -it \
+ --name fish-speech \
+ --gpus all \
+ -p 7860:7860 \
+ fishaudio/fish-speech:latest-dev \
+ zsh
+ # 他のポートを使用する場合は、-p パラメータを YourPort:7860 に変更してください
+ ```
+
+3. モデルの依存関係のダウンロード
+
+ Docker コンテナ内のターミナルにいることを確認し、huggingface リポジトリから必要な `vqgan` と `llama` モデルをダウンロードします。
+
+ ```bash
+ huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+ ```
+
+4. 環境変数の設定と WebUI へのアクセス
+
+ Docker コンテナ内のターミナルで、`export GRADIO_SERVER_NAME="0.0.0.0"` と入力して、外部から Docker 内の gradio サービスにアクセスできるようにします。
+ 次に、Docker コンテナ内のターミナルで `python tools/webui.py` と入力して WebUI サービスを起動します。
+
+ WSL または MacOS の場合は、[http://localhost:7860](http://localhost:7860) にアクセスして WebUI インターフェースを開くことができます。
+
+ サーバーにデプロイしている場合は、localhost をサーバーの IP に置き換えてください。
+
+## 変更履歴
+
+- 2024/09/10: Fish-Speech を Ver.1.4 に更新し、データセットのサイズを増加させ、quantizer n_groups を 4 から 8 に変更しました。
+- 2024/07/02: Fish-Speech を Ver.1.2 に更新し、VITS デコーダーを削除し、ゼロショット能力を大幅に強化しました。
+- 2024/05/10: Fish-Speech を Ver.1.1 に更新し、VITS デコーダーを実装して WER を減少させ、音色の類似性を向上させました。
+- 2024/04/22: Fish-Speech Ver.1.0 を完成させ、VQGAN および LLAMA モデルを大幅に修正しました。
+- 2023/12/28: `lora`微調整サポートを追加しました。
+- 2023/12/27: `gradient checkpointing`、`causual sampling`、および`flash-attn`サポートを追加しました。
+- 2023/12/19: webui および HTTP API を更新しました。
+- 2023/12/18: 微調整ドキュメントおよび関連例を更新しました。
+- 2023/12/17: `text2semantic`モデルを更新し、自由音素モードをサポートしました。
+- 2023/12/13: ベータ版をリリースし、VQGAN モデルおよび LLAMA に基づく言語モデル(音素のみサポート)を含みます。
+
+## 謝辞
+
+- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
+- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
+- [GPT VITS](https://github.com/innnky/gpt-vits)
+- [MQTTS](https://github.com/b04901014/MQTTS)
+- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
+- [Transformers](https://github.com/huggingface/transformers)
+- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
diff --git a/docs/ja/inference.md b/docs/ja/inference.md
new file mode 100644
index 0000000000000000000000000000000000000000..c4e6145036be32ecf0dc5afe0658acab125283c2
--- /dev/null
+++ b/docs/ja/inference.md
@@ -0,0 +1,114 @@
+# 推論
+
+推論は、コマンドライン、HTTP API、および Web UI をサポートしています。
+
+!!! note
+ 全体として、推論は次のいくつかの部分で構成されています:
+
+ 1. VQGANを使用して、与えられた約10秒の音声をエンコードします。
+ 2. エンコードされたセマンティックトークンと対応するテキストを例として言語モデルに入力します。
+ 3. 新しいテキストが与えられた場合、モデルに対応するセマンティックトークンを生成させます。
+ 4. 生成されたセマンティックトークンをVITS / VQGANに入力してデコードし、対応する音声を生成します。
+
+## コマンドライン推論
+
+必要な`vqgan`および`llama`モデルを Hugging Face リポジトリからダウンロードします。
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+```
+
+### 1. 音声からプロンプトを生成する:
+
+!!! note
+ モデルにランダムに音声の音色を選ばせる場合、このステップをスキップできます。
+
+```bash
+python tools/vqgan/inference.py \
+ -i "paimon.wav" \
+ --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+```
+
+`fake.npy`ファイルが生成されるはずです。
+
+### 2. テキストからセマンティックトークンを生成する:
+
+```bash
+python tools/llama/generate.py \
+ --text "変換したいテキスト" \
+ --prompt-text "参照テキスト" \
+ --prompt-tokens "fake.npy" \
+ --checkpoint-path "checkpoints/fish-speech-1.4" \
+ --num-samples 2 \
+ --compile
+```
+
+このコマンドは、作業ディレクトリに`codes_N`ファイルを作成します。ここで、N は 0 から始まる整数です。
+
+!!! note
+ `--compile`を使用して CUDA カーネルを融合し、より高速な推論を実現することができます(約 30 トークン/秒 -> 約 500 トークン/秒)。
+ それに対応して、加速を使用しない場合は、`--compile`パラメータをコメントアウトできます。
+
+!!! info
+ bf16 をサポートしていない GPU の場合、`--half`パラメータを使用する必要があるかもしれません。
+
+### 3. セマンティックトークンから音声を生成する:
+
+#### VQGAN デコーダー
+
+```bash
+python tools/vqgan/inference.py \
+ -i "codes_0.npy" \
+ --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+```
+
+## HTTP API 推論
+
+推論のための HTTP API を提供しています。次のコマンドを使用してサーバーを起動できます:
+
+```bash
+python -m tools.api \
+ --listen 0.0.0.0:8080 \
+ --llama-checkpoint-path "checkpoints/fish-speech-1.4" \
+ --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+ --decoder-config-name firefly_gan_vq
+```
+
+> 推論を高速化したい場合は、`--compile` パラメータを追加できます。
+
+その後、`http://127.0.0.1:8080/`で API を表示およびテストできます。
+
+以下は、`tools/post_api.py` を使用してリクエストを送信する例です。
+
+```bash
+python -m tools.post_api \
+ --text "入力するテキスト" \
+ --reference_audio "参照音声へのパス" \
+ --reference_text "参照音声テキスト" \
+ --streaming True
+```
+
+上記のコマンドは、参照音声の情報に基づいて必要な音声を合成し、ストリーミング方式で返すことを示しています。
+
+!!! info
+ 使用可能なパラメータの詳細については、コマンド` python -m tools.post_api -h `を使用してください
+
+## WebUI 推論
+
+次のコマンドを使用して WebUI を起動できます:
+
+```bash
+python -m tools.webui \
+ --llama-checkpoint-path "checkpoints/fish-speech-1.4" \
+ --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+ --decoder-config-name firefly_gan_vq
+```
+> 推論を高速化したい場合は、`--compile` パラメータを追加できます。
+
+!!! note
+ ラベルファイルと参照音声ファイルをメインディレクトリの `references` フォルダ(自分で作成する必要があります)に事前に保存しておくことで、WebUI で直接呼び出すことができます。
+
+!!! note
+ Gradio 環境変数(`GRADIO_SHARE`、`GRADIO_SERVER_PORT`、`GRADIO_SERVER_NAME`など)を使用して WebUI を構成できます。
+
+お楽しみください!
diff --git a/docs/ja/samples.md b/docs/ja/samples.md
new file mode 100644
index 0000000000000000000000000000000000000000..13907736daed5961c66f56524bb0fa7068b76f3a
--- /dev/null
+++ b/docs/ja/samples.md
@@ -0,0 +1,225 @@
+# サンプル
+
+v1.4デモは[こちら](https://speech.fish.audio/samples/)に更新されています
+
+v1.2のサンプルは[Bilibili](https://www.bilibili.com/video/BV1wz421B71D/)で利用可能です。
+
+以下のサンプルはv1.1モデルからのものです。
+
+## 中国語の文1
+```
+人間灯火倒映湖中,她的渴望让静水泛起涟漪。若代价只是孤独,那就让这份愿望肆意流淌。
+流入她所注视的世间,也流入她如湖水般澄澈的目光。
+```
+
+
+
+
+ 話者
+ 入力音声
+ 合成音声
+
+
+
+
+ ナヒーダ (原神)
+
+
+
+
+ 鍾離 (原神)
+
+
+
+
+ フリナ (原神)
+
+
+
+
+ ランダム話者1
+ -
+
+
+
+ ランダム話者2
+ -
+
+
+
+
+
+
+## 中国語の文2
+```
+你们这个是什么群啊,你们这是害人不浅啊你们这个群!谁是群主,出来!真的太过分了。你们搞这个群干什么?
+我儿子每一科的成绩都不过那个平均分呐,他现在初二,你叫我儿子怎么办啊?他现在还不到高中啊?
+你们害死我儿子了!快点出来你这个群主!再这样我去报警了啊!我跟你们说你们这一帮人啊,一天到晚啊,
+搞这些什么游戏啊,动漫啊,会害死你们的,你们没有前途我跟你说。你们这九百多个人,好好学习不好吗?
+一天到晚在上网。有什么意思啊?麻烦你重视一下你们的生活的目标啊?有一点学习目标行不行?一天到晚上网是不是人啊?
+```
+
+
+
+
+ 話者
+ 入力音声
+ 合成音声
+
+
+
+
+ ナヒーダ (原神)
+
+
+
+
+ ランダム話者
+ -
+
+
+
+
+
+
+## 中国語の文3
+```
+大家好,我是 Fish Audio 开发的开源文本转语音模型。经过十五万小时的数据训练,
+我已经能够熟练掌握中文、日语和英语,我的语言处理能力接近人类水平,声音表现形式丰富多变。
+作为一个仅有亿级参数的模型,我相信社区成员能够在个人设备上轻松运行和微调,让我成为您的私人语音助手。
+```
+
+
+
+
+
+ 話者
+ 入力音声
+ 合成音声
+
+
+
+
+ ランダム話者
+ -
+
+
+
+
+
+## 英語の文1
+
+```
+In the realm of advanced technology, the evolution of artificial intelligence stands as a
+monumental achievement. This dynamic field, constantly pushing the boundaries of what
+machines can do, has seen rapid growth and innovation. From deciphering complex data
+patterns to driving cars autonomously, AI's applications are vast and diverse.
+```
+
+
+
+
+ 話者
+ 入力音声
+ 合成音声
+
+
+
+
+ ランダム話者1
+ -
+
+
+
+ ランダム話者2
+ -
+
+
+
+
+
+## 英語の文2
+```
+Hello everyone, I am an open-source text-to-speech model developed by
+Fish Audio. After training with 150,000 hours of data, I have become proficient
+in Chinese, Japanese, and English, and my language processing abilities
+are close to human level. My voice is capable of a wide range of expressions.
+As a model with only hundreds of millions of parameters, I believe community
+members can easily run and fine-tune me on their personal devices, allowing
+me to serve as your personal voice assistant.
+```
+
+
+
+
+ 話者
+ 入力音声
+ 合成音声
+
+
+
+
+ ランダム話者
+ -
+
+
+
+
+
+## 日本語の文1
+
+```
+先進技術の領域において、人工知能の進化は画期的な成果として立っています。常に機械ができることの限界を
+押し広げているこのダイナミックな分野は、急速な成長と革新を見せています。複雑なデータパターンの解読か
+ら自動運転車の操縦まで、AIの応用は広範囲に及びます。
+```
+
+
+
+
+
+ 話者
+ 入力音声
+ 合成音声
+
+
+
+
+ ランダム話者1
+ -
+
+
+
+ ランダム話者2
+ -
+
+
+
+
+
+## 日本語の文2
+```
+皆さん、こんにちは。私はフィッシュオーディオによって開発されたオープンソースのテ
+キストから音声への変換モデルです。15万時間のデータトレーニングを経て、
+中国語、日本語、英語を熟知しており、言語処理能力は人間に近いレベルです。
+声の表現も多彩で豊かです。数億のパラメータを持つこのモデルは、コミュニティ
+のメンバーが個人のデバイスで簡単に実行し、微調整することができると
+信じています。これにより、私を個人の音声アシスタントとして活用できます。
+```
+
+
+
+
+ 話者
+ 入力音声
+ 合成音声
+
+
+
+
+ ランダム話者
+ -
+
+
+
+
diff --git a/docs/ja/start_agent.md b/docs/ja/start_agent.md
new file mode 100644
index 0000000000000000000000000000000000000000..10cabed6aece755bd25b0cfd2db2bf3540b98975
--- /dev/null
+++ b/docs/ja/start_agent.md
@@ -0,0 +1,80 @@
+# エージェントの開始
+
+!!! note
+ もしあなたがネイティブ・スピーカーで、翻訳に問題があるとお感じでしたら、issueかpull requestをお送りください!
+
+## 要件
+
+- GPUメモリ: 最低8GB(量子化使用時)、16GB以上推奨
+- ディスク使用量: 10GB
+
+## モデルのダウンロード
+
+以下のコマンドでモデルを取得できます:
+
+```bash
+huggingface-cli download fishaudio/fish-agent-v0.1-3b --local-dir checkpoints/fish-agent-v0.1-3b
+```
+
+これらを'checkpoints'フォルダに配置してください。
+
+また、[inference](inference.md)の手順に従ってfish-speechモデルもダウンロードする必要があります。
+
+checkpointsには2つのフォルダが必要です。
+
+`checkpoints/fish-speech-1.4`と`checkpoints/fish-agent-v0.1-3b`です。
+
+## 環境準備
+
+すでにFish-speechをお持ちの場合は、以下の指示を追加するだけで直接使用できます:
+```bash
+pip install cachetools
+```
+
+!!! note
+ コンパイルにはPythonバージョン3.12未満を使用してください。
+
+お持ちでない場合は、以下のコマンドで環境を構築してください:
+
+```bash
+sudo apt-get install portaudio19-dev
+
+pip install -e .[stable]
+```
+
+## エージェントデモの起動
+
+fish-agentを構築するには、メインフォルダで以下のコマンドを使用してください:
+
+```bash
+python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
+```
+
+`--compile`引数はPython < 3.12でのみサポートされており、トークン生成を大幅に高速化します。
+
+一度にコンパイルは行われません(覚えておいてください)。
+
+次に、別のターミナルを開いて以下のコマンドを使用します:
+
+```bash
+python -m tools.e2e_webui
+```
+
+これにより、デバイス上にGradio WebUIが作成されます。
+
+モデルを初めて使用する際は、(`--compile`がTrueの場合)しばらくコンパイルが行われますので、お待ちください。
+
+## Gradio Webui
+
+
+
+
+お楽しみください!
+
+## パフォーマンス
+
+テストでは、4060搭載のラップトップではかろうじて動作しますが、非常に厳しい状態で、約8トークン/秒程度です。4090ではコンパイル時に約95トークン/秒で、これが推奨環境です。
+
+# エージェントについて
+
+このデモは初期アルファテストバージョンで、推論速度の最適化が必要で、修正を待つバグが多数あります。バグを発見した場合や修正したい場合は、issueやプルリクエストをいただけると大変嬉しく思います。
diff --git a/docs/ko/finetune.md b/docs/ko/finetune.md
new file mode 100644
index 0000000000000000000000000000000000000000..a13d5e51edb4a11767f55e5c40e0d8434f27d2d7
--- /dev/null
+++ b/docs/ko/finetune.md
@@ -0,0 +1,128 @@
+# 파인튜닝
+
+이 페이지를 열었다는 것은, 사전 학습된 퓨샷(Few-shot) 모델의 성능에 만족하지 못했다는 의미일 것입니다. 데이터셋의 성능을 향상시키기 위해 모델을 파인튜닝하고 싶으시겠죠.
+
+현재 버전에서는 'LLAMA' 부분만 파인튜닝하시면 됩니다.
+
+## LLAMA 파인튜닝
+### 1. 데이터셋 준비
+
+```
+.
+├── SPK1
+│ ├── 21.15-26.44.lab
+│ ├── 21.15-26.44.mp3
+│ ├── 27.51-29.98.lab
+│ ├── 27.51-29.98.mp3
+│ ├── 30.1-32.71.lab
+│ └── 30.1-32.71.mp3
+└── SPK2
+ ├── 38.79-40.85.lab
+ └── 38.79-40.85.mp3
+```
+
+위와 같은 형식으로 데이터셋을 변환하여 `data` 디렉토리 안에 배치하세요. 오디오 파일의 확장자는 `.mp3`, `.wav`, `.flac` 중 하나여야 하며, 주석 파일은 `.lab` 확장자를 사용해야 합니다.
+
+!!! info "데이터셋 형식"
+ `.lab` 주석 파일은 오디오의 전사 내용만 포함하면 되며, 특별한 형식이 필요하지 않습니다. 예를 들어, `hi.mp3`에서 "Hello, goodbye"라는 대사를 말한다면, `hi.lab` 파일에는 "Hello, goodbye"라는 한 줄의 텍스트만 있어야 합니다.
+
+!!! warning
+ 데이터셋에 대한 음량 정규화(loudness normalization)를 적용하는 것이 좋습니다. 이를 위해 [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess)를 사용할 수 있습니다.
+
+ ```bash
+ fap loudness-norm data-raw data --clean
+ ```
+
+### 2. 시맨틱 토큰 배치 추출
+
+VQGAN 가중치를 다운로드했는지 확인하세요. 다운로드하지 않았다면 아래 명령어를 실행하세요:
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+```
+
+이후 시맨틱 토큰을 추출하기 위해 아래 명령어를 실행하세요:
+
+```bash
+python tools/vqgan/extract_vq.py data \
+ --num-workers 1 --batch-size 16 \
+ --config-name "firefly_gan_vq" \
+ --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+```
+
+!!! note
+ 추출 속도를 높이기 위해 `--num-workers`와 `--batch-size` 값을 조정할 수 있지만, GPU 메모리 한도를 초과하지 않도록 주의하세요.
+ VITS 형식의 경우, `--filelist xxx.list`를 사용하여 파일 목록을 지정할 수 있습니다.
+
+이 명령을 실행하면 `data` 디렉토리 안에 `.npy` 파일이 생성됩니다. 다음과 같이 표시됩니다:
+
+```
+.
+├── SPK1
+│ ├── 21.15-26.44.lab
+│ ├── 21.15-26.44.mp3
+│ ├── 21.15-26.44.npy
+│ ├── 27.51-29.98.lab
+│ ├── 27.51-29.98.mp3
+│ ├── 27.51-29.98.npy
+│ ├── 30.1-32.71.lab
+│ ├── 30.1-32.71.mp3
+│ └── 30.1-32.71.npy
+└── SPK2
+ ├── 38.79-40.85.lab
+ ├── 38.79-40.85.mp3
+ └── 38.79-40.85.npy
+```
+
+### 3. 데이터셋을 protobuf로 패킹
+
+```bash
+python tools/llama/build_dataset.py \
+ --input "data" \
+ --output "data/protos" \
+ --text-extension .lab \
+ --num-workers 16
+```
+
+명령이 완료되면 `data` 디렉토리 안에 `quantized-dataset-ft.protos` 파일이 생성됩니다.
+
+### 4. 마지막으로, LoRA를 이용한 파인튜닝
+
+마찬가지로, `LLAMA` 가중치를 다운로드했는지 확인하세요. 다운로드하지 않았다면 아래 명령어를 실행하세요:
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+```
+
+마지막으로, 아래 명령어를 실행하여 파인튜닝을 시작할 수 있습니다:
+
+```bash
+python fish_speech/train.py --config-name text2semantic_finetune \
+ project=$project \
+ +lora@model.model.lora_config=r_8_alpha_16
+```
+
+!!! note
+ `batch_size`, `gradient_accumulation_steps` 등의 학습 매개변수를 GPU 메모리에 맞게 조정하려면 `fish_speech/configs/text2semantic_finetune.yaml` 파일을 수정할 수 있습니다.
+
+!!! note
+ Windows 사용자의 경우, `nccl` 문제를 피하려면 `trainer.strategy.process_group_backend=gloo`를 사용할 수 있습니다.
+
+훈련이 완료되면 [추론](inference.md) 섹션을 참고하여 음성을 생성할 수 있습니다.
+
+!!! info
+ 기본적으로 모델은 화자의 말하는 패턴만 학습하고 음색은 학습하지 않습니다. 음색의 안정성을 위해 프롬프트를 사용해야 합니다.
+ 음색을 학습하려면 훈련 단계를 늘릴 수 있지만, 이는 과적합의 위험을 초래할 수 있습니다.
+
+훈련이 끝나면 LoRA 가중치를 일반 가중치로 변환한 후에 추론을 수행해야 합니다.
+
+```bash
+python tools/llama/merge_lora.py \
+ --lora-config r_8_alpha_16 \
+ --base-weight checkpoints/fish-speech-1.4 \
+ --lora-weight results/$project/checkpoints/step_000000010.ckpt \
+ --output checkpoints/fish-speech-1.4-yth-lora/
+```
+
+!!! note
+ 다른 체크포인트도 시도해 볼 수 있습니다. 요구 사항에 맞는 가장 초기 체크포인트를 사용하는 것이 좋습니다. 이들은 종종 분포 밖(OOD) 데이터에서 더 좋은 성능을 발휘합니다.
diff --git a/docs/ko/index.md b/docs/ko/index.md
new file mode 100644
index 0000000000000000000000000000000000000000..6af58535f94ca9385262b3dead14826b540fe170
--- /dev/null
+++ b/docs/ko/index.md
@@ -0,0 +1,215 @@
+# 소개
+
+
+
+!!! warning
+ 이 코드베이스의 불법적인 사용에 대해서는 책임을 지지 않습니다. DMCA(Digital Millennium Copyright Act) 및 해당 지역의 관련 법률을 참조하십시오.
+ 이 코드베이스와 모든 모델은 CC-BY-NC-SA-4.0 라이선스에 따라 배포됩니다.
+
+
+
+
+
+## 요구 사항
+
+- GPU 메모리: 4GB (추론용), 8GB (파인튜닝용)
+- 시스템: Linux, Windows
+
+## Windows 설정
+
+고급 Windows 사용자는 WSL2 또는 Docker를 사용하여 코드베이스를 실행하는 것을 고려할 수 있습니다.
+
+```bash
+# 파이썬 3.10 가상 환경 생성, virtualenv도 사용할 수 있습니다.
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+
+# pytorch 설치
+pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121
+
+# fish-speech 설치
+pip3 install -e .
+
+# (가속 활성화) triton-windows 설치
+pip install https://github.com/AnyaCoder/fish-speech/releases/download/v0.1.0/triton_windows-0.1.0-py3-none-any.whl
+```
+
+비전문 Windows 사용자는 Linux 환경 없이 프로젝트를 실행할 수 있는 다음 기본 방법을 고려할 수 있습니다 (모델 컴파일 기능 포함, 즉 `torch.compile`):
+
+1. 프로젝트 패키지 추출.
+2. `install_env.bat`을 클릭하여 환경 설치.
+3. 컴파일 가속을 활성화하려면 아래 단계를 따르세요:
+ 1. LLVM 컴파일러 다운로드:
+ - [LLVM-17.0.6 (공식 사이트)](https://huggingface.co/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
+ - [LLVM-17.0.6 (미러 사이트)](https://hf-mirror.com/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
+ - `LLVM-17.0.6-win64.exe`를 다운로드 후 더블클릭하여 설치하고, 설치 경로 선택 시 `Add Path to Current User` 옵션을 체크하여 환경 변수를 추가합니다.
+ - 설치가 완료되었는지 확인합니다.
+ 2. Microsoft Visual C++ 재배포 가능 패키지를 다운로드하여 .dll 누락 문제 해결:
+ - [MSVC++ 14.40.33810.0 다운로드](https://aka.ms/vs/17/release/vc_redist.x64.exe)
+ 3. Visual Studio Community Edition을 다운로드하여 LLVM의 헤더 파일 의존성을 해결:
+ - [Visual Studio 다운로드](https://visualstudio.microsoft.com/zh-hans/downloads/)
+ - Visual Studio Installer를 설치한 후 Visual Studio Community 2022를 다운로드.
+ - `Desktop development with C++` 옵션을 선택하여 설치.
+ 4. [CUDA Toolkit 12.x](https://developer.nvidia.com/cuda-12-1-0-download-archive?target_os=Windows&target_arch=x86_64) 다운로드 및 설치.
+4. `start.bat`을 더블 클릭하여 훈련 추론 WebUI 관리 인터페이스를 엽니다. 필요한 경우 아래 지침에 따라 `API_FLAGS`를 수정할 수 있습니다.
+
+!!! info "Optional"
+
+ 추론을 위해 WebUI를 사용하고자 하시나요?
+
+ 프로젝트 루트 디렉토리의 `API_FLAGS.txt` 파일을 편집하고 첫 세 줄을 아래와 같이 수정하세요:
+ ```
+ --infer
+ # --api
+ # --listen ...
+ ...
+ ```
+
+!!! info "Optional"
+
+ API 서버를 시작하고 싶으신가요?
+
+ 프로젝트 루트 디렉토리의 `API_FLAGS.txt` 파일을 편집하고 첫 세 줄을 아래와 같이 수정하세요:
+
+ ```
+ # --infer
+ --api
+ --listen ...
+ ...
+ ```
+
+!!! info "Optional"
+
+ `run_cmd.bat`을 더블 클릭하여 이 프로젝트의 conda/python 명령줄 환경에 진입할 수 있습니다.
+
+## Linux 설정
+
+[pyproject.toml](../../pyproject.toml)에서 자세한 내용을 확인하세요.
+```bash
+# 파이썬 3.10 가상 환경 생성, virtualenv도 사용할 수 있습니다.
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+
+# (Ubuntu / Debian 사용자) sox + ffmpeg 설치
+apt install libsox-dev ffmpeg
+
+# (Ubuntu / Debian 사용자) pyaudio 설치
+apt install build-essential \
+ cmake \
+ libasound-dev \
+ portaudio19-dev \
+ libportaudio2 \
+ libportaudiocpp0
+
+# pytorch 설치
+pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1
+
+# fish-speech 설치
+pip3 install -e .[stable]
+```
+
+## macos 설정
+
+MPS에서 추론을 수행하려면 `--device mps` 플래그를 추가하세요.
+추론 속도 비교는 [이 PR](https://github.com/fishaudio/fish-speech/pull/461#issuecomment-2284277772)을 참조하십시오.
+
+!!! warning
+ Apple Silicon 장치에서는 `compile` 옵션이 공식적으로 지원되지 않으므로 추론 속도가 향상된다는 보장은 없습니다.
+
+```bash
+# 파이썬 3.10 가상 환경 생성, virtualenv도 사용할 수 있습니다.
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+# pytorch 설치
+pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1
+# fish-speech 설치
+pip install -e .[stable]
+```
+
+## Docker 설정
+
+1. NVIDIA Container Toolkit 설치:
+
+ Docker에서 모델 훈련 및 추론에 GPU를 사용하려면 NVIDIA Container Toolkit을 설치해야 합니다:
+
+ Ubuntu 사용자:
+
+ ```bash
+ # 저장소 추가
+ curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg \
+ && curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \
+ sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \
+ sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
+ # nvidia-container-toolkit 설치
+ sudo apt-get update
+ sudo apt-get install -y nvidia-container-toolkit
+ # Docker 서비스 재시작
+ sudo systemctl restart docker
+ ```
+
+ 다른 Linux 배포판 사용자는: [NVIDIA Container Toolkit 설치 가이드](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)를 참조하십시오.
+
+2. fish-speech 이미지 가져오기 및 실행
+
+ ```bash
+ # 이미지 가져오기
+ docker pull fishaudio/fish-speech:latest-dev
+ # 이미지 실행
+ docker run -it \
+ --name fish-speech \
+ --gpus all \
+ -p 7860:7860 \
+ fishaudio/fish-speech:latest-dev \
+ zsh
+ # 다른 포트를 사용하려면 -p 매개변수를 YourPort:7860으로 수정하세요
+ ```
+
+3. 모델 종속성 다운로드
+
+ Docker 컨테이너 내부의 터미널에서 아래 명령어를 사용하여 필요한 `vqgan` 및 `llama` 모델을 Huggingface 리포지토리에서 다운로드합니다.
+
+ ```bash
+ huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+ ```
+
+4. 환경 변수 설정 및 WebUI 접근
+
+ Docker 컨테이너 내부의 터미널에서 `export GRADIO_SERVER_NAME="0.0.0.0"`를 입력하여 Docker 내부에서 Gradio 서비스에 외부 접근을 허용합니다.
+ 이후, 터미널에서 `python tools/webui.py` 명령어를 입력하여 WebUI 서비스를 시작합니다.
+
+ WSL 또는 macOS를 사용하는 경우 [http://localhost:7860](http://localhost:7860)에서 WebUI 인터페이스를 열 수 있습니다.
+
+ 서버에 배포된 경우, localhost를 서버의 IP로 교체하세요.
+
+## 변경 사항
+
+- 2024/09/10: Fish-Speech 1.4 버전으로 업데이트, 데이터셋 크기 증가 및 양자화기의 n_groups를 4에서 8로 변경.
+- 2024/07/02: Fish-Speech 1.2 버전으로 업데이트, VITS 디코더 제거 및 제로샷 능력 크게 향상.
+- 2024/05/10: Fish-Speech 1.1 버전으로 업데이트, WER 감소 및 음색 유사성을 개선하기 위해 VITS 디코더 구현.
+- 2024/04/22: Fish-Speech 1.0 버전 완료, VQGAN 및 LLAMA 모델 대폭 수정.
+- 2023/12/28: `lora` 파인튜닝 지원 추가.
+- 2023/12/27: `gradient checkpointing`, `causual sampling`, 및 `flash-attn` 지원 추가.
+- 2023/12/19: WebUI 및 HTTP API 업데이트.
+- 2023/12/18: 파인튜닝 문서 및 관련 예시 업데이트.
+- 2023/12/17: `text2semantic` 모델 업데이트, 음소 없는 모드 지원.
+- 2023/12/13: 베타 버전 출시, VQGAN 모델 및 LLAMA 기반 언어 모델(음소 지원만 포함).
+
+## 감사의 말
+
+- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
+- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
+- [GPT VITS](https://github.com/innnky/gpt-vits)
+- [MQTTS](https://github.com/b04901014/MQTTS)
+- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
+- [Transformers](https://github.com/huggingface/transformers)
+- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
diff --git a/docs/ko/inference.md b/docs/ko/inference.md
new file mode 100644
index 0000000000000000000000000000000000000000..65b3ec58f1c7076bf87640d17a68e66d4d4e9559
--- /dev/null
+++ b/docs/ko/inference.md
@@ -0,0 +1,134 @@
+# 추론
+
+추론은 명령줄, HTTP API, 그리고 웹 UI에서 지원됩니다.
+
+!!! note
+ 전체 추론 과정은 다음의 여러 단계로 구성됩니다:
+
+ 1. VQGAN을 사용하여 약 10초 분량의 음성을 인코딩합니다.
+ 2. 인코딩된 시맨틱 토큰과 해당 텍스트를 예시로 언어 모델에 입력합니다.
+ 3. 새로운 텍스트를 입력하면, 모델이 해당하는 시맨틱 토큰을 생성합니다.
+ 4. 생성된 시맨틱 토큰을 VITS / VQGAN에 입력하여 음성을 디코딩하고 생성합니다.
+
+## 명령줄 추론
+
+필요한 `vqgan` 및 `llama` 모델을 Hugging Face 리포지토리에서 다운로드하세요.
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+```
+
+### 1. 음성에서 프롬프트 생성:
+
+!!! note
+ 모델이 음색을 무작위로 선택하도록 하려면 이 단계를 건너뛸 수 있습니다.
+
+```bash
+python tools/vqgan/inference.py \
+ -i "paimon.wav" \
+ --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+```
+
+이 명령을 실행하면 `fake.npy` 파일을 얻게 됩니다.
+
+### 2. 텍스트에서 시맨틱 토큰 생성:
+
+```bash
+python tools/llama/generate.py \
+ --text "변환할 텍스트" \
+ --prompt-text "참고할 텍스트" \
+ --prompt-tokens "fake.npy" \
+ --checkpoint-path "checkpoints/fish-speech-1.4" \
+ --num-samples 2 \
+ --compile
+```
+
+이 명령을 실행하면 작업 디렉토리에 `codes_N` 파일이 생성되며, N은 0부터 시작하는 정수입니다.
+
+!!! note
+ 빠른 추론을 위해 `--compile` 옵션을 사용하여 CUDA 커널을 결합할 수 있습니다 (~초당 30 토큰 -> ~초당 500 토큰).
+ `--compile` 매개변수를 주석 처리하여 가속화 옵션을 사용하지 않을 수도 있습니다.
+
+!!! info
+ bf16을 지원하지 않는 GPU의 경우 `--half` 매개변수를 사용해야 할 수 있습니다.
+
+### 3. 시맨틱 토큰에서 음성 생성:
+
+#### VQGAN 디코더
+
+```bash
+python tools/vqgan/inference.py \
+ -i "codes_0.npy" \
+ --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+```
+
+## HTTP API 추론
+
+추론을 위한 HTTP API를 제공하고 있습니다. 아래의 명령어로 서버를 시작할 수 있습니다:
+
+```bash
+python -m tools.api \
+ --listen 0.0.0.0:8080 \
+ --llama-checkpoint-path "checkpoints/fish-speech-1.4" \
+ --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+ --decoder-config-name firefly_gan_vq
+```
+
+추론 속도를 높이고 싶다면 `--compile` 매개변수를 추가할 수 있습니다.
+
+이후, http://127.0.0.1:8080/ 에서 API를 확인하고 테스트할 수 있습니다.
+
+아래는 `tools/post_api.py`를 사용하여 요청을 보내는 예시입니다.
+
+```bash
+python -m tools.post_api \
+ --text "입력할 텍스트" \
+ --reference_audio "참고 음성 경로" \
+ --reference_text "참고 음성의 텍스트 내용" \
+ --streaming True
+```
+
+위 명령은 참고 음성 정보를 바탕으로 원하는 음성을 합성하고, 스트리밍 방식으로 반환합니다.
+
+다음 예시는 여러 개의 참고 음성 경로와 텍스트를 한꺼번에 사용할 수 있음을 보여줍니다. 명령에서 공백으로 구분하여 입력합니다.
+
+```bash
+python -m tools.post_api \
+ --text "입력할 텍스트" \
+ --reference_audio "참고 음성 경로1" "참고 음성 경로2" \
+ --reference_text "참고 음성 텍스트1" "참고 음성 텍스트2"\
+ --streaming False \
+ --output "generated" \
+ --format "mp3"
+```
+
+위 명령어는 여러 참고 음성 정보를 바탕으로 `MP3` 형식의 음성을 합성하여, 현재 디렉토리에 `generated.mp3`로 저장합니다.
+
+`--reference_audio`와 `--reference_text` 대신에 `--reference_id`(하나만 사용 가능)를 사용할 수 있습니다. 프로젝트 루트 디렉토리에 `references/` 폴더를 만들어 해당 음성과 주석 텍스트를 넣어야 합니다. 참고 음성은 최대 90초까지 지원됩니다.
+
+!!! info
+ 제공되는 파라미터는 `python -m tools.post_api -h`를 사용하여 확인할 수 있습니다.
+
+## GUI 추론
+[클라이언트 다운로드](https://github.com/AnyaCoder/fish-speech-gui/releases)
+
+## WebUI 추론
+
+다음 명령으로 WebUI를 시작할 수 있습니다:
+
+```bash
+python -m tools.webui \
+ --llama-checkpoint-path "checkpoints/fish-speech-1.4" \
+ --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+ --decoder-config-name firefly_gan_vq
+```
+
+> 추론 속도를 높이고 싶다면 `--compile` 매개변수를 추가할 수 있습니다.
+
+!!! note
+ 라벨 파일과 참고 음성 파일을 미리 메인 디렉토리의 `references` 폴더에 저장해 두면, WebUI에서 바로 호출할 수 있습니다. (해당 폴더는 직접 생성해야 합니다.)
+
+!!! note
+ WebUI를 구성하기 위해 `GRADIO_SHARE`, `GRADIO_SERVER_PORT`, `GRADIO_SERVER_NAME`과 같은 Gradio 환경 변수를 사용할 수 있습니다.
+
+즐기세요!
diff --git a/docs/ko/samples.md b/docs/ko/samples.md
new file mode 100644
index 0000000000000000000000000000000000000000..5286a3647a120f0ba1d60eff6bd9a83575731aaf
--- /dev/null
+++ b/docs/ko/samples.md
@@ -0,0 +1,137 @@
+# 샘플
+
+ver 1.4
+
+## Credits
+[Seed-TTS (2024)](https://bytedancespeech.github.io/seedtts_tech_report/)에 감사드리며, 평가 데이터를 제공해 주셔서 이 데모를 완성할 수 있었습니다.
+
+모든 프롬프트 음성은 Seed-TTS 효과 데모 페이지에서 가져왔으며, 모든 생성된 음성은 fish-speech 버전 1.4에서 첫 번째로 생성된 것입니다.
+
+## 제로샷 인컨텍스트 학습
+- TODO: 한국어 제로샷 인컨텍스트 학습 샘플 추가. (현재는 영어와 중국어 데모만 제공됩니다.)
+
+
+
+
+ 언어
+ 프롬프트
+ 동일 언어 생성
+ 교차 언어 생성
+
+
+
+
+ EN
+ Your browser does not support the audio element.
+ Your browser does not support the audio element. I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.
+ Your browser does not support the audio element. 顿时,气氛变得沉郁起来。乍看之下,一切的困扰仿佛都围绕在我身边。我皱着眉头,感受着那份压力,但我知道我不能放弃,不能认输。于是,我深吸一口气,心底的声音告诉我:“无论如何,都要冷静下来,重新开始。”
+
+
+ Your browser does not support the audio element.
+ Your browser does not support the audio element. Dealing with family secrets is never easy. Yet, sometimes, omission is a form of protection, intending to safeguard some from the harsh truths. One day, I hope you understand the reasons behind my actions. Until then, Anna, please, bear with me.
+ Your browser does not support the audio element. 处理家庭秘密从来都不是一件容易的事。然而,有时候,隐瞒是一种保护形式,旨在保护一些人免受残酷的真相伤害。有一天,我希望你能理解我行为背后的原因。在那之前,安娜,请容忍我。
+
+
+ Your browser does not support the audio element.
+ Your browser does not support the audio element. The combinations of different textures and flavors create a perfect harmony. The succulence of the steak, the tartness of the cranberries, the crunch of pine nuts, and creaminess of blue cheese make it a truly delectable delight. Enjoy your culinary adventure!
+ Your browser does not support the audio element. 听着你的话,我心里五味杂陈。虽然我愿意一直在你身边,承担一切不幸,但我知道只有让你自己面对,才能真正让你变得更强大。所以,你要记得,无论面对何种困难,都请你坚强,我会在心里一直支持你的。
+
+
+ ZH
+ Your browser does not support the audio element.
+ Your browser does not support the audio element. 突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:"我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?"
+ Your browser does not support the audio element. Suddenly, there was a burst of laughter beside me. I looked at them, stood up straight with high spirit, shook the slightly fleshy arms, and smiled lightly, saying, "The flesh on my body is to hide my bursting charm. Otherwise, wouldn't it scare you?"
+
+
+ Your browser does not support the audio element.
+ Your browser does not support the audio element. 他闭上眼睛,期望这一切都能过去。然而,当他再次睁开眼睛,眼前的景象让他不禁倒吸一口气。雾气中出现的禁闭岛,陌生又熟悉,充满未知的危险。他握紧拳头,心知他的生活即将发生翻天覆地的改变。
+ Your browser does not support the audio element. He closed his eyes, expecting that all of this could pass. However, when he opened his eyes again, the sight in front of him made him couldn't help but take a deep breath. The closed island that appeared in the fog, strange and familiar, was full of unknown dangers. He tightened his fist, knowing that his life was about to undergo earth-shaking changes.
+
+
+ Your browser does not support the audio element.
+ Your browser does not support the audio element. 顿时,气氛变得沉郁起来。乍看之下,一切的困扰仿佛都围绕在我身边。我皱着眉头,感受着那份压力,但我知道我不能放弃,不能认输。于是,我深吸一口气,心底的声音告诉我:“无论如何,都要冷静下来,重新开始。”
+ Your browser does not support the audio element. Suddenly, the atmosphere became gloomy. At first glance, all the troubles seemed to surround me. I frowned, feeling that pressure, but I know I can't give up, can't admit defeat. So, I took a deep breath, and the voice in my heart told me, "Anyway, must calm down and start again."
+
+
+
+
+## 화자 파인튜닝
+
+
+
+
+
+ 텍스트
+ 생성된 음성
+
+
+
+
+ 화자1
+ 好呀,哈哈哈哈哈,喜欢笑的人运气都不会差哦,希望你每天笑口常开~
+ Your browser does not support the audio element.
+
+
+ 哇!恭喜你中了大乐透,八百万可真不少呢!有什么特别的计划或想法吗?
+ Your browser does not support the audio element.
+
+
+ 哼,你这么问是想请本小姐吃饭吗?如果对象是你的话,那也不是不可以。
+ Your browser does not support the audio element.
+
+
+ 화자2
+ 是呀,他还想换个地球仪哈哈哈,看来给你积累了一些快乐值了,你还想不想再听一个其他的笑话呀?
+ Your browser does not support the audio element.
+
+
+ 嘿嘿,你是不是也想拥有甜甜的恋爱呢?《微微一笑很倾城》是你的不二选择,男女主是校花校草类型,他们通过游戏结识,再到两人见面,全程没有一点误会,真的齁甜,想想都忍不住“姨妈笑”~
+ Your browser does not support the audio element.
+
+
+ 小傻瓜,嗯……算是个很可爱很亲切的名字,有点“独特”哦,不过我有些好奇,你为什么会给我选这个昵称呢?
+ Your browser does not support the audio element.
+
+
+
+
+
+## 콘텐츠 편집
+
+
+
+ 언어
+ 원본 텍스트
+ 원본 음성
+ 목표 텍스트
+ 편집된 음성
+
+
+
+ EN
+ They can't order me to stop dreaming. If you dream a thing more than once, it's sure to come true. Have faith in your dreams, and someday your rainbow will come shining through.
+ Your browser does not support the audio element.
+ They can't require me to stop imagining. If you envision a thing more than once, it's bound to come about . Have trust in your visions , and someday your radiance will come beaming through.
+ Your browser does not support the audio element.
+
+
+ Are you familiar with it? Slice the steak and place the strips on top, then garnish with the dried cranberries, pine nuts, and blue cheese. I wonder how people rationalise the decision?
+ Your browser does not support the audio element.
+ Are you acquainted with it? Cut the pork and place the strips on top, then garnish with the dried cherries, almonds, and feta cheese. I query how people justify the choice?
+ Your browser does not support the audio element.
+
+
+ ZH
+ 自古以来,庸君最怕党政了,可圣君他就不怕,不但不怕,反能利用。要我说,你就让明珠索额图互相争宠,只要你心里明白,左右逢源,你就能立于不败之地。
+ Your browser does not support the audio element.
+ 从古至今 ,庸君最怕朝纲了 ,可明 君他就不怕,不但不怕,反能借助 。要我说,你就让李四张三 互相争宠,只要你心里清楚 ,左右周旋 ,你就能处 于不败之境 。
+ Your browser does not support the audio element.
+
+
+ 对,这就是我,万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。
+ Your browser does not support the audio element.
+ 对,这就是我,众人尊崇 的太白金星 ,虽然有点娃娃脸 ,但也遮 不住我迷人 的魅力。
+ Your browser does not support the audio element.
+
+
+
diff --git a/docs/ko/start_agent.md b/docs/ko/start_agent.md
new file mode 100644
index 0000000000000000000000000000000000000000..dedc76984d3d2f0faec17652792650ca5f3de71c
--- /dev/null
+++ b/docs/ko/start_agent.md
@@ -0,0 +1,80 @@
+# 에이전트 시작하기
+
+!!! note
+ 전체 문서는 claude3.5 Sonnet에 의해 번역되었으며, 원어민인 경우 번역에 문제가 있다고 생각되면 이슈나 풀 리퀘스트를 보내주셔서 대단히 감사합니다!
+
+## 요구사항
+
+- GPU 메모리: 최소 8GB(양자화 사용 시), 16GB 이상 권장
+- 디스크 사용량: 10GB
+
+## 모델 다운로드
+
+다음 명령어로 모델을 받을 수 있습니다:
+
+```bash
+huggingface-cli download fishaudio/fish-agent-v0.1-3b --local-dir checkpoints/fish-agent-v0.1-3b
+```
+
+'checkpoints' 폴더에 파일들을 넣으세요.
+
+또한 [inference](inference.md)에 설명된 대로 fish-speech 모델도 다운로드해야 합니다.
+
+checkpoints에는 2개의 폴더가 있어야 합니다.
+
+`checkpoints/fish-speech-1.4`와 `checkpoints/fish-agent-v0.1-3b`입니다.
+
+## 환경 준비
+
+이미 Fish-speech가 있다면 다음 명령어를 추가하여 바로 사용할 수 있습니다:
+```bash
+pip install cachetools
+```
+
+!!! 참고
+ 컴파일을 위해 Python 3.12 미만 버전을 사용해 주세요.
+
+없다면 아래 명령어를 사용하여 환경을 구축하세요:
+
+```bash
+sudo apt-get install portaudio19-dev
+
+pip install -e .[stable]
+```
+
+## 에이전트 데모 실행
+
+fish-agent를 구축하려면 메인 폴더에서 아래 명령어를 사용하세요:
+
+```bash
+python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
+```
+
+`--compile` 인자는 Python < 3.12에서만 지원되며, 토큰 생성 속도를 크게 향상시킵니다.
+
+한 번에 컴파일되지 않습니다(기억해 두세요).
+
+그런 다음 다른 터미널을 열고 다음 명령어를 사용하세요:
+
+```bash
+python -m tools.e2e_webui
+```
+
+이렇게 하면 기기에 Gradio WebUI가 생성됩니다.
+
+모델을 처음 사용할 때는 (`--compile`이 True인 경우) 잠시 컴파일이 진행되므로 기다려 주세요.
+
+## Gradio Webui
+
+
+
+
+즐거운 시간 되세요!
+
+## 성능
+
+테스트 결과, 4060 노트북은 겨우 실행되며 매우 부하가 큰 상태로, 초당 약 8토큰 정도만 처리합니다. 4090은 컴파일 상태에서 초당 약 95토큰을 처리하며, 이것이 저희가 권장하는 사양입니다.
+
+# 에이전트 소개
+
+이 데모는 초기 알파 테스트 버전으로, 추론 속도 최적화가 필요하며 수정해야 할 버그가 많이 있습니다. 버그를 발견하거나 수정하고 싶으시다면 이슈나 풀 리퀘스트를 보내주시면 매우 감사하겠습니다.
diff --git a/docs/pt/finetune.md b/docs/pt/finetune.md
new file mode 100644
index 0000000000000000000000000000000000000000..f57d92c7609e886f891ab5725eb1c5874047084c
--- /dev/null
+++ b/docs/pt/finetune.md
@@ -0,0 +1,128 @@
+# Ajuste Fino
+
+É óbvio que ao abrir esta página, você não deve estar muito satisfeito com o desempenho do modelo pré-treinado com poucos exemplos. Você pode querer ajustar o modelo para melhorar seu desempenho em seu conjunto de dados.
+
+Na atual versão, a única coisa que você precisa ajustar é a parte do 'LLAMA'.
+
+## Ajuste Fino do LLAMA
+### 1. Preparando o conjunto de dados
+
+```
+.
+├── SPK1
+│ ├── 21.15-26.44.lab
+│ ├── 21.15-26.44.mp3
+│ ├── 27.51-29.98.lab
+│ ├── 27.51-29.98.mp3
+│ ├── 30.1-32.71.lab
+│ └── 30.1-32.71.mp3
+└── SPK2
+ ├── 38.79-40.85.lab
+ └── 38.79-40.85.mp3
+```
+
+Você precisa converter seu conjunto de dados para o formato acima e colocá-lo em `data`. O arquivo de áudio pode ter as extensões `.mp3`, `.wav` ou `.flac`, e o arquivo de anotação deve ter a extensão `.lab`.
+
+!!! info
+ O arquivo de anotação `.lab` deve conter apenas a transcrição do áudio, sem a necessidade de formatação especial. Por exemplo, se o arquivo `hi.mp3` disser "Olá, tchau", o arquivo `hi.lab` conterá uma única linha de texto: "Olá, tchau".
+
+!!! warning
+ É recomendado aplicar normalização de volume ao conjunto de dados. Você pode usar o [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) para fazer isso.
+
+ ```bash
+ fap loudness-norm data-raw data --clean
+ ```
+
+
+### 2. Extração em lote de tokens semânticos
+
+Certifique-se de ter baixado os pesos do VQGAN. Se não, execute o seguinte comando:
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+```
+
+Em seguida, você pode executar o seguinte comando para extrair os tokens semânticos:
+
+```bash
+python tools/vqgan/extract_vq.py data \
+ --num-workers 1 --batch-size 16 \
+ --config-name "firefly_gan_vq" \
+ --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+```
+
+!!! note
+ Você pode ajustar `--num-workers` e `--batch-size` para aumentar a velocidade de extração, mas certifique-se de não exceder o limite de memória da sua GPU.
+ Para o formato VITS, você pode especificar uma lista de arquivos usando `--filelist xxx.list`.
+
+Este comando criará arquivos `.npy` no diretório `data`, como mostrado abaixo:
+
+```
+.
+├── SPK1
+│ ├── 21.15-26.44.lab
+│ ├── 21.15-26.44.mp3
+│ ├── 21.15-26.44.npy
+│ ├── 27.51-29.98.lab
+│ ├── 27.51-29.98.mp3
+│ ├── 27.51-29.98.npy
+│ ├── 30.1-32.71.lab
+│ ├── 30.1-32.71.mp3
+│ └── 30.1-32.71.npy
+└── SPK2
+ ├── 38.79-40.85.lab
+ ├── 38.79-40.85.mp3
+ └── 38.79-40.85.npy
+```
+
+### 3. Empacotar o conjunto de dados em protobuf
+
+```bash
+python tools/llama/build_dataset.py \
+ --input "data" \
+ --output "data/protos" \
+ --text-extension .lab \
+ --num-workers 16
+```
+
+Após executar o comando, você deverá ver o arquivo `quantized-dataset-ft.protos` no diretório `data`.
+
+### 4. E finalmente, chegamos ao ajuste fino com LoRA
+
+Da mesma forma, certifique-se de ter baixado os pesos do `LLAMA`. Se não, execute o seguinte comando:
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+```
+
+E então, execute o seguinte comando para iniciar o ajuste fino:
+
+```bash
+python fish_speech/train.py --config-name text2semantic_finetune \
+ project=$project \
+ +lora@model.model.lora_config=r_8_alpha_16
+```
+
+!!! note
+ Se quiser, você pode modificar os parâmetros de treinamento, como `batch_size`, `gradient_accumulation_steps`, etc., para se ajustar à memória da sua GPU, modificando `fish_speech/configs/text2semantic_finetune.yaml`.
+
+!!! note
+ Para usuários do Windows, é recomendado usar `trainer.strategy.process_group_backend=gloo` para evitar problemas com `nccl`.
+
+Após concluir o treinamento, consulte a seção [inferência](inference.md).
+
+!!! info
+ Por padrão, o modelo aprenderá apenas os padrões de fala do orador e não o timbre. Ainda pode ser preciso usar prompts para garantir a estabilidade do timbre.
+ Se quiser que ele aprenda o timbre, aumente o número de etapas de treinamento, mas isso pode levar ao overfitting (sobreajuste).
+
+Após o treinamento, é preciso converter os pesos do LoRA em pesos regulares antes de realizar a inferência.
+
+```bash
+python tools/llama/merge_lora.py \
+ --lora-config r_8_alpha_16 \
+ --base-weight checkpoints/fish-speech-1.4 \
+ --lora-weight results/$project/checkpoints/step_000000010.ckpt \
+ --output checkpoints/fish-speech-1.4-yth-lora/
+```
+!!! note
+ É possível também tentar outros checkpoints. Sugerimos usar o checkpoint que melhor atenda aos seus requisitos, pois eles geralmente têm um desempenho melhor em dados fora da distribuição (OOD).
diff --git a/docs/pt/index.md b/docs/pt/index.md
new file mode 100644
index 0000000000000000000000000000000000000000..05e27ff75548a3cbf26a9429ecab9ecdf29fa953
--- /dev/null
+++ b/docs/pt/index.md
@@ -0,0 +1,210 @@
+# Introdução
+
+
+
+!!! warning
+ Não nos responsabilizamos por qualquer uso ilegal do código-fonte. Consulte as leis locais sobre DMCA (Digital Millennium Copyright Act) e outras leis relevantes em sua região.
+ Este repositório de código e os modelos são distribuídos sob a licença CC-BY-NC-SA-4.0.
+
+
+
+
+
+## Requisitos
+
+- Memória da GPU: 4GB (para inferência), 8GB (para ajuste fino)
+- Sistema: Linux, Windows
+
+## Configuração do Windows
+
+Usuários profissionais do Windows podem considerar o uso do WSL2 ou Docker para executar a base de código.
+
+```bash
+# Crie um ambiente virtual Python 3.10, também é possível usar o virtualenv
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+
+# Instale o pytorch
+pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121
+
+# Instale o fish-speech
+pip3 install -e .
+
+# (Ativar aceleração) Instalar triton-windows
+pip install https://github.com/AnyaCoder/fish-speech/releases/download/v0.1.0/triton_windows-0.1.0-py3-none-any.whl
+```
+
+Usuários não profissionais do Windows podem considerar os seguintes métodos básicos para executar o projeto sem um ambiente Linux (com capacidades de compilação de modelo, ou seja, `torch.compile`):
+
+1. Extraia o pacote do projeto.
+2. Clique em `install_env.bat` para instalar o ambiente.
+3. Se você quiser ativar a aceleração de compilação, siga estas etapas:
+ 1. Baixe o compilador LLVM nos seguintes links:
+ - [LLVM-17.0.6 (Download do site oficial)](https://huggingface.co/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
+ - [LLVM-17.0.6 (Download do site espelho)](https://hf-mirror.com/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
+ - Após baixar o `LLVM-17.0.6-win64.exe`, clique duas vezes para instalar, selecione um local de instalação apropriado e, o mais importante, marque a opção `Add Path to Current User` para adicionar a variável de ambiente.
+ - Confirme que a instalação foi concluída.
+ 2. Baixe e instale o Microsoft Visual C++ Redistributable para resolver possíveis problemas de arquivos .dll ausentes:
+ - [Download do MSVC++ 14.40.33810.0](https://aka.ms/vs/17/release/vc_redist.x64.exe)
+ 3. Baixe e instale o Visual Studio Community Edition para obter as ferramentas de compilação do MSVC++ e resolver as dependências dos arquivos de cabeçalho do LLVM:
+ - [Download do Visual Studio](https://visualstudio.microsoft.com/pt-br/downloads/)
+ - Após instalar o Visual Studio Installer, baixe o Visual Studio Community 2022.
+ - Conforme mostrado abaixo, clique no botão `Modificar`, encontre a opção `Desenvolvimento de área de trabalho com C++` e selecione para fazer o download.
+ 4. Baixe e instale o [CUDA Toolkit 12.x](https://developer.nvidia.com/cuda-12-1-0-download-archive?target_os=Windows&target_arch=x86_64)
+4. Clique duas vezes em `start.bat` para abrir a interface de gerenciamento WebUI de inferência de treinamento. Se necessário, você pode modificar as `API_FLAGS` conforme mostrado abaixo.
+
+!!! info "Opcional"
+ Você quer iniciar o WebUI de inferência?
+ Edite o arquivo `API_FLAGS.txt` no diretório raiz do projeto e modifique as três primeiras linhas como segue:
+ ```
+ --infer
+ # --api
+ # --listen ...
+ ...
+ ```
+
+!!! info "Opcional"
+ Você quer iniciar o servidor de API?
+ Edite o arquivo `API_FLAGS.txt` no diretório raiz do projeto e modifique as três primeiras linhas como segue:
+
+ ```
+ # --infer
+ --api
+ --listen ...
+ ...
+ ```
+
+!!! info "Opcional"
+ Clique duas vezes em `run_cmd.bat` para entrar no ambiente de linha de comando conda/python deste projeto.
+
+
+## Configuração para Linux
+
+Para mais detalhes, consulte [pyproject.toml](../../pyproject.toml).
+```bash
+# Crie um ambiente virtual python 3.10, você também pode usar virtualenv
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+
+# Instale o pytorch
+pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1
+
+# Para os Usuário do Ubuntu / Debian: Instale o sox + ffmpeg
+apt install libsox-dev ffmpeg
+
+# Para os Usuário do Ubuntu / Debian: Instale o pyaudio
+apt install build-essential \
+ cmake \
+ libasound-dev \
+ portaudio19-dev \
+ libportaudio2 \
+ libportaudiocpp0
+
+# Instale o fish-speech
+pip3 install -e .[stable]
+```
+
+## Configuração para macos
+
+Se você quiser realizar inferências no MPS, adicione a flag `--device mps`.
+Para uma comparação das velocidades de inferência, consulte [este PR](https://github.com/fishaudio/fish-speech/pull/461#issuecomment-2284277772).
+
+!!! aviso
+ A opção `compile` não é oficialmente suportada em dispositivos Apple Silicon, então não há garantia de que a velocidade de inferência irá melhorar.
+
+```bash
+# create a python 3.10 virtual environment, you can also use virtualenv
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+# install pytorch
+pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1
+# install fish-speech
+pip install -e .[stable]
+```
+
+## Configuração do Docker
+
+1. Instale o NVIDIA Container Toolkit:
+
+ Para usar a GPU com Docker para treinamento e inferência de modelos, você precisa instalar o NVIDIA Container Toolkit:
+
+ Para usuários Ubuntu:
+
+ ```bash
+ # Adicione o repositório remoto
+ curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg \
+ && curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \
+ sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \
+ sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
+ # Instale o nvidia-container-toolkit
+ sudo apt-get update
+ sudo apt-get install -y nvidia-container-toolkit
+ # Reinicie o serviço Docker
+ sudo systemctl restart docker
+ ```
+
+ Para usuários de outras distribuições Linux, consulte o guia de instalação: [NVIDIA Container Toolkit Install-guide](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html).
+
+2. Baixe e execute a imagem fish-speech
+
+ ```shell
+ # Baixe a imagem
+ docker pull fishaudio/fish-speech:latest-dev
+ # Execute a imagem
+ docker run -it \
+ --name fish-speech \
+ --gpus all \
+ -p 7860:7860 \
+ fishaudio/fish-speech:latest-dev \
+ zsh
+ # Se precisar usar outra porta, modifique o parâmetro -p para YourPort:7860
+ ```
+
+3. Baixe as dependências do modelo
+
+ Certifique-se de estar no terminal do contêiner Docker e, em seguida, baixe os modelos necessários `vqgan` e `llama` do nosso repositório HuggingFace.
+
+ ```bash
+ huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+ ```
+
+4. Configure as variáveis de ambiente e acesse a WebUI
+
+ No terminal do contêiner Docker, digite `export GRADIO_SERVER_NAME="0.0.0.0"` para permitir o acesso externo ao serviço gradio dentro do Docker.
+ Em seguida, no terminal do contêiner Docker, digite `python tools/webui.py` para iniciar o serviço WebUI.
+
+ Se estiver usando WSL ou MacOS, acesse [http://localhost:7860](http://localhost:7860) para abrir a interface WebUI.
+
+ Se estiver implantando em um servidor, substitua localhost pelo IP do seu servidor.
+
+## Histórico de Alterações
+- 10/09/2024: Fish-Speech atualizado para a versão 1.4, aumentado o tamanho do conjunto de dados, quantizer n_groups 4 -> 8.
+- 02/07/2024: Fish-Speech atualizado para a versão 1.2, removido o Decodificador VITS e aprimorado consideravelmente a capacidade de zero-shot.
+- 10/05/2024: Fish-Speech atualizado para a versão 1.1, implementado o decodificador VITS para reduzir a WER e melhorar a similaridade de timbre.
+- 22/04/2024: Finalizada a versão 1.0 do Fish-Speech, modificados significativamente os modelos VQGAN e LLAMA.
+- 28/12/2023: Adicionado suporte para ajuste fino `lora`.
+- 27/12/2023: Adicionado suporte para `gradient checkpointing`, `causual sampling` e `flash-attn`.
+- 19/12/2023: Atualizada a interface web e a API HTTP.
+- 18/12/2023: Atualizada a documentação de ajuste fino e exemplos relacionados.
+- 17/12/2023: Atualizado o modelo `text2semantic`, suportando o modo sem fonemas.
+- 13/12/2023: Versão beta lançada, incluindo o modelo VQGAN e um modelo de linguagem baseado em LLAMA (suporte apenas a fonemas).
+
+## Agradecimentos
+
+- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
+- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
+- [GPT VITS](https://github.com/innnky/gpt-vits)
+- [MQTTS](https://github.com/b04901014/MQTTS)
+- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
+- [Transformers](https://github.com/huggingface/transformers)
+- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
diff --git a/docs/pt/inference.md b/docs/pt/inference.md
new file mode 100644
index 0000000000000000000000000000000000000000..8cbaa4ee84a25619cba78543d63a19af44a4f08c
--- /dev/null
+++ b/docs/pt/inference.md
@@ -0,0 +1,114 @@
+# Inferência
+
+Suporte para inferência por linha de comando, API HTTP e interface web (WebUI).
+
+!!! note
+ O processo de raciocínio, em geral, consiste em várias partes:
+
+ 1. Codificar cerca de 10 segundos de voz usando VQGAN.
+ 2. Inserir os tokens semânticos codificados e o texto correspondente no modelo de linguagem como um exemplo.
+ 3. Dado um novo trecho de texto, fazer com que o modelo gere os tokens semânticos correspondentes.
+ 4. Inserir os tokens semânticos gerados no VITS / VQGAN para decodificar e gerar a voz correspondente.
+
+## Inferência por Linha de Comando
+
+Baixe os modelos `vqgan` e `llama` necessários do nosso repositório Hugging Face.
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+```
+
+### 1. Gerar prompt a partir da voz:
+
+!!! note
+ Se quiser permitir que o modelo escolha aleatoriamente um timbre de voz, pule esta etapa.
+
+```bash
+python tools/vqgan/inference.py \
+ -i "paimon.wav" \
+ --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+```
+
+Você deverá obter um arquivo `fake.npy`.
+
+### 2. Gerar tokens semânticos a partir do texto:
+
+```bash
+python tools/llama/generate.py \
+ --text "O texto que você deseja converter" \
+ --prompt-text "Seu texto de referência" \
+ --prompt-tokens "fake.npy" \
+ --checkpoint-path "checkpoints/fish-speech-1.4" \
+ --num-samples 2 \
+ --compile
+```
+
+Este comando criará um arquivo `codes_N` no diretório de trabalho, onde N é um número inteiro começando de 0.
+
+!!! note
+ Use `--compile` para fundir kernels CUDA para ter uma inferência mais rápida (~30 tokens/segundo -> ~500 tokens/segundo).
+ Mas, se não planeja usar a aceleração CUDA, comente o parâmetro `--compile`.
+
+!!! info
+ Para GPUs que não suportam bf16, pode ser necessário usar o parâmetro `--half`.
+
+### 3. Gerar vocais a partir de tokens semânticos:
+
+#### Decodificador VQGAN
+
+```bash
+python tools/vqgan/inference.py \
+ -i "codes_0.npy" \
+ --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+```
+
+## Inferência por API HTTP
+
+Fornecemos uma API HTTP para inferência. O seguinte comando pode ser usado para iniciar o servidor:
+
+```bash
+python -m tools.api \
+ --listen 0.0.0.0:8080 \
+ --llama-checkpoint-path "checkpoints/fish-speech-1.4" \
+ --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+ --decoder-config-name firefly_gan_vq
+```
+
+> Para acelerar a inferência, adicione o parâmetro `--compile`.
+
+Depois disso, é possível visualizar e testar a API em http://127.0.0.1:8080/.
+
+Abaixo está um exemplo de envio de uma solicitação usando `tools/post_api.py`.
+
+```bash
+python -m tools.post_api \
+ --text "Texto a ser inserido" \
+ --reference_audio "Caminho para o áudio de referência" \
+ --reference_text "Conteúdo de texto do áudio de referência" \
+ --streaming True
+```
+
+O comando acima indica a síntese do áudio desejada de acordo com as informações do áudio de referência e a retorna em modo de streaming.
+
+!!! info
+ Para aprender mais sobre parâmetros disponíveis, você pode usar o comando `python -m tools.post_api -h`
+
+## Inferência por WebUI
+
+Para iniciar a WebUI de Inferência execute o seguinte comando:
+
+```bash
+python -m tools.webui \
+ --llama-checkpoint-path "checkpoints/fish-speech-1.4" \
+ --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+ --decoder-config-name firefly_gan_vq
+```
+> Para acelerar a inferência, adicione o parâmetro `--compile`.
+
+!!! note
+ Você pode salvar antecipadamente o arquivo de rótulos e o arquivo de áudio de referência na pasta `references` do diretório principal (que você precisa criar), para que possa chamá-los diretamente na WebUI.
+
+!!! note
+ É possível usar variáveis de ambiente do Gradio, como `GRADIO_SHARE`, `GRADIO_SERVER_PORT`, `GRADIO_SERVER_NAME`, para configurar a WebUI.
+
+Divirta-se!
diff --git a/docs/pt/samples.md b/docs/pt/samples.md
new file mode 100644
index 0000000000000000000000000000000000000000..25042475692d5412ce1f78f4bf29552d974ab457
--- /dev/null
+++ b/docs/pt/samples.md
@@ -0,0 +1,225 @@
+# Amostras
+
+A demonstração da versão 1.4 foi atualizada [aqui](https://speech.fish.audio/samples/)
+
+As amostras da v1.2 estão disponíveis em [Bilibili](https://www.bilibili.com/video/BV1wz421B71D/).
+
+As seguintes amostras são do modelo v1.1.
+
+## Frase em Chinês 1
+```
+人间灯火倒映湖中,她的渴望让静水泛起涟漪。若代价只是孤独,那就让这份愿望肆意流淌。
+流入她所注视的世间,也流入她如湖水般澄澈的目光。
+```
+
+
+
+
+ Orador
+ Áudio de Entrada
+ Áudio Sintetizado
+
+
+
+
+ Nahida (Genshin Impact)
+
+
+
+
+ Zhongli (Genshin Impact)
+
+
+
+
+ Furina (Genshin Impact)
+
+
+
+
+ Orador Aleatório 1
+ -
+
+
+
+ Orador Aleatório 2
+ -
+
+
+
+
+
+
+## Frase em Chinês 2
+```
+你们这个是什么群啊,你们这是害人不浅啊你们这个群!谁是群主,出来!真的太过分了。你们搞这个群干什么?
+我儿子每一科的成绩都不过那个平均分呐,他现在初二,你叫我儿子怎么办啊?他现在还不到高中啊?
+你们害死我儿子了!快点出来你这个群主!再这样我去报警了啊!我跟你们说你们这一帮人啊,一天到晚啊,
+搞这些什么游戏啊,动漫啊,会害死你们的,你们没有前途我跟你说。你们这九百多个人,好好学习不好吗?
+一天到晚在上网。有什么意思啊?麻烦你重视一下你们的生活的目标啊?有一点学习目标行不行?一天到晚上网是不是人啊?
+```
+
+
+
+
+ Orador
+ Áudio de Entrada
+ Áudio Sintetizado
+
+
+
+
+ Nahida (Genshin Impact)
+
+
+
+
+ Orador Aleatório
+ -
+
+
+
+
+
+
+## Frase em Chinês 3
+```
+大家好,我是 Fish Audio 开发的开源文本转语音模型。经过十五万小时的数据训练,
+我已经能够熟练掌握中文、日语和英语,我的语言处理能力接近人类水平,声音表现形式丰富多变。
+作为一个仅有亿级参数的模型,我相信社区成员能够在个人设备上轻松运行和微调,让我成为您的私人语音助手。
+```
+
+
+
+
+
+ Orador
+ Áudio de Entrada
+ Áudio Sintetizado
+
+
+
+
+ Orador Aleatório
+ -
+
+
+
+
+
+## Frase em Inglês 1
+
+```
+In the realm of advanced technology, the evolution of artificial intelligence stands as a
+monumental achievement. This dynamic field, constantly pushing the boundaries of what
+machines can do, has seen rapid growth and innovation. From deciphering complex data
+patterns to driving cars autonomously, AI's applications are vast and diverse.
+```
+
+
+
+
+ Orador
+ Áudio de Entrada
+ Áudio Sintetizado
+
+
+
+
+ Orador Aleatório 1
+ -
+
+
+
+ Orador Aleatório 2
+ -
+
+
+
+
+
+## Frase em Inglês 2
+```
+Hello everyone, I am an open-source text-to-speech model developed by
+Fish Audio. After training with 150,000 hours of data, I have become proficient
+in Chinese, Japanese, and English, and my language processing abilities
+are close to human level. My voice is capable of a wide range of expressions.
+As a model with only hundreds of millions of parameters, I believe community
+members can easily run and fine-tune me on their personal devices, allowing
+me to serve as your personal voice assistant.
+```
+
+
+
+
+ Orador
+ Áudio de Entrada
+ Áudio Sintetizado
+
+
+
+
+ Orador Aleatório
+ -
+
+
+
+
+
+## Frase em Japonês 1
+
+```
+先進技術の領域において、人工知能の進化は画期的な成果として立っています。常に機械ができることの限界を
+押し広げているこのダイナミックな分野は、急速な成長と革新を見せています。複雑なデータパターンの解読か
+ら自動運転車の操縦まで、AIの応用は広範囲に及びます。
+```
+
+
+
+
+
+ Orador
+ Áudio de Entrada
+ Áudio Sintetizado
+
+
+
+
+ Orador Aleatório 1
+ -
+
+
+
+ Orador Aleatório 2
+ -
+
+
+
+
+
+## Frase em Japonês 2
+```
+皆さん、こんにちは。私はフィッシュオーディオによって開発されたオープンソースのテ
+キストから音声への変換モデルです。15万時間のデータトレーニングを経て、
+中国語、日本語、英語を熟知しており、言語処理能力は人間に近いレベルです。
+声の表現も多彩で豊かです。数億のパラメータを持つこのモデルは、コミュニティ
+のメンバーが個人のデバイスで簡単に実行し、微調整することができると
+信じています。これにより、私を個人の音声アシスタントとして活用できます。
+```
+
+
+
+
+ Orador
+ Áudio de Entrada
+ Áudio Sintetizado
+
+
+
+
+ Orador Aleatório
+ -
+
+
+
+
diff --git a/docs/pt/start_agent.md b/docs/pt/start_agent.md
new file mode 100644
index 0000000000000000000000000000000000000000..a17321bc89277e7ac074a15e94190d6b4d6e86f5
--- /dev/null
+++ b/docs/pt/start_agent.md
@@ -0,0 +1,80 @@
+# Iniciar Agente
+
+!!! note
+ Todo o documento foi traduzido por claude3.5 Sonnet, se você for um falante nativo e achar a tradução problemática, muito obrigado por nos enviar um problema ou uma solicitação pull!
+
+## Requisitos
+
+- Memória GPU: No mínimo 8GB (com quantização), 16GB ou mais é recomendado.
+- Uso de disco: 10GB
+
+## Download do Modelo
+
+Você pode obter o modelo através de:
+
+```bash
+huggingface-cli download fishaudio/fish-agent-v0.1-3b --local-dir checkpoints/fish-agent-v0.1-3b
+```
+
+Coloque-os na pasta 'checkpoints'.
+
+Você também precisará do modelo fish-speech que pode ser baixado seguindo as instruções em [inference](inference.md).
+
+Então haverá 2 pastas em checkpoints.
+
+O `checkpoints/fish-speech-1.4` e `checkpoints/fish-agent-v0.1-3b`
+
+## Preparação do Ambiente
+
+Se você já tem o Fish-speech, pode usar diretamente adicionando a seguinte instrução:
+```bash
+pip install cachetools
+```
+
+!!! nota
+ Por favor, use a versão Python abaixo de 3.12 para compilação.
+
+Se você não tem, use os comandos abaixo para construir seu ambiente:
+
+```bash
+sudo apt-get install portaudio19-dev
+
+pip install -e .[stable]
+```
+
+## Iniciar a Demo do Agente
+
+Para construir o fish-agent, use o comando abaixo na pasta principal:
+
+```bash
+python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
+```
+
+O argumento `--compile` só suporta Python < 3.12, o que aumentará muito a velocidade de geração de tokens.
+
+Não será compilado de uma vez (lembre-se).
+
+Então abra outro terminal e use o comando:
+
+```bash
+python -m tools.e2e_webui
+```
+
+Isso criará uma WebUI Gradio no dispositivo.
+
+Quando você usar o modelo pela primeira vez, ele irá compilar (se `--compile` estiver True) por um curto período, então aguarde com paciência.
+
+## Gradio Webui
+
+
+
+
+Divirta-se!
+
+## Desempenho
+
+Em nossos testes, um laptop com 4060 mal consegue rodar, ficando muito sobrecarregado, gerando apenas cerca de 8 tokens/s. A 4090 gera cerca de 95 tokens/s com compilação, que é o que recomendamos.
+
+# Sobre o Agente
+
+A demo é uma versão alpha inicial de teste, a velocidade de inferência precisa ser otimizada, e há muitos bugs aguardando correção. Se você encontrou um bug ou quer corrigi-lo, ficaremos muito felizes em receber uma issue ou um pull request.
diff --git a/docs/requirements.txt b/docs/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d6e145dbea1b9b26b2bddd7500e3f270b3eb0009
--- /dev/null
+++ b/docs/requirements.txt
@@ -0,0 +1,3 @@
+mkdocs-material
+mkdocs-static-i18n[material]
+mkdocs[i18n]
diff --git a/docs/stylesheets/extra.css b/docs/stylesheets/extra.css
new file mode 100644
index 0000000000000000000000000000000000000000..a88af87b3cdbfd2d6b05f39877d5821bb7ebe119
--- /dev/null
+++ b/docs/stylesheets/extra.css
@@ -0,0 +1,3 @@
+.md-grid {
+ max-width: 1440px;
+}
diff --git a/docs/zh/finetune.md b/docs/zh/finetune.md
new file mode 100644
index 0000000000000000000000000000000000000000..f7db80c9ec01551b79315d3af38a7fe6036c79ee
--- /dev/null
+++ b/docs/zh/finetune.md
@@ -0,0 +1,139 @@
+# 微调
+
+显然, 当你打开这个页面的时候, 你已经对预训练模型 zero-shot 的效果不算满意. 你想要微调一个模型, 使得它在你的数据集上表现更好.
+
+在目前版本,你只需要微调'LLAMA'部分即可.
+
+## LLAMA 微调
+### 1. 准备数据集
+
+```
+.
+├── SPK1
+│ ├── 21.15-26.44.lab
+│ ├── 21.15-26.44.mp3
+│ ├── 27.51-29.98.lab
+│ ├── 27.51-29.98.mp3
+│ ├── 30.1-32.71.lab
+│ └── 30.1-32.71.mp3
+└── SPK2
+ ├── 38.79-40.85.lab
+ └── 38.79-40.85.mp3
+```
+
+你需要将数据集转为以上格式, 并放到 `data` 下, 音频后缀可以为 `.mp3`, `.wav` 或 `.flac`, 标注文件后缀建议为 `.lab`.
+
+!!! info
+ 标注文件 `.lab` 仅需包含音频的转写文本,无需遵循特殊格式要求。例如,如果 `hi.mp3` 中的内容是“你好,再见。”,那么 `hi.lab` 文件中只需包含一行文本:“你好,再见”。
+
+!!! warning
+ 建议先对数据集进行响度匹配, 你可以使用 [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) 来完成这一步骤.
+ ```bash
+ fap loudness-norm data-raw data --clean
+ ```
+
+### 2. 批量提取语义 token
+
+确保你已经下载了 vqgan 权重, 如果没有, 请运行以下命令:
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+```
+
+对于中国大陆用户, 可使用 mirror 下载.
+
+```bash
+HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+```
+
+随后可运行以下命令来提取语义 token:
+
+```bash
+python tools/vqgan/extract_vq.py data \
+ --num-workers 1 --batch-size 16 \
+ --config-name "firefly_gan_vq" \
+ --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+```
+
+!!! note
+ 你可以调整 `--num-workers` 和 `--batch-size` 来提高提取速度, 但是请注意不要超过你的显存限制.
+
+该命令会在 `data` 目录下创建 `.npy` 文件, 如下所示:
+
+```
+.
+├── SPK1
+│ ├── 21.15-26.44.lab
+│ ├── 21.15-26.44.mp3
+│ ├── 21.15-26.44.npy
+│ ├── 27.51-29.98.lab
+│ ├── 27.51-29.98.mp3
+│ ├── 27.51-29.98.npy
+│ ├── 30.1-32.71.lab
+│ ├── 30.1-32.71.mp3
+│ └── 30.1-32.71.npy
+└── SPK2
+ ├── 38.79-40.85.lab
+ ├── 38.79-40.85.mp3
+ └── 38.79-40.85.npy
+```
+
+### 3. 打包数据集为 protobuf
+
+```bash
+python tools/llama/build_dataset.py \
+ --input "data" \
+ --output "data/protos" \
+ --text-extension .lab \
+ --num-workers 16
+```
+
+命令执行完毕后, 你应该能在 `data` 目录下看到 `protos` 文件.
+
+
+### 4. 最后, 使用 LoRA 进行微调
+
+同样的, 请确保你已经下载了 `LLAMA` 权重, 如果没有, 请运行以下命令:
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+```
+
+对于中国大陆用户, 可使用 mirror 下载.
+
+```bash
+HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+```
+
+最后, 你可以运行以下命令来启动微调:
+
+```bash
+python fish_speech/train.py --config-name text2semantic_finetune \
+ project=$project \
+ +lora@model.model.lora_config=r_8_alpha_16
+```
+
+!!! note
+ 你可以通过修改 `fish_speech/configs/text2semantic_finetune.yaml` 来修改训练参数如 `batch_size`, `gradient_accumulation_steps` 等, 来适应你的显存.
+
+!!! note
+ 对于 Windows 用户, 你可以使用 `trainer.strategy.process_group_backend=gloo` 来避免 `nccl` 的问题.
+
+训练结束后, 你可以参考 [推理](inference.md) 部分来测试你的模型.
+
+!!! info
+ 默认配置下, 基本只会学到说话人的发音方式, 而不包含音色, 你依然需要使用 prompt 来保证音色的稳定性.
+ 如果你想要学到音色, 请将训练步数调大, 但这有可能会导致过拟合.
+
+训练完成后, 你需要先将 loRA 的权重转为普通权重, 然后再进行推理.
+
+```bash
+python tools/llama/merge_lora.py \
+ --lora-config r_8_alpha_16 \
+ --base-weight checkpoints/fish-speech-1.4 \
+ --lora-weight results/$project/checkpoints/step_000000010.ckpt \
+ --output checkpoints/fish-speech-1.4-yth-lora/
+```
+
+!!! note
+ 你也可以尝试其他的 checkpoint, 我们建议你使用最早的满足你要求的 checkpoint, 他们通常在 OOD 上表现更好.
diff --git a/docs/zh/index.md b/docs/zh/index.md
new file mode 100644
index 0000000000000000000000000000000000000000..f108c0a6c1731e47d158cada149642c6f83e7f6c
--- /dev/null
+++ b/docs/zh/index.md
@@ -0,0 +1,218 @@
+# 介绍
+
+
+
+!!! warning "警告"
+ 我们不对代码库的任何非法使用承担任何责任. 请参阅您当地关于 DMCA (数字千年法案) 和其他相关法律法规.
+ 此代码库与所有模型根据 CC-BY-NC-SA-4.0 许可证发布.
+
+
+
+
+
+## 要求
+
+- GPU 内存: 4GB (用于推理), 8GB (用于微调)
+- 系统: Linux, Windows
+
+## Windows 配置
+
+Windows 专业用户可以考虑 WSL2 或 docker 来运行代码库。
+
+```bash
+# 创建一个 python 3.10 虚拟环境, 你也可以用 virtualenv
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+
+# 安装 pytorch
+pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121
+
+# 安装 fish-speech
+pip3 install -e .
+
+# (开启编译加速) 安装 triton-windows
+pip install https://github.com/AnyaCoder/fish-speech/releases/download/v0.1.0/triton_windows-0.1.0-py3-none-any.whl
+```
+
+Windows 非专业用户可考虑以下为免 Linux 环境的基础运行方法(附带模型编译功能,即 `torch.compile`):
+
+1. 解压项目压缩包。
+2. 点击 `install_env.bat` 安装环境。
+3. 若需要开启编译加速则执行这一步:
+ 1. 使用如下链接下载 LLVM 编译器。
+ - [LLVM-17.0.6(原站站点下载)](https://huggingface.co/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
+ - [LLVM-17.0.6(镜像站点下载)](https://hf-mirror.com/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
+ - 下载完 `LLVM-17.0.6-win64.exe` 后,双击进行安装,选择合适的安装位置,最重要的是勾选 `Add Path to Current User` 添加环境变量。
+ - 确认安装完成。
+ 2. 下载安装 Microsoft Visual C++ 可再发行程序包,解决潜在 .dll 丢失问题。
+ - [MSVC++ 14.40.33810.0 下载](https://aka.ms/vs/17/release/vc_redist.x64.exe)
+ 3. 下载安装 Visual Studio 社区版以获取 MSVC++ 编译工具, 解决 LLVM 的头文件依赖问题。
+ - [Visual Studio 下载](https://visualstudio.microsoft.com/zh-hans/downloads/)
+ - 安装好 Visual Studio Installer 之后,下载 Visual Studio Community 2022
+ - 如下图点击`修改`按钮,找到`使用C++的桌面开发`项,勾选下载
+ 4. 下载安装 [CUDA Toolkit 12.x](https://developer.nvidia.com/cuda-12-1-0-download-archive?target_os=Windows&target_arch=x86_64)
+4. 双击 `start.bat` 打开训练推理 WebUI 管理界面. 如有需要,可照下列提示修改`API_FLAGS`.
+
+!!! info "可选"
+
+ 想启动 推理 WebUI 界面?编辑项目根目录下的 `API_FLAGS.txt`, 前三行修改成如下格式:
+ ```
+ --infer
+ # --api
+ # --listen ...
+ ...
+ ```
+
+!!! info "可选"
+
+ 想启动 API 服务器?编辑项目根目录下的 `API_FLAGS.txt`, 前三行修改成如下格式:
+ ```
+ # --infer
+ --api
+ --listen ...
+ ...
+ ```
+
+!!! info "可选"
+
+ 双击 `run_cmd.bat` 进入本项目的 conda/python 命令行环境
+
+## Linux 配置
+
+有关详细信息,请参见 [pyproject.toml](../../pyproject.toml)。
+```bash
+# 创建一个 python 3.10 虚拟环境, 你也可以用 virtualenv
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+
+# 安装 pytorch
+pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1
+
+# (Ubuntu / Debian 用户) 安装 sox + ffmpeg
+apt install libsox-dev ffmpeg
+
+# (Ubuntu / Debian 用户) 安装 pyaudio
+apt install build-essential \
+ cmake \
+ libasound-dev \
+ portaudio19-dev \
+ libportaudio2 \
+ libportaudiocpp0
+
+# 安装 fish-speech
+pip3 install -e .[stable]
+```
+
+## macos 配置
+
+如果您想在 MPS 上进行推理,请添加 `--device mps` 标志。
+有关推理速度的比较,请参考 [此 PR](https://github.com/fishaudio/fish-speech/pull/461#issuecomment-2284277772)。
+
+!!! 警告
+ `compile` 选项在 Apple Silicon 设备上尚未正式支持,因此推理速度没有提升的保证。
+
+```bash
+# create a python 3.10 virtual environment, you can also use virtualenv
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+# install pytorch
+pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1
+# install fish-speech
+pip install -e .[stable]
+```
+
+## Docker 配置
+
+1. 安装 NVIDIA Container Toolkit:
+
+ Docker 如果想使用 GPU 进行模型训练和推理,需要安装 NVIDIA Container Toolkit :
+
+ 对于 Ubuntu 用户:
+
+ ```bash
+ # 添加远程仓库
+ curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg \
+ && curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \
+ sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \
+ sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
+ # 安装 nvidia-container-toolkit
+ sudo apt-get update
+ sudo apt-get install -y nvidia-container-toolkit
+ # 重启 Docker 服务
+ sudo systemctl restart docker
+ ```
+
+ 对于使用其他 Linux 发行版的用户,安装指南请参考:[NVIDIA Container Toolkit Install-guide](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)。
+
+ 注:对于中国大陆的用户,您可能需要使用代理来完成相关工具的安装。
+
+2. 拉取并运行 fish-speech 镜像
+
+ ```shell
+ # 拉取镜像
+ docker pull fishaudio/fish-speech:latest-dev
+ # 运行镜像
+ docker run -it \
+ --name fish-speech \
+ --gpus all \
+ -p 7860:7860 \
+ fishaudio/fish-speech:latest-dev \
+ zsh
+ # 如果需要使用其他端口,请修改 -p 参数为 YourPort:7860
+ ```
+
+3. 下载模型依赖
+
+ 确保您在 docker 容器内的终端,然后再从我们的 huggingface 仓库下载所需的 `vqgan` 和 `llama` 模型。
+
+ ```bash
+ huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+ ```
+
+ 对于中国大陆用户,可以通过镜像站下载。
+
+ ```bash
+ HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+ ```
+
+4. 配置环境变量,访问 WebUI
+
+ 在 docker 容器内的终端,输入 `export GRADIO_SERVER_NAME="0.0.0.0"` ,从而让外部可以访问 docker 内的 gradio 服务。
+ 接着在 docker 容器内的终端,输入 `python tools/webui.py` 即可开启 WebUI 服务。
+
+ 如果是 WSL 或者是 MacOS ,访问 [http://localhost:7860](http://localhost:7860) 即可打开 WebUI 界面。
+
+ 如果是部署在服务器上,更换 localhost 为您的服务器 ip 即可。
+
+## 更新日志
+
+- 2024/09/10: 更新了 Fish-Speech 到 1.4, 增加了数据集大小, quantizer n_groups 4 -> 8.
+- 2024/07/02: 更新了 Fish-Speech 到 1.2 版本,移除 VITS Decoder,同时极大幅度提升 zero-shot 能力.
+- 2024/05/10: 更新了 Fish-Speech 到 1.1 版本,引入了 VITS Decoder 来降低口胡和提高音色相似度.
+- 2024/04/22: 完成了 Fish-Speech 1.0 版本, 大幅修改了 VQGAN 和 LLAMA 模型.
+- 2023/12/28: 添加了 `lora` 微调支持.
+- 2023/12/27: 添加了 `gradient checkpointing`, `causual sampling` 和 `flash-attn` 支持.
+- 2023/12/19: 更新了 Webui 和 HTTP API.
+- 2023/12/18: 更新了微调文档和相关例子.
+- 2023/12/17: 更新了 `text2semantic` 模型, 支持无音素模式.
+- 2023/12/13: 测试版发布, 包含 VQGAN 模型和一个基于 LLAMA 的语言模型 (只支持音素).
+
+## 致谢
+
+- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
+- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
+- [GPT VITS](https://github.com/innnky/gpt-vits)
+- [MQTTS](https://github.com/b04901014/MQTTS)
+- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
+- [Transformers](https://github.com/huggingface/transformers)
+- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
diff --git a/docs/zh/inference.md b/docs/zh/inference.md
new file mode 100644
index 0000000000000000000000000000000000000000..0c679be0be5b2960c9027347b0bc0b842f2b5bfe
--- /dev/null
+++ b/docs/zh/inference.md
@@ -0,0 +1,143 @@
+# 推理
+
+推理支持命令行, http api, 以及 webui 三种方式.
+
+!!! note
+ 总的来说, 推理分为几个部分:
+
+ 1. 给定一段 ~10 秒的语音, 将它用 VQGAN 编码.
+ 2. 将编码后的语义 token 和对应文本输入语言模型作为例子.
+ 3. 给定一段新文本, 让模型生成对应的语义 token.
+ 4. 将生成的语义 token 输入 VQGAN 解码, 生成对应的语音.
+
+## 命令行推理
+
+从我们的 huggingface 仓库下载所需的 `vqgan` 和 `llama` 模型。
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+```
+
+对于中国大陆用户,可使用 mirror 下载。
+
+```bash
+HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+```
+
+### 1. 从语音生成 prompt:
+
+!!! note
+ 如果你打算让模型随机选择音色, 你可以跳过这一步.
+
+```bash
+python tools/vqgan/inference.py \
+ -i "paimon.wav" \
+ --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+```
+
+你应该能得到一个 `fake.npy` 文件.
+
+### 2. 从文本生成语义 token:
+
+```bash
+python tools/llama/generate.py \
+ --text "要转换的文本" \
+ --prompt-text "你的参考文本" \
+ --prompt-tokens "fake.npy" \
+ --checkpoint-path "checkpoints/fish-speech-1.4" \
+ --num-samples 2 \
+ --compile
+```
+
+该命令会在工作目录下创建 `codes_N` 文件, 其中 N 是从 0 开始的整数.
+
+!!! note
+ 您可能希望使用 `--compile` 来融合 cuda 内核以实现更快的推理 (~30 个 token/秒 -> ~500 个 token/秒).
+ 对应的, 如果你不打算使用加速, 你可以注释掉 `--compile` 参数.
+
+!!! info
+ 对于不支持 bf16 的 GPU, 你可能需要使用 `--half` 参数.
+
+### 3. 从语义 token 生成人声:
+
+#### VQGAN 解码
+
+```bash
+python tools/vqgan/inference.py \
+ -i "codes_0.npy" \
+ --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+```
+
+## HTTP API 推理
+
+运行以下命令来启动 HTTP 服务:
+
+```bash
+python -m tools.api \
+ --listen 0.0.0.0:8080 \
+ --llama-checkpoint-path "checkpoints/fish-speech-1.4" \
+ --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+ --decoder-config-name firefly_gan_vq
+```
+> 如果你想要加速推理,可以加上`--compile`参数。
+
+推荐中国大陆用户运行以下命令来启动 HTTP 服务:
+```bash
+HF_ENDPOINT=https://hf-mirror.com python -m ...(同上)
+```
+
+随后, 你可以在 `http://127.0.0.1:8080/` 中查看并测试 API.
+
+下面是使用`tools/post_api.py`发送请求的示例。
+
+```bash
+python -m tools.post_api \
+ --text "要输入的文本" \
+ --reference_audio "参考音频路径" \
+ --reference_text "参考音频的文本内容" \
+ --streaming True
+```
+
+上面的命令表示按照参考音频的信息,合成所需的音频并流式返回.
+
+下面的示例展示了, 可以一次使用**多个** `参考音频路径` 和 `参考音频的文本内容`。在命令里用空格隔开即可。
+```bash
+python -m tools.post_api \
+ --text "要输入的文本" \
+ --reference_audio "参考音频路径1" "参考音频路径2" \
+ --reference_text "参考音频的文本内容1" "参考音频的文本内容2"\
+ --streaming False \
+ --output "generated" \
+ --format "mp3"
+```
+
+上面的命令表示按照多个参考音频的信息,合成所需的`MP3`格式音频,并保存为当前目录的`generated.mp3`文件。
+
+还可以用`--reference_id`(仅能用一个)来代替`--reference_audio`和`--reference_text`, 前提是在项目根目录下创建`references/`文件夹,
+里面放上任意对音频与标注文本。 目前支持的参考音频最多加起来总时长90s。
+
+!!! info
+ 要了解有关可用参数的更多信息,可以使用命令`python -m tools.post_api -h`
+
+## GUI 推理
+[下载客户端](https://github.com/AnyaCoder/fish-speech-gui/releases)
+
+## WebUI 推理
+
+你可以使用以下命令来启动 WebUI:
+
+```bash
+python -m tools.webui \
+ --llama-checkpoint-path "checkpoints/fish-speech-1.4" \
+ --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+ --decoder-config-name firefly_gan_vq
+```
+> 如果你想要加速推理,可以加上`--compile`参数。
+
+!!! note
+ 你可以提前将label文件和参考音频文件保存到主目录下的 `references` 文件夹(需要自行创建),这样你可以直接在WebUI中调用它们。
+
+!!! note
+ 你可以使用 Gradio 环境变量, 如 `GRADIO_SHARE`, `GRADIO_SERVER_PORT`, `GRADIO_SERVER_NAME` 来配置 WebUI.
+
+祝大家玩得开心!
diff --git a/docs/zh/samples.md b/docs/zh/samples.md
new file mode 100644
index 0000000000000000000000000000000000000000..bdcce108de01aa6b0da5fdea737fdd41450f6396
--- /dev/null
+++ b/docs/zh/samples.md
@@ -0,0 +1,225 @@
+# 例子
+
+v1.4 演示已更新至[此处](https://speech.fish.audio/samples/)。
+
+v1.2 的样本可以在 [Bilibili](https://www.bilibili.com/video/BV1wz421B71D/) 观看。
+
+以下样本来自 v1.1 版本的模型。
+
+## 中文句子 1
+```
+人间灯火倒映湖中,她的渴望让静水泛起涟漪。若代价只是孤独,那就让这份愿望肆意流淌。
+流入她所注视的世间,也流入她如湖水般澄澈的目光。
+```
+
+
+
+
+ 说话人
+ 输入音频
+ 合成音频
+
+
+
+
+ 纳西妲 (原神)
+
+
+
+
+ 钟离 (原神)
+
+
+
+
+ 芙宁娜 (原神)
+
+
+
+
+ 随机说话人 1
+ -
+
+
+
+ 随机说话人 2
+ -
+
+
+
+
+
+
+## 中文句子 2
+```
+你们这个是什么群啊,你们这是害人不浅啊你们这个群!谁是群主,出来!真的太过分了。你们搞这个群干什么?
+我儿子每一科的成绩都不过那个平均分呐,他现在初二,你叫我儿子怎么办啊?他现在还不到高中啊?
+你们害死我儿子了!快点出来你这个群主!再这样我去报警了啊!我跟你们说你们这一帮人啊,一天到晚啊,
+搞这些什么游戏啊,动漫啊,会害死你们的,你们没有前途我跟你说。你们这九百多个人,好好学习不好吗?
+一天到晚在上网。有什么意思啊?麻烦你重视一下你们的生活的目标啊?有一点学习目标行不行?一天到晚上网是不是人啊?
+```
+
+
+
+
+ 说话人
+ 输入音频
+ 合成音频
+
+
+
+
+ 纳西妲 (原神)
+
+
+
+
+ 随机说话人
+ -
+
+
+
+
+
+
+## 中文句子 3
+```
+大家好,我是 Fish Audio 开发的开源文本转语音模型。经过十五万小时的数据训练,
+我已经能够熟练掌握中文、日语和英语,我的语言处理能力接近人类水平,声音表现形式丰富多变。
+作为一个仅有亿级参数的模型,我相信社区成员能够在个人设备上轻松运行和微调,让我成为您的私人语音助手。
+```
+
+
+
+
+
+ 说话人
+ 输入音频
+ 合成音频
+
+
+
+
+ 随机说话人
+ -
+
+
+
+
+
+## 英文句子 1
+
+```
+In the realm of advanced technology, the evolution of artificial intelligence stands as a
+monumental achievement. This dynamic field, constantly pushing the boundaries of what
+machines can do, has seen rapid growth and innovation. From deciphering complex data
+patterns to driving cars autonomously, AI's applications are vast and diverse.
+```
+
+
+
+
+ 说话人
+ 输入音频
+ 合成音频
+
+
+
+
+ 随机说话人 1
+ -
+
+
+
+ 随机说话人 2
+ -
+
+
+
+
+
+## 英文句子 2
+```
+Hello everyone, I am an open-source text-to-speech model developed by
+Fish Audio. After training with 150,000 hours of data, I have become proficient
+in Chinese, Japanese, and English, and my language processing abilities
+are close to human level. My voice is capable of a wide range of expressions.
+As a model with only hundreds of millions of parameters, I believe community
+members can easily run and fine-tune me on their personal devices, allowing
+me to serve as your personal voice assistant.
+```
+
+
+
+
+ 说话人
+ 输入音频
+ 合成音频
+
+
+
+
+ 随机说话人
+ -
+
+
+
+
+
+## 日文句子 1
+
+```
+先進技術の領域において、人工知能の進化は画期的な成果として立っています。常に機械ができることの限界を
+押し広げているこのダイナミックな分野は、急速な成長と革新を見せています。複雑なデータパターンの解読か
+ら自動運転車の操縦まで、AIの応用は広範囲に及びます。
+```
+
+
+
+
+
+ 说话人
+ 输入音频
+ 合成音频
+
+
+
+
+ 随机说话人 1
+ -
+
+
+
+ 随机说话人 2
+ -
+
+
+
+
+
+## 日文句子 2
+```
+皆さん、こんにちは。私はフィッシュオーディオによって開発されたオープンソースのテ
+キストから音声への変換モデルです。15万時間のデータトレーニングを経て、
+中国語、日本語、英語を熟知しており、言語処理能力は人間に近いレベルです。
+声の表現も多彩で豊かです。数億のパラメータを持つこのモデルは、コミュニティ
+のメンバーが個人のデバイスで簡単に実行し、微調整することができると
+信じています。これにより、私を個人の音声アシスタントとして活用できます。
+```
+
+
+
+
+ 说话人
+ 输入音频
+ 合成音频
+
+
+
+
+ 随机说话人
+ -
+
+
+
+
diff --git a/docs/zh/start_agent.md b/docs/zh/start_agent.md
new file mode 100644
index 0000000000000000000000000000000000000000..799cfadc2fd66b091a600abf92844ed1ee24cb20
--- /dev/null
+++ b/docs/zh/start_agent.md
@@ -0,0 +1,83 @@
+# 启动 Agent
+
+## 要求
+
+- GPU 显存: 至少 8GB(在量化的条件下),推荐 16GB 及以上
+- 硬盘使用量: 10GB
+
+## 下载模型
+
+你可以执行下面的语句来获取模型:
+
+```bash
+huggingface-cli download fishaudio/fish-agent-v0.1-3b --local-dir checkpoints/fish-agent-v0.1-3b
+```
+
+如果你处于国内网络,首先执行:
+
+```bash
+export HF_ENDPOINT=https://hf-mirror.com
+```
+
+把他们放进名为 'checkpoints' 的文件夹内。
+
+你同样需要 fish-speech 的模型,关于如何获取 fish-speech 模型请查看[inference](inference.md)。
+
+完成后你的 checkpoints 文件夹中会有两个子文件夹:`checkpoints/fish-speech-1.4` 和 `checkpoints/fish-agent-v0.1-3b`。
+
+## Environment Prepare
+
+如果你已经有了 Fish-Speech 环境,你可以在安装下面的包的前提下直接使用:
+
+```bash
+pip install cachetools
+```
+
+!!! note
+请使用小于 3.12 的 python 版本使 compile 可用
+
+如果你没有 Fish-Speech 环境,请执行下面的语句来构造你的环境:
+
+```bash
+sudo apt-get install portaudio19-dev
+
+pip install -e .[stable]
+```
+
+## 链接 Agent.
+
+你需要使用以下指令来构建 fish-agent
+
+```bash
+python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
+```
+
+`--compile`只能在小于 3.12 版本的 Python 使用,这个功能可以极大程度上提高生成速度。
+
+你需要哦注意 compile 需要进行一段时间.
+
+然后启动另一个终端并执行:
+
+```bash
+python -m tools.e2e_webui
+```
+
+这会在设备上创建一个 Gradio WebUI。
+
+每当进行第一轮对话的时候,模型需要 compile 一段时间,请耐心等待
+
+## Gradio Webui
+
+
+
+
+
+玩得开心!
+
+## Performance
+
+在我们的测试环境下, 4060 laptop GPU 只能刚刚运行该模型,只有大概 8 tokens/s。 4090 CPU 可以在编译后达到 95 tokens/s,我们推荐使用至少 4080 以上级别的 GPU 来达到较好体验。
+
+# About Agent
+
+该模型仍处于测试阶段。如果你发现了问题,请给我们提 issue 或者 pull request,我们非常感谢。
diff --git a/entrypoint.sh b/entrypoint.sh
new file mode 100755
index 0000000000000000000000000000000000000000..d9e931429835cf454fd1a4e027b23bbee4875b65
--- /dev/null
+++ b/entrypoint.sh
@@ -0,0 +1,10 @@
+#!/bin/bash
+
+CUDA_ENABLED=${CUDA_ENABLED:-true}
+DEVICE=""
+
+if [ "${CUDA_ENABLED}" != "true" ]; then
+ DEVICE="--device cpu"
+fi
+
+exec python tools/webui.py ${DEVICE}
diff --git a/fish_speech/callbacks/__init__.py b/fish_speech/callbacks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbcf3f33656d180ca87cd14a21ede1544e5a61a3
--- /dev/null
+++ b/fish_speech/callbacks/__init__.py
@@ -0,0 +1,3 @@
+from .grad_norm import GradNormMonitor
+
+__all__ = ["GradNormMonitor"]
diff --git a/fish_speech/callbacks/grad_norm.py b/fish_speech/callbacks/grad_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbc95ef2a3723323b2d976001ed1e3c79c00b21a
--- /dev/null
+++ b/fish_speech/callbacks/grad_norm.py
@@ -0,0 +1,113 @@
+from typing import Optional, Union
+
+import lightning.pytorch as pl
+import torch
+from lightning import LightningModule, Trainer
+from lightning.pytorch.callbacks import Callback
+from torch import Tensor, nn
+from torch.utils._foreach_utils import (
+ _group_tensors_by_device_and_dtype,
+ _has_foreach_support,
+)
+
+
+@torch.no_grad()
+def grad_norm(
+ parameters: Union[Tensor, list[Tensor]],
+ norm_type: float = 2.0,
+) -> float:
+ """
+ Returns the norm of the gradients of the given parameters.
+
+ Args:
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
+ single Tensor that will have gradients normalized
+ norm_type (float): type of the used p-norm.
+
+ Returns:
+ Total norm of the parameter gradients (viewed as a single vector).
+ """ # noqa: E501
+
+ if isinstance(parameters, Tensor):
+ parameters = [parameters]
+
+ grads = [p.grad for p in parameters if p.grad is not None]
+ if len(grads) == 0:
+ return None
+
+ first_device = grads[0].device
+ grouped_grads: dict[
+ tuple[torch.device, torch.dtype], list[list[Tensor]]
+ ] = _group_tensors_by_device_and_dtype(
+ [[g.detach() for g in grads]]
+ ) # type: ignore[assignment]
+
+ norms = []
+ for (device, _), ([grads], _) in grouped_grads.items():
+ if _has_foreach_support(grads, device=device):
+ norms.extend(torch._foreach_norm(grads, norm_type))
+ else:
+ norms.extend([torch.norm(g, norm_type) for g in grads])
+
+ return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
+
+
+class GradNormMonitor(Callback):
+ """
+ Callback that computes the gradient norm of the model parameters.
+ """
+
+ def __init__(
+ self,
+ norm_type: float = 2.0,
+ logging_interval: str = "step",
+ sub_module: Optional[Union[str, list[str]]] = None,
+ ) -> None:
+ """
+ Args:
+ norm_type (float): type of the used p-norm.
+ logging_interval (str): "step" or "epoch".
+ """
+ super().__init__()
+
+ self.norm_type = norm_type
+ self.logging_interval = logging_interval
+ self.sub_module = sub_module
+
+ def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None:
+ """
+ Computes the gradient norm of the model parameters and logs it to the logger.
+
+ Args:
+ trainer (Trainer): The trainer object
+ model (LightningModule): The current lightningModule
+ """
+
+ lightning_model = model
+
+ if self.sub_module is None:
+ return self.log_sub_module_grad_norm(lightning_model, model, "")
+
+ sub_modules = self.sub_module
+ if isinstance(sub_modules, str):
+ sub_modules = [sub_modules]
+
+ for sub_module in sub_modules:
+ self.log_sub_module_grad_norm(
+ lightning_model, getattr(model, sub_module), f"/{sub_module}"
+ )
+
+ def log_sub_module_grad_norm(
+ self, lightning_model: LightningModule, model: nn.Module, path: str
+ ) -> None:
+ grad_norm_val = grad_norm(model.parameters(), self.norm_type)
+ if grad_norm_val is None:
+ return
+
+ on_step = self.logging_interval == "step"
+ lightning_model.log(
+ f"train{path}/grad_norm",
+ grad_norm_val,
+ on_step=on_step,
+ on_epoch=not on_step,
+ )
diff --git a/fish_speech/configs/base.yaml b/fish_speech/configs/base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..99e6dab54d3f57bce4f6d29a9129a19a523cad75
--- /dev/null
+++ b/fish_speech/configs/base.yaml
@@ -0,0 +1,87 @@
+# Base configuration for training a model
+paths:
+ run_dir: results/${project}
+ ckpt_dir: ${paths.run_dir}/checkpoints
+
+hydra:
+ run:
+ dir: ${paths.run_dir}
+
+# Lightning Trainer
+trainer:
+ _target_: lightning.pytorch.trainer.Trainer
+
+ default_root_dir: ${paths.run_dir}
+ accelerator: gpu
+ num_nodes: 1
+ devices: auto
+ strategy:
+ _target_: lightning.pytorch.strategies.DDPStrategy
+ process_group_backend: nccl # This should be override when training on windows
+
+ precision: bf16-mixed
+
+ # disable validation by epoch end
+ check_val_every_n_epoch: null
+ val_check_interval: 5000
+ max_steps: 100_000
+
+ # Use torch.backends.cudnn.benchmark to speed up training
+ benchmark: true
+
+# Callbacks
+callbacks:
+ model_checkpoint:
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
+ dirpath: ${paths.ckpt_dir}
+ filename: "step_{step:09d}"
+ save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt
+ save_top_k: 5 # save 5 latest checkpoints
+ monitor: step # use step to monitor checkpoints
+ mode: max # save the latest checkpoint with the highest global_step
+ every_n_epochs: null # don't save checkpoints by epoch end
+ every_n_train_steps: 5000 # save checkpoints every 5000 steps
+ auto_insert_metric_name: false
+
+ model_summary:
+ _target_: lightning.pytorch.callbacks.ModelSummary
+ max_depth: 2 # the maximum depth of layer nesting that the summary will include
+
+ learning_rate_monitor:
+ _target_: lightning.pytorch.callbacks.LearningRateMonitor
+ logging_interval: step
+ log_momentum: false
+
+ grad_norm_monitor:
+ _target_: fish_speech.callbacks.GradNormMonitor
+ norm_type: 2
+ logging_interval: step
+
+# Logger
+logger:
+ tensorboard:
+ _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
+ save_dir: "${paths.run_dir}/tensorboard/"
+ name: null
+ log_graph: false
+ default_hp_metric: true
+ prefix: ""
+
+ # wandb:
+ # _target_: lightning.pytorch.loggers.wandb.WandbLogger
+ # # name: "" # name of the run (normally generated by wandb)
+ # save_dir: "${paths.run_dir}"
+ # offline: False
+ # id: null # pass correct id to resume experiment!
+ # anonymous: null # enable anonymous logging
+ # project: "fish-speech"
+ # log_model: False # upload lightning ckpts
+ # prefix: "" # a string to put at the beginning of metric keys
+ # # entity: "" # set to name of your wandb team
+ # group: ""
+ # tags: ["vq", "hq", "finetune"]
+ # job_type: ""
+
+# Loop
+train: true
+test: false
diff --git a/fish_speech/configs/firefly_gan_vq.yaml b/fish_speech/configs/firefly_gan_vq.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..10aa8d4a522f0859ed8f541f5d48672d84b39c8f
--- /dev/null
+++ b/fish_speech/configs/firefly_gan_vq.yaml
@@ -0,0 +1,33 @@
+_target_: fish_speech.models.vqgan.modules.firefly.FireflyArchitecture
+spec_transform:
+ _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
+ sample_rate: 44100
+ n_mels: 160
+ n_fft: 2048
+ hop_length: 512
+ win_length: 2048
+backbone:
+ _target_: fish_speech.models.vqgan.modules.firefly.ConvNeXtEncoder
+ input_channels: 160
+ depths: [3, 3, 9, 3]
+ dims: [128, 256, 384, 512]
+ drop_path_rate: 0.2
+ kernel_size: 7
+head:
+ _target_: fish_speech.models.vqgan.modules.firefly.HiFiGANGenerator
+ hop_length: 512
+ upsample_rates: [8, 8, 2, 2, 2] # aka. strides
+ upsample_kernel_sizes: [16, 16, 4, 4, 4]
+ resblock_kernel_sizes: [3, 7, 11]
+ resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
+ num_mels: 512
+ upsample_initial_channel: 512
+ pre_conv_kernel_size: 13
+ post_conv_kernel_size: 13
+quantizer:
+ _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
+ input_dim: 512
+ n_groups: 8
+ n_codebooks: 1
+ levels: [8, 5, 5, 5]
+ downsample_factor: [2, 2]
diff --git a/fish_speech/configs/lora/r_8_alpha_16.yaml b/fish_speech/configs/lora/r_8_alpha_16.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..aecc4d9766a18fe31c55941e01b1f590c95e77c9
--- /dev/null
+++ b/fish_speech/configs/lora/r_8_alpha_16.yaml
@@ -0,0 +1,4 @@
+_target_: fish_speech.models.text2semantic.lora.LoraConfig
+r: 8
+lora_alpha: 16
+lora_dropout: 0.01
diff --git a/fish_speech/configs/text2semantic_finetune.yaml b/fish_speech/configs/text2semantic_finetune.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f4c1993023099e122fc9e004bda55ec075ed5e1b
--- /dev/null
+++ b/fish_speech/configs/text2semantic_finetune.yaml
@@ -0,0 +1,83 @@
+defaults:
+ - base
+ - _self_
+
+project: text2semantic_finetune_dual_ar
+max_length: 4096
+pretrained_ckpt_path: checkpoints/fish-speech-1.4
+
+# Lightning Trainer
+trainer:
+ accumulate_grad_batches: 1
+ gradient_clip_val: 1.0
+ gradient_clip_algorithm: "norm"
+ max_steps: 1000
+ precision: bf16-true
+ limit_val_batches: 10
+ val_check_interval: 100
+
+# Dataset Configuration
+tokenizer:
+ _target_: transformers.AutoTokenizer.from_pretrained
+ pretrained_model_name_or_path: ${pretrained_ckpt_path}
+
+# Dataset Configuration
+train_dataset:
+ _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
+ proto_files:
+ - data/protos
+ tokenizer: ${tokenizer}
+ causal: true
+ max_length: ${max_length}
+ use_speaker: false
+ interactive_prob: 0.7
+
+val_dataset:
+ _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
+ proto_files:
+ - data/protos
+ tokenizer: ${tokenizer}
+ causal: true
+ max_length: ${max_length}
+ use_speaker: false
+ interactive_prob: 0.7
+
+data:
+ _target_: fish_speech.datasets.semantic.SemanticDataModule
+ train_dataset: ${train_dataset}
+ val_dataset: ${val_dataset}
+ num_workers: 4
+ batch_size: 8
+ tokenizer: ${tokenizer}
+ max_length: ${max_length}
+
+# Model Configuration
+model:
+ _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic
+ model:
+ _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained
+ path: ${pretrained_ckpt_path}
+ load_weights: true
+ max_length: ${max_length}
+ lora_config: null
+
+ optimizer:
+ _target_: torch.optim.AdamW
+ _partial_: true
+ lr: 1e-4
+ weight_decay: 0
+ betas: [0.9, 0.95]
+ eps: 1e-5
+
+ lr_scheduler:
+ _target_: torch.optim.lr_scheduler.LambdaLR
+ _partial_: true
+ lr_lambda:
+ _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda
+ _partial_: true
+ num_warmup_steps: 10
+
+# Callbacks
+callbacks:
+ model_checkpoint:
+ every_n_train_steps: ${trainer.val_check_interval}
diff --git a/fish_speech/conversation.py b/fish_speech/conversation.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bbc1cdb6c4a1d276ccf922988a7ad13e058d70a
--- /dev/null
+++ b/fish_speech/conversation.py
@@ -0,0 +1,256 @@
+from dataclasses import dataclass, field
+from typing import Literal
+
+import torch
+from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerFast
+
+IM_START_TOKEN = "<|im_start|>"
+IM_END_TOKEN = "<|im_end|>"
+SEMANTIC_TOKEN = "<|semantic|>"
+MEL_TOKEN = "<|mel|>"
+PHONEME_START_TOKEN = "<|phoneme_start|>"
+PHONEME_END_TOKEN = "<|phoneme_end|>"
+ALL_SPECIAL_TOKENS = [
+ IM_START_TOKEN,
+ IM_END_TOKEN,
+ SEMANTIC_TOKEN,
+ MEL_TOKEN,
+ PHONEME_START_TOKEN,
+ PHONEME_END_TOKEN,
+]
+
+CODEBOOK_PAD_TOKEN_ID = 0
+
+
+class FishTokenizerConfig(PretrainedConfig):
+ share_codebook_embeddings: bool = True
+ codebook_size: int = 1024
+ num_codebooks: int = 8
+
+
+class FishTokenizerFast(PreTrainedTokenizerFast):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.share_codebook_embeddings = kwargs.pop("share_codebook_embeddings", True)
+ self.codebook_size = kwargs.pop("codebook_size", 1024)
+ self.num_codebooks = kwargs.pop("num_codebooks", 8)
+
+
+AutoTokenizer.register(FishTokenizerConfig, fast_tokenizer_class=FishTokenizerFast)
+
+
+@dataclass(kw_only=True)
+class BasePart:
+ pass
+
+
+@dataclass(kw_only=True)
+class VQPart(BasePart):
+ codes: torch.Tensor
+
+
+@dataclass(kw_only=True)
+class TextPart(BasePart):
+ text: str
+
+
+@dataclass(kw_only=True)
+class MelPart(BasePart):
+ mels: torch.Tensor
+
+
+@dataclass(kw_only=True)
+class EncodedMessage:
+ tokens: torch.Tensor
+ labels: torch.Tensor
+ vq_parts: list[torch.Tensor]
+ mel_parts: list[torch.Tensor]
+ vq_require_losses: torch.Tensor | None = None
+
+
+@dataclass(kw_only=True)
+class Message:
+ role: Literal["system", "user", "assistant"]
+ parts: list[VQPart | TextPart | MelPart] = field(default_factory=list)
+ add_im_start: bool = True
+ add_im_end: bool = True
+ cal_loss: bool = False
+
+ # By default, ignore the loss of the auto-generated im_start token
+ ignore_im_start_loss: bool = True
+
+ def encode(
+ self: "Message",
+ tokenizer: AutoTokenizer,
+ ) -> EncodedMessage:
+ all_tokens = []
+ all_labels = []
+
+ # Multi-modal tokens
+ vq_parts = []
+ mel_parts = []
+
+ semantic_id, mel_id = tokenizer.convert_tokens_to_ids(
+ [SEMANTIC_TOKEN, MEL_TOKEN]
+ )
+
+ parts = self.parts.copy()
+ if self.add_im_start:
+ parts.insert(0, TextPart(text=f"<|im_start|>{self.role}\n"))
+
+ if self.add_im_end:
+ parts.append(TextPart(text="<|im_end|>"))
+
+ for part in parts:
+ if isinstance(part, TextPart):
+ tokens = tokenizer.encode(
+ part.text,
+ add_special_tokens=False,
+ truncation=False,
+ return_tensors="pt",
+ ).int()[0]
+ elif isinstance(part, VQPart):
+ tokens = torch.zeros(part.codes.shape[1], dtype=torch.int) + semantic_id
+ codes = part.codes.clone() + 1
+
+ if getattr(tokenizer, "share_codebook_embeddings", True) is False:
+ for i in range(len(codes)):
+ codes[i] += tokenizer.codebook_size * i
+
+ vq_parts.append(codes)
+ elif isinstance(part, MelPart):
+ tokens = torch.zeros(part.mels.shape[1], dtype=torch.int) + mel_id
+ mel_parts.append(part.mels)
+ else:
+ raise ValueError(f"Unsupported part type: {type(part)}")
+
+ all_tokens.append(tokens)
+ if self.cal_loss:
+ all_labels.append(tokens.clone())
+ else:
+ all_labels.append(torch.full_like(tokens, -100))
+
+ tokens = torch.cat(all_tokens, dim=0)
+ labels = torch.cat(all_labels, dim=0)
+ assert tokens.shape == labels.shape
+
+ if self.ignore_im_start_loss and self.add_im_start:
+ labels[: len(all_tokens[0])] = -100
+
+ return EncodedMessage(
+ tokens=tokens,
+ labels=labels,
+ vq_parts=vq_parts,
+ mel_parts=mel_parts,
+ )
+
+
+@dataclass
+class Conversation:
+ messages: list[Message]
+
+ def encode(
+ self: "Conversation",
+ tokenizer: AutoTokenizer,
+ add_shift: bool = True,
+ ) -> EncodedMessage:
+ # Build the input_ids and labels
+ tokens = []
+ labels = []
+ vq_parts = []
+ mel_parts = []
+ vq_require_losses = []
+
+ for message in self.messages:
+ encoded = message.encode(
+ tokenizer,
+ )
+ tokens.append(encoded.tokens)
+ labels.append(encoded.labels)
+ vq_parts.extend(encoded.vq_parts)
+ mel_parts.extend(encoded.mel_parts)
+ vq_require_losses.extend([message.cal_loss] * len(encoded.vq_parts))
+
+ tokens = torch.cat(tokens, dim=0)
+ labels = torch.cat(labels, dim=0)
+ vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool)
+
+ if add_shift:
+ tokens = tokens[:-1]
+ labels = labels[1:]
+
+ assert tokens.dtype in [
+ torch.int,
+ torch.long,
+ ], f"Invalid dtype: {tokens.dtype}, conv: {conversation}"
+
+ return EncodedMessage(
+ tokens=tokens,
+ labels=labels,
+ vq_parts=vq_parts,
+ mel_parts=mel_parts,
+ vq_require_losses=vq_require_losses,
+ )
+
+ def encode_for_inference(
+ self: "Conversation",
+ tokenizer: AutoTokenizer,
+ num_codebooks: int,
+ ) -> EncodedMessage:
+ encoded = self.encode(tokenizer, add_shift=False)
+ tokens = encoded.tokens
+ values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int)
+ values[0] = tokens
+
+ if encoded.vq_parts is None or len(encoded.vq_parts) == 0:
+ return values
+
+ semantic_id, mel_id = tokenizer.convert_tokens_to_ids(
+ [SEMANTIC_TOKEN, MEL_TOKEN]
+ )
+ vq_parts = encoded.vq_parts
+ vq_parts = torch.cat(vq_parts, dim=1)
+ values[1:, tokens == semantic_id] = vq_parts
+ return values
+
+ def visualize(self: "Conversation", tokenizer: AutoTokenizer):
+ encoded = self.encode(tokenizer, add_shift=False)
+
+ print_in_blue = lambda x: print("\033[94m" + x + "\033[0m", end="")
+ print_in_green = lambda x: print("\033[92m" + x + "\033[0m", end="")
+
+ for tok, lab in zip(encoded.tokens, encoded.labels):
+ val = tokenizer.decode(tok, skip_special_tokens=False)
+ if val == "\n":
+ val = "\\n\n"
+
+ if lab == -100:
+ print_in_green(val)
+ else:
+ print_in_blue(val)
+
+ print()
+
+
+if __name__ == "__main__":
+ message0 = Message(
+ role="user",
+ parts=[
+ TextPart(text="Hello, how are you?"),
+ VQPart(codes=torch.zeros((4, 10))),
+ ],
+ cal_loss=False,
+ )
+
+ message1 = Message(
+ role="assistant",
+ parts=[TextPart(text="I'm fine, thank you.")],
+ cal_loss=True,
+ )
+ conversation = Conversation([message0, message1])
+ tokenizer = AutoTokenizer.from_pretrained("checkpoints/Qwen2-1.5B-Instruct")
+ conversation.visualize(tokenizer)
+
+ encoded = conversation.encode(tokenizer)
+ print(encoded)
+ print(tokenizer.batch_decode(encoded.tokens))
diff --git a/fish_speech/datasets/concat_repeat.py b/fish_speech/datasets/concat_repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..4aa596b95a572ee15c5570cbdb792c9a78e62dfa
--- /dev/null
+++ b/fish_speech/datasets/concat_repeat.py
@@ -0,0 +1,53 @@
+import bisect
+import random
+from typing import Iterable
+
+from torch.utils.data import Dataset, IterableDataset
+
+
+class ConcatRepeatDataset(Dataset):
+ datasets: list[Dataset]
+ cumulative_sizes: list[int]
+ repeats: list[int]
+
+ @staticmethod
+ def cumsum(sequence, repeats):
+ r, s = [], 0
+ for dataset, repeat in zip(sequence, repeats):
+ l = len(dataset) * repeat
+ r.append(l + s)
+ s += l
+ return r
+
+ def __init__(self, datasets: Iterable[Dataset], repeats: list[int]):
+ super().__init__()
+
+ self.datasets = list(datasets)
+ self.repeats = repeats
+
+ assert len(self.datasets) > 0, "datasets should not be an empty iterable"
+ assert len(self.datasets) == len(
+ repeats
+ ), "datasets and repeats should have the same length"
+
+ for d in self.datasets:
+ assert not isinstance(
+ d, IterableDataset
+ ), "ConcatRepeatDataset does not support IterableDataset"
+
+ self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
+
+ def __len__(self):
+ return self.cumulative_sizes[-1]
+
+ def __getitem__(self, idx):
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+
+ if dataset_idx == 0:
+ sample_idx = idx
+ else:
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
+
+ dataset = self.datasets[dataset_idx]
+
+ return dataset[sample_idx % len(dataset)]
diff --git a/fish_speech/datasets/protos/text-data.proto b/fish_speech/datasets/protos/text-data.proto
new file mode 100644
index 0000000000000000000000000000000000000000..5eb26d94aa3be1e21066f2bf38c90d54e85a8379
--- /dev/null
+++ b/fish_speech/datasets/protos/text-data.proto
@@ -0,0 +1,24 @@
+syntax = "proto3";
+
+package text_data;
+
+message Semantics {
+ repeated uint32 values = 1;
+}
+
+message Sentence {
+ repeated string texts = 1;
+ repeated Semantics semantics = 3;
+}
+
+message TextData {
+ string source = 1;
+ string name = 2;
+ repeated Sentence sentences = 4;
+}
+
+message SampledData {
+ string source = 1;
+ string name = 2;
+ repeated Sentence samples = 3;
+}
diff --git a/fish_speech/datasets/protos/text_data_pb2.py b/fish_speech/datasets/protos/text_data_pb2.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfce0e8be59fc51e68999ef137e1fd0e4adc0d7e
--- /dev/null
+++ b/fish_speech/datasets/protos/text_data_pb2.py
@@ -0,0 +1,33 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: text-data.proto
+# Protobuf Python Version: 4.25.1
+"""Generated protocol buffer code."""
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+from google.protobuf.internal import builder as _builder
+
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
+ b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3'
+)
+
+_globals = globals()
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals)
+if _descriptor._USE_C_DESCRIPTORS == False:
+ DESCRIPTOR._options = None
+ _globals["_SEMANTICS"]._serialized_start = 30
+ _globals["_SEMANTICS"]._serialized_end = 57
+ _globals["_SENTENCE"]._serialized_start = 59
+ _globals["_SENTENCE"]._serialized_end = 125
+ _globals["_TEXTDATA"]._serialized_start = 127
+ _globals["_TEXTDATA"]._serialized_end = 207
+ _globals["_SAMPLEDDATA"]._serialized_start = 209
+ _globals["_SAMPLEDDATA"]._serialized_end = 290
+# @@protoc_insertion_point(module_scope)
diff --git a/fish_speech/datasets/protos/text_data_stream.py b/fish_speech/datasets/protos/text_data_stream.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec3c25bcd764e8245de47dcdf9686d6adfb5a107
--- /dev/null
+++ b/fish_speech/datasets/protos/text_data_stream.py
@@ -0,0 +1,36 @@
+import struct
+
+from .text_data_pb2 import TextData
+
+
+def read_pb_stream(f):
+ while True:
+ buf = f.read(4)
+ if len(buf) == 0:
+ break
+ size = struct.unpack("I", buf)[0]
+ buf = f.read(size)
+ text_data = TextData()
+ text_data.ParseFromString(buf)
+ yield text_data
+
+
+def write_pb_stream(f, text_data):
+ buf = text_data.SerializeToString()
+ f.write(struct.pack("I", len(buf)))
+ f.write(buf)
+
+
+def pack_pb_stream(text_data):
+ buf = text_data.SerializeToString()
+ return struct.pack("I", len(buf)) + buf
+
+
+def split_pb_stream(f):
+ while True:
+ head = f.read(4)
+ if len(head) == 0:
+ break
+ size = struct.unpack("I", head)[0]
+ buf = f.read(size)
+ yield head + buf
diff --git a/fish_speech/datasets/semantic.py b/fish_speech/datasets/semantic.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c64e01077ae253bdc4e4d9cd948f8fb50df7418
--- /dev/null
+++ b/fish_speech/datasets/semantic.py
@@ -0,0 +1,496 @@
+import random
+from dataclasses import dataclass
+from itertools import chain
+from pathlib import Path
+from random import Random
+from typing import Optional, Union
+
+import numpy as np
+import pyarrow.parquet as pq
+import torch
+import torch.nn.functional as F
+from datasets.download.streaming_download_manager import xopen
+from huggingface_hub import HfApi
+from lightning import LightningDataModule
+from torch.distributed import get_rank, get_world_size, is_initialized
+from torch.utils.data import DataLoader, IterableDataset, get_worker_info
+from transformers import AutoTokenizer
+
+from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
+from fish_speech.datasets.protos.text_data_pb2 import SampledData
+from fish_speech.datasets.protos.text_data_stream import read_pb_stream
+from fish_speech.text.clean import clean_text
+from fish_speech.utils import RankedLogger
+from fish_speech.utils.braceexpand import braceexpand
+
+log = RankedLogger(__name__, rank_zero_only=True)
+
+
+def split_by_rank_worker(files):
+ # We need to know the total number of devices
+ # to split the data properly
+
+ total_devices = 1
+ if is_initialized():
+ total_devices = get_world_size()
+
+ worker_info = get_worker_info()
+ if worker_info is not None:
+ total_devices *= worker_info.num_workers
+
+ if len(files) < total_devices:
+ # Repeat the files N times to match the number of devices
+ files = files * (total_devices // len(files) + 1)
+
+ # DDP
+ if is_initialized():
+ files = files[get_rank() :: get_world_size()]
+
+ # Split by worker
+ if worker_info is not None:
+ files = files[worker_info.id :: worker_info.num_workers]
+
+ return files
+
+
+class AutoTextSemanticInstructionDataset(IterableDataset):
+ """
+ Auto Augment Dataset by Speaker
+
+ 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
+ 2. Automatically normalize the text
+
+ For interactive mode, we use the following format (multiple sequences):
+ [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
+
+ For non-interactive mode, we use the following format (one long sequence):
+ [INST] text [/INST] ...
+ """
+
+ def __init__(
+ self,
+ proto_files: list[str],
+ seed: int = 42,
+ interactive_prob: float = 0.5,
+ max_length: int = 1024,
+ tokenizer: AutoTokenizer = None,
+ use_speaker: bool | float = True,
+ causal: bool = True,
+ num_codebooks: Optional[int] = None,
+ skip_text_prob: float = 0.0,
+ ):
+ """
+ Args:
+ proto_files: proto buf files if using local data
+ seed: random seed
+ interactive_prob: probability to use interactive mode
+ max_length: max length of the text
+ tokenizer: tokenizer
+ use_speaker: include speaker information in the prompt
+ causal: use causal sampling when using local data, disable will lead to random sampling
+ num_codebooks: number of codebooks, if None, it will be automatically detected
+ skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
+ """
+
+ super().__init__()
+
+ assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
+
+ self.seed = seed
+ self.max_length = max_length
+ self.tokenizer = tokenizer
+ self.interactive_prob = interactive_prob
+ self.use_speaker = use_speaker
+ self.proto_files = proto_files
+ self.causal = causal
+ self.num_codebooks = num_codebooks
+ self.skip_text_prob = skip_text_prob
+
+ self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
+ self.groups = None
+
+ def init_mock_data_server(self):
+ if self.groups is not None:
+ return
+
+ # Expand the proto files
+ expanded_proto_files = []
+ for filename in self.proto_files:
+ for i in braceexpand(filename):
+ i = Path(i)
+ if i.is_file():
+ expanded_proto_files.append(i)
+ elif i.is_dir():
+ expanded_proto_files.extend(i.rglob("*.proto"))
+ expanded_proto_files.extend(i.rglob("*.protos"))
+ else:
+ raise ValueError(f"{i} is not a file or directory")
+
+ expanded_proto_files = sorted(expanded_proto_files)
+ Random(self.seed).shuffle(expanded_proto_files)
+
+ self.groups = []
+ shard_proto_files = split_by_rank_worker(expanded_proto_files)
+ log.info(
+ f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
+ )
+
+ count = 0
+ for filename in shard_proto_files:
+ with open(filename, "rb") as f:
+ for text_data in read_pb_stream(f):
+ self.groups.append(text_data)
+ count += 1
+
+ log.info(f"Read total {count} groups of data")
+
+ # Shuffle the lines
+ Random(self.seed).shuffle(self.groups)
+ self.group_weights = [len(i.sentences) for i in self.groups]
+
+ def __iter__(self):
+ while True:
+ yield self.augment()
+
+ def tokenize_sentence(self, sentence: str):
+ sentence = clean_text(sentence)
+ tokens = self.tokenizer.encode(
+ f"{sentence}",
+ max_length=10**6,
+ add_special_tokens=False,
+ truncation=False,
+ )
+ return sentence, len(tokens)
+
+ def sample_data(self):
+ if self.groups is None:
+ self.init_mock_data_server()
+
+ # Shuffle unique lines, estimate that each sample is at least 20 tokens
+ num_samples = self.max_length // 20
+
+ # choice group based on their number of samples
+ group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
+
+ if self.causal:
+ # Sample in order
+ if num_samples >= len(group.sentences):
+ samples = group.sentences
+ else:
+ begin = random.randint(0, len(group.sentences) - num_samples)
+ samples = group.sentences[begin : begin + num_samples]
+ else:
+ samples = random.choices(
+ group.sentences, k=min(num_samples, len(group.sentences))
+ )
+
+ return SampledData(
+ source=group.source,
+ name=group.name,
+ samples=samples,
+ )
+
+ def augment(self):
+ final_text, final_semantic = [], []
+ response = self.sample_data()
+ if len(response.samples) == 0:
+ # Invalid group
+ return None
+
+ samples = list(response.samples)
+ idx = 0
+ use_interactive = random.random() < self.interactive_prob
+
+ if use_interactive is False:
+ # Random sample based on speaker using a truncated normal distribution
+ a = torch.tensor([0], dtype=torch.float32)
+ torch.nn.init.trunc_normal_(
+ a,
+ mean=self.max_length // 2,
+ std=self.max_length // 4,
+ a=10,
+ b=self.max_length,
+ )
+ remaining_tokens = a.long().item() - 4
+ else:
+ remaining_tokens = self.max_length
+
+ # Use speaker
+ if isinstance(self.use_speaker, float):
+ use_speaker = random.random() < self.use_speaker
+ else:
+ use_speaker = self.use_speaker
+
+ all_tokens, all_labels = [], []
+ while remaining_tokens > 0 and len(samples) > 0:
+ sentence = samples.pop(0)
+
+ text = random.choice(sentence.texts)
+ text, length = self.tokenize_sentence(text)
+ remaining_tokens -= length + len(sentence.semantics[0].values)
+
+ if use_interactive is False:
+ final_text.append(text)
+ final_semantic.append(sentence.semantics)
+ else:
+ # For interactive mode, we only apply speaker for the first sentence
+ # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
+ tokens, labels = self.pack_sentences(
+ sentences=[text],
+ semantics=[sentence.semantics],
+ speaker=response.name if use_speaker else None,
+ skip_text=random.random() < self.skip_text_prob,
+ )
+
+ all_tokens.append(tokens)
+ all_labels.append(labels)
+
+ idx += 1
+
+ if use_interactive is False:
+ tokens, labels = self.pack_sentences(
+ final_text,
+ semantics=final_semantic,
+ speaker=response.name if use_speaker else None,
+ )
+ all_tokens.append(tokens)
+ all_labels.append(labels)
+
+ tokens = torch.cat(all_tokens, dim=1)
+ labels = torch.cat(all_labels, dim=1)
+
+ # Verify that the length is correct
+ assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
+
+ data = {"tokens": tokens, "labels": labels}
+
+ return data
+
+ def pack_sentences(
+ self,
+ sentences: list[str],
+ semantics: list,
+ speaker: Optional[str] = None,
+ skip_text: bool = False,
+ ):
+ if speaker is None:
+ speaker = "assistant"
+
+ cated_sentences = " ".join(sentences)
+ if skip_text:
+ cated_sentences = "<|skip_text|>"
+
+ final_text = "<|im_start|>user\n" + cated_sentences + "<|im_end|>"
+ final_text = final_text + f"<|im_start|>{speaker}\n"
+
+ encoded = self.tokenizer.encode(
+ final_text,
+ add_special_tokens=False,
+ truncation=False,
+ max_length=10**6,
+ )
+ semantic_length = sum([len(i[0].values) for i in semantics])
+ prompt_length = len(encoded)
+ num_codebooks = (
+ len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
+ )
+
+ # Pack the tokens and semantics (add and to semantic tokens)
+ tokens = (
+ encoded
+ + [self.semantic_token_id] * semantic_length
+ + self.tokenizer.convert_tokens_to_ids(["<|im_end|>"])
+ )
+
+ # Codebook bos/padding: 0, eos: 1
+ codes = [[CODEBOOK_PAD_TOKEN_ID] * prompt_length for _ in range(num_codebooks)]
+ for segment in semantics:
+ for book_idx, book in zip(range(num_codebooks), segment):
+ for j in book.values:
+ codes[book_idx].append(int(j) + 1)
+
+ for book in codes:
+ book.extend([CODEBOOK_PAD_TOKEN_ID] * 1)
+
+ tokens = [tokens] + codes
+
+ tokens = torch.tensor(tokens, dtype=torch.long)
+ labels = tokens.clone()
+
+ if skip_text:
+ # If text is not provided, the sentence is used for condition only, all labels are -100
+ torch.fill_(labels, -100)
+ return tokens, labels
+
+ # Mask out the tokens for semantic, predict semantic tokens only
+ # Since we don't mask out the input tokens, the language modeling still works
+ labels[1:, :prompt_length] = -100
+
+ tokens = tokens[:, :-1]
+ labels = labels[:, 1:]
+
+ # Verify the padding is correct, and the last token is eos
+ assert (tokens[1:, :prompt_length] == CODEBOOK_PAD_TOKEN_ID).all()
+ assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
+
+ return tokens, labels
+
+
+@dataclass
+class TextDataCollator:
+ tokenizer: AutoTokenizer
+ max_length: int = 1024
+
+ def __call__(self, examples):
+ if "negative_tokens" in examples:
+ positive_examples = []
+ negative_examples = []
+
+ for i in examples:
+ positive_examples.append(
+ {
+ "tokens": i["tokens"],
+ "labels": i["labels"],
+ }
+ )
+ negative_examples.append(
+ {
+ "tokens": i["negative_tokens"],
+ "labels": i["negative_labels"],
+ }
+ )
+
+ examples = positive_examples + negative_examples
+
+ return self.batchify(examples)
+
+ def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
+ tokens, attention_masks, labels = [], [], []
+
+ # Calculate the max length
+ max_tokens_length = 0
+ for example in examples:
+ max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
+ max_tokens_length = min(max_tokens_length, self.max_length)
+
+ for example in examples:
+ _tokens = example[tokens_key][:, :max_tokens_length]
+ _labels = example[labels_key][:, :max_tokens_length]
+ _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
+ tokens_length = _tokens.size(1)
+ _attention_mask[:tokens_length] = False
+
+ assert tokens_length == _labels.size(
+ 1
+ ), f"{tokens_length} != {_labels.size(1)}"
+
+ if tokens_length < max_tokens_length:
+ _tokens = F.pad(
+ _tokens,
+ (0, max_tokens_length - tokens_length),
+ value=self.tokenizer.eos_token_id,
+ )
+ _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
+ _labels = F.pad(
+ _labels, (0, max_tokens_length - _labels.size(1)), value=-100
+ )
+
+ tokens.append(_tokens)
+ attention_masks.append(_attention_mask)
+ labels.append(_labels)
+
+ tokens = torch.stack(tokens, dim=0)
+ attention_masks = torch.stack(attention_masks, dim=0)
+ labels = torch.stack(labels, dim=0)
+
+ return {
+ "inputs": tokens,
+ "attention_masks": attention_masks,
+ "labels": labels,
+ }
+
+
+class InterleaveDataset(IterableDataset):
+ def __init__(
+ self,
+ datasets: list[IterableDataset],
+ probabilities: list[float],
+ seed: int = 42,
+ ):
+ super().__init__()
+
+ self.datasets = datasets
+ self.probabilities = probabilities
+ self.seed = seed
+
+ def __iter__(self):
+ rng = np.random.default_rng(self.seed)
+ dataset_iterators = [iter(dataset) for dataset in self.datasets]
+
+ while True:
+ # Random choice one
+ dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
+ dataset_iterator = dataset_iterators[dataset_idx]
+
+ try:
+ yield next(dataset_iterator)
+ except StopIteration:
+ # Exhausted, create a new iterator
+ dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
+ yield next(dataset_iterators[dataset_idx])
+
+
+class SemanticDataModule(LightningDataModule):
+ def __init__(
+ self,
+ train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
+ val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
+ batch_size: int = 32,
+ tokenizer: AutoTokenizer = None,
+ max_length: int = 1024,
+ num_workers: int = 4,
+ ):
+ super().__init__()
+
+ self.train_dataset = train_dataset
+ self.val_dataset = val_dataset
+ self.batch_size = batch_size
+ self.tokenizer = tokenizer
+ self.max_length = max_length
+ self.num_workers = num_workers
+
+ def train_dataloader(self):
+ return DataLoader(
+ self.train_dataset,
+ batch_size=self.batch_size,
+ collate_fn=TextDataCollator(self.tokenizer, self.max_length),
+ num_workers=self.num_workers,
+ persistent_workers=True,
+ )
+
+ def val_dataloader(self):
+ return DataLoader(
+ self.val_dataset,
+ batch_size=self.batch_size,
+ collate_fn=TextDataCollator(self.tokenizer, self.max_length),
+ num_workers=self.num_workers,
+ persistent_workers=True,
+ )
+
+
+if __name__ == "__main__":
+ from tqdm import tqdm
+
+ ds = AutoTextSemanticInstructionDataset(
+ ["data/protos"],
+ tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
+ use_speaker=False,
+ interactive_prob=1.0,
+ skip_text_prob=0.5,
+ )
+
+ for i in ds:
+ print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
+ # i["labels"][0][i["labels"][0] == -100] = 0
+ # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
+ break
diff --git a/fish_speech/datasets/vqgan.py b/fish_speech/datasets/vqgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..a45583d22efb0feb9dc1e823bae1ef74534b299e
--- /dev/null
+++ b/fish_speech/datasets/vqgan.py
@@ -0,0 +1,147 @@
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Optional
+
+import librosa
+import numpy as np
+import torch
+from lightning import LightningDataModule
+from torch.utils.data import DataLoader, Dataset
+
+from fish_speech.utils import RankedLogger
+
+logger = RankedLogger(__name__, rank_zero_only=False)
+
+
+class VQGANDataset(Dataset):
+ def __init__(
+ self,
+ filelist: str,
+ sample_rate: int = 32000,
+ hop_length: int = 640,
+ slice_frames: Optional[int] = None,
+ ):
+ super().__init__()
+
+ filelist = Path(filelist)
+ root = filelist.parent
+
+ self.files = [
+ root / line.strip()
+ for line in filelist.read_text(encoding="utf-8").splitlines()
+ if line.strip()
+ ]
+ self.sample_rate = sample_rate
+ self.hop_length = hop_length
+ self.slice_frames = slice_frames
+
+ def __len__(self):
+ return len(self.files)
+
+ def get_item(self, idx):
+ file = self.files[idx]
+
+ audio, _ = librosa.load(file, sr=self.sample_rate, mono=True)
+
+ # Slice audio and features
+ if (
+ self.slice_frames is not None
+ and audio.shape[0] > self.slice_frames * self.hop_length
+ ):
+ start = np.random.randint(
+ 0, audio.shape[0] - self.slice_frames * self.hop_length
+ )
+ audio = audio[start : start + self.slice_frames * self.hop_length]
+
+ if len(audio) == 0:
+ return None
+
+ max_value = np.abs(audio).max()
+ if max_value > 1.0:
+ audio = audio / max_value
+
+ return {
+ "audio": torch.from_numpy(audio),
+ }
+
+ def __getitem__(self, idx):
+ try:
+ return self.get_item(idx)
+ except Exception as e:
+ import traceback
+
+ traceback.print_exc()
+ logger.error(f"Error loading {self.files[idx]}: {e}")
+ return None
+
+
+@dataclass
+class VQGANCollator:
+ def __call__(self, batch):
+ batch = [x for x in batch if x is not None]
+
+ audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
+ audio_maxlen = audio_lengths.max()
+
+ # Rounds up to nearest multiple of 2 (audio_lengths)
+ audios = []
+ for x in batch:
+ audios.append(
+ torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
+ )
+
+ return {
+ "audios": torch.stack(audios),
+ "audio_lengths": audio_lengths,
+ }
+
+
+class VQGANDataModule(LightningDataModule):
+ def __init__(
+ self,
+ train_dataset: VQGANDataset,
+ val_dataset: VQGANDataset,
+ batch_size: int = 32,
+ num_workers: int = 4,
+ val_batch_size: Optional[int] = None,
+ ):
+ super().__init__()
+
+ self.train_dataset = train_dataset
+ self.val_dataset = val_dataset
+ self.batch_size = batch_size
+ self.val_batch_size = val_batch_size or batch_size
+ self.num_workers = num_workers
+
+ def train_dataloader(self):
+ return DataLoader(
+ self.train_dataset,
+ batch_size=self.batch_size,
+ collate_fn=VQGANCollator(),
+ num_workers=self.num_workers,
+ shuffle=True,
+ persistent_workers=True,
+ )
+
+ def val_dataloader(self):
+ return DataLoader(
+ self.val_dataset,
+ batch_size=self.val_batch_size,
+ collate_fn=VQGANCollator(),
+ num_workers=self.num_workers,
+ persistent_workers=True,
+ )
+
+
+if __name__ == "__main__":
+ dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt")
+ dataloader = DataLoader(
+ dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator()
+ )
+
+ for batch in dataloader:
+ print(batch["audios"].shape)
+ print(batch["features"].shape)
+ print(batch["audio_lengths"])
+ print(batch["feature_lengths"])
+ break
diff --git a/fish_speech/i18n/README.md b/fish_speech/i18n/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..700902b09db20911ef1ad678cbdce5644b84aea2
--- /dev/null
+++ b/fish_speech/i18n/README.md
@@ -0,0 +1,27 @@
+## i18n Folder Attribution
+
+The `i18n` folder within the `fish_speech` directory contains files initially sourced from the RVC project. In compliance with the MIT license under which these files were released, we acknowledge the original authors and sources below:
+
+### fish_speech/i18n/core.py
+
+**Related code from RVC:**
+[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py)
+
+**Initial commit:**
+add localization(添加本地化) [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#35](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/35)
+
+**Initial author:**
+[@L4Ph](https://github.com/L4Ph)
+
+### fish_speech/i18n/scan.py
+
+**Related code from RVC:**
+[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py)
+
+**Initial commit:**
+File for detecting i18n missing keys [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#1058](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/1058)
+
+**Initial author:**
+[@towzeur](https://github.com/towzeur)
+
+We appreciate the contributions of the RVC project and its authors.
diff --git a/fish_speech/i18n/__init__.py b/fish_speech/i18n/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..981dbb3b3ecf28043ec9ff5757f947182821a246
--- /dev/null
+++ b/fish_speech/i18n/__init__.py
@@ -0,0 +1,3 @@
+from .core import i18n
+
+__all__ = ["i18n"]
diff --git a/fish_speech/i18n/core.py b/fish_speech/i18n/core.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f793ec95669228f7f4e8f9a7a5fe38da85c74bd
--- /dev/null
+++ b/fish_speech/i18n/core.py
@@ -0,0 +1,40 @@
+import json
+import locale
+from pathlib import Path
+
+I18N_FILE_PATH = Path(__file__).parent / "locale"
+DEFAULT_LANGUAGE = "en_US"
+
+
+def load_language_list(language):
+ with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f:
+ language_list = json.load(f)
+
+ return language_list
+
+
+class I18nAuto:
+ def __init__(self):
+ i18n_file = Path(".locale")
+
+ if i18n_file.exists():
+ with open(i18n_file, "r", encoding="utf-8") as f:
+ language = f.read().strip()
+ else:
+ # getlocale can't identify the system's language ((None, None))
+ language = locale.getdefaultlocale()[0]
+
+ if (I18N_FILE_PATH / f"{language}.json").exists() is False:
+ language = DEFAULT_LANGUAGE
+
+ self.language = language
+ self.language_map = load_language_list(language)
+
+ def __call__(self, key):
+ return self.language_map.get(key, key)
+
+ def __repr__(self):
+ return "Use Language: " + self.language
+
+
+i18n = I18nAuto()
diff --git a/fish_speech/i18n/locale/en_US.json b/fish_speech/i18n/locale/en_US.json
new file mode 100644
index 0000000000000000000000000000000000000000..d36c774313628fe9d4ee60e816f404c09935e655
--- /dev/null
+++ b/fish_speech/i18n/locale/en_US.json
@@ -0,0 +1,123 @@
+{
+ "16-mixed is recommended for 10+ series GPU": "16-mixed is recommended for 10+ series GPU",
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 to 10 seconds of reference audio, useful for specifying speaker.",
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).",
+ "Accumulate Gradient Batches": "Accumulate Gradient Batches",
+ "Add to Processing Area": "Add to Processing Area",
+ "Added path successfully!": "Added path successfully!",
+ "Advanced Config": "Advanced Config",
+ "Base LLAMA Model": "Base LLAMA Model",
+ "Batch Inference": "Batch Inference",
+ "Batch Size": "Batch Size",
+ "Changing with the Model Path": "Changing with the Model Path",
+ "Chinese": "Chinese",
+ "Compile Model": "Compile Model",
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compile the model can significantly reduce the inference time, but will increase cold start time",
+ "Copy": "Copy",
+ "Data Preprocessing": "Data Preprocessing",
+ "Data Preprocessing Path": "Data Preprocessing Path",
+ "Data Source": "Data Source",
+ "Decoder Model Config": "Decoder Model Config",
+ "Decoder Model Path": "Decoder Model Path",
+ "Disabled": "Disabled",
+ "Enable Reference Audio": "Enable Reference Audio",
+ "English": "English",
+ "Error Message": "Error Message",
+ "File Preprocessing": "File Preprocessing",
+ "Generate": "Generate",
+ "Generated Audio": "Generated Audio",
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format",
+ "Infer interface is closed": "Infer interface is closed",
+ "Inference Configuration": "Inference Configuration",
+ "Inference Server Configuration": "Inference Server Configuration",
+ "Inference Server Error": "Inference Server Error",
+ "Inferring interface is launched at {}": "Inferring interface is launched at {}",
+ "Initial Learning Rate": "Initial Learning Rate",
+ "Input Audio & Source Path for Transcription": "Input Audio & Source Path for Transcription",
+ "Input Text": "Input Text",
+ "Invalid path: {}": "Invalid path: {}",
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "It is recommended to use CUDA, if you have low configuration, use CPU",
+ "Iterative Prompt Length, 0 means off": "Iterative Prompt Length, 0 means off",
+ "Japanese": "Japanese",
+ "LLAMA Configuration": "LLAMA Configuration",
+ "LLAMA Model Config": "LLAMA Model Config",
+ "LLAMA Model Path": "LLAMA Model Path",
+ "Labeling Device": "Labeling Device",
+ "LoRA Model to be merged": "LoRA Model to be merged",
+ "Maximum Audio Duration": "Maximum Audio Duration",
+ "Maximum Length per Sample": "Maximum Length per Sample",
+ "Maximum Training Steps": "Maximum Training Steps",
+ "Maximum tokens per batch, 0 means no limit": "Maximum tokens per batch, 0 means no limit",
+ "Merge": "Merge",
+ "Merge LoRA": "Merge LoRA",
+ "Merge successfully": "Merge successfully",
+ "Minimum Audio Duration": "Minimum Audio Duration",
+ "Model Output Path": "Model Output Path",
+ "Model Size": "Model Size",
+ "Move": "Move",
+ "Move files successfully": "Move files successfully",
+ "No audio generated, please check the input text.": "No audio generated, please check the input text.",
+ "No selected options": "No selected options",
+ "Number of Workers": "Number of Workers",
+ "Open Inference Server": "Open Inference Server",
+ "Open Labeler WebUI": "Open Labeler WebUI",
+ "Open Tensorboard": "Open Tensorboard",
+ "Opened labeler in browser": "Opened labeler in browser",
+ "Optional Label Language": "Optional Label Language",
+ "Optional online ver": "Optional online ver",
+ "Output Path": "Output Path",
+ "Path error, please check the model file exists in the corresponding path": "Path error, please check the model file exists in the corresponding path",
+ "Precision": "Precision",
+ "Probability of applying Speaker Condition": "Probability of applying Speaker Condition",
+ "Put your text here.": "Put your text here.",
+ "Reference Audio": "Reference Audio",
+ "Reference Text": "Reference Text",
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "Related code and weights are released under CC BY-NC-SA 4.0 License.",
+ "Remove Selected Data": "Remove Selected Data",
+ "Removed path successfully!": "Removed path successfully!",
+ "Repetition Penalty": "Repetition Penalty",
+ "Save model every n steps": "Save model every n steps",
+ "Select LLAMA ckpt": "Select LLAMA ckpt",
+ "Select VITS ckpt": "Select VITS ckpt",
+ "Select VQGAN ckpt": "Select VQGAN ckpt",
+ "Select source file processing method": "Select source file processing method",
+ "Select the model to be trained (Depending on the Tab page you are on)": "Select the model to be trained (Depending on the Tab page you are on)",
+ "Selected: {}": "Selected: {}",
+ "Speaker": "Speaker",
+ "Speaker is identified by the folder name": "Speaker is identified by the folder name",
+ "Start Training": "Start Training",
+ "Streaming Audio": "Streaming Audio",
+ "Streaming Generate": "Streaming Generate",
+ "Tensorboard Host": "Tensorboard Host",
+ "Tensorboard Log Path": "Tensorboard Log Path",
+ "Tensorboard Port": "Tensorboard Port",
+ "Tensorboard interface is closed": "Tensorboard interface is closed",
+ "Tensorboard interface is launched at {}": "Tensorboard interface is launched at {}",
+ "Text is too long, please keep it under {} characters.": "Text is too long, please keep it under {} characters.",
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.",
+ "Training Configuration": "Training Configuration",
+ "Training Error": "Training Error",
+ "Training stopped": "Training stopped",
+ "Type name of the speaker": "Type name of the speaker",
+ "Type the path or select from the dropdown": "Type the path or select from the dropdown",
+ "Use LoRA": "Use LoRA",
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "Use LoRA can save GPU memory, but may reduce the quality of the model",
+ "Use filelist": "Use filelist",
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use large for 10G+ GPU, medium for 5G, small for 2G",
+ "VITS Configuration": "VITS Configuration",
+ "VQGAN Configuration": "VQGAN Configuration",
+ "Validation Batch Size": "Validation Batch Size",
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "View the status of the preprocessing folder (use the slider to control the depth of the tree)",
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.",
+ "WebUI Host": "WebUI Host",
+ "WebUI Port": "WebUI Port",
+ "Whisper Model": "Whisper Model",
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).",
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU",
+ "latest": "latest",
+ "new": "new",
+ "Realtime Transform Text": "Realtime Transform Text",
+ "Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)",
+ "Text Normalization": "Text Normalization",
+ "Select Example Audio": "Select Example Audio"
+}
diff --git a/fish_speech/i18n/locale/es_ES.json b/fish_speech/i18n/locale/es_ES.json
new file mode 100644
index 0000000000000000000000000000000000000000..7a4757967dd0fe3807ba4d354e75ad7a88eb510e
--- /dev/null
+++ b/fish_speech/i18n/locale/es_ES.json
@@ -0,0 +1,123 @@
+{
+ "16-mixed is recommended for 10+ series GPU": "se recomienda 16-mixed para GPU de la serie 10+",
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de audio de referencia, útil para especificar el hablante.",
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Un modelo de texto a voz basado en VQ-GAN y Llama desarrollado por [Fish Audio](https://fish.audio).",
+ "Accumulate Gradient Batches": "Acumular lotes de gradientes",
+ "Add to Processing Area": "Agregar al Área de Procesamiento",
+ "Added path successfully!": "¡Ruta agregada exitosamente!",
+ "Advanced Config": "Configuración Avanzada",
+ "Base LLAMA Model": "Modelo Base LLAMA",
+ "Batch Inference": "Inferencia por Lote",
+ "Batch Size": "Tamaño del Lote",
+ "Changing with the Model Path": "Cambiando con la Ruta del Modelo",
+ "Chinese": "Chino",
+ "Compile Model": "Compilar Modelo",
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar el modelo puede reducir significativamente el tiempo de inferencia, pero aumentará el tiempo de inicio en frío",
+ "Copy": "Copiar",
+ "Data Preprocessing": "Preprocesamiento de Datos",
+ "Data Preprocessing Path": "Ruta de Preprocesamiento de Datos",
+ "Data Source": "Fuente de Datos",
+ "Decoder Model Config": "Configuración del modelo decodificador",
+ "Decoder Model Path": "Ruta del modelo decodificador",
+ "Disabled": "Desactivado",
+ "Enable Reference Audio": "Habilitar Audio de Referencia",
+ "English": "Inglés",
+ "Error Message": "Mensaje de Error",
+ "File Preprocessing": "Preprocesamiento de Archivos",
+ "Generate": "Generar",
+ "Generated Audio": "Audio Generado",
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Si no hay texto correspondiente para el audio, aplique ASR para asistencia, soporte para formato .txt o .lab",
+ "Infer interface is closed": "La interfaz de inferencia está cerrada",
+ "Inference Configuration": "Configuración de Inferencia",
+ "Inference Server Configuration": "Configuración del Servidor de Inferencia",
+ "Inference Server Error": "Error del Servidor de Inferencia",
+ "Inferring interface is launched at {}": "La interfaz de inferencia se ha lanzado en {}",
+ "Initial Learning Rate": "Tasa de Aprendizaje Inicial",
+ "Input Audio & Source Path for Transcription": "Audio de Entrada y Ruta de Origen para Transcripción",
+ "Input Text": "Texto de Entrada",
+ "Invalid path: {}": "Ruta inválida: {}",
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "Se recomienda usar CUDA, si tiene una configuración baja, use CPU",
+ "Iterative Prompt Length, 0 means off": "Longitud de la Indicación Iterativa, 0 significa apagado",
+ "Japanese": "Japonés",
+ "LLAMA Configuration": "Configuración de LLAMA",
+ "LLAMA Model Config": "Configuración del Modelo LLAMA",
+ "LLAMA Model Path": "Ruta del Modelo LLAMA",
+ "Labeling Device": "Dispositivo de Etiquetado",
+ "LoRA Model to be merged": "Modelo LoRA a fusionar",
+ "Maximum Audio Duration": "Duración máxima de audio",
+ "Maximum Length per Sample": "Longitud Máxima por Muestra",
+ "Maximum Training Steps": "Pasos Máximos de Entrenamiento",
+ "Maximum tokens per batch, 0 means no limit": "Máximo de tokens por lote, 0 significa sin límite",
+ "Merge": "Fusionar",
+ "Merge LoRA": "Fusionar LoRA",
+ "Merge successfully": "Fusionado exitosamente",
+ "Minimum Audio Duration": "Duración mínima de audio",
+ "Model Output Path": "Ruta de Salida del Modelo",
+ "Model Size": "Tamaño del Modelo",
+ "Move": "Mover",
+ "Move files successfully": "Archivos movidos exitosamente",
+ "No audio generated, please check the input text.": "No se generó audio, por favor verifique el texto de entrada.",
+ "No selected options": "No hay opciones seleccionadas",
+ "Number of Workers": "Número de Trabajadores",
+ "Open Inference Server": "Abrir Servidor de Inferencia",
+ "Open Labeler WebUI": "Abrir Interfaz Web del Etiquetador",
+ "Open Tensorboard": "Abrir Tensorboard",
+ "Opened labeler in browser": "Se abrió el etiquetador en el navegador",
+ "Optional Label Language": "Idioma de Etiquetado Opcional",
+ "Optional online ver": "Ver en línea opcional",
+ "Output Path": "Ruta de Salida",
+ "Path error, please check the model file exists in the corresponding path": "Error de ruta, por favor verifique que el archivo del modelo exista en la ruta correspondiente",
+ "Precision": "Precisión",
+ "Probability of applying Speaker Condition": "Probabilidad de aplicar Condición de Hablante",
+ "Put your text here.": "Ponga su texto aquí.",
+ "Reference Audio": "Audio de Referencia",
+ "Reference Text": "Texto de Referencia",
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "El código relacionado y los pesos se publican bajo la Licencia CC BY-NC-SA 4.0.",
+ "Remove Selected Data": "Eliminar Datos Seleccionados",
+ "Removed path successfully!": "¡Ruta eliminada exitosamente!",
+ "Repetition Penalty": "Penalización por Repetición",
+ "Save model every n steps": "Guardar modelo cada n pasos",
+ "Select LLAMA ckpt": "Seleccionar punto de control LLAMA",
+ "Select VITS ckpt": "Seleccionar punto de control VITS",
+ "Select VQGAN ckpt": "Seleccionar punto de control VQGAN",
+ "Select source file processing method": "Seleccione el método de procesamiento de archivos fuente",
+ "Select the model to be trained (Depending on the Tab page you are on)": "Seleccione el modelo a entrenar (Dependiendo de la pestaña en la que se encuentre)",
+ "Selected: {}": "Seleccionado: {}",
+ "Speaker": "Hablante",
+ "Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta",
+ "Start Training": "Iniciar Entrenamiento",
+ "Streaming Audio": "transmisión de audio",
+ "Streaming Generate": "síntesis en flujo",
+ "Tensorboard Host": "Host de Tensorboard",
+ "Tensorboard Log Path": "Ruta de Registro de Tensorboard",
+ "Tensorboard Port": "Puerto de Tensorboard",
+ "Tensorboard interface is closed": "La interfaz de Tensorboard está cerrada",
+ "Tensorboard interface is launched at {}": "La interfaz de Tensorboard se ha lanzado en {}",
+ "Text is too long, please keep it under {} characters.": "El texto es demasiado largo, por favor manténgalo por debajo de {} caracteres.",
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "La ruta de la carpeta de entrada a la izquierda o la lista de archivos. Ya sea que esté marcado o no, se utilizará para el entrenamiento posterior en esta lista.",
+ "Training Configuration": "Configuración de Entrenamiento",
+ "Training Error": "Error de Entrenamiento",
+ "Training stopped": "Entrenamiento detenido",
+ "Type name of the speaker": "Escriba el nombre del hablante",
+ "Type the path or select from the dropdown": "Escriba la ruta o seleccione de la lista desplegable",
+ "Use LoRA": "Usar LoRA",
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "Usar LoRA puede ahorrar memoria GPU, pero puede reducir la calidad del modelo",
+ "Use filelist": "Usar lista de archivos",
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use grande para GPU de 10G+, mediano para 5G, pequeño para 2G",
+ "VITS Configuration": "Configuración de VITS",
+ "VQGAN Configuration": "Configuración de VQGAN",
+ "Validation Batch Size": "Tamaño del Lote de Validación",
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Vea el estado de la carpeta de preprocesamiento (use el control deslizante para controlar la profundidad del árbol)",
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "No somos responsables de ningún mal uso del modelo, por favor considere sus leyes y regulaciones locales antes de usarlo.",
+ "WebUI Host": "Host de WebUI",
+ "WebUI Port": "Puerto de WebUI",
+ "Whisper Model": "Modelo Whisper",
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Puede encontrar el código fuente [aquí](https://github.com/fishaudio/fish-speech) y los modelos [aquí](https://huggingface.co/fishaudio/fish-speech-1).",
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "Se recomienda bf16-true para GPU de la serie 30+, se recomienda 16-mixed para GPU de la serie 10+",
+ "latest": "más reciente",
+ "new": "nuevo",
+ "Realtime Transform Text": "Transformación de Texto en Tiempo Real",
+ "Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)",
+ "Text Normalization": "Normalización de Texto",
+ "Select Example Audio": "Selecionar áudio de exemplo"
+}
diff --git a/fish_speech/i18n/locale/ja_JP.json b/fish_speech/i18n/locale/ja_JP.json
new file mode 100644
index 0000000000000000000000000000000000000000..863b8b0b41da7e504ac0dcc4abf707f1f71a53fa
--- /dev/null
+++ b/fish_speech/i18n/locale/ja_JP.json
@@ -0,0 +1,123 @@
+{
+ "16-mixed is recommended for 10+ series GPU": "10シリーズ以降のGPUには16-mixedをお勧めします",
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "話者を指定するのに役立つ、5~10秒のリファレンスオーディオ。",
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)が開発したVQ-GANとLlamaに基づくテキスト音声合成モデル。",
+ "Accumulate Gradient Batches": "勾配バッチの累積",
+ "Add to Processing Area": "処理エリアに追加",
+ "Added path successfully!": "パスの追加に成功しました!",
+ "Advanced Config": "詳細設定",
+ "Base LLAMA Model": "基本LLAMAモデル",
+ "Batch Inference": "バッチ推論",
+ "Batch Size": "バッチサイズ",
+ "Changing with the Model Path": "モデルのパスに伴って変化する",
+ "Chinese": "中国語",
+ "Compile Model": "モデルのコンパイル",
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "モデルをコンパイルすると推論時間を大幅に短縮できますが、コールドスタート時間が長くなります",
+ "Copy": "コピー",
+ "Data Preprocessing": "データ前処理",
+ "Data Preprocessing Path": "データ前処理パス",
+ "Data Source": "データソース",
+ "Decoder Model Config": "デコーダーモデルの構成",
+ "Decoder Model Path": "デコーダーモデルのパス",
+ "Disabled": "無効",
+ "Enable Reference Audio": "リファレンスオーディオを有効にする",
+ "English": "英語",
+ "Error Message": "エラーメッセージ",
+ "File Preprocessing": "文書前处理",
+ "Generate": "生成",
+ "Generated Audio": "生成されたオーディオ",
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "音声に対応するテキストがない場合は、ASRを適用してサポートします。.txtまたは.lab形式をサポートしています",
+ "Infer interface is closed": "推論インターフェースが閉じられています",
+ "Inference Configuration": "推論設定",
+ "Inference Server Configuration": "推論サーバー設定",
+ "Inference Server Error": "推論サーバーエラー",
+ "Inferring interface is launched at {}": "推論インターフェースが{}で起動しました",
+ "Initial Learning Rate": "初期学習率",
+ "Input Audio & Source Path for Transcription": "入力オーディオと文字起こしのソースパス",
+ "Input Text": "入力テキスト",
+ "Invalid path: {}": "無効なパス: {}",
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDAの使用をお勧めします。低い構成の場合はCPUを使用してください",
+ "Iterative Prompt Length, 0 means off": "反復プロンプト長。0はオフを意味します",
+ "Japanese": "日本語",
+ "LLAMA Configuration": "LLAMA設定",
+ "LLAMA Model Config": "LLAMAモデル設定",
+ "LLAMA Model Path": "LLAMAモデルパス",
+ "Labeling Device": "ラベリングデバイス",
+ "LoRA Model to be merged": "マージするLoRAモデル",
+ "Maximum Audio Duration": "最大オーディオの長さ",
+ "Maximum Length per Sample": "サンプルあたりの最大長",
+ "Maximum Training Steps": "最大トレーニングステップ数",
+ "Maximum tokens per batch, 0 means no limit": "バッチあたりの最大トークン数。0は制限なしを意味します",
+ "Merge": "マージ",
+ "Merge LoRA": "LoRAのマージ",
+ "Merge successfully": "マージに成功しました",
+ "Minimum Audio Duration": "最小オーディオの長さ",
+ "Model Output Path": "モデル出力パス",
+ "Model Size": "モデルサイズ",
+ "Move": "移動",
+ "Move files successfully": "ファイルの移動に成功しました",
+ "No audio generated, please check the input text.": "オーディオが生成されていません。入力テキストを確認してください。",
+ "No selected options": "選択されたオプションはありません",
+ "Number of Workers": "ワーカー数",
+ "Open Inference Server": "推論サーバーを開く",
+ "Open Labeler WebUI": "ラベラーWebUIを開く",
+ "Open Tensorboard": "Tensorboardを開く",
+ "Opened labeler in browser": "ブラウザでラベラーを開きました",
+ "Optional Label Language": "オプションのラベル言語",
+ "Optional online ver": "オプションのオンラインバージョン",
+ "Output Path": "出力パス",
+ "Path error, please check the model file exists in the corresponding path": "パスエラー。対応するパスにモデルファイルが存在するか確認してください",
+ "Precision": "精度",
+ "Probability of applying Speaker Condition": "話者条件を適用する確率",
+ "Put your text here.": "ここにテキストを入力してください。",
+ "Reference Audio": "リファレンスオーディオ",
+ "Reference Text": "リファレンステキスト",
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "関連コードと重みはCC BY-NC-SA 4.0ライセンスの下でリリースされます。",
+ "Remove Selected Data": "選択したデータを削除",
+ "Removed path successfully!": "パスの削除に成功しました!",
+ "Repetition Penalty": "反復ペナルティ",
+ "Save model every n steps": "nステップごとにモデルを保存",
+ "Select LLAMA ckpt": " LLAMA チェックポイントを選択",
+ "Select VITS ckpt": "VITS チェックポイントを選択",
+ "Select VQGAN ckpt": "VQGAN チェックポイントを選択",
+ "Select source file processing method": "ソースファイルの処理方法を選択",
+ "Select the model to be trained (Depending on the Tab page you are on)": "タブページに応じてトレーニングするモデルを選択してください",
+ "Selected: {}": "選択済み: {}",
+ "Speaker": "話者",
+ "Speaker is identified by the folder name": "話者はフォルダ名で識別されます",
+ "Start Training": "トレーニング開始",
+ "Streaming Audio": "ストリーミングオーディオ",
+ "Streaming Generate": "ストリーミング合成",
+ "Tensorboard Host": "Tensorboardホスト",
+ "Tensorboard Log Path": "Tensorboardログパス",
+ "Tensorboard Port": "Tensorboardポート",
+ "Tensorboard interface is closed": "Tensorboardインターフェースが閉じられています",
+ "Tensorboard interface is launched at {}": "Tensorboardインターフェースが{}で起動されました",
+ "Text is too long, please keep it under {} characters.": "テキストが長すぎます。{}文字以内に抑えてください。",
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左側の入力フォルダまたはファイルリストのパス。チェックの有無にかかわらず、このリストの後続のトレーニングに使用されます。",
+ "Training Configuration": "トレーニング設定",
+ "Training Error": "トレーニングエラー",
+ "Training stopped": "トレーニングが停止しました",
+ "Type name of the speaker": "話者の名前を入力",
+ "Type the path or select from the dropdown": "パスを入力するか、ドロップダウンから選択してください",
+ "Use LoRA": "LoRAを使用",
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRAを使用するとGPUメモリを節約できますが、モデルの品質が低下する可能性があります",
+ "Use filelist": "ファイルリストを使用",
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G以上のGPUには大、5Gには中、2Gには小を使用してください",
+ "VITS Configuration": "VITS の構成",
+ "VQGAN Configuration": "VQGAN の構成",
+ "Validation Batch Size": "検証バッチサイズ",
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "前処理フォルダの状態を表示(スライダーを使用してツリーの深さを制御)",
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "モデルの誤用については一切責任を負いません。使用する前に、現地の法律と規制を考慮してください。",
+ "WebUI Host": "WebUIホスト",
+ "WebUI Port": "WebUIポート",
+ "Whisper Model": "Whisperモデル",
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "ソースコードは[こちら](https://github.com/fishaudio/fish-speech)、モデルは[こちら](https://huggingface.co/fishaudio/fish-speech-1)にあります。",
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30シリーズ以降のGPUにはbf16-trueを、10シリーズ以降のGPUには16-mixedをお勧めします",
+ "latest": "最新",
+ "new": "新規",
+ "Realtime Transform Text": "リアルタイム変換テキスト",
+ "Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)",
+ "Text Normalization": "テキスト正規化",
+ "Select Example Audio": "サンプル音声を選択"
+}
diff --git a/fish_speech/i18n/locale/ko_KR.json b/fish_speech/i18n/locale/ko_KR.json
new file mode 100644
index 0000000000000000000000000000000000000000..180263874b476059870035d4c2b74ce5fa553a8a
--- /dev/null
+++ b/fish_speech/i18n/locale/ko_KR.json
@@ -0,0 +1,123 @@
+{
+ "16-mixed is recommended for 10+ series GPU": "10+ 시리즈 GPU에는 16-mixed를 권장합니다.",
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "화자를 특정하는 데 유의미한 5~10초의 길이의 참조 오디오 데이터.",
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)에서 개발한 VQ-GAN 및 Llama 기반의 텍스트 음성 변환 모델.",
+ "Accumulate Gradient Batches": "그라디언트 배치 누적",
+ "Add to Processing Area": "처리 영역에 추가",
+ "Added path successfully!": "경로가 성공적으로 추가되었습니다!",
+ "Advanced Config": "고급 설정",
+ "Base LLAMA Model": "기본 LLAMA 모델",
+ "Batch Inference": "배치 추론",
+ "Batch Size": "배치 크기",
+ "Changing with the Model Path": "모델 경로에 따라 변경 중",
+ "Chinese": "중국어",
+ "Compile Model": "모델 컴파일",
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "모델을 컴파일하면 추론 시간이 크게 줄어들지만, 초기 시작 시간이 길어집니다.",
+ "Copy": "복사",
+ "Data Preprocessing": "데이터 전처리",
+ "Data Preprocessing Path": "데이터 전처리 경로",
+ "Data Source": "데이터 소스",
+ "Decoder Model Config": "디코더 모델 설정",
+ "Decoder Model Path": "디코더 모델 경로",
+ "Disabled": "비활성화 됨",
+ "Enable Reference Audio": "참고 음성 활성화",
+ "English": "영어",
+ "Error Message": "오류 메시지",
+ "File Preprocessing": "파일 전처리",
+ "Generate": "생성",
+ "Generated Audio": "생성된 오디오",
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "오디오애 대응하는 텍스트가 없을 경우, ASR을 적용해 지원하며, .txt 또는 .lab 형식을 지원합니다.",
+ "Infer interface is closed": "추론 인터페이스가 닫혔습니다.",
+ "Inference Configuration": "추론 설정",
+ "Inference Server Configuration": "추론 서버 설정",
+ "Inference Server Error": "추론 서버 오류",
+ "Inferring interface is launched at {}": "추론 인터페이스가 {}에서 시작되었습니다.",
+ "Initial Learning Rate": "초기 학습률",
+ "Input Audio & Source Path for Transcription": "전사할 입력 오디오 및 소스 경로",
+ "Input Text": "입력 텍스트",
+ "Invalid path: {}": "유효하지 않은 경로: {}",
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDA 사용을 권장하며, 낮은 사양일 경우 CPU를 사용하는 것을 권장합니다.",
+ "Iterative Prompt Length, 0 means off": "반복 프롬프트 길이. (0:비활성화)",
+ "Japanese": "일본어",
+ "LLAMA Configuration": "LLAMA 설정",
+ "LLAMA Model Config": "LLAMA 모델 설정",
+ "LLAMA Model Path": "LLAMA 모델 경로",
+ "Labeling Device": "라벨링 장치",
+ "LoRA Model to be merged": "병합할 LoRA 모델",
+ "Maximum Audio Duration": "최대 오디오 길이",
+ "Maximum Length per Sample": "샘플당 최대 길이",
+ "Maximum Training Steps": "최대 학습 단계",
+ "Maximum tokens per batch, 0 means no limit": "배치당 최대 토큰 수(0:제한 없음)",
+ "Merge": "병합",
+ "Merge LoRA": "LoRA 병합",
+ "Merge successfully": "성공적으로 병합 되었습니다.",
+ "Minimum Audio Duration": "최소 오디오 길이",
+ "Model Output Path": "모델 출력 경로",
+ "Model Size": "모델 크기",
+ "Move": "이동",
+ "Move files successfully": "파일이 성공적으로 이동되었습니다.",
+ "No audio generated, please check the input text.": "생성된 오디오가 없습니다. 입력된 텍스트를 확인하세요.",
+ "No selected options": "옵션이 선택되지 않았습니다.",
+ "Number of Workers": "작업자 수",
+ "Open Inference Server": "추론 서버 열기",
+ "Open Labeler WebUI": "라벨러 WebUI 열기",
+ "Open Tensorboard": "Tensorboard 열기",
+ "Opened labeler in browser": "브라우저에서 라벨러가 열렸습니다.",
+ "Optional Label Language": "선택적 라벨 언어",
+ "Optional online ver": "온라인 버전 선택",
+ "Output Path": "출력 경로",
+ "Path error, please check the model file exists in the corresponding path": "경로 오류, 해당 경로에 모델 파일이 있는지 확인하십시오.",
+ "Precision": "정밀도",
+ "Probability of applying Speaker Condition": "화자 조건 적용 확률",
+ "Put your text here.": "여기에 텍스트를 입력하세요.",
+ "Reference Audio": "참고 오디오",
+ "Reference Text": "참고 텍스트",
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "관련 코드 및 가중치는 CC BY-NC-SA 4.0 라이선스 하에 배포됩니다.",
+ "Remove Selected Data": "선택한 데이터 제거",
+ "Removed path successfully!": "경로가 성공적으로 제거되었습니다!",
+ "Repetition Penalty": "반복 패널티",
+ "Save model every n steps": "n 단계마다 모델 저장",
+ "Select LLAMA ckpt": "LLAMA ckpt 선택",
+ "Select VITS ckpt": "VITS ckpt 선택",
+ "Select VQGAN ckpt": "VQGAN ckpt 선택",
+ "Select source file processing method": "소스 파일 처리 방법 선택",
+ "Select the model to be trained (Depending on the Tab page you are on)": "학습할 모델 선택(탭 페이지에 따라 다름)",
+ "Selected: {}": "선택됨: {}",
+ "Speaker": "화자",
+ "Speaker is identified by the folder name": "화자는 폴더 이름으로 식별됩니다",
+ "Start Training": "학습 시작",
+ "Streaming Audio": "스트리밍 오디오",
+ "Streaming Generate": "스트리밍 생성",
+ "Tensorboard Host": "Tensorboard 호스트",
+ "Tensorboard Log Path": "Tensorboard 로그 경로",
+ "Tensorboard Port": "Tensorboard 포트",
+ "Tensorboard interface is closed": "Tensorboard 인터페이스가 닫혔습니다",
+ "Tensorboard interface is launched at {}": "Tensorboard 인터페이스가 {}에서 시작되었습니다.",
+ "Text is too long, please keep it under {} characters.": "텍스트가 너무 깁니다. {}자 이하로 입력해주세요.",
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "왼쪽의 입력 폴더 경로 또는 파일 목록의 경로. 체크 여부에 관계없이 이 목록에서 후속 학습에 사용됩니다.",
+ "Training Configuration": "학습 설정",
+ "Training Error": "학습 오류",
+ "Training stopped": "학습이 중지되었습니다.",
+ "Type name of the speaker": "화자의 이름을 입력하세요.",
+ "Type the path or select from the dropdown": "경로를 입력하거나 드롭다운에서 선택하세요.",
+ "Use LoRA": "LoRA 사용",
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRA를 사용하면 GPU 메모리를 절약할 수 있지만, 모델의 품질이 저하될 수 있습니다.",
+ "Use filelist": "파일 목록 사용",
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 환경에선 large, 5G에선 medium, 2G에선 small을 사용할 것을 권장합니다.",
+ "VITS Configuration": "VITS 설정",
+ "VQGAN Configuration": "VQGAN 설정",
+ "Validation Batch Size": "검증 배치 크기",
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "전처리 폴더의 상태를 확인합니다(슬라이더를 사용하여 트리의 깊이를 조절합니다)",
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "모델의 오용에 대해 책임지지 않습니다. 사용하기 전에 현지 법률과 규정을 고려하시길 바랍니다.",
+ "WebUI Host": "WebUI 호스트",
+ "WebUI Port": "WebUI 포트",
+ "Whisper Model": "Whisper 모델",
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "소스 코드는 [이곳](https://github.com/fishaudio/fish-speech)에서, 모델은 [이곳](https://huggingface.co/fishaudio/fish-speech-1)에서 확인하실 수 있습니다.",
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 시리즈 GPU에는 bf16-true를, 10+ 시리즈 GPU에는 16-mixed를 권장합니다",
+ "latest": "최신",
+ "new": "새로운",
+ "Realtime Transform Text": "실시간 텍스트 변환",
+ "Normalization Result Preview (Currently Only Chinese)": "정규화 결과 미리보기(현재 중국어만 지원)",
+ "Text Normalization": "텍스트 정규화",
+ "Select Example Audio": "예시 오디오 선택"
+}
diff --git a/fish_speech/i18n/locale/pt_BR.json b/fish_speech/i18n/locale/pt_BR.json
new file mode 100644
index 0000000000000000000000000000000000000000..385f20272e19053ab9b6cf6463a84c8ece768c68
--- /dev/null
+++ b/fish_speech/i18n/locale/pt_BR.json
@@ -0,0 +1,133 @@
+{
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de áudio de referência, útil para especificar o orador.",
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Um modelo de texto para fala baseado em VQ-GAN e Llama desenvolvido por [Fish Audio](https://fish.audio).",
+ "Accumulate Gradient Batches": "Acumular Lotes de Gradiente",
+ "Add to Processing Area": "Adicionar à Área de Processamento",
+ "Added path successfully!": "Caminho adicionado com sucesso!",
+ "Advanced Config": "Configuração Avançada",
+ "Base LLAMA Model": "Modelo LLAMA Base",
+ "Batch Inference": "Inferência em Lote",
+ "Batch Size": "Tamanho do Lote",
+ "Changing with the Model Path": "Alterando com o Caminho do Modelo",
+
+ "Compile Model": "Compilar Modelo",
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar o modelo pode reduzir significativamente o tempo de inferência, mas aumentará a latência inicial",
+ "Copy": "Copiar",
+ "Data Preprocessing": "Pré-processamento de Dados",
+ "Data Preprocessing Path": "Caminho de Pré-processamento de Dados",
+ "Data Source": "Fonte de Dados",
+ "Decoder Model Config": "Configuração do Modelo Decodificador",
+ "Decoder Model Path": "Caminho do Modelo Decodificador",
+ "Disabled": "Desativado",
+ "Enable Initial Prompt": "Habilitar Prompt Inicial",
+ "Enable Reference Audio": "Habilitar Áudio de Referência",
+ "English": "Inglês",
+ "Japanese": "Japonês",
+ "Chinese": "Chinês",
+ "Portuguese": "Português",
+ "Spanish": "Espanhol",
+ "Error Message": "Mensagem de Erro",
+ "Faster Whisper, Up to 5g GPU memory usage": "Faster Whisper (Usa até 5 GB de vRAM)",
+ "File Preprocessing": "Pré-processamento de Arquivos",
+ "Generate": "Gerar",
+ "Generated Audio": "Áudio Gerado",
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Se não houver texto correspondente ao áudio, utilize o ASR para assistência (formatos .txt ou .lab)",
+ "Infer interface is closed": "A interface de inferência foi fechada",
+ "Inference Configuration": "Configuração de Inferência",
+ "Inference Server Configuration": "Configuração do Servidor de Inferência",
+ "Inference Server Error": "Erro do Servidor de Inferência",
+ "Inferring interface is launched at {}": "A interface de inferência foi iniciada em {}",
+ "Initial Learning Rate": "Taxa de Aprendizagem Inicial",
+ "Initial Prompt": "Prompt Inicial",
+ "Initial prompt can provide contextual or vocabulary-specific guidance to the model.": "O prompt inicial pode fornecer orientação contextual ou específica de vocabulário para o modelo.",
+ "Input Audio & Source Path for Transcription": "Entrada de Áudio/Caminho de Origem para Transcrição",
+ "Input Text": "Texto de Entrada",
+ "Invalid path: {}": "Caminho inválido: {}",
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "Para GPUs Nvidia é recomendado usar CUDA. Se não tiver uma GPU Nvidia, use CPU",
+ "Iterative Prompt Length, 0 means off": "Comprimento do Prompt Iterativo (0 = desativado)",
+ "LLAMA Configuration": "Configuração do LLAMA",
+ "LLAMA Model Config": "Configuração do Modelo LLAMA",
+ "LLAMA Model Path": "Caminho do Modelo LLAMA",
+ "Labeling Device": "Dispositivo de Rotulagem",
+ "LoRA Model to be merged": "Modelo LoRA para mesclagem",
+ "Maximum Length per Sample": "Comprimento Máximo por Amostra",
+ "Maximum Training Steps": "Etapas Máximas de Treinamento",
+ "Maximum tokens per batch, 0 means no limit": "Número máximo de tokens por lote, 0 significa sem limite",
+ "Merge": "Mesclar",
+ "Merge LoRA": "Mesclar LoRA",
+ "Merge successfully": "Mesclado com sucesso",
+ "Model Output Path": "Caminho de Saída do Modelo",
+ "Model Quantization": "Quantização do Modelo",
+ "Model Size": "Tamanho do Modelo",
+ "Move": "Mover",
+ "Move files successfully": "Arquivos movidos com sucesso",
+ "No audio generated, please check the input text.": "Nenhum áudio gerado, verifique o texto de entrada.",
+ "No selected options": "Nenhuma opção selecionada",
+ "Normalization Result Preview (Currently Only Chinese)": "Pré-visualização do Resultado da Normalização (Atualmente Apenas Chinês)",
+ "Number of Workers": "Número de Processos",
+ "Open Inference Server": "Abrir Servidor de Inferência",
+ "Open Labeler WebUI": "Abrir WebUI de Rotulagem",
+ "Open Tensorboard": "Abrir Tensorboard",
+ "Opened labeler in browser": "WebUI de rotulagem aberta no navegador",
+ "Optional Label Language": "Idioma do Rótulo (Opcional)",
+ "Optional online ver": "Versão online (opcional)",
+ "Output Path": "Caminho de Saída",
+ "Path error, please check the model file exists in the corresponding path": "Erro de caminho, verifique se o arquivo do modelo existe no caminho correspondente",
+ "Post-quantification Precision": "Precisão Pós-quantização",
+ "Precision": "Precisão",
+ "Probability of applying Speaker Condition": "Probabilidade de Aplicar Condição de Orador",
+ "Put your text here.": "Insira seu texto aqui.",
+ "Quantify": "Quantizar",
+ "Quantify successfully": "Quantizado com sucesso",
+ "Realtime Transform Text": "Transformar Texto em Tempo Real",
+ "Reference Audio": "Áudio de Referência",
+ "Reference Text": "Texto de Referência",
+ "warning": "Aviso",
+ "Pre-processing begins...": "O pré-processamento começou!",
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "O código relacionado e os pesos são licenciados sob a Licença CC BY-NC-SA 4.0.",
+ "Remove Selected Data": "Remover Dados Selecionados",
+ "Removed path successfully!": "Caminho removido com sucesso!",
+ "Repetition Penalty": "Penalidade de Repetição",
+ "Save model every n steps": "Salvar modelo a cada n etapas",
+ "Select LLAMA ckpt": "Selecionar .ckpt do LLAMA",
+ "Select source file processing method": "Escolha como processar o arquivo de origem",
+ "Select the model to be trained (Depending on the Tab page you are on)": "Selecione o modelo para o treinamento (dependendo da aba em que você está)",
+ "Selected: {}": "Selecionado: {}",
+ "Speaker is identified by the folder name": "O orador é identificado pelo nome da pasta",
+ "Start Training": "Iniciar Treinamento",
+ "Streaming Audio": "Áudio em Streaming",
+ "Streaming Generate": "Geração em Streaming",
+ "Tensorboard Host": "Host do Tensorboard",
+ "Tensorboard Log Path": "Caminho de Log do Tensorboard",
+ "Tensorboard Port": "Porta do Tensorboard",
+ "Tensorboard interface is closed": "A interface do Tensorboard está fechada",
+ "Tensorboard interface is launched at {}": "A interface do Tensorboard foi iniciada em {}",
+ "Text Normalization": "Normalização de Texto",
+ "Text is too long, please keep it under {} characters.": "O texto é muito longo. Mantenha-o com menos de {} caracteres.",
+ "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase": "Quanto menor a precisão quantitativa, mais a eficácia pode diminuir, mas maior será o aumento da eficiência",
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "O caminho da pasta de entrada à esquerda ou a lista de arquivos. Independentemente de estar marcada ou não, ela será utilizada para o treinamento subsequente nesta lista.",
+ "Training Configuration": "Configuração de Treinamento",
+ "Training Error": "Erro de Treinamento",
+ "Training stopped": "Treinamento interrompido!",
+ "Type the path or select from the dropdown": "Digite o caminho ou selecione no menu suspenso",
+ "Use LoRA": "Usar LoRA",
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "O uso de LoRAs pode economizar memória da GPU, mas também pode reduzir a qualidade",
+ "Use filelist": "Usar lista de arquivos",
+ "VQGAN Configuration": "Configuração do VQGAN",
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Visualizar o status da pasta de pré-processamento (use o controle deslizante para controlar a profundidade da árvore)",
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "Não nos responsabilizamos por qualquer uso indevido do modelo. Por favor, considere as leis e regulamentações locais antes de usá-lo.",
+ "WebUI Host": "Host da WebUI",
+ "WebUI Port": "Porta da WebUI",
+ "Whisper Model": "Modelo Whisper",
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Você pode encontrar o código fonte [aqui](https://github.com/fishaudio/fish-speech) e os modelos [aqui](https://huggingface.co/fishaudio/fish-speech-1).",
+ "auto": "automático",
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true é recomendado para GPUs da série 30+, 16-mixed é recomendado para GPUs da série 10+",
+ "latest": "mais recente",
+ "new": "novo",
+ "This audio introduces the basic concepts and applications of artificial intelligence and machine learning.": "Este áudio introduz os conceitos básicos e aplicações de inteligência artificial e aprendizado de máquina.",
+ "You don't need to train this model!": "Não é necessário treinar este modelo!",
+ "Yes": "Sim",
+ "No": "Não",
+ "version:": "versão:",
+ "author:": "autor:"
+}
diff --git a/fish_speech/i18n/locale/zh_CN.json b/fish_speech/i18n/locale/zh_CN.json
new file mode 100644
index 0000000000000000000000000000000000000000..9068ef0b9a41b9941b37644c6a4c96ec6a5d836e
--- /dev/null
+++ b/fish_speech/i18n/locale/zh_CN.json
@@ -0,0 +1,123 @@
+{
+ "16-mixed is recommended for 10+ series GPU": "10+ 系列 GPU 建议使用 16-mixed",
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 到 10 秒的参考音频,适用于指定音色。",
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.",
+ "Accumulate Gradient Batches": "梯度累积批次",
+ "Add to Processing Area": "加入处理区",
+ "Added path successfully!": "添加路径成功!",
+ "Advanced Config": "高级参数",
+ "Base LLAMA Model": "基础 LLAMA 模型",
+ "Batch Inference": "批量推理",
+ "Batch Size": "批次大小",
+ "Changing with the Model Path": "随模型路径变化",
+ "Chinese": "中文",
+ "Compile Model": "编译模型",
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "编译模型可以显著减少推理时间,但会增加冷启动时间",
+ "Copy": "复制",
+ "Data Preprocessing": "数据预处理",
+ "Data Preprocessing Path": "数据预处理路径",
+ "Data Source": "数据源",
+ "Decoder Model Config": "解码器模型配置",
+ "Decoder Model Path": "解码器模型路径",
+ "Disabled": "禁用",
+ "Enable Reference Audio": "启用参考音频",
+ "English": "英文",
+ "Error Message": "错误信息",
+ "File Preprocessing": "文件预处理",
+ "Generate": "生成",
+ "Generated Audio": "音频",
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "如果音频没有对应的文本,可以应用 ASR 辅助,支持 .txt 或 .lab 格式",
+ "Infer interface is closed": "推理界面已关闭",
+ "Inference Configuration": "推理配置",
+ "Inference Server Configuration": "推理服务器配置",
+ "Inference Server Error": "推理服务器错误",
+ "Inferring interface is launched at {}": "推理界面已在 {} 上启动",
+ "Initial Learning Rate": "初始学习率",
+ "Input Audio & Source Path for Transcription": "输入音频和转录源路径",
+ "Input Text": "输入文本",
+ "Invalid path: {}": "无效路径: {}",
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "建议使用 CUDA,如果配置较低,使用 CPU",
+ "Iterative Prompt Length, 0 means off": "迭代提示长度,0 表示关闭",
+ "Japanese": "日文",
+ "LLAMA Configuration": "LLAMA 配置",
+ "LLAMA Model Config": "LLAMA 模型配置",
+ "LLAMA Model Path": "LLAMA 模型路径",
+ "Labeling Device": "标注加速设备",
+ "LoRA Model to be merged": "要合并的 LoRA 模型",
+ "Maximum Audio Duration": "最大音频时长",
+ "Maximum Length per Sample": "每个样本的最大长度",
+ "Maximum Training Steps": "最大训练步数",
+ "Maximum tokens per batch, 0 means no limit": "每批最大令牌数,0 表示无限制",
+ "Merge": "合并",
+ "Merge LoRA": "合并 LoRA",
+ "Merge successfully": "合并成功",
+ "Minimum Audio Duration": "最小音频时长",
+ "Model Output Path": "模型输出路径",
+ "Model Size": "模型规模",
+ "Move": "移动",
+ "Move files successfully": "移动文件成功",
+ "No audio generated, please check the input text.": "没有生成音频,请检查输入文本.",
+ "No selected options": "没有选择的选项",
+ "Number of Workers": "数据加载进程数",
+ "Open Inference Server": "打开推理服务器",
+ "Open Labeler WebUI": "打开标注工具",
+ "Open Tensorboard": "打开 Tensorboard",
+ "Opened labeler in browser": "在浏览器中打开标注工具",
+ "Optional Label Language": "[可选] 标注语言",
+ "Optional online ver": "[可选] 使用在线版",
+ "Output Path": "输出路径",
+ "Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径",
+ "Precision": "精度",
+ "Probability of applying Speaker Condition": "应用说话人条件的概率",
+ "Put your text here.": "在此处输入文本.",
+ "Reference Audio": "参考音频",
+ "Reference Text": "参考文本",
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "相关代码和权重使用 CC BY-NC-SA 4.0 许可证发布.",
+ "Remove Selected Data": "移除选中数据",
+ "Removed path successfully!": "移除路径成功!",
+ "Repetition Penalty": "重复惩罚",
+ "Save model every n steps": "每 n 步保存模型",
+ "Select LLAMA ckpt": "选择 LLAMA 检查点",
+ "Select VITS ckpt": "选择 VITS 检查点",
+ "Select VQGAN ckpt": "选择 VQGAN 检查点",
+ "Select source file processing method": "选择源文件处理方法",
+ "Select the model to be trained (Depending on the Tab page you are on)": "根据您所在的选项卡页面选择要训练的模型",
+ "Selected: {}": "已选择: {}",
+ "Speaker": "说话人",
+ "Speaker is identified by the folder name": "自动根据父目录名称识别说话人",
+ "Start Training": "开始训练",
+ "Streaming Audio": "流式音频",
+ "Streaming Generate": "流式合成",
+ "Tensorboard Host": "Tensorboard 监听地址",
+ "Tensorboard Log Path": "Tensorboard 日志路径",
+ "Tensorboard Port": "Tensorboard 端口",
+ "Tensorboard interface is closed": "Tensorboard 界面已关闭",
+ "Tensorboard interface is launched at {}": "Tensorboard 界面已在 {} 上启动",
+ "Text is too long, please keep it under {} characters.": "文本太长,请保持在 {} 个字符以内.",
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左侧输入文件夹的路径或文件列表。无论是否选中,都将在此列表中用于后续训练.",
+ "Training Configuration": "训练配置",
+ "Training Error": "训练错误",
+ "Training stopped": "训练已停止",
+ "Type name of the speaker": "输入说话人的名称",
+ "Type the path or select from the dropdown": "输入路径或从下拉菜单中选择",
+ "Use LoRA": "使用 LoRA",
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "使用 LoRA 可以节省 GPU 内存,但可能会降低模型质量",
+ "Use filelist": "使用文件列表",
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 使用 large, 5G 使用 medium, 2G 使用 small",
+ "VITS Configuration": "VITS 配置",
+ "VQGAN Configuration": "VQGAN 配置",
+ "Validation Batch Size": "验证批次大小",
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "查看预处理文件夹的状态 (使用滑块控制树的深度)",
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.",
+ "WebUI Host": "WebUI 监听地址",
+ "WebUI Port": "WebUI 端口",
+ "Whisper Model": "Whisper 模型",
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.",
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 系列 GPU 建议使用 bf16-true, 10+ 系列 GPU 建议使用 16-mixed",
+ "latest": "最近的检查点",
+ "new": "创建新的检查点",
+ "Realtime Transform Text": "实时规范化文本",
+ "Normalization Result Preview (Currently Only Chinese)": "规范化结果预览",
+ "Text Normalization": "文本规范化",
+ "Select Example Audio": "选择参考音频"
+}
diff --git a/fish_speech/i18n/scan.py b/fish_speech/i18n/scan.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0194c0f1a31dc95309c64626d13f04751a44ba1
--- /dev/null
+++ b/fish_speech/i18n/scan.py
@@ -0,0 +1,122 @@
+import ast
+import glob
+import json
+from collections import OrderedDict
+from pathlib import Path
+
+from loguru import logger
+
+from .core import DEFAULT_LANGUAGE, I18N_FILE_PATH
+
+
+def extract_i18n_strings(node):
+ i18n_strings = []
+
+ if (
+ isinstance(node, ast.Call)
+ and isinstance(node.func, ast.Name)
+ and node.func.id == "i18n"
+ ):
+ for arg in node.args:
+ if isinstance(arg, ast.Str):
+ i18n_strings.append(arg.s)
+
+ for child_node in ast.iter_child_nodes(node):
+ i18n_strings.extend(extract_i18n_strings(child_node))
+
+ return i18n_strings
+
+
+# scan the directory for all .py files (recursively)
+# for each file, parse the code into an AST
+# for each AST, extract the i18n strings
+
+strings = []
+folders = ["fish_speech", "tools"]
+# for filename in glob.iglob("**/*.py", recursive=True):
+for folder in folders:
+ for f in Path(folder).rglob("*.py"):
+ code = f.read_text(encoding="utf-8")
+ if "i18n(" in code:
+ tree = ast.parse(code)
+ i18n_strings = extract_i18n_strings(tree)
+ logger.info(f"Found {len(i18n_strings)} i18n strings in {f}")
+ strings.extend(i18n_strings)
+
+code_keys = set(strings)
+logger.info(f"Total unique: {len(code_keys)}")
+
+
+standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
+with open(standard_file, "r", encoding="utf-8") as f:
+ standard_data = json.load(f, object_pairs_hook=OrderedDict)
+standard_keys = set(standard_data.keys())
+
+# Define the standard file name
+unused_keys = standard_keys - code_keys
+logger.info(f"Found {len(unused_keys)} unused keys in {standard_file}")
+for unused_key in unused_keys:
+ logger.info(f"\t{unused_key}")
+
+missing_keys = code_keys - standard_keys
+logger.info(f"Found {len(missing_keys)} missing keys in {standard_file}")
+for missing_key in missing_keys:
+ logger.info(f"\t{missing_key}")
+
+code_keys_dict = OrderedDict()
+for s in strings:
+ code_keys_dict[s] = s
+
+# write back
+with open(standard_file, "w", encoding="utf-8") as f:
+ json.dump(code_keys_dict, f, ensure_ascii=False, indent=4, sort_keys=True)
+ f.write("\n")
+
+logger.info(f"Updated {standard_file}")
+
+
+# Define the standard file name
+standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
+
+# Find all JSON files in the directory
+dir_path = I18N_FILE_PATH
+languages = [f for f in dir_path.glob("*.json") if f.stem != DEFAULT_LANGUAGE]
+
+# Load the standard file
+with open(standard_file, "r", encoding="utf-8") as f:
+ standard_data = json.load(f, object_pairs_hook=OrderedDict)
+
+# Loop through each language file
+for lang_file in languages:
+ # Load the language file
+ with open(lang_file, "r", encoding="utf-8") as f:
+ lang_data = json.load(f, object_pairs_hook=OrderedDict)
+
+ # Find the difference between the language file and the standard file
+ diff = set(standard_data.keys()) - set(lang_data.keys())
+
+ miss = set(lang_data.keys()) - set(standard_data.keys())
+
+ # Add any missing keys to the language file
+ for key in diff:
+ lang_data[key] = "#!" + key
+ logger.info(f"Added missing key: {key} to {lang_file}")
+
+ # Del any extra keys to the language file
+ for key in miss:
+ del lang_data[key]
+ logger.info(f"Del extra key: {key} from {lang_file}")
+
+ # Sort the keys of the language file to match the order of the standard file
+ lang_data = OrderedDict(
+ sorted(lang_data.items(), key=lambda x: list(standard_data.keys()).index(x[0]))
+ )
+
+ # Save the updated language file
+ with open(lang_file, "w", encoding="utf-8") as f:
+ json.dump(lang_data, f, ensure_ascii=False, indent=4, sort_keys=True)
+ f.write("\n")
+
+ logger.info(f"Updated {lang_file}")
+
+logger.info("Done")
diff --git a/fish_speech/models/text2semantic/__init__.py b/fish_speech/models/text2semantic/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/fish_speech/models/text2semantic/lit_module.py b/fish_speech/models/text2semantic/lit_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..df970400f8a073be4c4166a697245fabdf6b09b0
--- /dev/null
+++ b/fish_speech/models/text2semantic/lit_module.py
@@ -0,0 +1,202 @@
+from typing import Any, Optional
+
+import lightning as L
+import torch
+import torch.nn.functional as F
+from lightning.pytorch.utilities.types import OptimizerLRScheduler
+
+import fish_speech.utils as utils
+from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
+from fish_speech.models.text2semantic.llama import NaiveTransformer
+
+log = utils.RankedLogger(__name__, rank_zero_only=True)
+
+
+class TextToSemantic(L.LightningModule):
+ def __init__(
+ self,
+ model: NaiveTransformer,
+ optimizer: Any,
+ lr_scheduler: Any,
+ ):
+ super().__init__()
+
+ self.model = model
+ self.optimizer_builder = optimizer
+ self.lr_scheduler_builder = lr_scheduler
+
+ def forward(self, x):
+ return self.model(x)
+
+ def on_save_checkpoint(self, checkpoint):
+ # Save only LoRA parameters
+ state_dict = checkpoint["state_dict"]
+ use_lora = any("lora" in name for name in state_dict.keys())
+ if not use_lora:
+ return
+
+ for name in list(state_dict.keys()):
+ if "lora" not in name:
+ state_dict.pop(name)
+
+ def configure_optimizers(self) -> OptimizerLRScheduler:
+ # Get weight decay parameters
+ weight_decay_parameters, other_parameters = [], []
+ for name, param in self.named_parameters():
+ if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
+ other_parameters.append(param)
+ else:
+ weight_decay_parameters.append(param)
+
+ optimizer = self.optimizer_builder(
+ [
+ {"params": weight_decay_parameters},
+ {"params": other_parameters, "weight_decay": 0.0},
+ ]
+ )
+
+ # Print the parameters and their weight decay
+ for i in optimizer.param_groups:
+ log.info(
+ f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
+ )
+
+ lr_scheduler = self.lr_scheduler_builder(optimizer)
+
+ return {
+ "optimizer": optimizer,
+ "lr_scheduler": {
+ "scheduler": lr_scheduler,
+ "interval": "step",
+ },
+ }
+
+ # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
+ def get_batch_logps(
+ self,
+ logits: torch.FloatTensor,
+ labels: torch.LongTensor,
+ average_log_prob: bool = False,
+ ) -> torch.FloatTensor:
+ """Compute the log probabilities of the given labels under the given logits.
+
+ Args:
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
+
+ Returns:
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
+ """
+ assert logits.shape[:-1] == labels.shape
+
+ labels = labels.clone()
+ loss_mask = labels != -100
+
+ # dummy token; we'll ignore the losses on these tokens later
+ labels[labels == -100] = 0
+
+ per_token_logps = torch.gather(
+ logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
+ ).squeeze(-1)
+
+ if average_log_prob:
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
+ else:
+ return (per_token_logps * loss_mask).sum(-1)
+
+ def _step(self, batch, batch_idx, stage: str):
+ is_train = stage == "train"
+
+ if is_train:
+ # Key part to make lora work
+ # Otherwise the parameters are merged, which lead to incorrect gradients
+ self.model.train()
+
+ # Do positive and negative samples in the same batch to speed up training
+ labels = batch["labels"]
+ outputs = self.model(
+ inp=batch["inputs"],
+ key_padding_mask=batch["attention_masks"],
+ )
+ token_logits = outputs.token_logits
+ codebook_logits = outputs.codebook_logits
+
+ # Generate labels
+ base_loss = F.cross_entropy(
+ token_logits.view(-1, token_logits.size(-1)),
+ labels[:, 0].reshape(-1),
+ ignore_index=-100,
+ )
+
+ codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
+ semantic_loss = F.cross_entropy(
+ codebook_logits.view(-1, codebook_logits.size(-1)),
+ codebook_labels.reshape(-1),
+ ignore_index=-100,
+ )
+
+ loss = base_loss + semantic_loss
+
+ self.log(
+ f"{stage}/loss",
+ loss,
+ on_step=is_train,
+ on_epoch=not is_train,
+ prog_bar=True,
+ logger=True,
+ sync_dist=not is_train,
+ )
+
+ self.log(
+ f"{stage}/base_loss",
+ base_loss,
+ on_step=is_train,
+ on_epoch=not is_train,
+ prog_bar=False,
+ logger=True,
+ sync_dist=not is_train,
+ )
+
+ self.log(
+ f"{stage}/semantic_loss",
+ semantic_loss,
+ on_step=is_train,
+ on_epoch=not is_train,
+ prog_bar=False,
+ logger=True,
+ sync_dist=not is_train,
+ )
+
+ # Top-5 accuracy
+ accuracy = self.get_accuracy(codebook_logits, codebook_labels)
+ self.log(
+ f"{stage}/top_5_accuracy",
+ accuracy,
+ on_step=is_train,
+ on_epoch=not is_train,
+ prog_bar=True,
+ logger=True,
+ sync_dist=not is_train,
+ )
+
+ return loss
+
+ def get_accuracy(self, logits, labels):
+ mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID)
+ if mask.sum() == 0:
+ return torch.tensor(0.0, device=logits.device)
+
+ _, indices = logits.topk(5, dim=-1)
+ correct = indices.eq(labels.unsqueeze(-1))
+ correct[~mask] = 0
+ correct = correct.sum()
+ accuracy = correct / mask.sum()
+
+ return accuracy
+
+ def training_step(self, batch, batch_idx):
+ return self._step(batch, batch_idx, "train")
+
+ def validation_step(self, batch, batch_idx):
+ return self._step(batch, batch_idx, "val")
diff --git a/fish_speech/models/text2semantic/llama.py b/fish_speech/models/text2semantic/llama.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ea15e595f909069ee4f99e090897737f9857b68
--- /dev/null
+++ b/fish_speech/models/text2semantic/llama.py
@@ -0,0 +1,844 @@
+import dataclasses
+import json
+import math
+from collections import OrderedDict
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+from loguru import logger
+from torch import Tensor
+from torch.nn import functional as F
+from torch.nn.attention import SDPBackend, sdpa_kernel
+from torch.utils.checkpoint import checkpoint
+from transformers import AutoTokenizer
+
+from fish_speech.conversation import SEMANTIC_TOKEN
+from fish_speech.utils import RankedLogger
+
+from .lora import LoraConfig, setup_lora
+
+log = RankedLogger(__name__, rank_zero_only=True)
+
+
+def find_multiple(n: int, k: int) -> int:
+ if n % k == 0:
+ return n
+ return n + k - (n % k)
+
+
+@dataclass
+class BaseModelArgs:
+ model_type: str = "base"
+
+ vocab_size: int = 32000
+ n_layer: int = 32
+ n_head: int = 32
+ dim: int = 4096
+ intermediate_size: int = None
+ n_local_heads: int = -1
+ head_dim: int = 64
+ rope_base: float = 10000
+ norm_eps: float = 1e-5
+ max_seq_len: int = 2048
+ dropout: float = 0.0
+ tie_word_embeddings: bool = True
+ attention_qkv_bias: bool = False
+
+ # Codebook configs
+ codebook_size: int = 160
+ num_codebooks: int = 4
+
+ # Gradient checkpointing
+ use_gradient_checkpointing: bool = True
+
+ # Initialize the model
+ initializer_range: float = 0.02
+
+ # Dummy vars
+ is_reward_model: bool = False
+ share_codebook_embeddings: bool = True
+
+ def __post_init__(self):
+ if self.n_local_heads == -1:
+ self.n_local_heads = self.n_head
+ if self.intermediate_size is None:
+ hidden_dim = 4 * self.dim
+ n_hidden = int(2 * hidden_dim / 3)
+ self.intermediate_size = find_multiple(n_hidden, 256)
+ self.head_dim = self.dim // self.n_head
+
+ @staticmethod
+ def from_pretrained(path: str):
+ path = Path(path)
+
+ if path.is_dir():
+ path = path / "config.json"
+
+ with open(path, "r", encoding="utf-8") as f:
+ data = json.load(f)
+
+ match data["model_type"]:
+ case "naive":
+ cls = NaiveModelArgs
+ case "dual_ar":
+ cls = DualARModelArgs
+ case _:
+ raise ValueError(f"Unknown model type: {data['model_type']}")
+
+ return cls(**data)
+
+ def save(self, path: str):
+ with open(path, "w") as f:
+ json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
+
+
+@dataclass
+class NaiveModelArgs(BaseModelArgs):
+ model_type: str = "naive"
+
+
+@dataclass
+class DualARModelArgs(BaseModelArgs):
+ model_type: str = "dual_ar"
+ n_fast_layer: int = 4
+ fast_dim: int | None = None
+ fast_n_head: int | None = None
+ fast_n_local_heads: int | None = None
+ fast_head_dim: int | None = None
+ fast_intermediate_size: int | None = None
+ fast_attention_qkv_bias: bool | None = None
+
+ def __post_init__(self):
+ super().__post_init__()
+
+ self.fast_dim = self.fast_dim or self.dim
+ self.fast_n_head = self.fast_n_head or self.n_head
+ self.fast_n_local_heads = self.fast_n_local_heads or self.n_local_heads
+ self.fast_head_dim = self.fast_head_dim or self.head_dim
+ self.fast_intermediate_size = (
+ self.fast_intermediate_size or self.intermediate_size
+ )
+ self.fast_attention_qkv_bias = (
+ self.fast_attention_qkv_bias
+ if self.fast_attention_qkv_bias is not None
+ else self.attention_qkv_bias
+ )
+
+
+class KVCache(nn.Module):
+ def __init__(
+ self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
+ ):
+ super().__init__()
+ cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
+ self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
+ self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
+
+ def update(self, input_pos, k_val, v_val):
+ # input_pos: [S], k_val: [B, H, S, D]
+ assert input_pos.shape[0] == k_val.shape[2]
+
+ k_out = self.k_cache
+ v_out = self.v_cache
+ k_out[:, :, input_pos] = k_val
+ v_out[:, :, input_pos] = v_val
+
+ return k_out, v_out
+
+
+@dataclass
+class TransformerForwardResult:
+ token_logits: Tensor
+ codebook_logits: Tensor
+
+
+@dataclass
+class BaseTransformerForwardResult:
+ logits: Tensor
+ hidden_states: Tensor
+
+
+class BaseTransformer(nn.Module):
+ def __init__(
+ self, config: BaseModelArgs, tokenizer: AutoTokenizer, init_weights: bool = True
+ ) -> None:
+ super().__init__()
+ self.config = config
+ self.tokenizer = tokenizer
+
+ self.semantic_token_id = tokenizer.convert_tokens_to_ids(SEMANTIC_TOKEN)
+
+ # Slow transformer
+ self.embeddings = nn.Embedding(
+ config.vocab_size,
+ config.dim,
+ )
+ self.codebook_embeddings = nn.Embedding(
+ config.codebook_size * config.num_codebooks,
+ config.dim,
+ )
+ self.layers = nn.ModuleList(
+ TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
+ )
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
+
+ if self.config.tie_word_embeddings is False:
+ self.output = nn.Linear(
+ config.dim,
+ config.vocab_size,
+ bias=False,
+ )
+
+ self.register_buffer(
+ "freqs_cis",
+ precompute_freqs_cis(
+ config.max_seq_len,
+ config.dim // config.n_head,
+ config.rope_base,
+ ),
+ persistent=False,
+ )
+ self.register_buffer(
+ "causal_mask",
+ torch.tril(
+ torch.ones(
+ config.max_seq_len,
+ config.max_seq_len,
+ dtype=torch.bool,
+ )
+ ),
+ persistent=False,
+ )
+
+ # For kv cache
+ self.max_batch_size = -1
+ self.max_seq_len = -1
+
+ if init_weights:
+ self.apply(self._init_weights)
+
+ def setup_caches(
+ self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
+ ):
+ if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
+ return
+
+ head_dim = self.config.dim // self.config.n_head
+ max_seq_len = find_multiple(max_seq_len, 8)
+ self.max_seq_len = max_seq_len
+ self.max_batch_size = max_batch_size
+
+ for b in self.layers:
+ b.attention.kv_cache = KVCache(
+ max_batch_size,
+ max_seq_len,
+ self.config.n_local_heads,
+ head_dim,
+ dtype=dtype,
+ )
+
+ def embed(self, x: Tensor) -> Tensor:
+ vocab_embeds = [self.embeddings(x[:, 0])]
+ for i in range(self.config.num_codebooks):
+ emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size)
+ emb[x[:, 0] != self.semantic_token_id] = 0
+ vocab_embeds.append(emb)
+
+ x = torch.stack(vocab_embeds, dim=3)
+ x = x.sum(dim=3)
+
+ return x
+
+ def forward(
+ self,
+ inp: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ ) -> BaseTransformerForwardResult:
+ seq_len = inp.size(2)
+
+ # Here we want to merge the embeddings of the codebooks
+ x = self.embed(inp)
+
+ freqs_cis = self.freqs_cis[:seq_len]
+
+ # Not that the causal mask here follows the definition of scaled_dot_product_attention
+ # That is, FALSE means masked out
+ # To maintain consistency, key_padding_mask use TRUE to mask out
+ mask = None
+ if key_padding_mask is not None:
+ mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
+ mask = mask & key_padding_mask[:, None, None, :].logical_not()
+
+ for layer in self.layers:
+ if self.config.use_gradient_checkpointing and self.training:
+ x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
+ else:
+ x = layer(x, freqs_cis, mask)
+
+ # We got slow_out here
+ slow_out = self.norm(x)
+
+ if self.config.tie_word_embeddings:
+ token_logits = F.linear(slow_out, self.embeddings.weight)
+ else:
+ token_logits = self.output(slow_out)
+
+ return BaseTransformerForwardResult(
+ logits=token_logits,
+ hidden_states=x,
+ )
+
+ def forward_generate(
+ self,
+ x: Tensor,
+ input_pos: Optional[Tensor] = None,
+ return_all: bool = False,
+ ) -> BaseTransformerForwardResult:
+ # This is used for generation, optimized for torch compile
+ assert (
+ self.max_seq_len != -1 and self.max_batch_size != -1
+ ), "Please call setup_caches before forward_generate"
+
+ x = self.embed(x)
+
+ mask = self.causal_mask[
+ None, None, input_pos, : self.max_seq_len
+ ] # (B, N, Q, K)
+ freqs_cis = self.freqs_cis[input_pos]
+
+ for layer in self.layers:
+ x = layer(x, freqs_cis, mask, input_pos=input_pos)
+
+ # If prefill, we only calculate the logits of last token
+ if x.size(1) > 1 and not return_all:
+ x = x[:, -1:]
+
+ # We got slow_out here
+ slow_out = self.norm(x)
+
+ if self.config.tie_word_embeddings:
+ token_logits = F.linear(slow_out, self.embeddings.weight)
+ else:
+ token_logits = self.output(slow_out)
+
+ return BaseTransformerForwardResult(
+ logits=token_logits,
+ hidden_states=x,
+ )
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ @staticmethod
+ def from_pretrained(
+ path: str,
+ load_weights: bool = False,
+ max_length: int | None = None,
+ lora_config: LoraConfig | None = None,
+ rope_base: int | None = None,
+ ) -> "BaseTransformer":
+ config = BaseModelArgs.from_pretrained(str(path))
+ if max_length is not None:
+ config.max_seq_len = max_length
+ log.info(f"Override max_seq_len to {max_length}")
+
+ if rope_base is not None:
+ config.rope_base = rope_base
+ log.info(f"Override rope_base to {rope_base}")
+
+ match config.model_type:
+ case "naive":
+ model_cls = NaiveTransformer
+ case "dual_ar":
+ model_cls = DualARTransformer
+ case _:
+ raise ValueError(f"Unknown model type: {config.model_type}")
+
+ tokenizer = AutoTokenizer.from_pretrained(str(path))
+ log.info(f"Loading model from {path}, config: {config}")
+ model = model_cls(config, tokenizer=tokenizer)
+
+ if lora_config is not None:
+ setup_lora(model, lora_config)
+ log.info(f"LoRA setup: {lora_config}")
+
+ if load_weights is False:
+ log.info("Randomly initialized model")
+ else:
+
+ if "int8" in str(Path(path)):
+ logger.info("Using int8 weight-only quantization!")
+ from tools.llama.quantize import WeightOnlyInt8QuantHandler
+
+ simple_quantizer = WeightOnlyInt8QuantHandler(model)
+ model = simple_quantizer.convert_for_runtime()
+
+ if "int4" in str(Path(path)):
+ logger.info("Using int4 quantization!")
+ path_comps = path.name.split("-")
+ assert path_comps[-2].startswith("g")
+ groupsize = int(path_comps[-2][1:])
+ from tools.llama.quantize import WeightOnlyInt4QuantHandler
+
+ simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
+ model = simple_quantizer.convert_for_runtime()
+
+ weights = torch.load(
+ Path(path) / "model.pth",
+ map_location="cpu",
+ mmap=True,
+ weights_only=True,
+ )
+
+ if "state_dict" in weights:
+ logger.warning(
+ "Using a TextToSemantic LightningModule checkpoint, "
+ "please make sure it is a full model, not a LoRA model."
+ )
+ weights = weights["state_dict"]
+
+ if next(iter(weights.keys())).startswith("model."):
+ logger.info(
+ f"Remove prefix 'model.' created by TextToSemantic LightningModule from keys"
+ )
+ new_weights = OrderedDict()
+ for k, v in weights.items():
+ new_weights[k.replace("model.", "")] = v
+ weights = new_weights
+
+ # Verify the name and shape of parameters since strict=False in load_state_dict.
+ for k, v in model.named_parameters():
+ if k not in weights:
+ logger.warning(f"No weight for {k}")
+ elif v.shape != weights[k].shape:
+ logger.warning(
+ f"Shape mismatch for {k}: {v.shape} vs {weights[k].shape}"
+ )
+
+ err = model.load_state_dict(weights, strict=False, assign=True)
+ log.info(f"Loaded weights with error: {err}")
+
+ return model
+
+ def save_pretrained(self, path: str, drop_lora: bool = False):
+ path = Path(path)
+ path.mkdir(parents=True, exist_ok=True)
+
+ self.config.save(path / "config.json")
+ state_dict = self.state_dict()
+
+ if drop_lora:
+ for key in list(state_dict.keys()):
+ if "lora" not in key:
+ continue
+
+ state_dict.pop(key)
+ log.info(f"Drop LoRA parameter: {key}")
+
+ torch.save(state_dict, path / "model.pth")
+ self.tokenizer.save_pretrained(path)
+
+
+class NaiveTransformer(BaseTransformer):
+ def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
+ super().__init__(config, init_weights=False, tokenizer=tokenizer)
+
+ self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
+ self.codebook_output = nn.Linear(
+ config.dim,
+ config.codebook_size * config.num_codebooks,
+ bias=False,
+ )
+
+ self.apply(self._init_weights)
+
+ def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
+ token_logits = result.logits
+ x = result.hidden_states
+
+ # Codebook
+ codebook_logits = self.codebook_output(self.codebook_norm(x))
+ codebook_logits = rearrange(
+ codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
+ )
+
+ return TransformerForwardResult(
+ token_logits=token_logits,
+ codebook_logits=codebook_logits,
+ )
+
+ def forward(
+ self,
+ inp: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ ) -> TransformerForwardResult:
+ result = super().forward(
+ inp=inp,
+ key_padding_mask=key_padding_mask,
+ )
+ return self.decode(result)
+
+ def forward_generate(
+ self, x: Tensor, input_pos: Optional[Tensor] = None
+ ) -> TransformerForwardResult:
+ result = super().forward_generate(x, input_pos)
+ return self.decode(result)
+
+
+class DualARTransformer(BaseTransformer):
+ def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
+ super().__init__(config, init_weights=False, tokenizer=tokenizer)
+
+ # Project to fast dim if needed
+ if config.fast_dim is not None and config.fast_dim != config.dim:
+ self.fast_project_in = nn.Linear(config.dim, config.fast_dim)
+ else:
+ self.fast_project_in = nn.Identity()
+
+ # Fast transformer
+ self.fast_embeddings = nn.Embedding(config.codebook_size, config.fast_dim)
+
+ # The equivalent bs is so large that sdpa doesn't work
+ override_config = dataclasses.replace(
+ config,
+ dim=config.fast_dim,
+ n_head=config.fast_n_head,
+ n_local_heads=config.fast_n_local_heads,
+ head_dim=config.fast_head_dim,
+ intermediate_size=config.fast_intermediate_size,
+ attention_qkv_bias=config.fast_attention_qkv_bias,
+ )
+
+ self.fast_layers = nn.ModuleList(
+ TransformerBlock(override_config, use_sdpa=False)
+ for _ in range(config.n_fast_layer)
+ )
+ self.fast_norm = RMSNorm(config.fast_dim, eps=config.norm_eps)
+ self.fast_output = nn.Linear(
+ config.fast_dim,
+ config.codebook_size,
+ bias=False,
+ )
+
+ self.register_buffer(
+ "fast_freqs_cis",
+ precompute_freqs_cis(
+ config.num_codebooks,
+ config.fast_dim // config.fast_n_head,
+ config.rope_base,
+ ),
+ persistent=False,
+ )
+ self.apply(self._init_weights)
+
+ def setup_caches(
+ self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
+ ):
+ super().setup_caches(max_batch_size, max_seq_len, dtype)
+
+ head_dim = self.config.fast_dim // self.config.fast_n_head
+
+ # Fast transformer
+ # The max seq len here is the number of codebooks
+ for b in self.fast_layers:
+ b.attention.kv_cache = KVCache(
+ max_batch_size,
+ self.config.num_codebooks,
+ self.config.fast_n_local_heads,
+ head_dim,
+ dtype=dtype,
+ )
+
+ def forward(
+ self,
+ inp: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ ) -> TransformerForwardResult:
+ parent_result = super().forward(inp, key_padding_mask)
+ token_logits = parent_result.logits
+ x = parent_result.hidden_states
+ x = self.fast_project_in(x)
+
+ # Fast transformer
+ fast_seq_len = self.config.num_codebooks
+ fast_mask = self.causal_mask[
+ None, None, :fast_seq_len, :fast_seq_len
+ ] # (B, N, Q, K)
+
+ # Drop the last token and rotate left
+ codebooks = inp[:, 1:-1, 1:]
+ codebooks = F.pad(codebooks, (0, 1), value=0)
+ codebook_embeddings = self.fast_embeddings(codebooks)
+ x = torch.cat([x[:, None], codebook_embeddings], dim=1)
+ b, s = x.size(0), x.size(2)
+ x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
+
+ # Remove padded part
+ codebooks = rearrange(codebooks, "b n s -> (b s) n")
+ codebook_mask = (codebooks == 0).all(dim=-1)
+
+ if torch.all(codebook_mask):
+ # If all codebooks are padded, we keep first 8 to make sure the model runs
+ codebook_mask[:8] = False
+
+ x_bs, x_len = x.size(0), x.size(1)
+ x = x[~codebook_mask]
+
+ for layer in self.fast_layers:
+ if self.config.use_gradient_checkpointing and self.training:
+ x = checkpoint(
+ layer, x, self.fast_freqs_cis, fast_mask, use_reentrant=True
+ )
+ else:
+ x = layer(x, self.fast_freqs_cis, fast_mask)
+
+ # unflatten the batch and num_codebooks
+ fast_out = self.fast_norm(x)
+ codebook_logits = self.fast_output(fast_out)
+
+ # Re-pad the codebook_logits
+ buffer = torch.zeros(
+ x_bs,
+ x_len,
+ codebook_logits.size(-1),
+ device=codebook_logits.device,
+ dtype=codebook_logits.dtype,
+ )
+ buffer[~codebook_mask] = codebook_logits
+ codebook_logits = buffer
+
+ assert codebook_logits.shape[1] == self.config.num_codebooks
+ codebook_logits = rearrange(
+ codebook_logits,
+ "(b s) n d -> b s n d",
+ b=b,
+ s=s,
+ n=self.config.num_codebooks,
+ )
+
+ return TransformerForwardResult(
+ token_logits=token_logits,
+ codebook_logits=codebook_logits,
+ )
+
+ def forward_generate_fast(
+ self, x: Tensor, input_pos: Optional[Tensor] = None
+ ) -> Tensor:
+ # Fast transformer
+ x = x.view(1, 1, -1)
+
+ fast_mask = self.causal_mask[
+ None, None, input_pos, : self.config.num_codebooks
+ ] # (B, N, Q, K)
+ fast_freqs_cis = self.fast_freqs_cis[input_pos]
+
+ for layer in self.fast_layers:
+ x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
+
+ # unflatten the batch and num_codebooks
+ fast_out = self.fast_norm(x) # only take the last token
+ codebook_logits = self.fast_output(fast_out)
+
+ return codebook_logits
+
+ def forward_generate(
+ self, x: Tensor, input_pos: Optional[Tensor] = None
+ ) -> TransformerForwardResult:
+ x = super().forward_generate(x, input_pos)
+ x.hidden_states = self.fast_project_in(x.hidden_states)
+ return x
+
+
+class TransformerBlock(nn.Module):
+ def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
+ super().__init__()
+ self.attention = Attention(config, use_sdpa=use_sdpa)
+ self.feed_forward = FeedForward(config)
+ self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
+ self.attention_norm = RMSNorm(config.dim, config.norm_eps)
+
+ def forward(
+ self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
+ ) -> Tensor:
+ h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
+ out = h + self.feed_forward(self.ffn_norm(h))
+ return out
+
+
+class Attention(nn.Module):
+ def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
+ super().__init__()
+ assert config.dim % config.n_head == 0
+
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
+ # key, query, value projections for all heads, but in a batch
+ self.wqkv = nn.Linear(
+ config.dim, total_head_dim, bias=config.attention_qkv_bias
+ )
+ self.wo = nn.Linear(config.dim, config.dim, bias=False)
+ self.kv_cache = None
+
+ self.dropout = config.dropout
+ self.n_head = config.n_head
+ self.head_dim = config.head_dim
+ self.n_local_heads = config.n_local_heads
+ self.dim = config.dim
+ self.use_sdpa = use_sdpa
+ self._register_load_state_dict_pre_hook(self.load_hook)
+
+ def load_hook(self, state_dict, prefix, *args):
+ if prefix + "wq.weight" in state_dict:
+ wq = state_dict.pop(prefix + "wq.weight")
+ wk = state_dict.pop(prefix + "wk.weight")
+ wv = state_dict.pop(prefix + "wv.weight")
+ state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
+
+ def forward(
+ self,
+ x: Tensor,
+ freqs_cis: Tensor,
+ mask: Tensor,
+ input_pos: Optional[Tensor] = None,
+ ) -> Tensor:
+ bsz, seqlen, _ = x.shape
+
+ kv_size = self.n_local_heads * self.head_dim
+ q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
+
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
+ k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
+ v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
+
+ q = apply_rotary_emb(q, freqs_cis)
+ k = apply_rotary_emb(k, freqs_cis)
+
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
+
+ if self.kv_cache is not None:
+ k, v = self.kv_cache.update(input_pos, k, v)
+
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+
+ if self.use_sdpa:
+ if mask is None:
+ with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
+ y = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ dropout_p=self.dropout if self.training else 0.0,
+ is_causal=True,
+ # No third party attn_mask here to use flash_attention
+ )
+ else:
+ y = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ attn_mask=mask,
+ dropout_p=self.dropout if self.training else 0.0,
+ )
+ else:
+ y = self.eq_scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ attn_mask=mask,
+ dropout_p=self.dropout if self.training else 0.0,
+ )
+
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
+
+ return self.wo(y)
+
+ def eq_scaled_dot_product_attention(
+ self,
+ query,
+ key,
+ value,
+ attn_mask=None,
+ dropout_p=0.0,
+ ) -> torch.Tensor:
+ # This is a standard scaled dot product attention
+ # It's low efficient, but it doesn't raise cuda error
+
+ L, S = query.size(-2), key.size(-2)
+ scale_factor = 1 / math.sqrt(query.size(-1))
+ attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
+
+ if attn_mask is not None:
+ if attn_mask.dtype == torch.bool:
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
+ else:
+ attn_bias += attn_mask
+
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
+ attn_weight += attn_bias
+ attn_weight = torch.softmax(attn_weight, dim=-1)
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
+
+ return attn_weight @ value
+
+
+class FeedForward(nn.Module):
+ def __init__(self, config: BaseModelArgs) -> None:
+ super().__init__()
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+ self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-5):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def _norm(self, x):
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
+
+ def forward(self, x: Tensor) -> Tensor:
+ output = self._norm(x.float()).type_as(x)
+ return output * self.weight
+
+
+def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
+ freqs = 1.0 / (
+ base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
+ )
+ t = torch.arange(seq_len, device=freqs.device)
+ freqs = torch.outer(t, freqs)
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
+ return cache.to(dtype=torch.bfloat16)
+
+
+def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
+ freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
+ x_out2 = torch.stack(
+ [
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
+ ],
+ -1,
+ )
+
+ x_out2 = x_out2.flatten(3)
+ return x_out2.type_as(x)
diff --git a/fish_speech/models/text2semantic/lora.py b/fish_speech/models/text2semantic/lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..647ca6fcccf038e17d2cf91a2874281dff3e0938
--- /dev/null
+++ b/fish_speech/models/text2semantic/lora.py
@@ -0,0 +1,92 @@
+from dataclasses import dataclass
+
+import loralib as lora
+
+
+@dataclass
+class LoraConfig:
+ r: int
+ lora_alpha: float
+ lora_dropout: float = 0.0
+
+
+def setup_lora(model, lora_config):
+ # Replace the embedding layer with a LoRA layer
+ model.embeddings = lora.Embedding(
+ num_embeddings=model.embeddings.num_embeddings,
+ embedding_dim=model.embeddings.embedding_dim,
+ padding_idx=model.embeddings.padding_idx,
+ r=lora_config.r,
+ lora_alpha=lora_config.lora_alpha,
+ )
+
+ model.codebook_embeddings = lora.Embedding(
+ num_embeddings=model.codebook_embeddings.num_embeddings,
+ embedding_dim=model.codebook_embeddings.embedding_dim,
+ padding_idx=model.codebook_embeddings.padding_idx,
+ r=lora_config.r,
+ lora_alpha=lora_config.lora_alpha,
+ )
+
+ # Replace output layer with a LoRA layer
+ linears = [(model, "output")]
+
+ # Replace all linear layers with LoRA layers
+ for layer in model.layers:
+ linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
+ linears.extend(
+ [
+ (layer.feed_forward, "w1"),
+ (layer.feed_forward, "w2"),
+ (layer.feed_forward, "w3"),
+ ]
+ )
+
+ if hasattr(model, "fast_layers"):
+ model.fast_embeddings = lora.Embedding(
+ num_embeddings=model.fast_embeddings.num_embeddings,
+ embedding_dim=model.fast_embeddings.embedding_dim,
+ padding_idx=model.fast_embeddings.padding_idx,
+ r=lora_config.r,
+ lora_alpha=lora_config.lora_alpha,
+ )
+
+ # Dual-AR model
+ linears.append((model, "fast_output"))
+
+ for layer in model.fast_layers:
+ linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
+ linears.extend(
+ [
+ (layer.feed_forward, "w1"),
+ (layer.feed_forward, "w2"),
+ (layer.feed_forward, "w3"),
+ ]
+ )
+
+ for module, layer in linears:
+ updated_linear = lora.Linear(
+ in_features=getattr(module, layer).in_features,
+ out_features=getattr(module, layer).out_features,
+ bias=getattr(module, layer).bias,
+ r=lora_config.r,
+ lora_alpha=lora_config.lora_alpha,
+ lora_dropout=lora_config.lora_dropout,
+ )
+ setattr(module, layer, updated_linear)
+
+ # Mark only the LoRA layers as trainable
+ lora.mark_only_lora_as_trainable(model, bias="none")
+
+
+def get_merged_state_dict(model):
+ # This line will merge the state dict of the model and the LoRA parameters
+ model.eval()
+
+ # Then we need to remove the LoRA parameters from the state dict
+ state_dict = model.state_dict()
+ for name in list(state_dict.keys()):
+ if "lora" in name:
+ state_dict.pop(name)
+
+ return state_dict
diff --git a/fish_speech/models/vqgan/__init__.py b/fish_speech/models/vqgan/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/fish_speech/models/vqgan/modules/firefly.py b/fish_speech/models/vqgan/modules/firefly.py
new file mode 100644
index 0000000000000000000000000000000000000000..91fc9118cc26f4d99171e7db3ee871071a7a296a
--- /dev/null
+++ b/fish_speech/models/vqgan/modules/firefly.py
@@ -0,0 +1,596 @@
+import math
+from functools import partial
+from math import prod
+from typing import Callable
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.nn.utils.parametrizations import weight_norm
+from torch.nn.utils.parametrize import remove_parametrizations
+from torch.utils.checkpoint import checkpoint
+
+
+def sequence_mask(length, max_length=None):
+ if max_length is None:
+ max_length = length.max()
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
+ return x.unsqueeze(0) < length.unsqueeze(1)
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv1D") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+ return (kernel_size * dilation - dilation) // 2
+
+
+def unpad1d(x: torch.Tensor, paddings: tuple[int, int]):
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
+ padding_left, padding_right = paddings
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+ assert (padding_left + padding_right) <= x.shape[-1]
+ end = x.shape[-1] - padding_right
+ return x[..., padding_left:end]
+
+
+def get_extra_padding_for_conv1d(
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
+) -> int:
+ """See `pad_for_conv1d`."""
+ length = x.shape[-1]
+ n_frames = (length - kernel_size + padding_total) / stride + 1
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
+ return ideal_length - length
+
+
+def pad1d(
+ x: torch.Tensor,
+ paddings: tuple[int, int],
+ mode: str = "zeros",
+ value: float = 0.0,
+):
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
+ If this is the case, we insert extra 0 padding to the right
+ before the reflection happen.
+ """
+ length = x.shape[-1]
+ padding_left, padding_right = paddings
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+ if mode == "reflect":
+ max_pad = max(padding_left, padding_right)
+ extra_pad = 0
+ if length <= max_pad:
+ extra_pad = max_pad - length + 1
+ x = F.pad(x, (0, extra_pad))
+ padded = F.pad(x, paddings, mode, value)
+ end = padded.shape[-1] - extra_pad
+ return padded[..., :end]
+ else:
+ return F.pad(x, paddings, mode, value)
+
+
+class FishConvNet(nn.Module):
+ def __init__(
+ self, in_channels, out_channels, kernel_size, dilation=1, stride=1, groups=1
+ ):
+ super(FishConvNet, self).__init__()
+ self.conv = nn.Conv1d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ dilation=dilation,
+ groups=groups,
+ )
+ self.stride = stride
+ self.kernel_size = (kernel_size - 1) * dilation + 1
+ self.dilation = dilation
+
+ def forward(self, x):
+ pad = self.kernel_size - self.stride
+ extra_padding = get_extra_padding_for_conv1d(
+ x, self.kernel_size, self.stride, pad
+ )
+ x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
+ return self.conv(x).contiguous()
+
+ def weight_norm(self, name="weight", dim=0):
+ self.conv = weight_norm(self.conv, name=name, dim=dim)
+ return self
+
+ def remove_parametrizations(self, name="weight"):
+ self.conv = remove_parametrizations(self.conv, name)
+ return self
+
+
+class FishTransConvNet(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1):
+ super(FishTransConvNet, self).__init__()
+ self.conv = nn.ConvTranspose1d(
+ in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
+ )
+ self.stride = stride
+ self.kernel_size = kernel_size
+
+ def forward(self, x):
+ x = self.conv(x)
+ pad = self.kernel_size - self.stride
+ padding_right = math.ceil(pad)
+ padding_left = pad - padding_right
+ x = unpad1d(x, (padding_left, padding_right))
+ return x.contiguous()
+
+ def weight_norm(self, name="weight", dim=0):
+ self.conv = weight_norm(self.conv, name=name, dim=dim)
+ return self
+
+ def remove_parametrizations(self, name="weight"):
+ self.conv = remove_parametrizations(self.conv, name)
+ return self
+
+
+class ResBlock1(torch.nn.Module):
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
+ super().__init__()
+
+ self.convs1 = nn.ModuleList(
+ [
+ FishConvNet(
+ channels, channels, kernel_size, stride=1, dilation=dilation[0]
+ ).weight_norm(),
+ FishConvNet(
+ channels, channels, kernel_size, stride=1, dilation=dilation[1]
+ ).weight_norm(),
+ FishConvNet(
+ channels, channels, kernel_size, stride=1, dilation=dilation[2]
+ ).weight_norm(),
+ ]
+ )
+ self.convs1.apply(init_weights)
+
+ self.convs2 = nn.ModuleList(
+ [
+ FishConvNet(
+ channels, channels, kernel_size, stride=1, dilation=dilation[0]
+ ).weight_norm(),
+ FishConvNet(
+ channels, channels, kernel_size, stride=1, dilation=dilation[1]
+ ).weight_norm(),
+ FishConvNet(
+ channels, channels, kernel_size, stride=1, dilation=dilation[2]
+ ).weight_norm(),
+ ]
+ )
+ self.convs2.apply(init_weights)
+
+ def forward(self, x):
+ for c1, c2 in zip(self.convs1, self.convs2):
+ xt = F.silu(x)
+ xt = c1(xt)
+ xt = F.silu(xt)
+ xt = c2(xt)
+ x = xt + x
+ return x
+
+ def remove_parametrizations(self):
+ for conv in self.convs1:
+ conv.remove_parametrizations()
+ for conv in self.convs2:
+ conv.remove_parametrizations()
+
+
+class ParallelBlock(nn.Module):
+ def __init__(
+ self,
+ channels: int,
+ kernel_sizes: tuple[int] = (3, 7, 11),
+ dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
+ ):
+ super().__init__()
+
+ assert len(kernel_sizes) == len(dilation_sizes)
+
+ self.blocks = nn.ModuleList()
+ for k, d in zip(kernel_sizes, dilation_sizes):
+ self.blocks.append(ResBlock1(channels, k, d))
+
+ def forward(self, x):
+ return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0)
+
+ def remove_parametrizations(self):
+ for block in self.blocks:
+ block.remove_parametrizations()
+
+
+class HiFiGANGenerator(nn.Module):
+ def __init__(
+ self,
+ *,
+ hop_length: int = 512,
+ upsample_rates: tuple[int] = (8, 8, 2, 2, 2),
+ upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2),
+ resblock_kernel_sizes: tuple[int] = (3, 7, 11),
+ resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
+ num_mels: int = 128,
+ upsample_initial_channel: int = 512,
+ pre_conv_kernel_size: int = 7,
+ post_conv_kernel_size: int = 7,
+ post_activation: Callable = partial(nn.SiLU, inplace=True),
+ ):
+ super().__init__()
+
+ assert (
+ prod(upsample_rates) == hop_length
+ ), f"hop_length must be {prod(upsample_rates)}"
+
+ self.conv_pre = FishConvNet(
+ num_mels,
+ upsample_initial_channel,
+ pre_conv_kernel_size,
+ stride=1,
+ ).weight_norm()
+
+ self.num_upsamples = len(upsample_rates)
+ self.num_kernels = len(resblock_kernel_sizes)
+
+ self.noise_convs = nn.ModuleList()
+ self.ups = nn.ModuleList()
+
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+ self.ups.append(
+ FishTransConvNet(
+ upsample_initial_channel // (2**i),
+ upsample_initial_channel // (2 ** (i + 1)),
+ k,
+ stride=u,
+ ).weight_norm()
+ )
+
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = upsample_initial_channel // (2 ** (i + 1))
+ self.resblocks.append(
+ ParallelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
+ )
+
+ self.activation_post = post_activation()
+ self.conv_post = FishConvNet(
+ ch, 1, post_conv_kernel_size, stride=1
+ ).weight_norm()
+ self.ups.apply(init_weights)
+ self.conv_post.apply(init_weights)
+
+ def forward(self, x):
+ x = self.conv_pre(x)
+
+ for i in range(self.num_upsamples):
+ x = F.silu(x, inplace=True)
+ x = self.ups[i](x)
+
+ if self.training and self.checkpointing:
+ x = checkpoint(
+ self.resblocks[i],
+ x,
+ use_reentrant=False,
+ )
+ else:
+ x = self.resblocks[i](x)
+
+ x = self.activation_post(x)
+ x = self.conv_post(x)
+ x = torch.tanh(x)
+
+ return x
+
+ def remove_parametrizations(self):
+ for up in self.ups:
+ up.remove_parametrizations()
+ for block in self.resblocks:
+ block.remove_parametrizations()
+ self.conv_pre.remove_parametrizations()
+ self.conv_post.remove_parametrizations()
+
+
+# DropPath copied from timm library
+def drop_path(
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
+):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+ 'survival rate' as the argument.
+
+ """ # noqa: E501
+
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (
+ x.ndim - 1
+ ) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0 and scale_by_keep:
+ random_tensor.div_(keep_prob)
+ return x * random_tensor
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
+
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
+
+ def extra_repr(self):
+ return f"drop_prob={round(self.drop_prob,3):0.3f}"
+
+
+class LayerNorm(nn.Module):
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
+ with shape (batch_size, channels, height, width).
+ """ # noqa: E501
+
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
+ self.eps = eps
+ self.data_format = data_format
+ if self.data_format not in ["channels_last", "channels_first"]:
+ raise NotImplementedError
+ self.normalized_shape = (normalized_shape,)
+
+ def forward(self, x):
+ if self.data_format == "channels_last":
+ return F.layer_norm(
+ x, self.normalized_shape, self.weight, self.bias, self.eps
+ )
+ elif self.data_format == "channels_first":
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = self.weight[:, None] * x + self.bias[:, None]
+ return x
+
+
+# ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py
+class ConvNeXtBlock(nn.Module):
+ r"""ConvNeXt Block. There are two equivalent implementations:
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
+ We use (2) as we find it slightly faster in PyTorch
+
+ Args:
+ dim (int): Number of input channels.
+ drop_path (float): Stochastic depth rate. Default: 0.0
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
+ kernel_size (int): Kernel size for depthwise conv. Default: 7.
+ dilation (int): Dilation for depthwise conv. Default: 1.
+ """ # noqa: E501
+
+ def __init__(
+ self,
+ dim: int,
+ drop_path: float = 0.0,
+ layer_scale_init_value: float = 1e-6,
+ mlp_ratio: float = 4.0,
+ kernel_size: int = 7,
+ dilation: int = 1,
+ ):
+ super().__init__()
+
+ self.dwconv = FishConvNet(
+ dim,
+ dim,
+ kernel_size=kernel_size,
+ # padding=int(dilation * (kernel_size - 1) / 2),
+ groups=dim,
+ ) # depthwise conv
+ self.norm = LayerNorm(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(
+ dim, int(mlp_ratio * dim)
+ ) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
+ self.gamma = (
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
+ if layer_scale_init_value > 0
+ else None
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ def forward(self, x, apply_residual: bool = True):
+ input = x
+
+ x = self.dwconv(x)
+ x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
+ x = self.norm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+
+ if self.gamma is not None:
+ x = self.gamma * x
+
+ x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
+ x = self.drop_path(x)
+
+ if apply_residual:
+ x = input + x
+
+ return x
+
+
+class ConvNeXtEncoder(nn.Module):
+ def __init__(
+ self,
+ input_channels: int = 3,
+ depths: list[int] = [3, 3, 9, 3],
+ dims: list[int] = [96, 192, 384, 768],
+ drop_path_rate: float = 0.0,
+ layer_scale_init_value: float = 1e-6,
+ kernel_size: int = 7,
+ ):
+ super().__init__()
+ assert len(depths) == len(dims)
+
+ self.downsample_layers = nn.ModuleList()
+ stem = nn.Sequential(
+ FishConvNet(
+ input_channels,
+ dims[0],
+ kernel_size=7,
+ # padding=3,
+ # padding_mode="replicate",
+ # padding_mode="zeros",
+ ),
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
+ )
+ self.downsample_layers.append(stem)
+
+ for i in range(len(depths) - 1):
+ mid_layer = nn.Sequential(
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
+ nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
+ )
+ self.downsample_layers.append(mid_layer)
+
+ self.stages = nn.ModuleList()
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
+
+ cur = 0
+ for i in range(len(depths)):
+ stage = nn.Sequential(
+ *[
+ ConvNeXtBlock(
+ dim=dims[i],
+ drop_path=dp_rates[cur + j],
+ layer_scale_init_value=layer_scale_init_value,
+ kernel_size=kernel_size,
+ )
+ for j in range(depths[i])
+ ]
+ )
+ self.stages.append(stage)
+ cur += depths[i]
+
+ self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ ) -> torch.Tensor:
+ for i in range(len(self.downsample_layers)):
+ x = self.downsample_layers[i](x)
+ x = self.stages[i](x)
+
+ return self.norm(x)
+
+
+class FireflyArchitecture(nn.Module):
+ def __init__(
+ self,
+ backbone: nn.Module,
+ head: nn.Module,
+ quantizer: nn.Module,
+ spec_transform: nn.Module,
+ ):
+ super().__init__()
+
+ self.backbone = backbone
+ self.head = head
+ self.quantizer = quantizer
+ self.spec_transform = spec_transform
+ self.downsample_factor = math.prod(self.quantizer.downsample_factor)
+
+ def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor:
+ if self.spec_transform is not None:
+ x = self.spec_transform(x)
+
+ x = self.backbone(x)
+ if mask is not None:
+ x = x * mask
+
+ if self.quantizer is not None:
+ vq_result = self.quantizer(x)
+ x = vq_result.z
+
+ if mask is not None:
+ x = x * mask
+
+ x = self.head(x, template=template)
+
+ if x.ndim == 2:
+ x = x[:, None, :]
+
+ if self.vq is not None:
+ return x, vq_result
+
+ return x
+
+ def encode(self, audios, audio_lengths):
+ audios = audios.float()
+
+ mels = self.spec_transform(audios)
+ mel_lengths = audio_lengths // self.spec_transform.hop_length
+ mel_masks = sequence_mask(mel_lengths, mels.shape[2])
+ mel_masks_float_conv = mel_masks[:, None, :].float()
+ mels = mels * mel_masks_float_conv
+
+ # Encode
+ encoded_features = self.backbone(mels) * mel_masks_float_conv
+ feature_lengths = mel_lengths // self.downsample_factor
+
+ return self.quantizer.encode(encoded_features), feature_lengths
+
+ def decode(self, indices, feature_lengths) -> torch.Tensor:
+ mel_masks = sequence_mask(
+ feature_lengths * self.downsample_factor,
+ indices.shape[2] * self.downsample_factor,
+ )
+ mel_masks_float_conv = mel_masks[:, None, :].float()
+ audio_lengths = (
+ feature_lengths * self.downsample_factor * self.spec_transform.hop_length
+ )
+
+ audio_masks = sequence_mask(
+ audio_lengths,
+ indices.shape[2] * self.downsample_factor * self.spec_transform.hop_length,
+ )
+ audio_masks_float_conv = audio_masks[:, None, :].float()
+
+ z = self.quantizer.decode(indices) * mel_masks_float_conv
+ x = self.head(z) * audio_masks_float_conv
+
+ return x, audio_lengths
+
+ def remove_parametrizations(self):
+ if hasattr(self.backbone, "remove_parametrizations"):
+ self.backbone.remove_parametrizations()
+
+ if hasattr(self.head, "remove_parametrizations"):
+ self.head.remove_parametrizations()
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
diff --git a/fish_speech/models/vqgan/modules/fsq.py b/fish_speech/models/vqgan/modules/fsq.py
new file mode 100644
index 0000000000000000000000000000000000000000..954553bbfe0b7b18d348db6c03bf04fc0c916c4f
--- /dev/null
+++ b/fish_speech/models/vqgan/modules/fsq.py
@@ -0,0 +1,116 @@
+from dataclasses import dataclass
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from vector_quantize_pytorch import GroupedResidualFSQ
+
+from .firefly import ConvNeXtBlock, FishConvNet, FishTransConvNet
+
+
+@dataclass
+class FSQResult:
+ z: torch.Tensor
+ codes: torch.Tensor
+ latents: torch.Tensor
+
+
+class DownsampleFiniteScalarQuantize(nn.Module):
+ def __init__(
+ self,
+ input_dim: int = 512,
+ n_codebooks: int = 9,
+ n_groups: int = 1,
+ levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
+ downsample_factor: tuple[int] = (2, 2),
+ downsample_dims: tuple[int] | None = None,
+ ):
+ super().__init__()
+
+ if downsample_dims is None:
+ downsample_dims = [input_dim for _ in range(len(downsample_factor))]
+
+ all_dims = (input_dim,) + tuple(downsample_dims)
+
+ self.residual_fsq = GroupedResidualFSQ(
+ dim=all_dims[-1],
+ levels=levels,
+ num_quantizers=n_codebooks,
+ groups=n_groups,
+ )
+
+ self.downsample_factor = downsample_factor
+ self.downsample_dims = downsample_dims
+
+ self.downsample = nn.Sequential(
+ *[
+ nn.Sequential(
+ FishConvNet(
+ all_dims[idx],
+ all_dims[idx + 1],
+ kernel_size=factor,
+ stride=factor,
+ ),
+ ConvNeXtBlock(dim=all_dims[idx + 1]),
+ )
+ for idx, factor in enumerate(downsample_factor)
+ ]
+ )
+
+ self.upsample = nn.Sequential(
+ *[
+ nn.Sequential(
+ FishTransConvNet(
+ all_dims[idx + 1],
+ all_dims[idx],
+ kernel_size=factor,
+ stride=factor,
+ ),
+ ConvNeXtBlock(dim=all_dims[idx]),
+ )
+ for idx, factor in reversed(list(enumerate(downsample_factor)))
+ ]
+ )
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, z) -> FSQResult:
+ original_shape = z.shape
+ z = self.downsample(z)
+ quantized, indices = self.residual_fsq(z.mT)
+ result = FSQResult(
+ z=quantized.mT,
+ codes=indices.mT,
+ latents=z,
+ )
+ result.z = self.upsample(result.z)
+
+ # Pad or crop z to match original shape
+ diff = original_shape[-1] - result.z.shape[-1]
+ left = diff // 2
+ right = diff - left
+
+ if diff > 0:
+ result.z = F.pad(result.z, (left, right))
+ elif diff < 0:
+ result.z = result.z[..., -left:right]
+
+ return result
+
+ def encode(self, z):
+ z = self.downsample(z)
+ _, indices = self.residual_fsq(z.mT)
+ indices = rearrange(indices, "g b l r -> b (g r) l")
+ return indices
+
+ def decode(self, indices: torch.Tensor):
+ indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups)
+ z_q = self.residual_fsq.get_output_from_indices(indices)
+ z_q = self.upsample(z_q.mT)
+ return z_q
diff --git a/fish_speech/models/vqgan/utils.py b/fish_speech/models/vqgan/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b90c131d214006875476a161cdfd2dffa8949dac
--- /dev/null
+++ b/fish_speech/models/vqgan/utils.py
@@ -0,0 +1,94 @@
+import matplotlib
+import torch
+from matplotlib import pyplot as plt
+
+matplotlib.use("Agg")
+
+
+def convert_pad_shape(pad_shape):
+ l = pad_shape[::-1]
+ pad_shape = [item for sublist in l for item in sublist]
+ return pad_shape
+
+
+def sequence_mask(length, max_length=None):
+ if max_length is None:
+ max_length = length.max()
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
+ return x.unsqueeze(0) < length.unsqueeze(1)
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+ return int((kernel_size * dilation - dilation) / 2)
+
+
+def plot_mel(data, titles=None):
+ fig, axes = plt.subplots(len(data), 1, squeeze=False)
+
+ if titles is None:
+ titles = [None for i in range(len(data))]
+
+ plt.tight_layout()
+
+ for i in range(len(data)):
+ mel = data[i]
+
+ if isinstance(mel, torch.Tensor):
+ mel = mel.float().detach().cpu().numpy()
+
+ axes[i][0].imshow(mel, origin="lower")
+ axes[i][0].set_aspect(2.5, adjustable="box")
+ axes[i][0].set_ylim(0, mel.shape[0])
+ axes[i][0].set_title(titles[i], fontsize="medium")
+ axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
+ axes[i][0].set_anchor("W")
+
+ return fig
+
+
+def slice_segments(x, ids_str, segment_size=4):
+ ret = torch.zeros_like(x[:, :, :segment_size])
+ for i in range(x.size(0)):
+ idx_str = ids_str[i]
+ idx_end = idx_str + segment_size
+ ret[i] = x[i, :, idx_str:idx_end]
+
+ return ret
+
+
+def rand_slice_segments(x, x_lengths=None, segment_size=4):
+ b, d, t = x.size()
+ if x_lengths is None:
+ x_lengths = t
+ ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
+ ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
+ ret = slice_segments(x, ids_str, segment_size)
+ return ret, ids_str
+
+
+@torch.jit.script
+def fused_add_tanh_sigmoid_multiply(in_act, n_channels):
+ n_channels_int = n_channels[0]
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
+ acts = t_act * s_act
+
+ return acts
+
+
+def avg_with_mask(x, mask):
+ assert mask.dtype == torch.float, "Mask should be float"
+
+ if mask.ndim == 2:
+ mask = mask.unsqueeze(1)
+
+ if mask.shape[1] == 1:
+ mask = mask.expand_as(x)
+
+ return (x * mask).sum() / mask.sum()
diff --git a/fish_speech/scheduler.py b/fish_speech/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..43bed6a2210723a7d5e1ea0a48ba61140047ca29
--- /dev/null
+++ b/fish_speech/scheduler.py
@@ -0,0 +1,40 @@
+import math
+
+
+def get_cosine_schedule_with_warmup_lr_lambda(
+ current_step: int,
+ *,
+ num_warmup_steps: int | float,
+ num_training_steps: int,
+ num_cycles: float = 0.5,
+ final_lr_ratio: float = 0.0,
+):
+ if 0 < num_warmup_steps < 1: # float mode
+ num_warmup_steps = int(num_warmup_steps * num_training_steps)
+
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+
+ progress = float(current_step - num_warmup_steps) / float(
+ max(1, num_training_steps - num_warmup_steps)
+ )
+
+ return max(
+ final_lr_ratio,
+ 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
+ )
+
+
+def get_constant_schedule_with_warmup_lr_lambda(
+ current_step: int,
+ *,
+ num_warmup_steps: int | float,
+ num_training_steps: int | None = None,
+):
+ if 0 < num_warmup_steps < 1: # float mode
+ num_warmup_steps = int(num_warmup_steps * num_training_steps)
+
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+
+ return 1.0
diff --git a/fish_speech/text/__init__.py b/fish_speech/text/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d740bd8eed447d162e55b165965dec17130377ce
--- /dev/null
+++ b/fish_speech/text/__init__.py
@@ -0,0 +1,4 @@
+from .clean import clean_text
+from .spliter import split_text
+
+__all__ = ["clean_text", "split_text"]
diff --git a/fish_speech/text/chn_text_norm/.gitignore b/fish_speech/text/chn_text_norm/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..75ea58fa4a7bf34fc9ab35afee24684aa6ef4c89
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/.gitignore
@@ -0,0 +1,114 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+
+# JetBrains PyCharm
+.idea
+
+# Customize
+references
+url.txt
+
+# Git
+.git
diff --git a/fish_speech/text/chn_text_norm/README.md b/fish_speech/text/chn_text_norm/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8450a2c6c0f8e40f4509f5be196eb9f9d2b9afb6
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/README.md
@@ -0,0 +1,36 @@
+# This account is no longer in use, see [Atomicoo](https://github.com/atomicoo) for my latest works.
+
+# Chn Text Norm
+
+this is a repository for chinese text normalization (no longer maintained).
+
+## Quick Start ##
+
+### Git Clone Repo ###
+
+git clone this repo to the root directory of your project which need to use it.
+
+ cd /path/to/proj
+ git clone https://github.com/Joee1995/chn-text-norm.git
+
+after that, your doc tree should be:
+```
+proj # root of your project
+|--- chn_text_norm # this chn-text-norm tool
+ |--- text.py
+ |--- ...
+|--- text_normalize.py # your text normalization code
+|--- ...
+```
+
+### How to Use ? ###
+
+ # text_normalize.py
+ from chn_text_norm.text import *
+
+ raw_text = 'your raw text'
+ text = Text(raw_text=raw_text).normalize()
+
+### How to add quantums ###
+
+打开test.py,然后你就知道怎么做了。
diff --git a/fish_speech/text/chn_text_norm/__init__.py b/fish_speech/text/chn_text_norm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/fish_speech/text/chn_text_norm/basic_class.py b/fish_speech/text/chn_text_norm/basic_class.py
new file mode 100644
index 0000000000000000000000000000000000000000..58d8f8eb7fc85d0861f106667d8f4e3e52b54761
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/basic_class.py
@@ -0,0 +1,172 @@
+# -*- coding: utf-8 -*-
+"""基本类
+中文字符类
+中文数字/数位类
+中文数字类
+中文数位类
+中文数字系统类
+中文数学符号类
+*中文其他符号类
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-02"
+
+from fish_speech.text.chn_text_norm.basic_constant import NUMBERING_TYPES
+
+
+class ChineseChar(object):
+ """
+ 中文字符
+ 每个字符对应简体和繁体,
+ e.g. 简体 = '负', 繁体 = '負'
+ 转换时可转换为简体或繁体
+ """
+
+ def __init__(self, simplified, traditional):
+ self.simplified = simplified
+ self.traditional = traditional
+ self.__repr__ = self.__str__
+
+ def __str__(self):
+ return self.simplified or self.traditional or None
+
+ def __repr__(self):
+ return self.__str__()
+
+
+class ChineseNumberUnit(ChineseChar):
+ """
+ 中文数字/数位字符
+ 每个字符除繁简体外还有一个额外的大写字符
+ e.g. '陆' 和 '陸'
+ """
+
+ def __init__(self, power, simplified, traditional, big_s, big_t):
+ super(ChineseNumberUnit, self).__init__(simplified, traditional)
+ self.power = power
+ self.big_s = big_s
+ self.big_t = big_t
+
+ def __str__(self):
+ return "10^{}".format(self.power)
+
+ @classmethod
+ def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
+
+ if small_unit:
+ return ChineseNumberUnit(
+ power=index + 1,
+ simplified=value[0],
+ traditional=value[1],
+ big_s=value[1],
+ big_t=value[1],
+ )
+ elif numbering_type == NUMBERING_TYPES[0]:
+ return ChineseNumberUnit(
+ power=index + 8,
+ simplified=value[0],
+ traditional=value[1],
+ big_s=value[0],
+ big_t=value[1],
+ )
+ elif numbering_type == NUMBERING_TYPES[1]:
+ return ChineseNumberUnit(
+ power=(index + 2) * 4,
+ simplified=value[0],
+ traditional=value[1],
+ big_s=value[0],
+ big_t=value[1],
+ )
+ elif numbering_type == NUMBERING_TYPES[2]:
+ return ChineseNumberUnit(
+ power=pow(2, index + 3),
+ simplified=value[0],
+ traditional=value[1],
+ big_s=value[0],
+ big_t=value[1],
+ )
+ else:
+ raise ValueError(
+ "Counting type should be in {0} ({1} provided).".format(
+ NUMBERING_TYPES, numbering_type
+ )
+ )
+
+
+class ChineseNumberDigit(ChineseChar):
+ """
+ 中文数字字符
+ """
+
+ def __init__(
+ self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None
+ ):
+ super(ChineseNumberDigit, self).__init__(simplified, traditional)
+ self.value = value
+ self.big_s = big_s
+ self.big_t = big_t
+ self.alt_s = alt_s
+ self.alt_t = alt_t
+
+ def __str__(self):
+ return str(self.value)
+
+ @classmethod
+ def create(cls, i, v):
+ return ChineseNumberDigit(i, v[0], v[1], v[2], v[3])
+
+
+class ChineseMath(ChineseChar):
+ """
+ 中文数位字符
+ """
+
+ def __init__(self, simplified, traditional, symbol, expression=None):
+ super(ChineseMath, self).__init__(simplified, traditional)
+ self.symbol = symbol
+ self.expression = expression
+ self.big_s = simplified
+ self.big_t = traditional
+
+
+CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath
+
+
+class NumberSystem(object):
+ """
+ 中文数字系统
+ """
+
+ pass
+
+
+class MathSymbol(object):
+ """
+ 用于中文数字系统的数学符号 (繁/简体), e.g.
+ positive = ['正', '正']
+ negative = ['负', '負']
+ point = ['点', '點']
+ """
+
+ def __init__(self, positive, negative, point):
+ self.positive = positive
+ self.negative = negative
+ self.point = point
+
+ def __iter__(self):
+ for v in self.__dict__.values():
+ yield v
+
+
+# class OtherSymbol(object):
+# """
+# 其他符号
+# """
+#
+# def __init__(self, sil):
+# self.sil = sil
+#
+# def __iter__(self):
+# for v in self.__dict__.values():
+# yield v
diff --git a/fish_speech/text/chn_text_norm/basic_constant.py b/fish_speech/text/chn_text_norm/basic_constant.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a65991b9a9d349a0571c80508633951e52749ef
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/basic_constant.py
@@ -0,0 +1,30 @@
+# -*- coding: utf-8 -*-
+"""基本常量
+中文数字/数位/符号字符常量
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-02"
+
+CHINESE_DIGIS = "零一二三四五六七八九"
+BIG_CHINESE_DIGIS_SIMPLIFIED = "零壹贰叁肆伍陆柒捌玖"
+BIG_CHINESE_DIGIS_TRADITIONAL = "零壹貳參肆伍陸柒捌玖"
+SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "十百千万"
+SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "拾佰仟萬"
+LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "亿兆京垓秭穰沟涧正载"
+LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "億兆京垓秭穰溝澗正載"
+SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "十百千万"
+SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "拾佰仟萬"
+
+ZERO_ALT = "〇"
+ONE_ALT = "幺"
+TWO_ALTS = ["两", "兩"]
+
+POSITIVE = ["正", "正"]
+NEGATIVE = ["负", "負"]
+POINT = ["点", "點"]
+# PLUS = [u'加', u'加']
+# SIL = [u'杠', u'槓']
+
+# 中文数字系统类型
+NUMBERING_TYPES = ["low", "mid", "high"]
diff --git a/fish_speech/text/chn_text_norm/basic_util.py b/fish_speech/text/chn_text_norm/basic_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbf6130be87f285eed9998186508ea489d3bac9e
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/basic_util.py
@@ -0,0 +1,342 @@
+# -*- coding: utf-8 -*-
+"""基本方法
+创建中文数字系统 方法
+中文字符串 <=> 数字串 方法
+数字串 <=> 中文字符串 方法
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-02"
+
+from fish_speech.text.chn_text_norm.basic_class import *
+from fish_speech.text.chn_text_norm.basic_constant import *
+
+
+def create_system(numbering_type=NUMBERING_TYPES[1]):
+ """
+ 根据数字系统类型返回创建相应的数字系统,默认为 mid
+ NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型
+ low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc.
+ mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc.
+ high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc.
+ 返回对应的数字系统
+ """
+
+ # chinese number units of '亿' and larger
+ all_larger_units = zip(
+ LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED,
+ LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL,
+ )
+ larger_units = [
+ CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units)
+ ]
+ # chinese number units of '十, 百, 千, 万'
+ all_smaller_units = zip(
+ SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED,
+ SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL,
+ )
+ smaller_units = [
+ CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units)
+ ]
+ # digis
+ chinese_digis = zip(
+ CHINESE_DIGIS,
+ CHINESE_DIGIS,
+ BIG_CHINESE_DIGIS_SIMPLIFIED,
+ BIG_CHINESE_DIGIS_TRADITIONAL,
+ )
+ digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)]
+ digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT
+ digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT
+ digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1]
+
+ # symbols
+ positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x)
+ negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x)
+ point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y)))
+ # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y)))
+ system = NumberSystem()
+ system.units = smaller_units + larger_units
+ system.digits = digits
+ system.math = MathSymbol(positive_cn, negative_cn, point_cn)
+ # system.symbols = OtherSymbol(sil_cn)
+ return system
+
+
+def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
+
+ def get_symbol(char, system):
+ for u in system.units:
+ if char in [u.traditional, u.simplified, u.big_s, u.big_t]:
+ return u
+ for d in system.digits:
+ if char in [
+ d.traditional,
+ d.simplified,
+ d.big_s,
+ d.big_t,
+ d.alt_s,
+ d.alt_t,
+ ]:
+ return d
+ for m in system.math:
+ if char in [m.traditional, m.simplified]:
+ return m
+
+ def string2symbols(chinese_string, system):
+ int_string, dec_string = chinese_string, ""
+ for p in [system.math.point.simplified, system.math.point.traditional]:
+ if p in chinese_string:
+ int_string, dec_string = chinese_string.split(p)
+ break
+ return [get_symbol(c, system) for c in int_string], [
+ get_symbol(c, system) for c in dec_string
+ ]
+
+ def correct_symbols(integer_symbols, system):
+ """
+ 一百八 to 一百八十
+ 一亿一千三百万 to 一亿 一千万 三百万
+ """
+
+ if integer_symbols and isinstance(integer_symbols[0], CNU):
+ if integer_symbols[0].power == 1:
+ integer_symbols = [system.digits[1]] + integer_symbols
+
+ if len(integer_symbols) > 1:
+ if isinstance(integer_symbols[-1], CND) and isinstance(
+ integer_symbols[-2], CNU
+ ):
+ integer_symbols.append(
+ CNU(integer_symbols[-2].power - 1, None, None, None, None)
+ )
+
+ result = []
+ unit_count = 0
+ for s in integer_symbols:
+ if isinstance(s, CND):
+ result.append(s)
+ unit_count = 0
+ elif isinstance(s, CNU):
+ current_unit = CNU(s.power, None, None, None, None)
+ unit_count += 1
+
+ if unit_count == 1:
+ result.append(current_unit)
+ elif unit_count > 1:
+ for i in range(len(result)):
+ if (
+ isinstance(result[-i - 1], CNU)
+ and result[-i - 1].power < current_unit.power
+ ):
+ result[-i - 1] = CNU(
+ result[-i - 1].power + current_unit.power,
+ None,
+ None,
+ None,
+ None,
+ )
+ return result
+
+ def compute_value(integer_symbols):
+ """
+ Compute the value.
+ When current unit is larger than previous unit, current unit * all previous units will be used as all previous units.
+ e.g. '两千万' = 2000 * 10000 not 2000 + 10000
+ """
+ value = [0]
+ last_power = 0
+ for s in integer_symbols:
+ if isinstance(s, CND):
+ value[-1] = s.value
+ elif isinstance(s, CNU):
+ value[-1] *= pow(10, s.power)
+ if s.power > last_power:
+ value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1]))
+ last_power = s.power
+ value.append(0)
+ return sum(value)
+
+ system = create_system(numbering_type)
+ int_part, dec_part = string2symbols(chinese_string, system)
+ int_part = correct_symbols(int_part, system)
+ int_str = str(compute_value(int_part))
+ dec_str = "".join([str(d.value) for d in dec_part])
+ if dec_part:
+ return "{0}.{1}".format(int_str, dec_str)
+ else:
+ return int_str
+
+
+def num2chn(
+ number_string,
+ numbering_type=NUMBERING_TYPES[1],
+ big=False,
+ traditional=False,
+ alt_zero=False,
+ alt_one=False,
+ alt_two=True,
+ use_zeros=True,
+ use_units=True,
+):
+
+ def get_value(value_string, use_zeros=True):
+
+ striped_string = value_string.lstrip("0")
+
+ # record nothing if all zeros
+ if not striped_string:
+ return []
+
+ # record one digits
+ elif len(striped_string) == 1:
+ if use_zeros and len(value_string) != len(striped_string):
+ return [system.digits[0], system.digits[int(striped_string)]]
+ else:
+ return [system.digits[int(striped_string)]]
+
+ # recursively record multiple digits
+ else:
+ result_unit = next(
+ u for u in reversed(system.units) if u.power < len(striped_string)
+ )
+ result_string = value_string[: -result_unit.power]
+ return (
+ get_value(result_string)
+ + [result_unit]
+ + get_value(striped_string[-result_unit.power :])
+ )
+
+ system = create_system(numbering_type)
+
+ int_dec = number_string.split(".")
+ if len(int_dec) == 1:
+ int_string = int_dec[0]
+ dec_string = ""
+ elif len(int_dec) == 2:
+ int_string = int_dec[0]
+ dec_string = int_dec[1]
+ else:
+ raise ValueError(
+ "invalid input num string with more than one dot: {}".format(number_string)
+ )
+
+ if use_units and len(int_string) > 1:
+ result_symbols = get_value(int_string)
+ else:
+ result_symbols = [system.digits[int(c)] for c in int_string]
+ dec_symbols = [system.digits[int(c)] for c in dec_string]
+ if dec_string:
+ result_symbols += [system.math.point] + dec_symbols
+
+ if alt_two:
+ liang = CND(
+ 2,
+ system.digits[2].alt_s,
+ system.digits[2].alt_t,
+ system.digits[2].big_s,
+ system.digits[2].big_t,
+ )
+ for i, v in enumerate(result_symbols):
+ if isinstance(v, CND) and v.value == 2:
+ next_symbol = (
+ result_symbols[i + 1] if i < len(result_symbols) - 1 else None
+ )
+ previous_symbol = result_symbols[i - 1] if i > 0 else None
+ if isinstance(next_symbol, CNU) and isinstance(
+ previous_symbol, (CNU, type(None))
+ ):
+ if next_symbol.power != 1 and (
+ (previous_symbol is None) or (previous_symbol.power != 1)
+ ):
+ result_symbols[i] = liang
+
+ # if big is True, '两' will not be used and `alt_two` has no impact on output
+ if big:
+ attr_name = "big_"
+ if traditional:
+ attr_name += "t"
+ else:
+ attr_name += "s"
+ else:
+ if traditional:
+ attr_name = "traditional"
+ else:
+ attr_name = "simplified"
+
+ result = "".join([getattr(s, attr_name) for s in result_symbols])
+
+ # if not use_zeros:
+ # result = result.strip(getattr(system.digits[0], attr_name))
+
+ if alt_zero:
+ result = result.replace(
+ getattr(system.digits[0], attr_name), system.digits[0].alt_s
+ )
+
+ if alt_one:
+ result = result.replace(
+ getattr(system.digits[1], attr_name), system.digits[1].alt_s
+ )
+
+ for i, p in enumerate(POINT):
+ if result.startswith(p):
+ return CHINESE_DIGIS[0] + result
+
+ # ^10, 11, .., 19
+ if (
+ len(result) >= 2
+ and result[1]
+ in [
+ SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0],
+ SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0],
+ ]
+ and result[0]
+ in [
+ CHINESE_DIGIS[1],
+ BIG_CHINESE_DIGIS_SIMPLIFIED[1],
+ BIG_CHINESE_DIGIS_TRADITIONAL[1],
+ ]
+ ):
+ result = result[1:]
+
+ return result
+
+
+if __name__ == "__main__":
+
+ # 测试程序
+ all_chinese_number_string = (
+ CHINESE_DIGIS
+ + BIG_CHINESE_DIGIS_SIMPLIFIED
+ + BIG_CHINESE_DIGIS_TRADITIONAL
+ + LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED
+ + LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL
+ + SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED
+ + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL
+ + ZERO_ALT
+ + ONE_ALT
+ + "".join(TWO_ALTS + POSITIVE + NEGATIVE + POINT)
+ )
+
+ print("num:", chn2num("一万零四百零三点八零五"))
+ print("num:", chn2num("一亿六点三"))
+ print("num:", chn2num("一亿零六点三"))
+ print("num:", chn2num("两千零一亿六点三"))
+ # print('num:', chn2num('一零零八六'))
+ print("txt:", num2chn("10260.03", alt_zero=True))
+ print("txt:", num2chn("20037.090", numbering_type="low", traditional=True))
+ print("txt:", num2chn("100860001.77", numbering_type="high", big=True))
+ print(
+ "txt:",
+ num2chn(
+ "059523810880",
+ alt_one=True,
+ alt_two=False,
+ use_lzeros=True,
+ use_rzeros=True,
+ use_units=False,
+ ),
+ )
+
+ print(all_chinese_number_string)
diff --git a/fish_speech/text/chn_text_norm/cardinal.py b/fish_speech/text/chn_text_norm/cardinal.py
new file mode 100644
index 0000000000000000000000000000000000000000..ace9f5ad8e7f3be3a8e41b11dc0b9f80db799616
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/cardinal.py
@@ -0,0 +1,32 @@
+# -*- coding: utf-8 -*-
+"""CARDINAL类 (包含小数DECIMAL类)
+纯数 <=> 中文字符串 方法
+中文字符串 <=> 纯数 方法
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-03"
+
+from fish_speech.text.chn_text_norm.basic_util import *
+
+
+class Cardinal:
+ """
+ CARDINAL类
+ """
+
+ def __init__(self, cardinal=None, chntext=None):
+ self.cardinal = cardinal
+ self.chntext = chntext
+
+ def chntext2cardinal(self):
+ return chn2num(self.chntext)
+
+ def cardinal2chntext(self):
+ return num2chn(self.cardinal)
+
+
+if __name__ == "__main__":
+
+ # 测试程序
+ print(Cardinal(cardinal="21357.230").cardinal2chntext())
diff --git a/fish_speech/text/chn_text_norm/date.py b/fish_speech/text/chn_text_norm/date.py
new file mode 100644
index 0000000000000000000000000000000000000000..77acfdb9a91df0fe3c615a0784f61aad87fbe56e
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/date.py
@@ -0,0 +1,75 @@
+# -*- coding: utf-8 -*-
+"""DATE类
+日期 <=> 中文字符串 方法
+中文字符串 <=> 日期 方法
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-07"
+
+from fish_speech.text.chn_text_norm.cardinal import Cardinal
+from fish_speech.text.chn_text_norm.digit import Digit
+
+
+class Date:
+ """
+ DATE类
+ """
+
+ def __init__(self, date=None, chntext=None):
+ self.date = date
+ self.chntext = chntext
+
+ # def chntext2date(self):
+ # chntext = self.chntext
+ # try:
+ # year, other = chntext.strip().split('年', maxsplit=1)
+ # year = Digit(chntext=year).digit2chntext() + '年'
+ # except ValueError:
+ # other = chntext
+ # year = ''
+ # if other:
+ # try:
+ # month, day = other.strip().split('月', maxsplit=1)
+ # month = Cardinal(chntext=month).chntext2cardinal() + '月'
+ # except ValueError:
+ # day = chntext
+ # month = ''
+ # if day:
+ # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1]
+ # else:
+ # month = ''
+ # day = ''
+ # date = year + month + day
+ # self.date = date
+ # return self.date
+
+ def date2chntext(self):
+ date = self.date
+ try:
+ year, other = date.strip().split("年", maxsplit=1)
+ year = Digit(digit=year).digit2chntext() + "年"
+ except ValueError:
+ other = date
+ year = ""
+ if other:
+ try:
+ month, day = other.strip().split("月", maxsplit=1)
+ month = Cardinal(cardinal=month).cardinal2chntext() + "月"
+ except ValueError:
+ day = date
+ month = ""
+ if day:
+ day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1]
+ else:
+ month = ""
+ day = ""
+ chntext = year + month + day
+ self.chntext = chntext
+ return self.chntext
+
+
+if __name__ == "__main__":
+
+ # 测试
+ print(Date(date="09年3月16日").date2chntext())
diff --git a/fish_speech/text/chn_text_norm/digit.py b/fish_speech/text/chn_text_norm/digit.py
new file mode 100644
index 0000000000000000000000000000000000000000..47c0cd4ad0c700635f84470bfdacfbdafb4a6185
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/digit.py
@@ -0,0 +1,32 @@
+# -*- coding: utf-8 -*-
+"""DIGIT类
+数字串 <=> 中文字符串 方法
+中文字符串 <=> 数字串 方法
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-03"
+
+from fish_speech.text.chn_text_norm.basic_util import *
+
+
+class Digit:
+ """
+ DIGIT类
+ """
+
+ def __init__(self, digit=None, chntext=None):
+ self.digit = digit
+ self.chntext = chntext
+
+ # def chntext2digit(self):
+ # return chn2num(self.chntext)
+
+ def digit2chntext(self):
+ return num2chn(self.digit, alt_two=False, use_units=False)
+
+
+if __name__ == "__main__":
+
+ # 测试程序
+ print(Digit(digit="2016").digit2chntext())
diff --git a/fish_speech/text/chn_text_norm/fraction.py b/fish_speech/text/chn_text_norm/fraction.py
new file mode 100644
index 0000000000000000000000000000000000000000..b43b6a7feb634d346d59a2b4ab84b77ac88df103
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/fraction.py
@@ -0,0 +1,35 @@
+# -*- coding: utf-8 -*-
+"""FRACTION类
+分数 <=> 中文字符串 方法
+中文字符串 <=> 分数 方法
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-03"
+
+from fish_speech.text.chn_text_norm.basic_util import *
+
+
+class Fraction:
+ """
+ FRACTION类
+ """
+
+ def __init__(self, fraction=None, chntext=None):
+ self.fraction = fraction
+ self.chntext = chntext
+
+ def chntext2fraction(self):
+ denominator, numerator = self.chntext.split("分之")
+ return chn2num(numerator) + "/" + chn2num(denominator)
+
+ def fraction2chntext(self):
+ numerator, denominator = self.fraction.split("/")
+ return num2chn(denominator) + "分之" + num2chn(numerator)
+
+
+if __name__ == "__main__":
+
+ # 测试程序
+ print(Fraction(fraction="2135/7230").fraction2chntext())
+ print(Fraction(chntext="五百八十一分之三百六十九").chntext2fraction())
diff --git a/fish_speech/text/chn_text_norm/money.py b/fish_speech/text/chn_text_norm/money.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4c980d32134e1460e96e5bcbcc73d0d55974d2a
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/money.py
@@ -0,0 +1,43 @@
+# -*- coding: utf-8 -*-
+"""MONEY类
+金钱 <=> 中文字符串 方法
+中文字符串 <=> 金钱 方法
+"""
+import re
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-08"
+
+from fish_speech.text.chn_text_norm.cardinal import Cardinal
+
+
+class Money:
+ """
+ MONEY类
+ """
+
+ def __init__(self, money=None, chntext=None):
+ self.money = money
+ self.chntext = chntext
+
+ # def chntext2money(self):
+ # return self.money
+
+ def money2chntext(self):
+ money = self.money
+ pattern = re.compile(r"(\d+(\.\d+)?)")
+ matchers = pattern.findall(money)
+ if matchers:
+ for matcher in matchers:
+ money = money.replace(
+ matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext()
+ )
+ self.chntext = money
+ return self.chntext
+
+
+if __name__ == "__main__":
+
+ # 测试
+ print(Money(money="21.5万元").money2chntext())
+ print(Money(money="230块5毛").money2chntext())
diff --git a/fish_speech/text/chn_text_norm/percentage.py b/fish_speech/text/chn_text_norm/percentage.py
new file mode 100644
index 0000000000000000000000000000000000000000..46abbf545af62eb951d8f6fe40bcf684587f81b0
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/percentage.py
@@ -0,0 +1,33 @@
+# -*- coding: utf-8 -*-
+"""PERCENTAGE类
+百分数 <=> 中文字符串 方法
+中文字符串 <=> 百分数 方法
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-06"
+
+from fish_speech.text.chn_text_norm.basic_util import *
+
+
+class Percentage:
+ """
+ PERCENTAGE类
+ """
+
+ def __init__(self, percentage=None, chntext=None):
+ self.percentage = percentage
+ self.chntext = chntext
+
+ def chntext2percentage(self):
+ return chn2num(self.chntext.strip().strip("百分之")) + "%"
+
+ def percentage2chntext(self):
+ return "百分之" + num2chn(self.percentage.strip().strip("%"))
+
+
+if __name__ == "__main__":
+
+ # 测试程序
+ print(Percentage(chntext="百分之五十六点零三").chntext2percentage())
+ print(Percentage(percentage="65.3%").percentage2chntext())
diff --git a/fish_speech/text/chn_text_norm/telephone.py b/fish_speech/text/chn_text_norm/telephone.py
new file mode 100644
index 0000000000000000000000000000000000000000..e72b546db628a3b807dc6235b59b188cae3153ff
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/telephone.py
@@ -0,0 +1,51 @@
+# -*- coding: utf-8 -*-
+"""TELEPHONE类
+电话号码 <=> 中文字符串 方法
+中文字符串 <=> 电话号码 方法
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-03"
+
+from fish_speech.text.chn_text_norm.basic_util import *
+
+
+class TelePhone:
+ """
+ TELEPHONE类
+ """
+
+ def __init__(self, telephone=None, raw_chntext=None, chntext=None):
+ self.telephone = telephone
+ self.raw_chntext = raw_chntext
+ self.chntext = chntext
+
+ # def chntext2telephone(self):
+ # sil_parts = self.raw_chntext.split('')
+ # self.telephone = '-'.join([
+ # str(chn2num(p)) for p in sil_parts
+ # ])
+ # return self.telephone
+
+ def telephone2chntext(self, fixed=False):
+
+ if fixed:
+ sil_parts = self.telephone.split("-")
+ self.raw_chntext = "".join(
+ [num2chn(part, alt_two=False, use_units=False) for part in sil_parts]
+ )
+ self.chntext = self.raw_chntext.replace("", "")
+ else:
+ sp_parts = self.telephone.strip("+").split()
+ self.raw_chntext = "".join(
+ [num2chn(part, alt_two=False, use_units=False) for part in sp_parts]
+ )
+ self.chntext = self.raw_chntext.replace("", "")
+ return self.chntext
+
+
+if __name__ == "__main__":
+
+ # 测试程序
+ print(TelePhone(telephone="0595-23980880").telephone2chntext())
+ # print(TelePhone(raw_chntext='零五九五杠二三八六五零九八').chntext2telephone())
diff --git a/fish_speech/text/chn_text_norm/text.py b/fish_speech/text/chn_text_norm/text.py
new file mode 100644
index 0000000000000000000000000000000000000000..54086fd933c01e14c3c55cee9adb52eefb58fd31
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/text.py
@@ -0,0 +1,177 @@
+# -*- coding: utf-8 -*-
+"""
+TEXT类
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-03"
+
+import re
+
+from fish_speech.text.chn_text_norm.cardinal import Cardinal
+from fish_speech.text.chn_text_norm.date import Date
+from fish_speech.text.chn_text_norm.digit import Digit
+from fish_speech.text.chn_text_norm.fraction import Fraction
+from fish_speech.text.chn_text_norm.money import Money
+from fish_speech.text.chn_text_norm.percentage import Percentage
+from fish_speech.text.chn_text_norm.telephone import TelePhone
+
+CURRENCY_NAMES = (
+ "(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|"
+ "里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)"
+)
+CURRENCY_UNITS = "((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)"
+COM_QUANTIFIERS = (
+ "(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|"
+ "砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|"
+ "针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|"
+ "毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|"
+ "盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|"
+ "纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|人|抽)"
+)
+
+
+class Text:
+ """
+ Text类
+ """
+
+ def __init__(self, raw_text, norm_text=None):
+ self.raw_text = "^" + raw_text + "$"
+ self.norm_text = norm_text
+
+ def _particular(self):
+ text = self.norm_text
+ pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('particular')
+ for matcher in matchers:
+ text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1)
+ self.norm_text = text
+ return self.norm_text
+
+ def normalize(self):
+ text = self.raw_text
+
+ # 规范化日期
+ pattern = re.compile(
+ r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)"
+ )
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('date')
+ for matcher in matchers:
+ text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
+
+ # 规范化金钱
+ pattern = re.compile(
+ r"\D+((\d+(\.\d+)?)[多余几]?"
+ + CURRENCY_UNITS
+ + "(\d"
+ + CURRENCY_UNITS
+ + "?)?)"
+ )
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('money')
+ for matcher in matchers:
+ text = text.replace(
+ matcher[0], Money(money=matcher[0]).money2chntext(), 1
+ )
+
+ # 规范化固话/手机号码
+ # 手机
+ # http://www.jihaoba.com/news/show/13680
+ # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
+ # 联通:130、131、132、156、155、186、185、176
+ # 电信:133、153、189、180、181、177
+ pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('telephone')
+ for matcher in matchers:
+ text = text.replace(
+ matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1
+ )
+ # 固话
+ pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('fixed telephone')
+ for matcher in matchers:
+ text = text.replace(
+ matcher[0],
+ TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True),
+ 1,
+ )
+
+ # 规范化分数
+ pattern = re.compile(r"(\d+/\d+)")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('fraction')
+ for matcher in matchers:
+ text = text.replace(
+ matcher, Fraction(fraction=matcher).fraction2chntext(), 1
+ )
+
+ # 规范化百分数
+ text = text.replace("%", "%")
+ pattern = re.compile(r"(\d+(\.\d+)?%)")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('percentage')
+ for matcher in matchers:
+ text = text.replace(
+ matcher[0],
+ Percentage(percentage=matcher[0]).percentage2chntext(),
+ 1,
+ )
+
+ # 规范化纯数+量词
+ pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS)
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('cardinal+quantifier')
+ for matcher in matchers:
+ text = text.replace(
+ matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1
+ )
+
+ # 规范化数字编号
+ pattern = re.compile(r"(\d{4,32})")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('digit')
+ for matcher in matchers:
+ text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
+
+ # 规范化纯数
+ pattern = re.compile(r"(\d+(\.\d+)?)")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('cardinal')
+ for matcher in matchers:
+ text = text.replace(
+ matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1
+ )
+
+ self.norm_text = text
+ self._particular()
+
+ return self.norm_text.lstrip("^").rstrip("$")
+
+
+if __name__ == "__main__":
+
+ # 测试程序
+ print(Text(raw_text="固话:0595-23865596或23880880。").normalize())
+ print(Text(raw_text="手机:+86 19859213959或15659451527。").normalize())
+ print(Text(raw_text="分数:32477/76391。").normalize())
+ print(Text(raw_text="百分数:80.03%。").normalize())
+ print(Text(raw_text="编号:31520181154418。").normalize())
+ print(Text(raw_text="纯数:2983.07克或12345.60米。").normalize())
+ print(Text(raw_text="日期:1999年2月20日或09年3月15号。").normalize())
+ print(Text(raw_text="金钱:12块5,34.5元,20.1万").normalize())
+ print(Text(raw_text="特殊:O2O或B2C。").normalize())
diff --git a/fish_speech/text/clean.py b/fish_speech/text/clean.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbaf843d781f113735043319cc00dc2aed5ae382
--- /dev/null
+++ b/fish_speech/text/clean.py
@@ -0,0 +1,62 @@
+import re
+
+SYMBOLS_MAPPING = {
+ "\n": "",
+ "…": ".",
+ "“": "'",
+ "”": "'",
+ "‘": "'",
+ "’": "'",
+ "【": "",
+ "】": "",
+ "[": "",
+ "]": "",
+ "(": "",
+ ")": "",
+ "(": "",
+ ")": "",
+ "・": "",
+ "·": "",
+ "「": "'",
+ "」": "'",
+ "《": "'",
+ "》": "'",
+ "—": "",
+ "~": "",
+ "~": "",
+ ":": ",",
+ ";": ",",
+ ";": ",",
+ ":": ",",
+}
+
+REPLACE_SYMBOL_REGEX = re.compile(
+ "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys())
+)
+
+
+EMOJI_REGEX = re.compile(
+ "["
+ "\U0001F600-\U0001F64F" # emoticons
+ "\U0001F300-\U0001F5FF" # symbols & pictographs
+ "\U0001F680-\U0001F6FF" # transport & map symbols
+ "\U0001F1E0-\U0001F1FF" # flags (iOS)
+ "]+",
+ flags=re.UNICODE,
+)
+
+
+def clean_text(text):
+ # Clean the text
+ text = text.strip()
+
+ # Replace all chinese symbols with their english counterparts
+ text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
+
+ # Remove emojis
+ text = EMOJI_REGEX.sub(r"", text)
+
+ # Remove continuous periods (...) and commas (,,,)
+ text = re.sub(r"[.,]{2,}", lambda m: m.group()[0], text)
+
+ return text
diff --git a/fish_speech/text/spliter.py b/fish_speech/text/spliter.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4bb995487c4f53818c6b2a16cf0a886b4e02e84
--- /dev/null
+++ b/fish_speech/text/spliter.py
@@ -0,0 +1,130 @@
+import re
+import string
+
+from fish_speech.text.clean import clean_text
+
+
+def utf_8_len(text):
+ return len(text.encode("utf-8"))
+
+
+def break_text(texts, length, splits: set):
+ for text in texts:
+ if utf_8_len(text) <= length:
+ yield text
+ continue
+
+ curr = ""
+ for char in text:
+ curr += char
+
+ if char in splits:
+ yield curr
+ curr = ""
+
+ if curr:
+ yield curr
+
+
+def break_text_by_length(texts, length):
+ for text in texts:
+ if utf_8_len(text) <= length:
+ yield text
+ continue
+
+ curr = ""
+ for char in text:
+ curr += char
+
+ if utf_8_len(curr) >= length:
+ yield curr
+ curr = ""
+
+ if curr:
+ yield curr
+
+
+def add_cleaned(curr, segments):
+ curr = curr.strip()
+ if curr and not all(c.isspace() or c in string.punctuation for c in curr):
+ segments.append(curr)
+
+
+def protect_float(text):
+ # Turns 3.14 into <3_f_14> to prevent splitting
+ return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text)
+
+
+def unprotect_float(text):
+ # Turns <3_f_14> into 3.14
+ return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text)
+
+
+def split_text(text, length):
+ text = clean_text(text)
+
+ # Break the text into pieces with following rules:
+ # 1. Split the text at ".", "!", "?" if text is NOT a float
+ # 2. If the text is longer than length, split at ","
+ # 3. If the text is still longer than length, split at " "
+ # 4. If the text is still longer than length, split at any character to length
+
+ texts = [text]
+ texts = map(protect_float, texts)
+ texts = break_text(texts, length, {".", "!", "?", "。", "!", "?"})
+ texts = map(unprotect_float, texts)
+ texts = break_text(texts, length, {",", ","})
+ texts = break_text(texts, length, {" "})
+ texts = list(break_text_by_length(texts, length))
+
+ # Then, merge the texts into segments with length <= length
+ segments = []
+ curr = ""
+
+ for text in texts:
+ if utf_8_len(curr) + utf_8_len(text) <= length:
+ curr += text
+ else:
+ add_cleaned(curr, segments)
+ curr = text
+
+ if curr:
+ add_cleaned(curr, segments)
+
+ return segments
+
+
+if __name__ == "__main__":
+ # Test the split_text function
+
+ text = "This is a test sentence. This is another test sentence. And a third one."
+
+ assert split_text(text, 50) == [
+ "This is a test sentence.",
+ "This is another test sentence. And a third one.",
+ ]
+ assert split_text("a,aaaaaa3.14", 10) == ["a,", "aaaaaa3.14"]
+ assert split_text(" ", 10) == []
+ assert split_text("a", 10) == ["a"]
+
+ text = "This is a test sentence with only commas, and no dots, and no exclamation marks, and no question marks, and no newlines."
+ assert split_text(text, 50) == [
+ "This is a test sentence with only commas,",
+ "and no dots, and no exclamation marks,",
+ "and no question marks, and no newlines.",
+ ]
+
+ text = "This is a test sentence This is a test sentence This is a test sentence. This is a test sentence, This is a test sentence, This is a test sentence."
+ # First half split at " ", second half split at ","
+ assert split_text(text, 50) == [
+ "This is a test sentence This is a test sentence",
+ "This is a test sentence. This is a test sentence,",
+ "This is a test sentence, This is a test sentence.",
+ ]
+
+ text = "这是一段很长的中文文本,而且没有句号,也没有感叹号,也没有问号,也没有换行符。"
+ assert split_text(text, 50) == [
+ "这是一段很长的中文文本,",
+ "而且没有句号,也没有感叹号,",
+ "也没有问号,也没有换行符.",
+ ]
diff --git a/fish_speech/train.py b/fish_speech/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..e693f3adc4dda787bdd587aec29f53355f2b1653
--- /dev/null
+++ b/fish_speech/train.py
@@ -0,0 +1,141 @@
+import os
+
+os.environ["USE_LIBUV"] = "0"
+import sys
+from typing import Optional
+
+import hydra
+import lightning as L
+import pyrootutils
+import torch
+from lightning import Callback, LightningDataModule, LightningModule, Trainer
+from lightning.pytorch.loggers import Logger
+from lightning.pytorch.strategies import DDPStrategy
+from omegaconf import DictConfig, OmegaConf
+
+os.environ.pop("SLURM_NTASKS", None)
+os.environ.pop("SLURM_JOB_NAME", None)
+os.environ.pop("SLURM_NTASKS_PER_NODE", None)
+
+# register eval resolver and root
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+
+# Allow TF32 on Ampere GPUs
+torch.set_float32_matmul_precision("high")
+torch.backends.cudnn.allow_tf32 = True
+
+# register eval resolver
+OmegaConf.register_new_resolver("eval", eval)
+
+import fish_speech.utils as utils
+
+log = utils.RankedLogger(__name__, rank_zero_only=True)
+
+
+@utils.task_wrapper
+def train(cfg: DictConfig) -> tuple[dict, dict]:
+ """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
+ training.
+ This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
+ failure. Useful for multiruns, saving info about the crash, etc.
+ Args:
+ cfg (DictConfig): Configuration composed by Hydra.
+ Returns:
+ Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
+ """ # noqa: E501
+
+ # set seed for random number generators in pytorch, numpy and python.random
+ if cfg.get("seed"):
+ L.seed_everything(cfg.seed, workers=False)
+
+ if cfg.get("deterministic"):
+ torch.use_deterministic_algorithms(True)
+
+ log.info(f"Instantiating datamodule <{cfg.data._target_}>")
+ datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
+
+ log.info(f"Instantiating model <{cfg.model._target_}>")
+ model: LightningModule = hydra.utils.instantiate(cfg.model)
+
+ log.info("Instantiating callbacks...")
+ callbacks: list[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
+
+ log.info("Instantiating loggers...")
+ logger: list[Logger] = utils.instantiate_loggers(cfg.get("logger"))
+
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
+ trainer: Trainer = hydra.utils.instantiate(
+ cfg.trainer,
+ callbacks=callbacks,
+ logger=logger,
+ )
+
+ object_dict = {
+ "cfg": cfg,
+ "datamodule": datamodule,
+ "model": model,
+ "callbacks": callbacks,
+ "logger": logger,
+ "trainer": trainer,
+ }
+
+ if logger:
+ log.info("Logging hyperparameters!")
+ utils.log_hyperparameters(object_dict)
+
+ if cfg.get("train"):
+ log.info("Starting training!")
+
+ ckpt_path = cfg.get("ckpt_path")
+ auto_resume = False
+
+ resume_ckpt_path = utils.get_latest_checkpoint(cfg.paths.ckpt_dir)
+ if resume_ckpt_path is not None:
+ ckpt_path = resume_ckpt_path
+ auto_resume = True
+
+ if ckpt_path is not None:
+ log.info(f"Resuming from checkpoint: {ckpt_path}")
+
+ # resume weights only is disabled for auto-resume
+ if cfg.get("resume_weights_only") and auto_resume is False:
+ log.info("Resuming weights only!")
+ ckpt = torch.load(ckpt_path, map_location=model.device)
+ if "state_dict" in ckpt:
+ ckpt = ckpt["state_dict"]
+ err = model.load_state_dict(ckpt, strict=False)
+ log.info(f"Error loading state dict: {err}")
+ ckpt_path = None
+
+ trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
+
+ train_metrics = trainer.callback_metrics
+
+ if cfg.get("test"):
+ log.info("Starting testing!")
+ ckpt_path = trainer.checkpoint_callback.best_model_path
+ if ckpt_path == "":
+ log.warning("Best ckpt not found! Using current weights for testing...")
+ ckpt_path = cfg.get("ckpt_path")
+
+ trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
+ log.info(f"Best ckpt path: {ckpt_path}")
+
+ test_metrics = trainer.callback_metrics
+
+ # merge train and test metrics
+ metric_dict = {**train_metrics, **test_metrics}
+
+ return metric_dict, object_dict
+
+
+@hydra.main(
+ version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml"
+)
+def main(cfg: DictConfig) -> Optional[float]:
+ # train the model
+ train(cfg)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/fish_speech/utils/__init__.py b/fish_speech/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..53cf2f23174ddac9bf523730aca2f6a9965d134a
--- /dev/null
+++ b/fish_speech/utils/__init__.py
@@ -0,0 +1,24 @@
+from .braceexpand import braceexpand
+from .context import autocast_exclude_mps
+from .file import get_latest_checkpoint
+from .instantiators import instantiate_callbacks, instantiate_loggers
+from .logger import RankedLogger
+from .logging_utils import log_hyperparameters
+from .rich_utils import enforce_tags, print_config_tree
+from .utils import extras, get_metric_value, set_seed, task_wrapper
+
+__all__ = [
+ "enforce_tags",
+ "extras",
+ "get_metric_value",
+ "RankedLogger",
+ "instantiate_callbacks",
+ "instantiate_loggers",
+ "log_hyperparameters",
+ "print_config_tree",
+ "task_wrapper",
+ "braceexpand",
+ "get_latest_checkpoint",
+ "autocast_exclude_mps",
+ "set_seed",
+]
diff --git a/fish_speech/utils/braceexpand.py b/fish_speech/utils/braceexpand.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3ac739f01f7e10e039c68c1157d6c761064f974
--- /dev/null
+++ b/fish_speech/utils/braceexpand.py
@@ -0,0 +1,217 @@
+"""
+Bash-style brace expansion
+Copied from: https://github.com/trendels/braceexpand/blob/main/src/braceexpand/__init__.py
+License: MIT
+"""
+
+import re
+import string
+from itertools import chain, product
+from typing import Iterable, Iterator, Optional
+
+__all__ = ["braceexpand", "alphabet", "UnbalancedBracesError"]
+
+
+class UnbalancedBracesError(ValueError):
+ pass
+
+
+alphabet = string.ascii_uppercase + string.ascii_lowercase
+
+int_range_re = re.compile(r"^(-?\d+)\.\.(-?\d+)(?:\.\.-?(\d+))?$")
+char_range_re = re.compile(r"^([A-Za-z])\.\.([A-Za-z])(?:\.\.-?(\d+))?$")
+escape_re = re.compile(r"\\(.)")
+
+
+def braceexpand(pattern: str, escape: bool = True) -> Iterator[str]:
+ """braceexpand(pattern) -> iterator over generated strings
+
+ Returns an iterator over the strings resulting from brace expansion
+ of pattern. This function implements Brace Expansion as described in
+ bash(1), with the following limitations:
+
+ * A pattern containing unbalanced braces will raise an
+ UnbalancedBracesError exception. In bash, unbalanced braces will either
+ be partly expanded or ignored.
+
+ * A mixed-case character range like '{Z..a}' or '{a..Z}' will not
+ include the characters '[]^_`' between 'Z' and 'a'.
+
+ When escape is True (the default), characters in pattern can be
+ prefixed with a backslash to cause them not to be interpreted as
+ special characters for brace expansion (such as '{', '}', ',').
+ To pass through a a literal backslash, double it ('\\\\').
+
+ When escape is False, backslashes in pattern have no special
+ meaning and will be preserved in the output.
+
+ Examples:
+
+ >>> from braceexpand import braceexpand
+
+ # Integer range
+ >>> list(braceexpand('item{1..3}'))
+ ['item1', 'item2', 'item3']
+
+ # Character range
+ >>> list(braceexpand('{a..c}'))
+ ['a', 'b', 'c']
+
+ # Sequence
+ >>> list(braceexpand('index.html{,.backup}'))
+ ['index.html', 'index.html.backup']
+
+ # Nested patterns
+ >>> list(braceexpand('python{2.{5..7},3.{2,3}}'))
+ ['python2.5', 'python2.6', 'python2.7', 'python3.2', 'python3.3']
+
+ # Prefixing an integer with zero causes all numbers to be padded to
+ # the same width.
+ >>> list(braceexpand('{07..10}'))
+ ['07', '08', '09', '10']
+
+ # An optional increment can be specified for ranges.
+ >>> list(braceexpand('{a..g..2}'))
+ ['a', 'c', 'e', 'g']
+
+ # Ranges can go in both directions.
+ >>> list(braceexpand('{4..1}'))
+ ['4', '3', '2', '1']
+
+ # Numbers can be negative
+ >>> list(braceexpand('{2..-1}'))
+ ['2', '1', '0', '-1']
+
+ # Unbalanced braces raise an exception.
+ >>> list(braceexpand('{1{2,3}'))
+ Traceback (most recent call last):
+ ...
+ UnbalancedBracesError: Unbalanced braces: '{1{2,3}'
+
+ # By default, the backslash is the escape character.
+ >>> list(braceexpand(r'{1\\{2,3}'))
+ ['1{2', '3']
+
+ # Setting 'escape' to False disables backslash escaping.
+ >>> list(braceexpand(r'\\{1,2}', escape=False))
+ ['\\\\1', '\\\\2']
+
+ """
+ return (
+ escape_re.sub(r"\1", s) if escape else s for s in parse_pattern(pattern, escape)
+ )
+
+
+def parse_pattern(pattern: str, escape: bool) -> Iterator[str]:
+ start = 0
+ pos = 0
+ bracketdepth = 0
+ items: list[Iterable[str]] = []
+
+ # print 'pattern:', pattern
+ while pos < len(pattern):
+ if escape and pattern[pos] == "\\":
+ pos += 2
+ continue
+ elif pattern[pos] == "{":
+ if bracketdepth == 0 and pos > start:
+ # print 'literal:', pattern[start:pos]
+ items.append([pattern[start:pos]])
+ start = pos
+ bracketdepth += 1
+ elif pattern[pos] == "}":
+ bracketdepth -= 1
+ if bracketdepth == 0:
+ # print 'expression:', pattern[start+1:pos]
+ expr = pattern[start + 1 : pos]
+ item = parse_expression(expr, escape)
+ if item is None: # not a range or sequence
+ items.extend([["{"], parse_pattern(expr, escape), ["}"]])
+ else:
+ items.append(item)
+ start = pos + 1 # skip the closing brace
+ pos += 1
+
+ if bracketdepth != 0: # unbalanced braces
+ raise UnbalancedBracesError("Unbalanced braces: '%s'" % pattern)
+
+ if start < pos:
+ items.append([pattern[start:]])
+
+ return ("".join(item) for item in product(*items))
+
+
+def parse_expression(expr: str, escape: bool) -> Optional[Iterable[str]]:
+ int_range_match = int_range_re.match(expr)
+ if int_range_match:
+ return make_int_range(*int_range_match.groups())
+
+ char_range_match = char_range_re.match(expr)
+ if char_range_match:
+ return make_char_range(*char_range_match.groups())
+
+ return parse_sequence(expr, escape)
+
+
+def parse_sequence(seq: str, escape: bool) -> Optional[Iterator[str]]:
+ # sequence -> chain(*sequence_items)
+ start = 0
+ pos = 0
+ bracketdepth = 0
+ items: list[Iterable[str]] = []
+
+ # print 'sequence:', seq
+ while pos < len(seq):
+ if escape and seq[pos] == "\\":
+ pos += 2
+ continue
+ elif seq[pos] == "{":
+ bracketdepth += 1
+ elif seq[pos] == "}":
+ bracketdepth -= 1
+ elif seq[pos] == "," and bracketdepth == 0:
+ items.append(parse_pattern(seq[start:pos], escape))
+ start = pos + 1 # skip the comma
+ pos += 1
+
+ if bracketdepth != 0:
+ raise UnbalancedBracesError
+ if not items:
+ return None
+
+ # part after the last comma (may be the empty string)
+ items.append(parse_pattern(seq[start:], escape))
+ return chain(*items)
+
+
+def make_int_range(left: str, right: str, incr: Optional[str] = None) -> Iterator[str]:
+ if any([s.startswith(("0", "-0")) for s in (left, right) if s not in ("0", "-0")]):
+ padding = max(len(left), len(right))
+ else:
+ padding = 0
+ step = (int(incr) or 1) if incr else 1
+ start = int(left)
+ end = int(right)
+ r = range(start, end + 1, step) if start < end else range(start, end - 1, -step)
+ fmt = "%0{}d".format(padding)
+ return (fmt % i for i in r)
+
+
+def make_char_range(left: str, right: str, incr: Optional[str] = None) -> str:
+ step = (int(incr) or 1) if incr else 1
+ start = alphabet.index(left)
+ end = alphabet.index(right)
+ if start < end:
+ return alphabet[start : end + 1 : step]
+ else:
+ end = end or -len(alphabet)
+ return alphabet[start : end - 1 : -step]
+
+
+if __name__ == "__main__":
+ import doctest
+ import sys
+
+ failed, _ = doctest.testmod(optionflags=doctest.IGNORE_EXCEPTION_DETAIL)
+ if failed:
+ sys.exit(1)
diff --git a/fish_speech/utils/context.py b/fish_speech/utils/context.py
new file mode 100644
index 0000000000000000000000000000000000000000..f04a99290ab32f7fe5b60656075a2d03af8468d6
--- /dev/null
+++ b/fish_speech/utils/context.py
@@ -0,0 +1,13 @@
+from contextlib import nullcontext
+
+import torch
+
+
+def autocast_exclude_mps(
+ device_type: str, dtype: torch.dtype
+) -> nullcontext | torch.autocast:
+ return (
+ nullcontext()
+ if torch.backends.mps.is_available()
+ else torch.autocast(device_type, dtype)
+ )
diff --git a/fish_speech/utils/file.py b/fish_speech/utils/file.py
new file mode 100644
index 0000000000000000000000000000000000000000..78c82640a963fa556657107729f7543d2e7c3510
--- /dev/null
+++ b/fish_speech/utils/file.py
@@ -0,0 +1,16 @@
+import os
+from pathlib import Path
+
+
+def get_latest_checkpoint(path: Path | str) -> Path | None:
+ # Find the latest checkpoint
+ ckpt_dir = Path(path)
+
+ if ckpt_dir.exists() is False:
+ return None
+
+ ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime)
+ if len(ckpts) == 0:
+ return None
+
+ return ckpts[-1]
diff --git a/fish_speech/utils/instantiators.py b/fish_speech/utils/instantiators.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6ee463924f588a35477937fbe3c3364043bdf3e
--- /dev/null
+++ b/fish_speech/utils/instantiators.py
@@ -0,0 +1,50 @@
+from typing import List
+
+import hydra
+from omegaconf import DictConfig
+from pytorch_lightning import Callback
+from pytorch_lightning.loggers import Logger
+
+from .logger import RankedLogger
+
+log = RankedLogger(__name__, rank_zero_only=True)
+
+
+def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
+ """Instantiates callbacks from config."""
+
+ callbacks: List[Callback] = []
+
+ if not callbacks_cfg:
+ log.warning("No callback configs found! Skipping..")
+ return callbacks
+
+ if not isinstance(callbacks_cfg, DictConfig):
+ raise TypeError("Callbacks config must be a DictConfig!")
+
+ for _, cb_conf in callbacks_cfg.items():
+ if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
+ log.info(f"Instantiating callback <{cb_conf._target_}>")
+ callbacks.append(hydra.utils.instantiate(cb_conf))
+
+ return callbacks
+
+
+def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
+ """Instantiates loggers from config."""
+
+ logger: List[Logger] = []
+
+ if not logger_cfg:
+ log.warning("No logger configs found! Skipping...")
+ return logger
+
+ if not isinstance(logger_cfg, DictConfig):
+ raise TypeError("Logger config must be a DictConfig!")
+
+ for _, lg_conf in logger_cfg.items():
+ if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
+ log.info(f"Instantiating logger <{lg_conf._target_}>")
+ logger.append(hydra.utils.instantiate(lg_conf))
+
+ return logger
diff --git a/fish_speech/utils/logger.py b/fish_speech/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..94f94f738d1d87404354d086c30ef0ad9ab04cdc
--- /dev/null
+++ b/fish_speech/utils/logger.py
@@ -0,0 +1,55 @@
+import logging
+from typing import Mapping, Optional
+
+from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
+
+
+class RankedLogger(logging.LoggerAdapter):
+ """A multi-GPU-friendly python command line logger."""
+
+ def __init__(
+ self,
+ name: str = __name__,
+ rank_zero_only: bool = True,
+ extra: Optional[Mapping[str, object]] = None,
+ ) -> None:
+ """Initializes a multi-GPU-friendly python command line logger that logs on all processes
+ with their rank prefixed in the log message.
+
+ :param name: The name of the logger. Default is ``__name__``.
+ :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
+ :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
+ """
+ logger = logging.getLogger(name)
+ super().__init__(logger=logger, extra=extra)
+ self.rank_zero_only = rank_zero_only
+
+ def log(
+ self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs
+ ) -> None:
+ """Delegate a log call to the underlying logger, after prefixing its message with the rank
+ of the process it's being logged from. If `'rank'` is provided, then the log will only
+ occur on that rank/process.
+
+ :param level: The level to log at. Look at `logging.__init__.py` for more information.
+ :param msg: The message to log.
+ :param rank: The rank to log at.
+ :param args: Additional args to pass to the underlying logging function.
+ :param kwargs: Any additional keyword args to pass to the underlying logging function.
+ """
+ if self.isEnabledFor(level):
+ msg, kwargs = self.process(msg, kwargs)
+ current_rank = getattr(rank_zero_only, "rank", None)
+ if current_rank is None:
+ raise RuntimeError(
+ "The `rank_zero_only.rank` needs to be set before use"
+ )
+ msg = rank_prefixed_message(msg, current_rank)
+ if self.rank_zero_only:
+ if current_rank == 0:
+ self.logger.log(level, msg, *args, **kwargs)
+ else:
+ if rank is None:
+ self.logger.log(level, msg, *args, **kwargs)
+ elif current_rank == rank:
+ self.logger.log(level, msg, *args, **kwargs)
diff --git a/fish_speech/utils/logging_utils.py b/fish_speech/utils/logging_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e3b0a2519e12845f09e5fbe86dfccbf5b345429
--- /dev/null
+++ b/fish_speech/utils/logging_utils.py
@@ -0,0 +1,48 @@
+from lightning.pytorch.utilities import rank_zero_only
+
+from fish_speech.utils import logger as log
+
+
+@rank_zero_only
+def log_hyperparameters(object_dict: dict) -> None:
+ """Controls which config parts are saved by lightning loggers.
+
+ Additionally saves:
+ - Number of model parameters
+ """
+
+ hparams = {}
+
+ cfg = object_dict["cfg"]
+ model = object_dict["model"]
+ trainer = object_dict["trainer"]
+
+ if not trainer.logger:
+ log.warning("Logger not found! Skipping hyperparameter logging...")
+ return
+
+ hparams["model"] = cfg["model"]
+
+ # save number of model parameters
+ hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
+ hparams["model/params/trainable"] = sum(
+ p.numel() for p in model.parameters() if p.requires_grad
+ )
+ hparams["model/params/non_trainable"] = sum(
+ p.numel() for p in model.parameters() if not p.requires_grad
+ )
+
+ hparams["data"] = cfg["data"]
+ hparams["trainer"] = cfg["trainer"]
+
+ hparams["callbacks"] = cfg.get("callbacks")
+ hparams["extras"] = cfg.get("extras")
+
+ hparams["task_name"] = cfg.get("task_name")
+ hparams["tags"] = cfg.get("tags")
+ hparams["ckpt_path"] = cfg.get("ckpt_path")
+ hparams["seed"] = cfg.get("seed")
+
+ # send hparams to all loggers
+ for logger in trainer.loggers:
+ logger.log_hyperparams(hparams)
diff --git a/fish_speech/utils/rich_utils.py b/fish_speech/utils/rich_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a465f54d610779766d51e3d1a020a3b1517fd1f
--- /dev/null
+++ b/fish_speech/utils/rich_utils.py
@@ -0,0 +1,100 @@
+from pathlib import Path
+from typing import Sequence
+
+import rich
+import rich.syntax
+import rich.tree
+from hydra.core.hydra_config import HydraConfig
+from lightning.pytorch.utilities import rank_zero_only
+from omegaconf import DictConfig, OmegaConf, open_dict
+from rich.prompt import Prompt
+
+from fish_speech.utils import logger as log
+
+
+@rank_zero_only
+def print_config_tree(
+ cfg: DictConfig,
+ print_order: Sequence[str] = (
+ "data",
+ "model",
+ "callbacks",
+ "logger",
+ "trainer",
+ "paths",
+ "extras",
+ ),
+ resolve: bool = False,
+ save_to_file: bool = False,
+) -> None:
+ """Prints content of DictConfig using Rich library and its tree structure.
+
+ Args:
+ cfg (DictConfig): Configuration composed by Hydra.
+ print_order (Sequence[str], optional): Determines in what order config components are printed.
+ resolve (bool, optional): Whether to resolve reference fields of DictConfig.
+ save_to_file (bool, optional): Whether to export config to the hydra output folder.
+ """ # noqa: E501
+
+ style = "dim"
+ tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
+
+ queue = []
+
+ # add fields from `print_order` to queue
+ for field in print_order:
+ (
+ queue.append(field)
+ if field in cfg
+ else log.warning(
+ f"Field '{field}' not found in config. "
+ + f"Skipping '{field}' config printing..."
+ )
+ )
+
+ # add all the other fields to queue (not specified in `print_order`)
+ for field in cfg:
+ if field not in queue:
+ queue.append(field)
+
+ # generate config tree from queue
+ for field in queue:
+ branch = tree.add(field, style=style, guide_style=style)
+
+ config_group = cfg[field]
+ if isinstance(config_group, DictConfig):
+ branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
+ else:
+ branch_content = str(config_group)
+
+ branch.add(rich.syntax.Syntax(branch_content, "yaml"))
+
+ # print config tree
+ rich.print(tree)
+
+ # save config tree to file
+ if save_to_file:
+ with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
+ rich.print(tree, file=file)
+
+
+@rank_zero_only
+def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
+ """Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501
+
+ if not cfg.get("tags"):
+ if "id" in HydraConfig().cfg.hydra.job:
+ raise ValueError("Specify tags before launching a multirun!")
+
+ log.warning("No tags provided in config. Prompting user to input tags...")
+ tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
+ tags = [t.strip() for t in tags.split(",") if t != ""]
+
+ with open_dict(cfg):
+ cfg.tags = tags
+
+ log.info(f"Tags: {cfg.tags}")
+
+ if save_to_file:
+ with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
+ rich.print(cfg.tags, file=file)
diff --git a/fish_speech/utils/spectrogram.py b/fish_speech/utils/spectrogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..01c3d7a2ab0f707ae92dbde0feb173927720c841
--- /dev/null
+++ b/fish_speech/utils/spectrogram.py
@@ -0,0 +1,122 @@
+import torch
+import torchaudio.functional as F
+from torch import Tensor, nn
+from torchaudio.transforms import MelScale
+
+
+class LinearSpectrogram(nn.Module):
+ def __init__(
+ self,
+ n_fft=2048,
+ win_length=2048,
+ hop_length=512,
+ center=False,
+ mode="pow2_sqrt",
+ ):
+ super().__init__()
+
+ self.n_fft = n_fft
+ self.win_length = win_length
+ self.hop_length = hop_length
+ self.center = center
+ self.mode = mode
+
+ self.register_buffer("window", torch.hann_window(win_length), persistent=False)
+
+ def forward(self, y: Tensor) -> Tensor:
+ if y.ndim == 3:
+ y = y.squeeze(1)
+
+ y = torch.nn.functional.pad(
+ y.unsqueeze(1),
+ (
+ (self.win_length - self.hop_length) // 2,
+ (self.win_length - self.hop_length + 1) // 2,
+ ),
+ mode="reflect",
+ ).squeeze(1)
+
+ spec = torch.stft(
+ y,
+ self.n_fft,
+ hop_length=self.hop_length,
+ win_length=self.win_length,
+ window=self.window,
+ center=self.center,
+ pad_mode="reflect",
+ normalized=False,
+ onesided=True,
+ return_complex=True,
+ )
+
+ spec = torch.view_as_real(spec)
+
+ if self.mode == "pow2_sqrt":
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
+
+ return spec
+
+
+class LogMelSpectrogram(nn.Module):
+ def __init__(
+ self,
+ sample_rate=44100,
+ n_fft=2048,
+ win_length=2048,
+ hop_length=512,
+ n_mels=128,
+ center=False,
+ f_min=0.0,
+ f_max=None,
+ ):
+ super().__init__()
+
+ self.sample_rate = sample_rate
+ self.n_fft = n_fft
+ self.win_length = win_length
+ self.hop_length = hop_length
+ self.center = center
+ self.n_mels = n_mels
+ self.f_min = f_min
+ self.f_max = f_max or float(sample_rate // 2)
+
+ self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
+
+ fb = F.melscale_fbanks(
+ n_freqs=self.n_fft // 2 + 1,
+ f_min=self.f_min,
+ f_max=self.f_max,
+ n_mels=self.n_mels,
+ sample_rate=self.sample_rate,
+ norm="slaney",
+ mel_scale="slaney",
+ )
+ self.register_buffer(
+ "fb",
+ fb,
+ persistent=False,
+ )
+
+ def compress(self, x: Tensor) -> Tensor:
+ return torch.log(torch.clamp(x, min=1e-5))
+
+ def decompress(self, x: Tensor) -> Tensor:
+ return torch.exp(x)
+
+ def apply_mel_scale(self, x: Tensor) -> Tensor:
+ return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2)
+
+ def forward(
+ self, x: Tensor, return_linear: bool = False, sample_rate: int = None
+ ) -> Tensor:
+ if sample_rate is not None and sample_rate != self.sample_rate:
+ x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate)
+
+ linear = self.spectrogram(x)
+ x = self.apply_mel_scale(linear)
+ x = self.compress(x)
+
+ if return_linear:
+ return x, self.compress(linear)
+
+ return x
diff --git a/fish_speech/utils/utils.py b/fish_speech/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a34bdcfedff76c333f50ed8be050d0dd5a8f98a
--- /dev/null
+++ b/fish_speech/utils/utils.py
@@ -0,0 +1,136 @@
+import random
+import warnings
+from importlib.util import find_spec
+from typing import Callable
+
+import numpy as np
+import torch
+from omegaconf import DictConfig
+
+from .logger import RankedLogger
+from .rich_utils import enforce_tags, print_config_tree
+
+log = RankedLogger(__name__, rank_zero_only=True)
+
+
+def extras(cfg: DictConfig) -> None:
+ """Applies optional utilities before the task is started.
+
+ Utilities:
+ - Ignoring python warnings
+ - Setting tags from command line
+ - Rich config printing
+ """
+
+ # return if no `extras` config
+ if not cfg.get("extras"):
+ log.warning("Extras config not found! ")
+ return
+
+ # disable python warnings
+ if cfg.extras.get("ignore_warnings"):
+ log.info("Disabling python warnings! ")
+ warnings.filterwarnings("ignore")
+
+ # prompt user to input tags from command line if none are provided in the config
+ if cfg.extras.get("enforce_tags"):
+ log.info("Enforcing tags! ")
+ enforce_tags(cfg, save_to_file=True)
+
+ # pretty print config tree using Rich library
+ if cfg.extras.get("print_config"):
+ log.info("Printing config tree with Rich! ")
+ print_config_tree(cfg, resolve=True, save_to_file=True)
+
+
+def task_wrapper(task_func: Callable) -> Callable:
+ """Optional decorator that controls the failure behavior when executing the task function.
+
+ This wrapper can be used to:
+ - make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
+ - save the exception to a `.log` file
+ - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
+ - etc. (adjust depending on your needs)
+
+ Example:
+ ```
+ @utils.task_wrapper
+ def train(cfg: DictConfig) -> Tuple[dict, dict]:
+
+ ...
+
+ return metric_dict, object_dict
+ ```
+ """ # noqa: E501
+
+ def wrap(cfg: DictConfig):
+ # execute the task
+ try:
+ metric_dict, object_dict = task_func(cfg=cfg)
+
+ # things to do if exception occurs
+ except Exception as ex:
+ # save exception to `.log` file
+ log.exception("")
+
+ # some hyperparameter combinations might be invalid or
+ # cause out-of-memory errors so when using hparam search
+ # plugins like Optuna, you might want to disable
+ # raising the below exception to avoid multirun failure
+ raise ex
+
+ # things to always do after either success or exception
+ finally:
+ # display output dir path in terminal
+ log.info(f"Output dir: {cfg.paths.run_dir}")
+
+ # always close wandb run (even if exception occurs so multirun won't fail)
+ if find_spec("wandb"): # check if wandb is installed
+ import wandb
+
+ if wandb.run:
+ log.info("Closing wandb!")
+ wandb.finish()
+
+ return metric_dict, object_dict
+
+ return wrap
+
+
+def get_metric_value(metric_dict: dict, metric_name: str) -> float:
+ """Safely retrieves value of the metric logged in LightningModule."""
+
+ if not metric_name:
+ log.info("Metric name is None! Skipping metric value retrieval...")
+ return None
+
+ if metric_name not in metric_dict:
+ raise Exception(
+ f"Metric value not found! \n"
+ "Make sure metric name logged in LightningModule is correct!\n"
+ "Make sure `optimized_metric` name in `hparams_search` config is correct!"
+ )
+
+ metric_value = metric_dict[metric_name].item()
+ log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
+
+ return metric_value
+
+
+def set_seed(seed: int):
+ if seed < 0:
+ seed = -seed
+ if seed > (1 << 31):
+ seed = 1 << 31
+
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+ if torch.backends.cudnn.is_available():
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
diff --git a/fish_speech/webui/css/style.css b/fish_speech/webui/css/style.css
new file mode 100644
index 0000000000000000000000000000000000000000..3c7a22ecc31881a65a76369b0fd889330a0874c7
--- /dev/null
+++ b/fish_speech/webui/css/style.css
@@ -0,0 +1,161 @@
+:root {
+ --my-200: #80eeee;
+ --my-50: #ecfdf5;
+ --water-width: 300px;
+ --water-heigh: 300px;
+}
+
+
+/* general styled components */
+.tools {
+ align-items: center;
+ justify-content: center;
+}
+
+.gradio-button {
+ max-width: 2.2em;
+ min-width: 2.2em !important;
+ height: 2.4em;
+ align-self: end;
+ line-height: 1em;
+ border-radius: 0.5em;
+
+}
+
+.gradio-button.secondary-down, .gradio-button.secondary-down:hover{
+ box-shadow: 1px 1px 1px rgba(0,0,0,0.25) inset, 0px 0px 3px rgba(0,0,0,0.15) inset;
+}
+
+/* replace original footer with ours */
+a{
+ font-weight: bold;
+ cursor: pointer;
+ color: #030C14 !important;
+}
+
+footer {
+ display: none !important;
+}
+
+#footer{
+ text-align: center;
+}
+
+#footer div{
+ display: inline-block;
+}
+
+#footer .versions{
+ font-size: 85%;
+ opacity: 0.85;
+}
+
+/*@keyframes moveBackground {*/
+/* 0% {*/
+/* background-position: 0 0;*/
+/* }*/
+/* 100% {*/
+/* background-position: -100px 100px;*/
+/* }*/
+/*}*/
+@keyframes moveJellyBackground {
+ 0% {
+ background-position: 0% 50%;
+ }
+ 50% {
+ background-position: 100% 50%;
+ }
+ 100% {
+ background-position: 0% 50%;
+ }
+}
+
+.gradio-container {
+ position: absolute;
+ z-index: 10;
+}
+
+
+.quan {
+ position: absolute;
+ bottom: 0;
+ width: var(--water-width);
+ height: var(--water-heigh);
+ border-radius: 0;
+ /*border: 3px solid rgb(246, 247, 248);*/
+ /*box-shadow: 0 0 0 3px rgb(41, 134, 196);*/
+ z-index: 0;
+
+}
+
+.quan:last-child {
+ margin-right: 0;
+}
+
+.shui {
+ position: absolute;
+ top: 0;
+ left: 0;
+ width: 100%;
+ height: 100%;
+ background-color: rgb(23, 106, 201);
+ border-radius: 0;
+ overflow: hidden;
+ z-index: 0;
+}
+
+.shui::after {
+
+ content: '';
+ position: absolute;
+ top: 20%;
+ left: 50%;
+ width: 150%;
+ height: 150%;
+ border-radius: 40%;
+ background-image: radial-gradient(circle at 0% 50%, #dcfcf1, var(--my-50) 50%);
+ animation: shi 5s linear infinite;
+}
+
+@keyframes shi {
+ 0% {
+ transform: translate(-50%, -65%) rotate(0deg);
+ }
+ 100% {
+ transform: translate(-50%, -65%) rotate(360deg);
+ }
+}
+
+.shui::before {
+ content: '';
+ position: absolute;
+ top: 20%;
+ left: 50%;
+ width: 150%;
+ height: 150%;
+ border-radius: 42%;
+ background-color: rgb(240, 228, 228, 0.2);
+ animation: xu 7s linear infinite;
+}
+
+@keyframes xu {
+ 0% {
+ transform: translate(-50%, -60%) rotate(0deg);
+ }
+ 100% {
+ transform: translate(-50%, -60%) rotate(360deg);
+ }
+}
+
+fieldset.data_src div.wrap label {
+ background: #f8bffee0 !important;
+}
+
+.scrollable-component {
+ max-height: 100px;
+ overflow-y: auto;
+}
+
+#file_accordion {
+ max-height: 220px !important;
+}
diff --git a/fish_speech/webui/html/footer.html b/fish_speech/webui/html/footer.html
new file mode 100644
index 0000000000000000000000000000000000000000..ac1745aa6f41f86a17e3d95564c2bf7a8d7bb615
--- /dev/null
+++ b/fish_speech/webui/html/footer.html
@@ -0,0 +1,11 @@
+
+
+
+{versions}
+
diff --git a/fish_speech/webui/js/animate.js b/fish_speech/webui/js/animate.js
new file mode 100644
index 0000000000000000000000000000000000000000..0637a541a8e704632a42b89bdf1471b26e7bb868
--- /dev/null
+++ b/fish_speech/webui/js/animate.js
@@ -0,0 +1,69 @@
+
+function createGradioAnimation() {
+ const params = new URLSearchParams(window.location.search);
+ if (!params.has('__theme')) {
+ params.set('__theme', 'light');
+ window.location.search = params.toString();
+ }
+
+ var gradioApp = document.querySelector('gradio-app');
+ if (gradioApp) {
+
+ document.documentElement.style.setProperty('--my-200', '#80eeee');
+ document.documentElement.style.setProperty('--my-50', '#ecfdf5');
+
+ // gradioApp.style.position = 'relative';
+ // gradioApp.style.backgroundSize = '200% 200%';
+ // gradioApp.style.animation = 'moveJellyBackground 10s ease infinite';
+ // gradioApp.style.backgroundImage = 'radial-gradient(circle at 0% 50%, var(--my-200), var(--my-50) 50%)';
+ // gradioApp.style.display = 'flex';
+ // gradioApp.style.justifyContent = 'flex-start';
+ // gradioApp.style.flexWrap = 'nowrap';
+ // gradioApp.style.overflowX = 'auto';
+
+ // for (let i = 0; i < 6; i++) {
+ // var quan = document.createElement('div');
+ // quan.className = 'quan';
+ // gradioApp.insertBefore(quan, gradioApp.firstChild);
+ // quan.id = 'quan' + i.toString();
+ // quan.style.left = 'calc(var(--water-width) * ' + i.toString() + ')';
+ // var quanContainer = document.querySelector('.quan');
+ // if (quanContainer) {
+ // var shui = document.createElement('div');
+ // shui.className = 'shui';
+ // quanContainer.insertBefore(shui, quanContainer.firstChild)
+ // }
+ // }
+ }
+
+ var container = document.createElement('div');
+ container.id = 'gradio-animation';
+ container.style.fontSize = '2em';
+ container.style.fontFamily = 'Maiandra GD, ui-monospace, monospace';
+ container.style.fontWeight = 'bold';
+ container.style.textAlign = 'center';
+ container.style.marginBottom = '20px';
+
+ var text = 'Welcome to Fish-Speech!';
+ for (var i = 0; i < text.length; i++) {
+ (function(i){
+ setTimeout(function(){
+ var letter = document.createElement('span');
+ letter.style.opacity = '0';
+ letter.style.transition = 'opacity 0.5s';
+ letter.innerText = text[i];
+
+ container.appendChild(letter);
+
+ setTimeout(function() {
+ letter.style.opacity = '1';
+ }, 50);
+ }, i * 200);
+ })(i);
+ }
+
+ var gradioContainer = document.querySelector('.gradio-container');
+ gradioContainer.insertBefore(container, gradioContainer.firstChild);
+
+ return 'Animation created';
+}
diff --git a/fish_speech/webui/launch_utils.py b/fish_speech/webui/launch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..790c0e632ce55e099e5578d8824e94b1d1260d6e
--- /dev/null
+++ b/fish_speech/webui/launch_utils.py
@@ -0,0 +1,120 @@
+import importlib.util
+import os
+import subprocess
+import sys
+from functools import lru_cache
+from pathlib import Path
+from typing import Iterable
+
+import gradio as gr
+from gradio.themes.base import Base
+from gradio.themes.utils import colors, fonts, sizes
+
+GIT = (
+ (Path(os.environ.get("GIT_HOME", "")) / "git").resolve()
+ if sys.platform == "win32"
+ else "git"
+)
+GIT = str(GIT)
+
+
+def is_module_installed(module_name: str) -> bool:
+ spec = importlib.util.find_spec(module_name)
+ return spec is not None
+
+
+@lru_cache()
+def commit_hash():
+ try:
+ return subprocess.check_output(
+ [GIT, "log", "-1", "--format='%h %s'"], shell=False, encoding="utf8"
+ ).strip()
+ except Exception:
+ return ""
+
+
+def versions_html():
+ import torch
+
+ python_version = ".".join([str(x) for x in sys.version_info[0:3]])
+ commit = commit_hash()
+ hash = commit.strip("'").split(" ")[0]
+
+ return f"""
+version: {hash}
+ •
+python: {python_version}
+ •
+torch: {getattr(torch, '__long_version__',torch.__version__)}
+ •
+gradio: {gr.__version__}
+ •
+author: fishaudio
+"""
+
+
+def version_check(commit):
+ try:
+ import requests
+
+ commits = requests.get(
+ "https://api.github.com/repos/fishaudio/fish-speech/branches/main"
+ ).json()
+ if commit != "" and commits["commit"]["sha"] != commit:
+ print("--------------------------------------------------------")
+ print("| You are not up to date with the most recent release. |")
+ print("| Consider running `git pull` to update. |")
+ print("--------------------------------------------------------")
+ elif commits["commit"]["sha"] == commit:
+ print("You are up to date with the most recent release.")
+ else:
+ print("Not a git clone, can't perform version check.")
+ except Exception as e:
+ print("version check failed", e)
+
+
+class Seafoam(Base):
+ def __init__(
+ self,
+ *,
+ primary_hue: colors.Color | str = colors.emerald,
+ secondary_hue: colors.Color | str = colors.blue,
+ neutral_hue: colors.Color | str = colors.blue,
+ spacing_size: sizes.Size | str = sizes.spacing_md,
+ radius_size: sizes.Size | str = sizes.radius_md,
+ text_size: sizes.Size | str = sizes.text_lg,
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (
+ fonts.GoogleFont("Quicksand"),
+ "ui-sans-serif",
+ "sans-serif",
+ ),
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
+ fonts.GoogleFont("IBM Plex Mono"),
+ "ui-monospace",
+ "monospace",
+ ),
+ ):
+ super().__init__(
+ primary_hue=primary_hue,
+ secondary_hue=secondary_hue,
+ neutral_hue=neutral_hue,
+ spacing_size=spacing_size,
+ radius_size=radius_size,
+ text_size=text_size,
+ font=font,
+ font_mono=font_mono,
+ )
+ super().set(
+ button_primary_background_fill="linear-gradient(90deg, *primary_300, *secondary_400)",
+ button_primary_background_fill_hover="linear-gradient(90deg, *primary_200, *secondary_300)",
+ button_primary_text_color="white",
+ button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *secondary_800)",
+ slider_color="*secondary_300",
+ slider_color_dark="*secondary_600",
+ block_title_text_weight="600",
+ block_border_width="3px",
+ block_shadow="*shadow_drop_lg",
+ # button_shadow="*shadow_drop_lg",
+ button_small_padding="0px",
+ button_large_padding="3px",
+ )
diff --git a/fish_speech/webui/manage.py b/fish_speech/webui/manage.py
new file mode 100644
index 0000000000000000000000000000000000000000..c21233eee3e3e99754c68efc2b8809a62217eb53
--- /dev/null
+++ b/fish_speech/webui/manage.py
@@ -0,0 +1,1239 @@
+from __future__ import annotations
+
+import os
+
+os.environ["USE_LIBUV"] = "0"
+import datetime
+import html
+import json
+import platform
+import shutil
+import signal
+import subprocess
+import sys
+from pathlib import Path
+
+import gradio as gr
+import psutil
+import yaml
+from loguru import logger
+from tqdm import tqdm
+
+PYTHON = os.path.join(os.environ.get("PYTHON_FOLDERPATH", ""), "python")
+sys.path.insert(0, "")
+print(sys.path)
+cur_work_dir = Path(os.getcwd()).resolve()
+print("You are in ", str(cur_work_dir))
+
+from fish_speech.i18n import i18n
+from fish_speech.webui.launch_utils import Seafoam, is_module_installed, versions_html
+
+config_path = cur_work_dir / "fish_speech" / "configs"
+vqgan_yml_path = config_path / "firefly_gan_vq.yaml"
+llama_yml_path = config_path / "text2semantic_finetune.yaml"
+
+env = os.environ.copy()
+env["no_proxy"] = "127.0.0.1, localhost, 0.0.0.0"
+
+seafoam = Seafoam()
+
+
+def build_html_error_message(error):
+ return f"""
+
+ {html.escape(error)}
+
+ """
+
+
+def build_html_ok_message(msg):
+ return f"""
+
+ {html.escape(msg)}
+
+ """
+
+
+def build_html_href(link, desc, msg):
+ return f"""
+
+ {html.escape(msg)}
+ {desc}
+
+ """
+
+
+def load_data_in_raw(path):
+ with open(path, "r", encoding="utf-8") as file:
+ data = file.read()
+ return str(data)
+
+
+def kill_proc_tree(pid, including_parent=True):
+ try:
+ parent = psutil.Process(pid)
+ except psutil.NoSuchProcess:
+ # Process already terminated
+ return
+
+ children = parent.children(recursive=True)
+ for child in children:
+ try:
+ os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL
+ except OSError:
+ pass
+ if including_parent:
+ try:
+ os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL
+ except OSError:
+ pass
+
+
+system = platform.system()
+p_label = None
+p_infer = None
+p_tensorboard = None
+
+
+def kill_process(pid):
+ if system == "Windows":
+ cmd = "taskkill /t /f /pid %s" % pid
+ # os.system(cmd)
+ subprocess.run(cmd)
+ else:
+ kill_proc_tree(pid)
+
+
+def change_label(if_label):
+ global p_label
+ if if_label == True and p_label is None:
+ url = "http://localhost:3000"
+ remote_url = "https://text-labeler.pages.dev/"
+ try:
+ p_label = subprocess.Popen(
+ [
+ (
+ "asr-label-linux-x64"
+ if sys.platform == "linux"
+ else "asr-label-win-x64.exe"
+ )
+ ]
+ )
+ except FileNotFoundError:
+ logger.warning("asr-label execution not found!")
+
+ yield build_html_href(
+ link=remote_url,
+ desc=i18n("Optional online ver"),
+ msg=i18n("Opened labeler in browser"),
+ )
+
+ elif if_label == False and p_label is not None:
+ kill_process(p_label.pid)
+ p_label = None
+ yield build_html_ok_message("Nothing")
+
+
+def clean_infer_cache():
+ import tempfile
+
+ temp_dir = Path(tempfile.gettempdir())
+ gradio_dir = str(temp_dir / "gradio")
+ try:
+ shutil.rmtree(gradio_dir)
+ logger.info(f"Deleted cached audios: {gradio_dir}")
+ except PermissionError:
+ logger.info(f"Permission denied: Unable to delete {gradio_dir}")
+ except FileNotFoundError:
+ logger.info(f"{gradio_dir} was not found")
+ except Exception as e:
+ logger.info(f"An error occurred: {e}")
+
+
+def change_infer(
+ if_infer,
+ host,
+ port,
+ infer_decoder_model,
+ infer_decoder_config,
+ infer_llama_model,
+ infer_compile,
+):
+ global p_infer
+ if if_infer == True and p_infer == None:
+ env = os.environ.copy()
+
+ env["GRADIO_SERVER_NAME"] = host
+ env["GRADIO_SERVER_PORT"] = port
+ # 启动第二个进程
+ url = f"http://{host}:{port}"
+ yield build_html_ok_message(
+ i18n("Inferring interface is launched at {}").format(url)
+ )
+
+ clean_infer_cache()
+
+ p_infer = subprocess.Popen(
+ [
+ PYTHON,
+ "tools/webui.py",
+ "--decoder-checkpoint-path",
+ infer_decoder_model,
+ "--decoder-config-name",
+ infer_decoder_config,
+ "--llama-checkpoint-path",
+ infer_llama_model,
+ ]
+ + (["--compile"] if infer_compile == "Yes" else []),
+ env=env,
+ )
+
+ elif if_infer == False and p_infer is not None:
+ kill_process(p_infer.pid)
+ p_infer = None
+ yield build_html_error_message(i18n("Infer interface is closed"))
+
+
+js = load_data_in_raw("fish_speech/webui/js/animate.js")
+css = load_data_in_raw("fish_speech/webui/css/style.css")
+
+data_pre_output = (cur_work_dir / "data").resolve()
+default_model_output = (cur_work_dir / "results").resolve()
+default_filelist = data_pre_output / "detect.list"
+data_pre_output.mkdir(parents=True, exist_ok=True)
+
+items = []
+dict_items = {}
+
+
+def load_yaml_data_in_fact(yml_path):
+ with open(yml_path, "r", encoding="utf-8") as file:
+ yml = yaml.safe_load(file)
+ return yml
+
+
+def write_yaml_data_in_fact(yml, yml_path):
+ with open(yml_path, "w", encoding="utf-8") as file:
+ yaml.safe_dump(yml, file, allow_unicode=True)
+ return yml
+
+
+def generate_tree(directory, depth=0, max_depth=None, prefix=""):
+ if max_depth is not None and depth > max_depth:
+ return ""
+
+ tree_str = ""
+ files = []
+ directories = []
+ for item in os.listdir(directory):
+ if os.path.isdir(os.path.join(directory, item)):
+ directories.append(item)
+ else:
+ files.append(item)
+
+ entries = directories + files
+ for i, entry in enumerate(entries):
+ connector = "├── " if i < len(entries) - 1 else "└── "
+ tree_str += f"{prefix}{connector}{entry} "
+ if i < len(directories):
+ extension = "│ " if i < len(entries) - 1 else " "
+ tree_str += generate_tree(
+ os.path.join(directory, entry),
+ depth + 1,
+ max_depth,
+ prefix=prefix + extension,
+ )
+ return tree_str
+
+
+def new_explorer(data_path, max_depth):
+ return gr.Markdown(
+ elem_classes=["scrollable-component"],
+ value=generate_tree(data_path, max_depth=max_depth),
+ )
+
+
+def add_item(
+ folder: str,
+ method: str,
+ label_lang: str,
+ if_initial_prompt: bool,
+ initial_prompt: str | None,
+):
+ folder = folder.strip(" ").strip('"')
+
+ folder_path = Path(folder)
+
+ if folder and folder not in items and data_pre_output not in folder_path.parents:
+ if folder_path.is_dir():
+ items.append(folder)
+ dict_items[folder] = dict(
+ type="folder",
+ method=method,
+ label_lang=label_lang,
+ initial_prompt=initial_prompt if if_initial_prompt else None,
+ )
+ elif folder:
+ err = folder
+ return gr.Checkboxgroup(choices=items), build_html_error_message(
+ i18n("Invalid path: {}").format(err)
+ )
+
+ formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
+ logger.info("After Adding: " + formatted_data)
+ gr.Info(formatted_data)
+ return gr.Checkboxgroup(choices=items), build_html_ok_message(
+ i18n("Added path successfully!")
+ )
+
+
+def remove_items(selected_items):
+ global items, dict_items
+ to_remove = [item for item in items if item in selected_items]
+ for item in to_remove:
+ del dict_items[item]
+ items = [item for item in items if item in dict_items.keys()]
+ formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
+ logger.info(formatted_data)
+ gr.Warning("After Removing: " + formatted_data)
+ return gr.Checkboxgroup(choices=items, value=[]), build_html_ok_message(
+ i18n("Removed path successfully!")
+ )
+
+
+def show_selected(options):
+ selected_options = ", ".join(options)
+
+ if options:
+ return i18n("Selected: {}").format(selected_options)
+ else:
+ return i18n("No selected options")
+
+
+from pydub import AudioSegment
+
+
+def convert_to_mono_in_place(audio_path: Path):
+ audio = AudioSegment.from_file(audio_path)
+ if audio.channels > 1:
+ mono_audio = audio.set_channels(1)
+ mono_audio.export(audio_path, format=audio_path.suffix[1:])
+ logger.info(f"Convert {audio_path} successfully")
+
+
+def list_copy(list_file_path, method):
+ wav_root = data_pre_output
+ lst = []
+ with list_file_path.open("r", encoding="utf-8") as file:
+ for line in tqdm(file, desc="Processing audio/transcript"):
+ wav_path, speaker_name, language, text = line.strip().split("|")
+ original_wav_path = Path(wav_path)
+ target_wav_path = (
+ wav_root / original_wav_path.parent.name / original_wav_path.name
+ )
+ lst.append(f"{target_wav_path}|{speaker_name}|{language}|{text}")
+ if target_wav_path.is_file():
+ continue
+ target_wav_path.parent.mkdir(parents=True, exist_ok=True)
+ if method == i18n("Copy"):
+ shutil.copy(original_wav_path, target_wav_path)
+ else:
+ shutil.move(original_wav_path, target_wav_path.parent)
+ convert_to_mono_in_place(target_wav_path)
+ original_lab_path = original_wav_path.with_suffix(".lab")
+ target_lab_path = (
+ wav_root
+ / original_wav_path.parent.name
+ / original_wav_path.with_suffix(".lab").name
+ )
+ if target_lab_path.is_file():
+ continue
+ if method == i18n("Copy"):
+ shutil.copy(original_lab_path, target_lab_path)
+ else:
+ shutil.move(original_lab_path, target_lab_path.parent)
+
+ if method == i18n("Move"):
+ with list_file_path.open("w", encoding="utf-8") as file:
+ file.writelines("\n".join(lst))
+
+ del lst
+ return build_html_ok_message(i18n("Use filelist"))
+
+
+def check_files(data_path: str, max_depth: int, label_model: str, label_device: str):
+ global dict_items
+ data_path = Path(data_path)
+ gr.Warning("Pre-processing begins...")
+ for item, content in dict_items.items():
+ item_path = Path(item)
+ tar_path = data_path / item_path.name
+
+ if content["type"] == "folder" and item_path.is_dir():
+ if content["method"] == i18n("Copy"):
+ os.makedirs(tar_path, exist_ok=True)
+ shutil.copytree(
+ src=str(item_path), dst=str(tar_path), dirs_exist_ok=True
+ )
+ elif not tar_path.is_dir():
+ shutil.move(src=str(item_path), dst=str(tar_path))
+
+ for suf in ["wav", "flac", "mp3"]:
+ for audio_path in tar_path.glob(f"**/*.{suf}"):
+ convert_to_mono_in_place(audio_path)
+
+ cur_lang = content["label_lang"]
+ initial_prompt = content["initial_prompt"]
+
+ transcribe_cmd = [
+ PYTHON,
+ "tools/whisper_asr.py",
+ "--model-size",
+ label_model,
+ "--device",
+ label_device,
+ "--audio-dir",
+ tar_path,
+ "--save-dir",
+ tar_path,
+ "--language",
+ cur_lang,
+ ]
+
+ if initial_prompt is not None:
+ transcribe_cmd += ["--initial-prompt", initial_prompt]
+
+ if cur_lang != "IGNORE":
+ try:
+ gr.Warning("Begin To Transcribe")
+ subprocess.run(
+ transcribe_cmd,
+ env=env,
+ )
+ except Exception:
+ print("Transcription error occurred")
+
+ elif content["type"] == "file" and item_path.is_file():
+ list_copy(item_path, content["method"])
+
+ return build_html_ok_message(i18n("Move files successfully")), new_explorer(
+ data_path, max_depth=max_depth
+ )
+
+
+def generate_folder_name():
+ now = datetime.datetime.now()
+ folder_name = now.strftime("%Y%m%d_%H%M%S")
+ return folder_name
+
+
+def train_process(
+ data_path: str,
+ option: str,
+ # llama config
+ llama_ckpt,
+ llama_base_config,
+ llama_lr,
+ llama_maxsteps,
+ llama_data_num_workers,
+ llama_data_batch_size,
+ llama_data_max_length,
+ llama_precision,
+ llama_check_interval,
+ llama_grad_batches,
+ llama_use_speaker,
+ llama_use_lora,
+):
+
+ backend = "nccl" if sys.platform == "linux" else "gloo"
+
+ new_project = generate_folder_name()
+ print("New Project Name: ", new_project)
+
+ if option == "VQGAN":
+ msg = "Skipped VQGAN Training."
+ gr.Warning(msg)
+ logger.info(msg)
+
+ if option == "LLAMA":
+ msg = "LLAMA Training begins..."
+ gr.Warning(msg)
+ logger.info(msg)
+ subprocess.run(
+ [
+ PYTHON,
+ "tools/vqgan/extract_vq.py",
+ str(data_pre_output),
+ "--num-workers",
+ "1",
+ "--batch-size",
+ "16",
+ "--config-name",
+ "firefly_gan_vq",
+ "--checkpoint-path",
+ "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+ ]
+ )
+
+ subprocess.run(
+ [
+ PYTHON,
+ "tools/llama/build_dataset.py",
+ "--input",
+ str(data_pre_output),
+ "--text-extension",
+ ".lab",
+ "--num-workers",
+ "16",
+ ]
+ )
+ ckpt_path = "checkpoints/fish-speech-1.4/model.pth"
+ lora_prefix = "lora_" if llama_use_lora else ""
+ llama_name = lora_prefix + "text2semantic_" + new_project
+ latest = next(
+ iter(
+ sorted(
+ [
+ str(p.relative_to("results"))
+ for p in Path("results").glob(lora_prefix + "text2sem*/")
+ ],
+ reverse=True,
+ )
+ ),
+ llama_name,
+ )
+ project = (
+ llama_name
+ if llama_ckpt == i18n("new")
+ else (
+ latest
+ if llama_ckpt == i18n("latest")
+ else Path(llama_ckpt).relative_to("results")
+ )
+ )
+ logger.info(project)
+
+ if llama_check_interval > llama_maxsteps:
+ llama_check_interval = llama_maxsteps
+
+ train_cmd = [
+ PYTHON,
+ "fish_speech/train.py",
+ "--config-name",
+ "text2semantic_finetune",
+ f"project={project}",
+ f"trainer.strategy.process_group_backend={backend}",
+ f"train_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
+ f"val_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
+ f"model.optimizer.lr={llama_lr}",
+ f"trainer.max_steps={llama_maxsteps}",
+ f"data.num_workers={llama_data_num_workers}",
+ f"data.batch_size={llama_data_batch_size}",
+ f"max_length={llama_data_max_length}",
+ f"trainer.precision={llama_precision}",
+ f"trainer.val_check_interval={llama_check_interval}",
+ f"trainer.accumulate_grad_batches={llama_grad_batches}",
+ f"train_dataset.interactive_prob={llama_use_speaker}",
+ ] + ([f"+lora@model.model.lora_config=r_8_alpha_16"] if llama_use_lora else [])
+ logger.info(train_cmd)
+ subprocess.run(train_cmd)
+
+ return build_html_ok_message(i18n("Training stopped"))
+
+
+def tensorboard_process(
+ if_tensorboard: bool,
+ tensorboard_dir: str,
+ host: str,
+ port: str,
+):
+ global p_tensorboard
+ if if_tensorboard == True and p_tensorboard == None:
+ url = f"http://{host}:{port}"
+ yield build_html_ok_message(
+ i18n("Tensorboard interface is launched at {}").format(url)
+ )
+ prefix = ["tensorboard"]
+ if Path("fishenv").exists():
+ prefix = ["fishenv/env/python.exe", "fishenv/env/Scripts/tensorboard.exe"]
+
+ p_tensorboard = subprocess.Popen(
+ prefix
+ + [
+ "--logdir",
+ tensorboard_dir,
+ "--host",
+ host,
+ "--port",
+ port,
+ "--reload_interval",
+ "120",
+ ]
+ )
+ elif if_tensorboard == False and p_tensorboard != None:
+ kill_process(p_tensorboard.pid)
+ p_tensorboard = None
+ yield build_html_error_message(i18n("Tensorboard interface is closed"))
+
+
+def fresh_tb_dir():
+ return gr.Dropdown(
+ choices=[str(p) for p in Path("results").glob("**/tensorboard/")]
+ )
+
+
+def list_decoder_models():
+ paths = [str(p) for p in Path("checkpoints").glob("fish*/firefly*.pth")]
+ if not paths:
+ logger.warning("No decoder model found")
+ return paths
+
+
+def list_llama_models():
+ choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*model*.pth")]
+ choices += [str(p.parent) for p in Path("checkpoints").glob("fish*/*model*.pth")]
+ choices += [str(p.parent) for p in Path("checkpoints").glob("fs*/*model*.pth")]
+ choices = sorted(choices, reverse=True)
+ if not choices:
+ logger.warning("No LLaMA model found")
+ return choices
+
+
+def list_lora_llama_models():
+ choices = sorted(
+ [str(p) for p in Path("results").glob("lora*/**/*.ckpt")], reverse=True
+ )
+ if not choices:
+ logger.warning("No LoRA LLaMA model found")
+ return choices
+
+
+def fresh_decoder_model():
+ return gr.Dropdown(choices=list_decoder_models())
+
+
+def fresh_llama_ckpt(llama_use_lora):
+ return gr.Dropdown(
+ choices=[i18n("latest"), i18n("new")]
+ + (
+ [str(p) for p in Path("results").glob("text2sem*/")]
+ if not llama_use_lora
+ else [str(p) for p in Path("results").glob("lora_*/")]
+ )
+ )
+
+
+def fresh_llama_model():
+ return gr.Dropdown(choices=list_llama_models())
+
+
+def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_output):
+ if (
+ lora_weight is None
+ or not Path(lora_weight).exists()
+ or not Path(llama_weight).exists()
+ ):
+ return build_html_error_message(
+ i18n(
+ "Path error, please check the model file exists in the corresponding path"
+ )
+ )
+ gr.Warning("Merging begins...")
+ merge_cmd = [
+ PYTHON,
+ "tools/llama/merge_lora.py",
+ "--lora-config",
+ "r_8_alpha_16",
+ "--lora-weight",
+ lora_weight,
+ "--output",
+ llama_lora_output + "_" + generate_folder_name(),
+ ]
+ logger.info(merge_cmd)
+ subprocess.run(merge_cmd)
+ return build_html_ok_message(i18n("Merge successfully"))
+
+
+def llama_quantify(llama_weight, quantify_mode):
+ if llama_weight is None or not Path(llama_weight).exists():
+ return build_html_error_message(
+ i18n(
+ "Path error, please check the model file exists in the corresponding path"
+ )
+ )
+
+ gr.Warning("Quantifying begins...")
+
+ now = generate_folder_name()
+ quantify_cmd = [
+ PYTHON,
+ "tools/llama/quantize.py",
+ "--checkpoint-path",
+ llama_weight,
+ "--mode",
+ quantify_mode,
+ "--timestamp",
+ now,
+ ]
+ logger.info(quantify_cmd)
+ subprocess.run(quantify_cmd)
+ if quantify_mode == "int8":
+ quantize_path = str(
+ Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-{now}"
+ )
+ else:
+ quantize_path = str(
+ Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-g128-{now}"
+ )
+ return build_html_ok_message(
+ i18n("Quantify successfully") + f"Path: {quantize_path}"
+ )
+
+
+init_vqgan_yml = load_yaml_data_in_fact(vqgan_yml_path)
+init_llama_yml = load_yaml_data_in_fact(llama_yml_path)
+
+with gr.Blocks(
+ head="",
+ js=js,
+ theme=seafoam,
+ analytics_enabled=False,
+ title="Fish Speech",
+) as demo:
+ with gr.Row():
+ with gr.Column():
+ with gr.Tab("\U0001F4D6 " + i18n("Data Preprocessing")):
+ with gr.Row():
+ textbox = gr.Textbox(
+ label="\U0000270F "
+ + i18n("Input Audio & Source Path for Transcription"),
+ info=i18n("Speaker is identified by the folder name"),
+ interactive=True,
+ )
+ with gr.Row(equal_height=False):
+ with gr.Column():
+ output_radio = gr.Radio(
+ label="\U0001F4C1 "
+ + i18n("Select source file processing method"),
+ choices=[i18n("Copy"), i18n("Move")],
+ value=i18n("Copy"),
+ interactive=True,
+ )
+ with gr.Column():
+ error = gr.HTML(label=i18n("Error Message"))
+ if_label = gr.Checkbox(
+ label=i18n("Open Labeler WebUI"), scale=0, show_label=True
+ )
+
+ with gr.Row():
+ label_device = gr.Dropdown(
+ label=i18n("Labeling Device"),
+ info=i18n(
+ "It is recommended to use CUDA, if you have low configuration, use CPU"
+ ),
+ choices=["cpu", "cuda"],
+ value="cuda",
+ interactive=True,
+ )
+ label_model = gr.Dropdown(
+ label=i18n("Whisper Model"),
+ info=i18n("Faster Whisper, Up to 5g GPU memory usage"),
+ choices=["large-v3", "medium"],
+ value="large-v3",
+ interactive=True,
+ )
+ label_radio = gr.Dropdown(
+ label=i18n("Optional Label Language"),
+ info=i18n(
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format"
+ ),
+ choices=[
+ (i18n("Chinese"), "zh"),
+ (i18n("English"), "en"),
+ (i18n("Japanese"), "ja"),
+ (i18n("Disabled"), "IGNORE"),
+ (i18n("auto"), "auto"),
+ ],
+ value="IGNORE",
+ interactive=True,
+ )
+
+ with gr.Row():
+ if_initial_prompt = gr.Checkbox(
+ value=False,
+ label=i18n("Enable Initial Prompt"),
+ min_width=120,
+ scale=0,
+ )
+ initial_prompt = gr.Textbox(
+ label=i18n("Initial Prompt"),
+ info=i18n(
+ "Initial prompt can provide contextual or vocabulary-specific guidance to the model."
+ ),
+ placeholder="This audio introduces the basic concepts and applications of artificial intelligence and machine learning.",
+ interactive=False,
+ )
+
+ with gr.Row():
+ add_button = gr.Button(
+ "\U000027A1 " + i18n("Add to Processing Area"),
+ variant="primary",
+ )
+ remove_button = gr.Button(
+ "\U000026D4 " + i18n("Remove Selected Data")
+ )
+
+ with gr.Tab("\U0001F6E0 " + i18n("Training Configuration")):
+ with gr.Row():
+ model_type_radio = gr.Radio(
+ label=i18n(
+ "Select the model to be trained (Depending on the Tab page you are on)"
+ ),
+ interactive=False,
+ choices=["VQGAN", "LLAMA"],
+ value="VQGAN",
+ )
+ with gr.Row():
+ with gr.Column():
+ with gr.Tab(label=i18n("VQGAN Configuration")) as vqgan_page:
+ gr.HTML("You don't need to train this model!")
+
+ with gr.Tab(label=i18n("LLAMA Configuration")) as llama_page:
+ with gr.Row(equal_height=False):
+ llama_use_lora = gr.Checkbox(
+ label=i18n("Use LoRA"),
+ info=i18n(
+ "Use LoRA can save GPU memory, but may reduce the quality of the model"
+ ),
+ value=True,
+ interactive=True,
+ )
+ llama_ckpt = gr.Dropdown(
+ label=i18n("Select LLAMA ckpt"),
+ choices=[i18n("latest"), i18n("new")]
+ + [
+ str(p)
+ for p in Path("results").glob("text2sem*/")
+ ]
+ + [str(p) for p in Path("results").glob("lora*/")],
+ value=i18n("latest"),
+ interactive=True,
+ )
+ with gr.Row(equal_height=False):
+ llama_lr_slider = gr.Slider(
+ label=i18n("Initial Learning Rate"),
+ info=i18n(
+ "lr smaller -> usually train slower but more stable"
+ ),
+ interactive=True,
+ minimum=1e-5,
+ maximum=1e-4,
+ step=1e-5,
+ value=5e-5,
+ )
+ llama_maxsteps_slider = gr.Slider(
+ label=i18n("Maximum Training Steps"),
+ info=i18n(
+ "recommend: max_steps = num_audios // batch_size * (2 to 5)"
+ ),
+ interactive=True,
+ minimum=1,
+ maximum=10000,
+ step=1,
+ value=50,
+ )
+ with gr.Row(equal_height=False):
+ llama_base_config = gr.Dropdown(
+ label=i18n("Model Size"),
+ choices=[
+ "text2semantic_finetune",
+ ],
+ value="text2semantic_finetune",
+ )
+ llama_data_num_workers_slider = gr.Slider(
+ label=i18n("Number of Workers"),
+ minimum=1,
+ maximum=32,
+ step=1,
+ value=4,
+ )
+ with gr.Row(equal_height=False):
+ llama_data_batch_size_slider = gr.Slider(
+ label=i18n("Batch Size"),
+ interactive=True,
+ minimum=1,
+ maximum=32,
+ step=1,
+ value=2,
+ )
+ llama_data_max_length_slider = gr.Slider(
+ label=i18n("Maximum Length per Sample"),
+ interactive=True,
+ minimum=1024,
+ maximum=4096,
+ step=128,
+ value=2048,
+ )
+ with gr.Row(equal_height=False):
+ llama_precision_dropdown = gr.Dropdown(
+ label=i18n("Precision"),
+ info=i18n(
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU"
+ ),
+ interactive=True,
+ choices=["32", "bf16-true", "16-mixed"],
+ value="bf16-true",
+ )
+ llama_check_interval_slider = gr.Slider(
+ label=i18n("Save model every n steps"),
+ info=i18n(
+ "make sure that it's not greater than max_steps"
+ ),
+ interactive=True,
+ minimum=1,
+ maximum=1000,
+ step=1,
+ value=50,
+ )
+ with gr.Row(equal_height=False):
+ llama_grad_batches = gr.Slider(
+ label=i18n("Accumulate Gradient Batches"),
+ interactive=True,
+ minimum=1,
+ maximum=20,
+ step=1,
+ value=init_llama_yml["trainer"][
+ "accumulate_grad_batches"
+ ],
+ )
+ llama_use_speaker = gr.Slider(
+ label=i18n(
+ "Probability of applying Speaker Condition"
+ ),
+ interactive=True,
+ minimum=0.1,
+ maximum=1.0,
+ step=0.05,
+ value=init_llama_yml["train_dataset"][
+ "interactive_prob"
+ ],
+ )
+
+ with gr.Tab(label=i18n("Merge LoRA"), id=4):
+ with gr.Row(equal_height=False):
+ llama_weight = gr.Dropdown(
+ label=i18n("Base LLAMA Model"),
+ info=i18n(
+ "Type the path or select from the dropdown"
+ ),
+ choices=[
+ "checkpoints/fish-speech-1.4/model.pth",
+ ],
+ value="checkpoints/fish-speech-1.4/model.pth",
+ allow_custom_value=True,
+ interactive=True,
+ )
+ with gr.Row(equal_height=False):
+ lora_weight = gr.Dropdown(
+ label=i18n("LoRA Model to be merged"),
+ info=i18n(
+ "Type the path or select from the dropdown"
+ ),
+ choices=[
+ str(p)
+ for p in Path("results").glob("lora*/**/*.ckpt")
+ ],
+ allow_custom_value=True,
+ interactive=True,
+ )
+ lora_llama_config = gr.Dropdown(
+ label=i18n("LLAMA Model Config"),
+ info=i18n(
+ "Type the path or select from the dropdown"
+ ),
+ choices=[
+ "text2semantic_finetune",
+ ],
+ value="text2semantic_finetune",
+ allow_custom_value=True,
+ )
+ with gr.Row(equal_height=False):
+ llama_lora_output = gr.Dropdown(
+ label=i18n("Output Path"),
+ info=i18n(
+ "Type the path or select from the dropdown"
+ ),
+ value="checkpoints/merged",
+ choices=["checkpoints/merged"],
+ allow_custom_value=True,
+ interactive=True,
+ )
+ with gr.Row(equal_height=False):
+ llama_lora_merge_btn = gr.Button(
+ value=i18n("Merge"), variant="primary"
+ )
+
+ with gr.Tab(label=i18n("Model Quantization"), id=5):
+ with gr.Row(equal_height=False):
+ llama_weight_to_quantify = gr.Dropdown(
+ label=i18n("Base LLAMA Model"),
+ info=i18n(
+ "Type the path or select from the dropdown"
+ ),
+ choices=list_llama_models(),
+ value="checkpoints/fish-speech-1.4",
+ allow_custom_value=True,
+ interactive=True,
+ )
+ quantify_mode = gr.Dropdown(
+ label=i18n("Post-quantification Precision"),
+ info=i18n(
+ "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase"
+ ),
+ choices=["int8", "int4"],
+ value="int8",
+ allow_custom_value=False,
+ interactive=True,
+ )
+ with gr.Row(equal_height=False):
+ llama_quantify_btn = gr.Button(
+ value=i18n("Quantify"), variant="primary"
+ )
+
+ with gr.Tab(label="Tensorboard", id=6):
+ with gr.Row(equal_height=False):
+ tb_host = gr.Textbox(
+ label=i18n("Tensorboard Host"), value="127.0.0.1"
+ )
+ tb_port = gr.Textbox(
+ label=i18n("Tensorboard Port"), value="11451"
+ )
+ with gr.Row(equal_height=False):
+ tb_dir = gr.Dropdown(
+ label=i18n("Tensorboard Log Path"),
+ allow_custom_value=True,
+ choices=[
+ str(p)
+ for p in Path("results").glob("**/tensorboard/")
+ ],
+ )
+ with gr.Row(equal_height=False):
+ if_tb = gr.Checkbox(
+ label=i18n("Open Tensorboard"),
+ )
+
+ with gr.Tab("\U0001F9E0 " + i18n("Inference Configuration")):
+ with gr.Column():
+ with gr.Row():
+ with gr.Accordion(
+ label="\U0001F5A5 "
+ + i18n("Inference Server Configuration"),
+ open=False,
+ ):
+ with gr.Row():
+ infer_host_textbox = gr.Textbox(
+ label=i18n("WebUI Host"), value="127.0.0.1"
+ )
+ infer_port_textbox = gr.Textbox(
+ label=i18n("WebUI Port"), value="7862"
+ )
+ with gr.Row():
+ infer_decoder_model = gr.Dropdown(
+ label=i18n("Decoder Model Path"),
+ info=i18n(
+ "Type the path or select from the dropdown"
+ ),
+ choices=list_decoder_models(),
+ value="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+ allow_custom_value=True,
+ )
+ infer_decoder_config = gr.Dropdown(
+ label=i18n("Decoder Model Config"),
+ info=i18n("Changing with the Model Path"),
+ value="firefly_gan_vq",
+ choices=[
+ "firefly_gan_vq",
+ ],
+ allow_custom_value=True,
+ )
+ with gr.Row():
+ infer_llama_model = gr.Dropdown(
+ label=i18n("LLAMA Model Path"),
+ info=i18n(
+ "Type the path or select from the dropdown"
+ ),
+ value="checkpoints/fish-speech-1.4",
+ choices=list_llama_models(),
+ allow_custom_value=True,
+ )
+
+ with gr.Row():
+ infer_compile = gr.Radio(
+ label=i18n("Compile Model"),
+ info=i18n(
+ "Compile the model can significantly reduce the inference time, but will increase cold start time"
+ ),
+ choices=["Yes", "No"],
+ value=(
+ "Yes" if (sys.platform == "linux") else "No"
+ ),
+ interactive=is_module_installed("triton"),
+ )
+
+ with gr.Row():
+ infer_checkbox = gr.Checkbox(
+ label=i18n("Open Inference Server")
+ )
+ infer_error = gr.HTML(label=i18n("Inference Server Error"))
+
+ with gr.Column():
+ train_error = gr.HTML(label=i18n("Training Error"))
+ checkbox_group = gr.CheckboxGroup(
+ label="\U0001F4CA " + i18n("Data Source"),
+ info=i18n(
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list."
+ ),
+ elem_classes=["data_src"],
+ )
+ train_box = gr.Textbox(
+ label=i18n("Data Preprocessing Path"),
+ value=str(data_pre_output),
+ interactive=False,
+ )
+ model_box = gr.Textbox(
+ label="\U0001F4BE " + i18n("Model Output Path"),
+ value=str(default_model_output),
+ interactive=False,
+ )
+
+ with gr.Accordion(
+ i18n(
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)"
+ ),
+ elem_classes=["scrollable-component"],
+ elem_id="file_accordion",
+ ):
+ tree_slider = gr.Slider(
+ minimum=0,
+ maximum=3,
+ value=0,
+ step=1,
+ show_label=False,
+ container=False,
+ )
+ file_markdown = new_explorer(str(data_pre_output), 0)
+ with gr.Row(equal_height=False):
+ admit_btn = gr.Button(
+ "\U00002705 " + i18n("File Preprocessing"),
+ variant="primary",
+ )
+ fresh_btn = gr.Button("\U0001F503", scale=0, min_width=80)
+ help_button = gr.Button("\U00002753", scale=0, min_width=80) # question
+ train_btn = gr.Button(i18n("Start Training"), variant="primary")
+
+ footer = load_data_in_raw("fish_speech/webui/html/footer.html")
+ footer = footer.format(
+ versions=versions_html(),
+ api_docs="https://speech.fish.audio/inference/#http-api",
+ )
+ gr.HTML(footer, elem_id="footer")
+ vqgan_page.select(lambda: "VQGAN", None, model_type_radio)
+ llama_page.select(lambda: "LLAMA", None, model_type_radio)
+ add_button.click(
+ fn=add_item,
+ inputs=[textbox, output_radio, label_radio, if_initial_prompt, initial_prompt],
+ outputs=[checkbox_group, error],
+ )
+ remove_button.click(
+ fn=remove_items, inputs=[checkbox_group], outputs=[checkbox_group, error]
+ )
+ checkbox_group.change(fn=show_selected, inputs=checkbox_group, outputs=[error])
+ help_button.click(
+ fn=None,
+ js='() => { window.open("https://speech.fish.audio/", "newwindow", "height=100, width=400, '
+ 'toolbar=no, menubar=no, scrollbars=no, resizable=no, location=no, status=no")}',
+ )
+ if_label.change(fn=change_label, inputs=[if_label], outputs=[error])
+ if_initial_prompt.change(
+ fn=lambda x: gr.Textbox(value="", interactive=x),
+ inputs=[if_initial_prompt],
+ outputs=[initial_prompt],
+ )
+ train_btn.click(
+ fn=train_process,
+ inputs=[
+ train_box,
+ model_type_radio,
+ # llama config
+ llama_ckpt,
+ llama_base_config,
+ llama_lr_slider,
+ llama_maxsteps_slider,
+ llama_data_num_workers_slider,
+ llama_data_batch_size_slider,
+ llama_data_max_length_slider,
+ llama_precision_dropdown,
+ llama_check_interval_slider,
+ llama_grad_batches,
+ llama_use_speaker,
+ llama_use_lora,
+ ],
+ outputs=[train_error],
+ )
+ if_tb.change(
+ fn=tensorboard_process,
+ inputs=[if_tb, tb_dir, tb_host, tb_port],
+ outputs=[train_error],
+ )
+ tb_dir.change(fn=fresh_tb_dir, inputs=[], outputs=[tb_dir])
+ infer_decoder_model.change(
+ fn=fresh_decoder_model, inputs=[], outputs=[infer_decoder_model]
+ )
+ infer_llama_model.change(
+ fn=fresh_llama_model, inputs=[], outputs=[infer_llama_model]
+ )
+ llama_weight.change(fn=fresh_llama_model, inputs=[], outputs=[llama_weight])
+ admit_btn.click(
+ fn=check_files,
+ inputs=[train_box, tree_slider, label_model, label_device],
+ outputs=[error, file_markdown],
+ )
+ fresh_btn.click(
+ fn=new_explorer, inputs=[train_box, tree_slider], outputs=[file_markdown]
+ )
+ llama_use_lora.change(
+ fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt]
+ )
+ llama_ckpt.change(
+ fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt]
+ )
+ lora_weight.change(
+ fn=lambda: gr.Dropdown(choices=list_lora_llama_models()),
+ inputs=[],
+ outputs=[lora_weight],
+ )
+ llama_lora_merge_btn.click(
+ fn=llama_lora_merge,
+ inputs=[llama_weight, lora_llama_config, lora_weight, llama_lora_output],
+ outputs=[train_error],
+ )
+ llama_quantify_btn.click(
+ fn=llama_quantify,
+ inputs=[llama_weight_to_quantify, quantify_mode],
+ outputs=[train_error],
+ )
+ infer_checkbox.change(
+ fn=change_infer,
+ inputs=[
+ infer_checkbox,
+ infer_host_textbox,
+ infer_port_textbox,
+ infer_decoder_model,
+ infer_decoder_config,
+ infer_llama_model,
+ infer_compile,
+ ],
+ outputs=[infer_error],
+ )
+
+demo.launch(inbrowser=True)
diff --git a/inference.ipynb b/inference.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..e690a80d4f751342569ffe78c43b0b5c327c7f7a
--- /dev/null
+++ b/inference.ipynb
@@ -0,0 +1,214 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Fish Speech"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### For Windows User / win用户"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "vscode": {
+ "languageId": "bat"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "!chcp 65001"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### For Linux User / Linux 用户"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import locale\n",
+ "locale.setlocale(locale.LC_ALL, 'en_US.UTF-8')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Prepare Model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# For Chinese users, you probably want to use mirror to accelerate downloading\n",
+ "# !set HF_ENDPOINT=https://hf-mirror.com\n",
+ "# !export HF_ENDPOINT=https://hf-mirror.com \n",
+ "\n",
+ "!huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4/"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## WebUI Inference\n",
+ "\n",
+ "> You can use --compile to fuse CUDA kernels for faster inference (10x)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "vscode": {
+ "languageId": "shellscript"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "!python tools/webui.py \\\n",
+ " --llama-checkpoint-path checkpoints/fish-speech-1.4 \\\n",
+ " --decoder-checkpoint-path checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth \\\n",
+ " # --compile"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Break-down CLI Inference"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1. Encode reference audio: / 从语音生成 prompt: \n",
+ "\n",
+ "You should get a `fake.npy` file.\n",
+ "\n",
+ "你应该能得到一个 `fake.npy` 文件."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "vscode": {
+ "languageId": "shellscript"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "## Enter the path to the audio file here\n",
+ "src_audio = r\"D:\\PythonProject\\vo_hutao_draw_appear.wav\"\n",
+ "\n",
+ "!python tools/vqgan/inference.py \\\n",
+ " -i {src_audio} \\\n",
+ " --checkpoint-path \"checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth\"\n",
+ "\n",
+ "from IPython.display import Audio, display\n",
+ "audio = Audio(filename=\"fake.wav\")\n",
+ "display(audio)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2. Generate semantic tokens from text: / 从文本生成语义 token:\n",
+ "\n",
+ "> This command will create a codes_N file in the working directory, where N is an integer starting from 0.\n",
+ "\n",
+ "> You may want to use `--compile` to fuse CUDA kernels for faster inference (~30 tokens/second -> ~300 tokens/second).\n",
+ "\n",
+ "> 该命令会在工作目录下创建 codes_N 文件, 其中 N 是从 0 开始的整数.\n",
+ "\n",
+ "> 您可以使用 `--compile` 来融合 cuda 内核以实现更快的推理 (~30 tokens/秒 -> ~300 tokens/秒)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "vscode": {
+ "languageId": "shellscript"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "!python tools/llama/generate.py \\\n",
+ " --text \"hello world\" \\\n",
+ " --prompt-text \"The text corresponding to reference audio\" \\\n",
+ " --prompt-tokens \"fake.npy\" \\\n",
+ " --checkpoint-path \"checkpoints/fish-speech-1.4\" \\\n",
+ " --num-samples 2\n",
+ " # --compile"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 3. Generate speech from semantic tokens: / 从语义 token 生成人声:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "vscode": {
+ "languageId": "shellscript"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "!python tools/vqgan/inference.py \\\n",
+ " -i \"codes_0.npy\" \\\n",
+ " --checkpoint-path \"checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth\"\n",
+ "\n",
+ "from IPython.display import Audio, display\n",
+ "audio = Audio(filename=\"fake.wav\")\n",
+ "display(audio)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.14"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/install_env.bat b/install_env.bat
new file mode 100644
index 0000000000000000000000000000000000000000..744ddb42c70a84d342fc3a2e357d8d36f7d4c998
--- /dev/null
+++ b/install_env.bat
@@ -0,0 +1,180 @@
+@echo off
+chcp 65001
+
+set USE_MIRROR=true
+echo "USE_MIRROR: %USE_MIRROR%"
+setlocal enabledelayedexpansion
+
+cd /D "%~dp0"
+
+set PATH="%PATH%";%SystemRoot%\system32
+
+echo %PATH%
+
+
+echo "%CD%"| findstr /R /C:"[!#\$%&()\*+,;<=>?@\[\]\^`{|}~\u4E00-\u9FFF ] " >nul && (
+ echo.
+ echo There are special characters in the current path, please make the path of fish-speech free of special characters before running. && (
+ goto end
+ )
+)
+
+
+set TMP=%CD%\fishenv
+set TEMP=%CD%\fishenv
+
+(call conda deactivate && call conda deactivate && call conda deactivate) 2>nul
+
+set INSTALL_DIR=%cd%\fishenv
+set CONDA_ROOT_PREFIX=%cd%\fishenv\conda
+set INSTALL_ENV_DIR=%cd%\fishenv\env
+set PIP_CMD=%cd%\fishenv\env\python -m pip
+set PYTHON_CMD=%cd%\fishenv\env\python
+set API_FLAG_PATH=%~dp0API_FLAGS.txt
+set MINICONDA_DOWNLOAD_URL=https://repo.anaconda.com/miniconda/Miniconda3-py310_23.3.1-0-Windows-x86_64.exe
+if "!USE_MIRROR!" == "true" (
+ set MINICONDA_DOWNLOAD_URL=https://mirrors.tuna.tsinghua.edu.cn/anaconda/miniconda/Miniconda3-py310_23.3.1-0-Windows-x86_64.exe
+)
+set MINICONDA_CHECKSUM=307194e1f12bbeb52b083634e89cc67db4f7980bd542254b43d3309eaf7cb358
+set conda_exists=F
+
+call "%CONDA_ROOT_PREFIX%\_conda.exe" --version >nul 2>&1
+if "%ERRORLEVEL%" EQU "0" set conda_exists=T
+
+if "%conda_exists%" == "F" (
+ echo.
+ echo Downloading Miniconda...
+ mkdir "%INSTALL_DIR%" 2>nul
+ call curl -Lk "%MINICONDA_DOWNLOAD_URL%" > "%INSTALL_DIR%\miniconda_installer.exe"
+ if errorlevel 1 (
+ echo.
+ echo Failed to download miniconda.
+ goto end
+ )
+ for /f %%a in ('
+ certutil -hashfile "%INSTALL_DIR%\miniconda_installer.exe" sha256
+ ^| find /i /v " "
+ ^| find /i "%MINICONDA_CHECKSUM%"
+ ') do (
+ set "hash=%%a"
+ )
+ if not defined hash (
+ echo.
+ echo Miniconda hash mismatched!
+ del "%INSTALL_DIR%\miniconda_installer.exe"
+ goto end
+ ) else (
+ echo.
+ echo Miniconda hash matched successfully.
+ )
+ echo Downloaded "%CONDA_ROOT_PREFIX%"
+ start /wait "" "%INSTALL_DIR%\miniconda_installer.exe" /InstallationType=JustMe /NoShortcuts=1 /AddToPath=0 /RegisterPython=0 /NoRegistry=1 /S /D=%CONDA_ROOT_PREFIX%
+
+ call "%CONDA_ROOT_PREFIX%\_conda.exe" --version
+ if errorlevel 1 (
+ echo.
+ echo Cannot install Miniconda.
+ goto end
+ ) else (
+ echo.
+ echo Miniconda Install success.
+ )
+
+ del "%INSTALL_DIR%\miniconda_installer.exe"
+)
+
+
+if not exist "%INSTALL_ENV_DIR%" (
+ echo.
+ echo Creating Conda Environment...
+ if "!USE_MIRROR!" == "true" (
+ call "%CONDA_ROOT_PREFIX%\_conda.exe" create --no-shortcuts -y -k --prefix "%INSTALL_ENV_DIR%" -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ python=3.10
+ ) else (
+ call "%CONDA_ROOT_PREFIX%\_conda.exe" create --no-shortcuts -y -k --prefix "%INSTALL_ENV_DIR%" python=3.10
+ )
+
+ if errorlevel 1 (
+ echo.
+ echo Failed to Create Environment.
+ goto end
+ )
+)
+
+if not exist "%INSTALL_ENV_DIR%\python.exe" (
+ echo.
+ echo Conda Env does not exist.
+ goto end
+)
+
+set PYTHONNOUSERSITE=1
+set PYTHONPATH=
+set PYTHONHOME=
+set "CUDA_PATH=%INSTALL_ENV_DIR%"
+set "CUDA_HOME=%CUDA_PATH%"
+
+call "%CONDA_ROOT_PREFIX%\condabin\conda.bat" activate "%INSTALL_ENV_DIR%"
+
+if errorlevel 1 (
+ echo.
+ echo Failed to activate Env.
+ goto end
+) else (
+ echo.
+ echo successfully create env.
+)
+
+set "HF_ENDPOINT=https://huggingface.co"
+set "no_proxy="
+if "%USE_MIRROR%"=="true" (
+ set "HF_ENDPOINT=https://hf-mirror.com"
+ set "no_proxy=localhost,127.0.0.1,0.0.0.0"
+)
+
+echo "HF_ENDPOINT: !HF_ENDPOINT!"
+echo "NO_PROXY: !no_proxy!"
+
+%PIP_CMD% install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121
+
+%PIP_CMD% install -e . --upgrade-strategy only-if-needed
+
+call :download_and_install "triton_windows-0.1.0-py3-none-any.whl" ^
+ "%HF_ENDPOINT%/datasets/SpicyqSama007/windows_compile/resolve/main/triton_windows-0.1.0-py3-none-any.whl?download=true" ^
+ "2cc998638180f37cf5025ab65e48c7f629aa5a369176cfa32177d2bd9aa26a0a"
+
+
+endlocal
+echo "Environment Check: Success."
+:end
+pause
+
+goto :EOF
+
+
+:download_and_install
+setlocal
+
+set "WHEEL_FILE=%1"
+set "URL=%2"
+set "CHKSUM=%3"
+
+:DOWNLOAD
+if not exist "%WHEEL_FILE%" (
+ call curl -Lk "%URL%" --output "%WHEEL_FILE%"
+)
+
+for /f "delims=" %%I in ("certutil -hashfile %WHEEL_FILE% SHA256 ^| find /i %CHKSUM%") do (
+ set "FILE_VALID=true"
+)
+
+if not defined FILE_VALID (
+ echo File checksum does not match, re-downloading...
+ del "%WHEEL_FILE%"
+ goto DOWNLOAD
+)
+
+echo "OK for %WHEEL_FILE%"
+%PIP_CMD% install "%WHEEL_FILE%" --no-warn-script-location
+del "%WHEEL_FILE%"
+
+endlocal
+goto :EOF
diff --git a/mkdocs.yml b/mkdocs.yml
new file mode 100644
index 0000000000000000000000000000000000000000..ae2099fe85ba34e2848ad6a1d1bea498155908b6
--- /dev/null
+++ b/mkdocs.yml
@@ -0,0 +1,144 @@
+site_name: Fish Speech
+site_description: Targeting SOTA TTS solutions.
+site_url: https://speech.fish.audio
+
+# Repository
+repo_name: fishaudio/fish-speech
+repo_url: https://github.com/fishaudio/fish-speech
+edit_uri: blob/main/docs
+
+# Copyright
+copyright: Copyright © 2023-2024 by Fish Audio
+
+theme:
+ name: material
+ favicon: assets/figs/logo-circle.png
+ language: en
+ features:
+ - content.action.edit
+ - content.action.view
+ - navigation.tracking
+ - navigation.footer
+ # - navigation.tabs
+ - search
+ - search.suggest
+ - search.highlight
+ - search.share
+ - content.code.copy
+ icon:
+ logo: fontawesome/solid/fish
+
+ palette:
+ # Palette toggle for automatic mode
+ - media: "(prefers-color-scheme)"
+ toggle:
+ icon: material/brightness-auto
+ name: Switch to light mode
+
+ # Palette toggle for light mode
+ - media: "(prefers-color-scheme: light)"
+ scheme: default
+ toggle:
+ icon: material/brightness-7
+ name: Switch to dark mode
+ primary: black
+ font:
+ code: Roboto Mono
+
+ # Palette toggle for dark mode
+ - media: "(prefers-color-scheme: dark)"
+ scheme: slate
+ toggle:
+ icon: material/brightness-4
+ name: Switch to light mode
+ primary: black
+ font:
+ code: Roboto Mono
+
+nav:
+ - Introduction: index.md
+ - Finetune: finetune.md
+ - Inference: inference.md
+ - Start Agent: start_agent.md
+ - Samples: samples.md
+
+# Plugins
+plugins:
+ - search:
+ separator: '[\s\-,:!=\[\]()"`/]+|\.(?!\d)|&[lg]t;|(?!\b)(?=[A-Z][a-z])'
+ lang:
+ - en
+ - zh
+ - ja
+ - pt
+ - ko
+ - i18n:
+ docs_structure: folder
+ languages:
+ - locale: en
+ name: English
+ default: true
+ build: true
+ - locale: zh
+ name: 简体中文
+ build: true
+ nav:
+ - 介绍: zh/index.md
+ - 微调: zh/finetune.md
+ - 推理: zh/inference.md
+ - 启动Agent: zh/start_agent.md
+ - 例子: zh/samples.md
+ - locale: ja
+ name: 日本語
+ build: true
+ nav:
+ - Fish Speech の紹介: ja/index.md
+ - 微調整: ja/finetune.md
+ - 推論: ja/inference.md
+ - スタートエージェント: ja/start_agent.md
+ - サンプル: ja/samples.md
+ - locale: pt
+ name: Português (Brasil)
+ build: true
+ nav:
+ - Introdução: pt/index.md
+ - Ajuste Fino: pt/finetune.md
+ - Inferência: pt/inference.md
+ - Agente inicial: pt/start_agent.md
+ - Amostras: pt/samples.md
+ - locale: ko
+ name: 한국어
+ build: true
+ nav:
+ - 소개: ko/index.md
+ - 파인튜닝: ko/finetune.md
+ - 추론: ko/inference.md
+ - 샘플: ko/samples.md
+
+markdown_extensions:
+ - pymdownx.highlight:
+ anchor_linenums: true
+ line_spans: __span
+ pygments_lang_class: true
+ - pymdownx.inlinehilite
+ - pymdownx.snippets
+ - pymdownx.superfences
+ - admonition
+ - pymdownx.details
+ - pymdownx.superfences
+ - attr_list
+ - md_in_html
+ - pymdownx.superfences
+
+extra_css:
+ - stylesheets/extra.css
+
+extra:
+ social:
+ - icon: fontawesome/brands/discord
+ link: https://discord.gg/Es5qTB9BcN
+ - icon: fontawesome/brands/docker
+ link: https://hub.docker.com/r/fishaudio/fish-speech
+ - icon: fontawesome/brands/qq
+ link: http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=jCKlUP7QgSm9kh95UlBoYv6s1I-Apl1M&authKey=xI5ttVAp3do68IpEYEalwXSYZFdfxZSkah%2BctF5FIMyN2NqAa003vFtLqJyAVRfF&noverify=0&group_code=593946093
+ homepage: https://speech.fish.audio
diff --git a/packages.txt b/packages.txt
new file mode 100644
index 0000000000000000000000000000000000000000..aaf59da0e781a7ca2fa58094096ebf819698bc72
--- /dev/null
+++ b/packages.txt
@@ -0,0 +1 @@
+portaudio19-dev
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..225b34d9d704129b212253cf54dc21a257b55edf
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,60 @@
+[project]
+name = "fish-speech"
+version = "0.1.0"
+authors = [
+ {name = "Lengyue", email = "lengyue@lengyue.me"},
+]
+description = "Fish Speech"
+readme = "README.md"
+requires-python = ">=3.10"
+keywords = ["TTS", "Speech"]
+license = {text = "CC BY-NC-SA 4.0"}
+classifiers = [
+ "Programming Language :: Python :: 3",
+]
+dependencies = [
+ "numpy<=1.26.4",
+ "transformers>=4.45.2",
+ "datasets==2.18.0",
+ "lightning>=2.1.0",
+ "hydra-core>=1.3.2",
+ "tensorboard>=2.14.1",
+ "natsort>=8.4.0",
+ "einops>=0.7.0",
+ "librosa>=0.10.1",
+ "rich>=13.5.3",
+ "gradio>5.0.0",
+ "wandb>=0.15.11",
+ "grpcio>=1.58.0",
+ "kui>=1.6.0",
+ "uvicorn>=0.30.0",
+ "loguru>=0.6.0",
+ "loralib>=0.1.2",
+ "natsort>=8.4.0",
+ "pyrootutils>=1.0.4",
+ "vector_quantize_pytorch==1.14.24",
+ "resampy>=0.4.3",
+ "einx[torch]==0.2.2",
+ "zstandard>=0.22.0",
+ "pydub",
+ "pyaudio",
+ "faster_whisper",
+ "modelscope==1.17.1",
+ "funasr==1.1.5",
+ "opencc-python-reimplemented==0.1.7",
+ "silero-vad",
+ "ormsgpack",
+]
+
+[project.optional-dependencies]
+stable = [
+ "torch<=2.4.1",
+ "torchaudio",
+]
+
+[build-system]
+requires = ["setuptools", "setuptools-scm"]
+build-backend = "setuptools.build_meta"
+
+[tool.setuptools]
+packages = ["fish_speech", "tools"]
diff --git a/pyrightconfig.json b/pyrightconfig.json
new file mode 100644
index 0000000000000000000000000000000000000000..ad1493530f7f6d8fa476dbe0b76e6239fce2d7e7
--- /dev/null
+++ b/pyrightconfig.json
@@ -0,0 +1,6 @@
+{
+ "exclude": [
+ "data",
+ "filelists"
+ ]
+}
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ca5152bff0451c3561f6046bc29d97cb581b3c32
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,31 @@
+"numpy<=1.26.4",
+ "transformers>=4.45.2",
+ "datasets==2.18.0",
+ "lightning>=2.1.0",
+ "hydra-core>=1.3.2",
+ "tensorboard>=2.14.1",
+ "natsort>=8.4.0",
+ "einops>=0.7.0",
+ "librosa>=0.10.1",
+ "rich>=13.5.3",
+ "gradio>5.0.0",
+ "wandb>=0.15.11",
+ "grpcio>=1.58.0",
+ "kui>=1.6.0",
+ "uvicorn>=0.30.0",
+ "loguru>=0.6.0",
+ "loralib>=0.1.2",
+ "natsort>=8.4.0",
+ "pyrootutils>=1.0.4",
+ "vector_quantize_pytorch==1.14.24",
+ "resampy>=0.4.3",
+ "einx[torch]==0.2.2",
+ "zstandard>=0.22.0",
+ "pydub",
+ "pyaudio",
+ "faster_whisper",
+ "modelscope==1.17.1",
+ "funasr==1.1.5",
+ "opencc-python-reimplemented==0.1.7",
+ "silero-vad",
+ "ormsgpack",
\ No newline at end of file
diff --git a/run_cmd.bat b/run_cmd.bat
new file mode 100644
index 0000000000000000000000000000000000000000..c2af8a9b6fb75df7b7c81ff5986286845e247fb9
--- /dev/null
+++ b/run_cmd.bat
@@ -0,0 +1,50 @@
+@echo off
+chcp 65001
+
+set no_proxy="127.0.0.1, 0.0.0.0, localhost"
+setlocal enabledelayedexpansion
+
+cd /D "%~dp0"
+
+set PATH="%PATH%";%SystemRoot%\system32
+
+
+echo "%CD%"| findstr /R /C:"[!#\$%&()\*+,;<=>?@\[\]\^`{|}~\u4E00-\u9FFF ] " >nul && (
+ echo.
+ echo There are special characters in the current path, please make the path of fish-speech free of special characters before running. && (
+ goto end
+ )
+)
+
+
+set TMP=%CD%\fishenv
+set TEMP=%CD%\fishenv
+
+
+(call conda deactivate && call conda deactivate && call conda deactivate) 2>nul
+
+
+set CONDA_ROOT_PREFIX=%cd%\fishenv\conda
+set INSTALL_ENV_DIR=%cd%\fishenv\env
+
+
+set PYTHONNOUSERSITE=1
+set PYTHONPATH=%~dp0
+set PYTHONHOME=
+
+
+call "%CONDA_ROOT_PREFIX%\condabin\conda.bat" activate "%INSTALL_ENV_DIR%"
+
+if errorlevel 1 (
+ echo.
+ echo Environment activation failed.
+ goto end
+) else (
+ echo.
+ echo Environment activation succeeded.
+)
+
+cmd /k "%*"
+
+:end
+pause
diff --git a/start.bat b/start.bat
new file mode 100644
index 0000000000000000000000000000000000000000..40c7f4d3bfd340753079add8e8e6b5db7abc3fcc
--- /dev/null
+++ b/start.bat
@@ -0,0 +1,97 @@
+@echo off
+chcp 65001
+
+set USE_MIRROR=true
+set PYTHONPATH=%~dp0
+set PYTHON_CMD=python
+if exist "fishenv" (
+ set PYTHON_CMD=%cd%\fishenv\env\python
+)
+
+set API_FLAG_PATH=%~dp0API_FLAGS.txt
+set KMP_DUPLICATE_LIB_OK=TRUE
+
+setlocal enabledelayedexpansion
+
+set "HF_ENDPOINT=https://huggingface.co"
+set "no_proxy="
+if "%USE_MIRROR%" == "true" (
+ set "HF_ENDPOINT=https://hf-mirror.com"
+ set "no_proxy=localhost, 127.0.0.1, 0.0.0.0"
+)
+echo "HF_ENDPOINT: !HF_ENDPOINT!"
+echo "NO_PROXY: !no_proxy!"
+
+echo "%CD%"| findstr /R /C:"[!#\$%&()\*+,;<=>?@\[\]\^`{|}~\u4E00-\u9FFF ] " >nul && (
+ echo.
+ echo There are special characters in the current path, please make the path of fish-speech free of special characters before running. && (
+ goto end
+ )
+)
+
+%PYTHON_CMD% .\tools\download_models.py
+
+set "API_FLAGS="
+set "flags="
+
+if exist "%API_FLAG_PATH%" (
+ for /f "usebackq tokens=*" %%a in ("%API_FLAG_PATH%") do (
+ set "line=%%a"
+ if not "!line:~0,1!"=="#" (
+ set "line=!line: =!"
+ set "line=!line:\=!"
+ set "line=!line:= !"
+ if not "!line!"=="" (
+ set "API_FLAGS=!API_FLAGS!!line! "
+ )
+ )
+ )
+)
+
+
+if not "!API_FLAGS!"=="" set "API_FLAGS=!API_FLAGS:~0,-1!"
+
+set "flags="
+
+echo !API_FLAGS! | findstr /C:"--api" >nul 2>&1
+if !errorlevel! equ 0 (
+ echo.
+ echo Start HTTP API...
+ set "mode=api"
+ goto process_flags
+)
+
+echo !API_FLAGS! | findstr /C:"--infer" >nul 2>&1
+if !errorlevel! equ 0 (
+ echo.
+ echo Start WebUI Inference...
+ set "mode=infer"
+ goto process_flags
+)
+
+
+:process_flags
+for %%p in (!API_FLAGS!) do (
+ if not "%%p"=="--!mode!" (
+ set "flags=!flags! %%p"
+ )
+)
+
+if not "!flags!"=="" set "flags=!flags:~1!"
+
+echo Debug: flags = !flags!
+
+if "!mode!"=="api" (
+ %PYTHON_CMD% -m tools.api !flags!
+) else if "!mode!"=="infer" (
+ %PYTHON_CMD% -m tools.webui !flags!
+)
+
+echo.
+echo Next launch the page...
+%PYTHON_CMD% fish_speech\webui\manage.py
+
+
+:end
+endlocal
+pause
diff --git a/tools/api.py b/tools/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbcc5f6398fe144a411b644410e4938cb22245e1
--- /dev/null
+++ b/tools/api.py
@@ -0,0 +1,943 @@
+import io
+import os
+import queue
+import re
+import time
+import traceback
+import wave
+from argparse import ArgumentParser
+from http import HTTPStatus
+from pathlib import Path
+from typing import Annotated, Any
+
+import librosa
+import numpy as np
+import ormsgpack
+import pyrootutils
+import soundfile as sf
+import torch
+import torchaudio
+from baize.datastructures import ContentType
+from kui.asgi import (
+ Body,
+ FactoryClass,
+ HTTPException,
+ HttpRequest,
+ HttpView,
+ JSONResponse,
+ Kui,
+ OpenAPI,
+ StreamResponse,
+ request,
+)
+from kui.asgi.routing import MultimethodRoutes
+from loguru import logger
+from transformers import AutoTokenizer
+
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+import struct
+from threading import Lock
+
+import httpx
+from cachetools import LRUCache, cached
+from funasr import AutoModel
+from silero_vad import get_speech_timestamps, load_silero_vad
+
+from fish_speech.conversation import IM_END_TOKEN, SEMANTIC_TOKEN
+from fish_speech.models.text2semantic.llama import BaseModelArgs
+
+# from fish_speech.models.vqgan.lit_module import VQGAN
+from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
+from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
+from fish_speech.utils import autocast_exclude_mps, set_seed
+from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
+from tools.llama.generate import (
+ GenerateRequest,
+ GenerateResponse,
+ WrappedGenerateResponse,
+ launch_thread_safe_queue,
+ launch_thread_safe_queue_agent,
+)
+from tools.schema import (
+ GLOBAL_NUM_SAMPLES,
+ ASRPackRequest,
+ ServeASRRequest,
+ ServeASRResponse,
+ ServeASRSegment,
+ ServeAudioPart,
+ ServeForwardMessage,
+ ServeMessage,
+ ServeRequest,
+ ServeResponse,
+ ServeStreamDelta,
+ ServeStreamResponse,
+ ServeTextPart,
+ ServeTimedASRResponse,
+ ServeTTSRequest,
+ ServeVQGANDecodeRequest,
+ ServeVQGANDecodeResponse,
+ ServeVQGANEncodeRequest,
+ ServeVQGANEncodeResponse,
+ ServeVQPart,
+)
+from tools.vqgan.inference import load_model as load_decoder_model
+
+global_lock = Lock()
+
+# Whether to disable keepalive (which is helpful if the server is in the same cluster)
+DISABLE_KEEPALIVE = os.getenv("DISABLE_KEEPALIVE", "false").lower() == "true"
+async_client = httpx.AsyncClient(
+ timeout=120, limits=httpx.Limits(keepalive_expiry=0 if DISABLE_KEEPALIVE else None)
+)
+backends = torchaudio.list_audio_backends()
+
+if "ffmpeg" in backends:
+ backend = "ffmpeg"
+else:
+ backend = "soundfile"
+
+
+def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
+ buffer = io.BytesIO()
+
+ with wave.open(buffer, "wb") as wav_file:
+ wav_file.setnchannels(channels)
+ wav_file.setsampwidth(bit_depth // 8)
+ wav_file.setframerate(sample_rate)
+
+ wav_header_bytes = buffer.getvalue()
+ buffer.close()
+ return wav_header_bytes
+
+
+# Define utils for web server
+async def http_execption_handler(exc: HTTPException):
+ return JSONResponse(
+ dict(
+ statusCode=exc.status_code,
+ message=exc.content,
+ error=HTTPStatus(exc.status_code).phrase,
+ ),
+ exc.status_code,
+ exc.headers,
+ )
+
+
+async def other_exception_handler(exc: "Exception"):
+ traceback.print_exc()
+
+ status = HTTPStatus.INTERNAL_SERVER_ERROR
+ return JSONResponse(
+ dict(statusCode=status, message=str(exc), error=status.phrase),
+ status,
+ )
+
+
+def load_audio(reference_audio, sr):
+ if len(reference_audio) > 255 or not Path(reference_audio).exists():
+ audio_data = reference_audio
+ reference_audio = io.BytesIO(audio_data)
+
+ waveform, original_sr = torchaudio.load(reference_audio, backend=backend)
+
+ if waveform.shape[0] > 1:
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
+
+ if original_sr != sr:
+ resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=sr)
+ waveform = resampler(waveform)
+
+ audio = waveform.squeeze().numpy()
+ return audio
+
+
+def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
+ if enable_reference_audio and reference_audio is not None:
+ # Load audios, and prepare basic info here
+ reference_audio_content = load_audio(
+ reference_audio, decoder_model.spec_transform.sample_rate
+ )
+
+ audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[
+ None, None, :
+ ]
+ audio_lengths = torch.tensor(
+ [audios.shape[2]], device=decoder_model.device, dtype=torch.long
+ )
+ logger.info(
+ f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds"
+ )
+
+ # VQ Encoder
+ if isinstance(decoder_model, FireflyArchitecture):
+ prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
+
+ logger.info(f"Encoded prompt: {prompt_tokens.shape}")
+ else:
+ prompt_tokens = None
+ logger.info("No reference audio provided")
+
+ return prompt_tokens
+
+
+def decode_vq_tokens(
+ *,
+ decoder_model,
+ codes,
+):
+ feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
+ logger.info(f"VQ features: {codes.shape}")
+
+ if isinstance(decoder_model, FireflyArchitecture):
+ # VQGAN Inference
+ return decoder_model.decode(
+ indices=codes[None],
+ feature_lengths=feature_lengths,
+ )[0].squeeze()
+
+ raise ValueError(f"Unknown model type: {type(decoder_model)}")
+
+
+routes = MultimethodRoutes(base_class=HttpView)
+
+
+def get_content_type(audio_format):
+ if audio_format == "wav":
+ return "audio/wav"
+ elif audio_format == "flac":
+ return "audio/flac"
+ elif audio_format == "mp3":
+ return "audio/mpeg"
+ else:
+ return "application/octet-stream"
+
+
+@torch.no_grad()
+@torch.autocast(device_type="cuda", dtype=torch.half)
+def batch_encode(model, audios: list[bytes | torch.Tensor]):
+ audios = [
+ (
+ torch.from_numpy(
+ librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0]
+ )[None]
+ if isinstance(audio, bytes)
+ else audio
+ )
+ for audio in audios
+ ]
+
+ # if any(audio.shape[-1] > model.spec_transform.sample_rate * 120 for audio in audios):
+ # raise ValueError("Single audio length is too long (>120s)")
+
+ max_length = max(audio.shape[-1] for audio in audios)
+ print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s")
+
+ lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device)
+ max_length = lengths.max().item()
+ padded = torch.stack(
+ [
+ torch.nn.functional.pad(audio, (0, max_length - audio.shape[-1]))
+ for audio in audios
+ ]
+ ).to(model.device)
+
+ features, feature_lengths = model.encode(padded, audio_lengths=lengths)
+ features, feature_lengths = features.cpu(), feature_lengths.cpu()
+
+ return [feature[..., :length] for feature, length in zip(features, feature_lengths)]
+
+
+@cached(
+ cache=LRUCache(maxsize=10000),
+ key=lambda model, audios: (model.device, tuple(audios)),
+)
+def cached_vqgan_batch_encode(model, audios: list[bytes]):
+ return batch_encode(model, audios)
+
+
+@routes.http.post("/v1/vqgan/encode")
+def api_vqgan_encode(payload: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]):
+
+ start_time = time.time()
+ tokens = cached_vqgan_batch_encode(decoder_model, payload.audios)
+ logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms")
+
+ return ormsgpack.packb(
+ ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
+ option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
+ )
+
+
+@torch.no_grad()
+@torch.autocast(device_type="cuda", dtype=torch.half)
+def vqgan_decode(model, features):
+ lengths = torch.tensor(
+ [feature.shape[-1] for feature in features], device=model.device
+ )
+ max_length = lengths.max().item()
+ padded = torch.stack(
+ [
+ torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1]))
+ for feature in features
+ ]
+ ).to(model.device)
+
+ # If bs too large, we do micro batch decode
+ audios, audio_lengths = [], []
+ for i in range(0, padded.shape[0], 8):
+ audio, audio_length = model.decode(
+ padded[i : i + 8], feature_lengths=lengths[i : i + 8]
+ )
+ audios.append(audio)
+ audio_lengths.append(audio_length)
+ audios = torch.cat(audios, dim=0)
+ audio_lengths = torch.cat(audio_lengths, dim=0)
+ audios, audio_lengths = audios.cpu(), audio_lengths.cpu()
+
+ return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)]
+
+
+@routes.http.post("/v1/vqgan/decode")
+def api_vqgan_decode(payload: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]):
+ tokens = [torch.tensor(token, dtype=torch.int) for token in payload.tokens]
+ start_time = time.time()
+ audios = vqgan_decode(decoder_model, tokens)
+ logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms")
+ audios = [audio.astype(np.float16).tobytes() for audio in audios]
+ return ormsgpack.packb(
+ ServeVQGANDecodeResponse(audios=audios), option=ormsgpack.OPT_SERIALIZE_PYDANTIC
+ )
+
+
+@torch.no_grad()
+def batch_asr(model, audios, sr, language="auto"):
+ resampled_audios = []
+ for audio in audios:
+ audio = torchaudio.functional.resample(audio, sr, 16000)
+ assert audio.ndim == 1
+ resampled_audios.append(audio)
+
+ with global_lock:
+ res = model.generate(
+ input=resampled_audios,
+ batch_size=len(resampled_audios),
+ language=language,
+ use_itn=True,
+ )
+
+ results = []
+ for r, audio in zip(res, audios):
+ text = r["text"]
+ text = re.sub(r"<\|.*?\|>", "", text)
+ duration = len(audio) / sr * 1000
+ huge_gap = False
+
+ if "timestamp" in r and len(r["timestamp"]) > 2:
+ for timestamp_a, timestamp_b in zip(
+ r["timestamp"][:-1], r["timestamp"][1:]
+ ):
+ # If there is a gap of more than 5 seconds, we consider it as a huge gap
+ if timestamp_b[0] - timestamp_a[1] > 5000:
+ huge_gap = True
+ break
+
+ # Doesn't make sense to have a huge gap at the end
+ if duration - r["timestamp"][-1][1] > 3000:
+ huge_gap = True
+
+ results.append(
+ {
+ "text": text,
+ "duration": duration,
+ "huge_gap": huge_gap,
+ }
+ )
+
+ return results
+
+
+@routes.http.post("/v1/asr")
+def api_invoke_asr(payload: Annotated[ServeASRRequest, Body(exclusive=True)]):
+ start_time = time.time()
+ audios = [np.frombuffer(audio, dtype=np.float16) for audio in payload.audios]
+ audios = [torch.from_numpy(audio).float() for audio in audios]
+
+ if any(audios.shape[-1] >= 30 * payload.sample_rate for audios in audios):
+ raise HTTPException(status_code=400, detail="Audio length is too long")
+
+ transcriptions = batch_asr(
+ asr_model, audios=audios, sr=payload.sample_rate, language=payload.language
+ )
+ logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
+
+ return ormsgpack.packb(
+ ServeASRResponse(transcriptions=transcriptions),
+ option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
+ )
+
+
+from fish_speech.conversation import Conversation, Message
+
+
+def execute_request(
+ input_queue: queue.Queue,
+ tokenizer: AutoTokenizer,
+ config: BaseModelArgs,
+ request: ServeRequest,
+ device: str = "cuda:0",
+):
+ semantic_id, im_end_id = tokenizer.convert_tokens_to_ids(
+ [SEMANTIC_TOKEN, IM_END_TOKEN]
+ )
+ messages = []
+ for message in request.messages:
+ messages.append(message.to_conversation_message())
+
+ assert len(messages) >= 1, "At least one message is required"
+ # assert messages[-1].role == "user", "The last message must be from the user"
+
+ if messages[-1].role == "user":
+ messages.append(Message(role="assistant", parts=[], add_im_end=False))
+ else:
+ assert (
+ messages[-1].role == "assistant"
+ ), "The last message must be from the assistant"
+ messages[-1].add_im_end = False
+
+ conv = Conversation(messages=messages)
+ prompt = conv.encode_for_inference(
+ tokenizer=tokenizer, num_codebooks=config.num_codebooks
+ ).to(device)
+
+ if request.streaming:
+ for i in range(request.num_samples):
+ yield ServeStreamResponse(
+ sample_id=i,
+ delta=ServeStreamDelta(
+ role="assistant",
+ ),
+ )
+
+ req = {
+ "prompt": prompt,
+ "max_new_tokens": request.max_new_tokens,
+ "im_end_id": im_end_id,
+ "semantic_id": semantic_id,
+ "temperature": request.temperature,
+ "top_p": request.top_p,
+ "repetition_penalty": request.repetition_penalty,
+ "num_samples": request.num_samples,
+ "early_stop_threshold": request.early_stop_threshold,
+ }
+
+ start = time.time()
+ response_queue = queue.Queue()
+ input_queue.put(GenerateRequest(req, response_queue))
+
+ # Decoding
+ decode_buffer = [[] for _ in range(request.num_samples)]
+ parts = [[] for _ in range(request.num_samples)]
+
+ def send_reset_buffer(sample_id):
+ nonlocal decode_buffer
+ if len(decode_buffer[sample_id]) == 0:
+ return
+
+ decoded = tokenizer.decode(decode_buffer[sample_id])
+ part = ServeTextPart(text=decoded)
+
+ if request.streaming:
+ yield ServeStreamResponse(delta=ServeStreamDelta(part=part))
+ else:
+ parts[sample_id].append(part)
+
+ decode_buffer[sample_id] = []
+
+ # Decode process
+ finished = [False for _ in range(request.num_samples)]
+ stats = {}
+ idx = 0
+ while True:
+ response = response_queue.get()
+
+ if response in ["stop", "error"]:
+ break
+
+ for sample_id, tokens in enumerate(response):
+ if finished[sample_id]:
+ continue
+
+ if tokens[0] == im_end_id:
+ finished[sample_id] = True
+ if request.streaming:
+ yield from send_reset_buffer(sample_id)
+ yield ServeStreamResponse(
+ sample_id=sample_id,
+ finish_reason="stop",
+ stats=stats,
+ )
+ continue
+
+ if tokens[0] == semantic_id and request.streaming:
+ yield from send_reset_buffer(sample_id)
+ # Streaming vq
+ _tokens = tokens[1:].clone() - 1
+
+ if config.share_codebook_embeddings is False:
+ for i in range(len(_tokens)):
+ _tokens[i] -= config.codebook_size * i
+
+ yield ServeStreamResponse(
+ sample_id=sample_id,
+ delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())),
+ )
+ continue
+
+ # Not streaming vq
+ if tokens[0] == semantic_id:
+ yield from send_reset_buffer(sample_id)
+ # None streaming vq
+ if len(parts[sample_id]) == 0 or not isinstance(
+ parts[sample_id][-1], ServeVQPart
+ ):
+ _tokens = tokens[1:].clone() - 1
+
+ if config.share_codebook_embeddings is False:
+ for i in range(len(_tokens)):
+ _tokens[i] -= config.codebook_size * i
+
+ parts[sample_id].append(ServeVQPart(codes=_tokens.tolist()))
+ else:
+ for codebook_id, value in enumerate(tokens[1:, :]):
+ val = value.item() - 1
+ if config.share_codebook_embeddings is False:
+ val -= config.codebook_size * codebook_id
+
+ parts[sample_id][-1].codes[codebook_id].append(val)
+ continue
+
+ if tokens[0] != semantic_id:
+ # Stream text decode is not supported now
+ decode_buffer[sample_id].append(tokens[0, 0])
+
+ if idx == 0:
+ stats["time_to_first_token"] = (time.time() - start) * 1000
+
+ idx += 1
+
+ for sample_id in range(request.num_samples):
+ yield from send_reset_buffer(sample_id)
+
+ stats["total_time"] = (time.time() - start) * 1000
+ stats["total_tokens"] = idx
+
+ if request.streaming:
+ for sample_id in range(request.num_samples):
+ if finished[sample_id]:
+ continue
+ yield ServeStreamResponse(
+ finish_reason=response, stats=stats, sample_id=sample_id
+ )
+ return
+
+ yield ServeResponse(
+ messages=[
+ ServeMessage(role="assistant", parts=parts[i])
+ for i in range(request.num_samples)
+ ],
+ finish_reason=response,
+ stats=stats,
+ )
+
+
+@routes.http.post("/v1/chat")
+def api_invoke_chat(
+ req: Annotated[ServeRequest, Body(exclusive=True)],
+):
+ """
+ Invoke model and generate audio
+ """
+
+ # This makes torch compile happy
+ assert (
+ req.num_samples == GLOBAL_NUM_SAMPLES
+ ), f"num_samples must be {GLOBAL_NUM_SAMPLES}"
+
+ content_type = request.headers.get("Content-Type", "application/json")
+ json_mode = "application/json" in content_type
+
+ async def wrapped_generator():
+ generator = execute_request(llama_queue, tokenizer, config, req, args.device)
+
+ for i in generator:
+ if json_mode:
+ body = i.model_dump_json().encode("utf-8")
+ yield b"data: " + body + b"\n\n"
+ else:
+ body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
+ yield struct.pack("I", len(body)) + body
+
+ # Naive mode
+ if req.streaming is False:
+ result = next(execute_request(llama_queue, tokenizer, config, req, args.device))
+
+ if json_mode:
+ return JSONResponse(result.model_dump())
+ else:
+ return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
+
+ return StreamResponse(
+ iterable=wrapped_generator(), content_type="text/event-stream"
+ )
+
+
+@torch.inference_mode()
+def inference(req: ServeTTSRequest):
+
+ global prompt_tokens, prompt_texts
+
+ idstr: str | None = req.reference_id
+ if idstr is not None:
+ ref_folder = Path("references") / idstr
+ ref_folder.mkdir(parents=True, exist_ok=True)
+ ref_audios = list_files(
+ ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
+ )
+
+ if req.use_memory_cache == "never" or (
+ req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
+ ):
+ prompt_tokens = [
+ encode_reference(
+ decoder_model=decoder_model,
+ reference_audio=audio_to_bytes(str(ref_audio)),
+ enable_reference_audio=True,
+ )
+ for ref_audio in ref_audios
+ ]
+ prompt_texts = [
+ read_ref_text(str(ref_audio.with_suffix(".lab")))
+ for ref_audio in ref_audios
+ ]
+ else:
+ logger.info("Use same references")
+
+ else:
+ # Parse reference audio aka prompt
+ refs = req.references
+
+ if req.use_memory_cache == "never" or (
+ req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
+ ):
+ prompt_tokens = [
+ encode_reference(
+ decoder_model=decoder_model,
+ reference_audio=ref.audio,
+ enable_reference_audio=True,
+ )
+ for ref in refs
+ ]
+ prompt_texts = [ref.text for ref in refs]
+ else:
+ logger.info("Use same references")
+
+ if req.seed is not None:
+ set_seed(req.seed)
+ logger.warning(f"set seed: {req.seed}")
+
+ # LLAMA Inference
+ request = dict(
+ device=decoder_model.device,
+ max_new_tokens=req.max_new_tokens,
+ text=(
+ req.text
+ if not req.normalize
+ else ChnNormedText(raw_text=req.text).normalize()
+ ),
+ top_p=req.top_p,
+ repetition_penalty=req.repetition_penalty,
+ temperature=req.temperature,
+ compile=args.compile,
+ iterative_prompt=req.chunk_length > 0,
+ chunk_length=req.chunk_length,
+ max_length=4096,
+ prompt_tokens=prompt_tokens,
+ prompt_text=prompt_texts,
+ )
+
+ response_queue = queue.Queue()
+ llama_queue.put(
+ GenerateRequest(
+ request=request,
+ response_queue=response_queue,
+ )
+ )
+
+ if req.streaming:
+ yield wav_chunk_header()
+
+ segments = []
+ while True:
+ result: WrappedGenerateResponse = response_queue.get()
+ if result.status == "error":
+ raise result.response
+ break
+
+ result: GenerateResponse = result.response
+ if result.action == "next":
+ break
+
+ with autocast_exclude_mps(
+ device_type=decoder_model.device.type, dtype=args.precision
+ ):
+ fake_audios = decode_vq_tokens(
+ decoder_model=decoder_model,
+ codes=result.codes,
+ )
+
+ fake_audios = fake_audios.float().cpu().numpy()
+
+ if req.streaming:
+ yield (fake_audios * 32768).astype(np.int16).tobytes()
+ else:
+ segments.append(fake_audios)
+
+ if req.streaming:
+ return
+
+ if len(segments) == 0:
+ raise HTTPException(
+ HTTPStatus.INTERNAL_SERVER_ERROR,
+ content="No audio generated, please check the input text.",
+ )
+
+ fake_audios = np.concatenate(segments, axis=0)
+ yield fake_audios
+
+
+async def inference_async(req: ServeTTSRequest):
+ for chunk in inference(req):
+ yield chunk
+
+
+async def buffer_to_async_generator(buffer):
+ yield buffer
+
+
+@routes.http.post("/v1/tts")
+async def api_invoke_model(
+ req: Annotated[ServeTTSRequest, Body(exclusive=True)],
+):
+ """
+ Invoke model and generate audio
+ """
+
+ if args.max_text_length > 0 and len(req.text) > args.max_text_length:
+ raise HTTPException(
+ HTTPStatus.BAD_REQUEST,
+ content=f"Text is too long, max length is {args.max_text_length}",
+ )
+
+ if req.streaming and req.format != "wav":
+ raise HTTPException(
+ HTTPStatus.BAD_REQUEST,
+ content="Streaming only supports WAV format",
+ )
+
+ if req.streaming:
+ return StreamResponse(
+ iterable=inference_async(req),
+ headers={
+ "Content-Disposition": f"attachment; filename=audio.{req.format}",
+ },
+ content_type=get_content_type(req.format),
+ )
+ else:
+ fake_audios = next(inference(req))
+ buffer = io.BytesIO()
+ sf.write(
+ buffer,
+ fake_audios,
+ decoder_model.spec_transform.sample_rate,
+ format=req.format,
+ )
+
+ return StreamResponse(
+ iterable=buffer_to_async_generator(buffer.getvalue()),
+ headers={
+ "Content-Disposition": f"attachment; filename=audio.{req.format}",
+ },
+ content_type=get_content_type(req.format),
+ )
+
+
+@routes.http.post("/v1/health")
+async def api_health():
+ """
+ Health check
+ """
+
+ return JSONResponse({"status": "ok"})
+
+
+def parse_args():
+ parser = ArgumentParser()
+ parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts")
+ parser.add_argument("--load-asr-model", action="store_true")
+ parser.add_argument(
+ "--llama-checkpoint-path",
+ type=str,
+ default="checkpoints/fish-speech-1.4",
+ )
+ parser.add_argument(
+ "--decoder-checkpoint-path",
+ type=str,
+ default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+ )
+ parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
+ parser.add_argument("--device", type=str, default="cuda")
+ parser.add_argument("--half", action="store_true")
+ parser.add_argument("--compile", action="store_true")
+ parser.add_argument("--max-text-length", type=int, default=0)
+ parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
+ parser.add_argument("--workers", type=int, default=1)
+
+ return parser.parse_args()
+
+
+# Define Kui app
+openapi = OpenAPI(
+ {
+ "title": "Fish Speech API",
+ "version": "1.4.2",
+ },
+).routes
+
+
+class MsgPackRequest(HttpRequest):
+ async def data(
+ self,
+ ) -> Annotated[
+ Any, ContentType("application/msgpack"), ContentType("application/json")
+ ]:
+ if self.content_type == "application/msgpack":
+ return ormsgpack.unpackb(await self.body)
+
+ elif self.content_type == "application/json":
+ return await self.json
+
+ raise HTTPException(
+ HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
+ headers={"Accept": "application/msgpack, application/json"},
+ )
+
+
+app = Kui(
+ routes=routes + openapi[1:], # Remove the default route
+ exception_handlers={
+ HTTPException: http_execption_handler,
+ Exception: other_exception_handler,
+ },
+ factory_class=FactoryClass(http=MsgPackRequest),
+ cors_config={},
+)
+
+
+def load_asr_model(*, device="cuda", hub="ms"):
+ return AutoModel(
+ model="iic/SenseVoiceSmall",
+ device=device,
+ disable_pbar=True,
+ hub=hub,
+ )
+
+
+# Each worker process created by Uvicorn has its own memory space,
+# meaning that models and variables are not shared between processes.
+# Therefore, any global variables (like `llama_queue` or `decoder_model`)
+# will not be shared across workers.
+
+
+# Multi-threading for deep learning can cause issues, such as inconsistent
+# outputs if multiple threads access the same buffers simultaneously.
+# Instead, it's better to use multiprocessing or independent models per thread.
+@app.on_startup
+def initialize_app(app: Kui):
+
+ global args, llama_queue, tokenizer, config, decoder_model, vad_model, asr_model, prompt_tokens, prompt_texts
+
+ prompt_tokens, prompt_texts = [], []
+
+ args = parse_args() # args same as ones in other processes
+ args.precision = torch.half if args.half else torch.bfloat16
+
+ if args.load_asr_model:
+ logger.info(f"Loading ASR model...")
+ asr_model = load_asr_model(device=args.device)
+
+ logger.info("Loading Llama model...")
+
+ if args.mode == "tts":
+ llama_queue = launch_thread_safe_queue(
+ checkpoint_path=args.llama_checkpoint_path,
+ device=args.device,
+ precision=args.precision,
+ compile=args.compile,
+ )
+ else:
+ llama_queue, tokenizer, config = launch_thread_safe_queue_agent(
+ checkpoint_path=args.llama_checkpoint_path,
+ device=args.device,
+ precision=args.precision,
+ compile=args.compile,
+ )
+
+ logger.info("Llama model loaded, loading VQ-GAN model...")
+
+ decoder_model = load_decoder_model(
+ config_name=args.decoder_config_name,
+ checkpoint_path=args.decoder_checkpoint_path,
+ device=args.device,
+ )
+
+ logger.info("VQ-GAN model loaded, warming up...")
+
+ vad_model = load_silero_vad()
+
+ logger.info("VAD model loaded, warming up...")
+
+ if args.mode == "tts":
+ # Dry run to ensure models work and avoid first-time latency
+ list(
+ inference(
+ ServeTTSRequest(
+ text="Hello world.",
+ references=[],
+ reference_id=None,
+ max_new_tokens=0,
+ chunk_length=200,
+ top_p=0.7,
+ repetition_penalty=1.2,
+ temperature=0.7,
+ emotion=None,
+ format="wav",
+ )
+ )
+ )
+
+ logger.info(f"Warming up done, starting server at http://{args.listen}")
+
+
+if __name__ == "__main__":
+
+ import uvicorn
+
+ args = parse_args()
+ host, port = args.listen.split(":")
+ uvicorn.run(
+ "tools.api:app",
+ host=host,
+ port=int(port),
+ workers=args.workers,
+ log_level="info",
+ )
diff --git a/tools/download_models.py b/tools/download_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e79c34c43b424a8e47c43dd3edf003634fc667e
--- /dev/null
+++ b/tools/download_models.py
@@ -0,0 +1,55 @@
+import os
+
+from huggingface_hub import hf_hub_download
+
+
+# Download
+def check_and_download_files(repo_id, file_list, local_dir):
+ os.makedirs(local_dir, exist_ok=True)
+ for file in file_list:
+ file_path = os.path.join(local_dir, file)
+ if not os.path.exists(file_path):
+ print(f"{file} 不存在,从 Hugging Face 仓库下载...")
+ hf_hub_download(
+ repo_id=repo_id,
+ filename=file,
+ resume_download=True,
+ local_dir=local_dir,
+ local_dir_use_symlinks=False,
+ )
+ else:
+ print(f"{file} 已存在,跳过下载。")
+
+
+# 1st
+repo_id_1 = "fishaudio/fish-speech-1.4"
+local_dir_1 = "./checkpoints/fish-speech-1.4"
+files_1 = [
+ "model.pth",
+ "README.md",
+ "special_tokens_map.json",
+ "tokenizer_config.json",
+ "tokenizer.json",
+ "config.json",
+ "firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+]
+
+# 3rd
+repo_id_3 = "fishaudio/fish-speech-1"
+local_dir_3 = "./"
+files_3 = [
+ "ffmpeg.exe",
+ "ffprobe.exe",
+]
+
+# 4th
+repo_id_4 = "SpicyqSama007/fish-speech-packed"
+local_dir_4 = "./"
+files_4 = [
+ "asr-label-win-x64.exe",
+]
+
+check_and_download_files(repo_id_1, files_1, local_dir_1)
+
+check_and_download_files(repo_id_3, files_3, local_dir_3)
+check_and_download_files(repo_id_4, files_4, local_dir_4)
diff --git a/tools/e2e_webui.py b/tools/e2e_webui.py
new file mode 100644
index 0000000000000000000000000000000000000000..37474fbd5645c09fcbe6caac1331672614c5c821
--- /dev/null
+++ b/tools/e2e_webui.py
@@ -0,0 +1,232 @@
+import io
+import re
+import wave
+
+import gradio as gr
+import numpy as np
+
+from .fish_e2e import FishE2EAgent, FishE2EEventType
+from .schema import ServeMessage, ServeTextPart, ServeVQPart
+
+
+def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
+ buffer = io.BytesIO()
+
+ with wave.open(buffer, "wb") as wav_file:
+ wav_file.setnchannels(channels)
+ wav_file.setsampwidth(bit_depth // 8)
+ wav_file.setframerate(sample_rate)
+
+ wav_header_bytes = buffer.getvalue()
+ buffer.close()
+ return wav_header_bytes
+
+
+class ChatState:
+ def __init__(self):
+ self.conversation = []
+ self.added_systext = False
+ self.added_sysaudio = False
+
+ def get_history(self):
+ results = []
+ for msg in self.conversation:
+ results.append({"role": msg.role, "content": self.repr_message(msg)})
+
+ # Process assistant messages to extract questions and update user messages
+ for i, msg in enumerate(results):
+ if msg["role"] == "assistant":
+ match = re.search(r"Question: (.*?)\n\nResponse:", msg["content"])
+ if match and i > 0 and results[i - 1]["role"] == "user":
+ # Update previous user message with extracted question
+ results[i - 1]["content"] += "\n" + match.group(1)
+ # Remove the Question/Answer format from assistant message
+ msg["content"] = msg["content"].split("\n\nResponse: ", 1)[1]
+ return results
+
+ def repr_message(self, msg: ServeMessage):
+ response = ""
+ for part in msg.parts:
+ if isinstance(part, ServeTextPart):
+ response += part.text
+ elif isinstance(part, ServeVQPart):
+ response += f""
+ return response
+
+
+def clear_fn():
+ return [], ChatState(), None, None, None
+
+
+async def process_audio_input(
+ sys_audio_input, sys_text_input, audio_input, state: ChatState, text_input: str
+):
+ if audio_input is None and not text_input:
+ raise gr.Error("No input provided")
+
+ agent = FishE2EAgent() # Create new agent instance for each request
+
+ # Convert audio input to numpy array
+ if isinstance(audio_input, tuple):
+ sr, audio_data = audio_input
+ elif text_input:
+ sr = 44100
+ audio_data = None
+ else:
+ raise gr.Error("Invalid audio format")
+
+ if isinstance(sys_audio_input, tuple):
+ sr, sys_audio_data = sys_audio_input
+ else:
+ sr = 44100
+ sys_audio_data = None
+
+ def append_to_chat_ctx(
+ part: ServeTextPart | ServeVQPart, role: str = "assistant"
+ ) -> None:
+ if not state.conversation or state.conversation[-1].role != role:
+ state.conversation.append(ServeMessage(role=role, parts=[part]))
+ else:
+ state.conversation[-1].parts.append(part)
+
+ if state.added_systext is False and sys_text_input:
+ state.added_systext = True
+ append_to_chat_ctx(ServeTextPart(text=sys_text_input), role="system")
+ if text_input:
+ append_to_chat_ctx(ServeTextPart(text=text_input), role="user")
+ audio_data = None
+
+ result_audio = b""
+ async for event in agent.stream(
+ sys_audio_data,
+ audio_data,
+ sr,
+ 1,
+ chat_ctx={
+ "messages": state.conversation,
+ "added_sysaudio": state.added_sysaudio,
+ },
+ ):
+ if event.type == FishE2EEventType.USER_CODES:
+ append_to_chat_ctx(ServeVQPart(codes=event.vq_codes), role="user")
+ elif event.type == FishE2EEventType.SPEECH_SEGMENT:
+ append_to_chat_ctx(ServeVQPart(codes=event.vq_codes))
+ yield state.get_history(), wav_chunk_header() + event.frame.data, None, None
+ elif event.type == FishE2EEventType.TEXT_SEGMENT:
+ append_to_chat_ctx(ServeTextPart(text=event.text))
+ yield state.get_history(), None, None, None
+
+ yield state.get_history(), None, None, None
+
+
+async def process_text_input(
+ sys_audio_input, sys_text_input, state: ChatState, text_input: str
+):
+ async for event in process_audio_input(
+ sys_audio_input, sys_text_input, None, state, text_input
+ ):
+ yield event
+
+
+def create_demo():
+ with gr.Blocks() as demo:
+ state = gr.State(ChatState())
+
+ with gr.Row():
+ # Left column (70%) for chatbot and notes
+ with gr.Column(scale=7):
+ chatbot = gr.Chatbot(
+ [],
+ elem_id="chatbot",
+ bubble_full_width=False,
+ height=600,
+ type="messages",
+ )
+
+ # notes = gr.Markdown(
+ # """
+ # # Fish Agent
+ # 1. 此Demo为Fish Audio自研端到端语言模型Fish Agent 3B版本.
+ # 2. 你可以在我们的官方仓库找到代码以及权重,但是相关内容全部基于 CC BY-NC-SA 4.0 许可证发布.
+ # 3. Demo为早期灰度测试版本,推理速度尚待优化.
+ # # 特色
+ # 1. 该模型自动集成ASR与TTS部分,不需要外挂其它模型,即真正的端到端,而非三段式(ASR+LLM+TTS).
+ # 2. 模型可以使用reference audio控制说话音色.
+ # 3. 可以生成具有较强情感与韵律的音频.
+ # """
+ # )
+ notes = gr.Markdown(
+ """
+ # Fish Agent
+ 1. This demo is Fish Audio's self-researh end-to-end language model, Fish Agent version 3B.
+ 2. You can find the code and weights in our official repo in [gitub](https://github.com/fishaudio/fish-speech) and [hugging face](https://huggingface.co/fishaudio/fish-agent-v0.1-3b), but the content is released under a CC BY-NC-SA 4.0 licence.
+ 3. The demo is an early alpha test version, the inference speed needs to be optimised.
+ # Features
+ 1. The model automatically integrates ASR and TTS parts, no need to plug-in other models, i.e., true end-to-end, not three-stage (ASR+LLM+TTS).
+ 2. The model can use reference audio to control the speech timbre.
+ 3. The model can generate speech with strong emotion.
+ """
+ )
+
+ # Right column (30%) for controls
+ with gr.Column(scale=3):
+ sys_audio_input = gr.Audio(
+ sources=["upload"],
+ type="numpy",
+ label="Give a timbre for your assistant",
+ )
+ sys_text_input = gr.Textbox(
+ label="What is your assistant's role?",
+ value="You are a voice assistant created by Fish Audio, offering end-to-end voice interaction for a seamless user experience. You are required to first transcribe the user's speech, then answer it in the following format: 'Question: [USER_SPEECH]\n\nAnswer: [YOUR_RESPONSE]\n'. You are required to use the following voice in this conversation.",
+ type="text",
+ )
+ audio_input = gr.Audio(
+ sources=["microphone"], type="numpy", label="Speak your message"
+ )
+
+ text_input = gr.Textbox(label="Or type your message", type="text")
+
+ output_audio = gr.Audio(
+ label="Assistant's Voice",
+ streaming=True,
+ autoplay=True,
+ interactive=False,
+ )
+
+ send_button = gr.Button("Send", variant="primary")
+ clear_button = gr.Button("Clear")
+
+ # Event handlers
+ audio_input.stop_recording(
+ process_audio_input,
+ inputs=[sys_audio_input, sys_text_input, audio_input, state, text_input],
+ outputs=[chatbot, output_audio, audio_input, text_input],
+ show_progress=True,
+ )
+
+ send_button.click(
+ process_text_input,
+ inputs=[sys_audio_input, sys_text_input, state, text_input],
+ outputs=[chatbot, output_audio, audio_input, text_input],
+ show_progress=True,
+ )
+
+ text_input.submit(
+ process_text_input,
+ inputs=[sys_audio_input, sys_text_input, state, text_input],
+ outputs=[chatbot, output_audio, audio_input, text_input],
+ show_progress=True,
+ )
+
+ clear_button.click(
+ clear_fn,
+ inputs=[],
+ outputs=[chatbot, state, audio_input, output_audio, text_input],
+ )
+
+ return demo
+
+
+if __name__ == "__main__":
+ demo = create_demo()
+ demo.launch(server_name="127.0.0.1", server_port=7860, share=True)
diff --git a/tools/extract_model.py b/tools/extract_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..97fe62507b7282890319d8dc1eaa3cbca0e1f60a
--- /dev/null
+++ b/tools/extract_model.py
@@ -0,0 +1,21 @@
+import click
+import torch
+from loguru import logger
+
+
+@click.command()
+@click.argument("model_path")
+@click.argument("output_path")
+def main(model_path, output_path):
+ if model_path == output_path:
+ logger.error("Model path and output path are the same")
+ return
+
+ logger.info(f"Loading model from {model_path}")
+ state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
+ torch.save(state_dict, output_path)
+ logger.info(f"Model saved to {output_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/file.py b/tools/file.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7a0597365252e7aecf887897ff391a061275c3f
--- /dev/null
+++ b/tools/file.py
@@ -0,0 +1,125 @@
+import base64
+from pathlib import Path
+from typing import Union
+
+from loguru import logger
+from natsort import natsorted
+
+AUDIO_EXTENSIONS = {
+ ".mp3",
+ ".wav",
+ ".flac",
+ ".ogg",
+ ".m4a",
+ ".wma",
+ ".aac",
+ ".aiff",
+ ".aif",
+ ".aifc",
+}
+
+VIDEO_EXTENSIONS = {
+ ".mp4",
+ ".avi",
+}
+
+
+def audio_to_bytes(file_path):
+ if not file_path or not Path(file_path).exists():
+ return None
+ with open(file_path, "rb") as wav_file:
+ wav = wav_file.read()
+ return wav
+
+
+def read_ref_text(ref_text):
+ path = Path(ref_text)
+ if path.exists() and path.is_file():
+ with path.open("r", encoding="utf-8") as file:
+ return file.read()
+ return ref_text
+
+
+def list_files(
+ path: Union[Path, str],
+ extensions: set[str] = None,
+ recursive: bool = False,
+ sort: bool = True,
+) -> list[Path]:
+ """List files in a directory.
+
+ Args:
+ path (Path): Path to the directory.
+ extensions (set, optional): Extensions to filter. Defaults to None.
+ recursive (bool, optional): Whether to search recursively. Defaults to False.
+ sort (bool, optional): Whether to sort the files. Defaults to True.
+
+ Returns:
+ list: List of files.
+ """
+
+ if isinstance(path, str):
+ path = Path(path)
+
+ if not path.exists():
+ raise FileNotFoundError(f"Directory {path} does not exist.")
+
+ files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
+
+ if sort:
+ files = natsorted(files)
+
+ return files
+
+
+def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
+ """
+ Load a Bert-VITS2 style filelist.
+ """
+
+ files = set()
+ results = []
+ count_duplicated, count_not_found = 0, 0
+
+ LANGUAGE_TO_LANGUAGES = {
+ "zh": ["zh", "en"],
+ "jp": ["jp", "en"],
+ "en": ["en"],
+ }
+
+ with open(path, "r", encoding="utf-8") as f:
+ for line in f.readlines():
+ splits = line.strip().split("|", maxsplit=3)
+ if len(splits) != 4:
+ logger.warning(f"Invalid line: {line}")
+ continue
+
+ filename, speaker, language, text = splits
+ file = Path(filename)
+ language = language.strip().lower()
+
+ if language == "ja":
+ language = "jp"
+
+ assert language in ["zh", "jp", "en"], f"Invalid language {language}"
+ languages = LANGUAGE_TO_LANGUAGES[language]
+
+ if file in files:
+ logger.warning(f"Duplicated file: {file}")
+ count_duplicated += 1
+ continue
+
+ if not file.exists():
+ logger.warning(f"File not found: {file}")
+ count_not_found += 1
+ continue
+
+ results.append((file, speaker, languages, text))
+
+ if count_duplicated > 0:
+ logger.warning(f"Total duplicated files: {count_duplicated}")
+
+ if count_not_found > 0:
+ logger.warning(f"Total files not found: {count_not_found}")
+
+ return results
diff --git a/tools/fish_e2e.py b/tools/fish_e2e.py
new file mode 100644
index 0000000000000000000000000000000000000000..34b531f022ab0a1913c0c17581b3c73a2e7d03a0
--- /dev/null
+++ b/tools/fish_e2e.py
@@ -0,0 +1,298 @@
+import base64
+import ctypes
+import io
+import json
+import os
+import struct
+from dataclasses import dataclass
+from enum import Enum
+from typing import AsyncGenerator, Union
+
+import httpx
+import numpy as np
+import ormsgpack
+import soundfile as sf
+
+from .schema import (
+ ServeMessage,
+ ServeRequest,
+ ServeTextPart,
+ ServeVQGANDecodeRequest,
+ ServeVQGANEncodeRequest,
+ ServeVQPart,
+)
+
+
+class CustomAudioFrame:
+ def __init__(self, data, sample_rate, num_channels, samples_per_channel):
+ if len(data) < num_channels * samples_per_channel * ctypes.sizeof(
+ ctypes.c_int16
+ ):
+ raise ValueError(
+ "data length must be >= num_channels * samples_per_channel * sizeof(int16)"
+ )
+
+ self._data = bytearray(data)
+ self._sample_rate = sample_rate
+ self._num_channels = num_channels
+ self._samples_per_channel = samples_per_channel
+
+ @property
+ def data(self):
+ return memoryview(self._data).cast("h")
+
+ @property
+ def sample_rate(self):
+ return self._sample_rate
+
+ @property
+ def num_channels(self):
+ return self._num_channels
+
+ @property
+ def samples_per_channel(self):
+ return self._samples_per_channel
+
+ @property
+ def duration(self):
+ return self.samples_per_channel / self.sample_rate
+
+ def __repr__(self):
+ return (
+ f"CustomAudioFrame(sample_rate={self.sample_rate}, "
+ f"num_channels={self.num_channels}, "
+ f"samples_per_channel={self.samples_per_channel}, "
+ f"duration={self.duration:.3f})"
+ )
+
+
+class FishE2EEventType(Enum):
+ SPEECH_SEGMENT = 1
+ TEXT_SEGMENT = 2
+ END_OF_TEXT = 3
+ END_OF_SPEECH = 4
+ ASR_RESULT = 5
+ USER_CODES = 6
+
+
+@dataclass
+class FishE2EEvent:
+ type: FishE2EEventType
+ frame: np.ndarray = None
+ text: str = None
+ vq_codes: list[list[int]] = None
+
+
+client = httpx.AsyncClient(
+ timeout=None,
+ limits=httpx.Limits(
+ max_connections=None,
+ max_keepalive_connections=None,
+ keepalive_expiry=None,
+ ),
+)
+
+
+class FishE2EAgent:
+ def __init__(self):
+ self.llm_url = "http://localhost:8080/v1/chat"
+ self.vqgan_url = "http://localhost:8080"
+ self.client = httpx.AsyncClient(timeout=None)
+
+ async def get_codes(self, audio_data, sample_rate):
+ audio_buffer = io.BytesIO()
+ sf.write(audio_buffer, audio_data, sample_rate, format="WAV")
+ audio_buffer.seek(0)
+ # Step 1: Encode audio using VQGAN
+ encode_request = ServeVQGANEncodeRequest(audios=[audio_buffer.read()])
+ encode_request_bytes = ormsgpack.packb(
+ encode_request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC
+ )
+ encode_response = await self.client.post(
+ f"{self.vqgan_url}/v1/vqgan/encode",
+ data=encode_request_bytes,
+ headers={"Content-Type": "application/msgpack"},
+ )
+ encode_response_data = ormsgpack.unpackb(encode_response.content)
+ codes = encode_response_data["tokens"][0]
+ return codes
+
+ async def stream(
+ self,
+ system_audio_data: np.ndarray | None,
+ user_audio_data: np.ndarray | None,
+ sample_rate: int,
+ num_channels: int,
+ chat_ctx: dict | None = None,
+ ) -> AsyncGenerator[bytes, None]:
+
+ if system_audio_data is not None:
+ sys_codes = await self.get_codes(system_audio_data, sample_rate)
+ else:
+ sys_codes = None
+ if user_audio_data is not None:
+ user_codes = await self.get_codes(user_audio_data, sample_rate)
+ # Step 2: Prepare LLM request
+ if chat_ctx is None:
+ sys_parts = [
+ ServeTextPart(
+ text='您是由 Fish Audio 设计的语音助手,提供端到端的语音交互,实现无缝用户体验。首先转录用户的语音,然后使用以下格式回答:"Question: [用户语音]\n\nAnswer: [你的回答]\n"。'
+ ),
+ ]
+ if system_audio_data is not None:
+ sys_parts.append(ServeVQPart(codes=sys_codes))
+ chat_ctx = {
+ "messages": [
+ ServeMessage(
+ role="system",
+ parts=sys_parts,
+ ),
+ ],
+ }
+ else:
+ if chat_ctx["added_sysaudio"] is False and sys_codes:
+ chat_ctx["added_sysaudio"] = True
+ chat_ctx["messages"][0].parts.append(ServeVQPart(codes=sys_codes))
+
+ prev_messages = chat_ctx["messages"].copy()
+ if user_audio_data is not None:
+ yield FishE2EEvent(
+ type=FishE2EEventType.USER_CODES,
+ vq_codes=user_codes,
+ )
+ else:
+ user_codes = None
+
+ request = ServeRequest(
+ messages=prev_messages
+ + (
+ [
+ ServeMessage(
+ role="user",
+ parts=[ServeVQPart(codes=user_codes)],
+ )
+ ]
+ if user_codes
+ else []
+ ),
+ streaming=True,
+ num_samples=1,
+ )
+
+ # Step 3: Stream LLM response and decode audio
+ buffer = b""
+ vq_codes = []
+ current_vq = False
+
+ async def decode_send():
+ nonlocal current_vq
+ nonlocal vq_codes
+
+ data = np.concatenate(vq_codes, axis=1).tolist()
+ # Decode VQ codes to audio
+ decode_request = ServeVQGANDecodeRequest(tokens=[data])
+ decode_response = await self.client.post(
+ f"{self.vqgan_url}/v1/vqgan/decode",
+ data=ormsgpack.packb(
+ decode_request,
+ option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
+ ),
+ headers={"Content-Type": "application/msgpack"},
+ )
+ decode_data = ormsgpack.unpackb(decode_response.content)
+
+ # Convert float16 audio data to int16
+ audio_data = np.frombuffer(decode_data["audios"][0], dtype=np.float16)
+ audio_data = (audio_data * 32768).astype(np.int16).tobytes()
+
+ audio_frame = CustomAudioFrame(
+ data=audio_data,
+ samples_per_channel=len(audio_data) // 2,
+ sample_rate=44100,
+ num_channels=1,
+ )
+ yield FishE2EEvent(
+ type=FishE2EEventType.SPEECH_SEGMENT,
+ frame=audio_frame,
+ vq_codes=data,
+ )
+
+ current_vq = False
+ vq_codes = []
+
+ async with self.client.stream(
+ "POST",
+ self.llm_url,
+ data=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
+ headers={"Content-Type": "application/msgpack"},
+ ) as response:
+
+ async for chunk in response.aiter_bytes():
+ buffer += chunk
+
+ while len(buffer) >= 4:
+ read_length = struct.unpack("I", buffer[:4])[0]
+ if len(buffer) < 4 + read_length:
+ break
+
+ body = buffer[4 : 4 + read_length]
+ buffer = buffer[4 + read_length :]
+ data = ormsgpack.unpackb(body)
+
+ if data["delta"] and data["delta"]["part"]:
+ if current_vq and data["delta"]["part"]["type"] == "text":
+ async for event in decode_send():
+ yield event
+ if data["delta"]["part"]["type"] == "text":
+ yield FishE2EEvent(
+ type=FishE2EEventType.TEXT_SEGMENT,
+ text=data["delta"]["part"]["text"],
+ )
+ elif data["delta"]["part"]["type"] == "vq":
+ vq_codes.append(np.array(data["delta"]["part"]["codes"]))
+ current_vq = True
+
+ if current_vq and vq_codes:
+ async for event in decode_send():
+ yield event
+
+ yield FishE2EEvent(type=FishE2EEventType.END_OF_TEXT)
+ yield FishE2EEvent(type=FishE2EEventType.END_OF_SPEECH)
+
+
+# Example usage:
+async def main():
+ import torchaudio
+
+ agent = FishE2EAgent()
+
+ # Replace this with actual audio data loading
+ with open("uz_story_en.m4a", "rb") as f:
+ audio_data = f.read()
+
+ audio_data, sample_rate = torchaudio.load("uz_story_en.m4a")
+ audio_data = (audio_data.numpy() * 32768).astype(np.int16)
+
+ stream = agent.stream(audio_data, sample_rate, 1)
+ if os.path.exists("audio_segment.wav"):
+ os.remove("audio_segment.wav")
+
+ async for event in stream:
+ if event.type == FishE2EEventType.SPEECH_SEGMENT:
+ # Handle speech segment (e.g., play audio or save to file)
+ with open("audio_segment.wav", "ab+") as f:
+ f.write(event.frame.data)
+ elif event.type == FishE2EEventType.ASR_RESULT:
+ print(event.text, flush=True)
+ elif event.type == FishE2EEventType.TEXT_SEGMENT:
+ print(event.text, flush=True, end="")
+ elif event.type == FishE2EEventType.END_OF_TEXT:
+ print("\nEnd of text reached.")
+ elif event.type == FishE2EEventType.END_OF_SPEECH:
+ print("End of speech reached.")
+
+
+if __name__ == "__main__":
+ import asyncio
+
+ asyncio.run(main())
diff --git a/tools/llama/build_dataset.py b/tools/llama/build_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc5ef120cce2e04b24f0f897e49f022cb1946c97
--- /dev/null
+++ b/tools/llama/build_dataset.py
@@ -0,0 +1,169 @@
+import itertools
+import os
+import re
+from collections import defaultdict
+from functools import partial
+from multiprocessing import Pool
+from pathlib import Path
+
+import click
+import numpy as np
+from loguru import logger
+from tqdm import tqdm
+
+from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
+from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
+from tools.file import load_filelist
+
+# To avoid CPU overload
+os.environ["MKL_NUM_THREADS"] = "1"
+os.environ["OMP_NUM_THREADS"] = "1"
+
+
+def task_generator_folder(root: Path, text_extension: str):
+ files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
+ files = sorted(files)
+
+ grouped_files = defaultdict(list)
+ for file in tqdm(files, desc=f"Grouping {root}"):
+ p = str(file.parent)
+ speaker = file.parent.name
+
+ try:
+ if isinstance(text_extension, str):
+ texts = [file.with_suffix(text_extension).read_text(encoding="utf-8")]
+ else:
+ texts = [
+ file.with_suffix(ext).read_text(encoding="utf-8")
+ for ext in text_extension
+ ]
+ except Exception as e:
+ logger.error(f"Failed to read text {file}: {e}")
+ continue
+
+ grouped_files[p].append((speaker, file, texts))
+
+ logger.info(
+ f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
+ )
+
+ for i in grouped_files.values():
+ subset = [(f, t) for _, f, t in i]
+ yield i[0][0], subset, "folder"
+
+
+def task_generator_filelist(filelist):
+ grouped_files = defaultdict(list)
+ for filename, speaker, _, text in load_filelist(filelist):
+ grouped_files[speaker].append((Path(filename), [text]))
+
+ logger.info(f"Found {len(grouped_files)} groups in {filelist}")
+ for speaker, values in grouped_files.items():
+ yield speaker, values, "filelist"
+
+
+def run_task(task):
+ name, subset, source = task
+
+ # Parse the files
+ sentences = []
+ for file, texts in subset:
+ np_file = file.with_suffix(".npy")
+ if np_file.exists() is False:
+ logger.warning(f"Can't find {np_file}")
+ continue
+
+ new_texts = []
+
+ for text in texts:
+ # Simple cleaning: replace { xxx } and < xxx > with space
+ text = re.sub(r"\{.*?\}", " ", text)
+ text = re.sub(r"<.*?>", " ", text)
+ text = re.sub(r"\s+", " ", text)
+ new_texts.append(text)
+
+ try:
+ semantics = np.load(np_file)
+ except Exception as e:
+ logger.error(f"Failed to parse {file}: {e}")
+ continue
+
+ if isinstance(semantics, np.ndarray):
+ semantics = semantics.tolist()
+
+ sentences.append(
+ Sentence(
+ texts=new_texts,
+ semantics=[Semantics(values=s) for s in semantics],
+ )
+ )
+
+ # Pack the sentences
+ return pack_pb_stream(
+ TextData(
+ source=source,
+ name=name,
+ sentences=sentences,
+ )
+ )
+
+
+@click.command()
+@click.option(
+ "--input",
+ type=click.Path(path_type=Path),
+ required=True,
+ help="A folder containing the dataset or a filelist",
+ multiple=True,
+)
+@click.option(
+ "--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
+)
+@click.option("--num-workers", type=int, default=16)
+@click.option("--text-extension", type=str, default=[".txt"], multiple=True)
+@click.option(
+ "--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
+)
+def main(input, output, num_workers, text_extension, shard_size):
+ generator_fns = []
+
+ for f in input:
+ assert f.exists(), f"{f} not found"
+
+ if f.is_dir():
+ generator_fn = task_generator_folder(f, text_extension)
+ else:
+ generator_fn = task_generator_filelist(f)
+
+ generator_fns.append(generator_fn)
+
+ generator_fn = itertools.chain(*generator_fns)
+ output.mkdir(parents=True, exist_ok=True)
+
+ dataset_fp = None
+ tar_idx = 0
+ written_size = 0
+
+ with Pool(num_workers) as p:
+ for result in tqdm(p.imap_unordered(run_task, generator_fn)):
+ if dataset_fp is None:
+ dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")
+
+ dataset_fp.write(result)
+ written_size += len(result)
+
+ if written_size > shard_size * 1024 * 1024:
+ logger.info(f"Finished writing {tar_idx} shards to {output}")
+ dataset_fp.close()
+ dataset_fp = None
+ written_size = 0
+ tar_idx += 1
+
+ if dataset_fp is not None:
+ dataset_fp.close()
+
+ logger.info(f"Finished writing {tar_idx + 1} shards to {output}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/llama/eval_in_context.py b/tools/llama/eval_in_context.py
new file mode 100644
index 0000000000000000000000000000000000000000..30d70940487388185381246d8210a49a58e55743
--- /dev/null
+++ b/tools/llama/eval_in_context.py
@@ -0,0 +1,171 @@
+import pyrootutils
+import torch
+import torch.nn.functional as F
+from matplotlib import pyplot as plt
+from transformers import AutoTokenizer
+
+# register eval resolver and root
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+
+from torch.utils.data import DataLoader
+
+from fish_speech.datasets.semantic import AutoAugTextDataset, TextDataCollator
+from tools.llama.generate import load_model
+
+
+def smooth(
+ scalars: list[float], weight: float
+) -> list[float]: # Weight between 0 and 1
+ last = scalars[0] # First value in the plot (first timestep)
+ smoothed = list()
+ for point in scalars:
+ smoothed_val = last * weight + (1 - weight) * point # Calculate smoothed value
+ smoothed.append(smoothed_val) # Save it
+ last = smoothed_val # Anchor the last smoothed value
+
+ return smoothed
+
+
+@torch.inference_mode()
+def analyze_one_model(loader, config, weight, max_length):
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ model = load_model(
+ config,
+ weight,
+ device,
+ torch.bfloat16,
+ max_length,
+ compile=False,
+ )[0]
+
+ current_step = 0
+ model.eval()
+
+ semantic_loss_sum = torch.zeros(
+ max_length,
+ dtype=torch.float32,
+ device=device,
+ )
+ counter = torch.zeros(
+ max_length,
+ dtype=torch.long,
+ device=device,
+ )
+
+ for batch in loader:
+ batch = {k: v.to(device) for k, v in batch.items()}
+
+ labels = batch["labels"]
+ outputs = model(
+ inp=batch["inputs"],
+ key_padding_mask=batch["attention_masks"],
+ )
+
+ token_logits = outputs.token_logits
+ codebook_logits = outputs.codebook_logits
+
+ # Generate labels
+ base_loss = F.cross_entropy(
+ token_logits.reshape(-1, token_logits.size(-1)),
+ labels[:, 0].reshape(-1),
+ ignore_index=-100,
+ reduction="none",
+ )
+
+ codebook_labels = labels[:, 1 : 1 + model.config.num_codebooks].mT
+ semantic_loss = F.cross_entropy(
+ codebook_logits.reshape(-1, codebook_logits.size(-1)),
+ codebook_labels.reshape(-1),
+ ignore_index=-100,
+ reduction="none",
+ )
+
+ base_loss = base_loss.reshape(labels[:, 0].shape)
+ semantic_loss = semantic_loss.reshape(codebook_labels.shape)
+
+ semantic_loss_frame = semantic_loss.mean(-1)
+ pad_pos = codebook_labels.sum(-1) == -100 * model.config.num_codebooks
+
+ for loss_sample, pad in zip(semantic_loss_frame, pad_pos):
+ semantic_loss_sum[~pad] += loss_sample[~pad]
+ counter[~pad] += 1
+
+ current_step += 1
+ if current_step == 10:
+ break
+
+ semantic_loss = semantic_loss.cpu()
+ counter = counter.cpu()
+ xs, ys = [], []
+
+ for i, (loss, count) in enumerate(zip(semantic_loss_sum, counter)):
+ if count > 0:
+ xs.append(i)
+ ys.append((loss / count).item()) # for better loss visualization
+
+ smoothed_ys = smooth(ys, 0.95)
+
+ # Unload model
+ del model
+ torch.cuda.empty_cache()
+
+ return xs, ys, smoothed_ys
+
+
+def main():
+ tokenizer = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
+ max_length = 4096
+
+ ds = AutoAugTextDataset(
+ ["data/protos/sft/云天河"],
+ tokenizer=tokenizer,
+ use_speaker=False,
+ interactive_prob=1.0,
+ max_length=max_length,
+ )
+
+ loader = DataLoader(
+ ds,
+ batch_size=8,
+ collate_fn=TextDataCollator(tokenizer, max_length=max_length),
+ num_workers=0,
+ shuffle=False,
+ )
+
+ plt.figure(figsize=(10, 5), dpi=200)
+
+ plt.xlabel("Frame")
+ plt.ylabel("Loss")
+ plt.yscale("log")
+ plt.title("Semantic Loss")
+ plt.grid(which="both", axis="both")
+ plt.xlim(0, max_length)
+
+ tests = [
+ (
+ "pertrain-medium",
+ "dual_ar_2_codebook_medium",
+ "checkpoints/text2semantic-pretrain-medium-2k-v1.pth",
+ ),
+ (
+ "sft-medium",
+ "dual_ar_2_codebook_medium",
+ "checkpoints/text2semantic-sft-medium-v1.1-4k.pth",
+ ),
+ (
+ "sft-large",
+ "dual_ar_2_codebook_large",
+ "checkpoints/text2semantic-sft-large-v1.1-4k.pth",
+ ),
+ ]
+
+ for name, config, weight in tests:
+ xs, _, smoothed_ys = analyze_one_model(loader, config, weight, max_length)
+ plt.plot(xs, smoothed_ys, label=name)
+
+ plt.legend()
+ plt.savefig("semantic_loss.png")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/llama/generate.py b/tools/llama/generate.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2beba97ad04f1895b342a0ae66328834f5c9e5d
--- /dev/null
+++ b/tools/llama/generate.py
@@ -0,0 +1,1087 @@
+import os
+import queue
+import threading
+import time
+from contextlib import nullcontext
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Literal, Optional, Tuple, Union
+
+import click
+import hydra
+import numpy as np
+import torch
+import torch._dynamo.config
+import torch._inductor.config
+from loguru import logger
+from tqdm import tqdm
+from transformers import AutoTokenizer
+
+from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
+from fish_speech.models.text2semantic.llama import BaseModelArgs
+from fish_speech.text import clean_text, split_text
+
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
+torch._inductor.config.coordinate_descent_tuning = True
+torch._inductor.config.triton.unique_kernel_names = True
+
+if hasattr(torch._inductor.config, "fx_graph_cache"):
+ # Experimental feature to reduce compilation times, will be on by default in future
+ torch._inductor.config.fx_graph_cache = True
+
+
+from torch.nn.attention import SDPBackend, sdpa_kernel
+
+from fish_speech.models.text2semantic.llama import (
+ BaseTransformer,
+ DualARTransformer,
+ NaiveTransformer,
+)
+
+
+def multinomial_sample_one_no_sync(
+ probs_sort,
+): # Does multinomial sampling without a cuda synchronization
+ q = torch.empty_like(probs_sort).exponential_(1)
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
+
+
+def logits_to_probs(
+ logits,
+ previous_tokens: Optional[torch.Tensor] = None,
+ temperature: torch.Tensor = 1.0,
+ top_p: torch.Tensor = 1.0,
+ repetition_penalty: torch.Tensor = 1.0,
+) -> torch.Tensor:
+ # Apply repetition penalty
+ if previous_tokens is not None:
+ previous_tokens = previous_tokens.long()
+ score = torch.gather(logits, dim=0, index=previous_tokens)
+ score = torch.where(
+ score < 0, score * repetition_penalty, score / repetition_penalty
+ )
+ logits.scatter_(dim=0, index=previous_tokens, src=score)
+
+ # Apply top-p sampling
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+ cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
+ sorted_indices_to_remove = cum_probs > top_p
+ sorted_indices_to_remove[0] = False # keep at least one option
+ indices_to_remove = sorted_indices_to_remove.scatter(
+ dim=0, index=sorted_indices, src=sorted_indices_to_remove
+ )
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
+
+ logits = logits / max(temperature, 1e-5)
+
+ probs = torch.nn.functional.softmax(logits, dim=-1)
+ return probs
+
+
+def multinomial_sample_one_no_sync_agent(
+ probs_sort,
+): # Does multinomial sampling without a cuda synchronization
+ q = torch.empty_like(probs_sort).exponential_(1)
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
+
+
+def logits_to_probs_agent(
+ logits,
+ previous_tokens: Optional[torch.Tensor] = None,
+ temperature: torch.Tensor = 1.0,
+ top_p: torch.Tensor = 1.0,
+ repetition_penalty: torch.Tensor = 1.0,
+) -> torch.Tensor:
+ # Apply repetition penalty
+ if previous_tokens is not None:
+ previous_tokens = previous_tokens.long()
+ score = torch.gather(logits, dim=-1, index=previous_tokens)
+ score = torch.where(
+ score < 0, score * repetition_penalty, score / repetition_penalty
+ )
+ logits.scatter_(dim=-1, index=previous_tokens, src=score)
+
+ # Apply top-p sampling
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+ cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
+ sorted_indices_to_remove = cum_probs > top_p
+ sorted_indices_to_remove[..., 0] = False # keep at least one option
+ indices_to_remove = sorted_indices_to_remove.scatter(
+ dim=-1, index=sorted_indices, src=sorted_indices_to_remove
+ )
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
+
+ logits = logits / max(temperature, 1e-5)
+
+ probs = torch.nn.functional.softmax(logits, dim=-1)
+ return probs
+
+
+def sample(
+ logits,
+ previous_tokens: Optional[torch.Tensor] = None,
+ **sampling_kwargs,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ probs = logits_to_probs(
+ logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
+ )
+ idx_next = multinomial_sample_one_no_sync(probs)
+ return idx_next, probs
+
+
+def sample_agent(
+ logits,
+ previous_tokens: Optional[torch.Tensor] = None,
+ **sampling_kwargs,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ probs = logits_to_probs_agent(
+ logits=logits[:, -1], previous_tokens=previous_tokens, **sampling_kwargs
+ )
+ idx_next = multinomial_sample_one_no_sync_agent(probs)
+ return idx_next, probs
+
+
+def decode_one_token_ar_agent(
+ model: DualARTransformer,
+ x: torch.Tensor,
+ input_pos: torch.Tensor,
+ previous_tokens: torch.Tensor = None,
+ semantic_id: int = 32003,
+ **sampling_kwargs,
+) -> torch.Tensor:
+ # print(x, input_pos)
+ x = model.forward_generate(x, input_pos)
+ logits = x.logits # [:, -1:]
+ hidden_states = x.hidden_states # [:, -1:]
+
+ sampling_kwargs_main = sampling_kwargs.copy()
+ sampling_kwargs_main["temperature"] = 0.1
+ sampling_kwargs_main["top_p"] = 0.1
+ sampling_kwargs_main["repetition_penalty"] = 1.0
+
+ codebooks = [
+ sample_agent(
+ logits,
+ previous_tokens=None, # Disable repetition penalty for the token codebook
+ **sampling_kwargs_main,
+ )[0]
+ ]
+
+ # Cleanup the cache
+ for layer in model.fast_layers:
+ layer.attention.kv_cache.k_cache.fill_(0)
+ layer.attention.kv_cache.v_cache.fill_(0)
+
+ for codebook_idx in range(model.config.num_codebooks):
+ input_pos = torch.tensor(
+ [codebook_idx], device=hidden_states.device, dtype=torch.long
+ )
+ logits = model.forward_generate_fast(hidden_states, input_pos)
+ a = sample_agent(
+ logits,
+ previous_tokens=(
+ previous_tokens[:, codebook_idx + 1]
+ if previous_tokens is not None
+ else None
+ ),
+ **sampling_kwargs,
+ )[0]
+ hidden_states = model.fast_embeddings(a)
+ codebooks.append(a)
+
+ codebooks = torch.stack(codebooks, dim=1)
+ codebooks[:, 1:, :] = torch.masked_fill(
+ codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
+ )
+
+ # for i in range(codebooks.size(1) - 1):
+ # codebooks[:, i + 1, :] = torch.masked_fill(
+ # codebooks[:, i + 1, :],
+ # codebooks[:, :1, :] != semantic_id,
+ # CODEBOOK_PAD_TOKEN_ID + i * 1024,
+ # )
+
+ # print(codebooks)
+
+ return codebooks
+
+
+def decode_one_token_naive_agent(
+ model: NaiveTransformer,
+ x: torch.Tensor,
+ input_pos: torch.Tensor,
+ previous_tokens: torch.Tensor = None,
+ semantic_id: int = 32003,
+ **sampling_kwargs,
+) -> torch.Tensor:
+ x = model.forward_generate(x, input_pos)
+
+ codebooks = [
+ sample(
+ x.token_logits,
+ previous_tokens=None, # Disable repetition penalty for the token codebook
+ **sampling_kwargs,
+ )[0]
+ ]
+
+ for i in range(model.config.num_codebooks):
+ codebooks.append(
+ sample_agent(
+ x.codebook_logits[:, :, i],
+ previous_tokens=(
+ previous_tokens[:, i + 1] if previous_tokens is not None else None
+ ),
+ **sampling_kwargs,
+ )[0]
+ )
+
+ codebooks = torch.stack(codebooks, dim=1)
+ codebooks[:, 1:, :] = torch.masked_fill(
+ codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
+ )
+
+ return codebooks
+
+
+def decode_one_token_ar(
+ model: DualARTransformer,
+ x: torch.Tensor,
+ input_pos: torch.Tensor,
+ previous_tokens: torch.Tensor = None,
+ semantic_id: int = 0,
+ **sampling_kwargs,
+) -> torch.Tensor:
+ x = model.forward_generate(x, input_pos)
+
+ sampling_kwargs_main = sampling_kwargs.copy()
+ # sampling_kwargs_main["temperature"] = 0.1
+ # sampling_kwargs_main["top_p"] = 0.1
+ # sampling_kwargs_main["repetition_penalty"] = 1.0
+
+ codebooks = [
+ sample(
+ x.logits,
+ previous_tokens=None, # Disable repetition penalty for the token codebook
+ **sampling_kwargs_main,
+ )[0]
+ ]
+
+ x = x.hidden_states
+
+ # Cleanup the cache
+ for layer in model.fast_layers:
+ layer.attention.kv_cache.k_cache.fill_(0)
+ layer.attention.kv_cache.v_cache.fill_(0)
+
+ for codebook_idx in range(model.config.num_codebooks):
+ input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
+ logits = model.forward_generate_fast(x, input_pos)
+ a = sample(
+ logits,
+ previous_tokens=(
+ previous_tokens[codebook_idx + 1]
+ if previous_tokens is not None
+ else None
+ ),
+ **sampling_kwargs,
+ )[0]
+ x = model.fast_embeddings(a)
+ codebooks.append(a)
+
+ codebooks = torch.stack(codebooks, dim=0)
+ codebooks[1:, :] = torch.masked_fill(
+ codebooks[1:, :], codebooks[:1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
+ )
+
+ return codebooks
+
+
+def decode_one_token_naive(
+ model: NaiveTransformer,
+ x: torch.Tensor,
+ input_pos: torch.Tensor,
+ previous_tokens: torch.Tensor = None,
+ **sampling_kwargs,
+) -> torch.Tensor:
+ x = model.forward_generate(x, input_pos)
+
+ sampling_kwargs_main = sampling_kwargs.copy()
+ sampling_kwargs_main["temperature"] = 0.1
+ sampling_kwargs_main["top_p"] = 0.1
+ sampling_kwargs_main["repetition_penalty"] = 1.0
+
+ codebooks = [
+ sample(
+ x.logits,
+ previous_tokens=None, # Disable repetition penalty for the token codebook
+ **sampling_kwargs_main,
+ )[0]
+ ]
+
+ for i in range(model.config.num_codebooks):
+ codebooks.append(
+ sample(
+ x.codebook_logits[:, :, i],
+ previous_tokens=(
+ previous_tokens[i + 1] if previous_tokens is not None else None
+ ),
+ **sampling_kwargs,
+ )[0]
+ )
+
+ return torch.stack(codebooks, dim=0)
+
+
+def decode_n_tokens(
+ model: NaiveTransformer,
+ cur_token: torch.Tensor,
+ input_pos: torch.Tensor,
+ num_new_tokens: int,
+ im_end_id: int = 4,
+ decode_one_token=decode_one_token_naive,
+ semantic_id: int = 0,
+ **sampling_kwargs,
+):
+ previous_tokens = torch.zeros(
+ (model.config.num_codebooks + 1, model.config.max_seq_len),
+ dtype=torch.int,
+ device=cur_token.device,
+ )
+
+ for i in tqdm(range(num_new_tokens)):
+ # We need to get windowed repeat penalty
+ win_size = 16
+ if i < win_size:
+ window = previous_tokens[:, :win_size]
+ else:
+ window = previous_tokens[:, i - win_size : i]
+
+ with (
+ torch.backends.cuda.sdp_kernel(
+ enable_flash=False, enable_mem_efficient=False, enable_math=True
+ )
+ if torch.cuda.is_available()
+ else nullcontext()
+ ): # Actually better for Inductor to codegen attention here
+ next_token = decode_one_token(
+ model=model,
+ x=cur_token,
+ input_pos=input_pos,
+ previous_tokens=window,
+ semantic_id=semantic_id,
+ **sampling_kwargs,
+ )
+
+ input_pos += 1
+ cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
+ previous_tokens[:, i : i + 1] = next_token.view(
+ model.config.num_codebooks + 1, -1
+ )
+
+ if cur_token[0, 0, -1] == im_end_id:
+ break
+
+ return previous_tokens[:, : i + 1]
+
+
+@torch.no_grad()
+@torch.inference_mode()
+def generate(
+ *,
+ model: NaiveTransformer,
+ prompt: torch.Tensor,
+ max_new_tokens: int,
+ im_end_id: int = 4,
+ decode_one_token=decode_one_token_naive,
+ **sampling_kwargs,
+) -> torch.Tensor:
+ """
+ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
+ """
+
+ # create an empty tensor of the expected final shape and fill in the current tokens
+ T = prompt.size(1)
+ semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
+
+ if max_new_tokens:
+ if T + max_new_tokens > model.config.max_seq_len:
+ max_new_tokens = model.config.max_seq_len - T
+ logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
+
+ T_new = T + max_new_tokens
+ else:
+ T_new = model.config.max_seq_len
+ max_new_tokens = T_new - T
+
+ device, dtype = prompt.device, prompt.dtype
+
+ codebook_dim = 1 + model.config.num_codebooks
+ # create an empty tensor of the expected final shape and fill in the current tokens
+ empty = torch.empty(
+ (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
+ )
+ empty[:, :T] = prompt
+ seq = empty
+ input_pos = torch.arange(0, T, device=device)
+
+ # Use non-accelerated version for now, to avoid compilation overhead
+ prefill_decode = (
+ decode_one_token_naive
+ if isinstance(model, NaiveTransformer)
+ else decode_one_token_ar
+ )
+
+ next_token = prefill_decode(
+ model,
+ prompt.view(1, codebook_dim, -1),
+ input_pos,
+ semantic_id=semantic_id,
+ **sampling_kwargs,
+ )
+ seq[:, T : T + 1] = next_token
+
+ input_pos = torch.tensor([T], device=device, dtype=torch.int)
+ x = decode_n_tokens(
+ model,
+ next_token.view(1, codebook_dim, -1),
+ input_pos,
+ max_new_tokens - 1,
+ im_end_id=im_end_id,
+ decode_one_token=decode_one_token,
+ semantic_id=semantic_id,
+ **sampling_kwargs,
+ )
+ # x = torch.cat(generated_tokens, dim=1)
+ seq = seq[:, : T + 1 + x.size(1)]
+ seq[:, T + 1 :] = x
+
+ return seq
+
+
+def decode_n_tokens_agent(
+ model: NaiveTransformer,
+ cur_token: torch.Tensor,
+ input_pos: torch.Tensor,
+ num_new_tokens: int,
+ im_end_id: int = 4,
+ semantic_id: int = 32003,
+ decode_one_token=decode_one_token_naive_agent,
+ early_stop_threshold: float = 0.6,
+ **sampling_kwargs,
+):
+ batch_size = cur_token.size(0)
+ previous_tokens = torch.zeros(
+ (batch_size, model.config.num_codebooks + 1, model.config.max_seq_len),
+ dtype=torch.int,
+ device=cur_token.device,
+ )
+ finished = torch.zeros(batch_size, dtype=torch.bool, device=cur_token.device)
+ finished = finished | (cur_token[:, 0, -1] == im_end_id)
+ start_time = time.time()
+
+ for i in tqdm(range(num_new_tokens), desc="Decoding: ", total=num_new_tokens):
+ # We need to get windowed repeat penalty
+ win_size = 16
+ if i < win_size:
+ window = previous_tokens[:, :, :win_size]
+ else:
+ window = previous_tokens[:, :, i - win_size : i]
+
+ with sdpa_kernel(
+ SDPBackend.MATH
+ ): # Actually better for Inductor to codegen attention here
+ next_token = decode_one_token(
+ model=model,
+ x=cur_token,
+ input_pos=input_pos,
+ previous_tokens=window,
+ semantic_id=semantic_id,
+ **sampling_kwargs,
+ )
+
+ input_pos += 1
+ cur_token = next_token.view(batch_size, model.config.num_codebooks + 1, -1)
+ previous_tokens[:, :, i : i + 1] = next_token.view(
+ batch_size, model.config.num_codebooks + 1, -1
+ )
+
+ yield cur_token.cpu()
+
+ finished = finished | (cur_token[:, 0, -1] == im_end_id)
+ if finished.all() or (
+ 0 < early_stop_threshold < 1
+ and finished.sum() >= round(batch_size * early_stop_threshold)
+ ):
+ break
+
+ total_time = time.time() - start_time
+ generated_tokens = i + 1
+ tokens_per_second = (generated_tokens / total_time) * batch_size
+ logger.info(
+ f"Decoded {generated_tokens} x {batch_size} tokens in {total_time:.2f}s ({tokens_per_second:.2f} tokens/s)"
+ )
+
+
+@torch.no_grad()
+@torch.inference_mode()
+def generate_agent(
+ *,
+ model: BaseTransformer,
+ prompt: torch.Tensor,
+ max_new_tokens: int,
+ im_end_id: int = 4,
+ semantic_id: int = 32003,
+ decode_one_token=decode_one_token_naive_agent,
+ num_samples: int = 1,
+ early_stop_threshold: float = 0.6,
+ **sampling_kwargs,
+):
+ """
+ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
+ """
+
+ # create an empty tensor of the expected final shape and fill in the current tokens
+ T = prompt.size(1)
+ prompt = prompt[None].repeat(num_samples, 1, 1)
+
+ if T >= model.config.max_seq_len:
+ raise ValueError(
+ f"Input sequence length {T} exceeds max_seq_len {model.config.max_seq_len}"
+ )
+
+ if max_new_tokens:
+ if T + max_new_tokens > model.config.max_seq_len:
+ max_new_tokens = model.config.max_seq_len - T
+ logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
+
+ T_new = T + max_new_tokens
+ else:
+ T_new = model.config.max_seq_len
+ max_new_tokens = T_new - T
+
+ device, dtype = prompt.device, prompt.dtype
+
+ codebook_dim = 1 + model.config.num_codebooks
+ input_pos = torch.arange(0, T, device=device)
+
+ # Use non-accelerated version for now, to avoid compilation overhead
+ prefill_decode = (
+ decode_one_token_naive_agent
+ if isinstance(model, NaiveTransformer)
+ else decode_one_token_ar_agent
+ )
+ next_token = prefill_decode(
+ model,
+ prompt,
+ input_pos,
+ semantic_id=semantic_id,
+ **sampling_kwargs,
+ ).view(num_samples, codebook_dim, -1)
+ yield next_token.cpu()
+
+ input_pos = torch.tensor([T], device=device, dtype=torch.int)
+
+ yield from decode_n_tokens_agent(
+ model,
+ next_token,
+ input_pos,
+ max_new_tokens - 1,
+ im_end_id=im_end_id,
+ semantic_id=semantic_id,
+ decode_one_token=decode_one_token,
+ early_stop_threshold=early_stop_threshold,
+ **sampling_kwargs,
+ )
+
+
+def encode_tokens(
+ tokenizer,
+ string,
+ device="cuda",
+ prompt_tokens=None,
+ num_codebooks=4,
+):
+ string = clean_text(string)
+ string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
+
+ new_tokens = tokenizer.encode(
+ string,
+ add_special_tokens=False,
+ max_length=10**6,
+ truncation=False,
+ )
+ tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
+
+ # Codebooks
+ zeros = (
+ torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
+ * CODEBOOK_PAD_TOKEN_ID
+ )
+ prompt = torch.cat((tokens, zeros), dim=0)
+
+ if prompt_tokens is None:
+ return prompt
+
+ # Get prompt tokens
+ if prompt_tokens.ndim == 3:
+ assert (
+ prompt_tokens.shape[0] == 1
+ ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
+ prompt_tokens = prompt_tokens[0]
+
+ assert prompt_tokens.ndim == 2
+ data = prompt_tokens + 1
+
+ if prompt_tokens.shape[0] > num_codebooks:
+ logger.warning(
+ f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
+ )
+ data = data[:num_codebooks]
+
+ # Add pad token for each codebook
+ data = torch.cat(
+ (data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
+ dim=1,
+ )
+
+ # Since 1.0, we use <|semantic|>
+ s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
+ end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
+ main_token_ids = (
+ torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
+ )
+ main_token_ids[0, -1] = end_token_id
+
+ data = torch.cat((main_token_ids, data), dim=0)
+ prompt = torch.cat((prompt, data), dim=1)
+
+ return prompt
+
+
+def load_model(checkpoint_path, device, precision, compile=False, is_agent=False):
+ model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
+ checkpoint_path, load_weights=True
+ )
+
+ model = model.to(device=device, dtype=precision)
+ logger.info(f"Restored model from checkpoint")
+
+ if isinstance(model, DualARTransformer):
+ decode_one_token = (
+ decode_one_token_ar_agent if is_agent else decode_one_token_ar
+ )
+ logger.info("Using DualARTransformer")
+ else:
+ decode_one_token = (
+ decode_one_token_naive_agent if is_agent else decode_one_token_naive
+ )
+ logger.info("Using NaiveTransformer")
+
+ if compile:
+ logger.info("Compiling function...")
+ decode_one_token = torch.compile(
+ decode_one_token,
+ fullgraph=True,
+ backend="inductor" if torch.cuda.is_available() else "aot_eager",
+ mode="reduce-overhead" if torch.cuda.is_available() else None,
+ )
+
+ return model.eval(), decode_one_token
+
+
+@dataclass
+class GenerateResponse:
+ action: Literal["sample", "next"]
+ codes: Optional[torch.Tensor] = None
+ text: Optional[str] = None
+
+
+def generate_long(
+ *,
+ model,
+ device: str | torch.device,
+ decode_one_token: callable,
+ text: str,
+ num_samples: int = 1,
+ max_new_tokens: int = 0,
+ top_p: int = 0.7,
+ repetition_penalty: float = 1.5,
+ temperature: float = 0.7,
+ compile: bool = False,
+ iterative_prompt: bool = True,
+ max_length: int = 2048,
+ chunk_length: int = 150,
+ prompt_text: Optional[str | list[str]] = None,
+ prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
+):
+ assert 0 < top_p <= 1, "top_p must be in (0, 1]"
+ assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
+ assert 0 < temperature < 2, "temperature must be in (0, 2)"
+
+ use_prompt = prompt_text is not None and prompt_tokens is not None
+ if use_prompt and isinstance(prompt_text, str):
+ prompt_text = [prompt_text]
+ prompt_tokens = [prompt_tokens]
+
+ assert use_prompt is False or len(prompt_text) == len(
+ prompt_tokens
+ ), "Prompt text and tokens must have the same length"
+
+ model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ tokenizer = model.tokenizer
+ im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
+
+ encoded = []
+ texts = split_text(text, chunk_length) if iterative_prompt else [text]
+ encoded_prompts = []
+
+ if use_prompt:
+ for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
+ encoded_prompts.append(
+ encode_tokens(
+ tokenizer,
+ string=t,
+ device=device,
+ prompt_tokens=c,
+ num_codebooks=model.config.num_codebooks,
+ )
+ )
+
+ for idx, text in enumerate(texts):
+ encoded.append(
+ encode_tokens(
+ tokenizer,
+ string=text,
+ device=device,
+ num_codebooks=model.config.num_codebooks,
+ )
+ )
+ logger.info(f"Encoded text: {text}")
+
+ # Move temperature, top_p, repetition_penalty to device
+ # This is important so that changing params doesn't trigger recompile
+ temperature = torch.tensor(temperature, device=device, dtype=torch.float)
+ top_p = torch.tensor(top_p, device=device, dtype=torch.float)
+ repetition_penalty = torch.tensor(
+ repetition_penalty, device=device, dtype=torch.float
+ )
+
+ for sample_idx in range(num_samples):
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+
+ global_encoded = []
+ seg_idx = 0
+
+ while seg_idx < len(encoded):
+ logger.info(
+ f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
+ )
+
+ seg = encoded[seg_idx]
+ global_encoded.append(seg)
+
+ lengths = reversed([seg.size(1) for seg in global_encoded])
+
+ # Pick last 2000 tokens
+ count = 0
+ for i, length in enumerate(lengths):
+ count += length
+ if count + length > max_length - 1024 - sum(
+ t.shape[1] for t in encoded_prompts
+ ):
+ break
+
+ if i != 0 and i % 2 == 0:
+ i -= 1
+
+ # Rotate the list, always make sure first segment is included to avoid drift
+ if i < len(global_encoded) - 2:
+ partial_encoded = global_encoded[:2] + global_encoded[-i:]
+ else:
+ partial_encoded = global_encoded
+
+ if use_prompt:
+ partial_encoded = encoded_prompts + partial_encoded
+
+ cat_encoded = torch.cat(partial_encoded, dim=1)
+ prompt_length = cat_encoded.size(1)
+
+ t0 = time.perf_counter()
+ y = generate(
+ model=model,
+ prompt=cat_encoded,
+ max_new_tokens=max_new_tokens,
+ im_end_id=im_end_id,
+ decode_one_token=decode_one_token,
+ temperature=temperature,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ )
+
+ if sample_idx == 0 and seg_idx == 0 and compile:
+ logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
+
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+
+ t = time.perf_counter() - t0
+
+ tokens_generated = y.size(1) - prompt_length
+ tokens_sec = tokens_generated / t
+ logger.info(
+ f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
+ )
+ logger.info(
+ f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
+ )
+
+ if torch.cuda.is_available():
+ logger.info(
+ f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
+ )
+
+ # Put the generated tokens
+ # since there is and tokens, we remove last 2 tokens
+ codes = y[1:, prompt_length:-1].clone()
+ codes = codes - 1
+ assert (codes >= 0).all(), f"Negative code found"
+
+ decoded = y[:, prompt_length:-1].clone()
+ # But for global encoding, we should keep the token
+
+ global_encoded.append(decoded)
+ assert (codes >= 0).all(), f"Negative code found: {codes}"
+ yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
+ seg_idx += 1
+
+ # This indicates the end of the current sample
+ yield GenerateResponse(action="next")
+
+
+@dataclass
+class WrappedGenerateResponse:
+ status: Literal["success", "error"]
+ response: Optional[GenerateResponse | Exception] = None
+
+
+@dataclass
+class GenerateRequest:
+ request: dict
+ response_queue: queue.Queue
+
+
+def launch_thread_safe_queue(
+ checkpoint_path,
+ device,
+ precision,
+ compile: bool = False,
+):
+ input_queue = queue.Queue()
+ init_event = threading.Event()
+
+ def worker():
+ model, decode_one_token = load_model(
+ checkpoint_path, device, precision, compile=compile
+ )
+ with torch.device(device):
+ model.setup_caches(
+ max_batch_size=1,
+ max_seq_len=model.config.max_seq_len,
+ dtype=next(model.parameters()).dtype,
+ )
+ init_event.set()
+
+ while True:
+ item: GenerateRequest | None = input_queue.get()
+ if item is None:
+ break
+
+ kwargs = item.request
+ response_queue = item.response_queue
+
+ try:
+ for chunk in generate_long(
+ model=model, decode_one_token=decode_one_token, **kwargs
+ ):
+ response_queue.put(
+ WrappedGenerateResponse(status="success", response=chunk)
+ )
+ except Exception as e:
+ response_queue.put(WrappedGenerateResponse(status="error", response=e))
+
+ threading.Thread(target=worker, daemon=True).start()
+ init_event.wait()
+
+ return input_queue
+
+
+def launch_thread_safe_queue_agent(
+ checkpoint_path,
+ device,
+ precision,
+ compile: bool = False,
+):
+ input_queue = queue.Queue()
+ init_event = threading.Event()
+
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
+ config = BaseModelArgs.from_pretrained(checkpoint_path)
+
+ def worker():
+ model, decode_one_token = load_model(
+ checkpoint_path, device, precision, compile=compile, is_agent=True
+ )
+
+ with torch.device(device):
+ model.setup_caches(
+ max_batch_size=1,
+ max_seq_len=model.config.max_seq_len,
+ dtype=next(model.parameters()).dtype,
+ )
+ init_event.set()
+
+ while True:
+ item: GenerateRequest | None = input_queue.get()
+ if item is None:
+ break
+
+ kwargs = item.request
+ response_queue = item.response_queue
+
+ try:
+ for token in generate_agent(
+ model=model,
+ decode_one_token=decode_one_token,
+ **kwargs,
+ ):
+ response_queue.put(token)
+
+ response_queue.put("stop")
+ except Exception as e:
+ import traceback
+
+ logger.exception(f"Error in worker: {traceback.format_exc()}")
+ response_queue.put("error")
+
+ threading.Thread(target=worker, daemon=True).start()
+ init_event.wait()
+
+ return input_queue, tokenizer, config
+
+
+@click.command()
+@click.option(
+ "--text",
+ type=str,
+ default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
+)
+@click.option("--prompt-text", type=str, default=None, multiple=True)
+@click.option(
+ "--prompt-tokens",
+ type=click.Path(path_type=Path, exists=True),
+ default=None,
+ multiple=True,
+)
+@click.option("--num-samples", type=int, default=1)
+@click.option("--max-new-tokens", type=int, default=0)
+@click.option("--top-p", type=float, default=0.7)
+@click.option("--repetition-penalty", type=float, default=1.2)
+@click.option("--temperature", type=float, default=0.7)
+@click.option(
+ "--checkpoint-path",
+ type=click.Path(path_type=Path, exists=True),
+ default="checkpoints/fish-speech-1.4",
+)
+@click.option("--device", type=str, default="cuda")
+@click.option("--compile/--no-compile", default=False)
+@click.option("--seed", type=int, default=42)
+@click.option("--half/--no-half", default=False)
+@click.option("--iterative-prompt/--no-iterative-prompt", default=True)
+@click.option("--chunk-length", type=int, default=100)
+def main(
+ text: str,
+ prompt_text: Optional[list[str]],
+ prompt_tokens: Optional[list[Path]],
+ num_samples: int,
+ max_new_tokens: int,
+ top_p: int,
+ repetition_penalty: float,
+ temperature: float,
+ checkpoint_path: Path,
+ device: str,
+ compile: bool,
+ seed: int,
+ half: bool,
+ iterative_prompt: bool,
+ chunk_length: int,
+) -> None:
+
+ precision = torch.half if half else torch.bfloat16
+
+ if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
+ raise ValueError(
+ f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
+ )
+
+ logger.info("Loading model ...")
+ t0 = time.time()
+ model, decode_one_token = load_model(
+ checkpoint_path, device, precision, compile=compile
+ )
+ with torch.device(device):
+ model.setup_caches(
+ max_batch_size=1,
+ max_seq_len=model.config.max_seq_len,
+ dtype=next(model.parameters()).dtype,
+ )
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+
+ logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
+
+ if prompt_tokens is not None:
+ prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
+
+ torch.manual_seed(seed)
+
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed(seed)
+
+ generator = generate_long(
+ model=model,
+ device=device,
+ decode_one_token=decode_one_token,
+ text=text,
+ num_samples=num_samples,
+ max_new_tokens=max_new_tokens,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ temperature=temperature,
+ compile=compile,
+ iterative_prompt=iterative_prompt,
+ chunk_length=chunk_length,
+ prompt_text=prompt_text,
+ prompt_tokens=prompt_tokens,
+ )
+
+ idx = 0
+ codes = []
+
+ for response in generator:
+ if response.action == "sample":
+ codes.append(response.codes)
+ logger.info(f"Sampled text: {response.text}")
+ elif response.action == "next":
+ if codes:
+ np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
+ logger.info(f"Saved codes to codes_{idx}.npy")
+ logger.info(f"Next sample")
+ codes = []
+ idx += 1
+ else:
+ logger.error(f"Error: {response}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/llama/merge_lora.py b/tools/llama/merge_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1bd3cbd725c4eccbe78f711d9718dfb278a6aa7
--- /dev/null
+++ b/tools/llama/merge_lora.py
@@ -0,0 +1,95 @@
+import shutil
+from copy import deepcopy
+from pathlib import Path
+
+import click
+import hydra
+import torch
+from hydra import compose, initialize
+from hydra.utils import instantiate
+from loguru import logger
+
+from fish_speech.models.text2semantic.llama import BaseTransformer
+from fish_speech.models.text2semantic.lora import get_merged_state_dict
+
+
+@click.command()
+@click.option("--lora-config", type=str, default="r_8_alpha_16")
+@click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.4")
+@click.option("--lora-weight", type=str, required=True)
+@click.option("--output", type=str, required=True)
+def merge(lora_config, base_weight, lora_weight, output):
+ output = Path(output)
+ logger.info(
+ f"Merging {base_weight} and {lora_weight} into {output} with {lora_config}"
+ )
+
+ with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"):
+ cfg = compose(config_name=lora_config)
+
+ lora_config = instantiate(cfg)
+ logger.info(f"Loaded lora model with config {lora_config}")
+
+ llama_model = BaseTransformer.from_pretrained(
+ path=base_weight,
+ load_weights=True,
+ lora_config=lora_config,
+ )
+ logger.info(f"Loaded llama model")
+
+ llama_state_dict = llama_model.state_dict()
+ llama_state_dict = {k: v for k, v in llama_state_dict.items() if "lora" not in k}
+ llama_state_dict_copy = deepcopy(llama_state_dict)
+ lora_state_dict = torch.load(lora_weight, map_location="cpu")
+
+ if "state_dict" in llama_state_dict:
+ llama_state_dict = llama_state_dict["state_dict"]
+
+ if "state_dict" in lora_state_dict:
+ lora_state_dict = lora_state_dict["state_dict"]
+
+ # remove prefix model.
+ if any(k.startswith("model.") for k in llama_state_dict.keys()):
+ llama_state_dict = {
+ k.replace("model.", ""): v
+ for k, v in llama_state_dict.items()
+ if k.startswith("model.")
+ }
+ if any(k.startswith("model.") for k in lora_state_dict.keys()):
+ lora_state_dict = {
+ k.replace("model.", ""): v
+ for k, v in lora_state_dict.items()
+ if k.startswith("model.")
+ }
+
+ logger.info(f"Found {len(llama_state_dict)} keys in llama model")
+ logger.info(f"Found {len(lora_state_dict)} keys in lora model")
+
+ merged_state_dict = llama_state_dict | lora_state_dict
+ llama_model.load_state_dict(merged_state_dict, strict=True)
+ logger.info(f"Merged model loaded")
+
+ # Trigger eval mode to merge lora
+ llama_model.eval()
+ llama_model.save_pretrained(output, drop_lora=True)
+ logger.info(f"Saved merged model to {output}, validating")
+
+ new_state_dict = torch.load(output / "model.pth", map_location="cpu")
+ original_keys = set(llama_state_dict_copy.keys())
+ merged_keys = set(new_state_dict.keys())
+
+ assert original_keys == merged_keys, "Keys should be same"
+
+ for key in original_keys:
+ diff_l1 = (new_state_dict[key] - llama_state_dict_copy[key]).abs().sum().item()
+ if diff_l1 != 0:
+ break
+ else:
+ logger.error("Merged model is same as the original model")
+ exit(1)
+
+ logger.info("Merged model is different from the original model, check passed")
+
+
+if __name__ == "__main__":
+ merge()
diff --git a/tools/llama/quantize.py b/tools/llama/quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..e629d944b5d1e262f6c0517480980fcac01dad86
--- /dev/null
+++ b/tools/llama/quantize.py
@@ -0,0 +1,497 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+import datetime
+import shutil
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+import time
+from pathlib import Path
+
+import click
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from fish_speech.models.text2semantic.llama import find_multiple
+from tools.llama.generate import load_model
+
+##### Quantization Primitives ######
+
+
+def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
+ # assumes symmetric quantization
+ # assumes axis == 0
+ # assumes dense memory format
+ # TODO(future): relax ^ as needed
+
+ # default setup for affine quantization of activations
+ eps = torch.finfo(torch.float32).eps
+
+ # get min and max
+ min_val, max_val = torch.aminmax(x, dim=1)
+
+ # calculate scales and zero_points based on min and max
+ # reference: https://fburl.com/code/srbiybme
+ min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
+ max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
+ device = min_val_neg.device
+
+ # reference: https://fburl.com/code/4wll53rk
+ max_val_pos = torch.max(-min_val_neg, max_val_pos)
+ scales = max_val_pos / (float(quant_max - quant_min) / 2)
+ # ensure scales is the same dtype as the original tensor
+ scales = torch.clamp(scales, min=eps).to(x.dtype)
+ zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
+
+ # quantize based on qmin/qmax/scales/zp
+ # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
+ x_div = x / scales.unsqueeze(-1)
+ x_round = torch.round(x_div)
+ x_zp = x_round + zero_points.unsqueeze(-1)
+ quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
+
+ return quant, scales, zero_points
+
+
+def get_group_qparams(w, n_bit=4, groupsize=128):
+ # needed for GPTQ with padding
+ if groupsize > w.shape[-1]:
+ groupsize = w.shape[-1]
+ assert groupsize > 1
+ assert w.shape[-1] % groupsize == 0
+ assert w.dim() == 2
+
+ to_quant = w.reshape(-1, groupsize)
+ assert torch.isnan(to_quant).sum() == 0
+
+ max_val = to_quant.amax(dim=1, keepdim=True)
+ min_val = to_quant.amin(dim=1, keepdim=True)
+ max_int = 2**n_bit - 1
+ scales = (max_val - min_val).clamp(min=1e-6) / max_int
+ zeros = min_val + scales * (2 ** (n_bit - 1))
+ return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
+ torch.bfloat16
+ ).reshape(w.shape[0], -1)
+
+
+def pack_scales_and_zeros(scales, zeros):
+ assert scales.shape == zeros.shape
+ assert scales.dtype == torch.bfloat16
+ assert zeros.dtype == torch.bfloat16
+ return (
+ torch.cat(
+ [
+ scales.reshape(scales.size(0), scales.size(1), 1),
+ zeros.reshape(zeros.size(0), zeros.size(1), 1),
+ ],
+ 2,
+ )
+ .transpose(0, 1)
+ .contiguous()
+ )
+
+
+def unpack_scales_and_zeros(scales_and_zeros):
+ assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
+ assert scales_and_zeros.dtype == torch.float
+ return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
+
+
+def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
+ assert groupsize > 1
+ # needed for GPTQ single column quantize
+ if groupsize > w.shape[-1] and scales.shape[-1] == 1:
+ groupsize = w.shape[-1]
+
+ assert w.shape[-1] % groupsize == 0
+ assert w.dim() == 2
+
+ to_quant = w.reshape(-1, groupsize)
+ assert torch.isnan(to_quant).sum() == 0
+
+ scales = scales.reshape(-1, 1)
+ zeros = zeros.reshape(-1, 1)
+ min_val = zeros - scales * (2 ** (n_bit - 1))
+ max_int = 2**n_bit - 1
+ min_int = 0
+ w_int32 = (
+ to_quant.sub(min_val)
+ .div(scales)
+ .round()
+ .clamp_(min_int, max_int)
+ .to(torch.int32)
+ .reshape_as(w)
+ )
+
+ return w_int32
+
+
+def group_quantize_tensor(w, n_bit=4, groupsize=128):
+ scales, zeros = get_group_qparams(w, n_bit, groupsize)
+ w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
+ scales_and_zeros = pack_scales_and_zeros(scales, zeros)
+ return w_int32, scales_and_zeros
+
+
+def group_dequantize_tensor_from_qparams(
+ w_int32, scales, zeros, n_bit=4, groupsize=128
+):
+ assert groupsize > 1
+ # needed for GPTQ single column dequantize
+ if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
+ groupsize = w_int32.shape[-1]
+ assert w_int32.shape[-1] % groupsize == 0
+ assert w_int32.dim() == 2
+
+ w_int32_grouped = w_int32.reshape(-1, groupsize)
+ scales = scales.reshape(-1, 1)
+ zeros = zeros.reshape(-1, 1)
+
+ w_dq = (
+ w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32)
+ )
+ return w_dq
+
+
+def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128):
+ scales, zeros = unpack_scales_and_zeros(scales_and_zeros)
+ return group_dequantize_tensor_from_qparams(
+ w_int32, scales, zeros, n_bit, groupsize
+ )
+
+
+class QuantHandler:
+ def __init__(self, mod):
+ self.mod = mod
+
+ def create_quantized_state_dict(self) -> "StateDict":
+ pass
+
+ def convert_for_runtime(self) -> "nn.Module":
+ pass
+
+
+##### Weight-only int8 per-channel quantized code ######
+
+
+def replace_linear_weight_only_int8_per_channel(module):
+ for name, child in module.named_children():
+ if isinstance(child, nn.Linear):
+ setattr(
+ module,
+ name,
+ WeightOnlyInt8Linear(child.in_features, child.out_features),
+ )
+ else:
+ replace_linear_weight_only_int8_per_channel(child)
+
+
+class WeightOnlyInt8QuantHandler:
+ def __init__(self, mod):
+ self.mod = mod
+
+ @torch.no_grad()
+ def create_quantized_state_dict(self):
+ cur_state_dict = self.mod.state_dict()
+ for fqn, mod in self.mod.named_modules():
+ if isinstance(mod, torch.nn.Linear):
+ int8_weight, scales, _ = dynamically_quantize_per_channel(
+ mod.weight.float(), -128, 127, torch.int8
+ )
+ cur_state_dict[f"{fqn}.weight"] = int8_weight
+ cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
+
+ return cur_state_dict
+
+ def convert_for_runtime(self):
+ replace_linear_weight_only_int8_per_channel(self.mod)
+ return self.mod
+
+
+class WeightOnlyInt8Linear(torch.nn.Module):
+ __constants__ = ["in_features", "out_features"]
+ in_features: int
+ out_features: int
+ weight: torch.Tensor
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias: bool = True,
+ device=None,
+ dtype=None,
+ ) -> None:
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.register_buffer(
+ "weight", torch.empty((out_features, in_features), dtype=torch.int8)
+ )
+ self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
+
+
+##### weight only int4 per channel groupwise quantized code ######
+
+
+def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
+ weight_int32, scales_and_zeros = group_quantize_tensor(
+ weight_bf16, n_bit=4, groupsize=groupsize
+ )
+ weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
+ weight_int32, inner_k_tiles
+ )
+ return weight_int4pack, scales_and_zeros
+
+
+def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
+ origin_x_size = x.size()
+ x = x.reshape(-1, origin_x_size[-1])
+ c = torch.ops.aten._weight_int4pack_mm(
+ x, weight_int4pack, groupsize, scales_and_zeros
+ )
+ new_shape = origin_x_size[:-1] + (out_features,)
+ c = c.reshape(new_shape)
+ return c
+
+
+def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1):
+ return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
+
+
+def replace_linear_int4(module, groupsize, inner_k_tiles, padding):
+ for name, child in module.named_children():
+ if isinstance(child, nn.Linear):
+ if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
+ setattr(
+ module,
+ name,
+ WeightOnlyInt4Linear(
+ child.in_features,
+ child.out_features,
+ bias=False,
+ groupsize=groupsize,
+ inner_k_tiles=inner_k_tiles,
+ padding=False,
+ ),
+ )
+ elif padding:
+ setattr(
+ module,
+ name,
+ WeightOnlyInt4Linear(
+ child.in_features,
+ child.out_features,
+ bias=False,
+ groupsize=groupsize,
+ inner_k_tiles=inner_k_tiles,
+ padding=True,
+ ),
+ )
+ else:
+ replace_linear_int4(child, groupsize, inner_k_tiles, padding)
+
+
+class WeightOnlyInt4QuantHandler:
+ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
+ self.mod = mod
+ self.groupsize = groupsize
+ self.inner_k_tiles = inner_k_tiles
+ self.padding = padding
+ assert groupsize in [32, 64, 128, 256]
+ assert inner_k_tiles in [2, 4, 8]
+
+ @torch.no_grad()
+ def create_quantized_state_dict(self):
+ cur_state_dict = self.mod.state_dict()
+ for fqn, mod in self.mod.named_modules():
+ if isinstance(mod, torch.nn.Linear):
+ assert not mod.bias
+ out_features = mod.out_features
+ in_features = mod.in_features
+ assert out_features % 8 == 0, "require out_features % 8 == 0"
+ print(f"linear: {fqn}, in={in_features}, out={out_features}")
+
+ weight = mod.weight.data
+ if not _check_linear_int4_k(
+ in_features, self.groupsize, self.inner_k_tiles
+ ):
+ if self.padding:
+ import torch.nn.functional as F
+
+ print(
+ f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
+ )
+ padded_in_features = find_multiple(in_features, 1024)
+ weight = F.pad(
+ weight, pad=(0, padded_in_features - in_features)
+ )
+ else:
+ print(
+ f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
+ + "and that groupsize and inner_k_tiles*16 evenly divide into it"
+ )
+ continue
+ (
+ weight_int4pack,
+ scales_and_zeros,
+ ) = prepare_int4_weight_and_scales_and_zeros(
+ weight.to(torch.bfloat16).to("cuda"),
+ self.groupsize,
+ self.inner_k_tiles,
+ )
+ cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu")
+ cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu")
+
+ return cur_state_dict
+
+ def convert_for_runtime(self):
+ replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
+ return self.mod
+
+
+class WeightOnlyInt4Linear(torch.nn.Module):
+ __constants__ = ["in_features", "out_features"]
+ in_features: int
+ out_features: int
+ weight: torch.Tensor
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias=True,
+ device=None,
+ dtype=None,
+ groupsize: int = 128,
+ inner_k_tiles: int = 8,
+ padding: bool = True,
+ ) -> None:
+ super().__init__()
+ self.padding = padding
+ if padding:
+ self.origin_in_features = in_features
+ in_features = find_multiple(in_features, 1024)
+
+ self.in_features = in_features
+ self.out_features = out_features
+ assert not bias, "require bias=False"
+ self.groupsize = groupsize
+ self.inner_k_tiles = inner_k_tiles
+
+ assert out_features % 8 == 0, "require out_features % 8 == 0"
+ assert (
+ in_features % (inner_k_tiles * 16) == 0
+ ), "require in_features % (innerKTiles * 16) == 0"
+ self.register_buffer(
+ "weight",
+ torch.empty(
+ (
+ out_features // 8,
+ in_features // (inner_k_tiles * 16),
+ 32,
+ inner_k_tiles // 2,
+ ),
+ dtype=torch.int32,
+ ),
+ )
+ self.register_buffer(
+ "scales_and_zeros",
+ torch.empty(
+ (in_features // groupsize, out_features, 2), dtype=torch.bfloat16
+ ),
+ )
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ input = input.to(torch.bfloat16)
+ if self.padding:
+ import torch.nn.functional as F
+
+ input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
+ return linear_forward_int4(
+ input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize
+ )
+
+
+def generate_folder_name():
+ now = datetime.datetime.now()
+ folder_name = now.strftime("%Y%m%d_%H%M%S")
+ return folder_name
+
+
+@click.command()
+@click.option(
+ "--checkpoint-path",
+ type=click.Path(path_type=Path, exists=True),
+ default="checkpoints/fish-speech-1.4",
+)
+@click.option(
+ "--mode", type=str, default="int8", help="type of quantization to perform"
+)
+@click.option(
+ "--groupsize", type=int, default=128, help="Group size for int4 quantization."
+)
+@click.option("--timestamp", type=str, default="None", help="When to do quantization")
+def quantize(checkpoint_path: Path, mode: str, groupsize: int, timestamp: str) -> None:
+
+ device = "cpu"
+ precision = torch.bfloat16
+
+ print("Loading model ...")
+ t0 = time.time()
+
+ model, _ = load_model(
+ checkpoint_path=checkpoint_path,
+ device=device,
+ precision=precision,
+ compile=False,
+ )
+ vq_model = "firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+ now = timestamp if timestamp != "None" else generate_folder_name()
+
+ if mode == "int8":
+ print(
+ "Quantizing model weights for int8 weight-only symmetric per-channel quantization"
+ )
+ quant_handler = WeightOnlyInt8QuantHandler(model)
+ quantized_state_dict = quant_handler.create_quantized_state_dict()
+
+ dir_name = checkpoint_path
+ dst_name = Path(f"checkpoints/fs-1.2-int8-{now}")
+ shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
+ if (dst_name / vq_model).exists():
+ (dst_name / vq_model).unlink()
+ quantize_path = dst_name / "model.pth"
+
+ elif mode == "int4":
+ print(
+ "Quantizing model weights for int4 weight-only affine per-channel groupwise quantization"
+ )
+ quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
+ quantized_state_dict = quant_handler.create_quantized_state_dict()
+
+ dir_name = checkpoint_path
+ dst_name = Path(f"checkpoints/fs-1.2-int4-g{groupsize}-{now}")
+ shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
+ if (dst_name / vq_model).exists():
+ (dst_name / vq_model).unlink()
+ quantize_path = dst_name / "model.pth"
+
+ else:
+ raise ValueError(
+ f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]"
+ )
+
+ print(f"Writing quantized weights to {quantize_path}")
+ quantize_path.unlink(missing_ok=True) # remove existing file if one already there
+ torch.save(quantized_state_dict, quantize_path)
+ print(f"Quantization complete took {time.time() - t0:.02f} seconds")
+
+
+if __name__ == "__main__":
+ quantize()
diff --git a/tools/llama/rebuild_tokenizer.py b/tools/llama/rebuild_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea64fa6788833000c8dc41e3d570dd5b250fb14b
--- /dev/null
+++ b/tools/llama/rebuild_tokenizer.py
@@ -0,0 +1,57 @@
+from tokenizers import Tokenizer, decoders, models, pre_tokenizers, processors, trainers
+from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
+
+# Initialize a tokenizer
+tokenizer = Tokenizer(models.BPE())
+
+# Customize pre-tokenization and decoding
+tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
+tokenizer.decoder = decoders.ByteLevel()
+tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
+
+# Don't train the tokenizer
+trainer = trainers.BpeTrainer(
+ vocab_size=0,
+ min_frequency=2,
+ initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
+ special_tokens=[
+ "<|begin_of_sequence|>",
+ "<|end_of_sequence|>",
+ "<|im_start|>",
+ "<|im_sep|>", # system, user, assistant, etc.
+ "<|im_end|>",
+ "<|semantic|>", # audio features
+ "<|pad|>",
+ ],
+)
+
+# <|im_start|>user<|im_sep|>...<|im_end|>
+# <|im_start|>assistant<|im_sep|><|semantic|><|semantic|><|semantic|><|semantic|><|semantic|><|im_end|>
+tokenizer.train_from_iterator([], trainer=trainer)
+
+print(len(tokenizer.get_vocab()))
+x = tokenizer.encode(
+ "Hello, how are you? dfgnviadfjoiviouajeiodfjv 你好世界 🈶<|semantic|>"
+).ids
+print(x, len(x))
+print(tokenizer.decode(x, skip_special_tokens=True))
+
+
+tokenizer = PreTrainedTokenizerFast(
+ tokenizer_object=tokenizer,
+ pad_token="<|pad|>",
+ bos_token="<|begin_of_sequence|>",
+ eos_token="<|end_of_sequence|>",
+)
+
+# Try tokenizing a new sequence
+sequence = "All around, too, lay vast quantities of the costliest merchandise, and treasures were heaped in every cranny of the rocks, but all these things only added to the desolation of the scene. 测试中文, 你好世界 🈶<|semantic|>"
+encoded = tokenizer(sequence).input_ids
+
+print("Test encoding....")
+print(f"\tSentence: {sequence}")
+print(f"\tEncoded: {encoded}")
+print(f"\tDecoded: {tokenizer.batch_decode(encoded)}")
+print(f"\tDecoded: {tokenizer.decode(encoded)}")
+
+tokenizer.push_to_hub("fishaudio/fish-speech-1", private=True)
diff --git a/tools/msgpack_api.py b/tools/msgpack_api.py
new file mode 100644
index 0000000000000000000000000000000000000000..896cccf76f3d23c7bb7db6490f887f6618ce48b0
--- /dev/null
+++ b/tools/msgpack_api.py
@@ -0,0 +1,95 @@
+import os
+from argparse import ArgumentParser
+from pathlib import Path
+
+import httpx
+import ormsgpack
+
+from tools.schema import ServeReferenceAudio, ServeTTSRequest
+
+api_key = os.environ.get("FISH_API_KEY", "YOUR_API_KEY")
+
+
+def audio_request():
+ # priority: ref_id > references
+ request = ServeTTSRequest(
+ text="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
+ # reference_id="114514",
+ references=[
+ ServeReferenceAudio(
+ audio=open("lengyue.wav", "rb").read(),
+ text=open("lengyue.lab", "r", encoding="utf-8").read(),
+ )
+ ],
+ streaming=True,
+ )
+
+ api_key = os.environ.get("FISH_API_KEY", "YOUR_API_KEY")
+
+ with (
+ httpx.Client() as client,
+ open("hello.wav", "wb") as f,
+ ):
+ with client.stream(
+ "POST",
+ "http://127.0.0.1:8080/v1/tts",
+ content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
+ headers={
+ "authorization": f"Bearer {api_key}",
+ "content-type": "application/msgpack",
+ },
+ timeout=None,
+ ) as response:
+ for chunk in response.iter_bytes():
+ f.write(chunk)
+
+
+def asr_request(audio_path: Path):
+
+ # Read the audio file
+ with open(
+ str(audio_path),
+ "rb",
+ ) as audio_file:
+ audio_data = audio_file.read()
+
+ # Prepare the request data
+ request_data = {
+ "audio": audio_data,
+ "language": "en", # Optional: specify the language
+ "ignore_timestamps": False, # Optional: set to True to ignore precise timestamps
+ }
+
+ # Send the request
+ with httpx.Client() as client:
+ response = client.post(
+ "https://api.fish.audio/v1/asr",
+ headers={
+ "Authorization": f"Bearer {api_key}",
+ "Content-Type": "application/msgpack",
+ },
+ content=ormsgpack.packb(request_data),
+ )
+
+ # Parse the response
+ result = response.json()
+
+ print(f"Transcribed text: {result['text']}")
+ print(f"Audio duration: {result['duration']} seconds")
+
+ for segment in result["segments"]:
+ print(f"Segment: {segment['text']}")
+ print(f"Start time: {segment['start']}, End time: {segment['end']}")
+
+
+def parse_args():
+ parser = ArgumentParser()
+ parser.add_argument("--audio_path", type=Path, default="audio/ref/trump.mp3")
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+
+ asr_request(args.audio_path)
diff --git a/tools/post_api.py b/tools/post_api.py
new file mode 100644
index 0000000000000000000000000000000000000000..f319d12f1ec0515831ac9c667d6a1df8a500cc4b
--- /dev/null
+++ b/tools/post_api.py
@@ -0,0 +1,227 @@
+import argparse
+import base64
+import wave
+
+import ormsgpack
+import pyaudio
+import requests
+from pydub import AudioSegment
+from pydub.playback import play
+
+from tools.file import audio_to_bytes, read_ref_text
+from tools.schema import ServeReferenceAudio, ServeTTSRequest
+
+
+def parse_args():
+
+ parser = argparse.ArgumentParser(
+ description="Send a WAV file and text to a server and receive synthesized audio.",
+ formatter_class=argparse.RawTextHelpFormatter,
+ )
+
+ parser.add_argument(
+ "--url",
+ "-u",
+ type=str,
+ default="http://127.0.0.1:8080/v1/tts",
+ help="URL of the server",
+ )
+ parser.add_argument(
+ "--text", "-t", type=str, required=True, help="Text to be synthesized"
+ )
+ parser.add_argument(
+ "--reference_id",
+ "-id",
+ type=str,
+ default=None,
+ help="ID of the reference model to be used for the speech\n(Local: name of folder containing audios and files)",
+ )
+ parser.add_argument(
+ "--reference_audio",
+ "-ra",
+ type=str,
+ nargs="+",
+ default=None,
+ help="Path to the audio file",
+ )
+ parser.add_argument(
+ "--reference_text",
+ "-rt",
+ type=str,
+ nargs="+",
+ default=None,
+ help="Reference text for voice synthesis",
+ )
+ parser.add_argument(
+ "--output",
+ "-o",
+ type=str,
+ default="generated_audio",
+ help="Output audio file name",
+ )
+ parser.add_argument(
+ "--play",
+ type=bool,
+ default=True,
+ help="Whether to play audio after receiving data",
+ )
+ parser.add_argument("--normalize", type=bool, default=True)
+ parser.add_argument(
+ "--format", type=str, choices=["wav", "mp3", "flac"], default="wav"
+ )
+ parser.add_argument(
+ "--mp3_bitrate", type=int, choices=[64, 128, 192], default=64, help="kHz"
+ )
+ parser.add_argument("--opus_bitrate", type=int, default=-1000)
+ parser.add_argument(
+ "--latency",
+ type=str,
+ default="normal",
+ choices=["normal", "balanced"],
+ help="Used in api.fish.audio/v1/tts",
+ )
+ parser.add_argument(
+ "--max_new_tokens",
+ type=int,
+ default=0,
+ help="Maximum new tokens to generate. \n0 means no limit.",
+ )
+ parser.add_argument(
+ "--chunk_length", type=int, default=200, help="Chunk length for synthesis"
+ )
+ parser.add_argument(
+ "--top_p", type=float, default=0.7, help="Top-p sampling for synthesis"
+ )
+ parser.add_argument(
+ "--repetition_penalty",
+ type=float,
+ default=1.2,
+ help="Repetition penalty for synthesis",
+ )
+ parser.add_argument(
+ "--temperature", type=float, default=0.7, help="Temperature for sampling"
+ )
+
+ parser.add_argument(
+ "--streaming", type=bool, default=False, help="Enable streaming response"
+ )
+ parser.add_argument(
+ "--channels", type=int, default=1, help="Number of audio channels"
+ )
+ parser.add_argument("--rate", type=int, default=44100, help="Sample rate for audio")
+ parser.add_argument(
+ "--use_memory_cache",
+ type=str,
+ default="never",
+ choices=["on-demand", "never"],
+ help="Cache encoded references codes in memory.\n"
+ "If `on-demand`, the server will use cached encodings\n "
+ "instead of encoding reference audio again.",
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=None,
+ help="`None` means randomized inference, otherwise deterministic.\n"
+ "It can't be used for fixing a timbre.",
+ )
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+
+ args = parse_args()
+
+ idstr: str | None = args.reference_id
+ # priority: ref_id > [{text, audio},...]
+ if idstr is None:
+ ref_audios = args.reference_audio
+ ref_texts = args.reference_text
+ if ref_audios is None:
+ byte_audios = []
+ else:
+ byte_audios = [audio_to_bytes(ref_audio) for ref_audio in ref_audios]
+ if ref_texts is None:
+ ref_texts = []
+ else:
+ ref_texts = [read_ref_text(ref_text) for ref_text in ref_texts]
+ else:
+ byte_audios = []
+ ref_texts = []
+ pass # in api.py
+
+ data = {
+ "text": args.text,
+ "references": [
+ ServeReferenceAudio(audio=ref_audio, text=ref_text)
+ for ref_text, ref_audio in zip(ref_texts, byte_audios)
+ ],
+ "reference_id": idstr,
+ "normalize": args.normalize,
+ "format": args.format,
+ "mp3_bitrate": args.mp3_bitrate,
+ "opus_bitrate": args.opus_bitrate,
+ "max_new_tokens": args.max_new_tokens,
+ "chunk_length": args.chunk_length,
+ "top_p": args.top_p,
+ "repetition_penalty": args.repetition_penalty,
+ "temperature": args.temperature,
+ "streaming": args.streaming,
+ "use_memory_cache": args.use_memory_cache,
+ "seed": args.seed,
+ }
+
+ pydantic_data = ServeTTSRequest(**data)
+
+ response = requests.post(
+ args.url,
+ data=ormsgpack.packb(pydantic_data, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
+ stream=args.streaming,
+ headers={
+ "authorization": "Bearer YOUR_API_KEY",
+ "content-type": "application/msgpack",
+ },
+ )
+
+ if response.status_code == 200:
+ if args.streaming:
+ p = pyaudio.PyAudio()
+ audio_format = pyaudio.paInt16 # Assuming 16-bit PCM format
+ stream = p.open(
+ format=audio_format, channels=args.channels, rate=args.rate, output=True
+ )
+
+ wf = wave.open(f"{args.output}.wav", "wb")
+ wf.setnchannels(args.channels)
+ wf.setsampwidth(p.get_sample_size(audio_format))
+ wf.setframerate(args.rate)
+
+ stream_stopped_flag = False
+
+ try:
+ for chunk in response.iter_content(chunk_size=1024):
+ if chunk:
+ stream.write(chunk)
+ wf.writeframesraw(chunk)
+ else:
+ if not stream_stopped_flag:
+ stream.stop_stream()
+ stream_stopped_flag = True
+ finally:
+ stream.close()
+ p.terminate()
+ wf.close()
+ else:
+ audio_content = response.content
+ audio_path = f"{args.output}.{args.format}"
+ with open(audio_path, "wb") as audio_file:
+ audio_file.write(audio_content)
+
+ audio = AudioSegment.from_file(audio_path, format=args.format)
+ if args.play:
+ play(audio)
+ print(f"Audio has been saved to '{audio_path}'.")
+ else:
+ print(f"Request failed with status code {response.status_code}")
+ print(response.json())
diff --git a/tools/schema.py b/tools/schema.py
new file mode 100644
index 0000000000000000000000000000000000000000..0698a00932be264faf94bffc07b609b64a702b8f
--- /dev/null
+++ b/tools/schema.py
@@ -0,0 +1,187 @@
+import os
+import queue
+from dataclasses import dataclass
+from typing import Annotated, Literal, Optional
+
+import torch
+from pydantic import AfterValidator, BaseModel, Field, confloat, conint, conlist
+from pydantic.functional_validators import SkipValidation
+
+from fish_speech.conversation import Message, TextPart, VQPart
+
+GLOBAL_NUM_SAMPLES = int(os.getenv("GLOBAL_NUM_SAMPLES", 1))
+
+
+class ServeVQPart(BaseModel):
+ type: Literal["vq"] = "vq"
+ codes: SkipValidation[list[list[int]]]
+
+
+class ServeTextPart(BaseModel):
+ type: Literal["text"] = "text"
+ text: str
+
+
+class ServeAudioPart(BaseModel):
+ type: Literal["audio"] = "audio"
+ audio: bytes
+
+
+@dataclass
+class ASRPackRequest:
+ audio: torch.Tensor
+ result_queue: queue.Queue
+ language: str
+
+
+class ServeASRRequest(BaseModel):
+ # The audio should be an uncompressed PCM float16 audio
+ audios: list[bytes]
+ sample_rate: int = 44100
+ language: Literal["zh", "en", "ja", "auto"] = "auto"
+
+
+class ServeASRTranscription(BaseModel):
+ text: str
+ duration: float
+ huge_gap: bool
+
+
+class ServeASRSegment(BaseModel):
+ text: str
+ start: float
+ end: float
+
+
+class ServeTimedASRResponse(BaseModel):
+ text: str
+ segments: list[ServeASRSegment]
+ duration: float
+
+
+class ServeASRResponse(BaseModel):
+ transcriptions: list[ServeASRTranscription]
+
+
+class ServeMessage(BaseModel):
+ role: Literal["system", "assistant", "user"]
+ parts: list[ServeVQPart | ServeTextPart]
+
+ def to_conversation_message(self):
+ new_message = Message(role=self.role, parts=[])
+ for part in self.parts:
+ if isinstance(part, ServeTextPart):
+ new_message.parts.append(TextPart(text=part.text))
+ elif isinstance(part, ServeVQPart):
+ new_message.parts.append(
+ VQPart(codes=torch.tensor(part.codes, dtype=torch.int))
+ )
+ else:
+ raise ValueError(f"Unsupported part type: {part}")
+
+ return new_message
+
+
+class ServeRequest(BaseModel):
+ messages: Annotated[list[ServeMessage], conlist(ServeMessage, min_length=1)]
+ max_new_tokens: int = 1024
+ top_p: float = 0.7
+ repetition_penalty: float = 1.2
+ temperature: float = 0.7
+ streaming: bool = False
+ num_samples: int = 1
+ early_stop_threshold: float = 1.0
+
+
+class ServeVQGANEncodeRequest(BaseModel):
+ # The audio here should be in wav, mp3, etc
+ audios: list[bytes]
+
+
+class ServeVQGANEncodeResponse(BaseModel):
+ tokens: SkipValidation[list[list[list[int]]]]
+
+
+class ServeVQGANDecodeRequest(BaseModel):
+ tokens: SkipValidation[list[list[list[int]]]]
+
+
+class ServeVQGANDecodeResponse(BaseModel):
+ # The audio here should be in PCM float16 format
+ audios: list[bytes]
+
+
+class ServeReferenceAudio(BaseModel):
+ audio: bytes
+ text: str
+
+
+class ServeForwardMessage(BaseModel):
+ role: str
+ content: str
+
+
+class ServeResponse(BaseModel):
+ messages: list[ServeMessage]
+ finish_reason: Literal["stop", "error"] | None = None
+ stats: dict[str, int | float | str] = {}
+
+
+class ServeStreamDelta(BaseModel):
+ role: Literal["system", "assistant", "user"] | None = None
+ part: ServeVQPart | ServeTextPart | None = None
+
+
+class ServeStreamResponse(BaseModel):
+ sample_id: int = 0
+ delta: ServeStreamDelta | None = None
+ finish_reason: Literal["stop", "error"] | None = None
+ stats: dict[str, int | float | str] | None = None
+
+
+class ServeReferenceAudio(BaseModel):
+ audio: bytes
+ text: str
+
+ def __repr__(self) -> str:
+ return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
+
+
+class ServeChatRequestV1(BaseModel):
+ model: str = "llama3-8b"
+ messages: list[ServeForwardMessage] = []
+ audio: bytes | None = None
+ temperature: float = 1.0
+ top_p: float = 1.0
+ max_tokens: int = 256
+ voice: str = "jessica"
+ tts_audio_format: Literal["mp3", "pcm", "opus"] = "mp3"
+ tts_audio_bitrate: Literal[16, 24, 32, 48, 64, 96, 128, 192] = 128
+
+
+class ServeTTSRequest(BaseModel):
+ text: str
+ chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
+ # Audio format
+ format: Literal["wav", "pcm", "mp3"] = "wav"
+ mp3_bitrate: Literal[64, 128, 192] = 128
+ # References audios for in-context learning
+ references: list[ServeReferenceAudio] = []
+ # Reference id
+ # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
+ # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
+ reference_id: str | None = None
+ seed: int | None = None
+ use_memory_cache: Literal["on-demand", "never"] = "never"
+ # Normalize text for en & zh, this increase stability for numbers
+ normalize: bool = True
+ mp3_bitrate: Optional[int] = 64
+ opus_bitrate: Optional[int] = -1000
+ # Balance mode will reduce latency to 300ms, but may decrease stability
+ latency: Literal["normal", "balanced"] = "normal"
+ # not usually used below
+ streaming: bool = False
+ max_new_tokens: int = 1024
+ top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
+ repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
+ temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
diff --git a/tools/sensevoice/README.md b/tools/sensevoice/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..9a2078aa2d96dfafb445384316f2041d9e819e63
--- /dev/null
+++ b/tools/sensevoice/README.md
@@ -0,0 +1,59 @@
+# FunASR Command Line Interface
+
+This tool provides a command-line interface for separating vocals from instrumental tracks, converting videos to audio, and performing speech-to-text transcription on the resulting audio files.
+
+## Requirements
+
+- Python >= 3.10
+- PyTorch <= 2.3.1
+- ffmpeg, pydub, audio-separator[gpu].
+
+## Installation
+
+Install the required packages:
+
+```bash
+pip install -e .[stable]
+```
+
+Make sure you have `ffmpeg` installed and available in your `PATH`.
+
+## Usage
+
+### Basic Usage
+
+To run the tool with default settings:
+
+```bash
+python tools/sensevoice/fun_asr.py --audio-dir --save-dir
+```
+
+## Options
+
+| Option | Description |
+| :-----------------------: | :---------------------------------------------------------------------------: |
+| --audio-dir | Directory containing audio or video files. |
+| --save-dir | Directory to save processed audio files. |
+| --device | Device to use for processing. Options: cuda (default) or cpu. |
+| --language | Language of the transcription. Default is auto. |
+| --max_single_segment_time | Maximum duration of a single audio segment in milliseconds. Default is 20000. |
+| --punc | Enable punctuation prediction. |
+| --denoise | Enable noise reduction (vocal separation). |
+
+## Example
+
+To process audio files in the directory `path/to/audio` and save the output to `path/to/output`, with punctuation and noise reduction enabled:
+
+```bash
+python tools/sensevoice/fun_asr.py --audio-dir path/to/audio --save-dir path/to/output --punc --denoise
+```
+
+## Additional Notes
+
+- The tool supports `both audio and video files`. Videos will be converted to audio automatically.
+- If the `--denoise` option is used, the tool will perform vocal separation to isolate the vocals from the instrumental tracks.
+- The script will automatically create necessary directories in the `--save-dir`.
+
+## Troubleshooting
+
+If you encounter any issues, make sure all dependencies are correctly installed and configured. For more detailed troubleshooting, refer to the documentation of each dependency.
diff --git a/tools/sensevoice/__init__.py b/tools/sensevoice/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tools/sensevoice/auto_model.py b/tools/sensevoice/auto_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd2e186617fe889500d01d95eccdafc5c0248b84
--- /dev/null
+++ b/tools/sensevoice/auto_model.py
@@ -0,0 +1,573 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import copy
+import json
+import logging
+import os.path
+import random
+import re
+import string
+import time
+
+import numpy as np
+import torch
+from funasr.download.download_model_from_hub import download_model
+from funasr.download.file import download_from_url
+from funasr.register import tables
+from funasr.train_utils.load_pretrained_model import load_pretrained_model
+from funasr.train_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import export_utils, misc
+from funasr.utils.load_utils import load_audio_text_image_video, load_bytes
+from funasr.utils.misc import deep_update
+from funasr.utils.timestamp_tools import timestamp_sentence, timestamp_sentence_en
+from tqdm import tqdm
+
+from .vad_utils import merge_vad, slice_padding_audio_samples
+
+try:
+ from funasr.models.campplus.cluster_backend import ClusterBackend
+ from funasr.models.campplus.utils import distribute_spk, postprocess, sv_chunk
+except:
+ pass
+
+
+def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
+ """ """
+ data_list = []
+ key_list = []
+ filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]
+
+ chars = string.ascii_letters + string.digits
+ if isinstance(data_in, str):
+ if data_in.startswith("http://") or data_in.startswith("https://"): # url
+ data_in = download_from_url(data_in)
+
+ if isinstance(data_in, str) and os.path.exists(
+ data_in
+ ): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
+ _, file_extension = os.path.splitext(data_in)
+ file_extension = file_extension.lower()
+ if file_extension in filelist: # filelist: wav.scp, file.jsonl;text.txt;
+ with open(data_in, encoding="utf-8") as fin:
+ for line in fin:
+ key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
+ if data_in.endswith(
+ ".jsonl"
+ ): # file.jsonl: json.dumps({"source": data})
+ lines = json.loads(line.strip())
+ data = lines["source"]
+ key = data["key"] if "key" in data else key
+ else: # filelist, wav.scp, text.txt: id \t data or data
+ lines = line.strip().split(maxsplit=1)
+ data = lines[1] if len(lines) > 1 else lines[0]
+ key = lines[0] if len(lines) > 1 else key
+
+ data_list.append(data)
+ key_list.append(key)
+ else:
+ if key is None:
+ # key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
+ key = misc.extract_filename_without_extension(data_in)
+ data_list = [data_in]
+ key_list = [key]
+ elif isinstance(data_in, (list, tuple)):
+ if data_type is not None and isinstance(
+ data_type, (list, tuple)
+ ): # mutiple inputs
+ data_list_tmp = []
+ for data_in_i, data_type_i in zip(data_in, data_type):
+ key_list, data_list_i = prepare_data_iterator(
+ data_in=data_in_i, data_type=data_type_i
+ )
+ data_list_tmp.append(data_list_i)
+ data_list = []
+ for item in zip(*data_list_tmp):
+ data_list.append(item)
+ else:
+ # [audio sample point, fbank, text]
+ data_list = data_in
+ key_list = []
+ for data_i in data_in:
+ if isinstance(data_i, str) and os.path.exists(data_i):
+ key = misc.extract_filename_without_extension(data_i)
+ else:
+ if key is None:
+ key = "rand_key_" + "".join(
+ random.choice(chars) for _ in range(13)
+ )
+ key_list.append(key)
+
+ else: # raw text; audio sample point, fbank; bytes
+ if isinstance(data_in, bytes): # audio bytes
+ data_in = load_bytes(data_in)
+ if key is None:
+ key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
+ data_list = [data_in]
+ key_list = [key]
+
+ return key_list, data_list
+
+
+class AutoModel:
+
+ def __init__(self, **kwargs):
+
+ try:
+ from funasr.utils.version_checker import check_for_update
+
+ print(
+ "Check update of funasr, and it would cost few times. You may disable it by set `disable_update=True` in AutoModel"
+ )
+ check_for_update(disable=kwargs.get("disable_update", False))
+ except:
+ pass
+
+ log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
+ logging.basicConfig(level=log_level)
+
+ model, kwargs = self.build_model(**kwargs)
+
+ # if vad_model is not None, build vad model else None
+ vad_model = kwargs.get("vad_model", None)
+ vad_kwargs = (
+ {} if kwargs.get("vad_kwargs", {}) is None else kwargs.get("vad_kwargs", {})
+ )
+ if vad_model is not None:
+ logging.info("Building VAD model.")
+ vad_kwargs["model"] = vad_model
+ vad_kwargs["model_revision"] = kwargs.get("vad_model_revision", "master")
+ vad_kwargs["device"] = kwargs["device"]
+ vad_model, vad_kwargs = self.build_model(**vad_kwargs)
+
+ # if punc_model is not None, build punc model else None
+ punc_model = kwargs.get("punc_model", None)
+ punc_kwargs = (
+ {}
+ if kwargs.get("punc_kwargs", {}) is None
+ else kwargs.get("punc_kwargs", {})
+ )
+ if punc_model is not None:
+ logging.info("Building punc model.")
+ punc_kwargs["model"] = punc_model
+ punc_kwargs["model_revision"] = kwargs.get("punc_model_revision", "master")
+ punc_kwargs["device"] = kwargs["device"]
+ punc_model, punc_kwargs = self.build_model(**punc_kwargs)
+
+ # if spk_model is not None, build spk model else None
+ spk_model = kwargs.get("spk_model", None)
+ spk_kwargs = (
+ {} if kwargs.get("spk_kwargs", {}) is None else kwargs.get("spk_kwargs", {})
+ )
+ if spk_model is not None:
+ logging.info("Building SPK model.")
+ spk_kwargs["model"] = spk_model
+ spk_kwargs["model_revision"] = kwargs.get("spk_model_revision", "master")
+ spk_kwargs["device"] = kwargs["device"]
+ spk_model, spk_kwargs = self.build_model(**spk_kwargs)
+ self.cb_model = ClusterBackend().to(kwargs["device"])
+ spk_mode = kwargs.get("spk_mode", "punc_segment")
+ if spk_mode not in ["default", "vad_segment", "punc_segment"]:
+ logging.error(
+ "spk_mode should be one of default, vad_segment and punc_segment."
+ )
+ self.spk_mode = spk_mode
+
+ self.kwargs = kwargs
+ self.model = model
+ self.vad_model = vad_model
+ self.vad_kwargs = vad_kwargs
+ self.punc_model = punc_model
+ self.punc_kwargs = punc_kwargs
+ self.spk_model = spk_model
+ self.spk_kwargs = spk_kwargs
+ self.model_path = kwargs.get("model_path")
+
+ @staticmethod
+ def build_model(**kwargs):
+ assert "model" in kwargs
+ if "model_conf" not in kwargs:
+ logging.info(
+ "download models from model hub: {}".format(kwargs.get("hub", "ms"))
+ )
+ kwargs = download_model(**kwargs)
+
+ set_all_random_seed(kwargs.get("seed", 0))
+
+ device = kwargs.get("device", "cuda")
+ if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
+ device = "cpu"
+ kwargs["batch_size"] = 1
+ kwargs["device"] = device
+
+ torch.set_num_threads(kwargs.get("ncpu", 4))
+
+ # build tokenizer
+ tokenizer = kwargs.get("tokenizer", None)
+ if tokenizer is not None:
+ tokenizer_class = tables.tokenizer_classes.get(tokenizer)
+ tokenizer = tokenizer_class(**kwargs.get("tokenizer_conf", {}))
+ kwargs["token_list"] = (
+ tokenizer.token_list if hasattr(tokenizer, "token_list") else None
+ )
+ kwargs["token_list"] = (
+ tokenizer.get_vocab()
+ if hasattr(tokenizer, "get_vocab")
+ else kwargs["token_list"]
+ )
+ vocab_size = (
+ len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
+ )
+ if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"):
+ vocab_size = tokenizer.get_vocab_size()
+ else:
+ vocab_size = -1
+ kwargs["tokenizer"] = tokenizer
+
+ # build frontend
+ frontend = kwargs.get("frontend", None)
+ kwargs["input_size"] = None
+ if frontend is not None:
+ frontend_class = tables.frontend_classes.get(frontend)
+ frontend = frontend_class(**kwargs.get("frontend_conf", {}))
+ kwargs["input_size"] = (
+ frontend.output_size() if hasattr(frontend, "output_size") else None
+ )
+ kwargs["frontend"] = frontend
+ # build model
+ model_class = tables.model_classes.get(kwargs["model"])
+ assert model_class is not None, f'{kwargs["model"]} is not registered'
+ model_conf = {}
+ deep_update(model_conf, kwargs.get("model_conf", {}))
+ deep_update(model_conf, kwargs)
+ model = model_class(**model_conf, vocab_size=vocab_size)
+
+ # init_param
+ init_param = kwargs.get("init_param", None)
+ if init_param is not None:
+ if os.path.exists(init_param):
+ logging.info(f"Loading pretrained params from {init_param}")
+ load_pretrained_model(
+ model=model,
+ path=init_param,
+ ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
+ oss_bucket=kwargs.get("oss_bucket", None),
+ scope_map=kwargs.get("scope_map", []),
+ excludes=kwargs.get("excludes", None),
+ )
+ else:
+ print(f"error, init_param does not exist!: {init_param}")
+
+ # fp16
+ if kwargs.get("fp16", False):
+ model.to(torch.float16)
+ elif kwargs.get("bf16", False):
+ model.to(torch.bfloat16)
+ model.to(device)
+
+ if not kwargs.get("disable_log", True):
+ tables.print()
+
+ return model, kwargs
+
+ def __call__(self, *args, **cfg):
+ kwargs = self.kwargs
+ deep_update(kwargs, cfg)
+ res = self.model(*args, kwargs)
+ return res
+
+ def generate(self, input, input_len=None, **cfg):
+ if self.vad_model is None:
+ return self.inference(input, input_len=input_len, **cfg)
+
+ else:
+ return self.inference_with_vad(input, input_len=input_len, **cfg)
+
+ def inference(
+ self, input, input_len=None, model=None, kwargs=None, key=None, **cfg
+ ):
+ kwargs = self.kwargs if kwargs is None else kwargs
+ if "cache" in kwargs:
+ kwargs.pop("cache")
+ deep_update(kwargs, cfg)
+ model = self.model if model is None else model
+ model.eval()
+
+ batch_size = kwargs.get("batch_size", 1)
+ # if kwargs.get("device", "cpu") == "cpu":
+ # batch_size = 1
+
+ key_list, data_list = prepare_data_iterator(
+ input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key
+ )
+
+ speed_stats = {}
+ asr_result_list = []
+ num_samples = len(data_list)
+ disable_pbar = self.kwargs.get("disable_pbar", False)
+ pbar = (
+ tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
+ if not disable_pbar
+ else None
+ )
+ time_speech_total = 0.0
+ time_escape_total = 0.0
+ for beg_idx in range(0, num_samples, batch_size):
+ end_idx = min(num_samples, beg_idx + batch_size)
+ data_batch = data_list[beg_idx:end_idx]
+ key_batch = key_list[beg_idx:end_idx]
+ batch = {"data_in": data_batch, "key": key_batch}
+
+ if (end_idx - beg_idx) == 1 and kwargs.get(
+ "data_type", None
+ ) == "fbank": # fbank
+ batch["data_in"] = data_batch[0]
+ batch["data_lengths"] = input_len
+
+ time1 = time.perf_counter()
+ with torch.no_grad():
+ res = model.inference(**batch, **kwargs)
+ if isinstance(res, (list, tuple)):
+ results = res[0] if len(res) > 0 else [{"text": ""}]
+ meta_data = res[1] if len(res) > 1 else {}
+ time2 = time.perf_counter()
+
+ asr_result_list.extend(results)
+
+ # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
+ batch_data_time = meta_data.get("batch_data_time", -1)
+ time_escape = time2 - time1
+ speed_stats["load_data"] = meta_data.get("load_data", 0.0)
+ speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
+ speed_stats["forward"] = f"{time_escape:0.3f}"
+ speed_stats["batch_size"] = f"{len(results)}"
+ speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
+ description = f"{speed_stats}, "
+ if pbar:
+ pbar.update(end_idx - beg_idx)
+ pbar.set_description(description)
+ time_speech_total += batch_data_time
+ time_escape_total += time_escape
+
+ if pbar:
+ # pbar.update(1)
+ pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
+ torch.cuda.empty_cache()
+ return asr_result_list
+
+ def vad(self, input, input_len=None, **cfg):
+ kwargs = self.kwargs
+ # step.1: compute the vad model
+ deep_update(self.vad_kwargs, cfg)
+ beg_vad = time.time()
+ res = self.inference(
+ input,
+ input_len=input_len,
+ model=self.vad_model,
+ kwargs=self.vad_kwargs,
+ **cfg,
+ )
+ end_vad = time.time()
+ # FIX(gcf): concat the vad clips for sense vocie model for better aed
+ if cfg.get("merge_vad", False):
+ for i in range(len(res)):
+ res[i]["value"] = merge_vad(
+ res[i]["value"], kwargs.get("merge_length_s", 15) * 1000
+ )
+ elapsed = end_vad - beg_vad
+ return elapsed, res
+
+ def inference_with_vadres(self, input, vad_res, input_len=None, **cfg):
+
+ kwargs = self.kwargs
+
+ # step.2 compute asr model
+ model = self.model
+ deep_update(kwargs, cfg)
+ batch_size = max(int(kwargs.get("batch_size_s", 300)) * 1000, 1)
+ batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60)) * 1000
+ kwargs["batch_size"] = batch_size
+
+ key_list, data_list = prepare_data_iterator(
+ input, input_len=input_len, data_type=kwargs.get("data_type", None)
+ )
+ results_ret_list = []
+ time_speech_total_all_samples = 1e-6
+
+ beg_total = time.time()
+ pbar_total = (
+ tqdm(colour="red", total=len(vad_res), dynamic_ncols=True)
+ if not kwargs.get("disable_pbar", False)
+ else None
+ )
+
+ for i in range(len(vad_res)):
+ key = vad_res[i]["key"]
+ vadsegments = vad_res[i]["value"]
+ input_i = data_list[i]
+ fs = kwargs["frontend"].fs if hasattr(kwargs["frontend"], "fs") else 16000
+ speech = load_audio_text_image_video(
+ input_i, fs=fs, audio_fs=kwargs.get("fs", 16000)
+ )
+ speech_lengths = len(speech)
+ n = len(vadsegments)
+ data_with_index = [(vadsegments[i], i) for i in range(n)]
+ sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
+ results_sorted = []
+
+ if not len(sorted_data):
+ results_ret_list.append({"key": key, "text": "", "timestamp": []})
+ logging.info("decoding, utt: {}, empty speech".format(key))
+ continue
+
+ if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
+ batch_size = max(
+ batch_size, sorted_data[0][0][1] - sorted_data[0][0][0]
+ )
+
+ if kwargs["device"] == "cpu":
+ batch_size = 0
+
+ beg_idx = 0
+ beg_asr_total = time.time()
+ time_speech_total_per_sample = speech_lengths / 16000
+ time_speech_total_all_samples += time_speech_total_per_sample
+
+ # pbar_sample = tqdm(colour="blue", total=n, dynamic_ncols=True)
+
+ all_segments = []
+ max_len_in_batch = 0
+ end_idx = 1
+
+ for j, _ in enumerate(range(0, n)):
+ # pbar_sample.update(1)
+ sample_length = sorted_data[j][0][1] - sorted_data[j][0][0]
+ potential_batch_length = max(max_len_in_batch, sample_length) * (
+ j + 1 - beg_idx
+ )
+ # batch_size_ms_cum += sorted_data[j][0][1] - sorted_data[j][0][0]
+ if (
+ j < n - 1
+ and sample_length < batch_size_threshold_ms
+ and potential_batch_length < batch_size
+ ):
+ max_len_in_batch = max(max_len_in_batch, sample_length)
+ end_idx += 1
+ continue
+
+ speech_j, speech_lengths_j, intervals = slice_padding_audio_samples(
+ speech, speech_lengths, sorted_data[beg_idx:end_idx]
+ )
+ results = self.inference(
+ speech_j, input_len=None, model=model, kwargs=kwargs, **cfg
+ )
+
+ for _b in range(len(speech_j)):
+ results[_b]["interval"] = intervals[_b]
+
+ if self.spk_model is not None:
+ # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
+ for _b in range(len(speech_j)):
+ vad_segments = [
+ [
+ sorted_data[beg_idx:end_idx][_b][0][0] / 1000.0,
+ sorted_data[beg_idx:end_idx][_b][0][1] / 1000.0,
+ np.array(speech_j[_b]),
+ ]
+ ]
+ segments = sv_chunk(vad_segments)
+ all_segments.extend(segments)
+ speech_b = [i[2] for i in segments]
+ spk_res = self.inference(
+ speech_b,
+ input_len=None,
+ model=self.spk_model,
+ kwargs=kwargs,
+ **cfg,
+ )
+ results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"]
+
+ beg_idx = end_idx
+ end_idx += 1
+ max_len_in_batch = sample_length
+ if len(results) < 1:
+ continue
+ results_sorted.extend(results)
+
+ # end_asr_total = time.time()
+ # time_escape_total_per_sample = end_asr_total - beg_asr_total
+ # pbar_sample.update(1)
+ # pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
+ # f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
+ # f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
+
+ restored_data = [0] * n
+ for j in range(n):
+ index = sorted_data[j][1]
+ cur = results_sorted[j]
+ pattern = r"<\|([^|]+)\|>"
+ emotion_string = re.findall(pattern, cur["text"])
+ cur["text"] = re.sub(pattern, "", cur["text"])
+ cur["emo"] = "".join([f"<|{t}|>" for t in emotion_string])
+ if self.punc_model is not None and len(cur["text"].strip()) > 0:
+ deep_update(self.punc_kwargs, cfg)
+ punc_res = self.inference(
+ cur["text"],
+ model=self.punc_model,
+ kwargs=self.punc_kwargs,
+ **cfg,
+ )
+ cur["text"] = punc_res[0]["text"]
+
+ restored_data[index] = cur
+
+ end_asr_total = time.time()
+ time_escape_total_per_sample = end_asr_total - beg_asr_total
+ if pbar_total:
+ pbar_total.update(1)
+ pbar_total.set_description(
+ f"rtf_avg: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
+ f"time_speech: {time_speech_total_per_sample: 0.3f}, "
+ f"time_escape: {time_escape_total_per_sample:0.3f}"
+ )
+
+ # end_total = time.time()
+ # time_escape_total_all_samples = end_total - beg_total
+ # print(f"rtf_avg_all: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, "
+ # f"time_speech_all: {time_speech_total_all_samples: 0.3f}, "
+ # f"time_escape_all: {time_escape_total_all_samples:0.3f}")
+ return restored_data
+
+ def export(self, input=None, **cfg):
+ """
+
+ :param input:
+ :param type:
+ :param quantize:
+ :param fallback_num:
+ :param calib_num:
+ :param opset_version:
+ :param cfg:
+ :return:
+ """
+
+ device = cfg.get("device", "cpu")
+ model = self.model.to(device=device)
+ kwargs = self.kwargs
+ deep_update(kwargs, cfg)
+ kwargs["device"] = device
+ del kwargs["model"]
+ model.eval()
+
+ type = kwargs.get("type", "onnx")
+
+ key_list, data_list = prepare_data_iterator(
+ input, input_len=None, data_type=kwargs.get("data_type", None), key=None
+ )
+
+ with torch.no_grad():
+ export_dir = export_utils.export(model=model, data_in=data_list, **kwargs)
+
+ return export_dir
diff --git a/tools/sensevoice/fun_asr.py b/tools/sensevoice/fun_asr.py
new file mode 100644
index 0000000000000000000000000000000000000000..6789316d5186db69c021758094649553c3638f66
--- /dev/null
+++ b/tools/sensevoice/fun_asr.py
@@ -0,0 +1,332 @@
+import gc
+import os
+import re
+
+from audio_separator.separator import Separator
+
+os.environ["MODELSCOPE_CACHE"] = "./.cache/funasr"
+os.environ["UVR5_CACHE"] = "./.cache/uvr5-models"
+import json
+import subprocess
+from pathlib import Path
+
+import click
+import torch
+from loguru import logger
+from pydub import AudioSegment
+from silero_vad import get_speech_timestamps, load_silero_vad, read_audio
+from tqdm import tqdm
+
+from tools.file import AUDIO_EXTENSIONS, VIDEO_EXTENSIONS, list_files
+from tools.sensevoice.auto_model import AutoModel
+
+
+def uvr5_cli(
+ audio_dir: Path,
+ output_folder: Path,
+ audio_files: list[Path] | None = None,
+ output_format: str = "flac",
+ model: str = "BS-Roformer-Viperx-1297.ckpt",
+):
+ # ["BS-Roformer-Viperx-1297.ckpt", "BS-Roformer-Viperx-1296.ckpt", "BS-Roformer-Viperx-1053.ckpt", "Mel-Roformer-Viperx-1143.ckpt"]
+ sepr = Separator(
+ model_file_dir=os.environ["UVR5_CACHE"],
+ output_dir=output_folder,
+ output_format=output_format,
+ )
+ dictmodel = {
+ "BS-Roformer-Viperx-1297.ckpt": "model_bs_roformer_ep_317_sdr_12.9755.ckpt",
+ "BS-Roformer-Viperx-1296.ckpt": "model_bs_roformer_ep_368_sdr_12.9628.ckpt",
+ "BS-Roformer-Viperx-1053.ckpt": "model_bs_roformer_ep_937_sdr_10.5309.ckpt",
+ "Mel-Roformer-Viperx-1143.ckpt": "model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt",
+ }
+ roformer_model = dictmodel[model]
+ sepr.load_model(roformer_model)
+ if audio_files is None:
+ audio_files = list_files(
+ path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
+ )
+ total_files = len(audio_files)
+
+ print(f"{total_files} audio files found")
+
+ res = []
+ for audio in tqdm(audio_files, desc="Denoising: "):
+ file_path = str(audio_dir / audio)
+ sep_out = sepr.separate(file_path)
+ if isinstance(sep_out, str):
+ res.append(sep_out)
+ elif isinstance(sep_out, list):
+ res.extend(sep_out)
+ del sepr
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ return res, roformer_model
+
+
+def get_sample_rate(media_path: Path):
+ result = subprocess.run(
+ [
+ "ffprobe",
+ "-v",
+ "quiet",
+ "-print_format",
+ "json",
+ "-show_streams",
+ str(media_path),
+ ],
+ capture_output=True,
+ text=True,
+ check=True,
+ )
+ media_info = json.loads(result.stdout)
+ for stream in media_info.get("streams", []):
+ if stream.get("codec_type") == "audio":
+ return stream.get("sample_rate")
+ return "44100" # Default sample rate if not found
+
+
+def convert_to_mono(src_path: Path, out_path: Path, out_fmt: str = "wav"):
+ sr = get_sample_rate(src_path)
+ out_path.parent.mkdir(parents=True, exist_ok=True)
+ if src_path.resolve() == out_path.resolve():
+ output = str(out_path.with_stem(out_path.stem + f"_{sr}"))
+ else:
+ output = str(out_path)
+ subprocess.run(
+ [
+ "ffmpeg",
+ "-loglevel",
+ "error",
+ "-i",
+ str(src_path),
+ "-acodec",
+ "pcm_s16le" if out_fmt == "wav" else "flac",
+ "-ar",
+ sr,
+ "-ac",
+ "1",
+ "-y",
+ output,
+ ],
+ check=True,
+ )
+ return out_path
+
+
+def convert_video_to_audio(video_path: Path, audio_dir: Path):
+ cur_dir = audio_dir / video_path.relative_to(audio_dir).parent
+ vocals = [
+ p
+ for p in cur_dir.glob(f"{video_path.stem}_(Vocals)*.*")
+ if p.suffix in AUDIO_EXTENSIONS
+ ]
+ if len(vocals) > 0:
+ return vocals[0]
+ audio_path = cur_dir / f"{video_path.stem}.wav"
+ convert_to_mono(video_path, audio_path)
+ return audio_path
+
+
+@click.command()
+@click.option("--audio-dir", required=True, help="Directory containing audio files")
+@click.option(
+ "--save-dir", required=True, help="Directory to save processed audio files"
+)
+@click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
+@click.option("--language", default="auto", help="Language of the transcription")
+@click.option(
+ "--max_single_segment_time",
+ default=20000,
+ type=int,
+ help="Maximum of Output single audio duration(ms)",
+)
+@click.option("--fsmn-vad/--silero-vad", default=False)
+@click.option("--punc/--no-punc", default=False)
+@click.option("--denoise/--no-denoise", default=False)
+@click.option("--save_emo/--no_save_emo", default=False)
+def main(
+ audio_dir: str,
+ save_dir: str,
+ device: str,
+ language: str,
+ max_single_segment_time: int,
+ fsmn_vad: bool,
+ punc: bool,
+ denoise: bool,
+ save_emo: bool,
+):
+
+ audios_path = Path(audio_dir)
+ save_path = Path(save_dir)
+ save_path.mkdir(parents=True, exist_ok=True)
+
+ video_files = list_files(
+ path=audio_dir, extensions=VIDEO_EXTENSIONS, recursive=True
+ )
+ v2a_files = [convert_video_to_audio(p, audio_dir) for p in video_files]
+
+ if denoise:
+ VOCAL = "_(Vocals)"
+ original_files = [
+ p
+ for p in audios_path.glob("**/*")
+ if p.suffix in AUDIO_EXTENSIONS and VOCAL not in p.stem
+ ]
+
+ _, cur_model = uvr5_cli(
+ audio_dir=audio_dir, output_folder=audio_dir, audio_files=original_files
+ )
+ need_remove = [p for p in audios_path.glob("**/*(Instrumental)*")]
+ need_remove.extend(original_files)
+ for _ in need_remove:
+ _.unlink()
+ vocal_files = [
+ p
+ for p in audios_path.glob("**/*")
+ if p.suffix in AUDIO_EXTENSIONS and VOCAL in p.stem
+ ]
+ for f in vocal_files:
+ fn, ext = f.stem, f.suffix
+
+ v_pos = fn.find(VOCAL + "_" + cur_model.split(".")[0])
+ if v_pos != -1:
+ new_fn = fn[: v_pos + len(VOCAL)]
+ new_f = f.with_name(new_fn + ext)
+ f = f.rename(new_f)
+ convert_to_mono(f, f, "flac")
+ f.unlink()
+
+ audio_files = list_files(
+ path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
+ )
+
+ logger.info("Loading / Downloading Funasr model...")
+
+ model_dir = "iic/SenseVoiceSmall"
+
+ vad_model = "fsmn-vad" if fsmn_vad else None
+ vad_kwargs = {"max_single_segment_time": max_single_segment_time}
+ punc_model = "ct-punc" if punc else None
+
+ manager = AutoModel(
+ model=model_dir,
+ trust_remote_code=False,
+ vad_model=vad_model,
+ vad_kwargs=vad_kwargs,
+ punc_model=punc_model,
+ device=device,
+ )
+
+ if not fsmn_vad and vad_model is None:
+ vad_model = load_silero_vad()
+
+ logger.info("Model loaded.")
+
+ pattern = re.compile(r"_\d{3}\.")
+
+ for file_path in tqdm(audio_files, desc="Processing audio file"):
+
+ if pattern.search(file_path.name):
+ # logger.info(f"Skipping {file_path} as it has already been processed.")
+ continue
+
+ file_stem = file_path.stem
+ file_suffix = file_path.suffix
+
+ rel_path = Path(file_path).relative_to(audio_dir)
+ (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
+
+ audio = AudioSegment.from_file(file_path)
+
+ cfg = dict(
+ cache={},
+ language=language, # "zh", "en", "yue", "ja", "ko", "nospeech"
+ use_itn=False,
+ batch_size_s=60,
+ )
+
+ if fsmn_vad:
+ elapsed, vad_res = manager.vad(input=str(file_path), **cfg)
+ else:
+ wav = read_audio(
+ str(file_path)
+ ) # backend (sox, soundfile, or ffmpeg) required!
+ audio_key = file_path.stem
+ audio_val = []
+ speech_timestamps = get_speech_timestamps(
+ wav,
+ vad_model,
+ max_speech_duration_s=max_single_segment_time // 1000,
+ return_seconds=True,
+ )
+
+ audio_val = [
+ [int(timestamp["start"] * 1000), int(timestamp["end"] * 1000)]
+ for timestamp in speech_timestamps
+ ]
+ vad_res = []
+ vad_res.append(dict(key=audio_key, value=audio_val))
+
+ res = manager.inference_with_vadres(
+ input=str(file_path), vad_res=vad_res, **cfg
+ )
+
+ for i, info in enumerate(res):
+ [start_ms, end_ms] = info["interval"]
+ text = info["text"]
+ emo = info["emo"]
+ sliced_audio = audio[start_ms:end_ms]
+ audio_save_path = (
+ save_path / rel_path.parent / f"{file_stem}_{i:03d}{file_suffix}"
+ )
+ sliced_audio.export(audio_save_path, format=file_suffix[1:])
+ print(f"Exported {audio_save_path}: {text}")
+
+ transcript_save_path = (
+ save_path / rel_path.parent / f"{file_stem}_{i:03d}.lab"
+ )
+ with open(
+ transcript_save_path,
+ "w",
+ encoding="utf-8",
+ ) as f:
+ f.write(text)
+
+ if save_emo:
+ emo_save_path = save_path / rel_path.parent / f"{file_stem}_{i:03d}.emo"
+ with open(
+ emo_save_path,
+ "w",
+ encoding="utf-8",
+ ) as f:
+ f.write(emo)
+
+ if audios_path.resolve() == save_path.resolve():
+ file_path.unlink()
+
+
+if __name__ == "__main__":
+ main()
+ exit(0)
+ from funasr.utils.postprocess_utils import rich_transcription_postprocess
+
+ # Load the audio file
+ audio_path = Path(r"D:\PythonProject\ok\1_output_(Vocals).wav")
+ model_dir = "iic/SenseVoiceSmall"
+ m, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device="cuda:0")
+ m.eval()
+
+ res = m.inference(
+ data_in=f"{kwargs['model_path']}/example/zh.mp3",
+ language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
+ use_itn=False,
+ ban_emo_unk=False,
+ **kwargs,
+ )
+
+ print(res)
+ text = rich_transcription_postprocess(res[0][0]["text"])
+ print(text)
diff --git a/tools/sensevoice/vad_utils.py b/tools/sensevoice/vad_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bef75ed8c2841701fff44f7130e91ef8dfdf8cc
--- /dev/null
+++ b/tools/sensevoice/vad_utils.py
@@ -0,0 +1,61 @@
+import torch
+from torch.nn.utils.rnn import pad_sequence
+
+
+def slice_padding_fbank(speech, speech_lengths, vad_segments):
+ speech_list = []
+ speech_lengths_list = []
+ for i, segment in enumerate(vad_segments):
+
+ bed_idx = int(segment[0][0] * 16)
+ end_idx = min(int(segment[0][1] * 16), speech_lengths[0])
+ speech_i = speech[0, bed_idx:end_idx]
+ speech_lengths_i = end_idx - bed_idx
+ speech_list.append(speech_i)
+ speech_lengths_list.append(speech_lengths_i)
+ feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
+ speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
+ return feats_pad, speech_lengths_pad
+
+
+def slice_padding_audio_samples(speech, speech_lengths, vad_segments):
+ speech_list = []
+ speech_lengths_list = []
+ intervals = []
+ for i, segment in enumerate(vad_segments):
+ bed_idx = int(segment[0][0] * 16)
+ end_idx = min(int(segment[0][1] * 16), speech_lengths)
+ speech_i = speech[bed_idx:end_idx]
+ speech_lengths_i = end_idx - bed_idx
+ speech_list.append(speech_i)
+ speech_lengths_list.append(speech_lengths_i)
+ intervals.append([bed_idx // 16, end_idx // 16])
+
+ return speech_list, speech_lengths_list, intervals
+
+
+def merge_vad(vad_result, max_length=15000, min_length=0):
+ new_result = []
+ if len(vad_result) <= 1:
+ return vad_result
+ time_step = [t[0] for t in vad_result] + [t[1] for t in vad_result]
+ time_step = sorted(list(set(time_step)))
+ if len(time_step) == 0:
+ return []
+ bg = 0
+ for i in range(len(time_step) - 1):
+ time = time_step[i]
+ if time_step[i + 1] - bg < max_length:
+ continue
+ if time - bg > min_length:
+ new_result.append([bg, time])
+ # if time - bg < max_length * 1.5:
+ # new_result.append([bg, time])
+ # else:
+ # split_num = int(time - bg) // max_length + 1
+ # spl_l = int(time - bg) // split_num
+ # for j in range(split_num):
+ # new_result.append([bg + j * spl_l, bg + (j + 1) * spl_l])
+ bg = time
+ new_result.append([bg, time_step[-1]])
+ return new_result
diff --git a/tools/smart_pad.py b/tools/smart_pad.py
new file mode 100644
index 0000000000000000000000000000000000000000..de9dc154f26b2869a7e34f7d4cd95db741ee4c6a
--- /dev/null
+++ b/tools/smart_pad.py
@@ -0,0 +1,60 @@
+import random
+from multiprocessing import Pool
+from pathlib import Path
+
+import click
+import librosa
+import torch.nn.functional as F
+import torchaudio
+from tqdm import tqdm
+
+from tools.file import AUDIO_EXTENSIONS, list_files
+
+threshold = 10 ** (-50 / 20.0)
+
+
+def process(file):
+ waveform, sample_rate = torchaudio.load(str(file), backend="sox")
+ if waveform.size(0) > 1:
+ waveform = waveform.mean(dim=0, keepdim=True)
+
+ loudness = librosa.feature.rms(
+ y=waveform.numpy().squeeze(), frame_length=2048, hop_length=512, center=True
+ )[0]
+
+ for i in range(len(loudness) - 1, 0, -1):
+ if loudness[i] > threshold:
+ break
+
+ end_silent_time = (len(loudness) - i) * 512 / sample_rate
+
+ if end_silent_time <= 0.3:
+ random_time = random.uniform(0.3, 0.7) - end_silent_time
+ waveform = F.pad(
+ waveform, (0, int(random_time * sample_rate)), mode="constant", value=0
+ )
+
+ for i in range(len(loudness)):
+ if loudness[i] > threshold:
+ break
+
+ start_silent_time = i * 512 / sample_rate
+
+ if start_silent_time > 0.02:
+ waveform = waveform[:, int((start_silent_time - 0.02) * sample_rate) :]
+
+ torchaudio.save(uri=str(file), src=waveform, sample_rate=sample_rate)
+
+
+@click.command()
+@click.argument("source", type=Path)
+@click.option("--num-workers", type=int, default=12)
+def main(source, num_workers):
+ files = list(list_files(source, AUDIO_EXTENSIONS, recursive=True))
+
+ with Pool(num_workers) as p:
+ list(tqdm(p.imap_unordered(process, files), total=len(files)))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/vqgan/create_train_split.py b/tools/vqgan/create_train_split.py
new file mode 100644
index 0000000000000000000000000000000000000000..d24a5f39566c47ea0cb1fc506d463e9c95c3efbc
--- /dev/null
+++ b/tools/vqgan/create_train_split.py
@@ -0,0 +1,83 @@
+import math
+from pathlib import Path
+from random import Random
+
+import click
+from loguru import logger
+from pydub import AudioSegment
+from tqdm import tqdm
+
+from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
+
+
+@click.command()
+@click.argument("root", type=click.Path(exists=True, path_type=Path))
+@click.option("--val-ratio", type=float, default=None)
+@click.option("--val-count", type=int, default=None)
+@click.option("--filelist", default=None, type=Path)
+@click.option("--min-duration", default=None, type=float)
+@click.option("--max-duration", default=None, type=float)
+def main(root, val_ratio, val_count, filelist, min_duration, max_duration):
+ if filelist:
+ files = [i[0] for i in load_filelist(filelist)]
+ else:
+ files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
+
+ if min_duration is None and max_duration is None:
+ filtered_files = list(map(str, [file.relative_to(root) for file in files]))
+ else:
+ filtered_files = []
+ for file in tqdm(files):
+ try:
+ audio = AudioSegment.from_file(str(file))
+ duration = len(audio) / 1000.0
+
+ if min_duration is not None and duration < min_duration:
+ logger.info(
+ f"Skipping {file} due to duration {duration:.2f} < {min_duration:.2f}"
+ )
+ continue
+
+ if max_duration is not None and duration > max_duration:
+ logger.info(
+ f"Skipping {file} due to duration {duration:.2f} > {max_duration:.2f}"
+ )
+ continue
+
+ filtered_files.append(str(file.relative_to(root)))
+ except Exception as e:
+ logger.info(f"Error processing {file}: {e}")
+
+ logger.info(
+ f"Found {len(files)} files, remaining {len(filtered_files)} files after filtering"
+ )
+
+ Random(42).shuffle(filtered_files)
+
+ if val_count is None and val_ratio is None:
+ logger.info("Validation ratio and count not specified, using min(20%, 100)")
+ val_size = min(100, math.ceil(len(filtered_files) * 0.2))
+ elif val_count is not None and val_ratio is not None:
+ logger.error("Cannot specify both val_count and val_ratio")
+ return
+ elif val_count is not None:
+ if val_count < 1 or val_count > len(filtered_files):
+ logger.error("val_count must be between 1 and number of files")
+ return
+ val_size = val_count
+ else:
+ val_size = math.ceil(len(filtered_files) * val_ratio)
+
+ logger.info(f"Using {val_size} files for validation")
+
+ with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
+ f.write("\n".join(filtered_files[val_size:]))
+
+ with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
+ f.write("\n".join(filtered_files[:val_size]))
+
+ logger.info("Done")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/vqgan/extract_vq.py b/tools/vqgan/extract_vq.py
new file mode 100644
index 0000000000000000000000000000000000000000..11e8e1431b690a8f35c5bd155b06337082b34ee6
--- /dev/null
+++ b/tools/vqgan/extract_vq.py
@@ -0,0 +1,233 @@
+import os
+import subprocess as sp
+import sys
+import time
+from datetime import timedelta
+from functools import lru_cache
+from pathlib import Path
+from random import Random
+
+import click
+import numpy as np
+import torch
+import torchaudio
+from hydra import compose, initialize
+from hydra.utils import instantiate
+from lightning import LightningModule
+from loguru import logger
+from omegaconf import OmegaConf
+
+from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
+
+# register eval resolver
+OmegaConf.register_new_resolver("eval", eval)
+# This file is used to convert the audio files to text files using the Whisper model.
+# It's mainly used to generate the training data for the VQ model.
+
+backends = torchaudio.list_audio_backends()
+
+if "ffmpeg" in backends:
+ backend = "ffmpeg"
+else:
+ backend = "soundfile"
+
+RANK = int(os.environ.get("SLURM_PROCID", 0))
+WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
+
+logger_format = (
+ "{time:YYYY-MM-DD HH:mm:ss.SSS} | "
+ "{level: <8} | "
+ "{name} :{function} :{line} | "
+ "{extra[rank]} - {message} "
+)
+logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
+logger.remove()
+logger.add(sys.stderr, format=logger_format)
+
+
+@lru_cache(maxsize=1)
+def get_model(
+ config_name: str = "firefly_gan_vq",
+ checkpoint_path: str = "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+ device: str | torch.device = "cuda",
+):
+ with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
+ cfg = compose(config_name=config_name)
+
+ model = instantiate(cfg)
+ state_dict = torch.load(
+ checkpoint_path,
+ map_location=device,
+ )
+ if "state_dict" in state_dict:
+ state_dict = state_dict["state_dict"]
+
+ if any("generator" in k for k in state_dict):
+ state_dict = {
+ k.replace("generator.", ""): v
+ for k, v in state_dict.items()
+ if "generator." in k
+ }
+
+ model.load_state_dict(state_dict, strict=False)
+ model.eval()
+ model.to(device)
+
+ logger.info(f"Loaded model")
+ return model
+
+
+@torch.inference_mode()
+def process_batch(files: list[Path], model) -> float:
+ wavs = []
+ audio_lengths = []
+ new_files = []
+ max_length = total_time = 0
+
+ for file in files:
+ try:
+ wav, sr = torchaudio.load(
+ str(file), backend=backend
+ ) # Need to install libsox-dev
+ except Exception as e:
+ logger.error(f"Error reading {file}: {e}")
+ continue
+
+ if wav.shape[0] > 1:
+ wav = wav.mean(dim=0, keepdim=True)
+
+ wav = torchaudio.functional.resample(
+ wav.cuda(), sr, model.spec_transform.sample_rate
+ )[0]
+ total_time += len(wav) / model.spec_transform.sample_rate
+ max_length = max(max_length, len(wav))
+
+ wavs.append(wav)
+ audio_lengths.append(len(wav))
+ new_files.append(file)
+
+ files = new_files
+
+ # Pad to max length
+ for i, wav in enumerate(wavs):
+ wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
+
+ audios = torch.stack(wavs, dim=0)[:, None]
+ audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long)
+
+ # Calculate lengths
+ indices, feature_lengths = model.encode(audios, audio_lengths)
+
+ # Save to disk
+ outputs = indices.cpu().numpy()
+
+ for file, length, feature, audio_length in zip(
+ files, feature_lengths, outputs, audio_lengths
+ ):
+ feature = feature[:, :length]
+
+ # (T,)
+ with open(file.with_suffix(".npy"), "wb") as f:
+ np.save(f, feature)
+
+ return total_time
+
+
+@click.command()
+@click.argument("folder")
+@click.option("--num-workers", default=1)
+@click.option("--config-name", default="firefly_gan_vq")
+@click.option(
+ "--checkpoint-path",
+ default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+)
+@click.option("--batch-size", default=64)
+@click.option("--filelist", default=None, type=Path)
+def main(
+ folder: str,
+ num_workers: int,
+ config_name: str,
+ checkpoint_path: str,
+ batch_size: int,
+ filelist: Path,
+):
+ if num_workers > 1 and WORLD_SIZE != num_workers:
+ assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
+
+ logger.info(f"Spawning {num_workers} workers")
+
+ if torch.cuda.is_available():
+ visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
+ if visible_devices is None:
+ visible_devices = list(range(torch.cuda.device_count()))
+ else:
+ visible_devices = visible_devices.split(",")
+ else:
+ # Set to empty string to avoid using GPU
+ visible_devices = [""]
+
+ processes = []
+ for i in range(num_workers):
+ env = os.environ.copy()
+ env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
+ env["SLURM_PROCID"] = str(i)
+ env["SLURM_NTASKS"] = str(num_workers)
+
+ processes.append(
+ sp.Popen(
+ [sys.executable] + sys.argv.copy(),
+ env=env,
+ )
+ )
+
+ for p in processes:
+ p.wait()
+
+ logger.info(f"All workers finished")
+ return
+
+ # This is a worker
+ logger.info(f"Starting worker")
+ if filelist:
+ files = [i[0] for i in load_filelist(filelist)]
+ else:
+ files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=False)
+
+ print(f"Found {len(files)} files")
+ files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()]
+
+ total_files = len(files)
+ files = files[RANK::WORLD_SIZE]
+ logger.info(f"Processing {len(files)}/{total_files} files")
+
+ # Batch processing
+ total_time = 0
+ begin_time = time.time()
+ processed_files = 0
+ model = get_model(config_name, checkpoint_path)
+
+ for n_batch, idx in enumerate(range(0, len(files), batch_size)):
+ batch = files[idx : idx + batch_size]
+ batch_time = process_batch(batch, model)
+
+ total_time += batch_time
+ processed_files += len(batch)
+
+ if (n_batch + 1) % 10 == 0:
+ eta = (
+ (time.time() - begin_time)
+ / processed_files
+ * (len(files) - processed_files)
+ )
+ logger.info(
+ f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
+ + f"ETA: {timedelta(seconds=round(eta))}s"
+ )
+
+ logger.info(
+ f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/vqgan/inference.py b/tools/vqgan/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c69d99d80922e43e594f34e593fcfbbc321d25c
--- /dev/null
+++ b/tools/vqgan/inference.py
@@ -0,0 +1,121 @@
+from pathlib import Path
+
+import click
+import hydra
+import numpy as np
+import soundfile as sf
+import torch
+import torchaudio
+from hydra import compose, initialize
+from hydra.utils import instantiate
+from loguru import logger
+from omegaconf import OmegaConf
+
+from tools.file import AUDIO_EXTENSIONS
+
+# register eval resolver
+OmegaConf.register_new_resolver("eval", eval)
+
+
+def load_model(config_name, checkpoint_path, device="cuda"):
+ hydra.core.global_hydra.GlobalHydra.instance().clear()
+ with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
+ cfg = compose(config_name=config_name)
+
+ model = instantiate(cfg)
+ state_dict = torch.load(
+ checkpoint_path, map_location=device, mmap=True, weights_only=True
+ )
+ if "state_dict" in state_dict:
+ state_dict = state_dict["state_dict"]
+
+ if any("generator" in k for k in state_dict):
+ state_dict = {
+ k.replace("generator.", ""): v
+ for k, v in state_dict.items()
+ if "generator." in k
+ }
+
+ result = model.load_state_dict(state_dict, strict=False, assign=True)
+ model.eval()
+ model.to(device)
+
+ logger.info(f"Loaded model: {result}")
+ return model
+
+
+@torch.no_grad()
+@click.command()
+@click.option(
+ "--input-path",
+ "-i",
+ default="test.wav",
+ type=click.Path(exists=True, path_type=Path),
+)
+@click.option(
+ "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
+)
+@click.option("--config-name", default="firefly_gan_vq")
+@click.option(
+ "--checkpoint-path",
+ default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+)
+@click.option(
+ "--device",
+ "-d",
+ default="cuda",
+)
+def main(input_path, output_path, config_name, checkpoint_path, device):
+ model = load_model(config_name, checkpoint_path, device=device)
+
+ if input_path.suffix in AUDIO_EXTENSIONS:
+ logger.info(f"Processing in-place reconstruction of {input_path}")
+
+ # Load audio
+ audio, sr = torchaudio.load(str(input_path))
+ if audio.shape[0] > 1:
+ audio = audio.mean(0, keepdim=True)
+ audio = torchaudio.functional.resample(
+ audio, sr, model.spec_transform.sample_rate
+ )
+
+ audios = audio[None].to(device)
+ logger.info(
+ f"Loaded audio with {audios.shape[2] / model.spec_transform.sample_rate:.2f} seconds"
+ )
+
+ # VQ Encoder
+ audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long)
+ indices = model.encode(audios, audio_lengths)[0][0]
+
+ logger.info(f"Generated indices of shape {indices.shape}")
+
+ # Save indices
+ np.save(output_path.with_suffix(".npy"), indices.cpu().numpy())
+ elif input_path.suffix == ".npy":
+ logger.info(f"Processing precomputed indices from {input_path}")
+ indices = np.load(input_path)
+ indices = torch.from_numpy(indices).to(device).long()
+ assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
+ else:
+ raise ValueError(f"Unknown input type: {input_path}")
+
+ # Restore
+ feature_lengths = torch.tensor([indices.shape[1]], device=device)
+ fake_audios, _ = model.decode(
+ indices=indices[None], feature_lengths=feature_lengths
+ )
+ audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate
+
+ logger.info(
+ f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
+ )
+
+ # Save audio
+ fake_audio = fake_audios[0, 0].float().cpu().numpy()
+ sf.write(output_path, fake_audio, model.spec_transform.sample_rate)
+ logger.info(f"Saved audio to {output_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/webui.py b/tools/webui.py
new file mode 100644
index 0000000000000000000000000000000000000000..6863c63ad7869f8b7edde3de6402af462cb80d96
--- /dev/null
+++ b/tools/webui.py
@@ -0,0 +1,548 @@
+import gc
+import html
+import io
+import os
+import queue
+import wave
+from argparse import ArgumentParser
+from functools import partial
+from pathlib import Path
+
+import gradio as gr
+import librosa
+import numpy as np
+import pyrootutils
+import torch
+from loguru import logger
+from transformers import AutoTokenizer
+
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+
+
+from fish_speech.i18n import i18n
+from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
+from fish_speech.utils import autocast_exclude_mps, set_seed
+from tools.api import decode_vq_tokens, encode_reference
+from tools.file import AUDIO_EXTENSIONS, list_files
+from tools.llama.generate import (
+ GenerateRequest,
+ GenerateResponse,
+ WrappedGenerateResponse,
+ launch_thread_safe_queue,
+)
+from tools.vqgan.inference import load_model as load_decoder_model
+
+# Make einx happy
+os.environ["EINX_FILTER_TRACEBACK"] = "false"
+
+
+HEADER_MD = f"""# Fish Speech
+
+{i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")}
+
+{i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.4).")}
+
+{i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")}
+
+{i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")}
+"""
+
+TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
+SPACE_IMPORTED = False
+
+
+def build_html_error_message(error):
+ return f"""
+
+ {html.escape(str(error))}
+
+ """
+
+
+@torch.inference_mode()
+def inference(
+ text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ seed="0",
+ streaming=False,
+):
+ if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
+ return (
+ None,
+ None,
+ i18n("Text is too long, please keep it under {} characters.").format(
+ args.max_gradio_length
+ ),
+ )
+
+ seed = int(seed)
+ if seed != 0:
+ set_seed(seed)
+ logger.warning(f"set seed: {seed}")
+
+ # Parse reference audio aka prompt
+ prompt_tokens = encode_reference(
+ decoder_model=decoder_model,
+ reference_audio=reference_audio,
+ enable_reference_audio=enable_reference_audio,
+ )
+
+ # LLAMA Inference
+ request = dict(
+ device=decoder_model.device,
+ max_new_tokens=max_new_tokens,
+ text=text,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ temperature=temperature,
+ compile=args.compile,
+ iterative_prompt=chunk_length > 0,
+ chunk_length=chunk_length,
+ max_length=2048,
+ prompt_tokens=prompt_tokens if enable_reference_audio else None,
+ prompt_text=reference_text if enable_reference_audio else None,
+ )
+
+ response_queue = queue.Queue()
+ llama_queue.put(
+ GenerateRequest(
+ request=request,
+ response_queue=response_queue,
+ )
+ )
+
+ if streaming:
+ yield wav_chunk_header(), None, None
+
+ segments = []
+
+ while True:
+ result: WrappedGenerateResponse = response_queue.get()
+ if result.status == "error":
+ yield None, None, build_html_error_message(result.response)
+ break
+
+ result: GenerateResponse = result.response
+ if result.action == "next":
+ break
+
+ with autocast_exclude_mps(
+ device_type=decoder_model.device.type, dtype=args.precision
+ ):
+ fake_audios = decode_vq_tokens(
+ decoder_model=decoder_model,
+ codes=result.codes,
+ )
+
+ fake_audios = fake_audios.float().cpu().numpy()
+ segments.append(fake_audios)
+
+ if streaming:
+ wav_header = wav_chunk_header()
+ audio_data = (fake_audios * 32768).astype(np.int16).tobytes()
+ yield wav_header + audio_data, None, None
+
+ if len(segments) == 0:
+ return (
+ None,
+ None,
+ build_html_error_message(
+ i18n("No audio generated, please check the input text.")
+ ),
+ )
+
+ # No matter streaming or not, we need to return the final audio
+ audio = np.concatenate(segments, axis=0)
+ yield None, (decoder_model.spec_transform.sample_rate, audio), None
+
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ gc.collect()
+
+
+inference_stream = partial(inference, streaming=True)
+
+n_audios = 4
+
+global_audio_list = []
+global_error_list = []
+
+
+def inference_wrapper(
+ text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ seed,
+ batch_infer_num,
+):
+ audios = []
+ errors = []
+
+ for _ in range(batch_infer_num):
+ result = inference(
+ text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ seed,
+ )
+
+ _, audio_data, error_message = next(result)
+
+ audios.append(
+ gr.Audio(value=audio_data if audio_data else None, visible=True),
+ )
+ errors.append(
+ gr.HTML(value=error_message if error_message else None, visible=True),
+ )
+
+ for _ in range(batch_infer_num, n_audios):
+ audios.append(
+ gr.Audio(value=None, visible=False),
+ )
+ errors.append(
+ gr.HTML(value=None, visible=False),
+ )
+
+ return None, *audios, *errors
+
+
+def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
+ buffer = io.BytesIO()
+
+ with wave.open(buffer, "wb") as wav_file:
+ wav_file.setnchannels(channels)
+ wav_file.setsampwidth(bit_depth // 8)
+ wav_file.setframerate(sample_rate)
+
+ wav_header_bytes = buffer.getvalue()
+ buffer.close()
+ return wav_header_bytes
+
+
+def normalize_text(user_input, use_normalization):
+ if use_normalization:
+ return ChnNormedText(raw_text=user_input).normalize()
+ else:
+ return user_input
+
+
+def update_examples():
+ examples_dir = Path("references")
+ examples_dir.mkdir(parents=True, exist_ok=True)
+ example_audios = list_files(examples_dir, AUDIO_EXTENSIONS, recursive=True)
+ return gr.Dropdown(choices=example_audios + [""])
+
+
+def build_app():
+ with gr.Blocks(theme=gr.themes.Base()) as app:
+ gr.Markdown(HEADER_MD)
+
+ # Use light theme by default
+ app.load(
+ None,
+ None,
+ js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
+ % args.theme,
+ )
+
+ # Inference
+ with gr.Row():
+ with gr.Column(scale=3):
+ text = gr.Textbox(
+ label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
+ )
+ refined_text = gr.Textbox(
+ label=i18n("Realtime Transform Text"),
+ placeholder=i18n(
+ "Normalization Result Preview (Currently Only Chinese)"
+ ),
+ lines=5,
+ interactive=False,
+ )
+
+ with gr.Row():
+ if_refine_text = gr.Checkbox(
+ label=i18n("Text Normalization"),
+ value=False,
+ scale=1,
+ )
+
+ with gr.Row():
+ with gr.Column():
+ with gr.Tab(label=i18n("Advanced Config")):
+ with gr.Row():
+ chunk_length = gr.Slider(
+ label=i18n("Iterative Prompt Length, 0 means off"),
+ minimum=50,
+ maximum=300,
+ value=200,
+ step=8,
+ )
+
+ max_new_tokens = gr.Slider(
+ label=i18n(
+ "Maximum tokens per batch, 0 means no limit"
+ ),
+ minimum=0,
+ maximum=2048,
+ value=0, # 0 means no limit
+ step=8,
+ )
+
+ with gr.Row():
+ top_p = gr.Slider(
+ label="Top-P",
+ minimum=0.6,
+ maximum=0.9,
+ value=0.7,
+ step=0.01,
+ )
+
+ repetition_penalty = gr.Slider(
+ label=i18n("Repetition Penalty"),
+ minimum=1,
+ maximum=1.5,
+ value=1.2,
+ step=0.01,
+ )
+
+ with gr.Row():
+ temperature = gr.Slider(
+ label="Temperature",
+ minimum=0.6,
+ maximum=0.9,
+ value=0.7,
+ step=0.01,
+ )
+ seed = gr.Textbox(
+ label="Seed",
+ info="0 means randomized inference, otherwise deterministic",
+ placeholder="any 32-bit-integer",
+ value="0",
+ )
+
+ with gr.Tab(label=i18n("Reference Audio")):
+ with gr.Row():
+ gr.Markdown(
+ i18n(
+ "5 to 10 seconds of reference audio, useful for specifying speaker."
+ )
+ )
+ with gr.Row():
+ enable_reference_audio = gr.Checkbox(
+ label=i18n("Enable Reference Audio"),
+ )
+
+ with gr.Row():
+ example_audio_dropdown = gr.Dropdown(
+ label=i18n("Select Example Audio"),
+ choices=[""],
+ value="",
+ interactive=True,
+ allow_custom_value=True,
+ )
+ with gr.Row():
+ reference_audio = gr.Audio(
+ label=i18n("Reference Audio"),
+ type="filepath",
+ )
+ with gr.Row():
+ reference_text = gr.Textbox(
+ label=i18n("Reference Text"),
+ lines=1,
+ placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
+ value="",
+ )
+ with gr.Tab(label=i18n("Batch Inference")):
+ with gr.Row():
+ batch_infer_num = gr.Slider(
+ label="Batch infer nums",
+ minimum=1,
+ maximum=n_audios,
+ step=1,
+ value=1,
+ )
+
+ with gr.Column(scale=3):
+ for _ in range(n_audios):
+ with gr.Row():
+ error = gr.HTML(
+ label=i18n("Error Message"),
+ visible=True if _ == 0 else False,
+ )
+ global_error_list.append(error)
+ with gr.Row():
+ audio = gr.Audio(
+ label=i18n("Generated Audio"),
+ type="numpy",
+ interactive=False,
+ visible=True if _ == 0 else False,
+ )
+ global_audio_list.append(audio)
+
+ with gr.Row():
+ stream_audio = gr.Audio(
+ label=i18n("Streaming Audio"),
+ streaming=True,
+ autoplay=True,
+ interactive=False,
+ show_download_button=True,
+ )
+ with gr.Row():
+ with gr.Column(scale=3):
+ generate = gr.Button(
+ value="\U0001F3A7 " + i18n("Generate"), variant="primary"
+ )
+ generate_stream = gr.Button(
+ value="\U0001F3A7 " + i18n("Streaming Generate"),
+ variant="primary",
+ )
+
+ text.input(
+ fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text]
+ )
+
+ def select_example_audio(audio_path):
+ audio_path = Path(audio_path)
+ if audio_path.is_file():
+ lab_file = Path(audio_path.with_suffix(".lab"))
+
+ if lab_file.exists():
+ lab_content = lab_file.read_text(encoding="utf-8").strip()
+ else:
+ lab_content = ""
+
+ return str(audio_path), lab_content, True
+ return None, "", False
+
+ # Connect the dropdown to update reference audio and text
+ example_audio_dropdown.change(
+ fn=update_examples, inputs=[], outputs=[example_audio_dropdown]
+ ).then(
+ fn=select_example_audio,
+ inputs=[example_audio_dropdown],
+ outputs=[reference_audio, reference_text, enable_reference_audio],
+ )
+
+ # # Submit
+ generate.click(
+ inference_wrapper,
+ [
+ refined_text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ seed,
+ batch_infer_num,
+ ],
+ [stream_audio, *global_audio_list, *global_error_list],
+ concurrency_limit=1,
+ )
+
+ generate_stream.click(
+ inference_stream,
+ [
+ refined_text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ seed,
+ ],
+ [stream_audio, global_audio_list[0], global_error_list[0]],
+ concurrency_limit=1,
+ )
+ return app
+
+
+def parse_args():
+ parser = ArgumentParser()
+ parser.add_argument(
+ "--llama-checkpoint-path",
+ type=Path,
+ default="checkpoints/fish-speech-1.4",
+ )
+ parser.add_argument(
+ "--decoder-checkpoint-path",
+ type=Path,
+ default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+ )
+ parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
+ parser.add_argument("--device", type=str, default="cuda")
+ parser.add_argument("--half", action="store_true")
+ parser.add_argument("--compile", action="store_true")
+ parser.add_argument("--max-gradio-length", type=int, default=0)
+ parser.add_argument("--theme", type=str, default="light")
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ args.precision = torch.half if args.half else torch.bfloat16
+
+ logger.info("Loading Llama model...")
+ llama_queue = launch_thread_safe_queue(
+ checkpoint_path=args.llama_checkpoint_path,
+ device=args.device,
+ precision=args.precision,
+ compile=args.compile,
+ )
+ logger.info("Llama model loaded, loading VQ-GAN model...")
+
+ decoder_model = load_decoder_model(
+ config_name=args.decoder_config_name,
+ checkpoint_path=args.decoder_checkpoint_path,
+ device=args.device,
+ )
+
+ logger.info("Decoder model loaded, warming up...")
+
+ # Dry run to check if the model is loaded correctly and avoid the first-time latency
+ list(
+ inference(
+ text="Hello, world!",
+ enable_reference_audio=False,
+ reference_audio=None,
+ reference_text="",
+ max_new_tokens=0,
+ chunk_length=200,
+ top_p=0.7,
+ repetition_penalty=1.2,
+ temperature=0.7,
+ )
+ )
+
+ logger.info("Warming up done, launching the web UI...")
+
+ app = build_app()
+ app.launch(show_api=True)
diff --git a/tools/whisper_asr.py b/tools/whisper_asr.py
new file mode 100644
index 0000000000000000000000000000000000000000..42e7de8a185880d3f2afd368d6df3429488465a4
--- /dev/null
+++ b/tools/whisper_asr.py
@@ -0,0 +1,176 @@
+"""
+Used to transcribe all audio files in one folder into another folder.
+e.g.
+Directory structure:
+--pre_data_root
+----SP_1
+------01.wav
+------02.wav
+------......
+----SP_2
+------01.wav
+------02.wav
+------......
+Use
+python tools/whisper_asr.py --audio-dir pre_data_root/SP_1 --save-dir data/SP_1
+to transcribe the first speaker.
+
+Use
+python tools/whisper_asr.py --audio-dir pre_data_root/SP_2 --save-dir data/SP_2
+to transcribe the second speaker.
+
+Note: Be aware of your audio sample rate, which defaults to 44.1kHz.
+"""
+
+import re
+from pathlib import Path
+
+import click
+import soundfile as sf
+from faster_whisper import WhisperModel
+from loguru import logger
+from pydub import AudioSegment
+from tqdm import tqdm
+
+from tools.file import AUDIO_EXTENSIONS, list_files
+
+
+@click.command()
+@click.option("--model-size", default="large-v3", help="Size of the Whisper model")
+@click.option(
+ "--compute-type",
+ default="float16",
+ help="Computation Precision of the Whisper model [float16 / int8_float16 / int8]",
+)
+@click.option("--audio-dir", required=True, help="Directory containing audio files")
+@click.option(
+ "--save-dir", required=True, help="Directory to save processed audio files"
+)
+@click.option(
+ "--sample-rate",
+ default=44100,
+ type=int,
+ help="Output sample rate, default to input sample rate",
+)
+@click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
+@click.option("--language", default="auto", help="Language of the transcription")
+@click.option("--initial-prompt", default=None, help="Initial prompt for transcribing")
+def main(
+ model_size,
+ compute_type,
+ audio_dir,
+ save_dir,
+ sample_rate,
+ device,
+ language,
+ initial_prompt,
+):
+ logger.info("Loading / Downloading Faster Whisper model...")
+
+ model = WhisperModel(
+ model_size,
+ device=device,
+ compute_type=compute_type,
+ download_root="faster_whisper",
+ )
+
+ logger.info("Model loaded.")
+
+ save_path = Path(save_dir)
+ save_path.mkdir(parents=True, exist_ok=True)
+
+ audio_files = list_files(
+ path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
+ )
+
+ for file_path in tqdm(audio_files, desc="Processing audio file"):
+ file_stem = file_path.stem
+ file_suffix = file_path.suffix
+
+ rel_path = Path(file_path).relative_to(audio_dir)
+ (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
+
+ audio = AudioSegment.from_file(file_path)
+
+ segments, info = model.transcribe(
+ file_path,
+ beam_size=5,
+ language=None if language == "auto" else language,
+ initial_prompt=initial_prompt,
+ )
+
+ print(
+ "Detected language '%s' with probability %f"
+ % (info.language, info.language_probability)
+ )
+ print("Total len(ms): ", len(audio))
+
+ whole_text = None
+ for segment in segments:
+ id, start, end, text = (
+ segment.id,
+ segment.start,
+ segment.end,
+ segment.text,
+ )
+ print("Segment %03d [%.2fs -> %.2fs] %s" % (id, start, end, text))
+ if not whole_text:
+ whole_text = text
+ else:
+ whole_text += ", " + text
+
+ whole_text += "."
+
+ audio_save_path = save_path / rel_path.parent / f"{file_stem}{file_suffix}"
+ audio.export(audio_save_path, format=file_suffix[1:])
+ print(f"Exported {audio_save_path}")
+
+ transcript_save_path = save_path / rel_path.parent / f"{file_stem}.lab"
+ with open(
+ transcript_save_path,
+ "w",
+ encoding="utf-8",
+ ) as f:
+ f.write(whole_text)
+
+
+if __name__ == "__main__":
+ main()
+ exit(0)
+
+ audio = AudioSegment.from_wav(
+ r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav"
+ )
+
+ model_size = "large-v3"
+
+ model = WhisperModel(
+ model_size,
+ device="cuda",
+ compute_type="float16",
+ download_root="faster_whisper",
+ )
+
+ segments, info = model.transcribe(
+ r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav",
+ beam_size=5,
+ )
+
+ print(
+ "Detected language '%s' with probability %f"
+ % (info.language, info.language_probability)
+ )
+ print("Total len(ms): ", len(audio))
+
+ for i, segment in enumerate(segments):
+ print(
+ "Segment %03d [%.2fs -> %.2fs] %s"
+ % (i, segment.start, segment.end, segment.text)
+ )
+ start_ms = int(segment.start * 1000)
+ end_ms = int(segment.end * 1000)
+ segment_audio = audio[start_ms:end_ms]
+ segment_audio.export(f"segment_{i:03d}.wav", format="wav")
+ print(f"Exported segment_{i:03d}.wav")
+
+ print("All segments have been exported.")