andybi7676 commited on
Commit
eef5961
·
verified ·
1 Parent(s): ff859b6

Upload model

Browse files
Files changed (4) hide show
  1. config.json +38 -0
  2. configuration_reborn.py +72 -0
  3. modeling_reborn.py +184 -0
  4. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RebornUASRModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_reborn.RebornUASRConfig",
7
+ "AutoModel": "modeling_reborn.RebornUASRModel"
8
+ },
9
+ "discriminator_act_after_linear": false,
10
+ "discriminator_causal": true,
11
+ "discriminator_depth": 1,
12
+ "discriminator_dilation": 1,
13
+ "discriminator_dim": 256,
14
+ "discriminator_dropout": 0.0,
15
+ "discriminator_input_dim": 512,
16
+ "discriminator_kernel": 3,
17
+ "discriminator_linear_emb": false,
18
+ "discriminator_max_pool": false,
19
+ "discriminator_spectral_norm": false,
20
+ "discriminator_weight_norm": false,
21
+ "generator_bias": false,
22
+ "generator_bn_apply": false,
23
+ "generator_bn_init_weight": 30.0,
24
+ "generator_dilation": 1,
25
+ "generator_dropout": 0.0,
26
+ "generator_input_dim": 512,
27
+ "generator_kernel": 4,
28
+ "generator_output_dim": 40,
29
+ "generator_stride": 1,
30
+ "model_type": "reborn_uasr",
31
+ "segmenter_dropout": 0.1,
32
+ "segmenter_hidden_dim": 512,
33
+ "segmenter_input_dim": 512,
34
+ "segmenter_kernel_size": 7,
35
+ "segmenter_type": "cnn",
36
+ "torch_dtype": "float32",
37
+ "transformers_version": "4.24.0"
38
+ }
configuration_reborn.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class RebornUASRConfig(PretrainedConfig):
4
+ '''
5
+ We can use this class to define the configuration of the reborn model.
6
+ The reborn UASR is composed of a segmenter, a discriminator, and a generator.
7
+ We only include the required configurations for the discriminator and the generator from fairseq's wav2vec-U model configuration.
8
+ '''
9
+ model_type = "reborn_uasr"
10
+
11
+ def __init__(self,
12
+ segmenter_type: str = "cnn",
13
+ segmenter_input_dim: int = 512,
14
+ segmenter_hidden_dim: int = 512,
15
+ segmenter_dropout: float = 0.1,
16
+ segmenter_kernel_size: int = 7,
17
+
18
+ discriminator_input_dim: int = 512,
19
+ discriminator_kernel: int = 3,
20
+ discriminator_dilation: int = 1,
21
+ discriminator_dim: int = 256,
22
+ discriminator_causal: bool = True,
23
+ discriminator_linear_emb: bool = False,
24
+ discriminator_depth: int = 1,
25
+ discriminator_max_pool: bool = False,
26
+ discriminator_act_after_linear: bool = False,
27
+ discriminator_dropout: float = 0.0,
28
+ discriminator_spectral_norm: bool = False,
29
+ discriminator_weight_norm: bool = False,
30
+
31
+ generator_input_dim: int = 512,
32
+ generator_output_dim: int = 40,
33
+ generator_kernel: int = 4,
34
+ generator_dilation: int = 1,
35
+ generator_stride: int = 1,
36
+ generator_bias: bool = False,
37
+ generator_dropout: float = 0.0,
38
+ generator_bn_apply: bool = False,
39
+ generator_bn_init_weight: float = 30.0,
40
+ **kwargs
41
+ ):
42
+ super().__init__(**kwargs)
43
+ # read in all the configurations
44
+ self.segmenter_type = segmenter_type
45
+ self.segmenter_input_dim = segmenter_input_dim
46
+ self.segmenter_hidden_dim = segmenter_hidden_dim
47
+ self.segmenter_dropout = segmenter_dropout
48
+ self.segmenter_kernel_size = segmenter_kernel_size
49
+
50
+ self.discriminator_input_dim = discriminator_input_dim
51
+ self.discriminator_kernel = discriminator_kernel
52
+ self.discriminator_dilation = discriminator_dilation
53
+ self.discriminator_dim = discriminator_dim
54
+ self.discriminator_causal = discriminator_causal
55
+ self.discriminator_linear_emb = discriminator_linear_emb
56
+ self.discriminator_depth = discriminator_depth
57
+ self.discriminator_max_pool = discriminator_max_pool
58
+ self.discriminator_act_after_linear = discriminator_act_after_linear
59
+ self.discriminator_dropout = discriminator_dropout
60
+ self.discriminator_spectral_norm = discriminator_spectral_norm
61
+ self.discriminator_weight_norm = discriminator_weight_norm
62
+
63
+ self.generator_input_dim = generator_input_dim
64
+ self.generator_output_dim = generator_output_dim
65
+ self.generator_kernel = generator_kernel
66
+ self.generator_dilation = generator_dilation
67
+ self.generator_stride = generator_stride
68
+ self.generator_bias = generator_bias
69
+ self.generator_dropout = generator_dropout
70
+ self.generator_bn_apply = generator_bn_apply
71
+ self.generator_bn_init_weight = generator_bn_init_weight
72
+
modeling_reborn.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel
4
+ from .configuration_reborn import RebornUASRConfig
5
+ from typing import Optional, Tuple, Union
6
+
7
+ class RebornSegmenter(nn.Module):
8
+ def __init__(self, config):
9
+ super().__init__()
10
+ self.config = config
11
+ self.conv1 = nn.Conv1d(config.segmenter_input_dim, config.segmenter_hidden_dim, config.segmenter_kernel_size, padding=config.segmenter_kernel_size//2)
12
+ self.conv2 = nn.Conv1d(config.segmenter_hidden_dim, config.segmenter_hidden_dim, 3, padding=1)
13
+ self.conv3 = nn.Conv1d(config.segmenter_hidden_dim, 2, 1)
14
+ self.dropout = nn.Dropout(config.segmenter_dropout)
15
+ self.relu = nn.ReLU()
16
+
17
+ def forward(self, x):
18
+ """
19
+ Input:
20
+ x: (B, T, C)
21
+ padding_mask: (B, T) # 0: not padding; 1: padding
22
+ Output:
23
+ boundary: (B, T, 2) # 0: not boundary; 1: boundary
24
+ """
25
+ x = x.transpose(1, 2)
26
+ x = self.dropout(self.relu(self.conv1(x)))
27
+ x = self.dropout(self.relu(self.conv2(x)))
28
+ x = self.conv3(x)
29
+ x = x.transpose(1, 2)
30
+ return x
31
+
32
+ def boundary_predict(self, x, padding_mask, deterministic=False):
33
+ """
34
+ Input:
35
+ x: (B, T, C)
36
+ padding_mask: (B, T)
37
+ Output:
38
+ boundary: (B, T) # 0: not boundary; 1: boundary
39
+ boundary_logits: (B, T, 2) # 0: not boundary; 1: boundary
40
+ """
41
+ boundary_logits = self.forward(x)
42
+ if deterministic:
43
+ boundary = boundary_logits.argmax(-1)
44
+ boundary[padding_mask] = -1
45
+ else:
46
+ boundary = torch.distributions.Categorical(logits=boundary_logits).sample()
47
+ boundary[padding_mask] = -1
48
+ return boundary, boundary_logits
49
+
50
+ def pre_segment(self, logits, padding_mask, return_boundary=False, deterministic=True):
51
+ """
52
+ Input:
53
+ logits: (B, T, C)
54
+ padding_mask: (B, T)
55
+ Output:
56
+ new_logits: (B, T', C)
57
+ new_padding_mask: (B, T')
58
+ """
59
+
60
+ bsz, tsz, csz = logits.size()
61
+
62
+ boundary, boundary_logits = self.boundary_predict(logits, padding_mask, deterministic=deterministic)
63
+
64
+ # max boundary number
65
+ # print("boundary", boundary)
66
+ # print(torch.sum(boundary==1, dim=1))
67
+ new_tsz = int(torch.max(torch.sum(boundary==1, dim=1)).item())+1 # add <bos>
68
+ new_logits = logits.new_zeros(bsz, new_tsz, csz)
69
+ new_pad = padding_mask.new_zeros(bsz, new_tsz)
70
+
71
+ for b in range(bsz):
72
+ # merge consecutive segments when meeting a boundary (mean_pool_join)
73
+ new_idx = 0
74
+ count = 0
75
+ for t in range(tsz):
76
+ if padding_mask[b, t] == 1:
77
+ break
78
+ if boundary[b, t] == 1:
79
+ new_logits[b, new_idx] /= count
80
+ new_idx += 1
81
+ count = 0
82
+ new_logits[b, new_idx] += logits[b, t]
83
+ count += 1
84
+ if count > 0:
85
+ # last segment
86
+ new_logits[b, new_idx] /= count
87
+ new_idx += 1
88
+ count = 0
89
+ if new_idx < new_tsz:
90
+ pad = new_tsz - new_idx
91
+ new_logits[b, -pad:] = 0
92
+ new_pad[b, -pad:] = True
93
+
94
+ if return_boundary:
95
+ return new_logits, new_pad, boundary, boundary_logits
96
+ return new_logits, new_pad
97
+
98
+ class RebornGenerator(nn.Module):
99
+ def __init__(self, config):
100
+ super().__init__()
101
+
102
+ self.config = config
103
+ self.output_dim = config.generator_output_dim
104
+ self.stride = config.generator_stride
105
+ self.dropout = nn.Dropout(config.generator_dropout)
106
+ cnn_input_dim = config.generator_input_dim
107
+ cnn_output_dim = config.generator_output_dim
108
+
109
+ padding = config.generator_kernel // 2
110
+ self.proj = nn.Sequential(
111
+ nn.Conv1d(
112
+ cnn_input_dim,
113
+ cnn_output_dim,
114
+ kernel_size=config.generator_kernel,
115
+ stride=config.generator_stride,
116
+ dilation=config.generator_dilation,
117
+ padding=padding,
118
+ bias=config.generator_bias,
119
+ ),
120
+ )
121
+
122
+ def forward(self, dense_x, tokens, dense_padding_mask):
123
+ dense_x = self.dropout(dense_x)
124
+ # (B, T, C) -> (B, C, T)
125
+ dense_x = dense_x.transpose(-2, -1)
126
+
127
+ dense_x = self.proj(dense_x)
128
+ # (B, C, T) -> (B, T, C)
129
+ dense_x = dense_x.transpose(-2, -1)
130
+ if self.stride > 1:
131
+ dense_padding_mask = dense_padding_mask[:, :: self.stride]
132
+
133
+ if dense_padding_mask.size(1) != dense_x.size(1):
134
+ new_padding = dense_padding_mask.new_zeros(dense_x.shape[:-1])
135
+ diff = new_padding.size(1) - dense_padding_mask.size(1)
136
+ assert (
137
+ diff > 0
138
+ ), f"{new_padding.shape}, {dense_padding_mask.shape}, {dense_x.shape}, {diff}"
139
+ if diff > 0:
140
+ new_padding[:, diff:] = dense_padding_mask
141
+ else:
142
+ assert diff < 0
143
+ new_padding = dense_padding_mask[:, :diff]
144
+
145
+ dense_padding_mask = new_padding
146
+
147
+ result = {}
148
+
149
+ token_x = None
150
+ if tokens is not None:
151
+ token_x = dense_x.new_zeros(tokens.numel(), self.output_dim)
152
+ token_x.scatter_(1, tokens.view(-1, 1).long(), 1)
153
+ token_x = token_x.view(tokens.shape + (self.output_dim,))
154
+
155
+ result["dense_x"] = dense_x
156
+ result["token_x"] = token_x
157
+ result["dense_padding_mask"] = dense_padding_mask
158
+
159
+ return result
160
+
161
+ class RebornUASRModel(PreTrainedModel):
162
+ config_class = RebornUASRConfig
163
+
164
+ def __init__(self, config):
165
+ super().__init__(config)
166
+ self.pca = nn.Linear(1024, 512)
167
+ self.segmenter = RebornSegmenter(config)
168
+ self.generator = RebornGenerator(config)
169
+
170
+ def forward(
171
+ self,
172
+ x: Optional[torch.Tensor], # (B, T, C)
173
+ padding_mask: Optional[torch.Tensor], # (B, T)
174
+ ):
175
+ x_reduced = self.pca(x)
176
+ x_segmented, segmented_padding_mask = self.segmenter.pre_segment(x_reduced, padding_mask, deterministic=True)
177
+ x_generated = self.generator(x_segmented, None, segmented_padding_mask)
178
+
179
+ return {
180
+ 'x_reduced': x_reduced,
181
+ 'x_segmented': x_segmented,
182
+ 'x_generated': x_generated
183
+ }
184
+
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7d0dbe5553999bf8cdc8d7a2d678fee7d169b4514f983fa0e8597e9504f02a6
3
+ size 12923917