Spaces:
Sleeping
Sleeping
import torch | |
import logging | |
import math | |
from ding.torch_utils import to_list | |
from ding.utils.data import NaiveRLDataset | |
from torch.utils.data import DataLoader | |
logging.basicConfig(level=logging.INFO) | |
def test_accuracy_in_dataset(data_path, batch_size, policy): | |
""" | |
Overview: | |
Evaluate total accuracy and accuracy of each action in dataset from | |
``datapath`` using the ``policy`` for gfootball env. | |
""" | |
dataset = NaiveRLDataset(data_path) | |
dataloader = DataLoader(dataset, batch_size) | |
total_accuracy_in_dataset = [] | |
action_accuracy_in_dataset = {k: [] for k in range(19)} | |
for _, minibatch in enumerate(dataloader): | |
policy_output = policy._forward_eval(minibatch['obs']) | |
pred_action = policy_output['action'] | |
total_accuracy = (pred_action == minibatch['action'].view(-1)).float().mean() | |
total_accuracy_in_dataset.append(total_accuracy) | |
for action_unique in to_list(torch.unique(minibatch['action'])): | |
# find the index where action is `action_unique` in `pred_action` | |
action_index = (pred_action == action_unique).nonzero(as_tuple=True)[0] | |
action_accuracy = (pred_action[action_index] == minibatch['action'].view(-1)[action_index]).float().mean() | |
if math.isnan(action_accuracy): | |
action_accuracy = 0.0 | |
action_accuracy_in_dataset[action_unique].append(action_accuracy) | |
# logging.info(f'the accuracy of action {action_unique} in current train mini-batch is: {action_accuracy}') | |
logging.info(f'total accuracy in dataset is: {torch.tensor(total_accuracy_in_dataset).mean().item()}') | |
logging.info( | |
f'accuracy of each action in dataset is (nan means the action does not appear in the dataset): ' | |
f'{ {k: torch.tensor(action_accuracy_in_dataset[k]).mean().item() for k in range(19)} }' | |
) | |