sachin
commited on
Commit
·
773ab72
1
Parent(s):
0a0efec
asybc
Browse files- src/server/main.py +215 -7
src/server/main.py
CHANGED
@@ -14,7 +14,7 @@ from pydantic_settings import BaseSettings
|
|
14 |
from slowapi import Limiter
|
15 |
from slowapi.util import get_remote_address
|
16 |
import torch
|
17 |
-
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModel
|
18 |
from IndicTransToolkit import IndicProcessor
|
19 |
import json
|
20 |
import asyncio
|
@@ -25,7 +25,6 @@ import requests
|
|
25 |
from starlette.responses import StreamingResponse
|
26 |
from logging_config import logger
|
27 |
from tts_config import SPEED, ResponseFormat, config as tts_config
|
28 |
-
from gemma_llm import LLMManager # Assuming this is your custom LLMManager
|
29 |
|
30 |
# Device setup
|
31 |
if torch.cuda.is_available():
|
@@ -69,6 +68,209 @@ class Settings(BaseSettings):
|
|
69 |
|
70 |
settings = Settings()
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
# TTS Manager
|
73 |
class TTSManager:
|
74 |
def __init__(self, device_type=device):
|
@@ -197,7 +399,7 @@ class ModelManager:
|
|
197 |
if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
|
198 |
model_name = "ai4bharat/indictrans2-en-indic-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-en-indic-1B"
|
199 |
elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
|
200 |
-
model_name = "ai4bharat/indictrans2-indic-en-dist-200M" if
|
201 |
else:
|
202 |
model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if self.use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
|
203 |
|
@@ -292,6 +494,8 @@ def get_translate_manager(src_lang: str, tgt_lang: str) -> TranslateManager:
|
|
292 |
return model_manager.get_model(src_lang, tgt_lang)
|
293 |
|
294 |
# Lifespan Event Handler
|
|
|
|
|
295 |
@asynccontextmanager
|
296 |
async def lifespan(app: FastAPI):
|
297 |
async def load_all_models():
|
@@ -303,6 +507,12 @@ async def lifespan(app: FastAPI):
|
|
303 |
asyncio.create_task(model_manager.load_model('kan_Knda', 'eng_Latn', 'indic_eng')),
|
304 |
asyncio.create_task(model_manager.load_model('kan_Knda', 'hin_Deva', 'indic_indic')),
|
305 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
await asyncio.gather(*tasks)
|
307 |
logger.info("All models loaded successfully")
|
308 |
|
@@ -616,6 +826,7 @@ async def transcribe_audio(file: UploadFile = File(...), language: str = Query(.
|
|
616 |
if not asr_manager.model:
|
617 |
raise HTTPException(status_code=503, detail="ASR model still loading, please try again later")
|
618 |
try:
|
|
|
619 |
wav, sr = torchaudio.load(file.file)
|
620 |
wav = torch.mean(wav, dim=0, keepdim=True)
|
621 |
target_sample_rate = 16000
|
@@ -688,10 +899,7 @@ if __name__ == "__main__":
|
|
688 |
asr_manager.model_language[selected_config["language"]] = selected_config["components"]["ASR"]["language_code"]
|
689 |
|
690 |
if selected_config["components"]["Translation"]:
|
691 |
-
|
692 |
-
src_lang = translation_config["src_lang"]
|
693 |
-
tgt_lang = translation_config["tgt_lang"]
|
694 |
-
asyncio.create_task(model_manager.load_model(src_lang, tgt_lang, model_manager._get_model_key(src_lang, tgt_lang)))
|
695 |
|
696 |
host = args.host if args.host != settings.host else settings.host
|
697 |
port = args.port if args.port != settings.port else settings.port
|
|
|
14 |
from slowapi import Limiter
|
15 |
from slowapi.util import get_remote_address
|
16 |
import torch
|
17 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModel, AutoProcessor, BitsAndBytesConfig
|
18 |
from IndicTransToolkit import IndicProcessor
|
19 |
import json
|
20 |
import asyncio
|
|
|
25 |
from starlette.responses import StreamingResponse
|
26 |
from logging_config import logger
|
27 |
from tts_config import SPEED, ResponseFormat, config as tts_config
|
|
|
28 |
|
29 |
# Device setup
|
30 |
if torch.cuda.is_available():
|
|
|
68 |
|
69 |
settings = Settings()
|
70 |
|
71 |
+
# Quantization config for LLM
|
72 |
+
quantization_config = BitsAndBytesConfig(
|
73 |
+
load_in_4bit=True,
|
74 |
+
bnb_4bit_quant_type="nf4",
|
75 |
+
bnb_4bit_use_double_quant=True,
|
76 |
+
bnb_4bit_compute_dtype=torch.bfloat16
|
77 |
+
)
|
78 |
+
|
79 |
+
# LLM Manager (adapted from gemma_llm.py)
|
80 |
+
class LLMManager:
|
81 |
+
def __init__(self, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
|
82 |
+
self.model_name = model_name
|
83 |
+
self.device = torch.device(device)
|
84 |
+
self.torch_dtype = torch.bfloat16 if self.device.type != "cpu" else torch.float32
|
85 |
+
self.model = None
|
86 |
+
self.is_loaded = False
|
87 |
+
self.processor = None
|
88 |
+
logger.info(f"LLMManager initialized with model {model_name} on {self.device}")
|
89 |
+
|
90 |
+
async def unload(self):
|
91 |
+
if self.is_loaded:
|
92 |
+
await asyncio.to_thread(self._unload_sync)
|
93 |
+
self.is_loaded = False
|
94 |
+
logger.info(f"LLM {self.model_name} unloaded from {self.device}")
|
95 |
+
|
96 |
+
def _unload_sync(self):
|
97 |
+
del self.model
|
98 |
+
del self.processor
|
99 |
+
if self.device.type == "cuda":
|
100 |
+
torch.cuda.empty_cache()
|
101 |
+
logger.info(f"GPU memory allocated after unload: {torch.cuda.memory_allocated()}")
|
102 |
+
|
103 |
+
async def load(self):
|
104 |
+
if not self.is_loaded:
|
105 |
+
try:
|
106 |
+
self.model = await asyncio.to_thread(
|
107 |
+
AutoModel.from_pretrained,
|
108 |
+
self.model_name,
|
109 |
+
device_map="auto",
|
110 |
+
quantization_config=quantization_config,
|
111 |
+
torch_dtype=self.torch_dtype
|
112 |
+
)
|
113 |
+
self.model.eval()
|
114 |
+
self.processor = await asyncio.to_thread(AutoProcessor.from_pretrained, self.model_name)
|
115 |
+
self.is_loaded = True
|
116 |
+
logger.info(f"LLM {self.model_name} loaded on {self.device} with 4-bit quantization")
|
117 |
+
except Exception as e:
|
118 |
+
logger.error(f"Failed to load model: {str(e)}")
|
119 |
+
raise HTTPException(status_code=500, detail=f"Model loading failed: {str(e)}")
|
120 |
+
|
121 |
+
async def generate(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7) -> str:
|
122 |
+
if not self.is_loaded:
|
123 |
+
await self.load()
|
124 |
+
|
125 |
+
messages_vlm = [
|
126 |
+
{
|
127 |
+
"role": "system",
|
128 |
+
"content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state. Provide a concise response in one sentence maximum."}]
|
129 |
+
},
|
130 |
+
{
|
131 |
+
"role": "user",
|
132 |
+
"content": [{"type": "text", "text": prompt}]
|
133 |
+
}
|
134 |
+
]
|
135 |
+
|
136 |
+
try:
|
137 |
+
inputs_vlm = await asyncio.to_thread(
|
138 |
+
self.processor.apply_chat_template,
|
139 |
+
messages_vlm,
|
140 |
+
add_generation_prompt=True,
|
141 |
+
tokenize=True,
|
142 |
+
return_dict=True,
|
143 |
+
return_tensors="pt"
|
144 |
+
)
|
145 |
+
inputs_vlm = inputs_vlm.to(self.device, dtype=torch.bfloat16)
|
146 |
+
logger.info(f"Input IDs: {inputs_vlm['input_ids']}")
|
147 |
+
logger.info(f"Decoded input: {self.processor.decode(inputs_vlm['input_ids'][0])}")
|
148 |
+
except Exception as e:
|
149 |
+
logger.error(f"Error in tokenization: {str(e)}")
|
150 |
+
raise HTTPException(status_code=500, detail=f"Tokenization failed: {str(e)}")
|
151 |
+
|
152 |
+
input_len = inputs_vlm["input_ids"].shape[-1]
|
153 |
+
|
154 |
+
with torch.inference_mode():
|
155 |
+
generation = await asyncio.to_thread(
|
156 |
+
self.model.generate,
|
157 |
+
**inputs_vlm,
|
158 |
+
max_new_tokens=max_tokens,
|
159 |
+
do_sample=True,
|
160 |
+
temperature=temperature
|
161 |
+
)
|
162 |
+
generation = generation[0][input_len:]
|
163 |
+
|
164 |
+
response = self.processor.decode(generation, skip_special_tokens=True)
|
165 |
+
logger.info(f"Generated response: {response}")
|
166 |
+
return response
|
167 |
+
|
168 |
+
async def vision_query(self, image: Image.Image, query: str) -> str:
|
169 |
+
if not self.is_loaded:
|
170 |
+
await self.load()
|
171 |
+
|
172 |
+
messages_vlm = [
|
173 |
+
{
|
174 |
+
"role": "system",
|
175 |
+
"content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Summarize your answer in maximum 1 sentence."}]
|
176 |
+
},
|
177 |
+
{
|
178 |
+
"role": "user",
|
179 |
+
"content": []
|
180 |
+
}
|
181 |
+
]
|
182 |
+
|
183 |
+
messages_vlm[1]["content"].append({"type": "text", "text": query})
|
184 |
+
if image and image.size[0] > 0 and image.size[1] > 0:
|
185 |
+
messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
|
186 |
+
logger.info(f"Received valid image for processing")
|
187 |
+
else:
|
188 |
+
logger.info("No valid image provided, processing text only")
|
189 |
+
|
190 |
+
try:
|
191 |
+
inputs_vlm = await asyncio.to_thread(
|
192 |
+
self.processor.apply_chat_template,
|
193 |
+
messages_vlm,
|
194 |
+
add_generation_prompt=True,
|
195 |
+
tokenize=True,
|
196 |
+
return_dict=True,
|
197 |
+
return_tensors="pt"
|
198 |
+
)
|
199 |
+
inputs_vlm = inputs_vlm.to(self.device, dtype=torch.bfloat16)
|
200 |
+
logger.info(f"Input IDs: {inputs_vlm['input_ids']}")
|
201 |
+
except Exception as e:
|
202 |
+
logger.error(f"Error in apply_chat_template: {str(e)}")
|
203 |
+
raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
|
204 |
+
|
205 |
+
input_len = inputs_vlm["input_ids"].shape[-1]
|
206 |
+
|
207 |
+
with torch.inference_mode():
|
208 |
+
generation = await asyncio.to_thread(
|
209 |
+
self.model.generate,
|
210 |
+
**inputs_vlm,
|
211 |
+
max_new_tokens=512,
|
212 |
+
do_sample=True,
|
213 |
+
temperature=0.7
|
214 |
+
)
|
215 |
+
generation = generation[0][input_len:]
|
216 |
+
|
217 |
+
decoded = self.processor.decode(generation, skip_special_tokens=True)
|
218 |
+
logger.info(f"Vision query response: {decoded}")
|
219 |
+
return decoded
|
220 |
+
|
221 |
+
async def chat_v2(self, image: Image.Image, query: str) -> str:
|
222 |
+
if not self.is_loaded:
|
223 |
+
await self.load()
|
224 |
+
|
225 |
+
messages_vlm = [
|
226 |
+
{
|
227 |
+
"role": "system",
|
228 |
+
"content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state."}]
|
229 |
+
},
|
230 |
+
{
|
231 |
+
"role": "user",
|
232 |
+
"content": []
|
233 |
+
}
|
234 |
+
]
|
235 |
+
|
236 |
+
messages_vlm[1]["content"].append({"type": "text", "text": query})
|
237 |
+
if image and image.size[0] > 0 and image.size[1] > 0:
|
238 |
+
messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
|
239 |
+
logger.info(f"Received valid image for processing")
|
240 |
+
else:
|
241 |
+
logger.info("No valid image provided, processing text only")
|
242 |
+
|
243 |
+
try:
|
244 |
+
inputs_vlm = await asyncio.to_thread(
|
245 |
+
self.processor.apply_chat_template,
|
246 |
+
messages_vlm,
|
247 |
+
add_generation_prompt=True,
|
248 |
+
tokenize=True,
|
249 |
+
return_dict=True,
|
250 |
+
return_tensors="pt"
|
251 |
+
)
|
252 |
+
inputs_vlm = inputs_vlm.to(self.device, dtype=torch.bfloat16)
|
253 |
+
logger.info(f"Input IDs: {inputs_vlm['input_ids']}")
|
254 |
+
except Exception as e:
|
255 |
+
logger.error(f"Error in apply_chat_template: {str(e)}")
|
256 |
+
raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
|
257 |
+
|
258 |
+
input_len = inputs_vlm["input_ids"].shape[-1]
|
259 |
+
|
260 |
+
with torch.inference_mode():
|
261 |
+
generation = await asyncio.to_thread(
|
262 |
+
self.model.generate,
|
263 |
+
**inputs_vlm,
|
264 |
+
max_new_tokens=512,
|
265 |
+
do_sample=True,
|
266 |
+
temperature=0.7
|
267 |
+
)
|
268 |
+
generation = generation[0][input_len:]
|
269 |
+
|
270 |
+
decoded = self.processor.decode(generation, skip_special_tokens=True)
|
271 |
+
logger.info(f"Chat_v2 response: {decoded}")
|
272 |
+
return decoded
|
273 |
+
|
274 |
# TTS Manager
|
275 |
class TTSManager:
|
276 |
def __init__(self, device_type=device):
|
|
|
399 |
if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
|
400 |
model_name = "ai4bharat/indictrans2-en-indic-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-en-indic-1B"
|
401 |
elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
|
402 |
+
model_name = "ai4bharat/indictrans2-indic-en-dist-200M" if use_distilled else "ai4bharat/indictrans2-indic-en-1B"
|
403 |
else:
|
404 |
model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if self.use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
|
405 |
|
|
|
494 |
return model_manager.get_model(src_lang, tgt_lang)
|
495 |
|
496 |
# Lifespan Event Handler
|
497 |
+
translation_configs = []
|
498 |
+
|
499 |
@asynccontextmanager
|
500 |
async def lifespan(app: FastAPI):
|
501 |
async def load_all_models():
|
|
|
507 |
asyncio.create_task(model_manager.load_model('kan_Knda', 'eng_Latn', 'indic_eng')),
|
508 |
asyncio.create_task(model_manager.load_model('kan_Knda', 'hin_Deva', 'indic_indic')),
|
509 |
]
|
510 |
+
for config in translation_configs:
|
511 |
+
src_lang = config["src_lang"]
|
512 |
+
tgt_lang = config["tgt_lang"]
|
513 |
+
key = model_manager._get_model_key(src_lang, tgt_lang)
|
514 |
+
tasks.append(asyncio.create_task(model_manager.load_model(src_lang, tgt_lang, key)))
|
515 |
+
|
516 |
await asyncio.gather(*tasks)
|
517 |
logger.info("All models loaded successfully")
|
518 |
|
|
|
826 |
if not asr_manager.model:
|
827 |
raise HTTPException(status_code=503, detail="ASR model still loading, please try again later")
|
828 |
try:
|
829 |
+
import torchaudio # Added here for clarity
|
830 |
wav, sr = torchaudio.load(file.file)
|
831 |
wav = torch.mean(wav, dim=0, keepdim=True)
|
832 |
target_sample_rate = 16000
|
|
|
899 |
asr_manager.model_language[selected_config["language"]] = selected_config["components"]["ASR"]["language_code"]
|
900 |
|
901 |
if selected_config["components"]["Translation"]:
|
902 |
+
translation_configs.extend(selected_config["components"]["Translation"])
|
|
|
|
|
|
|
903 |
|
904 |
host = args.host if args.host != settings.host else settings.host
|
905 |
port = args.port if args.port != settings.port else settings.port
|