update for transformers >= 4.29.1
Browse files- modeling_lsg_bert.py +15 -22
modeling_lsg_bert.py
CHANGED
@@ -189,19 +189,25 @@ class CausalAttentionProduct(nn.Module):
|
|
189 |
del key_layer
|
190 |
|
191 |
if attention_mask is not None:
|
192 |
-
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
193 |
-
attention_scores = attention_scores + attention_mask
|
194 |
-
|
195 |
# Add causal mask
|
196 |
causal_shape = (self.block_size, self.block_size) if causal_shape is None else causal_shape
|
197 |
causal_mask = torch.tril(
|
198 |
torch.ones(*causal_shape, device=attention_mask.device, dtype=attention_scores.dtype),
|
199 |
diagonal=-1
|
200 |
)
|
201 |
-
|
202 |
-
|
|
|
|
|
|
|
203 |
|
|
|
|
|
|
|
|
|
|
|
204 |
del attention_mask
|
|
|
205 |
|
206 |
# Normalize the attention scores to probabilities.
|
207 |
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
@@ -991,8 +997,6 @@ class LSGBertModel(LSGBertPreTrainedModel, BertModel):
|
|
991 |
documentation alongside usage examples.
|
992 |
"""
|
993 |
|
994 |
-
config_class = LSGBertConfig
|
995 |
-
|
996 |
def __init__(self, config, add_pooling_layer=True):
|
997 |
|
998 |
LSGBertPreTrainedModel.__init__(self, config)
|
@@ -1031,6 +1035,8 @@ class LSGBertModel(LSGBertPreTrainedModel, BertModel):
|
|
1031 |
|
1032 |
class LSGBertForPreTraining(LSGBertPreTrainedModel, BertForPreTraining):
|
1033 |
|
|
|
|
|
1034 |
def __init__(self, config):
|
1035 |
|
1036 |
LSGBertPreTrainedModel.__init__(self, config)
|
@@ -1044,8 +1050,7 @@ class LSGBertForPreTraining(LSGBertPreTrainedModel, BertForPreTraining):
|
|
1044 |
|
1045 |
class LSGBertLMHeadModel(LSGBertPreTrainedModel, BertLMHeadModel):
|
1046 |
|
1047 |
-
|
1048 |
-
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
1049 |
|
1050 |
def __init__(self, config):
|
1051 |
|
@@ -1067,9 +1072,7 @@ class LSGBertForMaskedLM(LSGBertPreTrainedModel, BertForMaskedLM):
|
|
1067 |
documentation alongside usage examples.
|
1068 |
"""
|
1069 |
|
1070 |
-
|
1071 |
-
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
1072 |
-
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
1073 |
|
1074 |
def __init__(self, config):
|
1075 |
|
@@ -1107,8 +1110,6 @@ class LSGBertForSequenceClassification(LSGBertPreTrainedModel, BertForSequenceCl
|
|
1107 |
appropriate documentation alongside usage examples.
|
1108 |
"""
|
1109 |
|
1110 |
-
config_class = LSGBertConfig
|
1111 |
-
|
1112 |
def __init__(self, config):
|
1113 |
|
1114 |
LSGBertPreTrainedModel.__init__(self, config)
|
@@ -1133,8 +1134,6 @@ class LSGBertForMultipleChoice(LSGBertPreTrainedModel, BertForMultipleChoice):
|
|
1133 |
appropriate documentation alongside usage examples.
|
1134 |
"""
|
1135 |
|
1136 |
-
config_class = LSGBertConfig
|
1137 |
-
|
1138 |
def __init__(self, config):
|
1139 |
|
1140 |
LSGBertPreTrainedModel.__init__(self, config)
|
@@ -1156,9 +1155,6 @@ class LSGBertForTokenClassification(LSGBertPreTrainedModel, BertForTokenClassifi
|
|
1156 |
appropriate documentation alongside usage examples.
|
1157 |
"""
|
1158 |
|
1159 |
-
config_class = LSGBertConfig
|
1160 |
-
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
1161 |
-
|
1162 |
def __init__(self, config):
|
1163 |
|
1164 |
LSGBertPreTrainedModel.__init__(self, config)
|
@@ -1182,9 +1178,6 @@ class LSGBertForQuestionAnswering(LSGBertPreTrainedModel, BertForQuestionAnsweri
|
|
1182 |
appropriate documentation alongside usage examples.
|
1183 |
"""
|
1184 |
|
1185 |
-
config_class = LSGBertConfig
|
1186 |
-
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
1187 |
-
|
1188 |
def __init__(self, config):
|
1189 |
|
1190 |
LSGBertPreTrainedModel.__init__(self, config)
|
|
|
189 |
del key_layer
|
190 |
|
191 |
if attention_mask is not None:
|
|
|
|
|
|
|
192 |
# Add causal mask
|
193 |
causal_shape = (self.block_size, self.block_size) if causal_shape is None else causal_shape
|
194 |
causal_mask = torch.tril(
|
195 |
torch.ones(*causal_shape, device=attention_mask.device, dtype=attention_scores.dtype),
|
196 |
diagonal=-1
|
197 |
)
|
198 |
+
|
199 |
+
# Min value
|
200 |
+
dtype_min = torch.tensor(
|
201 |
+
torch.finfo(attention_scores.dtype).min, device=attention_scores.device, dtype=attention_scores.dtype
|
202 |
+
)
|
203 |
|
204 |
+
# Build causal + attention_mask
|
205 |
+
causal_mask = torch.nn.functional.pad(causal_mask.T * dtype_min, (attention_mask.size()[-1] - self.block_size, 0), value=0)
|
206 |
+
attention_mask = torch.max(attention_mask + causal_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0), dtype_min)
|
207 |
+
|
208 |
+
attention_scores = attention_scores + attention_mask
|
209 |
del attention_mask
|
210 |
+
del causal_mask
|
211 |
|
212 |
# Normalize the attention scores to probabilities.
|
213 |
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
|
|
997 |
documentation alongside usage examples.
|
998 |
"""
|
999 |
|
|
|
|
|
1000 |
def __init__(self, config, add_pooling_layer=True):
|
1001 |
|
1002 |
LSGBertPreTrainedModel.__init__(self, config)
|
|
|
1035 |
|
1036 |
class LSGBertForPreTraining(LSGBertPreTrainedModel, BertForPreTraining):
|
1037 |
|
1038 |
+
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
1039 |
+
|
1040 |
def __init__(self, config):
|
1041 |
|
1042 |
LSGBertPreTrainedModel.__init__(self, config)
|
|
|
1050 |
|
1051 |
class LSGBertLMHeadModel(LSGBertPreTrainedModel, BertLMHeadModel):
|
1052 |
|
1053 |
+
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
|
|
1054 |
|
1055 |
def __init__(self, config):
|
1056 |
|
|
|
1072 |
documentation alongside usage examples.
|
1073 |
"""
|
1074 |
|
1075 |
+
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
|
|
|
|
1076 |
|
1077 |
def __init__(self, config):
|
1078 |
|
|
|
1110 |
appropriate documentation alongside usage examples.
|
1111 |
"""
|
1112 |
|
|
|
|
|
1113 |
def __init__(self, config):
|
1114 |
|
1115 |
LSGBertPreTrainedModel.__init__(self, config)
|
|
|
1134 |
appropriate documentation alongside usage examples.
|
1135 |
"""
|
1136 |
|
|
|
|
|
1137 |
def __init__(self, config):
|
1138 |
|
1139 |
LSGBertPreTrainedModel.__init__(self, config)
|
|
|
1155 |
appropriate documentation alongside usage examples.
|
1156 |
"""
|
1157 |
|
|
|
|
|
|
|
1158 |
def __init__(self, config):
|
1159 |
|
1160 |
LSGBertPreTrainedModel.__init__(self, config)
|
|
|
1178 |
appropriate documentation alongside usage examples.
|
1179 |
"""
|
1180 |
|
|
|
|
|
|
|
1181 |
def __init__(self, config):
|
1182 |
|
1183 |
LSGBertPreTrainedModel.__init__(self, config)
|