x54-729 commited on
Commit
5966e12
·
1 Parent(s): 1611f45

support flash attn 2

Browse files
Files changed (2) hide show
  1. configuration_internlm.py +32 -3
  2. modeling_internlm2.py +216 -81
configuration_internlm.py CHANGED
@@ -106,7 +106,9 @@ class InternLMConfig(PretrainedConfig):
106
  eos_token_id=2,
107
  tie_word_embeddings=False,
108
  bias=True,
109
- rotary={"base": 10000, "type": "dynamic"}, # pylint: disable=W0102
 
 
110
  **kwargs,
111
  ):
112
  self.vocab_size = vocab_size
@@ -115,6 +117,7 @@ class InternLMConfig(PretrainedConfig):
115
  self.intermediate_size = intermediate_size
116
  self.num_hidden_layers = num_hidden_layers
117
  self.num_attention_heads = num_attention_heads
 
118
 
119
  if num_key_value_heads is None:
120
  num_key_value_heads = num_attention_heads
@@ -124,8 +127,13 @@ class InternLMConfig(PretrainedConfig):
124
  self.initializer_range = initializer_range
125
  self.rms_norm_eps = rms_norm_eps
126
  self.use_cache = use_cache
127
- self.bias = bias
128
- self.rotary = rotary
 
 
 
 
 
129
  super().__init__(
130
  pad_token_id=pad_token_id,
131
  bos_token_id=bos_token_id,
@@ -133,3 +141,24 @@ class InternLMConfig(PretrainedConfig):
133
  tie_word_embeddings=tie_word_embeddings,
134
  **kwargs,
135
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  eos_token_id=2,
107
  tie_word_embeddings=False,
108
  bias=True,
109
+ rope_theta=10000,
110
+ rope_scaling=None,
111
+ attn_implementation="eager",
112
  **kwargs,
113
  ):
114
  self.vocab_size = vocab_size
 
117
  self.intermediate_size = intermediate_size
118
  self.num_hidden_layers = num_hidden_layers
119
  self.num_attention_heads = num_attention_heads
120
+ self.bias = bias
121
 
122
  if num_key_value_heads is None:
123
  num_key_value_heads = num_attention_heads
 
127
  self.initializer_range = initializer_range
128
  self.rms_norm_eps = rms_norm_eps
129
  self.use_cache = use_cache
130
+ self.rope_theta = rope_theta
131
+ self.rope_scaling = rope_scaling
132
+ self._rope_scaling_validation()
133
+
134
+ self.attn_implementation = attn_implementation
135
+ if self.attn_implementation is None:
136
+ self.attn_implementation = "eager"
137
  super().__init__(
138
  pad_token_id=pad_token_id,
139
  bos_token_id=bos_token_id,
 
141
  tie_word_embeddings=tie_word_embeddings,
142
  **kwargs,
143
  )
144
+
145
+ def _rope_scaling_validation(self):
146
+ """
147
+ Validate the `rope_scaling` configuration.
148
+ """
149
+ if self.rope_scaling is None:
150
+ return
151
+
152
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
153
+ raise ValueError(
154
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
155
+ f"got {self.rope_scaling}"
156
+ )
157
+ rope_scaling_type = self.rope_scaling.get("type", None)
158
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
159
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
160
+ raise ValueError(
161
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
162
+ )
163
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor < 1.0:
164
+ raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}")
modeling_internlm2.py CHANGED
@@ -1,10 +1,6 @@
1
- # coding=utf-8
2
- # # Copyright (c) InternLM. All rights reserved.
3
  #
4
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
- # and OPT implementations in this library. It has been modified from its
6
- # original forms to accommodate minor architectural differences compared
7
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
  #
9
  # Licensed under the Apache License, Version 2.0 (the "License");
10
  # you may not use this file except in compliance with the License.
@@ -25,6 +21,7 @@ import warnings
25
  from typing import List, Optional, Tuple, Union
26
 
27
  import torch
 
28
  import torch.utils.checkpoint
29
  from einops import rearrange
30
  from torch import nn
@@ -54,6 +51,18 @@ logger = logging.get_logger(__name__)
54
 
55
  _CONFIG_FOR_DOC = "InternLM2Config"
