import torch import torch.nn.functional as F from torch.autograd import Variable import time import numpy as np from sklearn import metrics from sklearn.metrics import f1_score from tqdm import trange, tqdm import torch.nn as nn from CatGCN.layers import StackedGNN from FairGNN.src.models.GCN import GCN class ClusterGNNTrainer(object): """ Training a huge graph cluster partition strategy. """ def __init__(self, args, clustering_machine, neptune_run): self.args = args self.clustering_machine = clustering_machine self.neptune_run = neptune_run self.device = torch.device("cuda:{}".format(args.gpu) if torch.cuda.is_available() else "cpu") print('device:', self.device) self.class_weight = clustering_machine.class_weight.to(self.device) self.create_model() # mew part -- input? #self.sens_model = GCN(95, 128, 0.5) # number of feat, number of hidden, dropout percentage #self.adv_model = nn.Linear(128, 1) # adversary optimizer #self.optimizer_A = torch.optim.Adam(self.adv_model.parameters(), lr = args.lr, weight_decay = args.weight_decay) ## can add weight decay #self.criterion = nn.BCEWithLogitsLoss() #self.A_loss = 0 def create_model(self): """ Creating a StackedGNN and transferring to CPU/GPU. """ self.model = StackedGNN(self.args, self.clustering_machine.field_count, self.clustering_machine.field_size, self.clustering_machine.class_count) self.model = self.model.to(self.device) def generate_field_adjs(self, node_count): # Normalization by P'' = Q^{-1/2}*P'*Q^{-1/2}, P' = P+probe*O. field_adjs = torch.ones((node_count, self.clustering_machine.field_size, self.clustering_machine.field_size)) field_adjs += self.args.diag_probe * torch.eye(self.clustering_machine.field_size) row_sum = self.clustering_machine.field_size + self.args.diag_probe field_adjs = (1. / row_sum) * field_adjs return field_adjs def do_forward_pass(self, cluster): """ Making a forward pass with data from a given partition. :param cluster: Cluster index. :return average_loss: Average loss on the cluster. :return node_count: Number of nodes. """ edges = self.clustering_machine.sg_edges[cluster].to(self.device) macro_nodes = self.clustering_machine.sg_nodes[cluster].to(self.device) train_nodes = self.clustering_machine.sg_train_nodes[cluster].to(self.device) field_index = self.clustering_machine.sg_field_index[cluster].to(self.device) field_adjs = self.generate_field_adjs(field_index.shape[0]).to(self.device) target = self.clustering_machine.sg_targets[cluster].to(self.device).squeeze() prediction = self.model(edges, field_index, field_adjs) # todo add forward pass for estimator and adversary average_loss = F.nll_loss(prediction[train_nodes], target[train_nodes], self.class_weight) node_count = train_nodes.shape[0] return average_loss, node_count def do_validation(self, cluster, epoch): """ Making a validation with data from a given partition. :param cluster: Cluster index. :return average_loss: Average loss on the cluster. :return node_count: Number of nodes. """ edges = self.clustering_machine.sg_edges[cluster].to(self.device) macro_nodes = self.clustering_machine.sg_nodes[cluster].to(self.device) val_nodes = self.clustering_machine.sg_val_nodes[cluster].to(self.device) field_index = self.clustering_machine.sg_field_index[cluster].to(self.device) field_adjs = self.generate_field_adjs(field_index.shape[0]).to(self.device) target = self.clustering_machine.sg_targets[cluster].to(self.device).squeeze() prediction = self.model(edges, field_index, field_adjs) average_loss = F.nll_loss(prediction[val_nodes], target[val_nodes], self.class_weight) node_count = val_nodes.shape[0] return average_loss, node_count def do_prediction(self, cluster): """ Scoring a cluster. :param cluster: Cluster index. :return average_loss: Average loss on the cluster. :return node_count: Number of nodes. :return prediction: Prediction matrix with probabilities. :return target: Target vector. """ edges = self.clustering_machine.sg_edges[cluster].to(self.device) macro_nodes = self.clustering_machine.sg_nodes[cluster].to(self.device) test_nodes = self.clustering_machine.sg_test_nodes[cluster].to(self.device) field_index = self.clustering_machine.sg_field_index[cluster].to(self.device) field_adjs = self.generate_field_adjs(field_index.shape[0]).to(self.device) target = self.clustering_machine.sg_targets[cluster].to(self.device).squeeze() prediction = self.model(edges, field_index, field_adjs) average_loss = F.nll_loss(prediction[test_nodes], target[test_nodes], self.class_weight) node_count = test_nodes.shape[0] target = target[test_nodes] prediction = prediction[test_nodes,:] return average_loss, node_count, prediction, target def update_average_loss(self, batch_average_loss, node_count): """ Updating the average loss in the epoch. :param batch_average_loss: Loss of the cluster. :param node_count: Number of nodes in currently processed cluster. :return average_loss: Average loss in the epoch. """ self.accumulated_loss = self.accumulated_loss + batch_average_loss.item()*node_count self.node_count_seen = self.node_count_seen + node_count average_loss = self.accumulated_loss / self.node_count_seen return average_loss def train_val_test(self): """ Training, validation, and test a model per epoch. """ print("Training, validation, and test started.\n") train_start_time = time.perf_counter() bad_counter = 0 best_loss = np.inf best_epoch = 0 self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay) # test for epochs print(self.args.epochs) for epoch in range(1, self.args.epochs+1): epoch_start_time = time.time() np.random.shuffle(self.clustering_machine.clusters) self.model.train() self.node_count_seen = 0 self.accumulated_loss = 0 for cluster in self.clustering_machine.clusters: self.optimizer.zero_grad() batch_average_loss, node_count = self.do_forward_pass(cluster) batch_average_loss.backward() self.optimizer.step() average_loss = self.update_average_loss(batch_average_loss, node_count) train_loss = average_loss self.model.eval() self.node_count_seen = 0 self.accumulated_loss = 0 for cluster in self.clustering_machine.clusters: batch_average_loss, node_count = self.do_validation(cluster, epoch) average_loss = self.update_average_loss(batch_average_loss, node_count) val_loss = average_loss print("Epoch: {:04d}".format(epoch), "||", "time cost: {:.2f}s".format(time.time() - epoch_start_time), "||", "train loss: {:.4f}".format(train_loss), "val loss: {:.4f}".format(val_loss)) if val_loss < best_loss: best_loss = val_loss best_epoch = epoch bad_counter = 0 best_model_state = self.model.state_dict() else: bad_counter += 1 if bad_counter == self.args.patience: break self.model.load_state_dict(best_model_state) self.model.eval() self.node_count_seen = 0 self.accumulated_loss = 0 self.predictions = [] self.targets = [] for cluster in self.clustering_machine.clusters: batch_average_loss, node_count, prediction, target = self.do_prediction(cluster) average_loss = self.update_average_loss(batch_average_loss, node_count) self.predictions.append(prediction.cpu().detach().numpy()) self.targets.append(target.cpu().detach().numpy()) test_loss = average_loss self.targets = np.concatenate(self.targets) self.predictions = np.concatenate(self.predictions).argmax(1) acc_score = metrics.accuracy_score(self.targets, self.predictions) macro_f1 = metrics.f1_score(self.targets, self.predictions, average="macro") classification_report = metrics.classification_report(self.targets, self.predictions, digits=4) print(classification_report) # Confusion matrics and AUC confusion_matrix = metrics.confusion_matrix(self.targets, self.predictions) print(confusion_matrix) #F1 f1 = f1_score(self.targets, self.predictions, average='macro') print('F1 score:', f1) # fpr, tpr, _ = metrics.roc_curve(self.targets, self.predictions) # auc = metrics.auc(fpr, tpr) # print("AUC:", auc) train_time = (time.perf_counter() - train_start_time)/60 print("Optimization Finished!") print("Total time elapsed: {:.2f}min".format(train_time)) print("Best Result:\n", "best epoch: {:04d}".format(best_epoch), "||", "test loss: {:.4f}".format(test_loss), "||", "accuracy: {:.4f}".format(acc_score), "macro-f1: {:.4f}".format(macro_f1)) # Save results on Neptune self.neptune_run["best_epoch"] = best_epoch self.neptune_run["test/loss"] = test_loss self.neptune_run["test/acc"] = acc_score self.neptune_run["test/f1"] = macro_f1 # self.neptune_run["test/auc"] = auc # self.neptune_run["test/tpr"] = tpr # self.neptune_run["test/fpr"] = fpr self.neptune_run["conf_matrix"] = confusion_matrix self.neptune_run["train_time"] = train_time