File size: 2,771 Bytes
6d1ad4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import argparse
import subprocess
import wandb
import wandb.apis.public

from collections import defaultdict
from multiprocessing.pool import ThreadPool
from typing import List, NamedTuple


class RunGroup(NamedTuple):
    algo: str
    env_id: str


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--wandb-project-name",
        type=str,
        default="rl-algo-impls-benchmarks",
        help="WandB project name to load runs from",
    )
    parser.add_argument(
        "--wandb-entity",
        type=str,
        default=None,
        help="WandB team of project. None uses default entity",
    )
    parser.add_argument("--wandb-tags", type=str, nargs="+", help="WandB tags")
    parser.add_argument("--wandb-report-url", type=str, help="Link to WandB report")
    parser.add_argument(
        "--envs", type=str, nargs="*", help="Optional filter down to these envs"
    )
    parser.add_argument(
        "--huggingface-user",
        type=str,
        default=None,
        help="Huggingface user or team to upload model cards. Defaults to huggingface-cli login user",
    )
    parser.add_argument(
        "--pool-size",
        type=int,
        default=3,
        help="How many publish jobs can run in parallel",
    )
    # parser.set_defaults(
    #     wandb_tags=["benchmark_5598ebc", "host_192-9-145-26"],
    #     wandb_report_url="https://api.wandb.ai/links/sgoodfriend/6p2sjqtn",
    # )
    args = parser.parse_args()
    print(args)

    api = wandb.Api()
    all_runs = api.runs(
        f"{args.wandb_entity or api.default_entity}/{args.wandb_project_name}"
    )

    required_tags = set(args.wandb_tags)
    runs: List[wandb.apis.public.Run] = [
        r
        for r in all_runs
        if required_tags.issubset(set(r.config.get("wandb_tags", [])))
    ]

    runs_paths_by_group = defaultdict(list)
    for r in runs:
        algo = r.config["algo"]
        env = r.config["env"]
        if args.envs and env not in args.envs:
            continue
        run_group = RunGroup(algo, env)
        runs_paths_by_group[run_group].append("/".join(r.path))

    def run(run_paths: List[str]) -> None:
        publish_args = ["python", "huggingface_publish.py"]
        publish_args.append("--wandb-run-paths")
        publish_args.extend(run_paths)
        publish_args.append("--wandb-report-url")
        publish_args.append(args.wandb_report_url)
        if args.huggingface_user:
            publish_args.append("--huggingface-user")
            publish_args.append(args.huggingface_user)
        subprocess.run(publish_args)

    tp = ThreadPool(args.pool_size)
    for run_paths in runs_paths_by_group.values():
        tp.apply_async(run, (run_paths,))
    tp.close()
    tp.join()