File size: 3,813 Bytes
f641099
 
 
 
 
 
 
 
8a42096
 
f641099
 
 
 
 
 
043e817
 
 
f641099
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
043e817
 
 
 
 
2724ad2
043e817
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2724ad2
f641099
043e817
f641099
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
from logging_config import logger
from transformers import AutoTokenizer
from parler_tts import ParlerTTSForConditionalGeneration
from typing import OrderedDict, Tuple
from tts_config import config

# Device setup
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TORCH_DTYPE = torch.bfloat16 if DEVICE.type != "cpu" else torch.float32

class TTSModelManager:
    def __init__(self):
        self.model_tokenizer: OrderedDict[
            str, Tuple[ParlerTTSForConditionalGeneration, AutoTokenizer, AutoTokenizer]
        ] = OrderedDict()
        self.max_length = 50  # Reverted to baseline value
        self.voice_cache = {}  # Reserved for future use
        self.audio_cache = {}  # Used for caching generated audio

    def load_model(
        self, model_name: str
    ) -> Tuple[ParlerTTSForConditionalGeneration, AutoTokenizer, AutoTokenizer]:
        from time import time
        logger.debug(f"Loading {model_name}...")
        start = time()
        
        model_name = "ai4bharat/indic-parler-tts"
        attn_implementation = "flash_attention_2"
        
        model = ParlerTTSForConditionalGeneration.from_pretrained(
            model_name,
            attn_implementation=attn_implementation
        ).to(DEVICE, dtype=TORCH_DTYPE)

        tokenizer = AutoTokenizer.from_pretrained(model_name)
        description_tokenizer = AutoTokenizer.from_pretrained(model.config.text_encoder._name_or_path)

        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        if description_tokenizer.pad_token is None:
            description_tokenizer.pad_token = description_tokenizer.eos_token

        # Update model configuration (from baseline)
        model.config.pad_token_id = tokenizer.pad_token_id
        if hasattr(model.generation_config.cache_config, 'max_batch_size'):
            model.generation_config.cache_config.max_batch_size = 1
        model.generation_config.cache_implementation = "static"

        # Compile the model (baseline approach)
        compile_mode = "default"
        model.forward = torch.compile(model.forward, mode=compile_mode)

        # Warmup (baseline approach)
        warmup_inputs = tokenizer(
            "Warmup text for compilation",
            return_tensors="pt",
            padding="max_length",
            max_length=self.max_length
        ).to(DEVICE)
        
        model_kwargs = {
            "input_ids": warmup_inputs["input_ids"],
            "attention_mask": warmup_inputs["attention_mask"],
            "prompt_input_ids": warmup_inputs["input_ids"],
            "prompt_attention_mask": warmup_inputs["attention_mask"],
            "max_new_tokens": 100,  # Added for better graph capture
            "do_sample": True,  # Added for consistency with endpoint
            "top_p": 0.9,
            "temperature": 0.7,
        }
        
        n_steps = 1  # Baseline uses 1 step for "default" mode
        for _ in range(n_steps):
            _ = model.generate(**model_kwargs)

        logger.info(
            f"Loaded {model_name} with Flash Attention and compilation in {time() - start:.2f} seconds"
        )
        return model, tokenizer, description_tokenizer

    def get_or_load_model(
        self, model_name: str
    ) -> Tuple[ParlerTTSForConditionalGeneration, AutoTokenizer, AutoTokenizer]:
        if model_name not in self.model_tokenizer:
            logger.info(f"Model {model_name} isn't already loaded")
            if len(self.model_tokenizer) == config.max_models:
                logger.info("Unloading the oldest loaded model")
                del self.model_tokenizer[next(iter(self.model_tokenizer))]
            self.model_tokenizer[model_name] = self.load_model(model_name)
        return self.model_tokenizer[model_name]