diff --git a/app.py b/app.py index 8731f6999b2098ef841c1495adfb0d53832d1604..542a98adfb4f594f23e889c91f6d3506ed8e0426 100644 --- a/app.py +++ b/app.py @@ -1,193 +1,129 @@ # coding=utf-8 -import os -import librosa -import base64 import io -import gradio as gr -import re - import numpy as np -import torch import torchaudio -from modelscope import HubApi - -api = HubApi() - -key = os.environ["apikey"] if "apikey" in os.environ else "" -try: - api.login(key) -except: - pass - -from funasr import AutoModel - -model = "iic/SenseVoiceSmall" -model = AutoModel(model=model, - vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", - vad_kwargs={"max_single_segment_time": 30000}, - trust_remote_code=True, - ) - -import re - -emo_dict = { - "<|HAPPY|>": "๐Ÿ˜Š", - "<|SAD|>": "๐Ÿ˜”", - "<|ANGRY|>": "๐Ÿ˜ก", - "<|NEUTRAL|>": "", - "<|FEARFUL|>": "๐Ÿ˜ฐ", - "<|DISGUSTED|>": "๐Ÿคข", - "<|SURPRISED|>": "๐Ÿ˜ฎ", -} - -event_dict = { - "<|BGM|>": "๐ŸŽผ", - "<|Speech|>": "", - "<|Applause|>": "๐Ÿ‘", - "<|Laughter|>": "๐Ÿ˜€", - "<|Cry|>": "๐Ÿ˜ญ", - "<|Sneeze|>": "๐Ÿคง", - "<|Breath|>": "", - "<|Cough|>": "๐Ÿคง", -} - -emoji_dict = { - "<|nospeech|><|Event_UNK|>": "โ“", - "<|zh|>": "", - "<|en|>": "", - "<|yue|>": "", - "<|ja|>": "", - "<|ko|>": "", - "<|nospeech|>": "", - "<|HAPPY|>": "๐Ÿ˜Š", - "<|SAD|>": "๐Ÿ˜”", - "<|ANGRY|>": "๐Ÿ˜ก", - "<|NEUTRAL|>": "", - "<|BGM|>": "๐ŸŽผ", - "<|Speech|>": "", - "<|Applause|>": "๐Ÿ‘", - "<|Laughter|>": "๐Ÿ˜€", - "<|FEARFUL|>": "๐Ÿ˜ฐ", - "<|DISGUSTED|>": "๐Ÿคข", - "<|SURPRISED|>": "๐Ÿ˜ฎ", - "<|Cry|>": "๐Ÿ˜ญ", - "<|EMO_UNKNOWN|>": "", - "<|Sneeze|>": "๐Ÿคง", - "<|Breath|>": "", - "<|Cough|>": "๐Ÿ˜ท", - "<|Sing|>": "", - "<|Speech_Noise|>": "", - "<|withitn|>": "", - "<|woitn|>": "", - "<|GBG|>": "", - "<|Event_UNK|>": "", -} - -lang_dict = { - "<|zh|>": "<|lang|>", - "<|en|>": "<|lang|>", - "<|yue|>": "<|lang|>", - "<|ja|>": "<|lang|>", - "<|ko|>": "<|lang|>", - "<|nospeech|>": "<|lang|>", -} - - -emo_set = {"๐Ÿ˜Š", "๐Ÿ˜”", "๐Ÿ˜ก", "๐Ÿ˜ฐ", "๐Ÿคข", "๐Ÿ˜ฎ"} -event_set = {"๐ŸŽผ", "๐Ÿ‘", "๐Ÿ˜€", "๐Ÿ˜ญ", "๐Ÿคง", "๐Ÿ˜ท",} - -notes = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"] - -def format_str(s): - for sptk in emoji_dict: - s = s.replace(sptk, emoji_dict[sptk]) - return s - - -def format_str_v2(s): - sptk_dict = {} - for sptk in emoji_dict: - sptk_dict[sptk] = s.count(sptk) - s = s.replace(sptk, "") - emo = "<|NEUTRAL|>" - for e in emo_dict: - if sptk_dict[e] > sptk_dict[emo]: - emo = e - for e in event_dict: - if sptk_dict[e] > 0: - s = event_dict[e] + s - s = s + emo_dict[emo] - - for emoji in emo_set.union(event_set): - s = s.replace(" " + emoji, emoji) - s = s.replace(emoji + " ", emoji) - return s.strip() - -def format_str_v3(s): - def get_emo(s): - return s[-1] if s[-1] in emo_set else None - def get_event(s): - return s[0] if s[0] in event_set else None - - s = s.replace("<|nospeech|><|Event_UNK|>", "โ“") - for lang in lang_dict: - s = s.replace(lang, "<|lang|>") - s_list = [format_str_v2(s_i).strip(" ") for s_i in s.split("<|lang|>")] - new_s = " " + s_list[0] - cur_ent_event = get_event(new_s) - for i in range(1, len(s_list)): - if len(s_list[i]) == 0: - continue - if get_event(s_list[i]) == cur_ent_event and get_event(s_list[i]) != None: - s_list[i] = s_list[i][1:] - #else: - cur_ent_event = get_event(s_list[i]) - if get_emo(s_list[i]) != None and get_emo(s_list[i]) == get_emo(new_s): - new_s = new_s[:-1] - new_s += s_list[i].strip().lstrip() - new_s = new_s.replace("The.", " ") - return new_s.strip() - -def model_inference(input_wav, language, fs=16000): - # task_abbr = {"Speech Recognition": "ASR", "Rich Text Transcription": ("ASR", "AED", "SER")} - language_abbr = {"auto": "auto", "zh": "zh", "en": "en", "yue": "yue", "ja": "ja", "ko": "ko", - "nospeech": "nospeech"} - - # task = "Speech Recognition" if task is None else task - language = "auto" if len(language) < 1 else language - selected_language = language_abbr[language] - # selected_task = task_abbr.get(task) - - # print(f"input_wav: {type(input_wav)}, {input_wav[1].shape}, {input_wav}") - - if isinstance(input_wav, tuple): - fs, input_wav = input_wav - input_wav = input_wav.astype(np.float32) / np.iinfo(np.int16).max - if len(input_wav.shape) > 1: - input_wav = input_wav.mean(-1) - if fs != 16000: - print(f"audio_fs: {fs}") - resampler = torchaudio.transforms.Resample(fs, 16000) - input_wav_t = torch.from_numpy(input_wav).to(torch.float32) - input_wav = resampler(input_wav_t[None, :])[0, :].numpy() - - - merge_vad = True - print(f"language: {language}, merge_vad: {merge_vad}") - text = model.generate(input=input_wav, - cache={}, - language=language, - use_itn=True, - batch_size_s=300, merge_vad=merge_vad) - - print(text) - text = text[0]["text"] - text = format_str_v3(text) - print(text) - - return text +import torch +import soundfile as sf +import gradio as gr +import spaces +from inspiremusic.cli.inference import InspireMusicUnified, set_env_variables +import os +import sys + + +def get_args(): + parser = argparse.ArgumentParser( + description='Run inference with your model') + parser.add_argument('-m', '--model_name', default="InspireMusic-1.5B-Long", + help='Model name') + + parser.add_argument('-d', '--model_dir', + help='Model folder path') + + parser.add_argument('-t', '--text', + default="Experience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance.", + help='Prompt text') + + parser.add_argument('-a', '--audio_prompt', default=None, + help='Prompt audio') + + parser.add_argument('-c', '--chorus', default="intro", + help='Chorus tag generation mode (e.g., random, verse, chorus, intro, outro)') + + parser.add_argument('--fast', type=bool, default=False, + help='Enable fast inference mode (without flow matching)') + + parser.add_argument('-g', '--gpu', type=int, default=0, + help='GPU ID for this rank, -1 for CPU') + + parser.add_argument('--task', default='text-to-music', + choices=['text-to-music', 'continuation', 'reconstruct', 'super_resolution'], + help='Inference task type: text-to-music, continuation, reconstruct, super_resolution') + + parser.add_argument('-r', '--result_dir', default="exp/inspiremusic", + help='Directory to save generated audio') + + parser.add_argument('-o', '--output_fn', default="output_audio", + help='Output file name') + + parser.add_argument('-f', '--format', type=str, default="wav", + choices=["wav", "mp3", "m4a", "flac"], + help='Format of output audio') + + parser.add_argument('--sample_rate', type=int, default=24000, + help='Sampling rate of input audio') + + parser.add_argument('--output_sample_rate', type=int, default=48000, + choices=[24000, 48000], + help='Sampling rate of generated output audio') + + parser.add_argument('-s', '--time_start', type=float, default=0.0, + help='Start time in seconds') + + parser.add_argument('-e', '--time_end', type=float, default=30.0, + help='End time in seconds') + + parser.add_argument('--max_audio_prompt_length', type=float, default=5.0, + help='Maximum audio prompt length in seconds') + + parser.add_argument('--min_generate_audio_seconds', type=float, + default=10.0, + help='Minimum generated audio length in seconds') + + parser.add_argument('--max_generate_audio_seconds', type=float, + default=30.0, + help='Maximum generated audio length in seconds') + + parser.add_argument('--fp16', type=bool, default=True, + help='Inference with fp16 model') + + parser.add_argument('--fade_out', type=bool, default=True, + help='Apply fade out effect to generated audio') + + parser.add_argument('--fade_out_duration', type=float, default=1.0, + help='Fade out duration in seconds') + + parser.add_argument('--trim', type=bool, default=False, + help='Trim the silence ending of generated audio') + + args = parser.parse_args() + + if not args.model_dir: + args.model_dir = os.path.join("./pretrained_models", args.model_name) + + print(args) + return args + +def InspireMusic(args): + set_env_variables() + model = InspireMusicUnified(model_name=args.model_name, + model_dir=args.model_dir, + min_generate_audio_seconds=args.min_generate_audio_seconds, + max_generate_audio_seconds=args.max_generate_audio_seconds, + sample_rate=args.sample_rate, + output_sample_rate=args.output_sample_rate, + load_jit=True, + load_onnx=False, + fast=args.fast, + fp16=args.fp16, + gpu=args.gpu, + result_dir=args.result_dir) + + model.inference(task=args.task, + text=args.text, + audio_prompt=args.audio_prompt, + chorus=args.chorus, + time_start=args.time_start, + time_end=args.time_end, + output_fn=args.output_fn, + max_audio_prompt_length=args.max_audio_prompt_length, + fade_out_duration=args.fade_out_duration, + output_format=args.format, + fade_out_mode=args.fade_out, + trim=args.trim) + return os.path.join(args.result_dir, f"{args.output_fn}.{args.format}") audio_examples = [ ["example/inspiremusic/inspiremusic_01.wav", "text-to-music"], @@ -218,7 +154,7 @@ description = """ - `The instrumental rock piece features a prominent bass guitar, delivering a pure and energetic sound.` - `A serene blend of instrumental and light pop, featuring soothing melodies and a gentle, soulful keyboard performance.` -Recommended select audio duration is below 30 seconds. For audio longer than 30 seconds, local deployment is recommended, github repo. +Recommended audio prompt duration is 5 seconds, generate audio length is below 30 seconds. To generate audio longer than 30 seconds, local deployment is recommended, github repo. """ @@ -232,86 +168,92 @@ html_content = """

Code

Demo

Models

-

Modelscope Model:

-

Huggingface Model

+

Modelscope Model:

+

Huggingface Model

""" -# ่‡ชๅฎšไน‰่กจๆ ผ็š„ HTML ๅ’Œ CSS ไปฃ็  -centered_table_html = """ - -
- - - - - - - - - - - - - - - - -
SamplesInspireMusicText-to-Music
normal modeExperience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance.
fast modeThe instrumental piece exudes a playful and whimsical atmosphere, likely featuring lively and rhythmic elements. The music seems to be inspired by nature and animals, creating an engaging and light-hearted experience.
-
-""" - - -def launch(): - with gr.Blocks(theme=gr.themes.Soft()) as demo: - # gr.Markdown(description) - gr.HTML(html_content) - with gr.Column(): - with gr.Row(): - with gr.Column(): - text_inputs = gr.Textbox( - label="Input Text", - placeholder="Enter the text you want to generate music, e.g., Experience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance.", - lines=3 - ) - fn_button = gr.Button("Start", variant="primary") - audio_inputs = gr.Audio( - label="Upload prompt audio", - ) - with gr.Column(): - with gr.Accordion("Configuration"): - # task_inputs = gr.Radio(choices=["Speech Recognition", "Rich Text Transcription"], - # value="Speech Recognition", label="Task") - task_inputs = gr.Dropdown(choices=["text-to-music", "music-continuation"], - value="text-to-music", - label="Task") - inference_mode_inputs = gr.Dropdown(choices=["normal", "fast"], - value="normal", - label="Inference Mode") - cfg_input = gr.Slider(3, 10, step=1, label="CFG value") - audio_length = gr.Textbox(value="30", - label="Duration in seconds") - - gr.Examples(examples=audio_examples, - inputs=[text_inputs, audio_inputs, task_inputs], - examples_per_page=5) - - audio_output = gr.Audio(label="Audio Output") - - fn_button.click(model_inference, inputs=[text_inputs, audio_inputs, task_inputs], outputs=audio_output) - - # with gr.Accordion("More examples"): - # gr.HTML(centered_table_html) - demo.launch() - - -if __name__ == "__main__": - # iface.launch() - launch() +def music_generation(task, text=None, audio=None): + args = get_args() + args.task = task + args.text = text if text + args.audio_prompt = audio if audio + generate_audio_path = InspireMusic(args) + return generate_audio_path + +demo = gr.Blocks() + +t2m_demo = gr.Interface( + fn=music_generation, + inputs = [ + gr.Dropdown(["Text-To-Music"], value="text-to-music", multiselect=False, info="Choose a task."), + gr.Text(label="Input Text"), + ], + outputs = [ + gr.Audio(label="Generated Music", type="generated audio filepath"), + ], + title = "InspireMusic: A Unified Framework for Music, Song, Audio Generation.", + description = ("InspireMusic ([Github Repo](https://github.com/FunAudioLLM/InspireMusic)) is a fundamental AIGC toolkit and models designed for music, song, and audio generation using PyTorch." + "To try it, simply type text to generation music, or click one of the examples. "), + article = ("

InspireMusic

" + "

WavTokenizer: an Efficient Acoustic Discrete Codec Tokenizer for Audio Language Modeling

"), + examples = [ + ["example/inspiremusic/inspiremusic_01.wav", "24000 Hz"], + ["example/ras/chorus/chorus_01.wav", "48000 Hz"], + ], + cache_examples = True, +) + +con_demo = gr.Interface( + fn=music_generation, + inputs = [ + gr.Dropdown(["Music Continuation"], value="continuation", multiselect=False, info="Choose a task."), + gr.Text(label="Input Text"), + gr.Audio(label="Input Audio Prompt", type="audio prompt filepath"), + ], + outputs = [ + gr.Audio(label="Generated Music", type="generated audio filepath"), + ], + title = "InspireMusic: A Unified Framework for Music, Song, Audio Generation.", + description = ("InspireMusic ([Github Repo](https://github.com/FunAudioLLM/InspireMusic)) is a fundamental AIGC toolkit and models designed for music, song, and audio generation using PyTorch." + "To try it, simply type text to generation music, or click one of the examples. "), + article = ("

InspireMusic

" + "

WavTokenizer: an Efficient Acoustic Discrete Codec Tokenizer for Audio Language Modeling

"), + examples = [ + ["example/inspiremusic/inspiremusic_01.wav", "24000 Hz"], + ["example/ras/chorus/chorus_01.wav", "48000 Hz"], + ], + cache_examples = True, +) + +con_demo = gr.Interface( + fn=music_generation, + inputs = [ + gr.Dropdown(["Music Continuation"], value="continuation", multiselect=False, info="Choose a task."), + gr.Text(label="Input Text"), + gr.Audio(label="Input Audio Prompt", type="audio prompt filepath"), + ], + outputs = [ + gr.Audio(label="Generated Music", type="generated audio filepath"), + ], + title = "InspireMusic: A Unified Framework for Music, Song, Audio Generation.", + description = ("InspireMusic ([Github Repo](https://github.com/FunAudioLLM/InspireMusic)) is a fundamental AIGC toolkit and models designed for music, song, and audio generation using PyTorch." + "To try it, simply type text to generation music, or click one of the examples. "), + article = ("

InspireMusic

" + "

WavTokenizer: an Efficient Acoustic Discrete Codec Tokenizer for Audio Language Modeling

"), + examples = [ + ["example/inspiremusic/inspiremusic_01.wav", "24000 Hz"], + ["example/ras/chorus/chorus_01.wav", "48000 Hz"], + ], + cache_examples = True, +) + +with demo: + gr.TabbedInterface([t2m_demo, con_demo,], + ["Task 1: Text-to-Music", + "Task 2: Music Continuation"]) + # gr.TabbedInterface([t2m_demo, con_demo, fast_demo], ["Task 1: Text-to-Music", "Task 2: Music Continuation", "Task 3: Without Flow Matching"]) + +demo.launch() diff --git a/inspiremusic/__init__.py b/inspiremusic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inspiremusic/bin/export_jit.py b/inspiremusic/bin/export_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..f68203f61246f1167f68682b5a1ff8fc5929f521 --- /dev/null +++ b/inspiremusic/bin/export_jit.py @@ -0,0 +1,74 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import logging +logging.getLogger('matplotlib').setLevel(logging.WARNING) +import os +import sys +import torch +ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) +sys.path.append('{}/../..'.format(ROOT_DIR)) +sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR)) +from inspiremusic.cli.inspiremusic import InspireMusic + + +def get_args(): + parser = argparse.ArgumentParser(description='export your model for deployment') + parser.add_argument('--model_dir', + type=str, + default='pretrained_models/InspireMusic', + help='local path') + args = parser.parse_args() + print(args) + return args + + +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + + torch._C._jit_set_fusion_strategy([('STATIC', 1)]) + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + + inspiremusic = InspireMusic(args.model_dir, load_jit=False, load_onnx=False) + + # 1. export llm text_encoder + llm_text_encoder = inspiremusic.model.llm.text_encoder.half() + script = torch.jit.script(llm_text_encoder) + script = torch.jit.freeze(script) + script = torch.jit.optimize_for_inference(script) + script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir)) + + # 2. export llm llm + llm_llm = inspiremusic.model.llm.llm.half() + script = torch.jit.script(llm_llm) + script = torch.jit.freeze(script, preserved_attrs=['forward_chunk']) + script = torch.jit.optimize_for_inference(script) + script.save('{}/llm.llm.fp16.zip'.format(args.model_dir)) + + # 3. export flow encoder + flow_encoder = inspiremusic.model.flow.encoder + script = torch.jit.script(flow_encoder) + script = torch.jit.freeze(script) + script = torch.jit.optimize_for_inference(script) + script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir)) + + +if __name__ == '__main__': + main() diff --git a/inspiremusic/bin/export_onnx.py b/inspiremusic/bin/export_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..659ee13e4fc495757d11dfc558bf2b1629f35089 --- /dev/null +++ b/inspiremusic/bin/export_onnx.py @@ -0,0 +1,112 @@ +# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com) +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import logging +logging.getLogger('matplotlib').setLevel(logging.WARNING) +import os +import sys +import onnxruntime +import random +import torch +from tqdm import tqdm +ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) +sys.path.append('{}/../..'.format(ROOT_DIR)) +sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR)) +from inspiremusic.cli.inspiremusic import InspireMusic + + +def get_dummy_input(batch_size, seq_len, out_channels, device): + x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) + mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device) + mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) + t = torch.rand((batch_size), dtype=torch.float32, device=device) + spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device) + cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) + return x, mask, mu, t, spks, cond + + +def get_args(): + parser = argparse.ArgumentParser(description='export your model for deployment') + parser.add_argument('--model_dir', + type=str, + default='pretrained_models/InspireMusic', + help='local path') + args = parser.parse_args() + print(args) + return args + + +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + + inspiremusic = InspireMusic(args.model_dir, load_jit=False, load_onnx=False) + + # 1. export flow decoder estimator + estimator = inspiremusic.model.flow.decoder.estimator + + device = inspiremusic.model.device + batch_size, seq_len = 1, 256 + out_channels = inspiremusic.model.flow.decoder.estimator.out_channels + x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device) + torch.onnx.export( + estimator, + (x, mask, mu, t, spks, cond), + '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), + export_params=True, + opset_version=18, + do_constant_folding=True, + input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'], + output_names=['estimator_out'], + dynamic_axes={ + 'x': {0: 'batch_size', 2: 'seq_len'}, + 'mask': {0: 'batch_size', 2: 'seq_len'}, + 'mu': {0: 'batch_size', 2: 'seq_len'}, + 'cond': {0: 'batch_size', 2: 'seq_len'}, + 't': {0: 'batch_size'}, + 'spks': {0: 'batch_size'}, + 'estimator_out': {0: 'batch_size', 2: 'seq_len'}, + } + ) + + # 2. test computation consistency + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] + estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), + sess_options=option, providers=providers) + + for _ in tqdm(range(10)): + x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device) + output_pytorch = estimator(x, mask, mu, t, spks, cond) + ort_inputs = { + 'x': x.cpu().numpy(), + 'mask': mask.cpu().numpy(), + 'mu': mu.cpu().numpy(), + 't': t.cpu().numpy(), + 'spks': spks.cpu().numpy(), + 'cond': cond.cpu().numpy() + } + output_onnx = estimator_onnx.run(None, ort_inputs)[0] + torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4) + + +if __name__ == "__main__": + main() diff --git a/inspiremusic/bin/flow_only_infer.py b/inspiremusic/bin/flow_only_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..d84e9c3671546071b02e1e78d071f8e0ded78b94 --- /dev/null +++ b/inspiremusic/bin/flow_only_infer.py @@ -0,0 +1,150 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import logging +logging.getLogger('matplotlib').setLevel(logging.WARNING) +import os +import torch +from torch.utils.data import DataLoader +import torchaudio +from hyperpyyaml import load_hyperpyyaml +from tqdm import tqdm +from inspiremusic.cli.model import InspireMusicModel +from inspiremusic.dataset.dataset import Dataset +from inspiremusic.utils.common import MUSIC_STRUCTURE_LABELS + +def get_args(): + parser = argparse.ArgumentParser(description='inference only with flow model') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--prompt_data', required=True, help='prompt data file') + parser.add_argument('--flow_model', required=True, help='flow model file') + parser.add_argument('--llm_model', default=None,required=False, help='llm model file') + + parser.add_argument('--music_tokenizer', required=True, help='music tokenizer model file') + parser.add_argument('--wavtokenizer', required=True, help='wavtokenizer model file') + parser.add_argument('--chorus', default="random",required=False, help='chorus tag generation mode, eg. random, verse, chorus, intro.') + parser.add_argument('--sample_rate', type=int, default=48000, required=False, + help='sampling rate of generated audio') + parser.add_argument('--min_generate_audio_seconds', type=float, default=10.0, required=False, + help='the minimum generated audio length in seconds') + parser.add_argument('--max_generate_audio_seconds', type=float, default=30.0, required=False, + help='the maximum generated audio length in seconds') + parser.add_argument('--gpu', + type=int, + default=-1, + help='gpu id for this rank, -1 for cpu') + parser.add_argument('--result_dir', required=True, help='asr result file') + args = parser.parse_args() + print(args) + return args + +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) + + # Init inspiremusic models from configs + use_cuda = args.gpu >= 0 and torch.cuda.is_available() + device = torch.device('cuda' if use_cuda else 'cpu') + with open(args.config, 'r') as f: + configs = load_hyperpyyaml(f) + + model = InspireMusicModel(None, configs['flow'], configs['hift'], configs['wavtokenizer']) + model.load(args.llm_model, args.flow_model, args.music_tokenizer, args.wavtokenizer) + + if args.llm_model is None: + model.llm = None + else: + model.llm = model.llm.to(torch.float32) + + if args.flow_model is None: + model.flow = None + + test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=True, partition=False) + test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) + + del configs + os.makedirs(args.result_dir, exist_ok=True) + fn = os.path.join(args.result_dir, 'wav.scp') + f = open(fn, 'w') + with torch.no_grad(): + for _, batch in tqdm(enumerate(test_data_loader)): + utts = batch["utts"] + assert len(utts) == 1, "inference mode only support batchsize 1" + + if "semantic_token" in batch: + token = batch["semantic_token"].to(device) + token_len = batch["semantic_token_len"].to(device) + else: + if audio_token is None: + token = None + token_len = None + else: + token = audio_token.view(audio_token.size(0),-1,4)[:,:,0] + token_len = audio_token_len / 4 + + text_token = batch["text_token"].to(device) + text_token_len = batch["text_token_len"].to(device) + text = batch["text"] + + if "time_start" not in batch.keys(): + batch["time_start"] = torch.randint(0, args.min_generate_audio_seconds, (1,)).to(torch.float64) + if "time_end" not in batch.keys(): + batch["time_end"] = torch.randint(args.min_generate_audio_seconds, args.max_generate_audio_seconds, (1,)).to(torch.float64) + elif (batch["time_end"].numpy()[0] - batch["time_start"].numpy()[0]) < args.min_generate_audio_seconds: + batch["time_end"] = torch.randint(int(batch["time_start"].numpy()[0] + args.min_generate_audio_seconds), int(batch["time_start"].numpy()[0] + args.max_generate_audio_seconds), (1,)).to(torch.float64) + + if "chorus" not in batch.keys(): + batch["chorus"] = torch.randint(1, 5, (1,)) + + if args.chorus == "random": + batch["chorus"] = torch.randint(1, 5, (1,)) + elif args.chorus == "intro": + batch["chorus"] = torch.Tensor([0]) + elif "verse" in args.chorus: + batch["chorus"] = torch.Tensor([1]) + elif args.chorus == "chorus": + batch["chorus"] = torch.Tensor([2]) + elif args.chorus == "outro": + batch["chorus"] = torch.Tensor([4]) + + time_start = batch["time_start"].to(device) + time_end = batch["time_end"].to(device) + chorus = batch["chorus"].to(torch.int) + + text_prompt = f"<|{batch['time_start'].numpy()[0]}|><|{MUSIC_STRUCTURE_LABELS[chorus.numpy()[0]]}|><|{batch['text'][0]}|><|{batch['time_end'].numpy()[0]}|>" + chorus = chorus.to(device) + + model_input = {"text": text, "audio_token": token, "audio_token_len": token_len, + "text_token": text_token, "text_token_len": text_token_len, + "embeddings": [time_start, time_end, chorus], "raw_text":text} + + music_audios = [] + for model_output in model.inference(**model_input): + music_audios.append(model_output['music_audio']) + + music_key = utts[0] + music_fn = os.path.join(args.result_dir, '{}.wav'.format(music_key)) + torchaudio.save(music_fn, music_audios[0], sample_rate=args.sample_rate) + f.write('{} {}\n'.format(music_key, music_fn)) + f.flush() + f.close() + logging.info('Result wav.scp saved in {}'.format(fn)) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/inspiremusic/bin/inference.py b/inspiremusic/bin/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..c7429fd34487c7b749cdca6afecd4b3a5dcc4bad --- /dev/null +++ b/inspiremusic/bin/inference.py @@ -0,0 +1,266 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import logging + +logging.getLogger('matplotlib').setLevel(logging.WARNING) +import os +import torch +from torch.utils.data import DataLoader +import torchaudio +from hyperpyyaml import load_hyperpyyaml +from tqdm import tqdm +from inspiremusic.cli.model import InspireMusicModel +from inspiremusic.dataset.dataset import Dataset +import time +from inspiremusic.utils.audio_utils import trim_audio, fade_out, process_audio +from inspiremusic.utils.common import MUSIC_STRUCTURE_LABELS + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +def get_args(): + parser = argparse.ArgumentParser(description='inference only with your model') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--prompt_data', required=True, help='prompt data file') + parser.add_argument('--flow_model', default=None, required=False, help='flow model file') + parser.add_argument('--llm_model', default=None,required=False, help='flow model file') + parser.add_argument('--music_tokenizer', required=True, help='music tokenizer model file') + parser.add_argument('--wavtokenizer', required=True, help='wavtokenizer model file') + parser.add_argument('--chorus', default="random",required=False, help='chorus tag generation mode, eg. random, verse, chorus, intro.') + parser.add_argument('--fast', action='store_true', required=False, help='True: fast inference mode, without flow matching for fast inference. False: normal inference mode, with flow matching for high quality.') + parser.add_argument('--fp16', default=True, type=bool, required=False, help='inference with fp16 model') + parser.add_argument('--fade_out', default=True, type=bool, required=False, help='add fade out effect to generated audio') + parser.add_argument('--fade_out_duration', default=1.0, type=float, required=False, help='fade out duration in seconds') + parser.add_argument('--trim', default=False, type=bool, required=False, help='trim the silence ending of generated audio') + parser.add_argument('--format', type=str, default="wav", required=False, + choices=["wav", "mp3", "m4a", "flac"], + help='sampling rate of input audio') + parser.add_argument('--sample_rate', type=int, default=24000, required=False, + help='sampling rate of input audio') + parser.add_argument('--output_sample_rate', type=int, default=48000, required=False, choices=[24000, 48000], + help='sampling rate of generated output audio') + parser.add_argument('--min_generate_audio_seconds', type=float, default=10.0, required=False, + help='the minimum generated audio length in seconds') + parser.add_argument('--max_generate_audio_seconds', type=float, default=30.0, required=False, + help='the maximum generated audio length in seconds') + parser.add_argument('--gpu', + type=int, + default=0, + help='gpu id for this rank, -1 for cpu') + parser.add_argument('--task', + default='text-to-music', + choices=['text-to-music', 'continuation', "reconstruct", "super_resolution"], + help='choose inference task type. text-to-music: text-to-music task. continuation: music continuation task. reconstruct: reconstruction of original music. super_resolution: convert original 24kHz music into 48kHz music.') + parser.add_argument('--result_dir', required=True, help='asr result file') + args = parser.parse_args() + print(args) + return args + + +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s') + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) + + if args.fast: + args.output_sample_rate = 24000 + + min_generate_audio_length = int(args.output_sample_rate * args.min_generate_audio_seconds) + max_generate_audio_length = int(args.output_sample_rate * args.max_generate_audio_seconds) + assert args.min_generate_audio_seconds <= args.max_generate_audio_seconds + + # Init inspiremusic models from configs + use_cuda = args.gpu >= 0 and torch.cuda.is_available() + device = torch.device('cuda' if use_cuda else 'cpu') + with open(args.config, 'r') as f: + configs = load_hyperpyyaml(f) + + model = InspireMusicModel(configs['llm'], configs['flow'], configs['hift'], configs['wavtokenizer'], args.fast, args.fp16) + + model.load(args.llm_model, args.flow_model, args.music_tokenizer, args.wavtokenizer) + + if args.llm_model is None: + model.llm = None + else: + model.llm = model.llm.to(torch.float32) + + if args.flow_model is None: + model.flow = None + + test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=True, partition=False) + test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) + + del configs + os.makedirs(args.result_dir, exist_ok=True) + fn = os.path.join(args.result_dir, 'wav.scp') + f = open(fn, 'w') + caption_fn = os.path.join(args.result_dir, 'captions.txt') + caption_f = open(caption_fn, 'w') + + with torch.no_grad(): + for _, batch in tqdm(enumerate(test_data_loader)): + utts = batch["utts"] + + assert len(utts) == 1, "inference mode only support batchsize 1" + text_token = batch["text_token"].to(device) + text_token_len = batch["text_token_len"].to(device) + + if "time_start" not in batch.keys(): + batch["time_start"] = torch.randint(0, args.min_generate_audio_seconds, (1,)).to(torch.float64) + + if batch["time_start"].numpy()[0] > 300: + batch["time_start"] = torch.Tensor([0]).to(torch.float64) + + if "time_end" not in batch.keys(): + batch["time_end"] = torch.randint(int(batch["time_start"].numpy()[0] + args.min_generate_audio_seconds), int(batch["time_start"].numpy()[0] + args.max_generate_audio_seconds), (1,)).to(torch.float64) + else: + if (batch["time_end"].numpy()[0] - batch["time_start"].numpy()[0]) < args.min_generate_audio_seconds: + batch["time_end"] = torch.randint(int(batch["time_start"].numpy()[0] + args.min_generate_audio_seconds), int(batch["time_start"].numpy()[0] + args.max_generate_audio_seconds), (1,)).to(torch.float64) + elif (batch["time_end"].numpy()[0] - batch["time_start"].numpy()[0]) > args.max_generate_audio_seconds: + batch["time_end"] = torch.Tensor([(batch["time_start"].numpy()[0] + args.max_generate_audio_seconds)]).to(torch.float64) + + if "chorus" not in batch.keys(): + batch["chorus"] = torch.randint(1, 5, (1,)) + + if args.chorus == "random": + batch["chorus"] = torch.randint(1, 5, (1,)) + elif args.chorus == "intro": + batch["chorus"] = torch.Tensor([0]) + elif "verse" in args.chorus: + batch["chorus"] = torch.Tensor([1]) + elif args.chorus == "chorus": + batch["chorus"] = torch.Tensor([2]) + elif args.chorus == "outro": + batch["chorus"] = torch.Tensor([4]) + else: + batch["chorus"] = batch["chorus"] + + time_start = batch["time_start"].to(device) + time_end = batch["time_end"].to(device) + chorus = batch["chorus"].to(torch.int) + + text_prompt = f"<|{batch['time_start'].numpy()[0]}|><|{MUSIC_STRUCTURE_LABELS[chorus.numpy()[0]]}|><|{batch['text'][0]}|><|{batch['time_end'].numpy()[0]}|>" + chorus = chorus.to(device) + + if batch["acoustic_token"] is None: + audio_token = None + audio_token_len = None + else: + audio_token = batch["acoustic_token"].to(device) + audio_token_len = batch["acoustic_token_len"].to(device) + + text = batch["text"] + + if "semantic_token" in batch: + token = batch["semantic_token"].to(device) + token_len = batch["semantic_token_len"].to(device) + else: + if audio_token is None: + token = None + token_len = None + else: + token = audio_token.view(audio_token.size(0), -1, 4)[:, :, 0] + token_len = audio_token_len / 4 + + if args.task in ['text-to-music', 'continuation']: + # text to music, music continuation + model_input = {"text": text, "audio_token": token, + "audio_token_len": token_len, + "text_token": text_token, + "text_token_len": text_token_len, + "embeddings": [time_start, time_end, chorus], + "raw_text": text, + "sample_rate": args.output_sample_rate, + "duration_to_gen": args.max_generate_audio_seconds, + "task": args.task} + elif args.task in ['reconstruct', 'super_resolution']: + # audio reconstruction, audio super resolution + model_input = {"text": text, "audio_token": audio_token, + "audio_token_len": audio_token_len, + "text_token": text_token, + "text_token_len": text_token_len, + "embeddings": [time_start, time_end, chorus], + "raw_text": text, + "sample_rate": args.output_sample_rate, + "duration_to_gen": args.max_generate_audio_seconds, + "task": args.task} + else: + # zero-shot + model_input = {'text' : text, + 'text_len' : text_token_len, + 'prompt_text' : text_token, + 'prompt_text_len' : text_token_len, + 'llm_prompt_audio_token' : token, + 'llm_prompt_audio_token_len' : token_len, + 'flow_prompt_audio_token' : audio_token, + 'flow_prompt_audio_token_len': audio_token_len, + 'prompt_audio_feat' : audio_feat, + 'prompt_audio_feat_len' : audio_feat_len, + "embeddings" : [time_start, + time_end, + chorus]} + + music_key = utts[0] + music_audios = [] + music_fn = os.path.join(args.result_dir, f'{music_key}.{args.format}') + bench_start = time.time() + + for model_output in model.inference(**model_input): + music_audios.append(model_output['music_audio']) + bench_end = time.time() + if args.trim: + music_audio = trim_audio(music_audios[0], + sample_rate=args.output_sample_rate, + threshold=0.05, + min_silence_duration=0.8) + else: + music_audio = music_audios[0] + if music_audio.shape[0] != 0: + if music_audio.shape[1] > max_generate_audio_length: + music_audio = music_audio[:, :max_generate_audio_length] + if music_audio.shape[1] >= min_generate_audio_length: + try: + if args.fade_out: + music_audio = fade_out(music_audio, args.output_sample_rate, args.fade_out_duration) + music_audio = music_audio.repeat(2, 1) + if args.format in ["wav", "flac"]: + torchaudio.save(music_fn, music_audio, sample_rate=args.output_sample_rate, encoding="PCM_S", bits_per_sample=24) + elif args.format in ["mp3", "m4a"]: + torchaudio.backend.sox_io_backend.save(filepath=music_fn, src=music_audio, sample_rate=args.output_sample_rate, format=args.format) + else: + logging.info(f"Format is not supported. Please choose from wav, mp3, m4a, flac.") + except Exception as e: + logging.info(f"Error saving file: {e}") + raise + + audio_duration = music_audio.shape[1] / args.output_sample_rate + rtf = (bench_end - bench_start) / audio_duration + logging.info(f"processing time: {int(bench_end - bench_start)}s, audio length: {int(audio_duration)}s, rtf: {rtf}, text prompt: {text_prompt}") + f.write('{} {}\n'.format(music_key, music_fn)) + f.flush() + caption_f.write('{}\t{}\n'.format(music_key, text_prompt)) + caption_f.flush() + else: + logging.info(f"Generate audio length {music_audio.shape[1]} is shorter than min_generate_audio_length.") + else: + logging.info(f"Generate audio is empty, dim = {music_audio.shape[0]}.") + f.close() + logging.info('Result wav.scp saved in {}'.format(fn)) + + +if __name__ == '__main__': + main() diff --git a/inspiremusic/bin/train.py b/inspiremusic/bin/train.py new file mode 100644 index 0000000000000000000000000000000000000000..92a9bae52670a4d7d9b48d944f3e96ac086ca762 --- /dev/null +++ b/inspiremusic/bin/train.py @@ -0,0 +1,194 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import argparse +import datetime +import logging +logging.getLogger('matplotlib').setLevel(logging.WARNING) +from copy import deepcopy +import torch +import torch.distributed as dist +import deepspeed +import glob +import os +from hyperpyyaml import load_hyperpyyaml +from torch.cuda.amp import GradScaler, autocast +from torch.distributed.elastic.multiprocessing.errors import record +from peft import get_peft_config, get_peft_model, LoraConfig, TaskType +from inspiremusic.utils.executor import Executor +from inspiremusic.utils.train_utils import ( + init_distributed, + init_dataset_and_dataloader, + init_optimizer_and_scheduler, + init_summarywriter, save_model, + wrap_cuda_model, check_modify_and_save_config) + + +def get_args(): + parser = argparse.ArgumentParser(description='training your network') + parser.add_argument('--train_engine', + default='torch_ddp', + choices=['torch_ddp', 'deepspeed'], + help='Engine for paralleled training') + parser.add_argument('--model', required=True, help='model which will be trained') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--train_data', required=True, help='train data file') + parser.add_argument('--cv_data', required=True, help='cv data file') + parser.add_argument('--checkpoint', help='checkpoint model') + parser.add_argument('--model_dir', required=True, help='save model dir') + parser.add_argument('--tensorboard_dir', + default='tensorboard', + help='tensorboard log dir') + parser.add_argument('--ddp.dist_backend', + dest='dist_backend', + default='nccl', + choices=['nccl', 'gloo'], + help='distributed backend') + parser.add_argument('--num_workers', + default=0, + type=int, + help='number of subprocess workers for reading') + parser.add_argument('--prefetch', + default=100, + type=int, + help='prefetch number') + parser.add_argument('--pin_memory', + action='store_true', + default=True, + help='Use pinned memory buffers used for reading') + parser.add_argument('--deepspeed.save_states', + dest='save_states', + default='model_only', + choices=['model_only', 'model+optimizer'], + help='save model/optimizer states') + parser.add_argument('--timeout', + default=30, + type=int, + help='timeout (in seconds) of inspiremusic_join.') + parser.add_argument('--fp16', + action='store_true', + default=False, + help='Enable fp16 mixed precision training') + parser.add_argument('--lora', + action='store_true', + default=False, + help='Enable LoRA training') + parser.add_argument('--lora_rank', + default=4, + type=int, + help='LoRA rank') + parser.add_argument('--lora_alpha', + default=16, + type=int, + help='LoRA alpha') + parser.add_argument('--lora_dropout', + default=0.1, + type=float, + help='LoRA dropout rate') + parser.add_argument('--lora_target_modules', + nargs='+', + default=["k_proj","v_proj"], + help='Target modules to apply LoRA (e.g., ["q_proj", "v_proj"])') + + parser = deepspeed.add_config_arguments(parser) + args = parser.parse_args() + return args + + +@record +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + + override_dict = {k: None for k in ['llm', 'flow', 'hift'] if k != args.model} + with open(args.config, 'r') as f: + configs = load_hyperpyyaml(f, overrides=override_dict) + configs['train_conf'].update(vars(args)) + + # Init env for ddp + init_distributed(args) + + # Get dataset & dataloader + train_dataset, cv_dataset, train_data_loader, cv_data_loader = \ + init_dataset_and_dataloader(args, configs) + + # Do some sanity checks and save config to arsg.model_dir + configs = check_modify_and_save_config(args, configs) + + # Tensorboard summary + writer = init_summarywriter(args) + + # load checkpoint + model = configs[args.model] + + if args.checkpoint is not None: + model.load_state_dict(torch.load(args.checkpoint, map_location='cpu')) + else: + # Find and load the latest checkpoint + checkpoint_files = glob.glob(os.path.join(args.model_dir, '*.pt')) + + if checkpoint_files: + latest_checkpoint = max(checkpoint_files, key=os.path.getctime) + logging.info(f"Loaded latest checkpoint from {latest_checkpoint}") + + model.load_state_dict(torch.load(latest_checkpoint, map_location='cpu')) + + if args.lora: + logging.info("Applying LoRA to the model...") + if not args.lora_target_modules: + raise ValueError("No target modules specified for LoRA. Please provide --lora_target_modules.") + lora_config = LoraConfig( + task_type="CAUSAL_LM", # Change to appropriate task type + inference_mode=False, + r=args.lora_rank, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + target_modules=args.lora_target_modules + ) + model.llm.model = get_peft_model(model.llm.model, lora_config) + # Optionally freeze the base model + else: + logging.info("LoRA is not enabled. Training the full model.") + + # Dispatch model from cpu to gpu + model = wrap_cuda_model(args, model) + + # Get optimizer & scheduler + model, optimizer, scheduler = init_optimizer_and_scheduler(args, configs, model) + + # Initialize AMP for torch_ddp if fp16 is enabled + scaler = None + if args.fp16: + scaler = GradScaler() + logging.info("Initialized AMP GradScaler for mixed precision training.") + + # Save init checkpoints + info_dict = deepcopy(configs['train_conf']) + + # Get executor + executor = Executor() + + # Start training loop + for epoch in range(info_dict['max_epoch']): + executor.epoch = epoch + train_dataset.set_epoch(epoch) + dist.barrier() + group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout)) + executor.train_one_epoch(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join, scaler=scaler) + dist.destroy_process_group(group_join) + +if __name__ == '__main__': + main() diff --git a/inspiremusic/cli/__init__.py b/inspiremusic/cli/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inspiremusic/cli/frontend.py b/inspiremusic/cli/frontend.py new file mode 100644 index 0000000000000000000000000000000000000000..d717c3a5c77cbf3c9cc2513f75b3045ad26d9149 --- /dev/null +++ b/inspiremusic/cli/frontend.py @@ -0,0 +1,106 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +import onnxruntime +import torch +import numpy as np +import whisper +from typing import Callable +import torchaudio.compliance.kaldi as kaldi +import torchaudio +import os +import re +import inflect +from inspiremusic.cli.model import InspireMusicModel +from inspiremusic.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph +from inspiremusic.wavtokenizer.decoder.pretrained import WavTokenizer + +class InspireMusicFrontEnd: + def __init__(self, + configs: Callable, + get_tokenizer: Callable, + llm_model: str, + flow_model: str, + music_tokenizer_dir: str, + audio_tokenizer_dir: str, + instruct: bool = False, + fast: bool = False, + fp16: bool = True, + allowed_special: str = 'all'): + self.tokenizer = get_tokenizer() + self.audio_tokenizer_dir = audio_tokenizer_dir + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + self.bandwidth_id = torch.tensor([0]).to(self.device) + self.wavtokenizer = WavTokenizer.from_pretrained_feat(f"{audio_tokenizer_dir}/config.yaml", f"{audio_tokenizer_dir}/model.pt").to(self.device) + + self.model = InspireMusicModel(configs['llm'], configs['flow'], configs['hift'], configs['wavtokenizer'], fast, fp16) + self.model = self.model.load(llm_model, flow_model, music_tokenizer_dir, audio_tokenizer_dir) + + self.instruct = instruct + self.allowed_special = allowed_special + self.inflect_parser = inflect.engine() + + def _extract_text_token(self, text): + text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special) + text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device) + text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device) + return text_token, text_token_len + + def _extract_audio_token(self, audio, sample_rate=24000): + audio = torch.tensor(audio, dtype=torch.float32, device=self.device) + _, audio_token = self.wavtokenizer.encode_infer(audio, bandwidth_id=self.bandwidth_id) + audio_token = audio_token.squeeze(0) + audio_token_len = torch.tensor([audio_token.shape[1]], dtype=torch.int32, device=self.device) + return audio_token, audio_token_len + + def text_normalize(self, text, split=True): + text = text.strip() + if contains_chinese(text): + text = text.replace("\n", "") + text = replace_blank(text) + text = replace_corner_mark(text) + text = text.replace(".", "ใ€") + text = text.replace(" - ", "๏ผŒ") + text = remove_bracket(text) + text = re.sub(r'[๏ผŒ,]+$', 'ใ€‚', text) + texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80, + token_min_n=60, merge_len=20, comma_split=False)) + else: + text = spell_out_number(text, self.inflect_parser) + texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80, + token_min_n=60, merge_len=20, comma_split=False)) + if split is False: + return text + return texts + + def frontend_text_to_music(self, text, time_start, time_end, chorus): + text_token, text_token_len = self._extract_text_token(text) + model_input = {"text": text, "audio_token": None, "audio_token_len": None, + "text_token": text_token, "text_token_len": text_token_len, + "embeddings": [time_start, time_end, chorus], "raw_text":text} + return model_input + + def frontend_continuation(self, text, audio, time_start, time_end, chorus, target_sr=24000): + if text is None: + text_token = None + text_token_len = None + else: + text_token, text_token_len = self._extract_text_token(text) + audio_token, audio_token_len = self._extract_audio_token(audio, target_sr) + model_input = {"text": text, "audio_token": audio_token, "audio_token_len": audio_token_len, + "text_token": text_token, "text_token_len": text_token_len, + "embeddings": [time_start, time_end, chorus], "raw_text":text} + return model_input + diff --git a/inspiremusic/cli/inference.py b/inspiremusic/cli/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..ff0ad0994a5989f4082e94af776fb38590f550a5 --- /dev/null +++ b/inspiremusic/cli/inference.py @@ -0,0 +1,296 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import torchaudio +import time +import logging +import argparse + +from modelscope import snapshot_download +from inspiremusic.cli.inspiremusic import InspireMusic +from inspiremusic.utils.file_utils import logging +import torch +from inspiremusic.utils.audio_utils import trim_audio, fade_out +from transformers import AutoModel + +def set_env_variables(): + os.environ['PYTHONIOENCODING'] = 'UTF-8' + os.environ['TOKENIZERS_PARALLELISM'] = 'False' + current_working_dir = os.getcwd() + main_root = os.path.realpath(os.path.join(current_working_dir, '../../')) + bin_dir = os.path.join(main_root, 'inspiremusic') + third_party_matcha_tts_path = os.path.join(main_root, 'third_party', 'Matcha-TTS') + python_path = f"{main_root}:{bin_dir}:{third_party_matcha_tts_path}:{os.environ.get('PYTHONPATH', '')}" + os.environ['PYTHONPATH'] = python_path + sys.path.extend([main_root, third_party_matcha_tts_path]) + +class InspireMusicUnified: + def __init__(self, + model_name: str = "InspireMusic-1.5B-Long", + model_dir: str = None, + min_generate_audio_seconds: float = 10.0, + max_generate_audio_seconds: float = 30.0, + sample_rate: int = 24000, + output_sample_rate: int = 48000, + load_jit: bool = True, + load_onnx: bool = False, + fast: bool = False, + fp16: bool = True, + gpu: int = 0, + result_dir: str = None): + os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) + + # Set model_dir or default to downloading if it doesn't exist + self.model_dir = model_dir or f"../../pretrained_models/{model_name}" + if not os.path.exists(self.model_dir): + self.model_dir = snapshot_download(f"iic/{model_name}", cache_dir=self.model_dir) + + self.sample_rate = sample_rate + self.output_sample_rate = 24000 if fast else output_sample_rate + self.result_dir = result_dir or f"exp/{model_name}" + os.makedirs(self.result_dir, exist_ok=True) + + self.min_generate_audio_seconds = min_generate_audio_seconds + self.max_generate_audio_seconds = max_generate_audio_seconds + self.min_generate_audio_length = int(self.output_sample_rate * self.min_generate_audio_seconds) + self.max_generate_audio_length = int(self.output_sample_rate * self.max_generate_audio_seconds) + assert self.min_generate_audio_seconds <= self.max_generate_audio_seconds, "Min audio seconds must be less than or equal to max audio seconds" + + use_cuda = gpu >= 0 and torch.cuda.is_available() + self.device = torch.device('cuda' if use_cuda else 'cpu') + self.model = InspireMusic(self.model_dir, load_jit=load_jit, load_onnx=load_onnx, fast=fast, fp16=fp16) + + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + + @torch.inference_mode() + def inference(self, + task: str = 'text-to-music', + text: str = None, + audio_prompt: str = None, # audio prompt file path + chorus: str = "verse", + time_start: float = 0.0, + time_end: float = 30.0, + output_fn: str = "output_audio", + max_audio_prompt_length: float = 5.0, + fade_out_duration: float = 1.0, + output_format: str = "wav", + fade_out_mode: bool = True, + trim: bool = False, + ): + + with torch.no_grad(): + text_prompt = f"<|{time_start}|><|{chorus}|><|{text}|><|{time_end}|>" + chorus_dict = {"random": torch.randint(1, 5, (1,)).item(), "intro" : 0, "verse": 1, "chorus": 2, "outro": 4} + chorus = chorus_dict.get(chorus, 1) + chorus = torch.tensor([chorus], dtype=torch.int).to(self.device) + + time_start_tensor = torch.tensor([time_start], dtype=torch.float64).to(self.device) + time_end_tensor = torch.tensor([time_end], dtype=torch.float64).to(self.device) + + music_fn = os.path.join(self.result_dir, f'{output_fn}.{output_format}') + + bench_start = time.time() + + if task == 'text-to-music': + model_input = { + "text" : text, + "audio_prompt" : audio_prompt, + "time_start" : time_start_tensor, + "time_end" : time_end_tensor, + "chorus" : chorus, + "task" : task, + "stream" : False, + "duration_to_gen": self.max_generate_audio_seconds, + "sr" : self.sample_rate + } + elif task == 'continuation': + if audio_prompt is not None: + audio, _ = process_audio(audio_prompt, self.sample_rate) + if audio.size(1) < self.sample_rate: + logging.warning("Warning: Input prompt audio length is shorter than 1s. Please provide an appropriate length audio prompt and try again.") + audio = None + else: + max_audio_prompt_length_samples = int(max_audio_prompt_length * self.sample_rate) + audio = audio[:, :max_audio_prompt_length_samples] # Trimming prompt audio + + model_input = { + "text" : text, + "audio_prompt" : audio, + "time_start" : time_start_tensor, + "time_end" : time_end_tensor, + "chorus" : chorus, + "task" : task, + "stream" : False, + "duration_to_gen": self.max_generate_audio_seconds, + "sr" : self.sample_rate + } + + music_audios = [] + for model_output in self.model.cli_inference(**model_input): + music_audios.append(model_output['music_audio']) + + bench_end = time.time() + + if trim: + music_audio = trim_audio(music_audios[0], + sample_rate=self.output_sample_rate, + threshold=0.05, + min_silence_duration=0.8) + else: + music_audio = music_audios[0] + + if music_audio.shape[0] != 0: + if music_audio.shape[1] > self.max_generate_audio_length: + music_audio = music_audio[:, :self.max_generate_audio_length] + + if music_audio.shape[1] >= self.min_generate_audio_length: + try: + if fade_out_mode: + music_audio = fade_out(music_audio, self.output_sample_rate, fade_out_duration) + + music_audio = music_audio.repeat(2, 1) + + if output_format in ["wav", "flac"]: + torchaudio.save(music_fn, music_audio, + sample_rate=self.output_sample_rate, + encoding="PCM_S", + bits_per_sample=24) + elif output_format in ["mp3", "m4a"]: + torchaudio.backend.sox_io_backend.save( + filepath=music_fn, src=music_audio, + sample_rate=self.output_sample_rate, + format=output_format) + else: + logging.info("Format is not supported. Please choose from wav, mp3, m4a, flac.") + + except Exception as e: + logging.error(f"Error saving file: {e}") + raise + + audio_duration = music_audio.shape[1] / self.output_sample_rate + rtf = (bench_end - bench_start) / audio_duration + logging.info(f"Processing time: {int(bench_end - bench_start)}s, audio length: {int(audio_duration)}s, rtf: {rtf}, text prompt: {text_prompt}") + + else: + logging.error(f"Generated audio length is shorter than minimum required audio length.") + +def get_args(): + parser = argparse.ArgumentParser(description='Run inference with your model') + parser.add_argument('-m', '--model_name', default="InspireMusic-1.5B-Long", + help='Model name') + + parser.add_argument('-d', '--model_dir', + help='Model folder path') + + parser.add_argument('-t', '--text', default="Experience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance.", + help='Prompt text') + + parser.add_argument('-a', '--audio_prompt', default=None, + help='Prompt audio') + + parser.add_argument('-c', '--chorus', default="intro", + help='Chorus tag generation mode (e.g., random, verse, chorus, intro, outro)') + + parser.add_argument('-f', '--fast', type=bool, default=False, + help='Enable fast inference mode (without flow matching)') + + parser.add_argument('-g', '--gpu', type=int, default=0, + help='GPU ID for this rank, -1 for CPU') + + parser.add_argument('--task', default='text-to-music', choices=['text-to-music', 'continuation', 'reconstruct', 'super_resolution'], + help='Inference task type: text-to-music, continuation, reconstruct, super_resolution') + + parser.add_argument('-r', '--result_dir', default="exp/inspiremusic", + help='Directory to save generated audio') + + parser.add_argument('-o', '--output_fn', default="output_audio", + help='Output file name') + + parser.add_argument('--format', type=str, default="wav", choices=["wav", "mp3", "m4a", "flac"], + help='Format of output audio') + + parser.add_argument('--sample_rate', type=int, default=24000, + help='Sampling rate of input audio') + + parser.add_argument('--output_sample_rate', type=int, default=48000, choices=[24000, 48000], + help='Sampling rate of generated output audio') + + parser.add_argument('-s', '--time_start', type=float, default=0.0, + help='Start time in seconds') + + parser.add_argument('-e', '--time_end', type=float, default=30.0, + help='End time in seconds') + + parser.add_argument('--max_audio_prompt_length', type=float, default=5.0, + help='Maximum audio prompt length in seconds') + + parser.add_argument('--min_generate_audio_seconds', type=float, default=10.0, + help='Minimum generated audio length in seconds') + + parser.add_argument('--max_generate_audio_seconds', type=float, default=30.0, + help='Maximum generated audio length in seconds') + + parser.add_argument('--fp16', type=bool, default=True, + help='Inference with fp16 model') + + parser.add_argument('--fade_out', type=bool, default=True, + help='Apply fade out effect to generated audio') + + parser.add_argument('--fade_out_duration', type=float, default=1.0, + help='Fade out duration in seconds') + + parser.add_argument('--trim', type=bool, default=False, + help='Trim the silence ending of generated audio') + + args = parser.parse_args() + + if not args.model_dir: + args.model_dir = os.path.join("../../pretrained_models", args.model_name) + + print(args) + return args + +def main(): + set_env_variables() + args = get_args() + model = InspireMusicUnified(model_name = args.model_name, + model_dir = args.model_dir, + min_generate_audio_seconds = args.min_generate_audio_seconds, + max_generate_audio_seconds = args.max_generate_audio_seconds, + sample_rate = args.sample_rate, + output_sample_rate = args.output_sample_rate, + load_jit = True, + load_onnx = False, + fast = args.fast, + fp16 = args.fp16, + gpu = args.gpu, + result_dir = args.result_dir) + + model.inference(task = args.task, + text = args.text, + audio_prompt = args.audio_prompt, + chorus = args.chorus, + time_start = args.time_start, + time_end = args.time_end, + output_fn = args.output_fn, + max_audio_prompt_length = args.max_audio_prompt_length, + fade_out_duration = args.fade_out_duration, + output_format = args.format, + fade_out_mode = args.fade_out, + trim = args.trim) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/inspiremusic/cli/inspiremusic.py b/inspiremusic/cli/inspiremusic.py new file mode 100644 index 0000000000000000000000000000000000000000..c70f07fe1a37be37f1d2910436847ee03aad97e9 --- /dev/null +++ b/inspiremusic/cli/inspiremusic.py @@ -0,0 +1,133 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import time +from tqdm import tqdm +from hyperpyyaml import load_hyperpyyaml +from modelscope import snapshot_download +from inspiremusic.cli.frontend import InspireMusicFrontEnd +from inspiremusic.cli.model import InspireMusicModel +from inspiremusic.utils.file_utils import logging +import torch + +class InspireMusic: + def __init__(self, model_dir, load_jit=True, load_onnx=False, fast = False, fp16=True): + instruct = True if '-Instruct' in model_dir else False + self.model_dir = model_dir + if not os.path.exists(model_dir): + model_dir = snapshot_download(model_dir) + with open('{}/inspiremusic.yaml'.format(model_dir), 'r') as f: + configs = load_hyperpyyaml(f) + + self.frontend = InspireMusicFrontEnd(configs, + configs['get_tokenizer'], + '{}/llm.pt'.format(model_dir), + '{}/flow.pt'.format(model_dir), + '{}/music_tokenizer/'.format(model_dir), + '{}/wavtokenizer/'.format(model_dir), + instruct, + fast, + fp16, + configs['allowed_special']) + + self.model = InspireMusicModel(configs['llm'], configs['flow'], configs['hift'], configs['wavtokenizer'], fast, fp16) + self.model.load('{}/llm.pt'.format(model_dir), + '{}/flow.pt'.format(model_dir), + '{}/music_tokenizer/'.format(model_dir), + '{}/wavtokenizer/model.pt'.format(model_dir)) + del configs + + @torch.inference_mode() + def inference(self, task, text, audio, time_start, time_end, chorus, stream=False, sr=24000): + if task == "text-to-music": + for i in tqdm(self.frontend.text_normalize(text, split=True)): + model_input = self.frontend.frontend_text_to_music(i, time_start, time_end, chorus) + start_time = time.time() + logging.info('prompt text {}'.format(i)) + for model_output in self.model.inference(**model_input, stream=stream): + music_audios_len = model_output['music_audio'].shape[1] / sr + logging.info('yield music len {}, rtf {}'.format(music_audios_len, (time.time() - start_time) / music_audios_len)) + yield model_output + start_time = time.time() + + elif task == "continuation": + if text is None: + if audio is not None: + for i in tqdm(audio): + model_input = self.frontend.frontend_continuation(None, i, time_start, time_end, chorus, sr, max_audio_length) + start_time = time.time() + logging.info('prompt text {}'.format(i)) + for model_output in self.model.continuation_inference(**model_input, stream=stream): + music_audios_len = model_output['music_audio'].shape[1] / sr + logging.info('yield music len {}, rtf {}'.format(music_audios_len, (time.time() - start_time) / music_audios_len)) + yield model_output + start_time = time.time() + else: + if audio is not None: + for i in tqdm(self.frontend.text_normalize(text, split=True)): + model_input = self.frontend.frontend_continuation(i, audio, time_start, time_end, chorus, sr, max_audio_length) + start_time = time.time() + logging.info('prompt text {}'.format(i)) + for model_output in self.model.continuation_inference(**model_input, stream=stream): + music_audios_len = model_output['music_audio'].shape[1] / sr + logging.info('yield music len {}, rtf {}'.format(music_audios_len, (time.time() - start_time) / music_audios_len)) + yield model_output + start_time = time.time() + else: + print("Please input text or audio.") + else: + print("Currently only support text-to-music and music continuation tasks.") + + @torch.inference_mode() + def cli_inference(self, text, audio_prompt, time_start, time_end, chorus, task, stream=False, duration_to_gen=30, sr=24000): + if task == "text-to-music": + model_input = self.frontend.frontend_text_to_music(text, time_start, time_end, chorus) + logging.info('prompt text {}'.format(text)) + elif task == "continuation": + model_input = self.frontend.frontend_continuation(text, audio_prompt, time_start, time_end, chorus, sr) + logging.info('prompt audio length: {}'.format(len(audio_prompt))) + + start_time = time.time() + for model_output in self.model.inference(**model_input, duration_to_gen=duration_to_gen, task=task): + music_audios_len = model_output['music_audio'].shape[1] / sr + logging.info('yield music len {}, rtf {}'.format(music_audios_len, (time.time() - start_time) / music_audios_len)) + yield model_output + start_time = time.time() + + @torch.inference_mode() + def inference_zero_shot(self, text, prompt_text, prompt_audio_16k, stream=False, sr=24000): + prompt_text = self.frontend.text_normalize(prompt_text, split=False) + for i in tqdm(self.frontend.text_normalize(text, split=True)): + model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_audio_16k) + start_time = time.time() + logging.info('prompt text {}'.format(i)) + for model_output in self.model.inference(**model_input, stream=stream): + audio_len = model_output['music_audio'].shape[1] / sr + logging.info('yield audio len {}, rtf {}'.format(audio_len, (time.time() - start_time) / audio_len)) + yield model_output + start_time = time.time() + @torch.inference_mode() + def inference_instruct(self, text, spk_id, instruct_text, stream=False, sr=24000): + if self.frontend.instruct is False: + raise ValueError('{} do not support instruct inference'.format(self.model_dir)) + instruct_text = self.frontend.text_normalize(instruct_text, split=False) + for i in tqdm(self.frontend.text_normalize(text, split=True)): + model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text) + start_time = time.time() + logging.info('prompt text {}'.format(i)) + for model_output in self.model.inference(**model_input, stream=stream): + audio_len = model_output['music_audio'].shape[1] / sr + logging.info('yield audio len {}, rtf {}'.format(audio_len, (time.time() - start_time) / audio_len)) + yield model_output + start_time = time.time() diff --git a/inspiremusic/cli/model.py b/inspiremusic/cli/model.py new file mode 100644 index 0000000000000000000000000000000000000000..47c23da6284d607f89c337ff712f9d7d5978b998 --- /dev/null +++ b/inspiremusic/cli/model.py @@ -0,0 +1,297 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import numpy as np +import threading +import time +from contextlib import nullcontext +import uuid +from inspiremusic.utils.common import fade_in_out +from inspiremusic.music_tokenizer.vqvae import VQVAE +from inspiremusic.wavtokenizer.decoder.pretrained import WavTokenizer +from torch.cuda.amp import autocast +import logging +import torch +import os + + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +class InspireMusicModel: + + def __init__(self, + llm: torch.nn.Module, + flow: torch.nn.Module, + music_tokenizer: torch.nn.Module, + wavtokenizer: torch.nn.Module, + fast: bool = False, + fp16: bool = True, + ): + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.llm = llm + self.flow = flow + self.music_tokenizer = music_tokenizer + self.wavtokenizer = wavtokenizer + self.fp16 = fp16 + self.token_min_hop_len = 100 + self.token_max_hop_len = 200 + self.token_overlap_len = 20 + # mel fade in out + self.mel_overlap_len = 34 + self.mel_window = np.hamming(2 * self.mel_overlap_len) + # hift cache + self.mel_cache_len = 20 + self.source_cache_len = int(self.mel_cache_len * 256) + # rtf and decoding related + self.stream_scale_factor = 1 + assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf' + self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() + self.lock = threading.Lock() + # dict used to store session related variable + self.music_token_dict = {} + self.llm_end_dict = {} + self.mel_overlap_dict = {} + self.fast = fast + self.generator = "hifi" + + def load(self, llm_model, flow_model, hift_model, wavtokenizer_model): + if llm_model is not None: + self.llm.load_state_dict(torch.load(llm_model, map_location=self.device)) + self.llm.to(self.device).eval() + else: + self.llm = None + if flow_model is not None: + self.flow.load_state_dict(torch.load(flow_model, map_location=self.device)) + self.flow.to(self.device).eval() + if hift_model is not None: + if ".pt" not in hift_model: + self.music_tokenizer = VQVAE( hift_model + '/config.json', + hift_model + '/model.pt', with_encoder=True) + else: + self.music_tokenizer = VQVAE(os.path.dirname(hift_model) + '/config.json', + hift_model, with_encoder=True) + self.music_tokenizer.to(self.device).eval() + if wavtokenizer_model is not None: + if ".pt" not in wavtokenizer_model: + self.wavtokenizer = WavTokenizer.from_pretrained_feat( wavtokenizer_model + '/config.yaml', + wavtokenizer_model + '/model.pt') + else: + self.wavtokenizer = WavTokenizer.from_pretrained_feat( os.path.dirname(wavtokenizer_model) + '/config.yaml', + wavtokenizer_model ) + self.wavtokenizer.to(self.device) + + def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model): + assert self.fp16 is True, "we only provide fp16 jit model, set fp16=True if you want to use jit model" + llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device) + self.llm.text_encoder = llm_text_encoder + llm_llm = torch.jit.load(llm_llm_model) + self.llm.llm = llm_llm + flow_encoder = torch.jit.load(flow_encoder_model) + self.flow.encoder = flow_encoder + + def load_onnx(self, flow_decoder_estimator_model): + import onnxruntime + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] + del self.flow.decoder.estimator + self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers) + + def llm_job(self, text, audio_token, audio_token_len, prompt_text, llm_prompt_audio_token, embeddings, uuid, duration_to_gen, task): + with self.llm_context: + local_res = [] + with autocast(enabled=self.fp16): + inference_kwargs = { + 'text': text.to(self.device), + 'text_len': torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device), + 'prompt_text': prompt_text.to(self.device), + 'prompt_text_len': torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), + 'prompt_audio_token': llm_prompt_audio_token.to(self.device), + 'prompt_audio_token_len': torch.tensor([llm_prompt_audio_token.shape[1]], dtype=torch.int32).to(self.device), + 'embeddings': embeddings, + 'duration_to_gen': duration_to_gen, + 'task': task + } + + if audio_token is not None: + inference_kwargs['audio_token'] = audio_token.to(self.device) + else: + inference_kwargs['audio_token'] = torch.Tensor([0]).to(self.device) + + if audio_token_len is not None: + inference_kwargs['audio_token_len'] = audio_token_len.to(self.device) + else: + inference_kwargs['audio_token_len'] = torch.Tensor([0]).to(self.device) + + for i in self.llm.inference(**inference_kwargs): + local_res.append(i) + + self.music_token_dict[uuid] = local_res + self.llm_end_dict[uuid] = True + + # def token2wav(self, token, token_len, text, text_len, uuid, sample_rate, finalize=False): + def token2wav(self, token, token_len, uuid, sample_rate, finalize=False, flow_cfg=None): + # if self.flow is not None: + # if isinstance(self.flow,MaskedDiffWithText): + # codec_embed = self.flow.inference(token=token.to(self.device), + # token_len=token_len.to(self.device), + # text_token=text, + # text_token_len=text_len, + # ) + # else: + if flow_cfg is not None: + codec_embed = self.flow.inference_cfg(token=token.to(self.device), + token_len=token_len.to(self.device), + sample_rate=sample_rate + ) + else: + codec_embed = self.flow.inference(token=token.to(self.device), + token_len=token_len.to(self.device), + sample_rate=sample_rate + ) + # use music_tokenizer decoder + wav = self.music_tokenizer.generator(codec_embed) + wav = wav.squeeze(0).cpu().detach() + return wav + + def acoustictoken2wav(self, token): + # use music_tokenizer to generate waveform from token + token = token.view(token.size(0), -1, 4) + # codec = token.view(1, -1, 4) + codec_embed = self.music_tokenizer.quantizer.embed(torch.tensor(token).long().to(self.device)).cuda() + wav = self.music_tokenizer.generator(codec_embed) + wav = wav.squeeze(0).cpu().detach() + return wav + + def semantictoken2wav(self, token): + # fast mode, use wavtokenizer decoder + new_tensor = torch.tensor(token.to(self.device)).unsqueeze(0) + features = self.wavtokenizer.codes_to_features(new_tensor) + bandwidth_id = torch.tensor([0]).to(self.device) + wav = self.wavtokenizer.to(self.device).decode(features, bandwidth_id=bandwidth_id) + wav = wav.cpu().detach() + return wav + + @torch.inference_mode() + def inference(self, text, audio_token, audio_token_len, text_token, text_token_len, embeddings=None, + prompt_text=torch.zeros(1, 0, dtype=torch.int32), + llm_prompt_audio_token=torch.zeros(1, 0, dtype=torch.int32), + flow_prompt_audio_token=torch.zeros(1, 0, dtype=torch.int32), + prompt_audio_feat=torch.zeros(1, 0, 80), sample_rate=48000, duration_to_gen = 30, task="continuation", trim = True, stream=False, **kwargs): + + # this_uuid is used to track variables related to this inference thread + # support tasks: + # text to music task + # music continuation task + # require either audio input only or text and audio inputs + + this_uuid = str(uuid.uuid1()) + + if self.llm: + with self.lock: + self.music_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False + + p = threading.Thread(target=self.llm_job, args=(text_token, audio_token, audio_token_len, prompt_text, llm_prompt_audio_token, embeddings, this_uuid, duration_to_gen, task)) + p.start() + + if stream is True: + token_hop_len = self.token_min_hop_len + while True: + time.sleep(0.1) + if len(self.music_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len: + this_music_audio = self.token2wav(token=text_token, + token_len=text_token_len, + uuid=this_uuid, + sample_rate=sample_rate, + finalize=False) + yield {'music_audio': this_music_audio.cpu()} + with self.lock: + self.music_token_dict[this_uuid] = self.music_token_dict[this_uuid][token_hop_len:] + # increase token_hop_len for better audio quality + token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor)) + if self.llm_end_dict[this_uuid] is True and len(self.music_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len: + break + p.join() + # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None + this_music_token = torch.concat(self.music_token_dict[this_uuid], dim=1) + with self.flow_hift_context: + this_music_audio = self.token2wav(token=this_music_token, + prompt_token=flow_prompt_audio_token, + prompt_feat=prompt_audio_feat, + embedding=flow_embedding, + uuid=this_uuid, + sample_rate=sample_rate, + finalize=True) + yield {'music_audio': this_music_audio.cpu()} + else: + # deal with all tokens + if self.fast: + if task == "reconstruct": + assert audio_token is None + this_music_token = audio_token + this_music_audio = self.acoustictoken2wav(token=this_music_token) + else: + if self.llm: + p.join() + print(len(self.music_token_dict[this_uuid])) + this_music_token = torch.concat(self.music_token_dict[this_uuid], dim=1) + print(this_music_token.shape) + else: + this_music_token = text_token + + logging.info("using wavtokenizer generator without flow matching") + this_music_audio = self.semantictoken2wav(token=this_music_token) + print(this_music_audio.shape) + + else: + if self.llm: + p.join() + if len(self.music_token_dict[this_uuid]) != 0: + this_music_token = torch.concat(self.music_token_dict[this_uuid], dim=1) + else: + print(f"The list of tensors is empty for UUID: {this_uuid}") + else: + this_music_token = text_token + logging.info(f"LLM generated audio token length: {this_music_token.shape[1]}") + logging.info(f"using flow matching and {self.generator} generator") + + if self.generator == "hifi": + if (embeddings[1] - embeddings[0]) <= duration_to_gen: + if trim: + trim_length = (int((embeddings[1] - embeddings[0])*75)) + this_music_token = this_music_token[:, :trim_length] + logging.info(f"After trimmed, generated audio token length: {this_music_token.shape[1]}") + elif (embeddings[1] - embeddings[0]) < 1: + logging.info(f"Given audio length={(embeddings[1] - embeddings[0])}, which is too short, please give a longer audio length.") + + this_music_audio = self.token2wav(token=this_music_token, + token_len=torch.LongTensor([this_music_token.size(1)]), + uuid=this_uuid, + sample_rate=sample_rate, + finalize=True) + logging.info(f"Generated audio sequence length: {this_music_audio.shape[1]}") + elif self.generator == "wavtokenizer": + if (embeddings[1] - embeddings[0]) < duration_to_gen: + if trim: + trim_length = (int((embeddings[1] - embeddings[0])*75)) + this_music_token = this_music_token[:,:trim_length] + logging.info(f"After trimmed, generated audio token length: {this_music_token.shape[1]}") + elif (embeddings[1] - embeddings[0]) < 1: + logging.info(f"Given audio length={(embeddings[1] - embeddings[0])}, which is too short, please give a longer audio length.") + + this_music_audio = self.semantictoken2wav(token=this_music_token) + + yield {'music_audio': this_music_audio.cpu()} + torch.cuda.synchronize() \ No newline at end of file diff --git a/inspiremusic/dataset/__init__.py b/inspiremusic/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inspiremusic/dataset/dataset.py b/inspiremusic/dataset/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..188872bf21cfbf55bd3df12ff3071e62c8b49f06 --- /dev/null +++ b/inspiremusic/dataset/dataset.py @@ -0,0 +1,154 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import json +import math +from functools import partial + +import torch +import torch.distributed as dist +from torch.utils.data import IterableDataset +from inspiremusic.utils.file_utils import read_lists, read_json_lists + +class Processor(IterableDataset): + + def __init__(self, source, f, *args, **kw): + assert callable(f) + self.source = source + self.f = f + self.args = args + self.kw = kw + + def set_epoch(self, epoch): + self.source.set_epoch(epoch) + + def __iter__(self): + """ Return an iterator over the source dataset processed by the + given processor. + """ + assert self.source is not None + assert callable(self.f) + return self.f(iter(self.source), *self.args, **self.kw) + + def apply(self, f): + assert callable(f) + return Processor(self, f, *self.args, **self.kw) + + +class DistributedSampler: + + def __init__(self, shuffle=True, partition=True): + self.epoch = -1 + self.update() + self.shuffle = shuffle + self.partition = partition + + def update(self): + assert dist.is_available() + if dist.is_initialized(): + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + else: + self.rank = 0 + self.world_size = 1 + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + self.worker_id = 0 + self.num_workers = 1 + else: + self.worker_id = worker_info.id + self.num_workers = worker_info.num_workers + return dict(rank=self.rank, + world_size=self.world_size, + worker_id=self.worker_id, + num_workers=self.num_workers) + + def set_epoch(self, epoch): + self.epoch = epoch + + def sample(self, data): + """ Sample data according to rank/world_size/num_workers + + Args: + data(List): input data list + + Returns: + List: data list after sample + """ + data = list(range(len(data))) + # force datalist even + + if self.partition: + if self.shuffle: + random.Random(self.epoch).shuffle(data) + if len(data) < self.world_size: + print(len(data), self.world_size) + data = data * math.ceil(self.world_size / len(data)) + data = data[:self.world_size] + data = data[self.rank::self.world_size] + if len(data) < self.num_workers: + data = data * math.ceil(self.num_workers / len(data)) + data = data[:self.num_workers] + data = data[self.worker_id::self.num_workers] + return data + + +class DataList(IterableDataset): + + def __init__(self, lists, shuffle=True, partition=True): + self.lists = lists + self.sampler = DistributedSampler(shuffle, partition) + + def set_epoch(self, epoch): + self.sampler.set_epoch(epoch) + + def __iter__(self): + sampler_info = self.sampler.update() + indexes = self.sampler.sample(self.lists) + for index in indexes: + data = dict(src=self.lists[index]) + data.update(sampler_info) + yield data + + +def Dataset(data_list_file, + data_pipeline, + mode='train', + shuffle=True, + partition=True + ): + """ Construct dataset from arguments + + We have two shuffle stage in the Dataset. The first is global + shuffle at shards tar/raw file level. The second is global shuffle + at training samples level. + + Args: + data_type(str): raw/shard + tokenizer (BaseTokenizer): tokenizer to tokenize + partition(bool): whether to do data partition in terms of rank + """ + assert mode in ['train', 'inference', 'processing'] + lists = read_lists(data_list_file) + + dataset = DataList(lists, + shuffle=shuffle, + partition=partition) + + for func in data_pipeline: + dataset = Processor(dataset, func, mode=mode) + + return dataset diff --git a/inspiremusic/dataset/processor.py b/inspiremusic/dataset/processor.py new file mode 100644 index 0000000000000000000000000000000000000000..21572b593e8e36b96202a1264b62cd73c7b6ecf0 --- /dev/null +++ b/inspiremusic/dataset/processor.py @@ -0,0 +1,595 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import random + +import pyarrow.parquet as pq +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence +import torch.nn.functional as F +import numpy as np +import re + +torchaudio.set_audio_backend('soundfile') + +AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'} +CHORUS = {"intro": 0, "chorus": 1, "verse1": 2, "verse2": 3, "verse": 2, + "outro": 4} + +metadata_pattern = re.compile(r'^\[(ti|ar|al|by|offset):.*\]$') +timestamp_pattern = re.compile(r'^\[\d{2}:\d{2}\.\d{2}\](.*)$') + + +def parquet_opener(data, mode='train', audio_data={}): + """ Give url or local file, return file descriptor + Inplace operation. + + Args: + data(Iterable[str]): url or local file list + + Returns: + Iterable[{src, stream}] + """ + for sample in data: + assert 'src' in sample + + url = sample['src'] + try: + df = pq.read_table(url).to_pandas() + for i in df.index: + sample.update(dict(df.loc[i])) + yield {**sample} + except Exception as ex: + logging.warning('Failed to open {}, ex info {}'.format(url, ex)) + + +def clean_lyrics(data, mode="train"): + for sample in data: + lyrics = sample["text"] + cleaned = [] + for line in lyrics.splitlines(): + if metadata_pattern.match(line): + continue + timestamp_match = timestamp_pattern.match(line) + if timestamp_match: + lyric = timestamp_match.group(1).strip() + if lyric: + cleaned.append(lyric) + else: + if line.strip(): + cleaned.append(line.strip()) + sample["text"] = '\n'.join(cleaned) + yield sample + + +def cut_by_length(data, max_length=8000, num_times=4, mode="train"): + for sample in data: + if "semantic_token" in sample: + sample["semantic_token"] = [ + sample["semantic_token"][0][:max_length]] + if "acoustic_token" not in sample: + sample["acoustic_token"] = sample["speech_token"] + sample["acoustic_token"] = sample["acoustic_token"][ + :max_length * num_times] + + yield sample + + +def filter(data, + max_length=22500, # 22500 #5min #10240 + max_acoustic_length=45000, + min_length=10, + min_acoustic_length=150, + token_max_length=200, + token_min_length=1, + min_output_input_ratio=0.0005, + max_output_input_ratio=1, + mode='train'): + """ Filter sample according to feature and label length + Inplace operation. + + Args:: + data: Iterable[{key, wav, label, sample_rate}] + max_length: drop utterance which is greater than max_length(10ms) + min_length: drop utterance which is less than min_length(10ms) + token_max_length: drop utterance which is greater than + token_max_length, especially when use char unit for + english modeling + token_min_length: drop utterance which is + less than token_max_length + min_output_input_ratio: minimal ration of + token_length / feats_length(10ms) + max_output_input_ratio: maximum ration of + token_length / feats_length(10ms) + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + if mode == "train": + for sample in data: + if "semantic_token" in sample: + new_sample_frames = sample['semantic_token'][0].shape[0] + else: + new_sample_frames = sample['speech_token'] + + if "text_token" in sample: + new_sample_frames += len(sample['text_token']) + + if new_sample_frames > max_length or new_sample_frames < min_length: + print(f"skipped 1 item length={new_sample_frames}") + continue + + sample["chorus"] = sample["chorus"].split(",") + if not isinstance(sample["time_start"], np.ndarray): + sample["time_start"] = [sample["time_start"]] + sample["time_end"] = [sample["time_end"]] + for i, t in enumerate(sample["chorus"]): + if sample["chorus"][i] == "verse": + sample["chorus"][i] = "verse1" + + yield sample + + if mode == "train_flow": + for sample in data: + if "semantic_token" in sample: + new_sample_frames = sample['semantic_token'][0].shape[0] + if "acoustic_token" in sample: + target_sample_frames = sample['acoustic_token'][0].shape[0] + + if new_sample_frames > max_length or new_sample_frames < min_acoustic_length or new_sample_frames < min_length or target_sample_frames > max_acoustic_length: + print( + f"skipped 1 item length={new_sample_frames}, target_length={target_sample_frames}") + continue + + yield sample + + elif mode == "inference": + for sample in data: + yield sample + + +def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'): + """ Resample data. + Inplace operation. + + Args: + data: Iterable[{key, wav, label, sample_rate}] + resample_rate: target resample rate + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'speech' in sample + sample_rate = sample['sample_rate'] + waveform = sample['speech'] + if sample_rate != resample_rate: + if sample_rate < min_sample_rate: + continue + sample['sample_rate'] = resample_rate + sample['speech'] = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=resample_rate)(waveform) + max_val = sample['speech'].abs().max() + if max_val > 1: + sample['speech'] /= max_val + yield sample + + +def truncate(data, truncate_length=24576, mode='train'): + """ Truncate data. + + Args: + data: Iterable[{key, wav, label, sample_rate}] + truncate_length: truncate length + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + waveform = sample['audio'] + if waveform.shape[1] > truncate_length: + start = random.randint(0, waveform.shape[1] - truncate_length) + waveform = waveform[:, start: start + truncate_length] + else: + waveform = torch.concat([waveform, torch.zeros(1, truncate_length - + waveform.shape[1])], + dim=1) + sample['audio'] = waveform + yield sample + + +def upsample(data, resample_rate=48000, min_sample_rate=16000, mode='train', + n_codebook=4): + """ Resample data. + Inplace operation. + + Args: + data: Iterable[{key, wav, label, sample_rate}] + resample_rate: target resample rate + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + assert 'semantic_token' in sample + # TODO: unify data processing key names + if 'acoustic_token' not in sample: + continue + + if 'sample_rate' in sample.keys(): + sample_rate = sample['sample_rate'] + else: + sample_rate = 24000 + token = np.array(sample['semantic_token'][0][:-1]) + + # Calculate the repetition factor for resampling + repetition_factor = int(n_codebook * resample_rate / sample_rate) + if sample_rate != resample_rate: + if sample_rate < min_sample_rate: + continue + sample['sample_rate'] = resample_rate + sample['semantic_token'] = np.array( + [np.repeat(token, repetition_factor)]) + + yield sample + +def compute_fbank(data, + feat_extractor, + mode='train'): + """ Extract fbank + + Args: + data: Iterable[{key, wav, label, sample_rate}] + + Returns: + Iterable[{key, feat, label}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'speech' in sample + assert 'utt' in sample + assert 'text_token' in sample + waveform = sample['speech'] + mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1) + sample['speech_feat'] = mat + del sample['speech'] + yield sample + + +def parse_embedding(data, normalize, mode='train'): + """ Parse utt_embedding/spk_embedding + + Args: + data: Iterable[{key, wav, label, sample_rate}] + + Returns: + Iterable[{key, feat, label}] + """ + + for sample in data: + sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], + dtype=torch.float32) + sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], + dtype=torch.float32) + if normalize: + sample['utt_embedding'] = F.normalize(sample['utt_embedding'], + dim=0) + sample['spk_embedding'] = F.normalize(sample['spk_embedding'], + dim=0) + yield sample + +def tokenize(data, get_tokenizer, allowed_special, mode='train'): + """ Decode text to chars or BPE + Inplace operation + + Args: + data: Iterable[{key, wav, txt, sample_rate}] + + Returns: + Iterable[{key, wav, txt, tokens, label, sample_rate}] + """ + tokenizer = get_tokenizer() + + for sample in data: + assert 'text' in sample + sample['text_token'] = tokenizer.encode(sample['text'], + allowed_special=allowed_special) + yield sample + + +def shuffle(data, shuffle_size=10000, mode='train'): + """ Local shuffle the data + + Args: + data: Iterable[{key, feat, label}] + shuffle_size: buffer size for shuffle + + Returns: + Iterable[{key, feat, label}] + """ + buf = [] + for sample in data: + buf.append(sample) + if len(buf) >= shuffle_size: + random.shuffle(buf) + for x in buf: + yield x + buf = [] + # The sample left over + random.shuffle(buf) + for x in buf: + yield x + + +def sort(data, sort_size=500, mode='train'): + """ Sort the data by feature length. + Sort is used after shuffle and before batch, so we can group + utts with similar lengths into a batch, and `sort_size` should + be less than `shuffle_size` + + Args: + data: Iterable[{key, feat, label}] + sort_size: buffer size for sort + + Returns: + Iterable[{key, feat, label}] + """ + + buf = [] + for sample in data: + if sample["chorus"] == "verse": + sample["chorus"] = "verse1" + + if sample["acoustic_token"].shape[0] == 1: + sample["acoustic_token"] = np.concatenate( + sample["acoustic_token"][0]) + else: + sample["acoustic_token"] = np.concatenate(sample["acoustic_token"]) + + sample["acoustic_token"] = torch.from_numpy(sample["acoustic_token"]) + buf.append(sample) + if len(buf) >= sort_size: + buf.sort(key=lambda x: x['acoustic_token'].size(0)) + for x in buf: + yield x + buf = [] + # The sample left over + buf.sort(key=lambda x: x['acoustic_token'].size(0)) + for x in buf: + yield x + + +def static_batch(data, batch_size=32): + """ Static batch the data by `batch_size` + + Args: + data: Iterable[{key, feat, label}] + batch_size: batch size + + Returns: + Iterable[List[{key, feat, label}]] + """ + buf = [] + data_empty = True + for sample in data: + data_empty = False + buf.append(sample) + if len(buf) >= batch_size: + yield buf + buf = [] + if data_empty: + raise ValueError("data is empty") + if len(buf) > 0: + yield buf + + +def dynamic_batch(data, max_frames_in_batch=12000, mode='train'): + """ Dynamic batch the data until the total frames in batch + reach `max_frames_in_batch` + + Args: + data: Iterable[{key, feat, label}] + max_frames_in_batch: max_frames in one batch + + Returns: + Iterable[List[{key, feat, label}]] + """ + buf = [] + longest_frames = 0 + for sample in data: + assert 'acoustic_token' in sample + assert isinstance(sample['acoustic_token'], torch.Tensor) + + if 'semantic_token' in sample: + new_sample_frames = sample['semantic_token'][0].shape[0] + else: + new_sample_frames = sample['semantic_token'] + + if "text_token" in sample: + new_sample_frames += len(sample['text_token']) + + longest_frames = max(longest_frames, new_sample_frames) + frames_after_padding = longest_frames * (len(buf) + 1) + + if frames_after_padding > max_frames_in_batch: + if len(buf) > 0: + yield buf + buf = [sample] + longest_frames = new_sample_frames + else: + buf.append(sample) + if len(buf) > 0: + yield buf + + +def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, + mode='train'): + """ Wrapper for static/dynamic batch + """ + if mode == 'inference': + return static_batch(data, 1) + elif mode == 'processing': + return static_batch(data, batch_size) + else: + if batch_type == 'static': + return static_batch(data, batch_size) + elif batch_type == 'dynamic': + return dynamic_batch(data, max_frames_in_batch) + else: + logging.fatal('Unsupported batch type {}'.format(batch_type)) + + +def padding(data, mode='train'): + """ Padding the data into training data + + Args: + data: Iterable[List[{key, feat, label}]] + + Returns: + Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] + """ + if mode == "train": + for sample in data: + assert isinstance(sample, list) + if len(sample) != 0: + acoustic_feat_len = torch.tensor( + [x['acoustic_token'].size(0) for x in sample], + dtype=torch.int32) + order = torch.argsort(acoustic_feat_len, descending=True) + utts = [sample[i]['utt'] for i in order] + acoustic_token = [ + sample[i]['acoustic_token'].clone().to(torch.int32) for i in + order] + acoustic_token_len = torch.tensor( + [i.size(0) for i in acoustic_token], dtype=torch.int32) + + acoustic_token = pad_sequence(acoustic_token, + batch_first=True, + padding_value=0) + + text = [sample[i]['text'] for i in order] + text_token = [torch.tensor(sample[i]['text_token']).long() for i + in order] + text_token_len = torch.tensor([i.size(0) for i in text_token], + dtype=torch.int32) + text_token = pad_sequence(text_token, batch_first=True, + padding_value=0) + time_start = torch.tensor( + [sample[i]['time_start'] for i in order]) + time_end = torch.tensor([sample[i]['time_end'] for i in order]) + + if isinstance(sample[0]['chorus'], str): + chorus = torch.tensor( + [CHORUS[sample[i]['chorus']] for i in order]) + else: + chorus = [ + torch.tensor([CHORUS[t] for t in sample[i]['chorus']]) + for i in order] + chorus = pad_sequence(chorus, batch_first=True, + padding_value=-1) + + batch = { + "utts" : utts, + "acoustic_token" : acoustic_token, + "acoustic_token_len": acoustic_token_len, + "time_start" : time_start, + "time_end" : time_end, + "chorus" : chorus, + "text" : text, + "text_token" : text_token, + "text_token_len" : text_token_len, + } + + if "semantic_token" in sample[0]: + semantic_token = [ + torch.tensor(sample[i]['semantic_token'][0], + dtype=torch.int32) for i in order] + semantic_token_len = torch.tensor( + [i.size(0) for i in semantic_token], + dtype=torch.int32) + semantic_token = pad_sequence(semantic_token, + batch_first=True, + padding_value=0) + batch.update({"semantic_token" : semantic_token, + "semantic_token_len": semantic_token_len}) + + yield batch + else: + logging.info("WARNING: sample is empty []!") + + elif mode == "inference": + for sample in data: + assert isinstance(sample, list) + utts = [sample[i]['utt'] for i in range(len(sample))] + text = [sample[i]['text'] for i in range(len(sample))] + text_token = [torch.tensor(sample[i]['text_token']).long() for i in + range(len(sample))] + text_token_len = torch.tensor([i.size(0) for i in text_token], + dtype=torch.int32) + text_token = pad_sequence(text_token, batch_first=True, + padding_value=0) + time_start = torch.tensor( + [sample[i]['time_start'] for i in range(len(sample))]) + time_end = torch.tensor( + [sample[i]['time_end'] for i in range(len(sample))]) + + if isinstance(sample[0]['chorus'], str): + chorus = torch.tensor([CHORUS[sample[i]['chorus']] for i in + range(len(sample))]) + else: + chorus = [torch.tensor([CHORUS[t] for t in sample[i]['chorus']]) + for i in range(len(sample))] + chorus = pad_sequence(chorus, batch_first=True, + padding_value=-1) + + if "acoustic_token" in sample[0]: + acoustic_token = [ + sample[i]['acoustic_token'].clone().to(torch.int32) for i in + range(len(sample))] + acoustic_token_len = torch.tensor( + [i.size(0) for i in acoustic_token], dtype=torch.int32) + acoustic_token = pad_sequence(acoustic_token, + batch_first=True, + padding_value=0) + else: + acoustic_token = None + acoustic_token_len = None + + batch = { + "utts" : utts, + "acoustic_token" : acoustic_token, + "acoustic_token_len": acoustic_token_len, + "time_start" : time_start, + "time_end" : time_end, + "chorus" : chorus, + "text" : text, + "text_token" : text_token, + "text_token_len" : text_token_len, + } + + if "semantic_token" in sample[0]: + semantic_token = [torch.tensor(sample[i]['semantic_token'][0], + dtype=torch.int32) for i in + range(len(sample))] + semantic_token_len = torch.tensor( + [i.size(0) for i in semantic_token], dtype=torch.int32) + semantic_token = pad_sequence(semantic_token, + batch_first=True, + padding_value=0) + batch.update({"semantic_token" : semantic_token, + "semantic_token_len": semantic_token_len}) + + yield batch diff --git a/inspiremusic/flow/decoder.py b/inspiremusic/flow/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..2d1dff57ba44a1b4952d6836729233368c131fe8 --- /dev/null +++ b/inspiremusic/flow/decoder.py @@ -0,0 +1,277 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn +from einops import pack, rearrange, repeat +from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D +from matcha.models.components.transformer import BasicTransformerBlock + +class Transpose(torch.nn.Module): + def __init__(self, dim0: int, dim1: int): + super().__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x: torch.Tensor): + x = torch.transpose(x, self.dim0, self.dim1) + return x + +class CausalBlock1D(Block1D): + def __init__(self, dim: int, dim_out: int): + super(CausalBlock1D, self).__init__(dim, dim_out) + self.block = torch.nn.Sequential( + CausalConv1d(dim, dim_out, 3), + Transpose(1, 2), + nn.LayerNorm(dim_out), + Transpose(1, 2), + nn.Mish(), + ) + + def forward(self, x: torch.Tensor, mask: torch.Tensor): + output = self.block(x * mask) + return output * mask + + +class CausalResnetBlock1D(ResnetBlock1D): + def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8): + super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups) + self.block1 = CausalBlock1D(dim, dim_out) + self.block2 = CausalBlock1D(dim_out, dim_out) + +class CausalConv1d(torch.nn.Conv1d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', + device=None, + dtype=None + ) -> None: + super(CausalConv1d, self).__init__(in_channels, out_channels, + kernel_size, stride, + padding=0, dilation=dilation, + groups=groups, bias=bias, + padding_mode=padding_mode, + device=device, dtype=dtype) + assert stride == 1 + self.causal_padding = (kernel_size - 1, 0) + + def forward(self, x: torch.Tensor): + x = F.pad(x, self.causal_padding) + x = super(CausalConv1d, self).forward(x) + return x + +class ConditionalDecoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + channels=(256, 256), + dropout=0.05, + attention_head_dim=64, + n_blocks=1, + num_mid_blocks=2, + num_heads=4, + act_fn="snake", + ): + """ + This decoder requires an input with the same shape of the target. So, if your text content + is shorter or longer than the outputs, please re-sampling it before feeding to the decoder. + """ + super().__init__() + channels = tuple(channels) + self.in_channels = in_channels + self.out_channels = out_channels + + self.time_embeddings = SinusoidalPosEmb(in_channels) + time_embed_dim = channels[0] * 4 + self.time_mlp = TimestepEmbedding( + in_channels=in_channels, + time_embed_dim=time_embed_dim, + act_fn="silu", + ) + self.down_blocks = nn.ModuleList([]) + self.mid_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + output_channel = in_channels + for i in range(len(channels)): # pylint: disable=consider-using-enumerate + input_channel = output_channel + output_channel = channels[i] + is_last = i == len(channels) - 1 + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + dim=output_channel, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + for _ in range(n_blocks) + ] + ) + downsample = ( + Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) + + for _ in range(num_mid_blocks): + input_channel = channels[-1] + out_channels = channels[-1] + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + + transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + dim=output_channel, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + for _ in range(n_blocks) + ] + ) + + self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks])) + + channels = channels[::-1] + (channels[0],) + for i in range(len(channels) - 1): + input_channel = channels[i] * 2 + output_channel = channels[i + 1] + is_last = i == len(channels) - 2 + resnet = ResnetBlock1D( + dim=input_channel, + dim_out=output_channel, + time_emb_dim=time_embed_dim, + ) + transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + dim=output_channel, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + for _ in range(n_blocks) + ] + ) + upsample = ( + Upsample1D(output_channel, use_conv_transpose=True) + if not is_last + else nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample])) + self.final_block = Block1D(channels[-1], channels[-1]) + self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) + self.initialize_weights() + + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.GroupNorm): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x, mask, mu, t, spks=None, cond=None): + """Forward pass of the UNet1DConditional model. + + Args: + x (torch.Tensor): shape (batch_size, in_channels, time) + mask (_type_): shape (batch_size, 1, time) + t (_type_): shape (batch_size) + spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None. + cond (_type_, optional): placeholder for future use. Defaults to None. + + Raises: + ValueError: _description_ + ValueError: _description_ + + Returns: + _type_: _description_ + """ + + t = self.time_embeddings(t).to(t.dtype) + t = self.time_mlp(t) + x = pack([x, mu], "b * t")[0] + if spks is not None: + spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) + x = pack([x, spks], "b * t")[0] + if cond is not None: + x = pack([x, cond], "b * t")[0] + hiddens = [] + masks = [mask] + for resnet, transformer_blocks, downsample in self.down_blocks: + mask_down = masks[-1] + x = resnet(x, mask_down, t) + x = rearrange(x, "b c t -> b t c").contiguous() + attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down) + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + hiddens.append(x) # Save hidden states for skip connections + x = downsample(x * mask_down) + masks.append(mask_down[:, :, ::2]) + masks = masks[:-1] + mask_mid = masks[-1] + + for resnet, transformer_blocks in self.mid_blocks: + x = resnet(x, mask_mid, t) + x = rearrange(x, "b c t -> b t c").contiguous() + attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid) + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + + for resnet, transformer_blocks, upsample in self.up_blocks: + mask_up = masks.pop() + skip = hiddens.pop() + x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] + x = resnet(x, mask_up, t) + x = rearrange(x, "b c t -> b t c").contiguous() + attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up) + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + x = upsample(x * mask_up) + x = self.final_block(x, mask_up) + output = self.final_proj(x * mask_up) + return output * mask diff --git a/inspiremusic/flow/flow.py b/inspiremusic/flow/flow.py new file mode 100644 index 0000000000000000000000000000000000000000..c67534ed8a1465c1e34a911263f57172801507d8 --- /dev/null +++ b/inspiremusic/flow/flow.py @@ -0,0 +1,143 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import random +from typing import Dict, Optional +import torch +import torch.nn as nn +from torch.nn import functional as F +from omegaconf import DictConfig +from inspiremusic.utils.mask import make_pad_mask +from inspiremusic.music_tokenizer.vqvae import VQVAE + +class MaskedDiff(torch.nn.Module): + def __init__(self, + input_size: int = 512, + output_size: int = 128, + output_type: str = "mel", + vocab_size: int = 4096, + input_frame_rate: int = 50, + only_mask_loss: bool = True, + encoder: torch.nn.Module = None, + length_regulator: torch.nn.Module = None, + decoder: torch.nn.Module = None, + decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, + 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', + 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), + 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, + 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}, + mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 128, 'sampling_rate': 48000, + 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 48000}, + generator_model_dir: str = "../../pretrained_models/InspireMusic-Base/music_tokenizer", + num_codebooks: int = 4 + ): + super().__init__() + self.input_size = input_size + self.output_size = output_size + self.decoder_conf = decoder_conf + self.mel_feat_conf = mel_feat_conf + self.vocab_size = vocab_size + self.output_type = output_type + self.input_frame_rate = input_frame_rate + logging.info(f"input frame rate={self.input_frame_rate}") + self.input_embedding = nn.Embedding(vocab_size, input_size) + + self.encoder = encoder + self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size) + self.decoder = decoder + self.length_regulator = length_regulator + self.only_mask_loss = only_mask_loss + self.quantizer = VQVAE( f'{generator_model_dir}/config.json', + f'{generator_model_dir}/model.pt',with_encoder=True).quantizer + self.quantizer.eval() + self.num_codebooks = num_codebooks + self.cond = None + self.interpolate = False + + def forward( + self, + batch: dict, + device: torch.device, + ) -> Dict[str, Optional[torch.Tensor]]: + + audio_token = batch['acoustic_token'].to(device) + audio_token_len = batch['acoustic_token_len'].to(device) + audio_token = audio_token.view(audio_token.size(0),-1,self.num_codebooks) + if "semantic_token" not in batch: + token = audio_token[:,:,0] + token_len = (audio_token_len/self.num_codebooks).long() + + else: + token = batch['semantic_token'].to(device) + token_len = batch['semantic_token_len'].to(device) + + with torch.no_grad(): + feat = self.quantizer.embed(audio_token) + feat_len = (audio_token_len/self.num_codebooks).long() + + token = self.input_embedding(token) + h, h_lengths = self.encoder(token, token_len) + h, h_lengths = self.length_regulator(h, feat_len) + + # get conditions + if self.cond: + conds = torch.zeros(feat.shape, device=token.device) + for i, j in enumerate(feat_len): + if random.random() < 0.5: + continue + index = random.randint(0, int(0.3 * j)) + conds[i, :index] = feat[i, :index] + conds = conds.transpose(1, 2) + else: + conds = None + + mask = (~make_pad_mask(feat_len)).to(h) + + loss, _ = self.decoder.compute_loss( + feat, + mask.unsqueeze(1), + h.transpose(1, 2).contiguous(), + None, + cond=conds + ) + + return {'loss': loss} + + @torch.inference_mode() + def inference(self, + token, + token_len, + sample_rate): + assert token.shape[0] == 1 + + token = self.input_embedding(torch.clamp(token, min=0)) + h, h_lengths = self.encoder(token, token_len) + + if sample_rate == 48000: + token_len = 2 * token_len + + h, h_lengths = self.length_regulator(h, token_len) + + # get conditions + conds = None + + mask = (~make_pad_mask(token_len)).to(h) + feat = self.decoder( + mu=h.transpose(1, 2).contiguous(), + mask=mask.unsqueeze(1), + spks=None, + cond=conds, + n_timesteps=10 + ) + return feat \ No newline at end of file diff --git a/inspiremusic/flow/flow_matching.py b/inspiremusic/flow/flow_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..e803c6c59f4c2daa1fd1d0a7ee374706507e78c3 --- /dev/null +++ b/inspiremusic/flow/flow_matching.py @@ -0,0 +1,167 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn.functional as F +from matcha.models.components.flow_matching import BASECFM + + +class ConditionalCFM(BASECFM): + def __init__(self, in_channels, cfm_params, estimator: torch.nn.Module = None): + super().__init__( + n_feats=in_channels, + cfm_params=cfm_params, + ) + self.t_scheduler = cfm_params.t_scheduler + self.training_cfg_rate = cfm_params.training_cfg_rate + self.inference_cfg_rate = cfm_params.inference_cfg_rate + # Just change the architecture of the estimator here + self.estimator = estimator + + @torch.inference_mode() + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): + """Forward diffusion + + Args: + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + n_timesteps (int): number of diffusion steps + temperature (float, optional): temperature for scaling noise. Defaults to 1.0. + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + + Returns: + sample: generated mel-spectrogram + shape: (batch_size, n_feats, mel_timesteps) + """ + z = torch.randn_like(mu) * temperature + t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) + if self.t_scheduler == 'cosine': + t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) + return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond) + + def solve_euler(self, x, t_span, mu, mask, spks, cond): + """ + Fixed euler solver for ODEs. + Args: + x (torch.Tensor): random noise + t_span (torch.Tensor): n_timesteps interpolated + shape: (n_timesteps + 1,) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + """ + t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] + t = t.unsqueeze(dim=0) + + # I am storing this because I can later plot it by putting a debugger here and saving it to a file + # Or in future might add like a return_all_steps flag + sol = [] + + for step in range(1, len(t_span)): + dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond) + # Classifier-Free Guidance inference introduced in VoiceBox + if self.inference_cfg_rate > 0: + cfg_dphi_dt = self.forward_estimator( + x, mask, + torch.zeros_like(mu), t, + torch.zeros_like(spks) if spks is not None else None, + torch.zeros_like(cond) if cond is not None else None + ) + dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - + self.inference_cfg_rate * cfg_dphi_dt) + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + + return sol[-1] + + def forward_estimator(self, x, mask, mu, t, spks, cond): + if isinstance(self.estimator, torch.nn.Module): + return self.estimator.forward(x, mask, mu, t, spks, cond) + elif isinstance(self.estimator, onnxruntime.InferenceSession): + ort_inputs = { + 'x': x.cpu().numpy(), + 'mask': mask.cpu().numpy(), + 'mu': mu.cpu().numpy(), + 't': t.cpu().numpy(), + 'spks': spks.cpu().numpy(), + 'cond': cond.cpu().numpy() + } + output = self.estimator.run(None, ort_inputs)[0] + return torch.tensor(output, dtype=x.dtype, device=x.device) + else: + self.estimator.set_input_shape('x', (2, 80, x.size(2))) + self.estimator.set_input_shape('mask', (2, 1, x.size(2))) + self.estimator.set_input_shape('mu', (2, 80, x.size(2))) + self.estimator.set_input_shape('t', (2,)) + self.estimator.set_input_shape('spks', (2, 80)) + self.estimator.set_input_shape('cond', (2, 80, x.size(2))) + # run trt engine + self.estimator.execute_v2([x.contiguous().data_ptr(), + mask.contiguous().data_ptr(), + mu.contiguous().data_ptr(), + t.contiguous().data_ptr(), + spks.contiguous().data_ptr(), + cond.contiguous().data_ptr(), + x.data_ptr()]) + return x + + def compute_loss(self, x1, mask, mu, spks=None, cond=None): + """Computes diffusion loss + + Args: + x1 (torch.Tensor): Target + shape: (batch_size, n_feats, mo) + mask (torch.Tensor): target mask + shape: (batch_size, 1, mel_timesteps) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + spks (torch.Tensor, optional): speaker embedding. Defaults to None. + shape: (batch_size, spk_emb_dim) + + Returns: + loss: conditional flow matching loss + y: conditional flow + shape: (batch_size, n_feats, mel_timesteps) + """ + b, _, t = mu.shape + + t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) + if self.t_scheduler == 'cosine': + t = 1 - torch.cos(t * 0.5 * torch.pi) + + z = torch.randn_like(x1) + y = (1 - (1 - self.sigma_min) * t) * z + t * x1 + u = x1 - (1 - self.sigma_min) * z + + # during training, we randomly drop condition to trade off mode coverage and sample fidelity + if self.training_cfg_rate > 0: + cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate + mu = mu * cfg_mask.view(-1, 1, 1) + if cond is not None: + cond = cond * cfg_mask.view(-1, 1, 1) + + pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond) + loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1]) + return loss, y + diff --git a/inspiremusic/flow/length_regulator.py b/inspiremusic/flow/length_regulator.py new file mode 100644 index 0000000000000000000000000000000000000000..05b74a9403a526c65dd05f0e558c62084b1772fe --- /dev/null +++ b/inspiremusic/flow/length_regulator.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple +import torch.nn as nn +import torch +from torch.nn import functional as F +from inspiremusic.utils.mask import make_pad_mask + + +class InterpolateRegulator(nn.Module): + def __init__( + self, + channels: int, + sampling_ratios: Tuple, + out_channels: int = None, + groups: int = 1, + ): + super().__init__() + self.sampling_ratios = sampling_ratios + out_channels = out_channels or channels + model = nn.ModuleList([]) + if len(sampling_ratios) > 0: + for _ in sampling_ratios: + module = nn.Conv1d(channels, channels, 3, 1, 1) + norm = nn.GroupNorm(groups, channels) + act = nn.Mish() + model.extend([module, norm, act]) + model.append( + nn.Conv1d(channels, out_channels, 1, 1) + ) + self.model = nn.Sequential(*model) + + def forward(self, x, ylens=None): + # x in (B, T, D) + mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1) + x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear') + out = self.model(x).transpose(1, 2).contiguous() + olens = ylens + return out * mask, olens + + def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50): + # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel + # x in (B, T, D) + if x2.shape[1] > 40: + x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear') + x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2, + mode='linear') + x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear') + x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2) + else: + x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear') + if x1.shape[1] != 0: + x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear') + x = torch.concat([x1, x2], dim=2) + else: + x = x2 + out = self.model(x).transpose(1, 2).contiguous() + return out, mel_len1 + mel_len2 diff --git a/inspiremusic/hifigan/discriminator.py b/inspiremusic/hifigan/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..6fc784599f3493e20830290a9cd182789c0428d5 --- /dev/null +++ b/inspiremusic/hifigan/discriminator.py @@ -0,0 +1,140 @@ +import torch +import torch.nn as nn +from torch.nn.utils import weight_norm +from typing import List, Optional, Tuple +from einops import rearrange +from torchaudio.transforms import Spectrogram + + +class MultipleDiscriminator(nn.Module): + def __init__( + self, mpd: nn.Module, mrd: nn.Module + ): + super().__init__() + self.mpd = mpd + self.mrd = mrd + + def forward(self, y: torch.Tensor, y_hat: torch.Tensor): + y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], [] + this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mpd(y.unsqueeze(dim=1), y_hat.unsqueeze(dim=1)) + y_d_rs += this_y_d_rs + y_d_gs += this_y_d_gs + fmap_rs += this_fmap_rs + fmap_gs += this_fmap_gs + this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mrd(y, y_hat) + y_d_rs += this_y_d_rs + y_d_gs += this_y_d_gs + fmap_rs += this_fmap_rs + fmap_gs += this_fmap_gs + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class MultiResolutionDiscriminator(nn.Module): + def __init__( + self, + fft_sizes: Tuple[int, ...] = (2048, 1024, 512), + num_embeddings: Optional[int] = None, + ): + """ + Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec. + Additionally, it allows incorporating conditional information with a learned embeddings table. + + Args: + fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512). + num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator. + Defaults to None. + """ + + super().__init__() + self.discriminators = nn.ModuleList( + [DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes] + ) + + def forward( + self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]: + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for d in self.discriminators: + y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) + y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorR(nn.Module): + def __init__( + self, + window_length: int, + num_embeddings: Optional[int] = None, + channels: int = 32, + hop_factor: float = 0.25, + bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)), + ): + super().__init__() + self.window_length = window_length + self.hop_factor = hop_factor + self.spec_fn = Spectrogram( + n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None + ) + n_fft = window_length // 2 + 1 + bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] + self.bands = bands + convs = lambda: nn.ModuleList( + [ + weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))), + weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), + weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), + weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), + weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))), + ] + ) + self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) + + if num_embeddings is not None: + self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels) + torch.nn.init.zeros_(self.emb.weight) + + self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))) + + def spectrogram(self, x): + # Remove DC offset + x = x - x.mean(dim=-1, keepdims=True) + # Peak normalize the volume of input audio + x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9) + x = self.spec_fn(x) + x = torch.view_as_real(x) + x = rearrange(x, "b f t c -> b c t f") + # Split into bands + x_bands = [x[..., b[0]: b[1]] for b in self.bands] + return x_bands + + def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None): + x_bands = self.spectrogram(x) + fmap = [] + x = [] + for band, stack in zip(x_bands, self.band_convs): + for i, layer in enumerate(stack): + band = layer(band) + band = torch.nn.functional.leaky_relu(band, 0.1) + if i > 0: + fmap.append(band) + x.append(band) + x = torch.cat(x, dim=-1) + if cond_embedding_id is not None: + emb = self.emb(cond_embedding_id) + h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) + else: + h = 0 + x = self.conv_post(x) + fmap.append(x) + x += h + + return x, fmap diff --git a/inspiremusic/hifigan/f0_predictor.py b/inspiremusic/hifigan/f0_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..331394ddc5c9240a2f66186ab0ce263d80ceeac0 --- /dev/null +++ b/inspiremusic/hifigan/f0_predictor.py @@ -0,0 +1,55 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn +from torch.nn.utils import weight_norm + + +class ConvRNNF0Predictor(nn.Module): + def __init__(self, + num_class: int = 1, + in_channels: int = 80, + cond_channels: int = 512 + ): + super().__init__() + + self.num_class = num_class + self.condnet = nn.Sequential( + weight_norm( + nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + weight_norm( + nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + weight_norm( + nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + weight_norm( + nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + weight_norm( + nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + ) + self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.condnet(x) + x = x.transpose(1, 2) + return torch.abs(self.classifier(x).squeeze(-1)) diff --git a/inspiremusic/hifigan/generator.py b/inspiremusic/hifigan/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..033758c05d55b9a1e23f79c7b551c6000762ee26 --- /dev/null +++ b/inspiremusic/hifigan/generator.py @@ -0,0 +1,411 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""HIFI-GAN""" + +from typing import Dict, Optional, List +import numpy as np +from scipy.signal import get_window +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Conv1d +from torch.nn import ConvTranspose1d +from torch.nn.utils import remove_weight_norm +from torch.nn.utils import weight_norm +from torch.distributions.uniform import Uniform + +from inspiremusic.transformer.activation import Snake +from inspiremusic.utils.common import get_padding +from inspiremusic.utils.common import init_weights + + +"""hifigan based generator implementation. + +This code is modified from https://github.com/jik876/hifi-gan + ,https://github.com/kan-bayashi/ParallelWaveGAN and + https://github.com/NVIDIA/BigVGAN + +""" + + +class ResBlock(torch.nn.Module): + """Residual block module in HiFiGAN/BigVGAN.""" + def __init__( + self, + channels: int = 512, + kernel_size: int = 3, + dilations: List[int] = [1, 3, 5], + ): + super(ResBlock, self).__init__() + self.convs1 = nn.ModuleList() + self.convs2 = nn.ModuleList() + + for dilation in dilations: + self.convs1.append( + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation, + padding=get_padding(kernel_size, dilation) + ) + ) + ) + self.convs2.append( + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1) + ) + ) + ) + self.convs1.apply(init_weights) + self.convs2.apply(init_weights) + self.activations1 = nn.ModuleList([ + Snake(channels, alpha_logscale=False) + for _ in range(len(self.convs1)) + ]) + self.activations2 = nn.ModuleList([ + Snake(channels, alpha_logscale=False) + for _ in range(len(self.convs2)) + ]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for idx in range(len(self.convs1)): + xt = self.activations1[idx](x) + xt = self.convs1[idx](xt) + xt = self.activations2[idx](xt) + xt = self.convs2[idx](xt) + x = xt + x + return x + + def remove_weight_norm(self): + for idx in range(len(self.convs1)): + remove_weight_norm(self.convs1[idx]) + remove_weight_norm(self.convs2[idx]) + + +class SineGen(torch.nn.Module): + """ Definition of sine generator + SineGen(samp_rate, harmonic_num = 0, + sine_amp = 0.1, noise_std = 0.003, + voiced_threshold = 0, + flag_for_pulse=False) + samp_rate: sampling rate in Hz + harmonic_num: number of harmonic overtones (default 0) + sine_amp: amplitude of sine-wavefrom (default 0.1) + noise_std: std of Gaussian noise (default 0.003) + voiced_thoreshold: F0 threshold for U/V classification (default 0) + flag_for_pulse: this SinGen is used inside PulseGen (default False) + Note: when flag_for_pulse is True, the first time step of a voiced + segment is always sin(np.pi) or cos(0) + """ + + def __init__(self, samp_rate, harmonic_num=0, + sine_amp=0.1, noise_std=0.003, + voiced_threshold=0): + super(SineGen, self).__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + + def _f02uv(self, f0): + # generate uv signal + uv = (f0 > self.voiced_threshold).type(torch.float32) + return uv + + @torch.no_grad() + def forward(self, f0): + """ + :param f0: [B, 1, sample_len], Hz + :return: [B, 1, sample_len] + """ + + F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device) + for i in range(self.harmonic_num + 1): + F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate + + theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1) + u_dist = Uniform(low=-np.pi, high=np.pi) + phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device) + phase_vec[:, 0, :] = 0 + + # generate sine waveforms + sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec) + + # generate uv signal + uv = self._f02uv(f0) + + # noise: for unvoiced should be similar to sine_amp + # std = self.sine_amp/3 -> max value ~ self.sine_amp + # . for voiced regions is self.noise_std + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + + # first: set the unvoiced part to 0 by uv + # then: additive noise + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + + +class SourceModuleHnNSF(torch.nn.Module): + """ SourceModule for hn-nsf + SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + harmonic_num: number of harmonic above F0 (default: 0) + sine_amp: amplitude of sine source signal (default: 0.1) + add_noise_std: std of additive Gaussian noise (default: 0.003) + note that amplitude of noise in unvoiced is decided + by sine_amp + voiced_threshold: threhold to set U/V given F0 (default: 0) + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + uv (batchsize, length, 1) + """ + + def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0): + super(SourceModuleHnNSF, self).__init__() + + self.sine_amp = sine_amp + self.noise_std = add_noise_std + + # to produce sine waveforms + self.l_sin_gen = SineGen(sampling_rate, harmonic_num, + sine_amp, add_noise_std, voiced_threshod) + + # to merge source harmonics into a single excitation + self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) + self.l_tanh = torch.nn.Tanh() + + def forward(self, x): + """ + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + """ + # source for harmonic branch + with torch.no_grad(): + sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2)) + sine_wavs = sine_wavs.transpose(1, 2) + uv = uv.transpose(1, 2) + sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + + # source for noise branch, in the same shape as uv + noise = torch.randn_like(uv) * self.sine_amp / 3 + return sine_merge, noise, uv + + +class HiFTGenerator(nn.Module): + """ + HiFTNet Generator: Neural Source Filter + ISTFTNet + https://arxiv.org/abs/2309.09493 + """ + def __init__( + self, + in_channels: int = 80, + base_channels: int = 512, + nb_harmonics: int = 8, + sampling_rate: int = 22050, + nsf_alpha: float = 0.1, + nsf_sigma: float = 0.003, + nsf_voiced_threshold: float = 10, + upsample_rates: List[int] = [8, 8], + upsample_kernel_sizes: List[int] = [16, 16], + istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4}, + resblock_kernel_sizes: List[int] = [3, 7, 11], + resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + source_resblock_kernel_sizes: List[int] = [7, 11], + source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]], + lrelu_slope: float = 0.1, + audio_limit: float = 0.99, + f0_predictor: torch.nn.Module = None, + ): + super(HiFTGenerator, self).__init__() + + self.out_channels = 1 + self.nb_harmonics = nb_harmonics + self.sampling_rate = sampling_rate + self.istft_params = istft_params + self.lrelu_slope = lrelu_slope + self.audio_limit = audio_limit + + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.m_source = SourceModuleHnNSF( + sampling_rate=sampling_rate, + upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"], + harmonic_num=nb_harmonics, + sine_amp=nsf_alpha, + add_noise_std=nsf_sigma, + voiced_threshod=nsf_voiced_threshold) + self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"]) + + self.conv_pre = weight_norm( + Conv1d(in_channels, base_channels, 7, 1, padding=3) + ) + + # Up + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + base_channels // (2**i), + base_channels // (2**(i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + # Down + self.source_downs = nn.ModuleList() + self.source_resblocks = nn.ModuleList() + downsample_rates = [1] + upsample_rates[::-1][:-1] + downsample_cum_rates = np.cumprod(downsample_rates) + for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)): + if u == 1: + self.source_downs.append( + Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1) + ) + else: + self.source_downs.append( + Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2)) + ) + + self.source_resblocks.append( + ResBlock(base_channels // (2 ** (i + 1)), k, d) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = base_channels // (2**(i + 1)) + for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(ResBlock(ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + self.reflection_pad = nn.ReflectionPad1d((1, 0)) + self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32)) + self.f0_predictor = f0_predictor + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + self.m_source.remove_weight_norm() + for l in self.source_downs: + remove_weight_norm(l) + for l in self.source_resblocks: + l.remove_weight_norm() + + def _stft(self, x): + spec = torch.stft( + x, + self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device), + return_complex=True) + spec = torch.view_as_real(spec) # [B, F, TT, 2] + return spec[..., 0], spec[..., 1] + + def _istft(self, magnitude, phase): + magnitude = torch.clip(magnitude, max=1e2) + real = magnitude * torch.cos(phase) + img = magnitude * torch.sin(phase) + inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], + self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device)) + return inverse_transform + + def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor: + s_stft_real, s_stft_imag = self._stft(s.squeeze(1)) + s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1) + + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, self.lrelu_slope) + x = self.ups[i](x) + + if i == self.num_upsamples - 1: + x = self.reflection_pad(x) + + # fusion + si = self.source_downs[i](s_stft) + si = self.source_resblocks[i](si) + x = x + si + + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + + x = F.leaky_relu(x) + x = self.conv_post(x) + magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :]) + phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy + + x = self._istft(magnitude, phase) + x = torch.clamp(x, -self.audio_limit, self.audio_limit) + return x + + def forward( + self, + batch: dict, + device: torch.device, + ) -> Dict[str, Optional[torch.Tensor]]: + speech_feat = batch['speech_feat'].transpose(1, 2).to(device) + # mel->f0 + f0 = self.f0_predictor(speech_feat) + # f0->source + s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t + s, _, _ = self.m_source(s) + s = s.transpose(1, 2) + # mel+source->speech + generated_speech = self.decode(x=speech_feat, s=s) + return generated_speech, f0 + + @torch.inference_mode() + def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor: + # mel->f0 + f0 = self.f0_predictor(speech_feat) + # f0->source + s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t + s, _, _ = self.m_source(s) + s = s.transpose(1, 2) + # use cache_source to avoid glitch + if cache_source.shape[2] != 0: + s[:, :, :cache_source.shape[2]] = cache_source + generated_speech = self.decode(x=speech_feat, s=s) + return generated_speech, s diff --git a/inspiremusic/hifigan/hifigan.py b/inspiremusic/hifigan/hifigan.py new file mode 100644 index 0000000000000000000000000000000000000000..8d7b612d1cd86569d430e8bf03bc7e3e0fa72957 --- /dev/null +++ b/inspiremusic/hifigan/hifigan.py @@ -0,0 +1,66 @@ +from typing import Dict, Optional +import torch +import torch.nn as nn +import torch.nn.functional as F +from matcha.hifigan.models import feature_loss, generator_loss, discriminator_loss +from inspiremusic.utils.losses import tpr_loss, mel_loss + +class HiFiGan(nn.Module): + def __init__(self, generator, discriminator, mel_spec_transform, + multi_mel_spectral_recon_loss_weight=45, feat_match_loss_weight=2.0, + tpr_loss_weight=1.0, tpr_loss_tau=0.04): + super(HiFiGan, self).__init__() + self.generator = generator + self.discriminator = discriminator + self.mel_spec_transform = mel_spec_transform + self.multi_mel_spectral_recon_loss_weight = multi_mel_spectral_recon_loss_weight + self.feat_match_loss_weight = feat_match_loss_weight + self.tpr_loss_weight = tpr_loss_weight + self.tpr_loss_tau = tpr_loss_tau + + def forward( + self, + batch: dict, + device: torch.device, + ) -> Dict[str, Optional[torch.Tensor]]: + if batch['turn'] == 'generator': + return self.forward_generator(batch, device) + else: + return self.forward_discriminator(batch, device) + + def forward_generator(self, batch, device): + real_speech = batch['speech'].to(device) + pitch_feat = batch['pitch_feat'].to(device) + # 1. calculate generator outputs + generated_speech, generated_f0 = self.generator(batch, device) + # 2. calculate discriminator outputs + y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech) + # 3. calculate generator losses, feature loss, mel loss, tpr losses [Optional] + loss_gen, _ = generator_loss(y_d_gs) + loss_fm = feature_loss(fmap_rs, fmap_gs) + loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform) + if self.tpr_loss_weight != 0: + loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau) + else: + loss_tpr = torch.zeros(1).to(device) + loss_f0 = F.l1_loss(generated_f0, pitch_feat) + loss = loss_gen + self.feat_match_loss_weight * loss_fm + \ + self.multi_mel_spectral_recon_loss_weight * loss_mel + \ + self.tpr_loss_weight * loss_tpr + loss_f0 + return {'loss': loss, 'loss_gen': loss_gen, 'loss_fm': loss_fm, 'loss_mel': loss_mel, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0} + + def forward_discriminator(self, batch, device): + real_speech = batch['speech'].to(device) + # 1. calculate generator outputs + with torch.no_grad(): + generated_speech, generated_f0 = self.generator(batch, device) + # 2. calculate discriminator outputs + y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech) + # 3. calculate discriminator losses, tpr losses [Optional] + loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs) + if self.tpr_loss_weight != 0: + loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau) + else: + loss_tpr = torch.zeros(1).to(device) + loss = loss_disc + self.tpr_loss_weight * loss_tpr + return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr} diff --git a/inspiremusic/llm/llm.py b/inspiremusic/llm/llm.py new file mode 100644 index 0000000000000000000000000000000000000000..338853613ef47957aa5508c151dbbf293849214c --- /dev/null +++ b/inspiremusic/llm/llm.py @@ -0,0 +1,402 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Optional, Callable, List, Generator +import torch +from torch import nn +from torch.nn.utils.rnn import pad_sequence, unpad_sequence +from inspiremusic.utils.common import IGNORE_ID +from inspiremusic.transformer.label_smoothing_loss import LabelSmoothingLoss +from inspiremusic.utils.common import th_accuracy +from torch import Tensor +from math import log +from einops import rearrange, reduce, repeat +import logging + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +class SinusoidalEmbedding(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: Tensor) -> Tensor: + device, half_dim = x.device, self.dim // 2 + emb = torch.tensor(log(10000) / (half_dim - 1), device=device) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j") + return torch.cat((emb.sin(), emb.cos()), dim=-1).to(torch.float32) + +class LLM(torch.nn.Module): + def __init__( + self, + text_encoder_input_size: int, + llm_input_size: int, + llm_output_size: int, + audio_token_size: int, + llm: torch.nn.Module, + sampling: Callable, + text_encoder_conf: Dict = None, + length_normalized_loss: bool = True, + lsm_weight: float = 0.0, + frozen_input_embed: bool = False, + **kwargs, + ): + super().__init__() + self.llm_input_size = llm_input_size + self.audio_token_size = audio_token_size + # 1. build text token inputs related modules + + if llm is None: + self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size) + else: + self.text_embedding = llm.model.model.embed_tokens + if frozen_input_embed: + print("Freezing input embedding layer") + for p in self.text_embedding.parameters(): + p.requires_grad = False + self.chorus_embedding = torch.nn.Embedding(5, llm_input_size) # intro, chorus, verse1, verse2 , outro + + self.text_encoder_conf = text_encoder_conf + self.text_encoder = self.build_encoder(text_encoder_conf) + self.infer_cfg_ratio = kwargs.get("infer_cfg_ratio", None) + logging.info(f"infer_cfg_ratio: {self.infer_cfg_ratio}") + self.train_cfg_ratio = kwargs.get("train_cfg_ratio", None) + logging.info(f"train_cfg_ratio: {self.train_cfg_ratio}") + # 2. build audio token language model related modules + self.sos_eos = 0 + self.task_id = 1 + + self.llm_embedding = torch.nn.Embedding(2, llm_input_size) + self.llm = llm + self.llm_decoder = nn.Linear(llm_output_size, audio_token_size + 1) + self.criterion_ce = LabelSmoothingLoss( + size=audio_token_size + 1, + padding_idx=IGNORE_ID, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + + # 3. [Optional] build audio token related modules + self.speech_embedding = torch.nn.Embedding(audio_token_size, llm_input_size) + self.spk_embed_affine_layer = torch.nn.Linear(192, llm_input_size) + self.num_codebooks = 4 + # 4. sampling method + self.sampling = sampling + self.time_embedding = SinusoidalEmbedding(llm_input_size) + + def cfg_dropout(self, text_token, text_token_len, p): + # Classifier-Free Guidance Dropout + B = text_token.size(0) + num_samples_to_mask = int(p * B) + if num_samples_to_mask == 0: + num_samples_to_mask = 1 + indices_to_mask = torch.randperm(B, device=text_token.device)[:num_samples_to_mask] + text_token[indices_to_mask] = 0 + text_token_len[indices_to_mask] = 0 + + return text_token, text_token_len + + def build_encoder(self, encoder_conf=None): + if encoder_conf is None: + assert hasattr(self, "encoder_conf"), \ + "function param encoder_conf is None and model doesn't has encoder_conf attribute either." + encoder_conf = self.encoder_conf + + encoder_name = encoder_conf.pop("name", "transformer") + model = None + if encoder_name == "transformer": + from inspiremusic.transformer.encoder.conformer_encoder import ConformerEncoder + model = ConformerEncoder( + **encoder_conf, + input_size=self.input_size, + use_cnn_module=False, + macaron_style=False, + ) + elif encoder_name == "conformer": + from inspiremusic.transformer.encoder.conformer_encoder import ConformerEncoder + model = ConformerEncoder( + **encoder_conf, + input_size=self.input_size, + ) + elif encoder_name == "llama_encoder": + from inspiremusic.transformer.encoder.llama_encoder import LlamaEncoder + model = LlamaEncoder( + **encoder_conf, + input_size=self.input_size, + ) + elif encoder_name == "qwen2": + from inspiremusic.transformer.encoder.qwen_encoder import QwenEncoder + model = QwenEncoder( + **encoder_conf, + input_size=self.input_size, + ) + elif encoder_name == "qwen2.5": + from inspiremusic.transformer.encoder.qwen_encoder import QwenEncoder + model = QwenEncoder( + **encoder_conf, + input_size=self.input_size, + ) + + encoder_conf["name"] = encoder_name + + return model + + def encode(self, + text: torch.Tensor, + text_lengths: torch.Tensor): + if self.text_encoder is not None: + encoder_out, encoder_mask = self.text_encoder(text, text_lengths, + decoding_chunk_size=1, + num_decoding_left_chunks=-1) + encoder_out_lens = encoder_mask.squeeze(1).sum(1) + encoder_out = self.text_encoder_affine_layer(encoder_out) + else: + encoder_out, encoder_out_lens = text, text_lengths + return encoder_out, encoder_out_lens + + def pad_unpad_sequence(self, sos_eos_emb, embeddings, text_token, + text_token_len, task_id_emb, audio_token, + audio_token_len, seg_len): + text_token = unpad_sequence(text_token, text_token_len.cpu(), + batch_first=True) + + audio_token = unpad_sequence(audio_token, audio_token_len.cpu(), + batch_first=True) + + for i in range(len(embeddings)): + embeddings[i] = unpad_sequence(embeddings[i], seg_len.cpu(), batch_first=True) + + lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0)] + [embedding[i] for embedding in embeddings] + [text_token[i], task_id_emb.squeeze(dim=0), audio_token[i]], dim=0) for i in range(len(text_token))] + lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32) + lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID) + return lm_input, lm_input_len + + def forward( + self, + batch: dict, + device: torch.device, + ) -> Dict[str, Optional[torch.Tensor]]: + """ + Args: + text: (B, L, D) + text_lengths: (B,) + audio: (B, T, N) or (B, T) + audio_lengths: (B,) + """ + mask = True + text_token = batch['text_token'].to(device) + text_token_len = batch['text_token_len'].to(device) + if "semantic_token" not in batch: + audio_token = batch['acoustic_token'].to(device) + audio_token_len = batch['acoustic_token_len'].to(device) + audio_token = audio_token.view(audio_token.size(0), -1, self.num_codebooks) + audio_token = audio_token[:, :, 0] + audio_token_len = (audio_token_len / self.num_codebooks).long() + + else: + audio_token = batch['semantic_token'].to(device) + audio_token_len = batch['semantic_token_len'].to(device) + + time_start = batch['time_start'].to(device) + time_end = batch['time_end'].to(device) + chorus = batch['chorus'].to(device) + # 1. encode text_token + + if self.train_cfg_ratio > 0: + # Classifier-Free Guidance + text_token, _ = self.cfg_dropout(text_token, text_token_len, self.train_cfg_ratio) + + # 2. Time Embedding & chorus embedding + text_token = self.text_embedding(text_token) + text_token, text_token_len = self.encode(text_token, text_token_len) + if mask: + time_mask = time_start != -1.0 + seg_len = time_mask.sum(-1) + time_start = time_start.masked_fill(~time_mask, 0.0) + time_end = time_end.masked_fill(~time_mask, 0.0) + chorus = chorus.masked_fill(~time_mask, 0) + time_start_embed = self.time_embedding(time_start.view(-1)).to(text_token.dtype) + time_end_embed = self.time_embedding(time_end.view(-1)).to(text_token.dtype) + time_start_embed = time_start_embed.view(chorus.size(0), chorus.size(1), -1) + time_end_embed = time_end_embed.view(chorus.size(0), chorus.size(1), -1) + chorus_embed = self.chorus_embedding(chorus) + lm_target = [torch.tensor([IGNORE_ID] * (1 + 3 * seg_len[i] + text_token_len[i]) + audio_token[i,:audio_token_len[i]].tolist() + [self.audio_token_size]) for i in range(text_token.size(0))] + else: + time_start_embed = self.time_embedding(time_start).to(text_token.dtype) + time_end_embed = self.time_embedding(time_end).to(text_token.dtype) + chorus_embed = self.chorus_embedding(chorus) + + lm_target = [torch.tensor( + [IGNORE_ID] * (4 + text_token_len[i]) + audio_token[i,:audio_token_len[i]].tolist() + [self.audio_token_size]) for i in range(text_token.size(0))] + + lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device) + + # 3. eos and task_id + sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + + # 4. encode audio_token + audio_token = self.speech_embedding(audio_token) + + # 5. unpad and pad + lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, + [time_start_embed, + time_end_embed, + chorus_embed], + text_token, + text_token_len, + task_id_emb, + audio_token, + audio_token_len, + seg_len) + # 6. run lm forward + lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device)) + logits = self.llm_decoder(lm_output) + loss = self.criterion_ce(logits, lm_target) + + acc = th_accuracy(logits.view(-1, self.audio_token_size + 1), lm_target, ignore_label=IGNORE_ID) + + return {'loss': loss, 'acc': acc} + + def sampling_ids( + self, + weighted_scores: torch.Tensor, + decoded_tokens: List, + ignore_eos: bool = True, + ): + top_ids = self.sampling(weighted_scores, decoded_tokens) + return top_ids + + @torch.inference_mode() + def inference( + self, + text: torch.Tensor, + text_len: torch.Tensor, + audio_token: torch.Tensor, + audio_token_len: torch.Tensor, + prompt_text: torch.Tensor, + prompt_text_len: torch.Tensor, + prompt_audio_token: torch.Tensor, + prompt_audio_token_len: torch.Tensor, + embeddings: List, + duration_to_gen: float = 30, + task: str = "continuation", + token_rate: int = 75, + limit_audio_prompt_len: int = 5, + ) -> Generator[torch.Tensor, None, None]: + device = text.device + + if text is not None: + text = torch.concat([prompt_text, text], dim=1) + text_len += prompt_text_len + infer_cfg = self.infer_cfg_ratio >= 0.0 + if infer_cfg: + text_cfg = self.text_embedding(text.new_zeros(text.shape)) + text = self.text_embedding(text) + + # 1. encode text + text, text_len = self.encode(text, text_len) + + # 2. encode embedding + if embeddings is not None: + time_start, time_end, chorus = embeddings + + if len(chorus.shape) == 1: + time_start_embed = self.time_embedding(time_start).reshape(1, 1, -1) # .half() + time_end_embed = self.time_embedding(time_end).reshape(1, 1, -1) # .half() + chorus_embed = self.chorus_embedding(chorus).reshape(1, 1, -1) # .half() + else: + time_start_embed = self.time_embedding( + time_start.view(-1)).reshape(1, chorus.size(1), -1) # .half() + time_end_embed = self.time_embedding(time_end.view(-1)).reshape(1, chorus.size(1), -1) # .half() + chorus_embed = self.chorus_embedding(chorus) # .half() + + # 3. concat llm_input + sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + + if audio_token_len: + audio_token = audio_token[:, :(limit_audio_prompt_len * token_rate)] + audio_token_emb = self.speech_embedding(audio_token) + else: + audio_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device) + + if prompt_audio_token_len: + prompt_audio_token_emb = self.speech_embedding(prompt_audio_token) + else: + prompt_audio_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device) + # Check if removing prompt audio token will fail decoding. + + if task == "continuation": + lm_input = torch.concat( + [sos_eos_emb, time_start_embed, time_end_embed, + chorus_embed, text, task_id_emb, audio_token_emb], dim=1) + + if infer_cfg: + audio_cfg = self.speech_embedding( + audio_token.new_zeros(audio_token.shape)) + lm_cf_input = torch.concat( + [sos_eos_emb, torch.rand_like(time_start_embed), + torch.rand_like(time_end_embed), + torch.rand_like(chorus_embed), text_cfg, task_id_emb, + audio_cfg], dim=1) + lm_input = torch.cat([lm_input, lm_cf_input], 0) + else: + lm_input = torch.concat( + [sos_eos_emb, time_start_embed, time_end_embed, + chorus_embed, text, task_id_emb], dim=1) + if infer_cfg: + lm_cf_input = torch.concat( + [sos_eos_emb, torch.rand_like(time_start_embed), + torch.rand_like(time_end_embed), + torch.rand_like(chorus_embed), text_cfg, task_id_emb], + dim=1) + lm_input = torch.cat([lm_input, lm_cf_input], 0) + + # 4. cal min/max_length + min_len = duration_to_gen * token_rate + max_len = duration_to_gen * token_rate + logging.info( + f"LLM generation sequence length: {max_len}, generate audio length {duration_to_gen}s.") + + # 5. step by step decode + out_tokens = [] + offset = 0 + state = None + + for i in range(int(max_len)): + y_pred, _, state = self.llm.forward_one_step(lm_input, torch.ones(lm_input.shape[0], lm_input.shape[1], device=lm_input.device).to(torch.bool), cache=state) + logits = self.llm_decoder(y_pred[:, -1]) + if infer_cfg: + # perform context free guidance + logits_cf = logits[1] + logits = logits[0] + infer_cfg_ratio = self.infer_cfg_ratio + logits = infer_cfg_ratio * logits + (1 - infer_cfg_ratio) * logits_cf + + logp = logits.log_softmax(dim=-1) + logp = logp.squeeze(dim=0) + top_ids = self.sampling_ids(logp, out_tokens, ignore_eos=i < min_len).item() + + if top_ids == self.audio_token_size: + break + + # # in stream mode, yield token one by one + + yield torch.tensor([[top_ids]], dtype=torch.int64, device=device) + out_tokens.append(top_ids) + offset += lm_input.size(1) + lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) + if infer_cfg: + lm_input = lm_input.repeat(2, 1, 1) diff --git a/inspiremusic/metrics/clap_score.py b/inspiremusic/metrics/clap_score.py new file mode 100644 index 0000000000000000000000000000000000000000..d77b200323d9374f4ea64ee3a7eeeb1c0c21fecf --- /dev/null +++ b/inspiremusic/metrics/clap_score.py @@ -0,0 +1,135 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import requests +from tqdm import tqdm +import torch +import numpy as np +import laion_clap +from clap_module.factory import load_state_dict +import librosa +import pyloudnorm as pyln + +# following documentation from https://github.com/LAION-AI/CLAP +def int16_to_float32(x): + return (x / 32767.0).astype(np.float32) + +def float32_to_int16(x): + x = np.clip(x, a_min=-1., a_max=1.) + return (x * 32767.).astype(np.int16) + + +def clap_score(id2text, audio_path, audio_files_extension='.wav', clap_model='music_audioset_epoch_15_esc_90.14.pt'): + """ + Cosine similarity is computed between the LAION-CLAP text embedding of the given prompt and + the LAION-CLAP audio embedding of the generated audio. LION-CLAP: https://github.com/LAION-AI/CLAP + + This evaluation script assumes that audio_path files are identified with the ids in id2text. + + clap_score() evaluates all ids in id2text. + + GPU-based computation. + + Select one of the following models from https://github.com/LAION-AI/CLAP: + - music_speech_audioset_epoch_15_esc_89.98.pt (used by musicgen) + - music_audioset_epoch_15_esc_90.14.pt + - music_speech_epoch_15_esc_89.25.pt + - 630k-audioset-fusion-best.pt (our default, with "fusion" to handle longer inputs) + + Params: + -- id2text: dictionary with the mapping between id (generated audio filenames in audio_path) + and text (prompt used to generate audio). clap_score() evaluates all ids in id2text. + -- audio_path: path where the generated audio files to evaluate are available. + -- audio_files_extension: files extension (default .wav) in eval_path. + -- clap_model: choose one of the above clap_models (default: '630k-audioset-fusion-best.pt'). + Returns: + -- CLAP-LION score + """ + # load model + if clap_model == 'music_speech_audioset_epoch_15_esc_89.98.pt': + url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_speech_audioset_epoch_15_esc_89.98.pt' + clap_path = 'CLAP/music_speech_audioset_epoch_15_esc_89.98.pt' + model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cuda') + elif clap_model == 'music_audioset_epoch_15_esc_90.14.pt': + url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_audioset_epoch_15_esc_90.14.pt' + clap_path = 'CLAP/music_audioset_epoch_15_esc_90.14.pt' + model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cuda') + elif clap_model == 'music_speech_epoch_15_esc_89.25.pt': + url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_speech_epoch_15_esc_89.25.pt' + clap_path = 'CLAP/music_speech_epoch_15_esc_89.25.pt' + model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cuda') + elif clap_model == '630k-audioset-fusion-best.pt': + url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/630k-audioset-fusion-best.pt' + clap_path = 'CLAP/630k-audioset-fusion-best.pt' + model = laion_clap.CLAP_Module(enable_fusion=True, device='cuda') + else: + raise ValueError('clap_model not implemented') + + # download clap_model if not already downloaded + if not os.path.exists(clap_path): + print('Downloading ', clap_model, '...') + os.makedirs(os.path.dirname(clap_path), exist_ok=True) + + response = requests.get(url, stream=True) + total_size = int(response.headers.get('content-length', 0)) + + with open(clap_path, 'wb') as file: + with tqdm(total=total_size, unit='B', unit_scale=True) as progress_bar: + for data in response.iter_content(chunk_size=8192): + file.write(data) + progress_bar.update(len(data)) + + # fixing CLAP-LION issue, see: https://github.com/LAION-AI/CLAP/issues/118 + pkg = load_state_dict(clap_path) + pkg.pop('text_branch.embeddings.position_ids', None) + model.model.load_state_dict(pkg) + model.eval() + + if not os.path.isdir(audio_path): + raise ValueError(f'audio_path: {audio_path} does not exist') + + if id2text: + print('[EXTRACTING TEXT EMBEDDINGS] ') + batch_size = 64 + text_emb = {} + for i in tqdm(range(0, len(id2text), batch_size)): + batch_ids = list(id2text.keys())[i:i+batch_size] + batch_texts = [id2text[id] for id in batch_ids] + with torch.no_grad(): + embeddings = model.get_text_embedding(batch_texts, use_tensor=True) + for id, emb in zip(batch_ids, embeddings): + text_emb[id] = emb + + else: + raise ValueError('Must specify id2text') + + print('[EVALUATING GENERATIONS] ', audio_path) + score = 0 + count = 0 + for id in tqdm(id2text.keys()): + file_path = os.path.join(audio_path, str(id)+audio_files_extension) + if os.path.isfile(file_path): + with torch.no_grad(): + audio, _ = librosa.load(file_path, sr=48000, mono=True) # sample rate should be 48000 + audio = pyln.normalize.peak(audio, -1.0) + audio = audio.reshape(1, -1) # unsqueeze (1,T) + audio = torch.from_numpy(int16_to_float32(float32_to_int16(audio))).float() + audio_embeddings = model.get_audio_embedding_from_data(x = audio, use_tensor=True) + cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_emb[id].unsqueeze(0), dim=1, eps=1e-8)[0] + print(f"{id} | CLAP score = {cosine_sim}") + score += cosine_sim + count += 1 + + return score / count if count > 0 else 0 + diff --git a/inspiremusic/metrics/openl3_fd.py b/inspiremusic/metrics/openl3_fd.py new file mode 100644 index 0000000000000000000000000000000000000000..78287970a8250dec2c6bbc77b5b4791122a02259 --- /dev/null +++ b/inspiremusic/metrics/openl3_fd.py @@ -0,0 +1,338 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import openl3 +import librosa +import numpy as np +from scipy import linalg +import glob +from tqdm import tqdm +import os +import soxr +import pyloudnorm as pyln + + +def calculate_embd_statistics(embd_lst): + if isinstance(embd_lst, list): + embd_lst = np.array(embd_lst) + mu = np.mean(embd_lst, axis=0) + sigma = np.cov(embd_lst, rowvar=False) + return mu, sigma + + +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """ + Adapted from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py + Adapted from: https://github.com/gudgud96/frechet-audio-distance/blob/main/frechet_audio_distance/fad.py + + Numpy implementation of the Frechet Distance. + + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + + Params: + -- mu1: Embedding's mean statistics for generated samples. + -- mu2: Embedding's mean statistics for reference samples. + -- sigma1: Covariance matrix over embeddings for generated samples. + -- sigma2: Covariance matrix over embeddings for reference samples. + Returns: + -- Frรฉchet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, \ + 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, \ + 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + + # product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ('fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates') % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return (diff.dot(diff) + np.trace(sigma1) + + np.trace(sigma2) - 2 * tr_covmean) + + +def extract_embeddings(directory_path, channels, samplingrate, content_type, openl3_hop_size, batch_size=16): + """ + Given a list of files, compute their embeddings in batches. + + If channels == 1: stereo audio is downmixed to mono. Mono embeddings are of dim=512. + + If channels == 2: mono audio is "faked" to stereo by copying the mono channel. + Stereo embeddings are of dim=1024, since we concatenate L (dim=512) and R (dim=512) embeddings. + + Params: + -- directory_path: path where the generated audio files are available. + -- channels: 1 (mono), or 2 (stereo) to get mono or stereo embeddings. + -- samplingrate: max bandwidth at which we evaluate the given signals. Up to 48kHz. + -- content_type: 'music' or 'env' to select a content type specific openl3 model. + -- openl3_hop_size: analysis resolution of openl3 in seconds. Openl3's input window is 1 sec. + -- batch_size: number of audio files to process in each batch. + Returns: + -- list of embeddings: [np.array[], ...], as expected by calculate_frechet_distance() + """ + _, extension = os.path.splitext(directory_path) + if extension.lower() == ".scp": + wav_files = [] + with open(directory_path, "r") as f: + for line in f: + sec = line.strip().split(" ") + wav_files.append(sec[1]) + else: + wav_files = glob.glob(directory_path) + if len(wav_files) == 0: + raise ValueError('No files with this extension in this path!') + model = openl3.models.load_audio_embedding_model(input_repr="mel256", content_type=content_type, embedding_size=512) + + first = True + for i in tqdm(range(0, len(wav_files), batch_size)): + batch_files = wav_files[i:i+batch_size] + batch_audio_l = [] + batch_audio_r = [] + batch_sr = [] + + for file in batch_files: + audio, sr = librosa.load(file, sr=None, mono=False) + audio = audio.T + audio = pyln.normalize.peak(audio, -1.0) + if audio.shape[0] < sr: + print('Audio shorter than 1 sec, openl3 will zero-pad it:', file, audio.shape, sr) + + # resample to the desired evaluation bandwidth + audio = soxr.resample(audio, sr, samplingrate) # mono/stereo <- mono/stereo, input sr, output sr + + # mono embeddings are stored in batch_audio_l (R channel not used) + if channels == 1: + batch_audio_l.append(audio) + + elif channels == 2: + if audio.ndim == 1: + # if mono, "fake" stereo by copying mono channel to L and R + batch_audio_l.append(audio) + batch_audio_r.append(audio) + elif audio.ndim == 2: + # if it's stereo separate channels for openl3 + batch_audio_l.append(audio[:,0]) + batch_audio_r.append(audio[:,1]) + + batch_sr.append(samplingrate) + + # extracting mono embeddings (dim=512) or the L channel for stereo embeddings + emb, _ = openl3.get_audio_embedding(batch_audio_l, batch_sr, model=model, verbose=False, hop_size=openl3_hop_size, batch_size=batch_size) + + # format mono embedding + if channels == 1: + emb = np.concatenate(emb,axis=0) + + # extracting stereo embeddings (dim=1024), since we concatenate L (dim=512) and R (dim=512) embeddings + elif channels == 2: + # extract the missing R channel + emb_r, _ = openl3.get_audio_embedding(batch_audio_r, batch_sr, model=model, verbose=False, hop_size=openl3_hop_size, batch_size=batch_size) + emb = [np.concatenate([l, r], axis=1) for l, r in zip(emb, emb_r)] + emb = np.concatenate(emb, axis=0) + + # concatenate embeddings + if first: + embeddings = emb + first = False + else: + embeddings = np.concatenate([embeddings, emb], axis=0) + + # return as a list of embeddings: [np.array[], ...] + return [e for e in embeddings] + + +def extract_embeddings_nobatching(directory_path, channels, samplingrate, content_type, openl3_hop_size): + """ + Given a list of files, compute their embeddings one by one. + + If channels == 1: stereo audio is downmixed to mono. Mono embeddings are of dim=512. + + If channels == 2: mono audio is "faked" to stereo by copying the mono channel. + Stereo embeddings are of dim=1024, since we concatenate L (dim=512) and R (dim=512) embeddings. + + Params: + -- directory_path: path where the generated audio files are available. + -- channels: 1 (mono), or 2 (stereo) to get mono or stereo embeddings. + -- samplingrate: max bandwidth at which we evaluate the given signals. Up to 48kHz. + -- content_type: 'music' or 'env' to select a content type specific openl3 model. + -- openl3_hop_size: analysis resolution of openl3 in seconds. Openl3's input window is 1 sec. + Returns: + -- list of embeddings: [np.array[], ...], as expected by calculate_frechet_distance() + """ + _, extension = os.path.splitext(directory_path) + if extension.lower() == ".scp": + wav_files = [] + with open(directory_path, "r") as f: + for line in f: + sec = line.strip().split(" ") + wav_files.append(sec[1]) + else: + wav_files = glob.glob(directory_path) + if len(wav_files) == 0: + raise ValueError('No files with this extension in this path!') + model = openl3.models.load_audio_embedding_model(input_repr="mel256", content_type=content_type, embedding_size=512) + + first = True + for file in tqdm(wav_files): + audio, sr = librosa.load(file, sr=None) + audio = pyln.normalize.peak(audio, -1.0) + if audio.shape[0] < sr: + print('Audio shorter than 1 sec, openl3 will zero-pad it:', file, audio.shape, sr) + + # resample to the desired evaluation bandwidth + audio = soxr.resample(audio, sr, samplingrate) # mono/stereo <- mono/stereo, input sr, output sr + + # extracting stereo embeddings (dim=1024), since we concatenate L (dim=512) and R (dim=512) embeddings + if channels == 2: + if audio.ndim == 1: + audio_l3, sr_l3 = audio, samplingrate + elif audio.ndim == 2: + # if it's stereo separate channels for openl3 + audio_l3 = [audio[:,0], audio[:,1]] + sr_l3 = [samplingrate, samplingrate] + emb, _ = openl3.get_audio_embedding(audio_l3, sr_l3, model=model, verbose=False, hop_size=openl3_hop_size) + if audio.ndim == 1: + # if mono audio, "fake" stereo by concatenating mono embedding as L and R embeddings + emb = np.concatenate([emb, emb],axis=1) + elif audio.ndim == 2: + emb = np.concatenate(emb,axis=1) + + # or extracting mono embeddings (dim=512) + elif channels == 1: + emb, _ = openl3.get_audio_embedding(audio, samplingrate, model=model, verbose=False, hop_size=openl3_hop_size) + + # concatenate embeddings + if first: + embeddings = emb + first = False + else: + embeddings = np.concatenate([embeddings, emb], axis=0) + + # return as a list of embeddings: [np.array[], ...] + return [e for e in embeddings] + + +def openl3_fd(channels, samplingrate, content_type, openl3_hop_size, eval_path, + eval_files_extension='.wav', ref_path=None, ref_files_extension='.wav', load_ref_embeddings=None, batching=False): + """ + Compute the Frรฉchet Distance between files in eval_path and ref_path. + + Frรฉchet distance computed on top of openl3 embeddings. + + GPU-based computation. + + Extracting the embeddings is timeconsuming. After being computed once, we store them. + We store pre-computed reference embedding statistics in load/openl3_fd/ + To load those and save computation, just set the path in load_ref_embeddings. + If load_ref_embeddings is set, ref_path is not required. + + Params: + -- channels: 1 (mono), or 2 (stereo) to get the Frรฉchet Distance over mono or stereo embeddings. + -- samplingrate: max bandwith at wich we evaluate the given signals. Up to 48kHz. + -- content_type: 'music' or 'env' to select a content type for openl3. + -- openl3_hop_size: analysis resolution of openl3 in seconds. Openl3's input window is 1 sec. + -- eval_path: path where the generated audio files to evaluate are available. + -- eval_files_extenstion: files extension (default .wav) in eval_path. + -- ref_path: path where the reference audio files are available. (instead of load_ref_embeddings) + -- ref_files_extension: files extension (default .wav) in ref_path. + -- load_ref_embeddings: path to the reference embedding statistics. (inestead of ref_path) + -- batching: set batch size (with an int) or set to False (default False). + Returns: + -- Frรฉchet distance. + """ + + if not os.path.isdir(eval_path): + raise ValueError('eval_path does not exist') + + if load_ref_embeddings: + if not os.path.exists(load_ref_embeddings): + raise ValueError('load_ref_embeddings does not exist') + print('[LOADING REFERENCE EMBEDDINGS] ', load_ref_embeddings) + loaded = np.load(load_ref_embeddings) + mu_ref = loaded['mu_ref'] + sigma_ref = loaded['sigma_ref'] + + else: + if ref_path: + if not os.path.isdir(ref_path): + if not os.path.isfile(ref_path): + raise ValueError("ref_path does not exist") + if os.path.isfile(ref_path): + path = ref_path + else: + path = os.path.join(ref_path, '*'+ref_files_extension) + print('[EXTRACTING REFERENCE EMBEDDINGS] ', path) + if batching: + ref_embeddings = extract_embeddings(path, channels, samplingrate, content_type, openl3_hop_size, batch_size=batching) + else: + ref_embeddings = extract_embeddings_nobatching(path, channels, samplingrate, content_type, openl3_hop_size) + mu_ref, sigma_ref = calculate_embd_statistics(ref_embeddings) + + # store statistics to load later on + if not os.path.exists('load/openl3_fd'): + os.makedirs('load/openl3_fd/') + save_ref_embeddings_path = ( + 'load/openl3_fd/' + + path.replace('/', '_') + + '__channels' + str(channels) + + '__' + str(samplingrate) + + '__openl3' + str(content_type) + + '__openl3hopsize' + str(openl3_hop_size) + + '__batch' + str(batching) + + '.npz' + ) + np.savez(save_ref_embeddings_path, mu_ref=mu_ref, sigma_ref=sigma_ref) + print('[REFERENCE EMBEDDINGS][SAVED] ', save_ref_embeddings_path) + + else: + raise ValueError('Must specify ref_path or load_ref_embeddings') + + path = os.path.join(eval_path, '*'+eval_files_extension) + print('[EXTRACTING EVALUATION EMBEDDINGS] ', path) + if batching: + eval_embeddings = extract_embeddings(path, channels, samplingrate, content_type, openl3_hop_size, batch_size=batching) + else: + eval_embeddings = extract_embeddings_nobatching(path, channels, samplingrate, content_type, openl3_hop_size) + mu_eval, sigma_eval = calculate_embd_statistics(eval_embeddings) + + fd = calculate_frechet_distance(mu_eval, sigma_eval, mu_ref, sigma_ref) + if load_ref_embeddings: + print('[FRรฉCHET DISTANCE] ', eval_path, load_ref_embeddings, fd) + else: + print('[FRรฉCHET DISTANCE] ', eval_path, ref_path, fd) + + return fd \ No newline at end of file diff --git a/inspiremusic/metrics/passt_kld.py b/inspiremusic/metrics/passt_kld.py new file mode 100644 index 0000000000000000000000000000000000000000..aa27835ee82161f9e7cd1c7f9d99c07409bbbfc0 --- /dev/null +++ b/inspiremusic/metrics/passt_kld.py @@ -0,0 +1,232 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", category=FutureWarning) + +import os +import contextlib +from functools import partial +from tqdm import tqdm +import pickle +import numpy as np +import librosa +from hear21passt.base import get_basic_model +import pyloudnorm as pyln + +import torch +import torch.nn.functional as F + + +SAMPLING_RATE = 32000 + + +class _patch_passt_stft: + """ + From version 1.8.0, return_complex must always be given explicitly + for real inputs and return_complex=False has been deprecated. + + Decorator to patch torch.stft in PaSST that uses an old stft version. + + Adapted from: https://github.com/facebookresearch/audiocraft/blob/a2b96756956846e194c9255d0cdadc2b47c93f1b/audiocraft/metrics/kld.py + """ + def __init__(self): + self.old_stft = torch.stft + + def __enter__(self): + # return_complex is a mandatory parameter in latest torch versions. + # torch is throwing RuntimeErrors when not set. + # see: https://pytorch.org/docs/1.7.1/generated/torch.stft.html?highlight=stft#torch.stft + #ย see: https://github.com/kkoutini/passt_hear21/commit/dce83183674e559162b49924d666c0a916dc967a + torch.stft = partial(torch.stft, return_complex=False) + + def __exit__(self, *exc): + torch.stft = self.old_stft + + +def return_probabilities(model, audio_path, window_size=10, overlap=5, collect='mean'): + """ + Given an audio and the PaSST model, return the probabilities of each AudioSet class. + + Audio is converted to mono at 32kHz. + + PaSST model is trained with 10 sec inputs. We refer to this parameter as the window_size. + We set it to 10 sec for consistency with PaSST training. + + For longer audios, we split audio into overlapping analysis windows of window_size and overlap of 10 and 5 seconds. + PaSST supports 10, 20 or 30 sec inputs. Not longer inputs: https://github.com/kkoutini/PaSST/issues/19 + + Note that AudioSet taggers normally use sigmoid output layers. Yet, to compute the + KL we work with normalized probabilities by running a softmax over logits as in MusicGen: + https://github.com/facebookresearch/audiocraft/blob/a2b96756956846e194c9255d0cdadc2b47c93f1b/audiocraft/metrics/kld.py + + This implementation assumes run will be on GPU. + + Params: + -- model: PaSST model on a GPU. + -- audio_path: path to the audio to be loaded with librosa. + -- window_size (default=10 sec): analysis window (and receptive field) of PaSST. + -- overlap (default=5 sec): overlap of the running analysis window for inputs longar than window_size (10 sec). + -- collect (default='mean'): for longer inputs, aggregate/collect via 'mean' or 'max' pooling along logits vector. + Returns: + -- 527 probabilities (after softmax, no logarithm). + """ + # load the audio using librosa + audio, _ = librosa.load(audio_path, sr=SAMPLING_RATE, mono=True) + audio = pyln.normalize.peak(audio, -1.0) + + # calculate the step size for the analysis windows with the specified overlap + step_size = int((window_size - overlap) * SAMPLING_RATE) + + # iterate over the audio, creating analysis windows + probabilities = [] + for i in range(0, max(step_size, len(audio) - step_size), step_size): + # extract the current analysis window + window = audio[i:i + int(window_size * SAMPLING_RATE)] + + # pad the window with zeros if it's shorter than the desired window size + if len(window) < int(window_size * SAMPLING_RATE): + # discard window if it's too small (avoid mostly zeros predicted as silence), as in MusicGen: + # https://github.com/facebookresearch/audiocraft/blob/a2b96756956846e194c9255d0cdadc2b47c93f1b/audiocraft/metrics/kld.py + if len(window) > int(window_size * SAMPLING_RATE * 0.15): + tmp = np.zeros(int(window_size * SAMPLING_RATE)) + tmp[:len(window)] = window + window = tmp + + # convert to a PyTorch tensor and move to GPU + audio_wave = torch.from_numpy(window.astype(np.float32)).unsqueeze(0).cuda() + + # get the probabilities for this analysis window + with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f): + with torch.no_grad(), _patch_passt_stft(): + logits = model(audio_wave) + probabilities.append(torch.squeeze(logits)) + + probabilities = torch.stack(probabilities) + if collect == 'mean': + probabilities = torch.mean(probabilities, dim=0) + elif collect == 'max': + probabilities, _ = torch.max(probabilities, dim=0) + + return F.softmax(probabilities, dim=0).squeeze().cpu() + + +def passt_kld(ids, eval_path, eval_files_extension='.wav', ref_path=None, ref_files_extension='.wav', load_ref_probabilities=None, no_ids=[], collect='mean'): + """ + Compute KL-divergence between the label probabilities of the generated audio with respect to the original audio. + Both generated audio (in eval_path) and original audio (in ref_path) are represented by the same prompt/description. + Audios are identified by an id, that is the name of the file in both directories and links the audio with the prompt/description. + segmenting the audio + + For inputs longer that the 10 sec PaSST was trained on, we aggregate/collect via 'mean' (default) or 'max' pooling along the logits vector. + We split the inpot into overlapping analysis windows. Subsequently, we aggregate/collect (accross windows) the generated logits and then apply a softmax. + + This evaluation script assumes that ids are in both ref_path and eval_path. + + We label probabilities via the PaSST model: https://github.com/kkoutini/PaSST + + GPU-based computation. + + Extracting the probabilities is timeconsuming. After being computed once, we store them. + We store pre-computed reference probabilities in load/ + To load those and save computation, just set the path in load_ref_probabilities. + If load_ref_probabilities is set, ref_path is not required. + + Params: + -- ids: list of ids present in both eval_path and ref_path. + -- eval_path: path where the generated audio files to evaluate are available. + -- eval_files_extenstion: files extension (default .wav) in eval_path. + -- ref_path: path where the reference audio files are available. (instead of load_ref_probabilities) + -- ref_files_extenstion: files extension (default .wav) in ref_path. + -- load_ref_probabilities: path to the reference probabilities. (inestead of ref_path) + -- no_ids: it is possible that some reference audio is corrupted or not present. Ignore some this list of ids. + -- collect (default='mean'): for longer inputs, aggregate/collect via 'mean' or 'max' pooling along the logits vector. + Returns: + -- KL divergence + """ + with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f): # capturing all useless outputs from passt + # load model + model = get_basic_model(mode="logits") + model.eval() + model = model.cuda() + + if not os.path.isdir(eval_path): + if not os.path.isfile(eval_path): + raise ValueError('eval_path does not exist') + + if load_ref_probabilities: + if not os.path.exists(load_ref_probabilities): + raise ValueError('load_ref_probabilities does not exist') + print('[LOADING REFERENCE PROBABILITIES] ', load_ref_probabilities) + with open(load_ref_probabilities, 'rb') as fp: + ref_p = pickle.load(fp) + + else: + if ref_path: + if not os.path.isdir(ref_path): + if os.path.isfile(ref_path): + id2utt = {} + with open(ref_path, "r") as f: + for line in f: + sec = line.strip().split(" ") + id2utt[sec[0]] = sec[1] + f.close() + else: + raise ValueError("ref_path does not exist") + print('[EXTRACTING REFERENCE PROBABILITIES] ', ref_path) + ref_p = {} + for id in tqdm(ids): + if id not in no_ids: + try: + if os.path.isfile(ref_path): + if id in id2utt.keys(): + audio_path = id2utt[id] + else: + raise ValueError(f"id: {id} not in {ref_path}!") + else: + audio_path = os.path.join(ref_path, str(id)+ref_files_extension) + if os.path.isfile(audio_path): + ref_p[id] = return_probabilities(model, audio_path, collect=collect) + except Exception as e: + print(f"An unexpected error occurred with {id}: {e}\nIf you failed to download it you can add it to no_ids list.") + + # store reference probabilities to load later on + if not os.path.exists('load/passt_kld/'): + os.makedirs('load/passt_kld/') + save_ref_probabilities_path = 'load/passt_kld/'+ref_path.replace('/', '_')+'_collect'+str(collect)+'__reference_probabilities.pkl' + with open(save_ref_probabilities_path, 'wb') as fp: + pickle.dump(ref_p, fp) + print('[REFERENCE EMBEDDINGS][SAVED] ', save_ref_probabilities_path) + + else: + raise ValueError('Must specify ref_path or load_ref_probabilities') + + print('[EVALUATING GENERATIONS] ', eval_path) + + passt_kl = 0 + count = 0 + for id in tqdm(ids): + if id not in no_ids: + try: + audio_path = os.path.join(eval_path, str(id)+eval_files_extension) + if os.path.isfile(audio_path): + eval_p = return_probabilities(model, audio_path, collect=collect) + # note: F.kl_div(x, y) is KL(y||x) + # see: https://github.com/pytorch/pytorch/issues/7337 + # see: https://discuss.pytorch.org/t/kl-divergence-different-results-from-tf/56903/2 + passt_kl += F.kl_div((ref_p[id] + 1e-6).log(), eval_p, reduction='sum', log_target=False) + count += 1 + except Exception as e: + print(f"An unexpected error occurred with {id}: {e}\nIf you failed to download it you can add it to no_ids list.") + return passt_kl / count if count > 0 else 0 diff --git a/inspiremusic/music_tokenizer/__init__.py b/inspiremusic/music_tokenizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inspiremusic/music_tokenizer/env.py b/inspiremusic/music_tokenizer/env.py new file mode 100644 index 0000000000000000000000000000000000000000..fe5843b8a244e9648bcd6f9e085dff9faa2e921a --- /dev/null +++ b/inspiremusic/music_tokenizer/env.py @@ -0,0 +1,29 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def build_env(config, config_name, path): + t_path = os.path.join(path, config_name) + if config != t_path: + os.makedirs(path, exist_ok=True) + shutil.copyfile(config, os.path.join(path, config_name)) diff --git a/inspiremusic/music_tokenizer/meldataset.py b/inspiremusic/music_tokenizer/meldataset.py new file mode 100644 index 0000000000000000000000000000000000000000..4a13b6247ec04b74559976687ff108c3380d42f1 --- /dev/null +++ b/inspiremusic/music_tokenizer/meldataset.py @@ -0,0 +1,226 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# code based on https://github.com/b04901014/MQTTS +import math +import os +import random + +import librosa +import numpy as np +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn + +def load_wav(full_path, sr): + wav, sr = librosa.load(full_path, sr=sr) + return wav, sr + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + +mel_basis = {} +hann_window = {} + +## modified to get stft with return complex value = True for pytorch ver2.0 +def mel_spectrogram(y, + n_fft, + num_mels, + sampling_rate, + hop_size, + win_size, + fmin, + fmax, + center=False): + + global mel_basis, hann_window + if fmax not in mel_basis: + mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) + mel_basis[str(fmax) + '_' + + str(y.device)] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), (int((n_fft - hop_size) / 2), int( + (n_fft - hop_size) / 2)), + mode='reflect') + y = y.squeeze(1) + + spec = torch.view_as_real(torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[str(y.device)], + center=center, + pad_mode='reflect', + normalized=False, + onesided=True, + return_complex=True + )) + + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + + spec = torch.matmul(mel_basis[str(fmax) + '_' + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec + + +def get_dataset_filelist(a): + with open(a.input_training_file, 'r') as f: + training_files = [l.strip() for l in f] + with open(a.input_validation_file, 'r') as f: + validation_files = [l.strip() for l in f] + return training_files, validation_files + + +class MelDataset(torch.utils.data.Dataset): + def __init__(self, + training_files, + segment_size, + n_fft, + num_mels, + hop_size, + win_size, + sampling_rate, + fmin, + fmax, + split=True, + shuffle=True, + n_cache_reuse=1, + device=None, + fmax_loss=None, + fine_tuning=False, + base_mels_path=None): + self.audio_files = training_files + random.seed(1234) + if shuffle: + random.shuffle(self.audio_files) + self.segment_size = segment_size + self.sampling_rate = sampling_rate + self.split = split + self.n_fft = n_fft + self.num_mels = num_mels + self.hop_size = hop_size + self.win_size = win_size + self.fmin = fmin + self.fmax = fmax + self.fmax_loss = fmax_loss + self.cached_wav = None + self.n_cache_reuse = n_cache_reuse + self._cache_ref_count = 0 + self.device = device + self.fine_tuning = fine_tuning + self.base_mels_path = base_mels_path + + def __getitem__(self, index): + filename = self.audio_files[index] + if self._cache_ref_count == 0: + try: + # Note by yuantian: load with the sample_rate of config + audio, sampling_rate = load_wav(filename, sr=self.sampling_rate) + except Exception as e: + print(f"Error on audio: {filename}") + audio = np.random.normal(size=(160000, )) * 0.05 + sampling_rate = self.sampling_rate + self.cached_wav = audio + if sampling_rate != self.sampling_rate: + raise ValueError("{} SR doesn't match target {} SR".format( + sampling_rate, self.sampling_rate)) + self._cache_ref_count = self.n_cache_reuse + else: + audio = self.cached_wav + self._cache_ref_count -= 1 + + audio = torch.FloatTensor(audio) + audio = audio.unsqueeze(0) + + if not self.fine_tuning: + if self.split: + if audio.size(1) >= self.segment_size: + max_audio_start = audio.size(1) - self.segment_size + audio_start = random.randint(0, max_audio_start) + audio = audio[:, audio_start:audio_start + + self.segment_size] + else: + audio = torch.nn.functional.pad(audio, ( + 0, self.segment_size - audio.size(1)), 'constant') + + mel = mel_spectrogram( + audio, + self.n_fft, + self.num_mels, + self.sampling_rate, + self.hop_size, + self.win_size, + self.fmin, + self.fmax, + center=False) + else: + mel = np.load( + os.path.join(self.base_mels_path, + os.path.splitext(os.path.split(filename)[-1])[0] + + '.npy')) + mel = torch.from_numpy(mel) + + if len(mel.shape) < 3: + mel = mel.unsqueeze(0) + + if self.split: + frames_per_seg = math.ceil(self.segment_size / self.hop_size) + + if audio.size(1) >= self.segment_size: + mel_start = random.randint(0, + mel.size(2) - frames_per_seg - 1) + mel = mel[:, :, mel_start:mel_start + frames_per_seg] + audio = audio[:, mel_start * self.hop_size:( + mel_start + frames_per_seg) * self.hop_size] + else: + mel = torch.nn.functional.pad(mel, ( + 0, frames_per_seg - mel.size(2)), 'constant') + audio = torch.nn.functional.pad(audio, ( + 0, self.segment_size - audio.size(1)), 'constant') + + mel_loss = mel_spectrogram( + audio, + self.n_fft, + self.num_mels, + self.sampling_rate, + self.hop_size, + self.win_size, + self.fmin, + self.fmax_loss, + center=False) + + return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) + + def __len__(self): + return len(self.audio_files) diff --git a/inspiremusic/music_tokenizer/models.py b/inspiremusic/music_tokenizer/models.py new file mode 100644 index 0000000000000000000000000000000000000000..86302c699224252e72b05971c6a52a2ba7e8764d --- /dev/null +++ b/inspiremusic/music_tokenizer/models.py @@ -0,0 +1,548 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import AvgPool1d +from torch.nn import Conv1d +from torch.nn import Conv2d +from torch.nn import ConvTranspose1d +from torch.nn.utils import remove_weight_norm +from torch.nn.utils import spectral_norm +from torch.nn.utils import weight_norm + +from inspiremusic.utils.tokenizer_utils import get_padding +from inspiremusic.utils.tokenizer_utils import init_weights + +LRELU_SLOPE = 0.1 + +class ResBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.h = h + self.convs1 = nn.ModuleList([ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1))), weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1))), weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.h = h + self.convs = nn.ModuleList([ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))) + ]) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Generator(torch.nn.Module): + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.conv_pre = weight_norm( + Conv1d(512, h.upsample_initial_channel, 7, 1, padding=3)) + resblock = ResBlock1 if h.resblock == '1' else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, + k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2**(i + 1)), + k, + u, + # padding=(u//2 + u%2), + padding=(k - u) // 2, + # output_padding=u%2 + ))) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2**(i + 1)) + for j, (k, d) in enumerate( + zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, + use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f( + Conv2d( + 1, + 32, (kernel_size, 1), (stride, 1), + padding=(get_padding(5, 1), 0))), + norm_f( + Conv2d( + 32, + 128, (kernel_size, 1), (stride, 1), + padding=(get_padding(5, 1), 0))), + norm_f( + Conv2d( + 128, + 512, (kernel_size, 1), (stride, 1), + padding=(get_padding(5, 1), 0))), + norm_f( + Conv2d( + 512, + 1024, (kernel_size, 1), (stride, 1), + padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self): + super(MultiPeriodDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorP(2), + DiscriminatorP(3), + DiscriminatorP(5), + DiscriminatorP(7), + DiscriminatorP(11), + ]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, 128, 15, 1, padding=7)), + norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiScaleDiscriminator(torch.nn.Module): + def __init__(self): + super(MultiScaleDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ]) + self.meanpools = nn.ModuleList( + [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + if i != 0: + y = self.meanpools[i - 1](y) + y_hat = self.meanpools[i - 1](y_hat) + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1 - dr)**2) + g_loss = torch.mean(dg**2) + loss += (r_loss + g_loss) + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1 - dg)**2) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + +class Encoder(torch.nn.Module): + def __init__(self, h): + super(Encoder, self).__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.conv_pre = weight_norm(Conv1d(1, 32, 7, 1, padding=3)) + self.normalize = nn.ModuleList() + resblock = ResBlock1 if h.resblock == '1' else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate( + list( + reversed( + list(zip(h.upsample_rates, h.upsample_kernel_sizes))))): + self.ups.append( + weight_norm( + Conv1d( + 32 * (2**i), + 32 * (2**(i + 1)), + k, + u, + padding=((k - u) // 2) + # padding=(u//2 + u%2) + ))) + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = 32 * (2**(i + 1)) + for j, (k, d) in enumerate( + zip( + list(reversed(h.resblock_kernel_sizes)), + list(reversed(h.resblock_dilation_sizes)))): + self.resblocks.append(resblock(h, ch, k, d)) + self.normalize.append( + torch.nn.GroupNorm(ch // 16, ch, eps=1e-6, affine=True)) + self.conv_post = Conv1d(512, 512, 3, 1, padding=1) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + xs = self.normalize[i * self.num_kernels + j](xs) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + xs = self.normalize[i * self.num_kernels + j](xs) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + + +class Quantizer_module(torch.nn.Module): + def __init__(self, n_e, e_dim): + super(Quantizer_module, self).__init__() + self.embedding = nn.Embedding(n_e, e_dim) + self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e) + + def forward(self, x): + # compute Euclidean distance + d = torch.sum(x ** 2, 1, keepdim=True) + torch.sum(self.embedding.weight ** 2, 1) \ + - 2 * torch.matmul(x, self.embedding.weight.T) + min_indicies = torch.argmin(d, 1) + z_q = self.embedding(min_indicies) + return z_q, min_indicies + + +class Quantizer(torch.nn.Module): + def __init__(self, h): + super(Quantizer, self).__init__() + assert 512 % h.n_code_groups == 0 + self.quantizer_modules = nn.ModuleList([ + Quantizer_module(h.n_codes, 512 // h.n_code_groups) + for _ in range(h.n_code_groups) + ]) + self.quantizer_modules2 = nn.ModuleList([ + Quantizer_module(h.n_codes, 512 // h.n_code_groups) + for _ in range(h.n_code_groups) + ]) + self.h = h + self.codebook_loss_lambda = self.h.codebook_loss_lambda # e.g., 1 + self.commitment_loss_lambda = self.h.commitment_loss_lambda # e.g., 0.25 + self.residul_layer = 2 + self.n_code_groups = h.n_code_groups + + def for_one_step(self, xin, idx): + xin = xin.transpose(1, 2) + x = xin.reshape(-1, 512) + x = torch.split(x, 512 // self.h.n_code_groups, dim=-1) + min_indicies = [] + z_q = [] + if idx == 0: + for _x, m in zip(x, self.quantizer_modules): + _z_q, _min_indicies = m(_x) + z_q.append(_z_q) + min_indicies.append(_min_indicies) #B * T, + z_q = torch.cat(z_q, -1).reshape(xin.shape) + # loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2) + loss = self.codebook_loss_lambda * torch.mean((z_q - xin.detach()) ** 2) \ + + self.commitment_loss_lambda * torch.mean((z_q.detach() - xin) ** 2) + z_q = xin + (z_q - xin).detach() + z_q = z_q.transpose(1, 2) + return z_q, loss, min_indicies + else: + for _x, m in zip(x, self.quantizer_modules2): + _z_q, _min_indicies = m(_x) + z_q.append(_z_q) + min_indicies.append(_min_indicies) #B * T, + z_q = torch.cat(z_q, -1).reshape(xin.shape) + # loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2) + loss = self.codebook_loss_lambda * torch.mean((z_q - xin.detach()) ** 2) \ + + self.commitment_loss_lambda * torch.mean((z_q.detach() - xin) ** 2) + z_q = xin + (z_q - xin).detach() + z_q = z_q.transpose(1, 2) + return z_q, loss, min_indicies + + def forward(self, xin): + #B, C, T + quantized_out = 0.0 + residual = xin + all_losses = [] + all_indices = [] + for i in range(self.residul_layer): + quantized, loss, indices = self.for_one_step(residual, i) # + residual = residual - quantized + quantized_out = quantized_out + quantized + all_indices.extend(indices) # + all_losses.append(loss) + all_losses = torch.stack(all_losses) + loss = torch.mean(all_losses) + return quantized_out, loss, all_indices + + def embed(self, x): + #idx: N, T, 4 + #print('x ', x.shape) + quantized_out = torch.tensor(0.0, device=x.device) + x = torch.split(x, 1, 2) # split, ๅฐ†ๆœ€ๅŽไธ€ไธช็ปดๅบฆๅˆ†ๅผ€, ๆฏไธชๅฑžไบŽไธ€ไธชindex group + #print('x.shape ', len(x),x[0].shape) + for i in range(self.residul_layer): + ret = [] + if i == 0: + for j in range(self.n_code_groups): + q = x[j] + embed = self.quantizer_modules[j] + q = embed.embedding(q.squeeze(-1).long()) + ret.append(q) + ret = torch.cat(ret, -1) + #print(ret.shape) + quantized_out = quantized_out + ret + else: + for j in range(self.n_code_groups): + q = x[j + self.n_code_groups] + embed = self.quantizer_modules2[j] + q = embed.embedding(q.squeeze(-1).long()) + ret.append(q) + ret = torch.cat(ret, -1) + quantized_out = quantized_out + ret + return quantized_out.transpose(1, 2) #N, C, T diff --git a/inspiremusic/music_tokenizer/vqvae.py b/inspiremusic/music_tokenizer/vqvae.py new file mode 100644 index 0000000000000000000000000000000000000000..553275b3c25bb2f44bc9009ebaacaef7a346e206 --- /dev/null +++ b/inspiremusic/music_tokenizer/vqvae.py @@ -0,0 +1,58 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import torch +import torch.nn as nn +from inspiremusic.music_tokenizer.env import AttrDict +from inspiremusic.music_tokenizer.models import Encoder +from inspiremusic.music_tokenizer.models import Generator +from inspiremusic.music_tokenizer.models import Quantizer + + +class VQVAE(nn.Module): + def __init__(self, + config_path, + ckpt_path, + with_encoder=False): + super(VQVAE, self).__init__() + ckpt = torch.load(ckpt_path) + with open(config_path) as f: + data = f.read() + json_config = json.loads(data) + self.h = AttrDict(json_config) + self.quantizer = Quantizer(self.h) + self.generator = Generator(self.h) + self.generator.load_state_dict(ckpt['generator']) + self.quantizer.load_state_dict(ckpt['quantizer']) + if with_encoder: + self.encoder = Encoder(self.h) + self.encoder.load_state_dict(ckpt['encoder']) + + def forward(self, x): + # x is the codebook + # x.shape (B, T, Nq) + quant_emb = self.quantizer.embed(x) + return self.generator(quant_emb) + + def encode(self, x): + batch_size = x.size(0) + if len(x.shape) == 3 and x.shape[-1] == 1: + x = x.squeeze(-1) + c = self.encoder(x.unsqueeze(1)) + q, loss_q, c = self.quantizer(c) + c = [code.reshape(batch_size, -1) for code in c] + # shape: [N, T, 4] + return torch.stack(c, -1) diff --git a/inspiremusic/text/abs_tokenizer.py b/inspiremusic/text/abs_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..050e3a79bc9f195b6bc069b71327d69651fa2d1b --- /dev/null +++ b/inspiremusic/text/abs_tokenizer.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC +from abc import abstractmethod +from typing import Iterable +from typing import List + + +class AbsTokenizer(ABC): + @abstractmethod + def text2tokens(self, line: str) -> List[str]: + raise NotImplementedError + + @abstractmethod + def tokens2text(self, tokens: Iterable[str]) -> str: + raise NotImplementedError + + + + def encode(self, line: str, **kwargs) -> List[str]: + + return self.text2tokens(line) \ No newline at end of file diff --git a/inspiremusic/text/tokenizer.py b/inspiremusic/text/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..4b2978783ce1833b797dad58c35ebebe64a492cf --- /dev/null +++ b/inspiremusic/text/tokenizer.py @@ -0,0 +1,76 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import os +import re +from typing import Iterable, List, Union +import numpy as np +import torch + +from inspiremusic.text.abs_tokenizer import AbsTokenizer +from transformers import AutoTokenizer + +def get_tokenizer(tokenizer_name, tokenizer_path): + if "qwen" in tokenizer_name: + return QwenTokenizer(tokenizer_path,skip_special_tokens=True) + else: + return None + +class QwenTokenizer(AbsTokenizer): + def __init__( + self, + token_path: str, + skip_special_tokens: bool = True, + ): + super().__init__() + # NOTE: non-chat model, all these special tokens keep randomly initialized. + special_tokens = { + 'eos_token': '<|endoftext|>', + 'pad_token': '<|endoftext|>', + 'additional_special_tokens': [ + '<|im_start|>', '<|im_end|>', '<|endofprompt|>', + '[breath]', '', '', '[noise]', + '[laughter]', '[cough]', '[clucking]', '[accent]', + '[quick_breath]', + ] + } + self.tokenizer = AutoTokenizer.from_pretrained(token_path) + self.tokenizer.add_special_tokens(special_tokens) + self.skip_special_tokens = skip_special_tokens + + def get_vocab_size(self): + return self.tokenizer.vocab_size + + def text2tokens(self, line: str) -> List: + tokens = self.tokenizer([line], return_tensors="pt") + tokens = tokens["input_ids"][0].cpu().tolist() + return tokens + + def tokens2text(self, tokens) -> str: + tokens = torch.tensor(tokens, dtype=torch.int64) + text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0] + return text + + + +def get_qwen_vocab_size(token_type: str): + if "qwen1.5" in token_type.lower() or "qwen2.0" in token_type.lower() or "qwen2.5" in token_type.lower(): + # 293 for special and extra tokens, including endoftext, im_start, im_end, endofprompt and others in the future. + # model.vocab_size = 151936, tokenizer.vocab_size = 151643 + # NOTE: the first three special tokens (endoftext, im_start, im_end) are trained in Chat series models, + # others are kept in random initialization state. + return 151643 + 293 + else: + raise ValueError(f"Unknown tokenizer {token_type}") \ No newline at end of file diff --git a/inspiremusic/transformer/__init__.py b/inspiremusic/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inspiremusic/transformer/activation.py b/inspiremusic/transformer/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..b87727d9d1d0ba6a0aeea5e7df21674f99926787 --- /dev/null +++ b/inspiremusic/transformer/activation.py @@ -0,0 +1,84 @@ +# Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe) +# 2020 Northwestern Polytechnical University (Pengcheng Guo) +# 2020 Mobvoi Inc (Binbin Zhang) +# 2024 Alibaba Inc (Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Swish() activation function for Conformer.""" + +import torch +from torch import nn, sin, pow +from torch.nn import Parameter + + +class Swish(torch.nn.Module): + """Construct an Swish object.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Return Swish activation function.""" + return x * torch.sigmoid(x) + + +# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. +# LICENSE is in incl_licenses directory. +class Snake(nn.Module): + ''' + Implementation of a sine-based periodic activation function + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter + References: + - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snake(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha: trainable parameter + alpha is initialized to 1 by default, higher values = higher-frequency. + alpha will be trained along with the rest of your model. + ''' + super(Snake, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + Snake โˆถ= x + 1/a * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + if self.alpha_logscale: + alpha = torch.exp(alpha) + x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x \ No newline at end of file diff --git a/inspiremusic/transformer/attention.py b/inspiremusic/transformer/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..ebf960e96ece8cc624c8c4b1ccd71a42910bfb62 --- /dev/null +++ b/inspiremusic/transformer/attention.py @@ -0,0 +1,328 @@ +# Copyright (c) 2019 Shigeki Karita +# 2020 Mobvoi Inc (Binbin Zhang) +# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Multi-Head Attention layer definition.""" + +import math +from typing import Tuple + +import torch +from torch import nn + + +class MultiHeadedAttention(nn.Module): + """Multi-Head Attention layer. + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, + n_head: int, + n_feat: int, + dropout_rate: float, + key_bias: bool = True): + """Construct an MultiHeadedAttention object.""" + super().__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = nn.Linear(n_feat, n_feat) + self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias) + self.linear_v = nn.Linear(n_feat, n_feat) + self.linear_out = nn.Linear(n_feat, n_feat) + self.dropout = nn.Dropout(p=dropout_rate) + + def forward_qkv( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Transform query, key and value. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + + Returns: + torch.Tensor: Transformed query tensor, size + (#batch, n_head, time1, d_k). + torch.Tensor: Transformed key tensor, size + (#batch, n_head, time2, d_k). + torch.Tensor: Transformed value tensor, size + (#batch, n_head, time2, d_k). + + """ + n_batch = query.size(0) + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) + k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + q = q.transpose(1, 2) # (batch, head, time1, d_k) + k = k.transpose(1, 2) # (batch, head, time2, d_k) + v = v.transpose(1, 2) # (batch, head, time2, d_k) + + return q, k, v + + def forward_attention( + self, + value: torch.Tensor, + scores: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool) + ) -> torch.Tensor: + """Compute attention context vector. + + Args: + value (torch.Tensor): Transformed value, size + (#batch, n_head, time2, d_k). + scores (torch.Tensor): Attention score, size + (#batch, n_head, time1, time2). + mask (torch.Tensor): Mask, size (#batch, 1, time2) or + (#batch, time1, time2), (0, 0, 0) means fake mask. + + Returns: + torch.Tensor: Transformed value (#batch, time1, d_model) + weighted by the attention score (#batch, time1, time2). + + """ + n_batch = value.size(0) + # NOTE(xcsong): When will `if mask.size(2) > 0` be True? + # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the + # 1st chunk to ease the onnx export.] + # 2. pytorch training + if mask.size(2) > 0: # time2 > 0 + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + # For last chunk, time2 might be larger than scores.size(-1) + mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2) + scores = scores.masked_fill(mask, -float('inf')) + attn = torch.softmax(scores, dim=-1).masked_fill( + mask, 0.0) # (batch, head, time1, time2) + # NOTE(xcsong): When will `if mask.size(2) > 0` be False? + # 1. onnx(16/-1, -1/-1, 16/0) + # 2. jit (16/-1, -1/-1, 16/0, 16/4) + else: + attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + + p_attn = self.dropout(attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = (x.transpose(1, 2).contiguous().view(n_batch, -1, + self.h * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + pos_emb: torch.Tensor = torch.empty(0), + cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute scaled dot product attention. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + 1.When applying cross attention between decoder and encoder, + the batch padding mask for input is in (#batch, 1, T) shape. + 2.When applying self attention of encoder, + the mask is in (#batch, T, T) shape. + 3.When applying self attention of decoder, + the mask is in (#batch, L, L) shape. + 4.If the different position in decoder see different block + of the encoder, such as Mocha, the passed in mask could be + in (#batch, L, T) shape. But there is no such case in current + InspireMusic. + cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + + """ + q, k, v = self.forward_qkv(query, key, value) + + # NOTE(xcsong): + # when export onnx model, for 1st chunk, we feed + # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) + # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). + # In all modes, `if cache.size(0) > 0` will alwayse be `True` + # and we will always do splitting and + # concatnation(this will simplify onnx export). Note that + # it's OK to concat & split zero-shaped tensors(see code below). + # when export jit model, for 1st chunk, we always feed + # cache(0, 0, 0, 0) since jit supports dynamic if-branch. + # >>> a = torch.ones((1, 2, 0, 4)) + # >>> b = torch.ones((1, 2, 3, 4)) + # >>> c = torch.cat((a, b), dim=2) + # >>> torch.equal(b, c) # True + # >>> d = torch.split(a, 2, dim=-1) + # >>> torch.equal(d[0], d[1]) # True + if cache.size(0) > 0: + key_cache, value_cache = torch.split(cache, + cache.size(-1) // 2, + dim=-1) + k = torch.cat([key_cache, k], dim=2) + v = torch.cat([value_cache, v], dim=2) + + # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's + # non-trivial to calculate `next_cache_start` here. + new_cache = torch.cat((k, v), dim=-1) + + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) + return self.forward_attention(v, scores, mask), new_cache + + +class RelPositionMultiHeadedAttention(MultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding. + Paper: https://arxiv.org/abs/1901.02860 + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + """ + + def __init__(self, + n_head: int, + n_feat: int, + dropout_rate: float, + key_bias: bool = True): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate, key_bias) + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x: torch.Tensor) -> torch.Tensor: + """Compute relative positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). + time1 means the length of query vector. + + Returns: + torch.Tensor: Output tensor. + + """ + zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1), + device=x.device, + dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(x.size()[0], + x.size()[1], + x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as(x)[ + :, :, :, : x.size(-1) // 2 + 1 + ] # only keep the positions from 0 to time2 + return x + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + pos_emb: torch.Tensor = torch.empty(0), + cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2), (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): Positional embedding tensor + (#batch, time2, size). + cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + """ + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + + # NOTE(xcsong): + # when export onnx model, for 1st chunk, we feed + # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) + # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). + # In all modes, `if cache.size(0) > 0` will alwayse be `True` + # and we will always do splitting and + # concatnation(this will simplify onnx export). Note that + # it's OK to concat & split zero-shaped tensors(see code below). + # when export jit model, for 1st chunk, we always feed + # cache(0, 0, 0, 0) since jit supports dynamic if-branch. + # >>> a = torch.ones((1, 2, 0, 4)) + # >>> b = torch.ones((1, 2, 3, 4)) + # >>> c = torch.cat((a, b), dim=2) + # >>> torch.equal(b, c) # True + # >>> d = torch.split(a, 2, dim=-1) + # >>> torch.equal(d[0], d[1]) # True + if cache.size(0) > 0: + key_cache, value_cache = torch.split(cache, + cache.size(-1) // 2, + dim=-1) + k = torch.cat([key_cache, k], dim=2) + v = torch.cat([value_cache, v], dim=2) + # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's + # non-trivial to calculate `next_cache_start` here. + new_cache = torch.cat((k, v), dim=-1) + + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, time1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + # compute matrix b and matrix d + # (batch, head, time1, time2) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used + if matrix_ac.shape != matrix_bd.shape: + matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2) + + return self.forward_attention(v, scores, mask), new_cache diff --git a/inspiremusic/transformer/convolution.py b/inspiremusic/transformer/convolution.py new file mode 100644 index 0000000000000000000000000000000000000000..4d5d96149154776000991a681a666fbe55e562fe --- /dev/null +++ b/inspiremusic/transformer/convolution.py @@ -0,0 +1,145 @@ +# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) +# 2024 Alibaba Inc (Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""ConvolutionModule definition.""" + +from typing import Tuple + +import torch +from torch import nn + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model.""" + + def __init__(self, + channels: int, + kernel_size: int = 15, + activation: nn.Module = nn.ReLU(), + norm: str = "batch_norm", + causal: bool = False, + bias: bool = True): + """Construct an ConvolutionModule object. + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernel size of conv layers. + causal (int): Whether use causal convolution or not + """ + super().__init__() + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + # self.lorder is used to distinguish if it's a causal convolution, + # if self.lorder > 0: it's a causal convolution, the input will be + # padded with self.lorder frames on the left in forward. + # else: it's a symmetrical convolution + if causal: + padding = 0 + self.lorder = kernel_size - 1 + else: + # kernel_size should be an odd number for none causal convolution + assert (kernel_size - 1) % 2 == 0 + padding = (kernel_size - 1) // 2 + self.lorder = 0 + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=padding, + groups=channels, + bias=bias, + ) + + assert norm in ['batch_norm', 'layer_norm'] + if norm == "batch_norm": + self.use_layer_norm = False + self.norm = nn.BatchNorm1d(channels) + else: + self.use_layer_norm = True + self.norm = nn.LayerNorm(channels) + + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = activation + + def forward( + self, + x: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + cache: torch.Tensor = torch.zeros((0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute convolution module. + Args: + x (torch.Tensor): Input tensor (#batch, time, channels). + mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), + (0, 0, 0) means fake mask. + cache (torch.Tensor): left context cache, it is only + used in causal convolution (#batch, channels, cache_t), + (0, 0, 0) meas fake cache. + Returns: + torch.Tensor: Output tensor (#batch, time, channels). + """ + # exchange the temporal dimension and the feature dimension + x = x.transpose(1, 2) # (#batch, channels, time) + + # mask batch padding + if mask_pad.size(2) > 0: # time > 0 + x.masked_fill_(~mask_pad, 0.0) + + if self.lorder > 0: + if cache.size(2) == 0: # cache_t == 0 + x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0) + else: + assert cache.size(0) == x.size(0) # equal batch + assert cache.size(1) == x.size(1) # equal channel + x = torch.cat((cache, x), dim=2) + assert (x.size(2) > self.lorder) + new_cache = x[:, :, -self.lorder:] + else: + # It's better we just return None if no cache is required, + # However, for JIT export, here we just fake one tensor instead of + # None. + new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channel, dim) + x = nn.functional.glu(x, dim=1) # (batch, channel, dim) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + if self.use_layer_norm: + x = x.transpose(1, 2) + x = self.activation(self.norm(x)) + if self.use_layer_norm: + x = x.transpose(1, 2) + x = self.pointwise_conv2(x) + # mask batch padding + if mask_pad.size(2) > 0: # time > 0 + x.masked_fill_(~mask_pad, 0.0) + + return x.transpose(1, 2), new_cache diff --git a/inspiremusic/transformer/decoder.py b/inspiremusic/transformer/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..7ebf3b5109dc01149874d5c9f8c6414474d303bb --- /dev/null +++ b/inspiremusic/transformer/decoder.py @@ -0,0 +1,396 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) +# 2024 Alibaba Inc (Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""Decoder definition.""" +from typing import Tuple, List, Optional + +import torch +import torch.utils.checkpoint as ckpt +import logging + +from inspiremusic.transformer.decoder_layer import DecoderLayer +from inspiremusic.transformer.positionwise_feed_forward import PositionwiseFeedForward +from inspiremusic.utils.class_utils import ( + INSPIREMUSIC_EMB_CLASSES, + INSPIREMUSIC_ATTENTION_CLASSES, + INSPIREMUSIC_ACTIVATION_CLASSES, +) +from inspiremusic.utils.mask import (subsequent_mask, make_pad_mask) + + +class TransformerDecoder(torch.nn.Module): + """Base class of Transfomer decoder module. + Args: + vocab_size: output dim + encoder_output_size: dimension of attention + attention_heads: the number of heads of multi head attention + linear_units: the hidden units number of position-wise feedforward + num_blocks: the number of decoder blocks + dropout_rate: dropout rate + self_attention_dropout_rate: dropout rate for attention + input_layer: input layer type + use_output_layer: whether to use output layer + pos_enc_class: PositionalEncoding or ScaledPositionalEncoding + normalize_before: + True: use layer_norm before each sub-block of a layer. + False: use layer_norm after each sub-block of a layer. + src_attention: if false, encoder-decoder cross attention is not + applied, such as CIF model + key_bias: whether use bias in attention.linear_k, False for whisper models. + gradient_checkpointing: rerunning a forward-pass segment for each + checkpointed segment during backward. + tie_word_embedding: Tie or clone module weights depending of whether we are + using TorchScript or not + """ + + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + self_attention_dropout_rate: float = 0.0, + src_attention_dropout_rate: float = 0.0, + input_layer: str = "embed", + use_output_layer: bool = True, + normalize_before: bool = True, + src_attention: bool = True, + key_bias: bool = True, + activation_type: str = "relu", + gradient_checkpointing: bool = False, + tie_word_embedding: bool = False, + ): + super().__init__() + attention_dim = encoder_output_size + activation = INSPIREMUSIC_ACTIVATION_CLASSES[activation_type]() + + self.embed = torch.nn.Sequential( + torch.nn.Identity() if input_layer == "no_pos" else + torch.nn.Embedding(vocab_size, attention_dim), + INSPIREMUSIC_EMB_CLASSES[input_layer](attention_dim, + positional_dropout_rate), + ) + + self.normalize_before = normalize_before + self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5) + self.use_output_layer = use_output_layer + if use_output_layer: + self.output_layer = torch.nn.Linear(attention_dim, vocab_size) + else: + self.output_layer = torch.nn.Identity() + self.num_blocks = num_blocks + self.decoders = torch.nn.ModuleList([ + DecoderLayer( + attention_dim, + INSPIREMUSIC_ATTENTION_CLASSES["selfattn"]( + attention_heads, attention_dim, + self_attention_dropout_rate, key_bias), + INSPIREMUSIC_ATTENTION_CLASSES["selfattn"]( + attention_heads, attention_dim, src_attention_dropout_rate, + key_bias) if src_attention else None, + PositionwiseFeedForward(attention_dim, linear_units, + dropout_rate, activation), + dropout_rate, + normalize_before, + ) for _ in range(self.num_blocks) + ]) + + self.gradient_checkpointing = gradient_checkpointing + self.tie_word_embedding = tie_word_embedding + + def forward( + self, + memory: torch.Tensor, + memory_mask: torch.Tensor, + ys_in_pad: torch.Tensor, + ys_in_lens: torch.Tensor, + r_ys_in_pad: torch.Tensor = torch.empty(0), + reverse_weight: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward decoder. + Args: + memory: encoded memory, float32 (batch, maxlen_in, feat) + memory_mask: encoder memory mask, (batch, 1, maxlen_in) + ys_in_pad: padded input token ids, int64 (batch, maxlen_out) + ys_in_lens: input lengths of this batch (batch) + r_ys_in_pad: not used in transformer decoder, in order to unify api + with bidirectional decoder + reverse_weight: not used in transformer decoder, in order to unify + api with bidirectional decode + Returns: + (tuple): tuple containing: + x: decoded token score before softmax (batch, maxlen_out, + vocab_size) if use_output_layer is True, + torch.tensor(0.0), in order to unify api with bidirectional decoder + olens: (batch, ) + NOTE(xcsong): + We pass the `__call__` method of the modules instead of `forward` to the + checkpointing API because `__call__` attaches all the hooks of the module. + https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 + """ + tgt = ys_in_pad + maxlen = tgt.size(1) + # tgt_mask: (B, 1, L) + tgt_mask = ~make_pad_mask(ys_in_lens, maxlen).unsqueeze(1) + tgt_mask = tgt_mask.to(tgt.device) + # m: (1, L, L) + m = subsequent_mask(tgt_mask.size(-1), + device=tgt_mask.device).unsqueeze(0) + # tgt_mask: (B, L, L) + tgt_mask = tgt_mask & m + x, _ = self.embed(tgt) + if self.gradient_checkpointing and self.training: + x = self.forward_layers_checkpointed(x, tgt_mask, memory, + memory_mask) + else: + x = self.forward_layers(x, tgt_mask, memory, memory_mask) + if self.normalize_before: + x = self.after_norm(x) + if self.use_output_layer: + x = self.output_layer(x) + olens = tgt_mask.sum(1) + return x, torch.tensor(0.0), olens + + def forward_layers(self, x: torch.Tensor, tgt_mask: torch.Tensor, + memory: torch.Tensor, + memory_mask: torch.Tensor) -> torch.Tensor: + for layer in self.decoders: + x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory, + memory_mask) + return x + + @torch.jit.unused + def forward_layers_checkpointed(self, x: torch.Tensor, + tgt_mask: torch.Tensor, + memory: torch.Tensor, + memory_mask: torch.Tensor) -> torch.Tensor: + for layer in self.decoders: + x, tgt_mask, memory, memory_mask = ckpt.checkpoint( + layer.__call__, x, tgt_mask, memory, memory_mask) + return x + + def forward_one_step( + self, + memory: torch.Tensor, + memory_mask: torch.Tensor, + tgt: torch.Tensor, + tgt_mask: torch.Tensor, + cache: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """Forward one step. + This is only used for decoding. + Args: + memory: encoded memory, float32 (batch, maxlen_in, feat) + memory_mask: encoded memory mask, (batch, 1, maxlen_in) + tgt: input token ids, int64 (batch, maxlen_out) + tgt_mask: input token mask, (batch, maxlen_out) + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (include 1.2) + cache: cached output list of (batch, max_time_out-1, size) + Returns: + y, cache: NN output value and cache per `self.decoders`. + y.shape` is (batch, maxlen_out, token) + """ + x, _ = self.embed(tgt) + new_cache = [] + for i, decoder in enumerate(self.decoders): + if cache is None: + c = None + else: + c = cache[i] + x, tgt_mask, memory, memory_mask = decoder(x, + tgt_mask, + memory, + memory_mask, + cache=c) + new_cache.append(x) + if self.normalize_before: + y = self.after_norm(x[:, -1]) + else: + y = x[:, -1] + if self.use_output_layer: + y = torch.log_softmax(self.output_layer(y), dim=-1) + return y, new_cache + + def tie_or_clone_weights(self, jit_mode: bool = True): + """Tie or clone module weights (between word_emb and output_layer) + depending of whether we are using TorchScript or not""" + if not self.use_output_layer: + return + if jit_mode: + logging.info("clone emb.weight to output.weight") + self.output_layer.weight = torch.nn.Parameter( + self.embed[0].weight.clone()) + else: + logging.info("tie emb.weight with output.weight") + self.output_layer.weight = self.embed[0].weight + + if getattr(self.output_layer, "bias", None) is not None: + self.output_layer.bias.data = torch.nn.functional.pad( + self.output_layer.bias.data, + ( + 0, + self.output_layer.weight.shape[0] - + self.output_layer.bias.shape[0], + ), + "constant", + 0, + ) + + +class BiTransformerDecoder(torch.nn.Module): + """Base class of Transfomer decoder module. + Args: + vocab_size: output dim + encoder_output_size: dimension of attention + attention_heads: the number of heads of multi head attention + linear_units: the hidden units number of position-wise feedforward + num_blocks: the number of decoder blocks + r_num_blocks: the number of right to left decoder blocks + dropout_rate: dropout rate + self_attention_dropout_rate: dropout rate for attention + input_layer: input layer type + use_output_layer: whether to use output layer + pos_enc_class: PositionalEncoding or ScaledPositionalEncoding + normalize_before: + True: use layer_norm before each sub-block of a layer. + False: use layer_norm after each sub-block of a layer. + key_bias: whether use bias in attention.linear_k, False for whisper models. + """ + + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + r_num_blocks: int = 0, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + self_attention_dropout_rate: float = 0.0, + src_attention_dropout_rate: float = 0.0, + input_layer: str = "embed", + use_output_layer: bool = True, + normalize_before: bool = True, + key_bias: bool = True, + gradient_checkpointing: bool = False, + tie_word_embedding: bool = False, + ): + + super().__init__() + self.tie_word_embedding = tie_word_embedding + self.left_decoder = TransformerDecoder( + vocab_size, + encoder_output_size, + attention_heads, + linear_units, + num_blocks, + dropout_rate, + positional_dropout_rate, + self_attention_dropout_rate, + src_attention_dropout_rate, + input_layer, + use_output_layer, + normalize_before, + key_bias=key_bias, + gradient_checkpointing=gradient_checkpointing, + tie_word_embedding=tie_word_embedding) + + self.right_decoder = TransformerDecoder( + vocab_size, + encoder_output_size, + attention_heads, + linear_units, + r_num_blocks, + dropout_rate, + positional_dropout_rate, + self_attention_dropout_rate, + src_attention_dropout_rate, + input_layer, + use_output_layer, + normalize_before, + key_bias=key_bias, + gradient_checkpointing=gradient_checkpointing, + tie_word_embedding=tie_word_embedding) + + def forward( + self, + memory: torch.Tensor, + memory_mask: torch.Tensor, + ys_in_pad: torch.Tensor, + ys_in_lens: torch.Tensor, + r_ys_in_pad: torch.Tensor, + reverse_weight: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward decoder. + Args: + memory: encoded memory, float32 (batch, maxlen_in, feat) + memory_mask: encoder memory mask, (batch, 1, maxlen_in) + ys_in_pad: padded input token ids, int64 (batch, maxlen_out) + ys_in_lens: input lengths of this batch (batch) + r_ys_in_pad: padded input token ids, int64 (batch, maxlen_out), + used for right to left decoder + reverse_weight: used for right to left decoder + Returns: + (tuple): tuple containing: + x: decoded token score before softmax (batch, maxlen_out, + vocab_size) if use_output_layer is True, + r_x: x: decoded token score (right to left decoder) + before softmax (batch, maxlen_out, vocab_size) + if use_output_layer is True, + olens: (batch, ) + """ + l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad, + ys_in_lens) + r_x = torch.tensor(0.0) + if reverse_weight > 0.0: + r_x, _, olens = self.right_decoder(memory, memory_mask, + r_ys_in_pad, ys_in_lens) + return l_x, r_x, olens + + def forward_one_step( + self, + memory: torch.Tensor, + memory_mask: torch.Tensor, + tgt: torch.Tensor, + tgt_mask: torch.Tensor, + cache: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """Forward one step. + This is only used for decoding. + Args: + memory: encoded memory, float32 (batch, maxlen_in, feat) + memory_mask: encoded memory mask, (batch, 1, maxlen_in) + tgt: input token ids, int64 (batch, maxlen_out) + tgt_mask: input token mask, (batch, maxlen_out) + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (include 1.2) + cache: cached output list of (batch, max_time_out-1, size) + Returns: + y, cache: NN output value and cache per `self.decoders`. + y.shape` is (batch, maxlen_out, token) + """ + return self.left_decoder.forward_one_step(memory, memory_mask, tgt, + tgt_mask, cache) + + def tie_or_clone_weights(self, jit_mode: bool = True): + """Tie or clone module weights (between word_emb and output_layer) + depending of whether we are using TorchScript or not""" + self.left_decoder.tie_or_clone_weights(jit_mode) + self.right_decoder.tie_or_clone_weights(jit_mode) diff --git a/inspiremusic/transformer/decoder_layer.py b/inspiremusic/transformer/decoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..91c7c5d7fb2a8e79cea7705646e5381016f73466 --- /dev/null +++ b/inspiremusic/transformer/decoder_layer.py @@ -0,0 +1,132 @@ +# Copyright (c) 2019 Shigeki Karita +# 2020 Mobvoi Inc (Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Decoder self-attention layer definition.""" +from typing import Optional, Tuple + +import torch +from torch import nn + + +class DecoderLayer(nn.Module): + """Single decoder layer module. + + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` instance can be used as the argument. + src_attn (torch.nn.Module): Inter-attention module instance. + `MultiHeadedAttention` instance can be used as the argument. + If `None` is passed, Inter-attention is not used, such as + CIF, GPT, and other decoder only model. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward` instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): + True: use layer_norm before each sub-block. + False: to use layer_norm after each sub-block. + """ + + def __init__( + self, + size: int, + self_attn: nn.Module, + src_attn: Optional[nn.Module], + feed_forward: nn.Module, + dropout_rate: float, + normalize_before: bool = True, + ): + """Construct an DecoderLayer object.""" + super().__init__() + self.size = size + self.self_attn = self_attn + self.src_attn = src_attn + self.feed_forward = feed_forward + self.norm1 = nn.LayerNorm(size, eps=1e-5) + self.norm2 = nn.LayerNorm(size, eps=1e-5) + self.norm3 = nn.LayerNorm(size, eps=1e-5) + self.dropout = nn.Dropout(dropout_rate) + self.normalize_before = normalize_before + + def forward( + self, + tgt: torch.Tensor, + tgt_mask: torch.Tensor, + memory: torch.Tensor, + memory_mask: torch.Tensor, + cache: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute decoded features. + + Args: + tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). + tgt_mask (torch.Tensor): Mask for input tensor + (#batch, maxlen_out). + memory (torch.Tensor): Encoded memory + (#batch, maxlen_in, size). + memory_mask (torch.Tensor): Encoded memory mask + (#batch, maxlen_in). + cache (torch.Tensor): cached tensors. + (#batch, maxlen_out - 1, size). + + Returns: + torch.Tensor: Output tensor (#batch, maxlen_out, size). + torch.Tensor: Mask for output tensor (#batch, maxlen_out). + torch.Tensor: Encoded memory (#batch, maxlen_in, size). + torch.Tensor: Encoded memory mask (#batch, maxlen_in). + + """ + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + + if cache is None: + tgt_q = tgt + tgt_q_mask = tgt_mask + else: + # compute only the last frame query keeping dim: max_time_out -> 1 + assert cache.shape == ( + tgt.shape[0], + tgt.shape[1] - 1, + self.size, + ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" + tgt_q = tgt[:, -1:, :] + residual = residual[:, -1:, :] + tgt_q_mask = tgt_mask[:, -1:, :] + + x = residual + self.dropout( + self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]) + if not self.normalize_before: + x = self.norm1(x) + + if self.src_attn is not None: + residual = x + if self.normalize_before: + x = self.norm2(x) + x = residual + self.dropout( + self.src_attn(x, memory, memory, memory_mask)[0]) + if not self.normalize_before: + x = self.norm2(x) + + residual = x + if self.normalize_before: + x = self.norm3(x) + x = residual + self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm3(x) + + if cache is not None: + x = torch.cat([cache, x], dim=1) + + return x, tgt_mask, memory, memory_mask diff --git a/inspiremusic/transformer/embedding.py b/inspiremusic/transformer/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..eae8c8ecabb15b4174cc3aa73c070ae702bb5f82 --- /dev/null +++ b/inspiremusic/transformer/embedding.py @@ -0,0 +1,294 @@ +# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) +# 2024 Alibaba Inc (Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""Positonal Encoding Module.""" + +import math +from typing import Tuple, Union + +import torch +import torch.nn.functional as F +import numpy as np + + +class PositionalEncoding(torch.nn.Module): + """Positional encoding. + + :param int d_model: embedding dim + :param float dropout_rate: dropout rate + :param int max_len: maximum input length + + PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) + PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel))) + """ + + def __init__(self, + d_model: int, + dropout_rate: float, + max_len: int = 5000, + reverse: bool = False): + """Construct an PositionalEncoding object.""" + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.max_len = max_len + + self.pe = torch.zeros(self.max_len, self.d_model) + position = torch.arange(0, self.max_len, + dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) * + -(math.log(10000.0) / self.d_model)) + self.pe[:, 0::2] = torch.sin(position * div_term) + self.pe[:, 1::2] = torch.cos(position * div_term) + self.pe = self.pe.unsqueeze(0) + + def forward(self, + x: torch.Tensor, + offset: Union[int, torch.Tensor] = 0) \ + -> Tuple[torch.Tensor, torch.Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input. Its shape is (batch, time, ...) + offset (int, torch.tensor): position offset + + Returns: + torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) + torch.Tensor: for compatibility to RelPositionalEncoding + """ + + self.pe = self.pe.to(x.device) + pos_emb = self.position_encoding(offset, x.size(1), False) + x = x * self.xscale + pos_emb + return self.dropout(x), self.dropout(pos_emb) + + def position_encoding(self, + offset: Union[int, torch.Tensor], + size: int, + apply_dropout: bool = True) -> torch.Tensor: + """ For getting encoding in a streaming fashion + + Attention!!!!! + we apply dropout only once at the whole utterance level in a none + streaming way, but will call this function several times with + increasing input size in a streaming scenario, so the dropout will + be applied several times. + + Args: + offset (int or torch.tensor): start offset + size (int): required size of position encoding + + Returns: + torch.Tensor: Corresponding encoding + """ + # How to subscript a Union type: + # https://github.com/pytorch/pytorch/issues/69434 + if isinstance(offset, int): + assert offset + size <= self.max_len + pos_emb = self.pe[:, offset:offset + size] + elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar + assert offset + size <= self.max_len + pos_emb = self.pe[:, offset:offset + size] + else: # for batched streaming decoding on GPU + assert torch.max(offset) + size <= self.max_len + index = offset.unsqueeze(1) + \ + torch.arange(0, size).to(offset.device) # B X T + flag = index > 0 + # remove negative offset + index = index * flag + pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model + + if apply_dropout: + pos_emb = self.dropout(pos_emb) + return pos_emb + + +class RelPositionalEncoding(PositionalEncoding): + """Relative positional encoding module. + See : Appendix B in https://arxiv.org/abs/1901.02860 + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): + """Initialize class.""" + super().__init__(d_model, dropout_rate, max_len, reverse=True) + + def forward(self, + x: torch.Tensor, + offset: Union[int, torch.Tensor] = 0) \ + -> Tuple[torch.Tensor, torch.Tensor]: + """Compute positional encoding. + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Positional embedding tensor (1, time, `*`). + """ + self.pe = self.pe.to(x.device) + x = x * self.xscale + pos_emb = self.position_encoding(offset, x.size(1), False) + return self.dropout(x), self.dropout(pos_emb) + + +class WhisperPositionalEncoding(PositionalEncoding): + """ Sinusoids position encoding used in openai-whisper.encoder + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500): + super().__init__(d_model, dropout_rate, max_len) + self.xscale = 1.0 + log_timescale_increment = np.log(10000) / (d_model // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * + torch.arange(d_model // 2)) + scaled_time = torch.arange(max_len)[:, np.newaxis] * \ + inv_timescales[np.newaxis, :] + pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) + delattr(self, "pe") + self.register_buffer("pe", pe.unsqueeze(0)) + + +class LearnablePositionalEncoding(PositionalEncoding): + """ Learnable position encoding used in openai-whisper.decoder + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448): + super().__init__(d_model, dropout_rate, max_len) + # NOTE(xcsong): overwrite self.pe & self.xscale + self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model)) + self.xscale = 1.0 + + +class NoPositionalEncoding(torch.nn.Module): + """ No position encoding + """ + + def __init__(self, d_model: int, dropout_rate: float): + super().__init__() + self.d_model = d_model + self.dropout = torch.nn.Dropout(p=dropout_rate) + + def forward(self, + x: torch.Tensor, + offset: Union[int, torch.Tensor] = 0) \ + -> Tuple[torch.Tensor, torch.Tensor]: + """ Just return zero vector for interface compatibility + """ + pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device) + return self.dropout(x), pos_emb + + def position_encoding(self, offset: Union[int, torch.Tensor], + size: int) -> torch.Tensor: + return torch.zeros(1, size, self.d_model) + + +class EspnetRelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module (new implementation). + + Details can be found in https://github.com/espnet/espnet/pull/2816. + + See : Appendix B in https://arxiv.org/abs/1901.02860 + + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): + """Construct an PositionalEncoding object.""" + super(EspnetRelPositionalEncoding, self).__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: torch.Tensor): + """Reset the positional encodings.""" + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tuple[torch.Tensor, torch.Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + + """ + self.extend_pe(x) + x = x * self.xscale + pos_emb = self.position_encoding(size=x.size(1), offset=offset) + return self.dropout(x), self.dropout(pos_emb) + + def position_encoding(self, + offset: Union[int, torch.Tensor], + size: int) -> torch.Tensor: + """ For getting encoding in a streaming fashion + + Attention!!!!! + we apply dropout only once at the whole utterance level in a none + streaming way, but will call this function several times with + increasing input size in a streaming scenario, so the dropout will + be applied several times. + + Args: + offset (int or torch.tensor): start offset + size (int): required size of position encoding + + Returns: + torch.Tensor: Corresponding encoding + """ + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size, + ] + return pos_emb diff --git a/inspiremusic/transformer/encoder.py b/inspiremusic/transformer/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a46b531778f3c012f0a7e4f5da3c1d6f707c358d --- /dev/null +++ b/inspiremusic/transformer/encoder.py @@ -0,0 +1,477 @@ +# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) +# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn) +# 2024 Alibaba Inc (Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""Encoder definition.""" +from typing import Tuple + +import torch +import torch.utils.checkpoint as ckpt + +from inspiremusic.transformer.convolution import ConvolutionModule +from inspiremusic.transformer.encoder_layer import TransformerEncoderLayer +from inspiremusic.transformer.encoder_layer import ConformerEncoderLayer +from inspiremusic.transformer.positionwise_feed_forward import PositionwiseFeedForward +from inspiremusic.utils.class_utils import ( + INSPIREMUSIC_EMB_CLASSES, + INSPIREMUSIC_SUBSAMPLE_CLASSES, + INSPIREMUSIC_ATTENTION_CLASSES, + INSPIREMUSIC_ACTIVATION_CLASSES, +) +from inspiremusic.utils.mask import make_pad_mask +from inspiremusic.utils.mask import add_optional_chunk_mask + + +class BaseEncoder(torch.nn.Module): + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = "conv2d", + pos_enc_layer_type: str = "abs_pos", + normalize_before: bool = True, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: torch.nn.Module = None, + use_dynamic_left_chunk: bool = False, + gradient_checkpointing: bool = False, + ): + """ + Args: + input_size (int): input dim + output_size (int): dimension of attention + attention_heads (int): the number of heads of multi head attention + linear_units (int): the hidden units number of position-wise feed + forward + num_blocks (int): the number of decoder blocks + dropout_rate (float): dropout rate + attention_dropout_rate (float): dropout rate in attention + positional_dropout_rate (float): dropout rate after adding + positional encoding + input_layer (str): input layer type. + optional [linear, conv2d, conv2d6, conv2d8] + pos_enc_layer_type (str): Encoder positional encoding layer type. + opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos] + normalize_before (bool): + True: use layer_norm before each sub-block of a layer. + False: use layer_norm after each sub-block of a layer. + static_chunk_size (int): chunk size for static chunk training and + decoding + use_dynamic_chunk (bool): whether use dynamic chunk size for + training or not, You can only use fixed chunk(chunk_size > 0) + or dyanmic chunk size(use_dynamic_chunk = True) + global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module + use_dynamic_left_chunk (bool): whether use dynamic left chunk in + dynamic chunk training + key_bias: whether use bias in attention.linear_k, False for whisper models. + gradient_checkpointing: rerunning a forward-pass segment for each + checkpointed segment during backward. + """ + super().__init__() + self._output_size = output_size + + self.global_cmvn = global_cmvn + self.embed = INSPIREMUSIC_SUBSAMPLE_CLASSES[input_layer]( + input_size, + output_size, + dropout_rate, + INSPIREMUSIC_EMB_CLASSES[pos_enc_layer_type](output_size, + positional_dropout_rate), + ) + + self.normalize_before = normalize_before + self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5) + self.static_chunk_size = static_chunk_size + self.use_dynamic_chunk = use_dynamic_chunk + self.use_dynamic_left_chunk = use_dynamic_left_chunk + self.gradient_checkpointing = gradient_checkpointing + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs: torch.Tensor, + xs_lens: torch.Tensor, + decoding_chunk_size: int = 0, + num_decoding_left_chunks: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Embed positions in tensor. + + Args: + xs: padded input tensor (B, T, D) + xs_lens: input length (B) + decoding_chunk_size: decoding chunk size for dynamic chunk + 0: default for training, use random dynamic chunk. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + num_decoding_left_chunks: number of left chunks, this is for decoding, + the chunk size is decoding_chunk_size. + >=0: use num_decoding_left_chunks + <0: use all left chunks + Returns: + encoder output tensor xs, and subsampled masks + xs: padded output tensor (B, T' ~= T/subsample_rate, D) + masks: torch.Tensor batch padding mask after subsample + (B, 1, T' ~= T/subsample_rate) + NOTE(xcsong): + We pass the `__call__` method of the modules instead of `forward` to the + checkpointing API because `__call__` attaches all the hooks of the module. + https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 + """ + T = xs.size(1) + masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + xs, pos_emb, masks = self.embed(xs, masks) + mask_pad = masks # (B, 1, T/subsample_rate) + chunk_masks = add_optional_chunk_mask(xs, masks, + self.use_dynamic_chunk, + self.use_dynamic_left_chunk, + decoding_chunk_size, + self.static_chunk_size, + num_decoding_left_chunks) + + if self.gradient_checkpointing and self.training: + xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb, + mask_pad) + else: + xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad) + if self.normalize_before: + xs = self.after_norm(xs) + # Here we assume the mask is not changed in encoder layers, so just + # return the masks before encoder layers, and the masks will be used + # for cross attention with decoder later + return xs, masks + + def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor) -> torch.Tensor: + for layer in self.encoders: + xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + return xs + + @torch.jit.unused + def forward_layers_checkpointed(self, xs: torch.Tensor, + chunk_masks: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor) -> torch.Tensor: + for layer in self.encoders: + xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__, xs, + chunk_masks, pos_emb, + mask_pad) + return xs + + @torch.jit.export + def forward_chunk( + self, + xs: torch.Tensor, + offset: int, + required_cache_size: int, + att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ Forward just one chunk + + Args: + xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim), + where `time == (chunk_size - 1) * subsample_rate + \ + subsample.right_context + 1` + offset (int): current offset in encoder output time stamp + required_cache_size (int): cache size required for next chunk + compuation + >=0: actual cache size + <0: means all history cache is required + att_cache (torch.Tensor): cache tensor for KEY & VALUE in + transformer/conformer attention, with shape + (elayers, head, cache_t1, d_k * 2), where + `head * d_k == hidden-dim` and + `cache_t1 == chunk_size * num_decoding_left_chunks`. + cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, + (elayers, b=1, hidden-dim, cache_t2), where + `cache_t2 == cnn.lorder - 1` + + Returns: + torch.Tensor: output of current input xs, + with shape (b=1, chunk_size, hidden-dim). + torch.Tensor: new attention cache required for next chunk, with + dynamic shape (elayers, head, ?, d_k * 2) + depending on required_cache_size. + torch.Tensor: new conformer cnn cache required for next chunk, with + same shape as the original cnn_cache. + + """ + assert xs.size(0) == 1 + # tmp_masks is just for interface compatibility + tmp_masks = torch.ones(1, + xs.size(1), + device=xs.device, + dtype=torch.bool) + tmp_masks = tmp_masks.unsqueeze(1) + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim) + + xs, pos_emb, _ = self.embed(xs, tmp_masks, offset) + # NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim) + elayers, cache_t1 = att_cache.size(0), att_cache.size(2) + chunk_size = xs.size(1) + attention_key_size = cache_t1 + chunk_size + pos_emb = self.embed.position_encoding(offset=offset - cache_t1, + size=attention_key_size) + if required_cache_size < 0: + next_cache_start = 0 + elif required_cache_size == 0: + next_cache_start = attention_key_size + else: + next_cache_start = max(attention_key_size - required_cache_size, 0) + r_att_cache = [] + r_cnn_cache = [] + for i, layer in enumerate(self.encoders): + # NOTE(xcsong): Before layer.forward + # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2), + # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2) + + xs, _, new_att_cache, new_cnn_cache = layer( + xs, + att_mask, + pos_emb, + att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache, + cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache) + # NOTE(xcsong): After layer.forward + # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2), + # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2) + r_att_cache.append(new_att_cache[:, :, next_cache_start:, :]) + r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) + if self.normalize_before: + xs = self.after_norm(xs) + + # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2), + # ? may be larger than cache_t1, it depends on required_cache_size + r_att_cache = torch.cat(r_att_cache, dim=0) + # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2) + r_cnn_cache = torch.cat(r_cnn_cache, dim=0) + + return (xs, r_att_cache, r_cnn_cache) + + @torch.jit.unused + def forward_chunk_by_chunk( + self, + xs: torch.Tensor, + decoding_chunk_size: int, + num_decoding_left_chunks: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ Forward input chunk by chunk with chunk_size like a streaming + fashion + + Here we should pay special attention to computation cache in the + streaming style forward chunk by chunk. Three things should be taken + into account for computation in the current network: + 1. transformer/conformer encoder layers output cache + 2. convolution in conformer + 3. convolution in subsampling + + However, we don't implement subsampling cache for: + 1. We can control subsampling module to output the right result by + overlapping input instead of cache left context, even though it + wastes some computation, but subsampling only takes a very + small fraction of computation in the whole model. + 2. Typically, there are several covolution layers with subsampling + in subsampling module, it is tricky and complicated to do cache + with different convolution layers with different subsampling + rate. + 3. Currently, nn.Sequential is used to stack all the convolution + layers in subsampling, we need to rewrite it to make it work + with cache, which is not preferred. + Args: + xs (torch.Tensor): (1, max_len, dim) + chunk_size (int): decoding chunk size + """ + assert decoding_chunk_size > 0 + # The model is trained by static or dynamic chunk + assert self.static_chunk_size > 0 or self.use_dynamic_chunk + subsampling = self.embed.subsampling_rate + context = self.embed.right_context + 1 # Add current frame + stride = subsampling * decoding_chunk_size + decoding_window = (decoding_chunk_size - 1) * subsampling + context + num_frames = xs.size(1) + att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device) + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device) + outputs = [] + offset = 0 + required_cache_size = decoding_chunk_size * num_decoding_left_chunks + + # Feed forward overlap input step by step + for cur in range(0, num_frames - context + 1, stride): + end = min(cur + decoding_window, num_frames) + chunk_xs = xs[:, cur:end, :] + (y, att_cache, + cnn_cache) = self.forward_chunk(chunk_xs, offset, + required_cache_size, att_cache, + cnn_cache) + outputs.append(y) + offset += y.size(1) + ys = torch.cat(outputs, 1) + masks = torch.ones((1, 1, ys.size(1)), + device=ys.device, + dtype=torch.bool) + return ys, masks + + +class TransformerEncoder(BaseEncoder): + """Transformer encoder module.""" + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = "conv2d", + pos_enc_layer_type: str = "abs_pos", + normalize_before: bool = True, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: torch.nn.Module = None, + use_dynamic_left_chunk: bool = False, + key_bias: bool = True, + selfattention_layer_type: str = "selfattn", + activation_type: str = "relu", + gradient_checkpointing: bool = False, + ): + """ Construct TransformerEncoder + + See Encoder for the meaning of each parameter. + """ + super().__init__(input_size, output_size, attention_heads, + linear_units, num_blocks, dropout_rate, + positional_dropout_rate, attention_dropout_rate, + input_layer, pos_enc_layer_type, normalize_before, + static_chunk_size, use_dynamic_chunk, global_cmvn, + use_dynamic_left_chunk, gradient_checkpointing) + activation = INSPIREMUSIC_ACTIVATION_CLASSES[activation_type]() + self.encoders = torch.nn.ModuleList([ + TransformerEncoderLayer( + output_size, + INSPIREMUSIC_ATTENTION_CLASSES[selfattention_layer_type](attention_heads, + output_size, + attention_dropout_rate, + key_bias), + PositionwiseFeedForward(output_size, linear_units, + dropout_rate, activation), + dropout_rate, normalize_before) for _ in range(num_blocks) + ]) + + +class ConformerEncoder(BaseEncoder): + """Conformer encoder module.""" + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = "conv2d", + pos_enc_layer_type: str = "rel_pos", + normalize_before: bool = True, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: torch.nn.Module = None, + use_dynamic_left_chunk: bool = False, + positionwise_conv_kernel_size: int = 1, + macaron_style: bool = True, + selfattention_layer_type: str = "rel_selfattn", + activation_type: str = "swish", + use_cnn_module: bool = True, + cnn_module_kernel: int = 15, + causal: bool = False, + cnn_module_norm: str = "batch_norm", + key_bias: bool = True, + gradient_checkpointing: bool = False, + ): + """Construct ConformerEncoder + + Args: + input_size to use_dynamic_chunk, see in BaseEncoder + positionwise_conv_kernel_size (int): Kernel size of positionwise + conv1d layer. + macaron_style (bool): Whether to use macaron style for + positionwise layer. + selfattention_layer_type (str): Encoder attention layer type, + the parameter has no effect now, it's just for configure + compatibility. + activation_type (str): Encoder activation function type. + use_cnn_module (bool): Whether to use convolution module. + cnn_module_kernel (int): Kernel size of convolution module. + causal (bool): whether to use causal convolution or not. + key_bias: whether use bias in attention.linear_k, False for whisper models. + """ + super().__init__(input_size, output_size, attention_heads, + linear_units, num_blocks, dropout_rate, + positional_dropout_rate, attention_dropout_rate, + input_layer, pos_enc_layer_type, normalize_before, + static_chunk_size, use_dynamic_chunk, global_cmvn, + use_dynamic_left_chunk, gradient_checkpointing) + activation = INSPIREMUSIC_ACTIVATION_CLASSES[activation_type]() + + # self-attention module definition + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + key_bias, + ) + # feed-forward module definition + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + activation, + ) + # convolution module definition + convolution_layer_args = (output_size, cnn_module_kernel, activation, + cnn_module_norm, causal) + + self.encoders = torch.nn.ModuleList([ + ConformerEncoderLayer( + output_size, + INSPIREMUSIC_ATTENTION_CLASSES[selfattention_layer_type]( + *encoder_selfattn_layer_args), + PositionwiseFeedForward(*positionwise_layer_args), + PositionwiseFeedForward( + *positionwise_layer_args) if macaron_style else None, + ConvolutionModule( + *convolution_layer_args) if use_cnn_module else None, + dropout_rate, + normalize_before, + ) for _ in range(num_blocks) + ]) diff --git a/inspiremusic/transformer/encoder_layer.py b/inspiremusic/transformer/encoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..9fb3a9d4e99d0f8f92ec1802a1a7620328e9353a --- /dev/null +++ b/inspiremusic/transformer/encoder_layer.py @@ -0,0 +1,235 @@ +# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) +# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""Encoder self-attention layer definition.""" + +from typing import Optional, Tuple + +import torch +from torch import nn + + +class TransformerEncoderLayer(nn.Module): + """Encoder layer module. + + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + instance can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward`, instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): + True: use layer_norm before each sub-block. + False: to use layer_norm after each sub-block. + """ + + def __init__( + self, + size: int, + self_attn: torch.nn.Module, + feed_forward: torch.nn.Module, + dropout_rate: float, + normalize_before: bool = True, + ): + """Construct an EncoderLayer object.""" + super().__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.norm1 = nn.LayerNorm(size, eps=1e-5) + self.norm2 = nn.LayerNorm(size, eps=1e-5) + self.dropout = nn.Dropout(dropout_rate) + self.size = size + self.normalize_before = normalize_before + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute encoded features. + + Args: + x (torch.Tensor): (#batch, time, size) + mask (torch.Tensor): Mask tensor for the input (#batch, time๏ผŒtime), + (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): just for interface compatibility + to ConformerEncoderLayer + mask_pad (torch.Tensor): does not used in transformer layer, + just for unified api with conformer. + att_cache (torch.Tensor): Cache tensor of the KEY & VALUE + (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. + cnn_cache (torch.Tensor): Convolution cache in conformer layer + (#batch=1, size, cache_t2), not used here, it's for interface + compatibility to ConformerEncoderLayer. + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time, time). + torch.Tensor: att_cache tensor, + (#batch=1, head, cache_t1 + time, d_k * 2). + torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2). + + """ + residual = x + if self.normalize_before: + x = self.norm1(x) + x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb=pos_emb, cache=att_cache) + x = residual + self.dropout(x_att) + if not self.normalize_before: + x = self.norm1(x) + + residual = x + if self.normalize_before: + x = self.norm2(x) + x = residual + self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm2(x) + + fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) + return x, mask, new_att_cache, fake_cnn_cache + + +class ConformerEncoderLayer(nn.Module): + """Encoder layer module. + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + instance can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward` instance can be used as the argument. + feed_forward_macaron (torch.nn.Module): Additional feed-forward module + instance. + `PositionwiseFeedForward` instance can be used as the argument. + conv_module (torch.nn.Module): Convolution module instance. + `ConvlutionModule` instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): + True: use layer_norm before each sub-block. + False: use layer_norm after each sub-block. + """ + + def __init__( + self, + size: int, + self_attn: torch.nn.Module, + feed_forward: Optional[nn.Module] = None, + feed_forward_macaron: Optional[nn.Module] = None, + conv_module: Optional[nn.Module] = None, + dropout_rate: float = 0.1, + normalize_before: bool = True, + ): + """Construct an EncoderLayer object.""" + super().__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.feed_forward_macaron = feed_forward_macaron + self.conv_module = conv_module + self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module + self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module + if feed_forward_macaron is not None: + self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5) + self.ff_scale = 0.5 + else: + self.ff_scale = 1.0 + if self.conv_module is not None: + self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module + self.norm_final = nn.LayerNorm( + size, eps=1e-5) # for the final output of the block + self.dropout = nn.Dropout(dropout_rate) + self.size = size + self.normalize_before = normalize_before + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute encoded features. + + Args: + x (torch.Tensor): (#batch, time, size) + mask (torch.Tensor): Mask tensor for the input (#batch, time๏ผŒtime), + (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): positional encoding, must not be None + for ConformerEncoderLayer. + mask_pad (torch.Tensor): batch padding mask used for conv module. + (#batch, 1๏ผŒtime), (0, 0, 0) means fake mask. + att_cache (torch.Tensor): Cache tensor of the KEY & VALUE + (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. + cnn_cache (torch.Tensor): Convolution cache in conformer layer + (#batch=1, size, cache_t2) + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time, time). + torch.Tensor: att_cache tensor, + (#batch=1, head, cache_t1 + time, d_k * 2). + torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2). + """ + + # whether to use macaron style + if self.feed_forward_macaron is not None: + residual = x + if self.normalize_before: + x = self.norm_ff_macaron(x) + x = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(x)) + if not self.normalize_before: + x = self.norm_ff_macaron(x) + # multi-headed self-attention module + residual = x + if self.normalize_before: + x = self.norm_mha(x) + x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, + att_cache) + x = residual + self.dropout(x_att) + if not self.normalize_before: + x = self.norm_mha(x) + + # convolution module + # Fake new cnn cache here, and then change it in conv_module + new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) + if self.conv_module is not None: + residual = x + if self.normalize_before: + x = self.norm_conv(x) + x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) + x = residual + self.dropout(x) + + if not self.normalize_before: + x = self.norm_conv(x) + + # feed forward module + residual = x + if self.normalize_before: + x = self.norm_ff(x) + + x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm_ff(x) + + if self.conv_module is not None: + x = self.norm_final(x) + + return x, mask, new_att_cache, new_cnn_cache diff --git a/inspiremusic/transformer/label_smoothing_loss.py b/inspiremusic/transformer/label_smoothing_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..73ce09f35bfacb86730e39ef72a097f8a04e469b --- /dev/null +++ b/inspiremusic/transformer/label_smoothing_loss.py @@ -0,0 +1,97 @@ +# Copyright (c) 2019 Shigeki Karita +# 2020 Mobvoi Inc (Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Label smoothing module.""" + +import torch +from torch import nn + + +class LabelSmoothingLoss(nn.Module): + """Label-smoothing loss. + + In a standard CE loss, the label's data distribution is: + [0,1,2] -> + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ] + + In the smoothing version CE Loss,some probabilities + are taken from the true label prob (1.0) and are divided + among other labels. + + e.g. + smoothing=0.1 + [0,1,2] -> + [ + [0.9, 0.05, 0.05], + [0.05, 0.9, 0.05], + [0.05, 0.05, 0.9], + ] + + Args: + size (int): the number of class + padding_idx (int): padding class id which will be ignored for loss + smoothing (float): smoothing rate (0.0 means the conventional CE) + normalize_length (bool): + normalize loss by sequence length if True + normalize loss by batch size if False + """ + + def __init__(self, + size: int, + padding_idx: int, + smoothing: float, + normalize_length: bool = False): + """Construct an LabelSmoothingLoss object.""" + super(LabelSmoothingLoss, self).__init__() + self.criterion = nn.KLDivLoss(reduction="none") + self.padding_idx = padding_idx + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + self.size = size + self.normalize_length = normalize_length + + def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Compute loss between x and target. + + The model outputs and data labels tensors are flatten to + (batch*seqlen, class) shape and a mask is applied to the + padding part which should not be calculated for loss. + + Args: + x (torch.Tensor): prediction (batch, seqlen, class) + target (torch.Tensor): + target signal masked with self.padding_id (batch, seqlen) + Returns: + loss (torch.Tensor) : The KL loss, scalar float value + """ + assert x.size(2) == self.size + batch_size = x.size(0) + x = x.view(-1, self.size) + target = target.view(-1) + # use zeros_like instead of torch.no_grad() for true_dist, + # since no_grad() can not be exported by JIT + true_dist = torch.zeros_like(x) + true_dist.fill_(self.smoothing / (self.size - 1)) + ignore = target == self.padding_idx # (B,) + + total = len(target) - ignore.sum().item() + target = target.masked_fill(ignore, 0) # avoid -1 index + true_dist.scatter_(1, target.unsqueeze(1), self.confidence) + kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) + denom = total if self.normalize_length else batch_size + return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom diff --git a/inspiremusic/transformer/positionwise_feed_forward.py b/inspiremusic/transformer/positionwise_feed_forward.py new file mode 100644 index 0000000000000000000000000000000000000000..b7a2cf6e7315e3a5ed2794423daff0a59cc5b208 --- /dev/null +++ b/inspiremusic/transformer/positionwise_feed_forward.py @@ -0,0 +1,115 @@ +# Copyright (c) 2019 Shigeki Karita +# 2020 Mobvoi Inc (Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Positionwise feed forward layer definition.""" + +import torch + + +class PositionwiseFeedForward(torch.nn.Module): + """Positionwise feed forward layer. + + FeedForward are appied on each position of the sequence. + The output dim is same with the input dim. + + Args: + idim (int): Input dimenstion. + hidden_units (int): The number of hidden units. + dropout_rate (float): Dropout rate. + activation (torch.nn.Module): Activation function + """ + + def __init__( + self, + idim: int, + hidden_units: int, + dropout_rate: float, + activation: torch.nn.Module = torch.nn.ReLU(), + ): + """Construct a PositionwiseFeedForward object.""" + super(PositionwiseFeedForward, self).__init__() + self.w_1 = torch.nn.Linear(idim, hidden_units) + self.activation = activation + self.dropout = torch.nn.Dropout(dropout_rate) + self.w_2 = torch.nn.Linear(hidden_units, idim) + + def forward(self, xs: torch.Tensor) -> torch.Tensor: + """Forward function. + + Args: + xs: input tensor (B, L, D) + Returns: + output tensor, (B, L, D) + """ + return self.w_2(self.dropout(self.activation(self.w_1(xs)))) + + +class MoEFFNLayer(torch.nn.Module): + """ + Mixture of expert with Positionwise feed forward layer + See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf + The output dim is same with the input dim. + + Modified from https://github.com/Lightning-AI/lit-gpt/pull/823 + https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219 + Args: + n_expert: number of expert. + n_expert_per_token: The actual number of experts used for each frame + idim (int): Input dimenstion. + hidden_units (int): The number of hidden units. + dropout_rate (float): Dropout rate. + activation (torch.nn.Module): Activation function + """ + + def __init__( + self, + n_expert: int, + n_expert_per_token: int, + idim: int, + hidden_units: int, + dropout_rate: float, + activation: torch.nn.Module = torch.nn.ReLU(), + ): + super(MoEFFNLayer, self).__init__() + self.gate = torch.nn.Linear(idim, n_expert, bias=False) + self.experts = torch.nn.ModuleList( + PositionwiseFeedForward(idim, hidden_units, dropout_rate, + activation) for _ in range(n_expert)) + self.n_expert_per_token = n_expert_per_token + + def forward(self, xs: torch.Tensor) -> torch.Tensor: + """Foward function. + Args: + xs: input tensor (B, L, D) + Returns: + output tensor, (B, L, D) + + """ + B, L, D = xs.size( + ) # batch size, sequence length, embedding dimension (idim) + xs = xs.view(-1, D) # (B*L, D) + router = self.gate(xs) # (B*L, n_expert) + logits, indices = torch.topk( + router, self.n_expert_per_token + ) # probs:(B*L, n_expert), indices: (B*L, n_expert) + weights = torch.nn.functional.softmax( + logits, dim=1, + dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token) + output = torch.zeros_like(xs) # (B*L, D) + for i, expert in enumerate(self.experts): + mask = indices == i + batch_idx, ith_expert = torch.where(mask) + output[batch_idx] += weights[batch_idx, ith_expert, None] * expert( + xs[batch_idx]) + return output.view(B, L, D) diff --git a/inspiremusic/transformer/qwen_encoder.py b/inspiremusic/transformer/qwen_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..92e3305ee9a72b370aba87634a30a2a3b0c7ff83 --- /dev/null +++ b/inspiremusic/transformer/qwen_encoder.py @@ -0,0 +1,165 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.nn as nn +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from inspiremusic.utils.mask import make_pad_mask +from inspiremusic.utils.hinter import hint_once + +class QwenEncoder(nn.Module): + def __init__( + self, + input_size: int, + pretrain_path: str = "Qwen/Qwen2.0-0.5B", + trainable: bool = False, + do_fusion_emb: bool = False, + fusion_drop_rate: float = 0.0, + ): + super(QwenEncoder, self).__init__() + self.input_size = input_size + self.trainable = trainable + self.model = AutoModelForCausalLM.from_pretrained(pretrain_path, device_map="cpu") + self._output_size = self.model.config.hidden_size + self.do_fusion_emb = do_fusion_emb + self.hidden_norm = torch.nn.LayerNorm(self._output_size) + self.fusion_dropout = nn.Dropout(fusion_drop_rate) + if do_fusion_emb: + self.fusion_layer = torch.nn.Linear(self._output_size * 2, self._output_size) + self.emb_norm = torch.nn.LayerNorm(self._output_size) + self.fusion_norm = torch.nn.LayerNorm(self._output_size) + from inspiremusic.transformer.activation import Swish + self.fusion_act = Swish(self) + + if not self.trainable: + self.model.eval() + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + input_ids: torch.Tensor, + ilens: torch.Tensor, + ): + device = input_ids.device + input_ids = torch.clamp(input_ids, min=0, max=None) + input_masks = (~make_pad_mask(ilens)).to(device).long() + if not self.trainable: + with torch.no_grad(): + model_outputs = self.model( + input_ids=input_ids, + attention_mask=input_masks, + output_hidden_states=True + ) + else: + model_outputs = self.model( + input_ids=input_ids, + attention_mask=input_masks, + output_hidden_states=True + ) + outs = model_outputs.hidden_states[-1] + outs = self.hidden_norm(outs) + if self.do_fusion_emb: + hint_once("fuse embedding and LM outputs", "fuse_emb") + outs = self.fusion_dropout(self.fusion_act(outs)) + emb = model_outputs.hidden_states[0] + emb = self.fusion_dropout(self.fusion_act(self.emb_norm(emb))) + outs = self.fusion_layer( + torch.cat([outs, emb], dim=-1) + ) + outs = self.fusion_act(self.fusion_norm(outs)) + + return outs, ilens + + +class QwenEmbeddingEncoder(nn.Module): + def __init__( + self, + input_size: int, + pretrain_path: str = "Qwen/Qwen2.0-0.5B", + ): + super(QwenEmbeddingEncoder, self).__init__() + self.input_size = input_size + from transformers import Qwen2ForCausalLM + self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path, device_map="cpu", attn_implementation="flash_attention_2") + self._output_size = self.model.config.hidden_size + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + input_embeds: torch.Tensor, + ilens: torch.Tensor, + ): + input_masks = (~make_pad_mask(ilens)).to(input_embeds.device).long() + + outs = self.model( + inputs_embeds=input_embeds, + attention_mask=input_masks, + output_hidden_states=True, + return_dict=True, + ) + + return outs.hidden_states[-1], input_masks + + def forward_one_step(self, xs, masks, cache=None): + + outs = self.model( + inputs_embeds=xs, + attention_mask=masks, + output_hidden_states=True, + return_dict=True, + use_cache=True, + past_key_values=cache, + ) + xs = outs.hidden_states[-1] + new_cache = outs.past_key_values + + return xs, masks, new_cache + + +class QwenInputOnlyEncoder(nn.Module): + def __init__( + self, + input_size: int, + pretrain_path: str = "Qwen/Qwen2.0-0.5B", + ): + super(QwenInputOnlyEncoder, self).__init__() + self.input_size = input_size + from transformers import Qwen2ForCausalLM + model = Qwen2ForCausalLM.from_pretrained(pretrain_path, device_map="cpu", attn_implementation="flash_attention_2") + self.embed = model.model.embed_tokens + for p in self.embed.parameters(): + p.requires_grad = False + # set text embedding to non-trainable + + # self.post_embed = model.model.rotary_emb + self._output_size = model.config.hidden_size + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + input_ids: torch.Tensor, + ilens: torch.Tensor, + ): + input_masks = (~make_pad_mask(ilens)).to(input_ids.device).long() + + outs = self.embed(input_ids) + + return outs, input_masks + \ No newline at end of file diff --git a/inspiremusic/transformer/subsampling.py b/inspiremusic/transformer/subsampling.py new file mode 100644 index 0000000000000000000000000000000000000000..08cf7d4224c0f4bf95853590f7e2f97b387f44f9 --- /dev/null +++ b/inspiremusic/transformer/subsampling.py @@ -0,0 +1,384 @@ +# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) +# 2024 Alibaba Inc (Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""Subsampling layer definition.""" + +from typing import Tuple, Union + +import torch + + +class BaseSubsampling(torch.nn.Module): + + def __init__(self): + super().__init__() + self.right_context = 0 + self.subsampling_rate = 1 + + def position_encoding(self, offset: Union[int, torch.Tensor], + size: int) -> torch.Tensor: + return self.pos_enc.position_encoding(offset, size) + + +class EmbedinigNoSubsampling(BaseSubsampling): + """Embedding input without subsampling + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + super().__init__() + self.embed = torch.nn.Embedding(idim, odim) + self.pos_enc = pos_enc_class + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Input x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: linear input tensor (#batch, time', odim), + where time' = time . + torch.Tensor: linear input mask (#batch, 1, time'), + where time' = time . + + """ + x = self.embed(x) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask + + +class LinearNoSubsampling(BaseSubsampling): + """Linear transform the input without subsampling + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an linear object.""" + super().__init__() + self.out = torch.nn.Sequential( + torch.nn.Linear(idim, odim), + torch.nn.LayerNorm(odim, eps=1e-5), + torch.nn.Dropout(dropout_rate), + ) + self.pos_enc = pos_enc_class + self.right_context = 0 + self.subsampling_rate = 1 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Input x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: linear input tensor (#batch, time', odim), + where time' = time . + torch.Tensor: linear input mask (#batch, 1, time'), + where time' = time . + + """ + x = self.out(x) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask + + +class Conv1dSubsampling2(BaseSubsampling): + """Convolutional 1D subsampling (to 1/2 length). + It is designed for Whisper, ref: + https://github.com/openai/whisper/blob/main/whisper/model.py + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an Conv1dSubsampling2 object.""" + super().__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1), + torch.nn.GELU(), + torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1), + torch.nn.GELU(), + ) + self.pos_enc = pos_enc_class + # The right context for every conv layer is computed by: + # (kernel_size - 1) * frame_rate_of_this_layer + self.subsampling_rate = 2 + # 4 = (3 - 1) * 1 + (3 - 1) * 1 + self.right_context = 4 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 2. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 2. + torch.Tensor: positional encoding + + """ + time = x.size(1) + x = x.transpose(1, 2) # (b, f, t) + x = self.conv(x) + x = x.transpose(1, 2) # (b, t, f) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask[:, :, (time + 1) % 2::2] + + +class Conv2dSubsampling4(BaseSubsampling): + """Convolutional 2D subsampling (to 1/4 length). + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an Conv2dSubsampling4 object.""" + super().__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 2), + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)) + self.pos_enc = pos_enc_class + # The right context for every conv layer is computed by: + # (kernel_size - 1) * frame_rate_of_this_layer + self.subsampling_rate = 4 + # 6 = (3 - 1) * 1 + (3 - 1) * 2 + self.right_context = 6 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 4. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 4. + torch.Tensor: positional encoding + + """ + x = x.unsqueeze(1) # (b, c=1, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2] + + +class Conv2dSubsampling6(BaseSubsampling): + """Convolutional 2D subsampling (to 1/6 length). + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + pos_enc (torch.nn.Module): Custom position encoding layer. + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an Conv2dSubsampling6 object.""" + super().__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 5, 3), + torch.nn.ReLU(), + ) + self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), + odim) + self.pos_enc = pos_enc_class + # 10 = (3 - 1) * 1 + (5 - 1) * 2 + self.subsampling_rate = 6 + self.right_context = 10 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 6. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 6. + torch.Tensor: positional encoding + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f)) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3] + + +class Conv2dSubsampling8(BaseSubsampling): + """Convolutional 2D subsampling (to 1/8 length). + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an Conv2dSubsampling8 object.""" + super().__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 2), + torch.nn.ReLU(), + ) + self.linear = torch.nn.Linear( + odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim) + self.pos_enc = pos_enc_class + self.subsampling_rate = 8 + # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4 + self.right_context = 14 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 8. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 8. + torch.Tensor: positional encoding + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f)) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2] + + +class LegacyLinearNoSubsampling(BaseSubsampling): + """Linear transform the input without subsampling + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an linear object.""" + super().__init__() + self.out = torch.nn.Sequential( + torch.nn.Linear(idim, odim), + torch.nn.LayerNorm(odim, eps=1e-5), + torch.nn.Dropout(dropout_rate), + torch.nn.ReLU(), + ) + self.pos_enc = pos_enc_class + self.right_context = 0 + self.subsampling_rate = 1 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Input x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: linear input tensor (#batch, time', odim), + where time' = time . + torch.Tensor: linear input mask (#batch, 1, time'), + where time' = time . + + """ + + x = self.out(x) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask diff --git a/inspiremusic/utils/__init__.py b/inspiremusic/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inspiremusic/utils/audio_utils.py b/inspiremusic/utils/audio_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c7b5d2afe3fd572b4f9fdfcece8b88d99870052a --- /dev/null +++ b/inspiremusic/utils/audio_utils.py @@ -0,0 +1,623 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import io +import logging +import re +import sys +import inspect +import random +import typing as tp +from functools import partial + +import omegaconf +import torch +import torchaudio +import numpy as np + +from typing_extensions import Literal +from typing import ( + Any, + Union, + Iterable, + List, + Dict, + Optional, + Tuple, +) + +from librosa.filters import mel as librosa_mel_fn +from scipy.io.wavfile import read + +_BoolLike_co = Union[bool, np.bool_] +_IntLike_co = Union[_BoolLike_co, int, "np.integer[Any]"] +_FloatLike_co = Union[_IntLike_co, float, "np.floating[Any]"] + +def process_audio(file_path, target_sample_rate=24000): + audio, sample_rate = torchaudio.load(file_path) + # Check if the audio needs to be resampled + if sample_rate != target_sample_rate: + audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)(audio) + # Convert stereo to mono (if necessary) + audio = audio.mean(dim=0, keepdim=True) if audio.size(0) == 2 else audio + return audio, target_sample_rate + +def load_wav(full_path): + sampling_rate, data = read(full_path) + return data, sampling_rate + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + +def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + # global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned + mel_basis = {} + hann_window = {} + if f"{str(fmax)}_{str(y.device)}" not in mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" + ) + y = y.squeeze(1) + + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[str(y.device)], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + + spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec + + +def fade_out(audio: torch.Tensor, sample_rate: int, + fade_duration: float) -> torch.Tensor: + """ + Apply a linear fade-out effect to the given audio waveform. + + Parameters: + audio (torch.Tensor): The audio waveform tensor. + sample_rate (int): Sample rate of the audio. + fade_duration (float): Duration of the fade-out effect in seconds. + + Returns: + torch.Tensor: The audio with the fade-out effect applied. + """ + fade_samples = int(fade_duration * sample_rate) + + if fade_samples > audio.shape[1]: + fade_samples = audio.shape[ + 1] # use the whole length of audio if necessary + + fade_out_envelope = torch.linspace(1.0, 0.0, fade_samples, + dtype=audio.dtype, device=audio.device) + + fade_section = audio[:, -fade_samples:].clone() + + fade_section *= fade_out_envelope + + faded_audio = audio.clone() + faded_audio[:, -fade_samples:] = fade_section + + return faded_audio + +def split_wav_into_chunks(num_samples, wav, max_chunk_size, minimum_chunk_size=720): + num_chunks = (num_samples + max_chunk_size - 1) // max_chunk_size # Ceiling division + wav_chunks = [] + for i in range(num_chunks): + start_idx = i * max_chunk_size + end_idx = min(start_idx + max_chunk_size, num_samples) + if (end_idx - start_idx) >= minimum_chunk_size: + if len(wav.shape) == 2: + chunk = wav[:,start_idx:end_idx] + else: + chunk = wav[start_idx:end_idx] + wav_chunks.append(chunk) + else: + print(f"{num_samples}:{num_chunks}, chunk size={(end_idx - start_idx)} is lower then minimum_chunk_size!") + return wav_chunks + +def tiny(x: Union[float, np.ndarray]) -> _FloatLike_co: + """Compute the tiny-value corresponding to an input's data type. + """ + # Make sure we have an array view + x = np.asarray(x) + + # Only floating types generate a tiny + if np.issubdtype(x.dtype, np.floating) or np.issubdtype( + x.dtype, np.complexfloating + ): + dtype = x.dtype + else: + dtype = np.dtype(np.float32) + + return np.finfo(dtype).tiny + +def detect_silence(audio, sample_rate, threshold=0.05, min_silence_duration=1): + """ + Detects the first occurrence of silence in the audio. + + Parameters: + audio (Tensor): The audio waveform. + sample_rate (int): The sample rate of the audio. + threshold (float): The threshold below which the signal is considered silent. + min_silence_duration (float): The minimum duration of silence in seconds. + + Returns: + int: The timestamp (in samples) where the silence starts. + """ + # Convert the audio to a numpy array for easier manipulation + audio_np = audio.numpy().flatten() + # Calculate the energy of the signal + energy = np.abs(audio_np) + # Find the indices where the energy is below the threshold + silent_indices = np.where(energy < threshold)[0] + # Find the start and end of contiguous silent regions + silent_regions = np.split(silent_indices, np.where(np.diff(silent_indices) != 1)[0] + 1) + # Filter out regions that are too short + min_silence_samples = int(min_silence_duration * sample_rate) + for region in silent_regions: + if len(region) >= min_silence_samples: + return region[0] + + # If no silence is found, return the length of the audio + return len(audio_np) + +def trim_audio(waveform, sample_rate=24000, threshold=0.05, min_silence_duration=1, minimum_silence_start_sample=24000): + """ + Trims the audio from the beginning to the first occurrence of silence. + + Parameters: + waveform (Tensor): The waveform data to the input audio file. + sample_rate (int): Sample rate of the input audio file. + threshold (float): The threshold below which the signal is considered silent. + min_silence_duration (float): The minimum duration of silence in seconds. + """ + # Detect the first occurrence of silence + silence_start_sample = detect_silence(waveform, sample_rate, threshold, min_silence_duration) + if silence_start_sample > minimum_silence_start_sample : + trimmed_waveform = waveform[:silence_start_sample] + else: + trimmed_waveform = waveform[:minimum_silence_start_sample] + if isinstance(trimmed_waveform, torch.Tensor): + return trimmed_waveform + else: + return trimmed_waveform.unsqueeze() + +def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14, + loudness_compressor: bool = False, energy_floor: float = 2e-3): + """Normalize an input signal to a user loudness in dB LKFS. + Audio loudness is defined according to the ITU-R BS.1770-4 recommendation. + + Args: + wav (torch.Tensor): Input multichannel audio data. + sample_rate (int): Sample rate. + loudness_headroom_db (float): Target loudness of the output in dB LUFS. + loudness_compressor (bool): Uses tanh for soft clipping. + energy_floor (float): anything below that RMS level will not be rescaled. + Returns: + torch.Tensor: Loudness normalized output data. + """ + energy = wav.pow(2).mean().sqrt().item() + if energy < energy_floor: + return wav + transform = torchaudio.transforms.Loudness(sample_rate) + input_loudness_db = transform(wav).item() + # calculate the gain needed to scale to the desired loudness level + delta_loudness = -loudness_headroom_db - input_loudness_db + gain = 10.0 ** (delta_loudness / 20.0) + output = gain * wav + if loudness_compressor: + output = torch.tanh(output) + assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt()) + return output + +def normalize( + S: np.ndarray, + *, + norm: Optional[float] = np.inf, + axis: Optional[int] = 0, + threshold: Optional[_FloatLike_co] = None, + fill: Optional[bool] = None, +) -> np.ndarray: + """Normalize an array along a chosen axis. + """ + # Avoid div-by-zero + if threshold is None: + threshold = tiny(S) + + elif threshold <= 0: + raise ParameterError(f"threshold={threshold} must be strictly positive") + + if fill not in [None, False, True]: + raise ParameterError(f"fill={fill} must be None or boolean") + + if not np.isfinite(S).all(): + raise ParameterError("Input must be finite") + + # All norms only depend on magnitude, let's do that first + S = S.numpy() + mag = np.abs(S).astype(float) + + # For max/min norms, filling with 1 works + fill_norm = 1 + + if norm is None: + return S + + elif norm == np.inf: + length = np.max(mag, axis=axis, keepdims=True) + + elif norm == -np.inf: + length = np.min(mag, axis=axis, keepdims=True) + + elif norm == 0: + if fill is True: + raise ParameterError("Cannot normalize with norm=0 and fill=True") + + length = np.sum(mag > 0, axis=axis, keepdims=True, dtype=mag.dtype) + + elif np.issubdtype(type(norm), np.number) and norm > 0: + length = np.sum(mag**norm, axis=axis, keepdims=True) ** (1.0 / norm) + + if axis is None: + fill_norm = mag.size ** (-1.0 / norm) + else: + fill_norm = mag.shape[axis] ** (-1.0 / norm) + + else: + raise ParameterError(f"Unsupported norm: {repr(norm)}") + + # indices where norm is below the threshold + small_idx = length < threshold + + Snorm = np.empty_like(S) + if fill is None: + # Leave small indices un-normalized + length[small_idx] = 1.0 + Snorm[:] = S / length + + elif fill: + # If we have a non-zero fill value, we locate those entries by + # doing a nan-divide. + # If S was finite, then length is finite (except for small positions) + length[small_idx] = np.nan + Snorm[:] = S / length + Snorm[np.isnan(Snorm)] = fill_norm + else: + # Set small values to zero by doing an inf-divide. + # This is safe (by IEEE-754) as long as S is finite. + length[small_idx] = np.inf + Snorm[:] = S / length + + return Snorm + +def normalize_audio(wav: torch.Tensor, normalize: bool = True, + strategy: str = 'peak', peak_clip_headroom_db: float = 1, + rms_headroom_db: float = 18, loudness_headroom_db: float = 14, + loudness_compressor: bool = False, log_clipping: bool = False, + sample_rate: tp.Optional[int] = None, + stem_name: tp.Optional[str] = None) -> torch.Tensor: + """Normalize the audio according to the prescribed strategy (see after). + + Args: + wav (torch.Tensor): Audio data. + normalize (bool): if `True` (default), normalizes according to the prescribed + strategy (see after). If `False`, the strategy is only used in case clipping + would happen. + strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak', + i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square + with extra headroom to avoid clipping. 'clip' just clips. + peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy. + rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger + than the `peak_clip` one to avoid further clipping. + loudness_headroom_db (float): Target loudness for loudness normalization. + loudness_compressor (bool): If True, uses tanh based soft clipping. + log_clipping (bool): If True, basic logging on stderr when clipping still + occurs despite strategy (only for 'rms'). + sample_rate (int): Sample rate for the audio data (required for loudness). + stem_name (str, optional): Stem name for clipping logging. + Returns: + torch.Tensor: Normalized audio. + """ + scale_peak = 10 ** (-peak_clip_headroom_db / 20) + scale_rms = 10 ** (-rms_headroom_db / 20) + if strategy == 'peak': + rescaling = (scale_peak / wav.abs().max()) + if normalize or rescaling < 1: + wav = wav * rescaling + elif strategy == 'clip': + wav = wav.clamp(-scale_peak, scale_peak) + elif strategy == 'rms': + mono = wav.mean(dim=0) + rescaling = scale_rms / mono.pow(2).mean().sqrt() + if normalize or rescaling < 1: + wav = wav * rescaling + _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name) + elif strategy == 'loudness': + assert sample_rate is not None, "Loudness normalization requires sample rate." + wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor) + _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name) + else: + assert wav.abs().max() < 1 + assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'" + return wav + + +def f32_pcm(wav: torch.Tensor) -> torch.Tensor: + """ + Convert audio to float 32 bits PCM format. + Args: + wav (torch.tensor): Input wav tensor + Returns: + same wav in float32 PCM format + """ + if wav.dtype.is_floating_point: + return wav + elif wav.dtype == torch.int16: + return wav.float() / 2**15 + elif wav.dtype == torch.int32: + return wav.float() / 2**31 + raise ValueError(f"Unsupported wav dtype: {wav.dtype}") + + +def i16_pcm(wav: torch.Tensor) -> torch.Tensor: + """Convert audio to int 16 bits PCM format. + + ..Warning:: There exist many formula for doing this conversion. None are perfect + due to the asymmetry of the int16 range. One either have possible clipping, DC offset, + or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom, + it is possible that `i16_pcm(f32_pcm)) != Identity`. + Args: + wav (torch.tensor): Input wav tensor + Returns: + same wav in float16 PCM format + """ + if wav.dtype.is_floating_point: + assert wav.abs().max() <= 1 + candidate = (wav * 2 ** 15).round() + if candidate.max() >= 2 ** 15: # clipping would occur + candidate = (wav * (2 ** 15 - 1)).round() + return candidate.short() + else: + assert wav.dtype == torch.int16 + return wav + + +def compress(wav: torch.Tensor, sr: int, + target_format: tp.Literal["mp3", "ogg", "flac"] = "mp3", + bitrate: str = "128k") -> tp.Tuple[torch.Tensor, int]: + """Convert audio wave form to a specified lossy format: mp3, ogg, flac + + Args: + wav (torch.Tensor): Input wav tensor. + sr (int): Sampling rate. + target_format (str): Compression format (e.g., 'mp3'). + bitrate (str): Bitrate for compression. + + Returns: + Tuple of compressed WAV tensor and sampling rate. + """ + + # Extract the bit rate from string (e.g., '128k') + match = re.search(r"\d+(\.\d+)?", str(bitrate)) + parsed_bitrate = float(match.group()) if match else None + assert parsed_bitrate, f"Invalid bitrate specified (got {parsed_bitrate})" + try: + # Create a virtual file instead of saving to disk + buffer = io.BytesIO() + + torchaudio.save( + buffer, wav, sr, format=target_format, bits_per_sample=parsed_bitrate, + ) + # Move to the beginning of the file + buffer.seek(0) + compressed_wav, sr = torchaudio.load(buffer) + return compressed_wav, sr + + except RuntimeError: + logger.warning( + f"compression failed skipping compression: {format} {parsed_bitrate}" + ) + return wav, sr + + +def get_mp3(wav_tensor: torch.Tensor, sr: int, bitrate: str = "128k") -> torch.Tensor: + """Convert a batch of audio files to MP3 format, maintaining the original shape. + + This function takes a batch of audio files represented as a PyTorch tensor, converts + them to MP3 format using the specified bitrate, and returns the batch in the same + shape as the input. + + Args: + wav_tensor (torch.Tensor): Batch of audio files represented as a tensor. + Shape should be (batch_size, channels, length). + sr (int): Sampling rate of the audio. + bitrate (str): Bitrate for MP3 conversion, default is '128k'. + + Returns: + torch.Tensor: Batch of audio files converted to MP3 format, with the same + shape as the input tensor. + """ + device = wav_tensor.device + batch_size, channels, original_length = wav_tensor.shape + + # Flatten tensor for conversion and move to CPU + wav_tensor_flat = wav_tensor.view(1, -1).cpu() + + # Convert to MP3 format with specified bitrate + wav_tensor_flat, _ = compress(wav_tensor_flat, sr, bitrate=bitrate) + + # Reshape back to original batch format and trim or pad if necessary + wav_tensor = wav_tensor_flat.view(batch_size, channels, -1) + compressed_length = wav_tensor.shape[-1] + if compressed_length > original_length: + wav_tensor = wav_tensor[:, :, :original_length] # Trim excess frames + elif compressed_length < original_length: + padding = torch.zeros( + batch_size, channels, original_length - compressed_length, device=device + ) + wav_tensor = torch.cat((wav_tensor, padding), dim=-1) # Pad with zeros + + # Move tensor back to the original device + return wav_tensor.to(device) + + +def get_aac( + wav_tensor: torch.Tensor, + sr: int, + bitrate: str = "128k", + lowpass_freq: tp.Optional[int] = None, +) -> torch.Tensor: + """Converts a batch of audio tensors to AAC format and then back to tensors. + + This function first saves the input tensor batch as WAV files, then uses FFmpeg to convert + these WAV files to AAC format. Finally, it loads the AAC files back into tensors. + + Args: + wav_tensor (torch.Tensor): A batch of audio files represented as a tensor. + Shape should be (batch_size, channels, length). + sr (int): Sampling rate of the audio. + bitrate (str): Bitrate for AAC conversion, default is '128k'. + lowpass_freq (Optional[int]): Frequency for a low-pass filter. If None, no filter is applied. + + Returns: + torch.Tensor: Batch of audio files converted to AAC and back, with the same + shape as the input tensor. + """ + import tempfile + import subprocess + + device = wav_tensor.device + batch_size, channels, original_length = wav_tensor.shape + + # Parse the bitrate value from the string + match = re.search(r"\d+(\.\d+)?", bitrate) + parsed_bitrate = ( + match.group() if match else "128" + ) # Default to 128 if parsing fails + + # Flatten tensor for conversion and move to CPU + wav_tensor_flat = wav_tensor.view(1, -1).cpu() + + with tempfile.NamedTemporaryFile( + suffix=".wav" + ) as f_in, tempfile.NamedTemporaryFile(suffix=".aac") as f_out: + input_path, output_path = f_in.name, f_out.name + + # Save the tensor as a WAV file + torchaudio.save(input_path, wav_tensor_flat, sr, backend="ffmpeg") + + # Prepare FFmpeg command for AAC conversion + command = [ + "ffmpeg", + "-y", + "-i", + input_path, + "-ar", + str(sr), + "-b:a", + f"{parsed_bitrate}k", + "-c:a", + "aac", + ] + if lowpass_freq is not None: + command += ["-cutoff", str(lowpass_freq)] + command.append(output_path) + + try: + # Run FFmpeg and suppress output + subprocess.run(command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + + # Load the AAC audio back into a tensor + aac_tensor, _ = torchaudio.load(output_path, backend="ffmpeg") + except Exception as exc: + raise RuntimeError( + "Failed to run command " ".join(command)} " + "(Often this means ffmpeg is not installed or the encoder is not supported, " + "make sure you installed an older version ffmpeg<5)" + ) from exc + + original_length_flat = batch_size * channels * original_length + compressed_length_flat = aac_tensor.shape[-1] + + # Trim excess frames + if compressed_length_flat > original_length_flat: + aac_tensor = aac_tensor[:, :original_length_flat] + + # Pad the shortedn frames + elif compressed_length_flat < original_length_flat: + padding = torch.zeros( + 1, original_length_flat - compressed_length_flat, device=device + ) + aac_tensor = torch.cat((aac_tensor, padding), dim=-1) + + # Reshape and adjust length to match original tensor + wav_tensor = aac_tensor.view(batch_size, channels, -1) + compressed_length = wav_tensor.shape[-1] + + assert compressed_length == original_length, ( + "AAC-compressed audio does not have the same frames as original one. " + "One reason can be ffmpeg is not installed and used as proper backed " + "for torchaudio, or the AAC encoder is not correct. Run " + "`torchaudio.utils.ffmpeg_utils.get_audio_encoders()` and make sure we see entry for" + "AAC in the output." + ) + return wav_tensor.to(device) \ No newline at end of file diff --git a/inspiremusic/utils/binary.py b/inspiremusic/utils/binary.py new file mode 100644 index 0000000000000000000000000000000000000000..862cb467850b9af8e8b035939c018984e590e79c --- /dev/null +++ b/inspiremusic/utils/binary.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Raw binary format for Encodec compressed audio. Actual compression API is in `encodec.compress`.""" +import io +import json +import struct +import typing as tp + +# format is `ECDC` magic code, followed by the header size as uint32. +# Then an uint8 indicates the protocol version (0.) +# The header is then provided as json and should contain all required +# informations for decoding. A raw stream of bytes is then provided +# and should be interpretable using the json header. +_encodec_header_struct = struct.Struct('!4sBI') +_ENCODEC_MAGIC = b'ECDC' + + +def write_ecdc_header(fo: tp.IO[bytes], metadata: tp.Any): + meta_dumped = json.dumps(metadata).encode('utf-8') + version = 0 + header = _encodec_header_struct.pack(_ENCODEC_MAGIC, version, + len(meta_dumped)) + fo.write(header) + fo.write(meta_dumped) + fo.flush() + + +def _read_exactly(fo: tp.IO[bytes], size: int) -> bytes: + buf = b"" + while len(buf) < size: + new_buf = fo.read(size) + if not new_buf: + raise EOFError("Impossible to read enough data from the stream, " + f"{size} bytes remaining.") + buf += new_buf + size -= len(new_buf) + return buf + + +def read_ecdc_header(fo: tp.IO[bytes]): + header_bytes = _read_exactly(fo, _encodec_header_struct.size) + magic, version, meta_size = _encodec_header_struct.unpack(header_bytes) + if magic != _ENCODEC_MAGIC: + raise ValueError("File is not in ECDC format.") + if version != 0: + raise ValueError("Version not supported.") + meta_bytes = _read_exactly(fo, meta_size) + return json.loads(meta_bytes.decode('utf-8')) + + +class BitPacker: + """Simple bit packer to handle ints with a non standard width, e.g. 10 bits. + Note that for some bandwidth (1.5, 3), the codebook representation + will not cover an integer number of bytes. + + Args: + bits (int): number of bits per value that will be pushed. + fo (IO[bytes]): file-object to push the bytes to. + """ + + def __init__(self, bits: int, fo: tp.IO[bytes]): + self._current_value = 0 + self._current_bits = 0 + self.bits = bits + self.fo = fo + + def push(self, value: int): + """Push a new value to the stream. This will immediately + write as many uint8 as possible to the underlying file-object.""" + self._current_value += (value << self._current_bits) + self._current_bits += self.bits + while self._current_bits >= 8: + lower_8bits = self._current_value & 0xff + self._current_bits -= 8 + self._current_value >>= 8 + self.fo.write(bytes([lower_8bits])) + + def flush(self): + """Flushes the remaining partial uint8, call this at the end + of the stream to encode.""" + if self._current_bits: + self.fo.write(bytes([self._current_value])) + self._current_value = 0 + self._current_bits = 0 + self.fo.flush() + + +class BitUnpacker: + """BitUnpacker does the opposite of `BitPacker`. + + Args: + bits (int): number of bits of the values to decode. + fo (IO[bytes]): file-object to push the bytes to. + """ + + def __init__(self, bits: int, fo: tp.IO[bytes]): + self.bits = bits + self.fo = fo + self._mask = (1 << bits) - 1 + self._current_value = 0 + self._current_bits = 0 + + def pull(self) -> tp.Optional[int]: + """ + Pull a single value from the stream, potentially reading some + extra bytes from the underlying file-object. + Returns `None` when reaching the end of the stream. + """ + while self._current_bits < self.bits: + buf = self.fo.read(1) + if not buf: + return None + character = buf[0] + self._current_value += character << self._current_bits + self._current_bits += 8 + + out = self._current_value & self._mask + self._current_value >>= self.bits + self._current_bits -= self.bits + return out + + +def test(): + import torch + torch.manual_seed(1234) + for rep in range(4): + length: int = torch.randint(10, 2_000, (1, )).item() + bits: int = torch.randint(1, 16, (1, )).item() + tokens: tp.List[int] = torch.randint(2**bits, (length, )).tolist() + rebuilt: tp.List[int] = [] + buf = io.BytesIO() + packer = BitPacker(bits, buf) + for token in tokens: + packer.push(token) + packer.flush() + buf.seek(0) + unpacker = BitUnpacker(bits, buf) + while True: + value = unpacker.pull() + if value is None: + break + rebuilt.append(value) + assert len(rebuilt) >= len(tokens), (len(rebuilt), len(tokens)) + # The flushing mechanism might lead to "ghost" values at the end of the stream. + assert len(rebuilt) <= len(tokens) + 8 // bits, (len(rebuilt), + len(tokens), bits) + for idx, (a, b) in enumerate(zip(tokens, rebuilt)): + assert a == b, (idx, a, b) + + +if __name__ == '__main__': + test() diff --git a/inspiremusic/utils/class_utils.py b/inspiremusic/utils/class_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2a6ddd08863b54afd24e92268dbd1faf15114b3e --- /dev/null +++ b/inspiremusic/utils/class_utils.py @@ -0,0 +1,71 @@ +# Copyright [2023-11-28] +# 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from inspiremusic.transformer.activation import Swish +from inspiremusic.transformer.subsampling import ( + LinearNoSubsampling, + EmbedinigNoSubsampling, + Conv1dSubsampling2, + Conv2dSubsampling4, + Conv2dSubsampling6, + Conv2dSubsampling8, +) +from inspiremusic.transformer.embedding import (PositionalEncoding, + RelPositionalEncoding, + WhisperPositionalEncoding, + LearnablePositionalEncoding, + NoPositionalEncoding) +from inspiremusic.transformer.attention import (MultiHeadedAttention, + RelPositionMultiHeadedAttention) +from inspiremusic.transformer.embedding import EspnetRelPositionalEncoding +from inspiremusic.transformer.subsampling import LegacyLinearNoSubsampling + + +INSPIREMUSIC_ACTIVATION_CLASSES = { + "hardtanh": torch.nn.Hardtanh, + "tanh": torch.nn.Tanh, + "relu": torch.nn.ReLU, + "selu": torch.nn.SELU, + "swish": getattr(torch.nn, "SiLU", Swish), + "gelu": torch.nn.GELU, +} + +INSPIREMUSIC_SUBSAMPLE_CLASSES = { + "linear": LinearNoSubsampling, + "linear_legacy": LegacyLinearNoSubsampling, + "embed": EmbedinigNoSubsampling, + "conv1d2": Conv1dSubsampling2, + "conv2d": Conv2dSubsampling4, + "conv2d6": Conv2dSubsampling6, + "conv2d8": Conv2dSubsampling8, + 'paraformer_dummy': torch.nn.Identity +} + +INSPIREMUSIC_EMB_CLASSES = { + "embed": PositionalEncoding, + "abs_pos": PositionalEncoding, + "rel_pos": RelPositionalEncoding, + "rel_pos_espnet": EspnetRelPositionalEncoding, + "no_pos": NoPositionalEncoding, + "abs_pos_whisper": WhisperPositionalEncoding, + "embed_learnable_pe": LearnablePositionalEncoding, +} + +INSPIREMUSIC_ATTENTION_CLASSES = { + "selfattn": MultiHeadedAttention, + "rel_selfattn": RelPositionMultiHeadedAttention, +} + diff --git a/inspiremusic/utils/common.py b/inspiremusic/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..d888fa4b71743142a968d70b7c0bcd09a32de70b --- /dev/null +++ b/inspiremusic/utils/common.py @@ -0,0 +1,173 @@ +# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) +# 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""Unility functions for Transformer.""" + +from typing import List + +import torch +IGNORE_ID = -1 + +MUSIC_STRUCTURE_LABELS = ["intro", "verse1", "chorus", "verse2", "outro"] + +def pad_list(xs: List[torch.Tensor], pad_value: int): + """Perform padding for the list of tensors. + + Args: + xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. + pad_value (float): Value for padding. + + Returns: + Tensor: Padded tensor (B, Tmax, `*`). + + Examples: + >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] + >>> x + [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] + >>> pad_list(x, 0) + tensor([[1., 1., 1., 1.], + [1., 1., 0., 0.], + [1., 0., 0., 0.]]) + + """ + max_len = max([len(item) for item in xs]) + batchs = len(xs) + ndim = xs[0].ndim + if ndim == 1: + pad_res = torch.zeros(batchs, + max_len, + dtype=xs[0].dtype, + device=xs[0].device) + elif ndim == 2: + pad_res = torch.zeros(batchs, + max_len, + xs[0].shape[1], + dtype=xs[0].dtype, + device=xs[0].device) + elif ndim == 3: + pad_res = torch.zeros(batchs, + max_len, + xs[0].shape[1], + xs[0].shape[2], + dtype=xs[0].dtype, + device=xs[0].device) + else: + raise ValueError(f"Unsupported ndim: {ndim}") + pad_res.fill_(pad_value) + for i in range(batchs): + pad_res[i, :len(xs[i])] = xs[i] + return pad_res + + +def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor, + ignore_label: int) -> torch.Tensor: + """Calculate accuracy. + + Args: + pad_outputs (Tensor): Prediction tensors (B * Lmax, D). + pad_targets (LongTensor): Target label tensors (B, Lmax). + ignore_label (int): Ignore label id. + + Returns: + torch.Tensor: Accuracy value (0.0 - 1.0). + + """ + pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1), + pad_outputs.size(1)).argmax(2) + mask = pad_targets != ignore_label + numerator = torch.sum( + pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) + denominator = torch.sum(mask) + return (numerator / denominator).detach() + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + +def topk_sampling(weighted_scores, decoded_tokens, top_k=25): + zeros = weighted_scores.new_ones(weighted_scores.shape) * float('-inf') + values,indices = torch.topk(weighted_scores,top_k) + zeros.scatter_(-1, indices, values) + return random_sampling(zeros,decoded_tokens) + +# Repetition Aware Sampling in VALL-E 2 + +def ras_sampling(weighted_scores, decoded_tokens, top_p=0.8, top_k=25, win_size=10, tau_r=0.1): + top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k) + rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item() + if rep_num >= win_size * tau_r: + top_ids = random_sampling(weighted_scores, decoded_tokens) + return top_ids + +def caras_sampling(weighted_scores, decoded_tokens, top_p=0.8, top_k=25, win_size=10, tau_r=0.1): + weighted_scores, cfg_weighted_scores = weighted_scores + top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k) + rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item() + if rep_num >= win_size * tau_r: + top_ids = random_sampling(cfg_weighted_scores, decoded_tokens) + return top_ids + +def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25): + prob, indices = [], [] + cum_prob = 0.0 + sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True) + for i in range(len(sorted_idx)): + # sampling both top-p and numbers. + if cum_prob < top_p and len(prob) < top_k: + cum_prob += sorted_value[i] + prob.append(sorted_value[i]) + indices.append(sorted_idx[i]) + else: + break + prob = torch.tensor(prob).to(weighted_scores) + indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device) + top_ids = indices[prob.multinomial(1, replacement=True)] + return top_ids + + +def random_sampling(weighted_scores, decoded_tokens): + top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True) + return top_ids + + +def fade_in_out(fade_in_mel, fade_out_mel, window): + device = fade_in_mel.device + fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu() + mel_overlap_len = int(window.shape[0] / 2) + fade_in_mel[:, :, :mel_overlap_len] = fade_in_mel[:, :, :mel_overlap_len] * window[:mel_overlap_len] + \ + fade_out_mel[:, :, -mel_overlap_len:] * window[mel_overlap_len:] + return fade_in_mel.to(device) + +def set_all_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + assert mask.dtype == torch.bool + assert dtype in [torch.float32, torch.bfloat16, torch.float16] + mask = mask.to(dtype) + # attention mask bias + # NOTE(Mddct): torch.finfo jit issues + # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min + mask = (1.0 - mask) * torch.finfo(dtype).min + return mask \ No newline at end of file diff --git a/inspiremusic/utils/data_utils.py b/inspiremusic/utils/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..094a280f25ac50be81253854e62790ce50e5e7de --- /dev/null +++ b/inspiremusic/utils/data_utils.py @@ -0,0 +1,104 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torch.utils.data import DataLoader +from inspiremusic.dataset.dataset import Dataset +import numpy as np +import librosa + +def audio_process_dataset_and_dataloader(args, configs): + input_dataset = Dataset(args.input_data, data_pipeline=configs['data_pipeline'], mode='processing', shuffle=True, partition=True) + # do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts + input_data_loader = DataLoader(input_dataset, + batch_size=None, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + prefetch_factor=args.prefetch) + return input_dataset, input_data_loader + +def is_silent(wav_path, threshold=0.01, frame_length=2048, hop_length=512): + y, sr = librosa.load(wav_path, sr=None) + rms = librosa.feature.rms(y=y, frame_length=frame_length, hop_length=hop_length)[0] + silent_frames = np.sum(rms < threshold) / len(rms) + silence_fraction_threshold = 0.95 + return silent_frames >= silence_fraction_threshold + +def rich_captions(text=None, tags=None, lyrics=None, chorus="verse", start_time=0.0, end_time=30.0): + if text is None and tags is None and lyrics is None: + return None + else: + if start_time is None: + start_time = 0.0 + if end_time is None: + end_time = 30.0 + if chorus is None: + chorus = "verse" + captions = f"<|{start_time:.1f}|><|{chorus}|>" + if tags is not None: + captions += f"<|{tags}|>" + if text is not None: + captions += f"<|{text}|>" + if lyrics is not None: + captions += f"<|lyrics|><|{lyrics}|>" + captions += f"<|{end_time:.1f}|>" + return captions + +def process_tags(infile, outfile, timefile = None): + key_list = [] + with open(infile, "r") as f: + for line in f: + sec = line.strip() + key_list.append(sec) + f.close() + if timefile is None: + with open(outfile, 'w') as f: + for k in key_list: + parts = k.rsplit('_', 1) + text = parts[0].replace('_', ' ') + ', ' + parts[1] + caption = rich_captions(text, None, None) + if caption is not None: + f.write("%s\t%s\n" %(k, caption)) + f.close() + else: + times = {} + with open(timefile, "r") as f: + for line in f: + sec = line.strip().split("\t") + if len(sec) == 2 : + times[sec[0]] = sec[1] + f.close() + + with open(outfile, 'w') as f: + for k in key_list: + parts = k.rsplit('_', 1) + text = parts[0].replace('_', ' ') + ', ' + parts[1] + if k in times.keys(): + caption = rich_captions(text, None, None, "verse", 0.0, float(times[k])) + if caption is not None: + f.write("%s\t%s\n" %(k, caption)) + f.close() + +def process_trans(infile, outfile): + trans = {} + with open(infile, "r") as f: + for line in f: + sec = line.strip().split("\t") + if len(sec) == 2: + trans[sec[0]] = sec[1] + else: + print(line) + f.close() + with open(outfile, 'w') as f: + for k, v in trans.items(): + f.write("%s\t%s\n" %(k, rich_captions(v))) + f.close() \ No newline at end of file diff --git a/inspiremusic/utils/executor.py b/inspiremusic/utils/executor.py new file mode 100644 index 0000000000000000000000000000000000000000..da933e4942c6095da87fb26e964c39c809925ad3 --- /dev/null +++ b/inspiremusic/utils/executor.py @@ -0,0 +1,121 @@ +# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) +# 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from contextlib import nullcontext +import os + +import torch +import torch.distributed as dist + +from inspiremusic.utils.train_utils import update_parameter_and_lr, log_per_step, log_per_save, batch_forward, batch_backward, save_model, inspiremusic_join +from torch.cuda.amp import GradScaler, autocast + +class Executor: + + def __init__(self): + self.step = 0 + self.epoch = 0 + self.rank = int(os.environ.get('RANK', 0)) + self.device = torch.device('cuda:{}'.format(self.rank)) + + def train_one_epoch(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join, scaler=None): + ''' Train one epoch + ''' + + lr = optimizer.param_groups[0]['lr'] + logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank)) + logging.info('using accumulate grad, new batch size is {} times' + ' larger than before'.format(info_dict['accum_grad'])) + # A context manager to be used in conjunction with an instance of + # torch.nn.parallel.DistributedDataParallel to be able to train + # with uneven inputs across participating processes. + model.train() + model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext + with model_context(): + for batch_idx, batch_dict in enumerate(train_data_loader): + info_dict["tag"] = "TRAIN" + info_dict["step"] = self.step + info_dict["epoch"] = self.epoch + info_dict["batch_idx"] = batch_idx + if inspiremusic_join(group_join, info_dict): + break + + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0: + context = model.no_sync + # Used for single gpu training and DDP gradient synchronization + # processes. + else: + context = nullcontext + + with context(): + with autocast(enabled=scaler is not None): + info_dict = batch_forward(model, batch_dict, info_dict, scaler) + info_dict = batch_backward(model, info_dict, scaler) + + info_dict = update_parameter_and_lr(model, optimizer, scheduler, info_dict, scaler) + log_per_step(writer, info_dict) + # NOTE specify save_per_step in inspiremusic.yaml if you want to enable step save + if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \ + (batch_idx + 1) % info_dict["accum_grad"] == 0: + dist.barrier() + self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False, scaler=scaler) + model.train() + if (batch_idx + 1) % info_dict["accum_grad"] == 0: + self.step += 1 + dist.barrier() + self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True, scaler=scaler) + + @torch.inference_mode() + def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True, capped_at=5, scaler=None): + ''' Cross validation on + ''' + logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step + 1, on_batch_end, self.rank)) + model.eval() + total_num_utts, total_loss_dict = 0, {} # avoid division by 0 + stop = capped_at + for batch_idx, batch_dict in enumerate(cv_data_loader): + info_dict["tag"] = "CV" + info_dict["step"] = self.step + info_dict["epoch"] = self.epoch + info_dict["batch_idx"] = batch_idx + + num_utts = len(batch_dict["utts"]) + total_num_utts += num_utts + + if capped_at>0: + if stop <= 0: + continue + else: + stop -= 1 + + with autocast(enabled=scaler is not None): + info_dict = batch_forward(model, batch_dict, info_dict, scaler) + + for k, v in info_dict['loss_dict'].items(): + if k not in total_loss_dict: + total_loss_dict[k] = [] + total_loss_dict[k].append(v.item() * num_utts) + log_per_step(None, info_dict) + + for k, v in total_loss_dict.items(): + total_loss_dict[k] = sum(v) / total_num_utts + info_dict['loss_dict'] = total_loss_dict + log_per_save(writer, info_dict) + model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step + 1) + save_model(model, model_name, info_dict) diff --git a/inspiremusic/utils/file_utils.py b/inspiremusic/utils/file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..56019b4fc015681e2b4d7a79cc2e186c6007d079 --- /dev/null +++ b/inspiremusic/utils/file_utils.py @@ -0,0 +1,79 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import torchaudio +import logging +logging.getLogger('matplotlib').setLevel(logging.WARNING) +logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + +def read_trans(list_file): + trans = {} + with open(list_file, 'r', encoding='utf8') as fin: + for line in fin: + sec = line.strip().split("\t") + if len(sec) > 1: + if sec[0] not in trans.keys(): + trans[sec[0]] = sec[1] + return trans + +def read_scp(list_file): + scp = {} + with open(list_file, 'r', encoding='utf8') as fin: + for line in fin: + sec = line.strip().split(" ") + if len(sec) > 1: + if sec[0] not in scp.keys(): + scp[sec[0]] = sec[1] + return scp + +def read_lists(list_file): + lists = [] + with open(list_file, 'r', encoding='utf8') as fin: + for line in fin: + lists.append(line.strip()) + return lists + + +def read_json_lists(list_file): + lists = read_lists(list_file) + results = {} + for fn in lists: + with open(fn, 'r', encoding='utf8') as fin: + results.update(json.load(fin)) + return results + + +def load_wav(wav, target_sr): + audio, sample_rate = torchaudio.load(wav) + audio = audio.mean(dim=0, keepdim=True) + if sample_rate != target_sr: + assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr) + audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech) + return audio + + +def speed_change(waveform, sample_rate, speed_factor: str): + effects = [ + ["tempo", speed_factor], # speed_factor + ["rate", f"{sample_rate}"] + ] + augmented_waveform, new_sample_rate = torchaudio.sox_effects.apply_effects_tensor( + waveform, + sample_rate, + effects + ) + return augmented_waveform, new_sample_rate diff --git a/inspiremusic/utils/frontend_utils.py b/inspiremusic/utils/frontend_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..45239b854718afe0dc3ce194aa37a0ac6bb760eb --- /dev/null +++ b/inspiremusic/utils/frontend_utils.py @@ -0,0 +1,126 @@ +# Copyright (c) 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+') + + +# whether contain chinese character +def contains_chinese(text): + return bool(chinese_char_pattern.search(text)) + + +# replace special symbol +def replace_corner_mark(text): + text = text.replace('ยฒ', 'ๅนณๆ–น') + text = text.replace('ยณ', '็ซ‹ๆ–น') + return text + + +# remove meaningless symbol +def remove_bracket(text): + text = text.replace('๏ผˆ', '').replace('๏ผ‰', '') + text = text.replace('ใ€', '').replace('ใ€‘', '') + text = text.replace('`', '').replace('`', '') + text = text.replace("โ€”โ€”", " ") + return text + + +# spell Arabic numerals +def spell_out_number(text: str, inflect_parser): + new_text = [] + st = None + for i, c in enumerate(text): + if not c.isdigit(): + if st is not None: + num_str = inflect_parser.number_to_words(text[st: i]) + new_text.append(num_str) + st = None + new_text.append(c) + else: + if st is None: + st = i + if st is not None and st < len(text): + num_str = inflect_parser.number_to_words(text[st:]) + new_text.append(num_str) + return ''.join(new_text) + + +# split paragrah logic๏ผš +# 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len +# 2. cal sentence len according to lang +# 3. split sentence according to puncatation +def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False): + def calc_utt_length(_text: str): + if lang == "zh": + return len(_text) + else: + return len(tokenize(_text)) + + def should_merge(_text: str): + if lang == "zh": + return len(_text) < merge_len + else: + return len(tokenize(_text)) < merge_len + + if lang == "zh": + pounc = ['ใ€‚', '๏ผŸ', '๏ผ', '๏ผ›', '๏ผš', 'ใ€', '.', '?', '!', ';'] + else: + pounc = ['.', '?', '!', ';', ':'] + if comma_split: + pounc.extend(['๏ผŒ', ',']) + st = 0 + utts = [] + for i, c in enumerate(text): + if c in pounc: + if len(text[st: i]) > 0: + utts.append(text[st: i] + c) + if i + 1 < len(text) and text[i + 1] in ['"', 'โ€']: + tmp = utts.pop(-1) + utts.append(tmp + text[i + 1]) + st = i + 2 + else: + st = i + 1 + if len(utts) == 0: + if lang == "zh": + utts.append(text + 'ใ€‚') + else: + utts.append(text + '.') + final_utts = [] + cur_utt = "" + for utt in utts: + if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n: + final_utts.append(cur_utt) + cur_utt = "" + cur_utt = cur_utt + utt + if len(cur_utt) > 0: + if should_merge(cur_utt) and len(final_utts) != 0: + final_utts[-1] = final_utts[-1] + cur_utt + else: + final_utts.append(cur_utt) + + return final_utts + + +# remove blank between chinese character +def replace_blank(text: str): + out_str = [] + for i, c in enumerate(text): + if c == " ": + if ((text[i + 1].isascii() and text[i + 1] != " ") and + (text[i - 1].isascii() and text[i - 1] != " ")): + out_str.append(c) + else: + out_str.append(c) + return "".join(out_str) diff --git a/inspiremusic/utils/hinter.py b/inspiremusic/utils/hinter.py new file mode 100644 index 0000000000000000000000000000000000000000..6b32194336ed680a12772e2e3d3ed48d13cbcf59 --- /dev/null +++ b/inspiremusic/utils/hinter.py @@ -0,0 +1,12 @@ +import sys +import torch.distributed +import logging + +HINTED = set() + + +def hint_once(content, uid, rank=None): + if (rank is None) or (not torch.distributed.is_initialized()) or torch.distributed.get_rank() == rank: + if uid not in HINTED: + logging.info(content, stacklevel=3) + HINTED.add(uid) \ No newline at end of file diff --git a/inspiremusic/utils/losses.py b/inspiremusic/utils/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..78efd3b72ff4c61971f4732626c43613f812761d --- /dev/null +++ b/inspiremusic/utils/losses.py @@ -0,0 +1,20 @@ +import torch +import torch.nn.functional as F + + +def tpr_loss(disc_real_outputs, disc_generated_outputs, tau): + loss = 0 + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + m_DG = torch.median((dr - dg)) + L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG]) + loss += tau - F.relu(tau - L_rel) + return loss + + +def mel_loss(real_speech, generated_speech, mel_transforms): + loss = 0 + for transform in mel_transforms: + mel_r = transform(real_speech) + mel_g = transform(generated_speech) + loss += F.l1_loss(mel_g, mel_r) + return loss diff --git a/inspiremusic/utils/mask.py b/inspiremusic/utils/mask.py new file mode 100644 index 0000000000000000000000000000000000000000..b3d767dfe0b9a37295c49d6670e5600e60fe075e --- /dev/null +++ b/inspiremusic/utils/mask.py @@ -0,0 +1,227 @@ +# Copyright (c) 2019 Shigeki Karita +# 2020 Mobvoi Inc (Binbin Zhang) +# 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +''' +def subsequent_mask( + size: int, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: + """Create mask for subsequent steps (size, size). + + This mask is used only in decoder which works in an auto-regressive mode. + This means the current step could only do attention with its left steps. + + In encoder, fully attention is used when streaming is not necessary and + the sequence is not long. In this case, no attention mask is needed. + + When streaming is need, chunk-based attention is used in encoder. See + subsequent_chunk_mask for the chunk-based attention mask. + + Args: + size (int): size of mask + str device (str): "cpu" or "cuda" or torch.Tensor.device + dtype (torch.device): result dtype + + Returns: + torch.Tensor: mask + + Examples: + >>> subsequent_mask(3) + [[1, 0, 0], + [1, 1, 0], + [1, 1, 1]] + """ + ret = torch.ones(size, size, device=device, dtype=torch.bool) + return torch.tril(ret) +''' + + +def subsequent_mask( + size: int, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: + """Create mask for subsequent steps (size, size). + + This mask is used only in decoder which works in an auto-regressive mode. + This means the current step could only do attention with its left steps. + + In encoder, fully attention is used when streaming is not necessary and + the sequence is not long. In this case, no attention mask is needed. + + When streaming is need, chunk-based attention is used in encoder. See + subsequent_chunk_mask for the chunk-based attention mask. + + Args: + size (int): size of mask + str device (str): "cpu" or "cuda" or torch.Tensor.device + dtype (torch.device): result dtype + + Returns: + torch.Tensor: mask + + Examples: + >>> subsequent_mask(3) + [[1, 0, 0], + [1, 1, 0], + [1, 1, 1]] + """ + arange = torch.arange(size, device=device) + mask = arange.expand(size, size) + arange = arange.unsqueeze(-1) + mask = mask <= arange + return mask + + +def subsequent_chunk_mask( + size: int, + chunk_size: int, + num_left_chunks: int = -1, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: + """Create mask for subsequent steps (size, size) with chunk size, + this is for streaming encoder + + Args: + size (int): size of mask + chunk_size (int): size of chunk + num_left_chunks (int): number of left chunks + <0: use full chunk + >=0: use num_left_chunks + device (torch.device): "cpu" or "cuda" or torch.Tensor.device + + Returns: + torch.Tensor: mask + + Examples: + >>> subsequent_chunk_mask(4, 2) + [[1, 1, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 1], + [1, 1, 1, 1]] + """ + ret = torch.zeros(size, size, device=device, dtype=torch.bool) + for i in range(size): + if num_left_chunks < 0: + start = 0 + else: + start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) + ending = min((i // chunk_size + 1) * chunk_size, size) + ret[i, start:ending] = True + return ret + + +def add_optional_chunk_mask(xs: torch.Tensor, + masks: torch.Tensor, + use_dynamic_chunk: bool, + use_dynamic_left_chunk: bool, + decoding_chunk_size: int, + static_chunk_size: int, + num_decoding_left_chunks: int, + enable_full_context: bool = True): + """ Apply optional mask for encoder. + + Args: + xs (torch.Tensor): padded input, (B, L, D), L for max length + mask (torch.Tensor): mask for xs, (B, 1, L) + use_dynamic_chunk (bool): whether to use dynamic chunk or not + use_dynamic_left_chunk (bool): whether to use dynamic left chunk for + training. + decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's + 0: default for training, use random dynamic chunk. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + static_chunk_size (int): chunk size for static chunk training/decoding + if it's greater than 0, if use_dynamic_chunk is true, + this parameter will be ignored + num_decoding_left_chunks: number of left chunks, this is for decoding, + the chunk size is decoding_chunk_size. + >=0: use num_decoding_left_chunks + <0: use all left chunks + enable_full_context (bool): + True: chunk size is either [1, 25] or full context(max_len) + False: chunk size ~ U[1, 25] + + Returns: + torch.Tensor: chunk mask of the input xs. + """ + # Whether to use chunk mask or not + if use_dynamic_chunk: + max_len = xs.size(1) + if decoding_chunk_size < 0: + chunk_size = max_len + num_left_chunks = -1 + elif decoding_chunk_size > 0: + chunk_size = decoding_chunk_size + num_left_chunks = num_decoding_left_chunks + else: + # chunk size is either [1, 25] or full context(max_len). + # Since we use 4 times subsampling and allow up to 1s(100 frames) + # delay, the maximum frame is 100 / 4 = 25. + chunk_size = torch.randint(1, max_len, (1, )).item() + num_left_chunks = -1 + if chunk_size > max_len // 2 and enable_full_context: + chunk_size = max_len + else: + chunk_size = chunk_size % 25 + 1 + if use_dynamic_left_chunk: + max_left_chunks = (max_len - 1) // chunk_size + num_left_chunks = torch.randint(0, max_left_chunks, + (1, )).item() + chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size, + num_left_chunks, + xs.device) # (L, L) + chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) + chunk_masks = masks & chunk_masks # (B, L, L) + elif static_chunk_size > 0: + num_left_chunks = num_decoding_left_chunks + chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size, + num_left_chunks, + xs.device) # (L, L) + chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) + chunk_masks = masks & chunk_masks # (B, L, L) + else: + chunk_masks = masks + return chunk_masks + + +def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: + """Make mask tensor containing indices of padded part. + + See description of make_non_pad_mask. + + Args: + lengths (torch.Tensor): Batch of lengths (B,). + Returns: + torch.Tensor: Mask tensor containing indices of padded part. + + Examples: + >>> lengths = [5, 3, 2] + >>> make_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + """ + batch_size = lengths.size(0) + max_len = max_len if max_len > 0 else lengths.max().item() + seq_range = torch.arange(0, + max_len, + dtype=torch.int64, + device=lengths.device) + seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) + seq_length_expand = lengths.unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + return mask diff --git a/inspiremusic/utils/scheduler.py b/inspiremusic/utils/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..ea5f3a8b00ca9328a0fe917b34d73e21e9c25b2f --- /dev/null +++ b/inspiremusic/utils/scheduler.py @@ -0,0 +1,738 @@ +# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) +# 2022 Ximalaya Inc (Yuguang Yang) +# 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +# NeMo(https://github.com/NVIDIA/NeMo) + +from typing import Union + +import math +import warnings +import torch +from torch.optim.lr_scheduler import _LRScheduler + + +class WarmupLR(_LRScheduler): + """The WarmupLR scheduler + + This scheduler is almost same as NoamLR Scheduler except for following + difference: + + NoamLR: + lr = optimizer.lr * model_size ** -0.5 + * min(step ** -0.5, step * warmup_step ** -1.5) + WarmupLR: + lr = optimizer.lr * warmup_step ** 0.5 + * min(step ** -0.5, step * warmup_step ** -1.5) + + Note that the maximum lr equals to optimizer.lr in this scheduler. + + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + warmup_steps: Union[int, float] = 25000, + last_epoch: int = -1, + ): + self.warmup_steps = warmup_steps + + # __init__() must be invoked before setting field + # because step() is also invoked in __init__() + super().__init__(optimizer, last_epoch) + + def __repr__(self): + return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})" + + def get_lr(self): + step_num = self.last_epoch + 1 + if self.warmup_steps == 0: + return [lr * step_num**-0.5 for lr in self.base_lrs] + else: + return [ + lr * self.warmup_steps**0.5 * + min(step_num**-0.5, step_num * self.warmup_steps**-1.5) + for lr in self.base_lrs + ] + + def set_step(self, step: int): + self.last_epoch = step + + +class WarmupPolicy(_LRScheduler): + """Adds warmup kwargs and warmup logic to lr policy. + All arguments should be passed as kwargs for clarity, + Args: + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + max_steps: Total number of steps while training or `None` for + infinite training + """ + + def __init__(self, + optimizer, + *, + warmup_steps=None, + warmup_ratio=None, + max_steps=None, + min_lr=0.0, + last_epoch=-1): + assert not (warmup_steps is not None and warmup_ratio is not None),\ + "Either use particular number of step or ratio" + assert warmup_ratio is None or max_steps is not None, \ + "If there is a ratio, there should be a total steps" + + # It is necessary to assign all attributes *before* __init__, + # as class is wrapped by an inner class. + self.max_steps = max_steps + if warmup_steps is not None: + self.warmup_steps = warmup_steps + elif warmup_ratio is not None: + self.warmup_steps = int(warmup_ratio * max_steps) + else: + self.warmup_steps = 0 + + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed " + "by the scheduler, please use `get_last_lr()`.", + UserWarning, + stacklevel=2) + + step = self.last_epoch + + if step <= self.warmup_steps and self.warmup_steps > 0: + return self._get_warmup_lr(step) + + if step > self.max_steps: + return [self.min_lr for _ in self.base_lrs] + + return self._get_lr(step) + + def _get_warmup_lr(self, step): + lr_val = (step + 1) / (self.warmup_steps + 1) + return [initial_lr * lr_val for initial_lr in self.base_lrs] + + def _get_lr(self, step): + """Simple const lr policy""" + return self.base_lrs + + +class SquareRootConstantPolicy(_LRScheduler): + """Adds warmup kwargs and warmup logic to lr policy. + All arguments should be passed as kwargs for clarity, + Args: + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + max_steps: Total number of steps while training or `None` for + infinite training + """ + + def __init__(self, + optimizer, + *, + constant_steps=None, + constant_ratio=None, + max_steps=None, + min_lr=0.0, + last_epoch=-1): + assert not (constant_steps is not None + and constant_ratio is not None), \ + "Either use particular number of step or ratio" + assert constant_ratio is None or max_steps is not None, \ + "If there is a ratio, there should be a total steps" + + # It is necessary to assign all attributes *before* __init__, + # as class is wrapped by an inner class. + self.max_steps = max_steps + if constant_steps is not None: + self.constant_steps = constant_steps + elif constant_ratio is not None: + self.constant_steps = int(constant_ratio * max_steps) + else: + self.constant_steps = 0 + + self.constant_lr = 1 / (constant_steps**0.5) + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed " + "by the scheduler, please use `get_last_lr()`.", + UserWarning, + stacklevel=2) + + step = self.last_epoch + + if step <= self.constant_steps: + return [self.constant_lr for _ in self.base_lrs] + + if step > self.max_steps: + return [self.min_lr for _ in self.base_lrs] + + return self._get_lr(step) + + def _get_lr(self, step): + """Simple const lr policy""" + return self.base_lrs + + +class WarmupHoldPolicy(WarmupPolicy): + """Variant of WarmupPolicy which maintains high + learning rate for a defined number of steps. + All arguments should be passed as kwargs for clarity, + Args: + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + hold_steps: Number of training steps to + hold the learning rate after warm up + hold_ratio: Ratio of hold steps to total steps + max_steps: Total number of steps while training or `None` for + infinite training + """ + + def __init__( + self, + optimizer, + *, + warmup_steps=None, + warmup_ratio=None, + hold_steps=None, + hold_ratio=None, + max_steps=None, + min_lr=0.0, + last_epoch=-1, + ): + assert not (hold_steps is not None and hold_ratio is not None), \ + "Either use particular number of step or ratio" + assert hold_ratio is None or max_steps is not None, \ + "If there is a ratio, there should be a total steps" + + self.min_lr = min_lr + self._last_warmup_lr = 0.0 + + # Necessary to duplicate as class attributes are hidden in inner class + self.max_steps = max_steps + if warmup_steps is not None: + self.warmup_steps = warmup_steps + elif warmup_ratio is not None: + self.warmup_steps = int(warmup_ratio * max_steps) + else: + self.warmup_steps = 0 + + if hold_steps is not None: + self.hold_steps = hold_steps + self.warmup_steps + elif hold_ratio is not None: + self.hold_steps = int(hold_ratio * max_steps) + self.warmup_steps + else: + self.hold_steps = 0 + + super().__init__( + optimizer, + warmup_steps=warmup_steps, + warmup_ratio=warmup_ratio, + max_steps=max_steps, + last_epoch=last_epoch, + min_lr=min_lr, + ) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler," + " " + "please use `get_last_lr()`.", + UserWarning, + stacklevel=2) + + step = self.last_epoch + + # Warmup phase + if step <= self.warmup_steps and self.warmup_steps > 0: + return self._get_warmup_lr(step) + + # Hold phase + if (step >= self.warmup_steps) and (step < self.hold_steps): + return self.base_lrs + + if step > self.max_steps: + return [self.min_lr for _ in self.base_lrs] + + return self._get_lr(step) + + +class WarmupAnnealHoldPolicy(_LRScheduler): + """Adds warmup kwargs and warmup logic to lr policy. + All arguments should be passed as kwargs for clarity, + Args: + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + max_steps: Total number of steps while training or `None` for + infinite training + min_lr: Minimum lr to hold the learning rate after decay at. + constant_steps: Number of steps to keep lr constant at. + constant_ratio: Ratio of steps to keep lr constant. + """ + + def __init__( + self, + optimizer, + *, + warmup_steps=None, + warmup_ratio=None, + constant_steps=None, + constant_ratio=None, + max_steps=None, + min_lr=0.0, + last_epoch=-1, + ): + assert not (warmup_steps is not None + and warmup_ratio is not None), \ + "Either use particular number of step or ratio" + assert not (constant_steps is not None + and constant_ratio is not None), \ + "Either use constant_steps or constant_ratio" + assert warmup_ratio is None or max_steps is not None, \ + "If there is a ratio, there should be a total steps" + + # It is necessary to assign all attributes *before* __init__, + # as class is wrapped by an inner class. + self.max_steps = max_steps + + if warmup_steps is not None: + self.warmup_steps = warmup_steps + elif warmup_ratio is not None: + self.warmup_steps = int(warmup_ratio * max_steps) + else: + self.warmup_steps = 0 + + if constant_steps is not None: + self.constant_steps = constant_steps + elif constant_ratio is not None: + self.constant_steps = int(constant_ratio * max_steps) + else: + self.constant_steps = 0 + + self.decay_steps = max_steps - (self.constant_steps + + self.warmup_steps) + + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed " + "by the scheduler, please use `get_last_lr()`.", + UserWarning, + stacklevel=2) + + step = self.last_epoch + + # Warmup steps + if self.warmup_steps > 0 and step <= self.warmup_steps: + return self._get_warmup_lr(step) + + # Constant steps after warmup and decay + if self.constant_steps > 0 and ( + self.warmup_steps + self.decay_steps) < step <= self.max_steps: + return self._get_constant_lr(step) + + # Min lr after max steps of updates + if step > self.max_steps: + return [self.min_lr for _ in self.base_lrs] + + return self._get_lr(step) + + def _get_warmup_lr(self, step): + lr_val = (step + 1) / (self.warmup_steps + 1) + return [initial_lr * lr_val for initial_lr in self.base_lrs] + + def _get_constant_lr(self, step): + return [self.min_lr for _ in self.base_lrs] + + def _get_lr(self, step): + """Simple const lr policy""" + return self.base_lrs + + +def _squareroot_annealing(initial_lr, step, max_steps, min_lr): + mult = ((max_steps - step) / max_steps)**0.5 + out_lr = initial_lr * mult + out_lr = max(out_lr, min_lr) + return out_lr + + +def _square_annealing(initial_lr, step, max_steps, min_lr): + mult = ((max_steps - step) / max_steps)**2 + out_lr = initial_lr * mult + out_lr = max(out_lr, min_lr) + return out_lr + + +def _cosine_annealing(initial_lr, step, max_steps, min_lr): + mult = 0.5 * (1 + math.cos(math.pi * step / max_steps)) + out_lr = (initial_lr - min_lr) * mult + min_lr + return out_lr + + +def _linear_warmup_with_cosine_annealing(max_lr, warmup_steps, step, + decay_steps, min_lr): + assert max_lr > min_lr + # Use linear warmup for the initial part. + if warmup_steps > 0 and step <= warmup_steps: + return max_lr * float(step) / float(warmup_steps) + + # For any steps larger than `decay_steps`, use `min_lr`. + if step > warmup_steps + decay_steps: + return min_lr + + # If we are done with the warmup period, use the decay style. + num_steps_ = step - warmup_steps + decay_steps_ = decay_steps + decay_ratio = float(num_steps_) / float(decay_steps_) + assert decay_ratio >= 0.0 + assert decay_ratio <= 1.0 + delta_lr = max_lr - min_lr + + coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0) + + return min_lr + coeff * delta_lr + + +def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle): + if cycle: + multiplier = 1.0 if step == 0 else math.ceil(step / decay_steps) + decay_steps *= multiplier + else: + step = min(step, decay_steps) + p = step / decay_steps + lr = (initial_lr - min_lr) * math.pow(1.0 - p, power) + lr += min_lr + return lr + + +def _noam_hold_annealing(initial_lr, step, warmup_steps, hold_steps, + decay_rate, min_lr): + # hold_steps = total number of steps + # to hold the LR, not the warmup + hold steps. + T_warmup_decay = max(1, warmup_steps**decay_rate) + T_hold_decay = max(1, (step - hold_steps)**decay_rate) + lr = (initial_lr * T_warmup_decay) / T_hold_decay + lr = max(lr, min_lr) + return lr + + +class SquareAnnealing(WarmupPolicy): + + def __init__(self, + optimizer, + *, + max_steps, + min_lr=1e-5, + last_epoch=-1, + **kwargs): + super().__init__(optimizer=optimizer, + max_steps=max_steps, + last_epoch=last_epoch, + min_lr=min_lr, + **kwargs) + + def _get_lr(self, step): + new_lrs = [ + _square_annealing( + initial_lr=initial_lr, + step=step - self.warmup_steps, + max_steps=self.max_steps - self.warmup_steps, + min_lr=self.min_lr, + ) for initial_lr in self.base_lrs + ] + return new_lrs + + +class SquareRootAnnealing(WarmupPolicy): + + def __init__(self, + optimizer, + *, + max_steps, + min_lr=0, + last_epoch=-1, + **kwargs): + super().__init__(optimizer=optimizer, + max_steps=max_steps, + last_epoch=last_epoch, + min_lr=min_lr, + **kwargs) + + def _get_lr(self, step): + new_lrs = [ + _squareroot_annealing(initial_lr=initial_lr, + step=step, + max_steps=self.max_steps, + min_lr=self.min_lr) + for initial_lr in self.base_lrs + ] + return new_lrs + + +class CosineAnnealing(WarmupAnnealHoldPolicy): + + def __init__(self, + optimizer, + *, + max_steps, + min_lr=0, + last_epoch=-1, + **kwargs): + super().__init__(optimizer=optimizer, + max_steps=max_steps, + last_epoch=last_epoch, + min_lr=min_lr, + **kwargs) + + def _get_lr(self, step): + for initial_lr in self.base_lrs: + if initial_lr < self.min_lr: + raise ValueError( + f"{self} received an initial learning rate " + f"that was lower than the minimum learning rate.") + + if self.constant_steps is None or self.constant_steps == 0: + new_lrs = [ + _cosine_annealing( + initial_lr=initial_lr, + step=step - self.warmup_steps, + max_steps=self.max_steps - self.warmup_steps, + min_lr=self.min_lr, + ) for initial_lr in self.base_lrs + ] + else: + new_lrs = self._get_linear_warmup_with_cosine_annealing_lr(step) + return new_lrs + + def _get_warmup_lr(self, step): + if self.constant_steps is None or self.constant_steps == 0: + return super()._get_warmup_lr(step) + else: + # Use linear warmup for the initial part. + return self._get_linear_warmup_with_cosine_annealing_lr(step) + + def _get_constant_lr(self, step): + # Only called when `constant_steps` > 0. + return self._get_linear_warmup_with_cosine_annealing_lr(step) + + def _get_linear_warmup_with_cosine_annealing_lr(self, step): + # Cosine Schedule for Megatron LM, + # slightly different warmup schedule + constant LR at the end. + new_lrs = [ + _linear_warmup_with_cosine_annealing( + max_lr=self.base_lrs[0], + warmup_steps=self.warmup_steps, + step=step, + decay_steps=self.decay_steps, + min_lr=self.min_lr, + ) for _ in self.base_lrs + ] + return new_lrs + + +class NoamAnnealing(_LRScheduler): + + def __init__(self, + optimizer, + *, + d_model, + warmup_steps=None, + warmup_ratio=None, + max_steps=None, + min_lr=0.0, + last_epoch=-1): + self._normalize = d_model**(-0.5) + assert not (warmup_steps is not None and warmup_ratio is not None), \ + "Either use particular number of step or ratio" + assert warmup_ratio is None or max_steps is not None, \ + "If there is a ratio, there should be a total steps" + + # It is necessary to assign all attributes *before* __init__, + # as class is wrapped by an inner class. + self.max_steps = max_steps + if warmup_steps is not None: + self.warmup_steps = warmup_steps + elif warmup_ratio is not None: + self.warmup_steps = int(warmup_ratio * max_steps) + else: + self.warmup_steps = 0 + + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed " + "by the scheduler, please use `get_last_lr()`.", + UserWarning, + stacklevel=2) + + step = max(1, self.last_epoch) + + for initial_lr in self.base_lrs: + if initial_lr < self.min_lr: + raise ValueError( + f"{self} received an initial learning rate " + f"that was lower than the minimum learning rate.") + + new_lrs = [ + self._noam_annealing(initial_lr=initial_lr, step=step) + for initial_lr in self.base_lrs + ] + return new_lrs + + def _noam_annealing(self, initial_lr, step): + if self.warmup_steps > 0: + mult = self._normalize * min(step**(-0.5), + step * (self.warmup_steps**(-1.5))) + else: + mult = self._normalize * step**(-0.5) + + out_lr = initial_lr * mult + if step > self.warmup_steps: + out_lr = max(out_lr, self.min_lr) + return out_lr + + +class NoamHoldAnnealing(WarmupHoldPolicy): + + def __init__(self, + optimizer, + *, + max_steps, + decay_rate=0.5, + min_lr=0.0, + last_epoch=-1, + **kwargs): + """ + From Nemo: + Implementation of the Noam Hold Annealing policy + from the SqueezeFormer paper. + + Unlike NoamAnnealing, the peak learning rate + can be explicitly set for this scheduler. + The schedule first performs linear warmup, + then holds the peak LR, then decays with some schedule for + the remainder of the steps. + Therefore the min-lr is still dependent + on the hyper parameters selected. + + It's schedule is determined by three factors- + + Warmup Steps: Initial stage, where linear warmup + occurs uptil the peak LR is reached. Unlike NoamAnnealing, + the peak LR is explicitly stated here instead of a scaling factor. + + Hold Steps: Intermediate stage, where the peak LR + is maintained for some number of steps. In this region, + the high peak LR allows the model to converge faster + if training is stable. However the high LR + may also cause instability during training. + Should usually be a significant fraction of training + steps (around 30-40% of the entire training steps). + + Decay Steps: Final stage, where the LR rapidly decays + with some scaling rate (set by decay rate). + To attain Noam decay, use 0.5, + for Squeezeformer recommended decay, use 1.0. + The fast decay after prolonged high LR during + hold phase allows for rapid convergence. + + References: + - [Squeezeformer: + An Efficient Transformer for Automatic Speech Recognition] + (https://arxiv.org/abs/2206.00888) + + Args: + optimizer: Pytorch compatible Optimizer object. + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + hold_steps: Number of training steps to + hold the learning rate after warm up + hold_ratio: Ratio of hold steps to total steps + max_steps: Total number of steps while training or `None` for + infinite training + decay_rate: Float value describing the polynomial decay + after the hold period. Default value + of 0.5 corresponds to Noam decay. + min_lr: Minimum learning rate. + """ + self.decay_rate = decay_rate + super().__init__(optimizer=optimizer, + max_steps=max_steps, + last_epoch=last_epoch, + min_lr=min_lr, + **kwargs) + + def _get_lr(self, step): + if self.warmup_steps is None or self.warmup_steps == 0: + raise ValueError( + "Noam scheduler cannot be used without warmup steps") + + if self.hold_steps > 0: + hold_steps = self.hold_steps - self.warmup_steps + else: + hold_steps = 0 + + new_lrs = [ + _noam_hold_annealing( + initial_lr, + step=step, + warmup_steps=self.warmup_steps, + hold_steps=hold_steps, + decay_rate=self.decay_rate, + min_lr=self.min_lr, + ) for initial_lr in self.base_lrs + ] + return new_lrs + + def set_step(self, step: int): + self.last_epoch = step + + +class ConstantLR(_LRScheduler): + """The ConstantLR scheduler + + This scheduler keeps a constant lr + + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + ): + # __init__() must be invoked before setting field + # because step() is also invoked in __init__() + super().__init__(optimizer) + + def get_lr(self): + return self.base_lrs + + def set_step(self, step: int): + self.last_epoch = step diff --git a/inspiremusic/utils/tokenizer_utils.py b/inspiremusic/utils/tokenizer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..095757f5d0f8a6da8598f06ceed819e659f8bb61 --- /dev/null +++ b/inspiremusic/utils/tokenizer_utils.py @@ -0,0 +1,221 @@ +import glob +import json +import os +import random +import sys +import time +import warnings + +import matplotlib +import numpy as np +import torch +import yaml +from torch import distributed as dist +from torch.nn.utils import weight_norm +matplotlib.use("Agg") +import matplotlib.pylab as plt +import re +import pathlib + + +def seed_everything(seed, cudnn_deterministic=False): + """ + Function that sets seed for pseudo-random number generators in: + pytorch, numpy, python.random + + Args: + seed: the integer value seed for global random state + """ + if seed is not None: + # print(f"Global seed set to {seed}") + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + # if cudnn_deterministic: + # torch.backends.cudnn.deterministic = True + # warnings.warn('You have chosen to seed training. ' + # 'This will turn on the CUDNN deterministic setting, ' + # 'which can slow down your training considerably! ' + # 'You may see unexpected behavior when restarting ' + # 'from checkpoints.') + + +def is_primary(): + return get_rank() == 0 + + +def get_rank(): + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + + return dist.get_rank() + + +def load_yaml_config(path): + with open(path) as f: + config = yaml.full_load(f) + return config + + +def save_config_to_yaml(config, path): + assert path.endswith('.yaml') + with open(path, 'w') as f: + f.write(yaml.dump(config)) + f.close() + + +def save_dict_to_json(d, path, indent=None): + json.dump(d, open(path, 'w'), indent=indent) + + +def load_dict_from_json(path): + return json.load(open(path, 'r')) + + +def write_args(args, path): + args_dict = dict((name, getattr(args, name)) for name in dir(args) + if not name.startswith('_')) + with open(path, 'a') as args_file: + args_file.write('==> torch version: {}\n'.format(torch.__version__)) + args_file.write( + '==> cudnn version: {}\n'.format(torch.backends.cudnn.version())) + args_file.write('==> Cmd:\n') + args_file.write(str(sys.argv)) + args_file.write('\n==> args:\n') + for k, v in sorted(args_dict.items()): + args_file.write(' %s: %s\n' % (str(k), str(v))) + args_file.close() + + +class Logger(object): + def __init__(self, args): + self.args = args + self.save_dir = args.save_dir + self.is_primary = is_primary() + + if self.is_primary: + os.makedirs(self.save_dir, exist_ok=True) + + # save the args and config + self.config_dir = os.path.join(self.save_dir, 'configs') + os.makedirs(self.config_dir, exist_ok=True) + file_name = os.path.join(self.config_dir, 'args.txt') + write_args(args, file_name) + + log_dir = os.path.join(self.save_dir, 'logs') + if not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + self.text_writer = open(os.path.join(log_dir, 'log.txt'), + 'a') # 'w') + if args.tensorboard: + self.log_info('using tensorboard') + self.tb_writer = torch.utils.tensorboard.SummaryWriter( + log_dir=log_dir + ) # tensorboard.SummaryWriter(log_dir=log_dir) + else: + self.tb_writer = None + + def save_config(self, config): + if self.is_primary: + save_config_to_yaml(config, + os.path.join(self.config_dir, 'config.yaml')) + + def log_info(self, info, check_primary=True): + if self.is_primary or (not check_primary): + print(info) + if self.is_primary: + info = str(info) + time_str = time.strftime('%Y-%m-%d-%H-%M') + info = '{}: {}'.format(time_str, info) + if not info.endswith('\n'): + info += '\n' + self.text_writer.write(info) + self.text_writer.flush() + + def add_scalar(self, **kargs): + """Log a scalar variable.""" + if self.is_primary: + if self.tb_writer is not None: + self.tb_writer.add_scalar(**kargs) + + def add_scalars(self, **kargs): + """Log a scalar variable.""" + if self.is_primary: + if self.tb_writer is not None: + self.tb_writer.add_scalars(**kargs) + + def add_image(self, **kargs): + """Log a scalar variable.""" + if self.is_primary: + if self.tb_writer is not None: + self.tb_writer.add_image(**kargs) + + def add_images(self, **kargs): + """Log a scalar variable.""" + if self.is_primary: + if self.tb_writer is not None: + self.tb_writer.add_images(**kargs) + + def close(self): + if self.is_primary: + self.text_writer.close() + self.tb_writer.close() + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow( + spectrogram, aspect="auto", origin="lower", interpolation='none') + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print("Loading '{}'".format(filepath)) + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def save_checkpoint(filepath, obj, num_ckpt_keep=5): + name = re.match(r'(do|g)_\d+', pathlib.Path(filepath).name).group(1) + ckpts = sorted(pathlib.Path(filepath).parent.glob(f'{name}_*')) + if len(ckpts) > num_ckpt_keep: + [os.remove(c) for c in ckpts[:-num_ckpt_keep]] + print("Saving checkpoint to {}".format(filepath)) + torch.save(obj, filepath) + print("Complete.") + + +def scan_checkpoint(cp_dir, prefix): + pattern = os.path.join(cp_dir, prefix + '????????') + cp_list = glob.glob(pattern) + if len(cp_list) == 0: + return None + return sorted(cp_list)[-1] + diff --git a/inspiremusic/utils/train_utils.py b/inspiremusic/utils/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9312eb5a83f37388cd4de547beef1bc0f95b6d28 --- /dev/null +++ b/inspiremusic/utils/train_utils.py @@ -0,0 +1,300 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# 2023 Horizon Inc. (authors: Xingchen Song) +# 2024 Alibaba Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import nullcontext +import logging +import os +import torch +import json +import re +import datetime +import yaml + +import deepspeed +import torch.optim as optim +import torch.distributed as dist + +from torch.utils.tensorboard import SummaryWriter +from torch.utils.data import DataLoader +from torch.nn.utils import clip_grad_norm_ + +from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live + +from inspiremusic.dataset.dataset import Dataset +from inspiremusic.utils.scheduler import WarmupLR, NoamHoldAnnealing, ConstantLR + + +def init_distributed(args): + world_size = int(os.environ.get('WORLD_SIZE', 1)) + local_rank = int(os.environ.get('LOCAL_RANK', 0)) + rank = int(os.environ.get('RANK', 0)) + logging.info('training on multiple gpus, this gpu {}'.format(local_rank) + + ', rank {}, world_size {}'.format(rank, world_size)) + if args.train_engine == 'torch_ddp': + torch.cuda.set_device(local_rank) + dist.init_process_group(args.dist_backend) + else: + deepspeed.init_distributed(dist_backend=args.dist_backend) + return world_size, local_rank, rank + + +def init_dataset_and_dataloader(args, configs): + gan = False + data_pipeline = configs['data_pipeline_gan'] if gan is True else configs['data_pipeline'] + train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', shuffle=True, partition=True) + cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='train', shuffle=False, partition=False) + + # do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts + train_data_loader = DataLoader(train_dataset, + batch_size=None, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + prefetch_factor=args.prefetch, + timeout=60) + cv_data_loader = DataLoader(cv_dataset, + batch_size=None, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + prefetch_factor=args.prefetch, + timeout=60) + return train_dataset, cv_dataset, train_data_loader, cv_data_loader + + +def check_modify_and_save_config(args, configs): + if args.train_engine == "torch_ddp": + configs['train_conf']["dtype"] = 'fp32' + else: + with open(args.deepspeed_config, 'r') as fin: + ds_configs = json.load(fin) + if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]: + configs['train_conf']["dtype"] = "fp16" + elif "bf16" in ds_configs and ds_configs["bf16"]["enabled"]: + configs['train_conf']["dtype"] = "bf16" + else: + configs['train_conf']["dtype"] = "fp32" + assert ds_configs["train_micro_batch_size_per_gpu"] == 1 + # if use deepspeed, override ddp config + configs['train_conf']['save_per_step'] = int(configs['train_conf']['save_per_step'] * + configs['train_conf']['accum_grad'] / ds_configs["gradient_accumulation_steps"]) + configs['train_conf']['accum_grad'] = ds_configs["gradient_accumulation_steps"] + configs['train_conf']['grad_clip'] = ds_configs["gradient_clipping"] + configs['train_conf']['log_interval'] = ds_configs["steps_per_print"] + return configs + + +def wrap_cuda_model(args, model): + local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1)) + world_size = int(os.environ.get('WORLD_SIZE', 1)) + if args.train_engine == "torch_ddp": # native pytorch ddp + assert (torch.cuda.is_available()) + model.cuda() + model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True) + else: + if int(os.environ.get('RANK', 0)) == 0: + logging.info("Estimating model states memory needs (zero2)...") + estimate_zero2_model_states_mem_needs_all_live( + model, + num_gpus_per_node=local_world_size, + num_nodes=world_size // local_world_size) + return model + +def init_optimizer_and_scheduler(args, configs, model): + if configs['train_conf']['optim'] == 'adam': + optimizer = optim.Adam(model.parameters(), **configs['train_conf']['optim_conf']) + elif configs['train_conf']['optim'] == 'adamw': + optimizer = optim.AdamW(model.parameters(), **configs['train_conf']['optim_conf']) + else: + raise ValueError("unknown optimizer: " + configs['train_conf']) + + if configs['train_conf']['scheduler'] == 'warmuplr': + scheduler_type = WarmupLR + scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf']) + elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing': + scheduler_type = NoamHoldAnnealing + scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf']) + elif configs['train_conf']['scheduler'] == 'constantlr': + scheduler_type = ConstantLR + scheduler = ConstantLR(optimizer) + else: + raise ValueError("unknown scheduler: " + configs['train_conf']) + + # use deepspeed optimizer for speedup + if args.train_engine == "deepspeed": + def scheduler(opt): + return scheduler_type(opt, **configs['train_conf']['scheduler_conf']) + model, optimizer, _, scheduler = deepspeed.initialize( + args=args, + model=model, + optimizer=None, + lr_scheduler=scheduler, + model_parameters=model.parameters()) + + return model, optimizer, scheduler + + +def init_summarywriter(args): + writer = None + if int(os.environ.get('RANK', 0)) == 0: + os.makedirs(args.model_dir, exist_ok=True) + writer = SummaryWriter(args.tensorboard_dir) + return writer + + +def save_model(model, model_name, info_dict): + rank = int(os.environ.get('RANK', 0)) + model_dir = info_dict["model_dir"] + save_model_path = os.path.join(model_dir, '{}.pt'.format(model_name)) + + if info_dict["train_engine"] == "torch_ddp": + if rank == 0: + torch.save(model.module.state_dict(), save_model_path) + else: + with torch.no_grad(): + model.save_checkpoint(save_dir=model_dir, + tag=model_name, + client_state=info_dict) + if rank == 0: + info_path = re.sub('.pt$', '.yaml', save_model_path) + info_dict['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S') + with open(info_path, 'w') as fout: + data = yaml.dump(info_dict) + fout.write(data) + logging.info('[Rank {}] Checkpoint: save to checkpoint {}'.format(rank, save_model_path)) + + +def inspiremusic_join(group_join, info_dict): + world_size = int(os.environ.get('WORLD_SIZE', 1)) + local_rank = int(os.environ.get('LOCAL_RANK', 0)) + rank = int(os.environ.get('RANK', 0)) + + if info_dict["batch_idx"] != 0: + # we try to join all rank in both ddp and deepspeed mode, in case different rank has different lr + try: + dist.monitored_barrier(group=group_join, + timeout=group_join.options._timeout) + return False + except RuntimeError as e: + logging.info("Detected uneven workload distribution: {}\n".format(e) + + "Break current worker to manually join all workers, " + + "world_size {}, current rank {}, current local_rank {}\n". + format(world_size, rank, local_rank)) + return True + else: + return False + + +def batch_forward(model, batch, info_dict, scaler): + device = int(os.environ.get('LOCAL_RANK', 0)) + + dtype = info_dict["dtype"] + if dtype == "fp16": + dtype = torch.float16 + elif dtype == "bf16": + dtype = torch.bfloat16 + else: # fp32 + dtype = torch.float32 + + if info_dict['train_engine'] == 'torch_ddp': + autocast = torch.cuda.amp.autocast(enabled=scaler is not None) + else: + autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False) + + with autocast: + info_dict['loss_dict'] = model(batch, device) + return info_dict + + +def batch_backward(model, info_dict, scaler): + if info_dict["train_engine"] == "deepspeed": + scaled_loss = model.backward(info_dict['loss_dict']['loss']) + else: + scaled_loss = info_dict['loss_dict']['loss'] / info_dict['accum_grad'] + if scaler is not None: + scaler.scale(scaled_loss).backward() + else: + scaled_loss.backward() + + info_dict['loss_dict']['loss'] = scaled_loss + return info_dict + +def update_parameter_and_lr(model, optimizer, scheduler, info_dict, scaler=None): + grad_norm = 0.0 + if info_dict['train_engine'] == "deepspeed": + info_dict["is_gradient_accumulation_boundary"] = model.is_gradient_accumulation_boundary() + model.step() + grad_norm = model.get_global_grad_norm() + elif (info_dict['batch_idx'] + 1) % info_dict["accum_grad"] == 0: + if scaler is not None: + scaler.unscale_(optimizer) # Unscale gradients before clipping + grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip']) + scaler.step(optimizer) + scaler.update() + else: + grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip']) + if torch.isfinite(grad_norm): + optimizer.step() + optimizer.zero_grad() + scheduler.step() + info_dict["lr"] = optimizer.param_groups[0]['lr'] + info_dict["grad_norm"] = grad_norm + return info_dict + + +def log_per_step(writer, info_dict): + tag = info_dict["tag"] + epoch = info_dict.get('epoch', 0) + step = info_dict["step"] + batch_idx = info_dict["batch_idx"] + loss_dict = info_dict['loss_dict'] + rank = int(os.environ.get('RANK', 0)) + + # only rank 0 write to tensorboard to avoid multi-process write + if writer is not None: + if (info_dict['train_engine'] == 'deepspeed' and info_dict['is_gradient_accumulation_boundary'] is True) or \ + (info_dict['train_engine'] == 'torch_ddp' and (info_dict['batch_idx'] + 1) % info_dict['accum_grad'] == 0): + for k in ['epoch', 'lr', 'grad_norm']: + writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1) + for k, v in loss_dict.items(): + writer.add_scalar('{}/{}'.format(tag, k), v, step + 1) + + # TRAIN & CV, Shell log (stdout) + if (info_dict['batch_idx'] + 1) % info_dict['log_interval'] == 0: + log_str = '{} Batch {}/{} '.format(tag, epoch, batch_idx + 1) + for name, value in loss_dict.items(): + log_str += '{} {:.6f} '.format(name, value.item()) + if tag == "TRAIN": + log_str += 'lr {:.8f} grad_norm {:.6f}'.format( + info_dict["lr"], info_dict['grad_norm']) + log_str += ' rank {}'.format(rank) + logging.debug(log_str) + + +def log_per_save(writer, info_dict): + tag = info_dict["tag"] + epoch = info_dict["epoch"] + step = info_dict["step"] + loss_dict = info_dict["loss_dict"] + lr = info_dict['lr'] + rank = int(os.environ.get('RANK', 0)) + logging.info( + 'Epoch {} Step {} CV info lr {} {} rank {}'.format( + epoch, step + 1, lr, rank, ' '.join(['{}_{}'.format(k, v) for k, v in loss_dict.items()]))) + + if writer is not None: + for k in ['epoch', 'lr']: + writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1) + for k, v in loss_dict.items(): + writer.add_scalar('{}/{}'.format(tag, k), v, step + 1) diff --git a/inspiremusic/utils/utils.py b/inspiremusic/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..db1d8b41c94be6c41b8424d3d036e17a6471234d --- /dev/null +++ b/inspiremusic/utils/utils.py @@ -0,0 +1,22 @@ +import os +import sys + +def align_trans_scp_file(trans, scp): + trans_dict = {} + with open(trans, 'r') as f: + for line in f: + sec = line.strip().split("\t") + trans_dict[sec[0]] = sec[1] + scp_dict = {} + with open(scp, 'r') as f: + for line in f: + sec = line.strip().split(" ") + scp_dict[sec[0]] = sec[1] + with open("text", "w") as f: + for k, v in scp_dict.items(): + f.write("%s\t%s\n"%(k,trans_dict[k])) + +if __name__ == '__main__': + trans = sys.argv[1] + scp = sys.argv[2] + align_trans_scp_file(trans, scp) \ No newline at end of file diff --git a/inspiremusic/version.txt b/inspiremusic/version.txt new file mode 100644 index 0000000000000000000000000000000000000000..aa3386889155607fc99e0c091986218549648a89 --- /dev/null +++ b/inspiremusic/version.txt @@ -0,0 +1 @@ +v0.1 \ No newline at end of file diff --git a/inspiremusic/wavtokenizer/__init__.py b/inspiremusic/wavtokenizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inspiremusic/wavtokenizer/decoder/__init__.py b/inspiremusic/wavtokenizer/decoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inspiremusic/wavtokenizer/decoder/dataset.py b/inspiremusic/wavtokenizer/decoder/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..345a1617291648fe5f8671e7ea897c539fdcb2f5 --- /dev/null +++ b/inspiremusic/wavtokenizer/decoder/dataset.py @@ -0,0 +1,124 @@ +from dataclasses import dataclass + +import numpy as np +import torch +import torchaudio +from pytorch_lightning import LightningDataModule +from torch.utils.data import Dataset, DataLoader + +import soundfile +# import librosa +import random + +torch.set_num_threads(1) + + +@dataclass +class DataConfig: + filelist_path: str + sampling_rate: int + num_samples: int + batch_size: int + num_workers: int + +def collate_fn(batch): + batch = [item for item in batch if item is not None] + return torch.stack(batch, dim=0) + +class VocosDataModule(LightningDataModule): + def __init__(self, train_params: DataConfig, val_params: DataConfig): + super().__init__() + self.train_config = train_params + self.val_config = val_params + + def _get_dataloder(self, cfg: DataConfig, train: bool): + dataset = VocosDataset(cfg, train=train) + dataloader = DataLoader( + dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers, shuffle=train, pin_memory=True, collate_fn=collate_fn + ) + return dataloader + + def train_dataloader(self) -> DataLoader: + return self._get_dataloder(self.train_config, train=True) + + def val_dataloader(self) -> DataLoader: + return self._get_dataloder(self.val_config, train=False) + + +class VocosDataset(Dataset): + def __init__(self, cfg: DataConfig, train: bool): + with open(cfg.filelist_path) as f: + self.filelist = f.read().splitlines() + self.sampling_rate = cfg.sampling_rate + self.num_samples = cfg.num_samples + self.train = train + + def __len__(self) -> int: + return len(self.filelist) + + def __getitem__(self, index: int) -> torch.Tensor: + audio_path = self.filelist[index] + # y, sr = torchaudio.load(audio_path) + # print(audio_path,"111") + try: + y1, sr = soundfile.read(audio_path) + # y1, sr = librosa.load(audio_path,sr=None) + y = torch.tensor(y1).float().unsqueeze(0) + # if y.size(0) > 1: + # # mix to mono + # y = y.mean(dim=0, keepdim=True) + if y.ndim > 2: + # mix to mono + # print("ๆœ‰้—ฎ้ข˜ๅ“ˆ,ๆ•ฐๆฎๅค„็†้ƒจๅˆ†") + # y = y.mean(dim=-1, keepdim=False) + random_channel = random.randint(0, y.size(-1) - 1) + y = y[:, :, random_channel] + + gain = np.random.uniform(-1, -6) if self.train else -3 + y, _ = torchaudio.sox_effects.apply_effects_tensor(y, sr, [["norm", f"{gain:.2f}"]]) + if sr != self.sampling_rate: + y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=self.sampling_rate) + if y.size(-1) < self.num_samples: + pad_length = self.num_samples - y.size(-1) + padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1)) + y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1) + elif self.train: + start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1) + y = y[:, start : start + self.num_samples] + else: + # During validation, take always the first segment for determinism + y = y[:, : self.num_samples] + + return y[0] + except Exception as e: + print(f"Error processing file {audio_path} at index {index}: {e}") + # ่ฟ™้‡Œๅฏไปฅ็ปง็ปญ้€‰ๆ‹ฉๆŠ›ๅ‡บๅผ‚ๅธธ๏ผŒๆˆ–่€…่ฟ”ๅ›žไธ€ไธช None ่กจ็คบๆ— ๆ•ˆๆ•ฐๆฎ + return None + + # def __getitem__(self, index: int) -> torch.Tensor: + # audio_path = self.filelist[index] + # try: + # y, sr = torchaudio.load(audio_path) + # if y.size(0) > 1: + # # ้šๆœบ้€‰ๆ‹ฉไธ€ไธช้€š้“ + # random_channel = random.randint(0, y.size(0) - 1) + # y = y[random_channel, :].unsqueeze(0) # ไฟๆŒ่ฟ”ๅ›žๅ€ผไธบ (1, T) ็š„ๅฝขๅผ + # # gain = np.random.uniform(-1, -6) if self.train else -3 + # # y, _ = torchaudio.sox_effects.apply_effects_tensor(y, sr, [["norm", f"{gain:.2f}"]]) + # if sr != self.sampling_rate: + # y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=self.sampling_rate) + # if y.size(-1) < self.num_samples: + # pad_length = self.num_samples - y.size(-1) + # padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1)) + # y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1) + # elif self.train: + # start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1) + # y = y[:, start: start + self.num_samples] + # else: + # # During validation, take always the first segment for determinism + # y = y[:, :self.num_samples] + # return y[0] + # except Exception as e: + # print(f"Error processing file {audio_path} at index {index}: {e}") + # # ่ฟ™้‡Œๅฏไปฅ็ปง็ปญ้€‰ๆ‹ฉๆŠ›ๅ‡บๅผ‚ๅธธ๏ผŒๆˆ–่€…่ฟ”ๅ›žไธ€ไธช None ่กจ็คบๆ— ๆ•ˆๆ•ฐๆฎ + # return None \ No newline at end of file diff --git a/inspiremusic/wavtokenizer/decoder/discriminator_dac.py b/inspiremusic/wavtokenizer/decoder/discriminator_dac.py new file mode 100644 index 0000000000000000000000000000000000000000..33ef3258a3a4a3ab11f17ca9c3ad381de95b6477 --- /dev/null +++ b/inspiremusic/wavtokenizer/decoder/discriminator_dac.py @@ -0,0 +1,249 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +# from audiotools import AudioSignal +# from audiotools import ml +# from audiotools import STFTParams +from einops import rearrange +from torch.nn.utils import weight_norm + +from collections import namedtuple + +STFTParams = namedtuple( + "STFTParams", + ["window_length", "hop_length", "window_type", "match_stride", "padding_type"], +) + +STFTParams.__new__.__defaults__ = (None, None, None, None, None) + + +def WNConv1d(*args, **kwargs): + act = kwargs.pop("act", True) + conv = weight_norm(nn.Conv1d(*args, **kwargs)) + if not act: + return conv + return nn.Sequential(conv, nn.LeakyReLU(0.1)) + + +def WNConv2d(*args, **kwargs): + act = kwargs.pop("act", True) + conv = weight_norm(nn.Conv2d(*args, **kwargs)) + if not act: + return conv + return nn.Sequential(conv, nn.LeakyReLU(0.1)) + + +class MPD(nn.Module): + def __init__(self, period): + super().__init__() + self.period = period + self.convs = nn.ModuleList( + [ + WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)), + ] + ) + self.conv_post = WNConv2d( + 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False + ) + + def pad_to_period(self, x): + t = x.shape[-1] + x = F.pad(x, (0, self.period - t % self.period), mode="reflect") + return x + + def forward(self, x): + fmap = [] + + x = self.pad_to_period(x) + x = rearrange(x, "b c (l p) -> b c l p", p=self.period) + + for layer in self.convs: + x = layer(x) + fmap.append(x) + + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +class MSD(nn.Module): + def __init__(self, rate: int = 1, sample_rate: int = 48000): + super().__init__() + self.convs = nn.ModuleList( + [ + WNConv1d(1, 16, 15, 1, padding=7), + WNConv1d(16, 64, 41, 4, groups=4, padding=20), + WNConv1d(64, 256, 41, 4, groups=16, padding=20), + WNConv1d(256, 1024, 41, 4, groups=64, padding=20), + WNConv1d(1024, 1024, 41, 4, groups=256, padding=20), + WNConv1d(1024, 1024, 5, 1, padding=2), + ] + ) + self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False) + self.sample_rate = sample_rate + self.rate = rate + + def forward(self, x): + # x = AudioSignal(x, self.sample_rate) + # x.resample(self.sample_rate // self.rate) + # x = x.audio_data + + fmap = [] + + for l in self.convs: + x = l(x) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] + + +class MRD(nn.Module): + def __init__( + self, + window_length: int, + hop_factor: float = 0.25, + sample_rate: int = 24000, + bands: list = BANDS, + ): + """Complex multi-band spectrogram discriminator. + Parameters + ---------- + window_length : int + Window length of STFT. + hop_factor : float, optional + Hop factor of the STFT, defaults to ``0.25 * window_length``. + sample_rate : int, optional + Sampling rate of audio in Hz, by default 24000 + bands : list, optional + Bands to run discriminator over. + """ + super().__init__() + + self.window_length = window_length + self.hop_factor = hop_factor + self.sample_rate = sample_rate + self.stft_params = STFTParams( + window_length=window_length, + hop_length=int(window_length * hop_factor), + match_stride=True, + ) + + n_fft = window_length // 2 + 1 + bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] + self.bands = bands + self.n_fft = window_length + + ch = 32 + convs = lambda: nn.ModuleList( + [ + WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)), + ] + ) + self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) + self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False) + + def spectrogram(self, x): + # x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params) + # x = torch.view_as_real(x.stft()) + + # x.squeeze(0).stft(n_fft=1024,win_length=1024,return_complex=True).size() + # breakpoint() + if x.size(0)==1: + # x = torch.view_as_real(x.squeeze(0).stft(n_fft=self.window_length,return_complex=True).unsqueeze(0)) + x = torch.view_as_real(x.squeeze(0).stft(n_fft=self.n_fft,return_complex=True).unsqueeze(0)) + else: + # x = torch.view_as_real(x.squeeze(1).stft(n_fft=self.window_length,return_complex=True).unsqueeze(1)) + x = torch.view_as_real(x.squeeze(1).stft(n_fft=self.n_fft,return_complex=True).unsqueeze(1)) + x = rearrange(x, "b 1 f t c -> (b 1) c t f") + # Split into bands + x_bands = [x[..., b[0] : b[1]] for b in self.bands] + return x_bands + + def forward(self, x): + x_bands = self.spectrogram(x) + fmap = [] + + x = [] + for band, stack in zip(x_bands, self.band_convs): + for layer in stack: + band = layer(band) + fmap.append(band) + x.append(band) + + x = torch.cat(x, dim=-1) + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +# class DACDiscriminator(ml.BaseModel): +class DACDiscriminator(nn.Module): + def __init__( + self, + rates: list = [], + periods: list = [2, 3, 5, 7, 11], + fft_sizes: list = [2048, 1024, 512], + sample_rate: int = 24000, + bands: list = BANDS, + ): + """Discriminator that combines multiple discriminators. + + Parameters + ---------- + rates : list, optional + sampling rates (in Hz) to run MSD at, by default [] + If empty, MSD is not used. + periods : list, optional + periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11] + fft_sizes : list, optional + Window sizes of the FFT to run MRD at, by default [2048, 1024, 512] + sample_rate : int, optional + Sampling rate of audio in Hz, by default 24000 + bands : list, optional + Bands to run MRD at, by default `BANDS` + """ + super().__init__() + discs = [] + discs += [MPD(p) for p in periods] + discs += [MSD(r, sample_rate=sample_rate) for r in rates] + discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes] + self.discriminators = nn.ModuleList(discs) + + def preprocess(self, y): + # Remove DC offset + y = y - y.mean(dim=-1, keepdims=True) + # Peak normalize the volume of input audio + y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9) + return y + + def forward(self, x): + x = self.preprocess(x) + fmaps = [d(x) for d in self.discriminators] + return fmaps + + +if __name__ == "__main__": + disc = DACDiscriminator() + x = torch.zeros(1, 1, 24000) + results = disc(x) + breakpoint() + for i, result in enumerate(results): + print(f"disc{i}") + for i, r in enumerate(result): + print(r.shape, r.mean(), r.min(), r.max()) + print("00") diff --git a/inspiremusic/wavtokenizer/decoder/discriminators.py b/inspiremusic/wavtokenizer/decoder/discriminators.py new file mode 100644 index 0000000000000000000000000000000000000000..2f6dece570b3d181f3fd2206a4dae2549b9e0fa3 --- /dev/null +++ b/inspiremusic/wavtokenizer/decoder/discriminators.py @@ -0,0 +1,202 @@ +from typing import Tuple, List + +import torch +from torch import nn +from torch.nn import Conv2d +from torch.nn.utils import weight_norm + + +class MultiPeriodDiscriminator(nn.Module): + """ + Multi-Period Discriminator module adapted from https://github.com/jik876/hifi-gan. + Additionally, it allows incorporating conditional information with a learned embeddings table. + + Args: + periods (tuple[int]): Tuple of periods for each discriminator. + num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator. + Defaults to None. + """ + + def __init__(self, periods: Tuple[int] = (2, 3, 5, 7, 11), num_embeddings: int = None): + super().__init__() + self.discriminators = nn.ModuleList([DiscriminatorP(period=p, num_embeddings=num_embeddings) for p in periods]) + + def forward( + self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]: + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for d in self.discriminators: + y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) + y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorP(nn.Module): + def __init__( + self, + period: int, + in_channels: int = 1, + kernel_size: int = 5, + stride: int = 3, + lrelu_slope: float = 0.1, + num_embeddings: int = None, + ): + super().__init__() + self.period = period + self.convs = nn.ModuleList( + [ + weight_norm(Conv2d(in_channels, 32, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), + weight_norm(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), + weight_norm(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), + weight_norm(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), + weight_norm(Conv2d(1024, 1024, (kernel_size, 1), (1, 1), padding=(kernel_size // 2, 0))), + ] + ) + if num_embeddings is not None: + self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=1024) + torch.nn.init.zeros_(self.emb.weight) + + self.conv_post = weight_norm(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + self.lrelu_slope = lrelu_slope + + def forward( + self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + x = x.unsqueeze(1) + fmap = [] + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = torch.nn.functional.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for i, l in enumerate(self.convs): + x = l(x) + x = torch.nn.functional.leaky_relu(x, self.lrelu_slope) + if i > 0: + fmap.append(x) + if cond_embedding_id is not None: + emb = self.emb(cond_embedding_id) + h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) + else: + h = 0 + x = self.conv_post(x) + fmap.append(x) + x += h + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiResolutionDiscriminator(nn.Module): + def __init__( + self, + resolutions: Tuple[Tuple[int, int, int]] = ((1024, 256, 1024), (2048, 512, 2048), (512, 128, 512)), + num_embeddings: int = None, + ): + """ + Multi-Resolution Discriminator module adapted from https://github.com/mindslab-ai/univnet. + Additionally, it allows incorporating conditional information with a learned embeddings table. + + Args: + resolutions (tuple[tuple[int, int, int]]): Tuple of resolutions for each discriminator. + Each resolution should be a tuple of (n_fft, hop_length, win_length). + num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator. + Defaults to None. + """ + super().__init__() + self.discriminators = nn.ModuleList( + [DiscriminatorR(resolution=r, num_embeddings=num_embeddings) for r in resolutions] + ) + + def forward( + self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]: + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for d in self.discriminators: + y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) + y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorR(nn.Module): + def __init__( + self, + resolution: Tuple[int, int, int], + channels: int = 64, + in_channels: int = 1, + num_embeddings: int = None, + lrelu_slope: float = 0.1, + ): + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.lrelu_slope = lrelu_slope + self.convs = nn.ModuleList( + [ + weight_norm(nn.Conv2d(in_channels, channels, kernel_size=(7, 5), stride=(2, 2), padding=(3, 2))), + weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 1), padding=(2, 1))), + weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 2), padding=(2, 1))), + weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 1), padding=1)), + weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 2), padding=1)), + ] + ) + if num_embeddings is not None: + self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels) + torch.nn.init.zeros_(self.emb.weight) + self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), padding=(1, 1))) + + def forward( + self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + fmap = [] + x = self.spectrogram(x) + x = x.unsqueeze(1) + for l in self.convs: + x = l(x) + x = torch.nn.functional.leaky_relu(x, self.lrelu_slope) + fmap.append(x) + if cond_embedding_id is not None: + emb = self.emb(cond_embedding_id) + h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) + else: + h = 0 + x = self.conv_post(x) + fmap.append(x) + x += h + x = torch.flatten(x, 1, -1) + + return x, fmap + + def spectrogram(self, x: torch.Tensor) -> torch.Tensor: + n_fft, hop_length, win_length = self.resolution + magnitude_spectrogram = torch.stft( + x, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=None, # interestingly rectangular window kind of works here + center=True, + return_complex=True, + ).abs() + + return magnitude_spectrogram diff --git a/inspiremusic/wavtokenizer/decoder/experiment.py b/inspiremusic/wavtokenizer/decoder/experiment.py new file mode 100644 index 0000000000000000000000000000000000000000..f5557a218aedf77cc43ba3e774d6accb6003e45f --- /dev/null +++ b/inspiremusic/wavtokenizer/decoder/experiment.py @@ -0,0 +1,474 @@ +import math + +import numpy as np +import pytorch_lightning as pl +import torch +import torchaudio +import transformers +import yaml + +from decoder.discriminator_dac import DACDiscriminator + +from decoder.discriminators import MultiPeriodDiscriminator, MultiResolutionDiscriminator +from decoder.feature_extractors import FeatureExtractor +from decoder.heads import FourierHead +from decoder.helpers import plot_spectrogram_to_numpy +from decoder.loss import DiscriminatorLoss, GeneratorLoss, FeatureMatchingLoss, MelSpecReconstructionLoss, DACGANLoss +from decoder.models import Backbone +from decoder.modules import safe_log +from decoder.pretrained_model import instantiate_class + + +class VocosExp(pl.LightningModule): + # noinspection PyUnusedLocal + def __init__( + self, + feature_extractor: FeatureExtractor, + backbone: Backbone, + head: FourierHead, + resume_config: str, + resume_model: str, + sample_rate: int = 24000, + initial_learning_rate: float = 2e-4, + num_warmup_steps: int = 0, + mel_loss_coeff: float = 45, + mrd_loss_coeff: float = 1.0, + pretrain_mel_steps: int = 0, + decay_mel_coeff: bool = False, + evaluate_utmos: bool = False, + evaluate_pesq: bool = False, + evaluate_periodicty: bool = False, + resume: bool = False, + ): + """ + Args: + feature_extractor (FeatureExtractor): An instance of FeatureExtractor to extract features from audio signals. + backbone (Backbone): An instance of Backbone model. + head (FourierHead): An instance of Fourier head to generate spectral coefficients and reconstruct a waveform. + sample_rate (int): Sampling rate of the audio signals. + initial_learning_rate (float): Initial learning rate for the optimizer. + num_warmup_steps (int): Number of steps for the warmup phase of learning rate scheduler. Default is 0. + mel_loss_coeff (float, optional): Coefficient for Mel-spectrogram loss in the loss function. Default is 45. + mrd_loss_coeff (float, optional): Coefficient for Multi Resolution Discriminator loss. Default is 1.0. + pretrain_mel_steps (int, optional): Number of steps to pre-train the model without the GAN objective. Default is 0. + decay_mel_coeff (bool, optional): If True, the Mel-spectrogram loss coefficient is decayed during training. Default is False. + evaluate_utmos (bool, optional): If True, UTMOS scores are computed for each validation run. + evaluate_pesq (bool, optional): If True, PESQ scores are computed for each validation run. + evaluate_periodicty (bool, optional): If True, periodicity scores are computed for each validation run. + """ + super().__init__() + self.save_hyperparameters(ignore=["feature_extractor", "backbone", "head"]) + + self.feature_extractor = feature_extractor + self.backbone = backbone + self.head = head + + self.resume_config = resume_config + self.resume_model = resume_model + self.resume = resume + + self.multiperioddisc = MultiPeriodDiscriminator() + self.multiresddisc = MultiResolutionDiscriminator() + + + self.dac = DACDiscriminator() + + self.dacdiscriminator = DACGANLoss(self.dac) + + self.disc_loss = DiscriminatorLoss() + self.gen_loss = GeneratorLoss() + self.feat_matching_loss = FeatureMatchingLoss() + self.melspec_loss = MelSpecReconstructionLoss(sample_rate=sample_rate) + + self.train_discriminator = False + self.base_mel_coeff = self.mel_loss_coeff = mel_loss_coeff + + def configure_optimizers(self): + disc_params = [ + {"params": self.multiperioddisc.parameters()}, + {"params": self.multiresddisc.parameters()}, + {"params": self.dac.parameters()}, + ] + gen_params = [ + {"params": self.feature_extractor.parameters()}, + {"params": self.backbone.parameters()}, + {"params": self.head.parameters()}, + ] + + opt_disc = torch.optim.AdamW(disc_params, lr=self.hparams.initial_learning_rate) + opt_gen = torch.optim.AdamW(gen_params, lr=self.hparams.initial_learning_rate) + + max_steps = self.trainer.max_steps // 2 # Max steps per optimizer + scheduler_disc = transformers.get_cosine_schedule_with_warmup( + opt_disc, num_warmup_steps=self.hparams.num_warmup_steps, num_training_steps=max_steps, + ) + scheduler_gen = transformers.get_cosine_schedule_with_warmup( + opt_gen, num_warmup_steps=self.hparams.num_warmup_steps, num_training_steps=max_steps, + ) + + return ( + [opt_disc, opt_gen], + [{"scheduler": scheduler_disc, "interval": "step"}, {"scheduler": scheduler_gen, "interval": "step"}], + ) + + def forward(self, audio_input, **kwargs): + features, _, commit_loss = self.feature_extractor(audio_input, **kwargs) + # print('1111', self.feature_extractor.state_dict()['encodec.decoder.model.3.convtr.convtr.weight_g']) + x = self.backbone(features, **kwargs) + audio_output = self.head(x) + return audio_output, commit_loss + + def training_step(self, batch, batch_idx, optimizer_idx, **kwargs): + audio_input = batch + + # train discriminator + if optimizer_idx == 0 and self.train_discriminator: + with torch.no_grad(): + audio_hat, _ = self(audio_input, **kwargs) + + + loss_dac=self.dacdiscriminator.discriminator_loss(audio_hat.unsqueeze(1),audio_input.unsqueeze(1)) + + real_score_mp, gen_score_mp, _, _ = self.multiperioddisc(y=audio_input, y_hat=audio_hat, **kwargs,) + real_score_mrd, gen_score_mrd, _, _ = self.multiresddisc(y=audio_input, y_hat=audio_hat, **kwargs,) + loss_mp, loss_mp_real, _ = self.disc_loss( + disc_real_outputs=real_score_mp, disc_generated_outputs=gen_score_mp + ) + loss_mrd, loss_mrd_real, _ = self.disc_loss( + disc_real_outputs=real_score_mrd, disc_generated_outputs=gen_score_mrd + ) + loss_mp /= len(loss_mp_real) + loss_mrd /= len(loss_mrd_real) + loss = loss_mp + self.hparams.mrd_loss_coeff * loss_mrd + loss_dac + + self.log("discriminator/total", loss, prog_bar=True) + self.log("discriminator/multi_period_loss", loss_mp) + self.log("discriminator/multi_res_loss", loss_mrd) + self.log("discriminator/dac", loss_dac) + return loss + + # train generator + if optimizer_idx == 1: + audio_hat, commit_loss = self(audio_input, **kwargs) + if self.train_discriminator: + + loss_dac_1,loss_dac_2 = self.dacdiscriminator.generator_loss(audio_hat.unsqueeze(1),audio_input.unsqueeze(1)) + _, gen_score_mp, fmap_rs_mp, fmap_gs_mp = self.multiperioddisc( + y=audio_input, y_hat=audio_hat, **kwargs, + ) + _, gen_score_mrd, fmap_rs_mrd, fmap_gs_mrd = self.multiresddisc( + y=audio_input, y_hat=audio_hat, **kwargs, + ) + loss_gen_mp, list_loss_gen_mp = self.gen_loss(disc_outputs=gen_score_mp) + loss_gen_mrd, list_loss_gen_mrd = self.gen_loss(disc_outputs=gen_score_mrd) + loss_gen_mp = loss_gen_mp / len(list_loss_gen_mp) + loss_gen_mrd = loss_gen_mrd / len(list_loss_gen_mrd) + loss_fm_mp = self.feat_matching_loss(fmap_r=fmap_rs_mp, fmap_g=fmap_gs_mp) / len(fmap_rs_mp) + loss_fm_mrd = self.feat_matching_loss(fmap_r=fmap_rs_mrd, fmap_g=fmap_gs_mrd) / len(fmap_rs_mrd) + + self.log("generator/multi_period_loss", loss_gen_mp) + self.log("generator/multi_res_loss", loss_gen_mrd) + self.log("generator/feature_matching_mp", loss_fm_mp) + self.log("generator/feature_matching_mrd", loss_fm_mrd) + self.log("generator/loss_dac_1", loss_dac_1) + self.log("generator/loss_dac_2", loss_dac_2) + else: + loss_gen_mp = loss_gen_mrd = loss_fm_mp = loss_fm_mrd = 0 + + mel_loss = self.melspec_loss(audio_hat, audio_input) + loss = ( + loss_gen_mp + + self.hparams.mrd_loss_coeff * loss_gen_mrd + + loss_fm_mp + + self.hparams.mrd_loss_coeff * loss_fm_mrd + + self.mel_loss_coeff * mel_loss + + 1000 * commit_loss + + loss_dac_1 + + loss_dac_2 + ) + + self.log("generator/total_loss", loss, prog_bar=True) + self.log("mel_loss_coeff", self.mel_loss_coeff) + self.log("generator/mel_loss", mel_loss) + self.log("commit_loss", commit_loss) + + if self.global_step % 1000 == 0 and self.global_rank == 0: + self.logger.experiment.add_audio( + "train/audio_in", audio_input[0].data.cpu(), self.global_step, self.hparams.sample_rate + ) + self.logger.experiment.add_audio( + "train/audio_pred", audio_hat[0].data.cpu(), self.global_step, self.hparams.sample_rate + ) + with torch.no_grad(): + mel = safe_log(self.melspec_loss.mel_spec(audio_input[0])) + mel_hat = safe_log(self.melspec_loss.mel_spec(audio_hat[0])) + self.logger.experiment.add_image( + "train/mel_target", + plot_spectrogram_to_numpy(mel.data.cpu().numpy()), + self.global_step, + dataformats="HWC", + ) + self.logger.experiment.add_image( + "train/mel_pred", + plot_spectrogram_to_numpy(mel_hat.data.cpu().numpy()), + self.global_step, + dataformats="HWC", + ) + + return loss + + def on_validation_epoch_start(self): + if self.hparams.evaluate_utmos: + from metrics.UTMOS import UTMOSScore + + if not hasattr(self, "utmos_model"): + self.utmos_model = UTMOSScore(device=self.device) + + def validation_step(self, batch, batch_idx, **kwargs): + audio_input = batch + audio_hat, commit_loss = self(audio_input, **kwargs) + + audio_16_khz = torchaudio.functional.resample(audio_input, orig_freq=self.hparams.sample_rate, new_freq=16000) + audio_hat_16khz = torchaudio.functional.resample(audio_hat, orig_freq=self.hparams.sample_rate, new_freq=16000) + + if self.hparams.evaluate_periodicty: + from metrics.periodicity import calculate_periodicity_metrics + + periodicity_loss, pitch_loss, f1_score = calculate_periodicity_metrics(audio_16_khz, audio_hat_16khz) + else: + periodicity_loss = pitch_loss = f1_score = 0 + + if self.hparams.evaluate_utmos: + utmos_score = self.utmos_model.score(audio_hat_16khz.unsqueeze(1)).mean() + else: + utmos_score = torch.zeros(1, device=self.device) + + if self.hparams.evaluate_pesq: + from pesq import pesq + + pesq_score = 0 + for ref, deg in zip(audio_16_khz.cpu().numpy(), audio_hat_16khz.cpu().numpy()): + pesq_score += pesq(16000, ref, deg, "wb", on_error=1) + pesq_score /= len(audio_16_khz) + pesq_score = torch.tensor(pesq_score) + else: + pesq_score = torch.zeros(1, device=self.device) + + mel_loss = self.melspec_loss(audio_hat.unsqueeze(1), audio_input.unsqueeze(1)) + total_loss = mel_loss + (5 - utmos_score) + (5 - pesq_score) + 1000 * commit_loss + + return { + "val_loss": total_loss, + "mel_loss": mel_loss, + "utmos_score": utmos_score, + "pesq_score": pesq_score, + "periodicity_loss": periodicity_loss, + "pitch_loss": pitch_loss, + "f1_score": f1_score, + "audio_input": audio_input[0], + "audio_pred": audio_hat[0], + } + + def validation_epoch_end(self, outputs): + if self.global_rank == 0: + *_, audio_in, audio_pred = outputs[0].values() + self.logger.experiment.add_audio( + "val_in", audio_in.data.cpu().numpy(), self.global_step, self.hparams.sample_rate + ) + self.logger.experiment.add_audio( + "val_pred", audio_pred.data.cpu().numpy(), self.global_step, self.hparams.sample_rate + ) + mel_target = safe_log(self.melspec_loss.mel_spec(audio_in)) + mel_hat = safe_log(self.melspec_loss.mel_spec(audio_pred)) + self.logger.experiment.add_image( + "val_mel_target", + plot_spectrogram_to_numpy(mel_target.data.cpu().numpy()), + self.global_step, + dataformats="HWC", + ) + self.logger.experiment.add_image( + "val_mel_hat", + plot_spectrogram_to_numpy(mel_hat.data.cpu().numpy()), + self.global_step, + dataformats="HWC", + ) + avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() + mel_loss = torch.stack([x["mel_loss"] for x in outputs]).mean() + utmos_score = torch.stack([x["utmos_score"] for x in outputs]).mean() + pesq_score = torch.stack([x["pesq_score"] for x in outputs]).mean() + periodicity_loss = np.array([x["periodicity_loss"] for x in outputs]).mean() + pitch_loss = np.array([x["pitch_loss"] for x in outputs]).mean() + f1_score = np.array([x["f1_score"] for x in outputs]).mean() + + self.log("val_loss", avg_loss, sync_dist=True) + self.log("val/mel_loss", mel_loss, sync_dist=True) + self.log("val/utmos_score", utmos_score, sync_dist=True) + self.log("val/pesq_score", pesq_score, sync_dist=True) + self.log("val/periodicity_loss", periodicity_loss, sync_dist=True) + self.log("val/pitch_loss", pitch_loss, sync_dist=True) + self.log("val/f1_score", f1_score, sync_dist=True) + + @property + def global_step(self): + """ + Override global_step so that it returns the total number of batches processed + """ + return self.trainer.fit_loop.epoch_loop.total_batch_idx + + def on_train_batch_start(self, *args): + if self.global_step >= self.hparams.pretrain_mel_steps: + self.train_discriminator = True + else: + self.train_discriminator = False + + def on_train_batch_end(self, *args): + def mel_loss_coeff_decay(current_step, num_cycles=0.5): + max_steps = self.trainer.max_steps // 2 + if current_step < self.hparams.num_warmup_steps: + return 1.0 + progress = float(current_step - self.hparams.num_warmup_steps) / float( + max(1, max_steps - self.hparams.num_warmup_steps) + ) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) + + if self.hparams.decay_mel_coeff: + self.mel_loss_coeff = self.base_mel_coeff * mel_loss_coeff_decay(self.global_step + 1) + + +class WavTokenizer(VocosExp): + """ + WavTokenizer is a subclass of VocosExp that overrides the parent experiment to function as a conditional GAN. + It manages an additional `bandwidth_id` attribute, which denotes a learnable embedding corresponding to + a specific bandwidth value of EnCodec. During training, a random bandwidth_id is generated for each step, + while during validation, a fixed bandwidth_id is used. + """ + + def __init__( + self, + feature_extractor: FeatureExtractor, + backbone: Backbone, + head: FourierHead, + resume_config: str, + resume_model: str, + sample_rate: int = 24000, + initial_learning_rate: float = 2e-4, + num_warmup_steps: int = 0, + mel_loss_coeff: float = 45, + mrd_loss_coeff: float = 1.0, + pretrain_mel_steps: int = 0, + decay_mel_coeff: bool = False, + evaluate_utmos: bool = False, + evaluate_pesq: bool = False, + evaluate_periodicty: bool = False, + resume: bool = False, + ): + super().__init__( + feature_extractor, + backbone, + head, + resume_config, + resume_model, + sample_rate, + initial_learning_rate, + num_warmup_steps, + mel_loss_coeff, + mrd_loss_coeff, + pretrain_mel_steps, + decay_mel_coeff, + evaluate_utmos, + evaluate_pesq, + evaluate_periodicty, + resume + ) + # Override with conditional discriminators + # VocosExp.__init__(self, feature_extractor, backbone, head, resume_config, resume_model) + # if self.resume: + # VocosExp.load_from_checkpoint(self.resume_model) + self.multiperioddisc = MultiPeriodDiscriminator(num_embeddings=len(self.feature_extractor.bandwidths)) + self.multiresddisc = MultiResolutionDiscriminator(num_embeddings=len(self.feature_extractor.bandwidths)) + self.dac = DACDiscriminator() + if self.resume: + print('ๅŠ ่ฝฝ้ข„่ฎญ็ปƒๆจกๅž‹:', self.resume_model) + # with open(self.resume_config, "r") as f: + # config = yaml.safe_load(f) + # feature_extractor = instantiate_class(args=(), init=config['model']['init_args']["feature_extractor"]) + # backbone = instantiate_class(args=(), init=config['model']['init_args']["backbone"]) + # head = instantiate_class(args=(), init=config['model']['init_args']["head"]) + + # ไธๅŠ ่ฝฝ้‡ๅŒ–ๅ™จ้ƒจๅˆ†ๆƒ้‡ + state_dict_raw = torch.load(self.resume_model, map_location=self.device)['state_dict'] + state_dict_fa_qa = dict() + state_dict_fa_en = dict() + state_dict_fa_de = dict() + state_dict_bb = dict() + state_dict_hd = dict() + state_dict_mp = dict() + state_dict_mr = dict() + state_dict_dac = dict() + for k, v in state_dict_raw.items(): + # breakpoint() + if k.startswith('feature_extractor.encodec.quantizer'): + # breakpoint() + # print("*****",k) + ss = k[46:48] + if ss[-1] == '.': + num = int(ss[0]) + # print("num,k",num,k[36:]) + if num <= 7: + state_dict_fa_qa[k[36:]] = v + if k.startswith('feature_extractor.encodec.encoder'): + state_dict_fa_en[k[34:]] = v + if k.startswith('feature_extractor.encodec.decoder'): + state_dict_fa_de[k[34:]] = v + if k.startswith('backbone.'): + state_dict_bb[k[9:]] = v + if k.startswith('head.'): + state_dict_hd[k[5:]] = v + if k.startswith('multiperioddisc.'): + state_dict_mp[k[16:]] = v + if k.startswith('multiresddisc.'): + state_dict_mr[k[14:]] = v + if k.startswith('dac.'): + state_dict_dac[k[4:]] = v + # breakpoint() + # feature_extractor.encodec.quantizer.load_state_dict(state_dict_fa_qa, strict=True) + feature_extractor.encodec.encoder.load_state_dict(state_dict_fa_en, strict=True) + feature_extractor.encodec.decoder.load_state_dict(state_dict_fa_de, strict=True) + feature_extractor.encodec.quantizer.load_state_dict(state_dict_fa_qa, strict=True) + backbone.load_state_dict(state_dict_bb, strict=True) + head.load_state_dict(state_dict_hd, strict=True) + self.feature_extractor = feature_extractor.to(self.device) + self.backbone = backbone.to(self.device) + self.head = head.to(self.device) + self.multiperioddisc.load_state_dict(state_dict_mp, strict=True) + self.multiresddisc.load_state_dict(state_dict_mr, strict=True) + self.dac.load_state_dict(state_dict_dac, strict=True) + + def training_step(self, *args): + # print('-------------------train--------------------') + # if self.global_rank == 0 and self.resume: + # config_path = self.resume_config + # model_path = self.resume_model + # self.pretrained_load(config_path, model_path) + # print('ๅŠ ่ฝฝ้ข„่ฎญ็ปƒๆจกๅž‹:', model_path) + bandwidth_id = torch.randint(low=0, high=len(self.feature_extractor.bandwidths), size=(1,), device=self.device,) + output = super().training_step(*args, bandwidth_id=bandwidth_id) + return output + + def validation_step(self, *args): + # print('-------------------valid--------------------') + bandwidth_id = torch.tensor([0], device=self.device) + output = super().validation_step(*args, bandwidth_id=bandwidth_id) + return output + + def validation_epoch_end(self, outputs): + if self.global_rank == 0: + *_, audio_in, _ = outputs[0].values() + # Resynthesis with encodec for reference + self.feature_extractor.encodec.set_target_bandwidth(self.feature_extractor.bandwidths[0]) + encodec_audio = self.feature_extractor.encodec(audio_in[None, None, :]) + self.logger.experiment.add_audio( + "encodec", encodec_audio[0, 0].data.cpu().numpy(), self.global_step, self.hparams.sample_rate, + ) + + super().validation_epoch_end(outputs) diff --git a/inspiremusic/wavtokenizer/decoder/feature_extractors.py b/inspiremusic/wavtokenizer/decoder/feature_extractors.py new file mode 100644 index 0000000000000000000000000000000000000000..d4672d141ba89b88ecdec4d48464252cb524fb9f --- /dev/null +++ b/inspiremusic/wavtokenizer/decoder/feature_extractors.py @@ -0,0 +1,176 @@ +from typing import List + +import torch +import torchaudio +from torch import nn +import math +# from inspiremusic.wavtokenizer.decoder.modules import safe_log +from inspiremusic.wavtokenizer.encoder.modules import SEANetEncoder, SEANetDecoder +from inspiremusic.wavtokenizer.encoder import EncodecModel +from inspiremusic.wavtokenizer.encoder.quantization import ResidualVectorQuantizer + + +def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor: + """ + Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values. + + Args: + x (Tensor): Input tensor. + clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7. + + Returns: + Tensor: Element-wise logarithm of the input tensor with clipping applied. + """ + return torch.log(torch.clip(x, min=clip_val)) + + +def symlog(x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) * torch.log1p(x.abs()) + + +def symexp(x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) * (torch.exp(x.abs()) - 1) + + +class FeatureExtractor(nn.Module): + """Base class for feature extractors.""" + + def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Extract features from the given audio. + + Args: + audio (Tensor): Input audio waveform. + + Returns: + Tensor: Extracted features of shape (B, C, L), where B is the batch size, + C denotes output features, and L is the sequence length. + """ + raise NotImplementedError("Subclasses must implement the forward method.") + + +class MelSpectrogramFeatures(FeatureExtractor): + def __init__(self, sample_rate=24000, n_fft=1024, hop_length=256, n_mels=100, padding="center"): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.mel_spec = torchaudio.transforms.MelSpectrogram( + sample_rate=sample_rate, + n_fft=n_fft, + hop_length=hop_length, + n_mels=n_mels, + center=padding == "center", + power=1, + ) + + def forward(self, audio, **kwargs): + if self.padding == "same": + pad = self.mel_spec.win_length - self.mel_spec.hop_length + audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect") + mel = self.mel_spec(audio) + features = safe_log(mel) + return features + + +class EncodecFeatures(FeatureExtractor): + def __init__( + self, + encodec_model: str = "encodec_24khz", + bandwidths: List[float] = [1.5, 3.0, 6.0, 12.0], + train_codebooks: bool = False, + num_quantizers: int = 1, + dowmsamples: List[int] = [6, 5, 5, 4], + vq_bins: int = 16384, + vq_kmeans: int = 800, + ): + super().__init__() + + # breakpoint() + self.frame_rate = 25 # not use + # n_q = int(bandwidths[-1]*1000/(math.log2(2048) * self.frame_rate)) + n_q = num_quantizers # important + encoder = SEANetEncoder(causal=False, n_residual_layers=1, norm='weight_norm', pad_mode='reflect', lstm=2, + dimension=512, channels=1, n_filters=32, ratios=dowmsamples, activation='ELU', + kernel_size=7, residual_kernel_size=3, last_kernel_size=7, dilation_base=2, + true_skip=False, compress=2) + decoder = SEANetDecoder(causal=False, n_residual_layers=1, norm='weight_norm', pad_mode='reflect', lstm=2, + dimension=512, channels=1, n_filters=32, ratios=[8, 5, 4, 2], activation='ELU', + kernel_size=7, residual_kernel_size=3, last_kernel_size=7, dilation_base=2, + true_skip=False, compress=2) + quantizer = ResidualVectorQuantizer(dimension=512, n_q=n_q, bins=vq_bins, kmeans_iters=vq_kmeans, + decay=0.99, kmeans_init=True) + + # breakpoint() + if encodec_model == "encodec_24khz": + self.encodec = EncodecModel(encoder=encoder, decoder=decoder, quantizer=quantizer, + target_bandwidths=bandwidths, sample_rate=24000, channels=1) + else: + raise ValueError( + f"Unsupported encodec_model: {encodec_model}. Supported options are 'encodec_24khz'." + ) + for param in self.encodec.parameters(): + param.requires_grad = True + # self.num_q = n_q + # codebook_weights = torch.cat([vq.codebook for vq in self.encodec.quantizer.vq.layers[: self.num_q]], dim=0) + # self.codebook_weights = torch.nn.Parameter(codebook_weights, requires_grad=train_codebooks) + self.bandwidths = bandwidths + + # @torch.no_grad() + # def get_encodec_codes(self, audio): + # audio = audio.unsqueeze(1) + # emb = self.encodec.encoder(audio) + # codes = self.encodec.quantizer.encode(emb, self.encodec.frame_rate, self.encodec.bandwidth) + # return codes + + def forward(self, audio: torch.Tensor, bandwidth_id: torch.Tensor = torch.tensor(0)): + if self.training: + self.encodec.train() + + audio = audio.unsqueeze(1) # audio(16,24000) + + # breakpoint() + + emb = self.encodec.encoder(audio) + q_res = self.encodec.quantizer(emb, self.frame_rate, bandwidth=self.bandwidths[bandwidth_id]) + quantized = q_res.quantized + codes = q_res.codes + commit_loss = q_res.penalty # codes(8,16,75),features(16,128,75) + + return quantized, codes, commit_loss + + # codes = self.get_encodec_codes(audio) + # # Instead of summing in the loop, it stores subsequent VQ dictionaries in a single `self.codebook_weights` + # # with offsets given by the number of bins, and finally summed in a vectorized operation. + # offsets = torch.arange( + # 0, self.encodec.quantizer.bins * len(codes), self.encodec.quantizer.bins, device=audio.device + # ) + # embeddings_idxs = codes + offsets.view(-1, 1, 1) + # features = torch.nn.functional.embedding(embeddings_idxs, self.codebook_weights).sum(dim=0) + # return features.transpose(1, 2) + + def infer(self, audio: torch.Tensor, bandwidth_id: torch.Tensor): + if self.training: + self.encodec.train() + + audio = audio.unsqueeze(1) # audio(16,24000) + emb = self.encodec.encoder(audio) + q_res = self.encodec.quantizer.infer(emb, self.frame_rate, bandwidth=self.bandwidths[bandwidth_id]) + quantized = q_res.quantized + codes = q_res.codes + commit_loss = q_res.penalty # codes(8,16,75),features(16,128,75) + + return quantized, codes, commit_loss + + def _infer(self, audio: torch.Tensor, bandwidth_id: torch.Tensor = torch.tensor(0)): + if self.training: + self.encodec.train() + + audio = audio.unsqueeze(1) # audio(16,24000) + emb = self.encodec.encoder(audio) + q_res = self.encodec.quantizer.infer(emb, self.frame_rate, bandwidth=self.bandwidths[bandwidth_id]) + quantized = q_res.quantized + codes = q_res.codes + commit_loss = q_res.penalty # codes(8,16,75),features(16,128,75) + + return quantized, codes, commit_loss \ No newline at end of file diff --git a/inspiremusic/wavtokenizer/decoder/heads.py b/inspiremusic/wavtokenizer/decoder/heads.py new file mode 100644 index 0000000000000000000000000000000000000000..eb3b9f85bda23ae73c09e462cb584bc9878faca9 --- /dev/null +++ b/inspiremusic/wavtokenizer/decoder/heads.py @@ -0,0 +1,159 @@ +import torch +from torch import nn +from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz + +from inspiremusic.wavtokenizer.decoder.spectral_ops import IMDCT, ISTFT + +def symexp(x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) * (torch.exp(x.abs()) - 1) + +class FourierHead(nn.Module): + """Base class for inverse fourier modules.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + raise NotImplementedError("Subclasses must implement the forward method.") + + +class ISTFTHead(FourierHead): + """ + ISTFT Head module for predicting STFT complex coefficients. + + Args: + dim (int): Hidden dimension of the model. + n_fft (int): Size of Fourier transform. + hop_length (int): The distance between neighboring sliding window frames, which should align with + the resolution of the input features. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): + super().__init__() + out_dim = n_fft + 2 + self.out = torch.nn.Linear(dim, out_dim) + self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the ISTFTHead module. + + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + x = self.out(x).transpose(1, 2) + mag, p = x.chunk(2, dim=1) + mag = torch.exp(mag) + mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes + # wrapping happens here. These two lines produce real and imaginary value + x = torch.cos(p) + y = torch.sin(p) + # recalculating phase here does not produce anything new + # only costs time + # phase = torch.atan2(y, x) + # S = mag * torch.exp(phase * 1j) + # better directly produce the complex value + + S = mag * (x + 1j * y) + + audio = self.istft(S) + return audio + +class IMDCTSymExpHead(FourierHead): + """ + IMDCT Head module for predicting MDCT coefficients with symmetric exponential function + + Args: + dim (int): Hidden dimension of the model. + mdct_frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized + based on perceptual scaling. Defaults to None. + clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False. + """ + + def __init__( + self, dim: int, mdct_frame_len: int, padding: str = "same", sample_rate: int = None, clip_audio: bool = False, + ): + super().__init__() + out_dim = mdct_frame_len // 2 + self.out = nn.Linear(dim, out_dim) + self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding) + self.clip_audio = clip_audio + + if sample_rate is not None: + # optionally init the last layer following mel-scale + m_max = _hz_to_mel(sample_rate // 2) + m_pts = torch.linspace(0, m_max, out_dim) + f_pts = _mel_to_hz(m_pts) + scale = 1 - (f_pts / f_pts.max()) + + with torch.no_grad(): + self.out.weight.mul_(scale.view(-1, 1)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the IMDCTSymExpHead module. + + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + x = self.out(x) + x = symexp(x) + x = torch.clip(x, min=-1e2, max=1e2) # safeguard to prevent excessively large magnitudes + audio = self.imdct(x) + if self.clip_audio: + audio = torch.clip(x, min=-1.0, max=1.0) + + return audio + + +class IMDCTCosHead(FourierHead): + """ + IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) ยท cos(p) + + Args: + dim (int): Hidden dimension of the model. + mdct_frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False. + """ + + def __init__(self, dim: int, mdct_frame_len: int, padding: str = "same", clip_audio: bool = False): + super().__init__() + self.clip_audio = clip_audio + self.out = nn.Linear(dim, mdct_frame_len) + self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the IMDCTCosHead module. + + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + x = self.out(x) + m, p = x.chunk(2, dim=2) + m = torch.exp(m).clip(max=1e2) # safeguard to prevent excessively large magnitudes + audio = self.imdct(m * torch.cos(p)) + if self.clip_audio: + audio = torch.clip(x, min=-1.0, max=1.0) + return audio diff --git a/inspiremusic/wavtokenizer/decoder/helpers.py b/inspiremusic/wavtokenizer/decoder/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..3d303010352ad59dde2996605f124128ee17db36 --- /dev/null +++ b/inspiremusic/wavtokenizer/decoder/helpers.py @@ -0,0 +1,71 @@ +import matplotlib +import numpy as np +import torch +from matplotlib import pyplot as plt +from pytorch_lightning import Callback + +matplotlib.use("Agg") + + +def save_figure_to_numpy(fig: plt.Figure) -> np.ndarray: + """ + Save a matplotlib figure to a numpy array. + + Args: + fig (Figure): Matplotlib figure object. + + Returns: + ndarray: Numpy array representing the figure. + """ + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + return data + + +def plot_spectrogram_to_numpy(spectrogram: np.ndarray) -> np.ndarray: + """ + Plot a spectrogram and convert it to a numpy array. + + Args: + spectrogram (ndarray): Spectrogram data. + + Returns: + ndarray: Numpy array representing the plotted spectrogram. + """ + spectrogram = spectrogram.astype(np.float32) + fig, ax = plt.subplots(figsize=(12, 3)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.tight_layout() + + fig.canvas.draw() + data = save_figure_to_numpy(fig) + plt.close() + return data + + +class GradNormCallback(Callback): + """ + Callback to log the gradient norm. + """ + + def on_after_backward(self, trainer, model): + model.log("grad_norm", gradient_norm(model)) + + +def gradient_norm(model: torch.nn.Module, norm_type: float = 2.0) -> torch.Tensor: + """ + Compute the gradient norm. + + Args: + model (Module): PyTorch model. + norm_type (float, optional): Type of the norm. Defaults to 2.0. + + Returns: + Tensor: Gradient norm. + """ + grads = [p.grad for p in model.parameters() if p.grad is not None] + total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type) for g in grads]), norm_type) + return total_norm diff --git a/inspiremusic/wavtokenizer/decoder/loss.py b/inspiremusic/wavtokenizer/decoder/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..30f32ccf9a3f5373335ddb8da1334f508c16f752 --- /dev/null +++ b/inspiremusic/wavtokenizer/decoder/loss.py @@ -0,0 +1,159 @@ +from typing import List, Tuple + +import torch +import torchaudio +from torch import nn + +from decoder.modules import safe_log + +import torch.nn.functional as F + + +class MelSpecReconstructionLoss(nn.Module): + """ + L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample + """ + + def __init__( + self, sample_rate: int = 24000, n_fft: int = 1024, hop_length: int = 256, n_mels: int = 100, + ): + super().__init__() + self.mel_spec = torchaudio.transforms.MelSpectrogram( + sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, center=True, power=1, + ) + + def forward(self, y_hat, y) -> torch.Tensor: + """ + Args: + y_hat (Tensor): Predicted audio waveform. + y (Tensor): Ground truth audio waveform. + + Returns: + Tensor: L1 loss between the mel-scaled magnitude spectrograms. + """ + mel_hat = safe_log(self.mel_spec(y_hat)) + mel = safe_log(self.mel_spec(y)) + + loss = torch.nn.functional.l1_loss(mel, mel_hat) + + return loss + + +class GeneratorLoss(nn.Module): + """ + Generator Loss module. Calculates the loss for the generator based on discriminator outputs. + """ + + def forward(self, disc_outputs: List[torch.Tensor]) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Args: + disc_outputs (List[Tensor]): List of discriminator outputs. + + Returns: + Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from + the sub-discriminators + """ + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean(torch.clamp(1 - dg, min=0)) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + +class DiscriminatorLoss(nn.Module): + """ + Discriminator Loss module. Calculates the loss for the discriminator based on real and generated outputs. + """ + + def forward( + self, disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor] + ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: + """ + Args: + disc_real_outputs (List[Tensor]): List of discriminator outputs for real samples. + disc_generated_outputs (List[Tensor]): List of discriminator outputs for generated samples. + + Returns: + Tuple[Tensor, List[Tensor], List[Tensor]]: A tuple containing the total loss, a list of loss values from + the sub-discriminators for real outputs, and a list of + loss values for generated outputs. + """ + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean(torch.clamp(1 - dr, min=0)) + g_loss = torch.mean(torch.clamp(1 + dg, min=0)) + loss += r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +class FeatureMatchingLoss(nn.Module): + """ + Feature Matching Loss module. Calculates the feature matching loss between feature maps of the sub-discriminators. + """ + + def forward(self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor: + """ + Args: + fmap_r (List[List[Tensor]]): List of feature maps from real samples. + fmap_g (List[List[Tensor]]): List of feature maps from generated samples. + + Returns: + Tensor: The calculated feature matching loss. + """ + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss + +class DACGANLoss(nn.Module): + """ + Computes a discriminator loss, given a discriminator on + generated waveforms/spectrograms compared to ground truth + waveforms/spectrograms. Computes the loss for both the + discriminator and the generator in separate functions. + """ + + def __init__(self, discriminator): + super().__init__() + self.discriminator = discriminator + + def forward(self, fake, real): + # d_fake = self.discriminator(fake.audio_data) + # d_real = self.discriminator(real.audio_data) + d_fake = self.discriminator(fake) + d_real = self.discriminator(real) + return d_fake, d_real + + def discriminator_loss(self, fake, real): + d_fake, d_real = self.forward(fake.clone().detach(), real) + + loss_d = 0 + for x_fake, x_real in zip(d_fake, d_real): + loss_d += torch.mean(x_fake[-1] ** 2) + loss_d += torch.mean((1 - x_real[-1]) ** 2) + return loss_d + + def generator_loss(self, fake, real): + d_fake, d_real = self.forward(fake, real) + + loss_g = 0 + for x_fake in d_fake: + loss_g += torch.mean((1 - x_fake[-1]) ** 2) + + loss_feature = 0 + + for i in range(len(d_fake)): + for j in range(len(d_fake[i]) - 1): + loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) + return loss_g, loss_feature + diff --git a/inspiremusic/wavtokenizer/decoder/models.py b/inspiremusic/wavtokenizer/decoder/models.py new file mode 100644 index 0000000000000000000000000000000000000000..f7ce3d99c57feb48946039a7501c638874afdf62 --- /dev/null +++ b/inspiremusic/wavtokenizer/decoder/models.py @@ -0,0 +1,266 @@ +from typing import Optional + +import torch +from torch import nn +from torch.nn.utils import weight_norm + +from inspiremusic.wavtokenizer.decoder.modules import ConvNeXtBlock, ResBlock1, AdaLayerNorm + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv1d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv1d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv1d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv1d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb=None): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv1d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv1d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv1d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv1d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h = q.shape + q = q.permute(0, 2, 1) # b,hw,c + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + + h_ = self.proj_out(h_) + + return x + h_ + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + + +class Backbone(nn.Module): + """Base class for the generator's backbone. It preserves the same temporal resolution across all layers.""" + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Args: + x (Tensor): Input tensor of shape (B, C, L), where B is the batch size, + C denotes output features, and L is the sequence length. + + Returns: + Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length, + and H denotes the model dimension. + """ + raise NotImplementedError("Subclasses must implement the forward method.") + + +class VocosBackbone(Backbone): + """ + Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization + + Args: + input_channels (int): Number of input features channels. + dim (int): Hidden dimension of the model. + intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock. + num_layers (int): Number of ConvNeXtBlock layers. + layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`. + adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. + None means non-conditional model. Defaults to None. + """ + + def __init__( + self, + input_channels: int, + dim: int, + intermediate_dim: int, + num_layers: int, + layer_scale_init_value: Optional[float] = None, + adanorm_num_embeddings: Optional[int] = None, + ): + super().__init__() + self.input_channels = input_channels + self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3) + self.adanorm = adanorm_num_embeddings is not None + if adanorm_num_embeddings: + self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) + else: + self.norm = nn.LayerNorm(dim, eps=1e-6) + layer_scale_init_value = layer_scale_init_value or 1 / num_layers + self.convnext = nn.ModuleList( + [ + ConvNeXtBlock( + dim=dim, + intermediate_dim=intermediate_dim, + layer_scale_init_value=layer_scale_init_value, + adanorm_num_embeddings=adanorm_num_embeddings, + ) + for _ in range(num_layers) + ] + ) + self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6) + self.apply(self._init_weights) + + self.temb_ch = 0 + block_in = dim + dropout = 0.1 + attn_type="vanilla" + + pos_net : tp.List[nn.Module] = [ + ResnetBlock(in_channels=block_in,out_channels=block_in, + temb_channels=self.temb_ch,dropout=dropout), + ResnetBlock(in_channels=block_in,out_channels=block_in, + temb_channels=self.temb_ch,dropout=dropout), + make_attn(block_in, attn_type=attn_type), + ResnetBlock(in_channels=block_in,out_channels=block_in, + temb_channels=self.temb_ch,dropout=dropout), + ResnetBlock(in_channels=block_in,out_channels=block_in, + temb_channels=self.temb_ch,dropout=dropout), + Normalize(block_in) + ] + + self.pos_net = nn.Sequential(*pos_net) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv1d, nn.Linear)): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + def forward(self, x: torch.Tensor, bandwidth_id: Optional[torch.Tensor] = None) -> torch.Tensor: + x = self.embed(x) + x = self.pos_net(x) + if self.adanorm: + # assert bandwidth_id is not None + if bandwidth_id is None: + bandwidth_id = torch.tensor(0, device='cuda') + x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id) + else: + x = self.norm(x.transpose(1, 2)) + x = x.transpose(1, 2) + for conv_block in self.convnext: + x = conv_block(x, cond_embedding_id=bandwidth_id) + x = self.final_layer_norm(x.transpose(1, 2)) + return x + + +class VocosResNetBackbone(Backbone): + """ + Vocos backbone module built with ResBlocks. + + Args: + input_channels (int): Number of input features channels. + dim (int): Hidden dimension of the model. + num_blocks (int): Number of ResBlock1 blocks. + layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None. + """ + + def __init__( + self, input_channels, dim, num_blocks, layer_scale_init_value=None, + ): + super().__init__() + self.input_channels = input_channels + self.embed = weight_norm(nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)) + layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3 + self.resnet = nn.Sequential( + *[ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value) for _ in range(num_blocks)] + ) + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + x = self.embed(x) + x = self.resnet(x) + x = x.transpose(1, 2) + return x diff --git a/inspiremusic/wavtokenizer/decoder/modules.py b/inspiremusic/wavtokenizer/decoder/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..799a61fb94a2adc26c6e7a39e4ff3285f6556975 --- /dev/null +++ b/inspiremusic/wavtokenizer/decoder/modules.py @@ -0,0 +1,214 @@ +from typing import Optional +from typing import Tuple + +import torch +from torch import nn +from torch.nn.utils import weight_norm, remove_weight_norm + + +class ConvNeXtBlock(nn.Module): + """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. + + Args: + dim (int): Number of input channels. + intermediate_dim (int): Dimensionality of the intermediate layer. + layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. + Defaults to None. + adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. + None means non-conditional LayerNorm. Defaults to None. + """ + + def __init__( + self, + dim: int, + intermediate_dim: int, + layer_scale_init_value: Optional[float] = None, + adanorm_num_embeddings: Optional[int] = None, + ): + super().__init__() + self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv + self.adanorm = adanorm_num_embeddings is not None + if adanorm_num_embeddings: + self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) + else: + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(intermediate_dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + + def forward(self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None) -> torch.Tensor: + residual = x + x = self.dwconv(x) + x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) + if self.adanorm: + assert cond_embedding_id is not None + x = self.norm(x, cond_embedding_id) + else: + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) + + x = residual + x + return x + + +class AdaLayerNorm(nn.Module): + """ + Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes + + Args: + num_embeddings (int): Number of embeddings. + embedding_dim (int): Dimension of the embeddings. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.dim = embedding_dim + self.scale = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + self.shift = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + torch.nn.init.ones_(self.scale.weight) + torch.nn.init.zeros_(self.shift.weight) + + def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor: + scale = self.scale(cond_embedding_id) + shift = self.shift(cond_embedding_id) + x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps) + x = x * scale + shift + return x + + +class ResBlock1(nn.Module): + """ + ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions, + but without upsampling layers. + + Args: + dim (int): Number of input channels. + kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3. + dilation (tuple[int], optional): Dilation factors for the dilated convolutions. + Defaults to (1, 3, 5). + lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function. + Defaults to 0.1. + layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. + Defaults to None. + """ + + def __init__( + self, + dim: int, + kernel_size: int = 3, + dilation: Tuple[int, ...] = (1, 3, 5), + lrelu_slope: float = 0.1, + layer_scale_init_value: float = None, + ): + super().__init__() + self.lrelu_slope = lrelu_slope + self.convs1 = nn.ModuleList( + [ + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[0], + padding=self.get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[1], + padding=self.get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[2], + padding=self.get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), + weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), + weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), + ] + ) + + self.gamma = nn.ParameterList( + [ + nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) + if layer_scale_init_value is not None + else None, + nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) + if layer_scale_init_value is not None + else None, + nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) + if layer_scale_init_value is not None + else None, + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma): + xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope) + xt = c1(xt) + xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope) + xt = c2(xt) + if gamma is not None: + xt = gamma * xt + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + @staticmethod + def get_padding(kernel_size: int, dilation: int = 1) -> int: + return int((kernel_size * dilation - dilation) / 2) + + +def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor: + """ + Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values. + + Args: + x (Tensor): Input tensor. + clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7. + + Returns: + Tensor: Element-wise logarithm of the input tensor with clipping applied. + """ + return torch.log(torch.clip(x, min=clip_val)) + + +def symlog(x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) * torch.log1p(x.abs()) + + +def symexp(x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) * (torch.exp(x.abs()) - 1) diff --git a/inspiremusic/wavtokenizer/decoder/pretrained.py b/inspiremusic/wavtokenizer/decoder/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..1231a05d4945e2f2debe620b1347cad3e6c7ca76 --- /dev/null +++ b/inspiremusic/wavtokenizer/decoder/pretrained.py @@ -0,0 +1,253 @@ +import os +from typing import Tuple, Any, Union, Dict + +import torch +import yaml +from huggingface_hub import hf_hub_download +from torch import nn +from inspiremusic.wavtokenizer.decoder.feature_extractors import FeatureExtractor, EncodecFeatures +from inspiremusic.wavtokenizer.decoder.heads import FourierHead +from inspiremusic.wavtokenizer.decoder.models import Backbone + + +def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any: + """Instantiates a class with the given args and init. + + Args: + args: Positional arguments required for instantiation. + init: Dict of the form {"class_path":...,"init_args":...}. + + Returns: + The instantiated class object. + """ + kwargs = init.get("init_args", {}) + if not isinstance(args, tuple): + args = (args,) + class_module, class_name = init["class_path"].rsplit(".", 1) + module = __import__(class_module, fromlist=[class_name]) + args_class = getattr(module, class_name) + return args_class(*args, **kwargs) + + +class WavTokenizer(nn.Module): + """ + The Vocos class represents a Fourier-based neural vocoder for audio synthesis. + This class is primarily designed for inference, with support for loading from pretrained + model checkpoints. It consists of three main components: a feature extractor, + a backbone, and a head. + """ + + def __init__( + self, feature_extractor: FeatureExtractor, backbone: Backbone, head: FourierHead, + ): + super().__init__() + self.feature_extractor = feature_extractor + self.backbone = backbone + self.head = head + + @classmethod + def from_hparams(cls, config_path: str) -> "Vocos": + """ + Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file. + """ + with open(config_path, "r") as f: + config = yaml.safe_load(f) + feature_extractor = instantiate_class(args=(), init=config["feature_extractor"]) + backbone = instantiate_class(args=(), init=config["backbone"]) + head = instantiate_class(args=(), init=config["head"]) + model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head) + return model + + @classmethod + def from_pretrained(self, repo_id: str) -> "Vocos": + """ + Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub. + """ + config_path = hf_hub_download(repo_id=repo_id, filename="config.yaml") + model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin") + model = self.from_hparams(config_path) + state_dict = torch.load(model_path, map_location="cpu") + if isinstance(model.feature_extractor, EncodecFeatures): + encodec_parameters = { + "feature_extractor.encodec." + key: value + for key, value in model.feature_extractor.encodec.state_dict().items() + } + state_dict.update(encodec_parameters) + model.load_state_dict(state_dict) + model.eval() + return model + + + @classmethod + def from_hparams_feat(cls, config_path: str) -> "Vocos": + """ + Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file. + """ + with open(config_path, "r") as f: + config = yaml.safe_load(f) + feature_extractor = instantiate_class(args=(), init=config['model']['init_args']["feature_extractor"]) + backbone = instantiate_class(args=(), init=config['model']['init_args']["backbone"]) + head = instantiate_class(args=(), init=config['model']['init_args']["head"]) + model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head) + return model + + + @classmethod + def from_pretrained_feat(self, config_path, model_path): + """ + Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub. + """ + model = self.from_hparams_feat(config_path) + state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict'] + state_dict = dict() + for k, v in state_dict_raw.items(): + if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.'): + state_dict[k] = v + + model.load_state_dict(state_dict) + model.eval() + return model + + @classmethod + def estimator(self, config_path, model_path): + """ + Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub. + """ + model = self.from_hparams_feat(config_path) + state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict'] + state_dict = dict() + for k, v in state_dict_raw.items(): + if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.'): + state_dict[k] = v + model.load_state_dict(state_dict) + model.eval() + return model + + @classmethod + def from_pretrained0911(self, config_path, model_folder_path): + """ + Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub. + """ + model = self.from_hparams0802(config_path) + + models = os.listdir(model_folder_path) + val_loss = [] + for item in models: + if not item.startswith('vocos_'): + continue + val_loss.append(item[-11:-5]) + val_loss.sort() + val_loss = val_loss[:3] # ๅ–ๅ‰3ๆ€ง่ƒฝ่พƒๅฅฝ็š„ๆจกๅž‹ๅนณๅ‡ + state_dict = dict() + state_dicts = [] + for item in models: + if not item.startswith('vocos_'): + continue + ll = item[-11:-5] + if ll not in val_loss: + continue + model_path = model_folder_path + '/' + item + state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict'] + state_dict_single = dict() + for k, v in state_dict_raw.items(): + if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.'): + state_dict_single[k] = v + state_dicts.append(state_dict_single) + for kk in state_dicts[0].keys(): + vv = state_dicts[0][kk] + for i in range(1, len(state_dicts)): + ss = state_dicts[i] + vv += ss[kk] + vm = vv/len(state_dicts) + state_dict[kk] = vm + model.load_state_dict(state_dict) + model.eval() + return model + + + @torch.inference_mode() + def forward(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor: + """ + Method to run a copy-synthesis from audio waveform. The feature extractor first processes the audio input, + which is then passed through the backbone and the head to reconstruct the audio output. + + Args: + audio_input (Tensor): The input tensor representing the audio waveform of shape (B, T), + where B is the batch size and L is the waveform length. + + + Returns: + Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T). + """ + features, _, _ = self.feature_extractor(audio_input, **kwargs) # 0818 + audio_output = self.decode(features, **kwargs) + return audio_output + + + # 0818 + @torch.inference_mode() + def encode(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor: + features, discrete_codes, _ = self.feature_extractor(audio_input, **kwargs) + return features,discrete_codes + + + # 0818 + @torch.inference_mode() + def encode_infer(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor: + features, discrete_codes, _ = self.feature_extractor.infer(audio_input, **kwargs) + return features,discrete_codes + + @torch.inference_mode() + def infer(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor: + _, discrete_codes, _ = self.feature_extractor._infer(audio_input, **kwargs) + discrete_codes = discrete_codes.clamp(min=0, max=16383) + return discrete_codes + + @torch.inference_mode() + def decode(self, features_input: torch.Tensor, **kwargs: Any) -> torch.Tensor: + """ + Method to decode audio waveform from already calculated features. The features input is passed through + the backbone and the head to reconstruct the audio output. + + Args: + features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size, + C denotes the feature dimension, and L is the sequence length. + + Returns: + Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T). + """ + x = self.backbone(features_input, **kwargs) + audio_output = self.head(x) + return audio_output + + @torch.inference_mode() + def codes_to_features(self, codes: torch.Tensor) -> torch.Tensor: + """ + Transforms an input sequence of discrete tokens (codes) into feature embeddings using the feature extractor's + codebook weights. + + Args: + codes (Tensor): The input tensor. Expected shape is (K, L) or (K, B, L), + where K is the number of codebooks, B is the batch size and L is the sequence length. + + Returns: + Tensor: Features of shape (B, C, L), where B is the batch size, C denotes the feature dimension, + and L is the sequence length. + """ + assert isinstance( + self.feature_extractor, EncodecFeatures + ), "Feature extractor should be an instance of EncodecFeatures" + + if codes.dim() == 2: + codes = codes.unsqueeze(1) + + n_bins = self.feature_extractor.encodec.quantizer.bins + offsets = torch.arange(0, n_bins * len(codes), n_bins, device=codes.device) + embeddings_idxs = codes + offsets.view(-1, 1, 1) + + tmp=torch.cat([vq.codebook for vq in self.feature_extractor.encodec.quantizer.vq.layers],dim=0) + # features = torch.nn.functional.embedding(embeddings_idxs, self.feature_extractor.codebook_weights).sum(dim=0) + features = torch.nn.functional.embedding(embeddings_idxs, tmp).sum(dim=0) + features = features.transpose(1, 2) + + return features diff --git a/inspiremusic/wavtokenizer/decoder/pretrained_model.py b/inspiremusic/wavtokenizer/decoder/pretrained_model.py new file mode 100644 index 0000000000000000000000000000000000000000..c919bb25685d78522c6b638cd46310c7ae5edc0d --- /dev/null +++ b/inspiremusic/wavtokenizer/decoder/pretrained_model.py @@ -0,0 +1,192 @@ +from typing import Tuple, Any, Union, Dict + +import torch +import yaml +from huggingface_hub import hf_hub_download +from torch import nn +from inspiremusic.wavtokenizer.decoder.feature_extractors import FeatureExtractor, EncodecFeatures +from inspiremusic.wavtokenizer.decoder.heads import FourierHead +from inspiremusic.wavtokenizer.decoder.models import Backbone +from inspiremusic.wavtokenizer.decoder.discriminators import MultiPeriodDiscriminator, MultiResolutionDiscriminator + + +def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any: + """Instantiates a class with the given args and init. + + Args: + args: Positional arguments required for instantiation. + init: Dict of the form {"class_path":...,"init_args":...}. + + Returns: + The instantiated class object. + """ + kwargs = init.get("init_args", {}) + if not isinstance(args, tuple): + args = (args,) + class_module, class_name = init["class_path"].rsplit(".", 1) + module = __import__(class_module, fromlist=[class_name]) + args_class = getattr(module, class_name) + return args_class(*args, **kwargs) + + +class WavTokenizer(nn.Module): + """ + The Vocos class represents a Fourier-based neural vocoder for audio synthesis. + This class is primarily designed for inference, with support for loading from pretrained + model checkpoints. It consists of three main components: a feature extractor, + a backbone, and a head. + """ + + def __init__( + self, feature_extractor: FeatureExtractor, backbone: Backbone, head: FourierHead, + multiperioddisc: MultiPeriodDiscriminator, multiresddisc: MultiResolutionDiscriminator, + ): + super().__init__() + self.feature_extractor = feature_extractor + self.backbone = backbone + self.head = head + + self.multiperioddisc = multiperioddisc + self.multiresddisc = multiresddisc + + @classmethod + def from_hparams0828(cls, config_path: str) -> "Vocos": + """ + Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file. + """ + with open(config_path, "r") as f: + config = yaml.safe_load(f) + feature_extractor = instantiate_class(args=(), init=config['model']['init_args']["feature_extractor"]) + backbone = instantiate_class(args=(), init=config['model']['init_args']["backbone"]) + head = instantiate_class(args=(), init=config['model']['init_args']["head"]) + model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head, + multiperioddisc=MultiPeriodDiscriminator(num_embeddings=4), + multiresddisc=MultiResolutionDiscriminator(num_embeddings=4)) + return model + + @classmethod + def from_pretrained0828(self, config_path, model_path): + """ + Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub. + """ + model = self.from_hparams0828(config_path) + state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict'] + state_dict = dict() + for k, v in state_dict_raw.items(): + if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.') \ + or k.startswith('multiperioddisc.') or k.startswith('multiresddisc.'): + state_dict[k] = v + # if isinstance(model.feature_extractor, EncodecFeatures): + # encodec_parameters = { + # "feature_extractor.encodec." + key: value + # for key, value in model.feature_extractor.encodec.state_dict().items() + # } + # state_dict.update(encodec_parameters) + model.load_state_dict(state_dict) + return model + + @classmethod + def from_hparams0802(cls, config_path: str) -> "Vocos": + """ + Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file. + """ + with open(config_path, "r") as f: + config = yaml.safe_load(f) + feature_extractor = instantiate_class(args=(), init=config['model']['init_args']["feature_extractor"]) + backbone = instantiate_class(args=(), init=config['model']['init_args']["backbone"]) + head = instantiate_class(args=(), init=config['model']['init_args']["head"]) + model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head) + return model + + @classmethod + def from_pretrained0802(self, config_path, model_path): + """ + Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub. + """ + model = self.from_hparams0802(config_path) + state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict'] + state_dict = dict() + for k, v in state_dict_raw.items(): + if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.'): + state_dict[k] = v + # if isinstance(model.feature_extractor, EncodecFeatures): + # encodec_parameters = { + # "feature_extractor.encodec." + key: value + # for key, value in model.feature_extractor.encodec.state_dict().items() + # } + # state_dict.update(encodec_parameters) + model.load_state_dict(state_dict) + model.eval() + return model + + @torch.inference_mode() + def forward(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor: + """ + Method to run a copy-synthesis from audio waveform. The feature extractor first processes the audio input, + which is then passed through the backbone and the head to reconstruct the audio output. + + Args: + audio_input (Tensor): The input tensor representing the audio waveform of shape (B, T), + where B is the batch size and L is the waveform length. + + + Returns: + Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T). + """ + features, _, _ = self.feature_extractor(audio_input, **kwargs) # 0818 + audio_output = self.decode(features, **kwargs) + return audio_output + + + # 0818 + @torch.inference_mode() + def encode(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor: + features, _, _ = self.feature_extractor(audio_input, **kwargs) + return features + + + @torch.inference_mode() + def decode(self, features_input: torch.Tensor, **kwargs: Any) -> torch.Tensor: + """ + Method to decode audio waveform from already calculated features. The features input is passed through + the backbone and the head to reconstruct the audio output. + + Args: + features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size, + C denotes the feature dimension, and L is the sequence length. + + Returns: + Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T). + """ + x = self.backbone(features_input, **kwargs) + audio_output = self.head(x) + return audio_output + + @torch.inference_mode() + def codes_to_features(self, codes: torch.Tensor) -> torch.Tensor: + """ + Transforms an input sequence of discrete tokens (codes) into feature embeddings using the feature extractor's + codebook weights. + + Args: + codes (Tensor): The input tensor. Expected shape is (K, L) or (K, B, L), + where K is the number of codebooks, B is the batch size and L is the sequence length. + + Returns: + Tensor: Features of shape (B, C, L), where B is the batch size, C denotes the feature dimension, + and L is the sequence length. + """ + assert isinstance( + self.feature_extractor, EncodecFeatures + ), "Feature extractor should be an instance of EncodecFeatures" + + if codes.dim() == 2: + codes = codes.unsqueeze(1) + + n_bins = self.feature_extractor.encodec.quantizer.bins + offsets = torch.arange(0, n_bins * len(codes), n_bins, device=codes.device) + embeddings_idxs = codes + offsets.view(-1, 1, 1) + features = torch.nn.functional.embedding(embeddings_idxs, self.feature_extractor.codebook_weights).sum(dim=0) + features = features.transpose(1, 2) + + return features diff --git a/inspiremusic/wavtokenizer/decoder/spectral_ops.py b/inspiremusic/wavtokenizer/decoder/spectral_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..8b062d5ff4c1d82124afa2752cea62132434790d --- /dev/null +++ b/inspiremusic/wavtokenizer/decoder/spectral_ops.py @@ -0,0 +1,242 @@ +import numpy as np +import scipy +import torch +from torch import nn, view_as_real, view_as_complex +import pdb + +class ISTFT(nn.Module): + """ + Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with + windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. + See issue: https://github.com/pytorch/pytorch/issues/62323 + Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. + The NOLA constraint is met as we trim padded samples anyway. + + Args: + n_fft (int): Size of Fourier transform. + hop_length (int): The distance between neighboring sliding window frames. + win_length (int): The size of window frame and STFT filter. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + window = torch.hann_window(win_length) + self.register_buffer("window", window) + + def forward(self, spec: torch.Tensor) -> torch.Tensor: + """ + Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. + + Args: + spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, + N is the number of frequency bins, and T is the number of time frames. + + Returns: + Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. + """ + if self.padding == "center": + # Fallback to pytorch native implementation + return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True) + elif self.padding == "same": + pad = (self.win_length - self.hop_length) // 2 + else: + raise ValueError("Padding must be 'center' or 'same'.") + + assert spec.dim() == 3, "Expected a 3D tensor as input" + B, N, T = spec.shape + + # Inverse FFT + ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") + ifft = ifft * self.window[None, :, None] + + # Overlap and Add + output_size = (T - 1) * self.hop_length + self.win_length + y = torch.nn.functional.fold( + ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), + )[:, 0, 0, pad:-pad] + + # Window envelope + window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) + window_envelope = torch.nn.functional.fold( + window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), + ).squeeze()[pad:-pad] + + # Normalize + # assert (window_envelope > 1e-11).all() + if not torch.all(window_envelope > 1e-11): + window_envelope = torch.clamp(window_envelope, min=1e-11) + + y = y / window_envelope + + return y + + def onnx_forward(self, spec: torch.Tensor) -> torch.Tensor: + """ + Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. + + Args: + spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, + N is the number of frequency bins, and T is the number of time frames. + + Returns: + Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. + """ + if self.padding == "center": + # Fallback to pytorch native implementation + return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True) + elif self.padding == "same": + pad = (self.win_length - self.hop_length) // 2 + else: + raise ValueError("Padding must be 'center' or 'same'.") + + assert spec.dim() == 3, "Expected a 3D tensor as input" + B, N, T = spec.shape + pdb.set_trace() + # Inverse FFT + ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") + ifft = ifft * self.window[None, :, None] + + # Overlap and Add + output_size = (T - 1) * self.hop_length + self.win_length + y = torch.nn.functional.fold( + ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), + )[:, 0, 0, pad:-pad] + + # Window envelope + window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) + window_envelope = torch.nn.functional.fold( + window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), + ).squeeze()[pad:-pad] + + # Normalize + # assert (window_envelope > 1e-11).all() + if not torch.all(window_envelope > 1e-11): + window_envelope = torch.clamp(window_envelope, min=1e-11) + + y = y / window_envelope + + return y + + +class MDCT(nn.Module): + """ + Modified Discrete Cosine Transform (MDCT) module. + + Args: + frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, frame_len: int, padding: str = "same"): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.frame_len = frame_len + N = frame_len // 2 + n0 = (N + 1) / 2 + window = torch.from_numpy(scipy.signal.cosine(frame_len)).float() + self.register_buffer("window", window) + + pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len) + post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N) + # view_as_real: NCCL Backend does not support ComplexFloat data type + # https://github.com/pytorch/pytorch/issues/71613 + self.register_buffer("pre_twiddle", view_as_real(pre_twiddle)) + self.register_buffer("post_twiddle", view_as_real(post_twiddle)) + + def forward(self, audio: torch.Tensor) -> torch.Tensor: + """ + Apply the Modified Discrete Cosine Transform (MDCT) to the input audio. + + Args: + audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size + and T is the length of the audio. + + Returns: + Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames + and N is the number of frequency bins. + """ + if self.padding == "center": + audio = torch.nn.functional.pad(audio, (self.frame_len // 2, self.frame_len // 2)) + elif self.padding == "same": + # hop_length is 1/2 frame_len + audio = torch.nn.functional.pad(audio, (self.frame_len // 4, self.frame_len // 4)) + else: + raise ValueError("Padding must be 'center' or 'same'.") + + x = audio.unfold(-1, self.frame_len, self.frame_len // 2) + N = self.frame_len // 2 + x = x * self.window.expand(x.shape) + X = torch.fft.fft(x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1)[..., :N] + res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N) + return torch.real(res) * np.sqrt(2) + + +class IMDCT(nn.Module): + """ + Inverse Modified Discrete Cosine Transform (IMDCT) module. + + Args: + frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, frame_len: int, padding: str = "same"): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.frame_len = frame_len + N = frame_len // 2 + n0 = (N + 1) / 2 + window = torch.from_numpy(scipy.signal.cosine(frame_len)).float() + self.register_buffer("window", window) + + pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N) + post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2)) + self.register_buffer("pre_twiddle", view_as_real(pre_twiddle)) + self.register_buffer("post_twiddle", view_as_real(post_twiddle)) + + def forward(self, X: torch.Tensor) -> torch.Tensor: + """ + Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients. + + Args: + X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size, + L is the number of frames, and N is the number of frequency bins. + + Returns: + Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio. + """ + B, L, N = X.shape + Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device) + Y[..., :N] = X + Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,))) + y = torch.fft.ifft(Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1) + y = torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape)) * np.sqrt(N) * np.sqrt(2) + result = y * self.window.expand(y.shape) + output_size = (1, (L + 1) * N) + audio = torch.nn.functional.fold( + result.transpose(1, 2), + output_size=output_size, + kernel_size=(1, self.frame_len), + stride=(1, self.frame_len // 2), + )[:, 0, 0, :] + + if self.padding == "center": + pad = self.frame_len // 2 + elif self.padding == "same": + pad = self.frame_len // 4 + else: + raise ValueError("Padding must be 'center' or 'same'.") + + audio = audio[:, pad:-pad] + return audio diff --git a/inspiremusic/wavtokenizer/encoder/__init__.py b/inspiremusic/wavtokenizer/encoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ff8fd2ada59e0e15d4df2854052edf150e5238e3 --- /dev/null +++ b/inspiremusic/wavtokenizer/encoder/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# flake8: noqa + +"""EnCodec neural audio codec.""" + +__version__ = "0.1.2a3" + +from .model import EncodecModel diff --git a/inspiremusic/wavtokenizer/encoder/distrib.py b/inspiremusic/wavtokenizer/encoder/distrib.py new file mode 100644 index 0000000000000000000000000000000000000000..b1662d8085cf2878c4cd058537d0f097de91d158 --- /dev/null +++ b/inspiremusic/wavtokenizer/encoder/distrib.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Torch distributed utilities.""" + +import typing as tp + +import torch + + +def rank(): + if torch.distributed.is_initialized(): + return torch.distributed.get_rank() + else: + return 0 + + +def world_size(): + if torch.distributed.is_initialized(): + return torch.distributed.get_world_size() + else: + return 1 + + +def is_distributed(): + return world_size() > 1 + + +def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM): + if is_distributed(): + return torch.distributed.all_reduce(tensor, op) + + +def _is_complex_or_float(tensor): + return torch.is_floating_point(tensor) or torch.is_complex(tensor) + + +def _check_number_of_params(params: tp.List[torch.Tensor]): + # utility function to check that the number of params in all workers is the same, + # and thus avoid a deadlock with distributed all reduce. + if not is_distributed() or not params: + return + tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long) + all_reduce(tensor) + if tensor.item() != len(params) * world_size(): + # If not all the workers have the same number, for at least one of them, + # this inequality will be verified. + raise RuntimeError(f"Mismatch in number of params: ours is {len(params)}, " + "at least one worker has a different one.") + + +def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0): + """Broadcast the tensors from the given parameters to all workers. + This can be used to ensure that all workers have the same model to start with. + """ + if not is_distributed(): + return + tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)] + _check_number_of_params(tensors) + handles = [] + for tensor in tensors: + handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True) + handles.append(handle) + for handle in handles: + handle.wait() + + +def sync_buffer(buffers, average=True): + """ + Sync grad for buffers. If average is False, broadcast instead of averaging. + """ + if not is_distributed(): + return + handles = [] + for buffer in buffers: + if torch.is_floating_point(buffer.data): + if average: + handle = torch.distributed.all_reduce( + buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True) + else: + handle = torch.distributed.broadcast( + buffer.data, src=0, async_op=True) + handles.append((buffer, handle)) + for buffer, handle in handles: + handle.wait() + if average: + buffer.data /= world_size + + +def sync_grad(params): + """ + Simpler alternative to DistributedDataParallel, that doesn't rely + on any black magic. For simple models it can also be as fast. + Just call this on your model parameters after the call to backward! + """ + if not is_distributed(): + return + handles = [] + for p in params: + if p.grad is not None: + handle = torch.distributed.all_reduce( + p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True) + handles.append((p, handle)) + for p, handle in handles: + handle.wait() + p.grad.data /= world_size() + + +def average_metrics(metrics: tp.Dict[str, float], count=1.): + """Average a dictionary of metrics across all workers, using the optional + `count` as unnormalized weight. + """ + if not is_distributed(): + return metrics + keys, values = zip(*metrics.items()) + device = 'cuda' if torch.cuda.is_available() else 'cpu' + tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32) + tensor *= count + all_reduce(tensor) + averaged = (tensor[:-1] / tensor[-1]).cpu().tolist() + return dict(zip(keys, averaged)) diff --git a/inspiremusic/wavtokenizer/encoder/model.py b/inspiremusic/wavtokenizer/encoder/model.py new file mode 100644 index 0000000000000000000000000000000000000000..33be28de408112b0f54f062df43ac13953e170ea --- /dev/null +++ b/inspiremusic/wavtokenizer/encoder/model.py @@ -0,0 +1,324 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""EnCodec model implementation.""" + +import math +from pathlib import Path +import typing as tp + +import numpy as np +import torch +from torch import nn + +from . import quantization as qt +from . import modules as m +from .utils import _check_checksum, _linear_overlap_add, _get_checkpoint_url + + +ROOT_URL = 'https://dl.fbaipublicfiles.com/encodec/v0/' + +EncodedFrame = tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]] + + +class LMModel(nn.Module): + """Language Model to estimate probabilities of each codebook entry. + We predict all codebooks in parallel for a given time step. + + Args: + n_q (int): number of codebooks. + card (int): codebook cardinality. + dim (int): transformer dimension. + **kwargs: passed to `encoder.modules.transformer.StreamingTransformerEncoder`. + """ + def __init__(self, n_q: int = 32, card: int = 1024, dim: int = 200, **kwargs): + super().__init__() + self.card = card + self.n_q = n_q + self.dim = dim + self.transformer = m.StreamingTransformerEncoder(dim=dim, **kwargs) + self.emb = nn.ModuleList([nn.Embedding(card + 1, dim) for _ in range(n_q)]) + self.linears = nn.ModuleList([nn.Linear(dim, card) for _ in range(n_q)]) + + def forward(self, indices: torch.Tensor, + states: tp.Optional[tp.List[torch.Tensor]] = None, offset: int = 0): + """ + Args: + indices (torch.Tensor): indices from the previous time step. Indices + should be 1 + actual index in the codebook. The value 0 is reserved for + when the index is missing (i.e. first time step). Shape should be + `[B, n_q, T]`. + states: state for the streaming decoding. + offset: offset of the current time step. + + Returns a 3-tuple `(probabilities, new_states, new_offset)` with probabilities + with a shape `[B, card, n_q, T]`. + + """ + B, K, T = indices.shape + input_ = sum([self.emb[k](indices[:, k]) for k in range(K)]) + out, states, offset = self.transformer(input_, states, offset) + logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1).permute(0, 3, 1, 2) + return torch.softmax(logits, dim=1), states, offset + + +class EncodecModel(nn.Module): + """EnCodec model operating on the raw waveform. + Args: + target_bandwidths (list of float): Target bandwidths. + encoder (nn.Module): Encoder network. + decoder (nn.Module): Decoder network. + sample_rate (int): Audio sample rate. + channels (int): Number of audio channels. + normalize (bool): Whether to apply audio normalization. + segment (float or None): segment duration in sec. when doing overlap-add. + overlap (float): overlap between segment, given as a fraction of the segment duration. + name (str): name of the model, used as metadata when compressing audio. + """ + def __init__(self, + encoder: m.SEANetEncoder, + decoder: m.SEANetDecoder, + quantizer: qt.ResidualVectorQuantizer, + target_bandwidths: tp.List[float], + sample_rate: int, + channels: int, + normalize: bool = False, + segment: tp.Optional[float] = None, + overlap: float = 0.01, + name: str = 'unset'): + super().__init__() + self.bandwidth: tp.Optional[float] = None + self.target_bandwidths = target_bandwidths + self.encoder = encoder + self.quantizer = quantizer + self.decoder = decoder + self.sample_rate = sample_rate + self.channels = channels + self.normalize = normalize + self.segment = segment + self.overlap = overlap + self.frame_rate = math.ceil(self.sample_rate / np.prod(self.encoder.ratios)) + self.name = name + self.bits_per_codebook = int(math.log2(self.quantizer.bins)) + assert 2 ** self.bits_per_codebook == self.quantizer.bins, \ + "quantizer bins must be a power of 2." + + @property + def segment_length(self) -> tp.Optional[int]: + if self.segment is None: + return None + return int(self.segment * self.sample_rate) + + @property + def segment_stride(self) -> tp.Optional[int]: + segment_length = self.segment_length + if segment_length is None: + return None + return max(1, int((1 - self.overlap) * segment_length)) + + def encode(self, x: torch.Tensor) -> tp.List[EncodedFrame]: + """Given a tensor `x`, returns a list of frames containing + the discrete encoded codes for `x`, along with rescaling factors + for each segment, when `self.normalize` is True. + + Each frames is a tuple `(codebook, scale)`, with `codebook` of + shape `[B, K, T]`, with `K` the number of codebooks. + """ + assert x.dim() == 3 + _, channels, length = x.shape + assert channels > 0 and channels <= 2 + segment_length = self.segment_length + if segment_length is None: + segment_length = length + stride = length + else: + stride = self.segment_stride # type: ignore + assert stride is not None + + encoded_frames: tp.List[EncodedFrame] = [] + for offset in range(0, length, stride): + frame = x[:, :, offset: offset + segment_length] + encoded_frames.append(self._encode_frame(frame)) + return encoded_frames + + def _encode_frame(self, x: torch.Tensor) -> EncodedFrame: + length = x.shape[-1] + duration = length / self.sample_rate + assert self.segment is None or duration <= 1e-5 + self.segment + + if self.normalize: + mono = x.mean(dim=1, keepdim=True) + volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt() + scale = 1e-8 + volume + x = x / scale + scale = scale.view(-1, 1) + else: + scale = None + + emb = self.encoder(x) + codes = self.quantizer.encode(emb, self.frame_rate, self.bandwidth) + codes = codes.transpose(0, 1) + # codes is [B, K, T], with T frames, K nb of codebooks. + return codes, scale + + def decode(self, encoded_frames: tp.List[EncodedFrame]) -> torch.Tensor: + """Decode the given frames into a waveform. + Note that the output might be a bit bigger than the input. In that case, + any extra steps at the end can be trimmed. + """ + segment_length = self.segment_length + if segment_length is None: + assert len(encoded_frames) == 1 + return self._decode_frame(encoded_frames[0]) + + frames = [self._decode_frame(frame) for frame in encoded_frames] + return _linear_overlap_add(frames, self.segment_stride or 1) + + def _decode_frame(self, encoded_frame: EncodedFrame) -> torch.Tensor: + codes, scale = encoded_frame + codes = codes.transpose(0, 1) + emb = self.quantizer.decode(codes) + out = self.decoder(emb) + if scale is not None: + out = out * scale.view(-1, 1, 1) + return out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + frames = self.encode(x) + return self.decode(frames)[:, :, :x.shape[-1]] + + def set_target_bandwidth(self, bandwidth: float): + if bandwidth not in self.target_bandwidths: + raise ValueError(f"This model doesn't support the bandwidth {bandwidth}. " + f"Select one of {self.target_bandwidths}.") + self.bandwidth = bandwidth + + def get_lm_model(self) -> LMModel: + """Return the associated LM model to improve the compression rate. + """ + device = next(self.parameters()).device + lm = LMModel(self.quantizer.n_q, self.quantizer.bins, num_layers=5, dim=200, + past_context=int(3.5 * self.frame_rate)).to(device) + checkpoints = { + 'encodec_24khz': 'encodec_lm_24khz-1608e3c0.th', + 'encodec_48khz': 'encodec_lm_48khz-7add9fc3.th', + } + try: + checkpoint_name = checkpoints[self.name] + except KeyError: + raise RuntimeError("No LM pre-trained for the current Encodec model.") + url = _get_checkpoint_url(ROOT_URL, checkpoint_name) + state = torch.hub.load_state_dict_from_url( + url, map_location='cpu', check_hash=True) # type: ignore + lm.load_state_dict(state) + lm.eval() + return lm + + @staticmethod + def _get_model(target_bandwidths: tp.List[float], + sample_rate: int = 24_000, + channels: int = 1, + causal: bool = True, + model_norm: str = 'weight_norm', + audio_normalize: bool = False, + segment: tp.Optional[float] = None, + name: str = 'unset'): + encoder = m.SEANetEncoder(channels=channels, norm=model_norm, causal=causal) + decoder = m.SEANetDecoder(channels=channels, norm=model_norm, causal=causal) + n_q = int(1000 * target_bandwidths[-1] // (math.ceil(sample_rate / encoder.hop_length) * 10)) + quantizer = qt.ResidualVectorQuantizer( + dimension=encoder.dimension, + n_q=n_q, + bins=1024, + ) + model = EncodecModel( + encoder, + decoder, + quantizer, + target_bandwidths, + sample_rate, + channels, + normalize=audio_normalize, + segment=segment, + name=name, + ) + return model + + @staticmethod + def _get_pretrained(checkpoint_name: str, repository: tp.Optional[Path] = None): + if repository is not None: + if not repository.is_dir(): + raise ValueError(f"{repository} must exist and be a directory.") + file = repository / checkpoint_name + checksum = file.stem.split('-')[1] + _check_checksum(file, checksum) + return torch.load(file) + else: + url = _get_checkpoint_url(ROOT_URL, checkpoint_name) + return torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True) # type:ignore + + @staticmethod + def encodec_model_24khz(pretrained: bool = True, repository: tp.Optional[Path] = None): + """Return the pretrained causal 24khz model. + """ + if repository: + assert pretrained + target_bandwidths = [1.5, 3., 6, 12., 24.] + checkpoint_name = 'encodec_24khz-d7cc33bc.th' + sample_rate = 24_000 + channels = 1 + model = EncodecModel._get_model( + target_bandwidths, sample_rate, channels, + causal=True, model_norm='weight_norm', audio_normalize=False, + name='encodec_24khz' if pretrained else 'unset') + if pretrained: + state_dict = EncodecModel._get_pretrained(checkpoint_name, repository) + model.load_state_dict(state_dict) + model.eval() + return model + + @staticmethod + def encodec_model_48khz(pretrained: bool = True, repository: tp.Optional[Path] = None): + """Return the pretrained 48khz model. + """ + if repository: + assert pretrained + target_bandwidths = [3., 6., 12., 24.] + checkpoint_name = 'encodec_48khz-7e698e3e.th' + sample_rate = 48_000 + channels = 2 + model = EncodecModel._get_model( + target_bandwidths, sample_rate, channels, + causal=False, model_norm='time_group_norm', audio_normalize=True, + segment=1., name='encodec_48khz' if pretrained else 'unset') + if pretrained: + state_dict = EncodecModel._get_pretrained(checkpoint_name, repository) + model.load_state_dict(state_dict) + model.eval() + return model + + +def test(): + from itertools import product + import torchaudio + bandwidths = [3, 6, 12, 24] + models = { + 'encodec_24khz': EncodecModel.encodec_model_24khz, + 'encodec_48khz': EncodecModel.encodec_model_48khz + } + for model_name, bw in product(models.keys(), bandwidths): + model = models[model_name]() + model.set_target_bandwidth(bw) + audio_suffix = model_name.split('_')[1][:3] + wav, sr = torchaudio.load(f"test_{audio_suffix}.wav") + wav = wav[:, :model.sample_rate * 2] + wav_in = wav.unsqueeze(0) + wav_dec = model(wav_in)[0] + assert wav.shape == wav_dec.shape, (wav.shape, wav_dec.shape) + + +if __name__ == '__main__': + test() diff --git a/inspiremusic/wavtokenizer/encoder/modules/__init__.py b/inspiremusic/wavtokenizer/encoder/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b8e2f987aafa3abf9b882fe15ca5a3b6e150ea32 --- /dev/null +++ b/inspiremusic/wavtokenizer/encoder/modules/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Torch modules.""" + +# flake8: noqa +from .conv import ( + pad1d, + unpad1d, + NormConv1d, + NormConvTranspose1d, + NormConv2d, + NormConvTranspose2d, + SConv1d, + SConvTranspose1d, +) +from .lstm import SLSTM +from .seanet import SEANetEncoder, SEANetDecoder +from .transformer import StreamingTransformerEncoder diff --git a/inspiremusic/wavtokenizer/encoder/modules/conv.py b/inspiremusic/wavtokenizer/encoder/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..e83ae84d20ad2082c6e83bb7fc73bb22ac58cf13 --- /dev/null +++ b/inspiremusic/wavtokenizer/encoder/modules/conv.py @@ -0,0 +1,253 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Convolutional layers wrappers and utilities.""" + +import math +import typing as tp +import warnings + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.utils import spectral_norm, weight_norm + +from .norm import ConvLayerNorm + + +CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm', + 'time_layer_norm', 'layer_norm', 'time_group_norm']) + + +def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module: + assert norm in CONV_NORMALIZATIONS + if norm == 'weight_norm': + return weight_norm(module) + elif norm == 'spectral_norm': + return spectral_norm(module) + else: + # We already check was in CONV_NORMALIZATION, so any other choice + # doesn't need reparametrization. + return module + + +def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module: + """Return the proper normalization module. If causal is True, this will ensure the returned + module is causal, or return an error if the normalization doesn't support causal evaluation. + """ + assert norm in CONV_NORMALIZATIONS + if norm == 'layer_norm': + assert isinstance(module, nn.modules.conv._ConvNd) + return ConvLayerNorm(module.out_channels, **norm_kwargs) + elif norm == 'time_group_norm': + if causal: + raise ValueError("GroupNorm doesn't support causal evaluation.") + assert isinstance(module, nn.modules.conv._ConvNd) + return nn.GroupNorm(1, module.out_channels, **norm_kwargs) + else: + return nn.Identity() + + +def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, + padding_total: int = 0) -> int: + """See `pad_for_conv1d`. + """ + length = x.shape[-1] + n_frames = (length - kernel_size + padding_total) / stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + return ideal_length - length + + +def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0): + """Pad for a convolution to make sure that the last window is full. + Extra padding is added at the end. This is required to ensure that we can rebuild + an output of the same length, as otherwise, even with padding, some time steps + might get removed. + For instance, with total padding = 4, kernel size = 4, stride = 2: + 0 0 1 2 3 4 5 0 0 # (0s are padding) + 1 2 3 # (output frames of a convolution, last 0 is never used) + 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) + 1 2 3 4 # once you removed padding, we are missing one time step ! + """ + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + return F.pad(x, (0, extra_padding)) + + +def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.): + """Tiny wrapper around F.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right before the reflection happen. + """ + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == 'reflect': + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + else: + return F.pad(x, paddings, mode, value) + + +def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): + """Remove padding from x, handling properly zero padding. Only for 1d!""" + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + assert (padding_left + padding_right) <= x.shape[-1] + end = x.shape[-1] - padding_right + return x[..., padding_left: end] + + +class NormConv1d(nn.Module): + """Wrapper around Conv1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, causal: bool = False, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class NormConv2d(nn.Module): + """Wrapper around Conv2d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class NormConvTranspose1d(nn.Module): + """Wrapper around ConvTranspose1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, causal: bool = False, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm) + self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.convtr(x) + x = self.norm(x) + return x + + +class NormConvTranspose2d(nn.Module): + """Wrapper around ConvTranspose2d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm) + self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs) + + def forward(self, x): + x = self.convtr(x) + x = self.norm(x) + return x + + +class SConv1d(nn.Module): + """Conv1d with some builtin handling of asymmetric or causal padding + and normalization. + """ + def __init__(self, in_channels: int, out_channels: int, + kernel_size: int, stride: int = 1, dilation: int = 1, + groups: int = 1, bias: bool = True, causal: bool = False, + norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, + pad_mode: str = 'reflect'): + super().__init__() + # warn user on unusual setup between dilation and stride + if stride > 1 and dilation > 1: + warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1' + f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).') + self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride, + dilation=dilation, groups=groups, bias=bias, causal=causal, + norm=norm, norm_kwargs=norm_kwargs) + self.causal = causal + self.pad_mode = pad_mode + + def forward(self, x): + B, C, T = x.shape + kernel_size = self.conv.conv.kernel_size[0] + stride = self.conv.conv.stride[0] + dilation = self.conv.conv.dilation[0] + kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations + padding_total = kernel_size - stride + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + if self.causal: + # Left padding for causal + x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode) + return self.conv(x) + + +class SConvTranspose1d(nn.Module): + """ConvTranspose1d with some builtin handling of asymmetric or causal padding + and normalization. + """ + def __init__(self, in_channels: int, out_channels: int, + kernel_size: int, stride: int = 1, causal: bool = False, + norm: str = 'none', trim_right_ratio: float = 1., + norm_kwargs: tp.Dict[str, tp.Any] = {}): + super().__init__() + self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride, + causal=causal, norm=norm, norm_kwargs=norm_kwargs) + self.causal = causal + self.trim_right_ratio = trim_right_ratio + assert self.causal or self.trim_right_ratio == 1., \ + "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" + assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1. + + def forward(self, x): + kernel_size = self.convtr.convtr.kernel_size[0] + stride = self.convtr.convtr.stride[0] + padding_total = kernel_size - stride + + y = self.convtr(x) + + # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be + # removed at the very end, when keeping only the right length for the output, + # as removing it here would require also passing the length at the matching layer + # in the encoder. + if self.causal: + # Trim the padding on the right according to the specified ratio + # if trim_right_ratio = 1.0, trim everything from right + padding_right = math.ceil(padding_total * self.trim_right_ratio) + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + return y diff --git a/inspiremusic/wavtokenizer/encoder/modules/lstm.py b/inspiremusic/wavtokenizer/encoder/modules/lstm.py new file mode 100644 index 0000000000000000000000000000000000000000..49908198953deed173bed6eed5199eb74b99e5f8 --- /dev/null +++ b/inspiremusic/wavtokenizer/encoder/modules/lstm.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""LSTM layers module.""" + +from torch import nn + + +class SLSTM(nn.Module): + """ + LSTM without worrying about the hidden state, nor the layout of the data. + Expects input as convolutional layout. + """ + def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True): + super().__init__() + self.skip = skip + self.lstm = nn.LSTM(dimension, dimension, num_layers) + + # def forward(self, x): + # x = x.permute(2, 0, 1) + # y, _ = self.lstm(x) + # if self.skip: + # y = y + x + # y = y.permute(1, 2, 0) + # return y + + # ไฟฎๆ”นtranspose้กบๅบ + def forward(self, x): + # # ๆ’ๅ…ฅreshape + # x = x.reshape(x.shape) + x1 = x.permute(2, 0, 1) + y, _ = self.lstm(x1) + y = y.permute(1, 2, 0) + if self.skip: + y = y + x + return y diff --git a/inspiremusic/wavtokenizer/encoder/modules/norm.py b/inspiremusic/wavtokenizer/encoder/modules/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..19970e0a21ea1c10461cb56d776619dd5f64ff36 --- /dev/null +++ b/inspiremusic/wavtokenizer/encoder/modules/norm.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Normalization modules.""" + +import typing as tp + +import einops +import torch +from torch import nn + + +class ConvLayerNorm(nn.LayerNorm): + """ + Convolution-friendly LayerNorm that moves channels to last dimensions + before running the normalization and moves them back to original position right after. + """ + def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs): + super().__init__(normalized_shape, **kwargs) + + def forward(self, x): + x = einops.rearrange(x, 'b ... t -> b t ...') + x = super().forward(x) + x = einops.rearrange(x, 'b t ... -> b ... t') + return diff --git a/inspiremusic/wavtokenizer/encoder/modules/seanet.py b/inspiremusic/wavtokenizer/encoder/modules/seanet.py new file mode 100644 index 0000000000000000000000000000000000000000..ea1c02d508cbffce0613a637d4c7943d936b09db --- /dev/null +++ b/inspiremusic/wavtokenizer/encoder/modules/seanet.py @@ -0,0 +1,253 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Encodec SEANet-based encoder and decoder implementation.""" + +import typing as tp + +import numpy as np +import torch.nn as nn + +from . import ( + SConv1d, + SConvTranspose1d, + SLSTM +) + + +class SEANetResnetBlock(nn.Module): + """Residual block from SEANet model. + Args: + dim (int): Dimension of the input/output + kernel_sizes (list): List of kernel sizes for the convolutions. + dilations (list): List of dilations for the convolutions. + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + compress (int): Reduced dimensionality in residual branches (from Demucs v3) + true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection. + """ + def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1], + activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, + norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False, + pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True): + super().__init__() + assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations' + act = getattr(nn, activation) + hidden = dim // compress + block = [] + for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)): + in_chs = dim if i == 0 else hidden + out_chs = dim if i == len(kernel_sizes) - 1 else hidden + block += [ + act(**activation_params), + SConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation, + norm=norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode), + ] + self.block = nn.Sequential(*block) + self.shortcut: nn.Module + if true_skip: + self.shortcut = nn.Identity() + else: + self.shortcut = SConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode) + + def forward(self, x): + return self.shortcut(x) + self.block(x) + + +class SEANetEncoder(nn.Module): + """SEANet encoder. + Args: + channels (int): Audio channels. + dimension (int): Intermediate representation dimension. + n_filters (int): Base width for the model. + n_residual_layers (int): nb of residual layers. + ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of + upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here + that must match the decoder order + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + kernel_size (int): Kernel size for the initial convolution. + last_kernel_size (int): Kernel size for the initial convolution. + residual_kernel_size (int): Kernel size for the residual layers. + dilation_base (int): How much to increase the dilation with each layer. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + true_skip (bool): Whether to use true skip connection or a simple + (streamable) convolution as the skip connection in the residual network blocks. + compress (int): Reduced dimensionality in residual branches (from Demucs v3). + lstm (int): Number of LSTM layers at the end of the encoder. + """ + def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 1, + ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, + norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7, + last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, + pad_mode: str = 'reflect', true_skip: bool = False, compress: int = 2, lstm: int = 2): + super().__init__() + self.channels = channels + self.dimension = dimension + self.n_filters = n_filters + self.ratios = list(reversed(ratios)) + del ratios + self.n_residual_layers = n_residual_layers + self.hop_length = np.prod(self.ratios) + + act = getattr(nn, activation) + mult = 1 + model: tp.List[nn.Module] = [ + SConv1d(channels, mult * n_filters, kernel_size, norm=norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode) + ] + # Downsample to raw audio scale + for i, ratio in enumerate(self.ratios): + # Add residual layers + for j in range(n_residual_layers): + model += [ + SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1], + dilations=[dilation_base ** j, 1], + norm=norm, norm_params=norm_params, + activation=activation, activation_params=activation_params, + causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)] + + # Add downsampling layers + model += [ + act(**activation_params), + SConv1d(mult * n_filters, mult * n_filters * 2, + kernel_size=ratio * 2, stride=ratio, + norm=norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode), + ] + mult *= 2 + + if lstm: + model += [SLSTM(mult * n_filters, num_layers=lstm)] + + model += [ + act(**activation_params), + SConv1d(mult * n_filters, dimension, last_kernel_size, norm=norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode) + ] + + self.model = nn.Sequential(*model) + + def forward(self, x): + return self.model(x) + + +class SEANetDecoder(nn.Module): + """SEANet decoder. + Args: + channels (int): Audio channels. + dimension (int): Intermediate representation dimension. + n_filters (int): Base width for the model. + n_residual_layers (int): nb of residual layers. + ratios (Sequence[int]): kernel size and stride ratios + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function + final_activation (str): Final activation function after all convolutions. + final_activation_params (dict): Parameters to provide to the activation function + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + kernel_size (int): Kernel size for the initial convolution. + last_kernel_size (int): Kernel size for the initial convolution. + residual_kernel_size (int): Kernel size for the residual layers. + dilation_base (int): How much to increase the dilation with each layer. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + true_skip (bool): Whether to use true skip connection or a simple + (streamable) convolution as the skip connection in the residual network blocks. + compress (int): Reduced dimensionality in residual branches (from Demucs v3). + lstm (int): Number of LSTM layers at the end of the encoder. + trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup. + If equal to 1.0, it means that all the trimming is done at the right. + """ + def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 1, + ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, + final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None, + norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7, + last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, + pad_mode: str = 'reflect', true_skip: bool = False, compress: int = 2, lstm: int = 2, + trim_right_ratio: float = 1.0): + super().__init__() + self.dimension = dimension + self.channels = channels + self.n_filters = n_filters + self.ratios = ratios + del ratios + self.n_residual_layers = n_residual_layers + self.hop_length = np.prod(self.ratios) + + act = getattr(nn, activation) + mult = int(2 ** len(self.ratios)) + model: tp.List[nn.Module] = [ + SConv1d(dimension, mult * n_filters, kernel_size, norm=norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode) + ] + + if lstm: + model += [SLSTM(mult * n_filters, num_layers=lstm)] + + # Upsample to raw audio scale + for i, ratio in enumerate(self.ratios): + # Add upsampling layers + model += [ + act(**activation_params), + SConvTranspose1d(mult * n_filters, mult * n_filters // 2, + kernel_size=ratio * 2, stride=ratio, + norm=norm, norm_kwargs=norm_params, + causal=causal, trim_right_ratio=trim_right_ratio), + ] + # Add residual layers + for j in range(n_residual_layers): + model += [ + SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1], + dilations=[dilation_base ** j, 1], + activation=activation, activation_params=activation_params, + norm=norm, norm_params=norm_params, causal=causal, + pad_mode=pad_mode, compress=compress, true_skip=true_skip)] + + mult //= 2 + + # Add final layers + model += [ + act(**activation_params), + SConv1d(n_filters, channels, last_kernel_size, norm=norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode) + ] + # Add optional final activation to decoder (eg. tanh) + if final_activation is not None: + final_act = getattr(nn, final_activation) + final_activation_params = final_activation_params or {} + model += [ + final_act(**final_activation_params) + ] + self.model = nn.Sequential(*model) + + def forward(self, z): + y = self.model(z) + return y + + +def test(): + import torch + encoder = SEANetEncoder() + decoder = SEANetDecoder() + x = torch.randn(1, 1, 24000) + z = encoder(x) + assert list(z.shape) == [1, 128, 75], z.shape + y = decoder(z) + assert y.shape == x.shape, (x.shape, y.shape) + + +if __name__ == '__main__': + test() diff --git a/inspiremusic/wavtokenizer/encoder/modules/transformer.py b/inspiremusic/wavtokenizer/encoder/modules/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..44b47918f84aa47021c0d6f5bd58364641088541 --- /dev/null +++ b/inspiremusic/wavtokenizer/encoder/modules/transformer.py @@ -0,0 +1,119 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""A streamable transformer.""" + +import typing as tp + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000): + """Create time embedding for the given positions, target dimension `dim`. + """ + # We aim for BTC format + assert dim % 2 == 0 + half_dim = dim // 2 + adim = torch.arange(half_dim, device=positions.device).view(1, 1, -1) + phase = positions / (max_period ** (adim / (half_dim - 1))) + return torch.cat([ + torch.cos(phase), + torch.sin(phase), + ], dim=-1) + + +class StreamingTransformerEncoderLayer(nn.TransformerEncoderLayer): + def forward(self, x: torch.Tensor, x_past: torch.Tensor, past_context: int): # type: ignore + if self.norm_first: + sa_input = self.norm1(x) + x = x + self._sa_block(sa_input, x_past, past_context) + x = x + self._ff_block(self.norm2(x)) + else: + sa_input = x + x = self.norm1(x + self._sa_block(sa_input, x_past, past_context)) + x = self.norm2(x + self._ff_block(x)) + + return x, sa_input + + # self-attention block + def _sa_block(self, x: torch.Tensor, x_past: torch.Tensor, past_context: int): # type: ignore + _, T, _ = x.shape + _, H, _ = x_past.shape + + queries = x + keys = torch.cat([x_past, x], dim=1) + values = keys + + queries_pos = torch.arange(H, T + H, device=x.device).view(-1, 1) + keys_pos = torch.arange(T + H, device=x.device).view(1, -1) + delta = queries_pos - keys_pos + valid_access = (delta >= 0) & (delta <= past_context) + x = self.self_attn(queries, keys, values, + attn_mask=~valid_access, + need_weights=False)[0] + return self.dropout1(x) + + +class StreamingTransformerEncoder(nn.Module): + """TransformerEncoder with streaming support. + + Args: + dim (int): dimension of the data. + hidden_scale (int): intermediate dimension of FF module is this times the dimension. + num_heads (int): number of heads. + num_layers (int): number of layers. + max_period (float): maxium period of cosines in the positional embedding. + past_context (int or None): receptive field for the causal mask, infinite if None. + gelu (bool): if true uses GeLUs, otherwise use ReLUs. + norm_in (bool): normalize the input. + dropout (float): dropout probability. + **kwargs: See `nn.TransformerEncoderLayer`. + """ + def __init__(self, dim, hidden_scale: float = 4., num_heads: int = 8, num_layers: int = 5, + max_period: float = 10000, past_context: int = 1000, gelu: bool = True, + norm_in: bool = True, dropout: float = 0., **kwargs): + super().__init__() + assert dim % num_heads == 0 + hidden_dim = int(dim * hidden_scale) + + self.max_period = max_period + self.past_context = past_context + activation: tp.Any = F.gelu if gelu else F.relu + + self.norm_in: nn.Module + if norm_in: + self.norm_in = nn.LayerNorm(dim) + else: + self.norm_in = nn.Identity() + + self.layers = nn.ModuleList() + for idx in range(num_layers): + self.layers.append( + StreamingTransformerEncoderLayer( + dim, num_heads, hidden_dim, + activation=activation, batch_first=True, dropout=dropout, **kwargs)) + + def forward(self, x: torch.Tensor, + states: tp.Optional[tp.List[torch.Tensor]] = None, + offset: tp.Union[int, torch.Tensor] = 0): + B, T, C = x.shape + if states is None: + states = [torch.zeros_like(x[:, :1]) for _ in range(1 + len(self.layers))] + + positions = torch.arange(T, device=x.device).view(1, -1, 1) + offset + pos_emb = create_sin_embedding(positions, C, max_period=self.max_period) + + new_state: tp.List[torch.Tensor] = [] + x = self.norm_in(x) + x = x + pos_emb + + for layer_state, layer in zip(states, self.layers): + x, new_layer_state = layer(x, layer_state, self.past_context) + new_layer_state = torch.cat([layer_state, new_layer_state], dim=1) + new_state.append(new_layer_state[:, -self.past_context:, :]) + return x, new_state, offset + T diff --git a/inspiremusic/wavtokenizer/encoder/msstftd.py b/inspiremusic/wavtokenizer/encoder/msstftd.py new file mode 100644 index 0000000000000000000000000000000000000000..a1d3242a57e1e20e99bc2fa86e363cc5ec92cbf7 --- /dev/null +++ b/inspiremusic/wavtokenizer/encoder/msstftd.py @@ -0,0 +1,147 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""MS-STFT discriminator, provided here for reference.""" + +import typing as tp + +import torchaudio +import torch +from torch import nn +from einops import rearrange + +from .modules import NormConv2d + + +FeatureMapType = tp.List[torch.Tensor] +LogitsType = torch.Tensor +DiscriminatorOutput = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]] + + +def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)): + return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2) + + +class DiscriminatorSTFT(nn.Module): + """STFT sub-discriminator. + Args: + filters (int): Number of filters in convolutions + in_channels (int): Number of input channels. Default: 1 + out_channels (int): Number of output channels. Default: 1 + n_fft (int): Size of FFT for each scale. Default: 1024 + hop_length (int): Length of hop between STFT windows for each scale. Default: 256 + kernel_size (tuple of int): Inner Conv2d kernel sizes. Default: ``(3, 9)`` + stride (tuple of int): Inner Conv2d strides. Default: ``(1, 2)`` + dilations (list of int): Inner Conv2d dilation on the time dimension. Default: ``[1, 2, 4]`` + win_length (int): Window size for each scale. Default: 1024 + normalized (bool): Whether to normalize by magnitude after stft. Default: True + norm (str): Normalization method. Default: `'weight_norm'` + activation (str): Activation function. Default: `'LeakyReLU'` + activation_params (dict): Parameters to provide to the activation function. + growth (int): Growth factor for the filters. Default: 1 + """ + def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, + n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024, + filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4], + stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm', + activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}): + super().__init__() + assert len(kernel_size) == 2 + assert len(stride) == 2 + self.filters = filters + self.in_channels = in_channels + self.out_channels = out_channels + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.normalized = normalized + self.activation = getattr(torch.nn, activation)(**activation_params) + self.spec_transform = torchaudio.transforms.Spectrogram( + n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window, + normalized=self.normalized, center=False, pad_mode=None, power=None) + spec_channels = 2 * self.in_channels + self.convs = nn.ModuleList() + self.convs.append( + NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size)) + ) + in_chs = min(filters_scale * self.filters, max_filters) + for i, dilation in enumerate(dilations): + out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters) + self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride, + dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)), + norm=norm)) + in_chs = out_chs + out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters) + self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]), + padding=get_2d_padding((kernel_size[0], kernel_size[0])), + norm=norm)) + self.conv_post = NormConv2d(out_chs, self.out_channels, + kernel_size=(kernel_size[0], kernel_size[0]), + padding=get_2d_padding((kernel_size[0], kernel_size[0])), + norm=norm) + + def forward(self, x: torch.Tensor): + fmap = [] + z = self.spec_transform(x) # [B, 2, Freq, Frames, 2] + z = torch.cat([z.real, z.imag], dim=1) + z = rearrange(z, 'b c w t -> b c t w') + for i, layer in enumerate(self.convs): + z = layer(z) + z = self.activation(z) + fmap.append(z) + z = self.conv_post(z) + return z, fmap + + +class MultiScaleSTFTDiscriminator(nn.Module): + """Multi-Scale STFT (MS-STFT) discriminator. + Args: + filters (int): Number of filters in convolutions + in_channels (int): Number of input channels. Default: 1 + out_channels (int): Number of output channels. Default: 1 + n_ffts (Sequence[int]): Size of FFT for each scale + hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale + win_lengths (Sequence[int]): Window size for each scale + **kwargs: additional args for STFTDiscriminator + """ + def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, + n_ffts: tp.List[int] = [1024, 2048, 512], hop_lengths: tp.List[int] = [256, 512, 128], + win_lengths: tp.List[int] = [1024, 2048, 512], **kwargs): + super().__init__() + assert len(n_ffts) == len(hop_lengths) == len(win_lengths) + self.discriminators = nn.ModuleList([ + DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels, + n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs) + for i in range(len(n_ffts)) + ]) + self.num_discriminators = len(self.discriminators) + + def forward(self, x: torch.Tensor) -> DiscriminatorOutput: + logits = [] + fmaps = [] + for disc in self.discriminators: + logit, fmap = disc(x) + logits.append(logit) + fmaps.append(fmap) + return logits, fmaps + + +def test(): + disc = MultiScaleSTFTDiscriminator(filters=32) + y = torch.randn(1, 1, 24000) + y_hat = torch.randn(1, 1, 24000) + + y_disc_r, fmap_r = disc(y) + y_disc_gen, fmap_gen = disc(y_hat) + assert len(y_disc_r) == len(y_disc_gen) == len(fmap_r) == len(fmap_gen) == disc.num_discriminators + + assert all([len(fm) == 5 for fm in fmap_r + fmap_gen]) + assert all([list(f.shape)[:2] == [1, 32] for fm in fmap_r + fmap_gen for f in fm]) + assert all([len(logits.shape) == 4 for logits in y_disc_r + y_disc_gen]) + + +if __name__ == '__main__': + test() diff --git a/inspiremusic/wavtokenizer/encoder/quantization/__init__.py b/inspiremusic/wavtokenizer/encoder/quantization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bfabe52b8cb6f260cdda6137b34df2f4736bd02f --- /dev/null +++ b/inspiremusic/wavtokenizer/encoder/quantization/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# flake8: noqa +from .vq import QuantizedResult, ResidualVectorQuantizer diff --git a/inspiremusic/wavtokenizer/encoder/quantization/ac.py b/inspiremusic/wavtokenizer/encoder/quantization/ac.py new file mode 100644 index 0000000000000000000000000000000000000000..f0f3e5dcd385cd273a145effa3f53ce7ccfdc74c --- /dev/null +++ b/inspiremusic/wavtokenizer/encoder/quantization/ac.py @@ -0,0 +1,292 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Arithmetic coder.""" + +import io +import math +import random +import typing as tp +import torch + +from ..binary import BitPacker, BitUnpacker + + +def build_stable_quantized_cdf(pdf: torch.Tensor, total_range_bits: int, + roundoff: float = 1e-8, min_range: int = 2, + check: bool = True) -> torch.Tensor: + """Turn the given PDF into a quantized CDF that splits + [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional + to the PDF. + + Args: + pdf (torch.Tensor): probability distribution, shape should be `[N]`. + total_range_bits (int): see `ArithmeticCoder`, the typical range we expect + during the coding process is `[0, 2 ** total_range_bits - 1]`. + roundoff (float): will round the pdf up to that level to remove difference coming + from e.g. evaluating the Language Model on different architectures. + min_range (int): minimum range width. Should always be at least 2 for numerical + stability. Use this to avoid pathological behavior is a value + that is expected to be rare actually happens in real life. + check (bool): if True, checks that nothing bad happened, can be deactivated for speed. + """ + pdf = pdf.detach() + if roundoff: + pdf = (pdf / roundoff).floor() * roundoff + # interpolate with uniform distribution to achieve desired minimum probability. + total_range = 2 ** total_range_bits + cardinality = len(pdf) + alpha = min_range * cardinality / total_range + assert alpha <= 1, "you must reduce min_range" + ranges = (((1 - alpha) * total_range) * pdf).floor().long() + ranges += min_range + quantized_cdf = torch.cumsum(ranges, dim=-1) + if min_range < 2: + raise ValueError("min_range must be at least 2.") + if check: + assert quantized_cdf[-1] <= 2 ** total_range_bits, quantized_cdf[-1] + if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range: + raise ValueError("You must increase your total_range_bits.") + return quantized_cdf + + +class ArithmeticCoder: + """ArithmeticCoder, + Let us take a distribution `p` over `N` symbols, and assume we have a stream + of random variables `s_t` sampled from `p`. Let us assume that we have a budget + of `B` bits that we can afford to write on device. There are `2**B` possible numbers, + corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single + sequence `(s_t)` by doing the following: + + 1) Initialize the current range to` [0 ** 2 B - 1]`. + 2) For each time step t, split the current range into contiguous chunks, + one for each possible outcome, with size roughly proportional to `p`. + For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks + would be `{[0, 2], [3, 3]}`. + 3) Select the chunk corresponding to `s_t`, and replace the current range with this. + 4) When done encoding all the values, just select any value remaining in the range. + + You will notice that this procedure can fail: for instance if at any point in time + the range is smaller than `N`, then we can no longer assign a non-empty chunk to each + possible outcome. Intuitively, the more likely a value is, the less the range width + will reduce, and the longer we can go on encoding values. This makes sense: for any efficient + coding scheme, likely outcomes would take less bits, and more of them can be coded + with a fixed budget. + + In practice, we do not know `B` ahead of time, but we have a way to inject new bits + when the current range decreases below a given limit (given by `total_range_bits`), without + having to redo all the computations. If we encode mostly likely values, we will seldom + need to inject new bits, but a single rare value can deplete our stock of entropy! + + In this explanation, we assumed that the distribution `p` was constant. In fact, the present + code works for any sequence `(p_t)` possibly different for each timestep. + We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller + the KL between the true distribution and `p_t`, the most efficient the coding will be. + + Args: + fo (IO[bytes]): file-like object to which the bytes will be written to. + total_range_bits (int): the range `M` described above is `2 ** total_range_bits. + Any time the current range width fall under this limit, new bits will + be injected to rescale the initial range. + """ + + def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): + assert total_range_bits <= 30 + self.total_range_bits = total_range_bits + self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time. + self.low: int = 0 + self.high: int = 0 + self.max_bit: int = -1 + self._dbg: tp.List[tp.Any] = [] + self._dbg2: tp.List[tp.Any] = [] + + @property + def delta(self) -> int: + """Return the current range width.""" + return self.high - self.low + 1 + + def _flush_common_prefix(self): + # If self.low and self.high start with the sames bits, + # those won't change anymore as we always just increase the range + # by powers of 2, and we can flush them out to the bit stream. + assert self.high >= self.low, (self.low, self.high) + assert self.high < 2 ** (self.max_bit + 1) + while self.max_bit >= 0: + b1 = self.low >> self.max_bit + b2 = self.high >> self.max_bit + if b1 == b2: + self.low -= (b1 << self.max_bit) + self.high -= (b1 << self.max_bit) + assert self.high >= self.low, (self.high, self.low, self.max_bit) + assert self.low >= 0 + self.max_bit -= 1 + self.packer.push(b1) + else: + break + + def push(self, symbol: int, quantized_cdf: torch.Tensor): + """Push the given symbol on the stream, flushing out bits + if possible. + + Args: + symbol (int): symbol to encode with the AC. + quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` + to build this from your pdf estimate. + """ + while self.delta < 2 ** self.total_range_bits: + self.low *= 2 + self.high = self.high * 2 + 1 + self.max_bit += 1 + + range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item() + range_high = quantized_cdf[symbol].item() - 1 + effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits)))) + effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits)))) + assert self.low <= self.high + self.high = self.low + effective_high + self.low = self.low + effective_low + assert self.low <= self.high, (effective_low, effective_high, range_low, range_high) + self._dbg.append((self.low, self.high)) + self._dbg2.append((self.low, self.high)) + outs = self._flush_common_prefix() + assert self.low <= self.high + assert self.max_bit >= -1 + assert self.max_bit <= 61, self.max_bit + return outs + + def flush(self): + """Flush the remaining information to the stream. + """ + while self.max_bit >= 0: + b1 = (self.low >> self.max_bit) & 1 + self.packer.push(b1) + self.max_bit -= 1 + self.packer.flush() + + +class ArithmeticDecoder: + """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation. + + Note that this must be called with **exactly** the same parameters and sequence + of quantized cdf as the arithmetic encoder or the wrong values will be decoded. + + If the AC encoder current range is [L, H], with `L` and `H` having the some common + prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream. + For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside + `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained + for a specific sequence of symbols and a binary-search allows us to decode those symbols. + At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols, + and we will need to read new bits from the stream and repeat the process. + + """ + def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): + self.total_range_bits = total_range_bits + self.low: int = 0 + self.high: int = 0 + self.current: int = 0 + self.max_bit: int = -1 + self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time. + # Following is for debugging + self._dbg: tp.List[tp.Any] = [] + self._dbg2: tp.List[tp.Any] = [] + self._last: tp.Any = None + + @property + def delta(self) -> int: + return self.high - self.low + 1 + + def _flush_common_prefix(self): + # Given the current range [L, H], if both have a common prefix, + # we know we can remove it from our representation to avoid handling large numbers. + while self.max_bit >= 0: + b1 = self.low >> self.max_bit + b2 = self.high >> self.max_bit + if b1 == b2: + self.low -= (b1 << self.max_bit) + self.high -= (b1 << self.max_bit) + self.current -= (b1 << self.max_bit) + assert self.high >= self.low + assert self.low >= 0 + self.max_bit -= 1 + else: + break + + def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]: + """Pull a symbol, reading as many bits from the stream as required. + This returns `None` when the stream has been exhausted. + + Args: + quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` + to build this from your pdf estimate. This must be **exatly** + the same cdf as the one used at encoding time. + """ + while self.delta < 2 ** self.total_range_bits: + bit = self.unpacker.pull() + if bit is None: + return None + self.low *= 2 + self.high = self.high * 2 + 1 + self.current = self.current * 2 + bit + self.max_bit += 1 + + def bin_search(low_idx: int, high_idx: int): + # Binary search is not just for coding interviews :) + if high_idx < low_idx: + raise RuntimeError("Binary search failed") + mid = (low_idx + high_idx) // 2 + range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0 + range_high = quantized_cdf[mid].item() - 1 + effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits)))) + effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits)))) + low = effective_low + self.low + high = effective_high + self.low + if self.current >= low: + if self.current <= high: + return (mid, low, high, self.current) + else: + return bin_search(mid + 1, high_idx) + else: + return bin_search(low_idx, mid - 1) + + self._last = (self.low, self.high, self.current, self.max_bit) + sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1) + self._dbg.append((self.low, self.high, self.current)) + self._flush_common_prefix() + self._dbg2.append((self.low, self.high, self.current)) + + return sym + + +def test(): + torch.manual_seed(1234) + random.seed(1234) + for _ in range(4): + pdfs = [] + cardinality = random.randrange(4000) + steps = random.randrange(100, 500) + fo = io.BytesIO() + encoder = ArithmeticCoder(fo) + symbols = [] + for step in range(steps): + pdf = torch.softmax(torch.randn(cardinality), dim=0) + pdfs.append(pdf) + q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) + symbol = torch.multinomial(pdf, 1).item() + symbols.append(symbol) + encoder.push(symbol, q_cdf) + encoder.flush() + + fo.seek(0) + decoder = ArithmeticDecoder(fo) + for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)): + q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) + decoded_symbol = decoder.pull(q_cdf) + assert decoded_symbol == symbol, idx + assert decoder.pull(torch.zeros(1)) is None + + +if __name__ == "__main__": + test() diff --git a/inspiremusic/wavtokenizer/encoder/quantization/core_vq.py b/inspiremusic/wavtokenizer/encoder/quantization/core_vq.py new file mode 100644 index 0000000000000000000000000000000000000000..774781c2947622e6c0c7a55c6eded26a2813b7c7 --- /dev/null +++ b/inspiremusic/wavtokenizer/encoder/quantization/core_vq.py @@ -0,0 +1,421 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# This implementation is inspired from +# https://github.com/lucidrains/vector-quantize-pytorch +# which is released under MIT License. Hereafter, the original license: +# MIT License +# +# Copyright (c) 2020 Phil Wang +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Core vector quantization implementation.""" + +import typing as tp +import warnings + +from einops import rearrange, repeat +import torch +from torch import nn +import torch.nn.functional as F + +from .. import distrib + + +def default(val: tp.Any, d: tp.Any) -> tp.Any: + return val if val is not None else d + + +def ema_inplace(moving_avg, new, decay: float): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): + return (x + epsilon) / (x.sum() + n_categories * epsilon) + + +def uniform_init(*shape: int): + t = torch.empty(shape) + nn.init.kaiming_uniform_(t) + return t + + +def sample_vectors(samples, num: int): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +def kmeans(samples, num_clusters: int, num_iters: int = 10): + dim, dtype = samples.shape[-1], samples.dtype + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + diffs = rearrange(samples, "n d -> n () d") - rearrange( + means, "c d -> () c d" + ) + dists = -(diffs ** 2).sum(dim=-1) + + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +class EuclideanCodebook(nn.Module): + """Codebook with Euclidean distance. + Args: + dim (int): Dimension. + codebook_size (int): Codebook size. + kmeans_init (bool): Whether to use k-means to initialize the codebooks. + If set to true, run the k-means algorithm on the first training batch and use + the learned centroids as initialization. + kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + def __init__( + self, + dim: int, + codebook_size: int, + kmeans_init: int = False, + kmeans_iters: int = 10, + decay: float = 0.99, + epsilon: float = 1e-5, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.decay = decay + init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros + embed = init_fn(codebook_size, dim) + + self.codebook_size = codebook_size + + self.kmeans_iters = kmeans_iters + self.epsilon = epsilon + self.threshold_ema_dead_code = threshold_ema_dead_code + + self.register_buffer("inited", torch.Tensor([not kmeans_init])) + self.register_buffer("cluster_size", torch.zeros(codebook_size)) + self.register_buffer("embed", embed) + self.register_buffer("embed_avg", embed.clone()) + + @torch.jit.ignore + def init_embed_(self, data): + if self.inited: + return + + embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) #dataไธๅ˜ + self.embed.data.copy_(embed) + self.embed_avg.data.copy_(embed.clone()) + self.cluster_size.data.copy_(cluster_size) + self.inited.data.copy_(torch.Tensor([True])) + # Make sure all buffers across workers are in sync after initialization + distrib.broadcast_tensors(self.buffers()) + + def replace_(self, samples, mask): + modified_codebook = torch.where( + mask[..., None], sample_vectors(samples, self.codebook_size), self.embed + ) + self.embed.data.copy_(modified_codebook) + + def expire_codes_(self, batch_samples): + if self.threshold_ema_dead_code == 0: + return + + expired_codes = self.cluster_size < self.threshold_ema_dead_code + if not torch.any(expired_codes): + return + + batch_samples = rearrange(batch_samples, "... d -> (...) d") + self.replace_(batch_samples, mask=expired_codes) + distrib.broadcast_tensors(self.buffers()) + + def preprocess(self, x): + x = rearrange(x, "... d -> (...) d") + return x + + def quantize(self, x): + embed = self.embed.t() + dist = -( + x.pow(2).sum(1, keepdim=True) + - 2 * x @ embed + + embed.pow(2).sum(0, keepdim=True) + ) + embed_ind = dist.max(dim=-1).indices + return embed_ind + + def postprocess_emb(self, embed_ind, shape): + return embed_ind.view(*shape[:-1]) + + def dequantize(self, embed_ind): + quantize = F.embedding(embed_ind, self.embed) + return quantize + + def encode(self, x): + shape = x.shape + # pre-process + x = self.preprocess(x) + # quantize + embed_ind = self.quantize(x) + # post-process + embed_ind = self.postprocess_emb(embed_ind, shape) + return embed_ind + + def decode(self, embed_ind): + quantize = self.dequantize(embed_ind) + return quantize + + def forward(self, x): + shape, dtype = x.shape, x.dtype + x = self.preprocess(x) + + self.init_embed_(x) + + embed_ind = self.quantize(x) + embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) + embed_ind = self.postprocess_emb(embed_ind, shape) + quantize = self.dequantize(embed_ind) + + if self.training: + # We do the expiry of code at that point as buffers are in sync + # and all the workers will take the same decision. + self.expire_codes_(x) + ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) + embed_sum = x.t() @ embed_onehot + ema_inplace(self.embed_avg, embed_sum.t(), self.decay) + cluster_size = ( + laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) + * self.cluster_size.sum() + ) + embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) + self.embed.data.copy_(embed_normalized) + + return quantize, embed_ind + + +class VectorQuantization(nn.Module): + """Vector quantization implementation. + Currently supports only euclidean distance. + Args: + dim (int): Dimension + codebook_size (int): Codebook size + codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + commitment_weight (float): Weight for commitment loss. + """ + def __init__( + self, + dim: int, + codebook_size: int, + codebook_dim: tp.Optional[int] = None, + decay: float = 0.99, + epsilon: float = 1e-5, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + commitment_weight: float = 1., + ): + super().__init__() + _codebook_dim: int = default(codebook_dim, dim) + + requires_projection = _codebook_dim != dim + self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()) + self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()) + + self.epsilon = epsilon + self.commitment_weight = commitment_weight + + self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size, + kmeans_init=kmeans_init, kmeans_iters=kmeans_iters, + decay=decay, epsilon=epsilon, + threshold_ema_dead_code=threshold_ema_dead_code) + self.codebook_size = codebook_size + + @property + def codebook(self): + return self._codebook.embed + + def encode(self, x): + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + embed_in = self._codebook.encode(x) + return embed_in + + def decode(self, embed_ind): + quantize = self._codebook.decode(embed_ind) + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize + + def forward(self, x): + + # breakpoint() + device = x.device + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + quantize, embed_ind = self._codebook(x) + if self.training: + quantize = x + (quantize - x).detach() + loss = torch.tensor([0.0], device=device, requires_grad=self.training) + + if self.training: + # warnings.warn('When using RVQ in training model, first check ' + # 'https://github.com/facebookresearch/encodec/issues/25 . ' + # 'The bug wasn\'t fixed here for reproducibility.') + if self.commitment_weight > 0: + commit_loss = F.mse_loss(quantize.detach(), x) + loss = loss + commit_loss * self.commitment_weight + + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize, embed_ind, loss + + +class ResidualVectorQuantization(nn.Module): + """Residual vector quantization implementation. + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + """ + def __init__(self, *, num_quantizers, **kwargs): + super().__init__() + self.layers = nn.ModuleList( + [VectorQuantization(**kwargs) for _ in range(num_quantizers)] + ) + + def forward(self, x, n_q: tp.Optional[int] = None): + quantized_out = 0.0 + residual = x + + all_losses = [] + all_indices = [] + + n_q = n_q or len(self.layers) + for layer in self.layers[:n_q]: + quantized, indices, loss = layer(residual) + residual = residual - quantized.detach() + quantized_out = quantized_out + quantized + all_indices.append(indices) + all_losses.append(loss) + + out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) + return quantized_out, out_indices, out_losses + + def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor: + residual = x + all_indices = [] + n_q = n_q or len(self.layers) + for layer in self.layers[:n_q]: + indices = layer.encode(residual) + all_indices.append(indices) + quantized = layer.decode(indices) + residual = residual - quantized.detach() + out_indices = torch.stack(all_indices) + return out_indices + + def decode(self, q_indices: torch.Tensor) -> torch.Tensor: + quantized_out = torch.tensor(0.0, device=q_indices.device) + for i, indices in enumerate(q_indices): + layer = self.layers[i] + quantized = layer.decode(indices) + quantized_out = quantized_out + quantized + return quantized_out + + +class LanguageVectorQuantization(nn.Module): + """Residual vector quantization implementation. + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + """ + def __init__(self, *, num_quantizers, **kwargs): + super().__init__() + self.layers = nn.ModuleList( + [VectorQuantization(**kwargs) for _ in range(num_quantizers)] + ) + # print("core_vq.py:self.layers",self.layers) + + def forward(self, x, n_q: tp.Optional[int] = None): + # breakpoint() x[b,t,c] #[64,75,128] + quantized_out = 0.0 + residual = x + + + all_losses = [] + all_indices = [] + + # breakpoint() + + n_q = n_q or len(self.layers) + + for layer in self.layers[:n_q]: + quantized_out, indices, loss = layer(residual) #ๅพ—ๅˆฐ่ฏฅๅฑ‚็š„่กจๅพ๏ผŒ่ฏฅๅฑ‚็š„indices,่ฏฅๅฑ‚็š„loss [64,75] + # residual = residual - quantized.detach() + # quantized_out = quantized_out + quantized + all_indices.append(indices) + all_losses.append(loss) + # breakpoint() + # breakpoint() + + out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) + return quantized_out, out_indices, out_losses + + def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor: + residual = x + all_indices = [] + n_q = n_q or len(self.layers) + for layer in self.layers[:n_q]: + indices = layer.encode(residual) + all_indices.append(indices) + quantized = layer.decode(indices) + residual = residual - quantized.detach() + out_indices = torch.stack(all_indices) + return out_indices + + def decode(self, q_indices: torch.Tensor) -> torch.Tensor: + quantized_out = torch.tensor(0.0, device=q_indices.device) + for i, indices in enumerate(q_indices): + layer = self.layers[i] + quantized = layer.decode(indices) + quantized_out = quantized_out + quantized + return quantized_out \ No newline at end of file diff --git a/inspiremusic/wavtokenizer/encoder/quantization/vq.py b/inspiremusic/wavtokenizer/encoder/quantization/vq.py new file mode 100644 index 0000000000000000000000000000000000000000..f8e316b4bf912c2a743cd27fe038a17e85bceb13 --- /dev/null +++ b/inspiremusic/wavtokenizer/encoder/quantization/vq.py @@ -0,0 +1,172 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Residual vector quantizer implementation.""" + +from dataclasses import dataclass, field +import math +import typing as tp + +import torch +from torch import nn + +from .core_vq import ResidualVectorQuantization,LanguageVectorQuantization + + +@dataclass +class QuantizedResult: + quantized: torch.Tensor + codes: torch.Tensor + bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item. + penalty: tp.Optional[torch.Tensor] = None + metrics: dict = field(default_factory=dict) + + +class ResidualVectorQuantizer(nn.Module): + """Residual Vector Quantizer. + Args: + dimension (int): Dimension of the codebooks. + n_q (int): Number of residual vector quantizers used. + bins (int): Codebook size. + decay (float): Decay for exponential moving average over the codebooks. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + def __init__( + self, + dimension: int = 256, + n_q: int = 8, + bins: int = 1024, + decay: float = 0.99, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.n_q = n_q + self.dimension = dimension + self.bins = bins + self.decay = decay + self.kmeans_init = kmeans_init + self.kmeans_iters = kmeans_iters + self.threshold_ema_dead_code = threshold_ema_dead_code + + # print(self.bins) + + # breakpoint() + + self.vq = LanguageVectorQuantization( + dim=self.dimension, + codebook_size=self.bins, + num_quantizers=self.n_q, + decay=self.decay, + kmeans_init=self.kmeans_init, + kmeans_iters=self.kmeans_iters, + threshold_ema_dead_code=self.threshold_ema_dead_code, + ) + # self.vq = ResidualVectorQuantization( + # dim=self.dimension, + # codebook_size=self.bins, + # num_quantizers=self.n_q, + # decay=self.decay, + # kmeans_init=self.kmeans_init, + # kmeans_iters=self.kmeans_iters, + # threshold_ema_dead_code=self.threshold_ema_dead_code, + # ) + + + def forward(self, x: torch.Tensor, frame_rate: int, bandwidth: tp.Optional[float] = None) -> QuantizedResult: + """Residual vector quantization on the given input tensor. + Args: + x (torch.Tensor): Input tensor. + frame_rate (int): Sample rate of the input tensor. + bandwidth (float): Target bandwidth. + Returns: + QuantizedResult: + The quantized (or approximately quantized) representation with + the associated bandwidth and any penalty term for the loss. + """ + # breakpoint() + + + bw_per_q = self.get_bandwidth_per_quantizer(frame_rate) + n_q = self.get_num_quantizers_for_bandwidth(frame_rate, bandwidth) + # assert n_q==4 + # breakpoint() + # nq_choice=[3,4,8] + nq_choice=[4,6,8] + if self.training: + # choice = int(torch.randint(0, 3, (1,)).item()) + choice = int(torch.randint(0, 3, (1,)).item()) + # breakpoint() + n_q=nq_choice[choice] + # breakpoint() + # n_q=8 + quantized, codes, commit_loss = self.vq(x, n_q=n_q) + bw = torch.tensor(n_q * bw_per_q).to(x) + return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss)) + + def infer(self, x: torch.Tensor, frame_rate: int, bandwidth: tp.Optional[float] = None) -> QuantizedResult: + """Residual vector quantization on the given input tensor. + Args: + x (torch.Tensor): Input tensor. + frame_rate (int): Sample rate of the input tensor. + bandwidth (float): Target bandwidth. + Returns: + QuantizedResult: + The quantized (or approximately quantized) representation with + the associated bandwidth and any penalty term for the loss. + """ + bw_per_q = self.get_bandwidth_per_quantizer(frame_rate) + # n_q = self.get_num_quantizers_for_bandwidth(frame_rate, bandwidth) + # # assert n_q==4 + # # breakpoint() + # # nq_choice=[3,4,8] + # nq_choice=[3,4,5,6,7,8] + # if self.training: + # # choice = int(torch.randint(0, 3, (1,)).item()) + # choice = int(torch.randint(0, 6, (1,)).item()) + # # breakpoint() + # n_q=nq_choice[choice] + n_q=1 + quantized, codes, commit_loss = self.vq(x, n_q=n_q) + bw = torch.tensor(n_q * bw_per_q).to(x) + return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss)) + + def get_num_quantizers_for_bandwidth(self, frame_rate: int, bandwidth: tp.Optional[float] = None) -> int: + """Return n_q based on specified target bandwidth. + """ + bw_per_q = self.get_bandwidth_per_quantizer(frame_rate) + n_q = self.n_q + if bandwidth and bandwidth > 0.: + # bandwidth is represented as a thousandth of what it is, e.g. 6kbps bandwidth is represented as + # bandwidth == 6.0 + n_q = int(max(1, math.floor(bandwidth * 1000 / bw_per_q))) + return n_q + + def get_bandwidth_per_quantizer(self, frame_rate: int): + """Return bandwidth per quantizer for a given input frame rate. + Each quantizer encodes a frame with lg(bins) bits. + """ + return math.log2(self.bins) * frame_rate + + def encode(self, x: torch.Tensor, frame_rate: int, bandwidth: tp.Optional[float] = None) -> torch.Tensor: + """Encode a given input tensor with the specified frame rate at the given bandwidth. + The RVQ encode method sets the appropriate number of quantizers to use + and returns indices for each quantizer. + """ + n_q = self.get_num_quantizers_for_bandwidth(frame_rate, bandwidth) + codes = self.vq.encode(x, n_q=n_q) + return codes + + def decode(self, codes: torch.Tensor) -> torch.Tensor: + """Decode the given codes to the quantized representation. + """ + quantized = self.vq.decode(codes) + return quantized diff --git a/inspiremusic/wavtokenizer/encoder/utils.py b/inspiremusic/wavtokenizer/encoder/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e3f0f9e9bcb37f2267b2f8adefabfc3672453dc5 --- /dev/null +++ b/inspiremusic/wavtokenizer/encoder/utils.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Various utilities.""" + +from hashlib import sha256 +from pathlib import Path +import typing as tp + +import torch +import torchaudio + + +def _linear_overlap_add(frames: tp.List[torch.Tensor], stride: int): + # Generic overlap add, with linear fade-in/fade-out, supporting complex scenario + # e.g., more than 2 frames per position. + # The core idea is to use a weight function that is a triangle, + # with a maximum value at the middle of the segment. + # We use this weighting when summing the frames, and divide by the sum of weights + # for each positions at the end. Thus: + # - if a frame is the only one to cover a position, the weighting is a no-op. + # - if 2 frames cover a position: + # ... ... + # / \/ \ + # / /\ \ + # S T , i.e. S offset of second frame starts, T end of first frame. + # Then the weight function for each one is: (t - S), (T - t), with `t` a given offset. + # After the final normalization, the weight of the second frame at position `t` is + # (t - S) / (t - S + (T - t)) = (t - S) / (T - S), which is exactly what we want. + # + # - if more than 2 frames overlap at a given point, we hope that by induction + # something sensible happens. + assert len(frames) + device = frames[0].device + dtype = frames[0].dtype + shape = frames[0].shape[:-1] + total_size = stride * (len(frames) - 1) + frames[-1].shape[-1] + + frame_length = frames[0].shape[-1] + t = torch.linspace(0, 1, frame_length + 2, device=device, dtype=dtype)[1: -1] + weight = 0.5 - (t - 0.5).abs() + + sum_weight = torch.zeros(total_size, device=device, dtype=dtype) + out = torch.zeros(*shape, total_size, device=device, dtype=dtype) + offset: int = 0 + + for frame in frames: + frame_length = frame.shape[-1] + out[..., offset:offset + frame_length] += weight[:frame_length] * frame + sum_weight[offset:offset + frame_length] += weight[:frame_length] + offset += stride + assert sum_weight.min() > 0 + return out / sum_weight + + +def _get_checkpoint_url(root_url: str, checkpoint: str): + if not root_url.endswith('/'): + root_url += '/' + return root_url + checkpoint + + +def _check_checksum(path: Path, checksum: str): + sha = sha256() + with open(path, 'rb') as file: + while True: + buf = file.read(2**20) + if not buf: + break + sha.update(buf) + actual_checksum = sha.hexdigest()[:len(checksum)] + if actual_checksum != checksum: + raise RuntimeError(f'Invalid checksum for file {path}, ' + f'expected {checksum} but got {actual_checksum}') + + +def convert_audio(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int): + assert wav.dim() >= 2, "Audio tensor must have at least 2 dimensions" + assert wav.shape[-2] in [1, 2], "Audio must be mono or stereo." + *shape, channels, length = wav.shape + if target_channels == 1: + wav = wav.mean(-2, keepdim=True) + elif target_channels == 2: + wav = wav.expand(*shape, target_channels, length) + elif channels == 1: + wav = wav.expand(target_channels, -1) + else: + raise RuntimeError(f"Impossible to convert from {channels} to {target_channels}") + wav = torchaudio.transforms.Resample(sr, target_sr)(wav) + return wav + + +def save_audio(wav: torch.Tensor, path: tp.Union[Path, str], + sample_rate: int, rescale: bool = False): + limit = 0.99 + mx = wav.abs().max() + if rescale: + wav = wav * min(limit / mx, 1) + else: + wav = wav.clamp(-limit, limit) + torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16) diff --git a/requirements.txt b/requirements.txt index 6275b171d63423a0428c32474b4e7feeeb294368..288f394b475e2ccc99f54eb71df0074d94917213 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ librosa torch torchaudio modelscope -funasr \ No newline at end of file +funasr +transformers \ No newline at end of file