hyyh28 commited on
Commit
1aa4792
·
1 Parent(s): eb1b37d

update atari_env

Browse files

update atari_env, also make a test_atari.sh

deciders/parser.py CHANGED
@@ -40,6 +40,25 @@ class SixAction(BaseModel):
40
  if field not in [1, 2, 3, 4, 5, 6]:
41
  raise ValueError("Action is not valid ([1, 2, 3, 4, 5, 6])!")
42
  return field
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  class ContinuousAction(BaseModel):
45
  action: float = Field(description="the choosed action to perform")
@@ -50,4 +69,4 @@ class ContinuousAction(BaseModel):
50
  raise ValueError("Action is not valid ([-1,1])!")
51
  return field
52
 
53
- PARSERS = {1:ContinuousAction, 2: TwoAction, 3: ThreeAction, 4: FourAction, 6: SixAction}
 
40
  if field not in [1, 2, 3, 4, 5, 6]:
41
  raise ValueError("Action is not valid ([1, 2, 3, 4, 5, 6])!")
42
  return field
43
+
44
+
45
+ class NineAction(BaseModel):
46
+ action: int = Field(description="the choosed action to perform")
47
+
48
+ # You can add custom validation logic easily with Pydantic.
49
+ @validator('action')
50
+ def action_is_valid(cls, field):
51
+ if field not in [1, 2, 3, 4, 5, 6, 7, 8, 9]:
52
+ raise ValueError("Action is not valid ([1, 2, 3, 4, 5, 6, 7, 8, 9])!")
53
+ return field
54
+
55
+ class FullAtariAction(BaseModel):
56
+ action: int = Field(description="the choosed action to perform")
57
+ @validator('action')
58
+ def action_is_valid(cls, info):
59
+ if info not in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]:
60
+ raise ValueError("Action is not valid ([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18])!")
61
+ return info
62
 
63
  class ContinuousAction(BaseModel):
64
  action: float = Field(description="the choosed action to perform")
 
69
  raise ValueError("Action is not valid ([-1,1])!")
70
  return field
71
 
72
+ PARSERS = {1:ContinuousAction, 2: TwoAction, 3: ThreeAction, 4: FourAction, 6: SixAction, 9:NineAction, 18: FullAtariAction}
envs/__init__.py CHANGED
@@ -10,6 +10,10 @@ from .toy_text import blackjack_translator, blackjack_policies
10
  from .toy_text import taxi_translator, taxi_policies
11
  from .toy_text import cliffwalking_translator, cliffwalking_policies
12
  from .toy_text import frozenlake_translator, frozenlake_policies
 
 
 
 
13
 
14
  REGISTRY = {}
15
  REGISTRY["sampling_wrapper"] = SettableStateEnv
