sachin
adui
b88af23
from fastapi import APIRouter, Body, HTTPException
from fastapi.responses import StreamingResponse
from models.tts_manager import TTSModelManager
from tts_config import SPEED, ResponseFormat, config
from utils.helpers import chunk_text
from logging_config import logger
from typing import Annotated, List
import io
import zipfile
import soundfile as sf
import numpy as np
from time import perf_counter
import torch
router = APIRouter()
tts_model_manager = TTSModelManager()
@router.post("/audio/speech")
async def generate_audio(
input: Annotated[str, Body()] = config.input,
voice: Annotated[str, Body()] = config.voice,
model: Annotated[str, Body()] = config.model,
response_format: Annotated[ResponseFormat, Body(include_in_schema=False)] = config.response_format,
speed: Annotated[float, Body(include_in_schema=False)] = SPEED,
) -> StreamingResponse:
tts, tokenizer, description_tokenizer = tts_model_manager.get_or_load_model(model)
if speed != SPEED:
logger.warning("Specifying speed isn't supported by this model. Audio will be generated with the default speed")
start = perf_counter()
cache_key = f"{input}_{voice}_{response_format}"
if cache_key in tts_model_manager.audio_cache:
logger.info("Returning cached audio")
audio_buffer = io.BytesIO(tts_model_manager.audio_cache[cache_key])
audio_buffer.seek(0)
return StreamingResponse(audio_buffer, media_type=f"audio/{response_format}")
all_chunks = chunk_text(input, chunk_size=10)
cache_key_voice = f"voice_{voice}"
if cache_key_voice in tts_model_manager.voice_cache:
desc_inputs = tts_model_manager.voice_cache[cache_key_voice]
logger.info("Using cached voice description")
else:
desc_inputs = description_tokenizer(voice,
return_tensors="pt",
padding="max_length",
max_length=tts_model_manager.max_length).to("cuda" if torch.cuda.is_available() else "cpu")
tts_model_manager.voice_cache[cache_key_voice] = desc_inputs
if len(all_chunks) == 1:
prompt_inputs = tokenizer(input,
return_tensors="pt",
padding="max_length",
max_length=tts_model_manager.max_length).to("cuda" if torch.cuda.is_available() else "cpu")
generation = tts.generate(
input_ids=desc_inputs["input_ids"],
prompt_input_ids=prompt_inputs["input_ids"],
attention_mask=desc_inputs["attention_mask"],
prompt_attention_mask=prompt_inputs["attention_mask"]
).to(torch.float32)
audio_arr = generation.cpu().float().numpy().squeeze()
else:
all_descriptions = [voice] * len(all_chunks)
description_inputs = description_tokenizer(all_descriptions,
return_tensors="pt",
padding=True).to("cuda" if torch.cuda.is_available() else "cpu")
prompts = tokenizer(all_chunks,
return_tensors="pt",
padding=True).to("cuda" if torch.cuda.is_available() else "cpu")
generation = tts.generate(
input_ids=description_inputs["input_ids"],
attention_mask=description_inputs["attention_mask"],
prompt_input_ids=prompts["input_ids"],
prompt_attention_mask=prompts["attention_mask"],
do_sample=False,
return_dict_in_generate=True,
)
chunk_audios = []
for i, audio in enumerate(generation.sequences):
audio_data = audio[:generation.audios_length[i]].cpu().float().numpy().squeeze()
chunk_audios.append(audio_data)
audio_arr = np.concatenate(chunk_audios)
logger.info(f"Took {perf_counter() - start:.2f} seconds to generate audio for {len(input.split())} words")
audio_buffer = io.BytesIO()
sf.write(audio_buffer, audio_arr, tts.config.sampling_rate, format=response_format)
audio_buffer.seek(0)
tts_model_manager.audio_cache[cache_key] = audio_buffer.getvalue()
return StreamingResponse(audio_buffer, media_type=f"audio/{response_format}")
@router.post("/audio/speech_batch")
async def generate_audio_batch(
input: Annotated[List[str], Body()] = config.input,
voice: Annotated[List[str], Body()] = config.voice,
model: Annotated[str, Body(include_in_schema=False)] = config.model,
response_format: Annotated[ResponseFormat, Body()] = config.response_format,
speed: Annotated[float, Body(include_in_schema=False)] = SPEED,
) -> StreamingResponse:
tts, tokenizer, description_tokenizer = tts_model_manager.get_or_load_model(model)
if speed != SPEED:
logger.warning("Specifying speed isn't supported by this model. Audio will be generated with the default speed")
start = perf_counter()
cached_outputs = []
uncached_inputs = []
uncached_voices = []
cache_keys = [f"{text}_{voice[i]}_{response_format}" for i, text in enumerate(input)]
for i, key in enumerate(cache_keys):
if key in tts_model_manager.audio_cache:
cached_outputs.append((i, tts_model_manager.audio_cache[key]))
else:
uncached_inputs.append(input[i])
uncached_voices.append(voice[i])
if uncached_inputs:
all_chunks = []
all_descriptions = []
for i, text in enumerate(uncached_inputs):
chunks = chunk_text(text, chunk_size=10)
all_chunks.extend(chunks)
all_descriptions.extend([uncached_voices[i]] * len(chunks))
unique_descriptions = list(set(all_descriptions))
desc_inputs_dict = {}
for desc in unique_descriptions:
cache_key_voice = f"voice_{desc}"
if cache_key_voice in tts_model_manager.voice_cache:
desc_inputs_dict[desc] = tts_model_manager.voice_cache[cache_key_voice]
else:
desc_inputs = description_tokenizer(desc,
return_tensors="pt",
padding="max_length",
max_length=tts_model_manager.max_length).to("cuda" if torch.cuda.is_available() else "cpu")
desc_inputs_dict[desc] = desc_inputs
tts_model_manager.voice_cache[cache_key_voice] = desc_inputs
description_inputs = description_tokenizer(all_descriptions,
return_tensors="pt",
padding=True).to("cuda" if torch.cuda.is_available() else "cpu")
prompts = tokenizer(all_chunks,
return_tensors="pt",
padding=True).to("cuda" if torch.cuda.is_available() else "cpu")
generation = tts.generate(
input_ids=description_inputs["input_ids"],
attention_mask=description_inputs["attention_mask"],
prompt_input_ids=prompts["input_ids"],
prompt_attention_mask=prompts["attention_mask"],
do_sample=False,
return_dict_in_generate=True,
)
audio_outputs = []
current_index = 0
for i, text in enumerate(uncached_inputs):
chunks = chunk_text(text, chunk_size=10)
chunk_audios = []
for _ in range(len(chunks)):
audio_arr = generation.sequences[current_index][:generation.audios_length[current_index]].cpu().float().numpy().squeeze()
chunk_audios.append(audio_arr)
current_index += 1
combined_audio = np.concatenate(chunk_audios)
audio_outputs.append(combined_audio)
for i, (text, voice_) in enumerate(zip(uncached_inputs, uncached_voices)):
key = f"{text}_{voice_}_{response_format}"
audio_buffer = io.BytesIO()
sf.write(audio_buffer, audio_outputs[i], tts.config.sampling_rate, format=response_format)
audio_buffer.seek(0)
tts_model_manager.audio_cache[key] = audio_buffer.getvalue()
final_outputs = [None] * len(input)
for idx, audio_data in cached_outputs:
final_outputs[idx] = audio_data
uncached_idx = 0
for i in range(len(final_outputs)):
if final_outputs[i] is None:
audio_buffer = io.BytesIO()
sf.write(audio_buffer, audio_outputs[uncached_idx], tts.config.sampling_rate, format=response_format)
audio_buffer.seek(0)
final_outputs[i] = audio_buffer.getvalue()
uncached_idx += 1
file_data = {f"out_{i}.{response_format}": data for i, data in enumerate(final_outputs)}
in_memory_zip = io.BytesIO()
with zipfile.ZipFile(in_memory_zip, 'w') as zipf:
for file_name, data in file_data.items():
zipf.writestr(file_name, data)
in_memory_zip.seek(0)
logger.info(f"Took {perf_counter() - start:.2f} seconds to generate audio batch")
return StreamingResponse(in_memory_zip, media_type="application/zip")