sweetcocoa commited on
Commit
7a3b53b
·
1 Parent(s): 1aa2e4c

move to gradio

Browse files
Files changed (2) hide show
  1. app.py +66 -43
  2. transformer_wrapper.py +5 -17
app.py CHANGED
@@ -1,15 +1,13 @@
1
- import streamlit as st
2
  import os
3
  from transformer_wrapper import TransformerWrapper
4
  from omegaconf import OmegaConf
5
 
6
 
7
- @st.cache(show_spinner=False)
8
  def get_file_content_as_string(path):
9
  return open(path, "r", encoding="utf-8").read()
10
 
11
 
12
- @st.cache(show_spinner=True)
13
  def model_load():
14
  config = OmegaConf.load("config.yaml")
15
  wrapper = TransformerWrapper(config)
@@ -23,43 +21,68 @@ def model_load():
23
  return wrapper, model_id, config
24
 
25
 
26
- def main():
27
-
28
- wrapper, model_id, config = model_load()
29
- composers = list(config.composer_to_feature_token.keys())
30
- dest_dir = "ytsamples"
31
- os.makedirs(dest_dir, exist_ok=True)
32
- composer = st.selectbox(label="Arranger", options=composers)
33
- file_up = st.file_uploader("Upload an audio", type=["mp3", "wav"])
34
-
35
- if st.button("convert"):
36
-
37
- if file_up is not None:
38
- bytes_data = file_up.getvalue()
39
- target_file = f"{dest_dir}/{file_up.name}"
40
- with open(target_file, "wb") as f:
41
- f.write(bytes_data)
42
-
43
- with st.spinner("Wait for it..."):
44
- midi, arranger, mix_path, midi_path = wrapper.generate(
45
- audio_path=target_file,
46
- composer=composer,
47
- model=model_id,
48
- ignore_duplicate=True,
49
- show_plot=False,
50
- save_midi=True,
51
- save_mix=True,
52
- )
53
-
54
- with open(midi_path, "rb") as midi_f:
55
- file_down = st.download_button(
56
- "Download midi",
57
- data=midi_f,
58
- file_name=os.path.basename(midi_path),
59
- )
60
- with open(mix_path, "rb") as audio_f:
61
- st.audio(audio_f.read(), format="audio/wav")
62
-
63
-
64
- if __name__ == "__main__":
65
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
  import os
3
  from transformer_wrapper import TransformerWrapper
4
  from omegaconf import OmegaConf
5
 
6
 
 
7
  def get_file_content_as_string(path):
8
  return open(path, "r", encoding="utf-8").read()
9
 
10
 
 
11
  def model_load():
12
  config = OmegaConf.load("config.yaml")
13
  wrapper = TransformerWrapper(config)
 
21
  return wrapper, model_id, config
22
 
23
 
24
+ wrapper, model_id, config = model_load()
25
+ composers = list(config.composer_to_feature_token.keys())
26
+ dest_dir = "ytsamples"
27
+ os.makedirs(dest_dir, exist_ok=True)
28
+
29
+
30
+ def inference(file_up, composer):
31
+
32
+ midi, arranger, mix_path, midi_path = wrapper.generate(
33
+ audio_path=file_up,
34
+ composer=composer,
35
+ model=model_id,
36
+ ignore_duplicate=True,
37
+ show_plot=False,
38
+ save_midi=True,
39
+ save_mix=True,
40
+ )
41
+
42
+ return mix_path
43
+
44
+
45
+ block = gr.Blocks()
46
+
47
+
48
+ with block:
49
+ gr.HTML(
50
+ """
51
+ <div style="text-align: center; max-width: 700px; margin: 0 auto;">
52
+ <div
53
+ style="
54
+ display: inline-flex;
55
+ align-items: center;
56
+ gap: 0.8rem;
57
+ font-size: 1.75rem;
58
+ "
59
+ >
60
+ <h1 style="font-weight: 900; margin-bottom: 7px;">
61
+ Pop2piano
62
+ </h1>
63
+ </div>
64
+ <p style="margin-bottom: 10px; font-size: 94%">
65
+ A demo for Pop2Piano:Pop Audio-based Piano Cover Generation. Please select the composer and upload the pop audio to submit.
66
+ </p>
67
+ </div>
68
+ """
69
+ )
70
+ with gr.Group():
71
+ with gr.Box():
72
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
73
+ file_up = gr.Audio(label="Upload an audio", type="filepath")
74
+ composer = gr.Dropdown(label="Arranger", choices=composers, value="composer1")
75
+ btn = gr.Button("Convert")
76
+ out = gr.Audio(label="Output")
77
+
78
+ btn.click(inference, inputs=[file_up, composer], outputs=out)
79
+ gr.HTML(
80
+ """
81
+ <div class="footer">
82
+ <p><a href="http://sweetcocoa.github.io/pop2piano_samples" style="text-decoration: underline;" target="_blank">Project Page</a>
83
+ </p>
84
+ </div>
85
+ """
86
+ )
87
+
88
+ block.launch(debug=True)
transformer_wrapper.py CHANGED
@@ -155,9 +155,7 @@ class TransformerWrapper(pl.LightningModule):
155
 
