File size: 2,114 Bytes
76ee962
 
 
 
 
 
 
 
5ee99e9
 
76ee962
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
import os

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

from multiprocessing import Pool

from rl_algo_impls.runner.running_utils import base_parser
from rl_algo_impls.runner.train import TrainArgs
from rl_algo_impls.runner.train import train as runner_train


def train() -> None:
    parser = base_parser()
    parser.add_argument(
        "--wandb-project-name",
        type=str,
        default="rl-algo-impls",
        help="WandB project name to upload training data to. If none, won't upload.",
    )
    parser.add_argument(
        "--wandb-entity",
        type=str,
        default=None,
        help="WandB team of project. None uses default entity",
    )
    parser.add_argument(
        "--wandb-tags", type=str, nargs="*", help="WandB tags to add to run"
    )
    parser.add_argument(
        "--pool-size", type=int, default=1, help="Simultaneous training jobs to run"
    )
    parser.add_argument(
        "--virtual-display", action="store_true", help="Use headless virtual display"
    )
    # parser.set_defaults(
    #     algo=["ppo"],
    #     env=["CartPole-v1"],
    #     seed=[10],
    #     pool_size=3,
    # )
    args = parser.parse_args()
    print(args)

    if args.virtual_display:
        from pyvirtualdisplay.display import Display

        virtual_display = Display(visible=False, size=(1400, 900))
        virtual_display.start()
    # virtual_display isn't a TrainArg so must be removed
    delattr(args, "virtual_display")

    pool_size = min(args.pool_size, len(args.seed))
    # pool_size isn't a TrainArg so must be removed from args
    delattr(args, "pool_size")

    train_args = TrainArgs.expand_from_dict(vars(args))
    if len(train_args) == 1:
        runner_train(train_args[0])
    else:
        # Force a new process for each job to get around wandb not allowing more than one
        # wandb.tensorboard.patch call per process.
        with Pool(pool_size, maxtasksperchild=1) as p:
            p.map(runner_train, train_args)


if __name__ == "__main__":
    train()