dinalt commited on
Commit
bb91586
·
verified ·
1 Parent(s): a4d2883

Add support for inference cache to model.

Browse files
Files changed (1) hide show
  1. modelling_walsh.py +364 -61
modelling_walsh.py CHANGED
@@ -1,5 +1,5 @@
1
  # See: https://huggingface.co/docs/transformers/custom_models
2
- from typing import Optional, Tuple, Union
3
  import math
4
  import copy
5
  import sys
@@ -9,7 +9,7 @@ import torch
9
  from torch import nn, Tensor
10
  import torch.nn.init as init
11
  from torch.nn import functional as F
12
- from transformers.modeling_outputs import CausalLMOutput
13
  from transformers import (
14
  PreTrainedModel,
15
  PretrainedConfig,
@@ -18,6 +18,10 @@ from transformers import (
18
  AutoModelForCausalLM,
19
  )
20
 
 
 
 
 
21
  from transformers.utils import (
22
  is_flash_attn_2_available,
23
  is_flash_attn_greater_or_equal_2_10,
@@ -26,6 +30,8 @@ from transformers.utils import (
26
  if is_flash_attn_2_available():
27
  from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
28
 
 
 
29
  # The model type string to bind.
30
  model_type = "walsh-causal-v1"
31
 
@@ -78,6 +84,10 @@ class Config(PretrainedConfig):
78
  layer_args=dict(),
79
  embedding_args=dict(),
80
  output_proj_args=dict(),
 
 
 
 
81
 
82
  **kwargs,
83
  ):
@@ -113,6 +123,10 @@ class Config(PretrainedConfig):
113
  self.layer_args = layer_args
114
  self.embedding_args = embedding_args
115
  self.output_proj_args = output_proj_args
 
 
 
 
116
 
117
  super().__init__(**kwargs)
118
 
@@ -204,6 +218,8 @@ class HFCausalModel(PreTrainedModel):
204
  _no_split_modules = ["DeepNetLayer"]
205
  _supports_flash_attn_2 = True
206
  _supports_sdpa = True
 
 
207
 
208
  def __init__(self, config):
209
  super().__init__(config)
@@ -221,40 +237,144 @@ class HFCausalModel(PreTrainedModel):
221
  token_type_ids: Optional[torch.LongTensor] = None,
222
  position_ids: Optional[torch.LongTensor] = None,
223
  labels: Optional[torch.LongTensor] = None,
 
 
224
  output_attentions: Optional[bool] = None,
225
  output_hidden_states: Optional[bool] = None,
226
  return_dict: Optional[bool] = None,
227
  **kwargs,
228
  ) -> (Tensor, dict[str, Tensor]):
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  if self.gradient_checkpointing and self.training:
 
 
 
 
 
231
  gradient_checkpointing_func = self._gradient_checkpointing_func
232
  else:
233
  gradient_checkpointing_func = None
 
234
 
235
- logits, attentions = self.transformer_head(
236
  input_ids=input_ids,
237
- need_weights=output_attentions,
 
238
  gradient_checkpointing_func=gradient_checkpointing_func,
 
 
 
239
  )
 
 
 
240
 
241
  # Compute loss.
242
  if labels is not None:
243
  loss = self.loss_function(logits=logits, labels=labels, input_ids=input_ids)
244
  else:
245
  loss = None
 
 
 
 
 
246
 
247
- return CausalLMOutput(loss=loss, logits=logits, attentions=attentions)
248
-
249
- # Needed for generate() method.
250
- def prepare_inputs_for_generation(self, input_ids, **kwargs):
251
- attention_mask = kwargs.get("attention_mask", None)
252
- model_inputs = {
253
- "input_ids": input_ids,
254
- "attention_mask": attention_mask,
255
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  return model_inputs
257
 
 
 
 
 
 
 
 
 
 
258
  def _make_embedding(self, config):
259
  embedding_cls = get_dynamic_class(config.embdding_cls)
260
  return embedding_cls(config.vocab_size, self.d_model, config.pad_index, **config.embedding_args)
@@ -278,7 +398,7 @@ class HFCausalModel(PreTrainedModel):
278
  norm_cls = get_dynamic_class(config.norm_cls)
279
  return norm_cls(self.d_model)
280
 
281
- def _make_self_attention(self, config):
282
  attention_cls = get_dynamic_class(config.attention_cls)
283
  # Map HF _attn_implementation to attn_type
284
  match config._attn_implementation:
@@ -299,28 +419,31 @@ class HFCausalModel(PreTrainedModel):
299
  d_model=self.d_model,
300
  num_heads=config.num_attention_heads,
301
  attn_type=attn_type,
 
302
  **config.attention_args,
303
  )
304
 
305
- def _make_feedforward(self, config):
306
  feedforward_cls = get_dynamic_class(config.feedforward_cls)
307
  return feedforward_cls(
308
  d_model=self.d_model,
309
  feedforward_dim=config.dim_feedforward,
310
  dropout=config.dropout,
311
  activation=self._make_activation(config),
 
312
  **config.feedforward_args,
313
  )
314
 
315
- def _make_layer(self, config):
316
  layer_cls = get_dynamic_class(config.layer_cls)
317
  return layer_cls(
318
  d_model=self.d_model,
319
  dropout=self._make_dropout(config),
320
- attention=self._make_self_attention(config),
321
- feedforward=self._make_feedforward(config),
322
  norm1=self._make_norm(config),
323
  norm2=self._make_norm(config),
 
324
  **config.layer_args,
325
  )
326
 
@@ -328,7 +451,7 @@ class HFCausalModel(PreTrainedModel):
328
  layer_stack_cls = get_dynamic_class(config.layer_stack_cls)
329
  return layer_stack_cls(
330
  layers=nn.ModuleList([
331
- self._make_layer(config) for _ in range(config.num_hidden_layers)
332
  ]),
333
  **config.layer_stack_args,
334
  )
@@ -364,18 +487,29 @@ class Transformer(nn.Module):
364
  self.sqrt_d_model = d_model**0.5
365
  self.reset_parameters()
366
 
367
- def forward(self, input_ids, need_weights, gradient_checkpointing_func):
368
- x = self.positional_encoder(self.embedding(input_ids) * self.sqrt_d_model)
369
-
370
- x, attentions = self.layer_stack(
371
- x,
372
- need_weights,
373
- gradient_checkpointing_func,
 
 
 
 
 
 
 
 
 
 
374
  )
375
 
376
- # Translate output embedding ot logits.
377
- logits = self.output_projection(x)
378
- return logits, attentions
 
379
 
380
  def reset_parameters(self):
381
  init.xavier_uniform_(self.output_projection.weight)
@@ -472,7 +606,7 @@ class RSWalshPositionalEncoder(nn.Module):
472
  # walsh = (hadamard_walsh_matrix(k)[:bits,:d_embed] -0.5) * self.gain
473
  self.register_buffer('walsh', walsh, persistent=False)
474
 
475
- def forward(self, x):
476
  seq_len = x.size(-2)
477
 
478
  # Get sequence of binary codes...
@@ -486,6 +620,12 @@ class RSWalshPositionalEncoder(nn.Module):
486
  shift = torch.randint(self.max_seq - seq_len + 1, (1,)).item()
487
  seq = self.binary_code[shift:seq_len + shift,:]
488
 
 
 
 
 
 
 
489
  # Disable shifting when not training. This does not appear to change the evaluation loss, but
490
  # it does makes predictions easier to analyse when the attention weights are not shifting with each step.
491
  else:
@@ -508,25 +648,58 @@ class TransformerLayerStack(nn.Module):
508
  super().__init__()
509
  self.layers = layers
510
 
511
- def forward(self, x, need_weights, gradient_checkpointing_func=None):
512
- attentions = []
 
 
 
 
 
 
 
 
 
 
 
513
  for layer in self.layers:
514
  if gradient_checkpointing_func is not None:
515
- x, attention_weights = gradient_checkpointing_func(
516
  layer.__call__,
517
- x,
518
- need_weights,
519
- use_reentrant=False
 
 
520
  )
521
  else:
522
- x, attention_weights = layer(x, need_weights=need_weights)
523
- if need_weights:
524
- attentions.append(attention_weights)
 
 
 
 
 
525
 
526
- return x, attentions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
 
528
  # DeepNet: Scaling Transformers to 1,000 Layers
529
  # https://arxiv.org/abs/2203.00555
 
530
  class DeepnetLayer(nn.Module):
531
  def __init__(
532
  self,
@@ -536,6 +709,7 @@ class DeepnetLayer(nn.Module):
536
  norm1,
537
  norm2,
538
  dropout,
 
539
  alpha=1.0,
540
  ):
541
  super().__init__()
@@ -547,27 +721,45 @@ class DeepnetLayer(nn.Module):
547
  self.dropout = dropout
548
  # Deepnet alpha
549
  self.alpha = alpha
 
550
 
551
- def forward(self, x, need_weights=False):
 
 
 
 
 
 
552
  # Keep input as residual
553
- residual = x * self.alpha
554
 
555
  # Compute attention
556
- x, attention_weights = self.attention(x, need_weights)
 
 
 
 
 
 
 
557
 
558
  # Add attention with residual and normalize.
559
- x = self.norm1(residual + self.dropout(x))
560
 
561
  # Keep output as next residual.
562
- residual = x * self.alpha
563
 
564
  # Pass through feedforward network.
565
- x = self.feedforward(x)
566
 
567
  # Combine residual and ff output, then normalize again.
568
- x = self.norm2(residual + self.dropout(x))
569
 
570
- return x, attention_weights
 
 
 
 
571
 
572
  # A vanilla MLP transfomer layer.
573
  class FeedforwardLayer(nn.Module):
@@ -576,6 +768,7 @@ class FeedforwardLayer(nn.Module):
576
  d_model: int,
577
  feedforward_dim: int,
578
  dropout,
 
579
  activation=nn.ReLU(),
580
  beta=1.0,
581
  bias=True,
@@ -605,6 +798,7 @@ class SwiGLUFeedforwardLayer(nn.Module):
605
  self,
606
  d_model,
607
  d_feedforward,
 
608
  beta=1.0,
609
  dropout=0.1
610
  ):
