liruiw commited on
Commit
aca205f
·
1 Parent(s): 1df3733

try gr state

Browse files
Files changed (45) hide show
  1. __pycache__/cont_data.cpython-310.pyc +0 -0
  2. __pycache__/data.cpython-310.pyc +0 -0
  3. __pycache__/train_diffusion.cpython-310.pyc +0 -0
  4. __pycache__/visualize.cpython-310.pyc +0 -0
  5. app (Copy).py +120 -0
  6. app.py +31 -51
  7. common/__pycache__/__init__.cpython-310.pyc +0 -0
  8. common/__pycache__/eval_utils.cpython-310.pyc +0 -0
  9. data/mar_ckpt/langtable/random_states_0.pkl +0 -0
  10. datasets/__pycache__/__init__.cpython-310.pyc +0 -0
  11. datasets/__pycache__/encode_openx_dataset.cpython-310.pyc +0 -0
  12. datasets/__pycache__/utils.cpython-310.pyc +0 -0
  13. genie/__pycache__/__init__.cpython-310.pyc +0 -0
  14. genie/__pycache__/attention.cpython-310.pyc +0 -0
  15. genie/__pycache__/config.cpython-310.pyc +0 -0
  16. genie/__pycache__/diffloss.cpython-310.pyc +0 -0
  17. genie/__pycache__/factorization_utils.cpython-310.pyc +0 -0
  18. genie/__pycache__/st_mar.cpython-310.pyc +0 -0
  19. genie/__pycache__/st_mask_git.cpython-310.pyc +0 -0
  20. genie/__pycache__/st_transformer.cpython-310.pyc +0 -0
  21. genie/diffusion/__pycache__/__init__.cpython-310.pyc +0 -0
  22. genie/diffusion/__pycache__/diffusion_utils.cpython-310.pyc +0 -0
  23. genie/diffusion/__pycache__/gaussian_diffusion.cpython-310.pyc +0 -0
  24. genie/diffusion/__pycache__/respace.cpython-310.pyc +0 -0
  25. magvit2/__pycache__/__init__.cpython-310.pyc +0 -0
  26. magvit2/__pycache__/config.cpython-310.pyc +0 -0
  27. magvit2/__pycache__/util.cpython-310.pyc +0 -0
  28. magvit2/models/__pycache__/__init__.cpython-310.pyc +0 -0
  29. magvit2/models/__pycache__/lfqgan.cpython-310.pyc +0 -0
  30. magvit2/modules/__pycache__/__init__.cpython-310.pyc +0 -0
  31. magvit2/modules/__pycache__/ema.cpython-310.pyc +0 -0
  32. magvit2/modules/__pycache__/util.cpython-310.pyc +0 -0
  33. magvit2/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc +0 -0
  34. magvit2/modules/diffusionmodules/__pycache__/improved_model.cpython-310.pyc +0 -0
  35. magvit2/modules/discriminator/__pycache__/__init__.cpython-310.pyc +0 -0
  36. magvit2/modules/discriminator/__pycache__/model.cpython-310.pyc +0 -0
  37. magvit2/modules/losses/__pycache__/__init__.cpython-310.pyc +0 -0
  38. magvit2/modules/losses/__pycache__/lpips.cpython-310.pyc +0 -0
  39. magvit2/modules/losses/__pycache__/vqperceptual.cpython-310.pyc +0 -0
  40. magvit2/modules/scheduler/__pycache__/__init__.cpython-310.pyc +0 -0
  41. magvit2/modules/scheduler/__pycache__/lr_scheduler.cpython-310.pyc +0 -0
  42. magvit2/modules/vqvae/__pycache__/__init__.cpython-310.pyc +0 -0
  43. magvit2/modules/vqvae/__pycache__/lookup_free_quantize.cpython-310.pyc +0 -0
  44. sim/__pycache__/__init__.cpython-310.pyc +0 -0
  45. sim/__pycache__/simulator.cpython-310.pyc +0 -0
__pycache__/cont_data.cpython-310.pyc ADDED
Binary file (7.68 kB). View file
 
