#!/usr/bin/python3 # -*- coding: utf-8 -*- """ https://github.com/jzi040941/PercepNet https://arxiv.org/abs/2008.04259 https://modelzoo.co/model/percepnet 太复杂了。 (1)pytorch 模型只是整个 pipeline 中的一部分。 (2)训练样本需经过基音分析,频谱包络之类的计算。 """ import torch import torch.nn as nn class PercepNet(nn.Module): """ https://github.com/jzi040941/PercepNet/blob/main/rnn_train.py#L105 4.1% of an x86 CPU core """ def __init__(self, input_dim=70): super(PercepNet, self).__init__() # self.hidden_dim = hidden_dim # self.n_layers = n_layers self.fc = nn.Sequential( nn.Linear(input_dim, 128), nn.ReLU() ) self.conv1 = nn.Sequential( nn.Conv1d(128, 512, 5, stride=1, padding=4), nn.ReLU() )#padding for align with c++ dnn self.conv2 = nn.Sequential( nn.Conv1d(512, 512, 3, stride=1, padding=2), nn.Tanh() ) #self.gru = nn.GRU(512, 512, 3, batch_first=True) self.gru1 = nn.GRU(512, 512, 1, batch_first=True) self.gru2 = nn.GRU(512, 512, 1, batch_first=True) self.gru3 = nn.GRU(512, 512, 1, batch_first=True) self.gru_gb = nn.GRU(512, 512, 1, batch_first=True) self.gru_rb = nn.GRU(1024, 128, 1, batch_first=True) self.fc_gb = nn.Sequential( nn.Linear(512*5, 34), nn.Sigmoid() ) self.fc_rb = nn.Sequential( nn.Linear(128, 34), nn.Sigmoid() ) def forward(self, x: torch.Tensor): # x shape: [b, t, f] x = self.fc(x) x = x.permute([0, 2, 1]) # x shape: [b, f, t] # causal conv x = self.conv1(x) x = x[:, :, :-4] # x shape: [b, f, t] convout = self.conv2(x) convout = convout[:, :, :-2] convout = convout.permute([0, 2, 1]) # convout shape: [b, t, f] gru1_out, gru1_state = self.gru1(convout) gru2_out, gru2_state = self.gru2(gru1_out) gru3_out, gru3_state = self.gru3(gru2_out) gru_gb_out, gru_gb_state = self.gru_gb(gru3_out) concat_gb_layer = torch.cat(tensors=(convout, gru1_out, gru2_out, gru3_out, gru_gb_out), dim=-1) gb = self.fc_gb(concat_gb_layer) # concat rb need fix concat_rb_layer = torch.cat(tensors=(gru3_out, convout), dim=-1) rnn_rb_out, gru_rb_state = self.gru_rb(concat_rb_layer) rb = self.fc_rb(rnn_rb_out) output = torch.cat((gb, rb), dim=-1) return output def main(): model = PercepNet() x = torch.randn(20, 8, 70) out = model(x) print(out.shape) if __name__ == "__main__": main()