sachin commited on
Commit
7a61b58
·
1 Parent(s): 224556e
Dockerfile CHANGED
@@ -1,4 +1,4 @@
1
- FROM ubuntu:22.04
2
  WORKDIR /app
3
 
4
  RUN apt-get update && apt-get install -y \
@@ -17,6 +17,8 @@ RUN export CC=/usr/bin/gcc
17
  RUN export CXX=/usr/bin/g++
18
 
19
  RUN pip install --upgrade pip setuptools setuptools-rust torch
 
 
20
  COPY requirements.txt .
21
  #RUN pip install --no-cache-dir torch==2.6.0 torchvision
22
  #RUN pip install --no-cache-dir transformers
 
1
+ FROM nvidia/cuda:12.8.0-cudnn-devel-ubuntu22.04
2
  WORKDIR /app
3
 
4
  RUN apt-get update && apt-get install -y \
 
17
  RUN export CXX=/usr/bin/g++
18
 
19
  RUN pip install --upgrade pip setuptools setuptools-rust torch
20
+ RUN pip install flash-attn --no-build-isolation
21
+
22
  COPY requirements.txt .
23
  #RUN pip install --no-cache-dir torch==2.6.0 torchvision
24
  #RUN pip install --no-cache-dir transformers
requirements.txt CHANGED
@@ -9,3 +9,18 @@ pydantic_settings
9
  slowapi
10
  python-multipart
11
  IndicTransToolkit @ git+https://github.com/VarunGumma/IndicTransToolkit.git@399b3fec93d2ee85cb998cb7a4fb7a7d83afcbcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  slowapi
10
  python-multipart
11
  IndicTransToolkit @ git+https://github.com/VarunGumma/IndicTransToolkit.git@399b3fec93d2ee85cb998cb7a4fb7a7d83afcbcf
12
+ packaging
13
+
14
+ sentencepiece
15
+ descript-audio-codec
16
+ descript-audiotools @ git+https://github.com/descriptinc/audiotools
17
+ protobuf>=4.0.0
18
+ fastapi
19
+ uvicorn
20
+ pydantic-settings
21
+ huggingface-hub
22
+ openai
23
+ torch
24
+ parler_tts @ git+https://github.com/slabstech/parler-tts.git
25
+ packaging # Added to resolve flash-attn dependency
26
+ flash-attn
src/server/config.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+
3
+ from pydantic_settings import BaseSettings
4
+
5
+ SPEED = 1.0
6
+
7
+ class StrEnum(str, enum.Enum):
8
+ """Custom implementation of StrEnum for Python versions < 3.11"""
9
+ def __str__(self):
10
+ return str(self.value)
11
+
12
+ # NOTE: commented out response formats don't work
13
+ class ResponseFormat(StrEnum):
14
+ MP3 = "mp3"
15
+ # OPUS = "opus"
16
+ # AAC = "aac"
17
+ FLAC = "flac"
18
+ WAV = "wav"
19
+ # PCM = "pcm"
20
+
21
+ class Config(BaseSettings):
22
+ log_level: str = "info" # env: LOG_LEVEL
23
+ model: str = "ai4bharat/indic-parler-tts" # env: MODEL
24
+ max_models: int = 1 # env: MAX_MODELS
25
+ lazy_load_model: bool = False # env: LAZY_LOAD_MODEL
26
+ input: str = ("ನಿಮ್ಮ ಇನ್‌ಪುಟ್ ಪಠ್ಯವನ್ನು ಇಲ್ಲಿ ಸೇರಿಸಿ")
27
+ voice: str = (
28
+ "Anu speaks with a high pitch at a normal pace in a clear, close-sounding environment. Her neutral tone is captured with excellent audio quality" # env: VOICE
29
+ )
30
+ response_format: ResponseFormat = ResponseFormat.MP3 # env: RESPONSE_FORMAT
31
+
32
+ config = Config()
src/server/logger.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import logging.config
3
+
4
+ from config import config
5
+
6
+ logger = logging.getLogger("tts_indic_server")
7
+
8
+ # https://www.youtube.com/watch?v=9L77QExPmI0
9
+ # https://docs.python.org/3/library/logging.config.html
10
+ logging_config = {
11
+ "version": 1, # required
12
+ "disable_existing_loggers": False,
13
+ "formatters": {
14
+ "simple": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"},
15
+ },
16
+ "handlers": {
17
+ "stdout": {
18
+ "class": "logging.StreamHandler",
19
+ "formatter": "simple",
20
+ "stream": "ext://sys.stdout",
21
+ },
22
+ },
23
+ "loggers": {
24
+ "root": {
25
+ "level": config.log_level.upper(),
26
+ "handlers": ["stdout"],
27
+ },
28
+ },
29
+ }
30
+
31
+
32
+ logging.config.dictConfig(logging_config)
src/server/main.py CHANGED
@@ -23,6 +23,304 @@ from tts_config import SPEED, ResponseFormat, config as tts_config
23
  from gemma_llm import LLMManager
