liruiw commited on
Commit
a5cbf39
·
1 Parent(s): 557e9af
Files changed (2) hide show
  1. app.py +8 -3
  2. sim/simulator.py +2 -0
app.py CHANGED
@@ -46,7 +46,7 @@ def initialize_simulator(image_name):
46
  return Image.fromarray(reset_image)
47
 
48
  @spaces.GPU
49
- def model(direction, genie):
50
  if direction == 'right':
51
  action = np.array([0, 0.05])
52
  elif direction == 'left':
@@ -57,14 +57,19 @@ def model(direction, genie):
57
  action = np.array([-0.05, 0])
58
  else:
59
  raise ValueError(f"Invalid direction: {direction}")
60
- next_image = genie.step(action, cached_latent_frames=cached_latent_frames, cached_actions=cached_actions)['pred_next_frame']
 
 
 
 
 
61
  next_image = cv2.resize(next_image, (RES, RES))
62
  return Image.fromarray(next_image)
63
 
64
  @spaces.GPU
65
  def handle_input(direction):
66
  print(f"User clicked: {direction}")
67
- new_image = genie(direction)
68
  return new_image
69
 
70
  @spaces.GPU
 
46
  return Image.fromarray(reset_image)
47
 
48
  @spaces.GPU
49
+ def model(direction):
50
  if direction == 'right':
51
  action = np.array([0, 0.05])
52
  elif direction == 'left':
 
57
  action = np.array([-0.05, 0])
58
  else:
59
  raise ValueError(f"Invalid direction: {direction}")
60
+ genie_result = genie.step(action, cached_latent_frames=cached_latent_frames, cached_actions=cached_actions)
61
+ next_image = genie_result['pred_next_frame']
62
+ global cached_latent_frames
63
+ global cached_actions
64
+ cached_latent_frames = genie_result['set_cached_latent_frames']
65
+ cached_actions = genie_result['set_cached_actions']
66
  next_image = cv2.resize(next_image, (RES, RES))
67
  return Image.fromarray(next_image)
68
 
69
  @spaces.GPU
70
  def handle_input(direction):
71
  print(f"User clicked: {direction}")
72
+ new_image = model(direction)
73
  return new_image
74
 
75
  @spaces.GPU
sim/simulator.py CHANGED
@@ -366,6 +366,8 @@ class GenieSimulator(LearnedSimulator):
366
  pred_next_frame = self.post_processor(pred_next_frame, action)
367
 
368
  self.step_count += 1
 
 
369
 
370
  return step_result
371
 
 
366
  pred_next_frame = self.post_processor(pred_next_frame, action)
367
 
368
  self.step_count += 1
369
+ step_result['cached_actions'] = cached_actions
370
+ step_result['cached_latent_frames'] = cached_latent_frames
371
 
372
  return step_result
373