HoneyTian commited on
Commit
6b7a897
·
1 Parent(s): 365fc03
examples/nx_clean_unet/yaml/config.yaml CHANGED
@@ -12,8 +12,8 @@ down_sampling_hidden_channels: 64
12
  down_sampling_kernel_size: 4
13
  down_sampling_stride: 2
14
 
15
- tsfm_hidden_size: 64
16
- tsfm_attention_heads: 4
17
  tsfm_num_blocks: 6
18
  tsfm_dropout_rate: 0.1
19
  tsfm_max_length: 5120
 
12
  down_sampling_kernel_size: 4
13
  down_sampling_stride: 2
14
 
15
+ tsfm_hidden_size: 128
16
+ tsfm_attention_heads: 8
17
  tsfm_num_blocks: 6
18
  tsfm_dropout_rate: 0.1
19
  tsfm_max_length: 5120
toolbox/torchaudio/models/nx_clean_unet/transformer/embedding.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ class RelativeMultiheadAttention(nn.Module):
12
+ def __init__(self, d_model, num_heads, max_len, dropout=0.1):
13
+ super(RelativeMultiheadAttention, self).__init__()
14
+ self.num_heads = num_heads
15
+ self.d_model = d_model
16
+ self.head_dim = d_model // num_heads
17
+ self.scale = self.head_dim ** -0.5
18
+
19
+ self.query_projection = nn.Linear(d_model, d_model)
20
+ self.key_projection = nn.Linear(d_model, d_model)
21
+ self.value_projection = nn.Linear(d_model, d_model)
22
+ self.output_projection = nn.Linear(d_model, d_model)
23
+
24
+ self.dropout = nn.Dropout(dropout)
25
+
26
+ # Relative position encoding
27
+ self.relative_positions_encoding = self.generate_relative_positions_encoding(max_len, self.head_dim)
28
+
29
+ def generate_relative_positions_encoding(self, max_len, head_dim):
30
+ # Generate relative positions encoding matrix
31
+ even_index = torch.arange(max_len)[:, None] // torch.pow(10000, torch.arange(0, head_dim, 2) / head_dim)
32
+ odd_index = torch.arange(max_len)[:, None] // torch.pow(10000, torch.arange(1, head_dim, 2) / head_dim)
33
+ even_index = torch.sin(even_index)
34
+ odd_index = torch.cos(odd_index)
35
+ pos_encoding = torch.zeros(max_len, head_dim)
36
+ pos_encoding[:, 0::2] = even_index
37
+ pos_encoding[:, 1::2] = odd_index
38
+ return pos_encoding
39
+
40
+ def forward(self, query, key, value, mask=None):
41
+ batch_size = query.size(0)
42
+ query_len = query.size(1)
43
+ key_len = key.size(1)
44
+
45
+ # Project queries, keys, and values to multiple heads
46
+ query = self.query_projection(query).view(batch_size, query_len, self.num_heads, self.head_dim).transpose(1, 2)
47
+ key = self.key_projection(key).view(batch_size, key_len, self.num_heads, self.head_dim).transpose(1, 2)
48
+ value = self.value_projection(value).view(batch_size, key_len, self.num_heads, self.head_dim).transpose(1, 2)
49
+
50
+ # Apply relative position encoding
51
+ relative_keys = self.relative_positions_encoding[:query_len, :].unsqueeze(0).unsqueeze(0).repeat(batch_size, self.num_heads, 1, 1)
52
+ relative_values = self.relative_positions_encoding[:query_len, :].unsqueeze(0).unsqueeze(0).repeat(batch_size, self.num_heads, 1, 1)
53
+
54
+ # Compute attention scores
55
+ scores = torch.matmul(query, key.transpose(-2, -1)) * self.scale
56
+ scores += torch.matmul(query, relative_keys.transpose(-2, -1))
57
+
58
+ if mask is not None:
59
+ scores = scores.masked_fill(mask == 0, float('-inf'))
60
+
61
+ attn_weights = F.softmax(scores, dim=-1)
62
+ attn_weights = self.dropout(attn_weights)
63
+
64
+ # Apply attention weights to values
65
+ output = torch.matmul(attn_weights, value) + torch.matmul(attn_weights, relative_values)
66
+ output = output.transpose(1, 2).contiguous().view(batch_size, query_len, self.d_model)
67
+
68
+ # Apply output projection
69
+ output = self.output_projection(output)
70
+
71
+ return output
72
+
73
+
74
+ def main():
75
+ # Example usage
76
+ batch_size = 2
77
+ query_len = 10
78
+ key_len = 10
79
+ d_model = 512
80
+ num_heads = 8
81
+ max_len = 100
82
+
83
+ query = torch.rand(batch_size, query_len, d_model)
84
+ key = torch.rand(batch_size, key_len, d_model)
85
+ value = torch.rand(batch_size, key_len, d_model)
86
+
87
+ attention = RelativeMultiheadAttention(d_model, num_heads, max_len)
88
+ output = attention(query, key, value)
89
+ print(output.shape) # Output shape should be (batch_size, query_len, d_model)
90
+
91
+ return
92
+
93
+
94
+ if __name__ == '__main__':
95
+ main()