liruiw commited on
Commit
240cb6a
·
1 Parent(s): ca9c8aa
Files changed (2) hide show
  1. app copy 2.py +122 -0
  2. app.py +24 -37
app copy 2.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ from PIL import Image
8
+ import cv2
9
+ from sim.simulator import GenieSimulator
10
+ import os
11
+
12
+ if not os.path.exists("data/mar_ckpt/langtable"):
13
+ # download from google drive
14
+ import gdown
15
+ gdown.download_folder("https://drive.google.com/drive/u/2/folders/1XU87cRqV-IMZA6RLiabIR_uZngynvUFN")
16
+ os.system("mkdir -p data/mar_ckpt/; mv langtable data/mar_ckpt/")
17
+
18
+ RES = 512
19
+ PROMPT_HORIZON = 3
20
+ IMAGE_DIR = "sim/assets/langtable_prompt/"
21
+
22
+ # Load available images
23
+ available_images = sorted([img for img in os.listdir(IMAGE_DIR) if img.endswith(".png")])
24
+
25
+
26
+
27
+
28
+ # Helper function to reset GenieSimulator with the selected image
29
+
30
+ @spaces.GPU
31
+ def initialize_simulator(image_name, genie):
32
+ image_path = os.path.join(IMAGE_DIR, image_name)
33
+ image = Image.open(image_path)
34
+ prompt_image = np.tile(np.array(image), (genie.prompt_horizon, 1, 1, 1)).astype(np.uint8)
35
+ prompt_action = np.zeros((genie.prompt_horizon - 1, genie.action_stride, 2)).astype(np.float32)
36
+ genie.set_initial_state((prompt_image, prompt_action))
37
+ reset_image = genie.reset()
38
+ reset_image = cv2.resize(reset_image, (RES, RES))
39
+ return Image.fromarray(reset_image)
40
+
41
+ @spaces.GPU
42
+ def model(direction, genie):
43
+ if direction == 'right':
44
+ action = np.array([0, 0.05])
45
+ elif direction == 'left':
46
+ action = np.array([0, -0.05])
47
+ elif direction == 'down':
48
+ action = np.array([0.05, 0])
49
+ elif direction == 'up':
50
+ action = np.array([-0.05, 0])
51
+ else:
52
+ raise ValueError(f"Invalid direction: {direction}")
53
+ next_image = genie.step(action)['pred_next_frame']
54
+ next_image = cv2.resize(next_image, (RES, RES))
55
+ return Image.fromarray(next_image)
56
+
57
+ @spaces.GPU
58
+ def handle_input(direction):
59
+ print(f"User clicked: {direction}")
60
+ new_image = genie(direction)
61
+ return new_image
62
+
63
+ @spaces.GPU
64
+ def handle_image_selection(image_name, state):
65
+ print(f"User selected image: {image_name}")
66
+ return initialize_simulator(image_name, state)
67
+
68
+ genie = GenieSimulator(
69
+ image_encoder_type='temporalvae',
70
+ image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
71
+ quantize=False,
72
+ backbone_type='stmar',
73
+ backbone_ckpt='data/mar_ckpt/langtable',
74
+ prompt_horizon=PROMPT_HORIZON,
75
+ action_stride=1,
76
+ domain='language_table',
77
+ device="cuda"
78
+ )
79
+
80
+ image = Image.open("sim/assets/langtable_prompt/frame_06.png")
81
+ prompt_image = np.tile(
82
+ np.array(image), (genie.prompt_horizon, 1, 1, 1)
83
+ ).astype(np.uint8)
84
+ prompt_action = np.zeros(
85
+ (genie.prompt_horizon, genie.action_stride, 2)
86
+ ).astype(np.float32)
87
+ genie.set_initial_state((prompt_image, prompt_action))
88
+ genie.device = "cuda"
89
+
90
+
91
+
92
+ if __name__ == '__main__':
93
+ with gr.Blocks() as demo:
94
+ genie.device = "cuda"
95
+ with gr.Row():
96
+ gr.Textbox(label='HMA Demo: Select a prompt initial image from the gallery and Interact with arrow keys. \n'
97
+ 'Note: the speed is limited due to free GPU in HF and the interface supports one user at a time.', lines=1)
98
+ with gr.Row():
99
+ image_selector = gr.Dropdown(
100
+ choices=available_images, value=available_images[0], label="Select an Image"
101
+ )
102
+ select_button = gr.Button("Load Image")
103
+
104
+ with gr.Row():
105
+ image_display = gr.Image(type="pil", label="Generated Image")
106
+
107
+ with gr.Row():
108
+ up = gr.Button("↑ Up")
109
+ with gr.Row():
110
+ left = gr.Button("← Left")
111
+ down = gr.Button("↓ Down")
112
+ right = gr.Button("→ Right")
113
+
114
+ # Define interactions
115
+ select_button.click(
116
+ fn=handle_image_selection, inputs=[image_selector, genie], outputs=image_display, show_progress='hidden'
117
+ )
118
+ up.click(fn=lambda: handle_input("up"), outputs=image_display, show_progress='hidden')
119
+ down.click(fn=lambda: handle_input("down"), outputs=image_display, show_progress='hidden')
120
+ left.click(fn=lambda: handle_input("left"), outputs=image_display, show_progress='hidden')
121
+ right.click(fn=lambda: handle_input("right"), outputs=image_display, show_progress='hidden')
122
+ demo.launch(share=True)
app.py CHANGED
@@ -1,13 +1,13 @@
1
  import gradio as gr
