File size: 1,300 Bytes
0d998a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""Manage global variables for the app.
"""

from huggingface_hub import HfApi

import gradio as gr
from lczerolens import ModelWrapper
import torch

from src.constants import HF_TOKEN, ASSETS_FOLDER, DEVICE, MODEL_NAME, SAE_CONFIG, LAYER
from src.helpers import SparseAutoEncoder, OutputGenerator

hf_api: HfApi
wrapper: ModelWrapper
sae: SparseAutoEncoder
generator: OutputGenerator


def setup():
    global hf_api
    global wrapper
    global sae
    global generator

    hf_api = HfApi(token=HF_TOKEN)
    hf_api.snapshot_download(
        local_dir=f"{ASSETS_FOLDER}/models",
        repo_id="Xmaster6y/lczero-planning-models",
        repo_type="model",
    )
    hf_api.snapshot_download(
        local_dir=f"{ASSETS_FOLDER}/saes",
        repo_id="Xmaster6y/lczero-planning-saes",
        repo_type="model",
    )

    wrapper = ModelWrapper.from_onnx_path(f"{ASSETS_FOLDER}/models/{MODEL_NAME}").to(DEVICE)
    sae_dict = torch.load(
        f"{ASSETS_FOLDER}/saes/{SAE_CONFIG}/model.pt",
        map_location=DEVICE,
        weights_only=True
    )
    sae = SparseAutoEncoder()
    sae.load_state_dict(
        sae_dict
    )
    generator = OutputGenerator(
        sae=sae,
        wrapper=wrapper,
        module_exp=rf".*block{LAYER}/conv2/relu"
    )

if gr.NO_RELOAD:
    setup()