DanielJacob
commited on
Update modeling_svd_llama.py
Browse files- modeling_svd_llama.py +5 -5
modeling_svd_llama.py
CHANGED
@@ -11,7 +11,7 @@ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
|
11 |
from transformers.utils import logging
|
12 |
from transformers import LlamaForCausalLM
|
13 |
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel, LlamaRotaryEmbedding, LlamaRMSNorm, repeat_kv, apply_rotary_pos_emb
|
14 |
-
from
|
15 |
|
16 |
|
17 |
logger = logging.get_logger(__name__)
|
@@ -21,7 +21,7 @@ _CONFIG_FOR_DOC = "LlamaConfig"
|
|
21 |
ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
|
22 |
|
23 |
class SVDLlamaMLP(nn.Module):
|
24 |
-
def __init__(self, config:
|
25 |
super().__init__()
|
26 |
self.config = config
|
27 |
self.hidden_size = config.hidden_size
|
@@ -48,7 +48,7 @@ class SVDLlamaMLP(nn.Module):
|
|
48 |
class SVDLlamaAttention(nn.Module):
|
49 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
50 |
|
51 |
-
def __init__(self, config:
|
52 |
super().__init__()
|
53 |
self.config = config
|
54 |
self.layer_idx = layer_idx
|
@@ -334,14 +334,14 @@ class SVDLLaMASDPA(SVDLlamaAttention):
|
|
334 |
|
335 |
|
336 |
class SVDLlamaDecoderLayer(LlamaDecoderLayer):
|
337 |
-
def __init__(self, config:
|
338 |
super().__init__(config, layer_idx)
|
339 |
self.self_attn = SVDLlamaAttention(config=config, layer_idx=layer_idx)
|
340 |
self.mlp = SVDLlamaMLP(config)
|
341 |
|
342 |
|
343 |
class SVDLlamaForCausalLM(LlamaForCausalLM):
|
344 |
-
def __init__(self, config:
|
345 |
super().__init__(config)
|
346 |
self.model = LlamaModel(config)
|
347 |
self.model.layers = nn.ModuleList(
|
|
|
11 |
from transformers.utils import logging
|
12 |
from transformers import LlamaForCausalLM
|
13 |
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel, LlamaRotaryEmbedding, LlamaRMSNorm, repeat_kv, apply_rotary_pos_emb
|
14 |
+
from transformers import LlamaConfig
|
15 |
|
16 |
|
17 |
logger = logging.get_logger(__name__)
|
|
|
21 |
ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
|
22 |
|
23 |
class SVDLlamaMLP(nn.Module):
|
24 |
+
def __init__(self, config: LlamaConfig):
|
25 |
super().__init__()
|
26 |
self.config = config
|
27 |
self.hidden_size = config.hidden_size
|
|
|
48 |
class SVDLlamaAttention(nn.Module):
|
49 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
50 |
|
51 |
+
def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
|
52 |
super().__init__()
|
53 |
self.config = config
|
54 |
self.layer_idx = layer_idx
|
|
|
334 |
|
335 |
|
336 |
class SVDLlamaDecoderLayer(LlamaDecoderLayer):
|
337 |
+
def __init__(self, config: LlamaConfig, layer_idx: int):
|
338 |
super().__init__(config, layer_idx)
|
339 |
self.self_attn = SVDLlamaAttention(config=config, layer_idx=layer_idx)
|
340 |
self.mlp = SVDLlamaMLP(config)
|
341 |
|
342 |
|
343 |
class SVDLlamaForCausalLM(LlamaForCausalLM):
|
344 |
+
def __init__(self, config: LlamaConfig):
|
345 |
super().__init__(config)
|
346 |
self.model = LlamaModel(config)
|
347 |
self.model.layers = nn.ModuleList(
|