masterdezign
commited on
Commit
•
962a215
1
Parent(s):
c927860
Training on 300 000 timesteps
Browse files
README.md
CHANGED
@@ -10,7 +10,7 @@ model-index:
|
|
10 |
results:
|
11 |
- metrics:
|
12 |
- type: mean_reward
|
13 |
-
value:
|
14 |
name: mean_reward
|
15 |
task:
|
16 |
type: reinforcement-learning
|
@@ -25,48 +25,12 @@ This is a trained model of a **DQN** agent playing **SpaceInvadersNoFrameskip-v4
|
|
25 |
using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3).
|
26 |
|
27 |
## Usage (with Stable-baselines3)
|
|
|
28 |
|
29 |
-
```python
|
30 |
-
from stable_baselines3.common.env_util import make_atari_env
|
31 |
-
from stable_baselines3.common.vec_env import VecFrameStack
|
32 |
-
from stable_baselines3 import DQN
|
33 |
-
from stable_baselines3.common.evaluation import evaluate_policy
|
34 |
-
from huggingface_sb3 import load_from_hub, package_to_hub
|
35 |
-
from stable_baselines3.common.utils import set_random_seed
|
36 |
-
|
37 |
-
env_id = "SpaceInvadersNoFrameskip-v4"
|
38 |
-
|
39 |
-
env = make_atari_env(env_id,
|
40 |
-
n_envs=12,
|
41 |
-
# Improving reproducibility
|
42 |
-
seed=1)
|
43 |
-
env = VecFrameStack(env, n_stack=4) # Stack last four images
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
# Using these parameters as default: https://huggingface.co/micheljperez/dqn-SpaceInvadersNoFrameskip-v4
|
49 |
-
model = DQN(policy = "CnnPolicy",
|
50 |
-
env = env,
|
51 |
-
batch_size = 32,
|
52 |
-
buffer_size = 100_000,
|
53 |
-
exploration_final_eps = 0.01,
|
54 |
-
exploration_fraction = 0.025,
|
55 |
-
gradient_steps = 1,
|
56 |
-
learning_rate = 1e-4,
|
57 |
-
learning_starts = 100_000,
|
58 |
-
optimize_memory_usage = True,
|
59 |
-
replay_buffer_kwargs = {"handle_timeout_termination": False},
|
60 |
-
target_update_interval = 1000,
|
61 |
-
train_freq = 4,
|
62 |
-
# normalize = False,
|
63 |
-
tensorboard_log = "./tensorboard",
|
64 |
-
verbose=1
|
65 |
-
)
|
66 |
-
|
67 |
-
f = load_from_hub('masterdezign/dqn-SpaceInvadersNoFrameskip-v4', 'dqn-SpaceInvadersNoFrameskip-v4.zip')
|
68 |
-
model = model.load(f)
|
69 |
|
70 |
-
|
71 |
-
print(f"Mean reward = {mean_reward:.2f} +/- {std_reward:.2f}")
|
72 |
```
|
|
|
10 |
results:
|
11 |
- metrics:
|
12 |
- type: mean_reward
|
13 |
+
value: 271.50 +/- 80.19
|
14 |
name: mean_reward
|
15 |
task:
|
16 |
type: reinforcement-learning
|
|
|
25 |
using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3).
|
26 |
|
27 |
## Usage (with Stable-baselines3)
|
28 |
+
TODO: Add your code
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
+
```python
|
32 |
+
from stable_baselines3 import ...
|
33 |
+
from huggingface_sb3 import load_from_hub
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
+
...
|
|
|
36 |
```
|
config.json
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
dqn-SpaceInvadersNoFrameskip-v4.zip
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1d6d3956633c018a24698b9d8c6e8bcb555e97a1beb3137959274b3b44a23605
|
3 |
+
size 28089676
|
dqn-SpaceInvadersNoFrameskip-v4/data
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
dqn-SpaceInvadersNoFrameskip-v4/policy.optimizer.pth
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 13505611
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:511a074585123d95ebe155c2a0b1194250d977ca7d718e85f7041c42092644d1
|
3 |
size 13505611
|
dqn-SpaceInvadersNoFrameskip-v4/policy.pth
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 13504937
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:71b21eb240e1accfdeef2707a95e210c322bc20636161c1f02f983684168111a
|
3 |
size 13504937
|
results.json
CHANGED
@@ -1 +1 @@
|
|
1 |
-
{"mean_reward":
|
|
|
1 |
+
{"mean_reward": 271.5, "std_reward": 80.18883962248113, "is_deterministic": false, "n_eval_episodes": 10, "eval_datetime": "2022-07-20T14:35:46.400659"}
|