liruiw commited on
Commit
7140e01
Β·
1 Parent(s): a15acd9
Files changed (1) hide show
  1. app.py +14 -4
app.py CHANGED
@@ -62,7 +62,7 @@ def handle_image_selection(image_name, state):
62
  return initialize_simulator(image_name, state)
63
 
64
  def init_model():
65
- return GenieSimulator(
66
  image_encoder_type='temporalvae',
67
  image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
68
  quantize=False,
@@ -73,6 +73,14 @@ def init_model():
73
  domain='language_table',
74
  )
75
 
 
 
 
 
 
 
 
 
76
 
77
  if __name__ == '__main__':
78
 
@@ -89,6 +97,10 @@ if __name__ == '__main__':
89
  with gr.Row():
90
  image_display = gr.Image(type="pil", label="Generated Image")
91
 
 
 
 
 
92
  with gr.Row():
93
  up = gr.Button("↑ Up")
94
  with gr.Row():
@@ -96,9 +108,7 @@ if __name__ == '__main__':
96
  down = gr.Button("↓ Down")
97
  right = gr.Button("β†’ Right")
98
 
99
- select_button.click(
100
- fn=handle_image_selection, inputs=[image_selector, genie_instance], outputs=image_display, show_progress='hidden'
101
- )
102
  up.click(fn=lambda state: handle_input("up", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
103
  down.click(fn=lambda state: handle_input("down", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
104
  left.click(fn=lambda state: handle_input("left", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
 
62
  return initialize_simulator(image_name, state)
63
 
64
  def init_model():
65
+ genie = GenieSimulator(
66
  image_encoder_type='temporalvae',
67
  image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
68
  quantize=False,
 
73
  domain='language_table',
74
  )
75
 
76
+ image = Image.open("sim/assets/langtable_prompt/frame_06.png")
77
+ prompt_image = np.tile(
78
+ np.array(image), (genie.prompt_horizon, 1, 1, 1)
79
+ ).astype(np.uint8)
80
+ prompt_action = np.zeros(
81
+ (genie.prompt_horizon, genie.action_stride, 2)
82
+ ).astype(np.float32)
83
+ return genie
84
 
85
  if __name__ == '__main__':
86
 
 
97
  with gr.Row():
98
  image_display = gr.Image(type="pil", label="Generated Image")
99
 
100
+ select_button.click(
101
+ fn=handle_image_selection, inputs=[image_selector, genie_instance], outputs=image_display, show_progress='hidden'
102
+ )
103
+
104
  with gr.Row():
105
  up = gr.Button("↑ Up")
106
  with gr.Row():
 
108
  down = gr.Button("↓ Down")
109
  right = gr.Button("β†’ Right")
110
 
111
+
 
 
112
  up.click(fn=lambda state: handle_input("up", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
113
  down.click(fn=lambda state: handle_input("down", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
114
  left.click(fn=lambda state: handle_input("left", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')