Bo1015 commited on
Commit
1a3263b
·
verified ·
1 Parent(s): b017cc5

Delete modeling_xtrimopglm.py

Browse files
Files changed (1) hide show
  1. modeling_xtrimopglm.py +0 -1566
modeling_xtrimopglm.py DELETED
@@ -1,1566 +0,0 @@
1
- """ PyTorch xTrimoPGLM model. """
2
-
3
- import math
4
- import copy
5
- import warnings
6
- import re
7
- import sys
8
- import os
9
- import pathlib
10
- import time
11
- import random
12
- import numpy as np
13
- from tqdm.auto import tqdm
14
-
15
- import torch
16
- import torch.utils.checkpoint
17
- import torch.nn.functional as F
18
- from torch import nn
19
- from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
20
- from torch.nn.utils import skip_init
21
- from typing import Optional, Tuple, Union, List, Callable, Dict, Any
22
- from copy import deepcopy
23
- from collections import namedtuple
24
-
25
- from transformers.modeling_outputs import (
26
- BaseModelOutputWithPast,
27
- MaskedLMOutput,
28
- CausalLMOutputWithPast,
29
- SequenceClassifierOutput,
30
- TokenClassifierOutput
31
- )
32
- from transformers import PreTrainedModel
33
- from transformers.utils import logging
34
- from transformers.generation.logits_process import LogitsProcessor
35
- from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
36
-
37
- from .configuration_xtrimopglm import xTrimoPGLMConfig
38
- from .quantization import quantize
39
-
40
- # flags required to enable jit fusion kernels
41
-
42
- if sys.platform != 'darwin':
43
- torch._C._jit_set_profiling_mode(False)
44
- torch._C._jit_set_profiling_executor(False)
45
- torch._C._jit_override_can_fuse_on_cpu(True)
46
- torch._C._jit_override_can_fuse_on_gpu(True)
47
-
48
- logger = logging.get_logger(__name__)
49
-
50
- _CHECKPOINT_FOR_DOC = "BioMap/xtrimopglm-100b-int4"
51
- _CONFIG_FOR_DOC = "xTrimoPGLMConfig"
52
- DeepNormCoefficients = namedtuple("DeepNormCoefficients", ["alpha", "beta"])
53
-
54
- def default_init(cls, *args, **kwargs):
55
- return cls(*args, **kwargs)
56
-
57
-
58
- def get_deepnorm_coefficients(config: xTrimoPGLMConfig):
59
- """
60
- DeepNorm coefficients from : https://kexue.fm/archives/8978
61
- """
62
- num_layers = config.num_layers
63
- return DeepNormCoefficients(alpha=(2 * num_layers) ** 0.5, beta=(2 * num_layers) ** -0.5)
64
-
65
-
66
- class InvalidScoreLogitsProcessor(LogitsProcessor):
67
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
68
- if torch.isnan(scores).any() or torch.isinf(scores).any():
69
- scores.zero_()
70
- scores[..., 5] = 5e4
71
- return scores
72
-
73
-
74
- def split_tensor_along_last_dim(
75
- tensor: torch.Tensor,
76
- num_partitions: int,
77
- contiguous_split_chunks: bool = False,
78
- ) -> List[torch.Tensor]:
79
- """Split a tensor along its last dimension.
80
-
81
- Arguments:
82
- tensor: input tensor.
83
- num_partitions: number of partitions to split the tensor
84
- contiguous_split_chunks: If True, make each chunk contiguous
85
- in memory.
86
-
87
- Returns:
88
- A list of Tensors
89
- """
90
- # Get the size and dimension.
91
- last_dim = tensor.dim() - 1
92
- last_dim_size = tensor.size()[last_dim] // num_partitions
93
- # Split.
94
- tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
95
- # Note: torch.split does not create contiguous tensors by default.
96
- if contiguous_split_chunks:
97
- return tuple(chunk.contiguous() for chunk in tensor_list)
98
-
99
- return tensor_list
100
-
101
- class RotaryEmbedding(torch.nn.Module):
102
-
103
- def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
104
- super().__init__()
105
- inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)).to(precision)
106
- self.dim = dim
107
- self.base = base
108
- self.learnable = learnable
109
- if learnable:
110
- self.inv_freq = torch.nn.Parameter(inv_freq)
111
- self.max_seq_len_cached = None
112
- else:
113
- self.register_buffer('inv_freq', inv_freq)
114
- self.max_seq_len_cached = None
115
- self.cos_cached = None
116
- self.sin_cached = None
117
- self.precision = precision
118
-
119
- def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
120
- if f'{prefix}inv_freq' in state_dict:
121
- super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
122
- else:
123
- self.inv_freq.copy_(1. / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)).to(self.precision))
124
-
125
- def forward(self, x, seq_dim=1, seq_len=None):
126
- if seq_len is None:
127
- seq_len = x.shape[seq_dim]
128
- if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
129
- self.max_seq_len_cached = None if self.learnable else seq_len
130
- t = torch.arange(seq_len, device=x.device, dtype=torch.float32)
131
- freqs = torch.einsum('i,j->ij', t, self.inv_freq.to(x.device))
132
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
133
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
134
- if self.precision == torch.bfloat16 or self.precision == torch.half:
135
- emb = emb.float()
136
- # [sx, 1 (b * np), hn]
137
- cos_cached = emb.cos()[:, None, :]
138
- sin_cached = emb.sin()[:, None, :]
139
- if self.precision == torch.bfloat16:
140
- cos_cached = cos_cached.bfloat16()
141
- sin_cached = sin_cached.bfloat16()
142
- elif self.precision == torch.half:
143
- cos_cached = cos_cached.half()
144
- sin_cached = sin_cached.half()
145
- if self.learnable:
146
- return cos_cached, sin_cached
147
- self.cos_cached, self.sin_cached = cos_cached, sin_cached
148
- return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
149
-
150
- def rotate_half(x):
151
- x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
152
- return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
153
-
154
- def assert_dim_check(tensor, ndim=None, shape=None):
155
- if ndim is not None:
156
- assert tensor.ndim == ndim, f"Exepct tensor.ndim={ndim}. gut got tensor.shape={tensor.shape}"
157
- if shape is not None:
158
- assert list(tensor.shape) == list(shape), f"Exepct tensor.shape={shape}. gut got tensor.shape={tensor.shape}"
159
-
160
- def apply_rotary_pos_emb_index_torch(q, k, cos, sin, position_id): # jitting fails with bf16
161
- # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]
162
- cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \
163
- F.embedding(position_id, sin.squeeze(1)).unsqueeze(2)
164
- q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
165
- return q, k
166
-
167
- class RMSNorm(torch.nn.Module):
168
- def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
169
- super().__init__()
170
- self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
171
- self.eps = eps
172
-
173
- def forward(self, hidden_states: torch.Tensor):
174
- input_dtype = hidden_states.dtype
175
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
176
- hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
177
-
178
- return (self.weight * hidden_states).to(input_dtype)
179
-
180
- class CoreAttention(torch.nn.Module):
181
- def __init__(self, config: xTrimoPGLMConfig, layer_number):
182
- super(CoreAttention, self).__init__()
183
-
184
- self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
185
- self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
186
- if self.apply_query_key_layer_scaling:
187
- self.attention_softmax_in_fp32 = True
188
- self.layer_number = max(1, layer_number)
189
-
190
- projection_size = config.kv_channels * config.num_attention_heads
191
-
192
- # Per attention head and per partition values.
193
- self.hidden_size_per_partition = projection_size
194
- self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
195
- self.num_attention_heads_per_partition = config.num_attention_heads
196
-
197
- coeff = None
198
- self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
199
- if self.apply_query_key_layer_scaling:
200
- coeff = self.layer_number
201
- self.norm_factor *= coeff
202
- self.coeff = coeff
203
-
204
- self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
205
-
206
- self.is_causal = config.is_causal
207
- self.use_pytorch_sdpa = config.use_pytorch_sdpa
208
-
209
- def forward(self, query_layer, key_layer, value_layer, attention_mask):
210
- # query_layer, key_layer, value_layer: [seq_len, batch_size, num_heads, head_dim]
211
- # import pdb; pdb.set_trace();
212
- pytorch_major_version = int(torch.__version__.split('.')[0])
213
- # assert pytorch_major_version >= 2, f"Expect PyTorch version > 2.0"
214
- if pytorch_major_version >= 2 and self.use_pytorch_sdpa:
215
- dropout_p = self.attention_dropout.p if self.training else 0
216
- # [seq_len, batch_size, num_heads, head_dim] -> [batch_size, num_heads, seq_len, head_dim]
217
- query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
218
- # import pdb; pdb.set_trace();
219
- if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
220
- # context_layer: [batch_size, num_heads, seq_len, head_dim]
221
- context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, is_causal=self.is_causal, dropout_p=dropout_p)
222
- else:
223
- if (attention_mask is not None) and (attention_mask.dtype == torch.bool):
224
- attention_mask = attention_mask.logical_not() ## DO NOT inplace operation!!!!
225
- context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, attention_mask, dropout_p=dropout_p)
226
- # [batch_size, num_heads, seq_len, head_dim] -> [seq_len, batch_size, num_heads, head_dim]
227
- context_layer = context_layer.permute(2, 0, 1, 3)
228
- # [seq_len, batch_size, 2560]
229
- new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
230
- context_layer = context_layer.reshape(*new_context_layer_shape)
231
- else:
232
- # Raw attention scores
233
-
234
- # [b, np, sq, sk]
235
- output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
236
-
237
- # [sq, b, np, hn] -> [sq, b * np, hn]
238
- query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
239
- # [sk, b, np, hn] -> [sk, b * np, hn]
240
- key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
241
-
242
- # preallocting input tensor: [b * np, sq, sk]
243
- matmul_input_buffer = torch.empty(
244
- output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
245
- device=query_layer.device
246
- )
247
-
248
- # Raw attention scores. [b * np, sq, sk]
249
- matmul_result = torch.baddbmm(
250
- matmul_input_buffer,
251
- query_layer.transpose(0, 1), # [b * np, sq, hn]
252
- key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
253
- beta=0.0,
254
- alpha=(1.0 / self.norm_factor),
255
- )
256
-
257
- # change view to [b, np, sq, sk]
258
- attention_scores = matmul_result.view(*output_size)
259
-
260
- # ===========================
261
- # Attention probs and dropout
262
- # ===========================
263
-
264
- # attention scores and attention mask [b, np, sq, sk]
265
- if self.attention_softmax_in_fp32:
266
- attention_scores = attention_scores.float()
267
- if self.coeff is not None:
268
- attention_scores = attention_scores * self.coeff
269
- if self.is_causal and attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
270
- attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
271
- device=attention_scores.device, dtype=torch.bool)
272
- attention_mask.tril_()
273
- attention_mask = ~attention_mask
274
- if attention_mask is not None:
275
- attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
276
- attention_probs = F.softmax(attention_scores, dim=-1)
277
- attention_probs = attention_probs.type_as(value_layer)
278
-
279
- # This is actually dropping out entire tokens to attend to, which might
280
- # seem a bit unusual, but is taken from the original Transformer paper.
281
- attention_probs = self.attention_dropout(attention_probs)
282
- # =========================
283
- # Context layer. [sq, b, hp]
284
- # =========================
285
-
286
- # value_layer -> context layer.
287
- # [sk, b, np, hn] --> [b, np, sq, hn]
288
-
289
- # context layer shape: [b, np, sq, hn]
290
- output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
291
- # change view [sk, b * np, hn]
292
- value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
293
- # change view [b * np, sq, sk]
294
- attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
295
- # matmul: [b * np, sq, hn]
296
- context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
297
- # change view [b, np, sq, hn]
298
- context_layer = context_layer.view(*output_size)
299
- # [b, np, sq, hn] --> [sq, b, np, hn]
300
- context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
301
- # [sq, b, np, hn] --> [sq, b, hp]
302
- new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
303
- context_layer = context_layer.view(*new_context_layer_shape)
304
-
305
- return context_layer
306
-
307
-
308
- class SelfAttention(torch.nn.Module):
309
- """Parallel self-attention layer abstract class.
310
-
311
- Self-attention layer takes input with size [s, b, h]
312
- and returns output of the same size.
313
- """
314
-
315
- def __init__(self, config: xTrimoPGLMConfig, layer_number, device=None):
316
- super(SelfAttention, self).__init__()
317
- self.layer_number = max(1, layer_number)
318
-
319
- self.projection_size = config.kv_channels * config.num_attention_heads
320
-
321
- # Per attention head and per partition values.
322
- self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
323
- self.num_attention_heads_per_partition = config.num_attention_heads
324
-
325
- self.multi_query_attention = config.multi_query_attention
326
- self.qkv_hidden_size = 3 * self.projection_size
327
- if self.multi_query_attention:
328
- self.num_multi_query_groups_per_partition = config.multi_query_group_num
329
- self.qkv_hidden_size = (
330
- self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
331
- )
332
- self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
333
- bias=config.add_bias_linear or config.add_qkv_bias,
334
- device=device, **_config_to_kwargs(config)
335
- )
336
-
337
- self.core_attention = CoreAttention(config, self.layer_number)
338
-
339
- # Output.
340
- self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear, device=device, **_config_to_kwargs(config))
341
-
342
- self.rotary_embedding_2d = config.rotary_embedding_2d
343
- # dim, base=10000, precision=torch.half, learnable=False
344
- self.rotary_emb = RotaryEmbedding(self.hidden_size_per_attention_head // 2 if self.rotary_embedding_2d else self.hidden_size_per_attention_head,
345
- base=10000, precision=config.torch_dtype, learnable=False)
346
-
347
-
348
- def forward(
349
- self, hidden_states, attention_mask, position_ids, kv_cache=None, use_cache=True
350
- ):
351
- # hidden_states: [sq, b, h]
352
-
353
- # =================================================
354
- # Pre-allocate memory for key-values for inference.
355
- # =================================================
356
- # =====================
357
- # Query, Key, and Value
358
- # =====================
359
-
360
- # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
361
- mixed_x_layer = self.query_key_value(hidden_states)
362
-
363
- if self.multi_query_attention:
364
- (query_layer, key_layer, value_layer) = mixed_x_layer.split(
365
- [
366
- self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
367
- self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
368
- self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
369
- ],
370
- dim=-1,
371
- )
372
- query_layer = query_layer.view(
373
- query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
374
- )
375
- key_layer = key_layer.view(
376
- key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
377
- )
378
- value_layer = value_layer.view(
379
- value_layer.size()[:-1]
380
- + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
381
- )
382
- else:
383
- new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads_per_partition, 3 * self.hidden_size_per_attention_head)
384
- mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
385
- # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
386
- (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
387
-
388
- # apply relative positional encoding (rotary embedding)
389
- if position_ids is not None: # [seq_len, 2, batch_size, 32, 2]
390
-
391
- if self.rotary_embedding_2d:
392
- q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) # 32
393
- k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1))
394
- # import pdb; pdb.set_trace();
395
- cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1) # 32
396
- position_ids, block_position_ids = \
397
- position_ids[:, 0, :].transpose(0, 1).contiguous(), \
398
- position_ids[:, 1, :].transpose(0, 1).contiguous()
399
- q1, k1 = apply_rotary_pos_emb_index_torch(q1, k1, cos, sin, position_ids)
400
- q2, k2 = apply_rotary_pos_emb_index_torch(q2, k2, cos, sin, block_position_ids)
401
- query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1))
402
- key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1))
403
- else:
404
- # [b, sq] -> [sq, b]
405
- position_ids = position_ids.transpose(0, 1)
406
- cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1)
407
- query_layer, key_layer = apply_rotary_pos_emb_index_torch(query_layer, key_layer, cos, sin, position_ids)
408
-
409
- # adjust key and value for inference
410
- if kv_cache is not None:
411
- cache_k, cache_v = kv_cache
412
- key_layer = torch.cat((cache_k, key_layer), dim=0)
413
- value_layer = torch.cat((cache_v, value_layer), dim=0)
414
- if use_cache:
415
- kv_cache = (key_layer, value_layer)
416
- else:
417
- kv_cache = None
418
-
419
- if self.multi_query_attention:
420
- key_layer = key_layer.unsqueeze(-2)
421
- key_layer = key_layer.expand(-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1)
422
- key_layer = key_layer.contiguous().view(key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head))
423
- value_layer = value_layer.unsqueeze(-2)
424
- value_layer = value_layer.expand(-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1)
425
- value_layer = value_layer.contiguous().view(value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head))
426
-
427
- # ==================================
428
- # core attention computation
429
- # ==================================
430
-
431
- context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) # context_layer: [seq_len, batch_size, num_heads*head_dim]
432
- output = self.dense(context_layer)
433
- # =================
434
- # Output. [sq, b, h]
435
- # =================
436
-
437
- # output = context_layer @ self.dense.weight.T + self.dense.bias
438
- return output, kv_cache
439
-
440
-
441
- def _config_to_kwargs(args):
442
- common_kwargs = {
443
- "dtype": args.torch_dtype,
444
- }
445
- return common_kwargs
446
-
447
-
448
- class MLP(torch.nn.Module):
449
- """MLP.
450
-
451
- MLP will take the input with h hidden state, project it to 4*h
452
- hidden dimension, perform nonlinear transformation, and project the
453
- state back into h hidden dimension.
454
- """
455
-
456
- def __init__(self, config: xTrimoPGLMConfig, device=None):
457
- super(MLP, self).__init__()
458
-
459
- self.add_bias = config.add_bias_linear
460
- self.moe = config.moe
461
- self.num_experts = config.num_experts
462
- self.experts_per_token = config.experts_per_token # 2
463
-
464
- # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
465
- self.dense_h_to_4h = nn.Linear(
466
- config.hidden_size,
467
- config.ffn_hidden_size * 2,
468
- bias=self.add_bias,
469
- device=device,
470
- **_config_to_kwargs(config)
471
- )
472
-
473
- def swiglu(x):
474
- x = torch.chunk(x, 2, dim=-1)
475
- return x[0] * F.silu(x[1])
476
-
477
- def geglu(x):
478
- x = torch.chunk(x, 2, dim=-1)
479
- return x[0] * F.gelu(x[1])
480
-
481
- if config.glu_activation == 'geglu':
482
- self.activation_func = geglu
483
- elif config.glu_activation == 'swiglu':
484
- self.activation_func = swiglu
485
- else:
486
- assert RuntimeError(f"Unsupported glu_activation: {config.glu_activation}")
487
-
488
- # Project back to h.
489
- self.dense_4h_to_h = nn.Linear(
490
- config.ffn_hidden_size,
491
- config.hidden_size,
492
- bias=self.add_bias,
493
- device=device,
494
- **_config_to_kwargs(config)
495
- )
496
-
497
- if self.moe:
498
- assert self.num_experts > 1
499
- del self.dense_h_to_4h
500
- del self.dense_4h_to_h
501
- self.router = nn.Linear(
502
- config.hidden_size,
503
- config.num_experts,
504
- bias=False,
505
- device=device,
506
- dtype=torch.float32
507
- )
508
- for i in range(0, self.num_experts):
509
- self.register_module(f"dense_h_to_4h_{i}", nn.Linear(
510
- config.hidden_size,
511
- config.ffn_hidden_size * 2,
512
- bias=self.add_bias,
513
- device=device,
514
- **_config_to_kwargs(config)
515
- ))
516
- self.register_module(f"dense_4h_to_h_{i}", nn.Linear(
517
- config.ffn_hidden_size,
518
- config.hidden_size,
519
- bias=self.add_bias,
520
- device=device,
521
- **_config_to_kwargs(config)
522
- ))
523
-
524
- def moe_forward(self, hidden_states, expert_idx):
525
- intermediate_parallel = getattr(self, f"dense_h_to_4h_{expert_idx}")(hidden_states)
526
- intermediate_parallel = self.activation_func(intermediate_parallel)
527
- output = getattr(self, f"dense_4h_to_h_{expert_idx}")(intermediate_parallel)
528
- return output
529
-
530
- def forward(self, hidden_states):
531
- if self.moe:
532
- # import pdb; pdb.set_trace();
533
- s, b, n = hidden_states.shape
534
- dtype = hidden_states.dtype
535
- hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [s*b h]
536
- route = self.router(hidden_states).to(dtype)
537
-
538
- weights, selected_experts = torch.topk(route, self.experts_per_token)
539
- weights = F.softmax(weights, dim=1, dtype=torch.float).to(hidden_states.dtype)
540
- output = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
541
- for expert_idx in range(self.num_experts):
542
- batch_idx, nth_expert = torch.where(selected_experts == expert_idx)
543
- if nth_expert.shape[0] == 0:
544
- continue
545
- cur_out = self.moe_forward(hidden_states[batch_idx], expert_idx)
546
- output[batch_idx] += weights[batch_idx, nth_expert, None] * cur_out
547
- output = output.reshape(s, b, n)
548
- else:
549
- # [s, b, 4hp]
550
- intermediate_parallel = self.dense_h_to_4h(hidden_states)
551
- intermediate_parallel = self.activation_func(intermediate_parallel)
552
- # [s, b, h]
553
- output = self.dense_4h_to_h(intermediate_parallel)
554
- return output
555
-
556
- class xTrimoPGLMBlock(torch.nn.Module):
557
- """A single transformer layer.
558
-
559
- Transformer layer takes input with size [s, b, h] and returns an
560
- output of the same size.
561
- """
562
-
563
- def __init__(self, config: xTrimoPGLMConfig, layer_number, device=None):
564
- super(xTrimoPGLMBlock, self).__init__()
565
- self.layer_number = layer_number
566
-
567
- self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
568
-
569
- self.fp32_residual_connection = config.fp32_residual_connection
570
-
571
- LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
572
- # Layernorm on the input data.
573
- self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon)
574
-
575
- # Self attention.
576
- self.self_attention = SelfAttention(config, layer_number, device=device)
577
- self.hidden_dropout = config.hidden_dropout
578
-
579
- # Layernorm on the attention output
580
- self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon)
581
-
582
- # MLP
583
- self.mlp = MLP(config, device=device)
584
-
585
- self.deepnorm_coeff = get_deepnorm_coefficients(config) if config.deepnorm else None
586
-
587
- def forward(
588
- self, hidden_states, attention_mask, position_ids, kv_cache=None, use_cache=True,
589
- ):
590
- # hidden_states: [s, b, h]
591
- # Layer norm at the beginning of the transformer layer.
592
- layernorm_output = self.input_layernorm(hidden_states)
593
- # Self attention.
594
- attention_output, kv_cache = self.self_attention(
595
- layernorm_output,
596
- attention_mask,
597
- position_ids, # [batch_size, 2, seq_len, 32, 2]
598
- kv_cache=kv_cache,
599
- use_cache=use_cache
600
- )
601
-
602
- # Residual connection.
603
- if self.apply_residual_connection_post_layernorm:
604
- residual = layernorm_output
605
- else:
606
- residual = hidden_states
607
-
608
- layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
609
- if self.deepnorm_coeff is not None:
610
- layernorm_input = residual*self.deepnorm_coeff.alpha + layernorm_input
611
- else:
612
- layernorm_input = residual + layernorm_input
613
-
614
- # Layer norm post the self attention.
615
- layernorm_output = self.post_attention_layernorm(layernorm_input)
616
-
617
- # MLP.
618
- mlp_output = self.mlp(layernorm_output)
619
-
620
- # Second residual connection.
621
- if self.apply_residual_connection_post_layernorm:
622
- residual = layernorm_output
623
- else:
624
- residual = layernorm_input
625
-
626
- output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
627
- if self.deepnorm_coeff is not None:
628
- output = residual*self.deepnorm_coeff.alpha + output
629
- else:
630
- #print(f"2 self.deepnorm_coeff is None")
631
- output = residual + output
632
-
633
- return output, kv_cache
634
-
635
-
636
- class xTrimoPGLMTransformer(torch.nn.Module):
637
- """Transformer class."""
638
-
639
- def __init__(self, config: xTrimoPGLMConfig, device=None):
640
- super(xTrimoPGLMTransformer, self).__init__()
641
-
642
- self.fp32_residual_connection = config.fp32_residual_connection
643
- self.post_layer_norm = config.post_layer_norm
644
-
645
- # Number of layers.
646
- self.num_layers = config.num_layers
647
-
648
- # Transformer layers.
649
- def build_layer(layer_number):
650
- return xTrimoPGLMBlock(config, layer_number, device=device)
651
-
652
- self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
653
-
654
- if self.post_layer_norm:
655
- LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
656
- # Final layer norm before output.
657
- self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon)
658
-
659
- self.gradient_checkpointing = False
660
-
661
- def _get_layer(self, layer_number):
662
- return self.layers[layer_number]
663
-
664
- def forward(
665
- self, hidden_states, attention_mask, position_ids, kv_caches=None,
666
- use_cache: Optional[bool] = True,
667
- output_hidden_states: Optional[bool] = False,
668
- ):
669
- if not kv_caches:
670
- kv_caches = [None for _ in range(self.num_layers)]
671
- presents = () if use_cache else None
672
- if self.gradient_checkpointing and self.training:
673
- if use_cache:
674
- logger.warning_once(
675
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
676
- )
677
- use_cache = False
678
-
679
- all_self_attentions = None
680
- all_hidden_states = () if output_hidden_states else None
681
- for index in range(self.num_layers):
682
- if output_hidden_states:
683
- all_hidden_states = all_hidden_states + (hidden_states,)
684
-
685
- layer = self._get_layer(index)
686
- if self.gradient_checkpointing and self.training and torch.is_grad_enabled():
687
- layer_ret = get_checkpoint_fn()(
688
- layer,
689
- hidden_states,
690
- attention_mask,
691
- position_ids,
692
- kv_caches[index],
693
- use_cache
694
- )
695
- else:
696
- layer_ret = layer(
697
- hidden_states,
698
- attention_mask,
699
- position_ids,
700
- kv_cache=kv_caches[index],
701
- use_cache=use_cache
702
- )
703
- hidden_states, kv_cache = layer_ret
704
- if use_cache:
705
- presents = presents + (kv_cache,)
706
-
707
-
708
- # Final layer norm.
709
- if self.post_layer_norm:
710
- hidden_states = self.final_layernorm(hidden_states)
711
-
712
- if output_hidden_states:
713
- all_hidden_states = all_hidden_states + (hidden_states,)
714
-
715
- return hidden_states, presents, all_hidden_states, all_self_attentions
716
-
717
-
718
- class xTrimoPGLMPreTrainedModel(PreTrainedModel):
719
- """
720
- An abstract class to handle weights initialization and
721
- a simple interface for downloading and loading pretrained models.
722
- """
723
-
724
- is_parallelizable = False
725
- supports_gradient_checkpointing = True
726
- config_class = xTrimoPGLMConfig
727
- base_model_prefix = "transformer"
728
- _no_split_modules = ["xTrimoPGLMBlock"]
729
-
730
- _quantized = False
731
-
732
-
733
- def get_masks(self, input_ids, past_key_values, padding_mask=None, is_causal=True):
734
- batch_size, seq_length = input_ids.shape
735
- full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
736
- if is_causal:
737
- full_attention_mask.tril_()
738
- past_length = 0
739
- if past_key_values:
740
- past_length = past_key_values[0][0].shape[0]
741
- if past_length:
742
- full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
743
- device=input_ids.device), full_attention_mask), dim=-1)
744
- if padding_mask is not None:
745
- full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
746
- if not past_length and padding_mask is not None:
747
- full_attention_mask -= padding_mask.unsqueeze(-1) - 1
748
- full_attention_mask = (full_attention_mask < 0.5).bool()
749
- full_attention_mask.unsqueeze_(1)
750
- return full_attention_mask
751
-
752
- def get_position_ids(self, input_ids, device, context_length=0):
753
- batch_size, seq_length = input_ids.shape
754
- if self.config.rotary_embedding_2d:
755
- if self.config.is_causal: # 100b model
756
- position_ids_1 = torch.zeros(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len]
757
- position_ids_2 = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len]
758
- position_ids = torch.stack([position_ids_1, position_ids_2], axis=1) # [batch_size, 2, seq_len]
759
- else:
760
- position_ids_1 = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len]
761
- position_ids_2 = torch.zeros(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len]
762
- position_ids = torch.stack([position_ids_1, position_ids_2], axis=1) # [batch_size, 2, seq_len]
763
- else:
764
- position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) # [batch_size, 1, seq_len]
765
- return position_ids
766
-
767
- def _set_gradient_checkpointing(self, module, value=False):
768
- if isinstance(module, xTrimoPGLMTransformer):
769
- module.gradient_checkpointing = value
770
-
771
-
772
- # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
773
- def _init_weights(self, module):
774
- std = self.config.initializer_range
775
- """Initialize the weights"""
776
- if isinstance(module, nn.Linear):
777
- # Slightly different from the TF version which uses truncated_normal for initialization
778
- # cf https://github.com/pytorch/pytorch/pull/5617
779
- module.weight.data.normal_(mean=0.0, std=std)
780
- if module.bias is not None:
781
- module.bias.data.zero_()
782
- elif isinstance(module, nn.Embedding):
783
- module.weight.data.normal_(mean=0.0, std=std)
784
- if module.padding_idx is not None:
785
- module.weight.data[module.padding_idx].zero_()
786
- elif isinstance(module, nn.LayerNorm):
787
- module.bias.data.zero_()
788
- module.weight.data.fill_(1.0)
789
-
790
- def quantize(self, weight_bit_width: int, empty_init=True, device=None):
791
- if self._quantized:
792
- print(f"Model has been quantized...")
793
- return
794
- self.transformer.encoder = quantize(self.transformer.encoder, weight_bit_width, empty_init, device)
795
- self._quantized = True
796
- return self
797
-
798
- class Embedding(torch.nn.Module):
799
- """Language model embeddings."""
800
-
801
- def __init__(self, config: xTrimoPGLMConfig, device=None):
802
- super(Embedding, self).__init__()
803
-
804
- self.hidden_size = config.hidden_size
805
- # Word embeddings (parallel).
806
- self.word_embeddings = nn.Embedding(
807
- config.padded_vocab_size,
808
- self.hidden_size,
809
- dtype=config.torch_dtype,
810
- device=device
811
- )
812
- self.fp32_residual_connection = config.fp32_residual_connection
813
-
814
-
815
- def forward(self, input_ids):
816
- # Embeddings.
817
- words_embeddings = self.word_embeddings(input_ids)
818
- embeddings = words_embeddings
819
- # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
820
- embeddings = embeddings.transpose(0, 1).contiguous()
821
- # If the input flag for fp32 residual connection is set, convert for float.
822
- if self.fp32_residual_connection:
823
- embeddings = embeddings.float()
824
- return embeddings
825
-
826
- class xTrimoPGLMModel(xTrimoPGLMPreTrainedModel):
827
- def __init__(self, config: xTrimoPGLMConfig, device=None, empty_init=True):
828
- super().__init__(config)
829
- if empty_init:
830
- init_method = skip_init
831
- else:
832
- init_method = default_init
833
- init_kwargs = {}
834
- if device is not None:
835
- init_kwargs["device"] = device
836
- self.embedding = init_method(Embedding, config, **init_kwargs)
837
- self.num_layers = config.num_layers
838
- self.multi_query_group_num = config.multi_query_group_num
839
- self.kv_channels = config.kv_channels
840
-
841
- # Rotary positional embeddings
842
- self.seq_length = config.seq_length
843
- rotary_dim = (
844
- config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
845
- )
846
-
847
- # self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, base=10000, precision=config.torch_dtype, learnable=False)
848
- self.encoder = init_method(xTrimoPGLMTransformer, config, **init_kwargs)
849
-
850
- self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
851
- dtype=config.torch_dtype, **init_kwargs)
852
-
853
- def get_input_embeddings(self):
854
- return self.embedding.word_embeddings
855
-
856
- def set_input_embeddings(self, value):
857
- self.embedding.word_embeddings = value
858
-
859
- def forward(
860
- self,
861
- input_ids,
862
- position_ids: Optional[torch.Tensor] = None, # position_ids: [batch_size, 2, seq_len]
863
- attention_mask: Optional[torch.BoolTensor] = None,
864
- full_attention_mask: Optional[torch.BoolTensor] = None,
865
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
866
- inputs_embeds: Optional[torch.Tensor] = None,
867
- use_cache: Optional[bool] = None,
868
- output_hidden_states: Optional[bool] = None,
869
- return_dict: Optional[bool] = None,
870
- ):
871
- output_hidden_states = (
872
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
873
- )
874
- if self.config.is_causal:
875
- use_cache = use_cache if use_cache is not None else self.config.use_cache
876
- else:
877
- use_cache = False
878
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
879
-
880
- batch_size, seq_length = input_ids.shape
881
-
882
- if inputs_embeds is None:
883
- inputs_embeds = self.embedding(input_ids)
884
-
885
- if full_attention_mask is None:
886
- if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
887
- full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
888
- # Run encoder.
889
- hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
890
- inputs_embeds, full_attention_mask, position_ids=position_ids,
891
- kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
892
- )
893
-
894
- if not return_dict:
895
- return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
896
-
897
- return BaseModelOutputWithPast(
898
- last_hidden_state=hidden_states,
899
- past_key_values=presents,
900
- hidden_states=all_hidden_states,
901
- attentions=all_self_attentions,
902
- )
903
-
904
-
905
- class xTrimoPGLMForMaskedLM(xTrimoPGLMPreTrainedModel):
906
- def __init__(self, config: xTrimoPGLMConfig, empty_init=True, device=None):
907
- super().__init__(config)
908
-
909
- self.max_sequence_length = config.max_length
910
- self.transformer = xTrimoPGLMModel(config, empty_init=empty_init, device=device)
911
- self.config = config
912
- if self.config.quantization_bit:
913
- print(f"Begin Quantization to {self.config.quantization_bit} bit")
914
- self.quantize(self.config.quantization_bit, empty_init=True, device=device)
915
-
916
- def forward(
917
- self,
918
- input_ids: Optional[torch.Tensor] = None,
919
- position_ids: Optional[torch.Tensor] = None,
920
- attention_mask: Optional[torch.Tensor] = None,
921
- past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
922
- inputs_embeds: Optional[torch.Tensor] = None,
923
- labels: Optional[torch.Tensor] = None,
924
- use_cache: Optional[bool] = None,
925
- output_attentions: Optional[bool] = None,
926
- output_hidden_states: Optional[bool] = None,
927
- return_dict: Optional[bool] = None,
928
- return_last_logit: Optional[bool] = None,
929
- return_last_hidden_state: Optional[bool] = None
930
- ):
931
- if self.config.is_causal:
932
- use_cache = use_cache if use_cache is not None else self.config.use_cache
933
- else:
934
- use_cache = False
935
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
936
-
937
- if position_ids is None:
938
- position_ids = self.get_position_ids(input_ids, device=input_ids.device)
939
-
940
- full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask, is_causal=self.config.is_causal)
941
-
942
- transformer_outputs = self.transformer(
943
- input_ids=input_ids,
944
- position_ids=position_ids, # position_ids: [batch_size, 2, seq_len]
945
- full_attention_mask=full_attention_mask,
946
- past_key_values=past_key_values,
947
- inputs_embeds=inputs_embeds,
948
- use_cache=use_cache,
949
- output_hidden_states=output_hidden_states,
950
- return_dict=return_dict,
951
- )
952
-
953
- hidden_states = transformer_outputs[0]
954
- if return_last_logit:
955
- hidden_states = hidden_states[-1:]
956
- lm_logits = self.transformer.output_layer(hidden_states)
957
- lm_logits = lm_logits.transpose(0, 1).contiguous()
958
-
959
- masked_lm_loss = None
960
- if labels is not None:
961
- lm_logits = lm_logits.to(torch.float32)
962
-
963
- # Flatten the tokens
964
- loss_fct = CrossEntropyLoss(ignore_index=-100) # -100 for padding token.
965
- masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
966
-
967
- lm_logits = lm_logits.to(hidden_states.dtype)
968
- loss = loss.to(hidden_states.dtype)
969
-
970
- if not return_dict:
971
- output = (lm_logits,) + transformer_outputs[1:]
972
- return ((loss,) + output) if loss is not None else output
973
- return MaskedLMOutput(
974
- loss = masked_lm_loss,
975
- logits=lm_logits,
976
- hidden_states=transformer_outputs.last_hidden_state if return_last_hidden_state else transformer_outputs.hidden_states,
977
- attentions=transformer_outputs.attentions,
978
- )
979
-
980
-
981
-
982
-
983
- class xTrimoPGLMForSequenceClassification(xTrimoPGLMPreTrainedModel):
984
- def __init__(self, config: xTrimoPGLMConfig, empty_init=True, device=None):
985
- super().__init__(config)
986
- self.config = config
987
- self.num_labels = config.num_labels
988
-
989
- self.transformer = xTrimoPGLMModel(config, empty_init=empty_init, device=device)
990
- self.classifier = xTrimoPGLMClassificationHead(config)
991
- if self.config.quantization_bit:
992
- print(f"Begin Quantization to {self.config.quantization_bit} bit")
993
- self.quantize(self.config.quantization_bit, empty_init=True, device=device)
994
-
995
- def forward(
996
- self,
997
- input_ids: Optional[torch.Tensor] = None,
998
- position_ids: Optional[torch.Tensor] = None,
999
- attention_mask: Optional[torch.Tensor] = None,
1000
- past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
1001
- inputs_embeds: Optional[torch.Tensor] = None,
1002
- labels: Optional[torch.Tensor] = None,
1003
- use_cache: Optional[bool] = None,
1004
- output_attentions: Optional[bool] = None,
1005
- output_hidden_states: Optional[bool] = None,
1006
- return_dict: Optional[bool] = None,
1007
- return_last_logit: Optional[bool] = None,
1008
- return_last_hidden_state: Optional[bool] = None,
1009
- **kwargs
1010
- ) -> Union[Tuple, SequenceClassifierOutput]:
1011
- r"""
1012
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1013
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1014
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1015
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1016
- """
1017
- if self.config.is_causal:
1018
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1019
- else:
1020
- use_cache = False
1021
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1022
-
1023
- if position_ids is None:
1024
- position_ids = self.get_position_ids(input_ids, device=input_ids.device)
1025
-
1026
- full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask, is_causal=self.config.is_causal)
1027
-
1028
- transformer_outputs = self.transformer(
1029
- input_ids=input_ids,
1030
- position_ids=position_ids, # position_ids: [batch_size, 2, seq_len]
1031
- full_attention_mask=full_attention_mask,
1032
- past_key_values=past_key_values,
1033
- inputs_embeds=inputs_embeds,
1034
- use_cache=use_cache,
1035
- output_hidden_states=output_hidden_states,
1036
- return_dict=return_dict,
1037
- )
1038
- if self.config.add_special_tokens:
1039
- hidden_states = transformer_outputs[0][:-1] # get rid of <eos> token
1040
- else:
1041
- hidden_states = transformer_outputs[0]
1042
- logits = self.classifier(hidden_states, add_pooling=True)
1043
- loss = None
1044
- if labels is not None:
1045
- labels = labels.to(logits.device)
1046
-
1047
- if self.config.problem_type is None:
1048
- if self.num_labels == 1:
1049
- self.config.problem_type = "regression"
1050
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1051
- self.config.problem_type = "single_label_classification"
1052
- else:
1053
- self.config.problem_type = "multi_label_classification"
1054
-
1055
- if self.config.problem_type == "regression":
1056
- loss_fct = MSELoss()
1057
- if self.num_labels == 1:
1058
- loss = loss_fct(logits.squeeze(), labels.squeeze())
1059
- else:
1060
- loss = loss_fct(logits, labels)
1061
- elif self.config.problem_type == "single_label_classification":
1062
- loss_fct = CrossEntropyLoss()
1063
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1064
- elif self.config.problem_type == "multi_label_classification":
1065
- loss_fct = BCEWithLogitsLoss()
1066
- loss = loss_fct(logits, labels)
1067
-
1068
- if not return_dict:
1069
- output = (logits,) + transformer_outputs[2:]
1070
- return ((loss,) + output) if loss is not None else output
1071
-
1072
- return SequenceClassifierOutput(
1073
- loss=loss,
1074
- logits=logits,
1075
- hidden_states=transformer_outputs.hidden_states,
1076
- attentions=transformer_outputs.attentions,
1077
- )
1078
-
1079
- class xTrimoPGLMForTokenClassification(xTrimoPGLMPreTrainedModel):
1080
- def __init__(self, config: xTrimoPGLMConfig, empty_init=True, device=None):
1081
- super().__init__(config)
1082
- self.config = config
1083
- self.num_labels = config.num_labels
1084
-
1085
- self.transformer = xTrimoPGLMModel(config, empty_init=empty_init, device=device)
1086
- if config.task_modality == "token":
1087
- self.classifier = xTrimoPGLMClassificationHead(config)
1088
- elif config.task_modality == 'pair':
1089
- self.classifier = xTrimoPGLMContactHead(config)
1090
-
1091
- self.quantized = False
1092
-
1093
- if self.config.quantization_bit:
1094
- print(f"Begin Quantization to {self.config.quantization_bit} bit")
1095
- self.quantize(self.config.quantization_bit, empty_init=True, device=device)
1096
-
1097
-
1098
- def forward(
1099
- self,
1100
- input_ids: Optional[torch.Tensor] = None,
1101
- position_ids: Optional[torch.Tensor] = None,
1102
- attention_mask: Optional[torch.Tensor] = None,
1103
- past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
1104
- inputs_embeds: Optional[torch.Tensor] = None,
1105
- labels: Optional[torch.Tensor] = None,
1106
- use_cache: Optional[bool] = None,
1107
- output_attentions: Optional[bool] = None,
1108
- output_hidden_states: Optional[bool] = None,
1109
- return_dict: Optional[bool] = None,
1110
- return_last_logit: Optional[bool] = None,
1111
- return_last_hidden_state: Optional[bool] = None,
1112
- **kwargs
1113
- ) -> Union[Tuple, SequenceClassifierOutput]:
1114
- r"""
1115
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1116
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1117
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1118
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1119
- """
1120
- if self.config.is_causal:
1121
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1122
- else:
1123
- use_cache = False
1124
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1125
-
1126
- if position_ids is None:
1127
- position_ids = self.get_position_ids(input_ids, device=input_ids.device)
1128
-
1129
- full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask, is_causal = self.config.is_causal)
1130
-
1131
- transformer_outputs = self.transformer(
1132
- input_ids=input_ids,
1133
- position_ids=position_ids, # position_ids: [batch_size, 2, seq_len]
1134
- full_attention_mask=full_attention_mask,
1135
- past_key_values=past_key_values,
1136
- inputs_embeds=inputs_embeds,
1137
- use_cache=use_cache,
1138
- output_hidden_states=output_hidden_states,
1139
- return_dict=return_dict,
1140
- )
1141
- if self.config.add_special_tokens:
1142
- hidden_states = transformer_outputs[0][:-1] # get rid of <eos> token
1143
- else:
1144
- hidden_states = transformer_outputs[0]
1145
-
1146
- logits = self.classifier(hidden_states, add_pooling=False)
1147
- loss = None
1148
- if labels is not None:
1149
- labels = labels.to(logits.device)
1150
- loss_fct = CrossEntropyLoss()
1151
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1152
-
1153
- if not return_dict:
1154
- output = (logits,) + transformer_outputs[2:]
1155
- return ((loss,) + output) if loss is not None else output
1156
-
1157
-
1158
- return TokenClassifierOutput(
1159
- loss=loss,
1160
- logits=logits,
1161
- hidden_states=transformer_outputs.hidden_states,
1162
- attentions=transformer_outputs.attentions,
1163
- )
1164
-
1165
-
1166
-
1167
- class xTrimoPGLMClassificationHead(nn.Module):
1168
- """Head for classification tasks."""
1169
- def __init__(self, config):
1170
- super().__init__()
1171
- self.activation_func = config.activation_func
1172
- self.layers = torch.nn.ModuleList()
1173
- last_size = config.hidden_size
1174
- for sz in config.inter_hidden_size:
1175
- this_layer = torch.nn.Linear(last_size, sz, bias=config.bias)
1176
- last_size = sz
1177
- self.layers.append(this_layer)
1178
-
1179
- def forward(self,
1180
- input_features,
1181
- add_pooling: Optional[bool] = True
1182
- ):
1183
- # [s, b, h] -> [b, s ,h]
1184
- input_features = input_features.transpose(0,1).contiguous()
1185
- if add_pooling:
1186
- # [b, h]
1187
- input_features = torch.mean(input_features, dim = 1)
1188
- for i, layer in enumerate(self.layers):
1189
- if i > 0:
1190
- input_features = self.activation_func(input_features)
1191
- input_features = layer(input_features)
1192
- return input_features
1193
-
1194
- class xTrimoPGLMContactHead(nn.Module):
1195
- """Head for sentence-level classification tasks."""
1196
- def __init__(self, config):
1197
- super().__init__()
1198
- self.activation_func = config.activation_func
1199
- self.layers = torch.nn.ModuleList()
1200
- last_size = config.hidden_size * 2
1201
- for sz in config.inter_hidden_size:
1202
- this_layer = torch.nn.Linear(last_size, sz, bias=config.bias)
1203
- last_size = sz
1204
- self.layers.append(this_layer)
1205
-
1206
- def outer_concat(self, x):
1207
- batch_size, seq_len, features = x.shape
1208
-
1209
- # Permute to [batch_size, features, seq_len]
1210
- x = x.permute(0, 2, 1)
1211
-
1212
- # Introduce new dimensions for broadcasting
1213
- x_1 = x[:, None, :, :, None] # [batch_size, 1, features, seq_len, 1]
1214
- x_2 = x[:, None, :, None, :] # [batch_size, 1, features, 1, seq_len]
1215
-
1216
- # Repeat along new dimensions
1217
- x_1 = x_1.repeat(1, 1, 1, 1, seq_len) # [batch_size, 1, features, seq_len, seq_len]
1218
- x_2 = x_2.repeat(1, 1, 1, seq_len, 1) # [batch_size, 1, features, seq_len, seq_len]
1219
-
1220
- # Concatenate along the second dimension
1221
- x = torch.cat((x_1, x_2), dim=1) # [batch_size, 2, features, seq_len, seq_len]
1222
-
1223
- # Get lower triangular indices
1224
- I, J = torch.tril_indices(seq_len, seq_len, -1)
1225
-
1226
- # Symmetrize
1227
- x[:, :, :, I, J] = x[:, :, :, J, I]
1228
-
1229
- # Permute to desired shape and make contiguous
1230
- x = x.permute(0, 3, 4, 2, 1).contiguous() # [batch_size, seq_len, seq_len, features, 2]
1231
-
1232
- # Reshape to combine the last two dimensions
1233
- x = x.view(batch_size, seq_len, seq_len, features * 2) # [batch_size, seq_len, seq_len, features * 2]
1234
-
1235
- return x
1236
-
1237
- def forward(self,
1238
- input_features,
1239
- add_pooling: Optional[bool] = True
1240
- ):
1241
- # [s, b, h] -> [b, s ,h]
1242
- input_features = input_features.transpose(0,1).contiguous()
1243
- input_features = self.outer_concat(input_features)
1244
- for i, layer in enumerate(self.layers):
1245
- if i > 0:
1246
- input_features = self.activation_func(input_features)
1247
- input_features = layer(input_features)
1248
- return input_features
1249
-
1250
-
1251
-
1252
-
1253
-
1254
- class xTrimoPGLMForCasualLM(xTrimoPGLMPreTrainedModel):
1255
- def __init__(self, config: xTrimoPGLMConfig, empty_init=True, device=None):
1256
- super().__init__(config)
1257
-
1258
- self.max_sequence_length = config.max_length
1259
- self.transformer = xTrimoPGLMModel(config, empty_init=empty_init, device=device)
1260
- self.config = config
1261
- if self.config.quantization_bit:
1262
- print(f"Begin Quantization to {self.config.quantization_bit} bit")
1263
- self.quantize(self.config.quantization_bit, empty_init=True, device=device)
1264
-
1265
- def _update_model_kwargs_for_generation(
1266
- self,
1267
- outputs: ModelOutput,
1268
- model_kwargs: Dict[str, Any],
1269
- is_encoder_decoder: bool = False,
1270
- standardize_cache_format: bool = False,
1271
- ) -> Dict[str, Any]:
1272
- # update past_key_values
1273
- model_kwargs["past_key_values"] = self._extract_past_from_model_output(
1274
- outputs, standardize_cache_format=standardize_cache_format
1275
- )
1276
-
1277
- # update attention mask
1278
- if "attention_mask" in model_kwargs:
1279
- attention_mask = model_kwargs["attention_mask"]
1280
- model_kwargs["attention_mask"] = torch.cat(
1281
- [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
1282
- )
1283
-
1284
- # update position ids
1285
- if "position_ids" in model_kwargs:
1286
- position_ids = model_kwargs["position_ids"]
1287
- new_position_id = position_ids[..., -1:].clone() # [batch_size, 2, 1]
1288
- if self.config.rotary_embedding_2d:
1289
- new_position_id[:, 1] += 1 # Only update the 2nd dimension
1290
- else:
1291
- new_position_id[:] += 1
1292
- model_kwargs["position_ids"] = torch.cat(
1293
- [position_ids, new_position_id], dim=-1
1294
- ) # [batch_size, 2, seq_len+1]
1295
-
1296
- model_kwargs["is_first_forward"] = False
1297
- return model_kwargs
1298
-
1299
- def prepare_inputs_for_generation(
1300
- self,
1301
- input_ids: torch.LongTensor,
1302
- past_key_values: Optional[torch.Tensor] = None,
1303
- attention_mask: Optional[torch.Tensor] = None,
1304
- position_ids: Optional[torch.Tensor] = None,
1305
- use_cache: Optional[bool] = None,
1306
- is_first_forward: bool = True,
1307
- **kwargs
1308
- ) -> dict:
1309
- # only last token for input_ids if past is not None
1310
- if position_ids is None:
1311
- position_ids = self.get_position_ids(input_ids, device=input_ids.device) # position_ids: [batch_size, 2, seq_len]
1312
- if not is_first_forward:
1313
- if past_key_values is not None:
1314
- position_ids = position_ids[..., -1:]
1315
- input_ids = input_ids[:, -1:]
1316
- return {
1317
- "input_ids": input_ids,
1318
- "past_key_values": past_key_values,
1319
- "position_ids": position_ids,
1320
- "attention_mask": attention_mask,
1321
- "return_last_logit": True,
1322
- "use_cache": use_cache
1323
- }
1324
-
1325
- def forward(
1326
- self,
1327
- input_ids: Optional[torch.Tensor] = None,
1328
- position_ids: Optional[torch.Tensor] = None,
1329
- attention_mask: Optional[torch.Tensor] = None,
1330
- past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
1331
- inputs_embeds: Optional[torch.Tensor] = None,
1332
- labels: Optional[torch.Tensor] = None,
1333
- use_cache: Optional[bool] = None,
1334
- output_attentions: Optional[bool] = None,
1335
- output_hidden_states: Optional[bool] = None,
1336
- return_dict: Optional[bool] = None,
1337
- return_last_logit: Optional[bool] = False
1338
- ):
1339
- if self.config.is_causal:
1340
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1341
- else:
1342
- use_cache = False
1343
-
1344
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1345
-
1346
- if position_ids is None:
1347
- position_ids = self.get_position_ids(input_ids, device=input_ids.device)
1348
-
1349
- transformer_outputs = self.transformer(
1350
- input_ids=input_ids,
1351
- position_ids=position_ids, # position_ids: [batch_size, 2, seq_len]
1352
- attention_mask=attention_mask,
1353
- past_key_values=past_key_values,
1354
- inputs_embeds=inputs_embeds,
1355
- use_cache=use_cache,
1356
- output_hidden_states=output_hidden_states,
1357
- return_dict=return_dict
1358
- )
1359
- hidden_states = transformer_outputs[0]
1360
- if return_last_logit:
1361
- hidden_states = hidden_states[-1:]
1362
- lm_logits = self.transformer.output_layer(hidden_states)
1363
- lm_logits = lm_logits.transpose(0, 1).contiguous()
1364
-
1365
- loss = None
1366
- if labels is not None:
1367
- lm_logits = lm_logits.to(torch.float32)
1368
-
1369
- # Shift so that tokens < n predict n
1370
- shift_logits = lm_logits[..., :-1, :].contiguous()
1371
- shift_labels = labels[..., 1:].contiguous()
1372
- # Flatten the tokens
1373
- loss_fct = CrossEntropyLoss(ignore_index=-100)
1374
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1375
-
1376
- lm_logits = lm_logits.to(hidden_states.dtype)
1377
- loss = loss.to(hidden_states.dtype)
1378
-
1379
- if not return_dict:
1380
- output = (lm_logits,) + transformer_outputs[1:]
1381
- return ((loss,) + output) if loss is not None else output
1382
-
1383
- return CausalLMOutputWithPast(
1384
- loss=loss,
1385
- logits=lm_logits,
1386
- past_key_values=transformer_outputs.past_key_values,
1387
- hidden_states=transformer_outputs.hidden_states,
1388
- attentions=transformer_outputs.attentions,
1389
- )
1390
-
1391
- @staticmethod
1392
- def _reorder_cache(
1393
- past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
1394
- ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
1395
- """
1396
- This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1397
- [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1398
- beam_idx at every generation step.
1399
-
1400
- Output shares the same memory storage as `past`.
1401
- """
1402
- return tuple(
1403
- (
1404
- layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
1405
- layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
1406
- )
1407
- for layer_past in past
1408
- )
1409
-
1410
- @torch.inference_mode()
1411
- def chat(self, tokenizer, query: str, max_length: int = 256, num_beams=1, do_sample=True,
1412
- top_p=1.0, temperature=1.0, logits_processor=None, **kwargs):
1413
- if logits_processor is None:
1414
- logits_processor = LogitsProcessorList()
1415
- logits_processor.append(InvalidScoreLogitsProcessor())
1416
- gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1417
- "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1418
- inputs = tokenizer.apply_chat_template(query, add_generation_prompt=True, tokenize=True,
1419
- return_tensors="pt", return_dict=True)
1420
- position_ids = self.get_position_ids(inputs['input_ids'], device=self.device) # TODO: ADD BATCH
1421
- eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<eop>")]
1422
- inputs["position_ids"] = position_ids
1423
- inputs = inputs.to(self.device)
1424
- outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
1425
- outputs = outputs.tolist()[0][3:] # 3 for generation prompt "<gmask><sop><eos>"
1426
- if outputs[-1] in eos_token_id:
1427
- outputs = outputs[:-1]
1428
- response = tokenizer.decode(outputs)
1429
- return response
1430
-
1431
- # TODO: fix bug in streaming chat
1432
- @torch.inference_mode()
1433
- def stream_chat(self, tokenizer, query: str, max_length: int = 56, num_beams=1, do_sample=True,
1434
- top_p=0.8, temperature=0.8, logits_processor=None, past_key_values = None, **kwargs):
1435
- if logits_processor is None:
1436
- logits_processor = LogitsProcessorList()
1437
- logits_processor.append(InvalidScoreLogitsProcessor())
1438
- eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<eop>")]
1439
- gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
1440
- "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1441
- inputs = tokenizer.apply_chat_template(query, add_generation_prompt=True, tokenize=True,
1442
- return_tensors="pt", return_dict=True)
1443
- position_ids = self.get_position_ids(inputs['input_ids'], device=self.device) # TODO: ADD BATCH
1444
- eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<eop>")]
1445
- inputs["position_ids"] = position_ids
1446
- inputs = inputs.to(self.device)
1447
- offset = 3 # 3 for generation prompt
1448
- for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
1449
- eos_token_id=eos_token_id, return_past_key_values=False,
1450
- **gen_kwargs):
1451
- outputs = outputs.tolist()[0][3:]
1452
- if outputs[-1] in eos_token_id:
1453
- outputs = outputs[:-1]
1454
- # offset = 3 + len(outputs)
1455
- response = tokenizer.decode(outputs)
1456
- if response:
1457
- yield response
1458
-
1459
- @torch.inference_mode()
1460
- def stream_generate(
1461
- self,
1462
- input_ids,
1463
- generation_config: Optional[GenerationConfig] = None,
1464
- logits_processor: Optional[LogitsProcessorList] = None,
1465
- stopping_criteria: Optional[StoppingCriteriaList] = None,
1466
- prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
1467
- return_past_key_values=False,
1468
- **kwargs,
1469
- ):
1470
- breakpoint()
1471
- batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
1472
-
1473
- if generation_config is None:
1474
- generation_config = self.generation_config
1475
- generation_config = copy.deepcopy(generation_config)
1476
- model_kwargs = generation_config.update(**kwargs)
1477
- model_kwargs["use_cache"] = generation_config.use_cache
1478
- bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
1479
-
1480
- if isinstance(eos_token_id, int):
1481
- eos_token_id = [eos_token_id]
1482
- eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
1483
-
1484
- has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
1485
- if has_default_max_length and generation_config.max_new_tokens is None:
1486
- warnings.warn(
1487
- f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
1488
- "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
1489
- " recommend using `max_new_tokens` to control the maximum length of the generation.",
1490
- UserWarning,
1491
- )
1492
- elif generation_config.max_new_tokens is not None:
1493
- generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
1494
- if not has_default_max_length:
1495
- logger.warn(
1496
- f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
1497
- f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
1498
- "Please refer to the documentation for more information. "
1499
- "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
1500
- UserWarning,
1501
- )
1502
-
1503
- if input_ids_seq_length >= generation_config.max_length:
1504
- input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
1505
- logger.warning(
1506
- f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
1507
- f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
1508
- " increasing `max_new_tokens`."
1509
- )
1510
-
1511
- # 2. Set generation parameters if not already defined
1512
- logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1513
- stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1514
-
1515
- logits_processor = self._get_logits_processor(
1516
- generation_config=generation_config,
1517
- input_ids_seq_length=input_ids_seq_length,
1518
- encoder_input_ids=input_ids,
1519
- prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
1520
- logits_processor=logits_processor,
1521
- )
1522
-
1523
- stopping_criteria = self._get_stopping_criteria(
1524
- generation_config=generation_config, stopping_criteria=stopping_criteria
1525
- )
1526
- logits_warper = self._get_logits_warper(generation_config)
1527
-
1528
- unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
1529
- scores = None
1530
- while True:
1531
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1532
- # forward pass to get next token
1533
- outputs = self(
1534
- **model_inputs,
1535
- return_dict=True,
1536
- output_attentions=False,
1537
- output_hidden_states=False,
1538
- )
1539
-
1540
- next_token_logits = outputs.logits[:, -1, :]
1541
-
1542
- # pre-process distribution
1543
- next_token_scores = logits_processor(input_ids, next_token_logits)
1544
- next_token_scores = logits_warper(input_ids, next_token_scores)
1545
-
1546
- # sample
1547
- probs = nn.functional.softmax(next_token_scores, dim=-1)
1548
- if generation_config.do_sample:
1549
- next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1550
- else:
1551
- next_tokens = torch.argmax(probs, dim=-1)
1552
- # update generated ids, model inputs, and length for next step
1553
- input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
1554
- model_kwargs = self._update_model_kwargs_for_generation(
1555
- outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1556
- )
1557
- unfinished_sequences = unfinished_sequences.mul(
1558
- next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
1559
- )
1560
- if return_past_key_values:
1561
- yield input_ids, outputs.past_key_values
1562
- else:
1563
- yield input_ids
1564
- # stop when each sentence is finished, or if we exceed the maximum length
1565
- if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
1566
- break