Spaces:
Runtime error
Runtime error
import base64 | |
import io | |
import numpy as np | |
import torch | |
from fastapi import HTTPException | |
from indicnlp.tokenize import sentence_tokenize | |
from mosestokenizer import MosesSentenceSplitter | |
from scipy.io.wavfile import write | |
from tts_infer.num_to_word_on_sent import normalize_nums | |
from src import log_setup | |
from src.infer.model_inference import ModelService | |
from src.model.language import Language | |
from src.model.tts_request import TTSRequest | |
from src.model.tts_response import TTSResponse, AudioFile, AudioConfig | |
LOGGER = log_setup.get_logger(__name__) | |
model_service = ModelService() | |
_INDIC = ["as", "bn", "gu", "hi", "kn", "ml", "mr", "or", "pa", "ta", "te"] | |
_PURAM_VIRAM_LANGUAGES = ["hi", "or", "bn", "as"] | |
_TRANSLITERATION_NOT_AVAILABLE_IN = ["en","or"] | |
def infer_tts_request(request: TTSRequest): | |
config = request.config | |
lang = config.language.sourceLanguage | |
gender = config.gender | |
output_list = [] | |
audio_config = AudioConfig(language=Language(sourceLanguage=lang)) | |
try: | |
for sentence in request.input: | |
LOGGER.debug(f'infer for gender {gender} and lang {lang} text {sentence.source}') | |
speech_response = infer_tts(language=lang, gender=gender, text_to_infer=sentence.source) | |
LOGGER.debug(f'infer done for text {sentence.source}') | |
output_list.append(speech_response) | |
return TTSResponse(audio=output_list, config=audio_config) | |
except Exception as e: | |
LOGGER.exception('Failed to infer %s', e) | |
raise e | |
def infer_tts(language: str, gender: str, text_to_infer: str): | |
choice = language + "_" + gender | |
LOGGER.debug(f'choice for model {choice}') | |
if choice in model_service.available_choice.keys(): | |
t2s = model_service.available_choice[choice] | |
else: | |
raise NotImplementedError('Requested model not found') | |
if text_to_infer: | |
text_to_infer = normalize_text(text_to_infer, language) | |
# if len(text_to_infer) > settings.tts_max_text_limit: | |
LOGGER.debug("Running in paragraph mode...") | |
audio, sr = run_tts_paragraph(text_to_infer, language, t2s) | |
# else: | |
# LOGGER.debug("Running in text mode...") | |
# audio, sr = run_tts(text_to_infer, language, t2s) | |
torch.cuda.empty_cache() # TODO: find better approach for this | |
LOGGER.debug('Audio generates successfully') | |
bytes_wav = bytes() | |
byte_io = io.BytesIO(bytes_wav) | |
write(byte_io, sr, audio) | |
encoded_bytes = base64.b64encode(byte_io.read()) | |
encoded_string = encoded_bytes.decode() | |
LOGGER.debug(f'Encoded Audio string {encoded_string}') | |
return AudioFile(audioContent=encoded_string) | |
else: | |
raise HTTPException(status_code=400, detail={"error": "No text"}) | |
def split_sentences(paragraph, language): | |
if language == "en": | |
with MosesSentenceSplitter(language) as splitter: | |
return splitter([paragraph]) | |
elif language in _INDIC: | |
return sentence_tokenize.sentence_split(paragraph, lang=language) | |
def normalize_text(text, lang): | |
if lang in _PURAM_VIRAM_LANGUAGES: | |
text = text.replace('|', '।') | |
text = text.replace('.', '।') | |
return text | |
def pre_process_text(text, lang): | |
if lang == 'hi': | |
text = text.replace('।', '.') # only for hindi models | |
if lang == 'en' and text[-1] != '.': | |
text = text + '. ' | |
return text | |
def run_tts_paragraph(text, lang, t2s): | |
audio_list = [] | |
split_sentences_list = split_sentences(text, language=lang) | |
for sent in split_sentences_list: | |
audio, sr = run_tts(pre_process_text(sent, lang), lang, t2s) | |
audio_list.append(audio) | |
concatenated_audio = np.concatenate([i for i in audio_list]) | |
# write(filename='temp_long.wav', rate=sr, data=concatenated_audio) | |
return concatenated_audio, sr | |
def run_tts(text, lang, t2s): | |
text_num_to_word = normalize_nums(text, lang) # converting numbers to words in lang | |
if lang not in _TRANSLITERATION_NOT_AVAILABLE_IN: | |
text_num_to_word_and_transliterated = model_service.transliterate_obj.translit_sentence(text_num_to_word, | |
lang) # transliterating english words to lang | |
else: | |
text_num_to_word_and_transliterated = text_num_to_word | |
mel = t2s[0].generate_mel(' ' + text_num_to_word_and_transliterated) | |
audio, sr = t2s[1].generate_wav(mel) | |
return audio, sr | |