chenkq commited on
Commit
2d9f231
·
1 Parent(s): 2ecffe1

Update modeling_cogvlm.py: remove the dependence of triton

Browse files
Files changed (1) hide show
  1. modeling_cogvlm.py +57 -6
modeling_cogvlm.py CHANGED
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Optional, Tuple, List, Union, Literal, Dict, A
5
  import math
6
  import torch
7
  from torch import nn
 
8
  from torch.nn import CrossEntropyLoss
9
  from torchvision import transforms
10
  from einops import rearrange
@@ -15,7 +16,6 @@ from transformers.activations import ACT2FN
15
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
16
 
17
  from .configuration_cogvlm import CogVLMConfig
18
- from .util import FastRotaryEmbedding
19
  from .visual import EVA2CLIPModel
20
 
21
  if TYPE_CHECKING:
@@ -144,6 +144,57 @@ def attention_fn(
144
  return context_layer
145
 
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  class VisionExpertAttention(nn.Module):
148
  def __init__(self, config):
149
  super().__init__()
@@ -153,8 +204,7 @@ class VisionExpertAttention(nn.Module):
153
  self.head_dim = self.hidden_size // self.num_heads
154
  self.max_position_embeddings = config.max_position_embeddings
155
 
156
- # self.rotary_emb = RotaryEmbedding(self.hidden_size // self.num_heads)
157
- self.rotary_emb = FastRotaryEmbedding(dim=self.head_dim, pos_idx_in_fp32=False)
158
  self.vision_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
159
  self.vision_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
160
  self.language_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
@@ -193,8 +243,8 @@ class VisionExpertAttention(nn.Module):
193
  kv_seq_len = key_states.shape[-2]
194
  if past_key_value is not None:
195
  kv_seq_len += past_key_value[0].shape[-2]
196
-
197
- query_states, key_states = self.rotary_emb(query_states, key_states, position_ids=position_ids, max_seqlen=position_ids.max() + 1)
198
 
199
  if past_key_value is not None:
200
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
@@ -706,7 +756,8 @@ class CogVLMForCausalLM(CogVLMPreTrainedModel):
706
  # update token_type_ids with last value
707
  if "token_type_ids" in model_kwargs:
708
  token_type_ids = model_kwargs["token_type_ids"]
709
- new_token_type_ids = torch.ones(size=(token_type_ids.shape[0], 1), dtype=token_type_ids.dtype, device=token_type_ids.device) * LANGUAGE_TOKEN_TYPE
 
710
  model_kwargs["token_type_ids"] = torch.cat([token_type_ids, new_token_type_ids], dim=-1)
711
 
712
  if not is_encoder_decoder:
 
5
  import math
6
  import torch
7
  from torch import nn
8
+ from torch.nn import functional as F
9
  from torch.nn import CrossEntropyLoss
10
  from torchvision import transforms
11
  from einops import rearrange
 
16
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
17
 
18
  from .configuration_cogvlm import CogVLMConfig
 
19
  from .visual import EVA2CLIPModel
20
 
21
  if TYPE_CHECKING:
 
144
  return context_layer
145
 
146
 
147
+ class RotaryEmbedding(torch.nn.Module):
148
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
149
+ super().__init__()
150
+
151
+ self.dim = dim
152
+ self.max_position_embeddings = max_position_embeddings
153
+ self.base = base
154
+ inv_freq = self._compute_inv_freq(device)
155
+ self.register_buffer("inv_freq", inv_freq)
156
+ self.max_seq_len_cached = 0
157
+
158
+ def _compute_inv_freq(self, device=None):
159
+ return 1.0 / (
160
+ self.base
161
+ ** (torch.arange(0, self.dim, 2, device=device) / self.dim)
162
+ )
163
+
164
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
165
+ self.max_seq_len_cached = seq_len
166
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
167
+
168
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
169
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
170
+ emb = torch.cat((freqs, freqs), dim=-1)
171
+ self.register_buffer("cos_cached", emb.cos()[:, None, :].to(dtype), persistent=False)
172
+ self.register_buffer("sin_cached", emb.sin()[:, None, :].to(dtype), persistent=False)
173
+
174
+ def forward(self, x, seq_len):
175
+ # x: [bs, num_attention_heads, seq_len, head_size]
176
+ if seq_len > self.max_seq_len_cached:
177
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
178
+
179
+ return (
180
+ self.cos_cached[:seq_len, ...].to(dtype=x.dtype),
181
+ self.sin_cached[:seq_len, ...].to(dtype=x.dtype),
182
+ )
183
+
184
+
185
+ def rotate_half(x):
186
+ x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
187
+ return torch.cat((-x2, x1), dim=x1.ndim - 1)
188
+
189
+
190
+ def apply_rotary_pos_emb_index_bhs(q, k, cos, sin, position_id):
191
+ # batch_size, num_head, seq_len, hidden_size
192
+ cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(1), \
193
+ F.embedding(position_id, sin.squeeze(1)).unsqueeze(1)
194
+ q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
195
+ return q, k
196
+
197
+
198
  class VisionExpertAttention(nn.Module):
199
  def __init__(self, config):
200
  super().__init__()
 
204
  self.head_dim = self.hidden_size // self.num_heads
205
  self.max_position_embeddings = config.max_position_embeddings
206
 
207
+ self.rotary_emb = RotaryEmbedding(self.head_dim)
 
208
  self.vision_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
209
  self.vision_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
210
  self.language_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
 
243
  kv_seq_len = key_states.shape[-2]
244
  if past_key_value is not None:
245
  kv_seq_len += past_key_value[0].shape[-2]
246
+ cos, sin = self.rotary_emb(value_states, seq_len=position_ids.max() + 1)
247
+ query_states, key_states = apply_rotary_pos_emb_index_bhs(query_states, key_states, cos, sin, position_ids)
248
 
249
  if past_key_value is not None:
250
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
 
756
  # update token_type_ids with last value
757
  if "token_type_ids" in model_kwargs:
758
  token_type_ids = model_kwargs["token_type_ids"]
759
+ new_token_type_ids = torch.ones(size=(token_type_ids.shape[0], 1), dtype=token_type_ids.dtype,
760
+ device=token_type_ids.device) * LANGUAGE_TOKEN_TYPE
761
  model_kwargs["token_type_ids"] = torch.cat([token_type_ids, new_token_type_ids], dim=-1)
762
 
763
  if not is_encoder_decoder: