seawolf2357 commited on
Commit
a34249d
ยท
verified ยท
1 Parent(s): 374f68b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -28
app.py CHANGED
@@ -68,7 +68,7 @@ T2V_CINEMATIC_PROMPT = \
68
  '''1. For overly concise user inputs, reasonably infer and add details to make the video more complete and appealing without altering the original intent;\n''' \
69
  '''2. Enhance the main features in user descriptions (e.g., appearance, expression, quantity, race, posture, etc.), visual style, spatial relationships, and shot scales;\n''' \
70
  '''3. Output the entire prompt in English, retaining original text in quotes and titles, and preserving key input information;\n''' \
71
- '''4. Prompts should match the userโ€™s intent and accurately reflect the specified style. If the user does not specify a style, choose the most appropriate style for the video;\n''' \
72
  '''5. Emphasize motion information and different camera movements present in the input description;\n''' \
73
  '''6. Your output should have natural motion attributes. For the target category described, add natural actions of the target using simple and direct verbs;\n''' \
74
  '''7. The revised prompt should be around 80-100 words long.\n''' \
@@ -273,19 +273,23 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
273
  conditional_dict[key] = value.to(dtype=torch.float16)
274
 
275
  rnd = torch.Generator(gpu).manual_seed(int(seed))
 
276
  pipeline._initialize_kv_cache(1, torch.float16, device=gpu)
277
  pipeline._initialize_crossattn_cache(1, torch.float16, device=gpu)
278
- noise = torch.randn([1, 21, 16, 60, 104], device=gpu, dtype=torch.float16, generator=rnd)
 
 
279
 
280
  vae_cache, latents_cache = None, None
281
  if not APP_STATE["current_use_taehv"] and not args.trt:
282
  vae_cache = [c.to(device=gpu, dtype=torch.float16) for c in ZERO_VAE_CACHE]
283
 
284
- num_blocks = 7
285
  current_start_frame = 0
286
  all_num_frames = [pipeline.num_frame_per_block] * num_blocks
287
 
288
  total_frames_yielded = 0
 
289
 
290
  # Ensure temp directory exists
291
  os.makedirs("gradio_tmp", exist_ok=True)
@@ -352,6 +356,7 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
352
  frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
353
 
354
  all_frames_from_block.append(frame_np)
 
355
  total_frames_yielded += 1
356
 
357
  # Yield status update for each frame (cute tracking!)
@@ -375,7 +380,7 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
375
  )
376
 
377
  # Yield None for video but update status (frame-by-frame tracking)
378
- yield None, frame_status_html
379
 
380
  # Encode entire block as one chunk immediately
381
  if all_frames_from_block:
@@ -392,7 +397,7 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
392
  total_progress = (idx + 1) / num_blocks * 100
393
 
394
  # Yield the actual video chunk
395
- yield ts_path, gr.update()
396
 
397
  except Exception as e:
398
  print(f"โš ๏ธ Error encoding block {idx}: {e}")
@@ -400,31 +405,82 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
400
  traceback.print_exc()
401
 
402
  current_start_frame += current_num_frames
 
 
 
 
403
 
404
  # Final completion status
405
- final_status_html = (
406
- f"<div style='padding: 16px; border: 1px solid #198754; background: linear-gradient(135deg, #d1e7dd, #f8f9fa); border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);'>"
407
- f" <div style='display: flex; align-items: center; margin-bottom: 8px;'>"
408
- f" <span style='font-size: 24px; margin-right: 12px;'>๐ŸŽ‰</span>"
409
- f" <h4 style='margin: 0; color: #0f5132; font-size: 18px;'>Stream Complete!</h4>"
410
- f" </div>"
411
- f" <div style='background: rgba(255,255,255,0.7); padding: 8px; border-radius: 4px;'>"
412
- f" <p style='margin: 0; color: #0f5132; font-weight: 500;'>"
413
- f" ๐Ÿ“Š Generated {total_frames_yielded} frames across {num_blocks} blocks"
414
- f" </p>"
415
- f" <p style='margin: 4px 0 0 0; color: #0f5132; font-size: 14px;'>"
416
- f" ๐ŸŽฌ Playback: {fps} FPS โ€ข ๐Ÿ“ Format: MPEG-TS/H.264"
417
- f" </p>"
418
- f" </div>"
419
- f"</div>"
420
- )
421
- yield None, final_status_html
422
- print(f"โœ… PyAV streaming complete! {total_frames_yielded} frames across {num_blocks} blocks")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
 
