smoothieAI commited on
Commit
d185fb7
·
verified ·
1 Parent(s): 72f3d04

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +176 -13
pipeline.py CHANGED
@@ -53,6 +53,8 @@ import torchvision
53
  import PIL
54
  import PIL.Image
55
  import math
 
 
56
 
57
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
58
 
@@ -70,7 +72,62 @@ EXAMPLE_DOC_STRING = """
70
  >>> export_to_gif(frames, "animation.gif")
71
  ```
72
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  def tensor2vid(video: torch.Tensor, processor, output_type="np"):
76
  # Based on:
@@ -810,20 +867,104 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
810
  # init_latents[:, :, 1:] = torch.zeros_like(init_latents[:, :, 1:])
811
 
812
  latents = latents.to(device)
813
- return latents, init_latents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
814
 
815
- def prepare_control_latents(self, batch_size, contorl_frames, num_channels_latents, num_frames, height, width, dtype, device):
816
- shape = (
817
- num_frames,
818
- num_channels_latents,
819
- height // self.vae_scale_factor,
820
- width // self.vae_scale_factor,
 
 
 
 
 
 
821
  )
822
-
823
- # convert input control image array to latents tensor array
824
- latents = torch.zeros(shape, dtype=dtype, device=device)
825
 
826
- return latents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
827
 
828
  # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
829
  def prepare_control_frames(
@@ -1112,6 +1253,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1112
  # 4. Prepare timesteps
1113
  self.scheduler.set_timesteps(num_inference_steps, device=device)
1114
  timesteps = self.scheduler.timesteps
 
1115
 
1116
  # round num frames to the nearest multiple of context size - overlap
1117
  num_frames = (num_frames // (context_size - overlap)) * (context_size - overlap)
@@ -1189,6 +1331,25 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1189
  smooth_weight,
1190
  smooth_steps,
1191
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1192
 
1193
 
1194
  # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
@@ -1263,7 +1424,8 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1263
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1264
 
1265
 
1266
- if self.controlnet != None:
 
1267
 
1268
  current_context_conditioning_frames = conditioning_frames[current_context_indexes, :, :, :]
1269
  current_context_conditioning_frames = torch.cat([current_context_conditioning_frames] * 2) if do_classifier_free_guidance else current_context_conditioning_frames
@@ -1302,7 +1464,8 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1302
  conditioning_scale=cond_scale,
1303
  guess_mode=guess_mode,
1304
  return_dict=False,
1305
- )
 
1306
 
1307
 
1308
  # predict the noise residual with the added controlnet residuals
 
53
  import PIL
54
  import PIL.Image
55
  import math
56
+ import time
57
+
58
 
59
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
60
 
 
72
  >>> export_to_gif(frames, "animation.gif")
73
  ```
74
  """
75
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
76
+ def retrieve_latents(
77
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
78
+ ):
79
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
80
+ return encoder_output.latent_dist.sample(generator)
81
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
82
+ return encoder_output.latent_dist.mode()
83
+ elif hasattr(encoder_output, "latents"):
84
+ return encoder_output.latents
85
+ else:
86
+ raise AttributeError("Could not access latents of provided encoder_output")
87
+
88
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
89
+ def retrieve_timesteps(
90
+ scheduler,
91
+ num_inference_steps: Optional[int] = None,
92
+ device: Optional[Union[str, torch.device]] = None,
93
+ timesteps: Optional[List[int]] = None,
94
+ **kwargs,
95
+ ):
96
+ """
97
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
98
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
99
 
100
+ Args:
101
+ scheduler (`SchedulerMixin`):
102
+ The scheduler to get timesteps from.
103
+ num_inference_steps (`int`):
104
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
105
+ `timesteps` must be `None`.
106
+ device (`str` or `torch.device`, *optional*):
107
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
108
+ timesteps (`List[int]`, *optional*):
109
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
110
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
111
+ must be `None`.
112
+
113
+ Returns:
114
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
115
+ second element is the number of inference steps.
116
+ """
117
+ if timesteps is not None:
118
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
119
+ if not accepts_timesteps:
120
+ raise ValueError(
121
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
122
+ f" timestep schedules. Please check whether you are using the correct scheduler."
123
+ )
124
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
125
+ timesteps = scheduler.timesteps
126
+ num_inference_steps = len(timesteps)
127
+ else:
128
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
129
+ timesteps = scheduler.timesteps
130
+ return timesteps, num_inference_steps
131
 
132
  def tensor2vid(video: torch.Tensor, processor, output_type="np"):
133
  # Based on:
 
867
  # init_latents[:, :, 1:] = torch.zeros_like(init_latents[:, :, 1:])
868
 
869
  latents = latents.to(device)
