fixed flash_attention backward_compat

#3
by itlevy - opened
Files changed (4) hide show
  1. NOTICE +0 -5
  2. README.md +3 -5
  3. modeling_decilm.py +2 -46
  4. variable_cache.py +9 -14
NOTICE DELETED
@@ -1,5 +0,0 @@
1
- Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
-
3
- NVIDIA CORPORATION, its affiliates and licensors retain all intellectual property and proprietary rights in and to this material, related documentation and any modifications thereto. Any use, reproduction, disclosure or distribution of this material and related documentation without an express license agreement from NVIDIA CORPORATION or its affiliates is strictly prohibited.
4
-
5
- Llama 3.1 is licensed under the Llama 3.1 Community License, Copyright © Meta Platforms, Inc. All Rights Reserved.
 
 
 
 
 
 
README.md CHANGED
@@ -8,9 +8,9 @@ tags:
8
  - llama-3
9
  - pytorch
10
  license: other
11
- license_name: nvidia-open-model-license
12
  license_link: >-
13
- https://developer.download.nvidia.com/licenses/nvidia-open-model-license-agreement-june-2024.pdf
14
  ---
15
 
16
  # Llama-3_1-Nemotron-51B-instruct
@@ -22,8 +22,7 @@ Llama-3_1-Nemotron-51B-instruct is a model which offers a great tradeoff between
22
 
23
 
24
  ## License
