sachin
commited on
Commit
·
ef20fe6
1
Parent(s):
6aa9f3a
fix-bud
Browse files- src/server/main.py +23 -45
src/server/main.py
CHANGED
@@ -23,7 +23,6 @@ from tts_config import SPEED, ResponseFormat, config as tts_config
|
|
23 |
from gemma_llm import LLMManager
|
24 |
# from auth import get_api_key, settings as auth_settings
|
25 |
|
26 |
-
|
27 |
import time
|
28 |
from contextlib import asynccontextmanager
|
29 |
from typing import Annotated, Any, OrderedDict, List
|
@@ -95,7 +94,6 @@ class TTSModelManager:
|
|
95 |
if description_tokenizer.pad_token is None:
|
96 |
description_tokenizer.pad_token = description_tokenizer.eos_token
|
97 |
|
98 |
-
|
99 |
# TODO - temporary disable -torch.compile
|
100 |
'''
|
101 |
# Update model configuration
|
@@ -152,7 +150,6 @@ async def lifespan(_: FastAPI):
|
|
152 |
tts_model_manager.get_or_load_model(config.model)
|
153 |
yield
|
154 |
|
155 |
-
#app = FastAPI(lifespan=lifespan)
|
156 |
app = FastAPI(
|
157 |
title="Dhwani API",
|
158 |
description="AI Chat API supporting Indian languages",
|
@@ -161,7 +158,6 @@ app = FastAPI(
|
|
161 |
lifespan=lifespan
|
162 |
)
|
163 |
|
164 |
-
|
165 |
def chunk_text(text, chunk_size):
|
166 |
words = text.split()
|
167 |
chunks = []
|
@@ -197,7 +193,6 @@ async def generate_audio(
|
|
197 |
padding="max_length",
|
198 |
max_length=tts_model_manager.max_length).to(device)
|
199 |
|
200 |
-
# Use the tensor fields directly instead of BatchEncoding object
|
201 |
input_ids = desc_inputs["input_ids"]
|
202 |
attention_mask = desc_inputs["attention_mask"]
|
203 |
prompt_input_ids = prompt_inputs["input_ids"]
|
@@ -323,7 +318,6 @@ async def generate_audio_batch(
|
|
323 |
|
324 |
return StreamingResponse(in_memory_zip, media_type="application/zip")
|
325 |
|
326 |
-
|
327 |
# Supported language codes
|
328 |
SUPPORTED_LANGUAGES = {
|
329 |
"asm_Beng", "kas_Arab", "pan_Guru", "ben_Beng", "kas_Deva", "san_Deva",
|
@@ -354,7 +348,6 @@ class Settings(BaseSettings):
|
|
354 |
|
355 |
settings = Settings()
|
356 |
|
357 |
-
|
358 |
app.add_middleware(
|
359 |
CORSMiddleware,
|
360 |
allow_origins=["*"],
|
@@ -543,7 +536,7 @@ async def load_all_models():
|
|
543 |
return {"status": "success", "message": "All models loaded"}
|
544 |
except Exception as e:
|
545 |
logger.error(f"Error loading models: {str(e)}")
|
546 |
-
raise HTTPException(status_code=500, detail=f"Failed to
|
547 |
|
548 |
@app.post("/v1/translate", response_model=TranslationResponse)
|
549 |
async def translate_endpoint(request: TranslationRequest):
|
@@ -567,13 +560,10 @@ async def chat(request: Request, chat_request: ChatRequest):
|
|
567 |
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
568 |
logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
|
569 |
|
570 |
-
# Define European languages that gemma-3-4b-it can handle natively
|
571 |
EUROPEAN_LANGUAGES = {"deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn", "por_Latn", "rus_Cyrl", "pol_Latn"}
|
572 |
|
573 |
try:
|
574 |
-
# Check if the source language is Indian (requires translation) or European/English (direct processing)
|
575 |
if chat_request.src_lang != "eng_Latn" and chat_request.src_lang not in EUROPEAN_LANGUAGES:
|
576 |
-
# Translate Indian language prompt to English
|
577 |
translated_prompt = await perform_internal_translation(
|
578 |
sentences=[chat_request.prompt],
|
579 |
src_lang=chat_request.src_lang,
|
@@ -582,17 +572,13 @@ async def chat(request: Request, chat_request: ChatRequest):
|
|
582 |
prompt_to_process = translated_prompt[0]
|
583 |
logger.info(f"Translated prompt to English: {prompt_to_process}")
|
584 |
else:
|
585 |
-
# Use prompt directly for English and European languages
|
586 |
prompt_to_process = chat_request.prompt
|
587 |
logger.info("Prompt in English or European language, no translation needed")
|
588 |
|
589 |
-
# Generate response with the LLM (assumed to handle multilingual input natively)
|
590 |
response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
|
591 |
logger.info(f"Generated response: {response}")
|
592 |
|
593 |
-
# Check if the target language is Indian (requires translation) or European/English (direct output)
|
594 |
if chat_request.tgt_lang != "eng_Latn" and chat_request.tgt_lang not in EUROPEAN_LANGUAGES:
|
595 |
-
# Translate response to Indian target language
|
596 |
translated_response = await perform_internal_translation(
|
597 |
sentences=[response],
|
598 |
src_lang="eng_Latn",
|
@@ -601,7 +587,6 @@ async def chat(request: Request, chat_request: ChatRequest):
|
|
601 |
final_response = translated_response[0]
|
602 |
logger.info(f"Translated response to {chat_request.tgt_lang}: {final_response}")
|
603 |
else:
|
604 |
-
# Keep response as-is for English and European languages
|
605 |
final_response = response
|
606 |
logger.info(f"Response in {chat_request.tgt_lang}, no translation needed")
|
607 |
|
@@ -622,7 +607,6 @@ async def visual_query(
|
|
622 |
if image.size == (0, 0):
|
623 |
raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")
|
624 |
|
625 |
-
# Translate query to English if src_lang is not English
|
626 |
if src_lang != "eng_Latn":
|
627 |
translated_query = await perform_internal_translation(
|
628 |
sentences=[query],
|
@@ -635,11 +619,9 @@ async def visual_query(
|
|
635 |
query_to_process = query
|
636 |
logger.info("Query already in English, no translation needed")
|
637 |
|
638 |
-
# Generate response in English
|
639 |
answer = await llm_manager.vision_query(image, query_to_process)
|
640 |
logger.info(f"Generated English answer: {answer}")
|
641 |
|
642 |
-
# Translate answer to target language if tgt_lang is not English
|
643 |
if tgt_lang != "eng_Latn":
|
644 |
translated_answer = await perform_internal_translation(
|
645 |
sentences=[answer],
|
@@ -680,7 +662,6 @@ async def chat_v2(
|
|
680 |
raise HTTPException(status_code=400, detail="Uploaded image is empty")
|
681 |
img = Image.open(io.BytesIO(image_data))
|
682 |
|
683 |
-
# Translate prompt to English if src_lang is not English
|
684 |
if src_lang != "eng_Latn":
|
685 |
translated_prompt = await perform_internal_translation(
|
686 |
sentences=[prompt],
|
@@ -696,7 +677,6 @@ async def chat_v2(
|
|
696 |
decoded = await llm_manager.chat_v2(img, prompt_to_process)
|
697 |
logger.info(f"Generated English response: {decoded}")
|
698 |
|
699 |
-
# Translate response to target language if tgt_lang is not English
|
700 |
if tgt_lang != "eng_Latn":
|
701 |
translated_response = await perform_internal_translation(
|
702 |
sentences=[decoded],
|
@@ -709,7 +689,6 @@ async def chat_v2(
|
|
709 |
final_response = decoded
|
710 |
logger.info("Response kept in English, no translation needed")
|
711 |
else:
|
712 |
-
# Translate prompt to English if src_lang is not English
|
713 |
if src_lang != "eng_Latn":
|
714 |
translated_prompt = await perform_internal_translation(
|
715 |
sentences=[prompt],
|
@@ -725,7 +704,6 @@ async def chat_v2(
|
|
725 |
decoded = await llm_manager.generate(prompt_to_process, settings.max_tokens)
|
726 |
logger.info(f"Generated English response: {decoded}")
|
727 |
|
728 |
-
# Translate response to target language if tgt_lang is not English
|
729 |
if tgt_lang != "eng_Latn":
|
730 |
translated_response = await perform_internal_translation(
|
731 |
sentences=[decoded],
|
@@ -746,7 +724,6 @@ async def chat_v2(
|
|
746 |
class TranscriptionResponse(BaseModel):
|
747 |
text: str
|
748 |
|
749 |
-
|
750 |
class ASRModelManager:
|
751 |
def __init__(self, device_type="cuda"):
|
752 |
self.device_type = device_type
|
@@ -758,7 +735,6 @@ class ASRModelManager:
|
|
758 |
"telugu": "te", "urdu": "ur"
|
759 |
}
|
760 |
|
761 |
-
|
762 |
from fastapi import FastAPI, UploadFile
|
763 |
import torch
|
764 |
import torchaudio
|
@@ -774,27 +750,30 @@ from typing import List
|
|
774 |
# Load the model
|
775 |
model = AutoModel.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True)
|
776 |
|
777 |
-
asr_manager = ASRModelManager()
|
778 |
|
779 |
-
|
780 |
-
|
|
|
|
|
|
|
|
|
|
|
781 |
|
782 |
@app.post("/transcribe/", response_model=TranscriptionResponse)
|
783 |
async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
|
784 |
-
|
785 |
-
|
786 |
-
|
787 |
-
|
788 |
-
|
789 |
-
|
790 |
-
|
791 |
-
|
792 |
-
|
793 |
-
|
794 |
-
|
795 |
-
|
796 |
-
|
797 |
-
return JSONResponse(content={"text": transcription_rnnt})
|
798 |
|
799 |
@app.post("/v1/speech_to_speech")
|
800 |
async def speech_to_speech(
|
@@ -809,8 +788,8 @@ async def speech_to_speech(
|
|
809 |
# Step 2: Process text with chat endpoint
|
810 |
chat_request = ChatRequest(
|
811 |
prompt=transcription.text,
|
812 |
-
src_lang="
|
813 |
-
tgt_lang="
|
814 |
)
|
815 |
processed_text = await chat(Request(), chat_request)
|
816 |
logger.info(f"Processed text: {processed_text.response}")
|
@@ -825,7 +804,6 @@ async def speech_to_speech(
|
|
825 |
)
|
826 |
return audio_response
|
827 |
|
828 |
-
|
829 |
class BatchTranscriptionResponse(BaseModel):
|
830 |
transcriptions: List[str]
|
831 |
|
|
|
23 |
from gemma_llm import LLMManager
|
24 |
# from auth import get_api_key, settings as auth_settings
|
25 |
|
|
|
26 |
import time
|
27 |
from contextlib import asynccontextmanager
|
28 |
from typing import Annotated, Any, OrderedDict, List
|
|
|
94 |
if description_tokenizer.pad_token is None:
|
95 |
description_tokenizer.pad_token = description_tokenizer.eos_token
|
96 |
|
|
|
97 |
# TODO - temporary disable -torch.compile
|
98 |
'''
|
99 |
# Update model configuration
|
|
|
150 |
tts_model_manager.get_or_load_model(config.model)
|
151 |
yield
|
152 |
|
|
|
153 |
app = FastAPI(
|
154 |
title="Dhwani API",
|
155 |
description="AI Chat API supporting Indian languages",
|
|
|
158 |
lifespan=lifespan
|
159 |
)
|
160 |
|
|
|
161 |
def chunk_text(text, chunk_size):
|
162 |
words = text.split()
|
163 |
chunks = []
|
|
|
193 |
padding="max_length",
|
194 |
max_length=tts_model_manager.max_length).to(device)
|
195 |
|
|
|
196 |
input_ids = desc_inputs["input_ids"]
|
197 |
attention_mask = desc_inputs["attention_mask"]
|
198 |
prompt_input_ids = prompt_inputs["input_ids"]
|
|
|
318 |
|
319 |
return StreamingResponse(in_memory_zip, media_type="application/zip")
|
320 |
|
|
|
321 |
# Supported language codes
|
322 |
SUPPORTED_LANGUAGES = {
|
323 |
"asm_Beng", "kas_Arab", "pan_Guru", "ben_Beng", "kas_Deva", "san_Deva",
|
|
|
348 |
|
349 |
settings = Settings()
|
350 |
|
|
|
351 |
app.add_middleware(
|
352 |
CORSMiddleware,
|
353 |
allow_origins=["*"],
|
|
|
536 |
return {"status": "success", "message": "All models loaded"}
|
537 |
except Exception as e:
|
538 |
logger.error(f"Error loading models: {str(e)}")
|
539 |
+
raise HTTPException(status_code=500, detail=f"Failed to unload models: {str(e)}")
|
540 |
|
541 |
@app.post("/v1/translate", response_model=TranslationResponse)
|
542 |
async def translate_endpoint(request: TranslationRequest):
|
|
|
560 |
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
561 |
logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
|
562 |
|
|
|
563 |
EUROPEAN_LANGUAGES = {"deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn", "por_Latn", "rus_Cyrl", "pol_Latn"}
|
564 |
|
565 |
try:
|
|
|
566 |
if chat_request.src_lang != "eng_Latn" and chat_request.src_lang not in EUROPEAN_LANGUAGES:
|
|
|
567 |
translated_prompt = await perform_internal_translation(
|
568 |
sentences=[chat_request.prompt],
|
569 |
src_lang=chat_request.src_lang,
|
|
|
572 |
prompt_to_process = translated_prompt[0]
|
573 |
logger.info(f"Translated prompt to English: {prompt_to_process}")
|
574 |
else:
|
|
|
575 |
prompt_to_process = chat_request.prompt
|
576 |
logger.info("Prompt in English or European language, no translation needed")
|
577 |
|
|
|
578 |
response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
|
579 |
logger.info(f"Generated response: {response}")
|
580 |
|
|
|
581 |
if chat_request.tgt_lang != "eng_Latn" and chat_request.tgt_lang not in EUROPEAN_LANGUAGES:
|
|
|
582 |
translated_response = await perform_internal_translation(
|
583 |
sentences=[response],
|
584 |
src_lang="eng_Latn",
|
|
|
587 |
final_response = translated_response[0]
|
588 |
logger.info(f"Translated response to {chat_request.tgt_lang}: {final_response}")
|
589 |
else:
|
|
|
590 |
final_response = response
|
591 |
logger.info(f"Response in {chat_request.tgt_lang}, no translation needed")
|
592 |
|
|
|
607 |
if image.size == (0, 0):
|
608 |
raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")
|
609 |
|
|
|
610 |
if src_lang != "eng_Latn":
|
611 |
translated_query = await perform_internal_translation(
|
612 |
sentences=[query],
|
|
|
619 |
query_to_process = query
|
620 |
logger.info("Query already in English, no translation needed")
|
621 |
|
|
|
622 |
answer = await llm_manager.vision_query(image, query_to_process)
|
623 |
logger.info(f"Generated English answer: {answer}")
|
624 |
|
|
|
625 |
if tgt_lang != "eng_Latn":
|
626 |
translated_answer = await perform_internal_translation(
|
627 |
sentences=[answer],
|
|
|
662 |
raise HTTPException(status_code=400, detail="Uploaded image is empty")
|
663 |
img = Image.open(io.BytesIO(image_data))
|
664 |
|
|
|
665 |
if src_lang != "eng_Latn":
|
666 |
translated_prompt = await perform_internal_translation(
|
667 |
sentences=[prompt],
|
|
|
677 |
decoded = await llm_manager.chat_v2(img, prompt_to_process)
|
678 |
logger.info(f"Generated English response: {decoded}")
|
679 |
|
|
|
680 |
if tgt_lang != "eng_Latn":
|
681 |
translated_response = await perform_internal_translation(
|
682 |
sentences=[decoded],
|
|
|
689 |
final_response = decoded
|
690 |
logger.info("Response kept in English, no translation needed")
|
691 |
else:
|
|
|
692 |
if src_lang != "eng_Latn":
|
693 |
translated_prompt = await perform_internal_translation(
|
694 |
sentences=[prompt],
|
|
|
704 |
decoded = await llm_manager.generate(prompt_to_process, settings.max_tokens)
|
705 |
logger.info(f"Generated English response: {decoded}")
|
706 |
|
|
|
707 |
if tgt_lang != "eng_Latn":
|
708 |
translated_response = await perform_internal_translation(
|
709 |
sentences=[decoded],
|
|
|
724 |
class TranscriptionResponse(BaseModel):
|
725 |
text: str
|
726 |
|
|
|
727 |
class ASRModelManager:
|
728 |
def __init__(self, device_type="cuda"):
|
729 |
self.device_type = device_type
|
|
|
735 |
"telugu": "te", "urdu": "ur"
|
736 |
}
|
737 |
|
|
|
738 |
from fastapi import FastAPI, UploadFile
|
739 |
import torch
|
740 |
import torchaudio
|
|
|
750 |
# Load the model
|
751 |
model = AutoModel.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True)
|
752 |
|
753 |
+
asr_manager = ASRModelManager()
|
754 |
|
755 |
+
# Language to script mapping
|
756 |
+
LANGUAGE_TO_SCRIPT = {
|
757 |
+
"kannada": "kan_Knda", "hindi": "hin_Deva", "malayalam": "mal_Mlym", "tamil": "tam_Taml",
|
758 |
+
"telugu": "tel_Telu", "assamese": "asm_Beng", "bengali": "ben_Beng", "gujarati": "guj_Gujr",
|
759 |
+
"marathi": "mar_Deva", "odia": "ory_Orya", "punjabi": "pan_Guru", "urdu": "urd_Arab",
|
760 |
+
# Add more as needed
|
761 |
+
}
|
762 |
|
763 |
@app.post("/transcribe/", response_model=TranscriptionResponse)
|
764 |
async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
|
765 |
+
try:
|
766 |
+
wav, sr = torchaudio.load(file.file)
|
767 |
+
wav = torch.mean(wav, dim=0, keepdim=True)
|
768 |
+
target_sample_rate = 16000
|
769 |
+
if sr != target_sample_rate:
|
770 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
|
771 |
+
wav = resampler(wav)
|
772 |
+
transcription_rnnt = model(wav, asr_manager.model_language[language], "rnnt")
|
773 |
+
return TranscriptionResponse(text=transcription_rnnt)
|
774 |
+
except Exception as e:
|
775 |
+
logger.error(f"Error in transcription: {str(e)}")
|
776 |
+
raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
|
|
|
|
|
777 |
|
778 |
@app.post("/v1/speech_to_speech")
|
779 |
async def speech_to_speech(
|
|
|
788 |
# Step 2: Process text with chat endpoint
|
789 |
chat_request = ChatRequest(
|
790 |
prompt=transcription.text,
|
791 |
+
src_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"), # Dynamic script mapping
|
792 |
+
tgt_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda")
|
793 |
)
|
794 |
processed_text = await chat(Request(), chat_request)
|
795 |
logger.info(f"Processed text: {processed_text.response}")
|
|
|
804 |
)
|
805 |
return audio_response
|
806 |
|
|
|
807 |
class BatchTranscriptionResponse(BaseModel):
|
808 |
transcriptions: List[str]
|
809 |
|