sachin
commited on
Commit
·
6173695
1
Parent(s):
af923b7
enable-torch
Browse files- src/server/main.py +2 -2
src/server/main.py
CHANGED
@@ -96,7 +96,7 @@ class TTSModelManager:
|
|
96 |
description_tokenizer.pad_token = description_tokenizer.eos_token
|
97 |
|
98 |
# TODO - temporary disable -torch.compile
|
99 |
-
|
100 |
# Update model configuration
|
101 |
model.config.pad_token_id = tokenizer.pad_token_id
|
102 |
# Update for deprecation: use max_batch_size instead of batch_size
|
@@ -126,7 +126,7 @@ class TTSModelManager:
|
|
126 |
n_steps = 1 if compile_mode == "default" else 2
|
127 |
for _ in range(n_steps):
|
128 |
_ = model.generate(**model_kwargs)
|
129 |
-
|
130 |
logger.info(
|
131 |
f"Loaded {model_name} with Flash Attention and compilation in {time.perf_counter() - start:.2f} seconds"
|
132 |
)
|
|
|
96 |
description_tokenizer.pad_token = description_tokenizer.eos_token
|
97 |
|
98 |
# TODO - temporary disable -torch.compile
|
99 |
+
|
100 |
# Update model configuration
|
101 |
model.config.pad_token_id = tokenizer.pad_token_id
|
102 |
# Update for deprecation: use max_batch_size instead of batch_size
|
|
|
126 |
n_steps = 1 if compile_mode == "default" else 2
|
127 |
for _ in range(n_steps):
|
128 |
_ = model.generate(**model_kwargs)
|
129 |
+
|
130 |
logger.info(
|
131 |
f"Loaded {model_name} with Flash Attention and compilation in {time.perf_counter() - start:.2f} seconds"
|
132 |
)
|