|
from fastapi import APIRouter, Body, HTTPException |
|
from fastapi.responses import StreamingResponse |
|
from models.tts_manager import TTSModelManager |
|
from tts_config import SPEED, ResponseFormat, config |
|
from utils.helpers import chunk_text |
|
from logging_config import logger |
|
from typing import Annotated, List |
|
import io |
|
import zipfile |
|
import soundfile as sf |
|
import numpy as np |
|
from time import perf_counter |
|
import torch |
|
router = APIRouter() |
|
tts_model_manager = TTSModelManager() |
|
|
|
@router.post("/audio/speech") |
|
async def generate_audio( |
|
input: Annotated[str, Body()] = config.input, |
|
voice: Annotated[str, Body()] = config.voice, |
|
model: Annotated[str, Body()] = config.model, |
|
response_format: Annotated[ResponseFormat, Body(include_in_schema=False)] = config.response_format, |
|
speed: Annotated[float, Body(include_in_schema=False)] = SPEED, |
|
) -> StreamingResponse: |
|
tts, tokenizer, description_tokenizer = tts_model_manager.get_or_load_model(model) |
|
if speed != SPEED: |
|
logger.warning("Specifying speed isn't supported by this model. Audio will be generated with the default speed") |
|
start = perf_counter() |
|
|
|
cache_key = f"{input}_{voice}_{response_format}" |
|
if cache_key in tts_model_manager.audio_cache: |
|
logger.info("Returning cached audio") |
|
audio_buffer = io.BytesIO(tts_model_manager.audio_cache[cache_key]) |
|
audio_buffer.seek(0) |
|
return StreamingResponse(audio_buffer, media_type=f"audio/{response_format}") |
|
|
|
all_chunks = chunk_text(input, chunk_size=10) |
|
|
|
cache_key_voice = f"voice_{voice}" |
|
if cache_key_voice in tts_model_manager.voice_cache: |
|
desc_inputs = tts_model_manager.voice_cache[cache_key_voice] |
|
logger.info("Using cached voice description") |
|
else: |
|
desc_inputs = description_tokenizer(voice, |
|
return_tensors="pt", |
|
padding="max_length", |
|
max_length=tts_model_manager.max_length).to("cuda" if torch.cuda.is_available() else "cpu") |
|
tts_model_manager.voice_cache[cache_key_voice] = desc_inputs |
|
|
|
if len(all_chunks) == 1: |
|
prompt_inputs = tokenizer(input, |
|
return_tensors="pt", |
|
padding="max_length", |
|
max_length=tts_model_manager.max_length).to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
generation = tts.generate( |
|
input_ids=desc_inputs["input_ids"], |
|
prompt_input_ids=prompt_inputs["input_ids"], |
|
attention_mask=desc_inputs["attention_mask"], |
|
prompt_attention_mask=prompt_inputs["attention_mask"] |
|
).to(torch.float32) |
|
audio_arr = generation.cpu().float().numpy().squeeze() |
|
else: |
|
all_descriptions = [voice] * len(all_chunks) |
|
description_inputs = description_tokenizer(all_descriptions, |
|
return_tensors="pt", |
|
padding=True).to("cuda" if torch.cuda.is_available() else "cpu") |
|
prompts = tokenizer(all_chunks, |
|
return_tensors="pt", |
|
padding=True).to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
generation = tts.generate( |
|
input_ids=description_inputs["input_ids"], |
|
attention_mask=description_inputs["attention_mask"], |
|
prompt_input_ids=prompts["input_ids"], |
|
prompt_attention_mask=prompts["attention_mask"], |
|
do_sample=False, |
|
return_dict_in_generate=True, |
|
) |
|
|
|
chunk_audios = [] |
|
for i, audio in enumerate(generation.sequences): |
|
audio_data = audio[:generation.audios_length[i]].cpu().float().numpy().squeeze() |
|
chunk_audios.append(audio_data) |
|
audio_arr = np.concatenate(chunk_audios) |
|
|
|
logger.info(f"Took {perf_counter() - start:.2f} seconds to generate audio for {len(input.split())} words") |
|
|
|
audio_buffer = io.BytesIO() |
|
sf.write(audio_buffer, audio_arr, tts.config.sampling_rate, format=response_format) |
|
audio_buffer.seek(0) |
|
tts_model_manager.audio_cache[cache_key] = audio_buffer.getvalue() |
|
return StreamingResponse(audio_buffer, media_type=f"audio/{response_format}") |
|
|
|
@router.post("/audio/speech_batch") |
|
async def generate_audio_batch( |
|
input: Annotated[List[str], Body()] = config.input, |
|
voice: Annotated[List[str], Body()] = config.voice, |
|
model: Annotated[str, Body(include_in_schema=False)] = config.model, |
|
response_format: Annotated[ResponseFormat, Body()] = config.response_format, |
|
speed: Annotated[float, Body(include_in_schema=False)] = SPEED, |
|
) -> StreamingResponse: |
|
tts, tokenizer, description_tokenizer = tts_model_manager.get_or_load_model(model) |
|
if speed != SPEED: |
|
logger.warning("Specifying speed isn't supported by this model. Audio will be generated with the default speed") |
|
start = perf_counter() |
|
|
|
cached_outputs = [] |
|
uncached_inputs = [] |
|
uncached_voices = [] |
|
cache_keys = [f"{text}_{voice[i]}_{response_format}" for i, text in enumerate(input)] |
|
for i, key in enumerate(cache_keys): |
|
if key in tts_model_manager.audio_cache: |
|
cached_outputs.append((i, tts_model_manager.audio_cache[key])) |
|
else: |
|
uncached_inputs.append(input[i]) |
|
uncached_voices.append(voice[i]) |
|
|
|
if uncached_inputs: |
|
all_chunks = [] |
|
all_descriptions = [] |
|
for i, text in enumerate(uncached_inputs): |
|
chunks = chunk_text(text, chunk_size=10) |
|
all_chunks.extend(chunks) |
|
all_descriptions.extend([uncached_voices[i]] * len(chunks)) |
|
|
|
unique_descriptions = list(set(all_descriptions)) |
|
desc_inputs_dict = {} |
|
for desc in unique_descriptions: |
|
cache_key_voice = f"voice_{desc}" |
|
if cache_key_voice in tts_model_manager.voice_cache: |
|
desc_inputs_dict[desc] = tts_model_manager.voice_cache[cache_key_voice] |
|
else: |
|
desc_inputs = description_tokenizer(desc, |
|
return_tensors="pt", |
|
padding="max_length", |
|
max_length=tts_model_manager.max_length).to("cuda" if torch.cuda.is_available() else "cpu") |
|
desc_inputs_dict[desc] = desc_inputs |
|
tts_model_manager.voice_cache[cache_key_voice] = desc_inputs |
|
|
|
description_inputs = description_tokenizer(all_descriptions, |
|
return_tensors="pt", |
|
padding=True).to("cuda" if torch.cuda.is_available() else "cpu") |
|
prompts = tokenizer(all_chunks, |
|
return_tensors="pt", |
|
padding=True).to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
generation = tts.generate( |
|
input_ids=description_inputs["input_ids"], |
|
attention_mask=description_inputs["attention_mask"], |
|
prompt_input_ids=prompts["input_ids"], |
|
prompt_attention_mask=prompts["attention_mask"], |
|
do_sample=False, |
|
return_dict_in_generate=True, |
|
) |
|
|
|
audio_outputs = [] |
|
current_index = 0 |
|
for i, text in enumerate(uncached_inputs): |
|
chunks = chunk_text(text, chunk_size=10) |
|
chunk_audios = [] |
|
for _ in range(len(chunks)): |
|
audio_arr = generation.sequences[current_index][:generation.audios_length[current_index]].cpu().float().numpy().squeeze() |
|
chunk_audios.append(audio_arr) |
|
current_index += 1 |
|
combined_audio = np.concatenate(chunk_audios) |
|
audio_outputs.append(combined_audio) |
|
|
|
for i, (text, voice_) in enumerate(zip(uncached_inputs, uncached_voices)): |
|
key = f"{text}_{voice_}_{response_format}" |
|
audio_buffer = io.BytesIO() |
|
sf.write(audio_buffer, audio_outputs[i], tts.config.sampling_rate, format=response_format) |
|
audio_buffer.seek(0) |
|
tts_model_manager.audio_cache[key] = audio_buffer.getvalue() |
|
|
|
final_outputs = [None] * len(input) |
|
for idx, audio_data in cached_outputs: |
|
final_outputs[idx] = audio_data |
|
uncached_idx = 0 |
|
for i in range(len(final_outputs)): |
|
if final_outputs[i] is None: |
|
audio_buffer = io.BytesIO() |
|
sf.write(audio_buffer, audio_outputs[uncached_idx], tts.config.sampling_rate, format=response_format) |
|
audio_buffer.seek(0) |
|
final_outputs[i] = audio_buffer.getvalue() |
|
uncached_idx += 1 |
|
|
|
file_data = {f"out_{i}.{response_format}": data for i, data in enumerate(final_outputs)} |
|
in_memory_zip = io.BytesIO() |
|
with zipfile.ZipFile(in_memory_zip, 'w') as zipf: |
|
for file_name, data in file_data.items(): |
|
zipf.writestr(file_name, data) |
|
in_memory_zip.seek(0) |
|
|
|
logger.info(f"Took {perf_counter() - start:.2f} seconds to generate audio batch") |
|
return StreamingResponse(in_memory_zip, media_type="application/zip") |