Spaces:
Runtime error
Runtime error
from typing import Optional, Tuple | |
import torch | |
from diffusers.models.embeddings import get_3d_rotary_pos_embed | |
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid | |
def prepare_rotary_positional_embeddings( | |
height: int, | |
width: int, | |
num_frames: int, | |
vae_scale_factor_spatial: int = 8, | |
patch_size: int = 2, | |
patch_size_t: int = None, | |
attention_head_dim: int = 64, | |
device: Optional[torch.device] = None, | |
base_height: int = 480, | |
base_width: int = 720, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
grid_height = height // (vae_scale_factor_spatial * patch_size) | |
grid_width = width // (vae_scale_factor_spatial * patch_size) | |
base_size_width = base_width // (vae_scale_factor_spatial * patch_size) | |
base_size_height = base_height // (vae_scale_factor_spatial * patch_size) | |
if patch_size_t is None: | |
# CogVideoX 1.0 | |
grid_crops_coords = get_resize_crop_region_for_grid( | |
(grid_height, grid_width), base_size_width, base_size_height | |
) | |
freqs_cos, freqs_sin = get_3d_rotary_pos_embed( | |
embed_dim=attention_head_dim, | |
crops_coords=grid_crops_coords, | |
grid_size=(grid_height, grid_width), | |
temporal_size=num_frames, | |
) | |
else: | |
# CogVideoX 1.5 | |
base_num_frames = (num_frames + patch_size_t - 1) // patch_size_t | |
freqs_cos, freqs_sin = get_3d_rotary_pos_embed( | |
embed_dim=attention_head_dim, | |
crops_coords=None, | |
grid_size=(grid_height, grid_width), | |
temporal_size=base_num_frames, | |
grid_type="slice", | |
max_size=(base_size_height, base_size_width), | |
) | |
freqs_cos = freqs_cos.to(device=device) | |
freqs_sin = freqs_sin.to(device=device) | |
return freqs_cos, freqs_sin | |