File size: 470 Bytes
a162e39 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
from swarm_policy import SwarmPolicy
def run_episode(env, obs, blues: int, reds: int):
blue_policy = SwarmPolicy(blues=blues, reds=reds, is_blue=True)
red_policy = SwarmPolicy(blues=blues, reds=reds, is_blue=False)
sum_reward = 0
done = False
while not done:
action = blue_policy.predict(obs), red_policy.predict(obs)
obs, reward, done, info = env.step(action)
sum_reward += reward
return obs, sum_reward, done, info
|