liruiw commited on
Commit
aa82009
Β·
1 Parent(s): 7140e01
Files changed (3) hide show
  1. .gradio/certificate.pem +31 -0
  2. app copy.py +117 -0
  3. app.py +35 -49
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
app copy.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ from PIL import Image
8
+ import cv2
9
+ from sim.simulator import GenieSimulator
10
+ import os
11
+
12
+ if not os.path.exists("data/mar_ckpt/langtable"):
13
+ # download from google drive
14
+ import gdown
15
+ gdown.download_folder("https://drive.google.com/drive/u/2/folders/1XU87cRqV-IMZA6RLiabIR_uZngynvUFN")
16
+ os.system("mkdir -p data/mar_ckpt/; mv langtable data/mar_ckpt/")
17
+
18
+ RES = 512
19
+ PROMPT_HORIZON = 3
20
+ IMAGE_DIR = "sim/assets/langtable_prompt/"
21
+
22
+ # Load available images
23
+ available_images = sorted([img for img in os.listdir(IMAGE_DIR) if img.endswith(".png")])
24
+
25
+
26
+ @spaces.GPU
27
+ def initialize_simulator(image_name, state):
28
+ image_path = os.path.join(IMAGE_DIR, image_name)
29
+ image = Image.open(image_path)
30
+ prompt_image = np.tile(np.array(image), (state['genie'].prompt_horizon, 1, 1, 1)).astype(np.uint8)
31
+ prompt_action = np.zeros((state['genie'].prompt_horizon - 1, state['genie'].action_stride, 2)).astype(np.float32)
32
+ state['genie'].set_initial_state((prompt_image, prompt_action))
33
+ reset_image = state['genie'].reset()
34
+ reset_image = cv2.resize(reset_image, (RES, RES))
35
+ return Image.fromarray(reset_image)
36
+
37
+ @spaces.GPU
38
+ def model(direction, state):
39
+ if direction == 'right':
40
+ action = np.array([0, 0.05])
41
+ elif direction == 'left':
42
+ action = np.array([0, -0.05])
43
+ elif direction == 'down':
44
+ action = np.array([0.05, 0])
45
+ elif direction == 'up':
46
+ action = np.array([-0.05, 0])
47
+ else:
48
+ raise ValueError(f"Invalid direction: {direction}")
49
+ next_image = state['genie'].step(action)['pred_next_frame']
50
+ next_image = cv2.resize(next_image, (RES, RES))
51
+ return Image.fromarray(next_image)
52
+
53
+ @spaces.GPU
54
+ def handle_input(direction, state):
55
+ print(f"User clicked: {direction}")
56
+ new_image = model(direction, state)
57
+ return new_image
58
+
59
+ @spaces.GPU
60
+ def handle_image_selection(image_name, state):
61
+ print(f"User selected image: {image_name}")
62
+ return initialize_simulator(image_name, state)
63
+
64
+ def init_model():
65
+ genie = GenieSimulator(
66
+ image_encoder_type='temporalvae',
67
+ image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
68
+ quantize=False,
69
+ backbone_type='stmar',
70
+ backbone_ckpt='data/mar_ckpt/langtable',
71
+ prompt_horizon=PROMPT_HORIZON,
72
+ action_stride=1,
73
+ domain='language_table',
74
+ )
75
+
76
+ image = Image.open("sim/assets/langtable_prompt/frame_06.png")
77
+ prompt_image = np.tile(
78
+ np.array(image), (genie.prompt_horizon, 1, 1, 1)
79
+ ).astype(np.uint8)
80
+ prompt_action = np.zeros(
81
+ (genie.prompt_horizon, genie.action_stride, 2)
82
+ ).astype(np.float32)
83
+ return genie
84
+
85
+ if __name__ == '__main__':
86
+
87
+ with gr.Blocks() as demo:
88
+ genie = init_model()
89
+ genie_instance = gr.State({
90
+ 'genie': genie})
91
+ with gr.Row():
92
+ image_selector = gr.Dropdown(
93
+ choices=available_images, value=available_images[0], label="Select an Image"
94
+ )
95
+ select_button = gr.Button("Load Image")
96
+
97
+ with gr.Row():
98
+ image_display = gr.Image(type="pil", label="Generated Image")
99
+
100
+ select_button.click(
101
+ fn=handle_image_selection, inputs=[image_selector, genie_instance], outputs=image_display, show_progress='hidden'
102
+ )
103
+
104
+ with gr.Row():
105
+ up = gr.Button("↑ Up")
106
+ with gr.Row():
107
+ left = gr.Button("← Left")
108
+ down = gr.Button("↓ Down")
109
+ right = gr.Button("β†’ Right")
110
+
111
+
112
+ up.click(fn=lambda state: handle_input("up", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
113
+ down.click(fn=lambda state: handle_input("down", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
114
+ left.click(fn=lambda state: handle_input("left", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
115
+ right.click(fn=lambda state: handle_input("right", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
116
+
117
+ demo.launch()
app.py CHANGED
@@ -23,19 +23,30 @@ IMAGE_DIR = "sim/assets/langtable_prompt/"
23
  available_images = sorted([img for img in os.listdir(IMAGE_DIR) if img.endswith(".png")])
24
 
25
 
26
- @spaces.GPU
27
- def initialize_simulator(image_name, state):
 
 
 
 
 
 
 
 
 
 
 
28
  image_path = os.path.join(IMAGE_DIR, image_name)
29
  image = Image.open(image_path)
30
- prompt_image = np.tile(np.array(image), (state['genie'].prompt_horizon, 1, 1, 1)).astype(np.uint8)
31
- prompt_action = np.zeros((state['genie'].prompt_horizon - 1, state['genie'].action_stride, 2)).astype(np.float32)
32
- state['genie'].set_initial_state((prompt_image, prompt_action))
33
- reset_image = state['genie'].reset()
34
  reset_image = cv2.resize(reset_image, (RES, RES))
35
  return Image.fromarray(reset_image)
36
 
37
- @spaces.GPU
38
- def model(direction, state):
39
  if direction == 'right':
40
  action = np.array([0, 0.05])
41
  elif direction == 'left':
@@ -46,48 +57,24 @@ def model(direction, state):
46
  action = np.array([-0.05, 0])
47
  else:
48
  raise ValueError(f"Invalid direction: {direction}")
49
- next_image = state['genie'].step(action)['pred_next_frame']
50
  next_image = cv2.resize(next_image, (RES, RES))
51
  return Image.fromarray(next_image)
52
 
 
53
  @spaces.GPU
54
- def handle_input(direction, state):
55
  print(f"User clicked: {direction}")
56
- new_image = model(direction, state)
57
  return new_image
58
 
59
- @spaces.GPU
60
- def handle_image_selection(image_name, state):
61
  print(f"User selected image: {image_name}")
62
- return initialize_simulator(image_name, state)
63
-
64
- def init_model():
65
- genie = GenieSimulator(
66
- image_encoder_type='temporalvae',
67
- image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
68
- quantize=False,
69
- backbone_type='stmar',
70
- backbone_ckpt='data/mar_ckpt/langtable',
71
- prompt_horizon=PROMPT_HORIZON,
72
- action_stride=1,
73
- domain='language_table',
74
- )
75
-
76
- image = Image.open("sim/assets/langtable_prompt/frame_06.png")
77
- prompt_image = np.tile(
78
- np.array(image), (genie.prompt_horizon, 1, 1, 1)
79
- ).astype(np.uint8)
80
- prompt_action = np.zeros(
81
- (genie.prompt_horizon, genie.action_stride, 2)
82
- ).astype(np.float32)
83
- return genie
84
 
85
  if __name__ == '__main__':
86
-
87
  with gr.Blocks() as demo:
88
- genie = init_model()
89
- genie_instance = gr.State({
90
- 'genie': genie})
91
  with gr.Row():
92
  image_selector = gr.Dropdown(
93
  choices=available_images, value=available_images[0], label="Select an Image"
@@ -97,10 +84,6 @@ if __name__ == '__main__':
97
  with gr.Row():
98
  image_display = gr.Image(type="pil", label="Generated Image")
99
 
100
- select_button.click(
101
- fn=handle_image_selection, inputs=[image_selector, genie_instance], outputs=image_display, show_progress='hidden'
102
- )
103
-
104
  with gr.Row():
105
  up = gr.Button("↑ Up")
106
  with gr.Row():
@@ -108,10 +91,13 @@ if __name__ == '__main__':
108
  down = gr.Button("↓ Down")
109
  right = gr.Button("β†’ Right")
110
 
 
 
 
 
 
 
 
 
111
 
112
- up.click(fn=lambda state: handle_input("up", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
113
- down.click(fn=lambda state: handle_input("down", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
114
- left.click(fn=lambda state: handle_input("left", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
115
- right.click(fn=lambda state: handle_input("right", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
116
-
117
- demo.launch()
 
23
  available_images = sorted([img for img in os.listdir(IMAGE_DIR) if img.endswith(".png")])
24
 
25
 
26
+ genie = GenieSimulator(
27
+ image_encoder_type='temporalvae',
28
+ image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
29
+ quantize=False,
30
+ backbone_type='stmar',
31
+ backbone_ckpt='data/mar_ckpt/langtable',
32
+ prompt_horizon=PROMPT_HORIZON,
33
+ action_stride=1,
34
+ domain='language_table',
35
+ )
36
+
37
+ # Helper function to reset GenieSimulator with the selected image
38
+ def initialize_simulator(image_name):
39
  image_path = os.path.join(IMAGE_DIR, image_name)
40
  image = Image.open(image_path)
41
+ prompt_image = np.tile(np.array(image), (genie.prompt_horizon, 1, 1, 1)).astype(np.uint8)
42
+ prompt_action = np.zeros((genie.prompt_horizon - 1, genie.action_stride, 2)).astype(np.float32)
43
+ genie.set_initial_state((prompt_image, prompt_action))
44
+ reset_image = genie.reset()
45
  reset_image = cv2.resize(reset_image, (RES, RES))
46
  return Image.fromarray(reset_image)
47
 
48
+ # Example model: takes a direction and returns a random image
49
+ def model(direction: str):
50
  if direction == 'right':
51
  action = np.array([0, 0.05])
52
  elif direction == 'left':
 
57
  action = np.array([-0.05, 0])
58
  else:
59
  raise ValueError(f"Invalid direction: {direction}")
60
+ next_image = genie.step(action)['pred_next_frame']
61
  next_image = cv2.resize(next_image, (RES, RES))
62
  return Image.fromarray(next_image)
63
 
64
+ # Gradio function to handle user input
65
  @spaces.GPU
66
+ def handle_input(direction):
67
  print(f"User clicked: {direction}")
68
+ new_image = model(direction) # Get a new image from the model
69
  return new_image
70
 
71
+ # Gradio function to handle image selection
72
+ def handle_image_selection(image_name):
73
  print(f"User selected image: {image_name}")
74
+ return initialize_simulator(image_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  if __name__ == '__main__':
 
77
  with gr.Blocks() as demo:
 
 
 
78
  with gr.Row():
79
  image_selector = gr.Dropdown(
80
  choices=available_images, value=available_images[0], label="Select an Image"
 
84
  with gr.Row():
85
  image_display = gr.Image(type="pil", label="Generated Image")
86
 
 
 
 
 
87
  with gr.Row():
88
  up = gr.Button("↑ Up")
89
  with gr.Row():
 
91
  down = gr.Button("↓ Down")
92
  right = gr.Button("β†’ Right")
93
 
94
+ # Define interactions
95
+ select_button.click(
96
+ fn=handle_image_selection, inputs=image_selector, outputs=image_display
97
+ )
98
+ up.click(fn=lambda: handle_input("up"), outputs=image_display, show_progress='hidden')
99
+ down.click(fn=lambda: handle_input("down"), outputs=image_display, show_progress='hidden')
100
+ left.click(fn=lambda: handle_input("left"), outputs=image_display, show_progress='hidden')
101
+ right.click(fn=lambda: handle_input("right"), outputs=image_display, show_progress='hidden')
102
 
103
+ demo.launch(share=True)