liruiw commited on
Commit
b258516
·
1 Parent(s): aca205f
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -23,7 +23,7 @@ IMAGE_DIR = "sim/assets/langtable_prompt/"
23
  available_images = sorted([img for img in os.listdir(IMAGE_DIR) if img.endswith(".png")])
24
 
25
 
26
-
27
  def initialize_simulator(image_name, state):
28
  image_path = os.path.join(IMAGE_DIR, image_name)
29
  image = Image.open(image_path)
@@ -34,6 +34,7 @@ def initialize_simulator(image_name, state):
34
  reset_image = cv2.resize(reset_image, (RES, RES))
35
  return Image.fromarray(reset_image)
36
 
 
37
  def model(direction, state):
38
  if direction == 'right':
39
  action = np.array([0, 0.05])
@@ -49,11 +50,13 @@ def model(direction, state):
49
  next_image = cv2.resize(next_image, (RES, RES))
50
  return Image.fromarray(next_image)
51
 
 
52
  def handle_input(direction, state):
53
  print(f"User clicked: {direction}")
54
  new_image = model(direction, state)
55
  return new_image
56
 
 
57
  def handle_image_selection(image_name, state):
58
  print(f"User selected image: {image_name}")
59
  return initialize_simulator(image_name, state)
 
23
  available_images = sorted([img for img in os.listdir(IMAGE_DIR) if img.endswith(".png")])
24
 
25
 
26
+ @spaces.GPU
27
  def initialize_simulator(image_name, state):
28
  image_path = os.path.join(IMAGE_DIR, image_name)
29
  image = Image.open(image_path)
 
34
  reset_image = cv2.resize(reset_image, (RES, RES))
35
  return Image.fromarray(reset_image)
36
 
37
+ @spaces.GPU
38
  def model(direction, state):
39
  if direction == 'right':
40
  action = np.array([0, 0.05])
 
50
  next_image = cv2.resize(next_image, (RES, RES))
51
  return Image.fromarray(next_image)
52
 
53
+ @spaces.GPU
54
  def handle_input(direction, state):
55
  print(f"User clicked: {direction}")
56
  new_image = model(direction, state)
57
  return new_image
58
 
59
+ @spaces.GPU
60
  def handle_image_selection(image_name, state):
61
  print(f"User selected image: {image_name}")
62
  return initialize_simulator(image_name, state)