hma / sim /evaluate_policy_gt_sim.py
LeroyWaa's picture
draft
246c106
raw
history blame
4.05 kB
from sim.robomimic.robomimic_runner import RolloutRunner
from sim.policy import GeniePolicy
import argparse
from datetime import datetime
current_date = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
if __name__ == "__main__":
# initialize environment
parser = argparse.ArgumentParser(description="policy to evaluate")
# Data
parser.add_argument( "--env_name", type=str, default="lift")
parser.add_argument( "--num_runs", type=int, default=1)
parser.add_argument( "--save_video", action="store_true")
parser.add_argument( "--model", type=str, default="data/mar_policy_dynamics/step_30000")
parser.add_argument( "--use_magvit", action="store_true")
parser.add_argument( "--is_full_dynamics", action="store_true")
parser.add_argument( "--use_raw_image", action="store_true")
parser.add_argument( "--execution_horizon", type=int, default=4)
parser.add_argument( "--diffusion_steps", type=int, default=100)
parser.add_argument( "--inference_iterations", type=int, default=1)
parser.add_argument( "--prompt_horizon", type=int, default=1)
args = parser.parse_args()
env_name = args.env_name
rollout_runner = RolloutRunner( env_names=[env_name], episode_num=args.num_runs, save_video=args.save_video)
execution_horizon = args.execution_horizon
diffusion_steps = args.diffusion_steps
inference_iterations = args.inference_iterations
prompt_horizon = args.prompt_horizon
is_full_dynamics = args.is_full_dynamics
model = args.model
if is_full_dynamics:
# model = "data/mar_policy_dynamics2/final2_robomimic_scratch_mar_forward_dynamics_gpu_8_nodes_2_16g/step_50000"
# model = "data/final2_robomimic_scratch_mar_full_dynamics_new_gpu_8_nodes_4_16g/step_10000"
# model = "data/final2_robomimic_scratch_mar_dynamics_fullpastmask_new_gpu_8_nodes_4_16g/step_10000"
# model = "data/final2_robomimic_scratch_mar_full_dynamics_fixed_new_gpu_8_nodes_4_16g/step_20000"
model_suffix = f"dynamics_{model.split('/')[-2]}_{model.split('/')[-1]}_horizon{execution_horizon}"
else:
# model = "data/mar_policy2/final2_robomimic_scratch_mar_actiononly_gpu_8_nodes_4_16g/final_checkpt"
# model = "data/final2_robomimic_scratch_mar_actiononly_new_gpu_8_nodes_4_16g/final2_robomimic_scratch_mar_actiononly_new_gpu_8_nodes_4_16g/step_10000"
# model = "data/final2_robomimic_scratch_mar_fullpastmask_actiononly_fixed_gpu_8_nodes_4_16g/step_10000"
# model = "data/final2_robomimic_scratch_mar_fullpastmask_actiononly_fixed_gpu_8_nodes_4_16g/step_10000"
# model = "data/mar_policy_actiononly3/step_10000"
model_suffix = f"actiononly_{model.split('/')[-2]}_{model.split('/')[-1]}_horizon{execution_horizon}"
policy = GeniePolicy(
image_encoder_type="temporalvae" if not args.use_magvit else "magvit",
image_encoder_ckpt="stabilityai/stable-video-diffusion-img2vid" if not args.use_magvit else "data/magvit2.ckpt",
quantize=False if not args.use_magvit else True,
backbone_type="stmar" if not args.use_magvit else "stmaskgit",
backbone_ckpt=model,
prompt_horizon=prompt_horizon, # history step
prediction_horizon=execution_horizon, # future step
execution_horizon=execution_horizon, # open loop step
inference_iterations=inference_iterations, # maskgit step
diffusion_steps=diffusion_steps, # diffusion steps
action_stride=1,
domain="robomimic",
is_full_dynamics=is_full_dynamics,
use_raw_image=args.use_raw_image,
)
# initialize policy
success, reward = rollout_runner.run(policy=policy, env_name=[env_name], video_postfix=model_suffix)
print(f"success: {success}, reward: {reward}")
# dump the success with model name to csv
with open("success.csv", "a+") as f:
f.write(f"{model_suffix}, {success}, {reward}, {execution_horizon}, {diffusion_steps}, {inference_iterations}, {args.prompt_horizon}, {args.num_runs}, {current_date}\n")