Spaces:
Running
on
Zero
Running
on
Zero
# coding=utf-8 | |
# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# Reference: | |
# * transformers/models/dinov2/modeling_dinov2.py | |
# * https://github.com/facebookresearch/DiT/blob/main/models.py#L101 | |
# * https://github.com/3DTopia/OpenLRM/tree/main/openlrm/models/encoders/dinov2 | |
""" PyTorch CLIP model.""" | |
from typing import Dict, List, Optional, Set, Tuple, Union | |
import torch | |
import torch.nn as nn | |
from .modeling_clip import ( | |
CLIPConfig, | |
CLIPTextConfig, | |
CLIPVisionConfig, | |
CLIPEncoderLayer, | |
CLIPTextTransformer, | |
CLIPVisionTransformer, | |
CLIPModel, | |
CLIPVisionEmbeddings, | |
CLIPVisionModel, | |
CLIPOutput, | |
BaseModelOutput, | |
BaseModelOutputWithPooling | |
) | |
class ModLN(nn.Module): | |
def __init__(self, inner_dim: int, mod_dim: int = 32): | |
super().__init__() | |
self.mlp = nn.Sequential( | |
nn.SiLU(), | |
nn.Linear(mod_dim, inner_dim * 2), | |
) | |
for m in self.modules(): | |
if isinstance(m, nn.Linear): | |
nn.init.zeros_(m.weight) | |
nn.init.zeros_(m.bias) | |
def forward(self, x:torch.Tensor, condition:torch.Tensor): | |
''' | |
x: [N, M, C_in], M: num of tokens | |
condition: [N, C_mod] | |
''' | |
shift, scale = self.mlp(condition).unsqueeze(1).chunk(2, dim=-1) | |
return x * (1 + scale) + shift | |
class ConditionalCLIPVisionConfig(CLIPVisionConfig): | |
def __init__(self, modulation_dim: int = 32, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.modulation_dim = modulation_dim | |
class ConditionalCLIPEncoderLayer(CLIPEncoderLayer): | |
"""This corresponds to the Block class in the original implementation.""" | |
def __init__(self, config: ConditionalCLIPVisionConfig) -> None: | |
super().__init__(config) | |
self.mod_norm1 = ModLN(config.hidden_size, config.modulation_dim) | |
self.mod_norm2 = ModLN(config.hidden_size, config.modulation_dim) | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: torch.Tensor, | |
causal_attention_mask: torch.Tensor, | |
condition: Optional[torch.Tensor] = None, | |
output_attentions: bool = False, | |
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: | |
residual = hidden_states | |
hidden_states = self.mod_norm1(self.layer_norm1(hidden_states), condition) | |
hidden_states, attn_weights = self.self_attn( | |
hidden_states=hidden_states, | |
attention_mask=attention_mask, | |
causal_attention_mask=causal_attention_mask, | |
output_attentions=output_attentions, | |
) | |
hidden_states = residual + hidden_states | |
residual = hidden_states | |
hidden_states = self.mod_norm2(self.layer_norm2(hidden_states), condition) | |
hidden_states = self.mlp(hidden_states) | |
hidden_states = residual + hidden_states | |
outputs = (hidden_states,) | |
if output_attentions: | |
outputs += (attn_weights,) | |
return outputs | |
class ConditionalCLIPEncoder(nn.Module): | |
def __init__(self, config: CLIPConfig) -> None: | |
super().__init__() | |
self.config = config | |
self.layers = nn.ModuleList([ConditionalCLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]) | |
self.gradient_checkpointing = False | |
def forward( | |
self, | |
inputs_embeds, | |
attention_mask: Optional[torch.Tensor] = None, | |
causal_attention_mask: Optional[torch.Tensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
condition: Optional[torch.Tensor] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[tuple, BaseModelOutput]: | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
encoder_states = () if output_hidden_states else None | |
all_attentions = () if output_attentions else None | |
hidden_states = inputs_embeds | |
for idx, encoder_layer in enumerate(self.layers): | |
if output_hidden_states: | |
encoder_states = encoder_states + (hidden_states,) | |
if self.gradient_checkpointing and self.training: | |
layer_outputs = self._gradient_checkpointing_func( | |
encoder_layer.__call__, | |
hidden_states, | |
attention_mask, | |
causal_attention_mask, | |
condition=condition, | |
output_attentions=output_attentions, | |
) | |
else: | |
layer_outputs = encoder_layer( | |
hidden_states, | |
attention_mask, | |
causal_attention_mask, | |
condition=condition, | |
output_attentions=output_attentions, | |
) | |
hidden_states = layer_outputs[0] | |
if output_attentions: | |
all_attentions = all_attentions + (layer_outputs[1],) | |
if output_hidden_states: | |
encoder_states = encoder_states + (hidden_states,) | |
if not return_dict: | |
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) | |
return BaseModelOutput( | |
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions | |
) | |
class ConditionalCLIPVisionTransformer(CLIPVisionTransformer): | |
def __init__(self, config: ConditionalCLIPVisionConfig): | |
super().__init__(config) | |
self.config = config | |
embed_dim = config.hidden_size | |
self.embeddings = CLIPVisionEmbeddings(config) | |
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) | |
self.encoder = ConditionalCLIPEncoder(config) | |
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) | |
def forward( | |
self, | |
pixel_values: Optional[torch.FloatTensor] = None, | |
condition: Optional[torch.Tensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, BaseModelOutputWithPooling]: | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
if pixel_values is None: | |
raise ValueError("You have to specify pixel_values") | |
hidden_states = self.embeddings(pixel_values) | |
hidden_states = self.pre_layrnorm(hidden_states) | |
encoder_outputs = self.encoder( | |
inputs_embeds=hidden_states, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
condition=condition, | |
return_dict=return_dict, | |
) | |
last_hidden_state = encoder_outputs[0] | |
pooled_output = last_hidden_state[:, 0, :] | |
pooled_output = self.post_layernorm(pooled_output) | |
if not return_dict: | |
return (last_hidden_state, pooled_output) + encoder_outputs[1:] | |
return BaseModelOutputWithPooling( | |
last_hidden_state=last_hidden_state, | |
pooler_output=pooled_output, | |
hidden_states=encoder_outputs.hidden_states, | |
attentions=encoder_outputs.attentions, | |
) | |
class ConditionalCLIPVisionModel(CLIPVisionModel): | |
config_class = ConditionalCLIPVisionConfig | |
def __init__(self, config: ConditionalCLIPVisionConfig): | |
super().__init__(config) | |
self.vision_model = ConditionalCLIPVisionTransformer(config) | |
# Initialize weights and apply final processing | |
self.post_init() | |
def forward( | |
self, | |
pixel_values: Optional[torch.FloatTensor] = None, | |
condition: Optional[torch.Tensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, BaseModelOutputWithPooling]: | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
return self.vision_model( | |
pixel_values=pixel_values, | |
condition=condition, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
class ConditionalCLIPModel(CLIPModel): | |
config_class = CLIPConfig | |
def __init__(self, config: CLIPConfig): | |
super().__init__(config) | |
if not isinstance(config.text_config, CLIPTextConfig): | |
raise ValueError( | |
"config.text_config is expected to be of type CLIPTextConfig but is of type" | |
f" {type(config.text_config)}." | |
) | |
if not isinstance(config.vision_config, CLIPVisionConfig): | |
raise ValueError( | |
"config.vision_config is expected to be of type CLIPVisionConfig but is of type" | |
f" {type(config.vision_config)}." | |
) | |
text_config = config.text_config | |
vision_config = config.vision_config | |
self.projection_dim = config.projection_dim | |
self.text_embed_dim = text_config.hidden_size | |
self.vision_embed_dim = vision_config.hidden_size | |
self.text_model = CLIPTextTransformer(text_config) | |
self.vision_model = ConditionalCLIPVisionTransformer(vision_config) | |
self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) | |
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) | |
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value)) | |
# Initialize weights and apply final processing | |
self.post_init() | |
def get_image_features( | |
self, | |
pixel_values: Optional[torch.FloatTensor] = None, | |
condition: Optional[torch.Tensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> torch.FloatTensor: | |
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components. | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
vision_outputs = self.vision_model( | |
pixel_values=pixel_values, | |
condition=condition, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
pooled_output = vision_outputs[1] # pooled_output | |
image_features = self.visual_projection(pooled_output) | |
return image_features | |
def forward( | |
self, | |
input_ids: Optional[torch.LongTensor] = None, | |
pixel_values: Optional[torch.FloatTensor] = None, | |
condition: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
return_loss: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, CLIPOutput]: | |
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components. | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
vision_outputs = self.vision_model( | |
pixel_values=pixel_values, | |
condition=condition, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
text_outputs = self.text_model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
image_embeds = vision_outputs[1] | |
image_embeds = self.visual_projection(image_embeds) | |
text_embeds = text_outputs[1] | |
text_embeds = self.text_projection(text_embeds) | |
# normalized features | |
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) | |
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) | |
# cosine similarity as logits | |
logit_scale = self.logit_scale.exp() | |
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale | |
logits_per_image = logits_per_text.t() | |
loss = None | |
if return_loss: | |
loss = clip_loss(logits_per_text) | |
if not return_dict: | |
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) | |
return ((loss,) + output) if loss is not None else output | |
return CLIPOutput( | |
loss=loss, | |
logits_per_image=logits_per_image, | |
logits_per_text=logits_per_text, | |
text_embeds=text_embeds, | |
image_embeds=image_embeds, | |
text_model_output=text_outputs, | |
vision_model_output=vision_outputs, | |
) | |