424
  # --- Gradio UI Layout ---
425
  with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
426
- gr.Markdown("# ๐Ÿš€ Self-Forcing Video Generation")
427
- gr.Markdown("Real-time video generation with distilled Wan2-1 1.3B [[Model]](https://huggingface.co/gdhe17/Self-Forcing), [[Project page]](https://self-forcing.github.io), [[Paper]](https://huggingface.co/papers/2506.08009)")
428
 
429
  with gr.Row():
430
  with gr.Column(scale=2):
@@ -471,12 +527,13 @@ with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
471
  gr.Markdown("### ๐Ÿ“บ Video Stream")
472
 
473
  streaming_video = gr.Video(
474
- label="Live Stream",
475
  streaming=True,
476
  loop=True,
477
  height=400,
478
  autoplay=True,
479
- show_label=False
 
480
  )
481
 
482
  status_display = gr.HTML(
@@ -488,12 +545,18 @@ with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
488
  ),
489
  label="Generation Status"
490
  )
 
 
 
 
 
 
491
 
492
  # Connect the generator to the streaming video
493
  start_btn.click(
494
  fn=video_generation_handler_streaming,
495
  inputs=[prompt, seed, fps],
496
- outputs=[streaming_video, status_display]
497
  )
498
 
499
  enhance_button.click(
 
68
  '''1. For overly concise user inputs, reasonably infer and add details to make the video more complete and appealing without altering the original intent;\n''' \
69
  '''2. Enhance the main features in user descriptions (e.g., appearance, expression, quantity, race, posture, etc.), visual style, spatial relationships, and shot scales;\n''' \
70
  '''3. Output the entire prompt in English, retaining original text in quotes and titles, and preserving key input information;\n''' \
71
+ '''4. Prompts should match the user's intent and accurately reflect the specified style. If the user does not specify a style, choose the most appropriate style for the video;\n''' \
72
  '''5. Emphasize motion information and different camera movements present in the input description;\n''' \
73
  '''6. Your output should have natural motion attributes. For the target category described, add natural actions of the target using simple and direct verbs;\n''' \
74
  '''7. The revised prompt should be around 80-100 words long.\n''' \
 
273
  conditional_dict[key] = value.to(dtype=torch.float16)
274
 
275
  rnd = torch.Generator(gpu).manual_seed(int(seed))
276
+ # KV ์บ์‹œ ์ดˆ๊ธฐํ™”
277
  pipeline._initialize_kv_cache(1, torch.float16, device=gpu)
278
  pipeline._initialize_crossattn_cache(1, torch.float16, device=gpu)
279
+
280
+ # 5.5์ดˆ ์˜์ƒ์„ ์œ„ํ•ด ๋…ธ์ด์ฆˆ ํ…์„œ ํฌ๊ธฐ ์ฆ๊ฐ€ (21 -> 24)
281
+ noise = torch.randn([1, 24, 16, 60, 104], device=gpu, dtype=torch.float16, generator=rnd)
282
 
283
  vae_cache, latents_cache = None, None
284
  if not APP_STATE["current_use_taehv"] and not args.trt:
285
  vae_cache = [c.to(device=gpu, dtype=torch.float16) for c in ZERO_VAE_CACHE]
286
 
287
+ num_blocks = 8 # 7 -> 8๋กœ ์ฆ๊ฐ€ํ•˜์—ฌ ์•ฝ 5.5์ดˆ ์˜์ƒ ์ƒ์„ฑ
288
  current_start_frame = 0
289
  all_num_frames = [pipeline.num_frame_per_block] * num_blocks
290
 
291
  total_frames_yielded = 0
292
+ all_frames_for_download = [] # ๋‹ค์šด๋กœ๋“œ์šฉ ์ „์ฒด ํ”„๋ ˆ์ž„ ์ €์žฅ
293
 
294
  # Ensure temp directory exists
295
  os.makedirs("gradio_tmp", exist_ok=True)
 
356
  frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
357
 
358
  all_frames_from_block.append(frame_np)
359
+ all_frames_for_download.append(frame_np) # ๋‹ค์šด๋กœ๋“œ์šฉ ํ”„๋ ˆ์ž„ ์ €์žฅ
360
  total_frames_yielded += 1
361
 
362
  # Yield status update for each frame (cute tracking!)
 
380
  )
