import torch import torch.nn as nn import torch.utils.checkpoint import math from .utils.modules import PatchEmbed, TimestepEmbedder from .utils.modules import PE_wrapper, RMSNorm from .blocks import DiTBlock, JointDiTBlock, FinalBlock class UDiT(nn.Module): def __init__(self, img_size=224, patch_size=16, in_chans=3, input_type='2d', out_chans=None, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, qk_norm=None, act_layer='gelu', norm_layer='layernorm', context_norm=False, use_checkpoint=False, # time fusion ada or token time_fusion='token', ada_lora_rank=None, ada_lora_alpha=None, cls_dim=None, # max length is only used for concat context_dim=768, context_fusion='concat', context_max_length=128, context_pe_method='sinu', pe_method='abs', rope_mode='none', use_conv=True, skip=True, skip_norm=True): super().__init__() self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models # input self.in_chans = in_chans self.input_type = input_type if self.input_type == '2d': num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size) elif self.input_type == '1d': num_patches = img_size // patch_size self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, input_type=input_type) out_chans = in_chans if out_chans is None else out_chans self.out_chans = out_chans # position embedding self.rope = rope_mode self.x_pe = PE_wrapper(dim=embed_dim, method=pe_method, length=num_patches) print(f'x position embedding: {pe_method}') print(f'rope mode: {self.rope}') # time embed self.time_embed = TimestepEmbedder(embed_dim) self.time_fusion = time_fusion self.use_adanorm = False # cls embed if cls_dim is not None: self.cls_embed = nn.Sequential( nn.Linear(cls_dim, embed_dim, bias=True), nn.SiLU(), nn.Linear(embed_dim, embed_dim, bias=True),) else: self.cls_embed = None # time fusion if time_fusion == 'token': # put token at the beginning of sequence self.extras = 2 if self.cls_embed else 1 self.time_pe = PE_wrapper(dim=embed_dim, method='abs', length=self.extras) elif time_fusion in ['ada', 'ada_single', 'ada_lora', 'ada_lora_bias']: self.use_adanorm = True # aviod repetitive silu for each adaln block self.time_act = nn.SiLU() self.extras = 0 self.time_ada_final = nn.Linear(embed_dim, 2 * embed_dim, bias=True) if time_fusion in ['ada_single', 'ada_lora', 'ada_lora_bias']: # shared adaln self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True) else: self.time_ada = None else: raise NotImplementedError print(f'time fusion mode: {self.time_fusion}') # context # use a simple projection self.use_context = False self.context_cross = False self.context_max_length = context_max_length self.context_fusion = 'none' if context_dim is not None: self.use_context = True self.context_embed = nn.Sequential( nn.Linear(context_dim, embed_dim, bias=True), nn.SiLU(), nn.Linear(embed_dim, embed_dim, bias=True),) self.context_fusion = context_fusion if context_fusion == 'concat' or context_fusion == 'joint': self.extras += context_max_length self.context_pe = PE_wrapper(dim=embed_dim, method=context_pe_method, length=context_max_length) # no cross attention layers context_dim = None elif context_fusion == 'cross': self.context_pe = PE_wrapper(dim=embed_dim, method=context_pe_method, length=context_max_length) self.context_cross = True context_dim = embed_dim else: raise NotImplementedError print(f'context fusion mode: {context_fusion}') print(f'context position embedding: {context_pe_method}') if self.context_fusion == 'joint': Block = JointDiTBlock self.use_skip = skip[0] else: Block = DiTBlock self.use_skip = skip # norm layers if norm_layer == 'layernorm': norm_layer = nn.LayerNorm elif norm_layer == 'rmsnorm': norm_layer = RMSNorm else: raise NotImplementedError print(f'use long skip connection: {skip}') self.in_blocks = nn.ModuleList([ Block( dim=embed_dim, context_dim=context_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm, act_layer=act_layer, norm_layer=norm_layer, time_fusion=time_fusion, ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha, skip=False, skip_norm=False, rope_mode=self.rope, context_norm=context_norm, use_checkpoint=use_checkpoint) for _ in range(depth // 2)]) self.mid_block = Block( dim=embed_dim, context_dim=context_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm, act_layer=act_layer, norm_layer=norm_layer, time_fusion=time_fusion, ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha, skip=False, skip_norm=False, rope_mode=self.rope, context_norm=context_norm, use_checkpoint=use_checkpoint) self.out_blocks = nn.ModuleList([ Block( dim=embed_dim, context_dim=context_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm, act_layer=act_layer, norm_layer=norm_layer, time_fusion=time_fusion, ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha, skip=skip, skip_norm=skip_norm, rope_mode=self.rope, context_norm=context_norm, use_checkpoint=use_checkpoint) for _ in range(depth // 2)]) # FinalLayer block self.use_conv = use_conv self.final_block = FinalBlock(embed_dim=embed_dim, patch_size=patch_size, img_size=img_size, in_chans=out_chans, input_type=input_type, norm_layer=norm_layer, use_conv=use_conv, use_adanorm=self.use_adanorm) self.initialize_weights() def _init_ada(self): if self.time_fusion == 'ada': nn.init.constant_(self.time_ada_final.weight, 0) nn.init.constant_(self.time_ada_final.bias, 0) for block in self.in_blocks: nn.init.constant_(block.adaln.time_ada.weight, 0) nn.init.constant_(block.adaln.time_ada.bias, 0) nn.init.constant_(self.mid_block.adaln.time_ada.weight, 0) nn.init.constant_(self.mid_block.adaln.time_ada.bias, 0) for block in self.out_blocks: nn.init.constant_(block.adaln.time_ada.weight, 0) nn.init.constant_(block.adaln.time_ada.bias, 0) elif self.time_fusion == 'ada_single': nn.init.constant_(self.time_ada.weight, 0) nn.init.constant_(self.time_ada.bias, 0) nn.init.constant_(self.time_ada_final.weight, 0) nn.init.constant_(self.time_ada_final.bias, 0) elif self.time_fusion in ['ada_lora', 'ada_lora_bias']: nn.init.constant_(self.time_ada.weight, 0) nn.init.constant_(self.time_ada.bias, 0) nn.init.constant_(self.time_ada_final.weight, 0) nn.init.constant_(self.time_ada_final.bias, 0) for block in self.in_blocks: nn.init.kaiming_uniform_(block.adaln.lora_a.weight, a=math.sqrt(5)) nn.init.constant_(block.adaln.lora_b.weight, 0) nn.init.kaiming_uniform_(self.mid_block.adaln.lora_a.weight, a=math.sqrt(5)) nn.init.constant_(self.mid_block.adaln.lora_b.weight, 0) for block in self.out_blocks: nn.init.kaiming_uniform_(block.adaln.lora_a.weight, a=math.sqrt(5)) nn.init.constant_(block.adaln.lora_b.weight, 0) def initialize_weights(self): # Basic init for all layers def _basic_init(module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) # init patch Conv like Linear w = self.patch_embed.proj.weight.data nn.init.xavier_uniform_(w.view([w.shape[0], -1])) nn.init.constant_(self.patch_embed.proj.bias, 0) # Zero-out AdaLN if self.use_adanorm: self._init_ada() # Zero-out Cross Attention if self.context_cross: for block in self.in_blocks: nn.init.constant_(block.cross_attn.proj.weight, 0) nn.init.constant_(block.cross_attn.proj.bias, 0) nn.init.constant_(self.mid_block.cross_attn.proj.weight, 0) nn.init.constant_(self.mid_block.cross_attn.proj.bias, 0) for block in self.out_blocks: nn.init.constant_(block.cross_attn.proj.weight, 0) nn.init.constant_(block.cross_attn.proj.bias, 0) # Zero-out cls embedding if self.cls_embed: if self.use_adanorm: nn.init.constant_(self.cls_embed[-1].weight, 0) nn.init.constant_(self.cls_embed[-1].bias, 0) # Zero-out Output # might not zero-out this when using v-prediction # it could be good when using noise-prediction # nn.init.constant_(self.final_block.linear.weight, 0) # nn.init.constant_(self.final_block.linear.bias, 0) # if self.use_conv: # nn.init.constant_(self.final_block.final_layer.weight.data, 0) # nn.init.constant_(self.final_block.final_layer.bias, 0) # init out Conv if self.use_conv: nn.init.xavier_uniform_(self.final_block.final_layer.weight) nn.init.constant_(self.final_block.final_layer.bias, 0) def _concat_x_context(self, x, context, x_mask=None, context_mask=None): assert context.shape[-2] == self.context_max_length # Check if either x_mask or context_mask is provided B = x.shape[0] # Create default masks if they are not provided if x_mask is None: x_mask = torch.ones(B, x.shape[-2], device=x.device).bool() if context_mask is None: context_mask = torch.ones(B, context.shape[-2], device=context.device).bool() # Concatenate the masks along the second dimension (dim=1) x_mask = torch.cat([context_mask, x_mask], dim=1) # Concatenate context and x along the second dimension (dim=1) x = torch.cat((context, x), dim=1) return x, x_mask def forward(self, x, timesteps, context, x_mask=None, context_mask=None, cls_token=None ): # make it compatible with int time step during inference if timesteps.dim() == 0: timesteps = timesteps.expand(x.shape[0]).to(x.device, dtype=torch.long) x = self.patch_embed(x) x = self.x_pe(x) B, L, D = x.shape if self.use_context: context_token = self.context_embed(context) context_token = self.context_pe(context_token) if self.context_fusion == 'concat' or self.context_fusion == 'joint': x, x_mask = self._concat_x_context(x=x, context=context_token, x_mask=x_mask, context_mask=context_mask) context_token, context_mask = None, None else: context_token, context_mask = None, None time_token = self.time_embed(timesteps) if self.cls_embed: cls_token = self.cls_embed(cls_token) time_ada = None time_ada_final = None if self.use_adanorm: if self.cls_embed: time_token = time_token + cls_token time_token = self.time_act(time_token) time_ada_final = self.time_ada_final(time_token) if self.time_ada is not None: time_ada = self.time_ada(time_token) else: time_token = time_token.unsqueeze(dim=1) if self.cls_embed: cls_token = cls_token.unsqueeze(dim=1) time_token = torch.cat([time_token, cls_token], dim=1) time_token = self.time_pe(time_token) x = torch.cat((time_token, x), dim=1) if x_mask is not None: x_mask = torch.cat( [torch.ones(B, time_token.shape[1], device=x_mask.device).bool(), x_mask], dim=1) time_token = None skips = [] for blk in self.in_blocks: x = blk(x=x, time_token=time_token, time_ada=time_ada, skip=None, context=context_token, x_mask=x_mask, context_mask=context_mask, extras=self.extras) if self.use_skip: skips.append(x) x = self.mid_block(x=x, time_token=time_token, time_ada=time_ada, skip=None, context=context_token, x_mask=x_mask, context_mask=context_mask, extras=self.extras) for blk in self.out_blocks: skip = skips.pop() if self.use_skip else None x = blk(x=x, time_token=time_token, time_ada=time_ada, skip=skip, context=context_token, x_mask=x_mask, context_mask=context_mask, extras=self.extras) x = self.final_block(x, time_ada=time_ada_final, extras=self.extras) return x