import os from einops import rearrange import torch import torch.nn as nn from xfuser.core.distributed import ( get_sequence_parallel_rank, get_sequence_parallel_world_size, get_sp_group, ) from einops import rearrange, repeat from functools import lru_cache import imageio import uuid from tqdm import tqdm import numpy as np import subprocess import soundfile as sf VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv") ASPECT_RATIO_627 = { '0.26': ([320, 1216], 1), '0.38': ([384, 1024], 1), '0.50': ([448, 896], 1), '0.67': ([512, 768], 1), '0.82': ([576, 704], 1), '1.00': ([640, 640], 1), '1.22': ([704, 576], 1), '1.50': ([768, 512], 1), '1.86': ([832, 448], 1), '2.00': ([896, 448], 1), '2.50': ([960, 384], 1), '2.83': ([1088, 384], 1), '3.60': ([1152, 320], 1), '3.80': ([1216, 320], 1), '4.00': ([1280, 320], 1)} ASPECT_RATIO_960 = { '0.22': ([448, 2048], 1), '0.29': ([512, 1792], 1), '0.36': ([576, 1600], 1), '0.45': ([640, 1408], 1), '0.55': ([704, 1280], 1), '0.63': ([768, 1216], 1), '0.76': ([832, 1088], 1), '0.88': ([896, 1024], 1), '1.00': ([960, 960], 1), '1.14': ([1024, 896], 1), '1.31': ([1088, 832], 1), '1.50': ([1152, 768], 1), '1.58': ([1216, 768], 1), '1.82': ([1280, 704], 1), '1.91': ([1344, 704], 1), '2.20': ([1408, 640], 1), '2.30': ([1472, 640], 1), '2.67': ([1536, 576], 1), '2.89': ([1664, 576], 1), '3.62': ([1856, 512], 1), '3.75': ([1920, 512], 1)} def torch_gc(): torch.cuda.empty_cache() torch.cuda.ipc_collect() def split_token_counts_and_frame_ids(T, token_frame, world_size, rank): S = T * token_frame split_sizes = [S // world_size + (1 if i < S % world_size else 0) for i in range(world_size)] start = sum(split_sizes[:rank]) end = start + split_sizes[rank] counts = [0] * T for idx in range(start, end): t = idx // token_frame counts[t] += 1 counts_filtered = [] frame_ids = [] for t, c in enumerate(counts): if c > 0: counts_filtered.append(c) frame_ids.append(t) return counts_filtered, frame_ids def normalize_and_scale(column, source_range, target_range, epsilon=1e-8): source_min, source_max = source_range new_min, new_max = target_range normalized = (column - source_min) / (source_max - source_min + epsilon) scaled = normalized * (new_max - new_min) + new_min return scaled @torch.compile def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, mode='mean', attn_bias=None): ref_k = ref_k.to(visual_q.dtype).to(visual_q.device) scale = 1.0 / visual_q.shape[-1] ** 0.5 visual_q = visual_q * scale visual_q = visual_q.transpose(1, 2) ref_k = ref_k.transpose(1, 2) attn = visual_q @ ref_k.transpose(-2, -1) if attn_bias is not None: attn = attn + attn_bias x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens x_ref_attn_maps = [] ref_target_masks = ref_target_masks.to(visual_q.dtype) x_ref_attn_map_source = x_ref_attn_map_source.to(visual_q.dtype) for class_idx, ref_target_mask in enumerate(ref_target_masks): torch_gc() ref_target_mask = ref_target_mask[None, None, None, ...] x_ref_attnmap = x_ref_attn_map_source * ref_target_mask x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens x_ref_attnmap = x_ref_attnmap.permute(0, 2, 1) # B, x_seqlens, H if mode == 'mean': x_ref_attnmap = x_ref_attnmap.mean(-1) # B, x_seqlens elif mode == 'max': x_ref_attnmap = x_ref_attnmap.max(-1) # B, x_seqlens x_ref_attn_maps.append(x_ref_attnmap) del attn del x_ref_attn_map_source torch_gc() return torch.concat(x_ref_attn_maps, dim=0) def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=2, enable_sp=False): """Args: query (torch.tensor): B M H K key (torch.tensor): B M H K shape (tuple): (N_t, N_h, N_w) ref_target_masks: [B, N_h * N_w] """ N_t, N_h, N_w = shape if enable_sp: ref_k = get_sp_group().all_gather(ref_k, dim=1) x_seqlens = N_h * N_w ref_k = ref_k[:, :x_seqlens] _, seq_lens, heads, _ = visual_q.shape class_num, _ = ref_target_masks.shape x_ref_attn_maps = torch.zeros(class_num, seq_lens).to(visual_q.device).to(visual_q.dtype) split_chunk = heads // split_num for i in range(split_num): x_ref_attn_maps_perhead = calculate_x_ref_attn_map(visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_target_masks) x_ref_attn_maps += x_ref_attn_maps_perhead return x_ref_attn_maps / split_num def rotate_half(x): x = rearrange(x, "... (d r) -> ... d r", r=2) x1, x2 = x.unbind(dim=-1) x = torch.stack((-x2, x1), dim=-1) return rearrange(x, "... d r -> ... (d r)") class RotaryPositionalEmbedding1D(nn.Module): def __init__(self, head_dim, ): super().__init__() self.head_dim = head_dim self.base = 10000 @lru_cache(maxsize=32) def precompute_freqs_cis_1d(self, pos_indices): freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim)) freqs = freqs.to(pos_indices.device) freqs = torch.einsum("..., f -> ... f", pos_indices.float(), freqs) freqs = repeat(freqs, "... n -> ... (n r)", r=2) return freqs def forward(self, x, pos_indices): """1D RoPE. Args: query (torch.tensor): [B, head, seq, head_dim] pos_indices (torch.tensor): [seq,] Returns: query with the same shape as input. """ freqs_cis = self.precompute_freqs_cis_1d(pos_indices) x_ = x.float() freqs_cis = freqs_cis.float().to(x.device) cos, sin = freqs_cis.cos(), freqs_cis.sin() cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d') x_ = (x_ * cos) + (rotate_half(x_) * sin) return x_.type_as(x) def save_video_ffmpeg(gen_video_samples, save_path, vocal_audio_list, fps=25, quality=5): def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None): writer = imageio.get_writer( save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params ) for frame in tqdm(frames, desc="Saving video"): frame = np.array(frame) writer.append_data(frame) writer.close() save_path_tmp = save_path + "-temp.mp4" video_audio = (gen_video_samples+1)/2 # C T H W video_audio = video_audio.permute(1, 2, 3, 0).cpu().numpy() video_audio = np.clip(video_audio * 255, 0, 255).astype(np.uint8) save_video(video_audio, save_path_tmp, fps=fps, quality=quality) # crop audio according to video length _, T, _, _ = gen_video_samples.shape duration = T / fps save_path_crop_audio = save_path + "-cropaudio.wav" final_command = [ "ffmpeg", "-i", vocal_audio_list[0], "-t", f'{duration}', save_path_crop_audio, ] subprocess.run(final_command, check=True) # generate video with audio save_path = save_path + ".mp4" final_command = [ "ffmpeg", "-y", "-i", save_path_tmp, "-i", save_path_crop_audio, "-c:v", "libx264", "-c:a", "aac", "-shortest", save_path, ] subprocess.run(final_command, check=True) os.remove(save_path_tmp) os.remove(save_path_crop_audio) class MomentumBuffer: def __init__(self, momentum: float): self.momentum = momentum self.running_average = 0 def update(self, update_value: torch.Tensor): new_average = self.momentum * self.running_average self.running_average = update_value + new_average def project( v0: torch.Tensor, # [B, C, T, H, W] v1: torch.Tensor, # [B, C, T, H, W] ): dtype = v0.dtype v0, v1 = v0.double(), v1.double() v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3, -4]) v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3, -4], keepdim=True) * v1 v0_orthogonal = v0 - v0_parallel return v0_parallel.to(dtype), v0_orthogonal.to(dtype) def adaptive_projected_guidance( diff: torch.Tensor, # [B, C, T, H, W] pred_cond: torch.Tensor, # [B, C, T, H, W] momentum_buffer: MomentumBuffer = None, eta: float = 0.0, norm_threshold: float = 55, ): if momentum_buffer is not None: momentum_buffer.update(diff) diff = momentum_buffer.running_average if norm_threshold > 0: ones = torch.ones_like(diff) diff_norm = diff.norm(p=2, dim=[-1, -2, -3, -4], keepdim=True) print(f"diff_norm: {diff_norm}") scale_factor = torch.minimum(ones, norm_threshold / diff_norm) diff = diff * scale_factor diff_parallel, diff_orthogonal = project(diff, pred_cond) normalized_update = diff_orthogonal + eta * diff_parallel return normalized_update