fffiloni commited on
Commit
fdf117e
·
verified ·
1 Parent(s): f479bfc

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +41 -43
gradio_app.py CHANGED
@@ -2,6 +2,12 @@ import os
2
  import gradio as gr
3
  import torch
4
  from huggingface_hub import snapshot_download
 
 
 
 
 
 
5
  from diffusers.utils import load_image, export_to_video
6
  from diffusers import UNetSpatioTemporalConditionModel
7
  from custom_diffusers.pipelines.pipeline_frame_interpolation_with_noise_injection import FrameInterpolationWithNoiseInjectionPipeline
@@ -10,16 +16,8 @@ from attn_ctrl.attention_control import (AttentionStore,
10
  register_temporal_self_attention_control,
11
  register_temporal_self_attention_flip_control,
12
  )
13
- from torch.cuda.amp import autocast
14
 
15
- # Set up device
16
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
- # Download checkpoint
19
- snapshot_download(repo_id="fffiloni/svd_keyframe_interpolation", local_dir="checkpoints")
20
- checkpoint_dir = "checkpoints/svd_reverse_motion_with_attnflip"
21
-
22
- # Initialize pipeline
23
  pretrained_model_name_or_path = "stabilityai/stable-video-diffusion-img2vid-xt"
24
  noise_scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
25
 
@@ -31,14 +29,14 @@ pipe = FrameInterpolationWithNoiseInjectionPipeline.from_pretrained(
31
  )
32
  ref_unet = pipe.ori_unet
33
 
34
- # Compute delta w
35
  state_dict = pipe.unet.state_dict()
 
36
  finetuned_unet = UNetSpatioTemporalConditionModel.from_pretrained(
37
  checkpoint_dir,
38
  subfolder="unet",
39
  torch_dtype=torch.float16,
40
  )
41
- assert finetuned_unet.config.num_frames == 14
42
  ori_unet = UNetSpatioTemporalConditionModel.from_pretrained(
43
  "stabilityai/stable-video-diffusion-img2vid",
44
  subfolder="unet",
@@ -54,43 +52,43 @@ for name, param in finetuned_state_dict.items():
54
  state_dict[name] = state_dict[name] + delta_w
55
  pipe.unet.load_state_dict(state_dict)
56
 
57
- controller_ref = AttentionStore()
58
  register_temporal_self_attention_control(ref_unet, controller_ref)
59
 
60
  controller = AttentionStore()
61
  register_temporal_self_attention_flip_control(pipe.unet, controller, controller_ref)
62
 
63
- # Custom CUDA memory management function
64
- def cuda_memory_cleanup():
65
- torch.cuda.empty_cache()
66
- torch.cuda.ipc_collect()
67
 
68
  def check_outputs_folder(folder_path):
 
69
  if os.path.exists(folder_path) and os.path.isdir(folder_path):
 
70
  for filename in os.listdir(folder_path):
71
  file_path = os.path.join(folder_path, filename)
72
  try:
73
  if os.path.isfile(file_path) or os.path.islink(file_path):
74
- os.unlink(file_path)
75
  elif os.path.isdir(file_path):
76
- shutil.rmtree(file_path)
77
  except Exception as e:
78
  print(f'Failed to delete {file_path}. Reason: {e}')
79
  else:
80
  print(f'The folder {folder_path} does not exist.')
81
 
82
- @torch.no_grad()
83
  def infer(frame1_path, frame2_path):
 
84
  seed = 42
85
  num_inference_steps = 10
86
  noise_injection_steps = 0
87
  noise_injection_ratio = 0.5
88
  weighted_average = False
89
- decode_chunk_size = 8
90
 
91
  generator = torch.Generator(device)
92
  if seed is not None:
93
  generator = generator.manual_seed(seed)
 
94
 
95
  frame1 = load_image(frame1_path)
96
  frame1 = frame1.resize((512, 288))
@@ -98,33 +96,35 @@ def infer(frame1_path, frame2_path):
98
  frame2 = load_image(frame2_path)
99
  frame2 = frame2.resize((512, 288))
100
 
101
- cuda_memory_cleanup()
102
-
103
- with autocast():
104
- frames = pipe(image1=frame1, image2=frame2,
105
- num_inference_steps=num_inference_steps,
106
- generator=generator,
107
- weighted_average=weighted_average,
108
- noise_injection_steps=noise_injection_steps,
109
- noise_injection_ratio=noise_injection_ratio,
110
- decode_chunk_size=decode_chunk_size
111
- ).frames[0]
112
 
113
- frames = [frame.cpu() for frame in frames]
 
 
 
 
 
 
 
114
 
 
 
115
  out_dir = "result"
 
116
  check_outputs_folder(out_dir)
117
  os.makedirs(out_dir, exist_ok=True)
118
  out_path = "result/video_result.gif"
119
 
 
 
 
 
 
 
120
  return "done"
121
 
122
- @torch.no_grad()
123
- def load_model():
124
- global pipe
125
- pipe = pipe.to(device)
126
-
127
  with gr.Blocks() as demo:
 
128
  with gr.Column():
129
  gr.Markdown("# Keyframe Interpolation with Stable Video Diffusion")
130
  with gr.Row():
@@ -136,12 +136,10 @@ with gr.Blocks() as demo:
136
  output = gr.Textbox()
137
 
138
  submit_btn.click(
139
- fn=infer,
140
- inputs=[image_input1, image_input2],
141
- outputs=[output],
142
- show_api=False
143
  )
144
 
145
- demo.load(load_model)
146
-
147
- demo.queue(max_size=1).launch(show_api=False, show_error=True)
 
2
  import gradio as gr
3
  import torch
4
  from huggingface_hub import snapshot_download
5
+
6
+ # import argparse
7
+
8
+ snapshot_download(repo_id="fffiloni/svd_keyframe_interpolation", local_dir="checkpoints")
9
+ checkpoint_dir = "checkpoints/svd_reverse_motion_with_attnflip"
10
+
11
  from diffusers.utils import load_image, export_to_video
12
  from diffusers import UNetSpatioTemporalConditionModel
13
  from custom_diffusers.pipelines.pipeline_frame_interpolation_with_noise_injection import FrameInterpolationWithNoiseInjectionPipeline
 
16
  register_temporal_self_attention_control,
17
  register_temporal_self_attention_flip_control,
18
  )
 
