File size: 3,359 Bytes
d2a8669
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
import numpy as np
import neptune.new as neptune

#from parser import parameter_parser
from CatGCN.clustering import ClusteringMachine
from CatGCN.clustergnn import ClusterGNNTrainer

from CatGCN.utils import pos_preds_attr_distr, tab_printer, graph_reader, field_reader, target_reader, label_reader
import time
from CatGCN.fairness import Fairness

def train_CatGCN(user_edge, user_field, user_gender, user_labels, seed, label, args):
    start_time = time.perf_counter()

    """
    Parsing command line parameters, reading data, graph decomposition, fitting and scoring the model.
    """
    #args = parameter_parser()
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
   # tab_printer(args)
    graph = graph_reader(user_edge)
    field_index = field_reader(user_field)
    target = target_reader(user_gender)
    user_labels = label_reader(user_labels)
    print('args', args)

    
    # Instantiate Neptune client and log arguments
    print('token:', args.neptune_token)
    neptune_run = neptune.init(
        project = args.neptune_project,
        api_token = args.neptune_token,
    )
    #neptune_run["sys/tags"].add(args.log_tags.split(","))
    neptune_run["seed"] = seed
    #neptune_run["dataset"] = "JD-small" if "jd" in args.edge_path else "Alibaba-small"
    neptune_run["model"] = "CatGCN"
    neptune_run["label"] = label
    neptune_run["lr"] = args.lr
    neptune_run["L2"] = args.weight_decay
    neptune_run["dropout"] = args.dropout
    neptune_run["diag_probe"] = args.diag_probe
    neptune_run["nfm_units"] = args.nfm_units
    neptune_run["grn_units"] = args.grn_units
    neptune_run["gnn_hops"] = args.gnn_hops
    neptune_run["gnn_units"] = args.gnn_units
    neptune_run["balance_ratio"] = args.balance_ratio
    neptune_run["n_epochs"] = args.epochs
    
    clustering_machine = ClusteringMachine(args, graph, field_index, target)
    clustering_machine.decompose()
    gnn_trainer = ClusterGNNTrainer(args, clustering_machine, neptune_run) # todo add later neptune_run
    gnn_trainer.train_val_test()

    ## Compute accuracy per sensitive attribute group
    if args.dataset_name == 'nba' or args.dataset_name == 'pokec_z' or args.dataset_name == 'pokec_n':
        pos_preds_distr = pos_preds_attr_distr(user_labels, gnn_trainer.targets, gnn_trainer.predictions, clustering_machine.sg_test_nodes[0], label, args.sens_attr)
    else:
        pos_preds_distr = pos_preds_attr_distr(user_labels, gnn_trainer.targets, gnn_trainer.predictions, clustering_machine.sg_test_nodes[0], label, args.sens_attr)
    print(pos_preds_distr)
    #neptune_run["pos_preds_distr"] = pos_preds_distr

    ## Compute fairness metrics
    print("Fairness metrics on sensitive attributes '{}':".format(args.sens_attr))
    fair_obj = Fairness(user_labels, clustering_machine.sg_test_nodes[0], gnn_trainer.targets, gnn_trainer.predictions, args.sens_attr, neptune_run, args.multiclass_pred, args.multiclass_sens) #todo add neptune_run later
    fair_obj.statistical_parity()
    fair_obj.equal_opportunity()
    fair_obj.overall_accuracy_equality()
    fair_obj.treatment_equality()

    elaps_time = (time.perf_counter() - start_time)/60
    neptune_run["elaps_time"] = elaps_time

    neptune_run.stop()

if __name__ == "__main__":
    train_CatGCN()