File size: 24,420 Bytes
475b0b9
 
 
f0d2c94
475b0b9
 
 
460983d
475b0b9
 
 
 
 
 
 
224556e
37b0941
224556e
1936ef7
 
7a61b58
 
 
1936ef7
37b0941
1936ef7
37b0941
 
ba89109
37b0941
 
7a61b58
 
7e9e8cc
37b0941
7e9e8cc
7a61b58
 
 
 
7e9e8cc
7a61b58
 
7e9e8cc
7a61b58
7e9e8cc
7a61b58
1936ef7
 
 
 
 
 
 
 
7a61b58
1936ef7
 
 
 
 
7a61b58
1936ef7
 
6efcad4
1936ef7
6efcad4
37b0941
 
 
 
 
773ab72
37b0941
773ab72
564e070
773ab72
 
564e070
773ab72
 
a9b565f
564e070
37b0941
773ab72
 
4ec2bfb
773ab72
 
564e070
 
7e9e8cc
564e070
4ec2bfb
8308b9e
773ab72
37b0941
564e070
 
 
 
37b0941
 
 
 
773ab72
37b0941
773ab72
a9b565f
37b0941
773ab72
2472b8d
 
 
 
 
 
7e9e8cc
2472b8d
564e070
7e9e8cc
2472b8d
564e070
773ab72
37b0941
4ec2bfb
37b0941
 
773ab72
37b0941
 
 
564e070
37b0941
564e070
773ab72
7e9e8cc
 
773ab72
 
 
37b0941
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
773ab72
37b0941
 
773ab72
37b0941
0a0efec
 
7e9e8cc
0a0efec
 
37b0941
0a0efec
4ec2bfb
a9b565f
7e9e8cc
 
 
 
 
 
 
 
 
 
 
 
0a0efec
 
 
 
37b0941
 
fea8b58
37b0941
224556e
c8c1a0a
7e9e8cc
c8c1a0a
37b0941
 
224556e
7e9e8cc
37b0941
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2472b8d
37b0941
 
 
 
 
224556e
7e9e8cc
 
 
 
 
 
 
 
 
 
 
37b0941
224556e
37b0941
 
224556e
 
4d3fcd9
 
 
 
fea8b58
 
 
4d3fcd9
 
37b0941
4d3fcd9
224556e
37b0941
1936ef7
4d3fcd9
1936ef7
4d3fcd9
1936ef7
4d3fcd9
 
37b0941
 
 
4d3fcd9
1936ef7
37b0941
1936ef7
7e9e8cc
 
1936ef7
 
37b0941
1936ef7
4ec2bfb
a9b565f
7e9e8cc
4ec2bfb
8308b9e
a9b565f
7e9e8cc
4ec2bfb
1936ef7
7e9e8cc
 
 
 
 
 
 
 
 
1936ef7
 
224556e
1936ef7
0a0efec
1936ef7
224556e
37b0941
 
 
 
 
 
 
 
 
94b0142
475b0b9
 
1936ef7
 
475b0b9
 
 
 
 
 
 
 
 
 
37b0941
1936ef7
 
37b0941
 
 
 
 
4ec2bfb
37b0941
 
4ec2bfb
37b0941
 
 
 
 
 
 
 
1936ef7
37b0941
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1936ef7
 
 
37b0941
 
1936ef7
 
 
 
 
 
 
 
 
 
 
 
230a925
 
 
 
 
 
 
 
 
 
1936ef7
 
 
37b0941
 
 
 
 
 
 
 
 
 
 
475b0b9
37b0941
 
 
 
475b0b9
37b0941
475b0b9
37b0941
 
 
475b0b9
37b0941
 
 
 
 
475b0b9
37b0941
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475b0b9
37b0941
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460983d
37b0941
 
 
224556e
37b0941
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fea8b58
37b0941
 
 
 
 
 
 
 
 
 
fea8b58
 
 
37b0941
 
 
fea8b58
 
 
 
 
 