25
- This model is released under the [NVIDIA Open Model License Agreement](https://developer.download.nvidia.com/licenses/nvidia-open-model-license-agreement-june-2024.pdf).
26
- Additional Information: [Llama 3.1 Community License Agreement](https://www.llama.com/llama3_1/license/). Built with Llama.
27
 
28
  ## How was the model developed
29
 
@@ -33,7 +32,6 @@ The KD step included 40 billion tokens consisting of a mixture of 3 datasets - F
33
  Links to [NIM](https://build.nvidia.com/nvidia/llama-3_1-nemotron-51b-instruct), [blog](https://developer.nvidia.com/blog/advancing-the-accuracy-efficiency-frontier-with-llama-3-1-nemotron-51b/) and [huggingface](https://huggingface.co/nvidia/Llama-3_1-Nemotron-51B-Instruct)
34
 
35
 
36
-
37
  This results in a final model that is aligned for human chat preferences.
38
 
39
  **Model Developers:** NVIDIA
 
8
  - llama-3
9
  - pytorch
10
  license: other
11
+ license_name: nvidia-ai-foundation-models-community-license
12
  license_link: >-
13
+ https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-ai-foundation-models-community-license-agreement/
14
  ---
15
 
16
  # Llama-3_1-Nemotron-51B-instruct
 
22
 
23
 
24
  ## License
25
+ [NVIDIA AI Foundation Models Community License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-ai-foundation-models-community-license-agreement/). Additional Information: [Llama 3.1 Community License Agreement](https://www.llama.com/llama3_1/license/). Built with Llama.
 
26
 
27
  ## How was the model developed
28
 
 
32
  Links to [NIM](https://build.nvidia.com/nvidia/llama-3_1-nemotron-51b-instruct), [blog](https://developer.nvidia.com/blog/advancing-the-accuracy-efficiency-frontier-with-llama-3-1-nemotron-51b/) and [huggingface](https://huggingface.co/nvidia/Llama-3_1-Nemotron-51B-Instruct)
33
 
34
 
 
35
  This results in a final model that is aligned for human chat preferences.
36
 
37
  **Model Developers:** NVIDIA
modeling_decilm.py CHANGED
@@ -25,7 +25,7 @@ import torch.utils.checkpoint
25
  from torch import nn
26
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
  from transformers import GenerationConfig
28
- from transformers.generation.utils import NEED_SETUP_CACHE_CLASSES_MAPPING, GenerationMixin, GenerateOutput
29
  from transformers.modeling_utils import PreTrainedModel
30
  from transformers.utils import (
31
  add_start_docstrings,
@@ -1131,7 +1131,7 @@ class DeciLMModel(DeciLMPreTrainedModel):
1131
  return causal_mask
1132
 
1133
 
1134
- class DeciLMForCausalLM(DeciLMPreTrainedModel, GenerationMixin):
1135
  _tied_weights_keys = ["lm_head.weight"]
1136
 
1137
  def __init__(self, config):
@@ -1311,50 +1311,6 @@ class DeciLMForCausalLM(DeciLMPreTrainedModel, GenerationMixin):
1311
  )
1312
  return model_inputs
1313
 
1314
- def _maybe_initialize_input_ids_for_generation(
1315
- self,
1316
- inputs: Optional[torch.Tensor] = None,
1317
- bos_token_id: Optional[torch.Tensor] = None,
1318
- model_kwargs: Optional[dict[str, torch.Tensor]] = None,
1319
- ) -> torch.LongTensor:
1320
- """
1321
- Patching hf bug that creates wrong cache length if only inputs_embeds are passed to the model
1322
- """
1323
- input_ids = super()._maybe_initialize_input_ids_for_generation(
1324
- inputs=inputs, bos_token_id=bos_token_id, model_kwargs=model_kwargs)
1325
- if (
1326
- "inputs_embeds" in model_kwargs
1327
- and input_ids is not None
1328
- and input_ids.shape[1] == 0
1329
- ):
1330
- batch_size, input_sequence_length = model_kwargs["inputs_embeds"].shape[:2]
1331
- input_ids = torch.zeros((batch_size, input_sequence_length), dtype=torch.long, device=self.device)
1332
- return input_ids
1333
-
1334
- def generate(
1335
- self,
1336
- inputs: Optional[torch.Tensor] = None,
1337
- *args,
1338
- **kwargs,
1339
- ) -> Union[GenerateOutput, torch.LongTensor]:
1340
- """
1341
- Patching hf bug that creates wrong cache length if only inputs_embeds are passed to the model
1342
- """
1343
- only_passed_inputs_embeds = (
1344
- "inputs_embeds" in kwargs and
1345
- "input_ids" not in kwargs and
1346
- inputs is None
1347
- )
1348
- if only_passed_inputs_embeds:
1349
- input_sequence_length = kwargs["inputs_embeds"].shape[1]
1350
-
1351
- generation_output = super().generate(inputs=inputs, *args, **kwargs)
1352
-
1353
- if only_passed_inputs_embeds and isinstance(generation_output, torch.Tensor):
1354
- generation_output = generation_output[:, input_sequence_length:]
1355
-
1356
- return generation_output
1357
-
1358
 
1359
  @add_start_docstrings(
1360
  """
 
25
  from torch import nn
26
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
  from transformers import GenerationConfig
28
+ from transformers.generation.utils import NEED_SETUP_CACHE_CLASSES_MAPPING
29
  from transformers.modeling_utils import PreTrainedModel
30
  from transformers.utils import (
31
  add_start_docstrings,
 
1131
  return causal_mask
1132
 
1133
 
1134
+ class DeciLMForCausalLM(DeciLMPreTrainedModel):
1135
  _tied_weights_keys = ["lm_head.weight"]
1136
 
1137
  def __init__(self, config):
 
1311
  )
1312
  return model_inputs
1313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1314
 
1315
  @add_start_docstrings(
1316
  """
variable_cache.py CHANGED
@@ -32,21 +32,17 @@ class VariableCache(Cache_4_44_2, Cache):
32
  The cache of each layer is allocated to the same gpu as the layer itself.
33
  """
34
 
35
- def __init__(
36
- self,
37
- *, # key-word only, no positional args allowed to avoid mix-ups with newer transformers versions
38
- config: DeciLMConfig,
39
- batch_size: int = None,
40
- max_cache_len: int = None,
41
- dtype: torch.dtype = torch.float32,
42
- max_batch_size: Optional[int] = None,
43
- **kwargs: Any,
44
- ) -> None:
45
  Cache_4_44_2.__init__(self)
46
 
47
- self.config = deepcopy(config)
48
- self.max_batch_size = batch_size or max_batch_size
49
- self.batch_size = self.max_batch_size
50
  self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
51
  self.dtype = dtype
52
 
@@ -83,7 +79,6 @@ class VariableCache(Cache_4_44_2, Cache):
83
  if attention_config.no_op or attention_config.replace_with_linear:
84
  return None
85
  config = deepcopy(self.config)
86
- config.num_hidden_layers = 1
87
  config.num_key_value_heads = self.config.num_attention_heads // attention_config.n_heads_in_group
88
  return StaticCache(config, self.max_batch_size, self.max_cache_len, device, self.dtype)
89
 
 
32
  The cache of each layer is allocated to the same gpu as the layer itself.
33
  """
34
 
35
+ def __init__(self,
36
+ config: DeciLMConfig,
37
+ max_batch_size: int,
38
+ max_cache_len: int | None,
39
+ device: torch.device | str | None = None,
40
+ dtype: torch.dtype | None = None,
41
+ ):
 
 
 
42
  Cache_4_44_2.__init__(self)
43
 
44
+ self.config = config
45
+ self.max_batch_size = max_batch_size
 
46
  self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
47
  self.dtype = dtype
48
 
 
79
  if attention_config.no_op or attention_config.replace_with_linear:
80
  return None
81
  config = deepcopy(self.config)
 
82
  config.num_key_value_heads = self.config.num_attention_heads // attention_config.n_heads_in_group
83
  return StaticCache(config, self.max_batch_size, self.max_cache_len, device, self.dtype)
84