Spaces:
Sleeping
Sleeping
from typing import List, Dict | |
from ditk import logging | |
import numpy as np | |
import torch | |
import pickle | |
try: | |
from sklearn.svm import SVC | |
except ImportError: | |
SVC = None | |
from ding.torch_utils import cov | |
from ding.utils import REWARD_MODEL_REGISTRY, one_time_warning | |
from .base_reward_model import BaseRewardModel | |
class PdeilRewardModel(BaseRewardModel): | |
""" | |
Overview: | |
The Pdeil reward model class (https://arxiv.org/abs/2112.06746) | |
Interface: | |
``estimate``, ``train``, ``load_expert_data``, ``collect_data``, ``clear_date``, \ | |
``__init__``, ``_train``, ``_batch_mn_pdf`` | |
Config: | |
== ==================== ===== ============= ======================================= ======================= | |
ID Symbol Type Default Value Description Other(Shape) | |
== ==================== ===== ============= ======================================= ======================= | |
1 ``type`` str pdeil | Reward model register name, refer | | |
| to registry ``REWARD_MODEL_REGISTRY`` | | |
2 | ``expert_data_`` str expert_data. | Path to the expert dataset | Should be a '.pkl' | |
| ``path`` .pkl | | file | |
3 | ``discrete_`` bool False | Whether the action is discrete | | |
| ``action`` | | | |
4 | ``alpha`` float 0.5 | coefficient for Probability | | |
| | Density Estimator | | |
5 | ``clear_buffer`` int 1 | clear buffer per fixed iters | make sure replay | |
``_per_iters`` | buffer's data count | |
| isn't too few. | |
| (code work in entry) | |
== ==================== ===== ============= ======================================= ======================= | |
""" | |
config = dict( | |
# (str) Reward model register name, refer to registry ``REWARD_MODEL_REGISTRY``. | |
type='pdeil', | |
# (str) Path to the expert dataset. | |
# expert_data_path='expert_data.pkl', | |
# (bool) Whether the action is discrete. | |
discrete_action=False, | |
# (float) Coefficient for Probability Density Estimator. | |
# alpha + beta = 1, alpha is in [0,1] | |
# when alpha is close to 0, the estimator has high variance and low bias; | |
# when alpha is close to 1, the estimator has high bias and low variance. | |
alpha=0.5, | |
# (int) Clear buffer per fixed iters. | |
clear_buffer_per_iters=1, | |
) | |
def __init__(self, cfg: dict, device, tb_logger: 'SummaryWriter') -> None: # noqa | |
""" | |
Overview: | |
Initialize ``self.`` See ``help(type(self))`` for accurate signature. | |
Some rules in naming the attributes of ``self.``: | |
- ``e_`` : expert values | |
- ``_sigma_`` : standard division values | |
- ``p_`` : current policy values | |
- ``_s_`` : states | |
- ``_a_`` : actions | |
Arguments: | |
- cfg (:obj:`Dict`): Training config | |
- device (:obj:`str`): Device usage, i.e. "cpu" or "cuda" | |
- tb_logger (:obj:`str`): Logger, defaultly set as 'SummaryWriter' for model summary | |
""" | |
super(PdeilRewardModel, self).__init__() | |
try: | |
import scipy.stats as stats | |
self.stats = stats | |
except ImportError: | |
import sys | |
logging.warning("Please install scipy first, such as `pip3 install scipy`.") | |
sys.exit(1) | |
self.cfg: dict = cfg | |
self.e_u_s = None | |
self.e_sigma_s = None | |
if cfg.discrete_action: | |
self.svm = None | |
else: | |
self.e_u_s_a = None | |
self.e_sigma_s_a = None | |
self.p_u_s = None | |
self.p_sigma_s = None | |
self.expert_data = None | |
self.train_data: list = [] | |
assert device in ["cpu", "cuda"] or "cuda" in device | |
# pedil default use cpu device | |
self.device = 'cpu' | |
self.load_expert_data() | |
states: list = [] | |
actions: list = [] | |
for item in self.expert_data: | |
states.append(item['obs']) | |
actions.append(item['action']) | |
states: torch.Tensor = torch.stack(states, dim=0) | |
actions: torch.Tensor = torch.stack(actions, dim=0) | |
self.e_u_s: torch.Tensor = torch.mean(states, axis=0) | |
self.e_sigma_s: torch.Tensor = cov(states, rowvar=False) | |
if self.cfg.discrete_action and SVC is None: | |
one_time_warning("You are using discrete action while the SVC is not installed!") | |
if self.cfg.discrete_action and SVC is not None: | |
self.svm: SVC = SVC(probability=True) | |
self.svm.fit(states.cpu().numpy(), actions.cpu().numpy()) | |
else: | |
# states action conjuct | |
state_actions = torch.cat((states, actions.float()), dim=-1) | |
self.e_u_s_a = torch.mean(state_actions, axis=0) | |
self.e_sigma_s_a = cov(state_actions, rowvar=False) | |
def load_expert_data(self) -> None: | |
""" | |
Overview: | |
Getting the expert data from ``config['expert_data_path']`` attribute in self. | |
Effects: | |
This is a side effect function which updates the expert data attribute (e.g. ``self.expert_data``) | |
""" | |
expert_data_path: str = self.cfg.expert_data_path | |
with open(expert_data_path, 'rb') as f: | |
self.expert_data: list = pickle.load(f) | |
def _train(self, states: torch.Tensor) -> None: | |
""" | |
Overview: | |
Helper function for ``train`` which caclulates loss for train data and expert data. | |
Arguments: | |
- states (:obj:`torch.Tensor`): current policy states | |
Effects: | |
- Update attributes of ``p_u_s`` and ``p_sigma_s`` | |
""" | |
# we only need to collect the current policy state | |
self.p_u_s = torch.mean(states, axis=0) | |
self.p_sigma_s = cov(states, rowvar=False) | |
def train(self): | |
""" | |
Overview: | |
Training the Pdeil reward model. | |
""" | |
states = torch.stack([item['obs'] for item in self.train_data], dim=0) | |
self._train(states) | |
def _batch_mn_pdf(self, x: np.ndarray, mean: np.ndarray, cov: np.ndarray) -> np.ndarray: | |
""" | |
Overview: | |
Get multivariate normal pdf of given np array. | |
""" | |
return np.asarray( | |
self.stats.multivariate_normal.pdf(x, mean=mean, cov=cov, allow_singular=False), dtype=np.float32 | |
) | |
def estimate(self, data: list) -> List[Dict]: | |
""" | |
Overview: | |
Estimate reward by rewriting the reward keys. | |
Arguments: | |
- data (:obj:`list`): the list of data used for estimation,\ | |
with at least ``obs`` and ``action`` keys. | |
Effects: | |
- This is a side effect function which updates the reward values in place. | |
""" | |
# NOTE: deepcopy reward part of data is very important, | |
# otherwise the reward of data in the replay buffer will be incorrectly modified. | |
train_data_augmented = self.reward_deepcopy(data) | |
s = torch.stack([item['obs'] for item in train_data_augmented], dim=0) | |
a = torch.stack([item['action'] for item in train_data_augmented], dim=0) | |
if self.p_u_s is None: | |
print("you need to train you reward model first") | |
for item in train_data_augmented: | |
item['reward'].zero_() | |
else: | |
rho_1 = self._batch_mn_pdf(s.cpu().numpy(), self.e_u_s.cpu().numpy(), self.e_sigma_s.cpu().numpy()) | |
rho_1 = torch.from_numpy(rho_1) | |
rho_2 = self._batch_mn_pdf(s.cpu().numpy(), self.p_u_s.cpu().numpy(), self.p_sigma_s.cpu().numpy()) | |
rho_2 = torch.from_numpy(rho_2) | |
if self.cfg.discrete_action: | |
rho_3 = self.svm.predict_proba(s.cpu().numpy())[a.cpu().numpy()] | |
rho_3 = torch.from_numpy(rho_3) | |
else: | |
s_a = torch.cat([s, a.float()], dim=-1) | |
rho_3 = self._batch_mn_pdf( | |
s_a.cpu().numpy(), | |
self.e_u_s_a.cpu().numpy(), | |
self.e_sigma_s_a.cpu().numpy() | |
) | |
rho_3 = torch.from_numpy(rho_3) | |
rho_3 = rho_3 / rho_1 | |
alpha = self.cfg.alpha | |
beta = 1 - alpha | |
den = rho_1 * rho_3 | |
frac = alpha * rho_1 + beta * rho_2 | |
if frac.abs().max() < 1e-4: | |
for item in train_data_augmented: | |
item['reward'].zero_() | |
else: | |
reward = den / frac | |
reward = torch.chunk(reward, reward.shape[0], dim=0) | |
for item, rew in zip(train_data_augmented, reward): | |
item['reward'] = rew | |
return train_data_augmented | |
def collect_data(self, item: list): | |
""" | |
Overview: | |
Collecting training data by iterating data items in the input list | |
Arguments: | |
- data (:obj:`list`): Raw training data (e.g. some form of states, actions, obs, etc) | |
Effects: | |
- This is a side effect function which updates the data attribute in ``self`` by \ | |
iterating data items in the input data items' list | |
""" | |
self.train_data.extend(item) | |
def clear_data(self): | |
""" | |
Overview: | |
Clearing training data. \ | |
This is a side effect function which clears the data attribute in ``self`` | |
""" | |
self.train_data.clear() | |