Spaces:
Runtime error
Runtime error
hyyh28
commited on
Commit
·
1aa4792
1
Parent(s):
eb1b37d
update atari_env
Browse filesupdate atari_env, also make a test_atari.sh
- deciders/parser.py +20 -1
- envs/__init__.py +44 -1
- envs/atari/Boxing_policies.py +161 -0
- envs/atari/Boxing_translator.py +99 -0
- envs/atari/Pong_policies.py +65 -0
- envs/atari/Pong_translator.py +67 -0
- envs/atari/__init__.py +1 -0
- envs/atari/represented_atari_game.py +221 -0
- record_reflexion.csv +2 -0
- test_atari.sh +1 -0
deciders/parser.py
CHANGED
@@ -40,6 +40,25 @@ class SixAction(BaseModel):
|
|
40 |
if field not in [1, 2, 3, 4, 5, 6]:
|
41 |
raise ValueError("Action is not valid ([1, 2, 3, 4, 5, 6])!")
|
42 |
return field
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
class ContinuousAction(BaseModel):
|
45 |
action: float = Field(description="the choosed action to perform")
|
@@ -50,4 +69,4 @@ class ContinuousAction(BaseModel):
|
|
50 |
raise ValueError("Action is not valid ([-1,1])!")
|
51 |
return field
|
52 |
|
53 |
-
PARSERS = {1:ContinuousAction, 2: TwoAction, 3: ThreeAction, 4: FourAction, 6: SixAction}
|
|
|
40 |
if field not in [1, 2, 3, 4, 5, 6]:
|
41 |
raise ValueError("Action is not valid ([1, 2, 3, 4, 5, 6])!")
|
42 |
return field
|
43 |
+
|
44 |
+
|
45 |
+
class NineAction(BaseModel):
|
46 |
+
action: int = Field(description="the choosed action to perform")
|
47 |
+
|
48 |
+
# You can add custom validation logic easily with Pydantic.
|
49 |
+
@validator('action')
|
50 |
+
def action_is_valid(cls, field):
|
51 |
+
if field not in [1, 2, 3, 4, 5, 6, 7, 8, 9]:
|
52 |
+
raise ValueError("Action is not valid ([1, 2, 3, 4, 5, 6, 7, 8, 9])!")
|
53 |
+
return field
|
54 |
+
|
55 |
+
class FullAtariAction(BaseModel):
|
56 |
+
action: int = Field(description="the choosed action to perform")
|
57 |
+
@validator('action')
|
58 |
+
def action_is_valid(cls, info):
|
59 |
+
if info not in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]:
|
60 |
+
raise ValueError("Action is not valid ([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18])!")
|
61 |
+
return info
|
62 |
|
63 |
class ContinuousAction(BaseModel):
|
64 |
action: float = Field(description="the choosed action to perform")
|
|
|
69 |
raise ValueError("Action is not valid ([-1,1])!")
|
70 |
return field
|
71 |
|
72 |
+
PARSERS = {1:ContinuousAction, 2: TwoAction, 3: ThreeAction, 4: FourAction, 6: SixAction, 9:NineAction, 18: FullAtariAction}
|
envs/__init__.py
CHANGED
@@ -10,6 +10,10 @@ from .toy_text import blackjack_translator, blackjack_policies
|
|
10 |
from .toy_text import taxi_translator, taxi_policies
|
11 |
from .toy_text import cliffwalking_translator, cliffwalking_policies
|
12 |
from .toy_text import frozenlake_translator, frozenlake_policies
|
|
|
|
|
|
|
|
|
13 |
|
14 |
REGISTRY = {}
|
15 |
REGISTRY["sampling_wrapper"] = SettableStateEnv
|
@@ -48,4 +52,43 @@ REGISTRY["frozenlake_policies"] = [frozenlake_policies.dedicated_1_policy, froze
|
|
48 |
|
49 |
REGISTRY["mountaincarContinuous_init_translator"] = mountaincarContinuous_translator.GameDescriber
|
50 |
REGISTRY["mountaincarContinuous_basic_translator"] = mountaincarContinuous_translator.BasicStateSequenceTranslator
|
51 |
-
REGISTRY["mountaincarContinuous_policies"] = [mountaincarContinuous_policies.pseudo_random_policy, mountaincarContinuous_policies.real_random_policy]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
from .toy_text import taxi_translator, taxi_policies
|
11 |
from .toy_text import cliffwalking_translator, cliffwalking_policies
|
12 |
from .toy_text import frozenlake_translator, frozenlake_policies
|
13 |
+
from .atari import register_environments
|
14 |
+
from .atari import Boxing_policies, Boxing_translator, Pong_policies, Pong_translator
|
15 |
+
|
16 |
+
register_environments()
|
17 |
|
18 |
REGISTRY = {}
|
19 |
REGISTRY["sampling_wrapper"] = SettableStateEnv
|
|
|
52 |
|
53 |
REGISTRY["mountaincarContinuous_init_translator"] = mountaincarContinuous_translator.GameDescriber
|
54 |
REGISTRY["mountaincarContinuous_basic_translator"] = mountaincarContinuous_translator.BasicStateSequenceTranslator
|
55 |
+
REGISTRY["mountaincarContinuous_policies"] = [mountaincarContinuous_policies.pseudo_random_policy, mountaincarContinuous_policies.real_random_policy]
|
56 |
+
|
57 |
+
|
58 |
+
REGISTRY["RepresentedBoxing_init_translator"] = Boxing_translator.GameDescriber
|
59 |
+
REGISTRY["RepresentedBoxing_basic_translator"] = Boxing_translator.BasicStateSequenceTranslator
|
60 |
+
REGISTRY["RepresentedBoxing_basic_policies"] = [
|
61 |
+
Boxing_policies.real_random_policy,
|
62 |
+
Boxing_policies.pseudo_random_policy,
|
63 |
+
Boxing_policies.dedicated_1_policy,
|
64 |
+
Boxing_policies.dedicated_2_policy,
|
65 |
+
Boxing_policies.dedicated_3_policy,
|
66 |
+
Boxing_policies.dedicated_4_policy,
|
67 |
+
Boxing_policies.dedicated_5_policy,
|
68 |
+
Boxing_policies.dedicated_6_policy,
|
69 |
+
Boxing_policies.dedicated_7_policy,
|
70 |
+
Boxing_policies.dedicated_8_policy,
|
71 |
+
Boxing_policies.dedicated_9_policy,
|
72 |
+
Boxing_policies.dedicated_10_policy,
|
73 |
+
Boxing_policies.dedicated_11_policy,
|
74 |
+
Boxing_policies.dedicated_12_policy,
|
75 |
+
Boxing_policies.dedicated_13_policy,
|
76 |
+
Boxing_policies.dedicated_14_policy,
|
77 |
+
Boxing_policies.dedicated_15_policy,
|
78 |
+
Boxing_policies.dedicated_16_policy,
|
79 |
+
Boxing_policies.dedicated_17_policy,
|
80 |
+
Boxing_policies.dedicated_18_policy
|
81 |
+
]
|
82 |
+
|
83 |
+
REGISTRY["RepresentedPong_init_translator"] = Pong_translator.GameDescriber
|
84 |
+
REGISTRY["RepresentedPong_basic_translator"] = Pong_translator.BasicStateSequenceTranslator
|
85 |
+
REGISTRY["RepresentedPong_basic_policies"] = [
|
86 |
+
Pong_policies.real_random_policy,
|
87 |
+
Pong_policies.pseudo_random_policy,
|
88 |
+
Pong_policies.dedicated_1_policy,
|
89 |
+
Pong_policies.dedicated_2_policy,
|
90 |
+
Pong_policies.dedicated_3_policy,
|
91 |
+
Pong_policies.dedicated_4_policy,
|
92 |
+
Pong_policies.dedicated_5_policy,
|
93 |
+
Pong_policies.dedicated_6_policy,
|
94 |
+
]
|
envs/atari/Boxing_policies.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def dedicated_1_policy(state, pre_action=1):
|
5 |
+
def get_description():
|
6 |
+
return "Always select action 1 which does NOOP (no operation)"
|
7 |
+
|
8 |
+
dedicated_1_policy.description = get_description()
|
9 |
+
return 1
|
10 |
+
|
11 |
+
|
12 |
+
def dedicated_2_policy(state, pre_action=1):
|
13 |
+
def get_description():
|
14 |
+
return "Always select action 2 which hits the enemy"
|
15 |
+
|
16 |
+
dedicated_1_policy.description = get_description()
|
17 |
+
return 2
|
18 |
+
|
19 |
+
|
20 |
+
def dedicated_3_policy(state, pre_action=1):
|
21 |
+
def get_description():
|
22 |
+
return "Always select action 3 which moves the agent up"
|
23 |
+
|
24 |
+
dedicated_3_policy.description = get_description()
|
25 |
+
return 3
|
26 |
+
|
27 |
+
|
28 |
+
def dedicated_4_policy(state, pre_action=1):
|
29 |
+
def get_description():
|
30 |
+
return "Always select action 4 which moves the agent right"
|
31 |
+
|
32 |
+
dedicated_4_policy.description = get_description()
|
33 |
+
return 4
|
34 |
+
|
35 |
+
|
36 |
+
def dedicated_5_policy(state, pre_action=1):
|
37 |
+
def get_description():
|
38 |
+
return "Always select action 5 which moves the agent left"
|
39 |
+
|
40 |
+
dedicated_5_policy.description = get_description()
|
41 |
+
return 5
|
42 |
+
|
43 |
+
|
44 |
+
def pseudo_random_policy(state, pre_action):
|
45 |
+
def get_description():
|
46 |
+
return "Select an action among 1 to 18 alternatively"
|
47 |
+
pseudo_random_policy.description = get_description()
|
48 |
+
return pre_action % 18 + 1
|
49 |
+
|
50 |
+
|
51 |
+
def real_random_policy(state, pre_action=1):
|
52 |
+
def get_description():
|
53 |
+
return "Select action with a random policy"
|
54 |
+
|
55 |
+
real_random_policy.description = get_description()
|
56 |
+
return np.random.choice(range(0, 18)) + 1
|
57 |
+
|
58 |
+
|
59 |
+
# Complete set of dedicated action policies
|
60 |
+
def dedicated_6_policy(state, pre_action=1):
|
61 |
+
def get_description():
|
62 |
+
return "Always select action 6 which moves the agent down"
|
63 |
+
|
64 |
+
dedicated_6_policy.description = get_description()
|
65 |
+
return 6
|
66 |
+
|
67 |
+
|
68 |
+
def dedicated_7_policy(state, pre_action=1):
|
69 |
+
def get_description():
|
70 |
+
return "Always select action 7 which moves the agent up and to the right"
|
71 |
+
|
72 |
+
dedicated_7_policy.description = get_description()
|
73 |
+
return 7
|
74 |
+
|
75 |
+
|
76 |
+
def dedicated_8_policy(state, pre_action=1):
|
77 |
+
def get_description():
|
78 |
+
return "Always select action 8 which moves the agent up and to the left"
|
79 |
+
|
80 |
+
dedicated_8.description = get_description()
|
81 |
+
return 8
|
82 |
+
|
83 |
+
|
84 |
+
def dedicated_9_policy(state, pre_action=1):
|
85 |
+
def get_description():
|
86 |
+
return "Always select action 9 which moves the agent down and to the right"
|
87 |
+
|
88 |
+
dedicated_9.description = get_description()
|
89 |
+
return 9
|
90 |
+
|
91 |
+
|
92 |
+
def dedicated_10_policy(state, pre_action=1):
|
93 |
+
def get_description():
|
94 |
+
return "Always select action 10 which moves the agent down and to the left"
|
95 |
+
|
96 |
+
dedicated_10_policy.description = get_description()
|
97 |
+
return 10
|
98 |
+
|
99 |
+
|
100 |
+
def dedicated_11_policy(state, pre_action=1):
|
101 |
+
def get_description():
|
102 |
+
return "Always select action 11 which moves the agent up while hiting the enemy"
|
103 |
+
|
104 |
+
dedicated_11_policy.description = get_description()
|
105 |
+
return 11
|
106 |
+
|
107 |
+
|
108 |
+
def dedicated_12_policy(state, pre_action=1):
|
109 |
+
def get_description():
|
110 |
+
return "Always select action 12 which moves the agent right while hiting the enemy"
|
111 |
+
|
112 |
+
dedicated_12_policy.description = get_description()
|
113 |
+
return 12
|
114 |
+
|
115 |
+
|
116 |
+
def dedicated_13_policy(state, pre_action=1):
|
117 |
+
def get_description():
|
118 |
+
return "Always select action 13 which moves the agent left while hiting the enemy"
|
119 |
+
|
120 |
+
dedicated_13_policy.description = get_description()
|
121 |
+
return 13
|
122 |
+
|
123 |
+
|
124 |
+
def dedicated_14_policy(state, pre_action=1):
|
125 |
+
def get_description():
|
126 |
+
return "Always select action 14 which moves the agent down while hiting the enemy"
|
127 |
+
|
128 |
+
dedicated_14_policy.description = get_description()
|
129 |
+
return 14
|
130 |
+
|
131 |
+
|
132 |
+
def dedicated_15_policy(state, pre_action=1):
|
133 |
+
def get_description():
|
134 |
+
return "Always select action 15 which moves the agent up and to the right while hiting the enemy"
|
135 |
+
|
136 |
+
dedicated_15_policy.description = get_description()
|
137 |
+
return 15
|
138 |
+
|
139 |
+
|
140 |
+
def dedicated_16_policy(state, pre_action=1):
|
141 |
+
def get_description():
|
142 |
+
return "Always select action 16 which moves the agent up and to the left while hiting the enemy"
|
143 |
+
|
144 |
+
dedicated_16_policy.description = get_description()
|
145 |
+
return 16
|
146 |
+
|
147 |
+
|
148 |
+
def dedicated_17_policy(state, pre_action=1):
|
149 |
+
def get_description():
|
150 |
+
return "Always select action 17 which moves the agent down and to the right while hiting the enemy"
|
151 |
+
|
152 |
+
dedicated_17_policy.description = get_description()
|
153 |
+
return 17
|
154 |
+
|
155 |
+
|
156 |
+
def dedicated_18_policy(state, pre_action=1):
|
157 |
+
def get_description():
|
158 |
+
return "Always select action 18 which moves the agent down and to the left while hiting the enemy"
|
159 |
+
|
160 |
+
dedicated_18_policy.description = get_description()
|
161 |
+
return 18
|
envs/atari/Boxing_translator.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# [Translator classes and functions for Atari Boxing environment]
|
2 |
+
|
3 |
+
class BasicLevelTranslator:
|
4 |
+
def __init__(self,):
|
5 |
+
pass
|
6 |
+
|
7 |
+
def translate(self, state):
|
8 |
+
player_x, player_y, enemy_x, enemy_y, enemy_score, clock, player_score = state
|
9 |
+
return f"The player is at position ({player_x, player_y}, your opponent is at position ({enemy_x, enemy_y}) ), " \
|
10 |
+
f"your oppoent's score is {enemy_score}, your score is {player_score}. Move left and right will change the player_x while move up and down will change the player_y"
|
11 |
+
|
12 |
+
|
13 |
+
class GameDescriber:
|
14 |
+
def __init__(self, args):
|
15 |
+
self.is_only_local_obs = args.is_only_local_obs == 1
|
16 |
+
self.max_episode_len = args.max_episode_len
|
17 |
+
self.action_desc_dict = {
|
18 |
+
}
|
19 |
+
self.reward_desc_dict = {
|
20 |
+
}
|
21 |
+
|
22 |
+
def describe_goal(self):
|
23 |
+
return "The goal is to knock out your opponent."
|
24 |
+
|
25 |
+
def translate_terminate_state(self, state, episode_len, max_episode_len):
|
26 |
+
return ""
|
27 |
+
|
28 |
+
def translate_potential_next_state(self, state, action):
|
29 |
+
return ""
|
30 |
+
|
31 |
+
def describe_game(self):
|
32 |
+
return "In the Boxing game, you fight an opponent in a boxing ring. You score points for hitting the opponent. " \
|
33 |
+
"If you score 100 points, your opponent is knocked out.Scoring Points: When you get near enough to your opponent to throw a punch, " \
|
34 |
+
"press the red button. Each punch moves your opponent slightly back and away from the punch." \
|
35 |
+
" If you move him to the ropes, he can't easily duck the next punch, " \
|
36 |
+
"and you can set up a real scoring barrage. But don't get caughton the ropes yourself! " \
|
37 |
+
"Watch your distance. If you move in too close, the computer gets tougher;" \
|
38 |
+
" but if you're too far away, you can't land scoring punches. "
|
39 |
+
|
40 |
+
def describe_action(self):
|
41 |
+
return "Your Next Move: \n Please choose an action. Type '1' for NOOP (no operation), '2' to hit your opponent, " \
|
42 |
+
"'3' to move up, '4' to move right, '5' to move left, '6' to move down, '7' to move up-right, " \
|
43 |
+
"'8' to move up-left, '9' to move down-right, '10' to move down-left, '11' to hit your opponent and move up, " \
|
44 |
+
"'12' to hit your opponent and move right, '13' to hit your opponent and move left, '14' to hit your opponent and move down, " \
|
45 |
+
"'15' to hit your opponent and move up-right, '16' to hit your opponent and move up-left, '17' to hit your opponent and move down-right, " \
|
46 |
+
"or '18' to hit your opponent and move down-left. Ensure you only provide the action number " \
|
47 |
+
"from the valid action list, i.e., [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]."
|
48 |
+
|
49 |
+
|
50 |
+
class BasicStateSequenceTranslator(BasicLevelTranslator):
|
51 |
+
def translate(self, infos, is_current=False):
|
52 |
+
descriptions = []
|
53 |
+
if is_current:
|
54 |
+
state_desc = BasicLevelTranslator().translate(infos[-1]['state'])
|
55 |
+
return state_desc
|
56 |
+
for i, info in enumerate(infos):
|
57 |
+
assert 'state' in info, "info should contain state information"
|
58 |
+
|
59 |
+
state_desc = BasicLevelTranslator().translate(info['state'])
|
60 |
+
if info['action'] == 1:
|
61 |
+
action_desc = f"Take Action: 'Do nothing'"
|
62 |
+
elif info['action'] == 2:
|
63 |
+
action_desc = f"Take Action: 'Hit your opponent'"
|
64 |
+
elif info['action'] == 3:
|
65 |
+
action_desc = f"Take Action: 'Move up'"
|
66 |
+
elif info['action'] == 4:
|
67 |
+
action_desc = f"Take Action: 'Move right'"
|
68 |
+
elif info['action'] == 5:
|
69 |
+
action_desc = f"Take Action: 'Move left'"
|
70 |
+
elif info['action'] == 6:
|
71 |
+
action_desc = f"Take Action: 'Move down'"
|
72 |
+
elif info['action'] == 7:
|
73 |
+
action_desc = f"Take Action: 'Move up-right'"
|
74 |
+
elif info['action'] == 8:
|
75 |
+
action_desc = f"Take Action: 'Move up-lefr'"
|
76 |
+
elif info['action'] == 9:
|
77 |
+
action_desc = f"Take Action: 'Move down-right'"
|
78 |
+
elif info['action'] == 10:
|
79 |
+
action_desc = f"Take Action: 'Move down-left'"
|
80 |
+
elif info['action'] == 11:
|
81 |
+
action_desc = f"Take Action: 'Hit your opponent and move up'"
|
82 |
+
elif info['action'] == 12:
|
83 |
+
action_desc = f"Take Action: 'Hit your opponent and move right'"
|
84 |
+
elif info['action'] == 13:
|
85 |
+
action_desc = f"Take Action: 'Hit your opponent and move left'"
|
86 |
+
elif info['action'] == 14:
|
87 |
+
action_desc = f"Take Action: 'Hit your opponent and move down'"
|
88 |
+
elif info['action'] == 15:
|
89 |
+
action_desc = f"Take Action: 'Hit your opponent and move up-right'"
|
90 |
+
elif info['action'] == 16:
|
91 |
+
action_desc = f"Take Action: 'Hit your opponent and move up-left'"
|
92 |
+
elif info['action'] == 17:
|
93 |
+
action_desc = f"Take Action: 'Hit your opponent and move down-right'"
|
94 |
+
else:
|
95 |
+
action_desc = f"Take Action: 'Hit your opponent and move down-left'"
|
96 |
+
reward_desc = f"Result: Reward of {info['reward']}, "
|
97 |
+
next_state_desc = BasicLevelTranslator().translate(info['next_state'])
|
98 |
+
descriptions.append(f"{state_desc}.\n {action_desc} \n {reward_desc} \n Transit to {next_state_desc}")
|
99 |
+
return descriptions
|
envs/atari/Pong_policies.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def dedicated_1_policy(state, pre_action=1):
|
5 |
+
def get_description():
|
6 |
+
return "Always select action 1 which does NOOP (no operation)"
|
7 |
+
|
8 |
+
dedicated_1_policy.description = get_description()
|
9 |
+
return 1
|
10 |
+
|
11 |
+
|
12 |
+
def dedicated_2_policy(state, pre_action=1):
|
13 |
+
def get_description():
|
14 |
+
return "Always select action 2 which hits the ball"
|
15 |
+
|
16 |
+
dedicated_1_policy.description = get_description()
|
17 |
+
return 2
|
18 |
+
|
19 |
+
|
20 |
+
def dedicated_3_policy(state, pre_action=1):
|
21 |
+
def get_description():
|
22 |
+
return "Always select action 3 which moves the agent right"
|
23 |
+
|
24 |
+
dedicated_3_policy.description = get_description()
|
25 |
+
return 3
|
26 |
+
|
27 |
+
|
28 |
+
def dedicated_4_policy(state, pre_action=1):
|
29 |
+
def get_description():
|
30 |
+
return "Always select action 4 which moves the agent left"
|
31 |
+
|
32 |
+
dedicated_4_policy.description = get_description()
|
33 |
+
return 4
|
34 |
+
|
35 |
+
|
36 |
+
def dedicated_5_policy(state, pre_action=1):
|
37 |
+
def get_description():
|
38 |
+
return "Always select action 5 which moves the agent right while hiting the ball"
|
39 |
+
|
40 |
+
dedicated_5_policy.description = get_description()
|
41 |
+
return 5
|
42 |
+
|
43 |
+
|
44 |
+
def pseudo_random_policy(state, pre_action):
|
45 |
+
def get_description():
|
46 |
+
return "Select an action among 1 to 6 alternatively"
|
47 |
+
pseudo_random_policy.description = get_description()
|
48 |
+
return pre_action % 6 + 1
|
49 |
+
|
50 |
+
|
51 |
+
def real_random_policy(state, pre_action=1):
|
52 |
+
def get_description():
|
53 |
+
return "Select action with a random policy"
|
54 |
+
|
55 |
+
real_random_policy.description = get_description()
|
56 |
+
return np.random.choice(range(0, 6)) + 1
|
57 |
+
|
58 |
+
|
59 |
+
# Complete set of dedicated action policies
|
60 |
+
def dedicated_6_policy(state, pre_action=1):
|
61 |
+
def get_description():
|
62 |
+
return "Always select action 5 which moves the agent left while hiting the ball"
|
63 |
+
|
64 |
+
dedicated_6_policy.description = get_description()
|
65 |
+
return 6
|
envs/atari/Pong_translator.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# [Translator classes and functions for Atari Boxing environment]
|
2 |
+
#'labels': {'player_y': 109, 'player_x': 188, 'enemy_y': 20, 'enemy_x': 64, 'ball_x': 0, 'ball_y': 0, 'enemy_score': 0, 'player_score': 0}
|
3 |
+
class BasicLevelTranslator:
|
4 |
+
def __init__(self, ):
|
5 |
+
pass
|
6 |
+
|
7 |
+
def translate(self, state):
|
8 |
+
player_y, player_x, enemy_y, enemy_x, ball_x, ball_y, enemy_score, player_score = state
|
9 |
+
return f"The player is at position ({player_y, player_x}, your opponent is at position ({enemy_y, enemy_x}) ), the ball is at ({ball_y, ball_x})" \
|
10 |
+
f"your oppoent's score is {enemy_score}, your score is {player_score}."
|
11 |
+
|
12 |
+
|
13 |
+
class GameDescriber:
|
14 |
+
def __init__(self, args):
|
15 |
+
self.is_only_local_obs = args.is_only_local_obs == 1
|
16 |
+
self.max_episode_len = args.max_episode_len
|
17 |
+
self.action_desc_dict = {
|
18 |
+
}
|
19 |
+
self.reward_desc_dict = {
|
20 |
+
}
|
21 |
+
|
22 |
+
def describe_goal(self):
|
23 |
+
return "The goal is to knock out your opponent."
|
24 |
+
|
25 |
+
def translate_terminate_state(self, state, episode_len, max_episode_len):
|
26 |
+
return ""
|
27 |
+
|
28 |
+
def translate_potential_next_state(self, state, action):
|
29 |
+
return ""
|
30 |
+
|
31 |
+
def describe_game(self):
|
32 |
+
return "In the Pong game, you play the ball with your opponent, each player rallys the ball by moving the paddles on the playfield. " \
|
33 |
+
"Paddles move only vertically on the playfield. A player scores one point when the opponent hits the ball out of bounds or misses a hit. " \
|
34 |
+
"The first player to score 21 points wins the game."
|
35 |
+
|
36 |
+
def describe_action(self):
|
37 |
+
return "Your Next Move: \n Please choose an action. Type '1' for NOOP (no operation), '2' to hit the ball, " \
|
38 |
+
"'3' to move right, '4' to move left, '5' to move right while hit the ball, '6' to move left while hit the ball. Ensure you only provide the action number " \
|
39 |
+
"from the valid action list, i.e., [1, 2, 3, 4, 5, 6]."
|
40 |
+
|
41 |
+
|
42 |
+
class BasicStateSequenceTranslator(BasicLevelTranslator):
|
43 |
+
def translate(self, infos, is_current=False):
|
44 |
+
descriptions = []
|
45 |
+
if is_current:
|
46 |
+
state_desc = BasicLevelTranslator().translate(infos[-1]['state'])
|
47 |
+
return state_desc
|
48 |
+
for i, info in enumerate(infos):
|
49 |
+
assert 'state' in info, "info should contain state information"
|
50 |
+
|
51 |
+
state_desc = BasicLevelTranslator().translate(info['state'])
|
52 |
+
if info['action'] == 1:
|
53 |
+
action_desc = f"Take Action: 'Do nothing'"
|
54 |
+
elif info['action'] == 2:
|
55 |
+
action_desc = f"Take Action: 'Hit your ball'"
|
56 |
+
elif info['action'] == 3:
|
57 |
+
action_desc = f"Take Action: 'Move right'"
|
58 |
+
elif info['action'] == 4:
|
59 |
+
action_desc = f"Take Action: 'Move left'"
|
60 |
+
elif info['action'] == 5:
|
61 |
+
action_desc = f"Take Action: 'Move right while hiting the ball'"
|
62 |
+
else:
|
63 |
+
action_desc = f"Take Action: 'Move left while hiting the ball'"
|
64 |
+
reward_desc = f"Result: Reward of {info['reward']}, "
|
65 |
+
next_state_desc = BasicLevelTranslator().translate(info['next_state'])
|
66 |
+
descriptions.append(f"{state_desc}.\n {action_desc} \n {reward_desc} \n Transit to {next_state_desc}")
|
67 |
+
return descriptions
|
envs/atari/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .represented_atari_game import register_environments
|
envs/atari/represented_atari_game.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gym
|
2 |
+
import ale_py
|
3 |
+
import numpy as np
|
4 |
+
from atariari.benchmark.wrapper import AtariARIWrapper
|
5 |
+
from typing import Optional, Union
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
class RepresentedAtariEnv(gym.Wrapper):
|
10 |
+
def __init__(self, env_name, render_mode=None):
|
11 |
+
super().__init__(AtariARIWrapper(gym.make(env_name, render_mode=render_mode)))
|
12 |
+
self.metadata = self.env.metadata
|
13 |
+
self.env_name = env_name
|
14 |
+
self.observation = None
|
15 |
+
self.info = {}
|
16 |
+
self.action_space = self.env.action_space
|
17 |
+
_ = self.env.reset()
|
18 |
+
obs = self.env.labels()
|
19 |
+
obs_dim = len(obs)
|
20 |
+
self.obs_label = obs.keys()
|
21 |
+
self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(obs_dim,), dtype=np.float32)
|
22 |
+
|
23 |
+
def step(self, action):
|
24 |
+
original_next_obs, reward, env_done, env_truncated, info = self.env.step(action)
|
25 |
+
next_obs = self.env.labels()
|
26 |
+
self.obs_label = next_obs.keys()
|
27 |
+
self.observation = next_obs
|
28 |
+
return np.array(list(next_obs.values())), reward, env_done, env_truncated, info
|
29 |
+
|
30 |
+
def reset(self, seed=0):
|
31 |
+
obs_original, info = self.env.reset(seed=seed)
|
32 |
+
obs = self.env.labels()
|
33 |
+
self.obs_label = obs.keys()
|
34 |
+
self.observation = obs
|
35 |
+
return np.array(list(obs.values())), info
|
36 |
+
|
37 |
+
def get_info(self):
|
38 |
+
return self.observation
|
39 |
+
|
40 |
+
def render(self, render_mode=None):
|
41 |
+
return self.env.render()
|
42 |
+
|
43 |
+
|
44 |
+
class RepresentedMsPacman(RepresentedAtariEnv):
|
45 |
+
def __init__(self, render_mode: Optional[str]=None):
|
46 |
+
env_name = "MsPacmanNoFrameskip-v4"
|
47 |
+
super().__init__(env_name=env_name, render_mode=render_mode)
|
48 |
+
|
49 |
+
|
50 |
+
class RepresentedBowling(RepresentedAtariEnv):
|
51 |
+
def __init__(self, render_mode: Optional[str]=None):
|
52 |
+
env_name = "BowlingNoFrameskip-v4"
|
53 |
+
super().__init__(env_name=env_name, render_mode=render_mode)
|
54 |
+
|
55 |
+
|
56 |
+
class RepresentedBoxing(RepresentedAtariEnv):
|
57 |
+
def __init__(self, render_mode: Optional[str]=None):
|
58 |
+
env_name = "BoxingNoFrameskip-v4"
|
59 |
+
super().__init__(env_name=env_name, render_mode=render_mode)
|
60 |
+
|
61 |
+
|
62 |
+
class RepresentedBreakout(RepresentedAtariEnv):
|
63 |
+
def __init__(self, render_mode: Optional[str]=None):
|
64 |
+
env_name = "BreakoutNoFrameskip-v4"
|
65 |
+
super().__init__(env_name=env_name, render_mode=render_mode)
|
66 |
+
|
67 |
+
|
68 |
+
class RepresentedDemonAttack(RepresentedAtariEnv):
|
69 |
+
def __init__(self, render_mode: Optional[str]=None):
|
70 |
+
env_name = "DemonAttackNoFrameskip-v4"
|
71 |
+
super().__init__(env_name=env_name, render_mode=render_mode)
|
72 |
+
|
73 |
+
|
74 |
+
class RepresentedFreeway(RepresentedAtariEnv):
|
75 |
+
def __init__(self, render_mode: Optional[str]=None):
|
76 |
+
env_name = "FreewayNoFrameskip-v4"
|
77 |
+
super().__init__(env_name=env_name, render_mode=render_mode)
|
78 |
+
|
79 |
+
|
80 |
+
class RepresentedFrostbite(RepresentedAtariEnv):
|
81 |
+
def __init__(self, render_mode: Optional[str]=None):
|
82 |
+
env_name = "FrostbiteNoFrameskip-v4"
|
83 |
+
super().__init__(env_name=env_name, render_mode=render_mode)
|
84 |
+
|
85 |
+
|
86 |
+
class RepresentedHero(RepresentedAtariEnv):
|
87 |
+
def __init__(self, render_mode: Optional[str]=None):
|
88 |
+
env_name = "HeroNoFrameskip-v4"
|
89 |
+
super().__init__(env_name=env_name, render_mode=render_mode)
|
90 |
+
|
91 |
+
|
92 |
+
class RepresentedMontezumaRevenge(RepresentedAtariEnv):
|
93 |
+
def __init__(self, render_mode: Optional[str]=None):
|
94 |
+
env_name = "MontezumaRevengeNoFrameskip-v4"
|
95 |
+
super().__init__(env_name=env_name, render_mode=render_mode)
|
96 |
+
|
97 |
+
|
98 |
+
class RepresentedPitfall(RepresentedAtariEnv):
|
99 |
+
def __init__(self, render_mode: Optional[str]=None):
|
100 |
+
env_name = "PitfallNoFrameskip-v4"
|
101 |
+
super().__init__(env_name=env_name, render_mode=render_mode)
|
102 |
+
|
103 |
+
|
104 |
+
class RepresentedPong(RepresentedAtariEnv):
|
105 |
+
def __init__(self, render_mode: Optional[str]=None):
|
106 |
+
env_name = "PongNoFrameskip-v4"
|
107 |
+
super().__init__(env_name=env_name, render_mode=render_mode)
|
108 |
+
|
109 |
+
|
110 |
+
class RepresentedPrivateEye(RepresentedAtariEnv):
|
111 |
+
def __init__(self, render_mode: Optional[str]=None):
|
112 |
+
env_name = "PrivateEyeNoFrameskip-v4"
|
113 |
+
super().__init__(env_name=env_name, render_mode=render_mode)
|
114 |
+
|
115 |
+
|
116 |
+
class RepresentedQbert(RepresentedAtariEnv):
|
117 |
+
def __init__(self, render_mode: Optional[str]=None):
|
118 |
+
env_name = "QbertNoFrameskip-v4"
|
119 |
+
super().__init__(env_name=env_name, render_mode=render_mode)
|
120 |
+
|
121 |
+
|
122 |
+
class RepresentedRiverraid(RepresentedAtariEnv):
|
123 |
+
def __init__(self, render_mode: Optional[str]=None):
|
124 |
+
env_name = "RiverraidNoFrameskip-v4"
|
125 |
+
super().__init__(env_name=env_name, render_mode=render_mode)
|
126 |
+
|
127 |
+
|
128 |
+
class RepresentedSeaquest(RepresentedAtariEnv):
|
129 |
+
def __init__(self, render_mode: Optional[str]=None):
|
130 |
+
env_name = "SeaquestNoFrameskip-v4"
|
131 |
+
super().__init__(env_name=env_name, render_mode=render_mode)
|
132 |
+
|
133 |
+
|
134 |
+
class RepresentedSpaceInvaders(RepresentedAtariEnv):
|
135 |
+
def __init__(self, render_mode: Optional[str]=None):
|
136 |
+
env_name = "SpaceInvadersNoFrameskip-v4"
|
137 |
+
super().__init__(env_name=env_name, render_mode=render_mode)
|
138 |
+
|
139 |
+
|
140 |
+
class RepresentedTennis(RepresentedAtariEnv):
|
141 |
+
def __init__(self, render_mode: Optional[str]=None):
|
142 |
+
env_name = "TennisNoFrameskip-v4"
|
143 |
+
super().__init__(env_name=env_name, render_mode=render_mode)
|
144 |
+
|
145 |
+
|
146 |
+
class RepresentedVenture(RepresentedAtariEnv):
|
147 |
+
def __init__(self, render_mode: Optional[str]=None):
|
148 |
+
env_name = "VentureNoFrameskip-v4"
|
149 |
+
super().__init__(env_name=env_name, render_mode=render_mode)
|
150 |
+
|
151 |
+
|
152 |
+
class RepresentedVideoPinball(RepresentedAtariEnv):
|
153 |
+
def __init__(self, render_mode: Optional[str]=None):
|
154 |
+
env_name = "VideoPinballNoFrameskip-v4"
|
155 |
+
super().__init__(env_name=env_name, render_mode=render_mode)
|
156 |
+
|
157 |
+
|
158 |
+
def env_factory(env_class):
|
159 |
+
def _create_instance(render_mode=None):
|
160 |
+
return env_class(render_mode=render_mode)
|
161 |
+
return _create_instance
|
162 |
+
|
163 |
+
|
164 |
+
def register_environments():
|
165 |
+
env_classes = {
|
166 |
+
'RepresentedMsPacman-v0': RepresentedMsPacman,
|
167 |
+
'RepresentedBowling-v0': RepresentedBowling,
|
168 |
+
'RepresentedBoxing-v0': RepresentedBoxing,
|
169 |
+
'RepresentedBreakout-v0': RepresentedBreakout,
|
170 |
+
'RepresentedDemonAttack-v0': RepresentedDemonAttack,
|
171 |
+
'RepresentedFreeway-v0': RepresentedFreeway,
|
172 |
+
'RepresentedFrostbite-v0': RepresentedFrostbite,
|
173 |
+
'RepresentedHero-v0': RepresentedHero,
|
174 |
+
'RepresentedMontezumaRevenge-v0': RepresentedMontezumaRevenge,
|
175 |
+
'RepresentedPitfall-v0': RepresentedPitfall,
|
176 |
+
'RepresentedPong-v0': RepresentedPong,
|
177 |
+
'RepresentedPrivateEye-v0': RepresentedPrivateEye,
|
178 |
+
'RepresentedQbert-v0': RepresentedQbert,
|
179 |
+
'RepresentedRiverraid-v0': RepresentedRiverraid,
|
180 |
+
'RepresentedSeaquest-v0': RepresentedSeaquest,
|
181 |
+
'RepresentedSpaceInvaders-v0': RepresentedSpaceInvaders,
|
182 |
+
'RepresentedTennis-v0': RepresentedTennis,
|
183 |
+
'RepresentedVenture-v0': RepresentedVenture,
|
184 |
+
'RepresentedVideoPinball-v0': RepresentedVideoPinball
|
185 |
+
}
|
186 |
+
|
187 |
+
for env_name, env_class in env_classes.items():
|
188 |
+
gym.register(
|
189 |
+
id=env_name,
|
190 |
+
entry_point=env_factory(env_class),
|
191 |
+
)
|
192 |
+
|
193 |
+
|
194 |
+
# register_environments()
|
195 |
+
# env_classes = {
|
196 |
+
# 'RepresentedMsPacman-v0': RepresentedMsPacman,
|
197 |
+
# 'RepresentedBowling-v0': RepresentedBowling,
|
198 |
+
# 'RepresentedBoxing-v0': RepresentedBoxing,
|
199 |
+
# 'RepresentedBreakout-v0': RepresentedBreakout,
|
200 |
+
# 'RepresentedDemonAttack-v0': RepresentedDemonAttack,
|
201 |
+
# 'RepresentedFreeway-v0': RepresentedFreeway,
|
202 |
+
# 'RepresentedFrostbite-v0': RepresentedFrostbite,
|
203 |
+
# 'RepresentedHero-v0': RepresentedHero,
|
204 |
+
# 'RepresentedMontezumaRevenge-v0': RepresentedMontezumaRevenge,
|
205 |
+
# 'RepresentedPitfall-v0': RepresentedPitfall,
|
206 |
+
# 'RepresentedPong-v0': RepresentedPong,
|
207 |
+
# 'RepresentedPrivateEye-v0': RepresentedPrivateEye,
|
208 |
+
# 'RepresentedQbert-v0': RepresentedQbert,
|
209 |
+
# 'RepresentedRiverraid-v0': RepresentedRiverraid,
|
210 |
+
# 'RepresentedSeaquest-v0': RepresentedSeaquest,
|
211 |
+
# 'RepresentedSpaceInvaders-v0': RepresentedSpaceInvaders,
|
212 |
+
# 'RepresentedTennis-v0': RepresentedTennis,
|
213 |
+
# 'RepresentedVenture-v0': RepresentedVenture,
|
214 |
+
# 'RepresentedVideoPinball-v0': RepresentedVideoPinball
|
215 |
+
# }
|
216 |
+
#
|
217 |
+
# for env, env_class in env_classes.items():
|
218 |
+
# env_1 = env_class()
|
219 |
+
# env_name = env_1.env_name
|
220 |
+
# env_2 = gym.make(env_name)
|
221 |
+
# print(env_name, env_1.action_space == env_2.action_space, env_1.action_space)
|
record_reflexion.csv
CHANGED
@@ -8,4 +8,6 @@ Taxi-v3,1,expert,200.0
|
|
8 |
CliffWalking-v0,1,expert,200.0
|
9 |
FrozenLake-v1,1,expert,200.0
|
10 |
MountainCarContinuous-v0,1,expert,200.0
|
|
|
|
|
11 |
|
|
|
8 |
CliffWalking-v0,1,expert,200.0
|
9 |
FrozenLake-v1,1,expert,200.0
|
10 |
MountainCarContinuous-v0,1,expert,200.0
|
11 |
+
RepresentedBoxing-v0,1,expert,200.0
|
12 |
+
RepresentedPong-v0,1,expert,200.0
|
13 |
|
test_atari.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
python main_reflexion.py --env_name RepresentedBoxing-v0 --init_summarizer RepresentedBoxing_init_translator --curr_summarizer RepresentedBoxing_basic_translator --decider naive_actor --prompt_level 1 --num_trails 1 --seed 0
|