sachin
commited on
Commit
·
f238ccb
1
Parent(s):
665c478
config-based start
Browse files- Dockerfile +2 -1
- dhwani_config.json +143 -0
- src/server/main.py +80 -6
Dockerfile
CHANGED
@@ -20,6 +20,7 @@ RUN pip install --upgrade pip setuptools setuptools-rust torch
|
|
20 |
RUN pip install flash-attn --no-build-isolation
|
21 |
|
22 |
COPY requirements.txt .
|
|
|
23 |
#RUN pip install --no-cache-dir torch==2.6.0 torchvision
|
24 |
#RUN pip install --no-cache-dir transformers
|
25 |
RUN pip install --no-cache-dir -r requirements.txt
|
@@ -35,4 +36,4 @@ USER appuser
|
|
35 |
EXPOSE 7860
|
36 |
|
37 |
# Use absolute path for clarity
|
38 |
-
CMD ["python", "/app/src/server/main.py", "--host", "0.0.0.0", "--port", "7860"]
|
|
|
20 |
RUN pip install flash-attn --no-build-isolation
|
21 |
|
22 |
COPY requirements.txt .
|
23 |
+
COPY dhwani_config.json .
|
24 |
#RUN pip install --no-cache-dir torch==2.6.0 torchvision
|
25 |
#RUN pip install --no-cache-dir transformers
|
26 |
RUN pip install --no-cache-dir -r requirements.txt
|
|
|
36 |
EXPOSE 7860
|
37 |
|
38 |
# Use absolute path for clarity
|
39 |
+
CMD ["python", "/app/src/server/main.py", "--host", "0.0.0.0", "--port", "7860", "--config", "config_two"]
|
dhwani_config.json
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"variant": "base",
|
3 |
+
"hardware": "NVIDIA T4",
|
4 |
+
"configs": {
|
5 |
+
"config_one": {
|
6 |
+
"description": "Kannada - Speech to Text",
|
7 |
+
"language": "kannada",
|
8 |
+
"components": {
|
9 |
+
"ASR": {
|
10 |
+
"model": "ai4bharat/indic-conformer-600m-multilingual",
|
11 |
+
"language_code": "kn",
|
12 |
+
"decoding": "rnnt"
|
13 |
+
},
|
14 |
+
"LLM": {
|
15 |
+
"model": "google/gemma-3-1b-it",
|
16 |
+
"max_tokens": 512
|
17 |
+
},
|
18 |
+
"Vision": {
|
19 |
+
"model": "moondream2"
|
20 |
+
},
|
21 |
+
"Translation": [
|
22 |
+
{
|
23 |
+
"type": "eng_indic",
|
24 |
+
"model": "ai4bharat/indictrans2-en-indic-dist-200M",
|
25 |
+
"src_lang": "eng_Latn",
|
26 |
+
"tgt_lang": "kan_Knda"
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"type": "indic_eng",
|
30 |
+
"model": "ai4bharat/indictrans2-indic-en-dist-200M",
|
31 |
+
"src_lang": "kan_Knda",
|
32 |
+
"tgt_lang": "eng_Latn"
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"type": "indic_indic",
|
36 |
+
"model": "ai4bharat/indictrans2-indic-indic-dist-320M",
|
37 |
+
"src_lang": "kan_Knda",
|
38 |
+
"tgt_lang": "hin_Deva"
|
39 |
+
}
|
40 |
+
],
|
41 |
+
"TTS": null
|
42 |
+
}
|
43 |
+
},
|
44 |
+
"config_two": {
|
45 |
+
"description": "Kannada - Speech to Speech",
|
46 |
+
"language": "kannada",
|
47 |
+
"components": {
|
48 |
+
"ASR": {
|
49 |
+
"model": "ai4bharat/indic-conformer-600m-multilingual",
|
50 |
+
"language_code": "kn",
|
51 |
+
"decoding": "rnnt"
|
52 |
+
},
|
53 |
+
"LLM": {
|
54 |
+
"model": "google/gemma-3-1b-it",
|
55 |
+
"max_tokens": 512
|
56 |
+
},
|
57 |
+
"Vision": {
|
58 |
+
"model": "moondream2"
|
59 |
+
},
|
60 |
+
"Translation": [
|
61 |
+
{
|
62 |
+
"type": "eng_indic",
|
63 |
+
"model": "ai4bharat/indictrans2-en-indic-dist-200M",
|
64 |
+
"src_lang": "eng_Latn",
|
65 |
+
"tgt_lang": "kan_Knda"
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"type": "indic_eng",
|
69 |
+
"model": "ai4bharat/indictrans2-indic-en-dist-200M",
|
70 |
+
"src_lang": "kan_Knda",
|
71 |
+
"tgt_lang": "eng_Latn"
|
72 |
+
},
|
73 |
+
{
|
74 |
+
"type": "indic_indic",
|
75 |
+
"model": "ai4bharat/indictrans2-indic-indic-dist-320M",
|
76 |
+
"src_lang": "kan_Knda",
|
77 |
+
"tgt_lang": "hin_Deva"
|
78 |
+
}
|
79 |
+
],
|
80 |
+
"TTS": {
|
81 |
+
"model": "ai4bharat/indic-parler-tts",
|
82 |
+
"voice": "default_kannada_voice",
|
83 |
+
"speed": 1.0,
|
84 |
+
"response_format": "wav"
|
85 |
+
}
|
86 |
+
}
|
87 |
+
},
|
88 |
+
"config_three": {
|
89 |
+
"description": "German - Speech to Text",
|
90 |
+
"language": "german",
|
91 |
+
"components": {
|
92 |
+
"ASR": {
|
93 |
+
"model": "openai/whisper",
|
94 |
+
"language_code": "de",
|
95 |
+
"decoding": "default"
|
96 |
+
},
|
97 |
+
"LLM": {
|
98 |
+
"model": "google/gemma-3-1b-it",
|
99 |
+
"max_tokens": 512
|
100 |
+
},
|
101 |
+
"Vision": {
|
102 |
+
"model": "moondream2"
|
103 |
+
},
|
104 |
+
"Translation": null,
|
105 |
+
"TTS": null
|
106 |
+
}
|
107 |
+
},
|
108 |
+
"config_four": {
|
109 |
+
"description": "German - Speech to Speech",
|
110 |
+
"language": "german",
|
111 |
+
"components": {
|
112 |
+
"ASR": {
|
113 |
+
"model": "openai/whisper",
|
114 |
+
"language_code": "de",
|
115 |
+
"decoding": "default"
|
116 |
+
},
|
117 |
+
"LLM": {
|
118 |
+
"model": "google/gemma-3-1b-it",
|
119 |
+
"max_tokens": 512
|
120 |
+
},
|
121 |
+
"Vision": {
|
122 |
+
"model": "moondream2"
|
123 |
+
},
|
124 |
+
"Translation": null,
|
125 |
+
"TTS": {
|
126 |
+
"model": "parler-tts",
|
127 |
+
"voice": "default_german_voice",
|
128 |
+
"speed": 1.0,
|
129 |
+
"response_format": "wav"
|
130 |
+
}
|
131 |
+
}
|
132 |
+
}
|
133 |
+
},
|
134 |
+
"global_settings": {
|
135 |
+
"host": "0.0.0.0",
|
136 |
+
"port": 7860,
|
137 |
+
"chat_rate_limit": "100/minute",
|
138 |
+
"speech_rate_limit": "5/minute",
|
139 |
+
"device": "cuda",
|
140 |
+
"dtype": "bfloat16",
|
141 |
+
"lazy_load": false
|
142 |
+
}
|
143 |
+
}
|
src/server/main.py
CHANGED
@@ -791,14 +791,39 @@ async def transcribe_audio(file: UploadFile = File(...), language: str = Query(.
|
|
791 |
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
|
792 |
wav = resampler(wav)
|
793 |
|
794 |
-
# Perform ASR with
|
795 |
-
|
796 |
-
|
797 |
-
# Perform ASR with RNNT decoding
|
798 |
-
transcription_rnnt = model(wav, "kn", "rnnt")
|
799 |
|
800 |
return JSONResponse(content={"text": transcription_rnnt})
|
801 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
802 |
|
803 |
|
804 |
class BatchTranscriptionResponse(BaseModel):
|
@@ -810,5 +835,54 @@ if __name__ == "__main__":
|
|
810 |
parser = argparse.ArgumentParser(description="Run the FastAPI server.")
|
811 |
parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
|
812 |
parser.add_argument("--host", type=str, default=settings.host, help="Host to run the server on.")
|
|
|
813 |
args = parser.parse_args()
|
814 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
791 |
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
|
792 |
wav = resampler(wav)
|
793 |
|
794 |
+
# Perform ASR with RNNT decoding using the provided language
|
795 |
+
transcription_rnnt = model(wav, asr_manager.model_language[language], "rnnt")
|
|
|
|
|
|
|
796 |
|
797 |
return JSONResponse(content={"text": transcription_rnnt})
|
798 |
|
799 |
+
@app.post("/v1/speech_to_speech")
|
800 |
+
async def speech_to_speech(
|
801 |
+
file: UploadFile = File(...),
|
802 |
+
language: str = Query(..., enum=list(asr_manager.model_language.keys())),
|
803 |
+
voice: str = Body(default=config.voice)
|
804 |
+
) -> StreamingResponse:
|
805 |
+
# Step 1: Transcribe audio to text
|
806 |
+
transcription = await transcribe_audio(file, language)
|
807 |
+
logger.info(f"Transcribed text: {transcription.text}")
|
808 |
+
|
809 |
+
# Step 2: Process text with chat endpoint
|
810 |
+
chat_request = ChatRequest(
|
811 |
+
prompt=transcription.text,
|
812 |
+
src_lang=f"{language}_Knda", # Assuming script for Indian languages
|
813 |
+
tgt_lang=f"{language}_Knda"
|
814 |
+
)
|
815 |
+
processed_text = await chat(Request(), chat_request)
|
816 |
+
logger.info(f"Processed text: {processed_text.response}")
|
817 |
+
|
818 |
+
# Step 3: Convert processed text to speech
|
819 |
+
audio_response = await generate_audio(
|
820 |
+
input=processed_text.response,
|
821 |
+
voice=voice,
|
822 |
+
model=tts_config.model,
|
823 |
+
response_format=config.response_format,
|
824 |
+
speed=SPEED
|
825 |
+
)
|
826 |
+
return audio_response
|
827 |
|
828 |
|
829 |
class BatchTranscriptionResponse(BaseModel):
|
|
|
835 |
parser = argparse.ArgumentParser(description="Run the FastAPI server.")
|
836 |
parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
|
837 |
parser.add_argument("--host", type=str, default=settings.host, help="Host to run the server on.")
|
838 |
+
parser.add_argument("--config", type=str, default="config_one", help="Configuration to use (e.g., config_one, config_two, config_three, config_four)")
|
839 |
args = parser.parse_args()
|
840 |
+
|
841 |
+
# Load the JSON configuration file
|
842 |
+
def load_config(config_path="dhwani_config.json"):
|
843 |
+
with open(config_path, "r") as f:
|
844 |
+
return json.load(f)
|
845 |
+
|
846 |
+
config_data = load_config()
|
847 |
+
if args.config not in config_data["configs"]:
|
848 |
+
raise ValueError(f"Invalid config: {args.config}. Available: {list(config_data['configs'].keys())}")
|
849 |
+
|
850 |
+
selected_config = config_data["configs"][args.config]
|
851 |
+
global_settings = config_data["global_settings"]
|
852 |
+
|
853 |
+
# Update settings based on selected config
|
854 |
+
settings.llm_model_name = selected_config["components"]["LLM"]["model"]
|
855 |
+
settings.max_tokens = selected_config["components"]["LLM"]["max_tokens"]
|
856 |
+
settings.host = global_settings["host"]
|
857 |
+
settings.port = global_settings["port"]
|
858 |
+
settings.chat_rate_limit = global_settings["chat_rate_limit"]
|
859 |
+
settings.speech_rate_limit = global_settings["speech_rate_limit"]
|
860 |
+
|
861 |
+
# Initialize LLMManager with the selected LLM model
|
862 |
+
llm_manager = LLMManager(settings.llm_model_name)
|
863 |
+
|
864 |
+
# Initialize ASR model if present in config
|
865 |
+
if selected_config["components"]["ASR"]:
|
866 |
+
asr_model_name = selected_config["components"]["ASR"]["model"]
|
867 |
+
model = AutoModel.from_pretrained(asr_model_name, trust_remote_code=True)
|
868 |
+
asr_manager.model_language[selected_config["language"]] = selected_config["components"]["ASR"]["language_code"]
|
869 |
+
|
870 |
+
# Initialize TTS model if present in config
|
871 |
+
if selected_config["components"]["TTS"]:
|
872 |
+
tts_model_name = selected_config["components"]["TTS"]["model"]
|
873 |
+
tts_config.model = tts_model_name # Update tts_config to use the selected model
|
874 |
+
tts_model_manager.get_or_load_model(tts_model_name)
|
875 |
+
|
876 |
+
# Initialize Translation models - load all specified models
|
877 |
+
if selected_config["components"]["Translation"]:
|
878 |
+
for translation_config in selected_config["components"]["Translation"]:
|
879 |
+
src_lang = translation_config["src_lang"]
|
880 |
+
tgt_lang = translation_config["tgt_lang"]
|
881 |
+
model_manager.get_model(src_lang, tgt_lang)
|
882 |
+
|
883 |
+
# Override host and port from command line arguments if provided
|
884 |
+
host = args.host if args.host != settings.host else settings.host
|
885 |
+
port = args.port if args.port != settings.port else settings.port
|
886 |
+
|
887 |
+
# Run the server
|
888 |
+
uvicorn.run(app, host=host, port=port)
|