156
  return relative_tokens, notes, pm
157
 
158
- def prepare_inference_mel(
159
- self, audio, beatstep, n_bars, padding_value, composer_value=None
160
- ):
161
  n_steps = n_bars * 4
162
  n_target_step = len(beatstep)
163
  sample_rate = self.config.dataset.sample_rate
@@ -240,9 +238,7 @@ class TransformerWrapper(pl.LightningModule):
240
  composer = random.sample(list(composer_to_feature_token.keys()), 1)[0]
241
 
242
  composer_value = composer_to_feature_token[composer]
243
- mix_sample_rate = (
244
- config.dataset.sample_rate if mix_sample_rate is None else mix_sample_rate
245
- )
246
 
247
  if not ignore_duplicate:
248
  if os.path.exists(midi_path):
@@ -295,8 +291,7 @@ class TransformerWrapper(pl.LightningModule):
295
  feature_tokens=fzs,
296
  audio=_audio,
297
  beatstep=beatsteps - beatsteps[0],
298
- max_length=config.dataset.target_length
299
- * max(1, (n_bars // config.dataset.n_bars)),
300
  max_batch_size=max_batch_size,
301
  n_bars=n_bars,
302
  composer_value=composer_value,
@@ -311,22 +306,15 @@ class TransformerWrapper(pl.LightningModule):
311
  y = librosa.core.resample(y, orig_sr=sr, target_sr=mix_sample_rate)
312
  sr = mix_sample_rate
313
  if add_click:
314
- clicks = (
315
- librosa.clicks(times=beatsteps, sr=sr, length=len(y)) * click_amp
316
- )
317
  y = y + clicks
318
  pm_y = pm.fluidsynth(sr)
319
  stereo = get_stereo(y, pm_y, pop_scale=stereo_amp)
320
 
321
  if show_plot:
322
- import IPython.display as ipd
323
- from IPython.display import display
324
  import note_seq
325
 
326
- display("Stereo MIX", ipd.Audio(stereo, rate=sr))
327
- display("Rendered MIDI", ipd.Audio(pm_y, rate=sr))
328
- display("Original Song", ipd.Audio(y, rate=sr))
329
- display(note_seq.plot_sequence(note_seq.midi_to_note_sequence(pm)))
330
 
331
  if save_mix:
332
  sf.write(
 
155
 
156
  return relative_tokens, notes, pm
157
 
158
+ def prepare_inference_mel(self, audio, beatstep, n_bars, padding_value, composer_value=None):
 
 
159
  n_steps = n_bars * 4
160
  n_target_step = len(beatstep)
161
  sample_rate = self.config.dataset.sample_rate
 
238
  composer = random.sample(list(composer_to_feature_token.keys()), 1)[0]
239
 
240
  composer_value = composer_to_feature_token[composer]
241
+ mix_sample_rate = config.dataset.sample_rate if mix_sample_rate is None else mix_sample_rate
 
 
242
 
243
  if not ignore_duplicate:
244
  if os.path.exists(midi_path):
 
291
  feature_tokens=fzs,
292
  audio=_audio,
293
  beatstep=beatsteps - beatsteps[0],
294
+ max_length=config.dataset.target_length * max(1, (n_bars // config.dataset.n_bars)),
 
295
  max_batch_size=max_batch_size,
296
  n_bars=n_bars,
297
  composer_value=composer_value,
 
306
  y = librosa.core.resample(y, orig_sr=sr, target_sr=mix_sample_rate)
307
  sr = mix_sample_rate
308
  if add_click:
309
+ clicks = librosa.clicks(times=beatsteps, sr=sr, length=len(y)) * click_amp
 
 
310
  y = y + clicks
311
  pm_y = pm.fluidsynth(sr)
312
  stereo = get_stereo(y, pm_y, pop_scale=stereo_amp)
313
 
314
  if show_plot:
 
 
315
  import note_seq
316
 
317
+ note_seq.plot_sequence(note_seq.midi_to_note_sequence(pm))
 
 
 
318
 
319
  if save_mix:
320
  sf.write(