import numpy as np import raven_utils as rv import tensorflow as tf def np_split(ary, indices_or_sections, axis=-1): return np.split(ary, np.cumsum(indices_or_sections), axis)[:-1] def lw(data, none_empty=True, convert_tuple=True): if isinstance(data, list): return data elif isinstance(data, tuple) and convert_tuple: return list(data) if none_empty and data is None: return [] return [data] def ibin(x): return tf.cast(bin(x), dtype=tf.int32) def output(x, split_fn=np_split, predict_fn_1=np.argmax, predict_fn_2=ibin): res = output_divide(x, split_fn=split_fn) res = output_predict(res, predict_fn_1=predict_fn_1, predict_fn_2=predict_fn_2) return (res[0], res[1]) + tuple(output_properties(res[2], predict_fn=predict_fn_1)) def output_divide(output, split_fn=np_split): group_output = output[..., rv.output.GROUP_SLICE_END] slot_output = output[..., rv.output.SLOT_SLICE_END] properties_output = output[..., rv.output.PROPERTIES_SLICE_END] properties_output_splited = split_fn(properties_output, list(rv.properties.INDEX.values()), axis=-1) return group_output, slot_output, properties_output_splited def output_predict(output, predict_fn_1=np.argmax, predict_fn_2=ibin): return predict_fn_1(output[0]), predict_fn_2(output[1]), output[2] def output_properties(x, predict_fn=np.argmax): out_reshaped = [] for i, out in enumerate(x): shape = (-1, rv.entity.SUM, rv.properties.RAW_SIZE[i]) out_reshaped.append(predict_fn(out.reshape(shape))) return out_reshaped def output_result(output, split_fn=np_split, arg_max=np.argmax): result = output_properties(output, predict_fn=split_fn) res = [] for i, r in enumerate(result): if i == 1: res.append(r) else: res.append(arg_max(r, axis=-1)) return tuple(res) def decode_inference(inference, reshape=np.reshape): return reshape(inference[rv.inference.SLOT_SLICE], [-1, rv.group.NO, rv.inference.PROPERTY_TRANSFORMATION_NO]), reshape( inference[rv.inference.PROPERTIES_SLICE], [-1, rv.properties.NO, rv.entity.SUM, rv.inference.PROPERTY_TRANSFORMATION_NO]) def decode_target(target): target_group = target[..., 0] target_slot = target[..., 1:rv.target.INDEX[0]] target_properties = target[..., rv.target.INDEX[0]:rv.target.END_INDEX] target_properties_splited = [ target_properties[..., ::rv.properties.NO], target_properties[..., 1::rv.properties.NO], target_properties[..., 2::rv.properties.NO] ] return target_group, target_slot, target_properties_splited def decode_target_flat(target): t = decode_target(target) return t[0], t[1], t[2][0], t[2][1], t[2][2] def demask(target, mask=None, group=None, zeroes=None): if mask is None: if group is None: group = target[0] # todo Use numpy range Mask from models.uitls_ import RangeMask mask = RangeMask()(group).numpy() if zeroes is None: return np.concatenate([t[mask] for t in lw(target[1:])]) return np.concatenate([target[0][None]] + [t * mask for t in lw(target[1:])], axis=-1) def target_mask(mask, right=1): shape = mask.shape return np.concatenate([np.ones([shape[0], 1]), mask, np.repeat(mask, 3, axis=1), np.ones([shape[0], right])], axis=1) def get_full_range_mask(mask): return np.concatenate([mask, np.repeat(mask, 3, axis=-1)], axis=-1) def compare(target, predict, mask): target_comp = target[:, 1:rv.target.END_INDEX] predict_comp = predict[:, 1:rv.target.END_INDEX] mask = get_full_range_mask(mask) target_masked = target_comp * mask predict_masked = predict_comp * mask return target_masked == predict_masked