File size: 14,871 Bytes
38a7973 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 |
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from einops import rearrange
from timm.models.layers import DropPath
from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (BaseModelOutput,
BaseModelOutputWithPooling)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from .configuration_intern_vit import InternVisionConfig
try:
from .flash_attention import FlashAttention
has_flash_attn = True
except:
print('FlashAttention is not installed.')
has_flash_attn = False
logger = logging.get_logger(__name__)
class InternRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
try:
from apex.normalization import FusedRMSNorm
InternRMSNorm = FusedRMSNorm # noqa
logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm')
except ImportError:
# using the normal InternRMSNorm
pass
except Exception:
logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm')
pass
NORM2FN = {
'rms_norm': InternRMSNorm,
'layer_norm': nn.LayerNorm,
}
class InternVisionEmbeddings(nn.Module):
def __init__(self, config: InternVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.class_embedding = nn.Parameter(
torch.randn(1, 1, self.embed_dim),
)
self.patch_embedding = nn.Conv2d(
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
def _get_pos_embed(self, pos_embed, H, W):
target_dtype = pos_embed.dtype
pos_embed = pos_embed.float().reshape(
1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False).\
reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype)
return pos_embed
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height]
batch_size, _, height, width = patch_embeds.shape
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
position_embedding = torch.cat([
self.position_embedding[:, :1, :],
self._get_pos_embed(self.position_embedding[:, 1:, :], height, width)
], dim=1)
embeddings = embeddings + position_embedding.to(target_dtype)
return embeddings
class InternAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: InternVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.use_flash_attn = config.use_flash_attn and has_flash_attn
if config.use_flash_attn and not has_flash_attn:
print('Warning: Flash Attention is not available, use_flash_attn is set to False.')
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'
f' {self.num_heads}).'
)
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
self.attn_drop = nn.Dropout(config.attention_dropout)
self.proj_drop = nn.Dropout(config.dropout)
self.qk_normalization = config.qk_normalization
if self.qk_normalization:
self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
if self.use_flash_attn:
self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
def _naive_attn(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
if self.qk_normalization:
B_, H_, N_, D_ = q.shape
q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
attn = ((q * self.scale) @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
qkv = self.qkv(x)
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
if self.qk_normalization:
q, k, v = qkv.unbind(2)
q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
qkv = torch.stack([q, k, v], dim=2)
context, _ = self.inner_attn(
qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
)
outs = self.proj(rearrange(context, 'b s h d -> b s (h d)'))
outs = self.proj_drop(outs)
return outs
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
return x
class InternMLP(nn.Module):
def __init__(self, config: InternVisionConfig):
super().__init__()
self.config = config
self.act = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class InternVisionEncoderLayer(nn.Module):
def __init__(self, config: InternVisionConfig, drop_path_rate: float):
super().__init__()
self.embed_dim = config.hidden_size
self.intermediate_size = config.intermediate_size
self.norm_type = config.norm_type
self.attn = InternAttention(config)
self.mlp = InternMLP(config)
self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
def forward(
self,
hidden_states: torch.Tensor,
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
"""
Args:
hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
"""
hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2)
return hidden_states
class InternVisionEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`InternEncoderLayer`].
Args:
config (`InternConfig`):
The corresponding vision configuration for the `InternEncoder`.
"""
def __init__(self, config: InternVisionConfig):
super().__init__()
self.config = config
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
self.layers = nn.ModuleList([
InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
self.gradient_checkpointing = True
def forward(
self,
inputs_embeds,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Embedded representation of the inputs. Should be float, not int tokens.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
encoder_states = () if output_hidden_states else None
hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = torch.utils.checkpoint.checkpoint(
encoder_layer,
hidden_states)
else:
layer_outputs = encoder_layer(
hidden_states,
)
hidden_states = layer_outputs
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states
)
class InternVisionModel(PreTrainedModel):
main_input_name = 'pixel_values'
config_class = InternVisionConfig
_no_split_modules = ['InternVisionEncoderLayer']
def __init__(self, config: InternVisionConfig):
super().__init__(config)
self.config = config
self.embeddings = InternVisionEmbeddings(config)
self.encoder = InternVisionEncoder(config)
def resize_pos_embeddings(self, old_size, new_size, patch_size):
pos_emb = self.embeddings.position_embedding
_, num_positions, embed_dim = pos_emb.shape
cls_emb = pos_emb[:, :1, :]
pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False)
pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
self.embeddings.position_embedding = nn.Parameter(pos_emb)
self.embeddings.image_size = new_size
logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size))
def get_input_embeddings(self):
return self.embeddings
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_embeds: Optional[torch.FloatTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None and pixel_embeds is None:
raise ValueError('You have to specify pixel_values or pixel_embeds')
if pixel_embeds is not None:
hidden_states = pixel_embeds
else:
if len(pixel_values.shape) == 4:
hidden_states = self.embeddings(pixel_values)
else:
raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs.last_hidden_state
pooled_output = last_hidden_state[:, 0, :]
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
|