FairUP / src /models /RHGN /model.py
erasmopurif's picture
First commit
d2a8669
from turtle import forward
import dgl
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
from dgl.nn.functional import edge_softmax
from FairGNN.src.models.GCN import GCN
from RHGN.layers import *
from RHGN.layers import RHGNLayer
class RHGN_adv(nn.Module):
def __init__(self, G, node_dict, edge_dict, n_inp, n_hid, n_out, n_layers, n_heads, cid1_feature, cid2_feature, cid3_feature):
super(RHGN_adv, self).__init__()
self.cid1_feature = nn.Embedding(cid1_feature.size(0), cid1_feature.size(1))
self.cid1_feature.weight = nn.Parameter(cid1_feature)
self.cid1_feature.weight.requires_grad = False
self.cid2_feature = nn.Embedding(cid2_feature.size(0), cid2_feature.size(1))
self.cid2_feature.weight = nn.Parameter(cid2_feature)
self.cid2_feature.weight.requires_grad = False
self.cid3_feature= nn.Embedding(cid3_feature.size(0), cid3_feature.size(1))
self.cid3_feature.weight = nn.Parameter(cid3_feature)
self.cid3_feature.weight.requires_grad = False
self.adv_model = nn.Linear(n_hid, 1) # was n_out
#self.sens_model = nn.Linear(64, 2)
self.sens_model = GCN(200, 128, 1, 0.5)
#self.optimizer_A = torch.optim.Adam(self.adv_model.parameters(), lr=0.1, weight_decay=1e-5)
#self.A_loss = 0
def forward(self, h, inputs, G, blocks, out_key, label_key, is_train=True, print_flag=False):
# h from orignal model
#s = self.sens_model(h)
inputs_new = inputs[0]
print('graph:', G)
s = self.sens_model(G, inputs_new)
print('inputs:', inputs.shape)
s_g = self.adv_model(h)
print('s:', s.shape)
print('s_g:', s_g.shape)
return s, s_g
class ali_RHGN(nn.Module):
def __init__(self, G, node_dict, edge_dict, n_inp, n_hid, n_out, n_layers, n_heads,cid1_feature,cid2_feature,cid3_feature, use_norm = True):
super(ali_RHGN, self).__init__()
self.node_dict = node_dict
self.edge_dict = edge_dict
self.gcs = nn.ModuleList()
self.n_inp = n_inp
self.n_hid = n_hid
self.n_out = n_out
self.n_layers = n_layers
self.adapt_ws = nn.ModuleList()
for t in range(len(node_dict)):
self.adapt_ws.append(nn.Linear(n_inp, n_hid))
for _ in range(n_layers):
self.gcs.append(RHGNLayer(n_hid, n_hid, node_dict, edge_dict, n_heads, use_norm = use_norm))
self.out = nn.Linear(n_hid, n_out)
self.cid1_feature= nn.Embedding(cid1_feature.size(0), cid1_feature.size(1))
self.cid1_feature.weight = nn.Parameter(cid1_feature)
self.cid1_feature.weight.requires_grad = False
self.cid2_feature= nn.Embedding(cid2_feature.size(0), cid2_feature.size(1))
self.cid2_feature.weight = nn.Parameter(cid2_feature)
self.cid2_feature.weight.requires_grad = False
self.cid3_feature= nn.Embedding(cid3_feature.size(0), cid3_feature.size(1))
self.cid3_feature.weight = nn.Parameter(cid3_feature)
self.cid3_feature.weight.requires_grad = False
self.excitation = nn.Sequential(
nn.Linear(3, 32, bias=False),
nn.ReLU(),
nn.Linear(32, 3, bias=False),
nn.ReLU()
)
self.query = nn.Linear(200, n_inp)
self.key = nn.Linear(200, n_inp)
self.value = nn.Linear(200, n_inp)
self.skip = nn.Parameter(torch.ones(1))
print('n_out:', self.n_out)
#self.query_sens = nn.Linear(200, n_inp)
#self.key_sens = nn.Linear(200, n_inp)
#self.value_sens = nn.Linear(200, n_inp)
#self.adv_model = nn.Linear(128, 1)
#self.adv_model = nn.Linear(n_hid, n_out)
#self.sens_model = GCN(95, 128, 1, 0.5)
#self.sens_model = nn.Linear(n_hid, n_out)
#self.sens_model2 = nn.Linear(n_inp, n_hid)
#self.sens_model3 = nn.Linear(n_hid, n_out)
#self.optimizer_A = torch.optim.Adam(self.adv_model.parameters(), lr=0.1, weight_decay=1e-5)
#self.criterion = nn.BCEWithLogitsLoss()
#self.optimizer_G = torch.optim.Adam(self.parameters())
#self.A_loss = 0
#self.G_loss = 0
#self.scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optimizer_G, epochs=epochs,
# steps_per_epoch=int(train_idx.shape[0]/batch_size)+1,max_lr = lr)
def forward(self, input_nodes, output_nodes,blocks, out_key,label_key, is_train=True,print_flag=False):
item_cid1=blocks[0].srcnodes['item'].data['cid1'].unsqueeze(1) #(N,1)
cid1_feature = self.cid1_feature(item_cid1) # #(N,1,200)
item_cid2=blocks[0].srcnodes['item'].data['cid2'].unsqueeze(1) #(N,1)
cid2_feature = self.cid2_feature(item_cid2) # #(N,1,200)
item_cid3=blocks[0].srcnodes['item'].data['cid3'].unsqueeze(1) #(N,1)
cid3_feature = self.cid3_feature(item_cid3) # #(N,1,200)
cid2_feature=cid1_feature
cid3_feature=cid1_feature
item_feature = blocks[0].srcnodes['item'].data['inp']
user_feature = blocks[0].srcnodes['user'].data['inp']
# brand_feature = blocks[0].srcnodes['brand'].data['inp']
inputs=torch.cat((cid1_feature,cid2_feature,cid3_feature),1) #(N,4,200)
#print('inputs:', inputs.shape) # (455, 3, 200)
k = self.key(inputs) #(N,4,n_inp)
v = self.value(inputs) #(N,4,n_inp)
q = self.query(item_feature.unsqueeze(-2)) #(N,1,n_inp)
att_score = torch.einsum("bij,bjk->bik", k, q.transpose(1,2)) / math.sqrt(200) #(N,4,1)
att_score = torch.softmax(att_score, axis=1) # (N,4,1)
alpha = torch.sigmoid(self.skip) #(1,)
temp = v * att_score #(N,4,n_inp)
item_feature = alpha*(torch.mean(temp, dim=-2).squeeze(-2)) + (1-alpha)*item_feature # #(N,200)
#print('item_feature:', item_feature)
h = {}
h['item']=F.gelu(self.adapt_ws[self.node_dict['item']](item_feature))
h['user']=F.gelu(self.adapt_ws[self.node_dict['user']](user_feature))
# h['brand']=F.gelu(self.adapt_ws[self.node_dict['brand']](brand_feature))
for i in range(self.n_layers):
h = self.gcs[i](blocks[i], h, is_train=is_train,print_flag=print_flag)
h = h[out_key]
#print('h:', h)
#self.adv_model.requires_grad_(False)
#add sens model input
#s = self.sens_model(inputs)
#s = self.sens_model2(s)
#s = self.sens_model3(s)
#add adv model input
#s_g = self.adv_model(h)
h_new=self.out(h)
#print('h_new:', h_new.shape)
labels=blocks[-1].dstnodes[out_key].data[label_key]
# h=F.log_softmax(h, dim=1)
# return will be h, labels, and estimator output
return h_new, labels
class jd_RHGN(nn.Module):
def __init__(self, G, node_dict, edge_dict, n_inp, n_hid, n_out, n_layers, n_heads, cid1_feature, cid2_feature,
cid3_feature, cid4_feature, use_norm=True, ):
super(jd_RHGN, self).__init__()
self.node_dict = node_dict
self.edge_dict = edge_dict
self.gcs = nn.ModuleList()
self.n_inp = n_inp
self.n_hid = n_hid
self.n_out = n_out
self.n_layers = n_layers
self.adapt_ws = nn.ModuleList()
for t in range(len(node_dict)):
self.adapt_ws.append(nn.Linear(n_inp, n_hid))
for _ in range(n_layers):
self.gcs.append(RHGNLayer(n_hid, n_hid, node_dict, edge_dict, n_heads, use_norm=use_norm))
self.out = nn.Linear(n_hid, n_out)
self.cid1_feature = nn.Embedding(cid1_feature.size(0), cid1_feature.size(1))
self.cid1_feature.weight = nn.Parameter(cid1_feature)
self.cid1_feature.weight.requires_grad = False
self.cid2_feature = nn.Embedding(cid2_feature.size(0), cid2_feature.size(1))
self.cid2_feature.weight = nn.Parameter(cid2_feature)
self.cid2_feature.weight.requires_grad = False
self.cid3_feature = nn.Embedding(cid3_feature.size(0), cid3_feature.size(1))
self.cid3_feature.weight = nn.Parameter(cid3_feature)
self.cid3_feature.weight.requires_grad = False
self.cid4_feature = nn.Embedding(cid4_feature.size(0), cid4_feature.size(1))
self.cid4_feature.weight = nn.Parameter(cid4_feature)
self.cid4_feature.weight.requires_grad = False
self.excitation = nn.Sequential(
nn.Linear(4, 32, bias=False),
nn.ReLU(),
nn.Linear(32, 4, bias=False),
nn.ReLU()
)
self.query = nn.Linear(200, n_inp)
self.key = nn.Linear(200, n_inp)
self.value = nn.Linear(200, n_inp)
self.skip = nn.Parameter(torch.ones(1))
self.l1=nn.Linear(200, n_inp)
self.l2=nn.Linear(200, n_inp)
self.l3=nn.Linear(200, n_inp)
self.l4=nn.Linear(200, n_inp)
def forward(self, input_nodes, output_nodes, blocks, out_key, label_key, is_train=True,print_flag=False):
item_cid1 = blocks[0].srcnodes['item'].data['cid1'].unsqueeze(1) # (N,1)
cid1_feature = self.cid1_feature(item_cid1) # #(N,1,200)
#cid1_feature = self.l1(cid1_feature)
item_cid2 = blocks[0].srcnodes['item'].data['cid2'].unsqueeze(1) # (N,1)
cid2_feature = self.cid2_feature(item_cid2) # #(N,1,200)
#cid2_feature = self.l2(cid2_feature)
item_cid3 = blocks[0].srcnodes['item'].data['cid3'].unsqueeze(1) # (N,1)
cid3_feature = self.cid3_feature(item_cid3) # #(N,1,200)
#cid3_fature = self.l3(cid3_feature)
# item_cid4 = blocks[0].srcnodes['item'].data['brand'].unsqueeze(1) # (N,1)
# cid4_feature = self.cid4_feature(item_cid4) # #(N,1,200)
#cid4_feature = self.l4(cid4_feature)
cid2_feature=cid1_feature
cid3_feature=cid1_feature
# cid4_feature=cid1_feature
item_feature = blocks[0].srcnodes['item'].data['inp']
user_feature = blocks[0].srcnodes['user'].data['inp']
# inputs = torch.cat((cid1_feature, cid2_feature, cid3_feature, cid4_feature), 1) # (N,4,200)
inputs = torch.cat((cid1_feature, cid2_feature, cid3_feature), 1) # (N,3,200)
k = self.key(inputs) # (N,3,200)
v = self.value(inputs) # (N,3,200)
q = self.query(item_feature.unsqueeze(-2)) # (N,1,32)
att_score = torch.einsum("bij,bjk->bik", k, q.transpose(1, 2)) / math.sqrt(200) # (N,4,1)
att_score = torch.softmax(att_score, axis=1) # (N,4,1)
#Z = torch.mean(inputs, dim=-1, out=None) # (N,4)
#A = self.excitation(Z).unsqueeze(-1) # (N,4,1)
#att_score = att_score + A # (N,4,1)
alpha = torch.sigmoid(self.skip) # (1,)
temp = v * att_score # (N,4,200)
item_feature = alpha * (torch.mean(temp, dim=-2).squeeze(-2)) + (1 - alpha) * item_feature # (N,200)
h = {}
h['item'] = F.gelu(self.adapt_ws[self.node_dict['item']](item_feature))
h['user'] = F.gelu(self.adapt_ws[self.node_dict['user']](user_feature))
for i in range(self.n_layers):
h = self.gcs[i](blocks[i], h, is_train=is_train,print_flag=print_flag)
h = h[out_key]
h = self.out(h)
labels = blocks[-1].dstnodes[out_key].data[label_key]
# h=F.log_softmax(h, dim=1)
return h, labels