FairUP / src /models /CatGCN /train_main.py
erasmopurif's picture
First commit
d2a8669
import torch
import numpy as np
import neptune.new as neptune
#from parser import parameter_parser
from CatGCN.clustering import ClusteringMachine
from CatGCN.clustergnn import ClusterGNNTrainer
from CatGCN.utils import pos_preds_attr_distr, tab_printer, graph_reader, field_reader, target_reader, label_reader
import time
from CatGCN.fairness import Fairness
def train_CatGCN(user_edge, user_field, user_gender, user_labels, seed, label, args):
start_time = time.perf_counter()
"""
Parsing command line parameters, reading data, graph decomposition, fitting and scoring the model.
"""
#args = parameter_parser()
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
# tab_printer(args)
graph = graph_reader(user_edge)
field_index = field_reader(user_field)
target = target_reader(user_gender)
user_labels = label_reader(user_labels)
print('args', args)
# Instantiate Neptune client and log arguments
print('token:', args.neptune_token)
neptune_run = neptune.init(
project = args.neptune_project,
api_token = args.neptune_token,
)
#neptune_run["sys/tags"].add(args.log_tags.split(","))
neptune_run["seed"] = seed
#neptune_run["dataset"] = "JD-small" if "jd" in args.edge_path else "Alibaba-small"
neptune_run["model"] = "CatGCN"
neptune_run["label"] = label
neptune_run["lr"] = args.lr
neptune_run["L2"] = args.weight_decay
neptune_run["dropout"] = args.dropout
neptune_run["diag_probe"] = args.diag_probe
neptune_run["nfm_units"] = args.nfm_units
neptune_run["grn_units"] = args.grn_units
neptune_run["gnn_hops"] = args.gnn_hops
neptune_run["gnn_units"] = args.gnn_units
neptune_run["balance_ratio"] = args.balance_ratio
neptune_run["n_epochs"] = args.epochs
clustering_machine = ClusteringMachine(args, graph, field_index, target)
clustering_machine.decompose()
gnn_trainer = ClusterGNNTrainer(args, clustering_machine, neptune_run) # todo add later neptune_run
gnn_trainer.train_val_test()
## Compute accuracy per sensitive attribute group
if args.dataset_name == 'nba' or args.dataset_name == 'pokec_z' or args.dataset_name == 'pokec_n':
pos_preds_distr = pos_preds_attr_distr(user_labels, gnn_trainer.targets, gnn_trainer.predictions, clustering_machine.sg_test_nodes[0], label, args.sens_attr)
else:
pos_preds_distr = pos_preds_attr_distr(user_labels, gnn_trainer.targets, gnn_trainer.predictions, clustering_machine.sg_test_nodes[0], label, args.sens_attr)
print(pos_preds_distr)
#neptune_run["pos_preds_distr"] = pos_preds_distr
## Compute fairness metrics
print("Fairness metrics on sensitive attributes '{}':".format(args.sens_attr))
fair_obj = Fairness(user_labels, clustering_machine.sg_test_nodes[0], gnn_trainer.targets, gnn_trainer.predictions, args.sens_attr, neptune_run, args.multiclass_pred, args.multiclass_sens) #todo add neptune_run later
fair_obj.statistical_parity()
fair_obj.equal_opportunity()
fair_obj.overall_accuracy_equality()
fair_obj.treatment_equality()
elaps_time = (time.perf_counter() - start_time)/60
neptune_run["elaps_time"] = elaps_time
neptune_run.stop()
if __name__ == "__main__":
train_CatGCN()