Spaces:
Running
Running
update
Browse files
toolbox/torchaudio/models/nx_clean_unet/transformer/attention.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import math
|
4 |
+
from typing import Tuple
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
|
10 |
+
class MultiHeadedAttention(nn.Module):
|
11 |
+
def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
|
12 |
+
"""
|
13 |
+
:param n_head: int. the number of heads.
|
14 |
+
:param n_feat: int. the number of features.
|
15 |
+
:param dropout_rate: float. dropout rate.
|
16 |
+
"""
|
17 |
+
super().__init__()
|
18 |
+
assert n_feat % n_head == 0
|
19 |
+
# We assume d_v always equals d_k
|
20 |
+
self.d_k = n_feat // n_head
|
21 |
+
self.h = n_head
|
22 |
+
self.linear_q = nn.Linear(n_feat, n_feat)
|
23 |
+
self.linear_k = nn.Linear(n_feat, n_feat)
|
24 |
+
self.linear_v = nn.Linear(n_feat, n_feat)
|
25 |
+
self.linear_out = nn.Linear(n_feat, n_feat)
|
26 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
27 |
+
|
28 |
+
def forward_qkv(self,
|
29 |
+
query: torch.Tensor,
|
30 |
+
key: torch.Tensor,
|
31 |
+
value: torch.Tensor
|
32 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
33 |
+
"""
|
34 |
+
transform query, key and value.
|
35 |
+
:param query: torch.Tensor. query tensor. shape=(batch_size, time1, n_feat).
|
36 |
+
:param key: torch.Tensor. key tensor. shape=(batch_size, time2, n_feat).
|
37 |
+
:param value: torch.Tensor. value tensor. shape=(batch_size, time2, n_feat).
|
38 |
+
:return:
|
39 |
+
"""
|
40 |
+
n_batch = query.size(0)
|
41 |
+
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
42 |
+
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
43 |
+
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
44 |
+
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
45 |
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
46 |
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
47 |
+
|
48 |
+
return q, k, v
|
49 |
+
|
50 |
+
def forward_attention(self,
|
51 |
+
value: torch.Tensor,
|
52 |
+
scores: torch.Tensor,
|
53 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
|
54 |
+
) -> torch.Tensor:
|
55 |
+
"""
|
56 |
+
compute attention context vector.
|
57 |
+
:param value: torch.Tensor. transformed value. shape=(batch_size, n_head, time2, d_k).
|
58 |
+
:param scores: torch.Tensor. attention score. shape=(batch_size, n_head, time1, time2).
|
59 |
+
:param mask: torch.Tensor. mask. shape=(batch_size, 1, time2) or
|
60 |
+
(batch_size, time1, time2), (0, 0, 0) means fake mask.
|
61 |
+
:return: torch.Tensor. transformed value. (batch_size, time1, d_model).
|
62 |
+
weighted by the attention score (batch_size, time1, time2).
|
63 |
+
"""
|
64 |
+
n_batch = value.size(0)
|
65 |
+
# NOTE: When will `if mask.size(2) > 0` be True?
|
66 |
+
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
|
67 |
+
# 1st chunk to ease the onnx export.]
|
68 |
+
# 2. pytorch training
|
69 |
+
if mask.size(2) > 0: # time2 > 0
|
70 |
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
71 |
+
# For last chunk, time2 might be larger than scores.size(-1)
|
72 |
+
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
|
73 |
+
scores = scores.masked_fill(mask, -float('inf'))
|
74 |
+
attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
|
75 |
+
|
76 |
+
# NOTE: When will `if mask.size(2) > 0` be False?
|
77 |
+
# 1. onnx(16/-1, -1/-1, 16/0)
|
78 |
+
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
|
79 |
+
else:
|
80 |
+
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
81 |
+
|
82 |
+
p_attn = self.dropout(attn)
|
83 |
+
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
84 |
+
x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, n_feat)
|
85 |
+
|
86 |
+
return self.linear_out(x) # (batch, time1, n_feat)
|
87 |
+
|
88 |
+
def forward(self,
|
89 |
+
query: torch.Tensor,
|
90 |
+
key: torch.Tensor,
|
91 |
+
value: torch.Tensor,
|
92 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
93 |
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
94 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
95 |
+
|
96 |
+
q, k, v = self.forward_qkv(query, key, value)
|
97 |
+
|
98 |
+
if cache.size(0) > 0:
|
99 |
+
key_cache, value_cache = torch.split(
|
100 |
+
cache, cache.size(-1) // 2, dim=-1)
|
101 |
+
k = torch.cat([key_cache, k], dim=2)
|
102 |
+
v = torch.cat([value_cache, v], dim=2)
|
103 |
+
# NOTE: We do cache slicing in encoder.forward_chunk, since it's
|
104 |
+
# non-trivial to calculate `next_cache_start` here.
|
105 |
+
new_cache = torch.cat((k, v), dim=-1)
|
106 |
+
|
107 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
108 |
+
return self.forward_attention(v, scores, mask), new_cache
|
109 |
+
|
110 |
+
|
111 |
+
class RelativeMultiHeadedAttention(nn.Module):
|
112 |
+
|
113 |
+
def __init__(self, n_head: int, n_feat: int, dropout_rate: float, max_relative_position: int = 5120):
|
114 |
+
"""
|
115 |
+
:param n_head: int. the number of heads.
|
116 |
+
:param n_feat: int. the number of features.
|
117 |
+
:param dropout_rate: float. dropout rate.
|
118 |
+
:param max_relative_position: int. maximum relative position for relative position encoding.
|
119 |
+
"""
|
120 |
+
super().__init__()
|
121 |
+
assert n_feat % n_head == 0
|
122 |
+
# We assume d_v always equals d_k
|
123 |
+
self.d_k = n_feat // n_head
|
124 |
+
self.h = n_head
|
125 |
+
self.linear_q = nn.Linear(n_feat, n_feat)
|
126 |
+
self.linear_k = nn.Linear(n_feat, n_feat)
|
127 |
+
self.linear_v = nn.Linear(n_feat, n_feat)
|
128 |
+
self.linear_out = nn.Linear(n_feat, n_feat)
|
129 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
130 |
+
|
131 |
+
# Relative position encoding
|
132 |
+
self.max_relative_position = max_relative_position
|
133 |
+
self.relative_position_k = nn.Parameter(torch.randn(max_relative_position * 2 + 1, self.d_k))
|
134 |
+
|
135 |
+
def forward_qkv(self,
|
136 |
+
query: torch.Tensor,
|
137 |
+
key: torch.Tensor,
|
138 |
+
value: torch.Tensor
|
139 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
140 |
+
"""
|
141 |
+
transform query, key and value.
|
142 |
+
:param query: torch.Tensor. query tensor. shape=(batch_size, time1, n_feat).
|
143 |
+
:param key: torch.Tensor. key tensor. shape=(batch_size, time2, n_feat).
|
144 |
+
:param value: torch.Tensor. value tensor. shape=(batch_size, time2, n_feat).
|
145 |
+
:return:
|
146 |
+
"""
|
147 |
+
n_batch = query.size(0)
|
148 |
+
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
149 |
+
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
150 |
+
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
151 |
+
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
152 |
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
153 |
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
154 |
+
|
155 |
+
return q, k, v
|
156 |
+
|
157 |
+
def forward_attention(self,
|
158 |
+
value: torch.Tensor,
|
159 |
+
scores: torch.Tensor,
|
160 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
|
161 |
+
) -> torch.Tensor:
|
162 |
+
"""
|
163 |
+
compute attention context vector.
|
164 |
+
:param value: torch.Tensor. transformed value. shape=(batch_size, n_head, time2, d_k).
|
165 |
+
:param scores: torch.Tensor. attention score. shape=(batch_size, n_head, time1, time2).
|
166 |
+
:param mask: torch.Tensor. mask. shape=(batch_size, 1, time2) or
|
167 |
+
(batch_size, time1, time2), (0, 0, 0) means fake mask.
|
168 |
+
:return: torch.Tensor. transformed value. (batch_size, time1, d_model).
|
169 |
+
weighted by the attention score (batch_size, time1, time2).
|
170 |
+
"""
|
171 |
+
n_batch = value.size(0)
|
172 |
+
if mask.size(2) > 0: # time2 > 0
|
173 |
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
174 |
+
# For last chunk, time2 might be larger than scores.size(-1)
|
175 |
+
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
|
176 |
+
scores = scores.masked_fill(mask, -float('inf'))
|
177 |
+
attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
|
178 |
+
else:
|
179 |
+
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
180 |
+
|
181 |
+
p_attn = self.dropout(attn)
|
182 |
+
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
183 |
+
x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, n_feat)
|
184 |
+
|
185 |
+
return self.linear_out(x) # (batch, time1, n_feat)
|
186 |
+
|
187 |
+
def relative_position_encoding(self, length: int) -> torch.Tensor:
|
188 |
+
"""
|
189 |
+
Generate relative position encoding.
|
190 |
+
:param length: int. length of the sequence.
|
191 |
+
:return: torch.Tensor. relative position encoding. shape=(length, length, d_k).
|
192 |
+
"""
|
193 |
+
range_vec = torch.arange(length)
|
194 |
+
distance_mat = range_vec.unsqueeze(0) - range_vec.unsqueeze(1)
|
195 |
+
distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
|
196 |
+
final_mat = distance_mat_clipped + self.max_relative_position
|
197 |
+
return final_mat
|
198 |
+
|
199 |
+
def forward(self,
|
200 |
+
query: torch.Tensor,
|
201 |
+
key: torch.Tensor,
|
202 |
+
value: torch.Tensor,
|
203 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
204 |
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
205 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
206 |
+
|
207 |
+
q, k, v = self.forward_qkv(query, key, value)
|
208 |
+
|
209 |
+
if cache.size(0) > 0:
|
210 |
+
key_cache, value_cache = torch.split(
|
211 |
+
cache, cache.size(-1) // 2, dim=-1)
|
212 |
+
k = torch.cat([key_cache, k], dim=2)
|
213 |
+
v = torch.cat([value_cache, v], dim=2)
|
214 |
+
# NOTE: We do cache slicing in encoder.forward_chunk, since it's
|
215 |
+
# non-trivial to calculate `next_cache_start` here.
|
216 |
+
|
217 |
+
# new_cache shape: [batch_size, self.h, time_steps, self.d_v * 2]
|
218 |
+
new_cache = torch.cat((k, v), dim=-1)
|
219 |
+
|
220 |
+
# Compute relative position encoding
|
221 |
+
length = q.size(2)
|
222 |
+
relative_position = self.relative_position_encoding(length)
|
223 |
+
relative_position_k = self.relative_position_k[relative_position.view(-1)].view(length, length, -1)
|
224 |
+
|
225 |
+
relative_position_k = relative_position_k.unsqueeze(0).unsqueeze(0) # (1, 1, length, length, d_k)
|
226 |
+
relative_position_k = relative_position_k.expand(q.size(0), q.size(1), -1, -1, -1) # (batch, head, length, length, d_k)
|
227 |
+
|
228 |
+
native_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
229 |
+
relative_position_scores = torch.matmul(q.unsqueeze(3), relative_position_k.transpose(-2, -1)).squeeze(3) / math.sqrt(self.d_k)
|
230 |
+
scores = native_scores + relative_position_scores
|
231 |
+
|
232 |
+
return self.forward_attention(v, scores, mask), new_cache
|
233 |
+
|
234 |
+
|
235 |
+
def main():
|
236 |
+
rel_attention = RelativeMultiHeadedAttention(n_head=4, n_feat=256, dropout_rate=0.1)
|
237 |
+
|
238 |
+
# x = torch.ones(size=(1, 200, 256), dtype=torch.float32)
|
239 |
+
|
240 |
+
x = torch.ones(size=(1, 1, 256), dtype=torch.float32)
|
241 |
+
cache = torch.ones(size=(1, 4, 199, 128), dtype=torch.float32)
|
242 |
+
|
243 |
+
xt, new_cache = rel_attention.forward(x, x, x, cache=cache)
|
244 |
+
print(xt.shape)
|
245 |
+
print(new_cache.shape)
|
246 |
+
return
|
247 |
+
|
248 |
+
|
249 |
+
if __name__ == '__main__':
|
250 |
+
main()
|
toolbox/torchaudio/models/nx_clean_unet/transformer/embedding.py
DELETED
@@ -1,95 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
import torch
|
4 |
-
import torch.nn as nn
|
5 |
-
|
6 |
-
|
7 |
-
import torch
|
8 |
-
import torch.nn as nn
|
9 |
-
import torch.nn.functional as F
|
10 |
-
|
11 |
-
class RelativeMultiheadAttention(nn.Module):
|
12 |
-
def __init__(self, d_model, num_heads, max_len, dropout=0.1):
|
13 |
-
super(RelativeMultiheadAttention, self).__init__()
|
14 |
-
self.num_heads = num_heads
|
15 |
-
self.d_model = d_model
|
16 |
-
self.head_dim = d_model // num_heads
|
17 |
-
self.scale = self.head_dim ** -0.5
|
18 |
-
|
19 |
-
self.query_projection = nn.Linear(d_model, d_model)
|
20 |
-
self.key_projection = nn.Linear(d_model, d_model)
|
21 |
-
self.value_projection = nn.Linear(d_model, d_model)
|
22 |
-
self.output_projection = nn.Linear(d_model, d_model)
|
23 |
-
|
24 |
-
self.dropout = nn.Dropout(dropout)
|
25 |
-
|
26 |
-
# Relative position encoding
|
27 |
-
self.relative_positions_encoding = self.generate_relative_positions_encoding(max_len, self.head_dim)
|
28 |
-
|
29 |
-
def generate_relative_positions_encoding(self, max_len, head_dim):
|
30 |
-
# Generate relative positions encoding matrix
|
31 |
-
even_index = torch.arange(max_len)[:, None] // torch.pow(10000, torch.arange(0, head_dim, 2) / head_dim)
|
32 |
-
odd_index = torch.arange(max_len)[:, None] // torch.pow(10000, torch.arange(1, head_dim, 2) / head_dim)
|
33 |
-
even_index = torch.sin(even_index)
|
34 |
-
odd_index = torch.cos(odd_index)
|
35 |
-
pos_encoding = torch.zeros(max_len, head_dim)
|
36 |
-
pos_encoding[:, 0::2] = even_index
|
37 |
-
pos_encoding[:, 1::2] = odd_index
|
38 |
-
return pos_encoding
|
39 |
-
|
40 |
-
def forward(self, query, key, value, mask=None):
|
41 |
-
batch_size = query.size(0)
|
42 |
-
query_len = query.size(1)
|
43 |
-
key_len = key.size(1)
|
44 |
-
|
45 |
-
# Project queries, keys, and values to multiple heads
|
46 |
-
query = self.query_projection(query).view(batch_size, query_len, self.num_heads, self.head_dim).transpose(1, 2)
|
47 |
-
key = self.key_projection(key).view(batch_size, key_len, self.num_heads, self.head_dim).transpose(1, 2)
|
48 |
-
value = self.value_projection(value).view(batch_size, key_len, self.num_heads, self.head_dim).transpose(1, 2)
|
49 |
-
|
50 |
-
# Apply relative position encoding
|
51 |
-
relative_keys = self.relative_positions_encoding[:query_len, :].unsqueeze(0).unsqueeze(0).repeat(batch_size, self.num_heads, 1, 1)
|
52 |
-
relative_values = self.relative_positions_encoding[:query_len, :].unsqueeze(0).unsqueeze(0).repeat(batch_size, self.num_heads, 1, 1)
|
53 |
-
|
54 |
-
# Compute attention scores
|
55 |
-
scores = torch.matmul(query, key.transpose(-2, -1)) * self.scale
|
56 |
-
scores += torch.matmul(query, relative_keys.transpose(-2, -1))
|
57 |
-
|
58 |
-
if mask is not None:
|
59 |
-
scores = scores.masked_fill(mask == 0, float('-inf'))
|
60 |
-
|
61 |
-
attn_weights = F.softmax(scores, dim=-1)
|
62 |
-
attn_weights = self.dropout(attn_weights)
|
63 |
-
|
64 |
-
# Apply attention weights to values
|
65 |
-
output = torch.matmul(attn_weights, value) + torch.matmul(attn_weights, relative_values)
|
66 |
-
output = output.transpose(1, 2).contiguous().view(batch_size, query_len, self.d_model)
|
67 |
-
|
68 |
-
# Apply output projection
|
69 |
-
output = self.output_projection(output)
|
70 |
-
|
71 |
-
return output
|
72 |
-
|
73 |
-
|
74 |
-
def main():
|
75 |
-
# Example usage
|
76 |
-
batch_size = 2
|
77 |
-
query_len = 10
|
78 |
-
key_len = 10
|
79 |
-
d_model = 512
|
80 |
-
num_heads = 8
|
81 |
-
max_len = 100
|
82 |
-
|
83 |
-
query = torch.rand(batch_size, query_len, d_model)
|
84 |
-
key = torch.rand(batch_size, key_len, d_model)
|
85 |
-
value = torch.rand(batch_size, key_len, d_model)
|
86 |
-
|
87 |
-
attention = RelativeMultiheadAttention(d_model, num_heads, max_len)
|
88 |
-
output = attention(query, key, value)
|
89 |
-
print(output.shape) # Output shape should be (batch_size, query_len, d_model)
|
90 |
-
|
91 |
-
return
|
92 |
-
|
93 |
-
|
94 |
-
if __name__ == '__main__':
|
95 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_clean_unet/transformer/transformer.py
CHANGED
@@ -5,178 +5,9 @@ from typing import Dict, Optional, Tuple, List, Union
|
|
5 |
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
-
import torch.nn.functional as f
|
9 |
|
10 |
from toolbox.torchaudio.models.nx_clean_unet.transformer.mask import subsequent_chunk_mask
|
11 |
-
|
12 |
-
|
13 |
-
class SinusoidalPositionalEncoding(nn.Module):
|
14 |
-
"""
|
15 |
-
Positional Encoding
|
16 |
-
|
17 |
-
PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
|
18 |
-
PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
|
19 |
-
"""
|
20 |
-
|
21 |
-
@staticmethod
|
22 |
-
def demo1():
|
23 |
-
batch_size = 2
|
24 |
-
time_steps = 10
|
25 |
-
embedding_dim = 64
|
26 |
-
|
27 |
-
pe = SinusoidalPositionalEncoding(
|
28 |
-
embedding_dim=embedding_dim,
|
29 |
-
dropout_rate=0.1,
|
30 |
-
)
|
31 |
-
|
32 |
-
x = torch.randn(size=(batch_size, time_steps, embedding_dim))
|
33 |
-
|
34 |
-
x, pos_emb = pe.forward(x)
|
35 |
-
|
36 |
-
# torch.Size([2, 10, 64])
|
37 |
-
print(x.shape)
|
38 |
-
# torch.Size([1, 10, 64])
|
39 |
-
print(pos_emb.shape)
|
40 |
-
return
|
41 |
-
|
42 |
-
@staticmethod
|
43 |
-
def demo2():
|
44 |
-
batch_size = 2
|
45 |
-
time_steps = 10
|
46 |
-
embedding_dim = 64
|
47 |
-
|
48 |
-
pe = SinusoidalPositionalEncoding(
|
49 |
-
embedding_dim=embedding_dim,
|
50 |
-
dropout_rate=0.1,
|
51 |
-
)
|
52 |
-
|
53 |
-
x = torch.randn(size=(batch_size, time_steps, embedding_dim))
|
54 |
-
offset = torch.randint(low=3, high=7, size=(batch_size,))
|
55 |
-
x, pos_emb = pe.forward(x, offset=offset)
|
56 |
-
|
57 |
-
# tensor([3, 4])
|
58 |
-
print(offset)
|
59 |
-
# torch.Size([2, 10, 64])
|
60 |
-
print(x.shape)
|
61 |
-
# torch.Size([2, 10, 64])
|
62 |
-
print(pos_emb.shape)
|
63 |
-
return
|
64 |
-
|
65 |
-
def __init__(self,
|
66 |
-
embedding_dim: int,
|
67 |
-
dropout_rate: float,
|
68 |
-
max_length: int = 5000,
|
69 |
-
reverse: bool = False
|
70 |
-
):
|
71 |
-
super().__init__()
|
72 |
-
self.embedding_dim = embedding_dim
|
73 |
-
self.dropout_rate = dropout_rate
|
74 |
-
self.max_length = max_length
|
75 |
-
self.reverse = reverse
|
76 |
-
|
77 |
-
self.x_scale = math.sqrt(self.embedding_dim)
|
78 |
-
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
79 |
-
|
80 |
-
self.pe = torch.zeros(self.max_length, self.embedding_dim)
|
81 |
-
position = torch.arange(0, self.max_length, dtype=torch.float32).unsqueeze(1)
|
82 |
-
|
83 |
-
div_term = torch.exp(
|
84 |
-
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) *
|
85 |
-
- (math.log(10000.0) / self.embedding_dim)
|
86 |
-
)
|
87 |
-
self.pe[:, 0::2] = torch.sin(position * div_term)
|
88 |
-
self.pe[:, 1::2] = torch.cos(position * div_term)
|
89 |
-
self.pe = self.pe.unsqueeze(0)
|
90 |
-
|
91 |
-
def forward(self,
|
92 |
-
x: torch.Tensor,
|
93 |
-
offset: Union[int, torch.Tensor] = 0
|
94 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
95 |
-
"""
|
96 |
-
Add positional encoding.
|
97 |
-
:param x: torch.Tensor. Input. shape=(batch_size, time_steps, ...).
|
98 |
-
:param offset: int or torch.Tensor. position offset.
|
99 |
-
:return:
|
100 |
-
torch.Tensor. Encoded tensor. shape=(batch_size, time_steps, ...).
|
101 |
-
torch.Tensor. for compatibility to RelPositionalEncoding. shape=(1, time_steps, ...).
|
102 |
-
"""
|
103 |
-
self.pe = self.pe.to(x.device)
|
104 |
-
pos_emb = self.position_encoding(offset, x.size(1), False)
|
105 |
-
x = x * self.x_scale + pos_emb
|
106 |
-
return self.dropout(x), self.dropout(pos_emb)
|
107 |
-
|
108 |
-
def position_encoding(self,
|
109 |
-
offset: Union[int, torch.Tensor],
|
110 |
-
size: int,
|
111 |
-
apply_dropout: bool = True
|
112 |
-
) -> torch.Tensor:
|
113 |
-
"""
|
114 |
-
For getting encoding in a streaming fashion.
|
115 |
-
|
116 |
-
Attention!!!!!
|
117 |
-
we apply dropout only once at the whole utterance level in a none
|
118 |
-
streaming way, but will call this function several times with
|
119 |
-
increasing input size in a streaming scenario, so the dropout will
|
120 |
-
be applied several times.
|
121 |
-
|
122 |
-
:param offset: int or torch.Tensor. start offset.
|
123 |
-
:param size: int. required size of position encoding.
|
124 |
-
:param apply_dropout:
|
125 |
-
:return: torch.Tensor. Corresponding encoding.
|
126 |
-
"""
|
127 |
-
if isinstance(offset, int):
|
128 |
-
assert offset + size <= self.max_length
|
129 |
-
pos_emb = self.pe[:, offset:offset + size]
|
130 |
-
elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
|
131 |
-
assert offset + size <= self.max_length
|
132 |
-
pos_emb = self.pe[:, offset:offset + size]
|
133 |
-
else: # for batched streaming decoding on GPU
|
134 |
-
# offset. shape=(batch_size,)
|
135 |
-
assert torch.max(offset) + size <= self.max_length
|
136 |
-
|
137 |
-
# shape=(batch_size, time_steps)
|
138 |
-
index = offset.unsqueeze(1) + torch.arange(0, size).to(offset.device)
|
139 |
-
flag = index > 0
|
140 |
-
# remove negative offset
|
141 |
-
index = index * flag
|
142 |
-
# shape=(batch_size, time_steps, embedding_dim)
|
143 |
-
pos_emb = f.embedding(index, self.pe[0])
|
144 |
-
|
145 |
-
if apply_dropout:
|
146 |
-
pos_emb = self.dropout(pos_emb)
|
147 |
-
return pos_emb
|
148 |
-
|
149 |
-
|
150 |
-
class RelPositionalEncoding(SinusoidalPositionalEncoding):
|
151 |
-
"""
|
152 |
-
Relative positional encoding module.
|
153 |
-
|
154 |
-
See : Appendix B in https://arxiv.org/abs/1901.02860
|
155 |
-
|
156 |
-
"""
|
157 |
-
def __init__(self,
|
158 |
-
embedding_dim: int,
|
159 |
-
dropout_rate: float,
|
160 |
-
max_length: int = 5000,
|
161 |
-
):
|
162 |
-
super().__init__(embedding_dim, dropout_rate, max_length, reverse=True)
|
163 |
-
|
164 |
-
def forward(self,
|
165 |
-
x: torch.Tensor,
|
166 |
-
offset: Union[int, torch.Tensor] = 0
|
167 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
168 |
-
"""
|
169 |
-
Compute positional encoding.
|
170 |
-
:param x: torch.Tensor. Input. shape=(batch_size, time_steps, ...).
|
171 |
-
:param offset:
|
172 |
-
:return:
|
173 |
-
torch.Tensor. Encoded tensor. shape=(batch_size, time_steps, ...).
|
174 |
-
torch.Tensor. Positional embedding tensor. shape=(1, time_steps, ...).
|
175 |
-
"""
|
176 |
-
self.pe = self.pe.to(x.device)
|
177 |
-
x = x * self.x_scale
|
178 |
-
pos_emb = self.position_encoding(offset, x.size(1), False)
|
179 |
-
return self.dropout(x), self.dropout(pos_emb)
|
180 |
|
181 |
|
182 |
class PositionwiseFeedForward(nn.Module):
|
@@ -209,151 +40,20 @@ class PositionwiseFeedForward(nn.Module):
|
|
209 |
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
|
210 |
|
211 |
|
212 |
-
class MultiHeadedAttention(nn.Module):
|
213 |
-
def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
|
214 |
-
"""
|
215 |
-
:param n_head: int. the number of heads.
|
216 |
-
:param n_feat: int. the number of features.
|
217 |
-
:param dropout_rate: float. dropout rate.
|
218 |
-
"""
|
219 |
-
super().__init__()
|
220 |
-
assert n_feat % n_head == 0
|
221 |
-
# We assume d_v always equals d_k
|
222 |
-
self.d_k = n_feat // n_head
|
223 |
-
self.h = n_head
|
224 |
-
self.linear_q = nn.Linear(n_feat, n_feat)
|
225 |
-
self.linear_k = nn.Linear(n_feat, n_feat)
|
226 |
-
self.linear_v = nn.Linear(n_feat, n_feat)
|
227 |
-
self.linear_out = nn.Linear(n_feat, n_feat)
|
228 |
-
self.dropout = nn.Dropout(p=dropout_rate)
|
229 |
-
|
230 |
-
def forward_qkv(self,
|
231 |
-
query: torch.Tensor,
|
232 |
-
key: torch.Tensor,
|
233 |
-
value: torch.Tensor
|
234 |
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
235 |
-
"""
|
236 |
-
transform query, key and value.
|
237 |
-
:param query: torch.Tensor. query tensor. shape=(batch_size, time1, n_feat).
|
238 |
-
:param key: torch.Tensor. key tensor. shape=(batch_size, time2, n_feat).
|
239 |
-
:param value: torch.Tensor. value tensor. shape=(batch_size, time2, n_feat).
|
240 |
-
:return:
|
241 |
-
"""
|
242 |
-
n_batch = query.size(0)
|
243 |
-
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
244 |
-
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
245 |
-
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
246 |
-
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
247 |
-
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
248 |
-
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
249 |
-
|
250 |
-
return q, k, v
|
251 |
-
|
252 |
-
def forward_attention(self,
|
253 |
-
value: torch.Tensor,
|
254 |
-
scores: torch.Tensor,
|
255 |
-
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
|
256 |
-
) -> torch.Tensor:
|
257 |
-
"""
|
258 |
-
compute attention context vector.
|
259 |
-
:param value: torch.Tensor. transformed value. shape=(batch_size, n_head, time2, d_k).
|
260 |
-
:param scores: torch.Tensor. attention score. shape=(batch_size, n_head, time1, time2).
|
261 |
-
:param mask: torch.Tensor. mask. shape=(batch_size, 1, time2) or
|
262 |
-
(batch_size, time1, time2), (0, 0, 0) means fake mask.
|
263 |
-
:return: torch.Tensor. transformed value. (batch_size, time1, d_model).
|
264 |
-
weighted by the attention score (batch_size, time1, time2).
|
265 |
-
"""
|
266 |
-
n_batch = value.size(0)
|
267 |
-
# NOTE: When will `if mask.size(2) > 0` be True?
|
268 |
-
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
|
269 |
-
# 1st chunk to ease the onnx export.]
|
270 |
-
# 2. pytorch training
|
271 |
-
if mask.size(2) > 0: # time2 > 0
|
272 |
-
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
273 |
-
# For last chunk, time2 might be larger than scores.size(-1)
|
274 |
-
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
|
275 |
-
scores = scores.masked_fill(mask, -float('inf'))
|
276 |
-
attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
|
277 |
-
|
278 |
-
# NOTE: When will `if mask.size(2) > 0` be False?
|
279 |
-
# 1. onnx(16/-1, -1/-1, 16/0)
|
280 |
-
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
|
281 |
-
else:
|
282 |
-
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
283 |
-
|
284 |
-
p_attn = self.dropout(attn)
|
285 |
-
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
286 |
-
x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, n_feat)
|
287 |
-
|
288 |
-
return self.linear_out(x) # (batch, time1, n_feat)
|
289 |
-
|
290 |
-
def forward(self,
|
291 |
-
query: torch.Tensor,
|
292 |
-
key: torch.Tensor,
|
293 |
-
value: torch.Tensor,
|
294 |
-
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
295 |
-
cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
296 |
-
**kwargs,
|
297 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
298 |
-
"""
|
299 |
-
compute scaled dot product attention.
|
300 |
-
:param query: torch.Tensor. query tensor. shape=(batch_size, time1, n_feat).
|
301 |
-
:param key: torch.Tensor. key tensor. shape=(batch_size, time2, n_feat).
|
302 |
-
:param value: torch.Tensor. value tensor. shape=(batch_size, time2, n_feat).
|
303 |
-
:param mask: torch.Tensor. mask tensor (batch_size, 1, time2) or
|
304 |
-
(batch_size, time1, time2).
|
305 |
-
:param cache: torch.Tensor. cache tensor. shape=(1, head, cache_t, d_k * 2),
|
306 |
-
where `cache_t == chunk_size * num_decoding_left_chunks`
|
307 |
-
and `head * d_k == n_feat`
|
308 |
-
:return:
|
309 |
-
torch.Tensor. output tensor. shape=(batch_size, time1, n_feat).
|
310 |
-
torch.Tensor. cache tensor. (1, head, cache_t + time1, d_k * 2)
|
311 |
-
where `cache_t == chunk_size * num_decoding_left_chunks`
|
312 |
-
and `head * d_k == n_feat`
|
313 |
-
"""
|
314 |
-
q, k, v = self.forward_qkv(query, key, value)
|
315 |
-
|
316 |
-
# NOTE:
|
317 |
-
# when export onnx model, for 1st chunk, we feed
|
318 |
-
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
319 |
-
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
|
320 |
-
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
|
321 |
-
# and we will always do splitting and
|
322 |
-
# concatnation(this will simplify onnx export). Note that
|
323 |
-
# it's OK to concat & split zero-shaped tensors(see code below).
|
324 |
-
# when export jit model, for 1st chunk, we always feed
|
325 |
-
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
|
326 |
-
# >>> a = torch.ones((1, 2, 0, 4))
|
327 |
-
# >>> b = torch.ones((1, 2, 3, 4))
|
328 |
-
# >>> c = torch.cat((a, b), dim=2)
|
329 |
-
# >>> torch.equal(b, c) # True
|
330 |
-
# >>> d = torch.split(a, 2, dim=-1)
|
331 |
-
# >>> torch.equal(d[0], d[1]) # True
|
332 |
-
if cache.size(0) > 0:
|
333 |
-
key_cache, value_cache = torch.split(
|
334 |
-
cache, cache.size(-1) // 2, dim=-1)
|
335 |
-
k = torch.cat([key_cache, k], dim=2)
|
336 |
-
v = torch.cat([value_cache, v], dim=2)
|
337 |
-
# NOTE: We do cache slicing in encoder.forward_chunk, since it's
|
338 |
-
# non-trivial to calculate `next_cache_start` here.
|
339 |
-
new_cache = torch.cat((k, v), dim=-1)
|
340 |
-
|
341 |
-
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
342 |
-
return self.forward_attention(v, scores, mask), new_cache
|
343 |
-
|
344 |
-
|
345 |
class TransformerEncoderLayer(nn.Module):
|
346 |
def __init__(self,
|
347 |
input_dim: int,
|
348 |
dropout_rate: float = 0.1,
|
349 |
n_heads: int = 4,
|
|
|
350 |
):
|
351 |
super().__init__()
|
352 |
self.norm1 = nn.LayerNorm(input_dim, eps=1e-5)
|
353 |
-
self.attention =
|
354 |
n_head=n_heads,
|
355 |
n_feat=input_dim,
|
356 |
-
dropout_rate=dropout_rate
|
|
|
357 |
)
|
358 |
|
359 |
self.dropout1 = nn.Dropout(dropout_rate)
|
@@ -370,7 +70,6 @@ class TransformerEncoderLayer(nn.Module):
|
|
370 |
self,
|
371 |
x: torch.Tensor,
|
372 |
mask: torch.Tensor,
|
373 |
-
position_embedding: torch.Tensor,
|
374 |
attention_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
375 |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
376 |
"""
|
@@ -388,7 +87,7 @@ class TransformerEncoderLayer(nn.Module):
|
|
388 |
xt = self.norm1(x)
|
389 |
|
390 |
x_att, new_att_cache = self.attention.forward(
|
391 |
-
xt, xt, xt, mask=mask, cache=attention_cache
|
392 |
)
|
393 |
x = x + self.dropout1(xt)
|
394 |
xt = self.norm2(x)
|
@@ -410,7 +109,7 @@ class TransformerEncoder(nn.Module):
|
|
410 |
attention_heads: int = 4,
|
411 |
num_blocks: int = 6,
|
412 |
dropout_rate: float = 0.1,
|
413 |
-
|
414 |
chunk_size: int = 1,
|
415 |
num_left_chunks: int = 128,
|
416 |
):
|
@@ -418,7 +117,7 @@ class TransformerEncoder(nn.Module):
|
|
418 |
self.input_size = input_size
|
419 |
self.hidden_size = hidden_size
|
420 |
|
421 |
-
self.
|
422 |
self.chunk_size = chunk_size
|
423 |
self.num_left_chunks = num_left_chunks
|
424 |
|
@@ -427,17 +126,12 @@ class TransformerEncoder(nn.Module):
|
|
427 |
out_features=self.hidden_size,
|
428 |
)
|
429 |
|
430 |
-
self.positional_encoding = RelPositionalEncoding(
|
431 |
-
embedding_dim=hidden_size,
|
432 |
-
dropout_rate=dropout_rate,
|
433 |
-
max_length=max_length,
|
434 |
-
)
|
435 |
-
|
436 |
self.encoder_layer_list = torch.nn.ModuleList([
|
437 |
TransformerEncoderLayer(
|
438 |
input_dim=hidden_size,
|
439 |
n_heads=attention_heads,
|
440 |
dropout_rate=dropout_rate,
|
|
|
441 |
) for _ in range(num_blocks)
|
442 |
])
|
443 |
|
@@ -458,10 +152,6 @@ class TransformerEncoder(nn.Module):
|
|
458 |
xs = self.input_linear.forward(xs)
|
459 |
# xs shape: [batch_size, time_steps, hidden_size]
|
460 |
|
461 |
-
xs, position_embedding = self.positional_encoding.forward(xs)
|
462 |
-
# xs shape: [batch_size, time_steps, hidden_size]
|
463 |
-
# position_embedding shape: [1, time_steps, hidden_size]
|
464 |
-
|
465 |
chunk_masks = subsequent_chunk_mask(
|
466 |
size=time_steps,
|
467 |
chunk_size=self.chunk_size,
|
@@ -473,7 +163,7 @@ class TransformerEncoder(nn.Module):
|
|
473 |
# chunk_masks shape: [batch_size, time_steps, time_steps]
|
474 |
|
475 |
for encoder_layer in self.encoder_layer_list:
|
476 |
-
xs, _ = encoder_layer.forward(xs, chunk_masks
|
477 |
|
478 |
# xs shape: [batch_size, time_steps, hidden_size]
|
479 |
xs = self.output_linear.forward(xs)
|
|
|
5 |
|
6 |
import torch
|
7 |
import torch.nn as nn
|
|
|
8 |
|
9 |
from toolbox.torchaudio.models.nx_clean_unet.transformer.mask import subsequent_chunk_mask
|
10 |
+
from toolbox.torchaudio.models.nx_clean_unet.transformer.attention import MultiHeadedAttention, RelativeMultiHeadedAttention
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
|
13 |
class PositionwiseFeedForward(nn.Module):
|
|
|
40 |
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
|
41 |
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
class TransformerEncoderLayer(nn.Module):
|
44 |
def __init__(self,
|
45 |
input_dim: int,
|
46 |
dropout_rate: float = 0.1,
|
47 |
n_heads: int = 4,
|
48 |
+
max_relative_position: int = 5120
|
49 |
):
|
50 |
super().__init__()
|
51 |
self.norm1 = nn.LayerNorm(input_dim, eps=1e-5)
|
52 |
+
self.attention = RelativeMultiHeadedAttention(
|
53 |
n_head=n_heads,
|
54 |
n_feat=input_dim,
|
55 |
+
dropout_rate=dropout_rate,
|
56 |
+
max_relative_position=max_relative_position,
|
57 |
)
|
58 |
|
59 |
self.dropout1 = nn.Dropout(dropout_rate)
|
|
|
70 |
self,
|
71 |
x: torch.Tensor,
|
72 |
mask: torch.Tensor,
|
|
|
73 |
attention_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
74 |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
75 |
"""
|
|
|
87 |
xt = self.norm1(x)
|
88 |
|
89 |
x_att, new_att_cache = self.attention.forward(
|
90 |
+
xt, xt, xt, mask=mask, cache=attention_cache
|
91 |
)
|
92 |
x = x + self.dropout1(xt)
|
93 |
xt = self.norm2(x)
|
|
|
109 |
attention_heads: int = 4,
|
110 |
num_blocks: int = 6,
|
111 |
dropout_rate: float = 0.1,
|
112 |
+
max_relative_position: int = 1024,
|
113 |
chunk_size: int = 1,
|
114 |
num_left_chunks: int = 128,
|
115 |
):
|
|
|
117 |
self.input_size = input_size
|
118 |
self.hidden_size = hidden_size
|
119 |
|
120 |
+
self.max_relative_position = max_relative_position
|
121 |
self.chunk_size = chunk_size
|
122 |
self.num_left_chunks = num_left_chunks
|
123 |
|
|
|
126 |
out_features=self.hidden_size,
|
127 |
)
|
128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
self.encoder_layer_list = torch.nn.ModuleList([
|
130 |
TransformerEncoderLayer(
|
131 |
input_dim=hidden_size,
|
132 |
n_heads=attention_heads,
|
133 |
dropout_rate=dropout_rate,
|
134 |
+
max_relative_position=max_relative_position,
|
135 |
) for _ in range(num_blocks)
|
136 |
])
|
137 |
|
|
|
152 |
xs = self.input_linear.forward(xs)
|
153 |
# xs shape: [batch_size, time_steps, hidden_size]
|
154 |
|
|
|
|
|
|
|
|
|
155 |
chunk_masks = subsequent_chunk_mask(
|
156 |
size=time_steps,
|
157 |
chunk_size=self.chunk_size,
|
|
|
163 |
# chunk_masks shape: [batch_size, time_steps, time_steps]
|
164 |
|
165 |
for encoder_layer in self.encoder_layer_list:
|
166 |
+
xs, _ = encoder_layer.forward(xs, chunk_masks)
|
167 |
|
168 |
# xs shape: [batch_size, time_steps, hidden_size]
|
169 |
xs = self.output_linear.forward(xs)
|