import torch from dataclasses import dataclass, field @dataclass class SAETrainingConfig: d_model: int n_dirs: int k: int block_name: str bs: int save_path_base: str auxk: int = 256 lr: float = 1e-4 eps: float = 6.25e-10 dead_toks_threshold: int = 10_000_000 auxk_coef: float = 1/32 @property def sae_name(self): return f'{self.block_name}_k{self.k}_hidden{self.n_dirs}_auxk{self.auxk}_bs{self.bs}_lr{self.lr}' @property def save_path(self): return f'/dlabscratch1/surkov/sae_models/{self.block_name}_k{self.k}_hidden{self.n_dirs}_auxk{self.auxk}_bs{self.bs}_lr{self.lr}' @dataclass class Config: saes: list[SAETrainingConfig] paths_to_latents: list[str] log_interval: int save_interval: int bs: int block_name: str wandb_project: str = 'sdxl_sae_train' wandb_name: str = 'multiple_sae' def __init__(self, cfg_json): self.saes = [SAETrainingConfig(**sae_cfg, block_name=cfg_json['block_name'], bs=cfg_json['bs'], save_path_base=cfg_json['save_path_base']) for sae_cfg in cfg_json['sae_configs']] self.save_path_base = cfg_json['save_path_base'] self.paths_to_latents = cfg_json['paths_to_latents'] self.log_interval = cfg_json['log_interval'] self.save_interval = cfg_json['save_interval'] self.bs = cfg_json['bs'] self.block_name = cfg_json['block_name']