from typing import Callable, Dict, List, Optional, Union import gc import numpy as np import torch import torch.nn.functional as F from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import ( _resize_with_antialiasing, StableVideoDiffusionPipeline, retrieve_timesteps, ) from diffusers.utils import logging from kornia.utils import create_meshgrid from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution logger = logging.get_logger(__name__) # pylint: disable=invalid-name @torch.no_grad() def normalize_point_map(point_map, valid_mask): # T,H,W,3 T,H,W norm_factor = (point_map[..., 2] * valid_mask.float()).mean() / (valid_mask.float().mean() + 1e-8) norm_factor = norm_factor.clip(min=1e-3) return point_map / norm_factor def point_map_xy2intrinsic_map(point_map_xy): # *,h,w,2 height, width = point_map_xy.shape[-3], point_map_xy.shape[-2] assert height % 2 == 0 assert width % 2 == 0 mesh_grid = create_meshgrid( height=height, width=width, normalized_coordinates=True, device=point_map_xy.device, dtype=point_map_xy.dtype )[0] # h,w,2 assert mesh_grid.abs().min() > 1e-4 # *,h,w,2 mesh_grid = mesh_grid.expand_as(point_map_xy) nc = point_map_xy.mean(dim=-2).mean(dim=-2) # *, 2 nc_map = nc[..., None, None, :].expand_as(point_map_xy) nf = ((point_map_xy - nc_map) / mesh_grid).mean(dim=-2).mean(dim=-2) nf_map = nf[..., None, None, :].expand_as(point_map_xy) # print((mesh_grid * nf_map + nc_map - point_map_xy).abs().max()) return torch.cat([nc_map, nf_map], dim=-1) def robust_min_max(tensor, quantile=0.99): T, H, W = tensor.shape min_vals = [] max_vals = [] for i in range(T): min_vals.append(torch.quantile(tensor[i], q=1-quantile, interpolation='nearest').item()) max_vals.append(torch.quantile(tensor[i], q=quantile, interpolation='nearest').item()) return min(min_vals), max(max_vals) class GeometryCrafterDiffPipeline(StableVideoDiffusionPipeline): @torch.inference_mode() def encode_video( self, video: torch.Tensor, chunk_size: int = 14, ) -> torch.Tensor: """ :param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames :param chunk_size: the chunk size to encode video :return: image_embeddings in shape of [b, 1024] """ video_224 = _resize_with_antialiasing(video.float(), (224, 224)) video_224 = (video_224 + 1.0) / 2.0 # [-1, 1] -> [0, 1] embeddings = [] for i in range(0, video_224.shape[0], chunk_size): emb = self.feature_extractor( images=video_224[i : i + chunk_size], do_normalize=True, do_center_crop=False, do_resize=False, do_rescale=False, return_tensors="pt", ).pixel_values.to(video.device, dtype=video.dtype) embeddings.append(self.image_encoder(emb).image_embeds) # [b, 1024] embeddings = torch.cat(embeddings, dim=0) # [t, 1024] return embeddings @torch.inference_mode() def encode_vae_video( self, video: torch.Tensor, chunk_size: int = 14, ): """ :param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames :param chunk_size: the chunk size to encode video :return: vae latents in shape of [b, c, h, w] """ video_latents = [] for i in range(0, video.shape[0], chunk_size): video_latents.append( self.vae.encode(video[i : i + chunk_size]).latent_dist.mode() ) video_latents = torch.cat(video_latents, dim=0) return video_latents @torch.inference_mode() def produce_priors(self, prior_model, frame, chunk_size=8): T, _, H, W = frame.shape frame = (frame + 1) / 2 pred_point_maps = [] pred_masks = [] for i in range(0, len(frame), chunk_size): pred_p, pred_m = prior_model.forward_image(frame[i:i+chunk_size]) pred_point_maps.append(pred_p) pred_masks.append(pred_m) pred_point_maps = torch.cat(pred_point_maps, dim=0) pred_masks = torch.cat(pred_masks, dim=0) pred_masks = pred_masks.float() * 2 - 1 # T,H,W,3 T,H,W pred_point_maps = normalize_point_map(pred_point_maps, pred_masks > 0) pred_disps = 1.0 / pred_point_maps[..., 2].clamp_min(1e-3) pred_disps = pred_disps * (pred_masks > 0) min_disparity, max_disparity = robust_min_max(pred_disps) pred_disps = ((pred_disps - min_disparity) / (max_disparity - min_disparity+1e-4)).clamp(0, 1) pred_disps = pred_disps * 2 - 1 pred_point_maps[..., :2] = pred_point_maps[..., :2] / (pred_point_maps[..., 2:3] + 1e-7) pred_point_maps[..., 2] = torch.log(pred_point_maps[..., 2] + 1e-7) * (pred_masks > 0) # [x/z, y/z, log(z)] pred_intr_maps = point_map_xy2intrinsic_map(pred_point_maps[..., :2]).permute(0,3,1,2) # T,H,W,2 pred_point_maps = pred_point_maps.permute(0,3,1,2) return pred_disps, pred_masks, pred_point_maps, pred_intr_maps @torch.inference_mode() def encode_point_map(self, point_map_vae, disparity, valid_mask, point_map, intrinsic_map, chunk_size=8): T, _, H, W = point_map.shape latents = [] psedo_image = disparity[:, None].repeat(1,3,1,1) intrinsic_map = torch.norm(intrinsic_map[:, 2:4], p=2, dim=1, keepdim=False) for i in range(0, T, chunk_size): latent_dist = self.vae.encode(psedo_image[i : i + chunk_size].to(self.vae.dtype)).latent_dist latent_dist = point_map_vae.encode( torch.cat([ intrinsic_map[i:i+chunk_size, None], point_map[i:i+chunk_size, 2:3], disparity[i:i+chunk_size, None], valid_mask[i:i+chunk_size, None]], dim=1), latent_dist ) if isinstance(latent_dist, DiagonalGaussianDistribution): latent = latent_dist.mode() else: latent = latent_dist assert isinstance(latent, torch.Tensor) latents.append(latent) latents = torch.cat(latents, dim=0) latents = latents * self.vae.config.scaling_factor return latents @torch.no_grad() def decode_point_map(self, point_map_vae, latents, chunk_size=8, force_projection=True, force_fixed_focal=True, use_extract_interp=False, need_resize=False, height=None, width=None): T = latents.shape[0] rec_intrinsic_maps = [] rec_depth_maps = [] rec_valid_masks = [] for i in range(0, T, chunk_size): lat = latents[i:i+chunk_size] rec_imap, rec_dmap, rec_vmask = point_map_vae.decode( lat, num_frames=lat.shape[0], ) rec_intrinsic_maps.append(rec_imap) rec_depth_maps.append(rec_dmap) rec_valid_masks.append(rec_vmask) rec_intrinsic_maps = torch.cat(rec_intrinsic_maps, dim=0) rec_depth_maps = torch.cat(rec_depth_maps, dim=0) rec_valid_masks = torch.cat(rec_valid_masks, dim=0) if need_resize: rec_depth_maps = F.interpolate(rec_depth_maps, (height, width), mode='nearest-exact') if use_extract_interp else F.interpolate(rec_depth_maps, (height, width), mode='bilinear', align_corners=False) rec_valid_masks = F.interpolate(rec_valid_masks, (height, width), mode='nearest-exact') if use_extract_interp else F.interpolate(rec_valid_masks, (height, width), mode='bilinear', align_corners=False) rec_intrinsic_maps = F.interpolate(rec_intrinsic_maps, (height, width), mode='bilinear', align_corners=False) H, W = rec_intrinsic_maps.shape[-2], rec_intrinsic_maps.shape[-1] mesh_grid = create_meshgrid( H, W, normalized_coordinates=True ).to(rec_intrinsic_maps.device, rec_intrinsic_maps.dtype, non_blocking=True) # 1,h,w,2 rec_intrinsic_maps = torch.cat([rec_intrinsic_maps * W / np.sqrt(W**2+H**2), rec_intrinsic_maps * H / np.sqrt(W**2+H**2)], dim=1) # t,2,h,w mesh_grid = mesh_grid.permute(0,3,1,2) rec_valid_masks = rec_valid_masks.squeeze(1) > 0 if force_projection: if force_fixed_focal: nfx = (rec_intrinsic_maps[:, 0, :, :] * rec_valid_masks.float()).mean() / (rec_valid_masks.float().mean() + 1e-4) nfy = (rec_intrinsic_maps[:, 1, :, :] * rec_valid_masks.float()).mean() / (rec_valid_masks.float().mean() + 1e-4) rec_intrinsic_maps = torch.tensor([nfx, nfy], device=rec_intrinsic_maps.device)[None, :, None, None].repeat(T, 1, 1, 1) else: nfx = (rec_intrinsic_maps[:, 0, :, :] * rec_valid_masks.float()).mean(dim=[-1, -2]) / (rec_valid_masks.float().mean(dim=[-1, -2]) + 1e-4) nfy = (rec_intrinsic_maps[:, 1, :, :] * rec_valid_masks.float()).mean(dim=[-1, -2]) / (rec_valid_masks.float().mean(dim=[-1, -2]) + 1e-4) rec_intrinsic_maps = torch.stack([nfx, nfy], dim=-1)[:, :, None, None] # t,2,1,1 rec_point_maps = torch.cat([rec_intrinsic_maps * mesh_grid, rec_depth_maps], dim=1).permute(0,2,3,1) xy, z = rec_point_maps.split([2, 1], dim=-1) z = torch.clamp_max(z, 10) # for numerical stability z = torch.exp(z) rec_point_maps = torch.cat([xy * z, z], dim=-1) return rec_point_maps, rec_valid_masks @torch.no_grad() def __call__( self, video: Union[np.ndarray, torch.Tensor], point_map_vae, prior_model, height: int = 320, width: int = 640, num_inference_steps: int = 5, guidance_scale: float = 1.0, window_size: Optional[int] = 14, noise_aug_strength: float = 0.02, decode_chunk_size: Optional[int] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], overlap: int = 4, force_projection: bool = True, force_fixed_focal: bool = True, use_extract_interp: bool = False, track_time: bool = False, ): # video: in shape [t, h, w, c] if np.ndarray or [t, c, h, w] if torch.Tensor, in range [0, 1] # 0. Default height and width to unet if isinstance(video, np.ndarray): video = torch.from_numpy(video.transpose(0, 3, 1, 2)) else: assert isinstance(video, torch.Tensor) height = height or video.shape[-2] width = width or video.shape[-1] original_height = video.shape[-2] original_width = video.shape[-1] num_frames = video.shape[0] decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else 8 if num_frames <= window_size: window_size = num_frames overlap = 0 stride = window_size - overlap # 1. Check inputs. Raise error if not correct assert height % 64 == 0 and width % 64 == 0 if original_height != height or original_width != width: need_resize = True else: need_resize = False # 2. Define call parameters batch_size = 1 device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. self._guidance_scale = guidance_scale if track_time: start_event = torch.cuda.Event(enable_timing=True) prior_event = torch.cuda.Event(enable_timing=True) encode_event = torch.cuda.Event(enable_timing=True) denoise_event = torch.cuda.Event(enable_timing=True) decode_event = torch.cuda.Event(enable_timing=True) start_event.record() # 3. Encode input video pred_disparity, pred_valid_mask, pred_point_map, pred_intrinsic_map = self.produce_priors( prior_model, video.to(device=device, dtype=torch.float32), chunk_size=decode_chunk_size ) # T,H,W T,H,W T,3,H,W T,2,H,W if need_resize: pred_disparity = F.interpolate(pred_disparity.unsqueeze(1), (height, width), mode='bilinear', align_corners=False).squeeze(1) pred_valid_mask = F.interpolate(pred_valid_mask.unsqueeze(1), (height, width), mode='bilinear', align_corners=False).squeeze(1) pred_point_map = F.interpolate(pred_point_map, (height, width), mode='bilinear', align_corners=False) pred_intrinsic_map = F.interpolate(pred_intrinsic_map, (height, width), mode='bilinear', align_corners=False) if track_time: prior_event.record() torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(prior_event) print(f"Elapsed time for computing per-frame prior: {elapsed_time_ms} ms") else: gc.collect() torch.cuda.empty_cache() # 3. Encode input video if need_resize: video = F.interpolate(video, (height, width), mode="bicubic", align_corners=False, antialias=True).clamp(0, 1) video = video.to(device=device, dtype=self.dtype) video = video * 2.0 - 1.0 # [0,1] -> [-1,1], in [t, c, h, w] video_embeddings = self.encode_video(video, chunk_size=decode_chunk_size).unsqueeze(0) prior_latents = self.encode_point_map( point_map_vae, pred_disparity, pred_valid_mask, pred_point_map, pred_intrinsic_map, chunk_size=decode_chunk_size ).unsqueeze(0).to(video_embeddings.dtype) # 1,T,C,H,W # 4. Encode input image using VAE # pdb.set_trace() needs_upcasting = ( self.vae.dtype == torch.float16 and self.vae.config.force_upcast ) if needs_upcasting: self.vae.to(dtype=torch.float32) video_latents = self.encode_vae_video( video.to(self.vae.dtype), chunk_size=decode_chunk_size, ).unsqueeze(0).to(video_embeddings.dtype) # [1, t, c, h, w] torch.cuda.empty_cache() if track_time: encode_event.record() torch.cuda.synchronize() elapsed_time_ms = prior_event.elapsed_time(encode_event) print(f"Elapsed time for encode prior and frames: {elapsed_time_ms} ms") else: gc.collect() torch.cuda.empty_cache() # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) # 5. Get Added Time IDs added_time_ids = self._get_add_time_ids( 7, 127, noise_aug_strength, video_embeddings.dtype, batch_size, 1, False, ) # [1 or 2, 3] added_time_ids = added_time_ids.to(device) # 6. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, None, None ) num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) # 7. Prepare latent variables # num_channels_latents = self.unet.config.in_channels - prior_latents.shape[1] num_channels_latents = 8 latents_init = self.prepare_latents( batch_size, window_size, num_channels_latents, height, width, video_embeddings.dtype, device, generator, latents, ) # [1, t, c, h, w] latents_all = None idx_start = 0 if overlap > 0: weights = torch.linspace(0, 1, overlap, device=device) weights = weights.view(1, overlap, 1, 1, 1) else: weights = None while idx_start < num_frames - overlap: idx_end = min(idx_start + window_size, num_frames) self.scheduler.set_timesteps(num_inference_steps, device=device) # 9. Denoising loop # latents_init = latents_init.flip(1) latents = latents_init[:, : idx_end - idx_start].clone() latents_init = torch.cat( [latents_init[:, -overlap:], latents_init[:, :stride]], dim=1 ) video_latents_current = video_latents[:, idx_start:idx_end] prior_latents_current = prior_latents[:, idx_start:idx_end] video_embeddings_current = video_embeddings[:, idx_start:idx_end] with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if latents_all is not None and i == 0: latents[:, :overlap] = ( latents_all[:, -overlap:] + latents[:, :overlap] / self.scheduler.init_noise_sigma * self.scheduler.sigmas[i] ) latent_model_input = latents latent_model_input = self.scheduler.scale_model_input( latent_model_input, t ) # [1 or 2, t, c, h, w] latent_model_input = torch.cat( [latent_model_input, video_latents_current, prior_latents_current], dim=2 ) noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=video_embeddings_current, added_time_ids=added_time_ids, return_dict=False, )[0] # pdb.set_trace() # perform guidance if self.do_classifier_free_guidance: latent_model_input = latents latent_model_input = self.scheduler.scale_model_input( latent_model_input, t ) latent_model_input = torch.cat( [latent_model_input, torch.zeros_like(latent_model_input), torch.zeros_like(latent_model_input)], dim=2, ) noise_pred_uncond = self.unet( latent_model_input, t, encoder_hidden_states=torch.zeros_like( video_embeddings_current ), added_time_ids=added_time_ids, return_dict=False, )[0] noise_pred = noise_pred_uncond + self.guidance_scale * ( noise_pred - noise_pred_uncond ) latents = self.scheduler.step(noise_pred, t, latents).prev_sample if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end( self, i, t, callback_kwargs ) latents = callback_outputs.pop("latents", latents) if i == len(timesteps) - 1 or ( (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 ): progress_bar.update() if latents_all is None: latents_all = latents.clone() else: if overlap > 0: latents_all[:, -overlap:] = latents[ :, :overlap ] * weights + latents_all[:, -overlap:] * (1 - weights) latents_all = torch.cat([latents_all, latents[:, overlap:]], dim=1) idx_start += stride latents_all = 1 / self.vae.config.scaling_factor * latents_all.squeeze(0).to(torch.float32) if track_time: denoise_event.record() torch.cuda.synchronize() elapsed_time_ms = encode_event.elapsed_time(denoise_event) print(f"Elapsed time for denoise latent: {elapsed_time_ms} ms") else: gc.collect() torch.cuda.empty_cache() point_map, valid_mask = self.decode_point_map( point_map_vae, latents_all, chunk_size=decode_chunk_size, force_projection=force_projection, force_fixed_focal=force_fixed_focal, use_extract_interp=use_extract_interp, need_resize=need_resize, height=original_height, width=original_width) if track_time: decode_event.record() torch.cuda.synchronize() elapsed_time_ms = denoise_event.elapsed_time(decode_event) print(f"Elapsed time for decode latent: {elapsed_time_ms} ms") else: gc.collect() torch.cuda.empty_cache() self.maybe_free_model_hooks() # t,h,w,3 t,h,w return point_map, valid_mask