@@ -643,6 +837,7 @@ class CausalSelfAttention(nn.Module):
643
  # torch: Use pytorch "scaled_dot_product_attention()"; faster; generally good compatibility; does not support returning attn weights.
644
  # flash2: Use Flash-Attention2 implementation; fastest; limited to int16 and bfloat16 types; least memory usage.
645
  attn_type,
 
646
  beta=1.0,
647
  dropout=0.1,
648
  ):
@@ -651,6 +846,7 @@ class CausalSelfAttention(nn.Module):
651
  self.num_heads = num_heads
652
  self.beta = beta
653
  self.attn_type = attn_type
 
654
 
655
  assert d_model % num_heads == 0, "d_model must be evenly divisible by num_heads"
656
 
@@ -685,9 +881,18 @@ class CausalSelfAttention(nn.Module):
685
  proj = self.in_proj(qkv)
686
  return proj.chunk(chunks=3, dim=-1)
687
 
688
- def forward(self, qkv, need_weights):
 
 
 
 
 
 
689
  if self.attn_type == "flash2":
690
- return self.flash2_forward(qkv)
 
 
 
691
 
692
  # qkv: (batch_size, seq_len, d_embed)
693
  batch_size, seq_len, d_embed = qkv.shape
@@ -700,8 +905,12 @@ class CausalSelfAttention(nn.Module):
700
  key = key.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
