fffiloni commited on
Commit
cf5b2d5
·
verified ·
1 Parent(s): b0ba480

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +41 -40
gradio_app.py CHANGED
@@ -2,12 +2,6 @@ import os
2
  import gradio as gr
3
  import torch
4
  from huggingface_hub import snapshot_download
5
-
6
- # import argparse
7
-
8
- snapshot_download(repo_id="fffiloni/svd_keyframe_interpolation", local_dir="checkpoints")
9
- checkpoint_dir = "checkpoints/svd_reverse_motion_with_attnflip"
10
-
11
  from diffusers.utils import load_image, export_to_video
12
  from diffusers import UNetSpatioTemporalConditionModel
13
  from custom_diffusers.pipelines.pipeline_frame_interpolation_with_noise_injection import FrameInterpolationWithNoiseInjectionPipeline
@@ -16,8 +10,16 @@ from attn_ctrl.attention_control import (AttentionStore,
16
  register_temporal_self_attention_control,
17
  register_temporal_self_attention_flip_control,
18
  )
 
19
 
 
 
20
 
 
 
 
 
 
21
  pretrained_model_name_or_path = "stabilityai/stable-video-diffusion-img2vid-xt"
22
  noise_scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
23
 
@@ -29,14 +31,14 @@ pipe = FrameInterpolationWithNoiseInjectionPipeline.from_pretrained(
29
  )
30
  ref_unet = pipe.ori_unet
31
 
 
32
  state_dict = pipe.unet.state_dict()
33
- # computing delta w
34
  finetuned_unet = UNetSpatioTemporalConditionModel.from_pretrained(
35
  checkpoint_dir,
36
  subfolder="unet",
37
  torch_dtype=torch.float16,
38
  )
39
- assert finetuned_unet.config.num_frames==14
40
  ori_unet = UNetSpatioTemporalConditionModel.from_pretrained(
41
  "stabilityai/stable-video-diffusion-img2vid",
42
  subfolder="unet",
@@ -52,33 +54,33 @@ for name, param in finetuned_state_dict.items():
52
  state_dict[name] = state_dict[name] + delta_w
53
  pipe.unet.load_state_dict(state_dict)
54
 
55
- controller_ref= AttentionStore()
56
  register_temporal_self_attention_control(ref_unet, controller_ref)
57
 
58
  controller = AttentionStore()
59
  register_temporal_self_attention_flip_control(pipe.unet, controller, controller_ref)
60
 
61
- device = "cuda"
62
- pipe = pipe.to(device)
 
 
63
 
64
  def check_outputs_folder(folder_path):
65
- # Check if the folder exists
66
  if os.path.exists(folder_path) and os.path.isdir(folder_path):
67
- # Delete all contents inside the folder
68
  for filename in os.listdir(folder_path):
69
  file_path = os.path.join(folder_path, filename)
70
  try:
71
  if os.path.isfile(file_path) or os.path.islink(file_path):
72
- os.unlink(file_path) # Remove file or link
73
  elif os.path.isdir(file_path):
74
- shutil.rmtree(file_path) # Remove directory
75
  except Exception as e:
76
  print(f'Failed to delete {file_path}. Reason: {e}')
77
  else:
78
  print(f'The folder {folder_path} does not exist.')
79
 
 
80
  def infer(frame1_path, frame2_path):
81
-
82
  seed = 42
83
  num_inference_steps = 10
84
  noise_injection_steps = 0
@@ -88,7 +90,6 @@ def infer(frame1_path, frame2_path):
88
  generator = torch.Generator(device)
89
  if seed is not None:
90
  generator = generator.manual_seed(seed)
91
-
92
 
93
  frame1 = load_image(frame1_path)
94
  frame1 = frame1.resize((512, 288))
@@ -96,34 +97,32 @@ def infer(frame1_path, frame2_path):
96
  frame2 = load_image(frame2_path)
97
  frame2 = frame2.resize((512, 288))
98
 
99
- torch.cuda.empty_cache()
100
 
101
- frames = pipe(image1=frame1, image2=frame2,
102
- num_inference_steps=num_inference_steps, # 50
103
- generator=generator,
104
- weighted_average=weighted_average, # True
105
- noise_injection_steps=noise_injection_steps, # 0
106
- noise_injection_ratio= noise_injection_ratio, # 0.5
107
- ).frames[0]
 
108
 
109
- print(f"FRAMES: {frames}")
110
-
111
- out_dir = "result"
112
 
 
113
  check_outputs_folder(out_dir)
114
  os.makedirs(out_dir, exist_ok=True)
115
  out_path = "result/video_result.gif"
116
 
117
- '''
118
- if out_path.endswith('.gif'):
119
- frames[0].save(out_path, save_all=True, append_images=frames[1:], duration=142, loop=0)
120
- else:
121
- export_to_video(frames, out_path, fps=7)
122
- '''
123
  return "done"
124
 
125
- with gr.Blocks() as demo:
 
 
 
126
 
 
127
  with gr.Column():
128
  gr.Markdown("# Keyframe Interpolation with Stable Video Diffusion")
129
  with gr.Row():
@@ -135,10 +134,12 @@ with gr.Blocks() as demo:
135
  output = gr.Textbox()
136
 
137
  submit_btn.click(
138
- fn = infer,
139
- inputs = [image_input1, image_input2],
140
- outputs = [output],
141
- show_api = False
142
  )
143
 
144
- demo.queue().launch(show_api=False, show_error=True)
 
 
 
2
  import gradio as gr
3
  import torch
4
  from huggingface_hub import snapshot_download
 
 
 
 
 
 
5
  from diffusers.utils import load_image, export_to_video
6
  from diffusers import UNetSpatioTemporalConditionModel
7
  from custom_diffusers.pipelines.pipeline_frame_interpolation_with_noise_injection import FrameInterpolationWithNoiseInjectionPipeline
 
10
  register_temporal_self_attention_control,
11
  register_temporal_self_attention_flip_control,
12
  )
