Spaces:
Running
on
L4
Running
on
L4
import argparse | |
import io | |
import os | |
from time import time | |
from typing import List | |
import tempfile | |
import uvicorn | |
from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Body, Form | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse | |
from PIL import Image | |
from pydantic import BaseModel, field_validator | |
from pydantic_settings import BaseSettings | |
from slowapi import Limiter | |
from slowapi.util import get_remote_address | |
import torch | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
from IndicTransToolkit import IndicProcessor | |
from logging_config import logger | |
from tts_config import SPEED, ResponseFormat, config as tts_config | |
from gemma_llm import LLMManager | |
# from auth import get_api_key, settings as auth_settings | |
import time | |
from contextlib import asynccontextmanager | |
from typing import Annotated, Any, OrderedDict, List | |
import zipfile | |
import soundfile as sf | |
import torch | |
from fastapi import Body, FastAPI, HTTPException, Response | |
from parler_tts import ParlerTTSForConditionalGeneration | |
from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed | |
import numpy as np | |
from config import SPEED, ResponseFormat, config | |
from logger import logger | |
import uvicorn | |
import argparse | |
from fastapi.responses import RedirectResponse, StreamingResponse | |
import io | |
import os | |
import logging | |
# Device setup | |
if torch.cuda.is_available(): | |
device = "cuda:0" | |
logger.info("GPU will be used for inference") | |
else: | |
device = "cpu" | |
logger.info("CPU will be used for inference") | |
torch_dtype = torch.bfloat16 if device != "cpu" else torch.float32 | |
# Check CUDA availability and version | |
cuda_available = torch.cuda.is_available() | |
cuda_version = torch.version.cuda if cuda_available else None | |
if torch.cuda.is_available(): | |
device_idx = torch.cuda.current_device() | |
capability = torch.cuda.get_device_capability(device_idx) | |
compute_capability_float = float(f"{capability[0]}.{capability[1]}") | |
print(f"CUDA version: {cuda_version}") | |
print(f"CUDA Compute Capability: {compute_capability_float}") | |
else: | |
print("CUDA is not available on this system.") | |
class TTSModelManager: | |
def __init__(self): | |
self.model_tokenizer: OrderedDict[ | |
str, tuple[ParlerTTSForConditionalGeneration, AutoTokenizer, AutoTokenizer] | |
] = OrderedDict() | |
self.max_length = 50 | |
def load_model( | |
self, model_name: str | |
) -> tuple[ParlerTTSForConditionalGeneration, AutoTokenizer, AutoTokenizer]: | |
logger.debug(f"Loading {model_name}...") | |
start = time.perf_counter() | |
model_name = "ai4bharat/indic-parler-tts" | |
attn_implementation = "flash_attention_2" | |
model = ParlerTTSForConditionalGeneration.from_pretrained( | |
model_name, | |
attn_implementation=attn_implementation | |
).to(device, dtype=torch_dtype) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
description_tokenizer = AutoTokenizer.from_pretrained(model.config.text_encoder._name_or_path) | |
# Set pad tokens | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
if description_tokenizer.pad_token is None: | |
description_tokenizer.pad_token = description_tokenizer.eos_token | |
# TODO - temporary disable -torch.compile | |
''' | |
# Update model configuration | |
model.config.pad_token_id = tokenizer.pad_token_id | |
# Update for deprecation: use max_batch_size instead of batch_size | |
if hasattr(model.generation_config.cache_config, 'max_batch_size'): | |
model.generation_config.cache_config.max_batch_size = 1 | |
model.generation_config.cache_implementation = "static" | |
''' | |
# Compile the model | |
compile_mode = "default" | |
#compile_mode = "reduce-overhead" | |
model.forward = torch.compile(model.forward, mode=compile_mode) | |
# Warmup | |
warmup_inputs = tokenizer("Warmup text for compilation", | |
return_tensors="pt", | |
padding="max_length", | |
max_length=self.max_length).to(device) | |
model_kwargs = { | |
"input_ids": warmup_inputs["input_ids"], | |
"attention_mask": warmup_inputs["attention_mask"], | |
"prompt_input_ids": warmup_inputs["input_ids"], | |
"prompt_attention_mask": warmup_inputs["attention_mask"], | |
} | |
n_steps = 1 if compile_mode == "default" else 2 | |
for _ in range(n_steps): | |
_ = model.generate(**model_kwargs) | |
logger.info( | |
f"Loaded {model_name} with Flash Attention and compilation in {time.perf_counter() - start:.2f} seconds" | |
) | |
return model, tokenizer, description_tokenizer | |
def get_or_load_model( | |
self, model_name: str | |
) -> tuple[ParlerTTSForConditionalGeneration, AutoTokenizer, AutoTokenizer]: | |
if model_name not in self.model_tokenizer: | |
logger.info(f"Model {model_name} isn't already loaded") | |
if len(self.model_tokenizer) == config.max_models: | |
logger.info("Unloading the oldest loaded model") | |
del self.model_tokenizer[next(iter(self.model_tokenizer))] | |
self.model_tokenizer[model_name] = self.load_model(model_name) | |
return self.model_tokenizer[model_name] | |
tts_model_manager = TTSModelManager() | |
async def lifespan(_: FastAPI): | |
if not config.lazy_load_model: | |
tts_model_manager.get_or_load_model(config.model) | |
yield | |
app = FastAPI( | |
title="Dhwani API", | |
description="AI Chat API supporting Indian languages", | |
version="1.0.0", | |
redirect_slashes=False, | |
lifespan=lifespan | |
) | |
def chunk_text(text, chunk_size): | |
words = text.split() | |
chunks = [] | |
for i in range(0, len(words), chunk_size): | |
chunks.append(' '.join(words[i:i + chunk_size])) | |
return chunks | |
async def generate_audio( | |
input: Annotated[str, Body()] = config.input, | |
voice: Annotated[str, Body()] = config.voice, | |
model: Annotated[str, Body()] = config.model, | |
response_format: Annotated[ResponseFormat, Body(include_in_schema=False)] = config.response_format, | |
speed: Annotated[float, Body(include_in_schema=False)] = SPEED, | |
) -> StreamingResponse: | |
tts, tokenizer, description_tokenizer = tts_model_manager.get_or_load_model(model) | |
if speed != SPEED: | |
logger.warning( | |
"Specifying speed isn't supported by this model. Audio will be generated with the default speed" | |
) | |
start = time.perf_counter() | |
chunk_size = 15 | |
all_chunks = chunk_text(input, chunk_size) | |
if len(all_chunks) <= chunk_size: | |
desc_inputs = description_tokenizer(voice, | |
return_tensors="pt", | |
padding="max_length", | |
max_length=tts_model_manager.max_length).to(device) | |
prompt_inputs = tokenizer(input, | |
return_tensors="pt", | |
padding="max_length", | |
max_length=tts_model_manager.max_length).to(device) | |
input_ids = desc_inputs["input_ids"] | |
attention_mask = desc_inputs["attention_mask"] | |
prompt_input_ids = prompt_inputs["input_ids"] | |
prompt_attention_mask = prompt_inputs["attention_mask"] | |
generation = tts.generate( | |
input_ids=input_ids, | |
prompt_input_ids=prompt_input_ids, | |
attention_mask=attention_mask, | |
prompt_attention_mask=prompt_attention_mask | |
).to(torch.float32) | |
audio_arr = generation.cpu().float().numpy().squeeze() | |
else: | |
all_descriptions = [voice] * len(all_chunks) | |
description_inputs = description_tokenizer(all_descriptions, | |
return_tensors="pt", | |
padding=True).to(device) | |
prompts = tokenizer(all_chunks, | |
return_tensors="pt", | |
padding=True).to(device) | |
set_seed(0) | |
generation = tts.generate( | |
input_ids=description_inputs["input_ids"], | |
attention_mask=description_inputs["attention_mask"], | |
prompt_input_ids=prompts["input_ids"], | |
prompt_attention_mask=prompts["attention_mask"], | |
do_sample=True, | |
return_dict_in_generate=True, | |
) | |
chunk_audios = [] | |
for i, audio in enumerate(generation.sequences): | |
audio_data = audio[:generation.audios_length[i]].cpu().float().numpy().squeeze() | |
chunk_audios.append(audio_data) | |
audio_arr = np.concatenate(chunk_audios) | |
device_str = str(device) | |
logger.info( | |
f"Took {time.perf_counter() - start:.2f} seconds to generate audio for {len(input.split())} words using {device_str.upper()}" | |
) | |
audio_buffer = io.BytesIO() | |
sf.write(audio_buffer, audio_arr, tts.config.sampling_rate, format=response_format) | |
audio_buffer.seek(0) | |
return StreamingResponse(audio_buffer, media_type=f"audio/{response_format}") | |
def create_in_memory_zip(file_data): | |
in_memory_zip = io.BytesIO() | |
with zipfile.ZipFile(in_memory_zip, 'w') as zipf: | |
for file_name, data in file_data.items(): | |
zipf.writestr(file_name, data) | |
in_memory_zip.seek(0) | |
return in_memory_zip | |
async def generate_audio_batch( | |
input: Annotated[List[str], Body()] = config.input, | |
voice: Annotated[List[str], Body()] = config.voice, | |
model: Annotated[str, Body(include_in_schema=False)] = config.model, | |
response_format: Annotated[ResponseFormat, Body()] = config.response_format, | |
speed: Annotated[float, Body(include_in_schema=False)] = SPEED, | |
) -> StreamingResponse: | |
tts, tokenizer, description_tokenizer = tts_model_manager.get_or_load_model(model) | |
if speed != SPEED: | |
logger.warning( | |
"Specifying speed isn't supported by this model. Audio will be generated with the default speed" | |
) | |
start = time.perf_counter() | |
chunk_size = 15 | |
all_chunks = [] | |
all_descriptions = [] | |
for i, text in enumerate(input): | |
chunks = chunk_text(text, chunk_size) | |
all_chunks.extend(chunks) | |
all_descriptions.extend([voice[i]] * len(chunks)) | |
description_inputs = description_tokenizer(all_descriptions, | |
return_tensors="pt", | |
padding=True).to(device) | |
prompts = tokenizer(all_chunks, | |
return_tensors="pt", | |
padding=True).to(device) | |
set_seed(0) | |
generation = tts.generate( | |
input_ids=description_inputs["input_ids"], | |
attention_mask=description_inputs["attention_mask"], | |
prompt_input_ids=prompts["input_ids"], | |
prompt_attention_mask=prompts["attention_mask"], | |
do_sample=True, | |
return_dict_in_generate=True, | |
) | |
audio_outputs = [] | |
current_index = 0 | |
for i, text in enumerate(input): | |
chunks = chunk_text(text, chunk_size) | |
chunk_audios = [] | |
for j in range(len(chunks)): | |
audio_arr = generation.sequences[current_index][:generation.audios_length[current_index]].cpu().float().numpy().squeeze() | |
chunk_audios.append(audio_arr) | |
current_index += 1 | |
combined_audio = np.concatenate(chunk_audios) | |
audio_outputs.append(combined_audio) | |
file_data = {} | |
for i, audio in enumerate(audio_outputs): | |
file_name = f"out_{i}.{response_format}" | |
audio_bytes = io.BytesIO() | |
sf.write(audio_bytes, audio, tts.config.sampling_rate, format=response_format) | |
audio_bytes.seek(0) | |
file_data[file_name] = audio_bytes.read() | |
in_memory_zip = create_in_memory_zip(file_data) | |
logger.info( | |
f"Took {time.perf_counter() - start:.2f} seconds to generate audio" | |
) | |
return StreamingResponse(in_memory_zip, media_type="application/zip") | |
# Supported language codes | |
SUPPORTED_LANGUAGES = { | |
"asm_Beng", "kas_Arab", "pan_Guru", "ben_Beng", "kas_Deva", "san_Deva", | |
"brx_Deva", "mai_Deva", "sat_Olck", "doi_Deva", "mal_Mlym", "snd_Arab", | |
"eng_Latn", "mar_Deva", "snd_Deva", "gom_Deva", "mni_Beng", "tam_Taml", | |
"guj_Gujr", "mni_Mtei", "tel_Telu", "hin_Deva", "npi_Deva", "urd_Arab", | |
"kan_Knda", "ory_Orya", | |
"deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn", | |
"por_Latn", "rus_Cyrl", "pol_Latn" | |
} | |
class Settings(BaseSettings): | |
llm_model_name: str = "google/gemma-3-4b-it" | |
max_tokens: int = 512 | |
host: str = "0.0.0.0" | |
port: int = 7860 | |
chat_rate_limit: str = "100/minute" | |
speech_rate_limit: str = "5/minute" | |
def validate_rate_limit(cls, v): | |
if not v.count("/") == 1 or not v.split("/")[0].isdigit(): | |
raise ValueError("Rate limit must be in format 'number/period' (e.g., '5/minute')") | |
return v | |
class Config: | |
env_file = ".env" | |
settings = Settings() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=False, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
limiter = Limiter(key_func=get_remote_address) | |
app.state.limiter = limiter | |
llm_manager = LLMManager(settings.llm_model_name) | |
# Translation Manager and Model Manager | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
class TranslateManager: | |
def __init__(self, src_lang, tgt_lang, device_type=DEVICE, use_distilled=True): | |
self.device_type = device_type | |
self.tokenizer, self.model = self.initialize_model(src_lang, tgt_lang, use_distilled) | |
def initialize_model(self, src_lang, tgt_lang, use_distilled): | |
if src_lang.startswith("eng") and not tgt_lang.startswith("eng"): | |
model_name = "ai4bharat/indictrans2-en-indic-dist-200M" if use_distilled else "ai4bharat/indictrans2-en-indic-1B" | |
elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"): | |
model_name = "ai4bharat/indictrans2-indic-en-dist-200M" if use_distilled else "ai4bharat/indictrans2-indic-en-1B" | |
elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"): | |
model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if use_distilled else "ai4bharat/indictrans2-indic-indic-1B" | |
else: | |
raise ValueError("Invalid language combination: English to English translation is not supported.") | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
torch_dtype=torch.float16, | |
attn_implementation="flash_attention_2" | |
).to(self.device_type) | |
return tokenizer, model | |
class ModelManager: | |
def __init__(self, device_type=DEVICE, use_distilled=True, is_lazy_loading=False): | |
self.models: dict[str, TranslateManager] = {} | |
self.device_type = device_type | |
self.use_distilled = use_distilled | |
self.is_lazy_loading = is_lazy_loading | |
if not is_lazy_loading: | |
self.preload_models() | |
def preload_models(self): | |
self.models['eng_indic'] = TranslateManager('eng_Latn', 'kan_Knda', self.device_type, self.use_distilled) | |
self.models['indic_eng'] = TranslateManager('kan_Knda', 'eng_Latn', self.device_type, self.use_distilled) | |
self.models['indic_indic'] = TranslateManager('kan_Knda', 'hin_Deva', self.device_type, self.use_distilled) | |
def get_model(self, src_lang, tgt_lang) -> TranslateManager: | |
if src_lang.startswith("eng") and not tgt_lang.startswith("eng"): | |
key = 'eng_indic' | |
elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"): | |
key = 'indic_eng' | |
elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"): | |
key = 'indic_indic' | |
else: | |
raise ValueError("Invalid language combination: English to English translation is not supported.") | |
if key not in self.models: | |
if self.is_lazy_loading: | |
if key == 'eng_indic': | |
self.models[key] = TranslateManager('eng_Latn', 'kan_Knda', self.device_type, self.use_distilled) | |
elif key == 'indic_eng': | |
self.models[key] = TranslateManager('kan_Knda', 'eng_Latn', self.device_type, self.use_distilled) | |
elif key == 'indic_indic': | |
self.models[key] = TranslateManager('kan_Knda', 'hin_Deva', self.device_type, self.use_distilled) | |
else: | |
raise ValueError(f"Model for {key} is not preloaded and lazy loading is disabled.") | |
return self.models[key] | |
ip = IndicProcessor(inference=True) | |
model_manager = ModelManager() | |
# Pydantic Models | |
class ChatRequest(BaseModel): | |
prompt: str | |
src_lang: str = "kan_Knda" # Default to Kannada | |
tgt_lang: str = "kan_Knda" # Default to Kannada | |
def prompt_must_be_valid(cls, v): | |
if len(v) > 1000: | |
raise ValueError("Prompt cannot exceed 1000 characters") | |
return v.strip() | |
def validate_language(cls, v): | |
if v not in SUPPORTED_LANGUAGES: | |
raise ValueError(f"Unsupported language code: {v}. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}") | |
return v | |
class ChatResponse(BaseModel): | |
response: str | |
class TranslationRequest(BaseModel): | |
sentences: List[str] | |
src_lang: str | |
tgt_lang: str | |
class TranslationResponse(BaseModel): | |
translations: List[str] | |
# Dependency to get TranslateManager | |
def get_translate_manager(src_lang: str, tgt_lang: str) -> TranslateManager: | |
return model_manager.get_model(src_lang, tgt_lang) | |
# Internal Translation Endpoint | |
async def translate(request: TranslationRequest, translate_manager: TranslateManager = Depends(get_translate_manager)): | |
input_sentences = request.sentences | |
src_lang = request.src_lang | |
tgt_lang = request.tgt_lang | |
if not input_sentences: | |
raise HTTPException(status_code=400, detail="Input sentences are required") | |
batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang) | |
inputs = translate_manager.tokenizer( | |
batch, | |
truncation=True, | |
padding="longest", | |
return_tensors="pt", | |
return_attention_mask=True, | |
).to(translate_manager.device_type) | |
with torch.no_grad(): | |
generated_tokens = translate_manager.model.generate( | |
**inputs, | |
use_cache=True, | |
min_length=0, | |
max_length=256, | |
num_beams=5, | |
num_return_sequences=1, | |
) | |
with translate_manager.tokenizer.as_target_tokenizer(): | |
generated_tokens = translate_manager.tokenizer.batch_decode( | |
generated_tokens.detach().cpu().tolist(), | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=True, | |
) | |
translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang) | |
return TranslationResponse(translations=translations) | |
# Helper function to perform internal translation | |
async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]: | |
translate_manager = model_manager.get_model(src_lang, tgt_lang) | |
request = TranslationRequest(sentences=sentences, src_lang=src_lang, tgt_lang=tgt_lang) | |
response = await translate(request, translate_manager) | |
return response.translations | |
# API Endpoints | |
async def health_check(): | |
return {"status": "healthy", "model": settings.llm_model_name} | |
async def home(): | |
return RedirectResponse(url="/docs") | |
async def unload_all_models(): | |
try: | |
logger.info("Starting to unload all models...") | |
llm_manager.unload() | |
logger.info("All models unloaded successfully") | |
return {"status": "success", "message": "All models unloaded"} | |
except Exception as e: | |
logger.error(f"Error unloading models: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Failed to unload models: {str(e)}") | |
async def load_all_models(): | |
try: | |
logger.info("Starting to load all models...") | |
llm_manager.load() | |
logger.info("All models loaded successfully") | |
return {"status": "success", "message": "All models loaded"} | |
except Exception as e: | |
logger.error(f"Error loading models: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Failed to unload models: {str(e)}") | |
async def translate_endpoint(request: TranslationRequest): | |
logger.info(f"Received translation request: {request.dict()}") | |
try: | |
translations = await perform_internal_translation( | |
sentences=request.sentences, | |
src_lang=request.src_lang, | |
tgt_lang=request.tgt_lang | |
) | |
logger.info(f"Translation successful: {translations}") | |
return TranslationResponse(translations=translations) | |
except Exception as e: | |
logger.error(f"Unexpected error during translation: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}") | |
async def chat(request: Request, chat_request: ChatRequest): | |
if not chat_request.prompt: | |
raise HTTPException(status_code=400, detail="Prompt cannot be empty") | |
logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}") | |
EUROPEAN_LANGUAGES = {"deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn", "por_Latn", "rus_Cyrl", "pol_Latn"} | |
try: | |
if chat_request.src_lang != "eng_Latn" and chat_request.src_lang not in EUROPEAN_LANGUAGES: | |
translated_prompt = await perform_internal_translation( | |
sentences=[chat_request.prompt], | |
src_lang=chat_request.src_lang, | |
tgt_lang="eng_Latn" | |
) | |
prompt_to_process = translated_prompt[0] | |
logger.info(f"Translated prompt to English: {prompt_to_process}") | |
else: | |
prompt_to_process = chat_request.prompt | |
logger.info("Prompt in English or European language, no translation needed") | |
response = await llm_manager.generate(prompt_to_process, settings.max_tokens) | |
logger.info(f"Generated response: {response}") | |
if chat_request.tgt_lang != "eng_Latn" and chat_request.tgt_lang not in EUROPEAN_LANGUAGES: | |
translated_response = await perform_internal_translation( | |
sentences=[response], | |
src_lang="eng_Latn", | |
tgt_lang=chat_request.tgt_lang | |
) | |
final_response = translated_response[0] | |
logger.info(f"Translated response to {chat_request.tgt_lang}: {final_response}") | |
else: | |
final_response = response | |
logger.info(f"Response in {chat_request.tgt_lang}, no translation needed") | |
return ChatResponse(response=final_response) | |
except Exception as e: | |
logger.error(f"Error processing request: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
async def visual_query( | |
file: UploadFile = File(...), | |
query: str = Body(...), | |
src_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)), | |
tgt_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)), | |
): | |
try: | |
image = Image.open(file.file) | |
if image.size == (0, 0): | |
raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid") | |
if src_lang != "eng_Latn": | |
translated_query = await perform_internal_translation( | |
sentences=[query], | |
src_lang=src_lang, | |
tgt_lang="eng_Latn" | |
) | |
query_to_process = translated_query[0] | |
logger.info(f"Translated query to English: {query_to_process}") | |
else: | |
query_to_process = query | |
logger.info("Query already in English, no translation needed") | |
answer = await llm_manager.vision_query(image, query_to_process) | |
logger.info(f"Generated English answer: {answer}") | |
if tgt_lang != "eng_Latn": | |
translated_answer = await perform_internal_translation( | |
sentences=[answer], | |
src_lang="eng_Latn", | |
tgt_lang=tgt_lang | |
) | |
final_answer = translated_answer[0] | |
logger.info(f"Translated answer to {tgt_lang}: {final_answer}") | |
else: | |
final_answer = answer | |
logger.info("Answer kept in English, no translation needed") | |
return {"answer": final_answer} | |
except Exception as e: | |
logger.error(f"Error processing request: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
async def chat_v2( | |
request: Request, | |
prompt: str = Form(...), | |
image: UploadFile = File(default=None), | |
src_lang: str = Form("kan_Knda"), | |
tgt_lang: str = Form("kan_Knda"), | |
): | |
if not prompt: | |
raise HTTPException(status_code=400, detail="Prompt cannot be empty") | |
if src_lang not in SUPPORTED_LANGUAGES or tgt_lang not in SUPPORTED_LANGUAGES: | |
raise HTTPException(status_code=400, detail=f"Unsupported language code. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}") | |
logger.info(f"Received prompt: {prompt}, src_lang: {src_lang}, tgt_lang: {tgt_lang}, Image provided: {image is not None}") | |
try: | |
if image: | |
image_data = await image.read() | |
if not image_data: | |
raise HTTPException(status_code=400, detail="Uploaded image is empty") | |
img = Image.open(io.BytesIO(image_data)) | |
if src_lang != "eng_Latn": | |
translated_prompt = await perform_internal_translation( | |
sentences=[prompt], | |
src_lang=src_lang, | |
tgt_lang="eng_Latn" | |
) | |
prompt_to_process = translated_prompt[0] | |
logger.info(f"Translated prompt to English: {prompt_to_process}") | |
else: | |
prompt_to_process = prompt | |
logger.info("Prompt already in English, no translation needed") | |
decoded = await llm_manager.chat_v2(img, prompt_to_process) | |
logger.info(f"Generated English response: {decoded}") | |
if tgt_lang != "eng_Latn": | |
translated_response = await perform_internal_translation( | |
sentences=[decoded], | |
src_lang="eng_Latn", | |
tgt_lang=tgt_lang | |
) | |
final_response = translated_response[0] | |
logger.info(f"Translated response to {tgt_lang}: {final_response}") | |
else: | |
final_response = decoded | |
logger.info("Response kept in English, no translation needed") | |
else: | |
if src_lang != "eng_Latn": | |
translated_prompt = await perform_internal_translation( | |
sentences=[prompt], | |
src_lang=src_lang, | |
tgt_lang="eng_Latn" | |
) | |
prompt_to_process = translated_prompt[0] | |
logger.info(f"Translated prompt to English: {prompt_to_process}") | |
else: | |
prompt_to_process = prompt | |
logger.info("Prompt already in English, no translation needed") | |
decoded = await llm_manager.generate(prompt_to_process, settings.max_tokens) | |
logger.info(f"Generated English response: {decoded}") | |
if tgt_lang != "eng_Latn": | |
translated_response = await perform_internal_translation( | |
sentences=[decoded], | |
src_lang="eng_Latn", | |
tgt_lang=tgt_lang | |
) | |
final_response = translated_response[0] | |
logger.info(f"Translated response to {tgt_lang}: {final_response}") | |
else: | |
final_response = decoded | |
logger.info("Response kept in English, no translation needed") | |
return ChatResponse(response=final_response) | |
except Exception as e: | |
logger.error(f"Error processing request: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
class TranscriptionResponse(BaseModel): | |
text: str | |
class ASRModelManager: | |
def __init__(self, device_type="cuda"): | |
self.device_type = device_type | |
self.model_language = { | |
"kannada": "kn", "hindi": "hi", "malayalam": "ml", "assamese": "as", "bengali": "bn", | |
"bodo": "brx", "dogri": "doi", "gujarati": "gu", "kashmiri": "ks", "konkani": "kok", | |
"maithili": "mai", "manipuri": "mni", "marathi": "mr", "nepali": "ne", "odia": "or", | |
"punjabi": "pa", "sanskrit": "sa", "santali": "sat", "sindhi": "sd", "tamil": "ta", | |
"telugu": "te", "urdu": "ur" | |
} | |
from fastapi import FastAPI, UploadFile | |
import torch | |
import torchaudio | |
from transformers import AutoModel | |
import argparse | |
import uvicorn | |
from pydantic import BaseModel | |
from pydub import AudioSegment | |
from fastapi import FastAPI, File, UploadFile, HTTPException, Query | |
from fastapi.responses import RedirectResponse, JSONResponse | |
from typing import List | |
# Load the model | |
model = AutoModel.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True) | |
asr_manager = ASRModelManager() | |
# Language to script mapping | |
LANGUAGE_TO_SCRIPT = { | |
"kannada": "kan_Knda", "hindi": "hin_Deva", "malayalam": "mal_Mlym", "tamil": "tam_Taml", | |
"telugu": "tel_Telu", "assamese": "asm_Beng", "bengali": "ben_Beng", "gujarati": "guj_Gujr", | |
"marathi": "mar_Deva", "odia": "ory_Orya", "punjabi": "pan_Guru", "urdu": "urd_Arab", | |
# Add more as needed | |
} | |
async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))): | |
try: | |
wav, sr = torchaudio.load(file.file) | |
wav = torch.mean(wav, dim=0, keepdim=True) | |
target_sample_rate = 16000 | |
if sr != target_sample_rate: | |
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate) | |
wav = resampler(wav) | |
transcription_rnnt = model(wav, asr_manager.model_language[language], "rnnt") | |
return TranscriptionResponse(text=transcription_rnnt) | |
except Exception as e: | |
logger.error(f"Error in transcription: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}") | |
async def speech_to_speech( | |
request: Request, # Inject Request object from FastAPI | |
file: UploadFile = File(...), | |
language: str = Query(..., enum=list(asr_manager.model_language.keys())), | |
voice: str = Body(default=config.voice) | |
) -> StreamingResponse: | |
# Step 1: Transcribe audio to text | |
transcription = await transcribe_audio(file, language) | |
logger.info(f"Transcribed text: {transcription.text}") | |
# Step 2: Process text with chat endpoint | |
chat_request = ChatRequest( | |
prompt=transcription.text, | |
src_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"), # Dynamic script mapping | |
tgt_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda") | |
) | |
processed_text = await chat(request, chat_request) # Pass the injected request | |
logger.info(f"Processed text: {processed_text.response}") | |
# Step 3: Convert processed text to speech | |
audio_response = await generate_audio( | |
input=processed_text.response, | |
voice=voice, | |
model=tts_config.model, | |
response_format=config.response_format, | |
speed=SPEED | |
) | |
return audio_response | |
class BatchTranscriptionResponse(BaseModel): | |
transcriptions: List[str] | |
import json | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Run the FastAPI server.") | |
parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.") | |
parser.add_argument("--host", type=str, default=settings.host, help="Host to run the server on.") | |
parser.add_argument("--config", type=str, default="config_one", help="Configuration to use (e.g., config_one, config_two, config_three, config_four)") | |
args = parser.parse_args() | |
# Load the JSON configuration file | |
def load_config(config_path="dhwani_config.json"): | |
with open(config_path, "r") as f: | |
return json.load(f) | |
config_data = load_config() | |
if args.config not in config_data["configs"]: | |
raise ValueError(f"Invalid config: {args.config}. Available: {list(config_data['configs'].keys())}") | |
selected_config = config_data["configs"][args.config] | |
global_settings = config_data["global_settings"] | |
# Update settings based on selected config | |
settings.llm_model_name = selected_config["components"]["LLM"]["model"] | |
settings.max_tokens = selected_config["components"]["LLM"]["max_tokens"] | |
settings.host = global_settings["host"] | |
settings.port = global_settings["port"] | |
settings.chat_rate_limit = global_settings["chat_rate_limit"] | |
settings.speech_rate_limit = global_settings["speech_rate_limit"] | |
# Initialize LLMManager with the selected LLM model | |
llm_manager = LLMManager(settings.llm_model_name) | |
# Initialize ASR model if present in config | |
if selected_config["components"]["ASR"]: | |
asr_model_name = selected_config["components"]["ASR"]["model"] | |
model = AutoModel.from_pretrained(asr_model_name, trust_remote_code=True) | |
asr_manager.model_language[selected_config["language"]] = selected_config["components"]["ASR"]["language_code"] | |
# Initialize TTS model if present in config | |
if selected_config["components"]["TTS"]: | |
tts_model_name = selected_config["components"]["TTS"]["model"] | |
tts_config.model = tts_model_name # Update tts_config to use the selected model | |
tts_model_manager.get_or_load_model(tts_model_name) | |
# Initialize Translation models - load all specified models | |
if selected_config["components"]["Translation"]: | |
for translation_config in selected_config["components"]["Translation"]: | |
src_lang = translation_config["src_lang"] | |
tgt_lang = translation_config["tgt_lang"] | |
model_manager.get_model(src_lang, tgt_lang) | |
# Override host and port from command line arguments if provided | |
host = args.host if args.host != settings.host else settings.host | |
port = args.port if args.port != settings.port else settings.port | |
# Run the server | |
uvicorn.run(app, host=host, port=port) |