balthou commited on
Commit
53c936d
·
1 Parent(s): fd22fde

update interface

Browse files
src/gyraudio/audio_separation/visualization/interactive_audio.py CHANGED
@@ -21,8 +21,7 @@ import numpy as np
21
  import logging
22
  from interactive_pipe.data_objects.curves import Curve, SingleCurve
23
  from interactive_pipe import interactive, KeyboardControl, Control
24
- from interactive_pipe.headless.pipeline import HeadlessPipeline
25
- from interactive_pipe.graphical.gradio_gui import InteractivePipeGradio
26
 
27
  from gyraudio.audio_separation.visualization.audio_player import audio_selector, audio_trim, audio_player
28
 
@@ -30,7 +29,6 @@ default_device = "cuda" if torch.cuda.is_available() else "cpu"
30
  LEARNT_SAMPLING_RATE = 8000
31
 
32
 
33
-
34
  @interactive(
35
  snr=(0., [-10., 10.], "SNR [dB]")
36
  )
@@ -43,7 +41,6 @@ def remix(signals, snr=0., global_params={}):
43
  return mixed_signal
44
 
45
 
46
-
47
  @interactive(std_dev=Control(0., value_range=[0., 0.1], name="extra noise std", step=0.0001),
48
  amplify=(1., [0., 10.], "amplification of everything"))
49
  def augment(signals, mixed, std_dev=0., amplify=1.):
@@ -65,7 +62,16 @@ def select_device(device=default_device, global_params={}):
65
  # model=KeyboardControl(value_default=0, value_range=[
66
  # 0, 99], keyup="pagedown", keydown="pageup")
67
  # )
 
 
 
 
 
 
68
  def audio_sep_inference(mixed, models, configs, model: int = 0, global_params={}):
 
 
 
69
  selected_model = models[model % len(models)]
70
  config = configs[model % len(models)]
71
  short_name = config.get(SHORT_NAME, "")
