DanielJacob commited on
Commit
1447616
·
verified ·
1 Parent(s): 035b4b8

Update modeling_svd_llama.py

Browse files
Files changed (1) hide show
  1. modeling_svd_llama.py +5 -5
modeling_svd_llama.py CHANGED
@@ -11,7 +11,7 @@ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
11
  from transformers.utils import logging
12
  from transformers import LlamaForCausalLM
13
  from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel, LlamaRotaryEmbedding, LlamaRMSNorm, repeat_kv, apply_rotary_pos_emb
14
- from component.configuration_svd_llama import SVDLlamaConfig
15
 
16
 
17
  logger = logging.get_logger(__name__)
@@ -21,7 +21,7 @@ _CONFIG_FOR_DOC = "LlamaConfig"
21
  ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
22
 
23
  class SVDLlamaMLP(nn.Module):
24
- def __init__(self, config: SVDLlamaConfig):
25
  super().__init__()
26
  self.config = config
27
  self.hidden_size = config.hidden_size
@@ -48,7 +48,7 @@ class SVDLlamaMLP(nn.Module):
48
  class SVDLlamaAttention(nn.Module):
49
  """Multi-headed attention from 'Attention Is All You Need' paper"""
50
 
51
- def __init__(self, config: SVDLlamaConfig, layer_idx: Optional[int] = None):
52
  super().__init__()
53
  self.config = config
54
  self.layer_idx = layer_idx
@@ -334,14 +334,14 @@ class SVDLLaMASDPA(SVDLlamaAttention):
334
 
335
 
336
  class SVDLlamaDecoderLayer(LlamaDecoderLayer):
337
- def __init__(self, config: SVDLlamaConfig, layer_idx: int):
338
  super().__init__(config, layer_idx)
339
  self.self_attn = SVDLlamaAttention(config=config, layer_idx=layer_idx)
340
  self.mlp = SVDLlamaMLP(config)
341
 
342
 
343
  class SVDLlamaForCausalLM(LlamaForCausalLM):
344
- def __init__(self, config: SVDLlamaConfig):
345
  super().__init__(config)
346
  self.model = LlamaModel(config)
347
  self.model.layers = nn.ModuleList(
 
11
  from transformers.utils import logging
12
  from transformers import LlamaForCausalLM
13
  from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel, LlamaRotaryEmbedding, LlamaRMSNorm, repeat_kv, apply_rotary_pos_emb
14
+ from transformers import LlamaConfig
15
 
16
 
17
  logger = logging.get_logger(__name__)
 
21
  ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
22
 
23
  class SVDLlamaMLP(nn.Module):
24
+ def __init__(self, config: LlamaConfig):
25
  super().__init__()
26
  self.config = config
27
  self.hidden_size = config.hidden_size
 
48
  class SVDLlamaAttention(nn.Module):
49
  """Multi-headed attention from 'Attention Is All You Need' paper"""
50
 
51
+ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
52
  super().__init__()
53
  self.config = config
54
  self.layer_idx = layer_idx
 
334
 
335
 
336
  class SVDLlamaDecoderLayer(LlamaDecoderLayer):
337
+ def __init__(self, config: LlamaConfig, layer_idx: int):
338
  super().__init__(config, layer_idx)
339
  self.self_attn = SVDLlamaAttention(config=config, layer_idx=layer_idx)
340
  self.mlp = SVDLlamaMLP(config)
341
 
342
 
343
  class SVDLlamaForCausalLM(LlamaForCausalLM):
344
+ def __init__(self, config: LlamaConfig):
345
  super().__init__(config)
346
  self.model = LlamaModel(config)
347
  self.model.layers = nn.ModuleList(