liruiw commited on
Commit
5ed752c
·
1 Parent(s): 2c0d4d5
Files changed (1) hide show
  1. app.py +12 -12
app.py CHANGED
@@ -28,18 +28,18 @@ available_images = sorted([img for img in os.listdir(IMAGE_DIR) if img.endswith(
28
  # Helper function to reset GenieSimulator with the selected image
29
 
30
  @spaces.GPU
31
- def initialize_simulator(image_name, state):
32
  image_path = os.path.join(IMAGE_DIR, image_name)
33
  image = Image.open(image_path)
34
- prompt_image = np.tile(np.array(image), (state['genie'].prompt_horizon, 1, 1, 1)).astype(np.uint8)
35
- prompt_action = np.zeros((state['genie'].prompt_horizon - 1, state['genie'].action_stride, 2)).astype(np.float32)
36
- state['genie'].set_initial_state((prompt_image, prompt_action))
37
- reset_image = state['genie'].reset()
38
  reset_image = cv2.resize(reset_image, (RES, RES))
39
  return Image.fromarray(reset_image)
40
 
41
  @spaces.GPU
42
- def model(direction, state):
43
  if direction == 'right':
44
  action = np.array([0, 0.05])
45
  elif direction == 'left':
@@ -50,7 +50,7 @@ def model(direction, state):
50
  action = np.array([-0.05, 0])
51
  else:
52
  raise ValueError(f"Invalid direction: {direction}")
53
- next_image = state['genie'].step(action)['pred_next_frame']
54
  next_image = cv2.resize(next_image, (RES, RES))
55
  return Image.fromarray(next_image)
56
 
@@ -112,12 +112,12 @@ if __name__ == '__main__':
112
 
113
  # Define interactions
114
  select_button.click(
115
- fn=handle_image_selection, inputs=[image_selector, genie_instance], outputs=image_display, show_progress='hidden'
116
  )
117
 
118
- up.click(fn=lambda state: handle_input("up", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
119
- down.click(fn=lambda state: handle_input("down", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
120
- left.click(fn=lambda state: handle_input("left", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
121
- right.click(fn=lambda state: handle_input("right", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
122
 
123
  demo.launch(share=True)
 
28
  # Helper function to reset GenieSimulator with the selected image
29
 
30
  @spaces.GPU
31
+ def initialize_simulator(image_name, genie):
32
  image_path = os.path.join(IMAGE_DIR, image_name)
33
  image = Image.open(image_path)
34
+ prompt_image = np.tile(np.array(image), (genie.prompt_horizon, 1, 1, 1)).astype(np.uint8)
35
+ prompt_action = np.zeros((genie.prompt_horizon - 1, genie.action_stride, 2)).astype(np.float32)
36
+ genie.set_initial_state((prompt_image, prompt_action))
37
+ reset_image = genie.reset()
38
  reset_image = cv2.resize(reset_image, (RES, RES))
39
  return Image.fromarray(reset_image)
40
 
41
  @spaces.GPU
42
+ def model(direction, genie):
43
  if direction == 'right':
44
  action = np.array([0, 0.05])
45
  elif direction == 'left':
 
50
  action = np.array([-0.05, 0])
51
  else:
52
  raise ValueError(f"Invalid direction: {direction}")
53
+ next_image = genie.step(action)['pred_next_frame']
54
  next_image = cv2.resize(next_image, (RES, RES))
55
  return Image.fromarray(next_image)
56
 
 
112
 
113
  # Define interactions
114
  select_button.click(
115
+ fn=handle_image_selection, inputs=[image_selector, genie], outputs=image_display, show_progress='hidden'
116
  )
117
 
118
+ up.click(fn=lambda state: handle_input("up", state), inputs=[genie], outputs=image_display, show_progress='hidden')
119
+ down.click(fn=lambda state: handle_input("down", state), inputs=[genie], outputs=image_display, show_progress='hidden')
120
+ left.click(fn=lambda state: handle_input("left", state), inputs=[genie], outputs=image_display, show_progress='hidden')
121
+ right.click(fn=lambda state: handle_input("right", state), inputs=[genie], outputs=image_display, show_progress='hidden')
122
 
123
  demo.launch(share=True)