24
  # from auth import get_api_key, settings as auth_settings
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  # Supported language codes
27
  SUPPORTED_LANGUAGES = {
28
  "asm_Beng", "kas_Arab", "pan_Guru", "ben_Beng", "kas_Deva", "san_Deva",
@@ -51,12 +349,7 @@ class Settings(BaseSettings):
51
 
52
  settings = Settings()
53
 
54
- app = FastAPI(
55
- title="Dhwani API",
56
- description="AI Chat API supporting Indian languages",
57
- version="1.0.0",
58
- redirect_slashes=False,
59
- )
60
  app.add_middleware(
61
  CORSMiddleware,
62
  allow_origins=["*"],
 
23
  from gemma_llm import LLMManager
24
  # from auth import get_api_key, settings as auth_settings
25
 
26
+
27
+ import time
28
+ from contextlib import asynccontextmanager
29
+ from typing import Annotated, Any, OrderedDict, List
30
+ import zipfile
31
+ import soundfile as sf
32
+ import torch
33
+ from fastapi import Body, FastAPI, HTTPException, Response
34
+ from parler_tts import ParlerTTSForConditionalGeneration
35
+ from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
36
+ import numpy as np
37
+ from config import SPEED, ResponseFormat, config
38
+ from logger import logger
39
+ import uvicorn
40
+ import argparse
41
+ from fastapi.responses import RedirectResponse, StreamingResponse
42
+ import io
43
+ import os
44
+ import logging
45
+
46
+ # Device setup
47
+ if torch.cuda.is_available():
48
+ device = "cuda:0"
49
+ logger.info("GPU will be used for inference")
50
+ else:
51
+ device = "cpu"
52
+ logger.info("CPU will be used for inference")
53
+ torch_dtype = torch.bfloat16 if device != "cpu" else torch.float32
54
+
55
+ # Check CUDA availability and version
56
+ cuda_available = torch.cuda.is_available()
57
+ cuda_version = torch.version.cuda if cuda_available else None
58
+
59
+ if torch.cuda.is_available():
60
+ device_idx = torch.cuda.current_device()
61
+ capability = torch.cuda.get_device_capability(device_idx)
62
+ compute_capability_float = float(f"{capability[0]}.{capability[1]}")
63
+ print(f"CUDA version: {cuda_version}")
64
+ print(f"CUDA Compute Capability: {compute_capability_float}")
65
+ else:
66
+ print("CUDA is not available on this system.")
67
+
68
+ class TTSModelManager:
69
+ def __init__(self):
70
+ self.model_tokenizer: OrderedDict[
71
+ str, tuple[ParlerTTSForConditionalGeneration, AutoTokenizer, AutoTokenizer]
72
+ ] = OrderedDict()
73
+ self.max_length = 50
74
+
75
+ def load_model(
76
+ self, model_name: str
77
+ ) -> tuple[ParlerTTSForConditionalGeneration, AutoTokenizer, AutoTokenizer]:
78
+ logger.debug(f"Loading {model_name}...")
79
+ start = time.perf_counter()
80
+
81
+ model_name = "ai4bharat/indic-parler-tts"
82
+ attn_implementation = "flash_attention_2"
83
+
84
+ model = ParlerTTSForConditionalGeneration.from_pretrained(
85
+ model_name,
86
+ attn_implementation=attn_implementation
87
+ ).to(device, dtype=torch_dtype)
88
+
89
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
90
+ description_tokenizer = AutoTokenizer.from_pretrained(model.config.text_encoder._name_or_path)
91
+
92
+ # Set pad tokens
93
+ if tokenizer.pad_token is None:
94
+ tokenizer.pad_token = tokenizer.eos_token
95
+ if description_tokenizer.pad_token is None:
96
+ description_tokenizer.pad_token = description_tokenizer.eos_token
97
+
98
+ # Update model configuration
99
+ model.config.pad_token_id = tokenizer.pad_token_id
100
+ # Update for deprecation: use max_batch_size instead of batch_size
101
+ if hasattr(model.generation_config.cache_config, 'max_batch_size'):
102
+ model.generation_config.cache_config.max_batch_size = 1
103
+ model.generation_config.cache_implementation = "static"
104
+
105
+ # Compile the model
106
+ ##compile_mode = "default"
107
+ compile_mode = "reduce-overhead"
108
+
109
+ model.forward = torch.compile(model.forward, mode=compile_mode)
110
+
111
+ # Warmup
112
+ warmup_inputs = tokenizer("Warmup text for compilation",
113
+ return_tensors="pt",
114
+ padding="max_length",
115
+ max_length=self.max_length).to(device)
116
+
117
+ model_kwargs = {
118
+ "input_ids": warmup_inputs["input_ids"],
119
+ "attention_mask": warmup_inputs["attention_mask"],
120
+ "prompt_input_ids": warmup_inputs["input_ids"],
121
+ "prompt_attention_mask": warmup_inputs["attention_mask"],
122
+ }
123
+
124
+ n_steps = 1 if compile_mode == "default" else 2
125
+ for _ in range(n_steps):
126
+ _ = model.generate(**model_kwargs)
127
+
128
+ logger.info(
129
+ f"Loaded {model_name} with Flash Attention and compilation in {time.perf_counter() - start:.2f} seconds"
130
+ )
131
+ return model, tokenizer, description_tokenizer
132
+
133
+ def get_or_load_model(
134
+ self, model_name: str
135
+ ) -> tuple[ParlerTTSForConditionalGeneration, AutoTokenizer, AutoTokenizer]:
136
+ if model_name not in self.model_tokenizer:
137
+ logger.info(f"Model {model_name} isn't already loaded")
138
+ if len(self.model_tokenizer) == config.max_models:
139
+ logger.info("Unloading the oldest loaded model")
140
+ del self.model_tokenizer[next(iter(self.model_tokenizer))]
141
+ self.model_tokenizer[model_name] = self.load_model(model_name)
142
+ return self.model_tokenizer[model_name]
143
+
144
+ tts_model_manager = TTSModelManager()
145
+
146
+ @asynccontextmanager
147
+ async def lifespan(_: FastAPI):
148
+ if not config.lazy_load_model:
149
+ tts_model_manager.get_or_load_model(config.model)
150
+ yield
151
+
152
+ #app = FastAPI(lifespan=lifespan)
153
+ app = FastAPI(
154
+ title="Dhwani API",
155
+ description="AI Chat API supporting Indian languages",
156
+ version="1.0.0",
157
+ redirect_slashes=False,
158
+ lifespan=lifespan
159
+ )
160
+
161
+
162
+ def chunk_text(text, chunk_size):
163
+ words = text.split()
164
+ chunks = []
165
+ for i in range(0, len(words), chunk_size):
166
+ chunks.append(' '.join(words[i:i + chunk_size]))
167
+ return chunks
168
+
169
+ @app.post("/v1/audio/speech")
170
+ async def generate_audio(
171
+ input: Annotated[str, Body()] = config.input,
172
+ voice: Annotated[str, Body()] = config.voice,
173
+ model: Annotated[str, Body()] = config.model,
174
+ response_format: Annotated[ResponseFormat, Body(include_in_schema=False)] = config.response_format,
175
+ speed: Annotated[float, Body(include_in_schema=False)] = SPEED,
176
+ ) -> StreamingResponse:
177
+ tts, tokenizer, description_tokenizer = model_manager.get_or_load_model(model)
178
+ if speed != SPEED:
179
+ logger.warning(
180
+ "Specifying speed isn't supported by this model. Audio will be generated with the default speed"
181
+ )
182
+ start = time.perf_counter()
183
+
184
+ chunk_size = 15
185
+ all_chunks = chunk_text(input, chunk_size)
186
+
187
+ if len(all_chunks) <= chunk_size:
188
+ desc_inputs = description_tokenizer(voice,
189
+ return_tensors="pt",
190
+ padding="max_length",
191
+ max_length=model_manager.max_length).to(device)
192
+ prompt_inputs = tokenizer(input,
193
+ return_tensors="pt",
194
+ padding="max_length",
195
+ max_length=model_manager.max_length).to(device)
196
+
197
+ # Use the tensor fields directly instead of BatchEncoding object
198
+ input_ids = desc_inputs["input_ids"]
199
+ attention_mask = desc_inputs["attention_mask"]
200
+ prompt_input_ids = prompt_inputs["input_ids"]
201
+ prompt_attention_mask = prompt_inputs["attention_mask"]
202
+
203
+ generation = tts.generate(
204
+ input_ids=input_ids,
205
+ prompt_input_ids=prompt_input_ids,
206
+ attention_mask=attention_mask,
207
+ prompt_attention_mask=prompt_attention_mask
208
+ ).to(torch.float32)
209
+
210
+ audio_arr = generation.cpu().float().numpy().squeeze()
211
+ else:
212
+ all_descriptions = [voice] * len(all_chunks)
213
+ description_inputs = description_tokenizer(all_descriptions,
214
+ return_tensors="pt",
215
+ padding=True).to(device)
216
+ prompts = tokenizer(all_chunks,
217
+ return_tensors="pt",
218
+ padding=True).to(device)
219
+
220
+ set_seed(0)
221
+ generation = tts.generate(
222
+ input_ids=description_inputs["input_ids"],
223
+ attention_mask=description_inputs["attention_mask"],
224
+ prompt_input_ids=prompts["input_ids"],
225
+ prompt_attention_mask=prompts["attention_mask"],
226
+ do_sample=True,
227
+ return_dict_in_generate=True,
228
+ )
229
+
230
+ chunk_audios = []
231
+ for i, audio in enumerate(generation.sequences):
232
+ audio_data = audio[:generation.audios_length[i]].cpu().float().numpy().squeeze()
233
+ chunk_audios.append(audio_data)
234
+ audio_arr = np.concatenate(chunk_audios)
235
+
236
+ device_str = str(device)
237
+ logger.info(
238
+ f"Took {time.perf_counter() - start:.2f} seconds to generate audio for {len(input.split())} words using {device_str.upper()}"
239
+ )
240
+
241
+ audio_buffer = io.BytesIO()
242
+ sf.write(audio_buffer, audio_arr, tts.config.sampling_rate, format=response_format)
243
+ audio_buffer.seek(0)
244
+
245
+ return StreamingResponse(audio_buffer, media_type=f"audio/{response_format}")
246
+
247
+ def create_in_memory_zip(file_data):
248
+ in_memory_zip = io.BytesIO()
249
+ with zipfile.ZipFile(in_memory_zip, 'w') as zipf:
250
+ for file_name, data in file_data.items():
251
+ zipf.writestr(file_name, data)
252
+ in_memory_zip.seek(0)
253
+ return in_memory_zip
254
+
255
+ @app.post("/v1/audio/speech_batch")
256
+ async def generate_audio_batch(
257
+ input: Annotated[List[str], Body()] = config.input,
258
+ voice: Annotated[List[str], Body()] = config.voice,
259
+ model: Annotated[str, Body(include_in_schema=False)] = config.model,
260
+ response_format: Annotated[ResponseFormat, Body()] = config.response_format,
261
+ speed: Annotated[float, Body(include_in_schema=False)] = SPEED,
262
+ ) -> StreamingResponse:
263
+ tts, tokenizer, description_tokenizer = model_manager.get_or_load_model(model)
264
+ if speed != SPEED:
265
+ logger.warning(
266
+ "Specifying speed isn't supported by this model. Audio will be generated with the default speed"
267
+ )
268
+ start = time.perf_counter()
269
+
270
+ chunk_size = 15
271
+ all_chunks = []
272
+ all_descriptions = []
273
+ for i, text in enumerate(input):
274
+ chunks = chunk_text(text, chunk_size)
275
+ all_chunks.extend(chunks)
276
+ all_descriptions.extend([voice[i]] * len(chunks))
277
+
278
+ description_inputs = description_tokenizer(all_descriptions,
279
+ return_tensors="pt",
280
+ padding=True).to(device)
281
+ prompts = tokenizer(all_chunks,
282
+ return_tensors="pt",
283
+ padding=True).to(device)
284
+
285
+ set_seed(0)
286
+ generation = tts.generate(
287
+ input_ids=description_inputs["input_ids"],
288
+ attention_mask=description_inputs["attention_mask"],
289
+ prompt_input_ids=prompts["input_ids"],
290
+ prompt_attention_mask=prompts["attention_mask"],
291
+ do_sample=True,
292
+ return_dict_in_generate=True,
293
+ )
294
+
295
+ audio_outputs = []
296
+ current_index = 0
297
+ for i, text in enumerate(input):
298
+ chunks = chunk_text(text, chunk_size)
299
+ chunk_audios = []
300
+ for j in range(len(chunks)):
301
+ audio_arr = generation.sequences[current_index][:generation.audios_length[current_index]].cpu().float().numpy().squeeze()
302
+ chunk_audios.append(audio_arr)
303
+ current_index += 1
304
+ combined_audio = np.concatenate(chunk_audios)
305
+ audio_outputs.append(combined_audio)
306
+
307
+ file_data = {}
308
+ for i, audio in enumerate(audio_outputs):
309
+ file_name = f"out_{i}.{response_format}"
310
+ audio_bytes = io.BytesIO()
311
+ sf.write(audio_bytes, audio, tts.config.sampling_rate, format=response_format)
312
+ audio_bytes.seek(0)
313
+ file_data[file_name] = audio_bytes.read()
314
+
315
+ in_memory_zip = create_in_memory_zip(file_data)
316
+
317
+ logger.info(
318
+ f"Took {time.perf_counter() - start:.2f} seconds to generate audio"
319
+ )
320
+
321
+ return StreamingResponse(in_memory_zip, media_type="application/zip")
322
+
323
+
324
  # Supported language codes
325
  SUPPORTED_LANGUAGES = {
326
  "asm_Beng", "kas_Arab", "pan_Guru", "ben_Beng", "kas_Deva", "san_Deva",
 
349
 
350
  settings = Settings()
351
 
352
+
 
 
 
 
 
353
  app.add_middleware(
354
  CORSMiddleware,
355
  allow_origins=["*"],