Spaces:
Runtime error
Runtime error
import scipy.io | |
import dgl | |
import math | |
import torch | |
import numpy as np | |
#from model import * | |
from RHGN.model import * | |
import argparse | |
from sklearn import metrics | |
import time | |
from sklearn.metrics import f1_score | |
#import neptune.new as neptune | |
from RHGN.fairness import Fairness | |
''' | |
parser = argparse.ArgumentParser(description='for JD Dataset') | |
parser.add_argument('--n_epoch', type=int, default=50) | |
parser.add_argument('--batch_size', type=int, default=512) | |
parser.add_argument('--seed', type=int, default=7) | |
parser.add_argument('--n_hid', type=int, default=32) | |
parser.add_argument('--n_inp', type=int, default=200) | |
parser.add_argument('--clip', type=int, default=1.0) | |
parser.add_argument('--max_lr', type=float, default=1e-2) | |
parser.add_argument('--label', type=str, default='gender') | |
parser.add_argument('--gpu', type=int, default=0,choices=[0,1,2,3,4,5,6,7]) | |
parser.add_argument('--graph', type=str, default='G_ori') | |
parser.add_argument('--model', type=str, default='RHGN',choices=['RHGN','RGCN']) | |
parser.add_argument('--data_dir', type=str, default='../data/sample') | |
parser.add_argument('--sens_attr', type=str, default='gender') | |
parser.add_argument('--log_tags', type=str, default='') | |
parser.add_argument('--neptune-project', type=str, default='') | |
parser.add_argument('--neptune-token', type=str, default='') | |
parser.add_argument('--multiclass-pred', type=bool, default=False) | |
parser.add_argument('--multiclass-sens', type=bool, default=False) | |
args = parser.parse_args() | |
# Instantiate Neptune client and log arguments | |
neptune_run = neptune.init( | |
project=args.neptune_project, | |
api_token=args.neptune_token, | |
) | |
neptune_run["sys/tags"].add(args.log_tags.split(",")) | |
neptune_run["seed"] = args.seed | |
neptune_run["dataset"] = "JD-small-sampled" | |
neptune_run["model"] = args.model | |
neptune_run["label"] = args.label | |
neptune_run["num_epochs"] = args.n_epoch | |
neptune_run["n_hid"] = args.n_hid | |
neptune_run["lr"] = args.max_lr | |
neptune_run["clip"] = args.clip | |
''' | |
def get_n_params(model): | |
pp=0 | |
for p in list(model.parameters()): | |
nn=1 | |
for s in list(p.size()): | |
nn = nn*s | |
pp += nn | |
return pp | |
def Batch_train(model, optimizer, scheduler, train_dataloader, val_dataloader, test_dataloader, epochs, label, clip, device): | |
tic = time.perf_counter() # start counting time | |
best_val_acc = 0 | |
best_test_acc = 0 | |
train_step = 0 | |
Minloss_val = 10000.0 | |
for epoch in np.arange(epochs) + 1: | |
model.train() | |
'''---------------------------train------------------------''' | |
total_loss = 0 | |
total_acc = 0 | |
count = 0 | |
for input_nodes, output_nodes, blocks in train_dataloader: | |
Batch_logits,Batch_labels = model(input_nodes,output_nodes,blocks, out_key='user',label_key=label, is_train=True) | |
# The loss is computed only for labeled nodes. | |
loss = F.cross_entropy(Batch_logits, Batch_labels) | |
optimizer.zero_grad() | |
loss.backward() | |
torch.nn.utils.clip_grad_norm_(model.parameters(), clip) | |
optimizer.step() | |
train_step += 1 | |
scheduler.step(train_step) | |
acc = torch.sum(Batch_logits.argmax(1) == Batch_labels).item() | |
total_loss += loss.item() * len(output_nodes['user'].cpu()) | |
total_acc += acc | |
count += len(output_nodes['user'].cpu()) | |
train_loss, train_acc = total_loss / count, total_acc / count | |
if epoch % 1 == 0: | |
model.eval() | |
'''-------------------------val-----------------------''' | |
with torch.no_grad(): | |
total_loss = 0 | |
total_acc = 0 | |
count = 0 | |
preds=[] | |
labels=[] | |
for input_nodes, output_nodes, blocks in val_dataloader: | |
Batch_logits,Batch_labels = model(input_nodes, output_nodes,blocks, out_key='user',label_key=label, is_train=False) | |
loss = F.cross_entropy(Batch_logits, Batch_labels) | |
acc = torch.sum(Batch_logits.argmax(1)==Batch_labels).item() | |
preds.extend(Batch_logits.argmax(1).tolist()) | |
labels.extend(Batch_labels.tolist()) | |
total_loss += loss.item() * len(output_nodes['user'].cpu()) | |
total_acc +=acc | |
count += len(output_nodes['user'].cpu()) | |
val_f1 = metrics.f1_score(labels, preds, average='macro') | |
val_loss,val_acc = total_loss / count, total_acc / count | |
'''------------------------test----------------------''' | |
total_loss = 0 | |
total_acc = 0 | |
count = 0 | |
preds=[] | |
labels=[] | |
for input_nodes, output_nodes, blocks in test_dataloader: | |
Batch_logits,Batch_labels = model(input_nodes, output_nodes,blocks, out_key='user',label_key=label, is_train=False) | |
loss = F.cross_entropy(Batch_logits, Batch_labels) | |
acc = torch.sum(Batch_logits.argmax(1)==Batch_labels).item() | |
preds.extend(Batch_logits.argmax(1).tolist()) | |
labels.extend(Batch_labels.tolist()) | |
total_loss += loss.item() * len(output_nodes['user'].cpu()) | |
total_acc +=acc | |
count += len(output_nodes['user'].cpu()) | |
test_f1 = metrics.f1_score(labels, preds, average='macro') | |
test_loss,test_acc = total_loss / count, total_acc / count | |
if val_acc > best_val_acc: | |
Minloss_val = val_loss | |
best_val_acc = val_acc | |
best_test_acc = test_acc | |
torch.save({'model_params':model.state_dict(),'epoch':epoch},'models.pth') | |
print('Epoch: %d LR: %.5f Loss %.4f, val loss %.4f, Val Acc %.4f (Best %.4f), Test Acc %.4f (Best %.4f)' % ( | |
epoch, | |
optimizer.param_groups[0]['lr'], | |
train_loss, | |
val_loss, | |
val_acc, | |
best_val_acc, | |
test_acc, | |
best_test_acc, | |
)) | |
print('\t\tval_f1 %.4f test_f1 \033[1;33m %.4f \033[0m' % (val_f1, test_f1)) | |
torch.cuda.empty_cache() | |
if val_loss < 0.4: | |
break | |
checkpoint=torch.load('models.pth') | |
epoch=checkpoint['epoch'] | |
model.load_state_dict(checkpoint['model_params']) | |
model.to(device) | |
print_flag=True | |
model.eval() | |
'''-------------------------test-----------------------''' | |
total_loss = 0 | |
total_acc = 0 | |
count = 0 | |
preds = [] | |
labels = [] | |
for input_nodes, output_nodes, blocks in test_dataloader: | |
Batch_logits, Batch_labels = model(input_nodes, output_nodes, blocks, out_key='user', label_key=label, | |
is_train=False,print_flag=print_flag) | |
loss = F.cross_entropy(Batch_logits, Batch_labels) | |
acc = torch.sum(Batch_logits.argmax(1) == Batch_labels).item() | |
preds.extend(Batch_logits.argmax(1).tolist()) | |
labels.extend(Batch_labels.tolist()) | |
total_loss += loss.item() * len(output_nodes['user'].cpu()) | |
total_acc += acc | |
count += len(output_nodes['user'].cpu()) | |
test_f1 = metrics.f1_score(labels, preds, average='macro') | |
test_loss, test_acc = total_loss / count, total_acc / count | |
print('Epoch: %d , test loss %.4f,, Test Acc %.4f (f1 %.4f)' % ( | |
epoch, | |
test_loss, | |
test_acc, | |
test_f1, | |
)) | |
# Classification reports | |
confusion_matrix = metrics.confusion_matrix(labels, preds) | |
print(confusion_matrix) | |
# fpr, tpr, _ = metrics.roc_curve(labels, preds, pos_label=0) | |
# auc = metrics.auc(fpr, tpr) | |
# print("AUC:", auc) | |
classification_report = metrics.classification_report(labels, preds) | |
print(classification_report) | |
f1 = f1_score(labels, preds, average='macro') | |
print('F1 score:', f1) | |
toc = time.perf_counter() # stop counting time | |
elapsed_time = (toc-tic)/60 | |
print("\nElapsed time: {:.4f} minutes".format(elapsed_time)) | |
# Log result on Neptune | |
#neptune_run["test/accuracy"] = best_test_acc | |
#neptune_run["test/f1_score"] = test_f1 | |
# neptune_run["test/auc"] = auc | |
# neptune_run["test/tpr"] = tpr | |
# neptune_run["test/fpr"] = fpr | |
#neptune_run["conf_matrix"] = confusion_matrix | |
#neptune_run["elaps_time"] = elapsed_time | |
return labels, preds | |
#################################################################################### | |
def tecent_training_main(G, cid1_feature, cid2_feature, cid3_feature, cid4_feature, model_type, seed, gpu, label, n_inp, batch_size, num_hidden, epochs, lr, sens_attr, multiclass_pred, multiclass_sens, clip): | |
'''Fixed random seeds''' | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
device = torch.device("cuda:{}".format(gpu)) | |
'''Loading charts and labels''' | |
#G=torch.load('{}/{}.pkl'.format(args.data_dir,args.graph)) | |
print(G) | |
labels=G.nodes['user'].data[label] | |
# generate train/val/test split | |
pid = np.arange(len(labels)) | |
shuffle = np.random.permutation(pid) | |
train_idx = torch.tensor(shuffle[0:int(len(labels)*0.75)]).long() | |
val_idx = torch.tensor(shuffle[int(len(labels)*0.75):int(len(labels)*0.875)]).long() | |
test_idx = torch.tensor(shuffle[int(len(labels)*0.875):]).long() | |
print(train_idx.shape) | |
print(val_idx.shape) | |
print(test_idx.shape) | |
node_dict = {} | |
edge_dict = {} | |
for ntype in G.ntypes: | |
node_dict[ntype] = len(node_dict) | |
for etype in G.etypes: | |
edge_dict[etype] = len(edge_dict) | |
G.edges[etype].data['id'] = torch.ones(G.number_of_edges(etype), dtype=torch.long) * edge_dict[etype] | |
# Initialize input feature | |
# import fasttext | |
# model = fasttext.load_model('../data/fasttext/fastText/cc.zh.200.bin') | |
# sentence_dic=torch.load('../data/sentence_dic.pkl') | |
# sentence_vec = [model.get_sentence_vector(sentence_dic[k]) for k, v in enumerate(G.nodes('item').tolist())] | |
# for ntype in G.ntypes: | |
# if ntype=='item': | |
# emb=nn.Parameter(torch.Tensor(sentence_vec), requires_grad = False) | |
# else: | |
# emb = nn.Parameter(torch.Tensor(G.number_of_nodes(ntype), 200), requires_grad = False) | |
# nn.init.xavier_uniform_(emb) | |
# G.nodes[ntype].data['inp'] = emb | |
# | |
for ntype in G.ntypes: | |
emb = nn.Parameter(torch.Tensor(G.number_of_nodes(ntype), 200), requires_grad = False) | |
nn.init.xavier_uniform_(emb) | |
G.nodes[ntype].data['inp'] = emb | |
G = G.to(device) | |
train_idx_item=torch.tensor(shuffle[0:int(G.number_of_nodes('item') * 0.75)]).long() | |
val_idx_item = torch.tensor(shuffle[int(G.number_of_nodes('item')*0.75):int(G.number_of_nodes('item')*0.875)]).long() | |
test_idx_item = torch.tensor(shuffle[int(G.number_of_nodes('item')*0.875):]).long() | |
'''Sampling''' | |
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2) | |
train_dataloader = dgl.dataloading.NodeDataLoader( | |
G, {'user':train_idx.to(device)}, sampler, | |
batch_size=batch_size, | |
shuffle=False, | |
drop_last=False, | |
device=device) | |
val_dataloader = dgl.dataloading.NodeDataLoader( | |
G, {'user':val_idx.to(device)}, sampler, | |
batch_size=batch_size, | |
shuffle=False, | |
drop_last=False, | |
device=device) | |
test_dataloader = dgl.dataloading.NodeDataLoader( | |
G, {'user':test_idx.to(device)}, sampler, | |
batch_size=batch_size, | |
shuffle=False, | |
drop_last=False, | |
device=device) | |
if model_type=='RHGN': | |
#cid1_feature = torch.load('{}/cid1_feature.npy'.format(args.data_dir)) | |
#cid2_feature = torch.load('{}/cid2_feature.npy'.format(args.data_dir)) | |
#cid3_feature = torch.load('{}/cid3_feature.npy'.format(args.data_dir)) | |
#cid4_feature = torch.load('{}/brand_feature.npy'.format(args.data_dir)) | |
# cid4_feature = torch.load('{}/cid4_feature.npy'.format(args.data_dir)) | |
model = jd_RHGN(G, | |
node_dict, edge_dict, | |
n_inp=n_inp, | |
n_hid=num_hidden, | |
n_out=labels.max().item()+1, | |
n_layers=2, | |
n_heads=4, | |
cid1_feature=cid1_feature, | |
cid2_feature=cid2_feature, | |
cid3_feature=cid3_feature, | |
cid4_feature=cid4_feature, | |
use_norm = True).to(device) | |
optimizer = torch.optim.AdamW(model.parameters()) | |
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, epochs=epochs, | |
steps_per_epoch=int(train_idx.shape[0]/batch_size)+1,max_lr = lr) | |
print('Training RHGN with #param: %d' % (get_n_params(model))) | |
targets, predictions = Batch_train(model, optimizer, scheduler, train_dataloader, val_dataloader, test_dataloader, epochs, label, clip, device) | |
### Compute fairness ### | |
fair_obj = Fairness(G, test_idx, targets, predictions, sens_attr, multiclass_pred, multiclass_sens) | |
fair_obj.statistical_parity() | |
fair_obj.equal_opportunity() | |
fair_obj.overall_accuracy_equality() | |
fair_obj.treatment_equality() | |
#neptune_run.stop() |