File size: 7,753 Bytes
6fa23b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167

# My implementation of the k_Sequitur algorithm described in the papers: https://arxiv.org/pdf/cs/9709102.pdf
# and https://www.biorxiv.org/content/biorxiv/early/2018/03/13/281543.full.pdf
# The algorithm takes in a sequence and forms a grammar using two rules:
# 1) No pair of adjacent symbols appears more than k times in the grammar
# 2) Every rule in the grammar is used more than k times
#
# e.g. string "abddddeabde" with k=2 would turn to:
# "AdddBAB"
# R1: A --> ab
# R2: B --> de

# TODO fix the fact that it sometimes provides rules that have end of episode symbol in them
# TODO add an option to return rules in terms of the amount of times they appear in a set of provided episodes

from collections import defaultdict, Counter


class k_Sequitur(object):

    def __init__(self, k, end_of_episode_symbol="/"):
        self.k = k
        self.end_of_episode_symbol = end_of_episode_symbol
        self.next_rule_name_ix = 0

    def generate_action_grammar(self, actions):
        """Generates a grammar given a list of actions"""
        assert isinstance(actions, list), actions
        assert not isinstance(actions[0], list), "Should be 1 long list of actions - {}".format(actions[0])
        assert len(actions) > 0, "Need to provide a list of at least 1 action"
        assert isinstance(actions[0], int), "The actions should be integers"
        new_actions, all_rules, rule_usage, rules_episode_appearance_count = self.discover_all_rules_and_new_actions_representation(actions)
        action_usage = self.extract_action_usage_from_rule_usage(rule_usage, all_rules)
        rules_episode_appearance_count = self.extract_action_usage_from_rule_usage(rules_episode_appearance_count,
                                                                                   all_rules)
        return new_actions, all_rules, action_usage, rules_episode_appearance_count

    def discover_all_rules_and_new_actions_representation(self, actions):
        """Takes in a list of actions and discovers all the rules present that get used more than self.k times and the
        subsequent new actions list when all rules are applied recursively"""
        all_rules = {}
        current_actions = None
        new_actions = actions
        rule_usage = defaultdict(int)
        num_episodes = Counter(actions)[self.end_of_episode_symbol]
        rules_episode_appearance_tracker = {k: defaultdict(int) for k in range(num_episodes)}

        while new_actions != current_actions:
            current_actions = new_actions
            rules, reverse_rules = self.generate_1_layer_of_rules(current_actions)
            all_rules.update(rules)
            new_actions, rules_usage_count = self.convert_a_string_using_reverse_rules(current_actions, reverse_rules,
                                                                                       rules_episode_appearance_tracker)
            for key in rules_usage_count.keys():
                rule_usage[key] += rules_usage_count[key]

        rules_episode_appearance_count = defaultdict(int)

        for episode in range(num_episodes):
            rule_apperance_tracker = rules_episode_appearance_tracker[episode]
            for key in rule_apperance_tracker.keys():
                if rule_apperance_tracker[key] == 1:
                    rules_episode_appearance_count[key] += 1

        return new_actions, all_rules, rule_usage, rules_episode_appearance_count

    def generate_1_layer_of_rules(self, string):
        """Generate dictionaries indicating the pair of symbols that appear next to each other more than self.k times"""
        pairs_of_symbols = defaultdict(int)
        last_pair = None
        skip_next_symbol = False
        rules = {}

        assert string[-1] == self.end_of_episode_symbol, "Final element of string must be self.end_of_episode_symbol {}".format(string)

        for ix in range(len(string) - 1):
            # We skip the next symbol if it is already being used in a rule we just made
            if skip_next_symbol:
                skip_next_symbol = False
                continue

            pair = (string[ix], string[ix+1])

            # We don't count a pair if it was the previous pair (and therefore we have 3 of the same symbols in a row)
            if pair != last_pair:
                pairs_of_symbols[pair] += 1
                last_pair = pair
            else: last_pair = None
            if pairs_of_symbols[pair] >= self.k:
                previous_pair = (string[ix-1], string[ix])
                pairs_of_symbols[previous_pair] -= 1
                skip_next_symbol = True
                if pair not in rules.values() and self.end_of_episode_symbol not in pair:
                    rule_name = self.get_next_rule_name()
                    rules[rule_name] = pair
        reverse_rules = {v: k for k, v in rules.items()}
        return rules, reverse_rules

    def get_next_rule_name(self):
        """Returns next rule name to use and increments count """
        next_rule_name = "R{}".format(self.next_rule_name_ix)
        self.next_rule_name_ix += 1
        return next_rule_name

    def convert_symbol_to_raw_actions(self, symbol, rules):
        """Converts a symbol back to the sequence of raw actions it represents"""
        assert not isinstance(symbol, list)
        assert isinstance(symbol, str) or isinstance(symbol, int)
        symbol = [symbol]
        finished = False
        while not finished:
            new_symbol = []
            for symbol_val in symbol:
                if symbol_val in rules.keys():
                    new_symbol.append(rules[symbol_val][0])
                    new_symbol.append(rules[symbol_val][1])
                else:
                    new_symbol.append(symbol_val)
            if new_symbol == symbol: finished = True
            else: symbol = new_symbol
        new_symbol = tuple(new_symbol)
        return new_symbol

    def extract_action_usage_from_rule_usage(self, rule_usage, all_rules):
        """Extracts the usage of each action (of 2 or more primitive actions) out from the usage of each rule"""
        action_usage = {}
        for key in rule_usage.keys():
            action_usage[self.convert_symbol_to_raw_actions(key, all_rules)] = rule_usage[key]
        return action_usage

    def convert_a_string_using_reverse_rules(self, string, reverse_rules, rules_episode_appearance_tracker):
        """Converts a string using the rules we have previously generated"""
        new_string = []
        skip_next_element = False
        rules_usage_count = defaultdict(int)

        episode = 0

        rules_used_this_episode = []

        for ix in range(len(string)):
            if string[ix] == self.end_of_episode_symbol:
                rules_used_this_episode = set(rules_used_this_episode)
                for rule in rules_used_this_episode:
                    rules_episode_appearance_tracker[episode][rule] = 1
                rules_used_this_episode = []
                episode += 1

            if skip_next_element:
                skip_next_element = False
                continue
            # If is last element in string and wasn't just part of a pair then we add it to new string and finish
            if ix == len(string) - 1:
                new_string.append(string[ix])
                continue
            pair = (string[ix], string[ix+1])
            if pair in reverse_rules.keys():
                result = reverse_rules[pair]
                rules_usage_count[result] += 1
                rules_used_this_episode.append(result)
                new_string.append(result)
                skip_next_element = True
            else:
                new_string.append(string[ix])
        return new_string, rules_usage_count