File size: 4,519 Bytes
82ea528 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import torch
import torch.nn as nn
from comfy.model_patcher import ModelPatcher
from typing import Union
T = torch.Tensor
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d
class StyleAlignedArgs:
def __init__(self, share_attn: str) -> None:
self.adain_keys = "k" in share_attn
self.adain_values = "v" in share_attn
self.adain_queries = "q" in share_attn
share_attention: bool = True
adain_queries: bool = True
adain_keys: bool = True
adain_values: bool = True
def expand_first(
feat: T,
scale=1.0,
) -> T:
"""
Expand the first element so it has the same shape as the rest of the batch.
"""
b = feat.shape[0]
feat_style = torch.stack((feat[0], feat[b // 2])).unsqueeze(1)
if scale == 1:
feat_style = feat_style.expand(2, b // 2, *feat.shape[1:])
else:
feat_style = feat_style.repeat(1, b // 2, 1, 1, 1)
feat_style = torch.cat([feat_style[:, :1], scale * feat_style[:, 1:]], dim=1)
return feat_style.reshape(*feat.shape)
def concat_first(feat: T, dim=2, scale=1.0) -> T:
"""
concat the the feature and the style feature expanded above
"""
feat_style = expand_first(feat, scale=scale)
return torch.cat((feat, feat_style), dim=dim)
def calc_mean_std(feat, eps: float = 1e-5) -> "tuple[T, T]":
feat_std = (feat.var(dim=-2, keepdims=True) + eps).sqrt()
feat_mean = feat.mean(dim=-2, keepdims=True)
return feat_mean, feat_std
def adain(feat: T) -> T:
feat_mean, feat_std = calc_mean_std(feat)
feat_style_mean = expand_first(feat_mean)
feat_style_std = expand_first(feat_std)
feat = (feat - feat_mean) / feat_std
feat = feat * feat_style_std + feat_style_mean
return feat
class SharedAttentionProcessor:
def __init__(self, args: StyleAlignedArgs, scale: float):
self.args = args
self.scale = scale
def __call__(self, q, k, v, extra_options):
if self.args.adain_queries:
q = adain(q)
if self.args.adain_keys:
k = adain(k)
if self.args.adain_values:
v = adain(v)
if self.args.share_attention:
k = concat_first(k, -2, scale=self.scale)
v = concat_first(v, -2)
return q, k, v
def get_norm_layers(
layer: nn.Module,
norm_layers_: "dict[str, list[Union[nn.GroupNorm, nn.LayerNorm]]]",
share_layer_norm: bool,
share_group_norm: bool,
):
if isinstance(layer, nn.LayerNorm) and share_layer_norm:
norm_layers_["layer"].append(layer)
if isinstance(layer, nn.GroupNorm) and share_group_norm:
norm_layers_["group"].append(layer)
else:
for child_layer in layer.children():
get_norm_layers(
child_layer, norm_layers_, share_layer_norm, share_group_norm
)
def register_norm_forward(
norm_layer: Union[nn.GroupNorm, nn.LayerNorm],
) -> Union[nn.GroupNorm, nn.LayerNorm]:
if not hasattr(norm_layer, "orig_forward"):
setattr(norm_layer, "orig_forward", norm_layer.forward)
orig_forward = norm_layer.orig_forward
def forward_(hidden_states: T) -> T:
n = hidden_states.shape[-2]
hidden_states = concat_first(hidden_states, dim=-2)
hidden_states = orig_forward(hidden_states) # type: ignore
return hidden_states[..., :n, :]
norm_layer.forward = forward_ # type: ignore
return norm_layer
def register_shared_norm(
model: ModelPatcher,
share_group_norm: bool = True,
share_layer_norm: bool = True,
):
norm_layers = {"group": [], "layer": []}
get_norm_layers(model.model, norm_layers, share_layer_norm, share_group_norm)
print(
f"Patching {len(norm_layers['group'])} group norms, {len(norm_layers['layer'])} layer norms."
)
return [register_norm_forward(layer) for layer in norm_layers["group"]] + [
register_norm_forward(layer) for layer in norm_layers["layer"]
]
SHARE_NORM_OPTIONS = ["both", "group", "layer", "disabled"]
SHARE_ATTN_OPTIONS = ["q+k", "q+k+v", "disabled"]
def styleAlignBatch(model, share_norm, share_attn, scale=1.0):
m = model.clone()
share_group_norm = share_norm in ["group", "both"]
share_layer_norm = share_norm in ["layer", "both"]
register_shared_norm(model, share_group_norm, share_layer_norm)
args = StyleAlignedArgs(share_attn)
m.set_model_attn1_patch(SharedAttentionProcessor(args, scale))
return m |