PPO playing AntBulletEnv-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/2067e21d62fff5db60168687e7d9e89019a8bfc0
7de7cd2
import os | |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" | |
import argparse | |
import requests | |
import shutil | |
import subprocess | |
import tempfile | |
import wandb | |
import wandb.apis.public | |
from typing import List, Optional | |
from huggingface_hub.hf_api import HfApi, upload_folder | |
from huggingface_hub.repocard import metadata_save | |
from pyvirtualdisplay.display import Display | |
from rl_algo_impls.publish.markdown_format import EvalTableData, model_card_text | |
from rl_algo_impls.runner.config import EnvHyperparams | |
from rl_algo_impls.runner.evaluate import EvalArgs, evaluate_model | |
from rl_algo_impls.runner.env import make_eval_env | |
from rl_algo_impls.shared.callbacks.eval_callback import evaluate | |
from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder | |
def publish( | |
wandb_run_paths: List[str], | |
wandb_report_url: str, | |
huggingface_user: Optional[str] = None, | |
huggingface_token: Optional[str] = None, | |
virtual_display: bool = False, | |
) -> None: | |
if virtual_display: | |
display = Display(visible=False, size=(1400, 900)) | |
display.start() | |
api = wandb.Api() | |
runs = [api.run(rp) for rp in wandb_run_paths] | |
algo = runs[0].config["algo"] | |
hyperparam_id = runs[0].config["env"] | |
evaluations = [ | |
evaluate_model( | |
EvalArgs( | |
algo, | |
hyperparam_id, | |
seed=r.config.get("seed", None), | |
render=False, | |
best=True, | |
n_envs=None, | |
n_episodes=10, | |
no_print_returns=True, | |
wandb_run_path="/".join(r.path), | |
), | |
os.getcwd(), | |
) | |
for r in runs | |
] | |
run_metadata = requests.get(runs[0].file("wandb-metadata.json").url).json() | |
table_data = list(EvalTableData(r, e) for r, e in zip(runs, evaluations)) | |
best_eval = sorted( | |
table_data, key=lambda d: d.evaluation.stats.score, reverse=True | |
)[0] | |
with tempfile.TemporaryDirectory() as tmpdirname: | |
_, (policy, stats, config) = best_eval | |
repo_name = config.model_name(include_seed=False) | |
repo_dir_path = os.path.join(tmpdirname, repo_name) | |
# Locally clone this repo to a temp directory | |
subprocess.run(["git", "clone", ".", repo_dir_path]) | |
shutil.rmtree(os.path.join(repo_dir_path, ".git")) | |
model_path = config.model_dir_path(best=True, downloaded=True) | |
shutil.copytree( | |
model_path, | |
os.path.join( | |
repo_dir_path, "saved_models", config.model_dir_name(best=True) | |
), | |
) | |
github_url = "https://github.com/sgoodfriend/rl-algo-impls" | |
commit_hash = run_metadata.get("git", {}).get("commit", None) | |
env_id = runs[0].config.get("env_id") or runs[0].config["env"] | |
card_text = model_card_text( | |
algo, | |
env_id, | |
github_url, | |
commit_hash, | |
wandb_report_url, | |
table_data, | |
best_eval, | |
) | |
readme_filepath = os.path.join(repo_dir_path, "README.md") | |
os.remove(readme_filepath) | |
with open(readme_filepath, "w") as f: | |
f.write(card_text) | |
metadata = { | |
"library_name": "rl-algo-impls", | |
"tags": [ | |
env_id, | |
algo, | |
"deep-reinforcement-learning", | |
"reinforcement-learning", | |
], | |
"model-index": [ | |
{ | |
"name": algo, | |
"results": [ | |
{ | |
"metrics": [ | |
{ | |
"type": "mean_reward", | |
"value": str(stats.score), | |
"name": "mean_reward", | |
} | |
], | |
"task": { | |
"type": "reinforcement-learning", | |
"name": "reinforcement-learning", | |
}, | |
"dataset": { | |
"name": env_id, | |
"type": env_id, | |
}, | |
} | |
], | |
} | |
], | |
} | |
metadata_save(readme_filepath, metadata) | |
video_env = VecEpisodeRecorder( | |
make_eval_env( | |
config, | |
EnvHyperparams(**config.env_hyperparams), | |
override_n_envs=1, | |
normalize_load_path=model_path, | |
), | |
os.path.join(repo_dir_path, "replay"), | |
max_video_length=3600, | |
) | |
evaluate( | |
video_env, | |
policy, | |
1, | |
deterministic=config.eval_params.get("deterministic", True), | |
) | |
api = HfApi() | |
huggingface_user = huggingface_user or api.whoami()["name"] | |
huggingface_repo = f"{huggingface_user}/{repo_name}" | |
api.create_repo( | |
token=huggingface_token, | |
repo_id=huggingface_repo, | |
private=False, | |
exist_ok=True, | |
) | |
repo_url = upload_folder( | |
repo_id=huggingface_repo, | |
folder_path=repo_dir_path, | |
path_in_repo="", | |
commit_message=f"{algo.upper()} playing {env_id} from {github_url}/tree/{commit_hash}", | |
token=huggingface_token, | |
delete_patterns="*", | |
) | |
print(f"Pushed model to the hub: {repo_url}") | |
def huggingface_publish(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--wandb-run-paths", | |
type=str, | |
nargs="+", | |
help="Run paths of the form entity/project/run_id", | |
) | |
parser.add_argument("--wandb-report-url", type=str, help="Link to WandB report") | |
parser.add_argument( | |
"--huggingface-user", | |
type=str, | |
help="Huggingface user or team to upload model cards", | |
default=None, | |
) | |
parser.add_argument( | |
"--virtual-display", action="store_true", help="Use headless virtual display" | |
) | |
args = parser.parse_args() | |
print(args) | |
publish(**vars(args)) | |
if __name__ == "__main__": | |
huggingface_publish() | |