Add files using upload-large-folder tool
Browse files- best_model.pth +3 -0
- best_model_80180.pth +3 -0
- checkpoint_80000.pth +3 -0
- config.json +193 -0
- events.out.tfevents.1739298885.node2.188457.0 +3 -0
- inference.py +81 -0
- karan_narration1.wav +0 -0
- train_gpt_xtts.py +237 -0
- trainer_0_log.txt +0 -0
best_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:23aa1fbbef3ec860e559fcb94a9830930ea7111c3bd73ba3d3e3957c5aeca369
|
3 |
+
size 5656611329
|
best_model_80180.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:23aa1fbbef3ec860e559fcb94a9830930ea7111c3bd73ba3d3e3957c5aeca369
|
3 |
+
size 5656611329
|
checkpoint_80000.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d0ec02d5d1a4a668edc2a4bb01dfbb7d78842aa1bc64fa63f456a40bf7637a7a
|
3 |
+
size 5656611265
|
config.json
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"output_path": "checkpoints/",
|
3 |
+
"logger_uri": null,
|
4 |
+
"run_name": "GPT_XTTS_FT",
|
5 |
+
"project_name": "XTTS_trainer",
|
6 |
+
"run_description": [
|
7 |
+
"\n GPT XTTS training\n "
|
8 |
+
],
|
9 |
+
"print_step": 50,
|
10 |
+
"plot_step": 100,
|
11 |
+
"model_param_stats": false,
|
12 |
+
"wandb_entity": null,
|
13 |
+
"dashboard_logger": "tensorboard",
|
14 |
+
"save_on_interrupt": true,
|
15 |
+
"log_model_step": 100,
|
16 |
+
"save_step": 1000,
|
17 |
+
"save_n_checkpoints": 1,
|
18 |
+
"save_checkpoints": true,
|
19 |
+
"save_all_best": false,
|
20 |
+
"save_best_after": 10000,
|
21 |
+
"target_loss": null,
|
22 |
+
"print_eval": false,
|
23 |
+
"test_delay_epochs": 0,
|
24 |
+
"run_eval": true,
|
25 |
+
"run_eval_steps": null,
|
26 |
+
"distributed_backend": "nccl",
|
27 |
+
"distributed_url": "tcp://localhost:54321",
|
28 |
+
"mixed_precision": false,
|
29 |
+
"precision": "fp16",
|
30 |
+
"epochs": 10,
|
31 |
+
"batch_size": 8,
|
32 |
+
"eval_batch_size": 16,
|
33 |
+
"grad_clip": 0.0,
|
34 |
+
"scheduler_after_epoch": true,
|
35 |
+
"lr": 5e-06,
|
36 |
+
"optimizer": "AdamW",
|
37 |
+
"optimizer_params": {
|
38 |
+
"betas": [
|
39 |
+
0.9,
|
40 |
+
0.96
|
41 |
+
],
|
42 |
+
"eps": 1e-08,
|
43 |
+
"weight_decay": 0.01
|
44 |
+
},
|
45 |
+
"lr_scheduler": "MultiStepLR",
|
46 |
+
"lr_scheduler_params": {
|
47 |
+
"milestones": [
|
48 |
+
900000,
|
49 |
+
2700000,
|
50 |
+
5400000
|
51 |
+
],
|
52 |
+
"gamma": 0.5,
|
53 |
+
"last_epoch": -1
|
54 |
+
},
|
55 |
+
"use_grad_scaler": false,
|
56 |
+
"allow_tf32": false,
|
57 |
+
"cudnn_enable": true,
|
58 |
+
"cudnn_deterministic": false,
|
59 |
+
"cudnn_benchmark": false,
|
60 |
+
"training_seed": 54321,
|
61 |
+
"model": "xtts",
|
62 |
+
"num_loader_workers": 8,
|
63 |
+
"num_eval_loader_workers": 0,
|
64 |
+
"use_noise_augment": false,
|
65 |
+
"audio": {
|
66 |
+
"sample_rate": 22050,
|
67 |
+
"output_sample_rate": 24000,
|
68 |
+
"dvae_sample_rate": 22050
|
69 |
+
},
|
70 |
+
"use_phonemes": false,
|
71 |
+
"phonemizer": null,
|
72 |
+
"phoneme_language": null,
|
73 |
+
"compute_input_seq_cache": false,
|
74 |
+
"text_cleaner": null,
|
75 |
+
"enable_eos_bos_chars": false,
|
76 |
+
"test_sentences_file": "",
|
77 |
+
"phoneme_cache_path": null,
|
78 |
+
"characters": null,
|
79 |
+
"add_blank": false,
|
80 |
+
"batch_group_size": 0,
|
81 |
+
"loss_masking": null,
|
82 |
+
"min_audio_len": 1,
|
83 |
+
"max_audio_len": Infinity,
|
84 |
+
"min_text_len": 1,
|
85 |
+
"max_text_len": Infinity,
|
86 |
+
"compute_f0": false,
|
87 |
+
"compute_energy": false,
|
88 |
+
"compute_linear_spec": false,
|
89 |
+
"precompute_num_workers": 0,
|
90 |
+
"start_by_longest": false,
|
91 |
+
"shuffle": false,
|
92 |
+
"drop_last": false,
|
93 |
+
"datasets": [
|
94 |
+
{
|
95 |
+
"formatter": "",
|
96 |
+
"dataset_name": "",
|
97 |
+
"path": "",
|
98 |
+
"meta_file_train": "",
|
99 |
+
"ignored_speakers": null,
|
100 |
+
"language": "",
|
101 |
+
"phonemizer": "",
|
102 |
+
"meta_file_val": "",
|
103 |
+
"meta_file_attn_mask": ""
|
104 |
+
}
|
105 |
+
],
|
106 |
+
"test_sentences": [],
|
107 |
+
"eval_split_max_size": 256,
|
108 |
+
"eval_split_size": 0.01,
|
109 |
+
"use_speaker_weighted_sampler": false,
|
110 |
+
"speaker_weighted_sampler_alpha": 1.0,
|
111 |
+
"use_language_weighted_sampler": false,
|
112 |
+
"language_weighted_sampler_alpha": 1.0,
|
113 |
+
"use_length_weighted_sampler": false,
|
114 |
+
"length_weighted_sampler_alpha": 1.0,
|
115 |
+
"model_args": {
|
116 |
+
"gpt_batch_size": 1,
|
117 |
+
"enable_redaction": false,
|
118 |
+
"kv_cache": true,
|
119 |
+
"gpt_checkpoint": "",
|
120 |
+
"clvp_checkpoint": null,
|
121 |
+
"decoder_checkpoint": null,
|
122 |
+
"num_chars": 255,
|
123 |
+
"tokenizer_file": "checkpoints/XTTS_v2.0_original_model_files/vocab.json",
|
124 |
+
"gpt_max_audio_tokens": 605,
|
125 |
+
"gpt_max_text_tokens": 402,
|
126 |
+
"gpt_max_prompt_tokens": 70,
|
127 |
+
"gpt_layers": 30,
|
128 |
+
"gpt_n_model_channels": 1024,
|
129 |
+
"gpt_n_heads": 16,
|
130 |
+
"gpt_number_text_tokens": 8661,
|
131 |
+
"gpt_start_text_token": 261,
|
132 |
+
"gpt_stop_text_token": 0,
|
133 |
+
"gpt_num_audio_tokens": 1026,
|
134 |
+
"gpt_start_audio_token": 1024,
|
135 |
+
"gpt_stop_audio_token": 1025,
|
136 |
+
"gpt_code_stride_len": 1024,
|
137 |
+
"gpt_use_masking_gt_prompt_approach": true,
|
138 |
+
"gpt_use_perceiver_resampler": true,
|
139 |
+
"input_sample_rate": 22050,
|
140 |
+
"output_sample_rate": 24000,
|
141 |
+
"output_hop_length": 256,
|
142 |
+
"decoder_input_dim": 1024,
|
143 |
+
"d_vector_dim": 512,
|
144 |
+
"cond_d_vector_in_each_upsampling_layer": true,
|
145 |
+
"duration_const": 102400,
|
146 |
+
"min_conditioning_length": 11025,
|
147 |
+
"max_conditioning_length": 132300,
|
148 |
+
"gpt_loss_text_ce_weight": 0.01,
|
149 |
+
"gpt_loss_mel_ce_weight": 1.0,
|
150 |
+
"debug_loading_failures": false,
|
151 |
+
"max_wav_length": 330750,
|
152 |
+
"max_text_length": 400,
|
153 |
+
"mel_norm_file": "checkpoints/XTTS_v2.0_original_model_files/mel_stats.pth",
|
154 |
+
"dvae_checkpoint": "checkpoints/XTTS_v2.0_original_model_files/dvae.pth",
|
155 |
+
"xtts_checkpoint": "checkpoints/XTTS_v2.0_original_model_files/model.pth",
|
156 |
+
"vocoder": ""
|
157 |
+
},
|
158 |
+
"model_dir": null,
|
159 |
+
"languages": [
|
160 |
+
"en",
|
161 |
+
"es",
|
162 |
+
"fr",
|
163 |
+
"de",
|
164 |
+
"it",
|
165 |
+
"pt",
|
166 |
+
"pl",
|
167 |
+
"tr",
|
168 |
+
"ru",
|
169 |
+
"nl",
|
170 |
+
"cs",
|
171 |
+
"ar",
|
172 |
+
"zh-cn",
|
173 |
+
"hu",
|
174 |
+
"ko",
|
175 |
+
"ja",
|
176 |
+
"hi",
|
177 |
+
"te"
|
178 |
+
],
|
179 |
+
"temperature": 0.75,
|
180 |
+
"length_penalty": 1.0,
|
181 |
+
"repetition_penalty": 5.0,
|
182 |
+
"top_k": 50,
|
183 |
+
"top_p": 0.85,
|
184 |
+
"num_gpt_outputs": 1,
|
185 |
+
"gpt_cond_len": 30,
|
186 |
+
"gpt_cond_chunk_len": 4,
|
187 |
+
"max_ref_len": 30,
|
188 |
+
"sound_norm_refs": false,
|
189 |
+
"optimizer_wd_only_on_weights": true,
|
190 |
+
"weighted_loss_attrs": null,
|
191 |
+
"weighted_loss_multipliers": null,
|
192 |
+
"github_branch": "* main"
|
193 |
+
}
|
events.out.tfevents.1739298885.node2.188457.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3ea042bd25d446dae787dfc01d9ec0249466eb72bea5d5a8f0eeb957501a8fd4
|
3 |
+
size 331915
|
inference.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchaudio
|
3 |
+
from tqdm import tqdm
|
4 |
+
from underthesea import sent_tokenize
|
5 |
+
import os
|
6 |
+
|
7 |
+
from TTS.tts.configs.xtts_config import XttsConfig
|
8 |
+
from TTS.tts.models.xtts import Xtts
|
9 |
+
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer # Import the tokenizer
|
10 |
+
|
11 |
+
# Device configuration
|
12 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
13 |
+
|
14 |
+
# Model paths
|
15 |
+
checkpoint_dir = "/export/home/vivian/svarah/XTTSv2-Finetuning-for-New-Languages/checkpoints/GPT_XTTS_FT-February-12-2025_12+04AM-8e59ec3"
|
16 |
+
xtts_checkpoint = os.path.join(checkpoint_dir, "best_model.pth")
|
17 |
+
xtts_config = os.path.join(checkpoint_dir, "config.json")
|
18 |
+
xtts_vocab = "/export/home/vivian/svarah/XTTSv2-Finetuning-for-New-Languages/checkpoints/XTTS_v2.0_original_model_files/vocab.json" # Path to vocab file
|
19 |
+
speaker_file_path = os.path.join(checkpoint_dir, "speakers_xtts.pth") # Path to speaker file
|
20 |
+
|
21 |
+
# Load model
|
22 |
+
config = XttsConfig()
|
23 |
+
config.load_json(xtts_config)
|
24 |
+
|
25 |
+
# Initialize the tokenizer
|
26 |
+
tokenizer = VoiceBpeTokenizer(xtts_vocab) # Manually initialize the tokenizer
|
27 |
+
|
28 |
+
# Initialize the model
|
29 |
+
XTTS_MODEL = Xtts.init_from_config(config)
|
30 |
+
|
31 |
+
# Load checkpoint
|
32 |
+
XTTS_MODEL.load_checkpoint(
|
33 |
+
config,
|
34 |
+
checkpoint_path=xtts_checkpoint,
|
35 |
+
checkpoint_dir=checkpoint_dir, # Explicitly provide checkpoint_dir
|
36 |
+
vocab_path=xtts_vocab,
|
37 |
+
speaker_file_path=speaker_file_path, # Explicitly provide speaker file path
|
38 |
+
use_deepspeed=False,
|
39 |
+
)
|
40 |
+
XTTS_MODEL.to(device)
|
41 |
+
# Manually set the tokenizer
|
42 |
+
|
43 |
+
print("Model loaded successfully!")
|
44 |
+
|
45 |
+
# Inference
|
46 |
+
tts_text = "పల్నాడు ప్రాంతంలోని చిన్న గ్రామంలో అపరాజితుడు అనే బాలుడు తన తల్లిదండ్రులతో కలిసి ఉండేవాడు. అతని కుటుంబం గొప్ప ధనవంతం కాకపోయినా, అతనికి గొప్ప లక్ష్యసాధన కోరిక ఉండేది. అతనికి చదువుపై అపారమైన ఆసక్తి, కానీ కుటుంబ స్థితిగతుల వల్ల మంచి పాఠశాలకు వెళ్లే అవకాశం లేకపోయింది. అపరాజితుడు ఉదయాన్నే లేచి మేతెల్లో పనిచేసి తల్లికి సహాయపడుతూ, మిగిలిన సమయాన్ని తన చదువుకు అంకితం చేసేవాడు. ఊరిలోని చిన్న గ్రంథాలయంలో ఉన్న పుస్తకాలను చదివి కొత్త విషయాలు నేర్చుకునే ప్రయత్నం చేసేవాడు. అతని పట్టుదల చూసిన గ్రామంలోని ఓ పండితుడు, బాలుడా నీ ప్రయత్నం అమోఘం, నీ లక్ష్యాన్ని సాధించేందుకు నేను సహాయపడతాను అని చెప్పి, అతనికి కావాల్సిన పుస్తకాలను అందించేవాడు. కొన్నేళ్ల తర్వాత, అపరాజితుడు తన కృషి, పట్టుదలతో పరీక్షల్లో రాష్ట్ర స్థాయిలో ఉత్తమ ర్యాంకు సాధించి, ప్రభుత్వ స్కాలర్షిప్తో పెద్ద నగరానికి వెళ్లాడు. అక్కడ ఉన్నత చదువులు పూర్తి చేసి, ఒక గొప్ప శాస్త్రవేత్తగా ఎదిగాడు. తన ఊరిని మర్చిపోకుండా, అక్కడ ఒక పాఠశాలను స్థాపించి, చాలా మంది పేద పిల్లలకు ఉచిత విద్య అందించాడు. అతని విజయాన్ని చూసి గ్రామస్థులు గర్వంతో ఆనందంతో మన ఊరి బాలుడు ప్రపంచానికి వెలుగునిచ్చాడు అని ప్రశంసించారు. కష్టపడే మనసు, పట్టుదల ఉంటే, ఏ లక్ష్యాన్నైనా సాధించవచ్చు."
|
47 |
+
speaker_audio_file = "/export/home/vivian/svarah/telugu_tts/Telugu/wavs/4503599627574654_chunk_10_enhanced.wav"
|
48 |
+
|
49 |
+
lang = "te"
|
50 |
+
|
51 |
+
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(
|
52 |
+
audio_path=speaker_audio_file,
|
53 |
+
gpt_cond_len=XTTS_MODEL.config.gpt_cond_len,
|
54 |
+
max_ref_length=XTTS_MODEL.config.max_ref_len,
|
55 |
+
sound_norm_refs=XTTS_MODEL.config.sound_norm_refs,
|
56 |
+
)
|
57 |
+
|
58 |
+
tts_texts = sent_tokenize(tts_text)
|
59 |
+
|
60 |
+
wav_chunks = []
|
61 |
+
for text in tqdm(tts_texts):
|
62 |
+
wav_chunk = XTTS_MODEL.inference(
|
63 |
+
text=text,
|
64 |
+
language=lang,
|
65 |
+
gpt_cond_latent=gpt_cond_latent,
|
66 |
+
speaker_embedding=speaker_embedding,
|
67 |
+
temperature=0.1,
|
68 |
+
length_penalty=1.0,
|
69 |
+
repetition_penalty=10.0,
|
70 |
+
top_k=10,
|
71 |
+
top_p=0.3,
|
72 |
+
)
|
73 |
+
wav_chunks.append(torch.tensor(wav_chunk["wav"]))
|
74 |
+
|
75 |
+
out_wav = torch.cat(wav_chunks, dim=0).unsqueeze(0).cpu()
|
76 |
+
|
77 |
+
# Play audio (for Jupyter Notebook)
|
78 |
+
output_audio_path = "/export/home/vivian/svarah/XTTSv2-Finetuning-for-New-Languages/checkpoints/output_audio2.wav"
|
79 |
+
torchaudio.save(output_audio_path, out_wav, sample_rate=24000)
|
80 |
+
|
81 |
+
print(f"Audio saved to {output_audio_path}")
|
karan_narration1.wav
ADDED
Binary file (60.6 kB). View file
|
|
train_gpt_xtts.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gc
|
3 |
+
|
4 |
+
from trainer import Trainer, TrainerArgs
|
5 |
+
|
6 |
+
from TTS.config.shared_configs import BaseDatasetConfig
|
7 |
+
from TTS.tts.datasets import load_tts_samples
|
8 |
+
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
|
9 |
+
from TTS.utils.manage import ModelManager
|
10 |
+
|
11 |
+
from dataclasses import dataclass, field
|
12 |
+
from typing import Optional
|
13 |
+
from transformers import HfArgumentParser
|
14 |
+
|
15 |
+
import argparse
|
16 |
+
|
17 |
+
def create_xtts_trainer_parser():
|
18 |
+
parser = argparse.ArgumentParser(description="Arguments for XTTS Trainer")
|
19 |
+
|
20 |
+
parser.add_argument("--output_path", type=str, required=True,
|
21 |
+
help="Path to pretrained + checkpoint model")
|
22 |
+
parser.add_argument("--metadatas", nargs='+', type=str, required=True,
|
23 |
+
help="train_csv_path,eval_csv_path,language")
|
24 |
+
parser.add_argument("--num_epochs", type=int, default=1,
|
25 |
+
help="Number of epochs")
|
26 |
+
parser.add_argument("--batch_size", type=int, default=1,
|
27 |
+
help="Mini batch size")
|
28 |
+
parser.add_argument("--grad_acumm", type=int, default=1,
|
29 |
+
help="Grad accumulation steps")
|
30 |
+
parser.add_argument("--max_audio_length", type=int, default=255995,
|
31 |
+
help="Max audio length")
|
32 |
+
parser.add_argument("--max_text_length", type=int, default=200,
|
33 |
+
help="Max text length")
|
34 |
+
parser.add_argument("--weight_decay", type=float, default=1e-2,
|
35 |
+
help="Weight decay")
|
36 |
+
parser.add_argument("--lr", type=float, default=5e-6,
|
37 |
+
help="Learning rate")
|
38 |
+
parser.add_argument("--save_step", type=int, default=5000,
|
39 |
+
help="Save step")
|
40 |
+
|
41 |
+
return parser
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
def train_gpt(metadatas, num_epochs, batch_size, grad_acumm, output_path, max_audio_length, max_text_length, lr, weight_decay, save_step):
|
46 |
+
# Logging parameters
|
47 |
+
RUN_NAME = "GPT_XTTS_FT"
|
48 |
+
PROJECT_NAME = "XTTS_trainer"
|
49 |
+
DASHBOARD_LOGGER = "tensorboard"
|
50 |
+
LOGGER_URI = None
|
51 |
+
|
52 |
+
# Set here the path that the checkpoints will be saved. Default: ./run/training/
|
53 |
+
# OUT_PATH = os.path.join(output_path, "run", "training")
|
54 |
+
OUT_PATH = output_path
|
55 |
+
|
56 |
+
# Training Parameters
|
57 |
+
OPTIMIZER_WD_ONLY_ON_WEIGHTS = True # for multi-gpu training please make it False
|
58 |
+
START_WITH_EVAL = False # if True it will star with evaluation
|
59 |
+
BATCH_SIZE = batch_size # set here the batch size
|
60 |
+
GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps
|
61 |
+
|
62 |
+
|
63 |
+
# Define here the dataset that you want to use for the fine-tuning on.
|
64 |
+
DATASETS_CONFIG_LIST = []
|
65 |
+
for metadata in metadatas:
|
66 |
+
train_csv, eval_csv, language = metadata.split(",")
|
67 |
+
print(train_csv, eval_csv, language)
|
68 |
+
|
69 |
+
config_dataset = BaseDatasetConfig(
|
70 |
+
formatter="coqui",
|
71 |
+
dataset_name="ft_dataset",
|
72 |
+
path=os.path.dirname(train_csv),
|
73 |
+
meta_file_train=os.path.basename(train_csv),
|
74 |
+
meta_file_val=os.path.basename(eval_csv),
|
75 |
+
language=language,
|
76 |
+
)
|
77 |
+
|
78 |
+
DATASETS_CONFIG_LIST.append(config_dataset)
|
79 |
+
|
80 |
+
# Define the path where XTTS v2.0.1 files will be downloaded
|
81 |
+
CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/")
|
82 |
+
os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
|
83 |
+
|
84 |
+
|
85 |
+
# DVAE files
|
86 |
+
DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth"
|
87 |
+
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth"
|
88 |
+
|
89 |
+
# Set the path to the downloaded files
|
90 |
+
DVAE_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(DVAE_CHECKPOINT_LINK))
|
91 |
+
MEL_NORM_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(MEL_NORM_LINK))
|
92 |
+
|
93 |
+
# download DVAE files if needed
|
94 |
+
if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
|
95 |
+
print(" > Downloading DVAE files!")
|
96 |
+
ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True)
|
97 |
+
|
98 |
+
|
99 |
+
# Download XTTS v2.0 checkpoint if needed
|
100 |
+
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json"
|
101 |
+
XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth"
|
102 |
+
XTTS_CONFIG_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/config.json"
|
103 |
+
|
104 |
+
# XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning.
|
105 |
+
TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(TOKENIZER_FILE_LINK)) # vocab.json file
|
106 |
+
XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CHECKPOINT_LINK)) # model.pth file
|
107 |
+
XTTS_CONFIG_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CONFIG_LINK)) # config.json file
|
108 |
+
|
109 |
+
# download XTTS v2.0 files if needed
|
110 |
+
if not os.path.isfile(TOKENIZER_FILE):
|
111 |
+
print(" > Downloading XTTS v2.0 tokenizer!")
|
112 |
+
ModelManager._download_model_files(
|
113 |
+
[TOKENIZER_FILE_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True
|
114 |
+
)
|
115 |
+
if not os.path.isfile(XTTS_CHECKPOINT):
|
116 |
+
print(" > Downloading XTTS v2.0 checkpoint!")
|
117 |
+
ModelManager._download_model_files(
|
118 |
+
[XTTS_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True
|
119 |
+
)
|
120 |
+
if not os.path.isfile(XTTS_CONFIG_FILE):
|
121 |
+
print(" > Downloading XTTS v2.0 config!")
|
122 |
+
ModelManager._download_model_files(
|
123 |
+
[XTTS_CONFIG_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True
|
124 |
+
)
|
125 |
+
|
126 |
+
# init args and config
|
127 |
+
model_args = GPTArgs(
|
128 |
+
max_conditioning_length=132300, # 6 secs
|
129 |
+
min_conditioning_length=11025, # 0.5 secs
|
130 |
+
debug_loading_failures=False,
|
131 |
+
max_wav_length=max_audio_length, # ~11.6 seconds
|
132 |
+
max_text_length=max_text_length,
|
133 |
+
mel_norm_file=MEL_NORM_FILE,
|
134 |
+
dvae_checkpoint=DVAE_CHECKPOINT,
|
135 |
+
xtts_checkpoint=XTTS_CHECKPOINT, # checkpoint path of the model that you want to fine-tune
|
136 |
+
tokenizer_file=TOKENIZER_FILE,
|
137 |
+
gpt_num_audio_tokens=1026,
|
138 |
+
gpt_start_audio_token=1024,
|
139 |
+
gpt_stop_audio_token=1025,
|
140 |
+
gpt_use_masking_gt_prompt_approach=True,
|
141 |
+
gpt_use_perceiver_resampler=True,
|
142 |
+
)
|
143 |
+
# define audio config
|
144 |
+
audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000)
|
145 |
+
# training parameters config
|
146 |
+
|
147 |
+
config = GPTTrainerConfig()
|
148 |
+
|
149 |
+
config.load_json(XTTS_CONFIG_FILE)
|
150 |
+
|
151 |
+
config.epochs = num_epochs
|
152 |
+
config.output_path = OUT_PATH
|
153 |
+
config.model_args = model_args
|
154 |
+
config.run_name = RUN_NAME
|
155 |
+
config.project_name = PROJECT_NAME
|
156 |
+
config.run_description = """
|
157 |
+
GPT XTTS training
|
158 |
+
""",
|
159 |
+
config.dashboard_logger = DASHBOARD_LOGGER
|
160 |
+
config.logger_uri = LOGGER_URI
|
161 |
+
config.audio = audio_config
|
162 |
+
config.batch_size = BATCH_SIZE
|
163 |
+
config.num_loader_workers = 8
|
164 |
+
config.eval_split_max_size = 256
|
165 |
+
config.print_step = 50
|
166 |
+
config.plot_step = 100
|
167 |
+
config.log_model_step = 100
|
168 |
+
config.save_step = save_step
|
169 |
+
config.save_n_checkpoints = 1
|
170 |
+
config.save_checkpoints = True
|
171 |
+
config.print_eval = False
|
172 |
+
config.optimizer = "AdamW"
|
173 |
+
config.optimizer_wd_only_on_weights = OPTIMIZER_WD_ONLY_ON_WEIGHTS
|
174 |
+
config.optimizer_params = {"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": weight_decay}
|
175 |
+
config.lr = lr
|
176 |
+
config.lr_scheduler = "MultiStepLR"
|
177 |
+
config.lr_scheduler_params = {"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1}
|
178 |
+
config.test_sentences = []
|
179 |
+
|
180 |
+
# init the model from config
|
181 |
+
model = GPTTrainer.init_from_config(config)
|
182 |
+
|
183 |
+
# load training samples
|
184 |
+
train_samples, eval_samples = load_tts_samples(
|
185 |
+
DATASETS_CONFIG_LIST,
|
186 |
+
eval_split=True,
|
187 |
+
eval_split_max_size=config.eval_split_max_size,
|
188 |
+
eval_split_size=config.eval_split_size,
|
189 |
+
)
|
190 |
+
|
191 |
+
# init the trainer and 🚀
|
192 |
+
trainer = Trainer(
|
193 |
+
TrainerArgs(
|
194 |
+
restore_path=None, # xtts checkpoint is restored via xtts_checkpoint key so no need of restore it using Trainer restore_path parameter
|
195 |
+
skip_train_epoch=False,
|
196 |
+
start_with_eval=START_WITH_EVAL,
|
197 |
+
grad_accum_steps=GRAD_ACUMM_STEPS
|
198 |
+
),
|
199 |
+
config,
|
200 |
+
output_path=os.path.join(output_path, "run", "training"),
|
201 |
+
model=model,
|
202 |
+
train_samples=train_samples,
|
203 |
+
eval_samples=eval_samples,
|
204 |
+
)
|
205 |
+
trainer.fit()
|
206 |
+
|
207 |
+
# get the longest text audio file to use as speaker reference
|
208 |
+
samples_len = [len(item["text"].split(" ")) for item in train_samples]
|
209 |
+
longest_text_idx = samples_len.index(max(samples_len))
|
210 |
+
speaker_ref = train_samples[longest_text_idx]["audio_file"]
|
211 |
+
|
212 |
+
trainer_out_path = trainer.output_path
|
213 |
+
|
214 |
+
# deallocate VRAM and RAM
|
215 |
+
del model, trainer, train_samples, eval_samples
|
216 |
+
gc.collect()
|
217 |
+
|
218 |
+
return trainer_out_path
|
219 |
+
|
220 |
+
if __name__ == "__main__":
|
221 |
+
parser = create_xtts_trainer_parser()
|
222 |
+
args = parser.parse_args()
|
223 |
+
|
224 |
+
trainer_out_path = train_gpt(
|
225 |
+
metadatas=args.metadatas,
|
226 |
+
output_path=args.output_path,
|
227 |
+
num_epochs=args.num_epochs,
|
228 |
+
batch_size=args.batch_size,
|
229 |
+
grad_acumm=args.grad_acumm,
|
230 |
+
weight_decay=args.weight_decay,
|
231 |
+
lr=args.lr,
|
232 |
+
max_text_length=args.max_text_length,
|
233 |
+
max_audio_length=args.max_audio_length,
|
234 |
+
save_step=args.save_step
|
235 |
+
)
|
236 |
+
|
237 |
+
print(f"Checkpoint saved in dir: {trainer_out_path}")
|
trainer_0_log.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|