SebastianBodza's picture
Update app.py
08f4fdd verified
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
import soundfile as sf
from xcodec2.modeling_xcodec2 import XCodec2Model
import torchaudio
import gradio as gr
import tempfile
import os
import numpy as np
llasa_1b ='HKUSTAudio/Llasa-1B-Multilingual'
tokenizer = AutoTokenizer.from_pretrained(llasa_1b, token=os.getenv("HF_TOKEN"))
model = AutoModelForCausalLM.from_pretrained(
llasa_1b, trust_remote_code=True, device_map="cuda", token=os.getenv("HF_TOKEN")
)
model_path = "srinivasbilla/xcodec2"
Codec_model = XCodec2Model.from_pretrained(model_path)
Codec_model.eval().cuda()
whisper_turbo_pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-large-v3-turbo",
torch_dtype=torch.float16,
device="cuda",
)
SPEAKERS = {
"Male 1": {
"path": "speakers/deep_speaker.mp3",
"transcript": "Das große Tor von Minas Tirith brach erst, nachdem er die Ramme eingesetzt hatte.",
"description": "Eine tiefe epische Männerstimme.",
},
"Male 2": {
"path": "speakers/male_austrian_accent.mp3",
"transcript": "Man kann sich auch leichter vorstellen, wie schwierig es ist, dass man Entscheidungen trifft, die allen passen.",
"description": "Eine männliche Stimme mit österreicherischem Akzent.",
},
"Male 3": {
"path": "speakers/male_energic.mp3",
"transcript": "Wo keine Infrastruktur, da auch keine Ansiedlung von IT-Unternehmen und deren Beschäftigten bzw. dem geeigneten Fachkräftenachwuchs. Kann man diese Rechnung so einfach aufmachen, wie es es tatsächlich um deren regionale Verteilung beschäftigt?",
"description": "Eine männliche energische Stimme",
},
"Male 4": {
"path": "speakers/schneller_speaker.mp3",
"transcript": "Genau, wenn wir alle Dächer voll machen, also alle Dächer von Einfamilienhäusern, alleine mit den Einfamilienhäusern können wir 20 Prozent des heutigen Strombedarfs decken.",
"description": "Eine männliche Spreche mit schnellerem Tempo.",
},
"Female 1": {
"path": "speakers/female_standard.mp3",
"transcript": "Es wird ein Beispiel für ein barrierearmes Layout gegeben, sowie Tipps und ein Verweis auf eine Checkliste, die hilft, Barrierearmut in den eigenen Materialien zu prüfen bzw. umzusetzen.",
"description": "Eine weibliche Stimme.",
},
"Female 2": {
"path": "speakers/female_energic.mp3",
"transcript": "Dunkel flog weiter durch das Wald. Er sah die Sterne am Phaneten an sich vorbeiziehen und fühlte sich frei und glücklich.",
"description": "Eine weibliche Erzähler-Stimme.",
},
"Female 3": {
"path": "speakers/austrian_accent.mp3",
"transcript": "Die politische Europäische Union war geboren, verbrieft im Vertrag von Maastricht. Ab diesem Zeitpunkt bestehen zwei Vertragswerke.",
"description": "Eine weibliche Stimme mit österreicherischem Akzent.",
},
"Special 1": {
"path": "speakers/low_audio.mp3",
"transcript": "Druckplatten und Lasersensoren, um sicherzugehen, dass er auch da drin ist und",
"description": "Eine männliche Stimme mit schlechter Audioqualität als Effekt.",
},
}
def preview_speaker(display_name):
"""Returns the audio and transcript for preview"""
speaker_name = speaker_display_dict[display_name]
if speaker_name in SPEAKERS:
waveform, sample_rate = torchaudio.load(SPEAKERS[speaker_name]["path"])
return (sample_rate, waveform[0].numpy()), SPEAKERS[speaker_name]["transcript"]
return None, ""
def normalize_audio(waveform: torch.Tensor, target_db: float = -20) -> torch.Tensor:
"""
Normalize audio volume to target dB and limit gain range.
Args:
waveform (torch.Tensor): Input audio waveform
target_db (float): Target dB level (default: -20)
Returns:
torch.Tensor: Normalized audio waveform
"""
# Calculate current dB
eps = 1e-10
current_db = 20 * torch.log10(torch.max(torch.abs(waveform)) + eps)
# Calculate required gain
gain_db = target_db - current_db
# Limit gain to -3 to 3 dB range
gain_db = torch.clamp(gain_db, min=-3, max=3)
# Apply gain
gain_factor = 10 ** (gain_db / 20)
normalized = waveform * gain_factor
# Final peak normalization
max_amplitude = torch.max(torch.abs(normalized))
if max_amplitude > 0:
normalized = normalized / max_amplitude
return normalized
def ids_to_speech_tokens(speech_ids):
speech_tokens_str = []
for speech_id in speech_ids:
speech_tokens_str.append(f"<|s_{speech_id}|>")
return speech_tokens_str
def extract_speech_ids(speech_tokens_str):
speech_ids = []
for token_str in speech_tokens_str:
if token_str.startswith("<|s_") and token_str.endswith("|>"):
num_str = token_str[4:-2]
num = int(num_str)
speech_ids.append(num)
else:
print(f"Unexpected token: {token_str}")
return speech_ids
@spaces.GPU(duration=30)
@torch.inference_mode()
def infer_with_speaker(
display_name,
target_text,
temp,
top_p_val,
min_new_tokens,
do_sample,
progress=gr.Progress(),
):
"""Modified infer function that uses predefined speaker"""
speaker_name = speaker_display_dict[display_name] # Get actual speaker name
if speaker_name not in SPEAKERS:
return None, "Invalid speaker selected"
return infer(
SPEAKERS[speaker_name]["path"],
target_text,
temp,
top_p_val,
min_new_tokens,
do_sample,
SPEAKERS[speaker_name]["transcript"], # Pass the predefined transcript
progress,
)
@spaces.GPU(duration=30)
@torch.inference_mode()
def gradio_infer(*args, **kwargs):
return infer(*args, **kwargs)
def infer(
sample_audio_path,
target_text,
temp,
top_p_val,
min_new_tokens,
do_sample,
transcribed_text=None,
progress=gr.Progress(),
):
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
progress(0, "Loading and trimming audio...")
waveform, sample_rate = torchaudio.load(sample_audio_path)
waveform = normalize_audio(waveform)
if len(waveform[0]) / sample_rate > 15:
gr.Warning("Trimming audio to first 15secs.")
waveform = waveform[:, : sample_rate * 15]
waveform = torch.nn.functional.pad(
waveform, (0, int(sample_rate * 0.5)), "constant", 0
)
# Check if the audio is stereo (i.e., has more than one channel)
if waveform.size(0) > 1:
# Convert stereo to mono by averaging the channels
waveform_mono = torch.mean(waveform, dim=0, keepdim=True)
else:
# If already mono, just use the original waveform
waveform_mono = waveform
prompt_wav = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=16000
)(waveform_mono)
if transcribed_text is None:
progress(0.3, "Transcribing audio...")
prompt_text = whisper_turbo_pipe(prompt_wav[0].numpy())["text"].strip()
else:
prompt_text = transcribed_text
progress(0.5, "Transcribed! Generating speech...")
if len(target_text) == 0:
return None
elif len(target_text) > 500:
gr.Warning("Text is too long. Please keep it under 300 characters.")
target_text = target_text[:500]
input_text = prompt_text + " " + target_text
# TTS start!
with torch.no_grad():
# Encode the prompt wav
vq_code_prompt = Codec_model.encode_code(input_waveform=prompt_wav)
vq_code_prompt = vq_code_prompt[0, 0, :]
# Convert int 12345 to token <|s_12345|>
speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt)
formatted_text = (
f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
)
# Tokenize the text and the speech prefix
chat = [
{
"role": "user",
"content": "Convert the text to speech:" + formatted_text,
},
{
"role": "assistant",
"content": "<|SPEECH_GENERATION_START|>"
+ "".join(speech_ids_prefix),
},
]
input_ids = tokenizer.apply_chat_template(
chat,
tokenize=True,
return_tensors="pt",
continue_final_message=True,
)
input_ids = input_ids.to("cuda")
speech_end_id = tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>")
# Generate the speech autoregressively
outputs = model.generate(
input_ids,
max_length=2048, # We trained our model with a max length of 2048
eos_token_id=speech_end_id,
do_sample=do_sample,
top_p=top_p_val,
temperature=temp,
min_new_tokens=min_new_tokens,
)
# Extract the speech tokens
generated_ids = outputs[0][input_ids.shape[1] - len(speech_ids_prefix) : -1]
speech_tokens = tokenizer.batch_decode(
generated_ids, skip_special_tokens=False
)
raw_output = " ".join(speech_tokens) # Capture raw tokens
speech_tokens = tokenizer.batch_decode(
generated_ids, skip_special_tokens=True
)
# Convert token <|s_23456|> to int 23456
speech_tokens = extract_speech_ids(speech_tokens)
speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
# Decode the speech tokens to speech waveform
gen_wav = Codec_model.decode_code(speech_tokens)
# if only need the generated part
gen_wav = gen_wav[:, :, prompt_wav.shape[1] :]
progress(1, "Synthesized!")
return (
16000,
gen_wav[0, 0, :].cpu().numpy(),
), raw_output # Return both audio and raw tokens
with gr.Blocks() as app_tts:
gr.Markdown("# Zero Shot Voice Clone TTS")
with gr.Accordion("Model Settings", open=False):
temperature = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.8,
step=0.1,
label="Temperature",
info="Higher values = more random/creative output",
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=1.0,
step=0.1,
label="Top P",
info="Nucleus sampling threshold",
)
min_new_tokens = gr.Slider(
minimum=0,
maximum=128,
value=3,
step=1,
label="Min Length",
info="If the model just produces a click you can force it to create longer generations.",
)
do_sample = gr.Checkbox(
label="Sample", value=True, info="Sample from the distribution"
)
ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
generate_btn = gr.Button("Synthesize", variant="primary")
audio_output = gr.Audio(label="Synthesized Audio")
raw_output_display = gr.Textbox(
label="Raw Model Output", interactive=False
) # Add textbox
generate_btn.click(
lambda *args: gradio_infer(*args, transcribed_text=None),
inputs=[
ref_audio_input,
gen_text_input,
temperature,
top_p,
min_new_tokens,
do_sample,
],
outputs=[audio_output, raw_output_display], # Include both outputs
)
with gr.Blocks() as app_speaker:
gr.Markdown("# Predefined Speaker TTS")
with gr.Accordion("Model Settings", open=False):
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.7,
step=0.1,
label="Temperature",
info="Higher values = more random/creative output",
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=1.0,
step=0.1,
label="Top P",
info="Nucleus sampling threshold",
)
min_new_tokens = gr.Slider(
minimum=0,
maximum=128,
value=3,
step=1,
label="Min Length",
info="If the model just produces a click you can force it to create longer generations.",
)
do_sample = gr.Checkbox(
label="Sample", value=True, info="Sample from the distribution"
)
with gr.Row():
speaker_display_dict = {
f"{name} - {SPEAKERS[name]['description']}": name
for name in SPEAKERS.keys()
}
speaker_dropdown = gr.Dropdown(
choices=list(speaker_display_dict.keys()),
label="Select Speaker",
value=list(speaker_display_dict.keys())[0],
)
preview_btn = gr.Button("Preview Voice")
with gr.Row():
preview_audio = gr.Audio(label="Preview")
preview_text = gr.Textbox(label="Original Transcript", interactive=False)
gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
generate_btn = gr.Button("Synthesize", variant="primary")
audio_output = gr.Audio(label="Synthesized Audio")
raw_output_display = gr.Textbox(label="Raw Model Output", interactive=False)
# Connect the preview button
preview_btn.click(
preview_speaker,
inputs=[speaker_dropdown],
outputs=[preview_audio, preview_text],
)
# Connect the generate button
generate_btn.click(
infer_with_speaker,
inputs=[
speaker_dropdown,
gen_text_input,
temperature,
top_p,
min_new_tokens,
do_sample,
],
outputs=[audio_output, raw_output_display],
)
with gr.Blocks() as app_credits:
gr.Markdown("""
# Credits
* [zhenye234](https://github.com/zhenye234) for the original [repo](https://github.com/zhenye234/LLaSA_training)
* [mrfakename](https://huggingface.co/mrfakename) for the [gradio demo code](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
""")
with gr.Blocks() as app:
gr.Markdown(
"""
Official Multilingual version
"""
)
gr.TabbedInterface([app_speaker, app_tts], ["Speaker", "Clone"])
app.launch(ssr_mode=False)