|
|
|
|
|
import functools |
|
import torch |
|
import einops |
|
|
|
from comfy import model_management, utils |
|
from comfy.ldm.modules.attention import optimized_attention |
|
|
|
module_mapping_sd15 = { |
|
0: "input_blocks.1.1.transformer_blocks.0.attn1", |
|
1: "input_blocks.1.1.transformer_blocks.0.attn2", |
|
2: "input_blocks.2.1.transformer_blocks.0.attn1", |
|
3: "input_blocks.2.1.transformer_blocks.0.attn2", |
|
4: "input_blocks.4.1.transformer_blocks.0.attn1", |
|
5: "input_blocks.4.1.transformer_blocks.0.attn2", |
|
6: "input_blocks.5.1.transformer_blocks.0.attn1", |
|
7: "input_blocks.5.1.transformer_blocks.0.attn2", |
|
8: "input_blocks.7.1.transformer_blocks.0.attn1", |
|
9: "input_blocks.7.1.transformer_blocks.0.attn2", |
|
10: "input_blocks.8.1.transformer_blocks.0.attn1", |
|
11: "input_blocks.8.1.transformer_blocks.0.attn2", |
|
12: "output_blocks.3.1.transformer_blocks.0.attn1", |
|
13: "output_blocks.3.1.transformer_blocks.0.attn2", |
|
14: "output_blocks.4.1.transformer_blocks.0.attn1", |
|
15: "output_blocks.4.1.transformer_blocks.0.attn2", |
|
16: "output_blocks.5.1.transformer_blocks.0.attn1", |
|
17: "output_blocks.5.1.transformer_blocks.0.attn2", |
|
18: "output_blocks.6.1.transformer_blocks.0.attn1", |
|
19: "output_blocks.6.1.transformer_blocks.0.attn2", |
|
20: "output_blocks.7.1.transformer_blocks.0.attn1", |
|
21: "output_blocks.7.1.transformer_blocks.0.attn2", |
|
22: "output_blocks.8.1.transformer_blocks.0.attn1", |
|
23: "output_blocks.8.1.transformer_blocks.0.attn2", |
|
24: "output_blocks.9.1.transformer_blocks.0.attn1", |
|
25: "output_blocks.9.1.transformer_blocks.0.attn2", |
|
26: "output_blocks.10.1.transformer_blocks.0.attn1", |
|
27: "output_blocks.10.1.transformer_blocks.0.attn2", |
|
28: "output_blocks.11.1.transformer_blocks.0.attn1", |
|
29: "output_blocks.11.1.transformer_blocks.0.attn2", |
|
30: "middle_block.1.transformer_blocks.0.attn1", |
|
31: "middle_block.1.transformer_blocks.0.attn2", |
|
} |
|
|
|
|
|
def compute_cond_mark(cond_or_uncond, sigmas): |
|
cond_or_uncond_size = int(sigmas.shape[0]) |
|
|
|
cond_mark = [] |
|
for cx in cond_or_uncond: |
|
cond_mark += [cx] * cond_or_uncond_size |
|
|
|
cond_mark = torch.Tensor(cond_mark).to(sigmas) |
|
return cond_mark |
|
|
|
|
|
class LoRALinearLayer(torch.nn.Module): |
|
def __init__(self, in_features: int, out_features: int, rank: int = 256, org=None): |
|
super().__init__() |
|
self.down = torch.nn.Linear(in_features, rank, bias=False) |
|
self.up = torch.nn.Linear(rank, out_features, bias=False) |
|
self.org = [org] |
|
|
|
def forward(self, h): |
|
org_weight = self.org[0].weight.to(h) |
|
org_bias = self.org[0].bias.to(h) if self.org[0].bias is not None else None |
|
down_weight = self.down.weight |
|
up_weight = self.up.weight |
|
final_weight = org_weight + torch.mm(up_weight, down_weight) |
|
return torch.nn.functional.linear(h, final_weight, org_bias) |
|
|
|
|
|
class AttentionSharingUnit(torch.nn.Module): |
|
|
|
|
|
transformer_options: dict = {} |
|
|
|
def __init__(self, module, frames=2, use_control=True, rank=256): |
|
super().__init__() |
|
|
|
self.heads = module.heads |
|
self.frames = frames |
|
self.original_module = [module] |
|
q_in_channels, q_out_channels = ( |
|
module.to_q.in_features, |
|
module.to_q.out_features, |
|
) |
|
k_in_channels, k_out_channels = ( |
|
module.to_k.in_features, |
|
module.to_k.out_features, |
|
) |
|
v_in_channels, v_out_channels = ( |
|
module.to_v.in_features, |
|
module.to_v.out_features, |
|
) |
|
o_in_channels, o_out_channels = ( |
|
module.to_out[0].in_features, |
|
module.to_out[0].out_features, |
|
) |
|
|
|
hidden_size = k_out_channels |
|
|
|
self.to_q_lora = [ |
|
LoRALinearLayer(q_in_channels, q_out_channels, rank, module.to_q) |
|
for _ in range(self.frames) |
|
] |
|
self.to_k_lora = [ |
|
LoRALinearLayer(k_in_channels, k_out_channels, rank, module.to_k) |
|
for _ in range(self.frames) |
|
] |
|
self.to_v_lora = [ |
|
LoRALinearLayer(v_in_channels, v_out_channels, rank, module.to_v) |
|
for _ in range(self.frames) |
|
] |
|
self.to_out_lora = [ |
|
LoRALinearLayer(o_in_channels, o_out_channels, rank, module.to_out[0]) |
|
for _ in range(self.frames) |
|
] |
|
|
|
self.to_q_lora = torch.nn.ModuleList(self.to_q_lora) |
|
self.to_k_lora = torch.nn.ModuleList(self.to_k_lora) |
|
self.to_v_lora = torch.nn.ModuleList(self.to_v_lora) |
|
self.to_out_lora = torch.nn.ModuleList(self.to_out_lora) |
|
|
|
self.temporal_i = torch.nn.Linear( |
|
in_features=hidden_size, out_features=hidden_size |
|
) |
|
self.temporal_n = torch.nn.LayerNorm( |
|
hidden_size, elementwise_affine=True, eps=1e-6 |
|
) |
|
self.temporal_q = torch.nn.Linear( |
|
in_features=hidden_size, out_features=hidden_size |
|
) |
|
self.temporal_k = torch.nn.Linear( |
|
in_features=hidden_size, out_features=hidden_size |
|
) |
|
self.temporal_v = torch.nn.Linear( |
|
in_features=hidden_size, out_features=hidden_size |
|
) |
|
self.temporal_o = torch.nn.Linear( |
|
in_features=hidden_size, out_features=hidden_size |
|
) |
|
|
|
self.control_convs = None |
|
|
|
if use_control: |
|
self.control_convs = [ |
|
torch.nn.Sequential( |
|
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1), |
|
torch.nn.SiLU(), |
|
torch.nn.Conv2d(256, hidden_size, kernel_size=1), |
|
) |
|
for _ in range(self.frames) |
|
] |
|
self.control_convs = torch.nn.ModuleList(self.control_convs) |
|
|
|
self.control_signals = None |
|
|
|
def forward(self, h, context=None, value=None): |
|
transformer_options = self.transformer_options |
|
|
|
modified_hidden_states = einops.rearrange( |
|
h, "(b f) d c -> f b d c", f=self.frames |
|
) |
|
|
|
if self.control_convs is not None: |
|
context_dim = int(modified_hidden_states.shape[2]) |
|
control_outs = [] |
|
for f in range(self.frames): |
|
control_signal = self.control_signals[context_dim].to( |
|
modified_hidden_states |
|
) |
|
control = self.control_convs[f](control_signal) |
|
control = einops.rearrange(control, "b c h w -> b (h w) c") |
|
control_outs.append(control) |
|
control_outs = torch.stack(control_outs, dim=0) |
|
modified_hidden_states = modified_hidden_states + control_outs.to( |
|
modified_hidden_states |
|
) |
|
|
|
if context is None: |
|
framed_context = modified_hidden_states |
|
else: |
|
framed_context = einops.rearrange( |
|
context, "(b f) d c -> f b d c", f=self.frames |
|
) |
|
|
|
framed_cond_mark = einops.rearrange( |
|
compute_cond_mark( |
|
transformer_options["cond_or_uncond"], |
|
transformer_options["sigmas"], |
|
), |
|
"(b f) -> f b", |
|
f=self.frames, |
|
).to(modified_hidden_states) |
|
|
|
attn_outs = [] |
|
for f in range(self.frames): |
|
fcf = framed_context[f] |
|
|
|
if context is not None: |
|
cond_overwrite = transformer_options.get("cond_overwrite", []) |
|
if len(cond_overwrite) > f: |
|
cond_overwrite = cond_overwrite[f] |
|
else: |
|
cond_overwrite = None |
|
if cond_overwrite is not None: |
|
cond_mark = framed_cond_mark[f][:, None, None] |
|
fcf = cond_overwrite.to(fcf) * (1.0 - cond_mark) + fcf * cond_mark |
|
|
|
q = self.to_q_lora[f](modified_hidden_states[f]) |
|
k = self.to_k_lora[f](fcf) |
|
v = self.to_v_lora[f](fcf) |
|
o = optimized_attention(q, k, v, self.heads) |
|
o = self.to_out_lora[f](o) |
|
o = self.original_module[0].to_out[1](o) |
|
attn_outs.append(o) |
|
|
|
attn_outs = torch.stack(attn_outs, dim=0) |
|
modified_hidden_states = modified_hidden_states + attn_outs.to( |
|
modified_hidden_states |
|
) |
|
modified_hidden_states = einops.rearrange( |
|
modified_hidden_states, "f b d c -> (b f) d c", f=self.frames |
|
) |
|
|
|
x = modified_hidden_states |
|
x = self.temporal_n(x) |
|
x = self.temporal_i(x) |
|
d = x.shape[1] |
|
|
|
x = einops.rearrange(x, "(b f) d c -> (b d) f c", f=self.frames) |
|
|
|
q = self.temporal_q(x) |
|
k = self.temporal_k(x) |
|
v = self.temporal_v(x) |
|
|
|
x = optimized_attention(q, k, v, self.heads) |
|
x = self.temporal_o(x) |
|
x = einops.rearrange(x, "(b d) f c -> (b f) d c", d=d) |
|
|
|
modified_hidden_states = modified_hidden_states + x |
|
|
|
return modified_hidden_states - h |
|
|
|
@classmethod |
|
def hijack_transformer_block(cls): |
|
def register_get_transformer_options(func): |
|
@functools.wraps(func) |
|
def forward(self, x, context=None, transformer_options={}): |
|
cls.transformer_options = transformer_options |
|
return func(self, x, context, transformer_options) |
|
|
|
return forward |
|
|
|
from comfy.ldm.modules.attention import BasicTransformerBlock |
|
|
|
BasicTransformerBlock.forward = register_get_transformer_options( |
|
BasicTransformerBlock.forward |
|
) |
|
|
|
|
|
AttentionSharingUnit.hijack_transformer_block() |
|
|
|
|
|
class AdditionalAttentionCondsEncoder(torch.nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
self.blocks_0 = torch.nn.Sequential( |
|
torch.nn.Conv2d(3, 32, kernel_size=3, padding=1, stride=1), |
|
torch.nn.SiLU(), |
|
torch.nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1), |
|
torch.nn.SiLU(), |
|
torch.nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=2), |
|
torch.nn.SiLU(), |
|
torch.nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1), |
|
torch.nn.SiLU(), |
|
torch.nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=2), |
|
torch.nn.SiLU(), |
|
torch.nn.Conv2d(128, 128, kernel_size=3, padding=1, stride=1), |
|
torch.nn.SiLU(), |
|
torch.nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2), |
|
torch.nn.SiLU(), |
|
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1), |
|
torch.nn.SiLU(), |
|
) |
|
|
|
self.blocks_1 = torch.nn.Sequential( |
|
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2), |
|
torch.nn.SiLU(), |
|
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1), |
|
torch.nn.SiLU(), |
|
) |
|
|
|
self.blocks_2 = torch.nn.Sequential( |
|
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2), |
|
torch.nn.SiLU(), |
|
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1), |
|
torch.nn.SiLU(), |
|
) |
|
|
|
self.blocks_3 = torch.nn.Sequential( |
|
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2), |
|
torch.nn.SiLU(), |
|
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1), |
|
torch.nn.SiLU(), |
|
) |
|
|
|
self.blks = [self.blocks_0, self.blocks_1, self.blocks_2, self.blocks_3] |
|
|
|
def __call__(self, h): |
|
results = {} |
|
for b in self.blks: |
|
h = b(h) |
|
results[int(h.shape[2]) * int(h.shape[3])] = h |
|
return results |
|
|
|
|
|
class HookerLayers(torch.nn.Module): |
|
def __init__(self, layer_list): |
|
super().__init__() |
|
self.layers = torch.nn.ModuleList(layer_list) |
|
|
|
|
|
class AttentionSharingPatcher(torch.nn.Module): |
|
def __init__(self, unet, frames=2, use_control=True, rank=256): |
|
super().__init__() |
|
model_management.unload_model_clones(unet) |
|
|
|
units = [] |
|
for i in range(32): |
|
real_key = module_mapping_sd15[i] |
|
attn_module = utils.get_attr(unet.model.diffusion_model, real_key) |
|
u = AttentionSharingUnit( |
|
attn_module, frames=frames, use_control=use_control, rank=rank |
|
) |
|
units.append(u) |
|
unet.add_object_patch("diffusion_model." + real_key, u) |
|
|
|
self.hookers = HookerLayers(units) |
|
|
|
if use_control: |
|
self.kwargs_encoder = AdditionalAttentionCondsEncoder() |
|
else: |
|
self.kwargs_encoder = None |
|
|
|
self.dtype = torch.float32 |
|
if model_management.should_use_fp16(model_management.get_torch_device()): |
|
self.dtype = torch.float16 |
|
self.hookers.half() |
|
return |
|
|
|
def set_control(self, img): |
|
img = img.cpu().float() * 2.0 - 1.0 |
|
signals = self.kwargs_encoder(img) |
|
for m in self.hookers.layers: |
|
m.control_signals = signals |
|
return |
|
|