Spaces:
Sleeping
Sleeping
File size: 7,753 Bytes
6fa23b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
# My implementation of the k_Sequitur algorithm described in the papers: https://arxiv.org/pdf/cs/9709102.pdf
# and https://www.biorxiv.org/content/biorxiv/early/2018/03/13/281543.full.pdf
# The algorithm takes in a sequence and forms a grammar using two rules:
# 1) No pair of adjacent symbols appears more than k times in the grammar
# 2) Every rule in the grammar is used more than k times
#
# e.g. string "abddddeabde" with k=2 would turn to:
# "AdddBAB"
# R1: A --> ab
# R2: B --> de
# TODO fix the fact that it sometimes provides rules that have end of episode symbol in them
# TODO add an option to return rules in terms of the amount of times they appear in a set of provided episodes
from collections import defaultdict, Counter
class k_Sequitur(object):
def __init__(self, k, end_of_episode_symbol="/"):
self.k = k
self.end_of_episode_symbol = end_of_episode_symbol
self.next_rule_name_ix = 0
def generate_action_grammar(self, actions):
"""Generates a grammar given a list of actions"""
assert isinstance(actions, list), actions
assert not isinstance(actions[0], list), "Should be 1 long list of actions - {}".format(actions[0])
assert len(actions) > 0, "Need to provide a list of at least 1 action"
assert isinstance(actions[0], int), "The actions should be integers"
new_actions, all_rules, rule_usage, rules_episode_appearance_count = self.discover_all_rules_and_new_actions_representation(actions)
action_usage = self.extract_action_usage_from_rule_usage(rule_usage, all_rules)
rules_episode_appearance_count = self.extract_action_usage_from_rule_usage(rules_episode_appearance_count,
all_rules)
return new_actions, all_rules, action_usage, rules_episode_appearance_count
def discover_all_rules_and_new_actions_representation(self, actions):
"""Takes in a list of actions and discovers all the rules present that get used more than self.k times and the
subsequent new actions list when all rules are applied recursively"""
all_rules = {}
current_actions = None
new_actions = actions
rule_usage = defaultdict(int)
num_episodes = Counter(actions)[self.end_of_episode_symbol]
rules_episode_appearance_tracker = {k: defaultdict(int) for k in range(num_episodes)}
while new_actions != current_actions:
current_actions = new_actions
rules, reverse_rules = self.generate_1_layer_of_rules(current_actions)
all_rules.update(rules)
new_actions, rules_usage_count = self.convert_a_string_using_reverse_rules(current_actions, reverse_rules,
rules_episode_appearance_tracker)
for key in rules_usage_count.keys():
rule_usage[key] += rules_usage_count[key]
rules_episode_appearance_count = defaultdict(int)
for episode in range(num_episodes):
rule_apperance_tracker = rules_episode_appearance_tracker[episode]
for key in rule_apperance_tracker.keys():
if rule_apperance_tracker[key] == 1:
rules_episode_appearance_count[key] += 1
return new_actions, all_rules, rule_usage, rules_episode_appearance_count
def generate_1_layer_of_rules(self, string):
"""Generate dictionaries indicating the pair of symbols that appear next to each other more than self.k times"""
pairs_of_symbols = defaultdict(int)
last_pair = None
skip_next_symbol = False
rules = {}
assert string[-1] == self.end_of_episode_symbol, "Final element of string must be self.end_of_episode_symbol {}".format(string)
for ix in range(len(string) - 1):
# We skip the next symbol if it is already being used in a rule we just made
if skip_next_symbol:
skip_next_symbol = False
continue
pair = (string[ix], string[ix+1])
# We don't count a pair if it was the previous pair (and therefore we have 3 of the same symbols in a row)
if pair != last_pair:
pairs_of_symbols[pair] += 1
last_pair = pair
else: last_pair = None
if pairs_of_symbols[pair] >= self.k:
previous_pair = (string[ix-1], string[ix])
pairs_of_symbols[previous_pair] -= 1
skip_next_symbol = True
if pair not in rules.values() and self.end_of_episode_symbol not in pair:
rule_name = self.get_next_rule_name()
rules[rule_name] = pair
reverse_rules = {v: k for k, v in rules.items()}
return rules, reverse_rules
def get_next_rule_name(self):
"""Returns next rule name to use and increments count """
next_rule_name = "R{}".format(self.next_rule_name_ix)
self.next_rule_name_ix += 1
return next_rule_name
def convert_symbol_to_raw_actions(self, symbol, rules):
"""Converts a symbol back to the sequence of raw actions it represents"""
assert not isinstance(symbol, list)
assert isinstance(symbol, str) or isinstance(symbol, int)
symbol = [symbol]
finished = False
while not finished:
new_symbol = []
for symbol_val in symbol:
if symbol_val in rules.keys():
new_symbol.append(rules[symbol_val][0])
new_symbol.append(rules[symbol_val][1])
else:
new_symbol.append(symbol_val)
if new_symbol == symbol: finished = True
else: symbol = new_symbol
new_symbol = tuple(new_symbol)
return new_symbol
def extract_action_usage_from_rule_usage(self, rule_usage, all_rules):
"""Extracts the usage of each action (of 2 or more primitive actions) out from the usage of each rule"""
action_usage = {}
for key in rule_usage.keys():
action_usage[self.convert_symbol_to_raw_actions(key, all_rules)] = rule_usage[key]
return action_usage
def convert_a_string_using_reverse_rules(self, string, reverse_rules, rules_episode_appearance_tracker):
"""Converts a string using the rules we have previously generated"""
new_string = []
skip_next_element = False
rules_usage_count = defaultdict(int)
episode = 0
rules_used_this_episode = []
for ix in range(len(string)):
if string[ix] == self.end_of_episode_symbol:
rules_used_this_episode = set(rules_used_this_episode)
for rule in rules_used_this_episode:
rules_episode_appearance_tracker[episode][rule] = 1
rules_used_this_episode = []
episode += 1
if skip_next_element:
skip_next_element = False
continue
# If is last element in string and wasn't just part of a pair then we add it to new string and finish
if ix == len(string) - 1:
new_string.append(string[ix])
continue
pair = (string[ix], string[ix+1])
if pair in reverse_rules.keys():
result = reverse_rules[pair]
rules_usage_count[result] += 1
rules_used_this_episode.append(result)
new_string.append(result)
skip_next_element = True
else:
new_string.append(string[ix])
return new_string, rules_usage_count
|