from typing import Optional, Tuple import torch from diffusers.models.embeddings import get_3d_rotary_pos_embed from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid def prepare_rotary_positional_embeddings( height: int, width: int, num_frames: int, vae_scale_factor_spatial: int = 8, patch_size: int = 2, patch_size_t: int = None, attention_head_dim: int = 64, device: Optional[torch.device] = None, base_height: int = 480, base_width: int = 720, ) -> Tuple[torch.Tensor, torch.Tensor]: grid_height = height // (vae_scale_factor_spatial * patch_size) grid_width = width // (vae_scale_factor_spatial * patch_size) base_size_width = base_width // (vae_scale_factor_spatial * patch_size) base_size_height = base_height // (vae_scale_factor_spatial * patch_size) if patch_size_t is None: # CogVideoX 1.0 grid_crops_coords = get_resize_crop_region_for_grid( (grid_height, grid_width), base_size_width, base_size_height ) freqs_cos, freqs_sin = get_3d_rotary_pos_embed( embed_dim=attention_head_dim, crops_coords=grid_crops_coords, grid_size=(grid_height, grid_width), temporal_size=num_frames, ) else: # CogVideoX 1.5 base_num_frames = (num_frames + patch_size_t - 1) // patch_size_t freqs_cos, freqs_sin = get_3d_rotary_pos_embed( embed_dim=attention_head_dim, crops_coords=None, grid_size=(grid_height, grid_width), temporal_size=base_num_frames, grid_type="slice", max_size=(base_size_height, base_size_width), ) freqs_cos = freqs_cos.to(device=device) freqs_sin = freqs_sin.to(device=device) return freqs_cos, freqs_sin