56
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
59
  def _make_causal_mask(
@@ -88,6 +97,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
88
  return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
89
 
90
 
 
91
  class InternLM2RMSNorm(nn.Module):
92
  def __init__(self, hidden_size, eps=1e-6):
93
  """
@@ -105,6 +115,7 @@ class InternLM2RMSNorm(nn.Module):
105
  return self.weight * hidden_states.to(input_dtype)
106
 
107
 
 
108
  class InternLM2RotaryEmbedding(nn.Module):
109
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
110
  super().__init__()
@@ -133,7 +144,7 @@ class InternLM2RotaryEmbedding(nn.Module):
133
  def forward(self, x, seq_len=None):
134
  # x: [bs, num_attention_heads, seq_len, head_size]
135
  if seq_len > self.max_seq_len_cached:
136
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
137
 
138
  return (
139
  self.cos_cached[:seq_len].to(dtype=x.dtype),
@@ -141,6 +152,7 @@ class InternLM2RotaryEmbedding(nn.Module):
141
  )
142
 
143
 
 
144
  class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):
145
  """InternLM2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
146
 
@@ -160,6 +172,7 @@ class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):
160
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
161
 
162
 
 
163
  class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding):
164
  """InternLM2RotaryEmbedding extended with Dynamic NTK scaling.
165
  Credits to the Reddit users /u/bloc97 and /u/emozilla.
@@ -188,6 +201,7 @@ class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding):
188
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
189
 
190
 
 
191
  def rotate_half(x):
192
  """Rotates half the hidden dims of the input."""
193
  x1 = x[..., : x.shape[-1] // 2]
@@ -195,22 +209,13 @@ def rotate_half(x):
195
  return torch.cat((-x2, x1), dim=-1)
196
 
197
 
198
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
199
- # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
200
- cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
201
- sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
202
- cos = cos.unsqueeze(0).unsqueeze(0).expand(len(position_ids), -1, -1, -1)
203
- sin = sin.unsqueeze(0).unsqueeze(0).expand(len(position_ids), -1, -1, -1)
204
- if q.size(2) == 1:
205
- q_embed = (q * cos[:, :, -1, :]) + (rotate_half(q) * sin[:, :, -1, :])
206
- else:
207
- q_embed = (q * cos) + (rotate_half(q) * sin)
208
-
209
- if k.size(2) == 1:
210
- k_embed = (k * cos[:, :, -1, :]) + (rotate_half(k) * sin[:, :, -1, :])
211
- else:
212
- k_embed = (k * cos) + (rotate_half(k) * sin)
213
-
214
  return q_embed, k_embed
215
 
216
 
@@ -231,6 +236,7 @@ class InternLM2MLP(nn.Module):
231
  return down_proj
232
 
233
 
 
234
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
235
  """
236
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
@@ -243,6 +249,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
243
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
244
 
245
 
 
246
  class InternLM2Attention(nn.Module):
247
  """Multi-headed attention from 'Attention Is All You Need' paper"""
248
 
@@ -273,21 +280,31 @@ class InternLM2Attention(nn.Module):
273
  self._init_rope()
274
 
275
  def _init_rope(self):
276
- if self.config.rotary["type"] == "origin":
277
  self.rotary_emb = InternLM2RotaryEmbedding(
278
  self.head_dim,
279
  max_position_embeddings=self.max_position_embeddings,
280
- base=self.config.rotary["base"],
281
- )
282
- elif self.config.rotary["type"] == "dynamic":
283
- self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding(
284
- self.head_dim,
285
- max_position_embeddings=self.max_position_embeddings,
286
- base=self.config.rotary["base"],
287
- scaling_factor=self.config.rotary.get("scaling_factor", 1.0),
288
  )
289
  else:
290
- raise ValueError("Currently we only support rotary embedding's type being one of ('origin', 'dynamic').")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  return self.rotary_emb
292
 
293
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
@@ -381,6 +398,7 @@ class InternLM2Attention(nn.Module):
381
  return attn_output, attn_weights, past_key_value
382
 
383
 
 
384
  class InternLM2FlashAttention2(InternLM2Attention):
