cfli commited on
Commit
3027a2e
·
verified ·
1 Parent(s): f97b1ee

Delete modeling_minicpm.py

Browse files
Files changed (1) hide show
  1. modeling_minicpm.py +0 -1494
modeling_minicpm.py DELETED
@@ -1,1494 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
- # and OPT implementations in this library. It has been modified from its
6
- # original forms to accommodate minor architectural differences compared
7
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
- #
9
- # Licensed under the Apache License, Version 2.0 (the "License");
10
- # you may not use this file except in compliance with the License.
11
- # You may obtain a copy of the License at
12
- #
13
- # http://www.apache.org/licenses/LICENSE-2.0
14
- #
15
- # Unless required by applicable law or agreed to in writing, software
16
- # distributed under the License is distributed on an "AS IS" BASIS,
17
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
- # See the License for the specific language governing permissions and
19
- # limitations under the License.
20
- """ PyTorch MiniCPM model."""
21
- import sys
22
-
23
- import math
24
- import warnings
25
- from typing import List, Optional, Tuple, Union, Dict
26
-
27
- import torch
28
- import torch.nn.functional as F
29
- import torch.utils.checkpoint
30
- from torch import nn
31
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
32
-
33
- from transformers.activations import ACT2FN
34
- from transformers.cache_utils import Cache, DynamicCache
35
- from transformers.modeling_attn_mask_utils import (
36
- AttentionMaskConverter,
37
- _prepare_4d_attention_mask,
38
- _prepare_4d_causal_attention_mask,
39
- _prepare_4d_causal_attention_mask_for_sdpa,
40
- )
41
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, \
42
- SequenceClassifierOutputWithPast
43
- from transformers.modeling_utils import PreTrainedModel
44
- from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
45
- from transformers.utils import (
46
- add_start_docstrings,
47
- add_start_docstrings_to_model_forward,
48
- is_flash_attn_2_available,
49
- is_flash_attn_greater_or_equal_2_10,
50
- logging,
51
- replace_return_docstrings,
52
- )
53
- from transformers.utils.import_utils import is_torch_fx_available
54
- from configuration_minicpm_reranker import MiniCPMConfig
55
- import re
56
-
57
- try:
58
- from flash_attn import flash_attn_func, flash_attn_varlen_func
59
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
60
- except:
61
- pass
62
-
63
- # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
64
- # It means that the function will not be traced through and simply appear as a node in the graph.
65
- if is_torch_fx_available():
66
- if not is_torch_greater_or_equal_than_1_13:
67
- import torch.fx
68
-
69
- _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
70
-
71
- logger = logging.get_logger(__name__)
72
-
73
- _CONFIG_FOR_DOC = "MiniCPMConfig"
74
-
75
-
76
- def _get_unpad_data(attention_mask):
77
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
78
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
79
- max_seqlen_in_batch = seqlens_in_batch.max().item()
80
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
81
- return (
82
- indices,
83
- cu_seqlens,
84
- max_seqlen_in_batch,
85
- )
86
-
87
-
88
- def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
89
- warnings.warn(
90
- "Calling `transformers.models.minicpm.modeling_minicpm._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils._prepare_4d_attention_mask"
91
- )
92
- return _prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
93
-
94
-
95
- def _make_causal_mask(
96
- input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
97
- ):
98
- warnings.warn(
99
- "Calling `transformers.models.minicpm.modeling_minicpm._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.minicpm.modeling_minicpm.AttentionMaskConverter._make_causal_mask"
100
- )
101
- return AttentionMaskConverter._make_causal_mask(
102
- input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length
103
- )
104
-
105
-
106
- # @torch.jit.script # type: ignore
107
- def rms_layernorm(hidden: torch.Tensor, weight: torch.Tensor, eps: float):
108
- old_dtype = hidden.dtype
109
- variance = hidden.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
110
- hidden = (hidden * torch.rsqrt(variance + eps)).to(old_dtype)
111
- return hidden * weight
112
-
113
-
114
- class MiniCPMRMSNorm(nn.Module):
115
- def __init__(self, hidden_size, eps=1e-6):
116
- """
117
- MiniCPMRMSNorm is equivalent to T5LayerNorm
118
- """
119
- super().__init__()
120
- self.weight = nn.Parameter(torch.ones(hidden_size))
121
- self.variance_epsilon = eps
122
-
123
- def forward(self, hidden_states):
124
- return rms_layernorm(hidden_states, self.weight, self.variance_epsilon)
125
-
126
-
127
- ALL_LAYERNORM_LAYERS.append(MiniCPMRMSNorm)
128
-
129
-
130
- class MiniCPMRotaryEmbedding(nn.Module):
131
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
132
- super().__init__()
133
-
134
- self.dim = dim
135
- self.max_position_embeddings = max_position_embeddings
136
- self.base = base
137
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
138
- self.register_buffer("inv_freq", inv_freq, persistent=False)
139
-
140
- # Build here to make `torch.jit.trace` work.
141
- self._set_cos_sin_cache(
142
- # seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
143
- seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32
144
- )
145
-
146
- def _set_cos_sin_cache(self, seq_len, device, dtype):
147
- self.max_seq_len_cached = seq_len
148
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
149
- freqs = torch.outer(t, self.inv_freq)
150
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
151
- emb = torch.cat((freqs, freqs), dim=-1)
152
-
153
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
154
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
155
-
156
- def forward(self, x, seq_len=None):
157
- # x: [bs, num_attention_heads, seq_len, head_size]
158
- if seq_len > self.max_seq_len_cached:
159
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
160
-
161
- return (
162
- self.cos_cached[:seq_len].to(dtype=x.dtype),
163
- self.sin_cached[:seq_len].to(dtype=x.dtype),
164
- )
165
-
166
-
167
- class MiniCPMLinearScalingRotaryEmbedding(MiniCPMRotaryEmbedding):
168
- """MiniCPMRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
169
-
170
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
171
- self.scaling_factor = scaling_factor
172
- super().__init__(dim, max_position_embeddings, base, device)
173
-
174
- def _set_cos_sin_cache(self, seq_len, device, dtype):
175
- self.max_seq_len_cached = seq_len
176
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
177
- t = t / self.scaling_factor
178
-
179
- freqs = torch.outer(t, self.inv_freq)
180
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
181
- emb = torch.cat((freqs, freqs), dim=-1)
182
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
183
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
184
-
185
-
186
- class MiniCPMDynamicNTKScalingRotaryEmbedding(MiniCPMRotaryEmbedding):
187
- """MiniCPMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
188
-
189
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
190
- self.scaling_factor = scaling_factor
191
- super().__init__(dim, max_position_embeddings, base, device)
192
-
193
- def _set_cos_sin_cache(self, seq_len, device, dtype):
194
- self.max_seq_len_cached = seq_len
195
-
196
- if seq_len > self.max_position_embeddings:
197
- base = self.base * (
198
- (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
199
- ) ** (self.dim / (self.dim - 2))
200
- inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
201
- self.register_buffer("inv_freq", inv_freq, persistent=False)
202
-
203
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
204
-
205
- freqs = torch.outer(t, self.inv_freq)
206
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
207
- emb = torch.cat((freqs, freqs), dim=-1)
208
-
209
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
210
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
211
-
212
-
213
- def rotate_half(x):
214
- """Rotates half the hidden dims of the input."""
215
- x1 = x[..., : x.shape[-1] // 2]
216
- x2 = x[..., x.shape[-1] // 2:]
217
- return torch.cat((-x2, x1), dim=-1)
218
-
219
-
220
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
221
- """Applies Rotary Position Embedding to the query and key tensors.
222
-
223
- Args:
224
- q (`torch.Tensor`): The query tensor.
225
- k (`torch.Tensor`): The key tensor.
226
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
227
- sin (`torch.Tensor`): The sine part of the rotary embedding.
228
- position_ids (`torch.Tensor`):
229
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
230
- used to pass offsetted position ids when working with a KV-cache.
231
- unsqueeze_dim (`int`, *optional*, defaults to 1):
232
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
233
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
234
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
235
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
236
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
237
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
238
- Returns:
239
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
240
- """
241
- # cos = cos[position_ids].unsqueeze(unsqueeze_dim)
242
- # sin = sin[position_ids].unsqueeze(unsqueeze_dim)
243
- # q_embed = (q * cos) + (rotate_half(q) * sin)
244
- # k_embed = (k * cos) + (rotate_half(k) * sin)
245
- orig_dtype = k.dtype
246
- cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
247
- sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
248
- q_fp32 = q.to(dtype=torch.float32, device=q.device)
249
- k_fp32 = k.to(dtype=torch.float32, device=k.device)
250
- q_embed = (q_fp32 * cos) + (rotate_half(q_fp32) * sin)
251
- k_embed = (k_fp32 * cos) + (rotate_half(k_fp32) * sin)
252
- return q_embed.to(dtype=orig_dtype), k_embed.to(dtype=orig_dtype)
253
-
254
-
255
- class MiniCPMMLP(nn.Module):
256
- def __init__(self, config):
257
- super().__init__()
258
- self.config = config
259
- self.hidden_size = config.hidden_size
260
- self.intermediate_size = config.intermediate_size
261
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
262
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
263
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
264
- self.act_fn = ACT2FN[config.hidden_act]
265
-
266
- def forward(self, x):
267
- if self.config.pretraining_tp > 1:
268
- slice = self.intermediate_size // self.config.pretraining_tp
269
- gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
270
- up_proj_slices = self.up_proj.weight.split(slice, dim=0)
271
- down_proj_slices = self.down_proj.weight.split(slice, dim=1)
272
-
273
- gate_proj = torch.cat(
274
- [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
275
- )
276
- up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
277
-
278
- intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
279
- down_proj = [
280
- F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
281
- ]
282
- down_proj = sum(down_proj)
283
- else:
284
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
285
-
286
- return down_proj
287
-
288
-
289
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
290
- """
291
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
292
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
293
- """
294
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
295
- if n_rep == 1:
296
- return hidden_states
297
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
298
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
299
-
300
-
301
- class MiniCPMAttention(nn.Module):
302
- """Multi-headed attention from 'Attention Is All You Need' paper"""
303
-
304
- def __init__(self, config: MiniCPMConfig, layer_idx: Optional[int] = None):
305
- super().__init__()
306
- self.config = config
307
- self.layer_idx = layer_idx
308
- if layer_idx is None:
309
- logger.warning_once(
310
- f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
311
- "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
312
- "when creating this class."
313
- )
314
-
315
- self.attention_dropout = config.attention_dropout
316
- self.hidden_size = config.hidden_size
317
- self.num_heads = config.num_attention_heads
318
- self.head_dim = self.hidden_size // self.num_heads
319
- self.num_key_value_heads = config.num_key_value_heads
320
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
321
- self.max_position_embeddings = config.max_position_embeddings
322
- self.rope_theta = config.rope_theta
323
- self.is_causal = True
324
-
325
- if (self.head_dim * self.num_heads) != self.hidden_size:
326
- raise ValueError(
327
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
328
- f" and `num_heads`: {self.num_heads})."
329
- )
330
-
331
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
332
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
333
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
334
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
335
- self._init_rope()
336
-
337
- def _init_rope(self):
338
- if self.config.rope_scaling is None:
339
- self.rotary_emb = MiniCPMRotaryEmbedding(
340
- self.head_dim,
341
- max_position_embeddings=self.max_position_embeddings,
342
- base=self.rope_theta,
343
- )
344
- else:
345
- scaling_type = self.config.rope_scaling["type"]
346
- scaling_factor = self.config.rope_scaling["factor"]
347
- if scaling_type == "linear":
348
- self.rotary_emb = MiniCPMLinearScalingRotaryEmbedding(
349
- self.head_dim,
350
- max_position_embeddings=self.max_position_embeddings,
351
- scaling_factor=scaling_factor,
352
- base=self.rope_theta,
353
- )
354
- elif scaling_type == "dynamic":
355
- self.rotary_emb = MiniCPMDynamicNTKScalingRotaryEmbedding(
356
- self.head_dim,
357
- max_position_embeddings=self.max_position_embeddings,
358
- scaling_factor=scaling_factor,
359
- base=self.rope_theta,
360
- )
361
- else:
362
- raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
363
-
364
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
365
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
366
-
367
- def forward(
368
- self,
369
- hidden_states: torch.Tensor,
370
- attention_mask: Optional[torch.Tensor] = None,
371
- position_ids: Optional[torch.LongTensor] = None,
372
- past_key_value: Optional[Cache] = None,
373
- output_attentions: bool = False,
374
- use_cache: bool = False,
375
- **kwargs,
376
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
377
- if "padding_mask" in kwargs:
378
- warnings.warn(
379
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
380
- )
381
-
382
- bsz, q_len, _ = hidden_states.size()
383
-
384
- if self.config.pretraining_tp > 1:
385
- key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
386
- query_slices = self.q_proj.weight.split(
387
- (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
388
- )
389
- key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
390
- value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
391
-
392
- query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
393
- query_states = torch.cat(query_states, dim=-1)
394
-
395
- key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
396
- key_states = torch.cat(key_states, dim=-1)
397
-
398
- value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
399
- value_states = torch.cat(value_states, dim=-1)
400
-
401
- else:
402
- query_states = self.q_proj(hidden_states)
403
- key_states = self.k_proj(hidden_states)
404
- value_states = self.v_proj(hidden_states)
405
-
406
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
407
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
408
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
409
-
410
- kv_seq_len = key_states.shape[-2]
411
- if past_key_value is not None:
412
- if self.layer_idx is None:
413
- raise ValueError(
414
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
415
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
416
- "with a layer index."
417
- )
418
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
419
- cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)
420
-
421
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
422
-
423
- if past_key_value is not None:
424
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
425
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
426
-
427
- key_states = repeat_kv(key_states, self.num_key_value_groups)
428
- value_states = repeat_kv(value_states, self.num_key_value_groups)
429
-
430
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
431
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
432
- raise ValueError(
433
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
434
- f" {attn_weights.size()}"
435
- )
436
-
437
- if attention_mask is not None:
438
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
439
- raise ValueError(
440
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
441
- )
442
- attn_weights = attn_weights + attention_mask
443
-
444
- # upcast attention to fp32
445
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
446
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
447
- attn_output = torch.matmul(attn_weights, value_states)
448
-
449
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
450
- raise ValueError(
451
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
452
- f" {attn_output.size()}"
453
- )
454
-
455
- attn_output = attn_output.transpose(1, 2).contiguous()
456
-
457
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
458
-
459
- if self.config.pretraining_tp > 1:
460
- attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
461
- o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
462
- attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
463
- else:
464
- attn_output = self.o_proj(attn_output)
465
-
466
- if not output_attentions:
467
- attn_weights = None
468
-
469
- return attn_output, attn_weights, past_key_value
470
-
471
-
472
- class MiniCPMFlashAttention2(MiniCPMAttention):
473
- """
474
- MiniCPM flash attention module. This module inherits from `MiniCPMAttention` as the weights of the module stays
475
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
476
- flash attention and deal with padding tokens in case the input contains any of them.
477
- """
478
-
479
- def __init__(self, *args, **kwargs):
480
- super().__init__(*args, **kwargs)
481
-
482
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
483
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
484
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
485
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
486
-
487
- def forward(
488
- self,
489
- hidden_states: torch.Tensor,
490
- attention_mask: Optional[torch.LongTensor] = None,
491
- position_ids: Optional[torch.LongTensor] = None,
492
- past_key_value: Optional[Cache] = None,
493
- output_attentions: bool = False,
494
- use_cache: bool = False,
495
- **kwargs,
496
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
497
- # MiniCPMFlashAttention2 attention does not support output_attentions
498
- if "padding_mask" in kwargs:
499
- warnings.warn(
500
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
501
- )
502
-
503
- # overwrite attention_mask with padding_mask
504
- attention_mask = kwargs.pop("padding_mask")
505
-
506
- output_attentions = False
507
-
508
- bsz, q_len, _ = hidden_states.size()
509
-
510
- query_states = self.q_proj(hidden_states)
511
- key_states = self.k_proj(hidden_states)
512
- value_states = self.v_proj(hidden_states)
513
-
514
- # Flash attention requires the input to have the shape
515
- # batch_size x seq_length x head_dim x hidden_dim
516
- # therefore we just need to keep the original shape
517
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
518
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
519
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
520
-
521
- kv_seq_len = key_states.shape[-2]
522
- if past_key_value is not None:
523
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
524
- cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)
525
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
526
-
527
- if past_key_value is not None:
528
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
529
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
530
-
531
- # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
532
- # to be able to avoid many of these transpose/reshape/view.
533
- query_states = query_states.transpose(1, 2)
534
- key_states = key_states.transpose(1, 2)
535
- value_states = value_states.transpose(1, 2)
536
-
537
- dropout_rate = self.attention_dropout if self.training else 0.0
538
-
539
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
540
- # therefore the input hidden states gets silently casted in float32. Hence, we need
541
- # cast them back in the correct dtype just to be sure everything works as expected.
542
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
543
- # in fp32. (MiniCPMRMSNorm handles it correctly)
544
-
545
- input_dtype = query_states.dtype
546
- if input_dtype == torch.float32:
547
- # Handle the case where the model is quantized
548
- if hasattr(self.config, "_pre_quantization_dtype"):
549
- target_dtype = self.config._pre_quantization_dtype
550
- else:
551
- target_dtype = self.q_proj.weight.dtype
552
-
553
- logger.warning_once(
554
- f"The input hidden states seems to be silently casted in float32, this might be related to"
555
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
556
- f" {target_dtype}."
557
- )
558
-
559
- query_states = query_states.to(target_dtype)
560
- key_states = key_states.to(target_dtype)
561
- value_states = value_states.to(target_dtype)
562
-
563
- attn_output = self._flash_attention_forward(
564
- query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
565
- )
566
-
567
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
568
- attn_output = self.o_proj(attn_output)
569
-
570
- if not output_attentions:
571
- attn_weights = None
572
-
573
- return attn_output, attn_weights, past_key_value
574
-
575
- def _flash_attention_forward(
576
- self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
577
- ):
578
- """
579
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
580
- first unpad the input, then computes the attention scores and pad the final attention scores.
581
-
582
- Args:
583
- query_states (`torch.Tensor`):
584
- Input query states to be passed to Flash Attention API
585
- key_states (`torch.Tensor`):
586
- Input key states to be passed to Flash Attention API
587
- value_states (`torch.Tensor`):
588
- Input value states to be passed to Flash Attention API
589
- attention_mask (`torch.Tensor`):
590
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
591
- position of padding tokens and 1 for the position of non-padding tokens.
592
- dropout (`int`, *optional*):
593
- Attention dropout
594
- softmax_scale (`float`, *optional*):
595
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
596
- """
597
- if not self._flash_attn_uses_top_left_mask:
598
- causal = self.is_causal
599
- else:
600
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in MiniCPMFlashAttention2 __init__.
601
- causal = self.is_causal and query_length != 1
602
- # Contains at least one padding token in the sequence
603
- if attention_mask is not None:
604
- batch_size = query_states.shape[0]
605
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
606
- query_states, key_states, value_states, attention_mask, query_length
607
- )
608
-
609
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
610
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
611
- attn_output_unpad = flash_attn_varlen_func(
612
- query_states,
613
- key_states,
614
- value_states,
615
- cu_seqlens_q=cu_seqlens_q,
616
- cu_seqlens_k=cu_seqlens_k,
617
- max_seqlen_q=max_seqlen_in_batch_q,
618
- max_seqlen_k=max_seqlen_in_batch_k,
619
- dropout_p=dropout,
620
- softmax_scale=softmax_scale,
621
- causal=causal,
622
- )
623
-
624
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
625
- else:
626
- attn_output = flash_attn_func(
627
- query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
628
- )
629
-
630
- return attn_output
631
-
632
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
633
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
634
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
635
-
636
- key_layer = index_first_axis(
637
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
638
- )
639
- value_layer = index_first_axis(
640
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
641
- )
642
- if query_length == kv_seq_len:
643
- query_layer = index_first_axis(
644
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
645
- )
646
- cu_seqlens_q = cu_seqlens_k
647
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
648
- indices_q = indices_k
649
- elif query_length == 1:
650
- max_seqlen_in_batch_q = 1
651
- cu_seqlens_q = torch.arange(
652
- batch_size + 1, dtype=torch.int32, device=query_layer.device
653
- ) # There is a memcpy here, that is very bad.
654
- indices_q = cu_seqlens_q[:-1]
655
- query_layer = query_layer.squeeze(1)
656
- else:
657
- # The -q_len: slice assumes left padding.
658
- attention_mask = attention_mask[:, -query_length:]
659
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
660
-
661
- return (
662
- query_layer,
663
- key_layer,
664
- value_layer,
665
- indices_q,
666
- (cu_seqlens_q, cu_seqlens_k),
667
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
668
- )
669
-
670
-
671
- class MiniCPMSdpaAttention(MiniCPMAttention):
672
- """
673
- MiniCPM attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
674
- `MiniCPMAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
675
- SDPA API.
676
- """
677
-
678
- # Adapted from MiniCPMAttention.forward
679
- def forward(
680
- self,
681
- hidden_states: torch.Tensor,
682
- attention_mask: Optional[torch.Tensor] = None,
683
- position_ids: Optional[torch.LongTensor] = None,
684
- past_key_value: Optional[Cache] = None,
685
- output_attentions: bool = False,
686
- use_cache: bool = False,
687
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
688
- if output_attentions:
689
- # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
690
- logger.warning_once(
691
- "MiniCPMModel is using MiniCPMSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
692
- 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
693
- )
694
- return super().forward(
695
- hidden_states=hidden_states,
696
- attention_mask=attention_mask,
697
- position_ids=position_ids,
698
- past_key_value=past_key_value,
699
- output_attentions=output_attentions,
700
- use_cache=use_cache,
701
- )
702
-
703
- bsz, q_len, _ = hidden_states.size()
704
-
705
- query_states = self.q_proj(hidden_states)
706
- key_states = self.k_proj(hidden_states)
707
- value_states = self.v_proj(hidden_states)
708
-
709
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
710
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
711
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
712
-
713
- kv_seq_len = key_states.shape[-2]
714
- if past_key_value is not None:
715
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
716
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
717
-
718
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
719
-
720
- if past_key_value is not None:
721
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
722
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
723
-
724
- key_states = repeat_kv(key_states, self.num_key_value_groups)
725
- value_states = repeat_kv(value_states, self.num_key_value_groups)
726
-
727
- if attention_mask is not None:
728
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
729
- raise ValueError(
730
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
731
- )
732
-
733
- # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
734
- # Reference: https://github.com/pytorch/pytorch/issues/112577.
735
- if query_states.device.type == "cuda" and attention_mask is not None:
736
- query_states = query_states.contiguous()
737
- key_states = key_states.contiguous()
738
- value_states = value_states.contiguous()
739
-
740
- attn_output = torch.nn.functional.scaled_dot_product_attention(
741
- query_states,
742
- key_states,
743
- value_states,
744
- attn_mask=attention_mask,
745
- dropout_p=self.attention_dropout if self.training else 0.0,
746
- # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
747
- is_causal=self.is_causal and attention_mask is None and q_len > 1,
748
- )
749
-
750
- attn_output = attn_output.transpose(1, 2).contiguous()
751
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
752
-
753
- attn_output = self.o_proj(attn_output)
754
-
755
- return attn_output, None, past_key_value
756
-
757
-
758
- MINICPM_ATTENTION_CLASSES = {
759
- "eager": MiniCPMAttention,
760
- "flash_attention_2": MiniCPMFlashAttention2,
761
- "sdpa": MiniCPMSdpaAttention,
762
- }
763
-
764
-
765
- class MiniCPMDecoderLayer(nn.Module):
766
- def __init__(self, config: MiniCPMConfig, layer_idx: int):
767
- super().__init__()
768
- self.hidden_size = config.hidden_size
769
- self.self_attn = MINICPM_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
770
-
771
- self.mlp = MiniCPMMLP(config)
772
- self.input_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
773
- self.post_attention_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
774
-
775
- self.scale_depth = config.scale_depth
776
- self.num_hidden_layers = config.num_hidden_layers
777
-
778
- def forward(
779
- self,
780
- hidden_states: torch.Tensor,
781
- attention_mask: Optional[torch.Tensor] = None,
782
- position_ids: Optional[torch.LongTensor] = None,
783
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
784
- output_attentions: Optional[bool] = False,
785
- use_cache: Optional[bool] = False,
786
- **kwargs,
787
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
788
- """
789
- Args:
790
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
791
- attention_mask (`torch.FloatTensor`, *optional*):
792
- attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
793
- query_sequence_length, key_sequence_length)` if default attention is used.
794
- output_attentions (`bool`, *optional*):
795
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
796
- returned tensors for more detail.
797
- use_cache (`bool`, *optional*):
798
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
799
- (see `past_key_values`).
800
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
801
- """
802
- if "padding_mask" in kwargs:
803
- warnings.warn(
804
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
805
- )
806
-
807
- residual = hidden_states
808
- hidden_states = self.input_layernorm(hidden_states)
809
- # Self Attention
810
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
811
- hidden_states=hidden_states,
812
- attention_mask=attention_mask,
813
- position_ids=position_ids,
814
- past_key_value=past_key_value,
815
- output_attentions=output_attentions,
816
- use_cache=use_cache,
817
- **kwargs,
818
- )
819
-
820
- hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
821
-
822
- # Fully Connected
823
- residual = hidden_states
824
- hidden_states = self.post_attention_layernorm(hidden_states)
825
-
826
- hidden_states = self.mlp(hidden_states)
827
- hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
828
-
829
- outputs = (hidden_states,)
830
-
831
- if output_attentions:
832
- outputs += (self_attn_weights,)
833
-
834
- if use_cache:
835
- outputs += (present_key_value,)
836
-
837
- return outputs
838
-
839
-
840
- MINICPM_START_DOCSTRING = r"""
841
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
842
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
843
- etc.)
844
-
845
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
846
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
847
- and behavior.
848
-
849
- Parameters:
850
- config ([`MiniCPMConfig`]):
851
- Model configuration class with all the parameters of the model. Initializing with a config file does not
852
- load the weights associated with the model, only the configuration. Check out the
853
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
854
- """
855
-
856
-
857
- @add_start_docstrings(
858
- "The bare MiniCPM Model outputting raw hidden-states without any specific head on top.",
859
- MINICPM_START_DOCSTRING,
860
- )
861
- class MiniCPMPreTrainedModel(PreTrainedModel):
862
- config_class = MiniCPMConfig
863
- base_model_prefix = "model"
864
- supports_gradient_checkpointing = True
865
- _no_split_modules = ["MiniCPMDecoderLayer"]
866
- _skip_keys_device_placement = "past_key_values"
867
- _supports_flash_attn_2 = True
868
- _supports_sdpa = True
869
- _supports_cache_class = True
870
-
871
- def _init_weights(self, module):
872
- std = self.config.initializer_range
873
- if isinstance(module, nn.Linear):
874
- module.weight.data.normal_(mean=0.0, std=std)
875
- if module.bias is not None:
876
- module.bias.data.zero_()
877
- elif isinstance(module, nn.Embedding):
878
- module.weight.data.normal_(mean=0.0, std=std)
879
- if module.padding_idx is not None:
880
- module.weight.data[module.padding_idx].zero_()
881
-
882
-
883
- MINICPM_INPUTS_DOCSTRING = r"""
884
- Args:
885
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
886
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
887
- it.
888
-
889
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
890
- [`PreTrainedTokenizer.__call__`] for details.
891
-
892
- [What are input IDs?](../glossary#input-ids)
893
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
894
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
895
-
896
- - 1 for tokens that are **not masked**,
897
- - 0 for tokens that are **masked**.
898
-
899
- [What are attention masks?](../glossary#attention-mask)
900
-
901
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
902
- [`PreTrainedTokenizer.__call__`] for details.
903
-
904
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
905
- `past_key_values`).
906
-
907
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
908
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
909
- information on the default strategy.
910
-
911
- - 1 indicates the head is **not masked**,
912
- - 0 indicates the head is **masked**.
913
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
914
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
915
- config.n_positions - 1]`.
916
-
917
- [What are position IDs?](../glossary#position-ids)
918
- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
919
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
920
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
921
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
922
-
923
- Two formats are allowed:
924
- - a [`~cache_utils.Cache`] instance;
925
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
926
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
927
- cache format.
928
-
929
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
930
- legacy cache format will be returned.
931
-
932
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
933
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
934
- of shape `(batch_size, sequence_length)`.
935
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
936
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
937
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
938
- model's internal embedding lookup matrix.
939
- use_cache (`bool`, *optional*):
940
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
941
- `past_key_values`).
942
- output_attentions (`bool`, *optional*):
943
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
944
- tensors for more detail.
945
- output_hidden_states (`bool`, *optional*):
946
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
947
- more detail.
948
- return_dict (`bool`, *optional*):
949
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
950
- """
951
-
952
-
953
- @add_start_docstrings(
954
- "The bare MiniCPM Model outputting raw hidden-states without any specific head on top.",
955
- MINICPM_START_DOCSTRING,
956
- )
957
- class LayerWiseMiniCPMModel(MiniCPMPreTrainedModel):
958
- """
959
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MiniCPMDecoderLayer`]
960
-
961
- Args:
962
- config: MiniCPMConfig
963
- """
964
-
965
- def __init__(self, config: MiniCPMConfig):
966
- super().__init__(config)
967
- self.padding_idx = config.pad_token_id
968
- self.vocab_size = config.vocab_size
969
-
970
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
971
- self.layers = nn.ModuleList(
972
- [MiniCPMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
973
- )
974
- self._use_sdpa = config._attn_implementation == "sdpa"
975
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
976
-
977
- self.norm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
978
-
979
- self.gradient_checkpointing = False
980
- # Initialize weights and apply final processing
981
- self.post_init()
982
-
983
- def get_input_embeddings(self):
984
- return self.embed_tokens
985
-
986
- def set_input_embeddings(self, value):
987
- self.embed_tokens = value
988
-
989
- @add_start_docstrings_to_model_forward(MINICPM_INPUTS_DOCSTRING)
990
- def forward(
991
- self,
992
- input_ids: torch.LongTensor = None,
993
- attention_mask: Optional[torch.Tensor] = None,
994
- position_ids: Optional[torch.LongTensor] = None,
995
- past_key_values: Optional[List[torch.FloatTensor]] = None,
996
- inputs_embeds: Optional[torch.FloatTensor] = None,
997
- use_cache: Optional[bool] = None,
998
- output_attentions: Optional[bool] = None,
999
- output_hidden_states: Optional[bool] = None,
1000
- return_dict: Optional[bool] = None,
1001
- cutoff_layers: Optional[Union[int, List]] = None,
1002
- ) -> Union[Tuple, BaseModelOutputWithPast]:
1003
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1004
- output_hidden_states = (
1005
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1006
- )
1007
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1008
-
1009
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1010
-
1011
- # retrieve input_ids and inputs_embeds
1012
- if input_ids is not None and inputs_embeds is not None:
1013
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1014
- elif input_ids is not None:
1015
- batch_size, seq_length = input_ids.shape[:2]
1016
- elif inputs_embeds is not None:
1017
- batch_size, seq_length = inputs_embeds.shape[:2]
1018
- else:
1019
- raise ValueError("You have to specify either input_ids or inputs_embeds")
1020
-
1021
- if self.gradient_checkpointing and self.training:
1022
- if use_cache:
1023
- logger.warning_once(
1024
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1025
- )
1026
- use_cache = False
1027
-
1028
- past_key_values_length = 0
1029
- if use_cache:
1030
- use_legacy_cache = not isinstance(past_key_values, Cache)
1031
- if use_legacy_cache:
1032
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1033
- past_key_values_length = past_key_values.get_usable_length(seq_length)
1034
-
1035
- if position_ids is None:
1036
- device = input_ids.device if input_ids is not None else inputs_embeds.device
1037
- position_ids = torch.arange(
1038
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1039
- )
1040
- position_ids = position_ids.unsqueeze(0)
1041
-
1042
- if inputs_embeds is None:
1043
- inputs_embeds = self.embed_tokens(input_ids) * self.config.scale_emb
1044
-
1045
- if self._use_flash_attention_2:
1046
- # 2d mask is passed through the layers
1047
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1048
- elif self._use_sdpa and not output_attentions:
1049
- # output_attentions=True can not be supported when using SDPA, and we fall back on
1050
- # the manual implementation that requires a 4D causal mask in all cases.
1051
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1052
- attention_mask,
1053
- (batch_size, seq_length),
1054
- inputs_embeds,
1055
- past_key_values_length,
1056
- )
1057
- else:
1058
- # 4d mask is passed through the layers
1059
- attention_mask = _prepare_4d_causal_attention_mask(
1060
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1061
- )
1062
-
1063
- # embed positions
1064
- hidden_states = inputs_embeds
1065
-
1066
- # decoder layers
1067
- all_hidden_states = () if output_hidden_states else None
1068
- all_self_attns = () if output_attentions else None
1069
- next_decoder_cache = None
1070
-
1071
- if cutoff_layers is None:
1072
- max_layer = self.config.num_hidden_layers
1073
- cutoff_layers = [max_layer]
1074
- if isinstance(cutoff_layers, int):
1075
- max_layer = cutoff_layers
1076
- cutoff_layers = [cutoff_layers]
1077
- else:
1078
- max_layer = max(cutoff_layers)
1079
-
1080
- for idx, decoder_layer in enumerate(self.layers):
1081
- if idx in cutoff_layers and output_hidden_states:
1082
- all_hidden_states += (self.norm(hidden_states),)
1083
-
1084
- if idx == max_layer:
1085
- break
1086
-
1087
- if self.gradient_checkpointing and self.training:
1088
- layer_outputs = self._gradient_checkpointing_func(
1089
- decoder_layer.__call__,
1090
- hidden_states,
1091
- attention_mask,
1092
- position_ids,
1093
- past_key_values,
1094
- output_attentions,
1095
- use_cache,
1096
- )
1097
- else:
1098
- layer_outputs = decoder_layer(
1099
- hidden_states,
1100
- attention_mask=attention_mask,
1101
- position_ids=position_ids,
1102
- past_key_value=past_key_values,
1103
- output_attentions=output_attentions,
1104
- use_cache=use_cache,
1105
- )
1106
-
1107
- hidden_states = layer_outputs[0]
1108
-
1109
- if use_cache:
1110
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1111
-
1112
- if output_attentions:
1113
- all_self_attns += (layer_outputs[1],)
1114
-
1115
- hidden_states = self.norm(hidden_states)
1116
-
1117
- # add hidden states from the last decoder layer
1118
- if output_hidden_states and self.config.num_hidden_layers == max_layer:
1119
- all_hidden_states += (hidden_states,)
1120
-
1121
- next_cache = None
1122
- if use_cache:
1123
- next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1124
- if not return_dict:
1125
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1126
- return BaseModelOutputWithPast(
1127
- last_hidden_state=hidden_states,
1128
- past_key_values=next_cache,
1129
- hidden_states=all_hidden_states,
1130
- attentions=all_self_attns,
1131
- )
1132
-
1133
-
1134
- class LayerWiseHead(nn.Module):
1135
- """Head for sentence-level classification tasks."""
1136
-
1137
- def __init__(self, input_size, output_size):
1138
- super().__init__()
1139
- self.linear_head = nn.Linear(input_size, output_size, bias=False)
1140
-
1141
- def forward(self, **kwargs):
1142
- return self.linear_head(**kwargs)
1143
-
1144
- class LayerWiseMiniCPMForCausalLM(MiniCPMPreTrainedModel):
1145
- _tied_weights_keys = ["lm_head.weight"]
1146
-
1147
- def __init__(self, config):
1148
- super().__init__(config)
1149
- self.model = LayerWiseMiniCPMModel(config)
1150
- self.vocab_size = config.vocab_size
1151
-
1152
- if self.config.head_type == 'raw':
1153
- if not self.config.head_multi:
1154
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1155
- else:
1156
- self.lm_head = nn.ModuleList([nn.Linear(
1157
- config.hidden_size, config.vocab_size, bias=False) for _ in range(
1158
- self.config.start_layer,
1159
- self.model.config.num_hidden_layers + 1)])
1160
- elif self.config.head_type == 'complex':
1161
- if not self.config.head_multi:
1162
- # self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1163
- self.lm_head = LayerWiseHead(config.hidden_size, config.vocab_size)
1164
- else:
1165
- # self.lm_head = nn.ModuleList([nn.Linear(
1166
- # config.hidden_size, config.vocab_size, bias=False) for _ in range(
1167
- # self.config.start_layer,
1168
- # self.model.config.num_hidden_layers + 1)])
1169
- self.lm_head = nn.ModuleList([LayerWiseHead(
1170
- config.hidden_size, config.vocab_size) for _ in range(
1171
- self.config.start_layer,
1172
- self.model.config.num_hidden_layers + 1)])
1173
- else:
1174
- if not self.config.head_multi:
1175
- # self.lm_head = nn.Linear(config.hidden_size, 1, bias=False)
1176
- self.lm_head = LayerWiseHead(config.hidden_size, 1)
1177
- else:
1178
- # self.lm_head = nn.ModuleList([nn.Linear(
1179
- # config.hidden_size, 1, bias=False) for _ in range(
1180
- # self.config.start_layer,
1181
- # self.model.config.num_hidden_layers + 1)])
1182
- self.lm_head = nn.ModuleList([LayerWiseHead(
1183
- config.hidden_size, 1) for _ in range(
1184
- self.config.start_layer,
1185
- self.model.config.num_hidden_layers + 1)])
1186
-
1187
- # Initialize weights and apply final processing
1188
- self.post_init()
1189
-
1190
- def get_input_embeddings(self):
1191
- return self.model.embed_tokens
1192
-
1193
- def set_input_embeddings(self, value):
1194
- self.model.embed_tokens = value
1195
-
1196
- def get_output_embeddings(self):
1197
- return self.lm_head
1198
-
1199
- def set_output_embeddings(self, new_embeddings):
1200
- self.lm_head = new_embeddings
1201
-
1202
- def set_decoder(self, decoder):
1203
- self.model = decoder
1204
-
1205
- def get_decoder(self):
1206
- return self.model
1207
-
1208
- @add_start_docstrings_to_model_forward(MINICPM_INPUTS_DOCSTRING)
1209
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1210
- def forward(
1211
- self,
1212
- input_ids: torch.LongTensor = None,
1213
- attention_mask: Optional[torch.Tensor] = None,
1214
- position_ids: Optional[torch.LongTensor] = None,
1215
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1216
- inputs_embeds: Optional[torch.FloatTensor] = None,
1217
- labels: Optional[torch.LongTensor] = None,
1218
- use_cache: Optional[bool] = None,
1219
- output_attentions: Optional[bool] = None,
1220
- output_hidden_states: Optional[bool] = None,
1221
- return_dict: Optional[bool] = None,
1222
- cutoff_layers: Optional[Union[int, List]] = None,
1223
- only_for_one_logit: Optional[int] = None
1224
- ) -> Union[Tuple, CausalLMOutputWithPast]:
1225
- r"""
1226
- Args:
1227
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1228
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1229
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1230
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1231
-
1232
- Returns:
1233
-
1234
- Example:
1235
-
1236
- ```python
1237
- >>> from transformers import AutoTokenizer, MiniCPMForCausalLM
1238
-
1239
- >>> model = MiniCPMForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1240
- >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1241
-
1242
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
1243
- >>> inputs = tokenizer(prompt, return_tensors="pt")
1244
-
1245
- >>> # Generate
1246
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1247
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1248
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1249
- ```"""
1250
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1251
- output_hidden_states = (
1252
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1253
- )
1254
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1255
-
1256
- if cutoff_layers is None:
1257
- cutoff_layers = [self.config.num_hidden_layers]
1258
- elif isinstance(cutoff_layers, int):
1259
- cutoff_layers = [cutoff_layers]
1260
-
1261
- remove_layers = [i for i in cutoff_layers if self.config.start_layer > i or i > self.config.num_hidden_layers]
1262
- if len(remove_layers) > 0:
1263
- logger.warning_once(
1264
- f"layers {remove_layers} is incompatible with the setting. They will be removed..."
1265
- )
1266
-
1267
- cutoff_layers = [i for i in cutoff_layers if i not in remove_layers]
1268
-
1269
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1270
- outputs = self.model(
1271
- input_ids=input_ids,
1272
- attention_mask=attention_mask,
1273
- position_ids=position_ids,
1274
- past_key_values=past_key_values,
1275
- inputs_embeds=inputs_embeds,
1276
- use_cache=use_cache,
1277
- output_attentions=output_attentions,
1278
- output_hidden_states=True,
1279
- return_dict=return_dict,
1280
- cutoff_layers=cutoff_layers
1281
- )
1282
-
1283
- hidden_states = outputs[0]
1284
-
1285
- all_logits = ()
1286
- if only_for_one_logit is None and (self.config.head_type == 'complex' or self.config.head_type == 'raw'):
1287
- if self.config.head_type == 'raw':
1288
- for i in range(len(outputs.hidden_states)):
1289
- if self.config.head_multi == False:
1290
- if self.config.pretraining_tp > 1:
1291
- lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1292
- logits = [F.linear(outputs.hidden_states[i], lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1293
- logits = torch.cat(logits, dim=-1)
1294
- else:
1295
- logits = self.lm_head(outputs.hidden_states[i] / (self.config.hidden_size / self.config.dim_model_base))
1296
- else:
1297
- if self.config.pretraining_tp > 1:
1298
- lm_head_slices = self.lm_head[cutoff_layers[i] - self.config.start_layer].weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1299
- logits = [F.linear(outputs.hidden_states[i], lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1300
- logits = torch.cat(logits, dim=-1)
1301
- else:
1302
- logits = self.lm_head[cutoff_layers[i] - self.config.start_layer](outputs.hidden_states[i] / (self.config.hidden_size / self.config.dim_model_base))
1303
- logits = logits.float()
1304
- logits = logits.reshape(input_ids.shape[0], -1)
1305
- all_logits = all_logits + (logits, )
1306
- else:
1307
- for i in range(len(outputs.hidden_states)):
1308
- if self.config.head_multi == False:
1309
- if self.config.pretraining_tp > 1:
1310
- lm_head_slices = self.lm_head.linear_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1311
- logits = [F.linear(outputs.hidden_states[i], lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1312
- logits = torch.cat(logits, dim=-1)
1313
- else:
1314
- logits = self.lm_head.linear_head(outputs.hidden_states[i] / (self.config.hidden_size / self.config.dim_model_base))
1315
- else:
1316
- if self.config.pretraining_tp > 1:
1317
- lm_head_slices = self.lm_head[cutoff_layers[i] - self.config.start_layer].linear_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1318
- logits = [F.linear(outputs.hidden_states[i], lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1319
- logits = torch.cat(logits, dim=-1)
1320
- else:
1321
- logits = self.lm_head[cutoff_layers[i] - self.config.start_layer].linear_head(outputs.hidden_states[i] / (self.config.hidden_size / self.config.dim_model_base))
1322
- logits = logits.float()
1323
- logits = logits.reshape(input_ids.shape[0], -1)
1324
- all_logits = all_logits + (logits, )
1325
- else:
1326
- if self.config.head_type == 'raw':
1327
- if only_for_one_logit is None:
1328
- raise ValueError("Cannot handle `only_for_one_logit` is None if the head type is complex.")
1329
-
1330
- if self.config.head_multi == False:
1331
- lm_head_slices = self.lm_head.weight.split(1, dim=0)
1332
- for i in range(len(outputs.hidden_states)):
1333
- logits = F.linear(outputs.hidden_states[i], lm_head_slices[only_for_one_logit])
1334
- logits = logits.float()
1335
- logits = logits.reshape(input_ids.shape[0], -1)
1336
- all_logits = all_logits + (logits,)
1337
- else:
1338
- for i in range(len(outputs.hidden_states)):
1339
- lm_head_slices = self.lm_head[cutoff_layers[i] - self.config.start_layer].weight.split(1, dim=0)
1340
- logits = F.linear(outputs.hidden_states[i], lm_head_slices[only_for_one_logit])
1341
- logits = logits.float()
1342
- logits = logits.reshape(input_ids.shape[0], -1)
1343
- all_logits = all_logits + (logits, )
1344
- elif self.config.head_type == 'complex':
1345
- if only_for_one_logit is None:
1346
- raise ValueError("Cannot handle `only_for_one_logit` is None if the head type is complex.")
1347
-
1348
- if self.config.head_multi == False:
1349
- lm_head_slices = self.lm_head.linear_head.weight.split(1, dim=0)
1350
- for i in range(len(outputs.hidden_states)):
1351
- logits = F.linear(outputs.hidden_states[i], lm_head_slices[only_for_one_logit])
1352
- logits = logits.float()
1353
- logits = logits.reshape(input_ids.shape[0], -1)
1354
- all_logits = all_logits + (logits,)
1355
- else:
1356
- for i in range(len(outputs.hidden_states)):
1357
- lm_head_slices = self.lm_head[cutoff_layers[i] - self.config.start_layer].linear_head.weight.split(1, dim=0)
1358
- logits = F.linear(outputs.hidden_states[i], lm_head_slices[only_for_one_logit])
1359
- logits = logits.float()
1360
- logits = logits.reshape(input_ids.shape[0], -1)
1361
- all_logits = all_logits + (logits, )
1362
- else:
1363
- if self.config.head_multi == False:
1364
- for i in range(len(outputs.hidden_states)):
1365
- logits = self.lm_head.linear_head(outputs.hidden_states[i])
1366
- logits = logits.float()
1367
- logits = logits.reshape(input_ids.shape[0], -1)
1368
- all_logits = all_logits + (logits,)
1369
- else:
1370
- for i in range(len(outputs.hidden_states)):
1371
- logits = self.lm_head[cutoff_layers[i] - self.config.start_layer].linear_head(outputs.hidden_states[i])
1372
- logits = logits.float()
1373
- logits = logits.reshape(input_ids.shape[0], -1)
1374
- all_logits = all_logits + (logits,)
1375
-
1376
- loss = None
1377
- if labels is not None and not only_for_one_logit and self.config.head_type == 'complex':
1378
- # Shift so that tokens < n predict n
1379
- loss = 0
1380
- for logits in all_logits:
1381
- shift_logits = logits[..., :-1, :].contiguous()
1382
- shift_labels = labels[..., 1:].contiguous()
1383
- # Flatten the tokens
1384
- loss_fct = CrossEntropyLoss()
1385
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
1386
- shift_labels = shift_labels.view(-1)
1387
- # Enable model parallelism
1388
- shift_labels = shift_labels.to(shift_logits.device)
1389
- loss += loss_fct(shift_logits, shift_labels)
1390
-
1391
- outputs.hidden_states = None if not output_hidden_states else outputs.hidden_states
1392
-
1393
- if not return_dict:
1394
- output = (all_logits,) + outputs[1:]
1395
- return (loss,) + output if loss is not None else output
1396
-
1397
- return CausalLMOutputWithPast(
1398
- loss=loss,
1399
- logits=all_logits,
1400
- past_key_values=outputs.past_key_values,
1401
- hidden_states=outputs.hidden_states,
1402
- attentions=outputs.attentions,
1403
- )
1404
-
1405
- def prepare_inputs_for_generation(
1406
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1407
- ):
1408
- if past_key_values is not None:
1409
- if isinstance(past_key_values, Cache):
1410
- cache_length = past_key_values.get_seq_length()
1411
- past_length = past_key_values.seen_tokens
1412
- max_cache_length = past_key_values.get_max_length()
1413
- else:
1414
- cache_length = past_length = past_key_values[0][0].shape[2]
1415
- max_cache_length = None
1416
-
1417
- # Keep only the unprocessed tokens:
1418
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1419
- # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
1420
- # input)
1421
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1422
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):]
1423
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1424
- # input_ids based on the past_length.
1425
- elif past_length < input_ids.shape[1]:
1426
- input_ids = input_ids[:, past_length:]
1427
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1428
-
1429
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1430
- if (
1431
- max_cache_length is not None
1432
- and attention_mask is not None
1433
- and cache_length + input_ids.shape[1] > max_cache_length
1434
- ):
1435
- attention_mask = attention_mask[:, -max_cache_length:]
1436
-
1437
- position_ids = kwargs.get("position_ids", None)
1438
- if attention_mask is not None and position_ids is None:
1439
- # create position_ids on the fly for batch generation
1440
- position_ids = attention_mask.long().cumsum(-1) - 1
1441
- position_ids.masked_fill_(attention_mask == 0, 1)
1442
- if past_key_values:
1443
- position_ids = position_ids[:, -input_ids.shape[1]:]
1444
-
1445
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1446
- if inputs_embeds is not None and past_key_values is None:
1447
- model_inputs = {"inputs_embeds": inputs_embeds}
1448
- else:
1449
- model_inputs = {"input_ids": input_ids}
1450
-
1451
- model_inputs.update(
1452
- {
1453
- "position_ids": position_ids,
1454
- "past_key_values": past_key_values,
1455
- "use_cache": kwargs.get("use_cache"),
1456
- "attention_mask": attention_mask,
1457
- }
1458
- )
1459
- return model_inputs
1460
-
1461
- @staticmethod
1462
- def _reorder_cache(past_key_values, beam_idx):
1463
- reordered_past = ()
1464
- for layer_past in past_key_values:
1465
- reordered_past += (
1466
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1467
- )
1468
- return reordered_past
1469
-
1470
- @torch.inference_mode()
1471
- def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
1472
- max_length: int = 4096, num_beams=1, do_sample=True, top_p=0.8, temperature=0.3, logits_processor=None,
1473
- **kwargs):
1474
- if history is None:
1475
- history = []
1476
- if logits_processor:
1477
- gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1478
- "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1479
- else:
1480
- gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1481
- "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1482
-
1483
- history.append({"role": role, "content": query})
1484
- history_str = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=False)
1485
- inputs = tokenizer(history_str, return_tensors='pt').to(self.device)
1486
- outputs = self.generate(**inputs, **gen_kwargs)
1487
- outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1488
- response = tokenizer.decode(outputs)
1489
- pattern = re.compile(r".*?(?=<AI>|<用户>)", re.DOTALL)
1490
- matches = pattern.findall(response)
1491
- if len(matches) > 0:
1492
- response = matches[0]
1493
- history.append({"role": "assistant", "content": response})
1494
- return response, history