Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 5,968 Bytes
22761bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import torch
import torch.nn as nn
PT_FEATURE_SIZE = 40
class DeepDTAF(nn.Module):
def __init__(self, smi_charset_len):
super().__init__()
smi_embed_size = 128
seq_embed_size = 128
seq_oc = 128
pkt_oc = 128
smi_oc = 128
self.smi_embed = nn.Embedding(smi_charset_len, smi_embed_size)
self.seq_embed = nn.Linear(PT_FEATURE_SIZE, seq_embed_size) # (N, *, H_{in}) -> (N, *, H_{out})
conv_seq = []
ic = seq_embed_size
for oc in [32, 64, 64, seq_oc]:
conv_seq.append(DilatedParllelResidualBlockA(ic, oc))
ic = oc
conv_seq.append(nn.AdaptiveMaxPool1d(1)) # (N, oc)
conv_seq.append(Squeeze())
self.conv_seq = nn.Sequential(*conv_seq)
# (N, H=32, L)
conv_pkt = []
ic = seq_embed_size
for oc in [32, 64, pkt_oc]:
conv_pkt.append(nn.Conv1d(ic, oc, 3)) # (N,C,L)
conv_pkt.append(nn.BatchNorm1d(oc))
conv_pkt.append(nn.PReLU())
ic = oc
conv_pkt.append(nn.AdaptiveMaxPool1d(1))
conv_pkt.append(Squeeze())
self.conv_pkt = nn.Sequential(*conv_pkt) # (N,oc)
conv_smi = []
ic = smi_embed_size
for oc in [32, 64, smi_oc]:
conv_smi.append(DilatedParllelResidualBlockB(ic, oc))
ic = oc
conv_smi.append(nn.AdaptiveMaxPool1d(1))
conv_smi.append(Squeeze())
self.conv_smi = nn.Sequential(*conv_smi) # (N,128)
self.cat_dropout = nn.Dropout(0.2)
self.classifier = nn.Sequential(
nn.Linear(seq_oc + pkt_oc + smi_oc, 128),
nn.Dropout(0.5),
nn.PReLU(),
nn.Linear(128, 64),
nn.Dropout(0.5),
nn.PReLU(),
# nn.Linear(64, 1),
# nn.PReLU()
)
def forward(self, seq, pkt, smi):
# assert seq.shape == (N,L,43)
seq_embed = self.seq_embed(seq) # (N,L,32)
seq_embed = torch.transpose(seq_embed, 1, 2) # (N,32,L)
seq_conv = self.conv_seq(seq_embed) # (N,128)
# assert pkt.shape == (N,L,43)
pkt_embed = self.seq_embed(pkt) # (N,L,32)
pkt_embed = torch.transpose(pkt_embed, 1, 2)
pkt_conv = self.conv_pkt(pkt_embed) # (N,128)
# assert smi.shape == (N, L)
smi_embed = self.smi_embed(smi) # (N,L,32)
smi_embed = torch.transpose(smi_embed, 1, 2)
smi_conv = self.conv_smi(smi_embed) # (N,128)
cat = torch.cat([seq_conv, pkt_conv, smi_conv], dim=1) # (N,128*3)
cat = self.cat_dropout(cat)
output = self.classifier(cat)
return output
class Squeeze(nn.Module):
def forward(self, input: torch.Tensor):
return input.squeeze()
class CDilated(nn.Module):
def __init__(self, nIn, nOut, kSize, stride=1, d=1):
super().__init__()
padding = int((kSize - 1) / 2) * d
self.conv = nn.Conv1d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False, dilation=d)
def forward(self, input):
output = self.conv(input)
return output
class DilatedParllelResidualBlockA(nn.Module):
def __init__(self, nIn, nOut, add=True):
super().__init__()
n = int(nOut / 5)
n1 = nOut - 4 * n
self.c1 = nn.Conv1d(nIn, n, 1, padding=0)
self.br1 = nn.Sequential(nn.BatchNorm1d(n), nn.PReLU())
self.d1 = CDilated(n, n1, 3, 1, 1) # dilation rate of 2^0
self.d2 = CDilated(n, n, 3, 1, 2) # dilation rate of 2^1
self.d4 = CDilated(n, n, 3, 1, 4) # dilation rate of 2^2
self.d8 = CDilated(n, n, 3, 1, 8) # dilation rate of 2^3
self.d16 = CDilated(n, n, 3, 1, 16) # dilation rate of 2^4
self.br2 = nn.Sequential(nn.BatchNorm1d(nOut), nn.PReLU())
if nIn != nOut:
# print(f'{nIn}-{nOut}: add=False')
add = False
self.add = add
def forward(self, input):
# reduce
output1 = self.c1(input)
output1 = self.br1(output1)
# split and transform
d1 = self.d1(output1)
d2 = self.d2(output1)
d4 = self.d4(output1)
d8 = self.d8(output1)
d16 = self.d16(output1)
# heirarchical fusion for de-gridding
add1 = d2
add2 = add1 + d4
add3 = add2 + d8
add4 = add3 + d16
# merge
combine = torch.cat([d1, add1, add2, add3, add4], 1)
# if residual version
if self.add:
combine = input + combine
output = self.br2(combine)
return output
class DilatedParllelResidualBlockB(nn.Module):
def __init__(self, nIn, nOut, add=True):
super().__init__()
n = int(nOut / 4)
n1 = nOut - 3 * n
self.c1 = nn.Conv1d(nIn, n, 1, padding=0)
self.br1 = nn.Sequential(nn.BatchNorm1d(n), nn.PReLU())
self.d1 = CDilated(n, n1, 3, 1, 1) # dilation rate of 2^0
self.d2 = CDilated(n, n, 3, 1, 2) # dilation rate of 2^1
self.d4 = CDilated(n, n, 3, 1, 4) # dilation rate of 2^2
self.d8 = CDilated(n, n, 3, 1, 8) # dilation rate of 2^3
self.br2 = nn.Sequential(nn.BatchNorm1d(nOut), nn.PReLU())
if nIn != nOut:
# print(f'{nIn}-{nOut}: add=False')
add = False
self.add = add
def forward(self, input):
# reduce
output1 = self.c1(input)
output1 = self.br1(output1)
# split and transform
d1 = self.d1(output1)
d2 = self.d2(output1)
d4 = self.d4(output1)
d8 = self.d8(output1)
# heirarchical fusion for de-gridding
add1 = d2
add2 = add1 + d4
add3 = add2 + d8
# merge
combine = torch.cat([d1, add1, add2, add3], 1)
# if residual version
if self.add:
combine = input + combine
output = self.br2(combine)
return output
|