sachin commited on
Commit
f238ccb
·
1 Parent(s): 665c478

config-based start

Browse files
Files changed (3) hide show
  1. Dockerfile +2 -1
  2. dhwani_config.json +143 -0
  3. src/server/main.py +80 -6
Dockerfile CHANGED
@@ -20,6 +20,7 @@ RUN pip install --upgrade pip setuptools setuptools-rust torch
20
  RUN pip install flash-attn --no-build-isolation
21
 
22
  COPY requirements.txt .
 
23
  #RUN pip install --no-cache-dir torch==2.6.0 torchvision
24
  #RUN pip install --no-cache-dir transformers
25
  RUN pip install --no-cache-dir -r requirements.txt
@@ -35,4 +36,4 @@ USER appuser
35
  EXPOSE 7860
36
 
37
  # Use absolute path for clarity
38
- CMD ["python", "/app/src/server/main.py", "--host", "0.0.0.0", "--port", "7860"]
 
20
  RUN pip install flash-attn --no-build-isolation
21
 
22
  COPY requirements.txt .
23
+ COPY dhwani_config.json .
24
  #RUN pip install --no-cache-dir torch==2.6.0 torchvision
25
  #RUN pip install --no-cache-dir transformers
26
  RUN pip install --no-cache-dir -r requirements.txt
 
36
  EXPOSE 7860
37
 
38
  # Use absolute path for clarity