__pycache__/data.cpython-310.pyc ADDED
Binary file (7.77 kB). View file
 
__pycache__/train_diffusion.cpython-310.pyc ADDED
Binary file (22.3 kB). View file
 
__pycache__/visualize.cpython-310.pyc ADDED
Binary file (9.31 kB). View file
 
app (Copy).py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ from PIL import Image
8
+ import cv2
9
+ from sim.simulator import GenieSimulator
10
+ import os
11
+
12
+ if not os.path.exists("data/mar_ckpt/langtable"):
13
+ # download from google drive
14
+ import gdown
15
+ gdown.download_folder("https://drive.google.com/drive/u/2/folders/1XU87cRqV-IMZA6RLiabIR_uZngynvUFN")
16
+ os.system("mkdir -p data/mar_ckpt/; mv langtable data/mar_ckpt/")
17
+
18
+ RES = 512
19
+ PROMPT_HORIZON = 3
20
+ IMAGE_DIR = "sim/assets/langtable_prompt/"
21
+
22
+ # Load available images
23
+ available_images = sorted([img for img in os.listdir(IMAGE_DIR) if img.endswith(".png")])
24
+
25
+
26
+ # Helper function to reset GenieSimulator with the selected image
27
+ @spaces.GPU
28
+ def initialize_simulator(image_name):
29
+ global genie
30
+ image_path = os.path.join(IMAGE_DIR, image_name)
31
+ image = Image.open(image_path)
32
+ prompt_image = np.tile(np.array(image), (genie.prompt_horizon, 1, 1, 1)).astype(np.uint8)
33
+ prompt_action = np.zeros((genie.prompt_horizon - 1, genie.action_stride, 2)).astype(np.float32)
34
+ genie.set_initial_state((prompt_image, prompt_action))
35
+ reset_image = genie.reset()
36
+ reset_image = cv2.resize(reset_image, (RES, RES))
37
+ return Image.fromarray(reset_image)
38
+
39
+ # Example model: takes a direction and returns a random image
40
+ @spaces.GPU
41
+ def model(direction: str):
42
+ global genie
43
+ if direction == 'right':
44
+ action = np.array([0, 0.05])
45
+ elif direction == 'left':
46
+ action = np.array([0, -0.05])
47
+ elif direction == 'down':
48
+ action = np.array([0.05, 0])
49
+ elif direction == 'up':
50
+ action = np.array([-0.05, 0])
51
+ else:
52
+ raise ValueError(f"Invalid direction: {direction}")
53
+ next_image = genie.step(action)['pred_next_frame']
54
+ next_image = cv2.resize(next_image, (RES, RES))
55
+ return Image.fromarray(next_image)
56
+
57
+ # Gradio function to handle user input
58
+ @spaces.GPU
59
+ def handle_input(direction):
60
+ print(f"User clicked: {direction}")
61
+ new_image = model(direction) # Get a new image from the model
62
+ return new_image
63
+
64
+ # Gradio function to handle image selection
65
+ @spaces.GPU
66
+ def handle_image_selection(image_name):
67
+ print(f"User selected image: {image_name}")
68
+ return initialize_simulator(image_name)
69
+
70
+ if __name__ == '__main__':
71
+ genie = GenieSimulator(
72
+ image_encoder_type='temporalvae',
73
+ image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
74
+ quantize=False,
75
+ backbone_type='stmar',
76
+ backbone_ckpt='data/mar_ckpt/langtable',
77
+ prompt_horizon=PROMPT_HORIZON,
78
+ action_stride=1,
79
+ domain='language_table',
80
+ )
81
+ image = Image.open("sim/assets/langtable_prompt/frame_06.png")
82
+ prompt_image = np.tile(
83
+ np.array(image), (genie.prompt_horizon, 1, 1, 1)
84
+ ).astype(np.uint8)
85
+ prompt_action = np.zeros(
86
+ (genie.prompt_horizon, genie.action_stride, 2)
87
+ ).astype(np.float32)
88
+
89
+ genie.set_initial_state((prompt_image, prompt_action))
90
+ image = genie.reset()
91
+
92
+ with gr.Blocks() as demo:
93
+ with gr.Row():
94
+ image_selector = gr.Dropdown(
95
+ choices=available_images, value=available_images[0], label="Select an Image"
96
+ )
97
+ select_button = gr.Button("Load Image")
98
+
99
+ with gr.Row():
100
+ image_display = gr.Image(type="pil", label="Generated Image")
101
+
102
+ with gr.Row():
103
+ up = gr.Button("↑ Up")
104
+ with gr.Row():
105
+ left = gr.Button("← Left")
106
+ down = gr.Button("↓ Down")
107
+ right = gr.Button("→ Right")
108
+
109
+
110
+
111
+ # Define interactions
112
+ select_button.click(
113
+ fn=handle_image_selection, inputs=image_selector, outputs=image_display
114
+ )
115
+ up.click(fn=lambda: handle_input("up"), outputs=image_display, show_progress='hidden')
116
+ down.click(fn=lambda: handle_input("down"), outputs=image_display, show_progress='hidden')
117
+ left.click(fn=lambda: handle_input("left"), outputs=image_display, show_progress='hidden')
118
+ right.click(fn=lambda: handle_input("right"), outputs=image_display, show_progress='hidden')
119
+
120
+ demo.launch()
app.py CHANGED
@@ -23,23 +23,18 @@ IMAGE_DIR = "sim/assets/langtable_prompt/"
23
  available_images = sorted([img for img in os.listdir(IMAGE_DIR) if img.endswith(".png")])
