Spaces:
Sleeping
Sleeping
import random | |
import numpy as np | |
from dizoo.gfootball.envs.obs.gfootball_obs import PlayerObs, MatchObs | |
from ding.utils.data import default_collate | |
def generate_data(player_obs: dict) -> np.array: | |
dim = player_obs['dim'] | |
min = player_obs['value']['min'] | |
max = player_obs['value']['max'] | |
dinfo = player_obs['value']['dinfo'] | |
if dinfo in ['one-hot', 'boolean vector']: | |
data = np.zeros((dim, ), dtype=np.float32) | |
data[random.randint(0, dim - 1)] = 1 | |
return data | |
elif dinfo == 'float': | |
data = np.random.rand(dim) | |
for dim_idx in range(dim): | |
data[dim_idx] = min[dim_idx] + (max[dim_idx] - min[dim_idx]) * data[dim_idx] | |
return data | |
class FakeGfootballDataset: | |
def __init__(self): | |
match_obs = MatchObs({}) | |
player_obs = PlayerObs({}) | |
self.match_obs_info = match_obs.template | |
self.player_obs_info = player_obs.template | |
self.action_dim = 19 | |
self.batch_size = 4 | |
del match_obs, player_obs | |
def __len__(self) -> int: | |
return self.batch_size | |
def get_random_action(self) -> np.array: | |
return np.random.randint(0, self.action_dim - 1, size=(1, )) | |
def get_random_obs(self) -> dict: | |
inputs = {} | |
for match_obs in self.match_obs_info: | |
key = match_obs['ret_key'] | |
data = generate_data(match_obs) | |
inputs[key] = data | |
players_list = [] | |
for _ in range(22): | |
one_player = {} | |
for player_obs in self.player_obs_info: | |
key = player_obs['ret_key'] | |
data = generate_data(player_obs) | |
one_player[key] = data | |
players_list.append(one_player) | |
inputs['players'] = players_list | |
return inputs | |
def get_batched_obs(self, bs: int) -> dict: | |
batch = [] | |
for _ in range(bs): | |
batch.append(self.get_random_obs()) | |
return default_collate(batch) | |
def get_random_reward(self) -> np.array: | |
return np.array([random.random() - 0.5]) | |
def get_random_terminals(self) -> int: | |
sample = random.random() | |
if sample > 0.99: | |
return 1 | |
return 0 | |
def get_batch_sample(self, bs: int) -> list: | |
batch = [] | |
for _ in range(bs): | |
step = {} | |
step['obs'] = self.get_random_obs() | |
step['next_obs'] = self.get_random_obs() | |
step['action'] = self.get_random_action() | |
step['done'] = self.get_random_terminals() | |
step['reward'] = self.get_random_reward() | |
batch.append(step) | |
return batch | |