Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
import torch | |
import soundfile as sf | |
from xcodec2.modeling_xcodec2 import XCodec2Model | |
import torchaudio | |
import gradio as gr | |
import tempfile | |
import os | |
import numpy as np | |
llasa_1b ='HKUSTAudio/Llasa-1B-Multilingual' | |
tokenizer = AutoTokenizer.from_pretrained(llasa_1b, token=os.getenv("HF_TOKEN")) | |
model = AutoModelForCausalLM.from_pretrained( | |
llasa_1b, trust_remote_code=True, device_map="cuda", token=os.getenv("HF_TOKEN") | |
) | |
model_path = "srinivasbilla/xcodec2" | |
Codec_model = XCodec2Model.from_pretrained(model_path) | |
Codec_model.eval().cuda() | |
whisper_turbo_pipe = pipeline( | |
"automatic-speech-recognition", | |
model="openai/whisper-large-v3-turbo", | |
torch_dtype=torch.float16, | |
device="cuda", | |
) | |
SPEAKERS = { | |
"Male 1": { | |
"path": "speakers/deep_speaker.mp3", | |
"transcript": "Das große Tor von Minas Tirith brach erst, nachdem er die Ramme eingesetzt hatte.", | |
"description": "Eine tiefe epische Männerstimme.", | |
}, | |
"Male 2": { | |
"path": "speakers/male_austrian_accent.mp3", | |
"transcript": "Man kann sich auch leichter vorstellen, wie schwierig es ist, dass man Entscheidungen trifft, die allen passen.", | |
"description": "Eine männliche Stimme mit österreicherischem Akzent.", | |
}, | |
"Male 3": { | |
"path": "speakers/male_energic.mp3", | |
"transcript": "Wo keine Infrastruktur, da auch keine Ansiedlung von IT-Unternehmen und deren Beschäftigten bzw. dem geeigneten Fachkräftenachwuchs. Kann man diese Rechnung so einfach aufmachen, wie es es tatsächlich um deren regionale Verteilung beschäftigt?", | |
"description": "Eine männliche energische Stimme", | |
}, | |
"Male 4": { | |
"path": "speakers/schneller_speaker.mp3", | |
"transcript": "Genau, wenn wir alle Dächer voll machen, also alle Dächer von Einfamilienhäusern, alleine mit den Einfamilienhäusern können wir 20 Prozent des heutigen Strombedarfs decken.", | |
"description": "Eine männliche Spreche mit schnellerem Tempo.", | |
}, | |
"Female 1": { | |
"path": "speakers/female_standard.mp3", | |
"transcript": "Es wird ein Beispiel für ein barrierearmes Layout gegeben, sowie Tipps und ein Verweis auf eine Checkliste, die hilft, Barrierearmut in den eigenen Materialien zu prüfen bzw. umzusetzen.", | |
"description": "Eine weibliche Stimme.", | |
}, | |
"Female 2": { | |
"path": "speakers/female_energic.mp3", | |
"transcript": "Dunkel flog weiter durch das Wald. Er sah die Sterne am Phaneten an sich vorbeiziehen und fühlte sich frei und glücklich.", | |
"description": "Eine weibliche Erzähler-Stimme.", | |
}, | |
"Female 3": { | |
"path": "speakers/austrian_accent.mp3", | |
"transcript": "Die politische Europäische Union war geboren, verbrieft im Vertrag von Maastricht. Ab diesem Zeitpunkt bestehen zwei Vertragswerke.", | |
"description": "Eine weibliche Stimme mit österreicherischem Akzent.", | |
}, | |
"Special 1": { | |
"path": "speakers/low_audio.mp3", | |
"transcript": "Druckplatten und Lasersensoren, um sicherzugehen, dass er auch da drin ist und", | |
"description": "Eine männliche Stimme mit schlechter Audioqualität als Effekt.", | |
}, | |
} | |
def preview_speaker(display_name): | |
"""Returns the audio and transcript for preview""" | |
speaker_name = speaker_display_dict[display_name] | |
if speaker_name in SPEAKERS: | |
waveform, sample_rate = torchaudio.load(SPEAKERS[speaker_name]["path"]) | |
return (sample_rate, waveform[0].numpy()), SPEAKERS[speaker_name]["transcript"] | |
return None, "" | |
def normalize_audio(waveform: torch.Tensor, target_db: float = -20) -> torch.Tensor: | |
""" | |
Normalize audio volume to target dB and limit gain range. | |
Args: | |
waveform (torch.Tensor): Input audio waveform | |
target_db (float): Target dB level (default: -20) | |
Returns: | |
torch.Tensor: Normalized audio waveform | |
""" | |
# Calculate current dB | |
eps = 1e-10 | |
current_db = 20 * torch.log10(torch.max(torch.abs(waveform)) + eps) | |
# Calculate required gain | |
gain_db = target_db - current_db | |
# Limit gain to -3 to 3 dB range | |
gain_db = torch.clamp(gain_db, min=-3, max=3) | |
# Apply gain | |
gain_factor = 10 ** (gain_db / 20) | |
normalized = waveform * gain_factor | |
# Final peak normalization | |
max_amplitude = torch.max(torch.abs(normalized)) | |
if max_amplitude > 0: | |
normalized = normalized / max_amplitude | |
return normalized | |
def ids_to_speech_tokens(speech_ids): | |
speech_tokens_str = [] | |
for speech_id in speech_ids: | |
speech_tokens_str.append(f"<|s_{speech_id}|>") | |
return speech_tokens_str | |
def extract_speech_ids(speech_tokens_str): | |
speech_ids = [] | |
for token_str in speech_tokens_str: | |
if token_str.startswith("<|s_") and token_str.endswith("|>"): | |
num_str = token_str[4:-2] | |
num = int(num_str) | |
speech_ids.append(num) | |
else: | |
print(f"Unexpected token: {token_str}") | |
return speech_ids | |
def infer_with_speaker( | |
display_name, | |
target_text, | |
temp, | |
top_p_val, | |
min_new_tokens, | |
do_sample, | |
progress=gr.Progress(), | |
): | |
"""Modified infer function that uses predefined speaker""" | |
speaker_name = speaker_display_dict[display_name] # Get actual speaker name | |
if speaker_name not in SPEAKERS: | |
return None, "Invalid speaker selected" | |
return infer( | |
SPEAKERS[speaker_name]["path"], | |
target_text, | |
temp, | |
top_p_val, | |
min_new_tokens, | |
do_sample, | |
SPEAKERS[speaker_name]["transcript"], # Pass the predefined transcript | |
progress, | |
) | |
def gradio_infer(*args, **kwargs): | |
return infer(*args, **kwargs) | |
def infer( | |
sample_audio_path, | |
target_text, | |
temp, | |
top_p_val, | |
min_new_tokens, | |
do_sample, | |
transcribed_text=None, | |
progress=gr.Progress(), | |
): | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: | |
progress(0, "Loading and trimming audio...") | |
waveform, sample_rate = torchaudio.load(sample_audio_path) | |
waveform = normalize_audio(waveform) | |
if len(waveform[0]) / sample_rate > 15: | |
gr.Warning("Trimming audio to first 15secs.") | |
waveform = waveform[:, : sample_rate * 15] | |
waveform = torch.nn.functional.pad( | |
waveform, (0, int(sample_rate * 0.5)), "constant", 0 | |
) | |
# Check if the audio is stereo (i.e., has more than one channel) | |
if waveform.size(0) > 1: | |
# Convert stereo to mono by averaging the channels | |
waveform_mono = torch.mean(waveform, dim=0, keepdim=True) | |
else: | |
# If already mono, just use the original waveform | |
waveform_mono = waveform | |
prompt_wav = torchaudio.transforms.Resample( | |
orig_freq=sample_rate, new_freq=16000 | |
)(waveform_mono) | |
if transcribed_text is None: | |
progress(0.3, "Transcribing audio...") | |
prompt_text = whisper_turbo_pipe(prompt_wav[0].numpy())["text"].strip() | |
else: | |
prompt_text = transcribed_text | |
progress(0.5, "Transcribed! Generating speech...") | |
if len(target_text) == 0: | |
return None | |
elif len(target_text) > 500: | |
gr.Warning("Text is too long. Please keep it under 300 characters.") | |
target_text = target_text[:500] | |
input_text = prompt_text + " " + target_text | |
# TTS start! | |
with torch.no_grad(): | |
# Encode the prompt wav | |
vq_code_prompt = Codec_model.encode_code(input_waveform=prompt_wav) | |
vq_code_prompt = vq_code_prompt[0, 0, :] | |
# Convert int 12345 to token <|s_12345|> | |
speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt) | |
formatted_text = ( | |
f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>" | |
) | |
# Tokenize the text and the speech prefix | |
chat = [ | |
{ | |
"role": "user", | |
"content": "Convert the text to speech:" + formatted_text, | |
}, | |
{ | |
"role": "assistant", | |
"content": "<|SPEECH_GENERATION_START|>" | |
+ "".join(speech_ids_prefix), | |
}, | |
] | |
input_ids = tokenizer.apply_chat_template( | |
chat, | |
tokenize=True, | |
return_tensors="pt", | |
continue_final_message=True, | |
) | |
input_ids = input_ids.to("cuda") | |
speech_end_id = tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>") | |
# Generate the speech autoregressively | |
outputs = model.generate( | |
input_ids, | |
max_length=2048, # We trained our model with a max length of 2048 | |
eos_token_id=speech_end_id, | |
do_sample=do_sample, | |
top_p=top_p_val, | |
temperature=temp, | |
min_new_tokens=min_new_tokens, | |
) | |
# Extract the speech tokens | |
generated_ids = outputs[0][input_ids.shape[1] - len(speech_ids_prefix) : -1] | |
speech_tokens = tokenizer.batch_decode( | |
generated_ids, skip_special_tokens=False | |
) | |
raw_output = " ".join(speech_tokens) # Capture raw tokens | |
speech_tokens = tokenizer.batch_decode( | |
generated_ids, skip_special_tokens=True | |
) | |
# Convert token <|s_23456|> to int 23456 | |
speech_tokens = extract_speech_ids(speech_tokens) | |
speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0) | |
# Decode the speech tokens to speech waveform | |
gen_wav = Codec_model.decode_code(speech_tokens) | |
# if only need the generated part | |
gen_wav = gen_wav[:, :, prompt_wav.shape[1] :] | |
progress(1, "Synthesized!") | |
return ( | |
16000, | |
gen_wav[0, 0, :].cpu().numpy(), | |
), raw_output # Return both audio and raw tokens | |
with gr.Blocks() as app_tts: | |
gr.Markdown("# Zero Shot Voice Clone TTS") | |
with gr.Accordion("Model Settings", open=False): | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.8, | |
step=0.1, | |
label="Temperature", | |
info="Higher values = more random/creative output", | |
) | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=1.0, | |
step=0.1, | |
label="Top P", | |
info="Nucleus sampling threshold", | |
) | |
min_new_tokens = gr.Slider( | |
minimum=0, | |
maximum=128, | |
value=3, | |
step=1, | |
label="Min Length", | |
info="If the model just produces a click you can force it to create longer generations.", | |
) | |
do_sample = gr.Checkbox( | |
label="Sample", value=True, info="Sample from the distribution" | |
) | |
ref_audio_input = gr.Audio(label="Reference Audio", type="filepath") | |
gen_text_input = gr.Textbox(label="Text to Generate", lines=10) | |
generate_btn = gr.Button("Synthesize", variant="primary") | |
audio_output = gr.Audio(label="Synthesized Audio") | |
raw_output_display = gr.Textbox( | |
label="Raw Model Output", interactive=False | |
) # Add textbox | |
generate_btn.click( | |
lambda *args: gradio_infer(*args, transcribed_text=None), | |
inputs=[ | |
ref_audio_input, | |
gen_text_input, | |
temperature, | |
top_p, | |
min_new_tokens, | |
do_sample, | |
], | |
outputs=[audio_output, raw_output_display], # Include both outputs | |
) | |
with gr.Blocks() as app_speaker: | |
gr.Markdown("# Predefined Speaker TTS") | |
with gr.Accordion("Model Settings", open=False): | |
temperature = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature", | |
info="Higher values = more random/creative output", | |
) | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=1.0, | |
step=0.1, | |
label="Top P", | |
info="Nucleus sampling threshold", | |
) | |
min_new_tokens = gr.Slider( | |
minimum=0, | |
maximum=128, | |
value=3, | |
step=1, | |
label="Min Length", | |
info="If the model just produces a click you can force it to create longer generations.", | |
) | |
do_sample = gr.Checkbox( | |
label="Sample", value=True, info="Sample from the distribution" | |
) | |
with gr.Row(): | |
speaker_display_dict = { | |
f"{name} - {SPEAKERS[name]['description']}": name | |
for name in SPEAKERS.keys() | |
} | |
speaker_dropdown = gr.Dropdown( | |
choices=list(speaker_display_dict.keys()), | |
label="Select Speaker", | |
value=list(speaker_display_dict.keys())[0], | |
) | |
preview_btn = gr.Button("Preview Voice") | |
with gr.Row(): | |
preview_audio = gr.Audio(label="Preview") | |
preview_text = gr.Textbox(label="Original Transcript", interactive=False) | |
gen_text_input = gr.Textbox(label="Text to Generate", lines=10) | |
generate_btn = gr.Button("Synthesize", variant="primary") | |
audio_output = gr.Audio(label="Synthesized Audio") | |
raw_output_display = gr.Textbox(label="Raw Model Output", interactive=False) | |
# Connect the preview button | |
preview_btn.click( | |
preview_speaker, | |
inputs=[speaker_dropdown], | |
outputs=[preview_audio, preview_text], | |
) | |
# Connect the generate button | |
generate_btn.click( | |
infer_with_speaker, | |
inputs=[ | |
speaker_dropdown, | |
gen_text_input, | |
temperature, | |
top_p, | |
min_new_tokens, | |
do_sample, | |
], | |
outputs=[audio_output, raw_output_display], | |
) | |
with gr.Blocks() as app_credits: | |
gr.Markdown(""" | |
# Credits | |
* [zhenye234](https://github.com/zhenye234) for the original [repo](https://github.com/zhenye234/LLaSA_training) | |
* [mrfakename](https://huggingface.co/mrfakename) for the [gradio demo code](https://huggingface.co/spaces/mrfakename/E2-F5-TTS) | |
""") | |
with gr.Blocks() as app: | |
gr.Markdown( | |
""" | |
Official Multilingual version | |
""" | |
) | |
gr.TabbedInterface([app_speaker, app_tts], ["Speaker", "Clone"]) | |
app.launch(ssr_mode=False) | |