Spaces:
Running
Running
update
Browse files
examples/nx_clean_unet/yaml/config.yaml
CHANGED
@@ -12,8 +12,8 @@ down_sampling_hidden_channels: 64
|
|
12 |
down_sampling_kernel_size: 4
|
13 |
down_sampling_stride: 2
|
14 |
|
15 |
-
tsfm_hidden_size:
|
16 |
-
tsfm_attention_heads:
|
17 |
tsfm_num_blocks: 6
|
18 |
tsfm_dropout_rate: 0.1
|
19 |
tsfm_max_length: 5120
|
|
|
12 |
down_sampling_kernel_size: 4
|
13 |
down_sampling_stride: 2
|
14 |
|
15 |
+
tsfm_hidden_size: 128
|
16 |
+
tsfm_attention_heads: 8
|
17 |
tsfm_num_blocks: 6
|
18 |
tsfm_dropout_rate: 0.1
|
19 |
tsfm_max_length: 5120
|
toolbox/torchaudio/models/nx_clean_unet/transformer/embedding.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
class RelativeMultiheadAttention(nn.Module):
|
12 |
+
def __init__(self, d_model, num_heads, max_len, dropout=0.1):
|
13 |
+
super(RelativeMultiheadAttention, self).__init__()
|
14 |
+
self.num_heads = num_heads
|
15 |
+
self.d_model = d_model
|
16 |
+
self.head_dim = d_model // num_heads
|
17 |
+
self.scale = self.head_dim ** -0.5
|
18 |
+
|
19 |
+
self.query_projection = nn.Linear(d_model, d_model)
|
20 |
+
self.key_projection = nn.Linear(d_model, d_model)
|
21 |
+
self.value_projection = nn.Linear(d_model, d_model)
|
22 |
+
self.output_projection = nn.Linear(d_model, d_model)
|
23 |
+
|
24 |
+
self.dropout = nn.Dropout(dropout)
|
25 |
+
|
26 |
+
# Relative position encoding
|
27 |
+
self.relative_positions_encoding = self.generate_relative_positions_encoding(max_len, self.head_dim)
|
28 |
+
|
29 |
+
def generate_relative_positions_encoding(self, max_len, head_dim):
|
30 |
+
# Generate relative positions encoding matrix
|
31 |
+
even_index = torch.arange(max_len)[:, None] // torch.pow(10000, torch.arange(0, head_dim, 2) / head_dim)
|
32 |
+
odd_index = torch.arange(max_len)[:, None] // torch.pow(10000, torch.arange(1, head_dim, 2) / head_dim)
|
33 |
+
even_index = torch.sin(even_index)
|
34 |
+
odd_index = torch.cos(odd_index)
|
35 |
+
pos_encoding = torch.zeros(max_len, head_dim)
|
36 |
+
pos_encoding[:, 0::2] = even_index
|
37 |
+
pos_encoding[:, 1::2] = odd_index
|
38 |
+
return pos_encoding
|
39 |
+
|
40 |
+
def forward(self, query, key, value, mask=None):
|
41 |
+
batch_size = query.size(0)
|
42 |
+
query_len = query.size(1)
|
43 |
+
key_len = key.size(1)
|
44 |
+
|
45 |
+
# Project queries, keys, and values to multiple heads
|
46 |
+
query = self.query_projection(query).view(batch_size, query_len, self.num_heads, self.head_dim).transpose(1, 2)
|
47 |
+
key = self.key_projection(key).view(batch_size, key_len, self.num_heads, self.head_dim).transpose(1, 2)
|
48 |
+
value = self.value_projection(value).view(batch_size, key_len, self.num_heads, self.head_dim).transpose(1, 2)
|
49 |
+
|
50 |
+
# Apply relative position encoding
|
51 |
+
relative_keys = self.relative_positions_encoding[:query_len, :].unsqueeze(0).unsqueeze(0).repeat(batch_size, self.num_heads, 1, 1)
|
52 |
+
relative_values = self.relative_positions_encoding[:query_len, :].unsqueeze(0).unsqueeze(0).repeat(batch_size, self.num_heads, 1, 1)
|
53 |
+
|
54 |
+
# Compute attention scores
|
55 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) * self.scale
|
56 |
+
scores += torch.matmul(query, relative_keys.transpose(-2, -1))
|
57 |
+
|
58 |
+
if mask is not None:
|
59 |
+
scores = scores.masked_fill(mask == 0, float('-inf'))
|
60 |
+
|
61 |
+
attn_weights = F.softmax(scores, dim=-1)
|
62 |
+
attn_weights = self.dropout(attn_weights)
|
63 |
+
|
64 |
+
# Apply attention weights to values
|
65 |
+
output = torch.matmul(attn_weights, value) + torch.matmul(attn_weights, relative_values)
|
66 |
+
output = output.transpose(1, 2).contiguous().view(batch_size, query_len, self.d_model)
|
67 |
+
|
68 |
+
# Apply output projection
|
69 |
+
output = self.output_projection(output)
|
70 |
+
|
71 |
+
return output
|
72 |
+
|
73 |
+
|
74 |
+
def main():
|
75 |
+
# Example usage
|
76 |
+
batch_size = 2
|
77 |
+
query_len = 10
|
78 |
+
key_len = 10
|
79 |
+
d_model = 512
|
80 |
+
num_heads = 8
|
81 |
+
max_len = 100
|
82 |
+
|
83 |
+
query = torch.rand(batch_size, query_len, d_model)
|
84 |
+
key = torch.rand(batch_size, key_len, d_model)
|
85 |
+
value = torch.rand(batch_size, key_len, d_model)
|
86 |
+
|
87 |
+
attention = RelativeMultiheadAttention(d_model, num_heads, max_len)
|
88 |
+
output = attention(query, key, value)
|
89 |
+
print(output.shape) # Output shape should be (batch_size, query_len, d_model)
|
90 |
+
|
91 |
+
return
|
92 |
+
|
93 |
+
|
94 |
+
if __name__ == '__main__':
|
95 |
+
main()
|