|
import sys |
|
import os |
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'amt/src'))) |
|
|
|
|
|
import glob |
|
import gradio as gr |
|
|
|
from gradio_helper import * |
|
from model_helper import * |
|
|
|
|
|
model_name = 'YPTF.MoE+Multi (noPS)' |
|
precision = '16' |
|
project = '2024' |
|
|
|
if model_name == "YMT3+": |
|
checkpoint = "[email protected]" |
|
args = [checkpoint, '-p', project, '-pr', precision] |
|
elif model_name == "YPTF+Single (noPS)": |
|
checkpoint = "ptf_all_cross_rebal5_mirst_xk2_edr005_attend_c_full_plus_b100@model.ckpt" |
|
args = [checkpoint, '-p', project, '-enc', 'perceiver-tf', '-ac', 'spec', |
|
'-hop', '300', '-atc', '1', '-pr', precision] |
|
elif model_name == "YPTF+Multi (PS)": |
|
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt" |
|
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', |
|
'-dec', 'multi-t5', '-nl', '26', '-enc', 'perceiver-tf', |
|
'-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision] |
|
elif model_name == "YPTF.MoE+Multi (noPS)": |
|
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt" |
|
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5', |
|
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe', |
|
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope', |
|
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision] |
|
elif model_name == "YPTF.MoE+Multi (PS)": |
|
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt" |
|
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5', |
|
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe', |
|
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope', |
|
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision] |
|
else: |
|
raise ValueError(model_name) |
|
|
|
model = load_model_checkpoint(args=args) |
|
|
|
|
|
|
|
AUDIO_EXAMPLES = glob.glob('/content/examples/*.*', recursive=True) |
|
YOUTUBE_EXAMPLES = ["https://www.youtube.com/watch?v=vMboypSkj3c"] |
|
|
|
theme = 'gradio/dracula_revamped' |
|
with gr.Blocks(theme=theme) as demo: |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=10): |
|
gr.Markdown( |
|
""" |
|
### YourMT3+: Multi-instrument Music Transcription with Enhanced Transformer Architectures and Cross-dataset Stem Augmentation |
|
""") |
|
|
|
with gr.Group(): |
|
with gr.Tab("Upload audio"): |
|
|
|
audio_input = gr.Audio(label="Record Audio", type="filepath", |
|
show_share_button=True, show_download_button=True) |
|
|
|
gr.Examples(examples=AUDIO_EXAMPLES, inputs=audio_input) |
|
|
|
transcribe_audio_button = gr.Button("Transcribe", variant="primary") |
|
|
|
output_tab1 = gr.HTML() |
|
|
|
|
|
transcribe_audio_button.click(process_audio, inputs=audio_input, outputs=output_tab1) |
|
|
|
with gr.Tab("From YouTube"): |
|
with gr.Row(): |
|
|
|
youtube_url = gr.Textbox(label="YouTube Link URL", |
|
placeholder="https://youtu.be/...") |
|
|
|
youtube_player = gr.HTML(render=True) |
|
with gr.Row(): |
|
|
|
play_video_button = gr.Button("Play", variant="primary") |
|
|
|
transcribe_video_button = gr.Button("Transcribe", variant="primary") |
|
|
|
output_tab2 = gr.HTML(render=True) |
|
|
|
transcribe_video_button.click(process_video, inputs=youtube_url, outputs=output_tab2) |
|
|
|
play_video_button.click(play_video, inputs=youtube_url, outputs=youtube_player) |
|
|
|
|
|
gr.Examples(examples=YOUTUBE_EXAMPLES, inputs=youtube_url) |
|
|
|
demo.launch(debug=True) |
|
|