385
  """
386
  InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays
@@ -417,9 +435,8 @@ class InternLM2FlashAttention2(InternLM2Attention):
417
  qkv_states = rearrange(
418
  qkv_states,
419
  "b q (h gs d) -> b q h gs d",
420
- gs=self.num_heads + 2 * self.num_key_value_heads,
421
  d=self.head_dim,
422
- q=q_len,
423
  )
424
 
425
  query_states = qkv_states[..., : self.num_key_value_groups, :]
@@ -427,6 +444,10 @@ class InternLM2FlashAttention2(InternLM2Attention):
427
  key_states = qkv_states[..., -2, :]
428
  value_states = qkv_states[..., -1, :]
429
 
 
 
 
 
430
  kv_seq_len = key_states.shape[-2]
431
  if past_key_value is not None:
432
  kv_seq_len += past_key_value[0].shape[-2]
@@ -448,34 +469,9 @@ class InternLM2FlashAttention2(InternLM2Attention):
448
 
449
  dropout_rate = 0.0 if not self.training else self.attention_dropout
450
 
451
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
452
- # therefore the input hidden states gets silently casted in float32. Hence, we need
453
- # cast them back in the correct dtype just to be sure everything works as expected.
454
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
455
- # in fp32. (InternLM2RMSNorm handles it correctly)
456
-
457
- input_dtype = query_states.dtype
458
- if input_dtype == torch.float32:
459
- # Handle the case where the model is quantized
460
- if hasattr(self.config, "_pre_quantization_dtype"):
461
- target_dtype = self.config._pre_quantization_dtype
462
- else:
463
- target_dtype = self.q_proj.weight.dtype
464
-
465
- logger.warning_once(
466
- f"The input hidden states seems to be silently casted in float32, this might be related to"
467
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back "
468
- f"the input in {target_dtype}."
469
- )
470
-
471
- query_states = query_states.to(target_dtype)
472
- key_states = key_states.to(target_dtype)
473
- value_states = value_states.to(target_dtype)
474
-
475
  attn_output = self._flash_attention_forward(
476
  query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
477
  )
478
-
479
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
480
  attn_output = self.wo(attn_output)
481
 
@@ -484,16 +480,115 @@ class InternLM2FlashAttention2(InternLM2Attention):
484
 
485
  return attn_output, attn_weights, past_key_value
486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487
 
 
488
  class InternLM2DecoderLayer(nn.Module):
489
  def __init__(self, config: InternLM2Config):
490
  super().__init__()
491
  self.hidden_size = config.hidden_size
492
- self.attention = (
493
- InternLM2Attention(config=config)
494
- if not getattr(config, "_flash_attn_2_enabled", False)
495
- else InternLM2FlashAttention2(config=config)
496
- )
497
  self.feed_forward = InternLM2MLP(config)
498
  self.attention_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
499
  self.ffn_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -565,9 +660,11 @@ InternLM2_START_DOCSTRING = r"""
565
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
566
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
567
  etc.)
 
568
  This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
569
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
570
  and behavior.
 
571
  Parameters:
572
  config ([`InternLM2Config`]):
573
  Model configuration class with all the parameters of the model. Initializing with a config file does not
@@ -576,6 +673,7 @@ InternLM2_START_DOCSTRING = r"""
576
  """
577
 
