balthou commited on
Commit
e2d04e8
·
1 Parent(s): 55ca18f

Simplified pipeline for Gradio

Browse files
src/gyraudio/audio_separation/properties.py CHANGED
@@ -68,8 +68,8 @@ CONFIGURATION = "configuration"
68
 
69
  # Signal names
70
  CLEAN = "clean"
71
- NOISY = "noise"
72
- MIXED = "mixed"
73
  PREDICTED = "predicted"
74
 
75
 
 
68
 
69
  # Signal names
70
  CLEAN = "clean"
71
+ NOISY = "pure noise"
72
+ MIXED = "noisy"
73
  PREDICTED = "predicted"
74
 
75
 
src/gyraudio/audio_separation/visualization/audio_player.py CHANGED
@@ -12,8 +12,8 @@ LOGOS = {
12
  PREDICTED: HERE/"play_logo_pred.png",
13
  MIXED: HERE/"play_logo_mixed.png",
14
  CLEAN: HERE/"play_logo_clean.png",
15
- NOISY: HERE/"play_logo_noise.png",
16
- MUTE: HERE/"mute_logo.png",
17
  }
18
  ICONS = [it for key, it in LOGOS.items()]
19
  KEYS = [key for key, it in LOGOS.items()]
@@ -22,7 +22,7 @@ ping_pong_index = 0
22
 
23
 
24
  @interactive(
25
- player=Control(MUTE, KEYS, icons=ICONS))
26
  def audio_selector(sig, mixed, pred, global_params={}, player=MUTE):
27
 
28
  global_params["selected_audio"] = player if player != MUTE else global_params.get("selected_audio", MIXED)
@@ -40,8 +40,8 @@ def audio_selector(sig, mixed, pred, global_params={}, player=MUTE):
40
  return audio_track
41
 
42
 
43
- @interactive(
44
- loop=KeyboardControl(True, keydown="l"))
45
  def audio_trim(audio_track, global_params={}, loop=True):
46
  sampling_rate = global_params.get(SAMPLING_RATE, 8000)
47
  if global_params.get("trim", False):
@@ -51,7 +51,7 @@ def audio_trim(audio_track, global_params={}, loop=True):
51
  repeat_factor = int(sampling_rate*4./(end-start))
52
  logging.debug(f"{repeat_factor}")
53
  repeat_factor = max(1, repeat_factor)
54
- if loop:
55
  repeat_factor = 1
56
  audio_trim = audio_trim.repeat(1, repeat_factor)
57
  logging.debug(f"{audio_trim.shape}")
@@ -60,10 +60,10 @@ def audio_trim(audio_track, global_params={}, loop=True):
60
  return audio_trim
61
 
62
 
63
- @interactive(
64
- volume=(100, [0, 1000], "volume"),
65
- )
66
- def audio_player(audio_trim, global_params={}, volume=100):
67
  sampling_rate = global_params.get(SAMPLING_RATE, 8000)
68
  try:
69
  if global_params.get(MUTE, True):
 
12
  PREDICTED: HERE/"play_logo_pred.png",
13
  MIXED: HERE/"play_logo_mixed.png",
14
  CLEAN: HERE/"play_logo_clean.png",
15
+ # NOISY: HERE/"play_logo_noise.png",
16
+ # MUTE: HERE/"mute_logo.png",
17
  }
18
  ICONS = [it for key, it in LOGOS.items()]
19
  KEYS = [key for key, it in LOGOS.items()]
 
22
 
23
 
24
  @interactive(
25
+ player=Control(PREDICTED, KEYS, icons=ICONS, name="Player selection"))
26
  def audio_selector(sig, mixed, pred, global_params={}, player=MUTE):
27
 
28
  global_params["selected_audio"] = player if player != MUTE else global_params.get("selected_audio", MIXED)
 
40
  return audio_track
41
 
42
 
43
+ # @interactive(
44
+ # loop=KeyboardControl(True, keydown="l"))
45
  def audio_trim(audio_track, global_params={}, loop=True):
46
  sampling_rate = global_params.get(SAMPLING_RATE, 8000)
47
  if global_params.get("trim", False):
 
51
  repeat_factor = int(sampling_rate*4./(end-start))
52
  logging.debug(f"{repeat_factor}")
53
  repeat_factor = max(1, repeat_factor)
54
+ if not loop:
55
  repeat_factor = 1
56
  audio_trim = audio_trim.repeat(1, repeat_factor)
57
  logging.debug(f"{audio_trim.shape}")
 
60
  return audio_trim
61
 
62
 
63
+ # @interactive(
64
+ # volume=(1000, [0, 1000], "volume"),
65
+ # )
66
+ def audio_player(audio_trim, global_params={}, volume=1000):
67
  sampling_rate = global_params.get(SAMPLING_RATE, 8000)
68
  try:
69
  if global_params.get(MUTE, True):
