mbp / UltraFlow /layers /utils.py
jiaxianustc's picture
test
3ad8be1
import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import edge_softmax
from dgl import softmax_edges
class FC(nn.Module):
def __init__(self, d_graph_layer, fc_hidden_dim, dropout, n_tasks):
super(FC, self).__init__()
self.predict = nn.ModuleList()
for index,dim in enumerate(fc_hidden_dim):
self.predict.append(nn.Linear(d_graph_layer, dim))
self.predict.append(nn.Dropout(dropout))
self.predict.append(nn.LeakyReLU())
self.predict.append(nn.BatchNorm1d(dim))
d_graph_layer = dim
self.predict.append(nn.Linear(d_graph_layer, n_tasks))
def forward(self, h):
for layer in self.predict:
h = layer(h)
# return torch.sigmoid(h)
return h
class EdgeWeightAndSum(nn.Module):
"""
for normal use, please delete the 'temporary version' line and meanwhile recover the 'normal version'
"""
def __init__(self, in_feats):
super(EdgeWeightAndSum, self).__init__()
self.in_feats = in_feats
self.atom_weighting = nn.Sequential(
nn.Linear(in_feats, 1),
nn.Tanh()
)
def forward(self, g, edge_feats):
with g.local_scope():
g.edata['e'] = edge_feats
g.edata['w'] = self.atom_weighting(g.edata['e'])
weights = g.edata['w'] # temporary version
h_g_sum = dgl.sum_edges(g, 'e', 'w')
# return h_g_sum, g.edata['w'] # normal version
return h_g_sum, weights # temporary version
class MultiHeadAttention(nn.Module):
def __init__(self, in_feats, num_head, merge):
super(MultiHeadAttention, self).__init__()
self.heads = nn.ModuleList()
for i in range(num_head):
self.heads.append(EdgeWeightAndSum(in_feats))
self.merge = merge
def forward(self, g, edge_feats):
h_g_heads, weight_heads = [], []
for attn_head in self.heads:
h_g_head, weigh = attn_head(g, edge_feats)
h_g_heads.append(h_g_head)
weight_heads.append(weigh)
if self.merge == 'concat':
return torch.cat(h_g_heads, dim=1), torch.cat(weight_heads, dim=1)
else:
return torch.mean(torch.stack(h_g_heads)), torch.mean(torch.stack(weight_heads))
class EdgeWeightAndSum_v2(nn.Module):
"""
for normal use, please delete the 'temporary version' line and meanwhile recover the 'normal version'
"""
def __init__(self, in_feats):
super(EdgeWeightAndSum_v2, self).__init__()
self.in_feats = in_feats
self.atom_weighting = nn.Sequential(
nn.Linear(in_feats, 1),
nn.LeakyReLU()
)
def forward(self, g, edge_feats):
with g.local_scope():
g.edata['e'] = edge_feats
g.edata['w'] = edge_softmax(g, self.atom_weighting(g.edata['e']))
weights = g.edata['w'] # temporary version
h_g_sum = dgl.sum_edges(g, 'e', 'w')
# return h_g_sum, g.edata['w'] # normal version
return h_g_sum, weights # temporary version
class MultiHeadAttention_v2(nn.Module):
def __init__(self, in_feats, num_head, merge):
super(MultiHeadAttention_v2, self).__init__()
self.heads = nn.ModuleList()
for i in range(num_head):
self.heads.append(EdgeWeightAndSum_v2(in_feats))
self.merge = merge
def forward(self, g, edge_feats):
h_g_heads, weight_heads = [], []
for attn_head in self.heads:
h_g_head, weigh = attn_head(g, edge_feats)
h_g_heads.append(h_g_head)
weight_heads.append(weigh)
if self.merge == 'concat':
return torch.cat(h_g_heads, dim=1), torch.cat(weight_heads, dim=1)
else:
return torch.mean(torch.stack(h_g_heads)), torch.mean(torch.stack(weight_heads))
class EdgeWeightAndSum_v3(nn.Module):
"""
for normal use, please delete the 'temporary version' line and meanwhile recover the 'normal version'
"""
def __init__(self, in_feats):
super(EdgeWeightAndSum_v3, self).__init__()
self.in_feats = in_feats
self.atom_weighting = nn.Sequential(
nn.Linear(in_feats, 1),
nn.LeakyReLU()
)
def forward(self, g, edge_feats):
with g.local_scope():
g.edata['e'] = edge_feats
g.edata['e2'] = self.atom_weighting(g.edata['e'])
g.edata['w'] = softmax_edges(g, 'e2')
weights = g.edata['w'] # temporary version
h_g_sum = dgl.sum_edges(g, 'e', 'w')
# return h_g_sum, g.edata['w'] # normal version
return h_g_sum, weights # temporary version
class MultiHeadAttention_v3(nn.Module):
def __init__(self, in_feats, num_head, merge):
super(MultiHeadAttention_v3, self).__init__()
self.heads = nn.ModuleList()
for i in range(num_head):
self.heads.append(EdgeWeightAndSum_v3(in_feats))
self.merge = merge
def forward(self, g, edge_feats):
h_g_heads, weight_heads = [], []
for attn_head in self.heads:
h_g_head, weigh = attn_head(g, edge_feats)
h_g_heads.append(h_g_head)
weight_heads.append(weigh)
if self.merge == 'concat':
return torch.cat(h_g_heads, dim=1), torch.cat(weight_heads, dim=1)
else:
return torch.mean(torch.stack(h_g_heads)), torch.mean(torch.stack(weight_heads))
class ReadsOutLayer(nn.Module):
"""
for normal use, please delete the 'temporary version' line and meanwhile recover the 'normal version'
"""
def __init__(self, in_feats, pooling, num_head=None, attn_merge=None):
super(ReadsOutLayer, self).__init__()
self.pooling = pooling
if self.pooling == 'w_sum':
self.weight_and_sum = EdgeWeightAndSum(in_feats)
elif self.pooling == 'multi_head':
self.weight_and_sum = MultiHeadAttention(in_feats, num_head, attn_merge)
elif self.pooling == 'w_sum_v2':
self.weight_and_sum = EdgeWeightAndSum_v2(in_feats)
elif self.pooling == 'multi_head_v2':
self.weight_and_sum = MultiHeadAttention_v2(in_feats, num_head, attn_merge)
elif self.pooling == 'w_sum_v3':
self.weight_and_sum = EdgeWeightAndSum_v3(in_feats)
elif self.pooling == 'multi_head_v3':
self.weight_and_sum = MultiHeadAttention_v3(in_feats, num_head, attn_merge)
def forward(self, bg, edge_feats):
# h_g_sum, weights = self.weight_and_sum(bg, edge_feats) # temporary version
with bg.local_scope():
bg.edata['e'] = edge_feats
h_g_max = dgl.max_edges(bg, 'e')
if self.pooling == 'mean':
h_p = dgl.mean_edges(bg, 'e')
elif self.pooling == 'sum':
h_p = dgl.sum_edges(bg,'e')
else:
h_p, weights = self.weight_and_sum(bg, edge_feats) # normal version
bg.edata['weights'] = weights
return torch.cat([h_p, h_g_max], dim=1) # normal version