Spaces:
Sleeping
Sleeping
from abc import ABC, abstractmethod | |
from typing import Dict | |
from easydict import EasyDict | |
from ditk import logging | |
import os | |
import copy | |
from typing import Any | |
from ding.utils import REWARD_MODEL_REGISTRY, import_module, save_file | |
class BaseRewardModel(ABC): | |
""" | |
Overview: | |
the base class of reward model | |
Interface: | |
``default_config``, ``estimate``, ``train``, ``clear_data``, ``collect_data``, ``load_expert_date`` | |
""" | |
def default_config(cls: type) -> EasyDict: | |
cfg = EasyDict(copy.deepcopy(cls.config)) | |
cfg.cfg_type = cls.__name__ + 'Dict' | |
return cfg | |
def estimate(self, data: list) -> Any: | |
""" | |
Overview: | |
estimate reward | |
Arguments: | |
- data (:obj:`List`): the list of data used for estimation | |
Returns / Effects: | |
- This can be a side effect function which updates the reward value | |
- If this function returns, an example returned object can be reward (:obj:`Any`): the estimated reward | |
""" | |
raise NotImplementedError() | |
def train(self, data) -> None: | |
""" | |
Overview: | |
Training the reward model | |
Arguments: | |
- data (:obj:`Any`): Data used for training | |
Effects: | |
- This is mostly a side effect function which updates the reward model | |
""" | |
raise NotImplementedError() | |
def collect_data(self, data) -> None: | |
""" | |
Overview: | |
Collecting training data in designated formate or with designated transition. | |
Arguments: | |
- data (:obj:`Any`): Raw training data (e.g. some form of states, actions, obs, etc) | |
Returns / Effects: | |
- This can be a side effect function which updates the data attribute in ``self`` | |
""" | |
raise NotImplementedError() | |
def clear_data(self) -> None: | |
""" | |
Overview: | |
Clearing training data. \ | |
This can be a side effect function which clears the data attribute in ``self`` | |
""" | |
raise NotImplementedError() | |
def load_expert_data(self, data) -> None: | |
""" | |
Overview: | |
Getting the expert data, usually used in inverse RL reward model | |
Arguments: | |
- data (:obj:`Any`): Expert data | |
Effects: | |
This is mostly a side effect function which updates the expert data attribute (e.g. ``self.expert_data``) | |
""" | |
pass | |
def reward_deepcopy(self, train_data) -> Any: | |
""" | |
Overview: | |
this method deepcopy reward part in train_data, and other parts keep shallow copy | |
to avoid the reward part of train_data in the replay buffer be incorrectly modified. | |
Arguments: | |
- train_data (:obj:`List`): the List of train data in which the reward part will be operated by deepcopy. | |
""" | |
train_data_reward_deepcopy = [ | |
{k: copy.deepcopy(v) if k == 'reward' else v | |
for k, v in sample.items()} for sample in train_data | |
] | |
return train_data_reward_deepcopy | |
def state_dict(self) -> Dict: | |
# this method should be overrided by subclass. | |
return {} | |
def load_state_dict(self, _state_dict) -> None: | |
# this method should be overrided by subclass. | |
pass | |
def save(self, path: str = None, name: str = 'best'): | |
if path is None: | |
path = self.cfg.exp_name | |
path = os.path.join(path, 'reward_model', 'ckpt') | |
if not os.path.exists(path): | |
try: | |
os.makedirs(path) | |
except FileExistsError: | |
pass | |
path = os.path.join(path, 'ckpt_{}.pth.tar'.format(name)) | |
state_dict = self.state_dict() | |
save_file(path, state_dict) | |
logging.info('Saved reward model ckpt in {}'.format(path)) | |
def create_reward_model(cfg: dict, device: str, tb_logger: 'SummaryWriter') -> BaseRewardModel: # noqa | |
""" | |
Overview: | |
Reward Estimation Model. | |
Arguments: | |
- cfg (:obj:`Dict`): Training config | |
- device (:obj:`str`): Device usage, i.e. "cpu" or "cuda" | |
- tb_logger (:obj:`str`): Logger, defaultly set as 'SummaryWriter' for model summary | |
Returns: | |
- reward (:obj:`Any`): The reward model | |
""" | |
cfg = copy.deepcopy(cfg) | |
if 'import_names' in cfg: | |
import_module(cfg.pop('import_names')) | |
if hasattr(cfg, 'reward_model'): | |
reward_model_type = cfg.reward_model.pop('type') | |
else: | |
reward_model_type = cfg.pop('type') | |
return REWARD_MODEL_REGISTRY.build(reward_model_type, cfg, device=device, tb_logger=tb_logger) | |
def get_reward_model_cls(cfg: EasyDict) -> type: | |
import_module(cfg.get('import_names', [])) | |
return REWARD_MODEL_REGISTRY.get(cfg.type) | |