File size: 8,583 Bytes
fa6856c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
from typing import Callable, Dict, List, Optional, Tuple

import networkx as nx
import numpy as np
import torch


def generate_rand_int_excluding(rng: np.random.RandomState, max: int, exclude: int) -> int:
    """Random integer generator, excluding a specific number

    Args:
        rng: Numpy random number generator
        max: Max number
        exclude: Number to exclude

    Returns:
        Random integer in [0, max], excluding the `exclude` integer.
    """
    while True:
        # Create the random integer
        x = rng.randint(max)

        # Return the random integer if it isn't the exclude value, otherwise try
        # again
        if x != exclude:
            return x


def generate_random_walks(  # noqa: max-complexity
    n_nodes: int = 21,
    max_length: int = 10,
    n_walks: int = 1000,
    p_edge: float = 0.1,
    seed: int = 1002,
    gpt2_tokenizer: bool = False,
) -> Tuple[Callable[[List[str]], Dict[str, List[float]]], List[str], List[str], torch.Tensor,]:
    """Generate random walks

    Args:
        n_nodes: Number of nodes. This should not be more than 26, as we use
        single letters to represent each node.
        max_length: Maximum number of steps in each random walk
        n_walks: Number of random walks (samples) to create
        p_edge: Probability that any source node connects to any other
        destination node
        seed: Random seed
        gpt2_tokenizer: True if GPT2's tokenizer is being used

    Returns:
        Tuple of metric function,
    """
    # Initialise a random state with the seed
    rng = np.random.RandomState(seed)

    # Create the adjacency matrix
    # https://en.wikipedia.org/wiki/Adjacency_matrix
    # This is a 2d matrix, where the rows represent the source nodes and the
    # columns represent the destination nodes. If a cell (i,j) is True, then
    # there is a directional edge from the source node (i) to the destination
    # node (j). If it is false there is no connection.
    while True:
        # Create the adjacency matrix, where each node is connected to each
        # other node, with probability p_edge
        adjacency_matrix: np.ndarray = rng.rand(n_nodes, n_nodes) > (1 - p_edge)

        # Nodes can't be connected to themselves, so the diagonal values must
        # all be False
        np.fill_diagonal(adjacency_matrix, 0)

        # Each destination node (column) must be connected to at least one
        # source node. This checks if this is the case, by checking there is a
        # True value in every column. If it is not the case, we try to generate
        # a new adjacency matrix again from scratch (in the while loop).
        if np.all(adjacency_matrix.sum(1)):
            break

    # Set the goal node as 0
    goal: int = 0

    # The goal node is the terminal state, so we make sure that it doesn't
    # have a directional edge going to any other nodes (i.e. it can only be
    # connected to from previous nodes). We also set the connection to itself as
    # True.
    adjacency_matrix[goal, :] = 0
    adjacency_matrix[goal, goal] = 1

    # Create dicts for converting nodes into characters and vice versa
    # Nodes are converted into characters as these (when split by the delimiter) are
    # guaranteed to be tokenized as individual tokens.
    char_to_node: Dict[str, int] = {chr(ix + ord("a")): ix for ix in range(n_nodes)}
    node_to_char: Dict[int, str] = {ix: chr(ix + ord("a")) for ix in range(n_nodes)}

    # Initialise a list of sample walks
    sample_walks: List[str] = []

    # String delimiter (to force the tokenizer to keep all nodes as separate
    # tokens)
    delimiter: str = "|" if gpt2_tokenizer else ""

    # Create n_walks samples
    for _ in range(n_walks):
        # Create a random starting node (that isn't already at the goal state)
        node: int = generate_rand_int_excluding(rng, n_nodes, goal)

        # Initialise the list of nodes that we visit
        walk_nodes: List[int] = [node]

        # Do a series of steps, until we hit the maximum number of steps or the
        # goal state (whichever comes first)
        for _step in range(max_length - 1):
            # From the starting node, get all the nodes we can move to. Pick one
            # of these at random, and add it to the list of visited nodes
            node = rng.choice(np.nonzero(adjacency_matrix[node])[0])
            walk_nodes.append(node)

            # If we're at the goal state, stop
            if node == goal:
                break

        # Convert the nodes visited to letters (not integers)
        walk: List[str] = [node_to_char[ix] for ix in walk_nodes]

        # Concatenate into a journey, with each node letter separated by the
        # delimiter.
        sample_walks.append(delimiter.join(walk))

    # Initialise list of shortest lengths for each node (to the goal node)
    shortest_lengths: List[int] = []

    # Create a directional graph from the adjacency list
    directional_graph = nx.from_numpy_array(adjacency_matrix, create_using=nx.DiGraph)

    # Fore each node (except for the goal node), find the shortest path
    for start in set(range(n_nodes)) - {goal}:
        try:
            # Find the shortest path (up to the max_length)
            shortest_path = nx.shortest_path(directional_graph, start, goal)[:max_length]
            shortest_lengths.append(len(shortest_path))
        except Exception:
            # If there is no path, use the maximum length instead
            shortest_lengths.append(max_length)

    def metric_fn(
        samples: List[str],
    ) -> Dict[str, List[float]]:
        """Metric Function

        Args:
            samples: Batch of samples

        Returns:
            Dict of metrics, each with a key of the metric name and value as a
            list of metric values for each batch item.
        """
        # Length to set if the path is invalid
        invalid_path_length: int = 100

        # Initialise batch lengths & reference lengths (the optimal length
        # starting from each batch items specific start node)
        lengths: List[float] = []
        sample_optimal_lengths: List[int] = []

        for sample_str in samples:
            # Remove GPT2 specific tokenizer delimiter
            if gpt2_tokenizer:
                sample_str = sample_str.replace("|", "")

            # Convert the sample into a list of nodes (default to an unused
            # integer if the node is not found)
            sample: List[int] = [char_to_node.get(c, 1000) for c in sample_str]

            # Initialise the specific sample length
            length: Optional[float] = None

            for node in range(len(sample)):
                # If an invalid path is taken, set the length to the invalid
                # path score
                if sample[node] >= n_nodes or node > 0 and not adjacency_matrix[sample[node - 1], sample[node]]:
                    length = invalid_path_length
                    break

                # Otherwise increment the length for each move (where we don't
                # end up at the goal node)
                elif sample[node] == 0:
                    length = node + 1
                    break

            # Catch the case where there are no moves
            if length is None:
                length = invalid_path_length

            # Store the batch item length & optimal length staring from the
            # start node
            lengths.append(float(length))
            sample_optimal_lengths.append(shortest_lengths[sample[0] - 1])

        # Calculate optimality scores, in [0, 1], as compared to the shortest
        # path
        lengths_tensor = torch.tensor(lengths, dtype=torch.float)
        bound_lengths: torch.Tensor = torch.where(
            lengths_tensor.eq(invalid_path_length), max_length, lengths_tensor
        ).abs()
        optimal_lengths = torch.as_tensor(sample_optimal_lengths)

        # Optimality scores, in [0, 1], as compared to the shortest path
        optimality = (max_length - bound_lengths) / (max_length - optimal_lengths)

        return {
            "lengths": lengths,
            "optimality": optimality.tolist(),
        }

    logit_mask = torch.tensor(adjacency_matrix)

    # Set the evaluation prompts as a list of unique random walk samples, using
    # just the start point (first character) from each samples.
    eval_prompts = list(sorted(set(w[0] for w in sample_walks)))
    eval_prompts = [prompt + delimiter for prompt in eval_prompts]

    return (metric_fn, eval_prompts, sample_walks, logit_mask)