jbilcke-hf HF Staff commited on
Commit
257dc46
·
1 Parent(s): 6373d0a

experimenting stuff

Browse files
Files changed (1) hide show
  1. app.py +52 -1
app.py CHANGED
@@ -57,6 +57,56 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
57
  DEFAULT_WIDTH = 832
58
  DEFAULT_HEIGHT = 480
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  # --- Argument Parsing ---
61
  parser = argparse.ArgumentParser(description="Gradio Demo for Self-Forcing with Frame Streaming")
62
  parser.add_argument('--port', type=int, default=7860, help="Port to run the Gradio app on.")
@@ -290,7 +340,8 @@ def video_generation_handler(prompt, seed=42, fps=15, width=DEFAULT_WIDTH, heigh
290
 
291
  vae_cache, latents_cache = None, None
292
  if not APP_STATE["current_use_taehv"] and not args.trt:
293
- vae_cache = [c.to(device=gpu, dtype=torch.float16) for c in ZERO_VAE_CACHE]
 
294
 
295
  # Calculate number of blocks based on duration
296
  # Current setup generates approximately 5 seconds with 7 blocks
 
57
  DEFAULT_WIDTH = 832
58
  DEFAULT_HEIGHT = 480
59
 
60
+ def create_vae_cache_for_resolution(latent_height, latent_width, device, dtype):
61
+ """
62
+ Create VAE cache tensors dynamically based on the latent resolution.
63
+ The cache structure mirrors ZERO_VAE_CACHE but with resolution-dependent dimensions.
64
+ """
65
+ # Scale dimensions based on latent resolution
66
+ # The original cache assumes 832x480 -> 104x60 latent dimensions
67
+ # We need to scale proportionally for other resolutions
68
+
69
+ cache = [
70
+ torch.zeros(1, 16, 2, latent_height, latent_width, device=device, dtype=dtype),
71
+ # First set of 384-channel caches at latent resolution
72
+ torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype),
73
+ torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype),
74
+ torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype),
75
+ torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype),
76
+ torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype),
77
+ torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype),
78
+ torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype),
79
+ torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype),
80
+ torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype),
81
+ torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype),
82
+ torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype),
83
+ # Second set at 2x upsampled resolution
84
+ torch.zeros(1, 192, 2, latent_height * 2, latent_width * 2, device=device, dtype=dtype),
85
+ torch.zeros(1, 384, 2, latent_height * 2, latent_width * 2, device=device, dtype=dtype),
86
+ torch.zeros(1, 384, 2, latent_height * 2, latent_width * 2, device=device, dtype=dtype),
87
+ torch.zeros(1, 384, 2, latent_height * 2, latent_width * 2, device=device, dtype=dtype),
88
+ torch.zeros(1, 384, 2, latent_height * 2, latent_width * 2, device=device, dtype=dtype),
89
+ torch.zeros(1, 384, 2, latent_height * 2, latent_width * 2, device=device, dtype=dtype),
90
+ torch.zeros(1, 384, 2, latent_height * 2, latent_width * 2, device=device, dtype=dtype),
91
+ # Third set at 4x upsampled resolution
92
+ torch.zeros(1, 192, 2, latent_height * 4, latent_width * 4, device=device, dtype=dtype),
93
+ torch.zeros(1, 192, 2, latent_height * 4, latent_width * 4, device=device, dtype=dtype),
94
+ torch.zeros(1, 192, 2, latent_height * 4, latent_width * 4, device=device, dtype=dtype),
95
+ torch.zeros(1, 192, 2, latent_height * 4, latent_width * 4, device=device, dtype=dtype),
96
+ torch.zeros(1, 192, 2, latent_height * 4, latent_width * 4, device=device, dtype=dtype),
97
+ torch.zeros(1, 192, 2, latent_height * 4, latent_width * 4, device=device, dtype=dtype),
98
+ # Fourth set at 8x upsampled resolution (final output resolution)
99
+ torch.zeros(1, 96, 2, latent_height * 8, latent_width * 8, device=device, dtype=dtype),
100
+ torch.zeros(1, 96, 2, latent_height * 8, latent_width * 8, device=device, dtype=dtype),
101
+ torch.zeros(1, 96, 2, latent_height * 8, latent_width * 8, device=device, dtype=dtype),
102
+ torch.zeros(1, 96, 2, latent_height * 8, latent_width * 8, device=device, dtype=dtype),
103
+ torch.zeros(1, 96, 2, latent_height * 8, latent_width * 8, device=device, dtype=dtype),
104
+ torch.zeros(1, 96, 2, latent_height * 8, latent_width * 8, device=device, dtype=dtype),
105
+ torch.zeros(1, 96, 2, latent_height * 8, latent_width * 8, device=device, dtype=dtype)
106
+ ]
107
+
108
+ return cache
109
+
110
  # --- Argument Parsing ---
111
  parser = argparse.ArgumentParser(description="Gradio Demo for Self-Forcing with Frame Streaming")
112
  parser.add_argument('--port', type=int, default=7860, help="Port to run the Gradio app on.")
 
340
 
341
  vae_cache, latents_cache = None, None
342
  if not APP_STATE["current_use_taehv"] and not args.trt:
343
+ # Create resolution-dependent VAE cache
344
+ vae_cache = create_vae_cache_for_resolution(latent_height, latent_width, device=gpu, dtype=torch.float16)
345
 
346
  # Calculate number of blocks based on duration
347
  # Current setup generates approximately 5 seconds with 7 blocks