24
 
25
 
26
- # Helper function to reset GenieSimulator with the selected image
27
- @spaces.GPU
28
- def initialize_simulator(image_name):
29
- global genie
30
  image_path = os.path.join(IMAGE_DIR, image_name)
31
  image = Image.open(image_path)
32
- prompt_image = np.tile(np.array(image), (genie.prompt_horizon, 1, 1, 1)).astype(np.uint8)
33
- prompt_action = np.zeros((genie.prompt_horizon - 1, genie.action_stride, 2)).astype(np.float32)
34
- genie.set_initial_state((prompt_image, prompt_action))
35
- reset_image = genie.reset()
36
  reset_image = cv2.resize(reset_image, (RES, RES))
37
  return Image.fromarray(reset_image)
38
 
39
- # Example model: takes a direction and returns a random image
40
- @spaces.GPU
41
- def model(direction: str):
42
- global genie
43
  if direction == 'right':
44
  action = np.array([0, 0.05])
45
  elif direction == 'left':
@@ -50,46 +45,34 @@ def model(direction: str):
50
  action = np.array([-0.05, 0])
51
  else:
52
  raise ValueError(f"Invalid direction: {direction}")
53
- next_image = genie.step(action)['pred_next_frame']
54
  next_image = cv2.resize(next_image, (RES, RES))
55
  return Image.fromarray(next_image)
56
 
57
- # Gradio function to handle user input
58
- @spaces.GPU
59
- def handle_input(direction):
60
  print(f"User clicked: {direction}")
61
- new_image = model(direction) # Get a new image from the model
62
  return new_image
63
 
64
- # Gradio function to handle image selection
65
- @spaces.GPU
66
- def handle_image_selection(image_name):
67
  print(f"User selected image: {image_name}")
68
- return initialize_simulator(image_name)
69
 
70
  if __name__ == '__main__':
71
- genie = GenieSimulator(
72
- image_encoder_type='temporalvae',
73
- image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
74
- quantize=False,
75
- backbone_type='stmar',
76
- backbone_ckpt='data/mar_ckpt/langtable',
77
- prompt_horizon=PROMPT_HORIZON,
78
- action_stride=1,
79
- domain='language_table',
80
- )
81
- image = Image.open("sim/assets/langtable_prompt/frame_06.png")
82
- prompt_image = np.tile(
83
- np.array(image), (genie.prompt_horizon, 1, 1, 1)
84
- ).astype(np.uint8)
85
- prompt_action = np.zeros(
86
- (genie.prompt_horizon, genie.action_stride, 2)
87
- ).astype(np.float32)
88
-
89
- genie.set_initial_state((prompt_image, prompt_action))
90
- image = genie.reset()
91
-
92
  with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  with gr.Row():
