sachin
commited on
Commit
·
0a0efec
1
Parent(s):
1936ef7
update tts
Browse files- src/server/main.py +31 -5
src/server/main.py
CHANGED
@@ -69,10 +69,29 @@ class Settings(BaseSettings):
|
|
69 |
|
70 |
settings = Settings()
|
71 |
|
72 |
-
# TTS
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
|
|
76 |
EXAMPLES = [
|
77 |
{
|
78 |
"audio_name": "KAN_F (Happy)",
|
@@ -99,7 +118,7 @@ def load_audio_from_url(url: str):
|
|
99 |
return sample_rate, audio_data
|
100 |
raise HTTPException(status_code=500, detail="Failed to load reference audio from URL.")
|
101 |
|
102 |
-
def synthesize_speech(text: str, ref_audio_name: str, ref_text: str):
|
103 |
ref_audio_url = None
|
104 |
for example in EXAMPLES:
|
105 |
if example["audio_name"] == ref_audio_name:
|
@@ -119,7 +138,7 @@ def synthesize_speech(text: str, ref_audio_name: str, ref_text: str):
|
|
119 |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
|
120 |
sf.write(temp_audio.name, audio_data, samplerate=sample_rate, format='WAV')
|
121 |
temp_audio.flush()
|
122 |
-
audio =
|
123 |
|
124 |
if audio.dtype == np.int16:
|
125 |
audio = audio.astype(np.float32) / 32768.0
|
@@ -233,6 +252,7 @@ class ASRModelManager:
|
|
233 |
llm_manager = LLMManager(settings.llm_model_name)
|
234 |
model_manager = ModelManager()
|
235 |
asr_manager = ASRModelManager()
|
|
|
236 |
ip = IndicProcessor(inference=True)
|
237 |
|
238 |
# Pydantic Models
|
@@ -278,6 +298,7 @@ async def lifespan(app: FastAPI):
|
|
278 |
tasks = [
|
279 |
asyncio.create_task(llm_manager.load()),
|
280 |
asyncio.create_task(asr_manager.load()),
|
|
|
281 |
asyncio.create_task(model_manager.load_model('eng_Latn', 'kan_Knda', 'eng_indic')),
|
282 |
asyncio.create_task(model_manager.load_model('kan_Knda', 'eng_Latn', 'indic_eng')),
|
283 |
asyncio.create_task(model_manager.load_model('kan_Knda', 'hin_Deva', 'indic_indic')),
|
@@ -314,11 +335,14 @@ app.state.limiter = limiter
|
|
314 |
# API Endpoints
|
315 |
@app.post("/audio/speech", response_class=StreamingResponse)
|
316 |
async def synthesize_kannada(request: KannadaSynthesizeRequest):
|
|
|
|
|
317 |
kannada_example = next(ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)")
|
318 |
if not request.text.strip():
|
319 |
raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
|
320 |
|
321 |
audio_buffer = synthesize_speech(
|
|
|
322 |
text=request.text,
|
323 |
ref_audio_name="KAN_F (Happy)",
|
324 |
ref_text=kannada_example["ref_text"]
|
@@ -610,6 +634,8 @@ async def speech_to_speech(
|
|
610 |
file: UploadFile = File(...),
|
611 |
language: str = Query(..., enum=list(asr_manager.model_language.keys())),
|
612 |
) -> StreamingResponse:
|
|
|
|
|
613 |
transcription = await transcribe_audio(file, language)
|
614 |
logger.info(f"Transcribed text: {transcription.text}")
|
615 |
|
|
|
69 |
|
70 |
settings = Settings()
|
71 |
|
72 |
+
# TTS Manager
|
73 |
+
class TTSManager:
|
74 |
+
def __init__(self, device_type=device):
|
75 |
+
self.device_type = device_type
|
76 |
+
self.model = None
|
77 |
+
self.repo_id = "ai4bharat/IndicF5"
|
78 |
+
|
79 |
+
async def load(self):
|
80 |
+
logger.info("Loading TTS model IndicF5...")
|
81 |
+
self.model = await asyncio.to_thread(
|
82 |
+
AutoModel.from_pretrained,
|
83 |
+
self.repo_id,
|
84 |
+
trust_remote_code=True
|
85 |
+
)
|
86 |
+
self.model = self.model.to(self.device_type)
|
87 |
+
logger.info("TTS model IndicF5 loaded")
|
88 |
+
|
89 |
+
def synthesize(self, text, ref_audio_path, ref_text):
|
90 |
+
if not self.model:
|
91 |
+
raise ValueError("TTS model not loaded")
|
92 |
+
return self.model(text, ref_audio_path=ref_audio_path, ref_text=ref_text)
|
93 |
|
94 |
+
# TTS Constants
|
95 |
EXAMPLES = [
|
96 |
{
|
97 |
"audio_name": "KAN_F (Happy)",
|
|
|
118 |
return sample_rate, audio_data
|
119 |
raise HTTPException(status_code=500, detail="Failed to load reference audio from URL.")
|
120 |
|
121 |
+
def synthesize_speech(tts_manager: TTSManager, text: str, ref_audio_name: str, ref_text: str):
|
122 |
ref_audio_url = None
|
123 |
for example in EXAMPLES:
|
124 |
if example["audio_name"] == ref_audio_name:
|
|
|
138 |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
|
139 |
sf.write(temp_audio.name, audio_data, samplerate=sample_rate, format='WAV')
|
140 |
temp_audio.flush()
|
141 |
+
audio = tts_manager.synthesize(text, ref_audio_path=temp_audio.name, ref_text=ref_text)
|
142 |
|
143 |
if audio.dtype == np.int16:
|
144 |
audio = audio.astype(np.float32) / 32768.0
|
|
|
252 |
llm_manager = LLMManager(settings.llm_model_name)
|
253 |
model_manager = ModelManager()
|
254 |
asr_manager = ASRModelManager()
|
255 |
+
tts_manager = TTSManager()
|
256 |
ip = IndicProcessor(inference=True)
|
257 |
|
258 |
# Pydantic Models
|
|
|
298 |
tasks = [
|
299 |
asyncio.create_task(llm_manager.load()),
|
300 |
asyncio.create_task(asr_manager.load()),
|
301 |
+
asyncio.create_task(tts_manager.load()),
|
302 |
asyncio.create_task(model_manager.load_model('eng_Latn', 'kan_Knda', 'eng_indic')),
|
303 |
asyncio.create_task(model_manager.load_model('kan_Knda', 'eng_Latn', 'indic_eng')),
|
304 |
asyncio.create_task(model_manager.load_model('kan_Knda', 'hin_Deva', 'indic_indic')),
|
|
|
335 |
# API Endpoints
|
336 |
@app.post("/audio/speech", response_class=StreamingResponse)
|
337 |
async def synthesize_kannada(request: KannadaSynthesizeRequest):
|
338 |
+
if not tts_manager.model:
|
339 |
+
raise HTTPException(status_code=503, detail="TTS model still loading, please try again later")
|
340 |
kannada_example = next(ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)")
|
341 |
if not request.text.strip():
|
342 |
raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
|
343 |
|
344 |
audio_buffer = synthesize_speech(
|
345 |
+
tts_manager,
|
346 |
text=request.text,
|
347 |
ref_audio_name="KAN_F (Happy)",
|
348 |
ref_text=kannada_example["ref_text"]
|
|
|
634 |
file: UploadFile = File(...),
|
635 |
language: str = Query(..., enum=list(asr_manager.model_language.keys())),
|
636 |
) -> StreamingResponse:
|
637 |
+
if not tts_manager.model:
|
638 |
+
raise HTTPException(status_code=503, detail="TTS model still loading, please try again later")
|
639 |
transcription = await transcribe_audio(file, language)
|
640 |
logger.info(f"Transcribed text: {transcription.text}")
|
641 |
|