870
+ return latents, init_latents
871
+
872
+ def prepare_video_latents(
873
+ self,
874
+ video,
875
+ height,
876
+ width,
877
+ num_channels_latents,
878
+ batch_size,
879
+ timestep,
880
+ dtype,
881
+ device,
882
+ generator,
883
+ latents=None,
884
+ ):
885
+ # video must be a list of list of images
886
+ # the outer list denotes having multiple videos as input, whereas inner list means the frames of the video
887
+ # as a list of images
888
+ if not isinstance(video[0], list):
889
+ video = [video]
890
+ if latents is None:
891
+ video = torch.cat(
892
+ [self.image_processor.preprocess(vid, height=height, width=width).unsqueeze(0) for vid in video], dim=0
893
+ )
894
+ video = video.to(device=device, dtype=dtype)
895
+ num_frames = video.shape[1]
896
+ else:
897
+ num_frames = latents.shape[2]
898
 
899
+ shape = (
900
+ batch_size,
901
+ num_channels_latents,
902
+ num_frames,
903
+ height // self.vae_scale_factor,
904
+ width // self.vae_scale_factor,
905
+ )
906
+
907
+ if isinstance(generator, list) and len(generator) != batch_size:
908
+ raise ValueError(
909
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
910
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
911
  )
 
 
 
912
 
913
+ if latents is None:
914
+ # make sure the VAE is in float32 mode, as it overflows in float16
915
+ if self.vae.config.force_upcast:
916
+ video = video.float()
917
+ self.vae.to(dtype=torch.float32)
918
+
919
+ if isinstance(generator, list):
920
+ if len(generator) != batch_size:
921
+ raise ValueError(
922
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
923
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
924
+ )
925
+
926
+ init_latents = [
927
+ retrieve_latents(self.vae.encode(video[i]), generator=generator[i]).unsqueeze(0)
928
+ for i in range(batch_size)
929
+ ]
930
+ else:
931
+ init_latents = [
932
+ retrieve_latents(self.vae.encode(vid), generator=generator).unsqueeze(0) for vid in video
933
+ ]
934
+
935
+ init_latents = torch.cat(init_latents, dim=0)
936
+
937
+ # restore vae to original dtype
938
+ if self.vae.config.force_upcast:
939
+ self.vae.to(dtype)
940
+
941
+ init_latents = init_latents.to(dtype)
942
+ init_latents = self.vae.config.scaling_factor * init_latents
943
+
944
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
945
+ # expand init_latents for batch_size
946
+ error_message = (
947
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
948
+ " images (`image`). Please make sure to update your script to pass as many initial images as text prompts"
949
+ )
950
+ raise ValueError(error_message)
951
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
952
+ raise ValueError(
953
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
954
+ )
955
+ else:
956
+ init_latents = torch.cat([init_latents], dim=0)
957
+
958
+ noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype)
959
+ latents = self.scheduler.add_noise(init_latents, noise, timestep).permute(0, 2, 1, 3, 4)
960
+ else:
961
+ if shape != latents.shape:
962
+ # [B, C, F, H, W]
963
+ raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}")
964
+ latents = latents.to(device, dtype=dtype)
965
+
966
+ return latents
967
+
968
 
969
  # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
970
  def prepare_control_frames(
 
1253
  # 4. Prepare timesteps
1254
  self.scheduler.set_timesteps(num_inference_steps, device=device)
1255
  timesteps = self.scheduler.timesteps
1256
+
1257
 
1258
  # round num frames to the nearest multiple of context size - overlap
1259
  num_frames = (num_frames // (context_size - overlap)) * (context_size - overlap)
 
1331
  smooth_weight,
1332
  smooth_steps,
1333
  )
1334
+ elif(latent_mode == "video"):
1335
+ # 4. Prepare timesteps
1336
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
1337
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, init_image_strength, device)
1338
+ latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
1339
+ self._num_timesteps = len(timesteps)
1340
+ num_channels_latents = self.unet.config.in_channels
1341
+ latents = self.prepare_latents(
1342
+ video=video,
1343
+ height=height,
1344
+ width=width,
1345
+ num_channels_latents=num_channels_latents,
1346
+ batch_size=batch_size * num_videos_per_prompt,
1347
+ timestep=latent_timestep,
1348
+ dtype=prompt_embeds.dtype,
1349
+ device=device,
1350
+ generator=generator,
1351
+ latents=latents,
1352
+ )
1353
 
1354
 
1355
  # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
 
1424
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1425
 
1426
 
1427
+ if self.controlnet != None or i > 2:
1428
+ contorl_start = time.time()
1429
 
1430
  current_context_conditioning_frames = conditioning_frames[current_context_indexes, :, :, :]
1431
  current_context_conditioning_frames = torch.cat([current_context_conditioning_frames] * 2) if do_classifier_free_guidance else current_context_conditioning_frames
 
1464
  conditioning_scale=cond_scale,
1465
  guess_mode=guess_mode,
1466
  return_dict=False,
1467
+ )
1468
+ print("controlnet time", time.time() - contorl_start)
1469
 
1470
 
1471
  # predict the noise residual with the added controlnet residuals