"""Manage global variables for the app. """ from huggingface_hub import HfApi import gradio as gr from lczerolens import ModelWrapper import torch from datasets import load_dataset, Dataset from src import constants from src.helpers import SparseAutoEncoder, OutputGenerator hf_api: HfApi wrapper: ModelWrapper sae: SparseAutoEncoder generator: OutputGenerator f_ds: Dataset def setup(): global hf_api global wrapper global sae global generator global f_ds hf_api = HfApi(token=constants.HF_TOKEN) hf_api.snapshot_download( local_dir=f"{constants.ASSETS_FOLDER}/models", repo_id="Xmaster6y/lczero-planning-models", repo_type="model", ) hf_api.snapshot_download( local_dir=f"{constants.ASSETS_FOLDER}/saes", repo_id="Xmaster6y/lczero-planning-saes", repo_type="model", ) wrapper = ModelWrapper.from_onnx_path(f"{constants.ASSETS_FOLDER}/models/{constants.MODEL_NAME}").to(constants.DEVICE) sae_dict = torch.load( f"{constants.ASSETS_FOLDER}/saes/{constants.SAE_CONFIG}/model.pt", map_location=constants.DEVICE, ) sae = SparseAutoEncoder( constants.ACTIVATION_DIM, constants.DICTIONARY_SIZE, pre_bias=constants.PRE_BIAS, init_normalise_dict=constants.INIT_NORMALISE_DICT, ) sae.load_state_dict( sae_dict ) generator = OutputGenerator( sae=sae, wrapper=wrapper, module_exp=rf".*block{constants.LAYER}/conv2/relu" ) f_ds = load_dataset( constants.FEATURE_DATASET, constants.SAE_CONFIG, split="test" ).with_format("torch") if gr.NO_RELOAD: setup()