File size: 3,228 Bytes
e986ee1
38f87b5
e986ee1
 
38f87b5
 
 
 
 
e986ee1
 
 
 
 
 
38f87b5
 
 
 
 
 
 
e986ee1
 
 
 
 
 
 
 
 
 
 
758ad23
e986ee1
758ad23
 
 
 
e986ee1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
758ad23
e986ee1
 
bf96524
 
 
 
 
fadd4cc
bf96524
e986ee1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import numpy as np
from funcy import identity

import raven_utils as rv
from raven_utils.constant import PROPERTY
from raven_utils.decode import target_mask
from raven_utils.image import draw_images
from raven_utils.render.rendering import render_panels
from raven_utils.tools import filter_keys, is_model, il
from raven_utils.uitls import get_matrix
from tensorflow.keras.models import load_model
from raven_utils.draw import render_from_model
import models
import ast

def render_from_model(data,predict,pre_fn=identity):
    data = filter_keys(data, PROPERTY, reverse=True)
    if is_model(predict) or str(type(predict)) == "<class 'tensorflow.python.saved_model.load.Loader._recreate_base_user_object.<locals>._UserObject'>":
        predict = predict(data)
    pro = np.array(target_mask(predict['predict_mask'].numpy()) * predict["predict"].numpy(), dtype=np.int8)
    return pre_fn(render_panels(pro, target=False)[None])[0]


def load_example(index=0):
    index = ast.literal_eval(str(index))
    if il(index):
        example = rv.draw.render_panels(np.array(index))
        desc = "Custom matrix"
    else:
        if not index:
            index = 0
        index = int(index)

        desc = models.properties[index]['Description']

        example = get_matrix(
            np.array(models.data[index:index + 1]['inputs'], dtype="uint8"),
            np.array(models.data[index:index + 1]['index'], dtype="uint8")[..., None]
        )
    result = np.tile(draw_images(example[:9], row=3), reps=(1, 1, 3))
    return result, desc


def load_model_(name):
    if name == "Transformer":
        path = "/home/jkwiatkowski/all/best/rav/full_trans/6e8e6bad403e4171ad10daa1a518ba09"
    else:
        path = name
    models.model = load_model(path)
    return f"Success loading: {name}"


def run_nn(index=0):
    index = ast.literal_eval(str(index))
    if il(index):
        data = rv.draw.render_panels(np.array(index))
        data = np.concatenate([data, data[:7]])[None]
    else:
        if not index:
            index = models.START_IMAGE
        index = int(index)
        data = models.data[index:index + 1]['inputs']

    # model = load_model("/home/jkwiatkowski/all/best/rav/full_trans/6e8e6bad403e4171ad10daa1a518ba09")
    # data = {
    #     'inputs': data,
    #     'index': np.zeros(shape=(1, 1), dtype="uint8"),
    #     'target': np.zeros(shape=(1, 16, 113), dtype="int8"),
    # }
    data = {
        'inputs': np.asarray(data, dtype="uint8"),
        'index': np.zeros(shape=(1, 1), dtype="uint8"),
        'target': np.zeros(shape=(1, 16, 113), dtype="int8"),
    }
    res = np.tile(render_from_model(data, models.model)[0, ..., None], reps=(1, 1, 3))

    # res = model({'inputs': data[0:1]})

    return res


def next_(index=0):
    index = ast.literal_eval(str(index))
    if not isinstance(index, int):
        index = models.START_IMAGE
    index = int(index) + 1
    return (index,) + load_example(index)


def prev_(index=0):
    index = ast.literal_eval(str(index))
    if not isinstance(index, int):
        index = models.START_IMAGE
    index = int(index) - 1
    return (index,) + load_example(index)


if __name__ == '__main__':
    image, _ = load_example(5)
    run_nn(image)