sachin
commited on
Commit
·
7e9e8cc
1
Parent(s):
fea8b58
test
Browse files- src/server/main.py +87 -86
src/server/main.py
CHANGED
@@ -28,26 +28,19 @@ from tts_config import SPEED, ResponseFormat, config as tts_config
|
|
28 |
import torchaudio
|
29 |
|
30 |
# Device setup
|
31 |
-
if torch.cuda.is_available()
|
32 |
-
device = "cuda:0"
|
33 |
-
logger.info("GPU will be used for inference")
|
34 |
-
else:
|
35 |
-
device = "cpu"
|
36 |
-
logger.info("CPU will be used for inference")
|
37 |
torch_dtype = torch.bfloat16 if device != "cpu" else torch.float32
|
|
|
38 |
|
39 |
# Check CUDA availability and version
|
40 |
cuda_available = torch.cuda.is_available()
|
41 |
cuda_version = torch.version.cuda if cuda_available else None
|
42 |
-
|
43 |
-
if torch.cuda.is_available():
|
44 |
device_idx = torch.cuda.current_device()
|
45 |
capability = torch.cuda.get_device_capability(device_idx)
|
46 |
-
|
47 |
-
print(f"CUDA version: {cuda_version}")
|
48 |
-
print(f"CUDA Compute Capability: {compute_capability_float}")
|
49 |
else:
|
50 |
-
|
51 |
|
52 |
# Settings
|
53 |
class Settings(BaseSettings):
|
@@ -94,14 +87,7 @@ class LLMManager:
|
|
94 |
try:
|
95 |
if self.device.type == "cuda":
|
96 |
torch.set_float32_matmul_precision('high')
|
97 |
-
logger.info("Enabled TF32 matrix multiplication for improved performance")
|
98 |
-
|
99 |
-
quantization_config = BitsAndBytesConfig(
|
100 |
-
load_in_4bit=True,
|
101 |
-
bnb_4bit_quant_type="nf4",
|
102 |
-
bnb_4bit_compute_dtype=self.torch_dtype,
|
103 |
-
bnb_4bit_use_double_quant=True
|
104 |
-
)
|
105 |
|
106 |
self.model = Gemma3ForConditionalGeneration.from_pretrained(
|
107 |
self.model_name,
|
@@ -113,7 +99,7 @@ class LLMManager:
|
|
113 |
|
114 |
self.processor = AutoProcessor.from_pretrained(self.model_name, use_fast=True)
|
115 |
self.is_loaded = True
|
116 |
-
logger.info(f"LLM {self.model_name} loaded on {self.device}
|
117 |
except Exception as e:
|
118 |
logger.error(f"Failed to load LLM: {str(e)}")
|
119 |
raise HTTPException(status_code=500, detail=f"Model loading failed: {str(e)}")
|
@@ -124,10 +110,10 @@ class LLMManager:
|
|
124 |
del self.processor
|
125 |
if self.device.type == "cuda":
|
126 |
torch.cuda.empty_cache()
|
127 |
-
logger.info(f"GPU memory
|
128 |
self.is_loaded = False
|
129 |
self.token_cache.clear()
|
130 |
-
logger.info(f"LLM {self.model_name} unloaded
|
131 |
|
132 |
async def generate(self, prompt: str, max_tokens: int = settings.max_tokens, temperature: float = 0.7) -> str:
|
133 |
if not self.is_loaded:
|
@@ -139,14 +125,8 @@ class LLMManager:
|
|
139 |
return self.token_cache[cache_key]["response"]
|
140 |
|
141 |
messages_vlm = [
|
142 |
-
{
|
143 |
-
|
144 |
-
"content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state. Provide a concise response in one sentence maximum."}]
|
145 |
-
},
|
146 |
-
{
|
147 |
-
"role": "user",
|
148 |
-
"content": [{"type": "text", "text": prompt}]
|
149 |
-
}
|
150 |
]
|
151 |
|
152 |
try:
|
@@ -169,7 +149,7 @@ class LLMManager:
|
|
169 |
input_len = inputs_vlm["input_ids"].shape[-1]
|
170 |
adjusted_max_tokens = min(max_tokens, max(20, input_len * 2))
|
171 |
|
172 |
-
with torch.
|
173 |
generation = self.model.generate(
|
174 |
**inputs_vlm,
|
175 |
max_new_tokens=adjusted_max_tokens,
|
@@ -189,14 +169,8 @@ class LLMManager:
|
|
189 |
self.load()
|
190 |
|
191 |
messages_vlm = [
|
192 |
-
{
|
193 |
-
|
194 |
-
"content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Summarize your answer in one sentence maximum."}]
|
195 |
-
},
|
196 |
-
{
|
197 |
-
"role": "user",
|
198 |
-
"content": [{"type": "text", "text": query}] + ([{"type": "image", "image": image}] if image else [])
|
199 |
-
}
|
200 |
]
|
201 |
|
202 |
cache_key = f"vision_{query}_{'image' if image else 'no_image'}"
|
@@ -224,7 +198,7 @@ class LLMManager:
|
|
224 |
input_len = inputs_vlm["input_ids"].shape[-1]
|
225 |
adjusted_max_tokens = min(512, max(20, input_len * 2))
|
226 |
|
227 |
-
with torch.
|
228 |
generation = self.model.generate(
|
229 |
**inputs_vlm,
|
230 |
max_new_tokens=adjusted_max_tokens,
|
@@ -244,14 +218,8 @@ class LLMManager:
|
|
244 |
self.load()
|
245 |
|
246 |
messages_vlm = [
|
247 |
-
{
|
248 |
-
|
249 |
-
"content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state."}]
|
250 |
-
},
|
251 |
-
{
|
252 |
-
"role": "user",
|
253 |
-
"content": [{"type": "text", "text": query}] + ([{"type": "image", "image": image}] if image else [])
|
254 |
-
}
|
255 |
]
|
256 |
|
257 |
cache_key = f"chat_v2_{query}_{'image' if image else 'no_image'}"
|
@@ -279,7 +247,7 @@ class LLMManager:
|
|
279 |
input_len = inputs_vlm["input_ids"].shape[-1]
|
280 |
adjusted_max_tokens = min(512, max(20, input_len * 2))
|
281 |
|
282 |
-
with torch.
|
283 |
generation = self.model.generate(
|
284 |
**inputs_vlm,
|
285 |
max_new_tokens=adjusted_max_tokens,
|
@@ -297,19 +265,24 @@ class LLMManager:
|
|
297 |
# TTS Manager
|
298 |
class TTSManager:
|
299 |
def __init__(self, device_type=device):
|
300 |
-
self.device_type = device_type
|
301 |
self.model = None
|
302 |
self.repo_id = "ai4bharat/IndicF5"
|
303 |
|
304 |
def load(self):
|
305 |
if not self.model:
|
306 |
-
logger.info("Loading TTS model
|
307 |
-
self.model = AutoModel.from_pretrained(
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
|
|
|
|
|
|
|
|
|
|
313 |
|
314 |
def synthesize(self, text, ref_audio_path, ref_text):
|
315 |
if not self.model:
|
@@ -394,11 +367,11 @@ SUPPORTED_LANGUAGES = {
|
|
394 |
|
395 |
# Translation Manager
|
396 |
class TranslateManager:
|
397 |
-
def __init__(self, src_lang, tgt_lang, device_type=device
|
398 |
-
self.device_type = device_type
|
399 |
-
self.tokenizer, self.model = self.initialize_model(src_lang, tgt_lang
|
400 |
|
401 |
-
def initialize_model(self, src_lang, tgt_lang, use_distilled):
|
402 |
if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
|
403 |
model_name = "ai4bharat/indictrans2-en-indic-dist-200M" if use_distilled else "ai4bharat/indictrans2-en-indic-1B"
|
404 |
elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
|
@@ -417,6 +390,17 @@ class TranslateManager:
|
|
417 |
).to(self.device_type)
|
418 |
return tokenizer, model
|
419 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
420 |
class ModelManager:
|
421 |
def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
|
422 |
self.models: dict[str, TranslateManager] = {}
|
@@ -432,7 +416,7 @@ class ModelManager:
|
|
432 |
('kan_Knda', 'hin_Deva', 'indic_indic')
|
433 |
]
|
434 |
for src_lang, tgt_lang, key in translation_pairs:
|
435 |
-
logger.info(f"Preloading translation model for {src_lang} -> {tgt_lang}...")
|
436 |
self.models[key] = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
|
437 |
logger.info(f"Translation model for {key} preloaded successfully")
|
438 |
|
@@ -452,21 +436,29 @@ class ModelManager:
|
|
452 |
|
453 |
# ASR Manager
|
454 |
class ASRModelManager:
|
455 |
-
def __init__(self, device_type=
|
456 |
-
self.device_type = device_type
|
457 |
self.model = None
|
458 |
self.model_language = {"kannada": "kn"}
|
459 |
|
460 |
def load(self):
|
461 |
if not self.model:
|
462 |
-
logger.info("Loading ASR model...")
|
463 |
self.model = AutoModel.from_pretrained(
|
464 |
"ai4bharat/indic-conformer-600m-multilingual",
|
465 |
trust_remote_code=True
|
466 |
-
)
|
467 |
-
self.model = self.model.to(self.device_type)
|
468 |
logger.info("ASR model loaded")
|
469 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
470 |
# Global Managers
|
471 |
llm_manager = LLMManager(settings.llm_model_name)
|
472 |
model_manager = ModelManager()
|
@@ -552,15 +544,15 @@ translation_configs = []
|
|
552 |
async def lifespan(app: FastAPI):
|
553 |
def load_all_models():
|
554 |
try:
|
555 |
-
logger.info("Loading LLM model...")
|
556 |
llm_manager.load()
|
557 |
logger.info("LLM model loaded successfully")
|
558 |
|
559 |
-
logger.info("Loading TTS model...")
|
560 |
tts_manager.load()
|
561 |
logger.info("TTS model loaded successfully")
|
562 |
|
563 |
-
logger.info("Loading ASR model...")
|
564 |
asr_manager.load()
|
565 |
logger.info("ASR model loaded successfully")
|
566 |
|
@@ -574,7 +566,11 @@ async def lifespan(app: FastAPI):
|
|
574 |
load_all_models()
|
575 |
yield
|
576 |
llm_manager.unload()
|
577 |
-
|
|
|
|
|
|
|
|
|
578 |
|
579 |
# FastAPI App
|
580 |
app = FastAPI(
|
@@ -585,7 +581,6 @@ app = FastAPI(
|
|
585 |
lifespan=lifespan
|
586 |
)
|
587 |
|
588 |
-
# Add CORS Middleware
|
589 |
app.add_middleware(
|
590 |
CORSMiddleware,
|
591 |
allow_origins=["*"],
|
@@ -594,7 +589,6 @@ app.add_middleware(
|
|
594 |
allow_headers=["*"],
|
595 |
)
|
596 |
|
597 |
-
# Add Timing Middleware
|
598 |
@app.middleware("http")
|
599 |
async def add_request_timing(request: Request, call_next):
|
600 |
start_time = time()
|
@@ -616,6 +610,10 @@ async def unload_all_models():
|
|
616 |
try:
|
617 |
logger.info("Starting to unload all models...")
|
618 |
llm_manager.unload()
|
|
|
|
|
|
|
|
|
619 |
logger.info("All models unloaded successfully")
|
620 |
return {"status": "success", "message": "All models unloaded"}
|
621 |
except Exception as e:
|
@@ -627,6 +625,8 @@ async def load_all_models():
|
|
627 |
try:
|
628 |
logger.info("Starting to load all models...")
|
629 |
llm_manager.load()
|
|
|
|
|
630 |
logger.info("All models loaded successfully")
|
631 |
return {"status": "success", "message": "All models loaded"}
|
632 |
except Exception as e:
|
@@ -775,10 +775,9 @@ async def chat_v2(
|
|
775 |
logger.error(f"Error processing request: {str(e)}")
|
776 |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
777 |
|
778 |
-
# Include LLM Router
|
779 |
app.include_router(llm_router)
|
780 |
|
781 |
-
# Improved Endpoints
|
782 |
@app.post("/audio/speech", response_class=StreamingResponse)
|
783 |
async def synthesize_kannada(request: KannadaSynthesizeRequest):
|
784 |
if not tts_manager.model:
|
@@ -821,8 +820,11 @@ async def transcribe_audio(file: UploadFile = File(...), language: str = Query(.
|
|
821 |
if sr != target_sample_rate:
|
822 |
logger.info(f"Resampling audio from {sr}Hz to {target_sample_rate}Hz")
|
823 |
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
|
824 |
-
wav = resampler(wav)
|
825 |
-
|
|
|
|
|
|
|
826 |
logger.info(f"Transcription completed: {transcription_rnnt[:50]}...")
|
827 |
return TranscriptionResponse(text=transcription_rnnt)
|
828 |
except Exception as e:
|
@@ -837,8 +839,11 @@ async def transcribe_step(audio_data: bytes, language: str) -> str:
|
|
837 |
target_sample_rate = 16000
|
838 |
if sr != target_sample_rate:
|
839 |
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
|
840 |
-
wav = resampler(wav)
|
841 |
-
|
|
|
|
|
|
|
842 |
|
843 |
async def synthesize_step(text: str) -> io.BytesIO:
|
844 |
kannada_example = next((ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)"), None)
|
@@ -863,11 +868,9 @@ async def speech_to_speech(
|
|
863 |
|
864 |
logger.info(f"Processing speech-to-speech for file: {file.filename} in language: {language}")
|
865 |
try:
|
866 |
-
# Step 1: Transcribe
|
867 |
transcription = await transcribe_step(audio_data, language)
|
868 |
logger.info(f"Transcribed text: {transcription[:50]}...")
|
869 |
|
870 |
-
# Step 2: Process with LLM
|
871 |
chat_request = ChatRequest(
|
872 |
prompt=transcription,
|
873 |
src_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"),
|
@@ -876,7 +879,6 @@ async def speech_to_speech(
|
|
876 |
processed_text = await chat(request, chat_request)
|
877 |
logger.info(f"Processed text: {processed_text.response[:50]}...")
|
878 |
|
879 |
-
# Step 3: Synthesize
|
880 |
audio_buffer = await synthesize_step(processed_text.response)
|
881 |
logger.info("Speech-to-speech processing completed")
|
882 |
|
@@ -900,7 +902,8 @@ async def health_check():
|
|
900 |
"translation_models": list(model_manager.models.keys()),
|
901 |
"device": device,
|
902 |
"cuda_available": cuda_available,
|
903 |
-
"cuda_version": cuda_version if cuda_available else "N/A"
|
|
|
904 |
}
|
905 |
logger.info("Health check requested")
|
906 |
return status
|
@@ -967,7 +970,6 @@ LANGUAGE_TO_SCRIPT = {
|
|
967 |
"kannada": "kan_Knda"
|
968 |
}
|
969 |
|
970 |
-
# Main Execution
|
971 |
if __name__ == "__main__":
|
972 |
parser = argparse.ArgumentParser(description="Run the FastAPI server.")
|
973 |
parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
|
@@ -996,7 +998,6 @@ if __name__ == "__main__":
|
|
996 |
llm_manager = LLMManager(settings.llm_model_name)
|
997 |
|
998 |
if selected_config["components"]["ASR"]:
|
999 |
-
asr_model_name = selected_config["components"]["ASR"]["model"]
|
1000 |
asr_manager.model_language[selected_config["language"]] = selected_config["components"]["ASR"]["language_code"]
|
1001 |
|
1002 |
if selected_config["components"]["Translation"]:
|
|
|
28 |
import torchaudio
|
29 |
|
30 |
# Device setup
|
31 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
|
32 |
torch_dtype = torch.bfloat16 if device != "cpu" else torch.float32
|
33 |
+
logger.info(f"Using device: {device} with dtype: {torch_dtype}")
|
34 |
|
35 |
# Check CUDA availability and version
|
36 |
cuda_available = torch.cuda.is_available()
|
37 |
cuda_version = torch.version.cuda if cuda_available else None
|
38 |
+
if cuda_available:
|
|
|
39 |
device_idx = torch.cuda.current_device()
|
40 |
capability = torch.cuda.get_device_capability(device_idx)
|
41 |
+
logger.info(f"CUDA version: {cuda_version}, Compute Capability: {capability[0]}.{capability[1]}")
|
|
|
|
|
42 |
else:
|
43 |
+
logger.info("CUDA is not available; falling back to CPU.")
|
44 |
|
45 |
# Settings
|
46 |
class Settings(BaseSettings):
|
|
|
87 |
try:
|
88 |
if self.device.type == "cuda":
|
89 |
torch.set_float32_matmul_precision('high')
|
90 |
+
logger.info("Enabled TF32 matrix multiplication for improved GPU performance")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
self.model = Gemma3ForConditionalGeneration.from_pretrained(
|
93 |
self.model_name,
|
|
|
99 |
|
100 |
self.processor = AutoProcessor.from_pretrained(self.model_name, use_fast=True)
|
101 |
self.is_loaded = True
|
102 |
+
logger.info(f"LLM {self.model_name} loaded on {self.device}")
|
103 |
except Exception as e:
|
104 |
logger.error(f"Failed to load LLM: {str(e)}")
|
105 |
raise HTTPException(status_code=500, detail=f"Model loading failed: {str(e)}")
|
|
|
110 |
del self.processor
|
111 |
if self.device.type == "cuda":
|
112 |
torch.cuda.empty_cache()
|
113 |
+
logger.info(f"GPU memory cleared: {torch.cuda.memory_allocated()} bytes allocated")
|
114 |
self.is_loaded = False
|
115 |
self.token_cache.clear()
|
116 |
+
logger.info(f"LLM {self.model_name} unloaded")
|
117 |
|
118 |
async def generate(self, prompt: str, max_tokens: int = settings.max_tokens, temperature: float = 0.7) -> str:
|
119 |
if not self.is_loaded:
|
|
|
125 |
return self.token_cache[cache_key]["response"]
|
126 |
|
127 |
messages_vlm = [
|
128 |
+
{"role": "system", "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state. Provide a concise response in one sentence maximum."}]},
|
129 |
+
{"role": "user", "content": [{"type": "text", "text": prompt}]}
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
]
|
131 |
|
132 |
try:
|
|
|
149 |
input_len = inputs_vlm["input_ids"].shape[-1]
|
150 |
adjusted_max_tokens = min(max_tokens, max(20, input_len * 2))
|
151 |
|
152 |
+
with torch.no_grad():
|
153 |
generation = self.model.generate(
|
154 |
**inputs_vlm,
|
155 |
max_new_tokens=adjusted_max_tokens,
|
|
|
169 |
self.load()
|
170 |
|
171 |
messages_vlm = [
|
172 |
+
{"role": "system", "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Summarize your answer in one sentence maximum."}]},
|
173 |
+
{"role": "user", "content": [{"type": "text", "text": query}] + ([{"type": "image", "image": image}] if image else [])}
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
]
|
175 |
|
176 |
cache_key = f"vision_{query}_{'image' if image else 'no_image'}"
|
|
|
198 |
input_len = inputs_vlm["input_ids"].shape[-1]
|
199 |
adjusted_max_tokens = min(512, max(20, input_len * 2))
|
200 |
|
201 |
+
with torch.no_grad():
|
202 |
generation = self.model.generate(
|
203 |
**inputs_vlm,
|
204 |
max_new_tokens=adjusted_max_tokens,
|
|
|
218 |
self.load()
|
219 |
|
220 |
messages_vlm = [
|
221 |
+
{"role": "system", "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state."}]},
|
222 |
+
{"role": "user", "content": [{"type": "text", "text": query}] + ([{"type": "image", "image": image}] if image else [])}
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
]
|
224 |
|
225 |
cache_key = f"chat_v2_{query}_{'image' if image else 'no_image'}"
|
|
|
247 |
input_len = inputs_vlm["input_ids"].shape[-1]
|
248 |
adjusted_max_tokens = min(512, max(20, input_len * 2))
|
249 |
|
250 |
+
with torch.no_grad():
|
251 |
generation = self.model.generate(
|
252 |
**inputs_vlm,
|
253 |
max_new_tokens=adjusted_max_tokens,
|
|
|
265 |
# TTS Manager
|
266 |
class TTSManager:
|
267 |
def __init__(self, device_type=device):
|
268 |
+
self.device_type = torch.device(device_type)
|
269 |
self.model = None
|
270 |
self.repo_id = "ai4bharat/IndicF5"
|
271 |
|
272 |
def load(self):
|
273 |
if not self.model:
|
274 |
+
logger.info(f"Loading TTS model {self.repo_id} on {self.device_type}...")
|
275 |
+
self.model = AutoModel.from_pretrained(self.repo_id, trust_remote_code=True).to(self.device_type)
|
276 |
+
logger.info("TTS model loaded")
|
277 |
+
|
278 |
+
def unload(self):
|
279 |
+
if self.model:
|
280 |
+
del self.model
|
281 |
+
if self.device_type.type == "cuda":
|
282 |
+
torch.cuda.empty_cache()
|
283 |
+
logger.info(f"TTS GPU memory cleared: {torch.cuda.memory_allocated()} bytes allocated")
|
284 |
+
self.model = None
|
285 |
+
logger.info("TTS model unloaded")
|
286 |
|
287 |
def synthesize(self, text, ref_audio_path, ref_text):
|
288 |
if not self.model:
|
|
|
367 |
|
368 |
# Translation Manager
|
369 |
class TranslateManager:
|
370 |
+
def __init__(self, src_lang, tgt_lang, device_type=device):
|
371 |
+
self.device_type = torch.device(device_type)
|
372 |
+
self.tokenizer, self.model = self.initialize_model(src_lang, tgt_lang)
|
373 |
|
374 |
+
def initialize_model(self, src_lang, tgt_lang, use_distilled=True):
|
375 |
if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
|
376 |
model_name = "ai4bharat/indictrans2-en-indic-dist-200M" if use_distilled else "ai4bharat/indictrans2-en-indic-1B"
|
377 |
elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
|
|
|
390 |
).to(self.device_type)
|
391 |
return tokenizer, model
|
392 |
|
393 |
+
def unload(self):
|
394 |
+
if self.model:
|
395 |
+
del self.model
|
396 |
+
del self.tokenizer
|
397 |
+
if self.device_type.type == "cuda":
|
398 |
+
torch.cuda.empty_cache()
|
399 |
+
logger.info(f"Translation GPU memory cleared: {torch.cuda.memory_allocated()} bytes allocated")
|
400 |
+
self.model = None
|
401 |
+
self.tokenizer = None
|
402 |
+
logger.info("Translation model unloaded")
|
403 |
+
|
404 |
class ModelManager:
|
405 |
def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
|
406 |
self.models: dict[str, TranslateManager] = {}
|
|
|
416 |
('kan_Knda', 'hin_Deva', 'indic_indic')
|
417 |
]
|
418 |
for src_lang, tgt_lang, key in translation_pairs:
|
419 |
+
logger.info(f"Preloading translation model for {src_lang} -> {tgt_lang} on {self.device_type}...")
|
420 |
self.models[key] = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
|
421 |
logger.info(f"Translation model for {key} preloaded successfully")
|
422 |
|
|
|
436 |
|
437 |
# ASR Manager
|
438 |
class ASRModelManager:
|
439 |
+
def __init__(self, device_type=device):
|
440 |
+
self.device_type = torch.device(device_type)
|
441 |
self.model = None
|
442 |
self.model_language = {"kannada": "kn"}
|
443 |
|
444 |
def load(self):
|
445 |
if not self.model:
|
446 |
+
logger.info(f"Loading ASR model on {self.device_type}...")
|
447 |
self.model = AutoModel.from_pretrained(
|
448 |
"ai4bharat/indic-conformer-600m-multilingual",
|
449 |
trust_remote_code=True
|
450 |
+
).to(self.device_type)
|
|
|
451 |
logger.info("ASR model loaded")
|
452 |
|
453 |
+
def unload(self):
|
454 |
+
if self.model:
|
455 |
+
del self.model
|
456 |
+
if self.device_type.type == "cuda":
|
457 |
+
torch.cuda.empty_cache()
|
458 |
+
logger.info(f"ASR GPU memory cleared: {torch.cuda.memory_allocated()} bytes allocated")
|
459 |
+
self.model = None
|
460 |
+
logger.info("ASR model unloaded")
|
461 |
+
|
462 |
# Global Managers
|
463 |
llm_manager = LLMManager(settings.llm_model_name)
|
464 |
model_manager = ModelManager()
|
|
|
544 |
async def lifespan(app: FastAPI):
|
545 |
def load_all_models():
|
546 |
try:
|
547 |
+
logger.info(f"Loading LLM model on {device}...")
|
548 |
llm_manager.load()
|
549 |
logger.info("LLM model loaded successfully")
|
550 |
|
551 |
+
logger.info(f"Loading TTS model on {device}...")
|
552 |
tts_manager.load()
|
553 |
logger.info("TTS model loaded successfully")
|
554 |
|
555 |
+
logger.info(f"Loading ASR model on {device}...")
|
556 |
asr_manager.load()
|
557 |
logger.info("ASR model loaded successfully")
|
558 |
|
|
|
566 |
load_all_models()
|
567 |
yield
|
568 |
llm_manager.unload()
|
569 |
+
tts_manager.unload()
|
570 |
+
asr_manager.unload()
|
571 |
+
for model in model_manager.models.values():
|
572 |
+
model.unload()
|
573 |
+
logger.info("Server shutdown complete; all models unloaded")
|
574 |
|
575 |
# FastAPI App
|
576 |
app = FastAPI(
|
|
|
581 |
lifespan=lifespan
|
582 |
)
|
583 |
|
|
|
584 |
app.add_middleware(
|
585 |
CORSMiddleware,
|
586 |
allow_origins=["*"],
|
|
|
589 |
allow_headers=["*"],
|
590 |
)
|
591 |
|
|
|
592 |
@app.middleware("http")
|
593 |
async def add_request_timing(request: Request, call_next):
|
594 |
start_time = time()
|
|
|
610 |
try:
|
611 |
logger.info("Starting to unload all models...")
|
612 |
llm_manager.unload()
|
613 |
+
tts_manager.unload()
|
614 |
+
asr_manager.unload()
|
615 |
+
for model in model_manager.models.values():
|
616 |
+
model.unload()
|
617 |
logger.info("All models unloaded successfully")
|
618 |
return {"status": "success", "message": "All models unloaded"}
|
619 |
except Exception as e:
|
|
|
625 |
try:
|
626 |
logger.info("Starting to load all models...")
|
627 |
llm_manager.load()
|
628 |
+
tts_manager.load()
|
629 |
+
asr_manager.load()
|
630 |
logger.info("All models loaded successfully")
|
631 |
return {"status": "success", "message": "All models loaded"}
|
632 |
except Exception as e:
|
|
|
775 |
logger.error(f"Error processing request: {str(e)}")
|
776 |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
777 |
|
|
|
778 |
app.include_router(llm_router)
|
779 |
|
780 |
+
# Improved Endpoints with GPU Optimization
|
781 |
@app.post("/audio/speech", response_class=StreamingResponse)
|
782 |
async def synthesize_kannada(request: KannadaSynthesizeRequest):
|
783 |
if not tts_manager.model:
|
|
|
820 |
if sr != target_sample_rate:
|
821 |
logger.info(f"Resampling audio from {sr}Hz to {target_sample_rate}Hz")
|
822 |
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
|
823 |
+
wav = resampler(wav).to(device)
|
824 |
+
else:
|
825 |
+
wav = wav.to(device)
|
826 |
+
with torch.no_grad():
|
827 |
+
transcription_rnnt = asr_manager.model(wav, asr_manager.model_language[language], "rnnt")
|
828 |
logger.info(f"Transcription completed: {transcription_rnnt[:50]}...")
|
829 |
return TranscriptionResponse(text=transcription_rnnt)
|
830 |
except Exception as e:
|
|
|
839 |
target_sample_rate = 16000
|
840 |
if sr != target_sample_rate:
|
841 |
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
|
842 |
+
wav = resampler(wav).to(device)
|
843 |
+
else:
|
844 |
+
wav = wav.to(device)
|
845 |
+
with torch.no_grad():
|
846 |
+
return asr_manager.model(wav, asr_manager.model_language[language], "rnnt")
|
847 |
|
848 |
async def synthesize_step(text: str) -> io.BytesIO:
|
849 |
kannada_example = next((ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)"), None)
|
|
|
868 |
|
869 |
logger.info(f"Processing speech-to-speech for file: {file.filename} in language: {language}")
|
870 |
try:
|
|
|
871 |
transcription = await transcribe_step(audio_data, language)
|
872 |
logger.info(f"Transcribed text: {transcription[:50]}...")
|
873 |
|
|
|
874 |
chat_request = ChatRequest(
|
875 |
prompt=transcription,
|
876 |
src_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"),
|
|
|
879 |
processed_text = await chat(request, chat_request)
|
880 |
logger.info(f"Processed text: {processed_text.response[:50]}...")
|
881 |
|
|
|
882 |
audio_buffer = await synthesize_step(processed_text.response)
|
883 |
logger.info("Speech-to-speech processing completed")
|
884 |
|
|
|
902 |
"translation_models": list(model_manager.models.keys()),
|
903 |
"device": device,
|
904 |
"cuda_available": cuda_available,
|
905 |
+
"cuda_version": cuda_version if cuda_available else "N/A",
|
906 |
+
"gpu_memory_allocated": torch.cuda.memory_allocated() if cuda_available else 0
|
907 |
}
|
908 |
logger.info("Health check requested")
|
909 |
return status
|
|
|
970 |
"kannada": "kan_Knda"
|
971 |
}
|
972 |
|
|
|
973 |
if __name__ == "__main__":
|
974 |
parser = argparse.ArgumentParser(description="Run the FastAPI server.")
|
975 |
parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
|
|
|
998 |
llm_manager = LLMManager(settings.llm_model_name)
|
999 |
|
1000 |
if selected_config["components"]["ASR"]:
|
|
|
1001 |
asr_manager.model_language[selected_config["language"]] = selected_config["components"]["ASR"]["language_code"]
|
1002 |
|
1003 |
if selected_config["components"]["Translation"]:
|