37b0941
fea8b58
 
 
475b0b9
 
 
 
 
f238ccb
37b0941
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
import argparse
import io
import os
import tempfile
from time import time
from typing import List
import uvicorn
from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Body, Form, APIRouter
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
from PIL import Image
from pydantic import BaseModel, field_validator
from pydantic_settings import BaseSettings
from slowapi import Limiter
from slowapi.util import get_remote_address
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoProcessor, AutoModel, Gemma3ForConditionalGeneration
from IndicTransToolkit import IndicProcessor
import json
import asyncio
from contextlib import asynccontextmanager
import soundfile as sf
import numpy as np
import requests
import logging
from starlette.responses import StreamingResponse
from logging_config import logger  # Assumed external logging config
from tts_config import SPEED, ResponseFormat, config as tts_config  # Assumed external TTS config
import torchaudio
from tenacity import retry, stop_after_attempt, wait_exponential
from torch.cuda.amp import autocast

# Device setup
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device != "cpu" else torch.float32  # Use float16 for speed
logger.info(f"Using device: {device} with dtype: {torch_dtype}")

# Check CUDA availability and version
cuda_available = torch.cuda.is_available()
cuda_version = torch.version.cuda if cuda_available else None
if cuda_available:
    device_idx = torch.cuda.current_device()
    capability = torch.cuda.get_device_capability(device_idx)
    logger.info(f"CUDA version: {cuda_version}, Compute Capability: {capability[0]}.{capability[1]}")
else:
    logger.info("CUDA is not available; falling back to CPU.")

# Settings
class Settings(BaseSettings):
    llm_model_name: str = "google/gemma-3-4b-it"
    max_tokens: int = 512
    host: str = "0.0.0.0"
    port: int = 7860
    chat_rate_limit: str = "100/minute"
    speech_rate_limit: str = "5/minute"

    @field_validator("chat_rate_limit", "speech_rate_limit")
    def validate_rate_limit(cls, v):
        if not v.count("/") == 1 or not v.split("/")[0].isdigit():
            raise ValueError("Rate limit must be in format 'number/period' (e.g., '5/minute')")
        return v

    class Config:
        env_file = ".env"

settings = Settings()

# Request queue for concurrency control (max 10 concurrent GPU tasks)
request_queue = asyncio.Queue(maxsize=10)

# Logging optimization
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))

# LLM Manager with persistent loading and improved caching
class LLMManager:
    def __init__(self, model_name: str, device: str = device):
        self.model_name = model_name
        self.device = torch.device(device)
        self.torch_dtype = torch_dtype
        self.model = None
        self.processor = None
        self.is_loaded = False
        self.token_cache = {}
        self.load()  # Load persistently at initialization
        logger.info(f"LLMManager initialized with model {model_name} on {self.device}")

    def load(self):
        if not self.is_loaded:
            try:
                if self.device.type == "cuda":
                    torch.set_float32_matmul_precision('high')
                    logger.info("Enabled TF32 matrix multiplication for improved GPU performance")

                self.model = Gemma3ForConditionalGeneration.from_pretrained(
                    self.model_name,
                    device_map="auto",
                    torch_dtype=torch.float16,  # Use float16 for speed
                    max_memory={0: "10GiB"}
                ).eval()

                self.processor = AutoProcessor.from_pretrained(self.model_name, use_fast=True)
                # Warm-up model
                dummy_input = self.processor("test", return_tensors="pt").to(self.device)
                with torch.no_grad():
                    self.model.generate(**dummy_input, max_new_tokens=10)
                self.is_loaded = True
                logger.info(f"LLM {self.model_name} loaded and warmed up on {self.device}")
            except Exception as e:
                logger.error(f"Failed to load LLM: {str(e)}")
                self.is_loaded = False  # Allow graceful degradation

    def unload(self):
        if self.is_loaded:
            del self.model
            del self.processor
            if self.device.type == "cuda":
                torch.cuda.empty_cache()
                logger.info(f"GPU memory cleared: {torch.cuda.memory_allocated()} bytes allocated")
            self.is_loaded = False
            self.token_cache.clear()
            logger.info(f"LLM {self.model_name} unloaded")

    async def generate(self, prompt: str, max_tokens: int = settings.max_tokens, temperature: float = 0.7) -> str:
        if not self.is_loaded:
            logger.warning("LLM not loaded; attempting reload")
            self.load()
        if not self.is_loaded:
            raise HTTPException(status_code=503, detail="LLM model unavailable")

        # Improved cache key with parameters
        cache_key = f"{prompt}:{max_tokens}:{temperature}"
        if cache_key in self.token_cache:
            logger.info("Using cached response")
            return self.token_cache[cache_key]

        messages_vlm = [
            {"role": "system", "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state. Provide a concise response in one sentence maximum."}]},
            {"role": "user", "content": [{"type": "text", "text": prompt}]}
        ]

        try:
            inputs_vlm = self.processor.apply_chat_template(
                messages_vlm,
                add_generation_prompt=True,
                tokenize=True,
                return_dict=True,
                return_tensors="pt"
            ).to(self.device)

            with autocast():  # Mixed precision for speed
                generation = self.model.generate(
                    **inputs_vlm,
                    max_new_tokens=max_tokens,
                    do_sample=True,
                    top_p=0.9,
                    temperature=temperature
                )
                generation = generation[0][inputs_vlm["input_ids"].shape[-1]:]

            response = self.processor.decode(generation, skip_special_tokens=True)
            self.token_cache[cache_key] = response
            logger.info(f"Generated response: {response}")
            return response
        except Exception as e:
            logger.error(f"Error in generation: {str(e)}")
            raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")

