Spaces:
Running
Running
update
Browse files
toolbox/torchaudio/models/nx_clean_unet/transformer/attention.py
CHANGED
@@ -7,7 +7,7 @@ import torch
|
|
7 |
import torch.nn as nn
|
8 |
|
9 |
|
10 |
-
class
|
11 |
def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
|
12 |
"""
|
13 |
:param n_head: int. the number of heads.
|
|
|
7 |
import torch.nn as nn
|
8 |
|
9 |
|
10 |
+
class MultiHeadAttention(nn.Module):
|
11 |
def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
|
12 |
"""
|
13 |
:param n_head: int. the number of heads.
|
toolbox/torchaudio/models/nx_clean_unet/transformer/transformer.py
CHANGED
@@ -7,7 +7,7 @@ import torch
|
|
7 |
import torch.nn as nn
|
8 |
|
9 |
from toolbox.torchaudio.models.nx_clean_unet.transformer.mask import subsequent_chunk_mask
|
10 |
-
from toolbox.torchaudio.models.nx_clean_unet.transformer.attention import
|
11 |
|
12 |
|
13 |
class PositionwiseFeedForward(nn.Module):
|
|
|
7 |
import torch.nn as nn
|
8 |
|
9 |
from toolbox.torchaudio.models.nx_clean_unet.transformer.mask import subsequent_chunk_mask
|
10 |
+
from toolbox.torchaudio.models.nx_clean_unet.transformer.attention import MultiHeadAttention, RelativeMultiHeadSelfAttention
|
11 |
|
12 |
|
13 |
class PositionwiseFeedForward(nn.Module):
|