Spaces:
Running
on
T4
Running
on
T4
sachin
commited on
Commit
·
8bf941a
1
Parent(s):
d1c9225
diable -torhc compile
Browse files- src/server/main.py +5 -4
src/server/main.py
CHANGED
@@ -95,14 +95,15 @@ class TTSModelManager:
|
|
95 |
if description_tokenizer.pad_token is None:
|
96 |
description_tokenizer.pad_token = description_tokenizer.eos_token
|
97 |
|
|
|
98 |
# TODO - temporary disable -torch.compile
|
99 |
|
100 |
# Update model configuration
|
101 |
model.config.pad_token_id = tokenizer.pad_token_id
|
102 |
# Update for deprecation: use max_batch_size instead of batch_size
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
|
107 |
# Compile the model
|
108 |
##compile_mode = "default"
|
@@ -126,7 +127,7 @@ class TTSModelManager:
|
|
126 |
n_steps = 1 if compile_mode == "default" else 2
|
127 |
for _ in range(n_steps):
|
128 |
_ = model.generate(**model_kwargs)
|
129 |
-
|
130 |
logger.info(
|
131 |
f"Loaded {model_name} with Flash Attention and compilation in {time.perf_counter() - start:.2f} seconds"
|
132 |
)
|
|
|
95 |
if description_tokenizer.pad_token is None:
|
96 |
description_tokenizer.pad_token = description_tokenizer.eos_token
|
97 |
|
98 |
+
'''
|
99 |
# TODO - temporary disable -torch.compile
|
100 |
|
101 |
# Update model configuration
|
102 |
model.config.pad_token_id = tokenizer.pad_token_id
|
103 |
# Update for deprecation: use max_batch_size instead of batch_size
|
104 |
+
if hasattr(model.generation_config.cache_config, 'max_batch_size'):
|
105 |
+
model.generation_config.cache_config.max_batch_size = 1
|
106 |
+
model.generation_config.cache_implementation = "static"
|
107 |
|
108 |
# Compile the model
|
109 |
##compile_mode = "default"
|
|
|
127 |
n_steps = 1 if compile_mode == "default" else 2
|
128 |
for _ in range(n_steps):
|
129 |
_ = model.generate(**model_kwargs)
|
130 |
+
'''
|
131 |
logger.info(
|
132 |
f"Loaded {model_name} with Flash Attention and compilation in {time.perf_counter() - start:.2f} seconds"
|
133 |
)
|