sachin commited on
Commit
1936ef7
·
1 Parent(s): 843c466

test-changes

Browse files
Files changed (1) hide show
  1. src/server/main.py +169 -240
src/server/main.py CHANGED
@@ -3,7 +3,6 @@ import io
3
  import os
4
  from time import time
5
  from typing import List
6
-
7
  import tempfile
8
  import uvicorn
9
  from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Body, Form
@@ -15,31 +14,18 @@ from pydantic_settings import BaseSettings
15
  from slowapi import Limiter
16
  from slowapi.util import get_remote_address
17
  import torch
18
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
19
  from IndicTransToolkit import IndicProcessor
20
-
21
- from logging_config import logger
22
- from tts_config import SPEED, ResponseFormat, config as tts_config
23
- from gemma_llm import LLMManager
24
- # from auth import get_api_key, settings as auth_settings
25
-
26
- import time
27
  from contextlib import asynccontextmanager
28
- from typing import Annotated, Any, OrderedDict, List
29
- import zipfile
30
  import soundfile as sf
31
- import torch
32
- from fastapi import Body, FastAPI, HTTPException, Response
33
- from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
34
  import numpy as np
35
- from config import SPEED, ResponseFormat, config
36
- from logger import logger
37
- import uvicorn
38
- import argparse
39
- from fastapi.responses import RedirectResponse, StreamingResponse
40
- import io
41
- import os
42
- import logging
43
 
44
  # Device setup
45
  if torch.cuda.is_available():
@@ -63,40 +49,29 @@ if torch.cuda.is_available():
63
  else:
64
  print("CUDA is not available on this system.")
65
 
66
- app = FastAPI(
67
- title="Dhwani API",
68
- description="AI Chat API supporting Indian languages",
69
- version="1.0.0",
70
- redirect_slashes=False,
71
- #lifespan=lifespan
72
- )
73
-
74
- def chunk_text(text, chunk_size):
75
- words = text.split()
76
- chunks = []
77
- for i in range(0, len(words), chunk_size):
78
- chunks.append(' '.join(words[i:i + chunk_size]))
79
- return chunks
80
 
 
 
 
 
 
81
 
82
- import io
83
- import torch
84
- import requests
85
- import tempfile
86
- import numpy as np
87
- import soundfile as sf
88
- from fastapi import FastAPI, HTTPException
89
- from transformers import AutoModel
90
- from pydantic import BaseModel
91
- from typing import Optional
92
- from starlette.responses import StreamingResponse
93
 
 
94
 
 
95
  tts_repo_id = "ai4bharat/IndicF5"
96
- tts_model = AutoModel.from_pretrained(tts_repo_id, trust_remote_code=True)
97
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
98
- print("Device:", device)
99
- tts_model = tts_model.to(device)
100
 
