tts-openai / src /application /tts_preprocess.py
imseldrith's picture
Upload folder using huggingface_hub (#1)
2f07364 verified
import base64
import io
import numpy as np
import torch
from fastapi import HTTPException
from indicnlp.tokenize import sentence_tokenize
from mosestokenizer import MosesSentenceSplitter
from scipy.io.wavfile import write
from tts_infer.num_to_word_on_sent import normalize_nums
from src import log_setup
from src.infer.model_inference import ModelService
from src.model.language import Language
from src.model.tts_request import TTSRequest
from src.model.tts_response import TTSResponse, AudioFile, AudioConfig
LOGGER = log_setup.get_logger(__name__)
model_service = ModelService()
_INDIC = ["as", "bn", "gu", "hi", "kn", "ml", "mr", "or", "pa", "ta", "te"]
_PURAM_VIRAM_LANGUAGES = ["hi", "or", "bn", "as"]
_TRANSLITERATION_NOT_AVAILABLE_IN = ["en","or"]
def infer_tts_request(request: TTSRequest):
config = request.config
lang = config.language.sourceLanguage
gender = config.gender
output_list = []
audio_config = AudioConfig(language=Language(sourceLanguage=lang))
try:
for sentence in request.input:
LOGGER.debug(f'infer for gender {gender} and lang {lang} text {sentence.source}')
speech_response = infer_tts(language=lang, gender=gender, text_to_infer=sentence.source)
LOGGER.debug(f'infer done for text {sentence.source}')
output_list.append(speech_response)
return TTSResponse(audio=output_list, config=audio_config)
except Exception as e:
LOGGER.exception('Failed to infer %s', e)
raise e
def infer_tts(language: str, gender: str, text_to_infer: str):
choice = language + "_" + gender
LOGGER.debug(f'choice for model {choice}')
if choice in model_service.available_choice.keys():
t2s = model_service.available_choice[choice]
else:
raise NotImplementedError('Requested model not found')
if text_to_infer:
text_to_infer = normalize_text(text_to_infer, language)
# if len(text_to_infer) > settings.tts_max_text_limit:
LOGGER.debug("Running in paragraph mode...")
audio, sr = run_tts_paragraph(text_to_infer, language, t2s)
# else:
# LOGGER.debug("Running in text mode...")
# audio, sr = run_tts(text_to_infer, language, t2s)
torch.cuda.empty_cache() # TODO: find better approach for this
LOGGER.debug('Audio generates successfully')
bytes_wav = bytes()
byte_io = io.BytesIO(bytes_wav)
write(byte_io, sr, audio)
encoded_bytes = base64.b64encode(byte_io.read())
encoded_string = encoded_bytes.decode()
LOGGER.debug(f'Encoded Audio string {encoded_string}')
return AudioFile(audioContent=encoded_string)
else:
raise HTTPException(status_code=400, detail={"error": "No text"})
def split_sentences(paragraph, language):
if language == "en":
with MosesSentenceSplitter(language) as splitter:
return splitter([paragraph])
elif language in _INDIC:
return sentence_tokenize.sentence_split(paragraph, lang=language)
def normalize_text(text, lang):
if lang in _PURAM_VIRAM_LANGUAGES:
text = text.replace('|', '।')
text = text.replace('.', '।')
return text
def pre_process_text(text, lang):
if lang == 'hi':
text = text.replace('।', '.') # only for hindi models
if lang == 'en' and text[-1] != '.':
text = text + '. '
return text
def run_tts_paragraph(text, lang, t2s):
audio_list = []
split_sentences_list = split_sentences(text, language=lang)
for sent in split_sentences_list:
audio, sr = run_tts(pre_process_text(sent, lang), lang, t2s)
audio_list.append(audio)
concatenated_audio = np.concatenate([i for i in audio_list])
# write(filename='temp_long.wav', rate=sr, data=concatenated_audio)
return concatenated_audio, sr
def run_tts(text, lang, t2s):
text_num_to_word = normalize_nums(text, lang) # converting numbers to words in lang
if lang not in _TRANSLITERATION_NOT_AVAILABLE_IN:
text_num_to_word_and_transliterated = model_service.transliterate_obj.translit_sentence(text_num_to_word,
lang) # transliterating english words to lang
else:
text_num_to_word_and_transliterated = text_num_to_word
mel = t2s[0].generate_mel(' ' + text_num_to_word_and_transliterated)
audio, sr = t2s[1].generate_wav(mel)
return audio, sr