liruiw commited on
Commit
8c3783f
·
1 Parent(s): a853f7d
Files changed (2) hide show
  1. app.py +26 -22
  2. sim/simulator.py +1 -0
app.py CHANGED
@@ -65,31 +65,35 @@ def handle_image_selection(image_name, state):
65
  print(f"User selected image: {image_name}")
66
  return initialize_simulator(image_name, state)
67
 
68
- if __name__ == '__main__':
69
- genie = GenieSimulator(
70
- image_encoder_type='temporalvae',
71
- image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
72
- quantize=False,
73
- backbone_type='stmar',
74
- backbone_ckpt='data/mar_ckpt/langtable',
75
- prompt_horizon=PROMPT_HORIZON,
76
- action_stride=1,
77
- domain='language_table',
78
- device="cpu"
79
- )
 
 
 
 
 
 
 
 
 
 
 
 
80
 
 
 
81
 
82
  with gr.Blocks() as demo:
83
- image = Image.open("sim/assets/langtable_prompt/frame_06.png")
84
- prompt_image = np.tile(
85
- np.array(image), (genie.prompt_horizon, 1, 1, 1)
86
- ).astype(np.uint8)
87
- prompt_action = np.zeros(
88
- (genie.prompt_horizon, genie.action_stride, 2)
89
- ).astype(np.float32)
90
- genie.set_initial_state((prompt_image, prompt_action))
91
- genie.device = "cuda"
92
- image = genie.reset()
93
  genie_instance = gr.State({'genie': genie})
94
 
95
  with gr.Row():
 
65
  print(f"User selected image: {image_name}")
66
  return initialize_simulator(image_name, state)
67
 
68
+ @spaces.GPU
69
+ def init_model():
70
+ genie = GenieSimulator(
71
+ image_encoder_type='temporalvae',
72
+ image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
73
+ quantize=False,
74
+ backbone_type='stmar',
75
+ backbone_ckpt='data/mar_ckpt/langtable',
76
+ prompt_horizon=PROMPT_HORIZON,
77
+ action_stride=1,
78
+ domain='language_table',
79
+ )
80
+
81
+ image = Image.open("sim/assets/langtable_prompt/frame_06.png")
82
+ prompt_image = np.tile(
83
+ np.array(image), (genie.prompt_horizon, 1, 1, 1)
84
+ ).astype(np.uint8)
85
+ prompt_action = np.zeros(
86
+ (genie.prompt_horizon, genie.action_stride, 2)
87
+ ).astype(np.float32)
88
+ genie.set_initial_state((prompt_image, prompt_action))
89
+ genie.device = "cuda"
90
+ image = genie.reset()
91
+ return genie
92
 
93
+ if __name__ == '__main__':
94
+ genie = init_model()
95
 
96
  with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
97
  genie_instance = gr.State({'genie': genie})
98
 
99
  with gr.Row():
sim/simulator.py CHANGED
@@ -245,6 +245,7 @@ class GenieSimulator(LearnedSimulator):
245
  if self.gauss_act_perturb_scale is not None:
246
  action = np.random.normal(action, self.gauss_act_perturb_scale)
247
 
 
248
  # encoding
249
  input_latent_states = torch.cat([
250
  self.cached_latent_frames,
 
245
  if self.gauss_act_perturb_scale is not None:
246
  action = np.random.normal(action, self.gauss_act_perturb_scale)
247
 
248
+ self.backbone = backbone.to(device=self.device).eval()
249
  # encoding
250
  input_latent_states = torch.cat([
251
  self.cached_latent_frames,