|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
from diffusers.utils import BaseOutput, logging |
|
from diffusers.models.embeddings import TimestepEmbedding, Timesteps |
|
from diffusers.models.modeling_utils import ModelMixin |
|
from diffusers.models.resnet import Downsample2D, ResnetBlock2D |
|
from einops import rearrange |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
@dataclass |
|
class ControlNetOutput(BaseOutput): |
|
""" |
|
The output of [`ControlNetModel`]. |
|
|
|
Args: |
|
down_block_res_samples (`tuple[torch.Tensor]`): |
|
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should |
|
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be |
|
used to condition the original UNet's downsampling activations. |
|
mid_down_block_re_sample (`torch.Tensor`): |
|
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape |
|
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. |
|
Output can be used to condition the original UNet's middle block activation. |
|
""" |
|
|
|
down_block_res_samples: Tuple[torch.Tensor] |
|
mid_block_res_sample: torch.Tensor |
|
|
|
|
|
class Block2D(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
temb_channels: int, |
|
dropout: float = 0.0, |
|
num_layers: int = 1, |
|
resnet_eps: float = 1e-6, |
|
resnet_time_scale_shift: str = "default", |
|
resnet_act_fn: str = "swish", |
|
resnet_groups: int = 32, |
|
resnet_pre_norm: bool = True, |
|
output_scale_factor: float = 1.0, |
|
add_downsample: bool = True, |
|
downsample_padding: int = 1, |
|
): |
|
super().__init__() |
|
resnets = [] |
|
|
|
for i in range(num_layers): |
|
in_channels = in_channels if i == 0 else out_channels |
|
resnets.append( |
|
ResnetBlock2D( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
temb_channels=temb_channels, |
|
eps=resnet_eps, |
|
groups=resnet_groups, |
|
dropout=dropout, |
|
time_embedding_norm=resnet_time_scale_shift, |
|
non_linearity=resnet_act_fn, |
|
output_scale_factor=output_scale_factor, |
|
pre_norm=resnet_pre_norm, |
|
) |
|
) |
|
|
|
self.resnets = nn.ModuleList(resnets) |
|
|
|
if add_downsample: |
|
self.downsamplers = nn.ModuleList( |
|
[ |
|
Downsample2D( |
|
out_channels, |
|
use_conv=True, |
|
out_channels=out_channels, |
|
padding=downsample_padding, |
|
name="op", |
|
) |
|
] |
|
) |
|
else: |
|
self.downsamplers = None |
|
|
|
self.gradient_checkpointing = False |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.FloatTensor, |
|
temb: Optional[torch.FloatTensor] = None, |
|
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: |
|
output_states = () |
|
|
|
for resnet in zip(self.resnets): |
|
hidden_states = resnet(hidden_states, temb) |
|
output_states += (hidden_states,) |
|
|
|
if self.downsamplers is not None: |
|
for downsampler in self.downsamplers: |
|
hidden_states = downsampler(hidden_states) |
|
|
|
output_states += (hidden_states,) |
|
|
|
return hidden_states, output_states |
|
|
|
|
|
class IdentityModule(nn.Module): |
|
def __init__(self): |
|
super(IdentityModule, self).__init__() |
|
|
|
def forward(self, *args): |
|
if len(args) > 0: |
|
return args[0] |
|
else: |
|
return None |
|
|
|
|
|
class BasicBlock(nn.Module): |
|
def __init__(self, |
|
in_channels: int, |
|
out_channels: Optional[int] = None, |
|
stride=1, |
|
conv_shortcut: bool = False, |
|
dropout: float = 0.0, |
|
temb_channels: int = 512, |
|
groups: int = 32, |
|
groups_out: Optional[int] = None, |
|
pre_norm: bool = True, |
|
eps: float = 1e-6, |
|
non_linearity: str = "swish", |
|
skip_time_act: bool = False, |
|
time_embedding_norm: str = "default", |
|
kernel: Optional[torch.FloatTensor] = None, |
|
output_scale_factor: float = 1.0, |
|
use_in_shortcut: Optional[bool] = None, |
|
up: bool = False, |
|
down: bool = False, |
|
conv_shortcut_bias: bool = True, |
|
conv_2d_out_channels: Optional[int] = None,): |
|
super(BasicBlock, self).__init__() |
|
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False) |
|
self.bn1 = nn.BatchNorm2d(out_channels) |
|
self.relu = nn.ReLU(inplace=True) |
|
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) |
|
self.bn2 = nn.BatchNorm2d(out_channels) |
|
|
|
self.downsample = None |
|
if stride != 1 or in_channels != out_channels: |
|
self.downsample = nn.Sequential( |
|
nn.Conv2d(in_channels, |
|
out_channels, |
|
kernel_size=3 if stride != 1 else 1, |
|
stride=stride, |
|
padding=1 if stride != 1 else 0, |
|
bias=False), |
|
nn.BatchNorm2d(out_channels) |
|
) |
|
|
|
def forward(self, x, *args): |
|
residual = x |
|
out = self.conv1(x) |
|
out = self.bn1(out) |
|
out = self.relu(out) |
|
|
|
out = self.conv2(out) |
|
out = self.bn2(out) |
|
|
|
if self.downsample is not None: |
|
residual = self.downsample(x) |
|
|
|
out += residual |
|
out = self.relu(out) |
|
|
|
return out |
|
|
|
|
|
class Block2D(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
temb_channels: int, |
|
dropout: float = 0.0, |
|
num_layers: int = 1, |
|
resnet_eps: float = 1e-6, |
|
resnet_time_scale_shift: str = "default", |
|
resnet_act_fn: str = "swish", |
|
resnet_groups: int = 32, |
|
resnet_pre_norm: bool = True, |
|
output_scale_factor: float = 1.0, |
|
add_downsample: bool = True, |
|
downsample_padding: int = 1, |
|
): |
|
super().__init__() |
|
resnets = [] |
|
|
|
for i in range(num_layers): |
|
|
|
resnets.append( |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BasicBlock( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
temb_channels=temb_channels, |
|
eps=resnet_eps, |
|
groups=resnet_groups, |
|
dropout=dropout, |
|
time_embedding_norm=resnet_time_scale_shift, |
|
non_linearity=resnet_act_fn, |
|
output_scale_factor=output_scale_factor, |
|
pre_norm=resnet_pre_norm, |
|
) if i == num_layers - 1 else \ |
|
IdentityModule() |
|
) |
|
|
|
self.resnets = nn.ModuleList(resnets) |
|
|
|
if add_downsample: |
|
self.downsamplers = nn.ModuleList( |
|
[ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BasicBlock( |
|
in_channels=out_channels, |
|
out_channels=out_channels, |
|
temb_channels=temb_channels, |
|
stride=2, |
|
eps=resnet_eps, |
|
groups=resnet_groups, |
|
dropout=dropout, |
|
time_embedding_norm=resnet_time_scale_shift, |
|
non_linearity=resnet_act_fn, |
|
output_scale_factor=output_scale_factor, |
|
pre_norm=resnet_pre_norm, |
|
) |
|
] |
|
) |
|
else: |
|
self.downsamplers = None |
|
|
|
self.gradient_checkpointing = False |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.FloatTensor, |
|
temb: Optional[torch.FloatTensor] = None, |
|
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: |
|
output_states = () |
|
|
|
for resnet in self.resnets: |
|
hidden_states = resnet(hidden_states, temb) |
|
output_states += (hidden_states,) |
|
|
|
if self.downsamplers is not None: |
|
for downsampler in self.downsamplers: |
|
hidden_states = downsampler(hidden_states) |
|
|
|
output_states += (hidden_states,) |
|
|
|
return hidden_states, output_states |
|
|
|
|
|
class ControlProject(nn.Module): |
|
def __init__(self, num_channels, scale=8, is_empty=False) -> None: |
|
super().__init__() |
|
assert scale and scale & (scale - 1) == 0 |
|
self.is_empty = is_empty |
|
self.scale = scale |
|
if not is_empty: |
|
if scale > 1: |
|
self.down_scale = nn.AvgPool2d(scale, scale) |
|
else: |
|
self.down_scale = nn.Identity() |
|
self.out = nn.Conv2d(num_channels, num_channels, kernel_size=1, stride=1, bias=False) |
|
for p in self.out.parameters(): |
|
nn.init.zeros_(p) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.FloatTensor): |
|
if self.is_empty: |
|
shape = list(hidden_states.shape) |
|
shape[-2] = shape[-2] // self.scale |
|
shape[-1] = shape[-1] // self.scale |
|
return torch.zeros(shape).to(hidden_states) |
|
|
|
if len(hidden_states.shape) == 5: |
|
B, F, C, H, W = hidden_states.shape |
|
hidden_states = rearrange(hidden_states, "B F C H W -> (B F) C H W") |
|
hidden_states = self.down_scale(hidden_states) |
|
hidden_states = self.out(hidden_states) |
|
hidden_states = rearrange(hidden_states, "(B F) C H W -> B F C H W", F=F) |
|
else: |
|
hidden_states = self.down_scale(hidden_states) |
|
hidden_states = self.out(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class ControlNetModel(ModelMixin, ConfigMixin): |
|
|
|
_supports_gradient_checkpointing = True |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
in_channels: List[int] = [128, 128], |
|
out_channels: List[int] = [128, 256], |
|
groups: List[int] = [4, 8], |
|
time_embed_dim: int = 256, |
|
final_out_channels: int = 320, |
|
): |
|
super().__init__() |
|
|
|
self.time_proj = Timesteps(128, True, downscale_freq_shift=0) |
|
self.time_embedding = TimestepEmbedding(128, time_embed_dim) |
|
|
|
self.embedding = nn.Sequential( |
|
nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), |
|
nn.GroupNorm(2, 64), |
|
nn.ReLU(), |
|
nn.Conv2d(64, 64, kernel_size=3, padding=1), |
|
nn.GroupNorm(2, 64), |
|
nn.ReLU(), |
|
nn.Conv2d(64, 128, kernel_size=3, padding=1), |
|
nn.GroupNorm(2, 128), |
|
nn.ReLU(), |
|
) |
|
|
|
self.down_res = nn.ModuleList() |
|
self.down_sample = nn.ModuleList() |
|
for i in range(len(in_channels)): |
|
self.down_res.append( |
|
ResnetBlock2D( |
|
in_channels=in_channels[i], |
|
out_channels=out_channels[i], |
|
temb_channels=time_embed_dim, |
|
groups=groups[i] |
|
), |
|
) |
|
self.down_sample.append( |
|
Downsample2D( |
|
out_channels[i], |
|
use_conv=True, |
|
out_channels=out_channels[i], |
|
padding=1, |
|
name="op", |
|
) |
|
) |
|
|
|
self.mid_convs = nn.ModuleList() |
|
self.mid_convs.append(nn.Sequential( |
|
nn.Conv2d( |
|
in_channels=out_channels[-1], |
|
out_channels=out_channels[-1], |
|
kernel_size=3, |
|
stride=1, |
|
padding=1 |
|
), |
|
nn.ReLU(), |
|
nn.GroupNorm(8, out_channels[-1]), |
|
nn.Conv2d( |
|
in_channels=out_channels[-1], |
|
out_channels=out_channels[-1], |
|
kernel_size=3, |
|
stride=1, |
|
padding=1 |
|
), |
|
nn.GroupNorm(8, out_channels[-1]), |
|
)) |
|
self.mid_convs.append( |
|
nn.Conv2d( |
|
in_channels=out_channels[-1], |
|
out_channels=final_out_channels, |
|
kernel_size=1, |
|
stride=1, |
|
)) |
|
self.scale = 1.0 |
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
if hasattr(module, "gradient_checkpointing"): |
|
module.gradient_checkpointing = value |
|
|
|
|
|
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: |
|
""" |
|
Sets the attention processor to use [feed forward |
|
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). |
|
|
|
Parameters: |
|
chunk_size (`int`, *optional*): |
|
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually |
|
over each tensor of dim=`dim`. |
|
dim (`int`, *optional*, defaults to `0`): |
|
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) |
|
or dim=1 (sequence length). |
|
""" |
|
if dim not in [0, 1]: |
|
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") |
|
|
|
|
|
chunk_size = chunk_size or 1 |
|
|
|
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): |
|
if hasattr(module, "set_chunk_feed_forward"): |
|
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) |
|
|
|
for child in module.children(): |
|
fn_recursive_feed_forward(child, chunk_size, dim) |
|
|
|
for module in self.children(): |
|
fn_recursive_feed_forward(module, chunk_size, dim) |
|
|
|
def forward( |
|
self, |
|
sample: torch.FloatTensor, |
|
timestep: Union[torch.Tensor, float, int], |
|
) -> Union[ControlNetOutput, Tuple]: |
|
|
|
timesteps = timestep |
|
if not torch.is_tensor(timesteps): |
|
|
|
|
|
is_mps = sample.device.type == "mps" |
|
if isinstance(timestep, float): |
|
dtype = torch.float32 if is_mps else torch.float64 |
|
else: |
|
dtype = torch.int32 if is_mps else torch.int64 |
|
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) |
|
elif len(timesteps.shape) == 0: |
|
timesteps = timesteps[None].to(sample.device) |
|
|
|
|
|
batch_size = sample.shape[0] |
|
timesteps = timesteps.expand(batch_size) |
|
t_emb = self.time_proj(timesteps) |
|
|
|
|
|
|
|
t_emb = t_emb.to(dtype=sample.dtype) |
|
emb_batch = self.time_embedding(t_emb) |
|
|
|
|
|
|
|
emb = emb_batch |
|
sample = self.embedding(sample) |
|
for res, downsample in zip(self.down_res, self.down_sample): |
|
sample = res(sample, emb) |
|
sample = downsample(sample, emb) |
|
sample = self.mid_convs[0](sample) + sample |
|
sample = self.mid_convs[1](sample) |
|
return sample |
|
|
|
|
|
def zero_module(module): |
|
for p in module.parameters(): |
|
nn.init.zeros_(p) |
|
return module |
|
|