701
  value = value.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
702
 
 
 
 
 
703
  # Default to returning empty attention weights.
704
- attention_weights = None
705
 
706
  if self.attn_type == "torch":
707
  # This context manager can be used to force which implementation to use.
@@ -730,28 +939,40 @@ class CausalSelfAttention(nn.Module):
730
  )
731
 
732
  # Calculate the attention weights; avoid NANs that might emerge from zeros in softmax's denominator
733
- attention_weights = self.dropout(torch.softmax(scores, dim=-1).clamp(min=1e-10))
734
  del scores
735
 
736
  # Use the attention weights to get a weighted combination of value vectors
737
- attended_values = torch.matmul(attention_weights, value)
738
- if not need_weights:
739
- del attention_weights
740
- attention_weights = None
741
 
742
  # Concatenate attention heads and project to original embedding size using the output linear layer
743
  attended_values = attended_values.transpose(1, 2).contiguous().view(batch_size, seq_len, d_embed)
744
 
745
  # Project the concatenated output through the output matrix.
746
  attended_values = self.output_linear(attended_values)
747
- return attended_values, attention_weights
 
 
 
 
 
748
 
749
- def flash2_forward(self, qkv):
 
 
 
750
  batch_size, seq_len, d_embed = qkv.shape
751
 
752
  # Feed the inputs through the K, Q, V matrices.
753
  # query : (batch_size, seq_len, d_model)
754
  # qkv : (batch_size, seq_len, 3, num_heads, d_kq)
 
 
 
 
