# My implementation of the k_Sequitur algorithm described in the papers: https://arxiv.org/pdf/cs/9709102.pdf # and https://www.biorxiv.org/content/biorxiv/early/2018/03/13/281543.full.pdf # The algorithm takes in a sequence and forms a grammar using two rules: # 1) No pair of adjacent symbols appears more than k times in the grammar # 2) Every rule in the grammar is used more than k times # # e.g. string "abddddeabde" with k=2 would turn to: # "AdddBAB" # R1: A --> ab # R2: B --> de # TODO fix the fact that it sometimes provides rules that have end of episode symbol in them # TODO add an option to return rules in terms of the amount of times they appear in a set of provided episodes from collections import defaultdict, Counter class k_Sequitur(object): def __init__(self, k, end_of_episode_symbol="/"): self.k = k self.end_of_episode_symbol = end_of_episode_symbol self.next_rule_name_ix = 0 def generate_action_grammar(self, actions): """Generates a grammar given a list of actions""" assert isinstance(actions, list), actions assert not isinstance(actions[0], list), "Should be 1 long list of actions - {}".format(actions[0]) assert len(actions) > 0, "Need to provide a list of at least 1 action" assert isinstance(actions[0], int), "The actions should be integers" new_actions, all_rules, rule_usage, rules_episode_appearance_count = self.discover_all_rules_and_new_actions_representation(actions) action_usage = self.extract_action_usage_from_rule_usage(rule_usage, all_rules) rules_episode_appearance_count = self.extract_action_usage_from_rule_usage(rules_episode_appearance_count, all_rules) return new_actions, all_rules, action_usage, rules_episode_appearance_count def discover_all_rules_and_new_actions_representation(self, actions): """Takes in a list of actions and discovers all the rules present that get used more than self.k times and the subsequent new actions list when all rules are applied recursively""" all_rules = {} current_actions = None new_actions = actions rule_usage = defaultdict(int) num_episodes = Counter(actions)[self.end_of_episode_symbol] rules_episode_appearance_tracker = {k: defaultdict(int) for k in range(num_episodes)} while new_actions != current_actions: current_actions = new_actions rules, reverse_rules = self.generate_1_layer_of_rules(current_actions) all_rules.update(rules) new_actions, rules_usage_count = self.convert_a_string_using_reverse_rules(current_actions, reverse_rules, rules_episode_appearance_tracker) for key in rules_usage_count.keys(): rule_usage[key] += rules_usage_count[key] rules_episode_appearance_count = defaultdict(int) for episode in range(num_episodes): rule_apperance_tracker = rules_episode_appearance_tracker[episode] for key in rule_apperance_tracker.keys(): if rule_apperance_tracker[key] == 1: rules_episode_appearance_count[key] += 1 return new_actions, all_rules, rule_usage, rules_episode_appearance_count def generate_1_layer_of_rules(self, string): """Generate dictionaries indicating the pair of symbols that appear next to each other more than self.k times""" pairs_of_symbols = defaultdict(int) last_pair = None skip_next_symbol = False rules = {} assert string[-1] == self.end_of_episode_symbol, "Final element of string must be self.end_of_episode_symbol {}".format(string) for ix in range(len(string) - 1): # We skip the next symbol if it is already being used in a rule we just made if skip_next_symbol: skip_next_symbol = False continue pair = (string[ix], string[ix+1]) # We don't count a pair if it was the previous pair (and therefore we have 3 of the same symbols in a row) if pair != last_pair: pairs_of_symbols[pair] += 1 last_pair = pair else: last_pair = None if pairs_of_symbols[pair] >= self.k: previous_pair = (string[ix-1], string[ix]) pairs_of_symbols[previous_pair] -= 1 skip_next_symbol = True if pair not in rules.values() and self.end_of_episode_symbol not in pair: rule_name = self.get_next_rule_name() rules[rule_name] = pair reverse_rules = {v: k for k, v in rules.items()} return rules, reverse_rules def get_next_rule_name(self): """Returns next rule name to use and increments count """ next_rule_name = "R{}".format(self.next_rule_name_ix) self.next_rule_name_ix += 1 return next_rule_name def convert_symbol_to_raw_actions(self, symbol, rules): """Converts a symbol back to the sequence of raw actions it represents""" assert not isinstance(symbol, list) assert isinstance(symbol, str) or isinstance(symbol, int) symbol = [symbol] finished = False while not finished: new_symbol = [] for symbol_val in symbol: if symbol_val in rules.keys(): new_symbol.append(rules[symbol_val][0]) new_symbol.append(rules[symbol_val][1]) else: new_symbol.append(symbol_val) if new_symbol == symbol: finished = True else: symbol = new_symbol new_symbol = tuple(new_symbol) return new_symbol def extract_action_usage_from_rule_usage(self, rule_usage, all_rules): """Extracts the usage of each action (of 2 or more primitive actions) out from the usage of each rule""" action_usage = {} for key in rule_usage.keys(): action_usage[self.convert_symbol_to_raw_actions(key, all_rules)] = rule_usage[key] return action_usage def convert_a_string_using_reverse_rules(self, string, reverse_rules, rules_episode_appearance_tracker): """Converts a string using the rules we have previously generated""" new_string = [] skip_next_element = False rules_usage_count = defaultdict(int) episode = 0 rules_used_this_episode = [] for ix in range(len(string)): if string[ix] == self.end_of_episode_symbol: rules_used_this_episode = set(rules_used_this_episode) for rule in rules_used_this_episode: rules_episode_appearance_tracker[episode][rule] = 1 rules_used_this_episode = [] episode += 1 if skip_next_element: skip_next_element = False continue # If is last element in string and wasn't just part of a pair then we add it to new string and finish if ix == len(string) - 1: new_string.append(string[ix]) continue pair = (string[ix], string[ix+1]) if pair in reverse_rules.keys(): result = reverse_rules[pair] rules_usage_count[result] += 1 rules_used_this_episode.append(result) new_string.append(result) skip_next_element = True else: new_string.append(string[ix]) return new_string, rules_usage_count