add mask_first_token
Browse files- README.md +5 -1
- config.json +1 -0
- modeling_lsg_bert.py +5 -0
README.md
CHANGED
@@ -52,13 +52,17 @@ You can change various parameters like :
|
|
52 |
Default parameters work well in practice. If you are short on memory, reduce block sizes, increase sparsity factor and remove dropout in the attention score matrix.
|
53 |
|
54 |
```python:
|
|
|
|
|
55 |
model = AutoModel.from_pretrained("ccdv/legal-lsg-small-uncased-4096",
|
56 |
trust_remote_code=True,
|
57 |
num_global_tokens=16,
|
58 |
block_size=64,
|
59 |
sparse_block_size=64,
|
60 |
-
sparsity_factor=4,
|
61 |
attention_probs_dropout_prob=0.0
|
|
|
|
|
|
|
62 |
)
|
63 |
```
|
64 |
|
|
|
52 |
Default parameters work well in practice. If you are short on memory, reduce block sizes, increase sparsity factor and remove dropout in the attention score matrix.
|
53 |
|
54 |
```python:
|
55 |
+
from transformers import AutoModel
|
56 |
+
|
57 |
model = AutoModel.from_pretrained("ccdv/legal-lsg-small-uncased-4096",
|
58 |
trust_remote_code=True,
|
59 |
num_global_tokens=16,
|
60 |
block_size=64,
|
61 |
sparse_block_size=64,
|
|
|
62 |
attention_probs_dropout_prob=0.0
|
63 |
+
sparsity_factor=4,
|
64 |
+
sparsity_type="none",
|
65 |
+
mask_first_token=True
|
66 |
)
|
67 |
```
|
68 |
|
config.json
CHANGED
@@ -28,6 +28,7 @@
|
|
28 |
"intermediate_size": 2048,
|
29 |
"layer_norm_eps": 1e-12,
|
30 |
"lsh_num_pre_rounds": 1,
|
|
|
31 |
"max_position_embeddings": 4096,
|
32 |
"model_type": "bert",
|
33 |
"num_attention_heads": 8,
|
|
|
28 |
"intermediate_size": 2048,
|
29 |
"layer_norm_eps": 1e-12,
|
30 |
"lsh_num_pre_rounds": 1,
|
31 |
+
"mask_first_token": false,
|
32 |
"max_position_embeddings": 4096,
|
33 |
"model_type": "bert",
|
34 |
"num_attention_heads": 8,
|
modeling_lsg_bert.py
CHANGED
@@ -31,6 +31,7 @@ class LSGBertConfig(BertConfig):
|
|
31 |
base_model_prefix="lsg",
|
32 |
block_size=128,
|
33 |
lsh_num_pre_rounds=1,
|
|
|
34 |
num_global_tokens=1,
|
35 |
pool_with_global=True,
|
36 |
sparse_block_size=128,
|
@@ -46,6 +47,7 @@ class LSGBertConfig(BertConfig):
|
|
46 |
self.base_model_prefix = base_model_prefix
|
47 |
self.block_size = block_size
|
48 |
self.lsh_num_pre_rounds = lsh_num_pre_rounds
|
|
|
49 |
self.num_global_tokens = num_global_tokens
|
50 |
self.pool_with_global = pool_with_global
|
51 |
self.sparse_block_size = sparse_block_size
|
@@ -1004,6 +1006,7 @@ class LSGBertModel(LSGBertPreTrainedModel, BertModel):
|
|
1004 |
assert hasattr(config, "block_size") and hasattr(config, "adaptive")
|
1005 |
self.block_size = config.block_size
|
1006 |
self.adaptive = config.adaptive
|
|
|
1007 |
self.pool_with_global = config.pool_with_global
|
1008 |
|
1009 |
self.embeddings = LSGBertEmbeddings(config)
|
@@ -1040,6 +1043,8 @@ class LSGBertModel(LSGBertPreTrainedModel, BertModel):
|
|
1040 |
|
1041 |
if attention_mask is None:
|
1042 |
attention_mask = torch.ones(n, t, device=inputs_.device)
|
|
|
|
|
1043 |
if token_type_ids is None:
|
1044 |
token_type_ids = torch.zeros(n, t, device=inputs_.device).long()
|
1045 |
|
|
|
31 |
base_model_prefix="lsg",
|
32 |
block_size=128,
|
33 |
lsh_num_pre_rounds=1,
|
34 |
+
mask_first_token=False,
|
35 |
num_global_tokens=1,
|
36 |
pool_with_global=True,
|
37 |
sparse_block_size=128,
|
|
|
47 |
self.base_model_prefix = base_model_prefix
|
48 |
self.block_size = block_size
|
49 |
self.lsh_num_pre_rounds = lsh_num_pre_rounds
|
50 |
+
self.mask_first_token = mask_first_token
|
51 |
self.num_global_tokens = num_global_tokens
|
52 |
self.pool_with_global = pool_with_global
|
53 |
self.sparse_block_size = sparse_block_size
|
|
|
1006 |
assert hasattr(config, "block_size") and hasattr(config, "adaptive")
|
1007 |
self.block_size = config.block_size
|
1008 |
self.adaptive = config.adaptive
|
1009 |
+
self.mask_first_token = config.mask_first_token
|
1010 |
self.pool_with_global = config.pool_with_global
|
1011 |
|
1012 |
self.embeddings = LSGBertEmbeddings(config)
|
|
|
1043 |
|
1044 |
if attention_mask is None:
|
1045 |
attention_mask = torch.ones(n, t, device=inputs_.device)
|
1046 |
+
if self.mask_first_token:
|
1047 |
+
attention_mask[:,0] = 0
|
1048 |
if token_type_ids is None:
|
1049 |
token_type_ids = torch.zeros(n, t, device=inputs_.device).long()
|
1050 |
|