755
  qkv = self.in_proj(qkv).unflatten(
756
  -1,
757
  (3, self.num_heads, self.d_head)
@@ -770,7 +991,89 @@ class CausalSelfAttention(nn.Module):
770
 
771
  # Project the concatenated output through the output matrix.
772
  attended_values = self.output_linear(attended_values)
773
- return attended_values, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
774
 
775
  # Attention layer with ALiBi relative positional encoding
776
  # TRAIN SHORT, TEST LONG: ATTENTION WITH LINEAR BIASES ENABLES INPUT LENGTH EXTRAPOLATION
@@ -907,7 +1210,7 @@ class CausalAlibiAttention(nn.Module):
907
 
908
  # Use the attention weights to get a weighted combination of value vectors
909
  attended_values = torch.matmul(attention_weights, value)
910
- if not need_weights:
911
  attention_weights = None
912
 
913
  # Concatenate attention heads and project to original embedding size using the output linear layer
@@ -946,4 +1249,4 @@ class CausalAlibiAttention(nn.Module):
946
 
947
  # Project the concatenated output through the output matrix.
948
  attended_values = self.output_linear(attended_values)
949
- return attended_values, None
 
1
  # See: https://huggingface.co/docs/transformers/custom_models
2
+ from typing import Optional, Tuple, Union, List
3
  import math
4
  import copy
5
  import sys
 
9
  from torch import nn, Tensor
10
  import torch.nn.init as init
11
  from torch.nn import functional as F
12
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutput, CausalLMOutputWithPast
13
  from transformers import (
14
  PreTrainedModel,
15
  PretrainedConfig,
 
18
  AutoModelForCausalLM,
19
  )
20
 
21
+ from transformers.utils import logging
22
+
23
+ from transformers.cache_utils import Cache, DynamicCache
24
+
25
  from transformers.utils import (
26
  is_flash_attn_2_available,
27
  is_flash_attn_greater_or_equal_2_10,
 
30
  if is_flash_attn_2_available():
31
  from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
32
 
33
+ logger = logging.get_logger(__name__)
34
+
35
  # The model type string to bind.
36
  model_type = "walsh-causal-v1"
37
 
 
84
  layer_args=dict(),
85
  embedding_args=dict(),
86
  output_proj_args=dict(),
87
+
88
+ output_attentions=False,
89
+ output_hidden_states=False,
90
+ use_cache=True,
91
 
92
  **kwargs,
93
  ):
 
123
  self.layer_args = layer_args
124
  self.embedding_args = embedding_args
125
  self.output_proj_args = output_proj_args
126
+
127
+ self.output_attentions = output_attentions
128
+ self.output_hidden_states = output_hidden_states
129
+ self.use_cache = use_cache
130
 
131
  super().__init__(**kwargs)
132
 
 
218
  _no_split_modules = ["DeepNetLayer"]
219
  _supports_flash_attn_2 = True
220
  _supports_sdpa = True
221
+ _supports_cache_class = True
222
+ _skip_keys_device_placement = "past_key_values"
223
 
224
  def __init__(self, config):
225
  super().__init__(config)
 
237
  token_type_ids: Optional[torch.LongTensor] = None,
238
  position_ids: Optional[torch.LongTensor] = None,
239
  labels: Optional[torch.LongTensor] = None,
240
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
241
+ use_cache: Optional[bool] = None,
242
  output_attentions: Optional[bool] = None,
243
  output_hidden_states: Optional[bool] = None,
244
  return_dict: Optional[bool] = None,
245
  **kwargs,
246
  ) -> (Tensor, dict[str, Tensor]):
247
 
248
+ batch_size, seq_len = input_ids.shape
249
+
250
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
251
+ output_hidden_states = (
252
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
253
+ )
254
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
255
+
256
+ if use_cache:
257
+ # If legacy cache, convert to DynamicCache
258
+ use_legacy_cache = not isinstance(past_key_values, Cache)
259
+ if use_legacy_cache:
260
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
261
+
262
+
263
  if self.gradient_checkpointing and self.training:
264
+ if use_cache:
265
+ logger.warning_once(
266
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
267
+ )
268
+ use_cache = False
269
  gradient_checkpointing_func = self._gradient_checkpointing_func
270
  else:
271
  gradient_checkpointing_func = None
272
+
273
 
274
+ outputs = self.transformer_head(
275
  input_ids=input_ids,
276
+ position_ids=position_ids,
277
+ output_attentions=output_attentions,
278
  gradient_checkpointing_func=gradient_checkpointing_func,
279
+ past_key_values=past_key_values,
280
+ use_cache=use_cache,
281
+ output_hidden_states=output_hidden_states,
282
  )
283
+
284
+ logits = outputs["logits"].float()
285
+ attentions = outputs["attentions"]
286
 
287
  # Compute loss.
288
  if labels is not None:
289
  loss = self.loss_function(logits=logits, labels=labels, input_ids=input_ids)
290
  else:
291
  loss = None
292
+
293
+ # Convert back to legacy cache, if that's what we received
294
+ new_cache = outputs["past_key_values"]
295
+ if use_cache and new_cache is not None and use_legacy_cache:
296
+ new_cache = new_cache.to_legacy_cache()
297
 