94
  image_selector = gr.Dropdown(
95
  choices=available_images, value=available_images[0], label="Select an Image"
@@ -106,15 +89,12 @@ if __name__ == '__main__':
106
  down = gr.Button("↓ Down")
107
  right = gr.Button("→ Right")
108
 
109
-
110
-
111
- # Define interactions
112
  select_button.click(
113
- fn=handle_image_selection, inputs=image_selector, outputs=image_display
114
  )
115
- up.click(fn=lambda: handle_input("up"), outputs=image_display, show_progress='hidden')
116
- down.click(fn=lambda: handle_input("down"), outputs=image_display, show_progress='hidden')
117
- left.click(fn=lambda: handle_input("left"), outputs=image_display, show_progress='hidden')
118
- right.click(fn=lambda: handle_input("right"), outputs=image_display, show_progress='hidden')
119
 
120
- demo.launch()
 
23
  available_images = sorted([img for img in os.listdir(IMAGE_DIR) if img.endswith(".png")])
24
 
25
 
26
+
27
+ def initialize_simulator(image_name, state):
 
 
28
  image_path = os.path.join(IMAGE_DIR, image_name)
29
  image = Image.open(image_path)
30
+ prompt_image = np.tile(np.array(image), (state['genie'].prompt_horizon, 1, 1, 1)).astype(np.uint8)
31
+ prompt_action = np.zeros((state['genie'].prompt_horizon - 1, state['genie'].action_stride, 2)).astype(np.float32)
32
+ state['genie'].set_initial_state((prompt_image, prompt_action))
33
+ reset_image = state['genie'].reset()
34
  reset_image = cv2.resize(reset_image, (RES, RES))
35
  return Image.fromarray(reset_image)
36
 
37
+ def model(direction, state):
 
 
 
38
  if direction == 'right':
39
  action = np.array([0, 0.05])
40
  elif direction == 'left':
 
45
  action = np.array([-0.05, 0])
46
  else:
47
  raise ValueError(f"Invalid direction: {direction}")
48
+ next_image = state['genie'].step(action)['pred_next_frame']
49
  next_image = cv2.resize(next_image, (RES, RES))
50
  return Image.fromarray(next_image)
51
 
52
+ def handle_input(direction, state):
 
 
53
  print(f"User clicked: {direction}")
54
+ new_image = model(direction, state)
55
  return new_image
56
 
57
+ def handle_image_selection(image_name, state):
 
 
58
  print(f"User selected image: {image_name}")
59
+ return initialize_simulator(image_name, state)
60
 
61
  if __name__ == '__main__':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  with gr.Blocks() as demo:
63
+ genie_instance = gr.State({
64
+ 'genie': GenieSimulator(
65
+ image_encoder_type='temporalvae',
66
+ image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
67
+ quantize=False,
68
+ backbone_type='stmar',
69
+ backbone_ckpt='data/mar_ckpt/langtable',
70
+ prompt_horizon=PROMPT_HORIZON,
71
+ action_stride=1,
72
+ domain='language_table',
73
+ )
74
+ })
75
+
76
  with gr.Row():
77
  image_selector = gr.Dropdown(
78
  choices=available_images, value=available_images[0], label="Select an Image"
 
89
  down = gr.Button("↓ Down")
90
  right = gr.Button("→ Right")
91
 
 
 
 
92
  select_button.click(
93
+ fn=handle_image_selection, inputs=[image_selector, genie_instance], outputs=image_display, show_progress='hidden'
94
  )
95
+ up.click(fn=lambda state: handle_input("up", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
96
+ down.click(fn=lambda state: handle_input("down", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
97
+ left.click(fn=lambda state: handle_input("left", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
98
+ right.click(fn=lambda state: handle_input("right", state), inputs=[genie_instance], outputs=image_display, show_progress='hidden')
99
 
100
+ demo.launch()
common/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (137 Bytes). View file
 
common/__pycache__/eval_utils.cpython-310.pyc ADDED
Binary file (4.6 kB). View file
 
data/mar_ckpt/langtable/random_states_0.pkl CHANGED
Binary files a/data/mar_ckpt/langtable/random_states_0.pkl and b/data/mar_ckpt/langtable/random_states_0.pkl differ
 
datasets/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (139 Bytes). View file
 
datasets/__pycache__/encode_openx_dataset.cpython-310.pyc ADDED
Binary file (13.1 kB). View file
 
datasets/__pycache__/utils.cpython-310.pyc ADDED
Binary file (5.36 kB). View file
 
genie/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (136 Bytes). View file
 
genie/__pycache__/attention.cpython-310.pyc ADDED
Binary file (4.4 kB). View file
 
genie/__pycache__/config.cpython-310.pyc ADDED
Binary file (4.37 kB). View file
 
genie/__pycache__/diffloss.cpython-310.pyc ADDED
Binary file (8.17 kB). View file
 
genie/__pycache__/factorization_utils.cpython-310.pyc ADDED
Binary file (4.05 kB). View file
 
genie/__pycache__/st_mar.cpython-310.pyc ADDED
Binary file (13.8 kB). View file
 
genie/__pycache__/st_mask_git.cpython-310.pyc ADDED
Binary file (20.6 kB). View file
 
genie/__pycache__/st_transformer.cpython-310.pyc ADDED
Binary file (5.18 kB). View file
 
genie/diffusion/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.01 kB). View file
 
genie/diffusion/__pycache__/diffusion_utils.cpython-310.pyc ADDED
Binary file (2.26 kB). View file
 
genie/diffusion/__pycache__/gaussian_diffusion.cpython-310.pyc ADDED
Binary file (24.3 kB). View file
 
genie/diffusion/__pycache__/respace.cpython-310.pyc ADDED
Binary file (4.97 kB). View file
 
magvit2/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (138 Bytes). View file
 
magvit2/__pycache__/config.cpython-310.pyc ADDED
Binary file (2.07 kB). View file
 
magvit2/__pycache__/util.cpython-310.pyc ADDED
Binary file (1.61 kB). View file
 
magvit2/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (145 Bytes). View file
 
magvit2/models/__pycache__/lfqgan.cpython-310.pyc ADDED
Binary file (8.87 kB). View file
 
magvit2/modules/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (146 Bytes). View file
 
magvit2/modules/__pycache__/ema.cpython-310.pyc ADDED
Binary file (3.3 kB). View file
 
magvit2/modules/__pycache__/util.cpython-310.pyc ADDED
Binary file (4.45 kB). View file
 
magvit2/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (163 Bytes). View file
 
magvit2/modules/diffusionmodules/__pycache__/improved_model.cpython-310.pyc ADDED
Binary file (5.57 kB). View file
 
magvit2/modules/discriminator/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (160 Bytes). View file
 
magvit2/modules/discriminator/__pycache__/model.cpython-310.pyc ADDED
Binary file (2.34 kB). View file
 
magvit2/modules/losses/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (222 Bytes). View file
 
magvit2/modules/losses/__pycache__/lpips.cpython-310.pyc ADDED
Binary file (5.37 kB). View file
 
magvit2/modules/losses/__pycache__/vqperceptual.cpython-310.pyc ADDED
Binary file (7.42 kB). View file
 
magvit2/modules/scheduler/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (156 Bytes). View file
 
magvit2/modules/scheduler/__pycache__/lr_scheduler.cpython-310.pyc ADDED
Binary file (989 Bytes). View file
 
magvit2/modules/vqvae/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (152 Bytes). View file
 
magvit2/modules/vqvae/__pycache__/lookup_free_quantize.cpython-310.pyc ADDED
Binary file (8.29 kB). View file
 
sim/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (134 Bytes). View file
 
sim/__pycache__/simulator.cpython-310.pyc ADDED
Binary file (13.4 kB). View file