39
+ CMD ["python", "/app/src/server/main.py", "--host", "0.0.0.0", "--port", "7860", "--config", "config_two"]
dhwani_config.json ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "variant": "base",
3
+ "hardware": "NVIDIA T4",
4
+ "configs": {
5
+ "config_one": {
6
+ "description": "Kannada - Speech to Text",
7
+ "language": "kannada",
8
+ "components": {
9
+ "ASR": {
10
+ "model": "ai4bharat/indic-conformer-600m-multilingual",
11
+ "language_code": "kn",
12
+ "decoding": "rnnt"
13
+ },
14
+ "LLM": {
15
+ "model": "google/gemma-3-1b-it",
16
+ "max_tokens": 512
17
+ },
18
+ "Vision": {
19
+ "model": "moondream2"
20
+ },
21
+ "Translation": [
22
+ {
23
+ "type": "eng_indic",
24
+ "model": "ai4bharat/indictrans2-en-indic-dist-200M",
25
+ "src_lang": "eng_Latn",
26
+ "tgt_lang": "kan_Knda"
27
+ },
28
+ {
29
+ "type": "indic_eng",
30
+ "model": "ai4bharat/indictrans2-indic-en-dist-200M",
31
+ "src_lang": "kan_Knda",
32
+ "tgt_lang": "eng_Latn"
33
+ },
34
+ {
35
+ "type": "indic_indic",
36
+ "model": "ai4bharat/indictrans2-indic-indic-dist-320M",
37
+ "src_lang": "kan_Knda",
38
+ "tgt_lang": "hin_Deva"
39
+ }
40
+ ],
41
+ "TTS": null
42
+ }
43
+ },
44
+ "config_two": {
45
+ "description": "Kannada - Speech to Speech",
46
+ "language": "kannada",
47
+ "components": {
48
+ "ASR": {
49
+ "model": "ai4bharat/indic-conformer-600m-multilingual",
50
+ "language_code": "kn",
51
+ "decoding": "rnnt"
52
+ },
53
+ "LLM": {
54
+ "model": "google/gemma-3-1b-it",
55
+ "max_tokens": 512
56
+ },
57
+ "Vision": {
58
+ "model": "moondream2"
59
+ },
60
+ "Translation": [
61
+ {
62
+ "type": "eng_indic",
63
+ "model": "ai4bharat/indictrans2-en-indic-dist-200M",
64
+ "src_lang": "eng_Latn",
65
+ "tgt_lang": "kan_Knda"
66
+ },
67
+ {
68
+ "type": "indic_eng",
69
+ "model": "ai4bharat/indictrans2-indic-en-dist-200M",
70
+ "src_lang": "kan_Knda",
71
+ "tgt_lang": "eng_Latn"
72
+ },
73
+ {
74
+ "type": "indic_indic",
75
+ "model": "ai4bharat/indictrans2-indic-indic-dist-320M",
76
+ "src_lang": "kan_Knda",
77
+ "tgt_lang": "hin_Deva"
78
+ }
79
+ ],
80
+ "TTS": {
81
+ "model": "ai4bharat/indic-parler-tts",
82
+ "voice": "default_kannada_voice",
83
+ "speed": 1.0,
84
+ "response_format": "wav"
85
+ }
86
+ }
87
+ },
88
+ "config_three": {
89
+ "description": "German - Speech to Text",
90
+ "language": "german",
91
+ "components": {
92
+ "ASR": {
93
+ "model": "openai/whisper",
94
+ "language_code": "de",
95
+ "decoding": "default"
96
+ },
97
+ "LLM": {
98
+ "model": "google/gemma-3-1b-it",
99
+ "max_tokens": 512
100
+ },
101
+ "Vision": {
102
+ "model": "moondream2"
103
+ },
104
+ "Translation": null,
105
+ "TTS": null
106
+ }
107
+ },
108
+ "config_four": {
109
+ "description": "German - Speech to Speech",
110
+ "language": "german",
111
+ "components": {
112
+ "ASR": {
113
+ "model": "openai/whisper",
114
+ "language_code": "de",
115
+ "decoding": "default"
116
+ },
117
+ "LLM": {
118
+ "model": "google/gemma-3-1b-it",
119
+ "max_tokens": 512
120
+ },
121
+ "Vision": {
122
+ "model": "moondream2"
123
+ },
124
+ "Translation": null,
125
+ "TTS": {
126
+ "model": "parler-tts",
127
+ "voice": "default_german_voice",
128
+ "speed": 1.0,
129
+ "response_format": "wav"
130
+ }
131
+ }
132
+ }
133
+ },
134
+ "global_settings": {
135
+ "host": "0.0.0.0",
136
+ "port": 7860,
137
+ "chat_rate_limit": "100/minute",
138
+ "speech_rate_limit": "5/minute",
139
+ "device": "cuda",
140
+ "dtype": "bfloat16",
141
+ "lazy_load": false
142
+ }
143
+ }
src/server/main.py CHANGED
@@ -791,14 +791,39 @@ async def transcribe_audio(file: UploadFile = File(...), language: str = Query(.
791
  resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
792
  wav = resampler(wav)
793
 
794
- # Perform ASR with CTC decoding
795
- #transcription_ctc = model(wav, "kn", "ctc")
796
-
797
- # Perform ASR with RNNT decoding
798
- transcription_rnnt = model(wav, "kn", "rnnt")
799
 
800
  return JSONResponse(content={"text": transcription_rnnt})
801
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
802
 
803
 
804
  class BatchTranscriptionResponse(BaseModel):
@@ -810,5 +835,54 @@ if __name__ == "__main__":
810
  parser = argparse.ArgumentParser(description="Run the FastAPI server.")
811
  parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
812
  parser.add_argument("--host", type=str, default=settings.host, help="Host to run the server on.")
 
813
  args = parser.parse_args()
814
- uvicorn.run(app, host=args.host, port=args.port)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
791
  resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
792
  wav = resampler(wav)
793
 
794
+ # Perform ASR with RNNT decoding using the provided language
795
+ transcription_rnnt = model(wav, asr_manager.model_language[language], "rnnt")
 
 
 
796
 
797
  return JSONResponse(content={"text": transcription_rnnt})
798
 
799
+ @app.post("/v1/speech_to_speech")
800
+ async def speech_to_speech(
801
+ file: UploadFile = File(...),
802
+ language: str = Query(..., enum=list(asr_manager.model_language.keys())),
803
+ voice: str = Body(default=config.voice)
804
+ ) -> StreamingResponse:
805
+ # Step 1: Transcribe audio to text
806
+ transcription = await transcribe_audio(file, language)
807
+ logger.info(f"Transcribed text: {transcription.text}")
808
+
809
+ # Step 2: Process text with chat endpoint
810
+ chat_request = ChatRequest(
811
+ prompt=transcription.text,
812
+ src_lang=f"{language}_Knda", # Assuming script for Indian languages
813
+ tgt_lang=f"{language}_Knda"
814
+ )
815
+ processed_text = await chat(Request(), chat_request)
816
+ logger.info(f"Processed text: {processed_text.response}")
817
+
818
+ # Step 3: Convert processed text to speech
819
+ audio_response = await generate_audio(
820
+ input=processed_text.response,
821
+ voice=voice,
822
+ model=tts_config.model,
823
+ response_format=config.response_format,
824
+ speed=SPEED
825
+ )
826
+ return audio_response
827
 
828
 
829
  class BatchTranscriptionResponse(BaseModel):
 
835
  parser = argparse.ArgumentParser(description="Run the FastAPI server.")
836
  parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
837
  parser.add_argument("--host", type=str, default=settings.host, help="Host to run the server on.")
838
+ parser.add_argument("--config", type=str, default="config_one", help="Configuration to use (e.g., config_one, config_two, config_three, config_four)")
839
  args = parser.parse_args()
840
+
841
+ # Load the JSON configuration file
842
+ def load_config(config_path="dhwani_config.json"):
843
+ with open(config_path, "r") as f:
844
+ return json.load(f)
845
+
846
+ config_data = load_config()
847
+ if args.config not in config_data["configs"]:
848
+ raise ValueError(f"Invalid config: {args.config}. Available: {list(config_data['configs'].keys())}")
849
+
850
+ selected_config = config_data["configs"][args.config]
851
+ global_settings = config_data["global_settings"]
852
+
853
+ # Update settings based on selected config
854
+ settings.llm_model_name = selected_config["components"]["LLM"]["model"]
855
+ settings.max_tokens = selected_config["components"]["LLM"]["max_tokens"]
856
+ settings.host = global_settings["host"]
857
+ settings.port = global_settings["port"]
858
+ settings.chat_rate_limit = global_settings["chat_rate_limit"]
859
+ settings.speech_rate_limit = global_settings["speech_rate_limit"]
860
+
861
+ # Initialize LLMManager with the selected LLM model
862
+ llm_manager = LLMManager(settings.llm_model_name)
863
+
864
+ # Initialize ASR model if present in config
865
+ if selected_config["components"]["ASR"]:
866
+ asr_model_name = selected_config["components"]["ASR"]["model"]
867
+ model = AutoModel.from_pretrained(asr_model_name, trust_remote_code=True)
868
+ asr_manager.model_language[selected_config["language"]] = selected_config["components"]["ASR"]["language_code"]
869
+
870
+ # Initialize TTS model if present in config
871
+ if selected_config["components"]["TTS"]:
872
+ tts_model_name = selected_config["components"]["TTS"]["model"]
873
+ tts_config.model = tts_model_name # Update tts_config to use the selected model
874
+ tts_model_manager.get_or_load_model(tts_model_name)
875
+
876
+ # Initialize Translation models - load all specified models
877
+ if selected_config["components"]["Translation"]:
878
+ for translation_config in selected_config["components"]["Translation"]:
879
+ src_lang = translation_config["src_lang"]
880
+ tgt_lang = translation_config["tgt_lang"]
881
+ model_manager.get_model(src_lang, tgt_lang)
882
+
883
+ # Override host and port from command line arguments if provided
884
+ host = args.host if args.host != settings.host else settings.host
885
+ port = args.port if args.port != settings.port else settings.port
886
+
887
+ # Run the server
888
+ uvicorn.run(app, host=host, port=port)