Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -68,7 +68,7 @@ T2V_CINEMATIC_PROMPT = \
|
|
68 |
'''1. For overly concise user inputs, reasonably infer and add details to make the video more complete and appealing without altering the original intent;\n''' \
|
69 |
'''2. Enhance the main features in user descriptions (e.g., appearance, expression, quantity, race, posture, etc.), visual style, spatial relationships, and shot scales;\n''' \
|
70 |
'''3. Output the entire prompt in English, retaining original text in quotes and titles, and preserving key input information;\n''' \
|
71 |
-
'''4. Prompts should match the user
|
72 |
'''5. Emphasize motion information and different camera movements present in the input description;\n''' \
|
73 |
'''6. Your output should have natural motion attributes. For the target category described, add natural actions of the target using simple and direct verbs;\n''' \
|
74 |
'''7. The revised prompt should be around 80-100 words long.\n''' \
|
@@ -273,19 +273,23 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
|
273 |
conditional_dict[key] = value.to(dtype=torch.float16)
|
274 |
|
275 |
rnd = torch.Generator(gpu).manual_seed(int(seed))
|
|
|
276 |
pipeline._initialize_kv_cache(1, torch.float16, device=gpu)
|
277 |
pipeline._initialize_crossattn_cache(1, torch.float16, device=gpu)
|
278 |
-
|
|
|
|
|
279 |
|
280 |
vae_cache, latents_cache = None, None
|
281 |
if not APP_STATE["current_use_taehv"] and not args.trt:
|
282 |
vae_cache = [c.to(device=gpu, dtype=torch.float16) for c in ZERO_VAE_CACHE]
|
283 |
|
284 |
-
num_blocks = 7
|
285 |
current_start_frame = 0
|
286 |
all_num_frames = [pipeline.num_frame_per_block] * num_blocks
|
287 |
|
288 |
total_frames_yielded = 0
|
|
|
289 |
|
290 |
# Ensure temp directory exists
|
291 |
os.makedirs("gradio_tmp", exist_ok=True)
|
@@ -352,6 +356,7 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
|
352 |
frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
|
353 |
|
354 |
all_frames_from_block.append(frame_np)
|
|
|
355 |
total_frames_yielded += 1
|
356 |
|
357 |
# Yield status update for each frame (cute tracking!)
|
@@ -375,7 +380,7 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
|
375 |
)
|
376 |
|
377 |
# Yield None for video but update status (frame-by-frame tracking)
|
378 |
-
yield None, frame_status_html
|
379 |
|
380 |
# Encode entire block as one chunk immediately
|
381 |
if all_frames_from_block:
|
@@ -392,7 +397,7 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
|
392 |
total_progress = (idx + 1) / num_blocks * 100
|
393 |
|
394 |
# Yield the actual video chunk
|
395 |
-
yield ts_path, gr.update()
|
396 |
|
397 |
except Exception as e:
|
398 |
print(f"โ ๏ธ Error encoding block {idx}: {e}")
|
@@ -400,31 +405,82 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
|
400 |
traceback.print_exc()
|
401 |
|
402 |
current_start_frame += current_num_frames
|
|
|
|
|
|
|
|
|
403 |
|
404 |
# Final completion status
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
f"
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
423 |
|
424 |
# --- Gradio UI Layout ---
|
425 |
with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
|
426 |
-
gr.Markdown("# ๐ Self-Forcing Video Generation")
|
427 |
-
gr.Markdown("Real-time video generation with distilled Wan2-1 1.3B [[Model]](https://huggingface.co/gdhe17/Self-Forcing), [[Project page]](https://self-forcing.github.io), [[Paper]](https://huggingface.co/papers/2506.08009)")
|
428 |
|
429 |
with gr.Row():
|
430 |
with gr.Column(scale=2):
|
@@ -471,12 +527,13 @@ with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
|
|
471 |
gr.Markdown("### ๐บ Video Stream")
|
472 |
|
473 |
streaming_video = gr.Video(
|
474 |
-
label="Live Stream",
|
475 |
streaming=True,
|
476 |
loop=True,
|
477 |
height=400,
|
478 |
autoplay=True,
|
479 |
-
show_label=
|
|
|
480 |
)
|
481 |
|
482 |
status_display = gr.HTML(
|
@@ -488,12 +545,18 @@ with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
|
|
488 |
),
|
489 |
label="Generation Status"
|
490 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
491 |
|
492 |
# Connect the generator to the streaming video
|
493 |
start_btn.click(
|
494 |
fn=video_generation_handler_streaming,
|
495 |
inputs=[prompt, seed, fps],
|
496 |
-
outputs=[streaming_video, status_display]
|
497 |
)
|
498 |
|
499 |
enhance_button.click(
|
|
|
68 |
'''1. For overly concise user inputs, reasonably infer and add details to make the video more complete and appealing without altering the original intent;\n''' \
|
69 |
'''2. Enhance the main features in user descriptions (e.g., appearance, expression, quantity, race, posture, etc.), visual style, spatial relationships, and shot scales;\n''' \
|
70 |
'''3. Output the entire prompt in English, retaining original text in quotes and titles, and preserving key input information;\n''' \
|
71 |
+
'''4. Prompts should match the user's intent and accurately reflect the specified style. If the user does not specify a style, choose the most appropriate style for the video;\n''' \
|
72 |
'''5. Emphasize motion information and different camera movements present in the input description;\n''' \
|
73 |
'''6. Your output should have natural motion attributes. For the target category described, add natural actions of the target using simple and direct verbs;\n''' \
|
74 |
'''7. The revised prompt should be around 80-100 words long.\n''' \
|
|
|
273 |
conditional_dict[key] = value.to(dtype=torch.float16)
|
274 |
|
275 |
rnd = torch.Generator(gpu).manual_seed(int(seed))
|
276 |
+
# KV ์บ์ ์ด๊ธฐํ
|
277 |
pipeline._initialize_kv_cache(1, torch.float16, device=gpu)
|
278 |
pipeline._initialize_crossattn_cache(1, torch.float16, device=gpu)
|
279 |
+
|
280 |
+
# 5.5์ด ์์์ ์ํด ๋
ธ์ด์ฆ ํ
์ ํฌ๊ธฐ ์ฆ๊ฐ (21 -> 24)
|
281 |
+
noise = torch.randn([1, 24, 16, 60, 104], device=gpu, dtype=torch.float16, generator=rnd)
|
282 |
|
283 |
vae_cache, latents_cache = None, None
|
284 |
if not APP_STATE["current_use_taehv"] and not args.trt:
|
285 |
vae_cache = [c.to(device=gpu, dtype=torch.float16) for c in ZERO_VAE_CACHE]
|
286 |
|
287 |
+
num_blocks = 8 # 7 -> 8๋ก ์ฆ๊ฐํ์ฌ ์ฝ 5.5์ด ์์ ์์ฑ
|
288 |
current_start_frame = 0
|
289 |
all_num_frames = [pipeline.num_frame_per_block] * num_blocks
|
290 |
|
291 |
total_frames_yielded = 0
|
292 |
+
all_frames_for_download = [] # ๋ค์ด๋ก๋์ฉ ์ ์ฒด ํ๋ ์ ์ ์ฅ
|
293 |
|
294 |
# Ensure temp directory exists
|
295 |
os.makedirs("gradio_tmp", exist_ok=True)
|
|
|
356 |
frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
|
357 |
|
358 |
all_frames_from_block.append(frame_np)
|
359 |
+
all_frames_for_download.append(frame_np) # ๋ค์ด๋ก๋์ฉ ํ๋ ์ ์ ์ฅ
|
360 |
total_frames_yielded += 1
|
361 |
|
362 |
# Yield status update for each frame (cute tracking!)
|
|
|
380 |
)
|
381 |
|
382 |
# Yield None for video but update status (frame-by-frame tracking)
|
383 |
+
yield None, frame_status_html, gr.update()
|
384 |
|
385 |
# Encode entire block as one chunk immediately
|
386 |
if all_frames_from_block:
|
|
|
397 |
total_progress = (idx + 1) / num_blocks * 100
|
398 |
|
399 |
# Yield the actual video chunk
|
400 |
+
yield ts_path, gr.update(), gr.update()
|
401 |
|
402 |
except Exception as e:
|
403 |
print(f"โ ๏ธ Error encoding block {idx}: {e}")
|
|
|
405 |
traceback.print_exc()
|
406 |
|
407 |
current_start_frame += current_num_frames
|
408 |
+
|
409 |
+
# ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ์ ์ํ GPU ์บ์ ์ ๋ฆฌ
|
410 |
+
if idx < num_blocks - 1 and idx % 2 == 1: # 2๋ธ๋ก๋ง๋ค ์บ์ ์ ๋ฆฌ
|
411 |
+
torch.cuda.empty_cache()
|
412 |
|
413 |
# Final completion status
|
414 |
+
video_duration = total_frames_yielded / fps
|
415 |
+
|
416 |
+
# ์ ์ฒด ๋น๋์ค๋ฅผ MP4๋ก ์ ์ฅ
|
417 |
+
if all_frames_for_download:
|
418 |
+
output_filename = f"generated_video_{int(time.time())}_{seed}.mp4"
|
419 |
+
output_path = os.path.join("gradio_tmp", output_filename)
|
420 |
+
|
421 |
+
print(f"๐พ Saving complete video to {output_path}")
|
422 |
+
|
423 |
+
# MP4 ์ปจํ
์ด๋๋ก ์ ์ฅ
|
424 |
+
container = av.open(output_path, mode='w')
|
425 |
+
stream = container.add_stream('h264', rate=fps)
|
426 |
+
stream.width = all_frames_for_download[0].shape[1]
|
427 |
+
stream.height = all_frames_for_download[0].shape[0]
|
428 |
+
stream.pix_fmt = 'yuv420p'
|
429 |
+
stream.options = {
|
430 |
+
'crf': '23',
|
431 |
+
'preset': 'medium'
|
432 |
+
}
|
433 |
+
|
434 |
+
for frame_np in all_frames_for_download:
|
435 |
+
frame = av.VideoFrame.from_ndarray(frame_np, format='rgb24')
|
436 |
+
frame = frame.reformat(format=stream.pix_fmt)
|
437 |
+
for packet in stream.encode(frame):
|
438 |
+
container.mux(packet)
|
439 |
+
|
440 |
+
for packet in stream.encode():
|
441 |
+
container.mux(packet)
|
442 |
+
|
443 |
+
container.close()
|
444 |
+
|
445 |
+
# ํ์ผ ํฌ๊ธฐ ๊ณ์ฐ
|
446 |
+
file_size_mb = os.path.getsize(output_path) / (1024 * 1024)
|
447 |
+
|
448 |
+
final_status_html = (
|
449 |
+
f"<div style='padding: 16px; border: 1px solid #198754; background: linear-gradient(135deg, #d1e7dd, #f8f9fa); border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);'>"
|
450 |
+
f" <div style='display: flex; align-items: center; margin-bottom: 8px;'>"
|
451 |
+
f" <span style='font-size: 24px; margin-right: 12px;'>๐</span>"
|
452 |
+
f" <h4 style='margin: 0; color: #0f5132; font-size: 18px;'>Video Generation Complete!</h4>"
|
453 |
+
f" </div>"
|
454 |
+
f" <div style='background: rgba(255,255,255,0.7); padding: 12px; border-radius: 4px;'>"
|
455 |
+
f" <p style='margin: 0 0 8px 0; color: #0f5132; font-weight: 500;'>"
|
456 |
+
f" ๐ Generated {total_frames_yielded} frames across {num_blocks} blocks ({video_duration:.1f} seconds)"
|
457 |
+
f" </p>"
|
458 |
+
f" <p style='margin: 0; color: #0f5132; font-size: 14px;'>"
|
459 |
+
f" ๐ฌ Resolution: {all_frames_for_download[0].shape[1]}x{all_frames_for_download[0].shape[0]} โข FPS: {fps} โข Size: {file_size_mb:.1f} MB"
|
460 |
+
f" </p>"
|
461 |
+
f" <p style='margin: 8px 0 0 0; color: #0f5132; font-size: 13px; font-style: italic;'>"
|
462 |
+
f" ๐พ Click the download button below to save your video!"
|
463 |
+
f" </p>"
|
464 |
+
f" </div>"
|
465 |
+
f"</div>"
|
466 |
+
)
|
467 |
+
|
468 |
+
# ์ต์ข
๋น๋์ค ํ์ผ ๊ฒฝ๋ก๋ ํจ๊ป ๋ฐํ
|
469 |
+
yield output_path, final_status_html, gr.update(value=output_path, visible=True)
|
470 |
+
else:
|
471 |
+
final_status_html = (
|
472 |
+
f"<div style='padding: 16px; border: 1px solid #dc3545; background: #f8d7da; border-radius: 8px;'>"
|
473 |
+
f" <h4 style='margin: 0; color: #721c24;'>โ ๏ธ No frames were generated</h4>"
|
474 |
+
f"</div>"
|
475 |
+
)
|
476 |
+
yield None, final_status_html, gr.update()
|
477 |
+
|
478 |
+
print(f"โ
Video generation complete! {total_frames_yielded} frames ({video_duration:.1f} seconds)")
|
479 |
|
480 |
# --- Gradio UI Layout ---
|
481 |
with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
|
482 |
+
gr.Markdown("# ๐ Self-Forcing Video Generation (6-second)")
|
483 |
+
gr.Markdown("Real-time 6-second video generation with distilled Wan2-1 1.3B [[Model]](https://huggingface.co/gdhe17/Self-Forcing), [[Project page]](https://self-forcing.github.io), [[Paper]](https://huggingface.co/papers/2506.08009)")
|
484 |
|
485 |
with gr.Row():
|
486 |
with gr.Column(scale=2):
|
|
|
527 |
gr.Markdown("### ๐บ Video Stream")
|
528 |
|
529 |
streaming_video = gr.Video(
|
530 |
+
label="Live Stream & Download",
|
531 |
streaming=True,
|
532 |
loop=True,
|
533 |
height=400,
|
534 |
autoplay=True,
|
535 |
+
show_label=True,
|
536 |
+
show_download_button=True # ๋ค์ด๋ก๋ ๋ฒํผ ํ์ฑํ
|
537 |
)
|
538 |
|
539 |
status_display = gr.HTML(
|
|
|
545 |
),
|
546 |
label="Generation Status"
|
547 |
)
|
548 |
+
|
549 |
+
# ๋ค์ด๋ก๋์ฉ ํ์ผ ์ถ๋ ฅ
|
550 |
+
download_file = gr.File(
|
551 |
+
label="๐ฅ Download Video",
|
552 |
+
visible=False
|
553 |
+
)
|
554 |
|
555 |
# Connect the generator to the streaming video
|
556 |
start_btn.click(
|
557 |
fn=video_generation_handler_streaming,
|
558 |
inputs=[prompt, seed, fps],
|
559 |
+
outputs=[streaming_video, status_display, download_file]
|
560 |
)
|
561 |
|
562 |
enhance_button.click(
|