File size: 4,046 Bytes
246c106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from sim.robomimic.robomimic_runner import RolloutRunner
from sim.policy import GeniePolicy
import argparse
from datetime import datetime

current_date = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

if __name__ == "__main__":
    # initialize environment
    parser = argparse.ArgumentParser(description="policy to evaluate")

    # Data
    parser.add_argument( "--env_name", type=str, default="lift")
    parser.add_argument( "--num_runs", type=int, default=1)
    parser.add_argument( "--save_video", action="store_true")
    parser.add_argument( "--model", type=str, default="data/mar_policy_dynamics/step_30000")
    parser.add_argument( "--use_magvit", action="store_true")
    parser.add_argument( "--is_full_dynamics", action="store_true")
    parser.add_argument( "--use_raw_image", action="store_true")

    parser.add_argument( "--execution_horizon", type=int, default=4)
    parser.add_argument( "--diffusion_steps", type=int, default=100)
    parser.add_argument( "--inference_iterations", type=int, default=1)
    parser.add_argument( "--prompt_horizon", type=int, default=1)

    args = parser.parse_args()

    env_name = args.env_name
    rollout_runner = RolloutRunner( env_names=[env_name], episode_num=args.num_runs, save_video=args.save_video)
    execution_horizon = args.execution_horizon
    diffusion_steps = args.diffusion_steps
    inference_iterations = args.inference_iterations
    prompt_horizon = args.prompt_horizon

    is_full_dynamics = args.is_full_dynamics
    model = args.model

    if is_full_dynamics:
        # model = "data/mar_policy_dynamics2/final2_robomimic_scratch_mar_forward_dynamics_gpu_8_nodes_2_16g/step_50000"
        # model = "data/final2_robomimic_scratch_mar_full_dynamics_new_gpu_8_nodes_4_16g/step_10000"
        # model = "data/final2_robomimic_scratch_mar_dynamics_fullpastmask_new_gpu_8_nodes_4_16g/step_10000"
        # model = "data/final2_robomimic_scratch_mar_full_dynamics_fixed_new_gpu_8_nodes_4_16g/step_20000"

        model_suffix = f"dynamics_{model.split('/')[-2]}_{model.split('/')[-1]}_horizon{execution_horizon}"
    else:
        # model = "data/mar_policy2/final2_robomimic_scratch_mar_actiononly_gpu_8_nodes_4_16g/final_checkpt"
        # model = "data/final2_robomimic_scratch_mar_actiononly_new_gpu_8_nodes_4_16g/final2_robomimic_scratch_mar_actiononly_new_gpu_8_nodes_4_16g/step_10000"
        # model = "data/final2_robomimic_scratch_mar_fullpastmask_actiononly_fixed_gpu_8_nodes_4_16g/step_10000"
        # model = "data/final2_robomimic_scratch_mar_fullpastmask_actiononly_fixed_gpu_8_nodes_4_16g/step_10000"


        # model = "data/mar_policy_actiononly3/step_10000"
        model_suffix = f"actiononly_{model.split('/')[-2]}_{model.split('/')[-1]}_horizon{execution_horizon}"

    policy = GeniePolicy(
        image_encoder_type="temporalvae"  if not args.use_magvit else "magvit",
        image_encoder_ckpt="stabilityai/stable-video-diffusion-img2vid"  if not args.use_magvit else "data/magvit2.ckpt",
        quantize=False  if not args.use_magvit else True,
        backbone_type="stmar" if not args.use_magvit else "stmaskgit",
        backbone_ckpt=model,
        prompt_horizon=prompt_horizon, # history step
        prediction_horizon=execution_horizon, # future step
        execution_horizon=execution_horizon, # open loop step
        inference_iterations=inference_iterations, # maskgit step
        diffusion_steps=diffusion_steps, # diffusion steps
        action_stride=1,
        domain="robomimic",
        is_full_dynamics=is_full_dynamics,
        use_raw_image=args.use_raw_image,
    )

    # initialize policy
    success, reward = rollout_runner.run(policy=policy, env_name=[env_name], video_postfix=model_suffix)
    print(f"success: {success}, reward: {reward}")
    # dump the success with model name to csv
    with open("success.csv", "a+") as f:
        f.write(f"{model_suffix}, {success}, {reward}, {execution_horizon}, {diffusion_steps}, {inference_iterations}, {args.prompt_horizon}, {args.num_runs}, {current_date}\n")