File size: 3,127 Bytes
aa7cb02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from xtts_fine_tune.xtts_v2_data_formattor import Data_Pipeline
from xtts_fine_tune.xtts_v2_model_utils import xtts_v2_Model
import sys
import time
#!/usr/bin/env python


def Train_XTTS_V2(audio_directory, num_epochs, batch_size, grad_acumm, output_path, max_audio_length, language):
    """
    Train the XTTS V2 model with the given parameters.

    This function initializes the data pipeline, checks the audio length, formats the data, and trains the XTTS V2 model.

    Args:
        audio_directory (str): Path to the directory containing audio files.
            Example: "path/to/audio_files/"
        num_epochs (int): Number of training epochs.
            Example: 50
        batch_size (int): Size of each training batch.
            Example: 16
        grad_acumm (int): Gradient accumulation steps.
            Example: 4
        output_path (str): Path to save the trained model outputs.
            Example: "path/to/output/"
        max_audio_length (int): Maximum allowed length of audio for training in seconds.
            Example: 3600
        language (str): Language of the audio files, either 'en' for English or 'es' for Spanish.
            Example: "en"

    Returns:
        tuple: A tuple containing paths to the configuration file, vocabulary file, fine-tuned XTTS checkpoint, and speaker wav file.
            Example: ("config_path.json", "vocab_path.json", "checkpoint.pth", "speaker.wav")

    Example usage:
        config_path, vocab_path, ft_xtts_checkpoint, speaker_wav = Train_XTTS_V2(
            "path/to/audio_files/", 50, 16, 4, "path/to/output/", 3600, "en"
        )
    """
    Data_class = Data_Pipeline(audio_directory, language)
    length_audio = Data_class.get_combined_wav_lengths()
    if length_audio > max_audio_length:
        print("The audio is not long enough to be fine tuned. Waiting....")
        time.sleep(20)
        Train_XTTS_V2(audio_directory, num_epochs, batch_size, grad_acumm, output_path, max_audio_length, language)

    # get the directory before the current one
    audio_directory_parent = audio_directory.split("/")
    audio_directory_parent = audio_directory_parent[:-1]
    audio_directory_parent = "/".join(audio_directory_parent)
    _, train_meta, eval_meta =  Data_class.data_formatter(audio_directory_parent)
    xtts_v2 = xtts_v2_Model(train_meta, eval_meta, num_epochs, batch_size, grad_acumm, output_path, max_audio_length, language)
    _, config_path, vocab_path, ft_xtts_checkpoint, speaker_wav = xtts_v2.train_model()

    return config_path, vocab_path, ft_xtts_checkpoint, speaker_wav

if __name__ == "__main__":
    audio_directory = sys.argv[1]
    num_epochs = int(sys.argv[2])
    batch_size = int(sys.argv[3])
    grad_acumm = int(sys.argv[4])
    output_path = sys.argv[5]
    max_audio_length = int(sys.argv[6])
    language = sys.argv[7]

    config_path, vocab_path, ft_xtts_checkpoint, speaker_wav = Train_XTTS_V2(audio_directory, num_epochs, batch_size, grad_acumm, output_path, max_audio_length, language)

    # Do something with the returned values
    print(config_path, vocab_path, ft_xtts_checkpoint, speaker_wav)