Spaces:
Runtime error
Runtime error
File size: 8,583 Bytes
fa6856c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
from typing import Callable, Dict, List, Optional, Tuple
import networkx as nx
import numpy as np
import torch
def generate_rand_int_excluding(rng: np.random.RandomState, max: int, exclude: int) -> int:
"""Random integer generator, excluding a specific number
Args:
rng: Numpy random number generator
max: Max number
exclude: Number to exclude
Returns:
Random integer in [0, max], excluding the `exclude` integer.
"""
while True:
# Create the random integer
x = rng.randint(max)
# Return the random integer if it isn't the exclude value, otherwise try
# again
if x != exclude:
return x
def generate_random_walks( # noqa: max-complexity
n_nodes: int = 21,
max_length: int = 10,
n_walks: int = 1000,
p_edge: float = 0.1,
seed: int = 1002,
gpt2_tokenizer: bool = False,
) -> Tuple[Callable[[List[str]], Dict[str, List[float]]], List[str], List[str], torch.Tensor,]:
"""Generate random walks
Args:
n_nodes: Number of nodes. This should not be more than 26, as we use
single letters to represent each node.
max_length: Maximum number of steps in each random walk
n_walks: Number of random walks (samples) to create
p_edge: Probability that any source node connects to any other
destination node
seed: Random seed
gpt2_tokenizer: True if GPT2's tokenizer is being used
Returns:
Tuple of metric function,
"""
# Initialise a random state with the seed
rng = np.random.RandomState(seed)
# Create the adjacency matrix
# https://en.wikipedia.org/wiki/Adjacency_matrix
# This is a 2d matrix, where the rows represent the source nodes and the
# columns represent the destination nodes. If a cell (i,j) is True, then
# there is a directional edge from the source node (i) to the destination
# node (j). If it is false there is no connection.
while True:
# Create the adjacency matrix, where each node is connected to each
# other node, with probability p_edge
adjacency_matrix: np.ndarray = rng.rand(n_nodes, n_nodes) > (1 - p_edge)
# Nodes can't be connected to themselves, so the diagonal values must
# all be False
np.fill_diagonal(adjacency_matrix, 0)
# Each destination node (column) must be connected to at least one
# source node. This checks if this is the case, by checking there is a
# True value in every column. If it is not the case, we try to generate
# a new adjacency matrix again from scratch (in the while loop).
if np.all(adjacency_matrix.sum(1)):
break
# Set the goal node as 0
goal: int = 0
# The goal node is the terminal state, so we make sure that it doesn't
# have a directional edge going to any other nodes (i.e. it can only be
# connected to from previous nodes). We also set the connection to itself as
# True.
adjacency_matrix[goal, :] = 0
adjacency_matrix[goal, goal] = 1
# Create dicts for converting nodes into characters and vice versa
# Nodes are converted into characters as these (when split by the delimiter) are
# guaranteed to be tokenized as individual tokens.
char_to_node: Dict[str, int] = {chr(ix + ord("a")): ix for ix in range(n_nodes)}
node_to_char: Dict[int, str] = {ix: chr(ix + ord("a")) for ix in range(n_nodes)}
# Initialise a list of sample walks
sample_walks: List[str] = []
# String delimiter (to force the tokenizer to keep all nodes as separate
# tokens)
delimiter: str = "|" if gpt2_tokenizer else ""
# Create n_walks samples
for _ in range(n_walks):
# Create a random starting node (that isn't already at the goal state)
node: int = generate_rand_int_excluding(rng, n_nodes, goal)
# Initialise the list of nodes that we visit
walk_nodes: List[int] = [node]
# Do a series of steps, until we hit the maximum number of steps or the
# goal state (whichever comes first)
for _step in range(max_length - 1):
# From the starting node, get all the nodes we can move to. Pick one
# of these at random, and add it to the list of visited nodes
node = rng.choice(np.nonzero(adjacency_matrix[node])[0])
walk_nodes.append(node)
# If we're at the goal state, stop
if node == goal:
break
# Convert the nodes visited to letters (not integers)
walk: List[str] = [node_to_char[ix] for ix in walk_nodes]
# Concatenate into a journey, with each node letter separated by the
# delimiter.
sample_walks.append(delimiter.join(walk))
# Initialise list of shortest lengths for each node (to the goal node)
shortest_lengths: List[int] = []
# Create a directional graph from the adjacency list
directional_graph = nx.from_numpy_array(adjacency_matrix, create_using=nx.DiGraph)
# Fore each node (except for the goal node), find the shortest path
for start in set(range(n_nodes)) - {goal}:
try:
# Find the shortest path (up to the max_length)
shortest_path = nx.shortest_path(directional_graph, start, goal)[:max_length]
shortest_lengths.append(len(shortest_path))
except Exception:
# If there is no path, use the maximum length instead
shortest_lengths.append(max_length)
def metric_fn(
samples: List[str],
) -> Dict[str, List[float]]:
"""Metric Function
Args:
samples: Batch of samples
Returns:
Dict of metrics, each with a key of the metric name and value as a
list of metric values for each batch item.
"""
# Length to set if the path is invalid
invalid_path_length: int = 100
# Initialise batch lengths & reference lengths (the optimal length
# starting from each batch items specific start node)
lengths: List[float] = []
sample_optimal_lengths: List[int] = []
for sample_str in samples:
# Remove GPT2 specific tokenizer delimiter
if gpt2_tokenizer:
sample_str = sample_str.replace("|", "")
# Convert the sample into a list of nodes (default to an unused
# integer if the node is not found)
sample: List[int] = [char_to_node.get(c, 1000) for c in sample_str]
# Initialise the specific sample length
length: Optional[float] = None
for node in range(len(sample)):
# If an invalid path is taken, set the length to the invalid
# path score
if sample[node] >= n_nodes or node > 0 and not adjacency_matrix[sample[node - 1], sample[node]]:
length = invalid_path_length
break
# Otherwise increment the length for each move (where we don't
# end up at the goal node)
elif sample[node] == 0:
length = node + 1
break
# Catch the case where there are no moves
if length is None:
length = invalid_path_length
# Store the batch item length & optimal length staring from the
# start node
lengths.append(float(length))
sample_optimal_lengths.append(shortest_lengths[sample[0] - 1])
# Calculate optimality scores, in [0, 1], as compared to the shortest
# path
lengths_tensor = torch.tensor(lengths, dtype=torch.float)
bound_lengths: torch.Tensor = torch.where(
lengths_tensor.eq(invalid_path_length), max_length, lengths_tensor
).abs()
optimal_lengths = torch.as_tensor(sample_optimal_lengths)
# Optimality scores, in [0, 1], as compared to the shortest path
optimality = (max_length - bound_lengths) / (max_length - optimal_lengths)
return {
"lengths": lengths,
"optimality": optimality.tolist(),
}
logit_mask = torch.tensor(adjacency_matrix)
# Set the evaluation prompts as a list of unique random walk samples, using
# just the start point (first character) from each samples.
eval_prompts = list(sorted(set(w[0] for w in sample_walks)))
eval_prompts = [prompt + delimiter for prompt in eval_prompts]
return (metric_fn, eval_prompts, sample_walks, logit_mask)
|