298
+ return CausalLMOutputWithPast(
299
+ loss=loss,
300
+ logits=logits,
301
+ past_key_values=new_cache,
302
+ hidden_states=outputs["hidden_states"],
303
+ attentions=outputs["attentions"],
304
+ )
305
+
306
+ # Implementation from Huggingface Transformers,
307
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py
308
+ # Note: We do not implement attention mask at present, so some of this code is not applicable
309
+ # TODO: Reenable attention mask support for batch inference..
310
+ def prepare_inputs_for_generation(
311
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
312
+ ):
313
+ # Omit tokens covered by past_key_values
314
+ if past_key_values is not None:
315
+ if isinstance(past_key_values, Cache):
316
+ cache_length = past_key_values.get_seq_length()
317
+ past_length = past_key_values.seen_tokens
318
+ max_cache_length = past_key_values.get_max_length()
319
+ else:
320
+ cache_length = past_length = past_key_values[0][0].shape[2]
321
+ max_cache_length = None
322
+
323
+ # Keep only the unprocessed tokens:
324
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
325
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
326
+ # input)
327
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
328
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
329
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
330
+ # input_ids based on the past_length.
331
+ elif past_length < input_ids.shape[1]:
332
+ input_ids = input_ids[:, past_length:]
333
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
334
+
335
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
336
+ if (
337
+ max_cache_length is not None
338
+ and attention_mask is not None
339
+ and cache_length + input_ids.shape[1] > max_cache_length
340
+ ):
341
+ attention_mask = attention_mask[:, -max_cache_length:]
342
+
343
+ # NOTE: "RSWalsh" models don't need to have their absolute positions adjusted to zero; they are trained for this.
344
+ position_ids = kwargs.get("position_ids", None)
345
+ if attention_mask is not None and position_ids is None:
346
+ # create position_ids on the fly for batch generation
347
+ position_ids = attention_mask.long().cumsum(-1) - 1
348
+ position_ids.masked_fill_(attention_mask == 0, 1)
349
+ if past_key_values:
350
+ position_ids = position_ids[:, -input_ids.shape[1] :]
351
+
352
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
353
+ # NOTE: Injecting positional embeddings is not yet supported.
354
+ if inputs_embeds is not None and past_key_values is None:
355
+ model_inputs = {"inputs_embeds": inputs_embeds}
356
+ else:
357
+ model_inputs = {"input_ids": input_ids}
358
+
359
+ model_inputs.update(
360
+ {
361
+ "position_ids": position_ids,
362
+ "past_key_values": past_key_values,
363
+ "use_cache": kwargs.get("use_cache"),
364
+ "attention_mask": attention_mask,
365
+ }
366
+ )
367
  return model_inputs
368
 
369
+ @staticmethod
370
+ def _reorder_cache(past_key_values, beam_idx):
371
+ reordered_past = ()
372
+ for layer_past in past_key_values:
373
+ reordered_past += (
374
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
375
+ )
376
+ return reordered_past
377
+
378
  def _make_embedding(self, config):
379
  embedding_cls = get_dynamic_class(config.embdding_cls)
380
  return embedding_cls(config.vocab_size, self.d_model, config.pad_index, **config.embedding_args)
 
398
  norm_cls = get_dynamic_class(config.norm_cls)
399
  return norm_cls(self.d_model)
400
 
401
+ def _make_self_attention(self, layer_idx, config):
402
  attention_cls = get_dynamic_class(config.attention_cls)
403
  # Map HF _attn_implementation to attn_type
404
  match config._attn_implementation:
 
419
  d_model=self.d_model,
420
  num_heads=config.num_attention_heads,
421
  attn_type=attn_type,
422
+ layer_idx=layer_idx,
423
  **config.attention_args,
424
  )
425
 
426
+ def _make_feedforward(self, layer_idx, config):
427
  feedforward_cls = get_dynamic_class(config.feedforward_cls)
428
  return feedforward_cls(
429
  d_model=self.d_model,
430
  feedforward_dim=config.dim_feedforward,
431
  dropout=config.dropout,
432
  activation=self._make_activation(config),
433
+ layer_idx=layer_idx,
434
  **config.feedforward_args,
435
  )
436
 
437
+ def _make_layer(self, layer_idx, config):
438
  layer_cls = get_dynamic_class(config.layer_cls)
439
  return layer_cls(
440
  d_model=self.d_model,
441
  dropout=self._make_dropout(config),
442
+ attention=self._make_self_attention(layer_idx, config),
443
+ feedforward=self._make_feedforward(layer_idx, config),
444
  norm1=self._make_norm(config),
445
  norm2=self._make_norm(config),
446
+ layer_idx=layer_idx,
447
  **config.layer_args,
448
  )
449
 
 
451
  layer_stack_cls = get_dynamic_class(config.layer_stack_cls)