578
 
 
579
  @add_start_docstrings(
580
  "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
581
  InternLM2_START_DOCSTRING,
@@ -586,7 +684,6 @@ class InternLM2PreTrainedModel(PreTrainedModel):
586
  supports_gradient_checkpointing = True
587
  _no_split_modules = ["InternLM2DecoderLayer"]
588
  _skip_keys_device_placement = "past_key_values"
589
- _supports_flash_attn_2 = True
590
 
591
  def _init_weights(self, module):
592
  std = self.config.initializer_range
@@ -605,34 +702,45 @@ InternLM2_INPUTS_DOCSTRING = r"""
605
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
606
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
607
  it.
 
608
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
609
  [`PreTrainedTokenizer.__call__`] for details.
 
610
  [What are input IDs?](../glossary#input-ids)
611
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
612
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
 
613
  - 1 for tokens that are **not masked**,
614
  - 0 for tokens that are **masked**.
 
615
  [What are attention masks?](../glossary#attention-mask)
 
616
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
617
  [`PreTrainedTokenizer.__call__`] for details.
 
618
  If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
619
  `past_key_values`).
 
620
  If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
621
  and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
622
  information on the default strategy.
 
623
  - 1 indicates the head is **not masked**,
624
  - 0 indicates the head is **masked**.
625
  position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
626
  Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
627
  config.n_positions - 1]`.
 
628
  [What are position IDs?](../glossary#position-ids)
629
  past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or
630
  when `config.use_cache=True`):
631
  Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
632
  `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
633
  `(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)`.
 
634
  Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
635
  blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
 
636
  If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
637
  have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
638
  of shape `(batch_size, sequence_length)`.
@@ -654,6 +762,7 @@ InternLM2_INPUTS_DOCSTRING = r"""
654
  """
655
 
656
 
 
657
  @add_start_docstrings(
658
  "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
659
  InternLM2_START_DOCSTRING,
@@ -661,6 +770,7 @@ InternLM2_INPUTS_DOCSTRING = r"""
661
  class InternLM2Model(InternLM2PreTrainedModel):
662
  """
663
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLM2DecoderLayer`]
 
664
  Args:
665
  config: InternLM2Config
666
  """
@@ -671,8 +781,10 @@ class InternLM2Model(InternLM2PreTrainedModel):
671
  super().__init__(config)
672
  self.padding_idx = config.pad_token_id
673
  self.vocab_size = config.vocab_size
 
674
 
675
  self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 
676
  self.layers = nn.ModuleList([InternLM2DecoderLayer(config) for _ in range(config.num_hidden_layers)])
677
  self.norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
678
 
@@ -686,7 +798,6 @@ class InternLM2Model(InternLM2PreTrainedModel):
686
  def set_input_embeddings(self, value):
687
  self.tok_embeddings = value
688
 
689
- # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
690
  def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
691
  # create causal mask
692
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
@@ -756,14 +867,18 @@ class InternLM2Model(InternLM2PreTrainedModel):
756
 
757
  if inputs_embeds is None:
758
  inputs_embeds = self.tok_embeddings(input_ids)
759
- # embed positions
760
- if attention_mask is None:
761
- attention_mask = torch.ones(
762
- (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
 
 
 
 
 
 
 
763
  )
764
- attention_mask = self._prepare_decoder_attention_mask(
765
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
766
- )
767
 
768
  # embed positions
769
  hidden_states = inputs_embeds
@@ -837,6 +952,7 @@ class InternLM2Model(InternLM2PreTrainedModel):
837
  )
838
 
839
 
 
840
  class InternLM2ForCausalLM(InternLM2PreTrainedModel):
841
  _auto_class = "AutoModelForCausalLM"
842
 
@@ -890,14 +1006,20 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
890
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
891
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
892
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
 
893
  Returns:
 
894
  Example:
 
895
  ```python
896
  >>> from transformers import AutoTokenizer, InternLM2ForCausalLM
 
897
  >>> model = InternLM2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
898
  >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
 
899
  >>> prompt = "Hey, are you conscious? Can you talk to me?"
900
  >>> inputs = tokenizer(prompt, return_tensors="pt")
 
901
  >>> # Generate
902
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
903
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
@@ -1000,11 +1122,15 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
1000
  )
1001
  return reordered_past
1002
 
1003
- def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = []):
1004
  prompt = ""
 
 
 
 
1005
  for record in history:
1006
- prompt += f"""<|User|>:{record[0]}<eoh>\n<|Bot|>:{record[1]}<eoa>\n"""
1007
- prompt += f"""<|User|>:{query}<eoh>\n<|Bot|>:"""
1008
  return tokenizer([prompt], return_tensors="pt")
1009
 
1010
  @torch.no_grad()
@@ -1018,10 +1144,15 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
1018
  do_sample: bool = True,
1019
  temperature: float = 0.8,
1020
  top_p: float = 0.8,
 
 
 
1021
  **kwargs,
1022
  ):
1023
- inputs = self.build_inputs(tokenizer, query, history)
1024
  inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
 
 
1025
  outputs = self.generate(
1026
  **inputs,
1027
  streamer=streamer,
@@ -1029,11 +1160,12 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
1029
  do_sample=do_sample,
1030
  temperature=temperature,
1031
  top_p=top_p,
 
1032
  **kwargs,
1033
  )
1034
  outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :]
1035
  response = tokenizer.decode(outputs, skip_special_tokens=True)
1036
- response = response.split("<eoa>")[0]
1037
  history = history + [(query, response)]
1038
  return response, history
1039
 
@@ -1086,7 +1218,7 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
1086
  return
1087
 
1088
  token = self.tokenizer.decode([value[-1]], skip_special_tokens=True)
1089
- if token.strip() != "<eoa>":
1090
  self.response = self.response + token
1091
  history = self.history + [(self.query, self.response)]
1092
  self.queue.put((self.response, history))
@@ -1119,11 +1251,14 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
1119
  return consumer()
1120
 
1121
 
 
1122
  @add_start_docstrings(
1123
  """
1124
  The InternLM2 Model transformer with a sequence classification head on top (linear layer).
 
1125
  [`InternLM2ForSequenceClassification`] uses the last token in order to do the classification,
1126
  as other causal models (e.g. GPT-2) do.
 
1127
  Since it does classification on the last token, it requires to know the position of the last token. If a
1128
  `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1129
  no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
@@ -1236,4 +1371,4 @@ class InternLM2ForSequenceClassification(InternLM2PreTrainedModel):
1236
  past_key_values=transformer_outputs.past_key_values,
1237
  hidden_states=transformer_outputs.hidden_states,
1238
  attentions=transformer_outputs.attentions,
1239
- )
 
