Vladimir Alabov commited on
Commit
46b0a70
·
1 Parent(s): 1ba66c7

Refactor #3

Browse files
Files changed (48) hide show
  1. so_vits_svc_fork/__init__.py +5 -0
  2. so_vits_svc_fork/__main__.py +917 -0
  3. so_vits_svc_fork/cluster/__init__.py +48 -0
  4. so_vits_svc_fork/cluster/train_cluster.py +141 -0
  5. so_vits_svc_fork/dataset.py +87 -0
  6. so_vits_svc_fork/default_gui_presets.json +92 -0
  7. so_vits_svc_fork/f0.py +239 -0
  8. so_vits_svc_fork/gui.py +851 -0
  9. so_vits_svc_fork/hparams.py +38 -0
  10. so_vits_svc_fork/inference/__init__.py +0 -0
  11. so_vits_svc_fork/inference/core.py +692 -0
  12. so_vits_svc_fork/inference/main.py +272 -0
  13. so_vits_svc_fork/logger.py +46 -0
  14. so_vits_svc_fork/modules/__init__.py +0 -0
  15. so_vits_svc_fork/modules/attentions.py +488 -0
  16. so_vits_svc_fork/modules/commons.py +132 -0
  17. so_vits_svc_fork/modules/decoders/__init__.py +0 -0
  18. so_vits_svc_fork/modules/decoders/f0.py +46 -0
  19. so_vits_svc_fork/modules/decoders/hifigan/__init__.py +3 -0
  20. so_vits_svc_fork/modules/decoders/hifigan/_models.py +311 -0
  21. so_vits_svc_fork/modules/decoders/hifigan/_utils.py +15 -0
  22. so_vits_svc_fork/modules/decoders/mb_istft/__init__.py +15 -0
  23. so_vits_svc_fork/modules/decoders/mb_istft/_generators.py +376 -0
  24. so_vits_svc_fork/modules/decoders/mb_istft/_loss.py +11 -0
  25. so_vits_svc_fork/modules/decoders/mb_istft/_pqmf.py +128 -0
  26. so_vits_svc_fork/modules/decoders/mb_istft/_stft.py +244 -0
  27. so_vits_svc_fork/modules/decoders/mb_istft/_stft_loss.py +142 -0
  28. so_vits_svc_fork/modules/descriminators.py +177 -0
  29. so_vits_svc_fork/modules/encoders.py +136 -0
  30. so_vits_svc_fork/modules/flows.py +48 -0
  31. so_vits_svc_fork/modules/losses.py +58 -0
  32. so_vits_svc_fork/modules/mel_processing.py +205 -0
  33. so_vits_svc_fork/modules/modules.py +452 -0
  34. so_vits_svc_fork/modules/synthesizers.py +233 -0
  35. so_vits_svc_fork/preprocessing/__init__.py +0 -0
  36. so_vits_svc_fork/preprocessing/config_templates/quickvc.json +78 -0
  37. so_vits_svc_fork/preprocessing/config_templates/so-vits-svc-4.0v1-legacy.json +69 -0
  38. so_vits_svc_fork/preprocessing/config_templates/so-vits-svc-4.0v1.json +71 -0
  39. so_vits_svc_fork/preprocessing/preprocess_classify.py +95 -0
  40. so_vits_svc_fork/preprocessing/preprocess_flist_config.py +86 -0
  41. so_vits_svc_fork/preprocessing/preprocess_hubert_f0.py +157 -0
  42. so_vits_svc_fork/preprocessing/preprocess_resample.py +144 -0
  43. so_vits_svc_fork/preprocessing/preprocess_speaker_diarization.py +93 -0
  44. so_vits_svc_fork/preprocessing/preprocess_split.py +78 -0
  45. so_vits_svc_fork/preprocessing/preprocess_utils.py +5 -0
  46. so_vits_svc_fork/py.typed +0 -0
  47. so_vits_svc_fork/train.py +571 -0
  48. so_vits_svc_fork/utils.py +478 -0