452
  return layer_stack_cls(
453
  layers=nn.ModuleList([
454
+ self._make_layer(layer_idx, config) for layer_idx in range(config.num_hidden_layers)
455
  ]),
456
  **config.layer_stack_args,
457
  )
 
487
  self.sqrt_d_model = d_model**0.5
488
  self.reset_parameters()
489
 
490
+ def forward(
491
+ self,
492
+ input_ids,
493
+ position_ids,
494
+ output_attentions,
495
+ gradient_checkpointing_func,
496
+ past_key_values,
497
+ use_cache,
498
+ output_hidden_states,
499
+ ):
500
+ outputs = self.layer_stack(
501
+ self.positional_encoder(self.embedding(input_ids) * self.sqrt_d_model, position_ids),
502
+ output_attentions=output_attentions,
503
+ gradient_checkpointing_func=gradient_checkpointing_func,
504
+ past_key_values=past_key_values,
505
+ use_cache=use_cache,
506
+ output_hidden_states=output_hidden_states,
507
  )
508
 
509
+ # Translate output states to logits.
510
+ outputs["logits"] = self.output_projection(outputs["last_hidden_state"])
511
+ del outputs["last_hidden_state"]
512
+ return outputs
513
 
514
  def reset_parameters(self):
515
  init.xavier_uniform_(self.output_projection.weight)
 
606
  # walsh = (hadamard_walsh_matrix(k)[:bits,:d_embed] -0.5) * self.gain
607
  self.register_buffer('walsh', walsh, persistent=False)
608
 
609
+ def forward(self, x, position_ids=None):
610
  seq_len = x.size(-2)
611
 
612
  # Get sequence of binary codes...
 
620
  shift = torch.randint(self.max_seq - seq_len + 1, (1,)).item()
621
  seq = self.binary_code[shift:seq_len + shift,:]
622
 
623
+ # When the cache is used for generation, after the first call, we are only passed a single token at a time,
624
+ # with the remaining tokens being in the cache. We need to make sure that the newly injected tokens have the
625
+ # correct relative position by indexing the codes with the position_ids.
626
+ elif position_ids != None:
627
+ seq = self.binary_code[position_ids, :]
628
+
629
  # Disable shifting when not training. This does not appear to change the evaluation loss, but
630
  # it does makes predictions easier to analyse when the attention weights are not shifting with each step.
631
  else:
 
648
  super().__init__()
649
  self.layers = layers
650
 
651
+ def forward(
652
+ self,
653
+ hidden_states,
654
+ output_attentions,
655
+ past_key_values,
656
+ use_cache,
657
+ output_hidden_states,
658
+ gradient_checkpointing_func=None,
659
+ ):
660
+ present_key_value = None
661
+ all_attentions = [] if output_attentions else None
662
+ all_hidden_states = [hidden_states] if output_hidden_states else None
663
+
664
  for layer in self.layers:
665
  if gradient_checkpointing_func is not None:
666
+ layer_outputs = gradient_checkpointing_func(
667
  layer.__call__,
668
+ hidden_states,
669
+ output_attentions,
670
+ past_key_values,
671
+ use_cache,
672
+ use_reentrant=False,
673
  )
674
  else:
675
+ layer_outputs = layer(
676
+ hidden_states,
677
+ output_attentions,
678
+ past_key_values,
679
+ use_cache,
680
+ )
681
+
682
+ hidden_states = layer_outputs["hidden_states"]
683
 
684
+ if output_hidden_states:
685
+ all_hidden_states.append(hidden_states)
686
+
687
+ if use_cache:
688
+ present_key_value = layer_outputs["past_key_values"]
689
+
690
+ if output_attentions:
691
+ all_attentions.append(layer_outputs["attentions"])
692
+
693
+ return dict(
694
+ last_hidden_state=hidden_states,
695
+ past_key_values=present_key_value,
696
+ hidden_states=hidden_states,
697
+ attentions=all_attentions,
698
+ )
699
 
700
  # DeepNet: Scaling Transformers to 1,000 Layers
701
  # https://arxiv.org/abs/2203.00555
702
+ # Note: This is a type of Pre-Layer-Norm Transformer layer.
703
  class DeepnetLayer(nn.Module):
704
  def __init__(
705
  self,
 
709
  norm1,
710
  norm2,
711
  dropout,
712
+ layer_idx,
713
  alpha=1.0,
714
  ):
715
  super().__init__()
 
721
  self.dropout = dropout
722
  # Deepnet alpha
723
  self.alpha = alpha
724
+ self.layer_idx = layer_idx
725
 
