File size: 5,536 Bytes
d2a8669
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch

##import metis
import numpy as np
import networkx as nx
from sklearn.model_selection import train_test_split

class ClusteringMachine(object):
    """
    Clustering the graph, feature set, and target. If the graph is not huge enough, we suggest using 'none' here.
    """

    def __init__(self, args, graph, field_index, target):
        """
        :param args: Arguments object with parameters.
        :param graph: Networkx Graph.
        :param field_index: field_index matrix (ndarray).
        :param target: Target vector (ndarray).
        """
        self.args = args
        self.graph = graph
        self.field_index = field_index
        self.target = target
        self._set_sizes()
        self._set_loss_weight()

    def _set_sizes(self):
        """
        Setting the field and class count.
        """
        self.user_count = self.field_index.shape[0]
        self.field_count = np.max(self.field_index)+1
        self.field_size = self.field_index.shape[1]
        self.class_count = np.max(self.target)+1
        print("####\tData Info\t####")
        print("user count:\t", self.user_count)
        print("field count:\t", self.field_count)
        print("field size:\t", self.field_size)
        print("class count:\t", self.class_count)

    def _set_loss_weight(self):
        class_weight = self.target.shape[0] / (self.class_count * np.bincount(self.target.squeeze()))
        if self.args.weight_balanced == 'True':
            self.class_weight = torch.FloatTensor(class_weight)
        else:
            self.class_weight = torch.ones(self.class_count)

    def decompose(self):
        """
        Decomposing the graph, partitioning the features and target, creating Torch arrays.
        """
        if self.args.clustering_method == "none":
            print("\nWithout graph clustering.\n")
            self.clusters = [0]
            self.cluster_membership = {node: 0 for node in self.graph.nodes()}
            print('cluster memebership', self.cluster_membership)
        #elif self.args.clustering_method == "metis":
        #    print("\nMetis graph clustering started.\n")
        #    self.metis_clustering()
        else:
            print("\nRandom graph clustering started.\n")
            self.random_clustering()
        self.generate_data_partitioning()
        self.transfer_edges_and_nodes()

    def random_clustering(self):
        """
        Random clustering the nodes.
        """
        self.clusters = [cluster for cluster in range(self.args.cluster_number)]
        self.cluster_membership = {node: np.random.choice(self.clusters) for node in self.graph.nodes()}

    #def metis_clustering(self):
        """
        Clustering the graph with Metis.
        """
    #    (st, parts) = metis.part_graph(self.graph, self.args.cluster_number, seed=self.args.seed)
    #    self.clusters = list(set(parts))
    #    self.cluster_membership = {node: membership for node, membership in enumerate(parts)}

    def generate_data_partitioning(self):
        """
        Creating data partitions and train-val-test splits.
        """
        if self.args.clustering_method != "metis":
            self.sg_nodes = {}
            self.sg_targets = {}
            self.sg_edges = {}
            self.sg_train_nodes = {}
            self.sg_val_nodes = {}
            self.sg_test_nodes = {}
            self.sg_field_index = {}
            for cluster in self.clusters:
                print('Cluster', cluster)
                subgraph = self.graph.subgraph([node for node in sorted(self.graph.nodes()) if self.cluster_membership[node] == cluster])
                self.sg_nodes[cluster] = [node for node in sorted(subgraph.nodes())]
                self.sg_targets[cluster] = self.target[self.sg_nodes[cluster],:]
                mapper = {node: i for i, node in enumerate(sorted(self.sg_nodes[cluster]))}
                self.sg_edges[cluster] = [[mapper[edge[0]], mapper[edge[1]]] for edge in subgraph.edges()] + [[mapper[edge[1]], mapper[edge[0]]] for edge in subgraph.edges()]
                self.sg_train_nodes[cluster], sg_val_test_nodes = \
                    train_test_split(list(mapper.values()), test_size = 1-self.args.train_ratio, random_state=self.args.seed, shuffle=True)
                self.sg_val_nodes[cluster], self.sg_test_nodes[cluster] = \
                    train_test_split(sg_val_test_nodes, test_size = 0.5, random_state=self.args.seed, shuffle=True)
                self.sg_train_nodes[cluster] = sorted(self.sg_train_nodes[cluster])
                self.sg_val_nodes[cluster] = sorted(self.sg_val_nodes[cluster])
                self.sg_test_nodes[cluster] = sorted(self.sg_test_nodes[cluster])
                self.sg_field_index[cluster] = self.field_index[self.sg_nodes[cluster],:]

    def transfer_edges_and_nodes(self):
        """
        Transfering the data to PyTorch format.
        """
        for cluster in self.clusters:
            self.sg_nodes[cluster] = torch.LongTensor(self.sg_nodes[cluster])
            self.sg_targets[cluster] = torch.LongTensor(self.sg_targets[cluster])
            self.sg_edges[cluster] = torch.LongTensor(self.sg_edges[cluster]).t()
            self.sg_train_nodes[cluster] = torch.LongTensor(self.sg_train_nodes[cluster])
            self.sg_val_nodes[cluster] = torch.LongTensor(self.sg_val_nodes[cluster])
            self.sg_test_nodes[cluster] = torch.LongTensor(self.sg_test_nodes[cluster])
            self.sg_field_index[cluster] = torch.LongTensor(self.sg_field_index[cluster])