diff --git a/api.py b/api.py new file mode 100644 index 0000000000000000000000000000000000000000..9cbfa4e85398bf961629ec3f85656c77d2a18d0e --- /dev/null +++ b/api.py @@ -0,0 +1,397 @@ + +# -*- coding: utf-8 -*- +import numpy as np +import soundfile +import audresample +import text_utils +import msinference +import re +import srt +import subprocess +import cv2 +import markdown +import json +from pathlib import Path +from types import SimpleNamespace +from flask import Flask, request, send_from_directory +from flask_cors import CORS +from moviepy.editor import * +from audiocraft.audiogen import AudioGen, audio_write + +sound_generator = AudioGen.get_pretrained('facebook/audiogen-medium') +sound_generator.set_generation_params(duration=6) + +Path('./flask_cache').mkdir(parents=True, exist_ok=True) + +# SSH AGENT +# eval $(ssh-agent -s) +# ssh-add ~/.ssh/id_ed25519_github2024 +# +# git remote set-url origin git@github.com:audeering/shift +# == + +def _shift(x): + n = x.shape[0] + i = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0 + x = np.roll(x, i) + # fade_in = .5 + .5 * np.tanh(4*(np.linspace(-10, 10, x.shape[0]) + 9.4)) + # x = x * fade_in + return x + +def _background(x, sound_background=None): + if sound_background is not None: + sound_background = sound_background[0, :] + len_speech = len(x) + if len_speech < len(sound_background): + n_repeat = len_speech // len(sound_background) + 1 + replica = [sound_background] * n_repeat + replica = [_shift(_) for _ in replica] + sound_background = np.concatenate(replica) + + + print(f'\nSOUND\nBACKGROUND\nSHAPE\n{sound_background=}\n{x.shape=}\n- - - -') + x = .74 * x + .26 * sound_background[:len_speech] + return x + +def tts_multi_sentence(precomputed_style_vector=None, + text=None, + voice=None, + scene=None): + '''create 24kHZ np.array with tts + + precomputed_style_vector : required if en_US or en_UK in voice, so + to perform affective TTS. + text : string + voice : string or None (falls to styleTTS) + scene : 'A castle in far away lands' -> if passed will generate background sound scene + ''' + # Generate sound scene - up sample to 24KHz + if scene is not None: + + sound_background = sound_generator.generate([scene])[0] + sound_background = audio_write(None, + sound_background.cpu(), + 24000, # sound_generator.sample_rate, + strategy="loudness", + loudness_compressor=True) + else: + sound_background = None + + # StyleTTS2 + if ('en_US/' in voice) or ('en_UK/' in voice) or (voice is None): + assert precomputed_style_vector is not None, 'For affective TTS, style vector is needed.' + x = [] + for _sentence in text: + x.append(msinference.inference(_sentence, + precomputed_style_vector, + alpha=0.3, + beta=0.7, + diffusion_steps=7, + embedding_scale=1)) + x = np.concatenate(x) + + return _background(x, sound_background) + + # Fallback - Mimic-3 + text_utils.store_ssml(text=text, voice=voice) # Text has to be list of single sentences + ps = subprocess.Popen(f'cat _tmp_ssml.txt | mimic3 --ssml > _tmp.wav', shell=True) + ps.wait() + x, fs = soundfile.read('_tmp.wav') + x = audresample.resample(x.astype(np.float32), 24000, fs)[0, :] # reshapes (64,) -> (1,64) + + return _background(x, sound_background) + + + + +# voices = {} +# import phonemizer +# global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True) + +app = Flask(__name__) +cors = CORS(app) + + +@app.route("/") +def index(): + with open('README.md', 'r') as f: + return markdown.markdown(f.read()) + + +@app.route("/", methods=['GET', 'POST', 'PUT']) +def serve_wav(): + # https://stackoverflow.com/questions/13522137/in-flask-convert-form-post- + # object-into-a-representation-suitable-for-mongodb + r = request.form.to_dict(flat=False) + + + # Physically Save Client Files + for filename, obj in request.files.items(): + obj.save(f'flask_cache/{filename.replace("/","")}') + + print('Saved all files on Server Side\n\n') + + args = SimpleNamespace(text=None if r.get('text') is None else 'flask_cache/' + r.get('text')[0], + video=None if r.get('video') is None else 'flask_cache/' + r.get('video')[0], + image=None if r.get('image') is None else 'flask_cache/' + r.get('image')[0], + voice=r.get('voice')[0], + native=None if r.get('native') is None else 'flask_cache/' + r.get('native')[0], + affective = r.get('affective')[0], + scene=r.get('scene')[0] + ) + # print('\n==RECOMPOSED as \n',request.data,request.form,'\n==') + + + print(args, 'ENTER Script') + do_video_dub = True if args.text.endswith('.srt') else False + + SILENT_VIDEO = '_silent_video.mp4' + AUDIO_TRACK = '_audio_track.wav' + + if do_video_dub: + print('==\nFound .srt : {args.txt}, thus Video should be given as well\n\n') + with open(args.text, "r") as f: + s = f.read() + text = [[j.content, j.start.total_seconds(), j.end.total_seconds()] for j in srt.parse(s)] + assert args.video is not None + native_audio_file = '_tmp.wav' + subprocess.call( + ["ffmpeg", + "-y", # https://stackoverflow.com/questions/39788972/ffmpeg-overwrite-output-file-if-exists + "-i", + args.video, + "-f", + "mp3", + "-ar", + "24000", # "22050 for mimic3", + "-vn", + native_audio_file]) + x_native, _ = soundfile.read(native_audio_file) # reads mp3 + x_native = x_native[:, 0] # stereo + # ffmpeg -i Sandra\ Kotevska\,\ Painting\ Rose\ bush\,\ mixed\ media\,\ 2017.\ \[NMzC_036MtE\].mkv -f mp3 -ar 22050 -vn out44.wa + else: + with open(args.text, 'r') as f: + t = ''.join(f) + t = re.sub(' +', ' ', t) # delete spaces + text = text_utils.split_into_sentences(t) # split to short sentences (~200 phonemes max) + + # ====STYLE VECTOR==== + + precomputed_style_vector = None + if args.native: # Voice Cloning + try: + precomputed_style_vector = msinference.compute_style(args.native) + except soundfile.LibsndfileError: # Fallback - internal voice + print('\n Could not voice clone audio:', args.native, 'fallback to video or Internal TTS voice.\n') + if do_video_dub: # Clone voice via Video + native_audio_file = args.video.replace('.', '').replace('/', '') + native_audio_file += '__native_audio_track.wav' + soundfile.write('tgt_spk.wav', + np.concatenate([ + x_native[:int(4 * 24000)]], 0).astype(np.float32), 24000) # 27400? + precomputed_style_vector = msinference.compute_style('tgt_spk.wav') + + # NOTE: style vector may be None + + if precomputed_style_vector is None: + if 'en_US' in args.voice or 'en_UK' in args.voice: + _dir = '/' if args.affective else '_v2/' + precomputed_style_vector = msinference.compute_style( + 'assets/wavs/style_vector' + _dir + args.voice.replace( + '/', '_').replace( + '#', '_').replace( + 'cmu-arctic', 'cmu_arctic').replace( + '_low', '') + '.wav') + print('\n STYLE VECTOR \n', precomputed_style_vector) + # ====SILENT VIDEO==== + + if args.video is not None: + # banner + frame_tts = np.zeros((104, 1920, 3), dtype=np.uint8) + font = cv2.FONT_HERSHEY_SIMPLEX + bottomLeftCornerOfText = (240, 74) # w,h + fontScale = 2 + fontColor = (255, 255, 255) + thickness = 4 + lineType = 2 + cv2.putText(frame_tts, 'TTS', + bottomLeftCornerOfText, + font, + fontScale, + fontColor, + thickness, + lineType) + # cv2.imshow('i', frame_tts); cv2.waitKey(); cv2.destroyAllWindows() + # ====================================== NATIVE VOICE + frame_orig = np.zeros((104, 1920, 3), dtype=np.uint8) + font = cv2.FONT_HERSHEY_SIMPLEX + bottomLeftCornerOfText = (101, 74) # w,h + fontScale = 2 + fontColor = (255, 255, 255) + thickness = 4 + lineType = 1000 + cv2.putText(frame_orig, 'ORIGINAL VOICE', + bottomLeftCornerOfText, + font, + fontScale, + fontColor, + thickness, + lineType) + # ====SILENT VIDEO EXTRACT==== + # DONLOAD SRT from youtube + # + # yt-dlp --write-sub --sub-lang en --convert-subs "srt" https://www.youtube.com/watch?v=F1Ib7TAu7eg&list=PL4x2B6LSwFewdDvRnUTpBM7jkmpwouhPv&index=2 + # + # + # .mkv ->.mp4 moviepy loads only .mp4 + # + # ffmpeg -y -i Distaff\ \[qVonBgRXcWU\].mkv -c copy -c:a aac Distaff_qVonBgRXcWU.mp4 + # video_file, srt_file = ['assets/Head_of_fortuna.mp4', + # 'assets/head_of_fortuna_en.srt'] + # + video_file = args.video + vf = VideoFileClip(video_file) + try: + # inpaint banners if native voice + num = x_native.shape[0] + is_tts = .5 + .5 * np.tanh(4*(np.linspace(-10, 10, num) + 9.4)) # fade heaviside + + def inpaint_banner(get_frame, t): + '''blend banner - (now plays) tts or native voic + ''' + im = np.copy(get_frame(t)) + + ix = int(t * 24000) + + if is_tts[ix] > .5: # mask is 1 thus tts else native + frame = frame_tts + else: + frame = frame_orig + h, w, _ = frame.shape + # im[-h:, -w:, :] = (.4 * im[-h:, -w:, :] + .6 * frame_orig).astype(np.uint8) + offset_h = 24 + im[offset_h:h + offset_h, :w, :] = (.4 * im[offset_h:h + offset_h, :w, :] + + .6 * frame).astype(np.uint8) + + # im2 = np.concatenate([im, frame_tts], 0) + # cv2.imshow('t', im2); cv2.waitKey(); cv2.destroyAllWindows() + return im # np.concatenate([im, frane_ttts], 0) + except UnboundLocalError: # args.native == False + def inpaint_banner(get_frame, t): + im = np.copy(get_frame(t)) + frame = frame_tts + h, w, _ = frame.shape + offset_h = 24 + im[offset_h:h + offset_h, :w, :] = (.4 * im[offset_h:h+offset_h, :w, :] + + .6 * frame).astype(np.uint8) + return im + vf = vf.fl(inpaint_banner) + vf.write_videofile(SILENT_VIDEO) + + # ==== TTS .srt ==== + + if do_video_dub: + OUT_FILE = './flask_cache/tmp.mp4' #args.out_file + '_video_dub.mp4' + subtitles = text + MAX_LEN = int(subtitles[-1][2] + 17) * 24000 + # 17 extra seconds fail-safe for long-last-segment + print("TOTAL LEN SAMPLES ", MAX_LEN, '\n====================') + pieces = [] + for k, (_text_, orig_start, orig_end) in enumerate(subtitles): + + # PAUSES ????????????????????????? + + + pieces.append(tts_multi_sentence(text=[_text_], + precomputed_style_vector=precomputed_style_vector, + voice=args.voice, + scene=args.scene) + ) + total = np.concatenate(pieces, 0) + # x = audresample.resample(x.astype(np.float32), 24000, 22050) # reshapes (64,) -> (1,64) + # PAD SHORTEST of TTS / NATIVE + if len(x_native) > len(total): + total = np.pad(total, (0, max(0, x_native.shape[0] - total.shape[0]))) + + else: # pad native to len of is_tts & total + x_native = np.pad(x_native, (0, max(0, total.shape[0] - x_native.shape[0]))) + # print(total.shape, x_native.shape, 'PADDED TRACKS') + soundfile.write(AUDIO_TRACK, + # (is_tts * total + (1-is_tts) * x_native)[:, None], + (.64 * total + .27 * x_native)[:, None], + 24000) + else: # Video from plain (.txt) + OUT_FILE = './flask_cache/tmp.mp4' #args.out_file + '_video_from_txt.mp4' + x = tts_multi_sentence(text=text, + precomputed_style_vector=precomputed_style_vector, + voice=args.voice, + scene=args.scene) + soundfile.write(AUDIO_TRACK, x, 24000) + + # IMAGE 2 SPEECH + + if args.image is not None: + + STATIC_FRAME = args.image # 'assets/image_from_T31.jpg' + OUT_FILE = './flask_cache/tmp.mp4' #args.out_file + '_image_to_speech.mp4' + + # SILENT CLIP + + clip_silent = ImageClip(STATIC_FRAME).set_duration(5) # as long as the audio - TTS first + clip_silent.write_videofile(SILENT_VIDEO, fps=24) + + x = tts_multi_sentence(text=text, + precomputed_style_vector=precomputed_style_vector, + voice=args.voice, + scene=args.scene + ) + soundfile.write(AUDIO_TRACK, x, 24000) + elif args.video or args.image: + # write final output video + subprocess.call( + ["ffmpeg", + "-y", + "-i", + SILENT_VIDEO, + "-i", + AUDIO_TRACK, + "-c:v", + "copy", + "-map", + "0:v:0", + "-map", + " 1:a:0", + OUT_FILE]) + + print(f'\noutput video is saved as {OUT_FILE}') + + else: + + # Fallback: No image nor video provided - do only tts + x = tts_multi_sentence(text=text, + precomputed_style_vector=precomputed_style_vector, + voice=args.voice, + scene=args.scene) + OUT_FILE = './flask_cache/tmp.wav' #args.out_file + '.wav' + soundfile.write(OUT_FILE, x, 24000) + + + + + # audios = [msinference.inference(text, + # msinference.compute_style(f'voices/{voice}.wav'), + # alpha=0.3, beta=0.7, diffusion_steps=7, embedding_scale=1)] + # # for t in [text]: + # output_buffer = io.BytesIO() + # write(output_buffer, 24000, np.concatenate(audios)) + # response = Response(output_buffer.getvalue()) + # response.headers["Content-Type"] = "audio/wav" + # https://stackoverflow.com/questions/67591467/ + # flask-shows-typeerror-send-from-directory-missing-1-required-positional-argum + response = send_from_directory('flask_cache/', path=OUT_FILE.split('/')[-1]) + response.headers['suffix-file-type'] = OUT_FILE.split('/')[-1] + return response + + +if __name__ == "__main__": + app.run(host="0.0.0.0") diff --git a/mimic3_foreign/af_ZA_google-nwu_0184.wav b/assets/mimic3_foreign/af_ZA_google-nwu_0184.wav similarity index 100% rename from mimic3_foreign/af_ZA_google-nwu_0184.wav rename to assets/mimic3_foreign/af_ZA_google-nwu_0184.wav diff --git a/mimic3_foreign/af_ZA_google-nwu_1919.wav b/assets/mimic3_foreign/af_ZA_google-nwu_1919.wav similarity index 100% rename from mimic3_foreign/af_ZA_google-nwu_1919.wav rename to assets/mimic3_foreign/af_ZA_google-nwu_1919.wav diff --git a/mimic3_foreign/af_ZA_google-nwu_2418.wav b/assets/mimic3_foreign/af_ZA_google-nwu_2418.wav similarity index 100% rename from mimic3_foreign/af_ZA_google-nwu_2418.wav rename to assets/mimic3_foreign/af_ZA_google-nwu_2418.wav diff --git a/mimic3_foreign/af_ZA_google-nwu_6590.wav b/assets/mimic3_foreign/af_ZA_google-nwu_6590.wav similarity index 100% rename from mimic3_foreign/af_ZA_google-nwu_6590.wav rename to assets/mimic3_foreign/af_ZA_google-nwu_6590.wav diff --git a/mimic3_foreign/af_ZA_google-nwu_7130.wav b/assets/mimic3_foreign/af_ZA_google-nwu_7130.wav similarity index 100% rename from mimic3_foreign/af_ZA_google-nwu_7130.wav rename to assets/mimic3_foreign/af_ZA_google-nwu_7130.wav diff --git a/mimic3_foreign/af_ZA_google-nwu_7214.wav b/assets/mimic3_foreign/af_ZA_google-nwu_7214.wav similarity index 100% rename from mimic3_foreign/af_ZA_google-nwu_7214.wav rename to assets/mimic3_foreign/af_ZA_google-nwu_7214.wav diff --git a/mimic3_foreign/af_ZA_google-nwu_8148.wav b/assets/mimic3_foreign/af_ZA_google-nwu_8148.wav similarity index 100% rename from mimic3_foreign/af_ZA_google-nwu_8148.wav rename to assets/mimic3_foreign/af_ZA_google-nwu_8148.wav diff --git a/mimic3_foreign/af_ZA_google-nwu_8924.wav b/assets/mimic3_foreign/af_ZA_google-nwu_8924.wav similarity index 100% rename from mimic3_foreign/af_ZA_google-nwu_8924.wav rename to assets/mimic3_foreign/af_ZA_google-nwu_8924.wav diff --git a/mimic3_foreign/af_ZA_google-nwu_8963.wav b/assets/mimic3_foreign/af_ZA_google-nwu_8963.wav similarity index 100% rename from mimic3_foreign/af_ZA_google-nwu_8963.wav rename to assets/mimic3_foreign/af_ZA_google-nwu_8963.wav diff --git a/mimic3_foreign/bn_multi_00737.wav b/assets/mimic3_foreign/bn_multi_00737.wav similarity index 100% rename from mimic3_foreign/bn_multi_00737.wav rename to assets/mimic3_foreign/bn_multi_00737.wav diff --git a/mimic3_foreign/bn_multi_00779.wav b/assets/mimic3_foreign/bn_multi_00779.wav similarity index 100% rename from mimic3_foreign/bn_multi_00779.wav rename to assets/mimic3_foreign/bn_multi_00779.wav diff --git a/mimic3_foreign/bn_multi_01232.wav b/assets/mimic3_foreign/bn_multi_01232.wav similarity index 100% rename from mimic3_foreign/bn_multi_01232.wav rename to assets/mimic3_foreign/bn_multi_01232.wav diff --git a/mimic3_foreign/bn_multi_01701.wav b/assets/mimic3_foreign/bn_multi_01701.wav similarity index 100% rename from mimic3_foreign/bn_multi_01701.wav rename to assets/mimic3_foreign/bn_multi_01701.wav diff --git a/mimic3_foreign/bn_multi_02194.wav b/assets/mimic3_foreign/bn_multi_02194.wav similarity index 100% rename from mimic3_foreign/bn_multi_02194.wav rename to assets/mimic3_foreign/bn_multi_02194.wav diff --git a/mimic3_foreign/bn_multi_03042.wav b/assets/mimic3_foreign/bn_multi_03042.wav similarity index 100% rename from mimic3_foreign/bn_multi_03042.wav rename to assets/mimic3_foreign/bn_multi_03042.wav diff --git a/mimic3_foreign/bn_multi_0834.wav b/assets/mimic3_foreign/bn_multi_0834.wav similarity index 100% rename from mimic3_foreign/bn_multi_0834.wav rename to assets/mimic3_foreign/bn_multi_0834.wav diff --git a/mimic3_foreign/bn_multi_1010.wav b/assets/mimic3_foreign/bn_multi_1010.wav similarity index 100% rename from mimic3_foreign/bn_multi_1010.wav rename to assets/mimic3_foreign/bn_multi_1010.wav diff --git a/mimic3_foreign/bn_multi_3108.wav b/assets/mimic3_foreign/bn_multi_3108.wav similarity index 100% rename from mimic3_foreign/bn_multi_3108.wav rename to assets/mimic3_foreign/bn_multi_3108.wav diff --git a/mimic3_foreign/bn_multi_3713.wav b/assets/mimic3_foreign/bn_multi_3713.wav similarity index 100% rename from mimic3_foreign/bn_multi_3713.wav rename to assets/mimic3_foreign/bn_multi_3713.wav diff --git a/mimic3_foreign/bn_multi_3958.wav b/assets/mimic3_foreign/bn_multi_3958.wav similarity index 100% rename from mimic3_foreign/bn_multi_3958.wav rename to assets/mimic3_foreign/bn_multi_3958.wav diff --git a/mimic3_foreign/bn_multi_4046.wav b/assets/mimic3_foreign/bn_multi_4046.wav similarity index 100% rename from mimic3_foreign/bn_multi_4046.wav rename to assets/mimic3_foreign/bn_multi_4046.wav diff --git a/mimic3_foreign/bn_multi_4811.wav b/assets/mimic3_foreign/bn_multi_4811.wav similarity index 100% rename from mimic3_foreign/bn_multi_4811.wav rename to assets/mimic3_foreign/bn_multi_4811.wav diff --git a/mimic3_foreign/bn_multi_5958.wav b/assets/mimic3_foreign/bn_multi_5958.wav similarity index 100% rename from mimic3_foreign/bn_multi_5958.wav rename to assets/mimic3_foreign/bn_multi_5958.wav diff --git a/mimic3_foreign/bn_multi_9169.wav b/assets/mimic3_foreign/bn_multi_9169.wav similarity index 100% rename from mimic3_foreign/bn_multi_9169.wav rename to assets/mimic3_foreign/bn_multi_9169.wav diff --git a/mimic3_foreign/bn_multi_rm.wav b/assets/mimic3_foreign/bn_multi_rm.wav similarity index 100% rename from mimic3_foreign/bn_multi_rm.wav rename to assets/mimic3_foreign/bn_multi_rm.wav diff --git a/mimic3_foreign/de_DE_m-ailabs_angela_merkel.wav b/assets/mimic3_foreign/de_DE_m-ailabs_angela_merkel.wav similarity index 100% rename from mimic3_foreign/de_DE_m-ailabs_angela_merkel.wav rename to assets/mimic3_foreign/de_DE_m-ailabs_angela_merkel.wav diff --git a/mimic3_foreign/de_DE_m-ailabs_eva_k.wav b/assets/mimic3_foreign/de_DE_m-ailabs_eva_k.wav similarity index 100% rename from mimic3_foreign/de_DE_m-ailabs_eva_k.wav rename to assets/mimic3_foreign/de_DE_m-ailabs_eva_k.wav diff --git a/mimic3_foreign/de_DE_m-ailabs_karlsson.wav b/assets/mimic3_foreign/de_DE_m-ailabs_karlsson.wav similarity index 100% rename from mimic3_foreign/de_DE_m-ailabs_karlsson.wav rename to assets/mimic3_foreign/de_DE_m-ailabs_karlsson.wav diff --git a/mimic3_foreign/de_DE_m-ailabs_ramona_deininger.wav b/assets/mimic3_foreign/de_DE_m-ailabs_ramona_deininger.wav similarity index 100% rename from mimic3_foreign/de_DE_m-ailabs_ramona_deininger.wav rename to assets/mimic3_foreign/de_DE_m-ailabs_ramona_deininger.wav diff --git a/mimic3_foreign/de_DE_m-ailabs_rebecca_braunert_plunkett.wav b/assets/mimic3_foreign/de_DE_m-ailabs_rebecca_braunert_plunkett.wav similarity index 100% rename from mimic3_foreign/de_DE_m-ailabs_rebecca_braunert_plunkett.wav rename to assets/mimic3_foreign/de_DE_m-ailabs_rebecca_braunert_plunkett.wav diff --git a/mimic3_foreign/de_DE_thorsten-emotion_amused.wav b/assets/mimic3_foreign/de_DE_thorsten-emotion_amused.wav similarity index 100% rename from mimic3_foreign/de_DE_thorsten-emotion_amused.wav rename to assets/mimic3_foreign/de_DE_thorsten-emotion_amused.wav diff --git a/mimic3_foreign/de_DE_thorsten-emotion_angry.wav b/assets/mimic3_foreign/de_DE_thorsten-emotion_angry.wav similarity index 100% rename from mimic3_foreign/de_DE_thorsten-emotion_angry.wav rename to assets/mimic3_foreign/de_DE_thorsten-emotion_angry.wav diff --git a/mimic3_foreign/de_DE_thorsten-emotion_disgusted.wav b/assets/mimic3_foreign/de_DE_thorsten-emotion_disgusted.wav similarity index 100% rename from mimic3_foreign/de_DE_thorsten-emotion_disgusted.wav rename to assets/mimic3_foreign/de_DE_thorsten-emotion_disgusted.wav diff --git a/mimic3_foreign/de_DE_thorsten-emotion_drunk.wav b/assets/mimic3_foreign/de_DE_thorsten-emotion_drunk.wav similarity index 100% rename from mimic3_foreign/de_DE_thorsten-emotion_drunk.wav rename to assets/mimic3_foreign/de_DE_thorsten-emotion_drunk.wav diff --git a/mimic3_foreign/de_DE_thorsten-emotion_neutral.wav b/assets/mimic3_foreign/de_DE_thorsten-emotion_neutral.wav similarity index 100% rename from mimic3_foreign/de_DE_thorsten-emotion_neutral.wav rename to assets/mimic3_foreign/de_DE_thorsten-emotion_neutral.wav diff --git a/mimic3_foreign/de_DE_thorsten-emotion_sleepy.wav b/assets/mimic3_foreign/de_DE_thorsten-emotion_sleepy.wav similarity index 100% rename from mimic3_foreign/de_DE_thorsten-emotion_sleepy.wav rename to assets/mimic3_foreign/de_DE_thorsten-emotion_sleepy.wav diff --git a/mimic3_foreign/de_DE_thorsten-emotion_surprised.wav b/assets/mimic3_foreign/de_DE_thorsten-emotion_surprised.wav similarity index 100% rename from mimic3_foreign/de_DE_thorsten-emotion_surprised.wav rename to assets/mimic3_foreign/de_DE_thorsten-emotion_surprised.wav diff --git a/mimic3_foreign/de_DE_thorsten-emotion_whisper.wav b/assets/mimic3_foreign/de_DE_thorsten-emotion_whisper.wav similarity index 100% rename from mimic3_foreign/de_DE_thorsten-emotion_whisper.wav rename to assets/mimic3_foreign/de_DE_thorsten-emotion_whisper.wav diff --git a/mimic3_foreign/de_DE_thorsten.wav b/assets/mimic3_foreign/de_DE_thorsten.wav similarity index 100% rename from mimic3_foreign/de_DE_thorsten.wav rename to assets/mimic3_foreign/de_DE_thorsten.wav diff --git a/mimic3_foreign/el_GR_rapunzelina.wav b/assets/mimic3_foreign/el_GR_rapunzelina.wav similarity index 100% rename from mimic3_foreign/el_GR_rapunzelina.wav rename to assets/mimic3_foreign/el_GR_rapunzelina.wav diff --git a/mimic3_foreign/es_ES_carlfm.wav b/assets/mimic3_foreign/es_ES_carlfm.wav similarity index 100% rename from mimic3_foreign/es_ES_carlfm.wav rename to assets/mimic3_foreign/es_ES_carlfm.wav diff --git a/mimic3_foreign/es_ES_m-ailabs_karen_savage.wav b/assets/mimic3_foreign/es_ES_m-ailabs_karen_savage.wav similarity index 100% rename from mimic3_foreign/es_ES_m-ailabs_karen_savage.wav rename to assets/mimic3_foreign/es_ES_m-ailabs_karen_savage.wav diff --git a/mimic3_foreign/es_ES_m-ailabs_tux.wav b/assets/mimic3_foreign/es_ES_m-ailabs_tux.wav similarity index 100% rename from mimic3_foreign/es_ES_m-ailabs_tux.wav rename to assets/mimic3_foreign/es_ES_m-ailabs_tux.wav diff --git a/mimic3_foreign/es_ES_m-ailabs_victor_villarraza.wav b/assets/mimic3_foreign/es_ES_m-ailabs_victor_villarraza.wav similarity index 100% rename from mimic3_foreign/es_ES_m-ailabs_victor_villarraza.wav rename to assets/mimic3_foreign/es_ES_m-ailabs_victor_villarraza.wav diff --git a/mimic3_foreign/fa_haaniye.wav b/assets/mimic3_foreign/fa_haaniye.wav similarity index 100% rename from mimic3_foreign/fa_haaniye.wav rename to assets/mimic3_foreign/fa_haaniye.wav diff --git a/mimic3_foreign/fi_FI_harri-tapani-ylilammi.wav b/assets/mimic3_foreign/fi_FI_harri-tapani-ylilammi.wav similarity index 100% rename from mimic3_foreign/fi_FI_harri-tapani-ylilammi.wav rename to assets/mimic3_foreign/fi_FI_harri-tapani-ylilammi.wav diff --git a/mimic3_foreign/fr_FR_m-ailabs_bernard.wav b/assets/mimic3_foreign/fr_FR_m-ailabs_bernard.wav similarity index 100% rename from mimic3_foreign/fr_FR_m-ailabs_bernard.wav rename to assets/mimic3_foreign/fr_FR_m-ailabs_bernard.wav diff --git a/mimic3_foreign/fr_FR_m-ailabs_ezwa.wav b/assets/mimic3_foreign/fr_FR_m-ailabs_ezwa.wav similarity index 100% rename from mimic3_foreign/fr_FR_m-ailabs_ezwa.wav rename to assets/mimic3_foreign/fr_FR_m-ailabs_ezwa.wav diff --git a/mimic3_foreign/fr_FR_m-ailabs_gilles_g_le_blanc.wav b/assets/mimic3_foreign/fr_FR_m-ailabs_gilles_g_le_blanc.wav similarity index 100% rename from mimic3_foreign/fr_FR_m-ailabs_gilles_g_le_blanc.wav rename to assets/mimic3_foreign/fr_FR_m-ailabs_gilles_g_le_blanc.wav diff --git a/mimic3_foreign/fr_FR_m-ailabs_nadine_eckert_boulet.wav b/assets/mimic3_foreign/fr_FR_m-ailabs_nadine_eckert_boulet.wav similarity index 100% rename from mimic3_foreign/fr_FR_m-ailabs_nadine_eckert_boulet.wav rename to assets/mimic3_foreign/fr_FR_m-ailabs_nadine_eckert_boulet.wav diff --git a/mimic3_foreign/fr_FR_m-ailabs_zeckou.wav b/assets/mimic3_foreign/fr_FR_m-ailabs_zeckou.wav similarity index 100% rename from mimic3_foreign/fr_FR_m-ailabs_zeckou.wav rename to assets/mimic3_foreign/fr_FR_m-ailabs_zeckou.wav diff --git a/mimic3_foreign/fr_FR_siwis.wav b/assets/mimic3_foreign/fr_FR_siwis.wav similarity index 100% rename from mimic3_foreign/fr_FR_siwis.wav rename to assets/mimic3_foreign/fr_FR_siwis.wav diff --git a/mimic3_foreign/fr_FR_tom.wav b/assets/mimic3_foreign/fr_FR_tom.wav similarity index 100% rename from mimic3_foreign/fr_FR_tom.wav rename to assets/mimic3_foreign/fr_FR_tom.wav diff --git a/mimic3_foreign/gu_IN_cmu-indic_cmu_indic_guj_ad.wav b/assets/mimic3_foreign/gu_IN_cmu-indic_cmu_indic_guj_ad.wav similarity index 100% rename from mimic3_foreign/gu_IN_cmu-indic_cmu_indic_guj_ad.wav rename to assets/mimic3_foreign/gu_IN_cmu-indic_cmu_indic_guj_ad.wav diff --git a/mimic3_foreign/gu_IN_cmu-indic_cmu_indic_guj_dp.wav b/assets/mimic3_foreign/gu_IN_cmu-indic_cmu_indic_guj_dp.wav similarity index 100% rename from mimic3_foreign/gu_IN_cmu-indic_cmu_indic_guj_dp.wav rename to assets/mimic3_foreign/gu_IN_cmu-indic_cmu_indic_guj_dp.wav diff --git a/mimic3_foreign/gu_IN_cmu-indic_cmu_indic_guj_kt.wav b/assets/mimic3_foreign/gu_IN_cmu-indic_cmu_indic_guj_kt.wav similarity index 100% rename from mimic3_foreign/gu_IN_cmu-indic_cmu_indic_guj_kt.wav rename to assets/mimic3_foreign/gu_IN_cmu-indic_cmu_indic_guj_kt.wav diff --git a/mimic3_foreign/ha_NE_openbible.wav b/assets/mimic3_foreign/ha_NE_openbible.wav similarity index 100% rename from mimic3_foreign/ha_NE_openbible.wav rename to assets/mimic3_foreign/ha_NE_openbible.wav diff --git a/mimic3_foreign/hu_HU_diana-majlinger.wav b/assets/mimic3_foreign/hu_HU_diana-majlinger.wav similarity index 100% rename from mimic3_foreign/hu_HU_diana-majlinger.wav rename to assets/mimic3_foreign/hu_HU_diana-majlinger.wav diff --git a/mimic3_foreign/it_IT_mls_10446.wav b/assets/mimic3_foreign/it_IT_mls_10446.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_10446.wav rename to assets/mimic3_foreign/it_IT_mls_10446.wav diff --git a/mimic3_foreign/it_IT_mls_1157.wav b/assets/mimic3_foreign/it_IT_mls_1157.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_1157.wav rename to assets/mimic3_foreign/it_IT_mls_1157.wav diff --git a/mimic3_foreign/it_IT_mls_12428.wav b/assets/mimic3_foreign/it_IT_mls_12428.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_12428.wav rename to assets/mimic3_foreign/it_IT_mls_12428.wav diff --git a/mimic3_foreign/it_IT_mls_12804.wav b/assets/mimic3_foreign/it_IT_mls_12804.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_12804.wav rename to assets/mimic3_foreign/it_IT_mls_12804.wav diff --git a/mimic3_foreign/it_IT_mls_1595.wav b/assets/mimic3_foreign/it_IT_mls_1595.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_1595.wav rename to assets/mimic3_foreign/it_IT_mls_1595.wav diff --git a/mimic3_foreign/it_IT_mls_1725.wav b/assets/mimic3_foreign/it_IT_mls_1725.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_1725.wav rename to assets/mimic3_foreign/it_IT_mls_1725.wav diff --git a/mimic3_foreign/it_IT_mls_1989.wav b/assets/mimic3_foreign/it_IT_mls_1989.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_1989.wav rename to assets/mimic3_foreign/it_IT_mls_1989.wav diff --git a/mimic3_foreign/it_IT_mls_2019.wav b/assets/mimic3_foreign/it_IT_mls_2019.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_2019.wav rename to assets/mimic3_foreign/it_IT_mls_2019.wav diff --git a/mimic3_foreign/it_IT_mls_2033.wav b/assets/mimic3_foreign/it_IT_mls_2033.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_2033.wav rename to assets/mimic3_foreign/it_IT_mls_2033.wav diff --git a/mimic3_foreign/it_IT_mls_277.wav b/assets/mimic3_foreign/it_IT_mls_277.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_277.wav rename to assets/mimic3_foreign/it_IT_mls_277.wav diff --git a/mimic3_foreign/it_IT_mls_4649.wav b/assets/mimic3_foreign/it_IT_mls_4649.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_4649.wav rename to assets/mimic3_foreign/it_IT_mls_4649.wav diff --git a/mimic3_foreign/it_IT_mls_4705.wav b/assets/mimic3_foreign/it_IT_mls_4705.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_4705.wav rename to assets/mimic3_foreign/it_IT_mls_4705.wav diff --git a/mimic3_foreign/it_IT_mls_4971.wav b/assets/mimic3_foreign/it_IT_mls_4971.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_4971.wav rename to assets/mimic3_foreign/it_IT_mls_4971.wav diff --git a/mimic3_foreign/it_IT_mls_4974.wav b/assets/mimic3_foreign/it_IT_mls_4974.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_4974.wav rename to assets/mimic3_foreign/it_IT_mls_4974.wav diff --git a/mimic3_foreign/it_IT_mls_4975.wav b/assets/mimic3_foreign/it_IT_mls_4975.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_4975.wav rename to assets/mimic3_foreign/it_IT_mls_4975.wav diff --git a/mimic3_foreign/it_IT_mls_4998.wav b/assets/mimic3_foreign/it_IT_mls_4998.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_4998.wav rename to assets/mimic3_foreign/it_IT_mls_4998.wav diff --git a/mimic3_foreign/it_IT_mls_5010.wav b/assets/mimic3_foreign/it_IT_mls_5010.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_5010.wav rename to assets/mimic3_foreign/it_IT_mls_5010.wav diff --git a/mimic3_foreign/it_IT_mls_5421.wav b/assets/mimic3_foreign/it_IT_mls_5421.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_5421.wav rename to assets/mimic3_foreign/it_IT_mls_5421.wav diff --git a/mimic3_foreign/it_IT_mls_6001.wav b/assets/mimic3_foreign/it_IT_mls_6001.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_6001.wav rename to assets/mimic3_foreign/it_IT_mls_6001.wav diff --git a/mimic3_foreign/it_IT_mls_6299.wav b/assets/mimic3_foreign/it_IT_mls_6299.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_6299.wav rename to assets/mimic3_foreign/it_IT_mls_6299.wav diff --git a/mimic3_foreign/it_IT_mls_6348.wav b/assets/mimic3_foreign/it_IT_mls_6348.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_6348.wav rename to assets/mimic3_foreign/it_IT_mls_6348.wav diff --git a/mimic3_foreign/it_IT_mls_643.wav b/assets/mimic3_foreign/it_IT_mls_643.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_643.wav rename to assets/mimic3_foreign/it_IT_mls_643.wav diff --git a/mimic3_foreign/it_IT_mls_644.wav b/assets/mimic3_foreign/it_IT_mls_644.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_644.wav rename to assets/mimic3_foreign/it_IT_mls_644.wav diff --git a/mimic3_foreign/it_IT_mls_659.wav b/assets/mimic3_foreign/it_IT_mls_659.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_659.wav rename to assets/mimic3_foreign/it_IT_mls_659.wav diff --git a/mimic3_foreign/it_IT_mls_6744.wav b/assets/mimic3_foreign/it_IT_mls_6744.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_6744.wav rename to assets/mimic3_foreign/it_IT_mls_6744.wav diff --git a/mimic3_foreign/it_IT_mls_6807.wav b/assets/mimic3_foreign/it_IT_mls_6807.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_6807.wav rename to assets/mimic3_foreign/it_IT_mls_6807.wav diff --git a/mimic3_foreign/it_IT_mls_7405.wav b/assets/mimic3_foreign/it_IT_mls_7405.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_7405.wav rename to assets/mimic3_foreign/it_IT_mls_7405.wav diff --git a/mimic3_foreign/it_IT_mls_7440.wav b/assets/mimic3_foreign/it_IT_mls_7440.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_7440.wav rename to assets/mimic3_foreign/it_IT_mls_7440.wav diff --git a/mimic3_foreign/it_IT_mls_7444.wav b/assets/mimic3_foreign/it_IT_mls_7444.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_7444.wav rename to assets/mimic3_foreign/it_IT_mls_7444.wav diff --git a/mimic3_foreign/it_IT_mls_7936.wav b/assets/mimic3_foreign/it_IT_mls_7936.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_7936.wav rename to assets/mimic3_foreign/it_IT_mls_7936.wav diff --git a/mimic3_foreign/it_IT_mls_8181.wav b/assets/mimic3_foreign/it_IT_mls_8181.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_8181.wav rename to assets/mimic3_foreign/it_IT_mls_8181.wav diff --git a/mimic3_foreign/it_IT_mls_8207.wav b/assets/mimic3_foreign/it_IT_mls_8207.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_8207.wav rename to assets/mimic3_foreign/it_IT_mls_8207.wav diff --git a/mimic3_foreign/it_IT_mls_8384.wav b/assets/mimic3_foreign/it_IT_mls_8384.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_8384.wav rename to assets/mimic3_foreign/it_IT_mls_8384.wav diff --git a/mimic3_foreign/it_IT_mls_844.wav b/assets/mimic3_foreign/it_IT_mls_844.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_844.wav rename to assets/mimic3_foreign/it_IT_mls_844.wav diff --git a/mimic3_foreign/it_IT_mls_8461.wav b/assets/mimic3_foreign/it_IT_mls_8461.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_8461.wav rename to assets/mimic3_foreign/it_IT_mls_8461.wav diff --git a/mimic3_foreign/it_IT_mls_8828.wav b/assets/mimic3_foreign/it_IT_mls_8828.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_8828.wav rename to assets/mimic3_foreign/it_IT_mls_8828.wav diff --git a/mimic3_foreign/it_IT_mls_8842.wav b/assets/mimic3_foreign/it_IT_mls_8842.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_8842.wav rename to assets/mimic3_foreign/it_IT_mls_8842.wav diff --git a/mimic3_foreign/it_IT_mls_9185.wav b/assets/mimic3_foreign/it_IT_mls_9185.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_9185.wav rename to assets/mimic3_foreign/it_IT_mls_9185.wav diff --git a/mimic3_foreign/it_IT_mls_9772.wav b/assets/mimic3_foreign/it_IT_mls_9772.wav similarity index 100% rename from mimic3_foreign/it_IT_mls_9772.wav rename to assets/mimic3_foreign/it_IT_mls_9772.wav diff --git a/mimic3_foreign/it_IT_riccardo-fasol.wav b/assets/mimic3_foreign/it_IT_riccardo-fasol.wav similarity index 100% rename from mimic3_foreign/it_IT_riccardo-fasol.wav rename to assets/mimic3_foreign/it_IT_riccardo-fasol.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_00027.wav b/assets/mimic3_foreign/jv_ID_google-gmu_00027.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_00027.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_00027.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_00264.wav b/assets/mimic3_foreign/jv_ID_google-gmu_00264.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_00264.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_00264.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_00658.wav b/assets/mimic3_foreign/jv_ID_google-gmu_00658.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_00658.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_00658.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_01392.wav b/assets/mimic3_foreign/jv_ID_google-gmu_01392.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_01392.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_01392.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_01519.wav b/assets/mimic3_foreign/jv_ID_google-gmu_01519.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_01519.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_01519.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_01932.wav b/assets/mimic3_foreign/jv_ID_google-gmu_01932.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_01932.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_01932.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_02059.wav b/assets/mimic3_foreign/jv_ID_google-gmu_02059.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_02059.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_02059.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_02326.wav b/assets/mimic3_foreign/jv_ID_google-gmu_02326.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_02326.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_02326.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_02884.wav b/assets/mimic3_foreign/jv_ID_google-gmu_02884.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_02884.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_02884.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_03187.wav b/assets/mimic3_foreign/jv_ID_google-gmu_03187.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_03187.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_03187.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_03314.wav b/assets/mimic3_foreign/jv_ID_google-gmu_03314.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_03314.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_03314.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_03424.wav b/assets/mimic3_foreign/jv_ID_google-gmu_03424.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_03424.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_03424.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_03727.wav b/assets/mimic3_foreign/jv_ID_google-gmu_03727.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_03727.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_03727.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_04175.wav b/assets/mimic3_foreign/jv_ID_google-gmu_04175.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_04175.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_04175.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_04285.wav b/assets/mimic3_foreign/jv_ID_google-gmu_04285.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_04285.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_04285.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_04588.wav b/assets/mimic3_foreign/jv_ID_google-gmu_04588.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_04588.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_04588.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_04679.wav b/assets/mimic3_foreign/jv_ID_google-gmu_04679.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_04679.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_04679.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_04715.wav b/assets/mimic3_foreign/jv_ID_google-gmu_04715.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_04715.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_04715.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_04982.wav b/assets/mimic3_foreign/jv_ID_google-gmu_04982.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_04982.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_04982.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_05219.wav b/assets/mimic3_foreign/jv_ID_google-gmu_05219.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_05219.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_05219.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_05522.wav b/assets/mimic3_foreign/jv_ID_google-gmu_05522.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_05522.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_05522.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_05540.wav b/assets/mimic3_foreign/jv_ID_google-gmu_05540.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_05540.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_05540.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_05667.wav b/assets/mimic3_foreign/jv_ID_google-gmu_05667.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_05667.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_05667.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_05970.wav b/assets/mimic3_foreign/jv_ID_google-gmu_05970.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_05970.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_05970.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_06080.wav b/assets/mimic3_foreign/jv_ID_google-gmu_06080.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_06080.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_06080.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_06207.wav b/assets/mimic3_foreign/jv_ID_google-gmu_06207.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_06207.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_06207.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_06383.wav b/assets/mimic3_foreign/jv_ID_google-gmu_06383.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_06383.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_06383.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_06510.wav b/assets/mimic3_foreign/jv_ID_google-gmu_06510.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_06510.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_06510.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_06941.wav b/assets/mimic3_foreign/jv_ID_google-gmu_06941.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_06941.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_06941.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_07335.wav b/assets/mimic3_foreign/jv_ID_google-gmu_07335.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_07335.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_07335.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_07638.wav b/assets/mimic3_foreign/jv_ID_google-gmu_07638.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_07638.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_07638.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_07765.wav b/assets/mimic3_foreign/jv_ID_google-gmu_07765.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_07765.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_07765.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_07875.wav b/assets/mimic3_foreign/jv_ID_google-gmu_07875.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_07875.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_07875.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_08002.wav b/assets/mimic3_foreign/jv_ID_google-gmu_08002.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_08002.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_08002.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_08178.wav b/assets/mimic3_foreign/jv_ID_google-gmu_08178.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_08178.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_08178.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_08305.wav b/assets/mimic3_foreign/jv_ID_google-gmu_08305.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_08305.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_08305.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_08736.wav b/assets/mimic3_foreign/jv_ID_google-gmu_08736.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_08736.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_08736.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_09039.wav b/assets/mimic3_foreign/jv_ID_google-gmu_09039.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_09039.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_09039.wav diff --git a/mimic3_foreign/jv_ID_google-gmu_09724.wav b/assets/mimic3_foreign/jv_ID_google-gmu_09724.wav similarity index 100% rename from mimic3_foreign/jv_ID_google-gmu_09724.wav rename to assets/mimic3_foreign/jv_ID_google-gmu_09724.wav diff --git a/mimic3_foreign/ko_KO_kss.wav b/assets/mimic3_foreign/ko_KO_kss.wav similarity index 100% rename from mimic3_foreign/ko_KO_kss.wav rename to assets/mimic3_foreign/ko_KO_kss.wav diff --git a/mimic3_foreign/ne_NP_ne-google_0258.wav b/assets/mimic3_foreign/ne_NP_ne-google_0258.wav similarity index 100% rename from mimic3_foreign/ne_NP_ne-google_0258.wav rename to assets/mimic3_foreign/ne_NP_ne-google_0258.wav diff --git a/mimic3_foreign/ne_NP_ne-google_0283.wav b/assets/mimic3_foreign/ne_NP_ne-google_0283.wav similarity index 100% rename from mimic3_foreign/ne_NP_ne-google_0283.wav rename to assets/mimic3_foreign/ne_NP_ne-google_0283.wav diff --git a/mimic3_foreign/ne_NP_ne-google_0546.wav b/assets/mimic3_foreign/ne_NP_ne-google_0546.wav similarity index 100% rename from mimic3_foreign/ne_NP_ne-google_0546.wav rename to assets/mimic3_foreign/ne_NP_ne-google_0546.wav diff --git a/mimic3_foreign/ne_NP_ne-google_0649.wav b/assets/mimic3_foreign/ne_NP_ne-google_0649.wav similarity index 100% rename from mimic3_foreign/ne_NP_ne-google_0649.wav rename to assets/mimic3_foreign/ne_NP_ne-google_0649.wav diff --git a/mimic3_foreign/ne_NP_ne-google_0883.wav b/assets/mimic3_foreign/ne_NP_ne-google_0883.wav similarity index 100% rename from mimic3_foreign/ne_NP_ne-google_0883.wav rename to assets/mimic3_foreign/ne_NP_ne-google_0883.wav diff --git a/mimic3_foreign/ne_NP_ne-google_2027.wav b/assets/mimic3_foreign/ne_NP_ne-google_2027.wav similarity index 100% rename from mimic3_foreign/ne_NP_ne-google_2027.wav rename to assets/mimic3_foreign/ne_NP_ne-google_2027.wav diff --git a/mimic3_foreign/ne_NP_ne-google_2099.wav b/assets/mimic3_foreign/ne_NP_ne-google_2099.wav similarity index 100% rename from mimic3_foreign/ne_NP_ne-google_2099.wav rename to assets/mimic3_foreign/ne_NP_ne-google_2099.wav diff --git a/mimic3_foreign/ne_NP_ne-google_2139.wav b/assets/mimic3_foreign/ne_NP_ne-google_2139.wav similarity index 100% rename from mimic3_foreign/ne_NP_ne-google_2139.wav rename to assets/mimic3_foreign/ne_NP_ne-google_2139.wav diff --git a/mimic3_foreign/ne_NP_ne-google_3154.wav b/assets/mimic3_foreign/ne_NP_ne-google_3154.wav similarity index 100% rename from mimic3_foreign/ne_NP_ne-google_3154.wav rename to assets/mimic3_foreign/ne_NP_ne-google_3154.wav diff --git a/mimic3_foreign/ne_NP_ne-google_3614.wav b/assets/mimic3_foreign/ne_NP_ne-google_3614.wav similarity index 100% rename from mimic3_foreign/ne_NP_ne-google_3614.wav rename to assets/mimic3_foreign/ne_NP_ne-google_3614.wav diff --git a/mimic3_foreign/ne_NP_ne-google_3960.wav b/assets/mimic3_foreign/ne_NP_ne-google_3960.wav similarity index 100% rename from mimic3_foreign/ne_NP_ne-google_3960.wav rename to assets/mimic3_foreign/ne_NP_ne-google_3960.wav diff --git a/mimic3_foreign/ne_NP_ne-google_3997.wav b/assets/mimic3_foreign/ne_NP_ne-google_3997.wav similarity index 100% rename from mimic3_foreign/ne_NP_ne-google_3997.wav rename to assets/mimic3_foreign/ne_NP_ne-google_3997.wav diff --git a/mimic3_foreign/ne_NP_ne-google_5687.wav b/assets/mimic3_foreign/ne_NP_ne-google_5687.wav similarity index 100% rename from mimic3_foreign/ne_NP_ne-google_5687.wav rename to assets/mimic3_foreign/ne_NP_ne-google_5687.wav diff --git a/mimic3_foreign/ne_NP_ne-google_6329.wav b/assets/mimic3_foreign/ne_NP_ne-google_6329.wav similarity index 100% rename from mimic3_foreign/ne_NP_ne-google_6329.wav rename to assets/mimic3_foreign/ne_NP_ne-google_6329.wav diff --git a/mimic3_foreign/ne_NP_ne-google_6587.wav b/assets/mimic3_foreign/ne_NP_ne-google_6587.wav similarity index 100% rename from mimic3_foreign/ne_NP_ne-google_6587.wav rename to assets/mimic3_foreign/ne_NP_ne-google_6587.wav diff --git a/mimic3_foreign/ne_NP_ne-google_6834.wav b/assets/mimic3_foreign/ne_NP_ne-google_6834.wav similarity index 100% rename from mimic3_foreign/ne_NP_ne-google_6834.wav rename to assets/mimic3_foreign/ne_NP_ne-google_6834.wav diff --git a/mimic3_foreign/ne_NP_ne-google_7957.wav b/assets/mimic3_foreign/ne_NP_ne-google_7957.wav similarity index 100% rename from mimic3_foreign/ne_NP_ne-google_7957.wav rename to assets/mimic3_foreign/ne_NP_ne-google_7957.wav diff --git a/mimic3_foreign/ne_NP_ne-google_9407.wav b/assets/mimic3_foreign/ne_NP_ne-google_9407.wav similarity index 100% rename from mimic3_foreign/ne_NP_ne-google_9407.wav rename to assets/mimic3_foreign/ne_NP_ne-google_9407.wav diff --git a/mimic3_foreign/nl_bart-de-leeuw.wav b/assets/mimic3_foreign/nl_bart-de-leeuw.wav similarity index 100% rename from mimic3_foreign/nl_bart-de-leeuw.wav rename to assets/mimic3_foreign/nl_bart-de-leeuw.wav diff --git a/mimic3_foreign/nl_flemishguy.wav b/assets/mimic3_foreign/nl_flemishguy.wav similarity index 100% rename from mimic3_foreign/nl_flemishguy.wav rename to assets/mimic3_foreign/nl_flemishguy.wav diff --git a/mimic3_foreign/nl_nathalie.wav b/assets/mimic3_foreign/nl_nathalie.wav similarity index 100% rename from mimic3_foreign/nl_nathalie.wav rename to assets/mimic3_foreign/nl_nathalie.wav diff --git a/mimic3_foreign/nl_pmk.wav b/assets/mimic3_foreign/nl_pmk.wav similarity index 100% rename from mimic3_foreign/nl_pmk.wav rename to assets/mimic3_foreign/nl_pmk.wav diff --git a/mimic3_foreign/nl_rdh.wav b/assets/mimic3_foreign/nl_rdh.wav similarity index 100% rename from mimic3_foreign/nl_rdh.wav rename to assets/mimic3_foreign/nl_rdh.wav diff --git a/mimic3_foreign/pl_PL_m-ailabs_nina_brown.wav b/assets/mimic3_foreign/pl_PL_m-ailabs_nina_brown.wav similarity index 100% rename from mimic3_foreign/pl_PL_m-ailabs_nina_brown.wav rename to assets/mimic3_foreign/pl_PL_m-ailabs_nina_brown.wav diff --git a/mimic3_foreign/pl_PL_m-ailabs_piotr_nater.wav b/assets/mimic3_foreign/pl_PL_m-ailabs_piotr_nater.wav similarity index 100% rename from mimic3_foreign/pl_PL_m-ailabs_piotr_nater.wav rename to assets/mimic3_foreign/pl_PL_m-ailabs_piotr_nater.wav diff --git a/mimic3_foreign/ru_RU_multi_hajdurova.wav b/assets/mimic3_foreign/ru_RU_multi_hajdurova.wav similarity index 100% rename from mimic3_foreign/ru_RU_multi_hajdurova.wav rename to assets/mimic3_foreign/ru_RU_multi_hajdurova.wav diff --git a/mimic3_foreign/ru_RU_multi_minaev.wav b/assets/mimic3_foreign/ru_RU_multi_minaev.wav similarity index 100% rename from mimic3_foreign/ru_RU_multi_minaev.wav rename to assets/mimic3_foreign/ru_RU_multi_minaev.wav diff --git a/mimic3_foreign/ru_RU_multi_nikolaev.wav b/assets/mimic3_foreign/ru_RU_multi_nikolaev.wav similarity index 100% rename from mimic3_foreign/ru_RU_multi_nikolaev.wav rename to assets/mimic3_foreign/ru_RU_multi_nikolaev.wav diff --git a/mimic3_foreign/sw_lanfrica.wav b/assets/mimic3_foreign/sw_lanfrica.wav similarity index 100% rename from mimic3_foreign/sw_lanfrica.wav rename to assets/mimic3_foreign/sw_lanfrica.wav diff --git a/mimic3_foreign/te_IN_cmu-indic_kpn.wav b/assets/mimic3_foreign/te_IN_cmu-indic_kpn.wav similarity index 100% rename from mimic3_foreign/te_IN_cmu-indic_kpn.wav rename to assets/mimic3_foreign/te_IN_cmu-indic_kpn.wav diff --git a/mimic3_foreign/te_IN_cmu-indic_sk.wav b/assets/mimic3_foreign/te_IN_cmu-indic_sk.wav similarity index 100% rename from mimic3_foreign/te_IN_cmu-indic_sk.wav rename to assets/mimic3_foreign/te_IN_cmu-indic_sk.wav diff --git a/mimic3_foreign/te_IN_cmu-indic_ss.wav b/assets/mimic3_foreign/te_IN_cmu-indic_ss.wav similarity index 100% rename from mimic3_foreign/te_IN_cmu-indic_ss.wav rename to assets/mimic3_foreign/te_IN_cmu-indic_ss.wav diff --git a/mimic3_foreign/tn_ZA_google-nwu_0045.wav b/assets/mimic3_foreign/tn_ZA_google-nwu_0045.wav similarity index 100% rename from mimic3_foreign/tn_ZA_google-nwu_0045.wav rename to assets/mimic3_foreign/tn_ZA_google-nwu_0045.wav diff --git a/mimic3_foreign/tn_ZA_google-nwu_0378.wav b/assets/mimic3_foreign/tn_ZA_google-nwu_0378.wav similarity index 100% rename from mimic3_foreign/tn_ZA_google-nwu_0378.wav rename to assets/mimic3_foreign/tn_ZA_google-nwu_0378.wav diff --git a/mimic3_foreign/tn_ZA_google-nwu_0441.wav b/assets/mimic3_foreign/tn_ZA_google-nwu_0441.wav similarity index 100% rename from mimic3_foreign/tn_ZA_google-nwu_0441.wav rename to assets/mimic3_foreign/tn_ZA_google-nwu_0441.wav diff --git a/mimic3_foreign/tn_ZA_google-nwu_1483.wav b/assets/mimic3_foreign/tn_ZA_google-nwu_1483.wav similarity index 100% rename from mimic3_foreign/tn_ZA_google-nwu_1483.wav rename to assets/mimic3_foreign/tn_ZA_google-nwu_1483.wav diff --git a/mimic3_foreign/tn_ZA_google-nwu_1498.wav b/assets/mimic3_foreign/tn_ZA_google-nwu_1498.wav similarity index 100% rename from mimic3_foreign/tn_ZA_google-nwu_1498.wav rename to assets/mimic3_foreign/tn_ZA_google-nwu_1498.wav diff --git a/mimic3_foreign/tn_ZA_google-nwu_1932.wav b/assets/mimic3_foreign/tn_ZA_google-nwu_1932.wav similarity index 100% rename from mimic3_foreign/tn_ZA_google-nwu_1932.wav rename to assets/mimic3_foreign/tn_ZA_google-nwu_1932.wav diff --git a/mimic3_foreign/tn_ZA_google-nwu_2839.wav b/assets/mimic3_foreign/tn_ZA_google-nwu_2839.wav similarity index 100% rename from mimic3_foreign/tn_ZA_google-nwu_2839.wav rename to assets/mimic3_foreign/tn_ZA_google-nwu_2839.wav diff --git a/mimic3_foreign/tn_ZA_google-nwu_3342.wav b/assets/mimic3_foreign/tn_ZA_google-nwu_3342.wav similarity index 100% rename from mimic3_foreign/tn_ZA_google-nwu_3342.wav rename to assets/mimic3_foreign/tn_ZA_google-nwu_3342.wav diff --git a/mimic3_foreign/tn_ZA_google-nwu_3629.wav b/assets/mimic3_foreign/tn_ZA_google-nwu_3629.wav similarity index 100% rename from mimic3_foreign/tn_ZA_google-nwu_3629.wav rename to assets/mimic3_foreign/tn_ZA_google-nwu_3629.wav diff --git a/mimic3_foreign/tn_ZA_google-nwu_4506.wav b/assets/mimic3_foreign/tn_ZA_google-nwu_4506.wav similarity index 100% rename from mimic3_foreign/tn_ZA_google-nwu_4506.wav rename to assets/mimic3_foreign/tn_ZA_google-nwu_4506.wav diff --git a/mimic3_foreign/tn_ZA_google-nwu_4850.wav b/assets/mimic3_foreign/tn_ZA_google-nwu_4850.wav similarity index 100% rename from mimic3_foreign/tn_ZA_google-nwu_4850.wav rename to assets/mimic3_foreign/tn_ZA_google-nwu_4850.wav diff --git a/mimic3_foreign/tn_ZA_google-nwu_5628.wav b/assets/mimic3_foreign/tn_ZA_google-nwu_5628.wav similarity index 100% rename from mimic3_foreign/tn_ZA_google-nwu_5628.wav rename to assets/mimic3_foreign/tn_ZA_google-nwu_5628.wav diff --git a/mimic3_foreign/tn_ZA_google-nwu_6116.wav b/assets/mimic3_foreign/tn_ZA_google-nwu_6116.wav similarity index 100% rename from mimic3_foreign/tn_ZA_google-nwu_6116.wav rename to assets/mimic3_foreign/tn_ZA_google-nwu_6116.wav diff --git a/mimic3_foreign/tn_ZA_google-nwu_6206.wav b/assets/mimic3_foreign/tn_ZA_google-nwu_6206.wav similarity index 100% rename from mimic3_foreign/tn_ZA_google-nwu_6206.wav rename to assets/mimic3_foreign/tn_ZA_google-nwu_6206.wav diff --git a/mimic3_foreign/tn_ZA_google-nwu_6234.wav b/assets/mimic3_foreign/tn_ZA_google-nwu_6234.wav similarity index 100% rename from mimic3_foreign/tn_ZA_google-nwu_6234.wav rename to assets/mimic3_foreign/tn_ZA_google-nwu_6234.wav diff --git a/mimic3_foreign/tn_ZA_google-nwu_6459.wav b/assets/mimic3_foreign/tn_ZA_google-nwu_6459.wav similarity index 100% rename from mimic3_foreign/tn_ZA_google-nwu_6459.wav rename to assets/mimic3_foreign/tn_ZA_google-nwu_6459.wav diff --git a/mimic3_foreign/tn_ZA_google-nwu_7674.wav b/assets/mimic3_foreign/tn_ZA_google-nwu_7674.wav similarity index 100% rename from mimic3_foreign/tn_ZA_google-nwu_7674.wav rename to assets/mimic3_foreign/tn_ZA_google-nwu_7674.wav diff --git a/mimic3_foreign/tn_ZA_google-nwu_7693.wav b/assets/mimic3_foreign/tn_ZA_google-nwu_7693.wav similarity index 100% rename from mimic3_foreign/tn_ZA_google-nwu_7693.wav rename to assets/mimic3_foreign/tn_ZA_google-nwu_7693.wav diff --git a/mimic3_foreign/tn_ZA_google-nwu_7866.wav b/assets/mimic3_foreign/tn_ZA_google-nwu_7866.wav similarity index 100% rename from mimic3_foreign/tn_ZA_google-nwu_7866.wav rename to assets/mimic3_foreign/tn_ZA_google-nwu_7866.wav diff --git a/mimic3_foreign/tn_ZA_google-nwu_7896.wav b/assets/mimic3_foreign/tn_ZA_google-nwu_7896.wav similarity index 100% rename from mimic3_foreign/tn_ZA_google-nwu_7896.wav rename to assets/mimic3_foreign/tn_ZA_google-nwu_7896.wav diff --git a/mimic3_foreign/tn_ZA_google-nwu_8333.wav b/assets/mimic3_foreign/tn_ZA_google-nwu_8333.wav similarity index 100% rename from mimic3_foreign/tn_ZA_google-nwu_8333.wav rename to assets/mimic3_foreign/tn_ZA_google-nwu_8333.wav diff --git a/mimic3_foreign/tn_ZA_google-nwu_8512.wav b/assets/mimic3_foreign/tn_ZA_google-nwu_8512.wav similarity index 100% rename from mimic3_foreign/tn_ZA_google-nwu_8512.wav rename to assets/mimic3_foreign/tn_ZA_google-nwu_8512.wav diff --git a/mimic3_foreign/tn_ZA_google-nwu_8532.wav b/assets/mimic3_foreign/tn_ZA_google-nwu_8532.wav similarity index 100% rename from mimic3_foreign/tn_ZA_google-nwu_8532.wav rename to assets/mimic3_foreign/tn_ZA_google-nwu_8532.wav diff --git a/mimic3_foreign/tn_ZA_google-nwu_8914.wav b/assets/mimic3_foreign/tn_ZA_google-nwu_8914.wav similarity index 100% rename from mimic3_foreign/tn_ZA_google-nwu_8914.wav rename to assets/mimic3_foreign/tn_ZA_google-nwu_8914.wav diff --git a/mimic3_foreign/tn_ZA_google-nwu_9061.wav b/assets/mimic3_foreign/tn_ZA_google-nwu_9061.wav similarity index 100% rename from mimic3_foreign/tn_ZA_google-nwu_9061.wav rename to assets/mimic3_foreign/tn_ZA_google-nwu_9061.wav diff --git a/mimic3_foreign/tn_ZA_google-nwu_9365.wav b/assets/mimic3_foreign/tn_ZA_google-nwu_9365.wav similarity index 100% rename from mimic3_foreign/tn_ZA_google-nwu_9365.wav rename to assets/mimic3_foreign/tn_ZA_google-nwu_9365.wav diff --git a/mimic3_foreign/uk_UK_m-ailabs_loboda.wav b/assets/mimic3_foreign/uk_UK_m-ailabs_loboda.wav similarity index 100% rename from mimic3_foreign/uk_UK_m-ailabs_loboda.wav rename to assets/mimic3_foreign/uk_UK_m-ailabs_loboda.wav diff --git a/mimic3_foreign/uk_UK_m-ailabs_miskun.wav b/assets/mimic3_foreign/uk_UK_m-ailabs_miskun.wav similarity index 100% rename from mimic3_foreign/uk_UK_m-ailabs_miskun.wav rename to assets/mimic3_foreign/uk_UK_m-ailabs_miskun.wav diff --git a/mimic3_foreign/uk_UK_m-ailabs_obruchov.wav b/assets/mimic3_foreign/uk_UK_m-ailabs_obruchov.wav similarity index 100% rename from mimic3_foreign/uk_UK_m-ailabs_obruchov.wav rename to assets/mimic3_foreign/uk_UK_m-ailabs_obruchov.wav diff --git a/mimic3_foreign/uk_UK_m-ailabs_pysariev.wav b/assets/mimic3_foreign/uk_UK_m-ailabs_pysariev.wav similarity index 100% rename from mimic3_foreign/uk_UK_m-ailabs_pysariev.wav rename to assets/mimic3_foreign/uk_UK_m-ailabs_pysariev.wav diff --git a/mimic3_foreign/uk_UK_m-ailabs_shepel.wav b/assets/mimic3_foreign/uk_UK_m-ailabs_shepel.wav similarity index 100% rename from mimic3_foreign/uk_UK_m-ailabs_shepel.wav rename to assets/mimic3_foreign/uk_UK_m-ailabs_shepel.wav diff --git a/mimic3_foreign/uk_UK_m-ailabs_sumska.wav b/assets/mimic3_foreign/uk_UK_m-ailabs_sumska.wav similarity index 100% rename from mimic3_foreign/uk_UK_m-ailabs_sumska.wav rename to assets/mimic3_foreign/uk_UK_m-ailabs_sumska.wav diff --git a/mimic3_foreign/vi_VN_vais1000.wav b/assets/mimic3_foreign/vi_VN_vais1000.wav similarity index 100% rename from mimic3_foreign/vi_VN_vais1000.wav rename to assets/mimic3_foreign/vi_VN_vais1000.wav diff --git a/mimic3_foreign/yo_openbible.wav b/assets/mimic3_foreign/yo_openbible.wav similarity index 100% rename from mimic3_foreign/yo_openbible.wav rename to assets/mimic3_foreign/yo_openbible.wav diff --git a/mimic3_foreign_4x/af_ZA_google-nwu_0184.wav b/assets/mimic3_foreign_4x/af_ZA_google-nwu_0184.wav similarity index 100% rename from mimic3_foreign_4x/af_ZA_google-nwu_0184.wav rename to assets/mimic3_foreign_4x/af_ZA_google-nwu_0184.wav diff --git a/mimic3_foreign_4x/af_ZA_google-nwu_1919.wav b/assets/mimic3_foreign_4x/af_ZA_google-nwu_1919.wav similarity index 100% rename from mimic3_foreign_4x/af_ZA_google-nwu_1919.wav rename to assets/mimic3_foreign_4x/af_ZA_google-nwu_1919.wav diff --git a/mimic3_foreign_4x/af_ZA_google-nwu_2418.wav b/assets/mimic3_foreign_4x/af_ZA_google-nwu_2418.wav similarity index 100% rename from mimic3_foreign_4x/af_ZA_google-nwu_2418.wav rename to assets/mimic3_foreign_4x/af_ZA_google-nwu_2418.wav diff --git a/mimic3_foreign_4x/af_ZA_google-nwu_6590.wav b/assets/mimic3_foreign_4x/af_ZA_google-nwu_6590.wav similarity index 100% rename from mimic3_foreign_4x/af_ZA_google-nwu_6590.wav rename to assets/mimic3_foreign_4x/af_ZA_google-nwu_6590.wav diff --git a/mimic3_foreign_4x/af_ZA_google-nwu_7130.wav b/assets/mimic3_foreign_4x/af_ZA_google-nwu_7130.wav similarity index 100% rename from mimic3_foreign_4x/af_ZA_google-nwu_7130.wav rename to assets/mimic3_foreign_4x/af_ZA_google-nwu_7130.wav diff --git a/mimic3_foreign_4x/af_ZA_google-nwu_7214.wav b/assets/mimic3_foreign_4x/af_ZA_google-nwu_7214.wav similarity index 100% rename from mimic3_foreign_4x/af_ZA_google-nwu_7214.wav rename to assets/mimic3_foreign_4x/af_ZA_google-nwu_7214.wav diff --git a/mimic3_foreign_4x/af_ZA_google-nwu_8148.wav b/assets/mimic3_foreign_4x/af_ZA_google-nwu_8148.wav similarity index 100% rename from mimic3_foreign_4x/af_ZA_google-nwu_8148.wav rename to assets/mimic3_foreign_4x/af_ZA_google-nwu_8148.wav diff --git a/mimic3_foreign_4x/af_ZA_google-nwu_8924.wav b/assets/mimic3_foreign_4x/af_ZA_google-nwu_8924.wav similarity index 100% rename from mimic3_foreign_4x/af_ZA_google-nwu_8924.wav rename to assets/mimic3_foreign_4x/af_ZA_google-nwu_8924.wav diff --git a/mimic3_foreign_4x/af_ZA_google-nwu_8963.wav b/assets/mimic3_foreign_4x/af_ZA_google-nwu_8963.wav similarity index 100% rename from mimic3_foreign_4x/af_ZA_google-nwu_8963.wav rename to assets/mimic3_foreign_4x/af_ZA_google-nwu_8963.wav diff --git a/mimic3_foreign_4x/bn_multi_00737.wav b/assets/mimic3_foreign_4x/bn_multi_00737.wav similarity index 100% rename from mimic3_foreign_4x/bn_multi_00737.wav rename to assets/mimic3_foreign_4x/bn_multi_00737.wav diff --git a/mimic3_foreign_4x/bn_multi_00779.wav b/assets/mimic3_foreign_4x/bn_multi_00779.wav similarity index 100% rename from mimic3_foreign_4x/bn_multi_00779.wav rename to assets/mimic3_foreign_4x/bn_multi_00779.wav diff --git a/mimic3_foreign_4x/bn_multi_01232.wav b/assets/mimic3_foreign_4x/bn_multi_01232.wav similarity index 100% rename from mimic3_foreign_4x/bn_multi_01232.wav rename to assets/mimic3_foreign_4x/bn_multi_01232.wav diff --git a/mimic3_foreign_4x/bn_multi_01701.wav b/assets/mimic3_foreign_4x/bn_multi_01701.wav similarity index 100% rename from mimic3_foreign_4x/bn_multi_01701.wav rename to assets/mimic3_foreign_4x/bn_multi_01701.wav diff --git a/mimic3_foreign_4x/bn_multi_02194.wav b/assets/mimic3_foreign_4x/bn_multi_02194.wav similarity index 100% rename from mimic3_foreign_4x/bn_multi_02194.wav rename to assets/mimic3_foreign_4x/bn_multi_02194.wav diff --git a/mimic3_foreign_4x/bn_multi_03042.wav b/assets/mimic3_foreign_4x/bn_multi_03042.wav similarity index 100% rename from mimic3_foreign_4x/bn_multi_03042.wav rename to assets/mimic3_foreign_4x/bn_multi_03042.wav diff --git a/mimic3_foreign_4x/bn_multi_0834.wav b/assets/mimic3_foreign_4x/bn_multi_0834.wav similarity index 100% rename from mimic3_foreign_4x/bn_multi_0834.wav rename to assets/mimic3_foreign_4x/bn_multi_0834.wav diff --git a/mimic3_foreign_4x/bn_multi_1010.wav b/assets/mimic3_foreign_4x/bn_multi_1010.wav similarity index 100% rename from mimic3_foreign_4x/bn_multi_1010.wav rename to assets/mimic3_foreign_4x/bn_multi_1010.wav diff --git a/mimic3_foreign_4x/bn_multi_3108.wav b/assets/mimic3_foreign_4x/bn_multi_3108.wav similarity index 100% rename from mimic3_foreign_4x/bn_multi_3108.wav rename to assets/mimic3_foreign_4x/bn_multi_3108.wav diff --git a/mimic3_foreign_4x/bn_multi_3713.wav b/assets/mimic3_foreign_4x/bn_multi_3713.wav similarity index 100% rename from mimic3_foreign_4x/bn_multi_3713.wav rename to assets/mimic3_foreign_4x/bn_multi_3713.wav diff --git a/mimic3_foreign_4x/bn_multi_3958.wav b/assets/mimic3_foreign_4x/bn_multi_3958.wav similarity index 100% rename from mimic3_foreign_4x/bn_multi_3958.wav rename to assets/mimic3_foreign_4x/bn_multi_3958.wav diff --git a/mimic3_foreign_4x/bn_multi_4046.wav b/assets/mimic3_foreign_4x/bn_multi_4046.wav similarity index 100% rename from mimic3_foreign_4x/bn_multi_4046.wav rename to assets/mimic3_foreign_4x/bn_multi_4046.wav diff --git a/mimic3_foreign_4x/bn_multi_4811.wav b/assets/mimic3_foreign_4x/bn_multi_4811.wav similarity index 100% rename from mimic3_foreign_4x/bn_multi_4811.wav rename to assets/mimic3_foreign_4x/bn_multi_4811.wav diff --git a/mimic3_foreign_4x/bn_multi_5958.wav b/assets/mimic3_foreign_4x/bn_multi_5958.wav similarity index 100% rename from mimic3_foreign_4x/bn_multi_5958.wav rename to assets/mimic3_foreign_4x/bn_multi_5958.wav diff --git a/mimic3_foreign_4x/bn_multi_9169.wav b/assets/mimic3_foreign_4x/bn_multi_9169.wav similarity index 100% rename from mimic3_foreign_4x/bn_multi_9169.wav rename to assets/mimic3_foreign_4x/bn_multi_9169.wav diff --git a/mimic3_foreign_4x/bn_multi_rm.wav b/assets/mimic3_foreign_4x/bn_multi_rm.wav similarity index 100% rename from mimic3_foreign_4x/bn_multi_rm.wav rename to assets/mimic3_foreign_4x/bn_multi_rm.wav diff --git a/mimic3_foreign_4x/de_DE_m-ailabs_angela_merkel.wav b/assets/mimic3_foreign_4x/de_DE_m-ailabs_angela_merkel.wav similarity index 100% rename from mimic3_foreign_4x/de_DE_m-ailabs_angela_merkel.wav rename to assets/mimic3_foreign_4x/de_DE_m-ailabs_angela_merkel.wav diff --git a/mimic3_foreign_4x/de_DE_m-ailabs_eva_k.wav b/assets/mimic3_foreign_4x/de_DE_m-ailabs_eva_k.wav similarity index 100% rename from mimic3_foreign_4x/de_DE_m-ailabs_eva_k.wav rename to assets/mimic3_foreign_4x/de_DE_m-ailabs_eva_k.wav diff --git a/mimic3_foreign_4x/de_DE_m-ailabs_karlsson.wav b/assets/mimic3_foreign_4x/de_DE_m-ailabs_karlsson.wav similarity index 100% rename from mimic3_foreign_4x/de_DE_m-ailabs_karlsson.wav rename to assets/mimic3_foreign_4x/de_DE_m-ailabs_karlsson.wav diff --git a/mimic3_foreign_4x/de_DE_m-ailabs_ramona_deininger.wav b/assets/mimic3_foreign_4x/de_DE_m-ailabs_ramona_deininger.wav similarity index 100% rename from mimic3_foreign_4x/de_DE_m-ailabs_ramona_deininger.wav rename to assets/mimic3_foreign_4x/de_DE_m-ailabs_ramona_deininger.wav diff --git a/mimic3_foreign_4x/de_DE_m-ailabs_rebecca_braunert_plunkett.wav b/assets/mimic3_foreign_4x/de_DE_m-ailabs_rebecca_braunert_plunkett.wav similarity index 100% rename from mimic3_foreign_4x/de_DE_m-ailabs_rebecca_braunert_plunkett.wav rename to assets/mimic3_foreign_4x/de_DE_m-ailabs_rebecca_braunert_plunkett.wav diff --git a/mimic3_foreign_4x/de_DE_thorsten-emotion_amused.wav b/assets/mimic3_foreign_4x/de_DE_thorsten-emotion_amused.wav similarity index 100% rename from mimic3_foreign_4x/de_DE_thorsten-emotion_amused.wav rename to assets/mimic3_foreign_4x/de_DE_thorsten-emotion_amused.wav diff --git a/mimic3_foreign_4x/de_DE_thorsten-emotion_angry.wav b/assets/mimic3_foreign_4x/de_DE_thorsten-emotion_angry.wav similarity index 100% rename from mimic3_foreign_4x/de_DE_thorsten-emotion_angry.wav rename to assets/mimic3_foreign_4x/de_DE_thorsten-emotion_angry.wav diff --git a/mimic3_foreign_4x/de_DE_thorsten-emotion_disgusted.wav b/assets/mimic3_foreign_4x/de_DE_thorsten-emotion_disgusted.wav similarity index 100% rename from mimic3_foreign_4x/de_DE_thorsten-emotion_disgusted.wav rename to assets/mimic3_foreign_4x/de_DE_thorsten-emotion_disgusted.wav diff --git a/mimic3_foreign_4x/de_DE_thorsten-emotion_drunk.wav b/assets/mimic3_foreign_4x/de_DE_thorsten-emotion_drunk.wav similarity index 100% rename from mimic3_foreign_4x/de_DE_thorsten-emotion_drunk.wav rename to assets/mimic3_foreign_4x/de_DE_thorsten-emotion_drunk.wav diff --git a/mimic3_foreign_4x/de_DE_thorsten-emotion_neutral.wav b/assets/mimic3_foreign_4x/de_DE_thorsten-emotion_neutral.wav similarity index 100% rename from mimic3_foreign_4x/de_DE_thorsten-emotion_neutral.wav rename to assets/mimic3_foreign_4x/de_DE_thorsten-emotion_neutral.wav diff --git a/mimic3_foreign_4x/de_DE_thorsten-emotion_sleepy.wav b/assets/mimic3_foreign_4x/de_DE_thorsten-emotion_sleepy.wav similarity index 100% rename from mimic3_foreign_4x/de_DE_thorsten-emotion_sleepy.wav rename to assets/mimic3_foreign_4x/de_DE_thorsten-emotion_sleepy.wav diff --git a/mimic3_foreign_4x/de_DE_thorsten-emotion_surprised.wav b/assets/mimic3_foreign_4x/de_DE_thorsten-emotion_surprised.wav similarity index 100% rename from mimic3_foreign_4x/de_DE_thorsten-emotion_surprised.wav rename to assets/mimic3_foreign_4x/de_DE_thorsten-emotion_surprised.wav diff --git a/mimic3_foreign_4x/de_DE_thorsten-emotion_whisper.wav b/assets/mimic3_foreign_4x/de_DE_thorsten-emotion_whisper.wav similarity index 100% rename from mimic3_foreign_4x/de_DE_thorsten-emotion_whisper.wav rename to assets/mimic3_foreign_4x/de_DE_thorsten-emotion_whisper.wav diff --git a/mimic3_foreign_4x/de_DE_thorsten.wav b/assets/mimic3_foreign_4x/de_DE_thorsten.wav similarity index 100% rename from mimic3_foreign_4x/de_DE_thorsten.wav rename to assets/mimic3_foreign_4x/de_DE_thorsten.wav diff --git a/mimic3_foreign_4x/el_GR_rapunzelina.wav b/assets/mimic3_foreign_4x/el_GR_rapunzelina.wav similarity index 100% rename from mimic3_foreign_4x/el_GR_rapunzelina.wav rename to assets/mimic3_foreign_4x/el_GR_rapunzelina.wav diff --git a/mimic3_foreign_4x/en_US_cmu_arctic_aew.wav b/assets/mimic3_foreign_4x/en_US_cmu_arctic_aew.wav similarity index 100% rename from mimic3_foreign_4x/en_US_cmu_arctic_aew.wav rename to assets/mimic3_foreign_4x/en_US_cmu_arctic_aew.wav diff --git a/mimic3_foreign_4x/en_US_cmu_arctic_ahw.wav b/assets/mimic3_foreign_4x/en_US_cmu_arctic_ahw.wav similarity index 100% rename from mimic3_foreign_4x/en_US_cmu_arctic_ahw.wav rename to assets/mimic3_foreign_4x/en_US_cmu_arctic_ahw.wav diff --git a/mimic3_foreign_4x/en_US_cmu_arctic_aup.wav b/assets/mimic3_foreign_4x/en_US_cmu_arctic_aup.wav similarity index 100% rename from mimic3_foreign_4x/en_US_cmu_arctic_aup.wav rename to assets/mimic3_foreign_4x/en_US_cmu_arctic_aup.wav diff --git a/mimic3_foreign_4x/en_US_cmu_arctic_awb.wav b/assets/mimic3_foreign_4x/en_US_cmu_arctic_awb.wav similarity index 100% rename from mimic3_foreign_4x/en_US_cmu_arctic_awb.wav rename to assets/mimic3_foreign_4x/en_US_cmu_arctic_awb.wav diff --git a/mimic3_foreign_4x/en_US_cmu_arctic_axb.wav b/assets/mimic3_foreign_4x/en_US_cmu_arctic_axb.wav similarity index 100% rename from mimic3_foreign_4x/en_US_cmu_arctic_axb.wav rename to assets/mimic3_foreign_4x/en_US_cmu_arctic_axb.wav diff --git a/mimic3_foreign_4x/en_US_cmu_arctic_bdl.wav b/assets/mimic3_foreign_4x/en_US_cmu_arctic_bdl.wav similarity index 100% rename from mimic3_foreign_4x/en_US_cmu_arctic_bdl.wav rename to assets/mimic3_foreign_4x/en_US_cmu_arctic_bdl.wav diff --git a/mimic3_foreign_4x/en_US_cmu_arctic_clb.wav b/assets/mimic3_foreign_4x/en_US_cmu_arctic_clb.wav similarity index 100% rename from mimic3_foreign_4x/en_US_cmu_arctic_clb.wav rename to assets/mimic3_foreign_4x/en_US_cmu_arctic_clb.wav diff --git a/mimic3_foreign_4x/en_US_cmu_arctic_eey.wav b/assets/mimic3_foreign_4x/en_US_cmu_arctic_eey.wav similarity index 100% rename from mimic3_foreign_4x/en_US_cmu_arctic_eey.wav rename to assets/mimic3_foreign_4x/en_US_cmu_arctic_eey.wav diff --git a/mimic3_foreign_4x/en_US_cmu_arctic_fem.wav b/assets/mimic3_foreign_4x/en_US_cmu_arctic_fem.wav similarity index 100% rename from mimic3_foreign_4x/en_US_cmu_arctic_fem.wav rename to assets/mimic3_foreign_4x/en_US_cmu_arctic_fem.wav diff --git a/mimic3_foreign_4x/en_US_cmu_arctic_gka.wav b/assets/mimic3_foreign_4x/en_US_cmu_arctic_gka.wav similarity index 100% rename from mimic3_foreign_4x/en_US_cmu_arctic_gka.wav rename to assets/mimic3_foreign_4x/en_US_cmu_arctic_gka.wav diff --git a/mimic3_foreign_4x/en_US_cmu_arctic_jmk.wav b/assets/mimic3_foreign_4x/en_US_cmu_arctic_jmk.wav similarity index 100% rename from mimic3_foreign_4x/en_US_cmu_arctic_jmk.wav rename to assets/mimic3_foreign_4x/en_US_cmu_arctic_jmk.wav diff --git a/mimic3_foreign_4x/en_US_cmu_arctic_ksp.wav b/assets/mimic3_foreign_4x/en_US_cmu_arctic_ksp.wav similarity index 100% rename from mimic3_foreign_4x/en_US_cmu_arctic_ksp.wav rename to assets/mimic3_foreign_4x/en_US_cmu_arctic_ksp.wav diff --git a/mimic3_foreign_4x/en_US_cmu_arctic_ljm.wav b/assets/mimic3_foreign_4x/en_US_cmu_arctic_ljm.wav similarity index 100% rename from mimic3_foreign_4x/en_US_cmu_arctic_ljm.wav rename to assets/mimic3_foreign_4x/en_US_cmu_arctic_ljm.wav diff --git a/mimic3_foreign_4x/en_US_cmu_arctic_lnh.wav b/assets/mimic3_foreign_4x/en_US_cmu_arctic_lnh.wav similarity index 100% rename from mimic3_foreign_4x/en_US_cmu_arctic_lnh.wav rename to assets/mimic3_foreign_4x/en_US_cmu_arctic_lnh.wav diff --git a/mimic3_foreign_4x/en_US_cmu_arctic_rms.wav b/assets/mimic3_foreign_4x/en_US_cmu_arctic_rms.wav similarity index 100% rename from mimic3_foreign_4x/en_US_cmu_arctic_rms.wav rename to assets/mimic3_foreign_4x/en_US_cmu_arctic_rms.wav diff --git a/mimic3_foreign_4x/en_US_cmu_arctic_rxr.wav b/assets/mimic3_foreign_4x/en_US_cmu_arctic_rxr.wav similarity index 100% rename from mimic3_foreign_4x/en_US_cmu_arctic_rxr.wav rename to assets/mimic3_foreign_4x/en_US_cmu_arctic_rxr.wav diff --git a/mimic3_foreign_4x/en_US_cmu_arctic_slp.wav b/assets/mimic3_foreign_4x/en_US_cmu_arctic_slp.wav similarity index 100% rename from mimic3_foreign_4x/en_US_cmu_arctic_slp.wav rename to assets/mimic3_foreign_4x/en_US_cmu_arctic_slp.wav diff --git a/mimic3_foreign_4x/en_US_cmu_arctic_slt.wav b/assets/mimic3_foreign_4x/en_US_cmu_arctic_slt.wav similarity index 100% rename from mimic3_foreign_4x/en_US_cmu_arctic_slt.wav rename to assets/mimic3_foreign_4x/en_US_cmu_arctic_slt.wav diff --git a/mimic3_foreign_4x/en_US_hifi-tts_6097.wav b/assets/mimic3_foreign_4x/en_US_hifi-tts_6097.wav similarity index 100% rename from mimic3_foreign_4x/en_US_hifi-tts_6097.wav rename to assets/mimic3_foreign_4x/en_US_hifi-tts_6097.wav diff --git a/mimic3_foreign_4x/en_US_hifi-tts_9017.wav b/assets/mimic3_foreign_4x/en_US_hifi-tts_9017.wav similarity index 100% rename from mimic3_foreign_4x/en_US_hifi-tts_9017.wav rename to assets/mimic3_foreign_4x/en_US_hifi-tts_9017.wav diff --git a/mimic3_foreign_4x/en_US_hifi-tts_92.wav b/assets/mimic3_foreign_4x/en_US_hifi-tts_92.wav similarity index 100% rename from mimic3_foreign_4x/en_US_hifi-tts_92.wav rename to assets/mimic3_foreign_4x/en_US_hifi-tts_92.wav diff --git a/mimic3_foreign_4x/en_US_ljspeech.wav b/assets/mimic3_foreign_4x/en_US_ljspeech.wav similarity index 100% rename from mimic3_foreign_4x/en_US_ljspeech.wav rename to assets/mimic3_foreign_4x/en_US_ljspeech.wav diff --git a/mimic3_foreign_4x/en_US_m-ailabs_elliot_miller.wav b/assets/mimic3_foreign_4x/en_US_m-ailabs_elliot_miller.wav similarity index 100% rename from mimic3_foreign_4x/en_US_m-ailabs_elliot_miller.wav rename to assets/mimic3_foreign_4x/en_US_m-ailabs_elliot_miller.wav diff --git a/mimic3_foreign_4x/en_US_m-ailabs_judy_bieber.wav b/assets/mimic3_foreign_4x/en_US_m-ailabs_judy_bieber.wav similarity index 100% rename from mimic3_foreign_4x/en_US_m-ailabs_judy_bieber.wav rename to assets/mimic3_foreign_4x/en_US_m-ailabs_judy_bieber.wav diff --git a/mimic3_foreign_4x/en_US_m-ailabs_mary_ann.wav b/assets/mimic3_foreign_4x/en_US_m-ailabs_mary_ann.wav similarity index 100% rename from mimic3_foreign_4x/en_US_m-ailabs_mary_ann.wav rename to assets/mimic3_foreign_4x/en_US_m-ailabs_mary_ann.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p225.wav b/assets/mimic3_foreign_4x/en_US_vctk_p225.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p225.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p225.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p226.wav b/assets/mimic3_foreign_4x/en_US_vctk_p226.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p226.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p226.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p227.wav b/assets/mimic3_foreign_4x/en_US_vctk_p227.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p227.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p227.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p228.wav b/assets/mimic3_foreign_4x/en_US_vctk_p228.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p228.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p228.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p229.wav b/assets/mimic3_foreign_4x/en_US_vctk_p229.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p229.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p229.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p230.wav b/assets/mimic3_foreign_4x/en_US_vctk_p230.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p230.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p230.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p231.wav b/assets/mimic3_foreign_4x/en_US_vctk_p231.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p231.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p231.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p232.wav b/assets/mimic3_foreign_4x/en_US_vctk_p232.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p232.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p232.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p233.wav b/assets/mimic3_foreign_4x/en_US_vctk_p233.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p233.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p233.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p234.wav b/assets/mimic3_foreign_4x/en_US_vctk_p234.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p234.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p234.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p236.wav b/assets/mimic3_foreign_4x/en_US_vctk_p236.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p236.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p236.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p237.wav b/assets/mimic3_foreign_4x/en_US_vctk_p237.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p237.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p237.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p238.wav b/assets/mimic3_foreign_4x/en_US_vctk_p238.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p238.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p238.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p239.wav b/assets/mimic3_foreign_4x/en_US_vctk_p239.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p239.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p239.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p240.wav b/assets/mimic3_foreign_4x/en_US_vctk_p240.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p240.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p240.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p241.wav b/assets/mimic3_foreign_4x/en_US_vctk_p241.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p241.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p241.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p243.wav b/assets/mimic3_foreign_4x/en_US_vctk_p243.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p243.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p243.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p244.wav b/assets/mimic3_foreign_4x/en_US_vctk_p244.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p244.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p244.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p245.wav b/assets/mimic3_foreign_4x/en_US_vctk_p245.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p245.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p245.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p246.wav b/assets/mimic3_foreign_4x/en_US_vctk_p246.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p246.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p246.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p247.wav b/assets/mimic3_foreign_4x/en_US_vctk_p247.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p247.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p247.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p248.wav b/assets/mimic3_foreign_4x/en_US_vctk_p248.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p248.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p248.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p249.wav b/assets/mimic3_foreign_4x/en_US_vctk_p249.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p249.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p249.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p250.wav b/assets/mimic3_foreign_4x/en_US_vctk_p250.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p250.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p250.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p251.wav b/assets/mimic3_foreign_4x/en_US_vctk_p251.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p251.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p251.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p252.wav b/assets/mimic3_foreign_4x/en_US_vctk_p252.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p252.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p252.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p253.wav b/assets/mimic3_foreign_4x/en_US_vctk_p253.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p253.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p253.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p254.wav b/assets/mimic3_foreign_4x/en_US_vctk_p254.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p254.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p254.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p255.wav b/assets/mimic3_foreign_4x/en_US_vctk_p255.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p255.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p255.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p256.wav b/assets/mimic3_foreign_4x/en_US_vctk_p256.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p256.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p256.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p257.wav b/assets/mimic3_foreign_4x/en_US_vctk_p257.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p257.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p257.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p258.wav b/assets/mimic3_foreign_4x/en_US_vctk_p258.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p258.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p258.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p259.wav b/assets/mimic3_foreign_4x/en_US_vctk_p259.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p259.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p259.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p260.wav b/assets/mimic3_foreign_4x/en_US_vctk_p260.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p260.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p260.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p261.wav b/assets/mimic3_foreign_4x/en_US_vctk_p261.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p261.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p261.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p262.wav b/assets/mimic3_foreign_4x/en_US_vctk_p262.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p262.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p262.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p263.wav b/assets/mimic3_foreign_4x/en_US_vctk_p263.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p263.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p263.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p264.wav b/assets/mimic3_foreign_4x/en_US_vctk_p264.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p264.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p264.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p265.wav b/assets/mimic3_foreign_4x/en_US_vctk_p265.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p265.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p265.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p266.wav b/assets/mimic3_foreign_4x/en_US_vctk_p266.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p266.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p266.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p267.wav b/assets/mimic3_foreign_4x/en_US_vctk_p267.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p267.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p267.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p268.wav b/assets/mimic3_foreign_4x/en_US_vctk_p268.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p268.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p268.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p269.wav b/assets/mimic3_foreign_4x/en_US_vctk_p269.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p269.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p269.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p270.wav b/assets/mimic3_foreign_4x/en_US_vctk_p270.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p270.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p270.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p271.wav b/assets/mimic3_foreign_4x/en_US_vctk_p271.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p271.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p271.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p272.wav b/assets/mimic3_foreign_4x/en_US_vctk_p272.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p272.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p272.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p273.wav b/assets/mimic3_foreign_4x/en_US_vctk_p273.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p273.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p273.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p274.wav b/assets/mimic3_foreign_4x/en_US_vctk_p274.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p274.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p274.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p275.wav b/assets/mimic3_foreign_4x/en_US_vctk_p275.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p275.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p275.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p276.wav b/assets/mimic3_foreign_4x/en_US_vctk_p276.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p276.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p276.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p277.wav b/assets/mimic3_foreign_4x/en_US_vctk_p277.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p277.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p277.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p278.wav b/assets/mimic3_foreign_4x/en_US_vctk_p278.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p278.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p278.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p279.wav b/assets/mimic3_foreign_4x/en_US_vctk_p279.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p279.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p279.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p280.wav b/assets/mimic3_foreign_4x/en_US_vctk_p280.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p280.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p280.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p281.wav b/assets/mimic3_foreign_4x/en_US_vctk_p281.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p281.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p281.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p282.wav b/assets/mimic3_foreign_4x/en_US_vctk_p282.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p282.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p282.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p283.wav b/assets/mimic3_foreign_4x/en_US_vctk_p283.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p283.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p283.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p284.wav b/assets/mimic3_foreign_4x/en_US_vctk_p284.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p284.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p284.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p285.wav b/assets/mimic3_foreign_4x/en_US_vctk_p285.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p285.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p285.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p286.wav b/assets/mimic3_foreign_4x/en_US_vctk_p286.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p286.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p286.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p287.wav b/assets/mimic3_foreign_4x/en_US_vctk_p287.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p287.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p287.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p288.wav b/assets/mimic3_foreign_4x/en_US_vctk_p288.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p288.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p288.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p292.wav b/assets/mimic3_foreign_4x/en_US_vctk_p292.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p292.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p292.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p293.wav b/assets/mimic3_foreign_4x/en_US_vctk_p293.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p293.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p293.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p294.wav b/assets/mimic3_foreign_4x/en_US_vctk_p294.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p294.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p294.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p295.wav b/assets/mimic3_foreign_4x/en_US_vctk_p295.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p295.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p295.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p297.wav b/assets/mimic3_foreign_4x/en_US_vctk_p297.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p297.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p297.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p298.wav b/assets/mimic3_foreign_4x/en_US_vctk_p298.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p298.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p298.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p299.wav b/assets/mimic3_foreign_4x/en_US_vctk_p299.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p299.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p299.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p300.wav b/assets/mimic3_foreign_4x/en_US_vctk_p300.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p300.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p300.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p301.wav b/assets/mimic3_foreign_4x/en_US_vctk_p301.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p301.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p301.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p302.wav b/assets/mimic3_foreign_4x/en_US_vctk_p302.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p302.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p302.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p303.wav b/assets/mimic3_foreign_4x/en_US_vctk_p303.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p303.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p303.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p304.wav b/assets/mimic3_foreign_4x/en_US_vctk_p304.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p304.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p304.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p305.wav b/assets/mimic3_foreign_4x/en_US_vctk_p305.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p305.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p305.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p306.wav b/assets/mimic3_foreign_4x/en_US_vctk_p306.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p306.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p306.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p307.wav b/assets/mimic3_foreign_4x/en_US_vctk_p307.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p307.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p307.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p308.wav b/assets/mimic3_foreign_4x/en_US_vctk_p308.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p308.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p308.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p310.wav b/assets/mimic3_foreign_4x/en_US_vctk_p310.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p310.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p310.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p311.wav b/assets/mimic3_foreign_4x/en_US_vctk_p311.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p311.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p311.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p312.wav b/assets/mimic3_foreign_4x/en_US_vctk_p312.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p312.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p312.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p313.wav b/assets/mimic3_foreign_4x/en_US_vctk_p313.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p313.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p313.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p314.wav b/assets/mimic3_foreign_4x/en_US_vctk_p314.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p314.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p314.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p316.wav b/assets/mimic3_foreign_4x/en_US_vctk_p316.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p316.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p316.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p317.wav b/assets/mimic3_foreign_4x/en_US_vctk_p317.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p317.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p317.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p318.wav b/assets/mimic3_foreign_4x/en_US_vctk_p318.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p318.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p318.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p323.wav b/assets/mimic3_foreign_4x/en_US_vctk_p323.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p323.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p323.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p326.wav b/assets/mimic3_foreign_4x/en_US_vctk_p326.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p326.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p326.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p329.wav b/assets/mimic3_foreign_4x/en_US_vctk_p329.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p329.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p329.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p330.wav b/assets/mimic3_foreign_4x/en_US_vctk_p330.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p330.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p330.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p333.wav b/assets/mimic3_foreign_4x/en_US_vctk_p333.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p333.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p333.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p334.wav b/assets/mimic3_foreign_4x/en_US_vctk_p334.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p334.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p334.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p335.wav b/assets/mimic3_foreign_4x/en_US_vctk_p335.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p335.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p335.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p336.wav b/assets/mimic3_foreign_4x/en_US_vctk_p336.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p336.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p336.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p339.wav b/assets/mimic3_foreign_4x/en_US_vctk_p339.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p339.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p339.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p340.wav b/assets/mimic3_foreign_4x/en_US_vctk_p340.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p340.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p340.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p341.wav b/assets/mimic3_foreign_4x/en_US_vctk_p341.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p341.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p341.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p343.wav b/assets/mimic3_foreign_4x/en_US_vctk_p343.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p343.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p343.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p345.wav b/assets/mimic3_foreign_4x/en_US_vctk_p345.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p345.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p345.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p347.wav b/assets/mimic3_foreign_4x/en_US_vctk_p347.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p347.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p347.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p351.wav b/assets/mimic3_foreign_4x/en_US_vctk_p351.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p351.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p351.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p360.wav b/assets/mimic3_foreign_4x/en_US_vctk_p360.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p360.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p360.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p361.wav b/assets/mimic3_foreign_4x/en_US_vctk_p361.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p361.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p361.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p362.wav b/assets/mimic3_foreign_4x/en_US_vctk_p362.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p362.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p362.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p363.wav b/assets/mimic3_foreign_4x/en_US_vctk_p363.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p363.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p363.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p364.wav b/assets/mimic3_foreign_4x/en_US_vctk_p364.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p364.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p364.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p374.wav b/assets/mimic3_foreign_4x/en_US_vctk_p374.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p374.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p374.wav diff --git a/mimic3_foreign_4x/en_US_vctk_p376.wav b/assets/mimic3_foreign_4x/en_US_vctk_p376.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_p376.wav rename to assets/mimic3_foreign_4x/en_US_vctk_p376.wav diff --git a/mimic3_foreign_4x/en_US_vctk_s5.wav b/assets/mimic3_foreign_4x/en_US_vctk_s5.wav similarity index 100% rename from mimic3_foreign_4x/en_US_vctk_s5.wav rename to assets/mimic3_foreign_4x/en_US_vctk_s5.wav diff --git a/mimic3_foreign_4x/es_ES_carlfm.wav b/assets/mimic3_foreign_4x/es_ES_carlfm.wav similarity index 100% rename from mimic3_foreign_4x/es_ES_carlfm.wav rename to assets/mimic3_foreign_4x/es_ES_carlfm.wav diff --git a/mimic3_foreign_4x/es_ES_m-ailabs_karen_savage.wav b/assets/mimic3_foreign_4x/es_ES_m-ailabs_karen_savage.wav similarity index 100% rename from mimic3_foreign_4x/es_ES_m-ailabs_karen_savage.wav rename to assets/mimic3_foreign_4x/es_ES_m-ailabs_karen_savage.wav diff --git a/mimic3_foreign_4x/es_ES_m-ailabs_tux.wav b/assets/mimic3_foreign_4x/es_ES_m-ailabs_tux.wav similarity index 100% rename from mimic3_foreign_4x/es_ES_m-ailabs_tux.wav rename to assets/mimic3_foreign_4x/es_ES_m-ailabs_tux.wav diff --git a/mimic3_foreign_4x/es_ES_m-ailabs_victor_villarraza.wav b/assets/mimic3_foreign_4x/es_ES_m-ailabs_victor_villarraza.wav similarity index 100% rename from mimic3_foreign_4x/es_ES_m-ailabs_victor_villarraza.wav rename to assets/mimic3_foreign_4x/es_ES_m-ailabs_victor_villarraza.wav diff --git a/mimic3_foreign_4x/fa_haaniye.wav b/assets/mimic3_foreign_4x/fa_haaniye.wav similarity index 100% rename from mimic3_foreign_4x/fa_haaniye.wav rename to assets/mimic3_foreign_4x/fa_haaniye.wav diff --git a/mimic3_foreign_4x/fi_FI_harri-tapani-ylilammi.wav b/assets/mimic3_foreign_4x/fi_FI_harri-tapani-ylilammi.wav similarity index 100% rename from mimic3_foreign_4x/fi_FI_harri-tapani-ylilammi.wav rename to assets/mimic3_foreign_4x/fi_FI_harri-tapani-ylilammi.wav diff --git a/mimic3_foreign_4x/fr_FR_m-ailabs_bernard.wav b/assets/mimic3_foreign_4x/fr_FR_m-ailabs_bernard.wav similarity index 100% rename from mimic3_foreign_4x/fr_FR_m-ailabs_bernard.wav rename to assets/mimic3_foreign_4x/fr_FR_m-ailabs_bernard.wav diff --git a/mimic3_foreign_4x/fr_FR_m-ailabs_ezwa.wav b/assets/mimic3_foreign_4x/fr_FR_m-ailabs_ezwa.wav similarity index 100% rename from mimic3_foreign_4x/fr_FR_m-ailabs_ezwa.wav rename to assets/mimic3_foreign_4x/fr_FR_m-ailabs_ezwa.wav diff --git a/mimic3_foreign_4x/fr_FR_m-ailabs_gilles_g_le_blanc.wav b/assets/mimic3_foreign_4x/fr_FR_m-ailabs_gilles_g_le_blanc.wav similarity index 100% rename from mimic3_foreign_4x/fr_FR_m-ailabs_gilles_g_le_blanc.wav rename to assets/mimic3_foreign_4x/fr_FR_m-ailabs_gilles_g_le_blanc.wav diff --git a/mimic3_foreign_4x/fr_FR_m-ailabs_nadine_eckert_boulet.wav b/assets/mimic3_foreign_4x/fr_FR_m-ailabs_nadine_eckert_boulet.wav similarity index 100% rename from mimic3_foreign_4x/fr_FR_m-ailabs_nadine_eckert_boulet.wav rename to assets/mimic3_foreign_4x/fr_FR_m-ailabs_nadine_eckert_boulet.wav diff --git a/mimic3_foreign_4x/fr_FR_m-ailabs_zeckou.wav b/assets/mimic3_foreign_4x/fr_FR_m-ailabs_zeckou.wav similarity index 100% rename from mimic3_foreign_4x/fr_FR_m-ailabs_zeckou.wav rename to assets/mimic3_foreign_4x/fr_FR_m-ailabs_zeckou.wav diff --git a/mimic3_foreign_4x/fr_FR_siwis.wav b/assets/mimic3_foreign_4x/fr_FR_siwis.wav similarity index 100% rename from mimic3_foreign_4x/fr_FR_siwis.wav rename to assets/mimic3_foreign_4x/fr_FR_siwis.wav diff --git a/mimic3_foreign_4x/fr_FR_tom.wav b/assets/mimic3_foreign_4x/fr_FR_tom.wav similarity index 100% rename from mimic3_foreign_4x/fr_FR_tom.wav rename to assets/mimic3_foreign_4x/fr_FR_tom.wav diff --git a/mimic3_foreign_4x/gu_IN_cmu-indic_cmu_indic_guj_ad.wav b/assets/mimic3_foreign_4x/gu_IN_cmu-indic_cmu_indic_guj_ad.wav similarity index 100% rename from mimic3_foreign_4x/gu_IN_cmu-indic_cmu_indic_guj_ad.wav rename to assets/mimic3_foreign_4x/gu_IN_cmu-indic_cmu_indic_guj_ad.wav diff --git a/mimic3_foreign_4x/gu_IN_cmu-indic_cmu_indic_guj_dp.wav b/assets/mimic3_foreign_4x/gu_IN_cmu-indic_cmu_indic_guj_dp.wav similarity index 100% rename from mimic3_foreign_4x/gu_IN_cmu-indic_cmu_indic_guj_dp.wav rename to assets/mimic3_foreign_4x/gu_IN_cmu-indic_cmu_indic_guj_dp.wav diff --git a/mimic3_foreign_4x/gu_IN_cmu-indic_cmu_indic_guj_kt.wav b/assets/mimic3_foreign_4x/gu_IN_cmu-indic_cmu_indic_guj_kt.wav similarity index 100% rename from mimic3_foreign_4x/gu_IN_cmu-indic_cmu_indic_guj_kt.wav rename to assets/mimic3_foreign_4x/gu_IN_cmu-indic_cmu_indic_guj_kt.wav diff --git a/mimic3_foreign_4x/ha_NE_openbible.wav b/assets/mimic3_foreign_4x/ha_NE_openbible.wav similarity index 100% rename from mimic3_foreign_4x/ha_NE_openbible.wav rename to assets/mimic3_foreign_4x/ha_NE_openbible.wav diff --git a/mimic3_foreign_4x/hu_HU_diana-majlinger.wav b/assets/mimic3_foreign_4x/hu_HU_diana-majlinger.wav similarity index 100% rename from mimic3_foreign_4x/hu_HU_diana-majlinger.wav rename to assets/mimic3_foreign_4x/hu_HU_diana-majlinger.wav diff --git a/mimic3_foreign_4x/it_IT_mls_10446.wav b/assets/mimic3_foreign_4x/it_IT_mls_10446.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_10446.wav rename to assets/mimic3_foreign_4x/it_IT_mls_10446.wav diff --git a/mimic3_foreign_4x/it_IT_mls_1157.wav b/assets/mimic3_foreign_4x/it_IT_mls_1157.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_1157.wav rename to assets/mimic3_foreign_4x/it_IT_mls_1157.wav diff --git a/mimic3_foreign_4x/it_IT_mls_12428.wav b/assets/mimic3_foreign_4x/it_IT_mls_12428.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_12428.wav rename to assets/mimic3_foreign_4x/it_IT_mls_12428.wav diff --git a/mimic3_foreign_4x/it_IT_mls_12804.wav b/assets/mimic3_foreign_4x/it_IT_mls_12804.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_12804.wav rename to assets/mimic3_foreign_4x/it_IT_mls_12804.wav diff --git a/mimic3_foreign_4x/it_IT_mls_1595.wav b/assets/mimic3_foreign_4x/it_IT_mls_1595.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_1595.wav rename to assets/mimic3_foreign_4x/it_IT_mls_1595.wav diff --git a/mimic3_foreign_4x/it_IT_mls_1725.wav b/assets/mimic3_foreign_4x/it_IT_mls_1725.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_1725.wav rename to assets/mimic3_foreign_4x/it_IT_mls_1725.wav diff --git a/mimic3_foreign_4x/it_IT_mls_1989.wav b/assets/mimic3_foreign_4x/it_IT_mls_1989.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_1989.wav rename to assets/mimic3_foreign_4x/it_IT_mls_1989.wav diff --git a/mimic3_foreign_4x/it_IT_mls_2019.wav b/assets/mimic3_foreign_4x/it_IT_mls_2019.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_2019.wav rename to assets/mimic3_foreign_4x/it_IT_mls_2019.wav diff --git a/mimic3_foreign_4x/it_IT_mls_2033.wav b/assets/mimic3_foreign_4x/it_IT_mls_2033.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_2033.wav rename to assets/mimic3_foreign_4x/it_IT_mls_2033.wav diff --git a/mimic3_foreign_4x/it_IT_mls_277.wav b/assets/mimic3_foreign_4x/it_IT_mls_277.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_277.wav rename to assets/mimic3_foreign_4x/it_IT_mls_277.wav diff --git a/mimic3_foreign_4x/it_IT_mls_4649.wav b/assets/mimic3_foreign_4x/it_IT_mls_4649.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_4649.wav rename to assets/mimic3_foreign_4x/it_IT_mls_4649.wav diff --git a/mimic3_foreign_4x/it_IT_mls_4705.wav b/assets/mimic3_foreign_4x/it_IT_mls_4705.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_4705.wav rename to assets/mimic3_foreign_4x/it_IT_mls_4705.wav diff --git a/mimic3_foreign_4x/it_IT_mls_4971.wav b/assets/mimic3_foreign_4x/it_IT_mls_4971.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_4971.wav rename to assets/mimic3_foreign_4x/it_IT_mls_4971.wav diff --git a/mimic3_foreign_4x/it_IT_mls_4974.wav b/assets/mimic3_foreign_4x/it_IT_mls_4974.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_4974.wav rename to assets/mimic3_foreign_4x/it_IT_mls_4974.wav diff --git a/mimic3_foreign_4x/it_IT_mls_4975.wav b/assets/mimic3_foreign_4x/it_IT_mls_4975.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_4975.wav rename to assets/mimic3_foreign_4x/it_IT_mls_4975.wav diff --git a/mimic3_foreign_4x/it_IT_mls_4998.wav b/assets/mimic3_foreign_4x/it_IT_mls_4998.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_4998.wav rename to assets/mimic3_foreign_4x/it_IT_mls_4998.wav diff --git a/mimic3_foreign_4x/it_IT_mls_5010.wav b/assets/mimic3_foreign_4x/it_IT_mls_5010.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_5010.wav rename to assets/mimic3_foreign_4x/it_IT_mls_5010.wav diff --git a/mimic3_foreign_4x/it_IT_mls_5421.wav b/assets/mimic3_foreign_4x/it_IT_mls_5421.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_5421.wav rename to assets/mimic3_foreign_4x/it_IT_mls_5421.wav diff --git a/mimic3_foreign_4x/it_IT_mls_6001.wav b/assets/mimic3_foreign_4x/it_IT_mls_6001.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_6001.wav rename to assets/mimic3_foreign_4x/it_IT_mls_6001.wav diff --git a/mimic3_foreign_4x/it_IT_mls_6299.wav b/assets/mimic3_foreign_4x/it_IT_mls_6299.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_6299.wav rename to assets/mimic3_foreign_4x/it_IT_mls_6299.wav diff --git a/mimic3_foreign_4x/it_IT_mls_6348.wav b/assets/mimic3_foreign_4x/it_IT_mls_6348.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_6348.wav rename to assets/mimic3_foreign_4x/it_IT_mls_6348.wav diff --git a/mimic3_foreign_4x/it_IT_mls_643.wav b/assets/mimic3_foreign_4x/it_IT_mls_643.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_643.wav rename to assets/mimic3_foreign_4x/it_IT_mls_643.wav diff --git a/mimic3_foreign_4x/it_IT_mls_644.wav b/assets/mimic3_foreign_4x/it_IT_mls_644.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_644.wav rename to assets/mimic3_foreign_4x/it_IT_mls_644.wav diff --git a/mimic3_foreign_4x/it_IT_mls_659.wav b/assets/mimic3_foreign_4x/it_IT_mls_659.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_659.wav rename to assets/mimic3_foreign_4x/it_IT_mls_659.wav diff --git a/mimic3_foreign_4x/it_IT_mls_6744.wav b/assets/mimic3_foreign_4x/it_IT_mls_6744.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_6744.wav rename to assets/mimic3_foreign_4x/it_IT_mls_6744.wav diff --git a/mimic3_foreign_4x/it_IT_mls_6807.wav b/assets/mimic3_foreign_4x/it_IT_mls_6807.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_6807.wav rename to assets/mimic3_foreign_4x/it_IT_mls_6807.wav diff --git a/mimic3_foreign_4x/it_IT_mls_7405.wav b/assets/mimic3_foreign_4x/it_IT_mls_7405.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_7405.wav rename to assets/mimic3_foreign_4x/it_IT_mls_7405.wav diff --git a/mimic3_foreign_4x/it_IT_mls_7440.wav b/assets/mimic3_foreign_4x/it_IT_mls_7440.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_7440.wav rename to assets/mimic3_foreign_4x/it_IT_mls_7440.wav diff --git a/mimic3_foreign_4x/it_IT_mls_7444.wav b/assets/mimic3_foreign_4x/it_IT_mls_7444.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_7444.wav rename to assets/mimic3_foreign_4x/it_IT_mls_7444.wav diff --git a/mimic3_foreign_4x/it_IT_mls_7936.wav b/assets/mimic3_foreign_4x/it_IT_mls_7936.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_7936.wav rename to assets/mimic3_foreign_4x/it_IT_mls_7936.wav diff --git a/mimic3_foreign_4x/it_IT_mls_8181.wav b/assets/mimic3_foreign_4x/it_IT_mls_8181.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_8181.wav rename to assets/mimic3_foreign_4x/it_IT_mls_8181.wav diff --git a/mimic3_foreign_4x/it_IT_mls_8207.wav b/assets/mimic3_foreign_4x/it_IT_mls_8207.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_8207.wav rename to assets/mimic3_foreign_4x/it_IT_mls_8207.wav diff --git a/mimic3_foreign_4x/it_IT_mls_8384.wav b/assets/mimic3_foreign_4x/it_IT_mls_8384.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_8384.wav rename to assets/mimic3_foreign_4x/it_IT_mls_8384.wav diff --git a/mimic3_foreign_4x/it_IT_mls_844.wav b/assets/mimic3_foreign_4x/it_IT_mls_844.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_844.wav rename to assets/mimic3_foreign_4x/it_IT_mls_844.wav diff --git a/mimic3_foreign_4x/it_IT_mls_8461.wav b/assets/mimic3_foreign_4x/it_IT_mls_8461.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_8461.wav rename to assets/mimic3_foreign_4x/it_IT_mls_8461.wav diff --git a/mimic3_foreign_4x/it_IT_mls_8828.wav b/assets/mimic3_foreign_4x/it_IT_mls_8828.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_8828.wav rename to assets/mimic3_foreign_4x/it_IT_mls_8828.wav diff --git a/mimic3_foreign_4x/it_IT_mls_8842.wav b/assets/mimic3_foreign_4x/it_IT_mls_8842.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_8842.wav rename to assets/mimic3_foreign_4x/it_IT_mls_8842.wav diff --git a/mimic3_foreign_4x/it_IT_mls_9185.wav b/assets/mimic3_foreign_4x/it_IT_mls_9185.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_9185.wav rename to assets/mimic3_foreign_4x/it_IT_mls_9185.wav diff --git a/mimic3_foreign_4x/it_IT_mls_9772.wav b/assets/mimic3_foreign_4x/it_IT_mls_9772.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_mls_9772.wav rename to assets/mimic3_foreign_4x/it_IT_mls_9772.wav diff --git a/mimic3_foreign_4x/it_IT_riccardo-fasol.wav b/assets/mimic3_foreign_4x/it_IT_riccardo-fasol.wav similarity index 100% rename from mimic3_foreign_4x/it_IT_riccardo-fasol.wav rename to assets/mimic3_foreign_4x/it_IT_riccardo-fasol.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_00027.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_00027.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_00027.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_00027.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_00264.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_00264.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_00264.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_00264.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_00658.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_00658.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_00658.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_00658.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_01392.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_01392.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_01392.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_01392.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_01519.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_01519.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_01519.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_01519.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_01932.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_01932.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_01932.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_01932.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_02059.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_02059.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_02059.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_02059.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_02326.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_02326.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_02326.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_02326.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_02884.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_02884.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_02884.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_02884.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_03187.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_03187.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_03187.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_03187.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_03314.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_03314.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_03314.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_03314.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_03424.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_03424.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_03424.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_03424.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_03727.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_03727.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_03727.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_03727.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_04175.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_04175.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_04175.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_04175.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_04285.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_04285.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_04285.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_04285.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_04588.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_04588.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_04588.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_04588.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_04679.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_04679.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_04679.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_04679.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_04715.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_04715.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_04715.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_04715.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_04982.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_04982.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_04982.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_04982.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_05219.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_05219.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_05219.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_05219.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_05522.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_05522.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_05522.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_05522.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_05540.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_05540.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_05540.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_05540.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_05667.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_05667.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_05667.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_05667.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_05970.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_05970.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_05970.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_05970.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_06080.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_06080.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_06080.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_06080.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_06207.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_06207.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_06207.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_06207.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_06383.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_06383.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_06383.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_06383.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_06510.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_06510.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_06510.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_06510.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_06941.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_06941.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_06941.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_06941.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_07335.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_07335.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_07335.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_07335.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_07638.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_07638.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_07638.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_07638.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_07765.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_07765.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_07765.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_07765.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_07875.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_07875.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_07875.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_07875.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_08002.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_08002.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_08002.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_08002.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_08178.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_08178.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_08178.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_08178.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_08305.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_08305.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_08305.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_08305.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_08736.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_08736.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_08736.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_08736.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_09039.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_09039.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_09039.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_09039.wav diff --git a/mimic3_foreign_4x/jv_ID_google-gmu_09724.wav b/assets/mimic3_foreign_4x/jv_ID_google-gmu_09724.wav similarity index 100% rename from mimic3_foreign_4x/jv_ID_google-gmu_09724.wav rename to assets/mimic3_foreign_4x/jv_ID_google-gmu_09724.wav diff --git a/mimic3_foreign_4x/ko_KO_kss.wav b/assets/mimic3_foreign_4x/ko_KO_kss.wav similarity index 100% rename from mimic3_foreign_4x/ko_KO_kss.wav rename to assets/mimic3_foreign_4x/ko_KO_kss.wav diff --git a/mimic3_foreign_4x/ne_NP_ne-google_0258.wav b/assets/mimic3_foreign_4x/ne_NP_ne-google_0258.wav similarity index 100% rename from mimic3_foreign_4x/ne_NP_ne-google_0258.wav rename to assets/mimic3_foreign_4x/ne_NP_ne-google_0258.wav diff --git a/mimic3_foreign_4x/ne_NP_ne-google_0283.wav b/assets/mimic3_foreign_4x/ne_NP_ne-google_0283.wav similarity index 100% rename from mimic3_foreign_4x/ne_NP_ne-google_0283.wav rename to assets/mimic3_foreign_4x/ne_NP_ne-google_0283.wav diff --git a/mimic3_foreign_4x/ne_NP_ne-google_0546.wav b/assets/mimic3_foreign_4x/ne_NP_ne-google_0546.wav similarity index 100% rename from mimic3_foreign_4x/ne_NP_ne-google_0546.wav rename to assets/mimic3_foreign_4x/ne_NP_ne-google_0546.wav diff --git a/mimic3_foreign_4x/ne_NP_ne-google_0649.wav b/assets/mimic3_foreign_4x/ne_NP_ne-google_0649.wav similarity index 100% rename from mimic3_foreign_4x/ne_NP_ne-google_0649.wav rename to assets/mimic3_foreign_4x/ne_NP_ne-google_0649.wav diff --git a/mimic3_foreign_4x/ne_NP_ne-google_0883.wav b/assets/mimic3_foreign_4x/ne_NP_ne-google_0883.wav similarity index 100% rename from mimic3_foreign_4x/ne_NP_ne-google_0883.wav rename to assets/mimic3_foreign_4x/ne_NP_ne-google_0883.wav diff --git a/mimic3_foreign_4x/ne_NP_ne-google_2027.wav b/assets/mimic3_foreign_4x/ne_NP_ne-google_2027.wav similarity index 100% rename from mimic3_foreign_4x/ne_NP_ne-google_2027.wav rename to assets/mimic3_foreign_4x/ne_NP_ne-google_2027.wav diff --git a/mimic3_foreign_4x/ne_NP_ne-google_2099.wav b/assets/mimic3_foreign_4x/ne_NP_ne-google_2099.wav similarity index 100% rename from mimic3_foreign_4x/ne_NP_ne-google_2099.wav rename to assets/mimic3_foreign_4x/ne_NP_ne-google_2099.wav diff --git a/mimic3_foreign_4x/ne_NP_ne-google_2139.wav b/assets/mimic3_foreign_4x/ne_NP_ne-google_2139.wav similarity index 100% rename from mimic3_foreign_4x/ne_NP_ne-google_2139.wav rename to assets/mimic3_foreign_4x/ne_NP_ne-google_2139.wav diff --git a/mimic3_foreign_4x/ne_NP_ne-google_3154.wav b/assets/mimic3_foreign_4x/ne_NP_ne-google_3154.wav similarity index 100% rename from mimic3_foreign_4x/ne_NP_ne-google_3154.wav rename to assets/mimic3_foreign_4x/ne_NP_ne-google_3154.wav diff --git a/mimic3_foreign_4x/ne_NP_ne-google_3614.wav b/assets/mimic3_foreign_4x/ne_NP_ne-google_3614.wav similarity index 100% rename from mimic3_foreign_4x/ne_NP_ne-google_3614.wav rename to assets/mimic3_foreign_4x/ne_NP_ne-google_3614.wav diff --git a/mimic3_foreign_4x/ne_NP_ne-google_3960.wav b/assets/mimic3_foreign_4x/ne_NP_ne-google_3960.wav similarity index 100% rename from mimic3_foreign_4x/ne_NP_ne-google_3960.wav rename to assets/mimic3_foreign_4x/ne_NP_ne-google_3960.wav diff --git a/mimic3_foreign_4x/ne_NP_ne-google_3997.wav b/assets/mimic3_foreign_4x/ne_NP_ne-google_3997.wav similarity index 100% rename from mimic3_foreign_4x/ne_NP_ne-google_3997.wav rename to assets/mimic3_foreign_4x/ne_NP_ne-google_3997.wav diff --git a/mimic3_foreign_4x/ne_NP_ne-google_5687.wav b/assets/mimic3_foreign_4x/ne_NP_ne-google_5687.wav similarity index 100% rename from mimic3_foreign_4x/ne_NP_ne-google_5687.wav rename to assets/mimic3_foreign_4x/ne_NP_ne-google_5687.wav diff --git a/mimic3_foreign_4x/ne_NP_ne-google_6329.wav b/assets/mimic3_foreign_4x/ne_NP_ne-google_6329.wav similarity index 100% rename from mimic3_foreign_4x/ne_NP_ne-google_6329.wav rename to assets/mimic3_foreign_4x/ne_NP_ne-google_6329.wav diff --git a/mimic3_foreign_4x/ne_NP_ne-google_6587.wav b/assets/mimic3_foreign_4x/ne_NP_ne-google_6587.wav similarity index 100% rename from mimic3_foreign_4x/ne_NP_ne-google_6587.wav rename to assets/mimic3_foreign_4x/ne_NP_ne-google_6587.wav diff --git a/mimic3_foreign_4x/ne_NP_ne-google_6834.wav b/assets/mimic3_foreign_4x/ne_NP_ne-google_6834.wav similarity index 100% rename from mimic3_foreign_4x/ne_NP_ne-google_6834.wav rename to assets/mimic3_foreign_4x/ne_NP_ne-google_6834.wav diff --git a/mimic3_foreign_4x/ne_NP_ne-google_7957.wav b/assets/mimic3_foreign_4x/ne_NP_ne-google_7957.wav similarity index 100% rename from mimic3_foreign_4x/ne_NP_ne-google_7957.wav rename to assets/mimic3_foreign_4x/ne_NP_ne-google_7957.wav diff --git a/mimic3_foreign_4x/ne_NP_ne-google_9407.wav b/assets/mimic3_foreign_4x/ne_NP_ne-google_9407.wav similarity index 100% rename from mimic3_foreign_4x/ne_NP_ne-google_9407.wav rename to assets/mimic3_foreign_4x/ne_NP_ne-google_9407.wav diff --git a/mimic3_foreign_4x/nl_bart-de-leeuw.wav b/assets/mimic3_foreign_4x/nl_bart-de-leeuw.wav similarity index 100% rename from mimic3_foreign_4x/nl_bart-de-leeuw.wav rename to assets/mimic3_foreign_4x/nl_bart-de-leeuw.wav diff --git a/mimic3_foreign_4x/nl_flemishguy.wav b/assets/mimic3_foreign_4x/nl_flemishguy.wav similarity index 100% rename from mimic3_foreign_4x/nl_flemishguy.wav rename to assets/mimic3_foreign_4x/nl_flemishguy.wav diff --git a/mimic3_foreign_4x/nl_nathalie.wav b/assets/mimic3_foreign_4x/nl_nathalie.wav similarity index 100% rename from mimic3_foreign_4x/nl_nathalie.wav rename to assets/mimic3_foreign_4x/nl_nathalie.wav diff --git a/mimic3_foreign_4x/nl_pmk.wav b/assets/mimic3_foreign_4x/nl_pmk.wav similarity index 100% rename from mimic3_foreign_4x/nl_pmk.wav rename to assets/mimic3_foreign_4x/nl_pmk.wav diff --git a/mimic3_foreign_4x/nl_rdh.wav b/assets/mimic3_foreign_4x/nl_rdh.wav similarity index 100% rename from mimic3_foreign_4x/nl_rdh.wav rename to assets/mimic3_foreign_4x/nl_rdh.wav diff --git a/mimic3_foreign_4x/pl_PL_m-ailabs_nina_brown.wav b/assets/mimic3_foreign_4x/pl_PL_m-ailabs_nina_brown.wav similarity index 100% rename from mimic3_foreign_4x/pl_PL_m-ailabs_nina_brown.wav rename to assets/mimic3_foreign_4x/pl_PL_m-ailabs_nina_brown.wav diff --git a/mimic3_foreign_4x/pl_PL_m-ailabs_piotr_nater.wav b/assets/mimic3_foreign_4x/pl_PL_m-ailabs_piotr_nater.wav similarity index 100% rename from mimic3_foreign_4x/pl_PL_m-ailabs_piotr_nater.wav rename to assets/mimic3_foreign_4x/pl_PL_m-ailabs_piotr_nater.wav diff --git a/mimic3_foreign_4x/ru_RU_multi_hajdurova.wav b/assets/mimic3_foreign_4x/ru_RU_multi_hajdurova.wav similarity index 100% rename from mimic3_foreign_4x/ru_RU_multi_hajdurova.wav rename to assets/mimic3_foreign_4x/ru_RU_multi_hajdurova.wav diff --git a/mimic3_foreign_4x/ru_RU_multi_minaev.wav b/assets/mimic3_foreign_4x/ru_RU_multi_minaev.wav similarity index 100% rename from mimic3_foreign_4x/ru_RU_multi_minaev.wav rename to assets/mimic3_foreign_4x/ru_RU_multi_minaev.wav diff --git a/mimic3_foreign_4x/ru_RU_multi_nikolaev.wav b/assets/mimic3_foreign_4x/ru_RU_multi_nikolaev.wav similarity index 100% rename from mimic3_foreign_4x/ru_RU_multi_nikolaev.wav rename to assets/mimic3_foreign_4x/ru_RU_multi_nikolaev.wav diff --git a/mimic3_foreign_4x/sw_lanfrica.wav b/assets/mimic3_foreign_4x/sw_lanfrica.wav similarity index 100% rename from mimic3_foreign_4x/sw_lanfrica.wav rename to assets/mimic3_foreign_4x/sw_lanfrica.wav diff --git a/mimic3_foreign_4x/te_IN_cmu-indic_kpn.wav b/assets/mimic3_foreign_4x/te_IN_cmu-indic_kpn.wav similarity index 100% rename from mimic3_foreign_4x/te_IN_cmu-indic_kpn.wav rename to assets/mimic3_foreign_4x/te_IN_cmu-indic_kpn.wav diff --git a/mimic3_foreign_4x/te_IN_cmu-indic_sk.wav b/assets/mimic3_foreign_4x/te_IN_cmu-indic_sk.wav similarity index 100% rename from mimic3_foreign_4x/te_IN_cmu-indic_sk.wav rename to assets/mimic3_foreign_4x/te_IN_cmu-indic_sk.wav diff --git a/mimic3_foreign_4x/te_IN_cmu-indic_ss.wav b/assets/mimic3_foreign_4x/te_IN_cmu-indic_ss.wav similarity index 100% rename from mimic3_foreign_4x/te_IN_cmu-indic_ss.wav rename to assets/mimic3_foreign_4x/te_IN_cmu-indic_ss.wav diff --git a/mimic3_foreign_4x/tn_ZA_google-nwu_0045.wav b/assets/mimic3_foreign_4x/tn_ZA_google-nwu_0045.wav similarity index 100% rename from mimic3_foreign_4x/tn_ZA_google-nwu_0045.wav rename to assets/mimic3_foreign_4x/tn_ZA_google-nwu_0045.wav diff --git a/mimic3_foreign_4x/tn_ZA_google-nwu_0378.wav b/assets/mimic3_foreign_4x/tn_ZA_google-nwu_0378.wav similarity index 100% rename from mimic3_foreign_4x/tn_ZA_google-nwu_0378.wav rename to assets/mimic3_foreign_4x/tn_ZA_google-nwu_0378.wav diff --git a/mimic3_foreign_4x/tn_ZA_google-nwu_0441.wav b/assets/mimic3_foreign_4x/tn_ZA_google-nwu_0441.wav similarity index 100% rename from mimic3_foreign_4x/tn_ZA_google-nwu_0441.wav rename to assets/mimic3_foreign_4x/tn_ZA_google-nwu_0441.wav diff --git a/mimic3_foreign_4x/tn_ZA_google-nwu_1483.wav b/assets/mimic3_foreign_4x/tn_ZA_google-nwu_1483.wav similarity index 100% rename from mimic3_foreign_4x/tn_ZA_google-nwu_1483.wav rename to assets/mimic3_foreign_4x/tn_ZA_google-nwu_1483.wav diff --git a/mimic3_foreign_4x/tn_ZA_google-nwu_1498.wav b/assets/mimic3_foreign_4x/tn_ZA_google-nwu_1498.wav similarity index 100% rename from mimic3_foreign_4x/tn_ZA_google-nwu_1498.wav rename to assets/mimic3_foreign_4x/tn_ZA_google-nwu_1498.wav diff --git a/mimic3_foreign_4x/tn_ZA_google-nwu_1932.wav b/assets/mimic3_foreign_4x/tn_ZA_google-nwu_1932.wav similarity index 100% rename from mimic3_foreign_4x/tn_ZA_google-nwu_1932.wav rename to assets/mimic3_foreign_4x/tn_ZA_google-nwu_1932.wav diff --git a/mimic3_foreign_4x/tn_ZA_google-nwu_2839.wav b/assets/mimic3_foreign_4x/tn_ZA_google-nwu_2839.wav similarity index 100% rename from mimic3_foreign_4x/tn_ZA_google-nwu_2839.wav rename to assets/mimic3_foreign_4x/tn_ZA_google-nwu_2839.wav diff --git a/mimic3_foreign_4x/tn_ZA_google-nwu_3342.wav b/assets/mimic3_foreign_4x/tn_ZA_google-nwu_3342.wav similarity index 100% rename from mimic3_foreign_4x/tn_ZA_google-nwu_3342.wav rename to assets/mimic3_foreign_4x/tn_ZA_google-nwu_3342.wav diff --git a/mimic3_foreign_4x/tn_ZA_google-nwu_3629.wav b/assets/mimic3_foreign_4x/tn_ZA_google-nwu_3629.wav similarity index 100% rename from mimic3_foreign_4x/tn_ZA_google-nwu_3629.wav rename to assets/mimic3_foreign_4x/tn_ZA_google-nwu_3629.wav diff --git a/mimic3_foreign_4x/tn_ZA_google-nwu_4506.wav b/assets/mimic3_foreign_4x/tn_ZA_google-nwu_4506.wav similarity index 100% rename from mimic3_foreign_4x/tn_ZA_google-nwu_4506.wav rename to assets/mimic3_foreign_4x/tn_ZA_google-nwu_4506.wav diff --git a/mimic3_foreign_4x/tn_ZA_google-nwu_4850.wav b/assets/mimic3_foreign_4x/tn_ZA_google-nwu_4850.wav similarity index 100% rename from mimic3_foreign_4x/tn_ZA_google-nwu_4850.wav rename to assets/mimic3_foreign_4x/tn_ZA_google-nwu_4850.wav diff --git a/mimic3_foreign_4x/tn_ZA_google-nwu_5628.wav b/assets/mimic3_foreign_4x/tn_ZA_google-nwu_5628.wav similarity index 100% rename from mimic3_foreign_4x/tn_ZA_google-nwu_5628.wav rename to assets/mimic3_foreign_4x/tn_ZA_google-nwu_5628.wav diff --git a/mimic3_foreign_4x/tn_ZA_google-nwu_6116.wav b/assets/mimic3_foreign_4x/tn_ZA_google-nwu_6116.wav similarity index 100% rename from mimic3_foreign_4x/tn_ZA_google-nwu_6116.wav rename to assets/mimic3_foreign_4x/tn_ZA_google-nwu_6116.wav diff --git a/mimic3_foreign_4x/tn_ZA_google-nwu_6206.wav b/assets/mimic3_foreign_4x/tn_ZA_google-nwu_6206.wav similarity index 100% rename from mimic3_foreign_4x/tn_ZA_google-nwu_6206.wav rename to assets/mimic3_foreign_4x/tn_ZA_google-nwu_6206.wav diff --git a/mimic3_foreign_4x/tn_ZA_google-nwu_6234.wav b/assets/mimic3_foreign_4x/tn_ZA_google-nwu_6234.wav similarity index 100% rename from mimic3_foreign_4x/tn_ZA_google-nwu_6234.wav rename to assets/mimic3_foreign_4x/tn_ZA_google-nwu_6234.wav diff --git a/mimic3_foreign_4x/tn_ZA_google-nwu_6459.wav b/assets/mimic3_foreign_4x/tn_ZA_google-nwu_6459.wav similarity index 100% rename from mimic3_foreign_4x/tn_ZA_google-nwu_6459.wav rename to assets/mimic3_foreign_4x/tn_ZA_google-nwu_6459.wav diff --git a/mimic3_foreign_4x/tn_ZA_google-nwu_7674.wav b/assets/mimic3_foreign_4x/tn_ZA_google-nwu_7674.wav similarity index 100% rename from mimic3_foreign_4x/tn_ZA_google-nwu_7674.wav rename to assets/mimic3_foreign_4x/tn_ZA_google-nwu_7674.wav diff --git a/mimic3_foreign_4x/tn_ZA_google-nwu_7693.wav b/assets/mimic3_foreign_4x/tn_ZA_google-nwu_7693.wav similarity index 100% rename from mimic3_foreign_4x/tn_ZA_google-nwu_7693.wav rename to assets/mimic3_foreign_4x/tn_ZA_google-nwu_7693.wav diff --git a/mimic3_foreign_4x/tn_ZA_google-nwu_7866.wav b/assets/mimic3_foreign_4x/tn_ZA_google-nwu_7866.wav similarity index 100% rename from mimic3_foreign_4x/tn_ZA_google-nwu_7866.wav rename to assets/mimic3_foreign_4x/tn_ZA_google-nwu_7866.wav diff --git a/mimic3_foreign_4x/tn_ZA_google-nwu_7896.wav b/assets/mimic3_foreign_4x/tn_ZA_google-nwu_7896.wav similarity index 100% rename from mimic3_foreign_4x/tn_ZA_google-nwu_7896.wav rename to assets/mimic3_foreign_4x/tn_ZA_google-nwu_7896.wav diff --git a/mimic3_foreign_4x/tn_ZA_google-nwu_8333.wav b/assets/mimic3_foreign_4x/tn_ZA_google-nwu_8333.wav similarity index 100% rename from mimic3_foreign_4x/tn_ZA_google-nwu_8333.wav rename to assets/mimic3_foreign_4x/tn_ZA_google-nwu_8333.wav diff --git a/mimic3_foreign_4x/tn_ZA_google-nwu_8512.wav b/assets/mimic3_foreign_4x/tn_ZA_google-nwu_8512.wav similarity index 100% rename from mimic3_foreign_4x/tn_ZA_google-nwu_8512.wav rename to assets/mimic3_foreign_4x/tn_ZA_google-nwu_8512.wav diff --git a/mimic3_foreign_4x/tn_ZA_google-nwu_8532.wav b/assets/mimic3_foreign_4x/tn_ZA_google-nwu_8532.wav similarity index 100% rename from mimic3_foreign_4x/tn_ZA_google-nwu_8532.wav rename to assets/mimic3_foreign_4x/tn_ZA_google-nwu_8532.wav diff --git a/mimic3_foreign_4x/tn_ZA_google-nwu_8914.wav b/assets/mimic3_foreign_4x/tn_ZA_google-nwu_8914.wav similarity index 100% rename from mimic3_foreign_4x/tn_ZA_google-nwu_8914.wav rename to assets/mimic3_foreign_4x/tn_ZA_google-nwu_8914.wav diff --git a/mimic3_foreign_4x/tn_ZA_google-nwu_9061.wav b/assets/mimic3_foreign_4x/tn_ZA_google-nwu_9061.wav similarity index 100% rename from mimic3_foreign_4x/tn_ZA_google-nwu_9061.wav rename to assets/mimic3_foreign_4x/tn_ZA_google-nwu_9061.wav diff --git a/mimic3_foreign_4x/tn_ZA_google-nwu_9365.wav b/assets/mimic3_foreign_4x/tn_ZA_google-nwu_9365.wav similarity index 100% rename from mimic3_foreign_4x/tn_ZA_google-nwu_9365.wav rename to assets/mimic3_foreign_4x/tn_ZA_google-nwu_9365.wav diff --git a/mimic3_foreign_4x/uk_UK_m-ailabs_loboda.wav b/assets/mimic3_foreign_4x/uk_UK_m-ailabs_loboda.wav similarity index 100% rename from mimic3_foreign_4x/uk_UK_m-ailabs_loboda.wav rename to assets/mimic3_foreign_4x/uk_UK_m-ailabs_loboda.wav diff --git a/mimic3_foreign_4x/uk_UK_m-ailabs_miskun.wav b/assets/mimic3_foreign_4x/uk_UK_m-ailabs_miskun.wav similarity index 100% rename from mimic3_foreign_4x/uk_UK_m-ailabs_miskun.wav rename to assets/mimic3_foreign_4x/uk_UK_m-ailabs_miskun.wav diff --git a/mimic3_foreign_4x/uk_UK_m-ailabs_obruchov.wav b/assets/mimic3_foreign_4x/uk_UK_m-ailabs_obruchov.wav similarity index 100% rename from mimic3_foreign_4x/uk_UK_m-ailabs_obruchov.wav rename to assets/mimic3_foreign_4x/uk_UK_m-ailabs_obruchov.wav diff --git a/mimic3_foreign_4x/uk_UK_m-ailabs_pysariev.wav b/assets/mimic3_foreign_4x/uk_UK_m-ailabs_pysariev.wav similarity index 100% rename from mimic3_foreign_4x/uk_UK_m-ailabs_pysariev.wav rename to assets/mimic3_foreign_4x/uk_UK_m-ailabs_pysariev.wav diff --git a/mimic3_foreign_4x/uk_UK_m-ailabs_shepel.wav b/assets/mimic3_foreign_4x/uk_UK_m-ailabs_shepel.wav similarity index 100% rename from mimic3_foreign_4x/uk_UK_m-ailabs_shepel.wav rename to assets/mimic3_foreign_4x/uk_UK_m-ailabs_shepel.wav diff --git a/mimic3_foreign_4x/uk_UK_m-ailabs_sumska.wav b/assets/mimic3_foreign_4x/uk_UK_m-ailabs_sumska.wav similarity index 100% rename from mimic3_foreign_4x/uk_UK_m-ailabs_sumska.wav rename to assets/mimic3_foreign_4x/uk_UK_m-ailabs_sumska.wav diff --git a/mimic3_foreign_4x/vi_VN_vais1000.wav b/assets/mimic3_foreign_4x/vi_VN_vais1000.wav similarity index 100% rename from mimic3_foreign_4x/vi_VN_vais1000.wav rename to assets/mimic3_foreign_4x/vi_VN_vais1000.wav diff --git a/mimic3_foreign_4x/yo_openbible.wav b/assets/mimic3_foreign_4x/yo_openbible.wav similarity index 100% rename from mimic3_foreign_4x/yo_openbible.wav rename to assets/mimic3_foreign_4x/yo_openbible.wav diff --git a/style_vector/en_UK_apope.wav b/assets/style_vector/en_UK_apope.wav similarity index 100% rename from style_vector/en_UK_apope.wav rename to assets/style_vector/en_UK_apope.wav diff --git a/style_vector/en_US_cmu_arctic_aew.wav b/assets/style_vector/en_US_cmu_arctic_aew.wav similarity index 100% rename from style_vector/en_US_cmu_arctic_aew.wav rename to assets/style_vector/en_US_cmu_arctic_aew.wav diff --git a/style_vector/en_US_cmu_arctic_ahw.wav b/assets/style_vector/en_US_cmu_arctic_ahw.wav similarity index 100% rename from style_vector/en_US_cmu_arctic_ahw.wav rename to assets/style_vector/en_US_cmu_arctic_ahw.wav diff --git a/style_vector/en_US_cmu_arctic_aup.wav b/assets/style_vector/en_US_cmu_arctic_aup.wav similarity index 100% rename from style_vector/en_US_cmu_arctic_aup.wav rename to assets/style_vector/en_US_cmu_arctic_aup.wav diff --git a/style_vector/en_US_cmu_arctic_awbrms.wav b/assets/style_vector/en_US_cmu_arctic_awbrms.wav similarity index 100% rename from style_vector/en_US_cmu_arctic_awbrms.wav rename to assets/style_vector/en_US_cmu_arctic_awbrms.wav diff --git a/style_vector/en_US_cmu_arctic_axb.wav b/assets/style_vector/en_US_cmu_arctic_axb.wav similarity index 100% rename from style_vector/en_US_cmu_arctic_axb.wav rename to assets/style_vector/en_US_cmu_arctic_axb.wav diff --git a/style_vector/en_US_cmu_arctic_bdl.wav b/assets/style_vector/en_US_cmu_arctic_bdl.wav similarity index 100% rename from style_vector/en_US_cmu_arctic_bdl.wav rename to assets/style_vector/en_US_cmu_arctic_bdl.wav diff --git a/style_vector/en_US_cmu_arctic_clb.wav b/assets/style_vector/en_US_cmu_arctic_clb.wav similarity index 100% rename from style_vector/en_US_cmu_arctic_clb.wav rename to assets/style_vector/en_US_cmu_arctic_clb.wav diff --git a/style_vector/en_US_cmu_arctic_eey.wav b/assets/style_vector/en_US_cmu_arctic_eey.wav similarity index 100% rename from style_vector/en_US_cmu_arctic_eey.wav rename to assets/style_vector/en_US_cmu_arctic_eey.wav diff --git a/style_vector/en_US_cmu_arctic_fem.wav b/assets/style_vector/en_US_cmu_arctic_fem.wav similarity index 100% rename from style_vector/en_US_cmu_arctic_fem.wav rename to assets/style_vector/en_US_cmu_arctic_fem.wav diff --git a/style_vector/en_US_cmu_arctic_gka.wav b/assets/style_vector/en_US_cmu_arctic_gka.wav similarity index 100% rename from style_vector/en_US_cmu_arctic_gka.wav rename to assets/style_vector/en_US_cmu_arctic_gka.wav diff --git a/style_vector/en_US_cmu_arctic_jmk.wav b/assets/style_vector/en_US_cmu_arctic_jmk.wav similarity index 100% rename from style_vector/en_US_cmu_arctic_jmk.wav rename to assets/style_vector/en_US_cmu_arctic_jmk.wav diff --git a/style_vector/en_US_cmu_arctic_ksp.wav b/assets/style_vector/en_US_cmu_arctic_ksp.wav similarity index 100% rename from style_vector/en_US_cmu_arctic_ksp.wav rename to assets/style_vector/en_US_cmu_arctic_ksp.wav diff --git a/style_vector/en_US_cmu_arctic_ljm.wav b/assets/style_vector/en_US_cmu_arctic_ljm.wav similarity index 100% rename from style_vector/en_US_cmu_arctic_ljm.wav rename to assets/style_vector/en_US_cmu_arctic_ljm.wav diff --git a/style_vector/en_US_cmu_arctic_lnh.wav b/assets/style_vector/en_US_cmu_arctic_lnh.wav similarity index 100% rename from style_vector/en_US_cmu_arctic_lnh.wav rename to assets/style_vector/en_US_cmu_arctic_lnh.wav diff --git a/style_vector/en_US_cmu_arctic_rxr.wav b/assets/style_vector/en_US_cmu_arctic_rxr.wav similarity index 100% rename from style_vector/en_US_cmu_arctic_rxr.wav rename to assets/style_vector/en_US_cmu_arctic_rxr.wav diff --git a/style_vector/en_US_cmu_arctic_slp.wav b/assets/style_vector/en_US_cmu_arctic_slp.wav similarity index 100% rename from style_vector/en_US_cmu_arctic_slp.wav rename to assets/style_vector/en_US_cmu_arctic_slp.wav diff --git a/style_vector/en_US_cmu_arctic_slt.wav b/assets/style_vector/en_US_cmu_arctic_slt.wav similarity index 100% rename from style_vector/en_US_cmu_arctic_slt.wav rename to assets/style_vector/en_US_cmu_arctic_slt.wav diff --git a/style_vector/en_US_hifi-tts_6097.wav b/assets/style_vector/en_US_hifi-tts_6097.wav similarity index 100% rename from style_vector/en_US_hifi-tts_6097.wav rename to assets/style_vector/en_US_hifi-tts_6097.wav diff --git a/style_vector/en_US_hifi-tts_9017.wav b/assets/style_vector/en_US_hifi-tts_9017.wav similarity index 100% rename from style_vector/en_US_hifi-tts_9017.wav rename to assets/style_vector/en_US_hifi-tts_9017.wav diff --git a/style_vector/en_US_hifi-tts_92.wav b/assets/style_vector/en_US_hifi-tts_92.wav similarity index 100% rename from style_vector/en_US_hifi-tts_92.wav rename to assets/style_vector/en_US_hifi-tts_92.wav diff --git a/style_vector/en_US_ljspeech.wav b/assets/style_vector/en_US_ljspeech.wav similarity index 100% rename from style_vector/en_US_ljspeech.wav rename to assets/style_vector/en_US_ljspeech.wav diff --git a/style_vector/en_US_m-ailabs_elliot_miller.wav b/assets/style_vector/en_US_m-ailabs_elliot_miller.wav similarity index 100% rename from style_vector/en_US_m-ailabs_elliot_miller.wav rename to assets/style_vector/en_US_m-ailabs_elliot_miller.wav diff --git a/style_vector/en_US_m-ailabs_judy_bieber.wav b/assets/style_vector/en_US_m-ailabs_judy_bieber.wav similarity index 100% rename from style_vector/en_US_m-ailabs_judy_bieber.wav rename to assets/style_vector/en_US_m-ailabs_judy_bieber.wav diff --git a/style_vector/en_US_m-ailabs_mary_ann.wav b/assets/style_vector/en_US_m-ailabs_mary_ann.wav similarity index 100% rename from style_vector/en_US_m-ailabs_mary_ann.wav rename to assets/style_vector/en_US_m-ailabs_mary_ann.wav diff --git a/style_vector/en_US_vctk_p225.wav b/assets/style_vector/en_US_vctk_p225.wav similarity index 100% rename from style_vector/en_US_vctk_p225.wav rename to assets/style_vector/en_US_vctk_p225.wav diff --git a/style_vector/en_US_vctk_p226.wav b/assets/style_vector/en_US_vctk_p226.wav similarity index 100% rename from style_vector/en_US_vctk_p226.wav rename to assets/style_vector/en_US_vctk_p226.wav diff --git a/style_vector/en_US_vctk_p227.wav b/assets/style_vector/en_US_vctk_p227.wav similarity index 100% rename from style_vector/en_US_vctk_p227.wav rename to assets/style_vector/en_US_vctk_p227.wav diff --git a/style_vector/en_US_vctk_p228.wav b/assets/style_vector/en_US_vctk_p228.wav similarity index 100% rename from style_vector/en_US_vctk_p228.wav rename to assets/style_vector/en_US_vctk_p228.wav diff --git a/style_vector/en_US_vctk_p229.wav b/assets/style_vector/en_US_vctk_p229.wav similarity index 100% rename from style_vector/en_US_vctk_p229.wav rename to assets/style_vector/en_US_vctk_p229.wav diff --git a/style_vector/en_US_vctk_p230.wav b/assets/style_vector/en_US_vctk_p230.wav similarity index 100% rename from style_vector/en_US_vctk_p230.wav rename to assets/style_vector/en_US_vctk_p230.wav diff --git a/style_vector/en_US_vctk_p231.wav b/assets/style_vector/en_US_vctk_p231.wav similarity index 100% rename from style_vector/en_US_vctk_p231.wav rename to assets/style_vector/en_US_vctk_p231.wav diff --git a/style_vector/en_US_vctk_p232.wav b/assets/style_vector/en_US_vctk_p232.wav similarity index 100% rename from style_vector/en_US_vctk_p232.wav rename to assets/style_vector/en_US_vctk_p232.wav diff --git a/style_vector/en_US_vctk_p233.wav b/assets/style_vector/en_US_vctk_p233.wav similarity index 100% rename from style_vector/en_US_vctk_p233.wav rename to assets/style_vector/en_US_vctk_p233.wav diff --git a/style_vector/en_US_vctk_p234.wav b/assets/style_vector/en_US_vctk_p234.wav similarity index 100% rename from style_vector/en_US_vctk_p234.wav rename to assets/style_vector/en_US_vctk_p234.wav diff --git a/style_vector/en_US_vctk_p236.wav b/assets/style_vector/en_US_vctk_p236.wav similarity index 100% rename from style_vector/en_US_vctk_p236.wav rename to assets/style_vector/en_US_vctk_p236.wav diff --git a/style_vector/en_US_vctk_p237.wav b/assets/style_vector/en_US_vctk_p237.wav similarity index 100% rename from style_vector/en_US_vctk_p237.wav rename to assets/style_vector/en_US_vctk_p237.wav diff --git a/style_vector/en_US_vctk_p238.wav b/assets/style_vector/en_US_vctk_p238.wav similarity index 100% rename from style_vector/en_US_vctk_p238.wav rename to assets/style_vector/en_US_vctk_p238.wav diff --git a/style_vector/en_US_vctk_p239.wav b/assets/style_vector/en_US_vctk_p239.wav similarity index 100% rename from style_vector/en_US_vctk_p239.wav rename to assets/style_vector/en_US_vctk_p239.wav diff --git a/style_vector/en_US_vctk_p240.wav b/assets/style_vector/en_US_vctk_p240.wav similarity index 100% rename from style_vector/en_US_vctk_p240.wav rename to assets/style_vector/en_US_vctk_p240.wav diff --git a/style_vector/en_US_vctk_p241.wav b/assets/style_vector/en_US_vctk_p241.wav similarity index 100% rename from style_vector/en_US_vctk_p241.wav rename to assets/style_vector/en_US_vctk_p241.wav diff --git a/style_vector/en_US_vctk_p243.wav b/assets/style_vector/en_US_vctk_p243.wav similarity index 100% rename from style_vector/en_US_vctk_p243.wav rename to assets/style_vector/en_US_vctk_p243.wav diff --git a/style_vector/en_US_vctk_p244.wav b/assets/style_vector/en_US_vctk_p244.wav similarity index 100% rename from style_vector/en_US_vctk_p244.wav rename to assets/style_vector/en_US_vctk_p244.wav diff --git a/style_vector/en_US_vctk_p245.wav b/assets/style_vector/en_US_vctk_p245.wav similarity index 100% rename from style_vector/en_US_vctk_p245.wav rename to assets/style_vector/en_US_vctk_p245.wav diff --git a/style_vector/en_US_vctk_p246.wav b/assets/style_vector/en_US_vctk_p246.wav similarity index 100% rename from style_vector/en_US_vctk_p246.wav rename to assets/style_vector/en_US_vctk_p246.wav diff --git a/style_vector/en_US_vctk_p247.wav b/assets/style_vector/en_US_vctk_p247.wav similarity index 100% rename from style_vector/en_US_vctk_p247.wav rename to assets/style_vector/en_US_vctk_p247.wav diff --git a/style_vector/en_US_vctk_p248.wav b/assets/style_vector/en_US_vctk_p248.wav similarity index 100% rename from style_vector/en_US_vctk_p248.wav rename to assets/style_vector/en_US_vctk_p248.wav diff --git a/style_vector/en_US_vctk_p249.wav b/assets/style_vector/en_US_vctk_p249.wav similarity index 100% rename from style_vector/en_US_vctk_p249.wav rename to assets/style_vector/en_US_vctk_p249.wav diff --git a/style_vector/en_US_vctk_p250.wav b/assets/style_vector/en_US_vctk_p250.wav similarity index 100% rename from style_vector/en_US_vctk_p250.wav rename to assets/style_vector/en_US_vctk_p250.wav diff --git a/style_vector/en_US_vctk_p251.wav b/assets/style_vector/en_US_vctk_p251.wav similarity index 100% rename from style_vector/en_US_vctk_p251.wav rename to assets/style_vector/en_US_vctk_p251.wav diff --git a/style_vector/en_US_vctk_p252.wav b/assets/style_vector/en_US_vctk_p252.wav similarity index 100% rename from style_vector/en_US_vctk_p252.wav rename to assets/style_vector/en_US_vctk_p252.wav diff --git a/style_vector/en_US_vctk_p253.wav b/assets/style_vector/en_US_vctk_p253.wav similarity index 100% rename from style_vector/en_US_vctk_p253.wav rename to assets/style_vector/en_US_vctk_p253.wav diff --git a/style_vector/en_US_vctk_p254.wav b/assets/style_vector/en_US_vctk_p254.wav similarity index 100% rename from style_vector/en_US_vctk_p254.wav rename to assets/style_vector/en_US_vctk_p254.wav diff --git a/style_vector/en_US_vctk_p255.wav b/assets/style_vector/en_US_vctk_p255.wav similarity index 100% rename from style_vector/en_US_vctk_p255.wav rename to assets/style_vector/en_US_vctk_p255.wav diff --git a/style_vector/en_US_vctk_p256.wav b/assets/style_vector/en_US_vctk_p256.wav similarity index 100% rename from style_vector/en_US_vctk_p256.wav rename to assets/style_vector/en_US_vctk_p256.wav diff --git a/style_vector/en_US_vctk_p257.wav b/assets/style_vector/en_US_vctk_p257.wav similarity index 100% rename from style_vector/en_US_vctk_p257.wav rename to assets/style_vector/en_US_vctk_p257.wav diff --git a/style_vector/en_US_vctk_p258.wav b/assets/style_vector/en_US_vctk_p258.wav similarity index 100% rename from style_vector/en_US_vctk_p258.wav rename to assets/style_vector/en_US_vctk_p258.wav diff --git a/style_vector/en_US_vctk_p259.wav b/assets/style_vector/en_US_vctk_p259.wav similarity index 100% rename from style_vector/en_US_vctk_p259.wav rename to assets/style_vector/en_US_vctk_p259.wav diff --git a/style_vector/en_US_vctk_p260.wav b/assets/style_vector/en_US_vctk_p260.wav similarity index 100% rename from style_vector/en_US_vctk_p260.wav rename to assets/style_vector/en_US_vctk_p260.wav diff --git a/style_vector/en_US_vctk_p261.wav b/assets/style_vector/en_US_vctk_p261.wav similarity index 100% rename from style_vector/en_US_vctk_p261.wav rename to assets/style_vector/en_US_vctk_p261.wav diff --git a/style_vector/en_US_vctk_p262.wav b/assets/style_vector/en_US_vctk_p262.wav similarity index 100% rename from style_vector/en_US_vctk_p262.wav rename to assets/style_vector/en_US_vctk_p262.wav diff --git a/style_vector/en_US_vctk_p263.wav b/assets/style_vector/en_US_vctk_p263.wav similarity index 100% rename from style_vector/en_US_vctk_p263.wav rename to assets/style_vector/en_US_vctk_p263.wav diff --git a/style_vector/en_US_vctk_p264.wav b/assets/style_vector/en_US_vctk_p264.wav similarity index 100% rename from style_vector/en_US_vctk_p264.wav rename to assets/style_vector/en_US_vctk_p264.wav diff --git a/style_vector/en_US_vctk_p265.wav b/assets/style_vector/en_US_vctk_p265.wav similarity index 100% rename from style_vector/en_US_vctk_p265.wav rename to assets/style_vector/en_US_vctk_p265.wav diff --git a/style_vector/en_US_vctk_p266.wav b/assets/style_vector/en_US_vctk_p266.wav similarity index 100% rename from style_vector/en_US_vctk_p266.wav rename to assets/style_vector/en_US_vctk_p266.wav diff --git a/style_vector/en_US_vctk_p267.wav b/assets/style_vector/en_US_vctk_p267.wav similarity index 100% rename from style_vector/en_US_vctk_p267.wav rename to assets/style_vector/en_US_vctk_p267.wav diff --git a/style_vector/en_US_vctk_p268.wav b/assets/style_vector/en_US_vctk_p268.wav similarity index 100% rename from style_vector/en_US_vctk_p268.wav rename to assets/style_vector/en_US_vctk_p268.wav diff --git a/style_vector/en_US_vctk_p269.wav b/assets/style_vector/en_US_vctk_p269.wav similarity index 100% rename from style_vector/en_US_vctk_p269.wav rename to assets/style_vector/en_US_vctk_p269.wav diff --git a/style_vector/en_US_vctk_p270.wav b/assets/style_vector/en_US_vctk_p270.wav similarity index 100% rename from style_vector/en_US_vctk_p270.wav rename to assets/style_vector/en_US_vctk_p270.wav diff --git a/style_vector/en_US_vctk_p271.wav b/assets/style_vector/en_US_vctk_p271.wav similarity index 100% rename from style_vector/en_US_vctk_p271.wav rename to assets/style_vector/en_US_vctk_p271.wav diff --git a/style_vector/en_US_vctk_p272.wav b/assets/style_vector/en_US_vctk_p272.wav similarity index 100% rename from style_vector/en_US_vctk_p272.wav rename to assets/style_vector/en_US_vctk_p272.wav diff --git a/style_vector/en_US_vctk_p273.wav b/assets/style_vector/en_US_vctk_p273.wav similarity index 100% rename from style_vector/en_US_vctk_p273.wav rename to assets/style_vector/en_US_vctk_p273.wav diff --git a/style_vector/en_US_vctk_p274.wav b/assets/style_vector/en_US_vctk_p274.wav similarity index 100% rename from style_vector/en_US_vctk_p274.wav rename to assets/style_vector/en_US_vctk_p274.wav diff --git a/style_vector/en_US_vctk_p275.wav b/assets/style_vector/en_US_vctk_p275.wav similarity index 100% rename from style_vector/en_US_vctk_p275.wav rename to assets/style_vector/en_US_vctk_p275.wav diff --git a/style_vector/en_US_vctk_p276.wav b/assets/style_vector/en_US_vctk_p276.wav similarity index 100% rename from style_vector/en_US_vctk_p276.wav rename to assets/style_vector/en_US_vctk_p276.wav diff --git a/style_vector/en_US_vctk_p277.wav b/assets/style_vector/en_US_vctk_p277.wav similarity index 100% rename from style_vector/en_US_vctk_p277.wav rename to assets/style_vector/en_US_vctk_p277.wav diff --git a/style_vector/en_US_vctk_p278.wav b/assets/style_vector/en_US_vctk_p278.wav similarity index 100% rename from style_vector/en_US_vctk_p278.wav rename to assets/style_vector/en_US_vctk_p278.wav diff --git a/style_vector/en_US_vctk_p279.wav b/assets/style_vector/en_US_vctk_p279.wav similarity index 100% rename from style_vector/en_US_vctk_p279.wav rename to assets/style_vector/en_US_vctk_p279.wav diff --git a/style_vector/en_US_vctk_p280.wav b/assets/style_vector/en_US_vctk_p280.wav similarity index 100% rename from style_vector/en_US_vctk_p280.wav rename to assets/style_vector/en_US_vctk_p280.wav diff --git a/style_vector/en_US_vctk_p281.wav b/assets/style_vector/en_US_vctk_p281.wav similarity index 100% rename from style_vector/en_US_vctk_p281.wav rename to assets/style_vector/en_US_vctk_p281.wav diff --git a/style_vector/en_US_vctk_p282.wav b/assets/style_vector/en_US_vctk_p282.wav similarity index 100% rename from style_vector/en_US_vctk_p282.wav rename to assets/style_vector/en_US_vctk_p282.wav diff --git a/style_vector/en_US_vctk_p283.wav b/assets/style_vector/en_US_vctk_p283.wav similarity index 100% rename from style_vector/en_US_vctk_p283.wav rename to assets/style_vector/en_US_vctk_p283.wav diff --git a/style_vector/en_US_vctk_p284.wav b/assets/style_vector/en_US_vctk_p284.wav similarity index 100% rename from style_vector/en_US_vctk_p284.wav rename to assets/style_vector/en_US_vctk_p284.wav diff --git a/style_vector/en_US_vctk_p285.wav b/assets/style_vector/en_US_vctk_p285.wav similarity index 100% rename from style_vector/en_US_vctk_p285.wav rename to assets/style_vector/en_US_vctk_p285.wav diff --git a/style_vector/en_US_vctk_p286.wav b/assets/style_vector/en_US_vctk_p286.wav similarity index 100% rename from style_vector/en_US_vctk_p286.wav rename to assets/style_vector/en_US_vctk_p286.wav diff --git a/style_vector/en_US_vctk_p287.wav b/assets/style_vector/en_US_vctk_p287.wav similarity index 100% rename from style_vector/en_US_vctk_p287.wav rename to assets/style_vector/en_US_vctk_p287.wav diff --git a/style_vector/en_US_vctk_p288.wav b/assets/style_vector/en_US_vctk_p288.wav similarity index 100% rename from style_vector/en_US_vctk_p288.wav rename to assets/style_vector/en_US_vctk_p288.wav diff --git a/style_vector/en_US_vctk_p292.wav b/assets/style_vector/en_US_vctk_p292.wav similarity index 100% rename from style_vector/en_US_vctk_p292.wav rename to assets/style_vector/en_US_vctk_p292.wav diff --git a/style_vector/en_US_vctk_p293.wav b/assets/style_vector/en_US_vctk_p293.wav similarity index 100% rename from style_vector/en_US_vctk_p293.wav rename to assets/style_vector/en_US_vctk_p293.wav diff --git a/style_vector/en_US_vctk_p294.wav b/assets/style_vector/en_US_vctk_p294.wav similarity index 100% rename from style_vector/en_US_vctk_p294.wav rename to assets/style_vector/en_US_vctk_p294.wav diff --git a/style_vector/en_US_vctk_p295.wav b/assets/style_vector/en_US_vctk_p295.wav similarity index 100% rename from style_vector/en_US_vctk_p295.wav rename to assets/style_vector/en_US_vctk_p295.wav diff --git a/style_vector/en_US_vctk_p297.wav b/assets/style_vector/en_US_vctk_p297.wav similarity index 100% rename from style_vector/en_US_vctk_p297.wav rename to assets/style_vector/en_US_vctk_p297.wav diff --git a/style_vector/en_US_vctk_p298.wav b/assets/style_vector/en_US_vctk_p298.wav similarity index 100% rename from style_vector/en_US_vctk_p298.wav rename to assets/style_vector/en_US_vctk_p298.wav diff --git a/style_vector/en_US_vctk_p299.wav b/assets/style_vector/en_US_vctk_p299.wav similarity index 100% rename from style_vector/en_US_vctk_p299.wav rename to assets/style_vector/en_US_vctk_p299.wav diff --git a/style_vector/en_US_vctk_p300.wav b/assets/style_vector/en_US_vctk_p300.wav similarity index 100% rename from style_vector/en_US_vctk_p300.wav rename to assets/style_vector/en_US_vctk_p300.wav diff --git a/style_vector/en_US_vctk_p301.wav b/assets/style_vector/en_US_vctk_p301.wav similarity index 100% rename from style_vector/en_US_vctk_p301.wav rename to assets/style_vector/en_US_vctk_p301.wav diff --git a/style_vector/en_US_vctk_p302.wav b/assets/style_vector/en_US_vctk_p302.wav similarity index 100% rename from style_vector/en_US_vctk_p302.wav rename to assets/style_vector/en_US_vctk_p302.wav diff --git a/style_vector/en_US_vctk_p303.wav b/assets/style_vector/en_US_vctk_p303.wav similarity index 100% rename from style_vector/en_US_vctk_p303.wav rename to assets/style_vector/en_US_vctk_p303.wav diff --git a/style_vector/en_US_vctk_p304.wav b/assets/style_vector/en_US_vctk_p304.wav similarity index 100% rename from style_vector/en_US_vctk_p304.wav rename to assets/style_vector/en_US_vctk_p304.wav diff --git a/style_vector/en_US_vctk_p305.wav b/assets/style_vector/en_US_vctk_p305.wav similarity index 100% rename from style_vector/en_US_vctk_p305.wav rename to assets/style_vector/en_US_vctk_p305.wav diff --git a/style_vector/en_US_vctk_p306.wav b/assets/style_vector/en_US_vctk_p306.wav similarity index 100% rename from style_vector/en_US_vctk_p306.wav rename to assets/style_vector/en_US_vctk_p306.wav diff --git a/style_vector/en_US_vctk_p307.wav b/assets/style_vector/en_US_vctk_p307.wav similarity index 100% rename from style_vector/en_US_vctk_p307.wav rename to assets/style_vector/en_US_vctk_p307.wav diff --git a/style_vector/en_US_vctk_p308.wav b/assets/style_vector/en_US_vctk_p308.wav similarity index 100% rename from style_vector/en_US_vctk_p308.wav rename to assets/style_vector/en_US_vctk_p308.wav diff --git a/style_vector/en_US_vctk_p310.wav b/assets/style_vector/en_US_vctk_p310.wav similarity index 100% rename from style_vector/en_US_vctk_p310.wav rename to assets/style_vector/en_US_vctk_p310.wav diff --git a/style_vector/en_US_vctk_p311.wav b/assets/style_vector/en_US_vctk_p311.wav similarity index 100% rename from style_vector/en_US_vctk_p311.wav rename to assets/style_vector/en_US_vctk_p311.wav diff --git a/style_vector/en_US_vctk_p312.wav b/assets/style_vector/en_US_vctk_p312.wav similarity index 100% rename from style_vector/en_US_vctk_p312.wav rename to assets/style_vector/en_US_vctk_p312.wav diff --git a/style_vector/en_US_vctk_p313.wav b/assets/style_vector/en_US_vctk_p313.wav similarity index 100% rename from style_vector/en_US_vctk_p313.wav rename to assets/style_vector/en_US_vctk_p313.wav diff --git a/style_vector/en_US_vctk_p314.wav b/assets/style_vector/en_US_vctk_p314.wav similarity index 100% rename from style_vector/en_US_vctk_p314.wav rename to assets/style_vector/en_US_vctk_p314.wav diff --git a/style_vector/en_US_vctk_p316.wav b/assets/style_vector/en_US_vctk_p316.wav similarity index 100% rename from style_vector/en_US_vctk_p316.wav rename to assets/style_vector/en_US_vctk_p316.wav diff --git a/style_vector/en_US_vctk_p317.wav b/assets/style_vector/en_US_vctk_p317.wav similarity index 100% rename from style_vector/en_US_vctk_p317.wav rename to assets/style_vector/en_US_vctk_p317.wav diff --git a/style_vector/en_US_vctk_p318.wav b/assets/style_vector/en_US_vctk_p318.wav similarity index 100% rename from style_vector/en_US_vctk_p318.wav rename to assets/style_vector/en_US_vctk_p318.wav diff --git a/style_vector/en_US_vctk_p323.wav b/assets/style_vector/en_US_vctk_p323.wav similarity index 100% rename from style_vector/en_US_vctk_p323.wav rename to assets/style_vector/en_US_vctk_p323.wav diff --git a/style_vector/en_US_vctk_p326.wav b/assets/style_vector/en_US_vctk_p326.wav similarity index 100% rename from style_vector/en_US_vctk_p326.wav rename to assets/style_vector/en_US_vctk_p326.wav diff --git a/style_vector/en_US_vctk_p329.wav b/assets/style_vector/en_US_vctk_p329.wav similarity index 100% rename from style_vector/en_US_vctk_p329.wav rename to assets/style_vector/en_US_vctk_p329.wav diff --git a/style_vector/en_US_vctk_p330.wav b/assets/style_vector/en_US_vctk_p330.wav similarity index 100% rename from style_vector/en_US_vctk_p330.wav rename to assets/style_vector/en_US_vctk_p330.wav diff --git a/style_vector/en_US_vctk_p333.wav b/assets/style_vector/en_US_vctk_p333.wav similarity index 100% rename from style_vector/en_US_vctk_p333.wav rename to assets/style_vector/en_US_vctk_p333.wav diff --git a/style_vector/en_US_vctk_p334.wav b/assets/style_vector/en_US_vctk_p334.wav similarity index 100% rename from style_vector/en_US_vctk_p334.wav rename to assets/style_vector/en_US_vctk_p334.wav diff --git a/style_vector/en_US_vctk_p335.wav b/assets/style_vector/en_US_vctk_p335.wav similarity index 100% rename from style_vector/en_US_vctk_p335.wav rename to assets/style_vector/en_US_vctk_p335.wav diff --git a/style_vector/en_US_vctk_p336.wav b/assets/style_vector/en_US_vctk_p336.wav similarity index 100% rename from style_vector/en_US_vctk_p336.wav rename to assets/style_vector/en_US_vctk_p336.wav diff --git a/style_vector/en_US_vctk_p339.wav b/assets/style_vector/en_US_vctk_p339.wav similarity index 100% rename from style_vector/en_US_vctk_p339.wav rename to assets/style_vector/en_US_vctk_p339.wav diff --git a/style_vector/en_US_vctk_p340.wav b/assets/style_vector/en_US_vctk_p340.wav similarity index 100% rename from style_vector/en_US_vctk_p340.wav rename to assets/style_vector/en_US_vctk_p340.wav diff --git a/style_vector/en_US_vctk_p341.wav b/assets/style_vector/en_US_vctk_p341.wav similarity index 100% rename from style_vector/en_US_vctk_p341.wav rename to assets/style_vector/en_US_vctk_p341.wav diff --git a/style_vector/en_US_vctk_p343.wav b/assets/style_vector/en_US_vctk_p343.wav similarity index 100% rename from style_vector/en_US_vctk_p343.wav rename to assets/style_vector/en_US_vctk_p343.wav diff --git a/style_vector/en_US_vctk_p345.wav b/assets/style_vector/en_US_vctk_p345.wav similarity index 100% rename from style_vector/en_US_vctk_p345.wav rename to assets/style_vector/en_US_vctk_p345.wav diff --git a/style_vector/en_US_vctk_p347.wav b/assets/style_vector/en_US_vctk_p347.wav similarity index 100% rename from style_vector/en_US_vctk_p347.wav rename to assets/style_vector/en_US_vctk_p347.wav diff --git a/style_vector/en_US_vctk_p351.wav b/assets/style_vector/en_US_vctk_p351.wav similarity index 100% rename from style_vector/en_US_vctk_p351.wav rename to assets/style_vector/en_US_vctk_p351.wav diff --git a/style_vector/en_US_vctk_p360.wav b/assets/style_vector/en_US_vctk_p360.wav similarity index 100% rename from style_vector/en_US_vctk_p360.wav rename to assets/style_vector/en_US_vctk_p360.wav diff --git a/style_vector/en_US_vctk_p361.wav b/assets/style_vector/en_US_vctk_p361.wav similarity index 100% rename from style_vector/en_US_vctk_p361.wav rename to assets/style_vector/en_US_vctk_p361.wav diff --git a/style_vector/en_US_vctk_p362.wav b/assets/style_vector/en_US_vctk_p362.wav similarity index 100% rename from style_vector/en_US_vctk_p362.wav rename to assets/style_vector/en_US_vctk_p362.wav diff --git a/style_vector/en_US_vctk_p363.wav b/assets/style_vector/en_US_vctk_p363.wav similarity index 100% rename from style_vector/en_US_vctk_p363.wav rename to assets/style_vector/en_US_vctk_p363.wav diff --git a/style_vector/en_US_vctk_p364.wav b/assets/style_vector/en_US_vctk_p364.wav similarity index 100% rename from style_vector/en_US_vctk_p364.wav rename to assets/style_vector/en_US_vctk_p364.wav diff --git a/style_vector/en_US_vctk_p374.wav b/assets/style_vector/en_US_vctk_p374.wav similarity index 100% rename from style_vector/en_US_vctk_p374.wav rename to assets/style_vector/en_US_vctk_p374.wav diff --git a/style_vector/en_US_vctk_p376.wav b/assets/style_vector/en_US_vctk_p376.wav similarity index 100% rename from style_vector/en_US_vctk_p376.wav rename to assets/style_vector/en_US_vctk_p376.wav diff --git a/style_vector/en_US_vctk_s5.wav b/assets/style_vector/en_US_vctk_s5.wav similarity index 100% rename from style_vector/en_US_vctk_s5.wav rename to assets/style_vector/en_US_vctk_s5.wav diff --git a/style_vector_v2/en_UK_apope.wav b/assets/style_vector_v2/en_UK_apope.wav similarity index 100% rename from style_vector_v2/en_UK_apope.wav rename to assets/style_vector_v2/en_UK_apope.wav diff --git a/style_vector_v2/en_US_cmu_arctic_aew.wav b/assets/style_vector_v2/en_US_cmu_arctic_aew.wav similarity index 100% rename from style_vector_v2/en_US_cmu_arctic_aew.wav rename to assets/style_vector_v2/en_US_cmu_arctic_aew.wav diff --git a/style_vector_v2/en_US_cmu_arctic_ahw.wav b/assets/style_vector_v2/en_US_cmu_arctic_ahw.wav similarity index 100% rename from style_vector_v2/en_US_cmu_arctic_ahw.wav rename to assets/style_vector_v2/en_US_cmu_arctic_ahw.wav diff --git a/style_vector_v2/en_US_cmu_arctic_aup.wav b/assets/style_vector_v2/en_US_cmu_arctic_aup.wav similarity index 100% rename from style_vector_v2/en_US_cmu_arctic_aup.wav rename to assets/style_vector_v2/en_US_cmu_arctic_aup.wav diff --git a/style_vector_v2/en_US_cmu_arctic_awb.wav b/assets/style_vector_v2/en_US_cmu_arctic_awb.wav similarity index 100% rename from style_vector_v2/en_US_cmu_arctic_awb.wav rename to assets/style_vector_v2/en_US_cmu_arctic_awb.wav diff --git a/style_vector_v2/en_US_cmu_arctic_awbrms.wav b/assets/style_vector_v2/en_US_cmu_arctic_awbrms.wav similarity index 100% rename from style_vector_v2/en_US_cmu_arctic_awbrms.wav rename to assets/style_vector_v2/en_US_cmu_arctic_awbrms.wav diff --git a/style_vector_v2/en_US_cmu_arctic_axb.wav b/assets/style_vector_v2/en_US_cmu_arctic_axb.wav similarity index 100% rename from style_vector_v2/en_US_cmu_arctic_axb.wav rename to assets/style_vector_v2/en_US_cmu_arctic_axb.wav diff --git a/style_vector_v2/en_US_cmu_arctic_bdl.wav b/assets/style_vector_v2/en_US_cmu_arctic_bdl.wav similarity index 100% rename from style_vector_v2/en_US_cmu_arctic_bdl.wav rename to assets/style_vector_v2/en_US_cmu_arctic_bdl.wav diff --git a/style_vector_v2/en_US_cmu_arctic_clb.wav b/assets/style_vector_v2/en_US_cmu_arctic_clb.wav similarity index 100% rename from style_vector_v2/en_US_cmu_arctic_clb.wav rename to assets/style_vector_v2/en_US_cmu_arctic_clb.wav diff --git a/style_vector_v2/en_US_cmu_arctic_eey.wav b/assets/style_vector_v2/en_US_cmu_arctic_eey.wav similarity index 100% rename from style_vector_v2/en_US_cmu_arctic_eey.wav rename to assets/style_vector_v2/en_US_cmu_arctic_eey.wav diff --git a/style_vector_v2/en_US_cmu_arctic_fem.wav b/assets/style_vector_v2/en_US_cmu_arctic_fem.wav similarity index 100% rename from style_vector_v2/en_US_cmu_arctic_fem.wav rename to assets/style_vector_v2/en_US_cmu_arctic_fem.wav diff --git a/style_vector_v2/en_US_cmu_arctic_gka.wav b/assets/style_vector_v2/en_US_cmu_arctic_gka.wav similarity index 100% rename from style_vector_v2/en_US_cmu_arctic_gka.wav rename to assets/style_vector_v2/en_US_cmu_arctic_gka.wav diff --git a/style_vector_v2/en_US_cmu_arctic_jmk.wav b/assets/style_vector_v2/en_US_cmu_arctic_jmk.wav similarity index 100% rename from style_vector_v2/en_US_cmu_arctic_jmk.wav rename to assets/style_vector_v2/en_US_cmu_arctic_jmk.wav diff --git a/style_vector_v2/en_US_cmu_arctic_ksp.wav b/assets/style_vector_v2/en_US_cmu_arctic_ksp.wav similarity index 100% rename from style_vector_v2/en_US_cmu_arctic_ksp.wav rename to assets/style_vector_v2/en_US_cmu_arctic_ksp.wav diff --git a/style_vector_v2/en_US_cmu_arctic_ljm.wav b/assets/style_vector_v2/en_US_cmu_arctic_ljm.wav similarity index 100% rename from style_vector_v2/en_US_cmu_arctic_ljm.wav rename to assets/style_vector_v2/en_US_cmu_arctic_ljm.wav diff --git a/style_vector_v2/en_US_cmu_arctic_lnh.wav b/assets/style_vector_v2/en_US_cmu_arctic_lnh.wav similarity index 100% rename from style_vector_v2/en_US_cmu_arctic_lnh.wav rename to assets/style_vector_v2/en_US_cmu_arctic_lnh.wav diff --git a/style_vector_v2/en_US_cmu_arctic_rms.wav b/assets/style_vector_v2/en_US_cmu_arctic_rms.wav similarity index 100% rename from style_vector_v2/en_US_cmu_arctic_rms.wav rename to assets/style_vector_v2/en_US_cmu_arctic_rms.wav diff --git a/style_vector_v2/en_US_cmu_arctic_rxr.wav b/assets/style_vector_v2/en_US_cmu_arctic_rxr.wav similarity index 100% rename from style_vector_v2/en_US_cmu_arctic_rxr.wav rename to assets/style_vector_v2/en_US_cmu_arctic_rxr.wav diff --git a/style_vector_v2/en_US_cmu_arctic_slp.wav b/assets/style_vector_v2/en_US_cmu_arctic_slp.wav similarity index 100% rename from style_vector_v2/en_US_cmu_arctic_slp.wav rename to assets/style_vector_v2/en_US_cmu_arctic_slp.wav diff --git a/style_vector_v2/en_US_cmu_arctic_slt.wav b/assets/style_vector_v2/en_US_cmu_arctic_slt.wav similarity index 100% rename from style_vector_v2/en_US_cmu_arctic_slt.wav rename to assets/style_vector_v2/en_US_cmu_arctic_slt.wav diff --git a/style_vector_v2/en_US_hifi-tts_6097.wav b/assets/style_vector_v2/en_US_hifi-tts_6097.wav similarity index 100% rename from style_vector_v2/en_US_hifi-tts_6097.wav rename to assets/style_vector_v2/en_US_hifi-tts_6097.wav diff --git a/style_vector_v2/en_US_hifi-tts_9017.wav b/assets/style_vector_v2/en_US_hifi-tts_9017.wav similarity index 100% rename from style_vector_v2/en_US_hifi-tts_9017.wav rename to assets/style_vector_v2/en_US_hifi-tts_9017.wav diff --git a/style_vector_v2/en_US_hifi-tts_92.wav b/assets/style_vector_v2/en_US_hifi-tts_92.wav similarity index 100% rename from style_vector_v2/en_US_hifi-tts_92.wav rename to assets/style_vector_v2/en_US_hifi-tts_92.wav diff --git a/style_vector_v2/en_US_ljspeech.wav b/assets/style_vector_v2/en_US_ljspeech.wav similarity index 100% rename from style_vector_v2/en_US_ljspeech.wav rename to assets/style_vector_v2/en_US_ljspeech.wav diff --git a/style_vector_v2/en_US_m-ailabs_elliot_miller.wav b/assets/style_vector_v2/en_US_m-ailabs_elliot_miller.wav similarity index 100% rename from style_vector_v2/en_US_m-ailabs_elliot_miller.wav rename to assets/style_vector_v2/en_US_m-ailabs_elliot_miller.wav diff --git a/style_vector_v2/en_US_m-ailabs_judy_bieber.wav b/assets/style_vector_v2/en_US_m-ailabs_judy_bieber.wav similarity index 100% rename from style_vector_v2/en_US_m-ailabs_judy_bieber.wav rename to assets/style_vector_v2/en_US_m-ailabs_judy_bieber.wav diff --git a/style_vector_v2/en_US_m-ailabs_mary_ann.wav b/assets/style_vector_v2/en_US_m-ailabs_mary_ann.wav similarity index 100% rename from style_vector_v2/en_US_m-ailabs_mary_ann.wav rename to assets/style_vector_v2/en_US_m-ailabs_mary_ann.wav diff --git a/style_vector_v2/en_US_vctk_p225.wav b/assets/style_vector_v2/en_US_vctk_p225.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p225.wav rename to assets/style_vector_v2/en_US_vctk_p225.wav diff --git a/style_vector_v2/en_US_vctk_p226.wav b/assets/style_vector_v2/en_US_vctk_p226.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p226.wav rename to assets/style_vector_v2/en_US_vctk_p226.wav diff --git a/style_vector_v2/en_US_vctk_p227.wav b/assets/style_vector_v2/en_US_vctk_p227.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p227.wav rename to assets/style_vector_v2/en_US_vctk_p227.wav diff --git a/style_vector_v2/en_US_vctk_p228.wav b/assets/style_vector_v2/en_US_vctk_p228.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p228.wav rename to assets/style_vector_v2/en_US_vctk_p228.wav diff --git a/style_vector_v2/en_US_vctk_p229.wav b/assets/style_vector_v2/en_US_vctk_p229.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p229.wav rename to assets/style_vector_v2/en_US_vctk_p229.wav diff --git a/style_vector_v2/en_US_vctk_p230.wav b/assets/style_vector_v2/en_US_vctk_p230.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p230.wav rename to assets/style_vector_v2/en_US_vctk_p230.wav diff --git a/style_vector_v2/en_US_vctk_p231.wav b/assets/style_vector_v2/en_US_vctk_p231.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p231.wav rename to assets/style_vector_v2/en_US_vctk_p231.wav diff --git a/style_vector_v2/en_US_vctk_p232.wav b/assets/style_vector_v2/en_US_vctk_p232.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p232.wav rename to assets/style_vector_v2/en_US_vctk_p232.wav diff --git a/style_vector_v2/en_US_vctk_p233.wav b/assets/style_vector_v2/en_US_vctk_p233.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p233.wav rename to assets/style_vector_v2/en_US_vctk_p233.wav diff --git a/style_vector_v2/en_US_vctk_p234.wav b/assets/style_vector_v2/en_US_vctk_p234.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p234.wav rename to assets/style_vector_v2/en_US_vctk_p234.wav diff --git a/style_vector_v2/en_US_vctk_p236.wav b/assets/style_vector_v2/en_US_vctk_p236.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p236.wav rename to assets/style_vector_v2/en_US_vctk_p236.wav diff --git a/style_vector_v2/en_US_vctk_p237.wav b/assets/style_vector_v2/en_US_vctk_p237.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p237.wav rename to assets/style_vector_v2/en_US_vctk_p237.wav diff --git a/style_vector_v2/en_US_vctk_p238.wav b/assets/style_vector_v2/en_US_vctk_p238.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p238.wav rename to assets/style_vector_v2/en_US_vctk_p238.wav diff --git a/style_vector_v2/en_US_vctk_p239.wav b/assets/style_vector_v2/en_US_vctk_p239.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p239.wav rename to assets/style_vector_v2/en_US_vctk_p239.wav diff --git a/style_vector_v2/en_US_vctk_p240.wav b/assets/style_vector_v2/en_US_vctk_p240.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p240.wav rename to assets/style_vector_v2/en_US_vctk_p240.wav diff --git a/style_vector_v2/en_US_vctk_p241.wav b/assets/style_vector_v2/en_US_vctk_p241.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p241.wav rename to assets/style_vector_v2/en_US_vctk_p241.wav diff --git a/style_vector_v2/en_US_vctk_p243.wav b/assets/style_vector_v2/en_US_vctk_p243.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p243.wav rename to assets/style_vector_v2/en_US_vctk_p243.wav diff --git a/style_vector_v2/en_US_vctk_p244.wav b/assets/style_vector_v2/en_US_vctk_p244.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p244.wav rename to assets/style_vector_v2/en_US_vctk_p244.wav diff --git a/style_vector_v2/en_US_vctk_p245.wav b/assets/style_vector_v2/en_US_vctk_p245.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p245.wav rename to assets/style_vector_v2/en_US_vctk_p245.wav diff --git a/style_vector_v2/en_US_vctk_p246.wav b/assets/style_vector_v2/en_US_vctk_p246.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p246.wav rename to assets/style_vector_v2/en_US_vctk_p246.wav diff --git a/style_vector_v2/en_US_vctk_p247.wav b/assets/style_vector_v2/en_US_vctk_p247.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p247.wav rename to assets/style_vector_v2/en_US_vctk_p247.wav diff --git a/style_vector_v2/en_US_vctk_p248.wav b/assets/style_vector_v2/en_US_vctk_p248.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p248.wav rename to assets/style_vector_v2/en_US_vctk_p248.wav diff --git a/style_vector_v2/en_US_vctk_p249.wav b/assets/style_vector_v2/en_US_vctk_p249.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p249.wav rename to assets/style_vector_v2/en_US_vctk_p249.wav diff --git a/style_vector_v2/en_US_vctk_p250.wav b/assets/style_vector_v2/en_US_vctk_p250.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p250.wav rename to assets/style_vector_v2/en_US_vctk_p250.wav diff --git a/style_vector_v2/en_US_vctk_p251.wav b/assets/style_vector_v2/en_US_vctk_p251.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p251.wav rename to assets/style_vector_v2/en_US_vctk_p251.wav diff --git a/style_vector_v2/en_US_vctk_p252.wav b/assets/style_vector_v2/en_US_vctk_p252.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p252.wav rename to assets/style_vector_v2/en_US_vctk_p252.wav diff --git a/style_vector_v2/en_US_vctk_p253.wav b/assets/style_vector_v2/en_US_vctk_p253.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p253.wav rename to assets/style_vector_v2/en_US_vctk_p253.wav diff --git a/style_vector_v2/en_US_vctk_p254.wav b/assets/style_vector_v2/en_US_vctk_p254.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p254.wav rename to assets/style_vector_v2/en_US_vctk_p254.wav diff --git a/style_vector_v2/en_US_vctk_p255.wav b/assets/style_vector_v2/en_US_vctk_p255.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p255.wav rename to assets/style_vector_v2/en_US_vctk_p255.wav diff --git a/style_vector_v2/en_US_vctk_p256.wav b/assets/style_vector_v2/en_US_vctk_p256.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p256.wav rename to assets/style_vector_v2/en_US_vctk_p256.wav diff --git a/style_vector_v2/en_US_vctk_p257.wav b/assets/style_vector_v2/en_US_vctk_p257.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p257.wav rename to assets/style_vector_v2/en_US_vctk_p257.wav diff --git a/style_vector_v2/en_US_vctk_p258.wav b/assets/style_vector_v2/en_US_vctk_p258.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p258.wav rename to assets/style_vector_v2/en_US_vctk_p258.wav diff --git a/style_vector_v2/en_US_vctk_p259.wav b/assets/style_vector_v2/en_US_vctk_p259.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p259.wav rename to assets/style_vector_v2/en_US_vctk_p259.wav diff --git a/style_vector_v2/en_US_vctk_p260.wav b/assets/style_vector_v2/en_US_vctk_p260.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p260.wav rename to assets/style_vector_v2/en_US_vctk_p260.wav diff --git a/style_vector_v2/en_US_vctk_p261.wav b/assets/style_vector_v2/en_US_vctk_p261.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p261.wav rename to assets/style_vector_v2/en_US_vctk_p261.wav diff --git a/style_vector_v2/en_US_vctk_p262.wav b/assets/style_vector_v2/en_US_vctk_p262.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p262.wav rename to assets/style_vector_v2/en_US_vctk_p262.wav diff --git a/style_vector_v2/en_US_vctk_p263.wav b/assets/style_vector_v2/en_US_vctk_p263.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p263.wav rename to assets/style_vector_v2/en_US_vctk_p263.wav diff --git a/style_vector_v2/en_US_vctk_p264.wav b/assets/style_vector_v2/en_US_vctk_p264.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p264.wav rename to assets/style_vector_v2/en_US_vctk_p264.wav diff --git a/style_vector_v2/en_US_vctk_p265.wav b/assets/style_vector_v2/en_US_vctk_p265.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p265.wav rename to assets/style_vector_v2/en_US_vctk_p265.wav diff --git a/style_vector_v2/en_US_vctk_p266.wav b/assets/style_vector_v2/en_US_vctk_p266.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p266.wav rename to assets/style_vector_v2/en_US_vctk_p266.wav diff --git a/style_vector_v2/en_US_vctk_p267.wav b/assets/style_vector_v2/en_US_vctk_p267.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p267.wav rename to assets/style_vector_v2/en_US_vctk_p267.wav diff --git a/style_vector_v2/en_US_vctk_p268.wav b/assets/style_vector_v2/en_US_vctk_p268.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p268.wav rename to assets/style_vector_v2/en_US_vctk_p268.wav diff --git a/style_vector_v2/en_US_vctk_p269.wav b/assets/style_vector_v2/en_US_vctk_p269.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p269.wav rename to assets/style_vector_v2/en_US_vctk_p269.wav diff --git a/style_vector_v2/en_US_vctk_p270.wav b/assets/style_vector_v2/en_US_vctk_p270.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p270.wav rename to assets/style_vector_v2/en_US_vctk_p270.wav diff --git a/style_vector_v2/en_US_vctk_p271.wav b/assets/style_vector_v2/en_US_vctk_p271.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p271.wav rename to assets/style_vector_v2/en_US_vctk_p271.wav diff --git a/style_vector_v2/en_US_vctk_p272.wav b/assets/style_vector_v2/en_US_vctk_p272.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p272.wav rename to assets/style_vector_v2/en_US_vctk_p272.wav diff --git a/style_vector_v2/en_US_vctk_p273.wav b/assets/style_vector_v2/en_US_vctk_p273.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p273.wav rename to assets/style_vector_v2/en_US_vctk_p273.wav diff --git a/style_vector_v2/en_US_vctk_p274.wav b/assets/style_vector_v2/en_US_vctk_p274.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p274.wav rename to assets/style_vector_v2/en_US_vctk_p274.wav diff --git a/style_vector_v2/en_US_vctk_p275.wav b/assets/style_vector_v2/en_US_vctk_p275.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p275.wav rename to assets/style_vector_v2/en_US_vctk_p275.wav diff --git a/style_vector_v2/en_US_vctk_p276.wav b/assets/style_vector_v2/en_US_vctk_p276.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p276.wav rename to assets/style_vector_v2/en_US_vctk_p276.wav diff --git a/style_vector_v2/en_US_vctk_p277.wav b/assets/style_vector_v2/en_US_vctk_p277.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p277.wav rename to assets/style_vector_v2/en_US_vctk_p277.wav diff --git a/style_vector_v2/en_US_vctk_p278.wav b/assets/style_vector_v2/en_US_vctk_p278.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p278.wav rename to assets/style_vector_v2/en_US_vctk_p278.wav diff --git a/style_vector_v2/en_US_vctk_p279.wav b/assets/style_vector_v2/en_US_vctk_p279.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p279.wav rename to assets/style_vector_v2/en_US_vctk_p279.wav diff --git a/style_vector_v2/en_US_vctk_p280.wav b/assets/style_vector_v2/en_US_vctk_p280.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p280.wav rename to assets/style_vector_v2/en_US_vctk_p280.wav diff --git a/style_vector_v2/en_US_vctk_p281.wav b/assets/style_vector_v2/en_US_vctk_p281.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p281.wav rename to assets/style_vector_v2/en_US_vctk_p281.wav diff --git a/style_vector_v2/en_US_vctk_p282.wav b/assets/style_vector_v2/en_US_vctk_p282.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p282.wav rename to assets/style_vector_v2/en_US_vctk_p282.wav diff --git a/style_vector_v2/en_US_vctk_p283.wav b/assets/style_vector_v2/en_US_vctk_p283.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p283.wav rename to assets/style_vector_v2/en_US_vctk_p283.wav diff --git a/style_vector_v2/en_US_vctk_p284.wav b/assets/style_vector_v2/en_US_vctk_p284.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p284.wav rename to assets/style_vector_v2/en_US_vctk_p284.wav diff --git a/style_vector_v2/en_US_vctk_p285.wav b/assets/style_vector_v2/en_US_vctk_p285.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p285.wav rename to assets/style_vector_v2/en_US_vctk_p285.wav diff --git a/style_vector_v2/en_US_vctk_p286.wav b/assets/style_vector_v2/en_US_vctk_p286.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p286.wav rename to assets/style_vector_v2/en_US_vctk_p286.wav diff --git a/style_vector_v2/en_US_vctk_p287.wav b/assets/style_vector_v2/en_US_vctk_p287.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p287.wav rename to assets/style_vector_v2/en_US_vctk_p287.wav diff --git a/style_vector_v2/en_US_vctk_p288.wav b/assets/style_vector_v2/en_US_vctk_p288.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p288.wav rename to assets/style_vector_v2/en_US_vctk_p288.wav diff --git a/style_vector_v2/en_US_vctk_p292.wav b/assets/style_vector_v2/en_US_vctk_p292.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p292.wav rename to assets/style_vector_v2/en_US_vctk_p292.wav diff --git a/style_vector_v2/en_US_vctk_p293.wav b/assets/style_vector_v2/en_US_vctk_p293.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p293.wav rename to assets/style_vector_v2/en_US_vctk_p293.wav diff --git a/style_vector_v2/en_US_vctk_p294.wav b/assets/style_vector_v2/en_US_vctk_p294.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p294.wav rename to assets/style_vector_v2/en_US_vctk_p294.wav diff --git a/style_vector_v2/en_US_vctk_p295.wav b/assets/style_vector_v2/en_US_vctk_p295.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p295.wav rename to assets/style_vector_v2/en_US_vctk_p295.wav diff --git a/style_vector_v2/en_US_vctk_p297.wav b/assets/style_vector_v2/en_US_vctk_p297.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p297.wav rename to assets/style_vector_v2/en_US_vctk_p297.wav diff --git a/style_vector_v2/en_US_vctk_p298.wav b/assets/style_vector_v2/en_US_vctk_p298.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p298.wav rename to assets/style_vector_v2/en_US_vctk_p298.wav diff --git a/style_vector_v2/en_US_vctk_p299.wav b/assets/style_vector_v2/en_US_vctk_p299.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p299.wav rename to assets/style_vector_v2/en_US_vctk_p299.wav diff --git a/style_vector_v2/en_US_vctk_p300.wav b/assets/style_vector_v2/en_US_vctk_p300.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p300.wav rename to assets/style_vector_v2/en_US_vctk_p300.wav diff --git a/style_vector_v2/en_US_vctk_p301.wav b/assets/style_vector_v2/en_US_vctk_p301.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p301.wav rename to assets/style_vector_v2/en_US_vctk_p301.wav diff --git a/style_vector_v2/en_US_vctk_p302.wav b/assets/style_vector_v2/en_US_vctk_p302.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p302.wav rename to assets/style_vector_v2/en_US_vctk_p302.wav diff --git a/style_vector_v2/en_US_vctk_p303.wav b/assets/style_vector_v2/en_US_vctk_p303.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p303.wav rename to assets/style_vector_v2/en_US_vctk_p303.wav diff --git a/style_vector_v2/en_US_vctk_p304.wav b/assets/style_vector_v2/en_US_vctk_p304.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p304.wav rename to assets/style_vector_v2/en_US_vctk_p304.wav diff --git a/style_vector_v2/en_US_vctk_p305.wav b/assets/style_vector_v2/en_US_vctk_p305.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p305.wav rename to assets/style_vector_v2/en_US_vctk_p305.wav diff --git a/style_vector_v2/en_US_vctk_p306.wav b/assets/style_vector_v2/en_US_vctk_p306.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p306.wav rename to assets/style_vector_v2/en_US_vctk_p306.wav diff --git a/style_vector_v2/en_US_vctk_p307.wav b/assets/style_vector_v2/en_US_vctk_p307.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p307.wav rename to assets/style_vector_v2/en_US_vctk_p307.wav diff --git a/style_vector_v2/en_US_vctk_p308.wav b/assets/style_vector_v2/en_US_vctk_p308.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p308.wav rename to assets/style_vector_v2/en_US_vctk_p308.wav diff --git a/style_vector_v2/en_US_vctk_p310.wav b/assets/style_vector_v2/en_US_vctk_p310.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p310.wav rename to assets/style_vector_v2/en_US_vctk_p310.wav diff --git a/style_vector_v2/en_US_vctk_p311.wav b/assets/style_vector_v2/en_US_vctk_p311.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p311.wav rename to assets/style_vector_v2/en_US_vctk_p311.wav diff --git a/style_vector_v2/en_US_vctk_p312.wav b/assets/style_vector_v2/en_US_vctk_p312.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p312.wav rename to assets/style_vector_v2/en_US_vctk_p312.wav diff --git a/style_vector_v2/en_US_vctk_p313.wav b/assets/style_vector_v2/en_US_vctk_p313.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p313.wav rename to assets/style_vector_v2/en_US_vctk_p313.wav diff --git a/style_vector_v2/en_US_vctk_p314.wav b/assets/style_vector_v2/en_US_vctk_p314.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p314.wav rename to assets/style_vector_v2/en_US_vctk_p314.wav diff --git a/style_vector_v2/en_US_vctk_p316.wav b/assets/style_vector_v2/en_US_vctk_p316.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p316.wav rename to assets/style_vector_v2/en_US_vctk_p316.wav diff --git a/style_vector_v2/en_US_vctk_p317.wav b/assets/style_vector_v2/en_US_vctk_p317.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p317.wav rename to assets/style_vector_v2/en_US_vctk_p317.wav diff --git a/style_vector_v2/en_US_vctk_p318.wav b/assets/style_vector_v2/en_US_vctk_p318.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p318.wav rename to assets/style_vector_v2/en_US_vctk_p318.wav diff --git a/style_vector_v2/en_US_vctk_p323.wav b/assets/style_vector_v2/en_US_vctk_p323.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p323.wav rename to assets/style_vector_v2/en_US_vctk_p323.wav diff --git a/style_vector_v2/en_US_vctk_p326.wav b/assets/style_vector_v2/en_US_vctk_p326.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p326.wav rename to assets/style_vector_v2/en_US_vctk_p326.wav diff --git a/style_vector_v2/en_US_vctk_p329.wav b/assets/style_vector_v2/en_US_vctk_p329.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p329.wav rename to assets/style_vector_v2/en_US_vctk_p329.wav diff --git a/style_vector_v2/en_US_vctk_p330.wav b/assets/style_vector_v2/en_US_vctk_p330.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p330.wav rename to assets/style_vector_v2/en_US_vctk_p330.wav diff --git a/style_vector_v2/en_US_vctk_p333.wav b/assets/style_vector_v2/en_US_vctk_p333.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p333.wav rename to assets/style_vector_v2/en_US_vctk_p333.wav diff --git a/style_vector_v2/en_US_vctk_p334.wav b/assets/style_vector_v2/en_US_vctk_p334.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p334.wav rename to assets/style_vector_v2/en_US_vctk_p334.wav diff --git a/style_vector_v2/en_US_vctk_p335.wav b/assets/style_vector_v2/en_US_vctk_p335.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p335.wav rename to assets/style_vector_v2/en_US_vctk_p335.wav diff --git a/style_vector_v2/en_US_vctk_p336.wav b/assets/style_vector_v2/en_US_vctk_p336.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p336.wav rename to assets/style_vector_v2/en_US_vctk_p336.wav diff --git a/style_vector_v2/en_US_vctk_p339.wav b/assets/style_vector_v2/en_US_vctk_p339.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p339.wav rename to assets/style_vector_v2/en_US_vctk_p339.wav diff --git a/style_vector_v2/en_US_vctk_p340.wav b/assets/style_vector_v2/en_US_vctk_p340.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p340.wav rename to assets/style_vector_v2/en_US_vctk_p340.wav diff --git a/style_vector_v2/en_US_vctk_p341.wav b/assets/style_vector_v2/en_US_vctk_p341.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p341.wav rename to assets/style_vector_v2/en_US_vctk_p341.wav diff --git a/style_vector_v2/en_US_vctk_p343.wav b/assets/style_vector_v2/en_US_vctk_p343.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p343.wav rename to assets/style_vector_v2/en_US_vctk_p343.wav diff --git a/style_vector_v2/en_US_vctk_p345.wav b/assets/style_vector_v2/en_US_vctk_p345.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p345.wav rename to assets/style_vector_v2/en_US_vctk_p345.wav diff --git a/style_vector_v2/en_US_vctk_p347.wav b/assets/style_vector_v2/en_US_vctk_p347.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p347.wav rename to assets/style_vector_v2/en_US_vctk_p347.wav diff --git a/style_vector_v2/en_US_vctk_p351.wav b/assets/style_vector_v2/en_US_vctk_p351.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p351.wav rename to assets/style_vector_v2/en_US_vctk_p351.wav diff --git a/style_vector_v2/en_US_vctk_p360.wav b/assets/style_vector_v2/en_US_vctk_p360.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p360.wav rename to assets/style_vector_v2/en_US_vctk_p360.wav diff --git a/style_vector_v2/en_US_vctk_p361.wav b/assets/style_vector_v2/en_US_vctk_p361.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p361.wav rename to assets/style_vector_v2/en_US_vctk_p361.wav diff --git a/style_vector_v2/en_US_vctk_p362.wav b/assets/style_vector_v2/en_US_vctk_p362.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p362.wav rename to assets/style_vector_v2/en_US_vctk_p362.wav diff --git a/style_vector_v2/en_US_vctk_p363.wav b/assets/style_vector_v2/en_US_vctk_p363.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p363.wav rename to assets/style_vector_v2/en_US_vctk_p363.wav diff --git a/style_vector_v2/en_US_vctk_p364.wav b/assets/style_vector_v2/en_US_vctk_p364.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p364.wav rename to assets/style_vector_v2/en_US_vctk_p364.wav diff --git a/style_vector_v2/en_US_vctk_p374.wav b/assets/style_vector_v2/en_US_vctk_p374.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p374.wav rename to assets/style_vector_v2/en_US_vctk_p374.wav diff --git a/style_vector_v2/en_US_vctk_p376.wav b/assets/style_vector_v2/en_US_vctk_p376.wav similarity index 100% rename from style_vector_v2/en_US_vctk_p376.wav rename to assets/style_vector_v2/en_US_vctk_p376.wav diff --git a/style_vector_v2/en_US_vctk_s5.wav b/assets/style_vector_v2/en_US_vctk_s5.wav similarity index 100% rename from style_vector_v2/en_US_vctk_s5.wav rename to assets/style_vector_v2/en_US_vctk_s5.wav diff --git a/audiocraft/activations.py b/audiocraft/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..2d83d7c4c2dc84c64b724eadbe06157507d4f20d --- /dev/null +++ b/audiocraft/activations.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch import Tensor +from typing import Union, Callable + + +class CustomGLU(nn.Module): + """Custom Gated Linear Unit activation. + Applies a modified gated linear unit :math:`a * f(b)` where :math:`a` is the first half + of the input matrices, :math:`b` is the second half, and :math:`f` is a provided activation + function (i.e. sigmoid, swish, etc.). + + Args: + activation (nn.Module): The custom activation to apply in the Gated Linear Unit + dim (int): the dimension on which to split the input. Default: -1 + + Shape: + - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional + dimensions + - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2` + + Examples:: + >>> m = CustomGLU(nn.Sigmoid()) + >>> input = torch.randn(4, 2) + >>> output = m(input) + """ + def __init__(self, activation: nn.Module, dim: int = -1): + super(CustomGLU, self).__init__() + self.dim = dim + self.activation = activation + + def forward(self, x: Tensor): + assert x.shape[self.dim] % 2 == 0 # M = N / 2 + a, b = torch.chunk(x, 2, dim=self.dim) + return a * self.activation(b) + + +class SwiGLU(CustomGLU): + """SiLU Gated Linear Unit activation. + Applies SiLU Gated Linear Unit :math:`a * SiLU(b)` where :math:`a` is + the first half of the input matrices, :math:`b` is the second half. + + Args: + dim (int): the dimension on which to split the input. Default: -1 + """ + def __init__(self, dim: int = -1): + super(SwiGLU, self).__init__(nn.SiLU(), dim) + + +class GeGLU(CustomGLU): + """GeLU Gated Linear Unit activation. + Applies GeLU Gated Linear Unit :math:`a * GELU(b)` where :math:`a` is + the first half of the input matrices, :math:`b` is the second half. + + Args: + dim (int): the dimension on which to split the input. Default: -1 + """ + def __init__(self, dim: int = -1): + super(GeGLU, self).__init__(nn.GELU(), dim) + + +class ReGLU(CustomGLU): + """ReLU Gated Linear Unit activation. + Applies ReLU Gated Linear Unit :math:`a * ReLU(b)` where :math:`a` is + the first half of the input matrices, :math:`b` is the second half. + + Args: + dim (int): the dimension on which to split the input. Default: -1 + """ + def __init__(self, dim: int = -1): + super(ReGLU, self).__init__(nn.ReLU(), dim) + + +def get_activation_fn( + activation: Union[str, Callable[[Tensor], Tensor]] +) -> Union[str, Callable[[Tensor], Tensor]]: + """Helper function to map an activation string to the activation class. + If the supplied activation is not a string that is recognized, the activation is passed back. + + Args: + activation (str, or Callable[[Tensor], Tensor]): Activation to check + """ + if isinstance(activation, str): + if activation == "reglu": + return ReGLU() + elif activation == "geglu": + return GeGLU() + elif activation == "swiglu": + return SwiGLU() + return activation diff --git a/audiocraft/audiogen.py b/audiocraft/audiogen.py new file mode 100644 index 0000000000000000000000000000000000000000..45cf50058f77e928d7010792e224950dba4bfbb7 --- /dev/null +++ b/audiocraft/audiogen.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Main model for using AudioGen. This will combine all the required components +and provide easy access to the generation API. +""" + +import typing as tp +import torch + +from audiocraft.encodec import CompressionModel +from audiocraft.genmodel import BaseGenModel +from audiocraft.lm import LMModel +from audiocraft.loaders import load_compression_model, load_lm_model +from .utils.audio_utils import f32_pcm, normalize_audio + + +def audio_write(stem_name, + wav, + sample_rate, + format= 'wav', + mp3_rate=320, + ogg_rate= None, + normalize= True, + strategy= 'peak', + peak_clip_headroom_db=1, + rms_headroom_db= 18, + loudness_headroom_db = 14, + loudness_compressor = False, + log_clipping = True, + make_parent_dir = True, + add_suffix = True): + + assert wav.dtype.is_floating_point, "wav is not floating point" + if wav.dim() == 1: + wav = wav[None] + elif wav.dim() > 2: + raise ValueError("Input wav should be at most 2 dimension.") + assert wav.isfinite().all() + wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db, + rms_headroom_db, loudness_headroom_db, loudness_compressor, + log_clipping=log_clipping, sample_rate=sample_rate, + stem_name=str(stem_name)) + return wav +# === + +class AudioGen(BaseGenModel): + """AudioGen main model with convenient generation API. + + Args: + name (str): name of the model. + compression_model (CompressionModel): Compression model + used to map audio to invertible discrete representations. + lm (LMModel): Language model over discrete representations. + max_duration (float, optional): maximum duration the model can produce, + otherwise, inferred from the training params. + """ + def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel, + max_duration: tp.Optional[float] = None): + # print(f'Using {compression_model=}\n-----=-----') + super().__init__(name, compression_model, lm, max_duration) + self.set_generation_params(duration=5) # default duration + + @staticmethod + def get_pretrained(name: str = 'facebook/audiogen-medium', device=None): + """Return pretrained model, we provide a single model for now: + - facebook/audiogen-medium (1.5B), text to sound, + # see: https://huggingface.co/facebook/audiogen-medium + """ + if device is None: + if torch.cuda.device_count(): + device = 'cuda' + else: + device = 'cpu' + + + + compression_model = load_compression_model(name, device=device) + lm = load_lm_model(name, device=device) + assert 'self_wav' not in lm.condition_provider.conditioners, \ + "AudioGen do not support waveform conditioning for now" + return AudioGen(name, compression_model, lm) + + def set_generation_params(self, use_sampling: bool = True, top_k: int = 250, + top_p: float = 0.0, temperature: float = 1.0, + duration: float = 10.0, cfg_coef: float = 3.0, + two_step_cfg: bool = False, extend_stride: float = 2): + """Set the generation parameters for AudioGen. + + Args: + use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True. + top_k (int, optional): top_k used for sampling. Defaults to 250. + top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0. + temperature (float, optional): Softmax temperature parameter. Defaults to 1.0. + duration (float, optional): Duration of the generated waveform. Defaults to 10.0. + cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0. + two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance, + instead of batching together the two. This has some impact on how things + are padded but seems to have little impact in practice. + extend_stride: when doing extended generation (i.e. more than 10 seconds), by how much + should we extend the audio each time. Larger values will mean less context is + preserved, and shorter value will require extra computations. + """ + assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration." + self.extend_stride = extend_stride + self.duration = duration + self.generation_params = { + 'use_sampling': use_sampling, + 'temp': temperature, + 'top_k': top_k, + 'top_p': top_p, + 'cfg_coef': cfg_coef, + 'two_step_cfg': two_step_cfg, + } diff --git a/audiocraft/builders.py b/audiocraft/builders.py new file mode 100644 index 0000000000000000000000000000000000000000..3a035f729204851f95708f281e2a4a8635586bb4 --- /dev/null +++ b/audiocraft/builders.py @@ -0,0 +1,263 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +All the functions to build the relevant models and modules +from the Hydra config. +""" + +import typing as tp + +import audiocraft +import omegaconf +import torch + +from .encodec import CompressionModel, EncodecModel, InterleaveStereoCompressionModel +from .lm import LMModel +from .seanet import SEANetEncoder, SEANetDecoder +from .codebooks_patterns import ( + CodebooksPatternProvider, + DelayedPatternProvider, + MusicLMPattern, + ParallelPatternProvider, + UnrolledPatternProvider, + CoarseFirstPattern, +) +from .conditioners import ( + BaseConditioner, + ChromaStemConditioner, + CLAPEmbeddingConditioner, + ConditionFuser, + ConditioningProvider, + LUTConditioner, + T5Conditioner, +) +from .unet import DiffusionUnet +import audiocraft.quantization as qt +from .utils.utils import dict_from_config +from .diffusion_schedule import MultiBandProcessor, SampleProcessor + + +def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer: + klass = { + 'no_quant': qt.DummyQuantizer, + 'rvq': qt.ResidualVectorQuantizer + }[quantizer] + kwargs = dict_from_config(getattr(cfg, quantizer)) + if quantizer != 'no_quant': + kwargs['dimension'] = dimension + return klass(**kwargs) + + +def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig): + if encoder_name == 'seanet': + kwargs = dict_from_config(getattr(cfg, 'seanet')) + encoder_override_kwargs = kwargs.pop('encoder') + decoder_override_kwargs = kwargs.pop('decoder') + encoder_kwargs = {**kwargs, **encoder_override_kwargs} + decoder_kwargs = {**kwargs, **decoder_override_kwargs} + encoder = SEANetEncoder(**encoder_kwargs) + decoder = SEANetDecoder(**decoder_kwargs) + return encoder, decoder + else: + raise KeyError(f"Unexpected compression model {cfg.compression_model}") + + +def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel: + """Instantiate a compression model.""" + if cfg.compression_model == 'encodec': + kwargs = dict_from_config(getattr(cfg, 'encodec')) + encoder_name = kwargs.pop('autoencoder') + quantizer_name = kwargs.pop('quantizer') + encoder, decoder = get_encodec_autoencoder(encoder_name, cfg) + quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension) + frame_rate = kwargs['sample_rate'] // encoder.hop_length + renormalize = kwargs.pop('renormalize', False) + # deprecated params + kwargs.pop('renorm', None) + return EncodecModel(encoder, decoder, quantizer, + frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device) + else: + raise KeyError(f"Unexpected compression model {cfg.compression_model}") + + +def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel: + """Instantiate a transformer LM.""" + if cfg.lm_model in ['transformer_lm', 'transformer_lm_magnet']: + kwargs = dict_from_config(getattr(cfg, 'transformer_lm')) + n_q = kwargs['n_q'] + q_modeling = kwargs.pop('q_modeling', None) + codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern') + attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout')) + cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance')) + cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef'] + fuser = get_condition_fuser(cfg) + condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device) + if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically + kwargs['cross_attention'] = True + if codebooks_pattern_cfg.modeling is None: + assert q_modeling is not None, \ + "LM model should either have a codebook pattern defined or transformer_lm.q_modeling" + codebooks_pattern_cfg = omegaconf.OmegaConf.create( + {'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}} + ) + + pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg) + # lm_class = MagnetLMModel if cfg.lm_model == 'transformer_lm_magnet' else LMModel + lm_class = LMModel # hard coded D + print(f'{lm_class=}\n\n\n\n=====================') + return lm_class( + pattern_provider=pattern_provider, + condition_provider=condition_provider, + fuser=fuser, + cfg_dropout=cfg_prob, + cfg_coef=cfg_coef, + attribute_dropout=attribute_dropout, + dtype=getattr(torch, cfg.dtype), + device=cfg.device, + **kwargs + ).to(cfg.device) + else: + raise KeyError(f"Unexpected LM model {cfg.lm_model}") + + +def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider: + """Instantiate a conditioning model.""" + device = cfg.device + duration = cfg.dataset.segment_duration + cfg = getattr(cfg, 'conditioners') + dict_cfg = {} if cfg is None else dict_from_config(cfg) + conditioners: tp.Dict[str, BaseConditioner] = {} + condition_provider_args = dict_cfg.pop('args', {}) + condition_provider_args.pop('merge_text_conditions_p', None) + condition_provider_args.pop('drop_desc_p', None) + + for cond, cond_cfg in dict_cfg.items(): + model_type = cond_cfg['model'] + model_args = cond_cfg[model_type] + if model_type == 't5': + conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args) + elif model_type == 'lut': + conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args) + elif model_type == 'chroma_stem': + conditioners[str(cond)] = ChromaStemConditioner( + output_dim=output_dim, + duration=duration, + device=device, + **model_args + ) + elif model_type == 'clap': + conditioners[str(cond)] = CLAPEmbeddingConditioner( + output_dim=output_dim, + device=device, + **model_args + ) + else: + raise ValueError(f"Unrecognized conditioning model: {model_type}") + conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args) + return conditioner + + +def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser: + """Instantiate a condition fuser object.""" + fuser_cfg = getattr(cfg, 'fuser') + fuser_methods = ['sum', 'cross', 'prepend', 'input_interpolate'] + fuse2cond = {k: fuser_cfg[k] for k in fuser_methods} + kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods} + fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs) + return fuser + + +def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider: + """Instantiate a codebooks pattern provider object.""" + pattern_providers = { + 'parallel': ParallelPatternProvider, + 'delay': DelayedPatternProvider, + 'unroll': UnrolledPatternProvider, + 'coarse_first': CoarseFirstPattern, + 'musiclm': MusicLMPattern, + } + name = cfg.modeling + kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {} + klass = pattern_providers[name] + return klass(n_q, **kwargs) + + +def get_debug_compression_model(device='cpu', sample_rate: int = 32000): + """Instantiate a debug compression model to be used for unit tests.""" + assert sample_rate in [16000, 32000], "unsupported sample rate for debug compression model" + model_ratios = { + 16000: [10, 8, 8], # 25 Hz at 16kHz + 32000: [10, 8, 16] # 25 Hz at 32kHz + } + ratios: tp.List[int] = model_ratios[sample_rate] + frame_rate = 25 + seanet_kwargs: dict = { + 'n_filters': 4, + 'n_residual_layers': 1, + 'dimension': 32, + 'ratios': ratios, + } + encoder = SEANetEncoder(**seanet_kwargs) + decoder = SEANetDecoder(**seanet_kwargs) + quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4) + init_x = torch.randn(8, 32, 128) + quantizer(init_x, 1) # initialize kmeans etc. + compression_model = EncodecModel( + encoder, decoder, quantizer, + frame_rate=frame_rate, sample_rate=sample_rate, channels=1).to(device) + return compression_model.eval() + + +def get_diffusion_model(cfg: omegaconf.DictConfig): + # TODO Find a way to infer the channels from dset + channels = cfg.channels + num_steps = cfg.schedule.num_steps + return DiffusionUnet( + chin=channels, num_steps=num_steps, **cfg.diffusion_unet) + + +def get_processor(cfg, sample_rate: int = 24000): + sample_processor = SampleProcessor() + if cfg.use: + kw = dict(cfg) + kw.pop('use') + kw.pop('name') + if cfg.name == "multi_band_processor": + sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw) + return sample_processor + + +def get_debug_lm_model(device='cpu'): + """Instantiate a debug LM to be used for unit tests.""" + pattern = DelayedPatternProvider(n_q=4) + dim = 16 + providers = { + 'description': LUTConditioner(n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"), + } + condition_provider = ConditioningProvider(providers) + fuser = ConditionFuser( + {'cross': ['description'], 'prepend': [], + 'sum': [], 'input_interpolate': []}) + lm = LMModel( + pattern, condition_provider, fuser, + n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2, + cross_attention=True, causal=True) + return lm.to(device).eval() + + +def get_wrapped_compression_model( + compression_model: CompressionModel, + cfg: omegaconf.DictConfig) -> CompressionModel: + if hasattr(cfg, 'interleave_stereo_codebooks'): + if cfg.interleave_stereo_codebooks.use: + kwargs = dict_from_config(cfg.interleave_stereo_codebooks) + kwargs.pop('use') + compression_model = InterleaveStereoCompressionModel(compression_model, **kwargs) + if hasattr(cfg, 'compression_model_n_q'): + if cfg.compression_model_n_q is not None: + compression_model.set_num_codebooks(cfg.compression_model_n_q) + return compression_model diff --git a/audiocraft/chroma.py b/audiocraft/chroma.py new file mode 100644 index 0000000000000000000000000000000000000000..e84fb66b4a4aaefb0b3ccac8a9a44c3b20e48f61 --- /dev/null +++ b/audiocraft/chroma.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import typing as tp + +from einops import rearrange +from librosa import filters +import torch +from torch import nn +import torch.nn.functional as F +import torchaudio + + +class ChromaExtractor(nn.Module): + """Chroma extraction and quantization. + + Args: + sample_rate (int): Sample rate for the chroma extraction. + n_chroma (int): Number of chroma bins for the chroma extraction. + radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12). + nfft (int, optional): Number of FFT. + winlen (int, optional): Window length. + winhop (int, optional): Window hop size. + argmax (bool, optional): Whether to use argmax. Defaults to False. + norm (float, optional): Norm for chroma normalization. Defaults to inf. + """ + def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12, nfft: tp.Optional[int] = None, + winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None, argmax: bool = False, + norm: float = torch.inf): + super().__init__() + self.winlen = winlen or 2 ** radix2_exp + self.nfft = nfft or self.winlen + self.winhop = winhop or (self.winlen // 4) + self.sample_rate = sample_rate + self.n_chroma = n_chroma + self.norm = norm + self.argmax = argmax + self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0, + n_chroma=self.n_chroma)), persistent=False) + self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen, + hop_length=self.winhop, power=2, center=True, + pad=0, normalized=True) + + def forward(self, wav: torch.Tensor) -> torch.Tensor: + T = wav.shape[-1] + # in case we are getting a wav that was dropped out (nullified) + # from the conditioner, make sure wav length is no less that nfft + if T < self.nfft: + pad = self.nfft - T + r = 0 if pad % 2 == 0 else 1 + wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0) + assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}" + + spec = self.spec(wav).squeeze(1) + raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec) + norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6) + norm_chroma = rearrange(norm_chroma, 'b d t -> b t d') + + if self.argmax: + idx = norm_chroma.argmax(-1, keepdim=True) + norm_chroma[:] = 0 + norm_chroma.scatter_(dim=-1, index=idx, value=1) + + return norm_chroma diff --git a/audiocraft/codebooks_patterns.py b/audiocraft/codebooks_patterns.py new file mode 100644 index 0000000000000000000000000000000000000000..386df5826937178e29eec670280f8bea57f1a19e --- /dev/null +++ b/audiocraft/codebooks_patterns.py @@ -0,0 +1,548 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from collections import namedtuple +from dataclasses import dataclass +from functools import lru_cache +import logging +import typing as tp + +from abc import ABC, abstractmethod +import torch + +LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index) +PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates +logger = logging.getLogger(__name__) + + +@dataclass +class Pattern: + """Base implementation of a pattern over a sequence with multiple codebooks. + + The codebook pattern consists in a layout, defining for each sequence step + the list of coordinates of each codebook timestep in the resulting interleaved sequence. + The first item of the pattern is always an empty list in order to properly insert a special token + to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern + and ``timesteps`` the number of timesteps corresponding to the original sequence. + + The pattern provides convenient methods to build and revert interleaved sequences from it: + ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T] + to the interleaved sequence of shape [B, K, S] applying the pattern, with B being the batch size, + K being the number of codebooks, T the number of original timesteps and S the number of sequence steps + for the output sequence. The unfilled positions are replaced with a special token and the built sequence + is returned along with a mask indicating valid tokens. + ``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment + of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask + to fill and specify invalid positions if needed. + See the dedicated methods for more details. + """ + # Pattern layout, for each sequence step, we have a list of coordinates + # corresponding to the original codebook timestep and position. + # The first list is always an empty list in order to properly insert + # a special token to start with. + layout: PatternLayout + timesteps: int + n_q: int + + def __post_init__(self): + assert len(self.layout) > 0 + self._validate_layout() + self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes) + self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes) + logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout)) + + def _validate_layout(self): + """Runs checks on the layout to ensure a valid pattern is defined. + A pattern is considered invalid if: + - Multiple timesteps for a same codebook are defined in the same sequence step + - The timesteps for a given codebook are not in ascending order as we advance in the sequence + (this would mean that we have future timesteps before past timesteps). + """ + q_timesteps = {q: 0 for q in range(self.n_q)} + for s, seq_coords in enumerate(self.layout): + if len(seq_coords) > 0: + qs = set() + for coord in seq_coords: + qs.add(coord.q) + last_q_timestep = q_timesteps[coord.q] + assert coord.t >= last_q_timestep, \ + f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}" + q_timesteps[coord.q] = coord.t + # each sequence step contains at max 1 coordinate per codebook + assert len(qs) == len(seq_coords), \ + f"Multiple entries for a same codebook are found at step {s}" + + @property + def num_sequence_steps(self): + return len(self.layout) - 1 + + @property + def max_delay(self): + max_t_in_seq_coords = 0 + for seq_coords in self.layout[1:]: + for coords in seq_coords: + max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1) + return max_t_in_seq_coords - self.timesteps + + @property + def valid_layout(self): + valid_step = len(self.layout) - self.max_delay + return self.layout[:valid_step] + + def starts_with_special_token(self): + return self.layout[0] == [] + + def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None): + """Get codebook coordinates in the layout that corresponds to the specified timestep t + and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step + and the actual codebook coordinates. + """ + assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps" + if q is not None: + assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks" + coords = [] + for s, seq_codes in enumerate(self.layout): + for code in seq_codes: + if code.t == t and (q is None or code.q == q): + coords.append((s, code)) + return coords + + def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]: + return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)] + + def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]: + steps_with_timesteps = self.get_steps_with_timestep(t, q) + return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None + + def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool, + device: tp.Union[torch.device, str] = 'cpu'): + """Build scatter indexes corresponding to the pattern, up to the provided sequence_steps. + + Args: + timesteps (int): Maximum number of timesteps steps to consider. + keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps. + device (torch.device or str): Device for created tensors. + Returns: + indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S]. + mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S]. + """ + assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}" + assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern" + # use the proper layout based on whether we limit ourselves to valid steps only or not, + # note that using the valid_layout will result in a truncated sequence up to the valid steps + ref_layout = self.valid_layout if keep_only_valid_steps else self.layout + # single item indexing being super slow with pytorch vs. numpy, so we use numpy here + indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy() + mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy() + # fill indexes with last sequence step value that will correspond to our special token + # the last value is n_q * timesteps as we have flattened z and append special token as the last token + # which will correspond to the index: n_q * timesteps + indexes[:] = n_q * timesteps + # iterate over the pattern and fill scattered indexes and mask + for s, sequence_coords in enumerate(ref_layout): + for coords in sequence_coords: + if coords.t < timesteps: + indexes[coords.q, s] = coords.t + coords.q * timesteps + mask[coords.q, s] = 1 + indexes = torch.from_numpy(indexes).to(device) + mask = torch.from_numpy(mask).to(device) + return indexes, mask + + def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False): + """Build sequence corresponding to the pattern from the input tensor z. + The sequence is built using up to sequence_steps if specified, and non-pattern + coordinates are filled with the special token. + + Args: + z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T]. + special_token (int): Special token used to fill non-pattern coordinates in the new sequence. + keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps. + Steps that are beyond valid steps will be replaced by the special_token in that case. + Returns: + values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S + corresponding either to the sequence_steps if provided, otherwise to the length of the pattern. + indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S]. + mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S]. + """ + B, K, T = z.shape + indexes, mask = self._build_pattern_sequence_scatter_indexes( + T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device) + ) + z = z.view(B, -1) + # we append the special token as the last index of our flattened z tensor + z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1) + values = z[:, indexes.view(-1)] + values = values.view(B, K, indexes.shape[-1]) + return values, indexes, mask + + def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int, + keep_only_valid_steps: bool = False, + is_model_output: bool = False, + device: tp.Union[torch.device, str] = 'cpu'): + """Builds scatter indexes required to retrieve the original multi-codebook sequence + from interleaving pattern. + + Args: + sequence_steps (int): Sequence steps. + n_q (int): Number of codebooks. + keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps. + Steps that are beyond valid steps will be replaced by the special_token in that case. + is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not. + device (torch.device or str): Device for created tensors. + Returns: + indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T]. + mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T]. + """ + ref_layout = self.valid_layout if keep_only_valid_steps else self.layout + # TODO(jade): Do we want to further truncate to only valid timesteps here as well? + timesteps = self.timesteps + assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}" + assert sequence_steps <= len(ref_layout), \ + f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}" + + # ensure we take the appropriate indexes to keep the model output from the first special token as well + if is_model_output and self.starts_with_special_token(): + ref_layout = ref_layout[1:] + + # single item indexing being super slow with pytorch vs. numpy, so we use numpy here + indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy() + mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy() + # fill indexes with last sequence step value that will correspond to our special token + indexes[:] = n_q * sequence_steps + for s, sequence_codes in enumerate(ref_layout): + if s < sequence_steps: + for code in sequence_codes: + if code.t < timesteps: + indexes[code.q, code.t] = s + code.q * sequence_steps + mask[code.q, code.t] = 1 + indexes = torch.from_numpy(indexes).to(device) + mask = torch.from_numpy(mask).to(device) + return indexes, mask + + def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False): + """Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving. + The sequence is reverted using up to timesteps if specified, and non-pattern coordinates + are filled with the special token. + + Args: + s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S]. + special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence. + Returns: + values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T + corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise. + indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T]. + mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T]. + """ + B, K, S = s.shape + indexes, mask = self._build_reverted_sequence_scatter_indexes( + S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device) + ) + s = s.view(B, -1) + # we append the special token as the last index of our flattened z tensor + s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1) + values = s[:, indexes.view(-1)] + values = values.view(B, K, indexes.shape[-1]) + return values, indexes, mask + + def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False): + """Revert model logits obtained on a sequence built from the pattern + back to a tensor matching the original sequence. + + This method is similar to ``revert_pattern_sequence`` with the following specificities: + 1. It is designed to work with the extra cardinality dimension + 2. We return the logits for the first sequence item that matches the special_token and + which matching target in the original sequence is the first item of the sequence, + while we skip the last logits as there is no matching target + """ + B, card, K, S = logits.shape + indexes, mask = self._build_reverted_sequence_scatter_indexes( + S, K, keep_only_valid_steps, is_model_output=True, device=logits.device + ) + logits = logits.reshape(B, card, -1) + # we append the special token as the last index of our flattened z tensor + logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S] + values = logits[:, :, indexes.view(-1)] + values = values.view(B, card, K, indexes.shape[-1]) + return values, indexes, mask + + +class CodebooksPatternProvider(ABC): + """Abstraction around providing pattern for interleaving codebooks. + + The CodebooksPatternProvider abstraction allows to implement various strategies to + define interleaving pattern of sequences composed of multiple codebooks. For a given + number of codebooks `n_q`, the pattern provider can generate a specified pattern + corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern + can be used to construct a new sequence from the original codes respecting the specified + pattern. The pattern is defined as a list of list of code coordinates, code coordinate + being a tuple with the original timestep and codebook to build the new sequence. + Note that all patterns must start with an empty list that is then used to insert a first + sequence step of special tokens in the newly generated sequence. + + Args: + n_q (int): number of codebooks. + cached (bool): if True, patterns for a given length are cached. In general + that should be true for efficiency reason to avoid synchronization points. + """ + def __init__(self, n_q: int, cached: bool = True): + assert n_q > 0 + self.n_q = n_q + self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore + + @abstractmethod + def get_pattern(self, timesteps: int) -> Pattern: + """Builds pattern with specific interleaving between codebooks. + + Args: + timesteps (int): Total number of timesteps. + """ + raise NotImplementedError() + + +class DelayedPatternProvider(CodebooksPatternProvider): + """Provider for delayed pattern across delayed codebooks. + Codebooks are delayed in the sequence and sequence steps will contain codebooks + from different timesteps. + + Example: + Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence: + [[1, 2, 3, 4], + [1, 2, 3, 4], + [1, 2, 3, 4]] + The resulting sequence obtained from the returned pattern is: + [[S, 1, 2, 3, 4], + [S, S, 1, 2, 3], + [S, S, S, 1, 2]] + (with S being a special token) + + Args: + n_q (int): Number of codebooks. + delays (list of int, optional): Delay for each of the codebooks. + If delays not defined, each codebook is delayed by 1 compared to the previous one. + flatten_first (int): Flatten the first N timesteps. + empty_initial (int): Prepend with N empty list of coordinates. + """ + def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None, + flatten_first: int = 0, empty_initial: int = 0): + super().__init__(n_q) + if delays is None: + delays = list(range(n_q)) + self.delays = delays + self.flatten_first = flatten_first + self.empty_initial = empty_initial + assert len(self.delays) == self.n_q + assert sorted(self.delays) == self.delays + + def get_pattern(self, timesteps: int) -> Pattern: + omit_special_token = self.empty_initial < 0 + out: PatternLayout = [] if omit_special_token else [[]] + max_delay = max(self.delays) + if self.empty_initial: + out += [[] for _ in range(self.empty_initial)] + if self.flatten_first: + for t in range(min(timesteps, self.flatten_first)): + for q in range(self.n_q): + out.append([LayoutCoord(t, q)]) + for t in range(self.flatten_first, timesteps + max_delay): + v = [] + for q, delay in enumerate(self.delays): + t_for_q = t - delay + if t_for_q >= self.flatten_first: + v.append(LayoutCoord(t_for_q, q)) + out.append(v) + return Pattern(out, n_q=self.n_q, timesteps=timesteps) + + +class ParallelPatternProvider(DelayedPatternProvider): + """Provider for parallel pattern across codebooks. + This pattern provider is a special case of the delayed pattern with actually no delay, + hence delays=repeat(0, n_q). + + Args: + n_q (int): Number of codebooks. + empty_initial (int): Prepend with N empty list of coordinates. + """ + def __init__(self, n_q: int, empty_initial: int = 0): + super().__init__(n_q, [0] * n_q, empty_initial=empty_initial) + + +class UnrolledPatternProvider(CodebooksPatternProvider): + """Provider for unrolling codebooks pattern. + This pattern provider enables to represent the codebook flattened completely or only to some extend + while also specifying a given delay between the flattened codebooks representation, allowing to + unroll the codebooks in the sequence. + + Example: + 1. Flattening of the codebooks. + By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q), + taking n_q = 3 and timesteps = 4: + [[1, 2, 3, 4], + [1, 2, 3, 4], + [1, 2, 3, 4]] + will result into: + [[S, S, 1, S, S, 2, S, S, 3, S, S, 4], + [S, 1, S, S, 2, S, S, 3, S, S, 4, S], + [1, S, S, 2, S, S, 3, S, S, 4, S, S]] + 2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step + for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example + taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]: + [[1, 2, 3, 4], + [1, 2, 3, 4], + [1, 2, 3, 4]] + will result into: + [[S, 1, S, S, 2, S, S, 3, S, S, 4, S], + [S, 1, S, S, 2, S, S, 3, S, S, 4, S], + [1, S, S, 2, S, S, 3, S, S, 4, S, S]] + 3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks + allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the + same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1] + and delays = [0, 3, 3]: + [[1, 2, 3, 4], + [1, 2, 3, 4], + [1, 2, 3, 4]] + will result into: + [[S, S, S, 1, S, 2, S, 3, S, 4], + [S, S, S, 1, S, 2, S, 3, S, 4], + [1, 2, 3, S, 4, S, 5, S, 6, S]] + + Args: + n_q (int): Number of codebooks. + flattening (list of int, optional): Flattening schema over the codebooks. If not defined, + the codebooks will be flattened to 1 codebook per step, meaning that the sequence will + have n_q extra steps for each timestep. + delays (list of int, optional): Delay for each of the codebooks. If not defined, + no delay is added and therefore will default to [0] * ``n_q``. + Note that two codebooks that will be flattened to the same inner step + should have the same delay, otherwise the pattern is considered as invalid. + """ + FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay']) + + def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None, + delays: tp.Optional[tp.List[int]] = None): + super().__init__(n_q) + if flattening is None: + flattening = list(range(n_q)) + if delays is None: + delays = [0] * n_q + assert len(flattening) == n_q + assert len(delays) == n_q + assert sorted(flattening) == flattening + assert sorted(delays) == delays + self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening) + self.max_delay = max(delays) + + def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]): + """Build a flattened codebooks representation as a dictionary of inner step + and the actual codebook indices corresponding to the flattened codebook. For convenience, we + also store the delay associated to the flattened codebook to avoid maintaining an extra mapping. + """ + flattened_codebooks: dict = {} + for q, (inner_step, delay) in enumerate(zip(flattening, delays)): + if inner_step not in flattened_codebooks: + flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay) + else: + flat_codebook = flattened_codebooks[inner_step] + assert flat_codebook.delay == delay, ( + "Delay and flattening between codebooks is inconsistent: ", + "two codebooks flattened to the same position should have the same delay." + ) + flat_codebook.codebooks.append(q) + flattened_codebooks[inner_step] = flat_codebook + return flattened_codebooks + + @property + def _num_inner_steps(self): + """Number of inner steps to unroll between timesteps in order to flatten the codebooks. + """ + return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1 + + def num_virtual_steps(self, timesteps: int) -> int: + return timesteps * self._num_inner_steps + 1 + + def get_pattern(self, timesteps: int) -> Pattern: + """Builds pattern for delay across codebooks. + + Args: + timesteps (int): Total number of timesteps. + """ + # the PatternLayout is built as a tuple of sequence position and list of coordinates + # so that it can be reordered properly given the required delay between codebooks of given timesteps + indexed_out: list = [(-1, [])] + max_timesteps = timesteps + self.max_delay + for t in range(max_timesteps): + # for each timestep, we unroll the flattened codebooks, + # emitting the sequence step with the corresponding delay + for step in range(self._num_inner_steps): + if step in self._flattened_codebooks: + # we have codebooks at this virtual step to emit + step_codebooks = self._flattened_codebooks[step] + t_for_q = t + step_codebooks.delay + coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks] + if t_for_q < max_timesteps and t < max_timesteps: + indexed_out.append((t_for_q, coords)) + else: + # there is no codebook in this virtual step so we emit an empty list + indexed_out.append((t, [])) + out = [coords for _, coords in sorted(indexed_out)] + return Pattern(out, n_q=self.n_q, timesteps=timesteps) + + +class CoarseFirstPattern(CodebooksPatternProvider): + """First generates all the codebooks #1 (e.g. coarser), then the remaining ones, + potentially with delays. + + ..Warning:: You must always generate the full training duration at test time, for instance, + 30 seconds, as otherwise, the fine codebooks will start being generated in an unexpected + location. This is due to the non causality of the remaining codebooks with respect to + the first ones. + + Args: + n_q (int): Number of codebooks. + delays (list of int, optional): Delay for each of the codebooks. + If delays not defined, each codebook is delayed by 1 compared to the previous one. + """ + def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None): + super().__init__(n_q) + if delays is None: + delays = [0] * (n_q - 1) + self.delays = delays + assert len(self.delays) == self.n_q - 1 + assert sorted(self.delays) == self.delays + + def get_pattern(self, timesteps: int) -> Pattern: + out: PatternLayout = [[]] + for t in range(timesteps): + out.append([LayoutCoord(t, 0)]) + max_delay = max(self.delays) + for t in range(timesteps + max_delay): + v = [] + for q, delay in enumerate(self.delays): + t_for_q = t - delay + if t_for_q >= 0: + v.append(LayoutCoord(t_for_q, q + 1)) + out.append(v) + return Pattern(out, n_q=self.n_q, timesteps=timesteps) + + +class MusicLMPattern(CodebooksPatternProvider): + """Almost MusicLM style pattern. This is equivalent to full flattening + but in a different order. + + Args: + n_q (int): Number of codebooks. + group_by (int): Number of codebooks to group together. + """ + def __init__(self, n_q: int, group_by: int = 2): + super().__init__(n_q) + self.group_by = group_by + + def get_pattern(self, timesteps: int) -> Pattern: + out: PatternLayout = [[]] + for offset in range(0, self.n_q, self.group_by): + for t in range(timesteps): + for q in range(offset, offset + self.group_by): + out.append([LayoutCoord(t, q)]) + return Pattern(out, n_q=self.n_q, timesteps=timesteps) diff --git a/audiocraft/conditioners.py b/audiocraft/conditioners.py new file mode 100644 index 0000000000000000000000000000000000000000..0a1d3a7718b0145ff2e103e9b8298ba1263d316a --- /dev/null +++ b/audiocraft/conditioners.py @@ -0,0 +1,1409 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from collections import defaultdict +from copy import deepcopy +from dataclasses import dataclass, field +from itertools import chain +import logging +import math +from pathlib import Path +import random +import re +import typing as tp +import warnings +import soundfile +import einops +from num2words import num2words +import spacy +from transformers import RobertaTokenizer, T5EncoderModel, T5Tokenizer # type: ignore +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence +from .streaming import StreamingModule + +from .chroma import ChromaExtractor +from .streaming import StreamingModule +from .transformer import create_sin_embedding + + +from .quantization import ResidualVectorQuantizer +from .utils.autocast import TorchAutocast +from .utils.cache import EmbeddingCache +from .utils.utils import collate, hash_trick, length_to_mask, load_clap_state_dict, warn_once + + +logger = logging.getLogger(__name__) +TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist) +ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask + + +class WavCondition(tp.NamedTuple): + wav: torch.Tensor + length: torch.Tensor + sample_rate: tp.List[int] + path: tp.List[tp.Optional[str]] = [] + seek_time: tp.List[tp.Optional[float]] = [] + + +class JointEmbedCondition(tp.NamedTuple): + wav: torch.Tensor + text: tp.List[tp.Optional[str]] + length: torch.Tensor + sample_rate: tp.List[int] + path: tp.List[tp.Optional[str]] = [] + seek_time: tp.List[tp.Optional[float]] = [] + + +@dataclass +class ConditioningAttributes: + text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict) + wav: tp.Dict[str, WavCondition] = field(default_factory=dict) + joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict) + + def __getitem__(self, item): + return getattr(self, item) + + @property + def text_attributes(self): + return self.text.keys() + + @property + def wav_attributes(self): + return self.wav.keys() + + @property + def joint_embed_attributes(self): + return self.joint_embed.keys() + + @property + def attributes(self): + return { + "text": self.text_attributes, + "wav": self.wav_attributes, + "joint_embed": self.joint_embed_attributes, + } + + def to_flat_dict(self): + return { + **{f"text.{k}": v for k, v in self.text.items()}, + **{f"wav.{k}": v for k, v in self.wav.items()}, + **{f"joint_embed.{k}": v for k, v in self.joint_embed.items()} + } + + @classmethod + def from_flat_dict(cls, x): + out = cls() + for k, v in x.items(): + kind, att = k.split(".") + out[kind][att] = v + return out + + + + + +def nullify_condition(condition: ConditionType, dim: int = 1): + """Transform an input condition to a null condition. + The way it is done by converting it to a single zero vector similarly + to how it is done inside WhiteSpaceTokenizer and NoopTokenizer. + + Args: + condition (ConditionType): A tuple of condition and mask (tuple[torch.Tensor, torch.Tensor]) + dim (int): The dimension that will be truncated (should be the time dimension) + WARNING!: dim should not be the batch dimension! + Returns: + ConditionType: A tuple of null condition and mask + """ + assert dim != 0, "dim cannot be the batch dimension!" + assert isinstance(condition, tuple) and \ + isinstance(condition[0], torch.Tensor) and \ + isinstance(condition[1], torch.Tensor), "'nullify_condition' got an unexpected input type!" + cond, mask = condition + B = cond.shape[0] + last_dim = cond.dim() - 1 + out = cond.transpose(dim, last_dim) + out = 0. * out[..., :1] + out = out.transpose(dim, last_dim) + mask = torch.zeros((B, 1), device=out.device).int() + assert cond.dim() == out.dim() + return out, mask + + +def nullify_wav(cond: WavCondition) -> WavCondition: + """Transform a WavCondition to a nullified WavCondition. + It replaces the wav by a null tensor, forces its length to 0, and replaces metadata by dummy attributes. + + Args: + cond (WavCondition): Wav condition with wav, tensor of shape [B, T]. + Returns: + WavCondition: Nullified wav condition. + """ + null_wav, _ = nullify_condition((cond.wav, torch.zeros_like(cond.wav)), dim=cond.wav.dim() - 1) + return WavCondition( + wav=null_wav, + length=torch.tensor([0] * cond.wav.shape[0], device=cond.wav.device), + sample_rate=cond.sample_rate, + path=[None] * cond.wav.shape[0], + seek_time=[None] * cond.wav.shape[0], + ) + + +def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition: + """Nullify the joint embedding condition by replacing it by a null tensor, forcing its length to 0, + and replacing metadata by dummy attributes. + + Args: + cond (JointEmbedCondition): Joint embedding condition with wav and text, wav tensor of shape [B, C, T]. + """ + null_wav, _ = nullify_condition((embed.wav, torch.zeros_like(embed.wav)), dim=embed.wav.dim() - 1) + return JointEmbedCondition( + wav=null_wav, text=[None] * len(embed.text), + length=torch.LongTensor([0]).to(embed.wav.device), + sample_rate=embed.sample_rate, + path=[None] * embed.wav.shape[0], + seek_time=[0] * embed.wav.shape[0], + ) + + +class Tokenizer: + """Base tokenizer implementation + (in case we want to introduce more advances tokenizers in the future). + """ + def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError() + + +class WhiteSpaceTokenizer(Tokenizer): + """This tokenizer should be used for natural language descriptions. + For example: + ["he didn't, know he's going home.", 'shorter sentence'] => + [[78, 62, 31, 4, 78, 25, 19, 34], + [59, 77, 0, 0, 0, 0, 0, 0]] + """ + PUNCTUATION = "?:!.,;" + + def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm", + lemma: bool = True, stopwords: bool = True) -> None: + self.n_bins = n_bins + self.pad_idx = pad_idx + self.lemma = lemma + self.stopwords = stopwords + try: + self.nlp = spacy.load(language) + except IOError: + spacy.cli.download(language) # type: ignore + self.nlp = spacy.load(language) + + @tp.no_type_check + def __call__(self, texts: tp.List[tp.Optional[str]], + return_text: bool = False) -> tp.Tuple[torch.Tensor, torch.Tensor]: + """Take a list of strings and convert them to a tensor of indices. + + Args: + texts (list[str]): List of strings. + return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False. + Returns: + tuple[torch.Tensor, torch.Tensor]: + - Indices of words in the LUT. + - And a mask indicating where the padding tokens are + """ + output, lengths = [], [] + texts = deepcopy(texts) + for i, text in enumerate(texts): + # if current sample doesn't have a certain attribute, replace with pad token + if text is None: + output.append(torch.Tensor([self.pad_idx])) + lengths.append(0) + continue + + # convert numbers to words + text = re.sub(r"(\d+)", lambda x: num2words(int(x.group(0))), text) # type: ignore + # normalize text + text = self.nlp(text) # type: ignore + # remove stopwords + if self.stopwords: + text = [w for w in text if not w.is_stop] # type: ignore + # remove punctuation + text = [w for w in text if w.text not in self.PUNCTUATION] # type: ignore + # lemmatize if needed + text = [getattr(t, "lemma_" if self.lemma else "text") for t in text] # type: ignore + + texts[i] = " ".join(text) + lengths.append(len(text)) + # convert to tensor + tokens = torch.Tensor([hash_trick(w, self.n_bins) for w in text]) + output.append(tokens) + + mask = length_to_mask(torch.IntTensor(lengths)).int() + padded_output = pad_sequence(output, padding_value=self.pad_idx).int().t() + if return_text: + return padded_output, mask, texts # type: ignore + return padded_output, mask + + +class NoopTokenizer(Tokenizer): + """This tokenizer should be used for global conditioners such as: artist, genre, key, etc. + The difference between this and WhiteSpaceTokenizer is that NoopTokenizer does not split + strings, so "Jeff Buckley" will get it's own index. Whereas WhiteSpaceTokenizer will + split it to ["Jeff", "Buckley"] and return an index per word. + + For example: + ["Queen", "ABBA", "Jeff Buckley"] => [43, 55, 101] + ["Metal", "Rock", "Classical"] => [0, 223, 51] + """ + def __init__(self, n_bins: int, pad_idx: int = 0): + self.n_bins = n_bins + self.pad_idx = pad_idx + + def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + output, lengths = [], [] + for text in texts: + # if current sample doesn't have a certain attribute, replace with pad token + if text is None: + output.append(self.pad_idx) + lengths.append(0) + else: + output.append(hash_trick(text, self.n_bins)) + lengths.append(1) + + tokens = torch.LongTensor(output).unsqueeze(1) + mask = length_to_mask(torch.IntTensor(lengths)).int() + return tokens, mask + + +class BaseConditioner(nn.Module): + """Base model for all conditioner modules. + We allow the output dim to be different than the hidden dim for two reasons: + 1) keep our LUTs small when the vocab is large; + 2) make all condition dims consistent. + + Args: + dim (int): Hidden dim of the model. + output_dim (int): Output dim of the conditioner. + """ + def __init__(self, dim: int, output_dim: int): + super().__init__() + self.dim = dim + self.output_dim = output_dim + self.output_proj = nn.Linear(dim, output_dim) + + def tokenize(self, *args, **kwargs) -> tp.Any: + """Should be any part of the processing that will lead to a synchronization + point, e.g. BPE tokenization with transfer to the GPU. + + The returned value will be saved and return later when calling forward(). + """ + raise NotImplementedError() + + def forward(self, inputs: tp.Any) -> ConditionType: + """Gets input that should be used as conditioning (e.g, genre, description or a waveform). + Outputs a ConditionType, after the input data was embedded as a dense vector. + + Returns: + ConditionType: + - A tensor of size [B, T, D] where B is the batch size, T is the length of the + output embedding and D is the dimension of the embedding. + - And a mask indicating where the padding tokens. + """ + raise NotImplementedError() + + +class TextConditioner(BaseConditioner): + ... + + +class LUTConditioner(TextConditioner): + """Lookup table TextConditioner. + + Args: + n_bins (int): Number of bins. + dim (int): Hidden dim of the model (text-encoder/LUT). + output_dim (int): Output dim of the conditioner. + tokenizer (str): Name of the tokenizer. + pad_idx (int, optional): Index for padding token. Defaults to 0. + """ + def __init__(self, n_bins: int, dim: int, output_dim: int, tokenizer: str, pad_idx: int = 0): + super().__init__(dim, output_dim) + self.embed = nn.Embedding(n_bins, dim) + self.tokenizer: Tokenizer + if tokenizer == 'whitespace': + self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx) + elif tokenizer == 'noop': + self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx) + else: + raise ValueError(f"unrecognized tokenizer `{tokenizer}`.") + + def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + device = self.embed.weight.device + tokens, mask = self.tokenizer(x) + tokens, mask = tokens.to(device), mask.to(device) + return tokens, mask + + def forward(self, inputs: tp.Tuple[torch.Tensor, torch.Tensor]) -> ConditionType: + tokens, mask = inputs + embeds = self.embed(tokens) + embeds = self.output_proj(embeds) + embeds = (embeds * mask.unsqueeze(-1)) + return embeds, mask + + +class T5Conditioner(TextConditioner): + """T5-based TextConditioner. + + Args: + name (str): Name of the T5 model. + output_dim (int): Output dim of the conditioner. + finetune (bool): Whether to fine-tune T5 at train time. + device (str): Device for T5 Conditioner. + autocast_dtype (tp.Optional[str], optional): Autocast dtype. + word_dropout (float, optional): Word dropout probability. + normalize_text (bool, optional): Whether to apply text normalization. + """ + MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b", + "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", + "google/flan-t5-xl", "google/flan-t5-xxl"] + MODELS_DIMS = { + "t5-small": 512, + "t5-base": 768, + "t5-large": 1024, + "t5-3b": 1024, + "t5-11b": 1024, + "google/flan-t5-small": 512, + "google/flan-t5-base": 768, + "google/flan-t5-large": 1024, + "google/flan-t5-3b": 1024, + "google/flan-t5-11b": 1024, + } + + def __init__(self, name: str, output_dim: int, finetune: bool, device: str, + autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0., + normalize_text: bool = False): + assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})" + super().__init__(self.MODELS_DIMS[name], output_dim) + self.device = device + self.name = name + self.finetune = finetune + self.word_dropout = word_dropout + if autocast_dtype is None or self.device == 'cpu': + self.autocast = TorchAutocast(enabled=False) + if self.device != 'cpu': + logger.warning("T5 has no autocast, this might lead to NaN") + else: + dtype = getattr(torch, autocast_dtype) + assert isinstance(dtype, torch.dtype) + logger.info(f"T5 will be evaluated with autocast as {autocast_dtype}") + self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype) + # Let's disable logging temporarily because T5 will vomit some errors otherwise. + # thanks https://gist.github.com/simon-weber/7853144 + previous_level = logging.root.manager.disable + logging.disable(logging.ERROR) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + self.t5_tokenizer = T5Tokenizer.from_pretrained(name) + t5 = T5EncoderModel.from_pretrained(name).train(mode=finetune) + finally: + logging.disable(previous_level) + if finetune: + self.t5 = t5 + else: + # this makes sure that the t5 models is not part + # of the saved checkpoint + self.__dict__['t5'] = t5.to(device) + + self.normalize_text = normalize_text + if normalize_text: + self.text_normalizer = WhiteSpaceTokenizer(1, lemma=True, stopwords=True) + + def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]: + # if current sample doesn't have a certain attribute, replace with empty string + entries: tp.List[str] = [xi if xi is not None else "" for xi in x] + if self.normalize_text: + _, _, entries = self.text_normalizer(entries, return_text=True) + if self.word_dropout > 0. and self.training: + new_entries = [] + for entry in entries: + words = [word for word in entry.split(" ") if random.random() >= self.word_dropout] + new_entries.append(" ".join(words)) + entries = new_entries + + empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""]) + + inputs = self.t5_tokenizer(entries, return_tensors='pt', padding=True).to(self.device) + mask = inputs['attention_mask'] + mask[empty_idx, :] = 0 # zero-out index where the input is non-existant + return inputs + + def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType: + mask = inputs['attention_mask'] + with torch.set_grad_enabled(self.finetune), self.autocast: + embeds = self.t5(**inputs).last_hidden_state + embeds = self.output_proj(embeds.to(self.output_proj.weight)) + embeds = (embeds * mask.unsqueeze(-1)) + return embeds, mask + + +class WaveformConditioner(BaseConditioner): + """Base class for all conditioners that take a waveform as input. + Classes that inherit must implement `_get_wav_embedding` that outputs + a continuous tensor, and `_downsampling_factor` that returns the down-sampling + factor of the embedding model. + + Args: + dim (int): The internal representation dimension. + output_dim (int): Output dimension. + device (tp.Union[torch.device, str]): Device. + """ + def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]): + super().__init__(dim, output_dim) + self.device = device + # if False no masking is done, used in ChromaStemConditioner when completing by periodicity a sample. + self._use_masking = True + + def tokenize(self, x: WavCondition) -> WavCondition: + wav, length, sample_rate, path, seek_time = x + assert length is not None + return WavCondition(wav.to(self.device), length.to(self.device), sample_rate, path, seek_time) + + def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor: + """Gets as input a WavCondition and returns a dense embedding.""" + raise NotImplementedError() + + def _downsampling_factor(self): + """Returns the downsampling factor of the embedding model.""" + raise NotImplementedError() + + def forward(self, x: WavCondition) -> ConditionType: + """Extract condition embedding and mask from a waveform and its metadata. + Args: + x (WavCondition): Waveform condition containing raw waveform and metadata. + Returns: + ConditionType: a dense vector representing the conditioning along with its mask + """ + wav, lengths, *_ = x + with torch.no_grad(): + embeds = self._get_wav_embedding(x) + embeds = embeds.to(self.output_proj.weight) + embeds = self.output_proj(embeds) + + if lengths is not None and self._use_masking: + lengths = lengths / self._downsampling_factor() + mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore + else: + mask = torch.ones_like(embeds[..., 0]) + embeds = (embeds * mask.unsqueeze(-1)) + return embeds, mask + + +class ChromaStemConditioner(WaveformConditioner): + """Chroma conditioner based on stems. + The ChromaStemConditioner uses DEMUCS to first filter out drums and bass, as + the drums and bass often dominate the chroma leading to the chroma features + not containing information about the melody. + + Args: + output_dim (int): Output dimension for the conditioner. + sample_rate (int): Sample rate for the chroma extractor. + n_chroma (int): Number of chroma bins for the chroma extractor. + radix2_exp (int): Size of stft window for the chroma extractor (power of 2, e.g. 12 -> 2^12). + duration (int): duration used during training. This is later used for correct padding + in case we are using chroma as prefix. + match_len_on_eval (bool, optional): if True then all chromas are padded to the training + duration. Defaults to False. + eval_wavs (str, optional): path to a dataset manifest with waveform, this waveforms are used as + conditions during eval (for cases where we don't want to leak test conditions like MusicCaps). + Defaults to None. + n_eval_wavs (int, optional): limits the number of waveforms used for conditioning. Defaults to 0. + device (tp.Union[torch.device, str], optional): Device for the conditioner. + **kwargs: Additional parameters for the chroma extractor. + """ + def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int, + duration: float, match_len_on_eval: bool = True, eval_wavs: tp.Optional[str] = None, + n_eval_wavs: int = 0, cache_path: tp.Optional[tp.Union[str, Path]] = None, + device: tp.Union[torch.device, str] = 'cpu', **kwargs): + from demucs import pretrained + super().__init__(dim=n_chroma, output_dim=output_dim, device=device) + self.autocast = TorchAutocast(enabled=device != 'cpu', device_type=self.device, dtype=torch.float32) + self.sample_rate = sample_rate + self.match_len_on_eval = match_len_on_eval + if match_len_on_eval: + self._use_masking = False + self.duration = duration + self.__dict__['demucs'] = pretrained.get_model('htdemucs').to(device) + stem_sources: list = self.demucs.sources # type: ignore + self.stem_indices = torch.LongTensor([stem_sources.index('vocals'), stem_sources.index('other')]).to(device) + self.chroma = ChromaExtractor(sample_rate=sample_rate, n_chroma=n_chroma, + radix2_exp=radix2_exp, **kwargs).to(device) + self.chroma_len = self._get_chroma_len() + self.eval_wavs: tp.Optional[torch.Tensor] = self._load_eval_wavs(eval_wavs, n_eval_wavs) + self.cache = None + if cache_path is not None: + self.cache = EmbeddingCache(Path(cache_path) / 'wav', self.device, + compute_embed_fn=self._get_full_chroma_for_cache, + extract_embed_fn=self._extract_chroma_chunk) + + def _downsampling_factor(self) -> int: + return self.chroma.winhop + + def _load_eval_wavs(self, path: tp.Optional[str], num_samples: int) -> tp.Optional[torch.Tensor]: + """Load pre-defined waveforms from a json. + These waveforms will be used for chroma extraction during evaluation. + This is done to make the evaluation on MusicCaps fair (we shouldn't see the chromas of MusicCaps). + """ + if path is None: + return None + + logger.info(f"Loading evaluation wavs from {path}") + from audiocraft.data.audio_dataset import AudioDataset + dataset: AudioDataset = AudioDataset.from_meta( + path, segment_duration=self.duration, min_audio_duration=self.duration, + sample_rate=self.sample_rate, channels=1) + + if len(dataset) > 0: + eval_wavs = dataset.collater([dataset[i] for i in range(num_samples)]).to(self.device) + logger.info(f"Using {len(eval_wavs)} evaluation wavs for chroma-stem conditioner") + return eval_wavs + else: + raise ValueError("Could not find evaluation wavs, check lengths of wavs") + + def reset_eval_wavs(self, eval_wavs: tp.Optional[torch.Tensor]) -> None: + self.eval_wavs = eval_wavs + + def has_eval_wavs(self) -> bool: + return self.eval_wavs is not None + + def _sample_eval_wavs(self, num_samples: int) -> torch.Tensor: + """Sample wavs from a predefined list.""" + assert self.eval_wavs is not None, "Cannot sample eval wavs as no eval wavs provided." + total_eval_wavs = len(self.eval_wavs) + out = self.eval_wavs + if num_samples > total_eval_wavs: + out = self.eval_wavs.repeat(num_samples // total_eval_wavs + 1, 1, 1) + return out[torch.randperm(len(out))][:num_samples] + + def _get_chroma_len(self) -> int: + """Get length of chroma during training.""" + dummy_wav = torch.zeros((1, int(self.sample_rate * self.duration)), device=self.device) + dummy_chr = self.chroma(dummy_wav) + return dummy_chr.shape[1] + + @torch.no_grad() + def _get_stemmed_wav(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor: + """Get parts of the wav that holds the melody, extracting the main stems from the wav.""" + from demucs.apply import apply_model + from demucs.audio import convert_audio + with self.autocast: + wav = convert_audio( + wav, sample_rate, self.demucs.samplerate, self.demucs.audio_channels) # type: ignore + stems = apply_model(self.demucs, wav, device=self.device) + stems = stems[:, self.stem_indices] # extract relevant stems for melody conditioning + mix_wav = stems.sum(1) # merge extracted stems to single waveform + mix_wav = convert_audio(mix_wav, self.demucs.samplerate, self.sample_rate, 1) # type: ignore + return mix_wav + + @torch.no_grad() + def _extract_chroma(self, wav: torch.Tensor) -> torch.Tensor: + """Extract chroma features from the waveform.""" + with self.autocast: + return self.chroma(wav) + + @torch.no_grad() + def _compute_wav_embedding(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor: + """Compute wav embedding, applying stem and chroma extraction.""" + # avoid 0-size tensors when we are working with null conds + if wav.shape[-1] == 1: + return self._extract_chroma(wav) + stems = self._get_stemmed_wav(wav, sample_rate) + chroma = self._extract_chroma(stems) + return chroma + + @torch.no_grad() + def _get_full_chroma_for_cache(self, path: tp.Union[str, Path], x: WavCondition, idx: int) -> torch.Tensor: + """Extract chroma from the whole audio waveform at the given path.""" + wav, sr = soundfile.read(path) + wav = wav[None].to(self.device) + wav = convert_audio(wav, sr, self.sample_rate, to_channels=1) + chroma = self._compute_wav_embedding(wav, self.sample_rate)[0] + return chroma + + def _extract_chroma_chunk(self, full_chroma: torch.Tensor, x: WavCondition, idx: int) -> torch.Tensor: + """Extract a chunk of chroma from the full chroma derived from the full waveform.""" + wav_length = x.wav.shape[-1] + seek_time = x.seek_time[idx] + assert seek_time is not None, ( + "WavCondition seek_time is required " + "when extracting chroma chunks from pre-computed chroma.") + full_chroma = full_chroma.float() + frame_rate = self.sample_rate / self._downsampling_factor() + target_length = int(frame_rate * wav_length / self.sample_rate) + index = int(frame_rate * seek_time) + out = full_chroma[index: index + target_length] + out = F.pad(out[None], (0, 0, 0, target_length - out.shape[0]))[0] + return out.to(self.device) + + @torch.no_grad() + def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor: + """Get the wav embedding from the WavCondition. + The conditioner will either extract the embedding on-the-fly computing it from the condition wav directly + or will rely on the embedding cache to load the pre-computed embedding if relevant. + """ + sampled_wav: tp.Optional[torch.Tensor] = None + if not self.training and self.eval_wavs is not None: + warn_once(logger, "Using precomputed evaluation wavs!") + sampled_wav = self._sample_eval_wavs(len(x.wav)) + + no_undefined_paths = all(p is not None for p in x.path) + no_nullified_cond = x.wav.shape[-1] > 1 + if sampled_wav is not None: + chroma = self._compute_wav_embedding(sampled_wav, self.sample_rate) + elif self.cache is not None and no_undefined_paths and no_nullified_cond: + paths = [Path(p) for p in x.path if p is not None] + chroma = self.cache.get_embed_from_cache(paths, x) + else: + assert all(sr == x.sample_rate[0] for sr in x.sample_rate), "All sample rates in batch should be equal." + chroma = self._compute_wav_embedding(x.wav, x.sample_rate[0]) + + if self.match_len_on_eval: + B, T, C = chroma.shape + if T > self.chroma_len: + chroma = chroma[:, :self.chroma_len] + logger.debug(f"Chroma was truncated to match length! ({T} -> {chroma.shape[1]})") + elif T < self.chroma_len: + n_repeat = int(math.ceil(self.chroma_len / T)) + chroma = chroma.repeat(1, n_repeat, 1) + chroma = chroma[:, :self.chroma_len] + logger.debug(f"Chroma was repeated to match length! ({T} -> {chroma.shape[1]})") + + return chroma + + def tokenize(self, x: WavCondition) -> WavCondition: + """Apply WavConditioner tokenization and populate cache if needed.""" + x = super().tokenize(x) + no_undefined_paths = all(p is not None for p in x.path) + if self.cache is not None and no_undefined_paths: + paths = [Path(p) for p in x.path if p is not None] + self.cache.populate_embed_cache(paths, x) + return x + + +class JointEmbeddingConditioner(BaseConditioner): + """Joint embedding conditioning supporting both audio or text conditioning. + + Args: + dim (int): Dimension. + output_dim (int): Output dimension. + device (str): Device. + attribute (str): Attribute used by the conditioner. + autocast_dtype (str): Autocast for the conditioner. + quantize (bool): Whether to quantize the CLAP embedding. + n_q (int): Number of residual quantizers (used if quantize is true). + bins (int): Quantizers' codebooks size (used if quantize is true). + kwargs: Additional parameters for residual vector quantizer. + """ + def __init__(self, dim: int, output_dim: int, device: str, attribute: str, + autocast_dtype: tp.Optional[str] = 'float32', quantize: bool = True, + n_q: int = 12, bins: int = 1024, **kwargs): + super().__init__(dim=dim, output_dim=output_dim) + self.device = device + self.attribute = attribute + if autocast_dtype is None or device == 'cpu': + self.autocast = TorchAutocast(enabled=False) + logger.warning("JointEmbeddingConditioner has no autocast, this might lead to NaN.") + else: + dtype = getattr(torch, autocast_dtype) + assert isinstance(dtype, torch.dtype) + logger.info(f"JointEmbeddingConditioner will be evaluated with autocast as {autocast_dtype}.") + self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype) + # residual vector quantizer to discretize the conditioned embedding + self.quantizer: tp.Optional[ResidualVectorQuantizer] = None + if quantize: + self.quantizer = ResidualVectorQuantizer(dim, n_q=n_q, bins=bins, **kwargs) + + def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]: + """Get joint embedding in latent space from the inputs. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Tensor for the latent embedding + and corresponding empty indexes. + """ + raise NotImplementedError() + + def forward(self, x: JointEmbedCondition) -> ConditionType: + with self.autocast: + embed, empty_idx = self._get_embed(x) + if self.quantizer is not None: + embed = embed.view(-1, self.dim, 1) + q_res = self.quantizer(embed, frame_rate=1) + out_embed = q_res.x.view(-1, self.dim) + else: + out_embed = embed + out_embed = self.output_proj(out_embed).view(-1, 1, self.output_dim) + mask = torch.ones(*out_embed.shape[:2], device=out_embed.device) + mask[empty_idx, :] = 0 # zero-out index where the input is non-existant + out_embed = (out_embed * mask.unsqueeze(-1)) + return out_embed, mask + + def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition: + return x + + +class CLAPEmbeddingConditioner(JointEmbeddingConditioner): + """Joint Embedding conditioner based on pre-trained CLAP model. + + This CLAP-based conditioner supports a caching mechanism + over the computed embeddings for faster training. + + Args: + dim (int): Dimension. + output_dim (int): Output dimension. + device (str): Device. + attribute (str): Attribute used by the conditioner. + quantize (bool): Whether to quantize the CLAP embedding. + n_q (int): Number of residual quantizers (used if quantize is true). + bins (int): Quantizers' codebooks size (used if quantize is true). + checkpoint (str): Path to CLAP checkpoint. + model_arch (str): CLAP model architecture. + enable_fusion (bool): Enable fusion for CLAP model. + sample_rate (int): Sample rate used by CLAP model. + max_audio_length (float): Maximum audio length for CLAP model. + audio_stride (float): Stride to use for getting a CLAP embedding on the full sequence. + normalize (bool): Whether to normalize the CLAP embedding. + text_p (float): Probability of using text representation instead of audio at train time. + batch_size (Optional[int]): Batch size for CLAP embedding computation. + autocast_dtype (str): Autocast for the conditioner. + cache_path (Optional[str]): Path for pre-computed embeddings caching. + kwargs: Additional parameters for residual vector quantizer. + """ + def __init__(self, dim: int, output_dim: int, device: str, attribute: str, + quantize: bool, n_q: int, bins: int, checkpoint: tp.Union[str, Path], model_arch: str, + enable_fusion: bool, sample_rate: int, max_audio_length: int, audio_stride: int, + normalize: bool, text_p: bool, batch_size: tp.Optional[int] = None, + autocast_dtype: tp.Optional[str] = 'float32', cache_path: tp.Optional[str] = None, **kwargs): + try: + import laion_clap # type: ignore + except ImportError: + raise ImportError("Please install CLAP to use the CLAPEmbeddingConditioner: 'pip install laion_clap'") + warnings.warn("Sample rate for CLAP conditioner was fixed in version v1.1.0, (from 44.1 to 48 kHz). " + "Please retrain all models.") + # checkpoint = AudioCraftEnvironment.resolve_reference_path(checkpoint) + clap_tokenize = RobertaTokenizer.from_pretrained('roberta-base') + clap_model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch) + load_clap_state_dict(clap_model, checkpoint) + clap_model.eval() + clap_model.to(device) + super().__init__(dim=dim, output_dim=output_dim, device=device, attribute=attribute, + autocast_dtype=autocast_dtype, quantize=quantize, n_q=n_q, bins=bins, + **kwargs) + self.checkpoint = checkpoint + self.enable_fusion = enable_fusion + self.model_arch = model_arch + self.clap: laion_clap.CLAP_Module + self.clap_tokenize: RobertaTokenizer + self.clap_sample_rate = sample_rate + self.clap_max_frames = int(self.clap_sample_rate * max_audio_length) + self.clap_stride = int(self.clap_sample_rate * audio_stride) + self.batch_size = batch_size or 1 + self.normalize = normalize + self.text_p = text_p + self.__dict__['clap_tokenize'] = clap_tokenize + self.__dict__['clap'] = clap_model + self.wav_cache, self.text_cache = None, None + if cache_path is not None: + self.wav_cache = EmbeddingCache(Path(cache_path) / 'wav', self.device, + compute_embed_fn=self._get_wav_embedding_for_cache, + extract_embed_fn=self._extract_wav_embedding_chunk) + self.text_cache = EmbeddingCache(Path(cache_path) / 'text', self.device, + compute_embed_fn=self._get_text_embedding_for_cache) + + def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict: + # we use the default params from CLAP module here as well + return self.clap_tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt") + + def _compute_text_embedding(self, text: tp.List[str]) -> torch.Tensor: + """Compute text embedding from CLAP model on a given a batch of text. + + Args: + text (list[str]): List of text for the batch, with B items. + Returns: + torch.Tensor: CLAP embedding derived from text, of shape [B, 1, D], with D the CLAP embedding dimension. + """ + with torch.no_grad(): + embed = self.clap.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True) + return embed.view(embed.size(0), 1, embed.size(-1)) + + def _get_text_embedding_for_cache(self, path: tp.Union[Path, str], + x: JointEmbedCondition, idx: int) -> torch.Tensor: + """Get text embedding function for the cache.""" + text = x.text[idx] + text = text if text is not None else "" + return self._compute_text_embedding([text])[0] + + def _preprocess_wav(self, wav: torch.Tensor, length: torch.Tensor, sample_rates: tp.List[int]) -> torch.Tensor: + """Preprocess wav to expected format by CLAP model. + + Args: + wav (torch.Tensor): Audio wav, of shape [B, C, T]. + length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B]. + sample_rates (list[int]): Sample rates for each sample in the batch + Returns: + torch.Tensor: Audio wav of shape [B, T]. + """ + assert wav.dim() == 3, "Expecting wav to be [B, C, T]" + if sample_rates is not None: + _wav = [] + for i, audio in enumerate(wav): + sr = sample_rates[i] + audio = convert_audio(audio, from_rate=sr, to_rate=self.clap_sample_rate, to_channels=1) + _wav.append(audio) + wav = torch.stack(_wav, dim=0) + wav = wav.mean(dim=1) + return wav + + def _compute_wav_embedding(self, wav: torch.Tensor, length: torch.Tensor, + sample_rates: tp.List[int], reduce_mean: bool = False) -> torch.Tensor: + """Compute audio wave embedding from CLAP model. + + Since CLAP operates on a fixed sequence length audio inputs and we need to process longer audio sequences, + we calculate the wav embeddings on `clap_max_frames` windows with `clap_stride`-second stride and + average the resulting embeddings. + + Args: + wav (torch.Tensor): Audio wav, of shape [B, C, T]. + length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B]. + sample_rates (list[int]): Sample rates for each sample in the batch. + reduce_mean (bool): Whether to get the average tensor. + Returns: + torch.Tensor: Audio embedding of shape [B, F, D], F being the number of chunks, D the dimension. + """ + with torch.no_grad(): + wav = self._preprocess_wav(wav, length, sample_rates) + B, T = wav.shape + if T >= self.clap_max_frames: + wav = wav.unfold(-1, self.clap_max_frames, self.clap_stride) # [B, F, T] + else: + wav = wav.view(-1, 1, T) # [B, F, T] with F=1 + wav = einops.rearrange(wav, 'b f t -> (b f) t') + embed_list = [] + for i in range(0, wav.size(0), self.batch_size): + _wav = wav[i:i+self.batch_size, ...] + _embed = self.clap.get_audio_embedding_from_data(_wav, use_tensor=True) + embed_list.append(_embed) + embed = torch.cat(embed_list, dim=0) + embed = einops.rearrange(embed, '(b f) d -> b f d', b=B) + if reduce_mean: + embed = embed.mean(dim=1, keepdim=True) + return embed # [B, F, D] with F=1 if reduce_mean is True + + def _get_wav_embedding_for_cache(self, path: tp.Union[str, Path], + x: JointEmbedCondition, idx: int) -> torch.Tensor: + """Compute audio wave embedding for the cache. + The embedding is computed on a given audio read from file. + + Args: + path (str or Path): Path to the full audio file. + Returns: + torch.Tensor: Single-item tensor of shape [F, D], F being the number of chunks, D the dimension. + """ + wav, sr = soundfile.read(path) # [C, T] + wav = wav.unsqueeze(0).to(self.device) # [1, C, T] + wav_len = torch.LongTensor([wav.shape[-1]]).to(self.device) + embed = self._compute_wav_embedding(wav, wav_len, [sr], reduce_mean=False) # [B, F, D] + return embed.squeeze(0) # [F, D] + + def _extract_wav_embedding_chunk(self, full_embed: torch.Tensor, x: JointEmbedCondition, idx: int) -> torch.Tensor: + """Extract the chunk of embedding matching the seek_time and length from the full CLAP audio embedding. + + Args: + full_embed (torch.Tensor): CLAP embedding computed on the full wave, of shape [F, D]. + x (JointEmbedCondition): Joint embedding condition for the full batch. + idx (int): Index considered for the given embedding to extract. + Returns: + torch.Tensor: Wav embedding averaged on sliding window, of shape [1, D]. + """ + sample_rate = x.sample_rate[idx] + seek_time = x.seek_time[idx] + seek_time = 0. if seek_time is None else seek_time + clap_stride = int(self.clap_stride / self.clap_sample_rate) * sample_rate + end_seek_time = seek_time + self.clap_max_frames / self.clap_sample_rate + start_offset = int(seek_time * sample_rate // clap_stride) + end_offset = int(end_seek_time * sample_rate // clap_stride) + wav_embed = full_embed[start_offset:end_offset, ...] + wav_embed = wav_embed.mean(dim=0, keepdim=True) + return wav_embed.to(self.device) # [F, D] + + def _get_text_embedding(self, x: JointEmbedCondition) -> torch.Tensor: + """Get CLAP embedding from a batch of text descriptions.""" + no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout + if self.text_cache is not None and no_nullified_cond: + assert all(p is not None for p in x.path), "Cache requires all JointEmbedCondition paths to be provided" + paths = [Path(p) for p in x.path if p is not None] + embed = self.text_cache.get_embed_from_cache(paths, x) + else: + text = [xi if xi is not None else "" for xi in x.text] + embed = self._compute_text_embedding(text) + if self.normalize: + embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1) + return embed + + def _get_wav_embedding(self, x: JointEmbedCondition) -> torch.Tensor: + """Get CLAP embedding from a batch of audio tensors (and corresponding sample rates).""" + no_undefined_paths = all(p is not None for p in x.path) + no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout + if self.wav_cache is not None and no_undefined_paths and no_nullified_cond: + paths = [Path(p) for p in x.path if p is not None] + embed = self.wav_cache.get_embed_from_cache(paths, x) + else: + embed = self._compute_wav_embedding(x.wav, x.length, x.sample_rate, reduce_mean=True) + if self.normalize: + embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1) + return embed + + def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition: + # Trying to limit as much as possible sync points when the cache is warm. + no_undefined_paths = all(p is not None for p in x.path) + if self.wav_cache is not None and no_undefined_paths: + assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided" + paths = [Path(p) for p in x.path if p is not None] + self.wav_cache.populate_embed_cache(paths, x) + if self.text_cache is not None and no_undefined_paths: + assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided" + paths = [Path(p) for p in x.path if p is not None] + self.text_cache.populate_embed_cache(paths, x) + return x + + def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]: + """Extract shared latent representation from either the wav or the text using CLAP.""" + # decide whether to use text embedding at train time or not + use_text_embed = random.random() < self.text_p + if self.training and not use_text_embed: + embed = self._get_wav_embedding(x) + empty_idx = torch.LongTensor([]) # we assume we always have the audio wav + else: + embed = self._get_text_embedding(x) + empty_idx = torch.LongTensor([i for i, xi in enumerate(x.text) if xi is None or xi == ""]) + return embed, empty_idx + + +def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) -> ConditioningAttributes: + """Utility function for nullifying an attribute inside an ConditioningAttributes object. + If the condition is of type "wav", then nullify it using `nullify_condition` function. + If the condition is of any other type, set its value to None. + Works in-place. + """ + if condition_type not in ['text', 'wav', 'joint_embed']: + raise ValueError( + "dropout_condition got an unexpected condition type!" + f" expected 'text', 'wav' or 'joint_embed' but got '{condition_type}'" + ) + + if condition not in getattr(sample, condition_type): + raise ValueError( + "dropout_condition received an unexpected condition!" + f" expected wav={sample.wav.keys()} and text={sample.text.keys()}" + f" but got '{condition}' of type '{condition_type}'!" + ) + + if condition_type == 'wav': + wav_cond = sample.wav[condition] + sample.wav[condition] = nullify_wav(wav_cond) + elif condition_type == 'joint_embed': + embed = sample.joint_embed[condition] + sample.joint_embed[condition] = nullify_joint_embed(embed) + else: + sample.text[condition] = None + + return sample + + +class DropoutModule(nn.Module): + """Base module for all dropout modules.""" + def __init__(self, seed: int = 1234): + super().__init__() + self.rng = torch.Generator() + self.rng.manual_seed(seed) + + +class AttributeDropout(DropoutModule): + """Dropout with a given probability per attribute. + This is different from the behavior of ClassifierFreeGuidanceDropout as this allows for attributes + to be dropped out separately. For example, "artist" can be dropped while "genre" remains. + This is in contrast to ClassifierFreeGuidanceDropout where if "artist" is dropped "genre" + must also be dropped. + + Args: + p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example: + ... + "genre": 0.1, + "artist": 0.5, + "wav": 0.25, + ... + active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False. + seed (int, optional): Random seed. + """ + def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234): + super().__init__(seed=seed) + self.active_on_eval = active_on_eval + # construct dict that return the values from p otherwise 0 + self.p = {} + for condition_type, probs in p.items(): + self.p[condition_type] = defaultdict(lambda: 0, probs) + + def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]: + """ + Args: + samples (list[ConditioningAttributes]): List of conditions. + Returns: + list[ConditioningAttributes]: List of conditions after certain attributes were set to None. + """ + if not self.training and not self.active_on_eval: + return samples + + samples = deepcopy(samples) + for condition_type, ps in self.p.items(): # for condition types [text, wav] + for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre]) + if torch.rand(1, generator=self.rng).item() < p: + for sample in samples: + dropout_condition(sample, condition_type, condition) + return samples + + def __repr__(self): + return f"AttributeDropout({dict(self.p)})" + + +class ClassifierFreeGuidanceDropout(DropoutModule): + """Classifier Free Guidance dropout. + All attributes are dropped with the same probability. + + Args: + p (float): Probability to apply condition dropout during training. + seed (int): Random seed. + """ + def __init__(self, p: float, seed: int = 1234): + super().__init__(seed=seed) + self.p = p + + def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]: + """ + Args: + samples (list[ConditioningAttributes]): List of conditions. + Returns: + list[ConditioningAttributes]: List of conditions after all attributes were set to None. + """ + if not self.training: + return samples + + # decide on which attributes to drop in a batched fashion + drop = torch.rand(1, generator=self.rng).item() < self.p + if not drop: + return samples + + # nullify conditions of all attributes + samples = deepcopy(samples) + for condition_type in ["wav", "text"]: + for sample in samples: + for condition in sample.attributes[condition_type]: + dropout_condition(sample, condition_type, condition) + return samples + + def __repr__(self): + return f"ClassifierFreeGuidanceDropout(p={self.p})" + + +class ConditioningProvider(nn.Module): + """Prepare and provide conditions given all the supported conditioners. + + Args: + conditioners (dict): Dictionary of conditioners. + device (torch.device or str, optional): Device for conditioners and output condition types. + """ + def __init__(self, conditioners: tp.Dict[str, BaseConditioner], device: tp.Union[torch.device, str] = "cpu"): + super().__init__() + self.device = device + self.conditioners = nn.ModuleDict(conditioners) + + @property + def joint_embed_conditions(self): + return [m.attribute for m in self.conditioners.values() if isinstance(m, JointEmbeddingConditioner)] + + @property + def has_joint_embed_conditions(self): + return len(self.joint_embed_conditions) > 0 + + @property + def text_conditions(self): + return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)] + + @property + def wav_conditions(self): + return [k for k, v in self.conditioners.items() if isinstance(v, WaveformConditioner)] + + @property + def has_wav_condition(self): + return len(self.wav_conditions) > 0 + + def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]: + """Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly. + This should be called before starting any real GPU work to avoid synchronization points. + This will return a dict matching conditioner names to their arbitrary tokenized representations. + + Args: + inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing + text and wav conditions. + """ + assert all([isinstance(x, ConditioningAttributes) for x in inputs]), ( + "Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]", + f" but types were {set([type(x) for x in inputs])}" + ) + + output = {} + text = self._collate_text(inputs) + wavs = self._collate_wavs(inputs) + joint_embeds = self._collate_joint_embeds(inputs) + + assert set(text.keys() | wavs.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), ( + f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ", + f"got {text.keys(), wavs.keys(), joint_embeds.keys()}" + ) + + for attribute, batch in chain(text.items(), wavs.items(), joint_embeds.items()): + output[attribute] = self.conditioners[attribute].tokenize(batch) + return output + + def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]: + """Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations. + The output is for example: + { + "genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])), + "description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])), + ... + } + + Args: + tokenized (dict): Dict of tokenized representations as returned by `tokenize()`. + """ + output = {} + for attribute, inputs in tokenized.items(): + condition, mask = self.conditioners[attribute](inputs) + output[attribute] = (condition, mask) + return output + + def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]: + """Given a list of ConditioningAttributes objects, compile a dictionary where the keys + are the attributes and the values are the aggregated input per attribute. + For example: + Input: + [ + ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...), + ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...), + ] + Output: + { + "genre": ["Rock", "Hip-hop"], + "description": ["A rock song with a guitar solo", "A hip-hop verse"] + } + + Args: + samples (list of ConditioningAttributes): List of ConditioningAttributes samples. + Returns: + dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch. + """ + out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list) + texts = [x.text for x in samples] + for text in texts: + for condition in self.text_conditions: + out[condition].append(text[condition]) + return out + + def _collate_wavs(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, WavCondition]: + """Generate a dict where the keys are attributes by which we fetch similar wavs, + and the values are Tensors of wavs according to said attributes. + + *Note*: by the time the samples reach this function, each sample should have some waveform + inside the "wav" attribute. It should be either: + 1. A real waveform + 2. A null waveform due to the sample having no similar waveforms (nullified by the dataset) + 3. A null waveform due to it being dropped in a dropout module (nullified by dropout) + + Args: + samples (list of ConditioningAttributes): List of ConditioningAttributes samples. + Returns: + dict[str, WavCondition]: A dictionary mapping an attribute name to wavs. + """ + wavs = defaultdict(list) + lengths = defaultdict(list) + sample_rates = defaultdict(list) + paths = defaultdict(list) + seek_times = defaultdict(list) + out: tp.Dict[str, WavCondition] = {} + + for sample in samples: + for attribute in self.wav_conditions: + wav, length, sample_rate, path, seek_time = sample.wav[attribute] + assert wav.dim() == 3, f"Got wav with dim={wav.dim()}, but expected 3 [1, C, T]" + assert wav.size(0) == 1, f"Got wav [B, C, T] with shape={wav.shape}, but expected B == 1" + # mono-channel conditioning + wav = wav.mean(1, keepdim=True) # [1, 1, T] + wavs[attribute].append(wav.flatten()) # [T] + lengths[attribute].append(length) + sample_rates[attribute].extend(sample_rate) + paths[attribute].extend(path) + seek_times[attribute].extend(seek_time) + + # stack all wavs to a single tensor + for attribute in self.wav_conditions: + stacked_wav, _ = collate(wavs[attribute], dim=0) + out[attribute] = WavCondition( + stacked_wav.unsqueeze(1), torch.cat(lengths[attribute]), sample_rates[attribute], + paths[attribute], seek_times[attribute]) + + return out + + def _collate_joint_embeds(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, JointEmbedCondition]: + """Generate a dict where the keys are attributes by which we compute joint embeddings, + and the values are Tensors of pre-computed embeddings and the corresponding text attributes. + + Args: + samples (list[ConditioningAttributes]): List of ConditioningAttributes samples. + Returns: + A dictionary mapping an attribute name to joint embeddings. + """ + texts = defaultdict(list) + wavs = defaultdict(list) + lengths = defaultdict(list) + sample_rates = defaultdict(list) + paths = defaultdict(list) + seek_times = defaultdict(list) + channels: int = 0 + + out = {} + for sample in samples: + for attribute in self.joint_embed_conditions: + wav, text, length, sample_rate, path, seek_time = sample.joint_embed[attribute] + assert wav.dim() == 3 + if channels == 0: + channels = wav.size(1) + else: + assert channels == wav.size(1), "not all audio has same number of channels in batch" + assert wav.size(0) == 1, "Expecting single-wav batch in the collate method" + wav = einops.rearrange(wav, "b c t -> (b c t)") # [1, C, T] => [C * T] + wavs[attribute].append(wav) + texts[attribute].extend(text) + lengths[attribute].append(length) + sample_rates[attribute].extend(sample_rate) + paths[attribute].extend(path) + seek_times[attribute].extend(seek_time) + + for attribute in self.joint_embed_conditions: + stacked_texts = texts[attribute] + stacked_paths = paths[attribute] + stacked_seek_times = seek_times[attribute] + stacked_wavs = pad_sequence(wavs[attribute]).to(self.device) + stacked_wavs = einops.rearrange(stacked_wavs, "(c t) b -> b c t", c=channels) + stacked_sample_rates = sample_rates[attribute] + stacked_lengths = torch.cat(lengths[attribute]).to(self.device) + assert stacked_lengths.size(0) == stacked_wavs.size(0) + assert len(stacked_sample_rates) == stacked_wavs.size(0) + assert len(stacked_texts) == stacked_wavs.size(0) + out[attribute] = JointEmbedCondition( + text=stacked_texts, wav=stacked_wavs, + length=stacked_lengths, sample_rate=stacked_sample_rates, + path=stacked_paths, seek_time=stacked_seek_times) + + return out + + +class ConditionFuser(StreamingModule): + """Condition fuser handles the logic to combine the different conditions + to the actual model input. + + Args: + fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse + each condition. For example: + { + "prepend": ["description"], + "sum": ["genre", "bpm"], + "cross": ["description"], + } + cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention. + cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used. + """ + FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"] + + def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False, + cross_attention_pos_emb_scale: float = 1.0): + super().__init__() + assert all( + [k in self.FUSING_METHODS for k in fuse2cond.keys()] + ), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}" + self.cross_attention_pos_emb = cross_attention_pos_emb + self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale + self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond + self.cond2fuse: tp.Dict[str, str] = {} + for fuse_method, conditions in fuse2cond.items(): + for condition in conditions: + self.cond2fuse[condition] = fuse_method + + def forward( + self, + input: torch.Tensor, + conditions: tp.Dict[str, ConditionType] + ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + """Fuse the conditions to the provided model input. + + Args: + input (torch.Tensor): Transformer input. + conditions (dict[str, ConditionType]): Dict of conditions. + Returns: + tuple[torch.Tensor, torch.Tensor]: The first tensor is the transformer input + after the conditions have been fused. The second output tensor is the tensor + used for cross-attention or None if no cross attention inputs exist. + """ + B, T, _ = input.shape + + if 'offsets' in self._streaming_state: + first_step = False + offsets = self._streaming_state['offsets'] + else: + first_step = True + offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device) + + assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \ + f"given conditions contain unknown attributes for fuser, " \ + f"expected {self.cond2fuse.keys()}, got {conditions.keys()}" + cross_attention_output = None + for cond_type, (cond, cond_mask) in conditions.items(): + op = self.cond2fuse[cond_type] + if op == 'sum': + input += cond + elif op == 'input_interpolate': + cond = einops.rearrange(cond, "b t d -> b d t") + cond = F.interpolate(cond, size=input.shape[1]) + input += einops.rearrange(cond, "b d t -> b t d") + elif op == 'prepend': + if first_step: + input = torch.cat([cond, input], dim=1) + elif op == 'cross': + if cross_attention_output is not None: + cross_attention_output = torch.cat([cross_attention_output, cond], dim=1) + else: + cross_attention_output = cond + else: + raise ValueError(f"unknown op ({op})") + + if self.cross_attention_pos_emb and cross_attention_output is not None: + positions = torch.arange( + cross_attention_output.shape[1], + device=cross_attention_output.device + ).view(1, -1, 1) + pos_emb = create_sin_embedding(positions, cross_attention_output.shape[-1]) + cross_attention_output = cross_attention_output + self.cross_attention_pos_emb_scale * pos_emb + + if self._is_streaming: + self._streaming_state['offsets'] = offsets + T + + return input, cross_attention_output diff --git a/audiocraft/conv.py b/audiocraft/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..d115cbf8729b642ed78608bd00a4d0fd5afae6fd --- /dev/null +++ b/audiocraft/conv.py @@ -0,0 +1,243 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +import typing as tp +import warnings + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.utils import spectral_norm, weight_norm + + +CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm', + 'time_group_norm']) + + +def apply_parametrization_norm(module: nn.Module, norm: str = 'none'): + assert norm in CONV_NORMALIZATIONS + if norm == 'weight_norm': + return weight_norm(module) + elif norm == 'spectral_norm': + return spectral_norm(module) + else: + # We already check was in CONV_NORMALIZATION, so any other choice + # doesn't need reparametrization. + return module + + +def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs): + """Return the proper normalization module. If causal is True, this will ensure the returned + module is causal, or return an error if the normalization doesn't support causal evaluation. + """ + assert norm in CONV_NORMALIZATIONS + if norm == 'time_group_norm': + if causal: + raise ValueError("GroupNorm doesn't support causal evaluation.") + assert isinstance(module, nn.modules.conv._ConvNd) + return nn.GroupNorm(1, module.out_channels, **norm_kwargs) + else: + return nn.Identity() + + +def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, + padding_total: int = 0) -> int: + """See `pad_for_conv1d`.""" + length = x.shape[-1] + n_frames = (length - kernel_size + padding_total) / stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + return ideal_length - length + + +def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0): + """Pad for a convolution to make sure that the last window is full. + Extra padding is added at the end. This is required to ensure that we can rebuild + an output of the same length, as otherwise, even with padding, some time steps + might get removed. + For instance, with total padding = 4, kernel size = 4, stride = 2: + 0 0 1 2 3 4 5 0 0 # (0s are padding) + 1 2 3 # (output frames of a convolution, last 0 is never used) + 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) + 1 2 3 4 # once you removed padding, we are missing one time step ! + """ + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + return F.pad(x, (0, extra_padding)) + + +def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.): + """Tiny wrapper around F.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right before the reflection happen. + """ + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == 'reflect': + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + else: + return F.pad(x, paddings, mode, value) + + +def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): + """Remove padding from x, handling properly zero padding. Only for 1d!""" + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + assert (padding_left + padding_right) <= x.shape[-1] + end = x.shape[-1] - padding_right + return x[..., padding_left: end] + + +class NormConv1d(nn.Module): + """Wrapper around Conv1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, causal: bool = False, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class NormConv2d(nn.Module): + """Wrapper around Conv2d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class NormConvTranspose1d(nn.Module): + """Wrapper around ConvTranspose1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, causal: bool = False, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm) + self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.convtr(x) + x = self.norm(x) + return x + + +class NormConvTranspose2d(nn.Module): + """Wrapper around ConvTranspose2d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm) + self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs) + + def forward(self, x): + x = self.convtr(x) + x = self.norm(x) + return x + + +class StreamableConv1d(nn.Module): + """Conv1d with some builtin handling of asymmetric or causal padding + and normalization. + """ + def __init__(self, in_channels: int, out_channels: int, + kernel_size: int, stride: int = 1, dilation: int = 1, + groups: int = 1, bias: bool = True, causal: bool = False, + norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, + pad_mode: str = 'reflect'): + super().__init__() + # warn user on unusual setup between dilation and stride + if stride > 1 and dilation > 1: + warnings.warn("StreamableConv1d has been initialized with stride > 1 and dilation > 1" + f" (kernel_size={kernel_size} stride={stride}, dilation={dilation}).") + self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride, + dilation=dilation, groups=groups, bias=bias, causal=causal, + norm=norm, norm_kwargs=norm_kwargs) + self.causal = causal + self.pad_mode = pad_mode + + def forward(self, x): + B, C, T = x.shape + kernel_size = self.conv.conv.kernel_size[0] + stride = self.conv.conv.stride[0] + dilation = self.conv.conv.dilation[0] + kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations + padding_total = kernel_size - stride + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + if self.causal: + # Left padding for causal + x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode) + return self.conv(x) + + +class StreamableConvTranspose1d(nn.Module): + """ConvTranspose1d with some builtin handling of asymmetric or causal padding + and normalization. + """ + def __init__(self, in_channels: int, out_channels: int, + kernel_size: int, stride: int = 1, causal: bool = False, + norm: str = 'none', trim_right_ratio: float = 1., + norm_kwargs: tp.Dict[str, tp.Any] = {}): + super().__init__() + self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride, + causal=causal, norm=norm, norm_kwargs=norm_kwargs) + self.causal = causal + self.trim_right_ratio = trim_right_ratio + assert self.causal or self.trim_right_ratio == 1., \ + "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" + assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1. + + def forward(self, x): + kernel_size = self.convtr.convtr.kernel_size[0] + stride = self.convtr.convtr.stride[0] + padding_total = kernel_size - stride + + y = self.convtr(x) + + # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be + # removed at the very end, when keeping only the right length for the output, + # as removing it here would require also passing the length at the matching layer + # in the encoder. + if self.causal: + # Trim the padding on the right according to the specified ratio + # if trim_right_ratio = 1.0, trim everything from right + padding_right = math.ceil(padding_total * self.trim_right_ratio) + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + return y diff --git a/audiocraft/diffusion_schedule.py b/audiocraft/diffusion_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..74ca6e3f2e7c4ff904d96dade315b0b46856778d --- /dev/null +++ b/audiocraft/diffusion_schedule.py @@ -0,0 +1,272 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Functions for Noise Schedule, defines diffusion process, reverse process and data processor. +""" + +from collections import namedtuple +import random +import typing as tp +import julius +import torch + +TrainingItem = namedtuple("TrainingItem", "noisy noise step") + + +def betas_from_alpha_bar(alpha_bar): + alphas = torch.cat([torch.Tensor([alpha_bar[0]]), alpha_bar[1:]/alpha_bar[:-1]]) + return 1 - alphas + + +class SampleProcessor(torch.nn.Module): + def project_sample(self, x: torch.Tensor): + """Project the original sample to the 'space' where the diffusion will happen.""" + return x + + def return_sample(self, z: torch.Tensor): + """Project back from diffusion space to the actual sample space.""" + return z + + +class MultiBandProcessor(SampleProcessor): + """ + MultiBand sample processor. The input audio is splitted across + frequency bands evenly distributed in mel-scale. + + Each band will be rescaled to match the power distribution + of Gaussian noise in that band, using online metrics + computed on the first few samples. + + Args: + n_bands (int): Number of mel-bands to split the signal over. + sample_rate (int): Sample rate of the audio. + num_samples (int): Number of samples to use to fit the rescaling + for each band. The processor won't be stable + until it has seen that many samples. + power_std (float or list/tensor): The rescaling factor computed to match the + power of Gaussian noise in each band is taken to + that power, i.e. `1.` means full correction of the energy + in each band, and values less than `1` means only partial + correction. Can be used to balance the relative importance + of low vs. high freq in typical audio signals. + """ + def __init__(self, n_bands: int = 8, sample_rate: float = 24_000, + num_samples: int = 10_000, power_std: tp.Union[float, tp.List[float], torch.Tensor] = 1.): + super().__init__() + self.n_bands = n_bands + self.split_bands = julius.SplitBands(sample_rate, n_bands=n_bands) + self.num_samples = num_samples + self.power_std = power_std + if isinstance(power_std, list): + assert len(power_std) == n_bands + power_std = torch.tensor(power_std) + self.register_buffer('counts', torch.zeros(1)) + self.register_buffer('sum_x', torch.zeros(n_bands)) + self.register_buffer('sum_x2', torch.zeros(n_bands)) + self.register_buffer('sum_target_x2', torch.zeros(n_bands)) + self.counts: torch.Tensor + self.sum_x: torch.Tensor + self.sum_x2: torch.Tensor + self.sum_target_x2: torch.Tensor + + @property + def mean(self): + mean = self.sum_x / self.counts + return mean + + @property + def std(self): + std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt() + return std + + @property + def target_std(self): + target_std = self.sum_target_x2 / self.counts + return target_std + + def project_sample(self, x: torch.Tensor): + assert x.dim() == 3 + bands = self.split_bands(x) + if self.counts.item() < self.num_samples: + ref_bands = self.split_bands(torch.randn_like(x)) + self.counts += len(x) + self.sum_x += bands.mean(dim=(2, 3)).sum(dim=1) + self.sum_x2 += bands.pow(2).mean(dim=(2, 3)).sum(dim=1) + self.sum_target_x2 += ref_bands.pow(2).mean(dim=(2, 3)).sum(dim=1) + rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size + bands = (bands - self.mean.view(-1, 1, 1, 1)) * rescale.view(-1, 1, 1, 1) + return bands.sum(dim=0) + + def return_sample(self, x: torch.Tensor): + assert x.dim() == 3 + bands = self.split_bands(x) + rescale = (self.std / self.target_std) ** self.power_std + bands = bands * rescale.view(-1, 1, 1, 1) + self.mean.view(-1, 1, 1, 1) + return bands.sum(dim=0) + + +class NoiseSchedule: + """Noise schedule for diffusion. + + Args: + beta_t0 (float): Variance of the first diffusion step. + beta_t1 (float): Variance of the last diffusion step. + beta_exp (float): Power schedule exponent + num_steps (int): Number of diffusion step. + variance (str): choice of the sigma value for the denoising eq. Choices: "beta" or "beta_tilde" + clip (float): clipping value for the denoising steps + rescale (float): rescaling value to avoid vanishing signals unused by default (i.e 1) + repartition (str): shape of the schedule only power schedule is supported + sample_processor (SampleProcessor): Module that normalize data to match better the gaussian distribution + noise_scale (float): Scaling factor for the noise + """ + def __init__(self, beta_t0: float = 1e-4, beta_t1: float = 0.02, num_steps: int = 1000, variance: str = 'beta', + clip: float = 5., rescale: float = 1., device='cuda', beta_exp: float = 1, + repartition: str = "power", alpha_sigmoid: dict = {}, n_bands: tp.Optional[int] = None, + sample_processor: SampleProcessor = SampleProcessor(), noise_scale: float = 1.0, **kwargs): + + self.beta_t0 = beta_t0 + self.beta_t1 = beta_t1 + self.variance = variance + self.num_steps = num_steps + self.clip = clip + self.sample_processor = sample_processor + self.rescale = rescale + self.n_bands = n_bands + self.noise_scale = noise_scale + assert n_bands is None + if repartition == "power": + self.betas = torch.linspace(beta_t0 ** (1 / beta_exp), beta_t1 ** (1 / beta_exp), num_steps, + device=device, dtype=torch.float) ** beta_exp + else: + raise RuntimeError('Not implemented') + self.rng = random.Random(1234) + + def get_beta(self, step: tp.Union[int, torch.Tensor]): + if self.n_bands is None: + return self.betas[step] + else: + return self.betas[:, step] # [n_bands, len(step)] + + def get_initial_noise(self, x: torch.Tensor): + if self.n_bands is None: + return torch.randn_like(x) + return torch.randn((x.size(0), self.n_bands, x.size(2))) + + def get_alpha_bar(self, step: tp.Optional[tp.Union[int, torch.Tensor]] = None) -> torch.Tensor: + """Return 'alpha_bar', either for a given step, or as a tensor with its value for each step.""" + if step is None: + return (1 - self.betas).cumprod(dim=-1) # works for simgle and multi bands + if type(step) is int: + return (1 - self.betas[:step + 1]).prod() + else: + return (1 - self.betas).cumprod(dim=0)[step].view(-1, 1, 1) + + def get_training_item(self, x: torch.Tensor, tensor_step: bool = False) -> TrainingItem: + """Create a noisy data item for diffusion model training: + + Args: + x (torch.Tensor): clean audio data torch.tensor(bs, 1, T) + tensor_step (bool): If tensor_step = false, only one step t is sample, + the whole batch is diffused to the same step and t is int. + If tensor_step = true, t is a tensor of size (x.size(0),) + every element of the batch is diffused to a independently sampled. + """ + step: tp.Union[int, torch.Tensor] + if tensor_step: + bs = x.size(0) + step = torch.randint(0, self.num_steps, size=(bs,), device=x.device) + else: + step = self.rng.randrange(self.num_steps) + alpha_bar = self.get_alpha_bar(step) # [batch_size, n_bands, 1] + + x = self.sample_processor.project_sample(x) + noise = torch.randn_like(x) + noisy = (alpha_bar.sqrt() / self.rescale) * x + (1 - alpha_bar).sqrt() * noise * self.noise_scale + return TrainingItem(noisy, noise, step) + + def generate(self, model: torch.nn.Module, initial: tp.Optional[torch.Tensor] = None, + condition: tp.Optional[torch.Tensor] = None, return_list: bool = False): + """Full ddpm reverse process. + + Args: + model (nn.Module): Diffusion model. + initial (tensor): Initial Noise. + condition (tensor): Input conditionning Tensor (e.g. encodec compressed representation). + return_list (bool): Whether to return the whole process or only the sampled point. + """ + alpha_bar = self.get_alpha_bar(step=self.num_steps - 1) + current = initial + iterates = [initial] + for step in range(self.num_steps)[::-1]: + with torch.no_grad(): + estimate = model(current, step, condition=condition).sample + alpha = 1 - self.betas[step] + previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt() + previous_alpha_bar = self.get_alpha_bar(step=step - 1) + if step == 0: + sigma2 = 0 + elif self.variance == 'beta': + sigma2 = 1 - alpha + elif self.variance == 'beta_tilde': + sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha) + elif self.variance == 'none': + sigma2 = 0 + else: + raise ValueError(f'Invalid variance type {self.variance}') + + if sigma2 > 0: + previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale + if self.clip: + previous = previous.clamp(-self.clip, self.clip) + current = previous + alpha_bar = previous_alpha_bar + if step == 0: + previous *= self.rescale + if return_list: + iterates.append(previous.cpu()) + + if return_list: + return iterates + else: + return self.sample_processor.return_sample(previous) + + def generate_subsampled(self, model: torch.nn.Module, initial: torch.Tensor, step_list: tp.Optional[list] = None, + condition: tp.Optional[torch.Tensor] = None, return_list: bool = False): + """Reverse process that only goes through Markov chain states in step_list.""" + if step_list is None: + step_list = list(range(1000))[::-50] + [0] + alpha_bar = self.get_alpha_bar(step=self.num_steps - 1) + alpha_bars_subsampled = (1 - self.betas).cumprod(dim=0)[list(reversed(step_list))].cpu() + betas_subsampled = betas_from_alpha_bar(alpha_bars_subsampled) + current = initial * self.noise_scale + iterates = [current] + for idx, step in enumerate(step_list[:-1]): + with torch.no_grad(): + estimate = model(current, step, condition=condition).sample * self.noise_scale + alpha = 1 - betas_subsampled[-1 - idx] + previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt() + previous_alpha_bar = self.get_alpha_bar(step_list[idx + 1]) + if step == step_list[-2]: + sigma2 = 0 + previous_alpha_bar = torch.tensor(1.0) + else: + sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha) + if sigma2 > 0: + previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale + if self.clip: + previous = previous.clamp(-self.clip, self.clip) + current = previous + alpha_bar = previous_alpha_bar + if step == 0: + previous *= self.rescale + if return_list: + iterates.append(previous.cpu()) + if return_list: + return iterates + else: + return self.sample_processor.return_sample(previous) diff --git a/audiocraft/encodec.py b/audiocraft/encodec.py new file mode 100644 index 0000000000000000000000000000000000000000..2f71be36c64c438b54ebf553cf09360bfb7fdacd --- /dev/null +++ b/audiocraft/encodec.py @@ -0,0 +1,506 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Compression models or wrapper around existing models. +Also defines the main interface that a model must follow to be usable as an audio tokenizer. +""" + +from abc import ABC, abstractmethod +import logging +import math +from pathlib import Path +import typing as tp + +from einops import rearrange +import numpy as np +import torch +from torch import nn +from transformers import EncodecModel as HFEncodecModel + +import audiocraft.quantization as qt + + +logger = logging.getLogger() + + +class CompressionModel(ABC, nn.Module): + """Base API for all compression models that aim at being used as audio tokenizers + with a language model. + """ + + @abstractmethod + def forward(self, x: torch.Tensor) -> qt.QuantizedResult: + ... + + @abstractmethod + def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + """See `EncodecModel.encode`.""" + ... + + @abstractmethod + def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): + """See `EncodecModel.decode`.""" + ... + + @abstractmethod + def decode_latent(self, codes: torch.Tensor): + """Decode from the discrete codes to continuous latent space.""" + ... + + @property + @abstractmethod + def channels(self) -> int: + ... + + @property + @abstractmethod + def frame_rate(self) -> float: + ... + + @property + @abstractmethod + def sample_rate(self) -> int: + ... + + @property + @abstractmethod + def cardinality(self) -> int: + ... + + @property + @abstractmethod + def num_codebooks(self) -> int: + ... + + @property + @abstractmethod + def total_codebooks(self) -> int: + ... + + @abstractmethod + def set_num_codebooks(self, n: int): + """Set the active number of codebooks used by the quantizer.""" + ... + + @staticmethod + def get_pretrained( + name: str, device: tp.Union[torch.device, str] = 'cpu' + ) -> 'CompressionModel': + """Instantiate a CompressionModel from a given pretrained model. + + Args: + name (Path or str): name of the pretrained model. See after. + device (torch.device or str): Device on which the model is loaded. + + Pretrained models: + - dac_44khz (https://github.com/descriptinc/descript-audio-codec) + - dac_24khz (same) + - facebook/encodec_24khz (https://huggingface.co/facebook/encodec_24khz) + - facebook/encodec_32khz (https://huggingface.co/facebook/encodec_32khz) + - your own model on Hugging Face. Export instructions to come... + """ + + from . import builders, loaders + model: CompressionModel + if name in ['dac_44khz', 'dac_24khz']: + model_type = name.split('_')[1] + logger.info("Getting pretrained compression model from DAC %s", model_type) + model = DAC(model_type) + elif name in ['debug_compression_model']: + logger.info("Getting pretrained compression model for debug") + model = builders.get_debug_compression_model() + elif Path(name).exists(): + # We assume here if the path exists that it is in fact an AC checkpoint + # that was exported using `audiocraft.utils.export` functions. + model = loaders.load_compression_model(name, device=device) + else: + logger.info("Getting pretrained compression model from HF %s", name) + hf_model = HFEncodecModel.from_pretrained(name) + model = HFEncodecCompressionModel(hf_model).to(device) + return model.to(device).eval() + + +class EncodecModel(CompressionModel): + """Encodec model operating on the raw waveform. + + Args: + encoder (nn.Module): Encoder network. + decoder (nn.Module): Decoder network. + quantizer (qt.BaseQuantizer): Quantizer network. + frame_rate (int): Frame rate for the latent representation. + sample_rate (int): Audio sample rate. + channels (int): Number of audio channels. + causal (bool): Whether to use a causal version of the model. + renormalize (bool): Whether to renormalize the audio before running the model. + """ + # we need assignment to override the property in the abstract class, + # I couldn't find a better way... + frame_rate: float = 0 + sample_rate: int = 0 + channels: int = 0 + + def __init__(self, + encoder: nn.Module, + decoder: nn.Module, + quantizer: qt.BaseQuantizer, + frame_rate: int, + sample_rate: int, + channels: int, + causal: bool = False, + renormalize: bool = False): + super().__init__() + self.encoder = encoder + self.decoder = decoder + self.quantizer = quantizer + self.frame_rate = frame_rate + self.sample_rate = sample_rate + self.channels = channels + self.renormalize = renormalize + self.causal = causal + if self.causal: + # we force disabling here to avoid handling linear overlap of segments + # as supported in original EnCodec codebase. + assert not self.renormalize, 'Causal model does not support renormalize' + + @property + def total_codebooks(self): + """Total number of quantizer codebooks available.""" + return self.quantizer.total_codebooks + + @property + def num_codebooks(self): + """Active number of codebooks used by the quantizer.""" + return self.quantizer.num_codebooks + + def set_num_codebooks(self, n: int): + """Set the active number of codebooks used by the quantizer.""" + self.quantizer.set_num_codebooks(n) + + @property + def cardinality(self): + """Cardinality of each codebook.""" + return self.quantizer.bins + + def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + scale: tp.Optional[torch.Tensor] + if self.renormalize: + mono = x.mean(dim=1, keepdim=True) + volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt() + scale = 1e-8 + volume + x = x / scale + scale = scale.view(-1, 1) + else: + scale = None + return x, scale + + def postprocess(self, + x: torch.Tensor, + scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor: + if scale is not None: + assert self.renormalize + x = x * scale.view(-1, 1, 1) + return x + + def forward(self, x: torch.Tensor) -> qt.QuantizedResult: + assert x.dim() == 3 + length = x.shape[-1] + x, scale = self.preprocess(x) + + emb = self.encoder(x) + q_res = self.quantizer(emb, self.frame_rate) + out = self.decoder(q_res.x) + + # remove extra padding added by the encoder and decoder + assert out.shape[-1] >= length, (out.shape[-1], length) + out = out[..., :length] + + q_res.x = self.postprocess(out, scale) + + return q_res + + def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + """Encode the given input tensor to quantized representation along with scale parameter. + + Args: + x (torch.Tensor): Float tensor of shape [B, C, T] + + Returns: + codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of: + codes: a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep. + scale: a float tensor containing the scale for audio renormalization. + """ + assert x.dim() == 3 + x, scale = self.preprocess(x) + emb = self.encoder(x) + codes = self.quantizer.encode(emb) + return codes, scale + + def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): + """Decode the given codes to a reconstructed representation, using the scale to perform + audio denormalization if needed. + + Args: + codes (torch.Tensor): Int tensor of shape [B, K, T] + scale (torch.Tensor, optional): Float tensor containing the scale value. + + Returns: + out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio. + """ + emb = self.decode_latent(codes) + out = self.decoder(emb) + out = self.postprocess(out, scale) + # out contains extra padding added by the encoder and decoder + return out + + def decode_latent(self, codes: torch.Tensor): + """Decode from the discrete codes to continuous latent space.""" + return self.quantizer.decode(codes) + + +class DAC(CompressionModel): + def __init__(self, model_type: str = "44khz"): + super().__init__() + try: + import dac.utils + except ImportError: + raise RuntimeError("Could not import dac, make sure it is installed, " + "please run `pip install descript-audio-codec`") + self.model = dac.utils.load_model(model_type=model_type) + self.n_quantizers = self.total_codebooks + self.model.eval() + + def forward(self, x: torch.Tensor) -> qt.QuantizedResult: + # We don't support training with this. + raise NotImplementedError("Forward and training with DAC not supported.") + + def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + codes = self.model.encode(x, self.n_quantizers)[1] + return codes[:, :self.n_quantizers], None + + def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): + assert scale is None + z_q = self.decode_latent(codes) + return self.model.decode(z_q) + + def decode_latent(self, codes: torch.Tensor): + """Decode from the discrete codes to continuous latent space.""" + return self.model.quantizer.from_codes(codes)[0] + + @property + def channels(self) -> int: + return 1 + + @property + def frame_rate(self) -> float: + return self.model.sample_rate / self.model.hop_length + + @property + def sample_rate(self) -> int: + return self.model.sample_rate + + @property + def cardinality(self) -> int: + return self.model.codebook_size + + @property + def num_codebooks(self) -> int: + return self.n_quantizers + + @property + def total_codebooks(self) -> int: + return self.model.n_codebooks + + def set_num_codebooks(self, n: int): + """Set the active number of codebooks used by the quantizer. + """ + assert n >= 1 + assert n <= self.total_codebooks + self.n_quantizers = n + + +class HFEncodecCompressionModel(CompressionModel): + """Wrapper around HuggingFace Encodec. + """ + def __init__(self, model: HFEncodecModel): + super().__init__() + self.model = model + bws = self.model.config.target_bandwidths + num_codebooks = [ + bw * 1000 / (self.frame_rate * math.log2(self.cardinality)) + for bw in bws + ] + deltas = [nc - int(nc) for nc in num_codebooks] + # Checking we didn't do some bad maths and we indeed have integers! + assert all(deltas) <= 1e-3, deltas + self.possible_num_codebooks = [int(nc) for nc in num_codebooks] + self.set_num_codebooks(max(self.possible_num_codebooks)) + + def forward(self, x: torch.Tensor) -> qt.QuantizedResult: + # We don't support training with this. + raise NotImplementedError("Forward and training with HF EncodecModel not supported.") + + def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + bandwidth_index = self.possible_num_codebooks.index(self.num_codebooks) + bandwidth = self.model.config.target_bandwidths[bandwidth_index] + res = self.model.encode(x, None, bandwidth) + assert len(res[0]) == 1 + assert len(res[1]) == 1 + return res[0][0], res[1][0] + + def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): + if scale is None: + scales = [None] # type: ignore + else: + scales = scale # type: ignore + res = self.model.decode(codes[None], scales) + return res[0] + + def decode_latent(self, codes: torch.Tensor): + """Decode from the discrete codes to continuous latent space.""" + return self.model.quantizer.decode(codes.transpose(0, 1)) + + @property + def channels(self) -> int: + return self.model.config.audio_channels + + @property + def frame_rate(self) -> float: + hop_length = int(np.prod(self.model.config.upsampling_ratios)) + return self.sample_rate / hop_length + + @property + def sample_rate(self) -> int: + return self.model.config.sampling_rate + + @property + def cardinality(self) -> int: + return self.model.config.codebook_size + + @property + def num_codebooks(self) -> int: + return self._num_codebooks + + @property + def total_codebooks(self) -> int: + return max(self.possible_num_codebooks) + + def set_num_codebooks(self, n: int): + """Set the active number of codebooks used by the quantizer. + """ + if n not in self.possible_num_codebooks: + raise ValueError(f"Allowed values for num codebooks: {self.possible_num_codebooks}") + self._num_codebooks = n + + +class InterleaveStereoCompressionModel(CompressionModel): + """Wraps a CompressionModel to support stereo inputs. The wrapped model + will be applied independently to the left and right channels, and both codebooks + will be interleaved. If the wrapped model returns a representation `[B, K ,T]` per + channel, then the output will be `[B, K * 2, T]` or `[B, K, T * 2]` depending on + `per_timestep`. + + Args: + model (CompressionModel): Compression model to wrap. + per_timestep (bool): Whether to interleave on the timestep dimension + or on the codebooks dimension. + """ + def __init__(self, model: CompressionModel, per_timestep: bool = False): + super().__init__() + self.model = model + self.per_timestep = per_timestep + assert self.model.channels == 1, "Wrapped model is expected to be for monophonic audio" + + @property + def total_codebooks(self): + return self.model.total_codebooks + + @property + def num_codebooks(self): + """Active number of codebooks used by the quantizer. + + ..Warning:: this reports the number of codebooks after the interleaving + of the codebooks! + """ + return self.model.num_codebooks if self.per_timestep else self.model.num_codebooks * 2 + + def set_num_codebooks(self, n: int): + """Set the active number of codebooks used by the quantizer. + + ..Warning:: this sets the number of codebooks before the interleaving! + """ + self.model.set_num_codebooks(n) + + @property + def num_virtual_steps(self) -> float: + """Return the number of virtual steps, e.g. one real step + will be split into that many steps. + """ + return 2 if self.per_timestep else 1 + + @property + def frame_rate(self) -> float: + return self.model.frame_rate * self.num_virtual_steps + + @property + def sample_rate(self) -> int: + return self.model.sample_rate + + @property + def channels(self) -> int: + return 2 + + @property + def cardinality(self): + """Cardinality of each codebook. + """ + return self.model.cardinality + + def forward(self, x: torch.Tensor) -> qt.QuantizedResult: + raise NotImplementedError("Not supported, use encode and decode.") + + def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + B, C, T = x.shape + assert C == self.channels, f"Expecting stereo audio but audio num channels is {C}" + + indices_c0, scales_c0 = self.model.encode(x[:, 0, ...].unsqueeze(1)) + indices_c1, scales_c1 = self.model.encode(x[:, 1, ...].unsqueeze(1)) + indices = torch.stack([indices_c0, indices_c1], dim=0) + scales: tp.Optional[torch.Tensor] = None + if scales_c0 is not None and scales_c1 is not None: + scales = torch.stack([scales_c0, scales_c1], dim=1) + + if self.per_timestep: + indices = rearrange(indices, 'c b k t -> b k (t c)', c=2) + else: + indices = rearrange(indices, 'c b k t -> b (k c) t', c=2) + + return (indices, scales) + + def get_left_right_codes(self, codes: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]: + if self.per_timestep: + codes = rearrange(codes, 'b k (t c) -> c b k t', c=2) + else: + codes = rearrange(codes, 'b (k c) t -> c b k t', c=2) + return codes[0], codes[1] + + def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): + B, K, T = codes.shape + assert T % self.num_virtual_steps == 0, "Provided codes' number of timesteps does not match" + assert K == self.num_codebooks, "Provided codes' number of codebooks does not match" + + scale_c0, scale_c1 = None, None + if scale is not None: + assert scale.size(0) == B and scale.size(1) == 2, f"Scale has unexpected shape: {scale.shape}" + scale_c0 = scale[0, ...] + scale_c1 = scale[1, ...] + + codes_c0, codes_c1 = self.get_left_right_codes(codes) + audio_c0 = self.model.decode(codes_c0, scale_c0) + audio_c1 = self.model.decode(codes_c1, scale_c1) + return torch.cat([audio_c0, audio_c1], dim=1) + + def decode_latent(self, codes: torch.Tensor): + """Decode from the discrete codes to continuous latent space.""" + raise NotImplementedError("Not supported by interleaved stereo wrapped models.") diff --git a/audiocraft/genmodel.py b/audiocraft/genmodel.py new file mode 100644 index 0000000000000000000000000000000000000000..6be361c0215269b375252c41a338ccd756f19028 --- /dev/null +++ b/audiocraft/genmodel.py @@ -0,0 +1,254 @@ +from abc import ABC, abstractmethod +import typing as tp + +import omegaconf +import torch + +from .encodec import CompressionModel +from .lm import LMModel +from .builders import get_wrapped_compression_model +from .utils.audio_utils import convert_audio +from .conditioners import ConditioningAttributes +from .utils.autocast import TorchAutocast + + +class BaseGenModel(ABC): + """Base generative model with convenient generation API. + + Args: + name (str): name of the model. + compression_model (CompressionModel): Compression model + used to map audio to invertible discrete representations. + lm (LMModel): Language model over discrete representations. + max_duration (float, optional): maximum duration the model can produce, + otherwise, inferred from the training params. + """ + def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel, + max_duration: tp.Optional[float] = None): + self.name = name + self.compression_model = compression_model + self.lm = lm + self.cfg: tp.Optional[omegaconf.DictConfig] = None + # Just to be safe, let's put everything in eval mode. + self.compression_model.eval() + self.lm.eval() + + if hasattr(lm, 'cfg'): + cfg = lm.cfg + assert isinstance(cfg, omegaconf.DictConfig) + self.cfg = cfg + + if self.cfg is not None: + self.compression_model = get_wrapped_compression_model(self.compression_model, self.cfg) + + if max_duration is None: + if self.cfg is not None: + max_duration = lm.cfg.dataset.segment_duration # type: ignore + else: + raise ValueError("You must provide max_duration when building directly your GenModel") + assert max_duration is not None + + self.max_duration: float = max_duration + self.duration = self.max_duration + + # self.extend_stride is the length of audio extension when generating samples longer + # than self.max_duration. NOTE: the derived class must set self.extend_stride to a + # positive float value when generating with self.duration > self.max_duration. + self.extend_stride: tp.Optional[float] = None + self.device = next(iter(lm.parameters())).device + self.generation_params: dict = {} + self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None + if self.device.type == 'cpu': + self.autocast = TorchAutocast(enabled=False) + else: + self.autocast = TorchAutocast( + enabled=True, device_type=self.device.type, dtype=torch.float16) + + @property + def frame_rate(self) -> float: + """Roughly the number of AR steps per seconds.""" + return self.compression_model.frame_rate + + @property + def sample_rate(self) -> int: + """Sample rate of the generated audio.""" + return self.compression_model.sample_rate + + @property + def audio_channels(self) -> int: + """Audio channels of the generated audio.""" + return self.compression_model.channels + + def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None): + """Override the default progress callback.""" + self._progress_callback = progress_callback + + @abstractmethod + def set_generation_params(self, *args, **kwargs): + """Set the generation parameters.""" + raise NotImplementedError("No base implementation for setting generation params.") + + @staticmethod + @abstractmethod + def get_pretrained(name: str, device=None): + raise NotImplementedError("No base implementation for getting pretrained model") + + @torch.no_grad() + def _prepare_tokens_and_attributes( + self, + descriptions: tp.Sequence[tp.Optional[str]], + prompt: tp.Optional[torch.Tensor], + ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]: + """Prepare model inputs. + + Args: + descriptions (list of str): A list of strings used as text conditioning. + prompt (torch.Tensor): A batch of waveforms used for continuation. + """ + attributes = [ + ConditioningAttributes(text={'description': description}) + for description in descriptions] + + if prompt is not None: + if descriptions is not None: + assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match" + prompt = prompt.to(self.device) + prompt_tokens, scale = self.compression_model.encode(prompt) + assert scale is None + else: + prompt_tokens = None + return attributes, prompt_tokens + + def generate_unconditional(self, num_samples: int, progress: bool = False, + return_tokens: bool = False) -> tp.Union[torch.Tensor, + tp.Tuple[torch.Tensor, torch.Tensor]]: + """Generate samples in an unconditional manner. + + Args: + num_samples (int): Number of samples to be generated. + progress (bool, optional): Flag to display progress of the generation process. Defaults to False. + """ + descriptions: tp.List[tp.Optional[str]] = [None] * num_samples + attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None) + tokens = self._generate_tokens(attributes, prompt_tokens, progress) + if return_tokens: + return self.generate_audio(tokens), tokens + return self.generate_audio(tokens) + + def generate(self, descriptions: tp.List[str], progress: bool = False, return_tokens: bool = False) \ + -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: + """Generate samples conditioned on text. + + Args: + descriptions (list of str): A list of strings used as text conditioning. + progress (bool, optional): Flag to display progress of the generation process. Defaults to False. + """ + attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None) + assert prompt_tokens is None + tokens = self._generate_tokens(attributes, prompt_tokens, progress) + if return_tokens: + return self.generate_audio(tokens), tokens + return self.generate_audio(tokens) + + def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int, + descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None, + progress: bool = False, return_tokens: bool = False) \ + -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: + """Generate samples conditioned on audio prompts and an optional text description. + + Args: + prompt (torch.Tensor): A batch of waveforms used for continuation. + Prompt should be [B, C, T], or [C, T] if only one sample is generated. + prompt_sample_rate (int): Sampling rate of the given audio waveforms. + descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None. + progress (bool, optional): Flag to display progress of the generation process. Defaults to False. + """ + if prompt.dim() == 2: + prompt = prompt[None] + if prompt.dim() != 3: + raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).") + prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels) + if descriptions is None: + descriptions = [None] * len(prompt) + attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt) + assert prompt_tokens is not None + tokens = self._generate_tokens(attributes, prompt_tokens, progress) + if return_tokens: + return self.generate_audio(tokens), tokens + return self.generate_audio(tokens) + + def _generate_tokens(self, attributes: tp.List[ConditioningAttributes], + prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor: + """Generate discrete audio tokens given audio prompt and/or conditions. + + Args: + attributes (list of ConditioningAttributes): Conditions used for generation (here text). + prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation. + progress (bool, optional): Flag to display progress of the generation process. Defaults to False. + Returns: + torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params. + """ + total_gen_len = int(self.duration * self.frame_rate) + max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate) + current_gen_offset: int = 0 + + def _progress_callback(generated_tokens: int, tokens_to_generate: int): + generated_tokens += current_gen_offset + if self._progress_callback is not None: + # Note that total_gen_len might be quite wrong depending on the + # codebook pattern used, but with delay it is almost accurate. + self._progress_callback(generated_tokens, tokens_to_generate) + else: + print(f'{generated_tokens: 6d} / {tokens_to_generate: 6d}', end='\r') + + if prompt_tokens is not None: + assert max_prompt_len >= prompt_tokens.shape[-1], \ + "Prompt is longer than audio to generate" + + callback = None + if progress: + callback = _progress_callback + + if self.duration <= self.max_duration: + # generate by sampling from LM, simple case. + with self.autocast: + gen_tokens = self.lm.generate( + prompt_tokens, attributes, + callback=callback, max_gen_len=total_gen_len, **self.generation_params) + + else: + assert self.extend_stride is not None, "Stride should be defined to generate beyond max_duration" + assert self.extend_stride < self.max_duration, "Cannot stride by more than max generation duration." + all_tokens = [] + if prompt_tokens is None: + prompt_length = 0 + else: + all_tokens.append(prompt_tokens) + prompt_length = prompt_tokens.shape[-1] + + stride_tokens = int(self.frame_rate * self.extend_stride) + while current_gen_offset + prompt_length < total_gen_len: + time_offset = current_gen_offset / self.frame_rate + chunk_duration = min(self.duration - time_offset, self.max_duration) + max_gen_len = int(chunk_duration * self.frame_rate) + with self.autocast: + gen_tokens = self.lm.generate( + prompt_tokens, attributes, + callback=callback, max_gen_len=max_gen_len, **self.generation_params) + if prompt_tokens is None: + all_tokens.append(gen_tokens) + else: + all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:]) + prompt_tokens = gen_tokens[:, :, stride_tokens:] + prompt_length = prompt_tokens.shape[-1] + current_gen_offset += stride_tokens + + gen_tokens = torch.cat(all_tokens, dim=-1) + return gen_tokens + + def generate_audio(self, gen_tokens: torch.Tensor) -> torch.Tensor: + """Generate Audio from tokens.""" + assert gen_tokens.dim() == 3 + with torch.no_grad(): + gen_audio = self.compression_model.decode(gen_tokens, None) + return gen_audio diff --git a/audiocraft/lm.py b/audiocraft/lm.py new file mode 100644 index 0000000000000000000000000000000000000000..3f6e2c3d382fd23b5a453fd54dc78b0447318a4b --- /dev/null +++ b/audiocraft/lm.py @@ -0,0 +1,1751 @@ +# ========================= From conditioners.py +import soundfile +from collections import defaultdict +from copy import deepcopy +from dataclasses import dataclass, field +from itertools import chain +import logging +import math +from pathlib import Path +import random +import re +import typing as tp +import warnings +import einops +from num2words import num2words +import spacy +from transformers import RobertaTokenizer, T5EncoderModel, T5Tokenizer # type: ignore +import torch +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence +from audiocraft.streaming import StreamingModule +from audiocraft.transformer import create_sin_embedding +from audiocraft.utils.audio_utils import convert_audio +from audiocraft.utils.autocast import TorchAutocast +from audiocraft.utils.cache import EmbeddingCache +from audiocraft.utils.utils import collate, hash_trick, length_to_mask, load_clap_state_dict, warn_once +from audiocraft.transformer import StreamingTransformer, create_norm_fn +from dataclasses import dataclass +from functools import partial +import logging +import math +import typing as tp + + +from torch import nn + +from audiocraft.utils import utils +from audiocraft.codebooks_patterns import CodebooksPatternProvider +from audiocraft.activations import get_activation_fn + + + + + +logger = logging.getLogger(__name__) +TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist) +ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask + + +class WavCondition(tp.NamedTuple): + wav: torch.Tensor + length: torch.Tensor + sample_rate: tp.List[int] + path: tp.List[tp.Optional[str]] = [] + seek_time: tp.List[tp.Optional[float]] = [] + + +class JointEmbedCondition(tp.NamedTuple): + wav: torch.Tensor + text: tp.List[tp.Optional[str]] + length: torch.Tensor + sample_rate: tp.List[int] + path: tp.List[tp.Optional[str]] = [] + seek_time: tp.List[tp.Optional[float]] = [] + + +@dataclass +class ConditioningAttributes: + text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict) + wav: tp.Dict[str, WavCondition] = field(default_factory=dict) + joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict) + + def __getitem__(self, item): + return getattr(self, item) + + @property + def text_attributes(self): + return self.text.keys() + + @property + def wav_attributes(self): + return self.wav.keys() + + @property + def joint_embed_attributes(self): + return self.joint_embed.keys() + + @property + def attributes(self): + return { + "text": self.text_attributes, + "wav": self.wav_attributes, + "joint_embed": self.joint_embed_attributes, + } + + def to_flat_dict(self): + return { + **{f"text.{k}": v for k, v in self.text.items()}, + **{f"wav.{k}": v for k, v in self.wav.items()}, + **{f"joint_embed.{k}": v for k, v in self.joint_embed.items()} + } + + @classmethod + def from_flat_dict(cls, x): + out = cls() + for k, v in x.items(): + kind, att = k.split(".") + out[kind][att] = v + return out + + + + + +def nullify_condition(condition: ConditionType, dim: int = 1): + """Transform an input condition to a null condition. + The way it is done by converting it to a single zero vector similarly + to how it is done inside WhiteSpaceTokenizer and NoopTokenizer. + + Args: + condition (ConditionType): A tuple of condition and mask (tuple[torch.Tensor, torch.Tensor]) + dim (int): The dimension that will be truncated (should be the time dimension) + WARNING!: dim should not be the batch dimension! + Returns: + ConditionType: A tuple of null condition and mask + """ + assert dim != 0, "dim cannot be the batch dimension!" + assert isinstance(condition, tuple) and \ + isinstance(condition[0], torch.Tensor) and \ + isinstance(condition[1], torch.Tensor), "'nullify_condition' got an unexpected input type!" + cond, mask = condition + B = cond.shape[0] + last_dim = cond.dim() - 1 + out = cond.transpose(dim, last_dim) + out = 0. * out[..., :1] + out = out.transpose(dim, last_dim) + mask = torch.zeros((B, 1), device=out.device).int() + assert cond.dim() == out.dim() + return out, mask + + +def nullify_wav(cond: WavCondition) -> WavCondition: + """Transform a WavCondition to a nullified WavCondition. + It replaces the wav by a null tensor, forces its length to 0, and replaces metadata by dummy attributes. + + Args: + cond (WavCondition): Wav condition with wav, tensor of shape [B, T]. + Returns: + WavCondition: Nullified wav condition. + """ + null_wav, _ = nullify_condition((cond.wav, torch.zeros_like(cond.wav)), dim=cond.wav.dim() - 1) + return WavCondition( + wav=null_wav, + length=torch.tensor([0] * cond.wav.shape[0], device=cond.wav.device), + sample_rate=cond.sample_rate, + path=[None] * cond.wav.shape[0], + seek_time=[None] * cond.wav.shape[0], + ) + + +def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition: + """Nullify the joint embedding condition by replacing it by a null tensor, forcing its length to 0, + and replacing metadata by dummy attributes. + + Args: + cond (JointEmbedCondition): Joint embedding condition with wav and text, wav tensor of shape [B, C, T]. + """ + null_wav, _ = nullify_condition((embed.wav, torch.zeros_like(embed.wav)), dim=embed.wav.dim() - 1) + return JointEmbedCondition( + wav=null_wav, text=[None] * len(embed.text), + length=torch.LongTensor([0]).to(embed.wav.device), + sample_rate=embed.sample_rate, + path=[None] * embed.wav.shape[0], + seek_time=[0] * embed.wav.shape[0], + ) + + +class Tokenizer: + """Base tokenizer implementation + (in case we want to introduce more advances tokenizers in the future). + """ + def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError() + + +class WhiteSpaceTokenizer(Tokenizer): + """This tokenizer should be used for natural language descriptions. + For example: + ["he didn't, know he's going home.", 'shorter sentence'] => + [[78, 62, 31, 4, 78, 25, 19, 34], + [59, 77, 0, 0, 0, 0, 0, 0]] + """ + PUNCTUATION = "?:!.,;" + + def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm", + lemma: bool = True, stopwords: bool = True) -> None: + self.n_bins = n_bins + self.pad_idx = pad_idx + self.lemma = lemma + self.stopwords = stopwords + try: + self.nlp = spacy.load(language) + except IOError: + spacy.cli.download(language) # type: ignore + self.nlp = spacy.load(language) + + @tp.no_type_check + def __call__(self, texts: tp.List[tp.Optional[str]], + return_text: bool = False) -> tp.Tuple[torch.Tensor, torch.Tensor]: + """Take a list of strings and convert them to a tensor of indices. + + Args: + texts (list[str]): List of strings. + return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False. + Returns: + tuple[torch.Tensor, torch.Tensor]: + - Indices of words in the LUT. + - And a mask indicating where the padding tokens are + """ + output, lengths = [], [] + texts = deepcopy(texts) + for i, text in enumerate(texts): + # if current sample doesn't have a certain attribute, replace with pad token + if text is None: + output.append(torch.Tensor([self.pad_idx])) + lengths.append(0) + continue + + # convert numbers to words + text = re.sub(r"(\d+)", lambda x: num2words(int(x.group(0))), text) # type: ignore + # normalize text + text = self.nlp(text) # type: ignore + # remove stopwords + if self.stopwords: + text = [w for w in text if not w.is_stop] # type: ignore + # remove punctuation + text = [w for w in text if w.text not in self.PUNCTUATION] # type: ignore + # lemmatize if needed + text = [getattr(t, "lemma_" if self.lemma else "text") for t in text] # type: ignore + + texts[i] = " ".join(text) + lengths.append(len(text)) + # convert to tensor + tokens = torch.Tensor([hash_trick(w, self.n_bins) for w in text]) + output.append(tokens) + + mask = length_to_mask(torch.IntTensor(lengths)).int() + padded_output = pad_sequence(output, padding_value=self.pad_idx).int().t() + if return_text: + return padded_output, mask, texts # type: ignore + return padded_output, mask + + +class NoopTokenizer(Tokenizer): + """This tokenizer should be used for global conditioners such as: artist, genre, key, etc. + The difference between this and WhiteSpaceTokenizer is that NoopTokenizer does not split + strings, so "Jeff Buckley" will get it's own index. Whereas WhiteSpaceTokenizer will + split it to ["Jeff", "Buckley"] and return an index per word. + + For example: + ["Queen", "ABBA", "Jeff Buckley"] => [43, 55, 101] + ["Metal", "Rock", "Classical"] => [0, 223, 51] + """ + def __init__(self, n_bins: int, pad_idx: int = 0): + self.n_bins = n_bins + self.pad_idx = pad_idx + + def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + output, lengths = [], [] + for text in texts: + # if current sample doesn't have a certain attribute, replace with pad token + if text is None: + output.append(self.pad_idx) + lengths.append(0) + else: + output.append(hash_trick(text, self.n_bins)) + lengths.append(1) + + tokens = torch.LongTensor(output).unsqueeze(1) + mask = length_to_mask(torch.IntTensor(lengths)).int() + return tokens, mask + + +class BaseConditioner(nn.Module): + """Base model for all conditioner modules. + We allow the output dim to be different than the hidden dim for two reasons: + 1) keep our LUTs small when the vocab is large; + 2) make all condition dims consistent. + + Args: + dim (int): Hidden dim of the model. + output_dim (int): Output dim of the conditioner. + """ + def __init__(self, dim: int, output_dim: int): + super().__init__() + self.dim = dim + self.output_dim = output_dim + self.output_proj = nn.Linear(dim, output_dim) + + def tokenize(self, *args, **kwargs) -> tp.Any: + """Should be any part of the processing that will lead to a synchronization + point, e.g. BPE tokenization with transfer to the GPU. + + The returned value will be saved and return later when calling forward(). + """ + raise NotImplementedError() + + def forward(self, inputs: tp.Any) -> ConditionType: + """Gets input that should be used as conditioning (e.g, genre, description or a waveform). + Outputs a ConditionType, after the input data was embedded as a dense vector. + + Returns: + ConditionType: + - A tensor of size [B, T, D] where B is the batch size, T is the length of the + output embedding and D is the dimension of the embedding. + - And a mask indicating where the padding tokens. + """ + raise NotImplementedError() + + +class TextConditioner(BaseConditioner): + ... + + +class LUTConditioner(TextConditioner): + """Lookup table TextConditioner. + + Args: + n_bins (int): Number of bins. + dim (int): Hidden dim of the model (text-encoder/LUT). + output_dim (int): Output dim of the conditioner. + tokenizer (str): Name of the tokenizer. + pad_idx (int, optional): Index for padding token. Defaults to 0. + """ + def __init__(self, n_bins: int, dim: int, output_dim: int, tokenizer: str, pad_idx: int = 0): + super().__init__(dim, output_dim) + self.embed = nn.Embedding(n_bins, dim) + self.tokenizer: Tokenizer + if tokenizer == 'whitespace': + self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx) + elif tokenizer == 'noop': + self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx) + else: + raise ValueError(f"unrecognized tokenizer `{tokenizer}`.") + + def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + device = self.embed.weight.device + tokens, mask = self.tokenizer(x) + tokens, mask = tokens.to(device), mask.to(device) + return tokens, mask + + def forward(self, inputs: tp.Tuple[torch.Tensor, torch.Tensor]) -> ConditionType: + tokens, mask = inputs + embeds = self.embed(tokens) + embeds = self.output_proj(embeds) + embeds = (embeds * mask.unsqueeze(-1)) + return embeds, mask + + +class T5Conditioner(TextConditioner): + """T5-based TextConditioner. + + Args: + name (str): Name of the T5 model. + output_dim (int): Output dim of the conditioner. + finetune (bool): Whether to fine-tune T5 at train time. + device (str): Device for T5 Conditioner. + autocast_dtype (tp.Optional[str], optional): Autocast dtype. + word_dropout (float, optional): Word dropout probability. + normalize_text (bool, optional): Whether to apply text normalization. + """ + MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b", + "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", + "google/flan-t5-xl", "google/flan-t5-xxl"] + MODELS_DIMS = { + "t5-small": 512, + "t5-base": 768, + "t5-large": 1024, + "t5-3b": 1024, + "t5-11b": 1024, + "google/flan-t5-small": 512, + "google/flan-t5-base": 768, + "google/flan-t5-large": 1024, + "google/flan-t5-3b": 1024, + "google/flan-t5-11b": 1024, + } + + def __init__(self, name: str, output_dim: int, finetune: bool, device: str, + autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0., + normalize_text: bool = False): + assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})" + super().__init__(self.MODELS_DIMS[name], output_dim) + self.device = device + self.name = name + self.finetune = finetune + self.word_dropout = word_dropout + if autocast_dtype is None or self.device == 'cpu': + self.autocast = TorchAutocast(enabled=False) + if self.device != 'cpu': + logger.warning("T5 has no autocast, this might lead to NaN") + else: + dtype = getattr(torch, autocast_dtype) + assert isinstance(dtype, torch.dtype) + logger.info(f"T5 will be evaluated with autocast as {autocast_dtype}") + self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype) + # Let's disable logging temporarily because T5 will vomit some errors otherwise. + # thanks https://gist.github.com/simon-weber/7853144 + previous_level = logging.root.manager.disable + logging.disable(logging.ERROR) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + self.t5_tokenizer = T5Tokenizer.from_pretrained(name) + t5 = T5EncoderModel.from_pretrained(name).train(mode=finetune) + finally: + logging.disable(previous_level) + if finetune: + self.t5 = t5 + else: + # this makes sure that the t5 models is not part + # of the saved checkpoint + self.__dict__['t5'] = t5.to(device) + + self.normalize_text = normalize_text + if normalize_text: + self.text_normalizer = WhiteSpaceTokenizer(1, lemma=True, stopwords=True) + + def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]: + # if current sample doesn't have a certain attribute, replace with empty string + entries: tp.List[str] = [xi if xi is not None else "" for xi in x] + if self.normalize_text: + _, _, entries = self.text_normalizer(entries, return_text=True) + if self.word_dropout > 0. and self.training: + new_entries = [] + for entry in entries: + words = [word for word in entry.split(" ") if random.random() >= self.word_dropout] + new_entries.append(" ".join(words)) + entries = new_entries + + empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""]) + + inputs = self.t5_tokenizer(entries, return_tensors='pt', padding=True).to(self.device) + mask = inputs['attention_mask'] + mask[empty_idx, :] = 0 # zero-out index where the input is non-existant + return inputs + + def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType: + mask = inputs['attention_mask'] + with torch.set_grad_enabled(self.finetune), self.autocast: + embeds = self.t5(**inputs).last_hidden_state + embeds = self.output_proj(embeds.to(self.output_proj.weight)) + embeds = (embeds * mask.unsqueeze(-1)) + return embeds, mask + + +class WaveformConditioner(BaseConditioner): + """Base class for all conditioners that take a waveform as input. + Classes that inherit must implement `_get_wav_embedding` that outputs + a continuous tensor, and `_downsampling_factor` that returns the down-sampling + factor of the embedding model. + + Args: + dim (int): The internal representation dimension. + output_dim (int): Output dimension. + device (tp.Union[torch.device, str]): Device. + """ + def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]): + super().__init__(dim, output_dim) + self.device = device + # if False no masking is done, used in ChromaStemConditioner when completing by periodicity a sample. + self._use_masking = True + + def tokenize(self, x: WavCondition) -> WavCondition: + wav, length, sample_rate, path, seek_time = x + assert length is not None + return WavCondition(wav.to(self.device), length.to(self.device), sample_rate, path, seek_time) + + def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor: + """Gets as input a WavCondition and returns a dense embedding.""" + raise NotImplementedError() + + def _downsampling_factor(self): + """Returns the downsampling factor of the embedding model.""" + raise NotImplementedError() + + def forward(self, x: WavCondition) -> ConditionType: + """Extract condition embedding and mask from a waveform and its metadata. + Args: + x (WavCondition): Waveform condition containing raw waveform and metadata. + Returns: + ConditionType: a dense vector representing the conditioning along with its mask + """ + wav, lengths, *_ = x + with torch.no_grad(): + embeds = self._get_wav_embedding(x) + embeds = embeds.to(self.output_proj.weight) + embeds = self.output_proj(embeds) + + if lengths is not None and self._use_masking: + lengths = lengths / self._downsampling_factor() + mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore + else: + mask = torch.ones_like(embeds[..., 0]) + embeds = (embeds * mask.unsqueeze(-1)) + return embeds, mask + + + + + +class JointEmbeddingConditioner(BaseConditioner): + """Joint embedding conditioning supporting both audio or text conditioning. + + Args: + dim (int): Dimension. + output_dim (int): Output dimension. + device (str): Device. + attribute (str): Attribute used by the conditioner. + autocast_dtype (str): Autocast for the conditioner. + quantize (bool): Whether to quantize the CLAP embedding. + n_q (int): Number of residual quantizers (used if quantize is true). + bins (int): Quantizers' codebooks size (used if quantize is true). + kwargs: Additional parameters for residual vector quantizer. + """ + def __init__(self, dim: int, output_dim: int, device: str, attribute: str, + autocast_dtype: tp.Optional[str] = 'float32', quantize: bool = True, + n_q: int = 12, bins: int = 1024, **kwargs): + super().__init__(dim=dim, output_dim=output_dim) + self.device = device + self.attribute = attribute + if autocast_dtype is None or device == 'cpu': + self.autocast = TorchAutocast(enabled=False) + logger.warning("JointEmbeddingConditioner has no autocast, this might lead to NaN.") + else: + dtype = getattr(torch, autocast_dtype) + assert isinstance(dtype, torch.dtype) + logger.info(f"JointEmbeddingConditioner will be evaluated with autocast as {autocast_dtype}.") + self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype) + # residual vector quantizer to discretize the conditioned embedding + self.quantizer=None + if quantize: + print('\n\n\n\nWANTS TO QUANTIZE on Inference\n\n\n\n') + # self.quantizer = ResidualVectorQuantizer(dim, n_q=n_q, bins=bins, **kwargs) + + def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]: + """Get joint embedding in latent space from the inputs. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Tensor for the latent embedding + and corresponding empty indexes. + """ + raise NotImplementedError() + + def forward(self, x: JointEmbedCondition) -> ConditionType: + with self.autocast: + embed, empty_idx = self._get_embed(x) + if self.quantizer is not None: + embed = embed.view(-1, self.dim, 1) + q_res = self.quantizer(embed, frame_rate=1) + out_embed = q_res.x.view(-1, self.dim) + else: + out_embed = embed + out_embed = self.output_proj(out_embed).view(-1, 1, self.output_dim) + mask = torch.ones(*out_embed.shape[:2], device=out_embed.device) + mask[empty_idx, :] = 0 # zero-out index where the input is non-existant + out_embed = (out_embed * mask.unsqueeze(-1)) + return out_embed, mask + + def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition: + return x + + +class CLAPEmbeddingConditioner(JointEmbeddingConditioner): + """Joint Embedding conditioner based on pre-trained CLAP model. + + This CLAP-based conditioner supports a caching mechanism + over the computed embeddings for faster training. + + Args: + dim (int): Dimension. + output_dim (int): Output dimension. + device (str): Device. + attribute (str): Attribute used by the conditioner. + quantize (bool): Whether to quantize the CLAP embedding. + n_q (int): Number of residual quantizers (used if quantize is true). + bins (int): Quantizers' codebooks size (used if quantize is true). + checkpoint (str): Path to CLAP checkpoint. + model_arch (str): CLAP model architecture. + enable_fusion (bool): Enable fusion for CLAP model. + sample_rate (int): Sample rate used by CLAP model. + max_audio_length (float): Maximum audio length for CLAP model. + audio_stride (float): Stride to use for getting a CLAP embedding on the full sequence. + normalize (bool): Whether to normalize the CLAP embedding. + text_p (float): Probability of using text representation instead of audio at train time. + batch_size (Optional[int]): Batch size for CLAP embedding computation. + autocast_dtype (str): Autocast for the conditioner. + cache_path (Optional[str]): Path for pre-computed embeddings caching. + kwargs: Additional parameters for residual vector quantizer. + """ + def __init__(self, dim: int, output_dim: int, device: str, attribute: str, + quantize: bool, n_q: int, bins: int, checkpoint: tp.Union[str, Path], model_arch: str, + enable_fusion: bool, sample_rate: int, max_audio_length: int, audio_stride: int, + normalize: bool, text_p: bool, batch_size: tp.Optional[int] = None, + autocast_dtype: tp.Optional[str] = 'float32', cache_path: tp.Optional[str] = None, **kwargs): + try: + import laion_clap # type: ignore + except ImportError: + raise ImportError("Please install CLAP to use the CLAPEmbeddingConditioner: 'pip install laion_clap'") + warnings.warn("Sample rate for CLAP conditioner was fixed in version v1.1.0, (from 44.1 to 48 kHz). " + "Please retrain all models.") + checkpoint = AudioCraftEnvironment.resolve_reference_path(checkpoint) + clap_tokenize = RobertaTokenizer.from_pretrained('roberta-base') + clap_model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch) + load_clap_state_dict(clap_model, checkpoint) + clap_model.eval() + clap_model.to(device) + super().__init__(dim=dim, output_dim=output_dim, device=device, attribute=attribute, + autocast_dtype=autocast_dtype, quantize=quantize, n_q=n_q, bins=bins, + **kwargs) + self.checkpoint = checkpoint + self.enable_fusion = enable_fusion + self.model_arch = model_arch + self.clap: laion_clap.CLAP_Module + self.clap_tokenize: RobertaTokenizer + self.clap_sample_rate = sample_rate + self.clap_max_frames = int(self.clap_sample_rate * max_audio_length) + self.clap_stride = int(self.clap_sample_rate * audio_stride) + self.batch_size = batch_size or 1 + self.normalize = normalize + self.text_p = text_p + self.__dict__['clap_tokenize'] = clap_tokenize + self.__dict__['clap'] = clap_model + self.wav_cache, self.text_cache = None, None + if cache_path is not None: + self.wav_cache = EmbeddingCache(Path(cache_path) / 'wav', self.device, + compute_embed_fn=self._get_wav_embedding_for_cache, + extract_embed_fn=self._extract_wav_embedding_chunk) + self.text_cache = EmbeddingCache(Path(cache_path) / 'text', self.device, + compute_embed_fn=self._get_text_embedding_for_cache) + + def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict: + # we use the default params from CLAP module here as well + return self.clap_tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt") + + def _compute_text_embedding(self, text: tp.List[str]) -> torch.Tensor: + """Compute text embedding from CLAP model on a given a batch of text. + + Args: + text (list[str]): List of text for the batch, with B items. + Returns: + torch.Tensor: CLAP embedding derived from text, of shape [B, 1, D], with D the CLAP embedding dimension. + """ + with torch.no_grad(): + embed = self.clap.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True) + return embed.view(embed.size(0), 1, embed.size(-1)) + + def _get_text_embedding_for_cache(self, path: tp.Union[Path, str], + x: JointEmbedCondition, idx: int) -> torch.Tensor: + """Get text embedding function for the cache.""" + text = x.text[idx] + text = text if text is not None else "" + return self._compute_text_embedding([text])[0] + + def _preprocess_wav(self, wav: torch.Tensor, length: torch.Tensor, sample_rates: tp.List[int]) -> torch.Tensor: + """Preprocess wav to expected format by CLAP model. + + Args: + wav (torch.Tensor): Audio wav, of shape [B, C, T]. + length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B]. + sample_rates (list[int]): Sample rates for each sample in the batch + Returns: + torch.Tensor: Audio wav of shape [B, T]. + """ + assert wav.dim() == 3, "Expecting wav to be [B, C, T]" + if sample_rates is not None: + _wav = [] + for i, audio in enumerate(wav): + sr = sample_rates[i] + audio = convert_audio(audio, from_rate=sr, to_rate=self.clap_sample_rate, to_channels=1) + _wav.append(audio) + wav = torch.stack(_wav, dim=0) + wav = wav.mean(dim=1) + return wav + + def _compute_wav_embedding(self, wav: torch.Tensor, length: torch.Tensor, + sample_rates: tp.List[int], reduce_mean: bool = False) -> torch.Tensor: + """Compute audio wave embedding from CLAP model. + + Since CLAP operates on a fixed sequence length audio inputs and we need to process longer audio sequences, + we calculate the wav embeddings on `clap_max_frames` windows with `clap_stride`-second stride and + average the resulting embeddings. + + Args: + wav (torch.Tensor): Audio wav, of shape [B, C, T]. + length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B]. + sample_rates (list[int]): Sample rates for each sample in the batch. + reduce_mean (bool): Whether to get the average tensor. + Returns: + torch.Tensor: Audio embedding of shape [B, F, D], F being the number of chunks, D the dimension. + """ + with torch.no_grad(): + wav = self._preprocess_wav(wav, length, sample_rates) + B, T = wav.shape + if T >= self.clap_max_frames: + wav = wav.unfold(-1, self.clap_max_frames, self.clap_stride) # [B, F, T] + else: + wav = wav.view(-1, 1, T) # [B, F, T] with F=1 + wav = einops.rearrange(wav, 'b f t -> (b f) t') + embed_list = [] + for i in range(0, wav.size(0), self.batch_size): + _wav = wav[i:i+self.batch_size, ...] + _embed = self.clap.get_audio_embedding_from_data(_wav, use_tensor=True) + embed_list.append(_embed) + embed = torch.cat(embed_list, dim=0) + embed = einops.rearrange(embed, '(b f) d -> b f d', b=B) + if reduce_mean: + embed = embed.mean(dim=1, keepdim=True) + return embed # [B, F, D] with F=1 if reduce_mean is True + + def _get_wav_embedding_for_cache(self, path: tp.Union[str, Path], + x: JointEmbedCondition, idx: int) -> torch.Tensor: + """Compute audio wave embedding for the cache. + The embedding is computed on a given audio read from file. + + Args: + path (str or Path): Path to the full audio file. + Returns: + torch.Tensor: Single-item tensor of shape [F, D], F being the number of chunks, D the dimension. + """ + wav, sr = soundfile.read(path) # [C, T] + wav = wav.unsqueeze(0).to(self.device) # [1, C, T] + wav_len = torch.LongTensor([wav.shape[-1]]).to(self.device) + embed = self._compute_wav_embedding(wav, wav_len, [sr], reduce_mean=False) # [B, F, D] + return embed.squeeze(0) # [F, D] + + def _extract_wav_embedding_chunk(self, full_embed: torch.Tensor, x: JointEmbedCondition, idx: int) -> torch.Tensor: + """Extract the chunk of embedding matching the seek_time and length from the full CLAP audio embedding. + + Args: + full_embed (torch.Tensor): CLAP embedding computed on the full wave, of shape [F, D]. + x (JointEmbedCondition): Joint embedding condition for the full batch. + idx (int): Index considered for the given embedding to extract. + Returns: + torch.Tensor: Wav embedding averaged on sliding window, of shape [1, D]. + """ + sample_rate = x.sample_rate[idx] + seek_time = x.seek_time[idx] + seek_time = 0. if seek_time is None else seek_time + clap_stride = int(self.clap_stride / self.clap_sample_rate) * sample_rate + end_seek_time = seek_time + self.clap_max_frames / self.clap_sample_rate + start_offset = int(seek_time * sample_rate // clap_stride) + end_offset = int(end_seek_time * sample_rate // clap_stride) + wav_embed = full_embed[start_offset:end_offset, ...] + wav_embed = wav_embed.mean(dim=0, keepdim=True) + return wav_embed.to(self.device) # [F, D] + + def _get_text_embedding(self, x: JointEmbedCondition) -> torch.Tensor: + """Get CLAP embedding from a batch of text descriptions.""" + no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout + if self.text_cache is not None and no_nullified_cond: + assert all(p is not None for p in x.path), "Cache requires all JointEmbedCondition paths to be provided" + paths = [Path(p) for p in x.path if p is not None] + embed = self.text_cache.get_embed_from_cache(paths, x) + else: + text = [xi if xi is not None else "" for xi in x.text] + embed = self._compute_text_embedding(text) + if self.normalize: + embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1) + return embed + + def _get_wav_embedding(self, x: JointEmbedCondition) -> torch.Tensor: + """Get CLAP embedding from a batch of audio tensors (and corresponding sample rates).""" + no_undefined_paths = all(p is not None for p in x.path) + no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout + if self.wav_cache is not None and no_undefined_paths and no_nullified_cond: + paths = [Path(p) for p in x.path if p is not None] + embed = self.wav_cache.get_embed_from_cache(paths, x) + else: + embed = self._compute_wav_embedding(x.wav, x.length, x.sample_rate, reduce_mean=True) + if self.normalize: + embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1) + return embed + + def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition: + # Trying to limit as much as possible sync points when the cache is warm. + no_undefined_paths = all(p is not None for p in x.path) + if self.wav_cache is not None and no_undefined_paths: + assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided" + paths = [Path(p) for p in x.path if p is not None] + self.wav_cache.populate_embed_cache(paths, x) + if self.text_cache is not None and no_undefined_paths: + assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided" + paths = [Path(p) for p in x.path if p is not None] + self.text_cache.populate_embed_cache(paths, x) + return x + + def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]: + """Extract shared latent representation from either the wav or the text using CLAP.""" + # decide whether to use text embedding at train time or not + use_text_embed = random.random() < self.text_p + if self.training and not use_text_embed: + embed = self._get_wav_embedding(x) + empty_idx = torch.LongTensor([]) # we assume we always have the audio wav + else: + embed = self._get_text_embedding(x) + empty_idx = torch.LongTensor([i for i, xi in enumerate(x.text) if xi is None or xi == ""]) + return embed, empty_idx + + +def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) -> ConditioningAttributes: + """Utility function for nullifying an attribute inside an ConditioningAttributes object. + If the condition is of type "wav", then nullify it using `nullify_condition` function. + If the condition is of any other type, set its value to None. + Works in-place. + """ + if condition_type not in ['text', 'wav', 'joint_embed']: + raise ValueError( + "dropout_condition got an unexpected condition type!" + f" expected 'text', 'wav' or 'joint_embed' but got '{condition_type}'" + ) + + if condition not in getattr(sample, condition_type): + raise ValueError( + "dropout_condition received an unexpected condition!" + f" expected wav={sample.wav.keys()} and text={sample.text.keys()}" + f" but got '{condition}' of type '{condition_type}'!" + ) + + if condition_type == 'wav': + wav_cond = sample.wav[condition] + sample.wav[condition] = nullify_wav(wav_cond) + elif condition_type == 'joint_embed': + embed = sample.joint_embed[condition] + sample.joint_embed[condition] = nullify_joint_embed(embed) + else: + sample.text[condition] = None + + return sample + + +class DropoutModule(nn.Module): + """Base module for all dropout modules.""" + def __init__(self, seed: int = 1234): + super().__init__() + self.rng = torch.Generator() + self.rng.manual_seed(seed) + + +class AttributeDropout(DropoutModule): + """Dropout with a given probability per attribute. + This is different from the behavior of ClassifierFreeGuidanceDropout as this allows for attributes + to be dropped out separately. For example, "artist" can be dropped while "genre" remains. + This is in contrast to ClassifierFreeGuidanceDropout where if "artist" is dropped "genre" + must also be dropped. + + Args: + p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example: + ... + "genre": 0.1, + "artist": 0.5, + "wav": 0.25, + ... + active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False. + seed (int, optional): Random seed. + """ + def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234): + super().__init__(seed=seed) + self.active_on_eval = active_on_eval + # construct dict that return the values from p otherwise 0 + self.p = {} + for condition_type, probs in p.items(): + self.p[condition_type] = defaultdict(lambda: 0, probs) + + def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]: + """ + Args: + samples (list[ConditioningAttributes]): List of conditions. + Returns: + list[ConditioningAttributes]: List of conditions after certain attributes were set to None. + """ + if not self.training and not self.active_on_eval: + return samples + + samples = deepcopy(samples) + for condition_type, ps in self.p.items(): # for condition types [text, wav] + for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre]) + if torch.rand(1, generator=self.rng).item() < p: + for sample in samples: + dropout_condition(sample, condition_type, condition) + return samples + + def __repr__(self): + return f"AttributeDropout({dict(self.p)})" + + +class ClassifierFreeGuidanceDropout(DropoutModule): + """Classifier Free Guidance dropout. + All attributes are dropped with the same probability. + + Args: + p (float): Probability to apply condition dropout during training. + seed (int): Random seed. + """ + def __init__(self, p: float, seed: int = 1234): + super().__init__(seed=seed) + self.p = p + + def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]: + """ + Args: + samples (list[ConditioningAttributes]): List of conditions. + Returns: + list[ConditioningAttributes]: List of conditions after all attributes were set to None. + """ + if not self.training: + return samples + + # decide on which attributes to drop in a batched fashion + drop = torch.rand(1, generator=self.rng).item() < self.p + if not drop: + return samples + + # nullify conditions of all attributes + samples = deepcopy(samples) + for condition_type in ["wav", "text"]: + for sample in samples: + for condition in sample.attributes[condition_type]: + dropout_condition(sample, condition_type, condition) + return samples + + def __repr__(self): + return f"ClassifierFreeGuidanceDropout(p={self.p})" + + +class ConditioningProvider(nn.Module): + """Prepare and provide conditions given all the supported conditioners. + + Args: + conditioners (dict): Dictionary of conditioners. + device (torch.device or str, optional): Device for conditioners and output condition types. + """ + def __init__(self, conditioners: tp.Dict[str, BaseConditioner], device: tp.Union[torch.device, str] = "cpu"): + super().__init__() + self.device = device + self.conditioners = nn.ModuleDict(conditioners) + + @property + def joint_embed_conditions(self): + return [m.attribute for m in self.conditioners.values() if isinstance(m, JointEmbeddingConditioner)] + + @property + def has_joint_embed_conditions(self): + return len(self.joint_embed_conditions) > 0 + + @property + def text_conditions(self): + return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)] + + @property + def wav_conditions(self): + return [k for k, v in self.conditioners.items() if isinstance(v, WaveformConditioner)] + + @property + def has_wav_condition(self): + return len(self.wav_conditions) > 0 + + def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]: + """Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly. + This should be called before starting any real GPU work to avoid synchronization points. + This will return a dict matching conditioner names to their arbitrary tokenized representations. + + Args: + inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing + text and wav conditions. + """ + assert all([isinstance(x, ConditioningAttributes) for x in inputs]), ( + "Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]", + f" but types were {set([type(x) for x in inputs])}" + ) + + output = {} + text = self._collate_text(inputs) + wavs = self._collate_wavs(inputs) + joint_embeds = self._collate_joint_embeds(inputs) + + assert set(text.keys() | wavs.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), ( + f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ", + f"got {text.keys(), wavs.keys(), joint_embeds.keys()}" + ) + + for attribute, batch in chain(text.items(), wavs.items(), joint_embeds.items()): + output[attribute] = self.conditioners[attribute].tokenize(batch) + return output + + def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]: + """Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations. + The output is for example: + { + "genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])), + "description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])), + ... + } + + Args: + tokenized (dict): Dict of tokenized representations as returned by `tokenize()`. + """ + output = {} + for attribute, inputs in tokenized.items(): + condition, mask = self.conditioners[attribute](inputs) + output[attribute] = (condition, mask) + return output + + def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]: + """Given a list of ConditioningAttributes objects, compile a dictionary where the keys + are the attributes and the values are the aggregated input per attribute. + For example: + Input: + [ + ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...), + ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...), + ] + Output: + { + "genre": ["Rock", "Hip-hop"], + "description": ["A rock song with a guitar solo", "A hip-hop verse"] + } + + Args: + samples (list of ConditioningAttributes): List of ConditioningAttributes samples. + Returns: + dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch. + """ + out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list) + texts = [x.text for x in samples] + for text in texts: + for condition in self.text_conditions: + out[condition].append(text[condition]) + return out + + def _collate_wavs(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, WavCondition]: + """Generate a dict where the keys are attributes by which we fetch similar wavs, + and the values are Tensors of wavs according to said attributes. + + *Note*: by the time the samples reach this function, each sample should have some waveform + inside the "wav" attribute. It should be either: + 1. A real waveform + 2. A null waveform due to the sample having no similar waveforms (nullified by the dataset) + 3. A null waveform due to it being dropped in a dropout module (nullified by dropout) + + Args: + samples (list of ConditioningAttributes): List of ConditioningAttributes samples. + Returns: + dict[str, WavCondition]: A dictionary mapping an attribute name to wavs. + """ + wavs = defaultdict(list) + lengths = defaultdict(list) + sample_rates = defaultdict(list) + paths = defaultdict(list) + seek_times = defaultdict(list) + out: tp.Dict[str, WavCondition] = {} + + for sample in samples: + for attribute in self.wav_conditions: + wav, length, sample_rate, path, seek_time = sample.wav[attribute] + assert wav.dim() == 3, f"Got wav with dim={wav.dim()}, but expected 3 [1, C, T]" + assert wav.size(0) == 1, f"Got wav [B, C, T] with shape={wav.shape}, but expected B == 1" + # mono-channel conditioning + wav = wav.mean(1, keepdim=True) # [1, 1, T] + wavs[attribute].append(wav.flatten()) # [T] + lengths[attribute].append(length) + sample_rates[attribute].extend(sample_rate) + paths[attribute].extend(path) + seek_times[attribute].extend(seek_time) + + # stack all wavs to a single tensor + for attribute in self.wav_conditions: + stacked_wav, _ = collate(wavs[attribute], dim=0) + out[attribute] = WavCondition( + stacked_wav.unsqueeze(1), torch.cat(lengths[attribute]), sample_rates[attribute], + paths[attribute], seek_times[attribute]) + + return out + + def _collate_joint_embeds(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, JointEmbedCondition]: + """Generate a dict where the keys are attributes by which we compute joint embeddings, + and the values are Tensors of pre-computed embeddings and the corresponding text attributes. + + Args: + samples (list[ConditioningAttributes]): List of ConditioningAttributes samples. + Returns: + A dictionary mapping an attribute name to joint embeddings. + """ + texts = defaultdict(list) + wavs = defaultdict(list) + lengths = defaultdict(list) + sample_rates = defaultdict(list) + paths = defaultdict(list) + seek_times = defaultdict(list) + channels: int = 0 + + out = {} + for sample in samples: + for attribute in self.joint_embed_conditions: + wav, text, length, sample_rate, path, seek_time = sample.joint_embed[attribute] + assert wav.dim() == 3 + if channels == 0: + channels = wav.size(1) + else: + assert channels == wav.size(1), "not all audio has same number of channels in batch" + assert wav.size(0) == 1, "Expecting single-wav batch in the collate method" + wav = einops.rearrange(wav, "b c t -> (b c t)") # [1, C, T] => [C * T] + wavs[attribute].append(wav) + texts[attribute].extend(text) + lengths[attribute].append(length) + sample_rates[attribute].extend(sample_rate) + paths[attribute].extend(path) + seek_times[attribute].extend(seek_time) + + for attribute in self.joint_embed_conditions: + stacked_texts = texts[attribute] + stacked_paths = paths[attribute] + stacked_seek_times = seek_times[attribute] + stacked_wavs = pad_sequence(wavs[attribute]).to(self.device) + stacked_wavs = einops.rearrange(stacked_wavs, "(c t) b -> b c t", c=channels) + stacked_sample_rates = sample_rates[attribute] + stacked_lengths = torch.cat(lengths[attribute]).to(self.device) + assert stacked_lengths.size(0) == stacked_wavs.size(0) + assert len(stacked_sample_rates) == stacked_wavs.size(0) + assert len(stacked_texts) == stacked_wavs.size(0) + out[attribute] = JointEmbedCondition( + text=stacked_texts, wav=stacked_wavs, + length=stacked_lengths, sample_rate=stacked_sample_rates, + path=stacked_paths, seek_time=stacked_seek_times) + + return out + + +class ConditionFuser(StreamingModule): + """Condition fuser handles the logic to combine the different conditions + to the actual model input. + + Args: + fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse + each condition. For example: + { + "prepend": ["description"], + "sum": ["genre", "bpm"], + "cross": ["description"], + } + cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention. + cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used. + """ + FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"] + + def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False, + cross_attention_pos_emb_scale: float = 1.0): + super().__init__() + assert all( + [k in self.FUSING_METHODS for k in fuse2cond.keys()] + ), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}" + self.cross_attention_pos_emb = cross_attention_pos_emb + self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale + self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond + self.cond2fuse: tp.Dict[str, str] = {} + for fuse_method, conditions in fuse2cond.items(): + for condition in conditions: + self.cond2fuse[condition] = fuse_method + + def forward( + self, + input: torch.Tensor, + conditions: tp.Dict[str, ConditionType] + ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + """Fuse the conditions to the provided model input. + + Args: + input (torch.Tensor): Transformer input. + conditions (dict[str, ConditionType]): Dict of conditions. + Returns: + tuple[torch.Tensor, torch.Tensor]: The first tensor is the transformer input + after the conditions have been fused. The second output tensor is the tensor + used for cross-attention or None if no cross attention inputs exist. + """ + B, T, _ = input.shape + + if 'offsets' in self._streaming_state: + first_step = False + offsets = self._streaming_state['offsets'] + else: + first_step = True + offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device) + + assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \ + f"given conditions contain unknown attributes for fuser, " \ + f"expected {self.cond2fuse.keys()}, got {conditions.keys()}" + cross_attention_output = None + for cond_type, (cond, cond_mask) in conditions.items(): + op = self.cond2fuse[cond_type] + if op == 'sum': + input += cond + elif op == 'input_interpolate': + cond = einops.rearrange(cond, "b t d -> b d t") + cond = F.interpolate(cond, size=input.shape[1]) + input += einops.rearrange(cond, "b d t -> b t d") + elif op == 'prepend': + if first_step: + input = torch.cat([cond, input], dim=1) + elif op == 'cross': + if cross_attention_output is not None: + cross_attention_output = torch.cat([cross_attention_output, cond], dim=1) + else: + cross_attention_output = cond + else: + raise ValueError(f"unknown op ({op})") + + if self.cross_attention_pos_emb and cross_attention_output is not None: + positions = torch.arange( + cross_attention_output.shape[1], + device=cross_attention_output.device + ).view(1, -1, 1) + pos_emb = create_sin_embedding(positions, cross_attention_output.shape[-1]) + cross_attention_output = cross_attention_output + self.cross_attention_pos_emb_scale * pos_emb + + if self._is_streaming: + self._streaming_state['offsets'] = offsets + T + + return input, cross_attention_output + + + +# ============================================== From LM.py + + + +logger = logging.getLogger(__name__) +ConditionTensors = tp.Dict[str, ConditionType] +CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]] + + +def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None): + """LM layer initialization. + Inspired from xlformers: https://github.com/fairinternal/xlformers + + Args: + method (str): Method name for init function. Valid options are: + 'gaussian', 'uniform'. + input_dim (int): Input dimension of the initialized module. + init_depth (int, optional): Optional init depth value used to rescale + the standard deviation if defined. + """ + # Compute std + std = 1 / math.sqrt(input_dim) + # Rescale with depth + if init_depth is not None: + std = std / math.sqrt(2 * init_depth) + + if method == 'gaussian': + return partial( + torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std + ) + elif method == 'uniform': + bound = math.sqrt(3) * std # ensure the standard deviation is `std` + return partial(torch.nn.init.uniform_, a=-bound, b=bound) + else: + raise ValueError("Unsupported layer initialization method") + + +def init_layer(m: nn.Module, + method: str, + init_depth: tp.Optional[int] = None, + zero_bias_init: bool = False): + """Wrapper around ``get_init_fn`` for proper initialization of LM modules. + + Args: + m (nn.Module): Module to initialize. + method (str): Method name for the init function. + init_depth (int, optional): Optional init depth value used to rescale + the standard deviation if defined. + zero_bias_init (bool): Whether to initialize the bias to 0 or not. + """ + if isinstance(m, nn.Linear): + init_fn = get_init_fn(method, m.in_features, init_depth=init_depth) + if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16: + weight = m.weight.float() + init_fn(weight) + m.weight.data[:] = weight.half() + else: + init_fn(m.weight) + if zero_bias_init and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Embedding): + init_fn = get_init_fn(method, m.embedding_dim, init_depth=None) + if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16: + weight = m.weight.float() + init_fn(weight) + m.weight.data[:] = weight.half() + else: + init_fn(m.weight) + + +class ScaledEmbedding(nn.Embedding): + """Boost learning rate for embeddings (with `scale`). + """ + def __init__(self, *args, lr=None, **kwargs): + super().__init__(*args, **kwargs) + self.lr = lr + + def make_optim_group(self): + group = {"params": list(self.parameters())} + if self.lr is not None: + group["lr"] = self.lr + return group + + +@dataclass +class LMOutput: + # The logits are already re-aligned with the input codes + # hence no extra shift is required, e.g. when computing CE + logits: torch.Tensor # [B, K, T, card] + mask: torch.Tensor # [B, K, T] + + +class LMModel(StreamingModule): + """Transformer-based language model on multiple streams of codes. + + Args: + pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving. + condition_provider (MusicConditioningProvider): Conditioning provider from metadata. + fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input. + n_q (int): Number of parallel streams to model. + card (int): Cardinality, vocabulary size. + dim (int): Dimension of the transformer encoder. + num_heads (int): Number of heads for the transformer encoder. + hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder. + norm (str): Normalization method. + norm_first (bool): Use pre-norm instead of post-norm. + emb_lr (float, optional): Embedding-specific learning rate. + bias_proj (bool): Use bias for output projections. + weight_init (str, optional): Method for weight initialization. + depthwise_init (str, optional): Method for depthwise weight initialization. + zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros. + cfg_dropout (float): Classifier-free guidance dropout. + cfg_coef (float): Classifier-free guidance coefficient. + attribute_dropout (dict): Attribute dropout probabilities. + two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps. + **kwargs: Additional parameters for the transformer encoder. + """ + def __init__(self, pattern_provider: CodebooksPatternProvider, condition_provider: ConditioningProvider, + fuser: ConditionFuser, n_q: int = 8, card: int = 1024, dim: int = 128, num_heads: int = 8, + hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False, + emb_lr: tp.Optional[float] = None, bias_proj: bool = True, + weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None, + zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0, + attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}, two_step_cfg: bool = False, + **kwargs): + super().__init__() + self.cfg_coef = cfg_coef + self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout) + self.att_dropout = AttributeDropout(p=attribute_dropout) + self.condition_provider = condition_provider + self.fuser = fuser + self.card = card + embed_dim = self.card + 1 + self.n_q = n_q + self.dim = dim + self.pattern_provider = pattern_provider + self.two_step_cfg = two_step_cfg + self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)]) + if 'activation' in kwargs: + kwargs['activation'] = get_activation_fn(kwargs['activation']) + self.transformer = StreamingTransformer( + d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim), + norm=norm, norm_first=norm_first, **kwargs) + self.out_norm: tp.Optional[nn.Module] = None + if norm_first: + self.out_norm = create_norm_fn(norm, dim) + self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=bias_proj) for _ in range(n_q)]) + self._init_weights(weight_init, depthwise_init, zero_bias_init) + self._fsdp: tp.Optional[nn.Module] + self.__dict__['_fsdp'] = None + + def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool): + """Initialization of the transformer module weights. + + Args: + weight_init (str, optional): Weight initialization strategy. See ``get_init_fn`` for valid options. + depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid: + 'current' where the depth corresponds to the current layer index or 'global' where the total number + of layer is used as depth. If not set, no depthwise initialization strategy is used. + zero_bias_init (bool): Whether to initialize bias to zero or not. + """ + assert depthwise_init is None or depthwise_init in ['current', 'global'] + assert depthwise_init is None or weight_init is not None, \ + "If 'depthwise_init' is defined, a 'weight_init' method should be provided." + assert not zero_bias_init or weight_init is not None, \ + "If 'zero_bias_init', a 'weight_init' method should be provided" + + if weight_init is None: + return + + for emb_layer in self.emb: + init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init) + + for layer_idx, tr_layer in enumerate(self.transformer.layers): + depth = None + if depthwise_init == 'current': + depth = layer_idx + 1 + elif depthwise_init == 'global': + depth = len(self.transformer.layers) + init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init) + tr_layer.apply(init_fn) + + for linear in self.linears: + init_layer(linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init) + + @property + def special_token_id(self) -> int: + return self.card + + @property + def num_codebooks(self) -> int: + return self.n_q + + def forward(self, sequence: torch.Tensor, + conditions: tp.List[ConditioningAttributes], + condition_tensors: tp.Optional[ConditionTensors] = None, + stage: int = -1) -> torch.Tensor: + """Apply language model on sequence and conditions. + Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and + S the sequence steps, return the logits with shape [B, card, K, S]. + + Args: + indices (torch.Tensor): Indices of the codes to model. + conditions (list of ConditioningAttributes): Conditions to use when modeling + the given codes. Note that when evaluating multiple time with the same conditioning + you should pre-compute those and pass them as `condition_tensors`. + condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning + tensors, see `conditions`. + stage (int): The codebook level that is being predicted. Relevant for MAGNeT + in which prediction is done in a codebook-by-codebook manner. + Takes values in range(n_q), and ignored by default. + Returns: + torch.Tensor: Logits. + """ + B, K, S = sequence.shape + assert K == self.num_codebooks, "Sequence shape must match the specified number of codebooks" + input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)]) + if condition_tensors is None: + assert not self._is_streaming, "Conditions tensors should be precomputed when streaming." + # apply dropout modules + conditions = self.cfg_dropout(conditions) + conditions = self.att_dropout(conditions) + tokenized = self.condition_provider.tokenize(conditions) + # encode conditions and fuse, both have a streaming cache to not recompute when generating. + condition_tensors = self.condition_provider(tokenized) + else: + assert not conditions, "Shouldn't pass both conditions and condition_tensors." + + input_, cross_attention_input = self.fuser(input_, condition_tensors) + + out = self.transformer(input_, cross_attention_src=cross_attention_input, + src_mask=(self.attn_mask_per_stage[stage] if stage >= 0 else None)) + if self.out_norm: + out = self.out_norm(out) + logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, S, card] + + # remove the prefix from the model outputs + if len(self.fuser.fuse2cond['prepend']) > 0: + logits = logits[:, :, -S:] + + return logits # [B, K, S, card] + + def compute_predictions( + self, codes: torch.Tensor, + conditions: tp.List[ConditioningAttributes], + condition_tensors: tp.Optional[ConditionTensors] = None, + stage: int = -1, + keep_only_valid_steps: bool = True) -> LMOutput: + """Given an input tensor of codes [B, K, T] and list of conditions, runs the model + forward using the specified codes interleaving pattern. + + Args: + codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size, + K the number of codebooks and T the number of timesteps. + conditions (list of ConditioningAttributes): conditionings to use when modeling + the given codes. Note that when evaluating multiple time with the same conditioning + you should pre-compute those and pass them as `condition_tensors`. + condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning + tensors, see `conditions`. + stage (int): The codebook level that is being predicted. Relevant for MAGNeT + in which prediction is done in a codebook-by-codebook manner. + Takes values in range(n_q), and ignored by default. + keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps. + Steps that are beyond valid steps will be replaced by the special_token in that case. + Returns: + LMOutput: Language model outputs + logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes, + i.e. the first item corresponds to logits to predict the first code, meaning that + no additional shifting of codes and logits is required. + mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions. + Given the specified interleaving strategies, parts of the logits and codes should + not be considered as valid predictions because of invalid context. + """ + B, K, T = codes.shape + codes = codes.contiguous() + # map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens + pattern = self.pattern_provider.get_pattern(T) + sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence( + codes, self.special_token_id, keep_only_valid_steps=keep_only_valid_steps, + ) + + # apply model on pattern sequence + model = self if self._fsdp is None else self._fsdp + logits = model(sequence_codes, conditions, condition_tensors, stage=stage) # [B, K, S, card] + # map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card] + # and provide the corresponding mask over invalid positions of tokens + logits = logits.permute(0, 3, 1, 2) # [B, card, K, S] + # note: we use nans as special token to make it obvious if we feed unexpected logits + logits, logits_indexes, logits_mask = pattern.revert_pattern_logits( + logits, float('nan'), keep_only_valid_steps=keep_only_valid_steps + ) + logits = logits.permute(0, 2, 3, 1) # [B, K, T, card] + logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T] + return LMOutput(logits, logits_mask) + + def _sample_next_token(self, + sequence, + cfg_conditions, + unconditional_state, + use_sampling=False, + temp: float = 1.0, + top_k: int = 0, + top_p: float = 0.0, + cfg_coef: tp.Optional[float] = None, + two_step_cfg: tp.Optional[bool] = None) -> torch.Tensor: + """Sample next token from the model given a sequence and a set of conditions. The model supports + multiple sampling strategies (greedy sampling, softmax, top-k, top-p...). + + Args: + sequence (torch.Tensor): Current sequence of shape [B, K, S] + with K corresponding to the number of codebooks and S the number of sequence steps. + S = 1 in streaming mode, except for the first step that contains a bigger prompt. + condition_tensors (dict[str, ConditionType): Set of conditions. If CFG is used, + should be twice the batch size, being the concatenation of the conditions + null conditions. + use_sampling (bool): Whether to use a sampling strategy or not. + temp (float): Sampling temperature. + top_k (int): K for "top-k" sampling. + top_p (float): P for "top-p" sampling. + cfg_coef (float, optional): classifier free guidance coefficient + Returns: + next_token (torch.Tensor): Next token tensor of shape [B, K, 1]. + """ + B = sequence.shape[0] + cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef + model = self if self._fsdp is None else self._fsdp + two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg + if two_step_cfg and cfg_conditions != {}: + assert isinstance(cfg_conditions, tuple), type(cfg_conditions) + condition_tensors, null_condition_tensors = cfg_conditions + cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors) + state = self.get_streaming_state() + self.set_streaming_state(unconditional_state) + uncond_logits = model(sequence, conditions=[], condition_tensors=null_condition_tensors) + unconditional_state.update(self.get_streaming_state()) + self.set_streaming_state(state) + logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_coef + else: + assert isinstance(cfg_conditions, dict) + condition_tensors = cfg_conditions + if condition_tensors: + # Preparing for CFG, predicting both conditional and unconditional logits. + sequence = torch.cat([sequence, sequence], dim=0) + all_logits = model( + sequence, + conditions=[], condition_tensors=condition_tensors) + if condition_tensors: + cond_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card] + logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef + else: + logits = all_logits + + logits = logits.permute(0, 1, 3, 2) # [B, K, card, T] + logits = logits[..., -1] # [B x K x card] + + # Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error. + if use_sampling and temp > 0.0: + probs = torch.softmax(logits / temp, dim=-1) + if top_p > 0.0: + next_token = utils.sample_top_p(probs, p=top_p) + elif top_k > 0: + next_token = utils.sample_top_k(probs, k=top_k) + else: + next_token = utils.multinomial(probs, num_samples=1) + else: + next_token = torch.argmax(logits, dim=-1, keepdim=True) + + return next_token + + @torch.no_grad() + def generate(self, + prompt: tp.Optional[torch.Tensor] = None, + conditions: tp.List[ConditioningAttributes] = [], + num_samples: tp.Optional[int] = None, + max_gen_len: int = 256, + use_sampling: bool = True, + temp: float = 1.0, + top_k: int = 250, + top_p: float = 0.0, + cfg_coef: tp.Optional[float] = None, + two_step_cfg: tp.Optional[bool] = None, + remove_prompts: bool = False, + check: bool = False, + callback: tp.Optional[tp.Callable[[int, int], None]] = None, + **kwargs) -> torch.Tensor: + """Generate tokens sampling from the model given a prompt or unconditionally. Generation can + be performed in a greedy fashion or using sampling with top K and top P strategies. + + Args: + prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T]. + conditions_tensors (list of ConditioningAttributes, optional): List of conditions. + num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given. + max_gen_len (int): Maximum generation length. + use_sampling (bool): Whether to use a sampling strategy or not. + temp (float): Sampling temperature. + top_k (int): K for "top-k" sampling. + top_p (float): P for "top-p" sampling. + cfg_coeff (float, optional): Classifier-free guidance coefficient. + two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation. + remove_prompts (bool): Whether to remove prompts from generation or not. + check (bool): Whether to apply further checks on generated sequence. + callback (Callback, optional): Callback function to report generation progress. + Returns: + torch.Tensor: Generated tokens. + """ + assert not self.training, "generation shouldn't be used in training mode." + first_param = next(iter(self.parameters())) + device = first_param.device + + # Checking all input shapes are consistent. + possible_num_samples = [] + if num_samples is not None: + possible_num_samples.append(num_samples) + elif prompt is not None: + possible_num_samples.append(prompt.shape[0]) + elif conditions: + possible_num_samples.append(len(conditions)) + else: + possible_num_samples.append(1) + assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes" + num_samples = possible_num_samples[0] + + # below we create set of conditions: one conditional and one unconditional + # to do that we merge the regular condition together with the null condition + # we then do 1 forward pass instead of 2. + # the reason for that is two-fold: + # 1. it is about x2 faster than doing 2 forward passes + # 2. avoid the streaming API treating the 2 passes as part of different time steps + # We also support doing two different passes, in particular to ensure that + # the padding structure is exactly the same between train and test. + # With a batch size of 1, this can be slower though. + cfg_conditions: CFGConditions + two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg + if conditions: + null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions) + if two_step_cfg: + cfg_conditions = ( + self.condition_provider(self.condition_provider.tokenize(conditions)), + self.condition_provider(self.condition_provider.tokenize(null_conditions)), + ) + else: + conditions = conditions + null_conditions + tokenized = self.condition_provider.tokenize(conditions) + cfg_conditions = self.condition_provider(tokenized) + else: + cfg_conditions = {} + + if prompt is None: + assert num_samples > 0 + prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device) + + B, K, T = prompt.shape + start_offset = T + assert start_offset < max_gen_len + + pattern = self.pattern_provider.get_pattern(max_gen_len) + # this token is used as default value for codes that are not generated yet + unknown_token = -1 + + # we generate codes up to the max_gen_len that will be mapped to the pattern sequence + gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device) + # filling the gen_codes with the prompt if needed + gen_codes[..., :start_offset] = prompt + # create the gen_sequence with proper interleaving from the pattern: [B, K, S] + gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id) + # retrieve the start_offset in the sequence: + # it is the first sequence step that contains the `start_offset` timestep + start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset) + assert start_offset_sequence is not None + + with self.streaming(): + unconditional_state = self.get_streaming_state() + prev_offset = 0 + gen_sequence_len = gen_sequence.shape[-1] # gen_sequence shape is [B, K, S] + for offset in range(start_offset_sequence, gen_sequence_len): + # get current sequence (note that the streaming API is providing the caching over previous offsets) + curr_sequence = gen_sequence[..., prev_offset:offset] + curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1) + if check: + # check coherence between mask and sequence + assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all() + # should never happen as gen_sequence is filled progressively + assert not (curr_sequence == unknown_token).any() + # sample next token from the model, next token shape is [B, K, 1] + next_token = self._sample_next_token( + curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p, + cfg_coef=cfg_coef, two_step_cfg=two_step_cfg) + # ensure the tokens that should be masked are properly set to special_token_id + # as the model never output special_token_id + valid_mask = mask[..., offset:offset+1].expand(B, -1, -1) + next_token[~valid_mask] = self.special_token_id + # ensure we don't overwrite prompt tokens, we only write over unknown tokens + # (then mask tokens should be left as is as well, which is correct) + gen_sequence[..., offset:offset+1] = torch.where( + gen_sequence[..., offset:offset+1] == unknown_token, + next_token, gen_sequence[..., offset:offset+1] + ) + prev_offset = offset + if callback is not None: + callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence) + unconditional_state.clear() + + # ensure sequence has been entirely filled + assert not (gen_sequence == unknown_token).any() + # ensure gen_sequence pattern and mask are matching + # which means the gen_sequence is valid according to the pattern + assert ( + gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, self.special_token_id) + ).all() + # get back the codes, trimming the prompt if needed and cutting potentially incomplete timesteps + out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token) + + # sanity checks over the returned codes and corresponding masks + assert (out_codes[..., :max_gen_len] != unknown_token).all() + assert (out_mask[..., :max_gen_len] == 1).all() + + out_start_offset = start_offset if remove_prompts else 0 + out_codes = out_codes[..., out_start_offset:max_gen_len] + + # ensure the returned codes are all valid + assert (out_codes >= 0).all() and (out_codes <= self.card).all() + return out_codes diff --git a/audiocraft/loaders.py b/audiocraft/loaders.py new file mode 100644 index 0000000000000000000000000000000000000000..5c6bbaa82bc653b8d0cf1a2b1a2401a67426fd77 --- /dev/null +++ b/audiocraft/loaders.py @@ -0,0 +1,178 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utility functions to load from the checkpoints. +Each checkpoint is a torch.saved dict with the following keys: +- 'xp.cfg': the hydra config as dumped during training. This should be used + to rebuild the object using the audiocraft.models.builders functions, +- 'model_best_state': a readily loadable best state for the model, including + the conditioner. The model obtained from `xp.cfg` should be compatible + with this state dict. In the case of a LM, the encodec model would not be + bundled along but instead provided separately. + +Those functions also support loading from a remote location with the Torch Hub API. +They also support overriding some parameters, in particular the device and dtype +of the returned model. +""" + +from pathlib import Path +from huggingface_hub import hf_hub_download +import typing as tp +import os + +from omegaconf import OmegaConf, DictConfig +import torch + +import audiocraft +from . import builders +from .encodec import CompressionModel + + +def get_audiocraft_cache_dir() -> tp.Optional[str]: + return os.environ.get('AUDIOCRAFT_CACHE_DIR', None) + + +def _get_state_dict( + file_or_url_or_id: tp.Union[Path, str], + filename: tp.Optional[str] = None, + device='cpu', + cache_dir: tp.Optional[str] = None, +): + if cache_dir is None: + cache_dir = get_audiocraft_cache_dir() + # Return the state dict either from a file or url + file_or_url_or_id = str(file_or_url_or_id) + assert isinstance(file_or_url_or_id, str) + + if os.path.isfile(file_or_url_or_id): + return torch.load(file_or_url_or_id, map_location=device) + + if os.path.isdir(file_or_url_or_id): + file = f"{file_or_url_or_id}/{filename}" + return torch.load(file, map_location=device) + + elif file_or_url_or_id.startswith('https://'): + return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True) + + else: + assert filename is not None, "filename needs to be defined if using HF checkpoints" + + file = hf_hub_download( + repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir, + library_name="audiocraft", + library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__) + return torch.load(file, map_location=device) + + +def load_compression_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None): + return _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir) + + +def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None): + pkg = load_compression_model_ckpt(file_or_url_or_id, cache_dir=cache_dir) + if 'pretrained' in pkg: + return CompressionModel.get_pretrained(pkg['pretrained'], device=device) + cfg = OmegaConf.create(pkg['xp.cfg']) + cfg.device = str(device) + model = builders.get_compression_model(cfg) + model.load_state_dict(pkg['best_state']) + model.eval() + return model + + +def load_lm_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None): + return _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir) + + +def _delete_param(cfg: DictConfig, full_name: str): + parts = full_name.split('.') + for part in parts[:-1]: + if part in cfg: + cfg = cfg[part] + else: + return + OmegaConf.set_struct(cfg, False) + if parts[-1] in cfg: + del cfg[parts[-1]] + OmegaConf.set_struct(cfg, True) + + +def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None): + pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir) + cfg = OmegaConf.create(pkg['xp.cfg']) + cfg.device = str(device) + if cfg.device == 'cpu': + cfg.dtype = 'float32' + else: + cfg.dtype = 'float16' + _delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path') + _delete_param(cfg, 'conditioners.args.merge_text_conditions_p') + _delete_param(cfg, 'conditioners.args.drop_desc_p') + model = builders.get_lm_model(cfg) + model.load_state_dict(pkg['best_state']) + model.eval() + model.cfg = cfg + return model + + +def load_lm_model_magnet(file_or_url_or_id: tp.Union[Path, str], compression_model_frame_rate: int, + device='cpu', cache_dir: tp.Optional[str] = None): + pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir) + cfg = OmegaConf.create(pkg['xp.cfg']) + cfg.device = str(device) + if cfg.device == 'cpu': + cfg.dtype = 'float32' + else: + cfg.dtype = 'float16' + _delete_param(cfg, 'conditioners.args.merge_text_conditions_p') + _delete_param(cfg, 'conditioners.args.drop_desc_p') + + cfg.transformer_lm.compression_model_framerate = compression_model_frame_rate + cfg.transformer_lm.segment_duration = cfg.dataset.segment_duration + cfg.transformer_lm.span_len = cfg.masking.span_len + + # MAGNeT models v1 support only xformers backend. + from .transformer import set_efficient_attention_backend + if cfg.transformer_lm.memory_efficient: + set_efficient_attention_backend("xformers") + + model = builders.get_lm_model(cfg) + model.load_state_dict(pkg['best_state']) + model.eval() + model.cfg = cfg + return model + + +def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str], + filename: tp.Optional[str] = None, + cache_dir: tp.Optional[str] = None): + return _get_state_dict(file_or_url_or_id, filename=filename, cache_dir=cache_dir) + + +def load_diffusion_models(file_or_url_or_id: tp.Union[Path, str], + device='cpu', + filename: tp.Optional[str] = None, + cache_dir: tp.Optional[str] = None): + pkg = load_mbd_ckpt(file_or_url_or_id, filename=filename, cache_dir=cache_dir) + models = [] + processors = [] + cfgs = [] + sample_rate = pkg['sample_rate'] + for i in range(pkg['n_bands']): + cfg = pkg[i]['cfg'] + model = builders.get_diffusion_model(cfg) + model_dict = pkg[i]['model_state'] + model.load_state_dict(model_dict) + model.to(device) + processor = builders.get_processor(cfg=cfg.processor, sample_rate=sample_rate) + processor_dict = pkg[i]['processor_state'] + processor.load_state_dict(processor_dict) + processor.to(device) + models.append(model) + processors.append(processor) + cfgs.append(cfg) + return models, processors, cfgs diff --git a/audiocraft/lstm.py b/audiocraft/lstm.py new file mode 100644 index 0000000000000000000000000000000000000000..c0866175950c1ca4f6cca98649525e6481853bba --- /dev/null +++ b/audiocraft/lstm.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from torch import nn + + +class StreamableLSTM(nn.Module): + """LSTM without worrying about the hidden state, nor the layout of the data. + Expects input as convolutional layout. + """ + def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True): + super().__init__() + self.skip = skip + self.lstm = nn.LSTM(dimension, dimension, num_layers) + + def forward(self, x): + x = x.permute(2, 0, 1) + y, _ = self.lstm(x) + if self.skip: + y = y + x + y = y.permute(1, 2, 0) + return y diff --git a/audiocraft/multibanddiffusion.py b/audiocraft/multibanddiffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..baab10890f0e74af79ed804b96e7a3767317004d --- /dev/null +++ b/audiocraft/multibanddiffusion.py @@ -0,0 +1,392 @@ +#====================================== From CompressionSolver.py + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import multiprocessing +from pathlib import Path +import typing as tp + +import flashy +import omegaconf +import torch +from torch import nn + +# from . import base, builders +from .. import models, quantization +from ..utils import checkpoint +from ..utils.samples.manager import SampleManager +from ..utils.utils import get_pool_executor + + + + + +class CompressionSolver(): #base.StandardSolver): + """Solver for compression task. + + The compression task combines a set of perceptual and objective losses + to train an EncodecModel (composed of an encoder-decoder and a quantizer) + to perform high fidelity audio reconstruction. + """ + def __init__(self, cfg: omegaconf.DictConfig): + # super().__init__(cfg) + self.cfg = cfg + self.rng: torch.Generator # set at each epoch + self.adv_losses = builders.get_adversarial_losses(self.cfg) + self.aux_losses = nn.ModuleDict() + self.info_losses = nn.ModuleDict() + assert not cfg.fsdp.use, "FSDP not supported by CompressionSolver." + loss_weights = dict() + for loss_name, weight in self.cfg.losses.items(): + if loss_name in ['adv', 'feat']: + for adv_name, _ in self.adv_losses.items(): + loss_weights[f'{loss_name}_{adv_name}'] = weight + elif weight > 0: + self.aux_losses[loss_name] = builders.get_loss(loss_name, self.cfg) + loss_weights[loss_name] = weight + else: + self.info_losses[loss_name] = builders.get_loss(loss_name, self.cfg) + self.balancer = builders.get_balancer(loss_weights, self.cfg.balancer) + self.register_stateful('adv_losses') + + @property + def best_metric_name(self) -> tp.Optional[str]: + # best model is the last for the compression model + return None + + def build_model(self): + """Instantiate model and optimizer.""" + # Model and optimizer + self.model = models.builders.get_compression_model(self.cfg).to(self.device) + self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim) + self.register_stateful('model', 'optimizer') + self.register_best_state('model') + self.register_ema('model') + + + + def evaluate(self): + """Evaluate stage. Runs audio reconstruction evaluation.""" + self.model.eval() + evaluate_stage_name = str(self.current_stage) + + loader = self.dataloaders['evaluate'] + updates = len(loader) + lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates) + average = flashy.averager() + + pendings = [] + ctx = multiprocessing.get_context('spawn') + with get_pool_executor(self.cfg.evaluate.num_workers, mp_context=ctx) as pool: + for idx, batch in enumerate(lp): + x = batch.to(self.device) + with torch.no_grad(): + qres = self.model(x) + + y_pred = qres.x.cpu() + y = batch.cpu() # should already be on CPU but just in case + pendings.append(pool.submit(evaluate_audio_reconstruction, y_pred, y, self.cfg)) + + metrics_lp = self.log_progress(f'{evaluate_stage_name} metrics', pendings, updates=self.log_updates) + for pending in metrics_lp: + metrics = pending.result() + metrics = average(metrics) + + metrics = flashy.distrib.average_metrics(metrics, len(loader)) + return metrics + + def generate(self): + """Generate stage.""" + self.model.eval() + sample_manager = SampleManager(self.xp, map_reference_to_sample_id=True) + generate_stage_name = str(self.current_stage) + + loader = self.dataloaders['generate'] + updates = len(loader) + lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates) + + for batch in lp: + reference, _ = batch + reference = reference.to(self.device) + with torch.no_grad(): + qres = self.model(reference) + assert isinstance(qres, quantization.QuantizedResult) + + reference = reference.cpu() + estimate = qres.x.cpu() + sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference) + + flashy.distrib.barrier() + + def load_from_pretrained(self, name: str) -> dict: + model = models.CompressionModel.get_pretrained(name) + if isinstance(model, models.DAC): + raise RuntimeError("Cannot fine tune a DAC model.") + elif isinstance(model, models.HFEncodecCompressionModel): + self.logger.warning('Trying to automatically convert a HuggingFace model ' + 'to AudioCraft, this might fail!') + state = model.model.state_dict() + new_state = {} + for k, v in state.items(): + if k.startswith('decoder.layers') and '.conv.' in k and '.block.' not in k: + # We need to determine if this a convtr or a regular conv. + layer = int(k.split('.')[2]) + if isinstance(model.model.decoder.layers[layer].conv, torch.nn.ConvTranspose1d): + + k = k.replace('.conv.', '.convtr.') + k = k.replace('encoder.layers.', 'encoder.model.') + k = k.replace('decoder.layers.', 'decoder.model.') + k = k.replace('conv.', 'conv.conv.') + k = k.replace('convtr.', 'convtr.convtr.') + k = k.replace('quantizer.layers.', 'quantizer.vq.layers.') + k = k.replace('.codebook.', '._codebook.') + new_state[k] = v + state = new_state + elif isinstance(model, models.EncodecModel): + state = model.state_dict() + else: + raise RuntimeError(f"Cannot fine tune model type {type(model)}.") + return { + 'best_state': {'model': state} + } + + @staticmethod + def model_from_checkpoint(checkpoint_path: tp.Union[Path, str], + device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel: + """Instantiate a CompressionModel from a given checkpoint path or dora sig. + This method is a convenient endpoint to load a CompressionModel to use in other solvers. + + Args: + checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved. + This also supports pre-trained models by using a path of the form //pretrained/NAME. + See `model_from_pretrained` for a list of supported pretrained models. + use_ema (bool): Use EMA variant of the model instead of the actual model. + device (torch.device or str): Device on which the model is loaded. + """ + checkpoint_path = str(checkpoint_path) + if checkpoint_path.startswith('//pretrained/'): + name = checkpoint_path.split('/', 3)[-1] + return models.CompressionModel.get_pretrained(name, device) + logger = logging.getLogger(__name__) + logger.info(f"Loading compression model from checkpoint: {checkpoint_path}") + _checkpoint_path = checkpoint.resolve_checkpoint_path(checkpoint_path, use_fsdp=False) + assert _checkpoint_path is not None, f"Could not resolve compression model checkpoint path: {checkpoint_path}" + state = checkpoint.load_checkpoint(_checkpoint_path) + assert state is not None and 'xp.cfg' in state, f"Could not load compression model from ckpt: {checkpoint_path}" + cfg = state['xp.cfg'] + cfg.device = device + compression_model = models.builders.get_compression_model(cfg).to(device) + assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match" + + assert 'best_state' in state and state['best_state'] != {} + assert 'exported' not in state, "When loading an exported checkpoint, use the //pretrained/ prefix." + compression_model.load_state_dict(state['best_state']['model']) + compression_model.eval() + logger.info("Compression model loaded!") + return compression_model + + @staticmethod + def wrapped_model_from_checkpoint(cfg: omegaconf.DictConfig, + checkpoint_path: tp.Union[Path, str], + device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel: + """Instantiate a wrapped CompressionModel from a given checkpoint path or dora sig. + + Args: + cfg (omegaconf.DictConfig): Configuration to read from for wrapped mode. + checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved. + use_ema (bool): Use EMA variant of the model instead of the actual model. + device (torch.device or str): Device on which the model is loaded. + """ + compression_model = CompressionSolver.model_from_checkpoint(checkpoint_path, device) + compression_model = models.builders.get_wrapped_compression_model(compression_model, cfg) + return compression_model + + + + + +#=========================================================================== ORIG + +import typing as tp + +import torch +import julius + +from .unet import DiffusionUnet +from ..modules.diffusion_schedule import NoiseSchedule +from .encodec import CompressionModel +from .loaders import load_compression_model, load_diffusion_models + + +class DiffusionProcess: + """Sampling for a diffusion Model. + + Args: + model (DiffusionUnet): Diffusion U-Net model. + noise_schedule (NoiseSchedule): Noise schedule for diffusion process. + """ + def __init__(self, model: DiffusionUnet, noise_schedule: NoiseSchedule) -> None: + self.model = model + self.schedule = noise_schedule + + def generate(self, condition: torch.Tensor, initial_noise: torch.Tensor, + step_list: tp.Optional[tp.List[int]] = None): + """Perform one diffusion process to generate one of the bands. + + Args: + condition (torch.Tensor): The embeddings from the compression model. + initial_noise (torch.Tensor): The initial noise to start the process. + """ + return self.schedule.generate_subsampled(model=self.model, initial=initial_noise, step_list=step_list, + condition=condition) + + +class MultiBandDiffusion: + """Sample from multiple diffusion models. + + Args: + DPs (list of DiffusionProcess): Diffusion processes. + codec_model (CompressionModel): Underlying compression model used to obtain discrete tokens. + """ + def __init__(self, DPs: tp.List[DiffusionProcess], codec_model: CompressionModel) -> None: + self.DPs = DPs + self.codec_model = codec_model + self.device = next(self.codec_model.parameters()).device + + @property + def sample_rate(self) -> int: + return self.codec_model.sample_rate + + @staticmethod + def get_mbd_musicgen(device=None): + """Load our diffusion models trained for MusicGen.""" + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + path = 'facebook/multiband-diffusion' + filename = 'mbd_musicgen_32khz.th' + name = 'facebook/musicgen-small' + codec_model = load_compression_model(name, device=device) + models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device) + DPs = [] + for i in range(len(models)): + schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device) + DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule)) + return MultiBandDiffusion(DPs=DPs, codec_model=codec_model) + + @staticmethod + def get_mbd_24khz(bw: float = 3.0, + device: tp.Optional[tp.Union[torch.device, str]] = None, + n_q: tp.Optional[int] = None): + """Get the pretrained Models for MultibandDiffusion. + + Args: + bw (float): Bandwidth of the compression model. + device (torch.device or str, optional): Device on which the models are loaded. + n_q (int, optional): Number of quantizers to use within the compression model. + """ + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + assert bw in [1.5, 3.0, 6.0], f"bandwidth {bw} not available" + if n_q is not None: + assert n_q in [2, 4, 8] + assert {1.5: 2, 3.0: 4, 6.0: 8}[bw] == n_q, \ + f"bandwidth and number of codebooks missmatch to use n_q = {n_q} bw should be {n_q * (1.5 / 2)}" + n_q = {1.5: 2, 3.0: 4, 6.0: 8}[bw] + codec_model = CompressionSolver.model_from_checkpoint( + '//pretrained/facebook/encodec_24khz', device=device) + codec_model.set_num_codebooks(n_q) + codec_model = codec_model.to(device) + path = 'facebook/multiband-diffusion' + filename = f'mbd_comp_{n_q}.pt' + models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device) + DPs = [] + for i in range(len(models)): + schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device) + DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule)) + return MultiBandDiffusion(DPs=DPs, codec_model=codec_model) + + @torch.no_grad() + def get_condition(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor: + """Get the conditioning (i.e. latent representations of the compression model) from a waveform. + Args: + wav (torch.Tensor): The audio that we want to extract the conditioning from. + sample_rate (int): Sample rate of the audio.""" + if sample_rate != self.sample_rate: + wav = julius.resample_frac(wav, sample_rate, self.sample_rate) + codes, scale = self.codec_model.encode(wav) + assert scale is None, "Scaled compression models not supported." + emb = self.get_emb(codes) + return emb + + @torch.no_grad() + def get_emb(self, codes: torch.Tensor): + """Get latent representation from the discrete codes. + Args: + codes (torch.Tensor): Discrete tokens.""" + emb = self.codec_model.decode_latent(codes) + return emb + + def generate(self, emb: torch.Tensor, size: tp.Optional[torch.Size] = None, + step_list: tp.Optional[tp.List[int]] = None): + """Generate waveform audio from the latent embeddings of the compression model. + Args: + emb (torch.Tensor): Conditioning embeddings + size (None, torch.Size): Size of the output + if None this is computed from the typical upsampling of the model. + step_list (list[int], optional): list of Markov chain steps, defaults to 50 linearly spaced step. + """ + if size is None: + upsampling = int(self.codec_model.sample_rate / self.codec_model.frame_rate) + size = torch.Size([emb.size(0), self.codec_model.channels, emb.size(-1) * upsampling]) + assert size[0] == emb.size(0) + out = torch.zeros(size).to(self.device) + for DP in self.DPs: + out += DP.generate(condition=emb, step_list=step_list, initial_noise=torch.randn_like(out)) + return out + + def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 32, strictness: float = 1): + """Match the eq to the encodec output by matching the standard deviation of some frequency bands. + Args: + wav (torch.Tensor): Audio to equalize. + ref (torch.Tensor): Reference audio from which we match the spectrogram. + n_bands (int): Number of bands of the eq. + strictness (float): How strict the matching. 0 is no matching, 1 is exact matching. + """ + split = julius.SplitBands(n_bands=n_bands, sample_rate=self.codec_model.sample_rate).to(wav.device) + bands = split(wav) + bands_ref = split(ref) + out = torch.zeros_like(ref) + for i in range(n_bands): + out += bands[i] * (bands_ref[i].std() / bands[i].std()) ** strictness + return out + + def regenerate(self, wav: torch.Tensor, sample_rate: int): + """Regenerate a waveform through compression and diffusion regeneration. + Args: + wav (torch.Tensor): Original 'ground truth' audio. + sample_rate (int): Sample rate of the input (and output) wav. + """ + if sample_rate != self.codec_model.sample_rate: + wav = julius.resample_frac(wav, sample_rate, self.codec_model.sample_rate) + emb = self.get_condition(wav, sample_rate=self.codec_model.sample_rate) + size = wav.size() + out = self.generate(emb, size=size) + if sample_rate != self.codec_model.sample_rate: + out = julius.resample_frac(out, self.codec_model.sample_rate, sample_rate) + return out + + def tokens_to_wav(self, tokens: torch.Tensor, n_bands: int = 32): + """Generate Waveform audio with diffusion from the discrete codes. + Args: + tokens (torch.Tensor): Discrete codes. + n_bands (int): Bands for the eq matching. + """ + wav_encodec = self.codec_model.decode(tokens) + condition = self.get_emb(tokens) + wav_diffusion = self.generate(emb=condition, size=wav_encodec.size()) + return self.re_eq(wav=wav_diffusion, ref=wav_encodec, n_bands=n_bands) diff --git a/audiocraft/quantization/__init__.py b/audiocraft/quantization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1e0c7e429ab96d67be667e23bf7a0ffa389c036b --- /dev/null +++ b/audiocraft/quantization/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""RVQ.""" +# flake8: noqa +from .vq import ResidualVectorQuantizer +from .base import BaseQuantizer, DummyQuantizer, QuantizedResult diff --git a/audiocraft/quantization/base.py b/audiocraft/quantization/base.py new file mode 100644 index 0000000000000000000000000000000000000000..a77fefb98e62a5bbc6385910261ffdde2ffa5a25 --- /dev/null +++ b/audiocraft/quantization/base.py @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Base class for all quantizers. +""" + +from dataclasses import dataclass, field +import typing as tp + +import torch +from torch import nn + + +@dataclass +class QuantizedResult: + x: torch.Tensor + codes: torch.Tensor + bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item. + penalty: tp.Optional[torch.Tensor] = None + metrics: dict = field(default_factory=dict) + + +class BaseQuantizer(nn.Module): + """Base class for quantizers. + """ + + def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult: + """ + Given input tensor x, returns first the quantized (or approximately quantized) + representation along with quantized codes, bandwidth, and any penalty term for the loss. + Finally, this returns a dict of metrics to update logging etc. + Frame rate must be passed so that the bandwidth is properly computed. + """ + raise NotImplementedError() + + def encode(self, x: torch.Tensor) -> torch.Tensor: + """Encode a given input tensor with the specified sample rate at the given bandwidth.""" + raise NotImplementedError() + + def decode(self, codes: torch.Tensor) -> torch.Tensor: + """Decode the given codes to the quantized representation.""" + raise NotImplementedError() + + @property + def total_codebooks(self): + """Total number of codebooks.""" + raise NotImplementedError() + + @property + def num_codebooks(self): + """Number of active codebooks.""" + raise NotImplementedError() + + def set_num_codebooks(self, n: int): + """Set the number of active codebooks.""" + raise NotImplementedError() + + +class DummyQuantizer(BaseQuantizer): + """Fake quantizer that actually does not perform any quantization. + """ + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, frame_rate: int): + q = x.unsqueeze(1) + return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x)) + + def encode(self, x: torch.Tensor) -> torch.Tensor: + """Encode a given input tensor with the specified sample rate at the given bandwidth. + In the case of the DummyQuantizer, the codes are actually identical + to the input and resulting quantized representation as no quantization is done. + """ + return x.unsqueeze(1) + + def decode(self, codes: torch.Tensor) -> torch.Tensor: + """Decode the given codes to the quantized representation. + In the case of the DummyQuantizer, the codes are actually identical + to the input and resulting quantized representation as no quantization is done. + """ + return codes.squeeze(1) + + @property + def total_codebooks(self): + """Total number of codebooks.""" + return 1 + + @property + def num_codebooks(self): + """Total number of codebooks.""" + return self.total_codebooks + + def set_num_codebooks(self, n: int): + """Set the number of active codebooks.""" + raise AttributeError("Cannot override the number of codebooks for the dummy quantizer") diff --git a/audiocraft/quantization/core_vq.py b/audiocraft/quantization/core_vq.py new file mode 100644 index 0000000000000000000000000000000000000000..6aaa3b077c53b413e2b2a904ac7e769d1c623b36 --- /dev/null +++ b/audiocraft/quantization/core_vq.py @@ -0,0 +1,405 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import typing as tp + +from einops import rearrange, repeat +import flashy +import torch +from torch import nn, einsum +import torch.nn.functional as F + + +def exists(val: tp.Optional[tp.Any]) -> bool: + return val is not None + + +def default(val: tp.Any, d: tp.Any) -> tp.Any: + return val if exists(val) else d + + +def l2norm(t): + return F.normalize(t, p=2, dim=-1) + + +def ema_inplace(moving_avg, new, decay: float): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): + return (x + epsilon) / (x.sum() + n_categories * epsilon) + + +def uniform_init(*shape: int): + t = torch.empty(shape) + nn.init.kaiming_uniform_(t) + return t + + +def sample_vectors(samples, num: int): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +def kmeans(samples, num_clusters: int, num_iters: int = 10): + dim, dtype = samples.shape[-1], samples.dtype + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + diffs = rearrange(samples, "n d -> n () d") - rearrange( + means, "c d -> () c d" + ) + dists = -(diffs ** 2).sum(dim=-1) + + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +def orthogonal_loss_fn(t): + # eq (2) from https://arxiv.org/abs/2112.00384 + n = t.shape[0] + normed_codes = l2norm(t) + identity = torch.eye(n, device=t.device) + cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes) + return ((cosine_sim - identity) ** 2).sum() / (n ** 2) + + +class EuclideanCodebook(nn.Module): + """Codebook with Euclidean distance. + + Args: + dim (int): Dimension. + codebook_size (int): Codebook size. + kmeans_init (bool): Whether to use k-means to initialize the codebooks. + If set to true, run the k-means algorithm on the first training batch and use + the learned centroids as initialization. + kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + def __init__( + self, + dim: int, + codebook_size: int, + kmeans_init: int = False, + kmeans_iters: int = 10, + decay: float = 0.8, + epsilon: float = 1e-5, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.decay = decay + init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros + embed = init_fn(codebook_size, dim) + + self.codebook_size = codebook_size + + self.kmeans_iters = kmeans_iters + self.epsilon = epsilon + self.threshold_ema_dead_code = threshold_ema_dead_code + + self.register_buffer("inited", torch.Tensor([not kmeans_init])) + self.register_buffer("cluster_size", torch.zeros(codebook_size)) + self.register_buffer("embed", embed) + self.register_buffer("embed_avg", embed.clone()) + + @torch.jit.ignore + def init_embed_(self, data): + if self.inited: + return + + embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) + self.embed.data.copy_(embed) + self.embed_avg.data.copy_(embed.clone()) + self.cluster_size.data.copy_(cluster_size) + self.inited.data.copy_(torch.Tensor([True])) + # Make sure all buffers across workers are in sync after initialization + flashy.distrib.broadcast_tensors(self.buffers()) + + def replace_(self, samples, mask): + modified_codebook = torch.where( + mask[..., None], sample_vectors(samples, self.codebook_size), self.embed + ) + self.embed.data.copy_(modified_codebook) + + def expire_codes_(self, batch_samples): + if self.threshold_ema_dead_code == 0: + return + + expired_codes = self.cluster_size < self.threshold_ema_dead_code + if not torch.any(expired_codes): + return + + batch_samples = rearrange(batch_samples, "... d -> (...) d") + self.replace_(batch_samples, mask=expired_codes) + flashy.distrib.broadcast_tensors(self.buffers()) + + def preprocess(self, x): + x = rearrange(x, "... d -> (...) d") + return x + + def quantize(self, x): + embed = self.embed.t() + dist = -( + x.pow(2).sum(1, keepdim=True) + - 2 * x @ embed + + embed.pow(2).sum(0, keepdim=True) + ) + embed_ind = dist.max(dim=-1).indices + return embed_ind + + def postprocess_emb(self, embed_ind, shape): + return embed_ind.view(*shape[:-1]) + + def dequantize(self, embed_ind): + quantize = F.embedding(embed_ind, self.embed) + return quantize + + def encode(self, x): + shape = x.shape + # pre-process + x = self.preprocess(x) + # quantize + embed_ind = self.quantize(x) + # post-process + embed_ind = self.postprocess_emb(embed_ind, shape) + return embed_ind + + def decode(self, embed_ind): + quantize = self.dequantize(embed_ind) + return quantize + + def forward(self, x): + shape, dtype = x.shape, x.dtype + x = self.preprocess(x) + self.init_embed_(x) + + embed_ind = self.quantize(x) + embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) + embed_ind = self.postprocess_emb(embed_ind, shape) + quantize = self.dequantize(embed_ind) + + if self.training: + # We do the expiry of code at that point as buffers are in sync + # and all the workers will take the same decision. + self.expire_codes_(x) + ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) + embed_sum = x.t() @ embed_onehot + ema_inplace(self.embed_avg, embed_sum.t(), self.decay) + cluster_size = ( + laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) + * self.cluster_size.sum() + ) + embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) + self.embed.data.copy_(embed_normalized) + + return quantize, embed_ind + + +class VectorQuantization(nn.Module): + """Vector quantization implementation. + Currently supports only euclidean distance. + + Args: + dim (int): Dimension + codebook_size (int): Codebook size + codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): + channels_last (bool): Channels are the last dimension in the input tensors. + commitment_weight (float): Weight for commitment loss. + orthogonal_reg_weight (float): Orthogonal regularization weights. + orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes. + orthogonal_reg_max_codes (optional int): Maximum number of codes to consider + for orthogonal regularization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + def __init__( + self, + dim: int, + codebook_size: int, + codebook_dim: tp.Optional[int] = None, + decay: float = 0.8, + epsilon: float = 1e-5, + kmeans_init: bool = False, + kmeans_iters: int = 10, + threshold_ema_dead_code: int = 2, + channels_last: bool = False, + commitment_weight: float = 1., + orthogonal_reg_weight: float = 0.0, + orthogonal_reg_active_codes_only: bool = False, + orthogonal_reg_max_codes: tp.Optional[int] = None, + ): + super().__init__() + _codebook_dim: int = default(codebook_dim, dim) + + requires_projection = _codebook_dim != dim + self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()) + self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()) + + self.epsilon = epsilon + self.commitment_weight = commitment_weight + + self.orthogonal_reg_weight = orthogonal_reg_weight + self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only + self.orthogonal_reg_max_codes = orthogonal_reg_max_codes + + self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size, + kmeans_init=kmeans_init, kmeans_iters=kmeans_iters, + decay=decay, epsilon=epsilon, + threshold_ema_dead_code=threshold_ema_dead_code) + self.codebook_size = codebook_size + + self.channels_last = channels_last + + @property + def codebook(self): + return self._codebook.embed + + @property + def inited(self): + return self._codebook.inited + + def _preprocess(self, x): + if not self.channels_last: + x = rearrange(x, "b d n -> b n d") + return x + + def _postprocess(self, quantize): + if not self.channels_last: + quantize = rearrange(quantize, "b n d -> b d n") + return quantize + + def encode(self, x): + x = self._preprocess(x) + x = self.project_in(x) + embed_in = self._codebook.encode(x) + return embed_in + + def decode(self, embed_ind): + quantize = self._codebook.decode(embed_ind) + quantize = self.project_out(quantize) + quantize = self._postprocess(quantize) + return quantize + + def forward(self, x): + device = x.device + x = self._preprocess(x) + + x = self.project_in(x) + quantize, embed_ind = self._codebook(x) + + if self.training: + quantize = x + (quantize - x).detach() + + loss = torch.tensor([0.0], device=device, requires_grad=self.training) + + if self.training: + if self.commitment_weight > 0: + commit_loss = F.mse_loss(quantize.detach(), x) + loss = loss + commit_loss * self.commitment_weight + + if self.orthogonal_reg_weight > 0: + codebook = self.codebook + + if self.orthogonal_reg_active_codes_only: + # only calculate orthogonal loss for the activated codes for this batch + unique_code_ids = torch.unique(embed_ind) + codebook = codebook[unique_code_ids] + + num_codes = codebook.shape[0] + if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes: + rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes] + codebook = codebook[rand_ids] + + orthogonal_reg_loss = orthogonal_loss_fn(codebook) + loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight + + quantize = self.project_out(quantize) + quantize = self._postprocess(quantize) + + return quantize, embed_ind, loss + + +class ResidualVectorQuantization(nn.Module): + """Residual vector quantization implementation. + + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + """ + def __init__(self, *, num_quantizers, **kwargs): + super().__init__() + self.layers = nn.ModuleList( + [VectorQuantization(**kwargs) for _ in range(num_quantizers)] + ) + + def forward(self, x, n_q: tp.Optional[int] = None): + quantized_out = 0.0 + residual = x + + all_losses = [] + all_indices = [] + + n_q = n_q or len(self.layers) + + for i, layer in enumerate(self.layers[:n_q]): + quantized, indices, loss = layer(residual) + quantized = quantized.detach() + residual = residual - quantized + quantized_out = quantized_out + quantized + all_indices.append(indices) + all_losses.append(loss) + + if self.training: + # Solving subtle bug with STE and RVQ: https://github.com/facebookresearch/encodec/issues/25 + quantized_out = x + (quantized_out - x).detach() + + out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) + return quantized_out, out_indices, out_losses + + def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor: + residual = x + all_indices = [] + n_q = n_q or len(self.layers) + for layer in self.layers[:n_q]: + indices = layer.encode(residual) + quantized = layer.decode(indices) + residual = residual - quantized + all_indices.append(indices) + out_indices = torch.stack(all_indices) + return out_indices + + def decode(self, q_indices: torch.Tensor) -> torch.Tensor: + quantized_out = torch.tensor(0.0, device=q_indices.device) + for i, indices in enumerate(q_indices): + layer = self.layers[i] + quantized = layer.decode(indices) + quantized_out = quantized_out + quantized + return quantized_out diff --git a/audiocraft/quantization/vq.py b/audiocraft/quantization/vq.py new file mode 100644 index 0000000000000000000000000000000000000000..aa57bea59db95ddae35e0657f723ca3a29ee943b --- /dev/null +++ b/audiocraft/quantization/vq.py @@ -0,0 +1,115 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +import typing as tp + +import torch + +from .base import BaseQuantizer, QuantizedResult +from .core_vq import ResidualVectorQuantization + + +class ResidualVectorQuantizer(BaseQuantizer): + """Residual Vector Quantizer. + + Args: + dimension (int): Dimension of the codebooks. + n_q (int): Number of residual vector quantizers used. + q_dropout (bool): Random quantizer drop out at train time. + bins (int): Codebook size. + decay (float): Decay for exponential moving average over the codebooks. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + orthogonal_reg_weight (float): Orthogonal regularization weights. + orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes. + orthogonal_reg_max_codes (optional int): Maximum number of codes to consider. + for orthogonal regularization. + """ + def __init__( + self, + dimension: int = 256, + n_q: int = 8, + q_dropout: bool = False, + bins: int = 1024, + decay: float = 0.99, + kmeans_init: bool = True, + kmeans_iters: int = 10, + threshold_ema_dead_code: int = 2, + orthogonal_reg_weight: float = 0.0, + orthogonal_reg_active_codes_only: bool = False, + orthogonal_reg_max_codes: tp.Optional[int] = None, + ): + super().__init__() + self.max_n_q = n_q + self.n_q = n_q + self.q_dropout = q_dropout + self.dimension = dimension + self.bins = bins + self.decay = decay + self.kmeans_init = kmeans_init + self.kmeans_iters = kmeans_iters + self.threshold_ema_dead_code = threshold_ema_dead_code + self.orthogonal_reg_weight = orthogonal_reg_weight + self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only + self.orthogonal_reg_max_codes = orthogonal_reg_max_codes + self.vq = ResidualVectorQuantization( + dim=self.dimension, + codebook_size=self.bins, + num_quantizers=self.n_q, + decay=self.decay, + kmeans_init=self.kmeans_init, + kmeans_iters=self.kmeans_iters, + threshold_ema_dead_code=self.threshold_ema_dead_code, + orthogonal_reg_weight=self.orthogonal_reg_weight, + orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only, + orthogonal_reg_max_codes=self.orthogonal_reg_max_codes, + channels_last=False + ) + + def forward(self, x: torch.Tensor, frame_rate: int): + n_q = self.n_q + if self.training and self.q_dropout: + n_q = int(torch.randint(1, self.n_q + 1, (1,)).item()) + bw_per_q = math.log2(self.bins) * frame_rate / 1000 + quantized, codes, commit_loss = self.vq(x, n_q=n_q) + codes = codes.transpose(0, 1) + # codes is [B, K, T], with T frames, K nb of codebooks. + bw = torch.tensor(n_q * bw_per_q).to(x) + return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss)) + + def encode(self, x: torch.Tensor) -> torch.Tensor: + """Encode a given input tensor with the specified frame rate at the given bandwidth. + The RVQ encode method sets the appropriate number of quantizer to use + and returns indices for each quantizer. + """ + n_q = self.n_q + codes = self.vq.encode(x, n_q=n_q) + codes = codes.transpose(0, 1) + # codes is [B, K, T], with T frames, K nb of codebooks. + return codes + + def decode(self, codes: torch.Tensor) -> torch.Tensor: + """Decode the given codes to the quantized representation.""" + # codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T]. + codes = codes.transpose(0, 1) + quantized = self.vq.decode(codes) + return quantized + + @property + def total_codebooks(self): + return self.max_n_q + + @property + def num_codebooks(self): + return self.n_q + + def set_num_codebooks(self, n: int): + assert n > 0 and n <= self.max_n_q + self.n_q = n diff --git a/audiocraft/rope.py b/audiocraft/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..c12cee0954f27c45d79627771fdf7fa9fc10dfcc --- /dev/null +++ b/audiocraft/rope.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import typing as tp + +from torch import nn +import torch + + +class XPos(nn.Module): + """Length-extrapolatable positional embedding (xPos) from [Sun et al 2022](https://arxiv.org/abs/2212.10554v1). + This applies an exponential decay to the RoPE rotation matrix. + + Args: + dim (int): Embedding dimension. + smoothing (float): Smoothing factor applied to the decay rates. + base_scale (int): Base decay rate, given in terms of scaling time. + device (torch.device, optional): Device on which to initialize the module. + dtype (torch.dtype): dtype to use to generate the embedding. + """ + def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int = 512, + device=None, dtype: torch.dtype = torch.float32): + super().__init__() + assert dim % 2 == 0 + assert dtype in [torch.float64, torch.float32] + self.dtype = dtype + self.base_scale = base_scale + + half_dim = dim // 2 + adim = torch.arange(half_dim, device=device, dtype=dtype) + decay_rates = (adim / half_dim + smoothing) / (1.0 + smoothing) + self.register_buffer("decay_rates", decay_rates) + self.decay: tp.Optional[torch.Tensor] = None + + def get_decay(self, start: int, end: int): + """Create complex decay tensor, cache values for fast computation.""" + if self.decay is None or end > self.decay.shape[0]: + assert isinstance(self.decay_rates, torch.Tensor) # Satisfy type checker. + idx = torch.arange(end, device=self.decay_rates.device, dtype=self.dtype) + power = idx / self.base_scale + scale = self.decay_rates ** power.unsqueeze(-1) + self.decay = torch.polar(scale, torch.zeros_like(scale)) + return self.decay[start:end] # [T, C/2] + + +class RotaryEmbedding(nn.Module): + """Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864). + + Args: + dim (int): Embedding dimension (twice the number of frequencies). + max_period (float): Maximum period of the rotation frequencies. + xpos (bool): Use xPos, applies an exponential decay to rotation matrix. + scale (float): Scale of positional embedding, set to 0 to deactivate. + device (torch.device, optional): Device on which to initialize the module. + dtype (torch.dtype): dtype to use to generate the embedding. + """ + def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool = False, + scale: float = 1.0, device=None, dtype: torch.dtype = torch.float32): + super().__init__() + assert dim % 2 == 0 + self.scale = scale + assert dtype in [torch.float64, torch.float32] + self.dtype = dtype + + adim = torch.arange(0, dim, 2, device=device, dtype=dtype)[: (dim // 2)] + frequencies = 1.0 / (max_period ** (adim / dim)) + self.register_buffer("frequencies", frequencies) + self.rotation: tp.Optional[torch.Tensor] = None + + self.xpos = XPos(dim, device=device, dtype=dtype) if xpos else None + + def get_rotation(self, start: int, end: int): + """Create complex rotation tensor, cache values for fast computation.""" + if self.rotation is None or end > self.rotation.shape[0]: + assert isinstance(self.frequencies, torch.Tensor) # Satisfy type checker. + idx = torch.arange(end, device=self.frequencies.device, dtype=self.dtype) + angles = torch.outer(idx, self.frequencies) + self.rotation = torch.polar(torch.ones_like(angles), angles) + return self.rotation[start:end] + + def rotate(self, x: torch.Tensor, start: int = 0, time_dim: int = 1, invert_decay: bool = False): + """Apply rope rotation to query or key tensor.""" + T = x.shape[time_dim] + target_shape = [1] * x.dim() + target_shape[time_dim] = T + target_shape[-1] = -1 + rotation = self.get_rotation(start, start + T).view(target_shape) + + if self.xpos: + decay = self.xpos.get_decay(start, start + T).view(target_shape) + else: + decay = 1.0 + + if invert_decay: + decay = decay ** -1 + + x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2)) + scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale) + x_out = torch.view_as_real(x_complex * scaled_rotation).view_as(x) + + return x_out.type_as(x) + + def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0, time_dim: int = 1): + """ Apply rope rotation to both query and key tensors. + Supports streaming mode, in which query and key are not expected to have the same shape. + In streaming mode, key will be of length [P + C] with P the cached past timesteps, but + query will be [C] (typically C == 1). + + Args: + query (torch.Tensor): Query to rotate. + key (torch.Tensor): Key to rotate. + start (int): Start index of the sequence for time offset. + time_dim (int): which dimension represent the time steps. + """ + query_timesteps = query.shape[time_dim] + key_timesteps = key.shape[time_dim] + streaming_offset = key_timesteps - query_timesteps + + query_out = self.rotate(query, start + streaming_offset, time_dim) + key_out = self.rotate(key, start, time_dim, invert_decay=True) + + return query_out, key_out diff --git a/audiocraft/seanet.py b/audiocraft/seanet.py new file mode 100644 index 0000000000000000000000000000000000000000..3e5998e9153afb6e68ea410d565e00ea835db248 --- /dev/null +++ b/audiocraft/seanet.py @@ -0,0 +1,258 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import typing as tp + +import numpy as np +import torch.nn as nn + +from .conv import StreamableConv1d, StreamableConvTranspose1d +from .lstm import StreamableLSTM + + +class SEANetResnetBlock(nn.Module): + """Residual block from SEANet model. + + Args: + dim (int): Dimension of the input/output. + kernel_sizes (list): List of kernel sizes for the convolutions. + dilations (list): List of dilations for the convolutions. + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function. + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + compress (int): Reduced dimensionality in residual branches (from Demucs v3). + true_skip (bool): Whether to use true skip connection or a simple + (streamable) convolution as the skip connection. + """ + def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1], + activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, + norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False, + pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True): + super().__init__() + assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations' + act = getattr(nn, activation) + hidden = dim // compress + block = [] + for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)): + in_chs = dim if i == 0 else hidden + out_chs = dim if i == len(kernel_sizes) - 1 else hidden + block += [ + act(**activation_params), + StreamableConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation, + norm=norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode), + ] + self.block = nn.Sequential(*block) + self.shortcut: nn.Module + if true_skip: + self.shortcut = nn.Identity() + else: + self.shortcut = StreamableConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode) + + def forward(self, x): + return self.shortcut(x) + self.block(x) + + +class SEANetEncoder(nn.Module): + """SEANet encoder. + + Args: + channels (int): Audio channels. + dimension (int): Intermediate representation dimension. + n_filters (int): Base width for the model. + n_residual_layers (int): nb of residual layers. + ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of + upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here + that must match the decoder order. We use the decoder order as some models may only employ the decoder. + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function. + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + kernel_size (int): Kernel size for the initial convolution. + last_kernel_size (int): Kernel size for the initial convolution. + residual_kernel_size (int): Kernel size for the residual layers. + dilation_base (int): How much to increase the dilation with each layer. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + true_skip (bool): Whether to use true skip connection or a simple + (streamable) convolution as the skip connection in the residual network blocks. + compress (int): Reduced dimensionality in residual branches (from Demucs v3). + lstm (int): Number of LSTM layers at the end of the encoder. + disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm. + For the encoder, it corresponds to the N first blocks. + """ + def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3, + ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, + norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7, + last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, + pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0, + disable_norm_outer_blocks: int = 0): + super().__init__() + self.channels = channels + self.dimension = dimension + self.n_filters = n_filters + self.ratios = list(reversed(ratios)) + del ratios + self.n_residual_layers = n_residual_layers + self.hop_length = np.prod(self.ratios) + self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks + self.disable_norm_outer_blocks = disable_norm_outer_blocks + assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \ + "Number of blocks for which to disable norm is invalid." \ + "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0." + + act = getattr(nn, activation) + mult = 1 + model: tp.List[nn.Module] = [ + StreamableConv1d(channels, mult * n_filters, kernel_size, + norm='none' if self.disable_norm_outer_blocks >= 1 else norm, + norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) + ] + # Downsample to raw audio scale + for i, ratio in enumerate(self.ratios): + block_norm = 'none' if self.disable_norm_outer_blocks >= i + 2 else norm + # Add residual layers + for j in range(n_residual_layers): + model += [ + SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1], + dilations=[dilation_base ** j, 1], + norm=block_norm, norm_params=norm_params, + activation=activation, activation_params=activation_params, + causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)] + + # Add downsampling layers + model += [ + act(**activation_params), + StreamableConv1d(mult * n_filters, mult * n_filters * 2, + kernel_size=ratio * 2, stride=ratio, + norm=block_norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode), + ] + mult *= 2 + + if lstm: + model += [StreamableLSTM(mult * n_filters, num_layers=lstm)] + + model += [ + act(**activation_params), + StreamableConv1d(mult * n_filters, dimension, last_kernel_size, + norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm, + norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) + ] + + self.model = nn.Sequential(*model) + + def forward(self, x): + return self.model(x) + + +class SEANetDecoder(nn.Module): + """SEANet decoder. + + Args: + channels (int): Audio channels. + dimension (int): Intermediate representation dimension. + n_filters (int): Base width for the model. + n_residual_layers (int): nb of residual layers. + ratios (Sequence[int]): kernel size and stride ratios. + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function. + final_activation (str): Final activation function after all convolutions. + final_activation_params (dict): Parameters to provide to the activation function. + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + kernel_size (int): Kernel size for the initial convolution. + last_kernel_size (int): Kernel size for the initial convolution. + residual_kernel_size (int): Kernel size for the residual layers. + dilation_base (int): How much to increase the dilation with each layer. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + true_skip (bool): Whether to use true skip connection or a simple. + (streamable) convolution as the skip connection in the residual network blocks. + compress (int): Reduced dimensionality in residual branches (from Demucs v3). + lstm (int): Number of LSTM layers at the end of the encoder. + disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm. + For the decoder, it corresponds to the N last blocks. + trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup. + If equal to 1.0, it means that all the trimming is done at the right. + """ + def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3, + ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, + final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None, + norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7, + last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, + pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0, + disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0): + super().__init__() + self.dimension = dimension + self.channels = channels + self.n_filters = n_filters + self.ratios = ratios + del ratios + self.n_residual_layers = n_residual_layers + self.hop_length = np.prod(self.ratios) + self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks + self.disable_norm_outer_blocks = disable_norm_outer_blocks + assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \ + "Number of blocks for which to disable norm is invalid." \ + "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0." + + act = getattr(nn, activation) + mult = int(2 ** len(self.ratios)) + model: tp.List[nn.Module] = [ + StreamableConv1d(dimension, mult * n_filters, kernel_size, + norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm, + norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) + ] + + if lstm: + model += [StreamableLSTM(mult * n_filters, num_layers=lstm)] + + # Upsample to raw audio scale + for i, ratio in enumerate(self.ratios): + block_norm = 'none' if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) else norm + # Add upsampling layers + model += [ + act(**activation_params), + StreamableConvTranspose1d(mult * n_filters, mult * n_filters // 2, + kernel_size=ratio * 2, stride=ratio, + norm=block_norm, norm_kwargs=norm_params, + causal=causal, trim_right_ratio=trim_right_ratio), + ] + # Add residual layers + for j in range(n_residual_layers): + model += [ + SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1], + dilations=[dilation_base ** j, 1], + activation=activation, activation_params=activation_params, + norm=block_norm, norm_params=norm_params, causal=causal, + pad_mode=pad_mode, compress=compress, true_skip=true_skip)] + + mult //= 2 + + # Add final layers + model += [ + act(**activation_params), + StreamableConv1d(n_filters, channels, last_kernel_size, + norm='none' if self.disable_norm_outer_blocks >= 1 else norm, + norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) + ] + # Add optional final activation to decoder (eg. tanh) + if final_activation is not None: + final_act = getattr(nn, final_activation) + final_activation_params = final_activation_params or {} + model += [ + final_act(**final_activation_params) + ] + self.model = nn.Sequential(*model) + + def forward(self, z): + y = self.model(z) + return y diff --git a/audiocraft/streaming.py b/audiocraft/streaming.py new file mode 100644 index 0000000000000000000000000000000000000000..fba06936294ca15d72acd2d44f9dbda39a638107 --- /dev/null +++ b/audiocraft/streaming.py @@ -0,0 +1,131 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Streaming module API that should be implemented by all Streaming components, +""" + +from contextlib import contextmanager +import typing as tp +from torch import nn +import torch + + +State = tp.Dict[str, torch.Tensor] + + +class StreamingModule(nn.Module): + """Common API for streaming components. + + Each streaming component has a streaming state, which is just a dict[str, Tensor]. + By convention, the first dim of each tensor must be the batch size. + Don't use dots in the key names, as this would clash with submodules + (like in state_dict). + + If `self._is_streaming` is True, the component should use and remember + the proper state inside `self._streaming_state`. + + To set a streaming component in streaming state, use + + with module.streaming(): + ... + + This will automatically reset the streaming state when exiting the context manager. + This also automatically propagates to all streaming children module. + + Some module might also implement the `StreamingModule.flush` method, although + this one is trickier, as all parents module must be StreamingModule and implement + it as well for it to work properly. See `StreamingSequential` after. + """ + def __init__(self) -> None: + super().__init__() + self._streaming_state: State = {} + self._is_streaming = False + + def _apply_named_streaming(self, fn: tp.Any): + for name, module in self.named_modules(): + if isinstance(module, StreamingModule): + fn(name, module) + + def _set_streaming(self, streaming: bool): + def _set_streaming(name, module): + module._is_streaming = streaming + self._apply_named_streaming(_set_streaming) + + @contextmanager + def streaming(self): + """Context manager to enter streaming mode. Reset streaming state on exit.""" + self._set_streaming(True) + try: + yield + finally: + self._set_streaming(False) + self.reset_streaming() + + def reset_streaming(self): + """Reset the streaming state.""" + def _reset(name: str, module: StreamingModule): + module._streaming_state.clear() + + self._apply_named_streaming(_reset) + + def get_streaming_state(self) -> State: + """Return the streaming state, including that of sub-modules.""" + state: State = {} + + def _add(name: str, module: StreamingModule): + if name: + name += "." + for key, value in module._streaming_state.items(): + state[name + key] = value + + self._apply_named_streaming(_add) + return state + + def set_streaming_state(self, state: State): + """Set the streaming state, including that of sub-modules.""" + state = dict(state) + + def _set(name: str, module: StreamingModule): + if name: + name += "." + module._streaming_state.clear() + for key, value in list(state.items()): + # complexity is not ideal here, but probably fine. + if key.startswith(name): + local_key = key[len(name):] + if '.' not in local_key: + module._streaming_state[local_key] = value + del state[key] + + self._apply_named_streaming(_set) + assert len(state) == 0, list(state.keys()) + + def flush(self, x: tp.Optional[torch.Tensor] = None): + """Flush any remaining outputs that were waiting for completion. + Typically, for convolutions, this will add the final padding + and process the last buffer. + + This should take an optional argument `x`, which will be provided + if a module before this one in the streaming pipeline has already + spitted out a flushed out buffer. + """ + if x is None: + return None + else: + return self(x) + + +class StreamingSequential(StreamingModule, nn.Sequential): + """A streaming compatible alternative of `nn.Sequential`. + """ + def flush(self, x: tp.Optional[torch.Tensor] = None): + for module in self: + if isinstance(module, StreamingModule): + x = module.flush(x) + elif x is not None: + x = module(x) + return x diff --git a/audiocraft/transformer.py b/audiocraft/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..4d44b39ecbe6ce2ec4370a149d0b285ebf663f44 --- /dev/null +++ b/audiocraft/transformer.py @@ -0,0 +1,755 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Transformer model, with streaming support, xformer attention support +and easy causal attention with a potentially finite receptive field. + +See `StreamingTransformer` for more information. + +Unlike regular PyTorch Transformer, we make the hard choice that batches are first. +""" + +import typing as tp + +from einops import rearrange +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.utils.checkpoint import checkpoint as torch_checkpoint +from xformers import ops + +from .rope import RotaryEmbedding +from .streaming import StreamingModule + +_efficient_attention_backend: str = 'torch' + + +def set_efficient_attention_backend(backend: str = 'torch'): + # Using torch by default, it seems a bit faster on older P100 GPUs (~20% faster). + global _efficient_attention_backend + assert _efficient_attention_backend in ['xformers', 'torch'] + _efficient_attention_backend = backend + + +def _get_attention_time_dimension(memory_efficient: bool) -> int: + if _efficient_attention_backend == 'torch' and memory_efficient: + return 2 + else: + return 1 + + +def _is_profiled() -> bool: + # Return true if we are currently running with a xformers profiler activated. + try: + from xformers.profiler import profiler + except ImportError: + return False + return profiler._Profiler._CURRENT_PROFILER is not None + + +def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module: + """Create normalization module for transformer encoder layer. + + Args: + norm_type (str): Normalization method. + dim (int): Dimension of the normalized layer. + **kwargs (dict): Additional parameters for normalization layer. + Returns: + nn.Module: Normalization module. + """ + if norm_type == 'layer_norm': + return nn.LayerNorm(dim, eps=1e-5, **kwargs) + else: + raise ValueError(f"Unknown norm type: {norm_type}") + + +def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000, + dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Create sinusoidal positional embedding, with shape `[B, T, C]`. + + Args: + positions (torch.Tensor): LongTensor of positions. + dim (int): Dimension of the embedding. + max_period (float): Maximum period of the cosine/sine functions. + dtype (torch.dtype or str): dtype to use to generate the embedding. + Returns: + torch.Tensor: Sinusoidal positional embedding. + """ + # We aim for BTC format + assert dim % 2 == 0 + half_dim = dim // 2 + positions = positions.to(dtype) + adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1) + max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) # avoid sync point + phase = positions / (max_period_tensor ** (adim / (half_dim - 1))) + return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1) + + +def expand_repeated_kv(x: torch.Tensor, n_rep: int, memory_efficient: bool) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers.""" + if n_rep == 1: + return x + if _efficient_attention_backend == 'torch' and memory_efficient: + bs, n_kv_heads, slen, head_dim = x.shape + return ( + x[:, :, None, :, :] + .expand(bs, n_kv_heads, n_rep, slen, head_dim) + .reshape(bs, n_kv_heads * n_rep, slen, head_dim) + ) + else: + bs, slen, n_kv_heads, head_dim = x.shape + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class LayerScale(nn.Module): + """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). + This rescales diagonally the residual outputs close to 0, with a learnt scale. + + Args: + channels (int): Number of channels. + init (float): Initial scale. + channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`. + device (torch.device or str, optional): Device on which to initialize the module. + dtype (torch.dtype, optional): dtype to use to initialize the module. + """ + def __init__(self, channels: int, init: float = 1e-4, channel_last: bool = True, + device=None, dtype=None): + super().__init__() + self.channel_last = channel_last + self.scale = nn.Parameter( + torch.full((channels,), init, + requires_grad=True, device=device, dtype=dtype)) + + def forward(self, x: torch.Tensor): + if self.channel_last: + return self.scale * x + else: + return self.scale[:, None] * x + + +class StreamingMultiheadAttention(StreamingModule): + """Similar to `nn.MultiheadAttention` but with support for streaming, causal evaluation. + + Args: + embed_dim (int): Dimension to project to. + num_heads (int): Number of heads. + dropout (float): Dropout level. + bias (bool): Use bias in projections. + causal (bool): Causal mask applied automatically. + past_context (int, optional): Receptive field for the causal mask, infinite if None. + custom (bool): Use custom MHA implementation, for testing / benchmarking. + memory_efficient (bool): Use xformers based memory efficient attention. + attention_as_float32 (bool): Perform the attention as float32 + (especially important with memory_efficient as autocast won't do this automatically). + rope (`RotaryEmbedding`, optional): Rope embedding to use. + cross_attention: Should be true when used as a cross attention. + All keys and values must be available at once, streaming is only for the queries. + Cannot be used with `causal` or `rope` (as it wouldn't make sens to + interpret the time steps in the keys relative to those in the queries). + safe_streaming (bool): Bug fix, will go away with xformers update. + qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product. + kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads). + This will lead to faster decoding time on A100 or other GPUs with tensorcore. + device (torch.device, optional): Device on which to initialize. + dtype (torch.dtype, optional): dtype to use. + """ + def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True, + causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False, + memory_efficient: bool = False, attention_as_float32: bool = False, + rope: tp.Optional[RotaryEmbedding] = None, cross_attention: bool = False, + safe_streaming: bool = True, qk_layer_norm: bool = False, kv_repeat: int = 1, + device=None, dtype=None): + super().__init__() + factory_kwargs = {'device': device, 'dtype': dtype} + if past_context is not None: + assert causal + + self.embed_dim = embed_dim + self.causal = causal + self.past_context = past_context + self.memory_efficient = memory_efficient + self.attention_as_float32 = attention_as_float32 + self.rope = rope + self.cross_attention = cross_attention + self.safe_streaming = safe_streaming + self.num_heads = num_heads + self.dropout = dropout + self.kv_repeat = kv_repeat + if cross_attention: + assert not causal, "Causal cannot work with cross attention." + assert rope is None, "Rope cannot work with cross attention." + + if memory_efficient: + _verify_xformers_memory_efficient_compat() + + self.custom = _is_custom(custom, memory_efficient) + if self.custom: + out_dim = embed_dim + assert num_heads % kv_repeat == 0 + assert not cross_attention or kv_repeat == 1 + num_kv = num_heads // kv_repeat + kv_dim = (embed_dim // num_heads) * num_kv + out_dim += 2 * kv_dim + in_proj = nn.Linear(embed_dim, out_dim, bias=bias, **factory_kwargs) + # We try to follow the default PyTorch MHA convention, to easily compare results. + self.in_proj_weight = in_proj.weight + self.in_proj_bias = in_proj.bias + if bias: + self.in_proj_bias.data.zero_() # Following Pytorch convention + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) + if bias: + self.out_proj.bias.data.zero_() + else: + assert not qk_layer_norm + assert kv_repeat == 1 + self.mha = nn.MultiheadAttention( + embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True, + **factory_kwargs) + self.qk_layer_norm = qk_layer_norm + if qk_layer_norm: + assert self.custom + assert kv_repeat == 1 + ln_dim = embed_dim + self.q_layer_norm = nn.LayerNorm(ln_dim) + self.k_layer_norm = nn.LayerNorm(ln_dim) + + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + if not self.custom: + # Support compat with regular MHA + keys = [n for n, _ in self.mha.named_parameters()] + for key in keys: + if prefix + key in state_dict: + state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype): + # Return a causal mask, accounting for potentially stored past keys/values + # We actually return a bias for the attention score, as this has the same + # convention both in the builtin MHA in Pytorch, and Xformers functions. + time_dim = _get_attention_time_dimension(self.memory_efficient) + if self.memory_efficient: + from xformers.ops import LowerTriangularMask + if current_steps == 1: + # If we only have one step, then we do not need a mask. + return None + elif 'past_keys' in self._streaming_state: + raise RuntimeError("Not supported at the moment") + else: + # Then we can safely use a lower triangular mask + return LowerTriangularMask() + if self._streaming_state: + past_keys = self._streaming_state['past_keys'] + past_steps = past_keys.shape[time_dim] + else: + past_steps = 0 + + queries_pos = torch.arange( + past_steps, current_steps + past_steps, device=device).view(-1, 1) + keys_pos = torch.arange(past_steps + current_steps, device=device).view(1, -1) + delta = queries_pos - keys_pos + valid = delta >= 0 + if self.past_context is not None: + valid &= (delta <= self.past_context) + return torch.where( + valid, + torch.zeros([], device=device, dtype=dtype), + torch.full([], float('-inf'), device=device, dtype=dtype)) + + def _complete_kv(self, k, v): + time_dim = _get_attention_time_dimension(self.memory_efficient) + if self.cross_attention: + # With cross attention we assume all keys and values + # are already available, and streaming is with respect + # to the queries only. + return k, v + # Complete the key/value pair using the streaming state. + if self._streaming_state: + pk = self._streaming_state['past_keys'] + nk = torch.cat([pk, k], dim=time_dim) + if v is k: + nv = nk + else: + pv = self._streaming_state['past_values'] + nv = torch.cat([pv, v], dim=time_dim) + else: + nk = k + nv = v + + assert nk.shape[time_dim] == nv.shape[time_dim] + offset = 0 + if self.past_context is not None: + offset = max(0, nk.shape[time_dim] - self.past_context) + if self._is_streaming: + self._streaming_state['past_keys'] = nk[:, offset:] + if v is not k: + self._streaming_state['past_values'] = nv[:, offset:] + if 'offset' in self._streaming_state: + self._streaming_state['offset'] += offset + else: + self._streaming_state['offset'] = torch.tensor(0) + return nk, nv + + def _apply_rope(self, query: torch.Tensor, key: torch.Tensor): + time_dim = _get_attention_time_dimension(self.memory_efficient) + # Apply rope embeddings to query and key tensors. + assert self.rope is not None + if 'past_keys' in self._streaming_state: + past_keys_offset = self._streaming_state['past_keys'].shape[1] + else: + past_keys_offset = 0 + if 'offset' in self._streaming_state: + past_context_offset = int(self._streaming_state['offset'].item()) + else: + past_context_offset = 0 + streaming_offset = past_context_offset + past_keys_offset + return self.rope.rotate_qk(query, key, start=streaming_offset, time_dim=time_dim) + + def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + key_padding_mask=None, need_weights=False, attn_mask=None, + average_attn_weights=True, is_causal=False): + assert not is_causal, ("New param added in torch 2.0.1 not supported, " + "use the causal args in the constructor.") + + time_dim = _get_attention_time_dimension(self.memory_efficient) + if time_dim == 2: + layout = "b h t d" + else: + layout = "b t h d" + dtype = query.dtype + if self._is_streaming: + assert self.causal or self.cross_attention, \ + "Streaming only available for causal or cross attention" + + custom_attn_mask = attn_mask is not None + + if self.causal: + assert attn_mask is None + # At the moment we specialize only for the self-attention case. + assert query.shape[1] == key.shape[1], "Causal only for same length query / key / value" + assert value.shape[1] == key.shape[1], "Causal only for same length query / key / value" + attn_mask = self._get_mask(query.shape[1], query.device, query.dtype) + + if self.custom: + # custom implementation + assert need_weights is False + assert key_padding_mask is None + if self.cross_attention: + # Different queries, keys, values, we have to spit manually the weights + # before applying the linear. + dim = self.in_proj_weight.shape[0] // 3 + if self.in_proj_bias is None: + bias_q, bias_k, bias_v = None, None, None + else: + bias_q = self.in_proj_bias[:dim] + bias_k = self.in_proj_bias[dim: 2 * dim] + bias_v = self.in_proj_bias[2 * dim:] + q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q) + # todo: when streaming, we could actually save k, v and check the shape actually match. + k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k) + v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v) + if self.qk_layer_norm is True: + q = self.q_layer_norm(q) + k = self.k_layer_norm(k) + q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]] + else: + if not _is_profiled(): + # profiling breaks that propertysomehow. + assert query is key, "specialized implementation" + assert value is key, "specialized implementation" + projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias) + if self.kv_repeat == 1: + if time_dim == 2: + bound_layout = "b h p t d" + else: + bound_layout = "b t p h d" + packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads) + q, k, v = ops.unbind(packed, dim=2) + else: + embed_dim = self.embed_dim + per_head_dim = (embed_dim // self.num_heads) + kv_heads = self.num_heads // self.kv_repeat + q = projected[:, :, :embed_dim] + start = embed_dim + end = start + per_head_dim * kv_heads + k = projected[:, :, start: end] + v = projected[:, :, end:] + q = rearrange(q, f"b t (h d) -> {layout}", h=self.num_heads) + k = rearrange(k, f"b t (h d) -> {layout}", h=kv_heads) + v = rearrange(v, f"b t (h d) -> {layout}", h=kv_heads) + + if self.qk_layer_norm is True: + assert self.kv_repeat == 1 + q, k = [rearrange(x, f"{layout} -> b t (h d)") for x in [q, k]] + q = self.q_layer_norm(q) + k = self.k_layer_norm(k) + q, k = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k]] + if self.rope: + q, k = self._apply_rope(q, k) + k, v = self._complete_kv(k, v) + if self.kv_repeat > 1: + k = expand_repeated_kv(k, self.kv_repeat, self.memory_efficient) + v = expand_repeated_kv(v, self.kv_repeat, self.memory_efficient) + if self.attention_as_float32: + q, k, v = [x.float() for x in [q, k, v]] + if self.memory_efficient: + if custom_attn_mask: + # When using a custom attn mask: + # Move to query's device, repeat for each sample, remove align8 padding + seq_len = query.shape[1] + attn_mask = attn_mask.to(q.dtype) + attn_mask = attn_mask.repeat((q.shape[0], 1, 1, 1)) + attn_mask = attn_mask[..., :seq_len, :seq_len] + + p = self.dropout if self.training else 0 + if _efficient_attention_backend == 'torch': + x = torch.nn.functional.scaled_dot_product_attention( + q, k, v, is_causal=attn_mask is not None, dropout_p=p) + else: + x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p) + else: + # We include the dot product as float32, for consistency + # with the other implementations that include that step + # as part of the attention. Note that when using `autocast`, + # the einsums would be done as bfloat16, but the softmax + # would be done as bfloat16, so `attention_as_float32` will + # extend a bit the range of operations done in float32, + # although this should make no difference. + q = q / q.shape[-1] ** 0.5 + key_layout = layout.replace('t', 'k') + query_layout = layout + if self._is_streaming and self.safe_streaming and q.device.type == 'cuda': + with torch.autocast(device_type=q.device.type, dtype=torch.float32): + pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k) + else: + pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k) + if attn_mask is not None: + pre_w = pre_w + attn_mask + w = torch.softmax(pre_w, dim=-1) + w = F.dropout(w, self.dropout, training=self.training).to(v) + # Key and value have the same format. + x = torch.einsum(f"b h t k, {key_layout} -> {layout}", w, v) + x = x.to(dtype) + x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads) + x = self.out_proj(x) + else: + key, value = self._complete_kv(key, value) + if self.attention_as_float32: + query, key, value = [x.float() for x in [query, key, value]] + x, _ = self.mha( + query, key, value, key_padding_mask, + need_weights, attn_mask, average_attn_weights) + x = x.to(dtype) + + return x, None + + +class StreamingTransformerLayer(nn.TransformerEncoderLayer): + """TransformerLayer with Streaming / Causal support. + This also integrates cross_attention, when passing `cross_attention=True`, + rather than having two separate classes like in PyTorch. + + Args: + d_model (int): Dimension of the data. + num_heads (int): Number of heads. + dim_feedforward (int): Intermediate dimension of FF module. + dropout (float): Dropout both for MHA and FF. + bias_ff (bool): Use bias for FF. + bias_attn (bool): Use bias for MHA. + causal (bool): Causal mask applied automatically. + past_context (int, optional): Receptive field for the causal mask, infinite if None. + custom (bool): Use custom MHA implementation, for testing / benchmarking. + memory_efficient (bool): Use xformers based memory efficient attention. + attention_as_float32 (bool): Perform the attention as float32 + (especially important with memory_efficient as autocast won't do this automatically). + qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product in attention. + qk_layer_norm_cross (bool): Same for the cross attention. + cross_attention (bool): If True, expect to get secondary input for cross-attention. + Cross attention will use the default MHA, as it typically won't require + special treatment. + layer_scale (float, optional): If not None, LayerScale will be used with + the given value as initial scale. + rope (`RotaryEmbedding`, optional): Rope embedding to use. + attention_dropout (float, optional): If not None, separate the value of the dimension dropout + in FFN and of the attention dropout. + kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads). + This will lead to faster decoding time on A100 or other GPUs with tensorcore. + device (torch.device, optional): Device on which to initialize. + dtype (torch.dtype, optional): dtype to use. + **kwargs: See `nn.TransformerEncoderLayer`. + """ + def __init__(self, d_model: int, num_heads: int, dim_feedforward: int = 2048, dropout: float = 0.1, + bias_ff: bool = True, bias_attn: bool = True, causal: bool = False, + past_context: tp.Optional[int] = None, custom: bool = False, + memory_efficient: bool = False, attention_as_float32: bool = False, + qk_layer_norm: bool = False, qk_layer_norm_cross: bool = False, + cross_attention: bool = False, layer_scale: tp.Optional[float] = None, + rope: tp.Optional[RotaryEmbedding] = None, attention_dropout: tp.Optional[float] = None, + kv_repeat: int = 1, norm: str = 'layer_norm', device=None, dtype=None, **kwargs): + super().__init__(d_model, num_heads, dim_feedforward, dropout, + device=device, dtype=dtype, batch_first=True, **kwargs) + factory_kwargs = {'device': device, 'dtype': dtype} + # Redefine self_attn to our streaming multi-head attention + attn_kwargs: tp.Dict[str, tp.Any] = { + 'embed_dim': d_model, + 'num_heads': num_heads, + 'dropout': dropout if attention_dropout is None else attention_dropout, + 'bias': bias_attn, + 'custom': custom, + 'memory_efficient': memory_efficient, + 'attention_as_float32': attention_as_float32, + } + self.self_attn: StreamingMultiheadAttention = StreamingMultiheadAttention( + causal=causal, past_context=past_context, rope=rope, qk_layer_norm=qk_layer_norm, + kv_repeat=kv_repeat, **attn_kwargs, **factory_kwargs) # type: ignore + # Redefine feedforward layers to expose bias parameter + self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs) + self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs) + + self.layer_scale_1: nn.Module + self.layer_scale_2: nn.Module + if layer_scale is None: + self.layer_scale_1 = nn.Identity() + self.layer_scale_2 = nn.Identity() + else: + self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs) + self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs) + + self.cross_attention: tp.Optional[nn.Module] = None + if cross_attention: + self.cross_attention = StreamingMultiheadAttention( + cross_attention=True, qk_layer_norm=qk_layer_norm_cross, + **attn_kwargs, **factory_kwargs) + # Norm and dropout + self.dropout_cross = nn.Dropout(dropout) + # eps value matching that used in PyTorch reference implementation. + self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs) + self.layer_scale_cross: nn.Module + if layer_scale is None: + self.layer_scale_cross = nn.Identity() + else: + self.layer_scale_cross = LayerScale(d_model, layer_scale, **factory_kwargs) + self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore + self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore + + def _cross_attention_block(self, src: torch.Tensor, + cross_attention_src: torch.Tensor) -> torch.Tensor: + assert self.cross_attention is not None + # queries are from src, keys and values from cross_attention_src. + x = self.cross_attention( + src, cross_attention_src, cross_attention_src, need_weights=False)[0] + return self.dropout_cross(x) # type: ignore + + def forward(self, src: torch.Tensor, src_mask: tp.Optional[torch.Tensor] = None, # type: ignore + src_key_padding_mask: tp.Optional[torch.Tensor] = None, + cross_attention_src: tp.Optional[torch.Tensor] = None): + if self.cross_attention is None: + assert cross_attention_src is None + else: + assert cross_attention_src is not None + x = src + if self.norm_first: + x = x + self.layer_scale_1( + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)) + if cross_attention_src is not None: + x = x + self.layer_scale_cross( + self._cross_attention_block( + self.norm_cross(x), cross_attention_src)) + x = x + self.layer_scale_2(self._ff_block(self.norm2(x))) + else: + x = self.norm1(x + self.layer_scale_1( + self._sa_block(x, src_mask, src_key_padding_mask))) + if cross_attention_src is not None: + x = self.norm_cross( + x + self.layer_scale_cross( + self._cross_attention_block(src, cross_attention_src))) + x = self.norm2(x + self.layer_scale_2(self._ff_block(x))) + return x + + +class StreamingTransformer(StreamingModule): + """Transformer with Streaming / Causal support. + + Args: + d_model (int): Dimension of the data. + num_heads (int): Number of heads. + dim_feedforward (int): Intermediate dimension of FF module. + dropout (float): Dropout both for MHA and FF. + bias_ff (bool): Use bias for FF. + bias_attn (bool): Use bias for MHA. + causal (bool): Causal mask applied automatically. + past_context (int, optional): Receptive field for the causal mask, infinite if None. + custom (bool): Use custom MHA implementation, for testing / benchmarking. + memory_efficient (bool): Use xformers based memory efficient attention. + attention_as_float32 (bool): Perform the attention as float32 + (especially important with memory_efficient as autocast won't do this automatically). + cross_attention (bool): If True, expect to get secondary input for cross-attention. + layer_scale (float, optional): If not None, LayerScale will be used + with the given value as initial scale. + positional_embedding (str): Positional embedding strategy (sin, rope, or sin_rope). + max_period (float): Maximum period of the time embedding. + positional_scale (float): Scale of positional embedding, set to 0 to deactivate. + xpos (bool): Apply xpos exponential decay to positional embedding (rope only). + lr (float, optional): learning rate override through the `make_optim_group` API. + weight_decay (float, optional): Weight_decay override through the `make_optim_group` API. + layer_class: (subclass of `StreamingTransformerLayer): class to use + to initialize the layers, allowing further customization outside of AudioCraft. + checkpointing (str): Checkpointing strategy to reduce memory usage. + No checkpointing if set to 'none'. Per layer checkpointing using PyTorch + if set to 'torch' (entire layer checkpointed, i.e. linears are evaluated twice, + minimal memory usage, but maximal runtime). Finally, `xformers_default` provide + a policy for opting-out some operations of the checkpointing like + linear layers and attention, providing a middle ground between speed and memory. + device (torch.device, optional): Device on which to initialize. + dtype (torch.dtype, optional): dtype to use. + **kwargs: See `nn.TransformerEncoderLayer`. + """ + def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048, + dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True, + causal: bool = False, past_context: tp.Optional[int] = None, + custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False, + cross_attention: bool = False, layer_scale: tp.Optional[float] = None, + positional_embedding: str = 'sin', max_period: float = 10_000, positional_scale: float = 1., + xpos: bool = False, lr: tp.Optional[float] = None, weight_decay: tp.Optional[float] = None, + layer_class: tp.Type[StreamingTransformerLayer] = StreamingTransformerLayer, + checkpointing: str = 'none', device=None, dtype=None, **kwargs): + super().__init__() + assert d_model % num_heads == 0 + + self.positional_embedding = positional_embedding + self.max_period = max_period + self.positional_scale = positional_scale + self.weight_decay = weight_decay + self.lr = lr + + assert positional_embedding in ['sin', 'rope', 'sin_rope'] + self.rope: tp.Optional[RotaryEmbedding] = None + if self.positional_embedding in ['rope', 'sin_rope']: + assert _is_custom(custom, memory_efficient) + self.rope = RotaryEmbedding(d_model // num_heads, max_period=max_period, + xpos=xpos, scale=positional_scale, device=device) + + self.checkpointing = checkpointing + + assert checkpointing in ['none', 'torch', 'xformers_default', 'xformers_mm'] + if self.checkpointing.startswith('xformers'): + _verify_xformers_internal_compat() + + self.layers = nn.ModuleList() + for idx in range(num_layers): + self.layers.append( + layer_class( + d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward, + dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn, + causal=causal, past_context=past_context, custom=custom, + memory_efficient=memory_efficient, attention_as_float32=attention_as_float32, + cross_attention=cross_attention, layer_scale=layer_scale, rope=self.rope, + device=device, dtype=dtype, **kwargs)) + + if self.checkpointing != 'none': + for layer in self.layers: + # see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the + # backward hook inside of FSDP... + layer._magma_checkpointed = True # type: ignore + + def _apply_layer(self, layer, *args, **kwargs): + method = self.checkpointing + if method == 'none': + return layer(*args, **kwargs) + elif method == 'torch': + return torch_checkpoint(layer, *args, use_reentrant=False, **kwargs) + elif method.startswith('xformers'): + from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy + if method == 'xformers_default': + # those operations will be saved, and not recomputed. + # According to Francisco we can get smarter policies but this is a good start. + allow_list = [ + "xformers.efficient_attention_forward_cutlass.default", + "xformers_flash.flash_fwd.default", + "aten.addmm.default", + "aten.mm.default", + ] + elif method == 'xformers_mm': + # those operations will be saved, and not recomputed. + # According to Francisco we can get smarter policies but this is a good start. + allow_list = [ + "aten.addmm.default", + "aten.mm.default", + ] + else: + raise ValueError(f"xformers checkpointing xformers policy {method} is not known.") + policy_fn = _get_default_policy(allow_list) + return checkpoint(layer, *args, policy_fn=policy_fn, **kwargs) + else: + raise ValueError(f"Checkpointing method {method} is unknown.") + + def forward(self, x: torch.Tensor, *args, **kwargs): + B, T, C = x.shape + + if 'offsets' in self._streaming_state: + offsets = self._streaming_state['offsets'] + else: + offsets = torch.zeros(B, dtype=torch.long, device=x.device) + + if self.positional_embedding in ['sin', 'sin_rope']: + positions = torch.arange(T, device=x.device).view(1, -1, 1) + positions = positions + offsets.view(-1, 1, 1) + pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype) + x = x + self.positional_scale * pos_emb + + for layer in self.layers: + x = self._apply_layer(layer, x, *args, **kwargs) + + if self._is_streaming: + self._streaming_state['offsets'] = offsets + T + + return x + + def make_optim_group(self): + group = {"params": list(self.parameters())} + if self.lr is not None: + group["lr"] = self.lr + if self.weight_decay is not None: + group["weight_decay"] = self.weight_decay + return group + + +# special attention related function + +def _verify_xformers_memory_efficient_compat(): + try: + from xformers.ops import memory_efficient_attention, LowerTriangularMask # noqa + except ImportError: + raise ImportError( + "xformers is not installed. Please install it and try again.\n" + "To install on AWS and Azure, run \n" + "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n" + "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n" + "To install on FAIR Cluster, run \n" + "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n" + "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n") + + +def _verify_xformers_internal_compat(): + try: + from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy # noqa + except ImportError: + raise ImportError( + "Francisco's fairinternal xformers is not installed. Please install it and try again.\n" + "To install on AWS and Azure, run \n" + "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n" + "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n" + "To install on FAIR Cluster, run \n" + "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n" + "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n") + + +def _is_custom(custom: bool, memory_efficient: bool): + return custom or memory_efficient diff --git a/audiocraft/unet.py b/audiocraft/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..c2e383afea384ad9230356f78dd81045f65e9af9 --- /dev/null +++ b/audiocraft/unet.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Pytorch Unet Module used for diffusion. +""" + +from dataclasses import dataclass +import typing as tp + +import torch +from torch import nn +from torch.nn import functional as F +from .transformer import StreamingTransformer, create_sin_embedding + + +@dataclass +class Output: + sample: torch.Tensor + + +def get_model(cfg, channels: int, side: int, num_steps: int): + if cfg.model == 'unet': + return DiffusionUnet( + chin=channels, num_steps=num_steps, **cfg.diffusion_unet) + else: + raise RuntimeError('Not Implemented') + + +class ResBlock(nn.Module): + def __init__(self, channels: int, kernel: int = 3, norm_groups: int = 4, + dilation: int = 1, activation: tp.Type[nn.Module] = nn.ReLU, + dropout: float = 0.): + super().__init__() + stride = 1 + padding = dilation * (kernel - stride) // 2 + Conv = nn.Conv1d + Drop = nn.Dropout1d + self.norm1 = nn.GroupNorm(norm_groups, channels) + self.conv1 = Conv(channels, channels, kernel, 1, padding, dilation=dilation) + self.activation1 = activation() + self.dropout1 = Drop(dropout) + + self.norm2 = nn.GroupNorm(norm_groups, channels) + self.conv2 = Conv(channels, channels, kernel, 1, padding, dilation=dilation) + self.activation2 = activation() + self.dropout2 = Drop(dropout) + + def forward(self, x): + h = self.dropout1(self.conv1(self.activation1(self.norm1(x)))) + h = self.dropout2(self.conv2(self.activation2(self.norm2(h)))) + return x + h + + +class DecoderLayer(nn.Module): + def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2, + norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU, + dropout: float = 0.): + super().__init__() + padding = (kernel - stride) // 2 + self.res_blocks = nn.Sequential( + *[ResBlock(chin, norm_groups=norm_groups, dilation=2**idx, dropout=dropout) + for idx in range(res_blocks)]) + self.norm = nn.GroupNorm(norm_groups, chin) + ConvTr = nn.ConvTranspose1d + self.convtr = ConvTr(chin, chout, kernel, stride, padding, bias=False) + self.activation = activation() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.res_blocks(x) + x = self.norm(x) + x = self.activation(x) + x = self.convtr(x) + return x + + +class EncoderLayer(nn.Module): + def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2, + norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU, + dropout: float = 0.): + super().__init__() + padding = (kernel - stride) // 2 + Conv = nn.Conv1d + self.conv = Conv(chin, chout, kernel, stride, padding, bias=False) + self.norm = nn.GroupNorm(norm_groups, chout) + self.activation = activation() + self.res_blocks = nn.Sequential( + *[ResBlock(chout, norm_groups=norm_groups, dilation=2**idx, dropout=dropout) + for idx in range(res_blocks)]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, C, T = x.shape + stride, = self.conv.stride + pad = (stride - (T % stride)) % stride + x = F.pad(x, (0, pad)) + + x = self.conv(x) + x = self.norm(x) + x = self.activation(x) + x = self.res_blocks(x) + return x + + +class BLSTM(nn.Module): + """BiLSTM with same hidden units as input dim. + """ + def __init__(self, dim, layers=2): + super().__init__() + self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) + self.linear = nn.Linear(2 * dim, dim) + + def forward(self, x): + x = x.permute(2, 0, 1) + x = self.lstm(x)[0] + x = self.linear(x) + x = x.permute(1, 2, 0) + return x + + +class DiffusionUnet(nn.Module): + def __init__(self, chin: int = 3, hidden: int = 24, depth: int = 3, growth: float = 2., + max_channels: int = 10_000, num_steps: int = 1000, emb_all_layers=False, cross_attention: bool = False, + bilstm: bool = False, transformer: bool = False, + codec_dim: tp.Optional[int] = None, **kwargs): + super().__init__() + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + self.embeddings: tp.Optional[nn.ModuleList] = None + self.embedding = nn.Embedding(num_steps, hidden) + if emb_all_layers: + self.embeddings = nn.ModuleList() + self.condition_embedding: tp.Optional[nn.Module] = None + for d in range(depth): + encoder = EncoderLayer(chin, hidden, **kwargs) + decoder = DecoderLayer(hidden, chin, **kwargs) + self.encoders.append(encoder) + self.decoders.insert(0, decoder) + if emb_all_layers and d > 0: + assert self.embeddings is not None + self.embeddings.append(nn.Embedding(num_steps, hidden)) + chin = hidden + hidden = min(int(chin * growth), max_channels) + self.bilstm: tp.Optional[nn.Module] + if bilstm: + self.bilstm = BLSTM(chin) + else: + self.bilstm = None + self.use_transformer = transformer + self.cross_attention = False + if transformer: + self.cross_attention = cross_attention + self.transformer = StreamingTransformer(chin, 8, 6, bias_ff=False, bias_attn=False, + cross_attention=cross_attention) + + self.use_codec = False + if codec_dim is not None: + self.conv_codec = nn.Conv1d(codec_dim, chin, 1) + self.use_codec = True + + def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], condition: tp.Optional[torch.Tensor] = None): + skips = [] + bs = x.size(0) + z = x + view_args = [1] + if type(step) is torch.Tensor: + step_tensor = step + else: + step_tensor = torch.tensor([step], device=x.device, dtype=torch.long).expand(bs) + + for idx, encoder in enumerate(self.encoders): + z = encoder(z) + if idx == 0: + z = z + self.embedding(step_tensor).view(bs, -1, *view_args).expand_as(z) + elif self.embeddings is not None: + z = z + self.embeddings[idx - 1](step_tensor).view(bs, -1, *view_args).expand_as(z) + + skips.append(z) + + if self.use_codec: # insert condition in the bottleneck + assert condition is not None, "Model defined for conditionnal generation" + condition_emb = self.conv_codec(condition) # reshape to the bottleneck dim + assert condition_emb.size(-1) <= 2 * z.size(-1), \ + f"You are downsampling the conditionning with factor >=2 : {condition_emb.size(-1)=} and {z.size(-1)=}" + if not self.cross_attention: + + condition_emb = torch.nn.functional.interpolate(condition_emb, z.size(-1)) + assert z.size() == condition_emb.size() + z += condition_emb + cross_attention_src = None + else: + cross_attention_src = condition_emb.permute(0, 2, 1) # B, T, C + B, T, C = cross_attention_src.shape + positions = torch.arange(T, device=x.device).view(1, -1, 1) + pos_emb = create_sin_embedding(positions, C, max_period=10_000, dtype=cross_attention_src.dtype) + cross_attention_src = cross_attention_src + pos_emb + if self.use_transformer: + z = self.transformer(z.permute(0, 2, 1), cross_attention_src=cross_attention_src).permute(0, 2, 1) + else: + if self.bilstm is None: + z = torch.zeros_like(z) + else: + z = self.bilstm(z) + + for decoder in self.decoders: + s = skips.pop(-1) + z = z[:, :, :s.shape[2]] + z = z + s + z = decoder(z) + + z = z[:, :, :x.shape[2]] + return Output(z) diff --git a/audiocraft/utils/__init__.py b/audiocraft/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..75e25a0212f98e4a18d97c86c6cda225636a3215 --- /dev/null +++ b/audiocraft/utils/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Utilities.""" diff --git a/audiocraft/utils/audio_utils.py b/audiocraft/utils/audio_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9d3129b84b114c5572078295604279884c79f2cc --- /dev/null +++ b/audiocraft/utils/audio_utils.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Various utilities for audio convertion (pcm format, sample rate and channels), +and volume normalization.""" +import sys +import typing as tp + +import julius +import torch +import torchaudio + + +def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor: + """Convert audio to the given number of channels. + + Args: + wav (torch.Tensor): Audio wave of shape [B, C, T]. + channels (int): Expected number of channels as output. + Returns: + torch.Tensor: Downmixed or unchanged audio wave [B, C, T]. + """ + *shape, src_channels, length = wav.shape + if src_channels == channels: + pass + elif channels == 1: + # Case 1: + # The caller asked 1-channel audio, and the stream has multiple + # channels, downmix all channels. + wav = wav.mean(dim=-2, keepdim=True) + elif src_channels == 1: + # Case 2: + # The caller asked for multiple channels, but the input file has + # a single channel, replicate the audio over all channels. + wav = wav.expand(*shape, channels, length) + elif src_channels >= channels: + # Case 3: + # The caller asked for multiple channels, and the input file has + # more channels than requested. In that case return the first channels. + wav = wav[..., :channels, :] + else: + # Case 4: What is a reasonable choice here? + raise ValueError('The audio file has less channels than requested but is not mono.') + return wav + + +def convert_audio(wav: torch.Tensor, from_rate: float, + to_rate: float, to_channels: int) -> torch.Tensor: + """Convert audio to new sample rate and number of audio channels.""" + wav = julius.resample_frac(wav, int(from_rate), int(to_rate)) + wav = convert_audio_channels(wav, to_channels) + return wav + + +def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14, + loudness_compressor: bool = False, energy_floor: float = 2e-3): + """Normalize an input signal to a user loudness in dB LKFS. + Audio loudness is defined according to the ITU-R BS.1770-4 recommendation. + + Args: + wav (torch.Tensor): Input multichannel audio data. + sample_rate (int): Sample rate. + loudness_headroom_db (float): Target loudness of the output in dB LUFS. + loudness_compressor (bool): Uses tanh for soft clipping. + energy_floor (float): anything below that RMS level will not be rescaled. + Returns: + torch.Tensor: Loudness normalized output data. + """ + energy = wav.pow(2).mean().sqrt().item() + if energy < energy_floor: + return wav + transform = torchaudio.transforms.Loudness(sample_rate) + input_loudness_db = transform(wav).item() + # calculate the gain needed to scale to the desired loudness level + delta_loudness = -loudness_headroom_db - input_loudness_db + gain = 10.0 ** (delta_loudness / 20.0) + output = gain * wav + if loudness_compressor: + output = torch.tanh(output) + assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt()) + return output + + +def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optional[str] = None) -> None: + """Utility function to clip the audio with logging if specified.""" + max_scale = wav.abs().max() + if log_clipping and max_scale > 1: + clamp_prob = (wav.abs() > 1).float().mean().item() + print(f"CLIPPING {stem_name or ''} happening with proba (a bit of clipping is okay):", + clamp_prob, "maximum scale: ", max_scale.item(), file=sys.stderr) + wav.clamp_(-1, 1) + + +def normalize_audio(wav: torch.Tensor, normalize: bool = True, + strategy: str = 'peak', peak_clip_headroom_db: float = 1, + rms_headroom_db: float = 18, loudness_headroom_db: float = 14, + loudness_compressor: bool = False, log_clipping: bool = False, + sample_rate: tp.Optional[int] = None, + stem_name: tp.Optional[str] = None) -> torch.Tensor: + """Normalize the audio according to the prescribed strategy (see after). + + Args: + wav (torch.Tensor): Audio data. + normalize (bool): if `True` (default), normalizes according to the prescribed + strategy (see after). If `False`, the strategy is only used in case clipping + would happen. + strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak', + i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square + with extra headroom to avoid clipping. 'clip' just clips. + peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy. + rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger + than the `peak_clip` one to avoid further clipping. + loudness_headroom_db (float): Target loudness for loudness normalization. + loudness_compressor (bool): If True, uses tanh based soft clipping. + log_clipping (bool): If True, basic logging on stderr when clipping still + occurs despite strategy (only for 'rms'). + sample_rate (int): Sample rate for the audio data (required for loudness). + stem_name (str, optional): Stem name for clipping logging. + Returns: + torch.Tensor: Normalized audio. + """ + scale_peak = 10 ** (-peak_clip_headroom_db / 20) + scale_rms = 10 ** (-rms_headroom_db / 20) + if strategy == 'peak': + rescaling = (scale_peak / wav.abs().max()) + if normalize or rescaling < 1: + wav = wav * rescaling + elif strategy == 'clip': + wav = wav.clamp(-scale_peak, scale_peak) + elif strategy == 'rms': + mono = wav.mean(dim=0) + rescaling = scale_rms / mono.pow(2).mean().sqrt() + if normalize or rescaling < 1: + wav = wav * rescaling + _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name) + elif strategy == 'loudness': + assert sample_rate is not None, "Loudness normalization requires sample rate." + wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor) + _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name) + else: + assert wav.abs().max() < 1 + assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'" + return wav + + +def f32_pcm(wav: torch.Tensor) -> torch.Tensor: + """Convert audio to float 32 bits PCM format. + """ + if wav.dtype.is_floating_point: + return wav + elif wav.dtype == torch.int16: + return wav.float() / 2**15 + elif wav.dtype == torch.int32: + return wav.float() / 2**31 + raise ValueError(f"Unsupported wav dtype: {wav.dtype}") + + +def i16_pcm(wav: torch.Tensor) -> torch.Tensor: + """Convert audio to int 16 bits PCM format. + + ..Warning:: There exist many formula for doing this conversion. None are perfect + due to the asymmetry of the int16 range. One either have possible clipping, DC offset, + or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom, + it is possible that `i16_pcm(f32_pcm)) != Identity`. + """ + if wav.dtype.is_floating_point: + assert wav.abs().max() <= 1 + candidate = (wav * 2 ** 15).round() + if candidate.max() >= 2 ** 15: # clipping would occur + candidate = (wav * (2 ** 15 - 1)).round() + return candidate.short() + else: + assert wav.dtype == torch.int16 + return wav diff --git a/audiocraft/utils/autocast.py b/audiocraft/utils/autocast.py new file mode 100644 index 0000000000000000000000000000000000000000..ed644843bb37cf8a92a20fbd51d6cebaa43b9a08 --- /dev/null +++ b/audiocraft/utils/autocast.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +class TorchAutocast: + """TorchAutocast utility class. + Allows you to enable and disable autocast. This is specially useful + when dealing with different architectures and clusters with different + levels of support. + + Args: + enabled (bool): Whether to enable torch.autocast or not. + args: Additional args for torch.autocast. + kwargs: Additional kwargs for torch.autocast + """ + def __init__(self, enabled: bool, *args, **kwargs): + self.autocast = torch.autocast(*args, **kwargs) if enabled else None + + def __enter__(self): + if self.autocast is None: + return + try: + self.autocast.__enter__() + except RuntimeError: + device = self.autocast.device + dtype = self.autocast.fast_dtype + raise RuntimeError( + f"There was an error autocasting with dtype={dtype} device={device}\n" + "If you are on the FAIR Cluster, you might need to use autocast_dtype=float16" + ) + + def __exit__(self, *args, **kwargs): + if self.autocast is None: + return + self.autocast.__exit__(*args, **kwargs) diff --git a/audiocraft/utils/cache.py b/audiocraft/utils/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..6ba017a761a29c44d3385e0b483877cb4a8d1ec1 --- /dev/null +++ b/audiocraft/utils/cache.py @@ -0,0 +1,324 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from concurrent.futures import ThreadPoolExecutor +from collections import deque +from functools import partial +from hashlib import sha1 +import logging +from pathlib import Path +import sys +import typing as tp +import zipfile + +import flashy +import torch + + +logger = logging.getLogger(__name__) + + +def get_full_embed(full_embed: torch.Tensor, x: tp.Any, idx: int, device: tp.Union[str, torch.device]) -> torch.Tensor: + """Utility function for the EmbeddingCache, returning the full embedding without any chunking. + This method can be used in case there is no need in extracting a chunk of the full embedding + read from the cache. + + Args: + full_embed (torch.Tensor): The full embedding. + x (any): Batch object from which the full embedding is derived. + idx (torch.Tensor): Index of object to consider in the batch object. + Returns: + full_embed (torch.Tensor): The full embedding + """ + return full_embed.to(device) + + +class EmbeddingCache: + """Cache around embeddings computation for faster execution. + The EmbeddingCache is storing pre-computed embeddings on disk and provides a simple API + to retrieve the pre-computed embeddings on full inputs and extract only a given chunk + using a user-provided function. When the cache is warm (all embeddings are pre-computed), + the EmbeddingCache allows for faster training as it removes the need of computing the embeddings. + Additionally, it provides in-memory cache around the loaded embeddings to limit IO footprint + and synchronization points in the forward calls. + + Args: + cache_path (Path): Path to folder where all pre-computed embeddings are saved on disk. + device (str or torch.device): Device on which the embedding is returned. + compute_embed_fn (callable[[Path, any, int], torch.Tensor], optional): Function to compute + the embedding from a given object and path. This user provided function can compute the + embedding from the provided object or using the provided path as entry point. The last parameter + specify the index corresponding to the current embedding in the object that can represent batch metadata. + extract_embed_fn (callable[[torch.Tensor, any, int], torch.Tensor], optional): Function to extract + the desired embedding chunk from the full embedding loaded from the cache. The last parameter + specify the index corresponding to the current embedding in the object that can represent batch metadata. + If not specified, will return the full embedding unmodified. + """ + def __init__(self, cache_path: tp.Union[str, Path], device: tp.Union[str, torch.device], + compute_embed_fn: tp.Callable[[Path, tp.Any, int], torch.Tensor], + extract_embed_fn: tp.Optional[tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor]] = None): + self.cache_path = Path(cache_path) + self.device = device + self._compute_embed_fn = compute_embed_fn + self._extract_embed_fn: tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor] + if extract_embed_fn is not None: + self._extract_embed_fn = extract_embed_fn + else: + self._extract_embed_fn = partial(get_full_embed, device=device) + if self.cache_path is not None: + self.cache_path.mkdir(exist_ok=True, parents=True) + logger.info(f"Cache instantiated at: {self.cache_path}") + self.pool = ThreadPoolExecutor(8) + self.pool.__enter__() + self._current_batch_cache: dict = {} + self._memory_cache: dict = {} + + def _get_cache_path(self, path: tp.Union[Path, str]): + """Get cache path for the given file path.""" + sig = sha1(str(path).encode()).hexdigest() + return self.cache_path / sig + + @staticmethod + def _get_full_embed_from_cache(cache: Path): + """Loads full pre-computed embedding from the cache.""" + try: + embed = torch.load(cache, 'cpu') + except Exception as exc: + logger.error("Error loading %s: %r", cache, exc) + embed = None + return embed + + def get_embed_from_cache(self, paths: tp.List[Path], x: tp.Any) -> torch.Tensor: + """Get embedding from cache, computing and storing it to cache if not already cached. + The EmbeddingCache first tries to load the embedding from the in-memory cache + containing the pre-computed chunks populated through `populate_embed_cache`. + If not found, the full embedding is computed and stored on disk to be later accessed + to populate the in-memory cache, and the desired embedding chunk is extracted and returned. + + Args: + paths (list[Path or str]): List of paths from where the embeddings can be loaded. + x (any): Object from which the embedding is extracted. + """ + embeds = [] + for idx, path in enumerate(paths): + cache = self._get_cache_path(path) + if cache in self._current_batch_cache: + embed = self._current_batch_cache[cache] + else: + full_embed = self._compute_embed_fn(path, x, idx) + try: + with flashy.utils.write_and_rename(cache, pid=True) as f: + torch.save(full_embed.cpu(), f) + except Exception as exc: + logger.error('Error saving embed %s (%s): %r', cache, full_embed.shape, exc) + else: + logger.info('New embed cache saved: %s (%s)', cache, full_embed.shape) + embed = self._extract_embed_fn(full_embed, x, idx) + embeds.append(embed) + embed = torch.stack(embeds, dim=0) + return embed + + def populate_embed_cache(self, paths: tp.List[Path], x: tp.Any) -> None: + """Populate in-memory caches for embeddings reading from the embeddings stored on disk. + The in-memory caches consist in a cache for the full embedding and another cache for the + final embedding chunk. Such caches are used to limit the IO access when computing the actual embeddings + and reduce the IO footprint and synchronization points during forward passes. + + Args: + paths (list[Path]): List of paths from where the embeddings can be loaded. + x (any): Object from which the embedding is extracted. + """ + self._current_batch_cache.clear() + if self.cache_path is not None: + futures: list = [] + for path in paths: + assert path is not None, "Path is required for computation from cache" + cache = self._get_cache_path(path) + if cache in self._memory_cache or not cache.exists(): + futures.append(None) + else: + futures.append(self.pool.submit(EmbeddingCache._get_full_embed_from_cache, cache)) + for idx, (path, future) in enumerate(zip(paths, futures)): + assert path is not None + cache = self._get_cache_path(path) + full_embed = None + if future is None: + if cache in self._memory_cache: + full_embed = self._memory_cache[cache] + else: + full_embed = future.result() + if full_embed is not None: + self._memory_cache[cache] = full_embed + full_embed = full_embed.to(self.device) + if full_embed is not None: + embed = self._extract_embed_fn(full_embed, x, idx) + self._current_batch_cache[cache] = embed + + +class CachedBatchWriter: + """Write pre computed caches for mini batches. This can + make loading a lot more efficient depending on your filesystem. + + Args: + cache_folder (Path): folder in which the cached minibatches + will be stored. + + Inside cache folder, the structure is the following: + `epoch_number / update_number.zip` + And the zip file contains one entry per batch item. + + It is possible to use the cache with a batch size smaller than + created with but obviously not larger. Make sure to call the + `start_epoch(epoch)` method for indicating changes of epochs. + + See the grid `audiocraft/grids/musicgen/musicgen_warmup_cache.py` + for an example of how to warmup the cache. + """ + def __init__(self, cache_folder: Path): + self.cache_folder = cache_folder + self._current_epoch: tp.Optional[int] = None + self._current_index = 0 + + def start_epoch(self, epoch: int): + """Call at the beginning of each epoch. + """ + self._current_epoch = epoch + self._current_index = 0 + self._zip_path.parent.mkdir(exist_ok=True, parents=True) + + @staticmethod + def _get_zip_path(cache_folder: Path, epoch: int, index: int): + return cache_folder / f"{epoch:05d}" / f"{index:06d}.zip" + + @property + def _zip_path(self): + assert self._current_epoch is not None + return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, self._current_index) + + def save(self, *content): + """Save one mini batch. This function is distributed-aware + and will automatically merge all the items from the different + workers. + """ + all_contents = [] + for rank in range(flashy.distrib.world_size()): + their_content = flashy.distrib.broadcast_object(content, src=rank) + all_contents.append(their_content) + + if flashy.distrib.is_rank_zero(): + idx = 0 + with flashy.utils.write_and_rename(self._zip_path) as tmp: + with zipfile.ZipFile(tmp, 'w') as zf: + for content in all_contents: + for vals in zip(*content): + with zf.open(f'{idx}', 'w') as f: # type: ignore + torch.save(vals, f) + idx += 1 + flashy.distrib.barrier() + self._current_index += 1 + + +class CachedBatchLoader: + """Loader for cached mini-batches dumped with `CachedBatchWriter`. + + Args: + cache_folder (Path): folder in which the cached minibatches are stored. + batch_size (int): batch size (per GPU) expected. + num_workers (int): number of workers to use for loading. + min_length (int): minimum expected length for each epoch. If some + mini-batches are missing, and error is raised. + + This is iterable just like a regular DataLoader. + """ + + def __init__(self, cache_folder: Path, batch_size: int, + num_workers: int = 10, min_length: int = 1): + self.cache_folder = cache_folder + self.batch_size = batch_size + self.num_workers = num_workers + self.min_length = min_length + self._current_epoch: tp.Optional[int] = None + self.sampler = None # for compatibility with the regular DataLoader + + def __len__(self): + path = CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch or 0, 0).parent + return len([p for p in path.iterdir() if p.suffix == ".zip"]) + + def start_epoch(self, epoch: int): + """Call at the beginning of each epoch. + """ + self._current_epoch = epoch + + def _zip_path(self, index: int): + assert self._current_epoch is not None + return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, index) + + def _load_one(self, index: int): + zip_path = self._zip_path(index) + if not zip_path.exists(): + if index < self.min_length: + raise RuntimeError(f"Cache should have at least {self.min_length} batches, but {index} doesn't exist") + + return None + mode = "rb" if sys.version_info >= (3, 9) else "r" + try: + with zipfile.ZipFile(zip_path, 'r') as zf: + rank = flashy.distrib.rank() + world_size = flashy.distrib.world_size() + root = zipfile.Path(zf) + items = list(root.iterdir()) + total_batch_size = self.batch_size * world_size + if len(items) < total_batch_size: + raise RuntimeError( + f"The cache can handle a max batch size of {len(items)}, " + f"but {total_batch_size} is needed.") + start = rank * self.batch_size + items = items[start: start + self.batch_size] + assert len(items) == self.batch_size + entries = [] + entries = [torch.load(item.open(mode), 'cpu') for item in items] # type: ignore + transposed = zip(*entries) + out = [] + for part in transposed: + assert len(part) > 0 + if isinstance(part[0], torch.Tensor): + out.append(torch.stack(part)) + else: + assert isinstance(part, torch.Tensor) + out.append(part) + return out + except Exception: + logger.error("Error when reading zip path %s", zip_path) + raise + + def __iter__(self): + """This will yields tuples, exactly as provided to the + `CachedBatchWriter.save` method. + """ + pool = ThreadPoolExecutor(self.num_workers) + next_index = 0 + queue = deque() + + def _get_next(): + nonlocal next_index + r = queue.popleft().result() + if r is None: + return None + else: + queue.append(pool.submit(self._load_one, next_index)) + next_index += 1 + return r + + with pool: + # fill the buffer of fetching jobs. + for _ in range(2 * self.num_workers): + queue.append(pool.submit(self._load_one, next_index)) + next_index += 1 + while True: + batch = _get_next() + if batch is None: + return + yield batch diff --git a/audiocraft/utils/checkpoint.py b/audiocraft/utils/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..f6f871837e09c5cc7832b85b0d80b84f59e87ca0 --- /dev/null +++ b/audiocraft/utils/checkpoint.py @@ -0,0 +1,161 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum +import logging +from pathlib import Path +import re +import typing as tp + +import flashy +import torch + +from ..environment import AudioCraftEnvironment + + +logger = logging.getLogger(__name__) + + +class CheckpointSource(Enum): + CURRENT_XP = "current_xp" + PRETRAINED = "pretrained" + OTHER = "other" + + +def checkpoint_name(name: tp.Optional[str] = None, rank: tp.Optional[int] = None, use_fsdp: bool = False) -> str: + """Checkpoint name formatted for all use in AudioCraft codebase and has the following format: + `checkpoint_.th(.)`. By convention, name is expected to be empty for last checkpoint, + 'best' for the best checkpoint or the epoch number. + + Args: + name (str, optional): Name suffix for the checkpoint file stem. + rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided. + use_fsdp (bool): Whether the calling solver relies on FSDP. + Returns: + str: The checkpoint name. + """ + suffix = '' + if rank is None: + rank = flashy.distrib.rank() + if rank > 0 and use_fsdp: + suffix = '.' + str(rank) + name_part = '' + if name is not None: + name_part = f'_{name}' + return f'checkpoint{name_part}.th{suffix}' + + +def is_sharded_checkpoint(path: Path) -> bool: + """Whether the checkpoint at the given path corresponds to a sharded checkpoint across rank.""" + return re.search(r'\.th\.\d+$', path.name) is not None + + +def resolve_checkpoint_path(sig_or_path: tp.Union[Path, str], name: tp.Optional[str] = None, + use_fsdp: bool = False) -> tp.Optional[Path]: + """Resolve a given checkpoint path for a provided dora sig or path. + + Args: + sig_or_path (Path or str): Checkpoint path or dora signature. + name (str, optional): Name suffix for the checkpoint file stem. + rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided. + use_fsdp (bool): Whether the calling solver relies on FSDP. + Returns: + Path, optional: Resolved checkpoint path, if it exists. + """ + from audiocraft import train + xps_root = train.main.dora.dir / 'xps' + sig_or_path = str(sig_or_path) + if sig_or_path.startswith('//sig/'): + sig = sig_or_path[len('//sig/'):] + path = xps_root / sig + else: + path = Path(sig_or_path) + path = AudioCraftEnvironment.resolve_reference_path(path) + + if path.is_dir(): + path = path / checkpoint_name(name, use_fsdp=use_fsdp) + + if path.exists(): + return path + else: + return None + + +def load_checkpoint(checkpoint_path: Path, is_sharded: bool = False) -> tp.Any: + """Load state from checkpoints at the specified checkpoint path.""" + if is_sharded: + rank0_checkpoint_path = checkpoint_path.parent / checkpoint_name(use_fsdp=False) + if rank0_checkpoint_path.exists(): + check_sharded_checkpoint(checkpoint_path, rank0_checkpoint_path) + state = torch.load(checkpoint_path, 'cpu') + logger.info("Checkpoint loaded from %s", checkpoint_path) + return state + + +def save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None: + """Save state to disk to the specified checkpoint_path.""" + _safe_save_checkpoint(state, checkpoint_path, is_sharded) + logger.info("Checkpoint saved to %s", checkpoint_path) + + +def flush_stale_checkpoints(checkpoint_path: Path, keep_last: tp.Optional[int] = None) -> None: + """Flush checkpoints to only keep last N checkpoints.""" + if keep_last is None or keep_last <= 0: + return + checkpoint_dir = checkpoint_path.parent + suffix = '' + if flashy.distrib.rank() > 0: + suffix = f'.{flashy.distrib.rank()}' + checkpoint_files_with_epoch = [] + for path in Path(checkpoint_dir).glob(f'checkpoint_*.th{suffix}'): + epoch_part = path.name.split('.', 1)[0].split('_', 1)[1] + if epoch_part.isdigit(): + checkpoint_files_with_epoch.append((path, int(epoch_part))) + checkpoint_files = [path for path, _ in list(sorted(checkpoint_files_with_epoch, key=lambda t: t[1]))] + total_to_flush = max(0, len(checkpoint_files) - keep_last) + files_to_flush = checkpoint_files[:total_to_flush] + for path in files_to_flush: + logger.debug("Removing checkpoint: %s", str(path)) + path.unlink(missing_ok=True) + + +def check_sharded_checkpoint(checkpoint_path: Path, rank0_checkpoint_path: Path) -> None: + """Check sharded checkpoint state, ensuring the checkpoints are not corrupted.""" + # Finish the work of a previous run that got interrupted while dumping. + old_path = Path(str(checkpoint_path) + '.old') + if old_path.exists(): + raise RuntimeError( + f"Old checkpoint {old_path} from previous version of this code exist, cannot safely proceed.") + token = Path(str(rank0_checkpoint_path) + '.tmp.done') + tmp_path = Path(str(checkpoint_path) + '.tmp') + if token.exists(): + if tmp_path.exists(): + tmp_path.rename(checkpoint_path) + flashy.distrib.barrier() + if flashy.distrib.is_rank_zero() and token.exists(): + token.unlink() + + +def _safe_save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None: + """Save checkpoints in a safe manner even with when sharded checkpoints across nodes.""" + def _barrier_if_sharded(): + if is_sharded: + flashy.distrib.barrier() + + if flashy.distrib.is_rank_zero(): + token = Path(str(checkpoint_path) + '.tmp.done') + if token.exists(): + token.unlink() + _barrier_if_sharded() + with flashy.utils.write_and_rename(checkpoint_path) as f: + torch.save(state, f) + _barrier_if_sharded() + if flashy.distrib.is_rank_zero(): + token.touch() + _barrier_if_sharded() + _barrier_if_sharded() + if flashy.distrib.rank() == 0: + token.unlink() diff --git a/audiocraft/utils/cluster.py b/audiocraft/utils/cluster.py new file mode 100644 index 0000000000000000000000000000000000000000..3380d031739d473fb859c76b9c25350f47fa77e8 --- /dev/null +++ b/audiocraft/utils/cluster.py @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utility functions for SLURM configuration and cluster settings. +""" + +from enum import Enum +import os +import socket +import typing as tp + +import omegaconf + + +class ClusterType(Enum): + AWS = "aws" + FAIR = "fair" + RSC = "rsc" + LOCAL_DARWIN = "darwin" + DEFAULT = "default" # used for any other cluster. + + +def _guess_cluster_type() -> ClusterType: + uname = os.uname() + fqdn = socket.getfqdn() + if uname.sysname == "Linux" and (uname.release.endswith("-aws") or ".ec2" in fqdn): + return ClusterType.AWS + + if fqdn.endswith(".fair"): + return ClusterType.FAIR + + if fqdn.endswith(".facebook.com"): + return ClusterType.RSC + + if uname.sysname == "Darwin": + return ClusterType.LOCAL_DARWIN + + return ClusterType.DEFAULT + + +def get_cluster_type( + cluster_type: tp.Optional[ClusterType] = None, +) -> tp.Optional[ClusterType]: + if cluster_type is None: + return _guess_cluster_type() + + return cluster_type + + +def get_slurm_parameters( + cfg: omegaconf.DictConfig, cluster_type: tp.Optional[ClusterType] = None +) -> omegaconf.DictConfig: + """Update SLURM parameters in configuration based on cluster type. + If the cluster type is not specify, it infers it automatically. + """ + from ..environment import AudioCraftEnvironment + cluster_type = get_cluster_type(cluster_type) + # apply cluster-specific adjustments + if cluster_type == ClusterType.AWS: + cfg["mem_per_gpu"] = None + cfg["constraint"] = None + cfg["setup"] = [] + elif cluster_type == ClusterType.RSC: + cfg["mem_per_gpu"] = None + cfg["setup"] = [] + cfg["constraint"] = None + cfg["partition"] = "learn" + slurm_exclude = AudioCraftEnvironment.get_slurm_exclude() + if slurm_exclude is not None: + cfg["exclude"] = slurm_exclude + return cfg diff --git a/audiocraft/utils/deadlock.py b/audiocraft/utils/deadlock.py new file mode 100644 index 0000000000000000000000000000000000000000..8abd1bbeea5909e664cf816c020bd7c37effdb66 --- /dev/null +++ b/audiocraft/utils/deadlock.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +from queue import Queue, Empty +import signal +import sys +import threading +import traceback + +logger = logging.getLogger(__name__) + + +class DeadlockDetect: + def __init__(self, use: bool = False, timeout: float = 120.): + self.use = use + self.timeout = timeout + self._queue: Queue = Queue() + + def update(self, stage: str): + if self.use: + self._queue.put(stage) + + def __enter__(self): + if self.use: + self._thread = threading.Thread(target=self._detector_thread) + self._thread.start() + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.use: + self._queue.put(None) + self._thread.join() + + def _detector_thread(self): + logger.debug("Deadlock detector started") + last_stage = "init" + while True: + try: + stage = self._queue.get(timeout=self.timeout) + except Empty: + break + if stage is None: + logger.debug("Exiting deadlock detector thread") + return + else: + last_stage = stage + logger.error("Deadlock detector timed out, last stage was %s", last_stage) + for th in threading.enumerate(): + print(th, file=sys.stderr) + traceback.print_stack(sys._current_frames()[th.ident]) + print(file=sys.stderr) + sys.stdout.flush() + sys.stderr.flush() + os.kill(os.getpid(), signal.SIGKILL) diff --git a/audiocraft/utils/export.py b/audiocraft/utils/export.py new file mode 100644 index 0000000000000000000000000000000000000000..28b214017d9ac23934b67e8254a96131cefa6501 --- /dev/null +++ b/audiocraft/utils/export.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utility to export a training checkpoint to a lightweight release checkpoint. +""" + +from pathlib import Path +import typing as tp + +from omegaconf import OmegaConf +import torch + +from audiocraft import __version__ + + +def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]): + """Export only the best state from the given EnCodec checkpoint. This + should be used if you trained your own EnCodec model. + """ + pkg = torch.load(checkpoint_path, 'cpu') + new_pkg = { + 'best_state': pkg['best_state']['model'], + 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), + 'version': __version__, + 'exported': True, + } + Path(out_file).parent.mkdir(exist_ok=True, parents=True) + torch.save(new_pkg, out_file) + return out_file + + +def export_pretrained_compression_model(pretrained_encodec: str, out_file: tp.Union[Path, str]): + """Export a compression model (potentially EnCodec) from a pretrained model. + This is required for packaging the audio tokenizer along a MusicGen or AudioGen model. + Do not include the //pretrained/ prefix. For instance if you trained a model + with `facebook/encodec_32khz`, just put that as a name. Same for `dac_44khz`. + + In that case, this will not actually include a copy of the model, simply the reference + to the model used. + """ + if Path(pretrained_encodec).exists(): + pkg = torch.load(pretrained_encodec) + assert 'best_state' in pkg + assert 'xp.cfg' in pkg + assert 'version' in pkg + assert 'exported' in pkg + else: + pkg = { + 'pretrained': pretrained_encodec, + 'exported': True, + 'version': __version__, + } + Path(out_file).parent.mkdir(exist_ok=True, parents=True) + torch.save(pkg, out_file) + + +def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]): + """Export only the best state from the given MusicGen or AudioGen checkpoint. + """ + pkg = torch.load(checkpoint_path, 'cpu') + if pkg['fsdp_best_state']: + best_state = pkg['fsdp_best_state']['model'] + else: + assert pkg['best_state'] + best_state = pkg['best_state']['model'] + new_pkg = { + 'best_state': best_state, + 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), + 'version': __version__, + 'exported': True, + } + + Path(out_file).parent.mkdir(exist_ok=True, parents=True) + torch.save(new_pkg, out_file) + return out_file diff --git a/audiocraft/utils/export_legacy.py b/audiocraft/utils/export_legacy.py new file mode 100644 index 0000000000000000000000000000000000000000..367c3f3c9f95ae59a95edbb60b470e03cc842fbb --- /dev/null +++ b/audiocraft/utils/export_legacy.py @@ -0,0 +1,70 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Legacy functions used at the time of the first release, kept for referencd. +""" + +from pathlib import Path +import typing as tp + +from omegaconf import OmegaConf, DictConfig +import torch + +from audiocraft import __version__ + + +def _clean_lm_cfg(cfg: DictConfig): + OmegaConf.set_struct(cfg, False) + # This used to be set automatically in the LM solver, need a more robust solution + # for the future. + cfg['transformer_lm']['card'] = 2048 + n_q = 4 + stereo_cfg = getattr(cfg, 'interleave_stereo_codebooks', None) + if stereo_cfg is not None and stereo_cfg.use: + if 'downsample' in stereo_cfg: + del stereo_cfg['downsample'] + n_q = 8 + cfg['transformer_lm']['n_q'] = n_q + # Experimental params no longer supported. + bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters', + 'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop'] + for name in bad_params: + del cfg['transformer_lm'][name] + OmegaConf.set_struct(cfg, True) + return cfg + + +def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]): + pkg = torch.load(checkpoint_path, 'cpu') + new_pkg = { + 'best_state': pkg['ema']['state']['model'], + 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), + # The following params were NOT exported for the first release of MusicGen. + 'version': __version__, + 'exported': True, + } + Path(out_file).parent.mkdir(exist_ok=True, parents=True) + torch.save(new_pkg, out_file) + return out_file + + +def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]): + pkg = torch.load(checkpoint_path, 'cpu') + if pkg['fsdp_best_state']: + best_state = pkg['fsdp_best_state']['model'] + else: + best_state = pkg['best_state']['model'] + new_pkg = { + 'best_state': best_state, + 'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg'])), + # The following params were NOT exported for the first release of MusicGen. + 'version': __version__, + 'exported': True, + } + Path(out_file).parent.mkdir(exist_ok=True, parents=True) + torch.save(new_pkg, out_file) + return out_file diff --git a/audiocraft/utils/notebook.py b/audiocraft/utils/notebook.py new file mode 100644 index 0000000000000000000000000000000000000000..019b9d19e5bef976bedddf428fd25da42a8a9726 --- /dev/null +++ b/audiocraft/utils/notebook.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +try: + import IPython.display as ipd # type: ignore +except ImportError: + # Note in a notebook... + pass + + +import torch + + +def display_audio(samples: torch.Tensor, sample_rate: int): + """Renders an audio player for the given audio samples. + + Args: + samples (torch.Tensor): a Tensor of decoded audio samples + with shapes [B, C, T] or [C, T] + sample_rate (int): sample rate audio should be displayed with. + """ + assert samples.dim() == 2 or samples.dim() == 3 + + samples = samples.detach().cpu() + if samples.dim() == 2: + samples = samples[None, ...] + + for audio in samples: + ipd.display(ipd.Audio(audio, rate=sample_rate)) diff --git a/audiocraft/utils/profiler.py b/audiocraft/utils/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..b45b6d15910b50305c7b212c089ffad3c25b324d --- /dev/null +++ b/audiocraft/utils/profiler.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import typing as tp + +import dora +import torch + + +logger = logging.getLogger(__name__) + + +class Profiler: + """Context manager wrapper for xformers profiler. + """ + def __init__(self, module: torch.nn.Module, enabled: bool = False): + self.profiler: tp.Optional[tp.Any] = None + if enabled: + from xformers.profiler import profile + output_dir = dora.get_xp().folder / 'profiler_data' + logger.info("Profiling activated, results with be saved to %s", output_dir) + self.profiler = profile(output_dir=output_dir, module=module) + + def step(self): + if self.profiler is not None: + self.profiler.step() # type: ignore + + def __enter__(self): + if self.profiler is not None: + return self.profiler.__enter__() # type: ignore + + def __exit__(self, exc_type, exc_value, exc_tb): + if self.profiler is not None: + return self.profiler.__exit__(exc_type, exc_value, exc_tb) # type: ignore diff --git a/audiocraft/utils/samples/__init__.py b/audiocraft/utils/samples/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa --- /dev/null +++ b/audiocraft/utils/samples/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/audiocraft/utils/samples/manager.py b/audiocraft/utils/samples/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..bf0fb21b2d2867c03f7cce6f27d9524fdb89b51d --- /dev/null +++ b/audiocraft/utils/samples/manager.py @@ -0,0 +1,386 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +API that can manage the storage and retrieval of generated samples produced by experiments. + +It offers the following benefits: +* Samples are stored in a consistent way across epoch +* Metadata about the samples can be stored and retrieved +* Can retrieve audio +* Identifiers are reliable and deterministic for prompted and conditioned samples +* Can request the samples for multiple XPs, grouped by sample identifier +* For no-input samples (not prompt and no conditions), samples across XPs are matched + by sorting their identifiers +""" + +from concurrent.futures import ThreadPoolExecutor +from dataclasses import asdict, dataclass +from functools import lru_cache +import hashlib +import json +import logging +from pathlib import Path +import re +import typing as tp +import unicodedata +import uuid + +import dora +import torch + +from ...data.audio import audio_read, audio_write + + +logger = logging.getLogger(__name__) + + +@dataclass +class ReferenceSample: + id: str + path: str + duration: float + + +@dataclass +class Sample: + id: str + path: str + epoch: int + duration: float + conditioning: tp.Optional[tp.Dict[str, tp.Any]] + prompt: tp.Optional[ReferenceSample] + reference: tp.Optional[ReferenceSample] + generation_args: tp.Optional[tp.Dict[str, tp.Any]] + + def __hash__(self): + return hash(self.id) + + def audio(self) -> tp.Tuple[torch.Tensor, int]: + return audio_read(self.path) + + def audio_prompt(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]: + return audio_read(self.prompt.path) if self.prompt is not None else None + + def audio_reference(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]: + return audio_read(self.reference.path) if self.reference is not None else None + + +class SampleManager: + """Audio samples IO handling within a given dora xp. + + The sample manager handles the dumping and loading logic for generated and + references samples across epochs for a given xp, providing a simple API to + store, retrieve and compare audio samples. + + Args: + xp (dora.XP): Dora experiment object. The XP contains information on the XP folder + where all outputs are stored and the configuration of the experiment, + which is useful to retrieve audio-related parameters. + map_reference_to_sample_id (bool): Whether to use the sample_id for all reference samples + instead of generating a dedicated hash id. This is useful to allow easier comparison + with ground truth sample from the files directly without having to read the JSON metadata + to do the mapping (at the cost of potentially dumping duplicate prompts/references + depending on the task). + """ + def __init__(self, xp: dora.XP, map_reference_to_sample_id: bool = False): + self.xp = xp + self.base_folder: Path = xp.folder / xp.cfg.generate.path + self.reference_folder = self.base_folder / 'reference' + self.map_reference_to_sample_id = map_reference_to_sample_id + self.samples: tp.List[Sample] = [] + self._load_samples() + + @property + def latest_epoch(self): + """Latest epoch across all samples.""" + return max(self.samples, key=lambda x: x.epoch).epoch if self.samples else 0 + + def _load_samples(self): + """Scan the sample folder and load existing samples.""" + jsons = self.base_folder.glob('**/*.json') + with ThreadPoolExecutor(6) as pool: + self.samples = list(pool.map(self._load_sample, jsons)) + + @staticmethod + @lru_cache(2**26) + def _load_sample(json_file: Path) -> Sample: + with open(json_file, 'r') as f: + data: tp.Dict[str, tp.Any] = json.load(f) + # fetch prompt data + prompt_data = data.get('prompt') + prompt = ReferenceSample(id=prompt_data['id'], path=prompt_data['path'], + duration=prompt_data['duration']) if prompt_data else None + # fetch reference data + reference_data = data.get('reference') + reference = ReferenceSample(id=reference_data['id'], path=reference_data['path'], + duration=reference_data['duration']) if reference_data else None + # build sample object + return Sample(id=data['id'], path=data['path'], epoch=data['epoch'], duration=data['duration'], + prompt=prompt, conditioning=data.get('conditioning'), reference=reference, + generation_args=data.get('generation_args')) + + def _init_hash(self): + return hashlib.sha1() + + def _get_tensor_id(self, tensor: torch.Tensor) -> str: + hash_id = self._init_hash() + hash_id.update(tensor.numpy().data) + return hash_id.hexdigest() + + def _get_sample_id(self, index: int, prompt_wav: tp.Optional[torch.Tensor], + conditions: tp.Optional[tp.Dict[str, str]]) -> str: + """Computes an id for a sample given its input data. + This id is deterministic if prompt and/or conditions are provided by using a sha1 hash on the input. + Otherwise, a random id of the form "noinput_{uuid4().hex}" is returned. + + Args: + index (int): Batch index, Helpful to differentiate samples from the same batch. + prompt_wav (torch.Tensor): Prompt used during generation. + conditions (dict[str, str]): Conditioning used during generation. + """ + # For totally unconditioned generations we will just use a random UUID. + # The function get_samples_for_xps will do a simple ordered match with a custom key. + if prompt_wav is None and not conditions: + return f"noinput_{uuid.uuid4().hex}" + + # Human readable portion + hr_label = "" + # Create a deterministic id using hashing + hash_id = self._init_hash() + hash_id.update(f"{index}".encode()) + if prompt_wav is not None: + hash_id.update(prompt_wav.numpy().data) + hr_label += "_prompted" + else: + hr_label += "_unprompted" + if conditions: + encoded_json = json.dumps(conditions, sort_keys=True).encode() + hash_id.update(encoded_json) + cond_str = "-".join([f"{key}={slugify(value)}" + for key, value in sorted(conditions.items())]) + cond_str = cond_str[:100] # some raw text might be too long to be a valid filename + cond_str = cond_str if len(cond_str) > 0 else "unconditioned" + hr_label += f"_{cond_str}" + else: + hr_label += "_unconditioned" + + return hash_id.hexdigest() + hr_label + + def _store_audio(self, wav: torch.Tensor, stem_path: Path, overwrite: bool = False) -> Path: + """Stores the audio with the given stem path using the XP's configuration. + + Args: + wav (torch.Tensor): Audio to store. + stem_path (Path): Path in sample output directory with file stem to use. + overwrite (bool): When False (default), skips storing an existing audio file. + Returns: + Path: The path at which the audio is stored. + """ + existing_paths = [ + path for path in stem_path.parent.glob(stem_path.stem + '.*') + if path.suffix != '.json' + ] + exists = len(existing_paths) > 0 + if exists and overwrite: + logger.warning(f"Overwriting existing audio file with stem path {stem_path}") + elif exists: + return existing_paths[0] + + audio_path = audio_write(stem_path, wav, **self.xp.cfg.generate.audio) + return audio_path + + def add_sample(self, sample_wav: torch.Tensor, epoch: int, index: int = 0, + conditions: tp.Optional[tp.Dict[str, str]] = None, prompt_wav: tp.Optional[torch.Tensor] = None, + ground_truth_wav: tp.Optional[torch.Tensor] = None, + generation_args: tp.Optional[tp.Dict[str, tp.Any]] = None) -> Sample: + """Adds a single sample. + The sample is stored in the XP's sample output directory, under a corresponding epoch folder. + Each sample is assigned an id which is computed using the input data. In addition to the + sample itself, a json file containing associated metadata is stored next to it. + + Args: + sample_wav (torch.Tensor): sample audio to store. Tensor of shape [channels, shape]. + epoch (int): current training epoch. + index (int): helpful to differentiate samples from the same batch. + conditions (dict[str, str], optional): conditioning used during generation. + prompt_wav (torch.Tensor, optional): prompt used during generation. Tensor of shape [channels, shape]. + ground_truth_wav (torch.Tensor, optional): reference audio where prompt was extracted from. + Tensor of shape [channels, shape]. + generation_args (dict[str, any], optional): dictionary of other arguments used during generation. + Returns: + Sample: The saved sample. + """ + sample_id = self._get_sample_id(index, prompt_wav, conditions) + reuse_id = self.map_reference_to_sample_id + prompt, ground_truth = None, None + if prompt_wav is not None: + prompt_id = sample_id if reuse_id else self._get_tensor_id(prompt_wav.sum(0, keepdim=True)) + prompt_duration = prompt_wav.shape[-1] / self.xp.cfg.sample_rate + prompt_path = self._store_audio(prompt_wav, self.base_folder / str(epoch) / 'prompt' / prompt_id) + prompt = ReferenceSample(prompt_id, str(prompt_path), prompt_duration) + if ground_truth_wav is not None: + ground_truth_id = sample_id if reuse_id else self._get_tensor_id(ground_truth_wav.sum(0, keepdim=True)) + ground_truth_duration = ground_truth_wav.shape[-1] / self.xp.cfg.sample_rate + ground_truth_path = self._store_audio(ground_truth_wav, self.base_folder / 'reference' / ground_truth_id) + ground_truth = ReferenceSample(ground_truth_id, str(ground_truth_path), ground_truth_duration) + sample_path = self._store_audio(sample_wav, self.base_folder / str(epoch) / sample_id, overwrite=True) + duration = sample_wav.shape[-1] / self.xp.cfg.sample_rate + sample = Sample(sample_id, str(sample_path), epoch, duration, conditions, prompt, ground_truth, generation_args) + self.samples.append(sample) + with open(sample_path.with_suffix('.json'), 'w') as f: + json.dump(asdict(sample), f, indent=2) + return sample + + def add_samples(self, samples_wavs: torch.Tensor, epoch: int, + conditioning: tp.Optional[tp.List[tp.Dict[str, tp.Any]]] = None, + prompt_wavs: tp.Optional[torch.Tensor] = None, + ground_truth_wavs: tp.Optional[torch.Tensor] = None, + generation_args: tp.Optional[tp.Dict[str, tp.Any]] = None) -> tp.List[Sample]: + """Adds a batch of samples. + The samples are stored in the XP's sample output directory, under a corresponding + epoch folder. Each sample is assigned an id which is computed using the input data and their batch index. + In addition to the sample itself, a json file containing associated metadata is stored next to it. + + Args: + sample_wavs (torch.Tensor): Batch of audio wavs to store. Tensor of shape [batch_size, channels, shape]. + epoch (int): Current training epoch. + conditioning (list of dict[str, str], optional): List of conditions used during generation, + one per sample in the batch. + prompt_wavs (torch.Tensor, optional): Prompts used during generation. Tensor of shape + [batch_size, channels, shape]. + ground_truth_wav (torch.Tensor, optional): Reference audio where prompts were extracted from. + Tensor of shape [batch_size, channels, shape]. + generation_args (dict[str, Any], optional): Dictionary of other arguments used during generation. + Returns: + samples (list of Sample): The saved audio samples with prompts, ground truth and metadata. + """ + samples = [] + for idx, wav in enumerate(samples_wavs): + prompt_wav = prompt_wavs[idx] if prompt_wavs is not None else None + gt_wav = ground_truth_wavs[idx] if ground_truth_wavs is not None else None + conditions = conditioning[idx] if conditioning is not None else None + samples.append(self.add_sample(wav, epoch, idx, conditions, prompt_wav, gt_wav, generation_args)) + return samples + + def get_samples(self, epoch: int = -1, max_epoch: int = -1, exclude_prompted: bool = False, + exclude_unprompted: bool = False, exclude_conditioned: bool = False, + exclude_unconditioned: bool = False) -> tp.Set[Sample]: + """Returns a set of samples for this XP. Optionally, you can filter which samples to obtain. + Please note that existing samples are loaded during the manager's initialization, and added samples through this + manager are also tracked. Any other external changes are not tracked automatically, so creating a new manager + is the only way detect them. + + Args: + epoch (int): If provided, only return samples corresponding to this epoch. + max_epoch (int): If provided, only return samples corresponding to the latest epoch that is <= max_epoch. + exclude_prompted (bool): If True, does not include samples that used a prompt. + exclude_unprompted (bool): If True, does not include samples that did not use a prompt. + exclude_conditioned (bool): If True, excludes samples that used conditioning. + exclude_unconditioned (bool): If True, excludes samples that did not use conditioning. + Returns: + Samples (set of Sample): The retrieved samples matching the provided filters. + """ + if max_epoch >= 0: + samples_epoch = max(sample.epoch for sample in self.samples if sample.epoch <= max_epoch) + else: + samples_epoch = self.latest_epoch if epoch < 0 else epoch + samples = { + sample + for sample in self.samples + if ( + (sample.epoch == samples_epoch) and + (not exclude_prompted or sample.prompt is None) and + (not exclude_unprompted or sample.prompt is not None) and + (not exclude_conditioned or not sample.conditioning) and + (not exclude_unconditioned or sample.conditioning) + ) + } + return samples + + +def slugify(value: tp.Any, allow_unicode: bool = False): + """Process string for safer file naming. + + Taken from https://github.com/django/django/blob/master/django/utils/text.py + + Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated + dashes to single dashes. Remove characters that aren't alphanumerics, + underscores, or hyphens. Convert to lowercase. Also strip leading and + trailing whitespace, dashes, and underscores. + """ + value = str(value) + if allow_unicode: + value = unicodedata.normalize("NFKC", value) + else: + value = ( + unicodedata.normalize("NFKD", value) + .encode("ascii", "ignore") + .decode("ascii") + ) + value = re.sub(r"[^\w\s-]", "", value.lower()) + return re.sub(r"[-\s]+", "-", value).strip("-_") + + +def _match_stable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> tp.Dict[str, tp.List[Sample]]: + # Create a dictionary of stable id -> sample per XP + stable_samples_per_xp = [{ + sample.id: sample for sample in samples + if sample.prompt is not None or sample.conditioning + } for samples in samples_per_xp] + # Set of all stable ids + stable_ids = {id for samples in stable_samples_per_xp for id in samples.keys()} + # Dictionary of stable id -> list of samples. If an XP does not have it, assign None + stable_samples = {id: [xp.get(id) for xp in stable_samples_per_xp] for id in stable_ids} + # Filter out ids that contain None values (we only want matched samples after all) + # cast is necessary to avoid mypy linter errors. + return {id: tp.cast(tp.List[Sample], samples) for id, samples in stable_samples.items() if None not in samples} + + +def _match_unstable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> tp.Dict[str, tp.List[Sample]]: + # For unstable ids, we use a sorted list since we'll match them in order + unstable_samples_per_xp = [[ + sample for sample in sorted(samples, key=lambda x: x.id) + if sample.prompt is None and not sample.conditioning + ] for samples in samples_per_xp] + # Trim samples per xp so all samples can have a match + min_len = min([len(samples) for samples in unstable_samples_per_xp]) + unstable_samples_per_xp = [samples[:min_len] for samples in unstable_samples_per_xp] + # Dictionary of index -> list of matched samples + return { + f'noinput_{i}': [samples[i] for samples in unstable_samples_per_xp] for i in range(min_len) + } + + +def get_samples_for_xps(xps: tp.List[dora.XP], **kwargs) -> tp.Dict[str, tp.List[Sample]]: + """Gets a dictionary of matched samples across the given XPs. + Each dictionary entry maps a sample id to a list of samples for that id. The number of samples per id + will always match the number of XPs provided and will correspond to each XP in the same order given. + In other words, only samples that can be match across all provided XPs will be returned + in order to satisfy this rule. + + There are two types of ids that can be returned: stable and unstable. + * Stable IDs are deterministic ids that were computed by the SampleManager given a sample's inputs + (prompts/conditioning). This is why we can match them across XPs. + * Unstable IDs are of the form "noinput_{idx}" and are generated on-the-fly, in order to map samples + that used non-deterministic, random ids. This is the case for samples that did not use prompts or + conditioning for their generation. This function will sort these samples by their id and match them + by their index. + + Args: + xps: a list of XPs to match samples from. + start_epoch (int): If provided, only return samples corresponding to this epoch or newer. + end_epoch (int): If provided, only return samples corresponding to this epoch or older. + exclude_prompted (bool): If True, does not include samples that used a prompt. + exclude_unprompted (bool): If True, does not include samples that did not use a prompt. + exclude_conditioned (bool): If True, excludes samples that used conditioning. + exclude_unconditioned (bool): If True, excludes samples that did not use conditioning. + """ + managers = [SampleManager(xp) for xp in xps] + samples_per_xp = [manager.get_samples(**kwargs) for manager in managers] + stable_samples = _match_stable_samples(samples_per_xp) + unstable_samples = _match_unstable_samples(samples_per_xp) + return dict(stable_samples, **unstable_samples) diff --git a/audiocraft/utils/utils.py b/audiocraft/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2c5799f8bc4ee07dd8d60d6afe67fbc5a6039215 --- /dev/null +++ b/audiocraft/utils/utils.py @@ -0,0 +1,298 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from concurrent.futures import ProcessPoolExecutor +from contextlib import contextmanager +from functools import wraps, lru_cache +import hashlib +import json +import logging +from pathlib import Path +import typing as tp + +import flashy +import flashy.distrib +import omegaconf +import torch +from torch.nn.utils.rnn import pad_sequence + + +logger = logging.getLogger(__name__) + + +def model_hash(model: torch.nn.Module) -> str: + """Return a model hash. This should allow us to track regressions in model init + from the logs of past experiments. + """ + hasher = hashlib.sha1() + for p in model.parameters(): + hasher.update(p.data.cpu().numpy().tobytes()) + return hasher.hexdigest() + + +def dict_from_config(cfg: omegaconf.DictConfig) -> dict: + """Convenience function to map an omegaconf configuration to a dictionary. + + Args: + cfg (omegaconf.DictConfig): Original configuration to map to dict. + Returns: + dict: Config as dictionary object. + """ + dct = omegaconf.OmegaConf.to_container(cfg, resolve=True) + assert isinstance(dct, dict) + return dct + + +def random_subset(dataset, max_samples: int, seed: int = 42) -> torch.utils.data.Subset: + if max_samples >= len(dataset): + return dataset + + generator = torch.Generator().manual_seed(seed) + perm = torch.randperm(len(dataset), generator=generator) + return torch.utils.data.Subset(dataset, perm[:max_samples].tolist()) + + +def get_loader(dataset, num_samples: tp.Optional[int], batch_size: int, + num_workers: int, seed: int, **kwargs) -> torch.utils.data.DataLoader: + """Convenience function to load dataset into a dataloader with optional subset sampling. + + Args: + dataset: Dataset to load. + num_samples (Optional[int]): Number of samples to limit subset size. + batch_size (int): Batch size. + num_workers (int): Number of workers for data loading. + seed (int): Random seed. + """ + if num_samples is not None: + dataset = random_subset(dataset, num_samples, seed) + + dataloader = flashy.distrib.loader( + dataset, + batch_size=batch_size, + num_workers=num_workers, + **kwargs + ) + return dataloader + + +def get_dataset_from_loader(dataloader): + dataset = dataloader.dataset + if isinstance(dataset, torch.utils.data.Subset): + return dataset.dataset + else: + return dataset + + +def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None): + """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension. + + Args: + input (torch.Tensor): The input tensor containing probabilities. + num_samples (int): Number of samples to draw. + replacement (bool): Whether to draw with replacement or not. + Keywords args: + generator (torch.Generator): A pseudorandom number generator for sampling. + Returns: + torch.Tensor: Last dimension contains num_samples indices + sampled from the multinomial probability distribution + located in the last dimension of tensor input. + """ + input_ = input.reshape(-1, input.shape[-1]) + output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator) + output = output_.reshape(*list(input.shape[:-1]), -1) + return output + + +def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor: + """Sample next token from top K values along the last dimension of the input probs tensor. + + Args: + probs (torch.Tensor): Input probabilities with token candidates on the last dimension. + k (int): The k in “top-k”. + Returns: + torch.Tensor: Sampled tokens. + """ + top_k_value, _ = torch.topk(probs, k, dim=-1) + min_value_top_k = top_k_value[..., [-1]] + probs *= (probs >= min_value_top_k).float() + probs.div_(probs.sum(dim=-1, keepdim=True)) + next_token = multinomial(probs, num_samples=1) + return next_token + + +def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: + """Sample next token from top P probabilities along the last dimension of the input probs tensor. + + Args: + probs (torch.Tensor): Input probabilities with token candidates on the last dimension. + p (int): The p in “top-p”. + Returns: + torch.Tensor: Sampled tokens. + """ + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort *= (~mask).float() + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token + + +class DummyPoolExecutor: + """Dummy pool executor to use when we actually have only 1 worker. + (e.g. instead of ProcessPoolExecutor). + """ + class DummyResult: + def __init__(self, func, *args, **kwargs): + self.func = func + self.args = args + self.kwargs = kwargs + + def result(self): + return self.func(*self.args, **self.kwargs) + + def __init__(self, workers, mp_context=None): + pass + + def submit(self, func, *args, **kwargs): + return DummyPoolExecutor.DummyResult(func, *args, **kwargs) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + return + + +def get_pool_executor(num_workers: int, mp_context=None): + return ProcessPoolExecutor(num_workers, mp_context) if num_workers > 1 else DummyPoolExecutor(1) + + +def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor: + """Utility function to convert a tensor of sequence lengths to a mask (useful when working on padded sequences). + For example: [3, 5] => [[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]] + + Args: + lengths (torch.Tensor): tensor with lengths + max_len (int): can set the max length manually. Defaults to None. + Returns: + torch.Tensor: mask with 0s where there is pad tokens else 1s + """ + assert len(lengths.shape) == 1, "Length shape should be 1 dimensional." + final_length = lengths.max().item() if not max_len else max_len + final_length = max(final_length, 1) # if all seqs are of len zero we don't want a zero-size tensor + return torch.arange(final_length, device=lengths.device)[None, :] < lengths[:, None] + + +def hash_trick(word: str, vocab_size: int) -> int: + """Hash trick to pair each word with an index + + Args: + word (str): word we wish to convert to an index + vocab_size (int): size of the vocabulary + Returns: + int: index of the word in the embedding LUT + """ + hash = int(hashlib.sha256(word.encode("utf-8")).hexdigest(), 16) + return hash % vocab_size + + +def with_rank_rng(base_seed: int = 1234): + """Decorator for a function so that the function will use a Random Number Generator + whose state depend on the GPU rank. The original RNG state is restored upon returning. + + Args: + base_seed (int): Random seed. + """ + def _decorator(fun: tp.Callable): + @wraps(fun) + def _decorated(*args, **kwargs): + state = torch.get_rng_state() + seed = base_seed ^ flashy.distrib.rank() + torch.manual_seed(seed) + logger.debug('Rank dependent seed set to %d', seed) + try: + return fun(*args, **kwargs) + finally: + torch.set_rng_state(state) + logger.debug('RNG state restored.') + return _decorated + return _decorator + + +def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tensor, torch.Tensor]: + """Get a list of tensors and collate them to a single tensor. according to the following logic: + - `dim` specifies the time dimension which will be stacked and padded. + - The output will contain 1 new dimension (dimension index 0) which will be the size of + of the original list. + + Args: + tensors (tp.List[torch.Tensor]): List of tensors to collate. + dim (int): Dimension which will be stacked and padded. + Returns: + tp.Tuple[torch.Tensor, torch.Tensor]: + torch.Tensor: Stacked and padded tensor. The output will contain 1 new dimension + (dimension index 0) which will be the size of the original list. + torch.Tensor: Tensor containing length of original tensor sizes (without padding). + """ + tensors = [x.transpose(0, dim) for x in tensors] + lens = torch.LongTensor([len(x) for x in tensors]) + padded_tensors = pad_sequence(tensors) + padded_tensors = padded_tensors.transpose(0, 1) + padded_tensors = padded_tensors.transpose(1, dim + 1) + return padded_tensors, lens + + +# TODO: Move to flashy? +def copy_state(state: tp.Any, device: tp.Union[torch.device, str] = 'cpu', + dtype: tp.Optional[torch.dtype] = None) -> tp.Any: + if isinstance(state, torch.Tensor): + if dtype is None or not state.is_floating_point(): + dtype = state.dtype + return state.detach().to(device=device, dtype=dtype, copy=True) + elif isinstance(state, dict): + return {k: copy_state(v, device, dtype) for k, v in state.items()} + elif isinstance(state, list): + return [copy_state(v, device, dtype) for v in state] + + +# TODO: Move to flashy? +@contextmanager +def swap_state(model, state, **kwargs): + old_state = copy_state(model.state_dict()) + model.load_state_dict(state, **kwargs) + try: + yield + finally: + model.load_state_dict(old_state) + + +@lru_cache(None) +def warn_once(logger, msg): + """Warn about a given message only once.""" + logger.warning(msg) + + +def is_jsonable(x: tp.Any): + """Check if an object can be serialized into a json:""" + try: + json.dumps(x) + return True + except (TypeError, OverflowError): + return False + + +def load_clap_state_dict(clap_model, path: tp.Union[str, Path]): + """Wrapper around state dict loading of CLAP model + addressing compatibility issues between CLAP and AudioCraft + HuggingFace transformer version. + See: https://github.com/LAION-AI/CLAP/issues/118 + """ + from clap_module.factory import load_state_dict # type: ignore + pkg = load_state_dict(path) + pkg.pop('text_branch.embeddings.position_ids', None) + clap_model.model.load_state_dict(pkg) diff --git a/msinference.py b/msinference.py new file mode 100644 index 0000000000000000000000000000000000000000..e3b38131daa2bd30e53d514fa49ff59bef3ebe4a --- /dev/null +++ b/msinference.py @@ -0,0 +1,261 @@ +import torch +from cached_path import cached_path +import nltk +# nltk.download('punkt') +import random +random.seed(0) +import numpy as np +np.random.seed(0) +import time +import random +import yaml +import torch.nn.functional as F +import copy +import torchaudio +import librosa +from models import * + +from scipy.io.wavfile import write +from munch import Munch +from torch import nn +from nltk.tokenize import word_tokenize +from monotonic_align import mask_from_lens +from monotonic_align.core import maximum_path_c + +torch.manual_seed(0) +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True + + +# IPA Phonemizer: https://github.com/bootphon/phonemizer + +_pad = "$" +_punctuation = ';:,.!?¡¿—…"«»“” ' +_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' +_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" + +# Export all symbols: +symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) + +dicts = {} +for i in range(len((symbols))): + dicts[symbols[i]] = i + +class TextCleaner: + def __init__(self, dummy=None): + self.word_index_dictionary = dicts + print(len(dicts)) + def __call__(self, text): + indexes = [] + for char in text: + try: + indexes.append(self.word_index_dictionary[char]) + except KeyError: + print('CLEAN', text) + return indexes + + + +textclenaer = TextCleaner() + + +to_mel = torchaudio.transforms.MelSpectrogram( + n_mels=80, n_fft=2048, win_length=1200, hop_length=300) +mean, std = -4, 4 + +# START UTIL + + + + + + + +def recursive_munch(d): + if isinstance(d, dict): + return Munch((k, recursive_munch(v)) for k, v in d.items()) + elif isinstance(d, list): + return [recursive_munch(v) for v in d] + else: + return d + + + +# ======== UTILS ABOVE + +def length_to_mask(lengths): + mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) + mask = torch.gt(mask+1, lengths.unsqueeze(1)) + return mask + +def preprocess(wave): + wave_tensor = torch.from_numpy(wave).float() + mel_tensor = to_mel(wave_tensor) + mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std + return mel_tensor + +def compute_style(path): + wave, sr = librosa.load(path, sr=24000) + audio, index = librosa.effects.trim(wave, top_db=30) + if sr != 24000: + audio = librosa.resample(audio, sr, 24000) + mel_tensor = preprocess(audio).to(device) + + with torch.no_grad(): + ref_s = model.style_encoder(mel_tensor.unsqueeze(1)) + ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1)) + + return torch.cat([ref_s, ref_p], dim=1) + +device = 'cpu' +if torch.cuda.is_available(): + device = 'cuda' +elif torch.backends.mps.is_available(): + # print("MPS would be available but cannot be used rn") + pass + # device = 'mps' + +import phonemizer +global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True) +# phonemizer = Phonemizer.from_checkpoint(str(cached_path('https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_ipa_forward.pt'))) + + +config = yaml.safe_load(open(str('Utils/config.yml'))) + +# load pretrained ASR model +ASR_config = config.get('ASR_config', False) +ASR_path = config.get('ASR_path', False) +text_aligner = load_ASR_models(ASR_path, ASR_config) + +# load pretrained F0 model +F0_path = config.get('F0_path', False) +pitch_extractor = load_F0_models(F0_path) + +# load BERT model +from Utils.PLBERT.util import load_plbert +BERT_path = config.get('PLBERT_dir', False) +plbert = load_plbert(BERT_path) + +model_params = recursive_munch(config['model_params']) +model = build_model(model_params, text_aligner, pitch_extractor, plbert) +_ = [model[key].eval() for key in model] +_ = [model[key].to(device) for key in model] + +# params_whole = torch.load("Models/LibriTTS/epochs_2nd_00020.pth", map_location='cpu') +params_whole = torch.load(str(cached_path("hf://yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020.pth")), map_location='cpu') +params = params_whole['net'] + +for key in model: + if key in params: + print('%s loaded' % key) + try: + model[key].load_state_dict(params[key]) + except: + from collections import OrderedDict + state_dict = params[key] + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[7:] # remove `module.` + new_state_dict[name] = v + # load params + model[key].load_state_dict(new_state_dict, strict=False) +# except: +# _load(params[key], model[key]) +_ = [model[key].eval() for key in model] + +from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule + +sampler = DiffusionSampler( + model.diffusion.diffusion, + sampler=ADPM2Sampler(), + sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters + clamp=False +) + +def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1, use_gruut=False): + text = text.strip() + ps = global_phonemizer.phonemize([text]) + # print(f'PHONEMIZER: {ps=}\n\n') #PHONEMIZER: ps=['ɐbˈɛbæbləm '] + ps = word_tokenize(ps[0]) + # print(f'TOKENIZER: {ps=}\n\n') #OKENIZER: ps=['ɐbˈɛbæbləm'] + ps = ' '.join(ps) + tokens = textclenaer(ps) + # print(f'TEXTCLEAN: {ps=}\n\n') #TEXTCLEAN: ps='ɐbˈɛbæbləm' + tokens.insert(0, 0) + tokens = torch.LongTensor(tokens).to(device).unsqueeze(0) + print(f'TOKENSFINAL: {ps=}\n\n') + + with torch.no_grad(): + input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device) + text_mask = length_to_mask(input_lengths).to(device) + # ----------------------- + # WHO TRANSLATES these tokens to sylla + # print(text_mask.shape, '\n__\n', tokens, '\n__\n', text_mask.min(), text_mask.max()) + # text_mask=is binary + # tokes = tensor([[ 0, 55, 157, 86, 125, 83, 55, 156, 57, 158, 123, 48, 83, 61, + # 157, 102, 61, 16, 138, 64, 16, 53, 156, 138, 54, 62, 131, 85, + # 123, 83, 54, 16, 50, 156, 86, 123, 102, 125, 102, 46, 147, 16, + # 62, 135, 16, 76, 158, 92, 55, 156, 86, 56, 62, 177, 46, 16, + # 50, 157, 43, 102, 58, 85, 55, 156, 51, 158, 46, 51, 158, 83, + # 16, 48, 76, 158, 123, 16, 72, 53, 61, 157, 86, 61, 83, 44, + # 156, 102, 54, 177, 125, 51, 16, 72, 56, 46, 16, 102, 112, 53, + # 54, 156, 63, 158, 147, 83, 56, 16, 4]], device='cuda:0') + + + t_en = model.text_encoder(tokens, input_lengths, text_mask) + bert_dur = model.bert(tokens, attention_mask=(~text_mask).int()) + d_en = model.bert_encoder(bert_dur).transpose(-1, -2) + # print('BERTdu', bert_dur.shape, tokens.shape, '\n') # bert what is the 768 per token -> IS USED in sampler + # BERTdu torch.Size([1, 11, 768]) torch.Size([1, 11]) + + s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device), + embedding=bert_dur, + embedding_scale=embedding_scale, + features=ref_s, # reference from the same speaker as the embedding + num_steps=diffusion_steps).squeeze(1) + + + s = s_pred[:, 128:] + ref = s_pred[:, :128] + + ref = alpha * ref + (1 - alpha) * ref_s[:, :128] + s = beta * s + (1 - beta) * ref_s[:, 128:] + + d = model.predictor.text_encoder(d_en, + s, input_lengths, text_mask) + + x, _ = model.predictor.lstm(d) + duration = model.predictor.duration_proj(x) + + duration = torch.sigmoid(duration).sum(axis=-1) + pred_dur = torch.round(duration.squeeze()).clamp(min=1) + + + pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data)) + c_frame = 0 + for i in range(pred_aln_trg.size(0)): + pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1 + c_frame += int(pred_dur[i].data) + + # encode prosody + en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)) + if model_params.decoder.type == "hifigan": + asr_new = torch.zeros_like(en) + asr_new[:, :, 0] = en[:, :, 0] + asr_new[:, :, 1:] = en[:, :, 0:-1] + en = asr_new + + F0_pred, N_pred = model.predictor.F0Ntrain(en, s) + + asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device)) + if model_params.decoder.type == "hifigan": + asr_new = torch.zeros_like(asr) + asr_new[:, :, 0] = asr[:, :, 0] + asr_new[:, :, 1:] = asr[:, :, 0:-1] + asr = asr_new + + out = model.decoder(asr, + F0_pred, N_pred, ref.squeeze().unsqueeze(0)) + + + return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later \ No newline at end of file diff --git a/text_utils.py b/text_utils.py index 5ed5f96f98cd4edcb600598f5c80dbdc8566c0bf..acf11856178d749f5d577ef405b0bc82c0660680 100644 --- a/text_utils.py +++ b/text_utils.py @@ -2,6 +2,7 @@ import numpy as np import re import codecs +import textwrap # IPA Phonemizer: https://github.com/bootphon/phonemizer _pad = "$" @@ -82,6 +83,10 @@ def split_into_sentences(text): text = text.replace("",".") sentences = text.split("") sentences = [s.strip() for s in sentences] + + # Split Very long sentences >500 phoneme - StyleTTS2 crashes + sentences = [sub_sent+' ' for s in sentences for sub_sent in textwrap.wrap(s, 400, break_long_words=0)] + if sentences and not sentences[-1]: sentences = sentences[:-1] return sentences diff --git a/tts.py b/tts.py new file mode 100644 index 0000000000000000000000000000000000000000..b81e4db4f81f230785d697227bc5f101990717d1 --- /dev/null +++ b/tts.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- +import numpy as np +import argparse +import os +import requests + +# SSH AGENT +# eval $(ssh-agent -s) +# ssh-add ~/.ssh/id_ed25519_github2024 +# +# git remote set-url origin git@github.com:audeering/shift +# == + + + + + +def command_line_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + '--affective', + help="Select Emotional or non-emotional variant of Available voices: https://audeering.github.io/shift/", + action='store_false', + ) + parser.add_argument( + '--device', + help="Device ID", + type=str, + default='cpu', + ) + parser.add_argument( + '--text', + help="Text to be synthesized.", + default='sample.txt', + type=str, + ) + parser.add_argument( + '--native', + help=""" + --native: (without argument) a flag to do voice cloning using the speech from --video, + --native my_voice.wav: Voice cloning from user provided audio""", + # nargs='?', + # const=None, + # default=False # default has to be none + ) + parser.add_argument( + '--voice', + help="TTS voice - Available voices: https://audeering.github.io/shift/", + default="en_US/m-ailabs_low#judy_bieber", #'en_US/cmu-arctic_low#lnh', + type=str, + ) + parser.add_argument( + '--image', + help="If provided is set as background for output video, see --text", + type=str, + ) + parser.add_argument( + '--video', + help="Video file for video translation. Voice cloned from the video", + type=str, + ) + parser.add_argument( + '--out_file', + help="Output file name.", + type=str, + default='out' + ) + parser.add_argument( + '--scene', + help='Sound scene description.', + type=str, + default='calm background sounds of a castle' + ) + return parser + +def send_to_server(args): + url = "http://192.168.88.209:5000" + + payload = { + 'affective': args.affective, + 'voice': args.voice, + 'native': args.native, + 'text': args.text, + 'image': args.image, + 'video': args.video, + 'scene': args.scene, + 'out_file': args.out_file + } + + # In data= we can write args + + # In files= sent actual files if provided + text_file = open(args.text, 'rb') + + image_file, video_file, native_file = None, None, None + if args.image is not None: + print('\nLOADING IMAGE\n') + try: + image_file = open(args.image, 'rb') + except FileNotFoundError: + pass + + + if args.video is not None: + print('\nLOADING vid\n') + try: + video_file = open(args.video, 'rb') + except FileNotFoundError: + pass + + if args.native is not None: + print('\nLOADING natv\n') + try: + native_file = open(args.native, 'rb') + except FileNotFoundError: + pass + + + + # --------------------- send this extra + + print('Sending...\n') + + response = requests.post(url, data=payload, + files=[(args.text, text_file), + (args.image, image_file), + (args.video, video_file), + (args.native, native_file)]) # NONEs do not arrive to servers dict + + # Check the response from the server + if response.status_code == 200: + print("\nRequest was successful!") + # print("Response:", respdonse.__dict__.keys(), '\n=====\n') + + else: + print("Failed to send the request") + print("Status Code:", response.status_code) + print("Response:", response.text) + return response + + +def cli(): + parser = command_line_args() + args = parser.parse_args() + response = send_to_server(args) + + with open( + args.out_file + '.' + response.headers['suffix-file-type'].split('.')[-1], + 'wb' + ) as f: + f.write(response.content) + print('REsponse AT client []\n----------------------------', response.headers) + + +if __name__ == '__main__': + cli() + +# assume also video and text for video we have to write some classes for video for audiocraft +# then call tts.py on this video with nonempty labels - thus calls audiocraft \ No newline at end of file diff --git a/utils.py b/utils.py deleted file mode 100644 index c2206d9277879b5abbd6d29be86eb2f181a8c1db..0000000000000000000000000000000000000000 --- a/utils.py +++ /dev/null @@ -1,74 +0,0 @@ -from monotonic_align import maximum_path -from monotonic_align import mask_from_lens -from monotonic_align.core import maximum_path_c -import numpy as np -import torch -import copy -from torch import nn -import torch.nn.functional as F -import torchaudio -import librosa -import matplotlib.pyplot as plt -from munch import Munch - -def maximum_path(neg_cent, mask): - """ Cython optimized version. - neg_cent: [b, t_t, t_s] - mask: [b, t_t, t_s] - """ - device = neg_cent.device - dtype = neg_cent.dtype - neg_cent = np.ascontiguousarray(neg_cent.data.cpu().numpy().astype(np.float32)) - path = np.ascontiguousarray(np.zeros(neg_cent.shape, dtype=np.int32)) - - t_t_max = np.ascontiguousarray(mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32)) - t_s_max = np.ascontiguousarray(mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32)) - maximum_path_c(path, neg_cent, t_t_max, t_s_max) - return torch.from_numpy(path).to(device=device, dtype=dtype) - -def get_data_path_list(train_path=None, val_path=None): - if train_path is None: - train_path = "Data/train_list.txt" - if val_path is None: - val_path = "Data/val_list.txt" - - with open(train_path, 'r', encoding='utf-8', errors='ignore') as f: - train_list = f.readlines() - with open(val_path, 'r', encoding='utf-8', errors='ignore') as f: - val_list = f.readlines() - - return train_list, val_list - -def length_to_mask(lengths): - mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) - mask = torch.gt(mask+1, lengths.unsqueeze(1)) - return mask - -# for norm consistency loss -def log_norm(x, mean=-4, std=4, dim=2): - """ - normalized log mel -> mel -> norm -> log(norm) - """ - x = torch.log(torch.exp(x * std + mean).norm(dim=dim)) - return x - -def get_image(arrs): - plt.switch_backend('agg') - fig = plt.figure() - ax = plt.gca() - ax.imshow(arrs) - - return fig - -def recursive_munch(d): - if isinstance(d, dict): - return Munch((k, recursive_munch(v)) for k, v in d.items()) - elif isinstance(d, list): - return [recursive_munch(v) for v in d] - else: - return d - -def log_print(message, logger): - logger.info(message) - print(message) - \ No newline at end of file