726
+ def forward(
727
+ self,
728
+ hidden_states,
729
+ output_attentions,
730
+ past_key_values,
731
+ use_cache,
732
+ ):
733
  # Keep input as residual
734
+ residual = hidden_states * self.alpha
735
 
736
  # Compute attention
737
+ attn_outputs = self.attention(
738
+ hidden_states,
739
+ past_key_values=past_key_values,
740
+ use_cache=use_cache,
741
+ output_attentions=output_attentions
742
+ )
743
+
744
+ hidden_states = attn_outputs["hidden_states"]
745
 
746
  # Add attention with residual and normalize.
747
+ hidden_states = self.norm1(residual + self.dropout(hidden_states))
748
 
749
  # Keep output as next residual.
750
+ residual = hidden_states * self.alpha
751
 
752
  # Pass through feedforward network.
753
+ hidden_states = self.feedforward(hidden_states)
754
 
755
  # Combine residual and ff output, then normalize again.
756
+ hidden_states = self.norm2(residual + self.dropout(hidden_states))
757
 
758
+ return dict(
759
+ hidden_states=hidden_states,
760
+ attentions=attn_outputs["attentions"],
761
+ past_key_values=attn_outputs["past_key_values"]
762
+ )
763
 
764
  # A vanilla MLP transfomer layer.
765
  class FeedforwardLayer(nn.Module):
 
768
  d_model: int,
769
  feedforward_dim: int,
770
  dropout,
771
+ layer_idx,
772
  activation=nn.ReLU(),
773
  beta=1.0,
774
  bias=True,
 
798
  self,
799
  d_model,
800
  d_feedforward,
801
+ layer_idx,
802
  beta=1.0,
803
  dropout=0.1
804
  ):
 
837
  # torch: Use pytorch "scaled_dot_product_attention()"; faster; generally good compatibility; does not support returning attn weights.
838
  # flash2: Use Flash-Attention2 implementation; fastest; limited to int16 and bfloat16 types; least memory usage.
839
  attn_type,
840
+ layer_idx,
841
  beta=1.0,
842
  dropout=0.1,
843
  ):
 
846
  self.num_heads = num_heads
847
  self.beta = beta
848
  self.attn_type = attn_type
849
+ self.layer_idx = layer_idx
850
 
851
  assert d_model % num_heads == 0, "d_model must be evenly divisible by num_heads"
852
 
 
881
  proj = self.in_proj(qkv)
882
  return proj.chunk(chunks=3, dim=-1)
883
 
884
+ def forward(
885
+ self,
886
+ qkv,
887
+ output_attentions,
888
+ past_key_values,
889
+ use_cache,
890
+ ):
891
  if self.attn_type == "flash2":
892
+ if use_cache is None or use_cache == False:
893
+ return self.flash2_forward(qkv)
894
+ else:
895
+ return self.flash2_forward_cached(qkv, past_key_values)
896
 
897
  # qkv: (batch_size, seq_len, d_embed)
898
  batch_size, seq_len, d_embed = qkv.shape
 
905
  key = key.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
906
  value = value.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
907
 
908
+ # Update the cache values.
909
+ if past_key_values is not None:
910
+ key, value = past_key_values.update(key, value, self.layer_idx)
911
+
912
  # Default to returning empty attention weights.
913
+ attentions = None
914
 
915
  if self.attn_type == "torch":
916
  # This context manager can be used to force which implementation to use.
 
939
  )
940
 
941
  # Calculate the attention weights; avoid NANs that might emerge from zeros in softmax's denominator
942
+ attentions = self.dropout(torch.softmax(scores, dim=-1).clamp(min=1e-10))
943
  del scores
944
 
945
  # Use the attention weights to get a weighted combination of value vectors
946
+ attended_values = torch.matmul(attentions, value)
947
+ if not output_attentions:
948
+ del attentions
949
+ attentions = None
950
 
951
  # Concatenate attention heads and project to original embedding size using the output linear layer
952
  attended_values = attended_values.transpose(1, 2).contiguous().view(batch_size, seq_len, d_embed)
953
 
954
  # Project the concatenated output through the output matrix.
955
  attended_values = self.output_linear(attended_values)
956
+ return dict(
957
+ hidden_states=attended_values,
958
+ attentions=attentions,
959
+ # Unimplemented...
960
+ past_key_values=None
961
+ )
962
 
963
+ def flash2_forward(
964
+ self,
965
+ qkv,
966
+ ):
967
  batch_size, seq_len, d_embed = qkv.shape
968
 
969
  # Feed the inputs through the K, Q, V matrices.
970
  # query : (batch_size, seq_len, d_model)
971
  # qkv : (batch_size, seq_len, 3, num_heads, d_kq)
