Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from tc5.config import SAMPLE_RATE, HOP_LENGTH | |
from tc5.model import TaikoConformer5 | |
from tc5 import infer as tc5infer | |
from tc6.model import TaikoConformer6 | |
from tc6 import infer as tc6infer | |
from tc7.model import TaikoConformer7 | |
from tc7 import infer as tc7infer | |
from gradio_client import Client, handle_file | |
import tempfile | |
DEVICE = torch.device("cpu") | |
# Load model once | |
tc5 = TaikoConformer5.from_pretrained("JacobLinCool/taiko-conformer-5") | |
tc5.to(DEVICE) | |
tc5.eval() | |
# Load TC6 model | |
tc6 = TaikoConformer6.from_pretrained("JacobLinCool/taiko-conformer-6") | |
tc6.to(DEVICE) | |
tc6.eval() | |
# Load TC7 model | |
tc7 = TaikoConformer7.from_pretrained("JacobLinCool/taiko-conformer-7") | |
tc7.to(DEVICE) | |
tc7.eval() | |
synthesizer = Client("ryanlinjui/taiko-music-generator") | |
def infer_tc5(audio, nps, bpm): | |
audio_path = audio | |
filename = audio_path.split("/")[-1] | |
# Preprocess | |
mel_input, nps_input = tc5infer.preprocess_audio(audio_path, nps) | |
# Inference | |
don_energy, ka_energy, drumroll_energy = tc5infer.run_inference( | |
tc5, mel_input, nps_input, DEVICE | |
) | |
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE | |
onsets = tc5infer.decode_onsets( | |
don_energy, | |
ka_energy, | |
drumroll_energy, | |
output_frame_hop_sec, | |
threshold=0.3, | |
min_distance_frames=3, | |
) | |
# Generate plot | |
plot = tc5infer.plot_results( | |
mel_input, | |
don_energy, | |
ka_energy, | |
drumroll_energy, | |
onsets, | |
output_frame_hop_sec, | |
) | |
# Generate TJA content | |
tja_content = tc5infer.write_tja(onsets, bpm=bpm, audio=filename) | |
# wrtie TJA content to a temporary file | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".tja") as temp_tja_file: | |
temp_tja_file.write(tja_content.encode("utf-8")) | |
tja_path = temp_tja_file.name | |
result = synthesizer.predict( | |
param_0=handle_file(tja_path), | |
param_1=handle_file(audio_path), | |
param_2="達人譜面 / Master", | |
param_3=16, | |
param_4=5, | |
param_5=5, | |
param_6=5, | |
param_7=5, | |
param_8=5, | |
param_9=5, | |
param_10=5, | |
param_11=5, | |
param_12=5, | |
param_13=5, | |
param_14=5, | |
param_15=5, | |
api_name="/handle", | |
) | |
oni_audio = result[1] | |
return oni_audio, plot, tja_content | |
def infer_tc6(audio, nps, bpm, difficulty, level): | |
audio_path = audio | |
filename = audio_path.split("/")[-1] | |
# Preprocess | |
mel_input = tc6infer.preprocess_audio(audio_path) | |
nps_input = torch.tensor(nps, dtype=torch.float32).to(DEVICE) | |
difficulty_input = torch.tensor(difficulty, dtype=torch.float32).to(DEVICE) | |
level_input = torch.tensor(level, dtype=torch.float32).to(DEVICE) | |
# Inference | |
don_energy, ka_energy, drumroll_energy = tc6infer.run_inference( | |
tc6, mel_input, nps_input, difficulty_input, level_input, DEVICE | |
) | |
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE | |
onsets = tc6infer.decode_onsets( | |
don_energy, | |
ka_energy, | |
drumroll_energy, | |
output_frame_hop_sec, | |
threshold=0.3, | |
min_distance_frames=3, | |
) | |
# Generate plot | |
plot = tc6infer.plot_results( | |
mel_input, | |
don_energy, | |
ka_energy, | |
drumroll_energy, | |
onsets, | |
output_frame_hop_sec, | |
) | |
# Generate TJA content | |
tja_content = tc6infer.write_tja(onsets, bpm=bpm, audio=filename) | |
# wrtie TJA content to a temporary file | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".tja") as temp_tja_file: | |
temp_tja_file.write(tja_content.encode("utf-8")) | |
tja_path = temp_tja_file.name | |
result = synthesizer.predict( | |
param_0=handle_file(tja_path), | |
param_1=handle_file(audio_path), | |
param_2="達人譜面 / Master", | |
param_3=16, | |
param_4=5, | |
param_5=5, | |
param_6=5, | |
param_7=5, | |
param_8=5, | |
param_9=5, | |
param_10=5, | |
param_11=5, | |
param_12=5, | |
param_13=5, | |
param_14=5, | |
param_15=5, | |
api_name="/handle", | |
) | |
oni_audio = result[1] | |
return oni_audio, plot, tja_content | |
def infer_tc7(audio, nps, bpm, difficulty, level): | |
audio_path = audio | |
filename = audio_path.split("/")[-1] | |
# Preprocess | |
mel_input = tc7infer.preprocess_audio(audio_path) | |
nps_input = torch.tensor(nps, dtype=torch.float32).to(DEVICE) | |
difficulty_input = torch.tensor(difficulty, dtype=torch.float32).to(DEVICE) | |
level_input = torch.tensor(level, dtype=torch.float32).to(DEVICE) | |
# Inference | |
don_energy, ka_energy, drumroll_energy = tc7infer.run_inference( | |
tc7, mel_input, nps_input, difficulty_input, level_input, DEVICE | |
) | |
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE | |
onsets = tc7infer.decode_onsets( | |
don_energy, | |
ka_energy, | |
drumroll_energy, | |
output_frame_hop_sec, | |
threshold=0.3, | |
min_distance_frames=3, | |
) | |
# Generate plot | |
plot = tc7infer.plot_results( | |
mel_input, | |
don_energy, | |
ka_energy, | |
drumroll_energy, | |
onsets, | |
output_frame_hop_sec, | |
) | |
# Generate TJA content | |
tja_content = tc7infer.write_tja(onsets, bpm=bpm, audio=filename) | |
# wrtie TJA content to a temporary file | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".tja") as temp_tja_file: | |
temp_tja_file.write(tja_content.encode("utf-8")) | |
tja_path = temp_tja_file.name | |
result = synthesizer.predict( | |
param_0=handle_file(tja_path), | |
param_1=handle_file(audio_path), | |
param_2="達人譜面 / Master", | |
param_3=16, | |
param_4=5, | |
param_5=5, | |
param_6=5, | |
param_7=5, | |
param_8=5, | |
param_9=5, | |
param_10=5, | |
param_11=5, | |
param_12=5, | |
param_13=5, | |
param_14=5, | |
param_15=5, | |
api_name="/handle", | |
) | |
oni_audio = result[1] | |
return oni_audio, plot, tja_content | |
def run_inference(audio, model_choice, nps, bpm, difficulty, level): | |
if model_choice == "TC5": | |
return infer_tc5(audio, nps, bpm) | |
elif model_choice == "TC6": | |
return infer_tc6(audio, nps, bpm, difficulty, level) | |
else: # TC7 | |
return infer_tc7(audio, nps, bpm, difficulty, level) | |
with gr.Blocks() as demo: | |
gr.Markdown("# Taiko Conformer 5/7 Demo") | |
with gr.Row(): | |
audio_input = gr.Audio(sources="upload", type="filepath", label="Input Audio") | |
with gr.Row(): | |
model_choice = gr.Dropdown( | |
choices=["TC5", "TC6", "TC7"], | |
value="TC7", | |
label="Model Selection", | |
info="Choose between TaikoConformer 5, 6 or 7", | |
) | |
with gr.Row(): | |
nps = gr.Slider( | |
value=5.0, | |
minimum=0.5, | |
maximum=11.0, | |
step=0.5, | |
label="NPS (Notes Per Second)", | |
) | |
bpm = gr.Slider( | |
value=240, | |
minimum=160, | |
maximum=640, | |
step=1, | |
label="BPM (Used by TJA Quantization)", | |
) | |
with gr.Row(): | |
difficulty = gr.Slider( | |
value=3.0, | |
minimum=1.0, | |
maximum=3.0, | |
step=1.0, | |
label="Difficulty", | |
visible=False, | |
info="1=Normal, 2=Hard, 3=Oni", | |
) | |
level = gr.Slider( | |
value=8.0, | |
minimum=1.0, | |
maximum=10.0, | |
step=1.0, | |
label="Level", | |
visible=False, | |
info="Difficulty level from 1 to 10", | |
) | |
audio_output = gr.Audio(label="Generated Audio", type="filepath") | |
plot_output = gr.Plot(label="Onset/Energy Plot") | |
tja_output = gr.Textbox(label="TJA File Content", show_copy_button=True) | |
run_btn = gr.Button("Run Inference") | |
# Update visibility of TC7-specific controls based on model selection | |
def update_visibility(model_choice): | |
if model_choice == "TC7" or model_choice == "TC6": | |
return gr.update(visible=True), gr.update(visible=True) | |
else: | |
return gr.update(visible=False), gr.update(visible=False) | |
model_choice.change( | |
update_visibility, inputs=[model_choice], outputs=[difficulty, level] | |
) | |
run_btn.click( | |
run_inference, | |
inputs=[audio_input, model_choice, nps, bpm, difficulty, level], | |
outputs=[audio_output, plot_output, tja_output], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |