HoneyTian commited on
Commit
20b7fa8
·
1 Parent(s): 3f9acc2
examples/conv_tasnet_gan/step_2_train_model.py CHANGED
@@ -211,9 +211,9 @@ def main():
211
  if last_step_idx != -1:
212
  logger.info(f"resume from steps-{last_step_idx}.")
213
  model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
214
- discriminator_pt = serialization_dir / f"steps-{last_step_idx}/discriminator.pt"
215
-
216
  optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
 
 
217
  discriminator_optimizer_pth = serialization_dir / f"steps-{last_step_idx}/discriminator_optimizer.pth"
218
 
219
  logger.info(f"load state dict for model.")
@@ -221,10 +221,11 @@ def main():
221
  state_dict = torch.load(f, map_location="cpu", weights_only=True)
222
  model.load_state_dict(state_dict, strict=True)
223
 
224
- logger.info(f"load state dict for optimizer.")
225
- with open(optimizer_pth.as_posix(), "rb") as f:
226
- state_dict = torch.load(f, map_location="cpu", weights_only=True)
227
- optimizer.load_state_dict(state_dict)
 
228
 
229
  if discriminator_pt.exists():
230
  logger.info(f"load state dict for discriminator.")
@@ -497,6 +498,8 @@ def main():
497
  total_neg_stoi_loss = 0.
498
  total_mr_stft_loss = 0.
499
  total_pesq_loss = 0.
 
 
500
  total_batches = 0.
501
 
502
  progress_bar_eval.close()
 
211
  if last_step_idx != -1:
212
  logger.info(f"resume from steps-{last_step_idx}.")
213
  model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
 
 
214
  optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
215
+
216
+ discriminator_pt = serialization_dir / f"steps-{last_step_idx}/discriminator.pt"
217
  discriminator_optimizer_pth = serialization_dir / f"steps-{last_step_idx}/discriminator_optimizer.pth"
218
 
219
  logger.info(f"load state dict for model.")
 
221
  state_dict = torch.load(f, map_location="cpu", weights_only=True)
222
  model.load_state_dict(state_dict, strict=True)
223
 
224
+ if optimizer_pth.exists():
225
+ logger.info(f"load state dict for optimizer.")
226
+ with open(optimizer_pth.as_posix(), "rb") as f:
227
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
228
+ optimizer.load_state_dict(state_dict)
229
 
230
  if discriminator_pt.exists():
231
  logger.info(f"load state dict for discriminator.")
 
498
  total_neg_stoi_loss = 0.
499
  total_mr_stft_loss = 0.
500
  total_pesq_loss = 0.
501
+ total_discriminator_g_loss = 0.
502
+ total_discriminator_d_loss = 0.
503
  total_batches = 0.
504
 
505
  progress_bar_eval.close()
examples/conv_tasnet_gan/yaml/config.yaml CHANGED
@@ -5,11 +5,11 @@ segment_size: 4
5
 
6
  win_size: 20
7
  freq_bins: 256
8
- bottleneck_channels: 256
9
  num_speakers: 1
10
- num_blocks: 4
11
- num_sub_blocks: 8
12
- sub_blocks_channels: 512
13
  sub_blocks_kernel_size: 3
14
 
15
  norm_type: "gLN"
 
5
 
6
  win_size: 20
7
  freq_bins: 256
8
+ bottleneck_channels: 128
9
  num_speakers: 1
10
+ num_blocks: 2
11
+ num_sub_blocks: 4
12
+ sub_blocks_channels: 256
13
  sub_blocks_kernel_size: 3
14
 
15
  norm_type: "gLN"
examples/conv_tasnet_gan/yaml/discriminator_config.yaml CHANGED
@@ -6,5 +6,5 @@ n_fft: 512
6
  win_size: 200
7
  hop_size: 80
8
 
9
- discriminator_dim: 24
10
  discriminator_in_channel: 2
 
6
  win_size: 200
7
  hop_size: 80
8
 
9
+ discriminator_dim: 32
10
  discriminator_in_channel: 2