1
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
 
2
  #
3
+ # This code is based on transformers/src/transformers/models/llama/modeling_llama.py
 
 
 
4
  #
5
  # Licensed under the Apache License, Version 2.0 (the "License");
6
  # you may not use this file except in compliance with the License.
 
21
  from typing import List, Optional, Tuple, Union
22
 
23
  import torch
24
+ import torch.nn.functional as F
25
  import torch.utils.checkpoint
26
  from einops import rearrange
27
  from torch import nn
 
51
 
52
  _CONFIG_FOR_DOC = "InternLM2Config"
53
 
54
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
55
+ def _get_unpad_data(attention_mask):
56
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
57
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
58
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
59
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
60
+ return (
61
+ indices,
62
+ cu_seqlens,
63
+ max_seqlen_in_batch,
64
+ )
65
+
66
 
67
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
68
  def _make_causal_mask(
 
97
  return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
98
 
99
 
100
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->InternLM2
101
  class InternLM2RMSNorm(nn.Module):
102
  def __init__(self, hidden_size, eps=1e-6):
103
  """
 
115
  return self.weight * hidden_states.to(input_dtype)
116
 
117
 
118
+ # Copied from transformers.model.llama.modeling_llama.LlamaRotaryEmbedding with Llama->InternLM2
119
  class InternLM2RotaryEmbedding(nn.Module):
120
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
121
  super().__init__()
 
144
  def forward(self, x, seq_len=None):
145
  # x: [bs, num_attention_heads, seq_len, head_size]
146
  if seq_len > self.max_seq_len_cached:
147
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.float32)
148
 
149
  return (
150
  self.cos_cached[:seq_len].to(dtype=x.dtype),
 
152
  )
153
 
154
 
155
+ # Copied from transformers.model.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->InternLM2
156
  class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):
157
  """InternLM2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
158
 
 
172
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
173
 
174
 
175
+ # Copied from transformers.model.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->InternLM2
176
  class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding):
177
  """InternLM2RotaryEmbedding extended with Dynamic NTK scaling.
178
  Credits to the Reddit users /u/bloc97 and /u/emozilla.
 
201
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
202
 
203
 
204
+ # Copied from transformers.model.llama.modeling_llama.rotate_half
205
  def rotate_half(x):
206
  """Rotates half the hidden dims of the input."""
207
  x1 = x[..., : x.shape[-1] // 2]
 
209
  return torch.cat((-x2, x1), dim=-1)
210
 
211
 
212
+ # Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
213
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
214
+ """Applies Rotary Position Embedding to the query and key tensors."""
215
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
216
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
217
+ q_embed = (q * cos) + (rotate_half(q) * sin)
218
+ k_embed = (k * cos) + (rotate_half(k) * sin)
 
 
 
 
 
 
 
 
 
219
  return q_embed, k_embed
220
 
221
 
 
236
  return down_proj
237
 
238
 
239
+ # Copied from transformers.model.llama.modeling_llama.repeat_kv
240
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
241
  """
242
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
 
249
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
250
 
251
 
252
+ # Modified from transformers.model.llama.modeling_llama.LlamaAttention
253
  class InternLM2Attention(nn.Module):
254
  """Multi-headed attention from 'Attention Is All You Need' paper"""
255
 
 
280
  self._init_rope()
281
 
282
  def _init_rope(self):
283
+ if self.config.rope_scaling is None:
284
  self.rotary_emb = InternLM2RotaryEmbedding(
285
  self.head_dim,
286
  max_position_embeddings=self.max_position_embeddings,
287
+ base=self.config.rope_theta,
 
 
 
 
 
 
 
288
  )
