|
from typing import Optional, Tuple, Union |
|
|
|
import torch |
|
from diffusers.configuration_utils import register_to_config |
|
from diffusers.models.controlnet import ControlNetModel, zero_module |
|
from diffusers.models.embeddings import ( |
|
TextImageProjection, |
|
TextImageTimeEmbedding, |
|
TextTimeEmbedding, |
|
TimestepEmbedding, |
|
Timesteps, |
|
) |
|
from diffusers.models.unets.unet_2d_blocks import ( |
|
CrossAttnDownBlock2D, |
|
DownBlock2D, |
|
UNetMidBlock2D, |
|
UNetMidBlock2DCrossAttn, |
|
get_down_block, |
|
) |
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel |
|
from diffusers.utils import logging |
|
from torch import nn |
|
from torch.nn import functional as F |
|
from torch.utils.checkpoint import checkpoint |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class ResBlock(nn.Module): |
|
def __init__(self, dim): |
|
super().__init__() |
|
self.conv = nn.Sequential( |
|
nn.Conv2d(dim, dim, 3, 1, 1), |
|
nn.GroupNorm(num_groups=8, num_channels=dim), |
|
nn.SiLU(inplace=True), |
|
nn.Conv2d(dim, dim, 3, 1, 1), |
|
) |
|
|
|
def forward(self, x): |
|
return x + self.conv(x) |
|
|
|
|
|
class NeuralTextureEncoder(nn.Module): |
|
def __init__(self, in_dim=3, out_dim=16, dims=(32, 64, 128), groups=8): |
|
super().__init__() |
|
self.model = nn.Sequential( |
|
nn.Conv2d(in_dim, dims[0], kernel_size=3, padding=1), |
|
nn.SiLU(inplace=True), |
|
|
|
|
|
nn.Conv2d(dims[0], dims[1], kernel_size=3, padding=1, stride=2), |
|
nn.GroupNorm(num_groups=groups, num_channels=dims[1]), |
|
nn.SiLU(inplace=True), |
|
|
|
|
|
nn.Conv2d(dims[1], dims[2], kernel_size=3, padding=1, stride=2), |
|
nn.GroupNorm(num_groups=groups, num_channels=dims[2]), |
|
nn.SiLU(inplace=True), |
|
|
|
|
|
ResBlock(dims[2]), |
|
ResBlock(dims[2]), |
|
ResBlock(dims[2]), |
|
ResBlock(dims[2]), |
|
|
|
|
|
nn.ConvTranspose2d(dims[2], dims[1], kernel_size=4, padding=1, stride=2), |
|
nn.GroupNorm(num_groups=groups, num_channels=dims[1]), |
|
nn.SiLU(inplace=True), |
|
|
|
|
|
nn.ConvTranspose2d(dims[1], dims[0], kernel_size=4, padding=1, stride=2), |
|
nn.GroupNorm(num_groups=groups, num_channels=dims[0]), |
|
nn.SiLU(inplace=True), |
|
|
|
|
|
nn.Conv2d(dims[0], out_dim, kernel_size=3, padding=1), |
|
) |
|
self.gradient_checkpointing = False |
|
|
|
def forward(self, x): |
|
if self.training and self.gradient_checkpointing: |
|
x = checkpoint(self.model, x, use_reentrant=False) |
|
else: |
|
x = self.model(x) |
|
return x |
|
|
|
|
|
class NeuralTextureEmbedding(nn.Module): |
|
def __init__( |
|
self, |
|
conditioning_embedding_channels: int, |
|
conditioning_channels: int = 3, |
|
block_out_channels: Tuple[int] = (16, 32, 96, 256), |
|
shading_hint_channels: int = 12, |
|
): |
|
super().__init__() |
|
self.conditioning_channels = conditioning_channels |
|
self.shading_hint_channels = shading_hint_channels |
|
|
|
self.conv_in = nn.Conv2d(shading_hint_channels, block_out_channels[0], kernel_size=3, padding=1) |
|
self.neural_texture_encoder = NeuralTextureEncoder(in_dim=conditioning_channels, out_dim=shading_hint_channels) |
|
|
|
self.blocks = nn.ModuleList([]) |
|
|
|
for i in range(len(block_out_channels) - 1): |
|
channel_in = block_out_channels[i] |
|
channel_out = block_out_channels[i + 1] |
|
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) |
|
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) |
|
|
|
self.conv_out = zero_module( |
|
nn.Conv2d( |
|
block_out_channels[-1], |
|
conditioning_embedding_channels, |
|
kernel_size=3, |
|
padding=1 |
|
) |
|
) |
|
|
|
def forward(self, all_conditioning): |
|
|
|
conditioning, shading_hint = torch.split( |
|
all_conditioning, |
|
[self.conditioning_channels, self.shading_hint_channels], |
|
dim=1 |
|
) |
|
embedding = self.neural_texture_encoder(conditioning) |
|
|
|
|
|
embedding = embedding * shading_hint |
|
embedding = self.conv_in(embedding) |
|
embedding = F.silu(embedding) |
|
|
|
for block in self.blocks: |
|
embedding = block(embedding) |
|
embedding = F.silu(embedding) |
|
|
|
embedding = self.conv_out(embedding) |
|
|
|
return embedding |
|
|
|
|
|
class NeuralTextureControlNetModel(ControlNetModel): |
|
""" |
|
A Neural Texture ControlNet Model. |
|
|
|
Args: |
|
in_channels (`int`, defaults to 4, RGBA): |
|
The number of channels in the input sample. |
|
shading_hint_channels (`int`, defaults to 12): channel number of hints |
|
""" |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
in_channels: int = 4, |
|
conditioning_channels: int = 3, |
|
flip_sin_to_cos: bool = True, |
|
freq_shift: int = 0, |
|
down_block_types: Tuple[str, ...] = ( |
|
"CrossAttnDownBlock2D", |
|
"CrossAttnDownBlock2D", |
|
"CrossAttnDownBlock2D", |
|
"DownBlock2D", |
|
), |
|
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", |
|
only_cross_attention: Union[bool, Tuple[bool]] = False, |
|
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), |
|
layers_per_block: int = 2, |
|
downsample_padding: int = 1, |
|
mid_block_scale_factor: float = 1, |
|
act_fn: str = "silu", |
|
norm_num_groups: Optional[int] = 32, |
|
norm_eps: float = 1e-5, |
|
cross_attention_dim: int = 1280, |
|
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, |
|
encoder_hid_dim: Optional[int] = None, |
|
encoder_hid_dim_type: Optional[str] = None, |
|
attention_head_dim: Union[int, Tuple[int, ...]] = 8, |
|
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None, |
|
use_linear_projection: bool = False, |
|
class_embed_type: Optional[str] = None, |
|
addition_embed_type: Optional[str] = None, |
|
addition_time_embed_dim: Optional[int] = None, |
|
num_class_embeds: Optional[int] = None, |
|
upcast_attention: bool = False, |
|
resnet_time_scale_shift: str = "default", |
|
projection_class_embeddings_input_dim: Optional[int] = None, |
|
controlnet_conditioning_channel_order: str = "rgb", |
|
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), |
|
global_pool_conditions: bool = False, |
|
addition_embed_type_num_heads: int = 64, |
|
shading_hint_channels: int = 12, |
|
): |
|
super().__init__() |
|
|
|
num_attention_heads = num_attention_heads or attention_head_dim |
|
|
|
assert controlnet_conditioning_channel_order == "rgb", "Only RGB channel order is supported." |
|
assert global_pool_conditions is False, "Global pooling conditions is not supported." |
|
|
|
|
|
if len(block_out_channels) != len(down_block_types): |
|
raise ValueError( |
|
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." |
|
) |
|
|
|
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): |
|
raise ValueError( |
|
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." |
|
) |
|
|
|
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): |
|
raise ValueError( |
|
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." |
|
) |
|
|
|
if isinstance(transformer_layers_per_block, int): |
|
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) |
|
|
|
|
|
conv_in_kernel = 3 |
|
conv_in_padding = (conv_in_kernel - 1) // 2 |
|
self.conv_in = nn.Conv2d( |
|
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding |
|
) |
|
|
|
|
|
time_embed_dim = block_out_channels[0] * 4 |
|
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) |
|
timestep_input_dim = block_out_channels[0] |
|
self.time_embedding = TimestepEmbedding( |
|
timestep_input_dim, |
|
time_embed_dim, |
|
act_fn=act_fn, |
|
) |
|
|
|
if encoder_hid_dim_type is None and encoder_hid_dim is not None: |
|
encoder_hid_dim_type = "text_proj" |
|
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) |
|
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") |
|
|
|
if encoder_hid_dim is None and encoder_hid_dim_type is not None: |
|
raise ValueError( |
|
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." |
|
) |
|
|
|
if encoder_hid_dim_type == "text_proj": |
|
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) |
|
elif encoder_hid_dim_type == "text_image_proj": |
|
|
|
|
|
|
|
self.encoder_hid_proj = TextImageProjection( |
|
text_embed_dim=encoder_hid_dim, |
|
image_embed_dim=cross_attention_dim, |
|
cross_attention_dim=cross_attention_dim, |
|
) |
|
|
|
elif encoder_hid_dim_type is not None: |
|
raise ValueError( |
|
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." |
|
) |
|
else: |
|
self.encoder_hid_proj = None |
|
|
|
|
|
if class_embed_type is None and num_class_embeds is not None: |
|
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) |
|
elif class_embed_type == "timestep": |
|
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) |
|
elif class_embed_type == "identity": |
|
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) |
|
elif class_embed_type == "projection": |
|
if projection_class_embeddings_input_dim is None: |
|
raise ValueError( |
|
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" |
|
) |
|
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) |
|
else: |
|
self.class_embedding = None |
|
|
|
if addition_embed_type == "text": |
|
if encoder_hid_dim is not None: |
|
text_time_embedding_from_dim = encoder_hid_dim |
|
else: |
|
text_time_embedding_from_dim = cross_attention_dim |
|
|
|
self.add_embedding = TextTimeEmbedding( |
|
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads |
|
) |
|
elif addition_embed_type == "text_image": |
|
self.add_embedding = TextImageTimeEmbedding( |
|
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim |
|
) |
|
elif addition_embed_type == "text_time": |
|
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) |
|
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) |
|
|
|
elif addition_embed_type is not None: |
|
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") |
|
|
|
|
|
self.controlnet_cond_embedding = NeuralTextureEmbedding( |
|
conditioning_embedding_channels=block_out_channels[0], |
|
block_out_channels=conditioning_embedding_out_channels, |
|
conditioning_channels=conditioning_channels, |
|
shading_hint_channels=shading_hint_channels, |
|
) |
|
|
|
self.down_blocks = nn.ModuleList([]) |
|
self.controlnet_down_blocks = nn.ModuleList([]) |
|
|
|
if isinstance(only_cross_attention, bool): |
|
only_cross_attention = [only_cross_attention] * len(down_block_types) |
|
|
|
if isinstance(attention_head_dim, int): |
|
attention_head_dim = (attention_head_dim,) * len(down_block_types) |
|
|
|
if isinstance(num_attention_heads, int): |
|
num_attention_heads = (num_attention_heads,) * len(down_block_types) |
|
|
|
|
|
output_channel = block_out_channels[0] |
|
|
|
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) |
|
controlnet_block = zero_module(controlnet_block) |
|
self.controlnet_down_blocks.append(controlnet_block) |
|
|
|
for i, down_block_type in enumerate(down_block_types): |
|
input_channel = output_channel |
|
output_channel = block_out_channels[i] |
|
is_final_block = i == len(block_out_channels) - 1 |
|
|
|
down_block = get_down_block( |
|
down_block_type, |
|
num_layers=layers_per_block, |
|
transformer_layers_per_block=transformer_layers_per_block[i], |
|
in_channels=input_channel, |
|
out_channels=output_channel, |
|
temb_channels=time_embed_dim, |
|
add_downsample=not is_final_block, |
|
resnet_eps=norm_eps, |
|
resnet_act_fn=act_fn, |
|
resnet_groups=norm_num_groups, |
|
cross_attention_dim=cross_attention_dim, |
|
num_attention_heads=num_attention_heads[i], |
|
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, |
|
downsample_padding=downsample_padding, |
|
use_linear_projection=use_linear_projection, |
|
only_cross_attention=only_cross_attention[i], |
|
upcast_attention=upcast_attention, |
|
resnet_time_scale_shift=resnet_time_scale_shift, |
|
) |
|
self.down_blocks.append(down_block) |
|
|
|
for _ in range(layers_per_block): |
|
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) |
|
controlnet_block = zero_module(controlnet_block) |
|
self.controlnet_down_blocks.append(controlnet_block) |
|
|
|
if not is_final_block: |
|
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) |
|
controlnet_block = zero_module(controlnet_block) |
|
self.controlnet_down_blocks.append(controlnet_block) |
|
|
|
|
|
mid_block_channel = block_out_channels[-1] |
|
|
|
controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) |
|
controlnet_block = zero_module(controlnet_block) |
|
self.controlnet_mid_block = controlnet_block |
|
|
|
if mid_block_type == "UNetMidBlock2DCrossAttn": |
|
self.mid_block = UNetMidBlock2DCrossAttn( |
|
transformer_layers_per_block=transformer_layers_per_block[-1], |
|
in_channels=mid_block_channel, |
|
temb_channels=time_embed_dim, |
|
resnet_eps=norm_eps, |
|
resnet_act_fn=act_fn, |
|
output_scale_factor=mid_block_scale_factor, |
|
resnet_time_scale_shift=resnet_time_scale_shift, |
|
cross_attention_dim=cross_attention_dim, |
|
num_attention_heads=num_attention_heads[-1], |
|
resnet_groups=norm_num_groups, |
|
use_linear_projection=use_linear_projection, |
|
upcast_attention=upcast_attention, |
|
) |
|
elif mid_block_type == "UNetMidBlock2D": |
|
self.mid_block = UNetMidBlock2D( |
|
in_channels=block_out_channels[-1], |
|
temb_channels=time_embed_dim, |
|
num_layers=0, |
|
resnet_eps=norm_eps, |
|
resnet_act_fn=act_fn, |
|
output_scale_factor=mid_block_scale_factor, |
|
resnet_groups=norm_num_groups, |
|
resnet_time_scale_shift=resnet_time_scale_shift, |
|
add_attention=False, |
|
) |
|
else: |
|
raise ValueError(f"unknown mid_block_type : {mid_block_type}") |
|
|
|
@classmethod |
|
def from_unet( |
|
cls, |
|
unet: UNet2DConditionModel, |
|
controlnet_conditioning_channel_order: str = "rgb", |
|
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), |
|
load_weights_from_unet: bool = True, |
|
shading_hint_channels: int = 12, |
|
conditioning_channels: int = 4, |
|
): |
|
r""" |
|
Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`]. |
|
|
|
Parameters: |
|
unet (`UNet2DConditionModel`): |
|
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied |
|
where applicable. |
|
""" |
|
transformer_layers_per_block = ( |
|
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 |
|
) |
|
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None |
|
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None |
|
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None |
|
addition_time_embed_dim = ( |
|
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None |
|
) |
|
|
|
controlnet = cls( |
|
encoder_hid_dim=encoder_hid_dim, |
|
encoder_hid_dim_type=encoder_hid_dim_type, |
|
addition_embed_type=addition_embed_type, |
|
addition_time_embed_dim=addition_time_embed_dim, |
|
transformer_layers_per_block=transformer_layers_per_block, |
|
in_channels=unet.config.in_channels, |
|
flip_sin_to_cos=unet.config.flip_sin_to_cos, |
|
freq_shift=unet.config.freq_shift, |
|
down_block_types=unet.config.down_block_types, |
|
only_cross_attention=unet.config.only_cross_attention, |
|
block_out_channels=unet.config.block_out_channels, |
|
layers_per_block=unet.config.layers_per_block, |
|
downsample_padding=unet.config.downsample_padding, |
|
mid_block_scale_factor=unet.config.mid_block_scale_factor, |
|
act_fn=unet.config.act_fn, |
|
norm_num_groups=unet.config.norm_num_groups, |
|
norm_eps=unet.config.norm_eps, |
|
cross_attention_dim=unet.config.cross_attention_dim, |
|
attention_head_dim=unet.config.attention_head_dim, |
|
num_attention_heads=unet.config.num_attention_heads, |
|
use_linear_projection=unet.config.use_linear_projection, |
|
class_embed_type=unet.config.class_embed_type, |
|
num_class_embeds=unet.config.num_class_embeds, |
|
upcast_attention=unet.config.upcast_attention, |
|
resnet_time_scale_shift=unet.config.resnet_time_scale_shift, |
|
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, |
|
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, |
|
conditioning_embedding_out_channels=conditioning_embedding_out_channels, |
|
shading_hint_channels=shading_hint_channels, |
|
conditioning_channels=conditioning_channels, |
|
) |
|
|
|
if load_weights_from_unet: |
|
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) |
|
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) |
|
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) |
|
|
|
if controlnet.class_embedding: |
|
controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) |
|
|
|
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) |
|
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) |
|
|
|
return controlnet |
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, NeuralTextureEncoder)): |
|
module.gradient_checkpointing = value |
|
|