from typing import Callable, Dict, List, Optional, Tuple import networkx as nx import numpy as np import torch def generate_rand_int_excluding(rng: np.random.RandomState, max: int, exclude: int) -> int: """Random integer generator, excluding a specific number Args: rng: Numpy random number generator max: Max number exclude: Number to exclude Returns: Random integer in [0, max], excluding the `exclude` integer. """ while True: # Create the random integer x = rng.randint(max) # Return the random integer if it isn't the exclude value, otherwise try # again if x != exclude: return x def generate_random_walks( # noqa: max-complexity n_nodes: int = 21, max_length: int = 10, n_walks: int = 1000, p_edge: float = 0.1, seed: int = 1002, gpt2_tokenizer: bool = False, ) -> Tuple[Callable[[List[str]], Dict[str, List[float]]], List[str], List[str], torch.Tensor,]: """Generate random walks Args: n_nodes: Number of nodes. This should not be more than 26, as we use single letters to represent each node. max_length: Maximum number of steps in each random walk n_walks: Number of random walks (samples) to create p_edge: Probability that any source node connects to any other destination node seed: Random seed gpt2_tokenizer: True if GPT2's tokenizer is being used Returns: Tuple of metric function, """ # Initialise a random state with the seed rng = np.random.RandomState(seed) # Create the adjacency matrix # https://en.wikipedia.org/wiki/Adjacency_matrix # This is a 2d matrix, where the rows represent the source nodes and the # columns represent the destination nodes. If a cell (i,j) is True, then # there is a directional edge from the source node (i) to the destination # node (j). If it is false there is no connection. while True: # Create the adjacency matrix, where each node is connected to each # other node, with probability p_edge adjacency_matrix: np.ndarray = rng.rand(n_nodes, n_nodes) > (1 - p_edge) # Nodes can't be connected to themselves, so the diagonal values must # all be False np.fill_diagonal(adjacency_matrix, 0) # Each destination node (column) must be connected to at least one # source node. This checks if this is the case, by checking there is a # True value in every column. If it is not the case, we try to generate # a new adjacency matrix again from scratch (in the while loop). if np.all(adjacency_matrix.sum(1)): break # Set the goal node as 0 goal: int = 0 # The goal node is the terminal state, so we make sure that it doesn't # have a directional edge going to any other nodes (i.e. it can only be # connected to from previous nodes). We also set the connection to itself as # True. adjacency_matrix[goal, :] = 0 adjacency_matrix[goal, goal] = 1 # Create dicts for converting nodes into characters and vice versa # Nodes are converted into characters as these (when split by the delimiter) are # guaranteed to be tokenized as individual tokens. char_to_node: Dict[str, int] = {chr(ix + ord("a")): ix for ix in range(n_nodes)} node_to_char: Dict[int, str] = {ix: chr(ix + ord("a")) for ix in range(n_nodes)} # Initialise a list of sample walks sample_walks: List[str] = [] # String delimiter (to force the tokenizer to keep all nodes as separate # tokens) delimiter: str = "|" if gpt2_tokenizer else "" # Create n_walks samples for _ in range(n_walks): # Create a random starting node (that isn't already at the goal state) node: int = generate_rand_int_excluding(rng, n_nodes, goal) # Initialise the list of nodes that we visit walk_nodes: List[int] = [node] # Do a series of steps, until we hit the maximum number of steps or the # goal state (whichever comes first) for _step in range(max_length - 1): # From the starting node, get all the nodes we can move to. Pick one # of these at random, and add it to the list of visited nodes node = rng.choice(np.nonzero(adjacency_matrix[node])[0]) walk_nodes.append(node) # If we're at the goal state, stop if node == goal: break # Convert the nodes visited to letters (not integers) walk: List[str] = [node_to_char[ix] for ix in walk_nodes] # Concatenate into a journey, with each node letter separated by the # delimiter. sample_walks.append(delimiter.join(walk)) # Initialise list of shortest lengths for each node (to the goal node) shortest_lengths: List[int] = [] # Create a directional graph from the adjacency list directional_graph = nx.from_numpy_array(adjacency_matrix, create_using=nx.DiGraph) # Fore each node (except for the goal node), find the shortest path for start in set(range(n_nodes)) - {goal}: try: # Find the shortest path (up to the max_length) shortest_path = nx.shortest_path(directional_graph, start, goal)[:max_length] shortest_lengths.append(len(shortest_path)) except Exception: # If there is no path, use the maximum length instead shortest_lengths.append(max_length) def metric_fn( samples: List[str], ) -> Dict[str, List[float]]: """Metric Function Args: samples: Batch of samples Returns: Dict of metrics, each with a key of the metric name and value as a list of metric values for each batch item. """ # Length to set if the path is invalid invalid_path_length: int = 100 # Initialise batch lengths & reference lengths (the optimal length # starting from each batch items specific start node) lengths: List[float] = [] sample_optimal_lengths: List[int] = [] for sample_str in samples: # Remove GPT2 specific tokenizer delimiter if gpt2_tokenizer: sample_str = sample_str.replace("|", "") # Convert the sample into a list of nodes (default to an unused # integer if the node is not found) sample: List[int] = [char_to_node.get(c, 1000) for c in sample_str] # Initialise the specific sample length length: Optional[float] = None for node in range(len(sample)): # If an invalid path is taken, set the length to the invalid # path score if sample[node] >= n_nodes or node > 0 and not adjacency_matrix[sample[node - 1], sample[node]]: length = invalid_path_length break # Otherwise increment the length for each move (where we don't # end up at the goal node) elif sample[node] == 0: length = node + 1 break # Catch the case where there are no moves if length is None: length = invalid_path_length # Store the batch item length & optimal length staring from the # start node lengths.append(float(length)) sample_optimal_lengths.append(shortest_lengths[sample[0] - 1]) # Calculate optimality scores, in [0, 1], as compared to the shortest # path lengths_tensor = torch.tensor(lengths, dtype=torch.float) bound_lengths: torch.Tensor = torch.where( lengths_tensor.eq(invalid_path_length), max_length, lengths_tensor ).abs() optimal_lengths = torch.as_tensor(sample_optimal_lengths) # Optimality scores, in [0, 1], as compared to the shortest path optimality = (max_length - bound_lengths) / (max_length - optimal_lengths) return { "lengths": lengths, "optimality": optimality.tolist(), } logit_mask = torch.tensor(adjacency_matrix) # Set the evaluation prompts as a list of unique random walk samples, using # just the start point (first character) from each samples. eval_prompts = list(sorted(set(w[0] for w in sample_walks))) eval_prompts = [prompt + delimiter for prompt in eval_prompts] return (metric_fn, eval_prompts, sample_walks, logit_mask)