liruiw commited on
Commit
3e59747
·
1 Parent(s): fc52fb9
app.py CHANGED
@@ -61,9 +61,9 @@ def handle_image_selection(image_name, state):
61
  print(f"User selected image: {image_name}")
62
  return initialize_simulator(image_name, state)
63
 
64
- if __name__ == '__main__':
65
- with gr.Blocks() as demo:
66
- genie_instance = gr.State({
67
  'genie': GenieSimulator(
68
  image_encoder_type='temporalvae',
69
  image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
@@ -76,6 +76,10 @@ if __name__ == '__main__':
76
  )
77
  })
78
 
 
 
 
 
79
  with gr.Row():
80
  image_selector = gr.Dropdown(
81
  choices=available_images, value=available_images[0], label="Select an Image"
 
61
  print(f"User selected image: {image_name}")
62
  return initialize_simulator(image_name, state)
63
 
64
+ @spaces.GPU
65
+ def init_model():
66
+ return gr.State({
67
  'genie': GenieSimulator(
68
  image_encoder_type='temporalvae',
69
  image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
 
76
  )
77
  })
78
 
79
+ if __name__ == '__main__':
80
+ with gr.Blocks() as demo:
81
+ genie_instance = init_model()
82
+
83
  with gr.Row():
84
  image_selector = gr.Dropdown(
85
  choices=available_images, value=available_images[0], label="Select an Image"
sim/__pycache__/simulator.cpython-310.pyc CHANGED
Binary files a/sim/__pycache__/simulator.cpython-310.pyc and b/sim/__pycache__/simulator.cpython-310.pyc differ
 
sim/simulator.py CHANGED
@@ -137,6 +137,7 @@ class GenieSimulator(LearnedSimulator):
137
  allow_external_prompt: bool = False
138
  ):
139
  super().__init__()
 
140
 
141
  assert quantize == (image_encoder_type == "magvit"), \
142
  "Currently quantization if and only if magvit is the image encoder."
@@ -378,10 +379,11 @@ class GenieSimulator(LearnedSimulator):
378
  W //= self.quant_slice_size
379
  _, _, indices, _ = self.image_encoder.encode(image, flip=True)
380
  indices = einops.rearrange(indices, "(h w) -> h w", h=H, w=W)
381
- indices = indices.to(torch.int32)
382
  return indices
383
 
384
  else:
 
385
  if self.image_encoder_type == "magvit":
386
  latent = self.image_encoder.encode_without_quantize(image)
387
  elif self.image_encoder_type == "temporalvae":
@@ -391,7 +393,7 @@ class GenieSimulator(LearnedSimulator):
391
  latent = einops.rearrange(latent, "b c h w -> b h w c")
392
  else:
393
  pass
394
- latent = latent.squeeze(0).to(torch.float32)
395
  return latent
396
 
397
 
 
137
  allow_external_prompt: bool = False
138
  ):
139
  super().__init__()
140
+ device = "cuda"
141
 
142
  assert quantize == (image_encoder_type == "magvit"), \
143
  "Currently quantization if and only if magvit is the image encoder."
 
379
  W //= self.quant_slice_size
380
  _, _, indices, _ = self.image_encoder.encode(image, flip=True)
381
  indices = einops.rearrange(indices, "(h w) -> h w", h=H, w=W)
382
+ indices = indices.to(torch.int32).to(self.device)
383
  return indices
384
 
385
  else:
386
+ self.image_encoder = self.image_encoder.to(device=self.device)
387
  if self.image_encoder_type == "magvit":
388
  latent = self.image_encoder.encode_without_quantize(image)
389
  elif self.image_encoder_type == "temporalvae":
 
393
  latent = einops.rearrange(latent, "b c h w -> b h w c")
394
  else:
395
  pass
396
+ latent = latent.squeeze(0).to(torch.float32).to(self.device)
397
  return latent
398
 
399