asataura's picture
initial commit
6fa23b0
raw
history blame
7.75 kB
# My implementation of the k_Sequitur algorithm described in the papers: https://arxiv.org/pdf/cs/9709102.pdf
# and https://www.biorxiv.org/content/biorxiv/early/2018/03/13/281543.full.pdf
# The algorithm takes in a sequence and forms a grammar using two rules:
# 1) No pair of adjacent symbols appears more than k times in the grammar
# 2) Every rule in the grammar is used more than k times
#
# e.g. string "abddddeabde" with k=2 would turn to:
# "AdddBAB"
# R1: A --> ab
# R2: B --> de
# TODO fix the fact that it sometimes provides rules that have end of episode symbol in them
# TODO add an option to return rules in terms of the amount of times they appear in a set of provided episodes
from collections import defaultdict, Counter
class k_Sequitur(object):
def __init__(self, k, end_of_episode_symbol="/"):
self.k = k
self.end_of_episode_symbol = end_of_episode_symbol
self.next_rule_name_ix = 0
def generate_action_grammar(self, actions):
"""Generates a grammar given a list of actions"""
assert isinstance(actions, list), actions
assert not isinstance(actions[0], list), "Should be 1 long list of actions - {}".format(actions[0])
assert len(actions) > 0, "Need to provide a list of at least 1 action"
assert isinstance(actions[0], int), "The actions should be integers"
new_actions, all_rules, rule_usage, rules_episode_appearance_count = self.discover_all_rules_and_new_actions_representation(actions)
action_usage = self.extract_action_usage_from_rule_usage(rule_usage, all_rules)
rules_episode_appearance_count = self.extract_action_usage_from_rule_usage(rules_episode_appearance_count,
all_rules)
return new_actions, all_rules, action_usage, rules_episode_appearance_count
def discover_all_rules_and_new_actions_representation(self, actions):
"""Takes in a list of actions and discovers all the rules present that get used more than self.k times and the
subsequent new actions list when all rules are applied recursively"""
all_rules = {}
current_actions = None
new_actions = actions
rule_usage = defaultdict(int)
num_episodes = Counter(actions)[self.end_of_episode_symbol]
rules_episode_appearance_tracker = {k: defaultdict(int) for k in range(num_episodes)}
while new_actions != current_actions:
current_actions = new_actions
rules, reverse_rules = self.generate_1_layer_of_rules(current_actions)
all_rules.update(rules)
new_actions, rules_usage_count = self.convert_a_string_using_reverse_rules(current_actions, reverse_rules,
rules_episode_appearance_tracker)
for key in rules_usage_count.keys():
rule_usage[key] += rules_usage_count[key]
rules_episode_appearance_count = defaultdict(int)
for episode in range(num_episodes):
rule_apperance_tracker = rules_episode_appearance_tracker[episode]
for key in rule_apperance_tracker.keys():
if rule_apperance_tracker[key] == 1:
rules_episode_appearance_count[key] += 1
return new_actions, all_rules, rule_usage, rules_episode_appearance_count
def generate_1_layer_of_rules(self, string):
"""Generate dictionaries indicating the pair of symbols that appear next to each other more than self.k times"""
pairs_of_symbols = defaultdict(int)
last_pair = None
skip_next_symbol = False
rules = {}
assert string[-1] == self.end_of_episode_symbol, "Final element of string must be self.end_of_episode_symbol {}".format(string)
for ix in range(len(string) - 1):
# We skip the next symbol if it is already being used in a rule we just made
if skip_next_symbol:
skip_next_symbol = False
continue
pair = (string[ix], string[ix+1])
# We don't count a pair if it was the previous pair (and therefore we have 3 of the same symbols in a row)
if pair != last_pair:
pairs_of_symbols[pair] += 1
last_pair = pair
else: last_pair = None
if pairs_of_symbols[pair] >= self.k:
previous_pair = (string[ix-1], string[ix])
pairs_of_symbols[previous_pair] -= 1
skip_next_symbol = True
if pair not in rules.values() and self.end_of_episode_symbol not in pair:
rule_name = self.get_next_rule_name()
rules[rule_name] = pair
reverse_rules = {v: k for k, v in rules.items()}
return rules, reverse_rules
def get_next_rule_name(self):
"""Returns next rule name to use and increments count """
next_rule_name = "R{}".format(self.next_rule_name_ix)
self.next_rule_name_ix += 1
return next_rule_name
def convert_symbol_to_raw_actions(self, symbol, rules):
"""Converts a symbol back to the sequence of raw actions it represents"""
assert not isinstance(symbol, list)
assert isinstance(symbol, str) or isinstance(symbol, int)
symbol = [symbol]
finished = False
while not finished:
new_symbol = []
for symbol_val in symbol:
if symbol_val in rules.keys():
new_symbol.append(rules[symbol_val][0])
new_symbol.append(rules[symbol_val][1])
else:
new_symbol.append(symbol_val)
if new_symbol == symbol: finished = True
else: symbol = new_symbol
new_symbol = tuple(new_symbol)
return new_symbol
def extract_action_usage_from_rule_usage(self, rule_usage, all_rules):
"""Extracts the usage of each action (of 2 or more primitive actions) out from the usage of each rule"""
action_usage = {}
for key in rule_usage.keys():
action_usage[self.convert_symbol_to_raw_actions(key, all_rules)] = rule_usage[key]
return action_usage
def convert_a_string_using_reverse_rules(self, string, reverse_rules, rules_episode_appearance_tracker):
"""Converts a string using the rules we have previously generated"""
new_string = []
skip_next_element = False
rules_usage_count = defaultdict(int)
episode = 0
rules_used_this_episode = []
for ix in range(len(string)):
if string[ix] == self.end_of_episode_symbol:
rules_used_this_episode = set(rules_used_this_episode)
for rule in rules_used_this_episode:
rules_episode_appearance_tracker[episode][rule] = 1
rules_used_this_episode = []
episode += 1
if skip_next_element:
skip_next_element = False
continue
# If is last element in string and wasn't just part of a pair then we add it to new string and finish
if ix == len(string) - 1:
new_string.append(string[ix])
continue
pair = (string[ix], string[ix+1])
if pair in reverse_rules.keys():
result = reverse_rules[pair]
rules_usage_count[result] += 1
rules_used_this_episode.append(result)
new_string.append(result)
skip_next_element = True
else:
new_string.append(string[ix])
return new_string, rules_usage_count