from typing import Optional, Tuple, Union import torch from diffusers.configuration_utils import register_to_config from diffusers.models.controlnet import ControlNetModel, zero_module from diffusers.models.embeddings import ( TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps, ) from diffusers.models.unets.unet_2d_blocks import ( CrossAttnDownBlock2D, DownBlock2D, UNetMidBlock2D, UNetMidBlock2DCrossAttn, get_down_block, ) from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.utils import logging from torch import nn from torch.nn import functional as F from torch.utils.checkpoint import checkpoint logger = logging.get_logger(__name__) # pylint: disable=invalid-name class ResBlock(nn.Module): def __init__(self, dim): super().__init__() self.conv = nn.Sequential( nn.Conv2d(dim, dim, 3, 1, 1), nn.GroupNorm(num_groups=8, num_channels=dim), nn.SiLU(inplace=True), nn.Conv2d(dim, dim, 3, 1, 1), ) def forward(self, x): return x + self.conv(x) class NeuralTextureEncoder(nn.Module): def __init__(self, in_dim=3, out_dim=16, dims=(32, 64, 128), groups=8): super().__init__() self.model = nn.Sequential( nn.Conv2d(in_dim, dims[0], kernel_size=3, padding=1), nn.SiLU(inplace=True), # down 1 nn.Conv2d(dims[0], dims[1], kernel_size=3, padding=1, stride=2), nn.GroupNorm(num_groups=groups, num_channels=dims[1]), nn.SiLU(inplace=True), # down 2 nn.Conv2d(dims[1], dims[2], kernel_size=3, padding=1, stride=2), nn.GroupNorm(num_groups=groups, num_channels=dims[2]), nn.SiLU(inplace=True), # res blocks ResBlock(dims[2]), ResBlock(dims[2]), ResBlock(dims[2]), ResBlock(dims[2]), # up 1 nn.ConvTranspose2d(dims[2], dims[1], kernel_size=4, padding=1, stride=2), nn.GroupNorm(num_groups=groups, num_channels=dims[1]), nn.SiLU(inplace=True), # up 2 nn.ConvTranspose2d(dims[1], dims[0], kernel_size=4, padding=1, stride=2), nn.GroupNorm(num_groups=groups, num_channels=dims[0]), nn.SiLU(inplace=True), # out nn.Conv2d(dims[0], out_dim, kernel_size=3, padding=1), ) self.gradient_checkpointing = False def forward(self, x): if self.training and self.gradient_checkpointing: x = checkpoint(self.model, x, use_reentrant=False) else: x = self.model(x) return x class NeuralTextureEmbedding(nn.Module): def __init__( self, conditioning_embedding_channels: int, conditioning_channels: int = 3, block_out_channels: Tuple[int] = (16, 32, 96, 256), shading_hint_channels: int = 12, # diffuse + 3 * ggx ): super().__init__() self.conditioning_channels = conditioning_channels self.shading_hint_channels = shading_hint_channels self.conv_in = nn.Conv2d(shading_hint_channels, block_out_channels[0], kernel_size=3, padding=1) self.neural_texture_encoder = NeuralTextureEncoder(in_dim=conditioning_channels, out_dim=shading_hint_channels) self.blocks = nn.ModuleList([]) for i in range(len(block_out_channels) - 1): channel_in = block_out_channels[i] channel_out = block_out_channels[i + 1] self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) self.conv_out = zero_module( nn.Conv2d( block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1 ) ) def forward(self, all_conditioning): # conditioning: [BS, 4 + 12, 512, 512] # RGB ref image + shading hint (diffuse + 3 * ggx) conditioning, shading_hint = torch.split( all_conditioning, [self.conditioning_channels, self.shading_hint_channels], dim=1 ) embedding = self.neural_texture_encoder(conditioning) # [BS, 15, 512, 512] # multiply shading hint to each channel embedding = embedding * shading_hint embedding = self.conv_in(embedding) embedding = F.silu(embedding) for block in self.blocks: embedding = block(embedding) embedding = F.silu(embedding) embedding = self.conv_out(embedding) return embedding class NeuralTextureControlNetModel(ControlNetModel): """ A Neural Texture ControlNet Model. Args: in_channels (`int`, defaults to 4, RGBA): The number of channels in the input sample. shading_hint_channels (`int`, defaults to 12): channel number of hints """ @register_to_config def __init__( self, in_channels: int = 4, conditioning_channels: int = 3, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str, ...] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ), mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), layers_per_block: int = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int, ...]] = 8, num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", projection_class_embeddings_input_dim: Optional[int] = None, controlnet_conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), global_pool_conditions: bool = False, addition_embed_type_num_heads: int = 64, shading_hint_channels: int = 12, ): super().__init__() num_attention_heads = num_attention_heads or attention_head_dim assert controlnet_conditioning_channel_order == "rgb", "Only RGB channel order is supported." assert global_pool_conditions is False, "Global pooling conditions is not supported." # Check inputs if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): raise ValueError( f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) # input conv_in_kernel = 3 conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) # time time_embed_dim = block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, ) if encoder_hid_dim_type is None and encoder_hid_dim is not None: encoder_hid_dim_type = "text_proj" self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") if encoder_hid_dim is None and encoder_hid_dim_type is not None: raise ValueError( f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." ) if encoder_hid_dim_type == "text_proj": self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) elif encoder_hid_dim_type == "text_image_proj": # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` self.encoder_hid_proj = TextImageProjection( text_embed_dim=encoder_hid_dim, image_embed_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type is not None: raise ValueError( f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." ) else: self.encoder_hid_proj = None # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" ) self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) else: self.class_embedding = None if addition_embed_type == "text": if encoder_hid_dim is not None: text_time_embedding_from_dim = encoder_hid_dim else: text_time_embedding_from_dim = cross_attention_dim self.add_embedding = TextTimeEmbedding( text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads ) elif addition_embed_type == "text_image": self.add_embedding = TextImageTimeEmbedding( text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim ) elif addition_embed_type == "text_time": self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif addition_embed_type is not None: raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") # control net conditioning embedding self.controlnet_cond_embedding = NeuralTextureEmbedding( conditioning_embedding_channels=block_out_channels[0], block_out_channels=conditioning_embedding_out_channels, conditioning_channels=conditioning_channels, shading_hint_channels=shading_hint_channels, ) self.down_blocks = nn.ModuleList([]) self.controlnet_down_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): only_cross_attention = [only_cross_attention] * len(down_block_types) if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) # down output_channel = block_out_channels[0] controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_down_blocks.append(controlnet_block) for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block, transformer_layers_per_block=transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads[i], attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, downsample_padding=downsample_padding, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, ) self.down_blocks.append(down_block) for _ in range(layers_per_block): controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_down_blocks.append(controlnet_block) if not is_final_block: controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_down_blocks.append(controlnet_block) # mid mid_block_channel = block_out_channels[-1] controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_mid_block = controlnet_block if mid_block_type == "UNetMidBlock2DCrossAttn": self.mid_block = UNetMidBlock2DCrossAttn( transformer_layers_per_block=transformer_layers_per_block[-1], in_channels=mid_block_channel, temb_channels=time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) elif mid_block_type == "UNetMidBlock2D": self.mid_block = UNetMidBlock2D( in_channels=block_out_channels[-1], temb_channels=time_embed_dim, num_layers=0, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_groups=norm_num_groups, resnet_time_scale_shift=resnet_time_scale_shift, add_attention=False, ) else: raise ValueError(f"unknown mid_block_type : {mid_block_type}") @classmethod def from_unet( cls, unet: UNet2DConditionModel, controlnet_conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), load_weights_from_unet: bool = True, shading_hint_channels: int = 12, conditioning_channels: int = 4, ): r""" Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`]. Parameters: unet (`UNet2DConditionModel`): The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied where applicable. """ transformer_layers_per_block = ( unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 ) encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None addition_time_embed_dim = ( unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None ) controlnet = cls( encoder_hid_dim=encoder_hid_dim, encoder_hid_dim_type=encoder_hid_dim_type, addition_embed_type=addition_embed_type, addition_time_embed_dim=addition_time_embed_dim, transformer_layers_per_block=transformer_layers_per_block, in_channels=unet.config.in_channels, flip_sin_to_cos=unet.config.flip_sin_to_cos, freq_shift=unet.config.freq_shift, down_block_types=unet.config.down_block_types, only_cross_attention=unet.config.only_cross_attention, block_out_channels=unet.config.block_out_channels, layers_per_block=unet.config.layers_per_block, downsample_padding=unet.config.downsample_padding, mid_block_scale_factor=unet.config.mid_block_scale_factor, act_fn=unet.config.act_fn, norm_num_groups=unet.config.norm_num_groups, norm_eps=unet.config.norm_eps, cross_attention_dim=unet.config.cross_attention_dim, attention_head_dim=unet.config.attention_head_dim, num_attention_heads=unet.config.num_attention_heads, use_linear_projection=unet.config.use_linear_projection, class_embed_type=unet.config.class_embed_type, num_class_embeds=unet.config.num_class_embeds, upcast_attention=unet.config.upcast_attention, resnet_time_scale_shift=unet.config.resnet_time_scale_shift, projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, conditioning_embedding_out_channels=conditioning_embedding_out_channels, shading_hint_channels=shading_hint_channels, conditioning_channels=conditioning_channels, ) if load_weights_from_unet: controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) if controlnet.class_embedding: controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) return controlnet def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, NeuralTextureEncoder)): module.gradient_checkpointing = value