sachin commited on
Commit
8bf941a
·
1 Parent(s): d1c9225

diable -torhc compile

Browse files
Files changed (1) hide show
  1. src/server/main.py +5 -4
src/server/main.py CHANGED
@@ -95,14 +95,15 @@ class TTSModelManager:
95
  if description_tokenizer.pad_token is None:
96
  description_tokenizer.pad_token = description_tokenizer.eos_token
97
 
 
98
  # TODO - temporary disable -torch.compile
99
 
100
  # Update model configuration
101
  model.config.pad_token_id = tokenizer.pad_token_id
102
  # Update for deprecation: use max_batch_size instead of batch_size
103
- #if hasattr(model.generation_config.cache_config, 'max_batch_size'):
104
- # model.generation_config.cache_config.max_batch_size = 1
105
- #model.generation_config.cache_implementation = "static"
106
 
107
  # Compile the model
108
  ##compile_mode = "default"
@@ -126,7 +127,7 @@ class TTSModelManager:
126
  n_steps = 1 if compile_mode == "default" else 2
127
  for _ in range(n_steps):
128
  _ = model.generate(**model_kwargs)
129
-
130
  logger.info(
131
  f"Loaded {model_name} with Flash Attention and compilation in {time.perf_counter() - start:.2f} seconds"
132
  )
 
95
  if description_tokenizer.pad_token is None:
96
  description_tokenizer.pad_token = description_tokenizer.eos_token
97
 
98
+ '''
99
  # TODO - temporary disable -torch.compile
100
 
101
  # Update model configuration
102
  model.config.pad_token_id = tokenizer.pad_token_id
103
  # Update for deprecation: use max_batch_size instead of batch_size
104
+ if hasattr(model.generation_config.cache_config, 'max_batch_size'):
105
+ model.generation_config.cache_config.max_batch_size = 1
106
+ model.generation_config.cache_implementation = "static"
107
 
108
  # Compile the model
109
  ##compile_mode = "default"
 
127
  n_steps = 1 if compile_mode == "default" else 2
128
  for _ in range(n_steps):
129
  _ = model.generate(**model_kwargs)
130
+ '''
131
  logger.info(
132
  f"Loaded {model_name} with Flash Attention and compilation in {time.perf_counter() - start:.2f} seconds"
133
  )