289
  else:
290
+ scaling_type = self.config.rope_scaling["type"]
291
+ scaling_factor = self.config.rope_scaling["factor"]
292
+ if scaling_type == "dynamic":
293
+ self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding(
294
+ self.head_dim,
295
+ max_position_embeddings=self.max_position_embeddings,
296
+ base=self.config.rope_theta,
297
+ scaling_factor=scaling_factor,
298
+ )
299
+ elif scaling_type == "linear":
300
+ self.rotary_emb = InternLM2LinearScalingRotaryEmbedding(
301
+ self.head_dim,
302
+ max_position_embeddings=self.max_position_embeddings,
303
+ base=self.config.rope_theta,
304
+ scaling_factor=scaling_factor,
305
+ )
306
+ else:
307
+ raise ValueError("Currently we only support rotary embedding's type being 'dynamic' or 'linear'.")
308
  return self.rotary_emb
309
 
310
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
 
398
  return attn_output, attn_weights, past_key_value
399
 
400
 
401
+ # Modified from transformers.model.llama.modeling_llama.InternLM2FlashAttention2
402
  class InternLM2FlashAttention2(InternLM2Attention):
403
  """
404
  InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays
 
435
  qkv_states = rearrange(
436
  qkv_states,
437
  "b q (h gs d) -> b q h gs d",
438
+ gs=2 + self.num_key_value_groups,
439
  d=self.head_dim,
 
440
  )
441
 
442
  query_states = qkv_states[..., : self.num_key_value_groups, :]
 
444
  key_states = qkv_states[..., -2, :]
445
  value_states = qkv_states[..., -1, :]
446
 
447
+ query_states = query_states.transpose(1, 2)
448
+ key_states = key_states.transpose(1, 2)
449
+ value_states = value_states.transpose(1, 2)
450
+
451
  kv_seq_len = key_states.shape[-2]
452
  if past_key_value is not None:
453
  kv_seq_len += past_key_value[0].shape[-2]
 
469
 
470
  dropout_rate = 0.0 if not self.training else self.attention_dropout
471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
  attn_output = self._flash_attention_forward(
473
  query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
474
  )
 
475
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
476
  attn_output = self.wo(attn_output)
477
 
 
480
 
481
  return attn_output, attn_weights, past_key_value
482
 
483
+ def _flash_attention_forward(
484
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
485
+ ):
486
+ """
487
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
488
+ first unpad the input, then computes the attention scores and pad the final attention scores.
489
+
490
+ Args:
491
+ query_states (`torch.Tensor`):
492
+ Input query states to be passed to Flash Attention API
493
+ key_states (`torch.Tensor`):
494
+ Input key states to be passed to Flash Attention API
495
+ value_states (`torch.Tensor`):
496
+ Input value states to be passed to Flash Attention API
497
+ attention_mask (`torch.Tensor`):
498
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
499
+ position of padding tokens and 1 for the position of non-padding tokens.
500
+ dropout (`int`, *optional*):
501
+ Attention dropout
502
+ softmax_scale (`float`, *optional*):
503
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
504
+ """
505
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
506
+ from flash_attn.bert_padding import pad_input
507
+ # Contains at least one padding token in the sequence
508
+ causal = self.is_causal and query_length != 1
509
+ if attention_mask is not None:
510
+ batch_size = query_states.shape[0]
511
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
512
+ query_states, key_states, value_states, attention_mask, query_length
513
+ )
514
+
515
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
516
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
517
+
518
+ attn_output_unpad = flash_attn_varlen_func(
519
+ query_states,
520
+ key_states,
521
+ value_states,
522
+ cu_seqlens_q=cu_seqlens_q,
523
+ cu_seqlens_k=cu_seqlens_k,
524
+ max_seqlen_q=max_seqlen_in_batch_q,
525
+ max_seqlen_k=max_seqlen_in_batch_k,
526
+ dropout_p=dropout,
527
+ softmax_scale=softmax_scale,
528
+ causal=causal,
529
+ )
530
+
531
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
532
+ else:
533
+ attn_output = flash_attn_func(
534
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
535
+ )
536
+
537
+ return attn_output
538
+
539
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
540
+ from flash_attn.bert_padding import index_first_axis, unpad_input
541
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
542
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
543
+
544
+ key_layer = index_first_axis(
545
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
546
+ )
547
+ value_layer = index_first_axis(
548
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
549
+ )
550
+
551
+ if query_length == kv_seq_len:
552
+ query_layer = index_first_axis(
553
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
554
+ )
555
+ cu_seqlens_q = cu_seqlens_k
556
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
557
+ indices_q = indices_k
558
+ elif query_length == 1:
559
+ max_seqlen_in_batch_q = 1
560
+ cu_seqlens_q = torch.arange(
561
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
562
+ ) # There is a memcpy here, that is very bad.
563
+ indices_q = cu_seqlens_q[:-1]
564
+ query_layer = query_layer.squeeze(1)
565
+ else:
566
+ # The -q_len: slice assumes left padding.
567
+ attention_mask = attention_mask[:, -query_length:]
568
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
569
+
570
+ return (
571
+ query_layer,
572
+ key_layer,
573
+ value_layer,
574
+ indices_q.to(torch.int64),
575
+ (cu_seqlens_q, cu_seqlens_k),
576
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
577
+ )
578
+
579
+ INTERNLM2_ATTENTION_CLASSES = {
580
+ "eager": InternLM2Attention,
581
+ "flash_attention_2": InternLM2FlashAttention2,
582
+ }
583
 
584
+ # Modified from transformers.model.llama.modeling_llama.LlamaDecoderLayer
585
  class InternLM2DecoderLayer(nn.Module):
586
  def __init__(self, config: InternLM2Config):
587
  super().__init__()
588
  self.hidden_size = config.hidden_size
589
+
590
+ self.attention = INTERNLM2_ATTENTION_CLASSES[config.attn_implementation](config=config)
591
+
 
 
592
  self.feed_forward = InternLM2MLP(config)
593
  self.attention_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
594
  self.ffn_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
660
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
661
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
662
  etc.)
663
+
664
  This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
665
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
666
  and behavior.
667
+
668
  Parameters:
669
  config ([`InternLM2Config`]):
670
  Model configuration class with all the parameters of the model. Initializing with a config file does not
 
673
  """
