Spaces:
Paused
Paused
Update train.py
Browse files
train.py
CHANGED
@@ -30,7 +30,7 @@ from so_vits_svc_fork.modules.losses import discriminator_loss, feature_loss, ge
|
|
30 |
from so_vits_svc_fork.modules.mel_processing import mel_spectrogram_torch
|
31 |
from so_vits_svc_fork.modules.synthesizers import SynthesizerTrn
|
32 |
|
33 |
-
from so_vits_svc_fork.train import VitsLightning
|
34 |
|
35 |
LOG = getLogger(__name__)
|
36 |
torch.set_float32_matmul_precision("high")
|
@@ -68,37 +68,15 @@ class HuggingFacePushCallback(pl.Callback):
|
|
68 |
commit_message="π» cheers",
|
69 |
ignore_patterns=["*.git*", "*README.md*", "*__pycache__*"],
|
70 |
)
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
self.batch_size = 1
|
81 |
-
self.collate_fn = TextAudioCollate()
|
82 |
-
|
83 |
-
# these should be called in setup(), but we need to calculate check_val_every_n_epoch
|
84 |
-
self.train_dataset = TextAudioDataset(self.__hparams, is_validation=False)
|
85 |
-
self.val_dataset = TextAudioDataset(self.__hparams, is_validation=True)
|
86 |
-
|
87 |
-
def train_dataloader(self):
|
88 |
-
return DataLoader(
|
89 |
-
self.train_dataset,
|
90 |
-
num_workers=min(cpu_count(), self.__hparams.train.get("num_workers", 8)),
|
91 |
-
batch_size=self.batch_size,
|
92 |
-
collate_fn=self.collate_fn,
|
93 |
-
persistent_workers=self.__hparams.train.get("persistent_workers", True),
|
94 |
-
)
|
95 |
-
|
96 |
-
def val_dataloader(self):
|
97 |
-
return DataLoader(
|
98 |
-
self.val_dataset,
|
99 |
-
batch_size=1,
|
100 |
-
collate_fn=self.collate_fn,
|
101 |
-
)
|
102 |
|
103 |
|
104 |
def train(
|
|
|
30 |
from so_vits_svc_fork.modules.mel_processing import mel_spectrogram_torch
|
31 |
from so_vits_svc_fork.modules.synthesizers import SynthesizerTrn
|
32 |
|
33 |
+
from so_vits_svc_fork.train import VitsLightning, VCDataModule
|
34 |
|
35 |
LOG = getLogger(__name__)
|
36 |
torch.set_float32_matmul_precision("high")
|
|
|
68 |
commit_message="π» cheers",
|
69 |
ignore_patterns=["*.git*", "*README.md*", "*__pycache__*"],
|
70 |
)
|
71 |
+
ckpt_pattern = r'^(D_|G_)\d+\.pth$'
|
72 |
+
todelete = []
|
73 |
+
repo_ckpts = [x for x in list_repo_files(self.repo_id) if re.match(ckpt_pattern, x) and x not in ["G_0.pth", "D_0.pth"]]
|
74 |
+
local_ckpts = [x.name for x in Path(model_dir).glob("*.pth") if re.match(ckpt_pattern, x.name)]
|
75 |
+
to_delete = set(repo_ckpts) - set(local_ckpts)
|
76 |
+
|
77 |
+
for fname in to_delete:
|
78 |
+
print("π Deleting {fname} from repo")
|
79 |
+
delete_file(fname, repo_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
|
82 |
def train(
|