File size: 4,559 Bytes
2f07364
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import base64
import io

import numpy as np
import torch
from fastapi import HTTPException
from indicnlp.tokenize import sentence_tokenize
from mosestokenizer import MosesSentenceSplitter
from scipy.io.wavfile import write
from tts_infer.num_to_word_on_sent import normalize_nums

from src import log_setup
from src.infer.model_inference import ModelService
from src.model.language import Language
from src.model.tts_request import TTSRequest
from src.model.tts_response import TTSResponse, AudioFile, AudioConfig

LOGGER = log_setup.get_logger(__name__)
model_service = ModelService()
_INDIC = ["as", "bn", "gu", "hi", "kn", "ml", "mr", "or", "pa", "ta", "te"]
_PURAM_VIRAM_LANGUAGES = ["hi", "or", "bn", "as"]
_TRANSLITERATION_NOT_AVAILABLE_IN = ["en","or"]

def infer_tts_request(request: TTSRequest):
    config = request.config
    lang = config.language.sourceLanguage
    gender = config.gender
    output_list = []
    audio_config = AudioConfig(language=Language(sourceLanguage=lang))
    try:
        for sentence in request.input:
            LOGGER.debug(f'infer for gender {gender} and lang {lang} text {sentence.source}')
            speech_response = infer_tts(language=lang, gender=gender, text_to_infer=sentence.source)
            LOGGER.debug(f'infer done for text {sentence.source}')
            output_list.append(speech_response)

        return TTSResponse(audio=output_list, config=audio_config)
    except Exception as e:
        LOGGER.exception('Failed to infer %s', e)
        raise e


def infer_tts(language: str, gender: str, text_to_infer: str):
    choice = language + "_" + gender
    LOGGER.debug(f'choice for model {choice}')

    if choice in model_service.available_choice.keys():
        t2s = model_service.available_choice[choice]
    else:
        raise NotImplementedError('Requested model not found')

    if text_to_infer:
        text_to_infer = normalize_text(text_to_infer, language)

        # if len(text_to_infer) > settings.tts_max_text_limit:
        LOGGER.debug("Running in paragraph mode...")
        audio, sr = run_tts_paragraph(text_to_infer, language, t2s)
        #         else:
        #             LOGGER.debug("Running in text mode...")
        #             audio, sr = run_tts(text_to_infer, language, t2s)
        torch.cuda.empty_cache()  # TODO: find better approach for this
        LOGGER.debug('Audio generates successfully')
        bytes_wav = bytes()
        byte_io = io.BytesIO(bytes_wav)
        write(byte_io, sr, audio)
        encoded_bytes = base64.b64encode(byte_io.read())
        encoded_string = encoded_bytes.decode()
        LOGGER.debug(f'Encoded Audio string {encoded_string}')
        return AudioFile(audioContent=encoded_string)
    else:
        raise HTTPException(status_code=400, detail={"error": "No text"})


def split_sentences(paragraph, language):
    if language == "en":
        with MosesSentenceSplitter(language) as splitter:
            return splitter([paragraph])
    elif language in _INDIC:
        return sentence_tokenize.sentence_split(paragraph, lang=language)


def normalize_text(text, lang):
    if lang in _PURAM_VIRAM_LANGUAGES:
        text = text.replace('|', '।')
        text = text.replace('.', '।')
    return text


def pre_process_text(text, lang):
    if lang == 'hi':
        text = text.replace('।', '.')  # only for hindi models
        
    if lang == 'en' and text[-1] != '.':
            text = text + '. '
            
    return text


def run_tts_paragraph(text, lang, t2s):
    audio_list = []
    split_sentences_list = split_sentences(text, language=lang)

    for sent in split_sentences_list:
        audio, sr = run_tts(pre_process_text(sent, lang), lang, t2s)
        audio_list.append(audio)

    concatenated_audio = np.concatenate([i for i in audio_list])
    # write(filename='temp_long.wav', rate=sr, data=concatenated_audio)
    return concatenated_audio, sr


def run_tts(text, lang, t2s):
    text_num_to_word = normalize_nums(text, lang)  # converting numbers to words in lang
    if lang not in _TRANSLITERATION_NOT_AVAILABLE_IN:
        text_num_to_word_and_transliterated = model_service.transliterate_obj.translit_sentence(text_num_to_word,
                                                                                                lang)  # transliterating english words to lang
    else:
        text_num_to_word_and_transliterated = text_num_to_word
    mel = t2s[0].generate_mel(' ' + text_num_to_word_and_transliterated)
    audio, sr = t2s[1].generate_wav(mel)
    return audio, sr