piano_trans / app.py
admin
upd req
9b709d7
raw
history blame
5.59 kB
import os
import re
import json
import torch
import shutil
import requests
import gradio as gr
from piano_transcription_inference import PianoTranscription, load_audio, sample_rate
from modelscope import snapshot_download
from urllib.parse import urlparse
from convert import midi2xml, xml2abc, xml2mxl, xml2jpg
WEIGHTS_PATH = (
snapshot_download("Genius-Society/piano_trans", cache_dir="./__pycache__")
+ "/CRNN_note_F1=0.9677_pedal_F1=0.9186.pth"
)
def clean_cache(cache_dir):
if os.path.exists(cache_dir):
shutil.rmtree(cache_dir)
os.mkdir(cache_dir)
def download_audio(url: str, save_path: str):
response = requests.get(url, stream=True)
response.raise_for_status()
with open(save_path, "wb") as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
def is_url(s: str):
try:
result = urlparse(s)
return all([result.scheme, result.netloc])
except:
return False
def audio2midi(audio_path: str, cache_dir: str):
audio, _ = load_audio(audio_path, sr=sample_rate)
transcriptor = PianoTranscription(
device="cuda" if torch.cuda.is_available() else "cpu",
checkpoint_path=WEIGHTS_PATH,
)
midi_path = f"{cache_dir}/output.mid"
transcriptor.transcribe(audio, midi_path)
return midi_path, os.path.basename(audio_path).split(".")[-2].capitalize()
def upl_infer(audio_path: str, cache_dir="./__pycache__/mode1"):
clean_cache(cache_dir)
try:
midi, title = audio2midi(audio_path, cache_dir)
xml = midi2xml(midi, title)
abc = xml2abc(xml)
mxl = xml2mxl(xml)
pdf, jpg = xml2jpg(xml)
return midi, pdf, xml, mxl, abc, jpg
except Exception as e:
return None, None, None, None, f"{e}", None
def get_1st_int(input_string: str):
match = re.search(r"\d+", input_string)
if match:
return str(int(match.group()))
else:
return ""
def music163_song_info(id: str):
detail_api = "https://music.163.com/api/v3/song/detail"
parm_dict = {"id": id, "c": str([{"id": id}]), "csrf_token": ""}
free = False
song_name = "Failed to get the song"
response = requests.get(detail_api, params=parm_dict)
if response.status_code == 200:
data = json.loads(response.text)
if data and "songs" in data and data["songs"]:
fee = int(data["songs"][0]["fee"])
free = fee == 0 or fee == 8
song_name = str(data["songs"][0]["name"])
else:
song_name = "The song does not exist"
else:
raise ConnectionError(f"Error: {response.status_code}, {response.text}")
return song_name, free
def url_infer(song: str, cache_dir="./__pycache__/mode2"):
song_name = ""
clean_cache(cache_dir)
audio_path = f"{cache_dir}/output.mp3"
try:
if (is_url(song) and "163" in song and "?id=" in song) or song.isdigit():
song_id = get_1st_int(song.split("?id=")[-1])
song_url = f"https://music.163.com/song/media/outer/url?id={song_id}.mp3"
song_name, free = music163_song_info(song_id)
if not free:
raise AttributeError("Unable to parse VIP songs")
download_audio(song_url, audio_path)
midi, title = audio2midi(audio_path, cache_dir)
if song_name:
title = song_name
xml = midi2xml(midi, title)
abc = xml2abc(xml)
mxl = xml2mxl(xml)
pdf, jpg = xml2jpg(xml)
return audio_path, midi, pdf, xml, mxl, abc, jpg
except Exception as e:
return None, None, None, None, None, f"{e}", None
if __name__ == "__main__":
with gr.Blocks() as iface:
gr.Markdown("# Piano Transcription Tool")
with gr.Tab("Uploading Mode"):
gr.Interface(
fn=upl_infer,
inputs=gr.Audio(
label="Upload an audio",
type="filepath",
),
outputs=[
gr.File(label="Download MIDI"),
gr.File(label="Download PDF score"),
gr.File(label="Download MusicXML"),
gr.File(label="Download MXL"),
gr.Textbox(label="ABC notation", show_copy_button=True),
gr.Image(label="Staff", type="filepath"),
],
description="Please make sure the audio is completely uploaded before clicking Submit",
flagging_mode="never",
)
with gr.Tab("Direct Link Mode"):
gr.Interface(
fn=url_infer,
inputs=gr.Textbox(
label="Input audio direct link",
placeholder="https://music.163.com/#/song?id=",
),
outputs=[
gr.Audio(label="Download audio", type="filepath"),
gr.File(label="Download MIDI"),
gr.File(label="Download PDF score"),
gr.File(label="Download MusicXML"),
gr.File(label="Download MXL"),
gr.Textbox(label="ABC notation", show_copy_button=True),
gr.Image(label="Staff", type="filepath"),
],
description="For Netease Cloud music, you can directly input the non-VIP song page link",
examples=["1945798894", "1945798973", "1946098771"],
flagging_mode="never",
cache_examples=False,
)
iface.launch()