# coding=utf-8 # Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Reference: # * transformers/models/dinov2/modeling_dinov2.py # * https://github.com/facebookresearch/DiT/blob/main/models.py#L101 # * https://github.com/3DTopia/OpenLRM/tree/main/openlrm/models/encoders/dinov2 """ PyTorch DINOv2 model.""" from typing import Dict, List, Optional, Set, Tuple, Union import torch import torch.nn as nn from .modeling_dinov2 import ( Dinov2Config, Dinov2Layer, Dinov2Model, Dinov2Embeddings, BaseModelOutput, BaseModelOutputWithPooling, ) class ModLN(nn.Module): def __init__(self, inner_dim: int, mod_dim: int = 1024): super().__init__() self.mlp = nn.Sequential( nn.SiLU(), nn.Linear(mod_dim, inner_dim * 2), ) for m in self.modules(): if isinstance(m, nn.Linear): nn.init.zeros_(m.weight) nn.init.zeros_(m.bias) def forward(self, x:torch.Tensor, condition:torch.Tensor): ''' x: [N, M, C_in], M: num of tokens condition: [N, C_mod] ''' shift, scale = self.mlp(condition).unsqueeze(1).chunk(2, dim=-1) return x * (1 + scale) + shift class ConditionalDinov2Config(Dinov2Config): def __init__(self, modulation_dim: int = 1024, *args, **kwargs): super().__init__(*args, **kwargs) self.modulation_dim = modulation_dim class ConditionalDinov2Layer(Dinov2Layer): """This corresponds to the Block class in the original implementation.""" def __init__(self, config: ConditionalDinov2Config) -> None: super().__init__(config) self.mod_norm1 = ModLN(config.hidden_size, config.modulation_dim) self.mod_norm2 = ModLN(config.hidden_size, config.modulation_dim) def forward( self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, condition: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: self_attention_outputs = self.attention( self.mod_norm1(self.norm1(hidden_states), condition), # in Dinov2, layernorm is applied before self-attention head_mask, output_attentions=output_attentions, ) attention_output = self_attention_outputs[0] attention_output = self.layer_scale1(attention_output) outputs = self_attention_outputs[1:] # add self attentions if we output attention weights # first residual connection hidden_states = self.drop_path(attention_output) + hidden_states # in Dinov2, layernorm is also applied after self-attention layer_output = self.mod_norm2(self.norm2(hidden_states), condition) layer_output = self.mlp(layer_output) layer_output = self.layer_scale2(layer_output) # second residual connection layer_output = self.drop_path(layer_output) + hidden_states outputs = (layer_output,) + outputs return outputs # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dinov2 class ConditionalDinov2Encoder(nn.Module): def __init__(self, config: ConditionalDinov2Config) -> None: super().__init__() self.config = config self.layer = nn.ModuleList([ConditionalDinov2Layer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, output_hidden_states: bool = False, condition: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[tuple, BaseModelOutput]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, layer_head_mask, condition, output_attentions, ) else: layer_outputs = layer_module( hidden_states, layer_head_mask, condition, output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions, ) class ConditionalDinov2Model(Dinov2Model): config_class = ConditionalDinov2Config def __init__(self, config: ConditionalDinov2Config): super().__init__(config) self.config = config self.embeddings = Dinov2Embeddings(config) self.encoder = ConditionalDinov2Encoder(config) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # Initialize weights and apply final processing self.post_init() def forward( self, pixel_values: Optional[torch.Tensor] = None, bool_masked_pos: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, condition: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is None: raise ValueError("You have to specify pixel_values") # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) encoder_outputs = self.encoder( embedding_output, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, condition=condition, return_dict=return_dict, ) sequence_output = encoder_outputs[0] sequence_output = self.layernorm(sequence_output) pooled_output = sequence_output[:, 0, :] if not return_dict: head_outputs = (sequence_output, pooled_output) return head_outputs + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=sequence_output, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, )