diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..f0d925dde1ba202159bebd54eb60185e3749f514
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,45 @@
+.vscode
+
+# Pylance
+pyrightconfig.json
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+syn_out/
+checkpoints/
+.gradio
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..5082284c8c2d5e8da6ffe04f46b88534506eeba8
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2025 Resemble AI
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/README.md b/README.md
index b4fc38653ff16015dba9788fbadcd4d96c0d6237..ed8c705fea104945b44e012fff4b808ebc6c4104 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,80 @@
----
-title: Chatterbox
-emoji: 📚
-colorFrom: red
-colorTo: purple
-sdk: gradio
-sdk_version: 5.31.0
-app_file: app.py
-pinned: false
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+---
+title: chatterbox
+app_file: gradio_tts_app.py
+sdk: gradio
+sdk_version: 4.44.1
+---
+
+
+
+# Chatterbox TTS
+
+[](https://resemble-ai.github.io/chatterbox_demopage/)
+[](https://huggingface.co/spaces/ResembleAI/Chatterbox)
+[](https://podonos.com/resembleai/chatterbox)
+
+_Made with ♥️ by
_
+
+We're excited to introduce Chatterbox, [Resemble AI's](https://resemble.ai) first production-grade open source TTS model. Licensed under MIT, Chatterbox has been benchmarked against leading closed-source systems like ElevenLabs, and is consistently preferred in side-by-side evaluations.
+
+Whether you're working on memes, videos, games, or AI agents, Chatterbox brings your content to life. It's also the first open source TTS model to support **emotion exaggeration control**, a powerful feature that makes your voices stand out. Try it now on our [Hugging Face Gradio app.](https://huggingface.co/spaces/ResembleAI/Chatterbox)
+
+If you like the model but need to scale or tune it for higher accuracy, check out our competitively priced TTS service (link). It delivers reliable performance with ultra-low latency of sub 200ms—ideal for production use in agents, applications, or interactive media.
+
+# Key Details
+- SoTA zeroshot TTS
+- 0.5B Llama backbone
+- Unique exaggeration/intensity control
+- Ultra-stable with alignment-informed inference
+- Trained on 0.5M hours of cleaned data
+- Watermarked outputs
+- Easy voice conversion script
+- [Outperforms ElevenLabs](https://podonos.com/resembleai/chatterbox)
+
+# Tips
+- **General Use (TTS and Voice Agents):**
+ - The default settings (`exaggeration=0.5`, `cfg_weight=0.5`) work well for most prompts.
+ - If the reference speaker has a fast speaking style, lowering `cfg_weight` to around `0.3` can improve pacing.
+
+- **Expressive or Dramatic Speech:**
+ - Try lower `cfg_weight` values (e.g. `~0.3`) and increase `exaggeration` to around `0.7` or higher.
+ - Higher `exaggeration` tends to speed up speech; reducing `cfg_weight` helps compensate with slower, more deliberate pacing.
+
+
+# Installation
+```
+pip install chatterbox-tts
+```
+
+
+# Usage
+```python
+import torchaudio as ta
+from chatterbox.tts import ChatterboxTTS
+
+model = ChatterboxTTS.from_pretrained(device="cuda")
+
+text = "Ezreal and Jinx teamed up with Ahri, Yasuo, and Teemo to take down the enemy's Nexus in an epic late-game pentakill."
+wav = model.generate(text)
+ta.save("test-1.wav", wav, model.sr)
+
+# If you want to synthesize with a different voice, specify the audio prompt
+AUDIO_PROMPT_PATH="YOUR_FILE.wav"
+wav = model.generate(text, audio_prompt_path=AUDIO_PROMPT_PATH)
+ta.save("test-2.wav", wav, model.sr)
+```
+See `example_tts.py` for more examples.
+
+# Acknowledgements
+- [Cosyvoice](https://github.com/FunAudioLLM/CosyVoice)
+- [Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning)
+- [HiFT-GAN](https://github.com/yl4579/HiFTNet)
+- [Llama 3](https://github.com/meta-llama/llama3)
+- [S3Tokenizer](https://github.com/xingchensong/S3Tokenizer)
+
+# Built-in PerTh Watermarking for Responsible AI
+
+Every audio file generated by Chatterbox includes [Resemble AI's Perth (Perceptual Threshold) Watermarker](https://github.com/resemble-ai/perth) - imperceptible neural watermarks that survive MP3 compression, audio editing, and common manipulations while maintaining nearly 100% detection accuracy.
+
+# Disclaimer
+Don't use this model to do bad things. Prompts are sourced from freely available data on the internet.
diff --git a/example_for_mac.py b/example_for_mac.py
new file mode 100644
index 0000000000000000000000000000000000000000..f90d1a808aff5ea0aebb327a8a4ff08059b3d392
--- /dev/null
+++ b/example_for_mac.py
@@ -0,0 +1,28 @@
+import torch
+import torchaudio as ta
+from chatterbox.tts import ChatterboxTTS
+
+# Detect device (Mac with M1/M2/M3/M4)
+device = "mps" if torch.backends.mps.is_available() else "cpu"
+map_location = torch.device(device)
+
+torch_load_original = torch.load
+def patched_torch_load(*args, **kwargs):
+ if 'map_location' not in kwargs:
+ kwargs['map_location'] = map_location
+ return torch_load_original(*args, **kwargs)
+
+torch.load = patched_torch_load
+
+model = ChatterboxTTS.from_pretrained(device=device)
+text = "Today is the day. I want to move like a titan at dawn, sweat like a god forging lightning. No more excuses. From now on, my mornings will be temples of discipline. I am going to work out like the gods… every damn day."
+
+# If you want to synthesize with a different voice, specify the audio prompt
+AUDIO_PROMPT_PATH = "YOUR_FILE.wav"
+wav = model.generate(
+ text,
+ audio_prompt_path=AUDIO_PROMPT_PATH,
+ exaggeration=2.0,
+ cfg_weight=0.5
+ )
+ta.save("test-2.wav", wav, model.sr)
diff --git a/example_tts.py b/example_tts.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2d255cfe0bd16ca989d3384e48aa07bc4330f65
--- /dev/null
+++ b/example_tts.py
@@ -0,0 +1,9 @@
+import torchaudio as ta
+from chatterbox.tts import ChatterboxTTS
+
+model = ChatterboxTTS.from_pretrained(device="cuda")
+
+text = "Ezreal and Jinx teamed up with Ahri, Yasuo, and Teemo to take down the enemy's Nexus in an epic late-game pentakill."
+wav = model.generate(text)
+ta.save("test-1.wav", wav, model.sr)
+
diff --git a/example_vc.py b/example_vc.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb3582cc10803137601fd012645e30fd11cf64b7
--- /dev/null
+++ b/example_vc.py
@@ -0,0 +1,6 @@
+from chatterbox.vc import ChatterboxVC
+
+model = ChatterboxVC.from_pretrained("cuda")
+wav = model.generate("tests/trimmed_8b7f38b1.wav")
+import torchaudio as ta
+ta.save("testvc.wav", wav, model.sr)
diff --git a/gradio_tts_app.py b/gradio_tts_app.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bff58d81539fd1c0688a4b04e914c143143748e
--- /dev/null
+++ b/gradio_tts_app.py
@@ -0,0 +1,80 @@
+import random
+import numpy as np
+import torch
+import gradio as gr
+from chatterbox.tts import ChatterboxTTS
+
+
+DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
+
+
+def set_seed(seed: int):
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ random.seed(seed)
+ np.random.seed(seed)
+
+
+def load_model():
+ model = ChatterboxTTS.from_pretrained(DEVICE)
+ return model
+
+
+def generate(model, text, audio_prompt_path, exaggeration, temperature, seed_num, cfgw):
+ if model is None:
+ model = ChatterboxTTS.from_pretrained(DEVICE)
+
+ if seed_num != 0:
+ set_seed(int(seed_num))
+
+ wav = model.generate(
+ text,
+ audio_prompt_path=audio_prompt_path,
+ exaggeration=exaggeration,
+ temperature=temperature,
+ cfg_weight=cfgw,
+ )
+ return (model.sr, wav.squeeze(0).numpy())
+
+
+with gr.Blocks() as demo:
+ model_state = gr.State(None) # Loaded once per session/user
+
+ with gr.Row():
+ with gr.Column():
+ text = gr.Textbox(value="What does the fox say?", label="Text to synthesize")
+ ref_wav = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Reference Audio File", value=None)
+ exaggeration = gr.Slider(0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5)
+ cfg_weight = gr.Slider(0.2, 1, step=.05, label="CFG/Pace", value=0.5)
+
+ with gr.Accordion("More options", open=False):
+ seed_num = gr.Number(value=0, label="Random seed (0 for random)")
+ temp = gr.Slider(0.05, 5, step=.05, label="temperature", value=.8)
+
+ run_btn = gr.Button("Generate", variant="primary")
+
+ with gr.Column():
+ audio_output = gr.Audio(label="Output Audio")
+
+ demo.load(fn=load_model, inputs=[], outputs=model_state)
+
+ run_btn.click(
+ fn=generate,
+ inputs=[
+ model_state,
+ text,
+ ref_wav,
+ exaggeration,
+ temp,
+ seed_num,
+ cfg_weight,
+ ],
+ outputs=audio_output,
+ )
+
+if __name__ == "__main__":
+ demo.queue(
+ max_size=50,
+ default_concurrency_limit=1,
+ ).launch(share=True)
diff --git a/gradio_vc_app.py b/gradio_vc_app.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf8de721df606cd99526359da3e1fb8c93538364
--- /dev/null
+++ b/gradio_vc_app.py
@@ -0,0 +1,27 @@
+import torch
+import gradio as gr
+from chatterbox.vc import ChatterboxVC
+
+
+DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
+
+
+model = ChatterboxVC.from_pretrained(DEVICE)
+def generate(audio, target_voice_path):
+ wav = model.generate(
+ audio, target_voice_path=target_voice_path,
+ )
+ return model.sr, wav.squeeze(0).numpy()
+
+
+demo = gr.Interface(
+ generate,
+ [
+ gr.Audio(sources=["upload", "microphone"], type="filepath", label="Input audio file"),
+ gr.Audio(sources=["upload", "microphone"], type="filepath", label="Target voice audio file (if none, the default voice is used)", value=None),
+ ],
+ "audio",
+)
+
+if __name__ == "__main__":
+ demo.launch()
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..9d2d936aa77a65e658b83e8f5fc8522c3726b02b
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,34 @@
+[project]
+name = "chatterbox-tts"
+version = "0.1.1"
+description = "Chatterbox: Open Source TTS and Voice Conversion by Resemble AI"
+readme = "README.md"
+requires-python = ">=3.8"
+license = {file = "LICENSE"}
+authors = [
+ {name = "resemble-ai", email = "engineering@resemble.ai"}
+]
+dependencies = [
+ "numpy==1.26.0",
+ "resampy==0.4.3",
+ "librosa==0.10.0",
+ "s3tokenizer",
+ "torch==2.6.0",
+ "torchaudio==2.6.0",
+ "transformers==4.46.3",
+ "diffusers==0.29.0",
+ "resemble-perth==1.0.1",
+ "omegaconf==2.3.0",
+ "conformer==0.3.2",
+]
+
+[project.urls]
+Homepage = "https://github.com/resemble-ai/chatterbox"
+Repository = "https://github.com/resemble-ai/chatterbox"
+
+[build-system]
+requires = ["setuptools>=61.0"]
+build-backend = "setuptools.build_meta"
+
+[tool.setuptools.packages.find]
+where = ["src"]
diff --git a/src/chatterbox/__init__.py b/src/chatterbox/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..20b4391a0b756a4a35eae163c8b70533494ef3a1
--- /dev/null
+++ b/src/chatterbox/__init__.py
@@ -0,0 +1,2 @@
+from .tts import ChatterboxTTS
+from .vc import ChatterboxVC
diff --git a/src/chatterbox/models/s3gen/__init__.py b/src/chatterbox/models/s3gen/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7dbf6ada0dfa7c54ba9dd43d63f4878595c59321
--- /dev/null
+++ b/src/chatterbox/models/s3gen/__init__.py
@@ -0,0 +1,2 @@
+from .s3gen import S3Token2Wav as S3Gen
+from .const import S3GEN_SR
diff --git a/src/chatterbox/models/s3gen/const.py b/src/chatterbox/models/s3gen/const.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ba4f14c910e2734e23b775041d60a19c1371663
--- /dev/null
+++ b/src/chatterbox/models/s3gen/const.py
@@ -0,0 +1 @@
+S3GEN_SR = 24000
diff --git a/src/chatterbox/models/s3gen/decoder.py b/src/chatterbox/models/s3gen/decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..a64b1c6c7c7b6958ce2c1b28fa01598334c51da7
--- /dev/null
+++ b/src/chatterbox/models/s3gen/decoder.py
@@ -0,0 +1,317 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import pack, rearrange, repeat
+
+from .utils.mask import add_optional_chunk_mask
+from .matcha.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, \
+ TimestepEmbedding, Upsample1D
+from .matcha.transformer import BasicTransformerBlock
+
+
+def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ assert mask.dtype == torch.bool
+ assert dtype in [torch.float32, torch.bfloat16, torch.float16]
+ mask = mask.to(dtype)
+ # attention mask bias
+ # NOTE(Mddct): torch.finfo jit issues
+ # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
+ mask = (1.0 - mask) * -1.0e+10
+ return mask
+
+
+
+class Transpose(torch.nn.Module):
+ def __init__(self, dim0: int, dim1: int):
+ super().__init__()
+ self.dim0 = dim0
+ self.dim1 = dim1
+
+ def forward(self, x: torch.Tensor):
+ x = torch.transpose(x, self.dim0, self.dim1)
+ return x
+
+
+class CausalBlock1D(Block1D):
+ def __init__(self, dim: int, dim_out: int):
+ super(CausalBlock1D, self).__init__(dim, dim_out)
+ self.block = torch.nn.Sequential(
+ CausalConv1d(dim, dim_out, 3),
+ Transpose(1, 2),
+ nn.LayerNorm(dim_out),
+ Transpose(1, 2),
+ nn.Mish(),
+ )
+
+ def forward(self, x: torch.Tensor, mask: torch.Tensor):
+ output = self.block(x * mask)
+ return output * mask
+
+
+class CausalResnetBlock1D(ResnetBlock1D):
+ def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
+ super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
+ self.block1 = CausalBlock1D(dim, dim_out)
+ self.block2 = CausalBlock1D(dim_out, dim_out)
+
+
+class CausalConv1d(torch.nn.Conv1d):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ dilation: int = 1,
+ groups: int = 1,
+ bias: bool = True,
+ padding_mode: str = 'zeros',
+ device=None,
+ dtype=None
+ ) -> None:
+ super(CausalConv1d, self).__init__(in_channels, out_channels,
+ kernel_size, stride,
+ padding=0, dilation=dilation,
+ groups=groups, bias=bias,
+ padding_mode=padding_mode,
+ device=device, dtype=dtype)
+ assert stride == 1
+ self.causal_padding = (kernel_size - 1, 0)
+
+ def forward(self, x: torch.Tensor):
+ x = F.pad(x, self.causal_padding)
+ x = super(CausalConv1d, self).forward(x)
+ return x
+
+
+class ConditionalDecoder(nn.Module):
+ def __init__(
+ self,
+ in_channels=320,
+ out_channels=80,
+ causal=True,
+ channels=[256],
+ dropout=0.0,
+ attention_head_dim=64,
+ n_blocks=4,
+ num_mid_blocks=12,
+ num_heads=8,
+ act_fn="gelu",
+ ):
+ """
+ This decoder requires an input with the same shape of the target. So, if your text content
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
+ """
+ super().__init__()
+ channels = tuple(channels)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.causal = causal
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
+ time_embed_dim = channels[0] * 4
+ self.time_mlp = TimestepEmbedding(
+ in_channels=in_channels,
+ time_embed_dim=time_embed_dim,
+ act_fn="silu",
+ )
+ self.down_blocks = nn.ModuleList([])
+ self.mid_blocks = nn.ModuleList([])
+ self.up_blocks = nn.ModuleList([])
+
+ # NOTE jrm: `static_chunk_size` is missing?
+ self.static_chunk_size = 0
+
+ output_channel = in_channels
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
+ input_channel = output_channel
+ output_channel = channels[i]
+ is_last = i == len(channels) - 1
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
+ ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
+ transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ dim=output_channel,
+ num_attention_heads=num_heads,
+ attention_head_dim=attention_head_dim,
+ dropout=dropout,
+ activation_fn=act_fn,
+ )
+ for _ in range(n_blocks)
+ ]
+ )
+ downsample = (
+ Downsample1D(output_channel) if not is_last else
+ CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
+ )
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
+
+ for _ in range(num_mid_blocks):
+ input_channel = channels[-1]
+ out_channels = channels[-1]
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
+ ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
+
+ transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ dim=output_channel,
+ num_attention_heads=num_heads,
+ attention_head_dim=attention_head_dim,
+ dropout=dropout,
+ activation_fn=act_fn,
+ )
+ for _ in range(n_blocks)
+ ]
+ )
+
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
+
+ channels = channels[::-1] + (channels[0],)
+ for i in range(len(channels) - 1):
+ input_channel = channels[i] * 2
+ output_channel = channels[i + 1]
+ is_last = i == len(channels) - 2
+ resnet = CausalResnetBlock1D(
+ dim=input_channel,
+ dim_out=output_channel,
+ time_emb_dim=time_embed_dim,
+ ) if self.causal else ResnetBlock1D(
+ dim=input_channel,
+ dim_out=output_channel,
+ time_emb_dim=time_embed_dim,
+ )
+ transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ dim=output_channel,
+ num_attention_heads=num_heads,
+ attention_head_dim=attention_head_dim,
+ dropout=dropout,
+ activation_fn=act_fn,
+ )
+ for _ in range(n_blocks)
+ ]
+ )
+ upsample = (
+ Upsample1D(output_channel, use_conv_transpose=True)
+ if not is_last
+ else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
+ )
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
+ self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv1d):
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.GroupNorm):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x, mask, mu, t, spks=None, cond=None):
+ """Forward pass of the UNet1DConditional model.
+
+ Args:
+ x (torch.Tensor): shape (batch_size, in_channels, time)
+ mask (_type_): shape (batch_size, 1, time)
+ t (_type_): shape (batch_size)
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
+ cond (_type_, optional): placeholder for future use. Defaults to None.
+
+ Raises:
+ ValueError: _description_
+ ValueError: _description_
+
+ Returns:
+ _type_: _description_
+ """
+
+ t = self.time_embeddings(t).to(t.dtype)
+ t = self.time_mlp(t)
+
+ x = pack([x, mu], "b * t")[0]
+
+ if spks is not None:
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
+ x = pack([x, spks], "b * t")[0]
+ if cond is not None:
+ x = pack([x, cond], "b * t")[0]
+
+ hiddens = []
+ masks = [mask]
+ for resnet, transformer_blocks, downsample in self.down_blocks:
+ mask_down = masks[-1]
+ x = resnet(x, mask_down, t)
+ x = rearrange(x, "b c t -> b t c").contiguous()
+ # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
+ for transformer_block in transformer_blocks:
+ x = transformer_block(
+ hidden_states=x,
+ attention_mask=attn_mask,
+ timestep=t,
+ )
+ x = rearrange(x, "b t c -> b c t").contiguous()
+ hiddens.append(x) # Save hidden states for skip connections
+ x = downsample(x * mask_down)
+ masks.append(mask_down[:, :, ::2])
+ masks = masks[:-1]
+ mask_mid = masks[-1]
+
+ for resnet, transformer_blocks in self.mid_blocks:
+ x = resnet(x, mask_mid, t)
+ x = rearrange(x, "b c t -> b t c").contiguous()
+ # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
+ for transformer_block in transformer_blocks:
+ x = transformer_block(
+ hidden_states=x,
+ attention_mask=attn_mask,
+ timestep=t,
+ )
+ x = rearrange(x, "b t c -> b c t").contiguous()
+
+ for resnet, transformer_blocks, upsample in self.up_blocks:
+ mask_up = masks.pop()
+ skip = hiddens.pop()
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
+ x = resnet(x, mask_up, t)
+ x = rearrange(x, "b c t -> b t c").contiguous()
+ # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
+ for transformer_block in transformer_blocks:
+ x = transformer_block(
+ hidden_states=x,
+ attention_mask=attn_mask,
+ timestep=t,
+ )
+ x = rearrange(x, "b t c -> b c t").contiguous()
+ x = upsample(x * mask_up)
+ x = self.final_block(x, mask_up)
+ output = self.final_proj(x * mask_up)
+ return output * mask
diff --git a/src/chatterbox/models/s3gen/f0_predictor.py b/src/chatterbox/models/s3gen/f0_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..f73689c6dee76112215bd4be9355d8961312edad
--- /dev/null
+++ b/src/chatterbox/models/s3gen/f0_predictor.py
@@ -0,0 +1,55 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+import torch.nn as nn
+from torch.nn.utils.parametrizations import weight_norm
+
+
+class ConvRNNF0Predictor(nn.Module):
+ def __init__(self,
+ num_class: int = 1,
+ in_channels: int = 80,
+ cond_channels: int = 512
+ ):
+ super().__init__()
+
+ self.num_class = num_class
+ self.condnet = nn.Sequential(
+ weight_norm(
+ nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
+ ),
+ nn.ELU(),
+ weight_norm(
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
+ ),
+ nn.ELU(),
+ weight_norm(
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
+ ),
+ nn.ELU(),
+ weight_norm(
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
+ ),
+ nn.ELU(),
+ weight_norm(
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
+ ),
+ nn.ELU(),
+ )
+ self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.condnet(x)
+ x = x.transpose(1, 2)
+ return torch.abs(self.classifier(x).squeeze(-1))
diff --git a/src/chatterbox/models/s3gen/flow.py b/src/chatterbox/models/s3gen/flow.py
new file mode 100644
index 0000000000000000000000000000000000000000..fee2ec9da57b13f4486f7d17d44c20430d5f3eaf
--- /dev/null
+++ b/src/chatterbox/models/s3gen/flow.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import random
+from typing import Dict, Optional
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+from omegaconf import DictConfig
+from .utils.mask import make_pad_mask
+
+
+class MaskedDiffWithXvec(torch.nn.Module):
+ def __init__(self,
+ input_size: int = 512,
+ output_size: int = 80,
+ spk_embed_dim: int = 192,
+ output_type: str = "mel",
+ vocab_size: int = 4096,
+ input_frame_rate: int = 50,
+ only_mask_loss: bool = True,
+ encoder: torch.nn.Module = None,
+ length_regulator: torch.nn.Module = None,
+ decoder: torch.nn.Module = None,
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
+ 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
+ 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
+ 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
+ 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
+ mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
+ 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
+ super().__init__()
+ self.input_size = input_size
+ self.output_size = output_size
+ self.decoder_conf = decoder_conf
+ self.mel_feat_conf = mel_feat_conf
+ self.vocab_size = vocab_size
+ self.output_type = output_type
+ self.input_frame_rate = input_frame_rate
+ logging.info(f"input frame rate={self.input_frame_rate}")
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
+ self.encoder = encoder
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
+ self.decoder = decoder
+ self.length_regulator = length_regulator
+ self.only_mask_loss = only_mask_loss
+
+ def forward(
+ self,
+ batch: dict,
+ device: torch.device,
+ ) -> Dict[str, Optional[torch.Tensor]]:
+ token = batch['speech_token'].to(device)
+ token_len = batch['speech_token_len'].to(device)
+ feat = batch['speech_feat'].to(device)
+ feat_len = batch['speech_feat_len'].to(device)
+ embedding = batch['embedding'].to(device)
+
+ # xvec projection
+ embedding = F.normalize(embedding, dim=1)
+ embedding = self.spk_embed_affine_layer(embedding)
+
+ # concat text and prompt_text
+ mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
+
+ # text encode
+ h, h_lengths = self.encoder(token, token_len)
+ h = self.encoder_proj(h)
+ h, h_lengths = self.length_regulator(h, feat_len)
+
+ # get conditions
+ conds = torch.zeros(feat.shape, device=token.device)
+ for i, j in enumerate(feat_len):
+ if random.random() < 0.5:
+ continue
+ index = random.randint(0, int(0.3 * j))
+ conds[i, :index] = feat[i, :index]
+ conds = conds.transpose(1, 2)
+
+ mask = (~make_pad_mask(feat_len)).to(h)
+ feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
+ loss, _ = self.decoder.compute_loss(
+ feat.transpose(1, 2).contiguous(),
+ mask.unsqueeze(1),
+ h.transpose(1, 2).contiguous(),
+ embedding,
+ cond=conds
+ )
+ return {'loss': loss}
+
+ @torch.inference_mode()
+ def inference(self,
+ token,
+ token_len,
+ prompt_token,
+ prompt_token_len,
+ prompt_feat,
+ prompt_feat_len,
+ embedding,
+ flow_cache):
+ if self.fp16 is True:
+ prompt_feat = prompt_feat.half()
+ embedding = embedding.half()
+
+ assert token.shape[0] == 1
+ # xvec projection
+ embedding = F.normalize(embedding, dim=1)
+ embedding = self.spk_embed_affine_layer(embedding)
+
+ # concat text and prompt_text
+ token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
+ mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
+
+ # text encode
+ h, h_lengths = self.encoder(token, token_len)
+ h = self.encoder_proj(h)
+ mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
+ h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
+
+ # get conditions
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
+ conds[:, :mel_len1] = prompt_feat
+ conds = conds.transpose(1, 2)
+
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
+ feat, flow_cache = self.decoder(
+ mu=h.transpose(1, 2).contiguous(),
+ mask=mask.unsqueeze(1),
+ spks=embedding,
+ cond=conds,
+ n_timesteps=10,
+ prompt_len=mel_len1,
+ flow_cache=flow_cache
+ )
+ feat = feat[:, :, mel_len1:]
+ assert feat.shape[2] == mel_len2
+ return feat.float(), flow_cache
+
+
+class CausalMaskedDiffWithXvec(torch.nn.Module):
+ def __init__(self,
+ input_size: int = 512,
+ output_size: int = 80,
+ spk_embed_dim: int = 192,
+ output_type: str = "mel",
+ vocab_size: int = 6561,
+ input_frame_rate: int = 25,
+ only_mask_loss: bool = True,
+ token_mel_ratio: int = 2,
+ pre_lookahead_len: int = 3,
+ encoder: torch.nn.Module = None,
+ decoder: torch.nn.Module = None,
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
+ 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
+ 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
+ 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
+ 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
+ mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
+ 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
+ super().__init__()
+ self.input_size = input_size
+ self.output_size = output_size
+ self.decoder_conf = decoder_conf
+ self.mel_feat_conf = mel_feat_conf
+ self.vocab_size = vocab_size
+ self.output_type = output_type
+ self.input_frame_rate = input_frame_rate
+ logging.info(f"input frame rate={self.input_frame_rate}")
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
+ self.encoder = encoder
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
+ self.decoder = decoder
+ self.only_mask_loss = only_mask_loss
+ self.token_mel_ratio = token_mel_ratio
+ self.pre_lookahead_len = pre_lookahead_len
+
+ # FIXME: this was missing - just putting it in as false
+ self.fp16 = False
+
+ @torch.inference_mode()
+ def inference(self,
+ token,
+ token_len,
+ prompt_token,
+ prompt_token_len,
+ prompt_feat,
+ prompt_feat_len,
+ embedding,
+ finalize):
+ if self.fp16 is True:
+ prompt_feat = prompt_feat.half()
+ embedding = embedding.half()
+
+ assert token.shape[0] == 1
+ # xvec projection
+ embedding = F.normalize(embedding, dim=1)
+ embedding = self.spk_embed_affine_layer(embedding)
+
+ # concat text and prompt_text
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
+ mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
+
+ # text encode
+ h, h_lengths = self.encoder(token, token_len)
+ if finalize is False:
+ h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
+ mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
+ h = self.encoder_proj(h)
+
+ # get conditions
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
+ conds[:, :mel_len1] = prompt_feat
+ conds = conds.transpose(1, 2)
+
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
+ feat, _ = self.decoder(
+ mu=h.transpose(1, 2).contiguous(),
+ mask=mask.unsqueeze(1),
+ spks=embedding,
+ cond=conds,
+ n_timesteps=10
+ )
+ feat = feat[:, :, mel_len1:]
+ assert feat.shape[2] == mel_len2
+ return feat.float(), None # NOTE jrm: why are they returning None here?
diff --git a/src/chatterbox/models/s3gen/flow_matching.py b/src/chatterbox/models/s3gen/flow_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..74fc66f2aa5acbd3005d40d6165f3966484cecbf
--- /dev/null
+++ b/src/chatterbox/models/s3gen/flow_matching.py
@@ -0,0 +1,228 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import threading
+import torch
+import torch.nn.functional as F
+from .matcha.flow_matching import BASECFM
+from omegaconf import OmegaConf
+
+
+CFM_PARAMS = OmegaConf.create({
+ "sigma_min": 1e-06,
+ "solver": "euler",
+ "t_scheduler": "cosine",
+ "training_cfg_rate": 0.2,
+ "inference_cfg_rate": 0.7,
+ "reg_loss_type": "l1"
+})
+
+
+class ConditionalCFM(BASECFM):
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
+ super().__init__(
+ n_feats=in_channels,
+ cfm_params=cfm_params,
+ n_spks=n_spks,
+ spk_emb_dim=spk_emb_dim,
+ )
+ self.t_scheduler = cfm_params.t_scheduler
+ self.training_cfg_rate = cfm_params.training_cfg_rate
+ self.inference_cfg_rate = cfm_params.inference_cfg_rate
+ in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
+ # Just change the architecture of the estimator here
+ self.estimator = estimator
+ self.lock = threading.Lock()
+
+ @torch.inference_mode()
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
+ """Forward diffusion
+
+ Args:
+ mu (torch.Tensor): output of encoder
+ shape: (batch_size, n_feats, mel_timesteps)
+ mask (torch.Tensor): output_mask
+ shape: (batch_size, 1, mel_timesteps)
+ n_timesteps (int): number of diffusion steps
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
+ shape: (batch_size, spk_emb_dim)
+ cond: Not used but kept for future purposes
+
+ Returns:
+ sample: generated mel-spectrogram
+ shape: (batch_size, n_feats, mel_timesteps)
+ """
+
+ z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
+ cache_size = flow_cache.shape[2]
+ # fix prompt and overlap part mu and z
+ if cache_size != 0:
+ z[:, :, :cache_size] = flow_cache[:, :, :, 0]
+ mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
+ z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
+ mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
+ flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
+
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
+ if self.t_scheduler == 'cosine':
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
+
+ def solve_euler(self, x, t_span, mu, mask, spks, cond):
+ """
+ Fixed euler solver for ODEs.
+ Args:
+ x (torch.Tensor): random noise
+ t_span (torch.Tensor): n_timesteps interpolated
+ shape: (n_timesteps + 1,)
+ mu (torch.Tensor): output of encoder
+ shape: (batch_size, n_feats, mel_timesteps)
+ mask (torch.Tensor): output_mask
+ shape: (batch_size, 1, mel_timesteps)
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
+ shape: (batch_size, spk_emb_dim)
+ cond: Not used but kept for future purposes
+ """
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
+ t = t.unsqueeze(dim=0)
+
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
+ # Or in future might add like a return_all_steps flag
+ sol = []
+
+ # Do not use concat, it may cause memory format changed and trt infer with wrong results!
+ x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
+ mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
+ mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
+ t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
+ spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
+ cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
+ for step in range(1, len(t_span)):
+ # Classifier-Free Guidance inference introduced in VoiceBox
+ x_in[:] = x
+ mask_in[:] = mask
+ mu_in[0] = mu
+ t_in[:] = t.unsqueeze(0)
+ spks_in[0] = spks
+ cond_in[0] = cond
+ dphi_dt = self.forward_estimator(
+ x_in, mask_in,
+ mu_in, t_in,
+ spks_in,
+ cond_in
+ )
+ dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
+ x = x + dt * dphi_dt
+ t = t + dt
+ sol.append(x)
+ if step < len(t_span) - 1:
+ dt = t_span[step + 1] - t
+
+ return sol[-1].float()
+
+ def forward_estimator(self, x, mask, mu, t, spks, cond):
+ if isinstance(self.estimator, torch.nn.Module):
+ return self.estimator.forward(x, mask, mu, t, spks, cond)
+ else:
+ with self.lock:
+ self.estimator.set_input_shape('x', (2, 80, x.size(2)))
+ self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
+ self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
+ self.estimator.set_input_shape('t', (2,))
+ self.estimator.set_input_shape('spks', (2, 80))
+ self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
+ # run trt engine
+ self.estimator.execute_v2([x.contiguous().data_ptr(),
+ mask.contiguous().data_ptr(),
+ mu.contiguous().data_ptr(),
+ t.contiguous().data_ptr(),
+ spks.contiguous().data_ptr(),
+ cond.contiguous().data_ptr(),
+ x.data_ptr()])
+ return x
+
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None):
+ """Computes diffusion loss
+
+ Args:
+ x1 (torch.Tensor): Target
+ shape: (batch_size, n_feats, mel_timesteps)
+ mask (torch.Tensor): target mask
+ shape: (batch_size, 1, mel_timesteps)
+ mu (torch.Tensor): output of encoder
+ shape: (batch_size, n_feats, mel_timesteps)
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
+ shape: (batch_size, spk_emb_dim)
+
+ Returns:
+ loss: conditional flow matching loss
+ y: conditional flow
+ shape: (batch_size, n_feats, mel_timesteps)
+ """
+ b, _, t = mu.shape
+
+ # random timestep
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
+ if self.t_scheduler == 'cosine':
+ t = 1 - torch.cos(t * 0.5 * torch.pi)
+ # sample noise p(x_0)
+ z = torch.randn_like(x1)
+
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
+ u = x1 - (1 - self.sigma_min) * z
+
+ # during training, we randomly drop condition to trade off mode coverage and sample fidelity
+ if self.training_cfg_rate > 0:
+ cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
+ mu = mu * cfg_mask.view(-1, 1, 1)
+ spks = spks * cfg_mask.view(-1, 1)
+ cond = cond * cfg_mask.view(-1, 1, 1)
+
+ pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
+ loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
+ return loss, y
+
+
+class CausalConditionalCFM(ConditionalCFM):
+ def __init__(self, in_channels=240, cfm_params=CFM_PARAMS, n_spks=1, spk_emb_dim=80, estimator=None):
+ super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
+ self.rand_noise = torch.randn([1, 80, 50 * 300])
+
+ @torch.inference_mode()
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
+ """Forward diffusion
+
+ Args:
+ mu (torch.Tensor): output of encoder
+ shape: (batch_size, n_feats, mel_timesteps)
+ mask (torch.Tensor): output_mask
+ shape: (batch_size, 1, mel_timesteps)
+ n_timesteps (int): number of diffusion steps
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
+ shape: (batch_size, spk_emb_dim)
+ cond: Not used but kept for future purposes
+
+ Returns:
+ sample: generated mel-spectrogram
+ shape: (batch_size, n_feats, mel_timesteps)
+ """
+
+ z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
+ # fix prompt and overlap part mu and z
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
+ if self.t_scheduler == 'cosine':
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
diff --git a/src/chatterbox/models/s3gen/hifigan.py b/src/chatterbox/models/s3gen/hifigan.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ecf73e5a50f2b6fea0b911a1e5da9b01d3fa0ea
--- /dev/null
+++ b/src/chatterbox/models/s3gen/hifigan.py
@@ -0,0 +1,474 @@
+# jrm: adapted from CosyVoice/cosyvoice/hifigan/generator.py
+# most modules should be reusable, but I found their SineGen changed a git.
+
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""HIFI-GAN"""
+
+from typing import Dict, Optional, List
+import numpy as np
+from scipy.signal import get_window
+import torch
+import torch.nn.functional as F
+from torch.nn import Conv1d
+from torch.nn import ConvTranspose1d
+from torch.nn.utils import remove_weight_norm
+from torch.nn.utils.parametrizations import weight_norm
+from torch.distributions.uniform import Uniform
+from torch import nn, sin, pow
+from torch.nn import Parameter
+
+
+class Snake(nn.Module):
+ '''
+ Implementation of a sine-based periodic activation function
+ Shape:
+ - Input: (B, C, T)
+ - Output: (B, C, T), same shape as the input
+ Parameters:
+ - alpha - trainable parameter
+ References:
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
+ https://arxiv.org/abs/2006.08195
+ Examples:
+ >>> a1 = snake(256)
+ >>> x = torch.randn(256)
+ >>> x = a1(x)
+ '''
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
+ '''
+ Initialization.
+ INPUT:
+ - in_features: shape of the input
+ - alpha: trainable parameter
+ alpha is initialized to 1 by default, higher values = higher-frequency.
+ alpha will be trained along with the rest of your model.
+ '''
+ super(Snake, self).__init__()
+ self.in_features = in_features
+
+ # initialize alpha
+ self.alpha_logscale = alpha_logscale
+ if self.alpha_logscale: # log scale alphas initialized to zeros
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
+ else: # linear scale alphas initialized to ones
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
+
+ self.alpha.requires_grad = alpha_trainable
+
+ self.no_div_by_zero = 0.000000001
+
+ def forward(self, x):
+ '''
+ Forward pass of the function.
+ Applies the function to the input elementwise.
+ Snake ∶= x + 1/a * sin^2 (xa)
+ '''
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
+ if self.alpha_logscale:
+ alpha = torch.exp(alpha)
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
+
+ return x
+
+
+
+def get_padding(kernel_size, dilation=1):
+ return int((kernel_size * dilation - dilation) / 2)
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+"""hifigan based generator implementation.
+
+This code is modified from https://github.com/jik876/hifi-gan
+ ,https://github.com/kan-bayashi/ParallelWaveGAN and
+ https://github.com/NVIDIA/BigVGAN
+
+"""
+
+
+class ResBlock(torch.nn.Module):
+ """Residual block module in HiFiGAN/BigVGAN."""
+ def __init__(
+ self,
+ channels: int = 512,
+ kernel_size: int = 3,
+ dilations: List[int] = [1, 3, 5],
+ ):
+ super(ResBlock, self).__init__()
+ self.convs1 = nn.ModuleList()
+ self.convs2 = nn.ModuleList()
+
+ for dilation in dilations:
+ self.convs1.append(
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation,
+ padding=get_padding(kernel_size, dilation)
+ )
+ )
+ )
+ self.convs2.append(
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1)
+ )
+ )
+ )
+ self.convs1.apply(init_weights)
+ self.convs2.apply(init_weights)
+ self.activations1 = nn.ModuleList([
+ Snake(channels, alpha_logscale=False)
+ for _ in range(len(self.convs1))
+ ])
+ self.activations2 = nn.ModuleList([
+ Snake(channels, alpha_logscale=False)
+ for _ in range(len(self.convs2))
+ ])
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ for idx in range(len(self.convs1)):
+ xt = self.activations1[idx](x)
+ xt = self.convs1[idx](xt)
+ xt = self.activations2[idx](xt)
+ xt = self.convs2[idx](xt)
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for idx in range(len(self.convs1)):
+ remove_weight_norm(self.convs1[idx])
+ remove_weight_norm(self.convs2[idx])
+
+
+class SineGen(torch.nn.Module):
+ """ Definition of sine generator
+ SineGen(samp_rate, harmonic_num = 0,
+ sine_amp = 0.1, noise_std = 0.003,
+ voiced_threshold = 0,
+ flag_for_pulse=False)
+ samp_rate: sampling rate in Hz
+ harmonic_num: number of harmonic overtones (default 0)
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
+ noise_std: std of Gaussian noise (default 0.003)
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
+ Note: when flag_for_pulse is True, the first time step of a voiced
+ segment is always sin(np.pi) or cos(0)
+ """
+
+ def __init__(self, samp_rate, harmonic_num=0,
+ sine_amp=0.1, noise_std=0.003,
+ voiced_threshold=0):
+ super(SineGen, self).__init__()
+ self.sine_amp = sine_amp
+ self.noise_std = noise_std
+ self.harmonic_num = harmonic_num
+ self.sampling_rate = samp_rate
+ self.voiced_threshold = voiced_threshold
+
+ def _f02uv(self, f0):
+ # generate uv signal
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
+ return uv
+
+ @torch.no_grad()
+ def forward(self, f0):
+ """
+ :param f0: [B, 1, sample_len], Hz
+ :return: [B, 1, sample_len]
+ """
+
+ F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
+ for i in range(self.harmonic_num + 1):
+ F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
+
+ theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
+ u_dist = Uniform(low=-np.pi, high=np.pi)
+ phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
+ phase_vec[:, 0, :] = 0
+
+ # generate sine waveforms
+ sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
+
+ # generate uv signal
+ uv = self._f02uv(f0)
+
+ # noise: for unvoiced should be similar to sine_amp
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
+ # . for voiced regions is self.noise_std
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
+ noise = noise_amp * torch.randn_like(sine_waves)
+
+ # first: set the unvoiced part to 0 by uv
+ # then: additive noise
+ sine_waves = sine_waves * uv + noise
+ return sine_waves, uv, noise
+
+
+class SourceModuleHnNSF(torch.nn.Module):
+ """ SourceModule for hn-nsf
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
+ add_noise_std=0.003, voiced_threshod=0)
+ sampling_rate: sampling_rate in Hz
+ harmonic_num: number of harmonic above F0 (default: 0)
+ sine_amp: amplitude of sine source signal (default: 0.1)
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
+ note that amplitude of noise in unvoiced is decided
+ by sine_amp
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
+ F0_sampled (batchsize, length, 1)
+ Sine_source (batchsize, length, 1)
+ noise_source (batchsize, length 1)
+ uv (batchsize, length, 1)
+ """
+
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
+ add_noise_std=0.003, voiced_threshod=0):
+ super(SourceModuleHnNSF, self).__init__()
+
+ self.sine_amp = sine_amp
+ self.noise_std = add_noise_std
+
+ # to produce sine waveforms
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
+ sine_amp, add_noise_std, voiced_threshod)
+
+ # to merge source harmonics into a single excitation
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
+ self.l_tanh = torch.nn.Tanh()
+
+ def forward(self, x):
+ """
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
+ F0_sampled (batchsize, length, 1)
+ Sine_source (batchsize, length, 1)
+ noise_source (batchsize, length 1)
+ """
+ # source for harmonic branch
+ with torch.no_grad():
+ sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
+ sine_wavs = sine_wavs.transpose(1, 2)
+ uv = uv.transpose(1, 2)
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
+
+ # source for noise branch, in the same shape as uv
+ noise = torch.randn_like(uv) * self.sine_amp / 3
+ return sine_merge, noise, uv
+
+
+class HiFTGenerator(nn.Module):
+ """
+ HiFTNet Generator: Neural Source Filter + ISTFTNet
+ https://arxiv.org/abs/2309.09493
+ """
+ def __init__(
+ self,
+ in_channels: int = 80,
+ base_channels: int = 512,
+ nb_harmonics: int = 8,
+ sampling_rate: int = 22050,
+ nsf_alpha: float = 0.1,
+ nsf_sigma: float = 0.003,
+ nsf_voiced_threshold: float = 10,
+ upsample_rates: List[int] = [8, 8],
+ upsample_kernel_sizes: List[int] = [16, 16],
+ istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
+ resblock_kernel_sizes: List[int] = [3, 7, 11],
+ resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+ source_resblock_kernel_sizes: List[int] = [7, 11],
+ source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
+ lrelu_slope: float = 0.1,
+ audio_limit: float = 0.99,
+ f0_predictor: torch.nn.Module = None,
+ ):
+ super(HiFTGenerator, self).__init__()
+
+ self.out_channels = 1
+ self.nb_harmonics = nb_harmonics
+ self.sampling_rate = sampling_rate
+ self.istft_params = istft_params
+ self.lrelu_slope = lrelu_slope
+ self.audio_limit = audio_limit
+
+ self.num_kernels = len(resblock_kernel_sizes)
+ self.num_upsamples = len(upsample_rates)
+ self.m_source = SourceModuleHnNSF(
+ sampling_rate=sampling_rate,
+ upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
+ harmonic_num=nb_harmonics,
+ sine_amp=nsf_alpha,
+ add_noise_std=nsf_sigma,
+ voiced_threshod=nsf_voiced_threshold)
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
+
+ self.conv_pre = weight_norm(
+ Conv1d(in_channels, base_channels, 7, 1, padding=3)
+ )
+
+ # Up
+ self.ups = nn.ModuleList()
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+ self.ups.append(
+ weight_norm(
+ ConvTranspose1d(
+ base_channels // (2**i),
+ base_channels // (2**(i + 1)),
+ k,
+ u,
+ padding=(k - u) // 2,
+ )
+ )
+ )
+
+ # Down
+ self.source_downs = nn.ModuleList()
+ self.source_resblocks = nn.ModuleList()
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
+ downsample_cum_rates = np.cumprod(downsample_rates)
+ for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
+ if u == 1:
+ self.source_downs.append(
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
+ )
+ else:
+ self.source_downs.append(
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
+ )
+
+ self.source_resblocks.append(
+ ResBlock(base_channels // (2 ** (i + 1)), k, d)
+ )
+
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = base_channels // (2**(i + 1))
+ for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
+ self.resblocks.append(ResBlock(ch, k, d))
+
+ self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
+ self.ups.apply(init_weights)
+ self.conv_post.apply(init_weights)
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
+ self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
+ self.f0_predictor = f0_predictor
+
+ def remove_weight_norm(self):
+ print('Removing weight norm...')
+ for l in self.ups:
+ remove_weight_norm(l)
+ for l in self.resblocks:
+ l.remove_weight_norm()
+ remove_weight_norm(self.conv_pre)
+ remove_weight_norm(self.conv_post)
+ self.m_source.remove_weight_norm()
+ for l in self.source_downs:
+ remove_weight_norm(l)
+ for l in self.source_resblocks:
+ l.remove_weight_norm()
+
+ def _stft(self, x):
+ spec = torch.stft(
+ x,
+ self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
+ return_complex=True)
+ spec = torch.view_as_real(spec) # [B, F, TT, 2]
+ return spec[..., 0], spec[..., 1]
+
+ def _istft(self, magnitude, phase):
+ magnitude = torch.clip(magnitude, max=1e2)
+ real = magnitude * torch.cos(phase)
+ img = magnitude * torch.sin(phase)
+ inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
+ self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
+ return inverse_transform
+
+ def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
+
+ x = self.conv_pre(x)
+ for i in range(self.num_upsamples):
+ x = F.leaky_relu(x, self.lrelu_slope)
+ x = self.ups[i](x)
+
+ if i == self.num_upsamples - 1:
+ x = self.reflection_pad(x)
+
+ # fusion
+ si = self.source_downs[i](s_stft)
+ si = self.source_resblocks[i](si)
+ x = x + si
+
+ xs = None
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i * self.num_kernels + j](x)
+ else:
+ xs += self.resblocks[i * self.num_kernels + j](x)
+ x = xs / self.num_kernels
+
+ x = F.leaky_relu(x)
+ x = self.conv_post(x)
+ magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
+ phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
+
+ x = self._istft(magnitude, phase)
+ x = torch.clamp(x, -self.audio_limit, self.audio_limit)
+ return x
+
+ def forward(
+ self,
+ batch: dict,
+ device: torch.device,
+ ) -> Dict[str, Optional[torch.Tensor]]:
+ speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
+ # mel->f0
+ f0 = self.f0_predictor(speech_feat)
+ # f0->source
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
+ s, _, _ = self.m_source(s)
+ s = s.transpose(1, 2)
+ # mel+source->speech
+ generated_speech = self.decode(x=speech_feat, s=s)
+ return generated_speech, f0
+
+ @torch.inference_mode()
+ def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
+ # mel->f0
+ f0 = self.f0_predictor(speech_feat)
+ # f0->source
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
+ s, _, _ = self.m_source(s)
+ s = s.transpose(1, 2)
+ # use cache_source to avoid glitch
+ if cache_source.shape[2] != 0:
+ s[:, :, :cache_source.shape[2]] = cache_source
+ generated_speech = self.decode(x=speech_feat, s=s)
+ return generated_speech, s
diff --git a/src/chatterbox/models/s3gen/matcha/decoder.py b/src/chatterbox/models/s3gen/matcha/decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7cb54724acb124376baad6298f7dec095c21ae3
--- /dev/null
+++ b/src/chatterbox/models/s3gen/matcha/decoder.py
@@ -0,0 +1,443 @@
+import math
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from conformer import ConformerBlock
+from diffusers.models.activations import get_activation
+from einops import pack, rearrange, repeat
+
+from .transformer import BasicTransformerBlock
+
+
+class SinusoidalPosEmb(torch.nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+ assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
+
+ def forward(self, x, scale=1000):
+ if x.ndim < 1:
+ x = x.unsqueeze(0)
+ device = x.device
+ half_dim = self.dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
+ return emb
+
+
+class Block1D(torch.nn.Module):
+ def __init__(self, dim, dim_out, groups=8):
+ super().__init__()
+ self.block = torch.nn.Sequential(
+ torch.nn.Conv1d(dim, dim_out, 3, padding=1),
+ torch.nn.GroupNorm(groups, dim_out),
+ nn.Mish(),
+ )
+
+ def forward(self, x, mask):
+ output = self.block(x * mask)
+ return output * mask
+
+
+class ResnetBlock1D(torch.nn.Module):
+ def __init__(self, dim, dim_out, time_emb_dim, groups=8):
+ super().__init__()
+ self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out))
+
+ self.block1 = Block1D(dim, dim_out, groups=groups)
+ self.block2 = Block1D(dim_out, dim_out, groups=groups)
+
+ self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
+
+ def forward(self, x, mask, time_emb):
+ h = self.block1(x, mask)
+ h += self.mlp(time_emb).unsqueeze(-1)
+ h = self.block2(h, mask)
+ output = h + self.res_conv(x * mask)
+ return output
+
+
+class Downsample1D(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+class TimestepEmbedding(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ time_embed_dim: int,
+ act_fn: str = "silu",
+ out_dim: int = None,
+ post_act_fn: Optional[str] = None,
+ cond_proj_dim=None,
+ ):
+ super().__init__()
+
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
+
+ if cond_proj_dim is not None:
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
+ else:
+ self.cond_proj = None
+
+ self.act = get_activation(act_fn)
+
+ if out_dim is not None:
+ time_embed_dim_out = out_dim
+ else:
+ time_embed_dim_out = time_embed_dim
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
+
+ if post_act_fn is None:
+ self.post_act = None
+ else:
+ self.post_act = get_activation(post_act_fn)
+
+ def forward(self, sample, condition=None):
+ if condition is not None:
+ sample = sample + self.cond_proj(condition)
+ sample = self.linear_1(sample)
+
+ if self.act is not None:
+ sample = self.act(sample)
+
+ sample = self.linear_2(sample)
+
+ if self.post_act is not None:
+ sample = self.post_act(sample)
+ return sample
+
+
+class Upsample1D(nn.Module):
+ """A 1D upsampling layer with an optional convolution.
+
+ Parameters:
+ channels (`int`):
+ number of channels in the inputs and outputs.
+ use_conv (`bool`, default `False`):
+ option to use a convolution.
+ use_conv_transpose (`bool`, default `False`):
+ option to use a convolution transpose.
+ out_channels (`int`, optional):
+ number of output channels. Defaults to `channels`.
+ """
+
+ def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_conv_transpose = use_conv_transpose
+ self.name = name
+
+ self.conv = None
+ if use_conv_transpose:
+ self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
+ elif use_conv:
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
+
+ def forward(self, inputs):
+ assert inputs.shape[1] == self.channels
+ if self.use_conv_transpose:
+ return self.conv(inputs)
+
+ outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
+
+ if self.use_conv:
+ outputs = self.conv(outputs)
+
+ return outputs
+
+
+class ConformerWrapper(ConformerBlock):
+ def __init__( # pylint: disable=useless-super-delegation
+ self,
+ *,
+ dim,
+ dim_head=64,
+ heads=8,
+ ff_mult=4,
+ conv_expansion_factor=2,
+ conv_kernel_size=31,
+ attn_dropout=0,
+ ff_dropout=0,
+ conv_dropout=0,
+ conv_causal=False,
+ ):
+ super().__init__(
+ dim=dim,
+ dim_head=dim_head,
+ heads=heads,
+ ff_mult=ff_mult,
+ conv_expansion_factor=conv_expansion_factor,
+ conv_kernel_size=conv_kernel_size,
+ attn_dropout=attn_dropout,
+ ff_dropout=ff_dropout,
+ conv_dropout=conv_dropout,
+ conv_causal=conv_causal,
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ timestep=None,
+ ):
+ return super().forward(x=hidden_states, mask=attention_mask.bool())
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ channels=(256, 256),
+ dropout=0.05,
+ attention_head_dim=64,
+ n_blocks=1,
+ num_mid_blocks=2,
+ num_heads=4,
+ act_fn="snake",
+ down_block_type="transformer",
+ mid_block_type="transformer",
+ up_block_type="transformer",
+ ):
+ super().__init__()
+ channels = tuple(channels)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
+ time_embed_dim = channels[0] * 4
+ self.time_mlp = TimestepEmbedding(
+ in_channels=in_channels,
+ time_embed_dim=time_embed_dim,
+ act_fn="silu",
+ )
+
+ self.down_blocks = nn.ModuleList([])
+ self.mid_blocks = nn.ModuleList([])
+ self.up_blocks = nn.ModuleList([])
+
+ output_channel = in_channels
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
+ input_channel = output_channel
+ output_channel = channels[i]
+ is_last = i == len(channels) - 1
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
+ transformer_blocks = nn.ModuleList(
+ [
+ self.get_block(
+ down_block_type,
+ output_channel,
+ attention_head_dim,
+ num_heads,
+ dropout,
+ act_fn,
+ )
+ for _ in range(n_blocks)
+ ]
+ )
+ downsample = (
+ Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
+ )
+
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
+
+ for i in range(num_mid_blocks):
+ input_channel = channels[-1]
+ out_channels = channels[-1]
+
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
+
+ transformer_blocks = nn.ModuleList(
+ [
+ self.get_block(
+ mid_block_type,
+ output_channel,
+ attention_head_dim,
+ num_heads,
+ dropout,
+ act_fn,
+ )
+ for _ in range(n_blocks)
+ ]
+ )
+
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
+
+ channels = channels[::-1] + (channels[0],)
+ for i in range(len(channels) - 1):
+ input_channel = channels[i]
+ output_channel = channels[i + 1]
+ is_last = i == len(channels) - 2
+
+ resnet = ResnetBlock1D(
+ dim=2 * input_channel,
+ dim_out=output_channel,
+ time_emb_dim=time_embed_dim,
+ )
+ transformer_blocks = nn.ModuleList(
+ [
+ self.get_block(
+ up_block_type,
+ output_channel,
+ attention_head_dim,
+ num_heads,
+ dropout,
+ act_fn,
+ )
+ for _ in range(n_blocks)
+ ]
+ )
+ upsample = (
+ Upsample1D(output_channel, use_conv_transpose=True)
+ if not is_last
+ else nn.Conv1d(output_channel, output_channel, 3, padding=1)
+ )
+
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
+
+ self.final_block = Block1D(channels[-1], channels[-1])
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
+
+ self.initialize_weights()
+ # nn.init.normal_(self.final_proj.weight)
+
+ @staticmethod
+ def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn):
+ if block_type == "conformer":
+ block = ConformerWrapper(
+ dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_heads,
+ ff_mult=1,
+ conv_expansion_factor=2,
+ ff_dropout=dropout,
+ attn_dropout=dropout,
+ conv_dropout=dropout,
+ conv_kernel_size=31,
+ )
+ elif block_type == "transformer":
+ block = BasicTransformerBlock(
+ dim=dim,
+ num_attention_heads=num_heads,
+ attention_head_dim=attention_head_dim,
+ dropout=dropout,
+ activation_fn=act_fn,
+ )
+ else:
+ raise ValueError(f"Unknown block type {block_type}")
+
+ return block
+
+ def initialize_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv1d):
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
+
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ elif isinstance(m, nn.GroupNorm):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ elif isinstance(m, nn.Linear):
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
+
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x, mask, mu, t, spks=None, cond=None):
+ """Forward pass of the UNet1DConditional model.
+
+ Args:
+ x (torch.Tensor): shape (batch_size, in_channels, time)
+ mask (_type_): shape (batch_size, 1, time)
+ t (_type_): shape (batch_size)
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
+ cond (_type_, optional): placeholder for future use. Defaults to None.
+
+ Raises:
+ ValueError: _description_
+ ValueError: _description_
+
+ Returns:
+ _type_: _description_
+ """
+
+ t = self.time_embeddings(t)
+ t = self.time_mlp(t)
+
+ x = pack([x, mu], "b * t")[0]
+
+ if spks is not None:
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
+ x = pack([x, spks], "b * t")[0]
+
+ hiddens = []
+ masks = [mask]
+ for resnet, transformer_blocks, downsample in self.down_blocks:
+ mask_down = masks[-1]
+ x = resnet(x, mask_down, t)
+ x = rearrange(x, "b c t -> b t c")
+ mask_down = rearrange(mask_down, "b 1 t -> b t")
+ for transformer_block in transformer_blocks:
+ x = transformer_block(
+ hidden_states=x,
+ attention_mask=mask_down,
+ timestep=t,
+ )
+ x = rearrange(x, "b t c -> b c t")
+ mask_down = rearrange(mask_down, "b t -> b 1 t")
+ hiddens.append(x) # Save hidden states for skip connections
+ x = downsample(x * mask_down)
+ masks.append(mask_down[:, :, ::2])
+
+ masks = masks[:-1]
+ mask_mid = masks[-1]
+
+ for resnet, transformer_blocks in self.mid_blocks:
+ x = resnet(x, mask_mid, t)
+ x = rearrange(x, "b c t -> b t c")
+ mask_mid = rearrange(mask_mid, "b 1 t -> b t")
+ for transformer_block in transformer_blocks:
+ x = transformer_block(
+ hidden_states=x,
+ attention_mask=mask_mid,
+ timestep=t,
+ )
+ x = rearrange(x, "b t c -> b c t")
+ mask_mid = rearrange(mask_mid, "b t -> b 1 t")
+
+ for resnet, transformer_blocks, upsample in self.up_blocks:
+ mask_up = masks.pop()
+ x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t)
+ x = rearrange(x, "b c t -> b t c")
+ mask_up = rearrange(mask_up, "b 1 t -> b t")
+ for transformer_block in transformer_blocks:
+ x = transformer_block(
+ hidden_states=x,
+ attention_mask=mask_up,
+ timestep=t,
+ )
+ x = rearrange(x, "b t c -> b c t")
+ mask_up = rearrange(mask_up, "b t -> b 1 t")
+ x = upsample(x * mask_up)
+
+ x = self.final_block(x, mask_up)
+ output = self.final_proj(x * mask_up)
+
+ return output * mask
diff --git a/src/chatterbox/models/s3gen/matcha/flow_matching.py b/src/chatterbox/models/s3gen/matcha/flow_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..c91118250b978770f2badd366877f929092ca102
--- /dev/null
+++ b/src/chatterbox/models/s3gen/matcha/flow_matching.py
@@ -0,0 +1,129 @@
+from abc import ABC
+
+import torch
+import torch.nn.functional as F
+
+from .decoder import Decoder
+
+
+class BASECFM(torch.nn.Module, ABC):
+ def __init__(
+ self,
+ n_feats,
+ cfm_params,
+ n_spks=1,
+ spk_emb_dim=128,
+ ):
+ super().__init__()
+ self.n_feats = n_feats
+ self.n_spks = n_spks
+ self.spk_emb_dim = spk_emb_dim
+ self.solver = cfm_params.solver
+ if hasattr(cfm_params, "sigma_min"):
+ self.sigma_min = cfm_params.sigma_min
+ else:
+ self.sigma_min = 1e-4
+
+ self.estimator = None
+
+ @torch.inference_mode()
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
+ """Forward diffusion
+
+ Args:
+ mu (torch.Tensor): output of encoder
+ shape: (batch_size, n_feats, mel_timesteps)
+ mask (torch.Tensor): output_mask
+ shape: (batch_size, 1, mel_timesteps)
+ n_timesteps (int): number of diffusion steps
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
+ shape: (batch_size, spk_emb_dim)
+ cond: Not used but kept for future purposes
+
+ Returns:
+ sample: generated mel-spectrogram
+ shape: (batch_size, n_feats, mel_timesteps)
+ """
+ z = torch.randn_like(mu) * temperature
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
+
+ def solve_euler(self, x, t_span, mu, mask, spks, cond):
+ """
+ Fixed euler solver for ODEs.
+ Args:
+ x (torch.Tensor): random noise
+ t_span (torch.Tensor): n_timesteps interpolated
+ shape: (n_timesteps + 1,)
+ mu (torch.Tensor): output of encoder
+ shape: (batch_size, n_feats, mel_timesteps)
+ mask (torch.Tensor): output_mask
+ shape: (batch_size, 1, mel_timesteps)
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
+ shape: (batch_size, spk_emb_dim)
+ cond: Not used but kept for future purposes
+ """
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
+
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
+ # Or in future might add like a return_all_steps flag
+ sol = []
+
+ for step in range(1, len(t_span)):
+ dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
+
+ x = x + dt * dphi_dt
+ t = t + dt
+ sol.append(x)
+ if step < len(t_span) - 1:
+ dt = t_span[step + 1] - t
+
+ return sol[-1]
+
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None):
+ """Computes diffusion loss
+
+ Args:
+ x1 (torch.Tensor): Target
+ shape: (batch_size, n_feats, mel_timesteps)
+ mask (torch.Tensor): target mask
+ shape: (batch_size, 1, mel_timesteps)
+ mu (torch.Tensor): output of encoder
+ shape: (batch_size, n_feats, mel_timesteps)
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
+ shape: (batch_size, spk_emb_dim)
+
+ Returns:
+ loss: conditional flow matching loss
+ y: conditional flow
+ shape: (batch_size, n_feats, mel_timesteps)
+ """
+ b, _, t = mu.shape
+
+ # random timestep
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
+ # sample noise p(x_0)
+ z = torch.randn_like(x1)
+
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
+ u = x1 - (1 - self.sigma_min) * z
+
+ loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / (
+ torch.sum(mask) * u.shape[1]
+ )
+ return loss, y
+
+
+class CFM(BASECFM):
+ def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64):
+ super().__init__(
+ n_feats=in_channels,
+ cfm_params=cfm_params,
+ n_spks=n_spks,
+ spk_emb_dim=spk_emb_dim,
+ )
+
+ in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0)
+ # Just change the architecture of the estimator here
+ self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params)
diff --git a/src/chatterbox/models/s3gen/matcha/text_encoder.py b/src/chatterbox/models/s3gen/matcha/text_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..12b82060bfd04d0cc15720f2fd49c2550164ae01
--- /dev/null
+++ b/src/chatterbox/models/s3gen/matcha/text_encoder.py
@@ -0,0 +1,413 @@
+""" from https://github.com/jaywalnut310/glow-tts """
+
+import math
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+
+
+def sequence_mask(length, max_length=None):
+ if max_length is None:
+ max_length = length.max()
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
+ return x.unsqueeze(0) < length.unsqueeze(1)
+
+
+
+class LayerNorm(nn.Module):
+ def __init__(self, channels, eps=1e-4):
+ super().__init__()
+ self.channels = channels
+ self.eps = eps
+
+ self.gamma = torch.nn.Parameter(torch.ones(channels))
+ self.beta = torch.nn.Parameter(torch.zeros(channels))
+
+ def forward(self, x):
+ n_dims = len(x.shape)
+ mean = torch.mean(x, 1, keepdim=True)
+ variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
+
+ x = (x - mean) * torch.rsqrt(variance + self.eps)
+
+ shape = [1, -1] + [1] * (n_dims - 2)
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
+ return x
+
+
+class ConvReluNorm(nn.Module):
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
+ super().__init__()
+ self.in_channels = in_channels
+ self.hidden_channels = hidden_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.n_layers = n_layers
+ self.p_dropout = p_dropout
+
+ self.conv_layers = torch.nn.ModuleList()
+ self.norm_layers = torch.nn.ModuleList()
+ self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
+ self.norm_layers.append(LayerNorm(hidden_channels))
+ self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout))
+ for _ in range(n_layers - 1):
+ self.conv_layers.append(
+ torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
+ )
+ self.norm_layers.append(LayerNorm(hidden_channels))
+ self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
+ self.proj.weight.data.zero_()
+ self.proj.bias.data.zero_()
+
+ def forward(self, x, x_mask):
+ x_org = x
+ for i in range(self.n_layers):
+ x = self.conv_layers[i](x * x_mask)
+ x = self.norm_layers[i](x)
+ x = self.relu_drop(x)
+ x = x_org + self.proj(x)
+ return x * x_mask
+
+
+class DurationPredictor(nn.Module):
+ def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
+ super().__init__()
+ self.in_channels = in_channels
+ self.filter_channels = filter_channels
+ self.p_dropout = p_dropout
+
+ self.drop = torch.nn.Dropout(p_dropout)
+ self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
+ self.norm_1 = LayerNorm(filter_channels)
+ self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
+ self.norm_2 = LayerNorm(filter_channels)
+ self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
+
+ def forward(self, x, x_mask):
+ x = self.conv_1(x * x_mask)
+ x = torch.relu(x)
+ x = self.norm_1(x)
+ x = self.drop(x)
+ x = self.conv_2(x * x_mask)
+ x = torch.relu(x)
+ x = self.norm_2(x)
+ x = self.drop(x)
+ x = self.proj(x * x_mask)
+ return x * x_mask
+
+
+class RotaryPositionalEmbeddings(nn.Module):
+ """
+ ## RoPE module
+
+ Rotary encoding transforms pairs of features by rotating in the 2D plane.
+ That is, it organizes the $d$ features as $\frac{d}{2}$ pairs.
+ Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it
+ by an angle depending on the position of the token.
+ """
+
+ def __init__(self, d: int, base: int = 10_000):
+ r"""
+ * `d` is the number of features $d$
+ * `base` is the constant used for calculating $\Theta$
+ """
+ super().__init__()
+
+ self.base = base
+ self.d = int(d)
+ self.cos_cached = None
+ self.sin_cached = None
+
+ def _build_cache(self, x: torch.Tensor):
+ r"""
+ Cache $\cos$ and $\sin$ values
+ """
+ # Return if cache is already built
+ if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
+ return
+
+ # Get sequence length
+ seq_len = x.shape[0]
+
+ # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
+ theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
+
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
+ seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
+
+ # Calculate the product of position index and $\theta_i$
+ idx_theta = torch.einsum("n,d->nd", seq_idx, theta)
+
+ # Concatenate so that for row $m$ we have
+ # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
+ idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
+
+ # Cache them
+ self.cos_cached = idx_theta2.cos()[:, None, None, :]
+ self.sin_cached = idx_theta2.sin()[:, None, None, :]
+
+ def _neg_half(self, x: torch.Tensor):
+ # $\frac{d}{2}$
+ d_2 = self.d // 2
+
+ # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
+ return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
+
+ def forward(self, x: torch.Tensor):
+ """
+ * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
+ """
+ # Cache $\cos$ and $\sin$ values
+ x = rearrange(x, "b h t d -> t b h d")
+
+ self._build_cache(x)
+
+ # Split the features, we can choose to apply rotary embeddings only to a partial set of features.
+ x_rope, x_pass = x[..., : self.d], x[..., self.d :]
+
+ # Calculate
+ # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
+ neg_half_x = self._neg_half(x_rope)
+
+ x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]])
+
+ return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d")
+
+
+class MultiHeadAttention(nn.Module):
+ def __init__(
+ self,
+ channels,
+ out_channels,
+ n_heads,
+ heads_share=True,
+ p_dropout=0.0,
+ proximal_bias=False,
+ proximal_init=False,
+ ):
+ super().__init__()
+ assert channels % n_heads == 0
+
+ self.channels = channels
+ self.out_channels = out_channels
+ self.n_heads = n_heads
+ self.heads_share = heads_share
+ self.proximal_bias = proximal_bias
+ self.p_dropout = p_dropout
+ self.attn = None
+
+ self.k_channels = channels // n_heads
+ self.conv_q = torch.nn.Conv1d(channels, channels, 1)
+ self.conv_k = torch.nn.Conv1d(channels, channels, 1)
+ self.conv_v = torch.nn.Conv1d(channels, channels, 1)
+
+ # from https://nn.labml.ai/transformers/rope/index.html
+ self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
+ self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
+
+ self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
+ self.drop = torch.nn.Dropout(p_dropout)
+
+ torch.nn.init.xavier_uniform_(self.conv_q.weight)
+ torch.nn.init.xavier_uniform_(self.conv_k.weight)
+ if proximal_init:
+ self.conv_k.weight.data.copy_(self.conv_q.weight.data)
+ self.conv_k.bias.data.copy_(self.conv_q.bias.data)
+ torch.nn.init.xavier_uniform_(self.conv_v.weight)
+
+ def forward(self, x, c, attn_mask=None):
+ q = self.conv_q(x)
+ k = self.conv_k(c)
+ v = self.conv_v(c)
+
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
+
+ x = self.conv_o(x)
+ return x
+
+ def attention(self, query, key, value, mask=None):
+ b, d, t_s, t_t = (*key.size(), query.size(2))
+ query = rearrange(query, "b (h c) t-> b h t c", h=self.n_heads)
+ key = rearrange(key, "b (h c) t-> b h t c", h=self.n_heads)
+ value = rearrange(value, "b (h c) t-> b h t c", h=self.n_heads)
+
+ query = self.query_rotary_pe(query)
+ key = self.key_rotary_pe(key)
+
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
+
+ if self.proximal_bias:
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
+ scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
+ if mask is not None:
+ scores = scores.masked_fill(mask == 0, -1e4)
+ p_attn = torch.nn.functional.softmax(scores, dim=-1)
+ p_attn = self.drop(p_attn)
+ output = torch.matmul(p_attn, value)
+ output = output.transpose(2, 3).contiguous().view(b, d, t_t)
+ return output, p_attn
+
+ @staticmethod
+ def _attention_bias_proximal(length):
+ r = torch.arange(length, dtype=torch.float32)
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
+
+
+class FFN(nn.Module):
+ def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.filter_channels = filter_channels
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+
+ self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
+ self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2)
+ self.drop = torch.nn.Dropout(p_dropout)
+
+ def forward(self, x, x_mask):
+ x = self.conv_1(x * x_mask)
+ x = torch.relu(x)
+ x = self.drop(x)
+ x = self.conv_2(x * x_mask)
+ return x * x_mask
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size=1,
+ p_dropout=0.0,
+ **kwargs,
+ ):
+ super().__init__()
+ self.hidden_channels = hidden_channels
+ self.filter_channels = filter_channels
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+
+ self.drop = torch.nn.Dropout(p_dropout)
+ self.attn_layers = torch.nn.ModuleList()
+ self.norm_layers_1 = torch.nn.ModuleList()
+ self.ffn_layers = torch.nn.ModuleList()
+ self.norm_layers_2 = torch.nn.ModuleList()
+ for _ in range(self.n_layers):
+ self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
+ self.ffn_layers.append(
+ FFN(
+ hidden_channels,
+ hidden_channels,
+ filter_channels,
+ kernel_size,
+ p_dropout=p_dropout,
+ )
+ )
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
+
+ def forward(self, x, x_mask):
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
+ for i in range(self.n_layers):
+ x = x * x_mask
+ y = self.attn_layers[i](x, x, attn_mask)
+ y = self.drop(y)
+ x = self.norm_layers_1[i](x + y)
+ y = self.ffn_layers[i](x, x_mask)
+ y = self.drop(y)
+ x = self.norm_layers_2[i](x + y)
+ x = x * x_mask
+ return x
+
+
+class TextEncoder(nn.Module):
+ def __init__(
+ self,
+ encoder_type,
+ encoder_params,
+ duration_predictor_params,
+ n_vocab,
+ n_spks=1,
+ spk_emb_dim=128,
+ ):
+ super().__init__()
+ self.encoder_type = encoder_type
+ self.n_vocab = n_vocab
+ self.n_feats = encoder_params.n_feats
+ self.n_channels = encoder_params.n_channels
+ self.spk_emb_dim = spk_emb_dim
+ self.n_spks = n_spks
+
+ self.emb = torch.nn.Embedding(n_vocab, self.n_channels)
+ torch.nn.init.normal_(self.emb.weight, 0.0, self.n_channels**-0.5)
+
+ if encoder_params.prenet:
+ self.prenet = ConvReluNorm(
+ self.n_channels,
+ self.n_channels,
+ self.n_channels,
+ kernel_size=5,
+ n_layers=3,
+ p_dropout=0.5,
+ )
+ else:
+ self.prenet = lambda x, x_mask: x
+
+ self.encoder = Encoder(
+ encoder_params.n_channels + (spk_emb_dim if n_spks > 1 else 0),
+ encoder_params.filter_channels,
+ encoder_params.n_heads,
+ encoder_params.n_layers,
+ encoder_params.kernel_size,
+ encoder_params.p_dropout,
+ )
+
+ self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1)
+ self.proj_w = DurationPredictor(
+ self.n_channels + (spk_emb_dim if n_spks > 1 else 0),
+ duration_predictor_params.filter_channels_dp,
+ duration_predictor_params.kernel_size,
+ duration_predictor_params.p_dropout,
+ )
+
+ def forward(self, x, x_lengths, spks=None):
+ """Run forward pass to the transformer based encoder and duration predictor
+
+ Args:
+ x (torch.Tensor): text input
+ shape: (batch_size, max_text_length)
+ x_lengths (torch.Tensor): text input lengths
+ shape: (batch_size,)
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
+ shape: (batch_size,)
+
+ Returns:
+ mu (torch.Tensor): average output of the encoder
+ shape: (batch_size, n_feats, max_text_length)
+ logw (torch.Tensor): log duration predicted by the duration predictor
+ shape: (batch_size, 1, max_text_length)
+ x_mask (torch.Tensor): mask for the text input
+ shape: (batch_size, 1, max_text_length)
+ """
+ x = self.emb(x) * math.sqrt(self.n_channels)
+ x = torch.transpose(x, 1, -1)
+ x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
+
+ x = self.prenet(x, x_mask)
+ if self.n_spks > 1:
+ x = torch.cat([x, spks.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
+ x = self.encoder(x, x_mask)
+ mu = self.proj_m(x) * x_mask
+
+ x_dp = torch.detach(x)
+ logw = self.proj_w(x_dp, x_mask)
+
+ return mu, logw, x_mask
diff --git a/src/chatterbox/models/s3gen/matcha/transformer.py b/src/chatterbox/models/s3gen/matcha/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f6762b7b8dbb61c8063516f44a5109887f38760
--- /dev/null
+++ b/src/chatterbox/models/s3gen/matcha/transformer.py
@@ -0,0 +1,316 @@
+from typing import Any, Dict, Optional
+
+import torch
+import torch.nn as nn
+from diffusers.models.attention import (
+ GEGLU,
+ GELU,
+ AdaLayerNorm,
+ AdaLayerNormZero,
+ ApproximateGELU,
+)
+from diffusers.models.attention_processor import Attention
+from diffusers.models.lora import LoRACompatibleLinear
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+
+
+class SnakeBeta(nn.Module):
+ """
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
+ Shape:
+ - Input: (B, C, T)
+ - Output: (B, C, T), same shape as the input
+ Parameters:
+ - alpha - trainable parameter that controls frequency
+ - beta - trainable parameter that controls magnitude
+ References:
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
+ https://arxiv.org/abs/2006.08195
+ Examples:
+ >>> a1 = snakebeta(256)
+ >>> x = torch.randn(256)
+ >>> x = a1(x)
+ """
+
+ def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
+ """
+ Initialization.
+ INPUT:
+ - in_features: shape of the input
+ - alpha - trainable parameter that controls frequency
+ - beta - trainable parameter that controls magnitude
+ alpha is initialized to 1 by default, higher values = higher-frequency.
+ beta is initialized to 1 by default, higher values = higher-magnitude.
+ alpha will be trained along with the rest of your model.
+ """
+ super().__init__()
+ self.in_features = out_features if isinstance(out_features, list) else [out_features]
+ self.proj = LoRACompatibleLinear(in_features, out_features)
+
+ # initialize alpha
+ self.alpha_logscale = alpha_logscale
+ if self.alpha_logscale: # log scale alphas initialized to zeros
+ self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
+ self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
+ else: # linear scale alphas initialized to ones
+ self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
+ self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
+
+ self.alpha.requires_grad = alpha_trainable
+ self.beta.requires_grad = alpha_trainable
+
+ self.no_div_by_zero = 0.000000001
+
+ def forward(self, x):
+ """
+ Forward pass of the function.
+ Applies the function to the input elementwise.
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
+ """
+ x = self.proj(x)
+ if self.alpha_logscale:
+ alpha = torch.exp(self.alpha)
+ beta = torch.exp(self.beta)
+ else:
+ alpha = self.alpha
+ beta = self.beta
+
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
+
+ return x
+
+
+class FeedForward(nn.Module):
+ r"""
+ A feed-forward layer.
+
+ Parameters:
+ dim (`int`): The number of channels in the input.
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: Optional[int] = None,
+ mult: int = 4,
+ dropout: float = 0.0,
+ activation_fn: str = "geglu",
+ final_dropout: bool = False,
+ ):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = dim_out if dim_out is not None else dim
+
+ if activation_fn == "gelu":
+ act_fn = GELU(dim, inner_dim)
+ if activation_fn == "gelu-approximate":
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
+ elif activation_fn == "geglu":
+ act_fn = GEGLU(dim, inner_dim)
+ elif activation_fn == "geglu-approximate":
+ act_fn = ApproximateGELU(dim, inner_dim)
+ elif activation_fn == "snakebeta":
+ act_fn = SnakeBeta(dim, inner_dim)
+
+ self.net = nn.ModuleList([])
+ # project in
+ self.net.append(act_fn)
+ # project dropout
+ self.net.append(nn.Dropout(dropout))
+ # project out
+ self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
+ if final_dropout:
+ self.net.append(nn.Dropout(dropout))
+
+ def forward(self, hidden_states):
+ for module in self.net:
+ hidden_states = module(hidden_states)
+ return hidden_states
+
+
+@maybe_allow_in_graph
+class BasicTransformerBlock(nn.Module):
+ r"""
+ A basic Transformer block.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ only_cross_attention (`bool`, *optional*):
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
+ double_self_attention (`bool`, *optional*):
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm (:
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
+ attention_bias (:
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ double_self_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_elementwise_affine: bool = True,
+ norm_type: str = "layer_norm",
+ final_dropout: bool = False,
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
+
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
+ raise ValueError(
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
+ )
+
+ # Define 3 blocks. Each block has its own normalization layer.
+ # 1. Self-Attn
+ if self.use_ada_layer_norm:
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ elif self.use_ada_layer_norm_zero:
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
+ else:
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ )
+
+ # 2. Cross-Attn
+ if cross_attention_dim is not None or double_self_attention:
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
+ # the second cross attention block.
+ self.norm2 = (
+ AdaLayerNorm(dim, num_embeds_ada_norm)
+ if self.use_ada_layer_norm
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+ )
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ # scale_qk=False, # uncomment this to not to use flash attention
+ ) # is self-attn if encoder_hidden_states is none
+ else:
+ self.norm2 = None
+ self.attn2 = None
+
+ # 3. Feed-forward
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = 0
+
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ self._chunk_dim = dim
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ ):
+ # Notice that normalization is always applied before the real computation in the following blocks.
+ # 1. Self-Attention
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ else:
+ norm_hidden_states = self.norm1(hidden_states)
+
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
+ **cross_attention_kwargs,
+ )
+ if self.use_ada_layer_norm_zero:
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = attn_output + hidden_states
+
+ # 2. Cross-Attention
+ if self.attn2 is not None:
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 3. Feed-forward
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ if self._chunk_size is not None:
+ # "feed_forward_chunk_size" can be used to save memory
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
+ raise ValueError(
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
+ )
+
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
+ ff_output = torch.cat(
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
+ dim=self._chunk_dim,
+ )
+ else:
+ ff_output = self.ff(norm_hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = ff_output + hidden_states
+
+ return hidden_states
diff --git a/src/chatterbox/models/s3gen/s3gen.py b/src/chatterbox/models/s3gen/s3gen.py
new file mode 100644
index 0000000000000000000000000000000000000000..c61134436a6fe2af1802d04f7400355c0b727455
--- /dev/null
+++ b/src/chatterbox/models/s3gen/s3gen.py
@@ -0,0 +1,305 @@
+# Modified from CosyVoice https://github.com/FunAudioLLM/CosyVoice
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+import numpy as np
+import torch
+import torchaudio as ta
+from functools import lru_cache
+from typing import Optional
+from omegaconf import DictConfig
+
+from ..s3tokenizer import S3_SR, SPEECH_VOCAB_SIZE, S3Tokenizer
+from .const import S3GEN_SR
+from .flow import CausalMaskedDiffWithXvec
+from .xvector import CAMPPlus
+from .utils.mel import mel_spectrogram
+from .f0_predictor import ConvRNNF0Predictor
+from .hifigan import HiFTGenerator
+from .transformer.upsample_encoder import UpsampleConformerEncoder
+from .flow_matching import CausalConditionalCFM
+from .decoder import ConditionalDecoder
+
+
+def drop_invalid_tokens(x):
+ assert len(x.shape) <= 2 and x.shape[0] == 1, "only batch size of one allowed for now"
+ return x[x < SPEECH_VOCAB_SIZE]
+
+
+# TODO: global resampler cache
+@lru_cache(100)
+def get_resampler(src_sr, dst_sr, device):
+ return ta.transforms.Resample(src_sr, dst_sr).to(device)
+
+
+class S3Token2Mel(torch.nn.Module):
+ """
+ CosyVoice2's CFM decoder maps S3 speech tokens to mel-spectrograms.
+
+ TODO: make these modules configurable?
+ """
+ def __init__(self):
+ super().__init__()
+ self.tokenizer = S3Tokenizer("speech_tokenizer_v2_25hz")
+ self.mel_extractor = mel_spectrogram # TODO: make it a torch module?
+ self.speaker_encoder = CAMPPlus() # use default args
+
+ encoder = UpsampleConformerEncoder(
+ output_size=512,
+ attention_heads=8,
+ linear_units=2048,
+ num_blocks=6,
+ dropout_rate=0.1,
+ positional_dropout_rate=0.1,
+ attention_dropout_rate=0.1,
+ normalize_before=True,
+ input_layer='linear',
+ pos_enc_layer_type='rel_pos_espnet',
+ selfattention_layer_type='rel_selfattn',
+ input_size=512,
+ use_cnn_module=False,
+ macaron_style=False,
+ )
+
+ estimator = ConditionalDecoder(
+ in_channels=320,
+ out_channels=80,
+ causal=True,
+ channels=[256],
+ dropout=0.0,
+ attention_head_dim=64,
+ n_blocks=4,
+ num_mid_blocks=12,
+ num_heads=8,
+ act_fn='gelu',
+ )
+ cfm_params = DictConfig({
+ "sigma_min": 1e-06,
+ "solver": 'euler',
+ "t_scheduler": 'cosine',
+ "training_cfg_rate": 0.2,
+ "inference_cfg_rate": 0.7,
+ "reg_loss_type": 'l1',
+ })
+ decoder = CausalConditionalCFM(
+ spk_emb_dim=80,
+ cfm_params=cfm_params,
+ estimator=estimator,
+ )
+
+ self.flow = CausalMaskedDiffWithXvec(
+ encoder=encoder,
+ decoder=decoder
+ )
+
+ self.resamplers = {}
+
+ @property
+ def device(self):
+ params = self.tokenizer.parameters()
+ return next(params).device
+
+ def embed_ref(
+ self,
+ ref_wav: torch.Tensor,
+ ref_sr: int,
+ device="auto",
+ ref_fade_out=True,
+ ):
+ device = self.device if device == "auto" else device
+ if isinstance(ref_wav, np.ndarray):
+ ref_wav = torch.from_numpy(ref_wav).float()
+
+ if ref_wav.device != device:
+ ref_wav = ref_wav.to(device)
+
+ if len(ref_wav.shape) == 1:
+ ref_wav = ref_wav.unsqueeze(0) # (B, L)
+
+ if ref_wav.size(1) > 10 * ref_sr:
+ print("WARNING: cosydec received ref longer than 10s")
+
+ ref_wav_24 = ref_wav
+ if ref_sr != S3GEN_SR:
+ ref_wav_24 = get_resampler(ref_sr, S3GEN_SR, device)(ref_wav)
+
+ ref_mels_24 = self.mel_extractor(ref_wav_24).transpose(1, 2).to(device)
+ ref_mels_24_len = None
+
+ # Resample to 16kHz
+ ref_wav_16 = get_resampler(ref_sr, S3_SR, device)(ref_wav).to(device)
+
+ # Speaker embedding
+ ref_x_vector = self.speaker_encoder.inference(ref_wav_16)
+
+ # Tokenize 16khz reference
+ ref_speech_tokens, ref_speech_token_lens = self.tokenizer(ref_wav_16)
+
+ # Make sure mel_len = 2 * stoken_len (happens when the input is not padded to multiple of 40ms)
+ if ref_mels_24.shape[1] != 2 * ref_speech_tokens.shape[1]:
+ logging.warning(
+ "Reference mel length is not equal to 2 * reference token length.\n"
+ )
+ ref_speech_tokens = ref_speech_tokens[:, :ref_mels_24.shape[1] // 2]
+ ref_speech_token_lens[0] = ref_speech_tokens.shape[1]
+
+ return dict(
+ prompt_token=ref_speech_tokens.to(device),
+ prompt_token_len=ref_speech_token_lens,
+ prompt_feat=ref_mels_24,
+ prompt_feat_len=ref_mels_24_len,
+ embedding=ref_x_vector,
+ )
+
+ def forward(
+ self,
+ speech_tokens: torch.LongTensor,
+ # locally-computed ref embedding (mutex with ref_dict)
+ ref_wav: Optional[torch.Tensor],
+ ref_sr: Optional[int],
+ # pre-computed ref embedding (prod API)
+ ref_dict: Optional[dict] = None,
+ finalize: bool = False,
+ ):
+ """
+ Generate waveforms from S3 speech tokens and a reference waveform, which the speaker timbre is inferred from.
+
+ NOTE:
+ - The speaker encoder accepts 16 kHz waveform.
+ - S3TokenizerV2 accepts 16 kHz waveform.
+ - The mel-spectrogram for the reference assumes 24 kHz input signal.
+ - This function is designed for batch_size=1 only.
+
+ Args
+ ----
+ - `speech_tokens`: S3 speech tokens [B=1, T]
+ - `ref_wav`: reference waveform (`torch.Tensor` with shape=[B=1, T])
+ - `ref_sr`: reference sample rate
+ - `finalize`: whether streaming is finished or not. Note that if False, the last 3 tokens will be ignored.
+ """
+ assert (ref_wav is None) ^ (ref_dict is None), f"Must provide exactly one of ref_wav or ref_dict (got {ref_wav} and {ref_dict})"
+
+ if ref_dict is None:
+ ref_dict = self.embed_ref(ref_wav, ref_sr)
+ else:
+ # type/device casting (all values will be numpy if it's from a prod API call)
+ for rk in list(ref_dict):
+ if isinstance(ref_dict[rk], np.ndarray):
+ ref_dict[rk] = torch.from_numpy(ref_dict[rk])
+ if torch.is_tensor(ref_dict[rk]):
+ ref_dict[rk] = ref_dict[rk].to(self.device)
+
+ if len(speech_tokens.shape) == 1:
+ speech_tokens = speech_tokens.unsqueeze(0)
+
+ # assert speech_tokens.shape[0] == 1, "only batch size of one allowed for now"
+ speech_token_lens = torch.LongTensor([speech_tokens.size(1)]).to(self.device)
+
+ output_mels, _ = self.flow.inference(
+ token=speech_tokens,
+ token_len=speech_token_lens,
+ finalize=finalize,
+ **ref_dict,
+ )
+ return output_mels
+
+
+class S3Token2Wav(S3Token2Mel):
+ """
+ The decoder of CosyVoice2 is a concat of token-to-mel (CFM) and a mel-to-waveform (HiFiGAN) modules.
+
+ TODO: make these modules configurable?
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ f0_predictor = ConvRNNF0Predictor()
+ self.mel2wav = HiFTGenerator(
+ sampling_rate=S3GEN_SR,
+ upsample_rates=[8, 5, 3],
+ upsample_kernel_sizes=[16, 11, 7],
+ source_resblock_kernel_sizes=[7, 7, 11],
+ source_resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+ f0_predictor=f0_predictor,
+ )
+
+ # silence out a few ms and fade audio in to reduce artifacts
+ n_trim = S3GEN_SR // 50 # 20ms = half of a frame
+ trim_fade = torch.zeros(2 * n_trim)
+ trim_fade[n_trim:] = (torch.cos(torch.linspace(torch.pi, 0, n_trim)) + 1) / 2
+ self.register_buffer("trim_fade", trim_fade, persistent=False) # (buffers get automatic device casting)
+
+ def forward(
+ self,
+ speech_tokens,
+ # locally-computed ref embedding (mutex with ref_dict)
+ ref_wav: Optional[torch.Tensor],
+ ref_sr: Optional[int],
+ # pre-computed ref embedding (prod API)
+ ref_dict: Optional[dict] = None,
+ finalize: bool = False
+ ):
+ output_mels = super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
+
+ # TODO jrm: ignoring the speed control (mel interpolation) and the HiFTGAN caching mechanisms for now.
+ hift_cache_source = torch.zeros(1, 1, 0).to(self.device)
+
+ output_wavs, *_ = self.mel2wav.inference(speech_feat=output_mels, cache_source=hift_cache_source)
+
+ if not self.training:
+ # NOTE: ad-hoc method to reduce "spillover" from the reference clip.
+ output_wavs[:, :len(self.trim_fade)] *= self.trim_fade
+
+ return output_wavs
+
+ @torch.inference_mode()
+ def flow_inference(
+ self,
+ speech_tokens,
+ # locally-computed ref embedding (mutex with ref_dict)
+ ref_wav: Optional[torch.Tensor] = None,
+ ref_sr: Optional[int] = None,
+ # pre-computed ref embedding (prod API)
+ ref_dict: Optional[dict] = None,
+ finalize: bool = False,
+ ):
+ return super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
+
+ @torch.inference_mode()
+ def hift_inference(self, speech_feat, cache_source: torch.Tensor = None):
+ if cache_source is None:
+ cache_source = torch.zeros(1, 1, 0).to(self.device)
+ return self.mel2wav.inference(speech_feat=speech_feat, cache_source=cache_source)
+
+ @torch.inference_mode()
+ def inference(
+ self,
+ speech_tokens,
+ # locally-computed ref embedding (mutex with ref_dict)
+ ref_wav: Optional[torch.Tensor] = None,
+ ref_sr: Optional[int] = None,
+ # pre-computed ref embedding (prod API)
+ ref_dict: Optional[dict] = None,
+ cache_source: torch.Tensor = None, # NOTE: this arg is for streaming, it can probably be removed here
+ finalize: bool = True,
+ ):
+ output_mels = self.flow_inference(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
+ output_wavs, output_sources = self.hift_inference(output_mels, cache_source)
+
+ # NOTE: ad-hoc method to reduce "spillover" from the reference clip.
+ output_wavs[:, :len(self.trim_fade)] *= self.trim_fade
+
+ return output_wavs, output_sources
diff --git a/src/chatterbox/models/s3gen/transformer/__init__.py b/src/chatterbox/models/s3gen/transformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/chatterbox/models/s3gen/transformer/activation.py b/src/chatterbox/models/s3gen/transformer/activation.py
new file mode 100644
index 0000000000000000000000000000000000000000..4604ce7e164c5f224767a6e144b3fbfce63a80ec
--- /dev/null
+++ b/src/chatterbox/models/s3gen/transformer/activation.py
@@ -0,0 +1,84 @@
+# Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
+# 2020 Northwestern Polytechnical University (Pengcheng Guo)
+# 2020 Mobvoi Inc (Binbin Zhang)
+# 2024 Alibaba Inc (Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Swish() activation function for Conformer."""
+
+import torch
+from torch import nn, sin, pow
+from torch.nn import Parameter
+
+
+class Swish(torch.nn.Module):
+ """Construct an Swish object."""
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Return Swish activation function."""
+ return x * torch.sigmoid(x)
+
+
+# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
+# LICENSE is in incl_licenses directory.
+class Snake(nn.Module):
+ '''
+ Implementation of a sine-based periodic activation function
+ Shape:
+ - Input: (B, C, T)
+ - Output: (B, C, T), same shape as the input
+ Parameters:
+ - alpha - trainable parameter
+ References:
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
+ https://arxiv.org/abs/2006.08195
+ Examples:
+ >>> a1 = snake(256)
+ >>> x = torch.randn(256)
+ >>> x = a1(x)
+ '''
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
+ '''
+ Initialization.
+ INPUT:
+ - in_features: shape of the input
+ - alpha: trainable parameter
+ alpha is initialized to 1 by default, higher values = higher-frequency.
+ alpha will be trained along with the rest of your model.
+ '''
+ super(Snake, self).__init__()
+ self.in_features = in_features
+
+ # initialize alpha
+ self.alpha_logscale = alpha_logscale
+ if self.alpha_logscale: # log scale alphas initialized to zeros
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
+ else: # linear scale alphas initialized to ones
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
+
+ self.alpha.requires_grad = alpha_trainable
+
+ self.no_div_by_zero = 0.000000001
+
+ def forward(self, x):
+ '''
+ Forward pass of the function.
+ Applies the function to the input elementwise.
+ Snake ∶= x + 1/a * sin^2 (xa)
+ '''
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
+ if self.alpha_logscale:
+ alpha = torch.exp(alpha)
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
+
+ return x
diff --git a/src/chatterbox/models/s3gen/transformer/attention.py b/src/chatterbox/models/s3gen/transformer/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..a382c0af9a4e68d030e9288f5fd3824a5511f8b0
--- /dev/null
+++ b/src/chatterbox/models/s3gen/transformer/attention.py
@@ -0,0 +1,330 @@
+# Copyright (c) 2019 Shigeki Karita
+# 2020 Mobvoi Inc (Binbin Zhang)
+# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
+# 2024 Alibaba Inc (Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Multi-Head Attention layer definition."""
+
+import math
+from typing import Tuple
+
+import torch
+from torch import nn
+
+
+class MultiHeadedAttention(nn.Module):
+ """Multi-Head Attention layer.
+
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self,
+ n_head: int,
+ n_feat: int,
+ dropout_rate: float,
+ key_bias: bool = True):
+ """Construct an MultiHeadedAttention object."""
+ super().__init__()
+ assert n_feat % n_head == 0
+ # We assume d_v always equals d_k
+ self.d_k = n_feat // n_head
+ self.h = n_head
+ self.linear_q = nn.Linear(n_feat, n_feat)
+ self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
+ self.linear_v = nn.Linear(n_feat, n_feat)
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ def forward_qkv(
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Transform query, key and value.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+
+ Returns:
+ torch.Tensor: Transformed query tensor, size
+ (#batch, n_head, time1, d_k).
+ torch.Tensor: Transformed key tensor, size
+ (#batch, n_head, time2, d_k).
+ torch.Tensor: Transformed value tensor, size
+ (#batch, n_head, time2, d_k).
+
+ """
+ n_batch = query.size(0)
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
+
+ return q, k, v
+
+ def forward_attention(
+ self,
+ value: torch.Tensor,
+ scores: torch.Tensor,
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
+ ) -> torch.Tensor:
+ """Compute attention context vector.
+
+ Args:
+ value (torch.Tensor): Transformed value, size
+ (#batch, n_head, time2, d_k).
+ scores (torch.Tensor): Attention score, size
+ (#batch, n_head, time1, time2).
+ mask (torch.Tensor): Mask, size (#batch, 1, time2) or
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
+
+ Returns:
+ torch.Tensor: Transformed value (#batch, time1, d_model)
+ weighted by the attention score (#batch, time1, time2).
+
+ """
+ n_batch = value.size(0)
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
+ # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
+ # 1st chunk to ease the onnx export.]
+ # 2. pytorch training
+ if mask.size(2) > 0: # time2 > 0
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
+ # For last chunk, time2 might be larger than scores.size(-1)
+ mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
+ scores = scores.masked_fill(mask, -float('inf'))
+ attn = torch.softmax(scores, dim=-1).masked_fill(
+ mask, 0.0) # (batch, head, time1, time2)
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
+ # 1. onnx(16/-1, -1/-1, 16/0)
+ # 2. jit (16/-1, -1/-1, 16/0, 16/4)
+ else:
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+
+ p_attn = self.dropout(attn)
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
+ self.h * self.d_k)
+ ) # (batch, time1, d_model)
+
+ return self.linear_out(x) # (batch, time1, d_model)
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ pos_emb: torch.Tensor = torch.empty(0),
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute scaled dot product attention.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+ 1.When applying cross attention between decoder and encoder,
+ the batch padding mask for input is in (#batch, 1, T) shape.
+ 2.When applying self attention of encoder,
+ the mask is in (#batch, T, T) shape.
+ 3.When applying self attention of decoder,
+ the mask is in (#batch, L, L) shape.
+ 4.If the different position in decoder see different block
+ of the encoder, such as Mocha, the passed in mask could be
+ in (#batch, L, T) shape. But there is no such case in current
+ CosyVoice.
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
+ where `cache_t == chunk_size * num_decoding_left_chunks`
+ and `head * d_k == size`
+
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
+ where `cache_t == chunk_size * num_decoding_left_chunks`
+ and `head * d_k == size`
+
+ """
+ q, k, v = self.forward_qkv(query, key, value)
+
+ # NOTE(xcsong):
+ # when export onnx model, for 1st chunk, we feed
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
+ # and we will always do splitting and
+ # concatnation(this will simplify onnx export). Note that
+ # it's OK to concat & split zero-shaped tensors(see code below).
+ # when export jit model, for 1st chunk, we always feed
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
+ # >>> a = torch.ones((1, 2, 0, 4))
+ # >>> b = torch.ones((1, 2, 3, 4))
+ # >>> c = torch.cat((a, b), dim=2)
+ # >>> torch.equal(b, c) # True
+ # >>> d = torch.split(a, 2, dim=-1)
+ # >>> torch.equal(d[0], d[1]) # True
+ if cache.size(0) > 0:
+ key_cache, value_cache = torch.split(cache,
+ cache.size(-1) // 2,
+ dim=-1)
+ k = torch.cat([key_cache, k], dim=2)
+ v = torch.cat([value_cache, v], dim=2)
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
+ # non-trivial to calculate `next_cache_start` here.
+ new_cache = torch.cat((k, v), dim=-1)
+
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
+ return self.forward_attention(v, scores, mask), new_cache
+
+
+class RelPositionMultiHeadedAttention(MultiHeadedAttention):
+ """Multi-Head Attention layer with relative position encoding.
+ Paper: https://arxiv.org/abs/1901.02860
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+ """
+
+ def __init__(self,
+ n_head: int,
+ n_feat: int,
+ dropout_rate: float,
+ key_bias: bool = True):
+ """Construct an RelPositionMultiHeadedAttention object."""
+ super().__init__(n_head, n_feat, dropout_rate, key_bias)
+ # linear transformation for positional encoding
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
+ # these two learnable bias are used in matrix c and matrix d
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
+
+ def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
+ """Compute relative positional encoding.
+
+ Args:
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
+ time1 means the length of query vector.
+
+ Returns:
+ torch.Tensor: Output tensor.
+
+ """
+ zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
+ device=x.device,
+ dtype=x.dtype)
+ x_padded = torch.cat([zero_pad, x], dim=-1)
+
+ x_padded = x_padded.view(x.size()[0],
+ x.size()[1],
+ x.size(3) + 1, x.size(2))
+ x = x_padded[:, :, 1:].view_as(x)[
+ :, :, :, : x.size(-1) // 2 + 1
+ ] # only keep the positions from 0 to time2
+ return x
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ pos_emb: torch.Tensor = torch.empty(0),
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
+ pos_emb (torch.Tensor): Positional embedding tensor
+ (#batch, time2, size).
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
+ where `cache_t == chunk_size * num_decoding_left_chunks`
+ and `head * d_k == size`
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
+ where `cache_t == chunk_size * num_decoding_left_chunks`
+ and `head * d_k == size`
+ """
+ q, k, v = self.forward_qkv(query, key, value)
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
+
+ # NOTE(xcsong):
+ # when export onnx model, for 1st chunk, we feed
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
+ # and we will always do splitting and
+ # concatnation(this will simplify onnx export). Note that
+ # it's OK to concat & split zero-shaped tensors(see code below).
+ # when export jit model, for 1st chunk, we always feed
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
+ # >>> a = torch.ones((1, 2, 0, 4))
+ # >>> b = torch.ones((1, 2, 3, 4))
+ # >>> c = torch.cat((a, b), dim=2)
+ # >>> torch.equal(b, c) # True
+ # >>> d = torch.split(a, 2, dim=-1)
+ # >>> torch.equal(d[0], d[1]) # True
+ if cache.size(0) > 0:
+ key_cache, value_cache = torch.split(cache,
+ cache.size(-1) // 2,
+ dim=-1)
+ k = torch.cat([key_cache, k], dim=2)
+ v = torch.cat([value_cache, v], dim=2)
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
+ # non-trivial to calculate `next_cache_start` here.
+ new_cache = torch.cat((k, v), dim=-1)
+
+ n_batch_pos = pos_emb.size(0)
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
+
+ # (batch, head, time1, d_k)
+ q_with_bias_u = (q + self.pos_bias_u.to(q.device)).transpose(1, 2)
+ # (batch, head, time1, d_k)
+ q_with_bias_v = (q + self.pos_bias_v.to(q.device)).transpose(1, 2)
+
+ # compute attention score
+ # first compute matrix a and matrix c
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ # (batch, head, time1, time2)
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
+
+ # compute matrix b and matrix d
+ # (batch, head, time1, time2)
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
+ # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
+ if matrix_ac.shape != matrix_bd.shape:
+ matrix_bd = self.rel_shift(matrix_bd)
+
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
+ self.d_k) # (batch, head, time1, time2)
+
+ return self.forward_attention(v, scores, mask), new_cache
diff --git a/src/chatterbox/models/s3gen/transformer/convolution.py b/src/chatterbox/models/s3gen/transformer/convolution.py
new file mode 100644
index 0000000000000000000000000000000000000000..1249400455c7080297d35bf078953bf2c80e829b
--- /dev/null
+++ b/src/chatterbox/models/s3gen/transformer/convolution.py
@@ -0,0 +1,145 @@
+# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
+# 2024 Alibaba Inc (Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from ESPnet(https://github.com/espnet/espnet)
+"""ConvolutionModule definition."""
+
+from typing import Tuple
+
+import torch
+from torch import nn
+
+
+class ConvolutionModule(nn.Module):
+ """ConvolutionModule in Conformer model."""
+
+ def __init__(self,
+ channels: int,
+ kernel_size: int = 15,
+ activation: nn.Module = nn.ReLU(),
+ norm: str = "batch_norm",
+ causal: bool = False,
+ bias: bool = True):
+ """Construct an ConvolutionModule object.
+ Args:
+ channels (int): The number of channels of conv layers.
+ kernel_size (int): Kernel size of conv layers.
+ causal (int): Whether use causal convolution or not
+ """
+ super().__init__()
+
+ self.pointwise_conv1 = nn.Conv1d(
+ channels,
+ 2 * channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ # self.lorder is used to distinguish if it's a causal convolution,
+ # if self.lorder > 0: it's a causal convolution, the input will be
+ # padded with self.lorder frames on the left in forward.
+ # else: it's a symmetrical convolution
+ if causal:
+ padding = 0
+ self.lorder = kernel_size - 1
+ else:
+ # kernel_size should be an odd number for none causal convolution
+ assert (kernel_size - 1) % 2 == 0
+ padding = (kernel_size - 1) // 2
+ self.lorder = 0
+ self.depthwise_conv = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=padding,
+ groups=channels,
+ bias=bias,
+ )
+
+ assert norm in ['batch_norm', 'layer_norm']
+ if norm == "batch_norm":
+ self.use_layer_norm = False
+ self.norm = nn.BatchNorm1d(channels)
+ else:
+ self.use_layer_norm = True
+ self.norm = nn.LayerNorm(channels)
+
+ self.pointwise_conv2 = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ self.activation = activation
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ cache: torch.Tensor = torch.zeros((0, 0, 0)),
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute convolution module.
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, channels).
+ mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
+ (0, 0, 0) means fake mask.
+ cache (torch.Tensor): left context cache, it is only
+ used in causal convolution (#batch, channels, cache_t),
+ (0, 0, 0) meas fake cache.
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, channels).
+ """
+ # exchange the temporal dimension and the feature dimension
+ x = x.transpose(1, 2) # (#batch, channels, time)
+
+ # mask batch padding
+ if mask_pad.size(2) > 0: # time > 0
+ x.masked_fill_(~mask_pad, 0.0)
+
+ if self.lorder > 0:
+ if cache.size(2) == 0: # cache_t == 0
+ x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
+ else:
+ assert cache.size(0) == x.size(0) # equal batch
+ assert cache.size(1) == x.size(1) # equal channel
+ x = torch.cat((cache, x), dim=2)
+ assert (x.size(2) > self.lorder)
+ new_cache = x[:, :, -self.lorder:]
+ else:
+ # It's better we just return None if no cache is required,
+ # However, for JIT export, here we just fake one tensor instead of
+ # None.
+ new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
+
+ # GLU mechanism
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
+
+ # 1D Depthwise Conv
+ x = self.depthwise_conv(x)
+ if self.use_layer_norm:
+ x = x.transpose(1, 2)
+ x = self.activation(self.norm(x))
+ if self.use_layer_norm:
+ x = x.transpose(1, 2)
+ x = self.pointwise_conv2(x)
+ # mask batch padding
+ if mask_pad.size(2) > 0: # time > 0
+ x.masked_fill_(~mask_pad, 0.0)
+
+ return x.transpose(1, 2), new_cache
diff --git a/src/chatterbox/models/s3gen/transformer/embedding.py b/src/chatterbox/models/s3gen/transformer/embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..298cf47b34cc6de3b7fcec7924f77da04be11e71
--- /dev/null
+++ b/src/chatterbox/models/s3gen/transformer/embedding.py
@@ -0,0 +1,294 @@
+# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
+# 2024 Alibaba Inc (Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from ESPnet(https://github.com/espnet/espnet)
+"""Positonal Encoding Module."""
+
+import math
+from typing import Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import numpy as np
+
+
+class PositionalEncoding(torch.nn.Module):
+ """Positional encoding.
+
+ :param int d_model: embedding dim
+ :param float dropout_rate: dropout rate
+ :param int max_len: maximum input length
+
+ PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
+ PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
+ """
+
+ def __init__(self,
+ d_model: int,
+ dropout_rate: float,
+ max_len: int = 5000,
+ reverse: bool = False):
+ """Construct an PositionalEncoding object."""
+ super().__init__()
+ self.d_model = d_model
+ self.xscale = math.sqrt(self.d_model)
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+ self.max_len = max_len
+
+ self.pe = torch.zeros(self.max_len, self.d_model)
+ position = torch.arange(0, self.max_len,
+ dtype=torch.float32).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, self.d_model, 2, dtype=torch.float32) *
+ -(math.log(10000.0) / self.d_model))
+ self.pe[:, 0::2] = torch.sin(position * div_term)
+ self.pe[:, 1::2] = torch.cos(position * div_term)
+ self.pe = self.pe.unsqueeze(0)
+
+ def forward(self,
+ x: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0) \
+ -> Tuple[torch.Tensor, torch.Tensor]:
+ """Add positional encoding.
+
+ Args:
+ x (torch.Tensor): Input. Its shape is (batch, time, ...)
+ offset (int, torch.tensor): position offset
+
+ Returns:
+ torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
+ torch.Tensor: for compatibility to RelPositionalEncoding
+ """
+
+ self.pe = self.pe.to(x.device)
+ pos_emb = self.position_encoding(offset, x.size(1), False)
+ x = x * self.xscale + pos_emb
+ return self.dropout(x), self.dropout(pos_emb)
+
+ def position_encoding(self,
+ offset: Union[int, torch.Tensor],
+ size: int,
+ apply_dropout: bool = True) -> torch.Tensor:
+ """ For getting encoding in a streaming fashion
+
+ Attention!!!!!
+ we apply dropout only once at the whole utterance level in a none
+ streaming way, but will call this function several times with
+ increasing input size in a streaming scenario, so the dropout will
+ be applied several times.
+
+ Args:
+ offset (int or torch.tensor): start offset
+ size (int): required size of position encoding
+
+ Returns:
+ torch.Tensor: Corresponding encoding
+ """
+ # How to subscript a Union type:
+ # https://github.com/pytorch/pytorch/issues/69434
+ if isinstance(offset, int):
+ assert offset + size <= self.max_len
+ pos_emb = self.pe[:, offset:offset + size]
+ elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
+ assert offset + size <= self.max_len
+ pos_emb = self.pe[:, offset:offset + size]
+ else: # for batched streaming decoding on GPU
+ assert torch.max(offset) + size <= self.max_len
+ index = offset.unsqueeze(1) + \
+ torch.arange(0, size).to(offset.device) # B X T
+ flag = index > 0
+ # remove negative offset
+ index = index * flag
+ pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model
+
+ if apply_dropout:
+ pos_emb = self.dropout(pos_emb)
+ return pos_emb
+
+
+class RelPositionalEncoding(PositionalEncoding):
+ """Relative positional encoding module.
+ See : Appendix B in https://arxiv.org/abs/1901.02860
+ Args:
+ d_model (int): Embedding dimension.
+ dropout_rate (float): Dropout rate.
+ max_len (int): Maximum input length.
+ """
+
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
+ """Initialize class."""
+ super().__init__(d_model, dropout_rate, max_len, reverse=True)
+
+ def forward(self,
+ x: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0) \
+ -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute positional encoding.
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
+ """
+ self.pe = self.pe.to(x.device)
+ x = x * self.xscale
+ pos_emb = self.position_encoding(offset, x.size(1), False)
+ return self.dropout(x), self.dropout(pos_emb)
+
+
+class WhisperPositionalEncoding(PositionalEncoding):
+ """ Sinusoids position encoding used in openai-whisper.encoder
+ """
+
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500):
+ super().__init__(d_model, dropout_rate, max_len)
+ self.xscale = 1.0
+ log_timescale_increment = np.log(10000) / (d_model // 2 - 1)
+ inv_timescales = torch.exp(-log_timescale_increment *
+ torch.arange(d_model // 2))
+ scaled_time = torch.arange(max_len)[:, np.newaxis] * \
+ inv_timescales[np.newaxis, :]
+ pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
+ delattr(self, "pe")
+ self.register_buffer("pe", pe.unsqueeze(0))
+
+
+class LearnablePositionalEncoding(PositionalEncoding):
+ """ Learnable position encoding used in openai-whisper.decoder
+ """
+
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448):
+ super().__init__(d_model, dropout_rate, max_len)
+ # NOTE(xcsong): overwrite self.pe & self.xscale
+ self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model))
+ self.xscale = 1.0
+
+
+class NoPositionalEncoding(torch.nn.Module):
+ """ No position encoding
+ """
+
+ def __init__(self, d_model: int, dropout_rate: float):
+ super().__init__()
+ self.d_model = d_model
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+
+ def forward(self,
+ x: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0) \
+ -> Tuple[torch.Tensor, torch.Tensor]:
+ """ Just return zero vector for interface compatibility
+ """
+ pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
+ return self.dropout(x), pos_emb
+
+ def position_encoding(self, offset: Union[int, torch.Tensor],
+ size: int) -> torch.Tensor:
+ return torch.zeros(1, size, self.d_model)
+
+
+class EspnetRelPositionalEncoding(torch.nn.Module):
+ """Relative positional encoding module (new implementation).
+
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
+
+ See : Appendix B in https://arxiv.org/abs/1901.02860
+
+ Args:
+ d_model (int): Embedding dimension.
+ dropout_rate (float): Dropout rate.
+ max_len (int): Maximum input length.
+
+ """
+
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
+ """Construct an PositionalEncoding object."""
+ super(EspnetRelPositionalEncoding, self).__init__()
+ self.d_model = d_model
+ self.xscale = math.sqrt(self.d_model)
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+ self.pe = None
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
+
+ def extend_pe(self, x: torch.Tensor):
+ """Reset the positional encodings."""
+ if self.pe is not None:
+ # self.pe contains both positive and negative parts
+ # the length of self.pe is 2 * input_len - 1
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+ # Suppose `i` means to the position of query vecotr and `j` means the
+ # position of key vector. We use position relative positions when keys
+ # are to the left (i>j) and negative relative positions otherwise (i Tuple[torch.Tensor, torch.Tensor]:
+ """Add positional encoding.
+
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+
+ """
+ self.extend_pe(x)
+ x = x * self.xscale
+ pos_emb = self.position_encoding(size=x.size(1), offset=offset)
+ return self.dropout(x), self.dropout(pos_emb)
+
+ def position_encoding(self,
+ offset: Union[int, torch.Tensor],
+ size: int) -> torch.Tensor:
+ """ For getting encoding in a streaming fashion
+
+ Attention!!!!!
+ we apply dropout only once at the whole utterance level in a none
+ streaming way, but will call this function several times with
+ increasing input size in a streaming scenario, so the dropout will
+ be applied several times.
+
+ Args:
+ offset (int or torch.tensor): start offset
+ size (int): required size of position encoding
+
+ Returns:
+ torch.Tensor: Corresponding encoding
+ """
+ pos_emb = self.pe[
+ :,
+ self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
+ ]
+ return pos_emb
diff --git a/src/chatterbox/models/s3gen/transformer/encoder_layer.py b/src/chatterbox/models/s3gen/transformer/encoder_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..2577e78ffefcbda4ee33a5db97db6cb2ee3070c0
--- /dev/null
+++ b/src/chatterbox/models/s3gen/transformer/encoder_layer.py
@@ -0,0 +1,236 @@
+# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
+# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from ESPnet(https://github.com/espnet/espnet)
+"""Encoder self-attention layer definition."""
+
+from typing import Optional, Tuple
+
+import torch
+from torch import nn
+
+
+class TransformerEncoderLayer(nn.Module):
+ """Encoder layer module.
+
+ Args:
+ size (int): Input dimension.
+ self_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
+ instance can be used as the argument.
+ feed_forward (torch.nn.Module): Feed-forward module instance.
+ `PositionwiseFeedForward`, instance can be used as the argument.
+ dropout_rate (float): Dropout rate.
+ normalize_before (bool):
+ True: use layer_norm before each sub-block.
+ False: to use layer_norm after each sub-block.
+ """
+
+ def __init__(
+ self,
+ size: int,
+ self_attn: torch.nn.Module,
+ feed_forward: torch.nn.Module,
+ dropout_rate: float,
+ normalize_before: bool = True,
+ ):
+ """Construct an EncoderLayer object."""
+ super().__init__()
+ self.self_attn = self_attn
+ self.feed_forward = feed_forward
+ self.norm1 = nn.LayerNorm(size, eps=1e-12)
+ self.norm2 = nn.LayerNorm(size, eps=1e-12)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.size = size
+ self.normalize_before = normalize_before
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask: torch.Tensor,
+ pos_emb: torch.Tensor,
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Compute encoded features.
+
+ Args:
+ x (torch.Tensor): (#batch, time, size)
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
+ (0, 0, 0) means fake mask.
+ pos_emb (torch.Tensor): just for interface compatibility
+ to ConformerEncoderLayer
+ mask_pad (torch.Tensor): does not used in transformer layer,
+ just for unified api with conformer.
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
+ (#batch=1, size, cache_t2), not used here, it's for interface
+ compatibility to ConformerEncoderLayer.
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, size).
+ torch.Tensor: Mask tensor (#batch, time, time).
+ torch.Tensor: att_cache tensor,
+ (#batch=1, head, cache_t1 + time, d_k * 2).
+ torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2).
+
+ """
+ residual = x
+ if self.normalize_before:
+ x = self.norm1(x)
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb=pos_emb, cache=att_cache)
+ x = residual + self.dropout(x_att)
+ if not self.normalize_before:
+ x = self.norm1(x)
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm2(x)
+ x = residual + self.dropout(self.feed_forward(x))
+ if not self.normalize_before:
+ x = self.norm2(x)
+
+ fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
+ return x, mask, new_att_cache, fake_cnn_cache
+
+
+class ConformerEncoderLayer(nn.Module):
+ """Encoder layer module.
+ Args:
+ size (int): Input dimension.
+ self_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
+ instance can be used as the argument.
+ feed_forward (torch.nn.Module): Feed-forward module instance.
+ `PositionwiseFeedForward` instance can be used as the argument.
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module
+ instance.
+ `PositionwiseFeedForward` instance can be used as the argument.
+ conv_module (torch.nn.Module): Convolution module instance.
+ `ConvlutionModule` instance can be used as the argument.
+ dropout_rate (float): Dropout rate.
+ normalize_before (bool):
+ True: use layer_norm before each sub-block.
+ False: use layer_norm after each sub-block.
+ """
+
+ def __init__(
+ self,
+ size: int,
+ self_attn: torch.nn.Module,
+ feed_forward: Optional[nn.Module] = None,
+ feed_forward_macaron: Optional[nn.Module] = None,
+ conv_module: Optional[nn.Module] = None,
+ dropout_rate: float = 0.1,
+ normalize_before: bool = True,
+ ):
+ """Construct an EncoderLayer object."""
+ super().__init__()
+ self.self_attn = self_attn
+ self.feed_forward = feed_forward
+ self.feed_forward_macaron = feed_forward_macaron
+ self.conv_module = conv_module
+ self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module
+ self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module
+ if feed_forward_macaron is not None:
+ self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
+ self.ff_scale = 0.5
+ else:
+ self.ff_scale = 1.0
+ if self.conv_module is not None:
+ self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module
+ self.norm_final = nn.LayerNorm(
+ size, eps=1e-12) # for the final output of the block
+ self.dropout = nn.Dropout(dropout_rate)
+ self.size = size
+ self.normalize_before = normalize_before
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask: torch.Tensor,
+ pos_emb: torch.Tensor,
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Compute encoded features.
+
+ Args:
+ x (torch.Tensor): (#batch, time, size)
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
+ (0, 0, 0) means fake mask.
+ pos_emb (torch.Tensor): positional encoding, must not be None
+ for ConformerEncoderLayer.
+ mask_pad (torch.Tensor): batch padding mask used for conv module.
+ (#batch, 1,time), (0, 0, 0) means fake mask.
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
+ (#batch=1, size, cache_t2)
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, size).
+ torch.Tensor: Mask tensor (#batch, time, time).
+ torch.Tensor: att_cache tensor,
+ (#batch=1, head, cache_t1 + time, d_k * 2).
+ torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
+ """
+
+ # whether to use macaron style
+ if self.feed_forward_macaron is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm_ff_macaron(x)
+ x = residual + self.ff_scale * self.dropout(
+ self.feed_forward_macaron(x))
+ if not self.normalize_before:
+ x = self.norm_ff_macaron(x)
+
+ # multi-headed self-attention module
+ residual = x
+ if self.normalize_before:
+ x = self.norm_mha(x)
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
+ att_cache)
+ x = residual + self.dropout(x_att)
+ if not self.normalize_before:
+ x = self.norm_mha(x)
+
+ # convolution module
+ # Fake new cnn cache here, and then change it in conv_module
+ new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
+ if self.conv_module is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm_conv(x)
+ x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
+ x = residual + self.dropout(x)
+
+ if not self.normalize_before:
+ x = self.norm_conv(x)
+
+ # feed forward module
+ residual = x
+ if self.normalize_before:
+ x = self.norm_ff(x)
+
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
+ if not self.normalize_before:
+ x = self.norm_ff(x)
+
+ if self.conv_module is not None:
+ x = self.norm_final(x)
+
+ return x, mask, new_att_cache, new_cnn_cache
diff --git a/src/chatterbox/models/s3gen/transformer/positionwise_feed_forward.py b/src/chatterbox/models/s3gen/transformer/positionwise_feed_forward.py
new file mode 100644
index 0000000000000000000000000000000000000000..70693a844d9f576833150b59bc67521b8639e3df
--- /dev/null
+++ b/src/chatterbox/models/s3gen/transformer/positionwise_feed_forward.py
@@ -0,0 +1,115 @@
+# Copyright (c) 2019 Shigeki Karita
+# 2020 Mobvoi Inc (Binbin Zhang)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Positionwise feed forward layer definition."""
+
+import torch
+
+
+class PositionwiseFeedForward(torch.nn.Module):
+ """Positionwise feed forward layer.
+
+ FeedForward are appied on each position of the sequence.
+ The output dim is same with the input dim.
+
+ Args:
+ idim (int): Input dimenstion.
+ hidden_units (int): The number of hidden units.
+ dropout_rate (float): Dropout rate.
+ activation (torch.nn.Module): Activation function
+ """
+
+ def __init__(
+ self,
+ idim: int,
+ hidden_units: int,
+ dropout_rate: float,
+ activation: torch.nn.Module = torch.nn.ReLU(),
+ ):
+ """Construct a PositionwiseFeedForward object."""
+ super(PositionwiseFeedForward, self).__init__()
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
+ self.activation = activation
+ self.dropout = torch.nn.Dropout(dropout_rate)
+ self.w_2 = torch.nn.Linear(hidden_units, idim)
+
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
+ """Forward function.
+
+ Args:
+ xs: input tensor (B, L, D)
+ Returns:
+ output tensor, (B, L, D)
+ """
+ return self.w_2(self.dropout(self.activation(self.w_1(xs))))
+
+
+class MoEFFNLayer(torch.nn.Module):
+ """
+ Mixture of expert with Positionwise feed forward layer
+ See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf
+ The output dim is same with the input dim.
+
+ Modified from https://github.com/Lightning-AI/lit-gpt/pull/823
+ https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
+ Args:
+ n_expert: number of expert.
+ n_expert_per_token: The actual number of experts used for each frame
+ idim (int): Input dimenstion.
+ hidden_units (int): The number of hidden units.
+ dropout_rate (float): Dropout rate.
+ activation (torch.nn.Module): Activation function
+ """
+
+ def __init__(
+ self,
+ n_expert: int,
+ n_expert_per_token: int,
+ idim: int,
+ hidden_units: int,
+ dropout_rate: float,
+ activation: torch.nn.Module = torch.nn.ReLU(),
+ ):
+ super(MoEFFNLayer, self).__init__()
+ self.gate = torch.nn.Linear(idim, n_expert, bias=False)
+ self.experts = torch.nn.ModuleList(
+ PositionwiseFeedForward(idim, hidden_units, dropout_rate,
+ activation) for _ in range(n_expert))
+ self.n_expert_per_token = n_expert_per_token
+
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
+ """Foward function.
+ Args:
+ xs: input tensor (B, L, D)
+ Returns:
+ output tensor, (B, L, D)
+
+ """
+ B, L, D = xs.size(
+ ) # batch size, sequence length, embedding dimension (idim)
+ xs = xs.view(-1, D) # (B*L, D)
+ router = self.gate(xs) # (B*L, n_expert)
+ logits, indices = torch.topk(
+ router, self.n_expert_per_token
+ ) # probs:(B*L, n_expert), indices: (B*L, n_expert)
+ weights = torch.nn.functional.softmax(
+ logits, dim=1,
+ dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token)
+ output = torch.zeros_like(xs) # (B*L, D)
+ for i, expert in enumerate(self.experts):
+ mask = indices == i
+ batch_idx, ith_expert = torch.where(mask)
+ output[batch_idx] += weights[batch_idx, ith_expert, None] * expert(
+ xs[batch_idx])
+ return output.view(B, L, D)
diff --git a/src/chatterbox/models/s3gen/transformer/subsampling.py b/src/chatterbox/models/s3gen/transformer/subsampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6e1a98f9731433b673f3e09183d6245cdbd6073
--- /dev/null
+++ b/src/chatterbox/models/s3gen/transformer/subsampling.py
@@ -0,0 +1,383 @@
+# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
+# 2024 Alibaba Inc (Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from ESPnet(https://github.com/espnet/espnet)
+"""Subsampling layer definition."""
+
+from typing import Tuple, Union
+
+import torch
+
+
+class BaseSubsampling(torch.nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.right_context = 0
+ self.subsampling_rate = 1
+
+ def position_encoding(self, offset: Union[int, torch.Tensor],
+ size: int) -> torch.Tensor:
+ return self.pos_enc.position_encoding(offset, size)
+
+
+class EmbedinigNoSubsampling(BaseSubsampling):
+ """Embedding input without subsampling
+ """
+
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
+ pos_enc_class: torch.nn.Module):
+ super().__init__()
+ self.embed = torch.nn.Embedding(idim, odim)
+ self.pos_enc = pos_enc_class
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Input x.
+
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+ Returns:
+ torch.Tensor: linear input tensor (#batch, time', odim),
+ where time' = time .
+ torch.Tensor: linear input mask (#batch, 1, time'),
+ where time' = time .
+
+ """
+ x = self.embed(x)
+ x, pos_emb = self.pos_enc(x, offset)
+ return x, pos_emb, x_mask
+
+
+class LinearNoSubsampling(BaseSubsampling):
+ """Linear transform the input without subsampling
+
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
+ pos_enc_class: torch.nn.Module):
+ """Construct an linear object."""
+ super().__init__()
+ self.out = torch.nn.Sequential(
+ torch.nn.Linear(idim, odim),
+ torch.nn.LayerNorm(odim, eps=1e-5),
+ torch.nn.Dropout(dropout_rate),
+ )
+ self.pos_enc = pos_enc_class
+ self.right_context = 0
+ self.subsampling_rate = 1
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Input x.
+
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+ Returns:
+ torch.Tensor: linear input tensor (#batch, time', odim),
+ where time' = time .
+ torch.Tensor: linear input mask (#batch, 1, time'),
+ where time' = time .
+
+ """
+ x = self.out(x)
+ x, pos_emb = self.pos_enc(x, offset)
+ return x, pos_emb, x_mask
+
+
+class Conv1dSubsampling2(BaseSubsampling):
+ """Convolutional 1D subsampling (to 1/2 length).
+ It is designed for Whisper, ref:
+ https://github.com/openai/whisper/blob/main/whisper/model.py
+
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
+ pos_enc_class: torch.nn.Module):
+ """Construct an Conv1dSubsampling2 object."""
+ super().__init__()
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1),
+ torch.nn.GELU(),
+ torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1),
+ torch.nn.GELU(),
+ )
+ self.pos_enc = pos_enc_class
+ # The right context for every conv layer is computed by:
+ # (kernel_size - 1) * frame_rate_of_this_layer
+ self.subsampling_rate = 2
+ # 4 = (3 - 1) * 1 + (3 - 1) * 1
+ self.right_context = 4
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Subsample x.
+
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+ Returns:
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
+ where time' = time // 2.
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
+ where time' = time // 2.
+ torch.Tensor: positional encoding
+
+ """
+ time = x.size(1)
+ x = x.transpose(1, 2) # (b, f, t)
+ x = self.conv(x)
+ x = x.transpose(1, 2) # (b, t, f)
+ x, pos_emb = self.pos_enc(x, offset)
+ return x, pos_emb, x_mask[:, :, (time + 1) % 2::2]
+
+
+class Conv2dSubsampling4(BaseSubsampling):
+ """Convolutional 2D subsampling (to 1/4 length).
+
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
+ pos_enc_class: torch.nn.Module):
+ """Construct an Conv2dSubsampling4 object."""
+ super().__init__()
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv2d(1, odim, 3, 2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(odim, odim, 3, 2),
+ torch.nn.ReLU(),
+ )
+ self.out = torch.nn.Sequential(
+ torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
+ self.pos_enc = pos_enc_class
+ # The right context for every conv layer is computed by:
+ # (kernel_size - 1) * frame_rate_of_this_layer
+ self.subsampling_rate = 4
+ # 6 = (3 - 1) * 1 + (3 - 1) * 2
+ self.right_context = 6
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Subsample x.
+
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+ Returns:
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
+ where time' = time // 4.
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
+ where time' = time // 4.
+ torch.Tensor: positional encoding
+
+ """
+ x = x.unsqueeze(1) # (b, c=1, t, f)
+ x = self.conv(x)
+ b, c, t, f = x.size()
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
+ x, pos_emb = self.pos_enc(x, offset)
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
+
+
+class Conv2dSubsampling6(BaseSubsampling):
+ """Convolutional 2D subsampling (to 1/6 length).
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+ pos_enc (torch.nn.Module): Custom position encoding layer.
+ """
+
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
+ pos_enc_class: torch.nn.Module):
+ """Construct an Conv2dSubsampling6 object."""
+ super().__init__()
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv2d(1, odim, 3, 2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(odim, odim, 5, 3),
+ torch.nn.ReLU(),
+ )
+ self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3),
+ odim)
+ self.pos_enc = pos_enc_class
+ # 10 = (3 - 1) * 1 + (5 - 1) * 2
+ self.subsampling_rate = 6
+ self.right_context = 10
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Subsample x.
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+ Returns:
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
+ where time' = time // 6.
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
+ where time' = time // 6.
+ torch.Tensor: positional encoding
+ """
+ x = x.unsqueeze(1) # (b, c, t, f)
+ x = self.conv(x)
+ b, c, t, f = x.size()
+ x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
+ x, pos_emb = self.pos_enc(x, offset)
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
+
+
+class Conv2dSubsampling8(BaseSubsampling):
+ """Convolutional 2D subsampling (to 1/8 length).
+
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
+ pos_enc_class: torch.nn.Module):
+ """Construct an Conv2dSubsampling8 object."""
+ super().__init__()
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv2d(1, odim, 3, 2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(odim, odim, 3, 2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(odim, odim, 3, 2),
+ torch.nn.ReLU(),
+ )
+ self.linear = torch.nn.Linear(
+ odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
+ self.pos_enc = pos_enc_class
+ self.subsampling_rate = 8
+ # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
+ self.right_context = 14
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Subsample x.
+
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+ Returns:
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
+ where time' = time // 8.
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
+ where time' = time // 8.
+ torch.Tensor: positional encoding
+ """
+ x = x.unsqueeze(1) # (b, c, t, f)
+ x = self.conv(x)
+ b, c, t, f = x.size()
+ x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
+ x, pos_emb = self.pos_enc(x, offset)
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]
+
+
+class LegacyLinearNoSubsampling(BaseSubsampling):
+ """Linear transform the input without subsampling
+
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
+ pos_enc_class: torch.nn.Module):
+ """Construct an linear object."""
+ super().__init__()
+ self.out = torch.nn.Sequential(
+ torch.nn.Linear(idim, odim),
+ torch.nn.LayerNorm(odim, eps=1e-5),
+ torch.nn.Dropout(dropout_rate),
+ torch.nn.ReLU(),
+ )
+ self.pos_enc = pos_enc_class
+ self.right_context = 0
+ self.subsampling_rate = 1
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Input x.
+
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+ Returns:
+ torch.Tensor: linear input tensor (#batch, time', odim),
+ where time' = time .
+ torch.Tensor: linear input mask (#batch, 1, time'),
+ where time' = time .
+
+ """
+ x = self.out(x)
+ x, pos_emb = self.pos_enc(x, offset)
+ return x, pos_emb, x_mask
diff --git a/src/chatterbox/models/s3gen/transformer/upsample_encoder.py b/src/chatterbox/models/s3gen/transformer/upsample_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1e20b16d223cc197ac9fda44b0fd084d561186b
--- /dev/null
+++ b/src/chatterbox/models/s3gen/transformer/upsample_encoder.py
@@ -0,0 +1,318 @@
+# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
+# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
+# 2024 Alibaba Inc (Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from ESPnet(https://github.com/espnet/espnet)
+"""Encoder definition."""
+from typing import Tuple
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from .convolution import ConvolutionModule
+from .encoder_layer import ConformerEncoderLayer
+from .positionwise_feed_forward import PositionwiseFeedForward
+from ..utils.class_utils import (
+ COSYVOICE_EMB_CLASSES,
+ COSYVOICE_SUBSAMPLE_CLASSES,
+ COSYVOICE_ATTENTION_CLASSES,
+ COSYVOICE_ACTIVATION_CLASSES,
+)
+from ..utils.mask import make_pad_mask
+from ..utils.mask import add_optional_chunk_mask
+
+
+class Upsample1D(nn.Module):
+ """A 1D upsampling layer with an optional convolution.
+
+ Parameters:
+ channels (`int`):
+ number of channels in the inputs and outputs.
+ use_conv (`bool`, default `False`):
+ option to use a convolution.
+ use_conv_transpose (`bool`, default `False`):
+ option to use a convolution transpose.
+ out_channels (`int`, optional):
+ number of output channels. Defaults to `channels`.
+ """
+
+ def __init__(self, channels: int, out_channels: int, stride: int = 2):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels
+ self.stride = stride
+ # In this mode, first repeat interpolate, than conv with stride=1
+ self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
+
+ def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor):
+ outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
+ outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
+ outputs = self.conv(outputs)
+ return outputs, input_lengths * self.stride
+
+
+class PreLookaheadLayer(nn.Module):
+ def __init__(self, channels: int, pre_lookahead_len: int = 1):
+ super().__init__()
+ self.channels = channels
+ self.pre_lookahead_len = pre_lookahead_len
+ self.conv1 = nn.Conv1d(
+ channels, channels,
+ kernel_size=pre_lookahead_len + 1,
+ stride=1, padding=0,
+ )
+ self.conv2 = nn.Conv1d(
+ channels, channels,
+ kernel_size=3, stride=1, padding=0,
+ )
+
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ """
+ inputs: (batch_size, seq_len, channels)
+ """
+ outputs = inputs.transpose(1, 2).contiguous()
+ # look ahead
+ outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
+ outputs = F.leaky_relu(self.conv1(outputs))
+ # outputs
+ outputs = F.pad(outputs, (2, 0), mode='constant', value=0.0)
+ outputs = self.conv2(outputs)
+ outputs = outputs.transpose(1, 2).contiguous()
+
+ # residual connection
+ outputs = outputs + inputs
+ return outputs
+
+
+class UpsampleConformerEncoder(torch.nn.Module):
+
+ def __init__(
+ self,
+ input_size: int = 512,
+ output_size: int = 512,
+ attention_heads: int = 8,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.1,
+ input_layer: str = "linear",
+ pos_enc_layer_type: str = "rel_pos_espnet",
+ normalize_before: bool = True,
+ static_chunk_size: int = 0,
+ use_dynamic_chunk: bool = False,
+ global_cmvn: torch.nn.Module = None,
+ use_dynamic_left_chunk: bool = False,
+ positionwise_conv_kernel_size: int = 1,
+ macaron_style: bool = False,
+ selfattention_layer_type: str = "rel_selfattn",
+ activation_type: str = "swish",
+ use_cnn_module: bool = False,
+ cnn_module_kernel: int = 15,
+ causal: bool = False,
+ cnn_module_norm: str = "batch_norm",
+ key_bias: bool = True,
+ gradient_checkpointing: bool = False,
+ ):
+ """
+ Args:
+ input_size (int): input dim
+ output_size (int): dimension of attention
+ attention_heads (int): the number of heads of multi head attention
+ linear_units (int): the hidden units number of position-wise feed
+ forward
+ num_blocks (int): the number of decoder blocks
+ dropout_rate (float): dropout rate
+ attention_dropout_rate (float): dropout rate in attention
+ positional_dropout_rate (float): dropout rate after adding
+ positional encoding
+ input_layer (str): input layer type.
+ optional [linear, conv2d, conv2d6, conv2d8]
+ pos_enc_layer_type (str): Encoder positional encoding layer type.
+ opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
+ normalize_before (bool):
+ True: use layer_norm before each sub-block of a layer.
+ False: use layer_norm after each sub-block of a layer.
+ static_chunk_size (int): chunk size for static chunk training and
+ decoding
+ use_dynamic_chunk (bool): whether use dynamic chunk size for
+ training or not, You can only use fixed chunk(chunk_size > 0)
+ or dyanmic chunk size(use_dynamic_chunk = True)
+ global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
+ use_dynamic_left_chunk (bool): whether use dynamic left chunk in
+ dynamic chunk training
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
+ gradient_checkpointing: rerunning a forward-pass segment for each
+ checkpointed segment during backward.
+ """
+ super().__init__()
+ self._output_size = output_size
+
+ self.global_cmvn = global_cmvn
+ self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
+ input_size,
+ output_size,
+ dropout_rate,
+ COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
+ positional_dropout_rate),
+ )
+
+ self.normalize_before = normalize_before
+ self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
+ self.static_chunk_size = static_chunk_size
+ self.use_dynamic_chunk = use_dynamic_chunk
+ self.use_dynamic_left_chunk = use_dynamic_left_chunk
+ self.gradient_checkpointing = gradient_checkpointing
+ activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
+ # self-attention module definition
+ encoder_selfattn_layer_args = (
+ attention_heads,
+ output_size,
+ attention_dropout_rate,
+ key_bias,
+ )
+ # feed-forward module definition
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ dropout_rate,
+ activation,
+ )
+ # convolution module definition
+ convolution_layer_args = (output_size, cnn_module_kernel, activation,
+ cnn_module_norm, causal)
+ self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3)
+ self.encoders = torch.nn.ModuleList([
+ ConformerEncoderLayer(
+ output_size,
+ COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
+ *encoder_selfattn_layer_args),
+ PositionwiseFeedForward(*positionwise_layer_args),
+ PositionwiseFeedForward(
+ *positionwise_layer_args) if macaron_style else None,
+ ConvolutionModule(
+ *convolution_layer_args) if use_cnn_module else None,
+ dropout_rate,
+ normalize_before,
+ ) for _ in range(num_blocks)
+ ])
+ self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2)
+ self.up_embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
+ input_size,
+ output_size,
+ dropout_rate,
+ COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
+ positional_dropout_rate),
+ )
+ self.up_encoders = torch.nn.ModuleList([
+ ConformerEncoderLayer(
+ output_size,
+ COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
+ *encoder_selfattn_layer_args),
+ PositionwiseFeedForward(*positionwise_layer_args),
+ PositionwiseFeedForward(
+ *positionwise_layer_args) if macaron_style else None,
+ ConvolutionModule(
+ *convolution_layer_args) if use_cnn_module else None,
+ dropout_rate,
+ normalize_before,
+ ) for _ in range(4)
+ ])
+
+ def output_size(self) -> int:
+ return self._output_size
+
+ def forward(
+ self,
+ xs: torch.Tensor,
+ xs_lens: torch.Tensor,
+ decoding_chunk_size: int = 0,
+ num_decoding_left_chunks: int = -1,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Embed positions in tensor.
+
+ Args:
+ xs: padded input tensor (B, T, D)
+ xs_lens: input length (B)
+ decoding_chunk_size: decoding chunk size for dynamic chunk
+ 0: default for training, use random dynamic chunk.
+ <0: for decoding, use full chunk.
+ >0: for decoding, use fixed chunk size as set.
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
+ the chunk size is decoding_chunk_size.
+ >=0: use num_decoding_left_chunks
+ <0: use all left chunks
+ Returns:
+ encoder output tensor xs, and subsampled masks
+ xs: padded output tensor (B, T' ~= T/subsample_rate, D)
+ masks: torch.Tensor batch padding mask after subsample
+ (B, 1, T' ~= T/subsample_rate)
+ NOTE(xcsong):
+ We pass the `__call__` method of the modules instead of `forward` to the
+ checkpointing API because `__call__` attaches all the hooks of the module.
+ https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
+ """
+ T = xs.size(1)
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
+ if self.global_cmvn is not None:
+ xs = self.global_cmvn(xs)
+ xs, pos_emb, masks = self.embed(xs, masks)
+ mask_pad = masks # (B, 1, T/subsample_rate)
+ chunk_masks = add_optional_chunk_mask(xs, masks,
+ self.use_dynamic_chunk,
+ self.use_dynamic_left_chunk,
+ decoding_chunk_size,
+ self.static_chunk_size,
+ num_decoding_left_chunks)
+ # lookahead + conformer encoder
+ xs = self.pre_lookahead_layer(xs)
+ xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
+
+ # upsample + conformer encoder
+ xs = xs.transpose(1, 2).contiguous()
+ xs, xs_lens = self.up_layer(xs, xs_lens)
+ xs = xs.transpose(1, 2).contiguous()
+ T = xs.size(1)
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
+ xs, pos_emb, masks = self.up_embed(xs, masks)
+ mask_pad = masks # (B, 1, T/subsample_rate)
+ chunk_masks = add_optional_chunk_mask(xs, masks,
+ self.use_dynamic_chunk,
+ self.use_dynamic_left_chunk,
+ decoding_chunk_size,
+ self.static_chunk_size * self.up_layer.stride,
+ num_decoding_left_chunks)
+ xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
+
+ if self.normalize_before:
+ xs = self.after_norm(xs)
+ # Here we assume the mask is not changed in encoder layers, so just
+ # return the masks before encoder layers, and the masks will be used
+ # for cross attention with decoder later
+ return xs, masks
+
+ def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
+ pos_emb: torch.Tensor,
+ mask_pad: torch.Tensor) -> torch.Tensor:
+ for layer in self.encoders:
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
+ return xs
+
+ def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
+ pos_emb: torch.Tensor,
+ mask_pad: torch.Tensor) -> torch.Tensor:
+ for layer in self.up_encoders:
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
+ return xs
diff --git a/src/chatterbox/models/s3gen/utils/class_utils.py b/src/chatterbox/models/s3gen/utils/class_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec32256e49ab132a5b8e1e44de9947c88e64b865
--- /dev/null
+++ b/src/chatterbox/models/s3gen/utils/class_utils.py
@@ -0,0 +1,71 @@
+# Copyright [2023-11-28]
+# 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+
+from ..transformer.activation import Swish
+from ..transformer.subsampling import (
+ LinearNoSubsampling,
+ EmbedinigNoSubsampling,
+ Conv1dSubsampling2,
+ Conv2dSubsampling4,
+ Conv2dSubsampling6,
+ Conv2dSubsampling8,
+)
+from ..transformer.embedding import (
+ PositionalEncoding,
+ RelPositionalEncoding,
+ WhisperPositionalEncoding,
+ LearnablePositionalEncoding,
+ NoPositionalEncoding)
+from ..transformer.attention import (MultiHeadedAttention,
+ RelPositionMultiHeadedAttention)
+from ..transformer.embedding import EspnetRelPositionalEncoding
+from ..transformer.subsampling import LegacyLinearNoSubsampling
+
+
+COSYVOICE_ACTIVATION_CLASSES = {
+ "hardtanh": torch.nn.Hardtanh,
+ "tanh": torch.nn.Tanh,
+ "relu": torch.nn.ReLU,
+ "selu": torch.nn.SELU,
+ "swish": getattr(torch.nn, "SiLU", Swish),
+ "gelu": torch.nn.GELU,
+}
+
+COSYVOICE_SUBSAMPLE_CLASSES = {
+ "linear": LinearNoSubsampling,
+ "linear_legacy": LegacyLinearNoSubsampling,
+ "embed": EmbedinigNoSubsampling,
+ "conv1d2": Conv1dSubsampling2,
+ "conv2d": Conv2dSubsampling4,
+ "conv2d6": Conv2dSubsampling6,
+ "conv2d8": Conv2dSubsampling8,
+ 'paraformer_dummy': torch.nn.Identity
+}
+
+COSYVOICE_EMB_CLASSES = {
+ "embed": PositionalEncoding,
+ "abs_pos": PositionalEncoding,
+ "rel_pos": RelPositionalEncoding,
+ "rel_pos_espnet": EspnetRelPositionalEncoding,
+ "no_pos": NoPositionalEncoding,
+ "abs_pos_whisper": WhisperPositionalEncoding,
+ "embed_learnable_pe": LearnablePositionalEncoding,
+}
+
+COSYVOICE_ATTENTION_CLASSES = {
+ "selfattn": MultiHeadedAttention,
+ "rel_selfattn": RelPositionMultiHeadedAttention,
+}
diff --git a/src/chatterbox/models/s3gen/utils/mask.py b/src/chatterbox/models/s3gen/utils/mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..979f054763937b2473f449dde99561752136267a
--- /dev/null
+++ b/src/chatterbox/models/s3gen/utils/mask.py
@@ -0,0 +1,193 @@
+# Copyright (c) 2019 Shigeki Karita
+# 2020 Mobvoi Inc (Binbin Zhang)
+# 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+'''
+def subsequent_mask(
+ size: int,
+ device: torch.device = torch.device("cpu"),
+) -> torch.Tensor:
+ """Create mask for subsequent steps (size, size).
+
+ This mask is used only in decoder which works in an auto-regressive mode.
+ This means the current step could only do attention with its left steps.
+
+ In encoder, fully attention is used when streaming is not necessary and
+ the sequence is not long. In this case, no attention mask is needed.
+
+ When streaming is need, chunk-based attention is used in encoder. See
+ subsequent_chunk_mask for the chunk-based attention mask.
+
+ Args:
+ size (int): size of mask
+ str device (str): "cpu" or "cuda" or torch.Tensor.device
+ dtype (torch.device): result dtype
+
+ Returns:
+ torch.Tensor: mask
+
+ Examples:
+ >>> subsequent_mask(3)
+ [[1, 0, 0],
+ [1, 1, 0],
+ [1, 1, 1]]
+ """
+ ret = torch.ones(size, size, device=device, dtype=torch.bool)
+ return torch.tril(ret)
+'''
+
+
+def subsequent_chunk_mask(
+ size: int,
+ chunk_size: int,
+ num_left_chunks: int = -1,
+ device: torch.device = torch.device("cpu"),
+) -> torch.Tensor:
+ """Create mask for subsequent steps (size, size) with chunk size,
+ this is for streaming encoder
+
+ Args:
+ size (int): size of mask
+ chunk_size (int): size of chunk
+ num_left_chunks (int): number of left chunks
+ <0: use full chunk
+ >=0: use num_left_chunks
+ device (torch.device): "cpu" or "cuda" or torch.Tensor.device
+
+ Returns:
+ torch.Tensor: mask
+
+ Examples:
+ >>> subsequent_chunk_mask(4, 2)
+ [[1, 1, 0, 0],
+ [1, 1, 0, 0],
+ [1, 1, 1, 1],
+ [1, 1, 1, 1]]
+ """
+ # NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
+ # actually this is not needed after we have inference cache implemented, will remove it later
+ pos_idx = torch.arange(size, device=device)
+ block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
+ ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
+ return ret
+
+
+def add_optional_chunk_mask(xs: torch.Tensor,
+ masks: torch.Tensor,
+ use_dynamic_chunk: bool,
+ use_dynamic_left_chunk: bool,
+ decoding_chunk_size: int,
+ static_chunk_size: int,
+ num_decoding_left_chunks: int,
+ enable_full_context: bool = True):
+ """ Apply optional mask for encoder.
+
+ Args:
+ xs (torch.Tensor): padded input, (B, L, D), L for max length
+ mask (torch.Tensor): mask for xs, (B, 1, L)
+ use_dynamic_chunk (bool): whether to use dynamic chunk or not
+ use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
+ training.
+ decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
+ 0: default for training, use random dynamic chunk.
+ <0: for decoding, use full chunk.
+ >0: for decoding, use fixed chunk size as set.
+ static_chunk_size (int): chunk size for static chunk training/decoding
+ if it's greater than 0, if use_dynamic_chunk is true,
+ this parameter will be ignored
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
+ the chunk size is decoding_chunk_size.
+ >=0: use num_decoding_left_chunks
+ <0: use all left chunks
+ enable_full_context (bool):
+ True: chunk size is either [1, 25] or full context(max_len)
+ False: chunk size ~ U[1, 25]
+
+ Returns:
+ torch.Tensor: chunk mask of the input xs.
+ """
+ # Whether to use chunk mask or not
+ if use_dynamic_chunk:
+ max_len = xs.size(1)
+ if decoding_chunk_size < 0:
+ chunk_size = max_len
+ num_left_chunks = -1
+ elif decoding_chunk_size > 0:
+ chunk_size = decoding_chunk_size
+ num_left_chunks = num_decoding_left_chunks
+ else:
+ # chunk size is either [1, 25] or full context(max_len).
+ # Since we use 4 times subsampling and allow up to 1s(100 frames)
+ # delay, the maximum frame is 100 / 4 = 25.
+ chunk_size = torch.randint(1, max_len, (1, )).item()
+ num_left_chunks = -1
+ if chunk_size > max_len // 2 and enable_full_context:
+ chunk_size = max_len
+ else:
+ chunk_size = chunk_size % 25 + 1
+ if use_dynamic_left_chunk:
+ max_left_chunks = (max_len - 1) // chunk_size
+ num_left_chunks = torch.randint(0, max_left_chunks,
+ (1, )).item()
+ chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
+ num_left_chunks,
+ xs.device) # (L, L)
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
+ chunk_masks = masks & chunk_masks # (B, L, L)
+ elif static_chunk_size > 0:
+ num_left_chunks = num_decoding_left_chunks
+ chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
+ num_left_chunks,
+ xs.device) # (L, L)
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
+ chunk_masks = masks & chunk_masks # (B, L, L)
+ else:
+ chunk_masks = masks
+ assert chunk_masks.dtype == torch.bool
+ if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
+ logging.warning('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
+ chunk_masks[chunk_masks.sum(dim=-1)==0] = True
+ return chunk_masks
+
+
+def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
+ """Make mask tensor containing indices of padded part.
+
+ See description of make_non_pad_mask.
+
+ Args:
+ lengths (torch.Tensor): Batch of lengths (B,).
+ Returns:
+ torch.Tensor: Mask tensor containing indices of padded part.
+
+ Examples:
+ >>> lengths = [5, 3, 2]
+ >>> make_pad_mask(lengths)
+ masks = [[0, 0, 0, 0 ,0],
+ [0, 0, 0, 1, 1],
+ [0, 0, 1, 1, 1]]
+ """
+ batch_size = lengths.size(0)
+ max_len = max_len if max_len > 0 else lengths.max().item()
+ seq_range = torch.arange(0,
+ max_len,
+ dtype=torch.int64,
+ device=lengths.device)
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
+ seq_length_expand = lengths.unsqueeze(-1)
+ mask = seq_range_expand >= seq_length_expand
+ return mask
diff --git a/src/chatterbox/models/s3gen/utils/mel.py b/src/chatterbox/models/s3gen/utils/mel.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6d46ee0635999b47ccd73c3411df9d624efd8b8
--- /dev/null
+++ b/src/chatterbox/models/s3gen/utils/mel.py
@@ -0,0 +1,81 @@
+"""mel-spectrogram extraction in Matcha-TTS"""
+from librosa.filters import mel as librosa_mel_fn
+import torch
+import numpy as np
+
+
+# NOTE: they decalred these global vars
+mel_basis = {}
+hann_window = {}
+
+
+def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
+ return torch.log(torch.clamp(x, min=clip_val) * C)
+
+
+def spectral_normalize_torch(magnitudes):
+ output = dynamic_range_compression_torch(magnitudes)
+ return output
+
+"""
+feat_extractor: !name:matcha.utils.audio.mel_spectrogram
+ n_fft: 1920
+ num_mels: 80
+ sampling_rate: 24000
+ hop_size: 480
+ win_size: 1920
+ fmin: 0
+ fmax: 8000
+ center: False
+
+"""
+
+def mel_spectrogram(y, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=480, win_size=1920,
+ fmin=0, fmax=8000, center=False):
+ """Copied from https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/audio.py
+ Set default values according to Cosyvoice's config.
+ """
+
+ if isinstance(y, np.ndarray):
+ y = torch.tensor(y).float()
+
+ if len(y.shape) == 1:
+ y = y[None, ]
+
+ if torch.min(y) < -1.0:
+ print("min value is ", torch.min(y))
+ if torch.max(y) > 1.0:
+ print("max value is ", torch.max(y))
+
+ global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned
+ if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
+ mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
+
+ y = torch.nn.functional.pad(
+ y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
+ )
+ y = y.squeeze(1)
+
+ spec = torch.view_as_real(
+ torch.stft(
+ y,
+ n_fft,
+ hop_length=hop_size,
+ win_length=win_size,
+ window=hann_window[str(y.device)],
+ center=center,
+ pad_mode="reflect",
+ normalized=False,
+ onesided=True,
+ return_complex=True,
+ )
+ )
+
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
+
+ spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
+ spec = spectral_normalize_torch(spec)
+
+ return spec
diff --git a/src/chatterbox/models/s3gen/xvector.py b/src/chatterbox/models/s3gen/xvector.py
new file mode 100644
index 0000000000000000000000000000000000000000..930d946affb33155369133c86128a43970a2a5f0
--- /dev/null
+++ b/src/chatterbox/models/s3gen/xvector.py
@@ -0,0 +1,428 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+# Modified from 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker)
+
+
+from collections import OrderedDict
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+import torchaudio.compliance.kaldi as Kaldi
+
+
+def pad_list(xs, pad_value):
+ """Perform padding for the list of tensors.
+
+ Args:
+ xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
+ pad_value (float): Value for padding.
+
+ Returns:
+ Tensor: Padded tensor (B, Tmax, `*`).
+
+ Examples:
+ >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
+ >>> x
+ [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
+ >>> pad_list(x, 0)
+ tensor([[1., 1., 1., 1.],
+ [1., 1., 0., 0.],
+ [1., 0., 0., 0.]])
+
+ """
+ n_batch = len(xs)
+ max_len = max(x.size(0) for x in xs)
+ pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
+
+ for i in range(n_batch):
+ pad[i, : xs[i].size(0)] = xs[i]
+
+ return pad
+
+
+def extract_feature(audio):
+ features = []
+ feature_times = []
+ feature_lengths = []
+ for au in audio:
+ feature = Kaldi.fbank(au.unsqueeze(0), num_mel_bins=80)
+ feature = feature - feature.mean(dim=0, keepdim=True)
+ features.append(feature)
+ feature_times.append(au.shape[0])
+ feature_lengths.append(feature.shape[0])
+ # padding for batch inference
+ features_padded = pad_list(features, pad_value=0)
+ # features = torch.cat(features)
+ return features_padded, feature_lengths, feature_times
+
+
+class BasicResBlock(torch.nn.Module):
+ expansion = 1
+
+ def __init__(self, in_planes, planes, stride=1):
+ super(BasicResBlock, self).__init__()
+ self.conv1 = torch.nn.Conv2d(
+ in_planes, planes, kernel_size=3, stride=(stride, 1), padding=1, bias=False
+ )
+ self.bn1 = torch.nn.BatchNorm2d(planes)
+ self.conv2 = torch.nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn2 = torch.nn.BatchNorm2d(planes)
+
+ self.shortcut = torch.nn.Sequential()
+ if stride != 1 or in_planes != self.expansion * planes:
+ self.shortcut = torch.nn.Sequential(
+ torch.nn.Conv2d(
+ in_planes,
+ self.expansion * planes,
+ kernel_size=1,
+ stride=(stride, 1),
+ bias=False,
+ ),
+ torch.nn.BatchNorm2d(self.expansion * planes),
+ )
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.bn2(self.conv2(out))
+ out += self.shortcut(x)
+ out = F.relu(out)
+ return out
+
+
+class FCM(torch.nn.Module):
+ def __init__(self, block=BasicResBlock, num_blocks=[2, 2], m_channels=32, feat_dim=80):
+ super(FCM, self).__init__()
+ self.in_planes = m_channels
+ self.conv1 = torch.nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn1 = torch.nn.BatchNorm2d(m_channels)
+
+ self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
+ self.layer2 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
+
+ self.conv2 = torch.nn.Conv2d(
+ m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False
+ )
+ self.bn2 = torch.nn.BatchNorm2d(m_channels)
+ self.out_channels = m_channels * (feat_dim // 8)
+
+ def _make_layer(self, block, planes, num_blocks, stride):
+ strides = [stride] + [1] * (num_blocks - 1)
+ layers = []
+ for stride in strides:
+ layers.append(block(self.in_planes, planes, stride))
+ self.in_planes = planes * block.expansion
+ return torch.nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = x.unsqueeze(1)
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.layer1(out)
+ out = self.layer2(out)
+ out = F.relu(self.bn2(self.conv2(out)))
+
+ shape = out.shape
+ out = out.reshape(shape[0], shape[1] * shape[2], shape[3])
+ return out
+
+
+def get_nonlinear(config_str, channels):
+ nonlinear = torch.nn.Sequential()
+ for name in config_str.split("-"):
+ if name == "relu":
+ nonlinear.add_module("relu", torch.nn.ReLU(inplace=True))
+ elif name == "prelu":
+ nonlinear.add_module("prelu", torch.nn.PReLU(channels))
+ elif name == "batchnorm":
+ nonlinear.add_module("batchnorm", torch.nn.BatchNorm1d(channels))
+ elif name == "batchnorm_":
+ nonlinear.add_module("batchnorm", torch.nn.BatchNorm1d(channels, affine=False))
+ else:
+ raise ValueError("Unexpected module ({}).".format(name))
+ return nonlinear
+
+
+def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2):
+ mean = x.mean(dim=dim)
+ std = x.std(dim=dim, unbiased=unbiased)
+ stats = torch.cat([mean, std], dim=-1)
+ if keepdim:
+ stats = stats.unsqueeze(dim=dim)
+ return stats
+
+
+class StatsPool(torch.nn.Module):
+ def forward(self, x):
+ return statistics_pooling(x)
+
+
+class TDNNLayer(torch.nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ bias=False,
+ config_str="batchnorm-relu",
+ ):
+ super(TDNNLayer, self).__init__()
+ if padding < 0:
+ assert (
+ kernel_size % 2 == 1
+ ), "Expect equal paddings, but got even kernel size ({})".format(kernel_size)
+ padding = (kernel_size - 1) // 2 * dilation
+ self.linear = torch.nn.Conv1d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=bias,
+ )
+ self.nonlinear = get_nonlinear(config_str, out_channels)
+
+ def forward(self, x):
+ x = self.linear(x)
+ x = self.nonlinear(x)
+ return x
+
+
+class CAMLayer(torch.nn.Module):
+ def __init__(
+ self, bn_channels, out_channels, kernel_size, stride, padding, dilation, bias, reduction=2
+ ):
+ super(CAMLayer, self).__init__()
+ self.linear_local = torch.nn.Conv1d(
+ bn_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=bias,
+ )
+ self.linear1 = torch.nn.Conv1d(bn_channels, bn_channels // reduction, 1)
+ self.relu = torch.nn.ReLU(inplace=True)
+ self.linear2 = torch.nn.Conv1d(bn_channels // reduction, out_channels, 1)
+ self.sigmoid = torch.nn.Sigmoid()
+
+ def forward(self, x):
+ y = self.linear_local(x)
+ context = x.mean(-1, keepdim=True) + self.seg_pooling(x)
+ context = self.relu(self.linear1(context))
+ m = self.sigmoid(self.linear2(context))
+ return y * m
+
+ def seg_pooling(self, x, seg_len=100, stype="avg"):
+ if stype == "avg":
+ seg = F.avg_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
+ elif stype == "max":
+ seg = F.max_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
+ else:
+ raise ValueError("Wrong segment pooling type.")
+ shape = seg.shape
+ seg = seg.unsqueeze(-1).expand(*shape, seg_len).reshape(*shape[:-1], -1)
+ seg = seg[..., : x.shape[-1]]
+ return seg
+
+
+class CAMDenseTDNNLayer(torch.nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ bn_channels,
+ kernel_size,
+ stride=1,
+ dilation=1,
+ bias=False,
+ config_str="batchnorm-relu",
+ memory_efficient=False,
+ ):
+ super(CAMDenseTDNNLayer, self).__init__()
+ assert kernel_size % 2 == 1, "Expect equal paddings, but got even kernel size ({})".format(
+ kernel_size
+ )
+ padding = (kernel_size - 1) // 2 * dilation
+ self.memory_efficient = memory_efficient
+ self.nonlinear1 = get_nonlinear(config_str, in_channels)
+ self.linear1 = torch.nn.Conv1d(in_channels, bn_channels, 1, bias=False)
+ self.nonlinear2 = get_nonlinear(config_str, bn_channels)
+ self.cam_layer = CAMLayer(
+ bn_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=bias,
+ )
+
+ def bn_function(self, x):
+ return self.linear1(self.nonlinear1(x))
+
+ def forward(self, x):
+ if self.training and self.memory_efficient:
+ x = cp.checkpoint(self.bn_function, x)
+ else:
+ x = self.bn_function(x)
+ x = self.cam_layer(self.nonlinear2(x))
+ return x
+
+
+class CAMDenseTDNNBlock(torch.nn.ModuleList):
+ def __init__(
+ self,
+ num_layers,
+ in_channels,
+ out_channels,
+ bn_channels,
+ kernel_size,
+ stride=1,
+ dilation=1,
+ bias=False,
+ config_str="batchnorm-relu",
+ memory_efficient=False,
+ ):
+ super(CAMDenseTDNNBlock, self).__init__()
+ for i in range(num_layers):
+ layer = CAMDenseTDNNLayer(
+ in_channels=in_channels + i * out_channels,
+ out_channels=out_channels,
+ bn_channels=bn_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ dilation=dilation,
+ bias=bias,
+ config_str=config_str,
+ memory_efficient=memory_efficient,
+ )
+ self.add_module("tdnnd%d" % (i + 1), layer)
+
+ def forward(self, x):
+ for layer in self:
+ x = torch.cat([x, layer(x)], dim=1)
+ return x
+
+
+class TransitLayer(torch.nn.Module):
+ def __init__(self, in_channels, out_channels, bias=True, config_str="batchnorm-relu"):
+ super(TransitLayer, self).__init__()
+ self.nonlinear = get_nonlinear(config_str, in_channels)
+ self.linear = torch.nn.Conv1d(in_channels, out_channels, 1, bias=bias)
+
+ def forward(self, x):
+ x = self.nonlinear(x)
+ x = self.linear(x)
+ return x
+
+
+class DenseLayer(torch.nn.Module):
+ def __init__(self, in_channels, out_channels, bias=False, config_str="batchnorm-relu"):
+ super(DenseLayer, self).__init__()
+ self.linear = torch.nn.Conv1d(in_channels, out_channels, 1, bias=bias)
+ self.nonlinear = get_nonlinear(config_str, out_channels)
+
+ def forward(self, x):
+ if len(x.shape) == 2:
+ x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1)
+ else:
+ x = self.linear(x)
+ x = self.nonlinear(x)
+ return x
+
+# @tables.register("model_classes", "CAMPPlus")
+class CAMPPlus(torch.nn.Module):
+ def __init__(
+ self,
+ feat_dim=80,
+ embedding_size=192,
+ growth_rate=32,
+ bn_size=4,
+ init_channels=128,
+ config_str="batchnorm-relu",
+ memory_efficient=True,
+ output_level="segment",
+ **kwargs,
+ ):
+ super().__init__()
+
+ self.head = FCM(feat_dim=feat_dim)
+ channels = self.head.out_channels
+ self.output_level = output_level
+
+ self.xvector = torch.nn.Sequential(
+ OrderedDict(
+ [
+ (
+ "tdnn",
+ TDNNLayer(
+ channels,
+ init_channels,
+ 5,
+ stride=2,
+ dilation=1,
+ padding=-1,
+ config_str=config_str,
+ ),
+ ),
+ ]
+ )
+ )
+ channels = init_channels
+ for i, (num_layers, kernel_size, dilation) in enumerate(
+ zip((12, 24, 16), (3, 3, 3), (1, 2, 2))
+ ):
+ block = CAMDenseTDNNBlock(
+ num_layers=num_layers,
+ in_channels=channels,
+ out_channels=growth_rate,
+ bn_channels=bn_size * growth_rate,
+ kernel_size=kernel_size,
+ dilation=dilation,
+ config_str=config_str,
+ memory_efficient=memory_efficient,
+ )
+ self.xvector.add_module("block%d" % (i + 1), block)
+ channels = channels + num_layers * growth_rate
+ self.xvector.add_module(
+ "transit%d" % (i + 1),
+ TransitLayer(channels, channels // 2, bias=False, config_str=config_str),
+ )
+ channels //= 2
+
+ self.xvector.add_module("out_nonlinear", get_nonlinear(config_str, channels))
+
+ if self.output_level == "segment":
+ self.xvector.add_module("stats", StatsPool())
+ self.xvector.add_module(
+ "dense", DenseLayer(channels * 2, embedding_size, config_str="batchnorm_")
+ )
+ else:
+ assert (
+ self.output_level == "frame"
+ ), "`output_level` should be set to 'segment' or 'frame'. "
+
+ for m in self.modules():
+ if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)):
+ torch.nn.init.kaiming_normal_(m.weight.data)
+ if m.bias is not None:
+ torch.nn.init.zeros_(m.bias)
+
+ def forward(self, x):
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
+ x = self.head(x)
+ x = self.xvector(x)
+ if self.output_level == "frame":
+ x = x.transpose(1, 2)
+ return x
+
+ def inference(self, audio_list):
+ speech, speech_lengths, speech_times = extract_feature(audio_list)
+ results = self.forward(speech.to(torch.float32))
+ return results
diff --git a/src/chatterbox/models/s3tokenizer/__init__.py b/src/chatterbox/models/s3tokenizer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b204b3a4d3dbb4656bc0bab46ca9e1e3721f0a0
--- /dev/null
+++ b/src/chatterbox/models/s3tokenizer/__init__.py
@@ -0,0 +1,30 @@
+from .s3tokenizer import (
+ S3_SR,
+ S3_HOP,
+ S3_TOKEN_HOP,
+ S3_TOKEN_RATE,
+ SPEECH_VOCAB_SIZE,
+ S3Tokenizer,
+)
+
+
+SOS = SPEECH_VOCAB_SIZE
+EOS = SPEECH_VOCAB_SIZE + 1
+
+
+
+def drop_invalid_tokens(x):
+ """Drop SoS and EoS"""
+ assert len(x.shape) == 1 or (len(x.shape) == 2 and x.shape[0] == 1), "only batch size of one allowed for now"
+ if SOS in x:
+ s = (x == SOS).nonzero(as_tuple=True)[0].squeeze(0) + 1
+ else:
+ s = 0
+
+ if EOS in x:
+ e = (x == EOS).nonzero(as_tuple=True)[0].squeeze(0)
+ else:
+ e = None
+
+ x = x[s: e]
+ return x
diff --git a/src/chatterbox/models/s3tokenizer/s3tokenizer.py b/src/chatterbox/models/s3tokenizer/s3tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffd74421276d1c33a6dd710b99f97783c44779f3
--- /dev/null
+++ b/src/chatterbox/models/s3tokenizer/s3tokenizer.py
@@ -0,0 +1,168 @@
+from typing import List, Tuple
+
+import numpy as np
+import librosa
+import torch
+import torch.nn.functional as F
+from s3tokenizer.utils import padding
+from s3tokenizer.model_v2 import (
+ S3TokenizerV2,
+ ModelConfig,
+)
+
+
+# Sampling rate of the inputs to S3TokenizerV2
+S3_SR = 16_000
+S3_HOP = 160 # 100 frames/sec
+S3_TOKEN_HOP = 640 # 25 tokens/sec
+S3_TOKEN_RATE = 25
+SPEECH_VOCAB_SIZE = 6561
+
+
+class S3Tokenizer(S3TokenizerV2):
+ """
+ s3tokenizer.S3TokenizerV2 with the following changes:
+ - a more integrated `forward`
+ - compute `log_mel_spectrogram` using `_mel_filters` and `window` in `register_buffers`
+ """
+
+ ignore_state_dict_missing = ("_mel_filters", "window")
+
+ def __init__(
+ self,
+ name: str="speech_tokenizer_v2_25hz",
+ config: ModelConfig = ModelConfig()
+ ):
+ super().__init__(name)
+
+ self.n_fft = 400
+ _mel_filters = librosa.filters.mel(
+ sr=S3_SR,
+ n_fft=self.n_fft,
+ n_mels=config.n_mels
+ )
+ self.register_buffer(
+ "_mel_filters",
+ torch.FloatTensor(_mel_filters),
+ )
+
+ self.register_buffer(
+ "window",
+ torch.hann_window(self.n_fft),
+ )
+
+ def pad(self, wavs, sr) -> List[torch.Tensor]:
+ """
+ Given a list of wavs with the same `sample_rate`, pad them so that the length is multiple of 40ms (S3 runs at 25 token/sec).
+ """
+ processed_wavs = []
+ for wav in wavs:
+ if isinstance(wav, np.ndarray):
+ wav = torch.from_numpy(wav)
+ if wav.dim() == 1:
+ wav = wav.unsqueeze(0)
+
+ n_tokens = (wav.shape[1] / sr) * S3_TOKEN_RATE
+ n_tokens = np.ceil(n_tokens)
+ intended_wav_len = n_tokens * (sr / S3_TOKEN_RATE)
+ intended_wav_len = int(intended_wav_len)
+ wav = torch.nn.functional.pad(
+ wav,
+ (0, intended_wav_len - wav.shape[-1]),
+ mode="constant",
+ value=0
+ )
+ processed_wavs.append(wav)
+ return processed_wavs
+
+ def _prepare_audio(self, wavs):
+ """Prepare a list of audios for s3tokenizer processing."""
+ processed_wavs = []
+ for wav in wavs:
+ if isinstance(wav, np.ndarray):
+ wav = torch.from_numpy(wav)
+ if wav.dim() == 1:
+ wav = wav.unsqueeze(0)
+
+ processed_wavs.append(wav)
+ return processed_wavs
+
+ @torch.no_grad()
+ def forward(
+ self,
+ wavs: torch.Tensor,
+ accelerator: 'Accelerator'=None,
+ max_len: int=None,
+ ) -> Tuple[torch.Tensor, torch.LongTensor]:
+ """
+ NOTE: mel-spec has a hop size of 160 points (100 frame/sec).
+ FIXME: this class inherits `nn.Module` but doesn't accept `torch.Tensor` and handles a list of wavs one by one, which is unexpected.
+
+ Args
+ ----
+ - `wavs`: 16 kHz speech audio
+ - `max_len` max length to truncate the output sequence to (25 token/sec).
+ NOTE: please pad the waveform if longer sequence is needed.
+ """
+ processed_wavs = self._prepare_audio(wavs)
+ mels, mel_lens = [], []
+ for wav in processed_wavs:
+ wav = wav.to(self.device)
+ mel = self.log_mel_spectrogram(wav) # [B=1, F, T]
+ if max_len is not None:
+ mel = mel[..., :max_len * 4] # num_mel_frames = 4 * num_tokens
+ mels.append(mel.squeeze(0))
+
+ mels, mel_lens = padding(mels)
+ if accelerator is None:
+ tokenizer = self
+ else:
+ tokenizer = accelerator.unwrap_model(self)
+
+ speech_tokens, speech_token_lens = tokenizer.quantize(mels, mel_lens.to(self.device))
+ return (
+ speech_tokens.long().detach(),
+ speech_token_lens.long().detach(),
+ )
+
+ def log_mel_spectrogram(
+ self,
+ audio: torch.Tensor,
+ padding: int = 0,
+ ):
+ """
+ Compute the log-Mel spectrogram of
+
+ Parameters
+ ----------
+ audio: torch.Tensor, shape = (*)
+ The path to audio or either a NumPy array or Tensor containing the
+ audio waveform in 16 kHz
+
+ padding: int
+ Number of zero samples to pad to the right
+
+ Returns
+ -------
+ torch.Tensor, shape = (128, n_frames)
+ A Tensor that contains the Mel spectrogram
+ """
+ if not torch.is_tensor(audio):
+ audio = torch.from_numpy(audio)
+
+ audio = audio.to(self.device)
+ if padding > 0:
+ audio = F.pad(audio, (0, padding))
+ stft = torch.stft(
+ audio, self.n_fft, S3_HOP,
+ window=self.window.to(self.device),
+ return_complex=True
+ )
+ magnitudes = stft[..., :-1].abs()**2
+
+ mel_spec = self._mel_filters.to(self.device) @ magnitudes
+
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
+ log_spec = (log_spec + 4.0) / 4.0
+ return log_spec
diff --git a/src/chatterbox/models/t3/__init__.py b/src/chatterbox/models/t3/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..da4af4071dd66528659820d432762762b0da8c64
--- /dev/null
+++ b/src/chatterbox/models/t3/__init__.py
@@ -0,0 +1 @@
+from .t3 import T3
diff --git a/src/chatterbox/models/t3/inference/alignment_stream_analyzer.py b/src/chatterbox/models/t3/inference/alignment_stream_analyzer.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b067538fed05653aedbb71f08bd0cd80d0ce9cf
--- /dev/null
+++ b/src/chatterbox/models/t3/inference/alignment_stream_analyzer.py
@@ -0,0 +1,154 @@
+# Copyright (c) 2025 Resemble AI
+# Author: John Meade, Jeremy Hsu
+# MIT License
+import logging
+import torch
+from dataclasses import dataclass
+from types import MethodType
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class AlignmentAnalysisResult:
+ # was this frame detected as being part of a noisy beginning chunk with potential hallucinations?
+ false_start: bool
+ # was this frame detected as being part of a long tail with potential hallucinations?
+ long_tail: bool
+ # was this frame detected as repeating existing text content?
+ repetition: bool
+ # was the alignment position of this frame too far from the previous frame?
+ discontinuity: bool
+ # has inference reached the end of the text tokens? eg, this remains false if inference stops early
+ complete: bool
+ # approximate position in the text token sequence. Can be used for generating online timestamps.
+ position: int
+
+
+class AlignmentStreamAnalyzer:
+ def __init__(self, tfmr, queue, text_tokens_slice, alignment_layer_idx=9, eos_idx=0):
+ """
+ Some transformer TTS models implicitly solve text-speech alignment in one or more of their self-attention
+ activation maps. This module exploits this to perform online integrity checks which streaming.
+ A hook is injected into the specified attention layer, and heuristics are used to determine alignment
+ position, repetition, etc.
+
+ NOTE: currently requires no queues.
+ """
+ # self.queue = queue
+ self.text_tokens_slice = (i, j) = text_tokens_slice
+ self.eos_idx = eos_idx
+ self.alignment = torch.zeros(0, j-i)
+ # self.alignment_bin = torch.zeros(0, j-i)
+ self.curr_frame_pos = 0
+ self.text_position = 0
+
+ self.started = False
+ self.started_at = None
+
+ self.complete = False
+ self.completed_at = None
+
+ # Using `output_attentions=True` is incompatible with optimized attention kernels, so
+ # using it for all layers slows things down too much. We can apply it to just one layer
+ # by intercepting the kwargs and adding a forward hook (credit: jrm)
+ self.last_aligned_attn = None
+ self._add_attention_spy(tfmr, alignment_layer_idx)
+
+ def _add_attention_spy(self, tfmr, alignment_layer_idx):
+ """
+ Adds a forward hook to a specific attention layer to collect outputs.
+ Using `output_attentions=True` is incompatible with optimized attention kernels, so
+ using it for all layers slows things down too much.
+ (credit: jrm)
+ """
+
+ def attention_forward_hook(module, input, output):
+ """
+ See `LlamaAttention.forward`; the output is a 3-tuple: `attn_output, attn_weights, past_key_value`.
+ NOTE:
+ - When `output_attentions=True`, `LlamaSdpaAttention.forward` calls `LlamaAttention.forward`.
+ - `attn_output` has shape [B, H, T0, T0] for the 0th entry, and [B, H, 1, T0+i] for the rest i-th.
+ """
+ step_attention = output[1].cpu() # (B, 16, N, N)
+ self.last_aligned_attn = step_attention[0].mean(0) # (N, N)
+
+ target_layer = tfmr.layers[alignment_layer_idx].self_attn
+ hook_handle = target_layer.register_forward_hook(attention_forward_hook)
+
+ # Backup original forward
+ original_forward = target_layer.forward
+ def patched_forward(self, *args, **kwargs):
+ kwargs['output_attentions'] = True
+ return original_forward(*args, **kwargs)
+
+ # TODO: how to unpatch it?
+ target_layer.forward = MethodType(patched_forward, target_layer)
+
+ def step(self, logits):
+ """
+ Emits an AlignmentAnalysisResult into the output queue, and potentially modifies the logits to force an EOS.
+ """
+ # extract approximate alignment matrix chunk (1 frame at a time after the first chunk)
+ aligned_attn = self.last_aligned_attn # (N, N)
+ i, j = self.text_tokens_slice
+ if self.curr_frame_pos == 0:
+ # first chunk has conditioning info, text tokens, and BOS token
+ A_chunk = aligned_attn[j:, i:j].clone().cpu() # (T, S)
+ else:
+ # subsequent chunks have 1 frame due to KV-caching
+ A_chunk = aligned_attn[:, i:j].clone().cpu() # (1, S)
+
+ # TODO: monotonic masking; could have issue b/c spaces are often skipped.
+ A_chunk[:, self.curr_frame_pos + 1:] = 0
+
+
+ self.alignment = torch.cat((self.alignment, A_chunk), dim=0)
+
+ A = self.alignment
+ T, S = A.shape
+
+ # update position
+ cur_text_posn = A_chunk[-1].argmax()
+ discontinuity = not(-4 < cur_text_posn - self.text_position < 7) # NOTE: very lenient!
+ if not discontinuity:
+ self.text_position = cur_text_posn
+
+ # Hallucinations at the start of speech show up as activations at the bottom of the attention maps!
+ # To mitigate this, we just wait until there are no activations far off-diagonal in the last 2 tokens,
+ # and there are some strong activations in the first few tokens.
+ false_start = (not self.started) and (A[-2:, -2:].max() > 0.1 or A[:, :4].max() < 0.5)
+ self.started = not false_start
+ if self.started and self.started_at is None:
+ self.started_at = T
+
+ # Is generation likely complete?
+ self.complete = self.complete or self.text_position >= S - 3
+ if self.complete and self.completed_at is None:
+ self.completed_at = T
+
+ # NOTE: EOS rarely assigned activations, and second-last token is often punctuation, so use last 3 tokens.
+ # NOTE: due to the false-start behaviour, we need to make sure we skip activations for the first few tokens.
+ last_text_token_duration = A[15:, -3:].sum()
+
+ # Activations for the final token that last too long are likely hallucinations.
+ long_tail = self.complete and (A[self.completed_at:, -3:].sum(dim=0).max() >= 10) # 400ms
+
+ # If there are activations in previous tokens after generation has completed, assume this is a repetition error.
+ repetition = self.complete and (A[self.completed_at:, :-5].max(dim=1).values.sum() > 5)
+
+ # If a bad ending is detected, force emit EOS by modifying logits
+ # NOTE: this means logits may be inconsistent with latents!
+ if long_tail or repetition:
+ logger.warn(f"forcing EOS token, {long_tail=}, {repetition=}")
+ # (±2**15 is safe for all dtypes >= 16bit)
+ logits = -(2**15) * torch.ones_like(logits)
+ logits[..., self.eos_idx] = 2**15
+
+ # Suppress EoS to prevent early termination
+ if cur_text_posn < S - 3: # FIXME: arbitrary
+ logits[..., self.eos_idx] = -2**15
+
+ self.curr_frame_pos += 1
+ return logits
diff --git a/src/chatterbox/models/t3/inference/t3_hf_backend.py b/src/chatterbox/models/t3/inference/t3_hf_backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c577112877634313fe31ede4b5f3b5bf0d016cc
--- /dev/null
+++ b/src/chatterbox/models/t3/inference/t3_hf_backend.py
@@ -0,0 +1,116 @@
+from typing import Optional
+
+import torch
+from torch import nn as nn
+from transformers import LlamaConfig, LlamaModel, LlamaPreTrainedModel, GenerationMixin
+from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
+
+
+class T3HuggingfaceBackend(LlamaPreTrainedModel, GenerationMixin):
+ """
+ Override some HuggingFace interface methods so we can use the standard `generate` method with our
+ custom embedding / logit layers.
+
+ NOTE: need to extend "*PreTrainedModel" to avoid re-initializing weights!
+ """
+
+ def __init__(
+ self,
+ config: LlamaConfig,
+ llama: LlamaModel,
+ *,
+ speech_enc,
+ speech_head,
+ latents_queue=None,
+ logits_queue=None,
+ alignment_stream_analyzer: 'AlignmentStreamAnalyzer'=None,
+ ):
+ super().__init__(config)
+ self.model = llama
+ self.speech_enc = speech_enc
+ self.speech_head = speech_head
+ self._added_cond = False
+ self.alignment_stream_analyzer = alignment_stream_analyzer
+
+ @torch.inference_mode()
+ def prepare_inputs_for_generation(
+ self, input_ids: torch.Tensor, decoder_cond: torch.Tensor, use_cache: bool, past_key_values=None,
+ # This argument was introduced in some recent version of transformers (>=4.29.1)
+ cache_position=None
+ ):
+ """
+ This is a method used by huggingface's generate() method.
+ Overridden here to apply our custom speech token embedding layer.
+
+ :param input_ids: (B, S) int64 tensors of input tokens.
+ :param decoder_cond: (B, T, C) float32 tensor of conditioning (prefixed to )
+ """
+
+ # Make use of the kv cache: only the last input ID is new, we trim away all the ones before
+ if not use_cache:
+ past_key_values = None
+ if past_key_values is not None:
+ input_ids = input_ids[:, -1:]
+
+ # custom speech token embedding layer
+ inputs_embeds = self.speech_enc(input_ids)
+
+ # prefix decoder conditioning if applicable
+ if not self._added_cond:
+ assert past_key_values is not None # should be first step
+ if decoder_cond.size(0) != inputs_embeds.size(0):
+ decoder_cond = decoder_cond.expand(inputs_embeds.size(0), -1, -1)
+ inputs_embeds = torch.cat([decoder_cond, inputs_embeds], dim=1)
+ self._added_cond = True
+
+ return {
+ "inputs_embeds": inputs_embeds,
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ }
+
+ @torch.inference_mode()
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ past_key_values: Optional[torch.Tensor]=None,
+ use_cache=True,
+ output_attentions=False,
+ output_hidden_states=True,
+ return_dict=True,
+ ):
+ """
+ This is a method used by huggingface's generate() method.
+ Overridden here to apply our custom layer norm and speech logit projection layers.
+
+ :param inputs_embeds: (B, S, C) float32 tensor of conditioning inputs. If past key values are given,
+ S should be 1.
+ """
+ is_large_input = inputs_embeds.size(1) != 1
+ has_cache = past_key_values is not None and len(past_key_values) > 0
+ assert not (is_large_input and has_cache)
+ assert return_dict
+ assert output_hidden_states
+
+ tfmr_out = self.model(
+ inputs_embeds=inputs_embeds,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ )
+ hidden_states = tfmr_out.hidden_states[-1] # (B, seq, dim)
+
+ logits = self.speech_head(hidden_states)
+ # assert inputs_embeds.size(0) == 1 # (disabled for CFG)
+
+ # NOTE: hallucination handler may modify logits to force emit an EOS token
+ # logits = self.alignment_stream_analyzer.step(logits)
+
+ return CausalLMOutputWithCrossAttentions(
+ logits=logits,
+ past_key_values=tfmr_out.past_key_values,
+ hidden_states=tfmr_out.hidden_states,
+ attentions=tfmr_out.attentions,
+ )
diff --git a/src/chatterbox/models/t3/llama_configs.py b/src/chatterbox/models/t3/llama_configs.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa7e1e0457db4a9f54a39c2299e8113ce0373f34
--- /dev/null
+++ b/src/chatterbox/models/t3/llama_configs.py
@@ -0,0 +1,37 @@
+LLAMA_520M_CONFIG_DICT = dict(
+ # Arbitrary small number that won't cause problems when loading.
+ # These param are unused due to custom input layers.
+ vocab_size=8,
+ # default params needed for loading most pretrained 1B weights
+ max_position_embeddings=131072,
+ hidden_size=1024,
+ intermediate_size=4096,
+ num_hidden_layers=30,
+ num_attention_heads=16,
+ attn_implementation="sdpa",
+ head_dim=64,
+ tie_word_embeddings=False,
+ hidden_act="silu",
+ attention_bias=False,
+ attention_dropout=0.0,
+ initializer_range=0.02,
+ mlp_bias=False,
+ model_type="llama",
+ num_key_value_heads=16,
+ pretraining_tp=1,
+ rms_norm_eps=1e-05,
+ rope_scaling=dict(
+ factor=8.0,
+ high_freq_factor=4.0,
+ low_freq_factor=1.0,
+ original_max_position_embeddings=8192,
+ rope_type="llama3"
+ ),
+ rope_theta=500000.0,
+ torch_dtype="bfloat16",
+ use_cache=True,
+)
+
+LLAMA_CONFIGS = {
+ "Llama_520M": LLAMA_520M_CONFIG_DICT,
+}
diff --git a/src/chatterbox/models/t3/modules/cond_enc.py b/src/chatterbox/models/t3/modules/cond_enc.py
new file mode 100644
index 0000000000000000000000000000000000000000..08d869e99f33836e0a4f871b28b90ac79719e7ee
--- /dev/null
+++ b/src/chatterbox/models/t3/modules/cond_enc.py
@@ -0,0 +1,97 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+from torch import nn, Tensor
+
+from .perceiver import Perceiver
+from .t3_config import T3Config
+
+
+@dataclass
+class T3Cond:
+ """
+ Dataclass container for most / all conditioning info.
+ TODO: serialization methods aren't used, keeping them around for convenience
+ """
+
+ speaker_emb: Tensor
+ clap_emb: Optional[Tensor] = None
+ cond_prompt_speech_tokens: Optional[Tensor] = None
+ cond_prompt_speech_emb: Optional[Tensor] = None
+ emotion_adv: Optional[Tensor] = 0.5
+
+ def to(self, *, device=None, dtype=None):
+ "Cast to a device and dtype. Dtype casting is ignored for long/int tensors."
+ for k, v in self.__dict__.items():
+ if torch.is_tensor(v):
+ is_fp = type(v.view(-1)[0].item()) is not int
+ setattr(self, k, v.to(device=device, dtype=dtype if is_fp else None))
+ return self
+
+ def save(self, fpath):
+ torch.save(self.__dict__, fpath)
+
+ @staticmethod
+ def load(fpath, map_location="cpu"):
+ kwargs = torch.load(fpath, map_location=map_location, weights_only=True)
+ return T3Cond(**kwargs)
+
+
+class T3CondEnc(nn.Module):
+ """
+ Handle all non-text conditioning, like speaker embeddings / prompts, CLAP, emotion, etc.
+ """
+
+ def __init__(self, hp: T3Config):
+ super().__init__()
+ self.hp = hp
+ if hp.encoder_type == "voice_encoder":
+ self.spkr_enc = nn.Linear(hp.speaker_embed_size, hp.n_channels)
+ else:
+ raise NotImplementedError(str(hp.encoder_type))
+
+ # emotion adv
+ self.emotion_adv_fc = None
+ if hp.emotion_adv:
+ self.emotion_adv_fc = nn.Linear(1, hp.n_channels, bias=False)
+
+ # perceiver resampler
+ self.perceiver = None
+ if hp.use_perceiver_resampler:
+ self.perceiver = Perceiver()
+
+ def forward(self, cond: T3Cond):
+ # Validate
+ assert (cond.cond_prompt_speech_tokens is None) == (cond.cond_prompt_speech_emb is None), \
+ "no embeddings for cond_prompt_speech_tokens"
+
+ # Speaker embedding projection
+ cond_spkr = self.spkr_enc(cond.speaker_emb.view(-1, self.hp.speaker_embed_size))[:, None] # (B, 1, dim)
+ empty = torch.zeros_like(cond_spkr[:, :0]) # (B, 0, dim)
+
+ # TODO CLAP
+ assert cond.clap_emb is None, "clap_embed not implemented"
+ cond_clap = empty # (B, 0, dim)
+
+ # Cond prompt
+ cond_prompt_speech_emb = cond.cond_prompt_speech_emb
+ if cond_prompt_speech_emb is None:
+ cond_prompt_speech_emb = empty # (B, 0, dim)
+ elif self.hp.use_perceiver_resampler:
+ cond_prompt_speech_emb = self.perceiver(cond_prompt_speech_emb)
+
+ # Emotion Adv: must provide a value if this model uses emotion conditioning
+ cond_emotion_adv = empty # (B, 0, dim)
+ if self.hp.emotion_adv:
+ assert cond.emotion_adv is not None
+ cond_emotion_adv = self.emotion_adv_fc(cond.emotion_adv.view(-1, 1, 1))
+
+ # Concat and return
+ cond_embeds = torch.cat((
+ cond_spkr,
+ cond_clap,
+ cond_prompt_speech_emb,
+ cond_emotion_adv,
+ ), dim=1)
+ return cond_embeds
diff --git a/src/chatterbox/models/t3/modules/learned_pos_emb.py b/src/chatterbox/models/t3/modules/learned_pos_emb.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ca6d6565e5c269c9b8a43a55492ab9651d739c9
--- /dev/null
+++ b/src/chatterbox/models/t3/modules/learned_pos_emb.py
@@ -0,0 +1,32 @@
+from typing import Union
+
+import torch
+from torch import nn, Tensor
+
+
+class LearnedPositionEmbeddings(nn.Module):
+ def __init__(self, seq_len, model_dim, init=.02):
+ super().__init__()
+ self.emb = nn.Embedding(seq_len, model_dim)
+ # Initializing this way is standard for GPT-2
+ self.emb.weight.data.normal_(mean=0.0, std=init)
+
+ def forward(self, x):
+ """
+ Returns positional embeddings for index 0 up to the length of x
+ """
+ sl = x.shape[1]
+ return self.emb(torch.arange(0, sl, device=x.device))
+
+ def get_fixed_embedding(self, idx: 'Union[int, Tensor]'):
+ """
+ Args:
+ idx: scalar int or an integer tensor of shape (T,) or (B, T)
+ Returns:
+ positional embeddings for given indices, shape (B, T, dim), ie (1, 1, dim) for int input
+ """
+ device = self.emb.weight.device
+ idx = idx.to(device) if torch.is_tensor(idx) else torch.tensor(idx, device=device)
+ idx = torch.atleast_2d(idx)
+ assert idx.ndim == 2
+ return self.emb(idx) # (B, T, dim)
diff --git a/src/chatterbox/models/t3/modules/perceiver.py b/src/chatterbox/models/t3/modules/perceiver.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d8290dd50708b182818fae4b9a617d46fc9c354
--- /dev/null
+++ b/src/chatterbox/models/t3/modules/perceiver.py
@@ -0,0 +1,212 @@
+# Copyright (c) 2025 Resemble AI
+# Author: Manmay Nakhashi
+# MIT License
+import math
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+from einops import rearrange
+
+
+class RelativePositionBias(nn.Module):
+ def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8):
+ super().__init__()
+ self.scale = scale
+ self.causal = causal
+ self.num_buckets = num_buckets
+ self.max_distance = max_distance
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
+
+ @staticmethod
+ def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
+ ret = 0
+ n = -relative_position
+ if not causal:
+ num_buckets //= 2
+ ret += (n < 0).long() * num_buckets
+ n = torch.abs(n)
+ else:
+ n = torch.max(n, torch.zeros_like(n))
+
+ max_exact = num_buckets // 2
+ is_small = n < max_exact
+
+ val_if_large = max_exact + (
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
+ ).long()
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
+
+ ret += torch.where(is_small, n, val_if_large)
+ return ret
+
+ def forward(self, qk_dots):
+ i, j, device = *qk_dots.shape[-2:], qk_dots.device
+ q_pos = torch.arange(i, dtype=torch.long, device=device)
+ k_pos = torch.arange(j, dtype=torch.long, device=device)
+ rel_pos = k_pos[None, :] - q_pos[:, None]
+ rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets,
+ max_distance=self.max_distance)
+ values = self.relative_attention_bias(rp_bucket)
+ bias = rearrange(values, 'i j h -> () h i j')
+ return qk_dots + (bias * self.scale)
+
+
+class AttentionQKV(nn.Module):
+ def __init__(self, n_heads, head_dim, dropout_rate=0.1, scale=None, flash=False):
+ super().__init__()
+ self.n_heads = n_heads
+ self.head_dim = head_dim
+ self.scale = scale if scale is not None else head_dim ** -0.5
+ self.flash = flash
+ self.dropout_rate = dropout_rate
+ self.dropout = nn.Dropout(dropout_rate)
+ self.flash_config = self.setup_flash_config() if flash else None
+
+ def setup_flash_config(self):
+ # Setup flash attention configuration
+ flash_config = {
+ 'enable_flash': True,
+ 'enable_math': True,
+ 'enable_mem_efficient': True
+ }
+ return flash_config
+
+ def forward(self, q, k, v, mask=None):
+ q, k, v = [self.split_heads(tensor) for tensor in [q, k, v]]
+ if self.flash:
+ out = self.flash_attention(q, k, v, mask=mask)
+ else:
+ out = self.scaled_dot_product_attention(q, k, v, mask=mask)
+
+ return self.combine_heads(out)
+
+ def scaled_dot_product_attention(self, q, k, v, mask=None):
+ sim = torch.einsum("bhlt,bhls->bhts", q, k) * self.scale
+ if mask is not None:
+ sim = sim.masked_fill(mask == 0, float('-inf'))
+ attn = torch.softmax(sim, dim=-1)
+ attn = self.dropout(attn)
+ return torch.einsum("bhts,bhls->bhlt", attn, v)
+
+ def flash_attention(self, q, k, v, mask=None):
+ config = self.flash_config if self.flash_config else {}
+ with torch.backends.cuda.sdp_kernel(**config):
+ out = F.scaled_dot_product_attention(
+ q, k, v,
+ attn_mask=mask,
+ dropout_p=self.dropout_rate if self.training else 0.
+ )
+ return out
+
+ def split_heads(self, x):
+ bs, length, _ = x.shape
+ x = x.view(bs, length, self.n_heads, self.head_dim)
+ return x.permute(0, 2, 1, 3)
+
+ def combine_heads(self, x):
+ bs, _, length, _ = x.shape
+ x = x.permute(0, 2, 1, 3).contiguous()
+ return x.view(bs, length, -1)
+
+
+class AttentionBlock2(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other,
+ using AttentionQKV and separate linear transformations for Q, K, and V.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ relative_pos_embeddings=False,
+ flash_attention=True,
+ dropout_rate=0.2,
+ scale=None
+ ):
+ super().__init__()
+ self.channels = channels
+
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+
+ self.norm = nn.LayerNorm(channels)
+
+ # Separate linear layers for Q, K, and V
+ self.to_q = nn.Linear(channels, channels)
+ self.to_k = nn.Linear(channels, channels)
+ self.to_v = nn.Linear(channels, channels)
+
+ self.attention = AttentionQKV(self.num_heads, channels // self.num_heads, dropout_rate=dropout_rate, flash=flash_attention, scale=scale)
+
+ self.proj_out = nn.Linear(channels, channels)
+
+ if relative_pos_embeddings:
+ self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64)
+ else:
+ self.relative_pos_embeddings = None
+
+ def forward(self, x1, x2, mask=None):
+ b1, c1, *spatial1 = x1.shape
+ b2, c2, *spatial2 = x2.shape
+
+ x1_norm = self.norm(x1)
+ x2_norm = self.norm(x2)
+
+ q = self.to_q(x1_norm)
+ k = self.to_k(x2_norm)
+ v = self.to_v(x2_norm)
+
+ h = self.attention(q, k, v, mask=mask)
+ h = self.proj_out(h)
+
+ return (x1 + h).reshape(b1, c1, *spatial1)
+
+
+class Perceiver(nn.Module):
+ """Inspired by https://arxiv.org/abs/2103.03206"""
+ def __init__(self, pre_attention_query_token=32, pre_attention_query_size=1024, embedding_dim=1024, num_attn_heads=4):
+ """
+ Initialize the perceiver module.
+
+ :param pre_attention_query_token: Number of query tokens for pre-attention
+ :param pre_attention_query_size: Size of each query token
+ :param embedding_dim: Dimension of the embedding space
+ :param num_attn_heads: Number of attention heads
+ """
+ super().__init__()
+
+ # Initialize the pre-attention query parameter
+ self.pre_attention_query = torch.nn.Parameter(
+ torch.empty(1, pre_attention_query_token, pre_attention_query_size)
+ )
+
+ # Calculate the variance for uniform initialization
+ query_variance = math.sqrt(3.0) * math.sqrt(2.0 / (pre_attention_query_token + pre_attention_query_token))
+
+ # Initialize the pre-attention query with uniform distribution
+ self.pre_attention_query.data.uniform_(-query_variance, query_variance)
+
+ # Initialize the attention block
+ self.attn = AttentionBlock2(embedding_dim, num_attn_heads)
+
+ def forward(self, h):
+ """
+ Forward pass of the perceiver module.
+ :param h: Input tensor
+ :return: Output after applying attention mechanisms
+ """
+ # Expand the pre-attention query to match the batch size of the input
+ query_ = self.pre_attention_query.expand(h.shape[0], -1, -1)
+ # Apply the first attention mechanism (cross-attention)
+ pre_att = self.attn(query_, h)
+ # Apply the second attention mechanism (self-attention)
+ attn = self.attn(pre_att, pre_att)
+ return attn
diff --git a/src/chatterbox/models/t3/modules/t3_config.py b/src/chatterbox/models/t3/modules/t3_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..28b4339622ab5980b9c808d02e4614cd1afea118
--- /dev/null
+++ b/src/chatterbox/models/t3/modules/t3_config.py
@@ -0,0 +1,27 @@
+from ..llama_configs import LLAMA_CONFIGS
+
+
+class T3Config:
+ start_text_token = 255
+ stop_text_token = 0
+ text_tokens_dict_size = 704
+ max_text_tokens = 2048
+
+ start_speech_token = 6561
+ stop_speech_token = 6562
+ speech_tokens_dict_size = 8194
+ max_speech_tokens = 4096
+
+ llama_config_name = "Llama_520M"
+ input_pos_emb = "learned"
+ speech_cond_prompt_len = 150
+
+ # For T3CondEnc
+ encoder_type = "voice_encoder"
+ speaker_embed_size = 256
+ use_perceiver_resampler = True
+ emotion_adv = True
+
+ @property
+ def n_channels(self):
+ return LLAMA_CONFIGS[self.llama_config_name]["hidden_size"]
diff --git a/src/chatterbox/models/t3/t3.py b/src/chatterbox/models/t3/t3.py
new file mode 100644
index 0000000000000000000000000000000000000000..32fc6bc877cbd872881cd9e33dd9a1277c6b9d31
--- /dev/null
+++ b/src/chatterbox/models/t3/t3.py
@@ -0,0 +1,372 @@
+# Copyright (c) 2025 Resemble AI
+# MIT License
+import logging
+from typing import Union, Optional, List
+
+from tqdm import tqdm
+import torch
+import torch.nn.functional as F
+from torch import nn, Tensor
+from transformers import LlamaModel, LlamaConfig
+from transformers.generation.logits_process import TopPLogitsWarper, RepetitionPenaltyLogitsProcessor
+
+from .modules.learned_pos_emb import LearnedPositionEmbeddings
+
+from .modules.cond_enc import T3CondEnc, T3Cond
+from .modules.t3_config import T3Config
+from .llama_configs import LLAMA_CONFIGS
+from .inference.t3_hf_backend import T3HuggingfaceBackend
+from .inference.alignment_stream_analyzer import AlignmentStreamAnalyzer
+
+
+logger = logging.getLogger(__name__)
+
+
+class AttrDict(dict):
+ def __init__(self, *args, **kwargs):
+ super(AttrDict, self).__init__(*args, **kwargs)
+ self.__dict__ = self
+
+
+def _ensure_BOT_EOT(text_tokens: Tensor, hp):
+ B = text_tokens.size(0)
+ assert (text_tokens == hp.start_text_token).int().sum() >= B, "missing start_text_token"
+ assert (text_tokens == hp.stop_text_token).int().sum() >= B, "missing stop_text_token"
+
+
+class T3(nn.Module):
+ """
+ Token-To-Token (T3) TTS model using huggingface transformer models as backbones,
+ * tokenization, including start / stop tokens are always added externally to this class
+ * conditioning data like CLAP, emotion, etc are all in a separate file for more modularity
+ * careful! this class assumes relative positional encoding -- with absolute PE, we would at
+ least want to reset the position to 0 when speech tokens begin, and optionally use a
+ different PE embedding space for speech.
+ """
+
+ def __init__(self, hp=T3Config()):
+ super().__init__()
+ self.hp = hp
+ self.cfg = LlamaConfig(**LLAMA_CONFIGS[hp.llama_config_name])
+ self.tfmr = LlamaModel(self.cfg)
+ self.dim = self.cfg.hidden_size
+ self.deepspeed_patch_applied = False
+
+ # conditioning / embedding
+ self.cond_enc = T3CondEnc(hp)
+ self.text_emb = nn.Embedding(hp.text_tokens_dict_size, self.dim)
+ self.speech_emb = nn.Embedding(hp.speech_tokens_dict_size, self.dim)
+
+ # custom position embedding
+ if hp.input_pos_emb == "learned":
+ max_text_seq_len = hp.max_text_tokens + 2
+ self.text_pos_emb = LearnedPositionEmbeddings(max_text_seq_len, self.dim)
+
+ max_mel_seq_len = hp.max_speech_tokens + 2 + 2
+ self.speech_pos_emb = LearnedPositionEmbeddings(max_mel_seq_len, self.dim)
+
+ # logit projection
+ self.text_head = nn.Linear(self.cfg.hidden_size, hp.text_tokens_dict_size, bias=False)
+ self.speech_head = nn.Linear(self.cfg.hidden_size, hp.speech_tokens_dict_size, bias=False)
+ self.compiled = False
+
+ @property
+ def device(self):
+ return self.speech_head.weight.device
+
+ def prepare_conditioning(self, t3_cond: T3Cond):
+ """
+ Token cond data needs to be embedded, so that needs to be here instead of in `T3CondEnc`.
+ """
+ if t3_cond.cond_prompt_speech_tokens is not None and t3_cond.cond_prompt_speech_emb is None:
+ t3_cond.cond_prompt_speech_emb = self.speech_emb(t3_cond.cond_prompt_speech_tokens) + \
+ self.speech_pos_emb(t3_cond.cond_prompt_speech_tokens)
+ return self.cond_enc(t3_cond) # (B, len_cond, dim)
+
+ def prepare_input_embeds(
+ self,
+ *,
+ t3_cond: T3Cond,
+ text_tokens: torch.LongTensor,
+ speech_tokens: torch.LongTensor,
+ ):
+ # prepare input embeddings (skip backbone tranformer embeddings)
+ cond_emb = self.prepare_conditioning(t3_cond) # (B, len_cond, dim)
+ text_emb = self.text_emb(text_tokens) # (B, len_text, dim)
+ text_emb[1].zero_() # CFG uncond
+
+ speech_emb = self.speech_emb(speech_tokens) # (B, len_speech, dim)
+ if self.hp.input_pos_emb == "learned":
+ text_emb = text_emb + self.text_pos_emb(text_tokens)
+ speech_emb = speech_emb + self.speech_pos_emb(speech_tokens)
+ len_cond = cond_emb.size(1)
+
+ if cond_emb.size(0) != text_emb.size(0):
+ cond_emb = cond_emb.expand(text_emb.size(0), -1, -1)
+
+ # concat
+ embeds = torch.stack([
+ torch.cat((ce, te, se))
+ for ce, te, se in zip(cond_emb, text_emb, speech_emb)
+ ]) # (B, length, dim)
+ return embeds, len_cond
+
+ def forward(
+ self,
+ *,
+ t3_cond: T3Cond,
+ text_tokens: torch.LongTensor,
+ text_token_lens: torch.LongTensor,
+ speech_tokens: torch.LongTensor,
+ speech_token_lens: torch.LongTensor,
+ training=False,
+ ):
+ _ensure_BOT_EOT(text_tokens, self.hp)
+
+ # prepare custom input embeds
+ embeds, len_cond = self.prepare_input_embeds(
+ t3_cond=t3_cond,
+ text_tokens=text_tokens,
+ speech_tokens=speech_tokens,
+ )
+
+ # backbone tranformer forward
+ tfmr_out = self.tfmr.forward(
+ input_ids=None,
+ # position_ids=position_ids, # TODO? ROPE should be fine?
+ inputs_embeds=embeds,
+ output_hidden_states=True,
+ return_dict=True,
+ use_cache=(not training),
+ )
+ hidden_states = tfmr_out.hidden_states[-1] # final tfmr layer output, (B, seq, dim)
+
+ # post-processing: splice out text and speech parts of hidden states
+ len_text = text_tokens.size(1)
+ len_speech = speech_tokens.size(1)
+ B, _, dim = hidden_states.shape
+ device, dtype = hidden_states.device, hidden_states.dtype
+ text_latents = torch.zeros(B, len_text, dim, dtype=dtype, device=device)
+ speech_latents = torch.zeros(B, len_speech, dim, dtype=dtype, device=device)
+ ttl, stl = text_token_lens, speech_token_lens
+ for i in range(B):
+ text_end = len_cond + ttl[i].item()
+ speech_start = len_cond + text_tokens.size(1)
+ speech_end = speech_start + stl[i].item()
+ text_latents[i, :ttl[i]] = hidden_states[i, len_cond:text_end]
+ speech_latents[i, :stl[i]] = hidden_states[i, speech_start:speech_end]
+
+ # logit projection
+ text_logits = self.text_head(text_latents)
+ speech_logits = self.speech_head(speech_latents)
+
+ return AttrDict(
+ text_logits=text_logits,
+ text_latents=text_latents,
+ speech_logits=speech_logits,
+ speech_latents=speech_latents,
+ hidden_states=hidden_states,
+ )
+
+ def loss(
+ self,
+ *,
+ t3_cond: T3Cond,
+ text_tokens: torch.LongTensor,
+ text_token_lens: torch.LongTensor,
+ speech_tokens: torch.LongTensor,
+ speech_token_lens: torch.LongTensor,
+ ):
+ "training method"
+ len_text = text_tokens.size(1)
+ len_speech = speech_tokens.size(1)
+ assert len_text == text_token_lens.max()
+ assert len_speech == speech_token_lens.max()
+
+ out = self.forward(
+ t3_cond=t3_cond,
+ text_tokens=text_tokens,
+ text_token_lens=text_token_lens,
+ speech_tokens=speech_tokens,
+ speech_token_lens=speech_token_lens,
+ training=True,
+ ) # (B, seq, vocab_size)
+
+ # Calc CCE losses
+ IGNORE_ID = -100
+ device = out.text_logits.device
+ mask_text = torch.arange(len_text, device=device)[None] >= text_token_lens[:, None] # (B, len_text)
+ mask_speech = torch.arange(len_speech, device=device)[None] >= speech_token_lens[:, None] # (B, len_speech)
+ masked_text = text_tokens.masked_fill(mask_text, IGNORE_ID)
+ masked_speech = speech_tokens.masked_fill(mask_speech, IGNORE_ID)
+ loss_text = F.cross_entropy(out.text_logits, masked_text, ignore_index=IGNORE_ID)
+ loss_speech = F.cross_entropy(out.speech_logits, masked_speech, ignore_index=IGNORE_ID)
+
+ return loss_text, loss_speech
+
+ @torch.inference_mode()
+ def inference(
+ self,
+ *,
+ t3_cond: T3Cond,
+ text_tokens: Tensor,
+ initial_speech_tokens: Optional[Tensor]=None,
+
+ # misc conditioning
+ prepend_prompt_speech_tokens: Optional[Tensor]=None,
+
+ # HF generate args
+ num_return_sequences=1,
+ max_new_tokens=None,
+ stop_on_eos=True,
+ do_sample=True,
+ temperature=0.8,
+ top_p=0.8,
+ length_penalty=1.0,
+ repetition_penalty=2.0,
+ cfg_weight=0,
+ ):
+ """
+ Args:
+ text_tokens: a 1D (unbatched) or 2D (batched) tensor.
+ """
+ # Validate / sanitize inputs
+ assert prepend_prompt_speech_tokens is None, "not implemented"
+ _ensure_BOT_EOT(text_tokens, self.hp)
+ text_tokens = torch.atleast_2d(text_tokens).to(dtype=torch.long, device=self.device)
+
+ # Default initial speech to a single start-of-speech token
+ if initial_speech_tokens is None:
+ initial_speech_tokens = self.hp.start_speech_token * torch.ones_like(text_tokens[:, :1])
+
+ # Prepare custom input embeds
+ embeds, len_cond = self.prepare_input_embeds(
+ t3_cond=t3_cond,
+ text_tokens=text_tokens,
+ speech_tokens=initial_speech_tokens,
+ )
+
+ # In order to use the standard HF generate method, we need to extend some methods to inject our custom logic
+ # Note the llama-specific logic. Other tfmr types can be added later.
+
+ self.compiled = False
+
+ # TODO? synchronize the expensive compile function
+ # with self.compile_lock:
+ if not self.compiled:
+ alignment_stream_analyzer = AlignmentStreamAnalyzer(
+ self.tfmr,
+ None,
+ text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)),
+ alignment_layer_idx=9, # TODO: hparam or something?
+ eos_idx=self.hp.stop_speech_token,
+ )
+ patched_model = T3HuggingfaceBackend(
+ config=self.cfg,
+ llama=self.tfmr,
+ speech_enc=self.speech_emb,
+ speech_head=self.speech_head,
+ alignment_stream_analyzer=alignment_stream_analyzer,
+ )
+ self.patched_model = patched_model
+ self.compiled = True
+
+ # # Run normal generate method, which calls our custom extended methods
+ # return self.patched_model.generate(
+ # inputs=initial_speech_tokens,
+ # decoder_cond=embeds,
+ # bos_token_id=self.hp.start_speech_token,
+ # eos_token_id=(self.hp.stop_speech_token if stop_on_eos else -1),
+ # pad_token_id=self.hp.stop_speech_token,
+ # max_new_tokens=max_new_tokens or self.hp.max_speech_tokens,
+ # num_return_sequences=num_return_sequences,
+ # temperature=temperature,
+ # top_p=top_p,
+ # length_penalty=length_penalty,
+ # repetition_penalty=repetition_penalty,
+ # do_sample=do_sample,
+ # # cache_implementation=None if not self.compiled else "static",
+ # )
+
+ device = embeds.device
+
+ bos_token = torch.tensor([[self.hp.start_speech_token]], dtype=torch.long, device=device)
+ bos_embed = self.speech_emb(bos_token) # shape: (B, 1, embed_dim)
+ bos_embed = bos_embed + self.speech_pos_emb.get_fixed_embedding(0)
+
+ # batch_size=2 for CFG
+ bos_embed = torch.cat([bos_embed, bos_embed])
+
+ # Combine condition and BOS token for the initial input
+ inputs_embeds = torch.cat([embeds, bos_embed], dim=1)
+
+ # Track generated token ids; start with the BOS token.
+ generated_ids = bos_token.clone()
+ predicted = [] # To store the predicted tokens
+
+ # Instantiate the logits processors.
+ top_p_warper = TopPLogitsWarper(top_p=top_p)
+ repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
+
+ # ---- Initial Forward Pass (no kv_cache yet) ----
+ output = self.patched_model(
+ inputs_embeds=inputs_embeds,
+ past_key_values=None,
+ use_cache=True,
+ output_attentions=True,
+ output_hidden_states=True,
+ return_dict=True,
+ )
+ # Initialize kv_cache with the full context.
+ past = output.past_key_values
+
+ # ---- Generation Loop using kv_cache ----
+ for i in tqdm(range(max_new_tokens), desc="Sampling", dynamic_ncols=True):
+ logits = output.logits[:, -1, :]
+
+ # CFG
+ logits_cond = logits[0:1]
+ logits_uncond = logits[1:2]
+ logits = logits_cond + cfg_weight * (logits_cond - logits_uncond)
+ logits = logits.squeeze(1)
+
+ # Apply temperature scaling.
+ if temperature != 1.0:
+ logits = logits / temperature
+
+ # Apply repetition penalty and top‑p filtering.
+ logits = repetition_penalty_processor(generated_ids, logits)
+ logits = top_p_warper(None, logits)
+
+ # Convert logits to probabilities and sample the next token.
+ probs = torch.softmax(logits, dim=-1)
+ next_token = torch.multinomial(probs, num_samples=1) # shape: (B, 1)
+
+ predicted.append(next_token)
+ generated_ids = torch.cat([generated_ids, next_token], dim=1)
+
+ # Check for EOS token.
+ if next_token.view(-1) == self.hp.stop_speech_token:
+ break
+
+ # Get embedding for the new token.
+ next_token_embed = self.speech_emb(next_token)
+ next_token_embed = next_token_embed + self.speech_pos_emb.get_fixed_embedding(i + 1)
+
+ # For CFG
+ next_token_embed = torch.cat([next_token_embed, next_token_embed])
+
+ # Forward pass with only the new token and the cached past.
+ output = self.patched_model(
+ inputs_embeds=next_token_embed,
+ past_key_values=past,
+ output_attentions=True,
+ output_hidden_states=True,
+ return_dict=True,
+ )
+ # Update the kv_cache.
+ past = output.past_key_values
+
+ # Concatenate all predicted tokens along the sequence dimension.
+ predicted_tokens = torch.cat(predicted, dim=1) # shape: (B, num_tokens)
+ return predicted_tokens
diff --git a/src/chatterbox/models/tokenizers/__init__.py b/src/chatterbox/models/tokenizers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0cc915b99cd5a374be1426122e2f8ad984f9db7
--- /dev/null
+++ b/src/chatterbox/models/tokenizers/__init__.py
@@ -0,0 +1 @@
+from .tokenizer import EnTokenizer
diff --git a/src/chatterbox/models/tokenizers/tokenizer.py b/src/chatterbox/models/tokenizers/tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c1e00122d82bacfe1f790153e7ecc4a392d65e4
--- /dev/null
+++ b/src/chatterbox/models/tokenizers/tokenizer.py
@@ -0,0 +1,50 @@
+import logging
+
+import torch
+from tokenizers import Tokenizer
+
+
+# Special tokens
+SOT = "[START]"
+EOT = "[STOP]"
+UNK = "[UNK]"
+SPACE = "[SPACE]"
+SPECIAL_TOKENS = [SOT, EOT, UNK, SPACE, "[PAD]", "[SEP]", "[CLS]", "[MASK]"]
+
+logger = logging.getLogger(__name__)
+
+class EnTokenizer:
+ def __init__(self, vocab_file_path):
+ self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path)
+ self.check_vocabset_sot_eot()
+
+ def check_vocabset_sot_eot(self):
+ voc = self.tokenizer.get_vocab()
+ assert SOT in voc
+ assert EOT in voc
+
+ def text_to_tokens(self, text: str):
+ text_tokens = self.encode(text)
+ text_tokens = torch.IntTensor(text_tokens).unsqueeze(0)
+ return text_tokens
+
+ def encode( self, txt: str, verbose=False):
+ """
+ clean_text > (append `lang_id`) > replace SPACE > encode text using Tokenizer
+ """
+ txt = txt.replace(' ', SPACE)
+ code = self.tokenizer.encode(txt)
+ ids = code.ids
+ return ids
+
+ def decode(self, seq):
+ if isinstance(seq, torch.Tensor):
+ seq = seq.cpu().numpy()
+
+ txt: str = self.tokenizer.decode(seq,
+ skip_special_tokens=False)
+ txt = txt.replace(' ', '')
+ txt = txt.replace(SPACE, ' ')
+ txt = txt.replace(EOT, '')
+ txt = txt.replace(UNK, '')
+ return txt
diff --git a/src/chatterbox/models/voice_encoder/__init__.py b/src/chatterbox/models/voice_encoder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9af714d4a03df9463412ae40bdfb6de37d2e781a
--- /dev/null
+++ b/src/chatterbox/models/voice_encoder/__init__.py
@@ -0,0 +1 @@
+from .voice_encoder import VoiceEncoder, VoiceEncConfig
diff --git a/src/chatterbox/models/voice_encoder/config.py b/src/chatterbox/models/voice_encoder/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..47f48bc8ec6ea5149a19b7baea3e1be1b45e82af
--- /dev/null
+++ b/src/chatterbox/models/voice_encoder/config.py
@@ -0,0 +1,18 @@
+class VoiceEncConfig:
+ num_mels = 40
+ sample_rate = 16000
+ speaker_embed_size = 256
+ ve_hidden_size = 256
+ flatten_lstm_params = False
+ n_fft = 400
+ hop_size = 160
+ win_size = 400
+ fmax = 8000
+ fmin = 0
+ preemphasis = 0.
+ mel_power = 2.0
+ mel_type = "amp"
+ normalized_mels = False
+ ve_partial_frames = 160
+ ve_final_relu = True
+ stft_magnitude_min = 1e-4
diff --git a/src/chatterbox/models/voice_encoder/melspec.py b/src/chatterbox/models/voice_encoder/melspec.py
new file mode 100644
index 0000000000000000000000000000000000000000..55a8a6d61c88934f883121391221e7fc5d74c195
--- /dev/null
+++ b/src/chatterbox/models/voice_encoder/melspec.py
@@ -0,0 +1,78 @@
+from functools import lru_cache
+
+from scipy import signal
+import numpy as np
+import librosa
+
+
+@lru_cache()
+def mel_basis(hp):
+ assert hp.fmax <= hp.sample_rate // 2
+ return librosa.filters.mel(
+ sr=hp.sample_rate,
+ n_fft=hp.n_fft,
+ n_mels=hp.num_mels,
+ fmin=hp.fmin,
+ fmax=hp.fmax) # -> (nmel, nfreq)
+
+
+def preemphasis(wav, hp):
+ assert hp.preemphasis != 0
+ wav = signal.lfilter([1, -hp.preemphasis], [1], wav)
+ wav = np.clip(wav, -1, 1)
+ return wav
+
+
+def melspectrogram(wav, hp, pad=True):
+ # Run through pre-emphasis
+ if hp.preemphasis > 0:
+ wav = preemphasis(wav, hp)
+ assert np.abs(wav).max() - 1 < 1e-07
+
+ # Do the stft
+ spec_complex = _stft(wav, hp, pad=pad)
+
+ # Get the magnitudes
+ spec_magnitudes = np.abs(spec_complex)
+
+ if hp.mel_power != 1.0:
+ spec_magnitudes **= hp.mel_power
+
+ # Get the mel and convert magnitudes->db
+ mel = np.dot(mel_basis(hp), spec_magnitudes)
+ if hp.mel_type == "db":
+ mel = _amp_to_db(mel, hp)
+
+ # Normalise the mel from db to 0,1
+ if hp.normalized_mels:
+ mel = _normalize(mel, hp).astype(np.float32)
+
+ assert not pad or mel.shape[1] == 1 + len(wav) // hp.hop_size # Sanity check
+ return mel # (M, T)
+
+
+def _stft(y, hp, pad=True):
+ # NOTE: after 0.8, pad mode defaults to constant, setting this to reflect for
+ # historical consistency and streaming-version consistency
+ return librosa.stft(
+ y,
+ n_fft=hp.n_fft,
+ hop_length=hp.hop_size,
+ win_length=hp.win_size,
+ center=pad,
+ pad_mode="reflect",
+ )
+
+
+def _amp_to_db(x, hp):
+ return 20 * np.log10(np.maximum(hp.stft_magnitude_min, x))
+
+
+def _db_to_amp(x):
+ return np.power(10.0, x * 0.05)
+
+
+def _normalize(s, hp, headroom_db=15):
+ min_level_db = 20 * np.log10(hp.stft_magnitude_min)
+ s = (s - min_level_db) / (-min_level_db + headroom_db)
+ return s
diff --git a/src/chatterbox/models/voice_encoder/voice_encoder.py b/src/chatterbox/models/voice_encoder/voice_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0ed2df76e90b315e3329215f9e56eadc75a48f3
--- /dev/null
+++ b/src/chatterbox/models/voice_encoder/voice_encoder.py
@@ -0,0 +1,274 @@
+# Adapted from https://github.com/CorentinJ/Real-Time-Voice-Cloning
+# MIT License
+from typing import List, Union, Optional
+
+import numpy as np
+from numpy.lib.stride_tricks import as_strided
+import librosa
+import torch
+import torch.nn.functional as F
+from torch import nn, Tensor
+
+from .config import VoiceEncConfig
+from .melspec import melspectrogram
+
+
+def pack(arrays, seq_len: int=None, pad_value=0):
+ """
+ Given a list of length B of array-like objects of shapes (Ti, ...), packs them in a single tensor of
+ shape (B, T, ...) by padding each individual array on the right.
+
+ :param arrays: a list of array-like objects of matching shapes except for the first axis.
+ :param seq_len: the value of T. It must be the maximum of the lengths Ti of the arrays at
+ minimum. Will default to that value if None.
+ :param pad_value: the value to pad the arrays with.
+ :return: a (B, T, ...) tensor
+ """
+ if seq_len is None:
+ seq_len = max(len(array) for array in arrays)
+ else:
+ assert seq_len >= max(len(array) for array in arrays)
+
+ # Convert lists to np.array
+ if isinstance(arrays[0], list):
+ arrays = [np.array(array) for array in arrays]
+
+ # Convert to tensor and handle device
+ device = None
+ if isinstance(arrays[0], torch.Tensor):
+ tensors = arrays
+ device = tensors[0].device
+ else:
+ tensors = [torch.as_tensor(array) for array in arrays]
+
+ # Fill the packed tensor with the array data
+ packed_shape = (len(tensors), seq_len, *tensors[0].shape[1:])
+ packed_tensor = torch.full(packed_shape, pad_value, dtype=tensors[0].dtype, device=device)
+
+ for i, tensor in enumerate(tensors):
+ packed_tensor[i, :tensor.size(0)] = tensor
+
+ return packed_tensor
+
+
+def get_num_wins(
+ n_frames: int,
+ step: int,
+ min_coverage: float,
+ hp: VoiceEncConfig,
+):
+ assert n_frames > 0
+ win_size = hp.ve_partial_frames
+ n_wins, remainder = divmod(max(n_frames - win_size + step, 0), step)
+ if n_wins == 0 or (remainder + (win_size - step)) / win_size >= min_coverage:
+ n_wins += 1
+ target_n = win_size + step * (n_wins - 1)
+ return n_wins, target_n
+
+
+def get_frame_step(
+ overlap: float,
+ rate: float,
+ hp: VoiceEncConfig,
+):
+ # Compute how many frames separate two partial utterances
+ assert 0 <= overlap < 1
+ if rate is None:
+ frame_step = int(np.round(hp.ve_partial_frames * (1 - overlap)))
+ else:
+ frame_step = int(np.round((hp.sample_rate / rate) / hp.ve_partial_frames))
+ assert 0 < frame_step <= hp.ve_partial_frames
+ return frame_step
+
+
+def stride_as_partials(
+ mel: np.ndarray,
+ hp: VoiceEncConfig,
+ overlap=0.5,
+ rate: float=None,
+ min_coverage=0.8,
+):
+ """
+ Takes unscaled mels in (T, M) format
+ TODO: doc
+ """
+ assert 0 < min_coverage <= 1
+ frame_step = get_frame_step(overlap, rate, hp)
+
+ # Compute how many partials can fit in the mel
+ n_partials, target_len = get_num_wins(len(mel), frame_step, min_coverage, hp)
+
+ # Trim or pad the mel spectrogram to match the number of partials
+ if target_len > len(mel):
+ mel = np.concatenate((mel, np.full((target_len - len(mel), hp.num_mels), 0)))
+ elif target_len < len(mel):
+ mel = mel[:target_len]
+
+ # Ensure the numpy array data is float32 and contiguous in memory
+ mel = mel.astype(np.float32, order="C")
+
+ # Re-arrange the array in memory to be of shape (N, P, M) with partials overlapping eachother,
+ # where N is the number of partials, P is the number of frames of each partial and M the
+ # number of channels of the mel spectrograms.
+ shape = (n_partials, hp.ve_partial_frames, hp.num_mels)
+ strides = (mel.strides[0] * frame_step, mel.strides[0], mel.strides[1])
+ partials = as_strided(mel, shape, strides)
+ return partials
+
+
+class VoiceEncoder(nn.Module):
+ def __init__(self, hp=VoiceEncConfig()):
+ super().__init__()
+
+ self.hp = hp
+
+ # Network definition
+ self.lstm = nn.LSTM(self.hp.num_mels, self.hp.ve_hidden_size, num_layers=3, batch_first=True)
+ if hp.flatten_lstm_params:
+ self.lstm.flatten_parameters()
+ self.proj = nn.Linear(self.hp.ve_hidden_size, self.hp.speaker_embed_size)
+
+ # Cosine similarity scaling (fixed initial parameter values)
+ self.similarity_weight = nn.Parameter(torch.tensor([10.]), requires_grad=True)
+ self.similarity_bias = nn.Parameter(torch.tensor([-5.]), requires_grad=True)
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ def forward(self, mels: torch.FloatTensor):
+ """
+ Computes the embeddings of a batch of partial utterances.
+
+ :param mels: a batch of unscaled mel spectrograms of same duration as a float32 tensor
+ of shape (B, T, M) where T is hp.ve_partial_frames
+ :return: the embeddings as a float32 tensor of shape (B, E) where E is
+ hp.speaker_embed_size. Embeddings are L2-normed and thus lay in the range [-1, 1].
+ """
+ if self.hp.normalized_mels and (mels.min() < 0 or mels.max() > 1):
+ raise Exception(f"Mels outside [0, 1]. Min={mels.min()}, Max={mels.max()}")
+
+ # Pass the input through the LSTM layers
+ _, (hidden, _) = self.lstm(mels)
+
+ # Project the final hidden state
+ raw_embeds = self.proj(hidden[-1])
+ if self.hp.ve_final_relu:
+ raw_embeds = F.relu(raw_embeds)
+
+ # L2 normalize the embeddings.
+ return raw_embeds / torch.linalg.norm(raw_embeds, dim=1, keepdim=True)
+
+ def inference(self, mels: torch.Tensor, mel_lens, overlap=0.5, rate: float=None, min_coverage=0.8, batch_size=None):
+ """
+ Computes the embeddings of a batch of full utterances with gradients.
+
+ :param mels: (B, T, M) unscaled mels
+ :return: (B, E) embeddings on CPU
+ """
+ mel_lens = mel_lens.tolist() if torch.is_tensor(mel_lens) else mel_lens
+
+ # Compute where to split the utterances into partials
+ frame_step = get_frame_step(overlap, rate, self.hp)
+ n_partials, target_lens = zip(*(get_num_wins(l, frame_step, min_coverage, self.hp) for l in mel_lens))
+
+ # Possibly pad the mels to reach the target lengths
+ len_diff = max(target_lens) - mels.size(1)
+ if len_diff > 0:
+ pad = torch.full((mels.size(0), len_diff, self.hp.num_mels), 0, dtype=torch.float32)
+ mels = torch.cat((mels, pad.to(mels.device)), dim=1)
+
+ # Group all partials together so that we can batch them easily
+ partials = [
+ mel[i * frame_step: i * frame_step + self.hp.ve_partial_frames]
+ for mel, n_partial in zip(mels, n_partials) for i in range(n_partial)
+ ]
+ assert all(partials[0].shape == partial.shape for partial in partials)
+ partials = torch.stack(partials)
+
+ # Forward the partials
+ n_chunks = int(np.ceil(len(partials) / (batch_size or len(partials))))
+ partial_embeds = torch.cat([self(batch) for batch in partials.chunk(n_chunks)], dim=0).cpu()
+
+ # Reduce the partial embeds into full embeds and L2-normalize them
+ slices = np.concatenate(([0], np.cumsum(n_partials)))
+ raw_embeds = [torch.mean(partial_embeds[start:end], dim=0) for start, end in zip(slices[:-1], slices[1:])]
+ raw_embeds = torch.stack(raw_embeds)
+ embeds = raw_embeds / torch.linalg.norm(raw_embeds, dim=1, keepdim=True)
+
+ return embeds
+
+ @staticmethod
+ def utt_to_spk_embed(utt_embeds: np.ndarray):
+ """
+ Takes an array of L2-normalized utterance embeddings, computes the mean embedding and L2-normalize it to get a
+ speaker embedding.
+ """
+ assert utt_embeds.ndim == 2
+ utt_embeds = np.mean(utt_embeds, axis=0)
+ return utt_embeds / np.linalg.norm(utt_embeds, 2)
+
+ @staticmethod
+ def voice_similarity(embeds_x: np.ndarray, embeds_y: np.ndarray):
+ """
+ Cosine similarity for L2-normalized utterance embeddings or speaker embeddings
+ """
+ embeds_x = embeds_x if embeds_x.ndim == 1 else VoiceEncoder.utt_to_spk_embed(embeds_x)
+ embeds_y = embeds_y if embeds_y.ndim == 1 else VoiceEncoder.utt_to_spk_embed(embeds_y)
+ return embeds_x @ embeds_y
+
+ def embeds_from_mels(
+ self, mels: Union[Tensor, List[np.ndarray]], mel_lens=None, as_spk=False, batch_size=32, **kwargs
+ ):
+ """
+ Convenience function for deriving utterance or speaker embeddings from mel spectrograms.
+
+ :param mels: unscaled mels strictly within [0, 1] as either a (B, T, M) tensor or a list of (Ti, M) arrays.
+ :param mel_lens: if passing mels as a tensor, individual mel lengths
+ :param as_spk: whether to return utterance embeddings or a single speaker embedding
+ :param kwargs: args for inference()
+
+ :returns: embeds as a (B, E) float32 numpy array if is False, else as a (E,) array
+ """
+ # Load mels in memory and pack them
+ if isinstance(mels, List):
+ mels = [np.asarray(mel) for mel in mels]
+ assert all(m.shape[1] == mels[0].shape[1] for m in mels), "Mels aren't in (B, T, M) format"
+ mel_lens = [mel.shape[0] for mel in mels]
+ mels = pack(mels)
+
+ # Embed them
+ with torch.inference_mode():
+ utt_embeds = self.inference(mels.to(self.device), mel_lens, batch_size=batch_size, **kwargs).numpy()
+
+ return self.utt_to_spk_embed(utt_embeds) if as_spk else utt_embeds
+
+ def embeds_from_wavs(
+ self,
+ wavs: List[np.ndarray],
+ sample_rate,
+ as_spk=False,
+ batch_size=32,
+ trim_top_db: Optional[float]=20,
+ **kwargs
+ ):
+ """
+ Wrapper around embeds_from_mels
+
+ :param trim_top_db: this argument was only added for the sake of compatibility with metavoice's implementation
+ """
+ if sample_rate != self.hp.sample_rate:
+ wavs = [
+ librosa.resample(wav, orig_sr=sample_rate, target_sr=self.hp.sample_rate, res_type="kaiser_fast")
+ for wav in wavs
+ ]
+
+ if trim_top_db:
+ wavs = [librosa.effects.trim(wav, top_db=trim_top_db)[0] for wav in wavs]
+
+ if "rate" not in kwargs:
+ kwargs["rate"] = 1.3 # Resemble's default value.
+
+ mels = [melspectrogram(w, self.hp).T for w in wavs]
+
+ return self.embeds_from_mels(mels, as_spk=as_spk, batch_size=batch_size, **kwargs)
diff --git a/src/chatterbox/tts.py b/src/chatterbox/tts.py
new file mode 100644
index 0000000000000000000000000000000000000000..3371dae1f0b0291c643b0977039e35628e3242b4
--- /dev/null
+++ b/src/chatterbox/tts.py
@@ -0,0 +1,244 @@
+from dataclasses import dataclass
+from pathlib import Path
+
+import librosa
+import torch
+import perth
+import torch.nn.functional as F
+from huggingface_hub import hf_hub_download
+
+from .models.t3 import T3
+from .models.s3tokenizer import S3_SR, drop_invalid_tokens
+from .models.s3gen import S3GEN_SR, S3Gen
+from .models.tokenizers import EnTokenizer
+from .models.voice_encoder import VoiceEncoder
+from .models.t3.modules.cond_enc import T3Cond
+
+
+REPO_ID = "ResembleAI/chatterbox"
+
+
+def punc_norm(text: str) -> str:
+ """
+ Quick cleanup func for punctuation from LLMs or
+ containing chars not seen often in the dataset
+ """
+ if len(text) == 0:
+ return "You need to add some text for me to talk."
+
+ # Capitalise first letter
+ if text[0].islower():
+ text = text[0].upper() + text[1:]
+
+ # Remove multiple space chars
+ text = " ".join(text.split())
+
+ # Replace uncommon/llm punc
+ punc_to_replace = [
+ ("...", ", "),
+ ("…", ", "),
+ (":", ","),
+ (" - ", ", "),
+ (";", ", "),
+ ("—", "-"),
+ ("–", "-"),
+ (" ,", ","),
+ ("“", "\""),
+ ("”", "\""),
+ ("‘", "'"),
+ ("’", "'"),
+ ]
+ for old_char_sequence, new_char in punc_to_replace:
+ text = text.replace(old_char_sequence, new_char)
+
+ # Add full stop if no ending punc
+ text = text.rstrip(" ")
+ sentence_enders = {".", "!", "?", "-", ","}
+ if not any(text.endswith(p) for p in sentence_enders):
+ text += "."
+
+ return text
+
+
+@dataclass
+class Conditionals:
+ """
+ Conditionals for T3 and S3Gen
+ - T3 conditionals:
+ - speaker_emb
+ - clap_emb
+ - cond_prompt_speech_tokens
+ - cond_prompt_speech_emb
+ - emotion_adv
+ - S3Gen conditionals:
+ - prompt_token
+ - prompt_token_len
+ - prompt_feat
+ - prompt_feat_len
+ - embedding
+ """
+ t3: T3Cond
+ gen: dict
+
+ def to(self, device):
+ self.t3 = self.t3.to(device=device)
+ for k, v in self.gen.items():
+ if torch.is_tensor(v):
+ self.gen[k] = v.to(device=device)
+ return self
+
+ def save(self, fpath: Path):
+ arg_dict = dict(
+ t3=self.t3.__dict__,
+ gen=self.gen
+ )
+ torch.save(arg_dict, fpath)
+
+ @classmethod
+ def load(cls, fpath, map_location="cpu"):
+ kwargs = torch.load(fpath, map_location=map_location, weights_only=True)
+ return cls(T3Cond(**kwargs['t3']), kwargs['gen'])
+
+
+class ChatterboxTTS:
+ ENC_COND_LEN = 6 * S3_SR
+ DEC_COND_LEN = 10 * S3GEN_SR
+
+ def __init__(
+ self,
+ t3: T3,
+ s3gen: S3Gen,
+ ve: VoiceEncoder,
+ tokenizer: EnTokenizer,
+ device: str,
+ conds: Conditionals = None,
+ ):
+ self.sr = S3GEN_SR # sample rate of synthesized audio
+ self.t3 = t3
+ self.s3gen = s3gen
+ self.ve = ve
+ self.tokenizer = tokenizer
+ self.device = device
+ self.conds = conds
+ self.watermarker = perth.PerthImplicitWatermarker()
+
+ @classmethod
+ def from_local(cls, ckpt_dir, device) -> 'ChatterboxTTS':
+ ckpt_dir = Path(ckpt_dir)
+
+ ve = VoiceEncoder()
+ ve.load_state_dict(
+ torch.load(ckpt_dir / "ve.pt")
+ )
+ ve.to(device).eval()
+
+ t3 = T3()
+ t3_state = torch.load(ckpt_dir / "t3_cfg.pt")
+ if "model" in t3_state.keys():
+ t3_state = t3_state["model"][0]
+ t3.load_state_dict(t3_state)
+ t3.to(device).eval()
+
+ s3gen = S3Gen()
+ s3gen.load_state_dict(
+ torch.load(ckpt_dir / "s3gen.pt")
+ )
+ s3gen.to(device).eval()
+
+ tokenizer = EnTokenizer(
+ str(ckpt_dir / "tokenizer.json")
+ )
+
+ conds = None
+ if (builtin_voice := ckpt_dir / "conds.pt").exists():
+ conds = Conditionals.load(builtin_voice).to(device)
+
+ return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
+
+ @classmethod
+ def from_pretrained(cls, device) -> 'ChatterboxTTS':
+ for fpath in ["ve.pt", "t3_cfg.pt", "s3gen.pt", "tokenizer.json", "conds.pt"]:
+ local_path = hf_hub_download(repo_id=REPO_ID, filename=fpath)
+
+ return cls.from_local(Path(local_path).parent, device)
+
+ def prepare_conditionals(self, wav_fpath, exaggeration=0.5):
+ ## Load reference wav
+ s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR)
+
+ ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
+
+ s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
+ s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device)
+
+ # Speech cond prompt tokens
+ if plen := self.t3.hp.speech_cond_prompt_len:
+ s3_tokzr = self.s3gen.tokenizer
+ t3_cond_prompt_tokens, _ = s3_tokzr.forward([ref_16k_wav[:self.ENC_COND_LEN]], max_len=plen)
+ t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens).to(self.device)
+
+ # Voice-encoder speaker embedding
+ ve_embed = torch.from_numpy(self.ve.embeds_from_wavs([ref_16k_wav], sample_rate=S3_SR))
+ ve_embed = ve_embed.mean(axis=0, keepdim=True).to(self.device)
+
+ t3_cond = T3Cond(
+ speaker_emb=ve_embed,
+ cond_prompt_speech_tokens=t3_cond_prompt_tokens,
+ emotion_adv=exaggeration * torch.ones(1, 1, 1),
+ ).to(device=self.device)
+ self.conds = Conditionals(t3_cond, s3gen_ref_dict)
+
+ def generate(
+ self,
+ text,
+ audio_prompt_path=None,
+ exaggeration=0.5,
+ cfg_weight=0.5,
+ temperature=0.8,
+ ):
+ if audio_prompt_path:
+ self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration)
+ else:
+ assert self.conds is not None, "Please `prepare_conditionals` first or specify `audio_prompt_path`"
+
+ # Update exaggeration if needed
+ if exaggeration != self.conds.t3.emotion_adv[0, 0, 0]:
+ _cond: T3Cond = self.conds.t3
+ self.conds.t3 = T3Cond(
+ speaker_emb=_cond.speaker_emb,
+ cond_prompt_speech_tokens=_cond.cond_prompt_speech_tokens,
+ emotion_adv=exaggeration * torch.ones(1, 1, 1),
+ ).to(device=self.device)
+
+ # Norm and tokenize text
+ text = punc_norm(text)
+ text_tokens = self.tokenizer.text_to_tokens(text).to(self.device)
+ text_tokens = torch.cat([text_tokens, text_tokens], dim=0) # Need two seqs for CFG
+
+ sot = self.t3.hp.start_text_token
+ eot = self.t3.hp.stop_text_token
+ text_tokens = F.pad(text_tokens, (1, 0), value=sot)
+ text_tokens = F.pad(text_tokens, (0, 1), value=eot)
+
+ with torch.inference_mode():
+ speech_tokens = self.t3.inference(
+ t3_cond=self.conds.t3,
+ text_tokens=text_tokens,
+ max_new_tokens=1000, # TODO: use the value in config
+ temperature=temperature,
+ cfg_weight=cfg_weight,
+ )
+ # Extract only the conditional batch.
+ speech_tokens = speech_tokens[0]
+
+ # TODO: output becomes 1D
+ speech_tokens = drop_invalid_tokens(speech_tokens)
+ speech_tokens = speech_tokens.to(self.device)
+
+ wav, _ = self.s3gen.inference(
+ speech_tokens=speech_tokens,
+ ref_dict=self.conds.gen,
+ )
+ wav = wav.squeeze(0).detach().cpu().numpy()
+ watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
+ return torch.from_numpy(watermarked_wav).unsqueeze(0)
diff --git a/src/chatterbox/vc.py b/src/chatterbox/vc.py
new file mode 100644
index 0000000000000000000000000000000000000000..629b686b6ed1e7e55654d45f54ac5f9b83953de5
--- /dev/null
+++ b/src/chatterbox/vc.py
@@ -0,0 +1,88 @@
+from pathlib import Path
+
+import librosa
+import torch
+import perth
+from huggingface_hub import hf_hub_download
+
+from .models.s3tokenizer import S3_SR
+from .models.s3gen import S3GEN_SR, S3Gen
+
+
+REPO_ID = "ResembleAI/chatterbox"
+
+
+class ChatterboxVC:
+ ENC_COND_LEN = 6 * S3_SR
+ DEC_COND_LEN = 10 * S3GEN_SR
+
+ def __init__(
+ self,
+ s3gen: S3Gen,
+ device: str,
+ ref_dict: dict=None,
+ ):
+ self.sr = S3GEN_SR
+ self.s3gen = s3gen
+ self.device = device
+ self.watermarker = perth.PerthImplicitWatermarker()
+ if ref_dict is None:
+ self.ref_dict = None
+ else:
+ self.ref_dict = {
+ k: v.to(device) if torch.is_tensor(v) else v
+ for k, v in ref_dict.items()
+ }
+
+ @classmethod
+ def from_local(cls, ckpt_dir, device) -> 'ChatterboxVC':
+ ckpt_dir = Path(ckpt_dir)
+ ref_dict = None
+ if (builtin_voice := ckpt_dir / "conds.pt").exists():
+ states = torch.load(builtin_voice)
+ ref_dict = states['gen']
+
+ s3gen = S3Gen()
+ s3gen.load_state_dict(
+ torch.load(ckpt_dir / "s3gen.pt")
+ )
+ s3gen.to(device).eval()
+
+ return cls(s3gen, device, ref_dict=ref_dict)
+
+ @classmethod
+ def from_pretrained(cls, device) -> 'ChatterboxVC':
+ for fpath in ["s3gen.pt", "conds.pt"]:
+ local_path = hf_hub_download(repo_id=REPO_ID, filename=fpath)
+
+ return cls.from_local(Path(local_path).parent, device)
+
+ def set_target_voice(self, wav_fpath):
+ ## Load reference wav
+ s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR)
+
+ s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
+ self.ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device)
+
+ def generate(
+ self,
+ audio,
+ target_voice_path=None,
+ ):
+ if target_voice_path:
+ self.set_target_voice(target_voice_path)
+ else:
+ assert self.ref_dict is not None, "Please `prepare_conditionals` first or specify `target_voice_path`"
+
+ with torch.inference_mode():
+ audio_16, _ = librosa.load(audio, sr=S3_SR)
+ audio_16 = torch.from_numpy(audio_16).float().to(self.device)[None, ]
+
+ s3_tokens, _ = self.s3gen.tokenizer(audio_16)
+ wav, _ = self.s3gen.inference(
+ speech_tokens=s3_tokens,
+ ref_dict=self.ref_dict,
+ )
+ wav = wav.squeeze(0).detach().cpu().numpy()
+ watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
+ return torch.from_numpy(watermarked_wav).unsqueeze(0)
diff --git a/voice_conversion.py b/voice_conversion.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ef8ecbbfee24ff2e72efd83cc6d0810b672cd84
--- /dev/null
+++ b/voice_conversion.py
@@ -0,0 +1,76 @@
+from tqdm import tqdm
+import sys
+import torch
+import shutil
+import perth
+from pathlib import Path
+import argparse
+import os
+import librosa
+import soundfile as sf
+from chatterbox.models.s3tokenizer import S3_SR
+from chatterbox.models.s3gen import S3GEN_SR, S3Gen
+
+AUDIO_EXTENSIONS = ["wav", "mp3", "flac", "opus"]
+
+
+@torch.inference_mode()
+def main():
+ parser = argparse.ArgumentParser(description="Voice Conversion")
+ parser.add_argument("input", type=str, help="Path to input (a sample or folder of samples).")
+ parser.add_argument("target_speaker", type=str, help="Path to the sample for the target speaker.")
+ parser.add_argument("-o", "--output_folder", type=str, default="vc_outputs")
+ parser.add_argument("-g", "--gpu_id", type=int, default=None)
+ parser.add_argument("--no-watermark", action="store_true", help="Skip watermarking")
+ args = parser.parse_args()
+
+ # Folders
+ input = Path(args.input)
+ output_folder = Path(args.output_folder)
+ output_orig_folder = output_folder / "input"
+ output_vc_folder = output_folder / "output"
+ ref_folder = output_vc_folder / "target"
+ output_orig_folder.mkdir(exist_ok=True, parents=True)
+ output_vc_folder.mkdir(exist_ok=True)
+ ref_folder.mkdir(exist_ok=True)
+
+ device = torch.device("cpu" if args.gpu_id is None else f"cuda:{args.gpu_id}")
+
+ ## s3gen
+ s3g_fp = "checkpoints/s3gen.pt"
+ s3gen = S3Gen()
+ s3gen.load_state_dict(torch.load(s3g_fp))
+ s3gen.to(device)
+ s3gen.eval()
+
+ wav_fpaths = []
+ if input.is_dir():
+ for ext in AUDIO_EXTENSIONS:
+ wav_fpaths += list(input.glob(f"*.{ext}"))
+ else:
+ wav_fpaths.append(input)
+
+ assert wav_fpaths, f"Didn't find any audio in '{input}'"
+
+ ref_24, _ = librosa.load(args.target_speaker, sr=S3GEN_SR, duration=10)
+ ref_24 = torch.tensor(ref_24).float()
+ shutil.copy(args.target_speaker, ref_folder / Path(args.target_speaker).name)
+ if not args.no_watermark:
+ watermarker = perth.PerthImplicitWatermarker()
+ for wav_fpath in tqdm(wav_fpaths):
+ shutil.copy(wav_fpath, output_orig_folder / wav_fpath.name)
+
+ audio_16, _ = librosa.load(str(wav_fpath), sr=S3_SR)
+ audio_16 = torch.tensor(audio_16).float().to(device)[None, ]
+ s3_tokens, _ = s3gen.tokenizer(audio_16)
+
+ wav = s3gen(s3_tokens.to(device), ref_24, S3GEN_SR)
+ wav = wav.view(-1).cpu().numpy()
+ if not args.no_watermark:
+ wav = watermarker.apply_watermark(wav, sample_rate=S3GEN_SR)
+ save_path = output_vc_folder / wav_fpath.name
+ sf.write(str(save_path), wav, samplerate=S3GEN_SR)
+
+
+if __name__ == "__main__":
+ main()