|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
"""Tasks that test correctness of algorithms.""" |
|
|
|
from six.moves import xrange |
|
from common import reward as reward_lib |
|
from single_task import misc |
|
|
|
|
|
class BasicTaskManager(object): |
|
"""Wraps a generic reward function.""" |
|
|
|
def __init__(self, reward_fn): |
|
self.reward_fn = reward_fn |
|
self.good_reward = 1.0 |
|
|
|
def _score_string(self, string): |
|
actions = misc.bf_string_to_tokens(string) |
|
reward, correct = self.reward_fn(actions) |
|
return misc.RewardInfo( |
|
episode_rewards=[0.0] * (len(string) - 1) + [reward], |
|
input_case=None, |
|
correct_output=None, |
|
code_output=actions, |
|
input_type=None, |
|
output_type=misc.IOType.integer, |
|
reason='correct' if correct else 'wrong') |
|
|
|
def rl_batch(self, batch_size): |
|
reward_fns = [self._score_string] * batch_size |
|
return reward_fns |
|
|
|
|
|
class Trie(object): |
|
"""Trie for sequences.""" |
|
EOS = () |
|
|
|
def __init__(self): |
|
self.trie = {} |
|
|
|
def insert(self, sequence): |
|
d = self.trie |
|
for e in sequence: |
|
if e not in d: |
|
d[e] = {} |
|
d = d[e] |
|
d[self.EOS] = True |
|
|
|
def prefix_match(self, sequence): |
|
"""Return prefix of `sequence` which exists in the trie.""" |
|
d = self.trie |
|
index = 0 |
|
for i, e in enumerate(sequence + [self.EOS]): |
|
index = i |
|
if e in d: |
|
d = d[e] |
|
if e == self.EOS: |
|
return sequence, True |
|
else: |
|
break |
|
return sequence[:index], False |
|
|
|
def next_choices(self, sequence): |
|
d = self.trie |
|
for e in sequence: |
|
if e in d: |
|
d = d[e] |
|
else: |
|
raise ValueError('Sequence not a prefix: %s' % (sequence,)) |
|
return d.keys() |
|
|
|
|
|
class HillClimbingTask(object): |
|
"""Simple task that tests reward hill climbing ability. |
|
|
|
There are a set of paths (sequences of tokens) which are rewarded. The total |
|
reward for a path is proportional to its length, so the longest path is the |
|
target. Shorter paths can be dead ends. |
|
""" |
|
|
|
def __init__(self): |
|
|
|
|
|
|
|
|
|
a = (1, 2, 3) |
|
b = (4, 5, 6) |
|
c = (7, 8, 7) |
|
d = (6, 5, 4) |
|
e = (3, 2, 1) |
|
f = (8, 5, 1) |
|
g = (6, 4, 2) |
|
h = (1, 8, 3) |
|
self.paths = Trie() |
|
self.paths.insert([a, b, h]) |
|
self.paths.insert([a, b, c, d, e, f, g, h]) |
|
self.paths.insert([a, b, c, d, e, b, a]) |
|
self.paths.insert([a, b, g, h]) |
|
self.paths.insert([a, e, f, g]) |
|
self.correct_sequence = misc.flatten([a, b, c, d, e, f, g, h]) |
|
|
|
def distance_fn(a, b): |
|
len_diff = abs(len(a) - len(b)) |
|
return sum(reward_lib.mod_abs_diff(ai - 1, bi - 1, 8) |
|
for ai, bi in zip(a, b)) + len_diff * 4 |
|
self.distance_fn = distance_fn |
|
|
|
def __call__(self, actions): |
|
|
|
actions = [a for a in actions if a > 0] |
|
sequence = [tuple(actions[i: i + 3]) for i in xrange(0, len(actions), 3)] |
|
prefix, complete = self.paths.prefix_match(sequence) |
|
if complete: |
|
return float(len(prefix)), actions == self.correct_sequence |
|
if len(prefix) == len(sequence): |
|
return float(len(prefix)), False |
|
next_pred = sequence[len(prefix)] |
|
choices = self.paths.next_choices(prefix) |
|
if choices == [()]: |
|
return (len(prefix) - len(next_pred) / 3.0), False |
|
min_dist = min(self.distance_fn(c, next_pred) for c in choices) |
|
|
|
|
|
|
|
return (len(prefix) + (1 - min_dist / 12.0)), False |
|
|