hans00 commited on
Commit
913c804
1 Parent(s): 621c9ce

Update modeling_bert_vits2.py

Browse files
Files changed (1) hide show
  1. modeling_bert_vits2.py +16 -22
modeling_bert_vits2.py CHANGED
@@ -33,16 +33,10 @@ from transformers.modeling_outputs import (
33
  from transformers.models.bert.modeling_bert import BertModel
34
  from transformers.modeling_utils import PreTrainedModel
35
  from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
36
- from configuration_bert_vits2 import BertVits2Config
37
-
38
 
39
  logger = logging.get_logger(__name__)
40
 
41
 
42
- # General docstring
43
- _CONFIG_FOR_DOC = "BertVits2Config"
44
-
45
-
46
  @dataclass
47
  class BertVits2ModelOutput(ModelOutput):
48
  """
@@ -328,7 +322,7 @@ def _rational_quadratic_spline(
328
 
329
 
330
  class BertVits2WaveNet(torch.nn.Module):
331
- def __init__(self, config: BertVits2Config, num_layers: int):
332
  super().__init__()
333
  self.hidden_size = config.hidden_size
334
  self.num_layers = num_layers
@@ -408,7 +402,7 @@ class BertVits2WaveNet(torch.nn.Module):
408
 
409
 
410
  class BertVits2PosteriorEncoder(nn.Module):
411
- def __init__(self, config: BertVits2Config):
412
  super().__init__()
413
  self.out_channels = config.flow_size
414
 
@@ -485,7 +479,7 @@ class HifiGanResidualBlock(nn.Module):
485
 
486
 
487
  class BertVits2HifiGan(nn.Module):
488
- def __init__(self, config: BertVits2Config):
489
  super().__init__()
490
  self.config = config
491
  self.num_kernels = len(config.resblock_kernel_sizes)
@@ -571,7 +565,7 @@ class BertVits2HifiGan(nn.Module):
571
 
572
 
573
  class BertVits2ResidualCouplingLayer(nn.Module):
574
- def __init__(self, config: BertVits2Config):
575
  super().__init__()
576
  self.half_channels = config.flow_size // 2
577
 
@@ -593,7 +587,7 @@ class BertVits2ResidualCouplingLayer(nn.Module):
593
 
594
 
595
  class BertVits2ResidualCouplingBlock(nn.Module):
596
- def __init__(self, config: BertVits2Config):
597
  super().__init__()
598
  self.flows = nn.ModuleList()
599
  for _ in range(config.prior_encoder_num_flows):
@@ -608,7 +602,7 @@ class BertVits2ResidualCouplingBlock(nn.Module):
608
 
609
 
610
  class BertVits2TransformerCouplingLayer(nn.Module):
611
- def __init__(self, config: BertVits2Config):
612
  super().__init__()
613
  self.half_channels = config.flow_size // 2
614
 
@@ -653,7 +647,7 @@ class BertVits2TransformerCouplingLayer(nn.Module):
653
 
654
 
655
  class BertVits2TransformerCouplingBlock(nn.Module):
656
- def __init__(self, config: BertVits2Config):
657
  super().__init__()
658
  self.flows = nn.ModuleList([
659
  BertVits2TransformerCouplingLayer(config) for _ in range(config.prior_encoder_num_flows)
@@ -672,7 +666,7 @@ class BertVits2TransformerCouplingBlock(nn.Module):
672
 
673
 
674
  class BertVits2DilatedDepthSeparableConv(nn.Module):
675
- def __init__(self, config: BertVits2Config, dropout_rate=0.0):
676
  super().__init__()
677
  kernel_size = config.duration_predictor_kernel_size
678
  channels = config.hidden_size
@@ -718,7 +712,7 @@ class BertVits2DilatedDepthSeparableConv(nn.Module):
718
 
719
 
720
  class BertVits2ConvFlow(nn.Module):
721
- def __init__(self, config: BertVits2Config):
722
  super().__init__()
723
  self.filter_channels = config.hidden_size
724
  self.half_channels = config.depth_separable_channels // 2
@@ -761,7 +755,7 @@ class BertVits2ConvFlow(nn.Module):
761
 
762
 
763
  class BertVits2ElementwiseAffine(nn.Module):
764
- def __init__(self, config: BertVits2Config):
765
  super().__init__()
766
  self.channels = config.depth_separable_channels
767
  self.translate = nn.Parameter(torch.zeros(self.channels, 1))
@@ -918,7 +912,7 @@ class BertVits2DurationPredictor(nn.Module):
918
  class BertVits2Attention(nn.Module):
919
  """Multi-headed attention with relative positional representation."""
920
 
921
- def __init__(self, config: BertVits2Config):
922
  super().__init__()
923
  self.embed_dim = config.hidden_size
924
  self.num_heads = config.num_attention_heads
@@ -1130,7 +1124,7 @@ class BertVits2FeedForward(nn.Module):
1130
 
1131
 
1132
  class BertVits2EncoderLayer(nn.Module):
1133
- def __init__(self, config: BertVits2Config, kernel_size=None):
1134
  super().__init__()
1135
  self.attention = BertVits2Attention(config)
1136
  self.dropout = nn.Dropout(config.hidden_dropout)
@@ -1169,7 +1163,7 @@ class BertVits2EncoderLayer(nn.Module):
1169
 
1170
 
1171
  class BertVits2Encoder(nn.Module):
1172
- def __init__(self, config: BertVits2Config, kernel_size=None, n_layers=None):
1173
  super().__init__()
1174
  self.config = config
1175
  if n_layers is None:
@@ -1260,7 +1254,7 @@ class BertVits2TextEncoder(nn.Module):
1260
  Transformer encoder that uses relative positional representation instead of absolute positional encoding.
1261
  """
1262
 
1263
- def __init__(self, config: BertVits2Config):
1264
  super().__init__()
1265
  self.config = config
1266
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
@@ -1330,7 +1324,7 @@ class BertVits2TextEncoder(nn.Module):
1330
 
1331
 
1332
  class BertVits2ReferenceEncoder(nn.Module):
1333
- def __init__(self, config: BertVits2Config):
1334
  super().__init__()
1335
  self.config = config
1336
  ref_enc_filters = [32, 32, 64, 64, 128, 128]
@@ -1464,7 +1458,7 @@ BERT_VITS2_INPUTS_DOCSTRING = r"""
1464
  BERT_VITS2_START_DOCSTRING,
1465
  )
1466
  class BertVits2Model(BertVits2PreTrainedModel):
1467
- def __init__(self, config: BertVits2Config):
1468
  super().__init__(config)
1469
  self.config = config
1470
  self.text_encoder = BertVits2TextEncoder(config)
 
33
  from transformers.models.bert.modeling_bert import BertModel
34
  from transformers.modeling_utils import PreTrainedModel
35
  from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
 
 
36
 
37
  logger = logging.get_logger(__name__)
38
 
39
 
 
 
 
 
40
  @dataclass
41
  class BertVits2ModelOutput(ModelOutput):
42
  """
 
322
 
323
 
324
  class BertVits2WaveNet(torch.nn.Module):
325
+ def __init__(self, config, num_layers: int):
326
  super().__init__()
327
  self.hidden_size = config.hidden_size
328
  self.num_layers = num_layers
 
402
 
403
 
404
  class BertVits2PosteriorEncoder(nn.Module):
405
+ def __init__(self, config):
406
  super().__init__()
407
  self.out_channels = config.flow_size
408
 
 
479
 
480
 
481
  class BertVits2HifiGan(nn.Module):
482
+ def __init__(self, config):
483
  super().__init__()
484
  self.config = config
485
  self.num_kernels = len(config.resblock_kernel_sizes)
 
565
 
566
 
567
  class BertVits2ResidualCouplingLayer(nn.Module):
568
+ def __init__(self, config):
569
  super().__init__()
570
  self.half_channels = config.flow_size // 2
571
 
 
587
 
588
 
589
  class BertVits2ResidualCouplingBlock(nn.Module):
590
+ def __init__(self, config):
591
  super().__init__()
592
  self.flows = nn.ModuleList()
593
  for _ in range(config.prior_encoder_num_flows):
 
602
 
603
 
604
  class BertVits2TransformerCouplingLayer(nn.Module):
605
+ def __init__(self, config):
606
  super().__init__()
607
  self.half_channels = config.flow_size // 2
608
 
 
647
 
648
 
649
  class BertVits2TransformerCouplingBlock(nn.Module):
650
+ def __init__(self, config):
651
  super().__init__()
652
  self.flows = nn.ModuleList([
653
  BertVits2TransformerCouplingLayer(config) for _ in range(config.prior_encoder_num_flows)
 
666
 
667
 
668
  class BertVits2DilatedDepthSeparableConv(nn.Module):
669
+ def __init__(self, config, dropout_rate=0.0):
670
  super().__init__()
671
  kernel_size = config.duration_predictor_kernel_size
672
  channels = config.hidden_size
 
712
 
713
 
714
  class BertVits2ConvFlow(nn.Module):
715
+ def __init__(self, config):
716
  super().__init__()
717
  self.filter_channels = config.hidden_size
718
  self.half_channels = config.depth_separable_channels // 2
 
755
 
756
 
757
  class BertVits2ElementwiseAffine(nn.Module):
758
+ def __init__(self, config):
759
  super().__init__()
760
  self.channels = config.depth_separable_channels
761
  self.translate = nn.Parameter(torch.zeros(self.channels, 1))
 
912
  class BertVits2Attention(nn.Module):
913
  """Multi-headed attention with relative positional representation."""
914
 
915
+ def __init__(self, config):
916
  super().__init__()
917
  self.embed_dim = config.hidden_size
918
  self.num_heads = config.num_attention_heads
 
1124
 
1125
 
1126
  class BertVits2EncoderLayer(nn.Module):
1127
+ def __init__(self, config, kernel_size=None):
1128
  super().__init__()
1129
  self.attention = BertVits2Attention(config)
1130
  self.dropout = nn.Dropout(config.hidden_dropout)
 
1163
 
1164
 
1165
  class BertVits2Encoder(nn.Module):
1166
+ def __init__(self, config, kernel_size=None, n_layers=None):
1167
  super().__init__()
1168
  self.config = config
1169
  if n_layers is None:
 
1254
  Transformer encoder that uses relative positional representation instead of absolute positional encoding.
1255
  """
1256
 
1257
+ def __init__(self, config):
1258
  super().__init__()
1259
  self.config = config
1260
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
 
1324
 
1325
 
1326
  class BertVits2ReferenceEncoder(nn.Module):
1327
+ def __init__(self, config):
1328
  super().__init__()
1329
  self.config = config
1330
  ref_enc_filters = [32, 32, 64, 64, 128, 128]
 
1458
  BERT_VITS2_START_DOCSTRING,
1459
  )
1460
  class BertVits2Model(BertVits2PreTrainedModel):
1461
+ def __init__(self, config):
1462
  super().__init__(config)
1463
  self.config = config
1464
  self.text_encoder = BertVits2TextEncoder(config)