# TTS Manager with file-based synthesis
class TTSManager:
    def __init__(self, device_type=device):
        self.device_type = torch.device(device_type)
        self.model = None
        self.repo_id = "ai4bharat/IndicF5"
        self.load()  # Persistent loading

    def load(self):
        if not self.model:
            logger.info(f"Loading TTS model {self.repo_id} on {self.device_type}...")
            self.model = AutoModel.from_pretrained(self.repo_id, trust_remote_code=True).to(self.device_type)
            logger.info("TTS model loaded")

    def unload(self):
        if self.model:
            del self.model
            if self.device_type.type == "cuda":
                torch.cuda.empty_cache()
                logger.info(f"TTS GPU memory cleared: {torch.cuda.memory_allocated()} bytes allocated")
            self.model = None
            logger.info("TTS model unloaded")

    def synthesize(self, text, ref_audio_path, ref_text):
        if not self.model:
            raise ValueError("TTS model not loaded")
        with autocast():  # Mixed precision
            return self.model(text, ref_audio_path=ref_audio_path, ref_text=ref_text)

# Translation Manager with warm-up and error handling
class TranslateManager:
    def __init__(self, src_lang, tgt_lang, device_type=device, use_distilled=True):
        self.device_type = torch.device(device_type)
        self.tokenizer, self.model = self.initialize_model(src_lang, tgt_lang, use_distilled)
        if self.model:
            self.warm_up()

    def initialize_model(self, src_lang, tgt_lang, use_distilled=True):
        try:
            if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
                model_name = "ai4bharat/indictrans2-en-indic-dist-200M" if use_distilled else "ai4bharat/indictrans2-en-indic-1B"
            elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
                model_name = "ai4bharat/indictrans2-indic-en-dist-200M" if use_distilled else "ai4bharat/indictrans2-indic-en-1B"
            elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
                model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
            else:
                raise ValueError("Invalid language combination")

            tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
            model = AutoModelForSeq2SeqLM.from_pretrained(
                model_name,
                trust_remote_code=True,
                torch_dtype=torch.float16,
                attn_implementation="flash_attention_2"
            ).to(self.device_type)
            return tokenizer, model
        except Exception as e:
            logger.error(f"Failed to load translation model: {str(e)}")
            return None, None  # Graceful degradation

    def warm_up(self):
        dummy_input = self.tokenizer("test", return_tensors="pt").to(self.device_type)
        with torch.no_grad(), autocast():
            self.model.generate(**dummy_input, max_length=10)
        logger.info("Translation model warmed up")

    def unload(self):
        if self.model:
            del self.model
            del self.tokenizer
            if self.device_type.type == "cuda":
                torch.cuda.empty_cache()
                logger.info(f"Translation GPU memory cleared: {torch.cuda.memory_allocated()} bytes allocated")
            self.model = None
            self.tokenizer = None
            logger.info("Translation model unloaded")

