wasmdashai commited on
Commit
60d02ae
·
verified ·
1 Parent(s): be47285

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -0
app.py CHANGED
@@ -407,7 +407,47 @@ class TrinerModelVITS:
407
  self.len_dataset=len(self.DataSets['train'])
408
  self.load_model()
409
  self.init_wandb()
 
 
410
  scaler = GradScaler(enabled=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
411
 
412
 
413
 
 
407
  self.len_dataset=len(self.DataSets['train'])
408
  self.load_model()
409
  self.init_wandb()
410
+ self.training_args=load_training_args(self.path_training_args)
411
+ training_args= self.training_args
412
  scaler = GradScaler(enabled=True)
413
+ for disc in self.model.discriminator.discriminators:
414
+ disc.apply_weight_norm()
415
+ self.model.decoder.apply_weight_norm()
416
+ # torch.nn.utils.weight_norm(self.decoder.conv_pre)
417
+ # torch.nn.utils.weight_norm(self.decoder.conv_post)
418
+ for flow in self.model.flow.flows:
419
+ torch.nn.utils.weight_norm(flow.conv_pre)
420
+ torch.nn.utils.weight_norm(flow.conv_post)
421
+
422
+ discriminator = self.model.discriminator
423
+ self.model.discriminator = None
424
+
425
+ optimizer = torch.optim.AdamW(
426
+ self.model.parameters(),
427
+ training_args.learning_rate,
428
+ betas=[training_args.adam_beta1, training_args.adam_beta2],
429
+ eps=training_args.adam_epsilon,
430
+ )
431
+
432
+ # Hack to be able to train on multiple device
433
+ disc_optimizer = torch.optim.AdamW(
434
+ discriminator.parameters(),
435
+ training_args.d_learning_rate,
436
+ betas=[training_args.d_adam_beta1, training_args.d_adam_beta2],
437
+ eps=training_args.adam_epsilon,
438
+ )
439
+ lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
440
+ optimizer, gamma=training_args.lr_decay, last_epoch=-1
441
+ )
442
+ disc_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
443
+ disc_optimizer, gamma=training_args.lr_decay, last_epoch=-1
444
+ )
445
+ self.models=(self.model,discriminator)
446
+ self.optimizers=(optimizer,disc_optimizer,scaler)
447
+ self.lr_schedulers=(lr_scheduler,disc_lr_scheduler)
448
+ self.tools=load_tools()
449
+ self.stute_mode=True
450
+ print(self.lr_schedulers)
451
 
452
 
453