sachin
test- tts speed
8a42096
import torch
from logging_config import logger
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, BitsAndBytesConfig
from PIL import Image
from fastapi import HTTPException
# Device setup
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.bfloat16 if DEVICE != "cpu" else torch.float32
class LLMManager:
def __init__(self, model_name: str, device: str = DEVICE):
self.model_name = model_name
self.device = torch.device(device)
self.torch_dtype = TORCH_DTYPE
self.model = None
self.is_loaded = False
self.processor = None
self.token_cache = {}
logger.info(f"LLMManager initialized with model {model_name} on {self.device}")
def unload(self):
if self.is_loaded:
del self.model
del self.processor
if self.device.type == "cuda":
torch.cuda.empty_cache()
logger.info(f"GPU memory allocated after unload: {torch.cuda.memory_allocated()}")
self.is_loaded = False
logger.info(f"LLM {self.model_name} unloaded from {self.device}")
def load(self):
if not self.is_loaded:
try:
# Enable TF32 for better performance on supported GPUs
if self.device.type == "cuda":
torch.set_float32_matmul_precision('high')
logger.info("Enabled TF32 matrix multiplication for improved performance")
# Configure 4-bit quantization
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4", # NF4 (Normal Float 4) is optimized for LLMs
bnb_4bit_compute_dtype=self.torch_dtype, # Use bfloat16 for computations
bnb_4bit_use_double_quant=True # Nested quantization for better accuracy
)
# Load model with 4-bit quantization
self.model = Gemma3ForConditionalGeneration.from_pretrained(
self.model_name,
device_map="auto",
quantization_config=quantization_config,
torch_dtype=self.torch_dtype,
max_memory={0: "10GiB"} # Adjust based on your GPU capacity
).eval()
# Move model to device (handled by device_map, but explicit for clarity)
self.model.to(self.device)
# Load processor with use_fast=True for faster tokenization
self.processor = AutoProcessor.from_pretrained(self.model_name, use_fast=True)
self.is_loaded = True
logger.info(f"LLM {self.model_name} loaded on {self.device} with 4-bit quantization and fast processor")
except Exception as e:
logger.error(f"Failed to load model: {str(e)}")
raise HTTPException(status_code=500, detail=f"Model loading failed: {str(e)}")
async def generate(self, prompt: str, max_tokens: int = 50) -> str:
if not self.is_loaded:
self.load()
cache_key = f"system_prompt_{prompt}"
if cache_key in self.token_cache and "response" in self.token_cache[cache_key]:
logger.info("Using cached response")
return self.token_cache[cache_key]["response"]
if cache_key in self.token_cache and "inputs" in self.token_cache[cache_key]:
inputs_vlm = self.token_cache[cache_key]["inputs"]
logger.info("Using cached tokenized input")
else:
messages_vlm = [
{
"role": "system",
"content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state. Provide a concise response in one sentence maximum."}]
},
{
"role": "user",
"content": [{"type": "text", "text": prompt}]
}
]
try:
inputs_vlm = self.processor.apply_chat_template(
messages_vlm,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(self.device, dtype=torch.bfloat16)
self.token_cache[cache_key] = {"inputs": inputs_vlm}
except Exception as e:
logger.error(f"Error in tokenization: {str(e)}")
raise HTTPException(status_code=500, detail=f"Tokenization failed: {str(e)}")
input_len = inputs_vlm["input_ids"].shape[-1]
adjusted_max_tokens = min(max_tokens, max(20, input_len * 2))
with torch.inference_mode():
generation = self.model.generate(
**inputs_vlm,
max_new_tokens=adjusted_max_tokens,
do_sample=True,
top_p=0.9,
temperature=0.7
)
generation = generation[0][input_len:]
response = self.processor.decode(generation, skip_special_tokens=True)
self.token_cache[cache_key]["response"] = response # Cache the full response
logger.info(f"Generated response: {response}")
return response
async def vision_query(self, image: Image.Image, query: str) -> str:
if not self.is_loaded:
self.load()
messages_vlm = [
{
"role": "system",
"content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Summarize your answer in one sentence maximum."}]
},
{
"role": "user",
"content": [{"type": "text", "text": query}] + ([{"type": "image", "image": image}] if image else [])
}
]
cache_key = f"vision_{query}_{'image' if image else 'no_image'}"
if cache_key in self.token_cache and "response" in self.token_cache[cache_key]:
logger.info("Using cached response")
return self.token_cache[cache_key]["response"]
if cache_key in self.token_cache and "inputs" in self.token_cache[cache_key]:
inputs_vlm = self.token_cache[cache_key]["inputs"]
logger.info("Using cached tokenized input")
else:
try:
inputs_vlm = self.processor.apply_chat_template(
messages_vlm,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(self.device, dtype=torch.bfloat16)
self.token_cache[cache_key] = {"inputs": inputs_vlm}
except Exception as e:
logger.error(f"Error in apply_chat_template: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
input_len = inputs_vlm["input_ids"].shape[-1]
adjusted_max_tokens = min(50, max(20, input_len * 2))
with torch.inference_mode():
generation = self.model.generate(
**inputs_vlm,
max_new_tokens=adjusted_max_tokens,
do_sample=True,
top_p=0.9,
temperature=0.7
)
generation = generation[0][input_len:]
response = self.processor.decode(generation, skip_special_tokens=True)
self.token_cache[cache_key]["response"] = response # Cache the full response
logger.info(f"Vision query response: {response}")
return response
async def chat_v2(self, image: Image.Image, query: str) -> str:
if not self.is_loaded:
self.load()
messages_vlm = [
{
"role": "system",
"content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state."}]
},
{
"role": "user",
"content": [{"type": "text", "text": query}] + ([{"type": "image", "image": image}] if image else [])
}
]
cache_key = f"chat_v2_{query}_{'image' if image else 'no_image'}"
if cache_key in self.token_cache and "response" in self.token_cache[cache_key]:
logger.info("Using cached response")
return self.token_cache[cache_key]["response"]
if cache_key in self.token_cache and "inputs" in self.token_cache[cache_key]:
inputs_vlm = self.token_cache[cache_key]["inputs"]
logger.info("Using cached tokenized input")
else:
try:
inputs_vlm = self.processor.apply_chat_template(
messages_vlm,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(self.device, dtype=torch.bfloat16)
self.token_cache[cache_key] = {"inputs": inputs_vlm}
except Exception as e:
logger.error(f"Error in apply_chat_template: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
input_len = inputs_vlm["input_ids"].shape[-1]
adjusted_max_tokens = min(50, max(20, input_len * 2))
with torch.inference_mode():
generation = self.model.generate(
**inputs_vlm,
max_new_tokens=adjusted_max_tokens,
do_sample=True,
top_p=0.9,
temperature=0.7
)
generation = generation[0][input_len:]
response = self.processor.decode(generation, skip_special_tokens=True)
self.token_cache[cache_key]["response"] = response # Cache the full response
logger.info(f"Chat_v2 response: {response}")
return response