Andrei Cozma commited on
Commit
efbb9e7
·
1 Parent(s): 4441b2d
Files changed (2) hide show
  1. MonteCarloAgent.py +8 -7
  2. requirements.txt +1 -1
MonteCarloAgent.py CHANGED
@@ -27,6 +27,8 @@ class MonteCarloAgent:
27
 
28
  self.env_kwargs = kwargs
29
  if self.env_name == "FrozenLake-v1":
 
 
30
  self.env_kwargs["desc"] = [
31
  "SFFFFFFF",
32
  "FFFFFFFH",
@@ -37,7 +39,6 @@ class MonteCarloAgent:
37
  "FHFFHFHF",
38
  "FFFHFFFG",
39
  ]
40
- # self.env_kwargs["map_name"] = "8x8"
41
  self.env_kwargs["is_slippery"] = False
42
 
43
  self.env = gym.make(self.env_name, **self.env_kwargs)
@@ -72,7 +73,7 @@ class MonteCarloAgent:
72
 
73
  def choose_action(self, state, epsilon_override=None, greedy=False, **kwargs):
74
  # Sample an action from the policy.
75
- # The override_epsilon argument allows forcing the use of a possibly new self.epsilon value than the one used during training.
76
  # The ability to override was mostly added for testing purposes and for the demo.
77
  greedy_action = np.argmax(self.Pi[state])
78
 
@@ -112,11 +113,11 @@ class MonteCarloAgent:
112
  episode_hist.append((state, action, reward))
113
  yield episode_hist, solved, rgb_array
114
 
 
115
  rgb_array = self.env.render() if render else None
 
116
  # For CliffWalking-v0 and Taxi-v3, the episode is solved when it terminates
117
- if done and (
118
- self.env_name == "CliffWalking-v0" or self.env_name == "Taxi-v3"
119
- ):
120
  solved = True
121
  break
122
 
@@ -125,9 +126,10 @@ class MonteCarloAgent:
125
  if done and self.env_name == "FrozenLake-v1":
126
  if next_state == self.env.nrow * self.env.ncol - 1:
127
  solved = True
128
- # print("Solved!")
129
  break
130
  else:
 
 
131
  done = False
132
  next_state, _ = self.env.reset()
133
 
@@ -137,7 +139,6 @@ class MonteCarloAgent:
137
  state = next_state
138
 
139
  rgb_array = self.env.render() if render else None
140
-
141
  yield episode_hist, solved, rgb_array
142
 
143
  def run_episode(self, max_steps=500, render=False, **kwargs):
 
27
 
28
  self.env_kwargs = kwargs
29
  if self.env_name == "FrozenLake-v1":
30
+ # Can use defaults by defining map_name (4x4 or 8x8) or custom map by defining desc
31
+ # self.env_kwargs["map_name"] = "8x8"
32
  self.env_kwargs["desc"] = [
33
  "SFFFFFFF",
34
  "FFFFFFFH",
 
39
  "FHFFHFHF",
40
  "FFFHFFFG",
41
  ]
 
42
  self.env_kwargs["is_slippery"] = False
43
 
44
  self.env = gym.make(self.env_name, **self.env_kwargs)
 
73
 
74
  def choose_action(self, state, epsilon_override=None, greedy=False, **kwargs):
75
  # Sample an action from the policy.
76
+ # The epsilon_override argument allows forcing the use of a new epsilon value than the one previously used during training.
77
  # The ability to override was mostly added for testing purposes and for the demo.
78
  greedy_action = np.argmax(self.Pi[state])
79
 
 
113
  episode_hist.append((state, action, reward))
114
  yield episode_hist, solved, rgb_array
115
 
116
+ # Rendering new frame if needed
117
  rgb_array = self.env.render() if render else None
118
+
119
  # For CliffWalking-v0 and Taxi-v3, the episode is solved when it terminates
120
+ if done and self.env_name in ["CliffWalking-v0", "Taxi-v3"]:
 
 
121
  solved = True
122
  break
123
 
 
126
  if done and self.env_name == "FrozenLake-v1":
127
  if next_state == self.env.nrow * self.env.ncol - 1:
128
  solved = True
 
129
  break
130
  else:
131
+ # Instead of terminating the episode when the agent moves into a hole, we reset the environment
132
+ # This is to keep consistent with the other environments
133
  done = False
134
  next_state, _ = self.env.reset()
135
 
 
139
  state = next_state
140
 
141
  rgb_array = self.env.render() if render else None
 
142
  yield episode_hist, solved, rgb_array
143
 
144
  def run_episode(self, max_steps=500, render=False, **kwargs):
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  gradio==3.27.0
2
- Gymnasium==0.26.3
3
  numpy==1.21.5
4
  opencv_python_headless==4.6.0.66
5
  pip==22.0.2
 
1
  gradio==3.27.0
2
+ gymnasium[toy_text]==0.28.1
3
  numpy==1.21.5
4
  opencv_python_headless==4.6.0.66
5
  pip==22.0.2