Spaces:
Running
on
Zero
Running
on
Zero
import torchvision.io | |
from einops import rearrange, repeat | |
import numpy as np | |
import inspect | |
from typing import List, Optional, Union, Tuple | |
import os | |
import PIL | |
import torch | |
import torchaudio | |
import torchvision.io | |
import torchvision.transforms as transforms | |
from transformers import ImageProcessingMixin | |
from diffusers.loaders import TextualInversionLoaderMixin | |
from diffusers.models import AutoencoderKL | |
from diffusers.schedulers import KarrasDiffusionSchedulers, PNDMScheduler | |
from diffusers.utils import logging | |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline | |
from diffusers.image_processor import VaeImageProcessor | |
from unet import AudioUNet3DConditionModel | |
from audio_encoder import ImageBindSegmaskAudioEncoder | |
from imagebind.data import waveform2melspec | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
def waveform_to_melspectrogram( | |
waveform: Union[np.ndarray, torch.Tensor], | |
num_mel_bins=128, | |
target_length=204, | |
sample_rate=16000, | |
clip_duration=2., | |
mean=-4.268, | |
std=9.138 | |
): | |
if isinstance(waveform, np.ndarray): | |
waveform = torch.from_numpy(waveform) | |
audio_length = waveform.shape[1] | |
audio_target_length = int(clip_duration * sample_rate) | |
audio_start_idx = 0 | |
if audio_length > audio_target_length: | |
audio_start_idx = (audio_length - audio_target_length) // 2 | |
audio_end_idx = audio_start_idx + audio_target_length | |
waveform_clip = waveform[:, audio_start_idx:audio_end_idx] | |
waveform_melspec = waveform2melspec( | |
waveform_clip, sample_rate, num_mel_bins, target_length | |
) # (1, n_mel, n_frame) | |
normalize = transforms.Normalize(mean=mean, std=std) | |
audio_clip = normalize(waveform_melspec) | |
return audio_clip # (1, freq, time) | |
class AudioMelspectrogramExtractor(ImageProcessingMixin): | |
def __init__( | |
self, | |
num_mel_bins=128, | |
target_length=204, | |
sample_rate=16000, | |
clip_duration=2, | |
mean=-4.268, | |
std=9.138 | |
): | |
super().__init__() | |
self.num_mel_bins = num_mel_bins | |
self.target_length = target_length | |
self.sample_rate = sample_rate | |
self.clip_duration = clip_duration | |
self.mean = mean | |
self.std = std | |
def max_length_s(self) -> int: | |
return self.clip_duration | |
def sampling_rate(self) -> int: | |
return self.sample_rate | |
def __call__( | |
self, | |
waveforms: Union[ | |
np.ndarray, | |
torch.Tensor, | |
List[np.ndarray], | |
List[torch.Tensor] | |
] | |
): | |
if isinstance(waveforms, (np.ndarray, torch.Tensor)) and waveforms.ndim == 2: | |
waveforms = [waveforms, ] | |
features = [] | |
for waveform in waveforms: | |
feature = waveform_to_melspectrogram( | |
waveform=waveform, | |
num_mel_bins=self.num_mel_bins, | |
target_length=self.target_length, | |
sample_rate=self.sample_rate, | |
clip_duration=self.clip_duration, | |
mean=self.mean, | |
std=self.std | |
) | |
features.append(feature) | |
features = torch.stack(features, dim=0) | |
return features # (b c n t) | |
class AudioCondAnimationPipeline(DiffusionPipeline, TextualInversionLoaderMixin): | |
""" | |
Pipeline for text-guided image to image generation using stable unCLIP. | |
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the | |
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) | |
Args: | |
feature_extractor ([`CLIPImageProcessor`]): | |
Feature extractor for image pre-processing before being encoded. | |
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. | |
scheduler ([`KarrasDiffusionSchedulers`]): | |
A scheduler to be used in combination with `unet` to denoise the encoded image latents. | |
vae ([`AutoencoderKL`]): | |
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. | |
""" | |
unet: AudioUNet3DConditionModel | |
scheduler: KarrasDiffusionSchedulers | |
vae: AutoencoderKL | |
audio_encoder: ImageBindSegmaskAudioEncoder | |
def __init__( | |
self, | |
unet: AudioUNet3DConditionModel, | |
scheduler: KarrasDiffusionSchedulers, | |
vae: AutoencoderKL, | |
audio_encoder: ImageBindSegmaskAudioEncoder, | |
null_text_encodings_path: str = "" | |
): | |
super().__init__() | |
self.register_modules( | |
unet=unet, | |
scheduler=scheduler, | |
vae=vae, | |
audio_encoder=audio_encoder | |
) | |
if null_text_encodings_path: | |
self.null_text_encoding = torch.load(null_text_encodings_path).view(1, 77, 768) | |
self.melspectrogram_shape = (128, 204) | |
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) | |
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) | |
self.audio_processor = AudioMelspectrogramExtractor() | |
def encode_text( | |
self, | |
text_encodings, | |
device, | |
dtype, | |
do_text_classifier_free_guidance, | |
do_audio_classifier_free_guidance, | |
): | |
if isinstance(text_encodings, (List, Tuple)): | |
text_encodings = torch.cat(text_encodings) | |
text_encodings = text_encodings.to(dtype=dtype, device=device) | |
batch_size = len(text_encodings) | |
# get unconditional embeddings for classifier free guidance | |
if do_text_classifier_free_guidance: | |
if not hasattr(self, "null_text_encoding"): | |
uncond_token = "" | |
max_length = text_encodings.shape[1] | |
uncond_input = self.tokenizer( | |
uncond_token, | |
padding="max_length", | |
max_length=max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
if hasattr(self.text_encoder.config, | |
"use_attention_mask") and self.text_encoder.config.use_attention_mask: | |
attention_mask = uncond_input.attention_mask.to(device) | |
else: | |
attention_mask = None | |
uncond_text_encodings = self.text_encoder( | |
uncond_input.input_ids.to(device), | |
attention_mask=attention_mask, | |
) | |
uncond_text_encodings = uncond_text_encodings[0] | |
else: | |
uncond_text_encodings = self.null_text_encoding | |
uncond_text_encodings = repeat(uncond_text_encodings, "1 n d -> b n d", b=batch_size).contiguous() | |
uncond_text_encodings = uncond_text_encodings.to(dtype=dtype, device=device) | |
if do_text_classifier_free_guidance and do_audio_classifier_free_guidance: # dual cfg | |
text_encodings = torch.cat([uncond_text_encodings, text_encodings, text_encodings]) | |
elif do_text_classifier_free_guidance: # only text cfg | |
text_encodings = torch.cat([uncond_text_encodings, text_encodings]) | |
elif do_audio_classifier_free_guidance: # only audio cfg | |
text_encodings = torch.cat([text_encodings, text_encodings]) | |
return text_encodings | |
def encode_audio( | |
self, | |
audios: Union[List[np.ndarray], List[torch.Tensor]], | |
video_length: int = 12, | |
do_text_classifier_free_guidance: bool = False, | |
do_audio_classifier_free_guidance: bool = False, | |
device: torch.device = torch.device("cuda:0"), | |
dtype: torch.dtype = torch.float32 | |
): | |
batch_size = len(audios) | |
melspectrograms = self.audio_processor(audios).to(device=device, dtype=dtype) # (b c n t) | |
# audio_encodings: (b, n, c) | |
# audio_masks: (b, s, n) | |
_, audio_encodings, audio_masks = self.audio_encoder( | |
melspectrograms, normalize=False, return_dict=False | |
) | |
audio_encodings = repeat(audio_encodings, "b n c -> b f n c", f=video_length) | |
if do_audio_classifier_free_guidance: | |
null_melspectrograms = torch.zeros(1, 1, *self.melspectrogram_shape).to(device=device, dtype=dtype) | |
_, null_audio_encodings, null_audio_masks = self.audio_encoder( | |
null_melspectrograms, normalize=False, return_dict=False | |
) | |
null_audio_encodings = repeat(null_audio_encodings, "1 n c -> b f n c", b=batch_size, f=video_length) | |
if do_text_classifier_free_guidance and do_audio_classifier_free_guidance: # dual cfg | |
audio_encodings = torch.cat([null_audio_encodings, null_audio_encodings, audio_encodings]) | |
audio_masks = torch.cat([null_audio_masks, null_audio_masks, audio_masks]) | |
elif do_text_classifier_free_guidance: # only text cfg | |
audio_encodings = torch.cat([audio_encodings, audio_encodings]) | |
audio_masks = torch.cat([audio_masks, audio_masks]) | |
elif do_audio_classifier_free_guidance: # only audio cfg | |
audio_encodings = torch.cat([null_audio_encodings, audio_encodings]) | |
audio_masks = torch.cat([null_audio_masks, audio_masks]) | |
return audio_encodings, audio_masks | |
def encode_latents(self, image: torch.Tensor): | |
dtype = self.vae.dtype | |
image = image.to(device=self.device, dtype=dtype) | |
image_latents = self.vae.encode(image).latent_dist.sample() * self.vae.config.scaling_factor | |
return image_latents | |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents | |
def decode_latents(self, latents): | |
dtype = next(self.vae.parameters()).dtype | |
latents = latents.to(dtype=dtype) | |
latents = 1 / self.vae.config.scaling_factor * latents | |
image = self.vae.decode(latents).sample | |
image = (image / 2 + 0.5).clamp(0, 1).cpu().float() # ((b t) c h w) | |
return image | |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs | |
def prepare_extra_step_kwargs(self, generator, eta): | |
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | |
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | |
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | |
# and should be between [0, 1] | |
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) | |
extra_step_kwargs = {} | |
if accepts_eta: | |
extra_step_kwargs["eta"] = eta | |
# check if the scheduler accepts generator | |
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) | |
if accepts_generator: | |
extra_step_kwargs["generator"] = generator | |
return extra_step_kwargs | |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents | |
def prepare_video_latents( | |
self, | |
image_latents: torch.Tensor, | |
num_channels_latents: int, | |
video_length: int = 12, | |
height: int = 256, | |
width: int = 256, | |
device: torch.device = torch.device("cuda"), | |
dtype: torch.dtype = torch.float32, | |
generator: Optional[torch.Generator] = None, | |
): | |
batch_size = len(image_latents) | |
shape = ( | |
batch_size, | |
num_channels_latents, | |
video_length - 1, | |
height // self.vae_scale_factor, | |
width // self.vae_scale_factor | |
) | |
image_latents = image_latents.unsqueeze(2) # (b c 1 h w) | |
rand_noise = torch.randn(shape, generator=generator, device=device, dtype=dtype) | |
noise_latents = torch.cat([image_latents, rand_noise], dim=2) | |
# scale the initial noise by the standard deviation required by the scheduler | |
noise_latents = noise_latents * self.scheduler.init_noise_sigma | |
return noise_latents | |
def __call__( | |
self, | |
images: List[PIL.Image.Image], | |
audios: Union[List[np.ndarray], List[torch.Tensor]], | |
text_encodings: List[torch.Tensor], | |
video_length: int = 12, | |
height: int = 256, | |
width: int = 256, | |
num_inference_steps: int = 20, | |
audio_guidance_scale: float = 4.0, | |
text_guidance_scale: float = 1.0, | |
generator: Optional[torch.Generator] = None, | |
return_dict: bool = True | |
): | |
# 0. Default height and width to unet | |
device = self.device | |
dtype = self.dtype | |
batch_size = len(images) | |
height = height or self.unet.config.sample_size * self.vae_scale_factor | |
width = width or self.unet.config.sample_size * self.vae_scale_factor | |
do_text_classifier_free_guidance = (text_guidance_scale > 1.0) | |
do_audio_classifier_free_guidance = (audio_guidance_scale > 1.0) | |
# 1. Encoder text into ((k b) f n d) | |
text_encodings = self.encode_text( | |
text_encodings=text_encodings, | |
device=device, | |
dtype=dtype, | |
do_text_classifier_free_guidance=do_text_classifier_free_guidance, | |
do_audio_classifier_free_guidance=do_audio_classifier_free_guidance | |
) # ((k b), n, d) | |
text_encodings = repeat(text_encodings, "b n d -> b t n d", t=video_length).to(device=device, dtype=dtype) | |
# 2. Encode audio | |
# audio_encodings: ((k b), n, d) | |
# audio_masks: ((k b), s, n) | |
audio_encodings, audio_masks = self.encode_audio( | |
audios, video_length, do_text_classifier_free_guidance, do_audio_classifier_free_guidance, device, dtype | |
) | |
# 3. Prepare image latent | |
image = self.image_processor.preprocess(images) | |
image_latents = self.encode_latents(image).to(device=device, dtype=dtype) # (b c h w) | |
# 4. Prepare unet noising video latents | |
video_latents = self.prepare_video_latents( | |
image_latents=image_latents, | |
num_channels_latents=self.unet.config.in_channels, | |
video_length=video_length, | |
height=height, | |
width=width, | |
dtype=dtype, | |
device=device, | |
generator=generator, | |
) # (b c f h w) | |
# 5. Prepare timesteps and extra step kwargs | |
self.scheduler.set_timesteps(num_inference_steps, device=device) | |
timesteps = self.scheduler.timesteps | |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta=0.0) | |
# 7. Denoising loop | |
for i, t in enumerate(self.progress_bar(timesteps)): | |
latent_model_input = [video_latents] | |
if do_text_classifier_free_guidance: | |
latent_model_input.append(video_latents) | |
if do_audio_classifier_free_guidance: | |
latent_model_input.append(video_latents) | |
latent_model_input = torch.cat(latent_model_input) | |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
# predict the noise residual | |
noise_pred = self.unet( | |
latent_model_input, | |
t, | |
encoder_hidden_states=text_encodings, | |
audio_encoder_hidden_states=audio_encodings, | |
audio_attention_mask=audio_masks | |
).sample | |
# perform guidance | |
if do_text_classifier_free_guidance and do_audio_classifier_free_guidance: # dual cfg | |
noise_pred_uncond, noise_pred_text, noise_pred_text_audio = noise_pred.chunk(3) | |
noise_pred = noise_pred_uncond + \ | |
text_guidance_scale * (noise_pred_text - noise_pred_uncond) + \ | |
audio_guidance_scale * (noise_pred_text_audio - noise_pred_text) | |
elif do_text_classifier_free_guidance: # only text cfg | |
noise_pred_audio, noise_pred_text_audio = noise_pred.chunk(2) | |
noise_pred = noise_pred_audio + \ | |
text_guidance_scale * (noise_pred_text_audio - noise_pred_audio) | |
elif do_audio_classifier_free_guidance: # only audio cfg | |
noise_pred_text, noise_pred_text_audio = noise_pred.chunk(2) | |
noise_pred = noise_pred_text + \ | |
audio_guidance_scale * (noise_pred_text_audio - noise_pred_text) | |
# First frame latent will always server as unchanged condition | |
video_latents[:, :, 1:, :, :] = self.scheduler.step(noise_pred[:, :, 1:, :, :], t, | |
video_latents[:, :, 1:, :, :], | |
**extra_step_kwargs).prev_sample | |
video_latents = video_latents.contiguous() | |
# 8. Post-processing | |
video_latents = rearrange(video_latents, "b c f h w -> (b f) c h w") | |
videos = self.decode_latents(video_latents).detach().cpu() | |
videos = rearrange(videos, "(b f) c h w -> b f c h w", f=video_length) # value range [0, 1] | |
if not return_dict: | |
return videos | |
return {"videos": videos} | |
def load_and_transform_images_stable_diffusion( | |
images: Union[List[np.ndarray], torch.Tensor, np.ndarray], | |
size=512, | |
flip=False, | |
randcrop=False, | |
normalize=True | |
): | |
""" | |
@images: (List of) np.uint8 images of shape (h, w, 3) | |
or tensor of shape (b, c, h, w) in [0., 1.0] | |
""" | |
assert isinstance(images, (List, torch.Tensor, np.ndarray)), type(images) | |
if isinstance(images, List): | |
assert isinstance(images[0], np.ndarray) | |
assert images[0].dtype == np.uint8 | |
assert images[0].shape[2] == 3 | |
# convert np images into torch float tensor | |
images = torch.from_numpy( | |
rearrange(np.stack(images, axis=0), "f h w c -> f c h w") | |
).float() / 255. | |
elif isinstance(images, np.ndarray): | |
assert isinstance(images, np.ndarray) | |
assert images.dtype == np.uint8 | |
assert images.shape[3] == 3 | |
# convert np images into torch float tensor | |
images = torch.from_numpy( | |
rearrange(images, "f h w c -> f c h w") | |
).float() / 255. | |
assert images.shape[1] == 3 | |
assert torch.all(images <= 1.0) and torch.all(images >= 0.0) | |
h, w = images.shape[-2:] | |
if isinstance(size, int): | |
target_h, target_w = size, size | |
else: | |
target_h, target_w = size | |
# first crop the image | |
target_aspect_ratio = float(target_h) / target_w | |
curr_aspect_ratio = float(h) / w | |
if target_aspect_ratio >= curr_aspect_ratio: # trim w | |
trimmed_w = int(h / target_aspect_ratio) | |
images = images[:, :, :, (w - trimmed_w) // 2: (w - trimmed_w) // 2 + trimmed_w] | |
else: # trim h | |
trimmed_h = int(w * target_aspect_ratio) | |
images = images[:, :, (h - trimmed_h) // 2: (h - trimmed_h) // 2 + trimmed_h] | |
transform_list = [ | |
transforms.Resize( | |
size, | |
interpolation=transforms.InterpolationMode.BILINEAR, | |
antialias=True | |
), | |
] | |
# assert not randcrop | |
if randcrop: | |
transform_list.append(transforms.RandomCrop(size)) | |
else: | |
transform_list.append(transforms.CenterCrop(size)) | |
if flip: | |
transform_list.append(transforms.RandomHorizontalFlip(p=1.0)) | |
if normalize: | |
transform_list.append(transforms.Normalize([0.5], [0.5])) | |
data_transform = transforms.Compose(transform_list) | |
images = data_transform(images) | |
return images | |
def load_image(image_path): | |
image = PIL.Image.open(image_path).convert('RGB') | |
width, height = image.size | |
if width < height: | |
new_width = 256 | |
new_height = int((256 / width) * height) | |
else: | |
new_height = 256 | |
new_width = int((256 / height) * width) | |
# Rescale the image | |
image = image.resize((new_width, new_height), PIL.Image.LANCZOS) | |
# Crop a 256x256 square from the center | |
left = (new_width - 256) / 2 | |
top = (new_height - 256) / 2 | |
right = (new_width + 256) / 2 | |
bottom = (new_height + 256) / 2 | |
image = image.crop((left, top, right, bottom)) | |
return image | |
def load_audio(audio_path): | |
audio, audio_sr = torchaudio.load(audio_path) | |
if audio.ndim == 1: audio = audio.unsqueeze(0) | |
else: | |
audio = audio.mean(dim=0).unsqueeze(0) | |
audio = torchaudio.functional.resample(audio, orig_freq=audio_sr, new_freq=16000) | |
audio = audio[:, :32000].contiguous().float() | |
if audio.shape[1] < 32000: | |
audio = torch.cat([audio, torch.ones(1, 32000-audio.shape[1]).float()], dim=1) | |
return audio.contiguous() | |
def generate_videos( | |
pipeline, | |
image_path: str = '', | |
audio_path: str = '', | |
category_text_encoding: Optional[torch.Tensor] = None, | |
image_size: Tuple[int, int] = (256, 256), | |
video_fps: int = 6, | |
video_num_frame: int = 12, | |
audio_guidance_scale: float = 4.0, | |
denoising_step: int = 20, | |
text_guidance_scale: float = 1.0, | |
seed: int = 0, | |
save_path: str = "", | |
device: torch.device = torch.device("cuda"), | |
): | |
image = load_image(image_path) | |
audio = load_audio(audio_path) | |
generator = torch.Generator(device=device) | |
generator.manual_seed(seed) | |
generated_video = pipeline( | |
images=[image], | |
audios=[audio], | |
text_encodings=[category_text_encoding], | |
video_length=video_num_frame, | |
height=image_size[0], | |
width=image_size[1], | |
num_inference_steps=denoising_step, | |
audio_guidance_scale=audio_guidance_scale, | |
text_guidance_scale=text_guidance_scale, | |
generator=generator, | |
return_dict=False | |
)[0] # (f c h w) in range [0, 1] | |
generated_video = (generated_video.permute(0, 2, 3, 1).contiguous() * 255).byte() | |
os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
torchvision.io.write_video( | |
filename=save_path, | |
video_array=generated_video, | |
fps=video_fps, | |
audio_array=audio, | |
audio_fps=16000, | |
audio_codec="aac" | |
) | |
return | |