Spaces:
Runtime error
Runtime error
weifeng.genius
commited on
Commit
•
3f4baa8
1
Parent(s):
a2f07f6
first init
Browse files- app.py +202 -0
- model/annotator/canny/__init__.py +6 -0
- model/annotator/hed/__init__.py +133 -0
- model/annotator/hed/__pycache__/__init__.cpython-39.pyc +0 -0
- model/annotator/util.py +38 -0
- model/video_diffusion/__init__.py +0 -0
- model/video_diffusion/__pycache__/__init__.cpython-39.pyc +0 -0
- model/video_diffusion/models/__init__.py +0 -0
- model/video_diffusion/models/__pycache__/__init__.cpython-39.pyc +0 -0
- model/video_diffusion/models/__pycache__/attention.cpython-39.pyc +0 -0
- model/video_diffusion/models/__pycache__/controlnet3d.cpython-39.pyc +0 -0
- model/video_diffusion/models/__pycache__/resnet.cpython-39.pyc +0 -0
- model/video_diffusion/models/__pycache__/unet_3d_blocks.cpython-39.pyc +0 -0
- model/video_diffusion/models/__pycache__/unet_3d_blocks_control.cpython-39.pyc +0 -0
- model/video_diffusion/models/__pycache__/unet_3d_condition.cpython-39.pyc +0 -0
- model/video_diffusion/models/attention.py +454 -0
- model/video_diffusion/models/controlnet3d.py +580 -0
- model/video_diffusion/models/resnet.py +486 -0
- model/video_diffusion/models/unet_3d_blocks.py +622 -0
- model/video_diffusion/models/unet_3d_blocks_control.py +116 -0
- model/video_diffusion/models/unet_3d_condition.py +571 -0
- model/video_diffusion/pipelines/__init__.py +0 -0
- model/video_diffusion/pipelines/__pycache__/__init__.cpython-39.pyc +0 -0
- model/video_diffusion/pipelines/__pycache__/pipeline_st_stable_diffusion.cpython-39.pyc +0 -0
- model/video_diffusion/pipelines/__pycache__/pipeline_stable_diffusion_controlnet3d.cpython-39.pyc +0 -0
- model/video_diffusion/pipelines/pipeline_st_stable_diffusion.py +618 -0
- model/video_diffusion/pipelines/pipeline_stable_diffusion_controlnet3d.py +482 -0
- requirements.txt +8 -0
app.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from model.video_diffusion.models.controlnet3d import ControlNet3DModel
|
2 |
+
from model.video_diffusion.models.unet_3d_condition import UNetPseudo3DConditionModel
|
3 |
+
from model.video_diffusion.pipelines.pipeline_stable_diffusion_controlnet3d import Controlnet3DStableDiffusionPipeline
|
4 |
+
from transformers import DPTForDepthEstimation
|
5 |
+
from model.annotator.hed import HEDNetwork
|
6 |
+
import torch
|
7 |
+
from einops import rearrange,repeat
|
8 |
+
import imageio
|
9 |
+
import numpy as np
|
10 |
+
import cv2
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from PIL import Image
|
13 |
+
import argparse
|
14 |
+
import tempfile
|
15 |
+
import os
|
16 |
+
import gradio as gr
|
17 |
+
|
18 |
+
|
19 |
+
control_mode = 'depth'
|
20 |
+
control_net_path = f"wf-genius/controlavideo-{control_mode}"
|
21 |
+
unet = UNetPseudo3DConditionModel.from_pretrained(control_net_path,
|
22 |
+
torch_dtype = torch.float16,
|
23 |
+
subfolder='unet',
|
24 |
+
).to("cuda")
|
25 |
+
controlnet = ControlNet3DModel.from_pretrained(control_net_path,
|
26 |
+
torch_dtype = torch.float16,
|
27 |
+
subfolder='controlnet',
|
28 |
+
).to("cuda")
|
29 |
+
|
30 |
+
if control_mode == 'depth':
|
31 |
+
annotator_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
|
32 |
+
elif control_mode == 'canny':
|
33 |
+
annotator_model = None
|
34 |
+
elif control_mode == 'hed':
|
35 |
+
# firstly download from https://huggingface.co/wf-genius/controlavideo-hed/resolve/main/hed-network.pth
|
36 |
+
annotator_model = HEDNetwork('hed-network.pth').to("cuda")
|
37 |
+
|
38 |
+
video_controlnet_pipe = Controlnet3DStableDiffusionPipeline.from_pretrained(control_net_path, unet=unet,
|
39 |
+
controlnet=controlnet, annotator_model=annotator_model,
|
40 |
+
torch_dtype = torch.float16,
|
41 |
+
).to("cuda")
|
42 |
+
|
43 |
+
|
44 |
+
def to_video(frames, fps: int) -> str:
|
45 |
+
out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
|
46 |
+
writer = imageio.get_writer(out_file.name, format='FFMPEG', fps=fps)
|
47 |
+
for frame in frames:
|
48 |
+
writer.append_data(np.array(frame))
|
49 |
+
writer.close()
|
50 |
+
return out_file.name
|
51 |
+
|
52 |
+
def inference(input_video,
|
53 |
+
prompt,
|
54 |
+
seed,
|
55 |
+
num_inference_steps,
|
56 |
+
guidance_scale,
|
57 |
+
sampling_rate,
|
58 |
+
video_scale,
|
59 |
+
init_noise_thres,
|
60 |
+
each_sample_frame,
|
61 |
+
iter_times,
|
62 |
+
h,
|
63 |
+
w,
|
64 |
+
):
|
65 |
+
num_sample_frames = iter_times * each_sample_frame
|
66 |
+
testing_prompt = [prompt]
|
67 |
+
np_frames, fps_vid = Controlnet3DStableDiffusionPipeline.get_frames_preprocess(input_video, num_frames=num_sample_frames, sampling_rate=sampling_rate, return_np=True)
|
68 |
+
if control_mode == 'depth':
|
69 |
+
frames = torch.from_numpy(np_frames).div(255) * 2 - 1
|
70 |
+
frames = rearrange(frames, "f h w c -> c f h w").unsqueeze(0)
|
71 |
+
frames = rearrange(frames, 'b c f h w -> (b f) c h w')
|
72 |
+
control_maps = video_controlnet_pipe.get_depth_map(frames, h, w, return_standard_norm=False) # (b f) 1 h w
|
73 |
+
elif control_mode == 'canny':
|
74 |
+
control_maps = np.stack([cv2.Canny(inp, 100, 200) for inp in np_frames])
|
75 |
+
control_maps = repeat(control_maps, 'f h w -> f c h w',c=1)
|
76 |
+
control_maps = torch.from_numpy(control_maps).div(255) # 0~1
|
77 |
+
elif control_mode == 'hed':
|
78 |
+
control_maps = np.stack([video_controlnet_pipe.get_hed_map(inp) for inp in np_frames])
|
79 |
+
control_maps = repeat(control_maps, 'f h w -> f c h w',c=1)
|
80 |
+
control_maps = torch.from_numpy(control_maps).div(255) # 0~1
|
81 |
+
control_maps = control_maps.to(dtype=controlnet.dtype, device=controlnet.device)
|
82 |
+
control_maps = F.interpolate(control_maps, size=(h,w), mode='bilinear', align_corners=False)
|
83 |
+
control_maps = rearrange(control_maps, "(b f) c h w -> b c f h w", f=num_sample_frames)
|
84 |
+
if control_maps.shape[1] == 1:
|
85 |
+
control_maps = repeat(control_maps, 'b c f h w -> b (n c) f h w', n=3)
|
86 |
+
|
87 |
+
frames = torch.from_numpy(np_frames).div(255)
|
88 |
+
frames = rearrange(frames, 'f h w c -> f c h w')
|
89 |
+
v2v_input_frames = torch.nn.functional.interpolate(
|
90 |
+
frames,
|
91 |
+
size=(h, w),
|
92 |
+
mode="bicubic",
|
93 |
+
antialias=True,
|
94 |
+
)
|
95 |
+
v2v_input_frames = rearrange(v2v_input_frames, '(b f) c h w -> b c f h w ', f=num_sample_frames)
|
96 |
+
|
97 |
+
out = []
|
98 |
+
for i in range(num_sample_frames//each_sample_frame):
|
99 |
+
out1 = video_controlnet_pipe(
|
100 |
+
# controlnet_hint= control_maps[:,:,:each_sample_frame,:,:],
|
101 |
+
# images= v2v_input_frames[:,:,:each_sample_frame,:,:],
|
102 |
+
controlnet_hint=control_maps[:,:,i*each_sample_frame-1:(i+1)*each_sample_frame-1,:,:] if i>0 else control_maps[:,:,:each_sample_frame,:,:],
|
103 |
+
images=v2v_input_frames[:,:,i*each_sample_frame-1:(i+1)*each_sample_frame-1,:,:] if i>0 else v2v_input_frames[:,:,:each_sample_frame,:,:],
|
104 |
+
first_frame_output=out[-1] if i>0 else None,
|
105 |
+
prompt=testing_prompt,
|
106 |
+
num_inference_steps=num_inference_steps,
|
107 |
+
width=w,
|
108 |
+
height=h,
|
109 |
+
guidance_scale=guidance_scale,
|
110 |
+
generator=[torch.Generator(device="cuda").manual_seed(seed)],
|
111 |
+
video_scale = video_scale,
|
112 |
+
init_noise_by_residual_thres = init_noise_thres, # residual-based init. larger thres ==> more smooth.
|
113 |
+
controlnet_conditioning_scale=1.0,
|
114 |
+
fix_first_frame=True,
|
115 |
+
in_domain=True,
|
116 |
+
)
|
117 |
+
out1 = out1.images[0]
|
118 |
+
if len(out1) > 1:
|
119 |
+
out1 = out1[1:] # drop the first frame
|
120 |
+
out.extend(out1)
|
121 |
+
|
122 |
+
return to_video(out, 8)
|
123 |
+
|
124 |
+
|
125 |
+
examples = [
|
126 |
+
["__assets__/depth_videos_depth/girl_dancing.mp4",
|
127 |
+
"A stormtrooper, masterpiece, a high-quality, detailed, and professional photo"],
|
128 |
+
]
|
129 |
+
|
130 |
+
def preview_inference(
|
131 |
+
input_video,
|
132 |
+
prompt, seed,
|
133 |
+
num_inference_steps, guidance_scale,
|
134 |
+
sampling_rate, video_scale, init_noise_thres,
|
135 |
+
each_sample_frame,iter_times, h, w,
|
136 |
+
):
|
137 |
+
return inference(input_video,
|
138 |
+
prompt, seed,
|
139 |
+
num_inference_steps, guidance_scale,
|
140 |
+
sampling_rate, 0.0, 0.0, 1, 1, h, w,)
|
141 |
+
|
142 |
+
if __name__ == '__main__':
|
143 |
+
with gr.Blocks() as demo:
|
144 |
+
with gr.Row():
|
145 |
+
with gr.Column():
|
146 |
+
input_video = gr.Video(
|
147 |
+
label="Input Video", source='upload', format="mp4", visible=True)
|
148 |
+
with gr.Column():
|
149 |
+
init_noise_thres = gr.Slider(0, 1, value=0.1, step=0.1, label="init_noise_thress")
|
150 |
+
each_sample_frame = gr.Slider(6, 16, value=8, step=1, label="each_sample_frame")
|
151 |
+
iter_times = gr.Slider(1, 4, value=1, step=1, label="iter_times")
|
152 |
+
sampling_rate = gr.Slider(1, 8, value=3, step=1, label="sampling_rate")
|
153 |
+
h = gr.Slider(256, 768, value=512, step=64, label="height")
|
154 |
+
w = gr.Slider(256, 768, value=512, step=64, label="width")
|
155 |
+
with gr.Column():
|
156 |
+
seed = gr.Slider(0, 6666, value=1, step=1, label="seed")
|
157 |
+
num_inference_steps = gr.Slider(5, 50, value=20, step=1, label="num_inference_steps")
|
158 |
+
guidance_scale = gr.Slider(1, 20, value=7.5, step=0.5, label="guidance_scale")
|
159 |
+
video_scale = gr.Slider(0, 2.5, value=1.5, step=0.1, label="video_scale")
|
160 |
+
prompt = gr.Textbox(label='Prompt')
|
161 |
+
# preview_button = gr.Button('Preview')
|
162 |
+
run_button = gr.Button('Generate Video')
|
163 |
+
|
164 |
+
with gr.Column():
|
165 |
+
result = gr.Video(label="Generated Video")
|
166 |
+
|
167 |
+
inputs = [
|
168 |
+
input_video,
|
169 |
+
prompt,
|
170 |
+
seed,
|
171 |
+
num_inference_steps,
|
172 |
+
guidance_scale,
|
173 |
+
sampling_rate,
|
174 |
+
video_scale,
|
175 |
+
init_noise_thres,
|
176 |
+
each_sample_frame,
|
177 |
+
iter_times,
|
178 |
+
h,
|
179 |
+
w,
|
180 |
+
]
|
181 |
+
|
182 |
+
gr.Examples(examples=examples,
|
183 |
+
inputs=inputs,
|
184 |
+
outputs=result,
|
185 |
+
fn=inference,
|
186 |
+
cache_examples=False,
|
187 |
+
run_on_click=False,
|
188 |
+
)
|
189 |
+
|
190 |
+
run_button.click(fn=inference,
|
191 |
+
inputs=inputs,
|
192 |
+
outputs=result,)
|
193 |
+
# preview_button.click(fn=preview_inference,
|
194 |
+
# inputs=inputs,
|
195 |
+
# outputs=result,)
|
196 |
+
|
197 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
198 |
+
|
199 |
+
|
200 |
+
# TODO
|
201 |
+
# 1. preview
|
202 |
+
# 2. params
|
model/annotator/canny/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
|
3 |
+
|
4 |
+
class CannyDetector:
|
5 |
+
def __call__(self, img, low_threshold, high_threshold):
|
6 |
+
return cv2.Canny(img, low_threshold, high_threshold)
|
model/annotator/hed/__init__.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
from einops import rearrange
|
6 |
+
|
7 |
+
|
8 |
+
class HEDNetwork(torch.nn.Module):
|
9 |
+
def __init__(self, model_path):
|
10 |
+
super().__init__()
|
11 |
+
|
12 |
+
self.netVggOne = torch.nn.Sequential(
|
13 |
+
torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
|
14 |
+
torch.nn.ReLU(inplace=False),
|
15 |
+
torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
|
16 |
+
torch.nn.ReLU(inplace=False)
|
17 |
+
)
|
18 |
+
|
19 |
+
self.netVggTwo = torch.nn.Sequential(
|
20 |
+
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
21 |
+
torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
|
22 |
+
torch.nn.ReLU(inplace=False),
|
23 |
+
torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
24 |
+
torch.nn.ReLU(inplace=False)
|
25 |
+
)
|
26 |
+
|
27 |
+
self.netVggThr = torch.nn.Sequential(
|
28 |
+
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
29 |
+
torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
|
30 |
+
torch.nn.ReLU(inplace=False),
|
31 |
+
torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
|
32 |
+
torch.nn.ReLU(inplace=False),
|
33 |
+
torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
|
34 |
+
torch.nn.ReLU(inplace=False)
|
35 |
+
)
|
36 |
+
|
37 |
+
self.netVggFou = torch.nn.Sequential(
|
38 |
+
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
39 |
+
torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
|
40 |
+
torch.nn.ReLU(inplace=False),
|
41 |
+
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
42 |
+
torch.nn.ReLU(inplace=False),
|
43 |
+
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
44 |
+
torch.nn.ReLU(inplace=False)
|
45 |
+
)
|
46 |
+
|
47 |
+
self.netVggFiv = torch.nn.Sequential(
|
48 |
+
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
49 |
+
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
50 |
+
torch.nn.ReLU(inplace=False),
|
51 |
+
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
52 |
+
torch.nn.ReLU(inplace=False),
|
53 |
+
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
54 |
+
torch.nn.ReLU(inplace=False)
|
55 |
+
)
|
56 |
+
|
57 |
+
self.netScoreOne = torch.nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0)
|
58 |
+
self.netScoreTwo = torch.nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0)
|
59 |
+
self.netScoreThr = torch.nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0)
|
60 |
+
self.netScoreFou = torch.nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
|
61 |
+
self.netScoreFiv = torch.nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
|
62 |
+
|
63 |
+
self.netCombine = torch.nn.Sequential(
|
64 |
+
torch.nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0),
|
65 |
+
torch.nn.Sigmoid()
|
66 |
+
)
|
67 |
+
|
68 |
+
self.load_state_dict({strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.load(model_path).items()})
|
69 |
+
|
70 |
+
def forward(self, tenInput):
|
71 |
+
tenInput = tenInput * 255.0
|
72 |
+
tenInput = tenInput - torch.tensor(data=[104.00698793, 116.66876762, 122.67891434], dtype=tenInput.dtype, device=tenInput.device).view(1, 3, 1, 1)
|
73 |
+
|
74 |
+
tenVggOne = self.netVggOne(tenInput)
|
75 |
+
tenVggTwo = self.netVggTwo(tenVggOne)
|
76 |
+
tenVggThr = self.netVggThr(tenVggTwo)
|
77 |
+
tenVggFou = self.netVggFou(tenVggThr)
|
78 |
+
tenVggFiv = self.netVggFiv(tenVggFou)
|
79 |
+
|
80 |
+
tenScoreOne = self.netScoreOne(tenVggOne)
|
81 |
+
tenScoreTwo = self.netScoreTwo(tenVggTwo)
|
82 |
+
tenScoreThr = self.netScoreThr(tenVggThr)
|
83 |
+
tenScoreFou = self.netScoreFou(tenVggFou)
|
84 |
+
tenScoreFiv = self.netScoreFiv(tenVggFiv)
|
85 |
+
|
86 |
+
tenScoreOne = torch.nn.functional.interpolate(input=tenScoreOne, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
|
87 |
+
tenScoreTwo = torch.nn.functional.interpolate(input=tenScoreTwo, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
|
88 |
+
tenScoreThr = torch.nn.functional.interpolate(input=tenScoreThr, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
|
89 |
+
tenScoreFou = torch.nn.functional.interpolate(input=tenScoreFou, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
|
90 |
+
tenScoreFiv = torch.nn.functional.interpolate(input=tenScoreFiv, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
|
91 |
+
|
92 |
+
return self.netCombine(torch.cat([ tenScoreOne, tenScoreTwo, tenScoreThr, tenScoreFou, tenScoreFiv ], 1))
|
93 |
+
|
94 |
+
|
95 |
+
class HEDdetector:
|
96 |
+
def __init__(self, network ):
|
97 |
+
self.netNetwork = network
|
98 |
+
|
99 |
+
def __call__(self, input_image):
|
100 |
+
if isinstance(input_image, torch.Tensor):
|
101 |
+
# 输入的就是 b c h w的tensor 范围是-1~1,需要转换为0~1
|
102 |
+
input_image = (input_image + 1) / 2
|
103 |
+
input_image = input_image.float().cuda()
|
104 |
+
edge = self.netNetwork(input_image) # 范围也是0~1, 不用转了直接用
|
105 |
+
return edge
|
106 |
+
else:
|
107 |
+
assert input_image.ndim == 3
|
108 |
+
input_image = input_image[:, :, ::-1].copy()
|
109 |
+
with torch.no_grad():
|
110 |
+
image_hed = torch.from_numpy(input_image).float().cuda()
|
111 |
+
image_hed = image_hed / 255.0
|
112 |
+
image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
|
113 |
+
edge = self.netNetwork(image_hed)[0]
|
114 |
+
edge = (edge.cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8)
|
115 |
+
return edge[0]
|
116 |
+
|
117 |
+
|
118 |
+
def nms(x, t, s):
|
119 |
+
x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
|
120 |
+
|
121 |
+
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
|
122 |
+
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
|
123 |
+
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
|
124 |
+
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
|
125 |
+
|
126 |
+
y = np.zeros_like(x)
|
127 |
+
|
128 |
+
for f in [f1, f2, f3, f4]:
|
129 |
+
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
|
130 |
+
|
131 |
+
z = np.zeros_like(y, dtype=np.uint8)
|
132 |
+
z[y > t] = 255
|
133 |
+
return z
|
model/annotator/hed/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (4.31 kB). View file
|
|
model/annotator/util.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
|
7 |
+
|
8 |
+
|
9 |
+
def HWC3(x):
|
10 |
+
assert x.dtype == np.uint8
|
11 |
+
if x.ndim == 2:
|
12 |
+
x = x[:, :, None]
|
13 |
+
assert x.ndim == 3
|
14 |
+
H, W, C = x.shape
|
15 |
+
assert C == 1 or C == 3 or C == 4
|
16 |
+
if C == 3:
|
17 |
+
return x
|
18 |
+
if C == 1:
|
19 |
+
return np.concatenate([x, x, x], axis=2)
|
20 |
+
if C == 4:
|
21 |
+
color = x[:, :, 0:3].astype(np.float32)
|
22 |
+
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
23 |
+
y = color * alpha + 255.0 * (1.0 - alpha)
|
24 |
+
y = y.clip(0, 255).astype(np.uint8)
|
25 |
+
return y
|
26 |
+
|
27 |
+
|
28 |
+
def resize_image(input_image, resolution):
|
29 |
+
H, W, C = input_image.shape
|
30 |
+
H = float(H)
|
31 |
+
W = float(W)
|
32 |
+
k = float(resolution) / min(H, W)
|
33 |
+
H *= k
|
34 |
+
W *= k
|
35 |
+
H = int(np.round(H / 64.0)) * 64
|
36 |
+
W = int(np.round(W / 64.0)) * 64
|
37 |
+
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
|
38 |
+
return img
|
model/video_diffusion/__init__.py
ADDED
File without changes
|
model/video_diffusion/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (168 Bytes). View file
|
|
model/video_diffusion/models/__init__.py
ADDED
File without changes
|
model/video_diffusion/models/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (175 Bytes). View file
|
|
model/video_diffusion/models/__pycache__/attention.cpython-39.pyc
ADDED
Binary file (9.27 kB). View file
|
|
model/video_diffusion/models/__pycache__/controlnet3d.cpython-39.pyc
ADDED
Binary file (15.7 kB). View file
|
|
model/video_diffusion/models/__pycache__/resnet.cpython-39.pyc
ADDED
Binary file (11.5 kB). View file
|
|
model/video_diffusion/models/__pycache__/unet_3d_blocks.cpython-39.pyc
ADDED
Binary file (10.7 kB). View file
|
|
model/video_diffusion/models/__pycache__/unet_3d_blocks_control.cpython-39.pyc
ADDED
Binary file (3.82 kB). View file
|
|
model/video_diffusion/models/__pycache__/unet_3d_condition.cpython-39.pyc
ADDED
Binary file (14.5 kB). View file
|
|
model/video_diffusion/models/attention.py
ADDED
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Bytedance Ltd. and/or its affiliates
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from dataclasses import dataclass
|
16 |
+
from typing import Optional
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from torch import nn
|
20 |
+
|
21 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from diffusers.models.modeling_utils import ModelMixin
|
23 |
+
from diffusers.models.attention import FeedForward, CrossAttention, AdaLayerNorm
|
24 |
+
from diffusers.utils import BaseOutput
|
25 |
+
from diffusers.utils.import_utils import is_xformers_available
|
26 |
+
from diffusers.models.cross_attention import XFormersCrossAttnProcessor
|
27 |
+
from einops import rearrange
|
28 |
+
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
class SpatioTemporalTransformerModelOutput(BaseOutput):
|
32 |
+
"""torch.FloatTensor of shape [batch x channel x frames x height x width]"""
|
33 |
+
|
34 |
+
sample: torch.FloatTensor
|
35 |
+
|
36 |
+
|
37 |
+
if is_xformers_available():
|
38 |
+
import xformers
|
39 |
+
import xformers.ops
|
40 |
+
else:
|
41 |
+
xformers = None
|
42 |
+
|
43 |
+
|
44 |
+
class SpatioTemporalTransformerModel(ModelMixin, ConfigMixin):
|
45 |
+
@register_to_config
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
num_attention_heads: int = 16,
|
49 |
+
attention_head_dim: int = 88,
|
50 |
+
in_channels: Optional[int] = None,
|
51 |
+
num_layers: int = 1,
|
52 |
+
dropout: float = 0.0,
|
53 |
+
norm_num_groups: int = 32,
|
54 |
+
cross_attention_dim: Optional[int] = None,
|
55 |
+
attention_bias: bool = False,
|
56 |
+
activation_fn: str = "geglu",
|
57 |
+
num_embeds_ada_norm: Optional[int] = None,
|
58 |
+
use_linear_projection: bool = False,
|
59 |
+
only_cross_attention: bool = False,
|
60 |
+
upcast_attention: bool = False,
|
61 |
+
**transformer_kwargs,
|
62 |
+
):
|
63 |
+
super().__init__()
|
64 |
+
self.use_linear_projection = use_linear_projection
|
65 |
+
self.num_attention_heads = num_attention_heads
|
66 |
+
self.attention_head_dim = attention_head_dim
|
67 |
+
inner_dim = num_attention_heads * attention_head_dim
|
68 |
+
|
69 |
+
# Define input layers
|
70 |
+
self.in_channels = in_channels
|
71 |
+
|
72 |
+
self.norm = torch.nn.GroupNorm(
|
73 |
+
num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
74 |
+
)
|
75 |
+
if use_linear_projection:
|
76 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
77 |
+
else:
|
78 |
+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
79 |
+
|
80 |
+
# Define transformers blocks
|
81 |
+
self.transformer_blocks = nn.ModuleList(
|
82 |
+
[
|
83 |
+
SpatioTemporalTransformerBlock(
|
84 |
+
inner_dim,
|
85 |
+
num_attention_heads,
|
86 |
+
attention_head_dim,
|
87 |
+
dropout=dropout,
|
88 |
+
cross_attention_dim=cross_attention_dim,
|
89 |
+
activation_fn=activation_fn,
|
90 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
91 |
+
attention_bias=attention_bias,
|
92 |
+
only_cross_attention=only_cross_attention,
|
93 |
+
upcast_attention=upcast_attention,
|
94 |
+
**transformer_kwargs,
|
95 |
+
)
|
96 |
+
for d in range(num_layers)
|
97 |
+
]
|
98 |
+
)
|
99 |
+
|
100 |
+
# Define output layers
|
101 |
+
if use_linear_projection:
|
102 |
+
self.proj_out = nn.Linear(in_channels, inner_dim)
|
103 |
+
else:
|
104 |
+
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
105 |
+
|
106 |
+
def forward(
|
107 |
+
self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True
|
108 |
+
):
|
109 |
+
# 1. Input
|
110 |
+
clip_length = None
|
111 |
+
is_video = hidden_states.ndim == 5
|
112 |
+
if is_video:
|
113 |
+
clip_length = hidden_states.shape[2]
|
114 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
115 |
+
encoder_hidden_states = encoder_hidden_states.repeat_interleave(clip_length, 0)
|
116 |
+
|
117 |
+
*_, h, w = hidden_states.shape
|
118 |
+
residual = hidden_states
|
119 |
+
|
120 |
+
hidden_states = self.norm(hidden_states)
|
121 |
+
if not self.use_linear_projection:
|
122 |
+
hidden_states = self.proj_in(hidden_states)
|
123 |
+
hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c")
|
124 |
+
else:
|
125 |
+
hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c")
|
126 |
+
hidden_states = self.proj_in(hidden_states)
|
127 |
+
|
128 |
+
# 2. Blocks
|
129 |
+
for block in self.transformer_blocks:
|
130 |
+
hidden_states = block(
|
131 |
+
hidden_states,
|
132 |
+
encoder_hidden_states=encoder_hidden_states,
|
133 |
+
timestep=timestep,
|
134 |
+
clip_length=clip_length,
|
135 |
+
)
|
136 |
+
|
137 |
+
# 3. Output
|
138 |
+
if not self.use_linear_projection:
|
139 |
+
hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=h, w=w).contiguous()
|
140 |
+
hidden_states = self.proj_out(hidden_states)
|
141 |
+
else:
|
142 |
+
hidden_states = self.proj_out(hidden_states)
|
143 |
+
hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=h, w=w).contiguous()
|
144 |
+
|
145 |
+
output = hidden_states + residual
|
146 |
+
if is_video:
|
147 |
+
output = rearrange(output, "(b f) c h w -> b c f h w", f=clip_length)
|
148 |
+
|
149 |
+
if not return_dict:
|
150 |
+
return (output,)
|
151 |
+
|
152 |
+
return SpatioTemporalTransformerModelOutput(sample=output)
|
153 |
+
|
154 |
+
|
155 |
+
class SpatioTemporalTransformerBlock(nn.Module):
|
156 |
+
def __init__(
|
157 |
+
self,
|
158 |
+
dim: int,
|
159 |
+
num_attention_heads: int,
|
160 |
+
attention_head_dim: int,
|
161 |
+
dropout=0.0,
|
162 |
+
cross_attention_dim: Optional[int] = None,
|
163 |
+
activation_fn: str = "geglu",
|
164 |
+
num_embeds_ada_norm: Optional[int] = None,
|
165 |
+
attention_bias: bool = False,
|
166 |
+
only_cross_attention: bool = False,
|
167 |
+
upcast_attention: bool = False,
|
168 |
+
use_sparse_causal_attention: bool = False,
|
169 |
+
use_full_sparse_causal_attention: bool = True,
|
170 |
+
temporal_attention_position: str = "after_feedforward",
|
171 |
+
use_gamma = False,
|
172 |
+
):
|
173 |
+
super().__init__()
|
174 |
+
self.only_cross_attention = only_cross_attention
|
175 |
+
self.use_ada_layer_norm = num_embeds_ada_norm is not None
|
176 |
+
self.use_sparse_causal_attention = use_sparse_causal_attention
|
177 |
+
self.use_full_sparse_causal_attention = use_full_sparse_causal_attention
|
178 |
+
self.use_gamma = use_gamma
|
179 |
+
|
180 |
+
self.temporal_attention_position = temporal_attention_position
|
181 |
+
temporal_attention_positions = ["after_spatial", "after_cross", "after_feedforward"]
|
182 |
+
if temporal_attention_position not in temporal_attention_positions:
|
183 |
+
raise ValueError(
|
184 |
+
f"`temporal_attention_position` must be one of {temporal_attention_positions}"
|
185 |
+
)
|
186 |
+
|
187 |
+
# 1. Spatial-Attn
|
188 |
+
if use_sparse_causal_attention:
|
189 |
+
spatial_attention = SparseCausalAttention
|
190 |
+
elif use_full_sparse_causal_attention:
|
191 |
+
spatial_attention = SparseCausalFullAttention
|
192 |
+
else:
|
193 |
+
spatial_attention = CrossAttention
|
194 |
+
|
195 |
+
self.attn1 = spatial_attention(
|
196 |
+
query_dim=dim,
|
197 |
+
heads=num_attention_heads,
|
198 |
+
dim_head=attention_head_dim,
|
199 |
+
dropout=dropout,
|
200 |
+
bias=attention_bias,
|
201 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
202 |
+
upcast_attention=upcast_attention,
|
203 |
+
processor=XFormersCrossAttnProcessor(),
|
204 |
+
) # is a self-attention
|
205 |
+
self.norm1 = (
|
206 |
+
AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
207 |
+
)
|
208 |
+
if use_gamma:
|
209 |
+
self.attn1_gamma = nn.Parameter(torch.ones(dim))
|
210 |
+
|
211 |
+
# 2. Cross-Attn
|
212 |
+
if cross_attention_dim is not None:
|
213 |
+
self.attn2 = CrossAttention(
|
214 |
+
query_dim=dim,
|
215 |
+
cross_attention_dim=cross_attention_dim,
|
216 |
+
heads=num_attention_heads,
|
217 |
+
dim_head=attention_head_dim,
|
218 |
+
dropout=dropout,
|
219 |
+
bias=attention_bias,
|
220 |
+
upcast_attention=upcast_attention,
|
221 |
+
processor=XFormersCrossAttnProcessor(),
|
222 |
+
) # is self-attn if encoder_hidden_states is none
|
223 |
+
self.norm2 = (
|
224 |
+
AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
225 |
+
)
|
226 |
+
if use_gamma:
|
227 |
+
self.attn2_gamma = nn.Parameter(torch.ones(dim))
|
228 |
+
else:
|
229 |
+
self.attn2 = None
|
230 |
+
self.norm2 = None
|
231 |
+
|
232 |
+
# 3. Temporal-Attn
|
233 |
+
self.attn_temporal = CrossAttention(
|
234 |
+
query_dim=dim,
|
235 |
+
heads=num_attention_heads,
|
236 |
+
dim_head=attention_head_dim,
|
237 |
+
dropout=dropout,
|
238 |
+
bias=attention_bias,
|
239 |
+
upcast_attention=upcast_attention,
|
240 |
+
processor=XFormersCrossAttnProcessor()
|
241 |
+
)
|
242 |
+
zero_module(self.attn_temporal) # 默认参数置0
|
243 |
+
|
244 |
+
self.norm_temporal = (
|
245 |
+
AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
246 |
+
)
|
247 |
+
|
248 |
+
# 4. Feed-forward
|
249 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
250 |
+
self.norm3 = nn.LayerNorm(dim)
|
251 |
+
if use_gamma:
|
252 |
+
self.ff_gamma = nn.Parameter(torch.ones(dim))
|
253 |
+
|
254 |
+
|
255 |
+
def forward(
|
256 |
+
self,
|
257 |
+
hidden_states,
|
258 |
+
encoder_hidden_states=None,
|
259 |
+
timestep=None,
|
260 |
+
attention_mask=None,
|
261 |
+
clip_length=None,
|
262 |
+
):
|
263 |
+
# 1. Self-Attention
|
264 |
+
norm_hidden_states = (
|
265 |
+
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
|
266 |
+
)
|
267 |
+
|
268 |
+
kwargs = dict(
|
269 |
+
hidden_states=norm_hidden_states,
|
270 |
+
attention_mask=attention_mask,
|
271 |
+
)
|
272 |
+
if self.only_cross_attention:
|
273 |
+
kwargs.update(encoder_hidden_states=encoder_hidden_states)
|
274 |
+
if self.use_sparse_causal_attention or self.use_full_sparse_causal_attention:
|
275 |
+
kwargs.update(clip_length=clip_length)
|
276 |
+
|
277 |
+
if self.use_gamma:
|
278 |
+
hidden_states = hidden_states + self.attn1(**kwargs) * self.attn1_gamma # NOTE gamma
|
279 |
+
else:
|
280 |
+
hidden_states = hidden_states + self.attn1(**kwargs)
|
281 |
+
|
282 |
+
|
283 |
+
if clip_length is not None and self.temporal_attention_position == "after_spatial":
|
284 |
+
hidden_states = self.apply_temporal_attention(hidden_states, timestep, clip_length)
|
285 |
+
|
286 |
+
if self.attn2 is not None:
|
287 |
+
# 2. Cross-Attention
|
288 |
+
norm_hidden_states = (
|
289 |
+
self.norm2(hidden_states, timestep)
|
290 |
+
if self.use_ada_layer_norm
|
291 |
+
else self.norm2(hidden_states)
|
292 |
+
)
|
293 |
+
if self.use_gamma:
|
294 |
+
hidden_states = (
|
295 |
+
self.attn2(
|
296 |
+
norm_hidden_states,
|
297 |
+
encoder_hidden_states=encoder_hidden_states,
|
298 |
+
attention_mask=attention_mask,
|
299 |
+
) * self.attn2_gamma
|
300 |
+
+ hidden_states
|
301 |
+
)
|
302 |
+
else:
|
303 |
+
hidden_states = (
|
304 |
+
self.attn2(
|
305 |
+
norm_hidden_states,
|
306 |
+
encoder_hidden_states=encoder_hidden_states,
|
307 |
+
attention_mask=attention_mask,
|
308 |
+
)
|
309 |
+
+ hidden_states
|
310 |
+
)
|
311 |
+
|
312 |
+
if clip_length is not None and self.temporal_attention_position == "after_cross":
|
313 |
+
hidden_states = self.apply_temporal_attention(hidden_states, timestep, clip_length)
|
314 |
+
|
315 |
+
# 3. Feed-forward
|
316 |
+
if self.use_gamma:
|
317 |
+
hidden_states = self.ff(self.norm3(hidden_states)) * self.ff_gamma + hidden_states
|
318 |
+
else:
|
319 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
320 |
+
|
321 |
+
if clip_length is not None and self.temporal_attention_position == "after_feedforward":
|
322 |
+
hidden_states = self.apply_temporal_attention(hidden_states, timestep, clip_length)
|
323 |
+
|
324 |
+
return hidden_states
|
325 |
+
|
326 |
+
def apply_temporal_attention(self, hidden_states, timestep, clip_length):
|
327 |
+
d = hidden_states.shape[1]
|
328 |
+
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=clip_length)
|
329 |
+
norm_hidden_states = (
|
330 |
+
self.norm_temporal(hidden_states, timestep)
|
331 |
+
if self.use_ada_layer_norm
|
332 |
+
else self.norm_temporal(hidden_states)
|
333 |
+
)
|
334 |
+
hidden_states = self.attn_temporal(norm_hidden_states) + hidden_states
|
335 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
336 |
+
return hidden_states
|
337 |
+
|
338 |
+
|
339 |
+
class SparseCausalAttention(CrossAttention):
|
340 |
+
def forward(
|
341 |
+
self,
|
342 |
+
hidden_states,
|
343 |
+
encoder_hidden_states=None,
|
344 |
+
attention_mask=None,
|
345 |
+
clip_length: int = None,
|
346 |
+
):
|
347 |
+
if (
|
348 |
+
self.added_kv_proj_dim is not None
|
349 |
+
or encoder_hidden_states is not None
|
350 |
+
or attention_mask is not None
|
351 |
+
):
|
352 |
+
raise NotImplementedError
|
353 |
+
|
354 |
+
if self.group_norm is not None:
|
355 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
356 |
+
|
357 |
+
query = self.to_q(hidden_states)
|
358 |
+
dim = query.shape[-1]
|
359 |
+
query = self.head_to_batch_dim(query) # 64 4096 40
|
360 |
+
|
361 |
+
key = self.to_k(hidden_states)
|
362 |
+
value = self.to_v(hidden_states)
|
363 |
+
|
364 |
+
if clip_length is not None and clip_length > 1:
|
365 |
+
# spatial temporal
|
366 |
+
prev_frame_index = torch.arange(clip_length) - 1
|
367 |
+
prev_frame_index[0] = 0
|
368 |
+
key = rearrange(key, "(b f) d c -> b f d c", f=clip_length)
|
369 |
+
key = torch.cat([key[:, [0] * clip_length], key[:, prev_frame_index]], dim=2)
|
370 |
+
key = rearrange(key, "b f d c -> (b f) d c", f=clip_length)
|
371 |
+
|
372 |
+
value = rearrange(value, "(b f) d c -> b f d c", f=clip_length)
|
373 |
+
value = torch.cat([value[:, [0] * clip_length], value[:, prev_frame_index]], dim=2)
|
374 |
+
value = rearrange(value, "b f d c -> (b f) d c", f=clip_length)
|
375 |
+
|
376 |
+
|
377 |
+
key = self.head_to_batch_dim(key)
|
378 |
+
value = self.head_to_batch_dim(value)
|
379 |
+
# use xfromers by default~
|
380 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
381 |
+
query, key, value, attn_bias=attention_mask, op=None
|
382 |
+
)
|
383 |
+
hidden_states = hidden_states.to(query.dtype)
|
384 |
+
hidden_states = self.batch_to_head_dim(hidden_states)
|
385 |
+
|
386 |
+
# linear proj
|
387 |
+
hidden_states = self.to_out[0](hidden_states)
|
388 |
+
|
389 |
+
# dropout
|
390 |
+
hidden_states = self.to_out[1](hidden_states)
|
391 |
+
return hidden_states
|
392 |
+
|
393 |
+
def zero_module(module):
|
394 |
+
for p in module.parameters():
|
395 |
+
nn.init.zeros_(p)
|
396 |
+
return module
|
397 |
+
|
398 |
+
|
399 |
+
class SparseCausalFullAttention(CrossAttention):
|
400 |
+
def forward(
|
401 |
+
self,
|
402 |
+
hidden_states,
|
403 |
+
encoder_hidden_states=None,
|
404 |
+
attention_mask=None,
|
405 |
+
clip_length: int = None,
|
406 |
+
):
|
407 |
+
if (
|
408 |
+
self.added_kv_proj_dim is not None
|
409 |
+
or encoder_hidden_states is not None
|
410 |
+
or attention_mask is not None
|
411 |
+
):
|
412 |
+
raise NotImplementedError
|
413 |
+
|
414 |
+
if self.group_norm is not None:
|
415 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
416 |
+
|
417 |
+
query = self.to_q(hidden_states)
|
418 |
+
dim = query.shape[-1]
|
419 |
+
query = self.head_to_batch_dim(query) # 64 4096 40
|
420 |
+
|
421 |
+
key = self.to_k(hidden_states)
|
422 |
+
value = self.to_v(hidden_states)
|
423 |
+
|
424 |
+
if clip_length is not None and clip_length > 1:
|
425 |
+
# 和所有帧做 spatial temporal attention
|
426 |
+
key = rearrange(key, "(b f) d c -> b f d c", f=clip_length)
|
427 |
+
# cat full frames
|
428 |
+
key = torch.cat([key[:, [iii] * clip_length] for iii in range(clip_length) ], dim=2) # concat第一帧+第i帧。以此为key, value。而非自己这一帧。
|
429 |
+
key = rearrange(key, "b f d c -> (b f) d c", f=clip_length)
|
430 |
+
|
431 |
+
value = rearrange(value, "(b f) d c -> b f d c", f=clip_length)
|
432 |
+
value = torch.cat([value[:, [iii] * clip_length] for iii in range(clip_length) ], dim=2) # concat第一帧+第i帧。以此为key, value。而非自己这一帧。
|
433 |
+
value = rearrange(value, "b f d c -> (b f) d c", f=clip_length)
|
434 |
+
|
435 |
+
key = self.head_to_batch_dim(key)
|
436 |
+
value = self.head_to_batch_dim(value)
|
437 |
+
# use xfromers by default~
|
438 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
439 |
+
query, key, value, attn_bias=attention_mask, op=None
|
440 |
+
)
|
441 |
+
hidden_states = hidden_states.to(query.dtype)
|
442 |
+
hidden_states = self.batch_to_head_dim(hidden_states)
|
443 |
+
|
444 |
+
# linear proj
|
445 |
+
hidden_states = self.to_out[0](hidden_states)
|
446 |
+
|
447 |
+
# dropout
|
448 |
+
hidden_states = self.to_out[1](hidden_states)
|
449 |
+
return hidden_states
|
450 |
+
|
451 |
+
def zero_module(module):
|
452 |
+
for p in module.parameters():
|
453 |
+
nn.init.zeros_(p)
|
454 |
+
return module
|
model/video_diffusion/models/controlnet3d.py
ADDED
@@ -0,0 +1,580 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Bytedance Ltd. and/or its affiliates
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from torch import nn
|
21 |
+
from torch.nn import functional as F
|
22 |
+
|
23 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
24 |
+
from diffusers.utils import BaseOutput, logging
|
25 |
+
from diffusers.models.cross_attention import AttnProcessor
|
26 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
27 |
+
from diffusers.models.modeling_utils import ModelMixin
|
28 |
+
|
29 |
+
from .unet_3d_blocks import (
|
30 |
+
CrossAttnDownBlockPseudo3D,
|
31 |
+
DownBlockPseudo3D,
|
32 |
+
UNetMidBlockPseudo3DCrossAttn,
|
33 |
+
get_down_block,
|
34 |
+
)
|
35 |
+
from .resnet import PseudoConv3d
|
36 |
+
from diffusers.models.cross_attention import AttnProcessor
|
37 |
+
from typing import Dict
|
38 |
+
from .unet_3d_blocks_control import ControlNetPseudoZeroConv3dBlock, ControlNetInputHintBlock
|
39 |
+
import glob
|
40 |
+
import os
|
41 |
+
import json
|
42 |
+
|
43 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
44 |
+
@dataclass
|
45 |
+
class ControlNetOutput(BaseOutput):
|
46 |
+
down_block_res_samples: Tuple[torch.Tensor]
|
47 |
+
mid_block_res_sample: torch.Tensor
|
48 |
+
|
49 |
+
|
50 |
+
class ControlNetConditioningEmbedding(nn.Module):
|
51 |
+
"""
|
52 |
+
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
|
53 |
+
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
|
54 |
+
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
|
55 |
+
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
|
56 |
+
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
|
57 |
+
model) to encode image-space conditions ... into feature maps ..."
|
58 |
+
"""
|
59 |
+
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
conditioning_embedding_channels: int,
|
63 |
+
conditioning_channels: int = 3,
|
64 |
+
block_out_channels: Tuple[int] = (16, 32, 96, 256),
|
65 |
+
):
|
66 |
+
super().__init__()
|
67 |
+
|
68 |
+
self.conv_in = PseudoConv3d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
|
69 |
+
|
70 |
+
self.blocks = nn.ModuleList([])
|
71 |
+
|
72 |
+
for i in range(len(block_out_channels) - 1):
|
73 |
+
channel_in = block_out_channels[i]
|
74 |
+
channel_out = block_out_channels[i + 1]
|
75 |
+
self.blocks.append(PseudoConv3d(channel_in, channel_in, kernel_size=3, padding=1))
|
76 |
+
self.blocks.append(PseudoConv3d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
|
77 |
+
|
78 |
+
# self.conv_out = zero_module(
|
79 |
+
# PseudoConv3d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
|
80 |
+
# )
|
81 |
+
self.conv_out = PseudoConv3d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
|
82 |
+
|
83 |
+
def forward(self, conditioning):
|
84 |
+
embedding = self.conv_in(conditioning)
|
85 |
+
embedding = F.silu(embedding)
|
86 |
+
|
87 |
+
for block in self.blocks:
|
88 |
+
embedding = block(embedding)
|
89 |
+
embedding = F.silu(embedding)
|
90 |
+
|
91 |
+
embedding = self.conv_out(embedding)
|
92 |
+
|
93 |
+
return embedding
|
94 |
+
|
95 |
+
|
96 |
+
class ControlNet3DModel(ModelMixin, ConfigMixin):
|
97 |
+
_supports_gradient_checkpointing = True
|
98 |
+
|
99 |
+
@register_to_config
|
100 |
+
def __init__(
|
101 |
+
self,
|
102 |
+
in_channels: int = 4,
|
103 |
+
flip_sin_to_cos: bool = True,
|
104 |
+
freq_shift: int = 0,
|
105 |
+
down_block_types: Tuple[str] = (
|
106 |
+
"CrossAttnDownBlockPseudo3D",
|
107 |
+
"CrossAttnDownBlockPseudo3D",
|
108 |
+
"CrossAttnDownBlockPseudo3D",
|
109 |
+
"DownBlockPseudo3D",
|
110 |
+
),
|
111 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
112 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
113 |
+
layers_per_block: int = 2,
|
114 |
+
downsample_padding: int = 1,
|
115 |
+
mid_block_scale_factor: float = 1,
|
116 |
+
act_fn: str = "silu",
|
117 |
+
norm_num_groups: Optional[int] = 32,
|
118 |
+
norm_eps: float = 1e-5,
|
119 |
+
cross_attention_dim: int = 1280,
|
120 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
121 |
+
use_linear_projection: bool = False,
|
122 |
+
class_embed_type: Optional[str] = None,
|
123 |
+
num_class_embeds: Optional[int] = None,
|
124 |
+
upcast_attention: bool = False,
|
125 |
+
resnet_time_scale_shift: str = "default",
|
126 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
127 |
+
controlnet_conditioning_channel_order: str = "rgb",
|
128 |
+
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
129 |
+
):
|
130 |
+
super().__init__()
|
131 |
+
|
132 |
+
# Check inputs
|
133 |
+
if len(block_out_channels) != len(down_block_types):
|
134 |
+
raise ValueError(
|
135 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
136 |
+
)
|
137 |
+
|
138 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
139 |
+
raise ValueError(
|
140 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
141 |
+
)
|
142 |
+
|
143 |
+
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
144 |
+
raise ValueError(
|
145 |
+
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
146 |
+
)
|
147 |
+
|
148 |
+
# input
|
149 |
+
conv_in_kernel = 3
|
150 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
151 |
+
self.conv_in = PseudoConv3d(
|
152 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
153 |
+
)
|
154 |
+
|
155 |
+
# time
|
156 |
+
time_embed_dim = block_out_channels[0] * 4
|
157 |
+
|
158 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
159 |
+
timestep_input_dim = block_out_channels[0]
|
160 |
+
|
161 |
+
self.time_embedding = TimestepEmbedding(
|
162 |
+
timestep_input_dim,
|
163 |
+
time_embed_dim,
|
164 |
+
act_fn=act_fn,
|
165 |
+
)
|
166 |
+
|
167 |
+
# class embedding
|
168 |
+
if class_embed_type is None and num_class_embeds is not None:
|
169 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
170 |
+
elif class_embed_type == "timestep":
|
171 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
172 |
+
elif class_embed_type == "identity":
|
173 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
174 |
+
elif class_embed_type == "projection":
|
175 |
+
if projection_class_embeddings_input_dim is None:
|
176 |
+
raise ValueError(
|
177 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
178 |
+
)
|
179 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
180 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
181 |
+
# 2. it projects from an arbitrary input dimension.
|
182 |
+
#
|
183 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
184 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
185 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
186 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
187 |
+
else:
|
188 |
+
self.class_embedding = None
|
189 |
+
|
190 |
+
# control net conditioning embedding
|
191 |
+
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
|
192 |
+
conditioning_embedding_channels=block_out_channels[0],
|
193 |
+
block_out_channels=conditioning_embedding_out_channels,
|
194 |
+
)
|
195 |
+
|
196 |
+
self.down_blocks = nn.ModuleList([])
|
197 |
+
self.controlnet_down_blocks = nn.ModuleList([])
|
198 |
+
|
199 |
+
if isinstance(only_cross_attention, bool):
|
200 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
201 |
+
|
202 |
+
if isinstance(attention_head_dim, int):
|
203 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
204 |
+
|
205 |
+
# down
|
206 |
+
output_channel = block_out_channels[0]
|
207 |
+
|
208 |
+
controlnet_block = PseudoConv3d(output_channel, output_channel, kernel_size=1)
|
209 |
+
# controlnet_block = zero_module(controlnet_block)
|
210 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
211 |
+
|
212 |
+
for i, down_block_type in enumerate(down_block_types):
|
213 |
+
input_channel = output_channel
|
214 |
+
output_channel = block_out_channels[i]
|
215 |
+
is_final_block = i == len(block_out_channels) - 1
|
216 |
+
|
217 |
+
down_block = get_down_block(
|
218 |
+
down_block_type,
|
219 |
+
num_layers=layers_per_block,
|
220 |
+
in_channels=input_channel,
|
221 |
+
out_channels=output_channel,
|
222 |
+
temb_channels=time_embed_dim,
|
223 |
+
add_downsample=not is_final_block,
|
224 |
+
resnet_eps=norm_eps,
|
225 |
+
resnet_act_fn=act_fn,
|
226 |
+
resnet_groups=norm_num_groups,
|
227 |
+
cross_attention_dim=cross_attention_dim,
|
228 |
+
attn_num_head_channels=attention_head_dim[i],
|
229 |
+
downsample_padding=downsample_padding,
|
230 |
+
use_linear_projection=use_linear_projection,
|
231 |
+
only_cross_attention=only_cross_attention[i],
|
232 |
+
upcast_attention=upcast_attention,
|
233 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
234 |
+
)
|
235 |
+
self.down_blocks.append(down_block)
|
236 |
+
|
237 |
+
for _ in range(layers_per_block):
|
238 |
+
controlnet_block = PseudoConv3d(output_channel, output_channel, kernel_size=1)
|
239 |
+
# controlnet_block = zero_module(controlnet_block)
|
240 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
241 |
+
|
242 |
+
if not is_final_block:
|
243 |
+
controlnet_block = PseudoConv3d(output_channel, output_channel, kernel_size=1)
|
244 |
+
# controlnet_block = zero_module(controlnet_block)
|
245 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
246 |
+
|
247 |
+
# mid
|
248 |
+
mid_block_channel = block_out_channels[-1]
|
249 |
+
|
250 |
+
controlnet_block = PseudoConv3d(mid_block_channel, mid_block_channel, kernel_size=1)
|
251 |
+
# controlnet_block = zero_module(controlnet_block)
|
252 |
+
self.controlnet_mid_block = controlnet_block
|
253 |
+
|
254 |
+
self.mid_block = UNetMidBlockPseudo3DCrossAttn(
|
255 |
+
in_channels=mid_block_channel,
|
256 |
+
temb_channels=time_embed_dim,
|
257 |
+
resnet_eps=norm_eps,
|
258 |
+
resnet_act_fn=act_fn,
|
259 |
+
output_scale_factor=mid_block_scale_factor,
|
260 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
261 |
+
cross_attention_dim=cross_attention_dim,
|
262 |
+
attn_num_head_channels=attention_head_dim[-1],
|
263 |
+
resnet_groups=norm_num_groups,
|
264 |
+
use_linear_projection=use_linear_projection,
|
265 |
+
upcast_attention=upcast_attention,
|
266 |
+
)
|
267 |
+
|
268 |
+
@property
|
269 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
270 |
+
def attn_processors(self) -> Dict[str, AttnProcessor]:
|
271 |
+
r"""
|
272 |
+
Returns:
|
273 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
274 |
+
indexed by its weight name.
|
275 |
+
"""
|
276 |
+
# set recursively
|
277 |
+
processors = {}
|
278 |
+
|
279 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]):
|
280 |
+
if hasattr(module, "set_processor"):
|
281 |
+
processors[f"{name}.processor"] = module.processor
|
282 |
+
|
283 |
+
for sub_name, child in module.named_children():
|
284 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
285 |
+
|
286 |
+
return processors
|
287 |
+
|
288 |
+
for name, module in self.named_children():
|
289 |
+
fn_recursive_add_processors(name, module, processors)
|
290 |
+
|
291 |
+
return processors
|
292 |
+
|
293 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
294 |
+
def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]):
|
295 |
+
r"""
|
296 |
+
Parameters:
|
297 |
+
`processor (`dict` of `AttnProcessor` or `AttnProcessor`):
|
298 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
299 |
+
of **all** `CrossAttention` layers.
|
300 |
+
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.:
|
301 |
+
|
302 |
+
"""
|
303 |
+
count = len(self.attn_processors.keys())
|
304 |
+
|
305 |
+
if isinstance(processor, dict) and len(processor) != count:
|
306 |
+
raise ValueError(
|
307 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
308 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
309 |
+
)
|
310 |
+
|
311 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
312 |
+
if hasattr(module, "set_processor"):
|
313 |
+
if not isinstance(processor, dict):
|
314 |
+
module.set_processor(processor)
|
315 |
+
else:
|
316 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
317 |
+
|
318 |
+
for sub_name, child in module.named_children():
|
319 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
320 |
+
|
321 |
+
for name, module in self.named_children():
|
322 |
+
fn_recursive_attn_processor(name, module, processor)
|
323 |
+
|
324 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
325 |
+
def set_attention_slice(self, slice_size):
|
326 |
+
r"""
|
327 |
+
Enable sliced attention computation.
|
328 |
+
|
329 |
+
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
330 |
+
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
331 |
+
|
332 |
+
Args:
|
333 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
334 |
+
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
335 |
+
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
|
336 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
337 |
+
must be a multiple of `slice_size`.
|
338 |
+
"""
|
339 |
+
sliceable_head_dims = []
|
340 |
+
|
341 |
+
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
|
342 |
+
if hasattr(module, "set_attention_slice"):
|
343 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
344 |
+
|
345 |
+
for child in module.children():
|
346 |
+
fn_recursive_retrieve_slicable_dims(child)
|
347 |
+
|
348 |
+
# retrieve number of attention layers
|
349 |
+
for module in self.children():
|
350 |
+
fn_recursive_retrieve_slicable_dims(module)
|
351 |
+
|
352 |
+
num_slicable_layers = len(sliceable_head_dims)
|
353 |
+
|
354 |
+
if slice_size == "auto":
|
355 |
+
# half the attention head size is usually a good trade-off between
|
356 |
+
# speed and memory
|
357 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
358 |
+
elif slice_size == "max":
|
359 |
+
# make smallest slice possible
|
360 |
+
slice_size = num_slicable_layers * [1]
|
361 |
+
|
362 |
+
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
363 |
+
|
364 |
+
if len(slice_size) != len(sliceable_head_dims):
|
365 |
+
raise ValueError(
|
366 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
367 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
368 |
+
)
|
369 |
+
|
370 |
+
for i in range(len(slice_size)):
|
371 |
+
size = slice_size[i]
|
372 |
+
dim = sliceable_head_dims[i]
|
373 |
+
if size is not None and size > dim:
|
374 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
375 |
+
|
376 |
+
# Recursively walk through all the children.
|
377 |
+
# Any children which exposes the set_attention_slice method
|
378 |
+
# gets the message
|
379 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
380 |
+
if hasattr(module, "set_attention_slice"):
|
381 |
+
module.set_attention_slice(slice_size.pop())
|
382 |
+
|
383 |
+
for child in module.children():
|
384 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
385 |
+
|
386 |
+
reversed_slice_size = list(reversed(slice_size))
|
387 |
+
for module in self.children():
|
388 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
389 |
+
|
390 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
391 |
+
if isinstance(module, (CrossAttnDownBlockPseudo3D, DownBlockPseudo3D)):
|
392 |
+
module.gradient_checkpointing = value
|
393 |
+
|
394 |
+
def forward(
|
395 |
+
self,
|
396 |
+
sample: torch.FloatTensor,
|
397 |
+
timestep: Union[torch.Tensor, float, int],
|
398 |
+
encoder_hidden_states: torch.Tensor,
|
399 |
+
controlnet_cond: torch.FloatTensor,
|
400 |
+
class_labels: Optional[torch.Tensor] = None,
|
401 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
402 |
+
attention_mask: Optional[torch.Tensor] = None,
|
403 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
404 |
+
return_dict: bool = True,
|
405 |
+
) -> Union[ControlNetOutput, Tuple]:
|
406 |
+
# check channel order
|
407 |
+
channel_order = self.config.controlnet_conditioning_channel_order
|
408 |
+
|
409 |
+
if channel_order == "rgb":
|
410 |
+
# in rgb order by default
|
411 |
+
...
|
412 |
+
elif channel_order == "bgr":
|
413 |
+
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
|
414 |
+
else:
|
415 |
+
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
|
416 |
+
|
417 |
+
# prepare attention_mask
|
418 |
+
if attention_mask is not None:
|
419 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
420 |
+
attention_mask = attention_mask.unsqueeze(1)
|
421 |
+
|
422 |
+
# 1. time
|
423 |
+
timesteps = timestep
|
424 |
+
if not torch.is_tensor(timesteps):
|
425 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
426 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
427 |
+
is_mps = sample.device.type == "mps"
|
428 |
+
if isinstance(timestep, float):
|
429 |
+
dtype = torch.float32 if is_mps else torch.float64
|
430 |
+
else:
|
431 |
+
dtype = torch.int32 if is_mps else torch.int64
|
432 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
433 |
+
elif len(timesteps.shape) == 0:
|
434 |
+
timesteps = timesteps[None].to(sample.device)
|
435 |
+
|
436 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
437 |
+
timesteps = timesteps.expand(sample.shape[0])
|
438 |
+
|
439 |
+
t_emb = self.time_proj(timesteps)
|
440 |
+
|
441 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
442 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
443 |
+
# there might be better ways to encapsulate this.
|
444 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
445 |
+
|
446 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
447 |
+
|
448 |
+
if self.class_embedding is not None:
|
449 |
+
if class_labels is None:
|
450 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
451 |
+
|
452 |
+
if self.config.class_embed_type == "timestep":
|
453 |
+
class_labels = self.time_proj(class_labels)
|
454 |
+
|
455 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
456 |
+
emb = emb + class_emb
|
457 |
+
|
458 |
+
# 2. pre-process
|
459 |
+
sample = self.conv_in(sample)
|
460 |
+
|
461 |
+
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
462 |
+
# print(sample.shape, controlnet_cond.shape)
|
463 |
+
|
464 |
+
sample += controlnet_cond
|
465 |
+
# 3. down
|
466 |
+
|
467 |
+
down_block_res_samples = (sample,)
|
468 |
+
for downsample_block in self.down_blocks:
|
469 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
470 |
+
sample, res_samples = downsample_block(
|
471 |
+
hidden_states=sample,
|
472 |
+
temb=emb,
|
473 |
+
encoder_hidden_states=encoder_hidden_states,
|
474 |
+
attention_mask=attention_mask,
|
475 |
+
)
|
476 |
+
else:
|
477 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
478 |
+
|
479 |
+
down_block_res_samples += res_samples
|
480 |
+
|
481 |
+
# 4. mid
|
482 |
+
if self.mid_block is not None:
|
483 |
+
sample = self.mid_block(
|
484 |
+
sample,
|
485 |
+
emb,
|
486 |
+
encoder_hidden_states=encoder_hidden_states,
|
487 |
+
attention_mask=attention_mask,
|
488 |
+
)
|
489 |
+
|
490 |
+
# 5. Control net blocks
|
491 |
+
|
492 |
+
controlnet_down_block_res_samples = ()
|
493 |
+
|
494 |
+
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
495 |
+
down_block_res_sample = controlnet_block(down_block_res_sample)
|
496 |
+
controlnet_down_block_res_samples += (down_block_res_sample,)
|
497 |
+
|
498 |
+
down_block_res_samples = controlnet_down_block_res_samples
|
499 |
+
|
500 |
+
mid_block_res_sample = self.controlnet_mid_block(sample)
|
501 |
+
|
502 |
+
if not return_dict:
|
503 |
+
return (down_block_res_samples, mid_block_res_sample)
|
504 |
+
|
505 |
+
return ControlNetOutput(
|
506 |
+
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
507 |
+
)
|
508 |
+
|
509 |
+
@classmethod
|
510 |
+
def from_2d_model(cls, model_path, condition_on_fps=False, controlnet_hint_channels: Optional[int] = None,):
|
511 |
+
'''
|
512 |
+
load a 2d model and convert it to a pseudo 3d model
|
513 |
+
'''
|
514 |
+
config_path = os.path.join(model_path, "config.json")
|
515 |
+
if not os.path.isfile(config_path):
|
516 |
+
raise RuntimeError(f"{config_path} does not exist")
|
517 |
+
with open(config_path, "r") as f:
|
518 |
+
config = json.load(f)
|
519 |
+
|
520 |
+
config.pop("_class_name")
|
521 |
+
config.pop("_diffusers_version")
|
522 |
+
|
523 |
+
block_replacer = {
|
524 |
+
"CrossAttnDownBlock2D": "CrossAttnDownBlockPseudo3D",
|
525 |
+
"DownBlock2D": "DownBlockPseudo3D",
|
526 |
+
"UNetMidBlock2DCrossAttn": "UNetMidBlockPseudo3DCrossAttn",
|
527 |
+
}
|
528 |
+
|
529 |
+
def convert_2d_to_3d_block(block):
|
530 |
+
return block_replacer[block] if block in block_replacer else block
|
531 |
+
|
532 |
+
config["down_block_types"] = [
|
533 |
+
convert_2d_to_3d_block(block) for block in config["down_block_types"]
|
534 |
+
]
|
535 |
+
|
536 |
+
if "mid_block_type" in config:
|
537 |
+
config["mid_block_type"] = convert_2d_to_3d_block(config["mid_block_type"])
|
538 |
+
|
539 |
+
if condition_on_fps:
|
540 |
+
config["fps_embed_type"] = "timestep" # 和timestep保持一致的type。
|
541 |
+
|
542 |
+
if controlnet_hint_channels:
|
543 |
+
config["controlnet_hint_channels"] = controlnet_hint_channels
|
544 |
+
|
545 |
+
print(config)
|
546 |
+
|
547 |
+
model = cls(**config) # 调用自身(init), 传入config参数全换成3d的setting
|
548 |
+
state_dict_path_condidates = glob.glob(os.path.join(model_path, "*.bin"))
|
549 |
+
if state_dict_path_condidates:
|
550 |
+
state_dict = torch.load(state_dict_path_condidates[0], map_location="cpu")
|
551 |
+
model.load_2d_state_dict(state_dict=state_dict)
|
552 |
+
|
553 |
+
return model
|
554 |
+
|
555 |
+
def load_2d_state_dict(self, state_dict, **kwargs):
|
556 |
+
'''
|
557 |
+
2D 部分的参数名完全不变。
|
558 |
+
'''
|
559 |
+
state_dict_3d = self.state_dict()
|
560 |
+
# print("diff params list:", list(set(state_dict_3d.keys()) - set(state_dict.keys())))
|
561 |
+
|
562 |
+
for k, v in state_dict.items():
|
563 |
+
if k not in state_dict_3d:
|
564 |
+
raise KeyError(f"2d state_dict key {k} does not exist in 3d model")
|
565 |
+
|
566 |
+
for k, v in state_dict_3d.items():
|
567 |
+
if "_temporal" in k:
|
568 |
+
continue
|
569 |
+
if "gamma" in k:
|
570 |
+
continue
|
571 |
+
if k not in state_dict:
|
572 |
+
raise KeyError(f"3d state_dict key {k} does not exist in 2d model")
|
573 |
+
state_dict_3d.update(state_dict)
|
574 |
+
self.load_state_dict(state_dict_3d, strict=True, **kwargs)
|
575 |
+
|
576 |
+
|
577 |
+
def zero_module(module):
|
578 |
+
for p in module.parameters():
|
579 |
+
nn.init.zeros_(p)
|
580 |
+
return module
|
model/video_diffusion/models/resnet.py
ADDED
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Bytedance Ltd. and/or its affiliates
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from functools import partial
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.nn.functional as F
|
20 |
+
|
21 |
+
from einops import rearrange
|
22 |
+
|
23 |
+
|
24 |
+
class PseudoConv3d(nn.Conv2d):
|
25 |
+
def __init__(self, in_channels, out_channels, kernel_size, temporal_kernel_size=None, **kwargs):
|
26 |
+
super().__init__(
|
27 |
+
in_channels=in_channels,
|
28 |
+
out_channels=out_channels,
|
29 |
+
kernel_size=kernel_size,
|
30 |
+
**kwargs,
|
31 |
+
)
|
32 |
+
if temporal_kernel_size is None:
|
33 |
+
temporal_kernel_size = kernel_size
|
34 |
+
|
35 |
+
self.conv_temporal = (
|
36 |
+
nn.Conv1d(
|
37 |
+
out_channels,
|
38 |
+
out_channels,
|
39 |
+
kernel_size=temporal_kernel_size,
|
40 |
+
padding=temporal_kernel_size // 2,
|
41 |
+
)
|
42 |
+
if kernel_size > 1
|
43 |
+
else None
|
44 |
+
)
|
45 |
+
|
46 |
+
if self.conv_temporal is not None:
|
47 |
+
nn.init.dirac_(self.conv_temporal.weight.data) # initialized to be identity
|
48 |
+
nn.init.zeros_(self.conv_temporal.bias.data)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
b = x.shape[0]
|
52 |
+
|
53 |
+
is_video = x.ndim == 5
|
54 |
+
if is_video:
|
55 |
+
x = rearrange(x, "b c f h w -> (b f) c h w")
|
56 |
+
|
57 |
+
x = super().forward(x)
|
58 |
+
|
59 |
+
if is_video:
|
60 |
+
x = rearrange(x, "(b f) c h w -> b c f h w", b=b)
|
61 |
+
|
62 |
+
if self.conv_temporal is None or not is_video:
|
63 |
+
return x
|
64 |
+
|
65 |
+
*_, h, w = x.shape
|
66 |
+
|
67 |
+
x = rearrange(x, "b c f h w -> (b h w) c f")
|
68 |
+
|
69 |
+
x = self.conv_temporal(x) # 加入空间1D的时序卷积。channel不变。(建模时序信息)
|
70 |
+
|
71 |
+
x = rearrange(x, "(b h w) c f -> b c f h w", h=h, w=w)
|
72 |
+
|
73 |
+
return x
|
74 |
+
|
75 |
+
|
76 |
+
class UpsamplePseudo3D(nn.Module):
|
77 |
+
"""
|
78 |
+
An upsampling layer with an optional convolution.
|
79 |
+
|
80 |
+
Parameters:
|
81 |
+
channels: channels in the inputs and outputs.
|
82 |
+
use_conv: a bool determining if a convolution is applied.
|
83 |
+
use_conv_transpose:
|
84 |
+
out_channels:
|
85 |
+
"""
|
86 |
+
|
87 |
+
def __init__(
|
88 |
+
self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"
|
89 |
+
):
|
90 |
+
super().__init__()
|
91 |
+
self.channels = channels
|
92 |
+
self.out_channels = out_channels or channels
|
93 |
+
self.use_conv = use_conv
|
94 |
+
self.use_conv_transpose = use_conv_transpose
|
95 |
+
self.name = name
|
96 |
+
|
97 |
+
conv = None
|
98 |
+
if use_conv_transpose:
|
99 |
+
raise NotImplementedError
|
100 |
+
conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
|
101 |
+
elif use_conv:
|
102 |
+
conv = PseudoConv3d(self.channels, self.out_channels, 3, padding=1)
|
103 |
+
|
104 |
+
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
105 |
+
if name == "conv":
|
106 |
+
self.conv = conv
|
107 |
+
else:
|
108 |
+
self.Conv2d_0 = conv
|
109 |
+
|
110 |
+
def forward(self, hidden_states, output_size=None):
|
111 |
+
assert hidden_states.shape[1] == self.channels
|
112 |
+
|
113 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
114 |
+
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
115 |
+
# https://github.com/pytorch/pytorch/issues/86679
|
116 |
+
dtype = hidden_states.dtype
|
117 |
+
if dtype == torch.bfloat16:
|
118 |
+
hidden_states = hidden_states.to(torch.float32)
|
119 |
+
|
120 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
121 |
+
if hidden_states.shape[0] >= 64:
|
122 |
+
hidden_states = hidden_states.contiguous()
|
123 |
+
|
124 |
+
b = hidden_states.shape[0]
|
125 |
+
is_video = hidden_states.ndim == 5
|
126 |
+
if is_video:
|
127 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
128 |
+
|
129 |
+
# if `output_size` is passed we force the interpolation output
|
130 |
+
# size and do not make use of `scale_factor=2`
|
131 |
+
if output_size is None:
|
132 |
+
# 先插值再用conv
|
133 |
+
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
134 |
+
else:
|
135 |
+
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
136 |
+
|
137 |
+
# If the input is bfloat16, we cast back to bfloat16
|
138 |
+
if dtype == torch.bfloat16:
|
139 |
+
hidden_states = hidden_states.to(dtype)
|
140 |
+
|
141 |
+
if is_video:
|
142 |
+
hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", b=b)
|
143 |
+
|
144 |
+
if self.use_conv:
|
145 |
+
if self.name == "conv":
|
146 |
+
hidden_states = self.conv(hidden_states)
|
147 |
+
else:
|
148 |
+
hidden_states = self.Conv2d_0(hidden_states)
|
149 |
+
|
150 |
+
return hidden_states
|
151 |
+
|
152 |
+
|
153 |
+
class DownsamplePseudo3D(nn.Module):
|
154 |
+
"""
|
155 |
+
A downsampling layer with an optional convolution.
|
156 |
+
|
157 |
+
Parameters:
|
158 |
+
channels: channels in the inputs and outputs.
|
159 |
+
use_conv: a bool determining if a convolution is applied.
|
160 |
+
out_channels:
|
161 |
+
padding:
|
162 |
+
"""
|
163 |
+
|
164 |
+
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
165 |
+
super().__init__()
|
166 |
+
self.channels = channels
|
167 |
+
self.out_channels = out_channels or channels
|
168 |
+
self.use_conv = use_conv
|
169 |
+
self.padding = padding
|
170 |
+
stride = 2
|
171 |
+
self.name = name
|
172 |
+
|
173 |
+
if use_conv:
|
174 |
+
conv = PseudoConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
175 |
+
else:
|
176 |
+
assert self.channels == self.out_channels
|
177 |
+
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
|
178 |
+
|
179 |
+
if name == "conv":
|
180 |
+
self.Conv2d_0 = conv
|
181 |
+
self.conv = conv
|
182 |
+
elif name == "Conv2d_0":
|
183 |
+
self.conv = conv
|
184 |
+
else:
|
185 |
+
self.conv = conv
|
186 |
+
|
187 |
+
def forward(self, hidden_states):
|
188 |
+
assert hidden_states.shape[1] == self.channels
|
189 |
+
if self.use_conv and self.padding == 0:
|
190 |
+
pad = (0, 1, 0, 1)
|
191 |
+
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
|
192 |
+
|
193 |
+
assert hidden_states.shape[1] == self.channels
|
194 |
+
if self.use_conv:
|
195 |
+
hidden_states = self.conv(hidden_states)
|
196 |
+
else:
|
197 |
+
b = hidden_states.shape[0]
|
198 |
+
is_video = hidden_states.ndim == 5
|
199 |
+
if is_video:
|
200 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
201 |
+
hidden_states = self.conv(hidden_states)
|
202 |
+
if is_video:
|
203 |
+
hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", b=b)
|
204 |
+
|
205 |
+
return hidden_states
|
206 |
+
|
207 |
+
|
208 |
+
class ResnetBlockPseudo3D(nn.Module):
|
209 |
+
def __init__(
|
210 |
+
self,
|
211 |
+
*,
|
212 |
+
in_channels,
|
213 |
+
out_channels=None,
|
214 |
+
conv_shortcut=False,
|
215 |
+
dropout=0.0,
|
216 |
+
temb_channels=512,
|
217 |
+
groups=32,
|
218 |
+
groups_out=None,
|
219 |
+
pre_norm=True,
|
220 |
+
eps=1e-6,
|
221 |
+
non_linearity="swish",
|
222 |
+
time_embedding_norm="default",
|
223 |
+
kernel=None,
|
224 |
+
output_scale_factor=1.0,
|
225 |
+
use_in_shortcut=None,
|
226 |
+
up=False,
|
227 |
+
down=False,
|
228 |
+
):
|
229 |
+
super().__init__()
|
230 |
+
self.pre_norm = pre_norm
|
231 |
+
self.pre_norm = True
|
232 |
+
self.in_channels = in_channels
|
233 |
+
out_channels = in_channels if out_channels is None else out_channels
|
234 |
+
self.out_channels = out_channels
|
235 |
+
self.use_conv_shortcut = conv_shortcut
|
236 |
+
self.time_embedding_norm = time_embedding_norm
|
237 |
+
self.up = up
|
238 |
+
self.down = down
|
239 |
+
self.output_scale_factor = output_scale_factor
|
240 |
+
|
241 |
+
if groups_out is None:
|
242 |
+
groups_out = groups
|
243 |
+
|
244 |
+
self.norm1 = torch.nn.GroupNorm(
|
245 |
+
num_groups=groups, num_channels=in_channels, eps=eps, affine=True
|
246 |
+
)
|
247 |
+
|
248 |
+
self.conv1 = PseudoConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
249 |
+
|
250 |
+
if temb_channels is not None:
|
251 |
+
if self.time_embedding_norm == "default":
|
252 |
+
time_emb_proj_out_channels = out_channels
|
253 |
+
elif self.time_embedding_norm == "scale_shift":
|
254 |
+
time_emb_proj_out_channels = out_channels * 2
|
255 |
+
else:
|
256 |
+
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
|
257 |
+
|
258 |
+
self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
|
259 |
+
else:
|
260 |
+
self.time_emb_proj = None
|
261 |
+
|
262 |
+
self.norm2 = torch.nn.GroupNorm(
|
263 |
+
num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
|
264 |
+
)
|
265 |
+
self.dropout = torch.nn.Dropout(dropout)
|
266 |
+
self.conv2 = PseudoConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
267 |
+
|
268 |
+
if non_linearity == "swish":
|
269 |
+
self.nonlinearity = lambda x: F.silu(x)
|
270 |
+
elif non_linearity == "mish":
|
271 |
+
self.nonlinearity = Mish()
|
272 |
+
elif non_linearity == "silu":
|
273 |
+
self.nonlinearity = nn.SiLU()
|
274 |
+
|
275 |
+
self.upsample = self.downsample = None
|
276 |
+
if self.up:
|
277 |
+
if kernel == "fir":
|
278 |
+
fir_kernel = (1, 3, 3, 1)
|
279 |
+
self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
|
280 |
+
elif kernel == "sde_vp":
|
281 |
+
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
|
282 |
+
else:
|
283 |
+
self.upsample = UpsamplePseudo3D(in_channels, use_conv=False)
|
284 |
+
elif self.down:
|
285 |
+
if kernel == "fir":
|
286 |
+
fir_kernel = (1, 3, 3, 1)
|
287 |
+
self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
|
288 |
+
elif kernel == "sde_vp":
|
289 |
+
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
|
290 |
+
else:
|
291 |
+
self.downsample = DownsamplePseudo3D(in_channels, use_conv=False, padding=1, name="op")
|
292 |
+
|
293 |
+
self.use_in_shortcut = (
|
294 |
+
self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
|
295 |
+
)
|
296 |
+
|
297 |
+
self.conv_shortcut = None
|
298 |
+
if self.use_in_shortcut:
|
299 |
+
self.conv_shortcut = PseudoConv3d(
|
300 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
301 |
+
)
|
302 |
+
|
303 |
+
def forward(self, input_tensor, temb):
|
304 |
+
hidden_states = input_tensor
|
305 |
+
|
306 |
+
hidden_states = self.norm1(hidden_states)
|
307 |
+
hidden_states = self.nonlinearity(hidden_states)
|
308 |
+
|
309 |
+
if self.upsample is not None:
|
310 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
311 |
+
if hidden_states.shape[0] >= 64:
|
312 |
+
input_tensor = input_tensor.contiguous()
|
313 |
+
hidden_states = hidden_states.contiguous()
|
314 |
+
input_tensor = self.upsample(input_tensor)
|
315 |
+
hidden_states = self.upsample(hidden_states)
|
316 |
+
elif self.downsample is not None:
|
317 |
+
input_tensor = self.downsample(input_tensor)
|
318 |
+
hidden_states = self.downsample(hidden_states)
|
319 |
+
|
320 |
+
hidden_states = self.conv1(hidden_states)
|
321 |
+
|
322 |
+
if temb is not None:
|
323 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
|
324 |
+
|
325 |
+
if temb is not None and self.time_embedding_norm == "default":
|
326 |
+
is_video = hidden_states.ndim == 5
|
327 |
+
if is_video:
|
328 |
+
b, c, f, h, w = hidden_states.shape
|
329 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
330 |
+
temb = temb.repeat_interleave(f, 0)
|
331 |
+
|
332 |
+
hidden_states = hidden_states + temb
|
333 |
+
|
334 |
+
if is_video:
|
335 |
+
hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", b=b)
|
336 |
+
|
337 |
+
hidden_states = self.norm2(hidden_states)
|
338 |
+
|
339 |
+
if temb is not None and self.time_embedding_norm == "scale_shift":
|
340 |
+
is_video = hidden_states.ndim == 5
|
341 |
+
if is_video:
|
342 |
+
b, c, f, h, w = hidden_states.shape
|
343 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
344 |
+
temb = temb.repeat_interleave(f, 0)
|
345 |
+
|
346 |
+
scale, shift = torch.chunk(temb, 2, dim=1)
|
347 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
348 |
+
|
349 |
+
if is_video:
|
350 |
+
hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", b=b)
|
351 |
+
|
352 |
+
hidden_states = self.nonlinearity(hidden_states)
|
353 |
+
|
354 |
+
hidden_states = self.dropout(hidden_states)
|
355 |
+
hidden_states = self.conv2(hidden_states)
|
356 |
+
|
357 |
+
if self.conv_shortcut is not None:
|
358 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
359 |
+
|
360 |
+
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
361 |
+
|
362 |
+
return output_tensor
|
363 |
+
|
364 |
+
|
365 |
+
class Mish(torch.nn.Module):
|
366 |
+
def forward(self, hidden_states):
|
367 |
+
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
|
368 |
+
|
369 |
+
|
370 |
+
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
|
371 |
+
r"""Upsample2D a batch of 2D images with the given filter.
|
372 |
+
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
|
373 |
+
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
|
374 |
+
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
|
375 |
+
a: multiple of the upsampling factor.
|
376 |
+
|
377 |
+
Args:
|
378 |
+
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
379 |
+
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
|
380 |
+
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
|
381 |
+
factor: Integer upsampling factor (default: 2).
|
382 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
383 |
+
|
384 |
+
Returns:
|
385 |
+
output: Tensor of the shape `[N, C, H * factor, W * factor]`
|
386 |
+
"""
|
387 |
+
assert isinstance(factor, int) and factor >= 1
|
388 |
+
if kernel is None:
|
389 |
+
kernel = [1] * factor
|
390 |
+
|
391 |
+
kernel = torch.tensor(kernel, dtype=torch.float32)
|
392 |
+
if kernel.ndim == 1:
|
393 |
+
kernel = torch.outer(kernel, kernel)
|
394 |
+
kernel /= torch.sum(kernel)
|
395 |
+
|
396 |
+
kernel = kernel * (gain * (factor**2))
|
397 |
+
pad_value = kernel.shape[0] - factor
|
398 |
+
output = upfirdn2d_native(
|
399 |
+
hidden_states,
|
400 |
+
kernel.to(device=hidden_states.device),
|
401 |
+
up=factor,
|
402 |
+
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
|
403 |
+
)
|
404 |
+
return output
|
405 |
+
|
406 |
+
|
407 |
+
def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
|
408 |
+
r"""Downsample2D a batch of 2D images with the given filter.
|
409 |
+
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
|
410 |
+
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
|
411 |
+
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
|
412 |
+
shape is a multiple of the downsampling factor.
|
413 |
+
|
414 |
+
Args:
|
415 |
+
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
416 |
+
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
|
417 |
+
(separable). The default is `[1] * factor`, which corresponds to average pooling.
|
418 |
+
factor: Integer downsampling factor (default: 2).
|
419 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
420 |
+
|
421 |
+
Returns:
|
422 |
+
output: Tensor of the shape `[N, C, H // factor, W // factor]`
|
423 |
+
"""
|
424 |
+
|
425 |
+
assert isinstance(factor, int) and factor >= 1
|
426 |
+
if kernel is None:
|
427 |
+
kernel = [1] * factor
|
428 |
+
|
429 |
+
kernel = torch.tensor(kernel, dtype=torch.float32)
|
430 |
+
if kernel.ndim == 1:
|
431 |
+
kernel = torch.outer(kernel, kernel)
|
432 |
+
kernel /= torch.sum(kernel)
|
433 |
+
|
434 |
+
kernel = kernel * gain
|
435 |
+
pad_value = kernel.shape[0] - factor
|
436 |
+
output = upfirdn2d_native(
|
437 |
+
hidden_states,
|
438 |
+
kernel.to(device=hidden_states.device),
|
439 |
+
down=factor,
|
440 |
+
pad=((pad_value + 1) // 2, pad_value // 2),
|
441 |
+
)
|
442 |
+
return output
|
443 |
+
|
444 |
+
|
445 |
+
def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
|
446 |
+
up_x = up_y = up
|
447 |
+
down_x = down_y = down
|
448 |
+
pad_x0 = pad_y0 = pad[0]
|
449 |
+
pad_x1 = pad_y1 = pad[1]
|
450 |
+
|
451 |
+
_, channel, in_h, in_w = tensor.shape
|
452 |
+
tensor = tensor.reshape(-1, in_h, in_w, 1)
|
453 |
+
|
454 |
+
_, in_h, in_w, minor = tensor.shape
|
455 |
+
kernel_h, kernel_w = kernel.shape
|
456 |
+
|
457 |
+
out = tensor.view(-1, in_h, 1, in_w, 1, minor)
|
458 |
+
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
459 |
+
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
460 |
+
|
461 |
+
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
|
462 |
+
out = out.to(tensor.device) # Move back to mps if necessary
|
463 |
+
out = out[
|
464 |
+
:,
|
465 |
+
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
466 |
+
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
|
467 |
+
:,
|
468 |
+
]
|
469 |
+
|
470 |
+
out = out.permute(0, 3, 1, 2)
|
471 |
+
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
|
472 |
+
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
473 |
+
out = F.conv2d(out, w)
|
474 |
+
out = out.reshape(
|
475 |
+
-1,
|
476 |
+
minor,
|
477 |
+
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
478 |
+
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
479 |
+
)
|
480 |
+
out = out.permute(0, 2, 3, 1)
|
481 |
+
out = out[:, ::down_y, ::down_x, :]
|
482 |
+
|
483 |
+
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
484 |
+
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
485 |
+
|
486 |
+
return out.view(-1, channel, out_h, out_w)
|
model/video_diffusion/models/unet_3d_blocks.py
ADDED
@@ -0,0 +1,622 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Bytedance Ltd. and/or its affiliates
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from torch import nn
|
17 |
+
|
18 |
+
from .attention import SpatioTemporalTransformerModel
|
19 |
+
from .resnet import DownsamplePseudo3D, ResnetBlockPseudo3D, UpsamplePseudo3D
|
20 |
+
|
21 |
+
|
22 |
+
def get_down_block(
|
23 |
+
down_block_type,
|
24 |
+
num_layers,
|
25 |
+
in_channels,
|
26 |
+
out_channels,
|
27 |
+
temb_channels,
|
28 |
+
add_downsample,
|
29 |
+
resnet_eps,
|
30 |
+
resnet_act_fn,
|
31 |
+
attn_num_head_channels,
|
32 |
+
resnet_groups=None,
|
33 |
+
cross_attention_dim=None,
|
34 |
+
downsample_padding=None,
|
35 |
+
dual_cross_attention=False,
|
36 |
+
use_linear_projection=False,
|
37 |
+
only_cross_attention=False,
|
38 |
+
upcast_attention=False,
|
39 |
+
resnet_time_scale_shift="default",
|
40 |
+
):
|
41 |
+
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
42 |
+
if down_block_type == "DownBlockPseudo3D":
|
43 |
+
return DownBlockPseudo3D(
|
44 |
+
num_layers=num_layers,
|
45 |
+
in_channels=in_channels,
|
46 |
+
out_channels=out_channels,
|
47 |
+
temb_channels=temb_channels,
|
48 |
+
add_downsample=add_downsample,
|
49 |
+
resnet_eps=resnet_eps,
|
50 |
+
resnet_act_fn=resnet_act_fn,
|
51 |
+
resnet_groups=resnet_groups,
|
52 |
+
downsample_padding=downsample_padding,
|
53 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
54 |
+
)
|
55 |
+
elif down_block_type == "CrossAttnDownBlockPseudo3D":
|
56 |
+
if cross_attention_dim is None:
|
57 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockPseudo3D")
|
58 |
+
return CrossAttnDownBlockPseudo3D(
|
59 |
+
num_layers=num_layers,
|
60 |
+
in_channels=in_channels,
|
61 |
+
out_channels=out_channels,
|
62 |
+
temb_channels=temb_channels,
|
63 |
+
add_downsample=add_downsample,
|
64 |
+
resnet_eps=resnet_eps,
|
65 |
+
resnet_act_fn=resnet_act_fn,
|
66 |
+
resnet_groups=resnet_groups,
|
67 |
+
downsample_padding=downsample_padding,
|
68 |
+
cross_attention_dim=cross_attention_dim,
|
69 |
+
attn_num_head_channels=attn_num_head_channels,
|
70 |
+
dual_cross_attention=dual_cross_attention,
|
71 |
+
use_linear_projection=use_linear_projection,
|
72 |
+
only_cross_attention=only_cross_attention,
|
73 |
+
upcast_attention=upcast_attention,
|
74 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
75 |
+
)
|
76 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
77 |
+
|
78 |
+
|
79 |
+
def get_up_block(
|
80 |
+
up_block_type,
|
81 |
+
num_layers,
|
82 |
+
in_channels,
|
83 |
+
out_channels,
|
84 |
+
prev_output_channel,
|
85 |
+
temb_channels,
|
86 |
+
add_upsample,
|
87 |
+
resnet_eps,
|
88 |
+
resnet_act_fn,
|
89 |
+
attn_num_head_channels,
|
90 |
+
resnet_groups=None,
|
91 |
+
cross_attention_dim=None,
|
92 |
+
dual_cross_attention=False,
|
93 |
+
use_linear_projection=False,
|
94 |
+
only_cross_attention=False,
|
95 |
+
upcast_attention=False,
|
96 |
+
resnet_time_scale_shift="default",
|
97 |
+
):
|
98 |
+
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
99 |
+
if up_block_type == "UpBlockPseudo3D":
|
100 |
+
return UpBlockPseudo3D(
|
101 |
+
num_layers=num_layers,
|
102 |
+
in_channels=in_channels,
|
103 |
+
out_channels=out_channels,
|
104 |
+
prev_output_channel=prev_output_channel,
|
105 |
+
temb_channels=temb_channels,
|
106 |
+
add_upsample=add_upsample,
|
107 |
+
resnet_eps=resnet_eps,
|
108 |
+
resnet_act_fn=resnet_act_fn,
|
109 |
+
resnet_groups=resnet_groups,
|
110 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
111 |
+
)
|
112 |
+
elif up_block_type == "CrossAttnUpBlockPseudo3D":
|
113 |
+
if cross_attention_dim is None:
|
114 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockPseudo3D")
|
115 |
+
return CrossAttnUpBlockPseudo3D(
|
116 |
+
num_layers=num_layers,
|
117 |
+
in_channels=in_channels,
|
118 |
+
out_channels=out_channels,
|
119 |
+
prev_output_channel=prev_output_channel,
|
120 |
+
temb_channels=temb_channels,
|
121 |
+
add_upsample=add_upsample,
|
122 |
+
resnet_eps=resnet_eps,
|
123 |
+
resnet_act_fn=resnet_act_fn,
|
124 |
+
resnet_groups=resnet_groups,
|
125 |
+
cross_attention_dim=cross_attention_dim,
|
126 |
+
attn_num_head_channels=attn_num_head_channels,
|
127 |
+
dual_cross_attention=dual_cross_attention,
|
128 |
+
use_linear_projection=use_linear_projection,
|
129 |
+
only_cross_attention=only_cross_attention,
|
130 |
+
upcast_attention=upcast_attention,
|
131 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
132 |
+
)
|
133 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
134 |
+
|
135 |
+
|
136 |
+
class UNetMidBlockPseudo3DCrossAttn(nn.Module):
|
137 |
+
def __init__(
|
138 |
+
self,
|
139 |
+
in_channels: int,
|
140 |
+
temb_channels: int,
|
141 |
+
dropout: float = 0.0,
|
142 |
+
num_layers: int = 1,
|
143 |
+
resnet_eps: float = 1e-6,
|
144 |
+
resnet_time_scale_shift: str = "default",
|
145 |
+
resnet_act_fn: str = "swish",
|
146 |
+
resnet_groups: int = 32,
|
147 |
+
resnet_pre_norm: bool = True,
|
148 |
+
attn_num_head_channels=1,
|
149 |
+
output_scale_factor=1.0,
|
150 |
+
cross_attention_dim=1280,
|
151 |
+
dual_cross_attention=False,
|
152 |
+
use_linear_projection=False,
|
153 |
+
upcast_attention=False,
|
154 |
+
):
|
155 |
+
super().__init__()
|
156 |
+
|
157 |
+
self.has_cross_attention = True
|
158 |
+
self.attn_num_head_channels = attn_num_head_channels
|
159 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
160 |
+
|
161 |
+
# there is always at least one resnet
|
162 |
+
resnets = [
|
163 |
+
ResnetBlockPseudo3D(
|
164 |
+
in_channels=in_channels,
|
165 |
+
out_channels=in_channels,
|
166 |
+
temb_channels=temb_channels,
|
167 |
+
eps=resnet_eps,
|
168 |
+
groups=resnet_groups,
|
169 |
+
dropout=dropout,
|
170 |
+
time_embedding_norm=resnet_time_scale_shift,
|
171 |
+
non_linearity=resnet_act_fn,
|
172 |
+
output_scale_factor=output_scale_factor,
|
173 |
+
pre_norm=resnet_pre_norm,
|
174 |
+
)
|
175 |
+
]
|
176 |
+
attentions = []
|
177 |
+
|
178 |
+
for _ in range(num_layers):
|
179 |
+
if dual_cross_attention:
|
180 |
+
raise NotImplementedError
|
181 |
+
attentions.append(
|
182 |
+
SpatioTemporalTransformerModel(
|
183 |
+
attn_num_head_channels,
|
184 |
+
in_channels // attn_num_head_channels,
|
185 |
+
in_channels=in_channels,
|
186 |
+
num_layers=1,
|
187 |
+
cross_attention_dim=cross_attention_dim,
|
188 |
+
norm_num_groups=resnet_groups,
|
189 |
+
use_linear_projection=use_linear_projection,
|
190 |
+
upcast_attention=upcast_attention,
|
191 |
+
)
|
192 |
+
)
|
193 |
+
resnets.append(
|
194 |
+
ResnetBlockPseudo3D(
|
195 |
+
in_channels=in_channels,
|
196 |
+
out_channels=in_channels,
|
197 |
+
temb_channels=temb_channels,
|
198 |
+
eps=resnet_eps,
|
199 |
+
groups=resnet_groups,
|
200 |
+
dropout=dropout,
|
201 |
+
time_embedding_norm=resnet_time_scale_shift,
|
202 |
+
non_linearity=resnet_act_fn,
|
203 |
+
output_scale_factor=output_scale_factor,
|
204 |
+
pre_norm=resnet_pre_norm,
|
205 |
+
)
|
206 |
+
)
|
207 |
+
|
208 |
+
self.attentions = nn.ModuleList(attentions)
|
209 |
+
self.resnets = nn.ModuleList(resnets)
|
210 |
+
|
211 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
212 |
+
# TODO(Patrick, William) - attention_mask is currently not used. Implement once used
|
213 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
214 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
215 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
216 |
+
hidden_states = resnet(hidden_states, temb)
|
217 |
+
|
218 |
+
return hidden_states
|
219 |
+
|
220 |
+
|
221 |
+
class CrossAttnDownBlockPseudo3D(nn.Module):
|
222 |
+
def __init__(
|
223 |
+
self,
|
224 |
+
in_channels: int,
|
225 |
+
out_channels: int,
|
226 |
+
temb_channels: int,
|
227 |
+
dropout: float = 0.0,
|
228 |
+
num_layers: int = 1,
|
229 |
+
resnet_eps: float = 1e-6,
|
230 |
+
resnet_time_scale_shift: str = "default",
|
231 |
+
resnet_act_fn: str = "swish",
|
232 |
+
resnet_groups: int = 32,
|
233 |
+
resnet_pre_norm: bool = True,
|
234 |
+
attn_num_head_channels=1,
|
235 |
+
cross_attention_dim=1280,
|
236 |
+
output_scale_factor=1.0,
|
237 |
+
downsample_padding=1,
|
238 |
+
add_downsample=True,
|
239 |
+
dual_cross_attention=False,
|
240 |
+
use_linear_projection=False,
|
241 |
+
only_cross_attention=False,
|
242 |
+
upcast_attention=False,
|
243 |
+
):
|
244 |
+
super().__init__()
|
245 |
+
resnets = []
|
246 |
+
attentions = []
|
247 |
+
|
248 |
+
self.has_cross_attention = True
|
249 |
+
self.attn_num_head_channels = attn_num_head_channels
|
250 |
+
|
251 |
+
for i in range(num_layers):
|
252 |
+
in_channels = in_channels if i == 0 else out_channels
|
253 |
+
resnets.append(
|
254 |
+
ResnetBlockPseudo3D(
|
255 |
+
in_channels=in_channels,
|
256 |
+
out_channels=out_channels,
|
257 |
+
temb_channels=temb_channels,
|
258 |
+
eps=resnet_eps,
|
259 |
+
groups=resnet_groups,
|
260 |
+
dropout=dropout,
|
261 |
+
time_embedding_norm=resnet_time_scale_shift,
|
262 |
+
non_linearity=resnet_act_fn,
|
263 |
+
output_scale_factor=output_scale_factor,
|
264 |
+
pre_norm=resnet_pre_norm,
|
265 |
+
)
|
266 |
+
)
|
267 |
+
if dual_cross_attention:
|
268 |
+
raise NotImplementedError
|
269 |
+
attentions.append(
|
270 |
+
SpatioTemporalTransformerModel(
|
271 |
+
attn_num_head_channels,
|
272 |
+
out_channels // attn_num_head_channels,
|
273 |
+
in_channels=out_channels,
|
274 |
+
num_layers=1,
|
275 |
+
cross_attention_dim=cross_attention_dim,
|
276 |
+
norm_num_groups=resnet_groups,
|
277 |
+
use_linear_projection=use_linear_projection,
|
278 |
+
only_cross_attention=only_cross_attention,
|
279 |
+
upcast_attention=upcast_attention,
|
280 |
+
)
|
281 |
+
)
|
282 |
+
self.attentions = nn.ModuleList(attentions)
|
283 |
+
self.resnets = nn.ModuleList(resnets)
|
284 |
+
|
285 |
+
if add_downsample:
|
286 |
+
self.downsamplers = nn.ModuleList(
|
287 |
+
[
|
288 |
+
DownsamplePseudo3D(
|
289 |
+
out_channels,
|
290 |
+
use_conv=True,
|
291 |
+
out_channels=out_channels,
|
292 |
+
padding=downsample_padding,
|
293 |
+
name="op",
|
294 |
+
)
|
295 |
+
]
|
296 |
+
)
|
297 |
+
else:
|
298 |
+
self.downsamplers = None
|
299 |
+
|
300 |
+
self.gradient_checkpointing = False
|
301 |
+
|
302 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
303 |
+
# TODO(Patrick, William) - attention mask is not used
|
304 |
+
output_states = ()
|
305 |
+
|
306 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
307 |
+
if self.training and self.gradient_checkpointing:
|
308 |
+
|
309 |
+
def create_custom_forward(module, return_dict=None):
|
310 |
+
def custom_forward(*inputs):
|
311 |
+
if return_dict is not None:
|
312 |
+
return module(*inputs, return_dict=return_dict)
|
313 |
+
else:
|
314 |
+
return module(*inputs)
|
315 |
+
|
316 |
+
return custom_forward
|
317 |
+
|
318 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
319 |
+
create_custom_forward(resnet), hidden_states, temb
|
320 |
+
)
|
321 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
322 |
+
create_custom_forward(attn, return_dict=False),
|
323 |
+
hidden_states,
|
324 |
+
encoder_hidden_states,
|
325 |
+
)[0]
|
326 |
+
else:
|
327 |
+
hidden_states = resnet(hidden_states, temb)
|
328 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
329 |
+
|
330 |
+
output_states += (hidden_states,)
|
331 |
+
|
332 |
+
if self.downsamplers is not None:
|
333 |
+
for downsampler in self.downsamplers:
|
334 |
+
hidden_states = downsampler(hidden_states)
|
335 |
+
|
336 |
+
output_states += (hidden_states,)
|
337 |
+
|
338 |
+
return hidden_states, output_states
|
339 |
+
|
340 |
+
|
341 |
+
class DownBlockPseudo3D(nn.Module):
|
342 |
+
def __init__(
|
343 |
+
self,
|
344 |
+
in_channels: int,
|
345 |
+
out_channels: int,
|
346 |
+
temb_channels: int,
|
347 |
+
dropout: float = 0.0,
|
348 |
+
num_layers: int = 1,
|
349 |
+
resnet_eps: float = 1e-6,
|
350 |
+
resnet_time_scale_shift: str = "default",
|
351 |
+
resnet_act_fn: str = "swish",
|
352 |
+
resnet_groups: int = 32,
|
353 |
+
resnet_pre_norm: bool = True,
|
354 |
+
output_scale_factor=1.0,
|
355 |
+
add_downsample=True,
|
356 |
+
downsample_padding=1,
|
357 |
+
):
|
358 |
+
super().__init__()
|
359 |
+
resnets = []
|
360 |
+
|
361 |
+
for i in range(num_layers):
|
362 |
+
in_channels = in_channels if i == 0 else out_channels
|
363 |
+
resnets.append(
|
364 |
+
ResnetBlockPseudo3D(
|
365 |
+
in_channels=in_channels,
|
366 |
+
out_channels=out_channels,
|
367 |
+
temb_channels=temb_channels,
|
368 |
+
eps=resnet_eps,
|
369 |
+
groups=resnet_groups,
|
370 |
+
dropout=dropout,
|
371 |
+
time_embedding_norm=resnet_time_scale_shift,
|
372 |
+
non_linearity=resnet_act_fn,
|
373 |
+
output_scale_factor=output_scale_factor,
|
374 |
+
pre_norm=resnet_pre_norm,
|
375 |
+
)
|
376 |
+
)
|
377 |
+
|
378 |
+
self.resnets = nn.ModuleList(resnets)
|
379 |
+
|
380 |
+
if add_downsample:
|
381 |
+
self.downsamplers = nn.ModuleList(
|
382 |
+
[
|
383 |
+
DownsamplePseudo3D(
|
384 |
+
out_channels,
|
385 |
+
use_conv=True,
|
386 |
+
out_channels=out_channels,
|
387 |
+
padding=downsample_padding,
|
388 |
+
name="op",
|
389 |
+
)
|
390 |
+
]
|
391 |
+
)
|
392 |
+
else:
|
393 |
+
self.downsamplers = None
|
394 |
+
|
395 |
+
self.gradient_checkpointing = False
|
396 |
+
|
397 |
+
def forward(self, hidden_states, temb=None):
|
398 |
+
output_states = ()
|
399 |
+
|
400 |
+
for resnet in self.resnets:
|
401 |
+
if self.training and self.gradient_checkpointing:
|
402 |
+
|
403 |
+
def create_custom_forward(module):
|
404 |
+
def custom_forward(*inputs):
|
405 |
+
return module(*inputs)
|
406 |
+
|
407 |
+
return custom_forward
|
408 |
+
|
409 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
410 |
+
create_custom_forward(resnet), hidden_states, temb
|
411 |
+
)
|
412 |
+
else:
|
413 |
+
hidden_states = resnet(hidden_states, temb)
|
414 |
+
|
415 |
+
output_states += (hidden_states,)
|
416 |
+
|
417 |
+
if self.downsamplers is not None:
|
418 |
+
for downsampler in self.downsamplers:
|
419 |
+
hidden_states = downsampler(hidden_states)
|
420 |
+
|
421 |
+
output_states += (hidden_states,)
|
422 |
+
|
423 |
+
return hidden_states, output_states
|
424 |
+
|
425 |
+
|
426 |
+
class CrossAttnUpBlockPseudo3D(nn.Module):
|
427 |
+
def __init__(
|
428 |
+
self,
|
429 |
+
in_channels: int,
|
430 |
+
out_channels: int,
|
431 |
+
prev_output_channel: int,
|
432 |
+
temb_channels: int,
|
433 |
+
dropout: float = 0.0,
|
434 |
+
num_layers: int = 1,
|
435 |
+
resnet_eps: float = 1e-6,
|
436 |
+
resnet_time_scale_shift: str = "default",
|
437 |
+
resnet_act_fn: str = "swish",
|
438 |
+
resnet_groups: int = 32,
|
439 |
+
resnet_pre_norm: bool = True,
|
440 |
+
attn_num_head_channels=1,
|
441 |
+
cross_attention_dim=1280,
|
442 |
+
output_scale_factor=1.0,
|
443 |
+
add_upsample=True,
|
444 |
+
dual_cross_attention=False,
|
445 |
+
use_linear_projection=False,
|
446 |
+
only_cross_attention=False,
|
447 |
+
upcast_attention=False,
|
448 |
+
):
|
449 |
+
super().__init__()
|
450 |
+
resnets = []
|
451 |
+
attentions = []
|
452 |
+
|
453 |
+
self.has_cross_attention = True
|
454 |
+
self.attn_num_head_channels = attn_num_head_channels
|
455 |
+
|
456 |
+
for i in range(num_layers):
|
457 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
458 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
459 |
+
|
460 |
+
resnets.append(
|
461 |
+
ResnetBlockPseudo3D(
|
462 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
463 |
+
out_channels=out_channels,
|
464 |
+
temb_channels=temb_channels,
|
465 |
+
eps=resnet_eps,
|
466 |
+
groups=resnet_groups,
|
467 |
+
dropout=dropout,
|
468 |
+
time_embedding_norm=resnet_time_scale_shift,
|
469 |
+
non_linearity=resnet_act_fn,
|
470 |
+
output_scale_factor=output_scale_factor,
|
471 |
+
pre_norm=resnet_pre_norm,
|
472 |
+
)
|
473 |
+
)
|
474 |
+
if dual_cross_attention:
|
475 |
+
raise NotImplementedError
|
476 |
+
attentions.append(
|
477 |
+
SpatioTemporalTransformerModel(
|
478 |
+
attn_num_head_channels,
|
479 |
+
out_channels // attn_num_head_channels,
|
480 |
+
in_channels=out_channels,
|
481 |
+
num_layers=1,
|
482 |
+
cross_attention_dim=cross_attention_dim,
|
483 |
+
norm_num_groups=resnet_groups,
|
484 |
+
use_linear_projection=use_linear_projection,
|
485 |
+
only_cross_attention=only_cross_attention,
|
486 |
+
upcast_attention=upcast_attention,
|
487 |
+
)
|
488 |
+
)
|
489 |
+
self.attentions = nn.ModuleList(attentions)
|
490 |
+
self.resnets = nn.ModuleList(resnets)
|
491 |
+
|
492 |
+
if add_upsample:
|
493 |
+
self.upsamplers = nn.ModuleList(
|
494 |
+
[UpsamplePseudo3D(out_channels, use_conv=True, out_channels=out_channels)]
|
495 |
+
)
|
496 |
+
else:
|
497 |
+
self.upsamplers = None
|
498 |
+
|
499 |
+
self.gradient_checkpointing = False
|
500 |
+
|
501 |
+
def forward(
|
502 |
+
self,
|
503 |
+
hidden_states,
|
504 |
+
res_hidden_states_tuple,
|
505 |
+
temb=None,
|
506 |
+
encoder_hidden_states=None,
|
507 |
+
upsample_size=None,
|
508 |
+
attention_mask=None,
|
509 |
+
):
|
510 |
+
# TODO(Patrick, William) - attention mask is not used
|
511 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
512 |
+
# pop res hidden states
|
513 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
514 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
515 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
516 |
+
|
517 |
+
if self.training and self.gradient_checkpointing:
|
518 |
+
|
519 |
+
def create_custom_forward(module, return_dict=None):
|
520 |
+
def custom_forward(*inputs):
|
521 |
+
if return_dict is not None:
|
522 |
+
return module(*inputs, return_dict=return_dict)
|
523 |
+
else:
|
524 |
+
return module(*inputs)
|
525 |
+
|
526 |
+
return custom_forward
|
527 |
+
|
528 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
529 |
+
create_custom_forward(resnet), hidden_states, temb
|
530 |
+
)
|
531 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
532 |
+
create_custom_forward(attn, return_dict=False),
|
533 |
+
hidden_states,
|
534 |
+
encoder_hidden_states,
|
535 |
+
)[0]
|
536 |
+
else:
|
537 |
+
hidden_states = resnet(hidden_states, temb)
|
538 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
539 |
+
|
540 |
+
if self.upsamplers is not None:
|
541 |
+
for upsampler in self.upsamplers:
|
542 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
543 |
+
|
544 |
+
return hidden_states
|
545 |
+
|
546 |
+
|
547 |
+
class UpBlockPseudo3D(nn.Module):
|
548 |
+
def __init__(
|
549 |
+
self,
|
550 |
+
in_channels: int,
|
551 |
+
prev_output_channel: int,
|
552 |
+
out_channels: int,
|
553 |
+
temb_channels: int,
|
554 |
+
dropout: float = 0.0,
|
555 |
+
num_layers: int = 1,
|
556 |
+
resnet_eps: float = 1e-6,
|
557 |
+
resnet_time_scale_shift: str = "default",
|
558 |
+
resnet_act_fn: str = "swish",
|
559 |
+
resnet_groups: int = 32,
|
560 |
+
resnet_pre_norm: bool = True,
|
561 |
+
output_scale_factor=1.0,
|
562 |
+
add_upsample=True,
|
563 |
+
):
|
564 |
+
super().__init__()
|
565 |
+
resnets = []
|
566 |
+
|
567 |
+
for i in range(num_layers):
|
568 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
569 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
570 |
+
|
571 |
+
resnets.append(
|
572 |
+
ResnetBlockPseudo3D(
|
573 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
574 |
+
out_channels=out_channels,
|
575 |
+
temb_channels=temb_channels,
|
576 |
+
eps=resnet_eps,
|
577 |
+
groups=resnet_groups,
|
578 |
+
dropout=dropout,
|
579 |
+
time_embedding_norm=resnet_time_scale_shift,
|
580 |
+
non_linearity=resnet_act_fn,
|
581 |
+
output_scale_factor=output_scale_factor,
|
582 |
+
pre_norm=resnet_pre_norm,
|
583 |
+
)
|
584 |
+
)
|
585 |
+
|
586 |
+
self.resnets = nn.ModuleList(resnets)
|
587 |
+
|
588 |
+
if add_upsample:
|
589 |
+
self.upsamplers = nn.ModuleList(
|
590 |
+
[UpsamplePseudo3D(out_channels, use_conv=True, out_channels=out_channels)]
|
591 |
+
)
|
592 |
+
else:
|
593 |
+
self.upsamplers = None
|
594 |
+
|
595 |
+
self.gradient_checkpointing = False
|
596 |
+
|
597 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
598 |
+
for resnet in self.resnets:
|
599 |
+
# pop res hidden states
|
600 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
601 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
602 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
603 |
+
|
604 |
+
if self.training and self.gradient_checkpointing:
|
605 |
+
|
606 |
+
def create_custom_forward(module):
|
607 |
+
def custom_forward(*inputs):
|
608 |
+
return module(*inputs)
|
609 |
+
|
610 |
+
return custom_forward
|
611 |
+
|
612 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
613 |
+
create_custom_forward(resnet), hidden_states, temb
|
614 |
+
)
|
615 |
+
else:
|
616 |
+
hidden_states = resnet(hidden_states, temb)
|
617 |
+
|
618 |
+
if self.upsamplers is not None:
|
619 |
+
for upsampler in self.upsamplers:
|
620 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
621 |
+
|
622 |
+
return hidden_states
|
model/video_diffusion/models/unet_3d_blocks_control.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Bytedance Ltd. and/or its affiliates
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from torch import nn
|
17 |
+
from .attention import SpatioTemporalTransformerModel
|
18 |
+
from .resnet import DownsamplePseudo3D, ResnetBlockPseudo3D, UpsamplePseudo3D
|
19 |
+
import glob
|
20 |
+
import json
|
21 |
+
from dataclasses import dataclass
|
22 |
+
from typing import List, Optional, Tuple, Union
|
23 |
+
import torch
|
24 |
+
import torch.nn as nn
|
25 |
+
import torch.utils.checkpoint
|
26 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
27 |
+
from diffusers.models.modeling_utils import ModelMixin
|
28 |
+
from diffusers.utils import BaseOutput, logging
|
29 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
30 |
+
from .unet_3d_blocks import (
|
31 |
+
CrossAttnDownBlockPseudo3D,
|
32 |
+
CrossAttnUpBlockPseudo3D,
|
33 |
+
DownBlockPseudo3D,
|
34 |
+
UNetMidBlockPseudo3DCrossAttn,
|
35 |
+
UpBlockPseudo3D,
|
36 |
+
get_down_block,
|
37 |
+
get_up_block,
|
38 |
+
)
|
39 |
+
from .resnet import PseudoConv3d
|
40 |
+
from diffusers.models.cross_attention import AttnProcessor
|
41 |
+
from typing import Dict
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
def set_zero_parameters(module):
|
46 |
+
for p in module.parameters():
|
47 |
+
p.detach().zero_()
|
48 |
+
return module
|
49 |
+
|
50 |
+
# ControlNet: Zero Convolution
|
51 |
+
def zero_conv(channels):
|
52 |
+
return set_zero_parameters(PseudoConv3d(channels, channels, 1, padding=0))
|
53 |
+
|
54 |
+
class ControlNetInputHintBlock(nn.Module):
|
55 |
+
def __init__(self, hint_channels: int = 3, channels: int = 320):
|
56 |
+
super().__init__()
|
57 |
+
# Layer configurations are from reference implementation.
|
58 |
+
self.input_hint_block = nn.Sequential(
|
59 |
+
PseudoConv3d(hint_channels, 16, 3, padding=1),
|
60 |
+
nn.SiLU(),
|
61 |
+
PseudoConv3d(16, 16, 3, padding=1),
|
62 |
+
nn.SiLU(),
|
63 |
+
PseudoConv3d(16, 32, 3, padding=1, stride=2),
|
64 |
+
nn.SiLU(),
|
65 |
+
PseudoConv3d(32, 32, 3, padding=1),
|
66 |
+
nn.SiLU(),
|
67 |
+
PseudoConv3d(32, 96, 3, padding=1, stride=2),
|
68 |
+
nn.SiLU(),
|
69 |
+
PseudoConv3d(96, 96, 3, padding=1),
|
70 |
+
nn.SiLU(),
|
71 |
+
PseudoConv3d(96, 256, 3, padding=1, stride=2),
|
72 |
+
nn.SiLU(),
|
73 |
+
set_zero_parameters(PseudoConv3d(256, channels, 3, padding=1)),
|
74 |
+
)
|
75 |
+
def forward(self, hint: torch.Tensor):
|
76 |
+
return self.input_hint_block(hint)
|
77 |
+
|
78 |
+
|
79 |
+
class ControlNetPseudoZeroConv3dBlock(nn.Module):
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
83 |
+
down_block_types: Tuple[str] = (
|
84 |
+
"CrossAttnDownBlockPseudo3D",
|
85 |
+
"CrossAttnDownBlockPseudo3D",
|
86 |
+
"CrossAttnDownBlockPseudo3D",
|
87 |
+
"DownBlockPseudo3D",
|
88 |
+
),
|
89 |
+
layers_per_block: int = 2,
|
90 |
+
):
|
91 |
+
super().__init__()
|
92 |
+
self.input_zero_conv = zero_conv(block_out_channels[0])
|
93 |
+
zero_convs = []
|
94 |
+
for i, down_block_type in enumerate(down_block_types):
|
95 |
+
output_channel = block_out_channels[i]
|
96 |
+
is_final_block = i == len(block_out_channels) - 1
|
97 |
+
for _ in range(layers_per_block):
|
98 |
+
zero_convs.append(zero_conv(output_channel))
|
99 |
+
if not is_final_block:
|
100 |
+
zero_convs.append(zero_conv(output_channel))
|
101 |
+
self.zero_convs = nn.ModuleList(zero_convs)
|
102 |
+
self.mid_zero_conv = zero_conv(block_out_channels[-1])
|
103 |
+
|
104 |
+
def forward(
|
105 |
+
self,
|
106 |
+
down_block_res_samples: List[torch.Tensor],
|
107 |
+
mid_block_sample: torch.Tensor,
|
108 |
+
) -> List[torch.Tensor]:
|
109 |
+
outputs = []
|
110 |
+
outputs.append(self.input_zero_conv(down_block_res_samples[0]))
|
111 |
+
for res_sample, zero_conv in zip(down_block_res_samples[1:], self.zero_convs):
|
112 |
+
outputs.append(zero_conv(res_sample))
|
113 |
+
outputs.append(self.mid_zero_conv(mid_block_sample))
|
114 |
+
return outputs
|
115 |
+
|
116 |
+
|
model/video_diffusion/models/unet_3d_condition.py
ADDED
@@ -0,0 +1,571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Bytedance Ltd. and/or its affiliates
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import glob
|
17 |
+
import json
|
18 |
+
from dataclasses import dataclass
|
19 |
+
from typing import List, Optional, Tuple, Union
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
import torch.utils.checkpoint
|
24 |
+
|
25 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
26 |
+
from diffusers.models.modeling_utils import ModelMixin
|
27 |
+
from diffusers.utils import BaseOutput, logging
|
28 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
29 |
+
from .unet_3d_blocks import (
|
30 |
+
CrossAttnDownBlockPseudo3D,
|
31 |
+
CrossAttnUpBlockPseudo3D,
|
32 |
+
DownBlockPseudo3D,
|
33 |
+
UNetMidBlockPseudo3DCrossAttn,
|
34 |
+
UpBlockPseudo3D,
|
35 |
+
get_down_block,
|
36 |
+
get_up_block,
|
37 |
+
)
|
38 |
+
from .resnet import PseudoConv3d
|
39 |
+
from diffusers.models.cross_attention import AttnProcessor
|
40 |
+
from typing import Dict
|
41 |
+
|
42 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
43 |
+
|
44 |
+
|
45 |
+
@dataclass
|
46 |
+
class UNetPseudo3DConditionOutput(BaseOutput):
|
47 |
+
sample: torch.FloatTensor
|
48 |
+
|
49 |
+
|
50 |
+
class UNetPseudo3DConditionModel(ModelMixin, ConfigMixin):
|
51 |
+
"""
|
52 |
+
这里把原来2D Unet的 2D卷积全换成新定义的PseudoConv3d。并且定义了从2D卷积继承的模型参数。
|
53 |
+
"""
|
54 |
+
_supports_gradient_checkpointing = True
|
55 |
+
|
56 |
+
@register_to_config
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
sample_size: Optional[int] = None,
|
60 |
+
in_channels: int = 4,
|
61 |
+
out_channels: int = 4,
|
62 |
+
center_input_sample: bool = False,
|
63 |
+
flip_sin_to_cos: bool = True,
|
64 |
+
freq_shift: int = 0,
|
65 |
+
down_block_types: Tuple[str] = (
|
66 |
+
"CrossAttnDownBlockPseudo3D",
|
67 |
+
"CrossAttnDownBlockPseudo3D",
|
68 |
+
"CrossAttnDownBlockPseudo3D",
|
69 |
+
"DownBlockPseudo3D",
|
70 |
+
),
|
71 |
+
mid_block_type: str = "UNetMidBlockPseudo3DCrossAttn",
|
72 |
+
up_block_types: Tuple[str] = (
|
73 |
+
"UpBlockPseudo3D",
|
74 |
+
"CrossAttnUpBlockPseudo3D",
|
75 |
+
"CrossAttnUpBlockPseudo3D",
|
76 |
+
"CrossAttnUpBlockPseudo3D",
|
77 |
+
),
|
78 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
79 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
80 |
+
layers_per_block: int = 2,
|
81 |
+
downsample_padding: int = 1,
|
82 |
+
mid_block_scale_factor: float = 1,
|
83 |
+
act_fn: str = "silu",
|
84 |
+
norm_num_groups: int = 32,
|
85 |
+
norm_eps: float = 1e-5,
|
86 |
+
cross_attention_dim: int = 1280,
|
87 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
88 |
+
dual_cross_attention: bool = False,
|
89 |
+
use_linear_projection: bool = False,
|
90 |
+
fps_embed_type: Optional[str] = None,
|
91 |
+
num_fps_embeds: Optional[int] = None,
|
92 |
+
upcast_attention: bool = False,
|
93 |
+
resnet_time_scale_shift: str = "default",
|
94 |
+
num_class_embeds=None,
|
95 |
+
|
96 |
+
):
|
97 |
+
super().__init__()
|
98 |
+
|
99 |
+
|
100 |
+
self.sample_size = sample_size
|
101 |
+
time_embed_dim = block_out_channels[0] * 4
|
102 |
+
|
103 |
+
# input
|
104 |
+
self.conv_in = PseudoConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
|
105 |
+
|
106 |
+
# time
|
107 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
108 |
+
timestep_input_dim = block_out_channels[0]
|
109 |
+
|
110 |
+
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
111 |
+
|
112 |
+
# class embedding
|
113 |
+
if fps_embed_type is None and num_fps_embeds is not None:
|
114 |
+
self.fps_embedding = nn.Embedding(num_fps_embeds, time_embed_dim)
|
115 |
+
elif fps_embed_type == "timestep":
|
116 |
+
self.fps_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
117 |
+
elif fps_embed_type == "identity":
|
118 |
+
self.fps_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
119 |
+
else:
|
120 |
+
self.fps_embedding = None
|
121 |
+
|
122 |
+
self.down_blocks = nn.ModuleList([])
|
123 |
+
self.mid_block = None
|
124 |
+
self.up_blocks = nn.ModuleList([])
|
125 |
+
|
126 |
+
if isinstance(only_cross_attention, bool):
|
127 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
128 |
+
|
129 |
+
if isinstance(attention_head_dim, int):
|
130 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
131 |
+
|
132 |
+
# down
|
133 |
+
output_channel = block_out_channels[0]
|
134 |
+
for i, down_block_type in enumerate(down_block_types):
|
135 |
+
input_channel = output_channel
|
136 |
+
output_channel = block_out_channels[i]
|
137 |
+
is_final_block = i == len(block_out_channels) - 1
|
138 |
+
|
139 |
+
down_block = get_down_block(
|
140 |
+
down_block_type,
|
141 |
+
num_layers=layers_per_block,
|
142 |
+
in_channels=input_channel,
|
143 |
+
out_channels=output_channel,
|
144 |
+
temb_channels=time_embed_dim,
|
145 |
+
add_downsample=not is_final_block,
|
146 |
+
resnet_eps=norm_eps,
|
147 |
+
resnet_act_fn=act_fn,
|
148 |
+
resnet_groups=norm_num_groups,
|
149 |
+
cross_attention_dim=cross_attention_dim,
|
150 |
+
attn_num_head_channels=attention_head_dim[i],
|
151 |
+
downsample_padding=downsample_padding,
|
152 |
+
dual_cross_attention=dual_cross_attention,
|
153 |
+
use_linear_projection=use_linear_projection,
|
154 |
+
only_cross_attention=only_cross_attention[i],
|
155 |
+
upcast_attention=upcast_attention,
|
156 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
157 |
+
)
|
158 |
+
self.down_blocks.append(down_block)
|
159 |
+
|
160 |
+
# mid
|
161 |
+
if mid_block_type == "UNetMidBlockPseudo3DCrossAttn":
|
162 |
+
self.mid_block = UNetMidBlockPseudo3DCrossAttn(
|
163 |
+
in_channels=block_out_channels[-1],
|
164 |
+
temb_channels=time_embed_dim,
|
165 |
+
resnet_eps=norm_eps,
|
166 |
+
resnet_act_fn=act_fn,
|
167 |
+
output_scale_factor=mid_block_scale_factor,
|
168 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
169 |
+
cross_attention_dim=cross_attention_dim,
|
170 |
+
attn_num_head_channels=attention_head_dim[-1],
|
171 |
+
resnet_groups=norm_num_groups,
|
172 |
+
dual_cross_attention=dual_cross_attention,
|
173 |
+
use_linear_projection=use_linear_projection,
|
174 |
+
upcast_attention=upcast_attention,
|
175 |
+
)
|
176 |
+
else:
|
177 |
+
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
178 |
+
|
179 |
+
# count how many layers upsample the images
|
180 |
+
self.num_upsamplers = 0
|
181 |
+
|
182 |
+
# up
|
183 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
184 |
+
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
185 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
186 |
+
output_channel = reversed_block_out_channels[0]
|
187 |
+
for i, up_block_type in enumerate(up_block_types):
|
188 |
+
is_final_block = i == len(block_out_channels) - 1
|
189 |
+
|
190 |
+
prev_output_channel = output_channel
|
191 |
+
output_channel = reversed_block_out_channels[i]
|
192 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
193 |
+
|
194 |
+
# add upsample block for all BUT final layer
|
195 |
+
if not is_final_block:
|
196 |
+
add_upsample = True
|
197 |
+
self.num_upsamplers += 1
|
198 |
+
else:
|
199 |
+
add_upsample = False
|
200 |
+
|
201 |
+
up_block = get_up_block(
|
202 |
+
up_block_type,
|
203 |
+
num_layers=layers_per_block + 1,
|
204 |
+
in_channels=input_channel,
|
205 |
+
out_channels=output_channel,
|
206 |
+
prev_output_channel=prev_output_channel,
|
207 |
+
temb_channels=time_embed_dim,
|
208 |
+
add_upsample=add_upsample,
|
209 |
+
resnet_eps=norm_eps,
|
210 |
+
resnet_act_fn=act_fn,
|
211 |
+
resnet_groups=norm_num_groups,
|
212 |
+
cross_attention_dim=cross_attention_dim,
|
213 |
+
attn_num_head_channels=reversed_attention_head_dim[i],
|
214 |
+
dual_cross_attention=dual_cross_attention,
|
215 |
+
use_linear_projection=use_linear_projection,
|
216 |
+
only_cross_attention=only_cross_attention[i],
|
217 |
+
upcast_attention=upcast_attention,
|
218 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
219 |
+
)
|
220 |
+
self.up_blocks.append(up_block)
|
221 |
+
prev_output_channel = output_channel
|
222 |
+
|
223 |
+
# out
|
224 |
+
self.conv_norm_out = nn.GroupNorm(
|
225 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
226 |
+
)
|
227 |
+
self.conv_act = nn.SiLU()
|
228 |
+
self.conv_out = PseudoConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
229 |
+
@property
|
230 |
+
def attn_processors(self) -> Dict[str, AttnProcessor]:
|
231 |
+
r"""
|
232 |
+
Returns:
|
233 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
234 |
+
indexed by its weight name.
|
235 |
+
"""
|
236 |
+
# set recursively
|
237 |
+
processors = {}
|
238 |
+
|
239 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]):
|
240 |
+
if hasattr(module, "set_processor"):
|
241 |
+
processors[f"{name}.processor"] = module.processor
|
242 |
+
|
243 |
+
for sub_name, child in module.named_children():
|
244 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
245 |
+
|
246 |
+
return processors
|
247 |
+
|
248 |
+
for name, module in self.named_children():
|
249 |
+
fn_recursive_add_processors(name, module, processors)
|
250 |
+
|
251 |
+
return processors
|
252 |
+
|
253 |
+
def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]):
|
254 |
+
r"""
|
255 |
+
Parameters:
|
256 |
+
`processor (`dict` of `AttnProcessor` or `AttnProcessor`):
|
257 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
258 |
+
of **all** `CrossAttention` layers.
|
259 |
+
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.:
|
260 |
+
|
261 |
+
"""
|
262 |
+
count = len(self.attn_processors.keys())
|
263 |
+
|
264 |
+
if isinstance(processor, dict) and len(processor) != count:
|
265 |
+
raise ValueError(
|
266 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
267 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
268 |
+
)
|
269 |
+
|
270 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
271 |
+
if hasattr(module, "set_processor"):
|
272 |
+
if not isinstance(processor, dict):
|
273 |
+
module.set_processor(processor)
|
274 |
+
else:
|
275 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
276 |
+
|
277 |
+
for sub_name, child in module.named_children():
|
278 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
279 |
+
|
280 |
+
for name, module in self.named_children():
|
281 |
+
fn_recursive_attn_processor(name, module, processor)
|
282 |
+
|
283 |
+
|
284 |
+
def set_attention_slice(self, slice_size):
|
285 |
+
r"""
|
286 |
+
Enable sliced attention computation.
|
287 |
+
|
288 |
+
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
289 |
+
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
290 |
+
|
291 |
+
Args:
|
292 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
293 |
+
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
294 |
+
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
|
295 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
296 |
+
must be a multiple of `slice_size`.
|
297 |
+
"""
|
298 |
+
sliceable_head_dims = []
|
299 |
+
|
300 |
+
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
|
301 |
+
if hasattr(module, "set_attention_slice"):
|
302 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
303 |
+
|
304 |
+
for child in module.children():
|
305 |
+
fn_recursive_retrieve_slicable_dims(child)
|
306 |
+
|
307 |
+
# retrieve number of attention layers
|
308 |
+
for module in self.children():
|
309 |
+
fn_recursive_retrieve_slicable_dims(module)
|
310 |
+
|
311 |
+
num_slicable_layers = len(sliceable_head_dims)
|
312 |
+
|
313 |
+
if slice_size == "auto":
|
314 |
+
# half the attention head size is usually a good trade-off between
|
315 |
+
# speed and memory
|
316 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
317 |
+
elif slice_size == "max":
|
318 |
+
# make smallest slice possible
|
319 |
+
slice_size = num_slicable_layers * [1]
|
320 |
+
|
321 |
+
slice_size = (
|
322 |
+
num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
323 |
+
)
|
324 |
+
|
325 |
+
if len(slice_size) != len(sliceable_head_dims):
|
326 |
+
raise ValueError(
|
327 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
328 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
329 |
+
)
|
330 |
+
|
331 |
+
for i in range(len(slice_size)):
|
332 |
+
size = slice_size[i]
|
333 |
+
dim = sliceable_head_dims[i]
|
334 |
+
if size is not None and size > dim:
|
335 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
336 |
+
|
337 |
+
# Recursively walk through all the children.
|
338 |
+
# Any children which exposes the set_attention_slice method
|
339 |
+
# gets the message
|
340 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
341 |
+
if hasattr(module, "set_attention_slice"):
|
342 |
+
module.set_attention_slice(slice_size.pop())
|
343 |
+
|
344 |
+
for child in module.children():
|
345 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
346 |
+
|
347 |
+
reversed_slice_size = list(reversed(slice_size))
|
348 |
+
for module in self.children():
|
349 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
350 |
+
|
351 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
352 |
+
if isinstance(
|
353 |
+
module,
|
354 |
+
(CrossAttnDownBlockPseudo3D, DownBlockPseudo3D, CrossAttnUpBlockPseudo3D, UpBlockPseudo3D),
|
355 |
+
):
|
356 |
+
module.gradient_checkpointing = value
|
357 |
+
|
358 |
+
def forward(
|
359 |
+
self,
|
360 |
+
sample: torch.FloatTensor,
|
361 |
+
timestep: Union[torch.Tensor, float, int],
|
362 |
+
encoder_hidden_states: torch.Tensor,
|
363 |
+
fps_labels: Optional[torch.Tensor] = None,
|
364 |
+
attention_mask: Optional[torch.Tensor] = None,
|
365 |
+
cross_attention_kwargs=None,
|
366 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
367 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
368 |
+
return_dict: bool = True,
|
369 |
+
) -> Union[UNetPseudo3DConditionOutput, Tuple]:
|
370 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
371 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
372 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
373 |
+
# on the fly if necessary.
|
374 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
375 |
+
|
376 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
377 |
+
forward_upsample_size = False
|
378 |
+
upsample_size = None
|
379 |
+
|
380 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
381 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
382 |
+
forward_upsample_size = True
|
383 |
+
|
384 |
+
# prepare attention_mask
|
385 |
+
if attention_mask is not None:
|
386 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
387 |
+
attention_mask = attention_mask.unsqueeze(1)
|
388 |
+
|
389 |
+
# 0. center input if necessary
|
390 |
+
if self.config.center_input_sample:
|
391 |
+
sample = 2 * sample - 1.0
|
392 |
+
|
393 |
+
# 1. time
|
394 |
+
timesteps = timestep
|
395 |
+
if not torch.is_tensor(timesteps):
|
396 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
397 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
398 |
+
is_mps = sample.device.type == "mps"
|
399 |
+
if isinstance(timestep, float):
|
400 |
+
dtype = torch.float32 if is_mps else torch.float64
|
401 |
+
else:
|
402 |
+
dtype = torch.int32 if is_mps else torch.int64
|
403 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
404 |
+
elif len(timesteps.shape) == 0:
|
405 |
+
timesteps = timesteps[None].to(sample.device)
|
406 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
407 |
+
timesteps = timesteps.expand(sample.shape[0])
|
408 |
+
|
409 |
+
t_emb = self.time_proj(timesteps)
|
410 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
411 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
412 |
+
# there might be better ways to encapsulate this.
|
413 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
414 |
+
emb = self.time_embedding(t_emb)
|
415 |
+
|
416 |
+
if self.fps_embedding is not None:
|
417 |
+
if fps_labels is None:
|
418 |
+
raise ValueError("fps_labels should be provided when num_fps_embeds > 0")
|
419 |
+
|
420 |
+
if self.config.fps_embed_type == "timestep":
|
421 |
+
fps_labels = self.time_proj(fps_labels) # 和timesteps共用,都是sin embedding?这里的weight不更新的。
|
422 |
+
|
423 |
+
# 这里和上面timesteps does not contain any weights and will always return f32 tensors的bug一样。需要先cast过去,不然多机多卡就有问题了。
|
424 |
+
fps_labels = fps_labels.to(dtype=self.dtype)
|
425 |
+
class_emb = self.fps_embedding(fps_labels)
|
426 |
+
|
427 |
+
emb = emb + class_emb
|
428 |
+
|
429 |
+
# 2. pre-process
|
430 |
+
sample = self.conv_in(sample)
|
431 |
+
|
432 |
+
# 3. down
|
433 |
+
down_block_res_samples = (sample,)
|
434 |
+
for downsample_block in self.down_blocks:
|
435 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
436 |
+
sample, res_samples = downsample_block(
|
437 |
+
hidden_states=sample,
|
438 |
+
temb=emb,
|
439 |
+
encoder_hidden_states=encoder_hidden_states,
|
440 |
+
attention_mask=attention_mask,
|
441 |
+
)
|
442 |
+
else:
|
443 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
444 |
+
|
445 |
+
down_block_res_samples += res_samples
|
446 |
+
|
447 |
+
if down_block_additional_residuals is not None:
|
448 |
+
new_down_block_res_samples = ()
|
449 |
+
|
450 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
451 |
+
down_block_res_samples, down_block_additional_residuals
|
452 |
+
):
|
453 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
454 |
+
new_down_block_res_samples += (down_block_res_sample,)
|
455 |
+
|
456 |
+
down_block_res_samples = new_down_block_res_samples
|
457 |
+
|
458 |
+
# 4. mid
|
459 |
+
sample = self.mid_block(
|
460 |
+
sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
461 |
+
)
|
462 |
+
if mid_block_additional_residual is not None:
|
463 |
+
sample = sample + mid_block_additional_residual
|
464 |
+
|
465 |
+
# 5. up
|
466 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
467 |
+
is_final_block = i == len(self.up_blocks) - 1
|
468 |
+
|
469 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
470 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
471 |
+
|
472 |
+
# if we have not reached the final block and need to forward the
|
473 |
+
# upsample size, we do it here
|
474 |
+
if not is_final_block and forward_upsample_size:
|
475 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
476 |
+
|
477 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
478 |
+
sample = upsample_block(
|
479 |
+
hidden_states=sample,
|
480 |
+
temb=emb,
|
481 |
+
res_hidden_states_tuple=res_samples,
|
482 |
+
encoder_hidden_states=encoder_hidden_states,
|
483 |
+
upsample_size=upsample_size,
|
484 |
+
attention_mask=attention_mask,
|
485 |
+
)
|
486 |
+
else:
|
487 |
+
sample = upsample_block(
|
488 |
+
hidden_states=sample,
|
489 |
+
temb=emb,
|
490 |
+
res_hidden_states_tuple=res_samples,
|
491 |
+
upsample_size=upsample_size,
|
492 |
+
)
|
493 |
+
# 6. post-process
|
494 |
+
sample = self.conv_norm_out(sample)
|
495 |
+
sample = self.conv_act(sample)
|
496 |
+
sample = self.conv_out(sample)
|
497 |
+
|
498 |
+
if not return_dict:
|
499 |
+
return (sample,)
|
500 |
+
|
501 |
+
return UNetPseudo3DConditionOutput(sample=sample)
|
502 |
+
|
503 |
+
@classmethod
|
504 |
+
def from_2d_model(cls, model_path, condition_on_fps=False):
|
505 |
+
'''
|
506 |
+
load a 2d model and convert it to a pseudo 3d model
|
507 |
+
'''
|
508 |
+
config_path = os.path.join(model_path, "config.json")
|
509 |
+
if not os.path.isfile(config_path):
|
510 |
+
raise RuntimeError(f"{config_path} does not exist")
|
511 |
+
with open(config_path, "r") as f:
|
512 |
+
config = json.load(f)
|
513 |
+
|
514 |
+
config.pop("_class_name")
|
515 |
+
config.pop("_diffusers_version")
|
516 |
+
|
517 |
+
block_replacer = {
|
518 |
+
"CrossAttnDownBlock2D": "CrossAttnDownBlockPseudo3D",
|
519 |
+
"DownBlock2D": "DownBlockPseudo3D",
|
520 |
+
"UpBlock2D": "UpBlockPseudo3D",
|
521 |
+
"CrossAttnUpBlock2D": "CrossAttnUpBlockPseudo3D",
|
522 |
+
}
|
523 |
+
|
524 |
+
def convert_2d_to_3d_block(block):
|
525 |
+
return block_replacer[block] if block in block_replacer else block
|
526 |
+
|
527 |
+
config["down_block_types"] = [
|
528 |
+
convert_2d_to_3d_block(block) for block in config["down_block_types"]
|
529 |
+
]
|
530 |
+
config["up_block_types"] = [convert_2d_to_3d_block(block) for block in config["up_block_types"]]
|
531 |
+
|
532 |
+
if condition_on_fps:
|
533 |
+
# config["num_fps_embeds"] = 60 # 这个在 trainable embeding时候才需要~
|
534 |
+
config["fps_embed_type"] = "timestep" # 和timestep保持一致的type。
|
535 |
+
|
536 |
+
|
537 |
+
model = cls(**config) # 调用自身(init), 传入config参数全换成3d的setting
|
538 |
+
|
539 |
+
state_dict_path_condidates = glob.glob(os.path.join(model_path, "*.bin"))
|
540 |
+
if state_dict_path_condidates:
|
541 |
+
state_dict = torch.load(state_dict_path_condidates[0], map_location="cpu")
|
542 |
+
model.load_2d_state_dict(state_dict=state_dict)
|
543 |
+
|
544 |
+
return model
|
545 |
+
|
546 |
+
def load_2d_state_dict(self, state_dict, **kwargs):
|
547 |
+
'''
|
548 |
+
2D 部分的参数名完全不变。
|
549 |
+
'''
|
550 |
+
state_dict_3d = self.state_dict()
|
551 |
+
|
552 |
+
for k, v in state_dict.items():
|
553 |
+
if k not in state_dict_3d:
|
554 |
+
raise KeyError(f"2d state_dict key {k} does not exist in 3d model")
|
555 |
+
elif v.shape != state_dict_3d[k].shape:
|
556 |
+
raise ValueError(f"state_dict shape mismatch, 2d {v.shape}, 3d {state_dict_3d[k].shape}")
|
557 |
+
|
558 |
+
for k, v in state_dict_3d.items():
|
559 |
+
if "_temporal" in k:
|
560 |
+
continue
|
561 |
+
if "gamma" in k:
|
562 |
+
continue
|
563 |
+
|
564 |
+
if k not in state_dict:
|
565 |
+
if "fps_embedding" in k:
|
566 |
+
# 忽略检查fps_embedding
|
567 |
+
continue
|
568 |
+
raise KeyError(f"3d state_dict key {k} does not exist in 2d model")
|
569 |
+
|
570 |
+
state_dict_3d.update(state_dict)
|
571 |
+
self.load_state_dict(state_dict_3d, **kwargs)
|
model/video_diffusion/pipelines/__init__.py
ADDED
File without changes
|
model/video_diffusion/pipelines/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (178 Bytes). View file
|
|
model/video_diffusion/pipelines/__pycache__/pipeline_st_stable_diffusion.cpython-39.pyc
ADDED
Binary file (19.3 kB). View file
|
|
model/video_diffusion/pipelines/__pycache__/pipeline_stable_diffusion_controlnet3d.cpython-39.pyc
ADDED
Binary file (12.1 kB). View file
|
|
model/video_diffusion/pipelines/pipeline_st_stable_diffusion.py
ADDED
@@ -0,0 +1,618 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Bytedance Ltd. and/or its affiliates
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import inspect
|
16 |
+
from typing import Callable, List, Optional, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from einops import rearrange
|
20 |
+
|
21 |
+
from diffusers.utils import is_accelerate_available
|
22 |
+
from packaging import version
|
23 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
24 |
+
|
25 |
+
from diffusers.configuration_utils import FrozenDict
|
26 |
+
from diffusers.models import AutoencoderKL
|
27 |
+
from diffusers.pipeline_utils import DiffusionPipeline
|
28 |
+
from diffusers.schedulers import (
|
29 |
+
DDIMScheduler,
|
30 |
+
DPMSolverMultistepScheduler,
|
31 |
+
EulerAncestralDiscreteScheduler,
|
32 |
+
EulerDiscreteScheduler,
|
33 |
+
LMSDiscreteScheduler,
|
34 |
+
PNDMScheduler,
|
35 |
+
)
|
36 |
+
from diffusers.utils import deprecate, logging
|
37 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
38 |
+
|
39 |
+
from ..models.unet_3d_condition import UNetPseudo3DConditionModel
|
40 |
+
import os, importlib
|
41 |
+
|
42 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
43 |
+
|
44 |
+
|
45 |
+
class SpatioTemporalStableDiffusionPipeline(DiffusionPipeline):
|
46 |
+
r"""
|
47 |
+
Pipeline for text-to-video generation using Spatio-Temporal Stable Diffusion.
|
48 |
+
改变了unet的输入, unet换成3d unet, 其他部分完全和原来2D的一致。
|
49 |
+
latents的变为 b,c,f,h,w 原来是 b,c,h,w。
|
50 |
+
要用VAE的decoder的时候, 把输入reshape 成 (b f) c h w
|
51 |
+
"""
|
52 |
+
_optional_components = []
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
vae: AutoencoderKL,
|
57 |
+
text_encoder: CLIPTextModel,
|
58 |
+
tokenizer: CLIPTokenizer,
|
59 |
+
unet: UNetPseudo3DConditionModel,
|
60 |
+
scheduler: Union[
|
61 |
+
DDIMScheduler,
|
62 |
+
PNDMScheduler,
|
63 |
+
LMSDiscreteScheduler,
|
64 |
+
EulerDiscreteScheduler,
|
65 |
+
EulerAncestralDiscreteScheduler,
|
66 |
+
DPMSolverMultistepScheduler,
|
67 |
+
],
|
68 |
+
):
|
69 |
+
super().__init__()
|
70 |
+
|
71 |
+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
72 |
+
deprecation_message = (
|
73 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
74 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
75 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
76 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
77 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
78 |
+
" file"
|
79 |
+
)
|
80 |
+
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
81 |
+
new_config = dict(scheduler.config)
|
82 |
+
new_config["steps_offset"] = 1
|
83 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
84 |
+
|
85 |
+
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
86 |
+
deprecation_message = (
|
87 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
88 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
89 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
90 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
91 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
92 |
+
)
|
93 |
+
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
94 |
+
new_config = dict(scheduler.config)
|
95 |
+
new_config["clip_sample"] = False
|
96 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
97 |
+
|
98 |
+
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
99 |
+
version.parse(unet.config._diffusers_version).base_version
|
100 |
+
) < version.parse("0.9.0.dev0")
|
101 |
+
is_unet_sample_size_less_64 = (
|
102 |
+
hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
103 |
+
)
|
104 |
+
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
105 |
+
deprecation_message = (
|
106 |
+
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
107 |
+
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
108 |
+
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
109 |
+
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
110 |
+
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
111 |
+
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
112 |
+
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
113 |
+
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
114 |
+
" the `unet/config.json` file"
|
115 |
+
)
|
116 |
+
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
117 |
+
new_config = dict(unet.config)
|
118 |
+
new_config["sample_size"] = 64
|
119 |
+
unet._internal_dict = FrozenDict(new_config)
|
120 |
+
|
121 |
+
self.register_modules(
|
122 |
+
vae=vae,
|
123 |
+
text_encoder=text_encoder,
|
124 |
+
tokenizer=tokenizer,
|
125 |
+
unet=unet,
|
126 |
+
scheduler=scheduler,
|
127 |
+
)
|
128 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
129 |
+
|
130 |
+
def enable_vae_slicing(self):
|
131 |
+
r"""
|
132 |
+
Enable sliced VAE decoding.
|
133 |
+
|
134 |
+
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
|
135 |
+
steps. This is useful to save some memory and allow larger batch sizes.
|
136 |
+
"""
|
137 |
+
self.vae.enable_slicing()
|
138 |
+
|
139 |
+
def disable_vae_slicing(self):
|
140 |
+
r"""
|
141 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
|
142 |
+
computing decoding in one step.
|
143 |
+
"""
|
144 |
+
self.vae.disable_slicing()
|
145 |
+
|
146 |
+
def enable_sequential_cpu_offload(self, gpu_id=0):
|
147 |
+
r"""
|
148 |
+
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
149 |
+
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
150 |
+
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
151 |
+
"""
|
152 |
+
if is_accelerate_available():
|
153 |
+
from accelerate import cpu_offload
|
154 |
+
else:
|
155 |
+
raise ImportError("Please install accelerate via `pip install accelerate`")
|
156 |
+
|
157 |
+
device = torch.device(f"cuda:{gpu_id}")
|
158 |
+
|
159 |
+
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
160 |
+
if cpu_offloaded_model is not None:
|
161 |
+
cpu_offload(cpu_offloaded_model, device)
|
162 |
+
|
163 |
+
@property
|
164 |
+
def _execution_device(self):
|
165 |
+
r"""
|
166 |
+
Returns the device on which the pipeline's models will be executed. After calling
|
167 |
+
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
168 |
+
hooks.
|
169 |
+
"""
|
170 |
+
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
171 |
+
return self.device
|
172 |
+
for module in self.unet.modules():
|
173 |
+
if (
|
174 |
+
hasattr(module, "_hf_hook")
|
175 |
+
and hasattr(module._hf_hook, "execution_device")
|
176 |
+
and module._hf_hook.execution_device is not None
|
177 |
+
):
|
178 |
+
return torch.device(module._hf_hook.execution_device)
|
179 |
+
return self.device
|
180 |
+
|
181 |
+
def _encode_prompt(
|
182 |
+
self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
183 |
+
):
|
184 |
+
r"""
|
185 |
+
Encodes the prompt into text encoder hidden states.
|
186 |
+
|
187 |
+
Args:
|
188 |
+
prompt (`str` or `list(int)`):
|
189 |
+
prompt to be encoded
|
190 |
+
device: (`torch.device`):
|
191 |
+
torch device
|
192 |
+
num_images_per_prompt (`int`):
|
193 |
+
number of images that should be generated per prompt
|
194 |
+
do_classifier_free_guidance (`bool`):
|
195 |
+
whether to use classifier free guidance or not
|
196 |
+
negative_prompt (`str` or `List[str]`):
|
197 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
198 |
+
if `guidance_scale` is less than `1`).
|
199 |
+
"""
|
200 |
+
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
201 |
+
|
202 |
+
text_inputs = self.tokenizer(
|
203 |
+
prompt,
|
204 |
+
padding="max_length",
|
205 |
+
max_length=self.tokenizer.model_max_length,
|
206 |
+
truncation=True,
|
207 |
+
return_tensors="pt",
|
208 |
+
)
|
209 |
+
text_input_ids = text_inputs.input_ids
|
210 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
211 |
+
|
212 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
213 |
+
text_input_ids, untruncated_ids
|
214 |
+
):
|
215 |
+
removed_text = self.tokenizer.batch_decode(
|
216 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
217 |
+
)
|
218 |
+
logger.warning(
|
219 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
220 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
221 |
+
)
|
222 |
+
|
223 |
+
if (
|
224 |
+
hasattr(self.text_encoder.config, "use_attention_mask")
|
225 |
+
and self.text_encoder.config.use_attention_mask
|
226 |
+
):
|
227 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
228 |
+
else:
|
229 |
+
attention_mask = None
|
230 |
+
|
231 |
+
text_embeddings = self.text_encoder(
|
232 |
+
text_input_ids.to(self.text_encoder.device), # FIXME 强制对齐device的位置
|
233 |
+
attention_mask=attention_mask,
|
234 |
+
)
|
235 |
+
text_embeddings = text_embeddings[0]
|
236 |
+
|
237 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
238 |
+
bs_embed, seq_len, _ = text_embeddings.shape
|
239 |
+
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
240 |
+
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
241 |
+
|
242 |
+
# get unconditional embeddings for classifier free guidance
|
243 |
+
if do_classifier_free_guidance:
|
244 |
+
uncond_tokens: List[str]
|
245 |
+
if negative_prompt is None:
|
246 |
+
uncond_tokens = [""] * batch_size
|
247 |
+
elif type(prompt) is not type(negative_prompt):
|
248 |
+
raise TypeError(
|
249 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
250 |
+
f" {type(prompt)}."
|
251 |
+
)
|
252 |
+
elif isinstance(negative_prompt, str):
|
253 |
+
uncond_tokens = [negative_prompt]
|
254 |
+
elif batch_size != len(negative_prompt):
|
255 |
+
raise ValueError(
|
256 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
257 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
258 |
+
" the batch size of `prompt`."
|
259 |
+
)
|
260 |
+
else:
|
261 |
+
uncond_tokens = negative_prompt
|
262 |
+
|
263 |
+
max_length = text_input_ids.shape[-1]
|
264 |
+
uncond_input = self.tokenizer(
|
265 |
+
uncond_tokens,
|
266 |
+
padding="max_length",
|
267 |
+
max_length=max_length,
|
268 |
+
truncation=True,
|
269 |
+
return_tensors="pt",
|
270 |
+
)
|
271 |
+
|
272 |
+
if (
|
273 |
+
hasattr(self.text_encoder.config, "use_attention_mask")
|
274 |
+
and self.text_encoder.config.use_attention_mask
|
275 |
+
):
|
276 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
277 |
+
else:
|
278 |
+
attention_mask = None
|
279 |
+
|
280 |
+
uncond_embeddings = self.text_encoder(
|
281 |
+
uncond_input.input_ids.to(self.text_encoder.device), # 同上,强制位置对齐。
|
282 |
+
attention_mask=attention_mask,
|
283 |
+
)
|
284 |
+
uncond_embeddings = uncond_embeddings[0]
|
285 |
+
|
286 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
287 |
+
seq_len = uncond_embeddings.shape[1]
|
288 |
+
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
|
289 |
+
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
290 |
+
|
291 |
+
# For classifier free guidance, we need to do two forward passes.
|
292 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
293 |
+
# to avoid doing two forward passes
|
294 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
295 |
+
|
296 |
+
return text_embeddings
|
297 |
+
|
298 |
+
def decode_latents(self, latents):
|
299 |
+
b = latents.shape[0]
|
300 |
+
latents = 1 / 0.18215 * latents
|
301 |
+
|
302 |
+
is_video = len(latents.shape) == 5
|
303 |
+
if is_video:
|
304 |
+
latents = rearrange(latents, "b c f h w -> (b f) c h w")
|
305 |
+
|
306 |
+
image = self.vae.decode(latents).sample
|
307 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
308 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
309 |
+
|
310 |
+
image = image.cpu().float().numpy()
|
311 |
+
if is_video:
|
312 |
+
image = rearrange(image, "(b f) c h w -> b f h w c", b=b)
|
313 |
+
else:
|
314 |
+
image = rearrange(image, "b c h w -> b h w c")
|
315 |
+
return image
|
316 |
+
|
317 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
318 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
319 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
320 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
321 |
+
# and should be between [0, 1]
|
322 |
+
|
323 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
324 |
+
extra_step_kwargs = {}
|
325 |
+
if accepts_eta:
|
326 |
+
extra_step_kwargs["eta"] = eta
|
327 |
+
|
328 |
+
# check if the scheduler accepts generator
|
329 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
330 |
+
if accepts_generator:
|
331 |
+
extra_step_kwargs["generator"] = generator
|
332 |
+
return extra_step_kwargs
|
333 |
+
|
334 |
+
def check_inputs(self, prompt, height, width, callback_steps):
|
335 |
+
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
336 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
337 |
+
|
338 |
+
if height % 8 != 0 or width % 8 != 0:
|
339 |
+
raise ValueError(
|
340 |
+
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
|
341 |
+
)
|
342 |
+
|
343 |
+
if (callback_steps is None) or (
|
344 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
345 |
+
):
|
346 |
+
raise ValueError(
|
347 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
348 |
+
f" {type(callback_steps)}."
|
349 |
+
)
|
350 |
+
|
351 |
+
def prepare_latents(
|
352 |
+
self,
|
353 |
+
batch_size,
|
354 |
+
num_channels_latents,
|
355 |
+
clip_length,
|
356 |
+
height,
|
357 |
+
width,
|
358 |
+
dtype,
|
359 |
+
device,
|
360 |
+
generator,
|
361 |
+
latents=None,
|
362 |
+
):
|
363 |
+
if clip_length>0:
|
364 |
+
shape = (
|
365 |
+
batch_size,
|
366 |
+
num_channels_latents,
|
367 |
+
clip_length,
|
368 |
+
height // self.vae_scale_factor,
|
369 |
+
width // self.vae_scale_factor,
|
370 |
+
)
|
371 |
+
else:
|
372 |
+
shape = (
|
373 |
+
batch_size,
|
374 |
+
num_channels_latents,
|
375 |
+
height // self.vae_scale_factor,
|
376 |
+
width // self.vae_scale_factor,
|
377 |
+
)
|
378 |
+
|
379 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
380 |
+
raise ValueError(
|
381 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
382 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
383 |
+
)
|
384 |
+
|
385 |
+
if latents is None:
|
386 |
+
rand_device = "cpu" if device.type == "mps" else device
|
387 |
+
|
388 |
+
if isinstance(generator, list):
|
389 |
+
shape = (1,) + shape[1:]
|
390 |
+
latents = [
|
391 |
+
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
392 |
+
for i in range(batch_size)
|
393 |
+
]
|
394 |
+
latents = torch.cat(latents, dim=0).to(device)
|
395 |
+
else:
|
396 |
+
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(
|
397 |
+
device
|
398 |
+
)
|
399 |
+
else:
|
400 |
+
if latents.shape != shape:
|
401 |
+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
402 |
+
latents = latents.to(device)
|
403 |
+
|
404 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
405 |
+
latents = latents * self.scheduler.init_noise_sigma
|
406 |
+
return latents
|
407 |
+
|
408 |
+
@torch.no_grad()
|
409 |
+
def __call__(
|
410 |
+
self,
|
411 |
+
prompt: Union[str, List[str]],
|
412 |
+
height: Optional[int] = None,
|
413 |
+
width: Optional[int] = None,
|
414 |
+
fps_labels = None,
|
415 |
+
num_inference_steps: int = 50,
|
416 |
+
clip_length: int = 8,
|
417 |
+
guidance_scale: float = 7.5,
|
418 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
419 |
+
num_images_per_prompt: Optional[int] = 1,
|
420 |
+
eta: float = 0.0,
|
421 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
422 |
+
latents: Optional[torch.FloatTensor] = None,
|
423 |
+
output_type: Optional[str] = "pil",
|
424 |
+
return_dict: bool = True,
|
425 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
426 |
+
callback_steps: Optional[int] = 1,
|
427 |
+
):
|
428 |
+
r"""
|
429 |
+
Function invoked when calling the pipeline for generation.
|
430 |
+
|
431 |
+
Args:
|
432 |
+
prompt (`str` or `List[str]`):
|
433 |
+
The prompt or prompts to guide the image generation.
|
434 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
435 |
+
The height in pixels of the generated image.
|
436 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
437 |
+
The width in pixels of the generated image.
|
438 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
439 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
440 |
+
expense of slower inference.
|
441 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
442 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
443 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
444 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
445 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
446 |
+
usually at the expense of lower image quality.
|
447 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
448 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
449 |
+
if `guidance_scale` is less than `1`).
|
450 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
451 |
+
The number of images to generate per prompt.
|
452 |
+
eta (`float`, *optional*, defaults to 0.0):
|
453 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
454 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
455 |
+
generator (`torch.Generator`, *optional*):
|
456 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
457 |
+
to make generation deterministic.
|
458 |
+
latents (`torch.FloatTensor`, *optional*):
|
459 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
460 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
461 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
462 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
463 |
+
The output format of the generate image. Choose between
|
464 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
465 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
466 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
467 |
+
plain tuple.
|
468 |
+
callback (`Callable`, *optional*):
|
469 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
470 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
471 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
472 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
473 |
+
called at every step.
|
474 |
+
|
475 |
+
Returns:
|
476 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
477 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
478 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
479 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
480 |
+
(nsfw) content, according to the `safety_checker`.
|
481 |
+
"""
|
482 |
+
# 0. Default height and width to unet
|
483 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
484 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
485 |
+
|
486 |
+
# 1. Check inputs. Raise error if not correct
|
487 |
+
self.check_inputs(prompt, height, width, callback_steps)
|
488 |
+
|
489 |
+
# 2. Define call parameters
|
490 |
+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
491 |
+
device = self._execution_device
|
492 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
493 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
494 |
+
# corresponds to doing no classifier free guidance.
|
495 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
496 |
+
|
497 |
+
# 3. Encode input prompt
|
498 |
+
text_embeddings = self._encode_prompt(
|
499 |
+
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
500 |
+
)
|
501 |
+
|
502 |
+
# 4. Prepare timesteps
|
503 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
504 |
+
timesteps = self.scheduler.timesteps
|
505 |
+
|
506 |
+
# 5. Prepare latent variables
|
507 |
+
num_channels_latents = self.unet.in_channels
|
508 |
+
|
509 |
+
latents = self.prepare_latents(
|
510 |
+
batch_size * num_images_per_prompt,
|
511 |
+
num_channels_latents,
|
512 |
+
clip_length,
|
513 |
+
height,
|
514 |
+
width,
|
515 |
+
text_embeddings.dtype,
|
516 |
+
device,
|
517 |
+
generator,
|
518 |
+
latents,
|
519 |
+
)
|
520 |
+
latents_dtype = latents.dtype
|
521 |
+
|
522 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
523 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
524 |
+
|
525 |
+
# 7. Denoising loop
|
526 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
527 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
528 |
+
for i, t in enumerate(timesteps):
|
529 |
+
# expand the latents if we are doing classifier free guidance
|
530 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
531 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
532 |
+
# print(latent_model_input.shape, )
|
533 |
+
# predict the noise residual
|
534 |
+
if fps_labels:
|
535 |
+
if isinstance(fps_labels, list):
|
536 |
+
fps_labels = torch.tensor(fps_labels).to(self.unet.device)
|
537 |
+
# 控制帧率
|
538 |
+
noise_pred = self.unet(
|
539 |
+
latent_model_input, t, encoder_hidden_states=text_embeddings, fps_labels=fps_labels,
|
540 |
+
).sample.to(dtype=latents_dtype)
|
541 |
+
else:
|
542 |
+
noise_pred = self.unet(
|
543 |
+
latent_model_input, t, encoder_hidden_states=text_embeddings
|
544 |
+
).sample.to(dtype=latents_dtype)
|
545 |
+
|
546 |
+
# perform guidance
|
547 |
+
if do_classifier_free_guidance:
|
548 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
549 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
550 |
+
noise_pred_text - noise_pred_uncond
|
551 |
+
)
|
552 |
+
|
553 |
+
# compute the previous noisy sample x_t -> x_t-1
|
554 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
555 |
+
|
556 |
+
# call the callback, if provided
|
557 |
+
if i == len(timesteps) - 1 or (
|
558 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
559 |
+
):
|
560 |
+
progress_bar.update()
|
561 |
+
if callback is not None and i % callback_steps == 0:
|
562 |
+
callback(i, t, latents)
|
563 |
+
|
564 |
+
# 8. Post-processing
|
565 |
+
image = self.decode_latents(latents)
|
566 |
+
# image[:, 1:, :, :, :] = image[:, 1:, :, :, :] + image[:, 0:1, :, :, :] # 叠加残差
|
567 |
+
|
568 |
+
# 9. Run safety checker
|
569 |
+
has_nsfw_concept = None
|
570 |
+
|
571 |
+
# 10. Convert to PIL
|
572 |
+
if output_type == "pil":
|
573 |
+
image = self.numpy_to_pil(image)
|
574 |
+
|
575 |
+
if not return_dict:
|
576 |
+
return (image, has_nsfw_concept)
|
577 |
+
|
578 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
579 |
+
|
580 |
+
@staticmethod
|
581 |
+
def numpy_to_pil(images):
|
582 |
+
if len(images.shape)==5:
|
583 |
+
pil_images = []
|
584 |
+
for sequence in images:
|
585 |
+
pil_images.append(DiffusionPipeline.numpy_to_pil(sequence))
|
586 |
+
return pil_images
|
587 |
+
else:
|
588 |
+
return DiffusionPipeline.numpy_to_pil(images)
|
589 |
+
|
590 |
+
|
591 |
+
# 改写一下 model_index.json的保存内容, Unet是新定义的,直接保存会导致读取的时候出错~
|
592 |
+
def to_json_string(self) -> str:
|
593 |
+
from diffusers import __version__
|
594 |
+
import json
|
595 |
+
import numpy as np
|
596 |
+
|
597 |
+
config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
|
598 |
+
config_dict["_class_name"] = self.__class__.__name__
|
599 |
+
config_dict["_diffusers_version"] = __version__
|
600 |
+
|
601 |
+
def to_json_saveable(value):
|
602 |
+
if isinstance(value, np.ndarray):
|
603 |
+
value = value.tolist()
|
604 |
+
return value
|
605 |
+
|
606 |
+
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
|
607 |
+
if 'unet' in config_dict:
|
608 |
+
config_dict["unet"] = [
|
609 |
+
"diffusers",
|
610 |
+
"UNet2DConditionModel"
|
611 |
+
]
|
612 |
+
if 'controlnet' in config_dict:
|
613 |
+
config_dict['controlnet'] = [
|
614 |
+
"diffusers",
|
615 |
+
"UNet2DConditionModel"
|
616 |
+
]
|
617 |
+
# 覆盖
|
618 |
+
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
model/video_diffusion/pipelines/pipeline_stable_diffusion_controlnet3d.py
ADDED
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Bytedance Ltd. and/or its affiliates
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from .pipeline_st_stable_diffusion import SpatioTemporalStableDiffusionPipeline
|
17 |
+
from typing import Callable, List, Optional, Union
|
18 |
+
from diffusers.schedulers import (
|
19 |
+
DDIMScheduler,
|
20 |
+
DPMSolverMultistepScheduler,
|
21 |
+
EulerAncestralDiscreteScheduler,
|
22 |
+
EulerDiscreteScheduler,
|
23 |
+
LMSDiscreteScheduler,
|
24 |
+
PNDMScheduler,
|
25 |
+
)
|
26 |
+
from transformers import DPTForDepthEstimation
|
27 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
28 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
29 |
+
from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler
|
30 |
+
import torch
|
31 |
+
from einops import rearrange, repeat
|
32 |
+
import decord
|
33 |
+
import cv2
|
34 |
+
import random
|
35 |
+
import numpy as np
|
36 |
+
from ..models.unet_3d_condition import UNetPseudo3DConditionModel
|
37 |
+
from ..models.controlnet3d import ControlNet3DModel
|
38 |
+
|
39 |
+
|
40 |
+
class Controlnet3DStableDiffusionPipeline(SpatioTemporalStableDiffusionPipeline):
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
vae: AutoencoderKL,
|
44 |
+
text_encoder: CLIPTextModel,
|
45 |
+
tokenizer: CLIPTokenizer,
|
46 |
+
unet: UNetPseudo3DConditionModel,
|
47 |
+
controlnet: ControlNet3DModel,
|
48 |
+
scheduler: Union[
|
49 |
+
DDIMScheduler,
|
50 |
+
PNDMScheduler,
|
51 |
+
LMSDiscreteScheduler,
|
52 |
+
EulerDiscreteScheduler,
|
53 |
+
EulerAncestralDiscreteScheduler,
|
54 |
+
DPMSolverMultistepScheduler,
|
55 |
+
],
|
56 |
+
annotator_model=None,
|
57 |
+
|
58 |
+
):
|
59 |
+
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
|
60 |
+
|
61 |
+
self.annotator_model = annotator_model
|
62 |
+
self.controlnet = controlnet
|
63 |
+
self.unet = unet
|
64 |
+
self.vae = vae
|
65 |
+
self.tokenizer = tokenizer
|
66 |
+
self.text_encoder = text_encoder
|
67 |
+
self.scheduler = scheduler
|
68 |
+
self.register_modules(
|
69 |
+
vae=vae,
|
70 |
+
text_encoder=text_encoder,
|
71 |
+
tokenizer=tokenizer,
|
72 |
+
unet=unet,
|
73 |
+
controlnet=controlnet,
|
74 |
+
scheduler=scheduler,
|
75 |
+
)
|
76 |
+
|
77 |
+
@staticmethod
|
78 |
+
def get_frames_preprocess(data_path, num_frames=24, sampling_rate=1, begin_indice=0, return_np=False):
|
79 |
+
vr = decord.VideoReader(data_path,)
|
80 |
+
n_images = len(vr)
|
81 |
+
fps_vid = round(vr.get_avg_fps())
|
82 |
+
frame_indices = [begin_indice + i*sampling_rate for i in range(num_frames)] # 随机取n帧
|
83 |
+
|
84 |
+
|
85 |
+
while n_images <= frame_indices[-1]:
|
86 |
+
# 超过视频长度,采样率减小直至不超过。
|
87 |
+
sampling_rate -= 1
|
88 |
+
if sampling_rate == 0:
|
89 |
+
# NOTE 边界检查
|
90 |
+
return None, None
|
91 |
+
frame_indices = [i*sampling_rate for i in range(num_frames)]
|
92 |
+
frames = vr.get_batch(frame_indices).asnumpy()
|
93 |
+
|
94 |
+
if return_np:
|
95 |
+
return frames, fps_vid
|
96 |
+
|
97 |
+
frames = torch.from_numpy(frames).div(255) * 2 - 1
|
98 |
+
frames = rearrange(frames, "f h w c -> c f h w").unsqueeze(0)
|
99 |
+
return frames, fps_vid
|
100 |
+
|
101 |
+
@torch.no_grad()
|
102 |
+
def get_canny_edge_map(self, frames, ):
|
103 |
+
# (b f) c h w"
|
104 |
+
# from tensor to numpy
|
105 |
+
inputs = frames.cpu().numpy()
|
106 |
+
inputs = rearrange(inputs, 'f c h w -> f h w c')
|
107 |
+
# inputs from [-1, 1] to [0, 255]
|
108 |
+
inputs = (inputs + 1) * 127.5
|
109 |
+
inputs = inputs.astype(np.uint8)
|
110 |
+
lower_threshold = 100
|
111 |
+
higher_threshold = 200
|
112 |
+
edge_images = np.stack([cv2.Canny(inp, lower_threshold, higher_threshold) for inp in inputs])
|
113 |
+
# from numpy to tensors
|
114 |
+
edge_images = torch.from_numpy(edge_images).unsqueeze(1) # f, 1, h, w
|
115 |
+
edge_images = edge_images.div(255)*2 - 1
|
116 |
+
# print(torch.max(out_images), torch.min(out_images), out_images.dtype)
|
117 |
+
return edge_images.to(dtype= self.controlnet.dtype, device=self.controlnet.device)
|
118 |
+
|
119 |
+
@torch.no_grad()
|
120 |
+
def get_depth_map(self, frames, height, width, return_standard_norm=False ):
|
121 |
+
"""
|
122 |
+
frames should be like: (f c h w), you may turn b f c h w -> (b f) c h w first
|
123 |
+
"""
|
124 |
+
h,w = height, width
|
125 |
+
inputs = torch.nn.functional.interpolate(
|
126 |
+
frames,
|
127 |
+
size=(384, 384),
|
128 |
+
mode="bicubic",
|
129 |
+
antialias=True,
|
130 |
+
)
|
131 |
+
# 转类型和设备
|
132 |
+
inputs = inputs.to(dtype= self.annotator_model.dtype, device=self.annotator_model.device)
|
133 |
+
|
134 |
+
outputs = self.annotator_model(inputs)
|
135 |
+
predicted_depths = outputs.predicted_depth
|
136 |
+
|
137 |
+
# interpolate to original size
|
138 |
+
predictions = torch.nn.functional.interpolate(
|
139 |
+
predicted_depths.unsqueeze(1),
|
140 |
+
size=(h, w),
|
141 |
+
mode="bicubic",
|
142 |
+
)
|
143 |
+
|
144 |
+
# normalize output
|
145 |
+
if return_standard_norm:
|
146 |
+
depth_min = torch.amin(predictions, dim=[1, 2, 3], keepdim=True)
|
147 |
+
depth_max = torch.amax(predictions, dim=[1, 2, 3], keepdim=True)
|
148 |
+
predictions = 2.0 * (predictions - depth_min) / (depth_max - depth_min) - 1.0
|
149 |
+
else:
|
150 |
+
predictions -= torch.min(predictions)
|
151 |
+
predictions /= torch.max(predictions)
|
152 |
+
|
153 |
+
return predictions
|
154 |
+
|
155 |
+
|
156 |
+
@torch.no_grad()
|
157 |
+
def get_hed_map(self, frames,):
|
158 |
+
if isinstance(frames, torch.Tensor):
|
159 |
+
# 输入的就是 b c h w的tensor 范围是-1~1,需要转换为0~1
|
160 |
+
frames = (frames + 1) / 2
|
161 |
+
#rgb转bgr
|
162 |
+
bgr_frames = frames.clone()
|
163 |
+
bgr_frames[:, 0, :, :] = frames[:, 2, :, :]
|
164 |
+
bgr_frames[:, 2, :, :] = frames[:, 0, :, :]
|
165 |
+
|
166 |
+
edge = self.annotator_model(bgr_frames) # 范围也是0~1
|
167 |
+
return edge
|
168 |
+
else:
|
169 |
+
assert frames.ndim == 3
|
170 |
+
frames = frames[:, :, ::-1].copy()
|
171 |
+
with torch.no_grad():
|
172 |
+
image_hed = torch.from_numpy(frames).to(next(self.annotator_model.parameters()).device, dtype=next(self.annotator_model.parameters()).dtype )
|
173 |
+
image_hed = image_hed / 255.0
|
174 |
+
image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
|
175 |
+
edge = self.annotator_model(image_hed)[0]
|
176 |
+
edge = (edge.cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8)
|
177 |
+
return edge[0]
|
178 |
+
|
179 |
+
@torch.no_grad()
|
180 |
+
def get_pose_map(self, frames,):
|
181 |
+
if isinstance(frames, torch.Tensor):
|
182 |
+
# 输入的就是 b c h w的tensor 范围是-1~1,需要转换为0~1
|
183 |
+
frames = (frames + 1) / 2
|
184 |
+
np_frames = frames.cpu().numpy() * 255
|
185 |
+
np_frames = np.array(np_frames, dtype=np.uint8)
|
186 |
+
np_frames = rearrange(np_frames, 'f c h w-> f h w c')
|
187 |
+
poses = np.stack([self.annotator_model(inp) for inp in np_frames])
|
188 |
+
else:
|
189 |
+
poses = self.annotator_model(frames)
|
190 |
+
return poses
|
191 |
+
|
192 |
+
def get_timesteps(self, num_inference_steps, strength,):
|
193 |
+
# get the original timestep using init_timestep
|
194 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
195 |
+
|
196 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
197 |
+
timesteps = self.scheduler.timesteps[t_start:]
|
198 |
+
|
199 |
+
return timesteps, num_inference_steps - t_start
|
200 |
+
|
201 |
+
@torch.no_grad()
|
202 |
+
def __call__(
|
203 |
+
self,
|
204 |
+
prompt: Union[str, List[str]],
|
205 |
+
controlnet_hint = None,
|
206 |
+
fps_labels = None,
|
207 |
+
height: Optional[int] = None,
|
208 |
+
width: Optional[int] = None,
|
209 |
+
num_inference_steps: int = 50,
|
210 |
+
clip_length: int = 8, # NOTE clip_length和images的帧数一致。
|
211 |
+
guidance_scale: float = 7.5,
|
212 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
213 |
+
num_images_per_prompt: Optional[int] = 1,
|
214 |
+
eta: float = 0.0,
|
215 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
216 |
+
latents: Optional[torch.FloatTensor] = None,
|
217 |
+
output_type: Optional[str] = "pil",
|
218 |
+
return_dict: bool = True,
|
219 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
220 |
+
callback_steps: Optional[int] = 1,
|
221 |
+
cross_attention_kwargs = None,
|
222 |
+
video_scale: float = 0.0,
|
223 |
+
controlnet_conditioning_scale: float = 1.0,
|
224 |
+
fix_first_frame=True,
|
225 |
+
first_frame_output = None , # 也可以允许挑好图后传入。
|
226 |
+
first_frame_output_latent = None,
|
227 |
+
first_frame_control_hint = None, # 维持第一帧
|
228 |
+
add_first_frame_by_concat = False,
|
229 |
+
controlhint_in_uncond = False,
|
230 |
+
init_same_noise_per_frame=False,
|
231 |
+
init_noise_by_residual_thres=0.0,
|
232 |
+
images=None,
|
233 |
+
in_domain=False, # 是否调用视频模型生成图片
|
234 |
+
residual_control_steps=1,
|
235 |
+
first_frame_ddim_strength=1.0,
|
236 |
+
return_last_latent = False,
|
237 |
+
):
|
238 |
+
'''
|
239 |
+
add origin video frames to get depth maps
|
240 |
+
'''
|
241 |
+
|
242 |
+
if fix_first_frame and first_frame_output is None and first_frame_output_latent is None:
|
243 |
+
first_frame_output = self.__call__(
|
244 |
+
prompt=prompt,
|
245 |
+
controlnet_hint=controlnet_hint[:,:,0,:,:] if not in_domain else controlnet_hint[:,:,0:1,:,:],
|
246 |
+
# b c f h w
|
247 |
+
num_inference_steps=20,
|
248 |
+
width=width,
|
249 |
+
height=height,
|
250 |
+
guidance_scale=guidance_scale,
|
251 |
+
num_images_per_prompt=1,
|
252 |
+
generator=generator,
|
253 |
+
fix_first_frame=False,
|
254 |
+
controlhint_in_uncond=controlhint_in_uncond,
|
255 |
+
).images[0]
|
256 |
+
|
257 |
+
|
258 |
+
if first_frame_output is not None:
|
259 |
+
if isinstance(first_frame_output, list):
|
260 |
+
first_frame_output = first_frame_output[0]
|
261 |
+
first_frame_output = torch.from_numpy(np.array(first_frame_output)).div(255) * 2 - 1
|
262 |
+
first_frame_output = rearrange(first_frame_output, "h w c -> c h w").unsqueeze(0) # FIXME 目前不允许多个batch 先设置为1
|
263 |
+
first_frame_output = first_frame_output.to(dtype= self.vae.dtype, device=self.vae.device)
|
264 |
+
|
265 |
+
first_frame_output_latent = self.vae.encode(first_frame_output).latent_dist.sample()
|
266 |
+
first_frame_output_latent = first_frame_output_latent * 0.18215
|
267 |
+
# 0. Default height and width to unet
|
268 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
269 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
270 |
+
|
271 |
+
# 1. Check inputs. Raise error if not correct
|
272 |
+
self.check_inputs(prompt, height, width, callback_steps)
|
273 |
+
|
274 |
+
# 2. Define call parameters
|
275 |
+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
276 |
+
device = self._execution_device
|
277 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
278 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
279 |
+
# corresponds to doing no classifier free guidance.
|
280 |
+
do_classifier_free_guidance = guidance_scale > 5.0
|
281 |
+
|
282 |
+
# 3. Encode input prompt
|
283 |
+
text_embeddings = self._encode_prompt(
|
284 |
+
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
285 |
+
)
|
286 |
+
|
287 |
+
# 4. Prepare timesteps
|
288 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
289 |
+
timesteps = self.scheduler.timesteps
|
290 |
+
|
291 |
+
# 5. Prepare latent variables
|
292 |
+
num_channels_latents = self.unet.in_channels
|
293 |
+
if controlnet_hint is not None:
|
294 |
+
if len(controlnet_hint.shape) == 5:
|
295 |
+
clip_length = controlnet_hint.shape[2]
|
296 |
+
else:
|
297 |
+
clip_length = 0
|
298 |
+
|
299 |
+
latents = self.prepare_latents(
|
300 |
+
batch_size * num_images_per_prompt,
|
301 |
+
num_channels_latents,
|
302 |
+
clip_length,
|
303 |
+
height,
|
304 |
+
width,
|
305 |
+
text_embeddings.dtype,
|
306 |
+
device,
|
307 |
+
generator,
|
308 |
+
latents,
|
309 |
+
)
|
310 |
+
latents_dtype = latents.dtype
|
311 |
+
|
312 |
+
|
313 |
+
if len(latents.shape) == 5 and init_same_noise_per_frame:
|
314 |
+
latents[:,:,1:,:,:] = latents[:,:,0:1,:,:]
|
315 |
+
|
316 |
+
if len(latents.shape) == 5 and init_noise_by_residual_thres > 0.0 and images is not None:
|
317 |
+
|
318 |
+
images = images.to(device=device, dtype=latents_dtype) # b c f h w
|
319 |
+
image_residual = torch.abs(images[:,:,1:,:,:] - images[:,:,:-1,:,:])
|
320 |
+
images = rearrange(images, "b c f h w -> (b f) c h w")
|
321 |
+
|
322 |
+
# norm residual
|
323 |
+
image_residual = image_residual / torch.max(image_residual)
|
324 |
+
|
325 |
+
image_residual = rearrange(image_residual, "b c f h w -> (b f) c h w")
|
326 |
+
image_residual = torch.nn.functional.interpolate(
|
327 |
+
image_residual,
|
328 |
+
size=(latents.shape[-2], latents.shape[-1]),
|
329 |
+
mode='bilinear')
|
330 |
+
image_residual = torch.mean(image_residual, dim=1)
|
331 |
+
|
332 |
+
image_residual_mask = (image_residual > init_noise_by_residual_thres).float()
|
333 |
+
image_residual_mask = repeat(image_residual_mask, '(b f) h w -> b f h w', b=batch_size)
|
334 |
+
image_residual_mask = repeat(image_residual_mask, 'b f h w -> b c f h w', c=latents.shape[1])
|
335 |
+
|
336 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
337 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
338 |
+
|
339 |
+
# 7. Denoising loop
|
340 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
341 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
342 |
+
if fix_first_frame:
|
343 |
+
if add_first_frame_by_concat:
|
344 |
+
if len(first_frame_output_latent.shape) == 4:
|
345 |
+
latents = torch.cat([first_frame_output_latent.unsqueeze(2), latents], dim=2)
|
346 |
+
else:
|
347 |
+
latents = torch.cat([first_frame_output_latent, latents], dim=2)
|
348 |
+
if first_frame_control_hint is not None:
|
349 |
+
controlnet_hint = torch.cat([first_frame_control_hint, controlnet_hint], dim=2)
|
350 |
+
else:
|
351 |
+
controlnet_hint = torch.cat([controlnet_hint[:,:,0:1 ,:,:], controlnet_hint], dim=2)
|
352 |
+
|
353 |
+
if controlhint_in_uncond:
|
354 |
+
controlnet_hint = torch.cat([controlnet_hint] * 2) if do_classifier_free_guidance else controlnet_hint
|
355 |
+
for i, t in enumerate(timesteps):
|
356 |
+
# expand the latents if we are doing classifier free guidance
|
357 |
+
if i<residual_control_steps and len(latents.shape) == 5 and init_noise_by_residual_thres > 0.0 and images is not None :
|
358 |
+
if first_frame_ddim_strength < 1.0 and i == 0 :
|
359 |
+
# NOTE DDIM to get the first noise
|
360 |
+
first_frame_output_latent_DDIM = first_frame_output_latent.clone()
|
361 |
+
full_noise_timestep, _ = self.get_timesteps(num_inference_steps, strength=first_frame_ddim_strength)
|
362 |
+
latent_timestep = full_noise_timestep[:1].repeat(batch_size * num_images_per_prompt)
|
363 |
+
first_frame_output_latent_DDIM = self.scheduler.add_noise(first_frame_output_latent_DDIM, latents[:,:,0,:,:], latent_timestep)
|
364 |
+
latents[:,:,0,:,:]=first_frame_output_latent_DDIM
|
365 |
+
begin_frame = 1
|
366 |
+
for n_frame in range(begin_frame, latents.shape[2]):
|
367 |
+
latents[:,:, n_frame, :, :] = \
|
368 |
+
(latents[:,:, n_frame, :, :] - latents[:,:, n_frame-1, :, :]) \
|
369 |
+
* image_residual_mask[:,:, n_frame-1, :, :] + \
|
370 |
+
latents[:,:, n_frame-1, :, :]
|
371 |
+
if fix_first_frame:
|
372 |
+
latents[:,:,0 ,:,:] = first_frame_output_latent
|
373 |
+
|
374 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
375 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
376 |
+
if controlnet_hint is not None:
|
377 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
378 |
+
latent_model_input,
|
379 |
+
t,
|
380 |
+
encoder_hidden_states=text_embeddings,
|
381 |
+
controlnet_cond=controlnet_hint,
|
382 |
+
return_dict=False,
|
383 |
+
)
|
384 |
+
down_block_res_samples = [
|
385 |
+
down_block_res_sample * controlnet_conditioning_scale
|
386 |
+
for down_block_res_sample in down_block_res_samples
|
387 |
+
]
|
388 |
+
mid_block_res_sample *= controlnet_conditioning_scale
|
389 |
+
|
390 |
+
noise_pred = self.unet(
|
391 |
+
latent_model_input,
|
392 |
+
t,
|
393 |
+
encoder_hidden_states=text_embeddings,
|
394 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
395 |
+
down_block_additional_residuals=down_block_res_samples,
|
396 |
+
mid_block_additional_residual=mid_block_res_sample,
|
397 |
+
).sample.to(dtype=latents_dtype)
|
398 |
+
else:
|
399 |
+
# predict the noise residual
|
400 |
+
noise_pred = self.unet(
|
401 |
+
latent_model_input,
|
402 |
+
t,
|
403 |
+
encoder_hidden_states=text_embeddings,
|
404 |
+
).sample.to(dtype=latents_dtype)
|
405 |
+
|
406 |
+
if video_scale > 0 and controlnet_hint is not None:
|
407 |
+
bsz = latents.shape[0]
|
408 |
+
f = latents.shape[2]
|
409 |
+
# 逐帧预测
|
410 |
+
latent_model_input_single_frame = rearrange(latent_model_input, 'b c f h w -> (b f) c h w')
|
411 |
+
text_embeddings_single_frame = torch.cat([text_embeddings] * f, dim=0)
|
412 |
+
control_maps_single_frame = rearrange(controlnet_hint, 'b c f h w -> (b f) c h w')
|
413 |
+
latent_model_input_single_frame = latent_model_input_single_frame.chunk(2, dim=0)[0]
|
414 |
+
text_embeddings_single_frame = text_embeddings_single_frame.chunk(2, dim=0)[0]
|
415 |
+
if controlhint_in_uncond:
|
416 |
+
control_maps_single_frame = control_maps_single_frame.chunk(2, dim=0)[0]
|
417 |
+
|
418 |
+
down_block_res_samples_single_frame, mid_block_res_sample_single_frame = self.controlnet(
|
419 |
+
latent_model_input_single_frame,
|
420 |
+
t,
|
421 |
+
encoder_hidden_states=text_embeddings_single_frame,
|
422 |
+
controlnet_cond=control_maps_single_frame,
|
423 |
+
return_dict=False,
|
424 |
+
)
|
425 |
+
down_block_res_samples_single_frame = [
|
426 |
+
down_block_res_sample_single_frame * controlnet_conditioning_scale
|
427 |
+
for down_block_res_sample_single_frame in down_block_res_samples_single_frame
|
428 |
+
]
|
429 |
+
mid_block_res_sample_single_frame *= controlnet_conditioning_scale
|
430 |
+
|
431 |
+
noise_pred_single_frame_uncond = self.unet(
|
432 |
+
latent_model_input_single_frame,
|
433 |
+
t,
|
434 |
+
encoder_hidden_states = text_embeddings_single_frame,
|
435 |
+
down_block_additional_residuals=down_block_res_samples_single_frame,
|
436 |
+
mid_block_additional_residual=mid_block_res_sample_single_frame,
|
437 |
+
).sample
|
438 |
+
noise_pred_single_frame_uncond = rearrange(noise_pred_single_frame_uncond, '(b f) c h w -> b c f h w', f=f)
|
439 |
+
# perform guidance
|
440 |
+
if do_classifier_free_guidance:
|
441 |
+
if video_scale > 0 and controlnet_hint is not None:
|
442 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
443 |
+
noise_pred = noise_pred_single_frame_uncond + video_scale * (
|
444 |
+
noise_pred_uncond - noise_pred_single_frame_uncond
|
445 |
+
) + guidance_scale * (
|
446 |
+
noise_pred_text - noise_pred_uncond
|
447 |
+
)
|
448 |
+
else:
|
449 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
450 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
451 |
+
noise_pred_text - noise_pred_uncond
|
452 |
+
)
|
453 |
+
|
454 |
+
# compute the previous noisy sample x_t -> x_t-1
|
455 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
456 |
+
|
457 |
+
# call the callback, if provided
|
458 |
+
if i == len(timesteps) - 1 or (
|
459 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
460 |
+
):
|
461 |
+
progress_bar.update()
|
462 |
+
if callback is not None and i % callback_steps == 0:
|
463 |
+
callback(i, t, latents)
|
464 |
+
# 8. Post-processing
|
465 |
+
image = self.decode_latents(latents)
|
466 |
+
if add_first_frame_by_concat:
|
467 |
+
image = image[:,1:,:,:,:]
|
468 |
+
|
469 |
+
# 9. Run safety checker
|
470 |
+
has_nsfw_concept = None
|
471 |
+
# 10. Convert to PIL
|
472 |
+
if output_type == "pil":
|
473 |
+
image = self.numpy_to_pil(image)
|
474 |
+
|
475 |
+
if not return_dict:
|
476 |
+
return (image, has_nsfw_concept)
|
477 |
+
|
478 |
+
if return_last_latent:
|
479 |
+
last_latent = latents[:,:,-1,:,:]
|
480 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept), last_latent
|
481 |
+
else:
|
482 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diffusers==0.14.0
|
2 |
+
transformers==4.27.3
|
3 |
+
accelerate==0.18.0
|
4 |
+
xformers==0.0.16
|
5 |
+
imageio==2.27.0
|
6 |
+
decord==0.6.0
|
7 |
+
opencv-python==4.7.0.72
|
8 |
+
einops==0.6.0
|