File size: 3,813 Bytes
f641099 8a42096 f641099 043e817 f641099 043e817 2724ad2 043e817 2724ad2 f641099 043e817 f641099 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import torch
from logging_config import logger
from transformers import AutoTokenizer
from parler_tts import ParlerTTSForConditionalGeneration
from typing import OrderedDict, Tuple
from tts_config import config
# Device setup
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TORCH_DTYPE = torch.bfloat16 if DEVICE.type != "cpu" else torch.float32
class TTSModelManager:
def __init__(self):
self.model_tokenizer: OrderedDict[
str, Tuple[ParlerTTSForConditionalGeneration, AutoTokenizer, AutoTokenizer]
] = OrderedDict()
self.max_length = 50 # Reverted to baseline value
self.voice_cache = {} # Reserved for future use
self.audio_cache = {} # Used for caching generated audio
def load_model(
self, model_name: str
) -> Tuple[ParlerTTSForConditionalGeneration, AutoTokenizer, AutoTokenizer]:
from time import time
logger.debug(f"Loading {model_name}...")
start = time()
model_name = "ai4bharat/indic-parler-tts"
attn_implementation = "flash_attention_2"
model = ParlerTTSForConditionalGeneration.from_pretrained(
model_name,
attn_implementation=attn_implementation
).to(DEVICE, dtype=TORCH_DTYPE)
tokenizer = AutoTokenizer.from_pretrained(model_name)
description_tokenizer = AutoTokenizer.from_pretrained(model.config.text_encoder._name_or_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if description_tokenizer.pad_token is None:
description_tokenizer.pad_token = description_tokenizer.eos_token
# Update model configuration (from baseline)
model.config.pad_token_id = tokenizer.pad_token_id
if hasattr(model.generation_config.cache_config, 'max_batch_size'):
model.generation_config.cache_config.max_batch_size = 1
model.generation_config.cache_implementation = "static"
# Compile the model (baseline approach)
compile_mode = "default"
model.forward = torch.compile(model.forward, mode=compile_mode)
# Warmup (baseline approach)
warmup_inputs = tokenizer(
"Warmup text for compilation",
return_tensors="pt",
padding="max_length",
max_length=self.max_length
).to(DEVICE)
model_kwargs = {
"input_ids": warmup_inputs["input_ids"],
"attention_mask": warmup_inputs["attention_mask"],
"prompt_input_ids": warmup_inputs["input_ids"],
"prompt_attention_mask": warmup_inputs["attention_mask"],
"max_new_tokens": 100, # Added for better graph capture
"do_sample": True, # Added for consistency with endpoint
"top_p": 0.9,
"temperature": 0.7,
}
n_steps = 1 # Baseline uses 1 step for "default" mode
for _ in range(n_steps):
_ = model.generate(**model_kwargs)
logger.info(
f"Loaded {model_name} with Flash Attention and compilation in {time() - start:.2f} seconds"
)
return model, tokenizer, description_tokenizer
def get_or_load_model(
self, model_name: str
) -> Tuple[ParlerTTSForConditionalGeneration, AutoTokenizer, AutoTokenizer]:
if model_name not in self.model_tokenizer:
logger.info(f"Model {model_name} isn't already loaded")
if len(self.model_tokenizer) == config.max_models:
logger.info("Unloading the oldest loaded model")
del self.model_tokenizer[next(iter(self.model_tokenizer))]
self.model_tokenizer[model_name] = self.load_model(model_name)
return self.model_tokenizer[model_name] |