File size: 1,172 Bytes
6f3bdf9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
# Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
from rl_algo_impls.runner.evaluate import EvalArgs, evaluate_model
from rl_algo_impls.runner.running_utils import base_parser
def enjoy() -> None:
parser = base_parser(multiple=False)
parser.add_argument("--render", default=True, type=bool)
parser.add_argument("--best", default=True, type=bool)
parser.add_argument("--n_envs", default=1, type=int)
parser.add_argument("--n_episodes", default=3, type=int)
parser.add_argument("--deterministic-eval", default=None, type=bool)
parser.add_argument(
"--no-print-returns", action="store_true", help="Limit printing"
)
# wandb-run-path overrides base RunArgs
parser.add_argument("--wandb-run-path", default=None, type=str)
parser.set_defaults(
algo=["ppo"],
wandb_run_path="sgoodfriend/rl-algo-impls/m5c1t7g5",
)
args = parser.parse_args()
args.algo = args.algo[0]
args.env = args.env[0]
args = EvalArgs(**vars(args))
evaluate_model(args, os.getcwd())
if __name__ == "__main__":
enjoy()
|