roychao19477 commited on
Commit
0ff1354
·
1 Parent(s): 42fbee6

Update figs

Browse files
Files changed (1) hide show
  1. app.py +33 -4
app.py CHANGED
@@ -43,12 +43,21 @@ with open(cfg_f, 'r') as f:
43
  # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
  device = "cuda"
45
  model = SEMamba(cfg).to(device)
46
- sdict = torch.load(ckpt, map_location=device)
47
- model.load_state_dict(sdict["generator"])
48
- model.eval()
49
 
50
  @spaces.GPU
51
  def enhance(filepath):
 
 
 
 
 
 
 
 
 
52
  with torch.no_grad():
53
  # load & resample
54
  wav, orig_sr = librosa.load(filepath, sr=None)
@@ -107,13 +116,33 @@ def enhance(filepath):
107
 
108
  return "enhanced.wav", fig
109
 
 
 
 
 
 
 
 
 
 
 
 
110
  with gr.Blocks() as demo:
111
  gr.Markdown(ABOUT)
112
  input_audio = gr.Audio(label="Input Audio", type="filepath", interactive=True)
 
 
 
 
 
113
  enhance_btn = gr.Button("Enhance")
114
  output_audio = gr.Audio(label="Enhanced Audio", type="filepath")
115
  plot_output = gr.Plot(label="Spectrograms")
116
 
117
- enhance_btn.click(fn=enhance, inputs=input_audio, outputs=[output_audio, plot_output])
 
 
 
 
118
 
119
  demo.queue().launch()
 
43
  # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
  device = "cuda"
45
  model = SEMamba(cfg).to(device)
46
+ #sdict = torch.load(ckpt, map_location=device)
47
+ #model.load_state_dict(sdict["generator"])
48
+ #model.eval()
49
 
50
  @spaces.GPU
51
  def enhance(filepath):
52
+ # Load model based on selection
53
+ ckpt_path = {
54
+ "VCTK-Demand": "ckpts/SEMamba_advanced.pth",
55
+ "VCTK+DNS": "ckpts/vd.pth"
56
+ }[model_name]
57
+
58
+ print("Loading:", ckpt_path)
59
+ model.load_state_dict(torch.load(ckpt_path, map_location=device)["generator"])
60
+ model.eval()
61
  with torch.no_grad():
62
  # load & resample
63
  wav, orig_sr = librosa.load(filepath, sr=None)
 
116
 
117
  return "enhanced.wav", fig
118
 
119
+ #with gr.Blocks() as demo:
120
+ # gr.Markdown(ABOUT)
121
+ # input_audio = gr.Audio(label="Input Audio", type="filepath", interactive=True)
122
+ # enhance_btn = gr.Button("Enhance")
123
+ # output_audio = gr.Audio(label="Enhanced Audio", type="filepath")
124
+ # plot_output = gr.Plot(label="Spectrograms")
125
+ #
126
+ # enhance_btn.click(fn=enhance, inputs=input_audio, outputs=[output_audio, plot_output])
127
+ #
128
+ #demo.queue().launch()
129
+
130
  with gr.Blocks() as demo:
131
  gr.Markdown(ABOUT)
132
  input_audio = gr.Audio(label="Input Audio", type="filepath", interactive=True)
133
+ model_choice = gr.Radio(
134
+ label="Choose Model",
135
+ choices=["VCTK-Demand", "VCTK+DNS"],
136
+ value="VCTK-Demand"
137
+ )
138
  enhance_btn = gr.Button("Enhance")
139
  output_audio = gr.Audio(label="Enhanced Audio", type="filepath")
140
  plot_output = gr.Plot(label="Spectrograms")
141
 
142
+ enhance_btn.click(
143
+ fn=enhance,
144
+ inputs=[input_audio, model_choice],
145
+ outputs=[output_audio, plot_output]
146
+ )
147
 
148
  demo.queue().launch()