2
  import spaces
3
 
4
-
5
- import gradio as gr
6
  import numpy as np
7
  from PIL import Image
8
  import cv2
9
  from sim.simulator import GenieSimulator
10
  import os
 
 
11
 
12
  if not os.path.exists("data/mar_ckpt/langtable"):
13
  # download from google drive
@@ -23,12 +23,20 @@ IMAGE_DIR = "sim/assets/langtable_prompt/"
23
  available_images = sorted([img for img in os.listdir(IMAGE_DIR) if img.endswith(".png")])
24
 
25
 
26
-
 
 
 
 
 
 
 
 
 
27
 
28
  # Helper function to reset GenieSimulator with the selected image
29
-
30
  @spaces.GPU
31
- def initialize_simulator(image_name, genie):
32
  image_path = os.path.join(IMAGE_DIR, image_name)
33
  image = Image.open(image_path)
34
  prompt_image = np.tile(np.array(image), (genie.prompt_horizon, 1, 1, 1)).astype(np.uint8)
@@ -38,8 +46,9 @@ def initialize_simulator(image_name, genie):
38
  reset_image = cv2.resize(reset_image, (RES, RES))
39
  return Image.fromarray(reset_image)
40
 
 
41
  @spaces.GPU
42
- def model(direction, genie):
43
  if direction == 'right':
44
  action = np.array([0, 0.05])
45
  elif direction == 'left':
@@ -54,48 +63,25 @@ def model(direction, genie):
54
  next_image = cv2.resize(next_image, (RES, RES))
55
  return Image.fromarray(next_image)
56
 
 
57
  @spaces.GPU
58
  def handle_input(direction):
59
  print(f"User clicked: {direction}")
60
- new_image = genie(direction)
61
  return new_image
62
 
 
63
  @spaces.GPU
64
- def handle_image_selection(image_name, state):
65
  print(f"User selected image: {image_name}")
66
- return initialize_simulator(image_name, state)
67
-
68
- genie = GenieSimulator(
69
- image_encoder_type='temporalvae',
70
- image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
71
- quantize=False,
72
- backbone_type='stmar',
73
- backbone_ckpt='data/mar_ckpt/langtable',
74
- prompt_horizon=PROMPT_HORIZON,
75
- action_stride=1,
76
- domain='language_table',
77
- device="cuda"
78
- )
79
-
80
- image = Image.open("sim/assets/langtable_prompt/frame_06.png")
81
- prompt_image = np.tile(
82
- np.array(image), (genie.prompt_horizon, 1, 1, 1)
83
- ).astype(np.uint8)
84
- prompt_action = np.zeros(
85
- (genie.prompt_horizon, genie.action_stride, 2)
86
- ).astype(np.float32)
87
- genie.set_initial_state((prompt_image, prompt_action))
88
- genie.device = "cuda"
89
-
90
-
91
 
92
  if __name__ == '__main__':
93
  with gr.Blocks() as demo:
94
- genie_instance = gr.State({'genie': genie})
95
- genie.device = "cuda"
96
  with gr.Row():
