seawolf2357 commited on
Commit
c2c95c1
ยท
verified ยท
1 Parent(s): a34249d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -15
app.py CHANGED
@@ -110,7 +110,7 @@ parser.add_argument("--checkpoint_path", type=str, default='./checkpoints/self_f
110
  parser.add_argument("--config_path", type=str, default='./configs/self_forcing_dmd.yaml', help="Path to the model config.")
111
  parser.add_argument('--share', action='store_true', help="Create a public Gradio link.")
112
  parser.add_argument('--trt', action='store_true', help="Use TensorRT optimized VAE decoder.")
113
- parser.add_argument('--fps', type=float, default=15.0, help="Playback FPS for frame streaming.")
114
  args = parser.parse_args()
115
 
116
  gpu = "cuda"
@@ -257,7 +257,7 @@ pipeline.to(dtype=torch.float16).to(gpu)
257
 
258
  @torch.no_grad()
259
  @spaces.GPU
260
- def video_generation_handler_streaming(prompt, seed=42, fps=15):
261
  """
262
  Generator function that yields .ts video chunks using PyAV for streaming.
263
  Now optimized for block-based processing.
@@ -277,14 +277,14 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
277
  pipeline._initialize_kv_cache(1, torch.float16, device=gpu)
278
  pipeline._initialize_crossattn_cache(1, torch.float16, device=gpu)
279
 
280
- # 5.5์ดˆ ์˜์ƒ์„ ์œ„ํ•ด ๋…ธ์ด์ฆˆ ํ…์„œ ํฌ๊ธฐ ์ฆ๊ฐ€ (21 -> 24)
281
- noise = torch.randn([1, 24, 16, 60, 104], device=gpu, dtype=torch.float16, generator=rnd)
282
 
283
  vae_cache, latents_cache = None, None
284
  if not APP_STATE["current_use_taehv"] and not args.trt:
285
  vae_cache = [c.to(device=gpu, dtype=torch.float16) for c in ZERO_VAE_CACHE]
286
 
287
- num_blocks = 8 # 7 -> 8๋กœ ์ฆ๊ฐ€ํ•˜์—ฌ ์•ฝ 5.5์ดˆ ์˜์ƒ ์ƒ์„ฑ
288
  current_start_frame = 0
289
  all_num_frames = [pipeline.num_frame_per_block] * num_blocks
290
 
@@ -369,7 +369,7 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
369
 
370
  frame_status_html = (
371
  f"<div style='padding: 10px; border: 1px solid #ddd; border-radius: 8px; font-family: sans-serif;'>"
372
- f" <p style='margin: 0 0 8px 0; font-size: 16px; font-weight: bold;'>Generating Video...</p>"
373
  f" <div style='background: #e9ecef; border-radius: 4px; width: 100%; overflow: hidden;'>"
374
  f" <div style='width: {total_progress:.1f}%; height: 20px; background-color: #0d6efd; transition: width 0.2s;'></div>"
375
  f" </div>"
@@ -407,7 +407,7 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
407
  current_start_frame += current_num_frames
408
 
409
  # ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์„ฑ์„ ์œ„ํ•œ GPU ์บ์‹œ ์ •๋ฆฌ
410
- if idx < num_blocks - 1 and idx % 2 == 1: # 2๋ธ”๋ก๋งˆ๋‹ค ์บ์‹œ ์ •๋ฆฌ
411
  torch.cuda.empty_cache()
412
 
413
  # Final completion status
@@ -456,7 +456,7 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
456
  f" ๐Ÿ“Š Generated {total_frames_yielded} frames across {num_blocks} blocks ({video_duration:.1f} seconds)"
457
  f" </p>"
458
  f" <p style='margin: 0; color: #0f5132; font-size: 14px;'>"
459
- f" ๐ŸŽฌ Resolution: {all_frames_for_download[0].shape[1]}x{all_frames_for_download[0].shape[0]} โ€ข FPS: {fps} โ€ข Size: {file_size_mb:.1f} MB"
460
  f" </p>"
461
  f" <p style='margin: 8px 0 0 0; color: #0f5132; font-size: 13px; font-style: italic;'>"
462
  f" ๐Ÿ’พ Click the download button below to save your video!"
@@ -479,8 +479,8 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
479
 
480
  # --- Gradio UI Layout ---
481
  with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
482
- gr.Markdown("# ๐Ÿš€ Self-Forcing Video Generation (6-second)")
483
- gr.Markdown("Real-time 6-second video generation with distilled Wan2-1 1.3B [[Model]](https://huggingface.co/gdhe17/Self-Forcing), [[Project page]](https://self-forcing.github.io), [[Paper]](https://huggingface.co/papers/2506.08009)")
484
 
485
  with gr.Row():
486
  with gr.Column(scale=2):
@@ -506,6 +506,7 @@ with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
506
  )
507
 
508
  gr.Markdown("### โš™๏ธ Settings")
 
509
  with gr.Row():
510
  seed = gr.Number(
511
  label="Seed",
@@ -515,12 +516,12 @@ with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
515
  )
516
  fps = gr.Slider(
517
  label="Playback FPS",
518
- minimum=1,
519
  maximum=30,
520
  value=args.fps,
521
  step=1,
522
- visible=False,
523
- info="Frames per second for playback"
524
  )
525
 
526
  with gr.Column(scale=3):
@@ -548,8 +549,9 @@ with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
548
 
549
  # ๋‹ค์šด๋กœ๋“œ์šฉ ํŒŒ์ผ ์ถœ๋ ฅ
550
  download_file = gr.File(
551
- label="๐Ÿ“ฅ Download Video",
552
- visible=False
 
553
  )
554
 
555
  # Connect the generator to the streaming video
 
110
  parser.add_argument("--config_path", type=str, default='./configs/self_forcing_dmd.yaml', help="Path to the model config.")
111
  parser.add_argument('--share', action='store_true', help="Create a public Gradio link.")
112
  parser.add_argument('--trt', action='store_true', help="Use TensorRT optimized VAE decoder.")
113
+ parser.add_argument('--fps', type=float, default=12.0, help="Playback FPS for frame streaming.")
114
  args = parser.parse_args()
115
 
116
  gpu = "cuda"
 
257
 
258
  @torch.no_grad()
259
  @spaces.GPU
260
+ def video_generation_handler_streaming(prompt, seed=42, fps=12):
261
  """
262
  Generator function that yields .ts video chunks using PyAV for streaming.
263
  Now optimized for block-based processing.
 
277
  pipeline._initialize_kv_cache(1, torch.float16, device=gpu)
278
  pipeline._initialize_crossattn_cache(1, torch.float16, device=gpu)
279
 
280
+ # ๋…ธ์ด์ฆˆ ํ…์„œ ํฌ๊ธฐ
281
+ noise = torch.randn([1, 21, 16, 60, 104], device=gpu, dtype=torch.float16, generator=rnd)
282
 
283
  vae_cache, latents_cache = None, None
284
  if not APP_STATE["current_use_taehv"] and not args.trt:
285
  vae_cache = [c.to(device=gpu, dtype=torch.float16) for c in ZERO_VAE_CACHE]
286
 
287
+ num_blocks = 7 # ์›๋ž˜ ์„ค์ •์œผ๋กœ ๋ณต์›
288
  current_start_frame = 0
289
  all_num_frames = [pipeline.num_frame_per_block] * num_blocks
290
 
 
369
 
370
  frame_status_html = (
371
  f"<div style='padding: 10px; border: 1px solid #ddd; border-radius: 8px; font-family: sans-serif;'>"
372
+ f" <p style='margin: 0 0 8px 0; font-size: 16px; font-weight: bold;'>๐ŸŽฌ Generating Video...</p>"
373
  f" <div style='background: #e9ecef; border-radius: 4px; width: 100%; overflow: hidden;'>"
374
  f" <div style='width: {total_progress:.1f}%; height: 20px; background-color: #0d6efd; transition: width 0.2s;'></div>"
375
  f" </div>"
 
407
  current_start_frame += current_num_frames
408
 
409
  # ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์„ฑ์„ ์œ„ํ•œ GPU ์บ์‹œ ์ •๋ฆฌ
410
+ if idx < num_blocks - 1 and idx % 3 == 2: # 3๋ธ”๋ก๋งˆ๋‹ค ์บ์‹œ ์ •๋ฆฌ
411
  torch.cuda.empty_cache()
412
 
413
  # Final completion status
 
456
  f" ๐Ÿ“Š Generated {total_frames_yielded} frames across {num_blocks} blocks ({video_duration:.1f} seconds)"
457
  f" </p>"
458
  f" <p style='margin: 0; color: #0f5132; font-size: 14px;'>"
459
+ f" ๐ŸŽฌ Resolution: {all_frames_for_download[0].shape[1]}x{all_frames_for_download[0].shape[0]} โ€ข FPS: {fps} โ€ข Duration: {video_duration:.1f}s โ€ข Size: {file_size_mb:.1f} MB"
460
  f" </p>"
461
  f" <p style='margin: 8px 0 0 0; color: #0f5132; font-size: 13px; font-style: italic;'>"
462
  f" ๐Ÿ’พ Click the download button below to save your video!"
 
479
 
480
  # --- Gradio UI Layout ---
481
  with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
482
+ gr.Markdown("# ๐Ÿš€ Self-Forcing Video Generation")
483
+ gr.Markdown("Real-time video generation with distilled Wan2-1 1.3B | 5-6 seconds duration [[Model]](https://huggingface.co/gdhe17/Self-Forcing), [[Project page]](https://self-forcing.github.io), [[Paper]](https://huggingface.co/papers/2506.08009)")
484
 
485
  with gr.Row():
486
  with gr.Column(scale=2):
 
506
  )
507
 
508
  gr.Markdown("### โš™๏ธ Settings")
509
+ gr.Markdown("๐Ÿ’ก **Tip**: Adjust FPS to control video duration (8 FPS โ†’ ~10s, 10 FPS โ†’ ~8s, 12 FPS โ†’ ~6.8s, 15 FPS โ†’ ~5.4s)")
510
  with gr.Row():
511
  seed = gr.Number(
512
  label="Seed",
 
516
  )
517
  fps = gr.Slider(
518
  label="Playback FPS",
519
+ minimum=8,
520
  maximum=30,
521
  value=args.fps,
522
  step=1,
523
+ visible=True,
524
+ info="Lower FPS = longer video duration"
525
  )
526
 
527
  with gr.Column(scale=3):
 
549
 
550
  # ๋‹ค์šด๋กœ๋“œ์šฉ ํŒŒ์ผ ์ถœ๋ ฅ
551
  download_file = gr.File(
552
+ label="๐Ÿ“ฅ Download Generated Video",
553
+ visible=False,
554
+ elem_id="download_file"
555
  )
556
 
557
  # Connect the generator to the streaming video