diff --git a/app.py b/app.py
index 8731f6999b2098ef841c1495adfb0d53832d1604..542a98adfb4f594f23e889c91f6d3506ed8e0426 100644
--- a/app.py
+++ b/app.py
@@ -1,193 +1,129 @@
# coding=utf-8
-import os
-import librosa
-import base64
import io
-import gradio as gr
-import re
-
import numpy as np
-import torch
import torchaudio
-from modelscope import HubApi
-
-api = HubApi()
-
-key = os.environ["apikey"] if "apikey" in os.environ else ""
-try:
- api.login(key)
-except:
- pass
-
-from funasr import AutoModel
-
-model = "iic/SenseVoiceSmall"
-model = AutoModel(model=model,
- vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
- vad_kwargs={"max_single_segment_time": 30000},
- trust_remote_code=True,
- )
-
-import re
-
-emo_dict = {
- "<|HAPPY|>": "๐",
- "<|SAD|>": "๐",
- "<|ANGRY|>": "๐ก",
- "<|NEUTRAL|>": "",
- "<|FEARFUL|>": "๐ฐ",
- "<|DISGUSTED|>": "๐คข",
- "<|SURPRISED|>": "๐ฎ",
-}
-
-event_dict = {
- "<|BGM|>": "๐ผ",
- "<|Speech|>": "",
- "<|Applause|>": "๐",
- "<|Laughter|>": "๐",
- "<|Cry|>": "๐ญ",
- "<|Sneeze|>": "๐คง",
- "<|Breath|>": "",
- "<|Cough|>": "๐คง",
-}
-
-emoji_dict = {
- "<|nospeech|><|Event_UNK|>": "โ",
- "<|zh|>": "",
- "<|en|>": "",
- "<|yue|>": "",
- "<|ja|>": "",
- "<|ko|>": "",
- "<|nospeech|>": "",
- "<|HAPPY|>": "๐",
- "<|SAD|>": "๐",
- "<|ANGRY|>": "๐ก",
- "<|NEUTRAL|>": "",
- "<|BGM|>": "๐ผ",
- "<|Speech|>": "",
- "<|Applause|>": "๐",
- "<|Laughter|>": "๐",
- "<|FEARFUL|>": "๐ฐ",
- "<|DISGUSTED|>": "๐คข",
- "<|SURPRISED|>": "๐ฎ",
- "<|Cry|>": "๐ญ",
- "<|EMO_UNKNOWN|>": "",
- "<|Sneeze|>": "๐คง",
- "<|Breath|>": "",
- "<|Cough|>": "๐ท",
- "<|Sing|>": "",
- "<|Speech_Noise|>": "",
- "<|withitn|>": "",
- "<|woitn|>": "",
- "<|GBG|>": "",
- "<|Event_UNK|>": "",
-}
-
-lang_dict = {
- "<|zh|>": "<|lang|>",
- "<|en|>": "<|lang|>",
- "<|yue|>": "<|lang|>",
- "<|ja|>": "<|lang|>",
- "<|ko|>": "<|lang|>",
- "<|nospeech|>": "<|lang|>",
-}
-
-
-emo_set = {"๐", "๐", "๐ก", "๐ฐ", "๐คข", "๐ฎ"}
-event_set = {"๐ผ", "๐", "๐", "๐ญ", "๐คง", "๐ท",}
-
-notes = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
-
-def format_str(s):
- for sptk in emoji_dict:
- s = s.replace(sptk, emoji_dict[sptk])
- return s
-
-
-def format_str_v2(s):
- sptk_dict = {}
- for sptk in emoji_dict:
- sptk_dict[sptk] = s.count(sptk)
- s = s.replace(sptk, "")
- emo = "<|NEUTRAL|>"
- for e in emo_dict:
- if sptk_dict[e] > sptk_dict[emo]:
- emo = e
- for e in event_dict:
- if sptk_dict[e] > 0:
- s = event_dict[e] + s
- s = s + emo_dict[emo]
-
- for emoji in emo_set.union(event_set):
- s = s.replace(" " + emoji, emoji)
- s = s.replace(emoji + " ", emoji)
- return s.strip()
-
-def format_str_v3(s):
- def get_emo(s):
- return s[-1] if s[-1] in emo_set else None
- def get_event(s):
- return s[0] if s[0] in event_set else None
-
- s = s.replace("<|nospeech|><|Event_UNK|>", "โ")
- for lang in lang_dict:
- s = s.replace(lang, "<|lang|>")
- s_list = [format_str_v2(s_i).strip(" ") for s_i in s.split("<|lang|>")]
- new_s = " " + s_list[0]
- cur_ent_event = get_event(new_s)
- for i in range(1, len(s_list)):
- if len(s_list[i]) == 0:
- continue
- if get_event(s_list[i]) == cur_ent_event and get_event(s_list[i]) != None:
- s_list[i] = s_list[i][1:]
- #else:
- cur_ent_event = get_event(s_list[i])
- if get_emo(s_list[i]) != None and get_emo(s_list[i]) == get_emo(new_s):
- new_s = new_s[:-1]
- new_s += s_list[i].strip().lstrip()
- new_s = new_s.replace("The.", " ")
- return new_s.strip()
-
-def model_inference(input_wav, language, fs=16000):
- # task_abbr = {"Speech Recognition": "ASR", "Rich Text Transcription": ("ASR", "AED", "SER")}
- language_abbr = {"auto": "auto", "zh": "zh", "en": "en", "yue": "yue", "ja": "ja", "ko": "ko",
- "nospeech": "nospeech"}
-
- # task = "Speech Recognition" if task is None else task
- language = "auto" if len(language) < 1 else language
- selected_language = language_abbr[language]
- # selected_task = task_abbr.get(task)
-
- # print(f"input_wav: {type(input_wav)}, {input_wav[1].shape}, {input_wav}")
-
- if isinstance(input_wav, tuple):
- fs, input_wav = input_wav
- input_wav = input_wav.astype(np.float32) / np.iinfo(np.int16).max
- if len(input_wav.shape) > 1:
- input_wav = input_wav.mean(-1)
- if fs != 16000:
- print(f"audio_fs: {fs}")
- resampler = torchaudio.transforms.Resample(fs, 16000)
- input_wav_t = torch.from_numpy(input_wav).to(torch.float32)
- input_wav = resampler(input_wav_t[None, :])[0, :].numpy()
-
-
- merge_vad = True
- print(f"language: {language}, merge_vad: {merge_vad}")
- text = model.generate(input=input_wav,
- cache={},
- language=language,
- use_itn=True,
- batch_size_s=300, merge_vad=merge_vad)
-
- print(text)
- text = text[0]["text"]
- text = format_str_v3(text)
- print(text)
-
- return text
+import torch
+import soundfile as sf
+import gradio as gr
+import spaces
+from inspiremusic.cli.inference import InspireMusicUnified, set_env_variables
+import os
+import sys
+
+
+def get_args():
+ parser = argparse.ArgumentParser(
+ description='Run inference with your model')
+ parser.add_argument('-m', '--model_name', default="InspireMusic-1.5B-Long",
+ help='Model name')
+
+ parser.add_argument('-d', '--model_dir',
+ help='Model folder path')
+
+ parser.add_argument('-t', '--text',
+ default="Experience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance.",
+ help='Prompt text')
+
+ parser.add_argument('-a', '--audio_prompt', default=None,
+ help='Prompt audio')
+
+ parser.add_argument('-c', '--chorus', default="intro",
+ help='Chorus tag generation mode (e.g., random, verse, chorus, intro, outro)')
+
+ parser.add_argument('--fast', type=bool, default=False,
+ help='Enable fast inference mode (without flow matching)')
+
+ parser.add_argument('-g', '--gpu', type=int, default=0,
+ help='GPU ID for this rank, -1 for CPU')
+
+ parser.add_argument('--task', default='text-to-music',
+ choices=['text-to-music', 'continuation', 'reconstruct', 'super_resolution'],
+ help='Inference task type: text-to-music, continuation, reconstruct, super_resolution')
+
+ parser.add_argument('-r', '--result_dir', default="exp/inspiremusic",
+ help='Directory to save generated audio')
+
+ parser.add_argument('-o', '--output_fn', default="output_audio",
+ help='Output file name')
+
+ parser.add_argument('-f', '--format', type=str, default="wav",
+ choices=["wav", "mp3", "m4a", "flac"],
+ help='Format of output audio')
+
+ parser.add_argument('--sample_rate', type=int, default=24000,
+ help='Sampling rate of input audio')
+
+ parser.add_argument('--output_sample_rate', type=int, default=48000,
+ choices=[24000, 48000],
+ help='Sampling rate of generated output audio')
+
+ parser.add_argument('-s', '--time_start', type=float, default=0.0,
+ help='Start time in seconds')
+
+ parser.add_argument('-e', '--time_end', type=float, default=30.0,
+ help='End time in seconds')
+
+ parser.add_argument('--max_audio_prompt_length', type=float, default=5.0,
+ help='Maximum audio prompt length in seconds')
+
+ parser.add_argument('--min_generate_audio_seconds', type=float,
+ default=10.0,
+ help='Minimum generated audio length in seconds')
+
+ parser.add_argument('--max_generate_audio_seconds', type=float,
+ default=30.0,
+ help='Maximum generated audio length in seconds')
+
+ parser.add_argument('--fp16', type=bool, default=True,
+ help='Inference with fp16 model')
+
+ parser.add_argument('--fade_out', type=bool, default=True,
+ help='Apply fade out effect to generated audio')
+
+ parser.add_argument('--fade_out_duration', type=float, default=1.0,
+ help='Fade out duration in seconds')
+
+ parser.add_argument('--trim', type=bool, default=False,
+ help='Trim the silence ending of generated audio')
+
+ args = parser.parse_args()
+
+ if not args.model_dir:
+ args.model_dir = os.path.join("./pretrained_models", args.model_name)
+
+ print(args)
+ return args
+
+def InspireMusic(args):
+ set_env_variables()
+ model = InspireMusicUnified(model_name=args.model_name,
+ model_dir=args.model_dir,
+ min_generate_audio_seconds=args.min_generate_audio_seconds,
+ max_generate_audio_seconds=args.max_generate_audio_seconds,
+ sample_rate=args.sample_rate,
+ output_sample_rate=args.output_sample_rate,
+ load_jit=True,
+ load_onnx=False,
+ fast=args.fast,
+ fp16=args.fp16,
+ gpu=args.gpu,
+ result_dir=args.result_dir)
+
+ model.inference(task=args.task,
+ text=args.text,
+ audio_prompt=args.audio_prompt,
+ chorus=args.chorus,
+ time_start=args.time_start,
+ time_end=args.time_end,
+ output_fn=args.output_fn,
+ max_audio_prompt_length=args.max_audio_prompt_length,
+ fade_out_duration=args.fade_out_duration,
+ output_format=args.format,
+ fade_out_mode=args.fade_out,
+ trim=args.trim)
+ return os.path.join(args.result_dir, f"{args.output_fn}.{args.format}")
audio_examples = [
["example/inspiremusic/inspiremusic_01.wav", "text-to-music"],
@@ -218,7 +154,7 @@ description = """
- `The instrumental rock piece features a prominent bass guitar, delivering a pure and energetic sound.`
- `A serene blend of instrumental and light pop, featuring soothing melodies and a gentle, soulful keyboard performance.`
-Recommended select audio duration is below 30 seconds. For audio longer than 30 seconds, local deployment is recommended, github repo.
+Recommended audio prompt duration is 5 seconds, generate audio length is below 30 seconds. To generate audio longer than 30 seconds, local deployment is recommended, github repo.
"""
@@ -232,86 +168,92 @@ html_content = """
Code
Demo
Models
- Modelscope Model:
- Huggingface Model
+ Modelscope Model:
+ Huggingface Model
"""
-# ่ชๅฎไน่กจๆ ผ็ HTML ๅ CSS ไปฃ็
-centered_table_html = """
-
-
-
-
- Samples |
- InspireMusic |
- Text-to-Music |
-
-
-
- normal mode |
- Experience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance. |
-
-
-
- fast mode |
- The instrumental piece exudes a playful and whimsical atmosphere, likely featuring lively and rhythmic elements. The music seems to be inspired by nature and animals, creating an engaging and light-hearted experience. |
-
-
-
-"""
-
-
-def launch():
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
- # gr.Markdown(description)
- gr.HTML(html_content)
- with gr.Column():
- with gr.Row():
- with gr.Column():
- text_inputs = gr.Textbox(
- label="Input Text",
- placeholder="Enter the text you want to generate music, e.g., Experience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance.",
- lines=3
- )
- fn_button = gr.Button("Start", variant="primary")
- audio_inputs = gr.Audio(
- label="Upload prompt audio",
- )
- with gr.Column():
- with gr.Accordion("Configuration"):
- # task_inputs = gr.Radio(choices=["Speech Recognition", "Rich Text Transcription"],
- # value="Speech Recognition", label="Task")
- task_inputs = gr.Dropdown(choices=["text-to-music", "music-continuation"],
- value="text-to-music",
- label="Task")
- inference_mode_inputs = gr.Dropdown(choices=["normal", "fast"],
- value="normal",
- label="Inference Mode")
- cfg_input = gr.Slider(3, 10, step=1, label="CFG value")
- audio_length = gr.Textbox(value="30",
- label="Duration in seconds")
-
- gr.Examples(examples=audio_examples,
- inputs=[text_inputs, audio_inputs, task_inputs],
- examples_per_page=5)
-
- audio_output = gr.Audio(label="Audio Output")
-
- fn_button.click(model_inference, inputs=[text_inputs, audio_inputs, task_inputs], outputs=audio_output)
-
- # with gr.Accordion("More examples"):
- # gr.HTML(centered_table_html)
- demo.launch()
-
-
-if __name__ == "__main__":
- # iface.launch()
- launch()
+def music_generation(task, text=None, audio=None):
+ args = get_args()
+ args.task = task
+ args.text = text if text
+ args.audio_prompt = audio if audio
+ generate_audio_path = InspireMusic(args)
+ return generate_audio_path
+
+demo = gr.Blocks()
+
+t2m_demo = gr.Interface(
+ fn=music_generation,
+ inputs = [
+ gr.Dropdown(["Text-To-Music"], value="text-to-music", multiselect=False, info="Choose a task."),
+ gr.Text(label="Input Text"),
+ ],
+ outputs = [
+ gr.Audio(label="Generated Music", type="generated audio filepath"),
+ ],
+ title = "InspireMusic: A Unified Framework for Music, Song, Audio Generation.",
+ description = ("InspireMusic ([Github Repo](https://github.com/FunAudioLLM/InspireMusic)) is a fundamental AIGC toolkit and models designed for music, song, and audio generation using PyTorch."
+ "To try it, simply type text to generation music, or click one of the examples. "),
+ article = ("InspireMusic
"
+ "WavTokenizer: an Efficient Acoustic Discrete Codec Tokenizer for Audio Language Modeling
"),
+ examples = [
+ ["example/inspiremusic/inspiremusic_01.wav", "24000 Hz"],
+ ["example/ras/chorus/chorus_01.wav", "48000 Hz"],
+ ],
+ cache_examples = True,
+)
+
+con_demo = gr.Interface(
+ fn=music_generation,
+ inputs = [
+ gr.Dropdown(["Music Continuation"], value="continuation", multiselect=False, info="Choose a task."),
+ gr.Text(label="Input Text"),
+ gr.Audio(label="Input Audio Prompt", type="audio prompt filepath"),
+ ],
+ outputs = [
+ gr.Audio(label="Generated Music", type="generated audio filepath"),
+ ],
+ title = "InspireMusic: A Unified Framework for Music, Song, Audio Generation.",
+ description = ("InspireMusic ([Github Repo](https://github.com/FunAudioLLM/InspireMusic)) is a fundamental AIGC toolkit and models designed for music, song, and audio generation using PyTorch."
+ "To try it, simply type text to generation music, or click one of the examples. "),
+ article = ("InspireMusic
"
+ "WavTokenizer: an Efficient Acoustic Discrete Codec Tokenizer for Audio Language Modeling
"),
+ examples = [
+ ["example/inspiremusic/inspiremusic_01.wav", "24000 Hz"],
+ ["example/ras/chorus/chorus_01.wav", "48000 Hz"],
+ ],
+ cache_examples = True,
+)
+
+con_demo = gr.Interface(
+ fn=music_generation,
+ inputs = [
+ gr.Dropdown(["Music Continuation"], value="continuation", multiselect=False, info="Choose a task."),
+ gr.Text(label="Input Text"),
+ gr.Audio(label="Input Audio Prompt", type="audio prompt filepath"),
+ ],
+ outputs = [
+ gr.Audio(label="Generated Music", type="generated audio filepath"),
+ ],
+ title = "InspireMusic: A Unified Framework for Music, Song, Audio Generation.",
+ description = ("InspireMusic ([Github Repo](https://github.com/FunAudioLLM/InspireMusic)) is a fundamental AIGC toolkit and models designed for music, song, and audio generation using PyTorch."
+ "To try it, simply type text to generation music, or click one of the examples. "),
+ article = ("InspireMusic
"
+ "WavTokenizer: an Efficient Acoustic Discrete Codec Tokenizer for Audio Language Modeling
"),
+ examples = [
+ ["example/inspiremusic/inspiremusic_01.wav", "24000 Hz"],
+ ["example/ras/chorus/chorus_01.wav", "48000 Hz"],
+ ],
+ cache_examples = True,
+)
+
+with demo:
+ gr.TabbedInterface([t2m_demo, con_demo,],
+ ["Task 1: Text-to-Music",
+ "Task 2: Music Continuation"])
+ # gr.TabbedInterface([t2m_demo, con_demo, fast_demo], ["Task 1: Text-to-Music", "Task 2: Music Continuation", "Task 3: Without Flow Matching"])
+
+demo.launch()
diff --git a/inspiremusic/__init__.py b/inspiremusic/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/inspiremusic/bin/export_jit.py b/inspiremusic/bin/export_jit.py
new file mode 100644
index 0000000000000000000000000000000000000000..f68203f61246f1167f68682b5a1ff8fc5929f521
--- /dev/null
+++ b/inspiremusic/bin/export_jit.py
@@ -0,0 +1,74 @@
+# Copyright (c) 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import argparse
+import logging
+logging.getLogger('matplotlib').setLevel(logging.WARNING)
+import os
+import sys
+import torch
+ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
+sys.path.append('{}/../..'.format(ROOT_DIR))
+sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
+from inspiremusic.cli.inspiremusic import InspireMusic
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description='export your model for deployment')
+ parser.add_argument('--model_dir',
+ type=str,
+ default='pretrained_models/InspireMusic',
+ help='local path')
+ args = parser.parse_args()
+ print(args)
+ return args
+
+
+def main():
+ args = get_args()
+ logging.basicConfig(level=logging.DEBUG,
+ format='%(asctime)s %(levelname)s %(message)s')
+
+ torch._C._jit_set_fusion_strategy([('STATIC', 1)])
+ torch._C._jit_set_profiling_mode(False)
+ torch._C._jit_set_profiling_executor(False)
+
+ inspiremusic = InspireMusic(args.model_dir, load_jit=False, load_onnx=False)
+
+ # 1. export llm text_encoder
+ llm_text_encoder = inspiremusic.model.llm.text_encoder.half()
+ script = torch.jit.script(llm_text_encoder)
+ script = torch.jit.freeze(script)
+ script = torch.jit.optimize_for_inference(script)
+ script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
+
+ # 2. export llm llm
+ llm_llm = inspiremusic.model.llm.llm.half()
+ script = torch.jit.script(llm_llm)
+ script = torch.jit.freeze(script, preserved_attrs=['forward_chunk'])
+ script = torch.jit.optimize_for_inference(script)
+ script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
+
+ # 3. export flow encoder
+ flow_encoder = inspiremusic.model.flow.encoder
+ script = torch.jit.script(flow_encoder)
+ script = torch.jit.freeze(script)
+ script = torch.jit.optimize_for_inference(script)
+ script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/inspiremusic/bin/export_onnx.py b/inspiremusic/bin/export_onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..659ee13e4fc495757d11dfc558bf2b1629f35089
--- /dev/null
+++ b/inspiremusic/bin/export_onnx.py
@@ -0,0 +1,112 @@
+# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com)
+# Copyright (c) 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import argparse
+import logging
+logging.getLogger('matplotlib').setLevel(logging.WARNING)
+import os
+import sys
+import onnxruntime
+import random
+import torch
+from tqdm import tqdm
+ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
+sys.path.append('{}/../..'.format(ROOT_DIR))
+sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
+from inspiremusic.cli.inspiremusic import InspireMusic
+
+
+def get_dummy_input(batch_size, seq_len, out_channels, device):
+ x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
+ mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
+ mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
+ t = torch.rand((batch_size), dtype=torch.float32, device=device)
+ spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
+ cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
+ return x, mask, mu, t, spks, cond
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description='export your model for deployment')
+ parser.add_argument('--model_dir',
+ type=str,
+ default='pretrained_models/InspireMusic',
+ help='local path')
+ args = parser.parse_args()
+ print(args)
+ return args
+
+
+def main():
+ args = get_args()
+ logging.basicConfig(level=logging.DEBUG,
+ format='%(asctime)s %(levelname)s %(message)s')
+
+ inspiremusic = InspireMusic(args.model_dir, load_jit=False, load_onnx=False)
+
+ # 1. export flow decoder estimator
+ estimator = inspiremusic.model.flow.decoder.estimator
+
+ device = inspiremusic.model.device
+ batch_size, seq_len = 1, 256
+ out_channels = inspiremusic.model.flow.decoder.estimator.out_channels
+ x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
+ torch.onnx.export(
+ estimator,
+ (x, mask, mu, t, spks, cond),
+ '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
+ export_params=True,
+ opset_version=18,
+ do_constant_folding=True,
+ input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
+ output_names=['estimator_out'],
+ dynamic_axes={
+ 'x': {0: 'batch_size', 2: 'seq_len'},
+ 'mask': {0: 'batch_size', 2: 'seq_len'},
+ 'mu': {0: 'batch_size', 2: 'seq_len'},
+ 'cond': {0: 'batch_size', 2: 'seq_len'},
+ 't': {0: 'batch_size'},
+ 'spks': {0: 'batch_size'},
+ 'estimator_out': {0: 'batch_size', 2: 'seq_len'},
+ }
+ )
+
+ # 2. test computation consistency
+ option = onnxruntime.SessionOptions()
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+ option.intra_op_num_threads = 1
+ providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
+ estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
+ sess_options=option, providers=providers)
+
+ for _ in tqdm(range(10)):
+ x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device)
+ output_pytorch = estimator(x, mask, mu, t, spks, cond)
+ ort_inputs = {
+ 'x': x.cpu().numpy(),
+ 'mask': mask.cpu().numpy(),
+ 'mu': mu.cpu().numpy(),
+ 't': t.cpu().numpy(),
+ 'spks': spks.cpu().numpy(),
+ 'cond': cond.cpu().numpy()
+ }
+ output_onnx = estimator_onnx.run(None, ort_inputs)[0]
+ torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/inspiremusic/bin/flow_only_infer.py b/inspiremusic/bin/flow_only_infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..d84e9c3671546071b02e1e78d071f8e0ded78b94
--- /dev/null
+++ b/inspiremusic/bin/flow_only_infer.py
@@ -0,0 +1,150 @@
+# Copyright (c) 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import argparse
+import logging
+logging.getLogger('matplotlib').setLevel(logging.WARNING)
+import os
+import torch
+from torch.utils.data import DataLoader
+import torchaudio
+from hyperpyyaml import load_hyperpyyaml
+from tqdm import tqdm
+from inspiremusic.cli.model import InspireMusicModel
+from inspiremusic.dataset.dataset import Dataset
+from inspiremusic.utils.common import MUSIC_STRUCTURE_LABELS
+
+def get_args():
+ parser = argparse.ArgumentParser(description='inference only with flow model')
+ parser.add_argument('--config', required=True, help='config file')
+ parser.add_argument('--prompt_data', required=True, help='prompt data file')
+ parser.add_argument('--flow_model', required=True, help='flow model file')
+ parser.add_argument('--llm_model', default=None,required=False, help='llm model file')
+
+ parser.add_argument('--music_tokenizer', required=True, help='music tokenizer model file')
+ parser.add_argument('--wavtokenizer', required=True, help='wavtokenizer model file')
+ parser.add_argument('--chorus', default="random",required=False, help='chorus tag generation mode, eg. random, verse, chorus, intro.')
+ parser.add_argument('--sample_rate', type=int, default=48000, required=False,
+ help='sampling rate of generated audio')
+ parser.add_argument('--min_generate_audio_seconds', type=float, default=10.0, required=False,
+ help='the minimum generated audio length in seconds')
+ parser.add_argument('--max_generate_audio_seconds', type=float, default=30.0, required=False,
+ help='the maximum generated audio length in seconds')
+ parser.add_argument('--gpu',
+ type=int,
+ default=-1,
+ help='gpu id for this rank, -1 for cpu')
+ parser.add_argument('--result_dir', required=True, help='asr result file')
+ args = parser.parse_args()
+ print(args)
+ return args
+
+def main():
+ args = get_args()
+ logging.basicConfig(level=logging.DEBUG,
+ format='%(asctime)s %(levelname)s %(message)s')
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
+
+ # Init inspiremusic models from configs
+ use_cuda = args.gpu >= 0 and torch.cuda.is_available()
+ device = torch.device('cuda' if use_cuda else 'cpu')
+ with open(args.config, 'r') as f:
+ configs = load_hyperpyyaml(f)
+
+ model = InspireMusicModel(None, configs['flow'], configs['hift'], configs['wavtokenizer'])
+ model.load(args.llm_model, args.flow_model, args.music_tokenizer, args.wavtokenizer)
+
+ if args.llm_model is None:
+ model.llm = None
+ else:
+ model.llm = model.llm.to(torch.float32)
+
+ if args.flow_model is None:
+ model.flow = None
+
+ test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=True, partition=False)
+ test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
+
+ del configs
+ os.makedirs(args.result_dir, exist_ok=True)
+ fn = os.path.join(args.result_dir, 'wav.scp')
+ f = open(fn, 'w')
+ with torch.no_grad():
+ for _, batch in tqdm(enumerate(test_data_loader)):
+ utts = batch["utts"]
+ assert len(utts) == 1, "inference mode only support batchsize 1"
+
+ if "semantic_token" in batch:
+ token = batch["semantic_token"].to(device)
+ token_len = batch["semantic_token_len"].to(device)
+ else:
+ if audio_token is None:
+ token = None
+ token_len = None
+ else:
+ token = audio_token.view(audio_token.size(0),-1,4)[:,:,0]
+ token_len = audio_token_len / 4
+
+ text_token = batch["text_token"].to(device)
+ text_token_len = batch["text_token_len"].to(device)
+ text = batch["text"]
+
+ if "time_start" not in batch.keys():
+ batch["time_start"] = torch.randint(0, args.min_generate_audio_seconds, (1,)).to(torch.float64)
+ if "time_end" not in batch.keys():
+ batch["time_end"] = torch.randint(args.min_generate_audio_seconds, args.max_generate_audio_seconds, (1,)).to(torch.float64)
+ elif (batch["time_end"].numpy()[0] - batch["time_start"].numpy()[0]) < args.min_generate_audio_seconds:
+ batch["time_end"] = torch.randint(int(batch["time_start"].numpy()[0] + args.min_generate_audio_seconds), int(batch["time_start"].numpy()[0] + args.max_generate_audio_seconds), (1,)).to(torch.float64)
+
+ if "chorus" not in batch.keys():
+ batch["chorus"] = torch.randint(1, 5, (1,))
+
+ if args.chorus == "random":
+ batch["chorus"] = torch.randint(1, 5, (1,))
+ elif args.chorus == "intro":
+ batch["chorus"] = torch.Tensor([0])
+ elif "verse" in args.chorus:
+ batch["chorus"] = torch.Tensor([1])
+ elif args.chorus == "chorus":
+ batch["chorus"] = torch.Tensor([2])
+ elif args.chorus == "outro":
+ batch["chorus"] = torch.Tensor([4])
+
+ time_start = batch["time_start"].to(device)
+ time_end = batch["time_end"].to(device)
+ chorus = batch["chorus"].to(torch.int)
+
+ text_prompt = f"<|{batch['time_start'].numpy()[0]}|><|{MUSIC_STRUCTURE_LABELS[chorus.numpy()[0]]}|><|{batch['text'][0]}|><|{batch['time_end'].numpy()[0]}|>"
+ chorus = chorus.to(device)
+
+ model_input = {"text": text, "audio_token": token, "audio_token_len": token_len,
+ "text_token": text_token, "text_token_len": text_token_len,
+ "embeddings": [time_start, time_end, chorus], "raw_text":text}
+
+ music_audios = []
+ for model_output in model.inference(**model_input):
+ music_audios.append(model_output['music_audio'])
+
+ music_key = utts[0]
+ music_fn = os.path.join(args.result_dir, '{}.wav'.format(music_key))
+ torchaudio.save(music_fn, music_audios[0], sample_rate=args.sample_rate)
+ f.write('{} {}\n'.format(music_key, music_fn))
+ f.flush()
+ f.close()
+ logging.info('Result wav.scp saved in {}'.format(fn))
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/inspiremusic/bin/inference.py b/inspiremusic/bin/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7429fd34487c7b749cdca6afecd4b3a5dcc4bad
--- /dev/null
+++ b/inspiremusic/bin/inference.py
@@ -0,0 +1,266 @@
+# Copyright (c) 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import argparse
+import logging
+
+logging.getLogger('matplotlib').setLevel(logging.WARNING)
+import os
+import torch
+from torch.utils.data import DataLoader
+import torchaudio
+from hyperpyyaml import load_hyperpyyaml
+from tqdm import tqdm
+from inspiremusic.cli.model import InspireMusicModel
+from inspiremusic.dataset.dataset import Dataset
+import time
+from inspiremusic.utils.audio_utils import trim_audio, fade_out, process_audio
+from inspiremusic.utils.common import MUSIC_STRUCTURE_LABELS
+
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+
+def get_args():
+ parser = argparse.ArgumentParser(description='inference only with your model')
+ parser.add_argument('--config', required=True, help='config file')
+ parser.add_argument('--prompt_data', required=True, help='prompt data file')
+ parser.add_argument('--flow_model', default=None, required=False, help='flow model file')
+ parser.add_argument('--llm_model', default=None,required=False, help='flow model file')
+ parser.add_argument('--music_tokenizer', required=True, help='music tokenizer model file')
+ parser.add_argument('--wavtokenizer', required=True, help='wavtokenizer model file')
+ parser.add_argument('--chorus', default="random",required=False, help='chorus tag generation mode, eg. random, verse, chorus, intro.')
+ parser.add_argument('--fast', action='store_true', required=False, help='True: fast inference mode, without flow matching for fast inference. False: normal inference mode, with flow matching for high quality.')
+ parser.add_argument('--fp16', default=True, type=bool, required=False, help='inference with fp16 model')
+ parser.add_argument('--fade_out', default=True, type=bool, required=False, help='add fade out effect to generated audio')
+ parser.add_argument('--fade_out_duration', default=1.0, type=float, required=False, help='fade out duration in seconds')
+ parser.add_argument('--trim', default=False, type=bool, required=False, help='trim the silence ending of generated audio')
+ parser.add_argument('--format', type=str, default="wav", required=False,
+ choices=["wav", "mp3", "m4a", "flac"],
+ help='sampling rate of input audio')
+ parser.add_argument('--sample_rate', type=int, default=24000, required=False,
+ help='sampling rate of input audio')
+ parser.add_argument('--output_sample_rate', type=int, default=48000, required=False, choices=[24000, 48000],
+ help='sampling rate of generated output audio')
+ parser.add_argument('--min_generate_audio_seconds', type=float, default=10.0, required=False,
+ help='the minimum generated audio length in seconds')
+ parser.add_argument('--max_generate_audio_seconds', type=float, default=30.0, required=False,
+ help='the maximum generated audio length in seconds')
+ parser.add_argument('--gpu',
+ type=int,
+ default=0,
+ help='gpu id for this rank, -1 for cpu')
+ parser.add_argument('--task',
+ default='text-to-music',
+ choices=['text-to-music', 'continuation', "reconstruct", "super_resolution"],
+ help='choose inference task type. text-to-music: text-to-music task. continuation: music continuation task. reconstruct: reconstruction of original music. super_resolution: convert original 24kHz music into 48kHz music.')
+ parser.add_argument('--result_dir', required=True, help='asr result file')
+ args = parser.parse_args()
+ print(args)
+ return args
+
+
+def main():
+ args = get_args()
+ logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s')
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
+
+ if args.fast:
+ args.output_sample_rate = 24000
+
+ min_generate_audio_length = int(args.output_sample_rate * args.min_generate_audio_seconds)
+ max_generate_audio_length = int(args.output_sample_rate * args.max_generate_audio_seconds)
+ assert args.min_generate_audio_seconds <= args.max_generate_audio_seconds
+
+ # Init inspiremusic models from configs
+ use_cuda = args.gpu >= 0 and torch.cuda.is_available()
+ device = torch.device('cuda' if use_cuda else 'cpu')
+ with open(args.config, 'r') as f:
+ configs = load_hyperpyyaml(f)
+
+ model = InspireMusicModel(configs['llm'], configs['flow'], configs['hift'], configs['wavtokenizer'], args.fast, args.fp16)
+
+ model.load(args.llm_model, args.flow_model, args.music_tokenizer, args.wavtokenizer)
+
+ if args.llm_model is None:
+ model.llm = None
+ else:
+ model.llm = model.llm.to(torch.float32)
+
+ if args.flow_model is None:
+ model.flow = None
+
+ test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=True, partition=False)
+ test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
+
+ del configs
+ os.makedirs(args.result_dir, exist_ok=True)
+ fn = os.path.join(args.result_dir, 'wav.scp')
+ f = open(fn, 'w')
+ caption_fn = os.path.join(args.result_dir, 'captions.txt')
+ caption_f = open(caption_fn, 'w')
+
+ with torch.no_grad():
+ for _, batch in tqdm(enumerate(test_data_loader)):
+ utts = batch["utts"]
+
+ assert len(utts) == 1, "inference mode only support batchsize 1"
+ text_token = batch["text_token"].to(device)
+ text_token_len = batch["text_token_len"].to(device)
+
+ if "time_start" not in batch.keys():
+ batch["time_start"] = torch.randint(0, args.min_generate_audio_seconds, (1,)).to(torch.float64)
+
+ if batch["time_start"].numpy()[0] > 300:
+ batch["time_start"] = torch.Tensor([0]).to(torch.float64)
+
+ if "time_end" not in batch.keys():
+ batch["time_end"] = torch.randint(int(batch["time_start"].numpy()[0] + args.min_generate_audio_seconds), int(batch["time_start"].numpy()[0] + args.max_generate_audio_seconds), (1,)).to(torch.float64)
+ else:
+ if (batch["time_end"].numpy()[0] - batch["time_start"].numpy()[0]) < args.min_generate_audio_seconds:
+ batch["time_end"] = torch.randint(int(batch["time_start"].numpy()[0] + args.min_generate_audio_seconds), int(batch["time_start"].numpy()[0] + args.max_generate_audio_seconds), (1,)).to(torch.float64)
+ elif (batch["time_end"].numpy()[0] - batch["time_start"].numpy()[0]) > args.max_generate_audio_seconds:
+ batch["time_end"] = torch.Tensor([(batch["time_start"].numpy()[0] + args.max_generate_audio_seconds)]).to(torch.float64)
+
+ if "chorus" not in batch.keys():
+ batch["chorus"] = torch.randint(1, 5, (1,))
+
+ if args.chorus == "random":
+ batch["chorus"] = torch.randint(1, 5, (1,))
+ elif args.chorus == "intro":
+ batch["chorus"] = torch.Tensor([0])
+ elif "verse" in args.chorus:
+ batch["chorus"] = torch.Tensor([1])
+ elif args.chorus == "chorus":
+ batch["chorus"] = torch.Tensor([2])
+ elif args.chorus == "outro":
+ batch["chorus"] = torch.Tensor([4])
+ else:
+ batch["chorus"] = batch["chorus"]
+
+ time_start = batch["time_start"].to(device)
+ time_end = batch["time_end"].to(device)
+ chorus = batch["chorus"].to(torch.int)
+
+ text_prompt = f"<|{batch['time_start'].numpy()[0]}|><|{MUSIC_STRUCTURE_LABELS[chorus.numpy()[0]]}|><|{batch['text'][0]}|><|{batch['time_end'].numpy()[0]}|>"
+ chorus = chorus.to(device)
+
+ if batch["acoustic_token"] is None:
+ audio_token = None
+ audio_token_len = None
+ else:
+ audio_token = batch["acoustic_token"].to(device)
+ audio_token_len = batch["acoustic_token_len"].to(device)
+
+ text = batch["text"]
+
+ if "semantic_token" in batch:
+ token = batch["semantic_token"].to(device)
+ token_len = batch["semantic_token_len"].to(device)
+ else:
+ if audio_token is None:
+ token = None
+ token_len = None
+ else:
+ token = audio_token.view(audio_token.size(0), -1, 4)[:, :, 0]
+ token_len = audio_token_len / 4
+
+ if args.task in ['text-to-music', 'continuation']:
+ # text to music, music continuation
+ model_input = {"text": text, "audio_token": token,
+ "audio_token_len": token_len,
+ "text_token": text_token,
+ "text_token_len": text_token_len,
+ "embeddings": [time_start, time_end, chorus],
+ "raw_text": text,
+ "sample_rate": args.output_sample_rate,
+ "duration_to_gen": args.max_generate_audio_seconds,
+ "task": args.task}
+ elif args.task in ['reconstruct', 'super_resolution']:
+ # audio reconstruction, audio super resolution
+ model_input = {"text": text, "audio_token": audio_token,
+ "audio_token_len": audio_token_len,
+ "text_token": text_token,
+ "text_token_len": text_token_len,
+ "embeddings": [time_start, time_end, chorus],
+ "raw_text": text,
+ "sample_rate": args.output_sample_rate,
+ "duration_to_gen": args.max_generate_audio_seconds,
+ "task": args.task}
+ else:
+ # zero-shot
+ model_input = {'text' : text,
+ 'text_len' : text_token_len,
+ 'prompt_text' : text_token,
+ 'prompt_text_len' : text_token_len,
+ 'llm_prompt_audio_token' : token,
+ 'llm_prompt_audio_token_len' : token_len,
+ 'flow_prompt_audio_token' : audio_token,
+ 'flow_prompt_audio_token_len': audio_token_len,
+ 'prompt_audio_feat' : audio_feat,
+ 'prompt_audio_feat_len' : audio_feat_len,
+ "embeddings" : [time_start,
+ time_end,
+ chorus]}
+
+ music_key = utts[0]
+ music_audios = []
+ music_fn = os.path.join(args.result_dir, f'{music_key}.{args.format}')
+ bench_start = time.time()
+
+ for model_output in model.inference(**model_input):
+ music_audios.append(model_output['music_audio'])
+ bench_end = time.time()
+ if args.trim:
+ music_audio = trim_audio(music_audios[0],
+ sample_rate=args.output_sample_rate,
+ threshold=0.05,
+ min_silence_duration=0.8)
+ else:
+ music_audio = music_audios[0]
+ if music_audio.shape[0] != 0:
+ if music_audio.shape[1] > max_generate_audio_length:
+ music_audio = music_audio[:, :max_generate_audio_length]
+ if music_audio.shape[1] >= min_generate_audio_length:
+ try:
+ if args.fade_out:
+ music_audio = fade_out(music_audio, args.output_sample_rate, args.fade_out_duration)
+ music_audio = music_audio.repeat(2, 1)
+ if args.format in ["wav", "flac"]:
+ torchaudio.save(music_fn, music_audio, sample_rate=args.output_sample_rate, encoding="PCM_S", bits_per_sample=24)
+ elif args.format in ["mp3", "m4a"]:
+ torchaudio.backend.sox_io_backend.save(filepath=music_fn, src=music_audio, sample_rate=args.output_sample_rate, format=args.format)
+ else:
+ logging.info(f"Format is not supported. Please choose from wav, mp3, m4a, flac.")
+ except Exception as e:
+ logging.info(f"Error saving file: {e}")
+ raise
+
+ audio_duration = music_audio.shape[1] / args.output_sample_rate
+ rtf = (bench_end - bench_start) / audio_duration
+ logging.info(f"processing time: {int(bench_end - bench_start)}s, audio length: {int(audio_duration)}s, rtf: {rtf}, text prompt: {text_prompt}")
+ f.write('{} {}\n'.format(music_key, music_fn))
+ f.flush()
+ caption_f.write('{}\t{}\n'.format(music_key, text_prompt))
+ caption_f.flush()
+ else:
+ logging.info(f"Generate audio length {music_audio.shape[1]} is shorter than min_generate_audio_length.")
+ else:
+ logging.info(f"Generate audio is empty, dim = {music_audio.shape[0]}.")
+ f.close()
+ logging.info('Result wav.scp saved in {}'.format(fn))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/inspiremusic/bin/train.py b/inspiremusic/bin/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..92a9bae52670a4d7d9b48d944f3e96ac086ca762
--- /dev/null
+++ b/inspiremusic/bin/train.py
@@ -0,0 +1,194 @@
+# Copyright (c) 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+import argparse
+import datetime
+import logging
+logging.getLogger('matplotlib').setLevel(logging.WARNING)
+from copy import deepcopy
+import torch
+import torch.distributed as dist
+import deepspeed
+import glob
+import os
+from hyperpyyaml import load_hyperpyyaml
+from torch.cuda.amp import GradScaler, autocast
+from torch.distributed.elastic.multiprocessing.errors import record
+from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
+from inspiremusic.utils.executor import Executor
+from inspiremusic.utils.train_utils import (
+ init_distributed,
+ init_dataset_and_dataloader,
+ init_optimizer_and_scheduler,
+ init_summarywriter, save_model,
+ wrap_cuda_model, check_modify_and_save_config)
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description='training your network')
+ parser.add_argument('--train_engine',
+ default='torch_ddp',
+ choices=['torch_ddp', 'deepspeed'],
+ help='Engine for paralleled training')
+ parser.add_argument('--model', required=True, help='model which will be trained')
+ parser.add_argument('--config', required=True, help='config file')
+ parser.add_argument('--train_data', required=True, help='train data file')
+ parser.add_argument('--cv_data', required=True, help='cv data file')
+ parser.add_argument('--checkpoint', help='checkpoint model')
+ parser.add_argument('--model_dir', required=True, help='save model dir')
+ parser.add_argument('--tensorboard_dir',
+ default='tensorboard',
+ help='tensorboard log dir')
+ parser.add_argument('--ddp.dist_backend',
+ dest='dist_backend',
+ default='nccl',
+ choices=['nccl', 'gloo'],
+ help='distributed backend')
+ parser.add_argument('--num_workers',
+ default=0,
+ type=int,
+ help='number of subprocess workers for reading')
+ parser.add_argument('--prefetch',
+ default=100,
+ type=int,
+ help='prefetch number')
+ parser.add_argument('--pin_memory',
+ action='store_true',
+ default=True,
+ help='Use pinned memory buffers used for reading')
+ parser.add_argument('--deepspeed.save_states',
+ dest='save_states',
+ default='model_only',
+ choices=['model_only', 'model+optimizer'],
+ help='save model/optimizer states')
+ parser.add_argument('--timeout',
+ default=30,
+ type=int,
+ help='timeout (in seconds) of inspiremusic_join.')
+ parser.add_argument('--fp16',
+ action='store_true',
+ default=False,
+ help='Enable fp16 mixed precision training')
+ parser.add_argument('--lora',
+ action='store_true',
+ default=False,
+ help='Enable LoRA training')
+ parser.add_argument('--lora_rank',
+ default=4,
+ type=int,
+ help='LoRA rank')
+ parser.add_argument('--lora_alpha',
+ default=16,
+ type=int,
+ help='LoRA alpha')
+ parser.add_argument('--lora_dropout',
+ default=0.1,
+ type=float,
+ help='LoRA dropout rate')
+ parser.add_argument('--lora_target_modules',
+ nargs='+',
+ default=["k_proj","v_proj"],
+ help='Target modules to apply LoRA (e.g., ["q_proj", "v_proj"])')
+
+ parser = deepspeed.add_config_arguments(parser)
+ args = parser.parse_args()
+ return args
+
+
+@record
+def main():
+ args = get_args()
+ logging.basicConfig(level=logging.DEBUG,
+ format='%(asctime)s %(levelname)s %(message)s')
+
+ override_dict = {k: None for k in ['llm', 'flow', 'hift'] if k != args.model}
+ with open(args.config, 'r') as f:
+ configs = load_hyperpyyaml(f, overrides=override_dict)
+ configs['train_conf'].update(vars(args))
+
+ # Init env for ddp
+ init_distributed(args)
+
+ # Get dataset & dataloader
+ train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
+ init_dataset_and_dataloader(args, configs)
+
+ # Do some sanity checks and save config to arsg.model_dir
+ configs = check_modify_and_save_config(args, configs)
+
+ # Tensorboard summary
+ writer = init_summarywriter(args)
+
+ # load checkpoint
+ model = configs[args.model]
+
+ if args.checkpoint is not None:
+ model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'))
+ else:
+ # Find and load the latest checkpoint
+ checkpoint_files = glob.glob(os.path.join(args.model_dir, '*.pt'))
+
+ if checkpoint_files:
+ latest_checkpoint = max(checkpoint_files, key=os.path.getctime)
+ logging.info(f"Loaded latest checkpoint from {latest_checkpoint}")
+
+ model.load_state_dict(torch.load(latest_checkpoint, map_location='cpu'))
+
+ if args.lora:
+ logging.info("Applying LoRA to the model...")
+ if not args.lora_target_modules:
+ raise ValueError("No target modules specified for LoRA. Please provide --lora_target_modules.")
+ lora_config = LoraConfig(
+ task_type="CAUSAL_LM", # Change to appropriate task type
+ inference_mode=False,
+ r=args.lora_rank,
+ lora_alpha=args.lora_alpha,
+ lora_dropout=args.lora_dropout,
+ target_modules=args.lora_target_modules
+ )
+ model.llm.model = get_peft_model(model.llm.model, lora_config)
+ # Optionally freeze the base model
+ else:
+ logging.info("LoRA is not enabled. Training the full model.")
+
+ # Dispatch model from cpu to gpu
+ model = wrap_cuda_model(args, model)
+
+ # Get optimizer & scheduler
+ model, optimizer, scheduler = init_optimizer_and_scheduler(args, configs, model)
+
+ # Initialize AMP for torch_ddp if fp16 is enabled
+ scaler = None
+ if args.fp16:
+ scaler = GradScaler()
+ logging.info("Initialized AMP GradScaler for mixed precision training.")
+
+ # Save init checkpoints
+ info_dict = deepcopy(configs['train_conf'])
+
+ # Get executor
+ executor = Executor()
+
+ # Start training loop
+ for epoch in range(info_dict['max_epoch']):
+ executor.epoch = epoch
+ train_dataset.set_epoch(epoch)
+ dist.barrier()
+ group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
+ executor.train_one_epoch(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join, scaler=scaler)
+ dist.destroy_process_group(group_join)
+
+if __name__ == '__main__':
+ main()
diff --git a/inspiremusic/cli/__init__.py b/inspiremusic/cli/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/inspiremusic/cli/frontend.py b/inspiremusic/cli/frontend.py
new file mode 100644
index 0000000000000000000000000000000000000000..d717c3a5c77cbf3c9cc2513f75b3045ad26d9149
--- /dev/null
+++ b/inspiremusic/cli/frontend.py
@@ -0,0 +1,106 @@
+# Copyright (c) 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from functools import partial
+import onnxruntime
+import torch
+import numpy as np
+import whisper
+from typing import Callable
+import torchaudio.compliance.kaldi as kaldi
+import torchaudio
+import os
+import re
+import inflect
+from inspiremusic.cli.model import InspireMusicModel
+from inspiremusic.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph
+from inspiremusic.wavtokenizer.decoder.pretrained import WavTokenizer
+
+class InspireMusicFrontEnd:
+ def __init__(self,
+ configs: Callable,
+ get_tokenizer: Callable,
+ llm_model: str,
+ flow_model: str,
+ music_tokenizer_dir: str,
+ audio_tokenizer_dir: str,
+ instruct: bool = False,
+ fast: bool = False,
+ fp16: bool = True,
+ allowed_special: str = 'all'):
+ self.tokenizer = get_tokenizer()
+ self.audio_tokenizer_dir = audio_tokenizer_dir
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+ self.bandwidth_id = torch.tensor([0]).to(self.device)
+ self.wavtokenizer = WavTokenizer.from_pretrained_feat(f"{audio_tokenizer_dir}/config.yaml", f"{audio_tokenizer_dir}/model.pt").to(self.device)
+
+ self.model = InspireMusicModel(configs['llm'], configs['flow'], configs['hift'], configs['wavtokenizer'], fast, fp16)
+ self.model = self.model.load(llm_model, flow_model, music_tokenizer_dir, audio_tokenizer_dir)
+
+ self.instruct = instruct
+ self.allowed_special = allowed_special
+ self.inflect_parser = inflect.engine()
+
+ def _extract_text_token(self, text):
+ text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
+ text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
+ text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
+ return text_token, text_token_len
+
+ def _extract_audio_token(self, audio, sample_rate=24000):
+ audio = torch.tensor(audio, dtype=torch.float32, device=self.device)
+ _, audio_token = self.wavtokenizer.encode_infer(audio, bandwidth_id=self.bandwidth_id)
+ audio_token = audio_token.squeeze(0)
+ audio_token_len = torch.tensor([audio_token.shape[1]], dtype=torch.int32, device=self.device)
+ return audio_token, audio_token_len
+
+ def text_normalize(self, text, split=True):
+ text = text.strip()
+ if contains_chinese(text):
+ text = text.replace("\n", "")
+ text = replace_blank(text)
+ text = replace_corner_mark(text)
+ text = text.replace(".", "ใ")
+ text = text.replace(" - ", "๏ผ")
+ text = remove_bracket(text)
+ text = re.sub(r'[๏ผ,]+$', 'ใ', text)
+ texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
+ token_min_n=60, merge_len=20, comma_split=False))
+ else:
+ text = spell_out_number(text, self.inflect_parser)
+ texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
+ token_min_n=60, merge_len=20, comma_split=False))
+ if split is False:
+ return text
+ return texts
+
+ def frontend_text_to_music(self, text, time_start, time_end, chorus):
+ text_token, text_token_len = self._extract_text_token(text)
+ model_input = {"text": text, "audio_token": None, "audio_token_len": None,
+ "text_token": text_token, "text_token_len": text_token_len,
+ "embeddings": [time_start, time_end, chorus], "raw_text":text}
+ return model_input
+
+ def frontend_continuation(self, text, audio, time_start, time_end, chorus, target_sr=24000):
+ if text is None:
+ text_token = None
+ text_token_len = None
+ else:
+ text_token, text_token_len = self._extract_text_token(text)
+ audio_token, audio_token_len = self._extract_audio_token(audio, target_sr)
+ model_input = {"text": text, "audio_token": audio_token, "audio_token_len": audio_token_len,
+ "text_token": text_token, "text_token_len": text_token_len,
+ "embeddings": [time_start, time_end, chorus], "raw_text":text}
+ return model_input
+
diff --git a/inspiremusic/cli/inference.py b/inspiremusic/cli/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff0ad0994a5989f4082e94af776fb38590f550a5
--- /dev/null
+++ b/inspiremusic/cli/inference.py
@@ -0,0 +1,296 @@
+# Copyright (c) 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import torchaudio
+import time
+import logging
+import argparse
+
+from modelscope import snapshot_download
+from inspiremusic.cli.inspiremusic import InspireMusic
+from inspiremusic.utils.file_utils import logging
+import torch
+from inspiremusic.utils.audio_utils import trim_audio, fade_out
+from transformers import AutoModel
+
+def set_env_variables():
+ os.environ['PYTHONIOENCODING'] = 'UTF-8'
+ os.environ['TOKENIZERS_PARALLELISM'] = 'False'
+ current_working_dir = os.getcwd()
+ main_root = os.path.realpath(os.path.join(current_working_dir, '../../'))
+ bin_dir = os.path.join(main_root, 'inspiremusic')
+ third_party_matcha_tts_path = os.path.join(main_root, 'third_party', 'Matcha-TTS')
+ python_path = f"{main_root}:{bin_dir}:{third_party_matcha_tts_path}:{os.environ.get('PYTHONPATH', '')}"
+ os.environ['PYTHONPATH'] = python_path
+ sys.path.extend([main_root, third_party_matcha_tts_path])
+
+class InspireMusicUnified:
+ def __init__(self,
+ model_name: str = "InspireMusic-1.5B-Long",
+ model_dir: str = None,
+ min_generate_audio_seconds: float = 10.0,
+ max_generate_audio_seconds: float = 30.0,
+ sample_rate: int = 24000,
+ output_sample_rate: int = 48000,
+ load_jit: bool = True,
+ load_onnx: bool = False,
+ fast: bool = False,
+ fp16: bool = True,
+ gpu: int = 0,
+ result_dir: str = None):
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
+
+ # Set model_dir or default to downloading if it doesn't exist
+ self.model_dir = model_dir or f"../../pretrained_models/{model_name}"
+ if not os.path.exists(self.model_dir):
+ self.model_dir = snapshot_download(f"iic/{model_name}", cache_dir=self.model_dir)
+
+ self.sample_rate = sample_rate
+ self.output_sample_rate = 24000 if fast else output_sample_rate
+ self.result_dir = result_dir or f"exp/{model_name}"
+ os.makedirs(self.result_dir, exist_ok=True)
+
+ self.min_generate_audio_seconds = min_generate_audio_seconds
+ self.max_generate_audio_seconds = max_generate_audio_seconds
+ self.min_generate_audio_length = int(self.output_sample_rate * self.min_generate_audio_seconds)
+ self.max_generate_audio_length = int(self.output_sample_rate * self.max_generate_audio_seconds)
+ assert self.min_generate_audio_seconds <= self.max_generate_audio_seconds, "Min audio seconds must be less than or equal to max audio seconds"
+
+ use_cuda = gpu >= 0 and torch.cuda.is_available()
+ self.device = torch.device('cuda' if use_cuda else 'cpu')
+ self.model = InspireMusic(self.model_dir, load_jit=load_jit, load_onnx=load_onnx, fast=fast, fp16=fp16)
+
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+
+ @torch.inference_mode()
+ def inference(self,
+ task: str = 'text-to-music',
+ text: str = None,
+ audio_prompt: str = None, # audio prompt file path
+ chorus: str = "verse",
+ time_start: float = 0.0,
+ time_end: float = 30.0,
+ output_fn: str = "output_audio",
+ max_audio_prompt_length: float = 5.0,
+ fade_out_duration: float = 1.0,
+ output_format: str = "wav",
+ fade_out_mode: bool = True,
+ trim: bool = False,
+ ):
+
+ with torch.no_grad():
+ text_prompt = f"<|{time_start}|><|{chorus}|><|{text}|><|{time_end}|>"
+ chorus_dict = {"random": torch.randint(1, 5, (1,)).item(), "intro" : 0, "verse": 1, "chorus": 2, "outro": 4}
+ chorus = chorus_dict.get(chorus, 1)
+ chorus = torch.tensor([chorus], dtype=torch.int).to(self.device)
+
+ time_start_tensor = torch.tensor([time_start], dtype=torch.float64).to(self.device)
+ time_end_tensor = torch.tensor([time_end], dtype=torch.float64).to(self.device)
+
+ music_fn = os.path.join(self.result_dir, f'{output_fn}.{output_format}')
+
+ bench_start = time.time()
+
+ if task == 'text-to-music':
+ model_input = {
+ "text" : text,
+ "audio_prompt" : audio_prompt,
+ "time_start" : time_start_tensor,
+ "time_end" : time_end_tensor,
+ "chorus" : chorus,
+ "task" : task,
+ "stream" : False,
+ "duration_to_gen": self.max_generate_audio_seconds,
+ "sr" : self.sample_rate
+ }
+ elif task == 'continuation':
+ if audio_prompt is not None:
+ audio, _ = process_audio(audio_prompt, self.sample_rate)
+ if audio.size(1) < self.sample_rate:
+ logging.warning("Warning: Input prompt audio length is shorter than 1s. Please provide an appropriate length audio prompt and try again.")
+ audio = None
+ else:
+ max_audio_prompt_length_samples = int(max_audio_prompt_length * self.sample_rate)
+ audio = audio[:, :max_audio_prompt_length_samples] # Trimming prompt audio
+
+ model_input = {
+ "text" : text,
+ "audio_prompt" : audio,
+ "time_start" : time_start_tensor,
+ "time_end" : time_end_tensor,
+ "chorus" : chorus,
+ "task" : task,
+ "stream" : False,
+ "duration_to_gen": self.max_generate_audio_seconds,
+ "sr" : self.sample_rate
+ }
+
+ music_audios = []
+ for model_output in self.model.cli_inference(**model_input):
+ music_audios.append(model_output['music_audio'])
+
+ bench_end = time.time()
+
+ if trim:
+ music_audio = trim_audio(music_audios[0],
+ sample_rate=self.output_sample_rate,
+ threshold=0.05,
+ min_silence_duration=0.8)
+ else:
+ music_audio = music_audios[0]
+
+ if music_audio.shape[0] != 0:
+ if music_audio.shape[1] > self.max_generate_audio_length:
+ music_audio = music_audio[:, :self.max_generate_audio_length]
+
+ if music_audio.shape[1] >= self.min_generate_audio_length:
+ try:
+ if fade_out_mode:
+ music_audio = fade_out(music_audio, self.output_sample_rate, fade_out_duration)
+
+ music_audio = music_audio.repeat(2, 1)
+
+ if output_format in ["wav", "flac"]:
+ torchaudio.save(music_fn, music_audio,
+ sample_rate=self.output_sample_rate,
+ encoding="PCM_S",
+ bits_per_sample=24)
+ elif output_format in ["mp3", "m4a"]:
+ torchaudio.backend.sox_io_backend.save(
+ filepath=music_fn, src=music_audio,
+ sample_rate=self.output_sample_rate,
+ format=output_format)
+ else:
+ logging.info("Format is not supported. Please choose from wav, mp3, m4a, flac.")
+
+ except Exception as e:
+ logging.error(f"Error saving file: {e}")
+ raise
+
+ audio_duration = music_audio.shape[1] / self.output_sample_rate
+ rtf = (bench_end - bench_start) / audio_duration
+ logging.info(f"Processing time: {int(bench_end - bench_start)}s, audio length: {int(audio_duration)}s, rtf: {rtf}, text prompt: {text_prompt}")
+
+ else:
+ logging.error(f"Generated audio length is shorter than minimum required audio length.")
+
+def get_args():
+ parser = argparse.ArgumentParser(description='Run inference with your model')
+ parser.add_argument('-m', '--model_name', default="InspireMusic-1.5B-Long",
+ help='Model name')
+
+ parser.add_argument('-d', '--model_dir',
+ help='Model folder path')
+
+ parser.add_argument('-t', '--text', default="Experience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance.",
+ help='Prompt text')
+
+ parser.add_argument('-a', '--audio_prompt', default=None,
+ help='Prompt audio')
+
+ parser.add_argument('-c', '--chorus', default="intro",
+ help='Chorus tag generation mode (e.g., random, verse, chorus, intro, outro)')
+
+ parser.add_argument('-f', '--fast', type=bool, default=False,
+ help='Enable fast inference mode (without flow matching)')
+
+ parser.add_argument('-g', '--gpu', type=int, default=0,
+ help='GPU ID for this rank, -1 for CPU')
+
+ parser.add_argument('--task', default='text-to-music', choices=['text-to-music', 'continuation', 'reconstruct', 'super_resolution'],
+ help='Inference task type: text-to-music, continuation, reconstruct, super_resolution')
+
+ parser.add_argument('-r', '--result_dir', default="exp/inspiremusic",
+ help='Directory to save generated audio')
+
+ parser.add_argument('-o', '--output_fn', default="output_audio",
+ help='Output file name')
+
+ parser.add_argument('--format', type=str, default="wav", choices=["wav", "mp3", "m4a", "flac"],
+ help='Format of output audio')
+
+ parser.add_argument('--sample_rate', type=int, default=24000,
+ help='Sampling rate of input audio')
+
+ parser.add_argument('--output_sample_rate', type=int, default=48000, choices=[24000, 48000],
+ help='Sampling rate of generated output audio')
+
+ parser.add_argument('-s', '--time_start', type=float, default=0.0,
+ help='Start time in seconds')
+
+ parser.add_argument('-e', '--time_end', type=float, default=30.0,
+ help='End time in seconds')
+
+ parser.add_argument('--max_audio_prompt_length', type=float, default=5.0,
+ help='Maximum audio prompt length in seconds')
+
+ parser.add_argument('--min_generate_audio_seconds', type=float, default=10.0,
+ help='Minimum generated audio length in seconds')
+
+ parser.add_argument('--max_generate_audio_seconds', type=float, default=30.0,
+ help='Maximum generated audio length in seconds')
+
+ parser.add_argument('--fp16', type=bool, default=True,
+ help='Inference with fp16 model')
+
+ parser.add_argument('--fade_out', type=bool, default=True,
+ help='Apply fade out effect to generated audio')
+
+ parser.add_argument('--fade_out_duration', type=float, default=1.0,
+ help='Fade out duration in seconds')
+
+ parser.add_argument('--trim', type=bool, default=False,
+ help='Trim the silence ending of generated audio')
+
+ args = parser.parse_args()
+
+ if not args.model_dir:
+ args.model_dir = os.path.join("../../pretrained_models", args.model_name)
+
+ print(args)
+ return args
+
+def main():
+ set_env_variables()
+ args = get_args()
+ model = InspireMusicUnified(model_name = args.model_name,
+ model_dir = args.model_dir,
+ min_generate_audio_seconds = args.min_generate_audio_seconds,
+ max_generate_audio_seconds = args.max_generate_audio_seconds,
+ sample_rate = args.sample_rate,
+ output_sample_rate = args.output_sample_rate,
+ load_jit = True,
+ load_onnx = False,
+ fast = args.fast,
+ fp16 = args.fp16,
+ gpu = args.gpu,
+ result_dir = args.result_dir)
+
+ model.inference(task = args.task,
+ text = args.text,
+ audio_prompt = args.audio_prompt,
+ chorus = args.chorus,
+ time_start = args.time_start,
+ time_end = args.time_end,
+ output_fn = args.output_fn,
+ max_audio_prompt_length = args.max_audio_prompt_length,
+ fade_out_duration = args.fade_out_duration,
+ output_format = args.format,
+ fade_out_mode = args.fade_out,
+ trim = args.trim)
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/inspiremusic/cli/inspiremusic.py b/inspiremusic/cli/inspiremusic.py
new file mode 100644
index 0000000000000000000000000000000000000000..c70f07fe1a37be37f1d2910436847ee03aad97e9
--- /dev/null
+++ b/inspiremusic/cli/inspiremusic.py
@@ -0,0 +1,133 @@
+# Copyright (c) 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import time
+from tqdm import tqdm
+from hyperpyyaml import load_hyperpyyaml
+from modelscope import snapshot_download
+from inspiremusic.cli.frontend import InspireMusicFrontEnd
+from inspiremusic.cli.model import InspireMusicModel
+from inspiremusic.utils.file_utils import logging
+import torch
+
+class InspireMusic:
+ def __init__(self, model_dir, load_jit=True, load_onnx=False, fast = False, fp16=True):
+ instruct = True if '-Instruct' in model_dir else False
+ self.model_dir = model_dir
+ if not os.path.exists(model_dir):
+ model_dir = snapshot_download(model_dir)
+ with open('{}/inspiremusic.yaml'.format(model_dir), 'r') as f:
+ configs = load_hyperpyyaml(f)
+
+ self.frontend = InspireMusicFrontEnd(configs,
+ configs['get_tokenizer'],
+ '{}/llm.pt'.format(model_dir),
+ '{}/flow.pt'.format(model_dir),
+ '{}/music_tokenizer/'.format(model_dir),
+ '{}/wavtokenizer/'.format(model_dir),
+ instruct,
+ fast,
+ fp16,
+ configs['allowed_special'])
+
+ self.model = InspireMusicModel(configs['llm'], configs['flow'], configs['hift'], configs['wavtokenizer'], fast, fp16)
+ self.model.load('{}/llm.pt'.format(model_dir),
+ '{}/flow.pt'.format(model_dir),
+ '{}/music_tokenizer/'.format(model_dir),
+ '{}/wavtokenizer/model.pt'.format(model_dir))
+ del configs
+
+ @torch.inference_mode()
+ def inference(self, task, text, audio, time_start, time_end, chorus, stream=False, sr=24000):
+ if task == "text-to-music":
+ for i in tqdm(self.frontend.text_normalize(text, split=True)):
+ model_input = self.frontend.frontend_text_to_music(i, time_start, time_end, chorus)
+ start_time = time.time()
+ logging.info('prompt text {}'.format(i))
+ for model_output in self.model.inference(**model_input, stream=stream):
+ music_audios_len = model_output['music_audio'].shape[1] / sr
+ logging.info('yield music len {}, rtf {}'.format(music_audios_len, (time.time() - start_time) / music_audios_len))
+ yield model_output
+ start_time = time.time()
+
+ elif task == "continuation":
+ if text is None:
+ if audio is not None:
+ for i in tqdm(audio):
+ model_input = self.frontend.frontend_continuation(None, i, time_start, time_end, chorus, sr, max_audio_length)
+ start_time = time.time()
+ logging.info('prompt text {}'.format(i))
+ for model_output in self.model.continuation_inference(**model_input, stream=stream):
+ music_audios_len = model_output['music_audio'].shape[1] / sr
+ logging.info('yield music len {}, rtf {}'.format(music_audios_len, (time.time() - start_time) / music_audios_len))
+ yield model_output
+ start_time = time.time()
+ else:
+ if audio is not None:
+ for i in tqdm(self.frontend.text_normalize(text, split=True)):
+ model_input = self.frontend.frontend_continuation(i, audio, time_start, time_end, chorus, sr, max_audio_length)
+ start_time = time.time()
+ logging.info('prompt text {}'.format(i))
+ for model_output in self.model.continuation_inference(**model_input, stream=stream):
+ music_audios_len = model_output['music_audio'].shape[1] / sr
+ logging.info('yield music len {}, rtf {}'.format(music_audios_len, (time.time() - start_time) / music_audios_len))
+ yield model_output
+ start_time = time.time()
+ else:
+ print("Please input text or audio.")
+ else:
+ print("Currently only support text-to-music and music continuation tasks.")
+
+ @torch.inference_mode()
+ def cli_inference(self, text, audio_prompt, time_start, time_end, chorus, task, stream=False, duration_to_gen=30, sr=24000):
+ if task == "text-to-music":
+ model_input = self.frontend.frontend_text_to_music(text, time_start, time_end, chorus)
+ logging.info('prompt text {}'.format(text))
+ elif task == "continuation":
+ model_input = self.frontend.frontend_continuation(text, audio_prompt, time_start, time_end, chorus, sr)
+ logging.info('prompt audio length: {}'.format(len(audio_prompt)))
+
+ start_time = time.time()
+ for model_output in self.model.inference(**model_input, duration_to_gen=duration_to_gen, task=task):
+ music_audios_len = model_output['music_audio'].shape[1] / sr
+ logging.info('yield music len {}, rtf {}'.format(music_audios_len, (time.time() - start_time) / music_audios_len))
+ yield model_output
+ start_time = time.time()
+
+ @torch.inference_mode()
+ def inference_zero_shot(self, text, prompt_text, prompt_audio_16k, stream=False, sr=24000):
+ prompt_text = self.frontend.text_normalize(prompt_text, split=False)
+ for i in tqdm(self.frontend.text_normalize(text, split=True)):
+ model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_audio_16k)
+ start_time = time.time()
+ logging.info('prompt text {}'.format(i))
+ for model_output in self.model.inference(**model_input, stream=stream):
+ audio_len = model_output['music_audio'].shape[1] / sr
+ logging.info('yield audio len {}, rtf {}'.format(audio_len, (time.time() - start_time) / audio_len))
+ yield model_output
+ start_time = time.time()
+ @torch.inference_mode()
+ def inference_instruct(self, text, spk_id, instruct_text, stream=False, sr=24000):
+ if self.frontend.instruct is False:
+ raise ValueError('{} do not support instruct inference'.format(self.model_dir))
+ instruct_text = self.frontend.text_normalize(instruct_text, split=False)
+ for i in tqdm(self.frontend.text_normalize(text, split=True)):
+ model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
+ start_time = time.time()
+ logging.info('prompt text {}'.format(i))
+ for model_output in self.model.inference(**model_input, stream=stream):
+ audio_len = model_output['music_audio'].shape[1] / sr
+ logging.info('yield audio len {}, rtf {}'.format(audio_len, (time.time() - start_time) / audio_len))
+ yield model_output
+ start_time = time.time()
diff --git a/inspiremusic/cli/model.py b/inspiremusic/cli/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..47c23da6284d607f89c337ff712f9d7d5978b998
--- /dev/null
+++ b/inspiremusic/cli/model.py
@@ -0,0 +1,297 @@
+# Copyright (c) 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+import numpy as np
+import threading
+import time
+from contextlib import nullcontext
+import uuid
+from inspiremusic.utils.common import fade_in_out
+from inspiremusic.music_tokenizer.vqvae import VQVAE
+from inspiremusic.wavtokenizer.decoder.pretrained import WavTokenizer
+from torch.cuda.amp import autocast
+import logging
+import torch
+import os
+
+
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+
+class InspireMusicModel:
+
+ def __init__(self,
+ llm: torch.nn.Module,
+ flow: torch.nn.Module,
+ music_tokenizer: torch.nn.Module,
+ wavtokenizer: torch.nn.Module,
+ fast: bool = False,
+ fp16: bool = True,
+ ):
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ self.llm = llm
+ self.flow = flow
+ self.music_tokenizer = music_tokenizer
+ self.wavtokenizer = wavtokenizer
+ self.fp16 = fp16
+ self.token_min_hop_len = 100
+ self.token_max_hop_len = 200
+ self.token_overlap_len = 20
+ # mel fade in out
+ self.mel_overlap_len = 34
+ self.mel_window = np.hamming(2 * self.mel_overlap_len)
+ # hift cache
+ self.mel_cache_len = 20
+ self.source_cache_len = int(self.mel_cache_len * 256)
+ # rtf and decoding related
+ self.stream_scale_factor = 1
+ assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
+ self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
+ self.lock = threading.Lock()
+ # dict used to store session related variable
+ self.music_token_dict = {}
+ self.llm_end_dict = {}
+ self.mel_overlap_dict = {}
+ self.fast = fast
+ self.generator = "hifi"
+
+ def load(self, llm_model, flow_model, hift_model, wavtokenizer_model):
+ if llm_model is not None:
+ self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
+ self.llm.to(self.device).eval()
+ else:
+ self.llm = None
+ if flow_model is not None:
+ self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
+ self.flow.to(self.device).eval()
+ if hift_model is not None:
+ if ".pt" not in hift_model:
+ self.music_tokenizer = VQVAE( hift_model + '/config.json',
+ hift_model + '/model.pt', with_encoder=True)
+ else:
+ self.music_tokenizer = VQVAE(os.path.dirname(hift_model) + '/config.json',
+ hift_model, with_encoder=True)
+ self.music_tokenizer.to(self.device).eval()
+ if wavtokenizer_model is not None:
+ if ".pt" not in wavtokenizer_model:
+ self.wavtokenizer = WavTokenizer.from_pretrained_feat( wavtokenizer_model + '/config.yaml',
+ wavtokenizer_model + '/model.pt')
+ else:
+ self.wavtokenizer = WavTokenizer.from_pretrained_feat( os.path.dirname(wavtokenizer_model) + '/config.yaml',
+ wavtokenizer_model )
+ self.wavtokenizer.to(self.device)
+
+ def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
+ assert self.fp16 is True, "we only provide fp16 jit model, set fp16=True if you want to use jit model"
+ llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
+ self.llm.text_encoder = llm_text_encoder
+ llm_llm = torch.jit.load(llm_llm_model)
+ self.llm.llm = llm_llm
+ flow_encoder = torch.jit.load(flow_encoder_model)
+ self.flow.encoder = flow_encoder
+
+ def load_onnx(self, flow_decoder_estimator_model):
+ import onnxruntime
+ option = onnxruntime.SessionOptions()
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+ option.intra_op_num_threads = 1
+ providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
+ del self.flow.decoder.estimator
+ self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
+
+ def llm_job(self, text, audio_token, audio_token_len, prompt_text, llm_prompt_audio_token, embeddings, uuid, duration_to_gen, task):
+ with self.llm_context:
+ local_res = []
+ with autocast(enabled=self.fp16):
+ inference_kwargs = {
+ 'text': text.to(self.device),
+ 'text_len': torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
+ 'prompt_text': prompt_text.to(self.device),
+ 'prompt_text_len': torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
+ 'prompt_audio_token': llm_prompt_audio_token.to(self.device),
+ 'prompt_audio_token_len': torch.tensor([llm_prompt_audio_token.shape[1]], dtype=torch.int32).to(self.device),
+ 'embeddings': embeddings,
+ 'duration_to_gen': duration_to_gen,
+ 'task': task
+ }
+
+ if audio_token is not None:
+ inference_kwargs['audio_token'] = audio_token.to(self.device)
+ else:
+ inference_kwargs['audio_token'] = torch.Tensor([0]).to(self.device)
+
+ if audio_token_len is not None:
+ inference_kwargs['audio_token_len'] = audio_token_len.to(self.device)
+ else:
+ inference_kwargs['audio_token_len'] = torch.Tensor([0]).to(self.device)
+
+ for i in self.llm.inference(**inference_kwargs):
+ local_res.append(i)
+
+ self.music_token_dict[uuid] = local_res
+ self.llm_end_dict[uuid] = True
+
+ # def token2wav(self, token, token_len, text, text_len, uuid, sample_rate, finalize=False):
+ def token2wav(self, token, token_len, uuid, sample_rate, finalize=False, flow_cfg=None):
+ # if self.flow is not None:
+ # if isinstance(self.flow,MaskedDiffWithText):
+ # codec_embed = self.flow.inference(token=token.to(self.device),
+ # token_len=token_len.to(self.device),
+ # text_token=text,
+ # text_token_len=text_len,
+ # )
+ # else:
+ if flow_cfg is not None:
+ codec_embed = self.flow.inference_cfg(token=token.to(self.device),
+ token_len=token_len.to(self.device),
+ sample_rate=sample_rate
+ )
+ else:
+ codec_embed = self.flow.inference(token=token.to(self.device),
+ token_len=token_len.to(self.device),
+ sample_rate=sample_rate
+ )
+ # use music_tokenizer decoder
+ wav = self.music_tokenizer.generator(codec_embed)
+ wav = wav.squeeze(0).cpu().detach()
+ return wav
+
+ def acoustictoken2wav(self, token):
+ # use music_tokenizer to generate waveform from token
+ token = token.view(token.size(0), -1, 4)
+ # codec = token.view(1, -1, 4)
+ codec_embed = self.music_tokenizer.quantizer.embed(torch.tensor(token).long().to(self.device)).cuda()
+ wav = self.music_tokenizer.generator(codec_embed)
+ wav = wav.squeeze(0).cpu().detach()
+ return wav
+
+ def semantictoken2wav(self, token):
+ # fast mode, use wavtokenizer decoder
+ new_tensor = torch.tensor(token.to(self.device)).unsqueeze(0)
+ features = self.wavtokenizer.codes_to_features(new_tensor)
+ bandwidth_id = torch.tensor([0]).to(self.device)
+ wav = self.wavtokenizer.to(self.device).decode(features, bandwidth_id=bandwidth_id)
+ wav = wav.cpu().detach()
+ return wav
+
+ @torch.inference_mode()
+ def inference(self, text, audio_token, audio_token_len, text_token, text_token_len, embeddings=None,
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32),
+ llm_prompt_audio_token=torch.zeros(1, 0, dtype=torch.int32),
+ flow_prompt_audio_token=torch.zeros(1, 0, dtype=torch.int32),
+ prompt_audio_feat=torch.zeros(1, 0, 80), sample_rate=48000, duration_to_gen = 30, task="continuation", trim = True, stream=False, **kwargs):
+
+ # this_uuid is used to track variables related to this inference thread
+ # support tasks:
+ # text to music task
+ # music continuation task
+ # require either audio input only or text and audio inputs
+
+ this_uuid = str(uuid.uuid1())
+
+ if self.llm:
+ with self.lock:
+ self.music_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
+
+ p = threading.Thread(target=self.llm_job, args=(text_token, audio_token, audio_token_len, prompt_text, llm_prompt_audio_token, embeddings, this_uuid, duration_to_gen, task))
+ p.start()
+
+ if stream is True:
+ token_hop_len = self.token_min_hop_len
+ while True:
+ time.sleep(0.1)
+ if len(self.music_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
+ this_music_audio = self.token2wav(token=text_token,
+ token_len=text_token_len,
+ uuid=this_uuid,
+ sample_rate=sample_rate,
+ finalize=False)
+ yield {'music_audio': this_music_audio.cpu()}
+ with self.lock:
+ self.music_token_dict[this_uuid] = self.music_token_dict[this_uuid][token_hop_len:]
+ # increase token_hop_len for better audio quality
+ token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
+ if self.llm_end_dict[this_uuid] is True and len(self.music_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
+ break
+ p.join()
+ # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
+ this_music_token = torch.concat(self.music_token_dict[this_uuid], dim=1)
+ with self.flow_hift_context:
+ this_music_audio = self.token2wav(token=this_music_token,
+ prompt_token=flow_prompt_audio_token,
+ prompt_feat=prompt_audio_feat,
+ embedding=flow_embedding,
+ uuid=this_uuid,
+ sample_rate=sample_rate,
+ finalize=True)
+ yield {'music_audio': this_music_audio.cpu()}
+ else:
+ # deal with all tokens
+ if self.fast:
+ if task == "reconstruct":
+ assert audio_token is None
+ this_music_token = audio_token
+ this_music_audio = self.acoustictoken2wav(token=this_music_token)
+ else:
+ if self.llm:
+ p.join()
+ print(len(self.music_token_dict[this_uuid]))
+ this_music_token = torch.concat(self.music_token_dict[this_uuid], dim=1)
+ print(this_music_token.shape)
+ else:
+ this_music_token = text_token
+
+ logging.info("using wavtokenizer generator without flow matching")
+ this_music_audio = self.semantictoken2wav(token=this_music_token)
+ print(this_music_audio.shape)
+
+ else:
+ if self.llm:
+ p.join()
+ if len(self.music_token_dict[this_uuid]) != 0:
+ this_music_token = torch.concat(self.music_token_dict[this_uuid], dim=1)
+ else:
+ print(f"The list of tensors is empty for UUID: {this_uuid}")
+ else:
+ this_music_token = text_token
+ logging.info(f"LLM generated audio token length: {this_music_token.shape[1]}")
+ logging.info(f"using flow matching and {self.generator} generator")
+
+ if self.generator == "hifi":
+ if (embeddings[1] - embeddings[0]) <= duration_to_gen:
+ if trim:
+ trim_length = (int((embeddings[1] - embeddings[0])*75))
+ this_music_token = this_music_token[:, :trim_length]
+ logging.info(f"After trimmed, generated audio token length: {this_music_token.shape[1]}")
+ elif (embeddings[1] - embeddings[0]) < 1:
+ logging.info(f"Given audio length={(embeddings[1] - embeddings[0])}, which is too short, please give a longer audio length.")
+
+ this_music_audio = self.token2wav(token=this_music_token,
+ token_len=torch.LongTensor([this_music_token.size(1)]),
+ uuid=this_uuid,
+ sample_rate=sample_rate,
+ finalize=True)
+ logging.info(f"Generated audio sequence length: {this_music_audio.shape[1]}")
+ elif self.generator == "wavtokenizer":
+ if (embeddings[1] - embeddings[0]) < duration_to_gen:
+ if trim:
+ trim_length = (int((embeddings[1] - embeddings[0])*75))
+ this_music_token = this_music_token[:,:trim_length]
+ logging.info(f"After trimmed, generated audio token length: {this_music_token.shape[1]}")
+ elif (embeddings[1] - embeddings[0]) < 1:
+ logging.info(f"Given audio length={(embeddings[1] - embeddings[0])}, which is too short, please give a longer audio length.")
+
+ this_music_audio = self.semantictoken2wav(token=this_music_token)
+
+ yield {'music_audio': this_music_audio.cpu()}
+ torch.cuda.synchronize()
\ No newline at end of file
diff --git a/inspiremusic/dataset/__init__.py b/inspiremusic/dataset/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/inspiremusic/dataset/dataset.py b/inspiremusic/dataset/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..188872bf21cfbf55bd3df12ff3071e62c8b49f06
--- /dev/null
+++ b/inspiremusic/dataset/dataset.py
@@ -0,0 +1,154 @@
+# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
+# 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+import json
+import math
+from functools import partial
+
+import torch
+import torch.distributed as dist
+from torch.utils.data import IterableDataset
+from inspiremusic.utils.file_utils import read_lists, read_json_lists
+
+class Processor(IterableDataset):
+
+ def __init__(self, source, f, *args, **kw):
+ assert callable(f)
+ self.source = source
+ self.f = f
+ self.args = args
+ self.kw = kw
+
+ def set_epoch(self, epoch):
+ self.source.set_epoch(epoch)
+
+ def __iter__(self):
+ """ Return an iterator over the source dataset processed by the
+ given processor.
+ """
+ assert self.source is not None
+ assert callable(self.f)
+ return self.f(iter(self.source), *self.args, **self.kw)
+
+ def apply(self, f):
+ assert callable(f)
+ return Processor(self, f, *self.args, **self.kw)
+
+
+class DistributedSampler:
+
+ def __init__(self, shuffle=True, partition=True):
+ self.epoch = -1
+ self.update()
+ self.shuffle = shuffle
+ self.partition = partition
+
+ def update(self):
+ assert dist.is_available()
+ if dist.is_initialized():
+ self.rank = dist.get_rank()
+ self.world_size = dist.get_world_size()
+ else:
+ self.rank = 0
+ self.world_size = 1
+ worker_info = torch.utils.data.get_worker_info()
+ if worker_info is None:
+ self.worker_id = 0
+ self.num_workers = 1
+ else:
+ self.worker_id = worker_info.id
+ self.num_workers = worker_info.num_workers
+ return dict(rank=self.rank,
+ world_size=self.world_size,
+ worker_id=self.worker_id,
+ num_workers=self.num_workers)
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
+
+ def sample(self, data):
+ """ Sample data according to rank/world_size/num_workers
+
+ Args:
+ data(List): input data list
+
+ Returns:
+ List: data list after sample
+ """
+ data = list(range(len(data)))
+ # force datalist even
+
+ if self.partition:
+ if self.shuffle:
+ random.Random(self.epoch).shuffle(data)
+ if len(data) < self.world_size:
+ print(len(data), self.world_size)
+ data = data * math.ceil(self.world_size / len(data))
+ data = data[:self.world_size]
+ data = data[self.rank::self.world_size]
+ if len(data) < self.num_workers:
+ data = data * math.ceil(self.num_workers / len(data))
+ data = data[:self.num_workers]
+ data = data[self.worker_id::self.num_workers]
+ return data
+
+
+class DataList(IterableDataset):
+
+ def __init__(self, lists, shuffle=True, partition=True):
+ self.lists = lists
+ self.sampler = DistributedSampler(shuffle, partition)
+
+ def set_epoch(self, epoch):
+ self.sampler.set_epoch(epoch)
+
+ def __iter__(self):
+ sampler_info = self.sampler.update()
+ indexes = self.sampler.sample(self.lists)
+ for index in indexes:
+ data = dict(src=self.lists[index])
+ data.update(sampler_info)
+ yield data
+
+
+def Dataset(data_list_file,
+ data_pipeline,
+ mode='train',
+ shuffle=True,
+ partition=True
+ ):
+ """ Construct dataset from arguments
+
+ We have two shuffle stage in the Dataset. The first is global
+ shuffle at shards tar/raw file level. The second is global shuffle
+ at training samples level.
+
+ Args:
+ data_type(str): raw/shard
+ tokenizer (BaseTokenizer): tokenizer to tokenize
+ partition(bool): whether to do data partition in terms of rank
+ """
+ assert mode in ['train', 'inference', 'processing']
+ lists = read_lists(data_list_file)
+
+ dataset = DataList(lists,
+ shuffle=shuffle,
+ partition=partition)
+
+ for func in data_pipeline:
+ dataset = Processor(dataset, func, mode=mode)
+
+ return dataset
diff --git a/inspiremusic/dataset/processor.py b/inspiremusic/dataset/processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..21572b593e8e36b96202a1264b62cd73c7b6ecf0
--- /dev/null
+++ b/inspiremusic/dataset/processor.py
@@ -0,0 +1,595 @@
+# Copyright (c) 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import random
+
+import pyarrow.parquet as pq
+import torch
+import torchaudio
+from torch.nn.utils.rnn import pad_sequence
+import torch.nn.functional as F
+import numpy as np
+import re
+
+torchaudio.set_audio_backend('soundfile')
+
+AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
+CHORUS = {"intro": 0, "chorus": 1, "verse1": 2, "verse2": 3, "verse": 2,
+ "outro": 4}
+
+metadata_pattern = re.compile(r'^\[(ti|ar|al|by|offset):.*\]$')
+timestamp_pattern = re.compile(r'^\[\d{2}:\d{2}\.\d{2}\](.*)$')
+
+
+def parquet_opener(data, mode='train', audio_data={}):
+ """ Give url or local file, return file descriptor
+ Inplace operation.
+
+ Args:
+ data(Iterable[str]): url or local file list
+
+ Returns:
+ Iterable[{src, stream}]
+ """
+ for sample in data:
+ assert 'src' in sample
+
+ url = sample['src']
+ try:
+ df = pq.read_table(url).to_pandas()
+ for i in df.index:
+ sample.update(dict(df.loc[i]))
+ yield {**sample}
+ except Exception as ex:
+ logging.warning('Failed to open {}, ex info {}'.format(url, ex))
+
+
+def clean_lyrics(data, mode="train"):
+ for sample in data:
+ lyrics = sample["text"]
+ cleaned = []
+ for line in lyrics.splitlines():
+ if metadata_pattern.match(line):
+ continue
+ timestamp_match = timestamp_pattern.match(line)
+ if timestamp_match:
+ lyric = timestamp_match.group(1).strip()
+ if lyric:
+ cleaned.append(lyric)
+ else:
+ if line.strip():
+ cleaned.append(line.strip())
+ sample["text"] = '\n'.join(cleaned)
+ yield sample
+
+
+def cut_by_length(data, max_length=8000, num_times=4, mode="train"):
+ for sample in data:
+ if "semantic_token" in sample:
+ sample["semantic_token"] = [
+ sample["semantic_token"][0][:max_length]]
+ if "acoustic_token" not in sample:
+ sample["acoustic_token"] = sample["speech_token"]
+ sample["acoustic_token"] = sample["acoustic_token"][
+ :max_length * num_times]
+
+ yield sample
+
+
+def filter(data,
+ max_length=22500, # 22500 #5min #10240
+ max_acoustic_length=45000,
+ min_length=10,
+ min_acoustic_length=150,
+ token_max_length=200,
+ token_min_length=1,
+ min_output_input_ratio=0.0005,
+ max_output_input_ratio=1,
+ mode='train'):
+ """ Filter sample according to feature and label length
+ Inplace operation.
+
+ Args::
+ data: Iterable[{key, wav, label, sample_rate}]
+ max_length: drop utterance which is greater than max_length(10ms)
+ min_length: drop utterance which is less than min_length(10ms)
+ token_max_length: drop utterance which is greater than
+ token_max_length, especially when use char unit for
+ english modeling
+ token_min_length: drop utterance which is
+ less than token_max_length
+ min_output_input_ratio: minimal ration of
+ token_length / feats_length(10ms)
+ max_output_input_ratio: maximum ration of
+ token_length / feats_length(10ms)
+
+ Returns:
+ Iterable[{key, wav, label, sample_rate}]
+ """
+ if mode == "train":
+ for sample in data:
+ if "semantic_token" in sample:
+ new_sample_frames = sample['semantic_token'][0].shape[0]
+ else:
+ new_sample_frames = sample['speech_token']
+
+ if "text_token" in sample:
+ new_sample_frames += len(sample['text_token'])
+
+ if new_sample_frames > max_length or new_sample_frames < min_length:
+ print(f"skipped 1 item length={new_sample_frames}")
+ continue
+
+ sample["chorus"] = sample["chorus"].split(",")
+ if not isinstance(sample["time_start"], np.ndarray):
+ sample["time_start"] = [sample["time_start"]]
+ sample["time_end"] = [sample["time_end"]]
+ for i, t in enumerate(sample["chorus"]):
+ if sample["chorus"][i] == "verse":
+ sample["chorus"][i] = "verse1"
+
+ yield sample
+
+ if mode == "train_flow":
+ for sample in data:
+ if "semantic_token" in sample:
+ new_sample_frames = sample['semantic_token'][0].shape[0]
+ if "acoustic_token" in sample:
+ target_sample_frames = sample['acoustic_token'][0].shape[0]
+
+ if new_sample_frames > max_length or new_sample_frames < min_acoustic_length or new_sample_frames < min_length or target_sample_frames > max_acoustic_length:
+ print(
+ f"skipped 1 item length={new_sample_frames}, target_length={target_sample_frames}")
+ continue
+
+ yield sample
+
+ elif mode == "inference":
+ for sample in data:
+ yield sample
+
+
+def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
+ """ Resample data.
+ Inplace operation.
+
+ Args:
+ data: Iterable[{key, wav, label, sample_rate}]
+ resample_rate: target resample rate
+
+ Returns:
+ Iterable[{key, wav, label, sample_rate}]
+ """
+ for sample in data:
+ assert 'sample_rate' in sample
+ assert 'speech' in sample
+ sample_rate = sample['sample_rate']
+ waveform = sample['speech']
+ if sample_rate != resample_rate:
+ if sample_rate < min_sample_rate:
+ continue
+ sample['sample_rate'] = resample_rate
+ sample['speech'] = torchaudio.transforms.Resample(
+ orig_freq=sample_rate, new_freq=resample_rate)(waveform)
+ max_val = sample['speech'].abs().max()
+ if max_val > 1:
+ sample['speech'] /= max_val
+ yield sample
+
+
+def truncate(data, truncate_length=24576, mode='train'):
+ """ Truncate data.
+
+ Args:
+ data: Iterable[{key, wav, label, sample_rate}]
+ truncate_length: truncate length
+
+ Returns:
+ Iterable[{key, wav, label, sample_rate}]
+ """
+ for sample in data:
+ waveform = sample['audio']
+ if waveform.shape[1] > truncate_length:
+ start = random.randint(0, waveform.shape[1] - truncate_length)
+ waveform = waveform[:, start: start + truncate_length]
+ else:
+ waveform = torch.concat([waveform, torch.zeros(1, truncate_length -
+ waveform.shape[1])],
+ dim=1)
+ sample['audio'] = waveform
+ yield sample
+
+
+def upsample(data, resample_rate=48000, min_sample_rate=16000, mode='train',
+ n_codebook=4):
+ """ Resample data.
+ Inplace operation.
+
+ Args:
+ data: Iterable[{key, wav, label, sample_rate}]
+ resample_rate: target resample rate
+
+ Returns:
+ Iterable[{key, wav, label, sample_rate}]
+ """
+ for sample in data:
+ assert 'semantic_token' in sample
+ # TODO: unify data processing key names
+ if 'acoustic_token' not in sample:
+ continue
+
+ if 'sample_rate' in sample.keys():
+ sample_rate = sample['sample_rate']
+ else:
+ sample_rate = 24000
+ token = np.array(sample['semantic_token'][0][:-1])
+
+ # Calculate the repetition factor for resampling
+ repetition_factor = int(n_codebook * resample_rate / sample_rate)
+ if sample_rate != resample_rate:
+ if sample_rate < min_sample_rate:
+ continue
+ sample['sample_rate'] = resample_rate
+ sample['semantic_token'] = np.array(
+ [np.repeat(token, repetition_factor)])
+
+ yield sample
+
+def compute_fbank(data,
+ feat_extractor,
+ mode='train'):
+ """ Extract fbank
+
+ Args:
+ data: Iterable[{key, wav, label, sample_rate}]
+
+ Returns:
+ Iterable[{key, feat, label}]
+ """
+ for sample in data:
+ assert 'sample_rate' in sample
+ assert 'speech' in sample
+ assert 'utt' in sample
+ assert 'text_token' in sample
+ waveform = sample['speech']
+ mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
+ sample['speech_feat'] = mat
+ del sample['speech']
+ yield sample
+
+
+def parse_embedding(data, normalize, mode='train'):
+ """ Parse utt_embedding/spk_embedding
+
+ Args:
+ data: Iterable[{key, wav, label, sample_rate}]
+
+ Returns:
+ Iterable[{key, feat, label}]
+ """
+
+ for sample in data:
+ sample['utt_embedding'] = torch.tensor(sample['utt_embedding'],
+ dtype=torch.float32)
+ sample['spk_embedding'] = torch.tensor(sample['spk_embedding'],
+ dtype=torch.float32)
+ if normalize:
+ sample['utt_embedding'] = F.normalize(sample['utt_embedding'],
+ dim=0)
+ sample['spk_embedding'] = F.normalize(sample['spk_embedding'],
+ dim=0)
+ yield sample
+
+def tokenize(data, get_tokenizer, allowed_special, mode='train'):
+ """ Decode text to chars or BPE
+ Inplace operation
+
+ Args:
+ data: Iterable[{key, wav, txt, sample_rate}]
+
+ Returns:
+ Iterable[{key, wav, txt, tokens, label, sample_rate}]
+ """
+ tokenizer = get_tokenizer()
+
+ for sample in data:
+ assert 'text' in sample
+ sample['text_token'] = tokenizer.encode(sample['text'],
+ allowed_special=allowed_special)
+ yield sample
+
+
+def shuffle(data, shuffle_size=10000, mode='train'):
+ """ Local shuffle the data
+
+ Args:
+ data: Iterable[{key, feat, label}]
+ shuffle_size: buffer size for shuffle
+
+ Returns:
+ Iterable[{key, feat, label}]
+ """
+ buf = []
+ for sample in data:
+ buf.append(sample)
+ if len(buf) >= shuffle_size:
+ random.shuffle(buf)
+ for x in buf:
+ yield x
+ buf = []
+ # The sample left over
+ random.shuffle(buf)
+ for x in buf:
+ yield x
+
+
+def sort(data, sort_size=500, mode='train'):
+ """ Sort the data by feature length.
+ Sort is used after shuffle and before batch, so we can group
+ utts with similar lengths into a batch, and `sort_size` should
+ be less than `shuffle_size`
+
+ Args:
+ data: Iterable[{key, feat, label}]
+ sort_size: buffer size for sort
+
+ Returns:
+ Iterable[{key, feat, label}]
+ """
+
+ buf = []
+ for sample in data:
+ if sample["chorus"] == "verse":
+ sample["chorus"] = "verse1"
+
+ if sample["acoustic_token"].shape[0] == 1:
+ sample["acoustic_token"] = np.concatenate(
+ sample["acoustic_token"][0])
+ else:
+ sample["acoustic_token"] = np.concatenate(sample["acoustic_token"])
+
+ sample["acoustic_token"] = torch.from_numpy(sample["acoustic_token"])
+ buf.append(sample)
+ if len(buf) >= sort_size:
+ buf.sort(key=lambda x: x['acoustic_token'].size(0))
+ for x in buf:
+ yield x
+ buf = []
+ # The sample left over
+ buf.sort(key=lambda x: x['acoustic_token'].size(0))
+ for x in buf:
+ yield x
+
+
+def static_batch(data, batch_size=32):
+ """ Static batch the data by `batch_size`
+
+ Args:
+ data: Iterable[{key, feat, label}]
+ batch_size: batch size
+
+ Returns:
+ Iterable[List[{key, feat, label}]]
+ """
+ buf = []
+ data_empty = True
+ for sample in data:
+ data_empty = False
+ buf.append(sample)
+ if len(buf) >= batch_size:
+ yield buf
+ buf = []
+ if data_empty:
+ raise ValueError("data is empty")
+ if len(buf) > 0:
+ yield buf
+
+
+def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
+ """ Dynamic batch the data until the total frames in batch
+ reach `max_frames_in_batch`
+
+ Args:
+ data: Iterable[{key, feat, label}]
+ max_frames_in_batch: max_frames in one batch
+
+ Returns:
+ Iterable[List[{key, feat, label}]]
+ """
+ buf = []
+ longest_frames = 0
+ for sample in data:
+ assert 'acoustic_token' in sample
+ assert isinstance(sample['acoustic_token'], torch.Tensor)
+
+ if 'semantic_token' in sample:
+ new_sample_frames = sample['semantic_token'][0].shape[0]
+ else:
+ new_sample_frames = sample['semantic_token']
+
+ if "text_token" in sample:
+ new_sample_frames += len(sample['text_token'])
+
+ longest_frames = max(longest_frames, new_sample_frames)
+ frames_after_padding = longest_frames * (len(buf) + 1)
+
+ if frames_after_padding > max_frames_in_batch:
+ if len(buf) > 0:
+ yield buf
+ buf = [sample]
+ longest_frames = new_sample_frames
+ else:
+ buf.append(sample)
+ if len(buf) > 0:
+ yield buf
+
+
+def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000,
+ mode='train'):
+ """ Wrapper for static/dynamic batch
+ """
+ if mode == 'inference':
+ return static_batch(data, 1)
+ elif mode == 'processing':
+ return static_batch(data, batch_size)
+ else:
+ if batch_type == 'static':
+ return static_batch(data, batch_size)
+ elif batch_type == 'dynamic':
+ return dynamic_batch(data, max_frames_in_batch)
+ else:
+ logging.fatal('Unsupported batch type {}'.format(batch_type))
+
+
+def padding(data, mode='train'):
+ """ Padding the data into training data
+
+ Args:
+ data: Iterable[List[{key, feat, label}]]
+
+ Returns:
+ Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
+ """
+ if mode == "train":
+ for sample in data:
+ assert isinstance(sample, list)
+ if len(sample) != 0:
+ acoustic_feat_len = torch.tensor(
+ [x['acoustic_token'].size(0) for x in sample],
+ dtype=torch.int32)
+ order = torch.argsort(acoustic_feat_len, descending=True)
+ utts = [sample[i]['utt'] for i in order]
+ acoustic_token = [
+ sample[i]['acoustic_token'].clone().to(torch.int32) for i in
+ order]
+ acoustic_token_len = torch.tensor(
+ [i.size(0) for i in acoustic_token], dtype=torch.int32)
+
+ acoustic_token = pad_sequence(acoustic_token,
+ batch_first=True,
+ padding_value=0)
+
+ text = [sample[i]['text'] for i in order]
+ text_token = [torch.tensor(sample[i]['text_token']).long() for i
+ in order]
+ text_token_len = torch.tensor([i.size(0) for i in text_token],
+ dtype=torch.int32)
+ text_token = pad_sequence(text_token, batch_first=True,
+ padding_value=0)
+ time_start = torch.tensor(
+ [sample[i]['time_start'] for i in order])
+ time_end = torch.tensor([sample[i]['time_end'] for i in order])
+
+ if isinstance(sample[0]['chorus'], str):
+ chorus = torch.tensor(
+ [CHORUS[sample[i]['chorus']] for i in order])
+ else:
+ chorus = [
+ torch.tensor([CHORUS[t] for t in sample[i]['chorus']])
+ for i in order]
+ chorus = pad_sequence(chorus, batch_first=True,
+ padding_value=-1)
+
+ batch = {
+ "utts" : utts,
+ "acoustic_token" : acoustic_token,
+ "acoustic_token_len": acoustic_token_len,
+ "time_start" : time_start,
+ "time_end" : time_end,
+ "chorus" : chorus,
+ "text" : text,
+ "text_token" : text_token,
+ "text_token_len" : text_token_len,
+ }
+
+ if "semantic_token" in sample[0]:
+ semantic_token = [
+ torch.tensor(sample[i]['semantic_token'][0],
+ dtype=torch.int32) for i in order]
+ semantic_token_len = torch.tensor(
+ [i.size(0) for i in semantic_token],
+ dtype=torch.int32)
+ semantic_token = pad_sequence(semantic_token,
+ batch_first=True,
+ padding_value=0)
+ batch.update({"semantic_token" : semantic_token,
+ "semantic_token_len": semantic_token_len})
+
+ yield batch
+ else:
+ logging.info("WARNING: sample is empty []!")
+
+ elif mode == "inference":
+ for sample in data:
+ assert isinstance(sample, list)
+ utts = [sample[i]['utt'] for i in range(len(sample))]
+ text = [sample[i]['text'] for i in range(len(sample))]
+ text_token = [torch.tensor(sample[i]['text_token']).long() for i in
+ range(len(sample))]
+ text_token_len = torch.tensor([i.size(0) for i in text_token],
+ dtype=torch.int32)
+ text_token = pad_sequence(text_token, batch_first=True,
+ padding_value=0)
+ time_start = torch.tensor(
+ [sample[i]['time_start'] for i in range(len(sample))])
+ time_end = torch.tensor(
+ [sample[i]['time_end'] for i in range(len(sample))])
+
+ if isinstance(sample[0]['chorus'], str):
+ chorus = torch.tensor([CHORUS[sample[i]['chorus']] for i in
+ range(len(sample))])
+ else:
+ chorus = [torch.tensor([CHORUS[t] for t in sample[i]['chorus']])
+ for i in range(len(sample))]
+ chorus = pad_sequence(chorus, batch_first=True,
+ padding_value=-1)
+
+ if "acoustic_token" in sample[0]:
+ acoustic_token = [
+ sample[i]['acoustic_token'].clone().to(torch.int32) for i in
+ range(len(sample))]
+ acoustic_token_len = torch.tensor(
+ [i.size(0) for i in acoustic_token], dtype=torch.int32)
+ acoustic_token = pad_sequence(acoustic_token,
+ batch_first=True,
+ padding_value=0)
+ else:
+ acoustic_token = None
+ acoustic_token_len = None
+
+ batch = {
+ "utts" : utts,
+ "acoustic_token" : acoustic_token,
+ "acoustic_token_len": acoustic_token_len,
+ "time_start" : time_start,
+ "time_end" : time_end,
+ "chorus" : chorus,
+ "text" : text,
+ "text_token" : text_token,
+ "text_token_len" : text_token_len,
+ }
+
+ if "semantic_token" in sample[0]:
+ semantic_token = [torch.tensor(sample[i]['semantic_token'][0],
+ dtype=torch.int32) for i in
+ range(len(sample))]
+ semantic_token_len = torch.tensor(
+ [i.size(0) for i in semantic_token], dtype=torch.int32)
+ semantic_token = pad_sequence(semantic_token,
+ batch_first=True,
+ padding_value=0)
+ batch.update({"semantic_token" : semantic_token,
+ "semantic_token_len": semantic_token_len})
+
+ yield batch
diff --git a/inspiremusic/flow/decoder.py b/inspiremusic/flow/decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d1dff57ba44a1b4952d6836729233368c131fe8
--- /dev/null
+++ b/inspiremusic/flow/decoder.py
@@ -0,0 +1,277 @@
+# Copyright (c) 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+import torch.nn as nn
+from einops import pack, rearrange, repeat
+from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
+from matcha.models.components.transformer import BasicTransformerBlock
+
+class Transpose(torch.nn.Module):
+ def __init__(self, dim0: int, dim1: int):
+ super().__init__()
+ self.dim0 = dim0
+ self.dim1 = dim1
+
+ def forward(self, x: torch.Tensor):
+ x = torch.transpose(x, self.dim0, self.dim1)
+ return x
+
+class CausalBlock1D(Block1D):
+ def __init__(self, dim: int, dim_out: int):
+ super(CausalBlock1D, self).__init__(dim, dim_out)
+ self.block = torch.nn.Sequential(
+ CausalConv1d(dim, dim_out, 3),
+ Transpose(1, 2),
+ nn.LayerNorm(dim_out),
+ Transpose(1, 2),
+ nn.Mish(),
+ )
+
+ def forward(self, x: torch.Tensor, mask: torch.Tensor):
+ output = self.block(x * mask)
+ return output * mask
+
+
+class CausalResnetBlock1D(ResnetBlock1D):
+ def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
+ super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
+ self.block1 = CausalBlock1D(dim, dim_out)
+ self.block2 = CausalBlock1D(dim_out, dim_out)
+
+class CausalConv1d(torch.nn.Conv1d):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ dilation: int = 1,
+ groups: int = 1,
+ bias: bool = True,
+ padding_mode: str = 'zeros',
+ device=None,
+ dtype=None
+ ) -> None:
+ super(CausalConv1d, self).__init__(in_channels, out_channels,
+ kernel_size, stride,
+ padding=0, dilation=dilation,
+ groups=groups, bias=bias,
+ padding_mode=padding_mode,
+ device=device, dtype=dtype)
+ assert stride == 1
+ self.causal_padding = (kernel_size - 1, 0)
+
+ def forward(self, x: torch.Tensor):
+ x = F.pad(x, self.causal_padding)
+ x = super(CausalConv1d, self).forward(x)
+ return x
+
+class ConditionalDecoder(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ channels=(256, 256),
+ dropout=0.05,
+ attention_head_dim=64,
+ n_blocks=1,
+ num_mid_blocks=2,
+ num_heads=4,
+ act_fn="snake",
+ ):
+ """
+ This decoder requires an input with the same shape of the target. So, if your text content
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
+ """
+ super().__init__()
+ channels = tuple(channels)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
+ time_embed_dim = channels[0] * 4
+ self.time_mlp = TimestepEmbedding(
+ in_channels=in_channels,
+ time_embed_dim=time_embed_dim,
+ act_fn="silu",
+ )
+ self.down_blocks = nn.ModuleList([])
+ self.mid_blocks = nn.ModuleList([])
+ self.up_blocks = nn.ModuleList([])
+
+ output_channel = in_channels
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
+ input_channel = output_channel
+ output_channel = channels[i]
+ is_last = i == len(channels) - 1
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
+ transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ dim=output_channel,
+ num_attention_heads=num_heads,
+ attention_head_dim=attention_head_dim,
+ dropout=dropout,
+ activation_fn=act_fn,
+ )
+ for _ in range(n_blocks)
+ ]
+ )
+ downsample = (
+ Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
+ )
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
+
+ for _ in range(num_mid_blocks):
+ input_channel = channels[-1]
+ out_channels = channels[-1]
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
+
+ transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ dim=output_channel,
+ num_attention_heads=num_heads,
+ attention_head_dim=attention_head_dim,
+ dropout=dropout,
+ activation_fn=act_fn,
+ )
+ for _ in range(n_blocks)
+ ]
+ )
+
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
+
+ channels = channels[::-1] + (channels[0],)
+ for i in range(len(channels) - 1):
+ input_channel = channels[i] * 2
+ output_channel = channels[i + 1]
+ is_last = i == len(channels) - 2
+ resnet = ResnetBlock1D(
+ dim=input_channel,
+ dim_out=output_channel,
+ time_emb_dim=time_embed_dim,
+ )
+ transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ dim=output_channel,
+ num_attention_heads=num_heads,
+ attention_head_dim=attention_head_dim,
+ dropout=dropout,
+ activation_fn=act_fn,
+ )
+ for _ in range(n_blocks)
+ ]
+ )
+ upsample = (
+ Upsample1D(output_channel, use_conv_transpose=True)
+ if not is_last
+ else nn.Conv1d(output_channel, output_channel, 3, padding=1)
+ )
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
+ self.final_block = Block1D(channels[-1], channels[-1])
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv1d):
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.GroupNorm):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x, mask, mu, t, spks=None, cond=None):
+ """Forward pass of the UNet1DConditional model.
+
+ Args:
+ x (torch.Tensor): shape (batch_size, in_channels, time)
+ mask (_type_): shape (batch_size, 1, time)
+ t (_type_): shape (batch_size)
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
+ cond (_type_, optional): placeholder for future use. Defaults to None.
+
+ Raises:
+ ValueError: _description_
+ ValueError: _description_
+
+ Returns:
+ _type_: _description_
+ """
+
+ t = self.time_embeddings(t).to(t.dtype)
+ t = self.time_mlp(t)
+ x = pack([x, mu], "b * t")[0]
+ if spks is not None:
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
+ x = pack([x, spks], "b * t")[0]
+ if cond is not None:
+ x = pack([x, cond], "b * t")[0]
+ hiddens = []
+ masks = [mask]
+ for resnet, transformer_blocks, downsample in self.down_blocks:
+ mask_down = masks[-1]
+ x = resnet(x, mask_down, t)
+ x = rearrange(x, "b c t -> b t c").contiguous()
+ attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
+ for transformer_block in transformer_blocks:
+ x = transformer_block(
+ hidden_states=x,
+ attention_mask=attn_mask,
+ timestep=t,
+ )
+ x = rearrange(x, "b t c -> b c t").contiguous()
+ hiddens.append(x) # Save hidden states for skip connections
+ x = downsample(x * mask_down)
+ masks.append(mask_down[:, :, ::2])
+ masks = masks[:-1]
+ mask_mid = masks[-1]
+
+ for resnet, transformer_blocks in self.mid_blocks:
+ x = resnet(x, mask_mid, t)
+ x = rearrange(x, "b c t -> b t c").contiguous()
+ attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
+ for transformer_block in transformer_blocks:
+ x = transformer_block(
+ hidden_states=x,
+ attention_mask=attn_mask,
+ timestep=t,
+ )
+ x = rearrange(x, "b t c -> b c t").contiguous()
+
+ for resnet, transformer_blocks, upsample in self.up_blocks:
+ mask_up = masks.pop()
+ skip = hiddens.pop()
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
+ x = resnet(x, mask_up, t)
+ x = rearrange(x, "b c t -> b t c").contiguous()
+ attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
+ for transformer_block in transformer_blocks:
+ x = transformer_block(
+ hidden_states=x,
+ attention_mask=attn_mask,
+ timestep=t,
+ )
+ x = rearrange(x, "b t c -> b c t").contiguous()
+ x = upsample(x * mask_up)
+ x = self.final_block(x, mask_up)
+ output = self.final_proj(x * mask_up)
+ return output * mask
diff --git a/inspiremusic/flow/flow.py b/inspiremusic/flow/flow.py
new file mode 100644
index 0000000000000000000000000000000000000000..c67534ed8a1465c1e34a911263f57172801507d8
--- /dev/null
+++ b/inspiremusic/flow/flow.py
@@ -0,0 +1,143 @@
+# Copyright (c) 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import random
+from typing import Dict, Optional
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+from omegaconf import DictConfig
+from inspiremusic.utils.mask import make_pad_mask
+from inspiremusic.music_tokenizer.vqvae import VQVAE
+
+class MaskedDiff(torch.nn.Module):
+ def __init__(self,
+ input_size: int = 512,
+ output_size: int = 128,
+ output_type: str = "mel",
+ vocab_size: int = 4096,
+ input_frame_rate: int = 50,
+ only_mask_loss: bool = True,
+ encoder: torch.nn.Module = None,
+ length_regulator: torch.nn.Module = None,
+ decoder: torch.nn.Module = None,
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80,
+ 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
+ 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
+ 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
+ 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
+ mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 128, 'sampling_rate': 48000,
+ 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 48000},
+ generator_model_dir: str = "../../pretrained_models/InspireMusic-Base/music_tokenizer",
+ num_codebooks: int = 4
+ ):
+ super().__init__()
+ self.input_size = input_size
+ self.output_size = output_size
+ self.decoder_conf = decoder_conf
+ self.mel_feat_conf = mel_feat_conf
+ self.vocab_size = vocab_size
+ self.output_type = output_type
+ self.input_frame_rate = input_frame_rate
+ logging.info(f"input frame rate={self.input_frame_rate}")
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
+
+ self.encoder = encoder
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
+ self.decoder = decoder
+ self.length_regulator = length_regulator
+ self.only_mask_loss = only_mask_loss
+ self.quantizer = VQVAE( f'{generator_model_dir}/config.json',
+ f'{generator_model_dir}/model.pt',with_encoder=True).quantizer
+ self.quantizer.eval()
+ self.num_codebooks = num_codebooks
+ self.cond = None
+ self.interpolate = False
+
+ def forward(
+ self,
+ batch: dict,
+ device: torch.device,
+ ) -> Dict[str, Optional[torch.Tensor]]:
+
+ audio_token = batch['acoustic_token'].to(device)
+ audio_token_len = batch['acoustic_token_len'].to(device)
+ audio_token = audio_token.view(audio_token.size(0),-1,self.num_codebooks)
+ if "semantic_token" not in batch:
+ token = audio_token[:,:,0]
+ token_len = (audio_token_len/self.num_codebooks).long()
+
+ else:
+ token = batch['semantic_token'].to(device)
+ token_len = batch['semantic_token_len'].to(device)
+
+ with torch.no_grad():
+ feat = self.quantizer.embed(audio_token)
+ feat_len = (audio_token_len/self.num_codebooks).long()
+
+ token = self.input_embedding(token)
+ h, h_lengths = self.encoder(token, token_len)
+ h, h_lengths = self.length_regulator(h, feat_len)
+
+ # get conditions
+ if self.cond:
+ conds = torch.zeros(feat.shape, device=token.device)
+ for i, j in enumerate(feat_len):
+ if random.random() < 0.5:
+ continue
+ index = random.randint(0, int(0.3 * j))
+ conds[i, :index] = feat[i, :index]
+ conds = conds.transpose(1, 2)
+ else:
+ conds = None
+
+ mask = (~make_pad_mask(feat_len)).to(h)
+
+ loss, _ = self.decoder.compute_loss(
+ feat,
+ mask.unsqueeze(1),
+ h.transpose(1, 2).contiguous(),
+ None,
+ cond=conds
+ )
+
+ return {'loss': loss}
+
+ @torch.inference_mode()
+ def inference(self,
+ token,
+ token_len,
+ sample_rate):
+ assert token.shape[0] == 1
+
+ token = self.input_embedding(torch.clamp(token, min=0))
+ h, h_lengths = self.encoder(token, token_len)
+
+ if sample_rate == 48000:
+ token_len = 2 * token_len
+
+ h, h_lengths = self.length_regulator(h, token_len)
+
+ # get conditions
+ conds = None
+
+ mask = (~make_pad_mask(token_len)).to(h)
+ feat = self.decoder(
+ mu=h.transpose(1, 2).contiguous(),
+ mask=mask.unsqueeze(1),
+ spks=None,
+ cond=conds,
+ n_timesteps=10
+ )
+ return feat
\ No newline at end of file
diff --git a/inspiremusic/flow/flow_matching.py b/inspiremusic/flow/flow_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..e803c6c59f4c2daa1fd1d0a7ee374706507e78c3
--- /dev/null
+++ b/inspiremusic/flow/flow_matching.py
@@ -0,0 +1,167 @@
+# Copyright (c) 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+import torch.nn.functional as F
+from matcha.models.components.flow_matching import BASECFM
+
+
+class ConditionalCFM(BASECFM):
+ def __init__(self, in_channels, cfm_params, estimator: torch.nn.Module = None):
+ super().__init__(
+ n_feats=in_channels,
+ cfm_params=cfm_params,
+ )
+ self.t_scheduler = cfm_params.t_scheduler
+ self.training_cfg_rate = cfm_params.training_cfg_rate
+ self.inference_cfg_rate = cfm_params.inference_cfg_rate
+ # Just change the architecture of the estimator here
+ self.estimator = estimator
+
+ @torch.inference_mode()
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
+ """Forward diffusion
+
+ Args:
+ mu (torch.Tensor): output of encoder
+ shape: (batch_size, n_feats, mel_timesteps)
+ mask (torch.Tensor): output_mask
+ shape: (batch_size, 1, mel_timesteps)
+ n_timesteps (int): number of diffusion steps
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
+ shape: (batch_size, spk_emb_dim)
+ cond: Not used but kept for future purposes
+
+ Returns:
+ sample: generated mel-spectrogram
+ shape: (batch_size, n_feats, mel_timesteps)
+ """
+ z = torch.randn_like(mu) * temperature
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
+ if self.t_scheduler == 'cosine':
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
+
+ def solve_euler(self, x, t_span, mu, mask, spks, cond):
+ """
+ Fixed euler solver for ODEs.
+ Args:
+ x (torch.Tensor): random noise
+ t_span (torch.Tensor): n_timesteps interpolated
+ shape: (n_timesteps + 1,)
+ mu (torch.Tensor): output of encoder
+ shape: (batch_size, n_feats, mel_timesteps)
+ mask (torch.Tensor): output_mask
+ shape: (batch_size, 1, mel_timesteps)
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
+ shape: (batch_size, spk_emb_dim)
+ cond: Not used but kept for future purposes
+ """
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
+ t = t.unsqueeze(dim=0)
+
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
+ # Or in future might add like a return_all_steps flag
+ sol = []
+
+ for step in range(1, len(t_span)):
+ dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
+ # Classifier-Free Guidance inference introduced in VoiceBox
+ if self.inference_cfg_rate > 0:
+ cfg_dphi_dt = self.forward_estimator(
+ x, mask,
+ torch.zeros_like(mu), t,
+ torch.zeros_like(spks) if spks is not None else None,
+ torch.zeros_like(cond) if cond is not None else None
+ )
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
+ self.inference_cfg_rate * cfg_dphi_dt)
+ x = x + dt * dphi_dt
+ t = t + dt
+ sol.append(x)
+ if step < len(t_span) - 1:
+ dt = t_span[step + 1] - t
+
+ return sol[-1]
+
+ def forward_estimator(self, x, mask, mu, t, spks, cond):
+ if isinstance(self.estimator, torch.nn.Module):
+ return self.estimator.forward(x, mask, mu, t, spks, cond)
+ elif isinstance(self.estimator, onnxruntime.InferenceSession):
+ ort_inputs = {
+ 'x': x.cpu().numpy(),
+ 'mask': mask.cpu().numpy(),
+ 'mu': mu.cpu().numpy(),
+ 't': t.cpu().numpy(),
+ 'spks': spks.cpu().numpy(),
+ 'cond': cond.cpu().numpy()
+ }
+ output = self.estimator.run(None, ort_inputs)[0]
+ return torch.tensor(output, dtype=x.dtype, device=x.device)
+ else:
+ self.estimator.set_input_shape('x', (2, 80, x.size(2)))
+ self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
+ self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
+ self.estimator.set_input_shape('t', (2,))
+ self.estimator.set_input_shape('spks', (2, 80))
+ self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
+ # run trt engine
+ self.estimator.execute_v2([x.contiguous().data_ptr(),
+ mask.contiguous().data_ptr(),
+ mu.contiguous().data_ptr(),
+ t.contiguous().data_ptr(),
+ spks.contiguous().data_ptr(),
+ cond.contiguous().data_ptr(),
+ x.data_ptr()])
+ return x
+
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None):
+ """Computes diffusion loss
+
+ Args:
+ x1 (torch.Tensor): Target
+ shape: (batch_size, n_feats, mo)
+ mask (torch.Tensor): target mask
+ shape: (batch_size, 1, mel_timesteps)
+ mu (torch.Tensor): output of encoder
+ shape: (batch_size, n_feats, mel_timesteps)
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
+ shape: (batch_size, spk_emb_dim)
+
+ Returns:
+ loss: conditional flow matching loss
+ y: conditional flow
+ shape: (batch_size, n_feats, mel_timesteps)
+ """
+ b, _, t = mu.shape
+
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
+ if self.t_scheduler == 'cosine':
+ t = 1 - torch.cos(t * 0.5 * torch.pi)
+
+ z = torch.randn_like(x1)
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
+ u = x1 - (1 - self.sigma_min) * z
+
+ # during training, we randomly drop condition to trade off mode coverage and sample fidelity
+ if self.training_cfg_rate > 0:
+ cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
+ mu = mu * cfg_mask.view(-1, 1, 1)
+ if cond is not None:
+ cond = cond * cfg_mask.view(-1, 1, 1)
+
+ pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
+ loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
+ return loss, y
+
diff --git a/inspiremusic/flow/length_regulator.py b/inspiremusic/flow/length_regulator.py
new file mode 100644
index 0000000000000000000000000000000000000000..05b74a9403a526c65dd05f0e558c62084b1772fe
--- /dev/null
+++ b/inspiremusic/flow/length_regulator.py
@@ -0,0 +1,69 @@
+# Copyright (c) 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Tuple
+import torch.nn as nn
+import torch
+from torch.nn import functional as F
+from inspiremusic.utils.mask import make_pad_mask
+
+
+class InterpolateRegulator(nn.Module):
+ def __init__(
+ self,
+ channels: int,
+ sampling_ratios: Tuple,
+ out_channels: int = None,
+ groups: int = 1,
+ ):
+ super().__init__()
+ self.sampling_ratios = sampling_ratios
+ out_channels = out_channels or channels
+ model = nn.ModuleList([])
+ if len(sampling_ratios) > 0:
+ for _ in sampling_ratios:
+ module = nn.Conv1d(channels, channels, 3, 1, 1)
+ norm = nn.GroupNorm(groups, channels)
+ act = nn.Mish()
+ model.extend([module, norm, act])
+ model.append(
+ nn.Conv1d(channels, out_channels, 1, 1)
+ )
+ self.model = nn.Sequential(*model)
+
+ def forward(self, x, ylens=None):
+ # x in (B, T, D)
+ mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
+ x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
+ out = self.model(x).transpose(1, 2).contiguous()
+ olens = ylens
+ return out * mask, olens
+
+ def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
+ # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
+ # x in (B, T, D)
+ if x2.shape[1] > 40:
+ x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
+ x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
+ mode='linear')
+ x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
+ x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
+ else:
+ x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
+ if x1.shape[1] != 0:
+ x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
+ x = torch.concat([x1, x2], dim=2)
+ else:
+ x = x2
+ out = self.model(x).transpose(1, 2).contiguous()
+ return out, mel_len1 + mel_len2
diff --git a/inspiremusic/hifigan/discriminator.py b/inspiremusic/hifigan/discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fc784599f3493e20830290a9cd182789c0428d5
--- /dev/null
+++ b/inspiremusic/hifigan/discriminator.py
@@ -0,0 +1,140 @@
+import torch
+import torch.nn as nn
+from torch.nn.utils import weight_norm
+from typing import List, Optional, Tuple
+from einops import rearrange
+from torchaudio.transforms import Spectrogram
+
+
+class MultipleDiscriminator(nn.Module):
+ def __init__(
+ self, mpd: nn.Module, mrd: nn.Module
+ ):
+ super().__init__()
+ self.mpd = mpd
+ self.mrd = mrd
+
+ def forward(self, y: torch.Tensor, y_hat: torch.Tensor):
+ y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
+ this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mpd(y.unsqueeze(dim=1), y_hat.unsqueeze(dim=1))
+ y_d_rs += this_y_d_rs
+ y_d_gs += this_y_d_gs
+ fmap_rs += this_fmap_rs
+ fmap_gs += this_fmap_gs
+ this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mrd(y, y_hat)
+ y_d_rs += this_y_d_rs
+ y_d_gs += this_y_d_gs
+ fmap_rs += this_fmap_rs
+ fmap_gs += this_fmap_gs
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+class MultiResolutionDiscriminator(nn.Module):
+ def __init__(
+ self,
+ fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
+ num_embeddings: Optional[int] = None,
+ ):
+ """
+ Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
+ Additionally, it allows incorporating conditional information with a learned embeddings table.
+
+ Args:
+ fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
+ num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
+ Defaults to None.
+ """
+
+ super().__init__()
+ self.discriminators = nn.ModuleList(
+ [DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes]
+ )
+
+ def forward(
+ self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
+ y_d_rs = []
+ y_d_gs = []
+ fmap_rs = []
+ fmap_gs = []
+
+ for d in self.discriminators:
+ y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
+ y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
+ y_d_rs.append(y_d_r)
+ fmap_rs.append(fmap_r)
+ y_d_gs.append(y_d_g)
+ fmap_gs.append(fmap_g)
+
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+class DiscriminatorR(nn.Module):
+ def __init__(
+ self,
+ window_length: int,
+ num_embeddings: Optional[int] = None,
+ channels: int = 32,
+ hop_factor: float = 0.25,
+ bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
+ ):
+ super().__init__()
+ self.window_length = window_length
+ self.hop_factor = hop_factor
+ self.spec_fn = Spectrogram(
+ n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
+ )
+ n_fft = window_length // 2 + 1
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
+ self.bands = bands
+ convs = lambda: nn.ModuleList(
+ [
+ weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
+ weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
+ ]
+ )
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
+
+ if num_embeddings is not None:
+ self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
+ torch.nn.init.zeros_(self.emb.weight)
+
+ self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
+
+ def spectrogram(self, x):
+ # Remove DC offset
+ x = x - x.mean(dim=-1, keepdims=True)
+ # Peak normalize the volume of input audio
+ x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
+ x = self.spec_fn(x)
+ x = torch.view_as_real(x)
+ x = rearrange(x, "b f t c -> b c t f")
+ # Split into bands
+ x_bands = [x[..., b[0]: b[1]] for b in self.bands]
+ return x_bands
+
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
+ x_bands = self.spectrogram(x)
+ fmap = []
+ x = []
+ for band, stack in zip(x_bands, self.band_convs):
+ for i, layer in enumerate(stack):
+ band = layer(band)
+ band = torch.nn.functional.leaky_relu(band, 0.1)
+ if i > 0:
+ fmap.append(band)
+ x.append(band)
+ x = torch.cat(x, dim=-1)
+ if cond_embedding_id is not None:
+ emb = self.emb(cond_embedding_id)
+ h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
+ else:
+ h = 0
+ x = self.conv_post(x)
+ fmap.append(x)
+ x += h
+
+ return x, fmap
diff --git a/inspiremusic/hifigan/f0_predictor.py b/inspiremusic/hifigan/f0_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..331394ddc5c9240a2f66186ab0ce263d80ceeac0
--- /dev/null
+++ b/inspiremusic/hifigan/f0_predictor.py
@@ -0,0 +1,55 @@
+# Copyright (c) 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+import torch.nn as nn
+from torch.nn.utils import weight_norm
+
+
+class ConvRNNF0Predictor(nn.Module):
+ def __init__(self,
+ num_class: int = 1,
+ in_channels: int = 80,
+ cond_channels: int = 512
+ ):
+ super().__init__()
+
+ self.num_class = num_class
+ self.condnet = nn.Sequential(
+ weight_norm(
+ nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
+ ),
+ nn.ELU(),
+ weight_norm(
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
+ ),
+ nn.ELU(),
+ weight_norm(
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
+ ),
+ nn.ELU(),
+ weight_norm(
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
+ ),
+ nn.ELU(),
+ weight_norm(
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
+ ),
+ nn.ELU(),
+ )
+ self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.condnet(x)
+ x = x.transpose(1, 2)
+ return torch.abs(self.classifier(x).squeeze(-1))
diff --git a/inspiremusic/hifigan/generator.py b/inspiremusic/hifigan/generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..033758c05d55b9a1e23f79c7b551c6000762ee26
--- /dev/null
+++ b/inspiremusic/hifigan/generator.py
@@ -0,0 +1,411 @@
+# Copyright (c) 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""HIFI-GAN"""
+
+from typing import Dict, Optional, List
+import numpy as np
+from scipy.signal import get_window
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import Conv1d
+from torch.nn import ConvTranspose1d
+from torch.nn.utils import remove_weight_norm
+from torch.nn.utils import weight_norm
+from torch.distributions.uniform import Uniform
+
+from inspiremusic.transformer.activation import Snake
+from inspiremusic.utils.common import get_padding
+from inspiremusic.utils.common import init_weights
+
+
+"""hifigan based generator implementation.
+
+This code is modified from https://github.com/jik876/hifi-gan
+ ,https://github.com/kan-bayashi/ParallelWaveGAN and
+ https://github.com/NVIDIA/BigVGAN
+
+"""
+
+
+class ResBlock(torch.nn.Module):
+ """Residual block module in HiFiGAN/BigVGAN."""
+ def __init__(
+ self,
+ channels: int = 512,
+ kernel_size: int = 3,
+ dilations: List[int] = [1, 3, 5],
+ ):
+ super(ResBlock, self).__init__()
+ self.convs1 = nn.ModuleList()
+ self.convs2 = nn.ModuleList()
+
+ for dilation in dilations:
+ self.convs1.append(
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation,
+ padding=get_padding(kernel_size, dilation)
+ )
+ )
+ )
+ self.convs2.append(
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1)
+ )
+ )
+ )
+ self.convs1.apply(init_weights)
+ self.convs2.apply(init_weights)
+ self.activations1 = nn.ModuleList([
+ Snake(channels, alpha_logscale=False)
+ for _ in range(len(self.convs1))
+ ])
+ self.activations2 = nn.ModuleList([
+ Snake(channels, alpha_logscale=False)
+ for _ in range(len(self.convs2))
+ ])
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ for idx in range(len(self.convs1)):
+ xt = self.activations1[idx](x)
+ xt = self.convs1[idx](xt)
+ xt = self.activations2[idx](xt)
+ xt = self.convs2[idx](xt)
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for idx in range(len(self.convs1)):
+ remove_weight_norm(self.convs1[idx])
+ remove_weight_norm(self.convs2[idx])
+
+
+class SineGen(torch.nn.Module):
+ """ Definition of sine generator
+ SineGen(samp_rate, harmonic_num = 0,
+ sine_amp = 0.1, noise_std = 0.003,
+ voiced_threshold = 0,
+ flag_for_pulse=False)
+ samp_rate: sampling rate in Hz
+ harmonic_num: number of harmonic overtones (default 0)
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
+ noise_std: std of Gaussian noise (default 0.003)
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
+ Note: when flag_for_pulse is True, the first time step of a voiced
+ segment is always sin(np.pi) or cos(0)
+ """
+
+ def __init__(self, samp_rate, harmonic_num=0,
+ sine_amp=0.1, noise_std=0.003,
+ voiced_threshold=0):
+ super(SineGen, self).__init__()
+ self.sine_amp = sine_amp
+ self.noise_std = noise_std
+ self.harmonic_num = harmonic_num
+ self.sampling_rate = samp_rate
+ self.voiced_threshold = voiced_threshold
+
+ def _f02uv(self, f0):
+ # generate uv signal
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
+ return uv
+
+ @torch.no_grad()
+ def forward(self, f0):
+ """
+ :param f0: [B, 1, sample_len], Hz
+ :return: [B, 1, sample_len]
+ """
+
+ F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
+ for i in range(self.harmonic_num + 1):
+ F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
+
+ theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
+ u_dist = Uniform(low=-np.pi, high=np.pi)
+ phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
+ phase_vec[:, 0, :] = 0
+
+ # generate sine waveforms
+ sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
+
+ # generate uv signal
+ uv = self._f02uv(f0)
+
+ # noise: for unvoiced should be similar to sine_amp
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
+ # . for voiced regions is self.noise_std
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
+ noise = noise_amp * torch.randn_like(sine_waves)
+
+ # first: set the unvoiced part to 0 by uv
+ # then: additive noise
+ sine_waves = sine_waves * uv + noise
+ return sine_waves, uv, noise
+
+
+class SourceModuleHnNSF(torch.nn.Module):
+ """ SourceModule for hn-nsf
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
+ add_noise_std=0.003, voiced_threshod=0)
+ sampling_rate: sampling_rate in Hz
+ harmonic_num: number of harmonic above F0 (default: 0)
+ sine_amp: amplitude of sine source signal (default: 0.1)
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
+ note that amplitude of noise in unvoiced is decided
+ by sine_amp
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
+ F0_sampled (batchsize, length, 1)
+ Sine_source (batchsize, length, 1)
+ noise_source (batchsize, length 1)
+ uv (batchsize, length, 1)
+ """
+
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
+ add_noise_std=0.003, voiced_threshod=0):
+ super(SourceModuleHnNSF, self).__init__()
+
+ self.sine_amp = sine_amp
+ self.noise_std = add_noise_std
+
+ # to produce sine waveforms
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
+ sine_amp, add_noise_std, voiced_threshod)
+
+ # to merge source harmonics into a single excitation
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
+ self.l_tanh = torch.nn.Tanh()
+
+ def forward(self, x):
+ """
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
+ F0_sampled (batchsize, length, 1)
+ Sine_source (batchsize, length, 1)
+ noise_source (batchsize, length 1)
+ """
+ # source for harmonic branch
+ with torch.no_grad():
+ sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
+ sine_wavs = sine_wavs.transpose(1, 2)
+ uv = uv.transpose(1, 2)
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
+
+ # source for noise branch, in the same shape as uv
+ noise = torch.randn_like(uv) * self.sine_amp / 3
+ return sine_merge, noise, uv
+
+
+class HiFTGenerator(nn.Module):
+ """
+ HiFTNet Generator: Neural Source Filter + ISTFTNet
+ https://arxiv.org/abs/2309.09493
+ """
+ def __init__(
+ self,
+ in_channels: int = 80,
+ base_channels: int = 512,
+ nb_harmonics: int = 8,
+ sampling_rate: int = 22050,
+ nsf_alpha: float = 0.1,
+ nsf_sigma: float = 0.003,
+ nsf_voiced_threshold: float = 10,
+ upsample_rates: List[int] = [8, 8],
+ upsample_kernel_sizes: List[int] = [16, 16],
+ istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
+ resblock_kernel_sizes: List[int] = [3, 7, 11],
+ resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+ source_resblock_kernel_sizes: List[int] = [7, 11],
+ source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
+ lrelu_slope: float = 0.1,
+ audio_limit: float = 0.99,
+ f0_predictor: torch.nn.Module = None,
+ ):
+ super(HiFTGenerator, self).__init__()
+
+ self.out_channels = 1
+ self.nb_harmonics = nb_harmonics
+ self.sampling_rate = sampling_rate
+ self.istft_params = istft_params
+ self.lrelu_slope = lrelu_slope
+ self.audio_limit = audio_limit
+
+ self.num_kernels = len(resblock_kernel_sizes)
+ self.num_upsamples = len(upsample_rates)
+ self.m_source = SourceModuleHnNSF(
+ sampling_rate=sampling_rate,
+ upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
+ harmonic_num=nb_harmonics,
+ sine_amp=nsf_alpha,
+ add_noise_std=nsf_sigma,
+ voiced_threshod=nsf_voiced_threshold)
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
+
+ self.conv_pre = weight_norm(
+ Conv1d(in_channels, base_channels, 7, 1, padding=3)
+ )
+
+ # Up
+ self.ups = nn.ModuleList()
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+ self.ups.append(
+ weight_norm(
+ ConvTranspose1d(
+ base_channels // (2**i),
+ base_channels // (2**(i + 1)),
+ k,
+ u,
+ padding=(k - u) // 2,
+ )
+ )
+ )
+
+ # Down
+ self.source_downs = nn.ModuleList()
+ self.source_resblocks = nn.ModuleList()
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
+ downsample_cum_rates = np.cumprod(downsample_rates)
+ for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
+ if u == 1:
+ self.source_downs.append(
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
+ )
+ else:
+ self.source_downs.append(
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
+ )
+
+ self.source_resblocks.append(
+ ResBlock(base_channels // (2 ** (i + 1)), k, d)
+ )
+
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = base_channels // (2**(i + 1))
+ for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
+ self.resblocks.append(ResBlock(ch, k, d))
+
+ self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
+ self.ups.apply(init_weights)
+ self.conv_post.apply(init_weights)
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
+ self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
+ self.f0_predictor = f0_predictor
+
+ def remove_weight_norm(self):
+ print('Removing weight norm...')
+ for l in self.ups:
+ remove_weight_norm(l)
+ for l in self.resblocks:
+ l.remove_weight_norm()
+ remove_weight_norm(self.conv_pre)
+ remove_weight_norm(self.conv_post)
+ self.m_source.remove_weight_norm()
+ for l in self.source_downs:
+ remove_weight_norm(l)
+ for l in self.source_resblocks:
+ l.remove_weight_norm()
+
+ def _stft(self, x):
+ spec = torch.stft(
+ x,
+ self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
+ return_complex=True)
+ spec = torch.view_as_real(spec) # [B, F, TT, 2]
+ return spec[..., 0], spec[..., 1]
+
+ def _istft(self, magnitude, phase):
+ magnitude = torch.clip(magnitude, max=1e2)
+ real = magnitude * torch.cos(phase)
+ img = magnitude * torch.sin(phase)
+ inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
+ self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
+ return inverse_transform
+
+ def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
+
+ x = self.conv_pre(x)
+ for i in range(self.num_upsamples):
+ x = F.leaky_relu(x, self.lrelu_slope)
+ x = self.ups[i](x)
+
+ if i == self.num_upsamples - 1:
+ x = self.reflection_pad(x)
+
+ # fusion
+ si = self.source_downs[i](s_stft)
+ si = self.source_resblocks[i](si)
+ x = x + si
+
+ xs = None
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i * self.num_kernels + j](x)
+ else:
+ xs += self.resblocks[i * self.num_kernels + j](x)
+ x = xs / self.num_kernels
+
+ x = F.leaky_relu(x)
+ x = self.conv_post(x)
+ magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
+ phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
+
+ x = self._istft(magnitude, phase)
+ x = torch.clamp(x, -self.audio_limit, self.audio_limit)
+ return x
+
+ def forward(
+ self,
+ batch: dict,
+ device: torch.device,
+ ) -> Dict[str, Optional[torch.Tensor]]:
+ speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
+ # mel->f0
+ f0 = self.f0_predictor(speech_feat)
+ # f0->source
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
+ s, _, _ = self.m_source(s)
+ s = s.transpose(1, 2)
+ # mel+source->speech
+ generated_speech = self.decode(x=speech_feat, s=s)
+ return generated_speech, f0
+
+ @torch.inference_mode()
+ def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
+ # mel->f0
+ f0 = self.f0_predictor(speech_feat)
+ # f0->source
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
+ s, _, _ = self.m_source(s)
+ s = s.transpose(1, 2)
+ # use cache_source to avoid glitch
+ if cache_source.shape[2] != 0:
+ s[:, :, :cache_source.shape[2]] = cache_source
+ generated_speech = self.decode(x=speech_feat, s=s)
+ return generated_speech, s
diff --git a/inspiremusic/hifigan/hifigan.py b/inspiremusic/hifigan/hifigan.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d7b612d1cd86569d430e8bf03bc7e3e0fa72957
--- /dev/null
+++ b/inspiremusic/hifigan/hifigan.py
@@ -0,0 +1,66 @@
+from typing import Dict, Optional
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from matcha.hifigan.models import feature_loss, generator_loss, discriminator_loss
+from inspiremusic.utils.losses import tpr_loss, mel_loss
+
+class HiFiGan(nn.Module):
+ def __init__(self, generator, discriminator, mel_spec_transform,
+ multi_mel_spectral_recon_loss_weight=45, feat_match_loss_weight=2.0,
+ tpr_loss_weight=1.0, tpr_loss_tau=0.04):
+ super(HiFiGan, self).__init__()
+ self.generator = generator
+ self.discriminator = discriminator
+ self.mel_spec_transform = mel_spec_transform
+ self.multi_mel_spectral_recon_loss_weight = multi_mel_spectral_recon_loss_weight
+ self.feat_match_loss_weight = feat_match_loss_weight
+ self.tpr_loss_weight = tpr_loss_weight
+ self.tpr_loss_tau = tpr_loss_tau
+
+ def forward(
+ self,
+ batch: dict,
+ device: torch.device,
+ ) -> Dict[str, Optional[torch.Tensor]]:
+ if batch['turn'] == 'generator':
+ return self.forward_generator(batch, device)
+ else:
+ return self.forward_discriminator(batch, device)
+
+ def forward_generator(self, batch, device):
+ real_speech = batch['speech'].to(device)
+ pitch_feat = batch['pitch_feat'].to(device)
+ # 1. calculate generator outputs
+ generated_speech, generated_f0 = self.generator(batch, device)
+ # 2. calculate discriminator outputs
+ y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
+ # 3. calculate generator losses, feature loss, mel loss, tpr losses [Optional]
+ loss_gen, _ = generator_loss(y_d_gs)
+ loss_fm = feature_loss(fmap_rs, fmap_gs)
+ loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform)
+ if self.tpr_loss_weight != 0:
+ loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
+ else:
+ loss_tpr = torch.zeros(1).to(device)
+ loss_f0 = F.l1_loss(generated_f0, pitch_feat)
+ loss = loss_gen + self.feat_match_loss_weight * loss_fm + \
+ self.multi_mel_spectral_recon_loss_weight * loss_mel + \
+ self.tpr_loss_weight * loss_tpr + loss_f0
+ return {'loss': loss, 'loss_gen': loss_gen, 'loss_fm': loss_fm, 'loss_mel': loss_mel, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}
+
+ def forward_discriminator(self, batch, device):
+ real_speech = batch['speech'].to(device)
+ # 1. calculate generator outputs
+ with torch.no_grad():
+ generated_speech, generated_f0 = self.generator(batch, device)
+ # 2. calculate discriminator outputs
+ y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
+ # 3. calculate discriminator losses, tpr losses [Optional]
+ loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs)
+ if self.tpr_loss_weight != 0:
+ loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
+ else:
+ loss_tpr = torch.zeros(1).to(device)
+ loss = loss_disc + self.tpr_loss_weight * loss_tpr
+ return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr}
diff --git a/inspiremusic/llm/llm.py b/inspiremusic/llm/llm.py
new file mode 100644
index 0000000000000000000000000000000000000000..338853613ef47957aa5508c151dbbf293849214c
--- /dev/null
+++ b/inspiremusic/llm/llm.py
@@ -0,0 +1,402 @@
+# Copyright (c) 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Dict, Optional, Callable, List, Generator
+import torch
+from torch import nn
+from torch.nn.utils.rnn import pad_sequence, unpad_sequence
+from inspiremusic.utils.common import IGNORE_ID
+from inspiremusic.transformer.label_smoothing_loss import LabelSmoothingLoss
+from inspiremusic.utils.common import th_accuracy
+from torch import Tensor
+from math import log
+from einops import rearrange, reduce, repeat
+import logging
+
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+
+class SinusoidalEmbedding(nn.Module):
+ def __init__(self, dim: int):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, x: Tensor) -> Tensor:
+ device, half_dim = x.device, self.dim // 2
+ emb = torch.tensor(log(10000) / (half_dim - 1), device=device)
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
+ emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
+ return torch.cat((emb.sin(), emb.cos()), dim=-1).to(torch.float32)
+
+class LLM(torch.nn.Module):
+ def __init__(
+ self,
+ text_encoder_input_size: int,
+ llm_input_size: int,
+ llm_output_size: int,
+ audio_token_size: int,
+ llm: torch.nn.Module,
+ sampling: Callable,
+ text_encoder_conf: Dict = None,
+ length_normalized_loss: bool = True,
+ lsm_weight: float = 0.0,
+ frozen_input_embed: bool = False,
+ **kwargs,
+ ):
+ super().__init__()
+ self.llm_input_size = llm_input_size
+ self.audio_token_size = audio_token_size
+ # 1. build text token inputs related modules
+
+ if llm is None:
+ self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
+ else:
+ self.text_embedding = llm.model.model.embed_tokens
+ if frozen_input_embed:
+ print("Freezing input embedding layer")
+ for p in self.text_embedding.parameters():
+ p.requires_grad = False
+ self.chorus_embedding = torch.nn.Embedding(5, llm_input_size) # intro, chorus, verse1, verse2 , outro
+
+ self.text_encoder_conf = text_encoder_conf
+ self.text_encoder = self.build_encoder(text_encoder_conf)
+ self.infer_cfg_ratio = kwargs.get("infer_cfg_ratio", None)
+ logging.info(f"infer_cfg_ratio: {self.infer_cfg_ratio}")
+ self.train_cfg_ratio = kwargs.get("train_cfg_ratio", None)
+ logging.info(f"train_cfg_ratio: {self.train_cfg_ratio}")
+ # 2. build audio token language model related modules
+ self.sos_eos = 0
+ self.task_id = 1
+
+ self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
+ self.llm = llm
+ self.llm_decoder = nn.Linear(llm_output_size, audio_token_size + 1)
+ self.criterion_ce = LabelSmoothingLoss(
+ size=audio_token_size + 1,
+ padding_idx=IGNORE_ID,
+ smoothing=lsm_weight,
+ normalize_length=length_normalized_loss,
+ )
+
+ # 3. [Optional] build audio token related modules
+ self.speech_embedding = torch.nn.Embedding(audio_token_size, llm_input_size)
+ self.spk_embed_affine_layer = torch.nn.Linear(192, llm_input_size)
+ self.num_codebooks = 4
+ # 4. sampling method
+ self.sampling = sampling
+ self.time_embedding = SinusoidalEmbedding(llm_input_size)
+
+ def cfg_dropout(self, text_token, text_token_len, p):
+ # Classifier-Free Guidance Dropout
+ B = text_token.size(0)
+ num_samples_to_mask = int(p * B)
+ if num_samples_to_mask == 0:
+ num_samples_to_mask = 1
+ indices_to_mask = torch.randperm(B, device=text_token.device)[:num_samples_to_mask]
+ text_token[indices_to_mask] = 0
+ text_token_len[indices_to_mask] = 0
+
+ return text_token, text_token_len
+
+ def build_encoder(self, encoder_conf=None):
+ if encoder_conf is None:
+ assert hasattr(self, "encoder_conf"), \
+ "function param encoder_conf is None and model doesn't has encoder_conf attribute either."
+ encoder_conf = self.encoder_conf
+
+ encoder_name = encoder_conf.pop("name", "transformer")
+ model = None
+ if encoder_name == "transformer":
+ from inspiremusic.transformer.encoder.conformer_encoder import ConformerEncoder
+ model = ConformerEncoder(
+ **encoder_conf,
+ input_size=self.input_size,
+ use_cnn_module=False,
+ macaron_style=False,
+ )
+ elif encoder_name == "conformer":
+ from inspiremusic.transformer.encoder.conformer_encoder import ConformerEncoder
+ model = ConformerEncoder(
+ **encoder_conf,
+ input_size=self.input_size,
+ )
+ elif encoder_name == "llama_encoder":
+ from inspiremusic.transformer.encoder.llama_encoder import LlamaEncoder
+ model = LlamaEncoder(
+ **encoder_conf,
+ input_size=self.input_size,
+ )
+ elif encoder_name == "qwen2":
+ from inspiremusic.transformer.encoder.qwen_encoder import QwenEncoder
+ model = QwenEncoder(
+ **encoder_conf,
+ input_size=self.input_size,
+ )
+ elif encoder_name == "qwen2.5":
+ from inspiremusic.transformer.encoder.qwen_encoder import QwenEncoder
+ model = QwenEncoder(
+ **encoder_conf,
+ input_size=self.input_size,
+ )
+
+ encoder_conf["name"] = encoder_name
+
+ return model
+
+ def encode(self,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor):
+ if self.text_encoder is not None:
+ encoder_out, encoder_mask = self.text_encoder(text, text_lengths,
+ decoding_chunk_size=1,
+ num_decoding_left_chunks=-1)
+ encoder_out_lens = encoder_mask.squeeze(1).sum(1)
+ encoder_out = self.text_encoder_affine_layer(encoder_out)
+ else:
+ encoder_out, encoder_out_lens = text, text_lengths
+ return encoder_out, encoder_out_lens
+
+ def pad_unpad_sequence(self, sos_eos_emb, embeddings, text_token,
+ text_token_len, task_id_emb, audio_token,
+ audio_token_len, seg_len):
+ text_token = unpad_sequence(text_token, text_token_len.cpu(),
+ batch_first=True)
+
+ audio_token = unpad_sequence(audio_token, audio_token_len.cpu(),
+ batch_first=True)
+
+ for i in range(len(embeddings)):
+ embeddings[i] = unpad_sequence(embeddings[i], seg_len.cpu(), batch_first=True)
+
+ lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0)] + [embedding[i] for embedding in embeddings] + [text_token[i], task_id_emb.squeeze(dim=0), audio_token[i]], dim=0) for i in range(len(text_token))]
+ lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
+ lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
+ return lm_input, lm_input_len
+
+ def forward(
+ self,
+ batch: dict,
+ device: torch.device,
+ ) -> Dict[str, Optional[torch.Tensor]]:
+ """
+ Args:
+ text: (B, L, D)
+ text_lengths: (B,)
+ audio: (B, T, N) or (B, T)
+ audio_lengths: (B,)
+ """
+ mask = True
+ text_token = batch['text_token'].to(device)
+ text_token_len = batch['text_token_len'].to(device)
+ if "semantic_token" not in batch:
+ audio_token = batch['acoustic_token'].to(device)
+ audio_token_len = batch['acoustic_token_len'].to(device)
+ audio_token = audio_token.view(audio_token.size(0), -1, self.num_codebooks)
+ audio_token = audio_token[:, :, 0]
+ audio_token_len = (audio_token_len / self.num_codebooks).long()
+
+ else:
+ audio_token = batch['semantic_token'].to(device)
+ audio_token_len = batch['semantic_token_len'].to(device)
+
+ time_start = batch['time_start'].to(device)
+ time_end = batch['time_end'].to(device)
+ chorus = batch['chorus'].to(device)
+ # 1. encode text_token
+
+ if self.train_cfg_ratio > 0:
+ # Classifier-Free Guidance
+ text_token, _ = self.cfg_dropout(text_token, text_token_len, self.train_cfg_ratio)
+
+ # 2. Time Embedding & chorus embedding
+ text_token = self.text_embedding(text_token)
+ text_token, text_token_len = self.encode(text_token, text_token_len)
+ if mask:
+ time_mask = time_start != -1.0
+ seg_len = time_mask.sum(-1)
+ time_start = time_start.masked_fill(~time_mask, 0.0)
+ time_end = time_end.masked_fill(~time_mask, 0.0)
+ chorus = chorus.masked_fill(~time_mask, 0)
+ time_start_embed = self.time_embedding(time_start.view(-1)).to(text_token.dtype)
+ time_end_embed = self.time_embedding(time_end.view(-1)).to(text_token.dtype)
+ time_start_embed = time_start_embed.view(chorus.size(0), chorus.size(1), -1)
+ time_end_embed = time_end_embed.view(chorus.size(0), chorus.size(1), -1)
+ chorus_embed = self.chorus_embedding(chorus)
+ lm_target = [torch.tensor([IGNORE_ID] * (1 + 3 * seg_len[i] + text_token_len[i]) + audio_token[i,:audio_token_len[i]].tolist() + [self.audio_token_size]) for i in range(text_token.size(0))]
+ else:
+ time_start_embed = self.time_embedding(time_start).to(text_token.dtype)
+ time_end_embed = self.time_embedding(time_end).to(text_token.dtype)
+ chorus_embed = self.chorus_embedding(chorus)
+
+ lm_target = [torch.tensor(
+ [IGNORE_ID] * (4 + text_token_len[i]) + audio_token[i,:audio_token_len[i]].tolist() + [self.audio_token_size]) for i in range(text_token.size(0))]
+
+ lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
+
+ # 3. eos and task_id
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
+
+ # 4. encode audio_token
+ audio_token = self.speech_embedding(audio_token)
+
+ # 5. unpad and pad
+ lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb,
+ [time_start_embed,
+ time_end_embed,
+ chorus_embed],
+ text_token,
+ text_token_len,
+ task_id_emb,
+ audio_token,
+ audio_token_len,
+ seg_len)
+ # 6. run lm forward
+ lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
+ logits = self.llm_decoder(lm_output)
+ loss = self.criterion_ce(logits, lm_target)
+
+ acc = th_accuracy(logits.view(-1, self.audio_token_size + 1), lm_target, ignore_label=IGNORE_ID)
+
+ return {'loss': loss, 'acc': acc}
+
+ def sampling_ids(
+ self,
+ weighted_scores: torch.Tensor,
+ decoded_tokens: List,
+ ignore_eos: bool = True,
+ ):
+ top_ids = self.sampling(weighted_scores, decoded_tokens)
+ return top_ids
+
+ @torch.inference_mode()
+ def inference(
+ self,
+ text: torch.Tensor,
+ text_len: torch.Tensor,
+ audio_token: torch.Tensor,
+ audio_token_len: torch.Tensor,
+ prompt_text: torch.Tensor,
+ prompt_text_len: torch.Tensor,
+ prompt_audio_token: torch.Tensor,
+ prompt_audio_token_len: torch.Tensor,
+ embeddings: List,
+ duration_to_gen: float = 30,
+ task: str = "continuation",
+ token_rate: int = 75,
+ limit_audio_prompt_len: int = 5,
+ ) -> Generator[torch.Tensor, None, None]:
+ device = text.device
+
+ if text is not None:
+ text = torch.concat([prompt_text, text], dim=1)
+ text_len += prompt_text_len
+ infer_cfg = self.infer_cfg_ratio >= 0.0
+ if infer_cfg:
+ text_cfg = self.text_embedding(text.new_zeros(text.shape))
+ text = self.text_embedding(text)
+
+ # 1. encode text
+ text, text_len = self.encode(text, text_len)
+
+ # 2. encode embedding
+ if embeddings is not None:
+ time_start, time_end, chorus = embeddings
+
+ if len(chorus.shape) == 1:
+ time_start_embed = self.time_embedding(time_start).reshape(1, 1, -1) # .half()
+ time_end_embed = self.time_embedding(time_end).reshape(1, 1, -1) # .half()
+ chorus_embed = self.chorus_embedding(chorus).reshape(1, 1, -1) # .half()
+ else:
+ time_start_embed = self.time_embedding(
+ time_start.view(-1)).reshape(1, chorus.size(1), -1) # .half()
+ time_end_embed = self.time_embedding(time_end.view(-1)).reshape(1, chorus.size(1), -1) # .half()
+ chorus_embed = self.chorus_embedding(chorus) # .half()
+
+ # 3. concat llm_input
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
+
+ if audio_token_len:
+ audio_token = audio_token[:, :(limit_audio_prompt_len * token_rate)]
+ audio_token_emb = self.speech_embedding(audio_token)
+ else:
+ audio_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
+
+ if prompt_audio_token_len:
+ prompt_audio_token_emb = self.speech_embedding(prompt_audio_token)
+ else:
+ prompt_audio_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
+ # Check if removing prompt audio token will fail decoding.
+
+ if task == "continuation":
+ lm_input = torch.concat(
+ [sos_eos_emb, time_start_embed, time_end_embed,
+ chorus_embed, text, task_id_emb, audio_token_emb], dim=1)
+
+ if infer_cfg:
+ audio_cfg = self.speech_embedding(
+ audio_token.new_zeros(audio_token.shape))
+ lm_cf_input = torch.concat(
+ [sos_eos_emb, torch.rand_like(time_start_embed),
+ torch.rand_like(time_end_embed),
+ torch.rand_like(chorus_embed), text_cfg, task_id_emb,
+ audio_cfg], dim=1)
+ lm_input = torch.cat([lm_input, lm_cf_input], 0)
+ else:
+ lm_input = torch.concat(
+ [sos_eos_emb, time_start_embed, time_end_embed,
+ chorus_embed, text, task_id_emb], dim=1)
+ if infer_cfg:
+ lm_cf_input = torch.concat(
+ [sos_eos_emb, torch.rand_like(time_start_embed),
+ torch.rand_like(time_end_embed),
+ torch.rand_like(chorus_embed), text_cfg, task_id_emb],
+ dim=1)
+ lm_input = torch.cat([lm_input, lm_cf_input], 0)
+
+ # 4. cal min/max_length
+ min_len = duration_to_gen * token_rate
+ max_len = duration_to_gen * token_rate
+ logging.info(
+ f"LLM generation sequence length: {max_len}, generate audio length {duration_to_gen}s.")
+
+ # 5. step by step decode
+ out_tokens = []
+ offset = 0
+ state = None
+
+ for i in range(int(max_len)):
+ y_pred, _, state = self.llm.forward_one_step(lm_input, torch.ones(lm_input.shape[0], lm_input.shape[1], device=lm_input.device).to(torch.bool), cache=state)
+ logits = self.llm_decoder(y_pred[:, -1])
+ if infer_cfg:
+ # perform context free guidance
+ logits_cf = logits[1]
+ logits = logits[0]
+ infer_cfg_ratio = self.infer_cfg_ratio
+ logits = infer_cfg_ratio * logits + (1 - infer_cfg_ratio) * logits_cf
+
+ logp = logits.log_softmax(dim=-1)
+ logp = logp.squeeze(dim=0)
+ top_ids = self.sampling_ids(logp, out_tokens, ignore_eos=i < min_len).item()
+
+ if top_ids == self.audio_token_size:
+ break
+
+ # # in stream mode, yield token one by one
+
+ yield torch.tensor([[top_ids]], dtype=torch.int64, device=device)
+ out_tokens.append(top_ids)
+ offset += lm_input.size(1)
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
+ if infer_cfg:
+ lm_input = lm_input.repeat(2, 1, 1)
diff --git a/inspiremusic/metrics/clap_score.py b/inspiremusic/metrics/clap_score.py
new file mode 100644
index 0000000000000000000000000000000000000000..d77b200323d9374f4ea64ee3a7eeeb1c0c21fecf
--- /dev/null
+++ b/inspiremusic/metrics/clap_score.py
@@ -0,0 +1,135 @@
+# Copyright (c) 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import requests
+from tqdm import tqdm
+import torch
+import numpy as np
+import laion_clap
+from clap_module.factory import load_state_dict
+import librosa
+import pyloudnorm as pyln
+
+# following documentation from https://github.com/LAION-AI/CLAP
+def int16_to_float32(x):
+ return (x / 32767.0).astype(np.float32)
+
+def float32_to_int16(x):
+ x = np.clip(x, a_min=-1., a_max=1.)
+ return (x * 32767.).astype(np.int16)
+
+
+def clap_score(id2text, audio_path, audio_files_extension='.wav', clap_model='music_audioset_epoch_15_esc_90.14.pt'):
+ """
+ Cosine similarity is computed between the LAION-CLAP text embedding of the given prompt and
+ the LAION-CLAP audio embedding of the generated audio. LION-CLAP: https://github.com/LAION-AI/CLAP
+
+ This evaluation script assumes that audio_path files are identified with the ids in id2text.
+
+ clap_score() evaluates all ids in id2text.
+
+ GPU-based computation.
+
+ Select one of the following models from https://github.com/LAION-AI/CLAP:
+ - music_speech_audioset_epoch_15_esc_89.98.pt (used by musicgen)
+ - music_audioset_epoch_15_esc_90.14.pt
+ - music_speech_epoch_15_esc_89.25.pt
+ - 630k-audioset-fusion-best.pt (our default, with "fusion" to handle longer inputs)
+
+ Params:
+ -- id2text: dictionary with the mapping between id (generated audio filenames in audio_path)
+ and text (prompt used to generate audio). clap_score() evaluates all ids in id2text.
+ -- audio_path: path where the generated audio files to evaluate are available.
+ -- audio_files_extension: files extension (default .wav) in eval_path.
+ -- clap_model: choose one of the above clap_models (default: '630k-audioset-fusion-best.pt').
+ Returns:
+ -- CLAP-LION score
+ """
+ # load model
+ if clap_model == 'music_speech_audioset_epoch_15_esc_89.98.pt':
+ url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_speech_audioset_epoch_15_esc_89.98.pt'
+ clap_path = 'CLAP/music_speech_audioset_epoch_15_esc_89.98.pt'
+ model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cuda')
+ elif clap_model == 'music_audioset_epoch_15_esc_90.14.pt':
+ url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_audioset_epoch_15_esc_90.14.pt'
+ clap_path = 'CLAP/music_audioset_epoch_15_esc_90.14.pt'
+ model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cuda')
+ elif clap_model == 'music_speech_epoch_15_esc_89.25.pt':
+ url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_speech_epoch_15_esc_89.25.pt'
+ clap_path = 'CLAP/music_speech_epoch_15_esc_89.25.pt'
+ model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cuda')
+ elif clap_model == '630k-audioset-fusion-best.pt':
+ url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/630k-audioset-fusion-best.pt'
+ clap_path = 'CLAP/630k-audioset-fusion-best.pt'
+ model = laion_clap.CLAP_Module(enable_fusion=True, device='cuda')
+ else:
+ raise ValueError('clap_model not implemented')
+
+ # download clap_model if not already downloaded
+ if not os.path.exists(clap_path):
+ print('Downloading ', clap_model, '...')
+ os.makedirs(os.path.dirname(clap_path), exist_ok=True)
+
+ response = requests.get(url, stream=True)
+ total_size = int(response.headers.get('content-length', 0))
+
+ with open(clap_path, 'wb') as file:
+ with tqdm(total=total_size, unit='B', unit_scale=True) as progress_bar:
+ for data in response.iter_content(chunk_size=8192):
+ file.write(data)
+ progress_bar.update(len(data))
+
+ # fixing CLAP-LION issue, see: https://github.com/LAION-AI/CLAP/issues/118
+ pkg = load_state_dict(clap_path)
+ pkg.pop('text_branch.embeddings.position_ids', None)
+ model.model.load_state_dict(pkg)
+ model.eval()
+
+ if not os.path.isdir(audio_path):
+ raise ValueError(f'audio_path: {audio_path} does not exist')
+
+ if id2text:
+ print('[EXTRACTING TEXT EMBEDDINGS] ')
+ batch_size = 64
+ text_emb = {}
+ for i in tqdm(range(0, len(id2text), batch_size)):
+ batch_ids = list(id2text.keys())[i:i+batch_size]
+ batch_texts = [id2text[id] for id in batch_ids]
+ with torch.no_grad():
+ embeddings = model.get_text_embedding(batch_texts, use_tensor=True)
+ for id, emb in zip(batch_ids, embeddings):
+ text_emb[id] = emb
+
+ else:
+ raise ValueError('Must specify id2text')
+
+ print('[EVALUATING GENERATIONS] ', audio_path)
+ score = 0
+ count = 0
+ for id in tqdm(id2text.keys()):
+ file_path = os.path.join(audio_path, str(id)+audio_files_extension)
+ if os.path.isfile(file_path):
+ with torch.no_grad():
+ audio, _ = librosa.load(file_path, sr=48000, mono=True) # sample rate should be 48000
+ audio = pyln.normalize.peak(audio, -1.0)
+ audio = audio.reshape(1, -1) # unsqueeze (1,T)
+ audio = torch.from_numpy(int16_to_float32(float32_to_int16(audio))).float()
+ audio_embeddings = model.get_audio_embedding_from_data(x = audio, use_tensor=True)
+ cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_emb[id].unsqueeze(0), dim=1, eps=1e-8)[0]
+ print(f"{id} | CLAP score = {cosine_sim}")
+ score += cosine_sim
+ count += 1
+
+ return score / count if count > 0 else 0
+
diff --git a/inspiremusic/metrics/openl3_fd.py b/inspiremusic/metrics/openl3_fd.py
new file mode 100644
index 0000000000000000000000000000000000000000..78287970a8250dec2c6bbc77b5b4791122a02259
--- /dev/null
+++ b/inspiremusic/metrics/openl3_fd.py
@@ -0,0 +1,338 @@
+# Copyright (c) 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import openl3
+import librosa
+import numpy as np
+from scipy import linalg
+import glob
+from tqdm import tqdm
+import os
+import soxr
+import pyloudnorm as pyln
+
+
+def calculate_embd_statistics(embd_lst):
+ if isinstance(embd_lst, list):
+ embd_lst = np.array(embd_lst)
+ mu = np.mean(embd_lst, axis=0)
+ sigma = np.cov(embd_lst, rowvar=False)
+ return mu, sigma
+
+
+def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
+ """
+ Adapted from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py
+ Adapted from: https://github.com/gudgud96/frechet-audio-distance/blob/main/frechet_audio_distance/fad.py
+
+ Numpy implementation of the Frechet Distance.
+
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
+ and X_2 ~ N(mu_2, C_2) is
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
+
+ Params:
+ -- mu1: Embedding's mean statistics for generated samples.
+ -- mu2: Embedding's mean statistics for reference samples.
+ -- sigma1: Covariance matrix over embeddings for generated samples.
+ -- sigma2: Covariance matrix over embeddings for reference samples.
+ Returns:
+ -- Frรฉchet Distance.
+ """
+
+ mu1 = np.atleast_1d(mu1)
+ mu2 = np.atleast_1d(mu2)
+
+ sigma1 = np.atleast_2d(sigma1)
+ sigma2 = np.atleast_2d(sigma2)
+
+ assert mu1.shape == mu2.shape, \
+ 'Training and test mean vectors have different lengths'
+ assert sigma1.shape == sigma2.shape, \
+ 'Training and test covariances have different dimensions'
+
+ diff = mu1 - mu2
+
+ # product might be almost singular
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
+ if not np.isfinite(covmean).all():
+ msg = ('fid calculation produces singular product; '
+ 'adding %s to diagonal of cov estimates') % eps
+ print(msg)
+ offset = np.eye(sigma1.shape[0]) * eps
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
+
+ # numerical error might give slight imaginary component
+ if np.iscomplexobj(covmean):
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
+ m = np.max(np.abs(covmean.imag))
+ raise ValueError('Imaginary component {}'.format(m))
+ covmean = covmean.real
+
+ tr_covmean = np.trace(covmean)
+
+ return (diff.dot(diff) + np.trace(sigma1)
+ + np.trace(sigma2) - 2 * tr_covmean)
+
+
+def extract_embeddings(directory_path, channels, samplingrate, content_type, openl3_hop_size, batch_size=16):
+ """
+ Given a list of files, compute their embeddings in batches.
+
+ If channels == 1: stereo audio is downmixed to mono. Mono embeddings are of dim=512.
+
+ If channels == 2: mono audio is "faked" to stereo by copying the mono channel.
+ Stereo embeddings are of dim=1024, since we concatenate L (dim=512) and R (dim=512) embeddings.
+
+ Params:
+ -- directory_path: path where the generated audio files are available.
+ -- channels: 1 (mono), or 2 (stereo) to get mono or stereo embeddings.
+ -- samplingrate: max bandwidth at which we evaluate the given signals. Up to 48kHz.
+ -- content_type: 'music' or 'env' to select a content type specific openl3 model.
+ -- openl3_hop_size: analysis resolution of openl3 in seconds. Openl3's input window is 1 sec.
+ -- batch_size: number of audio files to process in each batch.
+ Returns:
+ -- list of embeddings: [np.array[], ...], as expected by calculate_frechet_distance()
+ """
+ _, extension = os.path.splitext(directory_path)
+ if extension.lower() == ".scp":
+ wav_files = []
+ with open(directory_path, "r") as f:
+ for line in f:
+ sec = line.strip().split(" ")
+ wav_files.append(sec[1])
+ else:
+ wav_files = glob.glob(directory_path)
+ if len(wav_files) == 0:
+ raise ValueError('No files with this extension in this path!')
+ model = openl3.models.load_audio_embedding_model(input_repr="mel256", content_type=content_type, embedding_size=512)
+
+ first = True
+ for i in tqdm(range(0, len(wav_files), batch_size)):
+ batch_files = wav_files[i:i+batch_size]
+ batch_audio_l = []
+ batch_audio_r = []
+ batch_sr = []
+
+ for file in batch_files:
+ audio, sr = librosa.load(file, sr=None, mono=False)
+ audio = audio.T
+ audio = pyln.normalize.peak(audio, -1.0)
+ if audio.shape[0] < sr:
+ print('Audio shorter than 1 sec, openl3 will zero-pad it:', file, audio.shape, sr)
+
+ # resample to the desired evaluation bandwidth
+ audio = soxr.resample(audio, sr, samplingrate) # mono/stereo <- mono/stereo, input sr, output sr
+
+ # mono embeddings are stored in batch_audio_l (R channel not used)
+ if channels == 1:
+ batch_audio_l.append(audio)
+
+ elif channels == 2:
+ if audio.ndim == 1:
+ # if mono, "fake" stereo by copying mono channel to L and R
+ batch_audio_l.append(audio)
+ batch_audio_r.append(audio)
+ elif audio.ndim == 2:
+ # if it's stereo separate channels for openl3
+ batch_audio_l.append(audio[:,0])
+ batch_audio_r.append(audio[:,1])
+
+ batch_sr.append(samplingrate)
+
+ # extracting mono embeddings (dim=512) or the L channel for stereo embeddings
+ emb, _ = openl3.get_audio_embedding(batch_audio_l, batch_sr, model=model, verbose=False, hop_size=openl3_hop_size, batch_size=batch_size)
+
+ # format mono embedding
+ if channels == 1:
+ emb = np.concatenate(emb,axis=0)
+
+ # extracting stereo embeddings (dim=1024), since we concatenate L (dim=512) and R (dim=512) embeddings
+ elif channels == 2:
+ # extract the missing R channel
+ emb_r, _ = openl3.get_audio_embedding(batch_audio_r, batch_sr, model=model, verbose=False, hop_size=openl3_hop_size, batch_size=batch_size)
+ emb = [np.concatenate([l, r], axis=1) for l, r in zip(emb, emb_r)]
+ emb = np.concatenate(emb, axis=0)
+
+ # concatenate embeddings
+ if first:
+ embeddings = emb
+ first = False
+ else:
+ embeddings = np.concatenate([embeddings, emb], axis=0)
+
+ # return as a list of embeddings: [np.array[], ...]
+ return [e for e in embeddings]
+
+
+def extract_embeddings_nobatching(directory_path, channels, samplingrate, content_type, openl3_hop_size):
+ """
+ Given a list of files, compute their embeddings one by one.
+
+ If channels == 1: stereo audio is downmixed to mono. Mono embeddings are of dim=512.
+
+ If channels == 2: mono audio is "faked" to stereo by copying the mono channel.
+ Stereo embeddings are of dim=1024, since we concatenate L (dim=512) and R (dim=512) embeddings.
+
+ Params:
+ -- directory_path: path where the generated audio files are available.
+ -- channels: 1 (mono), or 2 (stereo) to get mono or stereo embeddings.
+ -- samplingrate: max bandwidth at which we evaluate the given signals. Up to 48kHz.
+ -- content_type: 'music' or 'env' to select a content type specific openl3 model.
+ -- openl3_hop_size: analysis resolution of openl3 in seconds. Openl3's input window is 1 sec.
+ Returns:
+ -- list of embeddings: [np.array[], ...], as expected by calculate_frechet_distance()
+ """
+ _, extension = os.path.splitext(directory_path)
+ if extension.lower() == ".scp":
+ wav_files = []
+ with open(directory_path, "r") as f:
+ for line in f:
+ sec = line.strip().split(" ")
+ wav_files.append(sec[1])
+ else:
+ wav_files = glob.glob(directory_path)
+ if len(wav_files) == 0:
+ raise ValueError('No files with this extension in this path!')
+ model = openl3.models.load_audio_embedding_model(input_repr="mel256", content_type=content_type, embedding_size=512)
+
+ first = True
+ for file in tqdm(wav_files):
+ audio, sr = librosa.load(file, sr=None)
+ audio = pyln.normalize.peak(audio, -1.0)
+ if audio.shape[0] < sr:
+ print('Audio shorter than 1 sec, openl3 will zero-pad it:', file, audio.shape, sr)
+
+ # resample to the desired evaluation bandwidth
+ audio = soxr.resample(audio, sr, samplingrate) # mono/stereo <- mono/stereo, input sr, output sr
+
+ # extracting stereo embeddings (dim=1024), since we concatenate L (dim=512) and R (dim=512) embeddings
+ if channels == 2:
+ if audio.ndim == 1:
+ audio_l3, sr_l3 = audio, samplingrate
+ elif audio.ndim == 2:
+ # if it's stereo separate channels for openl3
+ audio_l3 = [audio[:,0], audio[:,1]]
+ sr_l3 = [samplingrate, samplingrate]
+ emb, _ = openl3.get_audio_embedding(audio_l3, sr_l3, model=model, verbose=False, hop_size=openl3_hop_size)
+ if audio.ndim == 1:
+ # if mono audio, "fake" stereo by concatenating mono embedding as L and R embeddings
+ emb = np.concatenate([emb, emb],axis=1)
+ elif audio.ndim == 2:
+ emb = np.concatenate(emb,axis=1)
+
+ # or extracting mono embeddings (dim=512)
+ elif channels == 1:
+ emb, _ = openl3.get_audio_embedding(audio, samplingrate, model=model, verbose=False, hop_size=openl3_hop_size)
+
+ # concatenate embeddings
+ if first:
+ embeddings = emb
+ first = False
+ else:
+ embeddings = np.concatenate([embeddings, emb], axis=0)
+
+ # return as a list of embeddings: [np.array[], ...]
+ return [e for e in embeddings]
+
+
+def openl3_fd(channels, samplingrate, content_type, openl3_hop_size, eval_path,
+ eval_files_extension='.wav', ref_path=None, ref_files_extension='.wav', load_ref_embeddings=None, batching=False):
+ """
+ Compute the Frรฉchet Distance between files in eval_path and ref_path.
+
+ Frรฉchet distance computed on top of openl3 embeddings.
+
+ GPU-based computation.
+
+ Extracting the embeddings is timeconsuming. After being computed once, we store them.
+ We store pre-computed reference embedding statistics in load/openl3_fd/
+ To load those and save computation, just set the path in load_ref_embeddings.
+ If load_ref_embeddings is set, ref_path is not required.
+
+ Params:
+ -- channels: 1 (mono), or 2 (stereo) to get the Frรฉchet Distance over mono or stereo embeddings.
+ -- samplingrate: max bandwith at wich we evaluate the given signals. Up to 48kHz.
+ -- content_type: 'music' or 'env' to select a content type for openl3.
+ -- openl3_hop_size: analysis resolution of openl3 in seconds. Openl3's input window is 1 sec.
+ -- eval_path: path where the generated audio files to evaluate are available.
+ -- eval_files_extenstion: files extension (default .wav) in eval_path.
+ -- ref_path: path where the reference audio files are available. (instead of load_ref_embeddings)
+ -- ref_files_extension: files extension (default .wav) in ref_path.
+ -- load_ref_embeddings: path to the reference embedding statistics. (inestead of ref_path)
+ -- batching: set batch size (with an int) or set to False (default False).
+ Returns:
+ -- Frรฉchet distance.
+ """
+
+ if not os.path.isdir(eval_path):
+ raise ValueError('eval_path does not exist')
+
+ if load_ref_embeddings:
+ if not os.path.exists(load_ref_embeddings):
+ raise ValueError('load_ref_embeddings does not exist')
+ print('[LOADING REFERENCE EMBEDDINGS] ', load_ref_embeddings)
+ loaded = np.load(load_ref_embeddings)
+ mu_ref = loaded['mu_ref']
+ sigma_ref = loaded['sigma_ref']
+
+ else:
+ if ref_path:
+ if not os.path.isdir(ref_path):
+ if not os.path.isfile(ref_path):
+ raise ValueError("ref_path does not exist")
+ if os.path.isfile(ref_path):
+ path = ref_path
+ else:
+ path = os.path.join(ref_path, '*'+ref_files_extension)
+ print('[EXTRACTING REFERENCE EMBEDDINGS] ', path)
+ if batching:
+ ref_embeddings = extract_embeddings(path, channels, samplingrate, content_type, openl3_hop_size, batch_size=batching)
+ else:
+ ref_embeddings = extract_embeddings_nobatching(path, channels, samplingrate, content_type, openl3_hop_size)
+ mu_ref, sigma_ref = calculate_embd_statistics(ref_embeddings)
+
+ # store statistics to load later on
+ if not os.path.exists('load/openl3_fd'):
+ os.makedirs('load/openl3_fd/')
+ save_ref_embeddings_path = (
+ 'load/openl3_fd/' +
+ path.replace('/', '_') +
+ '__channels' + str(channels) +
+ '__' + str(samplingrate) +
+ '__openl3' + str(content_type) +
+ '__openl3hopsize' + str(openl3_hop_size) +
+ '__batch' + str(batching) +
+ '.npz'
+ )
+ np.savez(save_ref_embeddings_path, mu_ref=mu_ref, sigma_ref=sigma_ref)
+ print('[REFERENCE EMBEDDINGS][SAVED] ', save_ref_embeddings_path)
+
+ else:
+ raise ValueError('Must specify ref_path or load_ref_embeddings')
+
+ path = os.path.join(eval_path, '*'+eval_files_extension)
+ print('[EXTRACTING EVALUATION EMBEDDINGS] ', path)
+ if batching:
+ eval_embeddings = extract_embeddings(path, channels, samplingrate, content_type, openl3_hop_size, batch_size=batching)
+ else:
+ eval_embeddings = extract_embeddings_nobatching(path, channels, samplingrate, content_type, openl3_hop_size)
+ mu_eval, sigma_eval = calculate_embd_statistics(eval_embeddings)
+
+ fd = calculate_frechet_distance(mu_eval, sigma_eval, mu_ref, sigma_ref)
+ if load_ref_embeddings:
+ print('[FRรฉCHET DISTANCE] ', eval_path, load_ref_embeddings, fd)
+ else:
+ print('[FRรฉCHET DISTANCE] ', eval_path, ref_path, fd)
+
+ return fd
\ No newline at end of file
diff --git a/inspiremusic/metrics/passt_kld.py b/inspiremusic/metrics/passt_kld.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa27835ee82161f9e7cd1c7f9d99c07409bbbfc0
--- /dev/null
+++ b/inspiremusic/metrics/passt_kld.py
@@ -0,0 +1,232 @@
+# Copyright (c) 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import warnings
+warnings.filterwarnings("ignore", category=UserWarning)
+warnings.filterwarnings("ignore", category=FutureWarning)
+
+import os
+import contextlib
+from functools import partial
+from tqdm import tqdm
+import pickle
+import numpy as np
+import librosa
+from hear21passt.base import get_basic_model
+import pyloudnorm as pyln
+
+import torch
+import torch.nn.functional as F
+
+
+SAMPLING_RATE = 32000
+
+
+class _patch_passt_stft:
+ """
+ From version 1.8.0, return_complex must always be given explicitly
+ for real inputs and return_complex=False has been deprecated.
+
+ Decorator to patch torch.stft in PaSST that uses an old stft version.
+
+ Adapted from: https://github.com/facebookresearch/audiocraft/blob/a2b96756956846e194c9255d0cdadc2b47c93f1b/audiocraft/metrics/kld.py
+ """
+ def __init__(self):
+ self.old_stft = torch.stft
+
+ def __enter__(self):
+ # return_complex is a mandatory parameter in latest torch versions.
+ # torch is throwing RuntimeErrors when not set.
+ # see: https://pytorch.org/docs/1.7.1/generated/torch.stft.html?highlight=stft#torch.stft
+ #ย see: https://github.com/kkoutini/passt_hear21/commit/dce83183674e559162b49924d666c0a916dc967a
+ torch.stft = partial(torch.stft, return_complex=False)
+
+ def __exit__(self, *exc):
+ torch.stft = self.old_stft
+
+
+def return_probabilities(model, audio_path, window_size=10, overlap=5, collect='mean'):
+ """
+ Given an audio and the PaSST model, return the probabilities of each AudioSet class.
+
+ Audio is converted to mono at 32kHz.
+
+ PaSST model is trained with 10 sec inputs. We refer to this parameter as the window_size.
+ We set it to 10 sec for consistency with PaSST training.
+
+ For longer audios, we split audio into overlapping analysis windows of window_size and overlap of 10 and 5 seconds.
+ PaSST supports 10, 20 or 30 sec inputs. Not longer inputs: https://github.com/kkoutini/PaSST/issues/19
+
+ Note that AudioSet taggers normally use sigmoid output layers. Yet, to compute the
+ KL we work with normalized probabilities by running a softmax over logits as in MusicGen:
+ https://github.com/facebookresearch/audiocraft/blob/a2b96756956846e194c9255d0cdadc2b47c93f1b/audiocraft/metrics/kld.py
+
+ This implementation assumes run will be on GPU.
+
+ Params:
+ -- model: PaSST model on a GPU.
+ -- audio_path: path to the audio to be loaded with librosa.
+ -- window_size (default=10 sec): analysis window (and receptive field) of PaSST.
+ -- overlap (default=5 sec): overlap of the running analysis window for inputs longar than window_size (10 sec).
+ -- collect (default='mean'): for longer inputs, aggregate/collect via 'mean' or 'max' pooling along logits vector.
+ Returns:
+ -- 527 probabilities (after softmax, no logarithm).
+ """
+ # load the audio using librosa
+ audio, _ = librosa.load(audio_path, sr=SAMPLING_RATE, mono=True)
+ audio = pyln.normalize.peak(audio, -1.0)
+
+ # calculate the step size for the analysis windows with the specified overlap
+ step_size = int((window_size - overlap) * SAMPLING_RATE)
+
+ # iterate over the audio, creating analysis windows
+ probabilities = []
+ for i in range(0, max(step_size, len(audio) - step_size), step_size):
+ # extract the current analysis window
+ window = audio[i:i + int(window_size * SAMPLING_RATE)]
+
+ # pad the window with zeros if it's shorter than the desired window size
+ if len(window) < int(window_size * SAMPLING_RATE):
+ # discard window if it's too small (avoid mostly zeros predicted as silence), as in MusicGen:
+ # https://github.com/facebookresearch/audiocraft/blob/a2b96756956846e194c9255d0cdadc2b47c93f1b/audiocraft/metrics/kld.py
+ if len(window) > int(window_size * SAMPLING_RATE * 0.15):
+ tmp = np.zeros(int(window_size * SAMPLING_RATE))
+ tmp[:len(window)] = window
+ window = tmp
+
+ # convert to a PyTorch tensor and move to GPU
+ audio_wave = torch.from_numpy(window.astype(np.float32)).unsqueeze(0).cuda()
+
+ # get the probabilities for this analysis window
+ with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f):
+ with torch.no_grad(), _patch_passt_stft():
+ logits = model(audio_wave)
+ probabilities.append(torch.squeeze(logits))
+
+ probabilities = torch.stack(probabilities)
+ if collect == 'mean':
+ probabilities = torch.mean(probabilities, dim=0)
+ elif collect == 'max':
+ probabilities, _ = torch.max(probabilities, dim=0)
+
+ return F.softmax(probabilities, dim=0).squeeze().cpu()
+
+
+def passt_kld(ids, eval_path, eval_files_extension='.wav', ref_path=None, ref_files_extension='.wav', load_ref_probabilities=None, no_ids=[], collect='mean'):
+ """
+ Compute KL-divergence between the label probabilities of the generated audio with respect to the original audio.
+ Both generated audio (in eval_path) and original audio (in ref_path) are represented by the same prompt/description.
+ Audios are identified by an id, that is the name of the file in both directories and links the audio with the prompt/description.
+ segmenting the audio
+
+ For inputs longer that the 10 sec PaSST was trained on, we aggregate/collect via 'mean' (default) or 'max' pooling along the logits vector.
+ We split the inpot into overlapping analysis windows. Subsequently, we aggregate/collect (accross windows) the generated logits and then apply a softmax.
+
+ This evaluation script assumes that ids are in both ref_path and eval_path.
+
+ We label probabilities via the PaSST model: https://github.com/kkoutini/PaSST
+
+ GPU-based computation.
+
+ Extracting the probabilities is timeconsuming. After being computed once, we store them.
+ We store pre-computed reference probabilities in load/
+ To load those and save computation, just set the path in load_ref_probabilities.
+ If load_ref_probabilities is set, ref_path is not required.
+
+ Params:
+ -- ids: list of ids present in both eval_path and ref_path.
+ -- eval_path: path where the generated audio files to evaluate are available.
+ -- eval_files_extenstion: files extension (default .wav) in eval_path.
+ -- ref_path: path where the reference audio files are available. (instead of load_ref_probabilities)
+ -- ref_files_extenstion: files extension (default .wav) in ref_path.
+ -- load_ref_probabilities: path to the reference probabilities. (inestead of ref_path)
+ -- no_ids: it is possible that some reference audio is corrupted or not present. Ignore some this list of ids.
+ -- collect (default='mean'): for longer inputs, aggregate/collect via 'mean' or 'max' pooling along the logits vector.
+ Returns:
+ -- KL divergence
+ """
+ with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f): # capturing all useless outputs from passt
+ # load model
+ model = get_basic_model(mode="logits")
+ model.eval()
+ model = model.cuda()
+
+ if not os.path.isdir(eval_path):
+ if not os.path.isfile(eval_path):
+ raise ValueError('eval_path does not exist')
+
+ if load_ref_probabilities:
+ if not os.path.exists(load_ref_probabilities):
+ raise ValueError('load_ref_probabilities does not exist')
+ print('[LOADING REFERENCE PROBABILITIES] ', load_ref_probabilities)
+ with open(load_ref_probabilities, 'rb') as fp:
+ ref_p = pickle.load(fp)
+
+ else:
+ if ref_path:
+ if not os.path.isdir(ref_path):
+ if os.path.isfile(ref_path):
+ id2utt = {}
+ with open(ref_path, "r") as f:
+ for line in f:
+ sec = line.strip().split(" ")
+ id2utt[sec[0]] = sec[1]
+ f.close()
+ else:
+ raise ValueError("ref_path does not exist")
+ print('[EXTRACTING REFERENCE PROBABILITIES] ', ref_path)
+ ref_p = {}
+ for id in tqdm(ids):
+ if id not in no_ids:
+ try:
+ if os.path.isfile(ref_path):
+ if id in id2utt.keys():
+ audio_path = id2utt[id]
+ else:
+ raise ValueError(f"id: {id} not in {ref_path}!")
+ else:
+ audio_path = os.path.join(ref_path, str(id)+ref_files_extension)
+ if os.path.isfile(audio_path):
+ ref_p[id] = return_probabilities(model, audio_path, collect=collect)
+ except Exception as e:
+ print(f"An unexpected error occurred with {id}: {e}\nIf you failed to download it you can add it to no_ids list.")
+
+ # store reference probabilities to load later on
+ if not os.path.exists('load/passt_kld/'):
+ os.makedirs('load/passt_kld/')
+ save_ref_probabilities_path = 'load/passt_kld/'+ref_path.replace('/', '_')+'_collect'+str(collect)+'__reference_probabilities.pkl'
+ with open(save_ref_probabilities_path, 'wb') as fp:
+ pickle.dump(ref_p, fp)
+ print('[REFERENCE EMBEDDINGS][SAVED] ', save_ref_probabilities_path)
+
+ else:
+ raise ValueError('Must specify ref_path or load_ref_probabilities')
+
+ print('[EVALUATING GENERATIONS] ', eval_path)
+
+ passt_kl = 0
+ count = 0
+ for id in tqdm(ids):
+ if id not in no_ids:
+ try:
+ audio_path = os.path.join(eval_path, str(id)+eval_files_extension)
+ if os.path.isfile(audio_path):
+ eval_p = return_probabilities(model, audio_path, collect=collect)
+ # note: F.kl_div(x, y) is KL(y||x)
+ # see: https://github.com/pytorch/pytorch/issues/7337
+ # see: https://discuss.pytorch.org/t/kl-divergence-different-results-from-tf/56903/2
+ passt_kl += F.kl_div((ref_p[id] + 1e-6).log(), eval_p, reduction='sum', log_target=False)
+ count += 1
+ except Exception as e:
+ print(f"An unexpected error occurred with {id}: {e}\nIf you failed to download it you can add it to no_ids list.")
+ return passt_kl / count if count > 0 else 0
diff --git a/inspiremusic/music_tokenizer/__init__.py b/inspiremusic/music_tokenizer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/inspiremusic/music_tokenizer/env.py b/inspiremusic/music_tokenizer/env.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe5843b8a244e9648bcd6f9e085dff9faa2e921a
--- /dev/null
+++ b/inspiremusic/music_tokenizer/env.py
@@ -0,0 +1,29 @@
+# Copyright (c) 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import shutil
+
+
+class AttrDict(dict):
+ def __init__(self, *args, **kwargs):
+ super(AttrDict, self).__init__(*args, **kwargs)
+ self.__dict__ = self
+
+
+def build_env(config, config_name, path):
+ t_path = os.path.join(path, config_name)
+ if config != t_path:
+ os.makedirs(path, exist_ok=True)
+ shutil.copyfile(config, os.path.join(path, config_name))
diff --git a/inspiremusic/music_tokenizer/meldataset.py b/inspiremusic/music_tokenizer/meldataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a13b6247ec04b74559976687ff108c3380d42f1
--- /dev/null
+++ b/inspiremusic/music_tokenizer/meldataset.py
@@ -0,0 +1,226 @@
+# Copyright (c) 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# code based on https://github.com/b04901014/MQTTS
+import math
+import os
+import random
+
+import librosa
+import numpy as np
+import torch.utils.data
+from librosa.filters import mel as librosa_mel_fn
+
+def load_wav(full_path, sr):
+ wav, sr = librosa.load(full_path, sr=sr)
+ return wav, sr
+
+def dynamic_range_compression(x, C=1, clip_val=1e-5):
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
+
+def dynamic_range_decompression(x, C=1):
+ return np.exp(x) / C
+
+def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
+ return torch.log(torch.clamp(x, min=clip_val) * C)
+
+def dynamic_range_decompression_torch(x, C=1):
+ return torch.exp(x) / C
+
+def spectral_normalize_torch(magnitudes):
+ output = dynamic_range_compression_torch(magnitudes)
+ return output
+
+def spectral_de_normalize_torch(magnitudes):
+ output = dynamic_range_decompression_torch(magnitudes)
+ return output
+
+mel_basis = {}
+hann_window = {}
+
+## modified to get stft with return complex value = True for pytorch ver2.0
+def mel_spectrogram(y,
+ n_fft,
+ num_mels,
+ sampling_rate,
+ hop_size,
+ win_size,
+ fmin,
+ fmax,
+ center=False):
+
+ global mel_basis, hann_window
+ if fmax not in mel_basis:
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
+ mel_basis[str(fmax) + '_' +
+ str(y.device)] = torch.from_numpy(mel).float().to(y.device)
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
+
+ y = torch.nn.functional.pad(
+ y.unsqueeze(1), (int((n_fft - hop_size) / 2), int(
+ (n_fft - hop_size) / 2)),
+ mode='reflect')
+ y = y.squeeze(1)
+
+ spec = torch.view_as_real(torch.stft(
+ y,
+ n_fft,
+ hop_length=hop_size,
+ win_length=win_size,
+ window=hann_window[str(y.device)],
+ center=center,
+ pad_mode='reflect',
+ normalized=False,
+ onesided=True,
+ return_complex=True
+ ))
+
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
+
+ spec = torch.matmul(mel_basis[str(fmax) + '_' + str(y.device)], spec)
+ spec = spectral_normalize_torch(spec)
+
+ return spec
+
+
+def get_dataset_filelist(a):
+ with open(a.input_training_file, 'r') as f:
+ training_files = [l.strip() for l in f]
+ with open(a.input_validation_file, 'r') as f:
+ validation_files = [l.strip() for l in f]
+ return training_files, validation_files
+
+
+class MelDataset(torch.utils.data.Dataset):
+ def __init__(self,
+ training_files,
+ segment_size,
+ n_fft,
+ num_mels,
+ hop_size,
+ win_size,
+ sampling_rate,
+ fmin,
+ fmax,
+ split=True,
+ shuffle=True,
+ n_cache_reuse=1,
+ device=None,
+ fmax_loss=None,
+ fine_tuning=False,
+ base_mels_path=None):
+ self.audio_files = training_files
+ random.seed(1234)
+ if shuffle:
+ random.shuffle(self.audio_files)
+ self.segment_size = segment_size
+ self.sampling_rate = sampling_rate
+ self.split = split
+ self.n_fft = n_fft
+ self.num_mels = num_mels
+ self.hop_size = hop_size
+ self.win_size = win_size
+ self.fmin = fmin
+ self.fmax = fmax
+ self.fmax_loss = fmax_loss
+ self.cached_wav = None
+ self.n_cache_reuse = n_cache_reuse
+ self._cache_ref_count = 0
+ self.device = device
+ self.fine_tuning = fine_tuning
+ self.base_mels_path = base_mels_path
+
+ def __getitem__(self, index):
+ filename = self.audio_files[index]
+ if self._cache_ref_count == 0:
+ try:
+ # Note by yuantian: load with the sample_rate of config
+ audio, sampling_rate = load_wav(filename, sr=self.sampling_rate)
+ except Exception as e:
+ print(f"Error on audio: {filename}")
+ audio = np.random.normal(size=(160000, )) * 0.05
+ sampling_rate = self.sampling_rate
+ self.cached_wav = audio
+ if sampling_rate != self.sampling_rate:
+ raise ValueError("{} SR doesn't match target {} SR".format(
+ sampling_rate, self.sampling_rate))
+ self._cache_ref_count = self.n_cache_reuse
+ else:
+ audio = self.cached_wav
+ self._cache_ref_count -= 1
+
+ audio = torch.FloatTensor(audio)
+ audio = audio.unsqueeze(0)
+
+ if not self.fine_tuning:
+ if self.split:
+ if audio.size(1) >= self.segment_size:
+ max_audio_start = audio.size(1) - self.segment_size
+ audio_start = random.randint(0, max_audio_start)
+ audio = audio[:, audio_start:audio_start +
+ self.segment_size]
+ else:
+ audio = torch.nn.functional.pad(audio, (
+ 0, self.segment_size - audio.size(1)), 'constant')
+
+ mel = mel_spectrogram(
+ audio,
+ self.n_fft,
+ self.num_mels,
+ self.sampling_rate,
+ self.hop_size,
+ self.win_size,
+ self.fmin,
+ self.fmax,
+ center=False)
+ else:
+ mel = np.load(
+ os.path.join(self.base_mels_path,
+ os.path.splitext(os.path.split(filename)[-1])[0] +
+ '.npy'))
+ mel = torch.from_numpy(mel)
+
+ if len(mel.shape) < 3:
+ mel = mel.unsqueeze(0)
+
+ if self.split:
+ frames_per_seg = math.ceil(self.segment_size / self.hop_size)
+
+ if audio.size(1) >= self.segment_size:
+ mel_start = random.randint(0,
+ mel.size(2) - frames_per_seg - 1)
+ mel = mel[:, :, mel_start:mel_start + frames_per_seg]
+ audio = audio[:, mel_start * self.hop_size:(
+ mel_start + frames_per_seg) * self.hop_size]
+ else:
+ mel = torch.nn.functional.pad(mel, (
+ 0, frames_per_seg - mel.size(2)), 'constant')
+ audio = torch.nn.functional.pad(audio, (
+ 0, self.segment_size - audio.size(1)), 'constant')
+
+ mel_loss = mel_spectrogram(
+ audio,
+ self.n_fft,
+ self.num_mels,
+ self.sampling_rate,
+ self.hop_size,
+ self.win_size,
+ self.fmin,
+ self.fmax_loss,
+ center=False)
+
+ return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
+
+ def __len__(self):
+ return len(self.audio_files)
diff --git a/inspiremusic/music_tokenizer/models.py b/inspiremusic/music_tokenizer/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..86302c699224252e72b05971c6a52a2ba7e8764d
--- /dev/null
+++ b/inspiremusic/music_tokenizer/models.py
@@ -0,0 +1,548 @@
+# Copyright (c) 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import AvgPool1d
+from torch.nn import Conv1d
+from torch.nn import Conv2d
+from torch.nn import ConvTranspose1d
+from torch.nn.utils import remove_weight_norm
+from torch.nn.utils import spectral_norm
+from torch.nn.utils import weight_norm
+
+from inspiremusic.utils.tokenizer_utils import get_padding
+from inspiremusic.utils.tokenizer_utils import init_weights
+
+LRELU_SLOPE = 0.1
+
+class ResBlock1(torch.nn.Module):
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
+ super(ResBlock1, self).__init__()
+ self.h = h
+ self.convs1 = nn.ModuleList([
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]))),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]))),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[2],
+ padding=get_padding(kernel_size, dilation[2])))
+ ])
+ self.convs1.apply(init_weights)
+
+ self.convs2 = nn.ModuleList([
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1))), weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1))), weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1)))
+ ])
+ self.convs2.apply(init_weights)
+
+ def forward(self, x):
+ for c1, c2 in zip(self.convs1, self.convs2):
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ xt = c1(xt)
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
+ xt = c2(xt)
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs1:
+ remove_weight_norm(l)
+ for l in self.convs2:
+ remove_weight_norm(l)
+
+
+class ResBlock2(torch.nn.Module):
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
+ super(ResBlock2, self).__init__()
+ self.h = h
+ self.convs = nn.ModuleList([
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]))),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1])))
+ ])
+ self.convs.apply(init_weights)
+
+ def forward(self, x):
+ for c in self.convs:
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ xt = c(xt)
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs:
+ remove_weight_norm(l)
+
+
+class Generator(torch.nn.Module):
+ def __init__(self, h):
+ super(Generator, self).__init__()
+ self.h = h
+ self.num_kernels = len(h.resblock_kernel_sizes)
+ self.num_upsamples = len(h.upsample_rates)
+ self.conv_pre = weight_norm(
+ Conv1d(512, h.upsample_initial_channel, 7, 1, padding=3))
+ resblock = ResBlock1 if h.resblock == '1' else ResBlock2
+
+ self.ups = nn.ModuleList()
+ for i, (u,
+ k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
+ self.ups.append(
+ weight_norm(
+ ConvTranspose1d(
+ h.upsample_initial_channel // (2**i),
+ h.upsample_initial_channel // (2**(i + 1)),
+ k,
+ u,
+ # padding=(u//2 + u%2),
+ padding=(k - u) // 2,
+ # output_padding=u%2
+ )))
+
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = h.upsample_initial_channel // (2**(i + 1))
+ for j, (k, d) in enumerate(
+ zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
+ self.resblocks.append(resblock(h, ch, k, d))
+
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
+ self.ups.apply(init_weights)
+ self.conv_post.apply(init_weights)
+
+ def forward(self, x):
+ x = self.conv_pre(x)
+ for i in range(self.num_upsamples):
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ x = self.ups[i](x)
+ xs = None
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i * self.num_kernels + j](x)
+ else:
+ xs += self.resblocks[i * self.num_kernels + j](x)
+ x = xs / self.num_kernels
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ x = self.conv_post(x)
+ x = torch.tanh(x)
+
+ return x
+
+ def remove_weight_norm(self):
+ print('Removing weight norm...')
+ for l in self.ups:
+ remove_weight_norm(l)
+ for l in self.resblocks:
+ l.remove_weight_norm()
+ remove_weight_norm(self.conv_pre)
+ remove_weight_norm(self.conv_post)
+
+
+class DiscriminatorP(torch.nn.Module):
+ def __init__(self, period, kernel_size=5, stride=3,
+ use_spectral_norm=False):
+ super(DiscriminatorP, self).__init__()
+ self.period = period
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
+ self.convs = nn.ModuleList([
+ norm_f(
+ Conv2d(
+ 1,
+ 32, (kernel_size, 1), (stride, 1),
+ padding=(get_padding(5, 1), 0))),
+ norm_f(
+ Conv2d(
+ 32,
+ 128, (kernel_size, 1), (stride, 1),
+ padding=(get_padding(5, 1), 0))),
+ norm_f(
+ Conv2d(
+ 128,
+ 512, (kernel_size, 1), (stride, 1),
+ padding=(get_padding(5, 1), 0))),
+ norm_f(
+ Conv2d(
+ 512,
+ 1024, (kernel_size, 1), (stride, 1),
+ padding=(get_padding(5, 1), 0))),
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
+ ])
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
+
+ def forward(self, x):
+ fmap = []
+
+ # 1d to 2d
+ b, c, t = x.shape
+ if t % self.period != 0: # pad first
+ n_pad = self.period - (t % self.period)
+ x = F.pad(x, (0, n_pad), "reflect")
+ t = t + n_pad
+ x = x.view(b, c, t // self.period, self.period)
+
+ for l in self.convs:
+ x = l(x)
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+ x = torch.flatten(x, 1, -1)
+
+ return x, fmap
+
+
+class MultiPeriodDiscriminator(torch.nn.Module):
+ def __init__(self):
+ super(MultiPeriodDiscriminator, self).__init__()
+ self.discriminators = nn.ModuleList([
+ DiscriminatorP(2),
+ DiscriminatorP(3),
+ DiscriminatorP(5),
+ DiscriminatorP(7),
+ DiscriminatorP(11),
+ ])
+
+ def forward(self, y, y_hat):
+ y_d_rs = []
+ y_d_gs = []
+ fmap_rs = []
+ fmap_gs = []
+ for i, d in enumerate(self.discriminators):
+ y_d_r, fmap_r = d(y)
+ y_d_g, fmap_g = d(y_hat)
+ y_d_rs.append(y_d_r)
+ fmap_rs.append(fmap_r)
+ y_d_gs.append(y_d_g)
+ fmap_gs.append(fmap_g)
+
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+class DiscriminatorS(torch.nn.Module):
+ def __init__(self, use_spectral_norm=False):
+ super(DiscriminatorS, self).__init__()
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
+ self.convs = nn.ModuleList([
+ norm_f(Conv1d(1, 128, 15, 1, padding=7)),
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
+ ])
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
+
+ def forward(self, x):
+ fmap = []
+ for l in self.convs:
+ x = l(x)
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+ x = torch.flatten(x, 1, -1)
+
+ return x, fmap
+
+
+class MultiScaleDiscriminator(torch.nn.Module):
+ def __init__(self):
+ super(MultiScaleDiscriminator, self).__init__()
+ self.discriminators = nn.ModuleList([
+ DiscriminatorS(use_spectral_norm=True),
+ DiscriminatorS(),
+ DiscriminatorS(),
+ ])
+ self.meanpools = nn.ModuleList(
+ [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)])
+
+ def forward(self, y, y_hat):
+ y_d_rs = []
+ y_d_gs = []
+ fmap_rs = []
+ fmap_gs = []
+ for i, d in enumerate(self.discriminators):
+ if i != 0:
+ y = self.meanpools[i - 1](y)
+ y_hat = self.meanpools[i - 1](y_hat)
+ y_d_r, fmap_r = d(y)
+ y_d_g, fmap_g = d(y_hat)
+ y_d_rs.append(y_d_r)
+ fmap_rs.append(fmap_r)
+ y_d_gs.append(y_d_g)
+ fmap_gs.append(fmap_g)
+
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+def feature_loss(fmap_r, fmap_g):
+ loss = 0
+ for dr, dg in zip(fmap_r, fmap_g):
+ for rl, gl in zip(dr, dg):
+ loss += torch.mean(torch.abs(rl - gl))
+
+ return loss * 2
+
+
+def discriminator_loss(disc_real_outputs, disc_generated_outputs):
+ loss = 0
+ r_losses = []
+ g_losses = []
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
+ r_loss = torch.mean((1 - dr)**2)
+ g_loss = torch.mean(dg**2)
+ loss += (r_loss + g_loss)
+ r_losses.append(r_loss.item())
+ g_losses.append(g_loss.item())
+
+ return loss, r_losses, g_losses
+
+
+def generator_loss(disc_outputs):
+ loss = 0
+ gen_losses = []
+ for dg in disc_outputs:
+ l = torch.mean((1 - dg)**2)
+ gen_losses.append(l)
+ loss += l
+
+ return loss, gen_losses
+
+
+class Encoder(torch.nn.Module):
+ def __init__(self, h):
+ super(Encoder, self).__init__()
+ self.h = h
+ self.num_kernels = len(h.resblock_kernel_sizes)
+ self.num_upsamples = len(h.upsample_rates)
+ self.conv_pre = weight_norm(Conv1d(1, 32, 7, 1, padding=3))
+ self.normalize = nn.ModuleList()
+ resblock = ResBlock1 if h.resblock == '1' else ResBlock2
+
+ self.ups = nn.ModuleList()
+ for i, (u, k) in enumerate(
+ list(
+ reversed(
+ list(zip(h.upsample_rates, h.upsample_kernel_sizes))))):
+ self.ups.append(
+ weight_norm(
+ Conv1d(
+ 32 * (2**i),
+ 32 * (2**(i + 1)),
+ k,
+ u,
+ padding=((k - u) // 2)
+ # padding=(u//2 + u%2)
+ )))
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = 32 * (2**(i + 1))
+ for j, (k, d) in enumerate(
+ zip(
+ list(reversed(h.resblock_kernel_sizes)),
+ list(reversed(h.resblock_dilation_sizes)))):
+ self.resblocks.append(resblock(h, ch, k, d))
+ self.normalize.append(
+ torch.nn.GroupNorm(ch // 16, ch, eps=1e-6, affine=True))
+ self.conv_post = Conv1d(512, 512, 3, 1, padding=1)
+ self.ups.apply(init_weights)
+ self.conv_post.apply(init_weights)
+
+ def forward(self, x):
+ x = self.conv_pre(x)
+ for i in range(self.num_upsamples):
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ x = self.ups[i](x)
+ xs = None
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i * self.num_kernels + j](x)
+ xs = self.normalize[i * self.num_kernels + j](xs)
+ else:
+ xs += self.resblocks[i * self.num_kernels + j](x)
+ xs = self.normalize[i * self.num_kernels + j](xs)
+ x = xs / self.num_kernels
+ x = F.leaky_relu(x)
+ x = self.conv_post(x)
+ return x
+
+ def remove_weight_norm(self):
+ print('Removing weight norm...')
+ for l in self.ups:
+ remove_weight_norm(l)
+ for l in self.resblocks:
+ l.remove_weight_norm()
+ remove_weight_norm(self.conv_pre)
+
+
+class Quantizer_module(torch.nn.Module):
+ def __init__(self, n_e, e_dim):
+ super(Quantizer_module, self).__init__()
+ self.embedding = nn.Embedding(n_e, e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e)
+
+ def forward(self, x):
+ # compute Euclidean distance
+ d = torch.sum(x ** 2, 1, keepdim=True) + torch.sum(self.embedding.weight ** 2, 1) \
+ - 2 * torch.matmul(x, self.embedding.weight.T)
+ min_indicies = torch.argmin(d, 1)
+ z_q = self.embedding(min_indicies)
+ return z_q, min_indicies
+
+
+class Quantizer(torch.nn.Module):
+ def __init__(self, h):
+ super(Quantizer, self).__init__()
+ assert 512 % h.n_code_groups == 0
+ self.quantizer_modules = nn.ModuleList([
+ Quantizer_module(h.n_codes, 512 // h.n_code_groups)
+ for _ in range(h.n_code_groups)
+ ])
+ self.quantizer_modules2 = nn.ModuleList([
+ Quantizer_module(h.n_codes, 512 // h.n_code_groups)
+ for _ in range(h.n_code_groups)
+ ])
+ self.h = h
+ self.codebook_loss_lambda = self.h.codebook_loss_lambda # e.g., 1
+ self.commitment_loss_lambda = self.h.commitment_loss_lambda # e.g., 0.25
+ self.residul_layer = 2
+ self.n_code_groups = h.n_code_groups
+
+ def for_one_step(self, xin, idx):
+ xin = xin.transpose(1, 2)
+ x = xin.reshape(-1, 512)
+ x = torch.split(x, 512 // self.h.n_code_groups, dim=-1)
+ min_indicies = []
+ z_q = []
+ if idx == 0:
+ for _x, m in zip(x, self.quantizer_modules):
+ _z_q, _min_indicies = m(_x)
+ z_q.append(_z_q)
+ min_indicies.append(_min_indicies) #B * T,
+ z_q = torch.cat(z_q, -1).reshape(xin.shape)
+ # loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2)
+ loss = self.codebook_loss_lambda * torch.mean((z_q - xin.detach()) ** 2) \
+ + self.commitment_loss_lambda * torch.mean((z_q.detach() - xin) ** 2)
+ z_q = xin + (z_q - xin).detach()
+ z_q = z_q.transpose(1, 2)
+ return z_q, loss, min_indicies
+ else:
+ for _x, m in zip(x, self.quantizer_modules2):
+ _z_q, _min_indicies = m(_x)
+ z_q.append(_z_q)
+ min_indicies.append(_min_indicies) #B * T,
+ z_q = torch.cat(z_q, -1).reshape(xin.shape)
+ # loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2)
+ loss = self.codebook_loss_lambda * torch.mean((z_q - xin.detach()) ** 2) \
+ + self.commitment_loss_lambda * torch.mean((z_q.detach() - xin) ** 2)
+ z_q = xin + (z_q - xin).detach()
+ z_q = z_q.transpose(1, 2)
+ return z_q, loss, min_indicies
+
+ def forward(self, xin):
+ #B, C, T
+ quantized_out = 0.0
+ residual = xin
+ all_losses = []
+ all_indices = []
+ for i in range(self.residul_layer):
+ quantized, loss, indices = self.for_one_step(residual, i) #
+ residual = residual - quantized
+ quantized_out = quantized_out + quantized
+ all_indices.extend(indices) #
+ all_losses.append(loss)
+ all_losses = torch.stack(all_losses)
+ loss = torch.mean(all_losses)
+ return quantized_out, loss, all_indices
+
+ def embed(self, x):
+ #idx: N, T, 4
+ #print('x ', x.shape)
+ quantized_out = torch.tensor(0.0, device=x.device)
+ x = torch.split(x, 1, 2) # split, ๅฐๆๅไธไธช็ปดๅบฆๅๅผ, ๆฏไธชๅฑไบไธไธชindex group
+ #print('x.shape ', len(x),x[0].shape)
+ for i in range(self.residul_layer):
+ ret = []
+ if i == 0:
+ for j in range(self.n_code_groups):
+ q = x[j]
+ embed = self.quantizer_modules[j]
+ q = embed.embedding(q.squeeze(-1).long())
+ ret.append(q)
+ ret = torch.cat(ret, -1)
+ #print(ret.shape)
+ quantized_out = quantized_out + ret
+ else:
+ for j in range(self.n_code_groups):
+ q = x[j + self.n_code_groups]
+ embed = self.quantizer_modules2[j]
+ q = embed.embedding(q.squeeze(-1).long())
+ ret.append(q)
+ ret = torch.cat(ret, -1)
+ quantized_out = quantized_out + ret
+ return quantized_out.transpose(1, 2) #N, C, T
diff --git a/inspiremusic/music_tokenizer/vqvae.py b/inspiremusic/music_tokenizer/vqvae.py
new file mode 100644
index 0000000000000000000000000000000000000000..553275b3c25bb2f44bc9009ebaacaef7a346e206
--- /dev/null
+++ b/inspiremusic/music_tokenizer/vqvae.py
@@ -0,0 +1,58 @@
+# Copyright (c) 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+import torch
+import torch.nn as nn
+from inspiremusic.music_tokenizer.env import AttrDict
+from inspiremusic.music_tokenizer.models import Encoder
+from inspiremusic.music_tokenizer.models import Generator
+from inspiremusic.music_tokenizer.models import Quantizer
+
+
+class VQVAE(nn.Module):
+ def __init__(self,
+ config_path,
+ ckpt_path,
+ with_encoder=False):
+ super(VQVAE, self).__init__()
+ ckpt = torch.load(ckpt_path)
+ with open(config_path) as f:
+ data = f.read()
+ json_config = json.loads(data)
+ self.h = AttrDict(json_config)
+ self.quantizer = Quantizer(self.h)
+ self.generator = Generator(self.h)
+ self.generator.load_state_dict(ckpt['generator'])
+ self.quantizer.load_state_dict(ckpt['quantizer'])
+ if with_encoder:
+ self.encoder = Encoder(self.h)
+ self.encoder.load_state_dict(ckpt['encoder'])
+
+ def forward(self, x):
+ # x is the codebook
+ # x.shape (B, T, Nq)
+ quant_emb = self.quantizer.embed(x)
+ return self.generator(quant_emb)
+
+ def encode(self, x):
+ batch_size = x.size(0)
+ if len(x.shape) == 3 and x.shape[-1] == 1:
+ x = x.squeeze(-1)
+ c = self.encoder(x.unsqueeze(1))
+ q, loss_q, c = self.quantizer(c)
+ c = [code.reshape(batch_size, -1) for code in c]
+ # shape: [N, T, 4]
+ return torch.stack(c, -1)
diff --git a/inspiremusic/text/abs_tokenizer.py b/inspiremusic/text/abs_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..050e3a79bc9f195b6bc069b71327d69651fa2d1b
--- /dev/null
+++ b/inspiremusic/text/abs_tokenizer.py
@@ -0,0 +1,34 @@
+# Copyright (c) 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC
+from abc import abstractmethod
+from typing import Iterable
+from typing import List
+
+
+class AbsTokenizer(ABC):
+ @abstractmethod
+ def text2tokens(self, line: str) -> List[str]:
+ raise NotImplementedError
+
+ @abstractmethod
+ def tokens2text(self, tokens: Iterable[str]) -> str:
+ raise NotImplementedError
+
+
+
+ def encode(self, line: str, **kwargs) -> List[str]:
+
+ return self.text2tokens(line)
\ No newline at end of file
diff --git a/inspiremusic/text/tokenizer.py b/inspiremusic/text/tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b2978783ce1833b797dad58c35ebebe64a492cf
--- /dev/null
+++ b/inspiremusic/text/tokenizer.py
@@ -0,0 +1,76 @@
+# Copyright (c) 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import os
+import re
+from typing import Iterable, List, Union
+import numpy as np
+import torch
+
+from inspiremusic.text.abs_tokenizer import AbsTokenizer
+from transformers import AutoTokenizer
+
+def get_tokenizer(tokenizer_name, tokenizer_path):
+ if "qwen" in tokenizer_name:
+ return QwenTokenizer(tokenizer_path,skip_special_tokens=True)
+ else:
+ return None
+
+class QwenTokenizer(AbsTokenizer):
+ def __init__(
+ self,
+ token_path: str,
+ skip_special_tokens: bool = True,
+ ):
+ super().__init__()
+ # NOTE: non-chat model, all these special tokens keep randomly initialized.
+ special_tokens = {
+ 'eos_token': '<|endoftext|>',
+ 'pad_token': '<|endoftext|>',
+ 'additional_special_tokens': [
+ '<|im_start|>', '<|im_end|>', '<|endofprompt|>',
+ '[breath]', '', '', '[noise]',
+ '[laughter]', '[cough]', '[clucking]', '[accent]',
+ '[quick_breath]',
+ ]
+ }
+ self.tokenizer = AutoTokenizer.from_pretrained(token_path)
+ self.tokenizer.add_special_tokens(special_tokens)
+ self.skip_special_tokens = skip_special_tokens
+
+ def get_vocab_size(self):
+ return self.tokenizer.vocab_size
+
+ def text2tokens(self, line: str) -> List:
+ tokens = self.tokenizer([line], return_tensors="pt")
+ tokens = tokens["input_ids"][0].cpu().tolist()
+ return tokens
+
+ def tokens2text(self, tokens) -> str:
+ tokens = torch.tensor(tokens, dtype=torch.int64)
+ text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0]
+ return text
+
+
+
+def get_qwen_vocab_size(token_type: str):
+ if "qwen1.5" in token_type.lower() or "qwen2.0" in token_type.lower() or "qwen2.5" in token_type.lower():
+ # 293 for special and extra tokens, including endoftext, im_start, im_end, endofprompt and others in the future.
+ # model.vocab_size = 151936, tokenizer.vocab_size = 151643
+ # NOTE: the first three special tokens (endoftext, im_start, im_end) are trained in Chat series models,
+ # others are kept in random initialization state.
+ return 151643 + 293
+ else:
+ raise ValueError(f"Unknown tokenizer {token_type}")
\ No newline at end of file
diff --git a/inspiremusic/transformer/__init__.py b/inspiremusic/transformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/inspiremusic/transformer/activation.py b/inspiremusic/transformer/activation.py
new file mode 100644
index 0000000000000000000000000000000000000000..b87727d9d1d0ba6a0aeea5e7df21674f99926787
--- /dev/null
+++ b/inspiremusic/transformer/activation.py
@@ -0,0 +1,84 @@
+# Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
+# 2020 Northwestern Polytechnical University (Pengcheng Guo)
+# 2020 Mobvoi Inc (Binbin Zhang)
+# 2024 Alibaba Inc (Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Swish() activation function for Conformer."""
+
+import torch
+from torch import nn, sin, pow
+from torch.nn import Parameter
+
+
+class Swish(torch.nn.Module):
+ """Construct an Swish object."""
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Return Swish activation function."""
+ return x * torch.sigmoid(x)
+
+
+# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
+# LICENSE is in incl_licenses directory.
+class Snake(nn.Module):
+ '''
+ Implementation of a sine-based periodic activation function
+ Shape:
+ - Input: (B, C, T)
+ - Output: (B, C, T), same shape as the input
+ Parameters:
+ - alpha - trainable parameter
+ References:
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
+ https://arxiv.org/abs/2006.08195
+ Examples:
+ >>> a1 = snake(256)
+ >>> x = torch.randn(256)
+ >>> x = a1(x)
+ '''
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
+ '''
+ Initialization.
+ INPUT:
+ - in_features: shape of the input
+ - alpha: trainable parameter
+ alpha is initialized to 1 by default, higher values = higher-frequency.
+ alpha will be trained along with the rest of your model.
+ '''
+ super(Snake, self).__init__()
+ self.in_features = in_features
+
+ # initialize alpha
+ self.alpha_logscale = alpha_logscale
+ if self.alpha_logscale: # log scale alphas initialized to zeros
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
+ else: # linear scale alphas initialized to ones
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
+
+ self.alpha.requires_grad = alpha_trainable
+
+ self.no_div_by_zero = 0.000000001
+
+ def forward(self, x):
+ '''
+ Forward pass of the function.
+ Applies the function to the input elementwise.
+ Snake โถ= x + 1/a * sin^2 (xa)
+ '''
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
+ if self.alpha_logscale:
+ alpha = torch.exp(alpha)
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
+
+ return x
\ No newline at end of file
diff --git a/inspiremusic/transformer/attention.py b/inspiremusic/transformer/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebf960e96ece8cc624c8c4b1ccd71a42910bfb62
--- /dev/null
+++ b/inspiremusic/transformer/attention.py
@@ -0,0 +1,328 @@
+# Copyright (c) 2019 Shigeki Karita
+# 2020 Mobvoi Inc (Binbin Zhang)
+# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Multi-Head Attention layer definition."""
+
+import math
+from typing import Tuple
+
+import torch
+from torch import nn
+
+
+class MultiHeadedAttention(nn.Module):
+ """Multi-Head Attention layer.
+
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self,
+ n_head: int,
+ n_feat: int,
+ dropout_rate: float,
+ key_bias: bool = True):
+ """Construct an MultiHeadedAttention object."""
+ super().__init__()
+ assert n_feat % n_head == 0
+ # We assume d_v always equals d_k
+ self.d_k = n_feat // n_head
+ self.h = n_head
+ self.linear_q = nn.Linear(n_feat, n_feat)
+ self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
+ self.linear_v = nn.Linear(n_feat, n_feat)
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ def forward_qkv(
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Transform query, key and value.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+
+ Returns:
+ torch.Tensor: Transformed query tensor, size
+ (#batch, n_head, time1, d_k).
+ torch.Tensor: Transformed key tensor, size
+ (#batch, n_head, time2, d_k).
+ torch.Tensor: Transformed value tensor, size
+ (#batch, n_head, time2, d_k).
+
+ """
+ n_batch = query.size(0)
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
+
+ return q, k, v
+
+ def forward_attention(
+ self,
+ value: torch.Tensor,
+ scores: torch.Tensor,
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
+ ) -> torch.Tensor:
+ """Compute attention context vector.
+
+ Args:
+ value (torch.Tensor): Transformed value, size
+ (#batch, n_head, time2, d_k).
+ scores (torch.Tensor): Attention score, size
+ (#batch, n_head, time1, time2).
+ mask (torch.Tensor): Mask, size (#batch, 1, time2) or
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
+
+ Returns:
+ torch.Tensor: Transformed value (#batch, time1, d_model)
+ weighted by the attention score (#batch, time1, time2).
+
+ """
+ n_batch = value.size(0)
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
+ # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
+ # 1st chunk to ease the onnx export.]
+ # 2. pytorch training
+ if mask.size(2) > 0: # time2 > 0
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
+ # For last chunk, time2 might be larger than scores.size(-1)
+ mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
+ scores = scores.masked_fill(mask, -float('inf'))
+ attn = torch.softmax(scores, dim=-1).masked_fill(
+ mask, 0.0) # (batch, head, time1, time2)
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
+ # 1. onnx(16/-1, -1/-1, 16/0)
+ # 2. jit (16/-1, -1/-1, 16/0, 16/4)
+ else:
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+
+ p_attn = self.dropout(attn)
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
+ self.h * self.d_k)
+ ) # (batch, time1, d_model)
+
+ return self.linear_out(x) # (batch, time1, d_model)
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ pos_emb: torch.Tensor = torch.empty(0),
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute scaled dot product attention.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+ 1.When applying cross attention between decoder and encoder,
+ the batch padding mask for input is in (#batch, 1, T) shape.
+ 2.When applying self attention of encoder,
+ the mask is in (#batch, T, T) shape.
+ 3.When applying self attention of decoder,
+ the mask is in (#batch, L, L) shape.
+ 4.If the different position in decoder see different block
+ of the encoder, such as Mocha, the passed in mask could be
+ in (#batch, L, T) shape. But there is no such case in current
+ InspireMusic.
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
+ where `cache_t == chunk_size * num_decoding_left_chunks`
+ and `head * d_k == size`
+
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
+ where `cache_t == chunk_size * num_decoding_left_chunks`
+ and `head * d_k == size`
+
+ """
+ q, k, v = self.forward_qkv(query, key, value)
+
+ # NOTE(xcsong):
+ # when export onnx model, for 1st chunk, we feed
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
+ # and we will always do splitting and
+ # concatnation(this will simplify onnx export). Note that
+ # it's OK to concat & split zero-shaped tensors(see code below).
+ # when export jit model, for 1st chunk, we always feed
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
+ # >>> a = torch.ones((1, 2, 0, 4))
+ # >>> b = torch.ones((1, 2, 3, 4))
+ # >>> c = torch.cat((a, b), dim=2)
+ # >>> torch.equal(b, c) # True
+ # >>> d = torch.split(a, 2, dim=-1)
+ # >>> torch.equal(d[0], d[1]) # True
+ if cache.size(0) > 0:
+ key_cache, value_cache = torch.split(cache,
+ cache.size(-1) // 2,
+ dim=-1)
+ k = torch.cat([key_cache, k], dim=2)
+ v = torch.cat([value_cache, v], dim=2)
+
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
+ # non-trivial to calculate `next_cache_start` here.
+ new_cache = torch.cat((k, v), dim=-1)
+
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
+ return self.forward_attention(v, scores, mask), new_cache
+
+
+class RelPositionMultiHeadedAttention(MultiHeadedAttention):
+ """Multi-Head Attention layer with relative position encoding.
+ Paper: https://arxiv.org/abs/1901.02860
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+ """
+
+ def __init__(self,
+ n_head: int,
+ n_feat: int,
+ dropout_rate: float,
+ key_bias: bool = True):
+ """Construct an RelPositionMultiHeadedAttention object."""
+ super().__init__(n_head, n_feat, dropout_rate, key_bias)
+ # linear transformation for positional encoding
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
+ # these two learnable bias are used in matrix c and matrix d
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
+
+ def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
+ """Compute relative positional encoding.
+
+ Args:
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
+ time1 means the length of query vector.
+
+ Returns:
+ torch.Tensor: Output tensor.
+
+ """
+ zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
+ device=x.device,
+ dtype=x.dtype)
+ x_padded = torch.cat([zero_pad, x], dim=-1)
+
+ x_padded = x_padded.view(x.size()[0],
+ x.size()[1],
+ x.size(3) + 1, x.size(2))
+ x = x_padded[:, :, 1:].view_as(x)[
+ :, :, :, : x.size(-1) // 2 + 1
+ ] # only keep the positions from 0 to time2
+ return x
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ pos_emb: torch.Tensor = torch.empty(0),
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
+ pos_emb (torch.Tensor): Positional embedding tensor
+ (#batch, time2, size).
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
+ where `cache_t == chunk_size * num_decoding_left_chunks`
+ and `head * d_k == size`
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
+ where `cache_t == chunk_size * num_decoding_left_chunks`
+ and `head * d_k == size`
+ """
+ q, k, v = self.forward_qkv(query, key, value)
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
+
+ # NOTE(xcsong):
+ # when export onnx model, for 1st chunk, we feed
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
+ # and we will always do splitting and
+ # concatnation(this will simplify onnx export). Note that
+ # it's OK to concat & split zero-shaped tensors(see code below).
+ # when export jit model, for 1st chunk, we always feed
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
+ # >>> a = torch.ones((1, 2, 0, 4))
+ # >>> b = torch.ones((1, 2, 3, 4))
+ # >>> c = torch.cat((a, b), dim=2)
+ # >>> torch.equal(b, c) # True
+ # >>> d = torch.split(a, 2, dim=-1)
+ # >>> torch.equal(d[0], d[1]) # True
+ if cache.size(0) > 0:
+ key_cache, value_cache = torch.split(cache,
+ cache.size(-1) // 2,
+ dim=-1)
+ k = torch.cat([key_cache, k], dim=2)
+ v = torch.cat([value_cache, v], dim=2)
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
+ # non-trivial to calculate `next_cache_start` here.
+ new_cache = torch.cat((k, v), dim=-1)
+
+ n_batch_pos = pos_emb.size(0)
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
+
+ # (batch, head, time1, d_k)
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
+ # (batch, head, time1, d_k)
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
+
+ # compute attention score
+ # first compute matrix a and matrix c
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ # (batch, head, time1, time2)
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
+ # compute matrix b and matrix d
+ # (batch, head, time1, time2)
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
+ # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
+ if matrix_ac.shape != matrix_bd.shape:
+ matrix_bd = self.rel_shift(matrix_bd)
+
+ scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2)
+
+ return self.forward_attention(v, scores, mask), new_cache
diff --git a/inspiremusic/transformer/convolution.py b/inspiremusic/transformer/convolution.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d5d96149154776000991a681a666fbe55e562fe
--- /dev/null
+++ b/inspiremusic/transformer/convolution.py
@@ -0,0 +1,145 @@
+# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
+# 2024 Alibaba Inc (Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from ESPnet(https://github.com/espnet/espnet)
+"""ConvolutionModule definition."""
+
+from typing import Tuple
+
+import torch
+from torch import nn
+
+
+class ConvolutionModule(nn.Module):
+ """ConvolutionModule in Conformer model."""
+
+ def __init__(self,
+ channels: int,
+ kernel_size: int = 15,
+ activation: nn.Module = nn.ReLU(),
+ norm: str = "batch_norm",
+ causal: bool = False,
+ bias: bool = True):
+ """Construct an ConvolutionModule object.
+ Args:
+ channels (int): The number of channels of conv layers.
+ kernel_size (int): Kernel size of conv layers.
+ causal (int): Whether use causal convolution or not
+ """
+ super().__init__()
+
+ self.pointwise_conv1 = nn.Conv1d(
+ channels,
+ 2 * channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ # self.lorder is used to distinguish if it's a causal convolution,
+ # if self.lorder > 0: it's a causal convolution, the input will be
+ # padded with self.lorder frames on the left in forward.
+ # else: it's a symmetrical convolution
+ if causal:
+ padding = 0
+ self.lorder = kernel_size - 1
+ else:
+ # kernel_size should be an odd number for none causal convolution
+ assert (kernel_size - 1) % 2 == 0
+ padding = (kernel_size - 1) // 2
+ self.lorder = 0
+ self.depthwise_conv = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=padding,
+ groups=channels,
+ bias=bias,
+ )
+
+ assert norm in ['batch_norm', 'layer_norm']
+ if norm == "batch_norm":
+ self.use_layer_norm = False
+ self.norm = nn.BatchNorm1d(channels)
+ else:
+ self.use_layer_norm = True
+ self.norm = nn.LayerNorm(channels)
+
+ self.pointwise_conv2 = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ self.activation = activation
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ cache: torch.Tensor = torch.zeros((0, 0, 0)),
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute convolution module.
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, channels).
+ mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
+ (0, 0, 0) means fake mask.
+ cache (torch.Tensor): left context cache, it is only
+ used in causal convolution (#batch, channels, cache_t),
+ (0, 0, 0) meas fake cache.
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, channels).
+ """
+ # exchange the temporal dimension and the feature dimension
+ x = x.transpose(1, 2) # (#batch, channels, time)
+
+ # mask batch padding
+ if mask_pad.size(2) > 0: # time > 0
+ x.masked_fill_(~mask_pad, 0.0)
+
+ if self.lorder > 0:
+ if cache.size(2) == 0: # cache_t == 0
+ x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
+ else:
+ assert cache.size(0) == x.size(0) # equal batch
+ assert cache.size(1) == x.size(1) # equal channel
+ x = torch.cat((cache, x), dim=2)
+ assert (x.size(2) > self.lorder)
+ new_cache = x[:, :, -self.lorder:]
+ else:
+ # It's better we just return None if no cache is required,
+ # However, for JIT export, here we just fake one tensor instead of
+ # None.
+ new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
+
+ # GLU mechanism
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
+
+ # 1D Depthwise Conv
+ x = self.depthwise_conv(x)
+ if self.use_layer_norm:
+ x = x.transpose(1, 2)
+ x = self.activation(self.norm(x))
+ if self.use_layer_norm:
+ x = x.transpose(1, 2)
+ x = self.pointwise_conv2(x)
+ # mask batch padding
+ if mask_pad.size(2) > 0: # time > 0
+ x.masked_fill_(~mask_pad, 0.0)
+
+ return x.transpose(1, 2), new_cache
diff --git a/inspiremusic/transformer/decoder.py b/inspiremusic/transformer/decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ebf3b5109dc01149874d5c9f8c6414474d303bb
--- /dev/null
+++ b/inspiremusic/transformer/decoder.py
@@ -0,0 +1,396 @@
+# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
+# 2024 Alibaba Inc (Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from ESPnet(https://github.com/espnet/espnet)
+"""Decoder definition."""
+from typing import Tuple, List, Optional
+
+import torch
+import torch.utils.checkpoint as ckpt
+import logging
+
+from inspiremusic.transformer.decoder_layer import DecoderLayer
+from inspiremusic.transformer.positionwise_feed_forward import PositionwiseFeedForward
+from inspiremusic.utils.class_utils import (
+ INSPIREMUSIC_EMB_CLASSES,
+ INSPIREMUSIC_ATTENTION_CLASSES,
+ INSPIREMUSIC_ACTIVATION_CLASSES,
+)
+from inspiremusic.utils.mask import (subsequent_mask, make_pad_mask)
+
+
+class TransformerDecoder(torch.nn.Module):
+ """Base class of Transfomer decoder module.
+ Args:
+ vocab_size: output dim
+ encoder_output_size: dimension of attention
+ attention_heads: the number of heads of multi head attention
+ linear_units: the hidden units number of position-wise feedforward
+ num_blocks: the number of decoder blocks
+ dropout_rate: dropout rate
+ self_attention_dropout_rate: dropout rate for attention
+ input_layer: input layer type
+ use_output_layer: whether to use output layer
+ pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
+ normalize_before:
+ True: use layer_norm before each sub-block of a layer.
+ False: use layer_norm after each sub-block of a layer.
+ src_attention: if false, encoder-decoder cross attention is not
+ applied, such as CIF model
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
+ gradient_checkpointing: rerunning a forward-pass segment for each
+ checkpointed segment during backward.
+ tie_word_embedding: Tie or clone module weights depending of whether we are
+ using TorchScript or not
+ """
+
+ def __init__(
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ normalize_before: bool = True,
+ src_attention: bool = True,
+ key_bias: bool = True,
+ activation_type: str = "relu",
+ gradient_checkpointing: bool = False,
+ tie_word_embedding: bool = False,
+ ):
+ super().__init__()
+ attention_dim = encoder_output_size
+ activation = INSPIREMUSIC_ACTIVATION_CLASSES[activation_type]()
+
+ self.embed = torch.nn.Sequential(
+ torch.nn.Identity() if input_layer == "no_pos" else
+ torch.nn.Embedding(vocab_size, attention_dim),
+ INSPIREMUSIC_EMB_CLASSES[input_layer](attention_dim,
+ positional_dropout_rate),
+ )
+
+ self.normalize_before = normalize_before
+ self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5)
+ self.use_output_layer = use_output_layer
+ if use_output_layer:
+ self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
+ else:
+ self.output_layer = torch.nn.Identity()
+ self.num_blocks = num_blocks
+ self.decoders = torch.nn.ModuleList([
+ DecoderLayer(
+ attention_dim,
+ INSPIREMUSIC_ATTENTION_CLASSES["selfattn"](
+ attention_heads, attention_dim,
+ self_attention_dropout_rate, key_bias),
+ INSPIREMUSIC_ATTENTION_CLASSES["selfattn"](
+ attention_heads, attention_dim, src_attention_dropout_rate,
+ key_bias) if src_attention else None,
+ PositionwiseFeedForward(attention_dim, linear_units,
+ dropout_rate, activation),
+ dropout_rate,
+ normalize_before,
+ ) for _ in range(self.num_blocks)
+ ])
+
+ self.gradient_checkpointing = gradient_checkpointing
+ self.tie_word_embedding = tie_word_embedding
+
+ def forward(
+ self,
+ memory: torch.Tensor,
+ memory_mask: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ r_ys_in_pad: torch.Tensor = torch.empty(0),
+ reverse_weight: float = 0.0,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Forward decoder.
+ Args:
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
+ memory_mask: encoder memory mask, (batch, 1, maxlen_in)
+ ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
+ ys_in_lens: input lengths of this batch (batch)
+ r_ys_in_pad: not used in transformer decoder, in order to unify api
+ with bidirectional decoder
+ reverse_weight: not used in transformer decoder, in order to unify
+ api with bidirectional decode
+ Returns:
+ (tuple): tuple containing:
+ x: decoded token score before softmax (batch, maxlen_out,
+ vocab_size) if use_output_layer is True,
+ torch.tensor(0.0), in order to unify api with bidirectional decoder
+ olens: (batch, )
+ NOTE(xcsong):
+ We pass the `__call__` method of the modules instead of `forward` to the
+ checkpointing API because `__call__` attaches all the hooks of the module.
+ https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
+ """
+ tgt = ys_in_pad
+ maxlen = tgt.size(1)
+ # tgt_mask: (B, 1, L)
+ tgt_mask = ~make_pad_mask(ys_in_lens, maxlen).unsqueeze(1)
+ tgt_mask = tgt_mask.to(tgt.device)
+ # m: (1, L, L)
+ m = subsequent_mask(tgt_mask.size(-1),
+ device=tgt_mask.device).unsqueeze(0)
+ # tgt_mask: (B, L, L)
+ tgt_mask = tgt_mask & m
+ x, _ = self.embed(tgt)
+ if self.gradient_checkpointing and self.training:
+ x = self.forward_layers_checkpointed(x, tgt_mask, memory,
+ memory_mask)
+ else:
+ x = self.forward_layers(x, tgt_mask, memory, memory_mask)
+ if self.normalize_before:
+ x = self.after_norm(x)
+ if self.use_output_layer:
+ x = self.output_layer(x)
+ olens = tgt_mask.sum(1)
+ return x, torch.tensor(0.0), olens
+
+ def forward_layers(self, x: torch.Tensor, tgt_mask: torch.Tensor,
+ memory: torch.Tensor,
+ memory_mask: torch.Tensor) -> torch.Tensor:
+ for layer in self.decoders:
+ x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory,
+ memory_mask)
+ return x
+
+ @torch.jit.unused
+ def forward_layers_checkpointed(self, x: torch.Tensor,
+ tgt_mask: torch.Tensor,
+ memory: torch.Tensor,
+ memory_mask: torch.Tensor) -> torch.Tensor:
+ for layer in self.decoders:
+ x, tgt_mask, memory, memory_mask = ckpt.checkpoint(
+ layer.__call__, x, tgt_mask, memory, memory_mask)
+ return x
+
+ def forward_one_step(
+ self,
+ memory: torch.Tensor,
+ memory_mask: torch.Tensor,
+ tgt: torch.Tensor,
+ tgt_mask: torch.Tensor,
+ cache: Optional[List[torch.Tensor]] = None,
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+ """Forward one step.
+ This is only used for decoding.
+ Args:
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
+ memory_mask: encoded memory mask, (batch, 1, maxlen_in)
+ tgt: input token ids, int64 (batch, maxlen_out)
+ tgt_mask: input token mask, (batch, maxlen_out)
+ dtype=torch.uint8 in PyTorch 1.2-
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
+ cache: cached output list of (batch, max_time_out-1, size)
+ Returns:
+ y, cache: NN output value and cache per `self.decoders`.
+ y.shape` is (batch, maxlen_out, token)
+ """
+ x, _ = self.embed(tgt)
+ new_cache = []
+ for i, decoder in enumerate(self.decoders):
+ if cache is None:
+ c = None
+ else:
+ c = cache[i]
+ x, tgt_mask, memory, memory_mask = decoder(x,
+ tgt_mask,
+ memory,
+ memory_mask,
+ cache=c)
+ new_cache.append(x)
+ if self.normalize_before:
+ y = self.after_norm(x[:, -1])
+ else:
+ y = x[:, -1]
+ if self.use_output_layer:
+ y = torch.log_softmax(self.output_layer(y), dim=-1)
+ return y, new_cache
+
+ def tie_or_clone_weights(self, jit_mode: bool = True):
+ """Tie or clone module weights (between word_emb and output_layer)
+ depending of whether we are using TorchScript or not"""
+ if not self.use_output_layer:
+ return
+ if jit_mode:
+ logging.info("clone emb.weight to output.weight")
+ self.output_layer.weight = torch.nn.Parameter(
+ self.embed[0].weight.clone())
+ else:
+ logging.info("tie emb.weight with output.weight")
+ self.output_layer.weight = self.embed[0].weight
+
+ if getattr(self.output_layer, "bias", None) is not None:
+ self.output_layer.bias.data = torch.nn.functional.pad(
+ self.output_layer.bias.data,
+ (
+ 0,
+ self.output_layer.weight.shape[0] -
+ self.output_layer.bias.shape[0],
+ ),
+ "constant",
+ 0,
+ )
+
+
+class BiTransformerDecoder(torch.nn.Module):
+ """Base class of Transfomer decoder module.
+ Args:
+ vocab_size: output dim
+ encoder_output_size: dimension of attention
+ attention_heads: the number of heads of multi head attention
+ linear_units: the hidden units number of position-wise feedforward
+ num_blocks: the number of decoder blocks
+ r_num_blocks: the number of right to left decoder blocks
+ dropout_rate: dropout rate
+ self_attention_dropout_rate: dropout rate for attention
+ input_layer: input layer type
+ use_output_layer: whether to use output layer
+ pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
+ normalize_before:
+ True: use layer_norm before each sub-block of a layer.
+ False: use layer_norm after each sub-block of a layer.
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
+ """
+
+ def __init__(
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ r_num_blocks: int = 0,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ normalize_before: bool = True,
+ key_bias: bool = True,
+ gradient_checkpointing: bool = False,
+ tie_word_embedding: bool = False,
+ ):
+
+ super().__init__()
+ self.tie_word_embedding = tie_word_embedding
+ self.left_decoder = TransformerDecoder(
+ vocab_size,
+ encoder_output_size,
+ attention_heads,
+ linear_units,
+ num_blocks,
+ dropout_rate,
+ positional_dropout_rate,
+ self_attention_dropout_rate,
+ src_attention_dropout_rate,
+ input_layer,
+ use_output_layer,
+ normalize_before,
+ key_bias=key_bias,
+ gradient_checkpointing=gradient_checkpointing,
+ tie_word_embedding=tie_word_embedding)
+
+ self.right_decoder = TransformerDecoder(
+ vocab_size,
+ encoder_output_size,
+ attention_heads,
+ linear_units,
+ r_num_blocks,
+ dropout_rate,
+ positional_dropout_rate,
+ self_attention_dropout_rate,
+ src_attention_dropout_rate,
+ input_layer,
+ use_output_layer,
+ normalize_before,
+ key_bias=key_bias,
+ gradient_checkpointing=gradient_checkpointing,
+ tie_word_embedding=tie_word_embedding)
+
+ def forward(
+ self,
+ memory: torch.Tensor,
+ memory_mask: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ r_ys_in_pad: torch.Tensor,
+ reverse_weight: float = 0.0,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Forward decoder.
+ Args:
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
+ memory_mask: encoder memory mask, (batch, 1, maxlen_in)
+ ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
+ ys_in_lens: input lengths of this batch (batch)
+ r_ys_in_pad: padded input token ids, int64 (batch, maxlen_out),
+ used for right to left decoder
+ reverse_weight: used for right to left decoder
+ Returns:
+ (tuple): tuple containing:
+ x: decoded token score before softmax (batch, maxlen_out,
+ vocab_size) if use_output_layer is True,
+ r_x: x: decoded token score (right to left decoder)
+ before softmax (batch, maxlen_out, vocab_size)
+ if use_output_layer is True,
+ olens: (batch, )
+ """
+ l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad,
+ ys_in_lens)
+ r_x = torch.tensor(0.0)
+ if reverse_weight > 0.0:
+ r_x, _, olens = self.right_decoder(memory, memory_mask,
+ r_ys_in_pad, ys_in_lens)
+ return l_x, r_x, olens
+
+ def forward_one_step(
+ self,
+ memory: torch.Tensor,
+ memory_mask: torch.Tensor,
+ tgt: torch.Tensor,
+ tgt_mask: torch.Tensor,
+ cache: Optional[List[torch.Tensor]] = None,
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+ """Forward one step.
+ This is only used for decoding.
+ Args:
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
+ memory_mask: encoded memory mask, (batch, 1, maxlen_in)
+ tgt: input token ids, int64 (batch, maxlen_out)
+ tgt_mask: input token mask, (batch, maxlen_out)
+ dtype=torch.uint8 in PyTorch 1.2-
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
+ cache: cached output list of (batch, max_time_out-1, size)
+ Returns:
+ y, cache: NN output value and cache per `self.decoders`.
+ y.shape` is (batch, maxlen_out, token)
+ """
+ return self.left_decoder.forward_one_step(memory, memory_mask, tgt,
+ tgt_mask, cache)
+
+ def tie_or_clone_weights(self, jit_mode: bool = True):
+ """Tie or clone module weights (between word_emb and output_layer)
+ depending of whether we are using TorchScript or not"""
+ self.left_decoder.tie_or_clone_weights(jit_mode)
+ self.right_decoder.tie_or_clone_weights(jit_mode)
diff --git a/inspiremusic/transformer/decoder_layer.py b/inspiremusic/transformer/decoder_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..91c7c5d7fb2a8e79cea7705646e5381016f73466
--- /dev/null
+++ b/inspiremusic/transformer/decoder_layer.py
@@ -0,0 +1,132 @@
+# Copyright (c) 2019 Shigeki Karita
+# 2020 Mobvoi Inc (Binbin Zhang)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Decoder self-attention layer definition."""
+from typing import Optional, Tuple
+
+import torch
+from torch import nn
+
+
+class DecoderLayer(nn.Module):
+ """Single decoder layer module.
+
+ Args:
+ size (int): Input dimension.
+ self_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` instance can be used as the argument.
+ src_attn (torch.nn.Module): Inter-attention module instance.
+ `MultiHeadedAttention` instance can be used as the argument.
+ If `None` is passed, Inter-attention is not used, such as
+ CIF, GPT, and other decoder only model.
+ feed_forward (torch.nn.Module): Feed-forward module instance.
+ `PositionwiseFeedForward` instance can be used as the argument.
+ dropout_rate (float): Dropout rate.
+ normalize_before (bool):
+ True: use layer_norm before each sub-block.
+ False: to use layer_norm after each sub-block.
+ """
+
+ def __init__(
+ self,
+ size: int,
+ self_attn: nn.Module,
+ src_attn: Optional[nn.Module],
+ feed_forward: nn.Module,
+ dropout_rate: float,
+ normalize_before: bool = True,
+ ):
+ """Construct an DecoderLayer object."""
+ super().__init__()
+ self.size = size
+ self.self_attn = self_attn
+ self.src_attn = src_attn
+ self.feed_forward = feed_forward
+ self.norm1 = nn.LayerNorm(size, eps=1e-5)
+ self.norm2 = nn.LayerNorm(size, eps=1e-5)
+ self.norm3 = nn.LayerNorm(size, eps=1e-5)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.normalize_before = normalize_before
+
+ def forward(
+ self,
+ tgt: torch.Tensor,
+ tgt_mask: torch.Tensor,
+ memory: torch.Tensor,
+ memory_mask: torch.Tensor,
+ cache: Optional[torch.Tensor] = None
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Compute decoded features.
+
+ Args:
+ tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+ tgt_mask (torch.Tensor): Mask for input tensor
+ (#batch, maxlen_out).
+ memory (torch.Tensor): Encoded memory
+ (#batch, maxlen_in, size).
+ memory_mask (torch.Tensor): Encoded memory mask
+ (#batch, maxlen_in).
+ cache (torch.Tensor): cached tensors.
+ (#batch, maxlen_out - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, maxlen_out, size).
+ torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+ torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+ torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+ """
+ residual = tgt
+ if self.normalize_before:
+ tgt = self.norm1(tgt)
+
+ if cache is None:
+ tgt_q = tgt
+ tgt_q_mask = tgt_mask
+ else:
+ # compute only the last frame query keeping dim: max_time_out -> 1
+ assert cache.shape == (
+ tgt.shape[0],
+ tgt.shape[1] - 1,
+ self.size,
+ ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
+ tgt_q = tgt[:, -1:, :]
+ residual = residual[:, -1:, :]
+ tgt_q_mask = tgt_mask[:, -1:, :]
+
+ x = residual + self.dropout(
+ self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
+ if not self.normalize_before:
+ x = self.norm1(x)
+
+ if self.src_attn is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm2(x)
+ x = residual + self.dropout(
+ self.src_attn(x, memory, memory, memory_mask)[0])
+ if not self.normalize_before:
+ x = self.norm2(x)
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm3(x)
+ x = residual + self.dropout(self.feed_forward(x))
+ if not self.normalize_before:
+ x = self.norm3(x)
+
+ if cache is not None:
+ x = torch.cat([cache, x], dim=1)
+
+ return x, tgt_mask, memory, memory_mask
diff --git a/inspiremusic/transformer/embedding.py b/inspiremusic/transformer/embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..eae8c8ecabb15b4174cc3aa73c070ae702bb5f82
--- /dev/null
+++ b/inspiremusic/transformer/embedding.py
@@ -0,0 +1,294 @@
+# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
+# 2024 Alibaba Inc (Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from ESPnet(https://github.com/espnet/espnet)
+"""Positonal Encoding Module."""
+
+import math
+from typing import Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import numpy as np
+
+
+class PositionalEncoding(torch.nn.Module):
+ """Positional encoding.
+
+ :param int d_model: embedding dim
+ :param float dropout_rate: dropout rate
+ :param int max_len: maximum input length
+
+ PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
+ PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
+ """
+
+ def __init__(self,
+ d_model: int,
+ dropout_rate: float,
+ max_len: int = 5000,
+ reverse: bool = False):
+ """Construct an PositionalEncoding object."""
+ super().__init__()
+ self.d_model = d_model
+ self.xscale = math.sqrt(self.d_model)
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+ self.max_len = max_len
+
+ self.pe = torch.zeros(self.max_len, self.d_model)
+ position = torch.arange(0, self.max_len,
+ dtype=torch.float32).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, self.d_model, 2, dtype=torch.float32) *
+ -(math.log(10000.0) / self.d_model))
+ self.pe[:, 0::2] = torch.sin(position * div_term)
+ self.pe[:, 1::2] = torch.cos(position * div_term)
+ self.pe = self.pe.unsqueeze(0)
+
+ def forward(self,
+ x: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0) \
+ -> Tuple[torch.Tensor, torch.Tensor]:
+ """Add positional encoding.
+
+ Args:
+ x (torch.Tensor): Input. Its shape is (batch, time, ...)
+ offset (int, torch.tensor): position offset
+
+ Returns:
+ torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
+ torch.Tensor: for compatibility to RelPositionalEncoding
+ """
+
+ self.pe = self.pe.to(x.device)
+ pos_emb = self.position_encoding(offset, x.size(1), False)
+ x = x * self.xscale + pos_emb
+ return self.dropout(x), self.dropout(pos_emb)
+
+ def position_encoding(self,
+ offset: Union[int, torch.Tensor],
+ size: int,
+ apply_dropout: bool = True) -> torch.Tensor:
+ """ For getting encoding in a streaming fashion
+
+ Attention!!!!!
+ we apply dropout only once at the whole utterance level in a none
+ streaming way, but will call this function several times with
+ increasing input size in a streaming scenario, so the dropout will
+ be applied several times.
+
+ Args:
+ offset (int or torch.tensor): start offset
+ size (int): required size of position encoding
+
+ Returns:
+ torch.Tensor: Corresponding encoding
+ """
+ # How to subscript a Union type:
+ # https://github.com/pytorch/pytorch/issues/69434
+ if isinstance(offset, int):
+ assert offset + size <= self.max_len
+ pos_emb = self.pe[:, offset:offset + size]
+ elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
+ assert offset + size <= self.max_len
+ pos_emb = self.pe[:, offset:offset + size]
+ else: # for batched streaming decoding on GPU
+ assert torch.max(offset) + size <= self.max_len
+ index = offset.unsqueeze(1) + \
+ torch.arange(0, size).to(offset.device) # B X T
+ flag = index > 0
+ # remove negative offset
+ index = index * flag
+ pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model
+
+ if apply_dropout:
+ pos_emb = self.dropout(pos_emb)
+ return pos_emb
+
+
+class RelPositionalEncoding(PositionalEncoding):
+ """Relative positional encoding module.
+ See : Appendix B in https://arxiv.org/abs/1901.02860
+ Args:
+ d_model (int): Embedding dimension.
+ dropout_rate (float): Dropout rate.
+ max_len (int): Maximum input length.
+ """
+
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
+ """Initialize class."""
+ super().__init__(d_model, dropout_rate, max_len, reverse=True)
+
+ def forward(self,
+ x: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0) \
+ -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute positional encoding.
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
+ """
+ self.pe = self.pe.to(x.device)
+ x = x * self.xscale
+ pos_emb = self.position_encoding(offset, x.size(1), False)
+ return self.dropout(x), self.dropout(pos_emb)
+
+
+class WhisperPositionalEncoding(PositionalEncoding):
+ """ Sinusoids position encoding used in openai-whisper.encoder
+ """
+
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500):
+ super().__init__(d_model, dropout_rate, max_len)
+ self.xscale = 1.0
+ log_timescale_increment = np.log(10000) / (d_model // 2 - 1)
+ inv_timescales = torch.exp(-log_timescale_increment *
+ torch.arange(d_model // 2))
+ scaled_time = torch.arange(max_len)[:, np.newaxis] * \
+ inv_timescales[np.newaxis, :]
+ pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
+ delattr(self, "pe")
+ self.register_buffer("pe", pe.unsqueeze(0))
+
+
+class LearnablePositionalEncoding(PositionalEncoding):
+ """ Learnable position encoding used in openai-whisper.decoder
+ """
+
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448):
+ super().__init__(d_model, dropout_rate, max_len)
+ # NOTE(xcsong): overwrite self.pe & self.xscale
+ self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model))
+ self.xscale = 1.0
+
+
+class NoPositionalEncoding(torch.nn.Module):
+ """ No position encoding
+ """
+
+ def __init__(self, d_model: int, dropout_rate: float):
+ super().__init__()
+ self.d_model = d_model
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+
+ def forward(self,
+ x: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0) \
+ -> Tuple[torch.Tensor, torch.Tensor]:
+ """ Just return zero vector for interface compatibility
+ """
+ pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
+ return self.dropout(x), pos_emb
+
+ def position_encoding(self, offset: Union[int, torch.Tensor],
+ size: int) -> torch.Tensor:
+ return torch.zeros(1, size, self.d_model)
+
+
+class EspnetRelPositionalEncoding(torch.nn.Module):
+ """Relative positional encoding module (new implementation).
+
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
+
+ See : Appendix B in https://arxiv.org/abs/1901.02860
+
+ Args:
+ d_model (int): Embedding dimension.
+ dropout_rate (float): Dropout rate.
+ max_len (int): Maximum input length.
+
+ """
+
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
+ """Construct an PositionalEncoding object."""
+ super(EspnetRelPositionalEncoding, self).__init__()
+ self.d_model = d_model
+ self.xscale = math.sqrt(self.d_model)
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+ self.pe = None
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
+
+ def extend_pe(self, x: torch.Tensor):
+ """Reset the positional encodings."""
+ if self.pe is not None:
+ # self.pe contains both positive and negative parts
+ # the length of self.pe is 2 * input_len - 1
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+ # Suppose `i` means to the position of query vecotr and `j` means the
+ # position of key vector. We use position relative positions when keys
+ # are to the left (i>j) and negative relative positions otherwise (i Tuple[torch.Tensor, torch.Tensor]:
+ """Add positional encoding.
+
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+
+ """
+ self.extend_pe(x)
+ x = x * self.xscale
+ pos_emb = self.position_encoding(size=x.size(1), offset=offset)
+ return self.dropout(x), self.dropout(pos_emb)
+
+ def position_encoding(self,
+ offset: Union[int, torch.Tensor],
+ size: int) -> torch.Tensor:
+ """ For getting encoding in a streaming fashion
+
+ Attention!!!!!
+ we apply dropout only once at the whole utterance level in a none
+ streaming way, but will call this function several times with
+ increasing input size in a streaming scenario, so the dropout will
+ be applied several times.
+
+ Args:
+ offset (int or torch.tensor): start offset
+ size (int): required size of position encoding
+
+ Returns:
+ torch.Tensor: Corresponding encoding
+ """
+ pos_emb = self.pe[
+ :,
+ self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
+ ]
+ return pos_emb
diff --git a/inspiremusic/transformer/encoder.py b/inspiremusic/transformer/encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..a46b531778f3c012f0a7e4f5da3c1d6f707c358d
--- /dev/null
+++ b/inspiremusic/transformer/encoder.py
@@ -0,0 +1,477 @@
+# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
+# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
+# 2024 Alibaba Inc (Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from ESPnet(https://github.com/espnet/espnet)
+"""Encoder definition."""
+from typing import Tuple
+
+import torch
+import torch.utils.checkpoint as ckpt
+
+from inspiremusic.transformer.convolution import ConvolutionModule
+from inspiremusic.transformer.encoder_layer import TransformerEncoderLayer
+from inspiremusic.transformer.encoder_layer import ConformerEncoderLayer
+from inspiremusic.transformer.positionwise_feed_forward import PositionwiseFeedForward
+from inspiremusic.utils.class_utils import (
+ INSPIREMUSIC_EMB_CLASSES,
+ INSPIREMUSIC_SUBSAMPLE_CLASSES,
+ INSPIREMUSIC_ATTENTION_CLASSES,
+ INSPIREMUSIC_ACTIVATION_CLASSES,
+)
+from inspiremusic.utils.mask import make_pad_mask
+from inspiremusic.utils.mask import add_optional_chunk_mask
+
+
+class BaseEncoder(torch.nn.Module):
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ input_layer: str = "conv2d",
+ pos_enc_layer_type: str = "abs_pos",
+ normalize_before: bool = True,
+ static_chunk_size: int = 0,
+ use_dynamic_chunk: bool = False,
+ global_cmvn: torch.nn.Module = None,
+ use_dynamic_left_chunk: bool = False,
+ gradient_checkpointing: bool = False,
+ ):
+ """
+ Args:
+ input_size (int): input dim
+ output_size (int): dimension of attention
+ attention_heads (int): the number of heads of multi head attention
+ linear_units (int): the hidden units number of position-wise feed
+ forward
+ num_blocks (int): the number of decoder blocks
+ dropout_rate (float): dropout rate
+ attention_dropout_rate (float): dropout rate in attention
+ positional_dropout_rate (float): dropout rate after adding
+ positional encoding
+ input_layer (str): input layer type.
+ optional [linear, conv2d, conv2d6, conv2d8]
+ pos_enc_layer_type (str): Encoder positional encoding layer type.
+ opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
+ normalize_before (bool):
+ True: use layer_norm before each sub-block of a layer.
+ False: use layer_norm after each sub-block of a layer.
+ static_chunk_size (int): chunk size for static chunk training and
+ decoding
+ use_dynamic_chunk (bool): whether use dynamic chunk size for
+ training or not, You can only use fixed chunk(chunk_size > 0)
+ or dyanmic chunk size(use_dynamic_chunk = True)
+ global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
+ use_dynamic_left_chunk (bool): whether use dynamic left chunk in
+ dynamic chunk training
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
+ gradient_checkpointing: rerunning a forward-pass segment for each
+ checkpointed segment during backward.
+ """
+ super().__init__()
+ self._output_size = output_size
+
+ self.global_cmvn = global_cmvn
+ self.embed = INSPIREMUSIC_SUBSAMPLE_CLASSES[input_layer](
+ input_size,
+ output_size,
+ dropout_rate,
+ INSPIREMUSIC_EMB_CLASSES[pos_enc_layer_type](output_size,
+ positional_dropout_rate),
+ )
+
+ self.normalize_before = normalize_before
+ self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
+ self.static_chunk_size = static_chunk_size
+ self.use_dynamic_chunk = use_dynamic_chunk
+ self.use_dynamic_left_chunk = use_dynamic_left_chunk
+ self.gradient_checkpointing = gradient_checkpointing
+
+ def output_size(self) -> int:
+ return self._output_size
+
+ def forward(
+ self,
+ xs: torch.Tensor,
+ xs_lens: torch.Tensor,
+ decoding_chunk_size: int = 0,
+ num_decoding_left_chunks: int = -1,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Embed positions in tensor.
+
+ Args:
+ xs: padded input tensor (B, T, D)
+ xs_lens: input length (B)
+ decoding_chunk_size: decoding chunk size for dynamic chunk
+ 0: default for training, use random dynamic chunk.
+ <0: for decoding, use full chunk.
+ >0: for decoding, use fixed chunk size as set.
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
+ the chunk size is decoding_chunk_size.
+ >=0: use num_decoding_left_chunks
+ <0: use all left chunks
+ Returns:
+ encoder output tensor xs, and subsampled masks
+ xs: padded output tensor (B, T' ~= T/subsample_rate, D)
+ masks: torch.Tensor batch padding mask after subsample
+ (B, 1, T' ~= T/subsample_rate)
+ NOTE(xcsong):
+ We pass the `__call__` method of the modules instead of `forward` to the
+ checkpointing API because `__call__` attaches all the hooks of the module.
+ https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
+ """
+ T = xs.size(1)
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
+ if self.global_cmvn is not None:
+ xs = self.global_cmvn(xs)
+ xs, pos_emb, masks = self.embed(xs, masks)
+ mask_pad = masks # (B, 1, T/subsample_rate)
+ chunk_masks = add_optional_chunk_mask(xs, masks,
+ self.use_dynamic_chunk,
+ self.use_dynamic_left_chunk,
+ decoding_chunk_size,
+ self.static_chunk_size,
+ num_decoding_left_chunks)
+
+ if self.gradient_checkpointing and self.training:
+ xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb,
+ mask_pad)
+ else:
+ xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
+ if self.normalize_before:
+ xs = self.after_norm(xs)
+ # Here we assume the mask is not changed in encoder layers, so just
+ # return the masks before encoder layers, and the masks will be used
+ # for cross attention with decoder later
+ return xs, masks
+
+ def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
+ pos_emb: torch.Tensor,
+ mask_pad: torch.Tensor) -> torch.Tensor:
+ for layer in self.encoders:
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
+ return xs
+
+ @torch.jit.unused
+ def forward_layers_checkpointed(self, xs: torch.Tensor,
+ chunk_masks: torch.Tensor,
+ pos_emb: torch.Tensor,
+ mask_pad: torch.Tensor) -> torch.Tensor:
+ for layer in self.encoders:
+ xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__, xs,
+ chunk_masks, pos_emb,
+ mask_pad)
+ return xs
+
+ @torch.jit.export
+ def forward_chunk(
+ self,
+ xs: torch.Tensor,
+ offset: int,
+ required_cache_size: int,
+ att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
+ cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
+ att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """ Forward just one chunk
+
+ Args:
+ xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim),
+ where `time == (chunk_size - 1) * subsample_rate + \
+ subsample.right_context + 1`
+ offset (int): current offset in encoder output time stamp
+ required_cache_size (int): cache size required for next chunk
+ compuation
+ >=0: actual cache size
+ <0: means all history cache is required
+ att_cache (torch.Tensor): cache tensor for KEY & VALUE in
+ transformer/conformer attention, with shape
+ (elayers, head, cache_t1, d_k * 2), where
+ `head * d_k == hidden-dim` and
+ `cache_t1 == chunk_size * num_decoding_left_chunks`.
+ cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
+ (elayers, b=1, hidden-dim, cache_t2), where
+ `cache_t2 == cnn.lorder - 1`
+
+ Returns:
+ torch.Tensor: output of current input xs,
+ with shape (b=1, chunk_size, hidden-dim).
+ torch.Tensor: new attention cache required for next chunk, with
+ dynamic shape (elayers, head, ?, d_k * 2)
+ depending on required_cache_size.
+ torch.Tensor: new conformer cnn cache required for next chunk, with
+ same shape as the original cnn_cache.
+
+ """
+ assert xs.size(0) == 1
+ # tmp_masks is just for interface compatibility
+ tmp_masks = torch.ones(1,
+ xs.size(1),
+ device=xs.device,
+ dtype=torch.bool)
+ tmp_masks = tmp_masks.unsqueeze(1)
+ if self.global_cmvn is not None:
+ xs = self.global_cmvn(xs)
+ # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
+
+ xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
+ # NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim)
+ elayers, cache_t1 = att_cache.size(0), att_cache.size(2)
+ chunk_size = xs.size(1)
+ attention_key_size = cache_t1 + chunk_size
+ pos_emb = self.embed.position_encoding(offset=offset - cache_t1,
+ size=attention_key_size)
+ if required_cache_size < 0:
+ next_cache_start = 0
+ elif required_cache_size == 0:
+ next_cache_start = attention_key_size
+ else:
+ next_cache_start = max(attention_key_size - required_cache_size, 0)
+ r_att_cache = []
+ r_cnn_cache = []
+ for i, layer in enumerate(self.encoders):
+ # NOTE(xcsong): Before layer.forward
+ # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
+ # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
+
+ xs, _, new_att_cache, new_cnn_cache = layer(
+ xs,
+ att_mask,
+ pos_emb,
+ att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
+ cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache)
+ # NOTE(xcsong): After layer.forward
+ # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
+ # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
+ r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
+ r_cnn_cache.append(new_cnn_cache.unsqueeze(0))
+ if self.normalize_before:
+ xs = self.after_norm(xs)
+
+ # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
+ # ? may be larger than cache_t1, it depends on required_cache_size
+ r_att_cache = torch.cat(r_att_cache, dim=0)
+ # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
+ r_cnn_cache = torch.cat(r_cnn_cache, dim=0)
+
+ return (xs, r_att_cache, r_cnn_cache)
+
+ @torch.jit.unused
+ def forward_chunk_by_chunk(
+ self,
+ xs: torch.Tensor,
+ decoding_chunk_size: int,
+ num_decoding_left_chunks: int = -1,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """ Forward input chunk by chunk with chunk_size like a streaming
+ fashion
+
+ Here we should pay special attention to computation cache in the
+ streaming style forward chunk by chunk. Three things should be taken
+ into account for computation in the current network:
+ 1. transformer/conformer encoder layers output cache
+ 2. convolution in conformer
+ 3. convolution in subsampling
+
+ However, we don't implement subsampling cache for:
+ 1. We can control subsampling module to output the right result by
+ overlapping input instead of cache left context, even though it
+ wastes some computation, but subsampling only takes a very
+ small fraction of computation in the whole model.
+ 2. Typically, there are several covolution layers with subsampling
+ in subsampling module, it is tricky and complicated to do cache
+ with different convolution layers with different subsampling
+ rate.
+ 3. Currently, nn.Sequential is used to stack all the convolution
+ layers in subsampling, we need to rewrite it to make it work
+ with cache, which is not preferred.
+ Args:
+ xs (torch.Tensor): (1, max_len, dim)
+ chunk_size (int): decoding chunk size
+ """
+ assert decoding_chunk_size > 0
+ # The model is trained by static or dynamic chunk
+ assert self.static_chunk_size > 0 or self.use_dynamic_chunk
+ subsampling = self.embed.subsampling_rate
+ context = self.embed.right_context + 1 # Add current frame
+ stride = subsampling * decoding_chunk_size
+ decoding_window = (decoding_chunk_size - 1) * subsampling + context
+ num_frames = xs.size(1)
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
+ outputs = []
+ offset = 0
+ required_cache_size = decoding_chunk_size * num_decoding_left_chunks
+
+ # Feed forward overlap input step by step
+ for cur in range(0, num_frames - context + 1, stride):
+ end = min(cur + decoding_window, num_frames)
+ chunk_xs = xs[:, cur:end, :]
+ (y, att_cache,
+ cnn_cache) = self.forward_chunk(chunk_xs, offset,
+ required_cache_size, att_cache,
+ cnn_cache)
+ outputs.append(y)
+ offset += y.size(1)
+ ys = torch.cat(outputs, 1)
+ masks = torch.ones((1, 1, ys.size(1)),
+ device=ys.device,
+ dtype=torch.bool)
+ return ys, masks
+
+
+class TransformerEncoder(BaseEncoder):
+ """Transformer encoder module."""
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ input_layer: str = "conv2d",
+ pos_enc_layer_type: str = "abs_pos",
+ normalize_before: bool = True,
+ static_chunk_size: int = 0,
+ use_dynamic_chunk: bool = False,
+ global_cmvn: torch.nn.Module = None,
+ use_dynamic_left_chunk: bool = False,
+ key_bias: bool = True,
+ selfattention_layer_type: str = "selfattn",
+ activation_type: str = "relu",
+ gradient_checkpointing: bool = False,
+ ):
+ """ Construct TransformerEncoder
+
+ See Encoder for the meaning of each parameter.
+ """
+ super().__init__(input_size, output_size, attention_heads,
+ linear_units, num_blocks, dropout_rate,
+ positional_dropout_rate, attention_dropout_rate,
+ input_layer, pos_enc_layer_type, normalize_before,
+ static_chunk_size, use_dynamic_chunk, global_cmvn,
+ use_dynamic_left_chunk, gradient_checkpointing)
+ activation = INSPIREMUSIC_ACTIVATION_CLASSES[activation_type]()
+ self.encoders = torch.nn.ModuleList([
+ TransformerEncoderLayer(
+ output_size,
+ INSPIREMUSIC_ATTENTION_CLASSES[selfattention_layer_type](attention_heads,
+ output_size,
+ attention_dropout_rate,
+ key_bias),
+ PositionwiseFeedForward(output_size, linear_units,
+ dropout_rate, activation),
+ dropout_rate, normalize_before) for _ in range(num_blocks)
+ ])
+
+
+class ConformerEncoder(BaseEncoder):
+ """Conformer encoder module."""
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ input_layer: str = "conv2d",
+ pos_enc_layer_type: str = "rel_pos",
+ normalize_before: bool = True,
+ static_chunk_size: int = 0,
+ use_dynamic_chunk: bool = False,
+ global_cmvn: torch.nn.Module = None,
+ use_dynamic_left_chunk: bool = False,
+ positionwise_conv_kernel_size: int = 1,
+ macaron_style: bool = True,
+ selfattention_layer_type: str = "rel_selfattn",
+ activation_type: str = "swish",
+ use_cnn_module: bool = True,
+ cnn_module_kernel: int = 15,
+ causal: bool = False,
+ cnn_module_norm: str = "batch_norm",
+ key_bias: bool = True,
+ gradient_checkpointing: bool = False,
+ ):
+ """Construct ConformerEncoder
+
+ Args:
+ input_size to use_dynamic_chunk, see in BaseEncoder
+ positionwise_conv_kernel_size (int): Kernel size of positionwise
+ conv1d layer.
+ macaron_style (bool): Whether to use macaron style for
+ positionwise layer.
+ selfattention_layer_type (str): Encoder attention layer type,
+ the parameter has no effect now, it's just for configure
+ compatibility.
+ activation_type (str): Encoder activation function type.
+ use_cnn_module (bool): Whether to use convolution module.
+ cnn_module_kernel (int): Kernel size of convolution module.
+ causal (bool): whether to use causal convolution or not.
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
+ """
+ super().__init__(input_size, output_size, attention_heads,
+ linear_units, num_blocks, dropout_rate,
+ positional_dropout_rate, attention_dropout_rate,
+ input_layer, pos_enc_layer_type, normalize_before,
+ static_chunk_size, use_dynamic_chunk, global_cmvn,
+ use_dynamic_left_chunk, gradient_checkpointing)
+ activation = INSPIREMUSIC_ACTIVATION_CLASSES[activation_type]()
+
+ # self-attention module definition
+ encoder_selfattn_layer_args = (
+ attention_heads,
+ output_size,
+ attention_dropout_rate,
+ key_bias,
+ )
+ # feed-forward module definition
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ dropout_rate,
+ activation,
+ )
+ # convolution module definition
+ convolution_layer_args = (output_size, cnn_module_kernel, activation,
+ cnn_module_norm, causal)
+
+ self.encoders = torch.nn.ModuleList([
+ ConformerEncoderLayer(
+ output_size,
+ INSPIREMUSIC_ATTENTION_CLASSES[selfattention_layer_type](
+ *encoder_selfattn_layer_args),
+ PositionwiseFeedForward(*positionwise_layer_args),
+ PositionwiseFeedForward(
+ *positionwise_layer_args) if macaron_style else None,
+ ConvolutionModule(
+ *convolution_layer_args) if use_cnn_module else None,
+ dropout_rate,
+ normalize_before,
+ ) for _ in range(num_blocks)
+ ])
diff --git a/inspiremusic/transformer/encoder_layer.py b/inspiremusic/transformer/encoder_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..9fb3a9d4e99d0f8f92ec1802a1a7620328e9353a
--- /dev/null
+++ b/inspiremusic/transformer/encoder_layer.py
@@ -0,0 +1,235 @@
+# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
+# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from ESPnet(https://github.com/espnet/espnet)
+"""Encoder self-attention layer definition."""
+
+from typing import Optional, Tuple
+
+import torch
+from torch import nn
+
+
+class TransformerEncoderLayer(nn.Module):
+ """Encoder layer module.
+
+ Args:
+ size (int): Input dimension.
+ self_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
+ instance can be used as the argument.
+ feed_forward (torch.nn.Module): Feed-forward module instance.
+ `PositionwiseFeedForward`, instance can be used as the argument.
+ dropout_rate (float): Dropout rate.
+ normalize_before (bool):
+ True: use layer_norm before each sub-block.
+ False: to use layer_norm after each sub-block.
+ """
+
+ def __init__(
+ self,
+ size: int,
+ self_attn: torch.nn.Module,
+ feed_forward: torch.nn.Module,
+ dropout_rate: float,
+ normalize_before: bool = True,
+ ):
+ """Construct an EncoderLayer object."""
+ super().__init__()
+ self.self_attn = self_attn
+ self.feed_forward = feed_forward
+ self.norm1 = nn.LayerNorm(size, eps=1e-5)
+ self.norm2 = nn.LayerNorm(size, eps=1e-5)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.size = size
+ self.normalize_before = normalize_before
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask: torch.Tensor,
+ pos_emb: torch.Tensor,
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Compute encoded features.
+
+ Args:
+ x (torch.Tensor): (#batch, time, size)
+ mask (torch.Tensor): Mask tensor for the input (#batch, time๏ผtime),
+ (0, 0, 0) means fake mask.
+ pos_emb (torch.Tensor): just for interface compatibility
+ to ConformerEncoderLayer
+ mask_pad (torch.Tensor): does not used in transformer layer,
+ just for unified api with conformer.
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
+ (#batch=1, size, cache_t2), not used here, it's for interface
+ compatibility to ConformerEncoderLayer.
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, size).
+ torch.Tensor: Mask tensor (#batch, time, time).
+ torch.Tensor: att_cache tensor,
+ (#batch=1, head, cache_t1 + time, d_k * 2).
+ torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2).
+
+ """
+ residual = x
+ if self.normalize_before:
+ x = self.norm1(x)
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb=pos_emb, cache=att_cache)
+ x = residual + self.dropout(x_att)
+ if not self.normalize_before:
+ x = self.norm1(x)
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm2(x)
+ x = residual + self.dropout(self.feed_forward(x))
+ if not self.normalize_before:
+ x = self.norm2(x)
+
+ fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
+ return x, mask, new_att_cache, fake_cnn_cache
+
+
+class ConformerEncoderLayer(nn.Module):
+ """Encoder layer module.
+ Args:
+ size (int): Input dimension.
+ self_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
+ instance can be used as the argument.
+ feed_forward (torch.nn.Module): Feed-forward module instance.
+ `PositionwiseFeedForward` instance can be used as the argument.
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module
+ instance.
+ `PositionwiseFeedForward` instance can be used as the argument.
+ conv_module (torch.nn.Module): Convolution module instance.
+ `ConvlutionModule` instance can be used as the argument.
+ dropout_rate (float): Dropout rate.
+ normalize_before (bool):
+ True: use layer_norm before each sub-block.
+ False: use layer_norm after each sub-block.
+ """
+
+ def __init__(
+ self,
+ size: int,
+ self_attn: torch.nn.Module,
+ feed_forward: Optional[nn.Module] = None,
+ feed_forward_macaron: Optional[nn.Module] = None,
+ conv_module: Optional[nn.Module] = None,
+ dropout_rate: float = 0.1,
+ normalize_before: bool = True,
+ ):
+ """Construct an EncoderLayer object."""
+ super().__init__()
+ self.self_attn = self_attn
+ self.feed_forward = feed_forward
+ self.feed_forward_macaron = feed_forward_macaron
+ self.conv_module = conv_module
+ self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module
+ self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module
+ if feed_forward_macaron is not None:
+ self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
+ self.ff_scale = 0.5
+ else:
+ self.ff_scale = 1.0
+ if self.conv_module is not None:
+ self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module
+ self.norm_final = nn.LayerNorm(
+ size, eps=1e-5) # for the final output of the block
+ self.dropout = nn.Dropout(dropout_rate)
+ self.size = size
+ self.normalize_before = normalize_before
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask: torch.Tensor,
+ pos_emb: torch.Tensor,
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Compute encoded features.
+
+ Args:
+ x (torch.Tensor): (#batch, time, size)
+ mask (torch.Tensor): Mask tensor for the input (#batch, time๏ผtime),
+ (0, 0, 0) means fake mask.
+ pos_emb (torch.Tensor): positional encoding, must not be None
+ for ConformerEncoderLayer.
+ mask_pad (torch.Tensor): batch padding mask used for conv module.
+ (#batch, 1๏ผtime), (0, 0, 0) means fake mask.
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
+ (#batch=1, size, cache_t2)
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, size).
+ torch.Tensor: Mask tensor (#batch, time, time).
+ torch.Tensor: att_cache tensor,
+ (#batch=1, head, cache_t1 + time, d_k * 2).
+ torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
+ """
+
+ # whether to use macaron style
+ if self.feed_forward_macaron is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm_ff_macaron(x)
+ x = residual + self.ff_scale * self.dropout(
+ self.feed_forward_macaron(x))
+ if not self.normalize_before:
+ x = self.norm_ff_macaron(x)
+ # multi-headed self-attention module
+ residual = x
+ if self.normalize_before:
+ x = self.norm_mha(x)
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
+ att_cache)
+ x = residual + self.dropout(x_att)
+ if not self.normalize_before:
+ x = self.norm_mha(x)
+
+ # convolution module
+ # Fake new cnn cache here, and then change it in conv_module
+ new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
+ if self.conv_module is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm_conv(x)
+ x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
+ x = residual + self.dropout(x)
+
+ if not self.normalize_before:
+ x = self.norm_conv(x)
+
+ # feed forward module
+ residual = x
+ if self.normalize_before:
+ x = self.norm_ff(x)
+
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
+ if not self.normalize_before:
+ x = self.norm_ff(x)
+
+ if self.conv_module is not None:
+ x = self.norm_final(x)
+
+ return x, mask, new_att_cache, new_cnn_cache
diff --git a/inspiremusic/transformer/label_smoothing_loss.py b/inspiremusic/transformer/label_smoothing_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..73ce09f35bfacb86730e39ef72a097f8a04e469b
--- /dev/null
+++ b/inspiremusic/transformer/label_smoothing_loss.py
@@ -0,0 +1,97 @@
+# Copyright (c) 2019 Shigeki Karita
+# 2020 Mobvoi Inc (Binbin Zhang)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Label smoothing module."""
+
+import torch
+from torch import nn
+
+
+class LabelSmoothingLoss(nn.Module):
+ """Label-smoothing loss.
+
+ In a standard CE loss, the label's data distribution is:
+ [0,1,2] ->
+ [
+ [1.0, 0.0, 0.0],
+ [0.0, 1.0, 0.0],
+ [0.0, 0.0, 1.0],
+ ]
+
+ In the smoothing version CE Loss,some probabilities
+ are taken from the true label prob (1.0) and are divided
+ among other labels.
+
+ e.g.
+ smoothing=0.1
+ [0,1,2] ->
+ [
+ [0.9, 0.05, 0.05],
+ [0.05, 0.9, 0.05],
+ [0.05, 0.05, 0.9],
+ ]
+
+ Args:
+ size (int): the number of class
+ padding_idx (int): padding class id which will be ignored for loss
+ smoothing (float): smoothing rate (0.0 means the conventional CE)
+ normalize_length (bool):
+ normalize loss by sequence length if True
+ normalize loss by batch size if False
+ """
+
+ def __init__(self,
+ size: int,
+ padding_idx: int,
+ smoothing: float,
+ normalize_length: bool = False):
+ """Construct an LabelSmoothingLoss object."""
+ super(LabelSmoothingLoss, self).__init__()
+ self.criterion = nn.KLDivLoss(reduction="none")
+ self.padding_idx = padding_idx
+ self.confidence = 1.0 - smoothing
+ self.smoothing = smoothing
+ self.size = size
+ self.normalize_length = normalize_length
+
+ def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
+ """Compute loss between x and target.
+
+ The model outputs and data labels tensors are flatten to
+ (batch*seqlen, class) shape and a mask is applied to the
+ padding part which should not be calculated for loss.
+
+ Args:
+ x (torch.Tensor): prediction (batch, seqlen, class)
+ target (torch.Tensor):
+ target signal masked with self.padding_id (batch, seqlen)
+ Returns:
+ loss (torch.Tensor) : The KL loss, scalar float value
+ """
+ assert x.size(2) == self.size
+ batch_size = x.size(0)
+ x = x.view(-1, self.size)
+ target = target.view(-1)
+ # use zeros_like instead of torch.no_grad() for true_dist,
+ # since no_grad() can not be exported by JIT
+ true_dist = torch.zeros_like(x)
+ true_dist.fill_(self.smoothing / (self.size - 1))
+ ignore = target == self.padding_idx # (B,)
+
+ total = len(target) - ignore.sum().item()
+ target = target.masked_fill(ignore, 0) # avoid -1 index
+ true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
+ kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
+ denom = total if self.normalize_length else batch_size
+ return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
diff --git a/inspiremusic/transformer/positionwise_feed_forward.py b/inspiremusic/transformer/positionwise_feed_forward.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7a2cf6e7315e3a5ed2794423daff0a59cc5b208
--- /dev/null
+++ b/inspiremusic/transformer/positionwise_feed_forward.py
@@ -0,0 +1,115 @@
+# Copyright (c) 2019 Shigeki Karita
+# 2020 Mobvoi Inc (Binbin Zhang)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Positionwise feed forward layer definition."""
+
+import torch
+
+
+class PositionwiseFeedForward(torch.nn.Module):
+ """Positionwise feed forward layer.
+
+ FeedForward are appied on each position of the sequence.
+ The output dim is same with the input dim.
+
+ Args:
+ idim (int): Input dimenstion.
+ hidden_units (int): The number of hidden units.
+ dropout_rate (float): Dropout rate.
+ activation (torch.nn.Module): Activation function
+ """
+
+ def __init__(
+ self,
+ idim: int,
+ hidden_units: int,
+ dropout_rate: float,
+ activation: torch.nn.Module = torch.nn.ReLU(),
+ ):
+ """Construct a PositionwiseFeedForward object."""
+ super(PositionwiseFeedForward, self).__init__()
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
+ self.activation = activation
+ self.dropout = torch.nn.Dropout(dropout_rate)
+ self.w_2 = torch.nn.Linear(hidden_units, idim)
+
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
+ """Forward function.
+
+ Args:
+ xs: input tensor (B, L, D)
+ Returns:
+ output tensor, (B, L, D)
+ """
+ return self.w_2(self.dropout(self.activation(self.w_1(xs))))
+
+
+class MoEFFNLayer(torch.nn.Module):
+ """
+ Mixture of expert with Positionwise feed forward layer
+ See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf
+ The output dim is same with the input dim.
+
+ Modified from https://github.com/Lightning-AI/lit-gpt/pull/823
+ https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
+ Args:
+ n_expert: number of expert.
+ n_expert_per_token: The actual number of experts used for each frame
+ idim (int): Input dimenstion.
+ hidden_units (int): The number of hidden units.
+ dropout_rate (float): Dropout rate.
+ activation (torch.nn.Module): Activation function
+ """
+
+ def __init__(
+ self,
+ n_expert: int,
+ n_expert_per_token: int,
+ idim: int,
+ hidden_units: int,
+ dropout_rate: float,
+ activation: torch.nn.Module = torch.nn.ReLU(),
+ ):
+ super(MoEFFNLayer, self).__init__()
+ self.gate = torch.nn.Linear(idim, n_expert, bias=False)
+ self.experts = torch.nn.ModuleList(
+ PositionwiseFeedForward(idim, hidden_units, dropout_rate,
+ activation) for _ in range(n_expert))
+ self.n_expert_per_token = n_expert_per_token
+
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
+ """Foward function.
+ Args:
+ xs: input tensor (B, L, D)
+ Returns:
+ output tensor, (B, L, D)
+
+ """
+ B, L, D = xs.size(
+ ) # batch size, sequence length, embedding dimension (idim)
+ xs = xs.view(-1, D) # (B*L, D)
+ router = self.gate(xs) # (B*L, n_expert)
+ logits, indices = torch.topk(
+ router, self.n_expert_per_token
+ ) # probs:(B*L, n_expert), indices: (B*L, n_expert)
+ weights = torch.nn.functional.softmax(
+ logits, dim=1,
+ dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token)
+ output = torch.zeros_like(xs) # (B*L, D)
+ for i, expert in enumerate(self.experts):
+ mask = indices == i
+ batch_idx, ith_expert = torch.where(mask)
+ output[batch_idx] += weights[batch_idx, ith_expert, None] * expert(
+ xs[batch_idx])
+ return output.view(B, L, D)
diff --git a/inspiremusic/transformer/qwen_encoder.py b/inspiremusic/transformer/qwen_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..92e3305ee9a72b370aba87634a30a2a3b0c7ff83
--- /dev/null
+++ b/inspiremusic/transformer/qwen_encoder.py
@@ -0,0 +1,165 @@
+# Copyright (c) 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch.nn as nn
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from inspiremusic.utils.mask import make_pad_mask
+from inspiremusic.utils.hinter import hint_once
+
+class QwenEncoder(nn.Module):
+ def __init__(
+ self,
+ input_size: int,
+ pretrain_path: str = "Qwen/Qwen2.0-0.5B",
+ trainable: bool = False,
+ do_fusion_emb: bool = False,
+ fusion_drop_rate: float = 0.0,
+ ):
+ super(QwenEncoder, self).__init__()
+ self.input_size = input_size
+ self.trainable = trainable
+ self.model = AutoModelForCausalLM.from_pretrained(pretrain_path, device_map="cpu")
+ self._output_size = self.model.config.hidden_size
+ self.do_fusion_emb = do_fusion_emb
+ self.hidden_norm = torch.nn.LayerNorm(self._output_size)
+ self.fusion_dropout = nn.Dropout(fusion_drop_rate)
+ if do_fusion_emb:
+ self.fusion_layer = torch.nn.Linear(self._output_size * 2, self._output_size)
+ self.emb_norm = torch.nn.LayerNorm(self._output_size)
+ self.fusion_norm = torch.nn.LayerNorm(self._output_size)
+ from inspiremusic.transformer.activation import Swish
+ self.fusion_act = Swish(self)
+
+ if not self.trainable:
+ self.model.eval()
+
+ def output_size(self) -> int:
+ return self._output_size
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ ilens: torch.Tensor,
+ ):
+ device = input_ids.device
+ input_ids = torch.clamp(input_ids, min=0, max=None)
+ input_masks = (~make_pad_mask(ilens)).to(device).long()
+ if not self.trainable:
+ with torch.no_grad():
+ model_outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=input_masks,
+ output_hidden_states=True
+ )
+ else:
+ model_outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=input_masks,
+ output_hidden_states=True
+ )
+ outs = model_outputs.hidden_states[-1]
+ outs = self.hidden_norm(outs)
+ if self.do_fusion_emb:
+ hint_once("fuse embedding and LM outputs", "fuse_emb")
+ outs = self.fusion_dropout(self.fusion_act(outs))
+ emb = model_outputs.hidden_states[0]
+ emb = self.fusion_dropout(self.fusion_act(self.emb_norm(emb)))
+ outs = self.fusion_layer(
+ torch.cat([outs, emb], dim=-1)
+ )
+ outs = self.fusion_act(self.fusion_norm(outs))
+
+ return outs, ilens
+
+
+class QwenEmbeddingEncoder(nn.Module):
+ def __init__(
+ self,
+ input_size: int,
+ pretrain_path: str = "Qwen/Qwen2.0-0.5B",
+ ):
+ super(QwenEmbeddingEncoder, self).__init__()
+ self.input_size = input_size
+ from transformers import Qwen2ForCausalLM
+ self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path, device_map="cpu", attn_implementation="flash_attention_2")
+ self._output_size = self.model.config.hidden_size
+
+ def output_size(self) -> int:
+ return self._output_size
+
+ def forward(
+ self,
+ input_embeds: torch.Tensor,
+ ilens: torch.Tensor,
+ ):
+ input_masks = (~make_pad_mask(ilens)).to(input_embeds.device).long()
+
+ outs = self.model(
+ inputs_embeds=input_embeds,
+ attention_mask=input_masks,
+ output_hidden_states=True,
+ return_dict=True,
+ )
+
+ return outs.hidden_states[-1], input_masks
+
+ def forward_one_step(self, xs, masks, cache=None):
+
+ outs = self.model(
+ inputs_embeds=xs,
+ attention_mask=masks,
+ output_hidden_states=True,
+ return_dict=True,
+ use_cache=True,
+ past_key_values=cache,
+ )
+ xs = outs.hidden_states[-1]
+ new_cache = outs.past_key_values
+
+ return xs, masks, new_cache
+
+
+class QwenInputOnlyEncoder(nn.Module):
+ def __init__(
+ self,
+ input_size: int,
+ pretrain_path: str = "Qwen/Qwen2.0-0.5B",
+ ):
+ super(QwenInputOnlyEncoder, self).__init__()
+ self.input_size = input_size
+ from transformers import Qwen2ForCausalLM
+ model = Qwen2ForCausalLM.from_pretrained(pretrain_path, device_map="cpu", attn_implementation="flash_attention_2")
+ self.embed = model.model.embed_tokens
+ for p in self.embed.parameters():
+ p.requires_grad = False
+ # set text embedding to non-trainable
+
+ # self.post_embed = model.model.rotary_emb
+ self._output_size = model.config.hidden_size
+
+ def output_size(self) -> int:
+ return self._output_size
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ ilens: torch.Tensor,
+ ):
+ input_masks = (~make_pad_mask(ilens)).to(input_ids.device).long()
+
+ outs = self.embed(input_ids)
+
+ return outs, input_masks
+
\ No newline at end of file
diff --git a/inspiremusic/transformer/subsampling.py b/inspiremusic/transformer/subsampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..08cf7d4224c0f4bf95853590f7e2f97b387f44f9
--- /dev/null
+++ b/inspiremusic/transformer/subsampling.py
@@ -0,0 +1,384 @@
+# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
+# 2024 Alibaba Inc (Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from ESPnet(https://github.com/espnet/espnet)
+"""Subsampling layer definition."""
+
+from typing import Tuple, Union
+
+import torch
+
+
+class BaseSubsampling(torch.nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.right_context = 0
+ self.subsampling_rate = 1
+
+ def position_encoding(self, offset: Union[int, torch.Tensor],
+ size: int) -> torch.Tensor:
+ return self.pos_enc.position_encoding(offset, size)
+
+
+class EmbedinigNoSubsampling(BaseSubsampling):
+ """Embedding input without subsampling
+ """
+
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
+ pos_enc_class: torch.nn.Module):
+ super().__init__()
+ self.embed = torch.nn.Embedding(idim, odim)
+ self.pos_enc = pos_enc_class
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Input x.
+
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+ Returns:
+ torch.Tensor: linear input tensor (#batch, time', odim),
+ where time' = time .
+ torch.Tensor: linear input mask (#batch, 1, time'),
+ where time' = time .
+
+ """
+ x = self.embed(x)
+ x, pos_emb = self.pos_enc(x, offset)
+ return x, pos_emb, x_mask
+
+
+class LinearNoSubsampling(BaseSubsampling):
+ """Linear transform the input without subsampling
+
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
+ pos_enc_class: torch.nn.Module):
+ """Construct an linear object."""
+ super().__init__()
+ self.out = torch.nn.Sequential(
+ torch.nn.Linear(idim, odim),
+ torch.nn.LayerNorm(odim, eps=1e-5),
+ torch.nn.Dropout(dropout_rate),
+ )
+ self.pos_enc = pos_enc_class
+ self.right_context = 0
+ self.subsampling_rate = 1
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Input x.
+
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+ Returns:
+ torch.Tensor: linear input tensor (#batch, time', odim),
+ where time' = time .
+ torch.Tensor: linear input mask (#batch, 1, time'),
+ where time' = time .
+
+ """
+ x = self.out(x)
+ x, pos_emb = self.pos_enc(x, offset)
+ return x, pos_emb, x_mask
+
+
+class Conv1dSubsampling2(BaseSubsampling):
+ """Convolutional 1D subsampling (to 1/2 length).
+ It is designed for Whisper, ref:
+ https://github.com/openai/whisper/blob/main/whisper/model.py
+
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
+ pos_enc_class: torch.nn.Module):
+ """Construct an Conv1dSubsampling2 object."""
+ super().__init__()
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1),
+ torch.nn.GELU(),
+ torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1),
+ torch.nn.GELU(),
+ )
+ self.pos_enc = pos_enc_class
+ # The right context for every conv layer is computed by:
+ # (kernel_size - 1) * frame_rate_of_this_layer
+ self.subsampling_rate = 2
+ # 4 = (3 - 1) * 1 + (3 - 1) * 1
+ self.right_context = 4
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Subsample x.
+
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+ Returns:
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
+ where time' = time // 2.
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
+ where time' = time // 2.
+ torch.Tensor: positional encoding
+
+ """
+ time = x.size(1)
+ x = x.transpose(1, 2) # (b, f, t)
+ x = self.conv(x)
+ x = x.transpose(1, 2) # (b, t, f)
+ x, pos_emb = self.pos_enc(x, offset)
+ return x, pos_emb, x_mask[:, :, (time + 1) % 2::2]
+
+
+class Conv2dSubsampling4(BaseSubsampling):
+ """Convolutional 2D subsampling (to 1/4 length).
+
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
+ pos_enc_class: torch.nn.Module):
+ """Construct an Conv2dSubsampling4 object."""
+ super().__init__()
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv2d(1, odim, 3, 2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(odim, odim, 3, 2),
+ torch.nn.ReLU(),
+ )
+ self.out = torch.nn.Sequential(
+ torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
+ self.pos_enc = pos_enc_class
+ # The right context for every conv layer is computed by:
+ # (kernel_size - 1) * frame_rate_of_this_layer
+ self.subsampling_rate = 4
+ # 6 = (3 - 1) * 1 + (3 - 1) * 2
+ self.right_context = 6
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Subsample x.
+
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+ Returns:
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
+ where time' = time // 4.
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
+ where time' = time // 4.
+ torch.Tensor: positional encoding
+
+ """
+ x = x.unsqueeze(1) # (b, c=1, t, f)
+ x = self.conv(x)
+ b, c, t, f = x.size()
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
+ x, pos_emb = self.pos_enc(x, offset)
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
+
+
+class Conv2dSubsampling6(BaseSubsampling):
+ """Convolutional 2D subsampling (to 1/6 length).
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+ pos_enc (torch.nn.Module): Custom position encoding layer.
+ """
+
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
+ pos_enc_class: torch.nn.Module):
+ """Construct an Conv2dSubsampling6 object."""
+ super().__init__()
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv2d(1, odim, 3, 2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(odim, odim, 5, 3),
+ torch.nn.ReLU(),
+ )
+ self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3),
+ odim)
+ self.pos_enc = pos_enc_class
+ # 10 = (3 - 1) * 1 + (5 - 1) * 2
+ self.subsampling_rate = 6
+ self.right_context = 10
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Subsample x.
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+ Returns:
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
+ where time' = time // 6.
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
+ where time' = time // 6.
+ torch.Tensor: positional encoding
+ """
+ x = x.unsqueeze(1) # (b, c, t, f)
+ x = self.conv(x)
+ b, c, t, f = x.size()
+ x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
+ x, pos_emb = self.pos_enc(x, offset)
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
+
+
+class Conv2dSubsampling8(BaseSubsampling):
+ """Convolutional 2D subsampling (to 1/8 length).
+
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
+ pos_enc_class: torch.nn.Module):
+ """Construct an Conv2dSubsampling8 object."""
+ super().__init__()
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv2d(1, odim, 3, 2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(odim, odim, 3, 2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(odim, odim, 3, 2),
+ torch.nn.ReLU(),
+ )
+ self.linear = torch.nn.Linear(
+ odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
+ self.pos_enc = pos_enc_class
+ self.subsampling_rate = 8
+ # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
+ self.right_context = 14
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Subsample x.
+
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+ Returns:
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
+ where time' = time // 8.
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
+ where time' = time // 8.
+ torch.Tensor: positional encoding
+ """
+ x = x.unsqueeze(1) # (b, c, t, f)
+ x = self.conv(x)
+ b, c, t, f = x.size()
+ x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
+ x, pos_emb = self.pos_enc(x, offset)
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]
+
+
+class LegacyLinearNoSubsampling(BaseSubsampling):
+ """Linear transform the input without subsampling
+
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
+ pos_enc_class: torch.nn.Module):
+ """Construct an linear object."""
+ super().__init__()
+ self.out = torch.nn.Sequential(
+ torch.nn.Linear(idim, odim),
+ torch.nn.LayerNorm(odim, eps=1e-5),
+ torch.nn.Dropout(dropout_rate),
+ torch.nn.ReLU(),
+ )
+ self.pos_enc = pos_enc_class
+ self.right_context = 0
+ self.subsampling_rate = 1
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Input x.
+
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+ Returns:
+ torch.Tensor: linear input tensor (#batch, time', odim),
+ where time' = time .
+ torch.Tensor: linear input mask (#batch, 1, time'),
+ where time' = time .
+
+ """
+
+ x = self.out(x)
+ x, pos_emb = self.pos_enc(x, offset)
+ return x, pos_emb, x_mask
diff --git a/inspiremusic/utils/__init__.py b/inspiremusic/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/inspiremusic/utils/audio_utils.py b/inspiremusic/utils/audio_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7b5d2afe3fd572b4f9fdfcece8b88d99870052a
--- /dev/null
+++ b/inspiremusic/utils/audio_utils.py
@@ -0,0 +1,623 @@
+# Copyright (c) 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import io
+import logging
+import re
+import sys
+import inspect
+import random
+import typing as tp
+from functools import partial
+
+import omegaconf
+import torch
+import torchaudio
+import numpy as np
+
+from typing_extensions import Literal
+from typing import (
+ Any,
+ Union,
+ Iterable,
+ List,
+ Dict,
+ Optional,
+ Tuple,
+)
+
+from librosa.filters import mel as librosa_mel_fn
+from scipy.io.wavfile import read
+
+_BoolLike_co = Union[bool, np.bool_]
+_IntLike_co = Union[_BoolLike_co, int, "np.integer[Any]"]
+_FloatLike_co = Union[_IntLike_co, float, "np.floating[Any]"]
+
+def process_audio(file_path, target_sample_rate=24000):
+ audio, sample_rate = torchaudio.load(file_path)
+ # Check if the audio needs to be resampled
+ if sample_rate != target_sample_rate:
+ audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)(audio)
+ # Convert stereo to mono (if necessary)
+ audio = audio.mean(dim=0, keepdim=True) if audio.size(0) == 2 else audio
+ return audio, target_sample_rate
+
+def load_wav(full_path):
+ sampling_rate, data = read(full_path)
+ return data, sampling_rate
+
+def dynamic_range_compression(x, C=1, clip_val=1e-5):
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
+
+
+def dynamic_range_decompression(x, C=1):
+ return np.exp(x) / C
+
+
+def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
+ return torch.log(torch.clamp(x, min=clip_val) * C)
+
+
+def dynamic_range_decompression_torch(x, C=1):
+ return torch.exp(x) / C
+
+
+def spectral_normalize_torch(magnitudes):
+ output = dynamic_range_compression_torch(magnitudes)
+ return output
+
+
+def spectral_de_normalize_torch(magnitudes):
+ output = dynamic_range_decompression_torch(magnitudes)
+ return output
+
+def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
+ if torch.min(y) < -1.0:
+ print("min value is ", torch.min(y))
+ if torch.max(y) > 1.0:
+ print("max value is ", torch.max(y))
+
+ # global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned
+ mel_basis = {}
+ hann_window = {}
+ if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
+ mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
+
+ y = torch.nn.functional.pad(
+ y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
+ )
+ y = y.squeeze(1)
+
+ spec = torch.view_as_real(
+ torch.stft(
+ y,
+ n_fft,
+ hop_length=hop_size,
+ win_length=win_size,
+ window=hann_window[str(y.device)],
+ center=center,
+ pad_mode="reflect",
+ normalized=False,
+ onesided=True,
+ return_complex=True,
+ )
+ )
+
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
+
+ spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
+ spec = spectral_normalize_torch(spec)
+
+ return spec
+
+
+def fade_out(audio: torch.Tensor, sample_rate: int,
+ fade_duration: float) -> torch.Tensor:
+ """
+ Apply a linear fade-out effect to the given audio waveform.
+
+ Parameters:
+ audio (torch.Tensor): The audio waveform tensor.
+ sample_rate (int): Sample rate of the audio.
+ fade_duration (float): Duration of the fade-out effect in seconds.
+
+ Returns:
+ torch.Tensor: The audio with the fade-out effect applied.
+ """
+ fade_samples = int(fade_duration * sample_rate)
+
+ if fade_samples > audio.shape[1]:
+ fade_samples = audio.shape[
+ 1] # use the whole length of audio if necessary
+
+ fade_out_envelope = torch.linspace(1.0, 0.0, fade_samples,
+ dtype=audio.dtype, device=audio.device)
+
+ fade_section = audio[:, -fade_samples:].clone()
+
+ fade_section *= fade_out_envelope
+
+ faded_audio = audio.clone()
+ faded_audio[:, -fade_samples:] = fade_section
+
+ return faded_audio
+
+def split_wav_into_chunks(num_samples, wav, max_chunk_size, minimum_chunk_size=720):
+ num_chunks = (num_samples + max_chunk_size - 1) // max_chunk_size # Ceiling division
+ wav_chunks = []
+ for i in range(num_chunks):
+ start_idx = i * max_chunk_size
+ end_idx = min(start_idx + max_chunk_size, num_samples)
+ if (end_idx - start_idx) >= minimum_chunk_size:
+ if len(wav.shape) == 2:
+ chunk = wav[:,start_idx:end_idx]
+ else:
+ chunk = wav[start_idx:end_idx]
+ wav_chunks.append(chunk)
+ else:
+ print(f"{num_samples}:{num_chunks}, chunk size={(end_idx - start_idx)} is lower then minimum_chunk_size!")
+ return wav_chunks
+
+def tiny(x: Union[float, np.ndarray]) -> _FloatLike_co:
+ """Compute the tiny-value corresponding to an input's data type.
+ """
+ # Make sure we have an array view
+ x = np.asarray(x)
+
+ # Only floating types generate a tiny
+ if np.issubdtype(x.dtype, np.floating) or np.issubdtype(
+ x.dtype, np.complexfloating
+ ):
+ dtype = x.dtype
+ else:
+ dtype = np.dtype(np.float32)
+
+ return np.finfo(dtype).tiny
+
+def detect_silence(audio, sample_rate, threshold=0.05, min_silence_duration=1):
+ """
+ Detects the first occurrence of silence in the audio.
+
+ Parameters:
+ audio (Tensor): The audio waveform.
+ sample_rate (int): The sample rate of the audio.
+ threshold (float): The threshold below which the signal is considered silent.
+ min_silence_duration (float): The minimum duration of silence in seconds.
+
+ Returns:
+ int: The timestamp (in samples) where the silence starts.
+ """
+ # Convert the audio to a numpy array for easier manipulation
+ audio_np = audio.numpy().flatten()
+ # Calculate the energy of the signal
+ energy = np.abs(audio_np)
+ # Find the indices where the energy is below the threshold
+ silent_indices = np.where(energy < threshold)[0]
+ # Find the start and end of contiguous silent regions
+ silent_regions = np.split(silent_indices, np.where(np.diff(silent_indices) != 1)[0] + 1)
+ # Filter out regions that are too short
+ min_silence_samples = int(min_silence_duration * sample_rate)
+ for region in silent_regions:
+ if len(region) >= min_silence_samples:
+ return region[0]
+
+ # If no silence is found, return the length of the audio
+ return len(audio_np)
+
+def trim_audio(waveform, sample_rate=24000, threshold=0.05, min_silence_duration=1, minimum_silence_start_sample=24000):
+ """
+ Trims the audio from the beginning to the first occurrence of silence.
+
+ Parameters:
+ waveform (Tensor): The waveform data to the input audio file.
+ sample_rate (int): Sample rate of the input audio file.
+ threshold (float): The threshold below which the signal is considered silent.
+ min_silence_duration (float): The minimum duration of silence in seconds.
+ """
+ # Detect the first occurrence of silence
+ silence_start_sample = detect_silence(waveform, sample_rate, threshold, min_silence_duration)
+ if silence_start_sample > minimum_silence_start_sample :
+ trimmed_waveform = waveform[:silence_start_sample]
+ else:
+ trimmed_waveform = waveform[:minimum_silence_start_sample]
+ if isinstance(trimmed_waveform, torch.Tensor):
+ return trimmed_waveform
+ else:
+ return trimmed_waveform.unsqueeze()
+
+def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14,
+ loudness_compressor: bool = False, energy_floor: float = 2e-3):
+ """Normalize an input signal to a user loudness in dB LKFS.
+ Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
+
+ Args:
+ wav (torch.Tensor): Input multichannel audio data.
+ sample_rate (int): Sample rate.
+ loudness_headroom_db (float): Target loudness of the output in dB LUFS.
+ loudness_compressor (bool): Uses tanh for soft clipping.
+ energy_floor (float): anything below that RMS level will not be rescaled.
+ Returns:
+ torch.Tensor: Loudness normalized output data.
+ """
+ energy = wav.pow(2).mean().sqrt().item()
+ if energy < energy_floor:
+ return wav
+ transform = torchaudio.transforms.Loudness(sample_rate)
+ input_loudness_db = transform(wav).item()
+ # calculate the gain needed to scale to the desired loudness level
+ delta_loudness = -loudness_headroom_db - input_loudness_db
+ gain = 10.0 ** (delta_loudness / 20.0)
+ output = gain * wav
+ if loudness_compressor:
+ output = torch.tanh(output)
+ assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
+ return output
+
+def normalize(
+ S: np.ndarray,
+ *,
+ norm: Optional[float] = np.inf,
+ axis: Optional[int] = 0,
+ threshold: Optional[_FloatLike_co] = None,
+ fill: Optional[bool] = None,
+) -> np.ndarray:
+ """Normalize an array along a chosen axis.
+ """
+ # Avoid div-by-zero
+ if threshold is None:
+ threshold = tiny(S)
+
+ elif threshold <= 0:
+ raise ParameterError(f"threshold={threshold} must be strictly positive")
+
+ if fill not in [None, False, True]:
+ raise ParameterError(f"fill={fill} must be None or boolean")
+
+ if not np.isfinite(S).all():
+ raise ParameterError("Input must be finite")
+
+ # All norms only depend on magnitude, let's do that first
+ S = S.numpy()
+ mag = np.abs(S).astype(float)
+
+ # For max/min norms, filling with 1 works
+ fill_norm = 1
+
+ if norm is None:
+ return S
+
+ elif norm == np.inf:
+ length = np.max(mag, axis=axis, keepdims=True)
+
+ elif norm == -np.inf:
+ length = np.min(mag, axis=axis, keepdims=True)
+
+ elif norm == 0:
+ if fill is True:
+ raise ParameterError("Cannot normalize with norm=0 and fill=True")
+
+ length = np.sum(mag > 0, axis=axis, keepdims=True, dtype=mag.dtype)
+
+ elif np.issubdtype(type(norm), np.number) and norm > 0:
+ length = np.sum(mag**norm, axis=axis, keepdims=True) ** (1.0 / norm)
+
+ if axis is None:
+ fill_norm = mag.size ** (-1.0 / norm)
+ else:
+ fill_norm = mag.shape[axis] ** (-1.0 / norm)
+
+ else:
+ raise ParameterError(f"Unsupported norm: {repr(norm)}")
+
+ # indices where norm is below the threshold
+ small_idx = length < threshold
+
+ Snorm = np.empty_like(S)
+ if fill is None:
+ # Leave small indices un-normalized
+ length[small_idx] = 1.0
+ Snorm[:] = S / length
+
+ elif fill:
+ # If we have a non-zero fill value, we locate those entries by
+ # doing a nan-divide.
+ # If S was finite, then length is finite (except for small positions)
+ length[small_idx] = np.nan
+ Snorm[:] = S / length
+ Snorm[np.isnan(Snorm)] = fill_norm
+ else:
+ # Set small values to zero by doing an inf-divide.
+ # This is safe (by IEEE-754) as long as S is finite.
+ length[small_idx] = np.inf
+ Snorm[:] = S / length
+
+ return Snorm
+
+def normalize_audio(wav: torch.Tensor, normalize: bool = True,
+ strategy: str = 'peak', peak_clip_headroom_db: float = 1,
+ rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
+ loudness_compressor: bool = False, log_clipping: bool = False,
+ sample_rate: tp.Optional[int] = None,
+ stem_name: tp.Optional[str] = None) -> torch.Tensor:
+ """Normalize the audio according to the prescribed strategy (see after).
+
+ Args:
+ wav (torch.Tensor): Audio data.
+ normalize (bool): if `True` (default), normalizes according to the prescribed
+ strategy (see after). If `False`, the strategy is only used in case clipping
+ would happen.
+ strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
+ i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
+ with extra headroom to avoid clipping. 'clip' just clips.
+ peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
+ rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
+ than the `peak_clip` one to avoid further clipping.
+ loudness_headroom_db (float): Target loudness for loudness normalization.
+ loudness_compressor (bool): If True, uses tanh based soft clipping.
+ log_clipping (bool): If True, basic logging on stderr when clipping still
+ occurs despite strategy (only for 'rms').
+ sample_rate (int): Sample rate for the audio data (required for loudness).
+ stem_name (str, optional): Stem name for clipping logging.
+ Returns:
+ torch.Tensor: Normalized audio.
+ """
+ scale_peak = 10 ** (-peak_clip_headroom_db / 20)
+ scale_rms = 10 ** (-rms_headroom_db / 20)
+ if strategy == 'peak':
+ rescaling = (scale_peak / wav.abs().max())
+ if normalize or rescaling < 1:
+ wav = wav * rescaling
+ elif strategy == 'clip':
+ wav = wav.clamp(-scale_peak, scale_peak)
+ elif strategy == 'rms':
+ mono = wav.mean(dim=0)
+ rescaling = scale_rms / mono.pow(2).mean().sqrt()
+ if normalize or rescaling < 1:
+ wav = wav * rescaling
+ _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
+ elif strategy == 'loudness':
+ assert sample_rate is not None, "Loudness normalization requires sample rate."
+ wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor)
+ _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
+ else:
+ assert wav.abs().max() < 1
+ assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'"
+ return wav
+
+
+def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
+ """
+ Convert audio to float 32 bits PCM format.
+ Args:
+ wav (torch.tensor): Input wav tensor
+ Returns:
+ same wav in float32 PCM format
+ """
+ if wav.dtype.is_floating_point:
+ return wav
+ elif wav.dtype == torch.int16:
+ return wav.float() / 2**15
+ elif wav.dtype == torch.int32:
+ return wav.float() / 2**31
+ raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
+
+
+def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
+ """Convert audio to int 16 bits PCM format.
+
+ ..Warning:: There exist many formula for doing this conversion. None are perfect
+ due to the asymmetry of the int16 range. One either have possible clipping, DC offset,
+ or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom,
+ it is possible that `i16_pcm(f32_pcm)) != Identity`.
+ Args:
+ wav (torch.tensor): Input wav tensor
+ Returns:
+ same wav in float16 PCM format
+ """
+ if wav.dtype.is_floating_point:
+ assert wav.abs().max() <= 1
+ candidate = (wav * 2 ** 15).round()
+ if candidate.max() >= 2 ** 15: # clipping would occur
+ candidate = (wav * (2 ** 15 - 1)).round()
+ return candidate.short()
+ else:
+ assert wav.dtype == torch.int16
+ return wav
+
+
+def compress(wav: torch.Tensor, sr: int,
+ target_format: tp.Literal["mp3", "ogg", "flac"] = "mp3",
+ bitrate: str = "128k") -> tp.Tuple[torch.Tensor, int]:
+ """Convert audio wave form to a specified lossy format: mp3, ogg, flac
+
+ Args:
+ wav (torch.Tensor): Input wav tensor.
+ sr (int): Sampling rate.
+ target_format (str): Compression format (e.g., 'mp3').
+ bitrate (str): Bitrate for compression.
+
+ Returns:
+ Tuple of compressed WAV tensor and sampling rate.
+ """
+
+ # Extract the bit rate from string (e.g., '128k')
+ match = re.search(r"\d+(\.\d+)?", str(bitrate))
+ parsed_bitrate = float(match.group()) if match else None
+ assert parsed_bitrate, f"Invalid bitrate specified (got {parsed_bitrate})"
+ try:
+ # Create a virtual file instead of saving to disk
+ buffer = io.BytesIO()
+
+ torchaudio.save(
+ buffer, wav, sr, format=target_format, bits_per_sample=parsed_bitrate,
+ )
+ # Move to the beginning of the file
+ buffer.seek(0)
+ compressed_wav, sr = torchaudio.load(buffer)
+ return compressed_wav, sr
+
+ except RuntimeError:
+ logger.warning(
+ f"compression failed skipping compression: {format} {parsed_bitrate}"
+ )
+ return wav, sr
+
+
+def get_mp3(wav_tensor: torch.Tensor, sr: int, bitrate: str = "128k") -> torch.Tensor:
+ """Convert a batch of audio files to MP3 format, maintaining the original shape.
+
+ This function takes a batch of audio files represented as a PyTorch tensor, converts
+ them to MP3 format using the specified bitrate, and returns the batch in the same
+ shape as the input.
+
+ Args:
+ wav_tensor (torch.Tensor): Batch of audio files represented as a tensor.
+ Shape should be (batch_size, channels, length).
+ sr (int): Sampling rate of the audio.
+ bitrate (str): Bitrate for MP3 conversion, default is '128k'.
+
+ Returns:
+ torch.Tensor: Batch of audio files converted to MP3 format, with the same
+ shape as the input tensor.
+ """
+ device = wav_tensor.device
+ batch_size, channels, original_length = wav_tensor.shape
+
+ # Flatten tensor for conversion and move to CPU
+ wav_tensor_flat = wav_tensor.view(1, -1).cpu()
+
+ # Convert to MP3 format with specified bitrate
+ wav_tensor_flat, _ = compress(wav_tensor_flat, sr, bitrate=bitrate)
+
+ # Reshape back to original batch format and trim or pad if necessary
+ wav_tensor = wav_tensor_flat.view(batch_size, channels, -1)
+ compressed_length = wav_tensor.shape[-1]
+ if compressed_length > original_length:
+ wav_tensor = wav_tensor[:, :, :original_length] # Trim excess frames
+ elif compressed_length < original_length:
+ padding = torch.zeros(
+ batch_size, channels, original_length - compressed_length, device=device
+ )
+ wav_tensor = torch.cat((wav_tensor, padding), dim=-1) # Pad with zeros
+
+ # Move tensor back to the original device
+ return wav_tensor.to(device)
+
+
+def get_aac(
+ wav_tensor: torch.Tensor,
+ sr: int,
+ bitrate: str = "128k",
+ lowpass_freq: tp.Optional[int] = None,
+) -> torch.Tensor:
+ """Converts a batch of audio tensors to AAC format and then back to tensors.
+
+ This function first saves the input tensor batch as WAV files, then uses FFmpeg to convert
+ these WAV files to AAC format. Finally, it loads the AAC files back into tensors.
+
+ Args:
+ wav_tensor (torch.Tensor): A batch of audio files represented as a tensor.
+ Shape should be (batch_size, channels, length).
+ sr (int): Sampling rate of the audio.
+ bitrate (str): Bitrate for AAC conversion, default is '128k'.
+ lowpass_freq (Optional[int]): Frequency for a low-pass filter. If None, no filter is applied.
+
+ Returns:
+ torch.Tensor: Batch of audio files converted to AAC and back, with the same
+ shape as the input tensor.
+ """
+ import tempfile
+ import subprocess
+
+ device = wav_tensor.device
+ batch_size, channels, original_length = wav_tensor.shape
+
+ # Parse the bitrate value from the string
+ match = re.search(r"\d+(\.\d+)?", bitrate)
+ parsed_bitrate = (
+ match.group() if match else "128"
+ ) # Default to 128 if parsing fails
+
+ # Flatten tensor for conversion and move to CPU
+ wav_tensor_flat = wav_tensor.view(1, -1).cpu()
+
+ with tempfile.NamedTemporaryFile(
+ suffix=".wav"
+ ) as f_in, tempfile.NamedTemporaryFile(suffix=".aac") as f_out:
+ input_path, output_path = f_in.name, f_out.name
+
+ # Save the tensor as a WAV file
+ torchaudio.save(input_path, wav_tensor_flat, sr, backend="ffmpeg")
+
+ # Prepare FFmpeg command for AAC conversion
+ command = [
+ "ffmpeg",
+ "-y",
+ "-i",
+ input_path,
+ "-ar",
+ str(sr),
+ "-b:a",
+ f"{parsed_bitrate}k",
+ "-c:a",
+ "aac",
+ ]
+ if lowpass_freq is not None:
+ command += ["-cutoff", str(lowpass_freq)]
+ command.append(output_path)
+
+ try:
+ # Run FFmpeg and suppress output
+ subprocess.run(command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
+
+ # Load the AAC audio back into a tensor
+ aac_tensor, _ = torchaudio.load(output_path, backend="ffmpeg")
+ except Exception as exc:
+ raise RuntimeError(
+ "Failed to run command " ".join(command)} "
+ "(Often this means ffmpeg is not installed or the encoder is not supported, "
+ "make sure you installed an older version ffmpeg<5)"
+ ) from exc
+
+ original_length_flat = batch_size * channels * original_length
+ compressed_length_flat = aac_tensor.shape[-1]
+
+ # Trim excess frames
+ if compressed_length_flat > original_length_flat:
+ aac_tensor = aac_tensor[:, :original_length_flat]
+
+ # Pad the shortedn frames
+ elif compressed_length_flat < original_length_flat:
+ padding = torch.zeros(
+ 1, original_length_flat - compressed_length_flat, device=device
+ )
+ aac_tensor = torch.cat((aac_tensor, padding), dim=-1)
+
+ # Reshape and adjust length to match original tensor
+ wav_tensor = aac_tensor.view(batch_size, channels, -1)
+ compressed_length = wav_tensor.shape[-1]
+
+ assert compressed_length == original_length, (
+ "AAC-compressed audio does not have the same frames as original one. "
+ "One reason can be ffmpeg is not installed and used as proper backed "
+ "for torchaudio, or the AAC encoder is not correct. Run "
+ "`torchaudio.utils.ffmpeg_utils.get_audio_encoders()` and make sure we see entry for"
+ "AAC in the output."
+ )
+ return wav_tensor.to(device)
\ No newline at end of file
diff --git a/inspiremusic/utils/binary.py b/inspiremusic/utils/binary.py
new file mode 100644
index 0000000000000000000000000000000000000000..862cb467850b9af8e8b035939c018984e590e79c
--- /dev/null
+++ b/inspiremusic/utils/binary.py
@@ -0,0 +1,155 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Raw binary format for Encodec compressed audio. Actual compression API is in `encodec.compress`."""
+import io
+import json
+import struct
+import typing as tp
+
+# format is `ECDC` magic code, followed by the header size as uint32.
+# Then an uint8 indicates the protocol version (0.)
+# The header is then provided as json and should contain all required
+# informations for decoding. A raw stream of bytes is then provided
+# and should be interpretable using the json header.
+_encodec_header_struct = struct.Struct('!4sBI')
+_ENCODEC_MAGIC = b'ECDC'
+
+
+def write_ecdc_header(fo: tp.IO[bytes], metadata: tp.Any):
+ meta_dumped = json.dumps(metadata).encode('utf-8')
+ version = 0
+ header = _encodec_header_struct.pack(_ENCODEC_MAGIC, version,
+ len(meta_dumped))
+ fo.write(header)
+ fo.write(meta_dumped)
+ fo.flush()
+
+
+def _read_exactly(fo: tp.IO[bytes], size: int) -> bytes:
+ buf = b""
+ while len(buf) < size:
+ new_buf = fo.read(size)
+ if not new_buf:
+ raise EOFError("Impossible to read enough data from the stream, "
+ f"{size} bytes remaining.")
+ buf += new_buf
+ size -= len(new_buf)
+ return buf
+
+
+def read_ecdc_header(fo: tp.IO[bytes]):
+ header_bytes = _read_exactly(fo, _encodec_header_struct.size)
+ magic, version, meta_size = _encodec_header_struct.unpack(header_bytes)
+ if magic != _ENCODEC_MAGIC:
+ raise ValueError("File is not in ECDC format.")
+ if version != 0:
+ raise ValueError("Version not supported.")
+ meta_bytes = _read_exactly(fo, meta_size)
+ return json.loads(meta_bytes.decode('utf-8'))
+
+
+class BitPacker:
+ """Simple bit packer to handle ints with a non standard width, e.g. 10 bits.
+ Note that for some bandwidth (1.5, 3), the codebook representation
+ will not cover an integer number of bytes.
+
+ Args:
+ bits (int): number of bits per value that will be pushed.
+ fo (IO[bytes]): file-object to push the bytes to.
+ """
+
+ def __init__(self, bits: int, fo: tp.IO[bytes]):
+ self._current_value = 0
+ self._current_bits = 0
+ self.bits = bits
+ self.fo = fo
+
+ def push(self, value: int):
+ """Push a new value to the stream. This will immediately
+ write as many uint8 as possible to the underlying file-object."""
+ self._current_value += (value << self._current_bits)
+ self._current_bits += self.bits
+ while self._current_bits >= 8:
+ lower_8bits = self._current_value & 0xff
+ self._current_bits -= 8
+ self._current_value >>= 8
+ self.fo.write(bytes([lower_8bits]))
+
+ def flush(self):
+ """Flushes the remaining partial uint8, call this at the end
+ of the stream to encode."""
+ if self._current_bits:
+ self.fo.write(bytes([self._current_value]))
+ self._current_value = 0
+ self._current_bits = 0
+ self.fo.flush()
+
+
+class BitUnpacker:
+ """BitUnpacker does the opposite of `BitPacker`.
+
+ Args:
+ bits (int): number of bits of the values to decode.
+ fo (IO[bytes]): file-object to push the bytes to.
+ """
+
+ def __init__(self, bits: int, fo: tp.IO[bytes]):
+ self.bits = bits
+ self.fo = fo
+ self._mask = (1 << bits) - 1
+ self._current_value = 0
+ self._current_bits = 0
+
+ def pull(self) -> tp.Optional[int]:
+ """
+ Pull a single value from the stream, potentially reading some
+ extra bytes from the underlying file-object.
+ Returns `None` when reaching the end of the stream.
+ """
+ while self._current_bits < self.bits:
+ buf = self.fo.read(1)
+ if not buf:
+ return None
+ character = buf[0]
+ self._current_value += character << self._current_bits
+ self._current_bits += 8
+
+ out = self._current_value & self._mask
+ self._current_value >>= self.bits
+ self._current_bits -= self.bits
+ return out
+
+
+def test():
+ import torch
+ torch.manual_seed(1234)
+ for rep in range(4):
+ length: int = torch.randint(10, 2_000, (1, )).item()
+ bits: int = torch.randint(1, 16, (1, )).item()
+ tokens: tp.List[int] = torch.randint(2**bits, (length, )).tolist()
+ rebuilt: tp.List[int] = []
+ buf = io.BytesIO()
+ packer = BitPacker(bits, buf)
+ for token in tokens:
+ packer.push(token)
+ packer.flush()
+ buf.seek(0)
+ unpacker = BitUnpacker(bits, buf)
+ while True:
+ value = unpacker.pull()
+ if value is None:
+ break
+ rebuilt.append(value)
+ assert len(rebuilt) >= len(tokens), (len(rebuilt), len(tokens))
+ # The flushing mechanism might lead to "ghost" values at the end of the stream.
+ assert len(rebuilt) <= len(tokens) + 8 // bits, (len(rebuilt),
+ len(tokens), bits)
+ for idx, (a, b) in enumerate(zip(tokens, rebuilt)):
+ assert a == b, (idx, a, b)
+
+
+if __name__ == '__main__':
+ test()
diff --git a/inspiremusic/utils/class_utils.py b/inspiremusic/utils/class_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a6ddd08863b54afd24e92268dbd1faf15114b3e
--- /dev/null
+++ b/inspiremusic/utils/class_utils.py
@@ -0,0 +1,71 @@
+# Copyright [2023-11-28]
+# 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+
+from inspiremusic.transformer.activation import Swish
+from inspiremusic.transformer.subsampling import (
+ LinearNoSubsampling,
+ EmbedinigNoSubsampling,
+ Conv1dSubsampling2,
+ Conv2dSubsampling4,
+ Conv2dSubsampling6,
+ Conv2dSubsampling8,
+)
+from inspiremusic.transformer.embedding import (PositionalEncoding,
+ RelPositionalEncoding,
+ WhisperPositionalEncoding,
+ LearnablePositionalEncoding,
+ NoPositionalEncoding)
+from inspiremusic.transformer.attention import (MultiHeadedAttention,
+ RelPositionMultiHeadedAttention)
+from inspiremusic.transformer.embedding import EspnetRelPositionalEncoding
+from inspiremusic.transformer.subsampling import LegacyLinearNoSubsampling
+
+
+INSPIREMUSIC_ACTIVATION_CLASSES = {
+ "hardtanh": torch.nn.Hardtanh,
+ "tanh": torch.nn.Tanh,
+ "relu": torch.nn.ReLU,
+ "selu": torch.nn.SELU,
+ "swish": getattr(torch.nn, "SiLU", Swish),
+ "gelu": torch.nn.GELU,
+}
+
+INSPIREMUSIC_SUBSAMPLE_CLASSES = {
+ "linear": LinearNoSubsampling,
+ "linear_legacy": LegacyLinearNoSubsampling,
+ "embed": EmbedinigNoSubsampling,
+ "conv1d2": Conv1dSubsampling2,
+ "conv2d": Conv2dSubsampling4,
+ "conv2d6": Conv2dSubsampling6,
+ "conv2d8": Conv2dSubsampling8,
+ 'paraformer_dummy': torch.nn.Identity
+}
+
+INSPIREMUSIC_EMB_CLASSES = {
+ "embed": PositionalEncoding,
+ "abs_pos": PositionalEncoding,
+ "rel_pos": RelPositionalEncoding,
+ "rel_pos_espnet": EspnetRelPositionalEncoding,
+ "no_pos": NoPositionalEncoding,
+ "abs_pos_whisper": WhisperPositionalEncoding,
+ "embed_learnable_pe": LearnablePositionalEncoding,
+}
+
+INSPIREMUSIC_ATTENTION_CLASSES = {
+ "selfattn": MultiHeadedAttention,
+ "rel_selfattn": RelPositionMultiHeadedAttention,
+}
+
diff --git a/inspiremusic/utils/common.py b/inspiremusic/utils/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..d888fa4b71743142a968d70b7c0bcd09a32de70b
--- /dev/null
+++ b/inspiremusic/utils/common.py
@@ -0,0 +1,173 @@
+# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
+# 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from ESPnet(https://github.com/espnet/espnet)
+"""Unility functions for Transformer."""
+
+from typing import List
+
+import torch
+IGNORE_ID = -1
+
+MUSIC_STRUCTURE_LABELS = ["intro", "verse1", "chorus", "verse2", "outro"]
+
+def pad_list(xs: List[torch.Tensor], pad_value: int):
+ """Perform padding for the list of tensors.
+
+ Args:
+ xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
+ pad_value (float): Value for padding.
+
+ Returns:
+ Tensor: Padded tensor (B, Tmax, `*`).
+
+ Examples:
+ >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
+ >>> x
+ [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
+ >>> pad_list(x, 0)
+ tensor([[1., 1., 1., 1.],
+ [1., 1., 0., 0.],
+ [1., 0., 0., 0.]])
+
+ """
+ max_len = max([len(item) for item in xs])
+ batchs = len(xs)
+ ndim = xs[0].ndim
+ if ndim == 1:
+ pad_res = torch.zeros(batchs,
+ max_len,
+ dtype=xs[0].dtype,
+ device=xs[0].device)
+ elif ndim == 2:
+ pad_res = torch.zeros(batchs,
+ max_len,
+ xs[0].shape[1],
+ dtype=xs[0].dtype,
+ device=xs[0].device)
+ elif ndim == 3:
+ pad_res = torch.zeros(batchs,
+ max_len,
+ xs[0].shape[1],
+ xs[0].shape[2],
+ dtype=xs[0].dtype,
+ device=xs[0].device)
+ else:
+ raise ValueError(f"Unsupported ndim: {ndim}")
+ pad_res.fill_(pad_value)
+ for i in range(batchs):
+ pad_res[i, :len(xs[i])] = xs[i]
+ return pad_res
+
+
+def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor,
+ ignore_label: int) -> torch.Tensor:
+ """Calculate accuracy.
+
+ Args:
+ pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
+ pad_targets (LongTensor): Target label tensors (B, Lmax).
+ ignore_label (int): Ignore label id.
+
+ Returns:
+ torch.Tensor: Accuracy value (0.0 - 1.0).
+
+ """
+ pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1),
+ pad_outputs.size(1)).argmax(2)
+ mask = pad_targets != ignore_label
+ numerator = torch.sum(
+ pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
+ denominator = torch.sum(mask)
+ return (numerator / denominator).detach()
+
+
+def get_padding(kernel_size, dilation=1):
+ return int((kernel_size * dilation - dilation) / 2)
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(mean, std)
+
+def topk_sampling(weighted_scores, decoded_tokens, top_k=25):
+ zeros = weighted_scores.new_ones(weighted_scores.shape) * float('-inf')
+ values,indices = torch.topk(weighted_scores,top_k)
+ zeros.scatter_(-1, indices, values)
+ return random_sampling(zeros,decoded_tokens)
+
+# Repetition Aware Sampling in VALL-E 2
+
+def ras_sampling(weighted_scores, decoded_tokens, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
+ top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
+ rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item()
+ if rep_num >= win_size * tau_r:
+ top_ids = random_sampling(weighted_scores, decoded_tokens)
+ return top_ids
+
+def caras_sampling(weighted_scores, decoded_tokens, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
+ weighted_scores, cfg_weighted_scores = weighted_scores
+ top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
+ rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item()
+ if rep_num >= win_size * tau_r:
+ top_ids = random_sampling(cfg_weighted_scores, decoded_tokens)
+ return top_ids
+
+def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
+ prob, indices = [], []
+ cum_prob = 0.0
+ sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True)
+ for i in range(len(sorted_idx)):
+ # sampling both top-p and numbers.
+ if cum_prob < top_p and len(prob) < top_k:
+ cum_prob += sorted_value[i]
+ prob.append(sorted_value[i])
+ indices.append(sorted_idx[i])
+ else:
+ break
+ prob = torch.tensor(prob).to(weighted_scores)
+ indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
+ top_ids = indices[prob.multinomial(1, replacement=True)]
+ return top_ids
+
+
+def random_sampling(weighted_scores, decoded_tokens):
+ top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
+ return top_ids
+
+
+def fade_in_out(fade_in_mel, fade_out_mel, window):
+ device = fade_in_mel.device
+ fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
+ mel_overlap_len = int(window.shape[0] / 2)
+ fade_in_mel[:, :, :mel_overlap_len] = fade_in_mel[:, :, :mel_overlap_len] * window[:mel_overlap_len] + \
+ fade_out_mel[:, :, -mel_overlap_len:] * window[mel_overlap_len:]
+ return fade_in_mel.to(device)
+
+def set_all_random_seed(seed):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ assert mask.dtype == torch.bool
+ assert dtype in [torch.float32, torch.bfloat16, torch.float16]
+ mask = mask.to(dtype)
+ # attention mask bias
+ # NOTE(Mddct): torch.finfo jit issues
+ # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
+ mask = (1.0 - mask) * torch.finfo(dtype).min
+ return mask
\ No newline at end of file
diff --git a/inspiremusic/utils/data_utils.py b/inspiremusic/utils/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..094a280f25ac50be81253854e62790ce50e5e7de
--- /dev/null
+++ b/inspiremusic/utils/data_utils.py
@@ -0,0 +1,104 @@
+# Copyright (c) 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from torch.utils.data import DataLoader
+from inspiremusic.dataset.dataset import Dataset
+import numpy as np
+import librosa
+
+def audio_process_dataset_and_dataloader(args, configs):
+ input_dataset = Dataset(args.input_data, data_pipeline=configs['data_pipeline'], mode='processing', shuffle=True, partition=True)
+ # do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
+ input_data_loader = DataLoader(input_dataset,
+ batch_size=None,
+ pin_memory=args.pin_memory,
+ num_workers=args.num_workers,
+ prefetch_factor=args.prefetch)
+ return input_dataset, input_data_loader
+
+def is_silent(wav_path, threshold=0.01, frame_length=2048, hop_length=512):
+ y, sr = librosa.load(wav_path, sr=None)
+ rms = librosa.feature.rms(y=y, frame_length=frame_length, hop_length=hop_length)[0]
+ silent_frames = np.sum(rms < threshold) / len(rms)
+ silence_fraction_threshold = 0.95
+ return silent_frames >= silence_fraction_threshold
+
+def rich_captions(text=None, tags=None, lyrics=None, chorus="verse", start_time=0.0, end_time=30.0):
+ if text is None and tags is None and lyrics is None:
+ return None
+ else:
+ if start_time is None:
+ start_time = 0.0
+ if end_time is None:
+ end_time = 30.0
+ if chorus is None:
+ chorus = "verse"
+ captions = f"<|{start_time:.1f}|><|{chorus}|>"
+ if tags is not None:
+ captions += f"<|{tags}|>"
+ if text is not None:
+ captions += f"<|{text}|>"
+ if lyrics is not None:
+ captions += f"<|lyrics|><|{lyrics}|>"
+ captions += f"<|{end_time:.1f}|>"
+ return captions
+
+def process_tags(infile, outfile, timefile = None):
+ key_list = []
+ with open(infile, "r") as f:
+ for line in f:
+ sec = line.strip()
+ key_list.append(sec)
+ f.close()
+ if timefile is None:
+ with open(outfile, 'w') as f:
+ for k in key_list:
+ parts = k.rsplit('_', 1)
+ text = parts[0].replace('_', ' ') + ', ' + parts[1]
+ caption = rich_captions(text, None, None)
+ if caption is not None:
+ f.write("%s\t%s\n" %(k, caption))
+ f.close()
+ else:
+ times = {}
+ with open(timefile, "r") as f:
+ for line in f:
+ sec = line.strip().split("\t")
+ if len(sec) == 2 :
+ times[sec[0]] = sec[1]
+ f.close()
+
+ with open(outfile, 'w') as f:
+ for k in key_list:
+ parts = k.rsplit('_', 1)
+ text = parts[0].replace('_', ' ') + ', ' + parts[1]
+ if k in times.keys():
+ caption = rich_captions(text, None, None, "verse", 0.0, float(times[k]))
+ if caption is not None:
+ f.write("%s\t%s\n" %(k, caption))
+ f.close()
+
+def process_trans(infile, outfile):
+ trans = {}
+ with open(infile, "r") as f:
+ for line in f:
+ sec = line.strip().split("\t")
+ if len(sec) == 2:
+ trans[sec[0]] = sec[1]
+ else:
+ print(line)
+ f.close()
+ with open(outfile, 'w') as f:
+ for k, v in trans.items():
+ f.write("%s\t%s\n" %(k, rich_captions(v)))
+ f.close()
\ No newline at end of file
diff --git a/inspiremusic/utils/executor.py b/inspiremusic/utils/executor.py
new file mode 100644
index 0000000000000000000000000000000000000000..da933e4942c6095da87fb26e964c39c809925ad3
--- /dev/null
+++ b/inspiremusic/utils/executor.py
@@ -0,0 +1,121 @@
+# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
+# 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from contextlib import nullcontext
+import os
+
+import torch
+import torch.distributed as dist
+
+from inspiremusic.utils.train_utils import update_parameter_and_lr, log_per_step, log_per_save, batch_forward, batch_backward, save_model, inspiremusic_join
+from torch.cuda.amp import GradScaler, autocast
+
+class Executor:
+
+ def __init__(self):
+ self.step = 0
+ self.epoch = 0
+ self.rank = int(os.environ.get('RANK', 0))
+ self.device = torch.device('cuda:{}'.format(self.rank))
+
+ def train_one_epoch(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join, scaler=None):
+ ''' Train one epoch
+ '''
+
+ lr = optimizer.param_groups[0]['lr']
+ logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
+ logging.info('using accumulate grad, new batch size is {} times'
+ ' larger than before'.format(info_dict['accum_grad']))
+ # A context manager to be used in conjunction with an instance of
+ # torch.nn.parallel.DistributedDataParallel to be able to train
+ # with uneven inputs across participating processes.
+ model.train()
+ model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
+ with model_context():
+ for batch_idx, batch_dict in enumerate(train_data_loader):
+ info_dict["tag"] = "TRAIN"
+ info_dict["step"] = self.step
+ info_dict["epoch"] = self.epoch
+ info_dict["batch_idx"] = batch_idx
+ if inspiremusic_join(group_join, info_dict):
+ break
+
+ # Disable gradient synchronizations across DDP processes.
+ # Within this context, gradients will be accumulated on module
+ # variables, which will later be synchronized.
+ if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
+ context = model.no_sync
+ # Used for single gpu training and DDP gradient synchronization
+ # processes.
+ else:
+ context = nullcontext
+
+ with context():
+ with autocast(enabled=scaler is not None):
+ info_dict = batch_forward(model, batch_dict, info_dict, scaler)
+ info_dict = batch_backward(model, info_dict, scaler)
+
+ info_dict = update_parameter_and_lr(model, optimizer, scheduler, info_dict, scaler)
+ log_per_step(writer, info_dict)
+ # NOTE specify save_per_step in inspiremusic.yaml if you want to enable step save
+ if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
+ (batch_idx + 1) % info_dict["accum_grad"] == 0:
+ dist.barrier()
+ self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False, scaler=scaler)
+ model.train()
+ if (batch_idx + 1) % info_dict["accum_grad"] == 0:
+ self.step += 1
+ dist.barrier()
+ self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True, scaler=scaler)
+
+ @torch.inference_mode()
+ def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True, capped_at=5, scaler=None):
+ ''' Cross validation on
+ '''
+ logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step + 1, on_batch_end, self.rank))
+ model.eval()
+ total_num_utts, total_loss_dict = 0, {} # avoid division by 0
+ stop = capped_at
+ for batch_idx, batch_dict in enumerate(cv_data_loader):
+ info_dict["tag"] = "CV"
+ info_dict["step"] = self.step
+ info_dict["epoch"] = self.epoch
+ info_dict["batch_idx"] = batch_idx
+
+ num_utts = len(batch_dict["utts"])
+ total_num_utts += num_utts
+
+ if capped_at>0:
+ if stop <= 0:
+ continue
+ else:
+ stop -= 1
+
+ with autocast(enabled=scaler is not None):
+ info_dict = batch_forward(model, batch_dict, info_dict, scaler)
+
+ for k, v in info_dict['loss_dict'].items():
+ if k not in total_loss_dict:
+ total_loss_dict[k] = []
+ total_loss_dict[k].append(v.item() * num_utts)
+ log_per_step(None, info_dict)
+
+ for k, v in total_loss_dict.items():
+ total_loss_dict[k] = sum(v) / total_num_utts
+ info_dict['loss_dict'] = total_loss_dict
+ log_per_save(writer, info_dict)
+ model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step + 1)
+ save_model(model, model_name, info_dict)
diff --git a/inspiremusic/utils/file_utils.py b/inspiremusic/utils/file_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..56019b4fc015681e2b4d7a79cc2e186c6007d079
--- /dev/null
+++ b/inspiremusic/utils/file_utils.py
@@ -0,0 +1,79 @@
+# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
+# 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import torchaudio
+import logging
+logging.getLogger('matplotlib').setLevel(logging.WARNING)
+logging.basicConfig(level=logging.DEBUG,
+ format='%(asctime)s %(levelname)s %(message)s')
+
+def read_trans(list_file):
+ trans = {}
+ with open(list_file, 'r', encoding='utf8') as fin:
+ for line in fin:
+ sec = line.strip().split("\t")
+ if len(sec) > 1:
+ if sec[0] not in trans.keys():
+ trans[sec[0]] = sec[1]
+ return trans
+
+def read_scp(list_file):
+ scp = {}
+ with open(list_file, 'r', encoding='utf8') as fin:
+ for line in fin:
+ sec = line.strip().split(" ")
+ if len(sec) > 1:
+ if sec[0] not in scp.keys():
+ scp[sec[0]] = sec[1]
+ return scp
+
+def read_lists(list_file):
+ lists = []
+ with open(list_file, 'r', encoding='utf8') as fin:
+ for line in fin:
+ lists.append(line.strip())
+ return lists
+
+
+def read_json_lists(list_file):
+ lists = read_lists(list_file)
+ results = {}
+ for fn in lists:
+ with open(fn, 'r', encoding='utf8') as fin:
+ results.update(json.load(fin))
+ return results
+
+
+def load_wav(wav, target_sr):
+ audio, sample_rate = torchaudio.load(wav)
+ audio = audio.mean(dim=0, keepdim=True)
+ if sample_rate != target_sr:
+ assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
+ audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
+ return audio
+
+
+def speed_change(waveform, sample_rate, speed_factor: str):
+ effects = [
+ ["tempo", speed_factor], # speed_factor
+ ["rate", f"{sample_rate}"]
+ ]
+ augmented_waveform, new_sample_rate = torchaudio.sox_effects.apply_effects_tensor(
+ waveform,
+ sample_rate,
+ effects
+ )
+ return augmented_waveform, new_sample_rate
diff --git a/inspiremusic/utils/frontend_utils.py b/inspiremusic/utils/frontend_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..45239b854718afe0dc3ce194aa37a0ac6bb760eb
--- /dev/null
+++ b/inspiremusic/utils/frontend_utils.py
@@ -0,0 +1,126 @@
+# Copyright (c) 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
+
+
+# whether contain chinese character
+def contains_chinese(text):
+ return bool(chinese_char_pattern.search(text))
+
+
+# replace special symbol
+def replace_corner_mark(text):
+ text = text.replace('ยฒ', 'ๅนณๆน')
+ text = text.replace('ยณ', '็ซๆน')
+ return text
+
+
+# remove meaningless symbol
+def remove_bracket(text):
+ text = text.replace('๏ผ', '').replace('๏ผ', '')
+ text = text.replace('ใ', '').replace('ใ', '')
+ text = text.replace('`', '').replace('`', '')
+ text = text.replace("โโ", " ")
+ return text
+
+
+# spell Arabic numerals
+def spell_out_number(text: str, inflect_parser):
+ new_text = []
+ st = None
+ for i, c in enumerate(text):
+ if not c.isdigit():
+ if st is not None:
+ num_str = inflect_parser.number_to_words(text[st: i])
+ new_text.append(num_str)
+ st = None
+ new_text.append(c)
+ else:
+ if st is None:
+ st = i
+ if st is not None and st < len(text):
+ num_str = inflect_parser.number_to_words(text[st:])
+ new_text.append(num_str)
+ return ''.join(new_text)
+
+
+# split paragrah logic๏ผ
+# 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
+# 2. cal sentence len according to lang
+# 3. split sentence according to puncatation
+def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False):
+ def calc_utt_length(_text: str):
+ if lang == "zh":
+ return len(_text)
+ else:
+ return len(tokenize(_text))
+
+ def should_merge(_text: str):
+ if lang == "zh":
+ return len(_text) < merge_len
+ else:
+ return len(tokenize(_text)) < merge_len
+
+ if lang == "zh":
+ pounc = ['ใ', '๏ผ', '๏ผ', '๏ผ', '๏ผ', 'ใ', '.', '?', '!', ';']
+ else:
+ pounc = ['.', '?', '!', ';', ':']
+ if comma_split:
+ pounc.extend(['๏ผ', ','])
+ st = 0
+ utts = []
+ for i, c in enumerate(text):
+ if c in pounc:
+ if len(text[st: i]) > 0:
+ utts.append(text[st: i] + c)
+ if i + 1 < len(text) and text[i + 1] in ['"', 'โ']:
+ tmp = utts.pop(-1)
+ utts.append(tmp + text[i + 1])
+ st = i + 2
+ else:
+ st = i + 1
+ if len(utts) == 0:
+ if lang == "zh":
+ utts.append(text + 'ใ')
+ else:
+ utts.append(text + '.')
+ final_utts = []
+ cur_utt = ""
+ for utt in utts:
+ if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
+ final_utts.append(cur_utt)
+ cur_utt = ""
+ cur_utt = cur_utt + utt
+ if len(cur_utt) > 0:
+ if should_merge(cur_utt) and len(final_utts) != 0:
+ final_utts[-1] = final_utts[-1] + cur_utt
+ else:
+ final_utts.append(cur_utt)
+
+ return final_utts
+
+
+# remove blank between chinese character
+def replace_blank(text: str):
+ out_str = []
+ for i, c in enumerate(text):
+ if c == " ":
+ if ((text[i + 1].isascii() and text[i + 1] != " ") and
+ (text[i - 1].isascii() and text[i - 1] != " ")):
+ out_str.append(c)
+ else:
+ out_str.append(c)
+ return "".join(out_str)
diff --git a/inspiremusic/utils/hinter.py b/inspiremusic/utils/hinter.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b32194336ed680a12772e2e3d3ed48d13cbcf59
--- /dev/null
+++ b/inspiremusic/utils/hinter.py
@@ -0,0 +1,12 @@
+import sys
+import torch.distributed
+import logging
+
+HINTED = set()
+
+
+def hint_once(content, uid, rank=None):
+ if (rank is None) or (not torch.distributed.is_initialized()) or torch.distributed.get_rank() == rank:
+ if uid not in HINTED:
+ logging.info(content, stacklevel=3)
+ HINTED.add(uid)
\ No newline at end of file
diff --git a/inspiremusic/utils/losses.py b/inspiremusic/utils/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..78efd3b72ff4c61971f4732626c43613f812761d
--- /dev/null
+++ b/inspiremusic/utils/losses.py
@@ -0,0 +1,20 @@
+import torch
+import torch.nn.functional as F
+
+
+def tpr_loss(disc_real_outputs, disc_generated_outputs, tau):
+ loss = 0
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
+ m_DG = torch.median((dr - dg))
+ L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG])
+ loss += tau - F.relu(tau - L_rel)
+ return loss
+
+
+def mel_loss(real_speech, generated_speech, mel_transforms):
+ loss = 0
+ for transform in mel_transforms:
+ mel_r = transform(real_speech)
+ mel_g = transform(generated_speech)
+ loss += F.l1_loss(mel_g, mel_r)
+ return loss
diff --git a/inspiremusic/utils/mask.py b/inspiremusic/utils/mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3d767dfe0b9a37295c49d6670e5600e60fe075e
--- /dev/null
+++ b/inspiremusic/utils/mask.py
@@ -0,0 +1,227 @@
+# Copyright (c) 2019 Shigeki Karita
+# 2020 Mobvoi Inc (Binbin Zhang)
+# 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+'''
+def subsequent_mask(
+ size: int,
+ device: torch.device = torch.device("cpu"),
+) -> torch.Tensor:
+ """Create mask for subsequent steps (size, size).
+
+ This mask is used only in decoder which works in an auto-regressive mode.
+ This means the current step could only do attention with its left steps.
+
+ In encoder, fully attention is used when streaming is not necessary and
+ the sequence is not long. In this case, no attention mask is needed.
+
+ When streaming is need, chunk-based attention is used in encoder. See
+ subsequent_chunk_mask for the chunk-based attention mask.
+
+ Args:
+ size (int): size of mask
+ str device (str): "cpu" or "cuda" or torch.Tensor.device
+ dtype (torch.device): result dtype
+
+ Returns:
+ torch.Tensor: mask
+
+ Examples:
+ >>> subsequent_mask(3)
+ [[1, 0, 0],
+ [1, 1, 0],
+ [1, 1, 1]]
+ """
+ ret = torch.ones(size, size, device=device, dtype=torch.bool)
+ return torch.tril(ret)
+'''
+
+
+def subsequent_mask(
+ size: int,
+ device: torch.device = torch.device("cpu"),
+) -> torch.Tensor:
+ """Create mask for subsequent steps (size, size).
+
+ This mask is used only in decoder which works in an auto-regressive mode.
+ This means the current step could only do attention with its left steps.
+
+ In encoder, fully attention is used when streaming is not necessary and
+ the sequence is not long. In this case, no attention mask is needed.
+
+ When streaming is need, chunk-based attention is used in encoder. See
+ subsequent_chunk_mask for the chunk-based attention mask.
+
+ Args:
+ size (int): size of mask
+ str device (str): "cpu" or "cuda" or torch.Tensor.device
+ dtype (torch.device): result dtype
+
+ Returns:
+ torch.Tensor: mask
+
+ Examples:
+ >>> subsequent_mask(3)
+ [[1, 0, 0],
+ [1, 1, 0],
+ [1, 1, 1]]
+ """
+ arange = torch.arange(size, device=device)
+ mask = arange.expand(size, size)
+ arange = arange.unsqueeze(-1)
+ mask = mask <= arange
+ return mask
+
+
+def subsequent_chunk_mask(
+ size: int,
+ chunk_size: int,
+ num_left_chunks: int = -1,
+ device: torch.device = torch.device("cpu"),
+) -> torch.Tensor:
+ """Create mask for subsequent steps (size, size) with chunk size,
+ this is for streaming encoder
+
+ Args:
+ size (int): size of mask
+ chunk_size (int): size of chunk
+ num_left_chunks (int): number of left chunks
+ <0: use full chunk
+ >=0: use num_left_chunks
+ device (torch.device): "cpu" or "cuda" or torch.Tensor.device
+
+ Returns:
+ torch.Tensor: mask
+
+ Examples:
+ >>> subsequent_chunk_mask(4, 2)
+ [[1, 1, 0, 0],
+ [1, 1, 0, 0],
+ [1, 1, 1, 1],
+ [1, 1, 1, 1]]
+ """
+ ret = torch.zeros(size, size, device=device, dtype=torch.bool)
+ for i in range(size):
+ if num_left_chunks < 0:
+ start = 0
+ else:
+ start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
+ ending = min((i // chunk_size + 1) * chunk_size, size)
+ ret[i, start:ending] = True
+ return ret
+
+
+def add_optional_chunk_mask(xs: torch.Tensor,
+ masks: torch.Tensor,
+ use_dynamic_chunk: bool,
+ use_dynamic_left_chunk: bool,
+ decoding_chunk_size: int,
+ static_chunk_size: int,
+ num_decoding_left_chunks: int,
+ enable_full_context: bool = True):
+ """ Apply optional mask for encoder.
+
+ Args:
+ xs (torch.Tensor): padded input, (B, L, D), L for max length
+ mask (torch.Tensor): mask for xs, (B, 1, L)
+ use_dynamic_chunk (bool): whether to use dynamic chunk or not
+ use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
+ training.
+ decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
+ 0: default for training, use random dynamic chunk.
+ <0: for decoding, use full chunk.
+ >0: for decoding, use fixed chunk size as set.
+ static_chunk_size (int): chunk size for static chunk training/decoding
+ if it's greater than 0, if use_dynamic_chunk is true,
+ this parameter will be ignored
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
+ the chunk size is decoding_chunk_size.
+ >=0: use num_decoding_left_chunks
+ <0: use all left chunks
+ enable_full_context (bool):
+ True: chunk size is either [1, 25] or full context(max_len)
+ False: chunk size ~ U[1, 25]
+
+ Returns:
+ torch.Tensor: chunk mask of the input xs.
+ """
+ # Whether to use chunk mask or not
+ if use_dynamic_chunk:
+ max_len = xs.size(1)
+ if decoding_chunk_size < 0:
+ chunk_size = max_len
+ num_left_chunks = -1
+ elif decoding_chunk_size > 0:
+ chunk_size = decoding_chunk_size
+ num_left_chunks = num_decoding_left_chunks
+ else:
+ # chunk size is either [1, 25] or full context(max_len).
+ # Since we use 4 times subsampling and allow up to 1s(100 frames)
+ # delay, the maximum frame is 100 / 4 = 25.
+ chunk_size = torch.randint(1, max_len, (1, )).item()
+ num_left_chunks = -1
+ if chunk_size > max_len // 2 and enable_full_context:
+ chunk_size = max_len
+ else:
+ chunk_size = chunk_size % 25 + 1
+ if use_dynamic_left_chunk:
+ max_left_chunks = (max_len - 1) // chunk_size
+ num_left_chunks = torch.randint(0, max_left_chunks,
+ (1, )).item()
+ chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
+ num_left_chunks,
+ xs.device) # (L, L)
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
+ chunk_masks = masks & chunk_masks # (B, L, L)
+ elif static_chunk_size > 0:
+ num_left_chunks = num_decoding_left_chunks
+ chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
+ num_left_chunks,
+ xs.device) # (L, L)
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
+ chunk_masks = masks & chunk_masks # (B, L, L)
+ else:
+ chunk_masks = masks
+ return chunk_masks
+
+
+def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
+ """Make mask tensor containing indices of padded part.
+
+ See description of make_non_pad_mask.
+
+ Args:
+ lengths (torch.Tensor): Batch of lengths (B,).
+ Returns:
+ torch.Tensor: Mask tensor containing indices of padded part.
+
+ Examples:
+ >>> lengths = [5, 3, 2]
+ >>> make_pad_mask(lengths)
+ masks = [[0, 0, 0, 0 ,0],
+ [0, 0, 0, 1, 1],
+ [0, 0, 1, 1, 1]]
+ """
+ batch_size = lengths.size(0)
+ max_len = max_len if max_len > 0 else lengths.max().item()
+ seq_range = torch.arange(0,
+ max_len,
+ dtype=torch.int64,
+ device=lengths.device)
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
+ seq_length_expand = lengths.unsqueeze(-1)
+ mask = seq_range_expand >= seq_length_expand
+ return mask
diff --git a/inspiremusic/utils/scheduler.py b/inspiremusic/utils/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea5f3a8b00ca9328a0fe917b34d73e21e9c25b2f
--- /dev/null
+++ b/inspiremusic/utils/scheduler.py
@@ -0,0 +1,738 @@
+# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
+# 2022 Ximalaya Inc (Yuguang Yang)
+# 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from ESPnet(https://github.com/espnet/espnet)
+# NeMo(https://github.com/NVIDIA/NeMo)
+
+from typing import Union
+
+import math
+import warnings
+import torch
+from torch.optim.lr_scheduler import _LRScheduler
+
+
+class WarmupLR(_LRScheduler):
+ """The WarmupLR scheduler
+
+ This scheduler is almost same as NoamLR Scheduler except for following
+ difference:
+
+ NoamLR:
+ lr = optimizer.lr * model_size ** -0.5
+ * min(step ** -0.5, step * warmup_step ** -1.5)
+ WarmupLR:
+ lr = optimizer.lr * warmup_step ** 0.5
+ * min(step ** -0.5, step * warmup_step ** -1.5)
+
+ Note that the maximum lr equals to optimizer.lr in this scheduler.
+
+ """
+
+ def __init__(
+ self,
+ optimizer: torch.optim.Optimizer,
+ warmup_steps: Union[int, float] = 25000,
+ last_epoch: int = -1,
+ ):
+ self.warmup_steps = warmup_steps
+
+ # __init__() must be invoked before setting field
+ # because step() is also invoked in __init__()
+ super().__init__(optimizer, last_epoch)
+
+ def __repr__(self):
+ return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})"
+
+ def get_lr(self):
+ step_num = self.last_epoch + 1
+ if self.warmup_steps == 0:
+ return [lr * step_num**-0.5 for lr in self.base_lrs]
+ else:
+ return [
+ lr * self.warmup_steps**0.5 *
+ min(step_num**-0.5, step_num * self.warmup_steps**-1.5)
+ for lr in self.base_lrs
+ ]
+
+ def set_step(self, step: int):
+ self.last_epoch = step
+
+
+class WarmupPolicy(_LRScheduler):
+ """Adds warmup kwargs and warmup logic to lr policy.
+ All arguments should be passed as kwargs for clarity,
+ Args:
+ warmup_steps: Number of training steps in warmup stage
+ warmup_ratio: Ratio of warmup steps to total steps
+ max_steps: Total number of steps while training or `None` for
+ infinite training
+ """
+
+ def __init__(self,
+ optimizer,
+ *,
+ warmup_steps=None,
+ warmup_ratio=None,
+ max_steps=None,
+ min_lr=0.0,
+ last_epoch=-1):
+ assert not (warmup_steps is not None and warmup_ratio is not None),\
+ "Either use particular number of step or ratio"
+ assert warmup_ratio is None or max_steps is not None, \
+ "If there is a ratio, there should be a total steps"
+
+ # It is necessary to assign all attributes *before* __init__,
+ # as class is wrapped by an inner class.
+ self.max_steps = max_steps
+ if warmup_steps is not None:
+ self.warmup_steps = warmup_steps
+ elif warmup_ratio is not None:
+ self.warmup_steps = int(warmup_ratio * max_steps)
+ else:
+ self.warmup_steps = 0
+
+ self.min_lr = min_lr
+ super().__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ if not self._get_lr_called_within_step:
+ warnings.warn(
+ "To get the last learning rate computed "
+ "by the scheduler, please use `get_last_lr()`.",
+ UserWarning,
+ stacklevel=2)
+
+ step = self.last_epoch
+
+ if step <= self.warmup_steps and self.warmup_steps > 0:
+ return self._get_warmup_lr(step)
+
+ if step > self.max_steps:
+ return [self.min_lr for _ in self.base_lrs]
+
+ return self._get_lr(step)
+
+ def _get_warmup_lr(self, step):
+ lr_val = (step + 1) / (self.warmup_steps + 1)
+ return [initial_lr * lr_val for initial_lr in self.base_lrs]
+
+ def _get_lr(self, step):
+ """Simple const lr policy"""
+ return self.base_lrs
+
+
+class SquareRootConstantPolicy(_LRScheduler):
+ """Adds warmup kwargs and warmup logic to lr policy.
+ All arguments should be passed as kwargs for clarity,
+ Args:
+ warmup_steps: Number of training steps in warmup stage
+ warmup_ratio: Ratio of warmup steps to total steps
+ max_steps: Total number of steps while training or `None` for
+ infinite training
+ """
+
+ def __init__(self,
+ optimizer,
+ *,
+ constant_steps=None,
+ constant_ratio=None,
+ max_steps=None,
+ min_lr=0.0,
+ last_epoch=-1):
+ assert not (constant_steps is not None
+ and constant_ratio is not None), \
+ "Either use particular number of step or ratio"
+ assert constant_ratio is None or max_steps is not None, \
+ "If there is a ratio, there should be a total steps"
+
+ # It is necessary to assign all attributes *before* __init__,
+ # as class is wrapped by an inner class.
+ self.max_steps = max_steps
+ if constant_steps is not None:
+ self.constant_steps = constant_steps
+ elif constant_ratio is not None:
+ self.constant_steps = int(constant_ratio * max_steps)
+ else:
+ self.constant_steps = 0
+
+ self.constant_lr = 1 / (constant_steps**0.5)
+ self.min_lr = min_lr
+ super().__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ if not self._get_lr_called_within_step:
+ warnings.warn(
+ "To get the last learning rate computed "
+ "by the scheduler, please use `get_last_lr()`.",
+ UserWarning,
+ stacklevel=2)
+
+ step = self.last_epoch
+
+ if step <= self.constant_steps:
+ return [self.constant_lr for _ in self.base_lrs]
+
+ if step > self.max_steps:
+ return [self.min_lr for _ in self.base_lrs]
+
+ return self._get_lr(step)
+
+ def _get_lr(self, step):
+ """Simple const lr policy"""
+ return self.base_lrs
+
+
+class WarmupHoldPolicy(WarmupPolicy):
+ """Variant of WarmupPolicy which maintains high
+ learning rate for a defined number of steps.
+ All arguments should be passed as kwargs for clarity,
+ Args:
+ warmup_steps: Number of training steps in warmup stage
+ warmup_ratio: Ratio of warmup steps to total steps
+ hold_steps: Number of training steps to
+ hold the learning rate after warm up
+ hold_ratio: Ratio of hold steps to total steps
+ max_steps: Total number of steps while training or `None` for
+ infinite training
+ """
+
+ def __init__(
+ self,
+ optimizer,
+ *,
+ warmup_steps=None,
+ warmup_ratio=None,
+ hold_steps=None,
+ hold_ratio=None,
+ max_steps=None,
+ min_lr=0.0,
+ last_epoch=-1,
+ ):
+ assert not (hold_steps is not None and hold_ratio is not None), \
+ "Either use particular number of step or ratio"
+ assert hold_ratio is None or max_steps is not None, \
+ "If there is a ratio, there should be a total steps"
+
+ self.min_lr = min_lr
+ self._last_warmup_lr = 0.0
+
+ # Necessary to duplicate as class attributes are hidden in inner class
+ self.max_steps = max_steps
+ if warmup_steps is not None:
+ self.warmup_steps = warmup_steps
+ elif warmup_ratio is not None:
+ self.warmup_steps = int(warmup_ratio * max_steps)
+ else:
+ self.warmup_steps = 0
+
+ if hold_steps is not None:
+ self.hold_steps = hold_steps + self.warmup_steps
+ elif hold_ratio is not None:
+ self.hold_steps = int(hold_ratio * max_steps) + self.warmup_steps
+ else:
+ self.hold_steps = 0
+
+ super().__init__(
+ optimizer,
+ warmup_steps=warmup_steps,
+ warmup_ratio=warmup_ratio,
+ max_steps=max_steps,
+ last_epoch=last_epoch,
+ min_lr=min_lr,
+ )
+
+ def get_lr(self):
+ if not self._get_lr_called_within_step:
+ warnings.warn(
+ "To get the last learning rate computed by the scheduler,"
+ " "
+ "please use `get_last_lr()`.",
+ UserWarning,
+ stacklevel=2)
+
+ step = self.last_epoch
+
+ # Warmup phase
+ if step <= self.warmup_steps and self.warmup_steps > 0:
+ return self._get_warmup_lr(step)
+
+ # Hold phase
+ if (step >= self.warmup_steps) and (step < self.hold_steps):
+ return self.base_lrs
+
+ if step > self.max_steps:
+ return [self.min_lr for _ in self.base_lrs]
+
+ return self._get_lr(step)
+
+
+class WarmupAnnealHoldPolicy(_LRScheduler):
+ """Adds warmup kwargs and warmup logic to lr policy.
+ All arguments should be passed as kwargs for clarity,
+ Args:
+ warmup_steps: Number of training steps in warmup stage
+ warmup_ratio: Ratio of warmup steps to total steps
+ max_steps: Total number of steps while training or `None` for
+ infinite training
+ min_lr: Minimum lr to hold the learning rate after decay at.
+ constant_steps: Number of steps to keep lr constant at.
+ constant_ratio: Ratio of steps to keep lr constant.
+ """
+
+ def __init__(
+ self,
+ optimizer,
+ *,
+ warmup_steps=None,
+ warmup_ratio=None,
+ constant_steps=None,
+ constant_ratio=None,
+ max_steps=None,
+ min_lr=0.0,
+ last_epoch=-1,
+ ):
+ assert not (warmup_steps is not None
+ and warmup_ratio is not None), \
+ "Either use particular number of step or ratio"
+ assert not (constant_steps is not None
+ and constant_ratio is not None), \
+ "Either use constant_steps or constant_ratio"
+ assert warmup_ratio is None or max_steps is not None, \
+ "If there is a ratio, there should be a total steps"
+
+ # It is necessary to assign all attributes *before* __init__,
+ # as class is wrapped by an inner class.
+ self.max_steps = max_steps
+
+ if warmup_steps is not None:
+ self.warmup_steps = warmup_steps
+ elif warmup_ratio is not None:
+ self.warmup_steps = int(warmup_ratio * max_steps)
+ else:
+ self.warmup_steps = 0
+
+ if constant_steps is not None:
+ self.constant_steps = constant_steps
+ elif constant_ratio is not None:
+ self.constant_steps = int(constant_ratio * max_steps)
+ else:
+ self.constant_steps = 0
+
+ self.decay_steps = max_steps - (self.constant_steps +
+ self.warmup_steps)
+
+ self.min_lr = min_lr
+ super().__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ if not self._get_lr_called_within_step:
+ warnings.warn(
+ "To get the last learning rate computed "
+ "by the scheduler, please use `get_last_lr()`.",
+ UserWarning,
+ stacklevel=2)
+
+ step = self.last_epoch
+
+ # Warmup steps
+ if self.warmup_steps > 0 and step <= self.warmup_steps:
+ return self._get_warmup_lr(step)
+
+ # Constant steps after warmup and decay
+ if self.constant_steps > 0 and (
+ self.warmup_steps + self.decay_steps) < step <= self.max_steps:
+ return self._get_constant_lr(step)
+
+ # Min lr after max steps of updates
+ if step > self.max_steps:
+ return [self.min_lr for _ in self.base_lrs]
+
+ return self._get_lr(step)
+
+ def _get_warmup_lr(self, step):
+ lr_val = (step + 1) / (self.warmup_steps + 1)
+ return [initial_lr * lr_val for initial_lr in self.base_lrs]
+
+ def _get_constant_lr(self, step):
+ return [self.min_lr for _ in self.base_lrs]
+
+ def _get_lr(self, step):
+ """Simple const lr policy"""
+ return self.base_lrs
+
+
+def _squareroot_annealing(initial_lr, step, max_steps, min_lr):
+ mult = ((max_steps - step) / max_steps)**0.5
+ out_lr = initial_lr * mult
+ out_lr = max(out_lr, min_lr)
+ return out_lr
+
+
+def _square_annealing(initial_lr, step, max_steps, min_lr):
+ mult = ((max_steps - step) / max_steps)**2
+ out_lr = initial_lr * mult
+ out_lr = max(out_lr, min_lr)
+ return out_lr
+
+
+def _cosine_annealing(initial_lr, step, max_steps, min_lr):
+ mult = 0.5 * (1 + math.cos(math.pi * step / max_steps))
+ out_lr = (initial_lr - min_lr) * mult + min_lr
+ return out_lr
+
+
+def _linear_warmup_with_cosine_annealing(max_lr, warmup_steps, step,
+ decay_steps, min_lr):
+ assert max_lr > min_lr
+ # Use linear warmup for the initial part.
+ if warmup_steps > 0 and step <= warmup_steps:
+ return max_lr * float(step) / float(warmup_steps)
+
+ # For any steps larger than `decay_steps`, use `min_lr`.
+ if step > warmup_steps + decay_steps:
+ return min_lr
+
+ # If we are done with the warmup period, use the decay style.
+ num_steps_ = step - warmup_steps
+ decay_steps_ = decay_steps
+ decay_ratio = float(num_steps_) / float(decay_steps_)
+ assert decay_ratio >= 0.0
+ assert decay_ratio <= 1.0
+ delta_lr = max_lr - min_lr
+
+ coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
+
+ return min_lr + coeff * delta_lr
+
+
+def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle):
+ if cycle:
+ multiplier = 1.0 if step == 0 else math.ceil(step / decay_steps)
+ decay_steps *= multiplier
+ else:
+ step = min(step, decay_steps)
+ p = step / decay_steps
+ lr = (initial_lr - min_lr) * math.pow(1.0 - p, power)
+ lr += min_lr
+ return lr
+
+
+def _noam_hold_annealing(initial_lr, step, warmup_steps, hold_steps,
+ decay_rate, min_lr):
+ # hold_steps = total number of steps
+ # to hold the LR, not the warmup + hold steps.
+ T_warmup_decay = max(1, warmup_steps**decay_rate)
+ T_hold_decay = max(1, (step - hold_steps)**decay_rate)
+ lr = (initial_lr * T_warmup_decay) / T_hold_decay
+ lr = max(lr, min_lr)
+ return lr
+
+
+class SquareAnnealing(WarmupPolicy):
+
+ def __init__(self,
+ optimizer,
+ *,
+ max_steps,
+ min_lr=1e-5,
+ last_epoch=-1,
+ **kwargs):
+ super().__init__(optimizer=optimizer,
+ max_steps=max_steps,
+ last_epoch=last_epoch,
+ min_lr=min_lr,
+ **kwargs)
+
+ def _get_lr(self, step):
+ new_lrs = [
+ _square_annealing(
+ initial_lr=initial_lr,
+ step=step - self.warmup_steps,
+ max_steps=self.max_steps - self.warmup_steps,
+ min_lr=self.min_lr,
+ ) for initial_lr in self.base_lrs
+ ]
+ return new_lrs
+
+
+class SquareRootAnnealing(WarmupPolicy):
+
+ def __init__(self,
+ optimizer,
+ *,
+ max_steps,
+ min_lr=0,
+ last_epoch=-1,
+ **kwargs):
+ super().__init__(optimizer=optimizer,
+ max_steps=max_steps,
+ last_epoch=last_epoch,
+ min_lr=min_lr,
+ **kwargs)
+
+ def _get_lr(self, step):
+ new_lrs = [
+ _squareroot_annealing(initial_lr=initial_lr,
+ step=step,
+ max_steps=self.max_steps,
+ min_lr=self.min_lr)
+ for initial_lr in self.base_lrs
+ ]
+ return new_lrs
+
+
+class CosineAnnealing(WarmupAnnealHoldPolicy):
+
+ def __init__(self,
+ optimizer,
+ *,
+ max_steps,
+ min_lr=0,
+ last_epoch=-1,
+ **kwargs):
+ super().__init__(optimizer=optimizer,
+ max_steps=max_steps,
+ last_epoch=last_epoch,
+ min_lr=min_lr,
+ **kwargs)
+
+ def _get_lr(self, step):
+ for initial_lr in self.base_lrs:
+ if initial_lr < self.min_lr:
+ raise ValueError(
+ f"{self} received an initial learning rate "
+ f"that was lower than the minimum learning rate.")
+
+ if self.constant_steps is None or self.constant_steps == 0:
+ new_lrs = [
+ _cosine_annealing(
+ initial_lr=initial_lr,
+ step=step - self.warmup_steps,
+ max_steps=self.max_steps - self.warmup_steps,
+ min_lr=self.min_lr,
+ ) for initial_lr in self.base_lrs
+ ]
+ else:
+ new_lrs = self._get_linear_warmup_with_cosine_annealing_lr(step)
+ return new_lrs
+
+ def _get_warmup_lr(self, step):
+ if self.constant_steps is None or self.constant_steps == 0:
+ return super()._get_warmup_lr(step)
+ else:
+ # Use linear warmup for the initial part.
+ return self._get_linear_warmup_with_cosine_annealing_lr(step)
+
+ def _get_constant_lr(self, step):
+ # Only called when `constant_steps` > 0.
+ return self._get_linear_warmup_with_cosine_annealing_lr(step)
+
+ def _get_linear_warmup_with_cosine_annealing_lr(self, step):
+ # Cosine Schedule for Megatron LM,
+ # slightly different warmup schedule + constant LR at the end.
+ new_lrs = [
+ _linear_warmup_with_cosine_annealing(
+ max_lr=self.base_lrs[0],
+ warmup_steps=self.warmup_steps,
+ step=step,
+ decay_steps=self.decay_steps,
+ min_lr=self.min_lr,
+ ) for _ in self.base_lrs
+ ]
+ return new_lrs
+
+
+class NoamAnnealing(_LRScheduler):
+
+ def __init__(self,
+ optimizer,
+ *,
+ d_model,
+ warmup_steps=None,
+ warmup_ratio=None,
+ max_steps=None,
+ min_lr=0.0,
+ last_epoch=-1):
+ self._normalize = d_model**(-0.5)
+ assert not (warmup_steps is not None and warmup_ratio is not None), \
+ "Either use particular number of step or ratio"
+ assert warmup_ratio is None or max_steps is not None, \
+ "If there is a ratio, there should be a total steps"
+
+ # It is necessary to assign all attributes *before* __init__,
+ # as class is wrapped by an inner class.
+ self.max_steps = max_steps
+ if warmup_steps is not None:
+ self.warmup_steps = warmup_steps
+ elif warmup_ratio is not None:
+ self.warmup_steps = int(warmup_ratio * max_steps)
+ else:
+ self.warmup_steps = 0
+
+ self.min_lr = min_lr
+ super().__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ if not self._get_lr_called_within_step:
+ warnings.warn(
+ "To get the last learning rate computed "
+ "by the scheduler, please use `get_last_lr()`.",
+ UserWarning,
+ stacklevel=2)
+
+ step = max(1, self.last_epoch)
+
+ for initial_lr in self.base_lrs:
+ if initial_lr < self.min_lr:
+ raise ValueError(
+ f"{self} received an initial learning rate "
+ f"that was lower than the minimum learning rate.")
+
+ new_lrs = [
+ self._noam_annealing(initial_lr=initial_lr, step=step)
+ for initial_lr in self.base_lrs
+ ]
+ return new_lrs
+
+ def _noam_annealing(self, initial_lr, step):
+ if self.warmup_steps > 0:
+ mult = self._normalize * min(step**(-0.5),
+ step * (self.warmup_steps**(-1.5)))
+ else:
+ mult = self._normalize * step**(-0.5)
+
+ out_lr = initial_lr * mult
+ if step > self.warmup_steps:
+ out_lr = max(out_lr, self.min_lr)
+ return out_lr
+
+
+class NoamHoldAnnealing(WarmupHoldPolicy):
+
+ def __init__(self,
+ optimizer,
+ *,
+ max_steps,
+ decay_rate=0.5,
+ min_lr=0.0,
+ last_epoch=-1,
+ **kwargs):
+ """
+ From Nemo:
+ Implementation of the Noam Hold Annealing policy
+ from the SqueezeFormer paper.
+
+ Unlike NoamAnnealing, the peak learning rate
+ can be explicitly set for this scheduler.
+ The schedule first performs linear warmup,
+ then holds the peak LR, then decays with some schedule for
+ the remainder of the steps.
+ Therefore the min-lr is still dependent
+ on the hyper parameters selected.
+
+ It's schedule is determined by three factors-
+
+ Warmup Steps: Initial stage, where linear warmup
+ occurs uptil the peak LR is reached. Unlike NoamAnnealing,
+ the peak LR is explicitly stated here instead of a scaling factor.
+
+ Hold Steps: Intermediate stage, where the peak LR
+ is maintained for some number of steps. In this region,
+ the high peak LR allows the model to converge faster
+ if training is stable. However the high LR
+ may also cause instability during training.
+ Should usually be a significant fraction of training
+ steps (around 30-40% of the entire training steps).
+
+ Decay Steps: Final stage, where the LR rapidly decays
+ with some scaling rate (set by decay rate).
+ To attain Noam decay, use 0.5,
+ for Squeezeformer recommended decay, use 1.0.
+ The fast decay after prolonged high LR during
+ hold phase allows for rapid convergence.
+
+ References:
+ - [Squeezeformer:
+ An Efficient Transformer for Automatic Speech Recognition]
+ (https://arxiv.org/abs/2206.00888)
+
+ Args:
+ optimizer: Pytorch compatible Optimizer object.
+ warmup_steps: Number of training steps in warmup stage
+ warmup_ratio: Ratio of warmup steps to total steps
+ hold_steps: Number of training steps to
+ hold the learning rate after warm up
+ hold_ratio: Ratio of hold steps to total steps
+ max_steps: Total number of steps while training or `None` for
+ infinite training
+ decay_rate: Float value describing the polynomial decay
+ after the hold period. Default value
+ of 0.5 corresponds to Noam decay.
+ min_lr: Minimum learning rate.
+ """
+ self.decay_rate = decay_rate
+ super().__init__(optimizer=optimizer,
+ max_steps=max_steps,
+ last_epoch=last_epoch,
+ min_lr=min_lr,
+ **kwargs)
+
+ def _get_lr(self, step):
+ if self.warmup_steps is None or self.warmup_steps == 0:
+ raise ValueError(
+ "Noam scheduler cannot be used without warmup steps")
+
+ if self.hold_steps > 0:
+ hold_steps = self.hold_steps - self.warmup_steps
+ else:
+ hold_steps = 0
+
+ new_lrs = [
+ _noam_hold_annealing(
+ initial_lr,
+ step=step,
+ warmup_steps=self.warmup_steps,
+ hold_steps=hold_steps,
+ decay_rate=self.decay_rate,
+ min_lr=self.min_lr,
+ ) for initial_lr in self.base_lrs
+ ]
+ return new_lrs
+
+ def set_step(self, step: int):
+ self.last_epoch = step
+
+
+class ConstantLR(_LRScheduler):
+ """The ConstantLR scheduler
+
+ This scheduler keeps a constant lr
+
+ """
+
+ def __init__(
+ self,
+ optimizer: torch.optim.Optimizer,
+ ):
+ # __init__() must be invoked before setting field
+ # because step() is also invoked in __init__()
+ super().__init__(optimizer)
+
+ def get_lr(self):
+ return self.base_lrs
+
+ def set_step(self, step: int):
+ self.last_epoch = step
diff --git a/inspiremusic/utils/tokenizer_utils.py b/inspiremusic/utils/tokenizer_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..095757f5d0f8a6da8598f06ceed819e659f8bb61
--- /dev/null
+++ b/inspiremusic/utils/tokenizer_utils.py
@@ -0,0 +1,221 @@
+import glob
+import json
+import os
+import random
+import sys
+import time
+import warnings
+
+import matplotlib
+import numpy as np
+import torch
+import yaml
+from torch import distributed as dist
+from torch.nn.utils import weight_norm
+matplotlib.use("Agg")
+import matplotlib.pylab as plt
+import re
+import pathlib
+
+
+def seed_everything(seed, cudnn_deterministic=False):
+ """
+ Function that sets seed for pseudo-random number generators in:
+ pytorch, numpy, python.random
+
+ Args:
+ seed: the integer value seed for global random state
+ """
+ if seed is not None:
+ # print(f"Global seed set to {seed}")
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+ # if cudnn_deterministic:
+ # torch.backends.cudnn.deterministic = True
+ # warnings.warn('You have chosen to seed training. '
+ # 'This will turn on the CUDNN deterministic setting, '
+ # 'which can slow down your training considerably! '
+ # 'You may see unexpected behavior when restarting '
+ # 'from checkpoints.')
+
+
+def is_primary():
+ return get_rank() == 0
+
+
+def get_rank():
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+
+ return dist.get_rank()
+
+
+def load_yaml_config(path):
+ with open(path) as f:
+ config = yaml.full_load(f)
+ return config
+
+
+def save_config_to_yaml(config, path):
+ assert path.endswith('.yaml')
+ with open(path, 'w') as f:
+ f.write(yaml.dump(config))
+ f.close()
+
+
+def save_dict_to_json(d, path, indent=None):
+ json.dump(d, open(path, 'w'), indent=indent)
+
+
+def load_dict_from_json(path):
+ return json.load(open(path, 'r'))
+
+
+def write_args(args, path):
+ args_dict = dict((name, getattr(args, name)) for name in dir(args)
+ if not name.startswith('_'))
+ with open(path, 'a') as args_file:
+ args_file.write('==> torch version: {}\n'.format(torch.__version__))
+ args_file.write(
+ '==> cudnn version: {}\n'.format(torch.backends.cudnn.version()))
+ args_file.write('==> Cmd:\n')
+ args_file.write(str(sys.argv))
+ args_file.write('\n==> args:\n')
+ for k, v in sorted(args_dict.items()):
+ args_file.write(' %s: %s\n' % (str(k), str(v)))
+ args_file.close()
+
+
+class Logger(object):
+ def __init__(self, args):
+ self.args = args
+ self.save_dir = args.save_dir
+ self.is_primary = is_primary()
+
+ if self.is_primary:
+ os.makedirs(self.save_dir, exist_ok=True)
+
+ # save the args and config
+ self.config_dir = os.path.join(self.save_dir, 'configs')
+ os.makedirs(self.config_dir, exist_ok=True)
+ file_name = os.path.join(self.config_dir, 'args.txt')
+ write_args(args, file_name)
+
+ log_dir = os.path.join(self.save_dir, 'logs')
+ if not os.path.exists(log_dir):
+ os.makedirs(log_dir, exist_ok=True)
+ self.text_writer = open(os.path.join(log_dir, 'log.txt'),
+ 'a') # 'w')
+ if args.tensorboard:
+ self.log_info('using tensorboard')
+ self.tb_writer = torch.utils.tensorboard.SummaryWriter(
+ log_dir=log_dir
+ ) # tensorboard.SummaryWriter(log_dir=log_dir)
+ else:
+ self.tb_writer = None
+
+ def save_config(self, config):
+ if self.is_primary:
+ save_config_to_yaml(config,
+ os.path.join(self.config_dir, 'config.yaml'))
+
+ def log_info(self, info, check_primary=True):
+ if self.is_primary or (not check_primary):
+ print(info)
+ if self.is_primary:
+ info = str(info)
+ time_str = time.strftime('%Y-%m-%d-%H-%M')
+ info = '{}: {}'.format(time_str, info)
+ if not info.endswith('\n'):
+ info += '\n'
+ self.text_writer.write(info)
+ self.text_writer.flush()
+
+ def add_scalar(self, **kargs):
+ """Log a scalar variable."""
+ if self.is_primary:
+ if self.tb_writer is not None:
+ self.tb_writer.add_scalar(**kargs)
+
+ def add_scalars(self, **kargs):
+ """Log a scalar variable."""
+ if self.is_primary:
+ if self.tb_writer is not None:
+ self.tb_writer.add_scalars(**kargs)
+
+ def add_image(self, **kargs):
+ """Log a scalar variable."""
+ if self.is_primary:
+ if self.tb_writer is not None:
+ self.tb_writer.add_image(**kargs)
+
+ def add_images(self, **kargs):
+ """Log a scalar variable."""
+ if self.is_primary:
+ if self.tb_writer is not None:
+ self.tb_writer.add_images(**kargs)
+
+ def close(self):
+ if self.is_primary:
+ self.text_writer.close()
+ self.tb_writer.close()
+
+
+def plot_spectrogram(spectrogram):
+ fig, ax = plt.subplots(figsize=(10, 2))
+ im = ax.imshow(
+ spectrogram, aspect="auto", origin="lower", interpolation='none')
+ plt.colorbar(im, ax=ax)
+
+ fig.canvas.draw()
+ plt.close()
+
+ return fig
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+def apply_weight_norm(m):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ weight_norm(m)
+
+
+def get_padding(kernel_size, dilation=1):
+ return int((kernel_size * dilation - dilation) / 2)
+
+
+def load_checkpoint(filepath, device):
+ assert os.path.isfile(filepath)
+ print("Loading '{}'".format(filepath))
+ checkpoint_dict = torch.load(filepath, map_location=device)
+ print("Complete.")
+ return checkpoint_dict
+
+
+def save_checkpoint(filepath, obj, num_ckpt_keep=5):
+ name = re.match(r'(do|g)_\d+', pathlib.Path(filepath).name).group(1)
+ ckpts = sorted(pathlib.Path(filepath).parent.glob(f'{name}_*'))
+ if len(ckpts) > num_ckpt_keep:
+ [os.remove(c) for c in ckpts[:-num_ckpt_keep]]
+ print("Saving checkpoint to {}".format(filepath))
+ torch.save(obj, filepath)
+ print("Complete.")
+
+
+def scan_checkpoint(cp_dir, prefix):
+ pattern = os.path.join(cp_dir, prefix + '????????')
+ cp_list = glob.glob(pattern)
+ if len(cp_list) == 0:
+ return None
+ return sorted(cp_list)[-1]
+
diff --git a/inspiremusic/utils/train_utils.py b/inspiremusic/utils/train_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9312eb5a83f37388cd4de547beef1bc0f95b6d28
--- /dev/null
+++ b/inspiremusic/utils/train_utils.py
@@ -0,0 +1,300 @@
+# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
+# 2023 Horizon Inc. (authors: Xingchen Song)
+# 2024 Alibaba Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from contextlib import nullcontext
+import logging
+import os
+import torch
+import json
+import re
+import datetime
+import yaml
+
+import deepspeed
+import torch.optim as optim
+import torch.distributed as dist
+
+from torch.utils.tensorboard import SummaryWriter
+from torch.utils.data import DataLoader
+from torch.nn.utils import clip_grad_norm_
+
+from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live
+
+from inspiremusic.dataset.dataset import Dataset
+from inspiremusic.utils.scheduler import WarmupLR, NoamHoldAnnealing, ConstantLR
+
+
+def init_distributed(args):
+ world_size = int(os.environ.get('WORLD_SIZE', 1))
+ local_rank = int(os.environ.get('LOCAL_RANK', 0))
+ rank = int(os.environ.get('RANK', 0))
+ logging.info('training on multiple gpus, this gpu {}'.format(local_rank) +
+ ', rank {}, world_size {}'.format(rank, world_size))
+ if args.train_engine == 'torch_ddp':
+ torch.cuda.set_device(local_rank)
+ dist.init_process_group(args.dist_backend)
+ else:
+ deepspeed.init_distributed(dist_backend=args.dist_backend)
+ return world_size, local_rank, rank
+
+
+def init_dataset_and_dataloader(args, configs):
+ gan = False
+ data_pipeline = configs['data_pipeline_gan'] if gan is True else configs['data_pipeline']
+ train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', shuffle=True, partition=True)
+ cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='train', shuffle=False, partition=False)
+
+ # do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
+ train_data_loader = DataLoader(train_dataset,
+ batch_size=None,
+ pin_memory=args.pin_memory,
+ num_workers=args.num_workers,
+ prefetch_factor=args.prefetch,
+ timeout=60)
+ cv_data_loader = DataLoader(cv_dataset,
+ batch_size=None,
+ pin_memory=args.pin_memory,
+ num_workers=args.num_workers,
+ prefetch_factor=args.prefetch,
+ timeout=60)
+ return train_dataset, cv_dataset, train_data_loader, cv_data_loader
+
+
+def check_modify_and_save_config(args, configs):
+ if args.train_engine == "torch_ddp":
+ configs['train_conf']["dtype"] = 'fp32'
+ else:
+ with open(args.deepspeed_config, 'r') as fin:
+ ds_configs = json.load(fin)
+ if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]:
+ configs['train_conf']["dtype"] = "fp16"
+ elif "bf16" in ds_configs and ds_configs["bf16"]["enabled"]:
+ configs['train_conf']["dtype"] = "bf16"
+ else:
+ configs['train_conf']["dtype"] = "fp32"
+ assert ds_configs["train_micro_batch_size_per_gpu"] == 1
+ # if use deepspeed, override ddp config
+ configs['train_conf']['save_per_step'] = int(configs['train_conf']['save_per_step'] *
+ configs['train_conf']['accum_grad'] / ds_configs["gradient_accumulation_steps"])
+ configs['train_conf']['accum_grad'] = ds_configs["gradient_accumulation_steps"]
+ configs['train_conf']['grad_clip'] = ds_configs["gradient_clipping"]
+ configs['train_conf']['log_interval'] = ds_configs["steps_per_print"]
+ return configs
+
+
+def wrap_cuda_model(args, model):
+ local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
+ world_size = int(os.environ.get('WORLD_SIZE', 1))
+ if args.train_engine == "torch_ddp": # native pytorch ddp
+ assert (torch.cuda.is_available())
+ model.cuda()
+ model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
+ else:
+ if int(os.environ.get('RANK', 0)) == 0:
+ logging.info("Estimating model states memory needs (zero2)...")
+ estimate_zero2_model_states_mem_needs_all_live(
+ model,
+ num_gpus_per_node=local_world_size,
+ num_nodes=world_size // local_world_size)
+ return model
+
+def init_optimizer_and_scheduler(args, configs, model):
+ if configs['train_conf']['optim'] == 'adam':
+ optimizer = optim.Adam(model.parameters(), **configs['train_conf']['optim_conf'])
+ elif configs['train_conf']['optim'] == 'adamw':
+ optimizer = optim.AdamW(model.parameters(), **configs['train_conf']['optim_conf'])
+ else:
+ raise ValueError("unknown optimizer: " + configs['train_conf'])
+
+ if configs['train_conf']['scheduler'] == 'warmuplr':
+ scheduler_type = WarmupLR
+ scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
+ elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
+ scheduler_type = NoamHoldAnnealing
+ scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
+ elif configs['train_conf']['scheduler'] == 'constantlr':
+ scheduler_type = ConstantLR
+ scheduler = ConstantLR(optimizer)
+ else:
+ raise ValueError("unknown scheduler: " + configs['train_conf'])
+
+ # use deepspeed optimizer for speedup
+ if args.train_engine == "deepspeed":
+ def scheduler(opt):
+ return scheduler_type(opt, **configs['train_conf']['scheduler_conf'])
+ model, optimizer, _, scheduler = deepspeed.initialize(
+ args=args,
+ model=model,
+ optimizer=None,
+ lr_scheduler=scheduler,
+ model_parameters=model.parameters())
+
+ return model, optimizer, scheduler
+
+
+def init_summarywriter(args):
+ writer = None
+ if int(os.environ.get('RANK', 0)) == 0:
+ os.makedirs(args.model_dir, exist_ok=True)
+ writer = SummaryWriter(args.tensorboard_dir)
+ return writer
+
+
+def save_model(model, model_name, info_dict):
+ rank = int(os.environ.get('RANK', 0))
+ model_dir = info_dict["model_dir"]
+ save_model_path = os.path.join(model_dir, '{}.pt'.format(model_name))
+
+ if info_dict["train_engine"] == "torch_ddp":
+ if rank == 0:
+ torch.save(model.module.state_dict(), save_model_path)
+ else:
+ with torch.no_grad():
+ model.save_checkpoint(save_dir=model_dir,
+ tag=model_name,
+ client_state=info_dict)
+ if rank == 0:
+ info_path = re.sub('.pt$', '.yaml', save_model_path)
+ info_dict['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S')
+ with open(info_path, 'w') as fout:
+ data = yaml.dump(info_dict)
+ fout.write(data)
+ logging.info('[Rank {}] Checkpoint: save to checkpoint {}'.format(rank, save_model_path))
+
+
+def inspiremusic_join(group_join, info_dict):
+ world_size = int(os.environ.get('WORLD_SIZE', 1))
+ local_rank = int(os.environ.get('LOCAL_RANK', 0))
+ rank = int(os.environ.get('RANK', 0))
+
+ if info_dict["batch_idx"] != 0:
+ # we try to join all rank in both ddp and deepspeed mode, in case different rank has different lr
+ try:
+ dist.monitored_barrier(group=group_join,
+ timeout=group_join.options._timeout)
+ return False
+ except RuntimeError as e:
+ logging.info("Detected uneven workload distribution: {}\n".format(e) +
+ "Break current worker to manually join all workers, " +
+ "world_size {}, current rank {}, current local_rank {}\n".
+ format(world_size, rank, local_rank))
+ return True
+ else:
+ return False
+
+
+def batch_forward(model, batch, info_dict, scaler):
+ device = int(os.environ.get('LOCAL_RANK', 0))
+
+ dtype = info_dict["dtype"]
+ if dtype == "fp16":
+ dtype = torch.float16
+ elif dtype == "bf16":
+ dtype = torch.bfloat16
+ else: # fp32
+ dtype = torch.float32
+
+ if info_dict['train_engine'] == 'torch_ddp':
+ autocast = torch.cuda.amp.autocast(enabled=scaler is not None)
+ else:
+ autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False)
+
+ with autocast:
+ info_dict['loss_dict'] = model(batch, device)
+ return info_dict
+
+
+def batch_backward(model, info_dict, scaler):
+ if info_dict["train_engine"] == "deepspeed":
+ scaled_loss = model.backward(info_dict['loss_dict']['loss'])
+ else:
+ scaled_loss = info_dict['loss_dict']['loss'] / info_dict['accum_grad']
+ if scaler is not None:
+ scaler.scale(scaled_loss).backward()
+ else:
+ scaled_loss.backward()
+
+ info_dict['loss_dict']['loss'] = scaled_loss
+ return info_dict
+
+def update_parameter_and_lr(model, optimizer, scheduler, info_dict, scaler=None):
+ grad_norm = 0.0
+ if info_dict['train_engine'] == "deepspeed":
+ info_dict["is_gradient_accumulation_boundary"] = model.is_gradient_accumulation_boundary()
+ model.step()
+ grad_norm = model.get_global_grad_norm()
+ elif (info_dict['batch_idx'] + 1) % info_dict["accum_grad"] == 0:
+ if scaler is not None:
+ scaler.unscale_(optimizer) # Unscale gradients before clipping
+ grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
+ scaler.step(optimizer)
+ scaler.update()
+ else:
+ grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
+ if torch.isfinite(grad_norm):
+ optimizer.step()
+ optimizer.zero_grad()
+ scheduler.step()
+ info_dict["lr"] = optimizer.param_groups[0]['lr']
+ info_dict["grad_norm"] = grad_norm
+ return info_dict
+
+
+def log_per_step(writer, info_dict):
+ tag = info_dict["tag"]
+ epoch = info_dict.get('epoch', 0)
+ step = info_dict["step"]
+ batch_idx = info_dict["batch_idx"]
+ loss_dict = info_dict['loss_dict']
+ rank = int(os.environ.get('RANK', 0))
+
+ # only rank 0 write to tensorboard to avoid multi-process write
+ if writer is not None:
+ if (info_dict['train_engine'] == 'deepspeed' and info_dict['is_gradient_accumulation_boundary'] is True) or \
+ (info_dict['train_engine'] == 'torch_ddp' and (info_dict['batch_idx'] + 1) % info_dict['accum_grad'] == 0):
+ for k in ['epoch', 'lr', 'grad_norm']:
+ writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1)
+ for k, v in loss_dict.items():
+ writer.add_scalar('{}/{}'.format(tag, k), v, step + 1)
+
+ # TRAIN & CV, Shell log (stdout)
+ if (info_dict['batch_idx'] + 1) % info_dict['log_interval'] == 0:
+ log_str = '{} Batch {}/{} '.format(tag, epoch, batch_idx + 1)
+ for name, value in loss_dict.items():
+ log_str += '{} {:.6f} '.format(name, value.item())
+ if tag == "TRAIN":
+ log_str += 'lr {:.8f} grad_norm {:.6f}'.format(
+ info_dict["lr"], info_dict['grad_norm'])
+ log_str += ' rank {}'.format(rank)
+ logging.debug(log_str)
+
+
+def log_per_save(writer, info_dict):
+ tag = info_dict["tag"]
+ epoch = info_dict["epoch"]
+ step = info_dict["step"]
+ loss_dict = info_dict["loss_dict"]
+ lr = info_dict['lr']
+ rank = int(os.environ.get('RANK', 0))
+ logging.info(
+ 'Epoch {} Step {} CV info lr {} {} rank {}'.format(
+ epoch, step + 1, lr, rank, ' '.join(['{}_{}'.format(k, v) for k, v in loss_dict.items()])))
+
+ if writer is not None:
+ for k in ['epoch', 'lr']:
+ writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1)
+ for k, v in loss_dict.items():
+ writer.add_scalar('{}/{}'.format(tag, k), v, step + 1)
diff --git a/inspiremusic/utils/utils.py b/inspiremusic/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..db1d8b41c94be6c41b8424d3d036e17a6471234d
--- /dev/null
+++ b/inspiremusic/utils/utils.py
@@ -0,0 +1,22 @@
+import os
+import sys
+
+def align_trans_scp_file(trans, scp):
+ trans_dict = {}
+ with open(trans, 'r') as f:
+ for line in f:
+ sec = line.strip().split("\t")
+ trans_dict[sec[0]] = sec[1]
+ scp_dict = {}
+ with open(scp, 'r') as f:
+ for line in f:
+ sec = line.strip().split(" ")
+ scp_dict[sec[0]] = sec[1]
+ with open("text", "w") as f:
+ for k, v in scp_dict.items():
+ f.write("%s\t%s\n"%(k,trans_dict[k]))
+
+if __name__ == '__main__':
+ trans = sys.argv[1]
+ scp = sys.argv[2]
+ align_trans_scp_file(trans, scp)
\ No newline at end of file
diff --git a/inspiremusic/version.txt b/inspiremusic/version.txt
new file mode 100644
index 0000000000000000000000000000000000000000..aa3386889155607fc99e0c091986218549648a89
--- /dev/null
+++ b/inspiremusic/version.txt
@@ -0,0 +1 @@
+v0.1
\ No newline at end of file
diff --git a/inspiremusic/wavtokenizer/__init__.py b/inspiremusic/wavtokenizer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/inspiremusic/wavtokenizer/decoder/__init__.py b/inspiremusic/wavtokenizer/decoder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/inspiremusic/wavtokenizer/decoder/dataset.py b/inspiremusic/wavtokenizer/decoder/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..345a1617291648fe5f8671e7ea897c539fdcb2f5
--- /dev/null
+++ b/inspiremusic/wavtokenizer/decoder/dataset.py
@@ -0,0 +1,124 @@
+from dataclasses import dataclass
+
+import numpy as np
+import torch
+import torchaudio
+from pytorch_lightning import LightningDataModule
+from torch.utils.data import Dataset, DataLoader
+
+import soundfile
+# import librosa
+import random
+
+torch.set_num_threads(1)
+
+
+@dataclass
+class DataConfig:
+ filelist_path: str
+ sampling_rate: int
+ num_samples: int
+ batch_size: int
+ num_workers: int
+
+def collate_fn(batch):
+ batch = [item for item in batch if item is not None]
+ return torch.stack(batch, dim=0)
+
+class VocosDataModule(LightningDataModule):
+ def __init__(self, train_params: DataConfig, val_params: DataConfig):
+ super().__init__()
+ self.train_config = train_params
+ self.val_config = val_params
+
+ def _get_dataloder(self, cfg: DataConfig, train: bool):
+ dataset = VocosDataset(cfg, train=train)
+ dataloader = DataLoader(
+ dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers, shuffle=train, pin_memory=True, collate_fn=collate_fn
+ )
+ return dataloader
+
+ def train_dataloader(self) -> DataLoader:
+ return self._get_dataloder(self.train_config, train=True)
+
+ def val_dataloader(self) -> DataLoader:
+ return self._get_dataloder(self.val_config, train=False)
+
+
+class VocosDataset(Dataset):
+ def __init__(self, cfg: DataConfig, train: bool):
+ with open(cfg.filelist_path) as f:
+ self.filelist = f.read().splitlines()
+ self.sampling_rate = cfg.sampling_rate
+ self.num_samples = cfg.num_samples
+ self.train = train
+
+ def __len__(self) -> int:
+ return len(self.filelist)
+
+ def __getitem__(self, index: int) -> torch.Tensor:
+ audio_path = self.filelist[index]
+ # y, sr = torchaudio.load(audio_path)
+ # print(audio_path,"111")
+ try:
+ y1, sr = soundfile.read(audio_path)
+ # y1, sr = librosa.load(audio_path,sr=None)
+ y = torch.tensor(y1).float().unsqueeze(0)
+ # if y.size(0) > 1:
+ # # mix to mono
+ # y = y.mean(dim=0, keepdim=True)
+ if y.ndim > 2:
+ # mix to mono
+ # print("ๆ้ฎ้ขๅ,ๆฐๆฎๅค็้จๅ")
+ # y = y.mean(dim=-1, keepdim=False)
+ random_channel = random.randint(0, y.size(-1) - 1)
+ y = y[:, :, random_channel]
+
+ gain = np.random.uniform(-1, -6) if self.train else -3
+ y, _ = torchaudio.sox_effects.apply_effects_tensor(y, sr, [["norm", f"{gain:.2f}"]])
+ if sr != self.sampling_rate:
+ y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=self.sampling_rate)
+ if y.size(-1) < self.num_samples:
+ pad_length = self.num_samples - y.size(-1)
+ padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1))
+ y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1)
+ elif self.train:
+ start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1)
+ y = y[:, start : start + self.num_samples]
+ else:
+ # During validation, take always the first segment for determinism
+ y = y[:, : self.num_samples]
+
+ return y[0]
+ except Exception as e:
+ print(f"Error processing file {audio_path} at index {index}: {e}")
+ # ่ฟ้ๅฏไปฅ็ปง็ปญ้ๆฉๆๅบๅผๅธธ๏ผๆ่
่ฟๅไธไธช None ่กจ็คบๆ ๆๆฐๆฎ
+ return None
+
+ # def __getitem__(self, index: int) -> torch.Tensor:
+ # audio_path = self.filelist[index]
+ # try:
+ # y, sr = torchaudio.load(audio_path)
+ # if y.size(0) > 1:
+ # # ้ๆบ้ๆฉไธไธช้้
+ # random_channel = random.randint(0, y.size(0) - 1)
+ # y = y[random_channel, :].unsqueeze(0) # ไฟๆ่ฟๅๅผไธบ (1, T) ็ๅฝขๅผ
+ # # gain = np.random.uniform(-1, -6) if self.train else -3
+ # # y, _ = torchaudio.sox_effects.apply_effects_tensor(y, sr, [["norm", f"{gain:.2f}"]])
+ # if sr != self.sampling_rate:
+ # y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=self.sampling_rate)
+ # if y.size(-1) < self.num_samples:
+ # pad_length = self.num_samples - y.size(-1)
+ # padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1))
+ # y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1)
+ # elif self.train:
+ # start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1)
+ # y = y[:, start: start + self.num_samples]
+ # else:
+ # # During validation, take always the first segment for determinism
+ # y = y[:, :self.num_samples]
+ # return y[0]
+ # except Exception as e:
+ # print(f"Error processing file {audio_path} at index {index}: {e}")
+ # # ่ฟ้ๅฏไปฅ็ปง็ปญ้ๆฉๆๅบๅผๅธธ๏ผๆ่
่ฟๅไธไธช None ่กจ็คบๆ ๆๆฐๆฎ
+ # return None
\ No newline at end of file
diff --git a/inspiremusic/wavtokenizer/decoder/discriminator_dac.py b/inspiremusic/wavtokenizer/decoder/discriminator_dac.py
new file mode 100644
index 0000000000000000000000000000000000000000..33ef3258a3a4a3ab11f17ca9c3ad381de95b6477
--- /dev/null
+++ b/inspiremusic/wavtokenizer/decoder/discriminator_dac.py
@@ -0,0 +1,249 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+# from audiotools import AudioSignal
+# from audiotools import ml
+# from audiotools import STFTParams
+from einops import rearrange
+from torch.nn.utils import weight_norm
+
+from collections import namedtuple
+
+STFTParams = namedtuple(
+ "STFTParams",
+ ["window_length", "hop_length", "window_type", "match_stride", "padding_type"],
+)
+
+STFTParams.__new__.__defaults__ = (None, None, None, None, None)
+
+
+def WNConv1d(*args, **kwargs):
+ act = kwargs.pop("act", True)
+ conv = weight_norm(nn.Conv1d(*args, **kwargs))
+ if not act:
+ return conv
+ return nn.Sequential(conv, nn.LeakyReLU(0.1))
+
+
+def WNConv2d(*args, **kwargs):
+ act = kwargs.pop("act", True)
+ conv = weight_norm(nn.Conv2d(*args, **kwargs))
+ if not act:
+ return conv
+ return nn.Sequential(conv, nn.LeakyReLU(0.1))
+
+
+class MPD(nn.Module):
+ def __init__(self, period):
+ super().__init__()
+ self.period = period
+ self.convs = nn.ModuleList(
+ [
+ WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
+ WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
+ WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
+ WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
+ WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
+ ]
+ )
+ self.conv_post = WNConv2d(
+ 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
+ )
+
+ def pad_to_period(self, x):
+ t = x.shape[-1]
+ x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
+ return x
+
+ def forward(self, x):
+ fmap = []
+
+ x = self.pad_to_period(x)
+ x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
+
+ for layer in self.convs:
+ x = layer(x)
+ fmap.append(x)
+
+ x = self.conv_post(x)
+ fmap.append(x)
+
+ return fmap
+
+
+class MSD(nn.Module):
+ def __init__(self, rate: int = 1, sample_rate: int = 48000):
+ super().__init__()
+ self.convs = nn.ModuleList(
+ [
+ WNConv1d(1, 16, 15, 1, padding=7),
+ WNConv1d(16, 64, 41, 4, groups=4, padding=20),
+ WNConv1d(64, 256, 41, 4, groups=16, padding=20),
+ WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
+ WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
+ WNConv1d(1024, 1024, 5, 1, padding=2),
+ ]
+ )
+ self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
+ self.sample_rate = sample_rate
+ self.rate = rate
+
+ def forward(self, x):
+ # x = AudioSignal(x, self.sample_rate)
+ # x.resample(self.sample_rate // self.rate)
+ # x = x.audio_data
+
+ fmap = []
+
+ for l in self.convs:
+ x = l(x)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+
+ return fmap
+
+
+BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
+
+
+class MRD(nn.Module):
+ def __init__(
+ self,
+ window_length: int,
+ hop_factor: float = 0.25,
+ sample_rate: int = 24000,
+ bands: list = BANDS,
+ ):
+ """Complex multi-band spectrogram discriminator.
+ Parameters
+ ----------
+ window_length : int
+ Window length of STFT.
+ hop_factor : float, optional
+ Hop factor of the STFT, defaults to ``0.25 * window_length``.
+ sample_rate : int, optional
+ Sampling rate of audio in Hz, by default 24000
+ bands : list, optional
+ Bands to run discriminator over.
+ """
+ super().__init__()
+
+ self.window_length = window_length
+ self.hop_factor = hop_factor
+ self.sample_rate = sample_rate
+ self.stft_params = STFTParams(
+ window_length=window_length,
+ hop_length=int(window_length * hop_factor),
+ match_stride=True,
+ )
+
+ n_fft = window_length // 2 + 1
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
+ self.bands = bands
+ self.n_fft = window_length
+
+ ch = 32
+ convs = lambda: nn.ModuleList(
+ [
+ WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
+ WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
+ ]
+ )
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
+ self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
+
+ def spectrogram(self, x):
+ # x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
+ # x = torch.view_as_real(x.stft())
+
+ # x.squeeze(0).stft(n_fft=1024,win_length=1024,return_complex=True).size()
+ # breakpoint()
+ if x.size(0)==1:
+ # x = torch.view_as_real(x.squeeze(0).stft(n_fft=self.window_length,return_complex=True).unsqueeze(0))
+ x = torch.view_as_real(x.squeeze(0).stft(n_fft=self.n_fft,return_complex=True).unsqueeze(0))
+ else:
+ # x = torch.view_as_real(x.squeeze(1).stft(n_fft=self.window_length,return_complex=True).unsqueeze(1))
+ x = torch.view_as_real(x.squeeze(1).stft(n_fft=self.n_fft,return_complex=True).unsqueeze(1))
+ x = rearrange(x, "b 1 f t c -> (b 1) c t f")
+ # Split into bands
+ x_bands = [x[..., b[0] : b[1]] for b in self.bands]
+ return x_bands
+
+ def forward(self, x):
+ x_bands = self.spectrogram(x)
+ fmap = []
+
+ x = []
+ for band, stack in zip(x_bands, self.band_convs):
+ for layer in stack:
+ band = layer(band)
+ fmap.append(band)
+ x.append(band)
+
+ x = torch.cat(x, dim=-1)
+ x = self.conv_post(x)
+ fmap.append(x)
+
+ return fmap
+
+
+# class DACDiscriminator(ml.BaseModel):
+class DACDiscriminator(nn.Module):
+ def __init__(
+ self,
+ rates: list = [],
+ periods: list = [2, 3, 5, 7, 11],
+ fft_sizes: list = [2048, 1024, 512],
+ sample_rate: int = 24000,
+ bands: list = BANDS,
+ ):
+ """Discriminator that combines multiple discriminators.
+
+ Parameters
+ ----------
+ rates : list, optional
+ sampling rates (in Hz) to run MSD at, by default []
+ If empty, MSD is not used.
+ periods : list, optional
+ periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
+ fft_sizes : list, optional
+ Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
+ sample_rate : int, optional
+ Sampling rate of audio in Hz, by default 24000
+ bands : list, optional
+ Bands to run MRD at, by default `BANDS`
+ """
+ super().__init__()
+ discs = []
+ discs += [MPD(p) for p in periods]
+ discs += [MSD(r, sample_rate=sample_rate) for r in rates]
+ discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes]
+ self.discriminators = nn.ModuleList(discs)
+
+ def preprocess(self, y):
+ # Remove DC offset
+ y = y - y.mean(dim=-1, keepdims=True)
+ # Peak normalize the volume of input audio
+ y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
+ return y
+
+ def forward(self, x):
+ x = self.preprocess(x)
+ fmaps = [d(x) for d in self.discriminators]
+ return fmaps
+
+
+if __name__ == "__main__":
+ disc = DACDiscriminator()
+ x = torch.zeros(1, 1, 24000)
+ results = disc(x)
+ breakpoint()
+ for i, result in enumerate(results):
+ print(f"disc{i}")
+ for i, r in enumerate(result):
+ print(r.shape, r.mean(), r.min(), r.max())
+ print("00")
diff --git a/inspiremusic/wavtokenizer/decoder/discriminators.py b/inspiremusic/wavtokenizer/decoder/discriminators.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f6dece570b3d181f3fd2206a4dae2549b9e0fa3
--- /dev/null
+++ b/inspiremusic/wavtokenizer/decoder/discriminators.py
@@ -0,0 +1,202 @@
+from typing import Tuple, List
+
+import torch
+from torch import nn
+from torch.nn import Conv2d
+from torch.nn.utils import weight_norm
+
+
+class MultiPeriodDiscriminator(nn.Module):
+ """
+ Multi-Period Discriminator module adapted from https://github.com/jik876/hifi-gan.
+ Additionally, it allows incorporating conditional information with a learned embeddings table.
+
+ Args:
+ periods (tuple[int]): Tuple of periods for each discriminator.
+ num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
+ Defaults to None.
+ """
+
+ def __init__(self, periods: Tuple[int] = (2, 3, 5, 7, 11), num_embeddings: int = None):
+ super().__init__()
+ self.discriminators = nn.ModuleList([DiscriminatorP(period=p, num_embeddings=num_embeddings) for p in periods])
+
+ def forward(
+ self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
+ y_d_rs = []
+ y_d_gs = []
+ fmap_rs = []
+ fmap_gs = []
+ for d in self.discriminators:
+ y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
+ y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
+ y_d_rs.append(y_d_r)
+ fmap_rs.append(fmap_r)
+ y_d_gs.append(y_d_g)
+ fmap_gs.append(fmap_g)
+
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+class DiscriminatorP(nn.Module):
+ def __init__(
+ self,
+ period: int,
+ in_channels: int = 1,
+ kernel_size: int = 5,
+ stride: int = 3,
+ lrelu_slope: float = 0.1,
+ num_embeddings: int = None,
+ ):
+ super().__init__()
+ self.period = period
+ self.convs = nn.ModuleList(
+ [
+ weight_norm(Conv2d(in_channels, 32, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
+ weight_norm(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
+ weight_norm(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
+ weight_norm(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
+ weight_norm(Conv2d(1024, 1024, (kernel_size, 1), (1, 1), padding=(kernel_size // 2, 0))),
+ ]
+ )
+ if num_embeddings is not None:
+ self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=1024)
+ torch.nn.init.zeros_(self.emb.weight)
+
+ self.conv_post = weight_norm(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
+ self.lrelu_slope = lrelu_slope
+
+ def forward(
+ self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+ x = x.unsqueeze(1)
+ fmap = []
+ # 1d to 2d
+ b, c, t = x.shape
+ if t % self.period != 0: # pad first
+ n_pad = self.period - (t % self.period)
+ x = torch.nn.functional.pad(x, (0, n_pad), "reflect")
+ t = t + n_pad
+ x = x.view(b, c, t // self.period, self.period)
+
+ for i, l in enumerate(self.convs):
+ x = l(x)
+ x = torch.nn.functional.leaky_relu(x, self.lrelu_slope)
+ if i > 0:
+ fmap.append(x)
+ if cond_embedding_id is not None:
+ emb = self.emb(cond_embedding_id)
+ h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
+ else:
+ h = 0
+ x = self.conv_post(x)
+ fmap.append(x)
+ x += h
+ x = torch.flatten(x, 1, -1)
+
+ return x, fmap
+
+
+class MultiResolutionDiscriminator(nn.Module):
+ def __init__(
+ self,
+ resolutions: Tuple[Tuple[int, int, int]] = ((1024, 256, 1024), (2048, 512, 2048), (512, 128, 512)),
+ num_embeddings: int = None,
+ ):
+ """
+ Multi-Resolution Discriminator module adapted from https://github.com/mindslab-ai/univnet.
+ Additionally, it allows incorporating conditional information with a learned embeddings table.
+
+ Args:
+ resolutions (tuple[tuple[int, int, int]]): Tuple of resolutions for each discriminator.
+ Each resolution should be a tuple of (n_fft, hop_length, win_length).
+ num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
+ Defaults to None.
+ """
+ super().__init__()
+ self.discriminators = nn.ModuleList(
+ [DiscriminatorR(resolution=r, num_embeddings=num_embeddings) for r in resolutions]
+ )
+
+ def forward(
+ self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
+ y_d_rs = []
+ y_d_gs = []
+ fmap_rs = []
+ fmap_gs = []
+
+ for d in self.discriminators:
+ y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
+ y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
+ y_d_rs.append(y_d_r)
+ fmap_rs.append(fmap_r)
+ y_d_gs.append(y_d_g)
+ fmap_gs.append(fmap_g)
+
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+class DiscriminatorR(nn.Module):
+ def __init__(
+ self,
+ resolution: Tuple[int, int, int],
+ channels: int = 64,
+ in_channels: int = 1,
+ num_embeddings: int = None,
+ lrelu_slope: float = 0.1,
+ ):
+ super().__init__()
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.lrelu_slope = lrelu_slope
+ self.convs = nn.ModuleList(
+ [
+ weight_norm(nn.Conv2d(in_channels, channels, kernel_size=(7, 5), stride=(2, 2), padding=(3, 2))),
+ weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 1), padding=(2, 1))),
+ weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 2), padding=(2, 1))),
+ weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 1), padding=1)),
+ weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 2), padding=1)),
+ ]
+ )
+ if num_embeddings is not None:
+ self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
+ torch.nn.init.zeros_(self.emb.weight)
+ self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), padding=(1, 1)))
+
+ def forward(
+ self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+ fmap = []
+ x = self.spectrogram(x)
+ x = x.unsqueeze(1)
+ for l in self.convs:
+ x = l(x)
+ x = torch.nn.functional.leaky_relu(x, self.lrelu_slope)
+ fmap.append(x)
+ if cond_embedding_id is not None:
+ emb = self.emb(cond_embedding_id)
+ h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
+ else:
+ h = 0
+ x = self.conv_post(x)
+ fmap.append(x)
+ x += h
+ x = torch.flatten(x, 1, -1)
+
+ return x, fmap
+
+ def spectrogram(self, x: torch.Tensor) -> torch.Tensor:
+ n_fft, hop_length, win_length = self.resolution
+ magnitude_spectrogram = torch.stft(
+ x,
+ n_fft=n_fft,
+ hop_length=hop_length,
+ win_length=win_length,
+ window=None, # interestingly rectangular window kind of works here
+ center=True,
+ return_complex=True,
+ ).abs()
+
+ return magnitude_spectrogram
diff --git a/inspiremusic/wavtokenizer/decoder/experiment.py b/inspiremusic/wavtokenizer/decoder/experiment.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5557a218aedf77cc43ba3e774d6accb6003e45f
--- /dev/null
+++ b/inspiremusic/wavtokenizer/decoder/experiment.py
@@ -0,0 +1,474 @@
+import math
+
+import numpy as np
+import pytorch_lightning as pl
+import torch
+import torchaudio
+import transformers
+import yaml
+
+from decoder.discriminator_dac import DACDiscriminator
+
+from decoder.discriminators import MultiPeriodDiscriminator, MultiResolutionDiscriminator
+from decoder.feature_extractors import FeatureExtractor
+from decoder.heads import FourierHead
+from decoder.helpers import plot_spectrogram_to_numpy
+from decoder.loss import DiscriminatorLoss, GeneratorLoss, FeatureMatchingLoss, MelSpecReconstructionLoss, DACGANLoss
+from decoder.models import Backbone
+from decoder.modules import safe_log
+from decoder.pretrained_model import instantiate_class
+
+
+class VocosExp(pl.LightningModule):
+ # noinspection PyUnusedLocal
+ def __init__(
+ self,
+ feature_extractor: FeatureExtractor,
+ backbone: Backbone,
+ head: FourierHead,
+ resume_config: str,
+ resume_model: str,
+ sample_rate: int = 24000,
+ initial_learning_rate: float = 2e-4,
+ num_warmup_steps: int = 0,
+ mel_loss_coeff: float = 45,
+ mrd_loss_coeff: float = 1.0,
+ pretrain_mel_steps: int = 0,
+ decay_mel_coeff: bool = False,
+ evaluate_utmos: bool = False,
+ evaluate_pesq: bool = False,
+ evaluate_periodicty: bool = False,
+ resume: bool = False,
+ ):
+ """
+ Args:
+ feature_extractor (FeatureExtractor): An instance of FeatureExtractor to extract features from audio signals.
+ backbone (Backbone): An instance of Backbone model.
+ head (FourierHead): An instance of Fourier head to generate spectral coefficients and reconstruct a waveform.
+ sample_rate (int): Sampling rate of the audio signals.
+ initial_learning_rate (float): Initial learning rate for the optimizer.
+ num_warmup_steps (int): Number of steps for the warmup phase of learning rate scheduler. Default is 0.
+ mel_loss_coeff (float, optional): Coefficient for Mel-spectrogram loss in the loss function. Default is 45.
+ mrd_loss_coeff (float, optional): Coefficient for Multi Resolution Discriminator loss. Default is 1.0.
+ pretrain_mel_steps (int, optional): Number of steps to pre-train the model without the GAN objective. Default is 0.
+ decay_mel_coeff (bool, optional): If True, the Mel-spectrogram loss coefficient is decayed during training. Default is False.
+ evaluate_utmos (bool, optional): If True, UTMOS scores are computed for each validation run.
+ evaluate_pesq (bool, optional): If True, PESQ scores are computed for each validation run.
+ evaluate_periodicty (bool, optional): If True, periodicity scores are computed for each validation run.
+ """
+ super().__init__()
+ self.save_hyperparameters(ignore=["feature_extractor", "backbone", "head"])
+
+ self.feature_extractor = feature_extractor
+ self.backbone = backbone
+ self.head = head
+
+ self.resume_config = resume_config
+ self.resume_model = resume_model
+ self.resume = resume
+
+ self.multiperioddisc = MultiPeriodDiscriminator()
+ self.multiresddisc = MultiResolutionDiscriminator()
+
+
+ self.dac = DACDiscriminator()
+
+ self.dacdiscriminator = DACGANLoss(self.dac)
+
+ self.disc_loss = DiscriminatorLoss()
+ self.gen_loss = GeneratorLoss()
+ self.feat_matching_loss = FeatureMatchingLoss()
+ self.melspec_loss = MelSpecReconstructionLoss(sample_rate=sample_rate)
+
+ self.train_discriminator = False
+ self.base_mel_coeff = self.mel_loss_coeff = mel_loss_coeff
+
+ def configure_optimizers(self):
+ disc_params = [
+ {"params": self.multiperioddisc.parameters()},
+ {"params": self.multiresddisc.parameters()},
+ {"params": self.dac.parameters()},
+ ]
+ gen_params = [
+ {"params": self.feature_extractor.parameters()},
+ {"params": self.backbone.parameters()},
+ {"params": self.head.parameters()},
+ ]
+
+ opt_disc = torch.optim.AdamW(disc_params, lr=self.hparams.initial_learning_rate)
+ opt_gen = torch.optim.AdamW(gen_params, lr=self.hparams.initial_learning_rate)
+
+ max_steps = self.trainer.max_steps // 2 # Max steps per optimizer
+ scheduler_disc = transformers.get_cosine_schedule_with_warmup(
+ opt_disc, num_warmup_steps=self.hparams.num_warmup_steps, num_training_steps=max_steps,
+ )
+ scheduler_gen = transformers.get_cosine_schedule_with_warmup(
+ opt_gen, num_warmup_steps=self.hparams.num_warmup_steps, num_training_steps=max_steps,
+ )
+
+ return (
+ [opt_disc, opt_gen],
+ [{"scheduler": scheduler_disc, "interval": "step"}, {"scheduler": scheduler_gen, "interval": "step"}],
+ )
+
+ def forward(self, audio_input, **kwargs):
+ features, _, commit_loss = self.feature_extractor(audio_input, **kwargs)
+ # print('1111', self.feature_extractor.state_dict()['encodec.decoder.model.3.convtr.convtr.weight_g'])
+ x = self.backbone(features, **kwargs)
+ audio_output = self.head(x)
+ return audio_output, commit_loss
+
+ def training_step(self, batch, batch_idx, optimizer_idx, **kwargs):
+ audio_input = batch
+
+ # train discriminator
+ if optimizer_idx == 0 and self.train_discriminator:
+ with torch.no_grad():
+ audio_hat, _ = self(audio_input, **kwargs)
+
+
+ loss_dac=self.dacdiscriminator.discriminator_loss(audio_hat.unsqueeze(1),audio_input.unsqueeze(1))
+
+ real_score_mp, gen_score_mp, _, _ = self.multiperioddisc(y=audio_input, y_hat=audio_hat, **kwargs,)
+ real_score_mrd, gen_score_mrd, _, _ = self.multiresddisc(y=audio_input, y_hat=audio_hat, **kwargs,)
+ loss_mp, loss_mp_real, _ = self.disc_loss(
+ disc_real_outputs=real_score_mp, disc_generated_outputs=gen_score_mp
+ )
+ loss_mrd, loss_mrd_real, _ = self.disc_loss(
+ disc_real_outputs=real_score_mrd, disc_generated_outputs=gen_score_mrd
+ )
+ loss_mp /= len(loss_mp_real)
+ loss_mrd /= len(loss_mrd_real)
+ loss = loss_mp + self.hparams.mrd_loss_coeff * loss_mrd + loss_dac
+
+ self.log("discriminator/total", loss, prog_bar=True)
+ self.log("discriminator/multi_period_loss", loss_mp)
+ self.log("discriminator/multi_res_loss", loss_mrd)
+ self.log("discriminator/dac", loss_dac)
+ return loss
+
+ # train generator
+ if optimizer_idx == 1:
+ audio_hat, commit_loss = self(audio_input, **kwargs)
+ if self.train_discriminator:
+
+ loss_dac_1,loss_dac_2 = self.dacdiscriminator.generator_loss(audio_hat.unsqueeze(1),audio_input.unsqueeze(1))
+ _, gen_score_mp, fmap_rs_mp, fmap_gs_mp = self.multiperioddisc(
+ y=audio_input, y_hat=audio_hat, **kwargs,
+ )
+ _, gen_score_mrd, fmap_rs_mrd, fmap_gs_mrd = self.multiresddisc(
+ y=audio_input, y_hat=audio_hat, **kwargs,
+ )
+ loss_gen_mp, list_loss_gen_mp = self.gen_loss(disc_outputs=gen_score_mp)
+ loss_gen_mrd, list_loss_gen_mrd = self.gen_loss(disc_outputs=gen_score_mrd)
+ loss_gen_mp = loss_gen_mp / len(list_loss_gen_mp)
+ loss_gen_mrd = loss_gen_mrd / len(list_loss_gen_mrd)
+ loss_fm_mp = self.feat_matching_loss(fmap_r=fmap_rs_mp, fmap_g=fmap_gs_mp) / len(fmap_rs_mp)
+ loss_fm_mrd = self.feat_matching_loss(fmap_r=fmap_rs_mrd, fmap_g=fmap_gs_mrd) / len(fmap_rs_mrd)
+
+ self.log("generator/multi_period_loss", loss_gen_mp)
+ self.log("generator/multi_res_loss", loss_gen_mrd)
+ self.log("generator/feature_matching_mp", loss_fm_mp)
+ self.log("generator/feature_matching_mrd", loss_fm_mrd)
+ self.log("generator/loss_dac_1", loss_dac_1)
+ self.log("generator/loss_dac_2", loss_dac_2)
+ else:
+ loss_gen_mp = loss_gen_mrd = loss_fm_mp = loss_fm_mrd = 0
+
+ mel_loss = self.melspec_loss(audio_hat, audio_input)
+ loss = (
+ loss_gen_mp
+ + self.hparams.mrd_loss_coeff * loss_gen_mrd
+ + loss_fm_mp
+ + self.hparams.mrd_loss_coeff * loss_fm_mrd
+ + self.mel_loss_coeff * mel_loss
+ + 1000 * commit_loss
+ + loss_dac_1
+ + loss_dac_2
+ )
+
+ self.log("generator/total_loss", loss, prog_bar=True)
+ self.log("mel_loss_coeff", self.mel_loss_coeff)
+ self.log("generator/mel_loss", mel_loss)
+ self.log("commit_loss", commit_loss)
+
+ if self.global_step % 1000 == 0 and self.global_rank == 0:
+ self.logger.experiment.add_audio(
+ "train/audio_in", audio_input[0].data.cpu(), self.global_step, self.hparams.sample_rate
+ )
+ self.logger.experiment.add_audio(
+ "train/audio_pred", audio_hat[0].data.cpu(), self.global_step, self.hparams.sample_rate
+ )
+ with torch.no_grad():
+ mel = safe_log(self.melspec_loss.mel_spec(audio_input[0]))
+ mel_hat = safe_log(self.melspec_loss.mel_spec(audio_hat[0]))
+ self.logger.experiment.add_image(
+ "train/mel_target",
+ plot_spectrogram_to_numpy(mel.data.cpu().numpy()),
+ self.global_step,
+ dataformats="HWC",
+ )
+ self.logger.experiment.add_image(
+ "train/mel_pred",
+ plot_spectrogram_to_numpy(mel_hat.data.cpu().numpy()),
+ self.global_step,
+ dataformats="HWC",
+ )
+
+ return loss
+
+ def on_validation_epoch_start(self):
+ if self.hparams.evaluate_utmos:
+ from metrics.UTMOS import UTMOSScore
+
+ if not hasattr(self, "utmos_model"):
+ self.utmos_model = UTMOSScore(device=self.device)
+
+ def validation_step(self, batch, batch_idx, **kwargs):
+ audio_input = batch
+ audio_hat, commit_loss = self(audio_input, **kwargs)
+
+ audio_16_khz = torchaudio.functional.resample(audio_input, orig_freq=self.hparams.sample_rate, new_freq=16000)
+ audio_hat_16khz = torchaudio.functional.resample(audio_hat, orig_freq=self.hparams.sample_rate, new_freq=16000)
+
+ if self.hparams.evaluate_periodicty:
+ from metrics.periodicity import calculate_periodicity_metrics
+
+ periodicity_loss, pitch_loss, f1_score = calculate_periodicity_metrics(audio_16_khz, audio_hat_16khz)
+ else:
+ periodicity_loss = pitch_loss = f1_score = 0
+
+ if self.hparams.evaluate_utmos:
+ utmos_score = self.utmos_model.score(audio_hat_16khz.unsqueeze(1)).mean()
+ else:
+ utmos_score = torch.zeros(1, device=self.device)
+
+ if self.hparams.evaluate_pesq:
+ from pesq import pesq
+
+ pesq_score = 0
+ for ref, deg in zip(audio_16_khz.cpu().numpy(), audio_hat_16khz.cpu().numpy()):
+ pesq_score += pesq(16000, ref, deg, "wb", on_error=1)
+ pesq_score /= len(audio_16_khz)
+ pesq_score = torch.tensor(pesq_score)
+ else:
+ pesq_score = torch.zeros(1, device=self.device)
+
+ mel_loss = self.melspec_loss(audio_hat.unsqueeze(1), audio_input.unsqueeze(1))
+ total_loss = mel_loss + (5 - utmos_score) + (5 - pesq_score) + 1000 * commit_loss
+
+ return {
+ "val_loss": total_loss,
+ "mel_loss": mel_loss,
+ "utmos_score": utmos_score,
+ "pesq_score": pesq_score,
+ "periodicity_loss": periodicity_loss,
+ "pitch_loss": pitch_loss,
+ "f1_score": f1_score,
+ "audio_input": audio_input[0],
+ "audio_pred": audio_hat[0],
+ }
+
+ def validation_epoch_end(self, outputs):
+ if self.global_rank == 0:
+ *_, audio_in, audio_pred = outputs[0].values()
+ self.logger.experiment.add_audio(
+ "val_in", audio_in.data.cpu().numpy(), self.global_step, self.hparams.sample_rate
+ )
+ self.logger.experiment.add_audio(
+ "val_pred", audio_pred.data.cpu().numpy(), self.global_step, self.hparams.sample_rate
+ )
+ mel_target = safe_log(self.melspec_loss.mel_spec(audio_in))
+ mel_hat = safe_log(self.melspec_loss.mel_spec(audio_pred))
+ self.logger.experiment.add_image(
+ "val_mel_target",
+ plot_spectrogram_to_numpy(mel_target.data.cpu().numpy()),
+ self.global_step,
+ dataformats="HWC",
+ )
+ self.logger.experiment.add_image(
+ "val_mel_hat",
+ plot_spectrogram_to_numpy(mel_hat.data.cpu().numpy()),
+ self.global_step,
+ dataformats="HWC",
+ )
+ avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
+ mel_loss = torch.stack([x["mel_loss"] for x in outputs]).mean()
+ utmos_score = torch.stack([x["utmos_score"] for x in outputs]).mean()
+ pesq_score = torch.stack([x["pesq_score"] for x in outputs]).mean()
+ periodicity_loss = np.array([x["periodicity_loss"] for x in outputs]).mean()
+ pitch_loss = np.array([x["pitch_loss"] for x in outputs]).mean()
+ f1_score = np.array([x["f1_score"] for x in outputs]).mean()
+
+ self.log("val_loss", avg_loss, sync_dist=True)
+ self.log("val/mel_loss", mel_loss, sync_dist=True)
+ self.log("val/utmos_score", utmos_score, sync_dist=True)
+ self.log("val/pesq_score", pesq_score, sync_dist=True)
+ self.log("val/periodicity_loss", periodicity_loss, sync_dist=True)
+ self.log("val/pitch_loss", pitch_loss, sync_dist=True)
+ self.log("val/f1_score", f1_score, sync_dist=True)
+
+ @property
+ def global_step(self):
+ """
+ Override global_step so that it returns the total number of batches processed
+ """
+ return self.trainer.fit_loop.epoch_loop.total_batch_idx
+
+ def on_train_batch_start(self, *args):
+ if self.global_step >= self.hparams.pretrain_mel_steps:
+ self.train_discriminator = True
+ else:
+ self.train_discriminator = False
+
+ def on_train_batch_end(self, *args):
+ def mel_loss_coeff_decay(current_step, num_cycles=0.5):
+ max_steps = self.trainer.max_steps // 2
+ if current_step < self.hparams.num_warmup_steps:
+ return 1.0
+ progress = float(current_step - self.hparams.num_warmup_steps) / float(
+ max(1, max_steps - self.hparams.num_warmup_steps)
+ )
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
+
+ if self.hparams.decay_mel_coeff:
+ self.mel_loss_coeff = self.base_mel_coeff * mel_loss_coeff_decay(self.global_step + 1)
+
+
+class WavTokenizer(VocosExp):
+ """
+ WavTokenizer is a subclass of VocosExp that overrides the parent experiment to function as a conditional GAN.
+ It manages an additional `bandwidth_id` attribute, which denotes a learnable embedding corresponding to
+ a specific bandwidth value of EnCodec. During training, a random bandwidth_id is generated for each step,
+ while during validation, a fixed bandwidth_id is used.
+ """
+
+ def __init__(
+ self,
+ feature_extractor: FeatureExtractor,
+ backbone: Backbone,
+ head: FourierHead,
+ resume_config: str,
+ resume_model: str,
+ sample_rate: int = 24000,
+ initial_learning_rate: float = 2e-4,
+ num_warmup_steps: int = 0,
+ mel_loss_coeff: float = 45,
+ mrd_loss_coeff: float = 1.0,
+ pretrain_mel_steps: int = 0,
+ decay_mel_coeff: bool = False,
+ evaluate_utmos: bool = False,
+ evaluate_pesq: bool = False,
+ evaluate_periodicty: bool = False,
+ resume: bool = False,
+ ):
+ super().__init__(
+ feature_extractor,
+ backbone,
+ head,
+ resume_config,
+ resume_model,
+ sample_rate,
+ initial_learning_rate,
+ num_warmup_steps,
+ mel_loss_coeff,
+ mrd_loss_coeff,
+ pretrain_mel_steps,
+ decay_mel_coeff,
+ evaluate_utmos,
+ evaluate_pesq,
+ evaluate_periodicty,
+ resume
+ )
+ # Override with conditional discriminators
+ # VocosExp.__init__(self, feature_extractor, backbone, head, resume_config, resume_model)
+ # if self.resume:
+ # VocosExp.load_from_checkpoint(self.resume_model)
+ self.multiperioddisc = MultiPeriodDiscriminator(num_embeddings=len(self.feature_extractor.bandwidths))
+ self.multiresddisc = MultiResolutionDiscriminator(num_embeddings=len(self.feature_extractor.bandwidths))
+ self.dac = DACDiscriminator()
+ if self.resume:
+ print('ๅ ่ฝฝ้ข่ฎญ็ปๆจกๅ:', self.resume_model)
+ # with open(self.resume_config, "r") as f:
+ # config = yaml.safe_load(f)
+ # feature_extractor = instantiate_class(args=(), init=config['model']['init_args']["feature_extractor"])
+ # backbone = instantiate_class(args=(), init=config['model']['init_args']["backbone"])
+ # head = instantiate_class(args=(), init=config['model']['init_args']["head"])
+
+ # ไธๅ ่ฝฝ้ๅๅจ้จๅๆ้
+ state_dict_raw = torch.load(self.resume_model, map_location=self.device)['state_dict']
+ state_dict_fa_qa = dict()
+ state_dict_fa_en = dict()
+ state_dict_fa_de = dict()
+ state_dict_bb = dict()
+ state_dict_hd = dict()
+ state_dict_mp = dict()
+ state_dict_mr = dict()
+ state_dict_dac = dict()
+ for k, v in state_dict_raw.items():
+ # breakpoint()
+ if k.startswith('feature_extractor.encodec.quantizer'):
+ # breakpoint()
+ # print("*****",k)
+ ss = k[46:48]
+ if ss[-1] == '.':
+ num = int(ss[0])
+ # print("num,k",num,k[36:])
+ if num <= 7:
+ state_dict_fa_qa[k[36:]] = v
+ if k.startswith('feature_extractor.encodec.encoder'):
+ state_dict_fa_en[k[34:]] = v
+ if k.startswith('feature_extractor.encodec.decoder'):
+ state_dict_fa_de[k[34:]] = v
+ if k.startswith('backbone.'):
+ state_dict_bb[k[9:]] = v
+ if k.startswith('head.'):
+ state_dict_hd[k[5:]] = v
+ if k.startswith('multiperioddisc.'):
+ state_dict_mp[k[16:]] = v
+ if k.startswith('multiresddisc.'):
+ state_dict_mr[k[14:]] = v
+ if k.startswith('dac.'):
+ state_dict_dac[k[4:]] = v
+ # breakpoint()
+ # feature_extractor.encodec.quantizer.load_state_dict(state_dict_fa_qa, strict=True)
+ feature_extractor.encodec.encoder.load_state_dict(state_dict_fa_en, strict=True)
+ feature_extractor.encodec.decoder.load_state_dict(state_dict_fa_de, strict=True)
+ feature_extractor.encodec.quantizer.load_state_dict(state_dict_fa_qa, strict=True)
+ backbone.load_state_dict(state_dict_bb, strict=True)
+ head.load_state_dict(state_dict_hd, strict=True)
+ self.feature_extractor = feature_extractor.to(self.device)
+ self.backbone = backbone.to(self.device)
+ self.head = head.to(self.device)
+ self.multiperioddisc.load_state_dict(state_dict_mp, strict=True)
+ self.multiresddisc.load_state_dict(state_dict_mr, strict=True)
+ self.dac.load_state_dict(state_dict_dac, strict=True)
+
+ def training_step(self, *args):
+ # print('-------------------train--------------------')
+ # if self.global_rank == 0 and self.resume:
+ # config_path = self.resume_config
+ # model_path = self.resume_model
+ # self.pretrained_load(config_path, model_path)
+ # print('ๅ ่ฝฝ้ข่ฎญ็ปๆจกๅ:', model_path)
+ bandwidth_id = torch.randint(low=0, high=len(self.feature_extractor.bandwidths), size=(1,), device=self.device,)
+ output = super().training_step(*args, bandwidth_id=bandwidth_id)
+ return output
+
+ def validation_step(self, *args):
+ # print('-------------------valid--------------------')
+ bandwidth_id = torch.tensor([0], device=self.device)
+ output = super().validation_step(*args, bandwidth_id=bandwidth_id)
+ return output
+
+ def validation_epoch_end(self, outputs):
+ if self.global_rank == 0:
+ *_, audio_in, _ = outputs[0].values()
+ # Resynthesis with encodec for reference
+ self.feature_extractor.encodec.set_target_bandwidth(self.feature_extractor.bandwidths[0])
+ encodec_audio = self.feature_extractor.encodec(audio_in[None, None, :])
+ self.logger.experiment.add_audio(
+ "encodec", encodec_audio[0, 0].data.cpu().numpy(), self.global_step, self.hparams.sample_rate,
+ )
+
+ super().validation_epoch_end(outputs)
diff --git a/inspiremusic/wavtokenizer/decoder/feature_extractors.py b/inspiremusic/wavtokenizer/decoder/feature_extractors.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4672d141ba89b88ecdec4d48464252cb524fb9f
--- /dev/null
+++ b/inspiremusic/wavtokenizer/decoder/feature_extractors.py
@@ -0,0 +1,176 @@
+from typing import List
+
+import torch
+import torchaudio
+from torch import nn
+import math
+# from inspiremusic.wavtokenizer.decoder.modules import safe_log
+from inspiremusic.wavtokenizer.encoder.modules import SEANetEncoder, SEANetDecoder
+from inspiremusic.wavtokenizer.encoder import EncodecModel
+from inspiremusic.wavtokenizer.encoder.quantization import ResidualVectorQuantizer
+
+
+def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
+ """
+ Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
+
+ Args:
+ x (Tensor): Input tensor.
+ clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
+
+ Returns:
+ Tensor: Element-wise logarithm of the input tensor with clipping applied.
+ """
+ return torch.log(torch.clip(x, min=clip_val))
+
+
+def symlog(x: torch.Tensor) -> torch.Tensor:
+ return torch.sign(x) * torch.log1p(x.abs())
+
+
+def symexp(x: torch.Tensor) -> torch.Tensor:
+ return torch.sign(x) * (torch.exp(x.abs()) - 1)
+
+
+class FeatureExtractor(nn.Module):
+ """Base class for feature extractors."""
+
+ def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor:
+ """
+ Extract features from the given audio.
+
+ Args:
+ audio (Tensor): Input audio waveform.
+
+ Returns:
+ Tensor: Extracted features of shape (B, C, L), where B is the batch size,
+ C denotes output features, and L is the sequence length.
+ """
+ raise NotImplementedError("Subclasses must implement the forward method.")
+
+
+class MelSpectrogramFeatures(FeatureExtractor):
+ def __init__(self, sample_rate=24000, n_fft=1024, hop_length=256, n_mels=100, padding="center"):
+ super().__init__()
+ if padding not in ["center", "same"]:
+ raise ValueError("Padding must be 'center' or 'same'.")
+ self.padding = padding
+ self.mel_spec = torchaudio.transforms.MelSpectrogram(
+ sample_rate=sample_rate,
+ n_fft=n_fft,
+ hop_length=hop_length,
+ n_mels=n_mels,
+ center=padding == "center",
+ power=1,
+ )
+
+ def forward(self, audio, **kwargs):
+ if self.padding == "same":
+ pad = self.mel_spec.win_length - self.mel_spec.hop_length
+ audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect")
+ mel = self.mel_spec(audio)
+ features = safe_log(mel)
+ return features
+
+
+class EncodecFeatures(FeatureExtractor):
+ def __init__(
+ self,
+ encodec_model: str = "encodec_24khz",
+ bandwidths: List[float] = [1.5, 3.0, 6.0, 12.0],
+ train_codebooks: bool = False,
+ num_quantizers: int = 1,
+ dowmsamples: List[int] = [6, 5, 5, 4],
+ vq_bins: int = 16384,
+ vq_kmeans: int = 800,
+ ):
+ super().__init__()
+
+ # breakpoint()
+ self.frame_rate = 25 # not use
+ # n_q = int(bandwidths[-1]*1000/(math.log2(2048) * self.frame_rate))
+ n_q = num_quantizers # important
+ encoder = SEANetEncoder(causal=False, n_residual_layers=1, norm='weight_norm', pad_mode='reflect', lstm=2,
+ dimension=512, channels=1, n_filters=32, ratios=dowmsamples, activation='ELU',
+ kernel_size=7, residual_kernel_size=3, last_kernel_size=7, dilation_base=2,
+ true_skip=False, compress=2)
+ decoder = SEANetDecoder(causal=False, n_residual_layers=1, norm='weight_norm', pad_mode='reflect', lstm=2,
+ dimension=512, channels=1, n_filters=32, ratios=[8, 5, 4, 2], activation='ELU',
+ kernel_size=7, residual_kernel_size=3, last_kernel_size=7, dilation_base=2,
+ true_skip=False, compress=2)
+ quantizer = ResidualVectorQuantizer(dimension=512, n_q=n_q, bins=vq_bins, kmeans_iters=vq_kmeans,
+ decay=0.99, kmeans_init=True)
+
+ # breakpoint()
+ if encodec_model == "encodec_24khz":
+ self.encodec = EncodecModel(encoder=encoder, decoder=decoder, quantizer=quantizer,
+ target_bandwidths=bandwidths, sample_rate=24000, channels=1)
+ else:
+ raise ValueError(
+ f"Unsupported encodec_model: {encodec_model}. Supported options are 'encodec_24khz'."
+ )
+ for param in self.encodec.parameters():
+ param.requires_grad = True
+ # self.num_q = n_q
+ # codebook_weights = torch.cat([vq.codebook for vq in self.encodec.quantizer.vq.layers[: self.num_q]], dim=0)
+ # self.codebook_weights = torch.nn.Parameter(codebook_weights, requires_grad=train_codebooks)
+ self.bandwidths = bandwidths
+
+ # @torch.no_grad()
+ # def get_encodec_codes(self, audio):
+ # audio = audio.unsqueeze(1)
+ # emb = self.encodec.encoder(audio)
+ # codes = self.encodec.quantizer.encode(emb, self.encodec.frame_rate, self.encodec.bandwidth)
+ # return codes
+
+ def forward(self, audio: torch.Tensor, bandwidth_id: torch.Tensor = torch.tensor(0)):
+ if self.training:
+ self.encodec.train()
+
+ audio = audio.unsqueeze(1) # audio(16,24000)
+
+ # breakpoint()
+
+ emb = self.encodec.encoder(audio)
+ q_res = self.encodec.quantizer(emb, self.frame_rate, bandwidth=self.bandwidths[bandwidth_id])
+ quantized = q_res.quantized
+ codes = q_res.codes
+ commit_loss = q_res.penalty # codes(8,16,75),features(16,128,75)
+
+ return quantized, codes, commit_loss
+
+ # codes = self.get_encodec_codes(audio)
+ # # Instead of summing in the loop, it stores subsequent VQ dictionaries in a single `self.codebook_weights`
+ # # with offsets given by the number of bins, and finally summed in a vectorized operation.
+ # offsets = torch.arange(
+ # 0, self.encodec.quantizer.bins * len(codes), self.encodec.quantizer.bins, device=audio.device
+ # )
+ # embeddings_idxs = codes + offsets.view(-1, 1, 1)
+ # features = torch.nn.functional.embedding(embeddings_idxs, self.codebook_weights).sum(dim=0)
+ # return features.transpose(1, 2)
+
+ def infer(self, audio: torch.Tensor, bandwidth_id: torch.Tensor):
+ if self.training:
+ self.encodec.train()
+
+ audio = audio.unsqueeze(1) # audio(16,24000)
+ emb = self.encodec.encoder(audio)
+ q_res = self.encodec.quantizer.infer(emb, self.frame_rate, bandwidth=self.bandwidths[bandwidth_id])
+ quantized = q_res.quantized
+ codes = q_res.codes
+ commit_loss = q_res.penalty # codes(8,16,75),features(16,128,75)
+
+ return quantized, codes, commit_loss
+
+ def _infer(self, audio: torch.Tensor, bandwidth_id: torch.Tensor = torch.tensor(0)):
+ if self.training:
+ self.encodec.train()
+
+ audio = audio.unsqueeze(1) # audio(16,24000)
+ emb = self.encodec.encoder(audio)
+ q_res = self.encodec.quantizer.infer(emb, self.frame_rate, bandwidth=self.bandwidths[bandwidth_id])
+ quantized = q_res.quantized
+ codes = q_res.codes
+ commit_loss = q_res.penalty # codes(8,16,75),features(16,128,75)
+
+ return quantized, codes, commit_loss
\ No newline at end of file
diff --git a/inspiremusic/wavtokenizer/decoder/heads.py b/inspiremusic/wavtokenizer/decoder/heads.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb3b9f85bda23ae73c09e462cb584bc9878faca9
--- /dev/null
+++ b/inspiremusic/wavtokenizer/decoder/heads.py
@@ -0,0 +1,159 @@
+import torch
+from torch import nn
+from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
+
+from inspiremusic.wavtokenizer.decoder.spectral_ops import IMDCT, ISTFT
+
+def symexp(x: torch.Tensor) -> torch.Tensor:
+ return torch.sign(x) * (torch.exp(x.abs()) - 1)
+
+class FourierHead(nn.Module):
+ """Base class for inverse fourier modules."""
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
+ L is the sequence length, and H denotes the model dimension.
+
+ Returns:
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
+ """
+ raise NotImplementedError("Subclasses must implement the forward method.")
+
+
+class ISTFTHead(FourierHead):
+ """
+ ISTFT Head module for predicting STFT complex coefficients.
+
+ Args:
+ dim (int): Hidden dimension of the model.
+ n_fft (int): Size of Fourier transform.
+ hop_length (int): The distance between neighboring sliding window frames, which should align with
+ the resolution of the input features.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ """
+
+ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
+ super().__init__()
+ out_dim = n_fft + 2
+ self.out = torch.nn.Linear(dim, out_dim)
+ self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass of the ISTFTHead module.
+
+ Args:
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
+ L is the sequence length, and H denotes the model dimension.
+
+ Returns:
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
+ """
+ x = self.out(x).transpose(1, 2)
+ mag, p = x.chunk(2, dim=1)
+ mag = torch.exp(mag)
+ mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes
+ # wrapping happens here. These two lines produce real and imaginary value
+ x = torch.cos(p)
+ y = torch.sin(p)
+ # recalculating phase here does not produce anything new
+ # only costs time
+ # phase = torch.atan2(y, x)
+ # S = mag * torch.exp(phase * 1j)
+ # better directly produce the complex value
+
+ S = mag * (x + 1j * y)
+
+ audio = self.istft(S)
+ return audio
+
+class IMDCTSymExpHead(FourierHead):
+ """
+ IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
+
+ Args:
+ dim (int): Hidden dimension of the model.
+ mdct_frame_len (int): Length of the MDCT frame.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
+ based on perceptual scaling. Defaults to None.
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
+ """
+
+ def __init__(
+ self, dim: int, mdct_frame_len: int, padding: str = "same", sample_rate: int = None, clip_audio: bool = False,
+ ):
+ super().__init__()
+ out_dim = mdct_frame_len // 2
+ self.out = nn.Linear(dim, out_dim)
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
+ self.clip_audio = clip_audio
+
+ if sample_rate is not None:
+ # optionally init the last layer following mel-scale
+ m_max = _hz_to_mel(sample_rate // 2)
+ m_pts = torch.linspace(0, m_max, out_dim)
+ f_pts = _mel_to_hz(m_pts)
+ scale = 1 - (f_pts / f_pts.max())
+
+ with torch.no_grad():
+ self.out.weight.mul_(scale.view(-1, 1))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass of the IMDCTSymExpHead module.
+
+ Args:
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
+ L is the sequence length, and H denotes the model dimension.
+
+ Returns:
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
+ """
+ x = self.out(x)
+ x = symexp(x)
+ x = torch.clip(x, min=-1e2, max=1e2) # safeguard to prevent excessively large magnitudes
+ audio = self.imdct(x)
+ if self.clip_audio:
+ audio = torch.clip(x, min=-1.0, max=1.0)
+
+ return audio
+
+
+class IMDCTCosHead(FourierHead):
+ """
+ IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) ยท cos(p)
+
+ Args:
+ dim (int): Hidden dimension of the model.
+ mdct_frame_len (int): Length of the MDCT frame.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
+ """
+
+ def __init__(self, dim: int, mdct_frame_len: int, padding: str = "same", clip_audio: bool = False):
+ super().__init__()
+ self.clip_audio = clip_audio
+ self.out = nn.Linear(dim, mdct_frame_len)
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass of the IMDCTCosHead module.
+
+ Args:
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
+ L is the sequence length, and H denotes the model dimension.
+
+ Returns:
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
+ """
+ x = self.out(x)
+ m, p = x.chunk(2, dim=2)
+ m = torch.exp(m).clip(max=1e2) # safeguard to prevent excessively large magnitudes
+ audio = self.imdct(m * torch.cos(p))
+ if self.clip_audio:
+ audio = torch.clip(x, min=-1.0, max=1.0)
+ return audio
diff --git a/inspiremusic/wavtokenizer/decoder/helpers.py b/inspiremusic/wavtokenizer/decoder/helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d303010352ad59dde2996605f124128ee17db36
--- /dev/null
+++ b/inspiremusic/wavtokenizer/decoder/helpers.py
@@ -0,0 +1,71 @@
+import matplotlib
+import numpy as np
+import torch
+from matplotlib import pyplot as plt
+from pytorch_lightning import Callback
+
+matplotlib.use("Agg")
+
+
+def save_figure_to_numpy(fig: plt.Figure) -> np.ndarray:
+ """
+ Save a matplotlib figure to a numpy array.
+
+ Args:
+ fig (Figure): Matplotlib figure object.
+
+ Returns:
+ ndarray: Numpy array representing the figure.
+ """
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
+ return data
+
+
+def plot_spectrogram_to_numpy(spectrogram: np.ndarray) -> np.ndarray:
+ """
+ Plot a spectrogram and convert it to a numpy array.
+
+ Args:
+ spectrogram (ndarray): Spectrogram data.
+
+ Returns:
+ ndarray: Numpy array representing the plotted spectrogram.
+ """
+ spectrogram = spectrogram.astype(np.float32)
+ fig, ax = plt.subplots(figsize=(12, 3))
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
+ plt.colorbar(im, ax=ax)
+ plt.xlabel("Frames")
+ plt.ylabel("Channels")
+ plt.tight_layout()
+
+ fig.canvas.draw()
+ data = save_figure_to_numpy(fig)
+ plt.close()
+ return data
+
+
+class GradNormCallback(Callback):
+ """
+ Callback to log the gradient norm.
+ """
+
+ def on_after_backward(self, trainer, model):
+ model.log("grad_norm", gradient_norm(model))
+
+
+def gradient_norm(model: torch.nn.Module, norm_type: float = 2.0) -> torch.Tensor:
+ """
+ Compute the gradient norm.
+
+ Args:
+ model (Module): PyTorch model.
+ norm_type (float, optional): Type of the norm. Defaults to 2.0.
+
+ Returns:
+ Tensor: Gradient norm.
+ """
+ grads = [p.grad for p in model.parameters() if p.grad is not None]
+ total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type) for g in grads]), norm_type)
+ return total_norm
diff --git a/inspiremusic/wavtokenizer/decoder/loss.py b/inspiremusic/wavtokenizer/decoder/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..30f32ccf9a3f5373335ddb8da1334f508c16f752
--- /dev/null
+++ b/inspiremusic/wavtokenizer/decoder/loss.py
@@ -0,0 +1,159 @@
+from typing import List, Tuple
+
+import torch
+import torchaudio
+from torch import nn
+
+from decoder.modules import safe_log
+
+import torch.nn.functional as F
+
+
+class MelSpecReconstructionLoss(nn.Module):
+ """
+ L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample
+ """
+
+ def __init__(
+ self, sample_rate: int = 24000, n_fft: int = 1024, hop_length: int = 256, n_mels: int = 100,
+ ):
+ super().__init__()
+ self.mel_spec = torchaudio.transforms.MelSpectrogram(
+ sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, center=True, power=1,
+ )
+
+ def forward(self, y_hat, y) -> torch.Tensor:
+ """
+ Args:
+ y_hat (Tensor): Predicted audio waveform.
+ y (Tensor): Ground truth audio waveform.
+
+ Returns:
+ Tensor: L1 loss between the mel-scaled magnitude spectrograms.
+ """
+ mel_hat = safe_log(self.mel_spec(y_hat))
+ mel = safe_log(self.mel_spec(y))
+
+ loss = torch.nn.functional.l1_loss(mel, mel_hat)
+
+ return loss
+
+
+class GeneratorLoss(nn.Module):
+ """
+ Generator Loss module. Calculates the loss for the generator based on discriminator outputs.
+ """
+
+ def forward(self, disc_outputs: List[torch.Tensor]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+ """
+ Args:
+ disc_outputs (List[Tensor]): List of discriminator outputs.
+
+ Returns:
+ Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from
+ the sub-discriminators
+ """
+ loss = 0
+ gen_losses = []
+ for dg in disc_outputs:
+ l = torch.mean(torch.clamp(1 - dg, min=0))
+ gen_losses.append(l)
+ loss += l
+
+ return loss, gen_losses
+
+
+class DiscriminatorLoss(nn.Module):
+ """
+ Discriminator Loss module. Calculates the loss for the discriminator based on real and generated outputs.
+ """
+
+ def forward(
+ self, disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
+ ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
+ """
+ Args:
+ disc_real_outputs (List[Tensor]): List of discriminator outputs for real samples.
+ disc_generated_outputs (List[Tensor]): List of discriminator outputs for generated samples.
+
+ Returns:
+ Tuple[Tensor, List[Tensor], List[Tensor]]: A tuple containing the total loss, a list of loss values from
+ the sub-discriminators for real outputs, and a list of
+ loss values for generated outputs.
+ """
+ loss = 0
+ r_losses = []
+ g_losses = []
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
+ r_loss = torch.mean(torch.clamp(1 - dr, min=0))
+ g_loss = torch.mean(torch.clamp(1 + dg, min=0))
+ loss += r_loss + g_loss
+ r_losses.append(r_loss.item())
+ g_losses.append(g_loss.item())
+
+ return loss, r_losses, g_losses
+
+
+class FeatureMatchingLoss(nn.Module):
+ """
+ Feature Matching Loss module. Calculates the feature matching loss between feature maps of the sub-discriminators.
+ """
+
+ def forward(self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor:
+ """
+ Args:
+ fmap_r (List[List[Tensor]]): List of feature maps from real samples.
+ fmap_g (List[List[Tensor]]): List of feature maps from generated samples.
+
+ Returns:
+ Tensor: The calculated feature matching loss.
+ """
+ loss = 0
+ for dr, dg in zip(fmap_r, fmap_g):
+ for rl, gl in zip(dr, dg):
+ loss += torch.mean(torch.abs(rl - gl))
+
+ return loss
+
+class DACGANLoss(nn.Module):
+ """
+ Computes a discriminator loss, given a discriminator on
+ generated waveforms/spectrograms compared to ground truth
+ waveforms/spectrograms. Computes the loss for both the
+ discriminator and the generator in separate functions.
+ """
+
+ def __init__(self, discriminator):
+ super().__init__()
+ self.discriminator = discriminator
+
+ def forward(self, fake, real):
+ # d_fake = self.discriminator(fake.audio_data)
+ # d_real = self.discriminator(real.audio_data)
+ d_fake = self.discriminator(fake)
+ d_real = self.discriminator(real)
+ return d_fake, d_real
+
+ def discriminator_loss(self, fake, real):
+ d_fake, d_real = self.forward(fake.clone().detach(), real)
+
+ loss_d = 0
+ for x_fake, x_real in zip(d_fake, d_real):
+ loss_d += torch.mean(x_fake[-1] ** 2)
+ loss_d += torch.mean((1 - x_real[-1]) ** 2)
+ return loss_d
+
+ def generator_loss(self, fake, real):
+ d_fake, d_real = self.forward(fake, real)
+
+ loss_g = 0
+ for x_fake in d_fake:
+ loss_g += torch.mean((1 - x_fake[-1]) ** 2)
+
+ loss_feature = 0
+
+ for i in range(len(d_fake)):
+ for j in range(len(d_fake[i]) - 1):
+ loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
+ return loss_g, loss_feature
+
diff --git a/inspiremusic/wavtokenizer/decoder/models.py b/inspiremusic/wavtokenizer/decoder/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7ce3d99c57feb48946039a7501c638874afdf62
--- /dev/null
+++ b/inspiremusic/wavtokenizer/decoder/models.py
@@ -0,0 +1,266 @@
+from typing import Optional
+
+import torch
+from torch import nn
+from torch.nn.utils import weight_norm
+
+from inspiremusic.wavtokenizer.decoder.modules import ConvNeXtBlock, ResBlock1, AdaLayerNorm
+
+
+def nonlinearity(x):
+ # swish
+ return x * torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+ dropout, temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv1d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels,
+ out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv1d(out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv1d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv1d(in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, temb=None):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv1d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv1d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv1d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv1d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h = q.shape
+ q = q.permute(0, 2, 1) # b,hw,c
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+def make_attn(in_channels, attn_type="vanilla"):
+ assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+ if attn_type == "vanilla":
+ return AttnBlock(in_channels)
+
+
+class Backbone(nn.Module):
+ """Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
+
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
+ """
+ Args:
+ x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
+ C denotes output features, and L is the sequence length.
+
+ Returns:
+ Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
+ and H denotes the model dimension.
+ """
+ raise NotImplementedError("Subclasses must implement the forward method.")
+
+
+class VocosBackbone(Backbone):
+ """
+ Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
+
+ Args:
+ input_channels (int): Number of input features channels.
+ dim (int): Hidden dimension of the model.
+ intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
+ num_layers (int): Number of ConvNeXtBlock layers.
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
+ None means non-conditional model. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ input_channels: int,
+ dim: int,
+ intermediate_dim: int,
+ num_layers: int,
+ layer_scale_init_value: Optional[float] = None,
+ adanorm_num_embeddings: Optional[int] = None,
+ ):
+ super().__init__()
+ self.input_channels = input_channels
+ self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
+ self.adanorm = adanorm_num_embeddings is not None
+ if adanorm_num_embeddings:
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
+ else:
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
+ layer_scale_init_value = layer_scale_init_value or 1 / num_layers
+ self.convnext = nn.ModuleList(
+ [
+ ConvNeXtBlock(
+ dim=dim,
+ intermediate_dim=intermediate_dim,
+ layer_scale_init_value=layer_scale_init_value,
+ adanorm_num_embeddings=adanorm_num_embeddings,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+ self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
+ self.apply(self._init_weights)
+
+ self.temb_ch = 0
+ block_in = dim
+ dropout = 0.1
+ attn_type="vanilla"
+
+ pos_net : tp.List[nn.Module] = [
+ ResnetBlock(in_channels=block_in,out_channels=block_in,
+ temb_channels=self.temb_ch,dropout=dropout),
+ ResnetBlock(in_channels=block_in,out_channels=block_in,
+ temb_channels=self.temb_ch,dropout=dropout),
+ make_attn(block_in, attn_type=attn_type),
+ ResnetBlock(in_channels=block_in,out_channels=block_in,
+ temb_channels=self.temb_ch,dropout=dropout),
+ ResnetBlock(in_channels=block_in,out_channels=block_in,
+ temb_channels=self.temb_ch,dropout=dropout),
+ Normalize(block_in)
+ ]
+
+ self.pos_net = nn.Sequential(*pos_net)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x: torch.Tensor, bandwidth_id: Optional[torch.Tensor] = None) -> torch.Tensor:
+ x = self.embed(x)
+ x = self.pos_net(x)
+ if self.adanorm:
+ # assert bandwidth_id is not None
+ if bandwidth_id is None:
+ bandwidth_id = torch.tensor(0, device='cuda')
+ x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
+ else:
+ x = self.norm(x.transpose(1, 2))
+ x = x.transpose(1, 2)
+ for conv_block in self.convnext:
+ x = conv_block(x, cond_embedding_id=bandwidth_id)
+ x = self.final_layer_norm(x.transpose(1, 2))
+ return x
+
+
+class VocosResNetBackbone(Backbone):
+ """
+ Vocos backbone module built with ResBlocks.
+
+ Args:
+ input_channels (int): Number of input features channels.
+ dim (int): Hidden dimension of the model.
+ num_blocks (int): Number of ResBlock1 blocks.
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
+ """
+
+ def __init__(
+ self, input_channels, dim, num_blocks, layer_scale_init_value=None,
+ ):
+ super().__init__()
+ self.input_channels = input_channels
+ self.embed = weight_norm(nn.Conv1d(input_channels, dim, kernel_size=3, padding=1))
+ layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
+ self.resnet = nn.Sequential(
+ *[ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value) for _ in range(num_blocks)]
+ )
+
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
+ x = self.embed(x)
+ x = self.resnet(x)
+ x = x.transpose(1, 2)
+ return x
diff --git a/inspiremusic/wavtokenizer/decoder/modules.py b/inspiremusic/wavtokenizer/decoder/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..799a61fb94a2adc26c6e7a39e4ff3285f6556975
--- /dev/null
+++ b/inspiremusic/wavtokenizer/decoder/modules.py
@@ -0,0 +1,214 @@
+from typing import Optional
+from typing import Tuple
+
+import torch
+from torch import nn
+from torch.nn.utils import weight_norm, remove_weight_norm
+
+
+class ConvNeXtBlock(nn.Module):
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
+
+ Args:
+ dim (int): Number of input channels.
+ intermediate_dim (int): Dimensionality of the intermediate layer.
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
+ Defaults to None.
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
+ None means non-conditional LayerNorm. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ intermediate_dim: int,
+ layer_scale_init_value: Optional[float] = None,
+ adanorm_num_embeddings: Optional[int] = None,
+ ):
+ super().__init__()
+ self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
+ self.adanorm = adanorm_num_embeddings is not None
+ if adanorm_num_embeddings:
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
+ else:
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
+ self.gamma = (
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
+ if layer_scale_init_value > 0
+ else None
+ )
+
+ def forward(self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None) -> torch.Tensor:
+ residual = x
+ x = self.dwconv(x)
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
+ if self.adanorm:
+ assert cond_embedding_id is not None
+ x = self.norm(x, cond_embedding_id)
+ else:
+ x = self.norm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+ if self.gamma is not None:
+ x = self.gamma * x
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
+
+ x = residual + x
+ return x
+
+
+class AdaLayerNorm(nn.Module):
+ """
+ Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
+
+ Args:
+ num_embeddings (int): Number of embeddings.
+ embedding_dim (int): Dimension of the embeddings.
+ """
+
+ def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
+ super().__init__()
+ self.eps = eps
+ self.dim = embedding_dim
+ self.scale = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
+ self.shift = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
+ torch.nn.init.ones_(self.scale.weight)
+ torch.nn.init.zeros_(self.shift.weight)
+
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
+ scale = self.scale(cond_embedding_id)
+ shift = self.shift(cond_embedding_id)
+ x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
+ x = x * scale + shift
+ return x
+
+
+class ResBlock1(nn.Module):
+ """
+ ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
+ but without upsampling layers.
+
+ Args:
+ dim (int): Number of input channels.
+ kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
+ dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
+ Defaults to (1, 3, 5).
+ lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
+ Defaults to 0.1.
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
+ Defaults to None.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ kernel_size: int = 3,
+ dilation: Tuple[int, ...] = (1, 3, 5),
+ lrelu_slope: float = 0.1,
+ layer_scale_init_value: float = None,
+ ):
+ super().__init__()
+ self.lrelu_slope = lrelu_slope
+ self.convs1 = nn.ModuleList(
+ [
+ weight_norm(
+ nn.Conv1d(
+ dim,
+ dim,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=self.get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ nn.Conv1d(
+ dim,
+ dim,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=self.get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ weight_norm(
+ nn.Conv1d(
+ dim,
+ dim,
+ kernel_size,
+ 1,
+ dilation=dilation[2],
+ padding=self.get_padding(kernel_size, dilation[2]),
+ )
+ ),
+ ]
+ )
+
+ self.convs2 = nn.ModuleList(
+ [
+ weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
+ weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
+ weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
+ ]
+ )
+
+ self.gamma = nn.ParameterList(
+ [
+ nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
+ if layer_scale_init_value is not None
+ else None,
+ nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
+ if layer_scale_init_value is not None
+ else None,
+ nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
+ if layer_scale_init_value is not None
+ else None,
+ ]
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
+ xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
+ xt = c1(xt)
+ xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
+ xt = c2(xt)
+ if gamma is not None:
+ xt = gamma * xt
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs1:
+ remove_weight_norm(l)
+ for l in self.convs2:
+ remove_weight_norm(l)
+
+ @staticmethod
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
+ return int((kernel_size * dilation - dilation) / 2)
+
+
+def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
+ """
+ Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
+
+ Args:
+ x (Tensor): Input tensor.
+ clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
+
+ Returns:
+ Tensor: Element-wise logarithm of the input tensor with clipping applied.
+ """
+ return torch.log(torch.clip(x, min=clip_val))
+
+
+def symlog(x: torch.Tensor) -> torch.Tensor:
+ return torch.sign(x) * torch.log1p(x.abs())
+
+
+def symexp(x: torch.Tensor) -> torch.Tensor:
+ return torch.sign(x) * (torch.exp(x.abs()) - 1)
diff --git a/inspiremusic/wavtokenizer/decoder/pretrained.py b/inspiremusic/wavtokenizer/decoder/pretrained.py
new file mode 100644
index 0000000000000000000000000000000000000000..1231a05d4945e2f2debe620b1347cad3e6c7ca76
--- /dev/null
+++ b/inspiremusic/wavtokenizer/decoder/pretrained.py
@@ -0,0 +1,253 @@
+import os
+from typing import Tuple, Any, Union, Dict
+
+import torch
+import yaml
+from huggingface_hub import hf_hub_download
+from torch import nn
+from inspiremusic.wavtokenizer.decoder.feature_extractors import FeatureExtractor, EncodecFeatures
+from inspiremusic.wavtokenizer.decoder.heads import FourierHead
+from inspiremusic.wavtokenizer.decoder.models import Backbone
+
+
+def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any:
+ """Instantiates a class with the given args and init.
+
+ Args:
+ args: Positional arguments required for instantiation.
+ init: Dict of the form {"class_path":...,"init_args":...}.
+
+ Returns:
+ The instantiated class object.
+ """
+ kwargs = init.get("init_args", {})
+ if not isinstance(args, tuple):
+ args = (args,)
+ class_module, class_name = init["class_path"].rsplit(".", 1)
+ module = __import__(class_module, fromlist=[class_name])
+ args_class = getattr(module, class_name)
+ return args_class(*args, **kwargs)
+
+
+class WavTokenizer(nn.Module):
+ """
+ The Vocos class represents a Fourier-based neural vocoder for audio synthesis.
+ This class is primarily designed for inference, with support for loading from pretrained
+ model checkpoints. It consists of three main components: a feature extractor,
+ a backbone, and a head.
+ """
+
+ def __init__(
+ self, feature_extractor: FeatureExtractor, backbone: Backbone, head: FourierHead,
+ ):
+ super().__init__()
+ self.feature_extractor = feature_extractor
+ self.backbone = backbone
+ self.head = head
+
+ @classmethod
+ def from_hparams(cls, config_path: str) -> "Vocos":
+ """
+ Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
+ """
+ with open(config_path, "r") as f:
+ config = yaml.safe_load(f)
+ feature_extractor = instantiate_class(args=(), init=config["feature_extractor"])
+ backbone = instantiate_class(args=(), init=config["backbone"])
+ head = instantiate_class(args=(), init=config["head"])
+ model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head)
+ return model
+
+ @classmethod
+ def from_pretrained(self, repo_id: str) -> "Vocos":
+ """
+ Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
+ """
+ config_path = hf_hub_download(repo_id=repo_id, filename="config.yaml")
+ model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
+ model = self.from_hparams(config_path)
+ state_dict = torch.load(model_path, map_location="cpu")
+ if isinstance(model.feature_extractor, EncodecFeatures):
+ encodec_parameters = {
+ "feature_extractor.encodec." + key: value
+ for key, value in model.feature_extractor.encodec.state_dict().items()
+ }
+ state_dict.update(encodec_parameters)
+ model.load_state_dict(state_dict)
+ model.eval()
+ return model
+
+
+ @classmethod
+ def from_hparams_feat(cls, config_path: str) -> "Vocos":
+ """
+ Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
+ """
+ with open(config_path, "r") as f:
+ config = yaml.safe_load(f)
+ feature_extractor = instantiate_class(args=(), init=config['model']['init_args']["feature_extractor"])
+ backbone = instantiate_class(args=(), init=config['model']['init_args']["backbone"])
+ head = instantiate_class(args=(), init=config['model']['init_args']["head"])
+ model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head)
+ return model
+
+
+ @classmethod
+ def from_pretrained_feat(self, config_path, model_path):
+ """
+ Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
+ """
+ model = self.from_hparams_feat(config_path)
+ state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict']
+ state_dict = dict()
+ for k, v in state_dict_raw.items():
+ if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.'):
+ state_dict[k] = v
+
+ model.load_state_dict(state_dict)
+ model.eval()
+ return model
+
+ @classmethod
+ def estimator(self, config_path, model_path):
+ """
+ Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
+ """
+ model = self.from_hparams_feat(config_path)
+ state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict']
+ state_dict = dict()
+ for k, v in state_dict_raw.items():
+ if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.'):
+ state_dict[k] = v
+ model.load_state_dict(state_dict)
+ model.eval()
+ return model
+
+ @classmethod
+ def from_pretrained0911(self, config_path, model_folder_path):
+ """
+ Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
+ """
+ model = self.from_hparams0802(config_path)
+
+ models = os.listdir(model_folder_path)
+ val_loss = []
+ for item in models:
+ if not item.startswith('vocos_'):
+ continue
+ val_loss.append(item[-11:-5])
+ val_loss.sort()
+ val_loss = val_loss[:3] # ๅๅ3ๆง่ฝ่พๅฅฝ็ๆจกๅๅนณๅ
+ state_dict = dict()
+ state_dicts = []
+ for item in models:
+ if not item.startswith('vocos_'):
+ continue
+ ll = item[-11:-5]
+ if ll not in val_loss:
+ continue
+ model_path = model_folder_path + '/' + item
+ state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict']
+ state_dict_single = dict()
+ for k, v in state_dict_raw.items():
+ if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.'):
+ state_dict_single[k] = v
+ state_dicts.append(state_dict_single)
+ for kk in state_dicts[0].keys():
+ vv = state_dicts[0][kk]
+ for i in range(1, len(state_dicts)):
+ ss = state_dicts[i]
+ vv += ss[kk]
+ vm = vv/len(state_dicts)
+ state_dict[kk] = vm
+ model.load_state_dict(state_dict)
+ model.eval()
+ return model
+
+
+ @torch.inference_mode()
+ def forward(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
+ """
+ Method to run a copy-synthesis from audio waveform. The feature extractor first processes the audio input,
+ which is then passed through the backbone and the head to reconstruct the audio output.
+
+ Args:
+ audio_input (Tensor): The input tensor representing the audio waveform of shape (B, T),
+ where B is the batch size and L is the waveform length.
+
+
+ Returns:
+ Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
+ """
+ features, _, _ = self.feature_extractor(audio_input, **kwargs) # 0818
+ audio_output = self.decode(features, **kwargs)
+ return audio_output
+
+
+ # 0818
+ @torch.inference_mode()
+ def encode(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
+ features, discrete_codes, _ = self.feature_extractor(audio_input, **kwargs)
+ return features,discrete_codes
+
+
+ # 0818
+ @torch.inference_mode()
+ def encode_infer(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
+ features, discrete_codes, _ = self.feature_extractor.infer(audio_input, **kwargs)
+ return features,discrete_codes
+
+ @torch.inference_mode()
+ def infer(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
+ _, discrete_codes, _ = self.feature_extractor._infer(audio_input, **kwargs)
+ discrete_codes = discrete_codes.clamp(min=0, max=16383)
+ return discrete_codes
+
+ @torch.inference_mode()
+ def decode(self, features_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
+ """
+ Method to decode audio waveform from already calculated features. The features input is passed through
+ the backbone and the head to reconstruct the audio output.
+
+ Args:
+ features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size,
+ C denotes the feature dimension, and L is the sequence length.
+
+ Returns:
+ Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
+ """
+ x = self.backbone(features_input, **kwargs)
+ audio_output = self.head(x)
+ return audio_output
+
+ @torch.inference_mode()
+ def codes_to_features(self, codes: torch.Tensor) -> torch.Tensor:
+ """
+ Transforms an input sequence of discrete tokens (codes) into feature embeddings using the feature extractor's
+ codebook weights.
+
+ Args:
+ codes (Tensor): The input tensor. Expected shape is (K, L) or (K, B, L),
+ where K is the number of codebooks, B is the batch size and L is the sequence length.
+
+ Returns:
+ Tensor: Features of shape (B, C, L), where B is the batch size, C denotes the feature dimension,
+ and L is the sequence length.
+ """
+ assert isinstance(
+ self.feature_extractor, EncodecFeatures
+ ), "Feature extractor should be an instance of EncodecFeatures"
+
+ if codes.dim() == 2:
+ codes = codes.unsqueeze(1)
+
+ n_bins = self.feature_extractor.encodec.quantizer.bins
+ offsets = torch.arange(0, n_bins * len(codes), n_bins, device=codes.device)
+ embeddings_idxs = codes + offsets.view(-1, 1, 1)
+
+ tmp=torch.cat([vq.codebook for vq in self.feature_extractor.encodec.quantizer.vq.layers],dim=0)
+ # features = torch.nn.functional.embedding(embeddings_idxs, self.feature_extractor.codebook_weights).sum(dim=0)
+ features = torch.nn.functional.embedding(embeddings_idxs, tmp).sum(dim=0)
+ features = features.transpose(1, 2)
+
+ return features
diff --git a/inspiremusic/wavtokenizer/decoder/pretrained_model.py b/inspiremusic/wavtokenizer/decoder/pretrained_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..c919bb25685d78522c6b638cd46310c7ae5edc0d
--- /dev/null
+++ b/inspiremusic/wavtokenizer/decoder/pretrained_model.py
@@ -0,0 +1,192 @@
+from typing import Tuple, Any, Union, Dict
+
+import torch
+import yaml
+from huggingface_hub import hf_hub_download
+from torch import nn
+from inspiremusic.wavtokenizer.decoder.feature_extractors import FeatureExtractor, EncodecFeatures
+from inspiremusic.wavtokenizer.decoder.heads import FourierHead
+from inspiremusic.wavtokenizer.decoder.models import Backbone
+from inspiremusic.wavtokenizer.decoder.discriminators import MultiPeriodDiscriminator, MultiResolutionDiscriminator
+
+
+def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any:
+ """Instantiates a class with the given args and init.
+
+ Args:
+ args: Positional arguments required for instantiation.
+ init: Dict of the form {"class_path":...,"init_args":...}.
+
+ Returns:
+ The instantiated class object.
+ """
+ kwargs = init.get("init_args", {})
+ if not isinstance(args, tuple):
+ args = (args,)
+ class_module, class_name = init["class_path"].rsplit(".", 1)
+ module = __import__(class_module, fromlist=[class_name])
+ args_class = getattr(module, class_name)
+ return args_class(*args, **kwargs)
+
+
+class WavTokenizer(nn.Module):
+ """
+ The Vocos class represents a Fourier-based neural vocoder for audio synthesis.
+ This class is primarily designed for inference, with support for loading from pretrained
+ model checkpoints. It consists of three main components: a feature extractor,
+ a backbone, and a head.
+ """
+
+ def __init__(
+ self, feature_extractor: FeatureExtractor, backbone: Backbone, head: FourierHead,
+ multiperioddisc: MultiPeriodDiscriminator, multiresddisc: MultiResolutionDiscriminator,
+ ):
+ super().__init__()
+ self.feature_extractor = feature_extractor
+ self.backbone = backbone
+ self.head = head
+
+ self.multiperioddisc = multiperioddisc
+ self.multiresddisc = multiresddisc
+
+ @classmethod
+ def from_hparams0828(cls, config_path: str) -> "Vocos":
+ """
+ Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
+ """
+ with open(config_path, "r") as f:
+ config = yaml.safe_load(f)
+ feature_extractor = instantiate_class(args=(), init=config['model']['init_args']["feature_extractor"])
+ backbone = instantiate_class(args=(), init=config['model']['init_args']["backbone"])
+ head = instantiate_class(args=(), init=config['model']['init_args']["head"])
+ model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head,
+ multiperioddisc=MultiPeriodDiscriminator(num_embeddings=4),
+ multiresddisc=MultiResolutionDiscriminator(num_embeddings=4))
+ return model
+
+ @classmethod
+ def from_pretrained0828(self, config_path, model_path):
+ """
+ Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
+ """
+ model = self.from_hparams0828(config_path)
+ state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict']
+ state_dict = dict()
+ for k, v in state_dict_raw.items():
+ if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.') \
+ or k.startswith('multiperioddisc.') or k.startswith('multiresddisc.'):
+ state_dict[k] = v
+ # if isinstance(model.feature_extractor, EncodecFeatures):
+ # encodec_parameters = {
+ # "feature_extractor.encodec." + key: value
+ # for key, value in model.feature_extractor.encodec.state_dict().items()
+ # }
+ # state_dict.update(encodec_parameters)
+ model.load_state_dict(state_dict)
+ return model
+
+ @classmethod
+ def from_hparams0802(cls, config_path: str) -> "Vocos":
+ """
+ Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
+ """
+ with open(config_path, "r") as f:
+ config = yaml.safe_load(f)
+ feature_extractor = instantiate_class(args=(), init=config['model']['init_args']["feature_extractor"])
+ backbone = instantiate_class(args=(), init=config['model']['init_args']["backbone"])
+ head = instantiate_class(args=(), init=config['model']['init_args']["head"])
+ model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head)
+ return model
+
+ @classmethod
+ def from_pretrained0802(self, config_path, model_path):
+ """
+ Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
+ """
+ model = self.from_hparams0802(config_path)
+ state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict']
+ state_dict = dict()
+ for k, v in state_dict_raw.items():
+ if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.'):
+ state_dict[k] = v
+ # if isinstance(model.feature_extractor, EncodecFeatures):
+ # encodec_parameters = {
+ # "feature_extractor.encodec." + key: value
+ # for key, value in model.feature_extractor.encodec.state_dict().items()
+ # }
+ # state_dict.update(encodec_parameters)
+ model.load_state_dict(state_dict)
+ model.eval()
+ return model
+
+ @torch.inference_mode()
+ def forward(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
+ """
+ Method to run a copy-synthesis from audio waveform. The feature extractor first processes the audio input,
+ which is then passed through the backbone and the head to reconstruct the audio output.
+
+ Args:
+ audio_input (Tensor): The input tensor representing the audio waveform of shape (B, T),
+ where B is the batch size and L is the waveform length.
+
+
+ Returns:
+ Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
+ """
+ features, _, _ = self.feature_extractor(audio_input, **kwargs) # 0818
+ audio_output = self.decode(features, **kwargs)
+ return audio_output
+
+
+ # 0818
+ @torch.inference_mode()
+ def encode(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
+ features, _, _ = self.feature_extractor(audio_input, **kwargs)
+ return features
+
+
+ @torch.inference_mode()
+ def decode(self, features_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
+ """
+ Method to decode audio waveform from already calculated features. The features input is passed through
+ the backbone and the head to reconstruct the audio output.
+
+ Args:
+ features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size,
+ C denotes the feature dimension, and L is the sequence length.
+
+ Returns:
+ Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
+ """
+ x = self.backbone(features_input, **kwargs)
+ audio_output = self.head(x)
+ return audio_output
+
+ @torch.inference_mode()
+ def codes_to_features(self, codes: torch.Tensor) -> torch.Tensor:
+ """
+ Transforms an input sequence of discrete tokens (codes) into feature embeddings using the feature extractor's
+ codebook weights.
+
+ Args:
+ codes (Tensor): The input tensor. Expected shape is (K, L) or (K, B, L),
+ where K is the number of codebooks, B is the batch size and L is the sequence length.
+
+ Returns:
+ Tensor: Features of shape (B, C, L), where B is the batch size, C denotes the feature dimension,
+ and L is the sequence length.
+ """
+ assert isinstance(
+ self.feature_extractor, EncodecFeatures
+ ), "Feature extractor should be an instance of EncodecFeatures"
+
+ if codes.dim() == 2:
+ codes = codes.unsqueeze(1)
+
+ n_bins = self.feature_extractor.encodec.quantizer.bins
+ offsets = torch.arange(0, n_bins * len(codes), n_bins, device=codes.device)
+ embeddings_idxs = codes + offsets.view(-1, 1, 1)
+ features = torch.nn.functional.embedding(embeddings_idxs, self.feature_extractor.codebook_weights).sum(dim=0)
+ features = features.transpose(1, 2)
+
+ return features
diff --git a/inspiremusic/wavtokenizer/decoder/spectral_ops.py b/inspiremusic/wavtokenizer/decoder/spectral_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b062d5ff4c1d82124afa2752cea62132434790d
--- /dev/null
+++ b/inspiremusic/wavtokenizer/decoder/spectral_ops.py
@@ -0,0 +1,242 @@
+import numpy as np
+import scipy
+import torch
+from torch import nn, view_as_real, view_as_complex
+import pdb
+
+class ISTFT(nn.Module):
+ """
+ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
+ windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
+ See issue: https://github.com/pytorch/pytorch/issues/62323
+ Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
+ The NOLA constraint is met as we trim padded samples anyway.
+
+ Args:
+ n_fft (int): Size of Fourier transform.
+ hop_length (int): The distance between neighboring sliding window frames.
+ win_length (int): The size of window frame and STFT filter.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ """
+
+ def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"):
+ super().__init__()
+ if padding not in ["center", "same"]:
+ raise ValueError("Padding must be 'center' or 'same'.")
+ self.padding = padding
+ self.n_fft = n_fft
+ self.hop_length = hop_length
+ self.win_length = win_length
+ window = torch.hann_window(win_length)
+ self.register_buffer("window", window)
+
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
+ """
+ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
+
+ Args:
+ spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
+ N is the number of frequency bins, and T is the number of time frames.
+
+ Returns:
+ Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
+ """
+ if self.padding == "center":
+ # Fallback to pytorch native implementation
+ return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True)
+ elif self.padding == "same":
+ pad = (self.win_length - self.hop_length) // 2
+ else:
+ raise ValueError("Padding must be 'center' or 'same'.")
+
+ assert spec.dim() == 3, "Expected a 3D tensor as input"
+ B, N, T = spec.shape
+
+ # Inverse FFT
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
+ ifft = ifft * self.window[None, :, None]
+
+ # Overlap and Add
+ output_size = (T - 1) * self.hop_length + self.win_length
+ y = torch.nn.functional.fold(
+ ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
+ )[:, 0, 0, pad:-pad]
+
+ # Window envelope
+ window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
+ window_envelope = torch.nn.functional.fold(
+ window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
+ ).squeeze()[pad:-pad]
+
+ # Normalize
+ # assert (window_envelope > 1e-11).all()
+ if not torch.all(window_envelope > 1e-11):
+ window_envelope = torch.clamp(window_envelope, min=1e-11)
+
+ y = y / window_envelope
+
+ return y
+
+ def onnx_forward(self, spec: torch.Tensor) -> torch.Tensor:
+ """
+ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
+
+ Args:
+ spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
+ N is the number of frequency bins, and T is the number of time frames.
+
+ Returns:
+ Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
+ """
+ if self.padding == "center":
+ # Fallback to pytorch native implementation
+ return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True)
+ elif self.padding == "same":
+ pad = (self.win_length - self.hop_length) // 2
+ else:
+ raise ValueError("Padding must be 'center' or 'same'.")
+
+ assert spec.dim() == 3, "Expected a 3D tensor as input"
+ B, N, T = spec.shape
+ pdb.set_trace()
+ # Inverse FFT
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
+ ifft = ifft * self.window[None, :, None]
+
+ # Overlap and Add
+ output_size = (T - 1) * self.hop_length + self.win_length
+ y = torch.nn.functional.fold(
+ ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
+ )[:, 0, 0, pad:-pad]
+
+ # Window envelope
+ window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
+ window_envelope = torch.nn.functional.fold(
+ window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
+ ).squeeze()[pad:-pad]
+
+ # Normalize
+ # assert (window_envelope > 1e-11).all()
+ if not torch.all(window_envelope > 1e-11):
+ window_envelope = torch.clamp(window_envelope, min=1e-11)
+
+ y = y / window_envelope
+
+ return y
+
+
+class MDCT(nn.Module):
+ """
+ Modified Discrete Cosine Transform (MDCT) module.
+
+ Args:
+ frame_len (int): Length of the MDCT frame.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ """
+
+ def __init__(self, frame_len: int, padding: str = "same"):
+ super().__init__()
+ if padding not in ["center", "same"]:
+ raise ValueError("Padding must be 'center' or 'same'.")
+ self.padding = padding
+ self.frame_len = frame_len
+ N = frame_len // 2
+ n0 = (N + 1) / 2
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
+ self.register_buffer("window", window)
+
+ pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
+ post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
+ # view_as_real: NCCL Backend does not support ComplexFloat data type
+ # https://github.com/pytorch/pytorch/issues/71613
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
+
+ def forward(self, audio: torch.Tensor) -> torch.Tensor:
+ """
+ Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
+
+ Args:
+ audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
+ and T is the length of the audio.
+
+ Returns:
+ Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
+ and N is the number of frequency bins.
+ """
+ if self.padding == "center":
+ audio = torch.nn.functional.pad(audio, (self.frame_len // 2, self.frame_len // 2))
+ elif self.padding == "same":
+ # hop_length is 1/2 frame_len
+ audio = torch.nn.functional.pad(audio, (self.frame_len // 4, self.frame_len // 4))
+ else:
+ raise ValueError("Padding must be 'center' or 'same'.")
+
+ x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
+ N = self.frame_len // 2
+ x = x * self.window.expand(x.shape)
+ X = torch.fft.fft(x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1)[..., :N]
+ res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
+ return torch.real(res) * np.sqrt(2)
+
+
+class IMDCT(nn.Module):
+ """
+ Inverse Modified Discrete Cosine Transform (IMDCT) module.
+
+ Args:
+ frame_len (int): Length of the MDCT frame.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ """
+
+ def __init__(self, frame_len: int, padding: str = "same"):
+ super().__init__()
+ if padding not in ["center", "same"]:
+ raise ValueError("Padding must be 'center' or 'same'.")
+ self.padding = padding
+ self.frame_len = frame_len
+ N = frame_len // 2
+ n0 = (N + 1) / 2
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
+ self.register_buffer("window", window)
+
+ pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
+ post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
+
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
+ """
+ Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
+
+ Args:
+ X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
+ L is the number of frames, and N is the number of frequency bins.
+
+ Returns:
+ Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
+ """
+ B, L, N = X.shape
+ Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
+ Y[..., :N] = X
+ Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
+ y = torch.fft.ifft(Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1)
+ y = torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape)) * np.sqrt(N) * np.sqrt(2)
+ result = y * self.window.expand(y.shape)
+ output_size = (1, (L + 1) * N)
+ audio = torch.nn.functional.fold(
+ result.transpose(1, 2),
+ output_size=output_size,
+ kernel_size=(1, self.frame_len),
+ stride=(1, self.frame_len // 2),
+ )[:, 0, 0, :]
+
+ if self.padding == "center":
+ pad = self.frame_len // 2
+ elif self.padding == "same":
+ pad = self.frame_len // 4
+ else:
+ raise ValueError("Padding must be 'center' or 'same'.")
+
+ audio = audio[:, pad:-pad]
+ return audio
diff --git a/inspiremusic/wavtokenizer/encoder/__init__.py b/inspiremusic/wavtokenizer/encoder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff8fd2ada59e0e15d4df2854052edf150e5238e3
--- /dev/null
+++ b/inspiremusic/wavtokenizer/encoder/__init__.py
@@ -0,0 +1,12 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# flake8: noqa
+
+"""EnCodec neural audio codec."""
+
+__version__ = "0.1.2a3"
+
+from .model import EncodecModel
diff --git a/inspiremusic/wavtokenizer/encoder/distrib.py b/inspiremusic/wavtokenizer/encoder/distrib.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1662d8085cf2878c4cd058537d0f097de91d158
--- /dev/null
+++ b/inspiremusic/wavtokenizer/encoder/distrib.py
@@ -0,0 +1,124 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Torch distributed utilities."""
+
+import typing as tp
+
+import torch
+
+
+def rank():
+ if torch.distributed.is_initialized():
+ return torch.distributed.get_rank()
+ else:
+ return 0
+
+
+def world_size():
+ if torch.distributed.is_initialized():
+ return torch.distributed.get_world_size()
+ else:
+ return 1
+
+
+def is_distributed():
+ return world_size() > 1
+
+
+def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
+ if is_distributed():
+ return torch.distributed.all_reduce(tensor, op)
+
+
+def _is_complex_or_float(tensor):
+ return torch.is_floating_point(tensor) or torch.is_complex(tensor)
+
+
+def _check_number_of_params(params: tp.List[torch.Tensor]):
+ # utility function to check that the number of params in all workers is the same,
+ # and thus avoid a deadlock with distributed all reduce.
+ if not is_distributed() or not params:
+ return
+ tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
+ all_reduce(tensor)
+ if tensor.item() != len(params) * world_size():
+ # If not all the workers have the same number, for at least one of them,
+ # this inequality will be verified.
+ raise RuntimeError(f"Mismatch in number of params: ours is {len(params)}, "
+ "at least one worker has a different one.")
+
+
+def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
+ """Broadcast the tensors from the given parameters to all workers.
+ This can be used to ensure that all workers have the same model to start with.
+ """
+ if not is_distributed():
+ return
+ tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
+ _check_number_of_params(tensors)
+ handles = []
+ for tensor in tensors:
+ handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
+ handles.append(handle)
+ for handle in handles:
+ handle.wait()
+
+
+def sync_buffer(buffers, average=True):
+ """
+ Sync grad for buffers. If average is False, broadcast instead of averaging.
+ """
+ if not is_distributed():
+ return
+ handles = []
+ for buffer in buffers:
+ if torch.is_floating_point(buffer.data):
+ if average:
+ handle = torch.distributed.all_reduce(
+ buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
+ else:
+ handle = torch.distributed.broadcast(
+ buffer.data, src=0, async_op=True)
+ handles.append((buffer, handle))
+ for buffer, handle in handles:
+ handle.wait()
+ if average:
+ buffer.data /= world_size
+
+
+def sync_grad(params):
+ """
+ Simpler alternative to DistributedDataParallel, that doesn't rely
+ on any black magic. For simple models it can also be as fast.
+ Just call this on your model parameters after the call to backward!
+ """
+ if not is_distributed():
+ return
+ handles = []
+ for p in params:
+ if p.grad is not None:
+ handle = torch.distributed.all_reduce(
+ p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
+ handles.append((p, handle))
+ for p, handle in handles:
+ handle.wait()
+ p.grad.data /= world_size()
+
+
+def average_metrics(metrics: tp.Dict[str, float], count=1.):
+ """Average a dictionary of metrics across all workers, using the optional
+ `count` as unnormalized weight.
+ """
+ if not is_distributed():
+ return metrics
+ keys, values = zip(*metrics.items())
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
+ tensor *= count
+ all_reduce(tensor)
+ averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
+ return dict(zip(keys, averaged))
diff --git a/inspiremusic/wavtokenizer/encoder/model.py b/inspiremusic/wavtokenizer/encoder/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..33be28de408112b0f54f062df43ac13953e170ea
--- /dev/null
+++ b/inspiremusic/wavtokenizer/encoder/model.py
@@ -0,0 +1,324 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""EnCodec model implementation."""
+
+import math
+from pathlib import Path
+import typing as tp
+
+import numpy as np
+import torch
+from torch import nn
+
+from . import quantization as qt
+from . import modules as m
+from .utils import _check_checksum, _linear_overlap_add, _get_checkpoint_url
+
+
+ROOT_URL = 'https://dl.fbaipublicfiles.com/encodec/v0/'
+
+EncodedFrame = tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]
+
+
+class LMModel(nn.Module):
+ """Language Model to estimate probabilities of each codebook entry.
+ We predict all codebooks in parallel for a given time step.
+
+ Args:
+ n_q (int): number of codebooks.
+ card (int): codebook cardinality.
+ dim (int): transformer dimension.
+ **kwargs: passed to `encoder.modules.transformer.StreamingTransformerEncoder`.
+ """
+ def __init__(self, n_q: int = 32, card: int = 1024, dim: int = 200, **kwargs):
+ super().__init__()
+ self.card = card
+ self.n_q = n_q
+ self.dim = dim
+ self.transformer = m.StreamingTransformerEncoder(dim=dim, **kwargs)
+ self.emb = nn.ModuleList([nn.Embedding(card + 1, dim) for _ in range(n_q)])
+ self.linears = nn.ModuleList([nn.Linear(dim, card) for _ in range(n_q)])
+
+ def forward(self, indices: torch.Tensor,
+ states: tp.Optional[tp.List[torch.Tensor]] = None, offset: int = 0):
+ """
+ Args:
+ indices (torch.Tensor): indices from the previous time step. Indices
+ should be 1 + actual index in the codebook. The value 0 is reserved for
+ when the index is missing (i.e. first time step). Shape should be
+ `[B, n_q, T]`.
+ states: state for the streaming decoding.
+ offset: offset of the current time step.
+
+ Returns a 3-tuple `(probabilities, new_states, new_offset)` with probabilities
+ with a shape `[B, card, n_q, T]`.
+
+ """
+ B, K, T = indices.shape
+ input_ = sum([self.emb[k](indices[:, k]) for k in range(K)])
+ out, states, offset = self.transformer(input_, states, offset)
+ logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1).permute(0, 3, 1, 2)
+ return torch.softmax(logits, dim=1), states, offset
+
+
+class EncodecModel(nn.Module):
+ """EnCodec model operating on the raw waveform.
+ Args:
+ target_bandwidths (list of float): Target bandwidths.
+ encoder (nn.Module): Encoder network.
+ decoder (nn.Module): Decoder network.
+ sample_rate (int): Audio sample rate.
+ channels (int): Number of audio channels.
+ normalize (bool): Whether to apply audio normalization.
+ segment (float or None): segment duration in sec. when doing overlap-add.
+ overlap (float): overlap between segment, given as a fraction of the segment duration.
+ name (str): name of the model, used as metadata when compressing audio.
+ """
+ def __init__(self,
+ encoder: m.SEANetEncoder,
+ decoder: m.SEANetDecoder,
+ quantizer: qt.ResidualVectorQuantizer,
+ target_bandwidths: tp.List[float],
+ sample_rate: int,
+ channels: int,
+ normalize: bool = False,
+ segment: tp.Optional[float] = None,
+ overlap: float = 0.01,
+ name: str = 'unset'):
+ super().__init__()
+ self.bandwidth: tp.Optional[float] = None
+ self.target_bandwidths = target_bandwidths
+ self.encoder = encoder
+ self.quantizer = quantizer
+ self.decoder = decoder
+ self.sample_rate = sample_rate
+ self.channels = channels
+ self.normalize = normalize
+ self.segment = segment
+ self.overlap = overlap
+ self.frame_rate = math.ceil(self.sample_rate / np.prod(self.encoder.ratios))
+ self.name = name
+ self.bits_per_codebook = int(math.log2(self.quantizer.bins))
+ assert 2 ** self.bits_per_codebook == self.quantizer.bins, \
+ "quantizer bins must be a power of 2."
+
+ @property
+ def segment_length(self) -> tp.Optional[int]:
+ if self.segment is None:
+ return None
+ return int(self.segment * self.sample_rate)
+
+ @property
+ def segment_stride(self) -> tp.Optional[int]:
+ segment_length = self.segment_length
+ if segment_length is None:
+ return None
+ return max(1, int((1 - self.overlap) * segment_length))
+
+ def encode(self, x: torch.Tensor) -> tp.List[EncodedFrame]:
+ """Given a tensor `x`, returns a list of frames containing
+ the discrete encoded codes for `x`, along with rescaling factors
+ for each segment, when `self.normalize` is True.
+
+ Each frames is a tuple `(codebook, scale)`, with `codebook` of
+ shape `[B, K, T]`, with `K` the number of codebooks.
+ """
+ assert x.dim() == 3
+ _, channels, length = x.shape
+ assert channels > 0 and channels <= 2
+ segment_length = self.segment_length
+ if segment_length is None:
+ segment_length = length
+ stride = length
+ else:
+ stride = self.segment_stride # type: ignore
+ assert stride is not None
+
+ encoded_frames: tp.List[EncodedFrame] = []
+ for offset in range(0, length, stride):
+ frame = x[:, :, offset: offset + segment_length]
+ encoded_frames.append(self._encode_frame(frame))
+ return encoded_frames
+
+ def _encode_frame(self, x: torch.Tensor) -> EncodedFrame:
+ length = x.shape[-1]
+ duration = length / self.sample_rate
+ assert self.segment is None or duration <= 1e-5 + self.segment
+
+ if self.normalize:
+ mono = x.mean(dim=1, keepdim=True)
+ volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
+ scale = 1e-8 + volume
+ x = x / scale
+ scale = scale.view(-1, 1)
+ else:
+ scale = None
+
+ emb = self.encoder(x)
+ codes = self.quantizer.encode(emb, self.frame_rate, self.bandwidth)
+ codes = codes.transpose(0, 1)
+ # codes is [B, K, T], with T frames, K nb of codebooks.
+ return codes, scale
+
+ def decode(self, encoded_frames: tp.List[EncodedFrame]) -> torch.Tensor:
+ """Decode the given frames into a waveform.
+ Note that the output might be a bit bigger than the input. In that case,
+ any extra steps at the end can be trimmed.
+ """
+ segment_length = self.segment_length
+ if segment_length is None:
+ assert len(encoded_frames) == 1
+ return self._decode_frame(encoded_frames[0])
+
+ frames = [self._decode_frame(frame) for frame in encoded_frames]
+ return _linear_overlap_add(frames, self.segment_stride or 1)
+
+ def _decode_frame(self, encoded_frame: EncodedFrame) -> torch.Tensor:
+ codes, scale = encoded_frame
+ codes = codes.transpose(0, 1)
+ emb = self.quantizer.decode(codes)
+ out = self.decoder(emb)
+ if scale is not None:
+ out = out * scale.view(-1, 1, 1)
+ return out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ frames = self.encode(x)
+ return self.decode(frames)[:, :, :x.shape[-1]]
+
+ def set_target_bandwidth(self, bandwidth: float):
+ if bandwidth not in self.target_bandwidths:
+ raise ValueError(f"This model doesn't support the bandwidth {bandwidth}. "
+ f"Select one of {self.target_bandwidths}.")
+ self.bandwidth = bandwidth
+
+ def get_lm_model(self) -> LMModel:
+ """Return the associated LM model to improve the compression rate.
+ """
+ device = next(self.parameters()).device
+ lm = LMModel(self.quantizer.n_q, self.quantizer.bins, num_layers=5, dim=200,
+ past_context=int(3.5 * self.frame_rate)).to(device)
+ checkpoints = {
+ 'encodec_24khz': 'encodec_lm_24khz-1608e3c0.th',
+ 'encodec_48khz': 'encodec_lm_48khz-7add9fc3.th',
+ }
+ try:
+ checkpoint_name = checkpoints[self.name]
+ except KeyError:
+ raise RuntimeError("No LM pre-trained for the current Encodec model.")
+ url = _get_checkpoint_url(ROOT_URL, checkpoint_name)
+ state = torch.hub.load_state_dict_from_url(
+ url, map_location='cpu', check_hash=True) # type: ignore
+ lm.load_state_dict(state)
+ lm.eval()
+ return lm
+
+ @staticmethod
+ def _get_model(target_bandwidths: tp.List[float],
+ sample_rate: int = 24_000,
+ channels: int = 1,
+ causal: bool = True,
+ model_norm: str = 'weight_norm',
+ audio_normalize: bool = False,
+ segment: tp.Optional[float] = None,
+ name: str = 'unset'):
+ encoder = m.SEANetEncoder(channels=channels, norm=model_norm, causal=causal)
+ decoder = m.SEANetDecoder(channels=channels, norm=model_norm, causal=causal)
+ n_q = int(1000 * target_bandwidths[-1] // (math.ceil(sample_rate / encoder.hop_length) * 10))
+ quantizer = qt.ResidualVectorQuantizer(
+ dimension=encoder.dimension,
+ n_q=n_q,
+ bins=1024,
+ )
+ model = EncodecModel(
+ encoder,
+ decoder,
+ quantizer,
+ target_bandwidths,
+ sample_rate,
+ channels,
+ normalize=audio_normalize,
+ segment=segment,
+ name=name,
+ )
+ return model
+
+ @staticmethod
+ def _get_pretrained(checkpoint_name: str, repository: tp.Optional[Path] = None):
+ if repository is not None:
+ if not repository.is_dir():
+ raise ValueError(f"{repository} must exist and be a directory.")
+ file = repository / checkpoint_name
+ checksum = file.stem.split('-')[1]
+ _check_checksum(file, checksum)
+ return torch.load(file)
+ else:
+ url = _get_checkpoint_url(ROOT_URL, checkpoint_name)
+ return torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True) # type:ignore
+
+ @staticmethod
+ def encodec_model_24khz(pretrained: bool = True, repository: tp.Optional[Path] = None):
+ """Return the pretrained causal 24khz model.
+ """
+ if repository:
+ assert pretrained
+ target_bandwidths = [1.5, 3., 6, 12., 24.]
+ checkpoint_name = 'encodec_24khz-d7cc33bc.th'
+ sample_rate = 24_000
+ channels = 1
+ model = EncodecModel._get_model(
+ target_bandwidths, sample_rate, channels,
+ causal=True, model_norm='weight_norm', audio_normalize=False,
+ name='encodec_24khz' if pretrained else 'unset')
+ if pretrained:
+ state_dict = EncodecModel._get_pretrained(checkpoint_name, repository)
+ model.load_state_dict(state_dict)
+ model.eval()
+ return model
+
+ @staticmethod
+ def encodec_model_48khz(pretrained: bool = True, repository: tp.Optional[Path] = None):
+ """Return the pretrained 48khz model.
+ """
+ if repository:
+ assert pretrained
+ target_bandwidths = [3., 6., 12., 24.]
+ checkpoint_name = 'encodec_48khz-7e698e3e.th'
+ sample_rate = 48_000
+ channels = 2
+ model = EncodecModel._get_model(
+ target_bandwidths, sample_rate, channels,
+ causal=False, model_norm='time_group_norm', audio_normalize=True,
+ segment=1., name='encodec_48khz' if pretrained else 'unset')
+ if pretrained:
+ state_dict = EncodecModel._get_pretrained(checkpoint_name, repository)
+ model.load_state_dict(state_dict)
+ model.eval()
+ return model
+
+
+def test():
+ from itertools import product
+ import torchaudio
+ bandwidths = [3, 6, 12, 24]
+ models = {
+ 'encodec_24khz': EncodecModel.encodec_model_24khz,
+ 'encodec_48khz': EncodecModel.encodec_model_48khz
+ }
+ for model_name, bw in product(models.keys(), bandwidths):
+ model = models[model_name]()
+ model.set_target_bandwidth(bw)
+ audio_suffix = model_name.split('_')[1][:3]
+ wav, sr = torchaudio.load(f"test_{audio_suffix}.wav")
+ wav = wav[:, :model.sample_rate * 2]
+ wav_in = wav.unsqueeze(0)
+ wav_dec = model(wav_in)[0]
+ assert wav.shape == wav_dec.shape, (wav.shape, wav_dec.shape)
+
+
+if __name__ == '__main__':
+ test()
diff --git a/inspiremusic/wavtokenizer/encoder/modules/__init__.py b/inspiremusic/wavtokenizer/encoder/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8e2f987aafa3abf9b882fe15ca5a3b6e150ea32
--- /dev/null
+++ b/inspiremusic/wavtokenizer/encoder/modules/__init__.py
@@ -0,0 +1,22 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Torch modules."""
+
+# flake8: noqa
+from .conv import (
+ pad1d,
+ unpad1d,
+ NormConv1d,
+ NormConvTranspose1d,
+ NormConv2d,
+ NormConvTranspose2d,
+ SConv1d,
+ SConvTranspose1d,
+)
+from .lstm import SLSTM
+from .seanet import SEANetEncoder, SEANetDecoder
+from .transformer import StreamingTransformerEncoder
diff --git a/inspiremusic/wavtokenizer/encoder/modules/conv.py b/inspiremusic/wavtokenizer/encoder/modules/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..e83ae84d20ad2082c6e83bb7fc73bb22ac58cf13
--- /dev/null
+++ b/inspiremusic/wavtokenizer/encoder/modules/conv.py
@@ -0,0 +1,253 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Convolutional layers wrappers and utilities."""
+
+import math
+import typing as tp
+import warnings
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.utils import spectral_norm, weight_norm
+
+from .norm import ConvLayerNorm
+
+
+CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
+ 'time_layer_norm', 'layer_norm', 'time_group_norm'])
+
+
+def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
+ assert norm in CONV_NORMALIZATIONS
+ if norm == 'weight_norm':
+ return weight_norm(module)
+ elif norm == 'spectral_norm':
+ return spectral_norm(module)
+ else:
+ # We already check was in CONV_NORMALIZATION, so any other choice
+ # doesn't need reparametrization.
+ return module
+
+
+def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
+ """Return the proper normalization module. If causal is True, this will ensure the returned
+ module is causal, or return an error if the normalization doesn't support causal evaluation.
+ """
+ assert norm in CONV_NORMALIZATIONS
+ if norm == 'layer_norm':
+ assert isinstance(module, nn.modules.conv._ConvNd)
+ return ConvLayerNorm(module.out_channels, **norm_kwargs)
+ elif norm == 'time_group_norm':
+ if causal:
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
+ assert isinstance(module, nn.modules.conv._ConvNd)
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
+ else:
+ return nn.Identity()
+
+
+def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
+ padding_total: int = 0) -> int:
+ """See `pad_for_conv1d`.
+ """
+ length = x.shape[-1]
+ n_frames = (length - kernel_size + padding_total) / stride + 1
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
+ return ideal_length - length
+
+
+def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
+ """Pad for a convolution to make sure that the last window is full.
+ Extra padding is added at the end. This is required to ensure that we can rebuild
+ an output of the same length, as otherwise, even with padding, some time steps
+ might get removed.
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
+ 1 2 3 4 # once you removed padding, we are missing one time step !
+ """
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
+ return F.pad(x, (0, extra_padding))
+
+
+def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
+ """
+ length = x.shape[-1]
+ padding_left, padding_right = paddings
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+ if mode == 'reflect':
+ max_pad = max(padding_left, padding_right)
+ extra_pad = 0
+ if length <= max_pad:
+ extra_pad = max_pad - length + 1
+ x = F.pad(x, (0, extra_pad))
+ padded = F.pad(x, paddings, mode, value)
+ end = padded.shape[-1] - extra_pad
+ return padded[..., :end]
+ else:
+ return F.pad(x, paddings, mode, value)
+
+
+def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
+ padding_left, padding_right = paddings
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+ assert (padding_left + padding_right) <= x.shape[-1]
+ end = x.shape[-1] - padding_right
+ return x[..., padding_left: end]
+
+
+class NormConv1d(nn.Module):
+ """Wrapper around Conv1d and normalization applied to this conv
+ to provide a uniform interface across normalization approaches.
+ """
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+ super().__init__()
+ self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
+ self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
+ self.norm_type = norm
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.norm(x)
+ return x
+
+
+class NormConv2d(nn.Module):
+ """Wrapper around Conv2d and normalization applied to this conv
+ to provide a uniform interface across normalization approaches.
+ """
+ def __init__(self, *args, norm: str = 'none',
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+ super().__init__()
+ self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
+ self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
+ self.norm_type = norm
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.norm(x)
+ return x
+
+
+class NormConvTranspose1d(nn.Module):
+ """Wrapper around ConvTranspose1d and normalization applied to this conv
+ to provide a uniform interface across normalization approaches.
+ """
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+ super().__init__()
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
+ self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
+ self.norm_type = norm
+
+ def forward(self, x):
+ x = self.convtr(x)
+ x = self.norm(x)
+ return x
+
+
+class NormConvTranspose2d(nn.Module):
+ """Wrapper around ConvTranspose2d and normalization applied to this conv
+ to provide a uniform interface across normalization approaches.
+ """
+ def __init__(self, *args, norm: str = 'none',
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+ super().__init__()
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
+ self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
+
+ def forward(self, x):
+ x = self.convtr(x)
+ x = self.norm(x)
+ return x
+
+
+class SConv1d(nn.Module):
+ """Conv1d with some builtin handling of asymmetric or causal padding
+ and normalization.
+ """
+ def __init__(self, in_channels: int, out_channels: int,
+ kernel_size: int, stride: int = 1, dilation: int = 1,
+ groups: int = 1, bias: bool = True, causal: bool = False,
+ norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
+ pad_mode: str = 'reflect'):
+ super().__init__()
+ # warn user on unusual setup between dilation and stride
+ if stride > 1 and dilation > 1:
+ warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1'
+ f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
+ self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
+ dilation=dilation, groups=groups, bias=bias, causal=causal,
+ norm=norm, norm_kwargs=norm_kwargs)
+ self.causal = causal
+ self.pad_mode = pad_mode
+
+ def forward(self, x):
+ B, C, T = x.shape
+ kernel_size = self.conv.conv.kernel_size[0]
+ stride = self.conv.conv.stride[0]
+ dilation = self.conv.conv.dilation[0]
+ kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
+ padding_total = kernel_size - stride
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
+ if self.causal:
+ # Left padding for causal
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
+ else:
+ # Asymmetric padding required for odd strides
+ padding_right = padding_total // 2
+ padding_left = padding_total - padding_right
+ x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
+ return self.conv(x)
+
+
+class SConvTranspose1d(nn.Module):
+ """ConvTranspose1d with some builtin handling of asymmetric or causal padding
+ and normalization.
+ """
+ def __init__(self, in_channels: int, out_channels: int,
+ kernel_size: int, stride: int = 1, causal: bool = False,
+ norm: str = 'none', trim_right_ratio: float = 1.,
+ norm_kwargs: tp.Dict[str, tp.Any] = {}):
+ super().__init__()
+ self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
+ causal=causal, norm=norm, norm_kwargs=norm_kwargs)
+ self.causal = causal
+ self.trim_right_ratio = trim_right_ratio
+ assert self.causal or self.trim_right_ratio == 1., \
+ "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
+ assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
+
+ def forward(self, x):
+ kernel_size = self.convtr.convtr.kernel_size[0]
+ stride = self.convtr.convtr.stride[0]
+ padding_total = kernel_size - stride
+
+ y = self.convtr(x)
+
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
+ # removed at the very end, when keeping only the right length for the output,
+ # as removing it here would require also passing the length at the matching layer
+ # in the encoder.
+ if self.causal:
+ # Trim the padding on the right according to the specified ratio
+ # if trim_right_ratio = 1.0, trim everything from right
+ padding_right = math.ceil(padding_total * self.trim_right_ratio)
+ padding_left = padding_total - padding_right
+ y = unpad1d(y, (padding_left, padding_right))
+ else:
+ # Asymmetric padding required for odd strides
+ padding_right = padding_total // 2
+ padding_left = padding_total - padding_right
+ y = unpad1d(y, (padding_left, padding_right))
+ return y
diff --git a/inspiremusic/wavtokenizer/encoder/modules/lstm.py b/inspiremusic/wavtokenizer/encoder/modules/lstm.py
new file mode 100644
index 0000000000000000000000000000000000000000..49908198953deed173bed6eed5199eb74b99e5f8
--- /dev/null
+++ b/inspiremusic/wavtokenizer/encoder/modules/lstm.py
@@ -0,0 +1,39 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""LSTM layers module."""
+
+from torch import nn
+
+
+class SLSTM(nn.Module):
+ """
+ LSTM without worrying about the hidden state, nor the layout of the data.
+ Expects input as convolutional layout.
+ """
+ def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
+ super().__init__()
+ self.skip = skip
+ self.lstm = nn.LSTM(dimension, dimension, num_layers)
+
+ # def forward(self, x):
+ # x = x.permute(2, 0, 1)
+ # y, _ = self.lstm(x)
+ # if self.skip:
+ # y = y + x
+ # y = y.permute(1, 2, 0)
+ # return y
+
+ # ไฟฎๆนtranspose้กบๅบ
+ def forward(self, x):
+ # # ๆๅ
ฅreshape
+ # x = x.reshape(x.shape)
+ x1 = x.permute(2, 0, 1)
+ y, _ = self.lstm(x1)
+ y = y.permute(1, 2, 0)
+ if self.skip:
+ y = y + x
+ return y
diff --git a/inspiremusic/wavtokenizer/encoder/modules/norm.py b/inspiremusic/wavtokenizer/encoder/modules/norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..19970e0a21ea1c10461cb56d776619dd5f64ff36
--- /dev/null
+++ b/inspiremusic/wavtokenizer/encoder/modules/norm.py
@@ -0,0 +1,28 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Normalization modules."""
+
+import typing as tp
+
+import einops
+import torch
+from torch import nn
+
+
+class ConvLayerNorm(nn.LayerNorm):
+ """
+ Convolution-friendly LayerNorm that moves channels to last dimensions
+ before running the normalization and moves them back to original position right after.
+ """
+ def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
+ super().__init__(normalized_shape, **kwargs)
+
+ def forward(self, x):
+ x = einops.rearrange(x, 'b ... t -> b t ...')
+ x = super().forward(x)
+ x = einops.rearrange(x, 'b t ... -> b ... t')
+ return
diff --git a/inspiremusic/wavtokenizer/encoder/modules/seanet.py b/inspiremusic/wavtokenizer/encoder/modules/seanet.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea1c02d508cbffce0613a637d4c7943d936b09db
--- /dev/null
+++ b/inspiremusic/wavtokenizer/encoder/modules/seanet.py
@@ -0,0 +1,253 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Encodec SEANet-based encoder and decoder implementation."""
+
+import typing as tp
+
+import numpy as np
+import torch.nn as nn
+
+from . import (
+ SConv1d,
+ SConvTranspose1d,
+ SLSTM
+)
+
+
+class SEANetResnetBlock(nn.Module):
+ """Residual block from SEANet model.
+ Args:
+ dim (int): Dimension of the input/output
+ kernel_sizes (list): List of kernel sizes for the convolutions.
+ dilations (list): List of dilations for the convolutions.
+ activation (str): Activation function.
+ activation_params (dict): Parameters to provide to the activation function
+ norm (str): Normalization method.
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
+ causal (bool): Whether to use fully causal convolution.
+ pad_mode (str): Padding mode for the convolutions.
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3)
+ true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection.
+ """
+ def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
+ activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
+ norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False,
+ pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True):
+ super().__init__()
+ assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations'
+ act = getattr(nn, activation)
+ hidden = dim // compress
+ block = []
+ for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
+ in_chs = dim if i == 0 else hidden
+ out_chs = dim if i == len(kernel_sizes) - 1 else hidden
+ block += [
+ act(**activation_params),
+ SConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
+ norm=norm, norm_kwargs=norm_params,
+ causal=causal, pad_mode=pad_mode),
+ ]
+ self.block = nn.Sequential(*block)
+ self.shortcut: nn.Module
+ if true_skip:
+ self.shortcut = nn.Identity()
+ else:
+ self.shortcut = SConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params,
+ causal=causal, pad_mode=pad_mode)
+
+ def forward(self, x):
+ return self.shortcut(x) + self.block(x)
+
+
+class SEANetEncoder(nn.Module):
+ """SEANet encoder.
+ Args:
+ channels (int): Audio channels.
+ dimension (int): Intermediate representation dimension.
+ n_filters (int): Base width for the model.
+ n_residual_layers (int): nb of residual layers.
+ ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
+ upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
+ that must match the decoder order
+ activation (str): Activation function.
+ activation_params (dict): Parameters to provide to the activation function
+ norm (str): Normalization method.
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
+ kernel_size (int): Kernel size for the initial convolution.
+ last_kernel_size (int): Kernel size for the initial convolution.
+ residual_kernel_size (int): Kernel size for the residual layers.
+ dilation_base (int): How much to increase the dilation with each layer.
+ causal (bool): Whether to use fully causal convolution.
+ pad_mode (str): Padding mode for the convolutions.
+ true_skip (bool): Whether to use true skip connection or a simple
+ (streamable) convolution as the skip connection in the residual network blocks.
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
+ lstm (int): Number of LSTM layers at the end of the encoder.
+ """
+ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 1,
+ ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
+ norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
+ last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
+ pad_mode: str = 'reflect', true_skip: bool = False, compress: int = 2, lstm: int = 2):
+ super().__init__()
+ self.channels = channels
+ self.dimension = dimension
+ self.n_filters = n_filters
+ self.ratios = list(reversed(ratios))
+ del ratios
+ self.n_residual_layers = n_residual_layers
+ self.hop_length = np.prod(self.ratios)
+
+ act = getattr(nn, activation)
+ mult = 1
+ model: tp.List[nn.Module] = [
+ SConv1d(channels, mult * n_filters, kernel_size, norm=norm, norm_kwargs=norm_params,
+ causal=causal, pad_mode=pad_mode)
+ ]
+ # Downsample to raw audio scale
+ for i, ratio in enumerate(self.ratios):
+ # Add residual layers
+ for j in range(n_residual_layers):
+ model += [
+ SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
+ dilations=[dilation_base ** j, 1],
+ norm=norm, norm_params=norm_params,
+ activation=activation, activation_params=activation_params,
+ causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
+
+ # Add downsampling layers
+ model += [
+ act(**activation_params),
+ SConv1d(mult * n_filters, mult * n_filters * 2,
+ kernel_size=ratio * 2, stride=ratio,
+ norm=norm, norm_kwargs=norm_params,
+ causal=causal, pad_mode=pad_mode),
+ ]
+ mult *= 2
+
+ if lstm:
+ model += [SLSTM(mult * n_filters, num_layers=lstm)]
+
+ model += [
+ act(**activation_params),
+ SConv1d(mult * n_filters, dimension, last_kernel_size, norm=norm, norm_kwargs=norm_params,
+ causal=causal, pad_mode=pad_mode)
+ ]
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, x):
+ return self.model(x)
+
+
+class SEANetDecoder(nn.Module):
+ """SEANet decoder.
+ Args:
+ channels (int): Audio channels.
+ dimension (int): Intermediate representation dimension.
+ n_filters (int): Base width for the model.
+ n_residual_layers (int): nb of residual layers.
+ ratios (Sequence[int]): kernel size and stride ratios
+ activation (str): Activation function.
+ activation_params (dict): Parameters to provide to the activation function
+ final_activation (str): Final activation function after all convolutions.
+ final_activation_params (dict): Parameters to provide to the activation function
+ norm (str): Normalization method.
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
+ kernel_size (int): Kernel size for the initial convolution.
+ last_kernel_size (int): Kernel size for the initial convolution.
+ residual_kernel_size (int): Kernel size for the residual layers.
+ dilation_base (int): How much to increase the dilation with each layer.
+ causal (bool): Whether to use fully causal convolution.
+ pad_mode (str): Padding mode for the convolutions.
+ true_skip (bool): Whether to use true skip connection or a simple
+ (streamable) convolution as the skip connection in the residual network blocks.
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
+ lstm (int): Number of LSTM layers at the end of the encoder.
+ trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
+ If equal to 1.0, it means that all the trimming is done at the right.
+ """
+ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 1,
+ ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
+ final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None,
+ norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
+ last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
+ pad_mode: str = 'reflect', true_skip: bool = False, compress: int = 2, lstm: int = 2,
+ trim_right_ratio: float = 1.0):
+ super().__init__()
+ self.dimension = dimension
+ self.channels = channels
+ self.n_filters = n_filters
+ self.ratios = ratios
+ del ratios
+ self.n_residual_layers = n_residual_layers
+ self.hop_length = np.prod(self.ratios)
+
+ act = getattr(nn, activation)
+ mult = int(2 ** len(self.ratios))
+ model: tp.List[nn.Module] = [
+ SConv1d(dimension, mult * n_filters, kernel_size, norm=norm, norm_kwargs=norm_params,
+ causal=causal, pad_mode=pad_mode)
+ ]
+
+ if lstm:
+ model += [SLSTM(mult * n_filters, num_layers=lstm)]
+
+ # Upsample to raw audio scale
+ for i, ratio in enumerate(self.ratios):
+ # Add upsampling layers
+ model += [
+ act(**activation_params),
+ SConvTranspose1d(mult * n_filters, mult * n_filters // 2,
+ kernel_size=ratio * 2, stride=ratio,
+ norm=norm, norm_kwargs=norm_params,
+ causal=causal, trim_right_ratio=trim_right_ratio),
+ ]
+ # Add residual layers
+ for j in range(n_residual_layers):
+ model += [
+ SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1],
+ dilations=[dilation_base ** j, 1],
+ activation=activation, activation_params=activation_params,
+ norm=norm, norm_params=norm_params, causal=causal,
+ pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
+
+ mult //= 2
+
+ # Add final layers
+ model += [
+ act(**activation_params),
+ SConv1d(n_filters, channels, last_kernel_size, norm=norm, norm_kwargs=norm_params,
+ causal=causal, pad_mode=pad_mode)
+ ]
+ # Add optional final activation to decoder (eg. tanh)
+ if final_activation is not None:
+ final_act = getattr(nn, final_activation)
+ final_activation_params = final_activation_params or {}
+ model += [
+ final_act(**final_activation_params)
+ ]
+ self.model = nn.Sequential(*model)
+
+ def forward(self, z):
+ y = self.model(z)
+ return y
+
+
+def test():
+ import torch
+ encoder = SEANetEncoder()
+ decoder = SEANetDecoder()
+ x = torch.randn(1, 1, 24000)
+ z = encoder(x)
+ assert list(z.shape) == [1, 128, 75], z.shape
+ y = decoder(z)
+ assert y.shape == x.shape, (x.shape, y.shape)
+
+
+if __name__ == '__main__':
+ test()
diff --git a/inspiremusic/wavtokenizer/encoder/modules/transformer.py b/inspiremusic/wavtokenizer/encoder/modules/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..44b47918f84aa47021c0d6f5bd58364641088541
--- /dev/null
+++ b/inspiremusic/wavtokenizer/encoder/modules/transformer.py
@@ -0,0 +1,119 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""A streamable transformer."""
+
+import typing as tp
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000):
+ """Create time embedding for the given positions, target dimension `dim`.
+ """
+ # We aim for BTC format
+ assert dim % 2 == 0
+ half_dim = dim // 2
+ adim = torch.arange(half_dim, device=positions.device).view(1, 1, -1)
+ phase = positions / (max_period ** (adim / (half_dim - 1)))
+ return torch.cat([
+ torch.cos(phase),
+ torch.sin(phase),
+ ], dim=-1)
+
+
+class StreamingTransformerEncoderLayer(nn.TransformerEncoderLayer):
+ def forward(self, x: torch.Tensor, x_past: torch.Tensor, past_context: int): # type: ignore
+ if self.norm_first:
+ sa_input = self.norm1(x)
+ x = x + self._sa_block(sa_input, x_past, past_context)
+ x = x + self._ff_block(self.norm2(x))
+ else:
+ sa_input = x
+ x = self.norm1(x + self._sa_block(sa_input, x_past, past_context))
+ x = self.norm2(x + self._ff_block(x))
+
+ return x, sa_input
+
+ # self-attention block
+ def _sa_block(self, x: torch.Tensor, x_past: torch.Tensor, past_context: int): # type: ignore
+ _, T, _ = x.shape
+ _, H, _ = x_past.shape
+
+ queries = x
+ keys = torch.cat([x_past, x], dim=1)
+ values = keys
+
+ queries_pos = torch.arange(H, T + H, device=x.device).view(-1, 1)
+ keys_pos = torch.arange(T + H, device=x.device).view(1, -1)
+ delta = queries_pos - keys_pos
+ valid_access = (delta >= 0) & (delta <= past_context)
+ x = self.self_attn(queries, keys, values,
+ attn_mask=~valid_access,
+ need_weights=False)[0]
+ return self.dropout1(x)
+
+
+class StreamingTransformerEncoder(nn.Module):
+ """TransformerEncoder with streaming support.
+
+ Args:
+ dim (int): dimension of the data.
+ hidden_scale (int): intermediate dimension of FF module is this times the dimension.
+ num_heads (int): number of heads.
+ num_layers (int): number of layers.
+ max_period (float): maxium period of cosines in the positional embedding.
+ past_context (int or None): receptive field for the causal mask, infinite if None.
+ gelu (bool): if true uses GeLUs, otherwise use ReLUs.
+ norm_in (bool): normalize the input.
+ dropout (float): dropout probability.
+ **kwargs: See `nn.TransformerEncoderLayer`.
+ """
+ def __init__(self, dim, hidden_scale: float = 4., num_heads: int = 8, num_layers: int = 5,
+ max_period: float = 10000, past_context: int = 1000, gelu: bool = True,
+ norm_in: bool = True, dropout: float = 0., **kwargs):
+ super().__init__()
+ assert dim % num_heads == 0
+ hidden_dim = int(dim * hidden_scale)
+
+ self.max_period = max_period
+ self.past_context = past_context
+ activation: tp.Any = F.gelu if gelu else F.relu
+
+ self.norm_in: nn.Module
+ if norm_in:
+ self.norm_in = nn.LayerNorm(dim)
+ else:
+ self.norm_in = nn.Identity()
+
+ self.layers = nn.ModuleList()
+ for idx in range(num_layers):
+ self.layers.append(
+ StreamingTransformerEncoderLayer(
+ dim, num_heads, hidden_dim,
+ activation=activation, batch_first=True, dropout=dropout, **kwargs))
+
+ def forward(self, x: torch.Tensor,
+ states: tp.Optional[tp.List[torch.Tensor]] = None,
+ offset: tp.Union[int, torch.Tensor] = 0):
+ B, T, C = x.shape
+ if states is None:
+ states = [torch.zeros_like(x[:, :1]) for _ in range(1 + len(self.layers))]
+
+ positions = torch.arange(T, device=x.device).view(1, -1, 1) + offset
+ pos_emb = create_sin_embedding(positions, C, max_period=self.max_period)
+
+ new_state: tp.List[torch.Tensor] = []
+ x = self.norm_in(x)
+ x = x + pos_emb
+
+ for layer_state, layer in zip(states, self.layers):
+ x, new_layer_state = layer(x, layer_state, self.past_context)
+ new_layer_state = torch.cat([layer_state, new_layer_state], dim=1)
+ new_state.append(new_layer_state[:, -self.past_context:, :])
+ return x, new_state, offset + T
diff --git a/inspiremusic/wavtokenizer/encoder/msstftd.py b/inspiremusic/wavtokenizer/encoder/msstftd.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1d3242a57e1e20e99bc2fa86e363cc5ec92cbf7
--- /dev/null
+++ b/inspiremusic/wavtokenizer/encoder/msstftd.py
@@ -0,0 +1,147 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""MS-STFT discriminator, provided here for reference."""
+
+import typing as tp
+
+import torchaudio
+import torch
+from torch import nn
+from einops import rearrange
+
+from .modules import NormConv2d
+
+
+FeatureMapType = tp.List[torch.Tensor]
+LogitsType = torch.Tensor
+DiscriminatorOutput = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]]
+
+
+def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)):
+ return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2)
+
+
+class DiscriminatorSTFT(nn.Module):
+ """STFT sub-discriminator.
+ Args:
+ filters (int): Number of filters in convolutions
+ in_channels (int): Number of input channels. Default: 1
+ out_channels (int): Number of output channels. Default: 1
+ n_fft (int): Size of FFT for each scale. Default: 1024
+ hop_length (int): Length of hop between STFT windows for each scale. Default: 256
+ kernel_size (tuple of int): Inner Conv2d kernel sizes. Default: ``(3, 9)``
+ stride (tuple of int): Inner Conv2d strides. Default: ``(1, 2)``
+ dilations (list of int): Inner Conv2d dilation on the time dimension. Default: ``[1, 2, 4]``
+ win_length (int): Window size for each scale. Default: 1024
+ normalized (bool): Whether to normalize by magnitude after stft. Default: True
+ norm (str): Normalization method. Default: `'weight_norm'`
+ activation (str): Activation function. Default: `'LeakyReLU'`
+ activation_params (dict): Parameters to provide to the activation function.
+ growth (int): Growth factor for the filters. Default: 1
+ """
+ def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1,
+ n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024,
+ filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4],
+ stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm',
+ activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}):
+ super().__init__()
+ assert len(kernel_size) == 2
+ assert len(stride) == 2
+ self.filters = filters
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.n_fft = n_fft
+ self.hop_length = hop_length
+ self.win_length = win_length
+ self.normalized = normalized
+ self.activation = getattr(torch.nn, activation)(**activation_params)
+ self.spec_transform = torchaudio.transforms.Spectrogram(
+ n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window,
+ normalized=self.normalized, center=False, pad_mode=None, power=None)
+ spec_channels = 2 * self.in_channels
+ self.convs = nn.ModuleList()
+ self.convs.append(
+ NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size))
+ )
+ in_chs = min(filters_scale * self.filters, max_filters)
+ for i, dilation in enumerate(dilations):
+ out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters)
+ self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride,
+ dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)),
+ norm=norm))
+ in_chs = out_chs
+ out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters)
+ self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]),
+ padding=get_2d_padding((kernel_size[0], kernel_size[0])),
+ norm=norm))
+ self.conv_post = NormConv2d(out_chs, self.out_channels,
+ kernel_size=(kernel_size[0], kernel_size[0]),
+ padding=get_2d_padding((kernel_size[0], kernel_size[0])),
+ norm=norm)
+
+ def forward(self, x: torch.Tensor):
+ fmap = []
+ z = self.spec_transform(x) # [B, 2, Freq, Frames, 2]
+ z = torch.cat([z.real, z.imag], dim=1)
+ z = rearrange(z, 'b c w t -> b c t w')
+ for i, layer in enumerate(self.convs):
+ z = layer(z)
+ z = self.activation(z)
+ fmap.append(z)
+ z = self.conv_post(z)
+ return z, fmap
+
+
+class MultiScaleSTFTDiscriminator(nn.Module):
+ """Multi-Scale STFT (MS-STFT) discriminator.
+ Args:
+ filters (int): Number of filters in convolutions
+ in_channels (int): Number of input channels. Default: 1
+ out_channels (int): Number of output channels. Default: 1
+ n_ffts (Sequence[int]): Size of FFT for each scale
+ hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale
+ win_lengths (Sequence[int]): Window size for each scale
+ **kwargs: additional args for STFTDiscriminator
+ """
+ def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1,
+ n_ffts: tp.List[int] = [1024, 2048, 512], hop_lengths: tp.List[int] = [256, 512, 128],
+ win_lengths: tp.List[int] = [1024, 2048, 512], **kwargs):
+ super().__init__()
+ assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
+ self.discriminators = nn.ModuleList([
+ DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels,
+ n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs)
+ for i in range(len(n_ffts))
+ ])
+ self.num_discriminators = len(self.discriminators)
+
+ def forward(self, x: torch.Tensor) -> DiscriminatorOutput:
+ logits = []
+ fmaps = []
+ for disc in self.discriminators:
+ logit, fmap = disc(x)
+ logits.append(logit)
+ fmaps.append(fmap)
+ return logits, fmaps
+
+
+def test():
+ disc = MultiScaleSTFTDiscriminator(filters=32)
+ y = torch.randn(1, 1, 24000)
+ y_hat = torch.randn(1, 1, 24000)
+
+ y_disc_r, fmap_r = disc(y)
+ y_disc_gen, fmap_gen = disc(y_hat)
+ assert len(y_disc_r) == len(y_disc_gen) == len(fmap_r) == len(fmap_gen) == disc.num_discriminators
+
+ assert all([len(fm) == 5 for fm in fmap_r + fmap_gen])
+ assert all([list(f.shape)[:2] == [1, 32] for fm in fmap_r + fmap_gen for f in fm])
+ assert all([len(logits.shape) == 4 for logits in y_disc_r + y_disc_gen])
+
+
+if __name__ == '__main__':
+ test()
diff --git a/inspiremusic/wavtokenizer/encoder/quantization/__init__.py b/inspiremusic/wavtokenizer/encoder/quantization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfabe52b8cb6f260cdda6137b34df2f4736bd02f
--- /dev/null
+++ b/inspiremusic/wavtokenizer/encoder/quantization/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# flake8: noqa
+from .vq import QuantizedResult, ResidualVectorQuantizer
diff --git a/inspiremusic/wavtokenizer/encoder/quantization/ac.py b/inspiremusic/wavtokenizer/encoder/quantization/ac.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0f3e5dcd385cd273a145effa3f53ce7ccfdc74c
--- /dev/null
+++ b/inspiremusic/wavtokenizer/encoder/quantization/ac.py
@@ -0,0 +1,292 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Arithmetic coder."""
+
+import io
+import math
+import random
+import typing as tp
+import torch
+
+from ..binary import BitPacker, BitUnpacker
+
+
+def build_stable_quantized_cdf(pdf: torch.Tensor, total_range_bits: int,
+ roundoff: float = 1e-8, min_range: int = 2,
+ check: bool = True) -> torch.Tensor:
+ """Turn the given PDF into a quantized CDF that splits
+ [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
+ to the PDF.
+
+ Args:
+ pdf (torch.Tensor): probability distribution, shape should be `[N]`.
+ total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
+ during the coding process is `[0, 2 ** total_range_bits - 1]`.
+ roundoff (float): will round the pdf up to that level to remove difference coming
+ from e.g. evaluating the Language Model on different architectures.
+ min_range (int): minimum range width. Should always be at least 2 for numerical
+ stability. Use this to avoid pathological behavior is a value
+ that is expected to be rare actually happens in real life.
+ check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
+ """
+ pdf = pdf.detach()
+ if roundoff:
+ pdf = (pdf / roundoff).floor() * roundoff
+ # interpolate with uniform distribution to achieve desired minimum probability.
+ total_range = 2 ** total_range_bits
+ cardinality = len(pdf)
+ alpha = min_range * cardinality / total_range
+ assert alpha <= 1, "you must reduce min_range"
+ ranges = (((1 - alpha) * total_range) * pdf).floor().long()
+ ranges += min_range
+ quantized_cdf = torch.cumsum(ranges, dim=-1)
+ if min_range < 2:
+ raise ValueError("min_range must be at least 2.")
+ if check:
+ assert quantized_cdf[-1] <= 2 ** total_range_bits, quantized_cdf[-1]
+ if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range:
+ raise ValueError("You must increase your total_range_bits.")
+ return quantized_cdf
+
+
+class ArithmeticCoder:
+ """ArithmeticCoder,
+ Let us take a distribution `p` over `N` symbols, and assume we have a stream
+ of random variables `s_t` sampled from `p`. Let us assume that we have a budget
+ of `B` bits that we can afford to write on device. There are `2**B` possible numbers,
+ corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single
+ sequence `(s_t)` by doing the following:
+
+ 1) Initialize the current range to` [0 ** 2 B - 1]`.
+ 2) For each time step t, split the current range into contiguous chunks,
+ one for each possible outcome, with size roughly proportional to `p`.
+ For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks
+ would be `{[0, 2], [3, 3]}`.
+ 3) Select the chunk corresponding to `s_t`, and replace the current range with this.
+ 4) When done encoding all the values, just select any value remaining in the range.
+
+ You will notice that this procedure can fail: for instance if at any point in time
+ the range is smaller than `N`, then we can no longer assign a non-empty chunk to each
+ possible outcome. Intuitively, the more likely a value is, the less the range width
+ will reduce, and the longer we can go on encoding values. This makes sense: for any efficient
+ coding scheme, likely outcomes would take less bits, and more of them can be coded
+ with a fixed budget.
+
+ In practice, we do not know `B` ahead of time, but we have a way to inject new bits
+ when the current range decreases below a given limit (given by `total_range_bits`), without
+ having to redo all the computations. If we encode mostly likely values, we will seldom
+ need to inject new bits, but a single rare value can deplete our stock of entropy!
+
+ In this explanation, we assumed that the distribution `p` was constant. In fact, the present
+ code works for any sequence `(p_t)` possibly different for each timestep.
+ We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller
+ the KL between the true distribution and `p_t`, the most efficient the coding will be.
+
+ Args:
+ fo (IO[bytes]): file-like object to which the bytes will be written to.
+ total_range_bits (int): the range `M` described above is `2 ** total_range_bits.
+ Any time the current range width fall under this limit, new bits will
+ be injected to rescale the initial range.
+ """
+
+ def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
+ assert total_range_bits <= 30
+ self.total_range_bits = total_range_bits
+ self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time.
+ self.low: int = 0
+ self.high: int = 0
+ self.max_bit: int = -1
+ self._dbg: tp.List[tp.Any] = []
+ self._dbg2: tp.List[tp.Any] = []
+
+ @property
+ def delta(self) -> int:
+ """Return the current range width."""
+ return self.high - self.low + 1
+
+ def _flush_common_prefix(self):
+ # If self.low and self.high start with the sames bits,
+ # those won't change anymore as we always just increase the range
+ # by powers of 2, and we can flush them out to the bit stream.
+ assert self.high >= self.low, (self.low, self.high)
+ assert self.high < 2 ** (self.max_bit + 1)
+ while self.max_bit >= 0:
+ b1 = self.low >> self.max_bit
+ b2 = self.high >> self.max_bit
+ if b1 == b2:
+ self.low -= (b1 << self.max_bit)
+ self.high -= (b1 << self.max_bit)
+ assert self.high >= self.low, (self.high, self.low, self.max_bit)
+ assert self.low >= 0
+ self.max_bit -= 1
+ self.packer.push(b1)
+ else:
+ break
+
+ def push(self, symbol: int, quantized_cdf: torch.Tensor):
+ """Push the given symbol on the stream, flushing out bits
+ if possible.
+
+ Args:
+ symbol (int): symbol to encode with the AC.
+ quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
+ to build this from your pdf estimate.
+ """
+ while self.delta < 2 ** self.total_range_bits:
+ self.low *= 2
+ self.high = self.high * 2 + 1
+ self.max_bit += 1
+
+ range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
+ range_high = quantized_cdf[symbol].item() - 1
+ effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))))
+ effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits))))
+ assert self.low <= self.high
+ self.high = self.low + effective_high
+ self.low = self.low + effective_low
+ assert self.low <= self.high, (effective_low, effective_high, range_low, range_high)
+ self._dbg.append((self.low, self.high))
+ self._dbg2.append((self.low, self.high))
+ outs = self._flush_common_prefix()
+ assert self.low <= self.high
+ assert self.max_bit >= -1
+ assert self.max_bit <= 61, self.max_bit
+ return outs
+
+ def flush(self):
+ """Flush the remaining information to the stream.
+ """
+ while self.max_bit >= 0:
+ b1 = (self.low >> self.max_bit) & 1
+ self.packer.push(b1)
+ self.max_bit -= 1
+ self.packer.flush()
+
+
+class ArithmeticDecoder:
+ """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.
+
+ Note that this must be called with **exactly** the same parameters and sequence
+ of quantized cdf as the arithmetic encoder or the wrong values will be decoded.
+
+ If the AC encoder current range is [L, H], with `L` and `H` having the some common
+ prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
+ For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
+ `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
+ for a specific sequence of symbols and a binary-search allows us to decode those symbols.
+ At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
+ and we will need to read new bits from the stream and repeat the process.
+
+ """
+ def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
+ self.total_range_bits = total_range_bits
+ self.low: int = 0
+ self.high: int = 0
+ self.current: int = 0
+ self.max_bit: int = -1
+ self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time.
+ # Following is for debugging
+ self._dbg: tp.List[tp.Any] = []
+ self._dbg2: tp.List[tp.Any] = []
+ self._last: tp.Any = None
+
+ @property
+ def delta(self) -> int:
+ return self.high - self.low + 1
+
+ def _flush_common_prefix(self):
+ # Given the current range [L, H], if both have a common prefix,
+ # we know we can remove it from our representation to avoid handling large numbers.
+ while self.max_bit >= 0:
+ b1 = self.low >> self.max_bit
+ b2 = self.high >> self.max_bit
+ if b1 == b2:
+ self.low -= (b1 << self.max_bit)
+ self.high -= (b1 << self.max_bit)
+ self.current -= (b1 << self.max_bit)
+ assert self.high >= self.low
+ assert self.low >= 0
+ self.max_bit -= 1
+ else:
+ break
+
+ def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
+ """Pull a symbol, reading as many bits from the stream as required.
+ This returns `None` when the stream has been exhausted.
+
+ Args:
+ quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
+ to build this from your pdf estimate. This must be **exatly**
+ the same cdf as the one used at encoding time.
+ """
+ while self.delta < 2 ** self.total_range_bits:
+ bit = self.unpacker.pull()
+ if bit is None:
+ return None
+ self.low *= 2
+ self.high = self.high * 2 + 1
+ self.current = self.current * 2 + bit
+ self.max_bit += 1
+
+ def bin_search(low_idx: int, high_idx: int):
+ # Binary search is not just for coding interviews :)
+ if high_idx < low_idx:
+ raise RuntimeError("Binary search failed")
+ mid = (low_idx + high_idx) // 2
+ range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
+ range_high = quantized_cdf[mid].item() - 1
+ effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))))
+ effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits))))
+ low = effective_low + self.low
+ high = effective_high + self.low
+ if self.current >= low:
+ if self.current <= high:
+ return (mid, low, high, self.current)
+ else:
+ return bin_search(mid + 1, high_idx)
+ else:
+ return bin_search(low_idx, mid - 1)
+
+ self._last = (self.low, self.high, self.current, self.max_bit)
+ sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1)
+ self._dbg.append((self.low, self.high, self.current))
+ self._flush_common_prefix()
+ self._dbg2.append((self.low, self.high, self.current))
+
+ return sym
+
+
+def test():
+ torch.manual_seed(1234)
+ random.seed(1234)
+ for _ in range(4):
+ pdfs = []
+ cardinality = random.randrange(4000)
+ steps = random.randrange(100, 500)
+ fo = io.BytesIO()
+ encoder = ArithmeticCoder(fo)
+ symbols = []
+ for step in range(steps):
+ pdf = torch.softmax(torch.randn(cardinality), dim=0)
+ pdfs.append(pdf)
+ q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
+ symbol = torch.multinomial(pdf, 1).item()
+ symbols.append(symbol)
+ encoder.push(symbol, q_cdf)
+ encoder.flush()
+
+ fo.seek(0)
+ decoder = ArithmeticDecoder(fo)
+ for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)):
+ q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
+ decoded_symbol = decoder.pull(q_cdf)
+ assert decoded_symbol == symbol, idx
+ assert decoder.pull(torch.zeros(1)) is None
+
+
+if __name__ == "__main__":
+ test()
diff --git a/inspiremusic/wavtokenizer/encoder/quantization/core_vq.py b/inspiremusic/wavtokenizer/encoder/quantization/core_vq.py
new file mode 100644
index 0000000000000000000000000000000000000000..774781c2947622e6c0c7a55c6eded26a2813b7c7
--- /dev/null
+++ b/inspiremusic/wavtokenizer/encoder/quantization/core_vq.py
@@ -0,0 +1,421 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+#
+# This implementation is inspired from
+# https://github.com/lucidrains/vector-quantize-pytorch
+# which is released under MIT License. Hereafter, the original license:
+# MIT License
+#
+# Copyright (c) 2020 Phil Wang
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+"""Core vector quantization implementation."""
+
+import typing as tp
+import warnings
+
+from einops import rearrange, repeat
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from .. import distrib
+
+
+def default(val: tp.Any, d: tp.Any) -> tp.Any:
+ return val if val is not None else d
+
+
+def ema_inplace(moving_avg, new, decay: float):
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
+
+
+def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
+
+
+def uniform_init(*shape: int):
+ t = torch.empty(shape)
+ nn.init.kaiming_uniform_(t)
+ return t
+
+
+def sample_vectors(samples, num: int):
+ num_samples, device = samples.shape[0], samples.device
+
+ if num_samples >= num:
+ indices = torch.randperm(num_samples, device=device)[:num]
+ else:
+ indices = torch.randint(0, num_samples, (num,), device=device)
+
+ return samples[indices]
+
+
+def kmeans(samples, num_clusters: int, num_iters: int = 10):
+ dim, dtype = samples.shape[-1], samples.dtype
+
+ means = sample_vectors(samples, num_clusters)
+
+ for _ in range(num_iters):
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(
+ means, "c d -> () c d"
+ )
+ dists = -(diffs ** 2).sum(dim=-1)
+
+ buckets = dists.max(dim=-1).indices
+ bins = torch.bincount(buckets, minlength=num_clusters)
+ zero_mask = bins == 0
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
+
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
+ new_means = new_means / bins_min_clamped[..., None]
+
+ means = torch.where(zero_mask[..., None], means, new_means)
+
+ return means, bins
+
+
+class EuclideanCodebook(nn.Module):
+ """Codebook with Euclidean distance.
+ Args:
+ dim (int): Dimension.
+ codebook_size (int): Codebook size.
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
+ If set to true, run the k-means algorithm on the first training batch and use
+ the learned centroids as initialization.
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
+ decay (float): Decay for exponential moving average over the codebooks.
+ epsilon (float): Epsilon value for numerical stability.
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
+ that have an exponential moving average cluster size less than the specified threshold with
+ randomly selected vector from the current batch.
+ """
+ def __init__(
+ self,
+ dim: int,
+ codebook_size: int,
+ kmeans_init: int = False,
+ kmeans_iters: int = 10,
+ decay: float = 0.99,
+ epsilon: float = 1e-5,
+ threshold_ema_dead_code: int = 2,
+ ):
+ super().__init__()
+ self.decay = decay
+ init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
+ embed = init_fn(codebook_size, dim)
+
+ self.codebook_size = codebook_size
+
+ self.kmeans_iters = kmeans_iters
+ self.epsilon = epsilon
+ self.threshold_ema_dead_code = threshold_ema_dead_code
+
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
+ self.register_buffer("embed", embed)
+ self.register_buffer("embed_avg", embed.clone())
+
+ @torch.jit.ignore
+ def init_embed_(self, data):
+ if self.inited:
+ return
+
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) #dataไธๅ
+ self.embed.data.copy_(embed)
+ self.embed_avg.data.copy_(embed.clone())
+ self.cluster_size.data.copy_(cluster_size)
+ self.inited.data.copy_(torch.Tensor([True]))
+ # Make sure all buffers across workers are in sync after initialization
+ distrib.broadcast_tensors(self.buffers())
+
+ def replace_(self, samples, mask):
+ modified_codebook = torch.where(
+ mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
+ )
+ self.embed.data.copy_(modified_codebook)
+
+ def expire_codes_(self, batch_samples):
+ if self.threshold_ema_dead_code == 0:
+ return
+
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
+ if not torch.any(expired_codes):
+ return
+
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
+ self.replace_(batch_samples, mask=expired_codes)
+ distrib.broadcast_tensors(self.buffers())
+
+ def preprocess(self, x):
+ x = rearrange(x, "... d -> (...) d")
+ return x
+
+ def quantize(self, x):
+ embed = self.embed.t()
+ dist = -(
+ x.pow(2).sum(1, keepdim=True)
+ - 2 * x @ embed
+ + embed.pow(2).sum(0, keepdim=True)
+ )
+ embed_ind = dist.max(dim=-1).indices
+ return embed_ind
+
+ def postprocess_emb(self, embed_ind, shape):
+ return embed_ind.view(*shape[:-1])
+
+ def dequantize(self, embed_ind):
+ quantize = F.embedding(embed_ind, self.embed)
+ return quantize
+
+ def encode(self, x):
+ shape = x.shape
+ # pre-process
+ x = self.preprocess(x)
+ # quantize
+ embed_ind = self.quantize(x)
+ # post-process
+ embed_ind = self.postprocess_emb(embed_ind, shape)
+ return embed_ind
+
+ def decode(self, embed_ind):
+ quantize = self.dequantize(embed_ind)
+ return quantize
+
+ def forward(self, x):
+ shape, dtype = x.shape, x.dtype
+ x = self.preprocess(x)
+
+ self.init_embed_(x)
+
+ embed_ind = self.quantize(x)
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
+ embed_ind = self.postprocess_emb(embed_ind, shape)
+ quantize = self.dequantize(embed_ind)
+
+ if self.training:
+ # We do the expiry of code at that point as buffers are in sync
+ # and all the workers will take the same decision.
+ self.expire_codes_(x)
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
+ embed_sum = x.t() @ embed_onehot
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
+ cluster_size = (
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
+ * self.cluster_size.sum()
+ )
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
+ self.embed.data.copy_(embed_normalized)
+
+ return quantize, embed_ind
+
+
+class VectorQuantization(nn.Module):
+ """Vector quantization implementation.
+ Currently supports only euclidean distance.
+ Args:
+ dim (int): Dimension
+ codebook_size (int): Codebook size
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
+ decay (float): Decay for exponential moving average over the codebooks.
+ epsilon (float): Epsilon value for numerical stability.
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
+ that have an exponential moving average cluster size less than the specified threshold with
+ randomly selected vector from the current batch.
+ commitment_weight (float): Weight for commitment loss.
+ """
+ def __init__(
+ self,
+ dim: int,
+ codebook_size: int,
+ codebook_dim: tp.Optional[int] = None,
+ decay: float = 0.99,
+ epsilon: float = 1e-5,
+ kmeans_init: bool = True,
+ kmeans_iters: int = 50,
+ threshold_ema_dead_code: int = 2,
+ commitment_weight: float = 1.,
+ ):
+ super().__init__()
+ _codebook_dim: int = default(codebook_dim, dim)
+
+ requires_projection = _codebook_dim != dim
+ self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
+ self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
+
+ self.epsilon = epsilon
+ self.commitment_weight = commitment_weight
+
+ self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
+ kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
+ decay=decay, epsilon=epsilon,
+ threshold_ema_dead_code=threshold_ema_dead_code)
+ self.codebook_size = codebook_size
+
+ @property
+ def codebook(self):
+ return self._codebook.embed
+
+ def encode(self, x):
+ x = rearrange(x, "b d n -> b n d")
+ x = self.project_in(x)
+ embed_in = self._codebook.encode(x)
+ return embed_in
+
+ def decode(self, embed_ind):
+ quantize = self._codebook.decode(embed_ind)
+ quantize = self.project_out(quantize)
+ quantize = rearrange(quantize, "b n d -> b d n")
+ return quantize
+
+ def forward(self, x):
+
+ # breakpoint()
+ device = x.device
+ x = rearrange(x, "b d n -> b n d")
+ x = self.project_in(x)
+ quantize, embed_ind = self._codebook(x)
+ if self.training:
+ quantize = x + (quantize - x).detach()
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
+
+ if self.training:
+ # warnings.warn('When using RVQ in training model, first check '
+ # 'https://github.com/facebookresearch/encodec/issues/25 . '
+ # 'The bug wasn\'t fixed here for reproducibility.')
+ if self.commitment_weight > 0:
+ commit_loss = F.mse_loss(quantize.detach(), x)
+ loss = loss + commit_loss * self.commitment_weight
+
+ quantize = self.project_out(quantize)
+ quantize = rearrange(quantize, "b n d -> b d n")
+ return quantize, embed_ind, loss
+
+
+class ResidualVectorQuantization(nn.Module):
+ """Residual vector quantization implementation.
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
+ """
+ def __init__(self, *, num_quantizers, **kwargs):
+ super().__init__()
+ self.layers = nn.ModuleList(
+ [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
+ )
+
+ def forward(self, x, n_q: tp.Optional[int] = None):
+ quantized_out = 0.0
+ residual = x
+
+ all_losses = []
+ all_indices = []
+
+ n_q = n_q or len(self.layers)
+ for layer in self.layers[:n_q]:
+ quantized, indices, loss = layer(residual)
+ residual = residual - quantized.detach()
+ quantized_out = quantized_out + quantized
+ all_indices.append(indices)
+ all_losses.append(loss)
+
+ out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
+ return quantized_out, out_indices, out_losses
+
+ def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
+ residual = x
+ all_indices = []
+ n_q = n_q or len(self.layers)
+ for layer in self.layers[:n_q]:
+ indices = layer.encode(residual)
+ all_indices.append(indices)
+ quantized = layer.decode(indices)
+ residual = residual - quantized.detach()
+ out_indices = torch.stack(all_indices)
+ return out_indices
+
+ def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
+ for i, indices in enumerate(q_indices):
+ layer = self.layers[i]
+ quantized = layer.decode(indices)
+ quantized_out = quantized_out + quantized
+ return quantized_out
+
+
+class LanguageVectorQuantization(nn.Module):
+ """Residual vector quantization implementation.
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
+ """
+ def __init__(self, *, num_quantizers, **kwargs):
+ super().__init__()
+ self.layers = nn.ModuleList(
+ [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
+ )
+ # print("core_vq.py:self.layers",self.layers)
+
+ def forward(self, x, n_q: tp.Optional[int] = None):
+ # breakpoint() x[b,t,c] #[64,75,128]
+ quantized_out = 0.0
+ residual = x
+
+
+ all_losses = []
+ all_indices = []
+
+ # breakpoint()
+
+ n_q = n_q or len(self.layers)
+
+ for layer in self.layers[:n_q]:
+ quantized_out, indices, loss = layer(residual) #ๅพๅฐ่ฏฅๅฑ็่กจๅพ๏ผ่ฏฅๅฑ็indices,่ฏฅๅฑ็loss [64,75]
+ # residual = residual - quantized.detach()
+ # quantized_out = quantized_out + quantized
+ all_indices.append(indices)
+ all_losses.append(loss)
+ # breakpoint()
+ # breakpoint()
+
+ out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
+ return quantized_out, out_indices, out_losses
+
+ def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
+ residual = x
+ all_indices = []
+ n_q = n_q or len(self.layers)
+ for layer in self.layers[:n_q]:
+ indices = layer.encode(residual)
+ all_indices.append(indices)
+ quantized = layer.decode(indices)
+ residual = residual - quantized.detach()
+ out_indices = torch.stack(all_indices)
+ return out_indices
+
+ def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
+ for i, indices in enumerate(q_indices):
+ layer = self.layers[i]
+ quantized = layer.decode(indices)
+ quantized_out = quantized_out + quantized
+ return quantized_out
\ No newline at end of file
diff --git a/inspiremusic/wavtokenizer/encoder/quantization/vq.py b/inspiremusic/wavtokenizer/encoder/quantization/vq.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8e316b4bf912c2a743cd27fe038a17e85bceb13
--- /dev/null
+++ b/inspiremusic/wavtokenizer/encoder/quantization/vq.py
@@ -0,0 +1,172 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Residual vector quantizer implementation."""
+
+from dataclasses import dataclass, field
+import math
+import typing as tp
+
+import torch
+from torch import nn
+
+from .core_vq import ResidualVectorQuantization,LanguageVectorQuantization
+
+
+@dataclass
+class QuantizedResult:
+ quantized: torch.Tensor
+ codes: torch.Tensor
+ bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
+ penalty: tp.Optional[torch.Tensor] = None
+ metrics: dict = field(default_factory=dict)
+
+
+class ResidualVectorQuantizer(nn.Module):
+ """Residual Vector Quantizer.
+ Args:
+ dimension (int): Dimension of the codebooks.
+ n_q (int): Number of residual vector quantizers used.
+ bins (int): Codebook size.
+ decay (float): Decay for exponential moving average over the codebooks.
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
+ that have an exponential moving average cluster size less than the specified threshold with
+ randomly selected vector from the current batch.
+ """
+ def __init__(
+ self,
+ dimension: int = 256,
+ n_q: int = 8,
+ bins: int = 1024,
+ decay: float = 0.99,
+ kmeans_init: bool = True,
+ kmeans_iters: int = 50,
+ threshold_ema_dead_code: int = 2,
+ ):
+ super().__init__()
+ self.n_q = n_q
+ self.dimension = dimension
+ self.bins = bins
+ self.decay = decay
+ self.kmeans_init = kmeans_init
+ self.kmeans_iters = kmeans_iters
+ self.threshold_ema_dead_code = threshold_ema_dead_code
+
+ # print(self.bins)
+
+ # breakpoint()
+
+ self.vq = LanguageVectorQuantization(
+ dim=self.dimension,
+ codebook_size=self.bins,
+ num_quantizers=self.n_q,
+ decay=self.decay,
+ kmeans_init=self.kmeans_init,
+ kmeans_iters=self.kmeans_iters,
+ threshold_ema_dead_code=self.threshold_ema_dead_code,
+ )
+ # self.vq = ResidualVectorQuantization(
+ # dim=self.dimension,
+ # codebook_size=self.bins,
+ # num_quantizers=self.n_q,
+ # decay=self.decay,
+ # kmeans_init=self.kmeans_init,
+ # kmeans_iters=self.kmeans_iters,
+ # threshold_ema_dead_code=self.threshold_ema_dead_code,
+ # )
+
+
+ def forward(self, x: torch.Tensor, frame_rate: int, bandwidth: tp.Optional[float] = None) -> QuantizedResult:
+ """Residual vector quantization on the given input tensor.
+ Args:
+ x (torch.Tensor): Input tensor.
+ frame_rate (int): Sample rate of the input tensor.
+ bandwidth (float): Target bandwidth.
+ Returns:
+ QuantizedResult:
+ The quantized (or approximately quantized) representation with
+ the associated bandwidth and any penalty term for the loss.
+ """
+ # breakpoint()
+
+
+ bw_per_q = self.get_bandwidth_per_quantizer(frame_rate)
+ n_q = self.get_num_quantizers_for_bandwidth(frame_rate, bandwidth)
+ # assert n_q==4
+ # breakpoint()
+ # nq_choice=[3,4,8]
+ nq_choice=[4,6,8]
+ if self.training:
+ # choice = int(torch.randint(0, 3, (1,)).item())
+ choice = int(torch.randint(0, 3, (1,)).item())
+ # breakpoint()
+ n_q=nq_choice[choice]
+ # breakpoint()
+ # n_q=8
+ quantized, codes, commit_loss = self.vq(x, n_q=n_q)
+ bw = torch.tensor(n_q * bw_per_q).to(x)
+ return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
+
+ def infer(self, x: torch.Tensor, frame_rate: int, bandwidth: tp.Optional[float] = None) -> QuantizedResult:
+ """Residual vector quantization on the given input tensor.
+ Args:
+ x (torch.Tensor): Input tensor.
+ frame_rate (int): Sample rate of the input tensor.
+ bandwidth (float): Target bandwidth.
+ Returns:
+ QuantizedResult:
+ The quantized (or approximately quantized) representation with
+ the associated bandwidth and any penalty term for the loss.
+ """
+ bw_per_q = self.get_bandwidth_per_quantizer(frame_rate)
+ # n_q = self.get_num_quantizers_for_bandwidth(frame_rate, bandwidth)
+ # # assert n_q==4
+ # # breakpoint()
+ # # nq_choice=[3,4,8]
+ # nq_choice=[3,4,5,6,7,8]
+ # if self.training:
+ # # choice = int(torch.randint(0, 3, (1,)).item())
+ # choice = int(torch.randint(0, 6, (1,)).item())
+ # # breakpoint()
+ # n_q=nq_choice[choice]
+ n_q=1
+ quantized, codes, commit_loss = self.vq(x, n_q=n_q)
+ bw = torch.tensor(n_q * bw_per_q).to(x)
+ return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
+
+ def get_num_quantizers_for_bandwidth(self, frame_rate: int, bandwidth: tp.Optional[float] = None) -> int:
+ """Return n_q based on specified target bandwidth.
+ """
+ bw_per_q = self.get_bandwidth_per_quantizer(frame_rate)
+ n_q = self.n_q
+ if bandwidth and bandwidth > 0.:
+ # bandwidth is represented as a thousandth of what it is, e.g. 6kbps bandwidth is represented as
+ # bandwidth == 6.0
+ n_q = int(max(1, math.floor(bandwidth * 1000 / bw_per_q)))
+ return n_q
+
+ def get_bandwidth_per_quantizer(self, frame_rate: int):
+ """Return bandwidth per quantizer for a given input frame rate.
+ Each quantizer encodes a frame with lg(bins) bits.
+ """
+ return math.log2(self.bins) * frame_rate
+
+ def encode(self, x: torch.Tensor, frame_rate: int, bandwidth: tp.Optional[float] = None) -> torch.Tensor:
+ """Encode a given input tensor with the specified frame rate at the given bandwidth.
+ The RVQ encode method sets the appropriate number of quantizers to use
+ and returns indices for each quantizer.
+ """
+ n_q = self.get_num_quantizers_for_bandwidth(frame_rate, bandwidth)
+ codes = self.vq.encode(x, n_q=n_q)
+ return codes
+
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
+ """Decode the given codes to the quantized representation.
+ """
+ quantized = self.vq.decode(codes)
+ return quantized
diff --git a/inspiremusic/wavtokenizer/encoder/utils.py b/inspiremusic/wavtokenizer/encoder/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3f0f9e9bcb37f2267b2f8adefabfc3672453dc5
--- /dev/null
+++ b/inspiremusic/wavtokenizer/encoder/utils.py
@@ -0,0 +1,103 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Various utilities."""
+
+from hashlib import sha256
+from pathlib import Path
+import typing as tp
+
+import torch
+import torchaudio
+
+
+def _linear_overlap_add(frames: tp.List[torch.Tensor], stride: int):
+ # Generic overlap add, with linear fade-in/fade-out, supporting complex scenario
+ # e.g., more than 2 frames per position.
+ # The core idea is to use a weight function that is a triangle,
+ # with a maximum value at the middle of the segment.
+ # We use this weighting when summing the frames, and divide by the sum of weights
+ # for each positions at the end. Thus:
+ # - if a frame is the only one to cover a position, the weighting is a no-op.
+ # - if 2 frames cover a position:
+ # ... ...
+ # / \/ \
+ # / /\ \
+ # S T , i.e. S offset of second frame starts, T end of first frame.
+ # Then the weight function for each one is: (t - S), (T - t), with `t` a given offset.
+ # After the final normalization, the weight of the second frame at position `t` is
+ # (t - S) / (t - S + (T - t)) = (t - S) / (T - S), which is exactly what we want.
+ #
+ # - if more than 2 frames overlap at a given point, we hope that by induction
+ # something sensible happens.
+ assert len(frames)
+ device = frames[0].device
+ dtype = frames[0].dtype
+ shape = frames[0].shape[:-1]
+ total_size = stride * (len(frames) - 1) + frames[-1].shape[-1]
+
+ frame_length = frames[0].shape[-1]
+ t = torch.linspace(0, 1, frame_length + 2, device=device, dtype=dtype)[1: -1]
+ weight = 0.5 - (t - 0.5).abs()
+
+ sum_weight = torch.zeros(total_size, device=device, dtype=dtype)
+ out = torch.zeros(*shape, total_size, device=device, dtype=dtype)
+ offset: int = 0
+
+ for frame in frames:
+ frame_length = frame.shape[-1]
+ out[..., offset:offset + frame_length] += weight[:frame_length] * frame
+ sum_weight[offset:offset + frame_length] += weight[:frame_length]
+ offset += stride
+ assert sum_weight.min() > 0
+ return out / sum_weight
+
+
+def _get_checkpoint_url(root_url: str, checkpoint: str):
+ if not root_url.endswith('/'):
+ root_url += '/'
+ return root_url + checkpoint
+
+
+def _check_checksum(path: Path, checksum: str):
+ sha = sha256()
+ with open(path, 'rb') as file:
+ while True:
+ buf = file.read(2**20)
+ if not buf:
+ break
+ sha.update(buf)
+ actual_checksum = sha.hexdigest()[:len(checksum)]
+ if actual_checksum != checksum:
+ raise RuntimeError(f'Invalid checksum for file {path}, '
+ f'expected {checksum} but got {actual_checksum}')
+
+
+def convert_audio(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int):
+ assert wav.dim() >= 2, "Audio tensor must have at least 2 dimensions"
+ assert wav.shape[-2] in [1, 2], "Audio must be mono or stereo."
+ *shape, channels, length = wav.shape
+ if target_channels == 1:
+ wav = wav.mean(-2, keepdim=True)
+ elif target_channels == 2:
+ wav = wav.expand(*shape, target_channels, length)
+ elif channels == 1:
+ wav = wav.expand(target_channels, -1)
+ else:
+ raise RuntimeError(f"Impossible to convert from {channels} to {target_channels}")
+ wav = torchaudio.transforms.Resample(sr, target_sr)(wav)
+ return wav
+
+
+def save_audio(wav: torch.Tensor, path: tp.Union[Path, str],
+ sample_rate: int, rescale: bool = False):
+ limit = 0.99
+ mx = wav.abs().max()
+ if rescale:
+ wav = wav * min(limit / mx, 1)
+ else:
+ wav = wav.clamp(-limit, limit)
+ torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
diff --git a/requirements.txt b/requirements.txt
index 6275b171d63423a0428c32474b4e7feeeb294368..288f394b475e2ccc99f54eb71df0074d94917213 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -4,4 +4,5 @@ librosa
torch
torchaudio
modelscope
-funasr
\ No newline at end of file
+funasr
+transformers
\ No newline at end of file