@@ -114,7 +120,7 @@ def zin(sig, zoom, center, num_samples=300):
114
  @interactive(
115
  center=KeyboardControl(value_default=0.5, value_range=[
116
  0., 1.], step=0.01, keyup="6", keydown="4", name="Trim (center)"),
117
- zoom=KeyboardControl(value_default=3., value_range=[
118
  0., 15.], step=1, keyup="+", keydown="-", name="Trim (zoom)"),
119
  # zoomy=KeyboardControl(
120
  # value_default=0., value_range=[-15., 15.], step=1, keyup="up", keydown="down")
@@ -161,7 +167,8 @@ def visualize_audio(signal: dict, mixed_signal, pred, zoom=1, zoomy=0., center=0
161
 
162
 
163
  @interactive(
164
- idx=("Voice 1", ["Voice 1", "Voice 2", "Voice 3", "Voice 4"], "Clean signal"),
 
165
  # idx=KeyboardControl(value_default=0, value_range=[
166
  # 0, 1000], modulo=True, keyup="8", keydown="2", name="clean signal index"),
167
  # idn=KeyboardControl(value_default=0, value_range=[
@@ -230,23 +237,14 @@ def interactive_audio_separation_visualization(
230
  all_signals: List[dict],
231
  model_list: List[torch.nn.Module],
232
  config_list: List[dict],
233
- gui="qt"
234
  ):
235
- pip = HeadlessPipeline.from_function(
236
- interactive_audio_separation_processing, cache=True)
237
- if gui == "gradio":
238
- app = InteractivePipeGradio(
239
- pipeline=pip, name="audio separation", audio=True)
240
- elif gui == "qt":
241
- from interactive_pipe.graphical.qt_gui import InteractivePipeQT
242
- app = InteractivePipeQT(
243
- pipeline=pip, name="audio separation", size=(1000, 1000), audio=True)
244
- else:
245
- from interactive_pipe.graphical.mpl_gui import InteractivePipeMatplotlib
246
- logging.warning("No support for audio player with Matplotlib")
247
- app = InteractivePipeMatplotlib(
248
- pipeline=pip, name="audio separation", size=None, audio=False)
249
- app(all_signals, model_list, config_list)
250
 
251
 
252
  def visualization(
@@ -285,7 +283,7 @@ def parse_command_line_gradio(parser: Batch = None, gradio_demo=True) -> argpars
285
  parser = parse_command_line_audio_load()
286
  default_device = "cuda" if torch.cuda.is_available() else "cpu"
287
  iparse = parser.add_argument_group("Audio separation visualization")
288
- iparse.add_argument("-e", "--experiments", type=int, nargs="+", default=[3001,],
289
  help="Experiment ids to be inferred sequentially")
290
  iparse.add_argument("-p", "--interactive", default=True,
291
  action="store_true", help="Play = Interactive mode")
 
21
  import logging
22
  from interactive_pipe.data_objects.curves import Curve, SingleCurve
23
  from interactive_pipe import interactive, KeyboardControl, Control
24
+ from interactive_pipe import interactive_pipeline
 
25
 
26
  from gyraudio.audio_separation.visualization.audio_player import audio_selector, audio_trim, audio_player
27
 
 
29
  LEARNT_SAMPLING_RATE = 8000
30
 
31
 
 
32
  @interactive(
33
  snr=(0., [-10., 10.], "SNR [dB]")
34
  )
 
41
  return mixed_signal
42
 
43
 
 
44
  @interactive(std_dev=Control(0., value_range=[0., 0.1], name="extra noise std", step=0.0001),
45
  amplify=(1., [0., 10.], "amplification of everything"))
46
  def augment(signals, mixed, std_dev=0., amplify=1.):
 
62
  # model=KeyboardControl(value_default=0, value_range=[
63
  # 0, 99], keyup="pagedown", keydown="pageup")
64
  # )
65
+ ALL_MODELS = ["Tiny UNET", "Large UNET", "Large UNET (Bias Free)"]
66
+
67
+
68
+ @interactive(
69
+ model=(ALL_MODELS[-1], ALL_MODELS, "Model selection")
70
+ )
71
  def audio_sep_inference(mixed, models, configs, model: int = 0, global_params={}):
72
+ if isinstance(model, str):
73
+ model = ALL_MODELS.index(model)
74
+ assert isinstance(model, int)
75
  selected_model = models[model % len(models)]
76
  config = configs[model % len(models)]
77
  short_name = config.get(SHORT_NAME, "")
 
120
  @interactive(
121
  center=KeyboardControl(value_default=0.5, value_range=[
122
  0., 1.], step=0.01, keyup="6", keydown="4", name="Trim (center)"),
123
+ zoom=KeyboardControl(value_default=0., value_range=[
124
  0., 15.], step=1, keyup="+", keydown="-", name="Trim (zoom)"),
125
  # zoomy=KeyboardControl(
126
  # value_default=0., value_range=[-15., 15.], step=1, keyup="up", keydown="down")
 
167
 
168
 
169
  @interactive(
170
+ idx=("Voice 1", ["Voice 1", "Voice 2",
171
+ "Voice 3", "Voice 4"], "Clean signal"),
172
  # idx=KeyboardControl(value_default=0, value_range=[
173
  # 0, 1000], modulo=True, keyup="8", keydown="2", name="clean signal index"),
174
  # idn=KeyboardControl(value_default=0, value_range=[
 
237
  all_signals: List[dict],
238
  model_list: List[torch.nn.Module],
239
  config_list: List[dict],
240
+ gui="gradio"
241
  ):
242
+
243
+ interactive_pipeline(gui=gui, cache=True, audio=True)(
244
+ interactive_audio_separation_processing
245
+ )(
246
+ all_signals, model_list, config_list
247
+ )
 
 
 
 
 
 
 
 
 
248
 
249
 
250
  def visualization(
 
283
  parser = parse_command_line_audio_load()
284
  default_device = "cuda" if torch.cuda.is_available() else "cpu"
285
  iparse = parser.add_argument_group("Audio separation visualization")
286
+ iparse.add_argument("-e", "--experiments", type=int, nargs="+", default=[4, 1004, 3001,],
287
  help="Experiment ids to be inferred sequentially")
288
  iparse.add_argument("-p", "--interactive", default=True,
289
  action="store_true", help="Play = Interactive mode")