972
+ # Feed the inputs through the K, Q, V matrices.
973
+ # query : (batch_size, seq_len, d_model)
974
+ # qkv : (batch_size, seq_len, 3, num_heads, d_kq)
975
+
976
  qkv = self.in_proj(qkv).unflatten(
977
  -1,
978
  (3, self.num_heads, self.d_head)
 
991
 
992
  # Project the concatenated output through the output matrix.
993
  attended_values = self.output_linear(attended_values)
994
+ return dict(
995
+ hidden_states=attended_values,
996
+ attentions=None,
997
+ past_key_values=None
998
+ )
999
+
1000
+ # See https://github.com/huggingface/transformers/blob/main/src/transformers/cache_utils.py
1001
+ #https://huggingface.co/docs/transformers/internal/generation_utils
1002
+ def flash2_forward_cached(
1003
+ self,
1004
+ qkv,
1005
+ past_key_values,
1006
+ ):
1007
+ batch_size, seq_len, d_embed = qkv.shape
1008
+
1009
+ # Feed the inputs through the K, Q, V matrices.
1010
+ query, key, value = self.project_input(qkv)
1011
+
1012
+ # TODO: Refactor -- this code is repeated in the baseline implementation.
1013
+ # Split projections into multiple heads and swap position of sequence / heads dimension
1014
+ query = query.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
1015
+ key = key.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
1016
+ value = value.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
1017
+
1018
+ if past_key_values is not None:
1019
+ key, value = past_key_values.update(key, value, self.layer_idx)
1020
+
1021
+ #query, key, value = self._downcast_to_float16(query, key, value)
1022
+
1023
+ # Expected inputs to flash2:
1024
+ # q: (batch_size, seqlen, nheads, headdim)
1025
+ # k: (batch_size, seqlen, nheads_k, headdim)
1026
+ # v: (batch_size, seqlen, nheads_k, headdim)
1027
+ query = query.transpose(1, 2)
1028
+ key = key.transpose(1, 2)
1029
+ value = value.transpose(1, 2)
1030
+
1031
+ attended_values = flash_attn_func(
1032
+ q=query,
1033
+ k=key,
1034
+ v=value,
1035
+ dropout_p=self.dropout.p if self.training else 0.0,
1036
+ softmax_scale=self.dot_product_scale,
1037
+ causal=True,
1038
+ )
1039
+ # attended_values: (batch_size, seqlen, nheads, headdim)
1040
+
1041
+ # Concatentate heads back into d_embed
1042
+ attended_values = attended_values.view(batch_size, seq_len, d_embed)
1043
+
1044
+ # Project the concatenated output through the output matrix.
1045
+ attended_values = self.output_linear(attended_values)
1046
+ return dict(
1047
+ hidden_states=attended_values,
1048
+ attentions=None,
1049
+ past_key_values=past_key_values
1050
+ )
1051
+
1052
+ @staticmethod
1053
+ def _downcast_to_float16(query, key, value):
1054
+ # Copied section for Transformers to handle this
1055
+ # TODO: Revist other Flash2 impelementation, above
1056
+ input_dtype = query.dtype
1057
+ if input_dtype == torch.float32:
1058
+ if torch.is_autocast_enabled():
1059
+ target_dtype = torch.get_autocast_gpu_dtype()
1060
+ # Handle the case where the model is quantized
1061
+ elif hasattr(self.config, "_pre_quantization_dtype"):
1062
+ target_dtype = self.config._pre_quantization_dtype
1063
+ else:
1064
+ target_dtype = self.q_proj.weight.dtype
1065
+ logger.warning_once(
1066
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
1067
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
1068
+ f" {target_dtype}."
1069
+ )
1070
+ query = query.to(target_dtype)
1071
+ key = key.to(target_dtype)
1072
+ value = value.to(target_dtype)
1073
+ return query, key, value
1074
+
1075
+
1076
+ ########### TODO: Update to newer API, with inference cache
1077
 
1078
  # Attention layer with ALiBi relative positional encoding
1079
  # TRAIN SHORT, TEST LONG: ATTENTION WITH LINEAR BIASES ENABLES INPUT LENGTH EXTRAPOLATION
 
1210
 
1211
  # Use the attention weights to get a weighted combination of value vectors
1212
  attended_values = torch.matmul(attention_weights, value)
1213
+ if not output_attentions:
1214
  attention_weights = None
1215
 
1216
  # Concatenate attention heads and project to original embedding size using the output linear layer
 
1249
 
1250
  # Project the concatenated output through the output matrix.
1251
  attended_values = self.output_linear(attended_values)
1252
+ return attended_values, None