File size: 2,275 Bytes
a162e39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import numpy as np
import gym


class SymetryWrapper(gym.Wrapper):
    """
    :param env: (gym.Env) Gym environment that will be wrapped
    """

    def __init__(self, env):
        # Call the parent constructor, so we can access self.env later

        self.symetry = False  # no need to perform a symetry
        super(SymetryWrapper, self).__init__(env)

    def reset(self):
        """
        Reset the environment
        """
        obs = self.env.reset()

        obs = self.post_obs(obs)

        return obs

    def step(self, action):
        """
        :param action: ([float] or int) Action taken by the agent
        :return: (np.ndarray, float, bool, dict) observation, reward, is the episode over?, additional informations
        """
        if self.symetry:
            action = symetrise_action(action)

        obs, reward, done, info = self.env.step(action)

        obs = self.post_obs(obs)

        return obs, reward, done, info

    def post_obs(self, obs):
        self.symetry = get_symetry(obs)
        if self.symetry:
            obs = symetrise_obs(obs)
        return obs


def get_symetry(obs):
    blue_obs, red_obs, blue_fire, red_fire = obs

    # count the drones who are positioned above the 0 x-axis
    count = 0
    for this_obs in (blue_obs, red_obs):
        for d in this_obs:
            add = 1 if (d[1] < 0.5) else 0
            count += add

    # compare with the total
    symetry = bool(2*count < (len(blue_obs) + len(red_obs)))

    return symetry


def symetrise_obs(obs):

    blue_obs, red_obs, blue_fire, red_fire = obs

    for this_obs in (blue_obs, red_obs):
        # symetrise positions and speeds
        this_obs[:, 1] = 1 - this_obs[:, 1]
        this_obs[:, 4] = 1 - this_obs[:, 4]

    return blue_obs, red_obs, blue_fire, red_fire


def symetrise_action(action):

    blue_action, red_action = action

    for this_action in (blue_action, red_action):
        for act in this_action:

            # symetrise action
            act[1] = - act[1]

    action = blue_action, red_action

    return action


def test_symetrise_obs():

    obs = np.arange(12).reshape(2, 6), np.arange(12).reshape(2, 6), np.random.random((1, 1)), np.random.random((1, 1))
    print(obs)
    symetrise_obs(obs)
    print(obs)