sachin
commited on
Commit
·
3f6f875
1
Parent(s):
7a61b58
add-tts
Browse files- src/server/main.py +3 -1
src/server/main.py
CHANGED
@@ -95,6 +95,8 @@ class TTSModelManager:
|
|
95 |
if description_tokenizer.pad_token is None:
|
96 |
description_tokenizer.pad_token = description_tokenizer.eos_token
|
97 |
|
|
|
|
|
98 |
# Update model configuration
|
99 |
model.config.pad_token_id = tokenizer.pad_token_id
|
100 |
# Update for deprecation: use max_batch_size instead of batch_size
|
@@ -124,7 +126,7 @@ class TTSModelManager:
|
|
124 |
n_steps = 1 if compile_mode == "default" else 2
|
125 |
for _ in range(n_steps):
|
126 |
_ = model.generate(**model_kwargs)
|
127 |
-
|
128 |
logger.info(
|
129 |
f"Loaded {model_name} with Flash Attention and compilation in {time.perf_counter() - start:.2f} seconds"
|
130 |
)
|
|
|
95 |
if description_tokenizer.pad_token is None:
|
96 |
description_tokenizer.pad_token = description_tokenizer.eos_token
|
97 |
|
98 |
+
# TODO - temporary disable -torch.compile
|
99 |
+
'''
|
100 |
# Update model configuration
|
101 |
model.config.pad_token_id = tokenizer.pad_token_id
|
102 |
# Update for deprecation: use max_batch_size instead of batch_size
|
|
|
126 |
n_steps = 1 if compile_mode == "default" else 2
|
127 |
for _ in range(n_steps):
|
128 |
_ = model.generate(**model_kwargs)
|
129 |
+
'''
|
130 |
logger.info(
|
131 |
f"Loaded {model_name} with Flash Attention and compilation in {time.perf_counter() - start:.2f} seconds"
|
132 |
)
|