"""Manage global variables for the app. """ from huggingface_hub import HfApi import gradio as gr from lczerolens import ModelWrapper import torch from src.constants import HF_TOKEN, ASSETS_FOLDER, DEVICE, MODEL_NAME, SAE_CONFIG, LAYER from src.helpers import SparseAutoEncoder, OutputGenerator hf_api: HfApi wrapper: ModelWrapper sae: SparseAutoEncoder generator: OutputGenerator def setup(): global hf_api global wrapper global sae global generator hf_api = HfApi(token=HF_TOKEN) hf_api.snapshot_download( local_dir=f"{ASSETS_FOLDER}/models", repo_id="Xmaster6y/lczero-planning-models", repo_type="model", ) hf_api.snapshot_download( local_dir=f"{ASSETS_FOLDER}/saes", repo_id="Xmaster6y/lczero-planning-saes", repo_type="model", ) wrapper = ModelWrapper.from_onnx_path(f"{ASSETS_FOLDER}/models/{MODEL_NAME}").to(DEVICE) sae_dict = torch.load( f"{ASSETS_FOLDER}/saes/{SAE_CONFIG}/model.pt", map_location=DEVICE, weights_only=True ) sae = SparseAutoEncoder() sae.load_state_dict( sae_dict ) generator = OutputGenerator( sae=sae, wrapper=wrapper, module_exp=rf".*block{LAYER}/conv2/relu" ) if gr.NO_RELOAD: setup()