radna commited on
Commit
a7c6d5c
1 Parent(s): 8067ae0

Update modeling_intern_vit.py

Browse files
Files changed (1) hide show
  1. modeling_intern_vit.py +59 -267
modeling_intern_vit.py CHANGED
@@ -12,123 +12,22 @@ from einops import rearrange
12
  from timm.models.layers import DropPath
13
  from torch import nn
14
  from transformers.activations import ACT2FN
15
- from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
 
16
  from transformers.modeling_utils import PreTrainedModel
17
  from transformers.utils import logging
18
 
19
  from .configuration_intern_vit import InternVisionConfig
20
 
21
-
22
  try:
23
- from triton_flash_atn import _attention
24
-
25
- from triton_bert_pading import pad_input, unpad_input
26
-
27
  has_flash_attn = True
28
  except:
29
- print("FlashAttention is not installed.")
30
  has_flash_attn = False
31
 
32
- logger = logging.get_logger(__name__)
33
-
34
-
35
- class FlashAttention(nn.Module):
36
- """Implement the scaled dot product attention with softmax.
37
- Arguments
38
- ---------
39
- softmax_scale: The temperature to use for the softmax attention.
40
- (default: 1/sqrt(d_keys) where d_keys is computed at
41
- runtime)
42
- attention_dropout: The dropout rate to apply to the attention
43
- (default: 0.0)
44
- """
45
-
46
- def __init__(
47
- self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None
48
- ):
49
- super().__init__()
50
- self.softmax_scale = softmax_scale
51
- self.dropout_p = attention_dropout
52
 
53
- def forward(
54
- self,
55
- qkv,
56
- key_padding_mask=None,
57
- causal=False,
58
- cu_seqlens=None,
59
- max_s=None,
60
- need_weights=False,
61
- ):
62
- """Implements the multihead softmax attention.
63
- Arguments
64
- ---------
65
- qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
66
- if unpadded: (nnz, 3, h, d)
67
- key_padding_mask: a bool tensor of shape (B, S)
68
- """
69
- assert not need_weights
70
- assert qkv.dtype in [torch.float16, torch.bfloat16]
71
- assert qkv.is_cuda
72
-
73
- if cu_seqlens is None:
74
- batch_size = qkv.shape[0]
75
- seqlen = qkv.shape[1]
76
- if key_padding_mask is None:
77
- qkv = rearrange(qkv, "b s ... -> (b s) ...")
78
- max_s = seqlen
79
- cu_seqlens = torch.arange(
80
- 0,
81
- (batch_size + 1) * seqlen,
82
- step=seqlen,
83
- dtype=torch.int32,
84
- device=qkv.device,
85
- )
86
- output = _attention.apply(
87
- qkv,
88
- cu_seqlens,
89
- max_s,
90
- self.dropout_p if self.training else 0.0,
91
- sm_scale=self.softmax_scale,
92
- causal=causal,
93
- )
94
- output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
95
- else:
96
- nheads = qkv.shape[-2]
97
- x = rearrange(qkv, "b s three h d -> b s (three h d)")
98
- x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
99
- x_unpad = rearrange(
100
- x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
101
- )
102
- output_unpad = _attention.apply(
103
- x_unpad,
104
- cu_seqlens,
105
- max_s,
106
- self.dropout_p if self.training else 0.0,
107
- sm_scale=self.softmax_scale,
108
- causal=causal,
109
- )
110
- output = rearrange(
111
- pad_input(
112
- rearrange(output_unpad, "nnz h d -> nnz (h d)"),
113
- indices,
114
- batch_size,
115
- seqlen,
116
- ),
117
- "b s (h d) -> b s h d",
118
- h=nheads,
119
- )
120
- else:
121
- assert max_s is not None
122
- output = _attention.apply(
123
- qkv,
124
- cu_seqlens,
125
- max_s,
126
- self.dropout_p if self.training else 0.0,
127
- sm_scale=self.softmax_scale,
128
- causal=causal,
129
- )
130
-
131
- return output, None
132
 
133
 
134
  class InternRMSNorm(nn.Module):
@@ -150,25 +49,15 @@ try:
150
 
151
  InternRMSNorm = FusedRMSNorm # noqa
152
 
153
- logger.info(
154
- "Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm"
155
- )
156
  except ImportError:
157
  # using the normal InternRMSNorm
158
  pass
159
  except Exception:
160
- logger.warning(
161
- "discovered apex but it failed to load, falling back to InternRMSNorm"
162
- )
163
  pass
