Abhinay45 commited on
Commit
cf68802
·
verified ·
1 Parent(s): f69438d

Add files using upload-large-folder tool

Browse files
best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:23aa1fbbef3ec860e559fcb94a9830930ea7111c3bd73ba3d3e3957c5aeca369
3
+ size 5656611329
best_model_80180.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:23aa1fbbef3ec860e559fcb94a9830930ea7111c3bd73ba3d3e3957c5aeca369
3
+ size 5656611329
checkpoint_80000.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d0ec02d5d1a4a668edc2a4bb01dfbb7d78842aa1bc64fa63f456a40bf7637a7a
3
+ size 5656611265
config.json ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "output_path": "checkpoints/",
3
+ "logger_uri": null,
4
+ "run_name": "GPT_XTTS_FT",
5
+ "project_name": "XTTS_trainer",
6
+ "run_description": [
7
+ "\n GPT XTTS training\n "
8
+ ],
9
+ "print_step": 50,
10
+ "plot_step": 100,
11
+ "model_param_stats": false,
12
+ "wandb_entity": null,
13
+ "dashboard_logger": "tensorboard",
14
+ "save_on_interrupt": true,
15
+ "log_model_step": 100,
16
+ "save_step": 1000,
17
+ "save_n_checkpoints": 1,
18
+ "save_checkpoints": true,
19
+ "save_all_best": false,
20
+ "save_best_after": 10000,
21
+ "target_loss": null,
22
+ "print_eval": false,
23
+ "test_delay_epochs": 0,
24
+ "run_eval": true,
25
+ "run_eval_steps": null,
26
+ "distributed_backend": "nccl",
27
+ "distributed_url": "tcp://localhost:54321",
28
+ "mixed_precision": false,
29
+ "precision": "fp16",
30
+ "epochs": 10,
31
+ "batch_size": 8,
32
+ "eval_batch_size": 16,
33
+ "grad_clip": 0.0,
34
+ "scheduler_after_epoch": true,
35
+ "lr": 5e-06,
36
+ "optimizer": "AdamW",
37
+ "optimizer_params": {
38
+ "betas": [
39
+ 0.9,
40
+ 0.96
41
+ ],
42
+ "eps": 1e-08,
43
+ "weight_decay": 0.01
44
+ },
45
+ "lr_scheduler": "MultiStepLR",
46
+ "lr_scheduler_params": {
47
+ "milestones": [
48
+ 900000,
49
+ 2700000,
50
+ 5400000
51
+ ],
52
+ "gamma": 0.5,
53
+ "last_epoch": -1
54
+ },
55
+ "use_grad_scaler": false,
56
+ "allow_tf32": false,
57
+ "cudnn_enable": true,
58
+ "cudnn_deterministic": false,
59
+ "cudnn_benchmark": false,
60
+ "training_seed": 54321,
61
+ "model": "xtts",
62
+ "num_loader_workers": 8,
63
+ "num_eval_loader_workers": 0,
64
+ "use_noise_augment": false,
65
+ "audio": {
66
+ "sample_rate": 22050,
67
+ "output_sample_rate": 24000,
68
+ "dvae_sample_rate": 22050
69
+ },
70
+ "use_phonemes": false,
71
+ "phonemizer": null,
72
+ "phoneme_language": null,
73
+ "compute_input_seq_cache": false,
74
+ "text_cleaner": null,
75
+ "enable_eos_bos_chars": false,
76
+ "test_sentences_file": "",
77
+ "phoneme_cache_path": null,
78
+ "characters": null,
79
+ "add_blank": false,
80
+ "batch_group_size": 0,
81
+ "loss_masking": null,
82
+ "min_audio_len": 1,
83
+ "max_audio_len": Infinity,
84
+ "min_text_len": 1,
85
+ "max_text_len": Infinity,
86
+ "compute_f0": false,
87
+ "compute_energy": false,
88
+ "compute_linear_spec": false,
89
+ "precompute_num_workers": 0,
90
+ "start_by_longest": false,
91
+ "shuffle": false,
92
+ "drop_last": false,
93
+ "datasets": [
94
+ {
95
+ "formatter": "",
96
+ "dataset_name": "",
97
+ "path": "",
98
+ "meta_file_train": "",
99
+ "ignored_speakers": null,
100
+ "language": "",
101
+ "phonemizer": "",
102
+ "meta_file_val": "",
103
+ "meta_file_attn_mask": ""
104
+ }
105
+ ],
106
+ "test_sentences": [],
107
+ "eval_split_max_size": 256,
108
+ "eval_split_size": 0.01,
109
+ "use_speaker_weighted_sampler": false,
110
+ "speaker_weighted_sampler_alpha": 1.0,
111
+ "use_language_weighted_sampler": false,
112
+ "language_weighted_sampler_alpha": 1.0,
113
+ "use_length_weighted_sampler": false,
114
+ "length_weighted_sampler_alpha": 1.0,
115
+ "model_args": {
116
+ "gpt_batch_size": 1,
117
+ "enable_redaction": false,
118
+ "kv_cache": true,
119
+ "gpt_checkpoint": "",
120
+ "clvp_checkpoint": null,
121
+ "decoder_checkpoint": null,
122
+ "num_chars": 255,
123
+ "tokenizer_file": "checkpoints/XTTS_v2.0_original_model_files/vocab.json",
124
+ "gpt_max_audio_tokens": 605,
125
+ "gpt_max_text_tokens": 402,
126
+ "gpt_max_prompt_tokens": 70,
127
+ "gpt_layers": 30,
128
+ "gpt_n_model_channels": 1024,
129
+ "gpt_n_heads": 16,
130
+ "gpt_number_text_tokens": 8661,
131
+ "gpt_start_text_token": 261,
132
+ "gpt_stop_text_token": 0,
133
+ "gpt_num_audio_tokens": 1026,
134
+ "gpt_start_audio_token": 1024,
135
+ "gpt_stop_audio_token": 1025,
136
+ "gpt_code_stride_len": 1024,
137
+ "gpt_use_masking_gt_prompt_approach": true,
138
+ "gpt_use_perceiver_resampler": true,
139
+ "input_sample_rate": 22050,
140
+ "output_sample_rate": 24000,
141
+ "output_hop_length": 256,
142
+ "decoder_input_dim": 1024,
143
+ "d_vector_dim": 512,
144
+ "cond_d_vector_in_each_upsampling_layer": true,
145
+ "duration_const": 102400,
146
+ "min_conditioning_length": 11025,
147
+ "max_conditioning_length": 132300,
148
+ "gpt_loss_text_ce_weight": 0.01,
149
+ "gpt_loss_mel_ce_weight": 1.0,
150
+ "debug_loading_failures": false,
151
+ "max_wav_length": 330750,
152
+ "max_text_length": 400,
153
+ "mel_norm_file": "checkpoints/XTTS_v2.0_original_model_files/mel_stats.pth",
154
+ "dvae_checkpoint": "checkpoints/XTTS_v2.0_original_model_files/dvae.pth",
155
+ "xtts_checkpoint": "checkpoints/XTTS_v2.0_original_model_files/model.pth",
156
+ "vocoder": ""
157
+ },
158
+ "model_dir": null,
159
+ "languages": [
160
+ "en",
161
+ "es",
162
+ "fr",
163
+ "de",
164
+ "it",
165
+ "pt",
166
+ "pl",
167
+ "tr",
168
+ "ru",
169
+ "nl",
170
+ "cs",
171
+ "ar",
172
+ "zh-cn",
173
+ "hu",
174
+ "ko",
175
+ "ja",
176
+ "hi",
177
+ "te"
178
+ ],
179
+ "temperature": 0.75,
180
+ "length_penalty": 1.0,
181
+ "repetition_penalty": 5.0,
182
+ "top_k": 50,
183
+ "top_p": 0.85,
184
+ "num_gpt_outputs": 1,
185
+ "gpt_cond_len": 30,
186
+ "gpt_cond_chunk_len": 4,
187
+ "max_ref_len": 30,
188
+ "sound_norm_refs": false,
189
+ "optimizer_wd_only_on_weights": true,
190
+ "weighted_loss_attrs": null,
191
+ "weighted_loss_multipliers": null,
192
+ "github_branch": "* main"
193
+ }
events.out.tfevents.1739298885.node2.188457.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ea042bd25d446dae787dfc01d9ec0249466eb72bea5d5a8f0eeb957501a8fd4
3
+ size 331915
inference.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ from tqdm import tqdm
4
+ from underthesea import sent_tokenize
5
+ import os
6
+
7
+ from TTS.tts.configs.xtts_config import XttsConfig
8
+ from TTS.tts.models.xtts import Xtts
9
+ from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer # Import the tokenizer
10
+
11
+ # Device configuration
12
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
13
+
14
+ # Model paths
15
+ checkpoint_dir = "/export/home/vivian/svarah/XTTSv2-Finetuning-for-New-Languages/checkpoints/GPT_XTTS_FT-February-12-2025_12+04AM-8e59ec3"
16
+ xtts_checkpoint = os.path.join(checkpoint_dir, "best_model.pth")
17
+ xtts_config = os.path.join(checkpoint_dir, "config.json")
18
+ xtts_vocab = "/export/home/vivian/svarah/XTTSv2-Finetuning-for-New-Languages/checkpoints/XTTS_v2.0_original_model_files/vocab.json" # Path to vocab file
19
+ speaker_file_path = os.path.join(checkpoint_dir, "speakers_xtts.pth") # Path to speaker file
20
+
21
+ # Load model
22
+ config = XttsConfig()
23
+ config.load_json(xtts_config)
24
+
25
+ # Initialize the tokenizer
26
+ tokenizer = VoiceBpeTokenizer(xtts_vocab) # Manually initialize the tokenizer
27
+
28
+ # Initialize the model
29
+ XTTS_MODEL = Xtts.init_from_config(config)
30
+
31
+ # Load checkpoint
32
+ XTTS_MODEL.load_checkpoint(
33
+ config,
34
+ checkpoint_path=xtts_checkpoint,
35
+ checkpoint_dir=checkpoint_dir, # Explicitly provide checkpoint_dir
36
+ vocab_path=xtts_vocab,
37
+ speaker_file_path=speaker_file_path, # Explicitly provide speaker file path
38
+ use_deepspeed=False,
39
+ )
40
+ XTTS_MODEL.to(device)
41
+ # Manually set the tokenizer
42
+
43
+ print("Model loaded successfully!")
44
+
45
+ # Inference
46
+ tts_text = "పల్నాడు ప్రాంతంలోని చిన్న గ్రామంలో అపరాజితుడు అనే బాలుడు తన తల్లిదండ్రులతో కలిసి ఉండేవాడు. అతని కుటుంబం గొప్ప ధనవంతం కాకపోయినా, అతనికి గొప్ప లక్ష్యసాధన కోరిక ఉండేది. అతనికి చదువుపై అపారమైన ఆసక్తి, కానీ కుటుంబ స్థితిగతుల వల్ల మంచి పాఠశాలకు వెళ్లే అవకాశం లేకపోయింది. అపరాజితుడు ఉదయాన్నే లేచి మేతెల్లో పనిచేసి తల్లికి సహాయపడుతూ, మిగిలిన సమయాన్ని తన చదువుకు అంకితం చేసేవాడు. ఊరిలోని చిన్న గ్రంథాలయంలో ఉన్న పుస్తకాలను చదివి కొత్త విషయాలు నేర్చుకునే ప్రయత్నం చేసేవాడు. అతని పట్టుదల చూసిన గ్రామంలోని ఓ పండితుడు, బాలుడా నీ ప్రయత్నం అమోఘం, నీ లక్ష్యాన్ని సాధించేందుకు నేను సహాయపడతాను అని చెప్పి, అతనికి కావాల్సిన పుస్తకాలను అందించేవాడు. కొన్నేళ్ల తర్వాత, అపరాజితుడు తన కృషి, పట్టుదలతో పరీక్షల్లో రాష్ట్ర స్థాయిలో ఉత్తమ ర్యాంకు సాధించి, ప్రభుత్వ స్కాలర్‌షిప్‌తో పెద్ద నగరానికి వెళ్లాడు. అక్కడ ఉన్నత చదువులు పూర్తి చేసి, ఒక గొప్ప శాస్త్రవేత్తగా ఎదిగాడు. తన ఊరిని మర్చిపోకుండా, అక్కడ ఒక పాఠశాలను స్థాపించి, చాలా మంది పేద పిల్లలకు ఉచిత విద్య అందించాడు. అతని విజయాన్ని చూసి గ్రామస్థులు గర్వంతో ఆనందంతో మన ఊరి బాలుడు ప్రపంచానికి వెలుగునిచ్చాడు అని ప్రశంసించారు. కష్టపడే మనసు, పట్టుదల ఉంటే, ఏ లక్ష్యాన్నైనా సాధించవచ్చు."
47
+ speaker_audio_file = "/export/home/vivian/svarah/telugu_tts/Telugu/wavs/4503599627574654_chunk_10_enhanced.wav"
48
+
49
+ lang = "te"
50
+
51
+ gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(
52
+ audio_path=speaker_audio_file,
53
+ gpt_cond_len=XTTS_MODEL.config.gpt_cond_len,
54
+ max_ref_length=XTTS_MODEL.config.max_ref_len,
55
+ sound_norm_refs=XTTS_MODEL.config.sound_norm_refs,
56
+ )
57
+
58
+ tts_texts = sent_tokenize(tts_text)
59
+
60
+ wav_chunks = []
61
+ for text in tqdm(tts_texts):
62
+ wav_chunk = XTTS_MODEL.inference(
63
+ text=text,
64
+ language=lang,
65
+ gpt_cond_latent=gpt_cond_latent,
66
+ speaker_embedding=speaker_embedding,
67
+ temperature=0.1,
68
+ length_penalty=1.0,
69
+ repetition_penalty=10.0,
70
+ top_k=10,
71
+ top_p=0.3,
72
+ )
73
+ wav_chunks.append(torch.tensor(wav_chunk["wav"]))
74
+
75
+ out_wav = torch.cat(wav_chunks, dim=0).unsqueeze(0).cpu()
76
+
77
+ # Play audio (for Jupyter Notebook)
78
+ output_audio_path = "/export/home/vivian/svarah/XTTSv2-Finetuning-for-New-Languages/checkpoints/output_audio2.wav"
79
+ torchaudio.save(output_audio_path, out_wav, sample_rate=24000)
80
+
81
+ print(f"Audio saved to {output_audio_path}")
karan_narration1.wav ADDED
Binary file (60.6 kB). View file
 