13
+ from torch.cuda.amp import autocast
14
 
15
+ # Set up device
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
+ # Download checkpoint
19
+ snapshot_download(repo_id="fffiloni/svd_keyframe_interpolation", local_dir="checkpoints")
20
+ checkpoint_dir = "checkpoints/svd_reverse_motion_with_attnflip"
21
+
22
+ # Initialize pipeline
23
  pretrained_model_name_or_path = "stabilityai/stable-video-diffusion-img2vid-xt"
24
  noise_scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
25
 
 
31
  )
32
  ref_unet = pipe.ori_unet
33
 
34
+ # Compute delta w
35
  state_dict = pipe.unet.state_dict()
 
36
  finetuned_unet = UNetSpatioTemporalConditionModel.from_pretrained(
37
  checkpoint_dir,
38
  subfolder="unet",
39
  torch_dtype=torch.float16,
40
  )
41
+ assert finetuned_unet.config.num_frames == 14
42
  ori_unet = UNetSpatioTemporalConditionModel.from_pretrained(
43
  "stabilityai/stable-video-diffusion-img2vid",
44
  subfolder="unet",
 
54
  state_dict[name] = state_dict[name] + delta_w
55
  pipe.unet.load_state_dict(state_dict)
56
 
57
+ controller_ref = AttentionStore()
58
  register_temporal_self_attention_control(ref_unet, controller_ref)
59
 
60
  controller = AttentionStore()
61
  register_temporal_self_attention_flip_control(pipe.unet, controller, controller_ref)
62
 
63
+ # Custom CUDA memory management function
64
+ def cuda_memory_cleanup():
65
+ torch.cuda.empty_cache()
66
+ torch.cuda.ipc_collect()
67
 
68
  def check_outputs_folder(folder_path):
 
69
  if os.path.exists(folder_path) and os.path.isdir(folder_path):
 
70
  for filename in os.listdir(folder_path):
71
  file_path = os.path.join(folder_path, filename)
72
  try:
73
  if os.path.isfile(file_path) or os.path.islink(file_path):
74
+ os.unlink(file_path)
75
  elif os.path.isdir(file_path):
76
+ shutil.rmtree(file_path)
77
  except Exception as e:
78
  print(f'Failed to delete {file_path}. Reason: {e}')
79
  else:
80
  print(f'The folder {folder_path} does not exist.')
81
 
82
+ @torch.no_grad()
83
  def infer(frame1_path, frame2_path):
 
84
  seed = 42
85
  num_inference_steps = 10
86
  noise_injection_steps = 0
 
90
  generator = torch.Generator(device)
91
  if seed is not None:
92
  generator = generator.manual_seed(seed)
 
93
 
94
  frame1 = load_image(frame1_path)
95
  frame1 = frame1.resize((512, 288))
 
97
  frame2 = load_image(frame2_path)
98
  frame2 = frame2.resize((512, 288))
99
 
100
+ cuda_memory_cleanup()
101
 
102
+ with autocast():
103
+ frames = pipe(image1=frame1, image2=frame2,
104
+ num_inference_steps=num_inference_steps,
105
+ generator=generator,
106
+ weighted_average=weighted_average,
107
+ noise_injection_steps=noise_injection_steps,
108
+ noise_injection_ratio=noise_injection_ratio,
109
+ ).frames[0]
110
 
111
+ frames = [frame.cpu() for frame in frames]
 
 
112
 
113
+ out_dir = "result"
114
  check_outputs_folder(out_dir)
115
  os.makedirs(out_dir, exist_ok=True)
116
  out_path = "result/video_result.gif"
117
 
 
 
 
 
 
 
118
  return "done"
119
 
120
+ @torch.no_grad()
121
+ def load_model():
122
+ global pipe
123
+ pipe = pipe.to(device)
124
 
125
+ with gr.Blocks() as demo:
126
  with gr.Column():
127
  gr.Markdown("# Keyframe Interpolation with Stable Video Diffusion")
128
  with gr.Row():
 
134
  output = gr.Textbox()
135
 
136
  submit_btn.click(
137
+ fn=infer,
138
+ inputs=[image_input1, image_input2],
139
+ outputs=[output],
140
+ show_api=False
141
  )
142
 
143
+ demo.load(load_model)
144
+
145
+ demo.queue(max_size=1).launch(show_api=False, show_error=True)