Spaces:
Runtime error
Runtime error
Commit
·
80a5e58
1
Parent(s):
1f3aec7
fix seed
Browse files- app.py +22 -11
- video-app/rl-video-episode-0.mp4 +0 -0
- video-app/rl-video-episode-1.meta.json +1 -1
- video-app/rl-video-episode-1.mp4 +0 -0
app.py
CHANGED
@@ -157,6 +157,7 @@ def test_c51(config : dict) -> None:
|
|
157 |
# frame_stack=config["frames_stack"],
|
158 |
# )
|
159 |
env_wrap = gym.make(config["task"],render_mode = 'rgb_array')
|
|
|
160 |
env_deep = wrap_deepmind(env_wrap)
|
161 |
rec_env = DummyVectorEnv(
|
162 |
[
|
@@ -174,8 +175,9 @@ def test_c51(config : dict) -> None:
|
|
174 |
# seed
|
175 |
np.random.seed(config["seed"])
|
176 |
torch.manual_seed(config["seed"])
|
177 |
-
rec_env.seed(config["seed"])
|
178 |
# test_envs.seed(config["seed"])
|
|
|
179 |
|
180 |
|
181 |
net = C51(*state_shape, action_shape, config["num_atoms"], config["device"])
|
@@ -218,8 +220,10 @@ def test_FQF(config : dict) -> None:
|
|
218 |
# frame_stack=config["frames_stack"],
|
219 |
# )
|
220 |
|
221 |
-
|
222 |
-
|
|
|
|
|
223 |
rec_env = DummyVectorEnv(
|
224 |
[
|
225 |
lambda: gym.wrappers.RecordVideo(
|
@@ -235,9 +239,11 @@ def test_FQF(config : dict) -> None:
|
|
235 |
print("Observations shape:", state_shape)
|
236 |
print("Actions shape:", action_shape)
|
237 |
# seed
|
238 |
-
|
239 |
-
|
240 |
-
|
|
|
|
|
241 |
feature_net = DQN(*state_shape, action_shape, config["device"], features_only=True)
|
242 |
|
243 |
# Create FullQuantileFunction net
|
@@ -297,8 +303,9 @@ def test_fqf_rainbow(config: dict) -> None:
|
|
297 |
# scale=config['scale_obs'],
|
298 |
# frame_stack=config['frames_stack'],
|
299 |
# )
|
300 |
-
|
301 |
-
|
|
|
302 |
rec_env = DummyVectorEnv(
|
303 |
[
|
304 |
lambda: gym.wrappers.RecordVideo(
|
@@ -319,10 +326,11 @@ def test_fqf_rainbow(config: dict) -> None:
|
|
319 |
# print("Observations shape:", config['state_shape'])
|
320 |
# print("Actions shape:", config['action_shape'])
|
321 |
# seed
|
322 |
-
|
323 |
-
|
|
|
324 |
# test_envs.seed(config['seed'])
|
325 |
-
rec_env.seed(config['seed'])
|
326 |
# define model
|
327 |
feature_net = DQN(*config['state_shape'], config['action_shape'], config['device'], features_only=True)
|
328 |
preprocess_net_output_dim = feature_net.output_dim # Ensure this is correctly set
|
@@ -377,6 +385,7 @@ def display_choice(algo, game,slider):
|
|
377 |
# Dictionary to store mean scores for each algorithm and game
|
378 |
match algo:
|
379 |
case "C51":
|
|
|
380 |
match game:
|
381 |
case "Freeway":
|
382 |
config_c51["resume_path"] = "models/c51_freeway.pth"
|
@@ -387,6 +396,7 @@ def display_choice(algo, game,slider):
|
|
387 |
return 19
|
388 |
|
389 |
case "FQF":
|
|
|
390 |
match game:
|
391 |
case "Freeway":
|
392 |
config_fqf["resume_path"] = "models/fqf_freeway.pth"
|
@@ -397,6 +407,7 @@ def display_choice(algo, game,slider):
|
|
397 |
return 20
|
398 |
|
399 |
case "FQF-Rainbow":
|
|
|
400 |
match game:
|
401 |
case "Freeway":
|
402 |
config_fqf_r["resume_path"] = "models/fqf-rainbow_freeway.pth"
|
|
|
157 |
# frame_stack=config["frames_stack"],
|
158 |
# )
|
159 |
env_wrap = gym.make(config["task"],render_mode = 'rgb_array')
|
160 |
+
env_wrap.action_space.seed(config["seed"])
|
161 |
env_deep = wrap_deepmind(env_wrap)
|
162 |
rec_env = DummyVectorEnv(
|
163 |
[
|
|
|
175 |
# seed
|
176 |
np.random.seed(config["seed"])
|
177 |
torch.manual_seed(config["seed"])
|
178 |
+
# rec_env.seed(config["seed"])
|
179 |
# test_envs.seed(config["seed"])
|
180 |
+
print("seed is ",config["seed"])
|
181 |
|
182 |
|
183 |
net = C51(*state_shape, action_shape, config["num_atoms"], config["device"])
|
|
|
220 |
# frame_stack=config["frames_stack"],
|
221 |
# )
|
222 |
|
223 |
+
env_wrap = gym.make(config["task"],render_mode = 'rgb_array')
|
224 |
+
env_wrap.action_space.seed(config["seed"])
|
225 |
+
env_deep = wrap_deepmind(env_wrap)
|
226 |
+
|
227 |
rec_env = DummyVectorEnv(
|
228 |
[
|
229 |
lambda: gym.wrappers.RecordVideo(
|
|
|
239 |
print("Observations shape:", state_shape)
|
240 |
print("Actions shape:", action_shape)
|
241 |
# seed
|
242 |
+
print(config["seed"])
|
243 |
+
# np.random.seed(config["seed"])
|
244 |
+
# torch.manual_seed(config["seed"])
|
245 |
+
# rec_env.seed(config["seed"])
|
246 |
+
|
247 |
feature_net = DQN(*state_shape, action_shape, config["device"], features_only=True)
|
248 |
|
249 |
# Create FullQuantileFunction net
|
|
|
303 |
# scale=config['scale_obs'],
|
304 |
# frame_stack=config['frames_stack'],
|
305 |
# )
|
306 |
+
env_wrap = gym.make(config["task"],render_mode = 'rgb_array')
|
307 |
+
env_wrap.action_space.seed(config["seed"])
|
308 |
+
env_deep = wrap_deepmind(env_wrap)
|
309 |
rec_env = DummyVectorEnv(
|
310 |
[
|
311 |
lambda: gym.wrappers.RecordVideo(
|
|
|
326 |
# print("Observations shape:", config['state_shape'])
|
327 |
# print("Actions shape:", config['action_shape'])
|
328 |
# seed
|
329 |
+
print(config["seed"])
|
330 |
+
# np.random.seed(config['seed'])
|
331 |
+
# torch.manual_seed(config['seed'])
|
332 |
# test_envs.seed(config['seed'])
|
333 |
+
# rec_env.seed(config['seed'])
|
334 |
# define model
|
335 |
feature_net = DQN(*config['state_shape'], config['action_shape'], config['device'], features_only=True)
|
336 |
preprocess_net_output_dim = feature_net.output_dim # Ensure this is correctly set
|
|
|
385 |
# Dictionary to store mean scores for each algorithm and game
|
386 |
match algo:
|
387 |
case "C51":
|
388 |
+
config_c51["seed"] = slider
|
389 |
match game:
|
390 |
case "Freeway":
|
391 |
config_c51["resume_path"] = "models/c51_freeway.pth"
|
|
|
396 |
return 19
|
397 |
|
398 |
case "FQF":
|
399 |
+
config_fqf["seed"] = slider
|
400 |
match game:
|
401 |
case "Freeway":
|
402 |
config_fqf["resume_path"] = "models/fqf_freeway.pth"
|
|
|
407 |
return 20
|
408 |
|
409 |
case "FQF-Rainbow":
|
410 |
+
config_fqf_r["seed"] = slider
|
411 |
match game:
|
412 |
case "Freeway":
|
413 |
config_fqf_r["resume_path"] = "models/fqf-rainbow_freeway.pth"
|
video-app/rl-video-episode-0.mp4
CHANGED
Binary files a/video-app/rl-video-episode-0.mp4 and b/video-app/rl-video-episode-0.mp4 differ
|
|
video-app/rl-video-episode-1.meta.json
CHANGED
@@ -1 +1 @@
|
|
1 |
-
{"step_id":
|
|
|
1 |
+
{"step_id": 2044, "episode_id": 1, "content_type": "video/mp4"}
|
video-app/rl-video-episode-1.mp4
CHANGED
Binary files a/video-app/rl-video-episode-1.mp4 and b/video-app/rl-video-episode-1.mp4 differ
|
|