Update pipeline.py
Browse files- pipeline.py +176 -13
pipeline.py
CHANGED
@@ -53,6 +53,8 @@ import torchvision
|
|
53 |
import PIL
|
54 |
import PIL.Image
|
55 |
import math
|
|
|
|
|
56 |
|
57 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
58 |
|
@@ -70,7 +72,62 @@ EXAMPLE_DOC_STRING = """
|
|
70 |
>>> export_to_gif(frames, "animation.gif")
|
71 |
```
|
72 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
76 |
# Based on:
|
@@ -810,20 +867,104 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
810 |
# init_latents[:, :, 1:] = torch.zeros_like(init_latents[:, :, 1:])
|
811 |
|
812 |
latents = latents.to(device)
|
813 |
-
return latents, init_latents
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
814 |
|
815 |
-
|
816 |
-
|
817 |
-
|
818 |
-
|
819 |
-
|
820 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
821 |
)
|
822 |
-
|
823 |
-
# convert input control image array to latents tensor array
|
824 |
-
latents = torch.zeros(shape, dtype=dtype, device=device)
|
825 |
|
826 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
827 |
|
828 |
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
|
829 |
def prepare_control_frames(
|
@@ -1112,6 +1253,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1112 |
# 4. Prepare timesteps
|
1113 |
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
1114 |
timesteps = self.scheduler.timesteps
|
|
|
1115 |
|
1116 |
# round num frames to the nearest multiple of context size - overlap
|
1117 |
num_frames = (num_frames // (context_size - overlap)) * (context_size - overlap)
|
@@ -1189,6 +1331,25 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1189 |
smooth_weight,
|
1190 |
smooth_steps,
|
1191 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1192 |
|
1193 |
|
1194 |
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
@@ -1263,7 +1424,8 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1263 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1264 |
|
1265 |
|
1266 |
-
if self.controlnet != None:
|
|
|
1267 |
|
1268 |
current_context_conditioning_frames = conditioning_frames[current_context_indexes, :, :, :]
|
1269 |
current_context_conditioning_frames = torch.cat([current_context_conditioning_frames] * 2) if do_classifier_free_guidance else current_context_conditioning_frames
|
@@ -1302,7 +1464,8 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1302 |
conditioning_scale=cond_scale,
|
1303 |
guess_mode=guess_mode,
|
1304 |
return_dict=False,
|
1305 |
-
)
|
|
|
1306 |
|
1307 |
|
1308 |
# predict the noise residual with the added controlnet residuals
|
|
|
53 |
import PIL
|
54 |
import PIL.Image
|
55 |
import math
|
56 |
+
import time
|
57 |
+
|
58 |
|
59 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
60 |
|
|
|
72 |
>>> export_to_gif(frames, "animation.gif")
|
73 |
```
|
74 |
"""
|
75 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
76 |
+
def retrieve_latents(
|
77 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
78 |
+
):
|
79 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
80 |
+
return encoder_output.latent_dist.sample(generator)
|
81 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
82 |
+
return encoder_output.latent_dist.mode()
|
83 |
+
elif hasattr(encoder_output, "latents"):
|
84 |
+
return encoder_output.latents
|
85 |
+
else:
|
86 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
87 |
+
|
88 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
89 |
+
def retrieve_timesteps(
|
90 |
+
scheduler,
|
91 |
+
num_inference_steps: Optional[int] = None,
|
92 |
+
device: Optional[Union[str, torch.device]] = None,
|
93 |
+
timesteps: Optional[List[int]] = None,
|
94 |
+
**kwargs,
|
95 |
+
):
|
96 |
+
"""
|
97 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
98 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
99 |
|
100 |
+
Args:
|
101 |
+
scheduler (`SchedulerMixin`):
|
102 |
+
The scheduler to get timesteps from.
|
103 |
+
num_inference_steps (`int`):
|
104 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
105 |
+
`timesteps` must be `None`.
|
106 |
+
device (`str` or `torch.device`, *optional*):
|
107 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
108 |
+
timesteps (`List[int]`, *optional*):
|
109 |
+
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
110 |
+
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
|
111 |
+
must be `None`.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
115 |
+
second element is the number of inference steps.
|
116 |
+
"""
|
117 |
+
if timesteps is not None:
|
118 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
119 |
+
if not accepts_timesteps:
|
120 |
+
raise ValueError(
|
121 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
122 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
123 |
+
)
|
124 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
125 |
+
timesteps = scheduler.timesteps
|
126 |
+
num_inference_steps = len(timesteps)
|
127 |
+
else:
|
128 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
129 |
+
timesteps = scheduler.timesteps
|
130 |
+
return timesteps, num_inference_steps
|
131 |
|
132 |
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
133 |
# Based on:
|
|
|
867 |
# init_latents[:, :, 1:] = torch.zeros_like(init_latents[:, :, 1:])
|
868 |
|
869 |
latents = latents.to(device)
|
870 |
+
return latents, init_latents
|
871 |
+
|
872 |
+
def prepare_video_latents(
|
873 |
+
self,
|
874 |
+
video,
|
875 |
+
height,
|
876 |
+
width,
|
877 |
+
num_channels_latents,
|
878 |
+
batch_size,
|
879 |
+
timestep,
|
880 |
+
dtype,
|
881 |
+
device,
|
882 |
+
generator,
|
883 |
+
latents=None,
|
884 |
+
):
|
885 |
+
# video must be a list of list of images
|
886 |
+
# the outer list denotes having multiple videos as input, whereas inner list means the frames of the video
|
887 |
+
# as a list of images
|
888 |
+
if not isinstance(video[0], list):
|
889 |
+
video = [video]
|
890 |
+
if latents is None:
|
891 |
+
video = torch.cat(
|
892 |
+
[self.image_processor.preprocess(vid, height=height, width=width).unsqueeze(0) for vid in video], dim=0
|
893 |
+
)
|
894 |
+
video = video.to(device=device, dtype=dtype)
|
895 |
+
num_frames = video.shape[1]
|
896 |
+
else:
|
897 |
+
num_frames = latents.shape[2]
|
898 |
|
899 |
+
shape = (
|
900 |
+
batch_size,
|
901 |
+
num_channels_latents,
|
902 |
+
num_frames,
|
903 |
+
height // self.vae_scale_factor,
|
904 |
+
width // self.vae_scale_factor,
|
905 |
+
)
|
906 |
+
|
907 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
908 |
+
raise ValueError(
|
909 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
910 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
911 |
)
|
|
|
|
|
|
|
912 |
|
913 |
+
if latents is None:
|
914 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
915 |
+
if self.vae.config.force_upcast:
|
916 |
+
video = video.float()
|
917 |
+
self.vae.to(dtype=torch.float32)
|
918 |
+
|
919 |
+
if isinstance(generator, list):
|
920 |
+
if len(generator) != batch_size:
|
921 |
+
raise ValueError(
|
922 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
923 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
924 |
+
)
|
925 |
+
|
926 |
+
init_latents = [
|
927 |
+
retrieve_latents(self.vae.encode(video[i]), generator=generator[i]).unsqueeze(0)
|
928 |
+
for i in range(batch_size)
|
929 |
+
]
|
930 |
+
else:
|
931 |
+
init_latents = [
|
932 |
+
retrieve_latents(self.vae.encode(vid), generator=generator).unsqueeze(0) for vid in video
|
933 |
+
]
|
934 |
+
|
935 |
+
init_latents = torch.cat(init_latents, dim=0)
|
936 |
+
|
937 |
+
# restore vae to original dtype
|
938 |
+
if self.vae.config.force_upcast:
|
939 |
+
self.vae.to(dtype)
|
940 |
+
|
941 |
+
init_latents = init_latents.to(dtype)
|
942 |
+
init_latents = self.vae.config.scaling_factor * init_latents
|
943 |
+
|
944 |
+
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
945 |
+
# expand init_latents for batch_size
|
946 |
+
error_message = (
|
947 |
+
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
|
948 |
+
" images (`image`). Please make sure to update your script to pass as many initial images as text prompts"
|
949 |
+
)
|
950 |
+
raise ValueError(error_message)
|
951 |
+
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
952 |
+
raise ValueError(
|
953 |
+
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
954 |
+
)
|
955 |
+
else:
|
956 |
+
init_latents = torch.cat([init_latents], dim=0)
|
957 |
+
|
958 |
+
noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype)
|
959 |
+
latents = self.scheduler.add_noise(init_latents, noise, timestep).permute(0, 2, 1, 3, 4)
|
960 |
+
else:
|
961 |
+
if shape != latents.shape:
|
962 |
+
# [B, C, F, H, W]
|
963 |
+
raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}")
|
964 |
+
latents = latents.to(device, dtype=dtype)
|
965 |
+
|
966 |
+
return latents
|
967 |
+
|
968 |
|
969 |
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
|
970 |
def prepare_control_frames(
|
|
|
1253 |
# 4. Prepare timesteps
|
1254 |
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
1255 |
timesteps = self.scheduler.timesteps
|
1256 |
+
|
1257 |
|
1258 |
# round num frames to the nearest multiple of context size - overlap
|
1259 |
num_frames = (num_frames // (context_size - overlap)) * (context_size - overlap)
|
|
|
1331 |
smooth_weight,
|
1332 |
smooth_steps,
|
1333 |
)
|
1334 |
+
elif(latent_mode == "video"):
|
1335 |
+
# 4. Prepare timesteps
|
1336 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
1337 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, init_image_strength, device)
|
1338 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
|
1339 |
+
self._num_timesteps = len(timesteps)
|
1340 |
+
num_channels_latents = self.unet.config.in_channels
|
1341 |
+
latents = self.prepare_latents(
|
1342 |
+
video=video,
|
1343 |
+
height=height,
|
1344 |
+
width=width,
|
1345 |
+
num_channels_latents=num_channels_latents,
|
1346 |
+
batch_size=batch_size * num_videos_per_prompt,
|
1347 |
+
timestep=latent_timestep,
|
1348 |
+
dtype=prompt_embeds.dtype,
|
1349 |
+
device=device,
|
1350 |
+
generator=generator,
|
1351 |
+
latents=latents,
|
1352 |
+
)
|
1353 |
|
1354 |
|
1355 |
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
|
|
1424 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1425 |
|
1426 |
|
1427 |
+
if self.controlnet != None or i > 2:
|
1428 |
+
contorl_start = time.time()
|
1429 |
|
1430 |
current_context_conditioning_frames = conditioning_frames[current_context_indexes, :, :, :]
|
1431 |
current_context_conditioning_frames = torch.cat([current_context_conditioning_frames] * 2) if do_classifier_free_guidance else current_context_conditioning_frames
|
|
|
1464 |
conditioning_scale=cond_scale,
|
1465 |
guess_mode=guess_mode,
|
1466 |
return_dict=False,
|
1467 |
+
)
|
1468 |
+
print("controlnet time", time.time() - contorl_start)
|
1469 |
|
1470 |
|
1471 |
# predict the noise residual with the added controlnet residuals
|