src/gyraudio/audio_separation/visualization/interactive_audio.py CHANGED
@@ -30,55 +30,6 @@ default_device = "cuda" if torch.cuda.is_available() else "cpu"
30
  LEARNT_SAMPLING_RATE = 8000
31
 
32
 
33
- @interactive(
34
- idx=KeyboardControl(value_default=0, value_range=[
35
- 0, 1000], modulo=True, keyup="8", keydown="2", name="clean signal index"),
36
- idn=KeyboardControl(value_default=0, value_range=[
37
- 0, 1000], modulo=True, keyup="9", keydown="3", name="noisy signal index")
38
- )
39
- def signal_selector(signals, idx=0, idn=0, global_params={}):
40
- if isinstance(signals, dict):
41
- clean_sigs = signals[CLEAN]
42
- clean = clean_sigs[idx % len(clean_sigs)]
43
- if BUFFERS not in clean:
44
- load_buffers_custom(clean)
45
- noise_sigs = signals[NOISY]
46
- noise = noise_sigs[idn % len(noise_sigs)]
47
- if BUFFERS not in noise:
48
- load_buffers_custom(noise)
49
- cbuf, nbuf = clean[BUFFERS], noise[BUFFERS]
50
- if clean[SAMPLING_RATE] != LEARNT_SAMPLING_RATE:
51
- cbuf = resample(cbuf, clean[SAMPLING_RATE], LEARNT_SAMPLING_RATE)
52
- clean[SAMPLING_RATE] = LEARNT_SAMPLING_RATE
53
- if noise[SAMPLING_RATE] != LEARNT_SAMPLING_RATE:
54
- nbuf = resample(nbuf, noise[SAMPLING_RATE], LEARNT_SAMPLING_RATE)
55
- noise[SAMPLING_RATE] = LEARNT_SAMPLING_RATE
56
- min_length = min(cbuf.shape[-1], nbuf.shape[-1])
57
- min_length = min_length - min_length % 1024
58
- signal = {
59
- PATHS: {
60
- CLEAN: clean[PATHS],
61
- NOISY: noise[PATHS]
62
-
63
- },
64
- BUFFERS: {
65
- CLEAN: cbuf[..., :1, :min_length],
66
- NOISY: nbuf[..., :1, :min_length],
67
- },
68
- NAME: f"Clean={clean[NAME]} | Noise={noise[NAME]}",
69
- SAMPLING_RATE: LEARNT_SAMPLING_RATE
70
- }
71
- else:
72
- # signals are loaded in CPU
73
- signal = signals[idx % len(signals)]
74
- if BUFFERS not in signal:
75
- load_buffers(signal)
76
- global_params["premixed_snr"] = signal.get("premixed_snr", None)
77
- signal[NAME] = f"File={signal[NAME]}"
78
- global_params["selected_info"] = signal[NAME]
79
- global_params[SAMPLING_RATE] = signal[SAMPLING_RATE]
80
- return signal
81
-
82
 
