|
import json |
|
from dataclasses import dataclass, make_dataclass, asdict, field |
|
from typing import List |
|
|
|
|
|
@dataclass |
|
class Config: |
|
device: str = "cpu" |
|
|
|
config: str = "config/default.json" |
|
loader: str = "loaders/google_sc.py" |
|
dataset: str = "" |
|
indices: str = "" |
|
model_dir: str = "default_model_dir" |
|
validation_datasets: List = field(default_factory=lambda: []) |
|
|
|
|
|
batch_size: int = 4 |
|
verbose: bool = True |
|
|
|
|
|
encoder_model_id: str = "distilroberta-base" |
|
|
|
rewards: tuple = ( |
|
"FluencyReward", |
|
"CrossSimilarityReward", |
|
) |
|
|
|
|
|
def load_config(args): |
|
""" |
|
Loads settings into a dataclass object, from the following sources: |
|
- defaults defined above by DefaultConfig |
|
- args.config (path to a JSON config file) |
|
- args (from using argparse in a script) |
|
|
|
Overlapping fields are overwritten in that order. |
|
|
|
Example usage: |
|
(...) |
|
args = load_config(parser.parse_args()) |
|
args.batch_size |
|
""" |
|
config = asdict(Config()) |
|
if args.config: |
|
with open(args.config) as f: |
|
config.update(json.load(f)) |
|
config.update(args.__dict__) |
|
Config_ = make_dataclass("Config", fields=config.items()) |
|
return Config_(**config) |
|
|