XzJosh commited on
Commit
ca38eea
·
1 Parent(s): a89409b

Update train_ms.py

Browse files
Files changed (1) hide show
  1. train_ms.py +28 -19
train_ms.py CHANGED
@@ -4,6 +4,7 @@ import argparse
4
  import itertools
5
  import math
6
  import torch
 
7
  from torch import nn, optim
8
  from torch.nn import functional as F
9
  from torch.utils.data import DataLoader
@@ -38,12 +39,8 @@ from text.symbols import symbols
38
 
39
  torch.backends.cudnn.benchmark = True
40
  torch.backends.cuda.matmul.allow_tf32 = True
41
- torch.backends.cudnn.allow_tf32 = True # If encontered training problem,please try to disable TF32.
42
  torch.set_float32_matmul_precision('medium')
43
- torch.backends.cuda.sdp_kernel("flash")
44
- torch.backends.cuda.enable_flash_sdp(True)
45
- torch.backends.cuda.enable_mem_efficient_sdp(True) # Not avaliable if torch version is lower than 2.0
46
- torch.backends.cuda.enable_math_sdp(True)
47
  global_step = 0
48
 
49
 
@@ -56,6 +53,10 @@ def main():
56
  os.environ['MASTER_PORT'] = '65280'
57
 
58
  hps = utils.get_hparams()
 
 
 
 
59
  mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,))
60
 
61
 
@@ -68,7 +69,7 @@ def run(rank, n_gpus, hps):
68
  writer = SummaryWriter(log_dir=hps.model_dir)
69
  writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
70
 
71
- dist.init_process_group(backend='nccl', init_method='env://', world_size=n_gpus, rank=rank)
72
  torch.manual_seed(hps.train.seed)
73
  torch.cuda.set_device(rank)
74
 
@@ -81,9 +82,8 @@ def run(rank, n_gpus, hps):
81
  rank=rank,
82
  shuffle=True)
83
  collate_fn = TextAudioSpeakerCollate()
84
- train_loader = DataLoader(train_dataset, num_workers=24, shuffle=False, pin_memory=True,
85
- collate_fn=collate_fn, batch_sampler=train_sampler,
86
- persistent_workers=True,prefetch_factor=4) #256G Memory suitable loader.
87
  if rank == 0:
88
  eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data)
89
  eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False,
@@ -155,20 +155,29 @@ def run(rank, n_gpus, hps):
155
  net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
156
  if net_dur_disc is not None:
157
  net_dur_disc = DDP(net_dur_disc, device_ids=[rank], find_unused_parameters=True)
158
- try:
159
- if net_dur_disc is not None:
160
- _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "DUR_*.pth"), net_dur_disc, optim_dur_disc, skip_optimizer=True)
161
- _, optim_g, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g,
162
- optim_g, skip_optimizer=True)
163
- _, optim_d, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d,
164
- optim_d, skip_optimizer=True)
 
 
 
165
 
166
- epoch_str = max(epoch_str, 1)
167
- global_step = (epoch_str - 1) * len(train_loader)
168
- except Exception as e:
169
  print(e)
170
  epoch_str = 1
171
  global_step = 0
 
 
 
 
 
 
172
 
173
 
174
  scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
 
4
  import itertools
5
  import math
6
  import torch
7
+ import shutil
8
  from torch import nn, optim
9
  from torch.nn import functional as F
10
  from torch.utils.data import DataLoader
 
39
 
40
  torch.backends.cudnn.benchmark = True
41
  torch.backends.cuda.matmul.allow_tf32 = True
42
+ torch.backends.cudnn.allow_tf32 = True
43
  torch.set_float32_matmul_precision('medium')
 
 
 
 
44
  global_step = 0
45
 
46
 
 
53
  os.environ['MASTER_PORT'] = '65280'
54
 
55
  hps = utils.get_hparams()
56
+ if not hps.cont:
57
+ shutil.copy('./pretrained_models/D_0.pth','./logs/OUTPUT_MODEL/D_0.pth')
58
+ shutil.copy('./pretrained_models/G_0.pth','./logs/OUTPUT_MODEL/G_0.pth')
59
+ shutil.copy('./pretrained_models/DUR_0.pth','./logs/OUTPUT_MODEL/DUR_0.pth')
60
  mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,))
61
 
62
 
 
69
  writer = SummaryWriter(log_dir=hps.model_dir)
70
  writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
71
 
72
+ dist.init_process_group(backend= 'gloo' if os.name == 'nt' else 'nccl', init_method='env://', world_size=n_gpus, rank=rank)
73
  torch.manual_seed(hps.train.seed)
74
  torch.cuda.set_device(rank)
75
 
 
82
  rank=rank,
83
  shuffle=True)
84
  collate_fn = TextAudioSpeakerCollate()
85
+ train_loader = DataLoader(train_dataset, num_workers=2, shuffle=False, pin_memory=True,
86
+ collate_fn=collate_fn, batch_sampler=train_sampler)
 
87
  if rank == 0:
88
  eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data)
89
  eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False,
 
155
  net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
156
  if net_dur_disc is not None:
157
  net_dur_disc = DDP(net_dur_disc, device_ids=[rank], find_unused_parameters=True)
158
+
159
+ pretrain_dir = None
160
+ if pretrain_dir is None:
161
+ try:
162
+ if net_dur_disc is not None:
163
+ _, optim_dur_disc, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "DUR_*.pth"), net_dur_disc, optim_dur_disc, skip_optimizer=not hps.cont)
164
+ _, optim_g, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g,
165
+ optim_g, skip_optimizer=not hps.cont)
166
+ _, optim_d, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d,
167
+ optim_d, skip_optimizer=not hps.cont)
168
 
169
+ epoch_str = max(epoch_str, 1)
170
+ global_step = (epoch_str - 1) * len(train_loader)
171
+ except Exception as e:
172
  print(e)
173
  epoch_str = 1
174
  global_step = 0
175
+ else:
176
+ _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(pretrain_dir, "G_*.pth"), net_g,
177
+ optim_g, True)
178
+ _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(pretrain_dir, "D_*.pth"), net_d,
179
+ optim_d, True)
180
+
181
 
182
 
183
  scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)