381
 
382
  # Yield None for video but update status (frame-by-frame tracking)
383
+ yield None, frame_status_html, gr.update()
384
 
385
  # Encode entire block as one chunk immediately
386
  if all_frames_from_block:
 
397
  total_progress = (idx + 1) / num_blocks * 100
398
 
399
  # Yield the actual video chunk
400
+ yield ts_path, gr.update(), gr.update()
401
 
402
  except Exception as e:
403
  print(f"โš ๏ธ Error encoding block {idx}: {e}")
 
405
  traceback.print_exc()
406
 
407
  current_start_frame += current_num_frames
408
+
409
+ # ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์„ฑ์„ ์œ„ํ•œ GPU ์บ์‹œ ์ •๋ฆฌ
410
+ if idx < num_blocks - 1 and idx % 2 == 1: # 2๋ธ”๋ก๋งˆ๋‹ค ์บ์‹œ ์ •๋ฆฌ
411
+ torch.cuda.empty_cache()
412
 
413
  # Final completion status
414
+ video_duration = total_frames_yielded / fps
415
+
416
+ # ์ „์ฒด ๋น„๋””์˜ค๋ฅผ MP4๋กœ ์ €์žฅ
417
+ if all_frames_for_download:
418
+ output_filename = f"generated_video_{int(time.time())}_{seed}.mp4"
419
+ output_path = os.path.join("gradio_tmp", output_filename)
420
+
421
+ print(f"๐Ÿ’พ Saving complete video to {output_path}")
422
+
423
+ # MP4 ์ปจํ…Œ์ด๋„ˆ๋กœ ์ €์žฅ
424
+ container = av.open(output_path, mode='w')
425
+ stream = container.add_stream('h264', rate=fps)
426
+ stream.width = all_frames_for_download[0].shape[1]
427
+ stream.height = all_frames_for_download[0].shape[0]
428
+ stream.pix_fmt = 'yuv420p'
429
+ stream.options = {
430
+ 'crf': '23',
431
+ 'preset': 'medium'
432
+ }
433
+
434
+ for frame_np in all_frames_for_download:
435
+ frame = av.VideoFrame.from_ndarray(frame_np, format='rgb24')
436
+ frame = frame.reformat(format=stream.pix_fmt)
437
+ for packet in stream.encode(frame):
438
+ container.mux(packet)
439
+
440
+ for packet in stream.encode():
441
+ container.mux(packet)
442
+
443
+ container.close()
444
+
445
+ # ํŒŒ์ผ ํฌ๊ธฐ ๊ณ„์‚ฐ
446
+ file_size_mb = os.path.getsize(output_path) / (1024 * 1024)
447
+
448
+ final_status_html = (
449
+ f"<div style='padding: 16px; border: 1px solid #198754; background: linear-gradient(135deg, #d1e7dd, #f8f9fa); border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);'>"
450
+ f" <div style='display: flex; align-items: center; margin-bottom: 8px;'>"
451
+ f" <span style='font-size: 24px; margin-right: 12px;'>๐ŸŽ‰</span>"
452
+ f" <h4 style='margin: 0; color: #0f5132; font-size: 18px;'>Video Generation Complete!</h4>"
453
+ f" </div>"
454
+ f" <div style='background: rgba(255,255,255,0.7); padding: 12px; border-radius: 4px;'>"
455
+ f" <p style='margin: 0 0 8px 0; color: #0f5132; font-weight: 500;'>"
456
+ f" ๐Ÿ“Š Generated {total_frames_yielded} frames across {num_blocks} blocks ({video_duration:.1f} seconds)"
457
+ f" </p>"
458
+ f" <p style='margin: 0; color: #0f5132; font-size: 14px;'>"
459
+ f" ๐ŸŽฌ Resolution: {all_frames_for_download[0].shape[1]}x{all_frames_for_download[0].shape[0]} โ€ข FPS: {fps} โ€ข Size: {file_size_mb:.1f} MB"
460
+ f" </p>"
461
+ f" <p style='margin: 8px 0 0 0; color: #0f5132; font-size: 13px; font-style: italic;'>"
462
+ f" ๐Ÿ’พ Click the download button below to save your video!"
463
+ f" </p>"
464
+ f" </div>"
465
+ f"</div>"
466
+ )
467
+
468
+ # ์ตœ์ข… ๋น„๋””์˜ค ํŒŒ์ผ ๊ฒฝ๋กœ๋„ ํ•จ๊ป˜ ๋ฐ˜ํ™˜
469
+ yield output_path, final_status_html, gr.update(value=output_path, visible=True)
470
+ else:
471
+ final_status_html = (
472
+ f"<div style='padding: 16px; border: 1px solid #dc3545; background: #f8d7da; border-radius: 8px;'>"
473
+ f" <h4 style='margin: 0; color: #721c24;'>โš ๏ธ No frames were generated</h4>"
474
+ f"</div>"
475
+ )
476
+ yield None, final_status_html, gr.update()
477
+
478
+ print(f"โœ… Video generation complete! {total_frames_yielded} frames ({video_duration:.1f} seconds)")
479
 
