File size: 470 Bytes
a162e39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17


from swarm_policy import SwarmPolicy


def run_episode(env, obs, blues: int, reds: int):
    blue_policy = SwarmPolicy(blues=blues, reds=reds, is_blue=True)
    red_policy = SwarmPolicy(blues=blues, reds=reds, is_blue=False)
    sum_reward = 0
    done = False
    while not done:
        action = blue_policy.predict(obs), red_policy.predict(obs)
        obs, reward, done, info = env.step(action)
        sum_reward += reward
    return obs, sum_reward, done, info