hma / sim /example /genie_langtable_random.py
LeroyWaa's picture
draft
246c106
raw
history blame
2.42 kB
import numpy as np
import imageio
from sim.main import InteractiveDigitalWorld
from sim.simulator import GenieSimulator
from sim.policy import RandomPlanarQuadDirectionalPolicy
if __name__ == '__main__':
def draw_action_arrow_to_image(image: np.ndarray, action: np.ndarray) -> np.ndarray:
action = action[0] # remove `stride` dimension
assert action[0] * action[1] == 0
arrow_image = imageio.imread('sim/assets/arrow.jpg')
if action[0] > 0: # `s`
arrow_image = np.flipud(arrow_image)
elif action[1] < 0: # `a`
arrow_image = np.rot90(arrow_image)
elif action[1] > 0: # `d`
arrow_image = np.rot90(arrow_image, -1)
else:
pass # `w`
image[0:arrow_image.shape[0], 0:arrow_image.shape[1]] = arrow_image
return image
genie_simulator = GenieSimulator(
# image_encoder_type="magvit",
# image_encoder_ckpt="data/magvit2.ckpt",
# quantize=True,
# backbone_type="stmaskgit",
# backbone_ckpt="data/genie_lang/step_5",
# prompt_horizon=8,
image_encoder_type='temporalvae',
image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
quantize=False,
backbone_type="stmar",
backbone_ckpt="data/language_table_scratch_mar_dynamics_gpu_8_nodes_4_16g/step_40000",
# backbone_ckpt="data/genie_lang/step_5",
prompt_horizon=11,
action_stride=1,
domain='language_table',
post_processor=draw_action_arrow_to_image
)
# use whatever current state is as the initial state
current_image = imageio.imread('sim/assets/langtable_prompt.png')
image_prompt = np.tile(
current_image, (genie_simulator.prompt_horizon, 1, 1, 1)
).astype(np.uint8)
action_prompt = np.zeros(
(genie_simulator.prompt_horizon, genie_simulator.action_stride, 2)
).astype(np.float32)
genie_simulator.set_initial_state((image_prompt, action_prompt))
random_policy = RandomPlanarQuadDirectionalPolicy(increment=0.05) # as IRASIM
playground = InteractiveDigitalWorld(
simulator=genie_simulator,
policy=random_policy,
offscreen=True,
window_size=(512, 512)
)
for _ in range(50):
playground.step()
playground.save_video(save_path='test.mp4', as_gif=False)
playground.close()