import torch import numpy as np import neptune.new as neptune #from parser import parameter_parser from CatGCN.clustering import ClusteringMachine from CatGCN.clustergnn import ClusterGNNTrainer from CatGCN.utils import pos_preds_attr_distr, tab_printer, graph_reader, field_reader, target_reader, label_reader import time from CatGCN.fairness import Fairness def train_CatGCN(user_edge, user_field, user_gender, user_labels, seed, label, args): start_time = time.perf_counter() """ Parsing command line parameters, reading data, graph decomposition, fitting and scoring the model. """ #args = parameter_parser() np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) # tab_printer(args) graph = graph_reader(user_edge) field_index = field_reader(user_field) target = target_reader(user_gender) user_labels = label_reader(user_labels) print('args', args) # Instantiate Neptune client and log arguments print('token:', args.neptune_token) neptune_run = neptune.init( project = args.neptune_project, api_token = args.neptune_token, ) #neptune_run["sys/tags"].add(args.log_tags.split(",")) neptune_run["seed"] = seed #neptune_run["dataset"] = "JD-small" if "jd" in args.edge_path else "Alibaba-small" neptune_run["model"] = "CatGCN" neptune_run["label"] = label neptune_run["lr"] = args.lr neptune_run["L2"] = args.weight_decay neptune_run["dropout"] = args.dropout neptune_run["diag_probe"] = args.diag_probe neptune_run["nfm_units"] = args.nfm_units neptune_run["grn_units"] = args.grn_units neptune_run["gnn_hops"] = args.gnn_hops neptune_run["gnn_units"] = args.gnn_units neptune_run["balance_ratio"] = args.balance_ratio neptune_run["n_epochs"] = args.epochs clustering_machine = ClusteringMachine(args, graph, field_index, target) clustering_machine.decompose() gnn_trainer = ClusterGNNTrainer(args, clustering_machine, neptune_run) # todo add later neptune_run gnn_trainer.train_val_test() ## Compute accuracy per sensitive attribute group if args.dataset_name == 'nba' or args.dataset_name == 'pokec_z' or args.dataset_name == 'pokec_n': pos_preds_distr = pos_preds_attr_distr(user_labels, gnn_trainer.targets, gnn_trainer.predictions, clustering_machine.sg_test_nodes[0], label, args.sens_attr) else: pos_preds_distr = pos_preds_attr_distr(user_labels, gnn_trainer.targets, gnn_trainer.predictions, clustering_machine.sg_test_nodes[0], label, args.sens_attr) print(pos_preds_distr) #neptune_run["pos_preds_distr"] = pos_preds_distr ## Compute fairness metrics print("Fairness metrics on sensitive attributes '{}':".format(args.sens_attr)) fair_obj = Fairness(user_labels, clustering_machine.sg_test_nodes[0], gnn_trainer.targets, gnn_trainer.predictions, args.sens_attr, neptune_run, args.multiclass_pred, args.multiclass_sens) #todo add neptune_run later fair_obj.statistical_parity() fair_obj.equal_opportunity() fair_obj.overall_accuracy_equality() fair_obj.treatment_equality() elaps_time = (time.perf_counter() - start_time)/60 neptune_run["elaps_time"] = elaps_time neptune_run.stop() if __name__ == "__main__": train_CatGCN()