Plachta commited on
Commit
bdf3420
1 Parent(s): 3ec1f67

Update modules/length_regulator.py

Browse files
Files changed (1) hide show
  1. modules/length_regulator.py +102 -96
modules/length_regulator.py CHANGED
@@ -1,96 +1,102 @@
1
- from typing import Tuple
2
- import torch
3
- import torch.nn as nn
4
- from torch.nn import functional as F
5
- from modules.commons import sequence_mask
6
-
7
-
8
- class InterpolateRegulator(nn.Module):
9
- def __init__(
10
- self,
11
- channels: int,
12
- sampling_ratios: Tuple,
13
- is_discrete: bool = False,
14
- codebook_size: int = 1024, # for discrete only
15
- out_channels: int = None,
16
- groups: int = 1,
17
- token_dropout_prob: float = 0.5, # randomly drop out input tokens
18
- token_dropout_range: float = 0.5, # randomly drop out input tokens
19
- n_codebooks: int = 1, # number of codebooks
20
- quantizer_dropout: float = 0.0, # dropout for quantizer
21
- f0_condition: bool = False,
22
- n_f0_bins: int = 512,
23
- ):
24
- super().__init__()
25
- self.sampling_ratios = sampling_ratios
26
- out_channels = out_channels or channels
27
- model = nn.ModuleList([])
28
- if len(sampling_ratios) > 0:
29
- for _ in sampling_ratios:
30
- module = nn.Conv1d(channels, channels, 3, 1, 1)
31
- norm = nn.GroupNorm(groups, channels)
32
- act = nn.Mish()
33
- model.extend([module, norm, act])
34
- model.append(
35
- nn.Conv1d(channels, out_channels, 1, 1)
36
- )
37
- self.model = nn.Sequential(*model)
38
- self.embedding = nn.Embedding(codebook_size, channels)
39
- self.is_discrete = is_discrete
40
-
41
- self.mask_token = nn.Parameter(torch.zeros(1, channels))
42
-
43
- self.n_codebooks = n_codebooks
44
- if n_codebooks > 1:
45
- self.extra_codebooks = nn.ModuleList([
46
- nn.Embedding(codebook_size, channels) for _ in range(n_codebooks - 1)
47
- ])
48
- self.token_dropout_prob = token_dropout_prob
49
- self.token_dropout_range = token_dropout_range
50
- self.quantizer_dropout = quantizer_dropout
51
-
52
- if f0_condition:
53
- self.f0_embedding = nn.Embedding(n_f0_bins, channels)
54
- self.f0_condition = f0_condition
55
- self.n_f0_bins = n_f0_bins
56
- self.f0_bins = torch.arange(2, 1024, 1024 // n_f0_bins)
57
- self.f0_mask = nn.Parameter(torch.zeros(1, channels))
58
- else:
59
- self.f0_condition = False
60
-
61
- def forward(self, x, ylens=None, n_quantizers=None, f0=None):
62
- # apply token drop
63
- if self.training:
64
- n_quantizers = torch.ones((x.shape[0],)) * self.n_codebooks
65
- dropout = torch.randint(1, self.n_codebooks + 1, (x.shape[0],))
66
- n_dropout = int(x.shape[0] * self.quantizer_dropout)
67
- n_quantizers[:n_dropout] = dropout[:n_dropout]
68
- n_quantizers = n_quantizers.to(x.device)
69
- # decide whether to drop for each sample in batch
70
- else:
71
- n_quantizers = torch.ones((x.shape[0],), device=x.device) * (self.n_codebooks if n_quantizers is None else n_quantizers)
72
- if self.is_discrete:
73
- if self.n_codebooks > 1:
74
- assert len(x.size()) == 3
75
- x_emb = self.embedding(x[:, 0])
76
- for i, emb in enumerate(self.extra_codebooks):
77
- x_emb = x_emb + (n_quantizers > i+1)[..., None, None] * emb(x[:, i+1])
78
- x = x_emb
79
- elif self.n_codebooks == 1:
80
- if len(x.size()) == 2:
81
- x = self.embedding(x)
82
- else:
83
- x = self.embedding(x[:, 0])
84
- # x in (B, T, D)
85
- mask = sequence_mask(ylens).unsqueeze(-1)
86
- x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
87
- if self.f0_condition:
88
- quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T)
89
- drop_f0 = torch.rand(quantized_f0.size(0)).to(f0.device) < self.quantizer_dropout
90
- f0_emb = self.f0_embedding(quantized_f0)
91
- f0_emb[drop_f0] = self.f0_mask
92
- f0_emb = F.interpolate(f0_emb.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
93
- x = x + f0_emb
94
- out = self.model(x).transpose(1, 2).contiguous()
95
- olens = ylens
96
- return out * mask, olens
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ from modules.commons import sequence_mask
6
+
7
+
8
+ class InterpolateRegulator(nn.Module):
9
+ def __init__(
10
+ self,
11
+ channels: int,
12
+ sampling_ratios: Tuple,
13
+ is_discrete: bool = False,
14
+ codebook_size: int = 1024, # for discrete only
15
+ out_channels: int = None,
16
+ groups: int = 1,
17
+ token_dropout_prob: float = 0.5, # randomly drop out input tokens
18
+ token_dropout_range: float = 0.5, # randomly drop out input tokens
19
+ n_codebooks: int = 1, # number of codebooks
20
+ quantizer_dropout: float = 0.0, # dropout for quantizer
21
+ f0_condition: bool = False,
22
+ n_f0_bins: int = 512,
23
+ ):
24
+ super().__init__()
25
+ self.sampling_ratios = sampling_ratios
26
+ out_channels = out_channels or channels
27
+ model = nn.ModuleList([])
28
+ if len(sampling_ratios) > 0:
29
+ for _ in sampling_ratios:
30
+ module = nn.Conv1d(channels, channels, 3, 1, 1)
31
+ norm = nn.GroupNorm(groups, channels)
32
+ act = nn.Mish()
33
+ model.extend([module, norm, act])
34
+ model.append(
35
+ nn.Conv1d(channels, out_channels, 1, 1)
36
+ )
37
+ self.model = nn.Sequential(*model)
38
+ self.embedding = nn.Embedding(codebook_size, channels)
39
+ self.is_discrete = is_discrete
40
+
41
+ self.mask_token = nn.Parameter(torch.zeros(1, channels))
42
+
43
+ self.n_codebooks = n_codebooks
44
+ if n_codebooks > 1:
45
+ self.extra_codebooks = nn.ModuleList([
46
+ nn.Embedding(codebook_size, channels) for _ in range(n_codebooks - 1)
47
+ ])
48
+ self.token_dropout_prob = token_dropout_prob
49
+ self.token_dropout_range = token_dropout_range
50
+ self.quantizer_dropout = quantizer_dropout
51
+
52
+ if f0_condition:
53
+ self.f0_embedding = nn.Embedding(n_f0_bins, channels)
54
+ self.f0_condition = f0_condition
55
+ self.n_f0_bins = n_f0_bins
56
+ self.f0_bins = torch.arange(2, 1024, 1024 // n_f0_bins)
57
+ self.f0_mask = nn.Parameter(torch.zeros(1, channels))
58
+ else:
59
+ self.f0_condition = False
60
+
61
+ def forward(self, x, ylens=None, n_quantizers=None, f0=None):
62
+ # apply token drop
63
+ if self.training:
64
+ n_quantizers = torch.ones((x.shape[0],)) * self.n_codebooks
65
+ dropout = torch.randint(1, self.n_codebooks + 1, (x.shape[0],))
66
+ n_dropout = int(x.shape[0] * self.quantizer_dropout)
67
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
68
+ n_quantizers = n_quantizers.to(x.device)
69
+ # decide whether to drop for each sample in batch
70
+ else:
71
+ n_quantizers = torch.ones((x.shape[0],), device=x.device) * (self.n_codebooks if n_quantizers is None else n_quantizers)
72
+ if self.is_discrete:
73
+ if self.n_codebooks > 1:
74
+ assert len(x.size()) == 3
75
+ x_emb = self.embedding(x[:, 0])
76
+ for i, emb in enumerate(self.extra_codebooks):
77
+ x_emb = x_emb + (n_quantizers > i+1)[..., None, None] * emb(x[:, i+1])
78
+ x = x_emb
79
+ elif self.n_codebooks == 1:
80
+ if len(x.size()) == 2:
81
+ x = self.embedding(x)
82
+ else:
83
+ x = self.embedding(x[:, 0])
84
+ # x in (B, T, D)
85
+ mask = sequence_mask(ylens).unsqueeze(-1)
86
+ x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
87
+ if self.f0_condition:
88
+ if f0 is None:
89
+ x = x + self.f0_mask.unsqueeze(-1)
90
+ else:
91
+ quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T)
92
+ if self.training:
93
+ drop_f0 = torch.rand(quantized_f0.size(0)).to(f0.device) < self.quantizer_dropout
94
+ else:
95
+ drop_f0 = torch.zeros(quantized_f0.size(0)).to(f0.device).bool()
96
+ f0_emb = self.f0_embedding(quantized_f0)
97
+ f0_emb[drop_f0] = self.f0_mask
98
+ f0_emb = F.interpolate(f0_emb.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
99
+ x = x + f0_emb
100
+ out = self.model(x).transpose(1, 2).contiguous()
101
+ olens = ylens
102
+ return out * mask, olens