# Model Manager with preloading
class ModelManager:
    def __init__(self, device_type=device, use_distilled=True):
        self.models = {}
        self.device_type = device_type
        self.use_distilled = use_distilled
        self.preload_models()

    def preload_models(self):
        translation_pairs = [
            ('eng_Latn', 'kan_Knda', 'eng_indic'),
            ('kan_Knda', 'eng_Latn', 'indic_eng'),
            ('kan_Knda', 'hin_Deva', 'indic_indic')
        ]
        for src_lang, tgt_lang, key in translation_pairs:
            logger.info(f"Preloading translation model for {src_lang} -> {tgt_lang}...")
            self.models[key] = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)

    def get_model(self, src_lang, tgt_lang):
        if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
            key = 'eng_indic'
        elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
            key = 'indic_eng'
        elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
            key = 'indic_indic'
        else:
            raise ValueError("Invalid language combination")
        if key not in self.models or not self.models[key].model:
            raise HTTPException(status_code=503, detail=f"Translation model for {key} unavailable")
        return self.models[key]

# ASR Manager with GPU audio processing
class ASRModelManager:
    def __init__(self, device_type=device):
        self.device_type = torch.device(device_type)
        self.model = None
        self.model_language = {"kannada": "kn"}
        self.load()

    def load(self):
        if not self.model:
            logger.info(f"Loading ASR model on {self.device_type}...")
            self.model = AutoModel.from_pretrained(
                "ai4bharat/indic-conformer-600m-multilingual",
                trust_remote_code=True
            ).to(self.device_type)
            logger.info("ASR model loaded")

    def unload(self):
        if self.model:
            del self.model
            if self.device_type.type == "cuda":
                torch.cuda.empty_cache()
                logger.info(f"ASR GPU memory cleared: {torch.cuda.memory_allocated()} bytes allocated")
            self.model = None
            logger.info("ASR model unloaded")

# Global Managers
llm_manager = LLMManager(settings.llm_model_name)
model_manager = ModelManager()
asr_manager = ASRModelManager()
tts_manager = TTSManager()
ip = IndicProcessor(inference=True)

# TTS Constants
EXAMPLES = [
    {
        "audio_name": "KAN_F (Happy)",
        "audio_url": "https://github.com/AI4Bharat/IndicF5/raw/refs/heads/main/prompts/KAN_F_HAPPY_00001.wav",
        "ref_text": "ನಮ್‌ ಫ್ರಿಜ್ಜಲ್ಲಿ ಕೂಲಿಂಗ್‌ ಸಮಸ್ಯೆ ಆಗಿ ನಾನ್‌ ಭಾಳ ದಿನದಿಂದ ಒದ್ದಾಡ್ತಿದ್ದೆ, ಆದ್ರೆ ಅದ್ನೀಗ ಮೆಕಾನಿಕ್ ಆಗಿರೋ ನಿಮ್‌ ಸಹಾಯ್ದಿಂದ ಬಗೆಹರಿಸ್ಕೋಬೋದು ಅಂತಾಗಿ ನಿರಾಳ ಆಯ್ತು ನಂಗೆ।",
    },
]

# Pydantic Models
class ChatRequest(BaseModel):
    prompt: str
    src_lang: str = "kan_Knda"
    tgt_lang: str = "kan_Knda"

    @field_validator("prompt")
    def prompt_must_be_valid(cls, v):
        if len(v) > 1000:
            raise ValueError("Prompt cannot exceed 1000 characters")
        return v.strip()

class ChatResponse(BaseModel):
    response: str

class KannadaSynthesizeRequest(BaseModel):
    text: str

    @field_validator("text")
    def text_must_be_valid(cls, v):
        if len(v) > 500:
            raise ValueError("Text cannot exceed 500 characters")
        return v.strip()

class TranscriptionResponse(BaseModel):
    text: str

# TTS Functions
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=10))
def load_audio_from_url(url: str):
    response = requests.get(url)
    if response.status_code == 200:
        audio_data, sample_rate = sf.read(io.BytesIO(response.content))
        return sample_rate, audio_data
    raise HTTPException(status_code=500, detail="Failed to load reference audio from URL after retries")

