anonymous8
update
d65ddc0
raw
history blame
2.1 kB
from argparse import Namespace
import torch
one_shot_messages = set()
def config_check(args):
pass
class ConfigManager(Namespace):
def __init__(self, args=None, **kwargs):
"""
The ConfigManager is a subclass of argparse.Namespace and based on parameter dict and count the call-frequency of each parameter
:param args: A parameter dict
:param kwargs: Same param as Namespce
"""
if not args:
args = {}
super().__init__(**kwargs)
if isinstance(args, Namespace):
self.args = vars(args)
self.args_call_count = {arg: 0 for arg in vars(args)}
else:
self.args = args
self.args_call_count = {arg: 0 for arg in args}
def __getattribute__(self, arg_name):
if arg_name == "args" or arg_name == "args_call_count":
return super().__getattribute__(arg_name)
try:
value = super().__getattribute__("args")[arg_name]
args_call_count = super().__getattribute__("args_call_count")
args_call_count[arg_name] += 1
super().__setattr__("args_call_count", args_call_count)
return value
except Exception as e:
return super().__getattribute__(arg_name)
def __setattr__(self, arg_name, value):
if arg_name == "args" or arg_name == "args_call_count":
super().__setattr__(arg_name, value)
return
try:
args = super().__getattribute__("args")
args[arg_name] = value
super().__setattr__("args", args)
args_call_count = super().__getattribute__("args_call_count")
if arg_name in args_call_count:
# args_call_count[arg_name] += 1
super().__setattr__("args_call_count", args_call_count)
else:
args_call_count[arg_name] = 0
super().__setattr__("args_call_count", args_call_count)
except Exception as e:
super().__setattr__(arg_name, value)
config_check(args)