khang119966 commited on
Commit
c0e718e
·
verified ·
1 Parent(s): a275120

Upload 2 files

Browse files
Files changed (2) hide show
  1. modeling_intern_vit.py +430 -0
  2. modeling_internvl_chat.py +348 -0
modeling_intern_vit.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ from typing import Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint
12
+ from einops import rearrange
13
+ from timm.models.layers import DropPath
14
+ from torch import nn
15
+ from transformers.activations import ACT2FN
16
+ from transformers.modeling_outputs import (BaseModelOutput,
17
+ BaseModelOutputWithPooling)
18
+ from transformers.modeling_utils import PreTrainedModel
19
+ from transformers.utils import logging
20
+
21
+ from .configuration_intern_vit import InternVisionConfig
22
+
23
+ try:
24
+ from flash_attn.bert_padding import pad_input, unpad_input
25
+ from flash_attn.flash_attn_interface import \
26
+ flash_attn_varlen_qkvpacked_func
27
+ has_flash_attn = True
28
+ except:
29
+ print('FlashAttention2 is not installed.')
30
+ has_flash_attn = False
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ class FlashAttention(nn.Module):
36
+ """Implement the scaled dot product attention with softmax.
37
+ Arguments
38
+ ---------
39
+ softmax_scale: The temperature to use for the softmax attention.
40
+ (default: 1/sqrt(d_keys) where d_keys is computed at
41
+ runtime)
42
+ attention_dropout: The dropout rate to apply to the attention
43
+ (default: 0.0)
44
+ """
45
+
46
+ def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
47
+ super().__init__()
48
+ self.softmax_scale = softmax_scale
49
+ self.dropout_p = attention_dropout
50
+
51
+ def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
52
+ max_s=None, need_weights=False):
53
+ """Implements the multihead softmax attention.
54
+ Arguments
55
+ ---------
56
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
57
+ if unpadded: (nnz, 3, h, d)
58
+ key_padding_mask: a bool tensor of shape (B, S)
59
+ """
60
+ assert not need_weights
61
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
62
+ assert qkv.is_cuda
63
+
64
+ if cu_seqlens is None:
65
+ batch_size = qkv.shape[0]
66
+ seqlen = qkv.shape[1]
67
+ if key_padding_mask is None:
68
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
69
+ max_s = seqlen
70
+ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
71
+ device=qkv.device)
72
+ output = flash_attn_varlen_qkvpacked_func(
73
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
74
+ softmax_scale=self.softmax_scale, causal=causal
75
+ )
76
+ output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
77
+ else:
78
+ nheads = qkv.shape[-2]
79
+ x = rearrange(qkv, 'b s three h d -> b s (three h d)')
80
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
81
+ x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
82
+ output_unpad = flash_attn_varlen_qkvpacked_func(
83
+ x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
84
+ softmax_scale=self.softmax_scale, causal=causal
85
+ )
86
+ output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
87
+ indices, batch_size, seqlen),
88
+ 'b s (h d) -> b s h d', h=nheads)
89
+ else:
90
+ assert max_s is not None
91
+ output = flash_attn_varlen_qkvpacked_func(
92
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
93
+ softmax_scale=self.softmax_scale, causal=causal
94
+ )
95
+
96
+ return output, None
97
+
98
+
99
+ class InternRMSNorm(nn.Module):
100
+ def __init__(self, hidden_size, eps=1e-6):
101
+ super().__init__()
102
+ self.weight = nn.Parameter(torch.ones(hidden_size))
103
+ self.variance_epsilon = eps
104
+
105
+ def forward(self, hidden_states):
106
+ input_dtype = hidden_states.dtype
107
+ hidden_states = hidden_states.to(torch.float32)
108
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
109
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
110
+ return self.weight * hidden_states.to(input_dtype)
111
+
112
+
113
+ try:
114
+ from apex.normalization import FusedRMSNorm
115
+
116
+ InternRMSNorm = FusedRMSNorm # noqa
117
+
118
+ logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm')
119
+ except ImportError:
120
+ # using the normal InternRMSNorm
121
+ pass
122
+ except Exception:
123
+ logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm')
124
+ pass
125
+
126
+
127
+ NORM2FN = {
128
+ 'rms_norm': InternRMSNorm,
129
+ 'layer_norm': nn.LayerNorm,
130
+ }
131
+
132
+
133
+ class InternVisionEmbeddings(nn.Module):
134
+ def __init__(self, config: InternVisionConfig):
135
+ super().__init__()
136
+ self.config = config
137
+ self.embed_dim = config.hidden_size
138
+ self.image_size = config.image_size
139
+ self.patch_size = config.patch_size
140
+
141
+ self.class_embedding = nn.Parameter(
142
+ torch.randn(1, 1, self.embed_dim),
143
+ )
144
+
145
+ self.patch_embedding = nn.Conv2d(
146
+ in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
147
+ )
148
+
149
+ self.num_patches = (self.image_size // self.patch_size) ** 2
150
+ self.num_positions = self.num_patches + 1
151
+
152
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
153
+
154
+ def _get_pos_embed(self, pos_embed, H, W):
155
+ target_dtype = pos_embed.dtype
156
+ pos_embed = pos_embed.float().reshape(
157
+ 1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
158
+ pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False). \
159
+ reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype)
160
+ return pos_embed
161
+
162
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
163
+ target_dtype = self.patch_embedding.weight.dtype
164
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height]
165
+ batch_size, _, height, width = patch_embeds.shape
166
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
167
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
168
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
169
+ position_embedding = torch.cat([
170
+ self.position_embedding[:, :1, :],
171
+ self._get_pos_embed(self.position_embedding[:, 1:, :], height, width)
172
+ ], dim=1)
173
+ embeddings = embeddings + position_embedding.to(target_dtype)
174
+ return embeddings
175
+
176
+
177
+ class InternAttention(nn.Module):
178
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
179
+
180
+ def __init__(self, config: InternVisionConfig):
181
+ super().__init__()
182
+ self.config = config
183
+ self.embed_dim = config.hidden_size
184
+ self.num_heads = config.num_attention_heads
185
+ self.use_flash_attn = config.use_flash_attn and has_flash_attn
186
+ if config.use_flash_attn and not has_flash_attn:
187
+ print('Warning: Flash Attention is not available, use_flash_attn is set to False.')
188
+ self.head_dim = self.embed_dim // self.num_heads
189
+ if self.head_dim * self.num_heads != self.embed_dim:
190
+ raise ValueError(
191
+ f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'
192
+ f' {self.num_heads}).'
193
+ )
194
+
195
+ self.scale = self.head_dim ** -0.5
196
+ self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
197
+ self.attn_drop = nn.Dropout(config.attention_dropout)
198
+ self.proj_drop = nn.Dropout(config.dropout)
199
+
200
+ self.qk_normalization = config.qk_normalization
201
+
202
+ if self.qk_normalization:
203
+ self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
204
+ self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
205
+
206
+ if self.use_flash_attn:
207
+ self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
208
+ self.proj = nn.Linear(self.embed_dim, self.embed_dim)
209
+
210
+ def _naive_attn(self, x):
211
+ B, N, C = x.shape
212
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
213
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
214
+
215
+ if self.qk_normalization:
216
+ B_, H_, N_, D_ = q.shape
217
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
218
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
219
+
220
+ attn = ((q * self.scale) @ k.transpose(-2, -1))
221
+ attn = attn.softmax(dim=-1)
222
+ attn = self.attn_drop(attn)
223
+
224
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
225
+ x = self.proj(x)
226
+ x = self.proj_drop(x)
227
+ return x
228
+
229
+ def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
230
+ qkv = self.qkv(x)
231
+ qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
232
+
233
+ if self.qk_normalization:
234
+ q, k, v = qkv.unbind(2)
235
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
236
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
237
+ qkv = torch.stack([q, k, v], dim=2)
238
+
239
+ context, _ = self.inner_attn(
240
+ qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
241
+ )
242
+ outs = self.proj(rearrange(context, 'b s h d -> b s (h d)'))
243
+ outs = self.proj_drop(outs)
244
+ return outs
245
+
246
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
247
+ x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
248
+ return x
249
+
250
+
251
+ class InternMLP(nn.Module):
252
+ def __init__(self, config: InternVisionConfig):
253
+ super().__init__()
254
+ self.config = config
255
+ self.act = ACT2FN[config.hidden_act]
256
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
257
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
258
+
259
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
260
+ hidden_states = self.fc1(hidden_states)
261
+ hidden_states = self.act(hidden_states)
262
+ hidden_states = self.fc2(hidden_states)
263
+ return hidden_states
264
+
265
+
266
+ class InternVisionEncoderLayer(nn.Module):
267
+ def __init__(self, config: InternVisionConfig, drop_path_rate: float):
268
+ super().__init__()
269
+ self.embed_dim = config.hidden_size
270
+ self.intermediate_size = config.intermediate_size
271
+ self.norm_type = config.norm_type
272
+
273
+ self.attn = InternAttention(config)
274
+ self.mlp = InternMLP(config)
275
+ self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
276
+ self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
277
+
278
+ self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
279
+ self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
280
+ self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
281
+ self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
282
+
283
+ def forward(
284
+ self,
285
+ hidden_states: torch.Tensor,
286
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
287
+ """
288
+ Args:
289
+ hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
290
+ """
291
+ hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1)
292
+
293
+ hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2)
294
+
295
+ return hidden_states
296
+
297
+
298
+ class InternVisionEncoder(nn.Module):
299
+ """
300
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
301
+ [`InternEncoderLayer`].
302
+
303
+ Args:
304
+ config (`InternConfig`):
305
+ The corresponding vision configuration for the `InternEncoder`.
306
+ """
307
+
308
+ def __init__(self, config: InternVisionConfig):
309
+ super().__init__()
310
+ self.config = config
311
+ # stochastic depth decay rule
312
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
313
+ self.layers = nn.ModuleList([
314
+ InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
315
+ self.gradient_checkpointing = True
316
+
317
+ def forward(
318
+ self,
319
+ inputs_embeds,
320
+ output_hidden_states: Optional[bool] = None,
321
+ return_dict: Optional[bool] = None,
322
+ ) -> Union[Tuple, BaseModelOutput]:
323
+ r"""
324
+ Args:
325
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
326
+ Embedded representation of the inputs. Should be float, not int tokens.
327
+ output_hidden_states (`bool`, *optional*):
328
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
329
+ for more detail.
330
+ return_dict (`bool`, *optional*):
331
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
332
+ """
333
+ output_hidden_states = (
334
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
335
+ )
336
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
337
+
338
+ encoder_states = () if output_hidden_states else None
339
+ hidden_states = inputs_embeds
340
+
341
+ for idx, encoder_layer in enumerate(self.layers):
342
+ if output_hidden_states:
343
+ encoder_states = encoder_states + (hidden_states,)
344
+ if self.gradient_checkpointing and self.training:
345
+ layer_outputs = torch.utils.checkpoint.checkpoint(
346
+ encoder_layer,
347
+ hidden_states)
348
+ else:
349
+ layer_outputs = encoder_layer(
350
+ hidden_states,
351
+ )
352
+ hidden_states = layer_outputs
353
+
354
+ if output_hidden_states:
355
+ encoder_states = encoder_states + (hidden_states,)
356
+
357
+ if not return_dict:
358
+ return tuple(v for v in [hidden_states, encoder_states] if v is not None)
359
+ return BaseModelOutput(
360
+ last_hidden_state=hidden_states, hidden_states=encoder_states
361
+ )
362
+
363
+
364
+ class InternVisionModel(PreTrainedModel):
365
+ main_input_name = 'pixel_values'
366
+ _supports_flash_attn_2 = True
367
+ config_class = InternVisionConfig
368
+ _no_split_modules = ['InternVisionEncoderLayer']
369
+
370
+ def __init__(self, config: InternVisionConfig):
371
+ super().__init__(config)
372
+ self.config = config
373
+
374
+ self.embeddings = InternVisionEmbeddings(config)
375
+ self.encoder = InternVisionEncoder(config)
376
+
377
+ def resize_pos_embeddings(self, old_size, new_size, patch_size):
378
+ pos_emb = self.embeddings.position_embedding
379
+ _, num_positions, embed_dim = pos_emb.shape
380
+ cls_emb = pos_emb[:, :1, :]
381
+ pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
382
+ pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False)
383
+ pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
384
+ pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
385
+ self.embeddings.position_embedding = nn.Parameter(pos_emb)
386
+ self.embeddings.image_size = new_size
387
+ logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size))
388
+
389
+ def get_input_embeddings(self):
390
+ return self.embeddings
391
+
392
+ def forward(
393
+ self,
394
+ pixel_values: Optional[torch.FloatTensor] = None,
395
+ output_hidden_states: Optional[bool] = None,
396
+ return_dict: Optional[bool] = None,
397
+ pixel_embeds: Optional[torch.FloatTensor] = None,
398
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
399
+ output_hidden_states = (
400
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
401
+ )
402
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
403
+
404
+ if pixel_values is None and pixel_embeds is None:
405
+ raise ValueError('You have to specify pixel_values or pixel_embeds')
406
+
407
+ if pixel_embeds is not None:
408
+ hidden_states = pixel_embeds
409
+ else:
410
+ if len(pixel_values.shape) == 4:
411
+ hidden_states = self.embeddings(pixel_values)
412
+ else:
413
+ raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
414
+ encoder_outputs = self.encoder(
415
+ inputs_embeds=hidden_states,
416
+ output_hidden_states=output_hidden_states,
417
+ return_dict=return_dict,
418
+ )
419
+ last_hidden_state = encoder_outputs.last_hidden_state
420
+ pooled_output = last_hidden_state[:, 0, :]
421
+
422
+ if not return_dict:
423
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
424
+
425
+ return BaseModelOutputWithPooling(
426
+ last_hidden_state=last_hidden_state,
427
+ pooler_output=pooled_output,
428
+ hidden_states=encoder_outputs.hidden_states,
429
+ attentions=encoder_outputs.attentions,
430
+ )
modeling_internvl_chat.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import warnings
8
+ from typing import List, Optional, Tuple, Union
9
+
10
+ import torch.utils.checkpoint
11
+ import transformers
12
+ from torch import nn
13
+ from torch.nn import CrossEntropyLoss
14
+ from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM,
15
+ Qwen2ForCausalLM)
16
+ from transformers.modeling_outputs import CausalLMOutputWithPast
17
+ from transformers.modeling_utils import PreTrainedModel
18
+ from transformers.utils import ModelOutput, logging
19
+
20
+ from .configuration_internvl_chat import InternVLChatConfig
21
+ from .conversation import get_conv_template
22
+ from .modeling_intern_vit import InternVisionModel, has_flash_attn
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ def version_cmp(v1, v2, op='eq'):
28
+ import operator
29
+
30
+ from packaging import version
31
+ op_func = getattr(operator, op)
32
+ return op_func(version.parse(v1), version.parse(v2))
33
+
34
+
35
+ class InternVLChatModel(PreTrainedModel):
36
+ config_class = InternVLChatConfig
37
+ main_input_name = 'pixel_values'
38
+ base_model_prefix = 'language_model'
39
+ _supports_flash_attn_2 = True
40
+ _no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'Qwen2DecoderLayer']
41
+
42
+ def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None, use_flash_attn=True):
43
+ super().__init__(config)
44
+
45
+ assert version_cmp(transformers.__version__, '4.37.0', 'ge')
46
+ image_size = config.force_image_size or config.vision_config.image_size
47
+ patch_size = config.vision_config.patch_size
48
+ self.patch_size = patch_size
49
+ self.select_layer = config.select_layer
50
+ self.template = config.template
51
+ self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
52
+ self.downsample_ratio = config.downsample_ratio
53
+ self.ps_version = config.ps_version
54
+ use_flash_attn = use_flash_attn if has_flash_attn else False
55
+ config.vision_config.use_flash_attn = True if use_flash_attn else False
56
+ config.llm_config._attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager'
57
+
58
+ logger.info(f'num_image_token: {self.num_image_token}')
59
+ logger.info(f'ps_version: {self.ps_version}')
60
+ if vision_model is not None:
61
+ self.vision_model = vision_model
62
+ else:
63
+ self.vision_model = InternVisionModel(config.vision_config)
64
+ if language_model is not None:
65
+ self.language_model = language_model
66
+ else:
67
+ if config.llm_config.architectures[0] == 'LlamaForCausalLM':
68
+ self.language_model = LlamaForCausalLM(config.llm_config)
69
+ elif config.llm_config.architectures[0] == 'Qwen2ForCausalLM':
70
+ self.language_model = Qwen2ForCausalLM(config.llm_config)
71
+ else:
72
+ raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
73
+
74
+ vit_hidden_size = config.vision_config.hidden_size
75
+ llm_hidden_size = config.llm_config.hidden_size
76
+
77
+ self.mlp1 = nn.Sequential(
78
+ nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
79
+ nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
80
+ nn.GELU(),
81
+ nn.Linear(llm_hidden_size, llm_hidden_size)
82
+ )
83
+
84
+ self.img_context_token_id = None
85
+ self.conv_template = get_conv_template(self.template)
86
+ self.system_message = self.conv_template.system_message
87
+
88
+ def forward(
89
+ self,
90
+ pixel_values: torch.FloatTensor,
91
+ input_ids: torch.LongTensor = None,
92
+ attention_mask: Optional[torch.Tensor] = None,
93
+ position_ids: Optional[torch.LongTensor] = None,
94
+ image_flags: Optional[torch.LongTensor] = None,
95
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
96
+ labels: Optional[torch.LongTensor] = None,
97
+ use_cache: Optional[bool] = None,
98
+ output_attentions: Optional[bool] = None,
99
+ output_hidden_states: Optional[bool] = None,
100
+ return_dict: Optional[bool] = None,
101
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
102
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
103
+
104
+ image_flags = image_flags.squeeze(-1)
105
+ input_embeds = self.language_model.get_input_embeddings()(input_ids).clone()
106
+
107
+ vit_embeds = self.extract_feature(pixel_values)
108
+ vit_embeds = vit_embeds[image_flags == 1]
109
+ vit_batch_size = pixel_values.shape[0]
110
+
111
+ B, N, C = input_embeds.shape
112
+ input_embeds = input_embeds.reshape(B * N, C)
113
+
114
+ if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
115
+ print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
116
+
117
+ input_ids = input_ids.reshape(B * N)
118
+ selected = (input_ids == self.img_context_token_id)
119
+ try:
120
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
121
+ except Exception as e:
122
+ vit_embeds = vit_embeds.reshape(-1, C)
123
+ print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, '
124
+ f'vit_embeds.shape={vit_embeds.shape}')
125
+ n_token = selected.sum()
126
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
127
+
128
+ input_embeds = input_embeds.reshape(B, N, C)
129
+
130
+ outputs = self.language_model(
131
+ inputs_embeds=input_embeds,
132
+ attention_mask=attention_mask,
133
+ position_ids=position_ids,
134
+ past_key_values=past_key_values,
135
+ use_cache=use_cache,
136
+ output_attentions=output_attentions,
137
+ output_hidden_states=output_hidden_states,
138
+ return_dict=return_dict,
139
+ )
140
+ logits = outputs.logits
141
+
142
+ loss = None
143
+ if labels is not None:
144
+ # Shift so that tokens < n predict n
145
+ shift_logits = logits[..., :-1, :].contiguous()
146
+ shift_labels = labels[..., 1:].contiguous()
147
+ # Flatten the tokens
148
+ loss_fct = CrossEntropyLoss()
149
+ shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
150
+ shift_labels = shift_labels.view(-1)
151
+ # Enable model parallelism
152
+ shift_labels = shift_labels.to(shift_logits.device)
153
+ loss = loss_fct(shift_logits, shift_labels)
154
+
155
+ if not return_dict:
156
+ output = (logits,) + outputs[1:]
157
+ return (loss,) + output if loss is not None else output
158
+
159
+ return CausalLMOutputWithPast(
160
+ loss=loss,
161
+ logits=logits,
162
+ past_key_values=outputs.past_key_values,
163
+ hidden_states=outputs.hidden_states,
164
+ attentions=outputs.attentions,
165
+ )
166
+
167
+ def pixel_shuffle(self, x, scale_factor=0.5):
168
+ n, w, h, c = x.size()
169
+ # N, W, H, C --> N, W, H * scale, C // scale
170
+ x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
171
+ # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
172
+ x = x.permute(0, 2, 1, 3).contiguous()
173
+ # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
174
+ x = x.view(n, int(h * scale_factor), int(w * scale_factor),
175
+ int(c / (scale_factor * scale_factor)))
176
+ if self.ps_version == 'v1':
177
+ warnings.warn("In ps_version 'v1', the height and width have not been swapped back, "
178
+ 'which results in a transposed image.')
179
+ else:
180
+ x = x.permute(0, 2, 1, 3).contiguous()
181
+ return x
182
+
183
+ def extract_feature(self, pixel_values):
184
+ if self.select_layer == -1:
185
+ vit_embeds = self.vision_model(
186
+ pixel_values=pixel_values,
187
+ output_hidden_states=False,
188
+ return_dict=True).last_hidden_state
189
+ else:
190
+ vit_embeds = self.vision_model(
191
+ pixel_values=pixel_values,
192
+ output_hidden_states=True,
193
+ return_dict=True).hidden_states[self.select_layer]
194
+ vit_embeds = vit_embeds[:, 1:, :]
195
+
196
+ h = w = int(vit_embeds.shape[1] ** 0.5)
197
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
198
+ vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
199
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
200
+ vit_embeds = self.mlp1(vit_embeds)
201
+ return vit_embeds
202
+
203
+ def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None,
204
+ history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
205
+ IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None):
206
+ if history is not None or return_history:
207
+ print('Now multi-turn chat is not supported in batch_chat.')
208
+ raise NotImplementedError
209
+
210
+ if image_counts is not None:
211
+ num_patches_list = image_counts
212
+ print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.')
213
+
214
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
215
+ self.img_context_token_id = img_context_token_id
216
+
217
+ if verbose and pixel_values is not None:
218
+ image_bs = pixel_values.shape[0]
219
+ print(f'dynamic ViT batch size: {image_bs}')
220
+
221
+ queries = []
222
+ for idx, num_patches in enumerate(num_patches_list):
223
+ question = questions[idx]
224
+ if pixel_values is not None and '<image>' not in question:
225
+ question = '<image>\n' + question
226
+ template = get_conv_template(self.template)
227
+ template.system_message = self.system_message
228
+ template.append_message(template.roles[0], question)
229
+ template.append_message(template.roles[1], None)
230
+ query = template.get_prompt()
231
+
232
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
233
+ query = query.replace('<image>', image_tokens, 1)
234
+ queries.append(query)
235
+
236
+ tokenizer.padding_side = 'left'
237
+ model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
238
+ input_ids = model_inputs['input_ids'].to(self.device)
239
+ attention_mask = model_inputs['attention_mask'].to(self.device)
240
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
241
+ generation_config['eos_token_id'] = eos_token_id
242
+ generation_output = self.generate(
243
+ pixel_values=pixel_values,
244
+ input_ids=input_ids,
245
+ attention_mask=attention_mask,
246
+ **generation_config
247
+ )
248
+ responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
249
+ responses = [response.split(template.sep.strip())[0].strip() for response in responses]
250
+ return responses
251
+
252
+ def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
253
+ num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
254
+ verbose=False):
255
+
256
+ if history is None and pixel_values is not None and '<image>' not in question:
257
+ question = '<image>\n' + question
258
+
259
+ if num_patches_list is None:
260
+ num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
261
+ assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
262
+
263
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
264
+ self.img_context_token_id = img_context_token_id
265
+
266
+ template = get_conv_template(self.template)
267
+ template.system_message = self.system_message
268
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
269
+
270
+ history = [] if history is None else history
271
+ for (old_question, old_answer) in history:
272
+ template.append_message(template.roles[0], old_question)
273
+ template.append_message(template.roles[1], old_answer)
274
+ template.append_message(template.roles[0], question)
275
+ template.append_message(template.roles[1], None)
276
+ query = template.get_prompt()
277
+
278
+ if verbose and pixel_values is not None:
279
+ image_bs = pixel_values.shape[0]
280
+ print(f'dynamic ViT batch size: {image_bs}')
281
+
282
+ for num_patches in num_patches_list:
283
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
284
+ query = query.replace('<image>', image_tokens, 1)
285
+
286
+ model_inputs = tokenizer(query, return_tensors='pt')
287
+ input_ids = model_inputs['input_ids'].to(self.device)
288
+ attention_mask = model_inputs['attention_mask'].to(self.device)
289
+ generation_config['eos_token_id'] = eos_token_id
290
+ generation_output = self.generate(
291
+ pixel_values=pixel_values,
292
+ input_ids=input_ids,
293
+ attention_mask=attention_mask,
294
+ **generation_config
295
+ )
296
+ response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
297
+ response = response.split(template.sep.strip())[0].strip()
298
+ history.append((question, response))
299
+ if return_history:
300
+ return response, history
301
+ else:
302
+ query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
303
+ query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
304
+ if verbose:
305
+ print(query_to_print, response)
306
+ return response
307
+
308
+ @torch.no_grad()
309
+ def generate(
310
+ self,
311
+ pixel_values: Optional[torch.FloatTensor] = None,
312
+ input_ids: Optional[torch.FloatTensor] = None,
313
+ attention_mask: Optional[torch.LongTensor] = None,
314
+ visual_features: Optional[torch.FloatTensor] = None,
315
+ generation_config: Optional[GenerationConfig] = None,
316
+ output_hidden_states: Optional[bool] = None,
317
+ **generate_kwargs,
318
+ ) -> torch.LongTensor:
319
+
320
+ assert self.img_context_token_id is not None
321
+ if pixel_values is not None:
322
+ if visual_features is not None:
323
+ vit_embeds = visual_features
324
+ else:
325
+ vit_embeds = self.extract_feature(pixel_values)
326
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
327
+ B, N, C = input_embeds.shape
328
+ input_embeds = input_embeds.reshape(B * N, C)
329
+
330
+ input_ids = input_ids.reshape(B * N)
331
+ selected = (input_ids == self.img_context_token_id)
332
+ assert selected.sum() != 0
333
+ input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
334
+
335
+ input_embeds = input_embeds.reshape(B, N, C)
336
+ else:
337
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
338
+
339
+ outputs = self.language_model.generate(
340
+ inputs_embeds=input_embeds,
341
+ attention_mask=attention_mask,
342
+ generation_config=generation_config,
343
+ output_hidden_states=output_hidden_states,
344
+ use_cache=True,
345
+ **generate_kwargs,
346
+ )
347
+
348
+ return outputs