Commit
·
257dc46
1
Parent(s):
6373d0a
experimenting stuff
Browse files
app.py
CHANGED
@@ -57,6 +57,56 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
57 |
DEFAULT_WIDTH = 832
|
58 |
DEFAULT_HEIGHT = 480
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
# --- Argument Parsing ---
|
61 |
parser = argparse.ArgumentParser(description="Gradio Demo for Self-Forcing with Frame Streaming")
|
62 |
parser.add_argument('--port', type=int, default=7860, help="Port to run the Gradio app on.")
|
@@ -290,7 +340,8 @@ def video_generation_handler(prompt, seed=42, fps=15, width=DEFAULT_WIDTH, heigh
|
|
290 |
|
291 |
vae_cache, latents_cache = None, None
|
292 |
if not APP_STATE["current_use_taehv"] and not args.trt:
|
293 |
-
|
|
|
294 |
|
295 |
# Calculate number of blocks based on duration
|
296 |
# Current setup generates approximately 5 seconds with 7 blocks
|
|
|
57 |
DEFAULT_WIDTH = 832
|
58 |
DEFAULT_HEIGHT = 480
|
59 |
|
60 |
+
def create_vae_cache_for_resolution(latent_height, latent_width, device, dtype):
|
61 |
+
"""
|
62 |
+
Create VAE cache tensors dynamically based on the latent resolution.
|
63 |
+
The cache structure mirrors ZERO_VAE_CACHE but with resolution-dependent dimensions.
|
64 |
+
"""
|
65 |
+
# Scale dimensions based on latent resolution
|
66 |
+
# The original cache assumes 832x480 -> 104x60 latent dimensions
|
67 |
+
# We need to scale proportionally for other resolutions
|
68 |
+
|
69 |
+
cache = [
|
70 |
+
torch.zeros(1, 16, 2, latent_height, latent_width, device=device, dtype=dtype),
|
71 |
+
# First set of 384-channel caches at latent resolution
|
72 |
+
torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype),
|
73 |
+
torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype),
|
74 |
+
torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype),
|
75 |
+
torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype),
|
76 |
+
torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype),
|
77 |
+
torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype),
|
78 |
+
torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype),
|
79 |
+
torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype),
|
80 |
+
torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype),
|
81 |
+
torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype),
|
82 |
+
torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype),
|
83 |
+
# Second set at 2x upsampled resolution
|
84 |
+
torch.zeros(1, 192, 2, latent_height * 2, latent_width * 2, device=device, dtype=dtype),
|
85 |
+
torch.zeros(1, 384, 2, latent_height * 2, latent_width * 2, device=device, dtype=dtype),
|
86 |
+
torch.zeros(1, 384, 2, latent_height * 2, latent_width * 2, device=device, dtype=dtype),
|
87 |
+
torch.zeros(1, 384, 2, latent_height * 2, latent_width * 2, device=device, dtype=dtype),
|
88 |
+
torch.zeros(1, 384, 2, latent_height * 2, latent_width * 2, device=device, dtype=dtype),
|
89 |
+
torch.zeros(1, 384, 2, latent_height * 2, latent_width * 2, device=device, dtype=dtype),
|
90 |
+
torch.zeros(1, 384, 2, latent_height * 2, latent_width * 2, device=device, dtype=dtype),
|
91 |
+
# Third set at 4x upsampled resolution
|
92 |
+
torch.zeros(1, 192, 2, latent_height * 4, latent_width * 4, device=device, dtype=dtype),
|
93 |
+
torch.zeros(1, 192, 2, latent_height * 4, latent_width * 4, device=device, dtype=dtype),
|
94 |
+
torch.zeros(1, 192, 2, latent_height * 4, latent_width * 4, device=device, dtype=dtype),
|
95 |
+
torch.zeros(1, 192, 2, latent_height * 4, latent_width * 4, device=device, dtype=dtype),
|
96 |
+
torch.zeros(1, 192, 2, latent_height * 4, latent_width * 4, device=device, dtype=dtype),
|
97 |
+
torch.zeros(1, 192, 2, latent_height * 4, latent_width * 4, device=device, dtype=dtype),
|
98 |
+
# Fourth set at 8x upsampled resolution (final output resolution)
|
99 |
+
torch.zeros(1, 96, 2, latent_height * 8, latent_width * 8, device=device, dtype=dtype),
|
100 |
+
torch.zeros(1, 96, 2, latent_height * 8, latent_width * 8, device=device, dtype=dtype),
|
101 |
+
torch.zeros(1, 96, 2, latent_height * 8, latent_width * 8, device=device, dtype=dtype),
|
102 |
+
torch.zeros(1, 96, 2, latent_height * 8, latent_width * 8, device=device, dtype=dtype),
|
103 |
+
torch.zeros(1, 96, 2, latent_height * 8, latent_width * 8, device=device, dtype=dtype),
|
104 |
+
torch.zeros(1, 96, 2, latent_height * 8, latent_width * 8, device=device, dtype=dtype),
|
105 |
+
torch.zeros(1, 96, 2, latent_height * 8, latent_width * 8, device=device, dtype=dtype)
|
106 |
+
]
|
107 |
+
|
108 |
+
return cache
|
109 |
+
|
110 |
# --- Argument Parsing ---
|
111 |
parser = argparse.ArgumentParser(description="Gradio Demo for Self-Forcing with Frame Streaming")
|
112 |
parser.add_argument('--port', type=int, default=7860, help="Port to run the Gradio app on.")
|
|
|
340 |
|
341 |
vae_cache, latents_cache = None, None
|
342 |
if not APP_STATE["current_use_taehv"] and not args.trt:
|
343 |
+
# Create resolution-dependent VAE cache
|
344 |
+
vae_cache = create_vae_cache_for_resolution(latent_height, latent_width, device=gpu, dtype=torch.float16)
|
345 |
|
346 |
# Calculate number of blocks based on duration
|
347 |
# Current setup generates approximately 5 seconds with 7 blocks
|