NCTCMumbai's picture
Upload 2583 files
18ddfe2 verified
raw
history blame
16.2 kB
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Neural GPU -- data generation and batching utilities."""
import math
import os
import random
import sys
import time
import numpy as np
from six.moves import xrange
import tensorflow as tf
import program_utils
FLAGS = tf.app.flags.FLAGS
bins = [2 + bin_idx_i for bin_idx_i in xrange(256)]
all_tasks = ["sort", "kvsort", "id", "rev", "rev2", "incr", "add", "left",
"right", "left-shift", "right-shift", "bmul", "mul", "dup",
"badd", "qadd", "search", "progeval", "progsynth"]
log_filename = ""
vocab, rev_vocab = None, None
def pad(l):
for b in bins:
if b >= l: return b
return bins[-1]
def bin_for(l):
for i, b in enumerate(bins):
if b >= l: return i
return len(bins) - 1
train_set = {}
test_set = {}
for some_task in all_tasks:
train_set[some_task] = []
test_set[some_task] = []
for all_max_len in xrange(10000):
train_set[some_task].append([])
test_set[some_task].append([])
def read_tmp_file(name):
"""Read from a file with the given name in our log directory or above."""
dirname = os.path.dirname(log_filename)
fname = os.path.join(dirname, name + ".txt")
if not tf.gfile.Exists(fname):
print_out("== not found file: " + fname)
fname = os.path.join(dirname, "../" + name + ".txt")
if not tf.gfile.Exists(fname):
print_out("== not found file: " + fname)
fname = os.path.join(dirname, "../../" + name + ".txt")
if not tf.gfile.Exists(fname):
print_out("== not found file: " + fname)
return None
print_out("== found file: " + fname)
res = []
with tf.gfile.GFile(fname, mode="r") as f:
for line in f:
res.append(line.strip())
return res
def write_tmp_file(name, lines):
dirname = os.path.dirname(log_filename)
fname = os.path.join(dirname, name + ".txt")
with tf.gfile.GFile(fname, mode="w") as f:
for line in lines:
f.write(line + "\n")
def add(n1, n2, base=10):
"""Add two numbers represented as lower-endian digit lists."""
k = max(len(n1), len(n2)) + 1
d1 = n1 + [0 for _ in xrange(k - len(n1))]
d2 = n2 + [0 for _ in xrange(k - len(n2))]
res = []
carry = 0
for i in xrange(k):
if d1[i] + d2[i] + carry < base:
res.append(d1[i] + d2[i] + carry)
carry = 0
else:
res.append(d1[i] + d2[i] + carry - base)
carry = 1
while res and res[-1] == 0:
res = res[:-1]
if res: return res
return [0]
def init_data(task, length, nbr_cases, nclass):
"""Data initialization."""
def rand_pair(l, task):
"""Random data pair for a task. Total length should be <= l."""
k = int((l-1)/2)
base = 10
if task[0] == "b": base = 2
if task[0] == "q": base = 4
d1 = [np.random.randint(base) for _ in xrange(k)]
d2 = [np.random.randint(base) for _ in xrange(k)]
if task in ["add", "badd", "qadd"]:
res = add(d1, d2, base)
elif task in ["mul", "bmul"]:
d1n = sum([d * (base ** i) for i, d in enumerate(d1)])
d2n = sum([d * (base ** i) for i, d in enumerate(d2)])
if task == "bmul":
res = [int(x) for x in list(reversed(str(bin(d1n * d2n))))[:-2]]
else:
res = [int(x) for x in list(reversed(str(d1n * d2n)))]
else:
sys.exit()
sep = [12]
if task in ["add", "badd", "qadd"]: sep = [11]
inp = [d + 1 for d in d1] + sep + [d + 1 for d in d2]
return inp, [r + 1 for r in res]
def rand_dup_pair(l):
"""Random data pair for duplication task. Total length should be <= l."""
k = int(l/2)
x = [np.random.randint(nclass - 1) + 1 for _ in xrange(k)]
inp = x + [0 for _ in xrange(l - k)]
res = x + x + [0 for _ in xrange(l - 2*k)]
return inp, res
def rand_rev2_pair(l):
"""Random data pair for reverse2 task. Total length should be <= l."""
inp = [(np.random.randint(nclass - 1) + 1,
np.random.randint(nclass - 1) + 1) for _ in xrange(l/2)]
res = [i for i in reversed(inp)]
return [x for p in inp for x in p], [x for p in res for x in p]
def rand_search_pair(l):
"""Random data pair for search task. Total length should be <= l."""
inp = [(np.random.randint(nclass - 1) + 1,
np.random.randint(nclass - 1) + 1) for _ in xrange(l-1/2)]
q = np.random.randint(nclass - 1) + 1
res = 0
for (k, v) in reversed(inp):
if k == q:
res = v
return [x for p in inp for x in p] + [q], [res]
def rand_kvsort_pair(l):
"""Random data pair for key-value sort. Total length should be <= l."""
keys = [(np.random.randint(nclass - 1) + 1, i) for i in xrange(l/2)]
vals = [np.random.randint(nclass - 1) + 1 for _ in xrange(l/2)]
kv = [(k, vals[i]) for (k, i) in keys]
sorted_kv = [(k, vals[i]) for (k, i) in sorted(keys)]
return [x for p in kv for x in p], [x for p in sorted_kv for x in p]
def prog_io_pair(prog, max_len, counter=0):
try:
ilen = np.random.randint(max_len - 3) + 1
bound = max(15 - (counter / 20), 1)
inp = [random.choice(range(-bound, bound)) for _ in range(ilen)]
inp_toks = [program_utils.prog_rev_vocab[t]
for t in program_utils.tokenize(str(inp)) if t != ","]
out = program_utils.evaluate(prog, {"a": inp})
out_toks = [program_utils.prog_rev_vocab[t]
for t in program_utils.tokenize(str(out)) if t != ","]
if counter > 400:
out_toks = []
if (out_toks and out_toks[0] == program_utils.prog_rev_vocab["["] and
len(out_toks) != len([o for o in out if o == ","]) + 3):
raise ValueError("generated list with too long ints")
if (out_toks and out_toks[0] != program_utils.prog_rev_vocab["["] and
len(out_toks) > 1):
raise ValueError("generated one int but tokenized it to many")
if len(out_toks) > max_len:
raise ValueError("output too long")
return (inp_toks, out_toks)
except ValueError:
return prog_io_pair(prog, max_len, counter+1)
def spec(inp):
"""Return the target given the input for some tasks."""
if task == "sort":
return sorted(inp)
elif task == "id":
return inp
elif task == "rev":
return [i for i in reversed(inp)]
elif task == "incr":
carry = 1
res = []
for i in xrange(len(inp)):
if inp[i] + carry < nclass:
res.append(inp[i] + carry)
carry = 0
else:
res.append(1)
carry = 1
return res
elif task == "left":
return [inp[0]]
elif task == "right":
return [inp[-1]]
elif task == "left-shift":
return [inp[l-1] for l in xrange(len(inp))]
elif task == "right-shift":
return [inp[l+1] for l in xrange(len(inp))]
else:
print_out("Unknown spec for task " + str(task))
sys.exit()
l = length
cur_time = time.time()
total_time = 0.0
is_prog = task in ["progeval", "progsynth"]
if is_prog:
inputs_per_prog = 5
program_utils.make_vocab()
progs = read_tmp_file("programs_len%d" % (l / 10))
if not progs:
progs = program_utils.gen(l / 10, 1.2 * nbr_cases / inputs_per_prog)
write_tmp_file("programs_len%d" % (l / 10), progs)
prog_ios = read_tmp_file("programs_len%d_io" % (l / 10))
nbr_cases = min(nbr_cases, len(progs) * inputs_per_prog) / 1.2
if not prog_ios:
# Generate program io data.
prog_ios = []
for pidx, prog in enumerate(progs):
if pidx % 500 == 0:
print_out("== generating io pairs for program %d" % pidx)
if pidx * inputs_per_prog > nbr_cases * 1.2:
break
ptoks = [program_utils.prog_rev_vocab[t]
for t in program_utils.tokenize(prog)]
ptoks.append(program_utils.prog_rev_vocab["_EOS"])
plen = len(ptoks)
for _ in xrange(inputs_per_prog):
if task == "progeval":
inp, out = prog_io_pair(prog, plen)
prog_ios.append(str(inp) + "\t" + str(out) + "\t" + prog)
elif task == "progsynth":
plen = max(len(ptoks), 8)
for _ in xrange(3):
inp, out = prog_io_pair(prog, plen / 2)
prog_ios.append(str(inp) + "\t" + str(out) + "\t" + prog)
write_tmp_file("programs_len%d_io" % (l / 10), prog_ios)
prog_ios_dict = {}
for s in prog_ios:
i, o, p = s.split("\t")
i_clean = "".join([c for c in i if c.isdigit() or c == " "])
o_clean = "".join([c for c in o if c.isdigit() or c == " "])
inp = [int(x) for x in i_clean.split()]
out = [int(x) for x in o_clean.split()]
if inp and out:
if p in prog_ios_dict:
prog_ios_dict[p].append([inp, out])
else:
prog_ios_dict[p] = [[inp, out]]
# Use prog_ios_dict to create data.
progs = []
for prog in prog_ios_dict:
if len([c for c in prog if c == ";"]) <= (l / 10):
progs.append(prog)
nbr_cases = min(nbr_cases, len(progs) * inputs_per_prog) / 1.2
print_out("== %d training cases on %d progs" % (nbr_cases, len(progs)))
for pidx, prog in enumerate(progs):
if pidx * inputs_per_prog > nbr_cases * 1.2:
break
ptoks = [program_utils.prog_rev_vocab[t]
for t in program_utils.tokenize(prog)]
ptoks.append(program_utils.prog_rev_vocab["_EOS"])
plen = len(ptoks)
dset = train_set if pidx < nbr_cases / inputs_per_prog else test_set
for _ in xrange(inputs_per_prog):
if task == "progeval":
inp, out = prog_ios_dict[prog].pop()
dset[task][bin_for(plen)].append([[ptoks, inp, [], []], [out]])
elif task == "progsynth":
plen, ilist = max(len(ptoks), 8), [[]]
for _ in xrange(3):
inp, out = prog_ios_dict[prog].pop()
ilist.append(inp + out)
dset[task][bin_for(plen)].append([ilist, [ptoks]])
for case in xrange(0 if is_prog else nbr_cases):
total_time += time.time() - cur_time
cur_time = time.time()
if l > 10000 and case % 100 == 1:
print_out(" avg gen time %.4f s" % (total_time / float(case)))
if task in ["add", "badd", "qadd", "bmul", "mul"]:
i, t = rand_pair(l, task)
train_set[task][bin_for(len(i))].append([[[], i, [], []], [t]])
i, t = rand_pair(l, task)
test_set[task][bin_for(len(i))].append([[[], i, [], []], [t]])
elif task == "dup":
i, t = rand_dup_pair(l)
train_set[task][bin_for(len(i))].append([[i], [t]])
i, t = rand_dup_pair(l)
test_set[task][bin_for(len(i))].append([[i], [t]])
elif task == "rev2":
i, t = rand_rev2_pair(l)
train_set[task][bin_for(len(i))].append([[i], [t]])
i, t = rand_rev2_pair(l)
test_set[task][bin_for(len(i))].append([[i], [t]])
elif task == "search":
i, t = rand_search_pair(l)
train_set[task][bin_for(len(i))].append([[i], [t]])
i, t = rand_search_pair(l)
test_set[task][bin_for(len(i))].append([[i], [t]])
elif task == "kvsort":
i, t = rand_kvsort_pair(l)
train_set[task][bin_for(len(i))].append([[i], [t]])
i, t = rand_kvsort_pair(l)
test_set[task][bin_for(len(i))].append([[i], [t]])
elif task not in ["progeval", "progsynth"]:
inp = [np.random.randint(nclass - 1) + 1 for i in xrange(l)]
target = spec(inp)
train_set[task][bin_for(l)].append([[inp], [target]])
inp = [np.random.randint(nclass - 1) + 1 for i in xrange(l)]
target = spec(inp)
test_set[task][bin_for(l)].append([[inp], [target]])
def to_symbol(i):
"""Covert ids to text."""
if i == 0: return ""
if i == 11: return "+"
if i == 12: return "*"
return str(i-1)
def to_id(s):
"""Covert text to ids."""
if s == "+": return 11
if s == "*": return 12
return int(s) + 1
def get_batch(bin_id, batch_size, data_set, height, offset=None, preset=None):
"""Get a batch of data, training or testing."""
inputs, targets = [], []
pad_length = bins[bin_id]
for b in xrange(batch_size):
if preset is None:
elem = random.choice(data_set[bin_id])
if offset is not None and offset + b < len(data_set[bin_id]):
elem = data_set[bin_id][offset + b]
else:
elem = preset
inpt, targett, inpl, targetl = elem[0], elem[1], [], []
for inp in inpt:
inpl.append(inp + [0 for _ in xrange(pad_length - len(inp))])
if len(inpl) == 1:
for _ in xrange(height - 1):
inpl.append([0 for _ in xrange(pad_length)])
for target in targett:
targetl.append(target + [0 for _ in xrange(pad_length - len(target))])
inputs.append(inpl)
targets.append(targetl)
res_input = np.array(inputs, dtype=np.int32)
res_target = np.array(targets, dtype=np.int32)
assert list(res_input.shape) == [batch_size, height, pad_length]
assert list(res_target.shape) == [batch_size, 1, pad_length]
return res_input, res_target
def print_out(s, newline=True):
"""Print a message out and log it to file."""
if log_filename:
try:
with tf.gfile.GFile(log_filename, mode="a") as f:
f.write(s + ("\n" if newline else ""))
# pylint: disable=bare-except
except:
sys.stderr.write("Error appending to %s\n" % log_filename)
sys.stdout.write(s + ("\n" if newline else ""))
sys.stdout.flush()
def decode(output):
return [np.argmax(o, axis=1) for o in output]
def accuracy(inpt_t, output, target_t, batch_size, nprint,
beam_out=None, beam_scores=None):
"""Calculate output accuracy given target."""
assert nprint < batch_size + 1
inpt = []
for h in xrange(inpt_t.shape[1]):
inpt.extend([inpt_t[:, h, l] for l in xrange(inpt_t.shape[2])])
target = [target_t[:, 0, l] for l in xrange(target_t.shape[2])]
def tok(i):
if rev_vocab and i < len(rev_vocab):
return rev_vocab[i]
return str(i - 1)
def task_print(inp, output, target):
stop_bound = 0
print_len = 0
while print_len < len(target) and target[print_len] > stop_bound:
print_len += 1
print_out(" i: " + " ".join([tok(i) for i in inp if i > 0]))
print_out(" o: " +
" ".join([tok(output[l]) for l in xrange(print_len)]))
print_out(" t: " +
" ".join([tok(target[l]) for l in xrange(print_len)]))
decoded_target = target
decoded_output = decode(output)
# Use beam output if given and score is high enough.
if beam_out is not None:
for b in xrange(batch_size):
if beam_scores[b] >= 10.0:
for l in xrange(min(len(decoded_output), beam_out.shape[2])):
decoded_output[l][b] = int(beam_out[b, 0, l])
total = 0
errors = 0
seq = [0 for b in xrange(batch_size)]
for l in xrange(len(decoded_output)):
for b in xrange(batch_size):
if decoded_target[l][b] > 0:
total += 1
if decoded_output[l][b] != decoded_target[l][b]:
seq[b] = 1
errors += 1
e = 0 # Previous error index
for _ in xrange(min(nprint, sum(seq))):
while seq[e] == 0:
e += 1
task_print([inpt[l][e] for l in xrange(len(inpt))],
[decoded_output[l][e] for l in xrange(len(decoded_target))],
[decoded_target[l][e] for l in xrange(len(decoded_target))])
e += 1
for b in xrange(nprint - errors):
task_print([inpt[l][b] for l in xrange(len(inpt))],
[decoded_output[l][b] for l in xrange(len(decoded_target))],
[decoded_target[l][b] for l in xrange(len(decoded_target))])
return errors, total, sum(seq)
def safe_exp(x):
perp = 10000
x = float(x)
if x < 100: perp = math.exp(x)
if perp > 10000: return 10000
return perp