HoneyTian commited on
Commit
d983ee9
·
1 Parent(s): b2f977d
toolbox/torchaudio/models/nx_clean_unet/transformer/attention.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import math
4
+ from typing import Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ class MultiHeadedAttention(nn.Module):
11
+ def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
12
+ """
13
+ :param n_head: int. the number of heads.
14
+ :param n_feat: int. the number of features.
15
+ :param dropout_rate: float. dropout rate.
16
+ """
17
+ super().__init__()
18
+ assert n_feat % n_head == 0
19
+ # We assume d_v always equals d_k
20
+ self.d_k = n_feat // n_head
21
+ self.h = n_head
22
+ self.linear_q = nn.Linear(n_feat, n_feat)
23
+ self.linear_k = nn.Linear(n_feat, n_feat)
24
+ self.linear_v = nn.Linear(n_feat, n_feat)
25
+ self.linear_out = nn.Linear(n_feat, n_feat)
26
+ self.dropout = nn.Dropout(p=dropout_rate)
27
+
28
+ def forward_qkv(self,
29
+ query: torch.Tensor,
30
+ key: torch.Tensor,
31
+ value: torch.Tensor
32
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
33
+ """
34
+ transform query, key and value.
35
+ :param query: torch.Tensor. query tensor. shape=(batch_size, time1, n_feat).
36
+ :param key: torch.Tensor. key tensor. shape=(batch_size, time2, n_feat).
37
+ :param value: torch.Tensor. value tensor. shape=(batch_size, time2, n_feat).
38
+ :return:
39
+ """
40
+ n_batch = query.size(0)
41
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
42
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
43
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
44
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
45
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
46
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
47
+
48
+ return q, k, v
49
+
50
+ def forward_attention(self,
51
+ value: torch.Tensor,
52
+ scores: torch.Tensor,
53
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
54
+ ) -> torch.Tensor:
55
+ """
56
+ compute attention context vector.
57
+ :param value: torch.Tensor. transformed value. shape=(batch_size, n_head, time2, d_k).
58
+ :param scores: torch.Tensor. attention score. shape=(batch_size, n_head, time1, time2).
59
+ :param mask: torch.Tensor. mask. shape=(batch_size, 1, time2) or
60
+ (batch_size, time1, time2), (0, 0, 0) means fake mask.
61
+ :return: torch.Tensor. transformed value. (batch_size, time1, d_model).
62
+ weighted by the attention score (batch_size, time1, time2).
63
+ """
64
+ n_batch = value.size(0)
65
+ # NOTE: When will `if mask.size(2) > 0` be True?
66
+ # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
67
+ # 1st chunk to ease the onnx export.]
68
+ # 2. pytorch training
69
+ if mask.size(2) > 0: # time2 > 0
70
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
71
+ # For last chunk, time2 might be larger than scores.size(-1)
72
+ mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
73
+ scores = scores.masked_fill(mask, -float('inf'))
74
+ attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
75
+
76
+ # NOTE: When will `if mask.size(2) > 0` be False?
77
+ # 1. onnx(16/-1, -1/-1, 16/0)
78
+ # 2. jit (16/-1, -1/-1, 16/0, 16/4)
79
+ else:
80
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
81
+
82
+ p_attn = self.dropout(attn)
83
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
84
+ x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, n_feat)
85
+
86
+ return self.linear_out(x) # (batch, time1, n_feat)
87
+
88
+ def forward(self,
89
+ query: torch.Tensor,
90
+ key: torch.Tensor,
91
+ value: torch.Tensor,
92
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
93
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
94
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
95
+
96
+ q, k, v = self.forward_qkv(query, key, value)
97
+
98
+ if cache.size(0) > 0:
99
+ key_cache, value_cache = torch.split(
100
+ cache, cache.size(-1) // 2, dim=-1)
101
+ k = torch.cat([key_cache, k], dim=2)
102
+ v = torch.cat([value_cache, v], dim=2)
103
+ # NOTE: We do cache slicing in encoder.forward_chunk, since it's
104
+ # non-trivial to calculate `next_cache_start` here.
105
+ new_cache = torch.cat((k, v), dim=-1)
106
+
107
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
108
+ return self.forward_attention(v, scores, mask), new_cache
109
+
110
+
111
+ class RelativeMultiHeadedAttention(nn.Module):
112
+
113
+ def __init__(self, n_head: int, n_feat: int, dropout_rate: float, max_relative_position: int = 5120):
114
+ """
115
+ :param n_head: int. the number of heads.
116
+ :param n_feat: int. the number of features.
117
+ :param dropout_rate: float. dropout rate.
118
+ :param max_relative_position: int. maximum relative position for relative position encoding.
119
+ """
120
+ super().__init__()
121
+ assert n_feat % n_head == 0
122
+ # We assume d_v always equals d_k
123
+ self.d_k = n_feat // n_head
124
+ self.h = n_head
125
+ self.linear_q = nn.Linear(n_feat, n_feat)
126
+ self.linear_k = nn.Linear(n_feat, n_feat)
127
+ self.linear_v = nn.Linear(n_feat, n_feat)
128
+ self.linear_out = nn.Linear(n_feat, n_feat)
129
+ self.dropout = nn.Dropout(p=dropout_rate)
130
+
131
+ # Relative position encoding
132
+ self.max_relative_position = max_relative_position
133
+ self.relative_position_k = nn.Parameter(torch.randn(max_relative_position * 2 + 1, self.d_k))
134
+
135
+ def forward_qkv(self,
136
+ query: torch.Tensor,
137
+ key: torch.Tensor,
138
+ value: torch.Tensor
139
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
140
+ """
141
+ transform query, key and value.
142
+ :param query: torch.Tensor. query tensor. shape=(batch_size, time1, n_feat).
143
+ :param key: torch.Tensor. key tensor. shape=(batch_size, time2, n_feat).
144
+ :param value: torch.Tensor. value tensor. shape=(batch_size, time2, n_feat).
145
+ :return:
146
+ """
147
+ n_batch = query.size(0)
148
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
149
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
150
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
151
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
152
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
153
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
154
+
155
+ return q, k, v
156
+
157
+ def forward_attention(self,
158
+ value: torch.Tensor,
159
+ scores: torch.Tensor,
160
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
161
+ ) -> torch.Tensor:
162
+ """
163
+ compute attention context vector.
164
+ :param value: torch.Tensor. transformed value. shape=(batch_size, n_head, time2, d_k).
165
+ :param scores: torch.Tensor. attention score. shape=(batch_size, n_head, time1, time2).
166
+ :param mask: torch.Tensor. mask. shape=(batch_size, 1, time2) or
167
+ (batch_size, time1, time2), (0, 0, 0) means fake mask.
168
+ :return: torch.Tensor. transformed value. (batch_size, time1, d_model).
169
+ weighted by the attention score (batch_size, time1, time2).
170
+ """
171
+ n_batch = value.size(0)
172
+ if mask.size(2) > 0: # time2 > 0
173
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
174
+ # For last chunk, time2 might be larger than scores.size(-1)
175
+ mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
176
+ scores = scores.masked_fill(mask, -float('inf'))
177
+ attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
178
+ else:
179
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
180
+
181
+ p_attn = self.dropout(attn)
182
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
183
+ x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, n_feat)
184
+
185
+ return self.linear_out(x) # (batch, time1, n_feat)
186
+
187
+ def relative_position_encoding(self, length: int) -> torch.Tensor:
188
+ """
189
+ Generate relative position encoding.
190
+ :param length: int. length of the sequence.
191
+ :return: torch.Tensor. relative position encoding. shape=(length, length, d_k).
192
+ """
193
+ range_vec = torch.arange(length)
194
+ distance_mat = range_vec.unsqueeze(0) - range_vec.unsqueeze(1)
195
+ distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
196
+ final_mat = distance_mat_clipped + self.max_relative_position
197
+ return final_mat
198
+
199
+ def forward(self,
200
+ query: torch.Tensor,
201
+ key: torch.Tensor,
202
+ value: torch.Tensor,
203
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
204
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
205
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
206
+
207
+ q, k, v = self.forward_qkv(query, key, value)
208
+
209
+ if cache.size(0) > 0:
210
+ key_cache, value_cache = torch.split(
211
+ cache, cache.size(-1) // 2, dim=-1)
212
+ k = torch.cat([key_cache, k], dim=2)
213
+ v = torch.cat([value_cache, v], dim=2)
214
+ # NOTE: We do cache slicing in encoder.forward_chunk, since it's
215
+ # non-trivial to calculate `next_cache_start` here.
216
+
217
+ # new_cache shape: [batch_size, self.h, time_steps, self.d_v * 2]
218
+ new_cache = torch.cat((k, v), dim=-1)
219
+
220
+ # Compute relative position encoding
221
+ length = q.size(2)
222
+ relative_position = self.relative_position_encoding(length)
223
+ relative_position_k = self.relative_position_k[relative_position.view(-1)].view(length, length, -1)
224
+
225
+ relative_position_k = relative_position_k.unsqueeze(0).unsqueeze(0) # (1, 1, length, length, d_k)
226
+ relative_position_k = relative_position_k.expand(q.size(0), q.size(1), -1, -1, -1) # (batch, head, length, length, d_k)
227
+
228
+ native_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
229
+ relative_position_scores = torch.matmul(q.unsqueeze(3), relative_position_k.transpose(-2, -1)).squeeze(3) / math.sqrt(self.d_k)
230
+ scores = native_scores + relative_position_scores
231
+
232
+ return self.forward_attention(v, scores, mask), new_cache
233
+
234
+
235
+ def main():
236
+ rel_attention = RelativeMultiHeadedAttention(n_head=4, n_feat=256, dropout_rate=0.1)
237
+
238
+ # x = torch.ones(size=(1, 200, 256), dtype=torch.float32)
239
+
240
+ x = torch.ones(size=(1, 1, 256), dtype=torch.float32)
241
+ cache = torch.ones(size=(1, 4, 199, 128), dtype=torch.float32)
242
+
243
+ xt, new_cache = rel_attention.forward(x, x, x, cache=cache)
244
+ print(xt.shape)
245
+ print(new_cache.shape)
246
+ return
247
+
248
+
249
+ if __name__ == '__main__':
250
+ main()
toolbox/torchaudio/models/nx_clean_unet/transformer/embedding.py DELETED
@@ -1,95 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import torch
4
- import torch.nn as nn
5
-
6
-
7
- import torch
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
-
11
- class RelativeMultiheadAttention(nn.Module):
12
- def __init__(self, d_model, num_heads, max_len, dropout=0.1):
13
- super(RelativeMultiheadAttention, self).__init__()
14
- self.num_heads = num_heads
15
- self.d_model = d_model
16
- self.head_dim = d_model // num_heads
17
- self.scale = self.head_dim ** -0.5
18
-
19
- self.query_projection = nn.Linear(d_model, d_model)
20
- self.key_projection = nn.Linear(d_model, d_model)
21
- self.value_projection = nn.Linear(d_model, d_model)
22
- self.output_projection = nn.Linear(d_model, d_model)
23
-
24
- self.dropout = nn.Dropout(dropout)
25
-
26
- # Relative position encoding
27
- self.relative_positions_encoding = self.generate_relative_positions_encoding(max_len, self.head_dim)
28
-
29
- def generate_relative_positions_encoding(self, max_len, head_dim):
30
- # Generate relative positions encoding matrix
31
- even_index = torch.arange(max_len)[:, None] // torch.pow(10000, torch.arange(0, head_dim, 2) / head_dim)
32
- odd_index = torch.arange(max_len)[:, None] // torch.pow(10000, torch.arange(1, head_dim, 2) / head_dim)
33
- even_index = torch.sin(even_index)
34
- odd_index = torch.cos(odd_index)
35
- pos_encoding = torch.zeros(max_len, head_dim)
36
- pos_encoding[:, 0::2] = even_index
37
- pos_encoding[:, 1::2] = odd_index
38
- return pos_encoding
39
-
40
- def forward(self, query, key, value, mask=None):
41
- batch_size = query.size(0)
42
- query_len = query.size(1)
43
- key_len = key.size(1)
44
-
45
- # Project queries, keys, and values to multiple heads
46
- query = self.query_projection(query).view(batch_size, query_len, self.num_heads, self.head_dim).transpose(1, 2)
47
- key = self.key_projection(key).view(batch_size, key_len, self.num_heads, self.head_dim).transpose(1, 2)
48
- value = self.value_projection(value).view(batch_size, key_len, self.num_heads, self.head_dim).transpose(1, 2)
49
-
50
- # Apply relative position encoding
51
- relative_keys = self.relative_positions_encoding[:query_len, :].unsqueeze(0).unsqueeze(0).repeat(batch_size, self.num_heads, 1, 1)
52
- relative_values = self.relative_positions_encoding[:query_len, :].unsqueeze(0).unsqueeze(0).repeat(batch_size, self.num_heads, 1, 1)
53
-
54
- # Compute attention scores
55
- scores = torch.matmul(query, key.transpose(-2, -1)) * self.scale
56
- scores += torch.matmul(query, relative_keys.transpose(-2, -1))
57
-
58
- if mask is not None:
59
- scores = scores.masked_fill(mask == 0, float('-inf'))
60
-
61
- attn_weights = F.softmax(scores, dim=-1)
62
- attn_weights = self.dropout(attn_weights)
63
-
64
- # Apply attention weights to values
65
- output = torch.matmul(attn_weights, value) + torch.matmul(attn_weights, relative_values)
66
- output = output.transpose(1, 2).contiguous().view(batch_size, query_len, self.d_model)
67
-
68
- # Apply output projection
69
- output = self.output_projection(output)
70
-
71
- return output
72
-
73
-
74
- def main():
75
- # Example usage
76
- batch_size = 2
77
- query_len = 10
78
- key_len = 10
79
- d_model = 512
80
- num_heads = 8
81
- max_len = 100
82
-
83
- query = torch.rand(batch_size, query_len, d_model)
84
- key = torch.rand(batch_size, key_len, d_model)
85
- value = torch.rand(batch_size, key_len, d_model)
86
-
87
- attention = RelativeMultiheadAttention(d_model, num_heads, max_len)
88
- output = attention(query, key, value)
89
- print(output.shape) # Output shape should be (batch_size, query_len, d_model)
90
-
91
- return
92
-
93
-
94
- if __name__ == '__main__':
95
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_clean_unet/transformer/transformer.py CHANGED
@@ -5,178 +5,9 @@ from typing import Dict, Optional, Tuple, List, Union
5
 
6
  import torch
7
  import torch.nn as nn
8
- import torch.nn.functional as f
9
 
10
  from toolbox.torchaudio.models.nx_clean_unet.transformer.mask import subsequent_chunk_mask
11
-
12
-
13
- class SinusoidalPositionalEncoding(nn.Module):
14
- """
15
- Positional Encoding
16
-
17
- PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
18
- PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
19
- """
20
-
21
- @staticmethod
22
- def demo1():
23
- batch_size = 2
24
- time_steps = 10
25
- embedding_dim = 64
26
-
27
- pe = SinusoidalPositionalEncoding(
28
- embedding_dim=embedding_dim,
29
- dropout_rate=0.1,
30
- )
31
-
32
- x = torch.randn(size=(batch_size, time_steps, embedding_dim))
33
-
34
- x, pos_emb = pe.forward(x)
35
-
36
- # torch.Size([2, 10, 64])
37
- print(x.shape)
38
- # torch.Size([1, 10, 64])
39
- print(pos_emb.shape)
40
- return
41
-
42
- @staticmethod
43
- def demo2():
44
- batch_size = 2
45
- time_steps = 10
46
- embedding_dim = 64
47
-
48
- pe = SinusoidalPositionalEncoding(
49
- embedding_dim=embedding_dim,
50
- dropout_rate=0.1,
51
- )
52
-
53
- x = torch.randn(size=(batch_size, time_steps, embedding_dim))
54
- offset = torch.randint(low=3, high=7, size=(batch_size,))
55
- x, pos_emb = pe.forward(x, offset=offset)
56
-
57
- # tensor([3, 4])
58
- print(offset)
59
- # torch.Size([2, 10, 64])
60
- print(x.shape)
61
- # torch.Size([2, 10, 64])
62
- print(pos_emb.shape)
63
- return
64
-
65
- def __init__(self,
66
- embedding_dim: int,
67
- dropout_rate: float,
68
- max_length: int = 5000,
69
- reverse: bool = False
70
- ):
71
- super().__init__()
72
- self.embedding_dim = embedding_dim
73
- self.dropout_rate = dropout_rate
74
- self.max_length = max_length
75
- self.reverse = reverse
76
-
77
- self.x_scale = math.sqrt(self.embedding_dim)
78
- self.dropout = torch.nn.Dropout(p=dropout_rate)
79
-
80
- self.pe = torch.zeros(self.max_length, self.embedding_dim)
81
- position = torch.arange(0, self.max_length, dtype=torch.float32).unsqueeze(1)
82
-
83
- div_term = torch.exp(
84
- torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) *
85
- - (math.log(10000.0) / self.embedding_dim)
86
- )
87
- self.pe[:, 0::2] = torch.sin(position * div_term)
88
- self.pe[:, 1::2] = torch.cos(position * div_term)
89
- self.pe = self.pe.unsqueeze(0)
90
-
91
- def forward(self,
92
- x: torch.Tensor,
93
- offset: Union[int, torch.Tensor] = 0
94
- ) -> Tuple[torch.Tensor, torch.Tensor]:
95
- """
96
- Add positional encoding.
97
- :param x: torch.Tensor. Input. shape=(batch_size, time_steps, ...).
98
- :param offset: int or torch.Tensor. position offset.
99
- :return:
100
- torch.Tensor. Encoded tensor. shape=(batch_size, time_steps, ...).
101
- torch.Tensor. for compatibility to RelPositionalEncoding. shape=(1, time_steps, ...).
102
- """
103
- self.pe = self.pe.to(x.device)
104
- pos_emb = self.position_encoding(offset, x.size(1), False)
105
- x = x * self.x_scale + pos_emb
106
- return self.dropout(x), self.dropout(pos_emb)
107
-
108
- def position_encoding(self,
109
- offset: Union[int, torch.Tensor],
110
- size: int,
111
- apply_dropout: bool = True
112
- ) -> torch.Tensor:
113
- """
114
- For getting encoding in a streaming fashion.
115
-
116
- Attention!!!!!
117
- we apply dropout only once at the whole utterance level in a none
118
- streaming way, but will call this function several times with
119
- increasing input size in a streaming scenario, so the dropout will
120
- be applied several times.
121
-
122
- :param offset: int or torch.Tensor. start offset.
123
- :param size: int. required size of position encoding.
124
- :param apply_dropout:
125
- :return: torch.Tensor. Corresponding encoding.
126
- """
127
- if isinstance(offset, int):
128
- assert offset + size <= self.max_length
129
- pos_emb = self.pe[:, offset:offset + size]
130
- elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
131
- assert offset + size <= self.max_length
132
- pos_emb = self.pe[:, offset:offset + size]
133
- else: # for batched streaming decoding on GPU
134
- # offset. shape=(batch_size,)
135
- assert torch.max(offset) + size <= self.max_length
136
-
137
- # shape=(batch_size, time_steps)
138
- index = offset.unsqueeze(1) + torch.arange(0, size).to(offset.device)
139
- flag = index > 0
140
- # remove negative offset
141
- index = index * flag
142
- # shape=(batch_size, time_steps, embedding_dim)
143
- pos_emb = f.embedding(index, self.pe[0])
144
-
145
- if apply_dropout:
146
- pos_emb = self.dropout(pos_emb)
147
- return pos_emb
148
-
149
-
150
- class RelPositionalEncoding(SinusoidalPositionalEncoding):
151
- """
152
- Relative positional encoding module.
153
-
154
- See : Appendix B in https://arxiv.org/abs/1901.02860
155
-
156
- """
157
- def __init__(self,
158
- embedding_dim: int,
159
- dropout_rate: float,
160
- max_length: int = 5000,
161
- ):
162
- super().__init__(embedding_dim, dropout_rate, max_length, reverse=True)
163
-
164
- def forward(self,
165
- x: torch.Tensor,
166
- offset: Union[int, torch.Tensor] = 0
167
- ) -> Tuple[torch.Tensor, torch.Tensor]:
168
- """
169
- Compute positional encoding.
170
- :param x: torch.Tensor. Input. shape=(batch_size, time_steps, ...).
171
- :param offset:
172
- :return:
173
- torch.Tensor. Encoded tensor. shape=(batch_size, time_steps, ...).
174
- torch.Tensor. Positional embedding tensor. shape=(1, time_steps, ...).
175
- """
176
- self.pe = self.pe.to(x.device)
177
- x = x * self.x_scale
178
- pos_emb = self.position_encoding(offset, x.size(1), False)
179
- return self.dropout(x), self.dropout(pos_emb)
180
 
181
 
182
  class PositionwiseFeedForward(nn.Module):
@@ -209,151 +40,20 @@ class PositionwiseFeedForward(nn.Module):
209
  return self.w_2(self.dropout(self.activation(self.w_1(xs))))
210
 
211
 
212
- class MultiHeadedAttention(nn.Module):
213
- def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
214
- """
215
- :param n_head: int. the number of heads.
216
- :param n_feat: int. the number of features.
217
- :param dropout_rate: float. dropout rate.
218
- """
219
- super().__init__()
220
- assert n_feat % n_head == 0
221
- # We assume d_v always equals d_k
222
- self.d_k = n_feat // n_head
223
- self.h = n_head
224
- self.linear_q = nn.Linear(n_feat, n_feat)
225
- self.linear_k = nn.Linear(n_feat, n_feat)
226
- self.linear_v = nn.Linear(n_feat, n_feat)
227
- self.linear_out = nn.Linear(n_feat, n_feat)
228
- self.dropout = nn.Dropout(p=dropout_rate)
229
-
230
- def forward_qkv(self,
231
- query: torch.Tensor,
232
- key: torch.Tensor,
233
- value: torch.Tensor
234
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
235
- """
236
- transform query, key and value.
237
- :param query: torch.Tensor. query tensor. shape=(batch_size, time1, n_feat).
238
- :param key: torch.Tensor. key tensor. shape=(batch_size, time2, n_feat).
239
- :param value: torch.Tensor. value tensor. shape=(batch_size, time2, n_feat).
240
- :return:
241
- """
242
- n_batch = query.size(0)
243
- q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
244
- k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
245
- v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
246
- q = q.transpose(1, 2) # (batch, head, time1, d_k)
247
- k = k.transpose(1, 2) # (batch, head, time2, d_k)
248
- v = v.transpose(1, 2) # (batch, head, time2, d_k)
249
-
250
- return q, k, v
251
-
252
- def forward_attention(self,
253
- value: torch.Tensor,
254
- scores: torch.Tensor,
255
- mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
256
- ) -> torch.Tensor:
257
- """
258
- compute attention context vector.
259
- :param value: torch.Tensor. transformed value. shape=(batch_size, n_head, time2, d_k).
260
- :param scores: torch.Tensor. attention score. shape=(batch_size, n_head, time1, time2).
261
- :param mask: torch.Tensor. mask. shape=(batch_size, 1, time2) or
262
- (batch_size, time1, time2), (0, 0, 0) means fake mask.
263
- :return: torch.Tensor. transformed value. (batch_size, time1, d_model).
264
- weighted by the attention score (batch_size, time1, time2).
265
- """
266
- n_batch = value.size(0)
267
- # NOTE: When will `if mask.size(2) > 0` be True?
268
- # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
269
- # 1st chunk to ease the onnx export.]
270
- # 2. pytorch training
271
- if mask.size(2) > 0: # time2 > 0
272
- mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
273
- # For last chunk, time2 might be larger than scores.size(-1)
274
- mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
275
- scores = scores.masked_fill(mask, -float('inf'))
276
- attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
277
-
278
- # NOTE: When will `if mask.size(2) > 0` be False?
279
- # 1. onnx(16/-1, -1/-1, 16/0)
280
- # 2. jit (16/-1, -1/-1, 16/0, 16/4)
281
- else:
282
- attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
283
-
284
- p_attn = self.dropout(attn)
285
- x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
286
- x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, n_feat)
287
-
288
- return self.linear_out(x) # (batch, time1, n_feat)
289
-
290
- def forward(self,
291
- query: torch.Tensor,
292
- key: torch.Tensor,
293
- value: torch.Tensor,
294
- mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
295
- cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
296
- **kwargs,
297
- ) -> Tuple[torch.Tensor, torch.Tensor]:
298
- """
299
- compute scaled dot product attention.
300
- :param query: torch.Tensor. query tensor. shape=(batch_size, time1, n_feat).
301
- :param key: torch.Tensor. key tensor. shape=(batch_size, time2, n_feat).
302
- :param value: torch.Tensor. value tensor. shape=(batch_size, time2, n_feat).
303
- :param mask: torch.Tensor. mask tensor (batch_size, 1, time2) or
304
- (batch_size, time1, time2).
305
- :param cache: torch.Tensor. cache tensor. shape=(1, head, cache_t, d_k * 2),
306
- where `cache_t == chunk_size * num_decoding_left_chunks`
307
- and `head * d_k == n_feat`
308
- :return:
309
- torch.Tensor. output tensor. shape=(batch_size, time1, n_feat).
310
- torch.Tensor. cache tensor. (1, head, cache_t + time1, d_k * 2)
311
- where `cache_t == chunk_size * num_decoding_left_chunks`
312
- and `head * d_k == n_feat`
313
- """
314
- q, k, v = self.forward_qkv(query, key, value)
315
-
316
- # NOTE:
317
- # when export onnx model, for 1st chunk, we feed
318
- # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
319
- # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
320
- # In all modes, `if cache.size(0) > 0` will alwayse be `True`
321
- # and we will always do splitting and
322
- # concatnation(this will simplify onnx export). Note that
323
- # it's OK to concat & split zero-shaped tensors(see code below).
324
- # when export jit model, for 1st chunk, we always feed
325
- # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
326
- # >>> a = torch.ones((1, 2, 0, 4))
327
- # >>> b = torch.ones((1, 2, 3, 4))
328
- # >>> c = torch.cat((a, b), dim=2)
329
- # >>> torch.equal(b, c) # True
330
- # >>> d = torch.split(a, 2, dim=-1)
331
- # >>> torch.equal(d[0], d[1]) # True
332
- if cache.size(0) > 0:
333
- key_cache, value_cache = torch.split(
334
- cache, cache.size(-1) // 2, dim=-1)
335
- k = torch.cat([key_cache, k], dim=2)
336
- v = torch.cat([value_cache, v], dim=2)
337
- # NOTE: We do cache slicing in encoder.forward_chunk, since it's
338
- # non-trivial to calculate `next_cache_start` here.
339
- new_cache = torch.cat((k, v), dim=-1)
340
-
341
- scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
342
- return self.forward_attention(v, scores, mask), new_cache
343
-
344
-
345
  class TransformerEncoderLayer(nn.Module):
346
  def __init__(self,
347
  input_dim: int,
348
  dropout_rate: float = 0.1,
349
  n_heads: int = 4,
 
350
  ):
351
  super().__init__()
352
  self.norm1 = nn.LayerNorm(input_dim, eps=1e-5)
353
- self.attention = MultiHeadedAttention(
354
  n_head=n_heads,
355
  n_feat=input_dim,
356
- dropout_rate=dropout_rate
 
357
  )
358
 
359
  self.dropout1 = nn.Dropout(dropout_rate)
@@ -370,7 +70,6 @@ class TransformerEncoderLayer(nn.Module):
370
  self,
371
  x: torch.Tensor,
372
  mask: torch.Tensor,
373
- position_embedding: torch.Tensor,
374
  attention_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
375
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
376
  """
@@ -388,7 +87,7 @@ class TransformerEncoderLayer(nn.Module):
388
  xt = self.norm1(x)
389
 
390
  x_att, new_att_cache = self.attention.forward(
391
- xt, xt, xt, mask=mask, cache=attention_cache, position_embedding=position_embedding
392
  )
393
  x = x + self.dropout1(xt)
394
  xt = self.norm2(x)
@@ -410,7 +109,7 @@ class TransformerEncoder(nn.Module):
410
  attention_heads: int = 4,
411
  num_blocks: int = 6,
412
  dropout_rate: float = 0.1,
413
- max_length: int = 1024,
414
  chunk_size: int = 1,
415
  num_left_chunks: int = 128,
416
  ):
@@ -418,7 +117,7 @@ class TransformerEncoder(nn.Module):
418
  self.input_size = input_size
419
  self.hidden_size = hidden_size
420
 
421
- self.max_length = max_length
422
  self.chunk_size = chunk_size
423
  self.num_left_chunks = num_left_chunks
424
 
@@ -427,17 +126,12 @@ class TransformerEncoder(nn.Module):
427
  out_features=self.hidden_size,
428
  )
429
 
430
- self.positional_encoding = RelPositionalEncoding(
431
- embedding_dim=hidden_size,
432
- dropout_rate=dropout_rate,
433
- max_length=max_length,
434
- )
435
-
436
  self.encoder_layer_list = torch.nn.ModuleList([
437
  TransformerEncoderLayer(
438
  input_dim=hidden_size,
439
  n_heads=attention_heads,
440
  dropout_rate=dropout_rate,
 
441
  ) for _ in range(num_blocks)
442
  ])
443
 
@@ -458,10 +152,6 @@ class TransformerEncoder(nn.Module):
458
  xs = self.input_linear.forward(xs)
459
  # xs shape: [batch_size, time_steps, hidden_size]
460
 
461
- xs, position_embedding = self.positional_encoding.forward(xs)
462
- # xs shape: [batch_size, time_steps, hidden_size]
463
- # position_embedding shape: [1, time_steps, hidden_size]
464
-
465
  chunk_masks = subsequent_chunk_mask(
466
  size=time_steps,
467
  chunk_size=self.chunk_size,
@@ -473,7 +163,7 @@ class TransformerEncoder(nn.Module):
473
  # chunk_masks shape: [batch_size, time_steps, time_steps]
474
 
475
  for encoder_layer in self.encoder_layer_list:
476
- xs, _ = encoder_layer.forward(xs, chunk_masks, position_embedding)
477
 
478
  # xs shape: [batch_size, time_steps, hidden_size]
479
  xs = self.output_linear.forward(xs)
 
5
 
6
  import torch
7
  import torch.nn as nn
 
8
 
9
  from toolbox.torchaudio.models.nx_clean_unet.transformer.mask import subsequent_chunk_mask
10
+ from toolbox.torchaudio.models.nx_clean_unet.transformer.attention import MultiHeadedAttention, RelativeMultiHeadedAttention
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  class PositionwiseFeedForward(nn.Module):
 
40
  return self.w_2(self.dropout(self.activation(self.w_1(xs))))
41
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  class TransformerEncoderLayer(nn.Module):
44
  def __init__(self,
45
  input_dim: int,
46
  dropout_rate: float = 0.1,
47
  n_heads: int = 4,
48
+ max_relative_position: int = 5120
49
  ):
50
  super().__init__()
51
  self.norm1 = nn.LayerNorm(input_dim, eps=1e-5)
52
+ self.attention = RelativeMultiHeadedAttention(
53
  n_head=n_heads,
54
  n_feat=input_dim,
55
+ dropout_rate=dropout_rate,
56
+ max_relative_position=max_relative_position,
57
  )
58
 
59
  self.dropout1 = nn.Dropout(dropout_rate)
 
70
  self,
71
  x: torch.Tensor,
72
  mask: torch.Tensor,
 
73
  attention_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
74
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
75
  """
 
87
  xt = self.norm1(x)
88
 
89
  x_att, new_att_cache = self.attention.forward(
90
+ xt, xt, xt, mask=mask, cache=attention_cache
91
  )
92
  x = x + self.dropout1(xt)
93
  xt = self.norm2(x)
 
109
  attention_heads: int = 4,
110
  num_blocks: int = 6,
111
  dropout_rate: float = 0.1,
112
+ max_relative_position: int = 1024,
113
  chunk_size: int = 1,
114
  num_left_chunks: int = 128,
115
  ):
 
117
  self.input_size = input_size
118
  self.hidden_size = hidden_size
119
 
120
+ self.max_relative_position = max_relative_position
121
  self.chunk_size = chunk_size
122
  self.num_left_chunks = num_left_chunks
123
 
 
126
  out_features=self.hidden_size,
127
  )
128
 
 
 
 
 
 
 
129
  self.encoder_layer_list = torch.nn.ModuleList([
130
  TransformerEncoderLayer(
131
  input_dim=hidden_size,
132
  n_heads=attention_heads,
133
  dropout_rate=dropout_rate,
134
+ max_relative_position=max_relative_position,
135
  ) for _ in range(num_blocks)
136
  ])
137
 
 
152
  xs = self.input_linear.forward(xs)
153
  # xs shape: [batch_size, time_steps, hidden_size]
154
 
 
 
 
 
155
  chunk_masks = subsequent_chunk_mask(
156
  size=time_steps,
157
  chunk_size=self.chunk_size,
 
163
  # chunk_masks shape: [batch_size, time_steps, time_steps]
164
 
165
  for encoder_layer in self.encoder_layer_list:
166
+ xs, _ = encoder_layer.forward(xs, chunk_masks)
167
 
168
  # xs shape: [batch_size, time_steps, hidden_size]
169
  xs = self.output_linear.forward(xs)