Spaces:
Runtime error
Runtime error
File size: 1,673 Bytes
0d998a6 3b6ef01 0d998a6 bd23dfe 0d998a6 3b6ef01 0d998a6 3b6ef01 0d998a6 bd23dfe 0d998a6 bd23dfe 5f19208 0d998a6 bd23dfe 5f19208 0d998a6 bd23dfe 0d998a6 bd23dfe 0d998a6 bd23dfe 0d998a6 3b6ef01 0d998a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
"""Manage global variables for the app.
"""
from huggingface_hub import HfApi
import gradio as gr
from lczerolens import ModelWrapper
import torch
from datasets import load_dataset, Dataset
from src import constants
from src.helpers import SparseAutoEncoder, OutputGenerator
hf_api: HfApi
wrapper: ModelWrapper
sae: SparseAutoEncoder
generator: OutputGenerator
f_ds: Dataset
def setup():
global hf_api
global wrapper
global sae
global generator
global f_ds
hf_api = HfApi(token=constants.HF_TOKEN)
hf_api.snapshot_download(
local_dir=f"{constants.ASSETS_FOLDER}/models",
repo_id="lczero-planning/models",
repo_type="model",
)
hf_api.snapshot_download(
local_dir=f"{constants.ASSETS_FOLDER}/saes",
repo_id="lczero-planning/saes",
repo_type="model",
)
wrapper = ModelWrapper.from_onnx_path(f"{constants.ASSETS_FOLDER}/models/{constants.MODEL_NAME}").to(constants.DEVICE)
sae_dict = torch.load(
f"{constants.ASSETS_FOLDER}/saes/{constants.SAE_CONFIG}/model.pt",
map_location=constants.DEVICE,
)
sae = SparseAutoEncoder(
constants.ACTIVATION_DIM,
constants.DICTIONARY_SIZE,
pre_bias=constants.PRE_BIAS,
init_normalise_dict=constants.INIT_NORMALISE_DICT,
)
sae.load_state_dict(
sae_dict
)
generator = OutputGenerator(
sae=sae,
wrapper=wrapper,
module_exp=rf".*block{constants.LAYER}/conv2/relu"
)
f_ds = load_dataset(
constants.FEATURE_DATASET,
constants.SAE_CONFIG,
split="test"
).with_format("torch")
if gr.NO_RELOAD:
setup()
|