101
  EXAMPLES = [
102
  {
@@ -107,18 +82,16 @@ EXAMPLES = [
107
  },
108
  ]
109
 
110
-
111
- # Pydantic model for request body
112
  class SynthesizeRequest(BaseModel):
113
- text: str # Text to synthesize (expected in Kannada)
114
- ref_audio_name: str # Dropdown of audio names from EXAMPLES
115
- ref_text: Optional[str] = None # Optional, defaults to example ref_text if not provided
116
 
117
  class KannadaSynthesizeRequest(BaseModel):
118
- text: str # Text to synthesize (must be in Kannada)
119
-
120
 
121
- # Function to load audio from URL
122
  def load_audio_from_url(url: str):
123
  response = requests.get(url)
124
  if response.status_code == 200:
@@ -126,9 +99,7 @@ def load_audio_from_url(url: str):
126
  return sample_rate, audio_data
127
  raise HTTPException(status_code=500, detail="Failed to load reference audio from URL.")
128
 
129
- # Function to synthesize speech
130
  def synthesize_speech(text: str, ref_audio_name: str, ref_text: str):
131
- # Find the matching example
132
  ref_audio_url = None
133
  for example in EXAMPLES:
134
  if example["audio_name"] == ref_audio_name:
@@ -139,58 +110,25 @@ def synthesize_speech(text: str, ref_audio_name: str, ref_text: str):
139
 
140
  if not ref_audio_url:
141
  raise HTTPException(status_code=400, detail="Invalid reference audio name.")
142
-
143
  if not text.strip():
144
  raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
145
-
146
  if not ref_text or not ref_text.strip():
147
  raise HTTPException(status_code=400, detail="Reference text cannot be empty.")
148
 
149
- # Load reference audio from URL
150
  sample_rate, audio_data = load_audio_from_url(ref_audio_url)
151
-
152
- # Save reference audio to a temporary file
153
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
154
  sf.write(temp_audio.name, audio_data, samplerate=sample_rate, format='WAV')
155
  temp_audio.flush()
156
-
157
- # Generate speech
158
  audio = tts_model(text, ref_audio_path=temp_audio.name, ref_text=ref_text)
159
 
160
- # Normalize output
161
  if audio.dtype == np.int16:
162
  audio = audio.astype(np.float32) / 32768.0
163
-
164
- # Save generated audio to a BytesIO buffer
165
  buffer = io.BytesIO()
166
  sf.write(buffer, audio, 24000, format='WAV')
167
  buffer.seek(0)
168
-
169
  return buffer
170
 
171
- @app.post("/audio/speech", response_class=StreamingResponse)
172
- async def synthesize_kannada(request: KannadaSynthesizeRequest):
173
- # Use the Kannada example as fixed reference
174
- kannada_example = next(ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)")
175
-
176
- if not request.text.strip():
177
- raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
178
-
179
- # Use the fixed Kannada reference audio and text
180
- audio_buffer = synthesize_speech(
181
- text=request.text,
182
- ref_audio_name="KAN_F (Happy)",
183
- ref_text=kannada_example["ref_text"]
184
- )
185
-
186
- return StreamingResponse(
187
- audio_buffer,
188
- media_type="audio/wav",
189
- headers={"Content-Disposition": "attachment; filename=synthesized_kannada_speech.wav"}
190
- )
191
-
192
-
193
- # Supported language codes
194
  SUPPORTED_LANGUAGES = {
195
  "asm_Beng", "kas_Arab", "pan_Guru", "ben_Beng", "kas_Deva", "san_Deva",
196
  "brx_Deva", "mai_Deva", "sat_Olck", "doi_Deva", "mal_Mlym", "snd_Arab",
@@ -201,43 +139,9 @@ SUPPORTED_LANGUAGES = {
201
  "por_Latn", "rus_Cyrl", "pol_Latn"
202
  }
203
 
204
- class Settings(BaseSettings):
205
- llm_model_name: str = "google/gemma-3-4b-it"
206
- max_tokens: int = 512
207
- host: str = "0.0.0.0"
208
- port: int = 7860
209
- chat_rate_limit: str = "100/minute"
210
- speech_rate_limit: str = "5/minute"
211
-
212
- @field_validator("chat_rate_limit", "speech_rate_limit")
213
- def validate_rate_limit(cls, v):
214
- if not v.count("/") == 1 or not v.split("/")[0].isdigit():
215
- raise ValueError("Rate limit must be in format 'number/period' (e.g., '5/minute')")
216
- return v
217
-
218
- class Config:
219
- env_file = ".env"
220
-
221
- settings = Settings()
222
-
223
- app.add_middleware(
224
- CORSMiddleware,
225
- allow_origins=["*"],
226
- allow_credentials=False,
227
- allow_methods=["*"],
228
- allow_headers=["*"],
229
- )
230
-
231
- limiter = Limiter(key_func=get_remote_address)
232
- app.state.limiter = limiter
233
-
234
- llm_manager = LLMManager(settings.llm_model_name)
235
-
236
- # Translation Manager and Model Manager
237
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
238
-
239
  class TranslateManager:
240
- def __init__(self, src_lang, tgt_lang, device_type=DEVICE, use_distilled=True):
241
  self.device_type = device_type
242
  self.tokenizer, self.model = self.initialize_model(src_lang, tgt_lang, use_distilled)
243
 
@@ -258,55 +162,84 @@ class TranslateManager:
258
  torch_dtype=torch.float16,
259
  attn_implementation="flash_attention_2"
260
  ).to(self.device_type)
261
-
262
  model = torch.compile(model, mode="reduce-overhead")
263
  print("Model compiled with torch.compile")
264
  return tokenizer, model
265
 
266
  class ModelManager:
267
- def __init__(self, device_type=DEVICE, use_distilled=True, is_lazy_loading=False):
268
- self.models: dict[str, TranslateManager] = {}
269
  self.device_type = device_type
270
  self.use_distilled = use_distilled
271
  self.is_lazy_loading = is_lazy_loading
272
- if not is_lazy_loading:
273
- self.preload_models()
274
 
275
- def preload_models(self):
276
- self.models['eng_indic'] = TranslateManager('eng_Latn', 'kan_Knda', self.device_type, self.use_distilled)
277
- self.models['indic_eng'] = TranslateManager('kan_Knda', 'eng_Latn', self.device_type, self.use_distilled)
278
- self.models['indic_indic'] = TranslateManager('kan_Knda', 'hin_Deva', self.device_type, self.use_distilled)
279
-
280
- def get_model(self, src_lang, tgt_lang) -> TranslateManager:
281
  if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
282
- key = 'eng_indic'
283
  elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
284
- key = 'indic_eng'
285
- elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
286
- key = 'indic_indic'
287
  else:
288
- raise ValueError("Invalid language combination: English to English translation is not supported.")
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
 
 
290
  if key not in self.models:
291
  if self.is_lazy_loading:
292
- if key == 'eng_indic':
293
- self.models[key] = TranslateManager('eng_Latn', 'kan_Knda', self.device_type, self.use_distilled)
294
- elif key == 'indic_eng':
295
- self.models[key] = TranslateManager('kan_Knda', 'eng_Latn', self.device_type, self.use_distilled)
296
- elif key == 'indic_indic':
297
- self.models[key] = TranslateManager('kan_Knda', 'hin_Deva', self.device_type, self.use_distilled)
298
  else:
299
  raise ValueError(f"Model for {key} is not preloaded and lazy loading is disabled.")
300
- return self.models[key]
301
 
302
- ip = IndicProcessor(inference=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  model_manager = ModelManager()
 
 
304
 
305
  # Pydantic Models
306
  class ChatRequest(BaseModel):
307
  prompt: str
308
- src_lang: str = "kan_Knda" # Default to Kannada
309
- tgt_lang: str = "kan_Knda" # Default to Kannada
310
 
311
  @field_validator("prompt")
312
  def prompt_must_be_valid(cls, v):
@@ -331,11 +264,72 @@ class TranslationRequest(BaseModel):
331
  class TranslationResponse(BaseModel):
332
  translations: List[str]
333
 
334
- # Dependency to get TranslateManager
 
 
 
335
  def get_translate_manager(src_lang: str, tgt_lang: str) -> TranslateManager:
336
  return model_manager.get_model(src_lang, tgt_lang)
337
 
338
- # Internal Translation Endpoint
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
  @app.post("/translate", response_model=TranslationResponse)
340
  async def translate(request: TranslationRequest, translate_manager: TranslateManager = Depends(get_translate_manager)):
341
  input_sentences = request.sentences
@@ -346,7 +340,6 @@ async def translate(request: TranslationRequest, translate_manager: TranslateMan
346
  raise HTTPException(status_code=400, detail="Input sentences are required")
347
 
348
  batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang)
349
-
350
  inputs = translate_manager.tokenizer(
351
  batch,
352
  truncation=True,
@@ -375,14 +368,12 @@ async def translate(request: TranslationRequest, translate_manager: TranslateMan
375
  translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
376
  return TranslationResponse(translations=translations)
377
 
378
- # Helper function to perform internal translation
379
  async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
380
  translate_manager = model_manager.get_model(src_lang, tgt_lang)
381
  request = TranslationRequest(sentences=sentences, src_lang=src_lang, tgt_lang=tgt_lang)
382
  response = await translate(request, translate_manager)
383
  return response.translations
384
 
385
- # API Endpoints
386
  @app.get("/v1/health")
387
  async def health_check():
388
  return {"status": "healthy", "model": settings.llm_model_name}
@@ -395,7 +386,7 @@ async def home():
395
  async def unload_all_models():
396
  try:
397
  logger.info("Starting to unload all models...")
398
- llm_manager.unload()
399
  logger.info("All models unloaded successfully")
400
  return {"status": "success", "message": "All models unloaded"}
401
  except Exception as e:
@@ -406,7 +397,7 @@ async def unload_all_models():
406
  async def load_all_models():
407
  try:
408
  logger.info("Starting to load all models...")
409
- llm_manager.load()
410
  logger.info("All models loaded successfully")
411
  return {"status": "success", "message": "All models loaded"}
412
  except Exception as e:
@@ -596,57 +587,10 @@ async def chat_v2(
596
  logger.error(f"Error processing request: {str(e)}")
597
  raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
598
 
599
- class TranscriptionResponse(BaseModel):
600
- text: str
601
-
602
- class ASRModelManager:
603
- def __init__(self, device_type="cuda"):
604
- self.device_type = device_type
605
- self.model_language = {
606
- "kannada": "kn"
607
- }
608
- '''
609
- self.model_language = {
610
- "kannada": "kn", "hindi": "hi", "malayalam": "ml", "assamese": "as", "bengali": "bn",
611
- "bodo": "brx", "dogri": "doi", "gujarati": "gu", "kashmiri": "ks", "konkani": "kok",
612
- "maithili": "mai", "manipuri": "mni", "marathi": "mr", "nepali": "ne", "odia": "or",
613
- "punjabi": "pa", "sanskrit": "sa", "santali": "sat", "sindhi": "sd", "tamil": "ta",
614
- "telugu": "te", "urdu": "ur"
615
- }
616
- '''
617
-
618
- from fastapi import FastAPI, UploadFile
619
- import torch
620
- import torchaudio
621
- from transformers import AutoModel
622
- import argparse
623
- import uvicorn
624
- from pydantic import BaseModel
625
- from pydub import AudioSegment
626
- from fastapi import FastAPI, File, UploadFile, HTTPException, Query
627
- from fastapi.responses import RedirectResponse, JSONResponse
628
- from typing import List
629
-
630
- # Load the model
631
- model = AutoModel.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True)
632
-
633
- asr_manager = ASRModelManager()
634
-
635
- # Language to script mapping
636
- LANGUAGE_TO_SCRIPT = {
637
- "kannada": "kan_Knda"
638
- }
639
- '''
640
- LANGUAGE_TO_SCRIPT = {
641
- "kannada": "kan_Knda", "hindi": "hin_Deva", "malayalam": "mal_Mlym", "tamil": "tam_Taml",
642
- "telugu": "tel_Telu", "assamese": "asm_Beng", "bengali": "ben_Beng", "gujarati": "guj_Gujr",
643
- "marathi": "mar_Deva", "odia": "ory_Orya", "punjabi": "pan_Guru", "urdu": "urd_Arab",
644
- # Add more as needed
645
- }
646
- '''
647
-
648
  @app.post("/transcribe/", response_model=TranscriptionResponse)
649
  async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
 
 
650
  try:
651
  wav, sr = torchaudio.load(file.file)
652
  wav = torch.mean(wav, dim=0, keepdim=True)
@@ -654,51 +598,45 @@ async def transcribe_audio(file: UploadFile = File(...), language: str = Query(.
654
  if sr != target_sample_rate:
655
  resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
656
  wav = resampler(wav)
657
- transcription_rnnt = model(wav, asr_manager.model_language[language], "rnnt")
658
  return TranscriptionResponse(text=transcription_rnnt)
659
  except Exception as e:
660
  logger.error(f"Error in transcription: {str(e)}")
661
  raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
 
662
  @app.post("/v1/speech_to_speech")
663
  async def speech_to_speech(
664
- request: Request, # Inject Request object from FastAPI
665
  file: UploadFile = File(...),
666
  language: str = Query(..., enum=list(asr_manager.model_language.keys())),
667
  ) -> StreamingResponse:
668
- # Step 1: Transcribe audio to text
669
  transcription = await transcribe_audio(file, language)
670
  logger.info(f"Transcribed text: {transcription.text}")
671
 
672
- # Step 2: Process text with chat endpoint
673
  chat_request = ChatRequest(
674
  prompt=transcription.text,
675
- src_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"), # Dynamic script mapping
676
  tgt_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda")
677
  )
678
- processed_text = await chat(request, chat_request) # Pass the injected request
679
  logger.info(f"Processed text: {processed_text.response}")
680
 
681
  voice_request = KannadaSynthesizeRequest(text=processed_text.response)
682
-
683
- # Step 3: Convert processed text to speech
684
- audio_response = await synthesize_kannada(
685
- voice_request
686
- )
687
  return audio_response
688
 
689
- class BatchTranscriptionResponse(BaseModel):
690
- transcriptions: List[str]
691
-
692
- import json
693
 
 
694
  if __name__ == "__main__":
695
  parser = argparse.ArgumentParser(description="Run the FastAPI server.")
696
  parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
697
  parser.add_argument("--host", type=str, default=settings.host, help="Host to run the server on.")
698
- parser.add_argument("--config", type=str, default="config_one", help="Configuration to use (e.g., config_one, config_two, config_three, config_four)")
699
  args = parser.parse_args()
700
 
701
- # Load the JSON configuration file
702
  def load_config(config_path="dhwani_config.json"):
703
  with open(config_path, "r") as f:
704
  return json.load(f)
@@ -710,7 +648,6 @@ if __name__ == "__main__":
710
  selected_config = config_data["configs"][args.config]
711
  global_settings = config_data["global_settings"]
712
 
713
- # Update settings based on selected config
714
  settings.llm_model_name = selected_config["components"]["LLM"]["model"]
715
  settings.max_tokens = selected_config["components"]["LLM"]["max_tokens"]
716
  settings.host = global_settings["host"]
@@ -718,27 +655,19 @@ if __name__ == "__main__":
718
  settings.chat_rate_limit = global_settings["chat_rate_limit"]
719
  settings.speech_rate_limit = global_settings["speech_rate_limit"]
720
 
721
- # Initialize LLMManager with the selected LLM model
722
  llm_manager = LLMManager(settings.llm_model_name)
723
 
724
- # Initialize ASR model if present in config
725
  if selected_config["components"]["ASR"]:
726
  asr_model_name = selected_config["components"]["ASR"]["model"]
727
- model = AutoModel.from_pretrained(asr_model_name, trust_remote_code=True)
728
  asr_manager.model_language[selected_config["language"]] = selected_config["components"]["ASR"]["language_code"]
729
 
730
-
731
-
732
- # Initialize Translation models - load all specified models
733
  if selected_config["components"]["Translation"]:
734
  for translation_config in selected_config["components"]["Translation"]:
735
  src_lang = translation_config["src_lang"]
736
  tgt_lang = translation_config["tgt_lang"]
737
- model_manager.get_model(src_lang, tgt_lang)
738
 
739
- # Override host and port from command line arguments if provided
740
  host = args.host if args.host != settings.host else settings.host
741
  port = args.port if args.port != settings.port else settings.port
742
 
743
- # Run the server
744
  uvicorn.run(app, host=host, port=port)
 
3
  import os
4
  from time import time
5
  from typing import List
 
6
  import tempfile
7
  import uvicorn
8
  from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Body, Form
 
14
  from slowapi import Limiter
15
  from slowapi.util import get_remote_address
16
  import torch
17
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModel
18
  from IndicTransToolkit import IndicProcessor
19
+ import json
20
+ import asyncio
 
 
 
 
 
21
  from contextlib import asynccontextmanager
 
 
22
  import soundfile as sf
 
 
 
23
  import numpy as np
24
+ import requests
25
+ from starlette.responses import StreamingResponse
26
+ from logging_config import logger
27
+ from tts_config import SPEED, ResponseFormat, config as tts_config
28
+ from gemma_llm import LLMManager # Assuming this is your custom LLMManager
 
 
 
29
 
30
  # Device setup
31
  if torch.cuda.is_available():
 
49
  else:
50
  print("CUDA is not available on this system.")
51
 
52
+ # Settings
53
+ class Settings(BaseSettings):
54
+ llm_model_name: str = "google/gemma-3-4b-it"
55
+ max_tokens: int = 512
56
+ host: str = "0.0.0.0"
57
+ port: int = 7860
58
+ chat_rate_limit: str = "100/minute"
59
+ speech_rate_limit: str = "5/minute"
 
 
 
 
 
 
60
 
61
+ @field_validator("chat_rate_limit", "speech_rate_limit")
62
+ def validate_rate_limit(cls, v):
63
+ if not v.count("/") == 1 or not v.split("/")[0].isdigit():
64
+ raise ValueError("Rate limit must be in format 'number/period' (e.g., '5/minute')")
65
+ return v
66
 
67
+ class Config:
68
+ env_file = ".env"
 
 
 
 
 
 
 
 
 
69
 
70
+ settings = Settings()
71
 
72
+ # TTS Setup
73
  tts_repo_id = "ai4bharat/IndicF5"
74
+ tts_model = AutoModel.from_pretrained(tts_repo_id, trust_remote_code=True).to(device)
 
 
 
75
 
76
  EXAMPLES = [
77
  {
 
82
  },
83
  ]
84
 
85
+ # Pydantic models for TTS
 
86
  class SynthesizeRequest(BaseModel):
87
+ text: str
88
+ ref_audio_name: str
89
+ ref_text: str = None
90
 
91
  class KannadaSynthesizeRequest(BaseModel):
92
+ text: str
 
93
 
94
+ # TTS Functions
95
  def load_audio_from_url(url: str):
96
  response = requests.get(url)
97
  if response.status_code == 200:
 
99
  return sample_rate, audio_data
100
  raise HTTPException(status_code=500, detail="Failed to load reference audio from URL.")
101
 
 
102
  def synthesize_speech(text: str, ref_audio_name: str, ref_text: str):
 
103
  ref_audio_url = None
104
  for example in EXAMPLES:
105
  if example["audio_name"] == ref_audio_name:
 
110
 
111
  if not ref_audio_url:
112
  raise HTTPException(status_code=400, detail="Invalid reference audio name.")
 
113
  if not text.strip():
114
  raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
 
115
  if not ref_text or not ref_text.strip():
116
  raise HTTPException(status_code=400, detail="Reference text cannot be empty.")
117
 
 
118
  sample_rate, audio_data = load_audio_from_url(ref_audio_url)
 
 
119
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
120
  sf.write(temp_audio.name, audio_data, samplerate=sample_rate, format='WAV')
121
  temp_audio.flush()
 
 
122
  audio = tts_model(text, ref_audio_path=temp_audio.name, ref_text=ref_text)
123
 
 
124
  if audio.dtype == np.int16:
125
  audio = audio.astype(np.float32) / 32768.0
 
 
126
  buffer = io.BytesIO()
127
  sf.write(buffer, audio, 24000, format='WAV')
128
  buffer.seek(0)
 
129
  return buffer
130
 
131
+ # Supported languages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  SUPPORTED_LANGUAGES = {
133
  "asm_Beng", "kas_Arab", "pan_Guru", "ben_Beng", "kas_Deva", "san_Deva",
134
  "brx_Deva", "mai_Deva", "sat_Olck", "doi_Deva", "mal_Mlym", "snd_Arab",
 
139
  "por_Latn", "rus_Cyrl", "pol_Latn"
140
  }
141
 
142
+ # Translation Manager
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  class TranslateManager:
144
+ def __init__(self, src_lang, tgt_lang, device_type=device, use_distilled=True):
145
  self.device_type = device_type
146
  self.tokenizer, self.model = self.initialize_model(src_lang, tgt_lang, use_distilled)
147
 
 
162
  torch_dtype=torch.float16,
163
  attn_implementation="flash_attention_2"
164
  ).to(self.device_type)
 
165
  model = torch.compile(model, mode="reduce-overhead")
166
  print("Model compiled with torch.compile")
167
  return tokenizer, model
168
 
169
  class ModelManager:
170
+ def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
171
+ self.models = {}
172
  self.device_type = device_type
173
  self.use_distilled = use_distilled
174
  self.is_lazy_loading = is_lazy_loading
 
 
175
 
176
+ async def load_model(self, src_lang, tgt_lang, key):
177
+ logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}")
 
 
 
 
178
  if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
179
+ model_name = "ai4bharat/indictrans2-en-indic-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-en-indic-1B"
180
  elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
181
+ model_name = "ai4bharat/indictrans2-indic-en-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-indic-en-1B"
 
 
182
  else:
183
+ model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if self.use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
184
+
185
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
186
+ model = await asyncio.to_thread(
187
+ AutoModelForSeq2SeqLM.from_pretrained,
188
+ model_name,
189
+ trust_remote_code=True,
190
+ torch_dtype=torch.float16,
191
+ attn_implementation="flash_attention_2"
192
+ )
193
+ model = model.to(self.device_type)
194
+ model = torch.compile(model, mode="reduce-overhead")
195
+ self.models[key] = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
196
+ logger.info(f"Loaded translation model for {key}")
197
 
198
+ def get_model(self, src_lang, tgt_lang):
199
+ key = self._get_model_key(src_lang, tgt_lang)
200
  if key not in self.models:
201
  if self.is_lazy_loading:
202
+ asyncio.create_task(self.load_model(src_lang, tgt_lang, key))
 
 
 
 
 
203
  else:
204
  raise ValueError(f"Model for {key} is not preloaded and lazy loading is disabled.")
205
+ return self.models.get(key)
206
 
207
+ def _get_model_key(self, src_lang, tgt_lang):
208
+ if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
209
+ return 'eng_indic'
210
+ elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
211
+ return 'indic_eng'
212
+ elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
213
+ return 'indic_indic'
214
+ raise ValueError("Invalid language combination")
215
+
216
+ # ASR Manager
217
+ class ASRModelManager:
218
+ def __init__(self, device_type="cuda"):
219
+ self.device_type = device_type
220
+ self.model = None
221
+ self.model_language = {"kannada": "kn"}
222
+
223
+ async def load(self):
224
+ logger.info("Loading ASR model...")
225
+ self.model = await asyncio.to_thread(
226
+ AutoModel.from_pretrained,
227
+ "ai4bharat/indic-conformer-600m-multilingual",
228
+ trust_remote_code=True
229
+ )
230
+ logger.info("ASR model loaded")
231
+
232
+ # Global Managers
233
+ llm_manager = LLMManager(settings.llm_model_name)
234
  model_manager = ModelManager()
235
+ asr_manager = ASRModelManager()
236
+ ip = IndicProcessor(inference=True)
237
 
238
  # Pydantic Models
239
  class ChatRequest(BaseModel):
240
  prompt: str
241
+ src_lang: str = "kan_Knda"
242
+ tgt_lang: str = "kan_Knda"
243
 
244
  @field_validator("prompt")
245
  def prompt_must_be_valid(cls, v):
 
264
  class TranslationResponse(BaseModel):
265
  translations: List[str]
266
 
267
+ class TranscriptionResponse(BaseModel):
268
+ text: str
269
+
270
+ # Dependency
271
  def get_translate_manager(src_lang: str, tgt_lang: str) -> TranslateManager:
272
  return model_manager.get_model(src_lang, tgt_lang)
273
 
274
+ # Lifespan Event Handler
275
+ @asynccontextmanager
276
+ async def lifespan(app: FastAPI):
277
+ async def load_all_models():
278
+ tasks = [
279
+ asyncio.create_task(llm_manager.load()),
280
+ asyncio.create_task(asr_manager.load()),
281
+ asyncio.create_task(model_manager.load_model('eng_Latn', 'kan_Knda', 'eng_indic')),
282
+ asyncio.create_task(model_manager.load_model('kan_Knda', 'eng_Latn', 'indic_eng')),
283
+ asyncio.create_task(model_manager.load_model('kan_Knda', 'hin_Deva', 'indic_indic')),
284
+ ]
285
+ await asyncio.gather(*tasks)
286
+ logger.info("All models loaded successfully")
287
+
288
+ logger.info("Starting model loading in background...")
289
+ asyncio.create_task(load_all_models())
290
+ yield
291
+ await llm_manager.unload()
292
+ logger.info("Server shutdown complete")
293
+
294
+ # FastAPI App
295
+ app = FastAPI(
296
+ title="Dhwani API",
297
+ description="AI Chat API supporting Indian languages",
298
+ version="1.0.0",
299
+ redirect_slashes=False,
300
+ lifespan=lifespan
301
+ )
302
+
303
+ app.add_middleware(
304
+ CORSMiddleware,
305
+ allow_origins=["*"],
306
+ allow_credentials=False,
307
+ allow_methods=["*"],
308
+ allow_headers=["*"],
309
+ )
310
+
311
+ limiter = Limiter(key_func=get_remote_address)
312
+ app.state.limiter = limiter
313
+
314
+ # API Endpoints
315
+ @app.post("/audio/speech", response_class=StreamingResponse)
316
+ async def synthesize_kannada(request: KannadaSynthesizeRequest):
317
+ kannada_example = next(ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)")
318
+ if not request.text.strip():
319
+ raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
320
+
321
+ audio_buffer = synthesize_speech(
322
+ text=request.text,
323
+ ref_audio_name="KAN_F (Happy)",
324
+ ref_text=kannada_example["ref_text"]
325
+ )
326
+
327
+ return StreamingResponse(
328
+ audio_buffer,
329
+ media_type="audio/wav",
330
+ headers={"Content-Disposition": "attachment; filename=synthesized_kannada_speech.wav"}
331
+ )
332
+
333
  @app.post("/translate", response_model=TranslationResponse)
334
  async def translate(request: TranslationRequest, translate_manager: TranslateManager = Depends(get_translate_manager)):
335
  input_sentences = request.sentences
 
340
  raise HTTPException(status_code=400, detail="Input sentences are required")
341
 
342
  batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang)
 
343
  inputs = translate_manager.tokenizer(
344
  batch,
345
  truncation=True,
 
368
  translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
369
  return TranslationResponse(translations=translations)
370
 
 
371
  async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
372
  translate_manager = model_manager.get_model(src_lang, tgt_lang)
373
  request = TranslationRequest(sentences=sentences, src_lang=src_lang, tgt_lang=tgt_lang)
374
  response = await translate(request, translate_manager)
375
  return response.translations
376
 
 
377
  @app.get("/v1/health")
378
  async def health_check():
379
  return {"status": "healthy", "model": settings.llm_model_name}
 
386
  async def unload_all_models():
387
  try:
388
  logger.info("Starting to unload all models...")
389
+ await llm_manager.unload()
390
  logger.info("All models unloaded successfully")
391
  return {"status": "success", "message": "All models unloaded"}
392
  except Exception as e:
 
397
  async def load_all_models():
398
  try:
399
  logger.info("Starting to load all models...")
400
+ await llm_manager.load()
401
  logger.info("All models loaded successfully")
402
  return {"status": "success", "message": "All models loaded"}
403
  except Exception as e:
 
587
  logger.error(f"Error processing request: {str(e)}")
588
  raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
589
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
590
  @app.post("/transcribe/", response_model=TranscriptionResponse)
591
  async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
592
+ if not asr_manager.model:
593
+ raise HTTPException(status_code=503, detail="ASR model still loading, please try again later")
594
  try:
595
  wav, sr = torchaudio.load(file.file)
596
  wav = torch.mean(wav, dim=0, keepdim=True)
 
598
  if sr != target_sample_rate:
599
  resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
600
  wav = resampler(wav)
601
+ transcription_rnnt = asr_manager.model(wav, asr_manager.model_language[language], "rnnt")
602
  return TranscriptionResponse(text=transcription_rnnt)
603
  except Exception as e:
604
  logger.error(f"Error in transcription: {str(e)}")
605
  raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
606
+
607
  @app.post("/v1/speech_to_speech")
608
  async def speech_to_speech(
609
+ request: Request,
610
  file: UploadFile = File(...),
611
  language: str = Query(..., enum=list(asr_manager.model_language.keys())),
612
  ) -> StreamingResponse:
 
613
  transcription = await transcribe_audio(file, language)
614
  logger.info(f"Transcribed text: {transcription.text}")
615
 
 
616
  chat_request = ChatRequest(
617
  prompt=transcription.text,
618
+ src_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"),
619
  tgt_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda")
620
  )
621
+ processed_text = await chat(request, chat_request)
622
  logger.info(f"Processed text: {processed_text.response}")
623
 
624
  voice_request = KannadaSynthesizeRequest(text=processed_text.response)
625
+ audio_response = await synthesize_kannada(voice_request)
 
 
 
 
626
  return audio_response
627
 
628
+ LANGUAGE_TO_SCRIPT = {
629
+ "kannada": "kan_Knda"
630
+ }
 
631
 
632
+ # Main Execution
633
  if __name__ == "__main__":
634
  parser = argparse.ArgumentParser(description="Run the FastAPI server.")
635
  parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
636
  parser.add_argument("--host", type=str, default=settings.host, help="Host to run the server on.")
637
+ parser.add_argument("--config", type=str, default="config_one", help="Configuration to use")
638
  args = parser.parse_args()
639
 
 
640
  def load_config(config_path="dhwani_config.json"):
641
  with open(config_path, "r") as f:
642
  return json.load(f)
 
648
  selected_config = config_data["configs"][args.config]
649
  global_settings = config_data["global_settings"]
650
 
 
651
  settings.llm_model_name = selected_config["components"]["LLM"]["model"]
652
  settings.max_tokens = selected_config["components"]["LLM"]["max_tokens"]
653
  settings.host = global_settings["host"]
 
655
  settings.chat_rate_limit = global_settings["chat_rate_limit"]
656
  settings.speech_rate_limit = global_settings["speech_rate_limit"]
657
 
 
658
  llm_manager = LLMManager(settings.llm_model_name)
659
 
 
660
  if selected_config["components"]["ASR"]:
661
  asr_model_name = selected_config["components"]["ASR"]["model"]
 
662
  asr_manager.model_language[selected_config["language"]] = selected_config["components"]["ASR"]["language_code"]
663
 
 
 
 
664
  if selected_config["components"]["Translation"]:
665
  for translation_config in selected_config["components"]["Translation"]:
666
  src_lang = translation_config["src_lang"]
667
  tgt_lang = translation_config["tgt_lang"]
668
+ asyncio.create_task(model_manager.load_model(src_lang, tgt_lang, model_manager._get_model_key(src_lang, tgt_lang)))
669
 
 
670
  host = args.host if args.host != settings.host else settings.host
671
  port = args.port if args.port != settings.port else settings.port
672
 
 
673
  uvicorn.run(app, host=host, port=port)