train_gpt_xtts.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+
4
+ from trainer import Trainer, TrainerArgs
5
+
6
+ from TTS.config.shared_configs import BaseDatasetConfig
7
+ from TTS.tts.datasets import load_tts_samples
8
+ from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
9
+ from TTS.utils.manage import ModelManager
10
+
11
+ from dataclasses import dataclass, field
12
+ from typing import Optional
13
+ from transformers import HfArgumentParser
14
+
15
+ import argparse
16
+
17
+ def create_xtts_trainer_parser():
18
+ parser = argparse.ArgumentParser(description="Arguments for XTTS Trainer")
19
+
20
+ parser.add_argument("--output_path", type=str, required=True,
21
+ help="Path to pretrained + checkpoint model")
22
+ parser.add_argument("--metadatas", nargs='+', type=str, required=True,
23
+ help="train_csv_path,eval_csv_path,language")
24
+ parser.add_argument("--num_epochs", type=int, default=1,
25
+ help="Number of epochs")
26
+ parser.add_argument("--batch_size", type=int, default=1,
27
+ help="Mini batch size")
28
+ parser.add_argument("--grad_acumm", type=int, default=1,
29
+ help="Grad accumulation steps")
30
+ parser.add_argument("--max_audio_length", type=int, default=255995,
31
+ help="Max audio length")
32
+ parser.add_argument("--max_text_length", type=int, default=200,
33
+ help="Max text length")
34
+ parser.add_argument("--weight_decay", type=float, default=1e-2,
35
+ help="Weight decay")
36
+ parser.add_argument("--lr", type=float, default=5e-6,
37
+ help="Learning rate")
38
+ parser.add_argument("--save_step", type=int, default=5000,
39
+ help="Save step")
40
+
41
+ return parser
42
+
43
+
44
+
45
+ def train_gpt(metadatas, num_epochs, batch_size, grad_acumm, output_path, max_audio_length, max_text_length, lr, weight_decay, save_step):
46
+ # Logging parameters
47
+ RUN_NAME = "GPT_XTTS_FT"
48
+ PROJECT_NAME = "XTTS_trainer"
49
+ DASHBOARD_LOGGER = "tensorboard"
50
+ LOGGER_URI = None
51
+
52
+ # Set here the path that the checkpoints will be saved. Default: ./run/training/
53
+ # OUT_PATH = os.path.join(output_path, "run", "training")
54
+ OUT_PATH = output_path
55
+
56
+ # Training Parameters
57
+ OPTIMIZER_WD_ONLY_ON_WEIGHTS = True # for multi-gpu training please make it False
58
+ START_WITH_EVAL = False # if True it will star with evaluation
59
+ BATCH_SIZE = batch_size # set here the batch size
60
+ GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps
61
+
62
+
63
+ # Define here the dataset that you want to use for the fine-tuning on.
64
+ DATASETS_CONFIG_LIST = []
65
+ for metadata in metadatas:
66
+ train_csv, eval_csv, language = metadata.split(",")
67
+ print(train_csv, eval_csv, language)
68
+
69
+ config_dataset = BaseDatasetConfig(
70
+ formatter="coqui",
71
+ dataset_name="ft_dataset",
72
+ path=os.path.dirname(train_csv),
73
+ meta_file_train=os.path.basename(train_csv),
74
+ meta_file_val=os.path.basename(eval_csv),
75
+ language=language,
76
+ )
77
+
78
+ DATASETS_CONFIG_LIST.append(config_dataset)
79
+
80
+ # Define the path where XTTS v2.0.1 files will be downloaded
81
+ CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/")
82
+ os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
83
+
84
+
85
+ # DVAE files
86
+ DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth"
87
+ MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth"
88
+
89
+ # Set the path to the downloaded files
90
+ DVAE_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(DVAE_CHECKPOINT_LINK))
91
+ MEL_NORM_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(MEL_NORM_LINK))
92
+
93
+ # download DVAE files if needed
94
+ if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
95
+ print(" > Downloading DVAE files!")
96
+ ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True)
97
+
98
+
99
+ # Download XTTS v2.0 checkpoint if needed
100
+ TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json"
101
+ XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth"
102
+ XTTS_CONFIG_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/config.json"
103
+
104
+ # XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning.
105
+ TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(TOKENIZER_FILE_LINK)) # vocab.json file
106
+ XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CHECKPOINT_LINK)) # model.pth file
107
+ XTTS_CONFIG_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CONFIG_LINK)) # config.json file
108
+
109
+ # download XTTS v2.0 files if needed
110
+ if not os.path.isfile(TOKENIZER_FILE):
111
+ print(" > Downloading XTTS v2.0 tokenizer!")
112
+ ModelManager._download_model_files(
113
+ [TOKENIZER_FILE_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True
114
+ )
115
+ if not os.path.isfile(XTTS_CHECKPOINT):
116
+ print(" > Downloading XTTS v2.0 checkpoint!")
117
+ ModelManager._download_model_files(
118
+ [XTTS_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True
119
+ )
120
+ if not os.path.isfile(XTTS_CONFIG_FILE):
121
+ print(" > Downloading XTTS v2.0 config!")
122
+ ModelManager._download_model_files(
123
+ [XTTS_CONFIG_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True
124
+ )
125
+
126
+ # init args and config
127
+ model_args = GPTArgs(
128
+ max_conditioning_length=132300, # 6 secs
129
+ min_conditioning_length=11025, # 0.5 secs
130
+ debug_loading_failures=False,
131
+ max_wav_length=max_audio_length, # ~11.6 seconds
132
+ max_text_length=max_text_length,
133
+ mel_norm_file=MEL_NORM_FILE,
134
+ dvae_checkpoint=DVAE_CHECKPOINT,
135
+ xtts_checkpoint=XTTS_CHECKPOINT, # checkpoint path of the model that you want to fine-tune
136
+ tokenizer_file=TOKENIZER_FILE,
137
+ gpt_num_audio_tokens=1026,
138
+ gpt_start_audio_token=1024,
139
+ gpt_stop_audio_token=1025,
140
+ gpt_use_masking_gt_prompt_approach=True,
141
+ gpt_use_perceiver_resampler=True,
142
+ )
143
+ # define audio config
144
+ audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000)
145
+ # training parameters config
146
+
147
+ config = GPTTrainerConfig()
148
+
149
+ config.load_json(XTTS_CONFIG_FILE)
150
+
151
+ config.epochs = num_epochs
152
+ config.output_path = OUT_PATH
153
+ config.model_args = model_args
154
+ config.run_name = RUN_NAME
155
+ config.project_name = PROJECT_NAME
156
+ config.run_description = """
157
+ GPT XTTS training
158
+ """,
159
+ config.dashboard_logger = DASHBOARD_LOGGER
160
+ config.logger_uri = LOGGER_URI
161
+ config.audio = audio_config
162
+ config.batch_size = BATCH_SIZE
163
+ config.num_loader_workers = 8
164
+ config.eval_split_max_size = 256
165
+ config.print_step = 50
166
+ config.plot_step = 100
167
+ config.log_model_step = 100
168
+ config.save_step = save_step
169
+ config.save_n_checkpoints = 1
170
+ config.save_checkpoints = True
171
+ config.print_eval = False
172
+ config.optimizer = "AdamW"
173
+ config.optimizer_wd_only_on_weights = OPTIMIZER_WD_ONLY_ON_WEIGHTS
174
+ config.optimizer_params = {"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": weight_decay}
175
+ config.lr = lr
176
+ config.lr_scheduler = "MultiStepLR"
177
+ config.lr_scheduler_params = {"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1}
178
+ config.test_sentences = []
179
+
180
+ # init the model from config
181
+ model = GPTTrainer.init_from_config(config)
182
+
183
+ # load training samples
184
+ train_samples, eval_samples = load_tts_samples(
185
+ DATASETS_CONFIG_LIST,
186
+ eval_split=True,
187
+ eval_split_max_size=config.eval_split_max_size,
188
+ eval_split_size=config.eval_split_size,
189
+ )
190
+
191
+ # init the trainer and 🚀
192
+ trainer = Trainer(
193
+ TrainerArgs(
194
+ restore_path=None, # xtts checkpoint is restored via xtts_checkpoint key so no need of restore it using Trainer restore_path parameter
195
+ skip_train_epoch=False,
196
+ start_with_eval=START_WITH_EVAL,
197
+ grad_accum_steps=GRAD_ACUMM_STEPS
198
+ ),
199
+ config,
200
+ output_path=os.path.join(output_path, "run", "training"),
201
+ model=model,
202
+ train_samples=train_samples,
203
+ eval_samples=eval_samples,
204
+ )
205
+ trainer.fit()
206
+
207
+ # get the longest text audio file to use as speaker reference
208
+ samples_len = [len(item["text"].split(" ")) for item in train_samples]
209
+ longest_text_idx = samples_len.index(max(samples_len))
210
+ speaker_ref = train_samples[longest_text_idx]["audio_file"]
211
+
212
+ trainer_out_path = trainer.output_path
213
+
214
+ # deallocate VRAM and RAM
215
+ del model, trainer, train_samples, eval_samples
216
+ gc.collect()
217
+
218
+ return trainer_out_path
219
+
220
+ if __name__ == "__main__":
221
+ parser = create_xtts_trainer_parser()
222
+ args = parser.parse_args()
223
+
224
+ trainer_out_path = train_gpt(
225
+ metadatas=args.metadatas,
226
+ output_path=args.output_path,
227
+ num_epochs=args.num_epochs,
228
+ batch_size=args.batch_size,
229
+ grad_acumm=args.grad_acumm,
230
+ weight_decay=args.weight_decay,
231
+ lr=args.lr,
232
+ max_text_length=args.max_text_length,
233
+ max_audio_length=args.max_audio_length,
234
+ save_step=args.save_step
235
+ )
236
+
237
+ print(f"Checkpoint saved in dir: {trainer_out_path}")
trainer_0_log.txt ADDED
The diff for this file is too large to render. See raw diff