Xmaster6y commited on
Commit
bd23dfe
·
unverified ·
1 Parent(s): 0d998a6
Files changed (2) hide show
  1. src/constants.py +4 -1
  2. src/global_variables.py +14 -10
src/constants.py CHANGED
@@ -15,4 +15,7 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  MODEL_NAME = "lc0-10-4238.onnx"
16
  SAE_CONFIG = "debug"
17
  LAYER = 9
18
- N_FEATURES = 7680
 
 
 
 
15
  MODEL_NAME = "lc0-10-4238.onnx"
16
  SAE_CONFIG = "debug"
17
  LAYER = 9
18
+ ACTIVATION_DIM = 256
19
+ DICTIONARY_SIZE = 7680
20
+ PRE_BIAS = False
21
+ INIT_NORMALISE_DICT = None
src/global_variables.py CHANGED
@@ -7,7 +7,7 @@ import gradio as gr
7
  from lczerolens import ModelWrapper
8
  import torch
9
 
10
- from src.constants import HF_TOKEN, ASSETS_FOLDER, DEVICE, MODEL_NAME, SAE_CONFIG, LAYER
11
  from src.helpers import SparseAutoEncoder, OutputGenerator
12
 
13
  hf_api: HfApi
@@ -22,32 +22,36 @@ def setup():
22
  global sae
23
  global generator
24
 
25
- hf_api = HfApi(token=HF_TOKEN)
26
  hf_api.snapshot_download(
27
- local_dir=f"{ASSETS_FOLDER}/models",
28
  repo_id="Xmaster6y/lczero-planning-models",
29
  repo_type="model",
30
  )
31
  hf_api.snapshot_download(
32
- local_dir=f"{ASSETS_FOLDER}/saes",
33
  repo_id="Xmaster6y/lczero-planning-saes",
34
  repo_type="model",
35
  )
36
 
37
- wrapper = ModelWrapper.from_onnx_path(f"{ASSETS_FOLDER}/models/{MODEL_NAME}").to(DEVICE)
38
  sae_dict = torch.load(
39
- f"{ASSETS_FOLDER}/saes/{SAE_CONFIG}/model.pt",
40
- map_location=DEVICE,
41
- weights_only=True
 
 
 
 
 
42
  )
43
- sae = SparseAutoEncoder()
44
  sae.load_state_dict(
45
  sae_dict
46
  )
47
  generator = OutputGenerator(
48
  sae=sae,
49
  wrapper=wrapper,
50
- module_exp=rf".*block{LAYER}/conv2/relu"
51
  )
52
 
53
  if gr.NO_RELOAD:
 
7
  from lczerolens import ModelWrapper
8
  import torch
9
 
10
+ from src import constants
11
  from src.helpers import SparseAutoEncoder, OutputGenerator
12
 
13
  hf_api: HfApi
 
22
  global sae
23
  global generator
24
 
25
+ hf_api = HfApi(token=constants.HF_TOKEN)
26
  hf_api.snapshot_download(
27
+ local_dir=f"{constants.ASSETS_FOLDER}/models",
28
  repo_id="Xmaster6y/lczero-planning-models",
29
  repo_type="model",
30
  )
31
  hf_api.snapshot_download(
32
+ local_dir=f"{constants.ASSETS_FOLDER}/saes",
33
  repo_id="Xmaster6y/lczero-planning-saes",
34
  repo_type="model",
35
  )
36
 
37
+ wrapper = ModelWrapper.from_onnx_path(f"{constants.ASSETS_FOLDER}/models/{constants.MODEL_NAME}").to(constants.DEVICE)
38
  sae_dict = torch.load(
39
+ f"{constants.ASSETS_FOLDER}/saes/{constants.SAE_CONFIG}/model.pt",
40
+ map_location=constants.DEVICE,
41
+ )
42
+ sae = SparseAutoEncoder(
43
+ constants.ACTIVATION_DIM,
44
+ constants.DICTIONARY_SIZE,
45
+ pre_bias=constants.PRE_BIAS,
46
+ init_normalise_dict=constants.INIT_NORMALISE_DICT,
47
  )
 
48
  sae.load_state_dict(
49
  sae_dict
50
  )
51
  generator = OutputGenerator(
52
  sae=sae,
53
  wrapper=wrapper,
54
+ module_exp=rf".*block{constants.LAYER}/conv2/relu"
55
  )
56
 
57
  if gr.NO_RELOAD: