File size: 4,338 Bytes
dae5b5d
 
 
 
 
 
 
 
 
 
 
376a444
dae5b5d
 
 
 
 
376a444
dae5b5d
 
 
 
 
376a444
dae5b5d
376a444
 
dae5b5d
 
 
 
376a444
dae5b5d
376a444
dae5b5d
376a444
 
dae5b5d
 
 
 
 
 
376a444
dae5b5d
376a444
dae5b5d
 
 
 
 
 
 
 
 
376a444
dae5b5d
 
376a444
 
 
 
 
dae5b5d
 
 
376a444
dae5b5d
 
376a444
 
 
 
 
dae5b5d
376a444
 
 
 
 
 
dae5b5d
 
376a444
 
 
dae5b5d
376a444
dae5b5d
376a444
 
dae5b5d
 
 
 
 
376a444
dae5b5d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import torch
import librosa
import binascii
import warnings
import midi2audio  # MIDI ํŒŒ์ผ์„ WAV ํŒŒ์ผ๋กœ ๋ณ€ํ™˜
import numpy as np
import pytube as pt  # YouTube ๋น„๋””์˜ค๋ฅผ ์˜ค๋””์˜ค๋กœ ๋‹ค์šด๋กœ๋“œ
import gradio as gr
import soundfile as sf
from transformers import Pop2PianoForConditionalGeneration, Pop2PianoProcessor

# ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
yt_video_dir = "./yt_dir"  # ์œ ํŠœ๋ธŒ ๋น„๋””์˜ค ๋‹ค์šด๋กœ๋“œ ๊ฒฝ๋กœ
outputs_dir = "./midi_wav_outputs"  # ์ถœ๋ ฅ ํŒŒ์ผ ๊ฒฝ๋กœ
os.makedirs(outputs_dir, exist_ok=True)
os.makedirs(yt_video_dir, exist_ok=True)

# ๋ชจ๋ธ ์„ค์ •
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Pop2PianoForConditionalGeneration.from_pretrained("sweetcocoa/pop2piano").to(device)
processor = Pop2PianoProcessor.from_pretrained("sweetcocoa/pop2piano")
composers = model.generation_config.composer_to_feature_token.keys()

# ์œ ํŠœ๋ธŒ ๋น„๋””์˜ค์—์„œ ์˜ค๋””์˜ค ์ถ”์ถœ ํ•จ์ˆ˜
def get_audio_from_yt_video(yt_link):
    try:
        yt = pt.YouTube(yt_link)
        t = yt.streams.filter(only_audio=True)
        filename = os.path.join(yt_video_dir, binascii.hexlify(os.urandom(8)).decode() + ".mp4")
        t[0].download(filename=filename)
    except:
        warnings.warn(f"Video Not Found at {yt_link}")
        filename = None
    
    return filename, filename

# ๋ชจ๋ธ ์ถ”๋ก  ํ•จ์ˆ˜
def inference(file_uploaded, composer):
    waveform, sr = librosa.load(file_uploaded, sr=None) 
    inputs = processor(audio=waveform, sampling_rate=sr, return_tensors="pt").to(device)
    model_output = model.generate(input_features=inputs["input_features"], composer=composer)
    tokenizer_output = processor.batch_decode(token_ids=model_output.to("cpu"), feature_extractor_output=inputs.to("cpu"))["pretty_midi_objects"]

    return prepare_output_file(tokenizer_output, sr)    

# ์ถœ๋ ฅ ํŒŒ์ผ ์ค€๋น„ ํ•จ์ˆ˜
def prepare_output_file(tokenizer_output, sr):
    output_file_name = "output_" + binascii.hexlify(os.urandom(8)).decode()
    midi_output = os.path.join(outputs_dir, output_file_name + ".mid")
    tokenizer_output[0].write(midi_output)
    wav_output = midi_output.replace(".mid", ".wav")
    midi2audio.FluidSynth().midi_to_audio(midi_output, wav_output)
    
    return wav_output, wav_output, midi_output

# Gradio UI ์„ค์ •
block = gr.Blocks(theme="Taithrah/Minimal")

with block:
    gr.HTML(
        """
        <div style="text-align: center; max-width: 800px; margin: 0 auto;">
            <h1 style="font-weight: 900; margin-bottom: 12px;">
                ๐ŸŽน Pop2Piano : ํ”ผ์•„๋…ธ ์ปค๋ฒ„๊ณก ์ƒ์„ฑ๊ธฐ ๐ŸŽน
            </h1>
            <p style="margin-bottom: 12px; font-size: 90%">
                Pop2Piano ๋ฐ๋ชจ: ํŒ ์˜ค๋””์˜ค ๊ธฐ๋ฐ˜ ํ”ผ์•„๋…ธ ์ปค๋ฒ„๊ณก ์ƒ์„ฑ. <br>
                ์ž‘๊ณก๊ฐ€(ํŽธ๊ณก์ž)๋ฅผ ์„ ํƒํ•˜๊ณ  ํŒ ์˜ค๋””์˜ค๋ฅผ ์—…๋กœ๋“œํ•˜๊ฑฐ๋‚˜ ์œ ํŠœ๋ธŒ ๋งํฌ๋ฅผ ์ž…๋ ฅํ•œ ํ›„ ์ƒ์„ฑ ๋ฒ„ํŠผ์„ ํด๋ฆญํ•˜์„ธ์š”.
            </p>
        </div>
        """
    )
    with gr.Group():
        with gr.Row():
            with gr.Column():
                file_uploaded = gr.Audio(label="์˜ค๋””์˜ค ์—…๋กœ๋“œ", type="filepath")
            with gr.Column():
                with gr.Row():
                    yt_link = gr.Textbox(label="์œ ํŠœ๋ธŒ ๋งํฌ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”.", autofocus=True, lines=3)
                    yt_btn = gr.Button("์œ ํŠœ๋ธŒ ๋งํฌ์—์„œ ์˜ค๋””์˜ค๋ฅผ ๋‹ค์šด ๋ฐ›์Šต๋‹ˆ๋‹ค.", size="lg")
                yt_audio_path = gr.Audio(label="์œ ํŠœ๋ธŒ ๋™์˜์ƒ์—์„œ ์ถ”์ถœํ•œ ์˜ค๋””์˜ค", interactive=False)
                yt_btn.click(get_audio_from_yt_video, inputs=[yt_link], outputs=[yt_audio_path, file_uploaded])

    with gr.Group():
        with gr.Column():
            composer = gr.Dropdown(label="ํŽธ๊ณก์ž", choices=composers, value="composer1")
            generate_btn = gr.Button("๋‚˜๋งŒ์˜ ํ”ผ์•„๋…ธ ์ปค๋ฒ„๊ณก ๋งŒ๋“ค๊ธฐ๐ŸŽน๐ŸŽต")
        with gr.Row():
            wav_output2 = gr.File(label="๋‚˜๋งŒ์˜ ํ”ผ์•„๋…ธ ์ปค๋ฒ„๊ณก์„ ๋‹ค์šด๋กœ๋“œ (.wav)")
            wav_output1 = gr.Audio(label="๋‚˜๋งŒ์˜ ํ”ผ์•„๋…ธ ์ปค๋ฒ„๊ณก ๋“ฃ๊ธฐ")
            midi_output = gr.File(label="์ƒ์„ฑํ•œ midi ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ (.mid)")
            generate_btn.click(
                inference, 
                inputs=[file_uploaded, composer], 
                outputs=[wav_output1, wav_output2, midi_output])

block.launch(debug=False)