tomer-deci commited on
Commit
83e1abe
·
1 Parent(s): 0be2d64

added support for text-generation pipeline

Browse files
Files changed (3) hide show
  1. config.json +2 -1
  2. modeling_decilm.py +6 -6
  3. version_check.py +3 -3
config.json CHANGED
@@ -25,5 +25,6 @@
25
  "use_bfloat16": true,
26
  "transformers_version": "4.35.2",
27
  "use_cache": true,
28
- "vocab_size": 32000
 
29
  }
 
25
  "use_bfloat16": true,
26
  "transformers_version": "4.35.2",
27
  "use_cache": true,
28
+ "vocab_size": 32000,
29
+ "tokenizer_class": "LlamaTokenizer"
30
  }
modeling_decilm.py CHANGED
@@ -11,18 +11,18 @@ import torch
11
  import torch.nn.functional as F
12
  import torch.utils.checkpoint
13
  from torch import nn
 
 
14
 
 
 
15
  from .transformers_v4_35_2__modeling_llama import LlamaMLP, LlamaRMSNorm, LlamaAttention, apply_rotary_pos_emb, \
16
  repeat_kv, LlamaPreTrainedModel, LLAMA_START_DOCSTRING, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, \
17
  BaseModelOutputWithPast, LLAMA_INPUTS_DOCSTRING
18
- from .transformers_v4_35_2__modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
19
- from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
20
-
21
- from .configuration_decilm import DeciLMConfig
22
-
23
- logger = logging.get_logger(__name__)
24
 
 
25
  _CONFIG_FOR_DOC = "DeciLMConfig"
 
26
 
27
 
28
  class DeciLMAttention(LlamaAttention):
 
11
  import torch.nn.functional as F
12
  import torch.utils.checkpoint
13
  from torch import nn
14
+ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
15
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
16
 
17
+ from .configuration_decilm import DeciLMConfig
18
+ from .transformers_v4_35_2__modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
19
  from .transformers_v4_35_2__modeling_llama import LlamaMLP, LlamaRMSNorm, LlamaAttention, apply_rotary_pos_emb, \
20
  repeat_kv, LlamaPreTrainedModel, LLAMA_START_DOCSTRING, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, \
21
  BaseModelOutputWithPast, LLAMA_INPUTS_DOCSTRING
 
 
 
 
 
 
22
 
23
+ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES["deci"] = "DeciLMForCausalLM"
24
  _CONFIG_FOR_DOC = "DeciLMConfig"
25
+ logger = logging.get_logger(__name__)
26
 
27
 
28
  class DeciLMAttention(LlamaAttention):
version_check.py CHANGED
@@ -1,11 +1,11 @@
1
  import transformers
2
  from packaging import version
3
 
4
- VERSION = "4.35.2"
5
 
6
 
7
  def check_transformers_version():
8
- if version.parse(transformers.__version__) < version.parse(VERSION):
9
  raise ImportError(
10
- f"You are using transformers=={transformers.__version__}, but transformers>={VERSION} is required to use DeciLM. Please upgrade transformers."
11
  )
 
1
  import transformers
2
  from packaging import version
3
 
4
+ MIN_VERSION = "4.35.2"
5
 
6
 
7
  def check_transformers_version():
8
+ if version.parse(transformers.__version__) < version.parse(MIN_VERSION):
9
  raise ImportError(
10
+ f"You are using transformers=={transformers.__version__}, but transformers>={MIN_VERSION} is required to use DeciLM. Please upgrade transformers."
11
  )