a2c-Pendulum-v1 / README.md
araffin's picture
Add code
08b00b7
|
raw
history blame
3.02 kB
---
library_name: stable-baselines3
tags:
- Pendulum-v1
- deep-reinforcement-learning
- reinforcement-learning
- stable-baselines3
model-index:
- name: A2C
results:
- metrics:
- type: mean_reward
value: -141.19 +/- 122.27
name: mean_reward
task:
type: reinforcement-learning
name: reinforcement-learning
dataset:
name: Pendulum-v1
type: Pendulum-v1
---
# **A2C** Agent playing **Pendulum-v1**
This is a trained model of a **A2C** agent playing **Pendulum-v1**
using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3).
## Usage (with Stable-baselines3)
```python
from huggingface_sb3 import load_from_hub
from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import VecNormalize
# Download checkpoint and stats
env_id = "Pendulum-v1"
checkpoint = load_from_hub(f"araffin/a2c-{env_id}", f"a2c-{env_id}.zip")
vec_normalize_stats = load_from_hub(f"araffin/a2c-{env_id}", f"vec_normalize.pkl")
# Load the model
model = A2C.load(checkpoint)
env = make_vec_env(env_id, n_envs=1)
env = VecNormalize.load(vec_normalize_stats, env)
# do not update them at test time
env.training = False
# reward normalization is not needed at test time
env.norm_reward = False
# Evaluate
print("Evaluating model")
mean_reward, std_reward = evaluate_policy(
model,
env,
n_eval_episodes=20,
deterministic=True,
)
print(f"Mean reward = {mean_reward:.2f} +/- {std_reward:.2f}")
# Start a new episode
obs = env.reset()
try:
while True:
action, _states = model.predict(obs, deterministic=True)
obs, rewards, dones, info = env.step(action)
env.render()
except KeyboardInterrupt:
pass
```
## Training Code
```python
from huggingface_sb3 import package_to_hub
from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize, sync_envs_normalization
# Create the environment
env_id = "Pendulum-v1"
env = make_vec_env(env_id, n_envs=8)
# Normalize
env = VecNormalize(env, gamma=0.9)
# Create the evaluation env (could be used in `EvalCallback`)
eval_env = make_vec_env(env_id, n_envs=1)
eval_env = VecNormalize(eval_env, gamma=0.9, training=False, norm_reward=False)
# Instantiate the agent
model = A2C(
"MlpPolicy",
env,
n_steps=8,
gamma=0.9,
gae_lambda=0.9,
use_sde=True,
policy_kwargs=dict(log_std_init=-2),
verbose=1,
)
# Train the agent
try:
model.learn(total_timesteps=int(1e6))
except KeyboardInterrupt:
pass
# Synchronize stats (done automatically in `EvalCallback`)
sync_envs_normalization(env, eval_env)
package_to_hub(
model=model,
model_name=f"a2c-{env_id}",
model_architecture="A2C",
env_id=env_id,
eval_env=eval_env,
repo_id=f"araffin/a2c-{env_id}",
commit_message="Initial commit",
)
```