balthou commited on
Commit
769db2b
·
1 Parent(s): 5d8001d

support proper cache, fix inplace issue with legend in plot

Browse files
app.py CHANGED
@@ -3,6 +3,9 @@ import os
3
  src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))
4
  os.sys.path.append(src_path)
5
  from gyraudio.audio_separation.visualization.interactive_audio import main as interactive_audio_main
 
6
  if __name__ == "__main__":
7
- # interactive_audio_main(sys.argv[1:])
8
- interactive_audio_main(["-i", "__data_source_separation/source_separation/test/000*"])
 
 
 
3
  src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))
4
  os.sys.path.append(src_path)
5
  from gyraudio.audio_separation.visualization.interactive_audio import main as interactive_audio_main
6
+
7
  if __name__ == "__main__":
8
+ if len(sys.argv[1:]) ==0:
9
+ interactive_audio_main(["-i", "__data_source_separation/source_separation/test/000*"])
10
+ else:
11
+ interactive_audio_main(sys.argv[1:])
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  batch_processing
2
- interactive-pipe>=0.7.2
3
  torch>=2.0.0
4
  torchaudio
5
  scipy
 
1
  batch_processing
2
+ interactive-pipe>=0.8.2
3
  torch>=2.0.0
4
  torchaudio
5
  scipy
src/gyraudio/audio_separation/visualization/interactive_audio.py CHANGED
@@ -52,7 +52,8 @@ def augment(signals, mixed, std_dev=0., amplify=1.):
52
 
53
 
54
  # @interactive(
55
- # device=("cuda", ["cpu", "cuda"]) if default_device == "cuda" else ("cpu", ["cpu"])
 
56
  # )
57
  def select_device(device=default_device, global_params={}):
58
  global_params["device"] = device
@@ -76,6 +77,8 @@ def audio_sep_inference(mixed, models, configs, model: int = 0, global_params={}
76
  config = configs[model % len(models)]
77
  short_name = config.get(SHORT_NAME, "")
78
  annotations = config.get(ANNOTATIONS, "")
 
 
79
  device = global_params.get("device", "cpu")
80
  with torch.no_grad():
81
  selected_model.eval()
@@ -83,8 +86,7 @@ def audio_sep_inference(mixed, models, configs, model: int = 0, global_params={}
83
  predicted_signal, predicted_noise = selected_model(
84
  mixed.to(device).unsqueeze(0))
85
  predicted_signal = predicted_signal.squeeze(0)
86
- pred_curve = SingleCurve(y=predicted_signal[0, :].detach().cpu().numpy(),
87
- style="g-", label=f"predicted_{short_name} {annotations}")
88
  return predicted_signal, pred_curve
89
 
90
 
@@ -125,14 +127,19 @@ def zin(sig, zoom, center, num_samples=300):
125
  # zoomy=KeyboardControl(
126
  # value_default=0., value_range=[-15., 15.], step=1, keyup="up", keydown="down")
127
  )
128
- def visualize_audio(signal: dict, mixed_signal, pred, zoom=1, zoomy=0., center=0.5, global_params={}):
129
  """Create curves
130
  """
 
 
 
131
  zval = 1.5**zoom
132
  start_idx, end_idx, _skip_factor = get_trim(
133
  signal[BUFFERS][CLEAN][0, :], zval, center)
134
  global_params["trim"] = dict(start=start_idx, end=end_idx)
135
  selected = global_params.get("selected_audio", MIXED)
 
 
136
  clean = SingleCurve(y=zin(signal[BUFFERS][CLEAN][0, :], zval, center),
137
  alpha=1.,
138
  style="k-",
@@ -150,10 +157,8 @@ def visualize_audio(signal: dict, mixed_signal, pred, zoom=1, zoomy=0., center=0
150
  label=("*" if selected == MIXED else " ") + "mixed")
151
  # true_mixed = SingleCurve(y=zin(signal[BUFFERS][MIXED][0, :], zval, center),
152
  # alpha=0.3, style="b-", linewidth=1, label="true mixed")
153
- pred.y = zin(pred.y, zval, center)
154
- pred.label = ("*" if selected == PREDICTED else " ") + pred.label
155
  curves = [noisy, mixed, pred, clean]
156
- title = f"SNR in {global_params['snr']:.1f} dB"
157
  if "selected_info" in global_params:
158
  title += f" | {global_params['selected_info']}"
159
  title += "\n"
 
52
 
53
 
54
  # @interactive(
55
+ # device=("cuda", ["cpu", "cuda"]
56
+ # ) if default_device == "cuda" else ("cpu", ["cpu"])
57
  # )
58
  def select_device(device=default_device, global_params={}):
59
  global_params["device"] = device
 
77
  config = configs[model % len(models)]
78
  short_name = config.get(SHORT_NAME, "")
79
  annotations = config.get(ANNOTATIONS, "")
80
+ global_params[SHORT_NAME] = short_name
81
+ global_params[ANNOTATIONS] = annotations
82
  device = global_params.get("device", "cpu")
83
  with torch.no_grad():
84
  selected_model.eval()
 
86
  predicted_signal, predicted_noise = selected_model(
87
  mixed.to(device).unsqueeze(0))
88
  predicted_signal = predicted_signal.squeeze(0)
89
+ pred_curve = predicted_signal.detach().cpu().numpy()
 
90
  return predicted_signal, pred_curve
91
 
92
 
 
127
  # zoomy=KeyboardControl(
128
  # value_default=0., value_range=[-15., 15.], step=1, keyup="up", keydown="down")
129
  )
130
+ def visualize_audio(signal: dict, mixed_signal, predicted_signal, zoom=1, zoomy=0., center=0.5, global_params={}):
131
  """Create curves
132
  """
133
+ selected = global_params.get("selected_audio", MIXED)
134
+ short_name = global_params.get(SHORT_NAME, "")
135
+ annotations = global_params.get(ANNOTATIONS, "")
136
  zval = 1.5**zoom
137
  start_idx, end_idx, _skip_factor = get_trim(
138
  signal[BUFFERS][CLEAN][0, :], zval, center)
139
  global_params["trim"] = dict(start=start_idx, end=end_idx)
140
  selected = global_params.get("selected_audio", MIXED)
141
+ pred = SingleCurve(y=zin(predicted_signal[0, :], zval, center),
142
+ style="g-", label=("*" if selected == PREDICTED else " ")+f"predicted_{short_name} {annotations}")
143
  clean = SingleCurve(y=zin(signal[BUFFERS][CLEAN][0, :], zval, center),
144
  alpha=1.,
145
  style="k-",
 
157
  label=("*" if selected == MIXED else " ") + "mixed")
158
  # true_mixed = SingleCurve(y=zin(signal[BUFFERS][MIXED][0, :], zval, center),
159
  # alpha=0.3, style="b-", linewidth=1, label="true mixed")
 
 
160
  curves = [noisy, mixed, pred, clean]
161
+ title = f"SNR in {global_params.get('snr', np.NaN):.1f} dB"
162
  if "selected_info" in global_params:
163
  title += f" | {global_params['selected_info']}"
164
  title += "\n"