liruiw commited on
Commit
e176061
·
1 Parent(s): eb5272e
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -26,6 +26,7 @@ available_images = sorted([img for img in os.listdir(IMAGE_DIR) if img.endswith(
26
 
27
 
28
  # Helper function to reset GenieSimulator with the selected image
 
29
  def initialize_simulator(image_name):
30
  image_path = os.path.join(IMAGE_DIR, image_name)
31
  image = Image.open(image_path)
@@ -37,6 +38,7 @@ def initialize_simulator(image_name):
37
  return Image.fromarray(reset_image)
38
 
39
  # Example model: takes a direction and returns a random image
 
40
  def model(direction: str):
41
  if direction == 'right':
42
  action = np.array([0, 0.05])
@@ -60,13 +62,13 @@ def handle_input(direction):
60
  return new_image
61
 
62
  # Gradio function to handle image selection
 
63
  def handle_image_selection(image_name):
64
  print(f"User selected image: {image_name}")
65
  return initialize_simulator(image_name)
66
 
67
  if __name__ == '__main__':
68
- with gr.Blocks() as demo:
69
- genie = GenieSimulator(
70
  image_encoder_type='temporalvae',
71
  image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
72
  quantize=False,
@@ -77,6 +79,7 @@ if __name__ == '__main__':
77
  domain='language_table',
78
  )
79
 
 
80
  with gr.Row():
81
  image_selector = gr.Dropdown(
82
  choices=available_images, value=available_images[0], label="Select an Image"
 
26
 
27
 
28
  # Helper function to reset GenieSimulator with the selected image
29
+ @spaces.GPU
30
  def initialize_simulator(image_name):
31
  image_path = os.path.join(IMAGE_DIR, image_name)
32
  image = Image.open(image_path)
 
38
  return Image.fromarray(reset_image)
39
 
40
  # Example model: takes a direction and returns a random image
41
+ @spaces.GPU
42
  def model(direction: str):
43
  if direction == 'right':
44
  action = np.array([0, 0.05])
 
62
  return new_image
63
 
64
  # Gradio function to handle image selection
65
+ @spaces.GPU
66
  def handle_image_selection(image_name):
67
  print(f"User selected image: {image_name}")
68
  return initialize_simulator(image_name)
69
 
70
  if __name__ == '__main__':
71
+ genie = GenieSimulator(
 
72
  image_encoder_type='temporalvae',
73
  image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
74
  quantize=False,
 
79
  domain='language_table',
80
  )
81
 
82
+ with gr.Blocks() as demo:
83
  with gr.Row():
84
  image_selector = gr.Dropdown(
85
  choices=available_images, value=available_images[0], label="Select an Image"