|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Neural GPU -- data generation and batching utilities.""" |
|
|
|
import math |
|
import os |
|
import random |
|
import sys |
|
import time |
|
|
|
import numpy as np |
|
from six.moves import xrange |
|
import tensorflow as tf |
|
|
|
import program_utils |
|
|
|
FLAGS = tf.app.flags.FLAGS |
|
|
|
bins = [2 + bin_idx_i for bin_idx_i in xrange(256)] |
|
all_tasks = ["sort", "kvsort", "id", "rev", "rev2", "incr", "add", "left", |
|
"right", "left-shift", "right-shift", "bmul", "mul", "dup", |
|
"badd", "qadd", "search", "progeval", "progsynth"] |
|
log_filename = "" |
|
vocab, rev_vocab = None, None |
|
|
|
|
|
def pad(l): |
|
for b in bins: |
|
if b >= l: return b |
|
return bins[-1] |
|
|
|
|
|
def bin_for(l): |
|
for i, b in enumerate(bins): |
|
if b >= l: return i |
|
return len(bins) - 1 |
|
|
|
|
|
train_set = {} |
|
test_set = {} |
|
for some_task in all_tasks: |
|
train_set[some_task] = [] |
|
test_set[some_task] = [] |
|
for all_max_len in xrange(10000): |
|
train_set[some_task].append([]) |
|
test_set[some_task].append([]) |
|
|
|
|
|
def read_tmp_file(name): |
|
"""Read from a file with the given name in our log directory or above.""" |
|
dirname = os.path.dirname(log_filename) |
|
fname = os.path.join(dirname, name + ".txt") |
|
if not tf.gfile.Exists(fname): |
|
print_out("== not found file: " + fname) |
|
fname = os.path.join(dirname, "../" + name + ".txt") |
|
if not tf.gfile.Exists(fname): |
|
print_out("== not found file: " + fname) |
|
fname = os.path.join(dirname, "../../" + name + ".txt") |
|
if not tf.gfile.Exists(fname): |
|
print_out("== not found file: " + fname) |
|
return None |
|
print_out("== found file: " + fname) |
|
res = [] |
|
with tf.gfile.GFile(fname, mode="r") as f: |
|
for line in f: |
|
res.append(line.strip()) |
|
return res |
|
|
|
|
|
def write_tmp_file(name, lines): |
|
dirname = os.path.dirname(log_filename) |
|
fname = os.path.join(dirname, name + ".txt") |
|
with tf.gfile.GFile(fname, mode="w") as f: |
|
for line in lines: |
|
f.write(line + "\n") |
|
|
|
|
|
def add(n1, n2, base=10): |
|
"""Add two numbers represented as lower-endian digit lists.""" |
|
k = max(len(n1), len(n2)) + 1 |
|
d1 = n1 + [0 for _ in xrange(k - len(n1))] |
|
d2 = n2 + [0 for _ in xrange(k - len(n2))] |
|
res = [] |
|
carry = 0 |
|
for i in xrange(k): |
|
if d1[i] + d2[i] + carry < base: |
|
res.append(d1[i] + d2[i] + carry) |
|
carry = 0 |
|
else: |
|
res.append(d1[i] + d2[i] + carry - base) |
|
carry = 1 |
|
while res and res[-1] == 0: |
|
res = res[:-1] |
|
if res: return res |
|
return [0] |
|
|
|
|
|
def init_data(task, length, nbr_cases, nclass): |
|
"""Data initialization.""" |
|
def rand_pair(l, task): |
|
"""Random data pair for a task. Total length should be <= l.""" |
|
k = int((l-1)/2) |
|
base = 10 |
|
if task[0] == "b": base = 2 |
|
if task[0] == "q": base = 4 |
|
d1 = [np.random.randint(base) for _ in xrange(k)] |
|
d2 = [np.random.randint(base) for _ in xrange(k)] |
|
if task in ["add", "badd", "qadd"]: |
|
res = add(d1, d2, base) |
|
elif task in ["mul", "bmul"]: |
|
d1n = sum([d * (base ** i) for i, d in enumerate(d1)]) |
|
d2n = sum([d * (base ** i) for i, d in enumerate(d2)]) |
|
if task == "bmul": |
|
res = [int(x) for x in list(reversed(str(bin(d1n * d2n))))[:-2]] |
|
else: |
|
res = [int(x) for x in list(reversed(str(d1n * d2n)))] |
|
else: |
|
sys.exit() |
|
sep = [12] |
|
if task in ["add", "badd", "qadd"]: sep = [11] |
|
inp = [d + 1 for d in d1] + sep + [d + 1 for d in d2] |
|
return inp, [r + 1 for r in res] |
|
|
|
def rand_dup_pair(l): |
|
"""Random data pair for duplication task. Total length should be <= l.""" |
|
k = int(l/2) |
|
x = [np.random.randint(nclass - 1) + 1 for _ in xrange(k)] |
|
inp = x + [0 for _ in xrange(l - k)] |
|
res = x + x + [0 for _ in xrange(l - 2*k)] |
|
return inp, res |
|
|
|
def rand_rev2_pair(l): |
|
"""Random data pair for reverse2 task. Total length should be <= l.""" |
|
inp = [(np.random.randint(nclass - 1) + 1, |
|
np.random.randint(nclass - 1) + 1) for _ in xrange(l/2)] |
|
res = [i for i in reversed(inp)] |
|
return [x for p in inp for x in p], [x for p in res for x in p] |
|
|
|
def rand_search_pair(l): |
|
"""Random data pair for search task. Total length should be <= l.""" |
|
inp = [(np.random.randint(nclass - 1) + 1, |
|
np.random.randint(nclass - 1) + 1) for _ in xrange(l-1/2)] |
|
q = np.random.randint(nclass - 1) + 1 |
|
res = 0 |
|
for (k, v) in reversed(inp): |
|
if k == q: |
|
res = v |
|
return [x for p in inp for x in p] + [q], [res] |
|
|
|
def rand_kvsort_pair(l): |
|
"""Random data pair for key-value sort. Total length should be <= l.""" |
|
keys = [(np.random.randint(nclass - 1) + 1, i) for i in xrange(l/2)] |
|
vals = [np.random.randint(nclass - 1) + 1 for _ in xrange(l/2)] |
|
kv = [(k, vals[i]) for (k, i) in keys] |
|
sorted_kv = [(k, vals[i]) for (k, i) in sorted(keys)] |
|
return [x for p in kv for x in p], [x for p in sorted_kv for x in p] |
|
|
|
def prog_io_pair(prog, max_len, counter=0): |
|
try: |
|
ilen = np.random.randint(max_len - 3) + 1 |
|
bound = max(15 - (counter / 20), 1) |
|
inp = [random.choice(range(-bound, bound)) for _ in range(ilen)] |
|
inp_toks = [program_utils.prog_rev_vocab[t] |
|
for t in program_utils.tokenize(str(inp)) if t != ","] |
|
out = program_utils.evaluate(prog, {"a": inp}) |
|
out_toks = [program_utils.prog_rev_vocab[t] |
|
for t in program_utils.tokenize(str(out)) if t != ","] |
|
if counter > 400: |
|
out_toks = [] |
|
if (out_toks and out_toks[0] == program_utils.prog_rev_vocab["["] and |
|
len(out_toks) != len([o for o in out if o == ","]) + 3): |
|
raise ValueError("generated list with too long ints") |
|
if (out_toks and out_toks[0] != program_utils.prog_rev_vocab["["] and |
|
len(out_toks) > 1): |
|
raise ValueError("generated one int but tokenized it to many") |
|
if len(out_toks) > max_len: |
|
raise ValueError("output too long") |
|
return (inp_toks, out_toks) |
|
except ValueError: |
|
return prog_io_pair(prog, max_len, counter+1) |
|
|
|
def spec(inp): |
|
"""Return the target given the input for some tasks.""" |
|
if task == "sort": |
|
return sorted(inp) |
|
elif task == "id": |
|
return inp |
|
elif task == "rev": |
|
return [i for i in reversed(inp)] |
|
elif task == "incr": |
|
carry = 1 |
|
res = [] |
|
for i in xrange(len(inp)): |
|
if inp[i] + carry < nclass: |
|
res.append(inp[i] + carry) |
|
carry = 0 |
|
else: |
|
res.append(1) |
|
carry = 1 |
|
return res |
|
elif task == "left": |
|
return [inp[0]] |
|
elif task == "right": |
|
return [inp[-1]] |
|
elif task == "left-shift": |
|
return [inp[l-1] for l in xrange(len(inp))] |
|
elif task == "right-shift": |
|
return [inp[l+1] for l in xrange(len(inp))] |
|
else: |
|
print_out("Unknown spec for task " + str(task)) |
|
sys.exit() |
|
|
|
l = length |
|
cur_time = time.time() |
|
total_time = 0.0 |
|
|
|
is_prog = task in ["progeval", "progsynth"] |
|
if is_prog: |
|
inputs_per_prog = 5 |
|
program_utils.make_vocab() |
|
progs = read_tmp_file("programs_len%d" % (l / 10)) |
|
if not progs: |
|
progs = program_utils.gen(l / 10, 1.2 * nbr_cases / inputs_per_prog) |
|
write_tmp_file("programs_len%d" % (l / 10), progs) |
|
prog_ios = read_tmp_file("programs_len%d_io" % (l / 10)) |
|
nbr_cases = min(nbr_cases, len(progs) * inputs_per_prog) / 1.2 |
|
if not prog_ios: |
|
|
|
prog_ios = [] |
|
for pidx, prog in enumerate(progs): |
|
if pidx % 500 == 0: |
|
print_out("== generating io pairs for program %d" % pidx) |
|
if pidx * inputs_per_prog > nbr_cases * 1.2: |
|
break |
|
ptoks = [program_utils.prog_rev_vocab[t] |
|
for t in program_utils.tokenize(prog)] |
|
ptoks.append(program_utils.prog_rev_vocab["_EOS"]) |
|
plen = len(ptoks) |
|
for _ in xrange(inputs_per_prog): |
|
if task == "progeval": |
|
inp, out = prog_io_pair(prog, plen) |
|
prog_ios.append(str(inp) + "\t" + str(out) + "\t" + prog) |
|
elif task == "progsynth": |
|
plen = max(len(ptoks), 8) |
|
for _ in xrange(3): |
|
inp, out = prog_io_pair(prog, plen / 2) |
|
prog_ios.append(str(inp) + "\t" + str(out) + "\t" + prog) |
|
write_tmp_file("programs_len%d_io" % (l / 10), prog_ios) |
|
prog_ios_dict = {} |
|
for s in prog_ios: |
|
i, o, p = s.split("\t") |
|
i_clean = "".join([c for c in i if c.isdigit() or c == " "]) |
|
o_clean = "".join([c for c in o if c.isdigit() or c == " "]) |
|
inp = [int(x) for x in i_clean.split()] |
|
out = [int(x) for x in o_clean.split()] |
|
if inp and out: |
|
if p in prog_ios_dict: |
|
prog_ios_dict[p].append([inp, out]) |
|
else: |
|
prog_ios_dict[p] = [[inp, out]] |
|
|
|
progs = [] |
|
for prog in prog_ios_dict: |
|
if len([c for c in prog if c == ";"]) <= (l / 10): |
|
progs.append(prog) |
|
nbr_cases = min(nbr_cases, len(progs) * inputs_per_prog) / 1.2 |
|
print_out("== %d training cases on %d progs" % (nbr_cases, len(progs))) |
|
for pidx, prog in enumerate(progs): |
|
if pidx * inputs_per_prog > nbr_cases * 1.2: |
|
break |
|
ptoks = [program_utils.prog_rev_vocab[t] |
|
for t in program_utils.tokenize(prog)] |
|
ptoks.append(program_utils.prog_rev_vocab["_EOS"]) |
|
plen = len(ptoks) |
|
dset = train_set if pidx < nbr_cases / inputs_per_prog else test_set |
|
for _ in xrange(inputs_per_prog): |
|
if task == "progeval": |
|
inp, out = prog_ios_dict[prog].pop() |
|
dset[task][bin_for(plen)].append([[ptoks, inp, [], []], [out]]) |
|
elif task == "progsynth": |
|
plen, ilist = max(len(ptoks), 8), [[]] |
|
for _ in xrange(3): |
|
inp, out = prog_ios_dict[prog].pop() |
|
ilist.append(inp + out) |
|
dset[task][bin_for(plen)].append([ilist, [ptoks]]) |
|
|
|
for case in xrange(0 if is_prog else nbr_cases): |
|
total_time += time.time() - cur_time |
|
cur_time = time.time() |
|
if l > 10000 and case % 100 == 1: |
|
print_out(" avg gen time %.4f s" % (total_time / float(case))) |
|
if task in ["add", "badd", "qadd", "bmul", "mul"]: |
|
i, t = rand_pair(l, task) |
|
train_set[task][bin_for(len(i))].append([[[], i, [], []], [t]]) |
|
i, t = rand_pair(l, task) |
|
test_set[task][bin_for(len(i))].append([[[], i, [], []], [t]]) |
|
elif task == "dup": |
|
i, t = rand_dup_pair(l) |
|
train_set[task][bin_for(len(i))].append([[i], [t]]) |
|
i, t = rand_dup_pair(l) |
|
test_set[task][bin_for(len(i))].append([[i], [t]]) |
|
elif task == "rev2": |
|
i, t = rand_rev2_pair(l) |
|
train_set[task][bin_for(len(i))].append([[i], [t]]) |
|
i, t = rand_rev2_pair(l) |
|
test_set[task][bin_for(len(i))].append([[i], [t]]) |
|
elif task == "search": |
|
i, t = rand_search_pair(l) |
|
train_set[task][bin_for(len(i))].append([[i], [t]]) |
|
i, t = rand_search_pair(l) |
|
test_set[task][bin_for(len(i))].append([[i], [t]]) |
|
elif task == "kvsort": |
|
i, t = rand_kvsort_pair(l) |
|
train_set[task][bin_for(len(i))].append([[i], [t]]) |
|
i, t = rand_kvsort_pair(l) |
|
test_set[task][bin_for(len(i))].append([[i], [t]]) |
|
elif task not in ["progeval", "progsynth"]: |
|
inp = [np.random.randint(nclass - 1) + 1 for i in xrange(l)] |
|
target = spec(inp) |
|
train_set[task][bin_for(l)].append([[inp], [target]]) |
|
inp = [np.random.randint(nclass - 1) + 1 for i in xrange(l)] |
|
target = spec(inp) |
|
test_set[task][bin_for(l)].append([[inp], [target]]) |
|
|
|
|
|
def to_symbol(i): |
|
"""Covert ids to text.""" |
|
if i == 0: return "" |
|
if i == 11: return "+" |
|
if i == 12: return "*" |
|
return str(i-1) |
|
|
|
|
|
def to_id(s): |
|
"""Covert text to ids.""" |
|
if s == "+": return 11 |
|
if s == "*": return 12 |
|
return int(s) + 1 |
|
|
|
|
|
def get_batch(bin_id, batch_size, data_set, height, offset=None, preset=None): |
|
"""Get a batch of data, training or testing.""" |
|
inputs, targets = [], [] |
|
pad_length = bins[bin_id] |
|
for b in xrange(batch_size): |
|
if preset is None: |
|
elem = random.choice(data_set[bin_id]) |
|
if offset is not None and offset + b < len(data_set[bin_id]): |
|
elem = data_set[bin_id][offset + b] |
|
else: |
|
elem = preset |
|
inpt, targett, inpl, targetl = elem[0], elem[1], [], [] |
|
for inp in inpt: |
|
inpl.append(inp + [0 for _ in xrange(pad_length - len(inp))]) |
|
if len(inpl) == 1: |
|
for _ in xrange(height - 1): |
|
inpl.append([0 for _ in xrange(pad_length)]) |
|
for target in targett: |
|
targetl.append(target + [0 for _ in xrange(pad_length - len(target))]) |
|
inputs.append(inpl) |
|
targets.append(targetl) |
|
res_input = np.array(inputs, dtype=np.int32) |
|
res_target = np.array(targets, dtype=np.int32) |
|
assert list(res_input.shape) == [batch_size, height, pad_length] |
|
assert list(res_target.shape) == [batch_size, 1, pad_length] |
|
return res_input, res_target |
|
|
|
|
|
def print_out(s, newline=True): |
|
"""Print a message out and log it to file.""" |
|
if log_filename: |
|
try: |
|
with tf.gfile.GFile(log_filename, mode="a") as f: |
|
f.write(s + ("\n" if newline else "")) |
|
|
|
except: |
|
sys.stderr.write("Error appending to %s\n" % log_filename) |
|
sys.stdout.write(s + ("\n" if newline else "")) |
|
sys.stdout.flush() |
|
|
|
|
|
def decode(output): |
|
return [np.argmax(o, axis=1) for o in output] |
|
|
|
|
|
def accuracy(inpt_t, output, target_t, batch_size, nprint, |
|
beam_out=None, beam_scores=None): |
|
"""Calculate output accuracy given target.""" |
|
assert nprint < batch_size + 1 |
|
inpt = [] |
|
for h in xrange(inpt_t.shape[1]): |
|
inpt.extend([inpt_t[:, h, l] for l in xrange(inpt_t.shape[2])]) |
|
target = [target_t[:, 0, l] for l in xrange(target_t.shape[2])] |
|
def tok(i): |
|
if rev_vocab and i < len(rev_vocab): |
|
return rev_vocab[i] |
|
return str(i - 1) |
|
def task_print(inp, output, target): |
|
stop_bound = 0 |
|
print_len = 0 |
|
while print_len < len(target) and target[print_len] > stop_bound: |
|
print_len += 1 |
|
print_out(" i: " + " ".join([tok(i) for i in inp if i > 0])) |
|
print_out(" o: " + |
|
" ".join([tok(output[l]) for l in xrange(print_len)])) |
|
print_out(" t: " + |
|
" ".join([tok(target[l]) for l in xrange(print_len)])) |
|
decoded_target = target |
|
decoded_output = decode(output) |
|
|
|
if beam_out is not None: |
|
for b in xrange(batch_size): |
|
if beam_scores[b] >= 10.0: |
|
for l in xrange(min(len(decoded_output), beam_out.shape[2])): |
|
decoded_output[l][b] = int(beam_out[b, 0, l]) |
|
total = 0 |
|
errors = 0 |
|
seq = [0 for b in xrange(batch_size)] |
|
for l in xrange(len(decoded_output)): |
|
for b in xrange(batch_size): |
|
if decoded_target[l][b] > 0: |
|
total += 1 |
|
if decoded_output[l][b] != decoded_target[l][b]: |
|
seq[b] = 1 |
|
errors += 1 |
|
e = 0 |
|
for _ in xrange(min(nprint, sum(seq))): |
|
while seq[e] == 0: |
|
e += 1 |
|
task_print([inpt[l][e] for l in xrange(len(inpt))], |
|
[decoded_output[l][e] for l in xrange(len(decoded_target))], |
|
[decoded_target[l][e] for l in xrange(len(decoded_target))]) |
|
e += 1 |
|
for b in xrange(nprint - errors): |
|
task_print([inpt[l][b] for l in xrange(len(inpt))], |
|
[decoded_output[l][b] for l in xrange(len(decoded_target))], |
|
[decoded_target[l][b] for l in xrange(len(decoded_target))]) |
|
return errors, total, sum(seq) |
|
|
|
|
|
def safe_exp(x): |
|
perp = 10000 |
|
x = float(x) |
|
if x < 100: perp = math.exp(x) |
|
if perp > 10000: return 10000 |
|
return perp |
|
|