Update modeling_bert_vits2.py
Browse files- modeling_bert_vits2.py +2 -5
modeling_bert_vits2.py
CHANGED
@@ -33,6 +33,7 @@ from transformers.modeling_outputs import (
|
|
33 |
from transformers.models.bert.modeling_bert import BertModel
|
34 |
from transformers.modeling_utils import PreTrainedModel
|
35 |
from transformers.utils import logging
|
|
|
36 |
|
37 |
logger = logging.get_logger(__name__)
|
38 |
|
@@ -1379,7 +1380,7 @@ class BertVits2PreTrainedModel(PreTrainedModel):
|
|
1379 |
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
1380 |
models.
|
1381 |
"""
|
1382 |
-
|
1383 |
base_model_prefix = "vits"
|
1384 |
main_input_name = "input_ids"
|
1385 |
supports_gradient_checkpointing = True
|
@@ -1404,10 +1405,6 @@ class BertVits2PreTrainedModel(PreTrainedModel):
|
|
1404 |
module.weight.data[module.padding_idx].zero_()
|
1405 |
|
1406 |
|
1407 |
-
# drop config_class for BertVits2PreTrainedModel
|
1408 |
-
del BertVits2PreTrainedModel.config_class
|
1409 |
-
|
1410 |
-
|
1411 |
class BertVits2Model(BertVits2PreTrainedModel):
|
1412 |
def __init__(self, config):
|
1413 |
super().__init__(config)
|
|
|
33 |
from transformers.models.bert.modeling_bert import BertModel
|
34 |
from transformers.modeling_utils import PreTrainedModel
|
35 |
from transformers.utils import logging
|
36 |
+
from configuration_bert_vits2 import BertVits2Config
|
37 |
|
38 |
logger = logging.get_logger(__name__)
|
39 |
|
|
|
1380 |
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
1381 |
models.
|
1382 |
"""
|
1383 |
+
config_class = BertVits2Config
|
1384 |
base_model_prefix = "vits"
|
1385 |
main_input_name = "input_ids"
|
1386 |
supports_gradient_checkpointing = True
|
|
|
1405 |
module.weight.data[module.padding_idx].zero_()
|
1406 |
|
1407 |
|
|
|
|
|
|
|
|
|
1408 |
class BertVits2Model(BertVits2PreTrainedModel):
|
1409 |
def __init__(self, config):
|
1410 |
super().__init__(config)
|