674
 
675
 
676
+ # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->InternLM2
677
  @add_start_docstrings(
678
  "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
679
  InternLM2_START_DOCSTRING,
 
684
  supports_gradient_checkpointing = True
685
  _no_split_modules = ["InternLM2DecoderLayer"]
686
  _skip_keys_device_placement = "past_key_values"
 
687
 
688
  def _init_weights(self, module):
689
  std = self.config.initializer_range
 
702
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
703
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
704
  it.
705
+
706
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
707
  [`PreTrainedTokenizer.__call__`] for details.
708
+
709
  [What are input IDs?](../glossary#input-ids)
710
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
711
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
712
+
713
  - 1 for tokens that are **not masked**,
714
  - 0 for tokens that are **masked**.
715
+
716
  [What are attention masks?](../glossary#attention-mask)
717
+
718
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
719
  [`PreTrainedTokenizer.__call__`] for details.
720
+
721
  If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
722
  `past_key_values`).
723
+
724
  If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
725
  and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
726
  information on the default strategy.
727
+
728
  - 1 indicates the head is **not masked**,
729
  - 0 indicates the head is **masked**.
730
  position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
731
  Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
732
  config.n_positions - 1]`.
733
+
734
  [What are position IDs?](../glossary#position-ids)
735
  past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or
736
  when `config.use_cache=True`):
737
  Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
738
  `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
739
  `(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)`.
740
+
741
  Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
742
  blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
743
+
744
  If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
745
  have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
746
  of shape `(batch_size, sequence_length)`.
 
762
  """
763
 
764
 
765
+ # Modified from transformers.model.llama.modeling_llama.LlamaModel
766
  @add_start_docstrings(
767
  "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
768
  InternLM2_START_DOCSTRING,
 
770
  class InternLM2Model(InternLM2PreTrainedModel):
771
  """
772
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLM2DecoderLayer`]
773
+
774
  Args:
775
  config: InternLM2Config
776
  """
 
781
  super().__init__(config)
782
  self.padding_idx = config.pad_token_id
783
  self.vocab_size = config.vocab_size
784
+ self.config = config
785
 
786
  self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
787
+
788
  self.layers = nn.ModuleList([InternLM2DecoderLayer(config) for _ in range(config.num_hidden_layers)])
789
  self.norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
790
 
 
798
  def set_input_embeddings(self, value):
799
  self.tok_embeddings = value
800
 
 
801
  def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
802
  # create causal mask
803
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
 
867
 
868
  if inputs_embeds is None:
869
  inputs_embeds = self.tok_embeddings(input_ids)
870
+
871
+ if self.config.attn_implementation == "flash_attention_2":
872
+ # 2d mask is passed through the layers
873
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
874
+ else:
875
+ if attention_mask is None:
876
+ attention_mask = torch.ones(
877
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
878
+ )
879
+ attention_mask = self._prepare_decoder_attention_mask(
880
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
881
  )
 
 
 
882
 
883
  # embed positions
884
  hidden_states = inputs_embeds
 
952
  )
953
 
954
 
955
+ # Modified from transformers.model.llama.modeling_llama.LlamaForCausalLM
956
  class InternLM2ForCausalLM(InternLM2PreTrainedModel):
957
  _auto_class = "AutoModelForCausalLM"
958
 
 
1006
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1007
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1008
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1009
+
1010
  Returns:
1011
+
1012
  Example:
1013
+
1014
  ```python
1015
  >>> from transformers import AutoTokenizer, InternLM2ForCausalLM
1016
+
1017
  >>> model = InternLM2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1018
  >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1019
+
1020
  >>> prompt = "Hey, are you conscious? Can you talk to me?"
1021
  >>> inputs = tokenizer(prompt, return_tensors="pt")
1022
+
1023
  >>> # Generate
1024
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1025
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
 
1122
  )
