liruiw commited on
Commit
14420e9
Β·
1 Parent(s): f64bf43
Files changed (1) hide show
  1. app.py +48 -17
app.py CHANGED
@@ -1,36 +1,52 @@
1
  import gradio as gr
2
  import spaces
3
 
 
 
4
  import numpy as np
5
  from PIL import Image
6
  import cv2
7
  from sim.simulator import GenieSimulator
 
 
 
 
 
 
 
8
 
9
  RES = 512
10
- image = Image.open("sim/assets/langtable_prompt/frame_06.png")
 
 
 
 
 
 
11
  genie = GenieSimulator(
12
  image_encoder_type='temporalvae',
13
  image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
14
  quantize=False,
15
  backbone_type='stmar',
16
- backbone_ckpt='data/mar_ckpt/langtable',
17
- prompt_horizon=3,
18
  action_stride=1,
19
  domain='language_table',
20
  )
21
- prompt_image = np.tile(
22
- np.array(image), (genie.prompt_horizon, 1, 1, 1)
23
- ).astype(np.uint8)
24
- prompt_action = np.zeros(
25
- (genie.prompt_horizon - 1, genie.action_stride, 2)
26
- ).astype(np.float32)
27
- genie.set_initial_state((prompt_image, prompt_action))
28
- image = genie.reset()
29
- image = cv2.resize(image, (RES, RES))
30
- image = Image.fromarray(image)
 
31
 
32
  # Example model: takes a direction and returns a random image
33
- def model(direction: str, genie=genie):
34
  if direction == 'right':
35
  action = np.array([0, 0.05])
36
  elif direction == 'left':
@@ -52,10 +68,22 @@ def handle_input(direction):
52
  new_image = model(direction) # Get a new image from the model
53
  return new_image
54
 
 
 
 
 
 
55
  if __name__ == '__main__':
56
  with gr.Blocks() as demo:
57
  with gr.Row():
58
- image_display = gr.Image(value=image, type="pil", label="Generated Image")
 
 
 
 
 
 
 
59
  with gr.Row():
60
  up = gr.Button("↑ Up")
61
  with gr.Row():
@@ -63,10 +91,13 @@ if __name__ == '__main__':
63
  down = gr.Button("↓ Down")
64
  right = gr.Button("β†’ Right")
65
 
66
- # Define button interactions
 
 
 
67
  up.click(fn=lambda: handle_input("up"), outputs=image_display, show_progress='hidden')
68
  down.click(fn=lambda: handle_input("down"), outputs=image_display, show_progress='hidden')
69
  left.click(fn=lambda: handle_input("left"), outputs=image_display, show_progress='hidden')
70
  right.click(fn=lambda: handle_input("right"), outputs=image_display, show_progress='hidden')
71
 
72
- demo.launch()
 
1
  import gradio as gr
2
  import spaces
3
 
4
+
5
+ import gradio as gr
6
  import numpy as np
7
  from PIL import Image
8
  import cv2
9
  from sim.simulator import GenieSimulator
10
+ import os
11
+
12
+ if not os.path.exists("data/mar_ckpt/langtable"):
13
+ # download from google drive
14
+ import gdown
15
+ gdown.download_folder("https://drive.google.com/drive/u/2/folders/1XU87cRqV-IMZA6RLiabIR_uZngynvUFN")
16
+ os.system("mkdir -p data/mar_ckpt/; mv langtable data/mar_ckpt/")
17
 
18
  RES = 512
19
+ PROMPT_HORIZON = 3
20
+ IMAGE_DIR = "sim/assets/langtable_prompt/"
21
+
22
+ # Load available images
23
+ available_images = sorted([img for img in os.listdir(IMAGE_DIR) if img.endswith(".png")])
24
+
25
+
26
  genie = GenieSimulator(
27
  image_encoder_type='temporalvae',
28
  image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
29
  quantize=False,
30
  backbone_type='stmar',
31
+ backbone_ckpt='data/mar_ckpt_long2/langtable',
32
+ prompt_horizon=PROMPT_HORIZON,
33
  action_stride=1,
34
  domain='language_table',
35
  )
36
+
37
+ # Helper function to reset GenieSimulator with the selected image
38
+ def initialize_simulator(image_name):
39
+ image_path = os.path.join(IMAGE_DIR, image_name)
40
+ image = Image.open(image_path)
41
+ prompt_image = np.tile(np.array(image), (genie.prompt_horizon, 1, 1, 1)).astype(np.uint8)
42
+ prompt_action = np.zeros((genie.prompt_horizon - 1, genie.action_stride, 2)).astype(np.float32)
43
+ genie.set_initial_state((prompt_image, prompt_action))
44
+ reset_image = genie.reset()
45
+ reset_image = cv2.resize(reset_image, (RES, RES))
46
+ return Image.fromarray(reset_image)
47
 
48
  # Example model: takes a direction and returns a random image
49
+ def model(direction: str):
50
  if direction == 'right':
51
  action = np.array([0, 0.05])
52
  elif direction == 'left':
 
68
  new_image = model(direction) # Get a new image from the model
69
  return new_image
70
 
71
+ # Gradio function to handle image selection
72
+ def handle_image_selection(image_name):
73
+ print(f"User selected image: {image_name}")
74
+ return initialize_simulator(image_name)
75
+
76
  if __name__ == '__main__':
77
  with gr.Blocks() as demo:
78
  with gr.Row():
79
+ image_selector = gr.Dropdown(
80
+ choices=available_images, value=available_images[0], label="Select an Image"
81
+ )
82
+ select_button = gr.Button("Load Image")
83
+
84
+ with gr.Row():
85
+ image_display = gr.Image(type="pil", label="Generated Image")
86
+
87
  with gr.Row():
88
  up = gr.Button("↑ Up")
89
  with gr.Row():
 
91
  down = gr.Button("↓ Down")
92
  right = gr.Button("β†’ Right")
93
 
94
+ # Define interactions
95
+ select_button.click(
96
+ fn=handle_image_selection, inputs=image_selector, outputs=image_display
97
+ )
98
  up.click(fn=lambda: handle_input("up"), outputs=image_display, show_progress='hidden')
99
  down.click(fn=lambda: handle_input("down"), outputs=image_display, show_progress='hidden')
100
  left.click(fn=lambda: handle_input("left"), outputs=image_display, show_progress='hidden')
101
  right.click(fn=lambda: handle_input("right"), outputs=image_display, show_progress='hidden')
102
 
103
+ demo.launch(share=True)