nateraw commited on
Commit
17e242c
Β·
1 Parent(s): d0ea7e0

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +10 -32
train.py CHANGED
@@ -30,7 +30,7 @@ from so_vits_svc_fork.modules.losses import discriminator_loss, feature_loss, ge
30
  from so_vits_svc_fork.modules.mel_processing import mel_spectrogram_torch
31
  from so_vits_svc_fork.modules.synthesizers import SynthesizerTrn
32
 
33
- from so_vits_svc_fork.train import VitsLightning
34
 
35
  LOG = getLogger(__name__)
36
  torch.set_float32_matmul_precision("high")
@@ -68,37 +68,15 @@ class HuggingFacePushCallback(pl.Callback):
68
  commit_message="🍻 cheers",
69
  ignore_patterns=["*.git*", "*README.md*", "*__pycache__*"],
70
  )
71
-
72
- class VCDataModule(pl.LightningDataModule):
73
- batch_size: int
74
-
75
- def __init__(self, hparams: Any):
76
- super().__init__()
77
- self.__hparams = hparams
78
- self.batch_size = hparams.train.batch_size
79
- if not isinstance(self.batch_size, int):
80
- self.batch_size = 1
81
- self.collate_fn = TextAudioCollate()
82
-
83
- # these should be called in setup(), but we need to calculate check_val_every_n_epoch
84
- self.train_dataset = TextAudioDataset(self.__hparams, is_validation=False)
85
- self.val_dataset = TextAudioDataset(self.__hparams, is_validation=True)
86
-
87
- def train_dataloader(self):
88
- return DataLoader(
89
- self.train_dataset,
90
- num_workers=min(cpu_count(), self.__hparams.train.get("num_workers", 8)),
91
- batch_size=self.batch_size,
92
- collate_fn=self.collate_fn,
93
- persistent_workers=self.__hparams.train.get("persistent_workers", True),
94
- )
95
-
96
- def val_dataloader(self):
97
- return DataLoader(
98
- self.val_dataset,
99
- batch_size=1,
100
- collate_fn=self.collate_fn,
101
- )
102
 
103
 
104
  def train(
 
30
  from so_vits_svc_fork.modules.mel_processing import mel_spectrogram_torch
31
  from so_vits_svc_fork.modules.synthesizers import SynthesizerTrn
32
 
33
+ from so_vits_svc_fork.train import VitsLightning, VCDataModule
34
 
35
  LOG = getLogger(__name__)
36
  torch.set_float32_matmul_precision("high")
 
68
  commit_message="🍻 cheers",
69
  ignore_patterns=["*.git*", "*README.md*", "*__pycache__*"],
70
  )
71
+ ckpt_pattern = r'^(D_|G_)\d+\.pth$'
72
+ todelete = []
73
+ repo_ckpts = [x for x in list_repo_files(self.repo_id) if re.match(ckpt_pattern, x) and x not in ["G_0.pth", "D_0.pth"]]
74
+ local_ckpts = [x.name for x in Path(model_dir).glob("*.pth") if re.match(ckpt_pattern, x.name)]
75
+ to_delete = set(repo_ckpts) - set(local_ckpts)
76
+
77
+ for fname in to_delete:
78
+ print("πŸ—‘ Deleting {fname} from repo")
79
+ delete_file(fname, repo_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
 
82
  def train(