Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,422 Bytes
246c106 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import numpy as np
import imageio
from sim.main import InteractiveDigitalWorld
from sim.simulator import GenieSimulator
from sim.policy import RandomPlanarQuadDirectionalPolicy
if __name__ == '__main__':
def draw_action_arrow_to_image(image: np.ndarray, action: np.ndarray) -> np.ndarray:
action = action[0] # remove `stride` dimension
assert action[0] * action[1] == 0
arrow_image = imageio.imread('sim/assets/arrow.jpg')
if action[0] > 0: # `s`
arrow_image = np.flipud(arrow_image)
elif action[1] < 0: # `a`
arrow_image = np.rot90(arrow_image)
elif action[1] > 0: # `d`
arrow_image = np.rot90(arrow_image, -1)
else:
pass # `w`
image[0:arrow_image.shape[0], 0:arrow_image.shape[1]] = arrow_image
return image
genie_simulator = GenieSimulator(
# image_encoder_type="magvit",
# image_encoder_ckpt="data/magvit2.ckpt",
# quantize=True,
# backbone_type="stmaskgit",
# backbone_ckpt="data/genie_lang/step_5",
# prompt_horizon=8,
image_encoder_type='temporalvae',
image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
quantize=False,
backbone_type="stmar",
backbone_ckpt="data/language_table_scratch_mar_dynamics_gpu_8_nodes_4_16g/step_40000",
# backbone_ckpt="data/genie_lang/step_5",
prompt_horizon=11,
action_stride=1,
domain='language_table',
post_processor=draw_action_arrow_to_image
)
# use whatever current state is as the initial state
current_image = imageio.imread('sim/assets/langtable_prompt.png')
image_prompt = np.tile(
current_image, (genie_simulator.prompt_horizon, 1, 1, 1)
).astype(np.uint8)
action_prompt = np.zeros(
(genie_simulator.prompt_horizon, genie_simulator.action_stride, 2)
).astype(np.float32)
genie_simulator.set_initial_state((image_prompt, action_prompt))
random_policy = RandomPlanarQuadDirectionalPolicy(increment=0.05) # as IRASIM
playground = InteractiveDigitalWorld(
simulator=genie_simulator,
policy=random_policy,
offscreen=True,
window_size=(512, 512)
)
for _ in range(50):
playground.step()
playground.save_video(save_path='test.mp4', as_gif=False)
playground.close()
|