liruiw commited on
Commit
8eeb719
1 Parent(s): 13dce27

improve pred

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. sim/simulator.py +9 -5
app.py CHANGED
@@ -14,7 +14,7 @@ genie = GenieSimulator(
14
  quantize=False,
15
  backbone_type='stmar',
16
  backbone_ckpt='data/mar_ckpt/langtable',
17
- prompt_horizon=5,
18
  action_stride=1,
19
  domain='language_table',
20
  )
 
14
  quantize=False,
15
  backbone_type='stmar',
16
  backbone_ckpt='data/mar_ckpt/langtable',
17
+ prompt_horizon=2,
18
  action_stride=1,
19
  domain='language_table',
20
  )
sim/simulator.py CHANGED
@@ -248,22 +248,26 @@ class GenieSimulator(LearnedSimulator):
248
  # encoding
249
  input_latent_states = torch.cat([
250
  self.cached_latent_frames,
251
- torch.zeros_like(self.cached_latent_frames[-1:]),
252
  ]).unsqueeze(0).to(torch.float32)
253
 
 
 
254
  # dtype conversion and mask token
255
  if self.backbone_type == "stmaskgit":
256
  input_latent_states = input_latent_states.long()
257
- input_latent_states[:, self.prompt_horizon] = self.backbone.mask_token_id
258
  elif self.backbone_type == "stmar":
259
- input_latent_states[:, self.prompt_horizon] = self.backbone.mask_token
260
 
261
  # dynamics rollout
262
  action = torch.from_numpy(action).to(device=self.device)
263
  input_actions = torch.cat([ # (1, prompt_horizon + 1, action_stride * A)
264
  self.cached_actions,
265
- action.unsqueeze(0)
266
- ]).view(1, self.prompt_horizon + 1, -1).to(torch.float32)
 
 
267
 
268
  if self.measure_step_time:
269
  start_time = time.time()
 
248
  # encoding
249
  input_latent_states = torch.cat([
250
  self.cached_latent_frames,
251
+ torch.zeros_like(self.cached_latent_frames[[0]]),
252
  ]).unsqueeze(0).to(torch.float32)
253
 
254
+ input_latent_states = input_latent_states[:, :self.prompt_horizon + 1]
255
+
256
  # dtype conversion and mask token
257
  if self.backbone_type == "stmaskgit":
258
  input_latent_states = input_latent_states.long()
259
+ input_latent_states[:, -1] = self.backbone.mask_token_id
260
  elif self.backbone_type == "stmar":
261
+ input_latent_states[:, -1] = self.backbone.mask_token
262
 
263
  # dynamics rollout
264
  action = torch.from_numpy(action).to(device=self.device)
265
  input_actions = torch.cat([ # (1, prompt_horizon + 1, action_stride * A)
266
  self.cached_actions,
267
+ action.unsqueeze(0),
268
+ action.unsqueeze(0) # the last action is not used, but we need a_{t-1}, s_{t-1} to predict s_t
269
+ ]).view(1, -1, action.shape[-1]).to(torch.float32) # + 1
270
+ input_actions = input_actions[:, :self.prompt_horizon + 1]
271
 
272
  if self.measure_step_time:
273
  start_time = time.time()