import numpy as np from funcy import identity import raven_utils as rv from raven_utils.constant import PROPERTY from raven_utils.decode import target_mask from raven_utils.image import draw_images from raven_utils.render.rendering import render_panels from raven_utils.tools import filter_keys, is_model, il from raven_utils.uitls import get_matrix from tensorflow.keras.models import load_model from raven_utils.draw import render_from_model import models import ast def render_from_model(data,predict,pre_fn=identity): data = filter_keys(data, PROPERTY, reverse=True) if is_model(predict) or str(type(predict)) == "._UserObject'>": predict = predict(data) pro = np.array(target_mask(predict['predict_mask'].numpy()) * predict["predict"].numpy(), dtype=np.int8) return pre_fn(render_panels(pro, target=False)[None])[0] def load_example(index=0): index = ast.literal_eval(str(index)) if il(index): example = rv.draw.render_panels(np.array(index)) desc = "Custom matrix" else: if not index: index = 0 index = int(index) desc = models.properties[index]['Description'] example = get_matrix( np.array(models.data[index:index + 1]['inputs'], dtype="uint8"), np.array(models.data[index:index + 1]['index'], dtype="uint8")[..., None] ) result = np.tile(draw_images(example[:9], row=3), reps=(1, 1, 3)) return result, desc def load_model_(name): if name == "Transformer": path = "/home/jkwiatkowski/all/best/rav/full_trans/6e8e6bad403e4171ad10daa1a518ba09" else: path = name models.model = load_model(path) return f"Success loading: {name}" def run_nn(index=0): index = ast.literal_eval(str(index)) if il(index): data = rv.draw.render_panels(np.array(index)) data = np.concatenate([data, data[:7]])[None] else: if not index: index = models.START_IMAGE index = int(index) data = models.data[index:index + 1]['inputs'] # model = load_model("/home/jkwiatkowski/all/best/rav/full_trans/6e8e6bad403e4171ad10daa1a518ba09") data = { 'inputs': data, 'index': np.zeros(shape=(1, 1), dtype="uint8"), 'target': np.zeros(shape=(1, 16, 113), dtype="int8"), } res = np.tile(render_from_model(data, models.model)[0, ..., None], reps=(1, 1, 3)) # res = model({'inputs': data[0:1]}) return res def next_(index=0): index = ast.literal_eval(str(index)) if not isinstance(index, int): index = models.START_IMAGE index = int(index) + 1 return (index,) + load_example(index) def prev_(index=0): index = ast.literal_eval(str(index)) if not isinstance(index, int): index = models.START_IMAGE index = int(index) - 1 return (index,) + load_example(index) if __name__ == '__main__': image, _ = load_example(5) run_nn(image)