Spaces:
Runtime error
Runtime error
from turtle import forward | |
import dgl | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import dgl.function as fn | |
from dgl.nn.functional import edge_softmax | |
from FairGNN.src.models.GCN import GCN | |
from RHGN.layers import * | |
from RHGN.layers import RHGNLayer | |
class RHGN_adv(nn.Module): | |
def __init__(self, G, node_dict, edge_dict, n_inp, n_hid, n_out, n_layers, n_heads, cid1_feature, cid2_feature, cid3_feature): | |
super(RHGN_adv, self).__init__() | |
self.cid1_feature = nn.Embedding(cid1_feature.size(0), cid1_feature.size(1)) | |
self.cid1_feature.weight = nn.Parameter(cid1_feature) | |
self.cid1_feature.weight.requires_grad = False | |
self.cid2_feature = nn.Embedding(cid2_feature.size(0), cid2_feature.size(1)) | |
self.cid2_feature.weight = nn.Parameter(cid2_feature) | |
self.cid2_feature.weight.requires_grad = False | |
self.cid3_feature= nn.Embedding(cid3_feature.size(0), cid3_feature.size(1)) | |
self.cid3_feature.weight = nn.Parameter(cid3_feature) | |
self.cid3_feature.weight.requires_grad = False | |
self.adv_model = nn.Linear(n_hid, 1) # was n_out | |
#self.sens_model = nn.Linear(64, 2) | |
self.sens_model = GCN(200, 128, 1, 0.5) | |
#self.optimizer_A = torch.optim.Adam(self.adv_model.parameters(), lr=0.1, weight_decay=1e-5) | |
#self.A_loss = 0 | |
def forward(self, h, inputs, G, blocks, out_key, label_key, is_train=True, print_flag=False): | |
# h from orignal model | |
#s = self.sens_model(h) | |
inputs_new = inputs[0] | |
print('graph:', G) | |
s = self.sens_model(G, inputs_new) | |
print('inputs:', inputs.shape) | |
s_g = self.adv_model(h) | |
print('s:', s.shape) | |
print('s_g:', s_g.shape) | |
return s, s_g | |
class ali_RHGN(nn.Module): | |
def __init__(self, G, node_dict, edge_dict, n_inp, n_hid, n_out, n_layers, n_heads,cid1_feature,cid2_feature,cid3_feature, use_norm = True): | |
super(ali_RHGN, self).__init__() | |
self.node_dict = node_dict | |
self.edge_dict = edge_dict | |
self.gcs = nn.ModuleList() | |
self.n_inp = n_inp | |
self.n_hid = n_hid | |
self.n_out = n_out | |
self.n_layers = n_layers | |
self.adapt_ws = nn.ModuleList() | |
for t in range(len(node_dict)): | |
self.adapt_ws.append(nn.Linear(n_inp, n_hid)) | |
for _ in range(n_layers): | |
self.gcs.append(RHGNLayer(n_hid, n_hid, node_dict, edge_dict, n_heads, use_norm = use_norm)) | |
self.out = nn.Linear(n_hid, n_out) | |
self.cid1_feature= nn.Embedding(cid1_feature.size(0), cid1_feature.size(1)) | |
self.cid1_feature.weight = nn.Parameter(cid1_feature) | |
self.cid1_feature.weight.requires_grad = False | |
self.cid2_feature= nn.Embedding(cid2_feature.size(0), cid2_feature.size(1)) | |
self.cid2_feature.weight = nn.Parameter(cid2_feature) | |
self.cid2_feature.weight.requires_grad = False | |
self.cid3_feature= nn.Embedding(cid3_feature.size(0), cid3_feature.size(1)) | |
self.cid3_feature.weight = nn.Parameter(cid3_feature) | |
self.cid3_feature.weight.requires_grad = False | |
self.excitation = nn.Sequential( | |
nn.Linear(3, 32, bias=False), | |
nn.ReLU(), | |
nn.Linear(32, 3, bias=False), | |
nn.ReLU() | |
) | |
self.query = nn.Linear(200, n_inp) | |
self.key = nn.Linear(200, n_inp) | |
self.value = nn.Linear(200, n_inp) | |
self.skip = nn.Parameter(torch.ones(1)) | |
print('n_out:', self.n_out) | |
#self.query_sens = nn.Linear(200, n_inp) | |
#self.key_sens = nn.Linear(200, n_inp) | |
#self.value_sens = nn.Linear(200, n_inp) | |
#self.adv_model = nn.Linear(128, 1) | |
#self.adv_model = nn.Linear(n_hid, n_out) | |
#self.sens_model = GCN(95, 128, 1, 0.5) | |
#self.sens_model = nn.Linear(n_hid, n_out) | |
#self.sens_model2 = nn.Linear(n_inp, n_hid) | |
#self.sens_model3 = nn.Linear(n_hid, n_out) | |
#self.optimizer_A = torch.optim.Adam(self.adv_model.parameters(), lr=0.1, weight_decay=1e-5) | |
#self.criterion = nn.BCEWithLogitsLoss() | |
#self.optimizer_G = torch.optim.Adam(self.parameters()) | |
#self.A_loss = 0 | |
#self.G_loss = 0 | |
#self.scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optimizer_G, epochs=epochs, | |
# steps_per_epoch=int(train_idx.shape[0]/batch_size)+1,max_lr = lr) | |
def forward(self, input_nodes, output_nodes,blocks, out_key,label_key, is_train=True,print_flag=False): | |
item_cid1=blocks[0].srcnodes['item'].data['cid1'].unsqueeze(1) #(N,1) | |
cid1_feature = self.cid1_feature(item_cid1) # #(N,1,200) | |
item_cid2=blocks[0].srcnodes['item'].data['cid2'].unsqueeze(1) #(N,1) | |
cid2_feature = self.cid2_feature(item_cid2) # #(N,1,200) | |
item_cid3=blocks[0].srcnodes['item'].data['cid3'].unsqueeze(1) #(N,1) | |
cid3_feature = self.cid3_feature(item_cid3) # #(N,1,200) | |
cid2_feature=cid1_feature | |
cid3_feature=cid1_feature | |
item_feature = blocks[0].srcnodes['item'].data['inp'] | |
user_feature = blocks[0].srcnodes['user'].data['inp'] | |
# brand_feature = blocks[0].srcnodes['brand'].data['inp'] | |
inputs=torch.cat((cid1_feature,cid2_feature,cid3_feature),1) #(N,4,200) | |
#print('inputs:', inputs.shape) # (455, 3, 200) | |
k = self.key(inputs) #(N,4,n_inp) | |
v = self.value(inputs) #(N,4,n_inp) | |
q = self.query(item_feature.unsqueeze(-2)) #(N,1,n_inp) | |
att_score = torch.einsum("bij,bjk->bik", k, q.transpose(1,2)) / math.sqrt(200) #(N,4,1) | |
att_score = torch.softmax(att_score, axis=1) # (N,4,1) | |
alpha = torch.sigmoid(self.skip) #(1,) | |
temp = v * att_score #(N,4,n_inp) | |
item_feature = alpha*(torch.mean(temp, dim=-2).squeeze(-2)) + (1-alpha)*item_feature # #(N,200) | |
#print('item_feature:', item_feature) | |
h = {} | |
h['item']=F.gelu(self.adapt_ws[self.node_dict['item']](item_feature)) | |
h['user']=F.gelu(self.adapt_ws[self.node_dict['user']](user_feature)) | |
# h['brand']=F.gelu(self.adapt_ws[self.node_dict['brand']](brand_feature)) | |
for i in range(self.n_layers): | |
h = self.gcs[i](blocks[i], h, is_train=is_train,print_flag=print_flag) | |
h = h[out_key] | |
#print('h:', h) | |
#self.adv_model.requires_grad_(False) | |
#add sens model input | |
#s = self.sens_model(inputs) | |
#s = self.sens_model2(s) | |
#s = self.sens_model3(s) | |
#add adv model input | |
#s_g = self.adv_model(h) | |
h_new=self.out(h) | |
#print('h_new:', h_new.shape) | |
labels=blocks[-1].dstnodes[out_key].data[label_key] | |
# h=F.log_softmax(h, dim=1) | |
# return will be h, labels, and estimator output | |
return h_new, labels | |
class jd_RHGN(nn.Module): | |
def __init__(self, G, node_dict, edge_dict, n_inp, n_hid, n_out, n_layers, n_heads, cid1_feature, cid2_feature, | |
cid3_feature, cid4_feature, use_norm=True, ): | |
super(jd_RHGN, self).__init__() | |
self.node_dict = node_dict | |
self.edge_dict = edge_dict | |
self.gcs = nn.ModuleList() | |
self.n_inp = n_inp | |
self.n_hid = n_hid | |
self.n_out = n_out | |
self.n_layers = n_layers | |
self.adapt_ws = nn.ModuleList() | |
for t in range(len(node_dict)): | |
self.adapt_ws.append(nn.Linear(n_inp, n_hid)) | |
for _ in range(n_layers): | |
self.gcs.append(RHGNLayer(n_hid, n_hid, node_dict, edge_dict, n_heads, use_norm=use_norm)) | |
self.out = nn.Linear(n_hid, n_out) | |
self.cid1_feature = nn.Embedding(cid1_feature.size(0), cid1_feature.size(1)) | |
self.cid1_feature.weight = nn.Parameter(cid1_feature) | |
self.cid1_feature.weight.requires_grad = False | |
self.cid2_feature = nn.Embedding(cid2_feature.size(0), cid2_feature.size(1)) | |
self.cid2_feature.weight = nn.Parameter(cid2_feature) | |
self.cid2_feature.weight.requires_grad = False | |
self.cid3_feature = nn.Embedding(cid3_feature.size(0), cid3_feature.size(1)) | |
self.cid3_feature.weight = nn.Parameter(cid3_feature) | |
self.cid3_feature.weight.requires_grad = False | |
self.cid4_feature = nn.Embedding(cid4_feature.size(0), cid4_feature.size(1)) | |
self.cid4_feature.weight = nn.Parameter(cid4_feature) | |
self.cid4_feature.weight.requires_grad = False | |
self.excitation = nn.Sequential( | |
nn.Linear(4, 32, bias=False), | |
nn.ReLU(), | |
nn.Linear(32, 4, bias=False), | |
nn.ReLU() | |
) | |
self.query = nn.Linear(200, n_inp) | |
self.key = nn.Linear(200, n_inp) | |
self.value = nn.Linear(200, n_inp) | |
self.skip = nn.Parameter(torch.ones(1)) | |
self.l1=nn.Linear(200, n_inp) | |
self.l2=nn.Linear(200, n_inp) | |
self.l3=nn.Linear(200, n_inp) | |
self.l4=nn.Linear(200, n_inp) | |
def forward(self, input_nodes, output_nodes, blocks, out_key, label_key, is_train=True,print_flag=False): | |
item_cid1 = blocks[0].srcnodes['item'].data['cid1'].unsqueeze(1) # (N,1) | |
cid1_feature = self.cid1_feature(item_cid1) # #(N,1,200) | |
#cid1_feature = self.l1(cid1_feature) | |
item_cid2 = blocks[0].srcnodes['item'].data['cid2'].unsqueeze(1) # (N,1) | |
cid2_feature = self.cid2_feature(item_cid2) # #(N,1,200) | |
#cid2_feature = self.l2(cid2_feature) | |
item_cid3 = blocks[0].srcnodes['item'].data['cid3'].unsqueeze(1) # (N,1) | |
cid3_feature = self.cid3_feature(item_cid3) # #(N,1,200) | |
#cid3_fature = self.l3(cid3_feature) | |
# item_cid4 = blocks[0].srcnodes['item'].data['brand'].unsqueeze(1) # (N,1) | |
# cid4_feature = self.cid4_feature(item_cid4) # #(N,1,200) | |
#cid4_feature = self.l4(cid4_feature) | |
cid2_feature=cid1_feature | |
cid3_feature=cid1_feature | |
# cid4_feature=cid1_feature | |
item_feature = blocks[0].srcnodes['item'].data['inp'] | |
user_feature = blocks[0].srcnodes['user'].data['inp'] | |
# inputs = torch.cat((cid1_feature, cid2_feature, cid3_feature, cid4_feature), 1) # (N,4,200) | |
inputs = torch.cat((cid1_feature, cid2_feature, cid3_feature), 1) # (N,3,200) | |
k = self.key(inputs) # (N,3,200) | |
v = self.value(inputs) # (N,3,200) | |
q = self.query(item_feature.unsqueeze(-2)) # (N,1,32) | |
att_score = torch.einsum("bij,bjk->bik", k, q.transpose(1, 2)) / math.sqrt(200) # (N,4,1) | |
att_score = torch.softmax(att_score, axis=1) # (N,4,1) | |
#Z = torch.mean(inputs, dim=-1, out=None) # (N,4) | |
#A = self.excitation(Z).unsqueeze(-1) # (N,4,1) | |
#att_score = att_score + A # (N,4,1) | |
alpha = torch.sigmoid(self.skip) # (1,) | |
temp = v * att_score # (N,4,200) | |
item_feature = alpha * (torch.mean(temp, dim=-2).squeeze(-2)) + (1 - alpha) * item_feature # (N,200) | |
h = {} | |
h['item'] = F.gelu(self.adapt_ws[self.node_dict['item']](item_feature)) | |
h['user'] = F.gelu(self.adapt_ws[self.node_dict['user']](user_feature)) | |
for i in range(self.n_layers): | |
h = self.gcs[i](blocks[i], h, is_train=is_train,print_flag=print_flag) | |
h = h[out_key] | |
h = self.out(h) | |
labels = blocks[-1].dstnodes[out_key].data[label_key] | |
# h=F.log_softmax(h, dim=1) | |
return h, labels | |