97
  gr.Textbox(label='HMA Demo: Select a prompt initial image from the gallery and Interact with arrow keys. \n'
98
- 'Note: the speed is limited due to free GPU in HF and the interface ', lines=1)
 
99
  with gr.Row():
100
  image_selector = gr.Dropdown(
101
  choices=available_images, value=available_images[0], label="Select an Image"
@@ -114,10 +100,11 @@ if __name__ == '__main__':
114
 
115
  # Define interactions
116
  select_button.click(
117
- fn=handle_image_selection, inputs=[image_selector, genie], outputs=image_display, show_progress='hidden'
118
  )
119
  up.click(fn=lambda: handle_input("up"), outputs=image_display, show_progress='hidden')
120
  down.click(fn=lambda: handle_input("down"), outputs=image_display, show_progress='hidden')
121
  left.click(fn=lambda: handle_input("left"), outputs=image_display, show_progress='hidden')
122
  right.click(fn=lambda: handle_input("right"), outputs=image_display, show_progress='hidden')
 
123
  demo.launch(share=True)
 
1
  import gradio as gr
2
  import spaces
3
 
 
 
4
  import numpy as np
5
  from PIL import Image
6
  import cv2
7
  from sim.simulator import GenieSimulator
8
  import os
9
+ import spaces
10
+
11
 
12
  if not os.path.exists("data/mar_ckpt/langtable"):
13
  # download from google drive
 
23
  available_images = sorted([img for img in os.listdir(IMAGE_DIR) if img.endswith(".png")])
24
 
25
 
26
+ genie = GenieSimulator(
27
+ image_encoder_type='temporalvae',
28
+ image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
29
+ quantize=False,
30
+ backbone_type='stmar',
31
+ backbone_ckpt='data/mar_ckpt_long2/langtable',
32
+ prompt_horizon=PROMPT_HORIZON,
33
+ action_stride=1,
34
+ domain='language_table',
35
+ )
36
 
37
  # Helper function to reset GenieSimulator with the selected image
 
38
  @spaces.GPU
39
+ def initialize_simulator(image_name):
40
  image_path = os.path.join(IMAGE_DIR, image_name)
41
  image = Image.open(image_path)
42
  prompt_image = np.tile(np.array(image), (genie.prompt_horizon, 1, 1, 1)).astype(np.uint8)
 
46
  reset_image = cv2.resize(reset_image, (RES, RES))
47
  return Image.fromarray(reset_image)
48
 
49
+ # Example model: takes a direction and returns a random image
50
  @spaces.GPU
51
+ def model(direction: str):
52
  if direction == 'right':
53
  action = np.array([0, 0.05])
54
  elif direction == 'left':
 
63
  next_image = cv2.resize(next_image, (RES, RES))
64
  return Image.fromarray(next_image)
65
 
66
+ # Gradio function to handle user input
67
  @spaces.GPU
68
  def handle_input(direction):
69
  print(f"User clicked: {direction}")
70
+ new_image = model(direction) # Get a new image from the model
71
  return new_image
72
 
73
+ # Gradio function to handle image selection
74
  @spaces.GPU
75
+ def handle_image_selection(image_name):
76
  print(f"User selected image: {image_name}")
77
+ return initialize_simulator(image_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  if __name__ == '__main__':
80
  with gr.Blocks() as demo:
 
 
81
  with gr.Row():
82
  gr.Textbox(label='HMA Demo: Select a prompt initial image from the gallery and Interact with arrow keys. \n'
83
+ 'Note: the speed is limited due to free GPU in HF and the interface supports one user at a time.', lines=1)
84
+
85
  with gr.Row():
86
  image_selector = gr.Dropdown(
87
  choices=available_images, value=available_images[0], label="Select an Image"
 
100
 
101
  # Define interactions
102
  select_button.click(
103
+ fn=handle_image_selection, inputs=image_selector, outputs=image_display
104
  )
105
  up.click(fn=lambda: handle_input("up"), outputs=image_display, show_progress='hidden')
106
  down.click(fn=lambda: handle_input("down"), outputs=image_display, show_progress='hidden')
107
  left.click(fn=lambda: handle_input("left"), outputs=image_display, show_progress='hidden')
108
  right.click(fn=lambda: handle_input("right"), outputs=image_display, show_progress='hidden')
109
+
110
  demo.launch(share=True)