araffin commited on
Commit
4771560
·
1 Parent(s): af5cc40

Add training/usage code

Browse files
Files changed (1) hide show
  1. README.md +89 -2
README.md CHANGED
@@ -23,6 +23,93 @@ model-index:
23
  # **TQC** Agent playing **BipedalWalker-v3**
24
  This is a trained model of a **TQC** agent playing **BipedalWalker-v3** using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3).
25
 
26
- ## Usage (with Stable-baselines3)
27
- TODO: Add your code
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
 
23
  # **TQC** Agent playing **BipedalWalker-v3**
24
  This is a trained model of a **TQC** agent playing **BipedalWalker-v3** using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3).
25
 
26
+ ## Usage (with Stable-baselines3)
27
+
28
+ ```python
29
+ from huggingface_sb3 import load_from_hub
30
+ from sb3_contrib import TQC
31
+ from stable_baselines3.common.env_util import make_vec_env
32
+ from stable_baselines3.common.evaluation import evaluate_policy
33
+
34
+ # Download checkpoint
35
+ checkpoint = load_from_hub("araffin/tqc-BipedalWalker-v3", "tqc-BipedalWalker-v3.zip")
36
+ # Load the model
37
+ model = TQC.load(checkpoint)
38
+
39
+ env = make_vec_env("BipedalWalker-v3", n_envs=1)
40
+
41
+ # Evaluate
42
+ print("Evaluating model")
43
+ mean_reward, std_reward = evaluate_policy(
44
+ model,
45
+ env,
46
+ n_eval_episodes=20,
47
+ deterministic=True,
48
+ )
49
+ print(f"Mean reward = {mean_reward:.2f} +/- {std_reward:.2f}")
50
+
51
+ # Start a new episode
52
+ obs = env.reset()
53
+
54
+ try:
55
+ while True:
56
+ action, _states = model.predict(obs, deterministic=True)
57
+ obs, rewards, dones, info = env.step(action)
58
+ env.render()
59
+ except KeyboardInterrupt:
60
+ pass
61
+ ```
62
+
63
+ ## Training code (with SB3)
64
+
65
+ ```python
66
+ from sb3_contrib import TQC
67
+ from stable_baselines3.common.env_util import make_vec_env
68
+ from stable_baselines3.common.callbacks import EvalCallback
69
+
70
+ # Create the environment
71
+ env_id = "BipedalWalker-v3"
72
+ n_envs = 2
73
+ env = make_vec_env(env_id, n_envs=n_envs)
74
+
75
+ # Create the evaluation envs
76
+ eval_envs = make_vec_env(env_id, n_envs=5)
77
+
78
+ # Adjust evaluation interval depending on the number of envs
79
+ eval_freq = int(1e5)
80
+ eval_freq = max(eval_freq // n_envs, 1)
81
+
82
+ # Create evaluation callback to save best model
83
+ # and monitor agent performance
84
+ eval_callback = EvalCallback(
85
+ eval_envs,
86
+ best_model_save_path="./logs/",
87
+ eval_freq=eval_freq,
88
+ n_eval_episodes=10,
89
+ )
90
+
91
+ # Instantiate the agent
92
+ # Hyperparameters from https://github.com/DLR-RM/rl-baselines3-zoo
93
+ model = TQC(
94
+ "MlpPolicy",
95
+ env,
96
+ learning_starts=10000,
97
+ batch_size=256,
98
+ buffer_size=300000,
99
+ learning_rate=7.3e-4,
100
+ use_sde=True,
101
+ train_freq=8,
102
+ gradient_steps=8,
103
+ gamma=0.98,
104
+ tau=0.02,
105
+ policy_kwargs=dict(log_std_init=-3, net_arch=[400, 300]),
106
+ verbose=1,
107
+ )
108
+
109
+ # Train the agent (you can kill it before using ctrl+c)
110
+ try:
111
+ model.learn(total_timesteps=int(5e5), callback=eval_callback, log_interval=10)
112
+ except KeyboardInterrupt:
113
+ pass
114
+ ```
115