liruiw commited on
Commit
bd39096
·
1 Parent(s): 178cd5c
Files changed (1) hide show
  1. app.py +22 -25
app.py CHANGED
@@ -65,34 +65,31 @@ def handle_image_selection(image_name, state):
65
  print(f"User selected image: {image_name}")
66
  return initialize_simulator(image_name, state)
67
 
68
- @spaces.GPU
69
- def init_model():
70
- genie = GenieSimulator(
71
- image_encoder_type='temporalvae',
72
- image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
73
- quantize=False,
74
- backbone_type='stmar',
75
- backbone_ckpt='data/mar_ckpt/langtable',
76
- prompt_horizon=PROMPT_HORIZON,
77
- action_stride=1,
78
- domain='language_table',
79
- device="cpu"
80
- )
 
 
 
 
 
 
 
 
 
81
 
82
- image = Image.open("sim/assets/langtable_prompt/frame_06.png")
83
- prompt_image = np.tile(
84
- np.array(image), (genie.prompt_horizon, 1, 1, 1)
85
- ).astype(np.uint8)
86
- prompt_action = np.zeros(
87
- (genie.prompt_horizon, genie.action_stride, 2)
88
- ).astype(np.float32)
89
- genie.set_initial_state((prompt_image, prompt_action))
90
- genie.device = "cuda"
91
- image = genie.reset()
92
- return genie
93
 
94
  if __name__ == '__main__':
95
- genie = init_model()
96
  with gr.Blocks() as demo:
97
  genie_instance = gr.State({'genie': genie})
98
  genie.device = "cuda"
 
65
  print(f"User selected image: {image_name}")
66
  return initialize_simulator(image_name, state)
67
 
68
+ genie = GenieSimulator(
69
+ image_encoder_type='temporalvae',
70
+ image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
71
+ quantize=False,
72
+ backbone_type='stmar',
73
+ backbone_ckpt='data/mar_ckpt/langtable',
74
+ prompt_horizon=PROMPT_HORIZON,
75
+ action_stride=1,
76
+ domain='language_table',
77
+ device="cpu"
78
+ )
79
+
80
+ image = Image.open("sim/assets/langtable_prompt/frame_06.png")
81
+ prompt_image = np.tile(
82
+ np.array(image), (genie.prompt_horizon, 1, 1, 1)
83
+ ).astype(np.uint8)
84
+ prompt_action = np.zeros(
85
+ (genie.prompt_horizon, genie.action_stride, 2)
86
+ ).astype(np.float32)
87
+ genie.set_initial_state((prompt_image, prompt_action))
88
+ genie.device = "cuda"
89
+ image = genie.reset()
90
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  if __name__ == '__main__':
 
93
  with gr.Blocks() as demo:
94
  genie_instance = gr.State({'genie': genie})
95
  genie.device = "cuda"