juewang commited on
Commit
4ec6edc
1 Parent(s): c6a4053

remove torch.jit

Browse files

This decorator will fail in some cases. Remove it until we figure out a solution.

Files changed (1) hide show
  1. modeling_flash_llama.py +1 -1
modeling_flash_llama.py CHANGED
@@ -62,7 +62,7 @@ logger = logging.get_logger(__name__)
62
  _CONFIG_FOR_DOC = "LlamaConfig"
63
 
64
 
65
- @torch.jit.script
66
  def rmsnorm_func(hidden_states, weight, variance_epsilon):
67
  input_dtype = hidden_states.dtype
68
  hidden_states = hidden_states.to(torch.float32)
 
62
  _CONFIG_FOR_DOC = "LlamaConfig"
63
 
64
 
65
+ # @torch.jit.script
66
  def rmsnorm_func(hidden_states, weight, variance_epsilon):
67
  input_dtype = hidden_states.dtype
68
  hidden_states = hidden_states.to(torch.float32)