Flexi-Propagator / config_adv_dif.py
Khalid Rafiq
Add all required modules and requirements.txt
ab72d17
raw
history blame
529 Bytes
from dataclasses import dataclass
import json
@dataclass
class Config:
# default values. DO NOT TOUCH
name: str = 'FlexiPropagator_2D'
latent_dim: int = 3
batch_size: int = 64
lr: float = 3e-4
num_epochs: int = 25
num_time_steps: int = 500
gamma: float = 3.25
beta: float = 1e-3
val_every: float = 0.25
plot_train_every: float = 0.01
save_dir: str = 'checkpoints'
def load_config(path):
with open(path, 'r') as f:
config = json.load(f)
return Config(**config)