Spaces:
Running
on
Zero
Running
on
Zero
improve pred
Browse files- app.py +1 -1
- sim/simulator.py +9 -5
app.py
CHANGED
@@ -14,7 +14,7 @@ genie = GenieSimulator(
|
|
14 |
quantize=False,
|
15 |
backbone_type='stmar',
|
16 |
backbone_ckpt='data/mar_ckpt/langtable',
|
17 |
-
prompt_horizon=
|
18 |
action_stride=1,
|
19 |
domain='language_table',
|
20 |
)
|
|
|
14 |
quantize=False,
|
15 |
backbone_type='stmar',
|
16 |
backbone_ckpt='data/mar_ckpt/langtable',
|
17 |
+
prompt_horizon=2,
|
18 |
action_stride=1,
|
19 |
domain='language_table',
|
20 |
)
|
sim/simulator.py
CHANGED
@@ -248,22 +248,26 @@ class GenieSimulator(LearnedSimulator):
|
|
248 |
# encoding
|
249 |
input_latent_states = torch.cat([
|
250 |
self.cached_latent_frames,
|
251 |
-
torch.zeros_like(self.cached_latent_frames[
|
252 |
]).unsqueeze(0).to(torch.float32)
|
253 |
|
|
|
|
|
254 |
# dtype conversion and mask token
|
255 |
if self.backbone_type == "stmaskgit":
|
256 |
input_latent_states = input_latent_states.long()
|
257 |
-
input_latent_states[:,
|
258 |
elif self.backbone_type == "stmar":
|
259 |
-
input_latent_states[:,
|
260 |
|
261 |
# dynamics rollout
|
262 |
action = torch.from_numpy(action).to(device=self.device)
|
263 |
input_actions = torch.cat([ # (1, prompt_horizon + 1, action_stride * A)
|
264 |
self.cached_actions,
|
265 |
-
action.unsqueeze(0)
|
266 |
-
|
|
|
|
|
267 |
|
268 |
if self.measure_step_time:
|
269 |
start_time = time.time()
|
|
|
248 |
# encoding
|
249 |
input_latent_states = torch.cat([
|
250 |
self.cached_latent_frames,
|
251 |
+
torch.zeros_like(self.cached_latent_frames[[0]]),
|
252 |
]).unsqueeze(0).to(torch.float32)
|
253 |
|
254 |
+
input_latent_states = input_latent_states[:, :self.prompt_horizon + 1]
|
255 |
+
|
256 |
# dtype conversion and mask token
|
257 |
if self.backbone_type == "stmaskgit":
|
258 |
input_latent_states = input_latent_states.long()
|
259 |
+
input_latent_states[:, -1] = self.backbone.mask_token_id
|
260 |
elif self.backbone_type == "stmar":
|
261 |
+
input_latent_states[:, -1] = self.backbone.mask_token
|
262 |
|
263 |
# dynamics rollout
|
264 |
action = torch.from_numpy(action).to(device=self.device)
|
265 |
input_actions = torch.cat([ # (1, prompt_horizon + 1, action_stride * A)
|
266 |
self.cached_actions,
|
267 |
+
action.unsqueeze(0),
|
268 |
+
action.unsqueeze(0) # the last action is not used, but we need a_{t-1}, s_{t-1} to predict s_t
|
269 |
+
]).view(1, -1, action.shape[-1]).to(torch.float32) # + 1
|
270 |
+
input_actions = input_actions[:, :self.prompt_horizon + 1]
|
271 |
|
272 |
if self.measure_step_time:
|
273 |
start_time = time.time()
|