Update modeling_bert_vits2.py
Browse files- modeling_bert_vits2.py +16 -22
modeling_bert_vits2.py
CHANGED
@@ -33,16 +33,10 @@ from transformers.modeling_outputs import (
|
|
33 |
from transformers.models.bert.modeling_bert import BertModel
|
34 |
from transformers.modeling_utils import PreTrainedModel
|
35 |
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
36 |
-
from configuration_bert_vits2 import BertVits2Config
|
37 |
-
|
38 |
|
39 |
logger = logging.get_logger(__name__)
|
40 |
|
41 |
|
42 |
-
# General docstring
|
43 |
-
_CONFIG_FOR_DOC = "BertVits2Config"
|
44 |
-
|
45 |
-
|
46 |
@dataclass
|
47 |
class BertVits2ModelOutput(ModelOutput):
|
48 |
"""
|
@@ -328,7 +322,7 @@ def _rational_quadratic_spline(
|
|
328 |
|
329 |
|
330 |
class BertVits2WaveNet(torch.nn.Module):
|
331 |
-
def __init__(self, config
|
332 |
super().__init__()
|
333 |
self.hidden_size = config.hidden_size
|
334 |
self.num_layers = num_layers
|
@@ -408,7 +402,7 @@ class BertVits2WaveNet(torch.nn.Module):
|
|
408 |
|
409 |
|
410 |
class BertVits2PosteriorEncoder(nn.Module):
|
411 |
-
def __init__(self, config
|
412 |
super().__init__()
|
413 |
self.out_channels = config.flow_size
|
414 |
|
@@ -485,7 +479,7 @@ class HifiGanResidualBlock(nn.Module):
|
|
485 |
|
486 |
|
487 |
class BertVits2HifiGan(nn.Module):
|
488 |
-
def __init__(self, config
|
489 |
super().__init__()
|
490 |
self.config = config
|
491 |
self.num_kernels = len(config.resblock_kernel_sizes)
|
@@ -571,7 +565,7 @@ class BertVits2HifiGan(nn.Module):
|
|
571 |
|
572 |
|
573 |
class BertVits2ResidualCouplingLayer(nn.Module):
|
574 |
-
def __init__(self, config
|
575 |
super().__init__()
|
576 |
self.half_channels = config.flow_size // 2
|
577 |
|
@@ -593,7 +587,7 @@ class BertVits2ResidualCouplingLayer(nn.Module):
|
|
593 |
|
594 |
|
595 |
class BertVits2ResidualCouplingBlock(nn.Module):
|
596 |
-
def __init__(self, config
|
597 |
super().__init__()
|
598 |
self.flows = nn.ModuleList()
|
599 |
for _ in range(config.prior_encoder_num_flows):
|
@@ -608,7 +602,7 @@ class BertVits2ResidualCouplingBlock(nn.Module):
|
|
608 |
|
609 |
|
610 |
class BertVits2TransformerCouplingLayer(nn.Module):
|
611 |
-
def __init__(self, config
|
612 |
super().__init__()
|
613 |
self.half_channels = config.flow_size // 2
|
614 |
|
@@ -653,7 +647,7 @@ class BertVits2TransformerCouplingLayer(nn.Module):
|
|
653 |
|
654 |
|
655 |
class BertVits2TransformerCouplingBlock(nn.Module):
|
656 |
-
def __init__(self, config
|
657 |
super().__init__()
|
658 |
self.flows = nn.ModuleList([
|
659 |
BertVits2TransformerCouplingLayer(config) for _ in range(config.prior_encoder_num_flows)
|
@@ -672,7 +666,7 @@ class BertVits2TransformerCouplingBlock(nn.Module):
|
|
672 |
|
673 |
|
674 |
class BertVits2DilatedDepthSeparableConv(nn.Module):
|
675 |
-
def __init__(self, config
|
676 |
super().__init__()
|
677 |
kernel_size = config.duration_predictor_kernel_size
|
678 |
channels = config.hidden_size
|
@@ -718,7 +712,7 @@ class BertVits2DilatedDepthSeparableConv(nn.Module):
|
|
718 |
|
719 |
|
720 |
class BertVits2ConvFlow(nn.Module):
|
721 |
-
def __init__(self, config
|
722 |
super().__init__()
|
723 |
self.filter_channels = config.hidden_size
|
724 |
self.half_channels = config.depth_separable_channels // 2
|
@@ -761,7 +755,7 @@ class BertVits2ConvFlow(nn.Module):
|
|
761 |
|
762 |
|
763 |
class BertVits2ElementwiseAffine(nn.Module):
|
764 |
-
def __init__(self, config
|
765 |
super().__init__()
|
766 |
self.channels = config.depth_separable_channels
|
767 |
self.translate = nn.Parameter(torch.zeros(self.channels, 1))
|
@@ -918,7 +912,7 @@ class BertVits2DurationPredictor(nn.Module):
|
|
918 |
class BertVits2Attention(nn.Module):
|
919 |
"""Multi-headed attention with relative positional representation."""
|
920 |
|
921 |
-
def __init__(self, config
|
922 |
super().__init__()
|
923 |
self.embed_dim = config.hidden_size
|
924 |
self.num_heads = config.num_attention_heads
|
@@ -1130,7 +1124,7 @@ class BertVits2FeedForward(nn.Module):
|
|
1130 |
|
1131 |
|
1132 |
class BertVits2EncoderLayer(nn.Module):
|
1133 |
-
def __init__(self, config
|
1134 |
super().__init__()
|
1135 |
self.attention = BertVits2Attention(config)
|
1136 |
self.dropout = nn.Dropout(config.hidden_dropout)
|
@@ -1169,7 +1163,7 @@ class BertVits2EncoderLayer(nn.Module):
|
|
1169 |
|
1170 |
|
1171 |
class BertVits2Encoder(nn.Module):
|
1172 |
-
def __init__(self, config
|
1173 |
super().__init__()
|
1174 |
self.config = config
|
1175 |
if n_layers is None:
|
@@ -1260,7 +1254,7 @@ class BertVits2TextEncoder(nn.Module):
|
|
1260 |
Transformer encoder that uses relative positional representation instead of absolute positional encoding.
|
1261 |
"""
|
1262 |
|
1263 |
-
def __init__(self, config
|
1264 |
super().__init__()
|
1265 |
self.config = config
|
1266 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
|
@@ -1330,7 +1324,7 @@ class BertVits2TextEncoder(nn.Module):
|
|
1330 |
|
1331 |
|
1332 |
class BertVits2ReferenceEncoder(nn.Module):
|
1333 |
-
def __init__(self, config
|
1334 |
super().__init__()
|
1335 |
self.config = config
|
1336 |
ref_enc_filters = [32, 32, 64, 64, 128, 128]
|
@@ -1464,7 +1458,7 @@ BERT_VITS2_INPUTS_DOCSTRING = r"""
|
|
1464 |
BERT_VITS2_START_DOCSTRING,
|
1465 |
)
|
1466 |
class BertVits2Model(BertVits2PreTrainedModel):
|
1467 |
-
def __init__(self, config
|
1468 |
super().__init__(config)
|
1469 |
self.config = config
|
1470 |
self.text_encoder = BertVits2TextEncoder(config)
|
|
|
33 |
from transformers.models.bert.modeling_bert import BertModel
|
34 |
from transformers.modeling_utils import PreTrainedModel
|
35 |
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
|
|
|
|
36 |
|
37 |
logger = logging.get_logger(__name__)
|
38 |
|
39 |
|
|
|
|
|
|
|
|
|
40 |
@dataclass
|
41 |
class BertVits2ModelOutput(ModelOutput):
|
42 |
"""
|
|
|
322 |
|
323 |
|
324 |
class BertVits2WaveNet(torch.nn.Module):
|
325 |
+
def __init__(self, config, num_layers: int):
|
326 |
super().__init__()
|
327 |
self.hidden_size = config.hidden_size
|
328 |
self.num_layers = num_layers
|
|
|
402 |
|
403 |
|
404 |
class BertVits2PosteriorEncoder(nn.Module):
|
405 |
+
def __init__(self, config):
|
406 |
super().__init__()
|
407 |
self.out_channels = config.flow_size
|
408 |
|
|
|
479 |
|
480 |
|
481 |
class BertVits2HifiGan(nn.Module):
|
482 |
+
def __init__(self, config):
|
483 |
super().__init__()
|
484 |
self.config = config
|
485 |
self.num_kernels = len(config.resblock_kernel_sizes)
|
|
|
565 |
|
566 |
|
567 |
class BertVits2ResidualCouplingLayer(nn.Module):
|
568 |
+
def __init__(self, config):
|
569 |
super().__init__()
|
570 |
self.half_channels = config.flow_size // 2
|
571 |
|
|
|
587 |
|
588 |
|
589 |
class BertVits2ResidualCouplingBlock(nn.Module):
|
590 |
+
def __init__(self, config):
|
591 |
super().__init__()
|
592 |
self.flows = nn.ModuleList()
|
593 |
for _ in range(config.prior_encoder_num_flows):
|
|
|
602 |
|
603 |
|
604 |
class BertVits2TransformerCouplingLayer(nn.Module):
|
605 |
+
def __init__(self, config):
|
606 |
super().__init__()
|
607 |
self.half_channels = config.flow_size // 2
|
608 |
|
|
|
647 |
|
648 |
|
649 |
class BertVits2TransformerCouplingBlock(nn.Module):
|
650 |
+
def __init__(self, config):
|
651 |
super().__init__()
|
652 |
self.flows = nn.ModuleList([
|
653 |
BertVits2TransformerCouplingLayer(config) for _ in range(config.prior_encoder_num_flows)
|
|
|
666 |
|
667 |
|
668 |
class BertVits2DilatedDepthSeparableConv(nn.Module):
|
669 |
+
def __init__(self, config, dropout_rate=0.0):
|
670 |
super().__init__()
|
671 |
kernel_size = config.duration_predictor_kernel_size
|
672 |
channels = config.hidden_size
|
|
|
712 |
|
713 |
|
714 |
class BertVits2ConvFlow(nn.Module):
|
715 |
+
def __init__(self, config):
|
716 |
super().__init__()
|
717 |
self.filter_channels = config.hidden_size
|
718 |
self.half_channels = config.depth_separable_channels // 2
|
|
|
755 |
|
756 |
|
757 |
class BertVits2ElementwiseAffine(nn.Module):
|
758 |
+
def __init__(self, config):
|
759 |
super().__init__()
|
760 |
self.channels = config.depth_separable_channels
|
761 |
self.translate = nn.Parameter(torch.zeros(self.channels, 1))
|
|
|
912 |
class BertVits2Attention(nn.Module):
|
913 |
"""Multi-headed attention with relative positional representation."""
|
914 |
|
915 |
+
def __init__(self, config):
|
916 |
super().__init__()
|
917 |
self.embed_dim = config.hidden_size
|
918 |
self.num_heads = config.num_attention_heads
|
|
|
1124 |
|
1125 |
|
1126 |
class BertVits2EncoderLayer(nn.Module):
|
1127 |
+
def __init__(self, config, kernel_size=None):
|
1128 |
super().__init__()
|
1129 |
self.attention = BertVits2Attention(config)
|
1130 |
self.dropout = nn.Dropout(config.hidden_dropout)
|
|
|
1163 |
|
1164 |
|
1165 |
class BertVits2Encoder(nn.Module):
|
1166 |
+
def __init__(self, config, kernel_size=None, n_layers=None):
|
1167 |
super().__init__()
|
1168 |
self.config = config
|
1169 |
if n_layers is None:
|
|
|
1254 |
Transformer encoder that uses relative positional representation instead of absolute positional encoding.
|
1255 |
"""
|
1256 |
|
1257 |
+
def __init__(self, config):
|
1258 |
super().__init__()
|
1259 |
self.config = config
|
1260 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
|
|
|
1324 |
|
1325 |
|
1326 |
class BertVits2ReferenceEncoder(nn.Module):
|
1327 |
+
def __init__(self, config):
|
1328 |
super().__init__()
|
1329 |
self.config = config
|
1330 |
ref_enc_filters = [32, 32, 64, 64, 128, 128]
|
|
|
1458 |
BERT_VITS2_START_DOCSTRING,
|
1459 |
)
|
1460 |
class BertVits2Model(BertVits2PreTrainedModel):
|
1461 |
+
def __init__(self, config):
|
1462 |
super().__init__(config)
|
1463 |
self.config = config
|
1464 |
self.text_encoder = BertVits2TextEncoder(config)
|