daquanzhou
merge github repos and lfs track ckpt/path/safetensors/pt
613c9ab
import math
from typing import Iterable, Tuple, Union
import re
import torch
from einops import rearrange, repeat
from torch import Tensor, nn
from comfy.ldm.modules.attention import FeedForward, SpatialTransformer
from comfy.model_patcher import ModelPatcher
from comfy.ldm.modules.diffusionmodules import openaimodel
from comfy.ldm.modules.diffusionmodules.openaimodel import SpatialTransformer
from comfy.controlnet import broadcast_image_to
from comfy.utils import repeat_to_batch_size
import comfy.ops
import comfy.model_management
from .context import ContextFuseMethod, ContextOptions, get_context_weights, get_context_windows
from .utils_motion import CrossAttentionMM, MotionCompatibilityError, extend_to_batch_size, prepare_mask_batch
from .utils_model import BetaSchedules, ModelTypeSD
from .logger import logger
def zero_module(module):
# Zero out the parameters of a module and return it.
for p in module.parameters():
p.detach().zero_()
return module
class AnimateDiffFormat:
ANIMATEDIFF = "AnimateDiff"
HOTSHOTXL = "HotshotXL"
ANIMATELCM = "AnimateLCM"
class AnimateDiffVersion:
V1 = "v1"
V2 = "v2"
V3 = "v3"
class AnimateDiffInfo:
def __init__(self, sd_type: str, mm_format: str, mm_version: str, mm_name: str):
self.sd_type = sd_type
self.mm_format = mm_format
self.mm_version = mm_version
self.mm_name = mm_name
def get_string(self):
return f"{self.mm_name}:{self.mm_version}:{self.mm_format}:{self.sd_type}"
def is_hotshotxl(mm_state_dict: dict[str, Tensor]) -> bool:
# use pos_encoder naming to determine if hotshotxl model
for key in mm_state_dict.keys():
if key.endswith("pos_encoder.positional_encoding"):
return True
return False
def is_animatelcm(mm_state_dict: dict[str, Tensor]) -> bool:
# use lack of ANY pos_encoder keys to determine if animatelcm model
for key in mm_state_dict.keys():
if "pos_encoder" in key:
return False
return True
def get_down_block_max(mm_state_dict: dict[str, Tensor]) -> int:
# keep track of biggest down_block count in module
biggest_block = 0
for key in mm_state_dict.keys():
if "down_blocks" in key:
try:
block_int = key.split(".")[1]
block_num = int(block_int)
if block_num > biggest_block:
biggest_block = block_num
except ValueError:
pass
return biggest_block
def has_mid_block(mm_state_dict: dict[str, Tensor]):
# check if keys contain mid_block
for key in mm_state_dict.keys():
if key.startswith("mid_block."):
return True
return False
def get_position_encoding_max_len(mm_state_dict: dict[str, Tensor], mm_name: str, mm_format: str) -> Union[int, None]:
# use pos_encoder.pe entries to determine max length - [1, {max_length}, {320|640|1280}]
for key in mm_state_dict.keys():
if key.endswith("pos_encoder.pe"):
return mm_state_dict[key].size(1) # get middle dim
# AnimateLCM models should have no pos_encoder entries, and assumed to be 64
if mm_format == AnimateDiffFormat.ANIMATELCM:
return 64
raise MotionCompatibilityError(f"No pos_encoder.pe found in mm_state_dict - {mm_name} is not a valid AnimateDiff motion module!")
_regex_hotshotxl_module_num = re.compile(r'temporal_attentions\.(\d+)\.')
def find_hotshot_module_num(key: str) -> Union[int, None]:
found = _regex_hotshotxl_module_num.search(key)
if found:
return int(found.group(1))
return None
def normalize_ad_state_dict(mm_state_dict: dict[str, Tensor], mm_name: str) -> Tuple[dict[str, Tensor], AnimateDiffInfo]:
# from pathlib import Path
# with open(Path(__file__).parent.parent.parent / f"keys_{mm_name}.txt", "w") as afile:
# for key, value in mm_state_dict.items():
# afile.write(f"{key}:\t{value.shape}\n")
# remove all non-temporal keys (in case model has extra stuff in it)
for key in list(mm_state_dict.keys()):
if "temporal" not in key:
del mm_state_dict[key]
# determine what SD model the motion module is intended for
sd_type: str = None
down_block_max = get_down_block_max(mm_state_dict)
if down_block_max == 3:
sd_type = ModelTypeSD.SD1_5
elif down_block_max == 2:
sd_type = ModelTypeSD.SDXL
else:
raise ValueError(f"'{mm_name}' is not a valid SD1.5 nor SDXL motion module - contained {down_block_max} downblocks.")
# determine the model's format
mm_format = AnimateDiffFormat.ANIMATEDIFF
if is_hotshotxl(mm_state_dict):
mm_format = AnimateDiffFormat.HOTSHOTXL
if is_animatelcm(mm_state_dict):
mm_format = AnimateDiffFormat.ANIMATELCM
# determine the model's version
mm_version = AnimateDiffVersion.V1
if has_mid_block(mm_state_dict):
mm_version = AnimateDiffVersion.V2
elif sd_type==ModelTypeSD.SD1_5 and get_position_encoding_max_len(mm_state_dict, mm_name, mm_format)==32:
mm_version = AnimateDiffVersion.V3
info = AnimateDiffInfo(sd_type=sd_type, mm_format=mm_format, mm_version=mm_version, mm_name=mm_name)
# convert to AnimateDiff format, if needed
if mm_format == AnimateDiffFormat.HOTSHOTXL:
# HotshotXL is AD-based architecture applied to SDXL instead of SD1.5
# By renaming the keys, no code needs to be adapted at all
#
# reformat temporal_attentions:
# HSXL: temporal_attentions.#.
# AD: motion_modules.#.temporal_transformer.
# HSXL: pos_encoder.positional_encoding
# AD: pos_encoder.pe
for key in list(mm_state_dict.keys()):
module_num = find_hotshot_module_num(key)
if module_num is not None:
new_key = key.replace(f"temporal_attentions.{module_num}",
f"motion_modules.{module_num}.temporal_transformer", 1)
new_key = new_key.replace("pos_encoder.positional_encoding", "pos_encoder.pe")
mm_state_dict[new_key] = mm_state_dict[key]
del mm_state_dict[key]
# return adjusted mm_state_dict and info
return mm_state_dict, info
class BlockType:
UP = "up"
DOWN = "down"
MID = "mid"
class AnimateDiffModel(nn.Module):
def __init__(self, mm_state_dict: dict[str, Tensor], mm_info: AnimateDiffInfo):
super().__init__()
self.mm_info = mm_info
self.down_blocks: Iterable[MotionModule] = nn.ModuleList([])
self.up_blocks: Iterable[MotionModule] = nn.ModuleList([])
self.mid_block: Union[MotionModule, None] = None
self.encoding_max_len = get_position_encoding_max_len(mm_state_dict, mm_info.mm_name, mm_info.mm_format)
self.has_position_encoding = self.encoding_max_len is not None
# determine ops to use (to support fp8 properly)
if comfy.model_management.unet_manual_cast(comfy.model_management.unet_dtype(), comfy.model_management.get_torch_device()) is None:
ops = comfy.ops.disable_weight_init
else:
ops = comfy.ops.manual_cast
# SDXL has 3 up/down blocks, SD1.5 has 4 up/down blocks
if mm_info.sd_type == ModelTypeSD.SDXL:
layer_channels = (320, 640, 1280)
else:
layer_channels = (320, 640, 1280, 1280)
# fill out down/up blocks and middle block, if present
for c in layer_channels:
self.down_blocks.append(MotionModule(c, temporal_position_encoding=self.has_position_encoding,
temporal_position_encoding_max_len=self.encoding_max_len, block_type=BlockType.DOWN, ops=ops))
for c in reversed(layer_channels):
self.up_blocks.append(MotionModule(c, temporal_position_encoding=self.has_position_encoding,
temporal_position_encoding_max_len=self.encoding_max_len, block_type=BlockType.UP, ops=ops))
if has_mid_block(mm_state_dict):
self.mid_block = MotionModule(1280, temporal_position_encoding=self.has_position_encoding,
temporal_position_encoding_max_len=self.encoding_max_len, block_type=BlockType.MID, ops=ops)
self.AD_video_length: int = 24
def get_device_debug(self):
return self.down_blocks[0].motion_modules[0].temporal_transformer.proj_in.weight.device
def is_length_valid_for_encoding_max_len(self, length: int):
if self.encoding_max_len is None:
return True
return length <= self.encoding_max_len
def get_best_beta_schedule(self, log=False) -> str:
to_return = None
if self.mm_info.sd_type == ModelTypeSD.SD1_5:
if self.mm_info.mm_format == AnimateDiffFormat.ANIMATELCM:
to_return = BetaSchedules.LCM # while LCM_100 is the intended schedule, I find LCM to have much less flicker
else:
to_return = BetaSchedules.SQRT_LINEAR
elif self.mm_info.sd_type == ModelTypeSD.SDXL:
if self.mm_info.mm_format == AnimateDiffFormat.HOTSHOTXL:
to_return = BetaSchedules.LINEAR
else:
to_return = BetaSchedules.LINEAR_ADXL
if to_return is not None:
if log: logger.info(f"[Autoselect]: '{to_return}' beta_schedule for {self.mm_info.get_string()}")
else:
to_return = BetaSchedules.USE_EXISTING
if log: logger.info(f"[Autoselect]: could not find beta_schedule for {self.mm_info.get_string()}, defaulting to '{to_return}'")
return to_return
def cleanup(self):
pass
def inject(self, model: ModelPatcher):
unet: openaimodel.UNetModel = model.model.diffusion_model
# inject input (down) blocks
# SD15 mm contains 4 downblocks, each with 2 TemporalTransformers - 8 in total
# SDXL mm contains 3 downblocks, each with 2 TemporalTransformers - 6 in total
self._inject(unet.input_blocks, self.down_blocks)
# inject output (up) blocks
# SD15 mm contains 4 upblocks, each with 3 TemporalTransformers - 12 in total
# SDXL mm contains 3 upblocks, each with 3 TemporalTransformers - 9 in total
self._inject(unet.output_blocks, self.up_blocks)
# inject mid block, if needed (encapsulate in list to make structure compatible)
if self.mid_block is not None:
self._inject([unet.middle_block], [self.mid_block])
del unet
def _inject(self, unet_blocks: nn.ModuleList, mm_blocks: nn.ModuleList):
# Rules for injection:
# For each component list in a unet block:
# if SpatialTransformer exists in list, place next block after last occurrence
# elif ResBlock exists in list, place next block after first occurrence
# else don't place block
injection_count = 0
unet_idx = 0
# details about blocks passed in
per_block = len(mm_blocks[0].motion_modules)
injection_goal = len(mm_blocks) * per_block
# only stop injecting when modules exhausted
while injection_count < injection_goal:
# figure out which VanillaTemporalModule from mm to inject
mm_blk_idx, mm_vtm_idx = injection_count // per_block, injection_count % per_block
# figure out layout of unet block components
st_idx = -1 # SpatialTransformer index
res_idx = -1 # first ResBlock index
# first, figure out indeces of relevant blocks
for idx, component in enumerate(unet_blocks[unet_idx]):
if type(component) == SpatialTransformer:
st_idx = idx
elif type(component).__name__ == "ResBlock" and res_idx < 0:
res_idx = idx
# if SpatialTransformer exists, inject right after
if st_idx >= 0:
#logger.info(f"AD: injecting after ST({st_idx})")
unet_blocks[unet_idx].insert(st_idx+1, mm_blocks[mm_blk_idx].motion_modules[mm_vtm_idx])
injection_count += 1
# otherwise, if only ResBlock exists, inject right after
elif res_idx >= 0:
#logger.info(f"AD: injecting after Res({res_idx})")
unet_blocks[unet_idx].insert(res_idx+1, mm_blocks[mm_blk_idx].motion_modules[mm_vtm_idx])
injection_count += 1
# increment unet_idx
unet_idx += 1
def eject(self, model: ModelPatcher):
unet: openaimodel.UNetModel = model.model.diffusion_model
# remove from input blocks (downblocks)
self._eject(unet.input_blocks)
# remove from output blocks (upblocks)
self._eject(unet.output_blocks)
# remove from middle block (encapsulate in list to make compatible)
self._eject([unet.middle_block])
del unet
def _eject(self, unet_blocks: nn.ModuleList):
# eject all VanillaTemporalModule objects from all blocks
for block in unet_blocks:
idx_to_pop = []
for idx, component in enumerate(block):
if type(component) == VanillaTemporalModule:
idx_to_pop.append(idx)
# pop in backwards order, as to not disturb what the indeces refer to
for idx in sorted(idx_to_pop, reverse=True):
block.pop(idx)
def set_video_length(self, video_length: int, full_length: int):
self.AD_video_length = video_length
for block in self.down_blocks:
block.set_video_length(video_length, full_length)
for block in self.up_blocks:
block.set_video_length(video_length, full_length)
if self.mid_block is not None:
self.mid_block.set_video_length(video_length, full_length)
def set_scale(self, multival: Union[float, Tensor]):
if multival is None:
multival = 1.0
if type(multival) == Tensor:
self._set_scale_multiplier(1.0)
self._set_scale_mask(multival)
else:
self._set_scale_multiplier(multival)
self._set_scale_mask(None)
def set_effect(self, multival: Union[float, Tensor]):
for block in self.down_blocks:
block.set_effect(multival)
for block in self.up_blocks:
block.set_effect(multival)
if self.mid_block is not None:
self.mid_block.set_effect(multival)
def set_sub_idxs(self, sub_idxs: list[int]):
for block in self.down_blocks:
block.set_sub_idxs(sub_idxs)
for block in self.up_blocks:
block.set_sub_idxs(sub_idxs)
if self.mid_block is not None:
self.mid_block.set_sub_idxs(sub_idxs)
def set_view_options(self, view_options: ContextOptions):
for block in self.down_blocks:
block.set_view_options(view_options)
for block in self.up_blocks:
block.set_view_options(view_options)
if self.mid_block is not None:
self.mid_block.set_view_options(view_options)
def reset(self):
self._reset_sub_idxs()
self._reset_scale_multiplier()
self._reset_temp_vars()
def _set_scale_multiplier(self, multiplier: Union[float, None]):
for block in self.down_blocks:
block.set_scale_multiplier(multiplier)
for block in self.up_blocks:
block.set_scale_multiplier(multiplier)
if self.mid_block is not None:
self.mid_block.set_scale_multiplier(multiplier)
def _set_scale_mask(self, mask: Tensor):
for block in self.down_blocks:
block.set_scale_mask(mask)
for block in self.up_blocks:
block.set_scale_mask(mask)
if self.mid_block is not None:
self.mid_block.set_scale_mask(mask)
def _reset_temp_vars(self):
for block in self.down_blocks:
block.reset_temp_vars()
for block in self.up_blocks:
block.reset_temp_vars()
if self.mid_block is not None:
self.mid_block.reset_temp_vars()
def _reset_scale_multiplier(self):
self._set_scale_multiplier(None)
def _reset_sub_idxs(self):
self.set_sub_idxs(None)
class MotionModule(nn.Module):
def __init__(self,
in_channels,
temporal_position_encoding=True,
temporal_position_encoding_max_len=24,
block_type: str=BlockType.DOWN,
ops=comfy.ops.disable_weight_init
):
super().__init__()
if block_type == BlockType.MID:
# mid blocks contain only a single VanillaTemporalModule
self.motion_modules: Iterable[VanillaTemporalModule] = nn.ModuleList([get_motion_module(in_channels, temporal_position_encoding, temporal_position_encoding_max_len, ops=ops)])
else:
# down blocks contain two VanillaTemporalModules
self.motion_modules: Iterable[VanillaTemporalModule] = nn.ModuleList(
[
get_motion_module(in_channels, temporal_position_encoding, temporal_position_encoding_max_len, ops=ops),
get_motion_module(in_channels, temporal_position_encoding, temporal_position_encoding_max_len, ops=ops)
]
)
# up blocks contain one additional VanillaTemporalModule
if block_type == BlockType.UP:
self.motion_modules.append(get_motion_module(in_channels, temporal_position_encoding, temporal_position_encoding_max_len, ops=ops))
def set_video_length(self, video_length: int, full_length: int):
for motion_module in self.motion_modules:
motion_module.set_video_length(video_length, full_length)
def set_scale_multiplier(self, multiplier: Union[float, None]):
for motion_module in self.motion_modules:
motion_module.set_scale_multiplier(multiplier)
def set_scale_mask(self, mask: Tensor):
for motion_module in self.motion_modules:
motion_module.set_scale_mask(mask)
def set_effect(self, multival: Union[float, Tensor]):
for motion_module in self.motion_modules:
motion_module.set_effect(multival)
def set_sub_idxs(self, sub_idxs: list[int]):
for motion_module in self.motion_modules:
motion_module.set_sub_idxs(sub_idxs)
def set_view_options(self, view_options: ContextOptions):
for motion_module in self.motion_modules:
motion_module.set_view_options(view_options=view_options)
def reset_temp_vars(self):
for motion_module in self.motion_modules:
motion_module.reset_temp_vars()
def get_motion_module(in_channels, temporal_position_encoding, temporal_position_encoding_max_len, ops=comfy.ops.disable_weight_init):
return VanillaTemporalModule(in_channels=in_channels, temporal_position_encoding=temporal_position_encoding, temporal_position_encoding_max_len=temporal_position_encoding_max_len, ops=ops)
class VanillaTemporalModule(nn.Module):
def __init__(
self,
in_channels,
num_attention_heads=8,
num_transformer_block=1,
attention_block_types=("Temporal_Self", "Temporal_Self"),
cross_frame_attention_mode=None,
temporal_position_encoding=True,
temporal_position_encoding_max_len=24,
temporal_attention_dim_div=1,
zero_initialize=True,
ops=comfy.ops.disable_weight_init,
):
super().__init__()
self.video_length = 16
self.full_length = 16
self.sub_idxs = None
self.view_options = None
self.effect = None
self.temp_effect_mask: Tensor = None
self.prev_input_tensor_batch = 0
self.temporal_transformer = TemporalTransformer3DModel(
in_channels=in_channels,
num_attention_heads=num_attention_heads,
attention_head_dim=in_channels
// num_attention_heads
// temporal_attention_dim_div,
num_layers=num_transformer_block,
attention_block_types=attention_block_types,
cross_frame_attention_mode=cross_frame_attention_mode,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
ops=ops
)
if zero_initialize:
self.temporal_transformer.proj_out = zero_module(
self.temporal_transformer.proj_out
)
def set_video_length(self, video_length: int, full_length: int):
self.video_length = video_length
self.full_length = full_length
self.temporal_transformer.set_video_length(video_length, full_length)
def set_scale_multiplier(self, multiplier: Union[float, None]):
self.temporal_transformer.set_scale_multiplier(multiplier)
def set_scale_mask(self, mask: Tensor):
self.temporal_transformer.set_scale_mask(mask)
def set_effect(self, multival: Union[float, Tensor]):
if type(multival) == Tensor:
self.effect = multival
elif multival is not None and math.isclose(multival, 1.0):
self.effect = None
else:
self.effect = multival
self.temp_effect_mask = None
def set_sub_idxs(self, sub_idxs: list[int]):
self.sub_idxs = sub_idxs
self.temporal_transformer.set_sub_idxs(sub_idxs)
def set_view_options(self, view_options: ContextOptions):
self.view_options = view_options
def reset_temp_vars(self):
self.set_effect(None)
self.set_view_options(None)
self.temporal_transformer.reset_temp_vars()
def get_effect_mask(self, input_tensor: Tensor):
batch, channel, height, width = input_tensor.shape
batched_number = batch // self.video_length
full_batched_idxs = list(range(self.video_length))*batched_number
# if there is a cached temp_effect_mask and it is valid for current input, return it
if batch == self.prev_input_tensor_batch and self.temp_effect_mask is not None:
if self.sub_idxs is not None:
return self.temp_effect_mask[self.sub_idxs*batched_number]
return self.temp_effect_mask[full_batched_idxs]
# clear any existing mask
del self.temp_effect_mask
self.temp_effect_mask = None
# recalculate temp mask
self.prev_input_tensor_batch = batch
# make sure mask matches expected dimensions
mask = prepare_mask_batch(self.effect, shape=(self.full_length, 1, height, width))
# make sure mask is as long as full_length - clone last element of list if too short
self.temp_effect_mask = extend_to_batch_size(mask, self.full_length).to(
dtype=input_tensor.dtype, device=input_tensor.device)
# return finalized mask
if self.sub_idxs is not None:
return self.temp_effect_mask[self.sub_idxs*batched_number]
return self.temp_effect_mask[full_batched_idxs]
def forward(self, input_tensor: Tensor, encoder_hidden_states=None, attention_mask=None):
if self.effect is None:
return self.temporal_transformer(input_tensor, encoder_hidden_states, attention_mask, self.view_options)
# return weighted average of input_tensor and AD output
if type(self.effect) != Tensor:
effect = self.effect
# do nothing if effect is 0
if math.isclose(effect, 0.0):
return input_tensor
else:
effect = self.get_effect_mask(input_tensor)
return input_tensor*(1.0-effect) + self.temporal_transformer(input_tensor, encoder_hidden_states, attention_mask, self.view_options)*effect
class TemporalTransformer3DModel(nn.Module):
def __init__(
self,
in_channels,
num_attention_heads,
attention_head_dim,
num_layers,
attention_block_types=(
"Temporal_Self",
"Temporal_Self",
),
dropout=0.0,
norm_num_groups=32,
cross_attention_dim=768,
activation_fn="geglu",
attention_bias=False,
upcast_attention=False,
cross_frame_attention_mode=None,
temporal_position_encoding=False,
temporal_position_encoding_max_len=24,
ops=comfy.ops.disable_weight_init,
):
super().__init__()
self.video_length = 16
self.full_length = 16
self.raw_scale_mask: Union[Tensor, None] = None
self.temp_scale_mask: Union[Tensor, None] = None
self.sub_idxs: Union[list[int], None] = None
self.prev_hidden_states_batch = 0
inner_dim = num_attention_heads * attention_head_dim
self.norm = ops.GroupNorm(
num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
)
self.proj_in = ops.Linear(in_channels, inner_dim)
self.transformer_blocks: Iterable[TemporalTransformerBlock] = nn.ModuleList(
[
TemporalTransformerBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
attention_block_types=attention_block_types,
dropout=dropout,
norm_num_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
attention_bias=attention_bias,
upcast_attention=upcast_attention,
cross_frame_attention_mode=cross_frame_attention_mode,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
ops=ops,
)
for d in range(num_layers)
]
)
self.proj_out = ops.Linear(inner_dim, in_channels)
def set_video_length(self, video_length: int, full_length: int):
self.video_length = video_length
self.full_length = full_length
def set_scale_multiplier(self, multiplier: Union[float, None]):
for block in self.transformer_blocks:
block.set_scale_multiplier(multiplier)
def set_scale_mask(self, mask: Tensor):
self.raw_scale_mask = mask
self.temp_scale_mask = None
def set_sub_idxs(self, sub_idxs: list[int]):
self.sub_idxs = sub_idxs
for block in self.transformer_blocks:
block.set_sub_idxs(sub_idxs)
def reset_temp_vars(self):
del self.temp_scale_mask
self.temp_scale_mask = None
self.prev_hidden_states_batch = 0
def get_scale_mask(self, hidden_states: Tensor) -> Union[Tensor, None]:
# if no raw mask, return None
if self.raw_scale_mask is None:
return None
shape = hidden_states.shape
batch, channel, height, width = shape
# if temp mask already calculated, return it
if self.temp_scale_mask != None:
# check if hidden_states batch matches
if batch == self.prev_hidden_states_batch:
if self.sub_idxs is not None:
return self.temp_scale_mask[:, self.sub_idxs, :]
return self.temp_scale_mask
# if does not match, reset cached temp_scale_mask and recalculate it
del self.temp_scale_mask
self.temp_scale_mask = None
# otherwise, calculate temp mask
self.prev_hidden_states_batch = batch
mask = prepare_mask_batch(self.raw_scale_mask, shape=(self.full_length, 1, height, width))
mask = repeat_to_batch_size(mask, self.full_length)
# if mask not the same amount length as full length, make it match
if self.full_length != mask.shape[0]:
mask = broadcast_image_to(mask, self.full_length, 1)
# reshape mask to attention K shape (h*w, latent_count, 1)
batch, channel, height, width = mask.shape
# first, perform same operations as on hidden_states,
# turning (b, c, h, w) -> (b, h*w, c)
mask = mask.permute(0, 2, 3, 1).reshape(batch, height*width, channel)
# then, make it the same shape as attention's k, (h*w, b, c)
mask = mask.permute(1, 0, 2)
# make masks match the expected length of h*w
batched_number = shape[0] // self.video_length
if batched_number > 1:
mask = torch.cat([mask] * batched_number, dim=0)
# cache mask and set to proper device
self.temp_scale_mask = mask
# move temp_scale_mask to proper dtype + device
self.temp_scale_mask = self.temp_scale_mask.to(dtype=hidden_states.dtype, device=hidden_states.device)
# return subset of masks, if needed
if self.sub_idxs is not None:
return self.temp_scale_mask[:, self.sub_idxs, :]
return self.temp_scale_mask
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, view_options: ContextOptions=None):
batch, channel, height, width = hidden_states.shape
residual = hidden_states
scale_mask = self.get_scale_mask(hidden_states)
# add some casts for fp8 purposes - does not affect speed otherwise
hidden_states = self.norm(hidden_states).to(hidden_states.dtype)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
batch, height * width, inner_dim
)
hidden_states = self.proj_in(hidden_states).to(hidden_states.dtype)
# Transformer Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
video_length=self.video_length,
scale_mask=scale_mask,
view_options=view_options
)
# output
hidden_states = self.proj_out(hidden_states)
hidden_states = (
hidden_states.reshape(batch, height, width, inner_dim)
.permute(0, 3, 1, 2)
.contiguous()
)
output = hidden_states + residual
return output
class TemporalTransformerBlock(nn.Module):
def __init__(
self,
dim,
num_attention_heads,
attention_head_dim,
attention_block_types=(
"Temporal_Self",
"Temporal_Self",
),
dropout=0.0,
norm_num_groups=32,
cross_attention_dim=768,
activation_fn="geglu",
attention_bias=False,
upcast_attention=False,
cross_frame_attention_mode=None,
temporal_position_encoding=False,
temporal_position_encoding_max_len=24,
ops=comfy.ops.disable_weight_init,
):
super().__init__()
attention_blocks = []
norms = []
for block_name in attention_block_types:
attention_blocks.append(
VersatileAttention(
attention_mode=block_name.split("_")[0],
context_dim=cross_attention_dim # called context_dim for ComfyUI impl
if block_name.endswith("_Cross")
else None,
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
#bias=attention_bias, # remove for Comfy CrossAttention
#upcast_attention=upcast_attention, # remove for Comfy CrossAttention
cross_frame_attention_mode=cross_frame_attention_mode,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
ops=ops,
)
)
norms.append(ops.LayerNorm(dim))
self.attention_blocks: Iterable[VersatileAttention] = nn.ModuleList(attention_blocks)
self.norms = nn.ModuleList(norms)
self.ff = FeedForward(dim, dropout=dropout, glu=(activation_fn == "geglu"), operations=ops)
self.ff_norm = ops.LayerNorm(dim)
def set_scale_multiplier(self, multiplier: Union[float, None]):
for block in self.attention_blocks:
block.set_scale_multiplier(multiplier)
def set_sub_idxs(self, sub_idxs: list[int]):
for block in self.attention_blocks:
block.set_sub_idxs(sub_idxs)
def forward(
self,
hidden_states: Tensor,
encoder_hidden_states: Tensor=None,
attention_mask: Tensor=None,
video_length: int=None,
scale_mask: Tensor=None,
view_options: ContextOptions=None,
):
# make view_options None if context_length > video_length, or if equal and equal not allowed
if view_options:
if view_options.context_length > video_length:
view_options = None
elif view_options.context_length == video_length and not view_options.use_on_equal_length:
view_options = None
if not view_options:
for attention_block, norm in zip(self.attention_blocks, self.norms):
norm_hidden_states = norm(hidden_states).to(hidden_states.dtype)
hidden_states = (
attention_block(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states
if attention_block.is_cross_attention
else None,
attention_mask=attention_mask,
video_length=video_length,
scale_mask=scale_mask
) + hidden_states
)
else:
# views idea gotten from diffusers AnimateDiff FreeNoise implementation:
# https://github.com/arthur-qiu/FreeNoise-AnimateDiff/blob/main/animatediff/models/motion_module.py
# apply sliding context windows (views)
views = get_context_windows(num_frames=video_length, opts=view_options)
hidden_states = rearrange(hidden_states, "(b f) d c -> b f d c", f=video_length)
value_final = torch.zeros_like(hidden_states)
count_final = torch.zeros_like(hidden_states)
# bias_final = [0.0] * video_length
batched_conds = hidden_states.size(1) // video_length
for sub_idxs in views:
sub_hidden_states = rearrange(hidden_states[:, sub_idxs], "b f d c -> (b f) d c")
for attention_block, norm in zip(self.attention_blocks, self.norms):
norm_hidden_states = norm(sub_hidden_states).to(sub_hidden_states.dtype)
sub_hidden_states = (
attention_block(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states # do these need to be changed for sub_idxs too?
if attention_block.is_cross_attention
else None,
attention_mask=attention_mask,
video_length=len(sub_idxs),
scale_mask=scale_mask[:, sub_idxs, :] if scale_mask is not None else scale_mask
) + sub_hidden_states
)
sub_hidden_states = rearrange(sub_hidden_states, "(b f) d c -> b f d c", f=len(sub_idxs))
# if view_options.fuse_method == ContextFuseMethod.RELATIVE:
# for pos, idx in enumerate(sub_idxs):
# # bias is the influence of a specific index in relation to the whole context window
# bias = 1 - abs(idx - (sub_idxs[0] + sub_idxs[-1]) / 2) / ((sub_idxs[-1] - sub_idxs[0] + 1e-2) / 2)
# bias = max(1e-2, bias)
# # take weighted averate relative to total bias of current idx
# bias_total = bias_final[idx]
# prev_weight = torch.tensor([bias_total / (bias_total + bias)],
# dtype=value_final.dtype, device=value_final.device).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
# #prev_weight = torch.cat([prev_weight]*value_final.shape[1], dim=1)
# new_weight = torch.tensor([bias / (bias_total + bias)],
# dtype=value_final.dtype, device=value_final.device).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
# #new_weight = torch.cat([new_weight]*value_final.shape[1], dim=1)
# test = value_final[:, idx:idx+1, :, :]
# value_final[:, idx:idx+1, :, :] = value_final[:, idx:idx+1, :, :] * prev_weight + sub_hidden_states[:, pos:pos+1, : ,:] * new_weight
# bias_final[idx] = bias_total + bias
# else:
weights = get_context_weights(len(sub_idxs), view_options.fuse_method) * batched_conds
weights_tensor = torch.Tensor(weights).to(device=hidden_states.device).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
value_final[:, sub_idxs] += sub_hidden_states * weights_tensor
count_final[:, sub_idxs] += weights_tensor
# get weighted average of sub_hidden_states, if fuse method requires it
# if view_options.fuse_method != ContextFuseMethod.RELATIVE:
hidden_states = value_final / count_final
hidden_states = rearrange(hidden_states, "b f d c -> (b f) d c")
del value_final
del count_final
# del bias_final
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
output = hidden_states
return output
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.0, max_len=24):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
)
pe = torch.zeros(1, max_len, d_model)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe)
self.sub_idxs = None
def set_sub_idxs(self, sub_idxs: list[int]):
self.sub_idxs = sub_idxs
def forward(self, x):
#if self.sub_idxs is not None:
# x = x + self.pe[:, self.sub_idxs]
#else:
x = x + self.pe[:, : x.size(1)]
return self.dropout(x)
class VersatileAttention(CrossAttentionMM):
def __init__(
self,
attention_mode=None,
cross_frame_attention_mode=None,
temporal_position_encoding=False,
temporal_position_encoding_max_len=24,
ops=comfy.ops.disable_weight_init,
*args,
**kwargs,
):
super().__init__(operations=ops, *args, **kwargs)
assert attention_mode == "Temporal"
self.attention_mode = attention_mode
self.is_cross_attention = kwargs["context_dim"] is not None
self.pos_encoder = (
PositionalEncoding(
kwargs["query_dim"],
dropout=0.0,
max_len=temporal_position_encoding_max_len,
)
if (temporal_position_encoding and attention_mode == "Temporal")
else None
)
def extra_repr(self):
return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
def set_scale_multiplier(self, multiplier: Union[float, None]):
if multiplier is None or math.isclose(multiplier, 1.0):
self.scale = 1.0
else:
self.scale = multiplier
def set_sub_idxs(self, sub_idxs: list[int]):
if self.pos_encoder != None:
self.pos_encoder.set_sub_idxs(sub_idxs)
def forward(
self,
hidden_states: Tensor,
encoder_hidden_states=None,
attention_mask=None,
video_length=None,
scale_mask=None,
):
if self.attention_mode != "Temporal":
raise NotImplementedError
d = hidden_states.shape[1]
hidden_states = rearrange(
hidden_states, "(b f) d c -> (b d) f c", f=video_length
)
if self.pos_encoder is not None:
hidden_states = self.pos_encoder(hidden_states).to(hidden_states.dtype)
encoder_hidden_states = (
repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
if encoder_hidden_states is not None
else encoder_hidden_states
)
hidden_states = super().forward(
hidden_states,
encoder_hidden_states,
value=None,
mask=attention_mask,
scale_mask=scale_mask,
)
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
return hidden_states