Mayank Mishra commited on
Commit
0e6c38f
·
1 Parent(s): e82da12

upload model

Browse files
config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_function": "swiglu",
3
+ "add_bias": true,
4
+ "apply_residual_connection_post_layernorm": false,
5
+ "architectures": [
6
+ "GraniteForCausalLM"
7
+ ],
8
+ "attention_head_type": "gqa",
9
+ "attention_multiplier": null,
10
+ "attention_softmax_in_fp32": true,
11
+ "attn_pdrop": 0.1,
12
+ "auto_map": {
13
+ "AutoConfig": "configuration_granite.GraniteConfig",
14
+ "AutoModel": "modeling_granite.GraniteModel",
15
+ "AutoModelForCausalLM": "modeling_granite.GraniteForCausalLM"
16
+ },
17
+ "bos_token_id": 0,
18
+ "embd_pdrop": 0.1,
19
+ "eos_token_id": 0,
20
+ "initializer_range": 0.02,
21
+ "layer_norm_epsilon": 1e-05,
22
+ "model_type": "granite",
23
+ "n_embd": 4096,
24
+ "n_head": 32,
25
+ "n_inner": 14336,
26
+ "n_layer": 36,
27
+ "n_positions": 4096,
28
+ "normalization_function": "rmsnorm",
29
+ "num_key_value_heads": 8,
30
+ "pad_token_id": 0,
31
+ "position_embedding_type": "rope",
32
+ "resid_pdrop": 0.1,
33
+ "rope_theta": 10000,
34
+ "scale_attention_softmax_in_fp32": true,
35
+ "scale_attn_weights": true,
36
+ "torch_dtype": "float32",
37
+ "transformers_version": "4.38.1",
38
+ "use_cache": true,
39
+ "vocab_size": 49152
40
+ }
configuration_granite.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class GraniteConfig(PretrainedConfig):
5
+ model_type = "granite"
6
+
7
+ keys_to_ignore_at_inference = ["past_key_values"]
8
+ attribute_map = {
9
+ "hidden_size": "n_embd",
10
+ "max_position_embeddings": "n_positions",
11
+ "num_attention_heads": "n_head",
12
+ "num_hidden_layers": "n_layer",
13
+ }
14
+
15
+ def __init__(
16
+ self,
17
+ vocab_size: int = 50257,
18
+ n_positions: int = 1024,
19
+ n_embd: int = 768,
20
+ n_layer: int = 12,
21
+ n_head: int = 12,
22
+ num_key_value_heads: int = None,
23
+ n_inner: int = None,
24
+ activation_function: str = "gelu_pytorch_tanh",
25
+ attention_head_type: str = "mqa",
26
+ resid_pdrop: float = 0.1,
27
+ embd_pdrop: float = 0.1,
28
+ attn_pdrop: float = 0.1,
29
+ normalization_function: str = "layernorm",
30
+ layer_norm_epsilon: float = 1e-5,
31
+ initializer_range: float = 0.02,
32
+ scale_attn_weights: bool = True,
33
+ attention_multiplier: float = None,
34
+ use_cache: bool = True,
35
+ bos_token_id: int = 50256,
36
+ eos_token_id: int = 50256,
37
+ pad_token_id: int = 50256,
38
+ attention_softmax_in_fp32: bool = True,
39
+ scale_attention_softmax_in_fp32: bool = True,
40
+ add_bias: bool = True,
41
+ position_embedding_type: str = "learned_absolute",
42
+ rope_theta: int = 10000,
43
+ **kwargs,
44
+ ) -> None:
45
+ self.vocab_size = vocab_size
46
+ self.n_positions = n_positions
47
+ self.n_embd = n_embd
48
+ self.n_layer = n_layer
49
+ self.n_head = n_head
50
+ self.num_key_value_heads = num_key_value_heads
51
+ self.n_inner = 4 * n_embd if n_inner is None else n_inner
52
+ self.activation_function = activation_function
53
+ self.attention_head_type = attention_head_type
54
+ self.resid_pdrop = resid_pdrop
55
+ self.embd_pdrop = embd_pdrop
56
+ self.attn_pdrop = attn_pdrop
57
+ self.normalization_function = normalization_function
58
+ self.layer_norm_epsilon = layer_norm_epsilon
59
+ self.initializer_range = initializer_range
60
+ self.scale_attn_weights = scale_attn_weights
61
+ self.attention_multiplier = attention_multiplier
62
+ self.use_cache = use_cache
63
+ self.attention_softmax_in_fp32 = attention_softmax_in_fp32
64
+ self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32
65
+ self.position_embedding_type = position_embedding_type
66
+ self.add_bias = add_bias
67
+ self.rope_theta = rope_theta
68
+
69
+ if self.attention_multiplier is not None:
70
+ assert self.scale_attn_weights
71
+
72
+ # for compatibility with some features
73
+ self.multi_query = attention_head_type == "mqa"
74
+
75
+ if attention_head_type == "mha":
76
+ if self.num_key_value_heads is None:
77
+ self.num_key_value_heads = self.n_head
78
+
79
+ assert (
80
+ self.n_head == self.num_key_value_heads
81
+ ), "MultiHeadAttention should have same number of heads for query, keys and values"
82
+ elif attention_head_type == "mqa":
83
+ if self.num_key_value_heads is None:
84
+ self.num_key_value_heads = 1
85
+
86
+ assert self.num_key_value_heads == 1, "MultiQueryAttention should have 1 head for keys and values"
87
+ elif attention_head_type == "gqa":
88
+ assert (
89
+ self.num_key_value_heads is not None
90
+ ), "`num_key_value_heads` needs to be specified with GroupedQueryAttention"
91
+
92
+ assert (
93
+ self.n_head % self.num_key_value_heads == 0
94
+ ), "GroupedQueryAttention should have more than 1 head for keys and values"
95
+ else:
96
+ raise ValueError(f"unexpected attention_head_type ({attention_head_type})")
97
+
98
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs)
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "eos_token_id": 0,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.38.1"
7
+ }
model-00001-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2aa28732a7ec623510ab63c06f1af7cb86c3050cf2b84e6e288864fac86c96aa
3
+ size 4933514320
model-00002-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6227e504458fe50bdacd6dc0a38f03fc29d659c923ec704486283d74a5b7676f
3
+ size 4765849624
model-00003-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79332859e730c83aee6bb3fbcf499e6966c1bf9965f871d4e7080f25a820d792
3
+ size 4832982480
model-00004-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d8403a71ea70c8dc20506a27354c7bc077681078c00c0de2d71ca2010bb2e41d
3
+ size 4765849680
model-00005-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b19f69cad2d87854233bddf0adcca89d0099add537c4e50df3dbfd67adf8f185
3
+ size 4832982480
model-00006-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bddf280f58ced5a8d4d80dd0f6ee82abc786a5dfb616f4e8a44badbffe9353df
3
+ size 4765849680
model-00007-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:66e4a5f4630e6edc448e4e6aaa49da21afc94d19b3f3cecdfa56a1809cae4f33
3
+ size 3322654376
model.safetensors.index.json ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 32219643904
4
+ },
5
+ "weight_map": {
6
+ "transformer.h.0.attn.c_attn.bias": "model-00001-of-00007.safetensors",
7
+ "transformer.h.0.attn.c_attn.weight": "model-00001-of-00007.safetensors",
8
+ "transformer.h.0.attn.c_proj.bias": "model-00001-of-00007.safetensors",
9
+ "transformer.h.0.attn.c_proj.weight": "model-00001-of-00007.safetensors",
10
+ "transformer.h.0.ln_1.weight": "model-00001-of-00007.safetensors",
11
+ "transformer.h.0.ln_2.weight": "model-00001-of-00007.safetensors",
12
+ "transformer.h.0.mlp.c_fc.bias": "model-00001-of-00007.safetensors",
13
+ "transformer.h.0.mlp.c_fc.weight": "model-00001-of-00007.safetensors",
14
+ "transformer.h.0.mlp.c_proj.bias": "model-00001-of-00007.safetensors",
15
+ "transformer.h.0.mlp.c_proj.weight": "model-00001-of-00007.safetensors",
16
+ "transformer.h.1.attn.c_attn.bias": "model-00001-of-00007.safetensors",
17
+ "transformer.h.1.attn.c_attn.weight": "model-00001-of-00007.safetensors",
18
+ "transformer.h.1.attn.c_proj.bias": "model-00001-of-00007.safetensors",
19
+ "transformer.h.1.attn.c_proj.weight": "model-00001-of-00007.safetensors",
20
+ "transformer.h.1.ln_1.weight": "model-00001-of-00007.safetensors",
21
+ "transformer.h.1.ln_2.weight": "model-00001-of-00007.safetensors",
22
+ "transformer.h.1.mlp.c_fc.bias": "model-00001-of-00007.safetensors",
23
+ "transformer.h.1.mlp.c_fc.weight": "model-00001-of-00007.safetensors",
24
+ "transformer.h.1.mlp.c_proj.bias": "model-00001-of-00007.safetensors",
25
+ "transformer.h.1.mlp.c_proj.weight": "model-00001-of-00007.safetensors",
26
+ "transformer.h.10.attn.c_attn.bias": "model-00002-of-00007.safetensors",
27
+ "transformer.h.10.attn.c_attn.weight": "model-00002-of-00007.safetensors",
28
+ "transformer.h.10.attn.c_proj.bias": "model-00002-of-00007.safetensors",
29
+ "transformer.h.10.attn.c_proj.weight": "model-00002-of-00007.safetensors",
30
+ "transformer.h.10.ln_1.weight": "model-00002-of-00007.safetensors",
31
+ "transformer.h.10.ln_2.weight": "model-00002-of-00007.safetensors",
32
+ "transformer.h.10.mlp.c_fc.bias": "model-00003-of-00007.safetensors",
33
+ "transformer.h.10.mlp.c_fc.weight": "model-00003-of-00007.safetensors",
34
+ "transformer.h.10.mlp.c_proj.bias": "model-00003-of-00007.safetensors",
35
+ "transformer.h.10.mlp.c_proj.weight": "model-00003-of-00007.safetensors",
36
+ "transformer.h.11.attn.c_attn.bias": "model-00003-of-00007.safetensors",
37
+ "transformer.h.11.attn.c_attn.weight": "model-00003-of-00007.safetensors",
38
+ "transformer.h.11.attn.c_proj.bias": "model-00003-of-00007.safetensors",
39
+ "transformer.h.11.attn.c_proj.weight": "model-00003-of-00007.safetensors",
40
+ "transformer.h.11.ln_1.weight": "model-00003-of-00007.safetensors",
41
+ "transformer.h.11.ln_2.weight": "model-00003-of-00007.safetensors",
42
+ "transformer.h.11.mlp.c_fc.bias": "model-00003-of-00007.safetensors",
43
+ "transformer.h.11.mlp.c_fc.weight": "model-00003-of-00007.safetensors",
44
+ "transformer.h.11.mlp.c_proj.bias": "model-00003-of-00007.safetensors",
45
+ "transformer.h.11.mlp.c_proj.weight": "model-00003-of-00007.safetensors",
46
+ "transformer.h.12.attn.c_attn.bias": "model-00003-of-00007.safetensors",
47
+ "transformer.h.12.attn.c_attn.weight": "model-00003-of-00007.safetensors",
48
+ "transformer.h.12.attn.c_proj.bias": "model-00003-of-00007.safetensors",
49
+ "transformer.h.12.attn.c_proj.weight": "model-00003-of-00007.safetensors",
50
+ "transformer.h.12.ln_1.weight": "model-00003-of-00007.safetensors",
51
+ "transformer.h.12.ln_2.weight": "model-00003-of-00007.safetensors",
52
+ "transformer.h.12.mlp.c_fc.bias": "model-00003-of-00007.safetensors",
53
+ "transformer.h.12.mlp.c_fc.weight": "model-00003-of-00007.safetensors",
54
+ "transformer.h.12.mlp.c_proj.bias": "model-00003-of-00007.safetensors",
55
+ "transformer.h.12.mlp.c_proj.weight": "model-00003-of-00007.safetensors",
56
+ "transformer.h.13.attn.c_attn.bias": "model-00003-of-00007.safetensors",
57
+ "transformer.h.13.attn.c_attn.weight": "model-00003-of-00007.safetensors",
58
+ "transformer.h.13.attn.c_proj.bias": "model-00003-of-00007.safetensors",
59
+ "transformer.h.13.attn.c_proj.weight": "model-00003-of-00007.safetensors",
60
+ "transformer.h.13.ln_1.weight": "model-00003-of-00007.safetensors",
61
+ "transformer.h.13.ln_2.weight": "model-00003-of-00007.safetensors",
62
+ "transformer.h.13.mlp.c_fc.bias": "model-00003-of-00007.safetensors",
63
+ "transformer.h.13.mlp.c_fc.weight": "model-00003-of-00007.safetensors",
64
+ "transformer.h.13.mlp.c_proj.bias": "model-00003-of-00007.safetensors",
65
+ "transformer.h.13.mlp.c_proj.weight": "model-00003-of-00007.safetensors",
66
+ "transformer.h.14.attn.c_attn.bias": "model-00003-of-00007.safetensors",
67
+ "transformer.h.14.attn.c_attn.weight": "model-00003-of-00007.safetensors",
68
+ "transformer.h.14.attn.c_proj.bias": "model-00003-of-00007.safetensors",
69
+ "transformer.h.14.attn.c_proj.weight": "model-00003-of-00007.safetensors",
70
+ "transformer.h.14.ln_1.weight": "model-00003-of-00007.safetensors",
71
+ "transformer.h.14.ln_2.weight": "model-00003-of-00007.safetensors",
72
+ "transformer.h.14.mlp.c_fc.bias": "model-00003-of-00007.safetensors",
73
+ "transformer.h.14.mlp.c_fc.weight": "model-00003-of-00007.safetensors",
74
+ "transformer.h.14.mlp.c_proj.bias": "model-00003-of-00007.safetensors",
75
+ "transformer.h.14.mlp.c_proj.weight": "model-00003-of-00007.safetensors",
76
+ "transformer.h.15.attn.c_attn.bias": "model-00003-of-00007.safetensors",
77
+ "transformer.h.15.attn.c_attn.weight": "model-00003-of-00007.safetensors",
78
+ "transformer.h.15.attn.c_proj.bias": "model-00003-of-00007.safetensors",
79
+ "transformer.h.15.attn.c_proj.weight": "model-00003-of-00007.safetensors",
80
+ "transformer.h.15.ln_1.weight": "model-00003-of-00007.safetensors",
81
+ "transformer.h.15.ln_2.weight": "model-00003-of-00007.safetensors",
82
+ "transformer.h.15.mlp.c_fc.bias": "model-00003-of-00007.safetensors",
83
+ "transformer.h.15.mlp.c_fc.weight": "model-00003-of-00007.safetensors",
84
+ "transformer.h.15.mlp.c_proj.bias": "model-00004-of-00007.safetensors",
85
+ "transformer.h.15.mlp.c_proj.weight": "model-00004-of-00007.safetensors",
86
+ "transformer.h.16.attn.c_attn.bias": "model-00004-of-00007.safetensors",
87
+ "transformer.h.16.attn.c_attn.weight": "model-00004-of-00007.safetensors",
88
+ "transformer.h.16.attn.c_proj.bias": "model-00004-of-00007.safetensors",
89
+ "transformer.h.16.attn.c_proj.weight": "model-00004-of-00007.safetensors",
90
+ "transformer.h.16.ln_1.weight": "model-00004-of-00007.safetensors",
91
+ "transformer.h.16.ln_2.weight": "model-00004-of-00007.safetensors",
92
+ "transformer.h.16.mlp.c_fc.bias": "model-00004-of-00007.safetensors",
93
+ "transformer.h.16.mlp.c_fc.weight": "model-00004-of-00007.safetensors",
94
+ "transformer.h.16.mlp.c_proj.bias": "model-00004-of-00007.safetensors",
95
+ "transformer.h.16.mlp.c_proj.weight": "model-00004-of-00007.safetensors",
96
+ "transformer.h.17.attn.c_attn.bias": "model-00004-of-00007.safetensors",
97
+ "transformer.h.17.attn.c_attn.weight": "model-00004-of-00007.safetensors",
98
+ "transformer.h.17.attn.c_proj.bias": "model-00004-of-00007.safetensors",
99
+ "transformer.h.17.attn.c_proj.weight": "model-00004-of-00007.safetensors",
100
+ "transformer.h.17.ln_1.weight": "model-00004-of-00007.safetensors",
101
+ "transformer.h.17.ln_2.weight": "model-00004-of-00007.safetensors",
102
+ "transformer.h.17.mlp.c_fc.bias": "model-00004-of-00007.safetensors",
103
+ "transformer.h.17.mlp.c_fc.weight": "model-00004-of-00007.safetensors",
104
+ "transformer.h.17.mlp.c_proj.bias": "model-00004-of-00007.safetensors",
105
+ "transformer.h.17.mlp.c_proj.weight": "model-00004-of-00007.safetensors",
106
+ "transformer.h.18.attn.c_attn.bias": "model-00004-of-00007.safetensors",
107
+ "transformer.h.18.attn.c_attn.weight": "model-00004-of-00007.safetensors",
108
+ "transformer.h.18.attn.c_proj.bias": "model-00004-of-00007.safetensors",
109
+ "transformer.h.18.attn.c_proj.weight": "model-00004-of-00007.safetensors",
110
+ "transformer.h.18.ln_1.weight": "model-00004-of-00007.safetensors",
111
+ "transformer.h.18.ln_2.weight": "model-00004-of-00007.safetensors",
112
+ "transformer.h.18.mlp.c_fc.bias": "model-00004-of-00007.safetensors",
113
+ "transformer.h.18.mlp.c_fc.weight": "model-00004-of-00007.safetensors",
114
+ "transformer.h.18.mlp.c_proj.bias": "model-00004-of-00007.safetensors",
115
+ "transformer.h.18.mlp.c_proj.weight": "model-00004-of-00007.safetensors",
116
+ "transformer.h.19.attn.c_attn.bias": "model-00004-of-00007.safetensors",
117
+ "transformer.h.19.attn.c_attn.weight": "model-00004-of-00007.safetensors",
118
+ "transformer.h.19.attn.c_proj.bias": "model-00004-of-00007.safetensors",
119
+ "transformer.h.19.attn.c_proj.weight": "model-00004-of-00007.safetensors",
120
+ "transformer.h.19.ln_1.weight": "model-00004-of-00007.safetensors",
121
+ "transformer.h.19.ln_2.weight": "model-00004-of-00007.safetensors",
122
+ "transformer.h.19.mlp.c_fc.bias": "model-00004-of-00007.safetensors",
123
+ "transformer.h.19.mlp.c_fc.weight": "model-00004-of-00007.safetensors",
124
+ "transformer.h.19.mlp.c_proj.bias": "model-00004-of-00007.safetensors",
125
+ "transformer.h.19.mlp.c_proj.weight": "model-00004-of-00007.safetensors",
126
+ "transformer.h.2.attn.c_attn.bias": "model-00001-of-00007.safetensors",
127
+ "transformer.h.2.attn.c_attn.weight": "model-00001-of-00007.safetensors",
128
+ "transformer.h.2.attn.c_proj.bias": "model-00001-of-00007.safetensors",
129
+ "transformer.h.2.attn.c_proj.weight": "model-00001-of-00007.safetensors",
130
+ "transformer.h.2.ln_1.weight": "model-00001-of-00007.safetensors",
131
+ "transformer.h.2.ln_2.weight": "model-00001-of-00007.safetensors",
132
+ "transformer.h.2.mlp.c_fc.bias": "model-00001-of-00007.safetensors",
133
+ "transformer.h.2.mlp.c_fc.weight": "model-00001-of-00007.safetensors",
134
+ "transformer.h.2.mlp.c_proj.bias": "model-00001-of-00007.safetensors",
135
+ "transformer.h.2.mlp.c_proj.weight": "model-00001-of-00007.safetensors",
136
+ "transformer.h.20.attn.c_attn.bias": "model-00004-of-00007.safetensors",
137
+ "transformer.h.20.attn.c_attn.weight": "model-00004-of-00007.safetensors",
138
+ "transformer.h.20.attn.c_proj.bias": "model-00004-of-00007.safetensors",
139
+ "transformer.h.20.attn.c_proj.weight": "model-00004-of-00007.safetensors",
140
+ "transformer.h.20.ln_1.weight": "model-00004-of-00007.safetensors",
141
+ "transformer.h.20.ln_2.weight": "model-00004-of-00007.safetensors",
142
+ "transformer.h.20.mlp.c_fc.bias": "model-00004-of-00007.safetensors",
143
+ "transformer.h.20.mlp.c_fc.weight": "model-00004-of-00007.safetensors",
144
+ "transformer.h.20.mlp.c_proj.bias": "model-00004-of-00007.safetensors",
145
+ "transformer.h.20.mlp.c_proj.weight": "model-00004-of-00007.safetensors",
146
+ "transformer.h.21.attn.c_attn.bias": "model-00004-of-00007.safetensors",
147
+ "transformer.h.21.attn.c_attn.weight": "model-00004-of-00007.safetensors",
148
+ "transformer.h.21.attn.c_proj.bias": "model-00004-of-00007.safetensors",
149
+ "transformer.h.21.attn.c_proj.weight": "model-00004-of-00007.safetensors",
150
+ "transformer.h.21.ln_1.weight": "model-00004-of-00007.safetensors",
151
+ "transformer.h.21.ln_2.weight": "model-00004-of-00007.safetensors",
152
+ "transformer.h.21.mlp.c_fc.bias": "model-00005-of-00007.safetensors",
153
+ "transformer.h.21.mlp.c_fc.weight": "model-00005-of-00007.safetensors",
154
+ "transformer.h.21.mlp.c_proj.bias": "model-00005-of-00007.safetensors",
155
+ "transformer.h.21.mlp.c_proj.weight": "model-00005-of-00007.safetensors",
156
+ "transformer.h.22.attn.c_attn.bias": "model-00005-of-00007.safetensors",
157
+ "transformer.h.22.attn.c_attn.weight": "model-00005-of-00007.safetensors",
158
+ "transformer.h.22.attn.c_proj.bias": "model-00005-of-00007.safetensors",
159
+ "transformer.h.22.attn.c_proj.weight": "model-00005-of-00007.safetensors",
160
+ "transformer.h.22.ln_1.weight": "model-00005-of-00007.safetensors",
161
+ "transformer.h.22.ln_2.weight": "model-00005-of-00007.safetensors",
162
+ "transformer.h.22.mlp.c_fc.bias": "model-00005-of-00007.safetensors",
163
+ "transformer.h.22.mlp.c_fc.weight": "model-00005-of-00007.safetensors",
164
+ "transformer.h.22.mlp.c_proj.bias": "model-00005-of-00007.safetensors",
165
+ "transformer.h.22.mlp.c_proj.weight": "model-00005-of-00007.safetensors",
166
+ "transformer.h.23.attn.c_attn.bias": "model-00005-of-00007.safetensors",
167
+ "transformer.h.23.attn.c_attn.weight": "model-00005-of-00007.safetensors",
168
+ "transformer.h.23.attn.c_proj.bias": "model-00005-of-00007.safetensors",
169
+ "transformer.h.23.attn.c_proj.weight": "model-00005-of-00007.safetensors",
170
+ "transformer.h.23.ln_1.weight": "model-00005-of-00007.safetensors",
171
+ "transformer.h.23.ln_2.weight": "model-00005-of-00007.safetensors",
172
+ "transformer.h.23.mlp.c_fc.bias": "model-00005-of-00007.safetensors",
173
+ "transformer.h.23.mlp.c_fc.weight": "model-00005-of-00007.safetensors",
174
+ "transformer.h.23.mlp.c_proj.bias": "model-00005-of-00007.safetensors",
175
+ "transformer.h.23.mlp.c_proj.weight": "model-00005-of-00007.safetensors",
176
+ "transformer.h.24.attn.c_attn.bias": "model-00005-of-00007.safetensors",
177
+ "transformer.h.24.attn.c_attn.weight": "model-00005-of-00007.safetensors",
178
+ "transformer.h.24.attn.c_proj.bias": "model-00005-of-00007.safetensors",
179
+ "transformer.h.24.attn.c_proj.weight": "model-00005-of-00007.safetensors",
180
+ "transformer.h.24.ln_1.weight": "model-00005-of-00007.safetensors",
181
+ "transformer.h.24.ln_2.weight": "model-00005-of-00007.safetensors",
182
+ "transformer.h.24.mlp.c_fc.bias": "model-00005-of-00007.safetensors",
183
+ "transformer.h.24.mlp.c_fc.weight": "model-00005-of-00007.safetensors",
184
+ "transformer.h.24.mlp.c_proj.bias": "model-00005-of-00007.safetensors",
185
+ "transformer.h.24.mlp.c_proj.weight": "model-00005-of-00007.safetensors",
186
+ "transformer.h.25.attn.c_attn.bias": "model-00005-of-00007.safetensors",
187
+ "transformer.h.25.attn.c_attn.weight": "model-00005-of-00007.safetensors",
188
+ "transformer.h.25.attn.c_proj.bias": "model-00005-of-00007.safetensors",
189
+ "transformer.h.25.attn.c_proj.weight": "model-00005-of-00007.safetensors",
190
+ "transformer.h.25.ln_1.weight": "model-00005-of-00007.safetensors",
191
+ "transformer.h.25.ln_2.weight": "model-00005-of-00007.safetensors",
192
+ "transformer.h.25.mlp.c_fc.bias": "model-00005-of-00007.safetensors",
193
+ "transformer.h.25.mlp.c_fc.weight": "model-00005-of-00007.safetensors",
194
+ "transformer.h.25.mlp.c_proj.bias": "model-00005-of-00007.safetensors",
195
+ "transformer.h.25.mlp.c_proj.weight": "model-00005-of-00007.safetensors",
196
+ "transformer.h.26.attn.c_attn.bias": "model-00005-of-00007.safetensors",
197
+ "transformer.h.26.attn.c_attn.weight": "model-00005-of-00007.safetensors",
198
+ "transformer.h.26.attn.c_proj.bias": "model-00005-of-00007.safetensors",
199
+ "transformer.h.26.attn.c_proj.weight": "model-00005-of-00007.safetensors",
200
+ "transformer.h.26.ln_1.weight": "model-00005-of-00007.safetensors",
201
+ "transformer.h.26.ln_2.weight": "model-00005-of-00007.safetensors",
202
+ "transformer.h.26.mlp.c_fc.bias": "model-00005-of-00007.safetensors",
203
+ "transformer.h.26.mlp.c_fc.weight": "model-00005-of-00007.safetensors",
204
+ "transformer.h.26.mlp.c_proj.bias": "model-00006-of-00007.safetensors",
205
+ "transformer.h.26.mlp.c_proj.weight": "model-00006-of-00007.safetensors",
206
+ "transformer.h.27.attn.c_attn.bias": "model-00006-of-00007.safetensors",
207
+ "transformer.h.27.attn.c_attn.weight": "model-00006-of-00007.safetensors",
208
+ "transformer.h.27.attn.c_proj.bias": "model-00006-of-00007.safetensors",
209
+ "transformer.h.27.attn.c_proj.weight": "model-00006-of-00007.safetensors",
210
+ "transformer.h.27.ln_1.weight": "model-00006-of-00007.safetensors",
211
+ "transformer.h.27.ln_2.weight": "model-00006-of-00007.safetensors",
212
+ "transformer.h.27.mlp.c_fc.bias": "model-00006-of-00007.safetensors",
213
+ "transformer.h.27.mlp.c_fc.weight": "model-00006-of-00007.safetensors",
214
+ "transformer.h.27.mlp.c_proj.bias": "model-00006-of-00007.safetensors",
215
+ "transformer.h.27.mlp.c_proj.weight": "model-00006-of-00007.safetensors",
216
+ "transformer.h.28.attn.c_attn.bias": "model-00006-of-00007.safetensors",
217
+ "transformer.h.28.attn.c_attn.weight": "model-00006-of-00007.safetensors",
218
+ "transformer.h.28.attn.c_proj.bias": "model-00006-of-00007.safetensors",
219
+ "transformer.h.28.attn.c_proj.weight": "model-00006-of-00007.safetensors",
220
+ "transformer.h.28.ln_1.weight": "model-00006-of-00007.safetensors",
221
+ "transformer.h.28.ln_2.weight": "model-00006-of-00007.safetensors",
222
+ "transformer.h.28.mlp.c_fc.bias": "model-00006-of-00007.safetensors",
223
+ "transformer.h.28.mlp.c_fc.weight": "model-00006-of-00007.safetensors",
224
+ "transformer.h.28.mlp.c_proj.bias": "model-00006-of-00007.safetensors",
225
+ "transformer.h.28.mlp.c_proj.weight": "model-00006-of-00007.safetensors",
226
+ "transformer.h.29.attn.c_attn.bias": "model-00006-of-00007.safetensors",
227
+ "transformer.h.29.attn.c_attn.weight": "model-00006-of-00007.safetensors",
228
+ "transformer.h.29.attn.c_proj.bias": "model-00006-of-00007.safetensors",
229
+ "transformer.h.29.attn.c_proj.weight": "model-00006-of-00007.safetensors",
230
+ "transformer.h.29.ln_1.weight": "model-00006-of-00007.safetensors",
231
+ "transformer.h.29.ln_2.weight": "model-00006-of-00007.safetensors",
232
+ "transformer.h.29.mlp.c_fc.bias": "model-00006-of-00007.safetensors",
233
+ "transformer.h.29.mlp.c_fc.weight": "model-00006-of-00007.safetensors",
234
+ "transformer.h.29.mlp.c_proj.bias": "model-00006-of-00007.safetensors",
235
+ "transformer.h.29.mlp.c_proj.weight": "model-00006-of-00007.safetensors",
236
+ "transformer.h.3.attn.c_attn.bias": "model-00001-of-00007.safetensors",
237
+ "transformer.h.3.attn.c_attn.weight": "model-00001-of-00007.safetensors",
238
+ "transformer.h.3.attn.c_proj.bias": "model-00001-of-00007.safetensors",
239
+ "transformer.h.3.attn.c_proj.weight": "model-00001-of-00007.safetensors",
240
+ "transformer.h.3.ln_1.weight": "model-00001-of-00007.safetensors",
241
+ "transformer.h.3.ln_2.weight": "model-00001-of-00007.safetensors",
242
+ "transformer.h.3.mlp.c_fc.bias": "model-00001-of-00007.safetensors",
243
+ "transformer.h.3.mlp.c_fc.weight": "model-00001-of-00007.safetensors",
244
+ "transformer.h.3.mlp.c_proj.bias": "model-00001-of-00007.safetensors",
245
+ "transformer.h.3.mlp.c_proj.weight": "model-00001-of-00007.safetensors",
246
+ "transformer.h.30.attn.c_attn.bias": "model-00006-of-00007.safetensors",
247
+ "transformer.h.30.attn.c_attn.weight": "model-00006-of-00007.safetensors",
248
+ "transformer.h.30.attn.c_proj.bias": "model-00006-of-00007.safetensors",
249
+ "transformer.h.30.attn.c_proj.weight": "model-00006-of-00007.safetensors",
250
+ "transformer.h.30.ln_1.weight": "model-00006-of-00007.safetensors",
251
+ "transformer.h.30.ln_2.weight": "model-00006-of-00007.safetensors",
252
+ "transformer.h.30.mlp.c_fc.bias": "model-00006-of-00007.safetensors",
253
+ "transformer.h.30.mlp.c_fc.weight": "model-00006-of-00007.safetensors",
254
+ "transformer.h.30.mlp.c_proj.bias": "model-00006-of-00007.safetensors",
255
+ "transformer.h.30.mlp.c_proj.weight": "model-00006-of-00007.safetensors",
256
+ "transformer.h.31.attn.c_attn.bias": "model-00006-of-00007.safetensors",
257
+ "transformer.h.31.attn.c_attn.weight": "model-00006-of-00007.safetensors",
258
+ "transformer.h.31.attn.c_proj.bias": "model-00006-of-00007.safetensors",
259
+ "transformer.h.31.attn.c_proj.weight": "model-00006-of-00007.safetensors",
260
+ "transformer.h.31.ln_1.weight": "model-00006-of-00007.safetensors",
261
+ "transformer.h.31.ln_2.weight": "model-00006-of-00007.safetensors",
262
+ "transformer.h.31.mlp.c_fc.bias": "model-00006-of-00007.safetensors",
263
+ "transformer.h.31.mlp.c_fc.weight": "model-00006-of-00007.safetensors",
264
+ "transformer.h.31.mlp.c_proj.bias": "model-00006-of-00007.safetensors",
265
+ "transformer.h.31.mlp.c_proj.weight": "model-00006-of-00007.safetensors",
266
+ "transformer.h.32.attn.c_attn.bias": "model-00006-of-00007.safetensors",
267
+ "transformer.h.32.attn.c_attn.weight": "model-00006-of-00007.safetensors",
268
+ "transformer.h.32.attn.c_proj.bias": "model-00006-of-00007.safetensors",
269
+ "transformer.h.32.attn.c_proj.weight": "model-00006-of-00007.safetensors",
270
+ "transformer.h.32.ln_1.weight": "model-00006-of-00007.safetensors",
271
+ "transformer.h.32.ln_2.weight": "model-00006-of-00007.safetensors",
272
+ "transformer.h.32.mlp.c_fc.bias": "model-00007-of-00007.safetensors",
273
+ "transformer.h.32.mlp.c_fc.weight": "model-00007-of-00007.safetensors",
274
+ "transformer.h.32.mlp.c_proj.bias": "model-00007-of-00007.safetensors",
275
+ "transformer.h.32.mlp.c_proj.weight": "model-00007-of-00007.safetensors",
276
+ "transformer.h.33.attn.c_attn.bias": "model-00007-of-00007.safetensors",
277
+ "transformer.h.33.attn.c_attn.weight": "model-00007-of-00007.safetensors",
278
+ "transformer.h.33.attn.c_proj.bias": "model-00007-of-00007.safetensors",
279
+ "transformer.h.33.attn.c_proj.weight": "model-00007-of-00007.safetensors",
280
+ "transformer.h.33.ln_1.weight": "model-00007-of-00007.safetensors",
281
+ "transformer.h.33.ln_2.weight": "model-00007-of-00007.safetensors",
282
+ "transformer.h.33.mlp.c_fc.bias": "model-00007-of-00007.safetensors",
283
+ "transformer.h.33.mlp.c_fc.weight": "model-00007-of-00007.safetensors",
284
+ "transformer.h.33.mlp.c_proj.bias": "model-00007-of-00007.safetensors",
285
+ "transformer.h.33.mlp.c_proj.weight": "model-00007-of-00007.safetensors",
286
+ "transformer.h.34.attn.c_attn.bias": "model-00007-of-00007.safetensors",
287
+ "transformer.h.34.attn.c_attn.weight": "model-00007-of-00007.safetensors",
288
+ "transformer.h.34.attn.c_proj.bias": "model-00007-of-00007.safetensors",
289
+ "transformer.h.34.attn.c_proj.weight": "model-00007-of-00007.safetensors",
290
+ "transformer.h.34.ln_1.weight": "model-00007-of-00007.safetensors",
291
+ "transformer.h.34.ln_2.weight": "model-00007-of-00007.safetensors",
292
+ "transformer.h.34.mlp.c_fc.bias": "model-00007-of-00007.safetensors",
293
+ "transformer.h.34.mlp.c_fc.weight": "model-00007-of-00007.safetensors",
294
+ "transformer.h.34.mlp.c_proj.bias": "model-00007-of-00007.safetensors",
295
+ "transformer.h.34.mlp.c_proj.weight": "model-00007-of-00007.safetensors",
296
+ "transformer.h.35.attn.c_attn.bias": "model-00007-of-00007.safetensors",
297
+ "transformer.h.35.attn.c_attn.weight": "model-00007-of-00007.safetensors",
298
+ "transformer.h.35.attn.c_proj.bias": "model-00007-of-00007.safetensors",
299
+ "transformer.h.35.attn.c_proj.weight": "model-00007-of-00007.safetensors",
300
+ "transformer.h.35.ln_1.weight": "model-00007-of-00007.safetensors",
301
+ "transformer.h.35.ln_2.weight": "model-00007-of-00007.safetensors",
302
+ "transformer.h.35.mlp.c_fc.bias": "model-00007-of-00007.safetensors",
303
+ "transformer.h.35.mlp.c_fc.weight": "model-00007-of-00007.safetensors",
304
+ "transformer.h.35.mlp.c_proj.bias": "model-00007-of-00007.safetensors",
305
+ "transformer.h.35.mlp.c_proj.weight": "model-00007-of-00007.safetensors",
306
+ "transformer.h.4.attn.c_attn.bias": "model-00001-of-00007.safetensors",
307
+ "transformer.h.4.attn.c_attn.weight": "model-00001-of-00007.safetensors",
308
+ "transformer.h.4.attn.c_proj.bias": "model-00001-of-00007.safetensors",
309
+ "transformer.h.4.attn.c_proj.weight": "model-00001-of-00007.safetensors",
310
+ "transformer.h.4.ln_1.weight": "model-00001-of-00007.safetensors",
311
+ "transformer.h.4.ln_2.weight": "model-00001-of-00007.safetensors",
312
+ "transformer.h.4.mlp.c_fc.bias": "model-00001-of-00007.safetensors",
313
+ "transformer.h.4.mlp.c_fc.weight": "model-00001-of-00007.safetensors",
314
+ "transformer.h.4.mlp.c_proj.bias": "model-00002-of-00007.safetensors",
315
+ "transformer.h.4.mlp.c_proj.weight": "model-00002-of-00007.safetensors",
316
+ "transformer.h.5.attn.c_attn.bias": "model-00002-of-00007.safetensors",
317
+ "transformer.h.5.attn.c_attn.weight": "model-00002-of-00007.safetensors",
318
+ "transformer.h.5.attn.c_proj.bias": "model-00002-of-00007.safetensors",
319
+ "transformer.h.5.attn.c_proj.weight": "model-00002-of-00007.safetensors",
320
+ "transformer.h.5.ln_1.weight": "model-00002-of-00007.safetensors",
321
+ "transformer.h.5.ln_2.weight": "model-00002-of-00007.safetensors",
322
+ "transformer.h.5.mlp.c_fc.bias": "model-00002-of-00007.safetensors",
323
+ "transformer.h.5.mlp.c_fc.weight": "model-00002-of-00007.safetensors",
324
+ "transformer.h.5.mlp.c_proj.bias": "model-00002-of-00007.safetensors",
325
+ "transformer.h.5.mlp.c_proj.weight": "model-00002-of-00007.safetensors",
326
+ "transformer.h.6.attn.c_attn.bias": "model-00002-of-00007.safetensors",
327
+ "transformer.h.6.attn.c_attn.weight": "model-00002-of-00007.safetensors",
328
+ "transformer.h.6.attn.c_proj.bias": "model-00002-of-00007.safetensors",
329
+ "transformer.h.6.attn.c_proj.weight": "model-00002-of-00007.safetensors",
330
+ "transformer.h.6.ln_1.weight": "model-00002-of-00007.safetensors",
331
+ "transformer.h.6.ln_2.weight": "model-00002-of-00007.safetensors",
332
+ "transformer.h.6.mlp.c_fc.bias": "model-00002-of-00007.safetensors",
333
+ "transformer.h.6.mlp.c_fc.weight": "model-00002-of-00007.safetensors",
334
+ "transformer.h.6.mlp.c_proj.bias": "model-00002-of-00007.safetensors",
335
+ "transformer.h.6.mlp.c_proj.weight": "model-00002-of-00007.safetensors",
336
+ "transformer.h.7.attn.c_attn.bias": "model-00002-of-00007.safetensors",
337
+ "transformer.h.7.attn.c_attn.weight": "model-00002-of-00007.safetensors",
338
+ "transformer.h.7.attn.c_proj.bias": "model-00002-of-00007.safetensors",
339
+ "transformer.h.7.attn.c_proj.weight": "model-00002-of-00007.safetensors",
340
+ "transformer.h.7.ln_1.weight": "model-00002-of-00007.safetensors",
341
+ "transformer.h.7.ln_2.weight": "model-00002-of-00007.safetensors",
342
+ "transformer.h.7.mlp.c_fc.bias": "model-00002-of-00007.safetensors",
343
+ "transformer.h.7.mlp.c_fc.weight": "model-00002-of-00007.safetensors",
344
+ "transformer.h.7.mlp.c_proj.bias": "model-00002-of-00007.safetensors",
345
+ "transformer.h.7.mlp.c_proj.weight": "model-00002-of-00007.safetensors",
346
+ "transformer.h.8.attn.c_attn.bias": "model-00002-of-00007.safetensors",
347
+ "transformer.h.8.attn.c_attn.weight": "model-00002-of-00007.safetensors",
348
+ "transformer.h.8.attn.c_proj.bias": "model-00002-of-00007.safetensors",
349
+ "transformer.h.8.attn.c_proj.weight": "model-00002-of-00007.safetensors",
350
+ "transformer.h.8.ln_1.weight": "model-00002-of-00007.safetensors",
351
+ "transformer.h.8.ln_2.weight": "model-00002-of-00007.safetensors",
352
+ "transformer.h.8.mlp.c_fc.bias": "model-00002-of-00007.safetensors",
353
+ "transformer.h.8.mlp.c_fc.weight": "model-00002-of-00007.safetensors",
354
+ "transformer.h.8.mlp.c_proj.bias": "model-00002-of-00007.safetensors",
355
+ "transformer.h.8.mlp.c_proj.weight": "model-00002-of-00007.safetensors",
356
+ "transformer.h.9.attn.c_attn.bias": "model-00002-of-00007.safetensors",
357
+ "transformer.h.9.attn.c_attn.weight": "model-00002-of-00007.safetensors",
358
+ "transformer.h.9.attn.c_proj.bias": "model-00002-of-00007.safetensors",
359
+ "transformer.h.9.attn.c_proj.weight": "model-00002-of-00007.safetensors",
360
+ "transformer.h.9.ln_1.weight": "model-00002-of-00007.safetensors",
361
+ "transformer.h.9.ln_2.weight": "model-00002-of-00007.safetensors",
362
+ "transformer.h.9.mlp.c_fc.bias": "model-00002-of-00007.safetensors",
363
+ "transformer.h.9.mlp.c_fc.weight": "model-00002-of-00007.safetensors",
364
+ "transformer.h.9.mlp.c_proj.bias": "model-00002-of-00007.safetensors",
365
+ "transformer.h.9.mlp.c_proj.weight": "model-00002-of-00007.safetensors",
366
+ "transformer.ln_f.weight": "model-00007-of-00007.safetensors",
367
+ "transformer.wte.weight": "model-00001-of-00007.safetensors"
368
+ }
369
+ }
modeling_granite.py ADDED
@@ -0,0 +1,1376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numbers
3
+ import warnings
4
+ from enum import Enum
5
+ from typing import Optional, Tuple, Union
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from transformers import DynamicCache, PreTrainedModel
11
+ from transformers.activations import get_activation as get_base_activation
12
+ from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
13
+ from transformers.utils import is_flash_attn_2_available
14
+
15
+ from .configuration_granite import GraniteConfig
16
+
17
+
18
+ class PositionEmbeddingType(Enum):
19
+ learned_absolute = "learned_absolute"
20
+ alibi = "alibi"
21
+ rope = "rope"
22
+
23
+
24
+ class AttentionHeadType(Enum):
25
+ mha = "mha"
26
+ mqa = "mqa"
27
+ gqa = "gqa"
28
+
29
+
30
+ if is_flash_attn_2_available():
31
+ from flash_attn.bert_padding import IndexFirstAxis, pad_input, unpad_input
32
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
33
+
34
+
35
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
36
+ def get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
37
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
38
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
39
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
40
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
41
+ return indices, cu_seqlens, max_seqlen_in_batch
42
+
43
+
44
+ def repeat_key_value(x: torch.Tensor, num_heads: int, num_key_value_heads: int) -> torch.Tensor:
45
+ num_groups = num_heads // num_key_value_heads
46
+
47
+ # mha
48
+ if num_groups == 1:
49
+ return x
50
+
51
+ # mqa
52
+ if num_key_value_heads == 1:
53
+ return x.expand(-1, num_heads, -1, -1)
54
+
55
+ # gqa
56
+ return x.repeat_interleave(num_groups, dim=1)
57
+
58
+
59
+ ##################################################
60
+ # activation functions
61
+
62
+
63
+ _GLU_BASE_MAPPING = {
64
+ "geglu": "gelu",
65
+ "miglu": "mish",
66
+ "mishglu": "mish",
67
+ "swiglu": "swish",
68
+ }
69
+
70
+
71
+ class GLUActivation(nn.Module):
72
+ def __init__(self, base_activation: nn.Module) -> None:
73
+ super().__init__()
74
+ self.base_activation = base_activation
75
+
76
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
77
+ x = x.chunk(2, dim=-1)
78
+ return x[0] * self.base_activation(x[1])
79
+
80
+
81
+ def is_glu(name: str) -> bool:
82
+ return name.endswith("glu")
83
+
84
+
85
+ def get_activation_function(name: str) -> nn.Module:
86
+ if is_glu(name):
87
+ # for glu and sigmoid_glu, we directly return the pytorch's GLU
88
+ if name in ["glu", "sigmoid_glu"]:
89
+ activation_function = nn.modules.GLU()
90
+ else:
91
+ if name in _GLU_BASE_MAPPING:
92
+ name = _GLU_BASE_MAPPING[name]
93
+ elif name.endswith("_glu"):
94
+ name = name.rstrip("_glu")
95
+ else:
96
+ raise ValueError("invalid activation function")
97
+
98
+ base_activation = get_base_activation(name)
99
+ activation_function = GLUActivation(base_activation)
100
+ else:
101
+ activation_function = get_base_activation(name)
102
+
103
+ return activation_function
104
+
105
+
106
+ ##################################################
107
+ # normalization functions
108
+
109
+
110
+ class RMSNorm(nn.Module):
111
+ def __init__(self, normalized_shape: int, eps: float = 1e-6) -> None:
112
+ super().__init__()
113
+
114
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
115
+ self.eps = eps
116
+
117
+ if isinstance(normalized_shape, numbers.Integral):
118
+ normalized_shape = (normalized_shape,)
119
+ self.normalized_shape = normalized_shape
120
+
121
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
122
+ input_dtype = input.dtype
123
+
124
+ input = input.to(torch.float32)
125
+ variance = input.pow(2).mean(-1, keepdim=True)
126
+ input = input * torch.rsqrt(variance + self.eps)
127
+
128
+ return self.weight * input.to(input_dtype)
129
+
130
+ def extra_repr(self) -> str:
131
+ return f"{self.normalized_shape}, eps={self.eps}"
132
+
133
+ def reset_parameters(self) -> None:
134
+ nn.init.ones_(self.weight)
135
+
136
+
137
+ _NORMALIZATION_FUNCTIONS = {
138
+ "layernorm": nn.LayerNorm,
139
+ "rmsnorm": RMSNorm,
140
+ }
141
+
142
+
143
+ def get_normalization_function(name: str, normalized_shape: int, eps: float = 1e-5) -> nn.Module:
144
+ if name in _NORMALIZATION_FUNCTIONS:
145
+ return _NORMALIZATION_FUNCTIONS[name](normalized_shape, eps=eps)
146
+
147
+ raise ValueError(f"unexpected `normalization_function` {name}")
148
+
149
+
150
+ ##################################################
151
+ # attention modules
152
+
153
+
154
+ class GraniteAttention(nn.Module):
155
+ def __init__(self, config: GraniteConfig, causal: bool, layer_idx: Optional[int] = None) -> None:
156
+ super().__init__()
157
+
158
+ self.causal = causal
159
+ self.hidden_size = config.n_embd
160
+ self.num_heads = config.n_head
161
+ self.num_key_value_heads = config.num_key_value_heads
162
+ self.add_bias = config.add_bias
163
+
164
+ assert (
165
+ self.hidden_size % self.num_heads == 0
166
+ ), f"`hidden_size` ({self.hidden_size}) must be divisible by `num_heads` ({self.num_heads})"
167
+
168
+ self.head_dim = self.hidden_size // self.num_heads
169
+ self.attention_head_type = AttentionHeadType(config.attention_head_type)
170
+
171
+ self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type)
172
+ self.scale_attn_weights = config.scale_attn_weights
173
+ self.attention_multiplier = config.attention_multiplier
174
+
175
+ self.layer_idx = layer_idx
176
+ self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
177
+ self.scale_attention_softmax_in_fp32 = (
178
+ config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32
179
+ )
180
+
181
+ if self.attention_head_type == AttentionHeadType.mha:
182
+ if self.num_key_value_heads is None:
183
+ self.num_key_value_heads = self.num_heads
184
+
185
+ assert (
186
+ self.num_heads == self.num_key_value_heads
187
+ ), f"{self.__class__.__name__} should have same number of heads for query, keys and values"
188
+ elif self.attention_head_type == AttentionHeadType.gqa:
189
+ assert (
190
+ self.num_key_value_heads is not None
191
+ ), "`num_key_value_heads` needs to be specified with GroupedQueryAttention"
192
+
193
+ assert self.num_heads % self.num_key_value_heads == 0, (
194
+ f"`num_heads` ({self.num_heads}) should be a multiple of `num_key_value_heads` "
195
+ f"({self.num_key_value_heads})"
196
+ )
197
+ elif self.attention_head_type == AttentionHeadType.mqa:
198
+ if self.num_key_value_heads is None:
199
+ self.num_key_value_heads = 1
200
+
201
+ assert self.num_key_value_heads == 1, f"{self.__class__.__name__} should have 1 head for keys and values"
202
+ else:
203
+ raise ValueError(f"unexpected attention_head_type ({self.attention_head_type})")
204
+
205
+ # note that the actual layout is different for the output and depends on whether we are using MHA, MQA or GQA
206
+ # (self.hidden_size + 2 * self.num_key_value_heads * self.head_dim) is just the actual number output features
207
+ self.c_attn = nn.Linear(
208
+ self.hidden_size, self.hidden_size + 2 * self.num_key_value_heads * self.head_dim, bias=self.add_bias
209
+ )
210
+ self.c_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.add_bias)
211
+
212
+ self.attn_pdrop = config.attn_pdrop
213
+ self.resid_pdrop = config.resid_pdrop
214
+
215
+ self.attn_dropout = nn.Identity() if self.attn_pdrop == 0 else nn.Dropout(self.attn_pdrop)
216
+ self.resid_dropout = nn.Identity() if self.resid_pdrop == 0 else nn.Dropout(self.resid_pdrop)
217
+
218
+ def _prepare_qkv_for_forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
219
+ # ==========================================================================================
220
+ # hidden_states -> (batch_size, query_length, num_heads * head_dim)
221
+ # ==========================================================================================
222
+
223
+ # the output of following is a tuple if using MQA with tensor parallel
224
+ hidden_states = self.c_attn(hidden_states)
225
+
226
+ # ==========================================================================================
227
+ # hidden_states -> (batch_size, query_length, [num_heads + num_key_value_heads * 2] * head_dim)
228
+ # ==========================================================================================
229
+
230
+ # for MHA, we can get away with doing just 1 transpose which is not true for GQA
231
+ if self.attention_head_type == AttentionHeadType.mha:
232
+ query, key, value = self._prepare_qkv_for_forward_mha(hidden_states)
233
+ elif self.attention_head_type == AttentionHeadType.gqa:
234
+ query, key, value = self._prepare_qkv_for_forward_gqa(hidden_states)
235
+ elif self.attention_head_type == AttentionHeadType.mqa:
236
+ query, key, value = self._prepare_qkv_for_forward_mqa(hidden_states)
237
+ else:
238
+ raise ValueError(f"unexpected attention_head_type ({self.attention_head_type})")
239
+
240
+ # ==========================================================================================
241
+ # query -> (batch_size, num_heads, query_length, head_dim)
242
+ # key -> (batch_size, num_key_value_heads, query_length, head_dim)
243
+ # value -> (batch_size, num_key_value_heads, query_length, head_dim)
244
+ # ==========================================================================================
245
+
246
+ return query, key, value
247
+
248
+ def _prepare_qkv_for_forward_mha(
249
+ self, hidden_states: torch.Tensor
250
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
251
+ batch_size, query_length = hidden_states.shape[:-1]
252
+
253
+ hidden_states = hidden_states.view(batch_size, query_length, self.num_heads, -1)
254
+ hidden_states = hidden_states.transpose(1, 2)
255
+
256
+ query, key, value = hidden_states.chunk(3, dim=-1)
257
+
258
+ return query, key, value
259
+
260
+ def _prepare_qkv_for_forward_gqa(
261
+ self, hidden_states: torch.Tensor
262
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
263
+ batch_size, query_length = hidden_states.shape[:-1]
264
+
265
+ hidden_states = hidden_states.view(batch_size, query_length, self.num_key_value_heads, -1)
266
+
267
+ query, key, value = hidden_states.split(
268
+ ((self.num_heads // self.num_key_value_heads) * self.head_dim, self.head_dim, self.head_dim), dim=-1
269
+ )
270
+
271
+ # this needs to be a reshape instead of view sadly
272
+ query = query.reshape(batch_size, query_length, -1, self.head_dim)
273
+
274
+ query = query.transpose(1, 2)
275
+ key = key.transpose(1, 2)
276
+ value = value.transpose(1, 2)
277
+
278
+ return query, key, value
279
+
280
+ def _prepare_qkv_for_forward_mqa(
281
+ self, hidden_states: torch.Tensor
282
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
283
+ batch_size, query_length = hidden_states.shape[:-1]
284
+
285
+ query, key, value = hidden_states.split((self.hidden_size, self.head_dim, self.head_dim), dim=-1)
286
+
287
+ query = query.view(batch_size, query_length, self.num_heads, -1)
288
+
289
+ query = query.transpose(1, 2)
290
+ key = key.unsqueeze(1)
291
+ value = value.unsqueeze(1)
292
+
293
+ return query, key, value
294
+
295
+ def forward(
296
+ self,
297
+ hidden_states: torch.Tensor,
298
+ past_key_values: Optional[DynamicCache] = None,
299
+ attention_mask: Optional[torch.Tensor] = None,
300
+ rope_cos_sin: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
301
+ ) -> torch.Tensor:
302
+ # ==========================================================================================
303
+ # hidden_states -> (batch_size, query_length, num_heads * head_dim)
304
+ # ==========================================================================================
305
+
306
+ query, key, value = self._prepare_qkv_for_forward(hidden_states)
307
+
308
+ # ==========================================================================================
309
+ # query -> (batch_size, num_heads, query_length, head_dim)
310
+ # key -> (batch_size, num_key_value_heads, query_length, head_dim)
311
+ # value -> (batch_size, num_key_value_heads, query_length, head_dim)
312
+ # ==========================================================================================
313
+
314
+ if self.position_embedding_type == PositionEmbeddingType.rope:
315
+ query = apply_rotary_pos_emb(query, rope_cos_sin)
316
+ key = apply_rotary_pos_emb(key, rope_cos_sin)
317
+
318
+ if past_key_values is not None:
319
+ key, value = past_key_values.update(key, value, self.layer_idx)
320
+
321
+ # ==========================================================================================
322
+ # query -> (batch_size, num_heads, query_length, head_dim)
323
+ # key -> (batch_size, num_key_value_heads, key_length, head_dim)
324
+ # value -> (batch_size, num_key_value_heads, key_length, head_dim)
325
+ # ==========================================================================================
326
+
327
+ key = key.transpose(-1, -2)
328
+
329
+ dtype = query.dtype
330
+ softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype
331
+
332
+ if self.scale_attn_weights:
333
+ if self.attention_multiplier is None:
334
+ scale_factor = 1 / self.head_dim**0.5
335
+ else:
336
+ scale_factor = self.attention_multiplier
337
+ else:
338
+ scale_factor = 1
339
+
340
+ # ==========================================================================================
341
+ # query -> (batch_size, num_heads, query_length, head_dim)
342
+ # key -> (batch_size, num_key_value_heads, head_dim, key_length)
343
+ # value -> (batch_size, num_key_value_heads, key_length, head_dim)
344
+ # ==========================================================================================
345
+
346
+ batch_size = query.shape[0]
347
+ query_length = query.shape[2]
348
+ key_length = key.shape[-1]
349
+
350
+ key = repeat_key_value(key, self.num_heads, self.num_key_value_heads)
351
+ value = repeat_key_value(value, self.num_heads, self.num_key_value_heads)
352
+
353
+ # Always copies
354
+ query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim)
355
+ # No copy when layer_past is provided.
356
+ key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length)
357
+
358
+ # ==========================================================================================
359
+ # query -> (batch_size * num_heads, query_length, head_dim)
360
+ # key -> (batch_size * num_heads, head_dim, key_length)
361
+ # value -> (batch_size, num_heads, key_length, head_dim)
362
+ # ==========================================================================================
363
+
364
+ attn_weights = torch.empty(
365
+ (batch_size * self.num_heads, query_length, key_length), device=query.device, dtype=query.dtype
366
+ )
367
+
368
+ attn_weights = torch.baddbmm(attn_weights, query, key, beta=0, alpha=scale_factor).view(
369
+ batch_size, self.num_heads, query_length, key_length
370
+ )
371
+
372
+ # ==========================================================================================
373
+ # attn_weights -> (batch_size, num_heads, query_length, key_length)
374
+ # ==========================================================================================
375
+
376
+ attn_weights = attn_weights.to(softmax_dtype)
377
+
378
+ if attention_mask is not None:
379
+ attn_weights = attn_weights + attention_mask
380
+
381
+ attn_weights = F.softmax(attn_weights, dim=-1).to(dtype)
382
+
383
+ attn_weights = self.attn_dropout(attn_weights)
384
+
385
+ # ==========================================================================================
386
+ # value -> (batch_size, num_heads, key_length, head_dim)
387
+ # attn_weights -> (batch_size, num_heads, query_length, key_length)
388
+ # ==========================================================================================
389
+
390
+ attn_output = torch.matmul(attn_weights, value)
391
+
392
+ # ==========================================================================================
393
+ # attn_output -> (batch_size, num_heads, query_length, head_dim)
394
+ # ==========================================================================================
395
+
396
+ attn_output = attn_output.transpose(1, 2)
397
+ attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
398
+
399
+ # ==========================================================================================
400
+ # attn_output -> (batch_size, query_length, num_heads * head_dim)
401
+ # ==========================================================================================
402
+
403
+ attn_output = self.c_proj(attn_output)
404
+ attn_output = self.resid_dropout(attn_output)
405
+
406
+ return attn_output
407
+
408
+
409
+ class GraniteSDPA(GraniteAttention):
410
+ def forward(
411
+ self,
412
+ hidden_states: torch.Tensor,
413
+ past_key_values: Optional[DynamicCache] = None,
414
+ attention_mask: Optional[torch.Tensor] = None,
415
+ rope_cos_sin: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
416
+ ) -> torch.Tensor:
417
+ # ==========================================================================================
418
+ # hidden_states -> (batch_size, query_length, num_heads * head_dim)
419
+ # ==========================================================================================
420
+
421
+ query, key, value = self._prepare_qkv_for_forward(hidden_states)
422
+
423
+ # ==========================================================================================
424
+ # query -> (batch_size, num_heads, query_length, head_dim)
425
+ # key -> (batch_size, num_key_value_heads, query_length, head_dim)
426
+ # value -> (batch_size, num_key_value_heads, query_length, head_dim)
427
+ # ==========================================================================================
428
+
429
+ if self.position_embedding_type == PositionEmbeddingType.rope:
430
+ query = apply_rotary_pos_emb(query, rope_cos_sin)
431
+ key = apply_rotary_pos_emb(key, rope_cos_sin)
432
+
433
+ if past_key_values is not None:
434
+ key, value = past_key_values.update(key, value, self.layer_idx)
435
+
436
+ # ==========================================================================================
437
+ # query -> (batch_size, num_heads, query_length, head_dim)
438
+ # key -> (batch_size, num_key_value_heads, key_length, head_dim)
439
+ # value -> (batch_size, num_key_value_heads, key_length, head_dim)
440
+ # ==========================================================================================
441
+
442
+ key = repeat_key_value(key, self.num_heads, self.num_key_value_heads)
443
+ value = repeat_key_value(value, self.num_heads, self.num_key_value_heads)
444
+
445
+ # ==========================================================================================
446
+ # query -> (batch_size, num_heads, query_length, head_dim)
447
+ # key -> (batch_size, num_heads, key_length, head_dim)
448
+ # value -> (batch_size, num_heads, key_length, head_dim)
449
+ # ==========================================================================================
450
+
451
+ attn_output = F.scaled_dot_product_attention(
452
+ query,
453
+ key,
454
+ value,
455
+ attn_mask=attention_mask,
456
+ dropout_p=self.attn_pdrop if self.training else 0,
457
+ is_causal=self.causal if attention_mask is None else False,
458
+ scale=self.attention_multiplier if self.scale_attn_weights else 1,
459
+ )
460
+
461
+ # ==========================================================================================
462
+ # attn_output -> (batch_size, num_heads, query_length, head_dim)
463
+ # ==========================================================================================
464
+
465
+ batch_size = attn_output.shape[0]
466
+ attn_output = attn_output.transpose(1, 2)
467
+ attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
468
+
469
+ # ==========================================================================================
470
+ # attn_output -> (batch_size, query_length, num_heads * head_dim)
471
+ # ==========================================================================================
472
+
473
+ attn_output = self.c_proj(attn_output)
474
+ attn_output = self.resid_dropout(attn_output)
475
+
476
+ return attn_output
477
+
478
+
479
+ class GraniteFlashAttention2(GraniteAttention):
480
+ def forward(
481
+ self,
482
+ hidden_states: torch.Tensor,
483
+ past_key_values: Optional[DynamicCache] = None,
484
+ attention_mask: Optional[torch.Tensor] = None,
485
+ rope_cos_sin: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
486
+ ) -> torch.Tensor:
487
+ # ==========================================================================================
488
+ # hidden_states -> (batch_size, query_length, num_heads * head_dim)
489
+ # ==========================================================================================
490
+
491
+ query, key, value = self._prepare_qkv_for_forward(hidden_states)
492
+
493
+ # ==========================================================================================
494
+ # query -> (batch_size, num_heads, query_length, head_dim)
495
+ # key -> (batch_size, num_key_value_heads, query_length, head_dim)
496
+ # value -> (batch_size, num_key_value_heads, query_length, head_dim)
497
+ # ==========================================================================================
498
+
499
+ if self.position_embedding_type == PositionEmbeddingType.rope:
500
+ query = apply_rotary_pos_emb(query, rope_cos_sin)
501
+ key = apply_rotary_pos_emb(key, rope_cos_sin)
502
+
503
+ if past_key_values is not None:
504
+ key, value = past_key_values.update(key, value, self.layer_idx)
505
+
506
+ # ==========================================================================================
507
+ # query -> (batch_size, num_heads, query_length, head_dim)
508
+ # key -> (batch_size, num_key_value_heads, key_length, head_dim)
509
+ # value -> (batch_size, num_key_value_heads, key_length, head_dim)
510
+ # ==========================================================================================
511
+
512
+ # TODO avoid this extra transpose
513
+ query = query.transpose(1, 2)
514
+ if self.attention_head_type == AttentionHeadType.mqa:
515
+ key = key.squeeze(1).unsqueeze(2)
516
+ value = value.squeeze(1).unsqueeze(2)
517
+ else:
518
+ key = key.transpose(1, 2)
519
+ value = value.transpose(1, 2)
520
+
521
+ # ==========================================================================================
522
+ # query -> (batch_size, query_length, num_heads, head_dim)
523
+ # key -> (batch_size, key_length, num_heads, head_dim)
524
+ # value -> (batch_size, key_length, num_heads, head_dim)
525
+ # ==========================================================================================
526
+
527
+ batch_size, query_length = query.shape[:2]
528
+ key_length = key.shape[1]
529
+ indices_k, cu_seqlens_k, max_seqlen_k = get_unpad_data(attention_mask)
530
+
531
+ key = IndexFirstAxis.apply(
532
+ key.reshape(batch_size * key_length, self.num_key_value_heads, self.head_dim), indices_k
533
+ )
534
+ value = IndexFirstAxis.apply(
535
+ value.reshape(batch_size * key_length, self.num_key_value_heads, self.head_dim), indices_k
536
+ )
537
+
538
+ if query_length == key_length:
539
+ query = IndexFirstAxis.apply(
540
+ query.reshape(batch_size * key_length, self.num_heads, self.head_dim), indices_k
541
+ )
542
+ cu_seqlens_q = cu_seqlens_k
543
+ max_seqlen_q = max_seqlen_k
544
+ indices_q = indices_k
545
+ elif query_length == 1:
546
+ max_seqlen_q = 1
547
+ cu_seqlens_q = torch.arange(
548
+ batch_size + 1, dtype=torch.int32, device=query.device
549
+ ) # There is a memcpy here, that is very bad.
550
+ indices_q = cu_seqlens_q[:-1]
551
+ query = query.squeeze(1)
552
+ else:
553
+ # The -q_len: slice assumes left padding.
554
+ attention_mask = attention_mask[:, -query_length:]
555
+ query, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query, attention_mask)
556
+
557
+ # ==========================================================================================
558
+ # query -> (total_q, num_heads, head_dim)
559
+ # key -> (total_q, num_heads, head_dim)
560
+ # value -> (total_q, num_heads, head_dim)
561
+ # ==========================================================================================
562
+
563
+ attn_output = flash_attn_varlen_func(
564
+ query,
565
+ key,
566
+ value,
567
+ cu_seqlens_q=cu_seqlens_q,
568
+ cu_seqlens_k=cu_seqlens_k,
569
+ max_seqlen_q=max_seqlen_q,
570
+ max_seqlen_k=max_seqlen_k,
571
+ dropout_p=self.attn_pdrop if self.training else 0,
572
+ softmax_scale=self.attention_multiplier if self.scale_attn_weights else 1,
573
+ causal=self.causal,
574
+ )
575
+
576
+ # ==========================================================================================
577
+ # attn_output -> (total_q, num_heads, head_dim)
578
+ # ==========================================================================================
579
+
580
+ attn_output = pad_input(attn_output, indices_q, batch_size, query_length)
581
+ attn_output = attn_output.view(batch_size, query_length, -1)
582
+
583
+ # ==========================================================================================
584
+ # attn_output -> (batch_size, query_length, num_heads * head_dim)
585
+ # ==========================================================================================
586
+
587
+ attn_output = self.c_proj(attn_output)
588
+ attn_output = self.resid_dropout(attn_output)
589
+
590
+ return attn_output
591
+
592
+
593
+ _ATTENTION_MODULES = {
594
+ "eager": GraniteAttention,
595
+ "sdpa": GraniteSDPA,
596
+ "flash_attention_2": GraniteFlashAttention2,
597
+ }
598
+
599
+
600
+ def get_attention_module(
601
+ config: GraniteConfig, causal: bool, attention_implementation: str, layer_idx: int
602
+ ) -> GraniteAttention:
603
+ if attention_implementation in _ATTENTION_MODULES:
604
+ return _ATTENTION_MODULES[attention_implementation](config, causal=causal, layer_idx=layer_idx)
605
+ raise ValueError(f"unexpected `attention_implementation` {attention_implementation}")
606
+
607
+
608
+ ##################################################
609
+ # position embeddings
610
+
611
+
612
+ class Alibi(nn.Module):
613
+ def __init__(self, num_heads: int) -> None:
614
+ super().__init__()
615
+ self.num_heads = num_heads
616
+
617
+ self.reset_parameters()
618
+
619
+ def forward(
620
+ self, attention_mask: torch.Tensor, batch_size: int, key_length: int, device: torch.device, dtype: torch.dtype
621
+ ) -> torch.Tensor:
622
+ """
623
+ Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
624
+ relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
625
+ `softmax(l+a) = softmax(l)`. Based on
626
+ https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
627
+ TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
628
+
629
+ Args:
630
+ attention_mask (torch.Tensor): attention_mask tensor of shape (`batch_size`, `key_length`)
631
+ num_heads (int): `num_heads` for the model
632
+ batch_size (int): `batch_size`
633
+ key_length (int): `key_length`
634
+ device (torch.device): device for the tensors
635
+ dtype (torch.dtype): dtype to use for the tensors
636
+
637
+ Returns:
638
+ torch.Tensor: alibi tensor of shape (`batch_size`, `num_heads`, `key_length`)
639
+ """
640
+
641
+ # Note: alibi will added to the attention bias that will be applied to the query, key product of attention
642
+ # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
643
+ # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
644
+ # => the query_length dimension will then be broadcasted correctly
645
+ # This is more or less identical to T5's relative position bias:
646
+ # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
647
+ if attention_mask is None:
648
+ arange_tensor = (
649
+ torch.arange(key_length, device=device).unsqueeze(0).unsqueeze(0).expand(batch_size, -1, -1)
650
+ )
651
+ else:
652
+ arange_tensor = (attention_mask.cumsum(dim=-1) - 1).masked_fill_(attention_mask == 0, 0).unsqueeze(1)
653
+
654
+ alibi = self.slopes.unsqueeze(1) * arange_tensor
655
+ return alibi.to(dtype)
656
+
657
+ def reset_parameters(self) -> None:
658
+ closest_power_of_2 = 2 ** math.floor(math.log2(self.num_heads))
659
+ base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32)
660
+ powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
661
+ slopes = torch.pow(base, powers)
662
+
663
+ if closest_power_of_2 != self.num_heads:
664
+ extra_base = torch.tensor(2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32)
665
+ num_remaining_heads = min(closest_power_of_2, self.num_heads - closest_power_of_2)
666
+ extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32)
667
+ slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
668
+
669
+ self.register_buffer("slopes", slopes, persistent=False)
670
+
671
+
672
+ class RoPE(nn.Module):
673
+ def __init__(
674
+ self,
675
+ head_dim: int,
676
+ max_position_embeddings: int = 2048,
677
+ base: int = 10000,
678
+ ) -> None:
679
+ super().__init__()
680
+
681
+ self.head_dim = head_dim
682
+ self.max_position_embeddings = max_position_embeddings
683
+ self.base = base
684
+ self.mscale = 1
685
+
686
+ self.reset_parameters()
687
+
688
+ def forward(self, seq_len: int, dtype: torch.dtype, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
689
+ if seq_len > self.max_seq_len_cached:
690
+ self._set_cos_sin_cache(seq_len=seq_len, device=device, dtype=dtype)
691
+
692
+ cos = self.cos_cached[:seq_len].to(dtype)
693
+ sin = self.sin_cached[:seq_len].to(dtype)
694
+
695
+ return cos, sin
696
+
697
+ def reset_parameters(self) -> None:
698
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
699
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
700
+
701
+ # Build here to make `torch.jit.trace` work.
702
+ self._set_cos_sin_cache(
703
+ seq_len=self.max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
704
+ )
705
+
706
+ @torch.no_grad()
707
+ def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> None:
708
+ self.max_seq_len_cached = seq_len
709
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
710
+
711
+ freqs = torch.outer(t, self.inv_freq)
712
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
713
+ emb = torch.cat((freqs, freqs), dim=-1)
714
+
715
+ self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(dtype), persistent=False)
716
+ self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(dtype), persistent=False)
717
+
718
+
719
+ def apply_rotary_pos_emb(x: torch.Tensor, cos_sin: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
720
+ cos, sin = cos_sin
721
+ x = (x * cos) + (_rotate_half(x) * sin)
722
+ return x
723
+
724
+
725
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
726
+ x1, x2 = torch.chunk(x, 2, dim=-1)
727
+ return torch.cat((-x2, x1), dim=-1)
728
+
729
+
730
+ ##################################################
731
+ # MLP
732
+
733
+
734
+ class GraniteMLP(nn.Module):
735
+ def __init__(self, config: GraniteConfig) -> None:
736
+ super().__init__()
737
+
738
+ hidden_size = config.n_embd
739
+ intermediate_size = config.n_inner
740
+ activation_function = config.activation_function
741
+ add_bias = config.add_bias
742
+ residual_dropout = config.resid_pdrop
743
+
744
+ self.c_fc = nn.Linear(
745
+ hidden_size,
746
+ 2 * intermediate_size if is_glu(activation_function) else intermediate_size,
747
+ bias=add_bias,
748
+ )
749
+ self.act = get_activation_function(activation_function)
750
+ self.c_proj = nn.Linear(intermediate_size, hidden_size, bias=add_bias)
751
+ self.dropout = nn.Identity() if residual_dropout == 0 else nn.Dropout(residual_dropout)
752
+
753
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
754
+ hidden_states = self.c_fc(hidden_states)
755
+ hidden_states = self.act(hidden_states)
756
+ hidden_states = self.c_proj(hidden_states)
757
+ hidden_states = self.dropout(hidden_states)
758
+ return hidden_states
759
+
760
+
761
+ ##################################################
762
+ # transformer layer
763
+
764
+
765
+ class GraniteBlock(nn.Module):
766
+ def __init__(
767
+ self,
768
+ config: GraniteConfig,
769
+ attention_implementation: str,
770
+ layer_idx: Optional[int] = None,
771
+ ) -> None:
772
+ super().__init__()
773
+
774
+ hidden_size = config.hidden_size
775
+ self.inner_dim = config.n_inner
776
+ self.layer_idx = layer_idx
777
+
778
+ self.ln_1 = get_normalization_function(
779
+ config.normalization_function,
780
+ hidden_size,
781
+ eps=config.layer_norm_epsilon,
782
+ )
783
+ self.attn = get_attention_module(config, True, attention_implementation, layer_idx)
784
+ self.ln_2 = get_normalization_function(
785
+ config.normalization_function,
786
+ hidden_size,
787
+ eps=config.layer_norm_epsilon,
788
+ )
789
+ self.mlp = GraniteMLP(config)
790
+
791
+ def forward(
792
+ self,
793
+ hidden_states: torch.Tensor,
794
+ past_key_values: Optional[DynamicCache] = None,
795
+ attention_mask: Optional[torch.Tensor] = None,
796
+ rope_cos_sin: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
797
+ ) -> torch.Tensor:
798
+ residual = hidden_states
799
+ hidden_states = self.ln_1(hidden_states)
800
+
801
+ attn_output = self.attn(
802
+ hidden_states,
803
+ past_key_values=past_key_values,
804
+ attention_mask=attention_mask,
805
+ rope_cos_sin=rope_cos_sin,
806
+ )
807
+
808
+ # residual connection
809
+ hidden_states = attn_output + residual
810
+
811
+ residual = hidden_states
812
+ hidden_states = self.ln_2(hidden_states)
813
+
814
+ feed_forward_hidden_states = self.mlp(hidden_states)
815
+
816
+ # residual connection
817
+ hidden_states = residual + feed_forward_hidden_states
818
+
819
+ return hidden_states
820
+
821
+
822
+ ##################################################
823
+ # model classes
824
+
825
+
826
+ class GranitePreTrainedModel(PreTrainedModel):
827
+ """
828
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
829
+ models.
830
+ """
831
+
832
+ config_class = GraniteConfig
833
+ base_model_prefix = "transformer"
834
+ causal = True
835
+ _no_split_modules = ["GraniteBlock"]
836
+ _skip_keys_device_placement = "past_key_values"
837
+ _supports_sdpa = True
838
+ _supports_flash_attn_2 = True
839
+
840
+ def __init__(self, config: GraniteConfig, *inputs, **kwargs):
841
+ super().__init__(config, *inputs, **kwargs)
842
+
843
+ self.attention_implementation = self.config._attn_implementation
844
+ self._use_eager_attention = self.attention_implementation == "eager"
845
+ self._use_sdpa = self.attention_implementation == "sdpa"
846
+ self._use_flash_attention_2 = self.attention_implementation == "flash_attention_2"
847
+
848
+ self.initializer_range = config.initializer_range
849
+
850
+ def _init_weights(self, module: nn.Module) -> None:
851
+ if isinstance(module, (nn.LayerNorm, RMSNorm, Alibi, RoPE)):
852
+ module.reset_parameters()
853
+ elif isinstance(module, nn.Linear):
854
+ nn.init.normal_(module.weight, mean=0, std=self.initializer_range)
855
+ if module.bias is not None:
856
+ module.bias.zero_()
857
+ elif isinstance(module, nn.Embedding):
858
+ nn.init.normal_(module.weight, mean=0, std=self.initializer_range)
859
+ if module.padding_idx is not None:
860
+ module.weight[module.padding_idx].zero_()
861
+
862
+
863
+ class GraniteModel(GranitePreTrainedModel):
864
+ _keys_to_ignore_on_load_missing = ["attn.masked_bias"]
865
+ mask_value = None
866
+
867
+ def __init__(self, config: GraniteConfig, **kwargs) -> None:
868
+ super().__init__(config, **kwargs)
869
+
870
+ self.attention_head_type = AttentionHeadType(config.attention_head_type)
871
+ self.embed_dim = config.hidden_size
872
+ self.num_heads = config.num_attention_heads
873
+ self.num_key_value_heads = config.num_key_value_heads
874
+
875
+ assert (
876
+ self.embed_dim % self.num_heads == 0
877
+ ), f"`embed_dim` ({self.embed_dim}) must be divisible by `num_heads` ({self.num_heads})"
878
+
879
+ self.head_dim = self.embed_dim // self.num_heads
880
+
881
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
882
+
883
+ self.drop = nn.Identity() if config.embd_pdrop == 0 else nn.Dropout(config.embd_pdrop)
884
+ self.h = nn.ModuleList(
885
+ [GraniteBlock(config, self.attention_implementation, layer_idx=i) for i in range(config.num_hidden_layers)]
886
+ )
887
+ self.ln_f = get_normalization_function(
888
+ config.normalization_function,
889
+ self.embed_dim,
890
+ eps=config.layer_norm_epsilon,
891
+ )
892
+
893
+ self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type)
894
+
895
+ if self.position_embedding_type == PositionEmbeddingType.learned_absolute:
896
+ self.wpe = nn.Embedding(config.n_positions, self.embed_dim)
897
+ elif self.position_embedding_type == PositionEmbeddingType.alibi:
898
+ assert not self._use_flash_attention_2, "alibi is not implemented with FlashAttention"
899
+
900
+ self.alibi = Alibi(self.num_heads)
901
+ elif self.position_embedding_type == PositionEmbeddingType.rope:
902
+ self.rope = RoPE(self.head_dim, max_position_embeddings=config.n_positions, base=config.rope_theta)
903
+ else:
904
+ raise NotImplementedError()
905
+
906
+ # Initialize weights and apply final processing
907
+ self.post_init()
908
+
909
+ def get_input_embeddings(self) -> nn.Embedding:
910
+ return self.wte
911
+
912
+ def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
913
+ self.wte = new_embeddings
914
+
915
+ def forward(
916
+ self,
917
+ input_ids: Optional[torch.Tensor] = None,
918
+ past_key_values: Optional[DynamicCache] = None,
919
+ attention_mask: Optional[torch.Tensor] = None,
920
+ token_type_ids: Optional[torch.Tensor] = None,
921
+ position_ids: Optional[torch.Tensor] = None,
922
+ inputs_embeds: Optional[torch.Tensor] = None,
923
+ use_cache: Optional[bool] = None,
924
+ output_hidden_states: Optional[bool] = None,
925
+ return_dict: Optional[bool] = None,
926
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
927
+ (
928
+ output_hidden_states,
929
+ use_cache,
930
+ return_dict,
931
+ input_shape,
932
+ hidden_states,
933
+ attention_mask,
934
+ position_ids,
935
+ rope_cos_sin,
936
+ past_key_values,
937
+ ) = self._prepare_a_bunch_of_stuff(
938
+ input_ids=input_ids,
939
+ past_key_values=past_key_values,
940
+ attention_mask=attention_mask,
941
+ token_type_ids=token_type_ids,
942
+ position_ids=position_ids,
943
+ inputs_embeds=inputs_embeds,
944
+ use_cache=use_cache,
945
+ output_hidden_states=output_hidden_states,
946
+ return_dict=return_dict,
947
+ )
948
+
949
+ # ==========================================================================================
950
+ # flash:
951
+ # attention_mask -> (batch_size, key_length)
952
+ # else:
953
+ # attention_mask -> (batch_size, 1, query_length, key_length)
954
+ # ==========================================================================================
955
+
956
+ output_shape = input_shape + (hidden_states.size(-1),)
957
+
958
+ past_key_values = DynamicCache() if use_cache and past_key_values is None else past_key_values
959
+ all_hidden_states = () if output_hidden_states else None
960
+ for block in self.h:
961
+ if output_hidden_states:
962
+ all_hidden_states += (hidden_states,)
963
+
964
+ hidden_states = block(
965
+ hidden_states,
966
+ past_key_values=past_key_values,
967
+ attention_mask=attention_mask,
968
+ rope_cos_sin=rope_cos_sin,
969
+ )
970
+
971
+ hidden_states = self.ln_f(hidden_states)
972
+
973
+ hidden_states = hidden_states.view(output_shape)
974
+ # Add last hidden state
975
+ if output_hidden_states:
976
+ all_hidden_states += (hidden_states,)
977
+
978
+ if not return_dict:
979
+ return tuple(v for v in [hidden_states, past_key_values, all_hidden_states] if v is not None)
980
+
981
+ return BaseModelOutputWithPastAndCrossAttentions(
982
+ last_hidden_state=hidden_states,
983
+ past_key_values=past_key_values,
984
+ hidden_states=all_hidden_states,
985
+ )
986
+
987
+ def _get_position_ids(
988
+ self, attention_mask: torch.Tensor, past_length: int, query_length: int, key_length: int, device: torch.device
989
+ ) -> torch.Tensor:
990
+ if attention_mask is not None and len(attention_mask.shape) == 2:
991
+ # create position_ids on the fly for batch generation
992
+ position_ids = attention_mask.long().cumsum(-1) - 1
993
+ position_ids.masked_fill_(attention_mask == 0, 0)
994
+ if past_length > 0:
995
+ position_ids = position_ids[:, past_length:key_length:]
996
+ else:
997
+ position_ids = torch.arange(past_length, key_length, dtype=torch.long, device=device)
998
+ position_ids = position_ids.unsqueeze(0).view(-1, query_length)
999
+
1000
+ return position_ids
1001
+
1002
+ def _get_alibi_bias(
1003
+ self,
1004
+ attention_mask: torch.Tensor,
1005
+ batch_size: int,
1006
+ query_length: int,
1007
+ key_length: int,
1008
+ device: torch.device,
1009
+ dtype: torch.dtype,
1010
+ ) -> torch.Tensor:
1011
+ if self.position_embedding_type != PositionEmbeddingType.alibi:
1012
+ return None
1013
+
1014
+ alibi_bias = self.alibi(attention_mask, batch_size, key_length, device, dtype)
1015
+
1016
+ # ==========================================================================================
1017
+ # alibi_bias -> (batch_size, num_heads, key_length)
1018
+ # ==========================================================================================
1019
+
1020
+ alibi_bias = alibi_bias.unsqueeze(2)
1021
+ if query_length != 1:
1022
+ alibi_bias = alibi_bias.expand(-1, -1, query_length, -1)
1023
+
1024
+ # ==========================================================================================
1025
+ # alibi_bias -> (batch_size, num_heads, query_length, key_length)
1026
+ # ==========================================================================================
1027
+
1028
+ return alibi_bias
1029
+
1030
+ def _get_rope_cos_sin(
1031
+ self, key_length: int, position_ids: torch.Tensor, dtype: torch.dtype, device: torch.device
1032
+ ) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
1033
+ if self.position_embedding_type == PositionEmbeddingType.rope:
1034
+ cos, sin = self.rope(key_length, dtype=dtype, device=device)
1035
+ cos = cos[position_ids].unsqueeze(1)
1036
+ sin = sin[position_ids].unsqueeze(1)
1037
+ return cos, sin
1038
+
1039
+ def _prepare_causal_attention_mask(
1040
+ self, attention_mask: torch.Tensor, batch_size: int, query_length: int, key_length: int, device: torch.device
1041
+ ) -> torch.Tensor:
1042
+ past_length = key_length - query_length
1043
+
1044
+ # ==========================================================================================
1045
+ # attention_mask -> (batch_size, key_length)
1046
+ # ==========================================================================================
1047
+
1048
+ if query_length > 1:
1049
+ # (query_length, key_length)
1050
+ causal_mask = torch.empty((query_length, key_length), dtype=torch.bool, device=device)
1051
+ causal_mask[:, past_length:] = torch.tril(
1052
+ torch.ones(query_length, query_length, dtype=torch.bool, device=device)
1053
+ )
1054
+
1055
+ if past_length > 0:
1056
+ causal_mask[:, :past_length] = True
1057
+
1058
+ # (query_length, key_length) -> (1, query_length, key_length)
1059
+ causal_mask = causal_mask.unsqueeze(0)
1060
+
1061
+ if attention_mask is None:
1062
+ # (1, query_length, key_length) -> (batch_size, query_length, key_length)
1063
+ causal_mask = causal_mask.expand(batch_size, -1, -1)
1064
+ else:
1065
+ # (1, query_length, key_length) & (batch_size, 1, key_length) -> (batch_size, query_length, key_length)
1066
+ causal_mask = causal_mask & attention_mask.unsqueeze(1).to(torch.bool)
1067
+ else:
1068
+ if attention_mask is None:
1069
+ # (batch_size, query_length, key_length)
1070
+ causal_mask = torch.ones(batch_size, query_length, key_length, dtype=torch.bool, device=device)
1071
+ else:
1072
+ # (batch_size, query_length, key_length)
1073
+ causal_mask = attention_mask.unsqueeze(1).to(dtype=torch.bool, device=device)
1074
+
1075
+ # ==========================================================================================
1076
+ # attention_mask -> (batch_size, query_length, key_length)
1077
+ # ==========================================================================================
1078
+
1079
+ causal_mask = causal_mask.unsqueeze(1)
1080
+
1081
+ # ==========================================================================================
1082
+ # attention_mask -> (batch_size, 1, query_length, key_length)
1083
+ # ==========================================================================================
1084
+
1085
+ return causal_mask
1086
+
1087
+ def _get_initial_hidden_state(
1088
+ self,
1089
+ input_ids: torch.Tensor,
1090
+ inputs_embeds: torch.Tensor,
1091
+ position_ids: torch.Tensor,
1092
+ token_type_ids: torch.Tensor,
1093
+ ) -> torch.Tensor:
1094
+ if inputs_embeds is None:
1095
+ inputs_embeds = self.wte(input_ids)
1096
+
1097
+ if self.position_embedding_type == PositionEmbeddingType.learned_absolute:
1098
+ inputs_embeds = inputs_embeds + self.wpe(position_ids)
1099
+
1100
+ if token_type_ids is not None:
1101
+ inputs_embeds = inputs_embeds + self.wte(token_type_ids)
1102
+
1103
+ inputs_embeds = self.drop(inputs_embeds)
1104
+
1105
+ return inputs_embeds
1106
+
1107
+ def _prepare_a_bunch_of_stuff(
1108
+ self,
1109
+ input_ids: torch.Tensor,
1110
+ past_key_values: DynamicCache,
1111
+ attention_mask: torch.Tensor,
1112
+ token_type_ids: torch.Tensor,
1113
+ position_ids: torch.Tensor,
1114
+ inputs_embeds: torch.Tensor,
1115
+ use_cache: bool,
1116
+ output_hidden_states: bool,
1117
+ return_dict: bool,
1118
+ ) -> Tuple[
1119
+ bool,
1120
+ bool,
1121
+ bool,
1122
+ torch.Size,
1123
+ torch.Tensor,
1124
+ torch.Tensor,
1125
+ torch.Tensor,
1126
+ Optional[Tuple[torch.Tensor, torch.Tensor]],
1127
+ DynamicCache,
1128
+ ]:
1129
+ output_hidden_states = (
1130
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1131
+ )
1132
+
1133
+ use_cache = self.config.use_cache if use_cache is None else use_cache
1134
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1135
+
1136
+ if input_ids is not None and inputs_embeds is not None:
1137
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1138
+ elif input_ids is not None:
1139
+ input_shape = input_ids.size()
1140
+ elif inputs_embeds is not None:
1141
+ # TODO special handling for padding free transformer needed here if we support inputs_embeds argument
1142
+ input_shape = inputs_embeds.size()[:-1]
1143
+ else:
1144
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1145
+
1146
+ batch_size = input_shape[0]
1147
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1148
+
1149
+ if self.position_embedding_type == PositionEmbeddingType.alibi:
1150
+ if position_ids is not None:
1151
+ warnings.warn("`position_ids` have no functionality with Alibi.", FutureWarning)
1152
+
1153
+ if token_type_ids is not None:
1154
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
1155
+
1156
+ # ==========================================================================================
1157
+ # input_ids -> (batch_size, query_length)
1158
+ # attention_mask -> None or (batch_size, key_length)
1159
+ # position_ids -> None or (batch_size, key_length)
1160
+ # ==========================================================================================
1161
+
1162
+ past_length = 0 if past_key_values is None else past_key_values.get_seq_length()
1163
+ query_length = input_shape[-1]
1164
+ key_length = past_length + query_length
1165
+
1166
+ if position_ids is None:
1167
+ position_ids = self._get_position_ids(attention_mask, past_length, query_length, key_length, device)
1168
+
1169
+ # ==========================================================================================
1170
+ # input_ids -> (batch_size, query_length)
1171
+ # attention_mask -> None or (batch_size, key_length)
1172
+ # position_ids -> (batch_size, query_length)
1173
+ # ==========================================================================================
1174
+
1175
+ hidden_states = self._get_initial_hidden_state(input_ids, inputs_embeds, position_ids, token_type_ids)
1176
+
1177
+ # ==========================================================================================
1178
+ # hidden_states -> (batch_size, query_length, num_heads * head_dim)
1179
+ # ==========================================================================================
1180
+
1181
+ alibi_bias = self._get_alibi_bias(
1182
+ attention_mask, batch_size, query_length, key_length, device, hidden_states.dtype
1183
+ )
1184
+
1185
+ # ==========================================================================================
1186
+ # alibi_bias -> (batch_size, num_heads, query_length, key_length)
1187
+ # ==========================================================================================
1188
+
1189
+ rope_cos_sin = self._get_rope_cos_sin(
1190
+ key_length, position_ids, dtype=hidden_states.dtype, device=hidden_states.device
1191
+ )
1192
+
1193
+ # ==========================================================================================
1194
+ # rope_cos_sin -> 2 * (key_length, head_dim)
1195
+ # ==========================================================================================
1196
+
1197
+ # prepare causal mask only if not using flash attention
1198
+ if self._use_flash_attention_2:
1199
+ if attention_mask is None:
1200
+ attention_mask = torch.ones_like(input_ids)
1201
+ elif self._use_sdpa:
1202
+ # we use the causal/non-causal argument of SDPA for attention in this case
1203
+ if attention_mask is not None:
1204
+ attention_mask = self._prepare_causal_attention_mask(
1205
+ attention_mask, batch_size, query_length, key_length, device
1206
+ )
1207
+
1208
+ attention_mask = torch.where(
1209
+ attention_mask,
1210
+ ~attention_mask if alibi_bias is None else alibi_bias,
1211
+ self._get_mask_value(attention_mask.device, hidden_states.dtype),
1212
+ )
1213
+ else:
1214
+ attention_mask = self._prepare_causal_attention_mask(
1215
+ attention_mask, batch_size, query_length, key_length, device
1216
+ )
1217
+
1218
+ attention_mask = torch.where(
1219
+ attention_mask,
1220
+ ~attention_mask if alibi_bias is None else alibi_bias,
1221
+ self._get_mask_value(attention_mask.device, hidden_states.dtype),
1222
+ )
1223
+
1224
+ return (
1225
+ output_hidden_states,
1226
+ use_cache,
1227
+ return_dict,
1228
+ input_shape,
1229
+ hidden_states,
1230
+ attention_mask,
1231
+ position_ids,
1232
+ rope_cos_sin,
1233
+ past_key_values,
1234
+ )
1235
+
1236
+ def _get_mask_value(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
1237
+ # torch.where expects a tensor. We use a cache to avoid recreating it every time.
1238
+ if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device:
1239
+ self.mask_value = torch.full([], torch.finfo(torch.float16).min, dtype=dtype, device=device)
1240
+ return self.mask_value
1241
+
1242
+
1243
+ class GraniteForCausalLM(GranitePreTrainedModel):
1244
+ _keys_to_ignore_on_load_missing = ["lm_head.weight"]
1245
+
1246
+ def __init__(self, config: GraniteConfig, **kwargs) -> None:
1247
+ super().__init__(config, **kwargs)
1248
+ self.transformer = GraniteModel(config, **kwargs)
1249
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1250
+
1251
+ # Initialize weights and apply final processing
1252
+ self.post_init()
1253
+
1254
+ def get_input_embeddings(self) -> nn.Embedding:
1255
+ return self.transformer.wte
1256
+
1257
+ def set_input_embeddings(self, value: nn.Embedding) -> None:
1258
+ self.transformer.wte = value
1259
+
1260
+ def get_output_embeddings(self) -> nn.Linear:
1261
+ return self.lm_head
1262
+
1263
+ def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
1264
+ self.lm_head = new_embeddings
1265
+
1266
+ # FIXME typing
1267
+ def prepare_inputs_for_generation(
1268
+ self,
1269
+ input_ids: torch.Tensor,
1270
+ past_key_values: Optional[DynamicCache] = None,
1271
+ inputs_embeds: Optional[torch.Tensor] = None,
1272
+ **kwargs,
1273
+ ) -> dict:
1274
+ token_type_ids = kwargs.get("token_type_ids", None)
1275
+ # Omit tokens covered by past_key_values
1276
+ if past_key_values:
1277
+ past_length = past_key_values.get_seq_length()
1278
+
1279
+ # Some generation methods already pass only the last input ID
1280
+ if input_ids.shape[1] > past_length:
1281
+ remove_prefix_length = past_length
1282
+ else:
1283
+ # Default to old behavior: keep only final ID
1284
+ remove_prefix_length = input_ids.shape[1] - 1
1285
+
1286
+ input_ids = input_ids[:, remove_prefix_length:]
1287
+ if token_type_ids is not None:
1288
+ token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
1289
+
1290
+ attention_mask = kwargs.get("attention_mask", None)
1291
+ position_ids = kwargs.get("position_ids", None)
1292
+
1293
+ if attention_mask is not None and position_ids is None:
1294
+ # create position_ids on the fly for batch generation
1295
+ position_ids = attention_mask.long().cumsum(-1) - 1
1296
+ position_ids.masked_fill_(attention_mask == 0, 0)
1297
+ if past_key_values:
1298
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1299
+ else:
1300
+ position_ids = None
1301
+
1302
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1303
+ if inputs_embeds is not None and past_key_values is None:
1304
+ model_inputs = {"inputs_embeds": inputs_embeds}
1305
+ else:
1306
+ model_inputs = {"input_ids": input_ids}
1307
+
1308
+ model_inputs.update(
1309
+ {
1310
+ "past_key_values": past_key_values,
1311
+ "use_cache": kwargs.get("use_cache"),
1312
+ "position_ids": position_ids,
1313
+ "attention_mask": attention_mask,
1314
+ "token_type_ids": token_type_ids,
1315
+ }
1316
+ )
1317
+ return model_inputs
1318
+
1319
+ def forward(
1320
+ self,
1321
+ input_ids: Optional[Union[torch.Tensor]] = None,
1322
+ past_key_values: Optional[DynamicCache] = None,
1323
+ attention_mask: Optional[torch.Tensor] = None,
1324
+ token_type_ids: Optional[Union[torch.Tensor]] = None,
1325
+ position_ids: Optional[Union[torch.Tensor]] = None,
1326
+ inputs_embeds: Optional[Union[torch.Tensor]] = None,
1327
+ labels: Optional[Union[torch.Tensor]] = None,
1328
+ use_cache: Optional[bool] = None,
1329
+ output_attentions: Optional[bool] = None,
1330
+ output_hidden_states: Optional[bool] = None,
1331
+ return_dict: Optional[bool] = None,
1332
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1333
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1334
+
1335
+ # ==========================================================================================
1336
+ # input_ids -> (batch_size, query_length)
1337
+ # attention_mask -> None or (batch_size, key_length)
1338
+ # position_ids -> None or (batch_size, key_length)
1339
+ # ==========================================================================================
1340
+
1341
+ transformer_outputs = self.transformer(
1342
+ input_ids,
1343
+ past_key_values=past_key_values,
1344
+ attention_mask=attention_mask,
1345
+ token_type_ids=token_type_ids,
1346
+ position_ids=position_ids,
1347
+ inputs_embeds=inputs_embeds,
1348
+ use_cache=use_cache,
1349
+ output_hidden_states=output_hidden_states,
1350
+ return_dict=return_dict,
1351
+ )
1352
+ hidden_states = transformer_outputs[0]
1353
+
1354
+ lm_logits = self.lm_head(hidden_states)
1355
+
1356
+ loss = None
1357
+ # Shift so that tokens < n predict n
1358
+ if labels is not None:
1359
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1360
+ shift_labels = labels[..., 1:].contiguous().to(shift_logits.device)
1361
+
1362
+ # Flatten the tokens
1363
+ loss_fct = nn.CrossEntropyLoss()
1364
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1365
+
1366
+ if not return_dict:
1367
+ output = (lm_logits,) + transformer_outputs[1:]
1368
+ return ((loss,) + output) if loss is not None else output
1369
+
1370
+ return CausalLMOutputWithCrossAttentions(
1371
+ loss=loss,
1372
+ logits=lm_logits,
1373
+ past_key_values=transformer_outputs.past_key_values,
1374
+ hidden_states=transformer_outputs.hidden_states,
1375
+ attentions=transformer_outputs.attentions,
1376
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|endoftext|>",
4
+ "<fim_prefix>",
5
+ "<fim_middle>",
6
+ "<fim_suffix>",
7
+ "<fim_pad>",
8
+ "<filename>",
9
+ "<gh_stars>",
10
+ "<issue_start>",
11
+ "<issue_comment>",
12
+ "<issue_closed>",
13
+ "<jupyter_start>",
14
+ "<jupyter_text>",
15
+ "<jupyter_code>",
16
+ "<jupyter_output>",
17
+ "<empty_output>",
18
+ "<commit_before>",
19
+ "<commit_msg>",
20
+ "<commit_after>",
21
+ "<reponame>"
22
+ ],
23
+ "bos_token": {
24
+ "content": "<|endoftext|>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "eos_token": {
31
+ "content": "<|endoftext|>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "pad_token": {
38
+ "content": "<|endoftext|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "<|endoftext|>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "<fim_prefix>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "<fim_middle>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "3": {
29
+ "content": "<fim_suffix>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "4": {
37
+ "content": "<fim_pad>",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ },
44
+ "5": {
45
+ "content": "<filename>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": true
51
+ },
52
+ "6": {
53
+ "content": "<gh_stars>",
54
+ "lstrip": false,
55
+ "normalized": false,
56
+ "rstrip": false,
57
+ "single_word": false,
58
+ "special": true
59
+ },
60
+ "7": {
61
+ "content": "<issue_start>",
62
+ "lstrip": false,
63
+ "normalized": false,
64
+ "rstrip": false,
65
+ "single_word": false,
66
+ "special": true
67
+ },
68
+ "8": {
69
+ "content": "<issue_comment>",
70
+ "lstrip": false,
71
+ "normalized": false,
72
+ "rstrip": false,
73
+ "single_word": false,
74
+ "special": true
75
+ },
76
+ "9": {
77
+ "content": "<issue_closed>",
78
+ "lstrip": false,
79
+ "normalized": false,
80
+ "rstrip": false,
81
+ "single_word": false,
82
+ "special": true
83
+ },
84
+ "10": {
85
+ "content": "<jupyter_start>",
86
+ "lstrip": false,
87
+ "normalized": false,
88
+ "rstrip": false,
89
+ "single_word": false,
90
+ "special": true
91
+ },
92
+ "11": {
93
+ "content": "<jupyter_text>",
94
+ "lstrip": false,
95
+ "normalized": false,
96
+ "rstrip": false,
97
+ "single_word": false,
98
+ "special": true
99
+ },
100
+ "12": {
101
+ "content": "<jupyter_code>",
102
+ "lstrip": false,
103
+ "normalized": false,
104
+ "rstrip": false,
105
+ "single_word": false,
106
+ "special": true
107
+ },
108
+ "13": {
109
+ "content": "<jupyter_output>",
110
+ "lstrip": false,
111
+ "normalized": false,
112
+ "rstrip": false,
113
+ "single_word": false,
114
+ "special": true
115
+ },
116
+ "14": {
117
+ "content": "<empty_output>",
118
+ "lstrip": false,
119
+ "normalized": false,
120
+ "rstrip": false,
121
+ "single_word": false,
122
+ "special": true
123
+ },
124
+ "15": {
125
+ "content": "<commit_before>",
126
+ "lstrip": false,
127
+ "normalized": false,
128
+ "rstrip": false,
129
+ "single_word": false,
130
+ "special": true
131
+ },
132
+ "16": {
133
+ "content": "<commit_msg>",
134
+ "lstrip": false,
135
+ "normalized": false,
136
+ "rstrip": false,
137
+ "single_word": false,
138
+ "special": true
139
+ },
140
+ "17": {
141
+ "content": "<commit_after>",
142
+ "lstrip": false,
143
+ "normalized": false,
144
+ "rstrip": false,
145
+ "single_word": false,
146
+ "special": true
147
+ },
148
+ "18": {
149
+ "content": "<reponame>",
150
+ "lstrip": false,
151
+ "normalized": false,
152
+ "rstrip": false,
153
+ "single_word": false,
154
+ "special": true
155
+ }
156
+ },
157
+ "additional_special_tokens": [
158
+ "<|endoftext|>",
159
+ "<fim_prefix>",
160
+ "<fim_middle>",
161
+ "<fim_suffix>",
162
+ "<fim_pad>",
163
+ "<filename>",
164
+ "<gh_stars>",
165
+ "<issue_start>",
166
+ "<issue_comment>",
167
+ "<issue_closed>",
168
+ "<jupyter_start>",
169
+ "<jupyter_text>",
170
+ "<jupyter_code>",
171
+ "<jupyter_output>",
172
+ "<empty_output>",
173
+ "<commit_before>",
174
+ "<commit_msg>",
175
+ "<commit_after>",
176
+ "<reponame>"
177
+ ],
178
+ "bos_token": "<|endoftext|>",
179
+ "clean_up_tokenization_spaces": true,
180
+ "eos_token": "<|endoftext|>",
181
+ "model_max_length": 9223372036854775807,
182
+ "pad_token": "<|endoftext|>",
183
+ "padding_side": "left",
184
+ "tokenizer_class": "GPT2Tokenizer",
185
+ "unk_token": "<|endoftext|>",
186
+ "vocab_size": 49152
187
+ }