Spaces:
Sleeping
Sleeping
import copy | |
from easydict import EasyDict | |
import pickle | |
from ding.utils import REWARD_MODEL_REGISTRY | |
from .trex_reward_model import TrexRewardModel | |
class DrexRewardModel(TrexRewardModel): | |
""" | |
Overview: | |
The Drex reward model class (https://arxiv.org/pdf/1907.03976.pdf) | |
Interface: | |
``estimate``, ``train``, ``load_expert_data``, ``collect_data``, ``clear_date``, \ | |
``__init__``, ``_train``, | |
Config: | |
== ==================== ====== ============= ======================================= =============== | |
ID Symbol Type Default Value Description Other(Shape) | |
== ==================== ====== ============= ======================================= =============== | |
1 ``type`` str drex | Reward model register name, refer | | |
| to registry ``REWARD_MODEL_REGISTRY`` | | |
3 | ``learning_rate`` float 0.00001 | learning rate for optimizer | | |
4 | ``update_per_`` int 100 | Number of updates per collect | | |
| ``collect`` | | | |
5 | ``batch_size`` int 64 | How many samples in a training batch | | |
6 | ``hidden_size`` int 128 | Linear model hidden size | | |
7 | ``num_trajs`` int 0 | Number of downsampled full | | |
| trajectories | | |
8 | ``num_snippets`` int 6000 | Number of short subtrajectories | | |
| to sample | | |
== ==================== ====== ============= ======================================= ================ | |
""" | |
config = dict( | |
# (str) Reward model register name, refer to registry ``REWARD_MODEL_REGISTRY``. | |
type='drex', | |
# (float) The step size of gradient descent. | |
learning_rate=1e-5, | |
# (int) How many updates(iterations) to train after collector's one collection. | |
# Bigger "update_per_collect" means bigger off-policy. | |
# collect data -> update policy-> collect data -> ... | |
update_per_collect=100, | |
# (int) How many samples in a training batch. | |
batch_size=64, | |
# (int) Linear model hidden size | |
hidden_size=128, | |
# (int) Number of downsampled full trajectories. | |
num_trajs=0, | |
# (int) Number of short subtrajectories to sample. | |
num_snippets=6000, | |
) | |
bc_cfg = None | |
def __init__(self, config: EasyDict, device: str, tb_logger: 'SummaryWriter') -> None: # noqa | |
""" | |
Overview: | |
Initialize ``self.`` See ``help(type(self))`` for accurate signature. | |
Arguments: | |
- cfg (:obj:`EasyDict`): Training config | |
- device (:obj:`str`): Device usage, i.e. "cpu" or "cuda" | |
- tb_logger (:obj:`SummaryWriter`): Logger, defaultly set as 'SummaryWriter' for model summary | |
""" | |
super(DrexRewardModel, self).__init__(copy.deepcopy(config), device, tb_logger) | |
self.demo_data = [] | |
self.load_expert_data() | |
def load_expert_data(self) -> None: | |
""" | |
Overview: | |
Getting the expert data from ``config.expert_data_path`` attribute in self | |
Effects: | |
This is a side effect function which updates the expert data attribute \ | |
(i.e. ``self.expert_data``) with ``fn:concat_state_action_pairs`` | |
""" | |
super(DrexRewardModel, self).load_expert_data() | |
with open(self.cfg.reward_model.offline_data_path + '/suboptimal_data.pkl', 'rb') as f: | |
self.demo_data = pickle.load(f) | |
def train(self): | |
self._train() | |
return_dict = self.pred_data(self.demo_data) | |
res, pred_returns = return_dict['real'], return_dict['pred'] | |
self._logger.info("real: " + str(res)) | |
self._logger.info("pred: " + str(pred_returns)) | |
info = { | |
"min_snippet_length": self.min_snippet_length, | |
"max_snippet_length": self.max_snippet_length, | |
"len_num_training_obs": len(self.training_obs), | |
"lem_num_labels": len(self.training_labels), | |
"accuracy": self.calc_accuracy(self.reward_model, self.training_obs, self.training_labels), | |
} | |
self._logger.info( | |
"accuracy and comparison:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()])) | |
) | |