File size: 4,769 Bytes
c16872d
 
 
 
 
 
2037e5f
 
c16872d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2037e5f
c16872d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2037e5f
c16872d
 
 
 
 
 
2037e5f
c16872d
 
 
 
 
2037e5f
c16872d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f21fe07
c16872d
 
 
 
 
 
9a1ad3d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gradio as gr
import torch
import soundfile as sf
from snac import SNAC
from transformers import AutoTokenizer, AutoModelForCausalLM

# Ensure the code uses NVIDIA GPUs
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def find_last_instance_of_separator(lst, element=50258):
    reversed_list = lst[::-1]
    try:
        reversed_index = reversed_list.index(element)
        return len(lst) - 1 - reversed_index
    except ValueError:
        raise ValueError

def reconstruct_tensors(flattened_output):
    def count_elements_between_hashes(lst):
        try:
            first_index = lst.index(50258)
            second_index = lst.index(50258, first_index + 1)
            return second_index - first_index - 1
        except ValueError:
            return "List does not contain two '#' symbols"

    def remove_elements_before_hash(flattened_list):
        try:
            first_hash_index = flattened_list.index(50258)
            return flattened_list[first_hash_index:]
        except ValueError:
            return "List does not contain the symbol '#'"

    def list_to_torch_tensor(tensor1):
        tensor = torch.tensor(tensor1)
        tensor = tensor.unsqueeze(0)
        return tensor

    flattened_output = remove_elements_before_hash(flattened_output)
    last_index = find_last_instance_of_separator(flattened_output)
    flattened_output = flattened_output[:last_index]

    codes = []
    tensor1 = []
    tensor2 = []
    tensor3 = []
    tensor4 = []

    n_tensors = count_elements_between_hashes(flattened_output)
    if n_tensors == 7:
        for i in range(0, len(flattened_output), 8):
            tensor1.append(flattened_output[i+1])
            tensor2.append(flattened_output[i+2])
            tensor3.append(flattened_output[i+3])
            tensor3.append(flattened_output[i+4])
            tensor2.append(flattened_output[i+5])
            tensor3.append(flattened_output[i+6])
            tensor3.append(flattened_output[i+7])
            codes = [list_to_torch_tensor(tensor1).to(device), list_to_torch_tensor(tensor2).to(device), list_to_torch_tensor(tensor3).to(device)]

    if n_tensors == 15:
        for i in range(0, len(flattened_output), 16):
            tensor1.append(flattened_output[i+1])
            tensor2.append(flattened_output[i+2])
            tensor3.append(flattened_output[i+3])
            tensor4.append(flattened_output[i+4])
            tensor4.append(flattened_output[i+5])
            tensor3.append(flattened_output[i+6])
            tensor4.append(flattened_output[i+7])
            tensor4.append(flattened_output[i+8])
            tensor2.append(flattened_output[i+9])
            tensor3.append(flattened_output[i+10])
            tensor4.append(flattened_output[i+11])
            tensor4.append(flattened_output[i+12])
            tensor3.append(flattened_output[i+13])
            tensor4.append(flattened_output[i+14])
            tensor4.append(flattened_output[i+15])
            codes = [list_to_torch_tensor(tensor1).to(device), list_to_torch_tensor(tensor2).to(device), list_to_torch_tensor(tensor3).to(device), list_to_torch_tensor(tensor4).to(device)]

    return codes

def load_model():
    tokenizer = AutoTokenizer.from_pretrained("Lwasinam/voicera-jenny-finetune")
    model = AutoModelForCausalLM.from_pretrained("Lwasinam/voicera-jenny-finetune").to(device)
    snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)
    return model, tokenizer, snac_model

def SpeechDecoder(codes, snac_model):
    codes = codes.squeeze(0).tolist()
    reconstructed_codes = reconstruct_tensors(codes)
    audio_hat = snac_model.decode(reconstructed_codes)
    audio_path = "reconstructed_audio.wav"
    sf.write(audio_path, audio_hat.squeeze().cpu().detach().numpy(), 24000)
    return audio_path

def generate_audio(text, tokenizer, model, snac_model):
    output_codes = []
    with torch.no_grad():
        input_text = text
        input_ids = tokenizer(input_text, return_tensors='pt').to(device)
        output_codes = model.generate(input_ids['input_ids'], attention_mask=input_ids['attention_mask'], max_length=1024,
                                      num_beams=5, top_p=0.95, temperature=0.8, do_sample=True, repetition_penalty=2.0)
        audio_path = SpeechDecoder(output_codes, snac_model)
    return audio_path

def main(text):
    model, tokenizer, snac_model = load_model()
    audio_path = generate_audio(text, tokenizer, model, snac_model)
    return audio_path

# Define the Gradio interface
iface = gr.Interface(
    fn=main,
    inputs='textbox',
    outputs="audio",
    title="Voicera TTS",
    description="Generate speech from text using Voicera TTS model."
)

if __name__ == "__main__":
    iface.launch()