HaileyStorm commited on
Commit
627fa95
·
verified ·
1 Parent(s): 1d3684d

Update chess-gpt-eval-contrastive/mamba_module.py

Browse files
chess-gpt-eval-contrastive/mamba_module.py CHANGED
@@ -93,7 +93,7 @@ class MambaPlayer:
93
  self.hooks = []
94
 
95
  for i, layer in enumerate(self.model.backbone.layers):
96
- self.activations[i] = {"won": [], "lost": [], "current": []}
97
 
98
  def hook(module, input, output, layer_idx=i):
99
  if isinstance(output, tuple):
@@ -102,8 +102,7 @@ class MambaPlayer:
102
  tensor_output = output
103
  self.activations[layer_idx]["current"] = tensor_output.detach().cpu().numpy()
104
 
105
- hook_function = partial(hook, layer_idx=i)
106
- self.hooks.append(layer.register_forward_hook(hook_function))
107
 
108
  def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
109
  game_state = game_state.split("\n\n")[-1].strip()
 
93
  self.hooks = []
94
 
95
  for i, layer in enumerate(self.model.backbone.layers):
96
+ self.activations[i] = {"won": [], "lost": []}
97
 
98
  def hook(module, input, output, layer_idx=i):
99
  if isinstance(output, tuple):
 
102
  tensor_output = output
103
  self.activations[layer_idx]["current"] = tensor_output.detach().cpu().numpy()
104
 
105
+ self.hooks.append(layer.register_forward_hook(hook))
 
106
 
107
  def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
108
  game_state = game_state.split("\n\n")[-1].strip()