Spaces:
Running
Running
update
Browse files
examples/conv_tasnet_gan/step_2_train_model.py
CHANGED
@@ -211,9 +211,9 @@ def main():
|
|
211 |
if last_step_idx != -1:
|
212 |
logger.info(f"resume from steps-{last_step_idx}.")
|
213 |
model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
|
214 |
-
discriminator_pt = serialization_dir / f"steps-{last_step_idx}/discriminator.pt"
|
215 |
-
|
216 |
optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
|
|
|
|
|
217 |
discriminator_optimizer_pth = serialization_dir / f"steps-{last_step_idx}/discriminator_optimizer.pth"
|
218 |
|
219 |
logger.info(f"load state dict for model.")
|
@@ -221,10 +221,11 @@ def main():
|
|
221 |
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
222 |
model.load_state_dict(state_dict, strict=True)
|
223 |
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
|
|
228 |
|
229 |
if discriminator_pt.exists():
|
230 |
logger.info(f"load state dict for discriminator.")
|
@@ -497,6 +498,8 @@ def main():
|
|
497 |
total_neg_stoi_loss = 0.
|
498 |
total_mr_stft_loss = 0.
|
499 |
total_pesq_loss = 0.
|
|
|
|
|
500 |
total_batches = 0.
|
501 |
|
502 |
progress_bar_eval.close()
|
|
|
211 |
if last_step_idx != -1:
|
212 |
logger.info(f"resume from steps-{last_step_idx}.")
|
213 |
model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
|
|
|
|
|
214 |
optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
|
215 |
+
|
216 |
+
discriminator_pt = serialization_dir / f"steps-{last_step_idx}/discriminator.pt"
|
217 |
discriminator_optimizer_pth = serialization_dir / f"steps-{last_step_idx}/discriminator_optimizer.pth"
|
218 |
|
219 |
logger.info(f"load state dict for model.")
|
|
|
221 |
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
222 |
model.load_state_dict(state_dict, strict=True)
|
223 |
|
224 |
+
if optimizer_pth.exists():
|
225 |
+
logger.info(f"load state dict for optimizer.")
|
226 |
+
with open(optimizer_pth.as_posix(), "rb") as f:
|
227 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
228 |
+
optimizer.load_state_dict(state_dict)
|
229 |
|
230 |
if discriminator_pt.exists():
|
231 |
logger.info(f"load state dict for discriminator.")
|
|
|
498 |
total_neg_stoi_loss = 0.
|
499 |
total_mr_stft_loss = 0.
|
500 |
total_pesq_loss = 0.
|
501 |
+
total_discriminator_g_loss = 0.
|
502 |
+
total_discriminator_d_loss = 0.
|
503 |
total_batches = 0.
|
504 |
|
505 |
progress_bar_eval.close()
|
examples/conv_tasnet_gan/yaml/config.yaml
CHANGED
@@ -5,11 +5,11 @@ segment_size: 4
|
|
5 |
|
6 |
win_size: 20
|
7 |
freq_bins: 256
|
8 |
-
bottleneck_channels:
|
9 |
num_speakers: 1
|
10 |
-
num_blocks:
|
11 |
-
num_sub_blocks:
|
12 |
-
sub_blocks_channels:
|
13 |
sub_blocks_kernel_size: 3
|
14 |
|
15 |
norm_type: "gLN"
|
|
|
5 |
|
6 |
win_size: 20
|
7 |
freq_bins: 256
|
8 |
+
bottleneck_channels: 128
|
9 |
num_speakers: 1
|
10 |
+
num_blocks: 2
|
11 |
+
num_sub_blocks: 4
|
12 |
+
sub_blocks_channels: 256
|
13 |
sub_blocks_kernel_size: 3
|
14 |
|
15 |
norm_type: "gLN"
|
examples/conv_tasnet_gan/yaml/discriminator_config.yaml
CHANGED
@@ -6,5 +6,5 @@ n_fft: 512
|
|
6 |
win_size: 200
|
7 |
hop_size: 80
|
8 |
|
9 |
-
discriminator_dim:
|
10 |
discriminator_in_channel: 2
|
|
|
6 |
win_size: 200
|
7 |
hop_size: 80
|
8 |
|
9 |
+
discriminator_dim: 32
|
10 |
discriminator_in_channel: 2
|