Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import torch | |
import torch.nn as nn | |
PT_FEATURE_SIZE = 40 | |
class DeepDTAF(nn.Module): | |
def __init__(self, smi_charset_len): | |
super().__init__() | |
smi_embed_size = 128 | |
seq_embed_size = 128 | |
seq_oc = 128 | |
pkt_oc = 128 | |
smi_oc = 128 | |
self.smi_embed = nn.Embedding(smi_charset_len, smi_embed_size) | |
self.seq_embed = nn.Linear(PT_FEATURE_SIZE, seq_embed_size) # (N, *, H_{in}) -> (N, *, H_{out}) | |
conv_seq = [] | |
ic = seq_embed_size | |
for oc in [32, 64, 64, seq_oc]: | |
conv_seq.append(DilatedParllelResidualBlockA(ic, oc)) | |
ic = oc | |
conv_seq.append(nn.AdaptiveMaxPool1d(1)) # (N, oc) | |
conv_seq.append(Squeeze()) | |
self.conv_seq = nn.Sequential(*conv_seq) | |
# (N, H=32, L) | |
conv_pkt = [] | |
ic = seq_embed_size | |
for oc in [32, 64, pkt_oc]: | |
conv_pkt.append(nn.Conv1d(ic, oc, 3)) # (N,C,L) | |
conv_pkt.append(nn.BatchNorm1d(oc)) | |
conv_pkt.append(nn.PReLU()) | |
ic = oc | |
conv_pkt.append(nn.AdaptiveMaxPool1d(1)) | |
conv_pkt.append(Squeeze()) | |
self.conv_pkt = nn.Sequential(*conv_pkt) # (N,oc) | |
conv_smi = [] | |
ic = smi_embed_size | |
for oc in [32, 64, smi_oc]: | |
conv_smi.append(DilatedParllelResidualBlockB(ic, oc)) | |
ic = oc | |
conv_smi.append(nn.AdaptiveMaxPool1d(1)) | |
conv_smi.append(Squeeze()) | |
self.conv_smi = nn.Sequential(*conv_smi) # (N,128) | |
self.cat_dropout = nn.Dropout(0.2) | |
self.classifier = nn.Sequential( | |
nn.Linear(seq_oc + pkt_oc + smi_oc, 128), | |
nn.Dropout(0.5), | |
nn.PReLU(), | |
nn.Linear(128, 64), | |
nn.Dropout(0.5), | |
nn.PReLU(), | |
# nn.Linear(64, 1), | |
# nn.PReLU() | |
) | |
def forward(self, seq, pkt, smi): | |
# assert seq.shape == (N,L,43) | |
seq_embed = self.seq_embed(seq) # (N,L,32) | |
seq_embed = torch.transpose(seq_embed, 1, 2) # (N,32,L) | |
seq_conv = self.conv_seq(seq_embed) # (N,128) | |
# assert pkt.shape == (N,L,43) | |
pkt_embed = self.seq_embed(pkt) # (N,L,32) | |
pkt_embed = torch.transpose(pkt_embed, 1, 2) | |
pkt_conv = self.conv_pkt(pkt_embed) # (N,128) | |
# assert smi.shape == (N, L) | |
smi_embed = self.smi_embed(smi) # (N,L,32) | |
smi_embed = torch.transpose(smi_embed, 1, 2) | |
smi_conv = self.conv_smi(smi_embed) # (N,128) | |
cat = torch.cat([seq_conv, pkt_conv, smi_conv], dim=1) # (N,128*3) | |
cat = self.cat_dropout(cat) | |
output = self.classifier(cat) | |
return output | |
class Squeeze(nn.Module): | |
def forward(self, input: torch.Tensor): | |
return input.squeeze() | |
class CDilated(nn.Module): | |
def __init__(self, nIn, nOut, kSize, stride=1, d=1): | |
super().__init__() | |
padding = int((kSize - 1) / 2) * d | |
self.conv = nn.Conv1d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False, dilation=d) | |
def forward(self, input): | |
output = self.conv(input) | |
return output | |
class DilatedParllelResidualBlockA(nn.Module): | |
def __init__(self, nIn, nOut, add=True): | |
super().__init__() | |
n = int(nOut / 5) | |
n1 = nOut - 4 * n | |
self.c1 = nn.Conv1d(nIn, n, 1, padding=0) | |
self.br1 = nn.Sequential(nn.BatchNorm1d(n), nn.PReLU()) | |
self.d1 = CDilated(n, n1, 3, 1, 1) # dilation rate of 2^0 | |
self.d2 = CDilated(n, n, 3, 1, 2) # dilation rate of 2^1 | |
self.d4 = CDilated(n, n, 3, 1, 4) # dilation rate of 2^2 | |
self.d8 = CDilated(n, n, 3, 1, 8) # dilation rate of 2^3 | |
self.d16 = CDilated(n, n, 3, 1, 16) # dilation rate of 2^4 | |
self.br2 = nn.Sequential(nn.BatchNorm1d(nOut), nn.PReLU()) | |
if nIn != nOut: | |
# print(f'{nIn}-{nOut}: add=False') | |
add = False | |
self.add = add | |
def forward(self, input): | |
# reduce | |
output1 = self.c1(input) | |
output1 = self.br1(output1) | |
# split and transform | |
d1 = self.d1(output1) | |
d2 = self.d2(output1) | |
d4 = self.d4(output1) | |
d8 = self.d8(output1) | |
d16 = self.d16(output1) | |
# heirarchical fusion for de-gridding | |
add1 = d2 | |
add2 = add1 + d4 | |
add3 = add2 + d8 | |
add4 = add3 + d16 | |
# merge | |
combine = torch.cat([d1, add1, add2, add3, add4], 1) | |
# if residual version | |
if self.add: | |
combine = input + combine | |
output = self.br2(combine) | |
return output | |
class DilatedParllelResidualBlockB(nn.Module): | |
def __init__(self, nIn, nOut, add=True): | |
super().__init__() | |
n = int(nOut / 4) | |
n1 = nOut - 3 * n | |
self.c1 = nn.Conv1d(nIn, n, 1, padding=0) | |
self.br1 = nn.Sequential(nn.BatchNorm1d(n), nn.PReLU()) | |
self.d1 = CDilated(n, n1, 3, 1, 1) # dilation rate of 2^0 | |
self.d2 = CDilated(n, n, 3, 1, 2) # dilation rate of 2^1 | |
self.d4 = CDilated(n, n, 3, 1, 4) # dilation rate of 2^2 | |
self.d8 = CDilated(n, n, 3, 1, 8) # dilation rate of 2^3 | |
self.br2 = nn.Sequential(nn.BatchNorm1d(nOut), nn.PReLU()) | |
if nIn != nOut: | |
# print(f'{nIn}-{nOut}: add=False') | |
add = False | |
self.add = add | |
def forward(self, input): | |
# reduce | |
output1 = self.c1(input) | |
output1 = self.br1(output1) | |
# split and transform | |
d1 = self.d1(output1) | |
d2 = self.d2(output1) | |
d4 = self.d4(output1) | |
d8 = self.d8(output1) | |
# heirarchical fusion for de-gridding | |
add1 = d2 | |
add2 = add1 + d4 | |
add3 = add2 + d8 | |
# merge | |
combine = torch.cat([d1, add1, add2, add3], 1) | |
# if residual version | |
if self.add: | |
combine = input + combine | |
output = self.br2(combine) | |
return output | |