async def synthesize_speech(tts_manager: TTSManager, text: str, ref_audio_name: str, ref_text: str) -> io.BytesIO:
    async with request_queue:
        ref_audio_url = None
        for example in EXAMPLES:
            if example["audio_name"] == ref_audio_name:
                ref_audio_url = example["audio_url"]
                if not ref_text:
                    ref_text = example["ref_text"]
                break

        if not ref_audio_url:
            raise HTTPException(status_code=400, detail=f"Invalid reference audio name: {ref_audio_name}")
        if not text.strip() or not ref_text.strip():
            raise HTTPException(status_code=400, detail="Text or reference text cannot be empty")

        logger.info(f"Synthesizing speech for text: {text[:50]}... with ref_audio: {ref_audio_name}")
        sample_rate, audio_data = load_audio_from_url(ref_audio_url)

        # Use temporary file since IndicF5 requires a path
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_ref_audio:
            sf.write(temp_ref_audio.name, audio_data, sample_rate, format='WAV')
            temp_ref_audio.flush()
            audio = tts_manager.synthesize(text, temp_ref_audio.name, ref_text)

        if audio.dtype == np.int16:
            audio = audio.astype(np.float32) / 32768.0
        output_buffer = io.BytesIO()
        sf.write(output_buffer, audio, 24000, format='WAV')
        output_buffer.seek(0)
        logger.info("Speech synthesis completed")
        return output_buffer

