demo / src /global_variables.py
Xmaster6y's picture
new repo structure
0d998a6 unverified
raw
history blame
1.3 kB
"""Manage global variables for the app.
"""
from huggingface_hub import HfApi
import gradio as gr
from lczerolens import ModelWrapper
import torch
from src.constants import HF_TOKEN, ASSETS_FOLDER, DEVICE, MODEL_NAME, SAE_CONFIG, LAYER
from src.helpers import SparseAutoEncoder, OutputGenerator
hf_api: HfApi
wrapper: ModelWrapper
sae: SparseAutoEncoder
generator: OutputGenerator
def setup():
global hf_api
global wrapper
global sae
global generator
hf_api = HfApi(token=HF_TOKEN)
hf_api.snapshot_download(
local_dir=f"{ASSETS_FOLDER}/models",
repo_id="Xmaster6y/lczero-planning-models",
repo_type="model",
)
hf_api.snapshot_download(
local_dir=f"{ASSETS_FOLDER}/saes",
repo_id="Xmaster6y/lczero-planning-saes",
repo_type="model",
)
wrapper = ModelWrapper.from_onnx_path(f"{ASSETS_FOLDER}/models/{MODEL_NAME}").to(DEVICE)
sae_dict = torch.load(
f"{ASSETS_FOLDER}/saes/{SAE_CONFIG}/model.pt",
map_location=DEVICE,
weights_only=True
)
sae = SparseAutoEncoder()
sae.load_state_dict(
sae_dict
)
generator = OutputGenerator(
sae=sae,
wrapper=wrapper,
module_exp=rf".*block{LAYER}/conv2/relu"
)
if gr.NO_RELOAD:
setup()