mimbres commited on
Commit
7888f4e
1 Parent(s): dd13312
Files changed (5) hide show
  1. .gitignore +2 -0
  2. app.py +56 -0
  3. gradio_helper.py +80 -0
  4. html_helper.py +100 -0
  5. model_helper.py +161 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ amt/
2
+ examples/
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import gradio as gr
3
+
4
+ from gradio_helper import *
5
+
6
+ AUDIO_EXAMPLES = glob.glob('/content/examples/*.*', recursive=True)
7
+ YOUTUBE_EXAMPLES = ["https://www.youtube.com/watch?v=vMboypSkj3c"]
8
+
9
+ theme = 'gradio/dracula_revamped' #'Insuz/Mocha' #gr.themes.Soft()
10
+ with gr.Blocks(theme=theme) as demo:
11
+
12
+ with gr.Row():
13
+ with gr.Column(scale=10):
14
+ gr.Markdown(
15
+ """
16
+ # YourMT3+: Bridging the Gap in Multi-instrument Music Transcription with Advanced Model Architectures and Cross-dataset Stem Augmentation
17
+ """)
18
+
19
+ with gr.Group():
20
+ with gr.Tab("Upload audio"):
21
+ # Input
22
+ audio_input = gr.Audio(label="Record Audio", type="filepath",
23
+ show_share_button=True, show_download_button=True)
24
+ # Display examples
25
+ gr.Examples(examples=AUDIO_EXAMPLES, inputs=audio_input)
26
+ # Submit button
27
+ transcribe_audio_button = gr.Button("Transcribe", variant="primary")
28
+ # Transcribe
29
+ output_tab1 = gr.HTML()
30
+ # audio_output = gr.Text(label="Audio Info")
31
+ # transcribe_audio_button.click(process_audio, inputs=audio_input, outputs=output_tab1)
32
+ transcribe_audio_button.click(process_audio, inputs=audio_input, outputs=output_tab1)
33
+
34
+ with gr.Tab("From YouTube"):
35
+ with gr.Row():
36
+ # Input URL
37
+ youtube_url = gr.Textbox(label="YouTube Link URL",
38
+ placeholder="https://youtu.be/...")
39
+ # Play youtube
40
+ youtube_player = gr.HTML(render=True)
41
+ with gr.Row():
42
+ # Play button
43
+ play_video_button = gr.Button("Play", variant="primary")
44
+ # Submit button
45
+ transcribe_video_button = gr.Button("Transcribe", variant="primary")
46
+ # Transcribe
47
+ output_tab2 = gr.HTML(render=True)
48
+ # video_output = gr.Text(label="Video Info")
49
+ transcribe_video_button.click(process_video, inputs=youtube_url, outputs=output_tab2)
50
+ # Play
51
+ play_video_button.click(play_video, inputs=youtube_url, outputs=youtube_player)
52
+
53
+ # Display examples
54
+ gr.Examples(examples=YOUTUBE_EXAMPLES, inputs=youtube_url)
55
+
56
+ demo.launch(debug=True)
gradio_helper.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @title GradIO helper
2
+ import os
3
+ import subprocess
4
+ import glob
5
+ from typing import Tuple, Dict, Literal
6
+ from ctypes import ArgumentError
7
+ # from google.colab import output
8
+
9
+ from model_helper import *
10
+ from html_helper import *
11
+
12
+ from pytube import YouTube
13
+ import gradio as gr
14
+ import torchaudio
15
+
16
+ def prepare_media(source_path_or_url: os.PathLike,
17
+ source_type: Literal['audio_filepath', 'youtube_url'],
18
+ delete_video: bool = True) -> Dict:
19
+ """prepare media from source path or youtube, and return audio info"""
20
+ # Get audio_file
21
+ if source_type == 'audio_filepath':
22
+ audio_file = source_path_or_url
23
+ elif source_type == 'youtube_url':
24
+ # Download from youtube
25
+ try:
26
+ # Try PyTube first
27
+ yt = YouTube(source_path_or_url)
28
+ audio_stream = min(yt.streams.filter(only_audio=True), key=lambda s: s.bitrate)
29
+ mp4_file = audio_stream.download(output_path='downloaded') # ./downloaded
30
+ audio_file = mp4_file[:-3] + 'mp3'
31
+ subprocess.run(['ffmpeg', '-i', mp4_file, '-ac', '1', audio_file])
32
+ os.remove(mp4_file)
33
+ except Exception as e:
34
+ try:
35
+ # Try alternative
36
+ print(f"Failed with PyTube, error: {e}. Trying yt-dlp...")
37
+ audio_file = './downloaded/yt_audio'
38
+ subprocess.run(['yt-dlp', '-x', source_path_or_url, '-f', 'bestaudio',
39
+ '-o', audio_file, '--audio-format', 'mp3', '--restrict-filenames',
40
+ '--force-overwrites'])
41
+ audio_file += '.mp3'
42
+ except Exception as e:
43
+ print(f"Alternative downloader failed, error: {e}. Please try again later!")
44
+ return None
45
+ else:
46
+ raise ValueError(source_type)
47
+
48
+ # Create info
49
+ info = torchaudio.info(audio_file)
50
+ return {
51
+ "filepath": audio_file,
52
+ "track_name": os.path.basename(audio_file).split('.')[0],
53
+ "sample_rate": int(info.sample_rate),
54
+ "bits_per_sample": int(info.bits_per_sample),
55
+ "num_channels": int(info.num_channels),
56
+ "num_frames": int(info.num_frames),
57
+ "duration": int(info.num_frames / info.sample_rate),
58
+ "encoding": str.lower(info.encoding),
59
+ }
60
+
61
+ def process_audio(audio_filepath):
62
+ if audio_filepath is None:
63
+ return None
64
+ audio_info = prepare_media(audio_filepath, source_type='audio_filepath')
65
+ midifile = transcribe(model, audio_info)
66
+ midifile = to_data_url(midifile)
67
+ return create_html_from_midi(midifile) # html midiplayer
68
+
69
+ def process_video(youtube_url):
70
+ if 'youtu' not in youtube_url:
71
+ return None
72
+ audio_info = prepare_media(youtube_url, source_type='youtube_url')
73
+ midifile = transcribe(model, audio_info)
74
+ midifile = to_data_url(midifile)
75
+ return create_html_from_midi(midifile) # html midiplayer
76
+
77
+ def play_video(youtube_url):
78
+ if 'youtu' not in youtube_url:
79
+ return None
80
+ return create_html_youtube_player(youtube_url)
html_helper.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @title HTML helper
2
+ import re
3
+ import base64
4
+ def to_data_url(midi_filename):
5
+ """ This is crucial for Colab/WandB support. Thanks to Scott Hawley!!
6
+ https://github.com/drscotthawley/midi-player/blob/main/midi_player/midi_player.py
7
+
8
+ """
9
+ with open(midi_filename, "rb") as f:
10
+ encoded_string = base64.b64encode(f.read())
11
+ return 'data:audio/midi;base64,'+encoded_string.decode('utf-8')
12
+
13
+
14
+ def to_youtube_embed_url(video_url):
15
+ regex = r"(?:https:\/\/)?(?:www\.)?(?:youtube\.com|youtu\.be)\/(?:watch\?v=)?(.+)"
16
+ return re.sub(regex, r"https://www.youtube.com/embed/\1",video_url)
17
+
18
+
19
+ def create_html_from_midi(midifile):
20
+ html_template = """
21
+ <!DOCTYPE html>
22
+ <html>
23
+ <head>
24
+ <title>Awesome MIDI Player</title>
25
+ <script src="https://cdn.jsdelivr.net/combine/npm/[email protected],npm/@magenta/[email protected]/es6/core.js,npm/focus-visible@5,npm/[email protected]">
26
+ </script>
27
+ <style>
28
+ /* Background color for the section */
29
+ #proll {{background-color:transparent}}
30
+
31
+ /* Custom player style */
32
+ #proll midi-player {{
33
+ display: block;
34
+ width: inherit;
35
+ margin: 4px;
36
+ margin-bottom: 0;
37
+ }}
38
+
39
+ #proll midi-player::part(control-panel) {{
40
+ background: #D8DAE8;
41
+ border-radius: 8px 8px 0 0;
42
+ border: 1px solid #A0A0A0;
43
+ }}
44
+
45
+ /* Custom visualizer style */
46
+ #proll midi-visualizer .piano-roll-visualizer {{
47
+ background: #F7FAFA;
48
+ border-radius: 0 0 8px 8px;
49
+ border: 1px solid #A0A0A0;
50
+ margin: 4px;
51
+ margin-top: 2;
52
+ overflow: auto;
53
+ }}
54
+
55
+ #proll midi-visualizer svg rect.note {{
56
+ opacity: 0.6;
57
+ stroke-width: 2;
58
+ }}
59
+
60
+ #proll midi-visualizer svg rect.note[data-instrument="0"] {{
61
+ fill: #e22;
62
+ stroke: #055;
63
+ }}
64
+
65
+ #proll midi-visualizer svg rect.note[data-instrument="2"] {{
66
+ fill: #2ee;
67
+ stroke: #055;
68
+ }}
69
+
70
+ #proll midi-visualizer svg rect.note[data-is-drum="true"] {{
71
+ fill: #888;
72
+ stroke: #888;
73
+ }}
74
+
75
+ #proll midi-visualizer svg rect.note.active {{
76
+ opacity: 0.9;
77
+ stroke: #34384F;
78
+ }}
79
+ </style>
80
+ </head>
81
+ <body>
82
+ <div>
83
+ <a href="{midifile}" target="_blank">Download MIDI</a> <br>
84
+ <section id="proll">
85
+ <midi-player src="{midifile}" sound-font="https://storage.googleapis.com/magentadata/js/soundfonts/sgm_plus" visualizer="#proll midi-visualizer">
86
+ </midi-player>
87
+ <midi-visualizer src="{midifile}">
88
+ </midi-visualizer>
89
+ </section>
90
+ </div>
91
+ </body>
92
+ </html>
93
+ """.format(midifile=midifile)
94
+ html = f"""<iframe style="width: 100%; height: 400px; overflow:auto" srcdoc='{html_template}'></iframe>"""
95
+ return html
96
+
97
+ def create_html_youtube_player(youtube_url):
98
+ youtube_url = to_youtube_embed_url(youtube_url)
99
+ html = f"""<iframe width="560" height="315" src='{youtube_url}' title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>"""
100
+ return html
model_helper.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @title Model helper
2
+ %cd /content/amt/src
3
+ from collections import Counter
4
+ import argparse
5
+ import torch
6
+ import numpy as np
7
+
8
+ from model.init_train import initialize_trainer, update_config
9
+ from utils.task_manager import TaskManager
10
+ from config.vocabulary import drum_vocab_presets
11
+ from utils.utils import str2bool
12
+ from utils.utils import Timer
13
+ from utils.audio import slice_padded_array
14
+ from utils.note2event import mix_notes
15
+ from utils.event2note import merge_zipped_note_events_and_ties_to_notes
16
+ from utils.utils import write_model_output_as_midi, write_err_cnt_as_json
17
+ from model.ymt3 import YourMT3
18
+
19
+
20
+ def load_model_checkpoint(args=None):
21
+ parser = argparse.ArgumentParser(description="YourMT3")
22
+ # General
23
+ parser.add_argument('exp_id', type=str, help='A unique identifier for the experiment is used to resume training. The "@" symbol can be used to load a specific checkpoint.')
24
+ parser.add_argument('-p', '--project', type=str, default='ymt3', help='project name')
25
+ parser.add_argument('-ac', '--audio-codec', type=str, default=None, help='audio codec (default=None). {"spec", "melspec"}. If None, default value defined in config.py will be used.')
26
+ parser.add_argument('-hop', '--hop-length', type=int, default=None, help='hop length in frames (default=None). {128, 300} 128 for MT3, 300 for PerceiverTFIf None, default value defined in config.py will be used.')
27
+ parser.add_argument('-nmel', '--n-mels', type=int, default=None, help='number of mel bins (default=None). If None, default value defined in config.py will be used.')
28
+ parser.add_argument('-if', '--input-frames', type=int, default=None, help='number of audio frames for input segment (default=None). If None, default value defined in config.py will be used.')
29
+ # Model configurations
30
+ parser.add_argument('-sqr', '--sca-use-query-residual', type=str2bool, default=None, help='sca use query residual flag. Default follows config.py')
31
+ parser.add_argument('-enc', '--encoder-type', type=str, default=None, help="Encoder type. 't5' or 'perceiver-tf' or 'conformer'. Default is 't5', following config.py.")
32
+ parser.add_argument('-dec', '--decoder-type', type=str, default=None, help="Decoder type. 't5' or 'multi-t5'. Default is 't5', following config.py.")
33
+ parser.add_argument('-preenc', '--pre-encoder-type', type=str, default='default', help="Pre-encoder type. None or 'conv' or 'default'. By default, t5_enc:None, perceiver_tf_enc:conv, conformer:None")
34
+ parser.add_argument('-predec', '--pre-decoder-type', type=str, default='default', help="Pre-decoder type. {None, 'linear', 'conv1', 'mlp', 'group_linear'} or 'default'. Default is {'t5': None, 'perceiver-tf': 'linear', 'conformer': None}.")
35
+ parser.add_argument('-cout', '--conv-out-channels', type=int, default=None, help='Number of filters for pre-encoder conv layer. Default follows "model_cfg" of config.py.')
36
+ parser.add_argument('-tenc', '--task-cond-encoder', type=str2bool, default=True, help='task conditional encoder (default=True). True or False')
37
+ parser.add_argument('-tdec', '--task-cond-decoder', type=str2bool, default=True, help='task conditional decoder (default=True). True or False')
38
+ parser.add_argument('-df', '--d-feat', type=int, default=None, help='Audio feature will be projected to this dimension for Q,K,V of T5 or K,V of Perceiver (default=None). If None, default value defined in config.py will be used.')
39
+ parser.add_argument('-pt', '--pretrained', type=str2bool, default=False, help='pretrained T5(default=False). True or False')
40
+ parser.add_argument('-b', '--base-name', type=str, default="google/t5-v1_1-small", help='base model name (default="google/t5-v1_1-small")')
41
+ parser.add_argument('-epe', '--encoder-position-encoding-type', type=str, default='default', help="Positional encoding type of encoder. By default, pre-defined PE for T5 or Perceiver-TF encoder in config.py. For T5: {'sinusoidal', 'trainable'}, conformer: {'rotary', 'trainable'}, Perceiver-TF: {'trainable', 'rope', 'alibi', 'alibit', 'None', '0', 'none', 'tkd', 'td', 'tk', 'kdt'}.")
42
+ parser.add_argument('-dpe', '--decoder-position-encoding-type', type=str, default='default', help="Positional encoding type of decoder. By default, pre-defined PE for T5 in config.py. {'sinusoidal', 'trainable'}.")
43
+ parser.add_argument('-twe', '--tie-word-embedding', type=str2bool, default=None, help='tie word embedding (default=None). If None, default value defined in config.py will be used.')
44
+ parser.add_argument('-el', '--event-length', type=int, default=None, help='event length (default=None). If None, default value defined in model cfg of config.py will be used.')
45
+ # Perceiver-TF configurations
46
+ parser.add_argument('-dl', '--d-latent', type=int, default=None, help='Latent dimension of Perceiver. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
47
+ parser.add_argument('-nl', '--num-latents', type=int, default=None, help='Number of latents of Perceiver. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
48
+ parser.add_argument('-dpm', '--perceiver-tf-d-model', type=int, default=None, help='Perceiver-TF d_model (default=None). If None, default value defined in config.py will be used.')
49
+ parser.add_argument('-npb', '--num-perceiver-tf-blocks', type=int, default=None, help='Number of blocks of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py.')
50
+ parser.add_argument('-npl', '--num-perceiver-tf-local-transformers-per-block', type=int, default=None, help='Number of local layers per block of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
51
+ parser.add_argument('-npt', '--num-perceiver-tf-temporal-transformers-per-block', type=int, default=None, help='Number of temporal layers per block of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
52
+ parser.add_argument('-atc', '--attention-to-channel', type=str2bool, default=None, help='Attention to channel flag of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
53
+ parser.add_argument('-ln', '--layer-norm-type', type=str, default=None, help='Layer normalization type (default=None). {"layer_norm", "rms_norm"}. If None, default value defined in config.py will be used.')
54
+ parser.add_argument('-ff', '--ff-layer-type', type=str, default=None, help='Feed forward layer type (default=None). {"mlp", "moe", "gmlp"}. If None, default value defined in config.py will be used.')
55
+ parser.add_argument('-wf', '--ff-widening-factor', type=int, default=None, help='Feed forward layer widening factor for MLP/MoE/gMLP (default=None). If None, default value defined in config.py will be used.')
56
+ parser.add_argument('-nmoe', '--moe-num-experts', type=int, default=None, help='Number of experts for MoE (default=None). If None, default value defined in config.py will be used.')
57
+ parser.add_argument('-kmoe', '--moe-topk', type=int, default=None, help='Top-k for MoE (default=None). If None, default value defined in config.py will be used.')
58
+ parser.add_argument('-act', '--hidden-act', type=str, default=None, help='Hidden activation function (default=None). {"gelu", "silu", "relu", "tanh"}. If None, default value defined in config.py will be used.')
59
+ parser.add_argument('-rt', '--rotary-type', type=str, default=None, help='Rotary embedding type expressed in three letters. e.g. ppl: "pixel" for SCA and latents, "lang" for temporal transformer. If None, use config.')
60
+ parser.add_argument('-rk', '--rope-apply-to-keys', type=str2bool, default=None, help='Apply rope to keys (default=None). If None, use config.')
61
+ parser.add_argument('-rp', '--rope-partial-pe', type=str2bool, default=None, help='Whether to apply RoPE to partial positions (default=None). If None, use config.')
62
+ # Decoder configurations
63
+ parser.add_argument('-dff', '--decoder-ff-layer-type', type=str, default=None, help='Feed forward layer type of decoder (default=None). {"mlp", "moe", "gmlp"}. If None, default value defined in config.py will be used.')
64
+ parser.add_argument('-dwf', '--decoder-ff-widening-factor', type=int, default=None, help='Feed forward layer widening factor for decoder MLP/MoE/gMLP (default=None). If None, default value defined in config.py will be used.')
65
+ # Task and Evaluation configurations
66
+ parser.add_argument('-tk', '--task', type=str, default='mt3_full_plus', help='tokenizer type (default=mt3_full_plus). See config/task.py for more options.')
67
+ parser.add_argument('-epv', '--eval-program-vocab', type=str, default=None, help='evaluation vocabulary (default=None). If None, default vocabulary of the data preset will be used.')
68
+ parser.add_argument('-edv', '--eval-drum-vocab', type=str, default=None, help='evaluation vocabulary for drum (default=None). If None, default vocabulary of the data preset will be used.')
69
+ parser.add_argument('-etk', '--eval-subtask-key', type=str, default='default', help='evaluation subtask key (default=default). See config/task.py for more options.')
70
+ parser.add_argument('-t', '--onset-tolerance', type=float, default=0.05, help='onset tolerance (default=0.05).')
71
+ parser.add_argument('-os', '--test-octave-shift', type=str2bool, default=False, help='test optimal octave shift (default=False). True or False')
72
+ parser.add_argument('-w', '--write-model-output', type=str2bool, default=True, help='write model test output to file (default=False). True or False')
73
+ # Trainer configurations
74
+ parser.add_argument('-pr','--precision', type=str, default="bf16-mixed", help='precision (default="bf16-mixed") {32, 16, bf16, bf16-mixed}')
75
+ parser.add_argument('-st', '--strategy', type=str, default='auto', help='strategy (default=auto). auto or deepspeed or ddp')
76
+ parser.add_argument('-n', '--num-nodes', type=int, default=1, help='number of nodes (default=1)')
77
+ parser.add_argument('-g', '--num-gpus', type=str, default='auto', help='number of gpus (default="auto")')
78
+ parser.add_argument('-wb', '--wandb-mode', type=str, default="disabled", help='wandb mode for logging (default=None). "disabled" or "online" or "offline". If None, default value defined in config.py will be used.')
79
+ # Debug
80
+ parser.add_argument('-debug', '--debug-mode', type=str2bool, default=False, help='debug mode (default=False). True or False')
81
+ parser.add_argument('-tps', '--test-pitch-shift', type=int, default=None, help='use pitch shift when testing. debug-purpose only. (default=None). semitone in int.')
82
+ args = parser.parse_args(args)
83
+ # yapf: enable
84
+ if torch.__version__ >= "1.13":
85
+ torch.set_float32_matmul_precision("high")
86
+ args.epochs = None
87
+
88
+ # Initialize and update config
89
+ _, _, dir_info, shared_cfg = initialize_trainer(args, stage='test')
90
+ shared_cfg, audio_cfg, model_cfg = update_config(args, shared_cfg, stage='test')
91
+
92
+ if args.eval_drum_vocab != None: # override eval_drum_vocab
93
+ eval_drum_vocab = drum_vocab_presets[args.eval_drum_vocab]
94
+
95
+ # Initialize task manager
96
+ tm = TaskManager(task_name=args.task,
97
+ max_shift_steps=int(shared_cfg["TOKENIZER"]["max_shift_steps"]),
98
+ debug_mode=args.debug_mode)
99
+ print(f"Task: {tm.task_name}, Max Shift Steps: {tm.max_shift_steps}")
100
+
101
+ # Use GPU if available
102
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
103
+
104
+ # Model
105
+ model = YourMT3(
106
+ audio_cfg=audio_cfg,
107
+ model_cfg=model_cfg,
108
+ shared_cfg=shared_cfg,
109
+ optimizer=None,
110
+ task_manager=tm, # tokenizer is a member of task_manager
111
+ eval_subtask_key=args.eval_subtask_key,
112
+ write_output_dir=dir_info["lightning_dir"] if args.write_model_output or args.test_octave_shift else None
113
+ ).to(device)
114
+ checkpoint = torch.load(dir_info["last_ckpt_path"])
115
+ state_dict = checkpoint['state_dict']
116
+ new_state_dict = {k: v for k, v in state_dict.items() if 'pitchshift' not in k}
117
+ model.load_state_dict(new_state_dict, strict=False)
118
+ return model.eval()
119
+
120
+
121
+ def transcribe(model, audio_info):
122
+ t = Timer()
123
+
124
+ # Converting Audio
125
+ t.start()
126
+ audio, sr = torchaudio.load(uri=audio_info['filepath'])
127
+ audio = torch.mean(audio, dim=0).unsqueeze(0)
128
+ audio = torchaudio.functional.resample(audio, sr, model.audio_cfg['sample_rate'])
129
+ audio_segments = slice_padded_array(audio, model.audio_cfg['input_frames'], model.audio_cfg['input_frames'])
130
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
131
+ audio_segments = torch.from_numpy(audio_segments.astype('float32')).to(device).unsqueeze(1) # (n_seg, 1, seg_sz)
132
+ t.stop(); t.print_elapsed_time("converting audio");
133
+
134
+ # Inference
135
+ t.start()
136
+ pred_token_arr, _ = model.inference_file(bsz=8, audio_segments=audio_segments)
137
+ t.stop(); t.print_elapsed_time("model inference");
138
+
139
+ # Post-processing
140
+ t.start()
141
+ num_channels = model.task_manager.num_decoding_channels
142
+ n_items = audio_segments.shape[0]
143
+ start_secs_file = [model.audio_cfg['input_frames'] * i / model.audio_cfg['sample_rate'] for i in range(n_items)]
144
+ pred_notes_in_file = []
145
+ n_err_cnt = Counter()
146
+ for ch in range(num_channels):
147
+ pred_token_arr_ch = [arr[:, ch, :] for arr in pred_token_arr] # (B, L)
148
+ zipped_note_events_and_tie, list_events, ne_err_cnt = model.task_manager.detokenize_list_batches(
149
+ pred_token_arr_ch, start_secs_file, return_events=True)
150
+ pred_notes_ch, n_err_cnt_ch = merge_zipped_note_events_and_ties_to_notes(zipped_note_events_and_tie)
151
+ pred_notes_in_file.append(pred_notes_ch)
152
+ n_err_cnt += n_err_cnt_ch
153
+ pred_notes = mix_notes(pred_notes_in_file) # This is the mixed notes from all channels
154
+
155
+ # Write MIDI
156
+ write_model_output_as_midi(pred_notes, '/content/',
157
+ audio_info['track_name'], model.midi_output_inverse_vocab)
158
+ t.stop(); t.print_elapsed_time("post processing");
159
+ midifile = os.path.join('/content/model_output/', audio_info['track_name'] + '.mid')
160
+ assert os.path.exists(midifile)
161
+ return midifile