so_vits_svc_fork/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __version__ = "4.1.1"
2
+
3
+ from .logger import init_logger
4
+
5
+ init_logger()
so_vits_svc_fork/__main__.py ADDED
@@ -0,0 +1,917 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from logging import getLogger
5
+ from multiprocessing import freeze_support
6
+ from pathlib import Path
7
+ from typing import Literal
8
+
9
+ import click
10
+ import torch
11
+
12
+ from so_vits_svc_fork import __version__
13
+ from so_vits_svc_fork.utils import get_optimal_device
14
+
15
+ LOG = getLogger(__name__)
16
+
17
+ IS_TEST = "test" in Path(__file__).parent.stem
18
+ if IS_TEST:
19
+ LOG.debug("Test mode is on.")
20
+
21
+
22
+ class RichHelpFormatter(click.HelpFormatter):
23
+ def __init__(
24
+ self,
25
+ indent_increment: int = 2,
26
+ width: int | None = None,
27
+ max_width: int | None = None,
28
+ ) -> None:
29
+ width = 100
30
+ super().__init__(indent_increment, width, max_width)
31
+ LOG.info(f"Version: {__version__}")
32
+
33
+
34
+ def patch_wrap_text():
35
+ orig_wrap_text = click.formatting.wrap_text
36
+
37
+ def wrap_text(
38
+ text,
39
+ width=78,
40
+ initial_indent="",
41
+ subsequent_indent="",
42
+ preserve_paragraphs=False,
43
+ ):
44
+ return orig_wrap_text(
45
+ text.replace("\n", "\n\n"),
46
+ width=width,
47
+ initial_indent=initial_indent,
48
+ subsequent_indent=subsequent_indent,
49
+ preserve_paragraphs=True,
50
+ ).replace("\n\n", "\n")
51
+
52
+ click.formatting.wrap_text = wrap_text
53
+
54
+
55
+ patch_wrap_text()
56
+
57
+ CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"], show_default=True)
58
+ click.Context.formatter_class = RichHelpFormatter
59
+
60
+
61
+ @click.group(context_settings=CONTEXT_SETTINGS)
62
+ def cli():
63
+ """so-vits-svc allows any folder structure for training data.
64
+ However, the following folder structure is recommended.\n
65
+ When training: dataset_raw/{speaker_name}/**/{wav_name}.{any_format}\n
66
+ When inference: configs/44k/config.json, logs/44k/G_XXXX.pth\n
67
+ If the folder structure is followed, you DO NOT NEED TO SPECIFY model path, config path, etc.
68
+ (The latest model will be automatically loaded.)\n
69
+ To train a model, run pre-resample, pre-config, pre-hubert, train.\n
70
+ To infer a model, run infer.
71
+ """
72
+
73
+
74
+ @cli.command()
75
+ @click.option(
76
+ "-c",
77
+ "--config-path",
78
+ type=click.Path(exists=True),
79
+ help="path to config",
80
+ default=Path("./configs/44k/config.json"),
81
+ )
82
+ @click.option(
83
+ "-m",
84
+ "--model-path",
85
+ type=click.Path(),
86
+ help="path to output dir",
87
+ default=Path("./logs/44k"),
88
+ )
89
+ @click.option(
90
+ "-t/-nt",
91
+ "--tensorboard/--no-tensorboard",
92
+ default=False,
93
+ type=bool,
94
+ help="launch tensorboard",
95
+ )
96
+ @click.option(
97
+ "-r",
98
+ "--reset-optimizer",
99
+ default=False,
100
+ type=bool,
101
+ help="reset optimizer",
102
+ is_flag=True,
103
+ )
104
+ def train(
105
+ config_path: Path,
106
+ model_path: Path,
107
+ tensorboard: bool = False,
108
+ reset_optimizer: bool = False,
109
+ ):
110
+ """Train model
111
+ If D_0.pth or G_0.pth not found, automatically download from hub."""
112
+ from .train import train
113
+
114
+ config_path = Path(config_path)
115
+ model_path = Path(model_path)
116
+
117
+ if tensorboard:
118
+ import webbrowser
119
+
120
+ from tensorboard import program
121
+
122
+ getLogger("tensorboard").setLevel(30)
123
+ tb = program.TensorBoard()
124
+ tb.configure(argv=[None, "--logdir", model_path.as_posix()])
125
+ url = tb.launch()
126
+ webbrowser.open(url)
127
+
128
+ train(
129
+ config_path=config_path, model_path=model_path, reset_optimizer=reset_optimizer
130
+ )
131
+
132
+
133
+ @cli.command()
134
+ def gui():
135
+ """Opens GUI
136
+ for conversion and realtime inference"""
137
+ from .gui import main
138
+
139
+ main()
140
+
141
+
142
+ @cli.command()
143
+ @click.argument(
144
+ "input-path",
145
+ type=click.Path(exists=True),
146
+ )
147
+ @click.option(
148
+ "-o",
149
+ "--output-path",
150
+ type=click.Path(),
151
+ help="path to output dir",
152
+ )
153
+ @click.option("-s", "--speaker", type=str, default=None, help="speaker name")
154
+ @click.option(
155
+ "-m",
156
+ "--model-path",
157
+ type=click.Path(exists=True),
158
+ default=Path("./logs/44k/"),
159
+ help="path to model",
160
+ )
161
+ @click.option(
162
+ "-c",
163
+ "--config-path",
164
+ type=click.Path(exists=True),
165
+ default=Path("./configs/44k/config.json"),
166
+ help="path to config",
167
+ )
168
+ @click.option(
169
+ "-k",
170
+ "--cluster-model-path",
171
+ type=click.Path(exists=True),
172
+ default=None,
173
+ help="path to cluster model",
174
+ )
175
+ @click.option(
176
+ "-re",
177
+ "--recursive",
178
+ type=bool,
179
+ default=False,
180
+ help="Search recursively",
181
+ is_flag=True,
182
+ )
183
+ @click.option("-t", "--transpose", type=int, default=0, help="transpose")
184
+ @click.option(
185
+ "-db", "--db-thresh", type=int, default=-20, help="threshold (DB) (RELATIVE)"
186
+ )
187
+ @click.option(
188
+ "-fm",
189
+ "--f0-method",
190
+ type=click.Choice(["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"]),
191
+ default="dio",
192
+ help="f0 prediction method",
193
+ )
194
+ @click.option(
195
+ "-a/-na",
196
+ "--auto-predict-f0/--no-auto-predict-f0",
197
+ type=bool,
198
+ default=True,
199
+ help="auto predict f0",
200
+ )
201
+ @click.option(
202
+ "-r", "--cluster-infer-ratio", type=float, default=0, help="cluster infer ratio"
203
+ )
204
+ @click.option("-n", "--noise-scale", type=float, default=0.4, help="noise scale")
205
+ @click.option("-p", "--pad-seconds", type=float, default=0.5, help="pad seconds")
206
+ @click.option(
207
+ "-d",
208
+ "--device",
209
+ type=str,
210
+ default=get_optimal_device(),
211
+ help="device",
212
+ )
213
+ @click.option("-ch", "--chunk-seconds", type=float, default=0.5, help="chunk seconds")
214
+ @click.option(
215
+ "-ab/-nab",
216
+ "--absolute-thresh/--no-absolute-thresh",
217
+ type=bool,
218
+ default=False,
219
+ help="absolute thresh",
220
+ )
221
+ @click.option(
222
+ "-mc",
223
+ "--max-chunk-seconds",
224
+ type=float,
225
+ default=40,
226
+ help="maximum allowed single chunk length, set lower if you get out of memory (0 to disable)",
227
+ )
228
+ def infer(
229
+ # paths
230
+ input_path: Path,
231
+ output_path: Path,
232
+ model_path: Path,
233
+ config_path: Path,
234
+ recursive: bool,
235
+ # svc config
236
+ speaker: str,
237
+ cluster_model_path: Path | None = None,
238
+ transpose: int = 0,
239
+ auto_predict_f0: bool = False,
240
+ cluster_infer_ratio: float = 0,
241
+ noise_scale: float = 0.4,
242
+ f0_method: Literal["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"] = "dio",
243
+ # slice config
244
+ db_thresh: int = -40,
245
+ pad_seconds: float = 0.5,
246
+ chunk_seconds: float = 0.5,
247
+ absolute_thresh: bool = False,
248
+ max_chunk_seconds: float = 40,
249
+ device: str | torch.device = get_optimal_device(),
250
+ ):
251
+ """Inference"""
252
+ from so_vits_svc_fork.inference.main import infer
253
+
254
+ if not auto_predict_f0:
255
+ LOG.warning(
256
+ f"auto_predict_f0 = False, transpose = {transpose}. If you want to change the pitch, please set transpose."
257
+ "Generally transpose = 0 does not work because your voice pitch and target voice pitch are different."
258
+ )
259
+
260
+ input_path = Path(input_path)
261
+ if output_path is None:
262
+ output_path = input_path.parent / f"{input_path.stem}.out{input_path.suffix}"
263
+ output_path = Path(output_path)
264
+ if input_path.is_dir() and not recursive:
265
+ raise ValueError(
266
+ "input_path is a directory. Use 0re or --recursive to infer recursively."
267
+ )
268
+ model_path = Path(model_path)
269
+ if model_path.is_dir():
270
+ model_path = list(
271
+ sorted(model_path.glob("G_*.pth"), key=lambda x: x.stat().st_mtime)
272
+ )[-1]
273
+ LOG.info(f"Since model_path is a directory, use {model_path}")
274
+ config_path = Path(config_path)
275
+ if cluster_model_path is not None:
276
+ cluster_model_path = Path(cluster_model_path)
277
+ infer(
278
+ # paths
279
+ input_path=input_path,
280
+ output_path=output_path,
281
+ model_path=model_path,
282
+ config_path=config_path,
283
+ recursive=recursive,
284
+ # svc config
285
+ speaker=speaker,
286
+ cluster_model_path=cluster_model_path,
287
+ transpose=transpose,
288
+ auto_predict_f0=auto_predict_f0,
289
+ cluster_infer_ratio=cluster_infer_ratio,
290
+ noise_scale=noise_scale,
291
+ f0_method=f0_method,
292
+ # slice config
293
+ db_thresh=db_thresh,
294
+ pad_seconds=pad_seconds,
295
+ chunk_seconds=chunk_seconds,
296
+ absolute_thresh=absolute_thresh,
297
+ max_chunk_seconds=max_chunk_seconds,
298
+ device=device,
299
+ )
300
+
301
+
302
+ @cli.command()
303
+ @click.option(
304
+ "-m",
305
+ "--model-path",
306
+ type=click.Path(exists=True),
307
+ default=Path("./logs/44k/"),
308
+ help="path to model",
309
+ )
310
+ @click.option(
311
+ "-c",
312
+ "--config-path",
313
+ type=click.Path(exists=True),
314
+ default=Path("./configs/44k/config.json"),
315
+ help="path to config",
316
+ )
317
+ @click.option(
318
+ "-k",
319
+ "--cluster-model-path",
320
+ type=click.Path(exists=True),
321
+ default=None,
322
+ help="path to cluster model",
323
+ )
324
+ @click.option("-t", "--transpose", type=int, default=12, help="transpose")
325
+ @click.option(
326
+ "-a/-na",
327
+ "--auto-predict-f0/--no-auto-predict-f0",
328
+ type=bool,
329
+ default=True,
330
+ help="auto predict f0 (not recommended for realtime since voice pitch will not be stable)",
331
+ )
332
+ @click.option(
333
+ "-r", "--cluster-infer-ratio", type=float, default=0, help="cluster infer ratio"
334
+ )
335
+ @click.option("-n", "--noise-scale", type=float, default=0.4, help="noise scale")
336
+ @click.option(
337
+ "-db", "--db-thresh", type=int, default=-30, help="threshold (DB) (ABSOLUTE)"
338
+ )
339
+ @click.option(
340
+ "-fm",
341
+ "--f0-method",
342
+ type=click.Choice(["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"]),
343
+ default="dio",
344
+ help="f0 prediction method",
345
+ )
346
+ @click.option("-p", "--pad-seconds", type=float, default=0.02, help="pad seconds")
347
+ @click.option("-ch", "--chunk-seconds", type=float, default=0.5, help="chunk seconds")
348
+ @click.option(
349
+ "-cr",
350
+ "--crossfade-seconds",
351
+ type=float,
352
+ default=0.01,
353
+ help="crossfade seconds",
354
+ )
355
+ @click.option(
356
+ "-ab",
357
+ "--additional-infer-before-seconds",
358
+ type=float,
359
+ default=0.2,
360
+ help="additional infer before seconds",
361
+ )
362
+ @click.option(
363
+ "-aa",
364
+ "--additional-infer-after-seconds",
365
+ type=float,
366
+ default=0.1,
367
+ help="additional infer after seconds",
368
+ )
369
+ @click.option("-b", "--block-seconds", type=float, default=0.5, help="block seconds")
370
+ @click.option(
371
+ "-d",
372
+ "--device",
373
+ type=str,
374
+ default=get_optimal_device(),
375
+ help="device",
376
+ )
377
+ @click.option("-s", "--speaker", type=str, default=None, help="speaker name")
378
+ @click.option("-v", "--version", type=int, default=2, help="version")
379
+ @click.option("-i", "--input-device", type=int, default=None, help="input device")
380
+ @click.option("-o", "--output-device", type=int, default=None, help="output device")
381
+ @click.option(
382
+ "-po",
383
+ "--passthrough-original",
384
+ type=bool,
385
+ default=False,
386
+ is_flag=True,
387
+ help="passthrough original (for latency check)",
388
+ )
389
+ def vc(
390
+ # paths
391
+ model_path: Path,
392
+ config_path: Path,
393
+ # svc config
394
+ speaker: str,
395
+ cluster_model_path: Path | None,
396
+ transpose: int,
397
+ auto_predict_f0: bool,
398
+ cluster_infer_ratio: float,
399
+ noise_scale: float,
400
+ f0_method: Literal["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"],
401
+ # slice config
402
+ db_thresh: int,
403
+ pad_seconds: float,
404
+ chunk_seconds: float,
405
+ # realtime config
406
+ crossfade_seconds: float,
407
+ additional_infer_before_seconds: float,
408
+ additional_infer_after_seconds: float,
409
+ block_seconds: float,
410
+ version: int,
411
+ input_device: int | str | None,
412
+ output_device: int | str | None,
413
+ device: torch.device,
414
+ passthrough_original: bool = False,
415
+ ) -> None:
416
+ """Realtime inference from microphone"""
417
+ from so_vits_svc_fork.inference.main import realtime
418
+
419
+ if auto_predict_f0:
420
+ LOG.warning(
421
+ "auto_predict_f0 = True in realtime inference will cause unstable voice pitch, use with caution"
422
+ )
423
+ else:
424
+ LOG.warning(
425
+ f"auto_predict_f0 = False, transpose = {transpose}. If you want to change the pitch, please change the transpose value."
426
+ "Generally transpose = 0 does not work because your voice pitch and target voice pitch are different."
427
+ )
428
+ model_path = Path(model_path)
429
+ config_path = Path(config_path)
430
+ if cluster_model_path is not None:
431
+ cluster_model_path = Path(cluster_model_path)
432
+ if model_path.is_dir():
433
+ model_path = list(
434
+ sorted(model_path.glob("G_*.pth"), key=lambda x: x.stat().st_mtime)
435
+ )[-1]
436
+ LOG.info(f"Since model_path is a directory, use {model_path}")
437
+
438
+ realtime(
439
+ # paths
440
+ model_path=model_path,
441
+ config_path=config_path,
442
+ # svc config
443
+ speaker=speaker,
444
+ cluster_model_path=cluster_model_path,
445
+ transpose=transpose,
446
+ auto_predict_f0=auto_predict_f0,
447
+ cluster_infer_ratio=cluster_infer_ratio,
448
+ noise_scale=noise_scale,
449
+ f0_method=f0_method,
450
+ # slice config
451
+ db_thresh=db_thresh,
452
+ pad_seconds=pad_seconds,
453
+ chunk_seconds=chunk_seconds,
454
+ # realtime config
455
+ crossfade_seconds=crossfade_seconds,
456
+ additional_infer_before_seconds=additional_infer_before_seconds,
457
+ additional_infer_after_seconds=additional_infer_after_seconds,
458
+ block_seconds=block_seconds,
459
+ version=version,
460
+ input_device=input_device,
461
+ output_device=output_device,
462
+ device=device,
463
+ passthrough_original=passthrough_original,
464
+ )
465
+
466
+
467
+ @cli.command()
468
+ @click.option(
469
+ "-i",
470
+ "--input-dir",
471
+ type=click.Path(exists=True),
472
+ default=Path("./dataset_raw"),
473
+ help="path to source dir",
474
+ )
475
+ @click.option(
476
+ "-o",
477
+ "--output-dir",
478
+ type=click.Path(),
479
+ default=Path("./dataset/44k"),
480
+ help="path to output dir",
481
+ )
482
+ @click.option("-s", "--sampling-rate", type=int, default=44100, help="sampling rate")
483
+ @click.option(
484
+ "-n",
485
+ "--n-jobs",
486
+ type=int,
487
+ default=-1,
488
+ help="number of jobs (optimal value may depend on your RAM capacity and audio duration per file)",
489
+ )
490
+ @click.option("-d", "--top-db", type=float, default=30, help="top db")
491
+ @click.option("-f", "--frame-seconds", type=float, default=1, help="frame seconds")
492
+ @click.option(
493
+ "-ho", "-hop", "--hop-seconds", type=float, default=0.3, help="hop seconds"
494
+ )
495
+ def pre_resample(
496
+ input_dir: Path,
497
+ output_dir: Path,
498
+ sampling_rate: int,
499
+ n_jobs: int,
500
+ top_db: int,
501
+ frame_seconds: float,
502
+ hop_seconds: float,
503
+ ) -> None:
504
+ """Preprocessing part 1: resample"""
505
+ from so_vits_svc_fork.preprocessing.preprocess_resample import preprocess_resample
506
+
507
+ input_dir = Path(input_dir)
508
+ output_dir = Path(output_dir)
509
+ preprocess_resample(
510
+ input_dir=input_dir,
511
+ output_dir=output_dir,
512
+ sampling_rate=sampling_rate,
513
+ n_jobs=n_jobs,
514
+ top_db=top_db,
515
+ frame_seconds=frame_seconds,
516
+ hop_seconds=hop_seconds,
517
+ )
518
+
519
+
520
+ from so_vits_svc_fork.preprocessing.preprocess_flist_config import CONFIG_TEMPLATE_DIR
521
+
522
+
523
+ @cli.command()
524
+ @click.option(
525
+ "-i",
526
+ "--input-dir",
527
+ type=click.Path(exists=True),
528
+ default=Path("./dataset/44k"),
529
+ help="path to source dir",
530
+ )
531
+ @click.option(
532
+ "-f",
533
+ "--filelist-path",
534
+ type=click.Path(),
535
+ default=Path("./filelists/44k"),
536
+ help="path to filelist dir",
537
+ )
538
+ @click.option(
539
+ "-c",
540
+ "--config-path",
541
+ type=click.Path(),
542
+ default=Path("./configs/44k/config.json"),
543
+ help="path to config",
544
+ )
545
+ @click.option(
546
+ "-t",
547
+ "--config-type",
548
+ type=click.Choice([x.stem for x in CONFIG_TEMPLATE_DIR.rglob("*.json")]),
549
+ default="so-vits-svc-4.0v1",
550
+ help="config type",
551
+ )
552
+ def pre_config(
553
+ input_dir: Path,
554
+ filelist_path: Path,
555
+ config_path: Path,
556
+ config_type: str,
557
+ ):
558
+ """Preprocessing part 2: config"""
559
+ from so_vits_svc_fork.preprocessing.preprocess_flist_config import preprocess_config
560
+
561
+ input_dir = Path(input_dir)
562
+ filelist_path = Path(filelist_path)
563
+ config_path = Path(config_path)
564
+ preprocess_config(
565
+ input_dir=input_dir,
566
+ train_list_path=filelist_path / "train.txt",
567
+ val_list_path=filelist_path / "val.txt",
568
+ test_list_path=filelist_path / "test.txt",
569
+ config_path=config_path,
570
+ config_name=config_type,
571
+ )
572
+
573
+
574
+ @cli.command()
575
+ @click.option(
576
+ "-i",
577
+ "--input-dir",
578
+ type=click.Path(exists=True),
579
+ default=Path("./dataset/44k"),
580
+ help="path to source dir",
581
+ )
582
+ @click.option(
583
+ "-c",
584
+ "--config-path",
585
+ type=click.Path(exists=True),
586
+ help="path to config",
587
+ default=Path("./configs/44k/config.json"),
588
+ )
589
+ @click.option(
590
+ "-n",
591
+ "--n-jobs",
592
+ type=int,
593
+ default=None,
594
+ help="number of jobs (optimal value may depend on your VRAM capacity and audio duration per file)",
595
+ )
596
+ @click.option(
597
+ "-f/-nf",
598
+ "--force-rebuild/--no-force-rebuild",
599
+ type=bool,
600
+ default=True,
601
+ help="force rebuild existing preprocessed files",
602
+ )
603
+ @click.option(
604
+ "-fm",
605
+ "--f0-method",
606
+ type=click.Choice(["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"]),
607
+ default="dio",
608
+ )
609
+ def pre_hubert(
610
+ input_dir: Path,
611
+ config_path: Path,
612
+ n_jobs: bool,
613
+ force_rebuild: bool,
614
+ f0_method: Literal["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"],
615
+ ) -> None:
616
+ """Preprocessing part 3: hubert
617
+ If the HuBERT model is not found, it will be downloaded automatically."""
618
+ from so_vits_svc_fork.preprocessing.preprocess_hubert_f0 import preprocess_hubert_f0
619
+
620
+ input_dir = Path(input_dir)
621
+ config_path = Path(config_path)
622
+ preprocess_hubert_f0(
623
+ input_dir=input_dir,
624
+ config_path=config_path,
625
+ n_jobs=n_jobs,
626
+ force_rebuild=force_rebuild,
627
+ f0_method=f0_method,
628
+ )
629
+
630
+
631
+ @cli.command()
632
+ @click.option(
633
+ "-i",
634
+ "--input-dir",
635
+ type=click.Path(exists=True),
636
+ default=Path("./dataset_raw_raw/"),
637
+ help="path to source dir",
638
+ )
639
+ @click.option(
640
+ "-o",
641
+ "--output-dir",
642
+ type=click.Path(),
643
+ default=Path("./dataset_raw/"),
644
+ help="path to output dir",
645
+ )
646
+ @click.option(
647
+ "-n",
648
+ "--n-jobs",
649
+ type=int,
650
+ default=-1,
651
+ help="number of jobs (optimal value may depend on your VRAM capacity and audio duration per file)",
652
+ )
653
+ @click.option("-min", "--min-speakers", type=int, default=2, help="min speakers")
654
+ @click.option("-max", "--max-speakers", type=int, default=2, help="max speakers")
655
+ @click.option(
656
+ "-t", "--huggingface-token", type=str, default=None, help="huggingface token"
657
+ )
658
+ @click.option("-s", "--sr", type=int, default=44100, help="sampling rate")
659
+ def pre_sd(
660
+ input_dir: Path | str,
661
+ output_dir: Path | str,
662
+ min_speakers: int,
663
+ max_speakers: int,
664
+ huggingface_token: str | None,
665
+ n_jobs: int,
666
+ sr: int,
667
+ ):
668
+ """Speech diarization using pyannote.audio"""
669
+ if huggingface_token is None:
670
+ huggingface_token = os.environ.get("HUGGINGFACE_TOKEN", None)
671
+ if huggingface_token is None:
672
+ huggingface_token = click.prompt(
673
+ "Please enter your HuggingFace token", hide_input=True
674
+ )
675
+ if os.environ.get("HUGGINGFACE_TOKEN", None) is None:
676
+ LOG.info("You can also set the HUGGINGFACE_TOKEN environment variable.")
677
+ assert huggingface_token is not None
678
+ huggingface_token = huggingface_token.rstrip(" \n\r\t\0")
679
+ if len(huggingface_token) <= 1:
680
+ raise ValueError("HuggingFace token is empty: " + huggingface_token)
681
+
682
+ if max_speakers == 1:
683
+ LOG.warning("Consider using pre-split if max_speakers == 1")
684
+ from so_vits_svc_fork.preprocessing.preprocess_speaker_diarization import (
685
+ preprocess_speaker_diarization,
686
+ )
687
+
688
+ preprocess_speaker_diarization(
689
+ input_dir=input_dir,
690
+ output_dir=output_dir,
691
+ min_speakers=min_speakers,
692
+ max_speakers=max_speakers,
693
+ huggingface_token=huggingface_token,
694
+ n_jobs=n_jobs,
695
+ sr=sr,
696
+ )
697
+
698
+
699
+ @cli.command()
700
+ @click.option(
701
+ "-i",
702
+ "--input-dir",
703
+ type=click.Path(exists=True),
704
+ default=Path("./dataset_raw_raw/"),
705
+ help="path to source dir",
706
+ )
707
+ @click.option(
708
+ "-o",
709
+ "--output-dir",
710
+ type=click.Path(),
711
+ default=Path("./dataset_raw/"),
712
+ help="path to output dir",
713
+ )
714
+ @click.option(
715
+ "-n",
716
+ "--n-jobs",
717
+ type=int,
718
+ default=-1,
719
+ help="number of jobs (optimal value may depend on your RAM capacity and audio duration per file)",
720
+ )
721
+ @click.option(
722
+ "-l",
723
+ "--max-length",
724
+ type=float,
725
+ default=10,
726
+ help="max length of each split in seconds",
727
+ )
728
+ @click.option("-d", "--top-db", type=float, default=30, help="top db")
729
+ @click.option("-f", "--frame-seconds", type=float, default=1, help="frame seconds")
730
+ @click.option(
731
+ "-ho", "-hop", "--hop-seconds", type=float, default=0.3, help="hop seconds"
732
+ )
733
+ @click.option("-s", "--sr", type=int, default=44100, help="sample rate")
734
+ def pre_split(
735
+ input_dir: Path | str,
736
+ output_dir: Path | str,
737
+ max_length: float,
738
+ top_db: int,
739
+ frame_seconds: float,
740
+ hop_seconds: float,
741
+ n_jobs: int,
742
+ sr: int,
743
+ ):
744
+ """Split audio files into multiple files"""
745
+ from so_vits_svc_fork.preprocessing.preprocess_split import preprocess_split
746
+
747
+ preprocess_split(
748
+ input_dir=input_dir,
749
+ output_dir=output_dir,
750
+ max_length=max_length,
751
+ top_db=top_db,
752
+ frame_seconds=frame_seconds,
753
+ hop_seconds=hop_seconds,
754
+ n_jobs=n_jobs,
755
+ sr=sr,
756
+ )
757
+
758
+
759
+ @cli.command()
760
+ @click.option(
761
+ "-i",
762
+ "--input-dir",
763
+ type=click.Path(exists=True),
764
+ required=True,
765
+ help="path to source dir",
766
+ )
767
+ @click.option(
768
+ "-o",
769
+ "--output-dir",
770
+ type=click.Path(),
771
+ default=None,
772
+ help="path to output dir",
773
+ )
774
+ @click.option(
775
+ "-c/-nc",
776
+ "--create-new/--no-create-new",
777
+ type=bool,
778
+ default=True,
779
+ help="create a new folder for the speaker if not exist",
780
+ )
781
+ def pre_classify(
782
+ input_dir: Path | str,
783
+ output_dir: Path | str | None,
784
+ create_new: bool,
785
+ ) -> None:
786
+ """Classify multiple audio files into multiple files"""
787
+ from so_vits_svc_fork.preprocessing.preprocess_classify import preprocess_classify
788
+
789
+ if output_dir is None:
790
+ output_dir = input_dir
791
+ preprocess_classify(
792
+ input_dir=input_dir,
793
+ output_dir=output_dir,
794
+ create_new=create_new,
795
+ )
796
+
797
+
798
+ @cli.command
799
+ def clean():
800
+ """Clean up files, only useful if you are using the default file structure"""
801
+ import shutil
802
+
803
+ folders = ["dataset", "filelists", "logs"]
804
+ # if pyip.inputYesNo(f"Are you sure you want to delete files in {folders}?") == "yes":
805
+ if input("Are you sure you want to delete files in {folders}?") in ["yes", "y"]:
806
+ for folder in folders:
807
+ if Path(folder).exists():
808
+ shutil.rmtree(folder)
809
+ LOG.info("Cleaned up files")
810
+ else:
811
+ LOG.info("Aborted")
812
+
813
+
814
+ @cli.command
815
+ @click.option(
816
+ "-i",
817
+ "--input-path",
818
+ type=click.Path(exists=True),
819
+ help="model path",
820
+ default=Path("./logs/44k/"),
821
+ )
822
+ @click.option(
823
+ "-o",
824
+ "--output-path",
825
+ type=click.Path(),
826
+ help="onnx model path to save",
827
+ default=None,
828
+ )
829
+ @click.option(
830
+ "-c",
831
+ "--config-path",
832
+ type=click.Path(),
833
+ help="config path",
834
+ default=Path("./configs/44k/config.json"),
835
+ )
836
+ @click.option(
837
+ "-d",
838
+ "--device",
839
+ type=str,
840
+ default="cpu",
841
+ help="device to use",
842
+ )
843
+ def onnx(
844
+ input_path: Path, output_path: Path, config_path: Path, device: torch.device | str
845
+ ) -> None:
846
+ """Export model to onnx (currently not working)"""
847
+ raise NotImplementedError("ONNX export is not yet supported")
848
+ input_path = Path(input_path)
849
+ if input_path.is_dir():
850
+ input_path = list(input_path.glob("*.pth"))[0]
851
+ if output_path is None:
852
+ output_path = input_path.with_suffix(".onnx")
853
+ output_path = Path(output_path)
854
+ if output_path.is_dir():
855
+ output_path = output_path / (input_path.stem + ".onnx")
856
+ config_path = Path(config_path)
857
+ device_ = torch.device(device)
858
+ from so_vits_svc_fork.modules.onnx._export import onnx_export
859
+
860
+ onnx_export(
861
+ input_path=input_path,
862
+ output_path=output_path,
863
+ config_path=config_path,
864
+ device=device_,
865
+ )
866
+
867
+
868
+ @cli.command
869
+ @click.option(
870
+ "-i",
871
+ "--input-dir",
872
+ type=click.Path(exists=True),
873
+ help="dataset directory",
874
+ default=Path("./dataset/44k"),
875
+ )
876
+ @click.option(
877
+ "-o",
878
+ "--output-path",
879
+ type=click.Path(),
880
+ help="model path to save",
881
+ default=Path("./logs/44k/kmeans.pt"),
882
+ )
883
+ @click.option("-n", "--n-clusters", type=int, help="number of clusters", default=2000)
884
+ @click.option(
885
+ "-m/-nm", "--minibatch/--no-minibatch", default=True, help="use minibatch k-means"
886
+ )
887
+ @click.option(
888
+ "-b", "--batch-size", type=int, default=4096, help="batch size for minibatch kmeans"
889
+ )
890
+ @click.option(
891
+ "-p/-np", "--partial-fit", default=False, help="use partial fit (only use with -m)"
892
+ )
893
+ def train_cluster(
894
+ input_dir: Path,
895
+ output_path: Path,
896
+ n_clusters: int,
897
+ minibatch: bool,
898
+ batch_size: int,
899
+ partial_fit: bool,
900
+ ) -> None:
901
+ """Train k-means clustering"""
902
+ from .cluster.train_cluster import main
903
+
904
+ main(
905
+ input_dir=input_dir,
906
+ output_path=output_path,
907
+ n_clusters=n_clusters,
908
+ verbose=True,
909
+ use_minibatch=minibatch,
910
+ batch_size=batch_size,
911
+ partial_fit=partial_fit,
912
+ )
913
+
914
+
915
+ if __name__ == "__main__":
916
+ freeze_support()
917
+ cli()
so_vits_svc_fork/cluster/__init__.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Any
5
+
6
+ import torch
7
+ from sklearn.cluster import KMeans
8
+
9
+
10
+ def get_cluster_model(ckpt_path: Path | str):
11
+ with Path(ckpt_path).open("rb") as f:
12
+ checkpoint = torch.load(
13
+ f, map_location="cpu"
14
+ ) # Danger of arbitrary code execution
15
+ kmeans_dict = {}
16
+ for spk, ckpt in checkpoint.items():
17
+ km = KMeans(ckpt["n_features_in_"])
18
+ km.__dict__["n_features_in_"] = ckpt["n_features_in_"]
19
+ km.__dict__["_n_threads"] = ckpt["_n_threads"]
20
+ km.__dict__["cluster_centers_"] = ckpt["cluster_centers_"]
21
+ kmeans_dict[spk] = km
22
+ return kmeans_dict
23
+
24
+
25
+ def check_speaker(model: Any, speaker: Any):
26
+ if speaker not in model:
27
+ raise ValueError(f"Speaker {speaker} not in {list(model.keys())}")
28
+
29
+
30
+ def get_cluster_result(model: Any, x: Any, speaker: Any):
31
+ """
32
+ x: np.array [t, 256]
33
+ return cluster class result
34
+ """
35
+ check_speaker(model, speaker)
36
+ return model[speaker].predict(x)
37
+
38
+
39
+ def get_cluster_center_result(model: Any, x: Any, speaker: Any):
40
+ """x: np.array [t, 256]"""
41
+ check_speaker(model, speaker)
42
+ predict = model[speaker].predict(x)
43
+ return model[speaker].cluster_centers_[predict]
44
+
45
+
46
+ def get_center(model: Any, x: Any, speaker: Any):
47
+ check_speaker(model, speaker)
48
+ return model[speaker].cluster_centers_[x]
so_vits_svc_fork/cluster/train_cluster.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from logging import getLogger
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ import torch
10
+ from cm_time import timer
11
+ from joblib import Parallel, delayed
12
+ from sklearn.cluster import KMeans, MiniBatchKMeans
13
+ from tqdm_joblib import tqdm_joblib
14
+
15
+ LOG = getLogger(__name__)
16
+
17
+
18
+ def train_cluster(
19
+ input_dir: Path | str,
20
+ n_clusters: int,
21
+ use_minibatch: bool = True,
22
+ batch_size: int = 4096,
23
+ partial_fit: bool = False,
24
+ verbose: bool = False,
25
+ ) -> dict:
26
+ input_dir = Path(input_dir)
27
+ if not partial_fit:
28
+ LOG.info(f"Loading features from {input_dir}")
29
+ features = []
30
+ for path in input_dir.rglob("*.data.pt"):
31
+ with path.open("rb") as f:
32
+ features.append(
33
+ torch.load(f, weights_only=True)["content"].squeeze(0).numpy().T
34
+ )
35
+ if not features:
36
+ raise ValueError(f"No features found in {input_dir}")
37
+ features = np.concatenate(features, axis=0).astype(np.float32)
38
+ if features.shape[0] < n_clusters:
39
+ raise ValueError(
40
+ "Too few HuBERT features to cluster. Consider using a smaller number of clusters."
41
+ )
42
+ LOG.info(
43
+ f"shape: {features.shape}, size: {features.nbytes/1024**2:.2f} MB, dtype: {features.dtype}"
44
+ )
45
+ with timer() as t:
46
+ if use_minibatch:
47
+ kmeans = MiniBatchKMeans(
48
+ n_clusters=n_clusters,
49
+ verbose=verbose,
50
+ batch_size=batch_size,
51
+ max_iter=80,
52
+ n_init="auto",
53
+ ).fit(features)
54
+ else:
55
+ kmeans = KMeans(
56
+ n_clusters=n_clusters, verbose=verbose, n_init="auto"
57
+ ).fit(features)
58
+ LOG.info(f"Clustering took {t.elapsed:.2f} seconds")
59
+
60
+ x = {
61
+ "n_features_in_": kmeans.n_features_in_,
62
+ "_n_threads": kmeans._n_threads,
63
+ "cluster_centers_": kmeans.cluster_centers_,
64
+ }
65
+ return x
66
+ else:
67
+ # minibatch partial fit
68
+ paths = list(input_dir.rglob("*.data.pt"))
69
+ if len(paths) == 0:
70
+ raise ValueError(f"No features found in {input_dir}")
71
+ LOG.info(f"Found {len(paths)} features in {input_dir}")
72
+ n_batches = math.ceil(len(paths) / batch_size)
73
+ LOG.info(f"Splitting into {n_batches} batches")
74
+ with timer() as t:
75
+ kmeans = MiniBatchKMeans(
76
+ n_clusters=n_clusters,
77
+ verbose=verbose,
78
+ batch_size=batch_size,
79
+ max_iter=80,
80
+ n_init="auto",
81
+ )
82
+ for i in range(0, len(paths), batch_size):
83
+ LOG.info(
84
+ f"Processing batch {i//batch_size+1}/{n_batches} for speaker {input_dir.stem}"
85
+ )
86
+ features = []
87
+ for path in paths[i : i + batch_size]:
88
+ with path.open("rb") as f:
89
+ features.append(
90
+ torch.load(f, weights_only=True)["content"]
91
+ .squeeze(0)
92
+ .numpy()
93
+ .T
94
+ )
95
+ features = np.concatenate(features, axis=0).astype(np.float32)
96
+ kmeans.partial_fit(features)
97
+ LOG.info(f"Clustering took {t.elapsed:.2f} seconds")
98
+
99
+ x = {
100
+ "n_features_in_": kmeans.n_features_in_,
101
+ "_n_threads": kmeans._n_threads,
102
+ "cluster_centers_": kmeans.cluster_centers_,
103
+ }
104
+ return x
105
+
106
+
107
+ def main(
108
+ input_dir: Path | str,
109
+ output_path: Path | str,
110
+ n_clusters: int = 10000,
111
+ use_minibatch: bool = True,
112
+ batch_size: int = 4096,
113
+ partial_fit: bool = False,
114
+ verbose: bool = False,
115
+ ) -> None:
116
+ input_dir = Path(input_dir)
117
+ output_path = Path(output_path)
118
+
119
+ if not (use_minibatch or not partial_fit):
120
+ raise ValueError("partial_fit requires use_minibatch")
121
+
122
+ def train_cluster_(input_path: Path, **kwargs: Any) -> tuple[str, dict]:
123
+ return input_path.stem, train_cluster(input_path, **kwargs)
124
+
125
+ with tqdm_joblib(desc="Training clusters", total=len(list(input_dir.iterdir()))):
126
+ parallel_result = Parallel(n_jobs=-1)(
127
+ delayed(train_cluster_)(
128
+ speaker_name,
129
+ n_clusters=n_clusters,
130
+ use_minibatch=use_minibatch,
131
+ batch_size=batch_size,
132
+ partial_fit=partial_fit,
133
+ verbose=verbose,
134
+ )
135
+ for speaker_name in input_dir.iterdir()
136
+ )
137
+ assert parallel_result is not None
138
+ checkpoint = dict(parallel_result)
139
+ output_path.parent.mkdir(exist_ok=True, parents=True)
140
+ with output_path.open("wb") as f:
141
+ torch.save(checkpoint, f)
so_vits_svc_fork/dataset.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from random import Random
5
+ from typing import Sequence
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.utils.data import Dataset
11
+
12
+ from .hparams import HParams
13
+
14
+
15
+ class TextAudioDataset(Dataset):
16
+ def __init__(self, hps: HParams, is_validation: bool = False):
17
+ self.datapaths = [
18
+ Path(x).parent / (Path(x).name + ".data.pt")
19
+ for x in Path(
20
+ hps.data.validation_files if is_validation else hps.data.training_files
21
+ )
22
+ .read_text("utf-8")
23
+ .splitlines()
24
+ ]
25
+ self.hps = hps
26
+ self.random = Random(hps.train.seed)
27
+ self.random.shuffle(self.datapaths)
28
+ self.max_spec_len = 800
29
+
30
+ def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
31
+ with Path(self.datapaths[index]).open("rb") as f:
32
+ data = torch.load(f, weights_only=True, map_location="cpu")
33
+
34
+ # cut long data randomly
35
+ spec_len = data["mel_spec"].shape[1]
36
+ hop_len = self.hps.data.hop_length
37
+ if spec_len > self.max_spec_len:
38
+ start = self.random.randint(0, spec_len - self.max_spec_len)
39
+ end = start + self.max_spec_len - 10
40
+ for key in data.keys():
41
+ if key == "audio":
42
+ data[key] = data[key][:, start * hop_len : end * hop_len]
43
+ elif key == "spk":
44
+ continue
45
+ else:
46
+ data[key] = data[key][..., start:end]
47
+ torch.cuda.empty_cache()
48
+ return data
49
+
50
+ def __len__(self) -> int:
51
+ return len(self.datapaths)
52
+
53
+
54
+ def _pad_stack(array: Sequence[torch.Tensor]) -> torch.Tensor:
55
+ max_idx = torch.argmax(torch.tensor([x_.shape[-1] for x_ in array]))
56
+ max_x = array[max_idx]
57
+ x_padded = [
58
+ F.pad(x_, (0, max_x.shape[-1] - x_.shape[-1]), mode="constant", value=0)
59
+ for x_ in array
60
+ ]
61
+ return torch.stack(x_padded)
62
+
63
+
64
+ class TextAudioCollate(nn.Module):
65
+ def forward(
66
+ self, batch: Sequence[dict[str, torch.Tensor]]
67
+ ) -> tuple[torch.Tensor, ...]:
68
+ batch = [b for b in batch if b is not None]
69
+ batch = list(sorted(batch, key=lambda x: x["mel_spec"].shape[1], reverse=True))
70
+ lengths = torch.tensor([b["mel_spec"].shape[1] for b in batch]).long()
71
+ results = {}
72
+ for key in batch[0].keys():
73
+ if key not in ["spk"]:
74
+ results[key] = _pad_stack([b[key] for b in batch]).cpu()
75
+ else:
76
+ results[key] = torch.tensor([[b[key]] for b in batch]).cpu()
77
+
78
+ return (
79
+ results["content"],
80
+ results["f0"],
81
+ results["spec"],
82
+ results["mel_spec"],
83
+ results["audio"],
84
+ results["spk"],
85
+ lengths,
86
+ results["uv"],
87
+ )
so_vits_svc_fork/default_gui_presets.json ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Default VC (GPU, GTX 1060)": {
3
+ "silence_threshold": -35.0,
4
+ "transpose": 12.0,
5
+ "auto_predict_f0": false,
6
+ "f0_method": "dio",
7
+ "cluster_infer_ratio": 0.0,
8
+ "noise_scale": 0.4,
9
+ "pad_seconds": 0.1,
10
+ "chunk_seconds": 0.5,
11
+ "absolute_thresh": true,
12
+ "max_chunk_seconds": 40,
13
+ "crossfade_seconds": 0.05,
14
+ "block_seconds": 0.35,
15
+ "additional_infer_before_seconds": 0.15,
16
+ "additional_infer_after_seconds": 0.1,
17
+ "realtime_algorithm": "1 (Divide constantly)",
18
+ "passthrough_original": false,
19
+ "use_gpu": true
20
+ },
21
+ "Default VC (CPU)": {
22
+ "silence_threshold": -35.0,
23
+ "transpose": 12.0,
24
+ "auto_predict_f0": false,
25
+ "f0_method": "dio",
26
+ "cluster_infer_ratio": 0.0,
27
+ "noise_scale": 0.4,
28
+ "pad_seconds": 0.1,
29
+ "chunk_seconds": 0.5,
30
+ "absolute_thresh": true,
31
+ "max_chunk_seconds": 40,
32
+ "crossfade_seconds": 0.05,
33
+ "block_seconds": 1.5,
34
+ "additional_infer_before_seconds": 0.01,
35
+ "additional_infer_after_seconds": 0.01,
36
+ "realtime_algorithm": "1 (Divide constantly)",
37
+ "passthrough_original": false,
38
+ "use_gpu": false
39
+ },
40
+ "Default VC (Mobile CPU)": {
41
+ "silence_threshold": -35.0,
42
+ "transpose": 12.0,
43
+ "auto_predict_f0": false,
44
+ "f0_method": "dio",
45
+ "cluster_infer_ratio": 0.0,
46
+ "noise_scale": 0.4,
47
+ "pad_seconds": 0.1,
48
+ "chunk_seconds": 0.5,
49
+ "absolute_thresh": true,
50
+ "max_chunk_seconds": 40,
51
+ "crossfade_seconds": 0.05,
52
+ "block_seconds": 2.5,
53
+ "additional_infer_before_seconds": 0.01,
54
+ "additional_infer_after_seconds": 0.01,
55
+ "realtime_algorithm": "1 (Divide constantly)",
56
+ "passthrough_original": false,
57
+ "use_gpu": false
58
+ },
59
+ "Default VC (Crooning)": {
60
+ "silence_threshold": -35.0,
61
+ "transpose": 12.0,
62
+ "auto_predict_f0": false,
63
+ "f0_method": "dio",
64
+ "cluster_infer_ratio": 0.0,
65
+ "noise_scale": 0.4,
66
+ "pad_seconds": 0.1,
67
+ "chunk_seconds": 0.5,
68
+ "absolute_thresh": true,
69
+ "max_chunk_seconds": 40,
70
+ "crossfade_seconds": 0.04,
71
+ "block_seconds": 0.15,
72
+ "additional_infer_before_seconds": 0.05,
73
+ "additional_infer_after_seconds": 0.05,
74
+ "realtime_algorithm": "1 (Divide constantly)",
75
+ "passthrough_original": false,
76
+ "use_gpu": true
77
+ },
78
+ "Default File": {
79
+ "silence_threshold": -35.0,
80
+ "transpose": 0.0,
81
+ "auto_predict_f0": true,
82
+ "f0_method": "crepe",
83
+ "cluster_infer_ratio": 0.0,
84
+ "noise_scale": 0.4,
85
+ "pad_seconds": 0.1,
86
+ "chunk_seconds": 0.5,
87
+ "absolute_thresh": true,
88
+ "max_chunk_seconds": 40,
89
+ "auto_play": true,
90
+ "passthrough_original": false
91
+ }
92
+ }
so_vits_svc_fork/f0.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from logging import getLogger
4
+ from typing import Any, Literal
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torchcrepe
9
+ from cm_time import timer
10
+ from numpy import dtype, float32, ndarray
11
+ from torch import FloatTensor, Tensor
12
+
13
+ from so_vits_svc_fork.utils import get_optimal_device
14
+
15
+ LOG = getLogger(__name__)
16
+
17
+
18
+ def normalize_f0(
19
+ f0: FloatTensor, x_mask: FloatTensor, uv: FloatTensor, random_scale=True
20
+ ) -> FloatTensor:
21
+ # calculate means based on x_mask
22
+ uv_sum = torch.sum(uv, dim=1, keepdim=True)
23
+ uv_sum[uv_sum == 0] = 9999
24
+ means = torch.sum(f0[:, 0, :] * uv, dim=1, keepdim=True) / uv_sum
25
+
26
+ if random_scale:
27
+ factor = torch.Tensor(f0.shape[0], 1).uniform_(0.8, 1.2).to(f0.device)
28
+ else:
29
+ factor = torch.ones(f0.shape[0], 1).to(f0.device)
30
+ # normalize f0 based on means and factor
31
+ f0_norm = (f0 - means.unsqueeze(-1)) * factor.unsqueeze(-1)
32
+ if torch.isnan(f0_norm).any():
33
+ exit(0)
34
+ return f0_norm * x_mask
35
+
36
+
37
+ def interpolate_f0(
38
+ f0: ndarray[Any, dtype[float32]]
39
+ ) -> tuple[ndarray[Any, dtype[float32]], ndarray[Any, dtype[float32]]]:
40
+ data = np.reshape(f0, (f0.size, 1))
41
+
42
+ vuv_vector = np.zeros((data.size, 1), dtype=np.float32)
43
+ vuv_vector[data > 0.0] = 1.0
44
+ vuv_vector[data <= 0.0] = 0.0
45
+
46
+ ip_data = data
47
+
48
+ frame_number = data.size
49
+ last_value = 0.0
50
+ for i in range(frame_number):
51
+ if data[i] <= 0.0:
52
+ j = i + 1
53
+ for j in range(i + 1, frame_number):
54
+ if data[j] > 0.0:
55
+ break
56
+ if j < frame_number - 1:
57
+ if last_value > 0.0:
58
+ step = (data[j] - data[i - 1]) / float(j - i)
59
+ for k in range(i, j):
60
+ ip_data[k] = data[i - 1] + step * (k - i + 1)
61
+ else:
62
+ for k in range(i, j):
63
+ ip_data[k] = data[j]
64
+ else:
65
+ for k in range(i, frame_number):
66
+ ip_data[k] = last_value
67
+ else:
68
+ ip_data[i] = data[i]
69
+ last_value = data[i]
70
+
71
+ return ip_data[:, 0], vuv_vector[:, 0]
72
+
73
+
74
+ def compute_f0_parselmouth(
75
+ wav_numpy: ndarray[Any, dtype[float32]],
76
+ p_len: None | int = None,
77
+ sampling_rate: int = 44100,
78
+ hop_length: int = 512,
79
+ ):
80
+ import parselmouth
81
+
82
+ x = wav_numpy
83
+ if p_len is None:
84
+ p_len = x.shape[0] // hop_length
85
+ else:
86
+ assert abs(p_len - x.shape[0] // hop_length) < 4, "pad length error"
87
+ time_step = hop_length / sampling_rate * 1000
88
+ f0_min = 50
89
+ f0_max = 1100
90
+ f0 = (
91
+ parselmouth.Sound(x, sampling_rate)
92
+ .to_pitch_ac(
93
+ time_step=time_step / 1000,
94
+ voicing_threshold=0.6,
95
+ pitch_floor=f0_min,
96
+ pitch_ceiling=f0_max,
97
+ )
98
+ .selected_array["frequency"]
99
+ )
100
+
101
+ pad_size = (p_len - len(f0) + 1) // 2
102
+ if pad_size > 0 or p_len - len(f0) - pad_size > 0:
103
+ f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant")
104
+ return f0
105
+
106
+
107
+ def _resize_f0(
108
+ x: ndarray[Any, dtype[float32]], target_len: int
109
+ ) -> ndarray[Any, dtype[float32]]:
110
+ source = np.array(x)
111
+ source[source < 0.001] = np.nan
112
+ target = np.interp(
113
+ np.arange(0, len(source) * target_len, len(source)) / target_len,
114
+ np.arange(0, len(source)),
115
+ source,
116
+ )
117
+ res = np.nan_to_num(target)
118
+ return res
119
+
120
+
121
+ def compute_f0_pyworld(
122
+ wav_numpy: ndarray[Any, dtype[float32]],
123
+ p_len: None | int = None,
124
+ sampling_rate: int = 44100,
125
+ hop_length: int = 512,
126
+ type_: Literal["dio", "harvest"] = "dio",
127
+ ):
128
+ import pyworld
129
+
130
+ if p_len is None:
131
+ p_len = wav_numpy.shape[0] // hop_length
132
+ if type_ == "dio":
133
+ f0, t = pyworld.dio(
134
+ wav_numpy.astype(np.double),
135
+ fs=sampling_rate,
136
+ f0_ceil=f0_max,
137
+ f0_floor=f0_min,
138
+ frame_period=1000 * hop_length / sampling_rate,
139
+ )
140
+ elif type_ == "harvest":
141
+ f0, t = pyworld.harvest(
142
+ wav_numpy.astype(np.double),
143
+ fs=sampling_rate,
144
+ f0_ceil=f0_max,
145
+ f0_floor=f0_min,
146
+ frame_period=1000 * hop_length / sampling_rate,
147
+ )
148
+ f0 = pyworld.stonemask(wav_numpy.astype(np.double), f0, t, sampling_rate)
149
+ for index, pitch in enumerate(f0):
150
+ f0[index] = round(pitch, 1)
151
+ return _resize_f0(f0, p_len)
152
+
153
+
154
+ def compute_f0_crepe(
155
+ wav_numpy: ndarray[Any, dtype[float32]],
156
+ p_len: None | int = None,
157
+ sampling_rate: int = 44100,
158
+ hop_length: int = 512,
159
+ device: str | torch.device = get_optimal_device(),
160
+ model: Literal["full", "tiny"] = "full",
161
+ ):
162
+ audio = torch.from_numpy(wav_numpy).to(device, copy=True)
163
+ audio = torch.unsqueeze(audio, dim=0)
164
+
165
+ if audio.ndim == 2 and audio.shape[0] > 1:
166
+ audio = torch.mean(audio, dim=0, keepdim=True).detach()
167
+ # (T) -> (1, T)
168
+ audio = audio.detach()
169
+
170
+ pitch: Tensor = torchcrepe.predict(
171
+ audio,
172
+ sampling_rate,
173
+ hop_length,
174
+ f0_min,
175
+ f0_max,
176
+ model,
177
+ batch_size=hop_length * 2,
178
+ device=device,
179
+ pad=True,
180
+ )
181
+
182
+ f0 = pitch.squeeze(0).cpu().float().numpy()
183
+ p_len = p_len or wav_numpy.shape[0] // hop_length
184
+ f0 = _resize_f0(f0, p_len)
185
+ return f0
186
+
187
+
188
+ def compute_f0(
189
+ wav_numpy: ndarray[Any, dtype[float32]],
190
+ p_len: None | int = None,
191
+ sampling_rate: int = 44100,
192
+ hop_length: int = 512,
193
+ method: Literal["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"] = "dio",
194
+ **kwargs,
195
+ ):
196
+ with timer() as t:
197
+ wav_numpy = wav_numpy.astype(np.float32)
198
+ wav_numpy /= np.quantile(np.abs(wav_numpy), 0.999)
199
+ if method in ["dio", "harvest"]:
200
+ f0 = compute_f0_pyworld(wav_numpy, p_len, sampling_rate, hop_length, method)
201
+ elif method == "crepe":
202
+ f0 = compute_f0_crepe(wav_numpy, p_len, sampling_rate, hop_length, **kwargs)
203
+ elif method == "crepe-tiny":
204
+ f0 = compute_f0_crepe(
205
+ wav_numpy, p_len, sampling_rate, hop_length, model="tiny", **kwargs
206
+ )
207
+ elif method == "parselmouth":
208
+ f0 = compute_f0_parselmouth(wav_numpy, p_len, sampling_rate, hop_length)
209
+ else:
210
+ raise ValueError(
211
+ "type must be dio, crepe, crepe-tiny, harvest or parselmouth"
212
+ )
213
+ rtf = t.elapsed / (len(wav_numpy) / sampling_rate)
214
+ LOG.info(f"F0 inference time: {t.elapsed:.3f}s, RTF: {rtf:.3f}")
215
+ return f0
216
+
217
+
218
+ def f0_to_coarse(f0: torch.Tensor | float):
219
+ is_torch = isinstance(f0, torch.Tensor)
220
+ f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700)
221
+ f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (
222
+ f0_mel_max - f0_mel_min
223
+ ) + 1
224
+
225
+ f0_mel[f0_mel <= 1] = 1
226
+ f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
227
+ f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(np.int)
228
+ assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (
229
+ f0_coarse.max(),
230
+ f0_coarse.min(),
231
+ )
232
+ return f0_coarse
233
+
234
+
235
+ f0_bin = 256
236
+ f0_max = 1100.0
237
+ f0_min = 50.0
238
+ f0_mel_min = 1127 * np.log(1 + f0_min / 700)
239
+ f0_mel_max = 1127 * np.log(1 + f0_max / 700)
so_vits_svc_fork/gui.py ADDED
@@ -0,0 +1,851 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import multiprocessing
5
+ import os
6
+ from copy import copy
7
+ from logging import getLogger
8
+ from pathlib import Path
9
+
10
+ import PySimpleGUI as sg
11
+ import sounddevice as sd
12
+ import soundfile as sf
13
+ import torch
14
+ from pebble import ProcessFuture, ProcessPool
15
+
16
+ from . import __version__
17
+ from .utils import get_optimal_device
18
+
19
+ GUI_DEFAULT_PRESETS_PATH = Path(__file__).parent / "default_gui_presets.json"
20
+ GUI_PRESETS_PATH = Path("./user_gui_presets.json").absolute()
21
+
22
+ LOG = getLogger(__name__)
23
+
24
+
25
+ def play_audio(path: Path | str):
26
+ if isinstance(path, Path):
27
+ path = path.as_posix()
28
+ data, sr = sf.read(path)
29
+ sd.play(data, sr)
30
+
31
+
32
+ def load_presets() -> dict:
33
+ defaults = json.loads(GUI_DEFAULT_PRESETS_PATH.read_text("utf-8"))
34
+ users = (
35
+ json.loads(GUI_PRESETS_PATH.read_text("utf-8"))
36
+ if GUI_PRESETS_PATH.exists()
37
+ else {}
38
+ )
39
+ # prioriy: defaults > users
40
+ # order: defaults -> users
41
+ return {**defaults, **users, **defaults}
42
+
43
+
44
+ def add_preset(name: str, preset: dict) -> dict:
45
+ presets = load_presets()
46
+ presets[name] = preset
47
+ with GUI_PRESETS_PATH.open("w") as f:
48
+ json.dump(presets, f, indent=2)
49
+ return load_presets()
50
+
51
+
52
+ def delete_preset(name: str) -> dict:
53
+ presets = load_presets()
54
+ if name in presets:
55
+ del presets[name]
56
+ else:
57
+ LOG.warning(f"Cannot delete preset {name} because it does not exist.")
58
+ with GUI_PRESETS_PATH.open("w") as f:
59
+ json.dump(presets, f, indent=2)
60
+ return load_presets()
61
+
62
+
63
+ def get_output_path(input_path: Path) -> Path:
64
+ # Default output path
65
+ output_path = input_path.parent / f"{input_path.stem}.out{input_path.suffix}"
66
+
67
+ # Increment file number in path if output file already exists
68
+ file_num = 1
69
+ while output_path.exists():
70
+ output_path = (
71
+ input_path.parent / f"{input_path.stem}.out_{file_num}{input_path.suffix}"
72
+ )
73
+ file_num += 1
74
+ return output_path
75
+
76
+
77
+ def get_supported_file_types() -> tuple[tuple[str, str], ...]:
78
+ res = tuple(
79
+ [
80
+ (extension, f".{extension.lower()}")
81
+ for extension in sf.available_formats().keys()
82
+ ]
83
+ )
84
+
85
+ # Sort by popularity
86
+ common_file_types = ["WAV", "MP3", "FLAC", "OGG", "M4A", "WMA"]
87
+ res = sorted(
88
+ res,
89
+ key=lambda x: common_file_types.index(x[0])
90
+ if x[0] in common_file_types
91
+ else len(common_file_types),
92
+ )
93
+ return res
94
+
95
+
96
+ def get_supported_file_types_concat() -> tuple[tuple[str, str], ...]:
97
+ return (("Audio", " ".join(sf.available_formats().keys())),)
98
+
99
+
100
+ def validate_output_file_type(output_path: Path) -> bool:
101
+ supported_file_types = sorted(
102
+ [f".{extension.lower()}" for extension in sf.available_formats().keys()]
103
+ )
104
+ if not output_path.suffix:
105
+ sg.popup_ok(
106
+ "Error: Output path missing file type extension, enter "
107
+ + "one of the following manually:\n\n"
108
+ + "\n".join(supported_file_types)
109
+ )
110
+ return False
111
+ if output_path.suffix.lower() not in supported_file_types:
112
+ sg.popup_ok(
113
+ f"Error: {output_path.suffix.lower()} is not a supported "
114
+ + "extension; use one of the following:\n\n"
115
+ + "\n".join(supported_file_types)
116
+ )
117
+ return False
118
+ return True
119
+
120
+
121
+ def get_devices(
122
+ update: bool = True,
123
+ ) -> tuple[list[str], list[str], list[int], list[int]]:
124
+ if update:
125
+ sd._terminate()
126
+ sd._initialize()
127
+ devices = sd.query_devices()
128
+ hostapis = sd.query_hostapis()
129
+ for hostapi in hostapis:
130
+ for device_idx in hostapi["devices"]:
131
+ devices[device_idx]["hostapi_name"] = hostapi["name"]
132
+ input_devices = [
133
+ f"{d['name']} ({d['hostapi_name']})"
134
+ for d in devices
135
+ if d["max_input_channels"] > 0
136
+ ]
137
+ output_devices = [
138
+ f"{d['name']} ({d['hostapi_name']})"
139
+ for d in devices
140
+ if d["max_output_channels"] > 0
141
+ ]
142
+ input_devices_indices = [d["index"] for d in devices if d["max_input_channels"] > 0]
143
+ output_devices_indices = [
144
+ d["index"] for d in devices if d["max_output_channels"] > 0
145
+ ]
146
+ return input_devices, output_devices, input_devices_indices, output_devices_indices
147
+
148
+
149
+ def after_inference(window: sg.Window, path: Path, auto_play: bool, output_path: Path):
150
+ try:
151
+ LOG.info(f"Finished inference for {path.stem}{path.suffix}")
152
+ window["infer"].update(disabled=False)
153
+
154
+ if auto_play:
155
+ play_audio(output_path)
156
+ except Exception as e:
157
+ LOG.exception(e)
158
+
159
+
160
+ def main():
161
+ LOG.info(f"version: {__version__}")
162
+
163
+ # sg.theme("Dark")
164
+ sg.theme_add_new(
165
+ "Very Dark",
166
+ {
167
+ "BACKGROUND": "#111111",
168
+ "TEXT": "#FFFFFF",
169
+ "INPUT": "#444444",
170
+ "TEXT_INPUT": "#FFFFFF",
171
+ "SCROLL": "#333333",
172
+ "BUTTON": ("white", "#112233"),
173
+ "PROGRESS": ("#111111", "#333333"),
174
+ "BORDER": 2,
175
+ "SLIDER_DEPTH": 2,
176
+ "PROGRESS_DEPTH": 2,
177
+ },
178
+ )
179
+ sg.theme("Very Dark")
180
+
181
+ model_candidates = list(sorted(Path("./logs/44k/").glob("G_*.pth")))
182
+
183
+ frame_contents = {
184
+ "Paths": [
185
+ [
186
+ sg.Text("Model path"),
187
+ sg.Push(),
188
+ sg.InputText(
189
+ key="model_path",
190
+ default_text=model_candidates[-1].absolute().as_posix()
191
+ if model_candidates
192
+ else "",
193
+ enable_events=True,
194
+ ),
195
+ sg.FileBrowse(
196
+ initial_folder=Path("./logs/44k/").absolute
197
+ if Path("./logs/44k/").exists()
198
+ else Path(".").absolute().as_posix(),
199
+ key="model_path_browse",
200
+ file_types=(
201
+ ("PyTorch", "G_*.pth G_*.pt"),
202
+ ("Pytorch", "*.pth *.pt"),
203
+ ),
204
+ ),
205
+ ],
206
+ [
207
+ sg.Text("Config path"),
208
+ sg.Push(),
209
+ sg.InputText(
210
+ key="config_path",
211
+ default_text=Path("./configs/44k/config.json").absolute().as_posix()
212
+ if Path("./configs/44k/config.json").exists()
213
+ else "",
214
+ enable_events=True,
215
+ ),
216
+ sg.FileBrowse(
217
+ initial_folder=Path("./configs/44k/").as_posix()
218
+ if Path("./configs/44k/").exists()
219
+ else Path(".").absolute().as_posix(),
220
+ key="config_path_browse",
221
+ file_types=(("JSON", "*.json"),),
222
+ ),
223
+ ],
224
+ [
225
+ sg.Text("Cluster model path (Optional)"),
226
+ sg.Push(),
227
+ sg.InputText(
228
+ key="cluster_model_path",
229
+ default_text=Path("./logs/44k/kmeans.pt").absolute().as_posix()
230
+ if Path("./logs/44k/kmeans.pt").exists()
231
+ else "",
232
+ enable_events=True,
233
+ ),
234
+ sg.FileBrowse(
235
+ initial_folder="./logs/44k/"
236
+ if Path("./logs/44k/").exists()
237
+ else ".",
238
+ key="cluster_model_path_browse",
239
+ file_types=(("PyTorch", "*.pt"), ("Pickle", "*.pt *.pth *.pkl")),
240
+ ),
241
+ ],
242
+ ],
243
+ "Common": [
244
+ [
245
+ sg.Text("Speaker"),
246
+ sg.Push(),
247
+ sg.Combo(values=[], key="speaker", size=(20, 1)),
248
+ ],
249
+ [
250
+ sg.Text("Silence threshold"),
251
+ sg.Push(),
252
+ sg.Slider(
253
+ range=(-60.0, 0),
254
+ orientation="h",
255
+ key="silence_threshold",
256
+ resolution=0.1,
257
+ ),
258
+ ],
259
+ [
260
+ sg.Text(
261
+ "Pitch (12 = 1 octave)\n"
262
+ "ADJUST THIS based on your voice\n"
263
+ "when Auto predict F0 is turned off.",
264
+ size=(None, 4),
265
+ ),
266
+ sg.Push(),
267
+ sg.Slider(
268
+ range=(-36, 36),
269
+ orientation="h",
270
+ key="transpose",
271
+ tick_interval=12,
272
+ ),
273
+ ],
274
+ [
275
+ sg.Checkbox(
276
+ key="auto_predict_f0",
277
+ text="Auto predict F0 (Pitch may become unstable when turned on in real-time inference.)",
278
+ )
279
+ ],
280
+ [
281
+ sg.Text("F0 prediction method"),
282
+ sg.Push(),
283
+ sg.Combo(
284
+ ["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"],
285
+ key="f0_method",
286
+ ),
287
+ ],
288
+ [
289
+ sg.Text("Cluster infer ratio"),
290
+ sg.Push(),
291
+ sg.Slider(
292
+ range=(0, 1.0),
293
+ orientation="h",
294
+ key="cluster_infer_ratio",
295
+ resolution=0.01,
296
+ ),
297
+ ],
298
+ [
299
+ sg.Text("Noise scale"),
300
+ sg.Push(),
301
+ sg.Slider(
302
+ range=(0.0, 1.0),
303
+ orientation="h",
304
+ key="noise_scale",
305
+ resolution=0.01,
306
+ ),
307
+ ],
308
+ [
309
+ sg.Text("Pad seconds"),
310
+ sg.Push(),
311
+ sg.Slider(
312
+ range=(0.0, 1.0),
313
+ orientation="h",
314
+ key="pad_seconds",
315
+ resolution=0.01,
316
+ ),
317
+ ],
318
+ [
319
+ sg.Text("Chunk seconds"),
320
+ sg.Push(),
321
+ sg.Slider(
322
+ range=(0.0, 3.0),
323
+ orientation="h",
324
+ key="chunk_seconds",
325
+ resolution=0.01,
326
+ ),
327
+ ],
328
+ [
329
+ sg.Text("Max chunk seconds (set lower if Out Of Memory, 0 to disable)"),
330
+ sg.Push(),
331
+ sg.Slider(
332
+ range=(0.0, 240.0),
333
+ orientation="h",
334
+ key="max_chunk_seconds",
335
+ resolution=1.0,
336
+ ),
337
+ ],
338
+ [
339
+ sg.Checkbox(
340
+ key="absolute_thresh",
341
+ text="Absolute threshold (ignored (True) in realtime inference)",
342
+ )
343
+ ],
344
+ ],
345
+ "File": [
346
+ [
347
+ sg.Text("Input audio path"),
348
+ sg.Push(),
349
+ sg.InputText(key="input_path", enable_events=True),
350
+ sg.FileBrowse(
351
+ initial_folder=".",
352
+ key="input_path_browse",
353
+ file_types=get_supported_file_types_concat(),
354
+ ),
355
+ sg.FolderBrowse(
356
+ button_text="Browse(Folder)",
357
+ initial_folder=".",
358
+ key="input_path_folder_browse",
359
+ target="input_path",
360
+ ),
361
+ sg.Button("Play", key="play_input"),
362
+ ],
363
+ [
364
+ sg.Text("Output audio path"),
365
+ sg.Push(),
366
+ sg.InputText(key="output_path"),
367
+ sg.FileSaveAs(
368
+ initial_folder=".",
369
+ key="output_path_browse",
370
+ file_types=get_supported_file_types(),
371
+ ),
372
+ ],
373
+ [sg.Checkbox(key="auto_play", text="Auto play", default=True)],
374
+ ],
375
+ "Realtime": [
376
+ [
377
+ sg.Text("Crossfade seconds"),
378
+ sg.Push(),
379
+ sg.Slider(
380
+ range=(0, 0.6),
381
+ orientation="h",
382
+ key="crossfade_seconds",
383
+ resolution=0.001,
384
+ ),
385
+ ],
386
+ [
387
+ sg.Text(
388
+ "Block seconds", # \n(big -> more robust, slower, (the same) latency)"
389
+ tooltip="Big -> more robust, slower, (the same) latency",
390
+ ),
391
+ sg.Push(),
392
+ sg.Slider(
393
+ range=(0, 3.0),
394
+ orientation="h",
395
+ key="block_seconds",
396
+ resolution=0.001,
397
+ ),
398
+ ],
399
+ [
400
+ sg.Text(
401
+ "Additional Infer seconds (before)", # \n(big -> more robust, slower)"
402
+ tooltip="Big -> more robust, slower, additional latency",
403
+ ),
404
+ sg.Push(),
405
+ sg.Slider(
406
+ range=(0, 2.0),
407
+ orientation="h",
408
+ key="additional_infer_before_seconds",
409
+ resolution=0.001,
410
+ ),
411
+ ],
412
+ [
413
+ sg.Text(
414
+ "Additional Infer seconds (after)", # \n(big -> more robust, slower, additional latency)"
415
+ tooltip="Big -> more robust, slower, additional latency",
416
+ ),
417
+ sg.Push(),
418
+ sg.Slider(
419
+ range=(0, 2.0),
420
+ orientation="h",
421
+ key="additional_infer_after_seconds",
422
+ resolution=0.001,
423
+ ),
424
+ ],
425
+ [
426
+ sg.Text("Realtime algorithm"),
427
+ sg.Push(),
428
+ sg.Combo(
429
+ ["2 (Divide by speech)", "1 (Divide constantly)"],
430
+ default_value="1 (Divide constantly)",
431
+ key="realtime_algorithm",
432
+ ),
433
+ ],
434
+ [
435
+ sg.Text("Input device"),
436
+ sg.Push(),
437
+ sg.Combo(
438
+ key="input_device",
439
+ values=[],
440
+ size=(60, 1),
441
+ ),
442
+ ],
443
+ [
444
+ sg.Text("Output device"),
445
+ sg.Push(),
446
+ sg.Combo(
447
+ key="output_device",
448
+ values=[],
449
+ size=(60, 1),
450
+ ),
451
+ ],
452
+ [
453
+ sg.Checkbox(
454
+ "Passthrough original audio (for latency check)",
455
+ key="passthrough_original",
456
+ default=False,
457
+ ),
458
+ sg.Push(),
459
+ sg.Button("Refresh devices", key="refresh_devices"),
460
+ ],
461
+ [
462
+ sg.Frame(
463
+ "Notes",
464
+ [
465
+ [
466
+ sg.Text(
467
+ "In Realtime Inference:\n"
468
+ " - Setting F0 prediction method to 'crepe` may cause performance degradation.\n"
469
+ " - Auto Predict F0 must be turned off.\n"
470
+ "If the audio sounds mumbly and choppy:\n"
471
+ " Case: The inference has not been made in time (Increase Block seconds)\n"
472
+ " Case: Mic input is low (Decrease Silence threshold)\n"
473
+ )
474
+ ]
475
+ ],
476
+ ),
477
+ ],
478
+ ],
479
+ "Presets": [
480
+ [
481
+ sg.Text("Presets"),
482
+ sg.Push(),
483
+ sg.Combo(
484
+ key="presets",
485
+ values=list(load_presets().keys()),
486
+ size=(40, 1),
487
+ enable_events=True,
488
+ ),
489
+ sg.Button("Delete preset", key="delete_preset"),
490
+ ],
491
+ [
492
+ sg.Text("Preset name"),
493
+ sg.Stretch(),
494
+ sg.InputText(key="preset_name", size=(26, 1)),
495
+ sg.Button("Add current settings as a preset", key="add_preset"),
496
+ ],
497
+ ],
498
+ }
499
+
500
+ # frames
501
+ frames = {}
502
+ for name, items in frame_contents.items():
503
+ frame = sg.Frame(name, items)
504
+ frame.expand_x = True
505
+ frames[name] = [frame]
506
+
507
+ bottoms = [
508
+ [
509
+ sg.Checkbox(
510
+ key="use_gpu",
511
+ default=get_optimal_device() != torch.device("cpu"),
512
+ text="Use GPU"
513
+ + (
514
+ " (not available; if your device has GPU, make sure you installed PyTorch with CUDA support)"
515
+ if get_optimal_device() == torch.device("cpu")
516
+ else ""
517
+ ),
518
+ disabled=get_optimal_device() == torch.device("cpu"),
519
+ )
520
+ ],
521
+ [
522
+ sg.Button("Infer", key="infer"),
523
+ sg.Button("(Re)Start Voice Changer", key="start_vc"),
524
+ sg.Button("Stop Voice Changer", key="stop_vc"),
525
+ sg.Push(),
526
+ # sg.Button("ONNX Export", key="onnx_export"),
527
+ ],
528
+ ]
529
+ column1 = sg.Column(
530
+ [
531
+ frames["Paths"],
532
+ frames["Common"],
533
+ ],
534
+ vertical_alignment="top",
535
+ )
536
+ column2 = sg.Column(
537
+ [
538
+ frames["File"],
539
+ frames["Realtime"],
540
+ frames["Presets"],
541
+ ]
542
+ + bottoms
543
+ )
544
+ # columns
545
+ layout = [[column1, column2]]
546
+ # get screen size
547
+ screen_width, screen_height = sg.Window.get_screen_size()
548
+ if screen_height < 720:
549
+ layout = [
550
+ [
551
+ sg.Column(
552
+ layout,
553
+ vertical_alignment="top",
554
+ scrollable=False,
555
+ expand_x=True,
556
+ expand_y=True,
557
+ vertical_scroll_only=True,
558
+ key="main_column",
559
+ )
560
+ ]
561
+ ]
562
+ window = sg.Window(
563
+ f"{__name__.split('.')[0].replace('_', '-')} v{__version__}",
564
+ layout,
565
+ grab_anywhere=True,
566
+ finalize=True,
567
+ scaling=1,
568
+ font=("Yu Gothic UI", 11) if os.name == "nt" else None,
569
+ # resizable=True,
570
+ # size=(1280, 720),
571
+ # Below disables taskbar, which may be not useful for some users
572
+ # use_custom_titlebar=True, no_titlebar=False
573
+ # Keep on top
574
+ # keep_on_top=True
575
+ )
576
+
577
+ # event, values = window.read(timeout=0.01)
578
+ # window["main_column"].Scrollable = True
579
+
580
+ # make slider height smaller
581
+ try:
582
+ for v in window.element_list():
583
+ if isinstance(v, sg.Slider):
584
+ v.Widget.configure(sliderrelief="flat", width=10, sliderlength=20)
585
+ except Exception as e:
586
+ LOG.exception(e)
587
+
588
+ # for n in ["input_device", "output_device"]:
589
+ # window[n].Widget.configure(justify="right")
590
+ event, values = window.read(timeout=0.01)
591
+
592
+ def update_speaker() -> None:
593
+ from . import utils
594
+
595
+ config_path = Path(values["config_path"])
596
+ if config_path.exists() and config_path.is_file():
597
+ hp = utils.get_hparams(values["config_path"])
598
+ LOG.debug(f"Loaded config from {values['config_path']}")
599
+ window["speaker"].update(
600
+ values=list(hp.__dict__["spk"].keys()), set_to_index=0
601
+ )
602
+
603
+ def update_devices() -> None:
604
+ (
605
+ input_devices,
606
+ output_devices,
607
+ input_device_indices,
608
+ output_device_indices,
609
+ ) = get_devices()
610
+ input_device_indices_reversed = {
611
+ v: k for k, v in enumerate(input_device_indices)
612
+ }
613
+ output_device_indices_reversed = {
614
+ v: k for k, v in enumerate(output_device_indices)
615
+ }
616
+ window["input_device"].update(
617
+ values=input_devices, value=values["input_device"]
618
+ )
619
+ window["output_device"].update(
620
+ values=output_devices, value=values["output_device"]
621
+ )
622
+ input_default, output_default = sd.default.device
623
+ if values["input_device"] not in input_devices:
624
+ window["input_device"].update(
625
+ values=input_devices,
626
+ set_to_index=input_device_indices_reversed.get(input_default, 0),
627
+ )
628
+ if values["output_device"] not in output_devices:
629
+ window["output_device"].update(
630
+ values=output_devices,
631
+ set_to_index=output_device_indices_reversed.get(output_default, 0),
632
+ )
633
+
634
+ PRESET_KEYS = [
635
+ key
636
+ for key in values.keys()
637
+ if not any(exclude in key for exclude in ["preset", "browse"])
638
+ ]
639
+
640
+ def apply_preset(name: str) -> None:
641
+ for key, value in load_presets()[name].items():
642
+ if key in PRESET_KEYS:
643
+ window[key].update(value)
644
+ values[key] = value
645
+
646
+ default_name = list(load_presets().keys())[0]
647
+ apply_preset(default_name)
648
+ window["presets"].update(default_name)
649
+ del default_name
650
+ update_speaker()
651
+ update_devices()
652
+ # with ProcessPool(max_workers=1) as pool:
653
+ # to support Linux
654
+ with ProcessPool(
655
+ max_workers=min(2, multiprocessing.cpu_count()),
656
+ context=multiprocessing.get_context("spawn"),
657
+ ) as pool:
658
+ future: None | ProcessFuture = None
659
+ infer_futures: set[ProcessFuture] = set()
660
+ while True:
661
+ event, values = window.read(200)
662
+ if event == sg.WIN_CLOSED:
663
+ break
664
+ if not event == sg.EVENT_TIMEOUT:
665
+ LOG.info(f"Event {event}, values {values}")
666
+ if event.endswith("_path"):
667
+ for name in window.AllKeysDict:
668
+ if str(name).endswith("_browse"):
669
+ browser = window[name]
670
+ if isinstance(browser, sg.Button):
671
+ LOG.info(
672
+ f"Updating browser {browser} to {Path(values[event]).parent}"
673
+ )
674
+ browser.InitialFolder = Path(values[event]).parent
675
+ browser.update()
676
+ else:
677
+ LOG.warning(f"Browser {browser} is not a FileBrowse")
678
+ window["transpose"].update(
679
+ disabled=values["auto_predict_f0"],
680
+ visible=not values["auto_predict_f0"],
681
+ )
682
+
683
+ input_path = Path(values["input_path"])
684
+ output_path = Path(values["output_path"])
685
+
686
+ if event == "add_preset":
687
+ presets = add_preset(
688
+ values["preset_name"], {key: values[key] for key in PRESET_KEYS}
689
+ )
690
+ window["presets"].update(values=list(presets.keys()))
691
+ elif event == "delete_preset":
692
+ presets = delete_preset(values["presets"])
693
+ window["presets"].update(values=list(presets.keys()))
694
+ elif event == "presets":
695
+ apply_preset(values["presets"])
696
+ update_speaker()
697
+ elif event == "refresh_devices":
698
+ update_devices()
699
+ elif event == "config_path":
700
+ update_speaker()
701
+ elif event == "input_path":
702
+ # Don't change the output path if it's already set
703
+ # if values["output_path"]:
704
+ # continue
705
+ # Set a sensible default output path
706
+ window.Element("output_path").Update(str(get_output_path(input_path)))
707
+ elif event == "infer":
708
+ if "Default VC" in values["presets"]:
709
+ window["presets"].update(
710
+ set_to_index=list(load_presets().keys()).index("Default File")
711
+ )
712
+ apply_preset("Default File")
713
+ if values["input_path"] == "":
714
+ LOG.warning("Input path is empty.")
715
+ continue
716
+ if not input_path.exists():
717
+ LOG.warning(f"Input path {input_path} does not exist.")
718
+ continue
719
+ # if not validate_output_file_type(output_path):
720
+ # continue
721
+
722
+ try:
723
+ from so_vits_svc_fork.inference.main import infer
724
+
725
+ LOG.info("Starting inference...")
726
+ window["infer"].update(disabled=True)
727
+ infer_future = pool.schedule(
728
+ infer,
729
+ kwargs=dict(
730
+ # paths
731
+ model_path=Path(values["model_path"]),
732
+ output_path=output_path,
733
+ input_path=input_path,
734
+ config_path=Path(values["config_path"]),
735
+ recursive=True,
736
+ # svc config
737
+ speaker=values["speaker"],
738
+ cluster_model_path=Path(values["cluster_model_path"])
739
+ if values["cluster_model_path"]
740
+ else None,
741
+ transpose=values["transpose"],
742
+ auto_predict_f0=values["auto_predict_f0"],
743
+ cluster_infer_ratio=values["cluster_infer_ratio"],
744
+ noise_scale=values["noise_scale"],
745
+ f0_method=values["f0_method"],
746
+ # slice config
747
+ db_thresh=values["silence_threshold"],
748
+ pad_seconds=values["pad_seconds"],
749
+ chunk_seconds=values["chunk_seconds"],
750
+ absolute_thresh=values["absolute_thresh"],
751
+ max_chunk_seconds=values["max_chunk_seconds"],
752
+ device="cpu"
753
+ if not values["use_gpu"]
754
+ else get_optimal_device(),
755
+ ),
756
+ )
757
+ infer_future.add_done_callback(
758
+ lambda _future: after_inference(
759
+ window, input_path, values["auto_play"], output_path
760
+ )
761
+ )
762
+ infer_futures.add(infer_future)
763
+ except Exception as e:
764
+ LOG.exception(e)
765
+ elif event == "play_input":
766
+ if Path(values["input_path"]).exists():
767
+ pool.schedule(play_audio, args=[Path(values["input_path"])])
768
+ elif event == "start_vc":
769
+ _, _, input_device_indices, output_device_indices = get_devices(
770
+ update=False
771
+ )
772
+ from so_vits_svc_fork.inference.main import realtime
773
+
774
+ if future:
775
+ LOG.info("Canceling previous task")
776
+ future.cancel()
777
+ future = pool.schedule(
778
+ realtime,
779
+ kwargs=dict(
780
+ # paths
781
+ model_path=Path(values["model_path"]),
782
+ config_path=Path(values["config_path"]),
783
+ speaker=values["speaker"],
784
+ # svc config
785
+ cluster_model_path=Path(values["cluster_model_path"])
786
+ if values["cluster_model_path"]
787
+ else None,
788
+ transpose=values["transpose"],
789
+ auto_predict_f0=values["auto_predict_f0"],
790
+ cluster_infer_ratio=values["cluster_infer_ratio"],
791
+ noise_scale=values["noise_scale"],
792
+ f0_method=values["f0_method"],
793
+ # slice config
794
+ db_thresh=values["silence_threshold"],
795
+ pad_seconds=values["pad_seconds"],
796
+ chunk_seconds=values["chunk_seconds"],
797
+ # realtime config
798
+ crossfade_seconds=values["crossfade_seconds"],
799
+ additional_infer_before_seconds=values[
800
+ "additional_infer_before_seconds"
801
+ ],
802
+ additional_infer_after_seconds=values[
803
+ "additional_infer_after_seconds"
804
+ ],
805
+ block_seconds=values["block_seconds"],
806
+ version=int(values["realtime_algorithm"][0]),
807
+ input_device=input_device_indices[
808
+ window["input_device"].widget.current()
809
+ ],
810
+ output_device=output_device_indices[
811
+ window["output_device"].widget.current()
812
+ ],
813
+ device=get_optimal_device() if values["use_gpu"] else "cpu",
814
+ passthrough_original=values["passthrough_original"],
815
+ ),
816
+ )
817
+ elif event == "stop_vc":
818
+ if future:
819
+ future.cancel()
820
+ future = None
821
+ elif event == "onnx_export":
822
+ try:
823
+ raise NotImplementedError("ONNX export is not implemented yet.")
824
+ from so_vits_svc_fork.modules.onnx._export import onnx_export
825
+
826
+ onnx_export(
827
+ input_path=Path(values["model_path"]),
828
+ output_path=Path(values["model_path"]).with_suffix(".onnx"),
829
+ config_path=Path(values["config_path"]),
830
+ device="cpu",
831
+ )
832
+ except Exception as e:
833
+ LOG.exception(e)
834
+ if future is not None and future.done():
835
+ try:
836
+ future.result()
837
+ except Exception as e:
838
+ LOG.error("Error in realtime: ")
839
+ LOG.exception(e)
840
+ future = None
841
+ for future in copy(infer_futures):
842
+ if future.done():
843
+ try:
844
+ future.result()
845
+ except Exception as e:
846
+ LOG.error("Error in inference: ")
847
+ LOG.exception(e)
848
+ infer_futures.remove(future)
849
+ if future:
850
+ future.cancel()
851
+ window.close()
so_vits_svc_fork/hparams.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+
6
+ class HParams:
7
+ def __init__(self, **kwargs: Any) -> None:
8
+ for k, v in kwargs.items():
9
+ if type(v) == dict:
10
+ v = HParams(**v)
11
+ self[k] = v
12
+
13
+ def keys(self):
14
+ return self.__dict__.keys()
15
+
16
+ def items(self):
17
+ return self.__dict__.items()
18
+
19
+ def values(self):
20
+ return self.__dict__.values()
21
+
22
+ def get(self, key: str, default: Any = None):
23
+ return self.__dict__.get(key, default)
24
+
25
+ def __len__(self):
26
+ return len(self.__dict__)
27
+
28
+ def __getitem__(self, key):
29
+ return getattr(self, key)
30
+
31
+ def __setitem__(self, key, value):
32
+ return setattr(self, key, value)
33
+
34
+ def __contains__(self, key):
35
+ return key in self.__dict__
36
+
37
+ def __repr__(self):
38
+ return self.__dict__.__repr__()
so_vits_svc_fork/inference/__init__.py ADDED
File without changes
so_vits_svc_fork/inference/core.py ADDED
@@ -0,0 +1,692 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from copy import deepcopy
4
+ from logging import getLogger
5
+ from pathlib import Path
6
+ from typing import Any, Callable, Iterable, Literal
7
+
8
+ import attrs
9
+ import librosa
10
+ import numpy as np
11
+ import torch
12
+ from cm_time import timer
13
+ from numpy import dtype, float32, ndarray
14
+
15
+ import so_vits_svc_fork.f0
16
+ from so_vits_svc_fork import cluster, utils
17
+
18
+ from ..modules.synthesizers import SynthesizerTrn
19
+ from ..utils import get_optimal_device
20
+
21
+ LOG = getLogger(__name__)
22
+
23
+
24
+ def pad_array(array_, target_length: int):
25
+ current_length = array_.shape[0]
26
+ if current_length >= target_length:
27
+ return array_[
28
+ (current_length - target_length)
29
+ // 2 : (current_length - target_length)
30
+ // 2
31
+ + target_length,
32
+ ...,
33
+ ]
34
+ else:
35
+ pad_width = target_length - current_length
36
+ pad_left = pad_width // 2
37
+ pad_right = pad_width - pad_left
38
+ padded_arr = np.pad(
39
+ array_, (pad_left, pad_right), "constant", constant_values=(0, 0)
40
+ )
41
+ return padded_arr
42
+
43
+
44
+ @attrs.frozen(kw_only=True)
45
+ class Chunk:
46
+ is_speech: bool
47
+ audio: ndarray[Any, dtype[float32]]
48
+ start: int
49
+ end: int
50
+
51
+ @property
52
+ def duration(self) -> float32:
53
+ # return self.end - self.start
54
+ return float32(self.audio.shape[0])
55
+
56
+ def __repr__(self) -> str:
57
+ return f"Chunk(Speech: {self.is_speech}, {self.duration})"
58
+
59
+
60
+ def split_silence(
61
+ audio: ndarray[Any, dtype[float32]],
62
+ top_db: int = 40,
63
+ ref: float | Callable[[ndarray[Any, dtype[float32]]], float] = 1,
64
+ frame_length: int = 2048,
65
+ hop_length: int = 512,
66
+ aggregate: Callable[[ndarray[Any, dtype[float32]]], float] = np.mean,
67
+ max_chunk_length: int = 0,
68
+ ) -> Iterable[Chunk]:
69
+ non_silence_indices = librosa.effects.split(
70
+ audio,
71
+ top_db=top_db,
72
+ ref=ref,
73
+ frame_length=frame_length,
74
+ hop_length=hop_length,
75
+ aggregate=aggregate,
76
+ )
77
+ last_end = 0
78
+ for start, end in non_silence_indices:
79
+ if start != last_end:
80
+ yield Chunk(
81
+ is_speech=False, audio=audio[last_end:start], start=last_end, end=start
82
+ )
83
+ while max_chunk_length > 0 and end - start > max_chunk_length:
84
+ yield Chunk(
85
+ is_speech=True,
86
+ audio=audio[start : start + max_chunk_length],
87
+ start=start,
88
+ end=start + max_chunk_length,
89
+ )
90
+ start += max_chunk_length
91
+ if end - start > 0:
92
+ yield Chunk(is_speech=True, audio=audio[start:end], start=start, end=end)
93
+ last_end = end
94
+ if last_end != len(audio):
95
+ yield Chunk(
96
+ is_speech=False, audio=audio[last_end:], start=last_end, end=len(audio)
97
+ )
98
+
99
+
100
+ class Svc:
101
+ def __init__(
102
+ self,
103
+ *,
104
+ net_g_path: Path | str,
105
+ config_path: Path | str,
106
+ device: torch.device | str | None = None,
107
+ cluster_model_path: Path | str | None = None,
108
+ half: bool = False,
109
+ ):
110
+ self.net_g_path = net_g_path
111
+ if device is None:
112
+ self.device = (get_optimal_device(),)
113
+ else:
114
+ self.device = torch.device(device)
115
+ self.hps = utils.get_hparams(config_path)
116
+ self.target_sample = self.hps.data.sampling_rate
117
+ self.hop_size = self.hps.data.hop_length
118
+ self.spk2id = self.hps.spk
119
+ self.hubert_model = utils.get_hubert_model(
120
+ self.device, self.hps.data.get("contentvec_final_proj", True)
121
+ )
122
+ self.dtype = torch.float16 if half else torch.float32
123
+ self.contentvec_final_proj = self.hps.data.__dict__.get(
124
+ "contentvec_final_proj", True
125
+ )
126
+ self.load_model()
127
+ if cluster_model_path is not None and Path(cluster_model_path).exists():
128
+ self.cluster_model = cluster.get_cluster_model(cluster_model_path)
129
+
130
+ def load_model(self):
131
+ self.net_g = SynthesizerTrn(
132
+ self.hps.data.filter_length // 2 + 1,
133
+ self.hps.train.segment_size // self.hps.data.hop_length,
134
+ **self.hps.model,
135
+ )
136
+ _ = utils.load_checkpoint(self.net_g_path, self.net_g, None)
137
+ _ = self.net_g.eval()
138
+ for m in self.net_g.modules():
139
+ utils.remove_weight_norm_if_exists(m)
140
+ _ = self.net_g.to(self.device, dtype=self.dtype)
141
+ self.net_g = self.net_g
142
+
143
+ def get_unit_f0(
144
+ self,
145
+ audio: ndarray[Any, dtype[float32]],
146
+ tran: int,
147
+ cluster_infer_ratio: float,
148
+ speaker: int | str,
149
+ f0_method: Literal[
150
+ "crepe", "crepe-tiny", "parselmouth", "dio", "harvest"
151
+ ] = "dio",
152
+ ):
153
+ f0 = so_vits_svc_fork.f0.compute_f0(
154
+ audio,
155
+ sampling_rate=self.target_sample,
156
+ hop_length=self.hop_size,
157
+ method=f0_method,
158
+ )
159
+ f0, uv = so_vits_svc_fork.f0.interpolate_f0(f0)
160
+ f0 = torch.as_tensor(f0, dtype=self.dtype, device=self.device)
161
+ uv = torch.as_tensor(uv, dtype=self.dtype, device=self.device)
162
+ f0 = f0 * 2 ** (tran / 12)
163
+ f0 = f0.unsqueeze(0)
164
+ uv = uv.unsqueeze(0)
165
+
166
+ c = utils.get_content(
167
+ self.hubert_model,
168
+ audio,
169
+ self.device,
170
+ self.target_sample,
171
+ self.contentvec_final_proj,
172
+ ).to(self.dtype)
173
+ c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[1])
174
+
175
+ if cluster_infer_ratio != 0:
176
+ cluster_c = cluster.get_cluster_center_result(
177
+ self.cluster_model, c.cpu().numpy().T, speaker
178
+ ).T
179
+ cluster_c = torch.FloatTensor(cluster_c).to(self.device)
180
+ c = cluster_infer_ratio * cluster_c + (1 - cluster_infer_ratio) * c
181
+
182
+ c = c.unsqueeze(0)
183
+ return c, f0, uv
184
+
185
+ def infer(
186
+ self,
187
+ speaker: int | str,
188
+ transpose: int,
189
+ audio: ndarray[Any, dtype[float32]],
190
+ cluster_infer_ratio: float = 0,
191
+ auto_predict_f0: bool = False,
192
+ noise_scale: float = 0.4,
193
+ f0_method: Literal[
194
+ "crepe", "crepe-tiny", "parselmouth", "dio", "harvest"
195
+ ] = "dio",
196
+ ) -> tuple[torch.Tensor, int]:
197
+ audio = audio.astype(np.float32)
198
+ # get speaker id
199
+ if isinstance(speaker, int):
200
+ if len(self.spk2id.__dict__) >= speaker:
201
+ speaker_id = speaker
202
+ else:
203
+ raise ValueError(
204
+ f"Speaker id {speaker} >= number of speakers {len(self.spk2id.__dict__)}"
205
+ )
206
+ else:
207
+ if speaker in self.spk2id.__dict__:
208
+ speaker_id = self.spk2id.__dict__[speaker]
209
+ else:
210
+ LOG.warning(f"Speaker {speaker} is not found. Use speaker 0 instead.")
211
+ speaker_id = 0
212
+ speaker_candidates = list(
213
+ filter(lambda x: x[1] == speaker_id, self.spk2id.__dict__.items())
214
+ )
215
+ if len(speaker_candidates) > 1:
216
+ raise ValueError(
217
+ f"Speaker_id {speaker_id} is not unique. Candidates: {speaker_candidates}"
218
+ )
219
+ elif len(speaker_candidates) == 0:
220
+ raise ValueError(f"Speaker_id {speaker_id} is not found.")
221
+ speaker = speaker_candidates[0][0]
222
+ sid = torch.LongTensor([int(speaker_id)]).to(self.device).unsqueeze(0)
223
+
224
+ # get unit f0
225
+ c, f0, uv = self.get_unit_f0(
226
+ audio, transpose, cluster_infer_ratio, speaker, f0_method
227
+ )
228
+
229
+ # inference
230
+ with torch.no_grad():
231
+ with timer() as t:
232
+ audio = self.net_g.infer(
233
+ c,
234
+ f0=f0,
235
+ g=sid,
236
+ uv=uv,
237
+ predict_f0=auto_predict_f0,
238
+ noice_scale=noise_scale,
239
+ )[0, 0].data.float()
240
+ audio_duration = audio.shape[-1] / self.target_sample
241
+ LOG.info(
242
+ f"Inference time: {t.elapsed:.2f}s, RTF: {t.elapsed / audio_duration:.2f}"
243
+ )
244
+ torch.cuda.empty_cache()
245
+ return audio, audio.shape[-1]
246
+
247
+ def infer_silence(
248
+ self,
249
+ audio: np.ndarray[Any, np.dtype[np.float32]],
250
+ *,
251
+ # svc config
252
+ speaker: int | str,
253
+ transpose: int = 0,
254
+ auto_predict_f0: bool = False,
255
+ cluster_infer_ratio: float = 0,
256
+ noise_scale: float = 0.4,
257
+ f0_method: Literal[
258
+ "crepe", "crepe-tiny", "parselmouth", "dio", "harvest"
259
+ ] = "dio",
260
+ # slice config
261
+ db_thresh: int = -40,
262
+ pad_seconds: float = 0.5,
263
+ chunk_seconds: float = 0.5,
264
+ absolute_thresh: bool = False,
265
+ max_chunk_seconds: float = 40,
266
+ # fade_seconds: float = 0.0,
267
+ ) -> np.ndarray[Any, np.dtype[np.float32]]:
268
+ sr = self.target_sample
269
+ result_audio = np.array([], dtype=np.float32)
270
+ chunk_length_min = chunk_length_min = (
271
+ int(
272
+ min(
273
+ sr / so_vits_svc_fork.f0.f0_min * 20 + 1,
274
+ chunk_seconds * sr,
275
+ )
276
+ )
277
+ // 2
278
+ )
279
+ for chunk in split_silence(
280
+ audio,
281
+ top_db=-db_thresh,
282
+ frame_length=chunk_length_min * 2,
283
+ hop_length=chunk_length_min,
284
+ ref=1 if absolute_thresh else np.max,
285
+ max_chunk_length=int(max_chunk_seconds * sr),
286
+ ):
287
+ LOG.info(f"Chunk: {chunk}")
288
+ if not chunk.is_speech:
289
+ audio_chunk_infer = np.zeros_like(chunk.audio)
290
+ else:
291
+ # pad
292
+ pad_len = int(sr * pad_seconds)
293
+ audio_chunk_pad = np.concatenate(
294
+ [
295
+ np.zeros([pad_len], dtype=np.float32),
296
+ chunk.audio,
297
+ np.zeros([pad_len], dtype=np.float32),
298
+ ]
299
+ )
300
+ audio_chunk_pad_infer_tensor, _ = self.infer(
301
+ speaker,
302
+ transpose,
303
+ audio_chunk_pad,
304
+ cluster_infer_ratio=cluster_infer_ratio,
305
+ auto_predict_f0=auto_predict_f0,
306
+ noise_scale=noise_scale,
307
+ f0_method=f0_method,
308
+ )
309
+ audio_chunk_pad_infer = audio_chunk_pad_infer_tensor.cpu().numpy()
310
+ pad_len = int(self.target_sample * pad_seconds)
311
+ cut_len_2 = (len(audio_chunk_pad_infer) - len(chunk.audio)) // 2
312
+ audio_chunk_infer = audio_chunk_pad_infer[
313
+ cut_len_2 : cut_len_2 + len(chunk.audio)
314
+ ]
315
+
316
+ # add fade
317
+ # fade_len = int(self.target_sample * fade_seconds)
318
+ # _audio[:fade_len] = _audio[:fade_len] * np.linspace(0, 1, fade_len)
319
+ # _audio[-fade_len:] = _audio[-fade_len:] * np.linspace(1, 0, fade_len)
320
+
321
+ # empty cache
322
+ torch.cuda.empty_cache()
323
+ result_audio = np.concatenate([result_audio, audio_chunk_infer])
324
+ result_audio = result_audio[: audio.shape[0]]
325
+ return result_audio
326
+
327
+
328
+ def sola_crossfade(
329
+ first: ndarray[Any, dtype[float32]],
330
+ second: ndarray[Any, dtype[float32]],
331
+ crossfade_len: int,
332
+ sola_search_len: int,
333
+ ) -> ndarray[Any, dtype[float32]]:
334
+ cor_nom = np.convolve(
335
+ second[: sola_search_len + crossfade_len],
336
+ np.flip(first[-crossfade_len:]),
337
+ "valid",
338
+ )
339
+ cor_den = np.sqrt(
340
+ np.convolve(
341
+ second[: sola_search_len + crossfade_len] ** 2,
342
+ np.ones(crossfade_len),
343
+ "valid",
344
+ )
345
+ + 1e-8
346
+ )
347
+ sola_shift = np.argmax(cor_nom / cor_den)
348
+ LOG.info(f"SOLA shift: {sola_shift}")
349
+ second = second[sola_shift : sola_shift + len(second) - sola_search_len]
350
+ return np.concatenate(
351
+ [
352
+ first[:-crossfade_len],
353
+ first[-crossfade_len:] * np.linspace(1, 0, crossfade_len)
354
+ + second[:crossfade_len] * np.linspace(0, 1, crossfade_len),
355
+ second[crossfade_len:],
356
+ ]
357
+ )
358
+
359
+
360
+ class Crossfader:
361
+ def __init__(
362
+ self,
363
+ *,
364
+ additional_infer_before_len: int,
365
+ additional_infer_after_len: int,
366
+ crossfade_len: int,
367
+ sola_search_len: int = 384,
368
+ ) -> None:
369
+ if additional_infer_before_len < 0:
370
+ raise ValueError("additional_infer_len must be >= 0")
371
+ if crossfade_len < 0:
372
+ raise ValueError("crossfade_len must be >= 0")
373
+ if additional_infer_after_len < 0:
374
+ raise ValueError("additional_infer_len must be >= 0")
375
+ if additional_infer_before_len < 0:
376
+ raise ValueError("additional_infer_len must be >= 0")
377
+ self.additional_infer_before_len = additional_infer_before_len
378
+ self.additional_infer_after_len = additional_infer_after_len
379
+ self.crossfade_len = crossfade_len
380
+ self.sola_search_len = sola_search_len
381
+ self.last_input_left = np.zeros(
382
+ sola_search_len
383
+ + crossfade_len
384
+ + additional_infer_before_len
385
+ + additional_infer_after_len,
386
+ dtype=np.float32,
387
+ )
388
+ self.last_infered_left = np.zeros(crossfade_len, dtype=np.float32)
389
+
390
+ def process(
391
+ self, input_audio: ndarray[Any, dtype[float32]], *args, **kwargs: Any
392
+ ) -> ndarray[Any, dtype[float32]]:
393
+ """
394
+ chunks : ■■■■■■□□□□□□
395
+ add last input:□■■■■■■
396
+ ■□□□□□□
397
+ infer :□■■■■■■
398
+ ■□□□□□□
399
+ crossfade :▲■■■■■
400
+ ▲□□□□□
401
+ """
402
+ # check input
403
+ if input_audio.ndim != 1:
404
+ raise ValueError("Input audio must be 1-dimensional.")
405
+ if (
406
+ input_audio.shape[0] + self.additional_infer_before_len
407
+ <= self.crossfade_len
408
+ ):
409
+ raise ValueError(
410
+ f"Input audio length ({input_audio.shape[0]}) + additional_infer_len ({self.additional_infer_before_len}) must be greater than crossfade_len ({self.crossfade_len})."
411
+ )
412
+ input_audio = input_audio.astype(np.float32)
413
+ input_audio_len = len(input_audio)
414
+
415
+ # concat last input and infer
416
+ input_audio_concat = np.concatenate([self.last_input_left, input_audio])
417
+ del input_audio
418
+ pad_len = 0
419
+ if pad_len:
420
+ infer_audio_concat = self.infer(
421
+ np.pad(input_audio_concat, (pad_len, pad_len), mode="reflect"),
422
+ *args,
423
+ **kwargs,
424
+ )[pad_len:-pad_len]
425
+ else:
426
+ infer_audio_concat = self.infer(input_audio_concat, *args, **kwargs)
427
+
428
+ # debug SOLA (using copy synthesis with a random shift)
429
+ """
430
+ rs = int(np.random.uniform(-200,200))
431
+ LOG.info(f"Debug random shift: {rs}")
432
+ infer_audio_concat = np.roll(input_audio_concat, rs)
433
+ """
434
+
435
+ if len(infer_audio_concat) != len(input_audio_concat):
436
+ raise ValueError(
437
+ f"Inferred audio length ({len(infer_audio_concat)}) should be equal to input audio length ({len(input_audio_concat)})."
438
+ )
439
+ infer_audio_to_use = infer_audio_concat[
440
+ -(
441
+ self.sola_search_len
442
+ + self.crossfade_len
443
+ + input_audio_len
444
+ + self.additional_infer_after_len
445
+ ) : -self.additional_infer_after_len
446
+ ]
447
+ assert (
448
+ len(infer_audio_to_use)
449
+ == input_audio_len + self.sola_search_len + self.crossfade_len
450
+ ), f"{len(infer_audio_to_use)} != {input_audio_len + self.sola_search_len + self.cross_fade_len}"
451
+ _audio = sola_crossfade(
452
+ self.last_infered_left,
453
+ infer_audio_to_use,
454
+ self.crossfade_len,
455
+ self.sola_search_len,
456
+ )
457
+ result_audio = _audio[: -self.crossfade_len]
458
+ assert (
459
+ len(result_audio) == input_audio_len
460
+ ), f"{len(result_audio)} != {input_audio_len}"
461
+
462
+ # update last input and inferred
463
+ self.last_input_left = input_audio_concat[
464
+ -(
465
+ self.sola_search_len
466
+ + self.crossfade_len
467
+ + self.additional_infer_before_len
468
+ + self.additional_infer_after_len
469
+ ) :
470
+ ]
471
+ self.last_infered_left = _audio[-self.crossfade_len :]
472
+ return result_audio
473
+
474
+ def infer(
475
+ self, input_audio: ndarray[Any, dtype[float32]]
476
+ ) -> ndarray[Any, dtype[float32]]:
477
+ return input_audio
478
+
479
+
480
+ class RealtimeVC(Crossfader):
481
+ def __init__(
482
+ self,
483
+ *,
484
+ svc_model: Svc,
485
+ crossfade_len: int = 3840,
486
+ additional_infer_before_len: int = 7680,
487
+ additional_infer_after_len: int = 7680,
488
+ split: bool = True,
489
+ ) -> None:
490
+ self.svc_model = svc_model
491
+ self.split = split
492
+ super().__init__(
493
+ crossfade_len=crossfade_len,
494
+ additional_infer_before_len=additional_infer_before_len,
495
+ additional_infer_after_len=additional_infer_after_len,
496
+ )
497
+
498
+ def process(
499
+ self,
500
+ input_audio: ndarray[Any, dtype[float32]],
501
+ *args: Any,
502
+ **kwargs: Any,
503
+ ) -> ndarray[Any, dtype[float32]]:
504
+ return super().process(input_audio, *args, **kwargs)
505
+
506
+ def infer(
507
+ self,
508
+ input_audio: np.ndarray[Any, np.dtype[np.float32]],
509
+ # svc config
510
+ speaker: int | str,
511
+ transpose: int,
512
+ cluster_infer_ratio: float = 0,
513
+ auto_predict_f0: bool = False,
514
+ noise_scale: float = 0.4,
515
+ f0_method: Literal[
516
+ "crepe", "crepe-tiny", "parselmouth", "dio", "harvest"
517
+ ] = "dio",
518
+ # slice config
519
+ db_thresh: int = -40,
520
+ pad_seconds: float = 0.5,
521
+ chunk_seconds: float = 0.5,
522
+ ) -> ndarray[Any, dtype[float32]]:
523
+ # infer
524
+ if self.split:
525
+ return self.svc_model.infer_silence(
526
+ audio=input_audio,
527
+ speaker=speaker,
528
+ transpose=transpose,
529
+ cluster_infer_ratio=cluster_infer_ratio,
530
+ auto_predict_f0=auto_predict_f0,
531
+ noise_scale=noise_scale,
532
+ f0_method=f0_method,
533
+ db_thresh=db_thresh,
534
+ pad_seconds=pad_seconds,
535
+ chunk_seconds=chunk_seconds,
536
+ absolute_thresh=True,
537
+ )
538
+ else:
539
+ rms = np.sqrt(np.mean(input_audio**2))
540
+ min_rms = 10 ** (db_thresh / 20)
541
+ if rms < min_rms:
542
+ LOG.info(f"Skip silence: RMS={rms:.2f} < {min_rms:.2f}")
543
+ return np.zeros_like(input_audio)
544
+ else:
545
+ LOG.info(f"Start inference: RMS={rms:.2f} >= {min_rms:.2f}")
546
+ infered_audio_c, _ = self.svc_model.infer(
547
+ speaker=speaker,
548
+ transpose=transpose,
549
+ audio=input_audio,
550
+ cluster_infer_ratio=cluster_infer_ratio,
551
+ auto_predict_f0=auto_predict_f0,
552
+ noise_scale=noise_scale,
553
+ f0_method=f0_method,
554
+ )
555
+ return infered_audio_c.cpu().numpy()
556
+
557
+
558
+ class RealtimeVC2:
559
+ chunk_store: list[Chunk]
560
+
561
+ def __init__(self, svc_model: Svc) -> None:
562
+ self.input_audio_store = np.array([], dtype=np.float32)
563
+ self.chunk_store = []
564
+ self.svc_model = svc_model
565
+
566
+ def process(
567
+ self,
568
+ input_audio: np.ndarray[Any, np.dtype[np.float32]],
569
+ # svc config
570
+ speaker: int | str,
571
+ transpose: int,
572
+ cluster_infer_ratio: float = 0,
573
+ auto_predict_f0: bool = False,
574
+ noise_scale: float = 0.4,
575
+ f0_method: Literal[
576
+ "crepe", "crepe-tiny", "parselmouth", "dio", "harvest"
577
+ ] = "dio",
578
+ # slice config
579
+ db_thresh: int = -40,
580
+ chunk_seconds: float = 0.5,
581
+ ) -> ndarray[Any, dtype[float32]]:
582
+ def infer(audio: ndarray[Any, dtype[float32]]) -> ndarray[Any, dtype[float32]]:
583
+ infered_audio_c, _ = self.svc_model.infer(
584
+ speaker=speaker,
585
+ transpose=transpose,
586
+ audio=audio,
587
+ cluster_infer_ratio=cluster_infer_ratio,
588
+ auto_predict_f0=auto_predict_f0,
589
+ noise_scale=noise_scale,
590
+ f0_method=f0_method,
591
+ )
592
+ return infered_audio_c.cpu().numpy()
593
+
594
+ self.input_audio_store = np.concatenate([self.input_audio_store, input_audio])
595
+ LOG.info(f"input_audio_store: {self.input_audio_store.shape}")
596
+ sr = self.svc_model.target_sample
597
+ chunk_length_min = (
598
+ int(min(sr / so_vits_svc_fork.f0.f0_min * 20 + 1, chunk_seconds * sr)) // 2
599
+ )
600
+ LOG.info(f"Chunk length min: {chunk_length_min}")
601
+ chunk_list = list(
602
+ split_silence(
603
+ self.input_audio_store,
604
+ -db_thresh,
605
+ frame_length=chunk_length_min * 2,
606
+ hop_length=chunk_length_min,
607
+ ref=1, # use absolute threshold
608
+ )
609
+ )
610
+ assert len(chunk_list) > 0
611
+ LOG.info(f"Chunk list: {chunk_list}")
612
+ # do not infer LAST incomplete is_speech chunk and save to store
613
+ if chunk_list[-1].is_speech:
614
+ self.input_audio_store = chunk_list.pop().audio
615
+ else:
616
+ self.input_audio_store = np.array([], dtype=np.float32)
617
+
618
+ # infer complete is_speech chunk and save to store
619
+ self.chunk_store.extend(
620
+ [
621
+ attrs.evolve(c, audio=infer(c.audio) if c.is_speech else c.audio)
622
+ for c in chunk_list
623
+ ]
624
+ )
625
+
626
+ # calculate lengths and determine compress rate
627
+ total_speech_len = sum(
628
+ [c.duration if c.is_speech else 0 for c in self.chunk_store]
629
+ )
630
+ total_silence_len = sum(
631
+ [c.duration if not c.is_speech else 0 for c in self.chunk_store]
632
+ )
633
+ input_audio_len = input_audio.shape[0]
634
+ silence_compress_rate = total_silence_len / max(
635
+ 0, input_audio_len - total_speech_len
636
+ )
637
+ LOG.info(
638
+ f"Total speech len: {total_speech_len}, silence len: {total_silence_len}, silence compress rate: {silence_compress_rate}"
639
+ )
640
+
641
+ # generate output audio
642
+ output_audio = np.array([], dtype=np.float32)
643
+ break_flag = False
644
+ LOG.info(f"Chunk store: {self.chunk_store}")
645
+ for chunk in deepcopy(self.chunk_store):
646
+ compress_rate = 1 if chunk.is_speech else silence_compress_rate
647
+ left_len = input_audio_len - output_audio.shape[0]
648
+ # calculate chunk duration
649
+ chunk_duration_output = int(min(chunk.duration / compress_rate, left_len))
650
+ chunk_duration_input = int(min(chunk.duration, left_len * compress_rate))
651
+ LOG.info(
652
+ f"Chunk duration output: {chunk_duration_output}, input: {chunk_duration_input}, left len: {left_len}"
653
+ )
654
+
655
+ # remove chunk from store
656
+ self.chunk_store.pop(0)
657
+ if chunk.duration > chunk_duration_input:
658
+ left_chunk = attrs.evolve(
659
+ chunk, audio=chunk.audio[chunk_duration_input:]
660
+ )
661
+ chunk = attrs.evolve(chunk, audio=chunk.audio[:chunk_duration_input])
662
+
663
+ self.chunk_store.insert(0, left_chunk)
664
+ break_flag = True
665
+
666
+ if chunk.is_speech:
667
+ # if is_speech, just concat
668
+ output_audio = np.concatenate([output_audio, chunk.audio])
669
+ else:
670
+ # if is_silence, concat with zeros and compress with silence_compress_rate
671
+ output_audio = np.concatenate(
672
+ [
673
+ output_audio,
674
+ np.zeros(
675
+ chunk_duration_output,
676
+ dtype=np.float32,
677
+ ),
678
+ ]
679
+ )
680
+
681
+ if break_flag:
682
+ break
683
+ LOG.info(f"Chunk store: {self.chunk_store}, output_audio: {output_audio.shape}")
684
+ # make same length (errors)
685
+ output_audio = output_audio[:input_audio_len]
686
+ output_audio = np.concatenate(
687
+ [
688
+ output_audio,
689
+ np.zeros(input_audio_len - output_audio.shape[0], dtype=np.float32),
690
+ ]
691
+ )
692
+ return output_audio
so_vits_svc_fork/inference/main.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from logging import getLogger
4
+ from pathlib import Path
5
+ from typing import Literal, Sequence
6
+
7
+ import librosa
8
+ import numpy as np
9
+ import soundfile
10
+ import torch
11
+ from cm_time import timer
12
+ from tqdm import tqdm
13
+
14
+ from so_vits_svc_fork.inference.core import RealtimeVC, RealtimeVC2, Svc
15
+ from so_vits_svc_fork.utils import get_optimal_device
16
+
17
+ LOG = getLogger(__name__)
18
+
19
+
20
+ def infer(
21
+ *,
22
+ # paths
23
+ input_path: Path | str | Sequence[Path | str],
24
+ output_path: Path | str | Sequence[Path | str],
25
+ model_path: Path | str,
26
+ config_path: Path | str,
27
+ recursive: bool = False,
28
+ # svc config
29
+ speaker: int | str,
30
+ cluster_model_path: Path | str | None = None,
31
+ transpose: int = 0,
32
+ auto_predict_f0: bool = False,
33
+ cluster_infer_ratio: float = 0,
34
+ noise_scale: float = 0.4,
35
+ f0_method: Literal["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"] = "dio",
36
+ # slice config
37
+ db_thresh: int = -40,
38
+ pad_seconds: float = 0.5,
39
+ chunk_seconds: float = 0.5,
40
+ absolute_thresh: bool = False,
41
+ max_chunk_seconds: float = 40,
42
+ device: str | torch.device = get_optimal_device(),
43
+ ):
44
+ if isinstance(input_path, (str, Path)):
45
+ input_path = [input_path]
46
+ if isinstance(output_path, (str, Path)):
47
+ output_path = [output_path]
48
+ if len(input_path) != len(output_path):
49
+ raise ValueError(
50
+ f"input_path and output_path must have same length, but got {len(input_path)} and {len(output_path)}"
51
+ )
52
+
53
+ model_path = Path(model_path)
54
+ config_path = Path(config_path)
55
+ output_path = [Path(p) for p in output_path]
56
+ input_path = [Path(p) for p in input_path]
57
+ output_paths = []
58
+ input_paths = []
59
+
60
+ for input_path, output_path in zip(input_path, output_path):
61
+ if input_path.is_dir():
62
+ if not recursive:
63
+ raise ValueError(
64
+ f"input_path is a directory, but recursive is False: {input_path}"
65
+ )
66
+ input_paths.extend(list(input_path.rglob("*.*")))
67
+ output_paths.extend(
68
+ [output_path / p.relative_to(input_path) for p in input_paths]
69
+ )
70
+ continue
71
+ input_paths.append(input_path)
72
+ output_paths.append(output_path)
73
+
74
+ cluster_model_path = Path(cluster_model_path) if cluster_model_path else None
75
+ svc_model = Svc(
76
+ net_g_path=model_path.as_posix(),
77
+ config_path=config_path.as_posix(),
78
+ cluster_model_path=cluster_model_path.as_posix()
79
+ if cluster_model_path
80
+ else None,
81
+ device=device,
82
+ )
83
+
84
+ try:
85
+ pbar = tqdm(list(zip(input_paths, output_paths)), disable=len(input_paths) == 1)
86
+ for input_path, output_path in pbar:
87
+ pbar.set_description(f"{input_path}")
88
+ try:
89
+ audio, _ = librosa.load(str(input_path), sr=svc_model.target_sample)
90
+ except Exception as e:
91
+ LOG.error(f"Failed to load {input_path}")
92
+ LOG.exception(e)
93
+ continue
94
+ output_path.parent.mkdir(parents=True, exist_ok=True)
95
+ audio = svc_model.infer_silence(
96
+ audio.astype(np.float32),
97
+ speaker=speaker,
98
+ transpose=transpose,
99
+ auto_predict_f0=auto_predict_f0,
100
+ cluster_infer_ratio=cluster_infer_ratio,
101
+ noise_scale=noise_scale,
102
+ f0_method=f0_method,
103
+ db_thresh=db_thresh,
104
+ pad_seconds=pad_seconds,
105
+ chunk_seconds=chunk_seconds,
106
+ absolute_thresh=absolute_thresh,
107
+ max_chunk_seconds=max_chunk_seconds,
108
+ )
109
+ soundfile.write(str(output_path), audio, svc_model.target_sample)
110
+ finally:
111
+ del svc_model
112
+ torch.cuda.empty_cache()
113
+
114
+
115
+ def realtime(
116
+ *,
117
+ # paths
118
+ model_path: Path | str,
119
+ config_path: Path | str,
120
+ # svc config
121
+ speaker: str,
122
+ cluster_model_path: Path | str | None = None,
123
+ transpose: int = 0,
124
+ auto_predict_f0: bool = False,
125
+ cluster_infer_ratio: float = 0,
126
+ noise_scale: float = 0.4,
127
+ f0_method: Literal["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"] = "dio",
128
+ # slice config
129
+ db_thresh: int = -40,
130
+ pad_seconds: float = 0.5,
131
+ chunk_seconds: float = 0.5,
132
+ # realtime config
133
+ crossfade_seconds: float = 0.05,
134
+ additional_infer_before_seconds: float = 0.2,
135
+ additional_infer_after_seconds: float = 0.1,
136
+ block_seconds: float = 0.5,
137
+ version: int = 2,
138
+ input_device: int | str | None = None,
139
+ output_device: int | str | None = None,
140
+ device: str | torch.device = get_optimal_device(),
141
+ passthrough_original: bool = False,
142
+ ):
143
+ import sounddevice as sd
144
+
145
+ model_path = Path(model_path)
146
+ config_path = Path(config_path)
147
+ cluster_model_path = Path(cluster_model_path) if cluster_model_path else None
148
+ svc_model = Svc(
149
+ net_g_path=model_path.as_posix(),
150
+ config_path=config_path.as_posix(),
151
+ cluster_model_path=cluster_model_path.as_posix()
152
+ if cluster_model_path
153
+ else None,
154
+ device=device,
155
+ )
156
+
157
+ LOG.info("Creating realtime model...")
158
+ if version == 1:
159
+ model = RealtimeVC(
160
+ svc_model=svc_model,
161
+ crossfade_len=int(crossfade_seconds * svc_model.target_sample),
162
+ additional_infer_before_len=int(
163
+ additional_infer_before_seconds * svc_model.target_sample
164
+ ),
165
+ additional_infer_after_len=int(
166
+ additional_infer_after_seconds * svc_model.target_sample
167
+ ),
168
+ )
169
+ else:
170
+ model = RealtimeVC2(
171
+ svc_model=svc_model,
172
+ )
173
+
174
+ # LOG all device info
175
+ devices = sd.query_devices()
176
+ LOG.info(f"Device: {devices}")
177
+ if isinstance(input_device, str):
178
+ input_device_candidates = [
179
+ i for i, d in enumerate(devices) if d["name"] == input_device
180
+ ]
181
+ if len(input_device_candidates) == 0:
182
+ LOG.warning(f"Input device {input_device} not found, using default")
183
+ input_device = None
184
+ else:
185
+ input_device = input_device_candidates[0]
186
+ if isinstance(output_device, str):
187
+ output_device_candidates = [
188
+ i for i, d in enumerate(devices) if d["name"] == output_device
189
+ ]
190
+ if len(output_device_candidates) == 0:
191
+ LOG.warning(f"Output device {output_device} not found, using default")
192
+ output_device = None
193
+ else:
194
+ output_device = output_device_candidates[0]
195
+ if input_device is None or input_device >= len(devices):
196
+ input_device = sd.default.device[0]
197
+ if output_device is None or output_device >= len(devices):
198
+ output_device = sd.default.device[1]
199
+ LOG.info(
200
+ f"Input Device: {devices[input_device]['name']}, Output Device: {devices[output_device]['name']}"
201
+ )
202
+
203
+ # the model RTL is somewhat significantly high only in the first inference
204
+ # there could be no better way to warm up the model than to do a dummy inference
205
+ # (there are not differences in the behavior of the model between the first and the later inferences)
206
+ # so we do a dummy inference to warm up the model (1 second of audio)
207
+ LOG.info("Warming up the model...")
208
+ svc_model.infer(
209
+ speaker=speaker,
210
+ transpose=transpose,
211
+ auto_predict_f0=auto_predict_f0,
212
+ cluster_infer_ratio=cluster_infer_ratio,
213
+ noise_scale=noise_scale,
214
+ f0_method=f0_method,
215
+ audio=np.zeros(svc_model.target_sample, dtype=np.float32),
216
+ )
217
+
218
+ def callback(
219
+ indata: np.ndarray,
220
+ outdata: np.ndarray,
221
+ frames: int,
222
+ time: int,
223
+ status: sd.CallbackFlags,
224
+ ) -> None:
225
+ LOG.debug(
226
+ f"Frames: {frames}, Status: {status}, Shape: {indata.shape}, Time: {time}"
227
+ )
228
+
229
+ kwargs = dict(
230
+ input_audio=indata.mean(axis=1).astype(np.float32),
231
+ # svc config
232
+ speaker=speaker,
233
+ transpose=transpose,
234
+ auto_predict_f0=auto_predict_f0,
235
+ cluster_infer_ratio=cluster_infer_ratio,
236
+ noise_scale=noise_scale,
237
+ f0_method=f0_method,
238
+ # slice config
239
+ db_thresh=db_thresh,
240
+ # pad_seconds=pad_seconds,
241
+ chunk_seconds=chunk_seconds,
242
+ )
243
+ if version == 1:
244
+ kwargs["pad_seconds"] = pad_seconds
245
+ with timer() as t:
246
+ inference = model.process(
247
+ **kwargs,
248
+ ).reshape(-1, 1)
249
+ if passthrough_original:
250
+ outdata[:] = (indata + inference) / 2
251
+ else:
252
+ outdata[:] = inference
253
+ rtf = t.elapsed / block_seconds
254
+ LOG.info(f"Realtime inference time: {t.elapsed:.3f}s, RTF: {rtf:.3f}")
255
+ if rtf > 1:
256
+ LOG.warning("RTF is too high, consider increasing block_seconds")
257
+
258
+ try:
259
+ with sd.Stream(
260
+ device=(input_device, output_device),
261
+ channels=1,
262
+ callback=callback,
263
+ samplerate=svc_model.target_sample,
264
+ blocksize=int(block_seconds * svc_model.target_sample),
265
+ latency="low",
266
+ ) as stream:
267
+ LOG.info(f"Latency: {stream.latency}")
268
+ while True:
269
+ sd.sleep(1000)
270
+ finally:
271
+ # del model, svc_model
272
+ torch.cuda.empty_cache()
so_vits_svc_fork/logger.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from logging import DEBUG, INFO, StreamHandler, basicConfig, captureWarnings, getLogger
4
+ from pathlib import Path
5
+
6
+ from rich.logging import RichHandler
7
+
8
+ LOGGER_INIT = False
9
+
10
+
11
+ def init_logger() -> None:
12
+ global LOGGER_INIT
13
+ if LOGGER_INIT:
14
+ return
15
+
16
+ IS_TEST = "test" in Path.cwd().stem
17
+ package_name = sys.modules[__name__].__package__
18
+ basicConfig(
19
+ level=INFO,
20
+ format="%(asctime)s %(message)s",
21
+ datefmt="[%X]",
22
+ handlers=[
23
+ StreamHandler() if is_notebook() else RichHandler(),
24
+ # FileHandler(f"{package_name}.log"),
25
+ ],
26
+ )
27
+ if IS_TEST:
28
+ getLogger(package_name).setLevel(DEBUG)
29
+ captureWarnings(True)
30
+ LOGGER_INIT = True
31
+
32
+
33
+ def is_notebook():
34
+ try:
35
+ from IPython import get_ipython
36
+
37
+ if "IPKernelApp" not in get_ipython().config: # pragma: no cover
38
+ raise ImportError("console")
39
+ return False
40
+ if "VSCODE_PID" in os.environ: # pragma: no cover
41
+ raise ImportError("vscode")
42
+ return False
43
+ except Exception:
44
+ return False
45
+ else: # pragma: no cover
46
+ return True
so_vits_svc_fork/modules/__init__.py ADDED
File without changes
so_vits_svc_fork/modules/attentions.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ from so_vits_svc_fork.modules import commons
8
+ from so_vits_svc_fork.modules.modules import LayerNorm
9
+
10
+
11
+ class FFT(nn.Module):
12
+ def __init__(
13
+ self,
14
+ hidden_channels,
15
+ filter_channels,
16
+ n_heads,
17
+ n_layers=1,
18
+ kernel_size=1,
19
+ p_dropout=0.0,
20
+ proximal_bias=False,
21
+ proximal_init=True,
22
+ **kwargs
23
+ ):
24
+ super().__init__()
25
+ self.hidden_channels = hidden_channels
26
+ self.filter_channels = filter_channels
27
+ self.n_heads = n_heads
28
+ self.n_layers = n_layers
29
+ self.kernel_size = kernel_size
30
+ self.p_dropout = p_dropout
31
+ self.proximal_bias = proximal_bias
32
+ self.proximal_init = proximal_init
33
+
34
+ self.drop = nn.Dropout(p_dropout)
35
+ self.self_attn_layers = nn.ModuleList()
36
+ self.norm_layers_0 = nn.ModuleList()
37
+ self.ffn_layers = nn.ModuleList()
38
+ self.norm_layers_1 = nn.ModuleList()
39
+ for i in range(self.n_layers):
40
+ self.self_attn_layers.append(
41
+ MultiHeadAttention(
42
+ hidden_channels,
43
+ hidden_channels,
44
+ n_heads,
45
+ p_dropout=p_dropout,
46
+ proximal_bias=proximal_bias,
47
+ proximal_init=proximal_init,
48
+ )
49
+ )
50
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
51
+ self.ffn_layers.append(
52
+ FFN(
53
+ hidden_channels,
54
+ hidden_channels,
55
+ filter_channels,
56
+ kernel_size,
57
+ p_dropout=p_dropout,
58
+ causal=True,
59
+ )
60
+ )
61
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
62
+
63
+ def forward(self, x, x_mask):
64
+ """
65
+ x: decoder input
66
+ h: encoder output
67
+ """
68
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
69
+ device=x.device, dtype=x.dtype
70
+ )
71
+ x = x * x_mask
72
+ for i in range(self.n_layers):
73
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
74
+ y = self.drop(y)
75
+ x = self.norm_layers_0[i](x + y)
76
+
77
+ y = self.ffn_layers[i](x, x_mask)
78
+ y = self.drop(y)
79
+ x = self.norm_layers_1[i](x + y)
80
+ x = x * x_mask
81
+ return x
82
+
83
+
84
+ class Encoder(nn.Module):
85
+ def __init__(
86
+ self,
87
+ hidden_channels,
88
+ filter_channels,
89
+ n_heads,
90
+ n_layers,
91
+ kernel_size=1,
92
+ p_dropout=0.0,
93
+ window_size=4,
94
+ **kwargs
95
+ ):
96
+ super().__init__()
97
+ self.hidden_channels = hidden_channels
98
+ self.filter_channels = filter_channels
99
+ self.n_heads = n_heads
100
+ self.n_layers = n_layers
101
+ self.kernel_size = kernel_size
102
+ self.p_dropout = p_dropout
103
+ self.window_size = window_size
104
+
105
+ self.drop = nn.Dropout(p_dropout)
106
+ self.attn_layers = nn.ModuleList()
107
+ self.norm_layers_1 = nn.ModuleList()
108
+ self.ffn_layers = nn.ModuleList()
109
+ self.norm_layers_2 = nn.ModuleList()
110
+ for i in range(self.n_layers):
111
+ self.attn_layers.append(
112
+ MultiHeadAttention(
113
+ hidden_channels,
114
+ hidden_channels,
115
+ n_heads,
116
+ p_dropout=p_dropout,
117
+ window_size=window_size,
118
+ )
119
+ )
120
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
121
+ self.ffn_layers.append(
122
+ FFN(
123
+ hidden_channels,
124
+ hidden_channels,
125
+ filter_channels,
126
+ kernel_size,
127
+ p_dropout=p_dropout,
128
+ )
129
+ )
130
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
131
+
132
+ def forward(self, x, x_mask):
133
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
134
+ x = x * x_mask
135
+ for i in range(self.n_layers):
136
+ y = self.attn_layers[i](x, x, attn_mask)
137
+ y = self.drop(y)
138
+ x = self.norm_layers_1[i](x + y)
139
+
140
+ y = self.ffn_layers[i](x, x_mask)
141
+ y = self.drop(y)
142
+ x = self.norm_layers_2[i](x + y)
143
+ x = x * x_mask
144
+ return x
145
+
146
+
147
+ class Decoder(nn.Module):
148
+ def __init__(
149
+ self,
150
+ hidden_channels,
151
+ filter_channels,
152
+ n_heads,
153
+ n_layers,
154
+ kernel_size=1,
155
+ p_dropout=0.0,
156
+ proximal_bias=False,
157
+ proximal_init=True,
158
+ **kwargs
159
+ ):
160
+ super().__init__()
161
+ self.hidden_channels = hidden_channels
162
+ self.filter_channels = filter_channels
163
+ self.n_heads = n_heads
164
+ self.n_layers = n_layers
165
+ self.kernel_size = kernel_size
166
+ self.p_dropout = p_dropout
167
+ self.proximal_bias = proximal_bias
168
+ self.proximal_init = proximal_init
169
+
170
+ self.drop = nn.Dropout(p_dropout)
171
+ self.self_attn_layers = nn.ModuleList()
172
+ self.norm_layers_0 = nn.ModuleList()
173
+ self.encdec_attn_layers = nn.ModuleList()
174
+ self.norm_layers_1 = nn.ModuleList()
175
+ self.ffn_layers = nn.ModuleList()
176
+ self.norm_layers_2 = nn.ModuleList()
177
+ for i in range(self.n_layers):
178
+ self.self_attn_layers.append(
179
+ MultiHeadAttention(
180
+ hidden_channels,
181
+ hidden_channels,
182
+ n_heads,
183
+ p_dropout=p_dropout,
184
+ proximal_bias=proximal_bias,
185
+ proximal_init=proximal_init,
186
+ )
187
+ )
188
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
189
+ self.encdec_attn_layers.append(
190
+ MultiHeadAttention(
191
+ hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
192
+ )
193
+ )
194
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
195
+ self.ffn_layers.append(
196
+ FFN(
197
+ hidden_channels,
198
+ hidden_channels,
199
+ filter_channels,
200
+ kernel_size,
201
+ p_dropout=p_dropout,
202
+ causal=True,
203
+ )
204
+ )
205
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
206
+
207
+ def forward(self, x, x_mask, h, h_mask):
208
+ """
209
+ x: decoder input
210
+ h: encoder output
211
+ """
212
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
213
+ device=x.device, dtype=x.dtype
214
+ )
215
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
216
+ x = x * x_mask
217
+ for i in range(self.n_layers):
218
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
219
+ y = self.drop(y)
220
+ x = self.norm_layers_0[i](x + y)
221
+
222
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
223
+ y = self.drop(y)
224
+ x = self.norm_layers_1[i](x + y)
225
+
226
+ y = self.ffn_layers[i](x, x_mask)
227
+ y = self.drop(y)
228
+ x = self.norm_layers_2[i](x + y)
229
+ x = x * x_mask
230
+ return x
231
+
232
+
233
+ class MultiHeadAttention(nn.Module):
234
+ def __init__(
235
+ self,
236
+ channels,
237
+ out_channels,
238
+ n_heads,
239
+ p_dropout=0.0,
240
+ window_size=None,
241
+ heads_share=True,
242
+ block_length=None,
243
+ proximal_bias=False,
244
+ proximal_init=False,
245
+ ):
246
+ super().__init__()
247
+ assert channels % n_heads == 0
248
+
249
+ self.channels = channels
250
+ self.out_channels = out_channels
251
+ self.n_heads = n_heads
252
+ self.p_dropout = p_dropout
253
+ self.window_size = window_size
254
+ self.heads_share = heads_share
255
+ self.block_length = block_length
256
+ self.proximal_bias = proximal_bias
257
+ self.proximal_init = proximal_init
258
+ self.attn = None
259
+
260
+ self.k_channels = channels // n_heads
261
+ self.conv_q = nn.Conv1d(channels, channels, 1)
262
+ self.conv_k = nn.Conv1d(channels, channels, 1)
263
+ self.conv_v = nn.Conv1d(channels, channels, 1)
264
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
265
+ self.drop = nn.Dropout(p_dropout)
266
+
267
+ if window_size is not None:
268
+ n_heads_rel = 1 if heads_share else n_heads
269
+ rel_stddev = self.k_channels**-0.5
270
+ self.emb_rel_k = nn.Parameter(
271
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
272
+ * rel_stddev
273
+ )
274
+ self.emb_rel_v = nn.Parameter(
275
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
276
+ * rel_stddev
277
+ )
278
+
279
+ nn.init.xavier_uniform_(self.conv_q.weight)
280
+ nn.init.xavier_uniform_(self.conv_k.weight)
281
+ nn.init.xavier_uniform_(self.conv_v.weight)
282
+ if proximal_init:
283
+ with torch.no_grad():
284
+ self.conv_k.weight.copy_(self.conv_q.weight)
285
+ self.conv_k.bias.copy_(self.conv_q.bias)
286
+
287
+ def forward(self, x, c, attn_mask=None):
288
+ q = self.conv_q(x)
289
+ k = self.conv_k(c)
290
+ v = self.conv_v(c)
291
+
292
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
293
+
294
+ x = self.conv_o(x)
295
+ return x
296
+
297
+ def attention(self, query, key, value, mask=None):
298
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
299
+ b, d, t_s, t_t = (*key.size(), query.size(2))
300
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
301
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
302
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
303
+
304
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
305
+ if self.window_size is not None:
306
+ assert (
307
+ t_s == t_t
308
+ ), "Relative attention is only available for self-attention."
309
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
310
+ rel_logits = self._matmul_with_relative_keys(
311
+ query / math.sqrt(self.k_channels), key_relative_embeddings
312
+ )
313
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
314
+ scores = scores + scores_local
315
+ if self.proximal_bias:
316
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
317
+ scores = scores + self._attention_bias_proximal(t_s).to(
318
+ device=scores.device, dtype=scores.dtype
319
+ )
320
+ if mask is not None:
321
+ scores = scores.masked_fill(mask == 0, -1e4)
322
+ if self.block_length is not None:
323
+ assert (
324
+ t_s == t_t
325
+ ), "Local attention is only available for self-attention."
326
+ block_mask = (
327
+ torch.ones_like(scores)
328
+ .triu(-self.block_length)
329
+ .tril(self.block_length)
330
+ )
331
+ scores = scores.masked_fill(block_mask == 0, -1e4)
332
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
333
+ p_attn = self.drop(p_attn)
334
+ output = torch.matmul(p_attn, value)
335
+ if self.window_size is not None:
336
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
337
+ value_relative_embeddings = self._get_relative_embeddings(
338
+ self.emb_rel_v, t_s
339
+ )
340
+ output = output + self._matmul_with_relative_values(
341
+ relative_weights, value_relative_embeddings
342
+ )
343
+ output = (
344
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
345
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
346
+ return output, p_attn
347
+
348
+ def _matmul_with_relative_values(self, x, y):
349
+ """
350
+ x: [b, h, l, m]
351
+ y: [h or 1, m, d]
352
+ ret: [b, h, l, d]
353
+ """
354
+ ret = torch.matmul(x, y.unsqueeze(0))
355
+ return ret
356
+
357
+ def _matmul_with_relative_keys(self, x, y):
358
+ """
359
+ x: [b, h, l, d]
360
+ y: [h or 1, m, d]
361
+ ret: [b, h, l, m]
362
+ """
363
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
364
+ return ret
365
+
366
+ def _get_relative_embeddings(self, relative_embeddings, length):
367
+ 2 * self.window_size + 1
368
+ # Pad first before slice to avoid using cond ops.
369
+ pad_length = max(length - (self.window_size + 1), 0)
370
+ slice_start_position = max((self.window_size + 1) - length, 0)
371
+ slice_end_position = slice_start_position + 2 * length - 1
372
+ if pad_length > 0:
373
+ padded_relative_embeddings = F.pad(
374
+ relative_embeddings,
375
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
376
+ )
377
+ else:
378
+ padded_relative_embeddings = relative_embeddings
379
+ used_relative_embeddings = padded_relative_embeddings[
380
+ :, slice_start_position:slice_end_position
381
+ ]
382
+ return used_relative_embeddings
383
+
384
+ def _relative_position_to_absolute_position(self, x):
385
+ """
386
+ x: [b, h, l, 2*l-1]
387
+ ret: [b, h, l, l]
388
+ """
389
+ batch, heads, length, _ = x.size()
390
+ # Concat columns of pad to shift from relative to absolute indexing.
391
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
392
+
393
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
394
+ x_flat = x.view([batch, heads, length * 2 * length])
395
+ x_flat = F.pad(
396
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
397
+ )
398
+
399
+ # Reshape and slice out the padded elements.
400
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
401
+ :, :, :length, length - 1 :
402
+ ]
403
+ return x_final
404
+
405
+ def _absolute_position_to_relative_position(self, x):
406
+ """
407
+ x: [b, h, l, l]
408
+ ret: [b, h, l, 2*l-1]
409
+ """
410
+ batch, heads, length, _ = x.size()
411
+ # pad along column
412
+ x = F.pad(
413
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
414
+ )
415
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
416
+ # add 0's in the beginning that will skew the elements after reshape
417
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
418
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
419
+ return x_final
420
+
421
+ def _attention_bias_proximal(self, length):
422
+ """Bias for self-attention to encourage attention to close positions.
423
+ Args:
424
+ length: an integer scalar.
425
+ Returns:
426
+ a Tensor with shape [1, 1, length, length]
427
+ """
428
+ r = torch.arange(length, dtype=torch.float32)
429
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
430
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
431
+
432
+
433
+ class FFN(nn.Module):
434
+ def __init__(
435
+ self,
436
+ in_channels,
437
+ out_channels,
438
+ filter_channels,
439
+ kernel_size,
440
+ p_dropout=0.0,
441
+ activation=None,
442
+ causal=False,
443
+ ):
444
+ super().__init__()
445
+ self.in_channels = in_channels
446
+ self.out_channels = out_channels
447
+ self.filter_channels = filter_channels
448
+ self.kernel_size = kernel_size
449
+ self.p_dropout = p_dropout
450
+ self.activation = activation
451
+ self.causal = causal
452
+
453
+ if causal:
454
+ self.padding = self._causal_padding
455
+ else:
456
+ self.padding = self._same_padding
457
+
458
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
459
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
460
+ self.drop = nn.Dropout(p_dropout)
461
+
462
+ def forward(self, x, x_mask):
463
+ x = self.conv_1(self.padding(x * x_mask))
464
+ if self.activation == "gelu":
465
+ x = x * torch.sigmoid(1.702 * x)
466
+ else:
467
+ x = torch.relu(x)
468
+ x = self.drop(x)
469
+ x = self.conv_2(self.padding(x * x_mask))
470
+ return x * x_mask
471
+
472
+ def _causal_padding(self, x):
473
+ if self.kernel_size == 1:
474
+ return x
475
+ pad_l = self.kernel_size - 1
476
+ pad_r = 0
477
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
478
+ x = F.pad(x, commons.convert_pad_shape(padding))
479
+ return x
480
+
481
+ def _same_padding(self, x):
482
+ if self.kernel_size == 1:
483
+ return x
484
+ pad_l = (self.kernel_size - 1) // 2
485
+ pad_r = self.kernel_size // 2
486
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
487
+ x = F.pad(x, commons.convert_pad_shape(padding))
488
+ return x
so_vits_svc_fork/modules/commons.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import Tensor
6
+
7
+
8
+ def slice_segments(x: Tensor, starts: Tensor, length: int) -> Tensor:
9
+ if length is None:
10
+ return x
11
+ length = min(length, x.size(-1))
12
+ x_slice = torch.zeros((x.size()[:-1] + (length,)), dtype=x.dtype, device=x.device)
13
+ ends = starts + length
14
+ for i, (start, end) in enumerate(zip(starts, ends)):
15
+ # LOG.debug(i, start, end, x.size(), x[i, ..., start:end].size(), x_slice.size())
16
+ # x_slice[i, ...] = x[i, ..., start:end] need to pad
17
+ # x_slice[i, ..., :end - start] = x[i, ..., start:end] this does not work
18
+ x_slice[i, ...] = F.pad(x[i, ..., start:end], (0, max(0, length - x.size(-1))))
19
+ return x_slice
20
+
21
+
22
+ def rand_slice_segments_with_pitch(
23
+ x: Tensor, f0: Tensor, x_lengths: Tensor | int | None, segment_size: int | None
24
+ ):
25
+ if segment_size is None:
26
+ return x, f0, torch.arange(x.size(0), device=x.device)
27
+ if x_lengths is None:
28
+ x_lengths = x.size(-1) * torch.ones(
29
+ x.size(0), dtype=torch.long, device=x.device
30
+ )
31
+ # slice_starts = (torch.rand(z.size(0), device=z.device) * (z_lengths - segment_size)).long()
32
+ slice_starts = (
33
+ torch.rand(x.size(0), device=x.device)
34
+ * torch.max(
35
+ x_lengths - segment_size, torch.zeros_like(x_lengths, device=x.device)
36
+ )
37
+ ).long()
38
+ z_slice = slice_segments(x, slice_starts, segment_size)
39
+ f0_slice = slice_segments(f0, slice_starts, segment_size)
40
+ return z_slice, f0_slice, slice_starts
41
+
42
+
43
+ def slice_2d_segments(x: Tensor, starts: Tensor, length: int) -> Tensor:
44
+ batch_size, num_features, seq_len = x.shape
45
+ ends = starts + length
46
+ idxs = (
47
+ torch.arange(seq_len, device=x.device)
48
+ .unsqueeze(0)
49
+ .unsqueeze(1)
50
+ .repeat(batch_size, num_features, 1)
51
+ )
52
+ mask = (idxs >= starts.unsqueeze(-1).unsqueeze(-1)) & (
53
+ idxs < ends.unsqueeze(-1).unsqueeze(-1)
54
+ )
55
+ return x[mask].reshape(batch_size, num_features, length)
56
+
57
+
58
+ def slice_1d_segments(x: Tensor, starts: Tensor, length: int) -> Tensor:
59
+ batch_size, seq_len = x.shape
60
+ ends = starts + length
61
+ idxs = torch.arange(seq_len, device=x.device).unsqueeze(0).repeat(batch_size, 1)
62
+ mask = (idxs >= starts.unsqueeze(-1)) & (idxs < ends.unsqueeze(-1))
63
+ return x[mask].reshape(batch_size, length)
64
+
65
+
66
+ def _slice_segments_v3(x: Tensor, starts: Tensor, length: int) -> Tensor:
67
+ shape = x.shape[:-1] + (length,)
68
+ ends = starts + length
69
+ idxs = torch.arange(x.shape[-1], device=x.device).unsqueeze(0).unsqueeze(0)
70
+ unsqueeze_dims = len(shape) - len(
71
+ x.shape
72
+ ) # calculate number of dimensions to unsqueeze
73
+ starts = starts.reshape(starts.shape + (1,) * unsqueeze_dims)
74
+ ends = ends.reshape(ends.shape + (1,) * unsqueeze_dims)
75
+ mask = (idxs >= starts) & (idxs < ends)
76
+ return x[mask].reshape(shape)
77
+
78
+
79
+ def init_weights(m, mean=0.0, std=0.01):
80
+ classname = m.__class__.__name__
81
+ if classname.find("Conv") != -1:
82
+ m.weight.data.normal_(mean, std)
83
+
84
+
85
+ def get_padding(kernel_size, dilation=1):
86
+ return int((kernel_size * dilation - dilation) / 2)
87
+
88
+
89
+ def convert_pad_shape(pad_shape):
90
+ l = pad_shape[::-1]
91
+ pad_shape = [item for sublist in l for item in sublist]
92
+ return pad_shape
93
+
94
+
95
+ def subsequent_mask(length):
96
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
97
+ return mask
98
+
99
+
100
+ @torch.jit.script
101
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
102
+ n_channels_int = n_channels[0]
103
+ in_act = input_a + input_b
104
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
105
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
106
+ acts = t_act * s_act
107
+ return acts
108
+
109
+
110
+ def sequence_mask(length, max_length=None):
111
+ if max_length is None:
112
+ max_length = length.max()
113
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
114
+ return x.unsqueeze(0) < length.unsqueeze(1)
115
+
116
+
117
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
118
+ if isinstance(parameters, torch.Tensor):
119
+ parameters = [parameters]
120
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
121
+ norm_type = float(norm_type)
122
+ if clip_value is not None:
123
+ clip_value = float(clip_value)
124
+
125
+ total_norm = 0
126
+ for p in parameters:
127
+ param_norm = p.grad.data.norm(norm_type)
128
+ total_norm += param_norm.item() ** norm_type
129
+ if clip_value is not None:
130
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
131
+ total_norm = total_norm ** (1.0 / norm_type)
132
+ return total_norm
so_vits_svc_fork/modules/decoders/__init__.py ADDED
File without changes
so_vits_svc_fork/modules/decoders/f0.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from so_vits_svc_fork.modules import attentions as attentions
5
+
6
+
7
+ class F0Decoder(nn.Module):
8
+ def __init__(
9
+ self,
10
+ out_channels,
11
+ hidden_channels,
12
+ filter_channels,
13
+ n_heads,
14
+ n_layers,
15
+ kernel_size,
16
+ p_dropout,
17
+ spk_channels=0,
18
+ ):
19
+ super().__init__()
20
+ self.out_channels = out_channels
21
+ self.hidden_channels = hidden_channels
22
+ self.filter_channels = filter_channels
23
+ self.n_heads = n_heads
24
+ self.n_layers = n_layers
25
+ self.kernel_size = kernel_size
26
+ self.p_dropout = p_dropout
27
+ self.spk_channels = spk_channels
28
+
29
+ self.prenet = nn.Conv1d(hidden_channels, hidden_channels, 3, padding=1)
30
+ self.decoder = attentions.FFT(
31
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
32
+ )
33
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
34
+ self.f0_prenet = nn.Conv1d(1, hidden_channels, 3, padding=1)
35
+ self.cond = nn.Conv1d(spk_channels, hidden_channels, 1)
36
+
37
+ def forward(self, x, norm_f0, x_mask, spk_emb=None):
38
+ x = torch.detach(x)
39
+ if spk_emb is not None:
40
+ spk_emb = torch.detach(spk_emb)
41
+ x = x + self.cond(spk_emb)
42
+ x += self.f0_prenet(norm_f0)
43
+ x = self.prenet(x) * x_mask
44
+ x = self.decoder(x * x_mask, x_mask)
45
+ x = self.proj(x) * x_mask
46
+ return x
so_vits_svc_fork/modules/decoders/hifigan/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from ._models import NSFHifiGANGenerator
2
+
3
+ __all__ = ["NSFHifiGANGenerator"]
so_vits_svc_fork/modules/decoders/hifigan/_models.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from logging import getLogger
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch.nn import Conv1d, ConvTranspose1d
8
+ from torch.nn.utils import remove_weight_norm, weight_norm
9
+
10
+ from ...modules import ResBlock1, ResBlock2
11
+ from ._utils import init_weights
12
+
13
+ LOG = getLogger(__name__)
14
+
15
+ LRELU_SLOPE = 0.1
16
+
17
+
18
+ def padDiff(x):
19
+ return F.pad(
20
+ F.pad(x, (0, 0, -1, 1), "constant", 0) - x, (0, 0, 0, -1), "constant", 0
21
+ )
22
+
23
+
24
+ class SineGen(torch.nn.Module):
25
+ """Definition of sine generator
26
+ SineGen(samp_rate, harmonic_num = 0,
27
+ sine_amp = 0.1, noise_std = 0.003,
28
+ voiced_threshold = 0,
29
+ flag_for_pulse=False)
30
+ samp_rate: sampling rate in Hz
31
+ harmonic_num: number of harmonic overtones (default 0)
32
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
33
+ noise_std: std of Gaussian noise (default 0.003)
34
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
35
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
36
+ Note: when flag_for_pulse is True, the first time step of a voiced
37
+ segment is always sin(np.pi) or cos(0)
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ samp_rate,
43
+ harmonic_num=0,
44
+ sine_amp=0.1,
45
+ noise_std=0.003,
46
+ voiced_threshold=0,
47
+ flag_for_pulse=False,
48
+ ):
49
+ super().__init__()
50
+ self.sine_amp = sine_amp
51
+ self.noise_std = noise_std
52
+ self.harmonic_num = harmonic_num
53
+ self.dim = self.harmonic_num + 1
54
+ self.sampling_rate = samp_rate
55
+ self.voiced_threshold = voiced_threshold
56
+ self.flag_for_pulse = flag_for_pulse
57
+
58
+ def _f02uv(self, f0):
59
+ # generate uv signal
60
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
61
+ return uv
62
+
63
+ def _f02sine(self, f0_values):
64
+ """f0_values: (batchsize, length, dim)
65
+ where dim indicates fundamental tone and overtones
66
+ """
67
+ # convert to F0 in rad. The integer part n can be ignored
68
+ # because 2 * np.pi * n doesn't affect phase
69
+ rad_values = (f0_values / self.sampling_rate) % 1
70
+
71
+ # initial phase noise (no noise for fundamental component)
72
+ rand_ini = torch.rand(
73
+ f0_values.shape[0], f0_values.shape[2], device=f0_values.device
74
+ )
75
+ rand_ini[:, 0] = 0
76
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
77
+
78
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
79
+ if not self.flag_for_pulse:
80
+ # for normal case
81
+
82
+ # To prevent torch.cumsum numerical overflow,
83
+ # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
84
+ # Buffer tmp_over_one_idx indicates the time step to add -1.
85
+ # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
86
+ tmp_over_one = torch.cumsum(rad_values, 1) % 1
87
+ tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
88
+ cumsum_shift = torch.zeros_like(rad_values)
89
+ cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
90
+
91
+ sines = torch.sin(
92
+ torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
93
+ )
94
+ else:
95
+ # If necessary, make sure that the first time step of every
96
+ # voiced segments is sin(pi) or cos(0)
97
+ # This is used for pulse-train generation
98
+
99
+ # identify the last time step in unvoiced segments
100
+ uv = self._f02uv(f0_values)
101
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
102
+ uv_1[:, -1, :] = 1
103
+ u_loc = (uv < 1) * (uv_1 > 0)
104
+
105
+ # get the instantanouse phase
106
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
107
+ # different batch needs to be processed differently
108
+ for idx in range(f0_values.shape[0]):
109
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
110
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
111
+ # stores the accumulation of i.phase within
112
+ # each voiced segments
113
+ tmp_cumsum[idx, :, :] = 0
114
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
115
+
116
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
117
+ # within the previous voiced segment.
118
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
119
+
120
+ # get the sines
121
+ sines = torch.cos(i_phase * 2 * np.pi)
122
+ return sines
123
+
124
+ def forward(self, f0):
125
+ """sine_tensor, uv = forward(f0)
126
+ input F0: tensor(batchsize=1, length, dim=1)
127
+ f0 for unvoiced steps should be 0
128
+ output sine_tensor: tensor(batchsize=1, length, dim)
129
+ output uv: tensor(batchsize=1, length, 1)
130
+ """
131
+ with torch.no_grad():
132
+ # f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
133
+ # fundamental component
134
+ # fn = torch.multiply(
135
+ # f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)
136
+ # )
137
+ fn = torch.multiply(
138
+ f0, torch.arange(1, self.harmonic_num + 2).to(f0.device).to(f0.dtype)
139
+ )
140
+
141
+ # generate sine waveforms
142
+ sine_waves = self._f02sine(fn) * self.sine_amp
143
+
144
+ # generate uv signal
145
+ # uv = torch.ones(f0.shape)
146
+ # uv = uv * (f0 > self.voiced_threshold)
147
+ uv = self._f02uv(f0)
148
+
149
+ # noise: for unvoiced should be similar to sine_amp
150
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
151
+ # . for voiced regions is self.noise_std
152
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
153
+ noise = noise_amp * torch.randn_like(sine_waves)
154
+
155
+ # first: set the unvoiced part to 0 by uv
156
+ # then: additive noise
157
+ sine_waves = sine_waves * uv + noise
158
+ return sine_waves, uv, noise
159
+
160
+
161
+ class SourceModuleHnNSF(torch.nn.Module):
162
+ """SourceModule for hn-nsf
163
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
164
+ add_noise_std=0.003, voiced_threshod=0)
165
+ sampling_rate: sampling_rate in Hz
166
+ harmonic_num: number of harmonic above F0 (default: 0)
167
+ sine_amp: amplitude of sine source signal (default: 0.1)
168
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
169
+ note that amplitude of noise in unvoiced is decided
170
+ by sine_amp
171
+ voiced_threshold: threshold to set U/V given F0 (default: 0)
172
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
173
+ F0_sampled (batchsize, length, 1)
174
+ Sine_source (batchsize, length, 1)
175
+ noise_source (batchsize, length 1)
176
+ uv (batchsize, length, 1)
177
+ """
178
+
179
+ def __init__(
180
+ self,
181
+ sampling_rate,
182
+ harmonic_num=0,
183
+ sine_amp=0.1,
184
+ add_noise_std=0.003,
185
+ voiced_threshod=0,
186
+ ):
187
+ super().__init__()
188
+
189
+ self.sine_amp = sine_amp
190
+ self.noise_std = add_noise_std
191
+
192
+ # to produce sine waveforms
193
+ self.l_sin_gen = SineGen(
194
+ sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
195
+ )
196
+
197
+ # to merge source harmonics into a single excitation
198
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
199
+ self.l_tanh = torch.nn.Tanh()
200
+
201
+ def forward(self, x):
202
+ """
203
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
204
+ F0_sampled (batchsize, length, 1)
205
+ Sine_source (batchsize, length, 1)
206
+ noise_source (batchsize, length 1)
207
+ """
208
+ # source for harmonic branch
209
+ sine_wavs, uv, _ = self.l_sin_gen(x)
210
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
211
+
212
+ # source for noise branch, in the same shape as uv
213
+ noise = torch.randn_like(uv) * self.sine_amp / 3
214
+ return sine_merge, noise, uv
215
+
216
+
217
+ class NSFHifiGANGenerator(torch.nn.Module):
218
+ def __init__(self, h):
219
+ super().__init__()
220
+ self.h = h
221
+
222
+ self.num_kernels = len(h["resblock_kernel_sizes"])
223
+ self.num_upsamples = len(h["upsample_rates"])
224
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h["upsample_rates"]))
225
+ self.m_source = SourceModuleHnNSF(
226
+ sampling_rate=h["sampling_rate"], harmonic_num=8
227
+ )
228
+ self.noise_convs = nn.ModuleList()
229
+ self.conv_pre = weight_norm(
230
+ Conv1d(h["inter_channels"], h["upsample_initial_channel"], 7, 1, padding=3)
231
+ )
232
+ resblock = ResBlock1 if h["resblock"] == "1" else ResBlock2
233
+ self.ups = nn.ModuleList()
234
+ for i, (u, k) in enumerate(
235
+ zip(h["upsample_rates"], h["upsample_kernel_sizes"])
236
+ ):
237
+ c_cur = h["upsample_initial_channel"] // (2 ** (i + 1))
238
+ self.ups.append(
239
+ weight_norm(
240
+ ConvTranspose1d(
241
+ h["upsample_initial_channel"] // (2**i),
242
+ h["upsample_initial_channel"] // (2 ** (i + 1)),
243
+ k,
244
+ u,
245
+ padding=(k - u) // 2,
246
+ )
247
+ )
248
+ )
249
+ if i + 1 < len(h["upsample_rates"]): #
250
+ stride_f0 = np.prod(h["upsample_rates"][i + 1 :])
251
+ self.noise_convs.append(
252
+ Conv1d(
253
+ 1,
254
+ c_cur,
255
+ kernel_size=stride_f0 * 2,
256
+ stride=stride_f0,
257
+ padding=stride_f0 // 2,
258
+ )
259
+ )
260
+ else:
261
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
262
+ self.resblocks = nn.ModuleList()
263
+ for i in range(len(self.ups)):
264
+ ch = h["upsample_initial_channel"] // (2 ** (i + 1))
265
+ for j, (k, d) in enumerate(
266
+ zip(h["resblock_kernel_sizes"], h["resblock_dilation_sizes"])
267
+ ):
268
+ self.resblocks.append(resblock(ch, k, d))
269
+
270
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
271
+ self.ups.apply(init_weights)
272
+ self.conv_post.apply(init_weights)
273
+ self.cond = nn.Conv1d(h["gin_channels"], h["upsample_initial_channel"], 1)
274
+
275
+ def forward(self, x, f0, g=None):
276
+ # LOG.info(1,x.shape,f0.shape,f0[:, None].shape)
277
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
278
+ # LOG.info(2,f0.shape)
279
+ har_source, noi_source, uv = self.m_source(f0)
280
+ har_source = har_source.transpose(1, 2)
281
+ x = self.conv_pre(x)
282
+ x = x + self.cond(g)
283
+ # LOG.info(124,x.shape,har_source.shape)
284
+ for i in range(self.num_upsamples):
285
+ x = F.leaky_relu(x, LRELU_SLOPE)
286
+ # LOG.info(3,x.shape)
287
+ x = self.ups[i](x)
288
+ x_source = self.noise_convs[i](har_source)
289
+ # LOG.info(4,x_source.shape,har_source.shape,x.shape)
290
+ x = x + x_source
291
+ xs = None
292
+ for j in range(self.num_kernels):
293
+ if xs is None:
294
+ xs = self.resblocks[i * self.num_kernels + j](x)
295
+ else:
296
+ xs += self.resblocks[i * self.num_kernels + j](x)
297
+ x = xs / self.num_kernels
298
+ x = F.leaky_relu(x)
299
+ x = self.conv_post(x)
300
+ x = torch.tanh(x)
301
+
302
+ return x
303
+
304
+ def remove_weight_norm(self):
305
+ LOG.info("Removing weight norm...")
306
+ for l in self.ups:
307
+ remove_weight_norm(l)
308
+ for l in self.resblocks:
309
+ l.remove_weight_norm()
310
+ remove_weight_norm(self.conv_pre)
311
+ remove_weight_norm(self.conv_post)
so_vits_svc_fork/modules/decoders/hifigan/_utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from logging import getLogger
2
+
3
+ # matplotlib.use("Agg")
4
+
5
+ LOG = getLogger(__name__)
6
+
7
+
8
+ def init_weights(m, mean=0.0, std=0.01):
9
+ classname = m.__class__.__name__
10
+ if classname.find("Conv") != -1:
11
+ m.weight.data.normal_(mean, std)
12
+
13
+
14
+ def get_padding(kernel_size, dilation=1):
15
+ return int((kernel_size * dilation - dilation) / 2)
so_vits_svc_fork/modules/decoders/mb_istft/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._generators import (
2
+ Multiband_iSTFT_Generator,
3
+ Multistream_iSTFT_Generator,
4
+ iSTFT_Generator,
5
+ )
6
+ from ._loss import subband_stft_loss
7
+ from ._pqmf import PQMF
8
+
9
+ __all__ = [
10
+ "subband_stft_loss",
11
+ "PQMF",
12
+ "iSTFT_Generator",
13
+ "Multiband_iSTFT_Generator",
14
+ "Multistream_iSTFT_Generator",
15
+ ]
so_vits_svc_fork/modules/decoders/mb_istft/_generators.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import Conv1d, ConvTranspose1d
6
+ from torch.nn import functional as F
7
+ from torch.nn.utils import remove_weight_norm, weight_norm
8
+
9
+ from ....modules import modules
10
+ from ....modules.commons import get_padding, init_weights
11
+ from ._pqmf import PQMF
12
+ from ._stft import TorchSTFT
13
+
14
+
15
+ class iSTFT_Generator(torch.nn.Module):
16
+ def __init__(
17
+ self,
18
+ initial_channel,
19
+ resblock,
20
+ resblock_kernel_sizes,
21
+ resblock_dilation_sizes,
22
+ upsample_rates,
23
+ upsample_initial_channel,
24
+ upsample_kernel_sizes,
25
+ gen_istft_n_fft,
26
+ gen_istft_hop_size,
27
+ gin_channels=0,
28
+ ):
29
+ super().__init__()
30
+ # self.h = h
31
+ self.gen_istft_n_fft = gen_istft_n_fft
32
+ self.gen_istft_hop_size = gen_istft_hop_size
33
+
34
+ self.num_kernels = len(resblock_kernel_sizes)
35
+ self.num_upsamples = len(upsample_rates)
36
+ self.conv_pre = weight_norm(
37
+ Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
38
+ )
39
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
40
+
41
+ self.ups = nn.ModuleList()
42
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
43
+ self.ups.append(
44
+ weight_norm(
45
+ ConvTranspose1d(
46
+ upsample_initial_channel // (2**i),
47
+ upsample_initial_channel // (2 ** (i + 1)),
48
+ k,
49
+ u,
50
+ padding=(k - u) // 2,
51
+ )
52
+ )
53
+ )
54
+
55
+ self.resblocks = nn.ModuleList()
56
+ for i in range(len(self.ups)):
57
+ ch = upsample_initial_channel // (2 ** (i + 1))
58
+ for j, (k, d) in enumerate(
59
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
60
+ ):
61
+ self.resblocks.append(resblock(ch, k, d))
62
+
63
+ self.post_n_fft = self.gen_istft_n_fft
64
+ self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
65
+ self.ups.apply(init_weights)
66
+ self.conv_post.apply(init_weights)
67
+ self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
68
+ self.stft = TorchSTFT(
69
+ filter_length=self.gen_istft_n_fft,
70
+ hop_length=self.gen_istft_hop_size,
71
+ win_length=self.gen_istft_n_fft,
72
+ )
73
+
74
+ def forward(self, x, g=None):
75
+ x = self.conv_pre(x)
76
+ for i in range(self.num_upsamples):
77
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
78
+ x = self.ups[i](x)
79
+ xs = None
80
+ for j in range(self.num_kernels):
81
+ if xs is None:
82
+ xs = self.resblocks[i * self.num_kernels + j](x)
83
+ else:
84
+ xs += self.resblocks[i * self.num_kernels + j](x)
85
+ x = xs / self.num_kernels
86
+ x = F.leaky_relu(x)
87
+ x = self.reflection_pad(x)
88
+ x = self.conv_post(x)
89
+ spec = torch.exp(x[:, : self.post_n_fft // 2 + 1, :])
90
+ phase = math.pi * torch.sin(x[:, self.post_n_fft // 2 + 1 :, :])
91
+ out = self.stft.inverse(spec, phase).to(x.device)
92
+ return out, None
93
+
94
+ def remove_weight_norm(self):
95
+ print("Removing weight norm...")
96
+ for l in self.ups:
97
+ remove_weight_norm(l)
98
+ for l in self.resblocks:
99
+ l.remove_weight_norm()
100
+ remove_weight_norm(self.conv_pre)
101
+ remove_weight_norm(self.conv_post)
102
+
103
+
104
+ class Multiband_iSTFT_Generator(torch.nn.Module):
105
+ def __init__(
106
+ self,
107
+ initial_channel,
108
+ resblock,
109
+ resblock_kernel_sizes,
110
+ resblock_dilation_sizes,
111
+ upsample_rates,
112
+ upsample_initial_channel,
113
+ upsample_kernel_sizes,
114
+ gen_istft_n_fft,
115
+ gen_istft_hop_size,
116
+ subbands,
117
+ gin_channels=0,
118
+ ):
119
+ super().__init__()
120
+ # self.h = h
121
+ self.subbands = subbands
122
+ self.num_kernels = len(resblock_kernel_sizes)
123
+ self.num_upsamples = len(upsample_rates)
124
+ self.conv_pre = weight_norm(
125
+ Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
126
+ )
127
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
128
+
129
+ self.ups = nn.ModuleList()
130
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
131
+ self.ups.append(
132
+ weight_norm(
133
+ ConvTranspose1d(
134
+ upsample_initial_channel // (2**i),
135
+ upsample_initial_channel // (2 ** (i + 1)),
136
+ k,
137
+ u,
138
+ padding=(k - u) // 2,
139
+ )
140
+ )
141
+ )
142
+
143
+ self.resblocks = nn.ModuleList()
144
+ for i in range(len(self.ups)):
145
+ ch = upsample_initial_channel // (2 ** (i + 1))
146
+ for j, (k, d) in enumerate(
147
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
148
+ ):
149
+ self.resblocks.append(resblock(ch, k, d))
150
+
151
+ self.post_n_fft = gen_istft_n_fft
152
+ self.ups.apply(init_weights)
153
+ self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
154
+ self.reshape_pixelshuffle = []
155
+
156
+ self.subband_conv_post = weight_norm(
157
+ Conv1d(ch, self.subbands * (self.post_n_fft + 2), 7, 1, padding=3)
158
+ )
159
+
160
+ self.subband_conv_post.apply(init_weights)
161
+
162
+ self.gen_istft_n_fft = gen_istft_n_fft
163
+ self.gen_istft_hop_size = gen_istft_hop_size
164
+
165
+ def forward(self, x, g=None):
166
+ stft = TorchSTFT(
167
+ filter_length=self.gen_istft_n_fft,
168
+ hop_length=self.gen_istft_hop_size,
169
+ win_length=self.gen_istft_n_fft,
170
+ ).to(x.device)
171
+ pqmf = PQMF(x.device, subbands=self.subbands).to(x.device, dtype=x.dtype)
172
+
173
+ x = self.conv_pre(x) # [B, ch, length]
174
+
175
+ for i in range(self.num_upsamples):
176
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
177
+ x = self.ups[i](x)
178
+
179
+ xs = None
180
+ for j in range(self.num_kernels):
181
+ if xs is None:
182
+ xs = self.resblocks[i * self.num_kernels + j](x)
183
+ else:
184
+ xs += self.resblocks[i * self.num_kernels + j](x)
185
+ x = xs / self.num_kernels
186
+
187
+ x = F.leaky_relu(x)
188
+ x = self.reflection_pad(x)
189
+ x = self.subband_conv_post(x)
190
+ x = torch.reshape(
191
+ x, (x.shape[0], self.subbands, x.shape[1] // self.subbands, x.shape[-1])
192
+ )
193
+
194
+ spec = torch.exp(x[:, :, : self.post_n_fft // 2 + 1, :])
195
+ phase = math.pi * torch.sin(x[:, :, self.post_n_fft // 2 + 1 :, :])
196
+
197
+ y_mb_hat = stft.inverse(
198
+ torch.reshape(
199
+ spec,
200
+ (
201
+ spec.shape[0] * self.subbands,
202
+ self.gen_istft_n_fft // 2 + 1,
203
+ spec.shape[-1],
204
+ ),
205
+ ),
206
+ torch.reshape(
207
+ phase,
208
+ (
209
+ phase.shape[0] * self.subbands,
210
+ self.gen_istft_n_fft // 2 + 1,
211
+ phase.shape[-1],
212
+ ),
213
+ ),
214
+ )
215
+ y_mb_hat = torch.reshape(
216
+ y_mb_hat, (x.shape[0], self.subbands, 1, y_mb_hat.shape[-1])
217
+ )
218
+ y_mb_hat = y_mb_hat.squeeze(-2)
219
+
220
+ y_g_hat = pqmf.synthesis(y_mb_hat)
221
+
222
+ return y_g_hat, y_mb_hat
223
+
224
+ def remove_weight_norm(self):
225
+ print("Removing weight norm...")
226
+ for l in self.ups:
227
+ remove_weight_norm(l)
228
+ for l in self.resblocks:
229
+ l.remove_weight_norm()
230
+
231
+
232
+ class Multistream_iSTFT_Generator(torch.nn.Module):
233
+ def __init__(
234
+ self,
235
+ initial_channel,
236
+ resblock,
237
+ resblock_kernel_sizes,
238
+ resblock_dilation_sizes,
239
+ upsample_rates,
240
+ upsample_initial_channel,
241
+ upsample_kernel_sizes,
242
+ gen_istft_n_fft,
243
+ gen_istft_hop_size,
244
+ subbands,
245
+ gin_channels=0,
246
+ ):
247
+ super().__init__()
248
+ # self.h = h
249
+ self.subbands = subbands
250
+ self.num_kernels = len(resblock_kernel_sizes)
251
+ self.num_upsamples = len(upsample_rates)
252
+ self.conv_pre = weight_norm(
253
+ Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
254
+ )
255
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
256
+
257
+ self.ups = nn.ModuleList()
258
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
259
+ self.ups.append(
260
+ weight_norm(
261
+ ConvTranspose1d(
262
+ upsample_initial_channel // (2**i),
263
+ upsample_initial_channel // (2 ** (i + 1)),
264
+ k,
265
+ u,
266
+ padding=(k - u) // 2,
267
+ )
268
+ )
269
+ )
270
+
271
+ self.resblocks = nn.ModuleList()
272
+ for i in range(len(self.ups)):
273
+ ch = upsample_initial_channel // (2 ** (i + 1))
274
+ for j, (k, d) in enumerate(
275
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
276
+ ):
277
+ self.resblocks.append(resblock(ch, k, d))
278
+
279
+ self.post_n_fft = gen_istft_n_fft
280
+ self.ups.apply(init_weights)
281
+ self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
282
+ self.reshape_pixelshuffle = []
283
+
284
+ self.subband_conv_post = weight_norm(
285
+ Conv1d(ch, self.subbands * (self.post_n_fft + 2), 7, 1, padding=3)
286
+ )
287
+
288
+ self.subband_conv_post.apply(init_weights)
289
+
290
+ self.gen_istft_n_fft = gen_istft_n_fft
291
+ self.gen_istft_hop_size = gen_istft_hop_size
292
+
293
+ updown_filter = torch.zeros(
294
+ (self.subbands, self.subbands, self.subbands)
295
+ ).float()
296
+ for k in range(self.subbands):
297
+ updown_filter[k, k, 0] = 1.0
298
+ self.register_buffer("updown_filter", updown_filter)
299
+ self.multistream_conv_post = weight_norm(
300
+ Conv1d(
301
+ self.subbands, 1, kernel_size=63, bias=False, padding=get_padding(63, 1)
302
+ )
303
+ )
304
+ self.multistream_conv_post.apply(init_weights)
305
+
306
+ def forward(self, x, g=None):
307
+ stft = TorchSTFT(
308
+ filter_length=self.gen_istft_n_fft,
309
+ hop_length=self.gen_istft_hop_size,
310
+ win_length=self.gen_istft_n_fft,
311
+ ).to(x.device)
312
+ # pqmf = PQMF(x.device)
313
+
314
+ x = self.conv_pre(x) # [B, ch, length]
315
+
316
+ for i in range(self.num_upsamples):
317
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
318
+ x = self.ups[i](x)
319
+
320
+ xs = None
321
+ for j in range(self.num_kernels):
322
+ if xs is None:
323
+ xs = self.resblocks[i * self.num_kernels + j](x)
324
+ else:
325
+ xs += self.resblocks[i * self.num_kernels + j](x)
326
+ x = xs / self.num_kernels
327
+
328
+ x = F.leaky_relu(x)
329
+ x = self.reflection_pad(x)
330
+ x = self.subband_conv_post(x)
331
+ x = torch.reshape(
332
+ x, (x.shape[0], self.subbands, x.shape[1] // self.subbands, x.shape[-1])
333
+ )
334
+
335
+ spec = torch.exp(x[:, :, : self.post_n_fft // 2 + 1, :])
336
+ phase = math.pi * torch.sin(x[:, :, self.post_n_fft // 2 + 1 :, :])
337
+
338
+ y_mb_hat = stft.inverse(
339
+ torch.reshape(
340
+ spec,
341
+ (
342
+ spec.shape[0] * self.subbands,
343
+ self.gen_istft_n_fft // 2 + 1,
344
+ spec.shape[-1],
345
+ ),
346
+ ),
347
+ torch.reshape(
348
+ phase,
349
+ (
350
+ phase.shape[0] * self.subbands,
351
+ self.gen_istft_n_fft // 2 + 1,
352
+ phase.shape[-1],
353
+ ),
354
+ ),
355
+ )
356
+ y_mb_hat = torch.reshape(
357
+ y_mb_hat, (x.shape[0], self.subbands, 1, y_mb_hat.shape[-1])
358
+ )
359
+ y_mb_hat = y_mb_hat.squeeze(-2)
360
+
361
+ y_mb_hat = F.conv_transpose1d(
362
+ y_mb_hat,
363
+ self.updown_filter.to(x.device) * self.subbands,
364
+ stride=self.subbands,
365
+ )
366
+
367
+ y_g_hat = self.multistream_conv_post(y_mb_hat)
368
+
369
+ return y_g_hat, y_mb_hat
370
+
371
+ def remove_weight_norm(self):
372
+ print("Removing weight norm...")
373
+ for l in self.ups:
374
+ remove_weight_norm(l)
375
+ for l in self.resblocks:
376
+ l.remove_weight_norm()
so_vits_svc_fork/modules/decoders/mb_istft/_loss.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._stft_loss import MultiResolutionSTFTLoss
2
+
3
+
4
+ def subband_stft_loss(h, y_mb, y_hat_mb):
5
+ sub_stft_loss = MultiResolutionSTFTLoss(
6
+ h.train.fft_sizes, h.train.hop_sizes, h.train.win_lengths
7
+ )
8
+ y_mb = y_mb.view(-1, y_mb.size(2))
9
+ y_hat_mb = y_hat_mb.view(-1, y_hat_mb.size(2))
10
+ sub_sc_loss, sub_mag_loss = sub_stft_loss(y_hat_mb[:, : y_mb.size(-1)], y_mb)
11
+ return sub_sc_loss + sub_mag_loss
so_vits_svc_fork/modules/decoders/mb_istft/_pqmf.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Tomoki Hayashi
2
+ # MIT License (https://opensource.org/licenses/MIT)
3
+
4
+ """Pseudo QMF modules."""
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from scipy.signal import kaiser
10
+
11
+
12
+ def design_prototype_filter(taps=62, cutoff_ratio=0.15, beta=9.0):
13
+ """Design prototype filter for PQMF.
14
+ This method is based on `A Kaiser window approach for the design of prototype
15
+ filters of cosine modulated filterbanks`_.
16
+ Args:
17
+ taps (int): The number of filter taps.
18
+ cutoff_ratio (float): Cut-off frequency ratio.
19
+ beta (float): Beta coefficient for kaiser window.
20
+ Returns:
21
+ ndarray: Impluse response of prototype filter (taps + 1,).
22
+ .. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:
23
+ https://ieeexplore.ieee.org/abstract/document/681427
24
+ """
25
+ # check the arguments are valid
26
+ assert taps % 2 == 0, "The number of taps mush be even number."
27
+ assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0."
28
+
29
+ # make initial filter
30
+ omega_c = np.pi * cutoff_ratio
31
+ with np.errstate(invalid="ignore"):
32
+ h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) / (
33
+ np.pi * (np.arange(taps + 1) - 0.5 * taps)
34
+ )
35
+ h_i[taps // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form
36
+
37
+ # apply kaiser window
38
+ w = kaiser(taps + 1, beta)
39
+ h = h_i * w
40
+
41
+ return h
42
+
43
+
44
+ class PQMF(torch.nn.Module):
45
+ """PQMF module.
46
+ This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_.
47
+ .. _`Near-perfect-reconstruction pseudo-QMF banks`:
48
+ https://ieeexplore.ieee.org/document/258122
49
+ """
50
+
51
+ def __init__(self, device, subbands=8, taps=62, cutoff_ratio=0.15, beta=9.0):
52
+ """Initialize PQMF module.
53
+ Args:
54
+ subbands (int): The number of subbands.
55
+ taps (int): The number of filter taps.
56
+ cutoff_ratio (float): Cut-off frequency ratio.
57
+ beta (float): Beta coefficient for kaiser window.
58
+ """
59
+ super().__init__()
60
+
61
+ # define filter coefficient
62
+ h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
63
+ h_analysis = np.zeros((subbands, len(h_proto)))
64
+ h_synthesis = np.zeros((subbands, len(h_proto)))
65
+ for k in range(subbands):
66
+ h_analysis[k] = (
67
+ 2
68
+ * h_proto
69
+ * np.cos(
70
+ (2 * k + 1)
71
+ * (np.pi / (2 * subbands))
72
+ * (np.arange(taps + 1) - ((taps - 1) / 2))
73
+ + (-1) ** k * np.pi / 4
74
+ )
75
+ )
76
+ h_synthesis[k] = (
77
+ 2
78
+ * h_proto
79
+ * np.cos(
80
+ (2 * k + 1)
81
+ * (np.pi / (2 * subbands))
82
+ * (np.arange(taps + 1) - ((taps - 1) / 2))
83
+ - (-1) ** k * np.pi / 4
84
+ )
85
+ )
86
+
87
+ # convert to tensor
88
+ analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1).to(device)
89
+ synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0).to(device)
90
+
91
+ # register coefficients as buffer
92
+ self.register_buffer("analysis_filter", analysis_filter)
93
+ self.register_buffer("synthesis_filter", synthesis_filter)
94
+
95
+ # filter for downsampling & upsampling
96
+ updown_filter = torch.zeros((subbands, subbands, subbands)).float().to(device)
97
+ for k in range(subbands):
98
+ updown_filter[k, k, 0] = 1.0
99
+ self.register_buffer("updown_filter", updown_filter)
100
+ self.subbands = subbands
101
+
102
+ # keep padding info
103
+ self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
104
+
105
+ def analysis(self, x):
106
+ """Analysis with PQMF.
107
+ Args:
108
+ x (Tensor): Input tensor (B, 1, T).
109
+ Returns:
110
+ Tensor: Output tensor (B, subbands, T // subbands).
111
+ """
112
+ x = F.conv1d(self.pad_fn(x), self.analysis_filter)
113
+ return F.conv1d(x, self.updown_filter, stride=self.subbands)
114
+
115
+ def synthesis(self, x):
116
+ """Synthesis with PQMF.
117
+ Args:
118
+ x (Tensor): Input tensor (B, subbands, T // subbands).
119
+ Returns:
120
+ Tensor: Output tensor (B, 1, T).
121
+ """
122
+ # NOTE(kan-bayashi): Power will be dreased so here multiply by # subbands.
123
+ # Not sure this is the correct way, it is better to check again.
124
+ # TODO(kan-bayashi): Understand the reconstruction procedure
125
+ x = F.conv_transpose1d(
126
+ x, self.updown_filter * self.subbands, stride=self.subbands
127
+ )
128
+ return F.conv1d(self.pad_fn(x), self.synthesis_filter)
so_vits_svc_fork/modules/decoders/mb_istft/_stft.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BSD 3-Clause License
3
+ Copyright (c) 2017, Prem Seetharaman
4
+ All rights reserved.
5
+ * Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+ * Redistributions of source code must retain the above copyright notice,
8
+ this list of conditions and the following disclaimer.
9
+ * Redistributions in binary form must reproduce the above copyright notice, this
10
+ list of conditions and the following disclaimer in the
11
+ documentation and/or other materials provided with the distribution.
12
+ * Neither the name of the copyright holder nor the names of its
13
+ contributors may be used to endorse or promote products derived from this
14
+ software without specific prior written permission.
15
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
19
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
22
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25
+ """
26
+
27
+ import librosa.util as librosa_util
28
+ import numpy as np
29
+ import torch
30
+ import torch.nn.functional as F
31
+ from librosa.util import pad_center, tiny
32
+ from scipy.signal import get_window
33
+ from torch.autograd import Variable
34
+
35
+
36
+ def window_sumsquare(
37
+ window,
38
+ n_frames,
39
+ hop_length=200,
40
+ win_length=800,
41
+ n_fft=800,
42
+ dtype=np.float32,
43
+ norm=None,
44
+ ):
45
+ """
46
+ # from librosa 0.6
47
+ Compute the sum-square envelope of a window function at a given hop length.
48
+ This is used to estimate modulation effects induced by windowing
49
+ observations in short-time fourier transforms.
50
+ Parameters
51
+ ----------
52
+ window : string, tuple, number, callable, or list-like
53
+ Window specification, as in `get_window`
54
+ n_frames : int > 0
55
+ The number of analysis frames
56
+ hop_length : int > 0
57
+ The number of samples to advance between frames
58
+ win_length : [optional]
59
+ The length of the window function. By default, this matches `n_fft`.
60
+ n_fft : int > 0
61
+ The length of each analysis frame.
62
+ dtype : np.dtype
63
+ The data type of the output
64
+ Returns
65
+ -------
66
+ wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
67
+ The sum-squared envelope of the window function
68
+ """
69
+ if win_length is None:
70
+ win_length = n_fft
71
+
72
+ n = n_fft + hop_length * (n_frames - 1)
73
+ x = np.zeros(n, dtype=dtype)
74
+
75
+ # Compute the squared window at the desired length
76
+ win_sq = get_window(window, win_length, fftbins=True)
77
+ win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
78
+ win_sq = librosa_util.pad_center(win_sq, n_fft)
79
+
80
+ # Fill the envelope
81
+ for i in range(n_frames):
82
+ sample = i * hop_length
83
+ x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
84
+ return x
85
+
86
+
87
+ class STFT(torch.nn.Module):
88
+ """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
89
+
90
+ def __init__(
91
+ self, filter_length=800, hop_length=200, win_length=800, window="hann"
92
+ ):
93
+ super().__init__()
94
+ self.filter_length = filter_length
95
+ self.hop_length = hop_length
96
+ self.win_length = win_length
97
+ self.window = window
98
+ self.forward_transform = None
99
+ scale = self.filter_length / self.hop_length
100
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
101
+
102
+ cutoff = int(self.filter_length / 2 + 1)
103
+ fourier_basis = np.vstack(
104
+ [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
105
+ )
106
+
107
+ forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
108
+ inverse_basis = torch.FloatTensor(
109
+ np.linalg.pinv(scale * fourier_basis).T[:, None, :]
110
+ )
111
+
112
+ if window is not None:
113
+ assert filter_length >= win_length
114
+ # get window and zero center pad it to filter_length
115
+ fft_window = get_window(window, win_length, fftbins=True)
116
+ fft_window = pad_center(fft_window, filter_length)
117
+ fft_window = torch.from_numpy(fft_window).float()
118
+
119
+ # window the bases
120
+ forward_basis *= fft_window
121
+ inverse_basis *= fft_window
122
+
123
+ self.register_buffer("forward_basis", forward_basis.float())
124
+ self.register_buffer("inverse_basis", inverse_basis.float())
125
+
126
+ def transform(self, input_data):
127
+ num_batches = input_data.size(0)
128
+ num_samples = input_data.size(1)
129
+
130
+ self.num_samples = num_samples
131
+
132
+ # similar to librosa, reflect-pad the input
133
+ input_data = input_data.view(num_batches, 1, num_samples)
134
+ input_data = F.pad(
135
+ input_data.unsqueeze(1),
136
+ (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
137
+ mode="reflect",
138
+ )
139
+ input_data = input_data.squeeze(1)
140
+
141
+ forward_transform = F.conv1d(
142
+ input_data,
143
+ Variable(self.forward_basis, requires_grad=False),
144
+ stride=self.hop_length,
145
+ padding=0,
146
+ )
147
+
148
+ cutoff = int((self.filter_length / 2) + 1)
149
+ real_part = forward_transform[:, :cutoff, :]
150
+ imag_part = forward_transform[:, cutoff:, :]
151
+
152
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
153
+ phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
154
+
155
+ return magnitude, phase
156
+
157
+ def inverse(self, magnitude, phase):
158
+ recombine_magnitude_phase = torch.cat(
159
+ [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
160
+ )
161
+
162
+ inverse_transform = F.conv_transpose1d(
163
+ recombine_magnitude_phase,
164
+ Variable(self.inverse_basis, requires_grad=False),
165
+ stride=self.hop_length,
166
+ padding=0,
167
+ )
168
+
169
+ if self.window is not None:
170
+ window_sum = window_sumsquare(
171
+ self.window,
172
+ magnitude.size(-1),
173
+ hop_length=self.hop_length,
174
+ win_length=self.win_length,
175
+ n_fft=self.filter_length,
176
+ dtype=np.float32,
177
+ )
178
+ # remove modulation effects
179
+ approx_nonzero_indices = torch.from_numpy(
180
+ np.where(window_sum > tiny(window_sum))[0]
181
+ )
182
+ window_sum = torch.autograd.Variable(
183
+ torch.from_numpy(window_sum), requires_grad=False
184
+ )
185
+ window_sum = window_sum.to(inverse_transform.device())
186
+ inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
187
+ approx_nonzero_indices
188
+ ]
189
+
190
+ # scale by hop ratio
191
+ inverse_transform *= float(self.filter_length) / self.hop_length
192
+
193
+ inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
194
+ inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
195
+
196
+ return inverse_transform
197
+
198
+ def forward(self, input_data):
199
+ self.magnitude, self.phase = self.transform(input_data)
200
+ reconstruction = self.inverse(self.magnitude, self.phase)
201
+ return reconstruction
202
+
203
+
204
+ class TorchSTFT(torch.nn.Module):
205
+ def __init__(
206
+ self, filter_length=800, hop_length=200, win_length=800, window="hann"
207
+ ):
208
+ super().__init__()
209
+ self.filter_length = filter_length
210
+ self.hop_length = hop_length
211
+ self.win_length = win_length
212
+ self.window = torch.from_numpy(
213
+ get_window(window, win_length, fftbins=True).astype(np.float32)
214
+ )
215
+
216
+ def transform(self, input_data):
217
+ forward_transform = torch.stft(
218
+ input_data,
219
+ self.filter_length,
220
+ self.hop_length,
221
+ self.win_length,
222
+ window=self.window,
223
+ return_complex=True,
224
+ )
225
+
226
+ return torch.abs(forward_transform), torch.angle(forward_transform)
227
+
228
+ def inverse(self, magnitude, phase):
229
+ inverse_transform = torch.istft(
230
+ magnitude * torch.exp(phase * 1j),
231
+ self.filter_length,
232
+ self.hop_length,
233
+ self.win_length,
234
+ window=self.window.to(magnitude.device),
235
+ )
236
+
237
+ return inverse_transform.unsqueeze(
238
+ -2
239
+ ) # unsqueeze to stay consistent with conv_transpose1d implementation
240
+
241
+ def forward(self, input_data):
242
+ self.magnitude, self.phase = self.transform(input_data)
243
+ reconstruction = self.inverse(self.magnitude, self.phase)
244
+ return reconstruction
so_vits_svc_fork/modules/decoders/mb_istft/_stft_loss.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 Tomoki Hayashi
2
+ # MIT License (https://opensource.org/licenses/MIT)
3
+
4
+ """STFT-based Loss modules."""
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+
10
+ def stft(x, fft_size, hop_size, win_length, window):
11
+ """Perform STFT and convert to magnitude spectrogram.
12
+ Args:
13
+ x (Tensor): Input signal tensor (B, T).
14
+ fft_size (int): FFT size.
15
+ hop_size (int): Hop size.
16
+ win_length (int): Window length.
17
+ window (str): Window function type.
18
+ Returns:
19
+ Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
20
+ """
21
+ x_stft = torch.stft(
22
+ x, fft_size, hop_size, win_length, window.to(x.device), return_complex=False
23
+ )
24
+ real = x_stft[..., 0]
25
+ imag = x_stft[..., 1]
26
+
27
+ # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
28
+ return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1)
29
+
30
+
31
+ class SpectralConvergengeLoss(torch.nn.Module):
32
+ """Spectral convergence loss module."""
33
+
34
+ def __init__(self):
35
+ """Initialize spectral convergence loss module."""
36
+ super().__init__()
37
+
38
+ def forward(self, x_mag, y_mag):
39
+ """Calculate forward propagation.
40
+ Args:
41
+ x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
42
+ y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
43
+ Returns:
44
+ Tensor: Spectral convergence loss value.
45
+ """
46
+ return torch.norm(y_mag - x_mag) / torch.norm(
47
+ y_mag
48
+ ) # MB-iSTFT-VITS changed here due to codespell
49
+
50
+
51
+ class LogSTFTMagnitudeLoss(torch.nn.Module):
52
+ """Log STFT magnitude loss module."""
53
+
54
+ def __init__(self):
55
+ """Initialize los STFT magnitude loss module."""
56
+ super().__init__()
57
+
58
+ def forward(self, x_mag, y_mag):
59
+ """Calculate forward propagation.
60
+ Args:
61
+ x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
62
+ y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
63
+ Returns:
64
+ Tensor: Log STFT magnitude loss value.
65
+ """
66
+ return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
67
+
68
+
69
+ class STFTLoss(torch.nn.Module):
70
+ """STFT loss module."""
71
+
72
+ def __init__(
73
+ self, fft_size=1024, shift_size=120, win_length=600, window="hann_window"
74
+ ):
75
+ """Initialize STFT loss module."""
76
+ super().__init__()
77
+ self.fft_size = fft_size
78
+ self.shift_size = shift_size
79
+ self.win_length = win_length
80
+ self.window = getattr(torch, window)(win_length)
81
+ self.spectral_convergenge_loss = SpectralConvergengeLoss()
82
+ self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
83
+
84
+ def forward(self, x, y):
85
+ """Calculate forward propagation.
86
+ Args:
87
+ x (Tensor): Predicted signal (B, T).
88
+ y (Tensor): Groundtruth signal (B, T).
89
+ Returns:
90
+ Tensor: Spectral convergence loss value.
91
+ Tensor: Log STFT magnitude loss value.
92
+ """
93
+ x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
94
+ y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
95
+ sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
96
+ mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
97
+
98
+ return sc_loss, mag_loss
99
+
100
+
101
+ class MultiResolutionSTFTLoss(torch.nn.Module):
102
+ """Multi resolution STFT loss module."""
103
+
104
+ def __init__(
105
+ self,
106
+ fft_sizes=[1024, 2048, 512],
107
+ hop_sizes=[120, 240, 50],
108
+ win_lengths=[600, 1200, 240],
109
+ window="hann_window",
110
+ ):
111
+ """Initialize Multi resolution STFT loss module.
112
+ Args:
113
+ fft_sizes (list): List of FFT sizes.
114
+ hop_sizes (list): List of hop sizes.
115
+ win_lengths (list): List of window lengths.
116
+ window (str): Window function type.
117
+ """
118
+ super().__init__()
119
+ assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
120
+ self.stft_losses = torch.nn.ModuleList()
121
+ for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
122
+ self.stft_losses += [STFTLoss(fs, ss, wl, window)]
123
+
124
+ def forward(self, x, y):
125
+ """Calculate forward propagation.
126
+ Args:
127
+ x (Tensor): Predicted signal (B, T).
128
+ y (Tensor): Groundtruth signal (B, T).
129
+ Returns:
130
+ Tensor: Multi resolution spectral convergence loss value.
131
+ Tensor: Multi resolution log STFT magnitude loss value.
132
+ """
133
+ sc_loss = 0.0
134
+ mag_loss = 0.0
135
+ for f in self.stft_losses:
136
+ sc_l, mag_l = f(x, y)
137
+ sc_loss += sc_l
138
+ mag_loss += mag_l
139
+ sc_loss /= len(self.stft_losses)
140
+ mag_loss /= len(self.stft_losses)
141
+
142
+ return sc_loss, mag_loss
so_vits_svc_fork/modules/descriminators.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import AvgPool1d, Conv1d, Conv2d
4
+ from torch.nn import functional as F
5
+ from torch.nn.utils import spectral_norm, weight_norm
6
+
7
+ from so_vits_svc_fork.modules import modules as modules
8
+ from so_vits_svc_fork.modules.commons import get_padding
9
+
10
+
11
+ class DiscriminatorP(torch.nn.Module):
12
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
13
+ super().__init__()
14
+ self.period = period
15
+ self.use_spectral_norm = use_spectral_norm
16
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
17
+ self.convs = nn.ModuleList(
18
+ [
19
+ norm_f(
20
+ Conv2d(
21
+ 1,
22
+ 32,
23
+ (kernel_size, 1),
24
+ (stride, 1),
25
+ padding=(get_padding(kernel_size, 1), 0),
26
+ )
27
+ ),
28
+ norm_f(
29
+ Conv2d(
30
+ 32,
31
+ 128,
32
+ (kernel_size, 1),
33
+ (stride, 1),
34
+ padding=(get_padding(kernel_size, 1), 0),
35
+ )
36
+ ),
37
+ norm_f(
38
+ Conv2d(
39
+ 128,
40
+ 512,
41
+ (kernel_size, 1),
42
+ (stride, 1),
43
+ padding=(get_padding(kernel_size, 1), 0),
44
+ )
45
+ ),
46
+ norm_f(
47
+ Conv2d(
48
+ 512,
49
+ 1024,
50
+ (kernel_size, 1),
51
+ (stride, 1),
52
+ padding=(get_padding(kernel_size, 1), 0),
53
+ )
54
+ ),
55
+ norm_f(
56
+ Conv2d(
57
+ 1024,
58
+ 1024,
59
+ (kernel_size, 1),
60
+ 1,
61
+ padding=(get_padding(kernel_size, 1), 0),
62
+ )
63
+ ),
64
+ ]
65
+ )
66
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
67
+
68
+ def forward(self, x):
69
+ fmap = []
70
+
71
+ # 1d to 2d
72
+ b, c, t = x.shape
73
+ if t % self.period != 0: # pad first
74
+ n_pad = self.period - (t % self.period)
75
+ x = F.pad(x, (0, n_pad), "reflect")
76
+ t = t + n_pad
77
+ x = x.view(b, c, t // self.period, self.period)
78
+
79
+ for l in self.convs:
80
+ x = l(x)
81
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
82
+ fmap.append(x)
83
+ x = self.conv_post(x)
84
+ fmap.append(x)
85
+ x = torch.flatten(x, 1, -1)
86
+
87
+ return x, fmap
88
+
89
+
90
+ class DiscriminatorS(torch.nn.Module):
91
+ def __init__(self, use_spectral_norm=False):
92
+ super().__init__()
93
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
94
+ self.convs = nn.ModuleList(
95
+ [
96
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
97
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
98
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
99
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
100
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
101
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
102
+ ]
103
+ )
104
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
105
+
106
+ def forward(self, x):
107
+ fmap = []
108
+
109
+ for l in self.convs:
110
+ x = l(x)
111
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
112
+ fmap.append(x)
113
+ x = self.conv_post(x)
114
+ fmap.append(x)
115
+ x = torch.flatten(x, 1, -1)
116
+
117
+ return x, fmap
118
+
119
+
120
+ class MultiPeriodDiscriminator(torch.nn.Module):
121
+ def __init__(self, use_spectral_norm=False):
122
+ super().__init__()
123
+ periods = [2, 3, 5, 7, 11]
124
+
125
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
126
+ discs = discs + [
127
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
128
+ ]
129
+ self.discriminators = nn.ModuleList(discs)
130
+
131
+ def forward(self, y, y_hat):
132
+ y_d_rs = []
133
+ y_d_gs = []
134
+ fmap_rs = []
135
+ fmap_gs = []
136
+ for i, d in enumerate(self.discriminators):
137
+ y_d_r, fmap_r = d(y)
138
+ y_d_g, fmap_g = d(y_hat)
139
+ y_d_rs.append(y_d_r)
140
+ y_d_gs.append(y_d_g)
141
+ fmap_rs.append(fmap_r)
142
+ fmap_gs.append(fmap_g)
143
+
144
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
145
+
146
+
147
+ class MultiScaleDiscriminator(torch.nn.Module):
148
+ def __init__(self):
149
+ super().__init__()
150
+ self.discriminators = nn.ModuleList(
151
+ [
152
+ DiscriminatorS(use_spectral_norm=True),
153
+ DiscriminatorS(),
154
+ DiscriminatorS(),
155
+ ]
156
+ )
157
+ self.meanpools = nn.ModuleList(
158
+ [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]
159
+ )
160
+
161
+ def forward(self, y, y_hat):
162
+ y_d_rs = []
163
+ y_d_gs = []
164
+ fmap_rs = []
165
+ fmap_gs = []
166
+ for i, d in enumerate(self.discriminators):
167
+ if i != 0:
168
+ y = self.meanpools[i - 1](y)
169
+ y_hat = self.meanpools[i - 1](y_hat)
170
+ y_d_r, fmap_r = d(y)
171
+ y_d_g, fmap_g = d(y_hat)
172
+ y_d_rs.append(y_d_r)
173
+ fmap_rs.append(fmap_r)
174
+ y_d_gs.append(y_d_g)
175
+ fmap_gs.append(fmap_g)
176
+
177
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
so_vits_svc_fork/modules/encoders.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from so_vits_svc_fork.modules import attentions as attentions
5
+ from so_vits_svc_fork.modules import commons as commons
6
+ from so_vits_svc_fork.modules import modules as modules
7
+
8
+
9
+ class SpeakerEncoder(torch.nn.Module):
10
+ def __init__(
11
+ self,
12
+ mel_n_channels=80,
13
+ model_num_layers=3,
14
+ model_hidden_size=256,
15
+ model_embedding_size=256,
16
+ ):
17
+ super().__init__()
18
+ self.lstm = nn.LSTM(
19
+ mel_n_channels, model_hidden_size, model_num_layers, batch_first=True
20
+ )
21
+ self.linear = nn.Linear(model_hidden_size, model_embedding_size)
22
+ self.relu = nn.ReLU()
23
+
24
+ def forward(self, mels):
25
+ self.lstm.flatten_parameters()
26
+ _, (hidden, _) = self.lstm(mels)
27
+ embeds_raw = self.relu(self.linear(hidden[-1]))
28
+ return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
29
+
30
+ def compute_partial_slices(self, total_frames, partial_frames, partial_hop):
31
+ mel_slices = []
32
+ for i in range(0, total_frames - partial_frames, partial_hop):
33
+ mel_range = torch.arange(i, i + partial_frames)
34
+ mel_slices.append(mel_range)
35
+
36
+ return mel_slices
37
+
38
+ def embed_utterance(self, mel, partial_frames=128, partial_hop=64):
39
+ mel_len = mel.size(1)
40
+ last_mel = mel[:, -partial_frames:]
41
+
42
+ if mel_len > partial_frames:
43
+ mel_slices = self.compute_partial_slices(
44
+ mel_len, partial_frames, partial_hop
45
+ )
46
+ mels = list(mel[:, s] for s in mel_slices)
47
+ mels.append(last_mel)
48
+ mels = torch.stack(tuple(mels), 0).squeeze(1)
49
+
50
+ with torch.no_grad():
51
+ partial_embeds = self(mels)
52
+ embed = torch.mean(partial_embeds, axis=0).unsqueeze(0)
53
+ # embed = embed / torch.linalg.norm(embed, 2)
54
+ else:
55
+ with torch.no_grad():
56
+ embed = self(last_mel)
57
+
58
+ return embed
59
+
60
+
61
+ class Encoder(nn.Module):
62
+ def __init__(
63
+ self,
64
+ in_channels,
65
+ out_channels,
66
+ hidden_channels,
67
+ kernel_size,
68
+ dilation_rate,
69
+ n_layers,
70
+ gin_channels=0,
71
+ ):
72
+ super().__init__()
73
+ self.in_channels = in_channels
74
+ self.out_channels = out_channels
75
+ self.hidden_channels = hidden_channels
76
+ self.kernel_size = kernel_size
77
+ self.dilation_rate = dilation_rate
78
+ self.n_layers = n_layers
79
+ self.gin_channels = gin_channels
80
+
81
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
82
+ self.enc = modules.WN(
83
+ hidden_channels,
84
+ kernel_size,
85
+ dilation_rate,
86
+ n_layers,
87
+ gin_channels=gin_channels,
88
+ )
89
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
90
+
91
+ def forward(self, x, x_lengths, g=None):
92
+ # print(x.shape,x_lengths.shape)
93
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
94
+ x.dtype
95
+ )
96
+ x = self.pre(x) * x_mask
97
+ x = self.enc(x, x_mask, g=g)
98
+ stats = self.proj(x) * x_mask
99
+ m, logs = torch.split(stats, self.out_channels, dim=1)
100
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
101
+ return z, m, logs, x_mask
102
+
103
+
104
+ class TextEncoder(nn.Module):
105
+ def __init__(
106
+ self,
107
+ out_channels,
108
+ hidden_channels,
109
+ kernel_size,
110
+ n_layers,
111
+ gin_channels=0,
112
+ filter_channels=None,
113
+ n_heads=None,
114
+ p_dropout=None,
115
+ ):
116
+ super().__init__()
117
+ self.out_channels = out_channels
118
+ self.hidden_channels = hidden_channels
119
+ self.kernel_size = kernel_size
120
+ self.n_layers = n_layers
121
+ self.gin_channels = gin_channels
122
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
123
+ self.f0_emb = nn.Embedding(256, hidden_channels)
124
+
125
+ self.enc_ = attentions.Encoder(
126
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
127
+ )
128
+
129
+ def forward(self, x, x_mask, f0=None, noice_scale=1):
130
+ x = x + self.f0_emb(f0).transpose(1, 2)
131
+ x = self.enc_(x * x_mask, x_mask)
132
+ stats = self.proj(x) * x_mask
133
+ m, logs = torch.split(stats, self.out_channels, dim=1)
134
+ z = (m + torch.randn_like(m) * torch.exp(logs) * noice_scale) * x_mask
135
+
136
+ return z, m, logs, x_mask
so_vits_svc_fork/modules/flows.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ from so_vits_svc_fork.modules import modules as modules
4
+
5
+
6
+ class ResidualCouplingBlock(nn.Module):
7
+ def __init__(
8
+ self,
9
+ channels,
10
+ hidden_channels,
11
+ kernel_size,
12
+ dilation_rate,
13
+ n_layers,
14
+ n_flows=4,
15
+ gin_channels=0,
16
+ ):
17
+ super().__init__()
18
+ self.channels = channels
19
+ self.hidden_channels = hidden_channels
20
+ self.kernel_size = kernel_size
21
+ self.dilation_rate = dilation_rate
22
+ self.n_layers = n_layers
23
+ self.n_flows = n_flows
24
+ self.gin_channels = gin_channels
25
+
26
+ self.flows = nn.ModuleList()
27
+ for i in range(n_flows):
28
+ self.flows.append(
29
+ modules.ResidualCouplingLayer(
30
+ channels,
31
+ hidden_channels,
32
+ kernel_size,
33
+ dilation_rate,
34
+ n_layers,
35
+ gin_channels=gin_channels,
36
+ mean_only=True,
37
+ )
38
+ )
39
+ self.flows.append(modules.Flip())
40
+
41
+ def forward(self, x, x_mask, g=None, reverse=False):
42
+ if not reverse:
43
+ for flow in self.flows:
44
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
45
+ else:
46
+ for flow in reversed(self.flows):
47
+ x = flow(x, x_mask, g=g, reverse=reverse)
48
+ return x
so_vits_svc_fork/modules/losses.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def feature_loss(fmap_r, fmap_g):
5
+ loss = 0
6
+ for dr, dg in zip(fmap_r, fmap_g):
7
+ for rl, gl in zip(dr, dg):
8
+ rl = rl.float().detach()
9
+ gl = gl.float()
10
+ loss += torch.mean(torch.abs(rl - gl))
11
+
12
+ return loss * 2
13
+
14
+
15
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
16
+ loss = 0
17
+ r_losses = []
18
+ g_losses = []
19
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
20
+ dr = dr.float()
21
+ dg = dg.float()
22
+ r_loss = torch.mean((1 - dr) ** 2)
23
+ g_loss = torch.mean(dg**2)
24
+ loss += r_loss + g_loss
25
+ r_losses.append(r_loss.item())
26
+ g_losses.append(g_loss.item())
27
+
28
+ return loss, r_losses, g_losses
29
+
30
+
31
+ def generator_loss(disc_outputs):
32
+ loss = 0
33
+ gen_losses = []
34
+ for dg in disc_outputs:
35
+ dg = dg.float()
36
+ l = torch.mean((1 - dg) ** 2)
37
+ gen_losses.append(l)
38
+ loss += l
39
+
40
+ return loss, gen_losses
41
+
42
+
43
+ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
44
+ """
45
+ z_p, logs_q: [b, h, t_t]
46
+ m_p, logs_p: [b, h, t_t]
47
+ """
48
+ z_p = z_p.float()
49
+ logs_q = logs_q.float()
50
+ m_p = m_p.float()
51
+ logs_p = logs_p.float()
52
+ z_mask = z_mask.float()
53
+ # print(logs_p)
54
+ kl = logs_p - logs_q - 0.5
55
+ kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
56
+ kl = torch.sum(kl * z_mask)
57
+ l = kl / torch.sum(z_mask)
58
+ return l
so_vits_svc_fork/modules/mel_processing.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """from logging import getLogger
2
+
3
+ import torch
4
+ import torch.utils.data
5
+ import torchaudio
6
+
7
+ LOG = getLogger(__name__)
8
+
9
+
10
+ from ..hparams import HParams
11
+
12
+
13
+ def spectrogram_torch(audio: torch.Tensor, hps: HParams) -> torch.Tensor:
14
+ return torchaudio.transforms.Spectrogram(
15
+ n_fft=hps.data.filter_length,
16
+ win_length=hps.data.win_length,
17
+ hop_length=hps.data.hop_length,
18
+ power=1.0,
19
+ window_fn=torch.hann_window,
20
+ normalized=False,
21
+ ).to(audio.device)(audio)
22
+
23
+
24
+ def spec_to_mel_torch(spec: torch.Tensor, hps: HParams) -> torch.Tensor:
25
+ return torchaudio.transforms.MelScale(
26
+ n_mels=hps.data.n_mel_channels,
27
+ sample_rate=hps.data.sampling_rate,
28
+ f_min=hps.data.mel_fmin,
29
+ f_max=hps.data.mel_fmax,
30
+ ).to(spec.device)(spec)
31
+
32
+
33
+ def mel_spectrogram_torch(audio: torch.Tensor, hps: HParams) -> torch.Tensor:
34
+ return torchaudio.transforms.MelSpectrogram(
35
+ sample_rate=hps.data.sampling_rate,
36
+ n_fft=hps.data.filter_length,
37
+ n_mels=hps.data.n_mel_channels,
38
+ win_length=hps.data.win_length,
39
+ hop_length=hps.data.hop_length,
40
+ f_min=hps.data.mel_fmin,
41
+ f_max=hps.data.mel_fmax,
42
+ power=1.0,
43
+ window_fn=torch.hann_window,
44
+ normalized=False,
45
+ ).to(audio.device)(audio)"""
46
+
47
+ from logging import getLogger
48
+
49
+ import torch
50
+ import torch.utils.data
51
+ from librosa.filters import mel as librosa_mel_fn
52
+
53
+ LOG = getLogger(__name__)
54
+
55
+ MAX_WAV_VALUE = 32768.0
56
+
57
+
58
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
59
+ """
60
+ PARAMS
61
+ ------
62
+ C: compression factor
63
+ """
64
+ return torch.log(torch.clamp(x, min=clip_val) * C)
65
+
66
+
67
+ def dynamic_range_decompression_torch(x, C=1):
68
+ """
69
+ PARAMS
70
+ ------
71
+ C: compression factor used to compress
72
+ """
73
+ return torch.exp(x) / C
74
+
75
+
76
+ def spectral_normalize_torch(magnitudes):
77
+ output = dynamic_range_compression_torch(magnitudes)
78
+ return output
79
+
80
+
81
+ def spectral_de_normalize_torch(magnitudes):
82
+ output = dynamic_range_decompression_torch(magnitudes)
83
+ return output
84
+
85
+
86
+ mel_basis = {}
87
+ hann_window = {}
88
+
89
+
90
+ def spectrogram_torch(y, hps, center=False):
91
+ if torch.min(y) < -1.0:
92
+ LOG.info("min value is ", torch.min(y))
93
+ if torch.max(y) > 1.0:
94
+ LOG.info("max value is ", torch.max(y))
95
+ n_fft = hps.data.filter_length
96
+ hop_size = hps.data.hop_length
97
+ win_size = hps.data.win_length
98
+ global hann_window
99
+ dtype_device = str(y.dtype) + "_" + str(y.device)
100
+ wnsize_dtype_device = str(win_size) + "_" + dtype_device
101
+ if wnsize_dtype_device not in hann_window:
102
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
103
+ dtype=y.dtype, device=y.device
104
+ )
105
+
106
+ y = torch.nn.functional.pad(
107
+ y.unsqueeze(1),
108
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
109
+ mode="reflect",
110
+ )
111
+ y = y.squeeze(1)
112
+
113
+ spec = torch.stft(
114
+ y,
115
+ n_fft,
116
+ hop_length=hop_size,
117
+ win_length=win_size,
118
+ window=hann_window[wnsize_dtype_device],
119
+ center=center,
120
+ pad_mode="reflect",
121
+ normalized=False,
122
+ onesided=True,
123
+ return_complex=False,
124
+ )
125
+
126
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
127
+ return spec
128
+
129
+
130
+ def spec_to_mel_torch(spec, hps):
131
+ sampling_rate = hps.data.sampling_rate
132
+ n_fft = hps.data.filter_length
133
+ num_mels = hps.data.n_mel_channels
134
+ fmin = hps.data.mel_fmin
135
+ fmax = hps.data.mel_fmax
136
+ global mel_basis
137
+ dtype_device = str(spec.dtype) + "_" + str(spec.device)
138
+ fmax_dtype_device = str(fmax) + "_" + dtype_device
139
+ if fmax_dtype_device not in mel_basis:
140
+ mel = librosa_mel_fn(
141
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
142
+ )
143
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
144
+ dtype=spec.dtype, device=spec.device
145
+ )
146
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
147
+ spec = spectral_normalize_torch(spec)
148
+ return spec
149
+
150
+
151
+ def mel_spectrogram_torch(y, hps, center=False):
152
+ sampling_rate = hps.data.sampling_rate
153
+ n_fft = hps.data.filter_length
154
+ num_mels = hps.data.n_mel_channels
155
+ fmin = hps.data.mel_fmin
156
+ fmax = hps.data.mel_fmax
157
+ hop_size = hps.data.hop_length
158
+ win_size = hps.data.win_length
159
+ if torch.min(y) < -1.0:
160
+ LOG.info(f"min value is {torch.min(y)}")
161
+ if torch.max(y) > 1.0:
162
+ LOG.info(f"max value is {torch.max(y)}")
163
+
164
+ global mel_basis, hann_window
165
+ dtype_device = str(y.dtype) + "_" + str(y.device)
166
+ fmax_dtype_device = str(fmax) + "_" + dtype_device
167
+ wnsize_dtype_device = str(win_size) + "_" + dtype_device
168
+ if fmax_dtype_device not in mel_basis:
169
+ mel = librosa_mel_fn(
170
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
171
+ )
172
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
173
+ dtype=y.dtype, device=y.device
174
+ )
175
+ if wnsize_dtype_device not in hann_window:
176
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
177
+ dtype=y.dtype, device=y.device
178
+ )
179
+
180
+ y = torch.nn.functional.pad(
181
+ y.unsqueeze(1),
182
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
183
+ mode="reflect",
184
+ )
185
+ y = y.squeeze(1)
186
+
187
+ spec = torch.stft(
188
+ y,
189
+ n_fft,
190
+ hop_length=hop_size,
191
+ win_length=win_size,
192
+ window=hann_window[wnsize_dtype_device],
193
+ center=center,
194
+ pad_mode="reflect",
195
+ normalized=False,
196
+ onesided=True,
197
+ return_complex=False,
198
+ )
199
+
200
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
201
+
202
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
203
+ spec = spectral_normalize_torch(spec)
204
+
205
+ return spec
so_vits_svc_fork/modules/modules.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import Conv1d
4
+ from torch.nn import functional as F
5
+ from torch.nn.utils import remove_weight_norm, weight_norm
6
+
7
+ from so_vits_svc_fork.modules import commons
8
+ from so_vits_svc_fork.modules.commons import get_padding, init_weights
9
+
10
+ LRELU_SLOPE = 0.1
11
+
12
+
13
+ class LayerNorm(nn.Module):
14
+ def __init__(self, channels, eps=1e-5):
15
+ super().__init__()
16
+ self.channels = channels
17
+ self.eps = eps
18
+
19
+ self.gamma = nn.Parameter(torch.ones(channels))
20
+ self.beta = nn.Parameter(torch.zeros(channels))
21
+
22
+ def forward(self, x):
23
+ x = x.transpose(1, -1)
24
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
25
+ return x.transpose(1, -1)
26
+
27
+
28
+ class ConvReluNorm(nn.Module):
29
+ def __init__(
30
+ self,
31
+ in_channels,
32
+ hidden_channels,
33
+ out_channels,
34
+ kernel_size,
35
+ n_layers,
36
+ p_dropout,
37
+ ):
38
+ super().__init__()
39
+ self.in_channels = in_channels
40
+ self.hidden_channels = hidden_channels
41
+ self.out_channels = out_channels
42
+ self.kernel_size = kernel_size
43
+ self.n_layers = n_layers
44
+ self.p_dropout = p_dropout
45
+ assert n_layers > 1, "Number of layers should be larger than 0."
46
+
47
+ self.conv_layers = nn.ModuleList()
48
+ self.norm_layers = nn.ModuleList()
49
+ self.conv_layers.append(
50
+ nn.Conv1d(
51
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
52
+ )
53
+ )
54
+ self.norm_layers.append(LayerNorm(hidden_channels))
55
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
56
+ for _ in range(n_layers - 1):
57
+ self.conv_layers.append(
58
+ nn.Conv1d(
59
+ hidden_channels,
60
+ hidden_channels,
61
+ kernel_size,
62
+ padding=kernel_size // 2,
63
+ )
64
+ )
65
+ self.norm_layers.append(LayerNorm(hidden_channels))
66
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
67
+ self.proj.weight.data.zero_()
68
+ self.proj.bias.data.zero_()
69
+
70
+ def forward(self, x, x_mask):
71
+ x_org = x
72
+ for i in range(self.n_layers):
73
+ x = self.conv_layers[i](x * x_mask)
74
+ x = self.norm_layers[i](x)
75
+ x = self.relu_drop(x)
76
+ x = x_org + self.proj(x)
77
+ return x * x_mask
78
+
79
+
80
+ class DDSConv(nn.Module):
81
+ """
82
+ Dialted and Depth-Separable Convolution
83
+ """
84
+
85
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
86
+ super().__init__()
87
+ self.channels = channels
88
+ self.kernel_size = kernel_size
89
+ self.n_layers = n_layers
90
+ self.p_dropout = p_dropout
91
+
92
+ self.drop = nn.Dropout(p_dropout)
93
+ self.convs_sep = nn.ModuleList()
94
+ self.convs_1x1 = nn.ModuleList()
95
+ self.norms_1 = nn.ModuleList()
96
+ self.norms_2 = nn.ModuleList()
97
+ for i in range(n_layers):
98
+ dilation = kernel_size**i
99
+ padding = (kernel_size * dilation - dilation) // 2
100
+ self.convs_sep.append(
101
+ nn.Conv1d(
102
+ channels,
103
+ channels,
104
+ kernel_size,
105
+ groups=channels,
106
+ dilation=dilation,
107
+ padding=padding,
108
+ )
109
+ )
110
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
111
+ self.norms_1.append(LayerNorm(channels))
112
+ self.norms_2.append(LayerNorm(channels))
113
+
114
+ def forward(self, x, x_mask, g=None):
115
+ if g is not None:
116
+ x = x + g
117
+ for i in range(self.n_layers):
118
+ y = self.convs_sep[i](x * x_mask)
119
+ y = self.norms_1[i](y)
120
+ y = F.gelu(y)
121
+ y = self.convs_1x1[i](y)
122
+ y = self.norms_2[i](y)
123
+ y = F.gelu(y)
124
+ y = self.drop(y)
125
+ x = x + y
126
+ return x * x_mask
127
+
128
+
129
+ class WN(torch.nn.Module):
130
+ def __init__(
131
+ self,
132
+ hidden_channels,
133
+ kernel_size,
134
+ dilation_rate,
135
+ n_layers,
136
+ gin_channels=0,
137
+ p_dropout=0,
138
+ ):
139
+ super().__init__()
140
+ assert kernel_size % 2 == 1
141
+ self.hidden_channels = hidden_channels
142
+ self.kernel_size = (kernel_size,)
143
+ self.dilation_rate = dilation_rate
144
+ self.n_layers = n_layers
145
+ self.gin_channels = gin_channels
146
+ self.p_dropout = p_dropout
147
+
148
+ self.in_layers = torch.nn.ModuleList()
149
+ self.res_skip_layers = torch.nn.ModuleList()
150
+ self.drop = nn.Dropout(p_dropout)
151
+
152
+ if gin_channels != 0:
153
+ cond_layer = torch.nn.Conv1d(
154
+ gin_channels, 2 * hidden_channels * n_layers, 1
155
+ )
156
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
157
+
158
+ for i in range(n_layers):
159
+ dilation = dilation_rate**i
160
+ padding = int((kernel_size * dilation - dilation) / 2)
161
+ in_layer = torch.nn.Conv1d(
162
+ hidden_channels,
163
+ 2 * hidden_channels,
164
+ kernel_size,
165
+ dilation=dilation,
166
+ padding=padding,
167
+ )
168
+ in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
169
+ self.in_layers.append(in_layer)
170
+
171
+ # last one is not necessary
172
+ if i < n_layers - 1:
173
+ res_skip_channels = 2 * hidden_channels
174
+ else:
175
+ res_skip_channels = hidden_channels
176
+
177
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
178
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
179
+ self.res_skip_layers.append(res_skip_layer)
180
+
181
+ def forward(self, x, x_mask, g=None, **kwargs):
182
+ output = torch.zeros_like(x)
183
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
184
+
185
+ if g is not None:
186
+ g = self.cond_layer(g)
187
+
188
+ for i in range(self.n_layers):
189
+ x_in = self.in_layers[i](x)
190
+ if g is not None:
191
+ cond_offset = i * 2 * self.hidden_channels
192
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
193
+ else:
194
+ g_l = torch.zeros_like(x_in)
195
+
196
+ acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
197
+ acts = self.drop(acts)
198
+
199
+ res_skip_acts = self.res_skip_layers[i](acts)
200
+ if i < self.n_layers - 1:
201
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
202
+ x = (x + res_acts) * x_mask
203
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
204
+ else:
205
+ output = output + res_skip_acts
206
+ return output * x_mask
207
+
208
+ def remove_weight_norm(self):
209
+ if self.gin_channels != 0:
210
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
211
+ for l in self.in_layers:
212
+ torch.nn.utils.remove_weight_norm(l)
213
+ for l in self.res_skip_layers:
214
+ torch.nn.utils.remove_weight_norm(l)
215
+
216
+
217
+ class ResBlock1(torch.nn.Module):
218
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
219
+ super().__init__()
220
+ self.convs1 = nn.ModuleList(
221
+ [
222
+ weight_norm(
223
+ Conv1d(
224
+ channels,
225
+ channels,
226
+ kernel_size,
227
+ 1,
228
+ dilation=dilation[0],
229
+ padding=get_padding(kernel_size, dilation[0]),
230
+ )
231
+ ),
232
+ weight_norm(
233
+ Conv1d(
234
+ channels,
235
+ channels,
236
+ kernel_size,
237
+ 1,
238
+ dilation=dilation[1],
239
+ padding=get_padding(kernel_size, dilation[1]),
240
+ )
241
+ ),
242
+ weight_norm(
243
+ Conv1d(
244
+ channels,
245
+ channels,
246
+ kernel_size,
247
+ 1,
248
+ dilation=dilation[2],
249
+ padding=get_padding(kernel_size, dilation[2]),
250
+ )
251
+ ),
252
+ ]
253
+ )
254
+ self.convs1.apply(init_weights)
255
+
256
+ self.convs2 = nn.ModuleList(
257
+ [
258
+ weight_norm(
259
+ Conv1d(
260
+ channels,
261
+ channels,
262
+ kernel_size,
263
+ 1,
264
+ dilation=1,
265
+ padding=get_padding(kernel_size, 1),
266
+ )
267
+ ),
268
+ weight_norm(
269
+ Conv1d(
270
+ channels,
271
+ channels,
272
+ kernel_size,
273
+ 1,
274
+ dilation=1,
275
+ padding=get_padding(kernel_size, 1),
276
+ )
277
+ ),
278
+ weight_norm(
279
+ Conv1d(
280
+ channels,
281
+ channels,
282
+ kernel_size,
283
+ 1,
284
+ dilation=1,
285
+ padding=get_padding(kernel_size, 1),
286
+ )
287
+ ),
288
+ ]
289
+ )
290
+ self.convs2.apply(init_weights)
291
+
292
+ def forward(self, x, x_mask=None):
293
+ for c1, c2 in zip(self.convs1, self.convs2):
294
+ xt = F.leaky_relu(x, LRELU_SLOPE)
295
+ if x_mask is not None:
296
+ xt = xt * x_mask
297
+ xt = c1(xt)
298
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
299
+ if x_mask is not None:
300
+ xt = xt * x_mask
301
+ xt = c2(xt)
302
+ x = xt + x
303
+ if x_mask is not None:
304
+ x = x * x_mask
305
+ return x
306
+
307
+ def remove_weight_norm(self):
308
+ for l in self.convs1:
309
+ remove_weight_norm(l)
310
+ for l in self.convs2:
311
+ remove_weight_norm(l)
312
+
313
+
314
+ class ResBlock2(torch.nn.Module):
315
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
316
+ super().__init__()
317
+ self.convs = nn.ModuleList(
318
+ [
319
+ weight_norm(
320
+ Conv1d(
321
+ channels,
322
+ channels,
323
+ kernel_size,
324
+ 1,
325
+ dilation=dilation[0],
326
+ padding=get_padding(kernel_size, dilation[0]),
327
+ )
328
+ ),
329
+ weight_norm(
330
+ Conv1d(
331
+ channels,
332
+ channels,
333
+ kernel_size,
334
+ 1,
335
+ dilation=dilation[1],
336
+ padding=get_padding(kernel_size, dilation[1]),
337
+ )
338
+ ),
339
+ ]
340
+ )
341
+ self.convs.apply(init_weights)
342
+
343
+ def forward(self, x, x_mask=None):
344
+ for c in self.convs:
345
+ xt = F.leaky_relu(x, LRELU_SLOPE)
346
+ if x_mask is not None:
347
+ xt = xt * x_mask
348
+ xt = c(xt)
349
+ x = xt + x
350
+ if x_mask is not None:
351
+ x = x * x_mask
352
+ return x
353
+
354
+ def remove_weight_norm(self):
355
+ for l in self.convs:
356
+ remove_weight_norm(l)
357
+
358
+
359
+ class Log(nn.Module):
360
+ def forward(self, x, x_mask, reverse=False, **kwargs):
361
+ if not reverse:
362
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
363
+ logdet = torch.sum(-y, [1, 2])
364
+ return y, logdet
365
+ else:
366
+ x = torch.exp(x) * x_mask
367
+ return x
368
+
369
+
370
+ class Flip(nn.Module):
371
+ def forward(self, x, *args, reverse=False, **kwargs):
372
+ x = torch.flip(x, [1])
373
+ if not reverse:
374
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
375
+ return x, logdet
376
+ else:
377
+ return x
378
+
379
+
380
+ class ElementwiseAffine(nn.Module):
381
+ def __init__(self, channels):
382
+ super().__init__()
383
+ self.channels = channels
384
+ self.m = nn.Parameter(torch.zeros(channels, 1))
385
+ self.logs = nn.Parameter(torch.zeros(channels, 1))
386
+
387
+ def forward(self, x, x_mask, reverse=False, **kwargs):
388
+ if not reverse:
389
+ y = self.m + torch.exp(self.logs) * x
390
+ y = y * x_mask
391
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
392
+ return y, logdet
393
+ else:
394
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
395
+ return x
396
+
397
+
398
+ class ResidualCouplingLayer(nn.Module):
399
+ def __init__(
400
+ self,
401
+ channels,
402
+ hidden_channels,
403
+ kernel_size,
404
+ dilation_rate,
405
+ n_layers,
406
+ p_dropout=0,
407
+ gin_channels=0,
408
+ mean_only=False,
409
+ ):
410
+ assert channels % 2 == 0, "channels should be divisible by 2"
411
+ super().__init__()
412
+ self.channels = channels
413
+ self.hidden_channels = hidden_channels
414
+ self.kernel_size = kernel_size
415
+ self.dilation_rate = dilation_rate
416
+ self.n_layers = n_layers
417
+ self.half_channels = channels // 2
418
+ self.mean_only = mean_only
419
+
420
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
421
+ self.enc = WN(
422
+ hidden_channels,
423
+ kernel_size,
424
+ dilation_rate,
425
+ n_layers,
426
+ p_dropout=p_dropout,
427
+ gin_channels=gin_channels,
428
+ )
429
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
430
+ self.post.weight.data.zero_()
431
+ self.post.bias.data.zero_()
432
+
433
+ def forward(self, x, x_mask, g=None, reverse=False):
434
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
435
+ h = self.pre(x0) * x_mask
436
+ h = self.enc(h, x_mask, g=g)
437
+ stats = self.post(h) * x_mask
438
+ if not self.mean_only:
439
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
440
+ else:
441
+ m = stats
442
+ logs = torch.zeros_like(m)
443
+
444
+ if not reverse:
445
+ x1 = m + x1 * torch.exp(logs) * x_mask
446
+ x = torch.cat([x0, x1], 1)
447
+ logdet = torch.sum(logs, [1, 2])
448
+ return x, logdet
449
+ else:
450
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
451
+ x = torch.cat([x0, x1], 1)
452
+ return x
so_vits_svc_fork/modules/synthesizers.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from logging import getLogger
3
+ from typing import Any, Literal, Sequence
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ import so_vits_svc_fork.f0
9
+ from so_vits_svc_fork.f0 import f0_to_coarse
10
+ from so_vits_svc_fork.modules import commons as commons
11
+ from so_vits_svc_fork.modules.decoders.f0 import F0Decoder
12
+ from so_vits_svc_fork.modules.decoders.hifigan import NSFHifiGANGenerator
13
+ from so_vits_svc_fork.modules.decoders.mb_istft import (
14
+ Multiband_iSTFT_Generator,
15
+ Multistream_iSTFT_Generator,
16
+ iSTFT_Generator,
17
+ )
18
+ from so_vits_svc_fork.modules.encoders import Encoder, TextEncoder
19
+ from so_vits_svc_fork.modules.flows import ResidualCouplingBlock
20
+
21
+ LOG = getLogger(__name__)
22
+
23
+
24
+ class SynthesizerTrn(nn.Module):
25
+ """
26
+ Synthesizer for Training
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ spec_channels: int,
32
+ segment_size: int,
33
+ inter_channels: int,
34
+ hidden_channels: int,
35
+ filter_channels: int,
36
+ n_heads: int,
37
+ n_layers: int,
38
+ kernel_size: int,
39
+ p_dropout: int,
40
+ resblock: str,
41
+ resblock_kernel_sizes: Sequence[int],
42
+ resblock_dilation_sizes: Sequence[Sequence[int]],
43
+ upsample_rates: Sequence[int],
44
+ upsample_initial_channel: int,
45
+ upsample_kernel_sizes: Sequence[int],
46
+ gin_channels: int,
47
+ ssl_dim: int,
48
+ n_speakers: int,
49
+ sampling_rate: int = 44100,
50
+ type_: Literal["hifi-gan", "istft", "ms-istft", "mb-istft"] = "hifi-gan",
51
+ gen_istft_n_fft: int = 16,
52
+ gen_istft_hop_size: int = 4,
53
+ subbands: int = 4,
54
+ **kwargs: Any,
55
+ ):
56
+ super().__init__()
57
+ self.spec_channels = spec_channels
58
+ self.inter_channels = inter_channels
59
+ self.hidden_channels = hidden_channels
60
+ self.filter_channels = filter_channels
61
+ self.n_heads = n_heads
62
+ self.n_layers = n_layers
63
+ self.kernel_size = kernel_size
64
+ self.p_dropout = p_dropout
65
+ self.resblock = resblock
66
+ self.resblock_kernel_sizes = resblock_kernel_sizes
67
+ self.resblock_dilation_sizes = resblock_dilation_sizes
68
+ self.upsample_rates = upsample_rates
69
+ self.upsample_initial_channel = upsample_initial_channel
70
+ self.upsample_kernel_sizes = upsample_kernel_sizes
71
+ self.segment_size = segment_size
72
+ self.gin_channels = gin_channels
73
+ self.ssl_dim = ssl_dim
74
+ self.n_speakers = n_speakers
75
+ self.sampling_rate = sampling_rate
76
+ self.type_ = type_
77
+ self.gen_istft_n_fft = gen_istft_n_fft
78
+ self.gen_istft_hop_size = gen_istft_hop_size
79
+ self.subbands = subbands
80
+ if kwargs:
81
+ warnings.warn(f"Unused arguments: {kwargs}")
82
+
83
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
84
+
85
+ if ssl_dim is None:
86
+ self.pre = nn.LazyConv1d(hidden_channels, kernel_size=5, padding=2)
87
+ else:
88
+ self.pre = nn.Conv1d(ssl_dim, hidden_channels, kernel_size=5, padding=2)
89
+
90
+ self.enc_p = TextEncoder(
91
+ inter_channels,
92
+ hidden_channels,
93
+ filter_channels=filter_channels,
94
+ n_heads=n_heads,
95
+ n_layers=n_layers,
96
+ kernel_size=kernel_size,
97
+ p_dropout=p_dropout,
98
+ )
99
+
100
+ LOG.info(f"Decoder type: {type_}")
101
+ if type_ == "hifi-gan":
102
+ hps = {
103
+ "sampling_rate": sampling_rate,
104
+ "inter_channels": inter_channels,
105
+ "resblock": resblock,
106
+ "resblock_kernel_sizes": resblock_kernel_sizes,
107
+ "resblock_dilation_sizes": resblock_dilation_sizes,
108
+ "upsample_rates": upsample_rates,
109
+ "upsample_initial_channel": upsample_initial_channel,
110
+ "upsample_kernel_sizes": upsample_kernel_sizes,
111
+ "gin_channels": gin_channels,
112
+ }
113
+ self.dec = NSFHifiGANGenerator(h=hps)
114
+ self.mb = False
115
+ else:
116
+ hps = {
117
+ "initial_channel": inter_channels,
118
+ "resblock": resblock,
119
+ "resblock_kernel_sizes": resblock_kernel_sizes,
120
+ "resblock_dilation_sizes": resblock_dilation_sizes,
121
+ "upsample_rates": upsample_rates,
122
+ "upsample_initial_channel": upsample_initial_channel,
123
+ "upsample_kernel_sizes": upsample_kernel_sizes,
124
+ "gin_channels": gin_channels,
125
+ "gen_istft_n_fft": gen_istft_n_fft,
126
+ "gen_istft_hop_size": gen_istft_hop_size,
127
+ "subbands": subbands,
128
+ }
129
+
130
+ # gen_istft_n_fft, gen_istft_hop_size, subbands
131
+ if type_ == "istft":
132
+ del hps["subbands"]
133
+ self.dec = iSTFT_Generator(**hps)
134
+ elif type_ == "ms-istft":
135
+ self.dec = Multistream_iSTFT_Generator(**hps)
136
+ elif type_ == "mb-istft":
137
+ self.dec = Multiband_iSTFT_Generator(**hps)
138
+ else:
139
+ raise ValueError(f"Unknown type: {type_}")
140
+ self.mb = True
141
+
142
+ self.enc_q = Encoder(
143
+ spec_channels,
144
+ inter_channels,
145
+ hidden_channels,
146
+ 5,
147
+ 1,
148
+ 16,
149
+ gin_channels=gin_channels,
150
+ )
151
+ self.flow = ResidualCouplingBlock(
152
+ inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
153
+ )
154
+ self.f0_decoder = F0Decoder(
155
+ 1,
156
+ hidden_channels,
157
+ filter_channels,
158
+ n_heads,
159
+ n_layers,
160
+ kernel_size,
161
+ p_dropout,
162
+ spk_channels=gin_channels,
163
+ )
164
+ self.emb_uv = nn.Embedding(2, hidden_channels)
165
+
166
+ def forward(self, c, f0, uv, spec, g=None, c_lengths=None, spec_lengths=None):
167
+ g = self.emb_g(g).transpose(1, 2)
168
+ # ssl prenet
169
+ x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(
170
+ c.dtype
171
+ )
172
+ x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1, 2)
173
+
174
+ # f0 predict
175
+ lf0 = 2595.0 * torch.log10(1.0 + f0.unsqueeze(1) / 700.0) / 500
176
+ norm_lf0 = so_vits_svc_fork.f0.normalize_f0(lf0, x_mask, uv)
177
+ pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g)
178
+
179
+ # encoder
180
+ z_ptemp, m_p, logs_p, _ = self.enc_p(x, x_mask, f0=f0_to_coarse(f0))
181
+ z, m_q, logs_q, spec_mask = self.enc_q(spec, spec_lengths, g=g)
182
+
183
+ # flow
184
+ z_p = self.flow(z, spec_mask, g=g)
185
+ z_slice, pitch_slice, ids_slice = commons.rand_slice_segments_with_pitch(
186
+ z, f0, spec_lengths, self.segment_size
187
+ )
188
+
189
+ # MB-iSTFT-VITS
190
+ if self.mb:
191
+ o, o_mb = self.dec(z_slice, g=g)
192
+ # HiFi-GAN
193
+ else:
194
+ o = self.dec(z_slice, g=g, f0=pitch_slice)
195
+ o_mb = None
196
+ return (
197
+ o,
198
+ o_mb,
199
+ ids_slice,
200
+ spec_mask,
201
+ (z, z_p, m_p, logs_p, m_q, logs_q),
202
+ pred_lf0,
203
+ norm_lf0,
204
+ lf0,
205
+ )
206
+
207
+ def infer(self, c, f0, uv, g=None, noice_scale=0.35, predict_f0=False):
208
+ c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
209
+ g = self.emb_g(g).transpose(1, 2)
210
+ x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(
211
+ c.dtype
212
+ )
213
+ x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1, 2)
214
+
215
+ if predict_f0:
216
+ lf0 = 2595.0 * torch.log10(1.0 + f0.unsqueeze(1) / 700.0) / 500
217
+ norm_lf0 = so_vits_svc_fork.f0.normalize_f0(
218
+ lf0, x_mask, uv, random_scale=False
219
+ )
220
+ pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g)
221
+ f0 = (700 * (torch.pow(10, pred_lf0 * 500 / 2595) - 1)).squeeze(1)
222
+
223
+ z_p, m_p, logs_p, c_mask = self.enc_p(
224
+ x, x_mask, f0=f0_to_coarse(f0), noice_scale=noice_scale
225
+ )
226
+ z = self.flow(z_p, c_mask, g=g, reverse=True)
227
+
228
+ # MB-iSTFT-VITS
229
+ if self.mb:
230
+ o, o_mb = self.dec(z * c_mask, g=g)
231
+ else:
232
+ o = self.dec(z * c_mask, g=g, f0=f0)
233
+ return o
so_vits_svc_fork/preprocessing/__init__.py ADDED
File without changes
so_vits_svc_fork/preprocessing/config_templates/quickvc.json ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 100,
4
+ "eval_interval": 200,
5
+ "seed": 1234,
6
+ "epochs": 10000,
7
+ "learning_rate": 0.0001,
8
+ "betas": [0.8, 0.99],
9
+ "eps": 1e-9,
10
+ "batch_size": 16,
11
+ "fp16_run": false,
12
+ "bf16_run": false,
13
+ "lr_decay": 0.999875,
14
+ "segment_size": 10240,
15
+ "init_lr_ratio": 1,
16
+ "warmup_epochs": 0,
17
+ "c_mel": 45,
18
+ "c_kl": 1.0,
19
+ "use_sr": true,
20
+ "max_speclen": 512,
21
+ "port": "8001",
22
+ "keep_ckpts": 3,
23
+ "fft_sizes": [768, 1366, 342],
24
+ "hop_sizes": [60, 120, 20],
25
+ "win_lengths": [300, 600, 120],
26
+ "window": "hann_window",
27
+ "num_workers": 4,
28
+ "log_version": 0,
29
+ "ckpt_name_by_step": false,
30
+ "accumulate_grad_batches": 1
31
+ },
32
+ "data": {
33
+ "training_files": "filelists/44k/train.txt",
34
+ "validation_files": "filelists/44k/val.txt",
35
+ "max_wav_value": 32768.0,
36
+ "sampling_rate": 44100,
37
+ "filter_length": 2048,
38
+ "hop_length": 512,
39
+ "win_length": 2048,
40
+ "n_mel_channels": 80,
41
+ "mel_fmin": 0.0,
42
+ "mel_fmax": 22050,
43
+ "contentvec_final_proj": false
44
+ },
45
+ "model": {
46
+ "inter_channels": 192,
47
+ "hidden_channels": 192,
48
+ "filter_channels": 768,
49
+ "n_heads": 2,
50
+ "n_layers": 6,
51
+ "kernel_size": 3,
52
+ "p_dropout": 0.1,
53
+ "resblock": "1",
54
+ "resblock_kernel_sizes": [3, 7, 11],
55
+ "resblock_dilation_sizes": [
56
+ [1, 3, 5],
57
+ [1, 3, 5],
58
+ [1, 3, 5]
59
+ ],
60
+ "upsample_rates": [8, 4],
61
+ "upsample_initial_channel": 512,
62
+ "upsample_kernel_sizes": [32, 16],
63
+ "n_layers_q": 3,
64
+ "use_spectral_norm": false,
65
+ "gin_channels": 256,
66
+ "ssl_dim": 768,
67
+ "n_speakers": 200,
68
+ "type_": "ms-istft",
69
+ "gen_istft_n_fft": 16,
70
+ "gen_istft_hop_size": 4,
71
+ "subbands": 4,
72
+ "pretrained": {
73
+ "D_0.pth": "https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_D_320000.pth",
74
+ "G_0.pth": "https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_G_320000.pth"
75
+ }
76
+ },
77
+ "spk": {}
78
+ }
so_vits_svc_fork/preprocessing/config_templates/so-vits-svc-4.0v1-legacy.json ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 200,
4
+ "eval_interval": 800,
5
+ "seed": 1234,
6
+ "epochs": 10000,
7
+ "learning_rate": 0.0001,
8
+ "betas": [0.8, 0.99],
9
+ "eps": 1e-9,
10
+ "batch_size": 16,
11
+ "fp16_run": false,
12
+ "bf16_run": false,
13
+ "lr_decay": 0.999875,
14
+ "segment_size": 10240,
15
+ "init_lr_ratio": 1,
16
+ "warmup_epochs": 0,
17
+ "c_mel": 45,
18
+ "c_kl": 1.0,
19
+ "use_sr": true,
20
+ "max_speclen": 512,
21
+ "port": "8001",
22
+ "keep_ckpts": 3,
23
+ "num_workers": 4,
24
+ "log_version": 0,
25
+ "ckpt_name_by_step": false,
26
+ "accumulate_grad_batches": 1
27
+ },
28
+ "data": {
29
+ "training_files": "filelists/44k/train.txt",
30
+ "validation_files": "filelists/44k/val.txt",
31
+ "max_wav_value": 32768.0,
32
+ "sampling_rate": 44100,
33
+ "filter_length": 2048,
34
+ "hop_length": 512,
35
+ "win_length": 2048,
36
+ "n_mel_channels": 80,
37
+ "mel_fmin": 0.0,
38
+ "mel_fmax": 22050
39
+ },
40
+ "model": {
41
+ "inter_channels": 192,
42
+ "hidden_channels": 192,
43
+ "filter_channels": 768,
44
+ "n_heads": 2,
45
+ "n_layers": 6,
46
+ "kernel_size": 3,
47
+ "p_dropout": 0.1,
48
+ "resblock": "1",
49
+ "resblock_kernel_sizes": [3, 7, 11],
50
+ "resblock_dilation_sizes": [
51
+ [1, 3, 5],
52
+ [1, 3, 5],
53
+ [1, 3, 5]
54
+ ],
55
+ "upsample_rates": [8, 8, 2, 2, 2],
56
+ "upsample_initial_channel": 512,
57
+ "upsample_kernel_sizes": [16, 16, 4, 4, 4],
58
+ "n_layers_q": 3,
59
+ "use_spectral_norm": false,
60
+ "gin_channels": 256,
61
+ "ssl_dim": 256,
62
+ "n_speakers": 200,
63
+ "pretrained": {
64
+ "D_0.pth": "https://huggingface.co/therealvul/so-vits-svc-4.0-init/resolve/main/D_0.pth",
65
+ "G_0.pth": "https://huggingface.co/therealvul/so-vits-svc-4.0-init/resolve/main/G_0.pth"
66
+ }
67
+ },
68
+ "spk": {}
69
+ }
so_vits_svc_fork/preprocessing/config_templates/so-vits-svc-4.0v1.json ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 100,
4
+ "eval_interval": 200,
5
+ "seed": 1234,
6
+ "epochs": 10000,
7
+ "learning_rate": 0.0001,
8
+ "betas": [0.8, 0.99],
9
+ "eps": 1e-9,
10
+ "batch_size": 16,
11
+ "fp16_run": false,
12
+ "bf16_run": false,
13
+ "lr_decay": 0.999875,
14
+ "segment_size": 10240,
15
+ "init_lr_ratio": 1,
16
+ "warmup_epochs": 0,
17
+ "c_mel": 45,
18
+ "c_kl": 1.0,
19
+ "use_sr": true,
20
+ "max_speclen": 512,
21
+ "port": "8001",
22
+ "keep_ckpts": 3,
23
+ "num_workers": 4,
24
+ "log_version": 0,
25
+ "ckpt_name_by_step": false,
26
+ "accumulate_grad_batches": 1
27
+ },
28
+ "data": {
29
+ "training_files": "filelists/44k/train.txt",
30
+ "validation_files": "filelists/44k/val.txt",
31
+ "max_wav_value": 32768.0,
32
+ "sampling_rate": 44100,
33
+ "filter_length": 2048,
34
+ "hop_length": 512,
35
+ "win_length": 2048,
36
+ "n_mel_channels": 80,
37
+ "mel_fmin": 0.0,
38
+ "mel_fmax": 22050,
39
+ "contentvec_final_proj": false
40
+ },
41
+ "model": {
42
+ "inter_channels": 192,
43
+ "hidden_channels": 192,
44
+ "filter_channels": 768,
45
+ "n_heads": 2,
46
+ "n_layers": 6,
47
+ "kernel_size": 3,
48
+ "p_dropout": 0.1,
49
+ "resblock": "1",
50
+ "resblock_kernel_sizes": [3, 7, 11],
51
+ "resblock_dilation_sizes": [
52
+ [1, 3, 5],
53
+ [1, 3, 5],
54
+ [1, 3, 5]
55
+ ],
56
+ "upsample_rates": [8, 8, 2, 2, 2],
57
+ "upsample_initial_channel": 512,
58
+ "upsample_kernel_sizes": [16, 16, 4, 4, 4],
59
+ "n_layers_q": 3,
60
+ "use_spectral_norm": false,
61
+ "gin_channels": 256,
62
+ "ssl_dim": 768,
63
+ "n_speakers": 200,
64
+ "type_": "hifi-gan",
65
+ "pretrained": {
66
+ "D_0.pth": "https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_D_320000.pth",
67
+ "G_0.pth": "https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_G_320000.pth"
68
+ }
69
+ },
70
+ "spk": {}
71
+ }
so_vits_svc_fork/preprocessing/preprocess_classify.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from logging import getLogger
4
+ from pathlib import Path
5
+
6
+ import keyboard
7
+ import librosa
8
+ import sounddevice as sd
9
+ import soundfile as sf
10
+ from rich.console import Console
11
+ from tqdm.rich import tqdm
12
+
13
+ LOG = getLogger(__name__)
14
+
15
+
16
+ def preprocess_classify(
17
+ input_dir: Path | str, output_dir: Path | str, create_new: bool = True
18
+ ) -> None:
19
+ # paths
20
+ input_dir_ = Path(input_dir)
21
+ output_dir_ = Path(output_dir)
22
+ speed = 1
23
+ if not input_dir_.is_dir():
24
+ raise ValueError(f"{input_dir} is not a directory.")
25
+ output_dir_.mkdir(exist_ok=True)
26
+
27
+ console = Console()
28
+ # get audio paths and folders
29
+ audio_paths = list(input_dir_.glob("*.*"))
30
+ last_folders = [x for x in output_dir_.glob("*") if x.is_dir()]
31
+ console.print("Press ↑ or ↓ to change speed. Press any other key to classify.")
32
+ console.print(f"Folders: {[x.name for x in last_folders]}")
33
+
34
+ pbar_description = ""
35
+
36
+ pbar = tqdm(audio_paths)
37
+ for audio_path in pbar:
38
+ # read file
39
+ audio, sr = sf.read(audio_path)
40
+
41
+ # update description
42
+ duration = librosa.get_duration(y=audio, sr=sr)
43
+ pbar_description = f"{duration:.1f} {pbar_description}"
44
+ pbar.set_description(pbar_description)
45
+
46
+ while True:
47
+ # start playing
48
+ sd.play(librosa.effects.time_stretch(audio, rate=speed), sr, loop=True)
49
+
50
+ # wait for key press
51
+ key = str(keyboard.read_key())
52
+ if key == "down":
53
+ speed /= 1.1
54
+ console.print(f"Speed: {speed:.2f}")
55
+ elif key == "up":
56
+ speed *= 1.1
57
+ console.print(f"Speed: {speed:.2f}")
58
+ else:
59
+ break
60
+
61
+ # stop playing
62
+ sd.stop()
63
+
64
+ # print if folder changed
65
+ folders = [x for x in output_dir_.glob("*") if x.is_dir()]
66
+ if folders != last_folders:
67
+ console.print(f"Folders updated: {[x.name for x in folders]}")
68
+ last_folders = folders
69
+
70
+ # get folder
71
+ folder_candidates = [x for x in folders if x.name.startswith(key)]
72
+ if len(folder_candidates) == 0:
73
+ if create_new:
74
+ folder = output_dir_ / key
75
+ else:
76
+ console.print(f"No folder starts with {key}.")
77
+ continue
78
+ else:
79
+ if len(folder_candidates) > 1:
80
+ LOG.warning(
81
+ f"Multiple folders ({[x.name for x in folder_candidates]}) start with {key}. "
82
+ f"Using first one ({folder_candidates[0].name})."
83
+ )
84
+ folder = folder_candidates[0]
85
+ folder.mkdir(exist_ok=True)
86
+
87
+ # move file
88
+ new_path = folder / audio_path.name
89
+ audio_path.rename(new_path)
90
+
91
+ # update description
92
+ pbar_description = f"Last: {audio_path.name} -> {folder.name}"
93
+
94
+ # yield result
95
+ # yield audio_path, key, folder, new_path
so_vits_svc_fork/preprocessing/preprocess_flist_config.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ from copy import deepcopy
6
+ from logging import getLogger
7
+ from pathlib import Path
8
+
9
+ import numpy as np
10
+ from librosa import get_duration
11
+ from tqdm import tqdm
12
+
13
+ LOG = getLogger(__name__)
14
+ CONFIG_TEMPLATE_DIR = Path(__file__).parent / "config_templates"
15
+
16
+
17
+ def preprocess_config(
18
+ input_dir: Path | str,
19
+ train_list_path: Path | str,
20
+ val_list_path: Path | str,
21
+ test_list_path: Path | str,
22
+ config_path: Path | str,
23
+ config_name: str,
24
+ ):
25
+ input_dir = Path(input_dir)
26
+ train_list_path = Path(train_list_path)
27
+ val_list_path = Path(val_list_path)
28
+ test_list_path = Path(test_list_path)
29
+ config_path = Path(config_path)
30
+ train = []
31
+ val = []
32
+ test = []
33
+ spk_dict = {}
34
+ spk_id = 0
35
+ random = np.random.RandomState(1234)
36
+ for speaker in os.listdir(input_dir):
37
+ spk_dict[speaker] = spk_id
38
+ spk_id += 1
39
+ paths = []
40
+ for path in tqdm(list((input_dir / speaker).rglob("*.wav"))):
41
+ if get_duration(filename=path) < 0.3:
42
+ LOG.warning(f"skip {path} because it is too short.")
43
+ continue
44
+ paths.append(path)
45
+ random.shuffle(paths)
46
+ if len(paths) <= 4:
47
+ raise ValueError(
48
+ f"too few files in {input_dir / speaker} (expected at least 5)."
49
+ )
50
+ train += paths[2:-2]
51
+ val += paths[:2]
52
+ test += paths[-2:]
53
+
54
+ LOG.info(f"Writing {train_list_path}")
55
+ train_list_path.parent.mkdir(parents=True, exist_ok=True)
56
+ train_list_path.write_text(
57
+ "\n".join([x.as_posix() for x in train]), encoding="utf-8"
58
+ )
59
+
60
+ LOG.info(f"Writing {val_list_path}")
61
+ val_list_path.parent.mkdir(parents=True, exist_ok=True)
62
+ val_list_path.write_text("\n".join([x.as_posix() for x in val]), encoding="utf-8")
63
+
64
+ LOG.info(f"Writing {test_list_path}")
65
+ test_list_path.parent.mkdir(parents=True, exist_ok=True)
66
+ test_list_path.write_text("\n".join([x.as_posix() for x in test]), encoding="utf-8")
67
+
68
+ config = deepcopy(
69
+ json.loads(
70
+ (
71
+ CONFIG_TEMPLATE_DIR
72
+ / (
73
+ config_name
74
+ if config_name.endswith(".json")
75
+ else config_name + ".json"
76
+ )
77
+ ).read_text(encoding="utf-8")
78
+ )
79
+ )
80
+ config["spk"] = spk_dict
81
+ config["data"]["training_files"] = train_list_path.as_posix()
82
+ config["data"]["validation_files"] = val_list_path.as_posix()
83
+ LOG.info(f"Writing {config_path}")
84
+ config_path.parent.mkdir(parents=True, exist_ok=True)
85
+ with config_path.open("w", encoding="utf-8") as f:
86
+ json.dump(config, f, indent=2)
so_vits_svc_fork/preprocessing/preprocess_hubert_f0.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from logging import getLogger
4
+ from pathlib import Path
5
+ from random import shuffle
6
+ from typing import Iterable, Literal
7
+
8
+ import librosa
9
+ import numpy as np
10
+ import torch
11
+ import torchaudio
12
+ from joblib import Parallel, cpu_count, delayed
13
+ from tqdm import tqdm
14
+ from transformers import HubertModel
15
+
16
+ import so_vits_svc_fork.f0
17
+ from so_vits_svc_fork import utils
18
+
19
+ from ..hparams import HParams
20
+ from ..modules.mel_processing import spec_to_mel_torch, spectrogram_torch
21
+ from ..utils import get_optimal_device, get_total_gpu_memory
22
+ from .preprocess_utils import check_hubert_min_duration
23
+
24
+ LOG = getLogger(__name__)
25
+ HUBERT_MEMORY = 2900
26
+ HUBERT_MEMORY_CREPE = 3900
27
+
28
+
29
+ def _process_one(
30
+ *,
31
+ filepath: Path,
32
+ content_model: HubertModel,
33
+ device: torch.device | str = get_optimal_device(),
34
+ f0_method: Literal["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"] = "dio",
35
+ force_rebuild: bool = False,
36
+ hps: HParams,
37
+ ):
38
+ audio, sr = librosa.load(filepath, sr=hps.data.sampling_rate, mono=True)
39
+
40
+ if not check_hubert_min_duration(audio, sr):
41
+ LOG.info(f"Skip {filepath} because it is too short.")
42
+ return
43
+
44
+ data_path = filepath.parent / (filepath.name + ".data.pt")
45
+ if data_path.exists() and not force_rebuild:
46
+ return
47
+
48
+ # Compute f0
49
+ f0 = so_vits_svc_fork.f0.compute_f0(
50
+ audio, sampling_rate=sr, hop_length=hps.data.hop_length, method=f0_method
51
+ )
52
+ f0, uv = so_vits_svc_fork.f0.interpolate_f0(f0)
53
+ f0 = torch.from_numpy(f0).float()
54
+ uv = torch.from_numpy(uv).float()
55
+
56
+ # Compute HuBERT content
57
+ audio = torch.from_numpy(audio).float().to(device)
58
+ c = utils.get_content(
59
+ content_model,
60
+ audio,
61
+ device,
62
+ sr=sr,
63
+ legacy_final_proj=hps.data.get("contentvec_final_proj", True),
64
+ )
65
+ c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[0])
66
+ torch.cuda.empty_cache()
67
+
68
+ # Compute spectrogram
69
+ audio, sr = torchaudio.load(filepath)
70
+ spec = spectrogram_torch(audio, hps).squeeze(0)
71
+ mel_spec = spec_to_mel_torch(spec, hps)
72
+ torch.cuda.empty_cache()
73
+
74
+ # fix lengths
75
+ lmin = min(spec.shape[1], mel_spec.shape[1], f0.shape[0], uv.shape[0], c.shape[1])
76
+ spec, mel_spec, f0, uv, c = (
77
+ spec[:, :lmin],
78
+ mel_spec[:, :lmin],
79
+ f0[:lmin],
80
+ uv[:lmin],
81
+ c[:, :lmin],
82
+ )
83
+
84
+ # get speaker id
85
+ spk_name = filepath.parent.name
86
+ spk = hps.spk.__dict__[spk_name]
87
+ spk = torch.tensor(spk).long()
88
+ assert (
89
+ spec.shape[1] == mel_spec.shape[1] == f0.shape[0] == uv.shape[0] == c.shape[1]
90
+ ), (spec.shape, mel_spec.shape, f0.shape, uv.shape, c.shape)
91
+ data = {
92
+ "spec": spec,
93
+ "mel_spec": mel_spec,
94
+ "f0": f0,
95
+ "uv": uv,
96
+ "content": c,
97
+ "audio": audio,
98
+ "spk": spk,
99
+ }
100
+ data = {k: v.cpu() for k, v in data.items()}
101
+ with data_path.open("wb") as f:
102
+ torch.save(data, f)
103
+
104
+
105
+ def _process_batch(filepaths: Iterable[Path], pbar_position: int, **kwargs):
106
+ hps = kwargs["hps"]
107
+ content_model = utils.get_hubert_model(
108
+ get_optimal_device(), hps.data.get("contentvec_final_proj", True)
109
+ )
110
+
111
+ for filepath in tqdm(filepaths, position=pbar_position):
112
+ _process_one(
113
+ content_model=content_model,
114
+ filepath=filepath,
115
+ **kwargs,
116
+ )
117
+
118
+
119
+ def preprocess_hubert_f0(
120
+ input_dir: Path | str,
121
+ config_path: Path | str,
122
+ n_jobs: int | None = None,
123
+ f0_method: Literal["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"] = "dio",
124
+ force_rebuild: bool = False,
125
+ ):
126
+ input_dir = Path(input_dir)
127
+ config_path = Path(config_path)
128
+ hps = utils.get_hparams(config_path)
129
+ if n_jobs is None:
130
+ # add cpu_count() to avoid SIGKILL
131
+ memory = get_total_gpu_memory("total")
132
+ n_jobs = min(
133
+ max(
134
+ memory
135
+ // (HUBERT_MEMORY_CREPE if f0_method == "crepe" else HUBERT_MEMORY)
136
+ if memory is not None
137
+ else 1,
138
+ 1,
139
+ ),
140
+ cpu_count(),
141
+ )
142
+ LOG.info(f"n_jobs automatically set to {n_jobs}, memory: {memory} MiB")
143
+
144
+ filepaths = list(input_dir.rglob("*.wav"))
145
+ n_jobs = min(len(filepaths) // 16 + 1, n_jobs)
146
+ shuffle(filepaths)
147
+ filepath_chunks = np.array_split(filepaths, n_jobs)
148
+ Parallel(n_jobs=n_jobs)(
149
+ delayed(_process_batch)(
150
+ filepaths=chunk,
151
+ pbar_position=pbar_position,
152
+ f0_method=f0_method,
153
+ force_rebuild=force_rebuild,
154
+ hps=hps,
155
+ )
156
+ for (pbar_position, chunk) in enumerate(filepath_chunks)
157
+ )
so_vits_svc_fork/preprocessing/preprocess_resample.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+ from logging import getLogger
5
+ from pathlib import Path
6
+ from typing import Iterable
7
+
8
+ import librosa
9
+ import soundfile
10
+ from joblib import Parallel, delayed
11
+ from tqdm_joblib import tqdm_joblib
12
+
13
+ from .preprocess_utils import check_hubert_min_duration
14
+
15
+ LOG = getLogger(__name__)
16
+
17
+ # input_dir and output_dir exists.
18
+ # write code to convert input dir audio files to output dir audio files,
19
+ # without changing folder structure. Use joblib to parallelize.
20
+ # Converting audio files includes:
21
+ # - resampling to specified sampling rate
22
+ # - trim silence
23
+ # - adjust volume in a smart way
24
+ # - save as 16-bit wav file
25
+
26
+
27
+ def _get_unique_filename(path: Path, existing_paths: Iterable[Path]) -> Path:
28
+ """Return a unique path by appending a number to the original path."""
29
+ if path not in existing_paths:
30
+ return path
31
+ i = 1
32
+ while True:
33
+ new_path = path.parent / f"{path.stem}_{i}{path.suffix}"
34
+ if new_path not in existing_paths:
35
+ return new_path
36
+ i += 1
37
+
38
+
39
+ def is_relative_to(path: Path, *other):
40
+ """Return True if the path is relative to another path or False.
41
+ Python 3.9+ has Path.is_relative_to() method, but we need to support Python 3.8.
42
+ """
43
+ try:
44
+ path.relative_to(*other)
45
+ return True
46
+ except ValueError:
47
+ return False
48
+
49
+
50
+ def _preprocess_one(
51
+ input_path: Path,
52
+ output_path: Path,
53
+ sr: int,
54
+ *,
55
+ top_db: int,
56
+ frame_seconds: float,
57
+ hop_seconds: float,
58
+ ) -> None:
59
+ """Preprocess one audio file."""
60
+
61
+ try:
62
+ audio, sr = librosa.load(input_path, sr=sr, mono=True)
63
+
64
+ # Audioread is the last backend it will attempt, so this is the exception thrown on failure
65
+ except Exception as e:
66
+ # Failure due to attempting to load a file that is not audio, so return early
67
+ LOG.warning(f"Failed to load {input_path} due to {e}")
68
+ return
69
+
70
+ if not check_hubert_min_duration(audio, sr):
71
+ LOG.info(f"Skip {input_path} because it is too short.")
72
+ return
73
+
74
+ # Adjust volume
75
+ audio /= max(audio.max(), -audio.min())
76
+
77
+ # Trim silence
78
+ audio, _ = librosa.effects.trim(
79
+ audio,
80
+ top_db=top_db,
81
+ frame_length=int(frame_seconds * sr),
82
+ hop_length=int(hop_seconds * sr),
83
+ )
84
+
85
+ if not check_hubert_min_duration(audio, sr):
86
+ LOG.info(f"Skip {input_path} because it is too short.")
87
+ return
88
+
89
+ soundfile.write(output_path, audio, samplerate=sr, subtype="PCM_16")
90
+
91
+
92
+ def preprocess_resample(
93
+ input_dir: Path | str,
94
+ output_dir: Path | str,
95
+ sampling_rate: int,
96
+ n_jobs: int = -1,
97
+ *,
98
+ top_db: int = 30,
99
+ frame_seconds: float = 0.1,
100
+ hop_seconds: float = 0.05,
101
+ ) -> None:
102
+ input_dir = Path(input_dir)
103
+ output_dir = Path(output_dir)
104
+ """Preprocess audio files in input_dir and save them to output_dir."""
105
+
106
+ out_paths = []
107
+ in_paths = list(input_dir.rglob("*.*"))
108
+ if not in_paths:
109
+ raise ValueError(f"No audio files found in {input_dir}")
110
+ for in_path in in_paths:
111
+ in_path_relative = in_path.relative_to(input_dir)
112
+ if not in_path.is_absolute() and is_relative_to(
113
+ in_path, Path("dataset_raw") / "44k"
114
+ ):
115
+ new_in_path_relative = in_path_relative.relative_to("44k")
116
+ warnings.warn(
117
+ f"Recommended folder structure has changed since v1.0.0. "
118
+ "Please move your dataset directly under dataset_raw folder. "
119
+ f"Recoginzed {in_path_relative} as {new_in_path_relative}"
120
+ )
121
+ in_path_relative = new_in_path_relative
122
+
123
+ if len(in_path_relative.parts) < 2:
124
+ continue
125
+ speaker_name = in_path_relative.parts[0]
126
+ file_name = in_path_relative.with_suffix(".wav").name
127
+ out_path = output_dir / speaker_name / file_name
128
+ out_path = _get_unique_filename(out_path, out_paths)
129
+ out_path.parent.mkdir(parents=True, exist_ok=True)
130
+ out_paths.append(out_path)
131
+
132
+ in_and_out_paths = list(zip(in_paths, out_paths))
133
+
134
+ with tqdm_joblib(desc="Preprocessing", total=len(in_and_out_paths)):
135
+ Parallel(n_jobs=n_jobs)(
136
+ delayed(_preprocess_one)(
137
+ *args,
138
+ sr=sampling_rate,
139
+ top_db=top_db,
140
+ frame_seconds=frame_seconds,
141
+ hop_seconds=hop_seconds,
142
+ )
143
+ for args in in_and_out_paths
144
+ )
so_vits_svc_fork/preprocessing/preprocess_speaker_diarization.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from collections import defaultdict
4
+ from logging import getLogger
5
+ from pathlib import Path
6
+
7
+ import librosa
8
+ import soundfile as sf
9
+ import torch
10
+ from joblib import Parallel, delayed
11
+ from pyannote.audio import Pipeline
12
+ from tqdm import tqdm
13
+ from tqdm_joblib import tqdm_joblib
14
+
15
+ LOG = getLogger(__name__)
16
+
17
+
18
+ def _process_one(
19
+ input_path: Path,
20
+ output_dir: Path,
21
+ sr: int,
22
+ *,
23
+ min_speakers: int = 1,
24
+ max_speakers: int = 1,
25
+ huggingface_token: str | None = None,
26
+ ) -> None:
27
+ try:
28
+ audio, sr = librosa.load(input_path, sr=sr, mono=True)
29
+ except Exception as e:
30
+ LOG.warning(f"Failed to read {input_path}: {e}")
31
+ return
32
+ pipeline = Pipeline.from_pretrained(
33
+ "pyannote/speaker-diarization", use_auth_token=huggingface_token
34
+ )
35
+ if pipeline is None:
36
+ raise ValueError("Failed to load pipeline")
37
+
38
+ LOG.info(f"Processing {input_path}. This may take a while...")
39
+ diarization = pipeline(
40
+ input_path, min_speakers=min_speakers, max_speakers=max_speakers
41
+ )
42
+
43
+ LOG.info(f"Found {len(diarization)} tracks, writing to {output_dir}")
44
+ speaker_count = defaultdict(int)
45
+
46
+ output_dir.mkdir(parents=True, exist_ok=True)
47
+ for segment, track, speaker in tqdm(
48
+ list(diarization.itertracks(yield_label=True)), desc=f"Writing {input_path}"
49
+ ):
50
+ if segment.end - segment.start < 1:
51
+ continue
52
+ speaker_count[speaker] += 1
53
+ audio_cut = audio[int(segment.start * sr) : int(segment.end * sr)]
54
+ sf.write(
55
+ (output_dir / f"{speaker}_{speaker_count[speaker]}.wav"),
56
+ audio_cut,
57
+ sr,
58
+ )
59
+
60
+ LOG.info(f"Speaker count: {speaker_count}")
61
+
62
+
63
+ def preprocess_speaker_diarization(
64
+ input_dir: Path | str,
65
+ output_dir: Path | str,
66
+ sr: int,
67
+ *,
68
+ min_speakers: int = 1,
69
+ max_speakers: int = 1,
70
+ huggingface_token: str | None = None,
71
+ n_jobs: int = -1,
72
+ ) -> None:
73
+ if huggingface_token is not None and not huggingface_token.startswith("hf_"):
74
+ LOG.warning("Huggingface token probably should start with hf_")
75
+ if not torch.cuda.is_available():
76
+ LOG.warning("CUDA is not available. This will be extremely slow.")
77
+ input_dir = Path(input_dir)
78
+ output_dir = Path(output_dir)
79
+ input_dir.mkdir(parents=True, exist_ok=True)
80
+ output_dir.mkdir(parents=True, exist_ok=True)
81
+ input_paths = list(input_dir.rglob("*.*"))
82
+ with tqdm_joblib(desc="Preprocessing speaker diarization", total=len(input_paths)):
83
+ Parallel(n_jobs=n_jobs)(
84
+ delayed(_process_one)(
85
+ input_path,
86
+ output_dir / input_path.relative_to(input_dir).parent / input_path.stem,
87
+ sr,
88
+ max_speakers=max_speakers,
89
+ min_speakers=min_speakers,
90
+ huggingface_token=huggingface_token,
91
+ )
92
+ for input_path in input_paths
93
+ )
so_vits_svc_fork/preprocessing/preprocess_split.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from logging import getLogger
4
+ from pathlib import Path
5
+
6
+ import librosa
7
+ import soundfile as sf
8
+ from joblib import Parallel, delayed
9
+ from tqdm import tqdm
10
+ from tqdm_joblib import tqdm_joblib
11
+
12
+ LOG = getLogger(__name__)
13
+
14
+
15
+ def _process_one(
16
+ input_path: Path,
17
+ output_dir: Path,
18
+ sr: int,
19
+ *,
20
+ max_length: float = 10.0,
21
+ top_db: int = 30,
22
+ frame_seconds: float = 0.5,
23
+ hop_seconds: float = 0.1,
24
+ ):
25
+ try:
26
+ audio, sr = librosa.load(input_path, sr=sr, mono=True)
27
+ except Exception as e:
28
+ LOG.warning(f"Failed to read {input_path}: {e}")
29
+ return
30
+ intervals = librosa.effects.split(
31
+ audio,
32
+ top_db=top_db,
33
+ frame_length=int(sr * frame_seconds),
34
+ hop_length=int(sr * hop_seconds),
35
+ )
36
+ output_dir.mkdir(parents=True, exist_ok=True)
37
+ for start, end in tqdm(intervals, desc=f"Writing {input_path}"):
38
+ for sub_start in range(start, end, int(sr * max_length)):
39
+ sub_end = min(sub_start + int(sr * max_length), end)
40
+ audio_cut = audio[sub_start:sub_end]
41
+ sf.write(
42
+ (
43
+ output_dir
44
+ / f"{input_path.stem}_{sub_start / sr:.3f}_{sub_end / sr:.3f}.wav"
45
+ ),
46
+ audio_cut,
47
+ sr,
48
+ )
49
+
50
+
51
+ def preprocess_split(
52
+ input_dir: Path | str,
53
+ output_dir: Path | str,
54
+ sr: int,
55
+ *,
56
+ max_length: float = 10.0,
57
+ top_db: int = 30,
58
+ frame_seconds: float = 0.5,
59
+ hop_seconds: float = 0.1,
60
+ n_jobs: int = -1,
61
+ ):
62
+ input_dir = Path(input_dir)
63
+ output_dir = Path(output_dir)
64
+ output_dir.mkdir(parents=True, exist_ok=True)
65
+ input_paths = list(input_dir.rglob("*.*"))
66
+ with tqdm_joblib(desc="Splitting", total=len(input_paths)):
67
+ Parallel(n_jobs=n_jobs)(
68
+ delayed(_process_one)(
69
+ input_path,
70
+ output_dir / input_path.relative_to(input_dir).parent,
71
+ sr,
72
+ max_length=max_length,
73
+ top_db=top_db,
74
+ frame_seconds=frame_seconds,
75
+ hop_seconds=hop_seconds,
76
+ )
77
+ for input_path in input_paths
78
+ )
so_vits_svc_fork/preprocessing/preprocess_utils.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from numpy import ndarray
2
+
3
+
4
+ def check_hubert_min_duration(audio: ndarray, sr: int) -> bool:
5
+ return len(audio) / sr >= 0.3
so_vits_svc_fork/py.typed ADDED
File without changes
so_vits_svc_fork/train.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import warnings
5
+ from logging import getLogger
6
+ from multiprocessing import cpu_count
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ import lightning.pytorch as pl
11
+ import torch
12
+ from lightning.pytorch.accelerators import MPSAccelerator, TPUAccelerator
13
+ from lightning.pytorch.callbacks import DeviceStatsMonitor
14
+ from lightning.pytorch.loggers import TensorBoardLogger
15
+ from lightning.pytorch.strategies.ddp import DDPStrategy
16
+ from lightning.pytorch.tuner import Tuner
17
+ from torch.cuda.amp import autocast
18
+ from torch.nn import functional as F
19
+ from torch.utils.data import DataLoader
20
+ from torch.utils.tensorboard.writer import SummaryWriter
21
+
22
+ import so_vits_svc_fork.f0
23
+ import so_vits_svc_fork.modules.commons as commons
24
+ import so_vits_svc_fork.utils
25
+
26
+ from . import utils
27
+ from .dataset import TextAudioCollate, TextAudioDataset
28
+ from .logger import is_notebook
29
+ from .modules.descriminators import MultiPeriodDiscriminator
30
+ from .modules.losses import discriminator_loss, feature_loss, generator_loss, kl_loss
31
+ from .modules.mel_processing import mel_spectrogram_torch
32
+ from .modules.synthesizers import SynthesizerTrn
33
+
34
+ LOG = getLogger(__name__)
35
+ torch.set_float32_matmul_precision("high")
36
+
37
+
38
+ class VCDataModule(pl.LightningDataModule):
39
+ batch_size: int
40
+
41
+ def __init__(self, hparams: Any):
42
+ super().__init__()
43
+ self.__hparams = hparams
44
+ self.batch_size = hparams.train.batch_size
45
+ if not isinstance(self.batch_size, int):
46
+ self.batch_size = 1
47
+ self.collate_fn = TextAudioCollate()
48
+
49
+ # these should be called in setup(), but we need to calculate check_val_every_n_epoch
50
+ self.train_dataset = TextAudioDataset(self.__hparams, is_validation=False)
51
+ self.val_dataset = TextAudioDataset(self.__hparams, is_validation=True)
52
+
53
+ def train_dataloader(self):
54
+ return DataLoader(
55
+ self.train_dataset,
56
+ num_workers=min(cpu_count(), self.__hparams.train.get("num_workers", 8)),
57
+ batch_size=self.batch_size,
58
+ collate_fn=self.collate_fn,
59
+ persistent_workers=True,
60
+ )
61
+
62
+ def val_dataloader(self):
63
+ return DataLoader(
64
+ self.val_dataset,
65
+ batch_size=1,
66
+ collate_fn=self.collate_fn,
67
+ )
68
+
69
+
70
+ def train(
71
+ config_path: Path | str, model_path: Path | str, reset_optimizer: bool = False
72
+ ):
73
+ config_path = Path(config_path)
74
+ model_path = Path(model_path)
75
+
76
+ hparams = utils.get_backup_hparams(config_path, model_path)
77
+ utils.ensure_pretrained_model(
78
+ model_path,
79
+ hparams.model.get(
80
+ "pretrained",
81
+ {
82
+ "D_0.pth": "https://huggingface.co/therealvul/so-vits-svc-4.0-init/resolve/main/D_0.pth",
83
+ "G_0.pth": "https://huggingface.co/therealvul/so-vits-svc-4.0-init/resolve/main/G_0.pth",
84
+ },
85
+ ),
86
+ )
87
+
88
+ datamodule = VCDataModule(hparams)
89
+ strategy = (
90
+ (
91
+ "ddp_find_unused_parameters_true"
92
+ if os.name != "nt"
93
+ else DDPStrategy(find_unused_parameters=True, process_group_backend="gloo")
94
+ )
95
+ if torch.cuda.device_count() > 1
96
+ else "auto"
97
+ )
98
+ LOG.info(f"Using strategy: {strategy}")
99
+ trainer = pl.Trainer(
100
+ logger=TensorBoardLogger(
101
+ model_path, "lightning_logs", hparams.train.get("log_version", 0)
102
+ ),
103
+ # profiler="simple",
104
+ val_check_interval=hparams.train.eval_interval,
105
+ max_epochs=hparams.train.epochs,
106
+ check_val_every_n_epoch=None,
107
+ precision="16-mixed"
108
+ if hparams.train.fp16_run
109
+ else "bf16-mixed"
110
+ if hparams.train.get("bf16_run", False)
111
+ else 32,
112
+ strategy=strategy,
113
+ callbacks=([pl.callbacks.RichProgressBar()] if not is_notebook() else [])
114
+ + [DeviceStatsMonitor()],
115
+ benchmark=True,
116
+ enable_checkpointing=False,
117
+ )
118
+ tuner = Tuner(trainer)
119
+ model = VitsLightning(reset_optimizer=reset_optimizer, **hparams)
120
+
121
+ # automatic batch size scaling
122
+ batch_size = hparams.train.batch_size
123
+ batch_split = str(batch_size).split("-")
124
+ batch_size = batch_split[0]
125
+ init_val = 2 if len(batch_split) <= 1 else int(batch_split[1])
126
+ max_trials = 25 if len(batch_split) <= 2 else int(batch_split[2])
127
+ if batch_size == "auto":
128
+ batch_size = "binsearch"
129
+ if batch_size in ["power", "binsearch"]:
130
+ model.tuning = True
131
+ tuner.scale_batch_size(
132
+ model,
133
+ mode=batch_size,
134
+ datamodule=datamodule,
135
+ steps_per_trial=1,
136
+ init_val=init_val,
137
+ max_trials=max_trials,
138
+ )
139
+ model.tuning = False
140
+ else:
141
+ batch_size = int(batch_size)
142
+ # automatic learning rate scaling is not supported for multiple optimizers
143
+ """if hparams.train.learning_rate == "auto":
144
+ lr_finder = tuner.lr_find(model)
145
+ LOG.info(lr_finder.results)
146
+ fig = lr_finder.plot(suggest=True)
147
+ fig.savefig(model_path / "lr_finder.png")"""
148
+
149
+ trainer.fit(model, datamodule=datamodule)
150
+
151
+
152
+ class VitsLightning(pl.LightningModule):
153
+ def __init__(self, reset_optimizer: bool = False, **hparams: Any):
154
+ super().__init__()
155
+ self._temp_epoch = 0 # Add this line to initialize the _temp_epoch attribute
156
+ self.save_hyperparameters("reset_optimizer")
157
+ self.save_hyperparameters(*[k for k in hparams.keys()])
158
+ torch.manual_seed(self.hparams.train.seed)
159
+ self.net_g = SynthesizerTrn(
160
+ self.hparams.data.filter_length // 2 + 1,
161
+ self.hparams.train.segment_size // self.hparams.data.hop_length,
162
+ **self.hparams.model,
163
+ )
164
+ self.net_d = MultiPeriodDiscriminator(self.hparams.model.use_spectral_norm)
165
+ self.automatic_optimization = False
166
+ self.learning_rate = self.hparams.train.learning_rate
167
+ self.optim_g = torch.optim.AdamW(
168
+ self.net_g.parameters(),
169
+ self.learning_rate,
170
+ betas=self.hparams.train.betas,
171
+ eps=self.hparams.train.eps,
172
+ )
173
+ self.optim_d = torch.optim.AdamW(
174
+ self.net_d.parameters(),
175
+ self.learning_rate,
176
+ betas=self.hparams.train.betas,
177
+ eps=self.hparams.train.eps,
178
+ )
179
+ self.scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
180
+ self.optim_g, gamma=self.hparams.train.lr_decay
181
+ )
182
+ self.scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
183
+ self.optim_d, gamma=self.hparams.train.lr_decay
184
+ )
185
+ self.optimizers_count = 2
186
+ self.load(reset_optimizer)
187
+ self.tuning = False
188
+
189
+ def on_train_start(self) -> None:
190
+ if not self.tuning:
191
+ self.set_current_epoch(self._temp_epoch)
192
+ total_batch_idx = self._temp_epoch * len(self.trainer.train_dataloader)
193
+ self.set_total_batch_idx(total_batch_idx)
194
+ global_step = total_batch_idx * self.optimizers_count
195
+ self.set_global_step(global_step)
196
+
197
+ # check if using tpu or mps
198
+ if isinstance(self.trainer.accelerator, (TPUAccelerator, MPSAccelerator)):
199
+ # patch torch.stft to use cpu
200
+ LOG.warning("Using TPU/MPS. Patching torch.stft to use cpu.")
201
+
202
+ def stft(
203
+ input: torch.Tensor,
204
+ n_fft: int,
205
+ hop_length: int | None = None,
206
+ win_length: int | None = None,
207
+ window: torch.Tensor | None = None,
208
+ center: bool = True,
209
+ pad_mode: str = "reflect",
210
+ normalized: bool = False,
211
+ onesided: bool | None = None,
212
+ return_complex: bool | None = None,
213
+ ) -> torch.Tensor:
214
+ device = input.device
215
+ input = input.cpu()
216
+ if window is not None:
217
+ window = window.cpu()
218
+ return torch.functional.stft(
219
+ input,
220
+ n_fft,
221
+ hop_length,
222
+ win_length,
223
+ window,
224
+ center,
225
+ pad_mode,
226
+ normalized,
227
+ onesided,
228
+ return_complex,
229
+ ).to(device)
230
+
231
+ torch.stft = stft
232
+
233
+ elif "bf" in self.trainer.precision:
234
+ LOG.warning("Using bf. Patching torch.stft to use fp32.")
235
+
236
+ def stft(
237
+ input: torch.Tensor,
238
+ n_fft: int,
239
+ hop_length: int | None = None,
240
+ win_length: int | None = None,
241
+ window: torch.Tensor | None = None,
242
+ center: bool = True,
243
+ pad_mode: str = "reflect",
244
+ normalized: bool = False,
245
+ onesided: bool | None = None,
246
+ return_complex: bool | None = None,
247
+ ) -> torch.Tensor:
248
+ dtype = input.dtype
249
+ input = input.float()
250
+ if window is not None:
251
+ window = window.float()
252
+ return torch.functional.stft(
253
+ input,
254
+ n_fft,
255
+ hop_length,
256
+ win_length,
257
+ window,
258
+ center,
259
+ pad_mode,
260
+ normalized,
261
+ onesided,
262
+ return_complex,
263
+ ).to(dtype)
264
+
265
+ torch.stft = stft
266
+
267
+ def on_train_end(self) -> None:
268
+ self.save_checkpoints(adjust=0)
269
+
270
+ def save_checkpoints(self, adjust=1):
271
+ if self.tuning or self.trainer.sanity_checking:
272
+ return
273
+
274
+ # only save checkpoints if we are on the main device
275
+ if (
276
+ hasattr(self.device, "index")
277
+ and self.device.index != None
278
+ and self.device.index != 0
279
+ ):
280
+ return
281
+
282
+ # `on_train_end` will be the actual epoch, not a -1, so we have to call it with `adjust = 0`
283
+ current_epoch = self.current_epoch + adjust
284
+ total_batch_idx = self.total_batch_idx - 1 + adjust
285
+
286
+ utils.save_checkpoint(
287
+ self.net_g,
288
+ self.optim_g,
289
+ self.learning_rate,
290
+ current_epoch,
291
+ Path(self.hparams.model_dir)
292
+ / f"G_{total_batch_idx if self.hparams.train.get('ckpt_name_by_step', False) else current_epoch}.pth",
293
+ )
294
+ utils.save_checkpoint(
295
+ self.net_d,
296
+ self.optim_d,
297
+ self.learning_rate,
298
+ current_epoch,
299
+ Path(self.hparams.model_dir)
300
+ / f"D_{total_batch_idx if self.hparams.train.get('ckpt_name_by_step', False) else current_epoch}.pth",
301
+ )
302
+ keep_ckpts = self.hparams.train.get("keep_ckpts", 0)
303
+ if keep_ckpts > 0:
304
+ utils.clean_checkpoints(
305
+ path_to_models=self.hparams.model_dir,
306
+ n_ckpts_to_keep=keep_ckpts,
307
+ sort_by_time=True,
308
+ )
309
+
310
+ def set_current_epoch(self, epoch: int):
311
+ LOG.info(f"Setting current epoch to {epoch}")
312
+ self.trainer.fit_loop.epoch_progress.current.completed = epoch
313
+ self.trainer.fit_loop.epoch_progress.current.processed = epoch
314
+ assert self.current_epoch == epoch, f"{self.current_epoch} != {epoch}"
315
+
316
+ def set_global_step(self, global_step: int):
317
+ LOG.info(f"Setting global step to {global_step}")
318
+ self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed = (
319
+ global_step
320
+ )
321
+ self.trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.optimizer.step.total.completed = (
322
+ global_step
323
+ )
324
+ assert self.global_step == global_step, f"{self.global_step} != {global_step}"
325
+
326
+ def set_total_batch_idx(self, total_batch_idx: int):
327
+ LOG.info(f"Setting total batch idx to {total_batch_idx}")
328
+ self.trainer.fit_loop.epoch_loop.batch_progress.total.ready = (
329
+ total_batch_idx + 1
330
+ )
331
+ self.trainer.fit_loop.epoch_loop.batch_progress.total.completed = (
332
+ total_batch_idx
333
+ )
334
+ assert (
335
+ self.total_batch_idx == total_batch_idx + 1
336
+ ), f"{self.total_batch_idx} != {total_batch_idx + 1}"
337
+
338
+ @property
339
+ def total_batch_idx(self) -> int:
340
+ return self.trainer.fit_loop.epoch_loop.total_batch_idx + 1
341
+
342
+ def load(self, reset_optimizer: bool = False):
343
+ latest_g_path = utils.latest_checkpoint_path(self.hparams.model_dir, "G_*.pth")
344
+ latest_d_path = utils.latest_checkpoint_path(self.hparams.model_dir, "D_*.pth")
345
+ if latest_g_path is not None and latest_d_path is not None:
346
+ try:
347
+ _, _, _, epoch = utils.load_checkpoint(
348
+ latest_g_path,
349
+ self.net_g,
350
+ self.optim_g,
351
+ reset_optimizer,
352
+ )
353
+ _, _, _, epoch = utils.load_checkpoint(
354
+ latest_d_path,
355
+ self.net_d,
356
+ self.optim_d,
357
+ reset_optimizer,
358
+ )
359
+ self._temp_epoch = epoch
360
+ self.scheduler_g.last_epoch = epoch - 1
361
+ self.scheduler_d.last_epoch = epoch - 1
362
+ except Exception as e:
363
+ raise RuntimeError("Failed to load checkpoint") from e
364
+ else:
365
+ LOG.warning("No checkpoint found. Start from scratch.")
366
+
367
+ def configure_optimizers(self):
368
+ return [self.optim_g, self.optim_d], [self.scheduler_g, self.scheduler_d]
369
+
370
+ def log_image_dict(
371
+ self, image_dict: dict[str, Any], dataformats: str = "HWC"
372
+ ) -> None:
373
+ if not isinstance(self.logger, TensorBoardLogger):
374
+ warnings.warn("Image logging is only supported with TensorBoardLogger.")
375
+ return
376
+ writer: SummaryWriter = self.logger.experiment
377
+ for k, v in image_dict.items():
378
+ try:
379
+ writer.add_image(k, v, self.total_batch_idx, dataformats=dataformats)
380
+ except Exception as e:
381
+ warnings.warn(f"Failed to log image {k}: {e}")
382
+
383
+ def log_audio_dict(self, audio_dict: dict[str, Any]) -> None:
384
+ if not isinstance(self.logger, TensorBoardLogger):
385
+ warnings.warn("Audio logging is only supported with TensorBoardLogger.")
386
+ return
387
+ writer: SummaryWriter = self.logger.experiment
388
+ for k, v in audio_dict.items():
389
+ writer.add_audio(
390
+ k,
391
+ v.float(),
392
+ self.total_batch_idx,
393
+ sample_rate=self.hparams.data.sampling_rate,
394
+ )
395
+
396
+ def log_dict_(self, log_dict: dict[str, Any], **kwargs) -> None:
397
+ if not isinstance(self.logger, TensorBoardLogger):
398
+ warnings.warn("Logging is only supported with TensorBoardLogger.")
399
+ return
400
+ writer: SummaryWriter = self.logger.experiment
401
+ for k, v in log_dict.items():
402
+ writer.add_scalar(k, v, self.total_batch_idx)
403
+ kwargs["logger"] = False
404
+ self.log_dict(log_dict, **kwargs)
405
+
406
+ def log_(self, key: str, value: Any, **kwargs) -> None:
407
+ self.log_dict_({key: value}, **kwargs)
408
+
409
+ def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> None:
410
+ self.net_g.train()
411
+ self.net_d.train()
412
+
413
+ # get optims
414
+ optim_g, optim_d = self.optimizers()
415
+
416
+ # Generator
417
+ # train
418
+ self.toggle_optimizer(optim_g)
419
+ c, f0, spec, mel, y, g, lengths, uv = batch
420
+ (
421
+ y_hat,
422
+ y_hat_mb,
423
+ ids_slice,
424
+ z_mask,
425
+ (z, z_p, m_p, logs_p, m_q, logs_q),
426
+ pred_lf0,
427
+ norm_lf0,
428
+ lf0,
429
+ ) = self.net_g(c, f0, uv, spec, g=g, c_lengths=lengths, spec_lengths=lengths)
430
+ y_mel = commons.slice_segments(
431
+ mel,
432
+ ids_slice,
433
+ self.hparams.train.segment_size // self.hparams.data.hop_length,
434
+ )
435
+ y_hat_mel = mel_spectrogram_torch(y_hat.squeeze(1), self.hparams)
436
+ y_mel = y_mel[..., : y_hat_mel.shape[-1]]
437
+ y = commons.slice_segments(
438
+ y,
439
+ ids_slice * self.hparams.data.hop_length,
440
+ self.hparams.train.segment_size,
441
+ )
442
+ y = y[..., : y_hat.shape[-1]]
443
+
444
+ # generator loss
445
+ y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.net_d(y, y_hat)
446
+
447
+ with autocast(enabled=False):
448
+ loss_mel = F.l1_loss(y_mel, y_hat_mel) * self.hparams.train.c_mel
449
+ loss_kl = (
450
+ kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * self.hparams.train.c_kl
451
+ )
452
+ loss_fm = feature_loss(fmap_r, fmap_g)
453
+ loss_gen, losses_gen = generator_loss(y_d_hat_g)
454
+ loss_lf0 = F.mse_loss(pred_lf0, lf0)
455
+ loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl + loss_lf0
456
+
457
+ # MB-iSTFT-VITS
458
+ loss_subband = torch.tensor(0.0)
459
+ if self.hparams.model.get("type_") == "mb-istft":
460
+ from .modules.decoders.mb_istft import PQMF, subband_stft_loss
461
+
462
+ y_mb = PQMF(y.device, self.hparams.model.subbands).analysis(y)
463
+ loss_subband = subband_stft_loss(self.hparams, y_mb, y_hat_mb)
464
+ loss_gen_all += loss_subband
465
+
466
+ # log loss
467
+ self.log_("lr", self.optim_g.param_groups[0]["lr"])
468
+ self.log_dict_(
469
+ {
470
+ "loss/g/total": loss_gen_all,
471
+ "loss/g/fm": loss_fm,
472
+ "loss/g/mel": loss_mel,
473
+ "loss/g/kl": loss_kl,
474
+ "loss/g/lf0": loss_lf0,
475
+ },
476
+ prog_bar=True,
477
+ )
478
+ if self.hparams.model.get("type_") == "mb-istft":
479
+ self.log_("loss/g/subband", loss_subband)
480
+ if self.total_batch_idx % self.hparams.train.log_interval == 0:
481
+ self.log_image_dict(
482
+ {
483
+ "slice/mel_org": utils.plot_spectrogram_to_numpy(
484
+ y_mel[0].data.cpu().float().numpy()
485
+ ),
486
+ "slice/mel_gen": utils.plot_spectrogram_to_numpy(
487
+ y_hat_mel[0].data.cpu().float().numpy()
488
+ ),
489
+ "all/mel": utils.plot_spectrogram_to_numpy(
490
+ mel[0].data.cpu().float().numpy()
491
+ ),
492
+ "all/lf0": so_vits_svc_fork.utils.plot_data_to_numpy(
493
+ lf0[0, 0, :].cpu().float().numpy(),
494
+ pred_lf0[0, 0, :].detach().cpu().float().numpy(),
495
+ ),
496
+ "all/norm_lf0": so_vits_svc_fork.utils.plot_data_to_numpy(
497
+ lf0[0, 0, :].cpu().float().numpy(),
498
+ norm_lf0[0, 0, :].detach().cpu().float().numpy(),
499
+ ),
500
+ }
501
+ )
502
+
503
+ accumulate_grad_batches = self.hparams.train.get("accumulate_grad_batches", 1)
504
+ should_update = (
505
+ batch_idx + 1
506
+ ) % accumulate_grad_batches == 0 or self.trainer.is_last_batch
507
+ # optimizer
508
+ self.manual_backward(loss_gen_all / accumulate_grad_batches)
509
+ if should_update:
510
+ self.log_(
511
+ "grad_norm_g", commons.clip_grad_value_(self.net_g.parameters(), None)
512
+ )
513
+ optim_g.step()
514
+ optim_g.zero_grad()
515
+ self.untoggle_optimizer(optim_g)
516
+
517
+ # Discriminator
518
+ # train
519
+ self.toggle_optimizer(optim_d)
520
+ y_d_hat_r, y_d_hat_g, _, _ = self.net_d(y, y_hat.detach())
521
+
522
+ # discriminator loss
523
+ with autocast(enabled=False):
524
+ loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
525
+ y_d_hat_r, y_d_hat_g
526
+ )
527
+ loss_disc_all = loss_disc
528
+
529
+ # log loss
530
+ self.log_("loss/d/total", loss_disc_all, prog_bar=True)
531
+
532
+ # optimizer
533
+ self.manual_backward(loss_disc_all / accumulate_grad_batches)
534
+ if should_update:
535
+ self.log_(
536
+ "grad_norm_d", commons.clip_grad_value_(self.net_d.parameters(), None)
537
+ )
538
+ optim_d.step()
539
+ optim_d.zero_grad()
540
+ self.untoggle_optimizer(optim_d)
541
+
542
+ # end of epoch
543
+ if self.trainer.is_last_batch:
544
+ self.scheduler_g.step()
545
+ self.scheduler_d.step()
546
+
547
+ def validation_step(self, batch, batch_idx):
548
+ # avoid logging with wrong global step
549
+ if self.global_step == 0:
550
+ return
551
+ with torch.no_grad():
552
+ self.net_g.eval()
553
+ c, f0, _, mel, y, g, _, uv = batch
554
+ y_hat = self.net_g.infer(c, f0, uv, g=g)
555
+ y_hat_mel = mel_spectrogram_torch(y_hat.squeeze(1).float(), self.hparams)
556
+ self.log_audio_dict(
557
+ {f"gen/audio_{batch_idx}": y_hat[0], f"gt/audio_{batch_idx}": y[0]}
558
+ )
559
+ self.log_image_dict(
560
+ {
561
+ "gen/mel": utils.plot_spectrogram_to_numpy(
562
+ y_hat_mel[0].cpu().float().numpy()
563
+ ),
564
+ "gt/mel": utils.plot_spectrogram_to_numpy(
565
+ mel[0].cpu().float().numpy()
566
+ ),
567
+ }
568
+ )
569
+
570
+ def on_validation_end(self) -> None:
571
+ self.save_checkpoints()
so_vits_svc_fork/utils.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ import re
6
+ import subprocess
7
+ import warnings
8
+ from itertools import groupby
9
+ from logging import getLogger
10
+ from pathlib import Path
11
+ from typing import Any, Literal, Sequence
12
+
13
+ import matplotlib
14
+ import matplotlib.pylab as plt
15
+ import numpy as np
16
+ import requests
17
+ import torch
18
+ import torch.backends.mps
19
+ import torch.nn as nn
20
+ import torchaudio
21
+ from cm_time import timer
22
+ from numpy import ndarray
23
+ from tqdm import tqdm
24
+ from transformers import HubertModel
25
+
26
+ from so_vits_svc_fork.hparams import HParams
27
+
28
+ LOG = getLogger(__name__)
29
+ HUBERT_SAMPLING_RATE = 16000
30
+ IS_COLAB = os.getenv("COLAB_RELEASE_TAG", False)
31
+
32
+
33
+ def get_optimal_device(index: int = 0) -> torch.device:
34
+ if torch.cuda.is_available():
35
+ return torch.device(f"cuda:{index % torch.cuda.device_count()}")
36
+ elif torch.backends.mps.is_available():
37
+ return torch.device("mps")
38
+ else:
39
+ try:
40
+ import torch_xla.core.xla_model as xm # noqa
41
+
42
+ if xm.xrt_world_size() > 0:
43
+ return torch.device("xla")
44
+ # return xm.xla_device()
45
+ except ImportError:
46
+ pass
47
+ return torch.device("cpu")
48
+
49
+
50
+ def download_file(
51
+ url: str,
52
+ filepath: Path | str,
53
+ chunk_size: int = 64 * 1024,
54
+ tqdm_cls: type = tqdm,
55
+ skip_if_exists: bool = False,
56
+ overwrite: bool = False,
57
+ **tqdm_kwargs: Any,
58
+ ):
59
+ if skip_if_exists is True and overwrite is True:
60
+ raise ValueError("skip_if_exists and overwrite cannot be both True")
61
+ filepath = Path(filepath)
62
+ filepath.parent.mkdir(parents=True, exist_ok=True)
63
+ temppath = filepath.parent / f"{filepath.name}.download"
64
+ if filepath.exists():
65
+ if skip_if_exists:
66
+ return
67
+ elif not overwrite:
68
+ filepath.unlink()
69
+ else:
70
+ raise FileExistsError(f"{filepath} already exists")
71
+ temppath.unlink(missing_ok=True)
72
+ resp = requests.get(url, stream=True)
73
+ total = int(resp.headers.get("content-length", 0))
74
+ kwargs = dict(
75
+ total=total,
76
+ unit="iB",
77
+ unit_scale=True,
78
+ unit_divisor=1024,
79
+ desc=f"Downloading {filepath.name}",
80
+ )
81
+ kwargs.update(tqdm_kwargs)
82
+ with temppath.open("wb") as f, tqdm_cls(**kwargs) as pbar:
83
+ for data in resp.iter_content(chunk_size=chunk_size):
84
+ size = f.write(data)
85
+ pbar.update(size)
86
+ temppath.rename(filepath)
87
+
88
+
89
+ PRETRAINED_MODEL_URLS = {
90
+ "hifi-gan": [
91
+ [
92
+ "https://huggingface.co/therealvul/so-vits-svc-4.0-init/resolve/main/D_0.pth",
93
+ "https://huggingface.co/therealvul/so-vits-svc-4.0-init/resolve/main/G_0.pth",
94
+ ],
95
+ [
96
+ "https://huggingface.co/Himawari00/so-vits-svc4.0-pretrain-models/resolve/main/D_0.pth",
97
+ "https://huggingface.co/Himawari00/so-vits-svc4.0-pretrain-models/resolve/main/G_0.pth",
98
+ ],
99
+ ],
100
+ "contentvec": [
101
+ [
102
+ "https://huggingface.co/therealvul/so-vits-svc-4.0-init/resolve/main/checkpoint_best_legacy_500.pt"
103
+ ],
104
+ [
105
+ "https://huggingface.co/Himawari00/so-vits-svc4.0-pretrain-models/resolve/main/checkpoint_best_legacy_500.pt"
106
+ ],
107
+ [
108
+ "http://obs.cstcloud.cn/share/obs/sankagenkeshi/checkpoint_best_legacy_500.pt"
109
+ ],
110
+ ],
111
+ }
112
+ from joblib import Parallel, delayed
113
+
114
+
115
+ def ensure_pretrained_model(
116
+ folder_path: Path | str, type_: str | dict[str, str], **tqdm_kwargs: Any
117
+ ) -> tuple[Path, ...] | None:
118
+ folder_path = Path(folder_path)
119
+
120
+ # new code
121
+ if not isinstance(type_, str):
122
+ try:
123
+ Parallel(n_jobs=len(type_))(
124
+ [
125
+ delayed(download_file)(
126
+ url,
127
+ folder_path / filename,
128
+ position=i,
129
+ skip_if_exists=True,
130
+ **tqdm_kwargs,
131
+ )
132
+ for i, (filename, url) in enumerate(type_.items())
133
+ ]
134
+ )
135
+ return tuple(folder_path / filename for filename in type_.values())
136
+ except Exception as e:
137
+ LOG.error(f"Failed to download {type_}")
138
+ LOG.exception(e)
139
+
140
+ # old code
141
+ models_candidates = PRETRAINED_MODEL_URLS.get(type_, None)
142
+ if models_candidates is None:
143
+ LOG.warning(f"Unknown pretrained model type: {type_}")
144
+ return
145
+ for model_urls in models_candidates:
146
+ paths = [folder_path / model_url.split("/")[-1] for model_url in model_urls]
147
+ try:
148
+ Parallel(n_jobs=len(paths))(
149
+ [
150
+ delayed(download_file)(
151
+ url, path, position=i, skip_if_exists=True, **tqdm_kwargs
152
+ )
153
+ for i, (url, path) in enumerate(zip(model_urls, paths))
154
+ ]
155
+ )
156
+ return tuple(paths)
157
+ except Exception as e:
158
+ LOG.error(f"Failed to download {model_urls}")
159
+ LOG.exception(e)
160
+
161
+
162
+ class HubertModelWithFinalProj(HubertModel):
163
+ def __init__(self, config):
164
+ super().__init__(config)
165
+
166
+ # The final projection layer is only used for backward compatibility.
167
+ # Following https://github.com/auspicious3000/contentvec/issues/6
168
+ # Remove this layer is necessary to achieve the desired outcome.
169
+ self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
170
+
171
+
172
+ def remove_weight_norm_if_exists(module, name: str = "weight"):
173
+ r"""Removes the weight normalization reparameterization from a module.
174
+
175
+ Args:
176
+ module (Module): containing module
177
+ name (str, optional): name of weight parameter
178
+
179
+ Example:
180
+ >>> m = weight_norm(nn.Linear(20, 40))
181
+ >>> remove_weight_norm(m)
182
+ """
183
+ from torch.nn.utils.weight_norm import WeightNorm
184
+
185
+ for k, hook in module._forward_pre_hooks.items():
186
+ if isinstance(hook, WeightNorm) and hook.name == name:
187
+ hook.remove(module)
188
+ del module._forward_pre_hooks[k]
189
+ return module
190
+
191
+
192
+ def get_hubert_model(
193
+ device: str | torch.device, final_proj: bool = True
194
+ ) -> HubertModel:
195
+ if final_proj:
196
+ model = HubertModelWithFinalProj.from_pretrained("lengyue233/content-vec-best")
197
+ else:
198
+ model = HubertModel.from_pretrained("lengyue233/content-vec-best")
199
+ # Hubert is always used in inference mode, we can safely remove weight-norms
200
+ for m in model.modules():
201
+ if isinstance(m, (nn.Conv2d, nn.Conv1d)):
202
+ remove_weight_norm_if_exists(m)
203
+
204
+ return model.to(device)
205
+
206
+
207
+ def get_content(
208
+ cmodel: HubertModel,
209
+ audio: torch.Tensor | ndarray[Any, Any],
210
+ device: torch.device | str,
211
+ sr: int,
212
+ legacy_final_proj: bool = False,
213
+ ) -> torch.Tensor:
214
+ audio = torch.as_tensor(audio)
215
+ if sr != HUBERT_SAMPLING_RATE:
216
+ audio = (
217
+ torchaudio.transforms.Resample(sr, HUBERT_SAMPLING_RATE)
218
+ .to(audio.device)(audio)
219
+ .to(device)
220
+ )
221
+ if audio.ndim == 1:
222
+ audio = audio.unsqueeze(0)
223
+ with torch.no_grad(), timer() as t:
224
+ if legacy_final_proj:
225
+ warnings.warn("legacy_final_proj is deprecated")
226
+ if not hasattr(cmodel, "final_proj"):
227
+ raise ValueError("HubertModel does not have final_proj")
228
+ c = cmodel(audio, output_hidden_states=True)["hidden_states"][9]
229
+ c = cmodel.final_proj(c)
230
+ else:
231
+ c = cmodel(audio)["last_hidden_state"]
232
+ c = c.transpose(1, 2)
233
+ wav_len = audio.shape[-1] / HUBERT_SAMPLING_RATE
234
+ LOG.info(
235
+ f"HuBERT inference time : {t.elapsed:.3f}s, RTF: {t.elapsed / wav_len:.3f}"
236
+ )
237
+ return c
238
+
239
+
240
+ def _substitute_if_same_shape(to_: dict[str, Any], from_: dict[str, Any]) -> None:
241
+ not_in_to = list(filter(lambda x: x not in to_, from_.keys()))
242
+ not_in_from = list(filter(lambda x: x not in from_, to_.keys()))
243
+ if not_in_to:
244
+ warnings.warn(f"Keys not found in model state dict:" f"{not_in_to}")
245
+ if not_in_from:
246
+ warnings.warn(f"Keys not found in checkpoint state dict:" f"{not_in_from}")
247
+ shape_missmatch = []
248
+ for k, v in from_.items():
249
+ if k not in to_:
250
+ pass
251
+ elif hasattr(v, "shape"):
252
+ if not hasattr(to_[k], "shape"):
253
+ raise ValueError(f"Key {k} is not a tensor")
254
+ if to_[k].shape == v.shape:
255
+ to_[k] = v
256
+ else:
257
+ shape_missmatch.append((k, to_[k].shape, v.shape))
258
+ elif isinstance(v, dict):
259
+ assert isinstance(to_[k], dict)
260
+ _substitute_if_same_shape(to_[k], v)
261
+ else:
262
+ to_[k] = v
263
+ if shape_missmatch:
264
+ warnings.warn(
265
+ f"Shape mismatch: {[f'{k}: {v1} -> {v2}' for k, v1, v2 in shape_missmatch]}"
266
+ )
267
+
268
+
269
+ def safe_load(model: torch.nn.Module, state_dict: dict[str, Any]) -> None:
270
+ model_state_dict = model.state_dict()
271
+ _substitute_if_same_shape(model_state_dict, state_dict)
272
+ model.load_state_dict(model_state_dict)
273
+
274
+
275
+ def load_checkpoint(
276
+ checkpoint_path: Path | str,
277
+ model: torch.nn.Module,
278
+ optimizer: torch.optim.Optimizer | None = None,
279
+ skip_optimizer: bool = False,
280
+ ) -> tuple[torch.nn.Module, torch.optim.Optimizer | None, float, int]:
281
+ if not Path(checkpoint_path).is_file():
282
+ raise FileNotFoundError(f"File {checkpoint_path} not found")
283
+ with Path(checkpoint_path).open("rb") as f:
284
+ with warnings.catch_warnings():
285
+ warnings.filterwarnings(
286
+ "ignore", category=UserWarning, message="TypedStorage is deprecated"
287
+ )
288
+ checkpoint_dict = torch.load(f, map_location="cpu", weights_only=True)
289
+ iteration = checkpoint_dict["iteration"]
290
+ learning_rate = checkpoint_dict["learning_rate"]
291
+
292
+ # safe load module
293
+ if hasattr(model, "module"):
294
+ safe_load(model.module, checkpoint_dict["model"])
295
+ else:
296
+ safe_load(model, checkpoint_dict["model"])
297
+ # safe load optim
298
+ if (
299
+ optimizer is not None
300
+ and not skip_optimizer
301
+ and checkpoint_dict["optimizer"] is not None
302
+ ):
303
+ with warnings.catch_warnings():
304
+ warnings.simplefilter("ignore")
305
+ safe_load(optimizer, checkpoint_dict["optimizer"])
306
+
307
+ LOG.info(f"Loaded checkpoint '{checkpoint_path}' (epoch {iteration})")
308
+ return model, optimizer, learning_rate, iteration
309
+
310
+
311
+ def save_checkpoint(
312
+ model: torch.nn.Module,
313
+ optimizer: torch.optim.Optimizer,
314
+ learning_rate: float,
315
+ iteration: int,
316
+ checkpoint_path: Path | str,
317
+ ) -> None:
318
+ LOG.info(
319
+ "Saving model and optimizer state at epoch {} to {}".format(
320
+ iteration, checkpoint_path
321
+ )
322
+ )
323
+ if hasattr(model, "module"):
324
+ state_dict = model.module.state_dict()
325
+ else:
326
+ state_dict = model.state_dict()
327
+ with Path(checkpoint_path).open("wb") as f:
328
+ torch.save(
329
+ {
330
+ "model": state_dict,
331
+ "iteration": iteration,
332
+ "optimizer": optimizer.state_dict(),
333
+ "learning_rate": learning_rate,
334
+ },
335
+ f,
336
+ )
337
+
338
+
339
+ def clean_checkpoints(
340
+ path_to_models: Path | str, n_ckpts_to_keep: int = 2, sort_by_time: bool = True
341
+ ) -> None:
342
+ """Freeing up space by deleting saved ckpts
343
+
344
+ Arguments:
345
+ path_to_models -- Path to the model directory
346
+ n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
347
+ sort_by_time -- True -> chronologically delete ckpts
348
+ False -> lexicographically delete ckpts
349
+ """
350
+ LOG.info("Cleaning old checkpoints...")
351
+ path_to_models = Path(path_to_models)
352
+
353
+ # Define sort key functions
354
+ name_key = lambda p: int(re.match(r"[GD]_(\d+)", p.stem).group(1))
355
+ time_key = lambda p: p.stat().st_mtime
356
+ path_key = lambda p: (p.stem[0], time_key(p) if sort_by_time else name_key(p))
357
+
358
+ models = list(
359
+ filter(
360
+ lambda p: (
361
+ p.is_file()
362
+ and re.match(r"[GD]_\d+", p.stem)
363
+ and not p.stem.endswith("_0")
364
+ ),
365
+ path_to_models.glob("*.pth"),
366
+ )
367
+ )
368
+
369
+ models_sorted = sorted(models, key=path_key)
370
+
371
+ models_sorted_grouped = groupby(models_sorted, lambda p: p.stem[0])
372
+
373
+ for group_name, group_items in models_sorted_grouped:
374
+ to_delete_list = list(group_items)[:-n_ckpts_to_keep]
375
+
376
+ for to_delete in to_delete_list:
377
+ if to_delete.exists():
378
+ LOG.info(f"Removing {to_delete}")
379
+ if IS_COLAB:
380
+ to_delete.write_text("")
381
+ to_delete.unlink()
382
+
383
+
384
+ def latest_checkpoint_path(dir_path: Path | str, regex: str = "G_*.pth") -> Path | None:
385
+ dir_path = Path(dir_path)
386
+ name_key = lambda p: int(re.match(r"._(\d+)\.pth", p.name).group(1))
387
+ paths = list(sorted(dir_path.glob(regex), key=name_key))
388
+ if len(paths) == 0:
389
+ return None
390
+ return paths[-1]
391
+
392
+
393
+ def plot_spectrogram_to_numpy(spectrogram: ndarray) -> ndarray:
394
+ matplotlib.use("Agg")
395
+ fig, ax = plt.subplots(figsize=(10, 2))
396
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
397
+ plt.colorbar(im, ax=ax)
398
+ plt.xlabel("Frames")
399
+ plt.ylabel("Channels")
400
+ plt.tight_layout()
401
+
402
+ fig.canvas.draw()
403
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
404
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
405
+ plt.close()
406
+ return data
407
+
408
+
409
+ def get_backup_hparams(
410
+ config_path: Path, model_path: Path, init: bool = True
411
+ ) -> HParams:
412
+ model_path.mkdir(parents=True, exist_ok=True)
413
+ config_save_path = model_path / "config.json"
414
+ if init:
415
+ with config_path.open() as f:
416
+ data = f.read()
417
+ with config_save_path.open("w") as f:
418
+ f.write(data)
419
+ else:
420
+ with config_save_path.open() as f:
421
+ data = f.read()
422
+ config = json.loads(data)
423
+
424
+ hparams = HParams(**config)
425
+ hparams.model_dir = model_path.as_posix()
426
+ return hparams
427
+
428
+
429
+ def get_hparams(config_path: Path | str) -> HParams:
430
+ config = json.loads(Path(config_path).read_text("utf-8"))
431
+ hparams = HParams(**config)
432
+ return hparams
433
+
434
+
435
+ def repeat_expand_2d(content: torch.Tensor, target_len: int) -> torch.Tensor:
436
+ # content : [h, t]
437
+ src_len = content.shape[-1]
438
+ if target_len < src_len:
439
+ return content[:, :target_len]
440
+ else:
441
+ return torch.nn.functional.interpolate(
442
+ content.unsqueeze(0), size=target_len, mode="nearest"
443
+ ).squeeze(0)
444
+
445
+
446
+ def plot_data_to_numpy(x: ndarray, y: ndarray) -> ndarray:
447
+ matplotlib.use("Agg")
448
+ fig, ax = plt.subplots(figsize=(10, 2))
449
+ plt.plot(x)
450
+ plt.plot(y)
451
+ plt.tight_layout()
452
+
453
+ fig.canvas.draw()
454
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
455
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
456
+ plt.close()
457
+ return data
458
+
459
+
460
+ def get_gpu_memory(type_: Literal["total", "free", "used"]) -> Sequence[int] | None:
461
+ command = f"nvidia-smi --query-gpu=memory.{type_} --format=csv"
462
+ try:
463
+ memory_free_info = (
464
+ subprocess.check_output(command.split())
465
+ .decode("ascii")
466
+ .split("\n")[:-1][1:]
467
+ )
468
+ memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
469
+ return memory_free_values
470
+ except Exception:
471
+ return
472
+
473
+
474
+ def get_total_gpu_memory(type_: Literal["total", "free", "used"]) -> int | None:
475
+ memories = get_gpu_memory(type_)
476
+ if memories is None:
477
+ return
478
+ return sum(memories)