FairUP / src /models /CatGCN /clustergnn.py
erasmopurif's picture
First commit
d2a8669
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import time
import numpy as np
from sklearn import metrics
from sklearn.metrics import f1_score
from tqdm import trange, tqdm
import torch.nn as nn
from CatGCN.layers import StackedGNN
from FairGNN.src.models.GCN import GCN
class ClusterGNNTrainer(object):
"""
Training a huge graph cluster partition strategy.
"""
def __init__(self, args, clustering_machine, neptune_run):
self.args = args
self.clustering_machine = clustering_machine
self.neptune_run = neptune_run
self.device = torch.device("cuda:{}".format(args.gpu) if torch.cuda.is_available() else "cpu")
print('device:', self.device)
self.class_weight = clustering_machine.class_weight.to(self.device)
self.create_model()
# mew part -- input?
#self.sens_model = GCN(95, 128, 0.5) # number of feat, number of hidden, dropout percentage
#self.adv_model = nn.Linear(128, 1)
# adversary optimizer
#self.optimizer_A = torch.optim.Adam(self.adv_model.parameters(), lr = args.lr, weight_decay = args.weight_decay) ## can add weight decay
#self.criterion = nn.BCEWithLogitsLoss()
#self.A_loss = 0
def create_model(self):
"""
Creating a StackedGNN and transferring to CPU/GPU.
"""
self.model = StackedGNN(self.args, self.clustering_machine.field_count, self.clustering_machine.field_size, self.clustering_machine.class_count)
self.model = self.model.to(self.device)
def generate_field_adjs(self, node_count):
# Normalization by P'' = Q^{-1/2}*P'*Q^{-1/2}, P' = P+probe*O.
field_adjs = torch.ones((node_count, self.clustering_machine.field_size, self.clustering_machine.field_size))
field_adjs += self.args.diag_probe * torch.eye(self.clustering_machine.field_size)
row_sum = self.clustering_machine.field_size + self.args.diag_probe
field_adjs = (1. / row_sum) * field_adjs
return field_adjs
def do_forward_pass(self, cluster):
"""
Making a forward pass with data from a given partition.
:param cluster: Cluster index.
:return average_loss: Average loss on the cluster.
:return node_count: Number of nodes.
"""
edges = self.clustering_machine.sg_edges[cluster].to(self.device)
macro_nodes = self.clustering_machine.sg_nodes[cluster].to(self.device)
train_nodes = self.clustering_machine.sg_train_nodes[cluster].to(self.device)
field_index = self.clustering_machine.sg_field_index[cluster].to(self.device)
field_adjs = self.generate_field_adjs(field_index.shape[0]).to(self.device)
target = self.clustering_machine.sg_targets[cluster].to(self.device).squeeze()
prediction = self.model(edges, field_index, field_adjs)
# todo add forward pass for estimator and adversary
average_loss = F.nll_loss(prediction[train_nodes], target[train_nodes], self.class_weight)
node_count = train_nodes.shape[0]
return average_loss, node_count
def do_validation(self, cluster, epoch):
"""
Making a validation with data from a given partition.
:param cluster: Cluster index.
:return average_loss: Average loss on the cluster.
:return node_count: Number of nodes.
"""
edges = self.clustering_machine.sg_edges[cluster].to(self.device)
macro_nodes = self.clustering_machine.sg_nodes[cluster].to(self.device)
val_nodes = self.clustering_machine.sg_val_nodes[cluster].to(self.device)
field_index = self.clustering_machine.sg_field_index[cluster].to(self.device)
field_adjs = self.generate_field_adjs(field_index.shape[0]).to(self.device)
target = self.clustering_machine.sg_targets[cluster].to(self.device).squeeze()
prediction = self.model(edges, field_index, field_adjs)
average_loss = F.nll_loss(prediction[val_nodes], target[val_nodes], self.class_weight)
node_count = val_nodes.shape[0]
return average_loss, node_count
def do_prediction(self, cluster):
"""
Scoring a cluster.
:param cluster: Cluster index.
:return average_loss: Average loss on the cluster.
:return node_count: Number of nodes.
:return prediction: Prediction matrix with probabilities.
:return target: Target vector.
"""
edges = self.clustering_machine.sg_edges[cluster].to(self.device)
macro_nodes = self.clustering_machine.sg_nodes[cluster].to(self.device)
test_nodes = self.clustering_machine.sg_test_nodes[cluster].to(self.device)
field_index = self.clustering_machine.sg_field_index[cluster].to(self.device)
field_adjs = self.generate_field_adjs(field_index.shape[0]).to(self.device)
target = self.clustering_machine.sg_targets[cluster].to(self.device).squeeze()
prediction = self.model(edges, field_index, field_adjs)
average_loss = F.nll_loss(prediction[test_nodes], target[test_nodes], self.class_weight)
node_count = test_nodes.shape[0]
target = target[test_nodes]
prediction = prediction[test_nodes,:]
return average_loss, node_count, prediction, target
def update_average_loss(self, batch_average_loss, node_count):
"""
Updating the average loss in the epoch.
:param batch_average_loss: Loss of the cluster.
:param node_count: Number of nodes in currently processed cluster.
:return average_loss: Average loss in the epoch.
"""
self.accumulated_loss = self.accumulated_loss + batch_average_loss.item()*node_count
self.node_count_seen = self.node_count_seen + node_count
average_loss = self.accumulated_loss / self.node_count_seen
return average_loss
def train_val_test(self):
"""
Training, validation, and test a model per epoch.
"""
print("Training, validation, and test started.\n")
train_start_time = time.perf_counter()
bad_counter = 0
best_loss = np.inf
best_epoch = 0
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
# test for epochs
print(self.args.epochs)
for epoch in range(1, self.args.epochs+1):
epoch_start_time = time.time()
np.random.shuffle(self.clustering_machine.clusters)
self.model.train()
self.node_count_seen = 0
self.accumulated_loss = 0
for cluster in self.clustering_machine.clusters:
self.optimizer.zero_grad()
batch_average_loss, node_count = self.do_forward_pass(cluster)
batch_average_loss.backward()
self.optimizer.step()
average_loss = self.update_average_loss(batch_average_loss, node_count)
train_loss = average_loss
self.model.eval()
self.node_count_seen = 0
self.accumulated_loss = 0
for cluster in self.clustering_machine.clusters:
batch_average_loss, node_count = self.do_validation(cluster, epoch)
average_loss = self.update_average_loss(batch_average_loss, node_count)
val_loss = average_loss
print("Epoch: {:04d}".format(epoch),
"||",
"time cost: {:.2f}s".format(time.time() - epoch_start_time),
"||",
"train loss: {:.4f}".format(train_loss),
"val loss: {:.4f}".format(val_loss))
if val_loss < best_loss:
best_loss = val_loss
best_epoch = epoch
bad_counter = 0
best_model_state = self.model.state_dict()
else:
bad_counter += 1
if bad_counter == self.args.patience:
break
self.model.load_state_dict(best_model_state)
self.model.eval()
self.node_count_seen = 0
self.accumulated_loss = 0
self.predictions = []
self.targets = []
for cluster in self.clustering_machine.clusters:
batch_average_loss, node_count, prediction, target = self.do_prediction(cluster)
average_loss = self.update_average_loss(batch_average_loss, node_count)
self.predictions.append(prediction.cpu().detach().numpy())
self.targets.append(target.cpu().detach().numpy())
test_loss = average_loss
self.targets = np.concatenate(self.targets)
self.predictions = np.concatenate(self.predictions).argmax(1)
acc_score = metrics.accuracy_score(self.targets, self.predictions)
macro_f1 = metrics.f1_score(self.targets, self.predictions, average="macro")
classification_report = metrics.classification_report(self.targets, self.predictions, digits=4)
print(classification_report)
# Confusion matrics and AUC
confusion_matrix = metrics.confusion_matrix(self.targets, self.predictions)
print(confusion_matrix)
#F1
f1 = f1_score(self.targets, self.predictions, average='macro')
print('F1 score:', f1)
# fpr, tpr, _ = metrics.roc_curve(self.targets, self.predictions)
# auc = metrics.auc(fpr, tpr)
# print("AUC:", auc)
train_time = (time.perf_counter() - train_start_time)/60
print("Optimization Finished!")
print("Total time elapsed: {:.2f}min".format(train_time))
print("Best Result:\n",
"best epoch: {:04d}".format(best_epoch),
"||",
"test loss: {:.4f}".format(test_loss),
"||",
"accuracy: {:.4f}".format(acc_score),
"macro-f1: {:.4f}".format(macro_f1))
# Save results on Neptune
self.neptune_run["best_epoch"] = best_epoch
self.neptune_run["test/loss"] = test_loss
self.neptune_run["test/acc"] = acc_score
self.neptune_run["test/f1"] = macro_f1
# self.neptune_run["test/auc"] = auc
# self.neptune_run["test/tpr"] = tpr
# self.neptune_run["test/fpr"] = fpr
self.neptune_run["conf_matrix"] = confusion_matrix
self.neptune_run["train_time"] = train_time