|
|
|
import math |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from torch import Tensor, nn |
|
|
|
import comfy.ops |
|
import comfy.model_management |
|
|
|
|
|
def zero_module(module): |
|
|
|
for p in module.parameters(): |
|
p.detach().zero_() |
|
return module |
|
|
|
|
|
def avg_pool_nd(dims, *args, **kwargs): |
|
""" |
|
Create a 1D, 2D, or 3D average pooling module. |
|
""" |
|
if dims == 1: |
|
return nn.AvgPool1d(*args, **kwargs) |
|
elif dims == 2: |
|
return nn.AvgPool2d(*args, **kwargs) |
|
elif dims == 3: |
|
return nn.AvgPool3d(*args, **kwargs) |
|
raise ValueError(f"unsupported dimensions: {dims}") |
|
|
|
|
|
|
|
def fixed_positional_embedding(t, d_model): |
|
position = torch.arange(0, t, dtype=torch.float).unsqueeze(1) |
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() |
|
* (-np.log(10000.0) / d_model)) |
|
pos_embedding = torch.zeros(t, d_model) |
|
pos_embedding[:, 0::2] = torch.sin(position * div_term) |
|
pos_embedding[:, 1::2] = torch.cos(position * div_term) |
|
return pos_embedding |
|
|
|
|
|
class AdapterEmbed(nn.Module): |
|
def __init__(self, channels=[320, 640, 1280, 1280], |
|
nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True, |
|
ops=comfy.ops.disable_weight_init): |
|
super(AdapterEmbed, self).__init__() |
|
self.channels = channels |
|
self.nums_rb = nums_rb |
|
self.body = [] |
|
for i in range(len(channels)): |
|
for j in range(nums_rb): |
|
if (i != 0) and (j == 0): |
|
self.body.append(ResnetBlockEmbed( |
|
channels[i-1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv, ops=ops |
|
)) |
|
else: |
|
self.body.append(ResnetBlockEmbed( |
|
channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv, ops=ops |
|
)) |
|
self.body = nn.ModuleList(self.body) |
|
self.conv_in = zero_module(ops.Conv2d(in_channels=cin, out_channels=channels[0], |
|
kernel_size=3, stride=1, padding=1)) |
|
self.d_model = channels[0] |
|
|
|
self.ref_drift = 0.5 |
|
self.insertion_weights = [1.0, 1.0, 1.0, 1.0] |
|
|
|
def set_ref_drift(self, ref_drift: float): |
|
if ref_drift is None: |
|
ref_drift = 0.5 |
|
self.ref_drift = ref_drift |
|
|
|
def set_insertion_weights(self, insertion_weights: list[float]): |
|
if insertion_weights is None: |
|
insertion_weights = [1.0, 1.0, 1.0, 1.0] |
|
assert len(insertion_weights) == 4 |
|
self.insertion_weights = insertion_weights |
|
|
|
def cleanup(self): |
|
self.set_ref_drift(None) |
|
self.set_insertion_weights(None) |
|
|
|
def forward(self, x: Tensor, video_length: int, batched_number: int): |
|
b, c, h, w = x.shape |
|
|
|
features = [] |
|
|
|
use_dtype = comfy.model_management.unet_dtype() |
|
|
|
if comfy.model_management.dtype_size(use_dtype) == 1: |
|
use_dtype = x.dtype |
|
|
|
x = self.conv_in(x.to(use_dtype)) |
|
|
|
pos_embedding = fixed_positional_embedding( |
|
video_length, self.d_model).to(use_dtype).to(x.device) |
|
pos_embedding = pos_embedding.unsqueeze(-1).unsqueeze(-1) |
|
pos_embedding = pos_embedding.expand(-1, -1, h, w) |
|
|
|
x = x + (pos_embedding * self.ref_drift) |
|
|
|
for i in range(len(self.channels)): |
|
for j in range(self.nums_rb): |
|
|
|
idx = i*self.nums_rb + j |
|
x = self.body[idx](x) |
|
|
|
real_x = x.repeat(batched_number, 1, 1, 1) |
|
features.append(real_x) |
|
features = [weight * feature for weight, feature in zip(features, self.insertion_weights)] |
|
return features |
|
|
|
|
|
class ResnetBlockEmbed(nn.Module): |
|
def __init__(self, in_c, out_c, down: bool, ksize=3, sk=False, use_conv=True, |
|
ops=comfy.ops.disable_weight_init): |
|
super().__init__() |
|
ps = ksize // 2 |
|
if in_c != out_c or sk == False: |
|
self.in_conv = zero_module(ops.Conv2d(in_c, out_c, ksize, 1, ps)) |
|
else: |
|
self.in_conv = None |
|
self.block1 = ops.Conv2d(out_c, out_c, 3, 1, 1) |
|
self.act = nn.ReLU() |
|
self.block2 = zero_module(ops.Conv2d(out_c, out_c, ksize, 1, ps)) |
|
if sk == False: |
|
self.skep = ops.Conv2d(in_c, out_c, ksize, 1, ps) |
|
else: |
|
self.skep = None |
|
|
|
self.down = down |
|
if self.down == True: |
|
self.down_opt = DownsampleEmbed(in_c, use_conv=use_conv, ops=ops) |
|
|
|
def forward(self, x: Tensor): |
|
if self.down == True: |
|
x = self.down_opt(x) |
|
|
|
if self.in_conv is not None: |
|
x = self.in_conv(x) |
|
|
|
h = self.block1(x) |
|
h = self.act(h) |
|
h = self.block2(h) |
|
|
|
if self.skep is not None: |
|
return h + self.skep(x) |
|
else: |
|
return h + x |
|
|
|
|
|
class DownsampleEmbed(nn.Module): |
|
""" |
|
A downsampling layer with an optional convolution. |
|
:param channels: channels in the inputs and outputs. |
|
:param use_conv: a bool determining if a convolution is applied. |
|
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then |
|
downsampling occurs in the inner-two dimensions. |
|
""" |
|
|
|
def __init__(self, channels, use_conv: bool, dims=2, out_channels=None, padding=1, |
|
ops=comfy.ops.disable_weight_init): |
|
super().__init__() |
|
self.channels = channels |
|
self.out_channels = out_channels or channels |
|
self.use_conv = use_conv |
|
self.dims = dims |
|
stride = 2 if dims != 3 else (1, 2, 2) |
|
if use_conv: |
|
self.operation = ops.conv_nd(dims, in_channels=self.channels, out_channels=self.out_channels, |
|
kernel_size=3, stride=stride, padding=padding) |
|
else: |
|
assert self.channels == self.out_channels |
|
self.operation = avg_pool_nd(dims, kernel_size=stride, stride=stride) |
|
|
|
def forward(self, x: Tensor): |
|
assert x.shape[1] == self.channels |
|
|
|
kernel_size = (2, 2) |
|
|
|
input_height, input_width = x.size(-2), x.size(-1) |
|
|
|
padding_height = ( |
|
math.ceil(input_height / kernel_size[0]) * kernel_size[0]) - input_height |
|
padding_width = ( |
|
math.ceil(input_width / kernel_size[1]) * kernel_size[1]) - input_width |
|
|
|
x = F.pad(x, (0, padding_width, 0, padding_height), mode='replicate') |
|
|
|
return self.operation(x) |
|
|