liruiw commited on
Commit
a853f7d
·
1 Parent(s): 9c0c09e
Files changed (1) hide show
  1. app.py +23 -19
app.py CHANGED
@@ -26,20 +26,20 @@ available_images = sorted([img for img in os.listdir(IMAGE_DIR) if img.endswith(
26
 
27
 
28
  # Helper function to reset GenieSimulator with the selected image
 
29
  @spaces.GPU
30
- def initialize_simulator(image_name):
31
  image_path = os.path.join(IMAGE_DIR, image_name)
32
  image = Image.open(image_path)
33
- prompt_image = np.tile(np.array(image), (genie.prompt_horizon, 1, 1, 1)).astype(np.uint8)
34
- prompt_action = np.zeros((genie.prompt_horizon - 1, genie.action_stride, 2)).astype(np.float32)
35
- genie.set_initial_state((prompt_image, prompt_action))
36
- reset_image = genie.reset()
37
  reset_image = cv2.resize(reset_image, (RES, RES))
38
  return Image.fromarray(reset_image)
39
 
40
- # Example model: takes a direction and returns a random image
41
  @spaces.GPU
42
- def model(direction: str):
43
  if direction == 'right':
44
  action = np.array([0, 0.05])
45
  elif direction == 'left':
@@ -50,22 +50,20 @@ def model(direction: str):
50
  action = np.array([-0.05, 0])
51
  else:
52
  raise ValueError(f"Invalid direction: {direction}")
53
- next_image = genie.step(action)['pred_next_frame']
54
  next_image = cv2.resize(next_image, (RES, RES))
55
  return Image.fromarray(next_image)
56
 
57
- # Gradio function to handle user input
58
  @spaces.GPU
59
- def handle_input(direction):
60
  print(f"User clicked: {direction}")
61
- new_image = model(direction) # Get a new image from the model
62
  return new_image
63
 
64
- # Gradio function to handle image selection
65
  @spaces.GPU
66
- def handle_image_selection(image_name):
67
  print(f"User selected image: {image_name}")
68
- return initialize_simulator(image_name)
69
 
70
  if __name__ == '__main__':
71
  genie = GenieSimulator(
@@ -77,8 +75,10 @@ if __name__ == '__main__':
77
  prompt_horizon=PROMPT_HORIZON,
78
  action_stride=1,
79
  domain='language_table',
 
80
  )
81
 
 
82
  with gr.Blocks() as demo:
83
  image = Image.open("sim/assets/langtable_prompt/frame_06.png")
84
  prompt_image = np.tile(
@@ -88,7 +88,10 @@ if __name__ == '__main__':
88
  (genie.prompt_horizon, genie.action_stride, 2)
89
  ).astype(np.float32)
90
  genie.set_initial_state((prompt_image, prompt_action))
 
91
  image = genie.reset()
 
 
92
  with gr.Row():
93
  image_selector = gr.Dropdown(
94
  choices=available_images, value=available_images[0], label="Select an Image"
@@ -107,11 +110,12 @@ if __name__ == '__main__':
107
 
108
  # Define interactions
109
  select_button.click(
110
- fn=handle_image_selection, inputs=image_selector, outputs=image_display
111
  )
112
- up.click(fn=lambda: handle_input("up"), outputs=image_display, show_progress='hidden')
113
- down.click(fn=lambda: handle_input("down"), outputs=image_display, show_progress='hidden')
114
- left.click(fn=lambda: handle_input("left"), outputs=image_display, show_progress='hidden')
115
- right.click(fn=lambda: handle_input("right"), outputs=image_display, show_progress='hidden')
 
116
 
117
  demo.launch(share=True)
 
26
 
27
 
28
  # Helper function to reset GenieSimulator with the selected image
29
+
30
  @spaces.GPU
31
+ def initialize_simulator(image_name, state):
32
  image_path = os.path.join(IMAGE_DIR, image_name)
33
  image = Image.open(image_path)
34
+ prompt_image = np.tile(np.array(image), (state['genie'].prompt_horizon, 1, 1, 1)).astype(np.uint8)
35
+ prompt_action = np.zeros((state['genie'].prompt_horizon - 1, state['genie'].action_stride, 2)).astype(np.float32)
36
+ state['genie'].set_initial_state((prompt_image, prompt_action))
37
+ reset_image = state['genie'].reset()
38
  reset_image = cv2.resize(reset_image, (RES, RES))
39
  return Image.fromarray(reset_image)
40
 
 
41
  @spaces.GPU
42
+ def model(direction, state):
43
  if direction == 'right':
44
  action = np.array([0, 0.05])
45
  elif direction == 'left':
 
50
  action = np.array([-0.05, 0])
51
  else:
52
  raise ValueError(f"Invalid direction: {direction}")
53
+ next_image = state['genie'].step(action)['pred_next_frame']
54
  next_image = cv2.resize(next_image, (RES, RES))
55
  return Image.fromarray(next_image)
56
 
 
57
  @spaces.GPU
58
+ def handle_input(direction, state):
59
  print(f"User clicked: {direction}")
60
+ new_image = model(direction, state)
61
  return new_image
62
 
 
63
  @spaces.GPU
64
+ def handle_image_selection(image_name, state):
65
  print(f"User selected image: {image_name}")
66
+ return initialize_simulator(image_name, state)
67
 
68
  if __name__ == '__main__':
69
  genie = GenieSimulator(
 
75
  prompt_horizon=PROMPT_HORIZON,
76
  action_stride=1,
77
  domain='language_table',
78
+ device="cpu"
79
  )
80
 
81
+
82
  with gr.Blocks() as demo:
83
  image = Image.open("sim/assets/langtable_prompt/frame_06.png")
84
  prompt_image = np.tile(
 
88
  (genie.prompt_horizon, genie.action_stride, 2)
89
  ).astype(np.float32)
90
  genie.set_initial_state((prompt_image, prompt_action))
91
+ genie.device = "cuda"
92
  image = genie.reset()
93
+ genie_instance = gr.State({'genie': genie})
94
+
95
  with gr.Row():
96
  image_selector = gr.Dropdown(
97
  choices=available_images, value=available_images[0], label="Select an Image"
 
110
 
111
  # Define interactions
112
  select_button.click(
113
+ fn=handle_image_selection, inputs=[image_selector, genie_instance], outputs=image_display, show_progress='hidden'
114
  )
115
+
116
+ up.click(fn=lambda state: handle_input("up", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
117
+ down.click(fn=lambda state: handle_input("down", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
118
+ left.click(fn=lambda state: handle_input("left", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
119
+ right.click(fn=lambda state: handle_input("right", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
120
 
121
  demo.launch(share=True)