import torchvision.io from einops import rearrange, repeat import numpy as np import inspect from typing import List, Optional, Union, Tuple import os import PIL import torch import torchaudio import torchvision.io import torchvision.transforms as transforms from transformers import ImageProcessingMixin from diffusers.loaders import TextualInversionLoaderMixin from diffusers.models import AutoencoderKL from diffusers.schedulers import KarrasDiffusionSchedulers, PNDMScheduler from diffusers.utils import logging from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.image_processor import VaeImageProcessor from unet import AudioUNet3DConditionModel from audio_encoder import ImageBindSegmaskAudioEncoder from imagebind.data import waveform2melspec logger = logging.get_logger(__name__) # pylint: disable=invalid-name def waveform_to_melspectrogram( waveform: Union[np.ndarray, torch.Tensor], num_mel_bins=128, target_length=204, sample_rate=16000, clip_duration=2., mean=-4.268, std=9.138 ): if isinstance(waveform, np.ndarray): waveform = torch.from_numpy(waveform) audio_length = waveform.shape[1] audio_target_length = int(clip_duration * sample_rate) audio_start_idx = 0 if audio_length > audio_target_length: audio_start_idx = (audio_length - audio_target_length) // 2 audio_end_idx = audio_start_idx + audio_target_length waveform_clip = waveform[:, audio_start_idx:audio_end_idx] waveform_melspec = waveform2melspec( waveform_clip, sample_rate, num_mel_bins, target_length ) # (1, n_mel, n_frame) normalize = transforms.Normalize(mean=mean, std=std) audio_clip = normalize(waveform_melspec) return audio_clip # (1, freq, time) class AudioMelspectrogramExtractor(ImageProcessingMixin): def __init__( self, num_mel_bins=128, target_length=204, sample_rate=16000, clip_duration=2, mean=-4.268, std=9.138 ): super().__init__() self.num_mel_bins = num_mel_bins self.target_length = target_length self.sample_rate = sample_rate self.clip_duration = clip_duration self.mean = mean self.std = std @property def max_length_s(self) -> int: return self.clip_duration @property def sampling_rate(self) -> int: return self.sample_rate def __call__( self, waveforms: Union[ np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor] ] ): if isinstance(waveforms, (np.ndarray, torch.Tensor)) and waveforms.ndim == 2: waveforms = [waveforms, ] features = [] for waveform in waveforms: feature = waveform_to_melspectrogram( waveform=waveform, num_mel_bins=self.num_mel_bins, target_length=self.target_length, sample_rate=self.sample_rate, clip_duration=self.clip_duration, mean=self.mean, std=self.std ) features.append(feature) features = torch.stack(features, dim=0) return features # (b c n t) class AudioCondAnimationPipeline(DiffusionPipeline, TextualInversionLoaderMixin): """ Pipeline for text-guided image to image generation using stable unCLIP. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: feature_extractor ([`CLIPImageProcessor`]): Feature extractor for image pre-processing before being encoded. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`KarrasDiffusionSchedulers`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. """ unet: AudioUNet3DConditionModel scheduler: KarrasDiffusionSchedulers vae: AutoencoderKL audio_encoder: ImageBindSegmaskAudioEncoder def __init__( self, unet: AudioUNet3DConditionModel, scheduler: KarrasDiffusionSchedulers, vae: AutoencoderKL, audio_encoder: ImageBindSegmaskAudioEncoder, null_text_encodings_path: str = "" ): super().__init__() self.register_modules( unet=unet, scheduler=scheduler, vae=vae, audio_encoder=audio_encoder ) if null_text_encodings_path: self.null_text_encoding = torch.load(null_text_encodings_path).view(1, 77, 768) self.melspectrogram_shape = (128, 204) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.audio_processor = AudioMelspectrogramExtractor() @torch.no_grad() def encode_text( self, text_encodings, device, dtype, do_text_classifier_free_guidance, do_audio_classifier_free_guidance, ): if isinstance(text_encodings, (List, Tuple)): text_encodings = torch.cat(text_encodings) text_encodings = text_encodings.to(dtype=dtype, device=device) batch_size = len(text_encodings) # get unconditional embeddings for classifier free guidance if do_text_classifier_free_guidance: if not hasattr(self, "null_text_encoding"): uncond_token = "" max_length = text_encodings.shape[1] uncond_input = self.tokenizer( uncond_token, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None uncond_text_encodings = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) uncond_text_encodings = uncond_text_encodings[0] else: uncond_text_encodings = self.null_text_encoding uncond_text_encodings = repeat(uncond_text_encodings, "1 n d -> b n d", b=batch_size).contiguous() uncond_text_encodings = uncond_text_encodings.to(dtype=dtype, device=device) if do_text_classifier_free_guidance and do_audio_classifier_free_guidance: # dual cfg text_encodings = torch.cat([uncond_text_encodings, text_encodings, text_encodings]) elif do_text_classifier_free_guidance: # only text cfg text_encodings = torch.cat([uncond_text_encodings, text_encodings]) elif do_audio_classifier_free_guidance: # only audio cfg text_encodings = torch.cat([text_encodings, text_encodings]) return text_encodings @torch.no_grad() def encode_audio( self, audios: Union[List[np.ndarray], List[torch.Tensor]], video_length: int = 12, do_text_classifier_free_guidance: bool = False, do_audio_classifier_free_guidance: bool = False, device: torch.device = torch.device("cuda:0"), dtype: torch.dtype = torch.float32 ): batch_size = len(audios) melspectrograms = self.audio_processor(audios).to(device=device, dtype=dtype) # (b c n t) # audio_encodings: (b, n, c) # audio_masks: (b, s, n) _, audio_encodings, audio_masks = self.audio_encoder( melspectrograms, normalize=False, return_dict=False ) audio_encodings = repeat(audio_encodings, "b n c -> b f n c", f=video_length) if do_audio_classifier_free_guidance: null_melspectrograms = torch.zeros(1, 1, *self.melspectrogram_shape).to(device=device, dtype=dtype) _, null_audio_encodings, null_audio_masks = self.audio_encoder( null_melspectrograms, normalize=False, return_dict=False ) null_audio_encodings = repeat(null_audio_encodings, "1 n c -> b f n c", b=batch_size, f=video_length) if do_text_classifier_free_guidance and do_audio_classifier_free_guidance: # dual cfg audio_encodings = torch.cat([null_audio_encodings, null_audio_encodings, audio_encodings]) audio_masks = torch.cat([null_audio_masks, null_audio_masks, audio_masks]) elif do_text_classifier_free_guidance: # only text cfg audio_encodings = torch.cat([audio_encodings, audio_encodings]) audio_masks = torch.cat([audio_masks, audio_masks]) elif do_audio_classifier_free_guidance: # only audio cfg audio_encodings = torch.cat([null_audio_encodings, audio_encodings]) audio_masks = torch.cat([null_audio_masks, audio_masks]) return audio_encodings, audio_masks @torch.no_grad() def encode_latents(self, image: torch.Tensor): dtype = self.vae.dtype image = image.to(device=self.device, dtype=dtype) image_latents = self.vae.encode(image).latent_dist.sample() * self.vae.config.scaling_factor return image_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents @torch.no_grad() def decode_latents(self, latents): dtype = next(self.vae.parameters()).dtype latents = latents.to(dtype=dtype) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1).cpu().float() # ((b t) c h w) return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_video_latents( self, image_latents: torch.Tensor, num_channels_latents: int, video_length: int = 12, height: int = 256, width: int = 256, device: torch.device = torch.device("cuda"), dtype: torch.dtype = torch.float32, generator: Optional[torch.Generator] = None, ): batch_size = len(image_latents) shape = ( batch_size, num_channels_latents, video_length - 1, height // self.vae_scale_factor, width // self.vae_scale_factor ) image_latents = image_latents.unsqueeze(2) # (b c 1 h w) rand_noise = torch.randn(shape, generator=generator, device=device, dtype=dtype) noise_latents = torch.cat([image_latents, rand_noise], dim=2) # scale the initial noise by the standard deviation required by the scheduler noise_latents = noise_latents * self.scheduler.init_noise_sigma return noise_latents @torch.no_grad() def __call__( self, images: List[PIL.Image.Image], audios: Union[List[np.ndarray], List[torch.Tensor]], text_encodings: List[torch.Tensor], video_length: int = 12, height: int = 256, width: int = 256, num_inference_steps: int = 20, audio_guidance_scale: float = 4.0, text_guidance_scale: float = 1.0, generator: Optional[torch.Generator] = None, return_dict: bool = True ): # 0. Default height and width to unet device = self.device dtype = self.dtype batch_size = len(images) height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor do_text_classifier_free_guidance = (text_guidance_scale > 1.0) do_audio_classifier_free_guidance = (audio_guidance_scale > 1.0) # 1. Encoder text into ((k b) f n d) text_encodings = self.encode_text( text_encodings=text_encodings, device=device, dtype=dtype, do_text_classifier_free_guidance=do_text_classifier_free_guidance, do_audio_classifier_free_guidance=do_audio_classifier_free_guidance ) # ((k b), n, d) text_encodings = repeat(text_encodings, "b n d -> b t n d", t=video_length).to(device=device, dtype=dtype) # 2. Encode audio # audio_encodings: ((k b), n, d) # audio_masks: ((k b), s, n) audio_encodings, audio_masks = self.encode_audio( audios, video_length, do_text_classifier_free_guidance, do_audio_classifier_free_guidance, device, dtype ) # 3. Prepare image latent image = self.image_processor.preprocess(images) image_latents = self.encode_latents(image).to(device=device, dtype=dtype) # (b c h w) # 4. Prepare unet noising video latents video_latents = self.prepare_video_latents( image_latents=image_latents, num_channels_latents=self.unet.config.in_channels, video_length=video_length, height=height, width=width, dtype=dtype, device=device, generator=generator, ) # (b c f h w) # 5. Prepare timesteps and extra step kwargs self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta=0.0) # 7. Denoising loop for i, t in enumerate(self.progress_bar(timesteps)): latent_model_input = [video_latents] if do_text_classifier_free_guidance: latent_model_input.append(video_latents) if do_audio_classifier_free_guidance: latent_model_input.append(video_latents) latent_model_input = torch.cat(latent_model_input) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=text_encodings, audio_encoder_hidden_states=audio_encodings, audio_attention_mask=audio_masks ).sample # perform guidance if do_text_classifier_free_guidance and do_audio_classifier_free_guidance: # dual cfg noise_pred_uncond, noise_pred_text, noise_pred_text_audio = noise_pred.chunk(3) noise_pred = noise_pred_uncond + \ text_guidance_scale * (noise_pred_text - noise_pred_uncond) + \ audio_guidance_scale * (noise_pred_text_audio - noise_pred_text) elif do_text_classifier_free_guidance: # only text cfg noise_pred_audio, noise_pred_text_audio = noise_pred.chunk(2) noise_pred = noise_pred_audio + \ text_guidance_scale * (noise_pred_text_audio - noise_pred_audio) elif do_audio_classifier_free_guidance: # only audio cfg noise_pred_text, noise_pred_text_audio = noise_pred.chunk(2) noise_pred = noise_pred_text + \ audio_guidance_scale * (noise_pred_text_audio - noise_pred_text) # First frame latent will always server as unchanged condition video_latents[:, :, 1:, :, :] = self.scheduler.step(noise_pred[:, :, 1:, :, :], t, video_latents[:, :, 1:, :, :], **extra_step_kwargs).prev_sample video_latents = video_latents.contiguous() # 8. Post-processing video_latents = rearrange(video_latents, "b c f h w -> (b f) c h w") videos = self.decode_latents(video_latents).detach().cpu() videos = rearrange(videos, "(b f) c h w -> b f c h w", f=video_length) # value range [0, 1] if not return_dict: return videos return {"videos": videos} def load_and_transform_images_stable_diffusion( images: Union[List[np.ndarray], torch.Tensor, np.ndarray], size=512, flip=False, randcrop=False, normalize=True ): """ @images: (List of) np.uint8 images of shape (h, w, 3) or tensor of shape (b, c, h, w) in [0., 1.0] """ assert isinstance(images, (List, torch.Tensor, np.ndarray)), type(images) if isinstance(images, List): assert isinstance(images[0], np.ndarray) assert images[0].dtype == np.uint8 assert images[0].shape[2] == 3 # convert np images into torch float tensor images = torch.from_numpy( rearrange(np.stack(images, axis=0), "f h w c -> f c h w") ).float() / 255. elif isinstance(images, np.ndarray): assert isinstance(images, np.ndarray) assert images.dtype == np.uint8 assert images.shape[3] == 3 # convert np images into torch float tensor images = torch.from_numpy( rearrange(images, "f h w c -> f c h w") ).float() / 255. assert images.shape[1] == 3 assert torch.all(images <= 1.0) and torch.all(images >= 0.0) h, w = images.shape[-2:] if isinstance(size, int): target_h, target_w = size, size else: target_h, target_w = size # first crop the image target_aspect_ratio = float(target_h) / target_w curr_aspect_ratio = float(h) / w if target_aspect_ratio >= curr_aspect_ratio: # trim w trimmed_w = int(h / target_aspect_ratio) images = images[:, :, :, (w - trimmed_w) // 2: (w - trimmed_w) // 2 + trimmed_w] else: # trim h trimmed_h = int(w * target_aspect_ratio) images = images[:, :, (h - trimmed_h) // 2: (h - trimmed_h) // 2 + trimmed_h] transform_list = [ transforms.Resize( size, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True ), ] # assert not randcrop if randcrop: transform_list.append(transforms.RandomCrop(size)) else: transform_list.append(transforms.CenterCrop(size)) if flip: transform_list.append(transforms.RandomHorizontalFlip(p=1.0)) if normalize: transform_list.append(transforms.Normalize([0.5], [0.5])) data_transform = transforms.Compose(transform_list) images = data_transform(images) return images def load_image(image_path): image = PIL.Image.open(image_path).convert('RGB') width, height = image.size if width < height: new_width = 256 new_height = int((256 / width) * height) else: new_height = 256 new_width = int((256 / height) * width) # Rescale the image image = image.resize((new_width, new_height), PIL.Image.LANCZOS) # Crop a 256x256 square from the center left = (new_width - 256) / 2 top = (new_height - 256) / 2 right = (new_width + 256) / 2 bottom = (new_height + 256) / 2 image = image.crop((left, top, right, bottom)) return image def load_audio(audio_path): audio, audio_sr = torchaudio.load(audio_path) if audio.ndim == 1: audio = audio.unsqueeze(0) else: audio = audio.mean(dim=0).unsqueeze(0) audio = torchaudio.functional.resample(audio, orig_freq=audio_sr, new_freq=16000) audio = audio[:, :32000].contiguous().float() if audio.shape[1] < 32000: audio = torch.cat([audio, torch.ones(1, 32000-audio.shape[1]).float()], dim=1) return audio.contiguous() @torch.no_grad() def generate_videos( pipeline, image_path: str = '', audio_path: str = '', category_text_encoding: Optional[torch.Tensor] = None, image_size: Tuple[int, int] = (256, 256), video_fps: int = 6, video_num_frame: int = 12, audio_guidance_scale: float = 4.0, denoising_step: int = 20, text_guidance_scale: float = 1.0, seed: int = 0, save_path: str = "", device: torch.device = torch.device("cuda"), ): image = load_image(image_path) audio = load_audio(audio_path) generator = torch.Generator(device=device) generator.manual_seed(seed) generated_video = pipeline( images=[image], audios=[audio], text_encodings=[category_text_encoding], video_length=video_num_frame, height=image_size[0], width=image_size[1], num_inference_steps=denoising_step, audio_guidance_scale=audio_guidance_scale, text_guidance_scale=text_guidance_scale, generator=generator, return_dict=False )[0] # (f c h w) in range [0, 1] generated_video = (generated_video.permute(0, 2, 3, 1).contiguous() * 255).byte() os.makedirs(os.path.dirname(save_path), exist_ok=True) torchvision.io.write_video( filename=save_path, video_array=generated_video, fps=video_fps, audio_array=audio, audio_fps=16000, audio_codec="aac" ) return