Spaces:
Running
on
Zero
Running
on
Zero
update
Browse files- app.py +8 -3
- sim/simulator.py +2 -0
app.py
CHANGED
@@ -46,7 +46,7 @@ def initialize_simulator(image_name):
|
|
46 |
return Image.fromarray(reset_image)
|
47 |
|
48 |
@spaces.GPU
|
49 |
-
def model(direction
|
50 |
if direction == 'right':
|
51 |
action = np.array([0, 0.05])
|
52 |
elif direction == 'left':
|
@@ -57,14 +57,19 @@ def model(direction, genie):
|
|
57 |
action = np.array([-0.05, 0])
|
58 |
else:
|
59 |
raise ValueError(f"Invalid direction: {direction}")
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
61 |
next_image = cv2.resize(next_image, (RES, RES))
|
62 |
return Image.fromarray(next_image)
|
63 |
|
64 |
@spaces.GPU
|
65 |
def handle_input(direction):
|
66 |
print(f"User clicked: {direction}")
|
67 |
-
new_image =
|
68 |
return new_image
|
69 |
|
70 |
@spaces.GPU
|
|
|
46 |
return Image.fromarray(reset_image)
|
47 |
|
48 |
@spaces.GPU
|
49 |
+
def model(direction):
|
50 |
if direction == 'right':
|
51 |
action = np.array([0, 0.05])
|
52 |
elif direction == 'left':
|
|
|
57 |
action = np.array([-0.05, 0])
|
58 |
else:
|
59 |
raise ValueError(f"Invalid direction: {direction}")
|
60 |
+
genie_result = genie.step(action, cached_latent_frames=cached_latent_frames, cached_actions=cached_actions)
|
61 |
+
next_image = genie_result['pred_next_frame']
|
62 |
+
global cached_latent_frames
|
63 |
+
global cached_actions
|
64 |
+
cached_latent_frames = genie_result['set_cached_latent_frames']
|
65 |
+
cached_actions = genie_result['set_cached_actions']
|
66 |
next_image = cv2.resize(next_image, (RES, RES))
|
67 |
return Image.fromarray(next_image)
|
68 |
|
69 |
@spaces.GPU
|
70 |
def handle_input(direction):
|
71 |
print(f"User clicked: {direction}")
|
72 |
+
new_image = model(direction)
|
73 |
return new_image
|
74 |
|
75 |
@spaces.GPU
|
sim/simulator.py
CHANGED
@@ -366,6 +366,8 @@ class GenieSimulator(LearnedSimulator):
|
|
366 |
pred_next_frame = self.post_processor(pred_next_frame, action)
|
367 |
|
368 |
self.step_count += 1
|
|
|
|
|
369 |
|
370 |
return step_result
|
371 |
|
|
|
366 |
pred_next_frame = self.post_processor(pred_next_frame, action)
|
367 |
|
368 |
self.step_count += 1
|
369 |
+
step_result['cached_actions'] = cached_actions
|
370 |
+
step_result['cached_latent_frames'] = cached_latent_frames
|
371 |
|
372 |
return step_result
|
373 |
|