tortoise5c / app.py
djkesu's picture
revert
65f8628
raw
history blame
11.7 kB
import os
import shutil
from pathlib import Path
import streamlit as st
from random import randint
from tortoise.api import MODELS_DIR
from tortoise.inference import (
infer_on_texts,
run_and_save_tts,
split_and_recombine_text,
)
from tortoise.api import TextToSpeech
from tortoise.utils.diffusion import SAMPLERS
from app_utils.filepicker import st_file_selector
from app_utils.conf import TortoiseConfig
from app_utils.funcs import (
timeit,
load_model,
list_voices,
load_voice_conditionings,
)
LATENT_MODES = [
"Tortoise original (bad)",
"average per 4.27s (broken on small files)",
"average per voice file (broken on small files)",
]
def main():
conf = TortoiseConfig()
voice_samples, conditioning_latents = None, None
with st.expander("Create New Voice", expanded=True):
if "file_uploader_key" not in st.session_state:
st.session_state["file_uploader_key"] = str(randint(1000, 100000000))
st.session_state["text_input_key"] = str(randint(1000, 100000000))
uploaded_files = st.file_uploader(
"Upload Audio Samples for a New Voice",
accept_multiple_files=True,
type=["wav"],
key=st.session_state["file_uploader_key"]
)
voice_name = st.text_input(
"New Voice Name",
help="Enter a name for your new voice.",
value="",
key=st.session_state["text_input_key"]
)
create_voice_button = st.button(
"Create Voice",
disabled = ((voice_name.strip() == "") | (len(uploaded_files) == 0))
)
if create_voice_button:
st.write(st.session_state)
with st.spinner(f"Creating new voice: {voice_name}"):
new_voice_name = voice_name.strip().replace(" ", "_")
voices_dir = f'./tortoise/voices/{new_voice_name}/'
if os.path.exists(voices_dir):
shutil.rmtree(voices_dir)
os.makedirs(voices_dir)
voice_samples = []
for index, uploaded_file in enumerate(uploaded_files):
bytes_data = uploaded_file.read()
with open(f"{voices_dir}voice_sample{index}.wav", "wb") as wav_file:
wav_file.write(bytes_data)
voice_samples.append(f"{voices_dir}voice_sample{index}.wav")
# # Generate conditioning latents and samples here
# voice_samples, conditioning_latents = generate_conditioning(voices_dir)
# # Save the conditioning latents and samples
# save_conditioning(voices_dir, voice_samples, conditioning_latents)
# conditioning_latents = TextToSpeech.get_conditioning_latents(new_voice_name, voice_samples=voice_samples)
# print(voice_samples, conditioning_latents)
st.session_state["text_input_key"] = str(randint(1000, 100000000))
st.session_state["file_uploader_key"] = str(randint(1000, 100000000))
st.experimental_rerun()
text = st.text_area(
"Text",
help="Text to speak.",
value="The expressiveness of autoregressive transformers is literally nuts! I absolutely adore them.",
)
voices = [v for v in os.listdir("tortoise/voices") if v != "cond_latent_example"]
voice = st.selectbox(
"Voice",
voices,
help="Selects the voice to use for generation. See options in voices/ directory (and add your own!) "
"Use the & character to join two voices together. Use a comma to perform inference on multiple voices.",
index=0,
)
preset = st.selectbox(
"Preset",
(
"single_sample",
"ultra_fast",
"very_fast",
"ultra_fast_old",
"fast",
"standard",
"high_quality",
),
help="Which voice preset to use.",
index=1,
)
with st.expander("Advanced"):
col1, col2 = st.columns(2)
with col1:
"""#### Model parameters"""
candidates = st.number_input(
"Candidates",
help="How many output candidates to produce per-voice.",
value=1,
)
latent_averaging_mode = st.radio(
"Latent averaging mode",
LATENT_MODES,
help="How voice samples should be averaged together.",
index=0,
)
sampler = st.radio(
"Sampler",
["dpm++2m", "p", "ddim"],
help="Diffusion sampler. Note that dpm++2m is experimental and typically requires more steps.",
index=1,
)
steps = st.number_input(
"Steps",
help="Override the steps used for diffusion (default depends on preset)",
value=10,
)
seed = st.number_input(
"Seed",
help="Random seed which can be used to reproduce results.",
value=-1,
)
if seed == -1:
seed = None
voice_fixer = st.checkbox(
"Voice fixer",
help="Use `voicefixer` to improve audio quality. This is a post-processing step which can be applied to any output.",
value=True,
)
"""#### Directories"""
output_path = st.text_input(
"Output Path", help="Where to store outputs.", value="results/"
)
with col2:
"""#### Optimizations"""
high_vram = not st.checkbox(
"Low VRAM",
help="Re-enable default offloading behaviour of tortoise",
value=True,
)
half = st.checkbox(
"Half-Precision",
help="Enable autocast to half precision for autoregressive model",
value=False,
)
kv_cache = st.checkbox(
"Key-Value Cache",
help="Enable kv_cache usage, leading to drastic speedups but worse memory usage",
value=True,
)
cond_free = st.checkbox(
"Conditioning Free",
help="Force conditioning free diffusion",
value=True,
)
no_cond_free = st.checkbox(
"Force Not Conditioning Free",
help="Force disable conditioning free diffusion",
value=False,
)
"""#### Text Splitting"""
min_chars_to_split = st.number_input(
"Min Chars to Split",
help="Minimum number of characters to split text on",
min_value=50,
value=200,
step=1,
)
"""#### Debug"""
produce_debug_state = st.checkbox(
"Produce Debug State",
help="Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.",
value=True,
)
ar_checkpoint = "."
diff_checkpoint = "."
if st.button("Update Basic Settings"):
conf.update(
EXTRA_VOICES_DIR=extra_voices_dir,
LOW_VRAM=not high_vram,
AR_CHECKPOINT=ar_checkpoint,
DIFF_CHECKPOINT=diff_checkpoint,
)
ar_checkpoint = None
diff_checkpoint = None
tts = load_model(MODELS_DIR, high_vram, kv_cache, ar_checkpoint, diff_checkpoint)
if st.button("Start"):
assert latent_averaging_mode
assert preset
assert voice
def show_generation(fp, filename: str):
"""
audio_buffer = BytesIO()
save_gen_with_voicefix(g, audio_buffer, squeeze=False)
torchaudio.save(audio_buffer, g, 24000, format='wav')
"""
st.audio(str(fp), format="audio/wav")
st.download_button(
"Download sample",
str(fp),
file_name=filename, # this doesn't actually seem to work lol
)
with st.spinner(
f"Generating {candidates} candidates for voice {voice} (seed={seed}). You can see progress in the terminal"
):
os.makedirs(output_path, exist_ok=True)
selected_voices = voice.split(",")
for k, selected_voice in enumerate(selected_voices):
if "&" in selected_voice:
voice_sel = selected_voice.split("&")
else:
voice_sel = [selected_voice]
voice_samples, conditioning_latents = load_voice_conditionings(
voice_sel, []
)
voice_path = Path(os.path.join(output_path, selected_voice))
with timeit(
f"Generating {candidates} candidates for voice {selected_voice} (seed={seed})"
):
nullable_kwargs = {
k: v
for k, v in zip(
["sampler", "diffusion_iterations", "cond_free"],
[sampler, steps, cond_free],
)
if v is not None
}
def call_tts(text: str):
return tts.tts_with_preset(
text,
k=candidates,
voice_samples=voice_samples,
conditioning_latents=conditioning_latents,
preset=preset,
use_deterministic_seed=seed,
return_deterministic_state=True,
cvvp_amount=0.0,
half=half,
latent_averaging_mode=LATENT_MODES.index(
latent_averaging_mode
),
**nullable_kwargs,
)
if len(text) < min_chars_to_split:
filepaths = run_and_save_tts(
call_tts,
text,
voice_path,
return_deterministic_state=True,
return_filepaths=True,
voicefixer=voice_fixer,
)
for i, fp in enumerate(filepaths):
show_generation(fp, f"{selected_voice}-text-{i}.wav")
else:
desired_length = int(min_chars_to_split)
texts = split_and_recombine_text(
text, desired_length, desired_length + 100
)
filepaths = infer_on_texts(
call_tts,
texts,
voice_path,
return_deterministic_state=True,
return_filepaths=True,
lines_to_regen=set(range(len(texts))),
voicefixer=voice_fixer,
)
for i, fp in enumerate(filepaths):
show_generation(fp, f"{selected_voice}-text-{i}.wav")
if produce_debug_state:
"""Debug states can be found in the output directory"""
if __name__ == "__main__":
main()