1123
  return reordered_past
1124
 
1125
+ def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = [], meta_instruction=""):
1126
  prompt = ""
1127
+ if meta_instruction:
1128
+ prompt += f"""<s>[UNUSED_TOKEN_146]system\n{meta_instruction}[UNUSED_TOKEN_145]\n"""
1129
+ else:
1130
+ prompt += "<s>"
1131
  for record in history:
1132
+ prompt += f"""[UNUSED_TOKEN_146]user\n{record[0]}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n{record[1]}[UNUSED_TOKEN_145]\n"""
1133
+ prompt += f"""[UNUSED_TOKEN_146]user\n{query}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n"""
1134
  return tokenizer([prompt], return_tensors="pt")
1135
 
1136
  @torch.no_grad()
 
1144
  do_sample: bool = True,
1145
  temperature: float = 0.8,
1146
  top_p: float = 0.8,
1147
+ meta_instruction: str = "You are an AI assistant whose name is InternLM (书生·浦语).\n"
1148
+ "- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
1149
+ "- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文.",
1150
  **kwargs,
1151
  ):
1152
+ inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
1153
  inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
1154
+ # also add end-of-assistant token in eos token id to avoid unnecessary generation
1155
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(["[UNUSED_TOKEN_145]"])[0]]
1156
  outputs = self.generate(
1157
  **inputs,
1158
  streamer=streamer,
 
1160
  do_sample=do_sample,
1161
  temperature=temperature,
1162
  top_p=top_p,
1163
+ eos_token_id=eos_token_id,
1164
  **kwargs,
1165
  )
1166
  outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :]
1167
  response = tokenizer.decode(outputs, skip_special_tokens=True)
1168
+ response = response.split("[UNUSED_TOKEN_145]")[0]
1169
  history = history + [(query, response)]
1170
  return response, history
1171
 
 
1218
  return
1219
 
1220
  token = self.tokenizer.decode([value[-1]], skip_special_tokens=True)
1221
+ if token.strip() != "[UNUSED_TOKEN_145]":
1222
  self.response = self.response + token
1223
  history = self.history + [(self.query, self.response)]
1224
  self.queue.put((self.response, history))
 
1251
  return consumer()
1252
 
1253
 
1254
+ # Copied from transformers.model.llama.modeling_llama.LlamaForSequenceClassification with Llama->InternLM2
1255
  @add_start_docstrings(
1256
  """
1257
  The InternLM2 Model transformer with a sequence classification head on top (linear layer).
1258
+
1259
  [`InternLM2ForSequenceClassification`] uses the last token in order to do the classification,
1260
  as other causal models (e.g. GPT-2) do.
1261
+
1262
  Since it does classification on the last token, it requires to know the position of the last token. If a
1263
  `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1264
  no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
 
1371
  past_key_values=transformer_outputs.past_key_values,
1372
  hidden_states=transformer_outputs.hidden_states,
1373
  attentions=transformer_outputs.attentions,
1374
+ )