83
  @interactive(
84
  snr=(0., [-10., 10.], "SNR [dB]")
@@ -92,6 +43,7 @@ def remix(signals, snr=0., global_params={}):
92
  return mixed_signal
93
 
94
 
 
95
  @interactive(std_dev=Control(0., value_range=[0., 0.1], name="extra noise std", step=0.0001),
96
  amplify=(1., [0., 10.], "amplification of everything"))
97
  def augment(signals, mixed, std_dev=0., amplify=1.):
@@ -109,10 +61,10 @@ def select_device(device=default_device, global_params={}):
109
  global_params["device"] = device
110
 
111
 
112
- @interactive(
113
- model=KeyboardControl(value_default=0, value_range=[
114
- 0, 99], keyup="pagedown", keydown="pageup")
115
- )
116
  def audio_sep_inference(mixed, models, configs, model: int = 0, global_params={}):
117
  selected_model = models[model % len(models)]
118
  config = configs[model % len(models)]
@@ -161,11 +113,11 @@ def zin(sig, zoom, center, num_samples=300):
161
 
162
  @interactive(
163
  center=KeyboardControl(value_default=0.5, value_range=[
164
- 0., 1.], step=0.01, keyup="6", keydown="4"),
165
- zoom=KeyboardControl(value_default=0., value_range=[
166
- 0., 15.], step=1, keyup="+", keydown="-"),
167
- zoomy=KeyboardControl(
168
- value_default=0., value_range=[-15., 15.], step=1, keyup="up", keydown="down")
169
  )
170
  def visualize_audio(signal: dict, mixed_signal, pred, zoom=1, zoomy=0., center=0.5, global_params={}):
171
  """Create curves
@@ -208,6 +160,58 @@ def visualize_audio(signal: dict, mixed_signal, pred, zoom=1, zoomy=0., center=0
208
  return Curve(curves, ylim=[-0.04 * 1.5 ** zoomy, 0.04 * 1.5 ** zoomy], xlabel="Time index", ylabel="Amplitude", title=title)
209
 
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  def interactive_audio_separation_processing(signals, model_list, config_list):
212
  sig = signal_selector(signals)
213
  mixed = remix(sig)
 
30
  LEARNT_SAMPLING_RATE = 8000
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  @interactive(
35
  snr=(0., [-10., 10.], "SNR [dB]")
 
43
  return mixed_signal
44
 
45
 
46
+
47
  @interactive(std_dev=Control(0., value_range=[0., 0.1], name="extra noise std", step=0.0001),
48
  amplify=(1., [0., 10.], "amplification of everything"))
49
  def augment(signals, mixed, std_dev=0., amplify=1.):
 
61
  global_params["device"] = device
62
 
63
 
64
+ # @interactive(
65
+ # model=KeyboardControl(value_default=0, value_range=[
66
+ # 0, 99], keyup="pagedown", keydown="pageup")
67
+ # )
68
  def audio_sep_inference(mixed, models, configs, model: int = 0, global_params={}):
69
  selected_model = models[model % len(models)]
70
  config = configs[model % len(models)]
 
113
 
114
  @interactive(
115
  center=KeyboardControl(value_default=0.5, value_range=[
116
+ 0., 1.], step=0.01, keyup="6", keydown="4", name="Trim (center)"),
117
+ zoom=KeyboardControl(value_default=3., value_range=[
118
+ 0., 15.], step=1, keyup="+", keydown="-", name="Trim (zoom)"),
119
+ # zoomy=KeyboardControl(
120
+ # value_default=0., value_range=[-15., 15.], step=1, keyup="up", keydown="down")
121
  )
122
  def visualize_audio(signal: dict, mixed_signal, pred, zoom=1, zoomy=0., center=0.5, global_params={}):
123
  """Create curves
 
160
  return Curve(curves, ylim=[-0.04 * 1.5 ** zoomy, 0.04 * 1.5 ** zoomy], xlabel="Time index", ylabel="Amplitude", title=title)
161
 
162
 
163
+ @interactive(
164
+ idx=("Voice 1", ["Voice 1", "Voice 2",], "Clean signal"),
165
+ # idx=KeyboardControl(value_default=0, value_range=[
166
+ # 0, 1000], modulo=True, keyup="8", keydown="2", name="clean signal index"),
167
+ # idn=KeyboardControl(value_default=0, value_range=[
168
+ # 0, 1000], modulo=True, keyup="9", keydown="3", name="noisy signal index")
169
+ )
170
+ def signal_selector(signals, idx="Voice 1", idn=0, global_params={}):
171
+ idx = int(idx.split("Voice ")[-1])
172
+ if isinstance(signals, dict):
173
+ clean_sigs = signals[CLEAN]
174
+ clean = clean_sigs[idx % len(clean_sigs)]
175
+ if BUFFERS not in clean:
176
+ load_buffers_custom(clean)
177
+ noise_sigs = signals[NOISY]
178
+ noise = noise_sigs[idn % len(noise_sigs)]
179
+ if BUFFERS not in noise:
180
+ load_buffers_custom(noise)
181
+ cbuf, nbuf = clean[BUFFERS], noise[BUFFERS]
182
+ if clean[SAMPLING_RATE] != LEARNT_SAMPLING_RATE:
183
+ cbuf = resample(cbuf, clean[SAMPLING_RATE], LEARNT_SAMPLING_RATE)
184
+ clean[SAMPLING_RATE] = LEARNT_SAMPLING_RATE
185
+ if noise[SAMPLING_RATE] != LEARNT_SAMPLING_RATE:
186
+ nbuf = resample(nbuf, noise[SAMPLING_RATE], LEARNT_SAMPLING_RATE)
187
+ noise[SAMPLING_RATE] = LEARNT_SAMPLING_RATE
188
+ min_length = min(cbuf.shape[-1], nbuf.shape[-1])
189
+ min_length = min_length - min_length % 1024
190
+ signal = {
191
+ PATHS: {
192
+ CLEAN: clean[PATHS],
193
+ NOISY: noise[PATHS]
194
+
195
+ },
196
+ BUFFERS: {
197
+ CLEAN: cbuf[..., :1, :min_length],
198
+ NOISY: nbuf[..., :1, :min_length],
199
+ },
200
+ NAME: f"Clean={clean[NAME]} | Noise={noise[NAME]}",
201
+ SAMPLING_RATE: LEARNT_SAMPLING_RATE
202
+ }
203
+ else:
204
+ # signals are loaded in CPU
205
+ signal = signals[idx % len(signals)]
206
+ if BUFFERS not in signal:
207
+ load_buffers(signal)
208
+ global_params["premixed_snr"] = signal.get("premixed_snr", None)
209
+ signal[NAME] = f"File={signal[NAME]}"
210
+ global_params["selected_info"] = signal[NAME]
211
+ global_params[SAMPLING_RATE] = signal[SAMPLING_RATE]
212
+ return signal
213
+
214
+
215
  def interactive_audio_separation_processing(signals, model_list, config_list):
216
  sig = signal_selector(signals)
217
  mixed = remix(sig)