File size: 663 Bytes
ab72d17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from dataclasses import dataclass
import json

@dataclass
class Config:
    # default values. DO NOT TOUCH
    name: str = 'FlexiPropagator'
    latent_dim: int = 2
    input_dim: int = 128
    batch_size: int = 128
    lr: float = 3e-4
    num_epochs: int = 200
    n_samples_train: int = 8_00_000

    num_time_steps: int = 500

    tau_left_fraction: float = 0.35
    tau_right_fraction: float = 0.85

    gamma: float = 3.25
    beta: float = 1e-4

    val_every: float = 0.25
    plot_train_every: float = 0.01

    save_dir: str = 'checkpoints'

def load_config(path):
    with open(path, 'r') as f:
        config = json.load(f)
    return Config(**config)