19
 
 
 
20
 
 
 
 
 
 
21
  pretrained_model_name_or_path = "stabilityai/stable-video-diffusion-img2vid-xt"
22
  noise_scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
23
 
 
29
  )
30
  ref_unet = pipe.ori_unet
31
 
 
32
  state_dict = pipe.unet.state_dict()
33
+ # computing delta w
34
  finetuned_unet = UNetSpatioTemporalConditionModel.from_pretrained(
35
  checkpoint_dir,
36
  subfolder="unet",
37
  torch_dtype=torch.float16,
38
  )
39
+ assert finetuned_unet.config.num_frames==14
40
  ori_unet = UNetSpatioTemporalConditionModel.from_pretrained(
41
  "stabilityai/stable-video-diffusion-img2vid",
42
  subfolder="unet",
 
52
  state_dict[name] = state_dict[name] + delta_w
53
  pipe.unet.load_state_dict(state_dict)
54
 
55
+ controller_ref= AttentionStore()
56
  register_temporal_self_attention_control(ref_unet, controller_ref)
57
 
58
  controller = AttentionStore()
59
  register_temporal_self_attention_flip_control(pipe.unet, controller, controller_ref)
60
 
61
+ device = "cuda"
62
+ pipe = pipe.to(device)
 
 
63
 
64
  def check_outputs_folder(folder_path):
65
+ # Check if the folder exists
66
  if os.path.exists(folder_path) and os.path.isdir(folder_path):
67
+ # Delete all contents inside the folder
68
  for filename in os.listdir(folder_path):
69
  file_path = os.path.join(folder_path, filename)
70
  try:
71
  if os.path.isfile(file_path) or os.path.islink(file_path):
72
+ os.unlink(file_path) # Remove file or link
73
  elif os.path.isdir(file_path):
74
+ shutil.rmtree(file_path) # Remove directory
75
  except Exception as e:
76
  print(f'Failed to delete {file_path}. Reason: {e}')
77
  else:
78
  print(f'The folder {folder_path} does not exist.')
79
 
 
80
  def infer(frame1_path, frame2_path):
81
+
82
  seed = 42
83
  num_inference_steps = 10
84
  noise_injection_steps = 0
85
  noise_injection_ratio = 0.5
86
  weighted_average = False
 
87
 
88
  generator = torch.Generator(device)
89
  if seed is not None:
90
  generator = generator.manual_seed(seed)
91
+
92
 
93
  frame1 = load_image(frame1_path)
94
  frame1 = frame1.resize((512, 288))
 
96
  frame2 = load_image(frame2_path)
97
  frame2 = frame2.resize((512, 288))
98
 
99
+ torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
100
 
101
+ frames = pipe(image1=frame1, image2=frame2,
102
+ num_inference_steps=num_inference_steps, # 50
103
+ generator=generator,
104
+ weighted_average=weighted_average, # True
105
+ noise_injection_steps=noise_injection_steps, # 0
106
+ noise_injection_ratio= noise_injection_ratio, # 0.5
107
+ decode_chunk_size=4
108
+ ).frames[0]
109
 
110
+ print(f"FRAMES: {frames}")
111
+
112
  out_dir = "result"
113
+
114
  check_outputs_folder(out_dir)
115
  os.makedirs(out_dir, exist_ok=True)
116
  out_path = "result/video_result.gif"
117
 
118
+ '''
119
+ if out_path.endswith('.gif'):
120
+ frames[0].save(out_path, save_all=True, append_images=frames[1:], duration=142, loop=0)
121
+ else:
122
+ export_to_video(frames, out_path, fps=7)
123
+ '''
124
  return "done"
125
 
 
 
 
 
 
126
  with gr.Blocks() as demo:
127
+
128
  with gr.Column():
129
  gr.Markdown("# Keyframe Interpolation with Stable Video Diffusion")
130
  with gr.Row():
 
136
  output = gr.Textbox()
137
 
138
  submit_btn.click(
139
+ fn = infer,
140
+ inputs = [image_input1, image_input2],
141
+ outputs = [output],
142
+ show_api = False
143
  )
144
 
145
+ demo.queue().launch(show_api=False, show_error=True)