multimodalart HF Staff commited on
Commit
aa5e39d
·
verified ·
1 Parent(s): 1b86783

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -18
app.py CHANGED
@@ -146,7 +146,7 @@ pipeline.to(dtype=torch.float16).to(gpu)
146
  # --- Frame Streaming Video Generation Handler ---
147
  @torch.no_grad()
148
  @spaces.GPU
149
- def video_generation_handler(prompt, seed, fps, progress=gr.Progress()):
150
  """
151
  Generator function that yields RGB frames for display in gr.Image.
152
  Includes timing delays for smooth playback.
@@ -230,8 +230,11 @@ def video_generation_handler(prompt, seed, fps, progress=gr.Progress()):
230
 
231
  print(f"📹 Decoded pixels shape: {pixels.shape}")
232
 
 
 
 
233
  # Yield individual frames with timing delays
234
- for frame_idx in range(pixels.shape[1]):
235
  frame_tensor = pixels[0, frame_idx] # Get single frame [C, H, W]
236
 
237
  # Normalize from [-1, 1] to [0, 255]
@@ -244,22 +247,47 @@ def video_generation_handler(prompt, seed, fps, progress=gr.Progress()):
244
  all_frames_for_video.append(frame_np)
245
  total_frames_yielded += 1
246
 
247
- # Calculate progress
248
- total_expected_frames = num_blocks * pipeline.num_frame_per_block
249
- current_frame_count = (idx * pipeline.num_frame_per_block) + frame_idx + 1
250
- frame_progress = current_frame_count / total_expected_frames
 
251
 
252
- # Update progress
253
- progress(frame_progress, desc=f"Frame {total_frames_yielded} | Block {idx+1}/{num_blocks}")
254
 
255
  print(f"📺 Yielding frame {total_frames_yielded}: shape {frame_np.shape}")
256
 
257
- # Yield frame with timing delay
258
- yield gr.update(visible=True, frame_np), gr.update(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
  # Sleep between frames for smooth playback (except for the last frame)
261
- if not (frame_idx == pixels.shape[1] - 1 and idx + 1 == num_blocks):
262
- time.sleep(frame_delay)
 
263
 
264
  current_start_frame += current_num_frames
265
 
@@ -270,10 +298,26 @@ def video_generation_handler(prompt, seed, fps, progress=gr.Progress()):
270
  video_path = f"gradio_tmp/{seed}_{hashlib.md5(prompt.encode()).hexdigest()}.mp4"
271
  imageio.mimwrite(video_path, all_frames_for_video, fps=fps, quality=8)
272
  print(f"✅ Video saved to {video_path}")
273
- return gr.update(visible=False), gr.update(value=video_path, visible=True)
 
 
 
 
 
 
 
 
274
  except Exception as e:
275
  print(f"⚠️ Could not save final video: {e}")
276
- return None, None
 
 
 
 
 
 
 
 
277
 
278
  # --- Gradio UI Layout ---
279
  with gr.Blocks(theme=gr.themes.Soft(), title="Self-Forcing Frame Streaming Demo") as demo:
@@ -326,17 +370,22 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Self-Forcing Frame Streaming Demo"
326
 
327
  final_video = gr.Video(
328
  label="Final Rendered Video",
329
- visible=True,
330
  interactive=False,
331
- height=400
 
 
 
 
 
 
332
  )
333
 
334
  # Connect the generator to the image display
335
  start_btn.click(
336
  fn=video_generation_handler,
337
  inputs=[prompt, seed, fps],
338
- outputs=[frame_display, final_video],
339
- show_progress="full"
340
  )
341
 
342
  # --- Launch App ---
 
146
  # --- Frame Streaming Video Generation Handler ---
147
  @torch.no_grad()
148
  @spaces.GPU
149
+ def video_generation_handler(prompt, seed, fps):
150
  """
151
  Generator function that yields RGB frames for display in gr.Image.
152
  Includes timing delays for smooth playback.
 
230
 
231
  print(f"📹 Decoded pixels shape: {pixels.shape}")
232
 
233
+ # Calculate actual frames that will be yielded for this block
234
+ actual_frames_this_block = pixels.shape[1]
235
+
236
  # Yield individual frames with timing delays
237
+ for frame_idx in range(actual_frames_this_block):
238
  frame_tensor = pixels[0, frame_idx] # Get single frame [C, H, W]
239
 
240
  # Normalize from [-1, 1] to [0, 255]
 
247
  all_frames_for_video.append(frame_np)
248
  total_frames_yielded += 1
249
 
250
+ # Calculate progress based on blocks completed + current block progress
251
+ blocks_completed = idx
252
+ current_block_progress = (frame_idx + 1) / actual_frames_this_block
253
+ total_block_progress = (blocks_completed + current_block_progress) / num_blocks
254
+ frame_progress_percent = total_block_progress * 100
255
 
