File size: 4,104 Bytes
943290c
0ea1315
 
 
 
 
 
943290c
0ea1315
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c411cd
0ea1315
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efcf731
 
 
 
 
 
0ea1315
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c411cd
 
0ea1315
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import spaces
import gradio as gr
import torch
import soundfile as sf
from transformers import AutoTokenizer, AutoModelForCausalLM
from xcodec2.modeling_xcodec2 import XCodec2Model
import tempfile

device = "cuda" if torch.cuda.is_available() else "cpu"

####################
#  全局加载模型
####################
llasa_3b = "HKUSTAudio/Llasa-1B-two-speakers-kore-puck"
print("Loading tokenizer & model ...")
tokenizer = AutoTokenizer.from_pretrained(llasa_3b)
model = AutoModelForCausalLM.from_pretrained(llasa_3b)
model.eval().to(device)

print("Loading XCodec2Model ...")
codec_model_path = "HKUSTAudio/xcodec2"
Codec_model = XCodec2Model.from_pretrained(codec_model_path)
Codec_model.eval().to(device)

print("Models loaded.")

####################
#  推理用函数
####################
def extract_speech_ids(speech_tokens_str):
    """
    将类似 <|s_23456|> 还原为 int 23456
    """
    speech_ids = []
    for token_str in speech_tokens_str:
        if token_str.startswith("<|s_") and token_str.endswith("|>"):
            num_str = token_str[4:-2]
            num = int(num_str)
            speech_ids.append(num)
        else:
            print(f"Unexpected token: {token_str}")
    return speech_ids
@spaces.GPU
def text2speech(input_text, speaker_choice):
    """
    将文本转为语音波形,并返回音频文件路径
    """
    with torch.no_grad():
        # 在输入文本前后拼接提示token
        formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
        chat = [
            {"role": "user", "content": "Convert the text to speech:" + formatted_text},
            {"role": "assistant", "content": f"Speaker {speaker_choice} <|SPEECH_GENERATION_START|>"}
        ]

        # tokenizer.apply_chat_template 是 Llasa 风格的对话模式
        input_ids = tokenizer.apply_chat_template(
            chat,
            tokenize=True,
            return_tensors='pt',
            continue_final_message=True
        ).to(device)

        # 结束符
        speech_end_id = tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>")

        # 文本生成
        outputs = model.generate(
            input_ids,
            max_length=2048,  # We trained our model with a max length of 2048
            eos_token_id= speech_end_id ,
            do_sample=True,    
            top_p=0.95,           #  Adjusts the diversity of generated content
            temperature=0.9,   #  Controls randomness in output
            repetition_penalty= 1.2,
        )

        # 把新生成的 token(不包括输入部分)取出来
        generated_ids = outputs[0][input_ids.shape[1]:-1]
        speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

        # 将 <|s_23456|> 提取成 [23456 ...]
        speech_tokens_int = extract_speech_ids(speech_tokens_str)
        speech_tokens_int = torch.tensor(speech_tokens_int).to(device).unsqueeze(0).unsqueeze(0)

        # 调用 XCodec2Model 解码波形
        gen_wav = Codec_model.decode_code(speech_tokens_int)  # [batch, channels, samples]

    # 获取音频数据和采样率
    audio = gen_wav[0, 0, :].cpu().numpy()
    sample_rate = 16000

    # 将音频保存到临时文件
    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
        sf.write(tmpfile.name, audio, sample_rate)
        audio_path = tmpfile.name

    return audio_path

####################
#  Gradio 界面
####################
speaker_choices = ["puck", "kore"]

demo = gr.Interface(
    fn=text2speech,
    inputs=[gr.Textbox(label="Enter text", lines=5),
            gr.Dropdown(choices=speaker_choices, label="Select Speaker", value="puck")],
    outputs=gr.Audio(label="Generated Audio", type="filepath"),
    title="Llasa-1B TTS finetuned using shb777/gemini-flash-2.0-speech",
    description="Input a piece of text in English, select a speaker (puck or kore), and click to generate speech.\nModel: HKUSTAudio/Llasa-1B-two-speakers-kore-puck + HKUSTAudio/xcodec2"
)

if __name__ == "__main__":
    demo.launch(
    share=True )