Spaces:
Runtime error
Runtime error
"""Manage global variables for the app. | |
""" | |
from huggingface_hub import HfApi | |
import gradio as gr | |
from lczerolens import ModelWrapper | |
import torch | |
from datasets import load_dataset, Dataset | |
from src import constants | |
from src.helpers import SparseAutoEncoder, OutputGenerator | |
hf_api: HfApi | |
wrapper: ModelWrapper | |
sae: SparseAutoEncoder | |
generator: OutputGenerator | |
f_ds: Dataset | |
def setup(): | |
global hf_api | |
global wrapper | |
global sae | |
global generator | |
global f_ds | |
hf_api = HfApi(token=constants.HF_TOKEN) | |
hf_api.snapshot_download( | |
local_dir=f"{constants.ASSETS_FOLDER}/models", | |
repo_id="lczero-planning/models", | |
repo_type="model", | |
) | |
hf_api.snapshot_download( | |
local_dir=f"{constants.ASSETS_FOLDER}/saes", | |
repo_id="lczero-planning/saes", | |
repo_type="model", | |
) | |
wrapper = ModelWrapper.from_onnx_path(f"{constants.ASSETS_FOLDER}/models/{constants.MODEL_NAME}").to(constants.DEVICE) | |
sae_dict = torch.load( | |
f"{constants.ASSETS_FOLDER}/saes/{constants.SAE_CONFIG}/model.pt", | |
map_location=constants.DEVICE, | |
) | |
sae = SparseAutoEncoder( | |
constants.ACTIVATION_DIM, | |
constants.DICTIONARY_SIZE, | |
pre_bias=constants.PRE_BIAS, | |
init_normalise_dict=constants.INIT_NORMALISE_DICT, | |
) | |
sae.load_state_dict( | |
sae_dict | |
) | |
generator = OutputGenerator( | |
sae=sae, | |
wrapper=wrapper, | |
module_exp=rf".*block{constants.LAYER}/conv2/relu" | |
) | |
f_ds = load_dataset( | |
constants.FEATURE_DATASET, | |
constants.SAE_CONFIG, | |
split="test" | |
).with_format("torch") | |
if gr.NO_RELOAD: | |
setup() | |