@@ -48,4 +52,43 @@ REGISTRY["frozenlake_policies"] = [frozenlake_policies.dedicated_1_policy, froze
48
 
49
  REGISTRY["mountaincarContinuous_init_translator"] = mountaincarContinuous_translator.GameDescriber
50
  REGISTRY["mountaincarContinuous_basic_translator"] = mountaincarContinuous_translator.BasicStateSequenceTranslator
51
- REGISTRY["mountaincarContinuous_policies"] = [mountaincarContinuous_policies.pseudo_random_policy, mountaincarContinuous_policies.real_random_policy]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  from .toy_text import taxi_translator, taxi_policies
11
  from .toy_text import cliffwalking_translator, cliffwalking_policies
12
  from .toy_text import frozenlake_translator, frozenlake_policies
13
+ from .atari import register_environments
14
+ from .atari import Boxing_policies, Boxing_translator, Pong_policies, Pong_translator
15
+
16
+ register_environments()
17
 
18
  REGISTRY = {}
19
  REGISTRY["sampling_wrapper"] = SettableStateEnv
 
52
 
53
  REGISTRY["mountaincarContinuous_init_translator"] = mountaincarContinuous_translator.GameDescriber
54
  REGISTRY["mountaincarContinuous_basic_translator"] = mountaincarContinuous_translator.BasicStateSequenceTranslator
55
+ REGISTRY["mountaincarContinuous_policies"] = [mountaincarContinuous_policies.pseudo_random_policy, mountaincarContinuous_policies.real_random_policy]
56
+
57
+
58
+ REGISTRY["RepresentedBoxing_init_translator"] = Boxing_translator.GameDescriber
59
+ REGISTRY["RepresentedBoxing_basic_translator"] = Boxing_translator.BasicStateSequenceTranslator
60
+ REGISTRY["RepresentedBoxing_basic_policies"] = [
61
+ Boxing_policies.real_random_policy,
62
+ Boxing_policies.pseudo_random_policy,
63
+ Boxing_policies.dedicated_1_policy,
64
+ Boxing_policies.dedicated_2_policy,
65
+ Boxing_policies.dedicated_3_policy,
66
+ Boxing_policies.dedicated_4_policy,
67
+ Boxing_policies.dedicated_5_policy,
68
+ Boxing_policies.dedicated_6_policy,
69
+ Boxing_policies.dedicated_7_policy,
70
+ Boxing_policies.dedicated_8_policy,
71
+ Boxing_policies.dedicated_9_policy,
72
+ Boxing_policies.dedicated_10_policy,
73
+ Boxing_policies.dedicated_11_policy,
74
+ Boxing_policies.dedicated_12_policy,
75
+ Boxing_policies.dedicated_13_policy,
76
+ Boxing_policies.dedicated_14_policy,
77
+ Boxing_policies.dedicated_15_policy,
78
+ Boxing_policies.dedicated_16_policy,
79
+ Boxing_policies.dedicated_17_policy,
80
+ Boxing_policies.dedicated_18_policy
81
+ ]
82
+
83
+ REGISTRY["RepresentedPong_init_translator"] = Pong_translator.GameDescriber
84
+ REGISTRY["RepresentedPong_basic_translator"] = Pong_translator.BasicStateSequenceTranslator
85
+ REGISTRY["RepresentedPong_basic_policies"] = [
86
+ Pong_policies.real_random_policy,
87
+ Pong_policies.pseudo_random_policy,
88
+ Pong_policies.dedicated_1_policy,
89
+ Pong_policies.dedicated_2_policy,
90
+ Pong_policies.dedicated_3_policy,
91
+ Pong_policies.dedicated_4_policy,
92
+ Pong_policies.dedicated_5_policy,
93
+ Pong_policies.dedicated_6_policy,
94
+ ]
envs/atari/Boxing_policies.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def dedicated_1_policy(state, pre_action=1):
5
+ def get_description():
6
+ return "Always select action 1 which does NOOP (no operation)"
7
+
8
+ dedicated_1_policy.description = get_description()
9
+ return 1
10
+
11
+
12
+ def dedicated_2_policy(state, pre_action=1):
13
+ def get_description():
14
+ return "Always select action 2 which hits the enemy"
15
+
16
+ dedicated_1_policy.description = get_description()
17
+ return 2
18
+
19
+
20
+ def dedicated_3_policy(state, pre_action=1):
21
+ def get_description():
22
+ return "Always select action 3 which moves the agent up"
23
+
24
+ dedicated_3_policy.description = get_description()
25
+ return 3
26
+
27
+
28
+ def dedicated_4_policy(state, pre_action=1):
29
+ def get_description():
30
+ return "Always select action 4 which moves the agent right"
31
+
32
+ dedicated_4_policy.description = get_description()
33
+ return 4
34
+
35
+
36
+ def dedicated_5_policy(state, pre_action=1):
37
+ def get_description():
38
+ return "Always select action 5 which moves the agent left"
39
+
40
+ dedicated_5_policy.description = get_description()
41
+ return 5
42
+
43
+
44
+ def pseudo_random_policy(state, pre_action):
45
+ def get_description():
46
+ return "Select an action among 1 to 18 alternatively"
47
+ pseudo_random_policy.description = get_description()
48
+ return pre_action % 18 + 1
49
+
50
+
51
+ def real_random_policy(state, pre_action=1):
52
+ def get_description():
53
+ return "Select action with a random policy"
54
+
55
+ real_random_policy.description = get_description()
56
+ return np.random.choice(range(0, 18)) + 1
57
+
58
+
59
+ # Complete set of dedicated action policies
60
+ def dedicated_6_policy(state, pre_action=1):
61
+ def get_description():
62
+ return "Always select action 6 which moves the agent down"
63
+
64
+ dedicated_6_policy.description = get_description()
65
+ return 6
66
+
67
+
68
+ def dedicated_7_policy(state, pre_action=1):
69
+ def get_description():
70
+ return "Always select action 7 which moves the agent up and to the right"
71
+
72
+ dedicated_7_policy.description = get_description()
73
+ return 7
74
+
75
+
76
+ def dedicated_8_policy(state, pre_action=1):
77
+ def get_description():
78
+ return "Always select action 8 which moves the agent up and to the left"
79
+
80
+ dedicated_8.description = get_description()
81
+ return 8
82
+
83
+
84
+ def dedicated_9_policy(state, pre_action=1):
85
+ def get_description():
86
+ return "Always select action 9 which moves the agent down and to the right"
87
+
88
+ dedicated_9.description = get_description()
89
+ return 9
90
+
91
+
92
+ def dedicated_10_policy(state, pre_action=1):
93
+ def get_description():
94
+ return "Always select action 10 which moves the agent down and to the left"
95
+
96
+ dedicated_10_policy.description = get_description()
97
+ return 10
98
+
99
+
100
+ def dedicated_11_policy(state, pre_action=1):
101
+ def get_description():
102
+ return "Always select action 11 which moves the agent up while hiting the enemy"
103
+
104
+ dedicated_11_policy.description = get_description()
105
+ return 11
106
+
107
+
108
+ def dedicated_12_policy(state, pre_action=1):
109
+ def get_description():
110
+ return "Always select action 12 which moves the agent right while hiting the enemy"
111
+
112
+ dedicated_12_policy.description = get_description()
113
+ return 12
114
+
115
+
116
+ def dedicated_13_policy(state, pre_action=1):
117
+ def get_description():
118
+ return "Always select action 13 which moves the agent left while hiting the enemy"
119
+
120
+ dedicated_13_policy.description = get_description()
121
+ return 13
122
+
123
+
124
+ def dedicated_14_policy(state, pre_action=1):
125
+ def get_description():
126
+ return "Always select action 14 which moves the agent down while hiting the enemy"
127
+
128
+ dedicated_14_policy.description = get_description()
129
+ return 14
130
+
131
+
132
+ def dedicated_15_policy(state, pre_action=1):
133
+ def get_description():
134
+ return "Always select action 15 which moves the agent up and to the right while hiting the enemy"
135
+
136
+ dedicated_15_policy.description = get_description()
137
+ return 15
138
+
139
+
140
+ def dedicated_16_policy(state, pre_action=1):
141
+ def get_description():
142
+ return "Always select action 16 which moves the agent up and to the left while hiting the enemy"
143
+
144
+ dedicated_16_policy.description = get_description()
145
+ return 16
146
+
147
+
148
+ def dedicated_17_policy(state, pre_action=1):
149
+ def get_description():
150
+ return "Always select action 17 which moves the agent down and to the right while hiting the enemy"
151
+
152
+ dedicated_17_policy.description = get_description()
153
+ return 17
154
+
155
+
156
+ def dedicated_18_policy(state, pre_action=1):
157
+ def get_description():
158
+ return "Always select action 18 which moves the agent down and to the left while hiting the enemy"
159
+
160
+ dedicated_18_policy.description = get_description()
161
+ return 18
envs/atari/Boxing_translator.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # [Translator classes and functions for Atari Boxing environment]
2
+
3
+ class BasicLevelTranslator:
4
+ def __init__(self,):
5
+ pass
6
+
7
+ def translate(self, state):
8
+ player_x, player_y, enemy_x, enemy_y, enemy_score, clock, player_score = state
9
+ return f"The player is at position ({player_x, player_y}, your opponent is at position ({enemy_x, enemy_y}) ), " \
10
+ f"your oppoent's score is {enemy_score}, your score is {player_score}. Move left and right will change the player_x while move up and down will change the player_y"
11
+
12
+
13
+ class GameDescriber:
14
+ def __init__(self, args):
15
+ self.is_only_local_obs = args.is_only_local_obs == 1
16
+ self.max_episode_len = args.max_episode_len
17
+ self.action_desc_dict = {
18
+ }
19
+ self.reward_desc_dict = {
20
+ }
21
+
22
+ def describe_goal(self):
23
+ return "The goal is to knock out your opponent."
24
+
25
+ def translate_terminate_state(self, state, episode_len, max_episode_len):
26
+ return ""
27
+
28
+ def translate_potential_next_state(self, state, action):
29
+ return ""
30
+
31
+ def describe_game(self):
32
+ return "In the Boxing game, you fight an opponent in a boxing ring. You score points for hitting the opponent. " \
33
+ "If you score 100 points, your opponent is knocked out.Scoring Points: When you get near enough to your opponent to throw a punch, " \
34
+ "press the red button. Each punch moves your opponent slightly back and away from the punch." \
35
+ " If you move him to the ropes, he can't easily duck the next punch, " \
36
+ "and you can set up a real scoring barrage. But don't get caughton the ropes yourself! " \
37
+ "Watch your distance. If you move in too close, the computer gets tougher;" \
38
+ " but if you're too far away, you can't land scoring punches. "
39
+
40
+ def describe_action(self):
41
+ return "Your Next Move: \n Please choose an action. Type '1' for NOOP (no operation), '2' to hit your opponent, " \
42
+ "'3' to move up, '4' to move right, '5' to move left, '6' to move down, '7' to move up-right, " \
43
+ "'8' to move up-left, '9' to move down-right, '10' to move down-left, '11' to hit your opponent and move up, " \
44
+ "'12' to hit your opponent and move right, '13' to hit your opponent and move left, '14' to hit your opponent and move down, " \
45
+ "'15' to hit your opponent and move up-right, '16' to hit your opponent and move up-left, '17' to hit your opponent and move down-right, " \
46
+ "or '18' to hit your opponent and move down-left. Ensure you only provide the action number " \
47
+ "from the valid action list, i.e., [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]."
48
+
49
+
50
+ class BasicStateSequenceTranslator(BasicLevelTranslator):
51
+ def translate(self, infos, is_current=False):
52
+ descriptions = []
53
+ if is_current:
54
+ state_desc = BasicLevelTranslator().translate(infos[-1]['state'])
55
+ return state_desc
56
+ for i, info in enumerate(infos):
57
+ assert 'state' in info, "info should contain state information"
58
+
59
+ state_desc = BasicLevelTranslator().translate(info['state'])
60
+ if info['action'] == 1:
61
+ action_desc = f"Take Action: 'Do nothing'"
62
+ elif info['action'] == 2:
63
+ action_desc = f"Take Action: 'Hit your opponent'"
64
+ elif info['action'] == 3:
65
+ action_desc = f"Take Action: 'Move up'"
66
+ elif info['action'] == 4:
67
+ action_desc = f"Take Action: 'Move right'"
68
+ elif info['action'] == 5:
69
+ action_desc = f"Take Action: 'Move left'"
70
+ elif info['action'] == 6:
71
+ action_desc = f"Take Action: 'Move down'"
72
+ elif info['action'] == 7:
73
+ action_desc = f"Take Action: 'Move up-right'"
74
+ elif info['action'] == 8:
75
+ action_desc = f"Take Action: 'Move up-lefr'"
76
+ elif info['action'] == 9:
77
+ action_desc = f"Take Action: 'Move down-right'"
78
+ elif info['action'] == 10:
79
+ action_desc = f"Take Action: 'Move down-left'"
80
+ elif info['action'] == 11:
81
+ action_desc = f"Take Action: 'Hit your opponent and move up'"
82
+ elif info['action'] == 12:
83
+ action_desc = f"Take Action: 'Hit your opponent and move right'"
84
+ elif info['action'] == 13:
85
+ action_desc = f"Take Action: 'Hit your opponent and move left'"
86
+ elif info['action'] == 14:
87
+ action_desc = f"Take Action: 'Hit your opponent and move down'"
88
+ elif info['action'] == 15:
89
+ action_desc = f"Take Action: 'Hit your opponent and move up-right'"
90
+ elif info['action'] == 16:
91
+ action_desc = f"Take Action: 'Hit your opponent and move up-left'"
92
+ elif info['action'] == 17:
93
+ action_desc = f"Take Action: 'Hit your opponent and move down-right'"
94
+ else:
95
+ action_desc = f"Take Action: 'Hit your opponent and move down-left'"
96
+ reward_desc = f"Result: Reward of {info['reward']}, "
97
+ next_state_desc = BasicLevelTranslator().translate(info['next_state'])
98
+ descriptions.append(f"{state_desc}.\n {action_desc} \n {reward_desc} \n Transit to {next_state_desc}")
99
+ return descriptions
envs/atari/Pong_policies.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def dedicated_1_policy(state, pre_action=1):
5
+ def get_description():
6
+ return "Always select action 1 which does NOOP (no operation)"
7
+
8
+ dedicated_1_policy.description = get_description()
9
+ return 1
10
+
11
+
12
+ def dedicated_2_policy(state, pre_action=1):
13
+ def get_description():
14
+ return "Always select action 2 which hits the ball"
15
+
16
+ dedicated_1_policy.description = get_description()
17
+ return 2
18
+
19
+
20
+ def dedicated_3_policy(state, pre_action=1):
21
+ def get_description():
22
+ return "Always select action 3 which moves the agent right"
23
+
24
+ dedicated_3_policy.description = get_description()
25
+ return 3
26
+
27
+
28
+ def dedicated_4_policy(state, pre_action=1):
29
+ def get_description():
30
+ return "Always select action 4 which moves the agent left"
31
+
32
+ dedicated_4_policy.description = get_description()
33
+ return 4
34
+
35
+
36
+ def dedicated_5_policy(state, pre_action=1):
37
+ def get_description():
38
+ return "Always select action 5 which moves the agent right while hiting the ball"
39
+
40
+ dedicated_5_policy.description = get_description()
41
+ return 5
42
+
43
+
44
+ def pseudo_random_policy(state, pre_action):
45
+ def get_description():
46
+ return "Select an action among 1 to 6 alternatively"
47
+ pseudo_random_policy.description = get_description()
48
+ return pre_action % 6 + 1
49
+
50
+
51
+ def real_random_policy(state, pre_action=1):
52
+ def get_description():
53
+ return "Select action with a random policy"
54
+
55
+ real_random_policy.description = get_description()
56
+ return np.random.choice(range(0, 6)) + 1
57
+
58
+
59
+ # Complete set of dedicated action policies
60
+ def dedicated_6_policy(state, pre_action=1):
61
+ def get_description():
62
+ return "Always select action 5 which moves the agent left while hiting the ball"
63
+
64
+ dedicated_6_policy.description = get_description()
65
+ return 6
envs/atari/Pong_translator.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # [Translator classes and functions for Atari Boxing environment]
2
+ #'labels': {'player_y': 109, 'player_x': 188, 'enemy_y': 20, 'enemy_x': 64, 'ball_x': 0, 'ball_y': 0, 'enemy_score': 0, 'player_score': 0}
3
+ class BasicLevelTranslator:
4
+ def __init__(self, ):
5
+ pass
6
+
7
+ def translate(self, state):
8
+ player_y, player_x, enemy_y, enemy_x, ball_x, ball_y, enemy_score, player_score = state
9
+ return f"The player is at position ({player_y, player_x}, your opponent is at position ({enemy_y, enemy_x}) ), the ball is at ({ball_y, ball_x})" \
10
+ f"your oppoent's score is {enemy_score}, your score is {player_score}."
11
+
12
+
13
+ class GameDescriber:
14
+ def __init__(self, args):
15
+ self.is_only_local_obs = args.is_only_local_obs == 1
16
+ self.max_episode_len = args.max_episode_len
17
+ self.action_desc_dict = {
18
+ }
19
+ self.reward_desc_dict = {
20
+ }
21
+
22
+ def describe_goal(self):
23
+ return "The goal is to knock out your opponent."
24
+
25
+ def translate_terminate_state(self, state, episode_len, max_episode_len):
26
+ return ""
27
+
28
+ def translate_potential_next_state(self, state, action):
29
+ return ""
30
+
31
+ def describe_game(self):
32
+ return "In the Pong game, you play the ball with your opponent, each player rallys the ball by moving the paddles on the playfield. " \
33
+ "Paddles move only vertically on the playfield. A player scores one point when the opponent hits the ball out of bounds or misses a hit. " \
34
+ "The first player to score 21 points wins the game."
35
+
36
+ def describe_action(self):
37
+ return "Your Next Move: \n Please choose an action. Type '1' for NOOP (no operation), '2' to hit the ball, " \
38
+ "'3' to move right, '4' to move left, '5' to move right while hit the ball, '6' to move left while hit the ball. Ensure you only provide the action number " \
39
+ "from the valid action list, i.e., [1, 2, 3, 4, 5, 6]."
40
+
41
+
42
+ class BasicStateSequenceTranslator(BasicLevelTranslator):
43
+ def translate(self, infos, is_current=False):
44
+ descriptions = []
45
+ if is_current:
46
+ state_desc = BasicLevelTranslator().translate(infos[-1]['state'])
47
+ return state_desc
48
+ for i, info in enumerate(infos):
49
+ assert 'state' in info, "info should contain state information"
50
+
51
+ state_desc = BasicLevelTranslator().translate(info['state'])
52
+ if info['action'] == 1:
53
+ action_desc = f"Take Action: 'Do nothing'"
54
+ elif info['action'] == 2:
55
+ action_desc = f"Take Action: 'Hit your ball'"
56
+ elif info['action'] == 3:
57
+ action_desc = f"Take Action: 'Move right'"
58
+ elif info['action'] == 4:
59
+ action_desc = f"Take Action: 'Move left'"
60
+ elif info['action'] == 5:
61
+ action_desc = f"Take Action: 'Move right while hiting the ball'"
62
+ else:
63
+ action_desc = f"Take Action: 'Move left while hiting the ball'"
64
+ reward_desc = f"Result: Reward of {info['reward']}, "
65
+ next_state_desc = BasicLevelTranslator().translate(info['next_state'])
66
+ descriptions.append(f"{state_desc}.\n {action_desc} \n {reward_desc} \n Transit to {next_state_desc}")
67
+ return descriptions
envs/atari/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .represented_atari_game import register_environments
envs/atari/represented_atari_game.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gym
2
+ import ale_py
3
+ import numpy as np
4
+ from atariari.benchmark.wrapper import AtariARIWrapper
5
+ from typing import Optional, Union
6
+
7
+
8
+
9
+ class RepresentedAtariEnv(gym.Wrapper):
10
+ def __init__(self, env_name, render_mode=None):
11
+ super().__init__(AtariARIWrapper(gym.make(env_name, render_mode=render_mode)))
12
+ self.metadata = self.env.metadata
13
+ self.env_name = env_name
14
+ self.observation = None
15
+ self.info = {}
16
+ self.action_space = self.env.action_space
17
+ _ = self.env.reset()
18
+ obs = self.env.labels()
19
+ obs_dim = len(obs)
20
+ self.obs_label = obs.keys()
21
+ self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(obs_dim,), dtype=np.float32)
22
+
23
+ def step(self, action):
24
+ original_next_obs, reward, env_done, env_truncated, info = self.env.step(action)
25
+ next_obs = self.env.labels()
26
+ self.obs_label = next_obs.keys()
27
+ self.observation = next_obs
28
+ return np.array(list(next_obs.values())), reward, env_done, env_truncated, info
29
+
30
+ def reset(self, seed=0):
31
+ obs_original, info = self.env.reset(seed=seed)
32
+ obs = self.env.labels()
33
+ self.obs_label = obs.keys()
34
+ self.observation = obs
35
+ return np.array(list(obs.values())), info
36
+
37
+ def get_info(self):
38
+ return self.observation
39
+
40
+ def render(self, render_mode=None):
41
+ return self.env.render()
42
+
43
+
44
+ class RepresentedMsPacman(RepresentedAtariEnv):
45
+ def __init__(self, render_mode: Optional[str]=None):
46
+ env_name = "MsPacmanNoFrameskip-v4"
47
+ super().__init__(env_name=env_name, render_mode=render_mode)
48
+
49
+
50
+ class RepresentedBowling(RepresentedAtariEnv):
51
+ def __init__(self, render_mode: Optional[str]=None):
52
+ env_name = "BowlingNoFrameskip-v4"
53
+ super().__init__(env_name=env_name, render_mode=render_mode)
54
+
55
+
56
+ class RepresentedBoxing(RepresentedAtariEnv):
57
+ def __init__(self, render_mode: Optional[str]=None):
58
+ env_name = "BoxingNoFrameskip-v4"
59
+ super().__init__(env_name=env_name, render_mode=render_mode)
60
+
61
+
62
+ class RepresentedBreakout(RepresentedAtariEnv):
63
+ def __init__(self, render_mode: Optional[str]=None):
64
+ env_name = "BreakoutNoFrameskip-v4"
65
+ super().__init__(env_name=env_name, render_mode=render_mode)
66
+
67
+
68
+ class RepresentedDemonAttack(RepresentedAtariEnv):
69
+ def __init__(self, render_mode: Optional[str]=None):
70
+ env_name = "DemonAttackNoFrameskip-v4"
71
+ super().__init__(env_name=env_name, render_mode=render_mode)
72
+
73
+
74
+ class RepresentedFreeway(RepresentedAtariEnv):
75
+ def __init__(self, render_mode: Optional[str]=None):
76
+ env_name = "FreewayNoFrameskip-v4"
77
+ super().__init__(env_name=env_name, render_mode=render_mode)
78
+
79
+
80
+ class RepresentedFrostbite(RepresentedAtariEnv):
81
+ def __init__(self, render_mode: Optional[str]=None):
82
+ env_name = "FrostbiteNoFrameskip-v4"
83
+ super().__init__(env_name=env_name, render_mode=render_mode)
84
+
85
+
86
+ class RepresentedHero(RepresentedAtariEnv):
87
+ def __init__(self, render_mode: Optional[str]=None):
88
+ env_name = "HeroNoFrameskip-v4"
89
+ super().__init__(env_name=env_name, render_mode=render_mode)
90
+
91
+
92
+ class RepresentedMontezumaRevenge(RepresentedAtariEnv):
93
+ def __init__(self, render_mode: Optional[str]=None):
94
+ env_name = "MontezumaRevengeNoFrameskip-v4"
95
+ super().__init__(env_name=env_name, render_mode=render_mode)
96
+
97
+
98
+ class RepresentedPitfall(RepresentedAtariEnv):
99
+ def __init__(self, render_mode: Optional[str]=None):
100
+ env_name = "PitfallNoFrameskip-v4"
101
+ super().__init__(env_name=env_name, render_mode=render_mode)
102
+
103
+
104
+ class RepresentedPong(RepresentedAtariEnv):
105
+ def __init__(self, render_mode: Optional[str]=None):
106
+ env_name = "PongNoFrameskip-v4"
107
+ super().__init__(env_name=env_name, render_mode=render_mode)
108
+
109
+
110
+ class RepresentedPrivateEye(RepresentedAtariEnv):
111
+ def __init__(self, render_mode: Optional[str]=None):
112
+ env_name = "PrivateEyeNoFrameskip-v4"
113
+ super().__init__(env_name=env_name, render_mode=render_mode)
114
+
115
+
116
+ class RepresentedQbert(RepresentedAtariEnv):
117
+ def __init__(self, render_mode: Optional[str]=None):
118
+ env_name = "QbertNoFrameskip-v4"
119
+ super().__init__(env_name=env_name, render_mode=render_mode)
120
+
121
+
122
+ class RepresentedRiverraid(RepresentedAtariEnv):
123
+ def __init__(self, render_mode: Optional[str]=None):
124
+ env_name = "RiverraidNoFrameskip-v4"
125
+ super().__init__(env_name=env_name, render_mode=render_mode)
126
+
127
+
128
+ class RepresentedSeaquest(RepresentedAtariEnv):
129
+ def __init__(self, render_mode: Optional[str]=None):
130
+ env_name = "SeaquestNoFrameskip-v4"
131
+ super().__init__(env_name=env_name, render_mode=render_mode)
132
+
133
+
134
+ class RepresentedSpaceInvaders(RepresentedAtariEnv):
135
+ def __init__(self, render_mode: Optional[str]=None):
136
+ env_name = "SpaceInvadersNoFrameskip-v4"
137
+ super().__init__(env_name=env_name, render_mode=render_mode)
138
+
139
+
140
+ class RepresentedTennis(RepresentedAtariEnv):
141
+ def __init__(self, render_mode: Optional[str]=None):
142
+ env_name = "TennisNoFrameskip-v4"
143
+ super().__init__(env_name=env_name, render_mode=render_mode)
144
+
145
+
146
+ class RepresentedVenture(RepresentedAtariEnv):
147
+ def __init__(self, render_mode: Optional[str]=None):
148
+ env_name = "VentureNoFrameskip-v4"
149
+ super().__init__(env_name=env_name, render_mode=render_mode)
150
+
151
+
152
+ class RepresentedVideoPinball(RepresentedAtariEnv):
153
+ def __init__(self, render_mode: Optional[str]=None):
154
+ env_name = "VideoPinballNoFrameskip-v4"
155
+ super().__init__(env_name=env_name, render_mode=render_mode)
156
+
157
+
158
+ def env_factory(env_class):
159
+ def _create_instance(render_mode=None):
160
+ return env_class(render_mode=render_mode)
161
+ return _create_instance
162
+
163
+
164
+ def register_environments():
165
+ env_classes = {
166
+ 'RepresentedMsPacman-v0': RepresentedMsPacman,
167
+ 'RepresentedBowling-v0': RepresentedBowling,
168
+ 'RepresentedBoxing-v0': RepresentedBoxing,
169
+ 'RepresentedBreakout-v0': RepresentedBreakout,
170
+ 'RepresentedDemonAttack-v0': RepresentedDemonAttack,
171
+ 'RepresentedFreeway-v0': RepresentedFreeway,
172
+ 'RepresentedFrostbite-v0': RepresentedFrostbite,
173
+ 'RepresentedHero-v0': RepresentedHero,
174
+ 'RepresentedMontezumaRevenge-v0': RepresentedMontezumaRevenge,
175
+ 'RepresentedPitfall-v0': RepresentedPitfall,
176
+ 'RepresentedPong-v0': RepresentedPong,
177
+ 'RepresentedPrivateEye-v0': RepresentedPrivateEye,
178
+ 'RepresentedQbert-v0': RepresentedQbert,
179
+ 'RepresentedRiverraid-v0': RepresentedRiverraid,
180
+ 'RepresentedSeaquest-v0': RepresentedSeaquest,
181
+ 'RepresentedSpaceInvaders-v0': RepresentedSpaceInvaders,
182
+ 'RepresentedTennis-v0': RepresentedTennis,
183
+ 'RepresentedVenture-v0': RepresentedVenture,
184
+ 'RepresentedVideoPinball-v0': RepresentedVideoPinball
185
+ }
186
+
187
+ for env_name, env_class in env_classes.items():
188
+ gym.register(
189
+ id=env_name,
190
+ entry_point=env_factory(env_class),
191
+ )
192
+
193
+
194
+ # register_environments()
195
+ # env_classes = {
196
+ # 'RepresentedMsPacman-v0': RepresentedMsPacman,
197
+ # 'RepresentedBowling-v0': RepresentedBowling,
198
+ # 'RepresentedBoxing-v0': RepresentedBoxing,
199
+ # 'RepresentedBreakout-v0': RepresentedBreakout,
200
+ # 'RepresentedDemonAttack-v0': RepresentedDemonAttack,
201
+ # 'RepresentedFreeway-v0': RepresentedFreeway,
202
+ # 'RepresentedFrostbite-v0': RepresentedFrostbite,
203
+ # 'RepresentedHero-v0': RepresentedHero,
204
+ # 'RepresentedMontezumaRevenge-v0': RepresentedMontezumaRevenge,
205
+ # 'RepresentedPitfall-v0': RepresentedPitfall,
206
+ # 'RepresentedPong-v0': RepresentedPong,
207
+ # 'RepresentedPrivateEye-v0': RepresentedPrivateEye,
208
+ # 'RepresentedQbert-v0': RepresentedQbert,
209
+ # 'RepresentedRiverraid-v0': RepresentedRiverraid,
210
+ # 'RepresentedSeaquest-v0': RepresentedSeaquest,
211
+ # 'RepresentedSpaceInvaders-v0': RepresentedSpaceInvaders,
212
+ # 'RepresentedTennis-v0': RepresentedTennis,
213
+ # 'RepresentedVenture-v0': RepresentedVenture,
214
+ # 'RepresentedVideoPinball-v0': RepresentedVideoPinball
215
+ # }
216
+ #
217
+ # for env, env_class in env_classes.items():
218
+ # env_1 = env_class()
219
+ # env_name = env_1.env_name
220
+ # env_2 = gym.make(env_name)
221
+ # print(env_name, env_1.action_space == env_2.action_space, env_1.action_space)
record_reflexion.csv CHANGED
@@ -8,4 +8,6 @@ Taxi-v3,1,expert,200.0
8
  CliffWalking-v0,1,expert,200.0
9
  FrozenLake-v1,1,expert,200.0
10
  MountainCarContinuous-v0,1,expert,200.0
 
 
11
 
 
8
  CliffWalking-v0,1,expert,200.0
9
  FrozenLake-v1,1,expert,200.0
10
  MountainCarContinuous-v0,1,expert,200.0
11
+ RepresentedBoxing-v0,1,expert,200.0
12
+ RepresentedPong-v0,1,expert,200.0
13
 
test_atari.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ python main_reflexion.py --env_name RepresentedBoxing-v0 --init_summarizer RepresentedBoxing_init_translator --curr_summarizer RepresentedBoxing_basic_translator --decider naive_actor --prompt_level 1 --num_trails 1 --seed 0