raven / raven_utils /decode.py
Jakub Kwiatkowski
Refactor.
38f87b5
raw
history blame
3.84 kB
import numpy as np
import raven_utils as rv
import tensorflow as tf
def np_split(ary, indices_or_sections, axis=-1):
return np.split(ary, np.cumsum(indices_or_sections), axis)[:-1]
def lw(data, none_empty=True, convert_tuple=True):
if isinstance(data, list):
return data
elif isinstance(data, tuple) and convert_tuple:
return list(data)
if none_empty and data is None:
return []
return [data]
def ibin(x):
return tf.cast(bin(x), dtype=tf.int32)
def output(x, split_fn=np_split, predict_fn_1=np.argmax, predict_fn_2=ibin):
res = output_divide(x, split_fn=split_fn)
res = output_predict(res, predict_fn_1=predict_fn_1, predict_fn_2=predict_fn_2)
return (res[0], res[1]) + tuple(output_properties(res[2], predict_fn=predict_fn_1))
def output_divide(output, split_fn=np_split):
group_output = output[..., rv.output.GROUP_SLICE_END]
slot_output = output[..., rv.output.SLOT_SLICE_END]
properties_output = output[..., rv.output.PROPERTIES_SLICE_END]
properties_output_splited = split_fn(properties_output, list(rv.properties.INDEX.values()), axis=-1)
return group_output, slot_output, properties_output_splited
def output_predict(output, predict_fn_1=np.argmax, predict_fn_2=ibin):
return predict_fn_1(output[0]), predict_fn_2(output[1]), output[2]
def output_properties(x, predict_fn=np.argmax):
out_reshaped = []
for i, out in enumerate(x):
shape = (-1, rv.entity.SUM, rv.properties.RAW_SIZE[i])
out_reshaped.append(predict_fn(out.reshape(shape)))
return out_reshaped
def output_result(output, split_fn=np_split, arg_max=np.argmax):
result = output_properties(output, predict_fn=split_fn)
res = []
for i, r in enumerate(result):
if i == 1:
res.append(r)
else:
res.append(arg_max(r, axis=-1))
return tuple(res)
def decode_inference(inference, reshape=np.reshape):
return reshape(inference[rv.inference.SLOT_SLICE],
[-1, rv.group.NO, rv.inference.PROPERTY_TRANSFORMATION_NO]), reshape(
inference[rv.inference.PROPERTIES_SLICE],
[-1, rv.properties.NO, rv.entity.SUM, rv.inference.PROPERTY_TRANSFORMATION_NO])
def decode_target(target):
target_group = target[..., 0]
target_slot = target[..., 1:rv.target.INDEX[0]]
target_properties = target[..., rv.target.INDEX[0]:rv.target.END_INDEX]
target_properties_splited = [
target_properties[..., ::rv.properties.NO],
target_properties[..., 1::rv.properties.NO],
target_properties[..., 2::rv.properties.NO]
]
return target_group, target_slot, target_properties_splited
def decode_target_flat(target):
t = decode_target(target)
return t[0], t[1], t[2][0], t[2][1], t[2][2]
def demask(target, mask=None, group=None, zeroes=None):
if mask is None:
if group is None:
group = target[0]
# todo Use numpy range Mask
from models.uitls_ import RangeMask
mask = RangeMask()(group).numpy()
if zeroes is None:
return np.concatenate([t[mask] for t in lw(target[1:])])
return np.concatenate([target[0][None]] + [t * mask for t in lw(target[1:])], axis=-1)
def target_mask(mask, right=1):
shape = mask.shape
return np.concatenate([np.ones([shape[0], 1]), mask, np.repeat(mask, 3, axis=1), np.ones([shape[0], right])],
axis=1)
def get_full_range_mask(mask):
return np.concatenate([mask, np.repeat(mask, 3, axis=-1)], axis=-1)
def compare(target, predict, mask):
target_comp = target[:, 1:rv.target.END_INDEX]
predict_comp = predict[:, 1:rv.target.END_INDEX]
mask = get_full_range_mask(mask)
target_masked = target_comp * mask
predict_masked = predict_comp * mask
return target_masked == predict_masked