Update bert_layers.py
#3
by
clarine
- opened
- bert_layers.py +7 -0
bert_layers.py
CHANGED
@@ -51,6 +51,7 @@ from transformers.models.bert.modeling_bert import BertPreTrainedModel
|
|
51 |
from .bert_padding import (index_first_axis,
|
52 |
index_put_first_axis, pad_input,
|
53 |
unpad_input, unpad_input_only)
|
|
|
54 |
|
55 |
try:
|
56 |
from .flash_attn_triton import flash_attn_qkvpacked_func
|
@@ -625,6 +626,8 @@ class BertModel(BertPreTrainedModel):
|
|
625 |
```
|
626 |
"""
|
627 |
|
|
|
|
|
628 |
def __init__(self, config, add_pooling_layer=True):
|
629 |
super(BertModel, self).__init__(config)
|
630 |
self.embeddings = BertEmbeddings(config)
|
@@ -758,6 +761,8 @@ class BertLMHeadModel(BertPreTrainedModel):
|
|
758 |
|
759 |
class BertForMaskedLM(BertPreTrainedModel):
|
760 |
|
|
|
|
|
761 |
def __init__(self, config):
|
762 |
super().__init__(config)
|
763 |
|
@@ -928,6 +933,8 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|
928 |
e.g., GLUE tasks.
|
929 |
"""
|
930 |
|
|
|
|
|
931 |
def __init__(self, config):
|
932 |
super().__init__(config)
|
933 |
self.num_labels = config.num_labels
|
|
|
51 |
from .bert_padding import (index_first_axis,
|
52 |
index_put_first_axis, pad_input,
|
53 |
unpad_input, unpad_input_only)
|
54 |
+
from .configuration_bert import BertConfig
|
55 |
|
56 |
try:
|
57 |
from .flash_attn_triton import flash_attn_qkvpacked_func
|
|
|
626 |
```
|
627 |
"""
|
628 |
|
629 |
+
config_class = BertConfig
|
630 |
+
|
631 |
def __init__(self, config, add_pooling_layer=True):
|
632 |
super(BertModel, self).__init__(config)
|
633 |
self.embeddings = BertEmbeddings(config)
|
|
|
761 |
|
762 |
class BertForMaskedLM(BertPreTrainedModel):
|
763 |
|
764 |
+
config_class = BertConfig
|
765 |
+
|
766 |
def __init__(self, config):
|
767 |
super().__init__(config)
|
768 |
|
|
|
933 |
e.g., GLUE tasks.
|
934 |
"""
|
935 |
|
936 |
+
config_class = BertConfig
|
937 |
+
|
938 |
def __init__(self, config):
|
939 |
super().__init__(config)
|
940 |
self.num_labels = config.num_labels
|