480
  # --- Gradio UI Layout ---
481
  with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
482
+ gr.Markdown("# ๐Ÿš€ Self-Forcing Video Generation (6-second)")
483
+ gr.Markdown("Real-time 6-second video generation with distilled Wan2-1 1.3B [[Model]](https://huggingface.co/gdhe17/Self-Forcing), [[Project page]](https://self-forcing.github.io), [[Paper]](https://huggingface.co/papers/2506.08009)")
484
 
485
  with gr.Row():
486
  with gr.Column(scale=2):
 
527
  gr.Markdown("### ๐Ÿ“บ Video Stream")
528
 
529
  streaming_video = gr.Video(
530
+ label="Live Stream & Download",
531
  streaming=True,
532
  loop=True,
533
  height=400,
534
  autoplay=True,
535
+ show_label=True,
536
+ show_download_button=True # ๋‹ค์šด๋กœ๋“œ ๋ฒ„ํŠผ ํ™œ์„ฑํ™”
537
  )
538
 
539
  status_display = gr.HTML(
 
545
  ),
546
  label="Generation Status"
547
  )
548
+
549
+ # ๋‹ค์šด๋กœ๋“œ์šฉ ํŒŒ์ผ ์ถœ๋ ฅ
550
+ download_file = gr.File(
551
+ label="๐Ÿ“ฅ Download Video",
552
+ visible=False
553
+ )
554
 
555
  # Connect the generator to the streaming video
556
  start_btn.click(
557
  fn=video_generation_handler_streaming,
558
  inputs=[prompt, seed, fps],
559
+ outputs=[streaming_video, status_display, download_file]
560
  )
561
 
562
  enhance_button.click(