# FastAPI App
app = FastAPI(
    title="Optimized Dhwani API",
    description="AI Chat API with optimized performance and robustness",
    version="1.0.0",
    lifespan=lifespan
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=False,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.middleware("http")
async def add_request_timing(request: Request, call_next):
    start_time = time()
    response = await call_next(request)
    end_time = time()
    duration = end_time - start_time
    logger.info(f"Request to {request.url.path} took {duration:.3f} seconds")
    response.headers["X-Response-Time"] = f"{duration:.3f}"
    return response

limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter

# Lifespan Event Handler
@asynccontextmanager
async def lifespan(app: FastAPI):
    logger.info("Starting server with preloaded models...")
    yield
    llm_manager.unload()
    tts_manager.unload()
    asr_manager.unload()
    for model in model_manager.models.values():
        model.unload()
    logger.info("Server shutdown complete; all models unloaded")

# Endpoints
@app.post("/v1/speech_to_speech", response_class=StreamingResponse)
async def speech_to_speech(
    request: Request,
    file: UploadFile = File(...),
    language: str = Query(..., enum=list(asr_manager.model_language.keys())),
):
    async with request_queue:
        if not tts_manager.model or not asr_manager.model:
            raise HTTPException(status_code=503, detail="TTS or ASR model not loaded")

        audio_data = await file.read()
        if not audio_data:
            raise HTTPException(status_code=400, detail="Uploaded audio file is empty")
        if len(audio_data) > 10 * 1024 * 1024:
            raise HTTPException(status_code=400, detail="Audio file exceeds 10MB limit")

        logger.info(f"Processing speech-to-speech for file: {file.filename} in language: {language}")
        try:
            # GPU-accelerated transcription
            wav, sr = torchaudio.load(io.BytesIO(audio_data), backend="cuda" if cuda_available else "cpu")
            wav = torch.mean(wav, dim=0, keepdim=True).to(device)
            target_sample_rate = 16000
            if sr != target_sample_rate:
                resampler = torchaudio.transforms.Resample(sr, target_sample_rate).to(device)
                wav = resampler(wav)
            with autocast(), torch.no_grad():
                transcription = asr_manager.model(wav, asr_manager.model_language[language], "rnnt")
            logger.info(f"Transcribed text: {transcription[:50]}...")

            chat_request = ChatRequest(
                prompt=transcription,
                src_lang="kan_Knda",
                tgt_lang="kan_Knda"
            )
            translate_mgr = model_manager.get_model(chat_request.src_lang, "eng_Latn")
            if translate_mgr.model:
                translated_prompt = await perform_internal_translation(
                    [chat_request.prompt], chat_request.src_lang, "eng_Latn"
                )
                prompt_to_process = translated_prompt[0]
            else:
                prompt_to_process = chat_request.prompt

            response = await llm_manager.generate(prompt_to_process)
            if chat_request.tgt_lang != "eng_Latn":
                translate_mgr = model_manager.get_model("eng_Latn", chat_request.tgt_lang)
                if translate_mgr.model:
                    translated_response = await perform_internal_translation(
                        [response], "eng_Latn", chat_request.tgt_lang
                    )
                    final_response = translated_response[0]
                else:
                    final_response = response
            else:
                final_response = response
            logger.info(f"Processed text: {final_response[:50]}...")

            audio_buffer = await synthesize_speech(tts_manager, final_response, "KAN_F (Happy)", EXAMPLES[0]["ref_text"])
            logger.info("Speech-to-speech processing completed")
            return StreamingResponse(
                audio_buffer,
                media_type="audio/wav",
                headers={"Content-Disposition": "attachment; filename=speech_to_speech_output.wav"}
            )
        except Exception as e:
            logger.error(f"Error in speech-to-speech pipeline: {str(e)}")
            raise HTTPException(status_code=500, detail=f"Speech-to-speech failed: {str(e)}")

@app.post("/v1/chat", response_model=ChatResponse)
@limiter.limit(settings.chat_rate_limit)
async def chat(request: Request, chat_request: ChatRequest):
    async with request_queue:
        logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
        try:
            if chat_request.src_lang != "eng_Latn":
                translate_mgr = model_manager.get_model(chat_request.src_lang, "eng_Latn")
                if translate_mgr.model:
                    translated_prompt = await perform_internal_translation(
                        [chat_request.prompt], chat_request.src_lang, "eng_Latn"
                    )
                    prompt_to_process = translated_prompt[0]
                    logger.info(f"Translated prompt to English: {prompt_to_process}")
                else:
                    prompt_to_process = chat_request.prompt
            else:
                prompt_to_process = chat_request.prompt

            response = await llm_manager.generate(prompt_to_process)
            logger.info(f"Generated English response: {response}")

            if chat_request.tgt_lang != "eng_Latn":
                translate_mgr = model_manager.get_model("eng_Latn", chat_request.tgt_lang)
                if translate_mgr.model:
                    translated_response = await perform_internal_translation(
                        [response], "eng_Latn", chat_request.tgt_lang
                    )
                    final_response = translated_response[0]
                    logger.info(f"Translated response to {chat_request.tgt_lang}: {final_response}")
                else:
                    final_response = response
            else:
                final_response = response
            return ChatResponse(response=final_response)
        except Exception as e:
            logger.error(f"Error in chat: {str(e)}")
            raise HTTPException(status_code=500, detail=f"Chat failed: {str(e)}")

async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
    translate_mgr = model_manager.get_model(src_lang, tgt_lang)
    if not translate_mgr.model:
        raise HTTPException(status_code=503, detail="Translation model unavailable")
    batch = ip.preprocess_batch(sentences, src_lang=src_lang, tgt_lang=tgt_lang)
    inputs = translate_mgr.tokenizer(batch, truncation=True, padding="longest", return_tensors="pt").to(device)
    with torch.no_grad(), autocast():
        tokens = translate_mgr.model.generate(**inputs, max_length=256, num_beams=5)
    translations = translate_mgr.tokenizer.batch_decode(tokens, skip_special_tokens=True)
    return ip.postprocess_batch(translations, lang=tgt_lang)

@app.get("/v1/health")
async def health_check():
    memory_usage = torch.cuda.memory_allocated() / (24 * 1024**3) if cuda_available else 0  # 24GB VRAM
    if memory_usage > 0.9:
        logger.warning("GPU memory usage exceeds 90%; consider unloading models")
    status = {
        "status": "healthy",
        "llm_loaded": llm_manager.is_loaded,
        "tts_loaded": bool(tts_manager.model),
        "asr_loaded": bool(asr_manager.model),
        "translation_models": list(model_manager.models.keys()),
        "gpu_memory_usage": f"{memory_usage:.2%}"
    }
    return status

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run the FastAPI server.")
    parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
    parser.add_argument("--host", type=str, default=settings.host, help="Host to run the server on.")
    args = parser.parse_args()

    # Uvicorn tuning: 2 workers for 8 vCPUs and 24GB VRAM
    uvicorn.run(app, host=args.host, port=args.port, workers=2)