Spaces:
Runtime error
Runtime error
import dgl | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import dgl.function as fn | |
from dgl.nn.functional import edge_softmax | |
import pandas as pd | |
class RHGNLayer(nn.Module): | |
def __init__(self, | |
in_dim, | |
out_dim, | |
node_dict, | |
edge_dict, | |
n_heads, | |
dropout = 0.2, | |
use_norm = False): | |
super(RHGNLayer, self).__init__() | |
self.in_dim = in_dim | |
self.out_dim = out_dim | |
self.node_dict = node_dict | |
self.edge_dict = edge_dict | |
self.num_types = len(node_dict) | |
self.num_relations = len(edge_dict) | |
self.total_rel = self.num_types * self.num_relations * self.num_types | |
self.n_heads = n_heads | |
self.d_k = out_dim // n_heads | |
self.sqrt_dk = math.sqrt(self.d_k) | |
self.att = None | |
self.k_linears = nn.ModuleList() | |
self.q_linears = nn.ModuleList() | |
self.v_linears = nn.ModuleList() | |
self.a_linears = nn.ModuleList() | |
self.norms = nn.ModuleList() | |
self.use_norm = use_norm | |
for t in range(self.num_types): | |
self.k_linears.append(nn.Linear(in_dim, out_dim)) | |
self.q_linears.append(nn.Linear(in_dim, out_dim)) | |
self.v_linears.append(nn.Linear(in_dim, out_dim)) | |
self.a_linears.append(nn.Linear(out_dim, out_dim)) | |
if use_norm: | |
self.norms.append(nn.LayerNorm(out_dim)) | |
self.relation_pri = nn.Parameter(torch.ones(self.num_relations, self.n_heads)) | |
self.relation_att = nn.Parameter(torch.Tensor(self.num_relations, n_heads, self.d_k, self.d_k)) | |
self.relation_msg = nn.Parameter(torch.Tensor(self.num_relations, n_heads, self.d_k, self.d_k)) | |
self.skip = nn.Parameter(torch.ones(self.num_types)) | |
self.drop = nn.Dropout(dropout) | |
nn.init.xavier_uniform_(self.relation_att) | |
nn.init.xavier_uniform_(self.relation_msg) | |
def forward(self, G, h,is_batch=True,is_train=True,print_flag=False): | |
with G.local_scope(): | |
node_dict, edge_dict = self.node_dict, self.edge_dict | |
for srctype, etype, dsttype in G.canonical_etypes: | |
sub_graph = G[srctype, etype, dsttype] | |
k_linear = self.k_linears[node_dict[srctype]] | |
v_linear = self.v_linears[node_dict[srctype]] | |
q_linear = self.q_linears[node_dict[dsttype]] | |
#k_linear = self.k_linears[0] | |
#v_linear = self.v_linears[0] | |
#q_linear = self.q_linears[0] | |
k = k_linear(h[srctype]).view(-1, self.n_heads, self.d_k) | |
v = v_linear(h[srctype]).view(-1, self.n_heads, self.d_k) | |
if is_batch: | |
q = q_linear(h[dsttype][:sub_graph.number_of_dst_nodes()]).view(-1, self.n_heads, self.d_k) | |
else: | |
q = q_linear(h[dsttype]).view(-1, self.n_heads, self.d_k) | |
e_id = self.edge_dict[etype] | |
relation_att = self.relation_att[e_id] | |
relation_pri = self.relation_pri[e_id] | |
relation_msg = self.relation_msg[e_id] | |
#relation_att = self.relation_att[0] | |
#relation_msg = self.relation_msg[0] | |
#relation_pri = self.relation_pri[0] | |
k = torch.einsum("bij,ijk->bik", k, relation_att) | |
v = torch.einsum("bij,ijk->bik", v, relation_msg) | |
sub_graph.srcdata['k'] = k | |
sub_graph.dstdata['q'] = q | |
sub_graph.srcdata['v_%d' % e_id] = v | |
sub_graph.apply_edges(fn.v_dot_u('q', 'k', 't')) | |
attn_score = sub_graph.edata.pop('t').sum(-1) * relation_pri / self.sqrt_dk | |
attn_score = edge_softmax(sub_graph, attn_score, norm_by='dst') | |
sub_graph.edata['t'] = attn_score.unsqueeze(-1) | |
''' | |
if print_flag==True: | |
# print('---------------',srctype,etype,dsttype,'---------------------') | |
srcnode=sub_graph.edges()[0].cpu().numpy() | |
dstnode=sub_graph.edges()[1].cpu().numpy() | |
attweight=attn_score.mean(dim=-1).cpu().detach().numpy() | |
# print(srcnode.shape,dstnode.shape,attweight.shape) | |
import time | |
if etype=='click': | |
df=pd.DataFrame({srctype:srcnode,dsttype:dstnode,etype:attweight}) | |
df.to_csv('../data/attweight/{}.csv'.format(time.time()),index=False) | |
if etype=='purchase': | |
df = pd.DataFrame({srctype: srcnode, dsttype: dstnode, etype: attweight}) | |
df.to_csv('../data/attweight/{}.csv'.format(time.time()), index=False) | |
''' | |
G.multi_update_all({etype : (fn.u_mul_e('v_%d' % e_id, 't', 'm'), fn.sum('m', 't')) for etype, e_id in edge_dict.items()}, | |
cross_reducer = 'mean') | |
new_h = {} | |
for ntype in G.ntypes: | |
''' | |
Step 3: Target-specific Aggregation | |
x = norm( W[node_type] * gelu( Agg(x) ) + x ) | |
''' | |
n_id = node_dict[ntype] | |
alpha = torch.sigmoid(self.skip[n_id]) | |
if is_batch: | |
t = G.dstnodes[ntype].data['t'].view(-1, self.out_dim) | |
else: | |
t = G.nodes[ntype].data['t'].view(-1, self.out_dim) | |
trans_out = self.a_linears[n_id](t) | |
if is_train: | |
trans_out = self.drop(trans_out) | |
trans_out = trans_out * alpha + h[ntype][:G.number_of_dst_nodes(ntype)] * (1-alpha) | |
if self.use_norm: | |
new_h[ntype] = self.norms[n_id](trans_out) | |
else: | |
new_h[ntype] = trans_out | |
return new_h | |