256
+ # Cap at 100% to avoid going over
257
+ frame_progress_percent = min(frame_progress_percent, 100.0)
258
 
259
  print(f"📺 Yielding frame {total_frames_yielded}: shape {frame_np.shape}")
260
 
261
+ # Create HTML status update
262
+ if frame_idx == actual_frames_this_block - 1 and idx + 1 == num_blocks: # Last frame
263
+ status_html = (
264
+ f"<div style='padding: 16px; border: 1px solid #198754; background-color: #d1e7dd; border-radius: 8px; font-family: sans-serif; text-align: center;'>"
265
+ f" <h4 style='margin: 0 0 8px 0; color: #0f5132; font-size: 18px;'>🎉 Generation Complete!</h4>"
266
+ f" <p style='margin: 0; color: #0f5132;'>"
267
+ f" Total frames: {total_frames_yielded}. The final video is now available."
268
+ f" </p>"
269
+ f"</div>"
270
+ )
271
+ else: # Regular frames
272
+ status_html = (
273
+ f"<div style='padding: 10px; border: 1px solid #ddd; border-radius: 8px; font-family: sans-serif;'>"
274
+ f" <p style='margin: 0 0 8px 0; font-size: 16px; font-weight: bold;'>Generating Video...</p>"
275
+ f" <div style='background: #e9ecef; border-radius: 4px; width: 100%; overflow: hidden;'>"
276
+ f" <div style='width: {frame_progress_percent:.1f}%; height: 20px; background-color: #0d6efd; transition: width 0.2s;'></div>"
277
+ f" </div>"
278
+ f" <p style='margin: 8px 0 0 0; color: #555; font-size: 14px; text-align: right;'>"
279
+ f" Block {idx+1}/{num_blocks} | Frame {total_frames_yielded} | {frame_progress_percent:.1f}%"
280
+ f" </p>"
281
+ f"</div>"
282
+ )
283
+
284
+ # Yield frame with a small delay to ensure UI updates
285
+ yield gr.update(visible=True, value=frame_np), gr.update(visible=False), status_html
286
 
287
  # Sleep between frames for smooth playback (except for the last frame)
288
+ # Add minimum delay to ensure UI can update
289
+ if not (frame_idx == actual_frames_this_block - 1 and idx + 1 == num_blocks):
290
+ time.sleep(max(frame_delay, 0.1)) # Minimum 100ms delay
291
 
292
  current_start_frame += current_num_frames
293
 
 
298
  video_path = f"gradio_tmp/{seed}_{hashlib.md5(prompt.encode()).hexdigest()}.mp4"
299
  imageio.mimwrite(video_path, all_frames_for_video, fps=fps, quality=8)
300
  print(f"✅ Video saved to {video_path}")
301
+ final_status_html = (
302
+ f"<div style='padding: 16px; border: 1px solid #198754; background-color: #d1e7dd; border-radius: 8px; font-family: sans-serif; text-align: center;'>"
303
+ f" <h4 style='margin: 0 0 8px 0; color: #0f5132; font-size: 18px;'>🎉 Generation Complete!</h4>"
304
+ f" <p style='margin: 0; color: #0f5132;'>"
305
+ f" Video saved successfully with {total_frames_yielded} frames at {fps} FPS."
306
+ f" </p>"
307
+ f"</div>"
308
+ )
309
+ yield gr.update(visible=False), gr.update(value=video_path, visible=True), final_status_html
310
  except Exception as e:
311
  print(f"⚠️ Could not save final video: {e}")
312
+ error_status_html = (
313
+ f"<div style='padding: 16px; border: 1px solid #dc3545; background-color: #f8d7da; border-radius: 8px; font-family: sans-serif; text-align: center;'>"
314
+ f" <h4 style='margin: 0 0 8px 0; color: #721c24; font-size: 18px;'>⚠️ Video Save Error</h4>"
315
+ f" <p style='margin: 0; color: #721c24;'>"
316
+ f" Could not save final video: {str(e)}"
317
+ f" </p>"
318
+ f"</div>"
319
+ )
320
+ yield None, None, error_status_html
321
 
322
  # --- Gradio UI Layout ---
323
  with gr.Blocks(theme=gr.themes.Soft(), title="Self-Forcing Frame Streaming Demo") as demo:
 
370
 
371
  final_video = gr.Video(
372
  label="Final Rendered Video",
373
+ visible=False,
374
  interactive=False,
375
+ height=400,
376
+ autoplay=True
377
+ )
378
+
379
+ status_html = gr.HTML(
380
+ value="<div style='text-align: center; padding: 20px; color: #666;'>Ready to start generation...</div>",
381
+ label="Generation Status"
382
  )
383
 
384
  # Connect the generator to the image display
385
  start_btn.click(
386
  fn=video_generation_handler,
387
  inputs=[prompt, seed, fps],
388
+ outputs=[frame_display, final_video, status_html]
 
389
  )
390
 
391
  # --- Launch App ---