164
 
165
 
166
- NORM2FN = {
167
- "rms_norm": InternRMSNorm,
168
- "layer_norm": nn.LayerNorm,
169
- }
170
-
171
-
172
  class InternVisionEmbeddings(nn.Module):
173
  def __init__(self, config: InternVisionConfig):
174
  super().__init__()
@@ -182,55 +71,22 @@ class InternVisionEmbeddings(nn.Module):
182
  )
183
 
184
  self.patch_embedding = nn.Conv2d(
185
- in_channels=3,
186
- out_channels=self.embed_dim,
187
- kernel_size=self.patch_size,
188
- stride=self.patch_size,
189
  )
190
 
191
  self.num_patches = (self.image_size // self.patch_size) ** 2
192
  self.num_positions = self.num_patches + 1
193
 
194
- self.position_embedding = nn.Parameter(
195
- torch.randn(1, self.num_positions, self.embed_dim)
196
- )
197
-
198
- def _get_pos_embed(self, pos_embed, H, W):
199
- target_dtype = pos_embed.dtype
200
- pos_embed = (
201
- pos_embed.float()
202
- .reshape(
203
- 1,
204
- self.image_size // self.patch_size,
205
- self.image_size // self.patch_size,
206
- -1,
207
- )
208
- .permute(0, 3, 1, 2)
209
- )
210
- pos_embed = (
211
- F.interpolate(pos_embed, size=(H, W), mode="bicubic", align_corners=False)
212
- .reshape(1, -1, H * W)
213
- .permute(0, 2, 1)
214
- .to(target_dtype)
215
- )
216
- return pos_embed
217
 
218
  def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
 
219
  target_dtype = self.patch_embedding.weight.dtype
220
- # shape = [*, channel, width, height]
221
- patch_embeds = self.patch_embedding(pixel_values)
222
- batch_size, _, height, width = patch_embeds.shape
223
  patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
224
  class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
225
  embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
226
- position_embedding = torch.cat(
227
- [
228
- self.position_embedding[:, :1, :],
229
- self._get_pos_embed(self.position_embedding[:, 1:, :], height, width),
230
- ],
231
- dim=1,
232
- )
233
- embeddings = embeddings + position_embedding.to(target_dtype)
234
  return embeddings
235
 
236
 
@@ -244,17 +100,15 @@ class InternAttention(nn.Module):
244
  self.num_heads = config.num_attention_heads
245
  self.use_flash_attn = config.use_flash_attn and has_flash_attn
246
  if config.use_flash_attn and not has_flash_attn:
247
- print(
248
- "Warning: Flash Attention is not available, use_flash_attn is set to False."
249
- )
250
  self.head_dim = self.embed_dim // self.num_heads
251
  if self.head_dim * self.num_heads != self.embed_dim:
252
  raise ValueError(
253
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
254
- f" {self.num_heads})."
255
  )
256
 
257
- self.scale = self.head_dim**-0.5
258
  self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
259
  self.attn_drop = nn.Dropout(config.attention_dropout)
260
  self.proj_drop = nn.Dropout(config.dropout)
@@ -271,28 +125,15 @@ class InternAttention(nn.Module):
271
 
272
  def _naive_attn(self, x):
273
  B, N, C = x.shape
274
- qkv = (
275
- self.qkv(x)
276
- .reshape(B, N, 3, self.num_heads, C // self.num_heads)
277
- .permute(2, 0, 3, 1, 4)
278
- )
279
- # make torchscript happy (cannot use tensor as tuple)
280
- q, k, v = qkv.unbind(0)
281
 
282
  if self.qk_normalization:
283
  B_, H_, N_, D_ = q.shape
284
- q = (
285
- self.q_norm(q.transpose(1, 2).flatten(-2, -1))
286
- .view(B_, N_, H_, D_)
287
- .transpose(1, 2)
288
- )
289
- k = (
290
- self.k_norm(k.transpose(1, 2).flatten(-2, -1))
291
- .view(B_, N_, H_, D_)
292
- .transpose(1, 2)
293
- )
294
 
295
- attn = (q * self.scale) @ k.transpose(-2, -1)
296
  attn = attn.softmax(dim=-1)
297
  attn = self.attn_drop(attn)
298
 
@@ -303,9 +144,7 @@ class InternAttention(nn.Module):
303
 
304
  def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
305
  qkv = self.qkv(x)
306
- qkv = rearrange(
307
- qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads
308
- )
309
 
310
  if self.qk_normalization:
311
  q, k, v = qkv.unbind(2)
@@ -314,21 +153,14 @@ class InternAttention(nn.Module):
314
  qkv = torch.stack([q, k, v], dim=2)
315
 
316
  context, _ = self.inner_attn(
317
- qkv,
318
- key_padding_mask=key_padding_mask,
319
- need_weights=need_weights,
320
- causal=False,
321
  )
322
- outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
323
  outs = self.proj_drop(outs)
324
  return outs
325
 
326
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
327
- x = (
328
- self._naive_attn(hidden_states)
329
- if not self.use_flash_attn
330
- else self._flash_attn(hidden_states)
331
- )
332
  return x
333
 
334
 
@@ -352,41 +184,28 @@ class InternVisionEncoderLayer(nn.Module):
352
  super().__init__()
353
  self.embed_dim = config.hidden_size
354
  self.intermediate_size = config.intermediate_size
355
- self.norm_type = config.norm_type
356
 
357
  self.attn = InternAttention(config)
358
  self.mlp = InternMLP(config)
359
- self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
360
- self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
361
 
362
  self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
363
  self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
364
- self.drop_path1 = (
365
- DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
366
- )
367
- self.drop_path2 = (
368
- DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
369
- )
370
 
371
  def forward(
372
- self,
373
- hidden_states: torch.Tensor,
374
- ) -> Tuple[
375
- torch.FloatTensor,
376
- Optional[torch.FloatTensor],
377
- Optional[Tuple[torch.FloatTensor]],
378
- ]:
379
  """
380
  Args:
381
  hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
382
  """
383
- hidden_states = hidden_states + self.drop_path1(
384
- self.attn(self.norm1(hidden_states)) * self.ls1
385
- )
386
 
387
- hidden_states = hidden_states + self.drop_path2(
388
- self.mlp(self.norm2(hidden_states)) * self.ls2
389
- )
390
 
391
  return hidden_states
392
 
@@ -405,23 +224,16 @@ class InternVisionEncoder(nn.Module):
405
  super().__init__()
406
  self.config = config
407
  # stochastic depth decay rule
408
- dpr = [
409
- x.item()
410
- for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)
411
- ]
412
- self.layers = nn.ModuleList(
413
- [
414
- InternVisionEncoderLayer(config, dpr[idx])
415
- for idx in range(config.num_hidden_layers)
416
- ]
417
- )
418
  self.gradient_checkpointing = True
419
 
420
  def forward(
421
- self,
422
- inputs_embeds,
423
- output_hidden_states: Optional[bool] = None,
424
- return_dict: Optional[bool] = None,
425
  ) -> Union[Tuple, BaseModelOutput]:
426
  r"""
427
  Args:
@@ -434,13 +246,9 @@ class InternVisionEncoder(nn.Module):
434
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
435
  """
436
  output_hidden_states = (
437
- output_hidden_states
438
- if output_hidden_states is not None
439
- else self.config.output_hidden_states
440
- )
441
- return_dict = (
442
- return_dict if return_dict is not None else self.config.use_return_dict
443
  )
 
444
 
445
  encoder_states = () if output_hidden_states else None
446
  hidden_states = inputs_embeds
@@ -450,8 +258,8 @@ class InternVisionEncoder(nn.Module):
450
  encoder_states = encoder_states + (hidden_states,)
451
  if self.gradient_checkpointing and self.training:
452
  layer_outputs = torch.utils.checkpoint.checkpoint(
453
- encoder_layer, hidden_states
454
- )
455
  else:
456
  layer_outputs = encoder_layer(
457
  hidden_states,
@@ -469,9 +277,9 @@ class InternVisionEncoder(nn.Module):
469
 
470
 
471
  class InternVisionModel(PreTrainedModel):
472
- main_input_name = "pixel_values"
473
  config_class = InternVisionConfig
474
- _no_split_modules = ["InternVisionEncoderLayer"]
475
 
476
  def __init__(self, config: InternVisionConfig):
477
  super().__init__(config)
@@ -484,46 +292,30 @@ class InternVisionModel(PreTrainedModel):
484
  pos_emb = self.embeddings.position_embedding
485
  _, num_positions, embed_dim = pos_emb.shape
486
  cls_emb = pos_emb[:, :1, :]
487
- pos_emb = (
488
- pos_emb[:, 1:, :]
489
- .reshape(1, old_size // patch_size, old_size // patch_size, -1)
490
- .permute(0, 3, 1, 2)
491
- )
492
- pos_emb = F.interpolate(
493
- pos_emb.float(),
494
- size=new_size // patch_size,
495
- mode="bicubic",
496
- align_corners=False,
497
- )
498
  pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
499
  pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
500
  self.embeddings.position_embedding = nn.Parameter(pos_emb)
501
- self.embeddings.image_size = new_size
502
- logger.info(
503
- "Resized position embeddings from {} to {}".format(old_size, new_size)
504
- )
505
 
506
  def get_input_embeddings(self):
507
  return self.embeddings
508
 
509
  def forward(
510
- self,
511
- pixel_values: Optional[torch.FloatTensor] = None,
512
- output_hidden_states: Optional[bool] = None,
513
- return_dict: Optional[bool] = None,
514
- pixel_embeds: Optional[torch.FloatTensor] = None,
515
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
516
  output_hidden_states = (
517
- output_hidden_states
518
- if output_hidden_states is not None
519
- else self.config.output_hidden_states
520
- )
521
- return_dict = (
522
- return_dict if return_dict is not None else self.config.use_return_dict
523
  )
 
524
 
525
  if pixel_values is None and pixel_embeds is None:
526
- raise ValueError("You have to specify pixel_values or pixel_embeds")
527
 
528
  if pixel_embeds is not None:
529
  hidden_states = pixel_embeds
@@ -531,7 +323,7 @@ class InternVisionModel(PreTrainedModel):
531
  if len(pixel_values.shape) == 4:
532
  hidden_states = self.embeddings(pixel_values)
533
  else:
534
- raise ValueError(f"wrong pixel_values size: {pixel_values.shape}")
535
  encoder_outputs = self.encoder(
536
  inputs_embeds=hidden_states,
537
  output_hidden_states=output_hidden_states,
 
12
  from timm.models.layers import DropPath
13
  from torch import nn
14
  from transformers.activations import ACT2FN
15
+ from transformers.modeling_outputs import (BaseModelOutput,
16
+ BaseModelOutputWithPooling)
17
  from transformers.modeling_utils import PreTrainedModel
18
  from transformers.utils import logging
19
 
20
  from .configuration_intern_vit import InternVisionConfig
21
 
 
22
  try:
23
+ from .flash_attention import FlashAttention
 
 
 
24
  has_flash_attn = True
25
  except:
26
+ print('FlashAttention is not installed.')
27
  has_flash_attn = False
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ logger = logging.get_logger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
 
33
  class InternRMSNorm(nn.Module):
 
49
 
50
  InternRMSNorm = FusedRMSNorm # noqa
51
 
52
+ logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm')
 
 
53
  except ImportError:
54
  # using the normal InternRMSNorm
55
  pass
56
  except Exception:
57
+ logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm')
 
 
58
  pass
59
 
60
 
 
 
 
 
 
 
61
  class InternVisionEmbeddings(nn.Module):
62
  def __init__(self, config: InternVisionConfig):
63
  super().__init__()
 
71
  )
72
 
73
  self.patch_embedding = nn.Conv2d(
74
+ in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
 
 
 
75
  )
76
 
77
  self.num_patches = (self.image_size // self.patch_size) ** 2
78
  self.num_positions = self.num_patches + 1
79
 
80
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
83
+ batch_size = pixel_values.shape[0]
84
  target_dtype = self.patch_embedding.weight.dtype
85
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
 
 
86
  patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
87
  class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
88
  embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
89
+ embeddings = embeddings + self.position_embedding.to(target_dtype)
 
 
 
 
 
 
 
90
  return embeddings
91
 
92
 
 
100
  self.num_heads = config.num_attention_heads
101
  self.use_flash_attn = config.use_flash_attn and has_flash_attn
102
  if config.use_flash_attn and not has_flash_attn:
103
+ print('Warning: Flash Attention is not available, use_flash_attn is set to False.')
 
 
104
  self.head_dim = self.embed_dim // self.num_heads
105
  if self.head_dim * self.num_heads != self.embed_dim:
106
  raise ValueError(
107
+ f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'
108
+ f' {self.num_heads}).'
109
  )
110
 
111
+ self.scale = self.head_dim ** -0.5
112
  self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
113
  self.attn_drop = nn.Dropout(config.attention_dropout)
114
  self.proj_drop = nn.Dropout(config.dropout)
 
125
 
126
  def _naive_attn(self, x):
127
  B, N, C = x.shape
128
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
129
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
 
 
 
 
 
130
 
131
  if self.qk_normalization:
132
  B_, H_, N_, D_ = q.shape
133
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
134
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
 
 
 
 
 
 
 
 
135
 
136
+ attn = ((q * self.scale) @ k.transpose(-2, -1))
137
  attn = attn.softmax(dim=-1)
138
  attn = self.attn_drop(attn)
139
 
 
144
 
145
  def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
146
  qkv = self.qkv(x)
147
+ qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
 
 
148
 
149
  if self.qk_normalization:
150
  q, k, v = qkv.unbind(2)
 
153
  qkv = torch.stack([q, k, v], dim=2)
154
 
155
  context, _ = self.inner_attn(
156
+ qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
 
 
 
157
  )
158
+ outs = self.proj(rearrange(context, 'b s h d -> b s (h d)'))
159
  outs = self.proj_drop(outs)
160
  return outs
161
 
162
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
163
+ x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
 
 
 
 
164
  return x
165
 
166
 
 
184
  super().__init__()
185
  self.embed_dim = config.hidden_size
186
  self.intermediate_size = config.intermediate_size
 
187
 
188
  self.attn = InternAttention(config)
189
  self.mlp = InternMLP(config)
190
+ self.norm1 = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
191
+ self.norm2 = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
192
 
193
  self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
194
  self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
195
+ self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
196
+ self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
 
 
 
 
197
 
198
  def forward(
199
+ self,
200
+ hidden_states: torch.Tensor,
201
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
 
 
 
 
202
  """
203
  Args:
204
  hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
205
  """
206
+ hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
 
 
207
 
208
+ hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2)
 
 
209
 
210
  return hidden_states
211
 
 
224
  super().__init__()
225
  self.config = config
226
  # stochastic depth decay rule
227
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
228
+ self.layers = nn.ModuleList([
229
+ InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
 
 
 
 
 
 
 
230
  self.gradient_checkpointing = True
231
 
232
  def forward(
233
+ self,
234
+ inputs_embeds,
235
+ output_hidden_states: Optional[bool] = None,
236
+ return_dict: Optional[bool] = None,
237
  ) -> Union[Tuple, BaseModelOutput]:
238
  r"""
239
  Args:
 
246
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
247
  """
248
  output_hidden_states = (
249
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
250
  )
251
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
252
 
253
  encoder_states = () if output_hidden_states else None
254
  hidden_states = inputs_embeds
 
258
  encoder_states = encoder_states + (hidden_states,)
259
  if self.gradient_checkpointing and self.training:
260
  layer_outputs = torch.utils.checkpoint.checkpoint(
261
+ encoder_layer,
262
+ hidden_states)
263
  else:
264
  layer_outputs = encoder_layer(
265
  hidden_states,
 
277
 
278
 
279
  class InternVisionModel(PreTrainedModel):
280
+ main_input_name = 'pixel_values'
281
  config_class = InternVisionConfig
282
+ _no_split_modules = ['InternVisionEncoderLayer']
283
 
284
  def __init__(self, config: InternVisionConfig):
285
  super().__init__(config)
 
292
  pos_emb = self.embeddings.position_embedding
293
  _, num_positions, embed_dim = pos_emb.shape
294
  cls_emb = pos_emb[:, :1, :]
295
+ pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
296
+ pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False)
 
 
 
 
 
 
 
 
 
297
  pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
298
  pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
299
  self.embeddings.position_embedding = nn.Parameter(pos_emb)
300
+ logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size))
 
 
 
301
 
302
  def get_input_embeddings(self):
303
  return self.embeddings
304
 
305
  def forward(
306
+ self,
307
+ pixel_values: Optional[torch.FloatTensor] = None,
308
+ output_hidden_states: Optional[bool] = None,
309
+ return_dict: Optional[bool] = None,
310
+ pixel_embeds: Optional[torch.FloatTensor] = None,
311
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
312
  output_hidden_states = (
313
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
314
  )
315
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
316
 
317
  if pixel_values is None and pixel_embeds is None:
318
+ raise ValueError('You have to specify pixel_values or pixel_embeds')
319
 
320
  if pixel_embeds is not None:
321
  hidden_states = pixel_embeds
 
323
  if len(pixel_values.shape) == 4:
324
  hidden_states = self.embeddings(pixel_values)
325
  else:
326
+ raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
327
  encoder_outputs = self.encoder(
328
  inputs_embeds=hidden_states,
329
  output_hidden_states=output_hidden_states,