HoneyTian's picture
first commit
1af34cd
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://github.com/jzi040941/PercepNet
https://arxiv.org/abs/2008.04259
https://modelzoo.co/model/percepnet
太复杂了。
(1)pytorch 模型只是整个 pipeline 中的一部分。
(2)训练样本需经过基音分析,频谱包络之类的计算。
"""
import torch
import torch.nn as nn
class PercepNet(nn.Module):
"""
https://github.com/jzi040941/PercepNet/blob/main/rnn_train.py#L105
4.1% of an x86 CPU core
"""
def __init__(self, input_dim=70):
super(PercepNet, self).__init__()
# self.hidden_dim = hidden_dim
# self.n_layers = n_layers
self.fc = nn.Sequential(
nn.Linear(input_dim, 128),
nn.ReLU()
)
self.conv1 = nn.Sequential(
nn.Conv1d(128, 512, 5, stride=1, padding=4),
nn.ReLU()
)#padding for align with c++ dnn
self.conv2 = nn.Sequential(
nn.Conv1d(512, 512, 3, stride=1, padding=2),
nn.Tanh()
)
#self.gru = nn.GRU(512, 512, 3, batch_first=True)
self.gru1 = nn.GRU(512, 512, 1, batch_first=True)
self.gru2 = nn.GRU(512, 512, 1, batch_first=True)
self.gru3 = nn.GRU(512, 512, 1, batch_first=True)
self.gru_gb = nn.GRU(512, 512, 1, batch_first=True)
self.gru_rb = nn.GRU(1024, 128, 1, batch_first=True)
self.fc_gb = nn.Sequential(
nn.Linear(512*5, 34),
nn.Sigmoid()
)
self.fc_rb = nn.Sequential(
nn.Linear(128, 34),
nn.Sigmoid()
)
def forward(self, x: torch.Tensor):
# x shape: [b, t, f]
x = self.fc(x)
x = x.permute([0, 2, 1])
# x shape: [b, f, t]
# causal conv
x = self.conv1(x)
x = x[:, :, :-4]
# x shape: [b, f, t]
convout = self.conv2(x)
convout = convout[:, :, :-2]
convout = convout.permute([0, 2, 1])
# convout shape: [b, t, f]
gru1_out, gru1_state = self.gru1(convout)
gru2_out, gru2_state = self.gru2(gru1_out)
gru3_out, gru3_state = self.gru3(gru2_out)
gru_gb_out, gru_gb_state = self.gru_gb(gru3_out)
concat_gb_layer = torch.cat(tensors=(convout, gru1_out, gru2_out, gru3_out, gru_gb_out), dim=-1)
gb = self.fc_gb(concat_gb_layer)
# concat rb need fix
concat_rb_layer = torch.cat(tensors=(gru3_out, convout), dim=-1)
rnn_rb_out, gru_rb_state = self.gru_rb(concat_rb_layer)
rb = self.fc_rb(rnn_rb_out)
output = torch.cat((gb, rb), dim=-1)
return output
def main():
model = PercepNet()
x = torch.randn(20, 8, 70)
out = model(x)
print(out.shape)
if __name__ == "__main__":
main()