Spaces:
Build error
Build error
Jakub Kwiatkowski
commited on
Commit
·
e986ee1
1
Parent(s):
9396266
Add model.
Browse files- main.py +49 -0
- models.py +13 -0
- raven_utils/__init__.py +10 -0
- raven_utils/config/__init__.py +0 -0
- raven_utils/config/constant.py +54 -0
- raven_utils/config/models.py +9 -0
- raven_utils/const.py +2 -0
- raven_utils/constant.py +53 -0
- raven_utils/data.py +46 -0
- raven_utils/decode.py +100 -0
- raven_utils/depricated/__init__.py +0 -0
- raven_utils/depricated/old_raven.py +490 -0
- raven_utils/draw.py +174 -0
- raven_utils/entity.py +6 -0
- raven_utils/group.py +11 -0
- raven_utils/inference.py +15 -0
- raven_utils/models/__init__.py +0 -0
- raven_utils/models/attn.py +187 -0
- raven_utils/models/attn2.py +187 -0
- raven_utils/models/augment.py +0 -0
- raven_utils/models/body.py +276 -0
- raven_utils/models/class_.py +31 -0
- raven_utils/models/head.py +159 -0
- raven_utils/models/loss.py +630 -0
- raven_utils/models/loss_3.py +638 -0
- raven_utils/models/multi_transformer.py +274 -0
- raven_utils/models/raven.py +239 -0
- raven_utils/models/trans.py +74 -0
- raven_utils/models/transformer.py +133 -0
- raven_utils/models/transformer_2.py +146 -0
- raven_utils/models/transformer_3.py +206 -0
- raven_utils/models/uitls_.py +16 -0
- raven_utils/output.py +16 -0
- raven_utils/params.py +110 -0
- raven_utils/properties.py +16 -0
- raven_utils/range_mask.py +16 -0
- raven_utils/render/__init__.py +0 -0
- raven_utils/render/const.py +86 -0
- raven_utils/render/rendering.py +304 -0
- raven_utils/render_.py +104 -0
- raven_utils/rules.py +21 -0
- raven_utils/target.py +50 -0
- raven_utils/uitls.py +64 -0
- saved_model/1/keras_metadata.pb +3 -0
- saved_model/1/saved_model.pb +3 -0
- saved_model/1/variables/variables.data-00000-of-00001 +3 -0
- saved_model/1/variables/variables.index +3 -0
- utils.py +84 -0
main.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
from utils import load_example, run_nn, load_model_, next_, prev_
|
4 |
+
|
5 |
+
demo = gr.Blocks()
|
6 |
+
import models
|
7 |
+
|
8 |
+
with demo:
|
9 |
+
headline = gr.Markdown("## Raven resolver ")
|
10 |
+
markdown = gr.Markdown("Below we show all 9 images from raven matrix. "
|
11 |
+
"Model gets 8 images and predicts the properties of last one. "
|
12 |
+
"Based on this properties the answer image is render in the right panel. <br />"
|
13 |
+
"Note that angle rotation is only used as a noise. "
|
14 |
+
"There are not rules applied to angle property, so angle rotation of final output do not need to be the same as in example. "
|
15 |
+
"Additionally there are cases that other properties could be used as noise.")
|
16 |
+
with gr.Row():
|
17 |
+
with gr.Column():
|
18 |
+
with gr.Row():
|
19 |
+
text = gr.Textbox(models.START_IMAGE,
|
20 |
+
label="Write the example number from validation dataset (0, 14,000). You can also paste here matrix representation from generator.")
|
21 |
+
with gr.Row():
|
22 |
+
prev = gr.Button("Prev")
|
23 |
+
show = gr.Button("Show")
|
24 |
+
next = gr.Button("Next")
|
25 |
+
# button = gr.Button("Run")
|
26 |
+
with gr.Row():
|
27 |
+
image = gr.Image(value=load_example(models.START_IMAGE)[0], label="Raven matrix")
|
28 |
+
desc = gr.Markdown(value=load_example(models.START_IMAGE)[1])
|
29 |
+
|
30 |
+
with gr.Column():
|
31 |
+
with gr.Row():
|
32 |
+
output = gr.Image(label="Generated image", shape=(200, 200))
|
33 |
+
with gr.Row():
|
34 |
+
button = gr.Button("Run")
|
35 |
+
|
36 |
+
# text.change(load_example, inputs=text, outputs=[image, desc])
|
37 |
+
show.click(load_example, inputs=text, outputs=[image, desc])
|
38 |
+
# button.click(run_nn, inputs=image, outputs=output)
|
39 |
+
button.click(run_nn, inputs=text, outputs=output)
|
40 |
+
|
41 |
+
# next.click(next_, inputs=text, outputs=text)
|
42 |
+
# next.click(load_example, inputs=text, outputs=[image, desc])
|
43 |
+
next.click(next_, inputs=text, outputs=[text, image, desc])
|
44 |
+
|
45 |
+
# prev.click(prev_, inputs=text, outputs=text)
|
46 |
+
# prev.click(load_example, inputs=text, outputs=[image, desc])
|
47 |
+
prev.click(prev_, inputs=text, outputs=[text, image, desc])
|
48 |
+
|
49 |
+
demo.launch(debug=True)
|
models.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
START_IMAGE = 12000
|
4 |
+
|
5 |
+
from tensorflow.keras.models import load_model
|
6 |
+
model = load_model("saved_model/1")
|
7 |
+
|
8 |
+
from data_utils import nload, ims, DataSetFromFolder
|
9 |
+
data = nload("/home/jkwiatkowski/all/dataset/arr/val.npy")
|
10 |
+
indexes = nload("/home/jkwiatkowski/all/dataset/arr/val_target.npy")
|
11 |
+
|
12 |
+
folders = DataSetFromFolder("/home/jkwiatkowski/all/dataset/arr/RAVEN-10000-release/RAVEN-10000", file_type="dir")
|
13 |
+
properties = DataSetFromFolder(folders[:], file_type="xml", extension="val")
|
raven_utils/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import raven_utils.group as group
|
2 |
+
import raven_utils.entity as entity
|
3 |
+
import raven_utils.properties as properties
|
4 |
+
import raven_utils.target as target
|
5 |
+
import raven_utils.rules as rules
|
6 |
+
import raven_utils.output as output
|
7 |
+
import raven_utils.inference as inference
|
8 |
+
import raven_utils.decode as decode
|
9 |
+
import raven_utils.render_ as render_
|
10 |
+
import raven_utils.draw as draw
|
raven_utils/config/__init__.py
ADDED
File without changes
|
raven_utils/config/constant.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
RAVEN = "arr"
|
2 |
+
RAVEN_BIG = "arrb"
|
3 |
+
INDEX = "index"
|
4 |
+
LABELS = "labels"
|
5 |
+
TARGET_LABELS = "target_labels"
|
6 |
+
FEATURES = "features"
|
7 |
+
ACC_SAME = "acc_same"
|
8 |
+
ACC_CHOOSE_LOWER = "acc_choose_lower"
|
9 |
+
ACC_CHOOSE_UPPER = "acc_choose_upper"
|
10 |
+
ACC_NO_GROUP = "acc_NO_group"
|
11 |
+
CLASSIFICATION = "classification"
|
12 |
+
INFERENCE = "inference"
|
13 |
+
# PROPERTIES = "properties"
|
14 |
+
PROPERTY = "property"
|
15 |
+
MEMORY = "memory"
|
16 |
+
CONTROL = "control"
|
17 |
+
LATENT = "latent"
|
18 |
+
TARGET = "target"
|
19 |
+
INPUTS = "inputs"
|
20 |
+
RES = "res"
|
21 |
+
RESULT = "result"
|
22 |
+
MERGE = "merge"
|
23 |
+
MEMORY_STATE = "memory_state"
|
24 |
+
CONTROL_STATE = "control_state"
|
25 |
+
CONCAT = "concat"
|
26 |
+
FLATTEN = "flatten"
|
27 |
+
CROSS_ENTROPY = "cross_entropy"
|
28 |
+
SLOT = "slot"
|
29 |
+
PROPERTIES = "properties"
|
30 |
+
ACC = "acc"
|
31 |
+
GROUP = 'group'
|
32 |
+
NUMBER = 'number'
|
33 |
+
TRANS = 'trans'
|
34 |
+
TAIL = "tail"
|
35 |
+
MASK = "mask"
|
36 |
+
|
37 |
+
RAV_METRICS = [
|
38 |
+
ACC_NO_GROUP,
|
39 |
+
ACC_SAME,
|
40 |
+
ACC_CHOOSE_UPPER,
|
41 |
+
ACC_CHOOSE_LOWER,
|
42 |
+
"acc",
|
43 |
+
"c_acc_NO_group",
|
44 |
+
"c_acc",
|
45 |
+
"loss",
|
46 |
+
]
|
47 |
+
|
48 |
+
IMP_RAV_METRICS = [
|
49 |
+
ACC_NO_GROUP,
|
50 |
+
ACC_SAME,
|
51 |
+
ACC_CHOOSE_UPPER,
|
52 |
+
ACC_CHOOSE_LOWER,
|
53 |
+
ACC,
|
54 |
+
]
|
raven_utils/config/models.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
AVAILABLE_MODELS = [
|
3 |
+
"197-0.31",
|
4 |
+
"53-0.48",
|
5 |
+
"74-0.50",
|
6 |
+
"21-0.48",
|
7 |
+
"10-0.52",
|
8 |
+
"179-0.50"
|
9 |
+
]
|
raven_utils/const.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
HORIZONTAL = "horizontal"
|
2 |
+
VERTICAL = "vertical"
|
raven_utils/constant.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
RAVEN = "arr"
|
2 |
+
RAVEN_BIG = "arrb"
|
3 |
+
INDEX = "index"
|
4 |
+
LABELS = "labels"
|
5 |
+
TARGET_LABELS = "target_labels"
|
6 |
+
FEATURES = "features"
|
7 |
+
ACC_SAME = "acc_same"
|
8 |
+
ACC_CHOOSE_LOWER = "acc_choose_lower"
|
9 |
+
ACC_CHOOSE_UPPER = "acc_choose_upper"
|
10 |
+
ACC_NO_GROUP = "acc_NO_group"
|
11 |
+
CLASSIFICATION = "classification"
|
12 |
+
INFERENCE = "inference"
|
13 |
+
# PROPERTIES = "properties"
|
14 |
+
PROPERTY = "property"
|
15 |
+
MEMORY = "memory"
|
16 |
+
CONTROL = "control"
|
17 |
+
LATENT = "latent"
|
18 |
+
TARGET = "target"
|
19 |
+
INPUTS = "inputs"
|
20 |
+
RES = "res"
|
21 |
+
RESULT = "result"
|
22 |
+
MERGE = "merge"
|
23 |
+
MEMORY_STATE = "memory_state"
|
24 |
+
CONTROL_STATE = "control_state"
|
25 |
+
CONCAT = "concat"
|
26 |
+
FLATTEN = "flatten"
|
27 |
+
CROSS_ENTROPY = "cross_entropy"
|
28 |
+
SLOT = "slot"
|
29 |
+
PROPERTIES = "properties"
|
30 |
+
ACC = "acc"
|
31 |
+
GROUP = 'group'
|
32 |
+
NUMBER = 'number'
|
33 |
+
TRANS = 'trans'
|
34 |
+
TAIL = "tail"
|
35 |
+
MASK = "mask"
|
36 |
+
|
37 |
+
RAV_METRICS = [
|
38 |
+
ACC_NO_GROUP,
|
39 |
+
ACC_SAME,
|
40 |
+
ACC_CHOOSE_UPPER,
|
41 |
+
ACC_CHOOSE_LOWER,
|
42 |
+
"acc",
|
43 |
+
"c_acc_NO_group",
|
44 |
+
"c_acc",
|
45 |
+
"loss",
|
46 |
+
]
|
47 |
+
|
48 |
+
IMP_RAV_METRICS = [
|
49 |
+
ACC_NO_GROUP,
|
50 |
+
ACC_SAME,
|
51 |
+
ACC_CHOOSE_UPPER,
|
52 |
+
ACC_CHOOSE_LOWER,
|
53 |
+
]
|
raven_utils/data.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import tensorflow as tf
|
4 |
+
|
5 |
+
from models_utils import INPUTS, TARGET
|
6 |
+
|
7 |
+
from raven_utils.config.constant import RAVEN, LABELS, INDEX, FEATURES, RAV_METRICS, IMP_RAV_METRICS, ACC_NO_GROUP
|
8 |
+
|
9 |
+
|
10 |
+
from typing import Any
|
11 |
+
|
12 |
+
from data_utils import pre, Data, gather, vec, resize
|
13 |
+
from data_utils.data_generator import DataGenerator
|
14 |
+
from funcy import identity
|
15 |
+
|
16 |
+
|
17 |
+
def get_data(data, batch_size, steps=None, val_steps=None):
|
18 |
+
if val_steps is None:
|
19 |
+
val_steps = steps
|
20 |
+
fn = identity
|
21 |
+
train_target_index = data[4] + 8
|
22 |
+
train_generator = DataGenerator({
|
23 |
+
INPUTS: Data(data[0], fn),
|
24 |
+
TARGET: Data(data[2], identity),
|
25 |
+
LABELS: Data(data[2], identity),
|
26 |
+
INDEX: train_target_index[:, None],
|
27 |
+
# FEATURES: data[6]
|
28 |
+
},
|
29 |
+
batch=batch_size,
|
30 |
+
steps=steps
|
31 |
+
)
|
32 |
+
val_target_index = data[5] + 8
|
33 |
+
val_data = {
|
34 |
+
INPUTS: Data(data[1], fn),
|
35 |
+
TARGET: Data(data[3], identity),
|
36 |
+
LABELS: Data(data[3], identity),
|
37 |
+
INDEX: val_target_index[:, None],
|
38 |
+
# FEATURES: data[7]
|
39 |
+
}
|
40 |
+
val_generator = DataGenerator(
|
41 |
+
val_data,
|
42 |
+
batch=batch_size,
|
43 |
+
sampler="val",
|
44 |
+
steps=val_steps
|
45 |
+
)
|
46 |
+
return train_generator, val_generator
|
raven_utils/decode.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from data_utils import np_split
|
3 |
+
from ml_utils import lw
|
4 |
+
from models_utils.ops import ibin
|
5 |
+
|
6 |
+
import raven_utils as rv
|
7 |
+
|
8 |
+
|
9 |
+
def output(x, split_fn=np_split, predict_fn_1=np.argmax, predict_fn_2=ibin):
|
10 |
+
res = output_divide(x, split_fn=split_fn)
|
11 |
+
res = output_predict(res, predict_fn_1=predict_fn_1, predict_fn_2=predict_fn_2)
|
12 |
+
return (res[0], res[1]) + tuple(output_properties(res[2], predict_fn=predict_fn_1))
|
13 |
+
|
14 |
+
|
15 |
+
def output_divide(output, split_fn=np_split):
|
16 |
+
group_output = output[..., rv.output.GROUP_SLICE_END]
|
17 |
+
slot_output = output[..., rv.output.SLOT_SLICE_END]
|
18 |
+
properties_output = output[..., rv.output.PROPERTIES_SLICE_END]
|
19 |
+
properties_output_splited = split_fn(properties_output, list(rv.properties.INDEX.values()), axis=-1)
|
20 |
+
return group_output, slot_output, properties_output_splited
|
21 |
+
|
22 |
+
|
23 |
+
def output_predict(output, predict_fn_1=np.argmax, predict_fn_2=ibin):
|
24 |
+
return predict_fn_1(output[0]), predict_fn_2(output[1]), output[2]
|
25 |
+
|
26 |
+
|
27 |
+
def output_properties(x, predict_fn=np.argmax):
|
28 |
+
out_reshaped = []
|
29 |
+
for i, out in enumerate(x):
|
30 |
+
shape = (-1, rv.entity.SUM, rv.properties.RAW_SIZE[i])
|
31 |
+
out_reshaped.append(predict_fn(out.reshape(shape)))
|
32 |
+
return out_reshaped
|
33 |
+
|
34 |
+
|
35 |
+
def output_result(output, split_fn=np_split, arg_max=np.argmax):
|
36 |
+
result = output_properties(output, predict_fn=split_fn)
|
37 |
+
res = []
|
38 |
+
for i, r in enumerate(result):
|
39 |
+
if i == 1:
|
40 |
+
res.append(r)
|
41 |
+
else:
|
42 |
+
res.append(arg_max(r, axis=-1))
|
43 |
+
return tuple(res)
|
44 |
+
|
45 |
+
|
46 |
+
def decode_inference(inference, reshape=np.reshape):
|
47 |
+
return reshape(inference[rv.inference.SLOT_SLICE],
|
48 |
+
[-1, rv.group.NO, rv.inference.PROPERTY_TRANSFORMATION_NO]), reshape(
|
49 |
+
inference[rv.inference.PROPERTIES_SLICE],
|
50 |
+
[-1, rv.properties.NO, rv.entity.SUM, rv.inference.PROPERTY_TRANSFORMATION_NO])
|
51 |
+
|
52 |
+
|
53 |
+
def decode_target(target):
|
54 |
+
target_group = target[..., 0]
|
55 |
+
target_slot = target[..., 1:rv.target.INDEX[0]]
|
56 |
+
target_properties = target[..., rv.target.INDEX[0]:rv.target.END_INDEX]
|
57 |
+
target_properties_splited = [
|
58 |
+
target_properties[..., ::rv.properties.NO],
|
59 |
+
target_properties[..., 1::rv.properties.NO],
|
60 |
+
target_properties[..., 2::rv.properties.NO]
|
61 |
+
]
|
62 |
+
return target_group, target_slot, target_properties_splited
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
def decode_target_flat(target):
|
68 |
+
t = decode_target(target)
|
69 |
+
return t[0], t[1], t[2][0], t[2][1], t[2][2]
|
70 |
+
|
71 |
+
|
72 |
+
def demask(target, mask=None, group=None, zeroes=None):
|
73 |
+
if mask is None:
|
74 |
+
if group is None:
|
75 |
+
group = target[0]
|
76 |
+
# todo Use numpy range Mask
|
77 |
+
from models.uitls_ import RangeMask
|
78 |
+
mask = RangeMask()(group).numpy()
|
79 |
+
if zeroes is None:
|
80 |
+
return np.concatenate([t[mask] for t in lw(target[1:])])
|
81 |
+
return np.concatenate([target[0][None]] + [t * mask for t in lw(target[1:])],axis=-1)
|
82 |
+
|
83 |
+
|
84 |
+
def target_mask(mask,right=1):
|
85 |
+
shape = mask.shape
|
86 |
+
return np.concatenate([np.ones([shape[0], 1]) ,mask, np.repeat(mask,3,axis=1), np.ones([shape[0], right])],axis=1)
|
87 |
+
|
88 |
+
|
89 |
+
def get_full_range_mask(mask):
|
90 |
+
return np.concatenate([mask, np.repeat(mask, 3, axis=-1)], axis=-1)
|
91 |
+
|
92 |
+
def compare(target, predict, mask):
|
93 |
+
target_comp = target[:, 1:rv.target.END_INDEX]
|
94 |
+
predict_comp = predict[:, 1:rv.target.END_INDEX]
|
95 |
+
|
96 |
+
mask = get_full_range_mask(mask)
|
97 |
+
|
98 |
+
target_masked = target_comp * mask
|
99 |
+
predict_masked = predict_comp * mask
|
100 |
+
return target_masked == predict_masked
|
raven_utils/depricated/__init__.py
ADDED
File without changes
|
raven_utils/depricated/old_raven.py
ADDED
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from data_utils import take, EXIST, COR
|
5 |
+
from data_utils.image import draw_images, add_text
|
6 |
+
from data_utils.op import np_split
|
7 |
+
from ml_utils import lu, dict_from_list2, filter_keys, none
|
8 |
+
from data_utils import ops as K
|
9 |
+
|
10 |
+
from config.constant import PROPERTY, TARGET, INPUTS
|
11 |
+
# from raven_utils.render.rendering import render_panels
|
12 |
+
|
13 |
+
RENDER_POSITIONS = [
|
14 |
+
[(0.5, 0.5, 1, 1)],
|
15 |
+
# ...
|
16 |
+
[(0.25, 0.25, 0.5, 0.5),
|
17 |
+
(0.25, 0.75, 0.5, 0.5),
|
18 |
+
(0.75, 0.25, 0.5, 0.5),
|
19 |
+
(0.75, 0.75, 0.5, 0.5)],
|
20 |
+
# ...
|
21 |
+
[(0.16, 0.16, 0.33, 0.33),
|
22 |
+
(0.16, 0.5, 0.33, 0.33),
|
23 |
+
(0.16, 0.83, 0.33, 0.33),
|
24 |
+
(0.5, 0.16, 0.33, 0.33),
|
25 |
+
(0.5, 0.5, 0.33, 0.33),
|
26 |
+
(0.5, 0.83, 0.33, 0.33),
|
27 |
+
(0.83, 0.16, 0.33, 0.33),
|
28 |
+
(0.83, 0.5, 0.33, 0.33),
|
29 |
+
(0.83, 0.83, 0.33, 0.33)],
|
30 |
+
# ...
|
31 |
+
[(0.5, 0.25, 0.5, 0.5)],
|
32 |
+
[(0.5, 0.75, 0.5, 0.5)],
|
33 |
+
# ...
|
34 |
+
[(0.25, 0.5, 0.5, 0.5)],
|
35 |
+
[(0.75, 0.5, 0.5, 0.5)],
|
36 |
+
# ...
|
37 |
+
[(0.5, 0.5, 1, 1)],
|
38 |
+
[(0.5, 0.5, 0.33, 0.33)],
|
39 |
+
# ...
|
40 |
+
[(0.5, 0.5, 1, 1)],
|
41 |
+
[(0.42, 0.42, 0.15, 0.15),
|
42 |
+
(0.42, 0.58, 0.15, 0.15),
|
43 |
+
(0.58, 0.42, 0.15, 0.15),
|
44 |
+
(0.58, 0.58, 0.15, 0.15)],
|
45 |
+
# ...
|
46 |
+
|
47 |
+
]
|
48 |
+
|
49 |
+
HORIZONTAL = "horizontal"
|
50 |
+
VERTICAL = "vertical"
|
51 |
+
|
52 |
+
NAMES = ['center_single',
|
53 |
+
'distribute_four',
|
54 |
+
'distribute_nine',
|
55 |
+
'in_center_single_out_center_single',
|
56 |
+
'in_distribute_four_out_center_single',
|
57 |
+
'left_center_single_right_center_single',
|
58 |
+
'up_center_single_down_center_single']
|
59 |
+
|
60 |
+
PROPERTIES_NAMES = [
|
61 |
+
'Color',
|
62 |
+
'Size',
|
63 |
+
'Type',
|
64 |
+
|
65 |
+
]
|
66 |
+
PROPERTIES = dict_from_list2(PROPERTIES_NAMES, [10, 6, 5])
|
67 |
+
ANGLE_MAX = 7
|
68 |
+
|
69 |
+
PROPERTIES_NO = len(PROPERTIES)
|
70 |
+
|
71 |
+
RULES_COMBINE = "Number/Position"
|
72 |
+
|
73 |
+
RULES_ATTRIBUTES = [
|
74 |
+
"Number",
|
75 |
+
"Position",
|
76 |
+
"Color",
|
77 |
+
"Size",
|
78 |
+
"Type"
|
79 |
+
]
|
80 |
+
RULES_ATTRIBUTES_LEN = len(RULES_ATTRIBUTES)
|
81 |
+
|
82 |
+
RULES_ATTRIBUTES_INDEX = dict_from_list2(RULES_ATTRIBUTES)
|
83 |
+
|
84 |
+
RULES_TYPES = [
|
85 |
+
"Constant",
|
86 |
+
"Arithmetic",
|
87 |
+
"Progression",
|
88 |
+
"Distribute_Three"
|
89 |
+
]
|
90 |
+
RULES_TYPES_INDEX = dict_from_list2(RULES_TYPES)
|
91 |
+
RULES_TYPES_LEN = len(RULES_ATTRIBUTES)
|
92 |
+
|
93 |
+
GROUPS_NO = len(NAMES)
|
94 |
+
ENTITY_NO = dict(zip(NAMES, [1, 4, 9, 2, 5, 2, 2]))
|
95 |
+
ENTITY_SUM = sum(list(ENTITY_NO.values()))
|
96 |
+
ENTITY_INDEX = np.concatenate([[0], np.cumsum(list(ENTITY_NO.values()))])
|
97 |
+
ENTITY_INDEX_TARGET = ENTITY_INDEX + 1
|
98 |
+
ENTITY_DICT = dict(zip(NAMES, ENTITY_INDEX_TARGET[:-1]))
|
99 |
+
NAMES_ORDER = dict(zip(NAMES, np.arange(len(NAMES))))
|
100 |
+
PROPERTIES_INDEXES = np.cumsum(np.array(list(ENTITY_NO.values())) * len(PROPERTIES))
|
101 |
+
INDEX = np.concatenate([[0], PROPERTIES_INDEXES]) + ENTITY_SUM + 1 # +2 type and uniformity
|
102 |
+
|
103 |
+
SECOND_LAYOUT = [i - 1 for i in [
|
104 |
+
ENTITY_DICT["in_center_single_out_center_single"] + 1,
|
105 |
+
ENTITY_DICT["in_distribute_four_out_center_single"] + 1,
|
106 |
+
ENTITY_DICT["in_distribute_four_out_center_single"] + 2,
|
107 |
+
ENTITY_DICT["in_distribute_four_out_center_single"] + 3,
|
108 |
+
ENTITY_DICT["left_center_single_right_center_single"] + 1,
|
109 |
+
ENTITY_DICT["up_center_single_down_center_single"] + 1
|
110 |
+
]]
|
111 |
+
|
112 |
+
FIRST_LAYOUT = list(set(range(ENTITY_SUM)) - set(SECOND_LAYOUT))
|
113 |
+
LAYOUT_NO = 2
|
114 |
+
|
115 |
+
START_INDEX = dict(zip(NAMES, INDEX[:-1]))
|
116 |
+
END_INDEX = INDEX[-1]
|
117 |
+
|
118 |
+
RULES_ATTRIBUTES_ALL_LEN = RULES_ATTRIBUTES_LEN * LAYOUT_NO
|
119 |
+
UNIFORMITY_NO = 2
|
120 |
+
UNIFORMITY_INDEX = END_INDEX + RULES_ATTRIBUTES_ALL_LEN
|
121 |
+
|
122 |
+
FEATURE_NO = UNIFORMITY_INDEX + UNIFORMITY_NO
|
123 |
+
MAPPING = {
|
124 |
+
"distribute_nine":
|
125 |
+
{0.16: 0,
|
126 |
+
0.5: 1,
|
127 |
+
0.83: 2},
|
128 |
+
"distribute_four":
|
129 |
+
{0.25: 0,
|
130 |
+
0.75: 1},
|
131 |
+
'in_distribute_four_out_center_single':
|
132 |
+
{0.42: 0,
|
133 |
+
0.58: 1}
|
134 |
+
}
|
135 |
+
MUL = {
|
136 |
+
"distribute_nine": 3,
|
137 |
+
"distribute_four": 2,
|
138 |
+
'in_distribute_four_out_center_single': 2
|
139 |
+
}
|
140 |
+
|
141 |
+
# SIZES = np.linspace(0.4, 0.9, 6)
|
142 |
+
TYPES = ["triangle", "square", "pentagon", "hexagon", "circle"]
|
143 |
+
# TYPES = ["triangle", "square", "pentagon", "circle", "circle"]
|
144 |
+
SIZES = ["vs", "s", "m", "h", "vh", "e"]
|
145 |
+
COLORS = ["vs", "s", "m", "h", "vh", "e"]
|
146 |
+
# TYPES = ["", "", "circle", "hexagon", "square"]
|
147 |
+
|
148 |
+
ENTITY_PROPERTIES_VALUES = list(PROPERTIES.values())
|
149 |
+
ENTITY_PROPERTIES_KEYS = list(PROPERTIES.keys())
|
150 |
+
ENTITY_PROPERTIES_NO = len(PROPERTIES)
|
151 |
+
INDEX = dict(zip(PROPERTIES, np.array(ENTITY_PROPERTIES_VALUES) * ENTITY_SUM))
|
152 |
+
ENTITY_PROPERTIES_SUM = sum(list(PROPERTIES.values()))
|
153 |
+
|
154 |
+
OUTPUT_SIZE = ENTITY_SUM * ENTITY_PROPERTIES_SUM + GROUPS_NO + ENTITY_SUM
|
155 |
+
|
156 |
+
SLOT_AND_GROUP = ENTITY_SUM + GROUPS_NO
|
157 |
+
|
158 |
+
OUTPUT_GROUP_SLICE = np.s_[:, -GROUPS_NO:]
|
159 |
+
OUTPUT_SLOT_SLICE = np.s_[:, -SLOT_AND_GROUP:-GROUPS_NO]
|
160 |
+
OUTPUT_PROPERTIES_SLICE = np.s_[:, :-SLOT_AND_GROUP]
|
161 |
+
|
162 |
+
OUTPUT_GROUP_SLICE_END = np.s_[-GROUPS_NO:]
|
163 |
+
OUTPUT_SLOT_SLICE_END = np.s_[-SLOT_AND_GROUP:-GROUPS_NO]
|
164 |
+
OUTPUT_PROPERTIES_SLICE_END = np.s_[:-SLOT_AND_GROUP]
|
165 |
+
|
166 |
+
# Transformation
|
167 |
+
# constant
|
168 |
+
# progression -2, -1,1 ,2
|
169 |
+
# arithmetic -/+ Position set arithmetic
|
170 |
+
# distribute three
|
171 |
+
|
172 |
+
# todo
|
173 |
+
SLOTS_GROUPS = GROUPS_NO
|
174 |
+
|
175 |
+
SLOT_TRANSFORMATION_NO = 4
|
176 |
+
PROPERTY_TRANSFORMATION_NO = 8
|
177 |
+
PROPERTIES_TRANSFORMATION_NO = PROPERTY_TRANSFORMATION_NO * PROPERTIES_NO
|
178 |
+
PROPERTIES_TRANSFORMATION_SIZE = PROPERTIES_TRANSFORMATION_NO * ENTITY_SUM
|
179 |
+
|
180 |
+
SLOT_TRANSFORMATION_SIZE = PROPERTY_TRANSFORMATION_NO * SLOTS_GROUPS
|
181 |
+
INFERENCE_SIZE = SLOT_TRANSFORMATION_SIZE + PROPERTIES_TRANSFORMATION_SIZE
|
182 |
+
|
183 |
+
INFERENCE_SLOT_SLICE = np.s_[:, :SLOT_TRANSFORMATION_SIZE]
|
184 |
+
INFERENCE_PROPERTIES_SLICE = np.s_[:, -PROPERTIES_TRANSFORMATION_SIZE:]
|
185 |
+
from operator import add
|
186 |
+
|
187 |
+
|
188 |
+
# todo Refactor
|
189 |
+
# Maybe properties should be on same level as rest.
|
190 |
+
def decode_output(output, split_fn=np_split):
|
191 |
+
group_output = output[..., OUTPUT_GROUP_SLICE_END]
|
192 |
+
slot_output = output[..., OUTPUT_SLOT_SLICE_END]
|
193 |
+
properties_output = output[..., OUTPUT_PROPERTIES_SLICE_END]
|
194 |
+
properties_output_splited = split_fn(properties_output, list(rv.properties.INDEX.values()), axis=-1)
|
195 |
+
return group_output, slot_output, properties_output_splited
|
196 |
+
|
197 |
+
|
198 |
+
def decode_inference(inference, reshape=np.reshape):
|
199 |
+
return reshape(inference[INFERENCE_SLOT_SLICE],
|
200 |
+
[-1, SLOTS_GROUPS, PROPERTY_TRANSFORMATION_NO]), reshape(
|
201 |
+
inference[INFERENCE_PROPERTIES_SLICE],
|
202 |
+
[-1, PROPERTIES_NO, ENTITY_SUM, PROPERTY_TRANSFORMATION_NO])
|
203 |
+
|
204 |
+
|
205 |
+
def decode_output_reshape(output, split_fn=np_split):
|
206 |
+
result = decode_output(output, split_fn=split_fn)
|
207 |
+
out_reshaped = []
|
208 |
+
for i, out in enumerate(result[2]):
|
209 |
+
shape = (-1, ENTITY_SUM, ENTITY_PROPERTIES_VALUES[i])
|
210 |
+
out_reshaped.append(out.reshape(shape))
|
211 |
+
return result[:2] + tuple(out_reshaped)
|
212 |
+
|
213 |
+
|
214 |
+
def take_target(target):
|
215 |
+
return target[1], target[2]
|
216 |
+
|
217 |
+
|
218 |
+
def create_target(images, index, pattern_index=(2, 5), full_index=False, arrange=np.arange, shape=lambda x: x.shape):
|
219 |
+
return [images[:, pattern_index[0]], images[:, pattern_index[1]],
|
220 |
+
images[arrange(shape(index)[0]), (0 if full_index else 8) + index[:, 0]]]
|
221 |
+
|
222 |
+
|
223 |
+
def take_target_simple(target):
|
224 |
+
return target[1], target[0]
|
225 |
+
|
226 |
+
|
227 |
+
def create_target_simple(images, target, index=slice(None), pattern_index=(2, 5)):
|
228 |
+
return [images[:, pattern_index[0]], images[:, pattern_index[1]], target][index]
|
229 |
+
|
230 |
+
|
231 |
+
def decode_output_result(output, split_fn=np_split, arg_max=np.argmax):
|
232 |
+
result = decode_output_reshape(output, split_fn=split_fn)
|
233 |
+
res = []
|
234 |
+
for i, r in enumerate(result):
|
235 |
+
if i == 1:
|
236 |
+
res.append(r)
|
237 |
+
else:
|
238 |
+
res.append(arg_max(r, axis=-1))
|
239 |
+
return tuple(res)
|
240 |
+
|
241 |
+
|
242 |
+
def decode_target(target):
|
243 |
+
target_group = target[..., 0]
|
244 |
+
target_slot = target[..., 1:INDEX[0]]
|
245 |
+
target_properties = target[..., INDEX[0]:END_INDEX]
|
246 |
+
target_properties_splited = [
|
247 |
+
target_properties[..., ::PROPERTIES_NO],
|
248 |
+
target_properties[..., 1::PROPERTIES_NO],
|
249 |
+
target_properties[..., 2::PROPERTIES_NO]
|
250 |
+
]
|
251 |
+
return target_group, target_slot, target_properties_splited
|
252 |
+
|
253 |
+
|
254 |
+
def decode_target_flat(target):
|
255 |
+
t = decode_target(target)
|
256 |
+
return t[0], t[1], t[2][0], t[2][1], t[2][2]
|
257 |
+
|
258 |
+
|
259 |
+
def draw_board(images, target=None, predict=None,image=None, desc=None, layout=None, break_=20):
|
260 |
+
if image != "target" and predict is not None:
|
261 |
+
image = images[predict:predict + 1]
|
262 |
+
elif images is None and target is not None:
|
263 |
+
image = images[target:target + 1]
|
264 |
+
# image = False to not draw anything
|
265 |
+
border = [{COR: target - 8, EXIST: (1, 3)}] + [{COR: p, EXIST: (0, 2)} for p in none(predict)]
|
266 |
+
|
267 |
+
boards = []
|
268 |
+
boards.append(draw_images(np.concatenate([images[:8], image[None] if len(image.shape)==3 else image]) if image is not None else images[:8]))
|
269 |
+
if layout == 1:
|
270 |
+
i = draw_images(images[8:], column=4, border=border)
|
271 |
+
if break_:
|
272 |
+
i = np.concatenate([np.zeros([ break_, i.shape[1],1]),i ],axis=0)
|
273 |
+
boards.append(i)
|
274 |
+
|
275 |
+
else:
|
276 |
+
boards.append(
|
277 |
+
draw_images(np.concatenate([images[8:], predict]) if predict is not None else images[8:], column=4,
|
278 |
+
border=target - 8))
|
279 |
+
full_board = draw_images(boards, grid=False)
|
280 |
+
if desc:
|
281 |
+
full_board = add_text(full_board, desc)
|
282 |
+
return full_board
|
283 |
+
|
284 |
+
|
285 |
+
def draw_boards(images, target=None, predict=None, image=None, desc=None, no=1, layout=None):
|
286 |
+
boards = []
|
287 |
+
for i, image in enumerate(images):
|
288 |
+
boards.append(draw_board(image, target[i][0] if target is not None else None,
|
289 |
+
predict[i] if predict is not None else None,
|
290 |
+
image[i] if image is not None else None,
|
291 |
+
desc[i] if desc is not None else None, layout=layout))
|
292 |
+
return boards
|
293 |
+
|
294 |
+
|
295 |
+
def draw_raven(generator, predict=None, no=1, add_target_desc=True, indexes=None, types=TYPES,
|
296 |
+
layout=1):
|
297 |
+
if indexes is None:
|
298 |
+
indexes = val_sample(no)
|
299 |
+
data = generator.data[indexes]
|
300 |
+
if is_model(predict):
|
301 |
+
d = filter_keys(data, PROPERTY,reverse=True)
|
302 |
+
# tmp change
|
303 |
+
pro = predict(d)['predict']
|
304 |
+
print(pro)
|
305 |
+
predict = render_panels(pro, target=False)
|
306 |
+
# if target is not None:
|
307 |
+
target = data[TARGET]
|
308 |
+
target_index = data["index"]
|
309 |
+
images = data[INPUTS]
|
310 |
+
|
311 |
+
if hasattr(predict, "shape"):
|
312 |
+
if len(predict.shape) > 3:
|
313 |
+
# iamges
|
314 |
+
image = predict
|
315 |
+
# todo create index and output based on image
|
316 |
+
predict = None
|
317 |
+
predict_index = None
|
318 |
+
elif len(predict.shape) == 3:
|
319 |
+
image = render_panels(predict, target=False)
|
320 |
+
# Create index based on predict.
|
321 |
+
predict_index = None
|
322 |
+
else:
|
323 |
+
image = images[predict]
|
324 |
+
predict_index = predict
|
325 |
+
predict = target
|
326 |
+
else:
|
327 |
+
image = K.gather(images, target_index[:, 0])
|
328 |
+
predict_index = None
|
329 |
+
predict = None
|
330 |
+
|
331 |
+
# elif not(hasattr(target,"shape") and len(target.shape) > 3):
|
332 |
+
# if hasattr(target,"shape") and target.shape[-1] == OUTPUT_SIZE:
|
333 |
+
# pro = target
|
334 |
+
# predict = render_panels(pro)
|
335 |
+
# elif hasattr(target,"shape") and target.shape[-1] == FEATURE_NO:
|
336 |
+
# # pro = target
|
337 |
+
# pro = np.zeros([no, OUTPUT_SIZE], dtype="int")
|
338 |
+
# else:
|
339 |
+
# pro = np.zeros([no, OUTPUT_SIZE], dtype="int")
|
340 |
+
# # predict = [None] * no
|
341 |
+
# predict = render_panels(data[TARGET])
|
342 |
+
|
343 |
+
all_rules = []
|
344 |
+
for d in data[PROPERTY]:
|
345 |
+
rules = []
|
346 |
+
for j, rule_group in enumerate(d.findAll("Rule_Group")):
|
347 |
+
# rules_all.append(rule_group['id'])
|
348 |
+
for j, rule in enumerate(rule_group.findAll("Rule")):
|
349 |
+
rules.append(f"{rule['attr']} - {rule['name']}")
|
350 |
+
rules.append("")
|
351 |
+
all_rules.append(rules)
|
352 |
+
target_desc = get_desc(target)
|
353 |
+
if predict is not None:
|
354 |
+
predict_desc = decode_output_result(predict) if predict.shape[-1] == OUTPUT_SIZE else get_desc(predict)
|
355 |
+
else:
|
356 |
+
predict_desc = [None] * len(target_desc)
|
357 |
+
for a, po, to in zip(all_rules, predict_desc, target_desc):
|
358 |
+
# fl(predict_desc[-1])
|
359 |
+
if po is None:
|
360 |
+
po = [None] * len(to)
|
361 |
+
for p, t in zip(po, to):
|
362 |
+
a.extend(
|
363 |
+
[" ".join([str(i) for i in t])] + (
|
364 |
+
[" ".join([str(i) for i in p]), ""] if p is not None else []
|
365 |
+
)
|
366 |
+
)
|
367 |
+
# a.extend([""] + [] + [""] + [" ".join(fl(p))])
|
368 |
+
|
369 |
+
# image = draw_boards(data[INPUTS],target=data["index"], predict=predict[:no], desc=all_rules, no=no,layer=layer)
|
370 |
+
image = draw_boards(images, target=target_index, predict=predict_index, image=image, desc=None, no=no,
|
371 |
+
layout=layout)
|
372 |
+
return lu([(i, j) for i, j in zip(image, all_rules)])
|
373 |
+
|
374 |
+
|
375 |
+
def val_sample(no=GROUPS_NO, base=3):
|
376 |
+
indexes = np.arange(no) * 2000 + base
|
377 |
+
return indexes
|
378 |
+
|
379 |
+
|
380 |
+
def get_desc(target, exist=None, types=TYPES, sizes=SIZES):
|
381 |
+
decoded = decode_target(target)
|
382 |
+
exist = decoded[1] if exist is None else exist
|
383 |
+
taken = np.stack(take(decoded[2], np.array(exist, dtype=bool))).T
|
384 |
+
|
385 |
+
figures_no = np.sum(exist, axis=-1)
|
386 |
+
desc = np.split(taken, np.cumsum(figures_no))[:-1]
|
387 |
+
# figures_no = np.sum(exist, axis=-1)
|
388 |
+
# div = np.split(desc, np.cumsum(figures_no))[:-1]
|
389 |
+
result = []
|
390 |
+
for pd in desc:
|
391 |
+
r = []
|
392 |
+
for p in pd:
|
393 |
+
r.append([p[0], sizes[p[1]], types[p[2]]])
|
394 |
+
result.append(r)
|
395 |
+
|
396 |
+
return result
|
397 |
+
|
398 |
+
|
399 |
+
# def get
|
400 |
+
|
401 |
+
|
402 |
+
def get_description(inputs, predict, pro, no, types=TYPES, sizes=SIZES):
|
403 |
+
# target = inputs[1][2][:no]
|
404 |
+
target = inputs[TARGET]
|
405 |
+
target_group = target[:, 0]
|
406 |
+
target_exist = np.asarray(target[:, 1:ENTITY_SUM + 1], dtype="bool")
|
407 |
+
target_rest = target[:, ENTITY_SUM + 1:ENTITY_SUM + 1 + ENTITY_SUM * PROPERTIES_NO]
|
408 |
+
pro_reshaped = np.reshape(pro, (pro.shape[0], -1, PROPERTIES_NO))
|
409 |
+
target_reshaped = np.reshape(target_rest, (target_rest.shape[0], -1, PROPERTIES_NO))
|
410 |
+
|
411 |
+
# mask = np.repeat(target_exist, [4] * ENTITY_SUM, axis=-1)
|
412 |
+
# masked_result = np.repeat(target_exist, [4] * ENTITY_SUM, axis=-1)
|
413 |
+
pro_res = pro_reshaped[target_exist]
|
414 |
+
target_res = target_reshaped[target_exist]
|
415 |
+
figures_no = np.sum(target_exist, axis=-1)
|
416 |
+
|
417 |
+
pro_div = np.split(pro_res, np.cumsum(figures_no))[:-1]
|
418 |
+
target_div = np.split(target_res, np.cumsum(figures_no))[:-1]
|
419 |
+
pro_result_full = []
|
420 |
+
target_result_full = []
|
421 |
+
for pd, td in zip(pro_div, target_div):
|
422 |
+
pro_result = []
|
423 |
+
target_result = []
|
424 |
+
for p in pd:
|
425 |
+
pro_result.append([p[0], sizes[p[1]], types[p[2]]])
|
426 |
+
for t in td:
|
427 |
+
target_result.append([t[0], sizes[t[1]], types[t[2]]])
|
428 |
+
pro_result_full.append(pro_result)
|
429 |
+
target_result_full.append(target_result)
|
430 |
+
|
431 |
+
return pro_result_full, target_result_full
|
432 |
+
|
433 |
+
|
434 |
+
def get_properties(target, types=TYPES, sizes=SIZES):
|
435 |
+
target_exist = np.asarray(target[:, 1:ENTITY_SUM + 1], dtype="bool")
|
436 |
+
target_rest = target[:, ENTITY_SUM + 1:ENTITY_SUM + 1 + ENTITY_SUM * PROPERTIES_NO]
|
437 |
+
target_reshaped = np.reshape(target_rest, (target_rest.shape[0], -1, PROPERTIES_NO))
|
438 |
+
target_res = target_reshaped[target_exist]
|
439 |
+
figures_no = np.sum(target_exist, axis=-1)
|
440 |
+
target_div = np.split(target_res, np.cumsum(figures_no))[:-1]
|
441 |
+
target_result_full = []
|
442 |
+
for td in target_div:
|
443 |
+
target_result = []
|
444 |
+
for t in td:
|
445 |
+
target_result.append([t[0], sizes[t[1]], types[t[2]]])
|
446 |
+
target_result_full.append(target_result)
|
447 |
+
return target_result_full
|
448 |
+
|
449 |
+
|
450 |
+
def desc_properties(target, decode_fn=None, types=TYPES, sizes=SIZES):
|
451 |
+
if decode_fn is None:
|
452 |
+
if target.shape[1] == OUTPUT_SIZE:
|
453 |
+
decode_fn = decode_output_result
|
454 |
+
else:
|
455 |
+
decode_fn = decode_target
|
456 |
+
|
457 |
+
target_div = decode_fn(target)[2:]
|
458 |
+
target_result_full = []
|
459 |
+
for td in target_div:
|
460 |
+
target_result = []
|
461 |
+
for t in td:
|
462 |
+
target_result.append([t[0], sizes[t[1]], types[t[2]]])
|
463 |
+
target_result_full.append(target_result)
|
464 |
+
return target_result_full
|
465 |
+
|
466 |
+
|
467 |
+
def get_pro(t, types=TYPES, sizes=SIZES):
|
468 |
+
return [int(t[0]), sizes[t[1]], types[t[2]]]
|
469 |
+
|
470 |
+
|
471 |
+
def get_pro2(td, types=TYPES, sizes=SIZES):
|
472 |
+
target_result = []
|
473 |
+
for t in td:
|
474 |
+
target_result.append([int(t[0]), sizes[t[1]], types[t[2]]])
|
475 |
+
return target_result
|
476 |
+
|
477 |
+
|
478 |
+
def get_pro3(target_div, types=TYPES, sizes=SIZES):
|
479 |
+
target_result_full = []
|
480 |
+
for td in target_div.to_list():
|
481 |
+
target_result = []
|
482 |
+
for t in td:
|
483 |
+
target_result.append([int(t[0]), sizes[t[1]], types[t[2]]])
|
484 |
+
target_result_full.append(target_result)
|
485 |
+
return target_result_full
|
486 |
+
|
487 |
+
|
488 |
+
from models_utils import init_image as def_init_image, is_model
|
489 |
+
|
490 |
+
init_image = partial(def_init_image, shape=(16, 8, 80, 80, 1))
|
raven_utils/draw.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from data_utils import take, EXIST, COR
|
3 |
+
from data_utils.image import draw_images, add_text
|
4 |
+
from funcy import identity
|
5 |
+
from ml_utils import none, filter_keys, lu
|
6 |
+
from models_utils import is_model
|
7 |
+
from models_utils import ops as K
|
8 |
+
|
9 |
+
from raven_utils.constant import PROPERTY, TARGET, INPUTS
|
10 |
+
from raven_utils.decode import decode_target, target_mask
|
11 |
+
from raven_utils.render.rendering import render_panels
|
12 |
+
from raven_utils.render_ import TYPES, SIZES
|
13 |
+
from raven_utils.uitls import get_val_index
|
14 |
+
|
15 |
+
|
16 |
+
def draw_board(images, target=None, predict=None, image=None, desc=None, layout=None, break_=20):
|
17 |
+
if image != "target" and predict is not None:
|
18 |
+
image = images[predict:predict + 1]
|
19 |
+
elif images is None and target is not None:
|
20 |
+
image = images[target:target + 1]
|
21 |
+
# image = False to not draw anything
|
22 |
+
border = [{COR: target - 8, EXIST: list(range(4)) if predict is None else (1, 3)}] + [{COR: p, EXIST: (0, 2)} for p
|
23 |
+
in none(predict)]
|
24 |
+
|
25 |
+
boards = []
|
26 |
+
boards.append(draw_images(
|
27 |
+
np.concatenate([images[:8], image[None] if len(image.shape) == 3 else image]) if image is not None else images[
|
28 |
+
:8]))
|
29 |
+
if layout == 1:
|
30 |
+
i = draw_images(images[8:], column=4, border=border)
|
31 |
+
if break_:
|
32 |
+
i = np.concatenate([np.zeros([break_, i.shape[1], 1]), i], axis=0)
|
33 |
+
boards.append(i)
|
34 |
+
|
35 |
+
else:
|
36 |
+
boards.append(
|
37 |
+
draw_images(np.concatenate([images[8:], predict]) if predict is not None else images[8:], column=4,
|
38 |
+
border=target - 8))
|
39 |
+
full_board = draw_images(boards, grid=False)
|
40 |
+
if desc:
|
41 |
+
full_board = add_text(full_board, desc)
|
42 |
+
return full_board
|
43 |
+
|
44 |
+
|
45 |
+
def draw_boards(images, target=None, predict=None,image=None, desc=None, layout=None):
|
46 |
+
boards = []
|
47 |
+
for i, im in enumerate(images):
|
48 |
+
boards.append(draw_board(im, target[i][0] if target is not None else None,
|
49 |
+
predict[i] if predict is not None else None,
|
50 |
+
image[i] if image is not None else None,
|
51 |
+
desc[i] if desc is not None else None, layout=layout))
|
52 |
+
return boards
|
53 |
+
|
54 |
+
|
55 |
+
def draw_from_generator(generator, predict=None, no=1, indexes=None, layout=1):
|
56 |
+
data,_ = val_sample(generator, no, indexes)
|
57 |
+
return draw_raven(data, predict=predict, pre_fn=generator.data.data["inputs"].fn, layout=layout)
|
58 |
+
|
59 |
+
|
60 |
+
def val_sample(generator, no=1, indexes=None):
|
61 |
+
if indexes is None:
|
62 |
+
indexes = get_val_index(base=no)
|
63 |
+
data = generator.data[indexes]
|
64 |
+
return data, indexes
|
65 |
+
|
66 |
+
def render_from_model(data,predict,pre_fn=identity):
|
67 |
+
data = filter_keys(data, PROPERTY, reverse=True)
|
68 |
+
if is_model(predict):
|
69 |
+
predict = predict(data)
|
70 |
+
pro = np.array(target_mask(predict['predict_mask'].numpy()) * predict["predict"].numpy(), dtype=np.int8)
|
71 |
+
return pre_fn(render_panels(pro, target=False)[None])[0]
|
72 |
+
|
73 |
+
def draw_raven(data, predict=None, pre_fn=identity, layout=1):
|
74 |
+
if is_model(predict):
|
75 |
+
d = filter_keys(data, PROPERTY, reverse=True)
|
76 |
+
# tmp change
|
77 |
+
res = predict(d)
|
78 |
+
pro = np.array(target_mask(res['mask'].numpy()) * res["predict"].numpy(),dtype=np.int8)
|
79 |
+
predict = pre_fn(render_panels(pro, target=False)[None])[0]
|
80 |
+
# from data_utils import ims
|
81 |
+
# ims(1 - predict[0])
|
82 |
+
# if target is not None:
|
83 |
+
target = data[TARGET]
|
84 |
+
target_index = data["index"]
|
85 |
+
images = data[INPUTS]
|
86 |
+
# np.equal(res['predict'], pro[:,:102]).sum()
|
87 |
+
|
88 |
+
if hasattr(predict, "shape"):
|
89 |
+
if len(predict.shape) > 3:
|
90 |
+
# iamges
|
91 |
+
image = predict
|
92 |
+
# todo create index and output based on image
|
93 |
+
predict = None
|
94 |
+
predict_index = None
|
95 |
+
elif len(predict.shape) == 3:
|
96 |
+
image = render_panels(predict, target=False)
|
97 |
+
# Create index based on predict.
|
98 |
+
predict_index = None
|
99 |
+
else:
|
100 |
+
image = images[predict]
|
101 |
+
predict_index = predict
|
102 |
+
predict = target
|
103 |
+
else:
|
104 |
+
image = K.gather(images, target_index[:, 0])
|
105 |
+
predict_index = None
|
106 |
+
predict = None
|
107 |
+
|
108 |
+
# elif not(hasattr(target,"shape") and len(target.shape) > 3):
|
109 |
+
# if hasattr(target,"shape") and target.shape[-1] == OUTPUT_SIZE:
|
110 |
+
# pro = target
|
111 |
+
# predict = render_panels(pro)
|
112 |
+
# elif hasattr(target,"shape") and target.shape[-1] == FEATURE_NO:
|
113 |
+
# # pro = target
|
114 |
+
# pro = np.zeros([no, OUTPUT_SIZE], dtype="int")
|
115 |
+
# else:
|
116 |
+
# pro = np.zeros([no, OUTPUT_SIZE], dtype="int")
|
117 |
+
# # predict = [None] * no
|
118 |
+
# predict = render_panels(data[TARGET])
|
119 |
+
|
120 |
+
image = draw_boards(images, target=target_index, predict=predict_index,image=image, desc=None,
|
121 |
+
layout=layout)
|
122 |
+
|
123 |
+
all_rules = extract_rules(data[PROPERTY])
|
124 |
+
target_desc = get_desc(target)
|
125 |
+
if predict is not None:
|
126 |
+
predict_desc = get_desc(predict)
|
127 |
+
else:
|
128 |
+
predict_desc = [None] * len(target_desc)
|
129 |
+
for a, po, to in zip(all_rules, predict_desc, target_desc):
|
130 |
+
# fl(predict_desc[-1])
|
131 |
+
if po is None:
|
132 |
+
po = [None] * len(to)
|
133 |
+
for p, t in zip(po, to):
|
134 |
+
a.extend(
|
135 |
+
[" ".join([str(i) for i in t])] + (
|
136 |
+
[" ".join([str(i) for i in p]), ""] if p is not None else []
|
137 |
+
)
|
138 |
+
)
|
139 |
+
# a.extend([""] + [] + [""] + [" ".join(fl(p))])
|
140 |
+
|
141 |
+
# image = draw_boards(data[INPUTS],target=data["index"], predict=predict[:no], desc=all_rules, no=no,layer=layer)
|
142 |
+
return lu([(i, j) for i, j in zip(image, all_rules)])
|
143 |
+
|
144 |
+
|
145 |
+
def extract_rules(data):
|
146 |
+
all_rules = []
|
147 |
+
for d in data:
|
148 |
+
rules = []
|
149 |
+
for j, rule_group in enumerate(d.findAll("Rule_Group")):
|
150 |
+
# rules_all.append(rule_group['id'])
|
151 |
+
for j, rule in enumerate(rule_group.findAll("Rule")):
|
152 |
+
rules.append(f"{rule['attr']} - {rule['name']}")
|
153 |
+
rules.append("")
|
154 |
+
all_rules.append(rules)
|
155 |
+
return all_rules
|
156 |
+
|
157 |
+
|
158 |
+
def get_desc(target, exist=None, types=TYPES, sizes=SIZES):
|
159 |
+
decoded = decode_target(target)
|
160 |
+
exist = decoded[1] if exist is None else exist
|
161 |
+
taken = np.stack(take(decoded[2], np.array(exist, dtype=bool))).T
|
162 |
+
|
163 |
+
figures_no = np.sum(exist, axis=-1)
|
164 |
+
desc = np.split(taken, np.cumsum(figures_no))[:-1]
|
165 |
+
# figures_no = np.sum(exist, axis=-1)
|
166 |
+
# div = np.split(desc, np.cumsum(figures_no))[:-1]
|
167 |
+
result = []
|
168 |
+
for pd in desc:
|
169 |
+
r = []
|
170 |
+
for p in pd:
|
171 |
+
r.append([p[0], sizes[p[1]], types[p[2]]])
|
172 |
+
result.append(r)
|
173 |
+
|
174 |
+
return result
|
raven_utils/entity.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import raven_utils.group as group
|
2 |
+
import numpy as np
|
3 |
+
NO = dict(zip(group.NAMES, [1, 4, 9, 2, 5, 2, 2]))
|
4 |
+
SUM = sum(list(NO.values()))
|
5 |
+
|
6 |
+
INDEX = np.concatenate([[0], np.cumsum(list(NO.values()))])
|
raven_utils/group.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
NAMES = ['center_single',
|
4 |
+
'distribute_four',
|
5 |
+
'distribute_nine',
|
6 |
+
'in_center_single_out_center_single',
|
7 |
+
'in_distribute_four_out_center_single',
|
8 |
+
'left_center_single_right_center_single',
|
9 |
+
'up_center_single_down_center_single']
|
10 |
+
|
11 |
+
NO = len(NAMES)
|
raven_utils/inference.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import raven_utils.properties as properties
|
3 |
+
import raven_utils.group as group
|
4 |
+
|
5 |
+
|
6 |
+
SLOT_TRANSFORMATION_NO = 4
|
7 |
+
PROPERTY_TRANSFORMATION_NO = 8
|
8 |
+
PROPERTIES_TRANSFORMATION_NO = PROPERTY_TRANSFORMATION_NO * properties.NO
|
9 |
+
PROPERTIES_TRANSFORMATION_SIZE = PROPERTIES_TRANSFORMATION_NO * group.NO
|
10 |
+
|
11 |
+
SLOT_TRANSFORMATION_SIZE = PROPERTY_TRANSFORMATION_NO * group.NO
|
12 |
+
SIZE = SLOT_TRANSFORMATION_SIZE + PROPERTIES_TRANSFORMATION_SIZE
|
13 |
+
|
14 |
+
SLOT_SLICE = np.s_[:, :SLOT_TRANSFORMATION_SIZE]
|
15 |
+
PROPERTIES_SLICE = np.s_[:, -PROPERTIES_TRANSFORMATION_SIZE:]
|
raven_utils/models/__init__.py
ADDED
File without changes
|
raven_utils/models/attn.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow.keras import backend as K
|
5 |
+
from tensorflow.keras.layers import LSTMCell
|
6 |
+
from tensorflow.keras.models import Model
|
7 |
+
from tensorflow.keras.layers import Conv2D, Dense
|
8 |
+
from tensorflow.keras.losses import mse
|
9 |
+
from tensorflow.keras.models import clone_model
|
10 |
+
from tensorflow.layers.base import InputSpec, Layer
|
11 |
+
|
12 |
+
from models.dense import create_conv_model
|
13 |
+
from models.utils import broadcast
|
14 |
+
|
15 |
+
|
16 |
+
class ReflectionPadding2D(Layer):
|
17 |
+
def __init__(self, padding=(1, 1), **kwargs):
|
18 |
+
self.padding = tuple(padding)
|
19 |
+
self.input_spec = [InputSpec(ndim=4)]
|
20 |
+
super(ReflectionPadding2D, self).__init__(**kwargs)
|
21 |
+
|
22 |
+
def compute_output_shape(self, s):
|
23 |
+
""" If you are using "channels_last" configuration"""
|
24 |
+
return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])
|
25 |
+
|
26 |
+
def call(self, x, mask=None):
|
27 |
+
w_pad, h_pad = self.padding
|
28 |
+
return tf.pad(x, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]], 'REFLECT')
|
29 |
+
|
30 |
+
|
31 |
+
class Conv2Ref(Layer):
|
32 |
+
def __init__(self, padding=(1, 1), **kwargs):
|
33 |
+
self.padding = tuple(padding)
|
34 |
+
self.input_spec = [InputSpec(ndim=4)]
|
35 |
+
super(ReflectionPadding2D, self).__init__(**kwargs)
|
36 |
+
|
37 |
+
def compute_output_shape(self, s):
|
38 |
+
""" If you are using "channels_last" configuration"""
|
39 |
+
return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])
|
40 |
+
|
41 |
+
def call(self, x, mask=None):
|
42 |
+
w_pad, h_pad = self.padding
|
43 |
+
return tf.pad(x, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]], 'REFLECT')
|
44 |
+
|
45 |
+
|
46 |
+
class SegmentationNetwork(Model):
|
47 |
+
|
48 |
+
def __init__(self, filters=64, kernels=(3, 3)):
|
49 |
+
super(RecAE, self).__init__()
|
50 |
+
self.conv_1 = Conv2D(filters, kernels, padding=SAME)
|
51 |
+
self.conv_2 = Conv2D(filters, kernels, padding=SAME)
|
52 |
+
|
53 |
+
def call(self, inputs):
|
54 |
+
x = K.relu(inputs)
|
55 |
+
x = self.conv_1(x)
|
56 |
+
x = K.relu(x)
|
57 |
+
x = self.conv_2(x)
|
58 |
+
return x + inputs
|
59 |
+
|
60 |
+
|
61 |
+
class QueryNetwork(Model):
|
62 |
+
|
63 |
+
def __init__(self, units=64):
|
64 |
+
super(RecAE, self).__init__()
|
65 |
+
self.conv_1 = Dense(units)
|
66 |
+
self.conv_2 = Dense(units)
|
67 |
+
|
68 |
+
def call(self, inputs):
|
69 |
+
x = K.relu(inputs)
|
70 |
+
x = self.conv_1(x)
|
71 |
+
x = K.relu(x)
|
72 |
+
x = self.conv_2(x)
|
73 |
+
return x + inputs
|
74 |
+
|
75 |
+
|
76 |
+
class RecAE(Model):
|
77 |
+
|
78 |
+
def __init__(self, head, bottle, decoder):
|
79 |
+
super(RecAE, self).__init__()
|
80 |
+
self.head = head
|
81 |
+
self.bottle = bottle
|
82 |
+
self.base = clone_model(bottle)
|
83 |
+
self.decoder = decoder
|
84 |
+
self.segmentation_network = SegmentationNetwork()
|
85 |
+
self.query_network = QueryNetwork()
|
86 |
+
self.control = LSTMCell(64)
|
87 |
+
self.memory = LSTMCell(64)
|
88 |
+
|
89 |
+
def call(self, inputs):
|
90 |
+
feature = self.head(inputs)
|
91 |
+
segmentation = self.segmentation_network(feature)
|
92 |
+
control_base = self.base(feature)
|
93 |
+
h_c = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
|
94 |
+
h_m = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
|
95 |
+
shape = K.shape(feature)[:-1]
|
96 |
+
full_attention = tf.zeros(shape)[..., tf.newaxis]
|
97 |
+
full_image = tf.zeros(K.shape(inputs))
|
98 |
+
masks = []
|
99 |
+
ff = tf.zeros(K.shape(inputs))
|
100 |
+
scope = tf.ones(shape)[..., tf.newaxis]
|
101 |
+
for i in range(4):
|
102 |
+
r_c, h_c = self.control(tf.concat([control_base, h_m[0]], 1), h_c)
|
103 |
+
query = self.query_network(h_c[0])
|
104 |
+
log_attention = image_attention(segmentation, query)
|
105 |
+
attention = K.sigmoid(log_attention)
|
106 |
+
mask = attention * scope
|
107 |
+
scope = scope - mask
|
108 |
+
im = feature * mask
|
109 |
+
# im = feature
|
110 |
+
latent = self.bottle(im)
|
111 |
+
decoded = self.decoder(latent)
|
112 |
+
# self.add_loss(K.mean(-mse(full_attention, attention)))
|
113 |
+
# self.add_loss(K.mean(-mse(tf.ones(attention.shape), attention)))
|
114 |
+
full_attention += attention
|
115 |
+
big_mask = tf.image.resize(mask, K.shape(inputs)[1:-1])
|
116 |
+
ff += K.sigmoid(decoded)
|
117 |
+
full_image += K.sigmoid(decoded) * big_mask
|
118 |
+
r_m, h_m = self.memory(latent, h_m)
|
119 |
+
masks.append(big_mask)
|
120 |
+
self.add_loss(K.mean(mse(inputs, full_image)))
|
121 |
+
return full_image, masks
|
122 |
+
|
123 |
+
|
124 |
+
# def image_attention(image, query, scale=True):
|
125 |
+
@tf.function
|
126 |
+
def image_attention(image, query):
|
127 |
+
log_attention = K.sum(query[:, tf.newaxis, tf.newaxis, :] * image, axis=-1, keepdims=True)
|
128 |
+
# if scale is not None:
|
129 |
+
log_attention /= tf.sqrt(tf.cast(K.shape(image)[-1], dtype=float))
|
130 |
+
return log_attention
|
131 |
+
|
132 |
+
|
133 |
+
class RecAE_2(Model):
|
134 |
+
|
135 |
+
def __init__(self, head, bottle, decoder):
|
136 |
+
super(RecAE_2, self).__init__()
|
137 |
+
self.head = head
|
138 |
+
self.bottle = bottle
|
139 |
+
# self.base = clone_model(bottle)
|
140 |
+
self.base = self.bottle
|
141 |
+
self.decoder = decoder
|
142 |
+
self.segmentation_network = create_conv_model((64, 64, 1))
|
143 |
+
self.control = LSTMCell(64)
|
144 |
+
self.memory = LSTMCell(64)
|
145 |
+
|
146 |
+
def call(self, inputs):
|
147 |
+
feature = self.head(inputs)
|
148 |
+
control_base = self.base(feature)
|
149 |
+
h_c = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
|
150 |
+
h_m = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
|
151 |
+
shape = K.shape(feature)[:-1]
|
152 |
+
full_attention = tf.zeros(shape)[..., tf.newaxis]
|
153 |
+
full_image = tf.zeros(K.shape(inputs))
|
154 |
+
big_masks = []
|
155 |
+
masks = []
|
156 |
+
ff = tf.zeros(K.shape(inputs))
|
157 |
+
scope = tf.ones(shape)[..., tf.newaxis]
|
158 |
+
for i in range(4):
|
159 |
+
if i ==3:
|
160 |
+
mask = scope
|
161 |
+
else:
|
162 |
+
r_c, h_c = self.control(tf.concat([control_base, h_m[0]], 1), h_c)
|
163 |
+
query = broadcast(h_c[0], feature.shape[1:])
|
164 |
+
log_attention = self.segmentation_network(tf.concat([feature, query], axis=-1))
|
165 |
+
attention = K.sigmoid(log_attention)
|
166 |
+
mask = attention * scope
|
167 |
+
scope = scope - mask
|
168 |
+
masks.append(mask)
|
169 |
+
im = feature * mask
|
170 |
+
# im = feature
|
171 |
+
latent = self.bottle(im)
|
172 |
+
decoded = self.decoder(latent)
|
173 |
+
# self.add_loss(K.mean(-mse(scope, mask)))
|
174 |
+
sum = K.sum(tf.ones(K.shape(mask)))
|
175 |
+
self.add_loss(K.abs((sum/4)-K.sum(mask))/sum)
|
176 |
+
# self.add_loss(K.mean(-mse(tf.zeros(K.shape(mask)), mask)))
|
177 |
+
for m in masks:
|
178 |
+
self.add_loss(K.mean(-mse(m,mask)))
|
179 |
+
|
180 |
+
full_attention += mask
|
181 |
+
big_mask = tf.image.resize(mask, K.shape(inputs)[1:-1])
|
182 |
+
ff += K.sigmoid(decoded)
|
183 |
+
full_image += K.sigmoid(decoded) * big_mask
|
184 |
+
r_m, h_m = self.memory(latent, h_m)
|
185 |
+
big_masks.append(big_mask)
|
186 |
+
self.add_loss(K.mean(mse(inputs, full_image)))
|
187 |
+
return full_image, big_masks
|
raven_utils/models/attn2.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow.keras import backend as K
|
5 |
+
from tensorflow.keras.layers import LSTMCell
|
6 |
+
from tensorflow.keras.models import Model
|
7 |
+
from tensorflow.keras.layers import Conv2D, Dense
|
8 |
+
from tensorflow.keras.losses import mse
|
9 |
+
from tensorflow.keras.models import clone_model
|
10 |
+
from tensorflow.layers.base import InputSpec, Layer
|
11 |
+
|
12 |
+
from models.dense import create_conv_model
|
13 |
+
from models.utils import broadcast
|
14 |
+
|
15 |
+
|
16 |
+
class ReflectionPadding2D(Layer):
|
17 |
+
def __init__(self, padding=(1, 1), **kwargs):
|
18 |
+
self.padding = tuple(padding)
|
19 |
+
self.input_spec = [InputSpec(ndim=4)]
|
20 |
+
super(ReflectionPadding2D, self).__init__(**kwargs)
|
21 |
+
|
22 |
+
def compute_output_shape(self, s):
|
23 |
+
""" If you are using "channels_last" configuration"""
|
24 |
+
return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])
|
25 |
+
|
26 |
+
def call(self, x, mask=None):
|
27 |
+
w_pad, h_pad = self.padding
|
28 |
+
return tf.pad(x, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]], 'REFLECT')
|
29 |
+
|
30 |
+
|
31 |
+
class Conv2Ref(Layer):
|
32 |
+
def __init__(self, padding=(1, 1), **kwargs):
|
33 |
+
self.padding = tuple(padding)
|
34 |
+
self.input_spec = [InputSpec(ndim=4)]
|
35 |
+
super(ReflectionPadding2D, self).__init__(**kwargs)
|
36 |
+
|
37 |
+
def compute_output_shape(self, s):
|
38 |
+
""" If you are using "channels_last" configuration"""
|
39 |
+
return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])
|
40 |
+
|
41 |
+
def call(self, x, mask=None):
|
42 |
+
w_pad, h_pad = self.padding
|
43 |
+
return tf.pad(x, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]], 'REFLECT')
|
44 |
+
|
45 |
+
|
46 |
+
class SegmentationNetwork(Model):
|
47 |
+
|
48 |
+
def __init__(self, filters=64, kernels=(3, 3)):
|
49 |
+
super(RecAE, self).__init__()
|
50 |
+
self.conv_1 = Conv2D(filters, kernels)
|
51 |
+
self.conv_2 = Conv2D(filters, kernels)
|
52 |
+
|
53 |
+
def call(self, inputs):
|
54 |
+
x = K.relu(inputs)
|
55 |
+
x = self.conv_1(x)
|
56 |
+
x = K.relu(x)
|
57 |
+
x = self.conv_2(x)
|
58 |
+
return x + inputs
|
59 |
+
|
60 |
+
|
61 |
+
class QueryNetwork(Model):
|
62 |
+
|
63 |
+
def __init__(self, units=64):
|
64 |
+
super(RecAE, self).__init__()
|
65 |
+
self.conv_1 = Dense(units)
|
66 |
+
self.conv_2 = Dense(units)
|
67 |
+
|
68 |
+
def call(self, inputs):
|
69 |
+
x = K.relu(inputs)
|
70 |
+
x = self.conv_1(x)
|
71 |
+
x = K.relu(x)
|
72 |
+
x = self.conv_2(x)
|
73 |
+
return x + inputs
|
74 |
+
|
75 |
+
|
76 |
+
class RecAE(Model):
|
77 |
+
|
78 |
+
def __init__(self, head, bottle, decoder):
|
79 |
+
super(RecAE, self).__init__()
|
80 |
+
self.head = head
|
81 |
+
self.bottle = bottle
|
82 |
+
self.base = clone_model(bottle)
|
83 |
+
self.decoder = decoder
|
84 |
+
self.segmentation_network = SegmentationNetwork()
|
85 |
+
self.query_network = QueryNetwork()
|
86 |
+
self.control = LSTMCell(64)
|
87 |
+
self.memory = LSTMCell(64)
|
88 |
+
|
89 |
+
def call(self, inputs):
|
90 |
+
feature = self.head(inputs)
|
91 |
+
segmentation = self.segmentation_network(feature)
|
92 |
+
control_base = self.base(feature)
|
93 |
+
h_c = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
|
94 |
+
h_m = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
|
95 |
+
shape = K.shape(feature)[:-1]
|
96 |
+
full_attention = tf.zeros(shape)[..., tf.newaxis]
|
97 |
+
full_image = tf.zeros(K.shape(inputs))
|
98 |
+
masks = []
|
99 |
+
ff = tf.zeros(K.shape(inputs))
|
100 |
+
scope = tf.ones(shape)[..., tf.newaxis]
|
101 |
+
for i in range(10):
|
102 |
+
r_c, h_c = self.control(tf.concat([control_base, h_m[0]], 1), h_c)
|
103 |
+
query = self.query_network(h_c[0])
|
104 |
+
log_attention = image_attention(segmentation, query)
|
105 |
+
attention = K.softmax(log_attention)
|
106 |
+
mask = attention * scope
|
107 |
+
scope = scope - mask
|
108 |
+
im = feature * mask
|
109 |
+
# im = feature
|
110 |
+
latent = self.bottle(im)
|
111 |
+
decoded = self.decoder(latent)
|
112 |
+
# self.add_loss(K.mean(-mse(full_attention, attention)))
|
113 |
+
# self.add_loss(K.mean(-mse(tf.ones(attention.shape), attention)))
|
114 |
+
full_attention += attention
|
115 |
+
big_mask = tf.image.resize(mask, K.shape(inputs)[1:-1])
|
116 |
+
ff += K.sigmoid(decoded)
|
117 |
+
full_image += K.sigmoid(decoded) * big_mask
|
118 |
+
r_m, h_m = self.memory(latent, h_m)
|
119 |
+
masks.append(big_mask)
|
120 |
+
self.add_loss(K.mean(mse(inputs, full_image)))
|
121 |
+
return full_image, masks
|
122 |
+
|
123 |
+
|
124 |
+
# def image_attention(image, query, scale=True):
|
125 |
+
@tf.function
|
126 |
+
def image_attention(image, query):
|
127 |
+
log_attention = K.sum(query[:, tf.newaxis, tf.newaxis, :] * image, axis=-1, keepdims=True)
|
128 |
+
# if scale is not None:
|
129 |
+
log_attention /= tf.sqrt(tf.cast(K.shape(image)[-1], dtype=float))
|
130 |
+
return log_attention
|
131 |
+
|
132 |
+
|
133 |
+
class RecAE_2(Model):
|
134 |
+
|
135 |
+
def __init__(self, head, bottle, decoder):
|
136 |
+
super(RecAE_2, self).__init__()
|
137 |
+
self.head = head
|
138 |
+
self.bottle = bottle
|
139 |
+
# self.base = clone_model(bottle)
|
140 |
+
self.base = self.bottle
|
141 |
+
self.decoder = decoder
|
142 |
+
self.segmentation_network = create_conv_model((64, 64, 1))
|
143 |
+
self.control = LSTMCell(64)
|
144 |
+
self.memory = LSTMCell(64)
|
145 |
+
|
146 |
+
def call(self, inputs):
|
147 |
+
feature = self.head(inputs)
|
148 |
+
control_base = self.base(feature)
|
149 |
+
h_c = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
|
150 |
+
h_m = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
|
151 |
+
shape = K.shape(feature)[:-1]
|
152 |
+
full_attention = tf.zeros(shape)[..., tf.newaxis]
|
153 |
+
full_image = tf.zeros(K.shape(inputs))
|
154 |
+
big_masks = []
|
155 |
+
masks = []
|
156 |
+
ff = tf.zeros(K.shape(inputs))
|
157 |
+
scope = tf.ones(shape)[..., tf.newaxis]
|
158 |
+
for i in range(4):
|
159 |
+
if i ==3:
|
160 |
+
mask = scope
|
161 |
+
else:
|
162 |
+
r_c, h_c = self.control(tf.concat([control_base, h_m[0]], 1), h_c)
|
163 |
+
query = broadcast(h_c[0], feature.shape[1:])
|
164 |
+
log_attention = self.segmentation_network(tf.concat([feature, query], axis=-1))
|
165 |
+
attention = K.sigmoid(log_attention)
|
166 |
+
mask = attention * scope
|
167 |
+
scope = scope - mask
|
168 |
+
masks.append(mask)
|
169 |
+
im = feature * mask
|
170 |
+
# im = feature
|
171 |
+
latent = self.bottle(im)
|
172 |
+
decoded = self.decoder(latent)
|
173 |
+
# self.add_loss(K.mean(-mse(scope, mask)))
|
174 |
+
sum = K.sum(tf.ones(K.shape(mask)))
|
175 |
+
self.add_loss(K.abs((sum/4)-K.sum(mask))/sum)
|
176 |
+
# self.add_loss(K.mean(-mse(tf.zeros(K.shape(mask)), mask)))
|
177 |
+
for m in masks:
|
178 |
+
self.add_loss(K.mean(-mse(m,mask)))
|
179 |
+
|
180 |
+
full_attention += mask
|
181 |
+
big_mask = tf.image.resize(mask, K.shape(inputs)[1:-1])
|
182 |
+
ff += K.sigmoid(decoded)
|
183 |
+
full_image += K.sigmoid(decoded) * big_mask
|
184 |
+
r_m, h_m = self.memory(latent, h_m)
|
185 |
+
big_masks.append(big_mask)
|
186 |
+
self.add_loss(K.mean(mse(inputs, full_image)))
|
187 |
+
return full_image, big_masks
|
raven_utils/models/augment.py
ADDED
File without changes
|
raven_utils/models/body.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
|
3 |
+
import tensorflow as tf
|
4 |
+
from ml_utils import self_product, lw
|
5 |
+
|
6 |
+
from models_utils import DictModel, ListModel, Flat, bm, Base, Cat, Res, Flat2, conv, KERNEL_SIZE, FILTERS, SAME, \
|
7 |
+
Get, SM, bs, RELU, ACTIVATION, dense, bd, HardBlock, MaxBlock
|
8 |
+
import models_utils.ops as K
|
9 |
+
from models_utils import Merge, SoftBlock
|
10 |
+
from models_utils.build import build_multi_dense, build_multi_conv, build_conv_model, build_encoder
|
11 |
+
from tensorflow.keras.layers import Lambda, Dense
|
12 |
+
from tensorflow.keras.layers import Conv2D
|
13 |
+
|
14 |
+
from config.constant import MEMORY, CONTROL, LATENT, MERGE, CONCAT, INFERENCE, FLATTEN
|
15 |
+
from models_utils.config import config
|
16 |
+
|
17 |
+
|
18 |
+
class RavRes(Res):
|
19 |
+
def __init__(self, model="v2", latent=256, act=RELU):
|
20 |
+
super().__init__(model=model)
|
21 |
+
self.latent = latent
|
22 |
+
|
23 |
+
def call(self, inputs):
|
24 |
+
return self.model(inputs) + inputs[0][:, ..., self.latent:]
|
25 |
+
|
26 |
+
|
27 |
+
# not working
|
28 |
+
class RavResConv(Res):
|
29 |
+
def __init__(self, model="v2", latent=256, act=RELU):
|
30 |
+
super().__init__(model=model)
|
31 |
+
self.latent = latent
|
32 |
+
self.conv = conv(latent, (1, 1), activation=act)
|
33 |
+
|
34 |
+
def call(self, inputs):
|
35 |
+
return self.model(inputs) + self.conv(inputs[0])
|
36 |
+
|
37 |
+
|
38 |
+
class RavResDense(Res):
|
39 |
+
def __init__(self, model="v2", latent=256, act=config.DEF_DENSE.activation):
|
40 |
+
super().__init__(model=model)
|
41 |
+
self.latent = latent
|
42 |
+
self.conv = dense(latent, activation=act)
|
43 |
+
|
44 |
+
def call(self, inputs):
|
45 |
+
return self.model(inputs) + self.conv(inputs[0])
|
46 |
+
|
47 |
+
|
48 |
+
def create_dense_block(latent=256, loop=1):
|
49 |
+
soft_block = Res(SoftBlock(build_multi_dense(latent), add_identity=None,
|
50 |
+
score_activation=tf.sigmoid), latent=latent)
|
51 |
+
cells = [
|
52 |
+
(lambda x: K.cat([x[:, 0], x[:, 1]]), LATENT, CONCAT),
|
53 |
+
(None, CONCAT, MEMORY),
|
54 |
+
(Dense(latent), CONCAT, MERGE),
|
55 |
+
(Merge(latent), [INFERENCE, MERGE], CONTROL),
|
56 |
+
(soft_block, [MEMORY, CONTROL], MEMORY)
|
57 |
+
]
|
58 |
+
|
59 |
+
return ListModel([DictModel(*cell) for cell in cells] * loop, [LATENT, INFERENCE], MEMORY)
|
60 |
+
|
61 |
+
|
62 |
+
def build_multi_conv(filters=32, end_filters=64, padding="same",mul=1, norm=None, **kwargs):
|
63 |
+
base = [(1, 3), (3, 1), (3, 3)]
|
64 |
+
block = list(self_product(base))
|
65 |
+
block2 = [b + b[0:1] for b in block]
|
66 |
+
block3 = [b + b for b in block]
|
67 |
+
block4 = ([[(3, 3)]] + [[(3, 3), (3, 3)]] + [[(3, 3), (3, 3), (3, 3)]]) * 2
|
68 |
+
block5 = [[], []]
|
69 |
+
all_blocks = [s for b in [block, block2, block3, block4, block5] for s in b]
|
70 |
+
start = {
|
71 |
+
FILTERS: filters,
|
72 |
+
KERNEL_SIZE: (1, 1)
|
73 |
+
}
|
74 |
+
|
75 |
+
end = {
|
76 |
+
FILTERS: end_filters,
|
77 |
+
KERNEL_SIZE: (1, 1),
|
78 |
+
ACTIVATION: None
|
79 |
+
}
|
80 |
+
|
81 |
+
all_arch = []
|
82 |
+
for ab in all_blocks:
|
83 |
+
arch = [{
|
84 |
+
FILTERS: filters,
|
85 |
+
KERNEL_SIZE: a,
|
86 |
+
**kwargs
|
87 |
+
} for a in ab]
|
88 |
+
all_arch.append([start] + arch + [end])
|
89 |
+
|
90 |
+
all_arch = all_arch * mul
|
91 |
+
|
92 |
+
return [
|
93 |
+
build_encoder(a, add_norm=norm if norm else None, padding=padding, name=f"b{i}", order=(1, 0) if norm else None)
|
94 |
+
for i, a in enumerate(all_arch)]
|
95 |
+
|
96 |
+
|
97 |
+
def create_block(latent=256, simpler=0, loop=1, padding=SAME, norm=None, trans_div=2, act="pass", type_="conv",
|
98 |
+
block_=SoftBlock,max_k=16,
|
99 |
+
**kwargs):
|
100 |
+
trans_size = int(latent / trans_div)
|
101 |
+
# if block_ == HardBlock:
|
102 |
+
# mul = 2
|
103 |
+
# elif block_ == MaxBlock:
|
104 |
+
# mul = int(38/max_k)
|
105 |
+
# else:
|
106 |
+
# mul = 1
|
107 |
+
|
108 |
+
if act == "pass":
|
109 |
+
res_class = RavRes
|
110 |
+
else:
|
111 |
+
if type_ == "dense":
|
112 |
+
res_class = RavResDense
|
113 |
+
else:
|
114 |
+
res_class = RavResConv
|
115 |
+
|
116 |
+
if type_ == "dense":
|
117 |
+
build_res = lambda: Res(model="dv2")
|
118 |
+
# build_reduction = lambda: bm([dense(latent), "IN"])
|
119 |
+
build_reduction = lambda: dense(latent)
|
120 |
+
build_flatten = lambda: bd([latent] * 2)
|
121 |
+
else:
|
122 |
+
build_res = lambda: Res(padding=padding)
|
123 |
+
build_reduction = lambda: bm([conv(trans_size if simpler else latent, 1, padding=padding), "BN"])
|
124 |
+
# build_reduction = lambda: bm([conv(latent, 1, padding=padding), "BN"])
|
125 |
+
# build_reduction = lambda: bm([conv(trans_size, 1, padding=padding), "BN"])
|
126 |
+
# build_reduction = lambda: conv(trans_size, 1, padding=padding)
|
127 |
+
# build_flatten = lambda: Flat2(filters=trans_size,res_no=2, padding=padding, units=64)
|
128 |
+
build_flatten = lambda: Flat2(filters=trans_size,padding=padding, units=64)
|
129 |
+
|
130 |
+
if simpler == 1:
|
131 |
+
cells = [
|
132 |
+
(lambda x: K.cat([x[:, 0], x[:, 1]]), LATENT, CONCAT,"concatenation"),
|
133 |
+
# (None, CONCAT, MEMORY),
|
134 |
+
(build_reduction(), CONCAT, MERGE,"Start_resnet_block"),
|
135 |
+
# (Get(), INFERENCE, INFERENCE),
|
136 |
+
(K.cat, [INFERENCE, MERGE], CONTROL,"concatenation"),
|
137 |
+
]
|
138 |
+
else:
|
139 |
+
cells = [
|
140 |
+
(lambda x: K.cat([x[:, 0], x[:, 1]]), LATENT, CONCAT),
|
141 |
+
(build_reduction(), CONCAT, MEMORY),
|
142 |
+
(build_reduction(), INFERENCE, CONTROL),
|
143 |
+
]
|
144 |
+
for i, l in enumerate(lw(loop)):
|
145 |
+
if l:
|
146 |
+
concat = K.cat
|
147 |
+
control_reduction = build_reduction()
|
148 |
+
control_res = build_res()
|
149 |
+
control_flatten = build_flatten()
|
150 |
+
if i == 0 and simpler == 1:
|
151 |
+
rest_params = {
|
152 |
+
"latent": latent,
|
153 |
+
"act": act
|
154 |
+
}
|
155 |
+
else:
|
156 |
+
rest_params = {
|
157 |
+
"latent": 0
|
158 |
+
}
|
159 |
+
|
160 |
+
|
161 |
+
if block_ == SoftBlock:
|
162 |
+
block_params = {
|
163 |
+
}
|
164 |
+
else:
|
165 |
+
block_params = {
|
166 |
+
"trans_output_shape": latent
|
167 |
+
}
|
168 |
+
if block_ == MaxBlock:
|
169 |
+
block_params['max_k'] = max_k
|
170 |
+
|
171 |
+
|
172 |
+
# todo change name
|
173 |
+
soft_block = res_class(
|
174 |
+
block_(
|
175 |
+
build_multi_dense(latent) if type_ == "dense" else build_multi_conv(trans_size, end_filters=latent,
|
176 |
+
norm=norm, padding=padding,
|
177 |
+
**kwargs),
|
178 |
+
add_identity=None,
|
179 |
+
score_activation=tf.sigmoid,
|
180 |
+
**block_params
|
181 |
+
|
182 |
+
),
|
183 |
+
**rest_params)
|
184 |
+
|
185 |
+
if i == 0 and simpler == 1:
|
186 |
+
cells.extend([
|
187 |
+
(control_reduction, CONTROL, CONTROL,"Reduction"),
|
188 |
+
(control_res, CONTROL, CONTROL,"Control_resnet_block"),
|
189 |
+
(control_flatten, CONTROL, FLATTEN,"Weights"),
|
190 |
+
(soft_block, [CONCAT, FLATTEN], MEMORY,"Transformation"),
|
191 |
+
# (soft_block, [MEMORY, FLATTEN], MEMORY,"Transformation"),
|
192 |
+
])
|
193 |
+
else:
|
194 |
+
if l:
|
195 |
+
memory_res = build_res()
|
196 |
+
|
197 |
+
cells.extend([
|
198 |
+
(memory_res, MEMORY, MEMORY,"Memory_resnet_block"),
|
199 |
+
(concat, [CONTROL, MEMORY], CONTROL,"concatenation"),
|
200 |
+
(control_reduction, CONTROL, CONTROL,"Reduction"),
|
201 |
+
(control_res, CONTROL, CONTROL,"Control_resnet_block"),
|
202 |
+
(control_flatten, CONTROL, FLATTEN,"Weights"),
|
203 |
+
(soft_block, [MEMORY, FLATTEN], MEMORY, "Transformation"),
|
204 |
+
])
|
205 |
+
return ListModel([DictModel(*cell) for cell in cells], [LATENT, INFERENCE], MEMORY, debug_=False)
|
206 |
+
|
207 |
+
#
|
208 |
+
#
|
209 |
+
# def test(x):
|
210 |
+
# np.zeros(4)
|
211 |
+
# self_product((1, 3))
|
212 |
+
#
|
213 |
+
#
|
214 |
+
# list(itertools.product())
|
215 |
+
# u.layers[0].layers[-1].model.layers[1]
|
216 |
+
|
217 |
+
# class RecurrentBodyDict(Model):
|
218 |
+
# # def __init__(self, start=None, cell=None, output_network=None, output_activation="tanh", latent=64, loop_no=5):
|
219 |
+
# def __init__(self, start=None, cell=None, output_network=None, output_activation=None, latent=64, loop_no=5):
|
220 |
+
# super().__init__()
|
221 |
+
# self.start = sm(start, lambda: SubClassingModel([StartLSTMControl(latent), StartLSTMMemory(latent)]),
|
222 |
+
# latent=latent)
|
223 |
+
# self.cell = sm(cell, lambda: SubClassingModel([LSTMControl(latent), LSTMMemory(latent)]), latent=latent)
|
224 |
+
# self.output_network = sm(output_network, lf(take_memory_states))
|
225 |
+
# self.loop_no = loop_no
|
226 |
+
# # tmp
|
227 |
+
# self.activation = Activation(output_activation)
|
228 |
+
#
|
229 |
+
# def call(self, inputs):
|
230 |
+
# outputs = []
|
231 |
+
# for j in range(3):
|
232 |
+
# outputs.append(self.start({"latent": inputs[0][j], "inference": inputs[1]}))
|
233 |
+
# for i in range(self.loop_no):
|
234 |
+
# for j in range(3):
|
235 |
+
# outputs[j] = self.cell(outputs[j])
|
236 |
+
#
|
237 |
+
# return self.activation(self.output_network(outputs))
|
238 |
+
#
|
239 |
+
#
|
240 |
+
# class RecurrentBodySimpleMix4Dict(RecurrentBodyDict):
|
241 |
+
# def __init__(self, latent=64, output_network=None, loop_no=5):
|
242 |
+
# super().__init__(
|
243 |
+
# start=SubClassingModel(
|
244 |
+
# [ConcatCell(), DenseCell(latent), InfMergeCell(latent),
|
245 |
+
# WeigthCell(latent, layer_no=np.repeat([1, 2, 3, 4, 5, 6, 7, 8], 4),
|
246 |
+
# add_identity=Lambda(lambda x: x[:, latent:]))]),
|
247 |
+
# cell=False,
|
248 |
+
# output_network=output_network, loop_no=0)
|
249 |
+
# class RecurrentBodySimpleMix4Conv(RecurrentBodyDict):
|
250 |
+
# def __init__(self, latent=64, output_network=None, loop_no=5):
|
251 |
+
# super().__init__(
|
252 |
+
# start=SubClassingModel(
|
253 |
+
# [ConcatCell(), ConvCell(latent), ReduceCell(latent), InfMergeCell(latent),
|
254 |
+
# ModelCell(latent=latent, layers_no=2, input_name=CONTROL, result_name=CONTROL),
|
255 |
+
# WeigthCell(latent,
|
256 |
+
# transformation_network=[build_conv_model2([latent] * i, kernels=(j, j)) for i in range(1, 7) for j in
|
257 |
+
# range(1, 5) for _ in range(1)],
|
258 |
+
# add_identity=Lambda(lambda x: x[:, ..., latent:]))
|
259 |
+
# ]),
|
260 |
+
# cell=False,
|
261 |
+
# output_network=output_network, loop_no=0)
|
262 |
+
#
|
263 |
+
#
|
264 |
+
# class RecurrentBodySimpleMix4Conv2(RecurrentBodyDict):
|
265 |
+
# def __init__(self, latent=64, output_network=None, loop_no=5):
|
266 |
+
# super().__init__(
|
267 |
+
# start=SubClassingModel(
|
268 |
+
# [ConcatCell(), ConvCell(latent), ReduceCell2(latent), InfMergeCell(latent),
|
269 |
+
# ModelCell(latent=latent, layers_no=2, input_name=CONTROL, result_name=CONTROL),
|
270 |
+
# WeigthCell(latent,
|
271 |
+
# transformation_network=[bc([latent] * i, kernels=(j, j)) for i in range(1, 7) for j in
|
272 |
+
# range(1, 5) for _ in range(1)],
|
273 |
+
# add_identity=Lambda(lambda x: x[:, ..., latent:]))
|
274 |
+
# ]),
|
275 |
+
# cell=False,
|
276 |
+
# output_network=output_network, loop_no=0)
|
raven_utils/models/class_.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ml_utils import lw
|
2 |
+
from models_utils import SubClassingModel, ops as K, Base
|
3 |
+
import tensorflow as tf
|
4 |
+
|
5 |
+
|
6 |
+
class Merge(SubClassingModel):
|
7 |
+
def call(self, inputs):
|
8 |
+
results = []
|
9 |
+
for i, model in enumerate(self.model[:-1]):
|
10 |
+
results.append(model(inputs[i]))
|
11 |
+
# todo why K.cat not working
|
12 |
+
results = self.model[-1](tf.concat(results, axis=-1))
|
13 |
+
return results
|
14 |
+
|
15 |
+
|
16 |
+
class RavenClass(Base):
|
17 |
+
def __init__(self, model, scales=None, no=3, name=None):
|
18 |
+
super().__init__(model=model, name=name)
|
19 |
+
self.scales = scales
|
20 |
+
self.no = no
|
21 |
+
|
22 |
+
def call(self, inputs):
|
23 |
+
inputs = lw(inputs)
|
24 |
+
class_res = []
|
25 |
+
# for i in range(inputs[0].shape[1]):
|
26 |
+
for i in range(self.no):
|
27 |
+
# d = [r[:, i] if r.ndim == 5 else r for r in inputs]
|
28 |
+
d = [inputs[s][:, i] if inputs[s].ndim > 2 else inputs for s in self.scales]
|
29 |
+
class_res.append(self.model(d))
|
30 |
+
# return tf.stack(class_res,axis=1)
|
31 |
+
return [class_res]
|
raven_utils/models/head.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from ml_utils import set_default
|
3 |
+
from models_utils import build_dense_model, bm, ActivationModel, sm, large_conv_dense_encoder, Pass
|
4 |
+
from models_utils import res
|
5 |
+
from tensorflow.keras import Model
|
6 |
+
from models_utils import ops as K
|
7 |
+
from tensorflow.keras.layers import Dense, Conv2D, Flatten
|
8 |
+
from keras.backend import batch_flatten
|
9 |
+
|
10 |
+
|
11 |
+
# todo Refactoring
|
12 |
+
class HeadModel(Model):
|
13 |
+
def __init__(self, encoder=None, inference_network=None, output_size=64, inference_output_size=None,
|
14 |
+
inference_activation="relu", stem=None, images_no=8, inference_image_no=None):
|
15 |
+
super().__init__()
|
16 |
+
# self.encoder = sm(encoder, bm([en.large_conv_dense_encoder(), Dense(output_size)], False))
|
17 |
+
self.encoder = encoder or bm([large_conv_dense_encoder(), Dense(output_size)])
|
18 |
+
# self.head = head or HeadBatch(encoder=encoder, output_size=output_size)
|
19 |
+
inference_output_size = inference_output_size or output_size
|
20 |
+
self.inference_network = inference_network or bm([
|
21 |
+
K.flat,
|
22 |
+
build_dense_model([1028, 512, 512, inference_output_size],
|
23 |
+
last_activation=inference_activation)]
|
24 |
+
)
|
25 |
+
self.stem = stem or Pass()
|
26 |
+
self.images_no = images_no
|
27 |
+
self.inference_image_no = self.images_no if inference_image_no is None else inference_image_no
|
28 |
+
|
29 |
+
|
30 |
+
class LatentHeadModel(HeadModel):
|
31 |
+
def call(self, inputs):
|
32 |
+
result = K.map_batch(inputs[:, :self.images_no], self.encoder)
|
33 |
+
inference = self.inference_network(result[:, :self.inference_image_no])
|
34 |
+
latents = self.stem(result)
|
35 |
+
return [latents, inference,result]
|
36 |
+
|
37 |
+
|
38 |
+
# # todo use map_batch
|
39 |
+
# class HeadBatch(Model):
|
40 |
+
# def __init__(self, encoder=None, output_size=64):
|
41 |
+
# super().__init__()
|
42 |
+
# self.encoder = sm(encoder, bm([large_conv_dense_encoder(), Dense(output_size)], False))
|
43 |
+
#
|
44 |
+
# def call(self, inputs):
|
45 |
+
# shape = tf.shape(inputs)
|
46 |
+
# latents = self.encoder(tf.reshape(inputs, shape=tf.concat([[-1], shape[2:]], axis=-1)))
|
47 |
+
# latents = K.reshape(latents, tf.concat([[-1, shape[1]], latents.shape[1:]], axis=-1))
|
48 |
+
# return latents
|
49 |
+
|
50 |
+
|
51 |
+
# Not working
|
52 |
+
class DuoHeadModel(HeadModel):
|
53 |
+
def __init__(self, encoder=None, inference_network=None, images_no=8, filters=-4):
|
54 |
+
super().__init__(encoder=encoder, inference_network=inference_network, images_no=images_no)
|
55 |
+
self.encoder = ActivationModel(self.encoder, filters=filters, include_input=False)
|
56 |
+
|
57 |
+
def call(self, inputs):
|
58 |
+
shape = inputs.shape
|
59 |
+
result = reversed(self.encoder(K.reshape(inputs, shape=[-1] + list(shape[2:]))))
|
60 |
+
latents = K.reshape(result[0], [-1, self.images_no] + [result[0].shape[-1]])
|
61 |
+
inference = self.inference_network(K.flat(result[1]))
|
62 |
+
return [latents, inference]
|
63 |
+
|
64 |
+
|
65 |
+
class MultiHeadModel(Model):
|
66 |
+
def __init__(self, encoder=None, images_no=8, filters=(1, 3, 6)):
|
67 |
+
super().__init__()
|
68 |
+
self.encoder = ActivationModel(encoder, filters=filters, include_input=False)
|
69 |
+
self.merge = MergeSacles()
|
70 |
+
self.images_no = images_no
|
71 |
+
|
72 |
+
def call(self, inputs):
|
73 |
+
shape = tf.shape(inputs)
|
74 |
+
results = self.encoder(tf.reshape(inputs, shape=tf.concat([[-1], shape[2:]], axis=-1)))
|
75 |
+
latents = [tf.reshape(result, shape=tf.concat([[-1, self.images_no], tf.shape(result)[1:]], axis=-1)) for result
|
76 |
+
in results]
|
77 |
+
|
78 |
+
l1 = tf.transpose(latents[0], (0, 2, 3, 1, 4))
|
79 |
+
# l1 = tf.reshape(l1, tuple(list(l1.shape[:3]) + [l1.shape[-2] * l1.shape[-1]]))
|
80 |
+
shape = tf.shape(l1)
|
81 |
+
l1 = tf.reshape(l1, tf.concat([[-1], shape[1:3], [shape[-2] * shape[-1]]], axis=-1))
|
82 |
+
|
83 |
+
l2 = tf.transpose(latents[1], (0, 2, 3, 1, 4))
|
84 |
+
# l2 = tf.reshape(l2, [-1] + list(l2.shape[1:3]) + [l2.shape[-2] * l2.shape[-1]])
|
85 |
+
shape = tf.shape(l2)
|
86 |
+
l2 = tf.reshape(l2, tf.concat([[-1], shape[1:3], [shape[-2] * shape[-1]]], axis=-1))
|
87 |
+
|
88 |
+
l3 = latents[2]
|
89 |
+
shape = tf.shape(l3)
|
90 |
+
# l3 = tf.reshape(l3, [-1] + [l3.shape[-2] * l3.shape[-1]])
|
91 |
+
l3 = tf.reshape(l3, tf.concat([[-1], [shape[-2] * shape[-1]]], axis=-1))
|
92 |
+
|
93 |
+
inference = self.merge([l1, l2, l3])
|
94 |
+
return [latents, inference]
|
95 |
+
|
96 |
+
|
97 |
+
class MergeSacles(Model):
|
98 |
+
def __init__(self):
|
99 |
+
super().__init__()
|
100 |
+
self.inf_1 = bm([Conv2D(64, 1, activation="relu"), res(64),
|
101 |
+
Conv2D(64, 3, strides=2, padding=SAME, activation="relu"),
|
102 |
+
res(64),
|
103 |
+
Flatten(),
|
104 |
+
Dense(256, "relu")])
|
105 |
+
self.inf_2 = bm([Conv2D(128, 1, activation="relu"),
|
106 |
+
res(128),
|
107 |
+
Flatten(),
|
108 |
+
Dense(256, "relu")])
|
109 |
+
self.inf_3 = Dense(256, "relu")
|
110 |
+
|
111 |
+
def call(self, inputs):
|
112 |
+
il1 = self.inf_1(inputs[0])
|
113 |
+
il2 = self.inf_2(inputs[1])
|
114 |
+
il3 = self.inf_3(inputs[2])
|
115 |
+
inference = tf.concat([il1, il2, il3], axis=1)
|
116 |
+
return inference
|
117 |
+
|
118 |
+
|
119 |
+
class MultiHeadModel2(Model):
|
120 |
+
def __init__(self, encoder=None, images_no=8, filters=(3, 6)):
|
121 |
+
super().__init__()
|
122 |
+
self.encoder = ActivationModel(encoder, filters=filters, include_input=False)
|
123 |
+
self.merge = MergeSacles2()
|
124 |
+
self.images_no = images_no
|
125 |
+
|
126 |
+
def call(self, inputs):
|
127 |
+
shape = tf.shape(inputs)
|
128 |
+
results = self.encoder(tf.reshape(inputs, shape=tf.concat([[-1], shape[2:]], axis=-1)))
|
129 |
+
latents = [tf.reshape(result, shape=tf.concat([[-1, self.images_no], tf.shape(result)[1:]], axis=-1)) for result
|
130 |
+
in results]
|
131 |
+
|
132 |
+
l2 = tf.transpose(latents[0], (0, 2, 3, 1, 4))
|
133 |
+
# l2 = tf.reshape(l2, [-1] + list(l2.shape[1:3]) + [l2.shape[-2] * l2.shape[-1]])
|
134 |
+
shape = tf.shape(l2)
|
135 |
+
l2 = tf.reshape(l2, tf.concat([[-1], shape[1:3], [shape[-2] * shape[-1]]], axis=-1))
|
136 |
+
|
137 |
+
l3 = latents[1]
|
138 |
+
shape = tf.shape(l3)
|
139 |
+
# l3 = tf.reshape(l3, [-1] + [l3.shape[-2] * l3.shape[-1]])
|
140 |
+
l3 = tf.reshape(l3, tf.concat([[-1], [shape[-2] * shape[-1]]], axis=-1))
|
141 |
+
|
142 |
+
inference = self.merge([l2, l3])
|
143 |
+
return [latents, inference]
|
144 |
+
|
145 |
+
|
146 |
+
class MergeSacles2(Model):
|
147 |
+
def __init__(self):
|
148 |
+
super().__init__()
|
149 |
+
self.inf_1 = bm([Conv2D(128, 1, activation="relu"),
|
150 |
+
res(128),
|
151 |
+
Flatten(),
|
152 |
+
Dense(256, "relu")])
|
153 |
+
self.inf_2 = Dense(256, "relu")
|
154 |
+
|
155 |
+
def call(self, inputs):
|
156 |
+
il1 = self.inf_1(inputs[0])
|
157 |
+
il2 = self.inf_2(inputs[1])
|
158 |
+
inference = tf.concat([il1, il2], axis=1)
|
159 |
+
return inference
|
raven_utils/models/loss.py
ADDED
@@ -0,0 +1,630 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import tensorflow as tf
|
4 |
+
import tensorflow.experimental.numpy as tnp
|
5 |
+
from models_utils import OUTPUT, TARGET, PREDICT, DictModel, add_loss, LOSS, Predict
|
6 |
+
from models_utils import SubClassingModel
|
7 |
+
from models_utils.models.utils import interleave
|
8 |
+
from models_utils.op import reshape
|
9 |
+
from tensorflow.keras import Model
|
10 |
+
# from tensorflow.keras import backend as K
|
11 |
+
from tensorflow.keras.layers import Lambda
|
12 |
+
from tensorflow.keras.losses import SparseCategoricalCrossentropy, mse
|
13 |
+
from tensorflow.keras.metrics import SparseCategoricalAccuracy, Accuracy, BinaryAccuracy
|
14 |
+
import models_utils.ops as K
|
15 |
+
|
16 |
+
import raven_utils.decode
|
17 |
+
import raven_utils as rv
|
18 |
+
from raven_utils.config.constant import LABELS, INDEX, ACC_SAME, ACC_CHOOSE_LOWER, ACC_CHOOSE_UPPER, CLASSIFICATION, \
|
19 |
+
SLOT, \
|
20 |
+
PROPERTIES, ACC, GROUP, NUMBER, MASK
|
21 |
+
from raven_utils.models.uitls_ import RangeMask
|
22 |
+
from raven_utils.const import VERTICAL, HORIZONTAL
|
23 |
+
|
24 |
+
|
25 |
+
def get_properties_mask(target):
|
26 |
+
return target[:, rv.target.END_INDEX:rv.target.UNIFORMITY_INDEX] > 0
|
27 |
+
|
28 |
+
|
29 |
+
def create_change_mask(target):
|
30 |
+
properties_mask = get_properties_mask(target)
|
31 |
+
return [create_mask(properties_mask, i) for i, _ in enumerate(rv.rules.ATTRIBUTES)]
|
32 |
+
|
33 |
+
|
34 |
+
def create_uniform_mask(target):
|
35 |
+
u_mask = lambda i: tf.tile(target[:, rv.target.UNIFORMITY_INDEX + i, None] == 3, [1, rv.rules.ATTRIBUTES_LEN])
|
36 |
+
properties_mask = tf.concat([u_mask(0), u_mask(1)], axis=-1) | get_properties_mask(target)
|
37 |
+
return [create_mask(properties_mask, i) for i, _ in enumerate(rv.rules.ATTRIBUTES)]
|
38 |
+
|
39 |
+
|
40 |
+
def create_all_mask(target):
|
41 |
+
return [
|
42 |
+
tf.cast(tf.ones(tf.stack([tf.shape(target)[0], rv.entity.SUM])), dtype=tf.bool) for i, _ in
|
43 |
+
enumerate(rv.rules.ATTRIBUTES)]
|
44 |
+
|
45 |
+
|
46 |
+
class BaselineClassificationLossModel(Model):
|
47 |
+
def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True):
|
48 |
+
super().__init__()
|
49 |
+
self.predict_fn = SubClassingModel([lambda x: x[0], PredictModel()])
|
50 |
+
self.loss_fn = ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
|
51 |
+
group_loss=group_loss)
|
52 |
+
self.metric_fn = SimilarityRaven(mode=mode)
|
53 |
+
|
54 |
+
def call(self, inputs):
|
55 |
+
losses = []
|
56 |
+
output = inputs[1]
|
57 |
+
losses.append(self.loss_fn([inputs[0][0], output]))
|
58 |
+
losses.append(self.metric_fn([inputs[0][2], inputs[3][0], inputs[0][1][:, 8:]]))
|
59 |
+
return losses
|
60 |
+
|
61 |
+
|
62 |
+
class RavenLoss(Model):
|
63 |
+
def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True, lw=(1.0, 0.3),
|
64 |
+
classification=False, trans=True, anneal=False):
|
65 |
+
super().__init__()
|
66 |
+
if anneal:
|
67 |
+
self.weight_scheduler
|
68 |
+
self.classification = classification
|
69 |
+
self.trans = trans
|
70 |
+
self.predict_fn = DictModel(SubClassingModel([lambda x: x[-1], PredictModel()]), in_=OUTPUT,
|
71 |
+
out=[PREDICT, MASK], name="pred")
|
72 |
+
if self.trans:
|
73 |
+
self.loss_fn = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
|
74 |
+
group_loss=group_loss, enable_metrics=False, lw=lw[0]),
|
75 |
+
name="main_loss")
|
76 |
+
self.loss_fn_2 = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
|
77 |
+
group_loss=group_loss), name="add_loss")
|
78 |
+
self.metric_fn = SimilarityRaven(mode=mode)
|
79 |
+
if self.classification:
|
80 |
+
self.loss_fn_3 = add_loss(
|
81 |
+
ClassRavenModel(mode=create_all_mask, number_loss=number_loss, slot_loss=slot_loss,
|
82 |
+
group_loss=group_loss, enable_metrics="c" if self.trans else True), lw=lw[1],
|
83 |
+
name="class_loss")
|
84 |
+
|
85 |
+
def call(self, inputs):
|
86 |
+
losses = []
|
87 |
+
output = inputs[OUTPUT]
|
88 |
+
target = inputs[TARGET]
|
89 |
+
labels = inputs[LABELS]
|
90 |
+
|
91 |
+
if self.trans:
|
92 |
+
losses.append(self.loss_fn([labels[:, 2], output[0]]))
|
93 |
+
losses.append(self.loss_fn([labels[:, 5], output[1]]))
|
94 |
+
losses.append(self.loss_fn_2([target, output[2]]))
|
95 |
+
losses.append(self.metric_fn([inputs[INDEX], inputs[PREDICT], labels]))
|
96 |
+
if self.classification:
|
97 |
+
for i in range(8):
|
98 |
+
losses.append(self.loss_fn_3([labels[:, i], inputs[CLASSIFICATION][i]]))
|
99 |
+
return {**inputs, LOSS: losses}
|
100 |
+
|
101 |
+
|
102 |
+
class VTRavenLoss(Model):
|
103 |
+
def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True, lw=(1.0, 0.1),
|
104 |
+
classification=False, trans=True, anneal=False, plw=None):
|
105 |
+
super().__init__()
|
106 |
+
if anneal:
|
107 |
+
self.weight_scheduler
|
108 |
+
self.classification = classification
|
109 |
+
self.trans = trans
|
110 |
+
self.predict_fn = DictModel(SubClassingModel([lambda x: x[:, -1], PredictModel()]), in_=OUTPUT,
|
111 |
+
out=[PREDICT, MASK], name="pred")
|
112 |
+
self.loss_fn = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
|
113 |
+
group_loss=group_loss, plw=plw), lw=lw[0] , name="add_loss")
|
114 |
+
self.metric_fn = SimilarityRaven(mode=mode)
|
115 |
+
if self.classification:
|
116 |
+
self.loss_fn_2 = add_loss(
|
117 |
+
ClassRavenModel(mode=create_all_mask, number_loss=number_loss, slot_loss=slot_loss,
|
118 |
+
group_loss=group_loss, enable_metrics="c", plw=plw), lw=lw[1], name="class_loss")
|
119 |
+
|
120 |
+
def call(self, inputs):
|
121 |
+
losses = []
|
122 |
+
output = inputs[OUTPUT]
|
123 |
+
target = inputs[TARGET]
|
124 |
+
labels = inputs[LABELS]
|
125 |
+
|
126 |
+
for i in range(9):
|
127 |
+
losses.append(self.loss_fn_2([labels[:, i], output[:, i]]))
|
128 |
+
losses.append(self.loss_fn([target, output[:, 8]]))
|
129 |
+
losses.append(self.metric_fn([inputs[INDEX], inputs[PREDICT], labels]))
|
130 |
+
return {**inputs, LOSS: losses}
|
131 |
+
|
132 |
+
|
133 |
+
class SingleVTRavenLoss(Model):
|
134 |
+
def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True, lw=(1.0, 0.1),
|
135 |
+
classification=False, trans=True, anneal=False):
|
136 |
+
super().__init__()
|
137 |
+
if anneal:
|
138 |
+
self.weight_scheduler
|
139 |
+
self.classification = classification
|
140 |
+
self.trans = trans
|
141 |
+
self.predict_fn = DictModel(PredictModel(), in_=OUTPUT, out=[PREDICT, MASK], name="pred")
|
142 |
+
self.loss_fn = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
|
143 |
+
group_loss=group_loss), lw=lw[0], name="add_loss")
|
144 |
+
self.metric_fn = SimilarityRaven(mode=mode)
|
145 |
+
|
146 |
+
def call(self, inputs):
|
147 |
+
losses = []
|
148 |
+
output = inputs[OUTPUT]
|
149 |
+
target = inputs[TARGET]
|
150 |
+
labels = inputs[LABELS]
|
151 |
+
|
152 |
+
losses.append(self.loss_fn([target, output]))
|
153 |
+
losses.append(self.metric_fn([inputs[INDEX], inputs[PREDICT], labels]))
|
154 |
+
return {**inputs, LOSS: losses}
|
155 |
+
|
156 |
+
|
157 |
+
class ClassRavenModel(Model):
|
158 |
+
def __init__(self, mode=create_all_mask,plw=None, number_loss=False, slot_loss=True, group_loss=True, enable_metrics=True,
|
159 |
+
lw=1.0):
|
160 |
+
super().__init__()
|
161 |
+
self.number_loss = number_loss
|
162 |
+
self.group_loss = group_loss
|
163 |
+
self.enable_metrics = enable_metrics
|
164 |
+
self.slot_loss = slot_loss
|
165 |
+
self.predict_fn = PredictModel()
|
166 |
+
self.loss_fn = SparseCategoricalCrossentropy(from_logits=True)
|
167 |
+
if self.slot_loss:
|
168 |
+
self.loss_fn_2 = tf.nn.sigmoid_cross_entropy_with_logits
|
169 |
+
if self.enable_metrics:
|
170 |
+
self.enable_metrics = f"{self.enable_metrics}_" if isinstance(self.enable_metrics, str) else ""
|
171 |
+
self.metric_fn = [
|
172 |
+
SparseCategoricalAccuracy(name=f"{self.enable_metrics}{ACC}_{property_}") for property_ in
|
173 |
+
rv.properties.NAMES]
|
174 |
+
if self.group_loss:
|
175 |
+
self.metric_fn_group = SparseCategoricalAccuracy(name=f"{self.enable_metrics}{ACC}_{GROUP}")
|
176 |
+
if self.slot_loss:
|
177 |
+
self.metric_fn_2 = BinaryAccuracy(name=f"{self.enable_metrics}{ACC}_{SLOT}")
|
178 |
+
self.range_mask = RangeMask()
|
179 |
+
self.mode = mode
|
180 |
+
self.lw = lw
|
181 |
+
if not plw:
|
182 |
+
plw = [1., 95.37352927, 2.83426987, 0.85212836, 1.096005, 1.21943385]
|
183 |
+
elif isinstance(plw, int) or isinstance(plw, float):
|
184 |
+
plw = [1., plw, 2.83426987, 0.85212836, 1.096005, 1.21943385]
|
185 |
+
# plw = [plw] * 6
|
186 |
+
self.plw = plw
|
187 |
+
|
188 |
+
# self.predict_fn = partial(tf.argmax, axis=-1)
|
189 |
+
|
190 |
+
def call(self, inputs):
|
191 |
+
losses = []
|
192 |
+
metrics = {}
|
193 |
+
target = inputs[0]
|
194 |
+
output = inputs[1]
|
195 |
+
|
196 |
+
target_group, target_slot, target_all = raven_utils.decode.decode_target(target)
|
197 |
+
|
198 |
+
group_output, output_slot, outputs = raven_utils.decode.output_divide(output, split_fn=tf.split)
|
199 |
+
|
200 |
+
# group
|
201 |
+
if self.group_loss:
|
202 |
+
group_loss = self.lw * self.plw[0] * self.loss_fn(target_group, group_output)
|
203 |
+
losses.append(group_loss)
|
204 |
+
|
205 |
+
if isinstance(self.enable_metrics, str):
|
206 |
+
group_metric = self.metric_fn_group(target_group, group_output)
|
207 |
+
# metrics[GROUP] = group_metric
|
208 |
+
self.add_metric(group_metric)
|
209 |
+
self.add_metric(tf.reduce_sum(group_metric), f"{self.enable_metrics}{ACC}")
|
210 |
+
|
211 |
+
# setting uniformity mask
|
212 |
+
full_properties_musks = self.mode(target)
|
213 |
+
|
214 |
+
range_mask = self.range_mask(target_group)
|
215 |
+
|
216 |
+
if self.slot_loss:
|
217 |
+
# number
|
218 |
+
number_mask = range_mask & full_properties_musks[0]
|
219 |
+
number_mask = tf.cast(number_mask, tf.float32)
|
220 |
+
target_number = tf.reduce_sum(
|
221 |
+
tf.cast(target_slot, "float32") * number_mask, axis=-1)
|
222 |
+
output_number = tf.reduce_sum(
|
223 |
+
tf.cast(tf.sigmoid(output_slot) >= 0.5, "float32") * number_mask, axis=-1)
|
224 |
+
|
225 |
+
# output_number = tf.reduce_sum(tf.sigmoid(output_slot) * number_mask, axis=-1)
|
226 |
+
if self.number_loss:
|
227 |
+
scale = 1 / 9
|
228 |
+
if self.number_loss == 2:
|
229 |
+
output_number_2 = tf.reduce_sum(tf.sigmoid(output_slot) * number_mask, axis=-1)
|
230 |
+
else:
|
231 |
+
output_number_2 = output_number
|
232 |
+
number_loss = self.lw * self.plw[1] * mse(tf.stop_gradient(target_number) * scale, output_number_2 * scale)
|
233 |
+
losses.append(number_loss)
|
234 |
+
|
235 |
+
# metrics[NUMBER] = number_acc
|
236 |
+
|
237 |
+
if isinstance(self.enable_metrics, str):
|
238 |
+
number_acc = tf.reduce_mean(
|
239 |
+
tf.cast(tf.cast(target_number, "int8") == tf.cast(output_number, "int8"), "float32"))
|
240 |
+
self.add_metric(tf.reduce_sum(number_acc), f"{self.enable_metrics}{ACC}_{NUMBER}")
|
241 |
+
self.add_metric(tf.reduce_sum(number_acc), f"{self.enable_metrics}{ACC}")
|
242 |
+
self.add_metric(tf.reduce_sum(number_acc), f"{self.enable_metrics}{ACC}_NO_{GROUP}")
|
243 |
+
|
244 |
+
# position/slot
|
245 |
+
slot_mask = range_mask & full_properties_musks[1]
|
246 |
+
# tf.boolean_mask(target_slot,slot_mask)
|
247 |
+
|
248 |
+
if tf.reduce_any(slot_mask):
|
249 |
+
# if tf.reduce_mean(tf.cast(slot_mask, dtype=tf.int32)) > 0:
|
250 |
+
target_slot_masked = tf.boolean_mask(target_slot, slot_mask)[:, None]
|
251 |
+
output_slot_masked = tf.boolean_mask(output_slot, slot_mask)[:, None]
|
252 |
+
loss_slot = self.lw * self.plw[2] * tf.reduce_mean(
|
253 |
+
self.loss_fn_2(tf.cast(target_slot_masked, "float32"), output_slot_masked))
|
254 |
+
if isinstance(self.enable_metrics, str):
|
255 |
+
acc_slot = self.metric_fn_2(target_slot_masked, output_slot_masked)
|
256 |
+
self.add_metric(acc_slot)
|
257 |
+
self.add_metric(tf.reduce_sum(acc_slot), f"{self.enable_metrics}{ACC}")
|
258 |
+
self.add_metric(tf.reduce_sum(acc_slot), f"{self.enable_metrics}{ACC}_NO_{GROUP}")
|
259 |
+
else:
|
260 |
+
loss_slot = 0.0
|
261 |
+
acc_slot = -1.0
|
262 |
+
|
263 |
+
losses.append(loss_slot)
|
264 |
+
# metrics[SLOT] = acc_slot
|
265 |
+
# if loss_slot != 0:
|
266 |
+
|
267 |
+
# if tf.reduce_any(slot_mask):
|
268 |
+
|
269 |
+
# self.add_metric(acc_slot, f"{self.enable_metrics}{ACC}_{NUMBER}")
|
270 |
+
# self.add_metric(acc_slot, f"{self.enable_metrics}{ACC}")
|
271 |
+
# self.add_metric(acc_slot, f"{self.enable_metrics}{ACC}_NO_{GROUP}")
|
272 |
+
|
273 |
+
# properties
|
274 |
+
for i, out in enumerate(outputs):
|
275 |
+
shape = (-1, rv.entity.SUM, rv.properties.RAW_SIZE[i])
|
276 |
+
out_reshaped = tf.reshape(out, shape)
|
277 |
+
properties_mask = tf.cast(target_slot, "bool") & full_properties_musks[i + 2]
|
278 |
+
|
279 |
+
if tf.reduce_any(properties_mask):
|
280 |
+
out_masked = tf.boolean_mask(out_reshaped, properties_mask)
|
281 |
+
out_target = tf.boolean_mask(target_all[i], properties_mask)
|
282 |
+
loss = self.lw * self.plw[3+i] * self.loss_fn(out_target, out_masked)
|
283 |
+
if isinstance(self.enable_metrics, str):
|
284 |
+
metric = self.metric_fn[i](out_target, out_masked)
|
285 |
+
self.add_metric(metric)
|
286 |
+
# self.add_metric(metric, f"{self.enable_metrics}{ACC}")
|
287 |
+
self.add_metric(tf.reduce_sum(metric), f"{self.enable_metrics}{ACC}")
|
288 |
+
self.add_metric(tf.reduce_sum(metric), f"{self.enable_metrics}{ACC}_{PROPERTIES}")
|
289 |
+
self.add_metric(tf.reduce_sum(metric), f"{self.enable_metrics}{ACC}_NO_{GROUP}")
|
290 |
+
else:
|
291 |
+
loss = 0.0
|
292 |
+
metric = -1.0
|
293 |
+
|
294 |
+
losses.append(loss)
|
295 |
+
return losses
|
296 |
+
|
297 |
+
|
298 |
+
class FullMask(Model):
|
299 |
+
def __init__(self, mode=create_uniform_mask):
|
300 |
+
super().__init__()
|
301 |
+
self.range_mask = RangeMask()
|
302 |
+
self.mode = mode
|
303 |
+
|
304 |
+
def call(self, inputs):
|
305 |
+
target_group, target_slot, _ = raven_utils.decode.decode_target(inputs)
|
306 |
+
full_properties_musks = self.mode(inputs)
|
307 |
+
range_mask = self.range_mask(target_group)
|
308 |
+
|
309 |
+
number_mask = range_mask & full_properties_musks[0]
|
310 |
+
|
311 |
+
slot_mask = range_mask & full_properties_musks[1]
|
312 |
+
properties_mask = []
|
313 |
+
for property_mask in full_properties_musks[2:]:
|
314 |
+
properties_mask.append(tf.cast(target_slot, "bool") & property_mask)
|
315 |
+
return [slot_mask, properties_mask, number_mask]
|
316 |
+
|
317 |
+
|
318 |
+
def create_mask(rules, i):
|
319 |
+
mask_1 = tf.tile(rules[:, i][None], [len(rv.target.FIRST_LAYOUT), 1])
|
320 |
+
mask_2 = tf.tile(rules[:, i + 5][None], [len(rv.target.SECOND_LAYOUT), 1])
|
321 |
+
shape = tf.shape(rules)
|
322 |
+
full_mask_1 = tf.scatter_nd(tnp.array(rv.target.FIRST_LAYOUT)[:, None], mask_1, shape=(rv.entity.SUM, shape[0]))
|
323 |
+
full_mask_2 = tf.tensor_scatter_nd_update(full_mask_1, tnp.array(rv.target.SECOND_LAYOUT)[:, None], mask_2)
|
324 |
+
return tf.transpose(full_mask_2)
|
325 |
+
|
326 |
+
|
327 |
+
# class PredictModel(Model):
|
328 |
+
# def __init__(self):
|
329 |
+
# super().__init__()
|
330 |
+
# self.predict_fn = Lambda(partial(tf.argmax, axis=-1))
|
331 |
+
# self.predict_fn_2 = Lambda(lambda x: tf.sigmoid(x) > 0.5)
|
332 |
+
# self.range_mask = RangeMask()
|
333 |
+
#
|
334 |
+
# # self.predict_fn = partial(tf.argmax, axis=-1)
|
335 |
+
#
|
336 |
+
# def call(self, inputs):
|
337 |
+
# group_output = inputs[rv.OUTPUT_GROUP_SLICE]
|
338 |
+
# group_loss = self.predict_fn(group_output)[:, None]
|
339 |
+
#
|
340 |
+
# output_slot = inputs[rv.OUTPUT_SLOT_SLICE]
|
341 |
+
# range_mask = self.range_mask(group_loss[:, 0])
|
342 |
+
# loss_slot = tf.cast(self.predict_fn_2(output_slot), dtype=tf.int64)
|
343 |
+
#
|
344 |
+
# properties_output = inputs[rv.OUTPUT_PROPERTIES_SLICE]
|
345 |
+
# properties = []
|
346 |
+
# outputs = tf.split(properties_output, list(rv.ENTITY_PROPERTIES_INDEX.values()), axis=-1)
|
347 |
+
# for i, out in enumerate(outputs):
|
348 |
+
# shape = (-1, rv.ENTITY_SUM, rv.ENTITY_PROPERTIES_VALUES[i])
|
349 |
+
# out_reshaped = tf.reshape(out, shape)
|
350 |
+
# properties.append(self.predict_fn(out_reshaped))
|
351 |
+
# number_loss = tf.reduce_sum(loss_slot, axis=-1, keepdims=True)
|
352 |
+
#
|
353 |
+
# result = tf.concat([group_loss, loss_slot, interleave(properties), number_loss], axis=-1)
|
354 |
+
#
|
355 |
+
# return [result, range_mask, range_mask, range_mask, range_mask]
|
356 |
+
|
357 |
+
class PredictModel(Model):
|
358 |
+
def __init__(self):
|
359 |
+
super().__init__()
|
360 |
+
self.predict_fn = Predict()
|
361 |
+
self.predict_fn_2 = Lambda(lambda x: tf.sigmoid(x) > 0.5)
|
362 |
+
self.range_mask = RangeMask()
|
363 |
+
|
364 |
+
# self.predict_fn = partial(tf.argmax, axis=-1)
|
365 |
+
|
366 |
+
def call(self, inputs):
|
367 |
+
group_output, output_slot, *properties = rv.decode.output(inputs, tf.split, self.predict_fn, self.predict_fn_2)
|
368 |
+
number_loss = K.int64(K.sum(output_slot))
|
369 |
+
result = tf.concat(
|
370 |
+
[group_output[:, None], tf.cast(output_slot, dtype=tf.int64), interleave(properties), number_loss[:, None]],
|
371 |
+
axis=-1)
|
372 |
+
|
373 |
+
range_mask = self.range_mask(group_output)
|
374 |
+
return [result, range_mask]
|
375 |
+
# return [result, range_mask, range_mask, range_mask, range_mask]
|
376 |
+
|
377 |
+
|
378 |
+
# todo change slices
|
379 |
+
class PredictModelMasked(Model):
|
380 |
+
def __init__(self):
|
381 |
+
super().__init__()
|
382 |
+
self.predict_fn = Lambda(partial(tf.argmax, axis=-1))
|
383 |
+
self.loss_fn_2 = Lambda(lambda x: tf.sigmoid(x) > 0.5)
|
384 |
+
self.range_mask = RangeMask()
|
385 |
+
|
386 |
+
# self.predict_fn = partial(tf.argmax, axis=-1)
|
387 |
+
|
388 |
+
def call(self, inputs):
|
389 |
+
group_output = inputs[:, -rv.GROUPS_NO:]
|
390 |
+
group_loss = self.predict_fn(group_output)[:, None]
|
391 |
+
|
392 |
+
output_slot = inputs[:, :rv.ENTITY_SUM]
|
393 |
+
range_mask = self.range_mask(group_loss[:, 0])
|
394 |
+
loss_slot = tf.cast(self.predict_fn_2(output_slot * range_mask), dtype=tf.int64)
|
395 |
+
|
396 |
+
properties_output = inputs[:, rv.ENTITY_SUM:-rv.GROUPS_NO]
|
397 |
+
|
398 |
+
properties = []
|
399 |
+
outputs = tf.split(properties_output, list(rv.ENTITY_PROPERTIES_INDEX.values()), axis=-1)
|
400 |
+
for i, out in enumerate(outputs):
|
401 |
+
shape = (-1, rv.ENTITY_SUM, rv.ENTITY_PROPERTIES_VALUES[i])
|
402 |
+
out_reshaped = tf.reshape(out, shape)
|
403 |
+
out_masked = out_reshaped * loss_slot[..., None]
|
404 |
+
properties.append(self.predict_fn(out_masked))
|
405 |
+
# out_masked[0].numpy()
|
406 |
+
number_loss = tf.reduce_sum(loss_slot, axis=-1, keepdims=True)
|
407 |
+
|
408 |
+
result = tf.concat([group_loss, loss_slot, interleave(properties), number_loss], axis=-1)
|
409 |
+
|
410 |
+
return result
|
411 |
+
|
412 |
+
|
413 |
+
def final_predict_mask(x, mask):
|
414 |
+
r = reshape(x[0][:, rv.INDEX[0]:-1], [-1, 3])
|
415 |
+
return tf.ragged.boolean_mask(r, mask)
|
416 |
+
|
417 |
+
|
418 |
+
def final_predict(x, mode=False):
|
419 |
+
m = x[1] if mode else tf.cast(x[0][:, 1:rv.INDEX[0]], tf.bool)
|
420 |
+
return final_predict_mask(x[0], m)
|
421 |
+
|
422 |
+
|
423 |
+
def final_predict_2(x):
|
424 |
+
ones = tf.cast(tf.ones(tf.shape(x[0])[0]), tf.bool)[:, None]
|
425 |
+
mask = tf.concat([ones, tf.tile(x[1], [1, 4]), ones], axis=-1)
|
426 |
+
return tf.ragged.boolean_mask(x[0], mask)
|
427 |
+
|
428 |
+
|
429 |
+
class PredictModelOld(Model):
|
430 |
+
|
431 |
+
def call(self, inputs):
|
432 |
+
output = inputs[-2]
|
433 |
+
|
434 |
+
rest_output = output[:, :-rv.GROUPS_NO]
|
435 |
+
|
436 |
+
result_all = []
|
437 |
+
outputs = tf.split(rest_output, list(rv.ENTITY_PROPERTIES_INDEX.values()), axis=-3)
|
438 |
+
for i, out in enumerate(outputs):
|
439 |
+
shape = (-3, rv.ENTITY_SUM, rv.ENTITY_PROPERTIES_VALUES[i])
|
440 |
+
out_reshaped = tf.reshape(out, shape)
|
441 |
+
|
442 |
+
result = tf.cast(tf.argmax(out_reshaped, axis=-3), dtype="int8")
|
443 |
+
result_all.append(result)
|
444 |
+
|
445 |
+
result_all = interleave(result_all)
|
446 |
+
return result_all
|
447 |
+
|
448 |
+
|
449 |
+
def get_matches(diff, target_index):
|
450 |
+
diff_sum = K.sum(diff)
|
451 |
+
db_argsort = tf.argsort(diff_sum, axis=-1)
|
452 |
+
db_sorted = tf.sort(diff_sum)
|
453 |
+
db_mask = db_sorted[:, 0, None] == db_sorted
|
454 |
+
db_same = tf.where(db_mask, db_argsort, -1 * tf.ones_like(db_argsort))
|
455 |
+
matched_index = db_same == target_index
|
456 |
+
# setting shape needed for TensorFlow graph
|
457 |
+
matched_index.set_shape(db_same.shape)
|
458 |
+
matches = K.any(matched_index)
|
459 |
+
more_matches = K.sum(db_mask) > 1
|
460 |
+
once_matches = K.sum(matches & tf.math.logical_not(more_matches))
|
461 |
+
return matches, more_matches, once_matches
|
462 |
+
|
463 |
+
|
464 |
+
class SimilarityRaven(Model):
|
465 |
+
def __init__(self, mode=create_all_mask, number_loss=False):
|
466 |
+
super().__init__()
|
467 |
+
self.range_mask = RangeMask()
|
468 |
+
self.mode = mode
|
469 |
+
|
470 |
+
# self.predict_fn = partial(tf.argmax, axis=-1)
|
471 |
+
|
472 |
+
# INDEX, PREDICT, LABELS
|
473 |
+
def call(self, inputs):
|
474 |
+
metrics = []
|
475 |
+
target_index = inputs[0] - 8
|
476 |
+
predict = inputs[1]
|
477 |
+
answers = inputs[2][:, 8:]
|
478 |
+
shape = tf.shape(predict)
|
479 |
+
|
480 |
+
target = K.gather(answers, target_index[:, 0])
|
481 |
+
|
482 |
+
target_group = target[:, 0]
|
483 |
+
|
484 |
+
# comp_slice = np.
|
485 |
+
target_comp = target[:, 1:rv.target.END_INDEX]
|
486 |
+
predict_comp = predict[:, 1:rv.target.END_INDEX]
|
487 |
+
answers_comp = answers[:, :, 1:rv.target.END_INDEX]
|
488 |
+
|
489 |
+
full_properties_musks = self.mode(target)
|
490 |
+
fpm = K.cat([full_properties_musks[0], interleave(full_properties_musks[2:])])
|
491 |
+
|
492 |
+
range_mask = self.range_mask(target_group)
|
493 |
+
full_range_mask = K.cat([range_mask, tf.repeat(range_mask, 3, axis=-1)], axis=-1)
|
494 |
+
|
495 |
+
final_mask = fpm & full_range_mask
|
496 |
+
|
497 |
+
target_masked = target_comp * final_mask
|
498 |
+
predict_masked = predict_comp * final_mask
|
499 |
+
answers_masked = answers_comp * tf.tile(final_mask[:, None], [1, 8, 1])
|
500 |
+
|
501 |
+
acc_same = K.mean(K.all(target_masked == predict_masked))
|
502 |
+
self.add_metric(acc_same, ACC_SAME)
|
503 |
+
metrics.append(acc_same)
|
504 |
+
|
505 |
+
diff = tf.abs(predict_masked[:, None] - answers_masked)
|
506 |
+
diff_bool = diff != 0
|
507 |
+
|
508 |
+
matches, more_matches, once_matches = get_matches(tf.cast(diff_bool, dtype=tf.int32), target_index)
|
509 |
+
|
510 |
+
second_phase_mask = (more_matches & matches)
|
511 |
+
diff_second_phase = tf.boolean_mask(diff, second_phase_mask)
|
512 |
+
target_index_2 = tf.boolean_mask(target_index, second_phase_mask, axis=0)
|
513 |
+
|
514 |
+
matches_2, more_matches_2, once_matches_2 = get_matches(diff_second_phase, target_index_2)
|
515 |
+
matches_2_no = K.sum(matches_2)
|
516 |
+
|
517 |
+
acc_choose_upper = (once_matches + matches_2_no) / shape[0]
|
518 |
+
self.add_metric(acc_choose_upper, ACC_CHOOSE_UPPER)
|
519 |
+
metrics.append(acc_choose_upper)
|
520 |
+
|
521 |
+
acc_choose_lower = (once_matches + once_matches_2) / shape[0]
|
522 |
+
self.add_metric(acc_choose_lower, ACC_CHOOSE_LOWER)
|
523 |
+
metrics.append(acc_choose_lower)
|
524 |
+
|
525 |
+
return metrics
|
526 |
+
|
527 |
+
|
528 |
+
class SimilarityRaven2(Model):
|
529 |
+
def __init__(self, mode=create_all_mask, number_loss=False):
|
530 |
+
super().__init__()
|
531 |
+
self.range_mask = RangeMask()
|
532 |
+
self.mode = mode
|
533 |
+
|
534 |
+
# self.predict_fn = partial(tf.argmax, axis=-1)
|
535 |
+
|
536 |
+
# INDEX, PREDICT, LABELS
|
537 |
+
def call(self, inputs):
|
538 |
+
metrics = []
|
539 |
+
target_index = inputs[0] - 8
|
540 |
+
predict = inputs[1]
|
541 |
+
answers = inputs[2][:, 8:]
|
542 |
+
shape = tf.shape(predict)
|
543 |
+
|
544 |
+
target = K.gather(answers, target_index[:, 0])
|
545 |
+
|
546 |
+
target_group = target[:, 0]
|
547 |
+
|
548 |
+
# comp_slice = np.
|
549 |
+
target_comp = target[:, 1:rv.target.END_INDEX]
|
550 |
+
predict_comp = predict[:, 1:rv.target.END_INDEX]
|
551 |
+
answers_comp = answers[:, :, 1:rv.target.END_INDEX]
|
552 |
+
|
553 |
+
full_properties_musks = self.mode(target)
|
554 |
+
fpm = K.cat([full_properties_musks[0], interleave(full_properties_musks[2:])])
|
555 |
+
|
556 |
+
range_mask = self.range_mask(target_group)
|
557 |
+
full_range_mask = K.cat([range_mask, tf.repeat(range_mask, 3, axis=-1)], axis=-1)
|
558 |
+
|
559 |
+
final_mask = fpm & full_range_mask
|
560 |
+
|
561 |
+
target_masked = target_comp * final_mask
|
562 |
+
predict_masked = predict_comp * final_mask
|
563 |
+
answers_masked = answers_comp * tf.tile(final_mask[:, None], [1, 8, 1])
|
564 |
+
|
565 |
+
acc_same = K.mean(K.all(target_masked == predict_masked))
|
566 |
+
self.add_metric(acc_same, ACC_SAME)
|
567 |
+
metrics.append(acc_same)
|
568 |
+
|
569 |
+
diff = tf.abs(predict_masked[:, None] - answers_masked)
|
570 |
+
diff_bool = diff != 0
|
571 |
+
|
572 |
+
matches, more_matches, once_matches = get_matches(tf.cast(diff_bool, dtype=tf.int32), target_index)
|
573 |
+
|
574 |
+
second_phase_mask = (more_matches & matches)
|
575 |
+
diff_second_phase = tf.boolean_mask(diff, second_phase_mask)
|
576 |
+
target_index_2 = tf.boolean_mask(target_index, second_phase_mask, axis=0)
|
577 |
+
|
578 |
+
matches_2, more_matches_2, once_matches_2 = get_matches(diff_second_phase, target_index_2)
|
579 |
+
matches_2_no = K.sum(matches_2)
|
580 |
+
|
581 |
+
acc_choose_upper = (once_matches + matches_2_no) / shape[0]
|
582 |
+
self.add_metric(acc_choose_upper, ACC_CHOOSE_UPPER)
|
583 |
+
metrics.append(acc_choose_upper)
|
584 |
+
|
585 |
+
acc_choose_lower = (once_matches + once_matches_2) / shape[0]
|
586 |
+
self.add_metric(acc_choose_lower, ACC_CHOOSE_LOWER)
|
587 |
+
metrics.append(acc_choose_lower)
|
588 |
+
|
589 |
+
metrics.append(K.sum(target_masked != predict_masked))
|
590 |
+
|
591 |
+
return metrics
|
592 |
+
|
593 |
+
|
594 |
+
class LatentLossModel(Model):
|
595 |
+
def __init__(self, dir_=HORIZONTAL):
|
596 |
+
super().__init__()
|
597 |
+
# self.sum_metrics = []
|
598 |
+
# for i in range(8):
|
599 |
+
# self.sum_metrics.append(Sum(name=f"no_{i}"))
|
600 |
+
self.metric_fn = Accuracy(name="acc_latent")
|
601 |
+
if dir_ == VERTICAL:
|
602 |
+
self.dir = (6, 7)
|
603 |
+
else:
|
604 |
+
self.dir = (2, 5)
|
605 |
+
|
606 |
+
def call(self, inputs):
|
607 |
+
target_image = tf.reshape(inputs[0][2], [-1])
|
608 |
+
output = inputs[1]
|
609 |
+
latents = tnp.asarray(inputs[2])
|
610 |
+
|
611 |
+
target_hor = tf.concat([
|
612 |
+
latents[:, self.dir],
|
613 |
+
latents[tf.range(latents.shape[0]), target_image + 8][:, None]
|
614 |
+
],
|
615 |
+
axis=1)
|
616 |
+
|
617 |
+
loss_hor = mse(K.stop_gradient(target_hor), output)
|
618 |
+
self.add_loss(loss_hor)
|
619 |
+
|
620 |
+
self.add_metric(self.metric_fn(inputs[3], target_image))
|
621 |
+
|
622 |
+
return loss_hor
|
623 |
+
|
624 |
+
|
625 |
+
class PredRav(Model):
|
626 |
+
|
627 |
+
def call(self, inputs):
|
628 |
+
output = inputs[0][:, -1]
|
629 |
+
answers = inputs[1][:, 8:]
|
630 |
+
return tf.argmin(tf.reduce_sum(tf.abs(output[:, None] - answers), axis=-1), axis=-1)
|
raven_utils/models/loss_3.py
ADDED
@@ -0,0 +1,638 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import tensorflow as tf
|
4 |
+
import tensorflow.experimental.numpy as tnp
|
5 |
+
from models_utils import OUTPUT, TARGET, PREDICT, DictModel, add_loss, LOSS, Predict
|
6 |
+
from models_utils import SubClassingModel
|
7 |
+
from models_utils.models.utils import interleave
|
8 |
+
from models_utils.op import reshape
|
9 |
+
from tensorflow.keras import Model
|
10 |
+
# from tensorflow.keras import backend as K
|
11 |
+
from tensorflow.keras.layers import Lambda
|
12 |
+
from tensorflow.keras.losses import SparseCategoricalCrossentropy, mse
|
13 |
+
from tensorflow.keras.metrics import SparseCategoricalAccuracy, Accuracy, BinaryAccuracy
|
14 |
+
import models_utils.ops as K
|
15 |
+
|
16 |
+
import raven_utils.decode
|
17 |
+
import raven_utils as rv
|
18 |
+
from raven_utils.config.constant import LABELS, INDEX, ACC_SAME, ACC_CHOOSE_LOWER, ACC_CHOOSE_UPPER, CLASSIFICATION, \
|
19 |
+
SLOT, \
|
20 |
+
PROPERTIES, ACC, GROUP, NUMBER, MASK
|
21 |
+
from raven_utils.models.uitls_ import RangeMask
|
22 |
+
from raven_utils.const import VERTICAL, HORIZONTAL
|
23 |
+
|
24 |
+
|
25 |
+
def get_properties_mask(target):
|
26 |
+
return target[:, rv.target.END_INDEX:rv.target.UNIFORMITY_INDEX] > 0
|
27 |
+
|
28 |
+
|
29 |
+
def create_change_mask(target):
|
30 |
+
properties_mask = get_properties_mask(target)
|
31 |
+
return [create_mask(properties_mask, i) for i, _ in enumerate(rv.rules.ATTRIBUTES)]
|
32 |
+
|
33 |
+
|
34 |
+
def create_uniform_mask(target):
|
35 |
+
u_mask = lambda i: tf.tile(target[:, rv.target.UNIFORMITY_INDEX + i, None] == 3, [1, rv.rules.ATTRIBUTES_LEN])
|
36 |
+
properties_mask = tf.concat([u_mask(0), u_mask(1)], axis=-1) | get_properties_mask(target)
|
37 |
+
return [create_mask(properties_mask, i) for i, _ in enumerate(rv.rules.ATTRIBUTES)]
|
38 |
+
|
39 |
+
|
40 |
+
def create_all_mask(target):
|
41 |
+
return [
|
42 |
+
tf.cast(tf.ones(tf.stack([tf.shape(target)[0], rv.entity.SUM])), dtype=tf.bool) for i, _ in
|
43 |
+
enumerate(rv.rules.ATTRIBUTES)]
|
44 |
+
|
45 |
+
|
46 |
+
class BaselineClassificationLossModel(Model):
|
47 |
+
def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True):
|
48 |
+
super().__init__()
|
49 |
+
self.predict_fn = SubClassingModel([lambda x: x[0], PredictModel()])
|
50 |
+
self.loss_fn = ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
|
51 |
+
group_loss=group_loss)
|
52 |
+
self.metric_fn = SimilarityRaven(mode=mode)
|
53 |
+
|
54 |
+
def call(self, inputs):
|
55 |
+
losses = []
|
56 |
+
output = inputs[1]
|
57 |
+
losses.append(self.loss_fn([inputs[0][0], output]))
|
58 |
+
losses.append(self.metric_fn([inputs[0][2], inputs[3][0], inputs[0][1][:, 8:]]))
|
59 |
+
return losses
|
60 |
+
|
61 |
+
|
62 |
+
class RavenLoss(Model):
|
63 |
+
def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True, lw=(1.0, 0.3),
|
64 |
+
classification=False, trans=True, anneal=False):
|
65 |
+
super().__init__()
|
66 |
+
if anneal:
|
67 |
+
self.weight_scheduler
|
68 |
+
self.classification = classification
|
69 |
+
self.trans = trans
|
70 |
+
self.predict_fn = DictModel(SubClassingModel([lambda x: x[-1], PredictModel()]), in_=OUTPUT,
|
71 |
+
out=[PREDICT, MASK], name="pred")
|
72 |
+
if self.trans:
|
73 |
+
self.loss_fn = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
|
74 |
+
group_loss=group_loss, enable_metrics=False, lw=lw[0]),
|
75 |
+
name="main_loss")
|
76 |
+
self.loss_fn_2 = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
|
77 |
+
group_loss=group_loss), name="add_loss")
|
78 |
+
self.metric_fn = SimilarityRaven(mode=mode)
|
79 |
+
if self.classification:
|
80 |
+
self.loss_fn_3 = add_loss(
|
81 |
+
ClassRavenModel(mode=create_all_mask, number_loss=number_loss, slot_loss=slot_loss,
|
82 |
+
group_loss=group_loss, enable_metrics="c" if self.trans else True), lw=lw[1],
|
83 |
+
name="class_loss")
|
84 |
+
|
85 |
+
def call(self, inputs):
|
86 |
+
losses = []
|
87 |
+
output = inputs[OUTPUT]
|
88 |
+
target = inputs[TARGET]
|
89 |
+
labels = inputs[LABELS]
|
90 |
+
|
91 |
+
if self.trans:
|
92 |
+
losses.append(self.loss_fn([labels[:, 2], output[0]]))
|
93 |
+
losses.append(self.loss_fn([labels[:, 5], output[1]]))
|
94 |
+
losses.append(self.loss_fn_2([target, output[2]]))
|
95 |
+
losses.append(self.metric_fn([inputs[INDEX], inputs[PREDICT], labels]))
|
96 |
+
if self.classification:
|
97 |
+
for i in range(8):
|
98 |
+
losses.append(self.loss_fn_3([labels[:, i], inputs[CLASSIFICATION][i]]))
|
99 |
+
return {**inputs, LOSS: losses}
|
100 |
+
|
101 |
+
|
102 |
+
class VTRavenLoss(Model):
|
103 |
+
def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True, lw=(2.0, 1.0),
|
104 |
+
classification=False, trans=True, anneal=False, plw=None):
|
105 |
+
super().__init__()
|
106 |
+
if anneal:
|
107 |
+
self.weight_scheduler
|
108 |
+
self.classification = classification
|
109 |
+
self.trans = trans
|
110 |
+
self.predict_fn = DictModel(SubClassingModel([lambda x: x[:, -1], PredictModel()]), in_=OUTPUT,
|
111 |
+
out=[PREDICT, "predict_mask"], name="pred")
|
112 |
+
self.loss_fn = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
|
113 |
+
group_loss=group_loss, plw=plw), lw=lw[0], name="add_loss")
|
114 |
+
self.metric_fn = SimilarityRaven(mode=mode)
|
115 |
+
if self.classification:
|
116 |
+
self.loss_fn_2 = add_loss(
|
117 |
+
ClassRavenModel(mode=create_all_mask, number_loss=number_loss, slot_loss=slot_loss,
|
118 |
+
group_loss=group_loss, enable_metrics="c", plw=plw), lw=lw[1], name="class_loss")
|
119 |
+
|
120 |
+
def call(self, inputs):
|
121 |
+
losses = []
|
122 |
+
output = inputs[OUTPUT]
|
123 |
+
target = inputs[TARGET]
|
124 |
+
labels = inputs[LABELS]
|
125 |
+
mask = inputs[MASK]
|
126 |
+
|
127 |
+
target_masked = target[mask]
|
128 |
+
output_masked = output[mask]
|
129 |
+
losses.append(self.loss_fn([target_masked, output_masked]))
|
130 |
+
|
131 |
+
target_unmasked = target[~mask]
|
132 |
+
output_unmasked = output[~mask]
|
133 |
+
losses.append(self.loss_fn_2([target_unmasked, output_unmasked]))
|
134 |
+
|
135 |
+
losses.append(self.metric_fn([inputs[INDEX], inputs[PREDICT], labels]))
|
136 |
+
return {**inputs, LOSS: losses}
|
137 |
+
|
138 |
+
|
139 |
+
class SingleVTRavenLoss(Model):
|
140 |
+
def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True, lw=(1.0, 0.1),
|
141 |
+
classification=False, trans=True, anneal=False):
|
142 |
+
super().__init__()
|
143 |
+
if anneal:
|
144 |
+
self.weight_scheduler
|
145 |
+
self.classification = classification
|
146 |
+
self.trans = trans
|
147 |
+
self.predict_fn = DictModel(PredictModel(), in_=OUTPUT, out=[PREDICT, MASK], name="pred")
|
148 |
+
self.loss_fn = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
|
149 |
+
group_loss=group_loss), lw=lw[0], name="add_loss")
|
150 |
+
self.metric_fn = SimilarityRaven(mode=mode)
|
151 |
+
|
152 |
+
def call(self, inputs):
|
153 |
+
losses = []
|
154 |
+
output = inputs[OUTPUT]
|
155 |
+
target = inputs[TARGET]
|
156 |
+
labels = inputs[LABELS]
|
157 |
+
|
158 |
+
losses.append(self.loss_fn([target, output]))
|
159 |
+
losses.append(self.metric_fn([inputs[INDEX], inputs[PREDICT], labels]))
|
160 |
+
return {**inputs, LOSS: losses}
|
161 |
+
|
162 |
+
|
163 |
+
class ClassRavenModel(Model):
|
164 |
+
def __init__(self, mode=create_all_mask, plw=None, number_loss=False, slot_loss=True, group_loss=True,
|
165 |
+
enable_metrics=True,
|
166 |
+
lw=1.0):
|
167 |
+
super().__init__()
|
168 |
+
self.number_loss = number_loss
|
169 |
+
self.group_loss = group_loss
|
170 |
+
self.enable_metrics = enable_metrics
|
171 |
+
self.slot_loss = slot_loss
|
172 |
+
self.predict_fn = PredictModel()
|
173 |
+
self.loss_fn = SparseCategoricalCrossentropy(from_logits=True)
|
174 |
+
if self.slot_loss:
|
175 |
+
self.loss_fn_2 = tf.nn.sigmoid_cross_entropy_with_logits
|
176 |
+
if self.enable_metrics:
|
177 |
+
self.enable_metrics = f"{self.enable_metrics}_" if isinstance(self.enable_metrics, str) else ""
|
178 |
+
self.metric_fn = [
|
179 |
+
SparseCategoricalAccuracy(name=f"{self.enable_metrics}{ACC}_{property_}") for property_ in
|
180 |
+
rv.properties.NAMES]
|
181 |
+
if self.group_loss:
|
182 |
+
self.metric_fn_group = SparseCategoricalAccuracy(name=f"{self.enable_metrics}{ACC}_{GROUP}")
|
183 |
+
if self.slot_loss:
|
184 |
+
self.metric_fn_2 = BinaryAccuracy(name=f"{self.enable_metrics}{ACC}_{SLOT}")
|
185 |
+
self.range_mask = RangeMask()
|
186 |
+
self.mode = mode
|
187 |
+
self.lw = lw
|
188 |
+
if not plw:
|
189 |
+
plw = [1., 95.37352927, 2.83426987, 0.85212836, 1.096005, 1.21943385]
|
190 |
+
elif isinstance(plw, int) or isinstance(plw, float):
|
191 |
+
plw = [1., plw, 2.83426987, 0.85212836, 1.096005, 1.21943385]
|
192 |
+
# plw = [plw] * 6
|
193 |
+
self.plw = plw
|
194 |
+
|
195 |
+
# self.predict_fn = partial(tf.argmax, axis=-1)
|
196 |
+
|
197 |
+
def call(self, inputs):
|
198 |
+
losses = []
|
199 |
+
metrics = {}
|
200 |
+
target = inputs[0]
|
201 |
+
output = inputs[1]
|
202 |
+
|
203 |
+
target_group, target_slot, target_all = raven_utils.decode.decode_target(target)
|
204 |
+
|
205 |
+
group_output, output_slot, outputs = raven_utils.decode.output_divide(output, split_fn=tf.split)
|
206 |
+
|
207 |
+
# group
|
208 |
+
if self.group_loss:
|
209 |
+
group_loss = self.lw * self.plw[0] * self.loss_fn(target_group, group_output)
|
210 |
+
losses.append(group_loss)
|
211 |
+
|
212 |
+
if isinstance(self.enable_metrics, str):
|
213 |
+
group_metric = self.metric_fn_group(target_group, group_output)
|
214 |
+
# metrics[GROUP] = group_metric
|
215 |
+
self.add_metric(group_metric)
|
216 |
+
self.add_metric(tf.reduce_sum(group_metric), f"{self.enable_metrics}{ACC}")
|
217 |
+
|
218 |
+
# setting uniformity mask
|
219 |
+
full_properties_musks = self.mode(target)
|
220 |
+
|
221 |
+
range_mask = self.range_mask(target_group)
|
222 |
+
|
223 |
+
if self.slot_loss:
|
224 |
+
# number
|
225 |
+
number_mask = range_mask & full_properties_musks[0]
|
226 |
+
number_mask = tf.cast(number_mask, tf.float32)
|
227 |
+
target_number = tf.reduce_sum(
|
228 |
+
tf.cast(target_slot, "float32") * number_mask, axis=-1)
|
229 |
+
output_number = tf.reduce_sum(
|
230 |
+
tf.cast(tf.sigmoid(output_slot) >= 0.5, "float32") * number_mask, axis=-1)
|
231 |
+
|
232 |
+
# output_number = tf.reduce_sum(tf.sigmoid(output_slot) * number_mask, axis=-1)
|
233 |
+
if self.number_loss:
|
234 |
+
scale = 1 / 9
|
235 |
+
if self.number_loss == 2:
|
236 |
+
output_number_2 = tf.reduce_sum(tf.sigmoid(output_slot) * number_mask, axis=-1)
|
237 |
+
else:
|
238 |
+
output_number_2 = output_number
|
239 |
+
number_loss = self.lw * self.plw[1] * mse(tf.stop_gradient(target_number) * scale,
|
240 |
+
output_number_2 * scale)
|
241 |
+
losses.append(number_loss)
|
242 |
+
|
243 |
+
# metrics[NUMBER] = number_acc
|
244 |
+
|
245 |
+
if isinstance(self.enable_metrics, str):
|
246 |
+
number_acc = tf.reduce_mean(
|
247 |
+
tf.cast(tf.cast(target_number, "int8") == tf.cast(output_number, "int8"), "float32"))
|
248 |
+
self.add_metric(tf.reduce_sum(number_acc), f"{self.enable_metrics}{ACC}_{NUMBER}")
|
249 |
+
self.add_metric(tf.reduce_sum(number_acc), f"{self.enable_metrics}{ACC}")
|
250 |
+
self.add_metric(tf.reduce_sum(number_acc), f"{self.enable_metrics}{ACC}_NO_{GROUP}")
|
251 |
+
|
252 |
+
# position/slot
|
253 |
+
slot_mask = range_mask & full_properties_musks[1]
|
254 |
+
# tf.boolean_mask(target_slot,slot_mask)
|
255 |
+
|
256 |
+
if tf.reduce_any(slot_mask):
|
257 |
+
# if tf.reduce_mean(tf.cast(slot_mask, dtype=tf.int32)) > 0:
|
258 |
+
target_slot_masked = tf.boolean_mask(target_slot, slot_mask)[:, None]
|
259 |
+
output_slot_masked = tf.boolean_mask(output_slot, slot_mask)[:, None]
|
260 |
+
loss_slot = self.lw * self.plw[2] * tf.reduce_mean(
|
261 |
+
self.loss_fn_2(tf.cast(target_slot_masked, "float32"), output_slot_masked))
|
262 |
+
if isinstance(self.enable_metrics, str):
|
263 |
+
acc_slot = self.metric_fn_2(target_slot_masked, output_slot_masked)
|
264 |
+
self.add_metric(acc_slot)
|
265 |
+
self.add_metric(tf.reduce_sum(acc_slot), f"{self.enable_metrics}{ACC}")
|
266 |
+
self.add_metric(tf.reduce_sum(acc_slot), f"{self.enable_metrics}{ACC}_NO_{GROUP}")
|
267 |
+
else:
|
268 |
+
loss_slot = 0.0
|
269 |
+
acc_slot = -1.0
|
270 |
+
|
271 |
+
losses.append(loss_slot)
|
272 |
+
# metrics[SLOT] = acc_slot
|
273 |
+
# if loss_slot != 0:
|
274 |
+
|
275 |
+
# if tf.reduce_any(slot_mask):
|
276 |
+
|
277 |
+
# self.add_metric(acc_slot, f"{self.enable_metrics}{ACC}_{NUMBER}")
|
278 |
+
# self.add_metric(acc_slot, f"{self.enable_metrics}{ACC}")
|
279 |
+
# self.add_metric(acc_slot, f"{self.enable_metrics}{ACC}_NO_{GROUP}")
|
280 |
+
|
281 |
+
# properties
|
282 |
+
for i, out in enumerate(outputs):
|
283 |
+
shape = (-1, rv.entity.SUM, rv.properties.RAW_SIZE[i])
|
284 |
+
out_reshaped = tf.reshape(out, shape)
|
285 |
+
properties_mask = tf.cast(target_slot, "bool") & full_properties_musks[i + 2]
|
286 |
+
|
287 |
+
if tf.reduce_any(properties_mask):
|
288 |
+
out_masked = tf.boolean_mask(out_reshaped, properties_mask)
|
289 |
+
out_target = tf.boolean_mask(target_all[i], properties_mask)
|
290 |
+
loss = self.lw * self.plw[3 + i] * self.loss_fn(out_target, out_masked)
|
291 |
+
if isinstance(self.enable_metrics, str):
|
292 |
+
metric = self.metric_fn[i](out_target, out_masked)
|
293 |
+
self.add_metric(metric)
|
294 |
+
# self.add_metric(metric, f"{self.enable_metrics}{ACC}")
|
295 |
+
self.add_metric(tf.reduce_sum(metric), f"{self.enable_metrics}{ACC}")
|
296 |
+
self.add_metric(tf.reduce_sum(metric), f"{self.enable_metrics}{ACC}_{PROPERTIES}")
|
297 |
+
self.add_metric(tf.reduce_sum(metric), f"{self.enable_metrics}{ACC}_NO_{GROUP}")
|
298 |
+
else:
|
299 |
+
loss = 0.0
|
300 |
+
metric = -1.0
|
301 |
+
|
302 |
+
losses.append(loss)
|
303 |
+
return losses
|
304 |
+
|
305 |
+
|
306 |
+
class FullMask(Model):
|
307 |
+
def __init__(self, mode=create_uniform_mask):
|
308 |
+
super().__init__()
|
309 |
+
self.range_mask = RangeMask()
|
310 |
+
self.mode = mode
|
311 |
+
|
312 |
+
def call(self, inputs):
|
313 |
+
target_group, target_slot, _ = raven_utils.decode.decode_target(inputs)
|
314 |
+
full_properties_musks = self.mode(inputs)
|
315 |
+
range_mask = self.range_mask(target_group)
|
316 |
+
|
317 |
+
number_mask = range_mask & full_properties_musks[0]
|
318 |
+
|
319 |
+
slot_mask = range_mask & full_properties_musks[1]
|
320 |
+
properties_mask = []
|
321 |
+
for property_mask in full_properties_musks[2:]:
|
322 |
+
properties_mask.append(tf.cast(target_slot, "bool") & property_mask)
|
323 |
+
return [slot_mask, properties_mask, number_mask]
|
324 |
+
|
325 |
+
|
326 |
+
def create_mask(rules, i):
|
327 |
+
mask_1 = tf.tile(rules[:, i][None], [len(rv.target.FIRST_LAYOUT), 1])
|
328 |
+
mask_2 = tf.tile(rules[:, i + 5][None], [len(rv.target.SECOND_LAYOUT), 1])
|
329 |
+
shape = tf.shape(rules)
|
330 |
+
full_mask_1 = tf.scatter_nd(tnp.array(rv.target.FIRST_LAYOUT)[:, None], mask_1, shape=(rv.entity.SUM, shape[0]))
|
331 |
+
full_mask_2 = tf.tensor_scatter_nd_update(full_mask_1, tnp.array(rv.target.SECOND_LAYOUT)[:, None], mask_2)
|
332 |
+
return tf.transpose(full_mask_2)
|
333 |
+
|
334 |
+
|
335 |
+
# class PredictModel(Model):
|
336 |
+
# def __init__(self):
|
337 |
+
# super().__init__()
|
338 |
+
# self.predict_fn = Lambda(partial(tf.argmax, axis=-1))
|
339 |
+
# self.predict_fn_2 = Lambda(lambda x: tf.sigmoid(x) > 0.5)
|
340 |
+
# self.range_mask = RangeMask()
|
341 |
+
#
|
342 |
+
# # self.predict_fn = partial(tf.argmax, axis=-1)
|
343 |
+
#
|
344 |
+
# def call(self, inputs):
|
345 |
+
# group_output = inputs[rv.OUTPUT_GROUP_SLICE]
|
346 |
+
# group_loss = self.predict_fn(group_output)[:, None]
|
347 |
+
#
|
348 |
+
# output_slot = inputs[rv.OUTPUT_SLOT_SLICE]
|
349 |
+
# range_mask = self.range_mask(group_loss[:, 0])
|
350 |
+
# loss_slot = tf.cast(self.predict_fn_2(output_slot), dtype=tf.int64)
|
351 |
+
#
|
352 |
+
# properties_output = inputs[rv.OUTPUT_PROPERTIES_SLICE]
|
353 |
+
# properties = []
|
354 |
+
# outputs = tf.split(properties_output, list(rv.ENTITY_PROPERTIES_INDEX.values()), axis=-1)
|
355 |
+
# for i, out in enumerate(outputs):
|
356 |
+
# shape = (-1, rv.ENTITY_SUM, rv.ENTITY_PROPERTIES_VALUES[i])
|
357 |
+
# out_reshaped = tf.reshape(out, shape)
|
358 |
+
# properties.append(self.predict_fn(out_reshaped))
|
359 |
+
# number_loss = tf.reduce_sum(loss_slot, axis=-1, keepdims=True)
|
360 |
+
#
|
361 |
+
# result = tf.concat([group_loss, loss_slot, interleave(properties), number_loss], axis=-1)
|
362 |
+
#
|
363 |
+
# return [result, range_mask, range_mask, range_mask, range_mask]
|
364 |
+
|
365 |
+
class PredictModel(Model):
|
366 |
+
def __init__(self):
|
367 |
+
super().__init__()
|
368 |
+
self.predict_fn = Predict()
|
369 |
+
self.predict_fn_2 = Lambda(lambda x: tf.sigmoid(x) > 0.5)
|
370 |
+
self.range_mask = RangeMask()
|
371 |
+
|
372 |
+
# self.predict_fn = partial(tf.argmax, axis=-1)
|
373 |
+
|
374 |
+
def call(self, inputs):
|
375 |
+
group_output, output_slot, *properties = rv.decode.output(inputs, tf.split, self.predict_fn, self.predict_fn_2)
|
376 |
+
number_loss = K.int64(K.sum(output_slot))
|
377 |
+
result = tf.concat(
|
378 |
+
[group_output[:, None], tf.cast(output_slot, dtype=tf.int64), interleave(properties), number_loss[:, None]],
|
379 |
+
axis=-1)
|
380 |
+
|
381 |
+
range_mask = self.range_mask(group_output)
|
382 |
+
return [result, range_mask]
|
383 |
+
# return [result, range_mask, range_mask, range_mask, range_mask]
|
384 |
+
|
385 |
+
|
386 |
+
# todo change slices
|
387 |
+
class PredictModelMasked(Model):
|
388 |
+
def __init__(self):
|
389 |
+
super().__init__()
|
390 |
+
self.predict_fn = Lambda(partial(tf.argmax, axis=-1))
|
391 |
+
self.loss_fn_2 = Lambda(lambda x: tf.sigmoid(x) > 0.5)
|
392 |
+
self.range_mask = RangeMask()
|
393 |
+
|
394 |
+
# self.predict_fn = partial(tf.argmax, axis=-1)
|
395 |
+
|
396 |
+
def call(self, inputs):
|
397 |
+
group_output = inputs[:, -rv.GROUPS_NO:]
|
398 |
+
group_loss = self.predict_fn(group_output)[:, None]
|
399 |
+
|
400 |
+
output_slot = inputs[:, :rv.ENTITY_SUM]
|
401 |
+
range_mask = self.range_mask(group_loss[:, 0])
|
402 |
+
loss_slot = tf.cast(self.predict_fn_2(output_slot * range_mask), dtype=tf.int64)
|
403 |
+
|
404 |
+
properties_output = inputs[:, rv.ENTITY_SUM:-rv.GROUPS_NO]
|
405 |
+
|
406 |
+
properties = []
|
407 |
+
outputs = tf.split(properties_output, list(rv.ENTITY_PROPERTIES_INDEX.values()), axis=-1)
|
408 |
+
for i, out in enumerate(outputs):
|
409 |
+
shape = (-1, rv.ENTITY_SUM, rv.ENTITY_PROPERTIES_VALUES[i])
|
410 |
+
out_reshaped = tf.reshape(out, shape)
|
411 |
+
out_masked = out_reshaped * loss_slot[..., None]
|
412 |
+
properties.append(self.predict_fn(out_masked))
|
413 |
+
# out_masked[0].numpy()
|
414 |
+
number_loss = tf.reduce_sum(loss_slot, axis=-1, keepdims=True)
|
415 |
+
|
416 |
+
result = tf.concat([group_loss, loss_slot, interleave(properties), number_loss], axis=-1)
|
417 |
+
|
418 |
+
return result
|
419 |
+
|
420 |
+
|
421 |
+
def final_predict_mask(x, mask):
|
422 |
+
r = reshape(x[0][:, rv.INDEX[0]:-1], [-1, 3])
|
423 |
+
return tf.ragged.boolean_mask(r, mask)
|
424 |
+
|
425 |
+
|
426 |
+
def final_predict(x, mode=False):
|
427 |
+
m = x[1] if mode else tf.cast(x[0][:, 1:rv.INDEX[0]], tf.bool)
|
428 |
+
return final_predict_mask(x[0], m)
|
429 |
+
|
430 |
+
|
431 |
+
def final_predict_2(x):
|
432 |
+
ones = tf.cast(tf.ones(tf.shape(x[0])[0]), tf.bool)[:, None]
|
433 |
+
mask = tf.concat([ones, tf.tile(x[1], [1, 4]), ones], axis=-1)
|
434 |
+
return tf.ragged.boolean_mask(x[0], mask)
|
435 |
+
|
436 |
+
|
437 |
+
class PredictModelOld(Model):
|
438 |
+
|
439 |
+
def call(self, inputs):
|
440 |
+
output = inputs[-2]
|
441 |
+
|
442 |
+
rest_output = output[:, :-rv.GROUPS_NO]
|
443 |
+
|
444 |
+
result_all = []
|
445 |
+
outputs = tf.split(rest_output, list(rv.ENTITY_PROPERTIES_INDEX.values()), axis=-3)
|
446 |
+
for i, out in enumerate(outputs):
|
447 |
+
shape = (-3, rv.ENTITY_SUM, rv.ENTITY_PROPERTIES_VALUES[i])
|
448 |
+
out_reshaped = tf.reshape(out, shape)
|
449 |
+
|
450 |
+
result = tf.cast(tf.argmax(out_reshaped, axis=-3), dtype="int8")
|
451 |
+
result_all.append(result)
|
452 |
+
|
453 |
+
result_all = interleave(result_all)
|
454 |
+
return result_all
|
455 |
+
|
456 |
+
|
457 |
+
def get_matches(diff, target_index):
|
458 |
+
diff_sum = K.sum(diff)
|
459 |
+
db_argsort = tf.argsort(diff_sum, axis=-1)
|
460 |
+
db_sorted = tf.sort(diff_sum)
|
461 |
+
db_mask = db_sorted[:, 0, None] == db_sorted
|
462 |
+
db_same = tf.where(db_mask, db_argsort, -1 * tf.ones_like(db_argsort))
|
463 |
+
matched_index = db_same == target_index
|
464 |
+
# setting shape needed for TensorFlow graph
|
465 |
+
matched_index.set_shape(db_same.shape)
|
466 |
+
matches = K.any(matched_index)
|
467 |
+
more_matches = K.sum(db_mask) > 1
|
468 |
+
once_matches = K.sum(matches & tf.math.logical_not(more_matches))
|
469 |
+
return matches, more_matches, once_matches
|
470 |
+
|
471 |
+
|
472 |
+
class SimilarityRaven(Model):
|
473 |
+
def __init__(self, mode=create_all_mask, number_loss=False):
|
474 |
+
super().__init__()
|
475 |
+
self.range_mask = RangeMask()
|
476 |
+
self.mode = mode
|
477 |
+
|
478 |
+
# self.predict_fn = partial(tf.argmax, axis=-1)
|
479 |
+
|
480 |
+
# INDEX, PREDICT, LABELS
|
481 |
+
def call(self, inputs):
|
482 |
+
metrics = []
|
483 |
+
target_index = inputs[0] - 8
|
484 |
+
predict = inputs[1]
|
485 |
+
answers = inputs[2][:, 8:]
|
486 |
+
shape = tf.shape(predict)
|
487 |
+
|
488 |
+
target = K.gather(answers, target_index[:, 0])
|
489 |
+
|
490 |
+
target_group = target[:, 0]
|
491 |
+
|
492 |
+
# comp_slice = np.
|
493 |
+
target_comp = target[:, 1:rv.target.END_INDEX]
|
494 |
+
predict_comp = predict[:, 1:rv.target.END_INDEX]
|
495 |
+
answers_comp = answers[:, :, 1:rv.target.END_INDEX]
|
496 |
+
|
497 |
+
full_properties_musks = self.mode(target)
|
498 |
+
fpm = K.cat([full_properties_musks[0], interleave(full_properties_musks[2:])])
|
499 |
+
|
500 |
+
range_mask = self.range_mask(target_group)
|
501 |
+
full_range_mask = K.cat([range_mask, tf.repeat(range_mask, 3, axis=-1)], axis=-1)
|
502 |
+
|
503 |
+
final_mask = fpm & full_range_mask
|
504 |
+
|
505 |
+
target_masked = target_comp * final_mask
|
506 |
+
predict_masked = predict_comp * final_mask
|
507 |
+
answers_masked = answers_comp * tf.tile(final_mask[:, None], [1, 8, 1])
|
508 |
+
|
509 |
+
acc_same = K.mean(K.all(target_masked == predict_masked))
|
510 |
+
self.add_metric(acc_same, ACC_SAME)
|
511 |
+
metrics.append(acc_same)
|
512 |
+
|
513 |
+
diff = tf.abs(predict_masked[:, None] - answers_masked)
|
514 |
+
diff_bool = diff != 0
|
515 |
+
|
516 |
+
matches, more_matches, once_matches = get_matches(tf.cast(diff_bool, dtype=tf.int32), target_index)
|
517 |
+
|
518 |
+
second_phase_mask = (more_matches & matches)
|
519 |
+
diff_second_phase = tf.boolean_mask(diff, second_phase_mask)
|
520 |
+
target_index_2 = tf.boolean_mask(target_index, second_phase_mask, axis=0)
|
521 |
+
|
522 |
+
matches_2, more_matches_2, once_matches_2 = get_matches(diff_second_phase, target_index_2)
|
523 |
+
matches_2_no = K.sum(matches_2)
|
524 |
+
|
525 |
+
acc_choose_upper = (once_matches + matches_2_no) / shape[0]
|
526 |
+
self.add_metric(acc_choose_upper, ACC_CHOOSE_UPPER)
|
527 |
+
metrics.append(acc_choose_upper)
|
528 |
+
|
529 |
+
acc_choose_lower = (once_matches + once_matches_2) / shape[0]
|
530 |
+
self.add_metric(acc_choose_lower, ACC_CHOOSE_LOWER)
|
531 |
+
metrics.append(acc_choose_lower)
|
532 |
+
|
533 |
+
return metrics
|
534 |
+
|
535 |
+
|
536 |
+
class SimilarityRaven2(Model):
|
537 |
+
def __init__(self, mode=create_all_mask, number_loss=False):
|
538 |
+
super().__init__()
|
539 |
+
self.range_mask = RangeMask()
|
540 |
+
self.mode = mode
|
541 |
+
|
542 |
+
# self.predict_fn = partial(tf.argmax, axis=-1)
|
543 |
+
|
544 |
+
# INDEX, PREDICT, LABELS
|
545 |
+
def call(self, inputs):
|
546 |
+
metrics = []
|
547 |
+
target_index = inputs[0] - 8
|
548 |
+
predict = inputs[1]
|
549 |
+
answers = inputs[2][:, 8:]
|
550 |
+
shape = tf.shape(predict)
|
551 |
+
|
552 |
+
target = K.gather(answers, target_index[:, 0])
|
553 |
+
|
554 |
+
target_group = target[:, 0]
|
555 |
+
|
556 |
+
# comp_slice = np.
|
557 |
+
target_comp = target[:, 1:rv.target.END_INDEX]
|
558 |
+
predict_comp = predict[:, 1:rv.target.END_INDEX]
|
559 |
+
answers_comp = answers[:, :, 1:rv.target.END_INDEX]
|
560 |
+
|
561 |
+
full_properties_musks = self.mode(target)
|
562 |
+
fpm = K.cat([full_properties_musks[0], interleave(full_properties_musks[2:])])
|
563 |
+
|
564 |
+
range_mask = self.range_mask(target_group)
|
565 |
+
full_range_mask = K.cat([range_mask, tf.repeat(range_mask, 3, axis=-1)], axis=-1)
|
566 |
+
|
567 |
+
final_mask = fpm & full_range_mask
|
568 |
+
|
569 |
+
target_masked = target_comp * final_mask
|
570 |
+
predict_masked = predict_comp * final_mask
|
571 |
+
answers_masked = answers_comp * tf.tile(final_mask[:, None], [1, 8, 1])
|
572 |
+
|
573 |
+
acc_same = K.mean(K.all(target_masked == predict_masked))
|
574 |
+
self.add_metric(acc_same, ACC_SAME)
|
575 |
+
metrics.append(acc_same)
|
576 |
+
|
577 |
+
diff = tf.abs(predict_masked[:, None] - answers_masked)
|
578 |
+
diff_bool = diff != 0
|
579 |
+
|
580 |
+
matches, more_matches, once_matches = get_matches(tf.cast(diff_bool, dtype=tf.int32), target_index)
|
581 |
+
|
582 |
+
second_phase_mask = (more_matches & matches)
|
583 |
+
diff_second_phase = tf.boolean_mask(diff, second_phase_mask)
|
584 |
+
target_index_2 = tf.boolean_mask(target_index, second_phase_mask, axis=0)
|
585 |
+
|
586 |
+
matches_2, more_matches_2, once_matches_2 = get_matches(diff_second_phase, target_index_2)
|
587 |
+
matches_2_no = K.sum(matches_2)
|
588 |
+
|
589 |
+
acc_choose_upper = (once_matches + matches_2_no) / shape[0]
|
590 |
+
self.add_metric(acc_choose_upper, ACC_CHOOSE_UPPER)
|
591 |
+
metrics.append(acc_choose_upper)
|
592 |
+
|
593 |
+
acc_choose_lower = (once_matches + once_matches_2) / shape[0]
|
594 |
+
self.add_metric(acc_choose_lower, ACC_CHOOSE_LOWER)
|
595 |
+
metrics.append(acc_choose_lower)
|
596 |
+
|
597 |
+
metrics.append(K.sum(target_masked != predict_masked))
|
598 |
+
|
599 |
+
return metrics
|
600 |
+
|
601 |
+
|
602 |
+
class LatentLossModel(Model):
|
603 |
+
def __init__(self, dir_=HORIZONTAL):
|
604 |
+
super().__init__()
|
605 |
+
# self.sum_metrics = []
|
606 |
+
# for i in range(8):
|
607 |
+
# self.sum_metrics.append(Sum(name=f"no_{i}"))
|
608 |
+
self.metric_fn = Accuracy(name="acc_latent")
|
609 |
+
if dir_ == VERTICAL:
|
610 |
+
self.dir = (6, 7)
|
611 |
+
else:
|
612 |
+
self.dir = (2, 5)
|
613 |
+
|
614 |
+
def call(self, inputs):
|
615 |
+
target_image = tf.reshape(inputs[0][2], [-1])
|
616 |
+
output = inputs[1]
|
617 |
+
latents = tnp.asarray(inputs[2])
|
618 |
+
|
619 |
+
target_hor = tf.concat([
|
620 |
+
latents[:, self.dir],
|
621 |
+
latents[tf.range(latents.shape[0]), target_image + 8][:, None]
|
622 |
+
],
|
623 |
+
axis=1)
|
624 |
+
|
625 |
+
loss_hor = mse(K.stop_gradient(target_hor), output)
|
626 |
+
self.add_loss(loss_hor)
|
627 |
+
|
628 |
+
self.add_metric(self.metric_fn(inputs[3], target_image))
|
629 |
+
|
630 |
+
return loss_hor
|
631 |
+
|
632 |
+
|
633 |
+
class PredRav(Model):
|
634 |
+
|
635 |
+
def call(self, inputs):
|
636 |
+
output = inputs[0][:, -1]
|
637 |
+
answers = inputs[1][:, 8:]
|
638 |
+
return tf.argmin(tf.reduce_sum(tf.abs(output[:, None] - answers), axis=-1), axis=-1)
|
raven_utils/models/multi_transformer.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from functools import partial
|
3 |
+
from tensorflow.keras.layers import Lambda
|
4 |
+
from tensorflow.keras.layers import Dense
|
5 |
+
from tensorflow.keras import Input, Model
|
6 |
+
from tensorflow.python.keras import Sequential
|
7 |
+
|
8 |
+
from config.constant import TRANS
|
9 |
+
from ml_utils import filter_init
|
10 |
+
from models.loss import VTRavenLoss, create_uniform_mask, SingleVTRavenLoss
|
11 |
+
from models_utils import pmodel, DictModel, bt, INPUTS, bm, OUTPUT, LATENTS, transformer, BatchModel, get_extractor, \
|
12 |
+
build_seq_model, BUILD, build_train_list, InitialWeight
|
13 |
+
from models_utils import SumPositionEmbedding, TransformerBlock, CatPositionEmbedding, transformer, BatchInitialWeight
|
14 |
+
import models_utils.ops as K
|
15 |
+
from models_utils.image import inverse_fn
|
16 |
+
from models_utils.ops_core import IndexReshape
|
17 |
+
from models_utils.random_ import EpsilonGreedy, EpsilonSoft
|
18 |
+
from models_utils.step import StepDict
|
19 |
+
|
20 |
+
|
21 |
+
def init_weights(shape, dtype=None):
|
22 |
+
return tf.cast(K.var.image(shape=shape, pre=True), dtype=tf.float32)
|
23 |
+
|
24 |
+
|
25 |
+
def conversion(x, max_=45):
|
26 |
+
shape = tf.shape(x)
|
27 |
+
return tf.reshape(x[:, :max_], tf.stack([shape[0], 9, -1]))
|
28 |
+
|
29 |
+
|
30 |
+
def take_left(x):
|
31 |
+
return x[..., 7:8]
|
32 |
+
|
33 |
+
|
34 |
+
def take_by_index(x, i=8):
|
35 |
+
return x[..., i:i + 1]
|
36 |
+
|
37 |
+
|
38 |
+
def mix(x):
|
39 |
+
return (x[..., 7:8] + x[..., 5:6]) / 2
|
40 |
+
|
41 |
+
|
42 |
+
def empty_last(x):
|
43 |
+
return tf.zeros_like(x[..., 7:8])
|
44 |
+
|
45 |
+
|
46 |
+
class Conversion(Model):
|
47 |
+
def __init__(self):
|
48 |
+
super().__init__()
|
49 |
+
self.model = IndexReshape((0, "9", None))
|
50 |
+
|
51 |
+
def call(self, inputs):
|
52 |
+
return self.model(inputs[:, :45])
|
53 |
+
|
54 |
+
|
55 |
+
class RandomImageMask(Model):
|
56 |
+
def __init__(self, last, last_index=9):
|
57 |
+
super().__init__()
|
58 |
+
self.get_last = last
|
59 |
+
self.last_index = last_index
|
60 |
+
|
61 |
+
def call(self, inputs):
|
62 |
+
shape = tf.shape(inputs)
|
63 |
+
indexes = tf.random.uniform(shape=shape[0:1], maxval=self.last_index, dtype=tf.int32)
|
64 |
+
mask = tf.one_hot(indexes, self.last_index)[:, None, None]
|
65 |
+
|
66 |
+
return (1 - mask) * inputs[..., :self.last_index] + mask * tf.tile(self.get_last(inputs),
|
67 |
+
(1, 1, 1, self.last_index))
|
68 |
+
|
69 |
+
|
70 |
+
# res = (1 - mask) * inputs[..., :self.last_index] + mask * tf.tile(self.get_last(inputs),
|
71 |
+
# (1, 1, 1, self.last_index))
|
72 |
+
|
73 |
+
|
74 |
+
# from data_utils import ims
|
75 |
+
# for i in range(50):
|
76 |
+
# ims(res[i].numpy().swapaxes(0, 2))
|
77 |
+
# res[12].numpy()
|
78 |
+
# self.get_last(inputs).numpy()
|
79 |
+
# import tensorflow as tf
|
80 |
+
# tf.random.uniform(shape=shape[0:1], maxval=255, dtype=tf.int32)
|
81 |
+
# from ml_utils import print_error
|
82 |
+
# ims(mask[0].numpy())
|
83 |
+
# print_error(lambda :ims(mask[0]))
|
84 |
+
# from models_utils import ops as K
|
85 |
+
|
86 |
+
|
87 |
+
class ImageMask(Model):
|
88 |
+
def __init__(self, last, index=8, last_index=9):
|
89 |
+
super().__init__()
|
90 |
+
self.get_last = last
|
91 |
+
self.index = index
|
92 |
+
self.last_index = last_index
|
93 |
+
|
94 |
+
def call(self, inputs):
|
95 |
+
return tf.concat([inputs[..., :8], self.get_last(inputs)], axis=-1)
|
96 |
+
|
97 |
+
|
98 |
+
class CreateGrid(Model):
|
99 |
+
def __init__(self,
|
100 |
+
no=4,
|
101 |
+
extractor="ef",
|
102 |
+
type_=3,
|
103 |
+
base="seq",
|
104 |
+
last=take_left,
|
105 |
+
epsilon=None,
|
106 |
+
pooling=None,
|
107 |
+
mask_fn=None,
|
108 |
+
model=None,
|
109 |
+
**kwargs
|
110 |
+
):
|
111 |
+
super().__init__()
|
112 |
+
self.type_ = type_
|
113 |
+
if type_ == 9:
|
114 |
+
self.start_shape = 75
|
115 |
+
data = (224, 224, 3)
|
116 |
+
conv = lambda: Conversion()
|
117 |
+
else:
|
118 |
+
self.start_shape = 84
|
119 |
+
data = (84, 84, 3)
|
120 |
+
extractor = BUILD[base]([
|
121 |
+
BatchModel(get_extractor(data=data, model=extractor)),
|
122 |
+
lambda x: tf.transpose(x, (1, 0, 2, 3, 4))
|
123 |
+
# lambda x: tf.tile(x[:, :224, :224], (1, 1, 1, 3))
|
124 |
+
])
|
125 |
+
conv = lambda: conversion
|
126 |
+
|
127 |
+
self.epsilon = epsilon
|
128 |
+
if mask_fn == "random":
|
129 |
+
mask_fn = RandomImageMask(last=last)
|
130 |
+
elif mask_fn is None:
|
131 |
+
mask_fn = ImageMask(last=last)
|
132 |
+
|
133 |
+
self.mask_fn = mask_fn
|
134 |
+
|
135 |
+
|
136 |
+
def call(self, inputs):
|
137 |
+
transposed = tf.image.resize(tf.transpose(inputs, (0, 2, 3, 1)), (self.start_shape, self.start_shape))
|
138 |
+
re = self.mask_fn(transposed)
|
139 |
+
|
140 |
+
# re = tf.concat([transposed[..., :8], self.get_last(transposed)], axis=-1)
|
141 |
+
if self.type_ == 9:
|
142 |
+
x = tf.transpose(re, [0, 3, 1, 2])[..., None]
|
143 |
+
x = K.create_image_grid(x, 3, 3)
|
144 |
+
x = x[:, :224, :224]
|
145 |
+
x = tf.tile(x, [1, 1, 1, 3])
|
146 |
+
else:
|
147 |
+
|
148 |
+
x = tf.stack([
|
149 |
+
re[..., :3],
|
150 |
+
re[..., 3:6],
|
151 |
+
re[..., 6:9],
|
152 |
+
])
|
153 |
+
return self.model(x)
|
154 |
+
|
155 |
+
|
156 |
+
# self.model.layers[0](x)
|
157 |
+
|
158 |
+
|
159 |
+
def grid_transformer(
|
160 |
+
*args,
|
161 |
+
type_=9,
|
162 |
+
no=4,
|
163 |
+
extractor="ef",
|
164 |
+
loss_mode=create_uniform_mask,
|
165 |
+
output_size=10,
|
166 |
+
loss_weight=1.0,
|
167 |
+
out_layers=(1000, 1000, 1000),
|
168 |
+
pos_emd="cat",
|
169 |
+
base="seq",
|
170 |
+
inverse_image=True,
|
171 |
+
last="left",
|
172 |
+
mask_fn=None,
|
173 |
+
model=None,
|
174 |
+
trans=None,
|
175 |
+
**kwargs):
|
176 |
+
|
177 |
+
if last == "left":
|
178 |
+
last = take_left
|
179 |
+
elif last == "mix":
|
180 |
+
last = mix
|
181 |
+
elif last == "empty":
|
182 |
+
last = empty_last
|
183 |
+
elif last == "start":
|
184 |
+
last = Sequential([Lambda(empty_last), BatchInitialWeight(initializer=init_weights)])
|
185 |
+
|
186 |
+
create_grid = CreateGrid(
|
187 |
+
type_=type_,
|
188 |
+
no=no,
|
189 |
+
extractor=extractor,
|
190 |
+
model=model,
|
191 |
+
output_size=output_size,
|
192 |
+
out_layer=out_layers,
|
193 |
+
pos_emd=pos_emd,
|
194 |
+
base=base,
|
195 |
+
last=last,
|
196 |
+
mask_fn=mask_fn,
|
197 |
+
**kwargs
|
198 |
+
)
|
199 |
+
|
200 |
+
if model is None:
|
201 |
+
trans = transformer(
|
202 |
+
extractor=extractor,
|
203 |
+
pos_emd=pos_emd,
|
204 |
+
data=data,
|
205 |
+
output_size=output_size,
|
206 |
+
out_layers=out_layer,
|
207 |
+
pooling=conv,
|
208 |
+
no=no,
|
209 |
+
base=base,
|
210 |
+
**kwargs
|
211 |
+
# **as_dict(p.trans)
|
212 |
+
)
|
213 |
+
else:
|
214 |
+
trans = trans
|
215 |
+
|
216 |
+
|
217 |
+
|
218 |
+
def get_rav_trans(
|
219 |
+
*args,
|
220 |
+
type_=9,
|
221 |
+
no=4,
|
222 |
+
extractor="ef",
|
223 |
+
loss_mode=create_uniform_mask,
|
224 |
+
output_size=10,
|
225 |
+
loss_weight=1.0,
|
226 |
+
out_layers=(1000, 1000, 1000),
|
227 |
+
pos_emd="cat",
|
228 |
+
base="seq",
|
229 |
+
inverse_image=True,
|
230 |
+
last="left",
|
231 |
+
epsilon="greedy",
|
232 |
+
epsilon_step=500,
|
233 |
+
mask_fn=None,
|
234 |
+
model=None,
|
235 |
+
loss="multi",
|
236 |
+
**kwargs):
|
237 |
+
if last == "left":
|
238 |
+
last = take_left
|
239 |
+
elif last == "mix":
|
240 |
+
last = mix
|
241 |
+
elif last == "empty":
|
242 |
+
last = empty_last
|
243 |
+
elif last == "start":
|
244 |
+
last = Sequential([Lambda(empty_last), BatchInitialWeight(initializer=init_weights)])
|
245 |
+
|
246 |
+
trans_raven = CreateGrid(
|
247 |
+
type_=type_,
|
248 |
+
no=no,
|
249 |
+
extractor=extractor,
|
250 |
+
model=model,
|
251 |
+
output_size=output_size,
|
252 |
+
out_layer=out_layers,
|
253 |
+
pos_emd=pos_emd,
|
254 |
+
base=base,
|
255 |
+
last=last,
|
256 |
+
epsilon=epsilon,
|
257 |
+
mask_fn=mask_fn,
|
258 |
+
**kwargs
|
259 |
+
)
|
260 |
+
|
261 |
+
if loss == "single":
|
262 |
+
loss = SingleVTRavenLoss
|
263 |
+
else:
|
264 |
+
loss = VTRavenLoss
|
265 |
+
|
266 |
+
return bt(
|
267 |
+
DictModel(
|
268 |
+
Sequential([Lambda(lambda x: 255 - x), trans_raven]) if inverse_image else trans_raven,
|
269 |
+
in_=INPUTS,
|
270 |
+
name="Body"
|
271 |
+
),
|
272 |
+
loss=loss(mode=loss_mode, classification=True, lw=(loss_weight, 1.0)),
|
273 |
+
loss_wrap=False
|
274 |
+
)
|
raven_utils/models/raven.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ml_utils import lw, lu
|
2 |
+
from models_utils import bm, Base, res, bt, DictModel, dense_drop, drop, build_encoder, MODEL_ARCH, ListModel, short, \
|
3 |
+
dense, Flatten, Cat, CatDenseBefore, \
|
4 |
+
CatDense, CatBefore, Drop, Flat2, down, Pass, conv, Flat, Get, bs, Res, SoftBlock
|
5 |
+
from models_utils import SubClassingModel
|
6 |
+
from models_utils.config.constants import *
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from config.constant import *
|
10 |
+
from tensorflow.keras.layers import Dense, Activation, BatchNormalization
|
11 |
+
import tensorflow as tf
|
12 |
+
|
13 |
+
import raven_utils as rv
|
14 |
+
|
15 |
+
from models.body import create_block
|
16 |
+
from models.class_ import Merge, RavenClass
|
17 |
+
from models.head import LatentHeadModel
|
18 |
+
|
19 |
+
from models.loss import RavenLoss
|
20 |
+
from models.trans import TransModel, FullTrans
|
21 |
+
from raven_utils.const import HORIZONTAL
|
22 |
+
|
23 |
+
|
24 |
+
def raven_model(scales,
|
25 |
+
out_layers,
|
26 |
+
latent=(64, 128, 256),
|
27 |
+
output_size=None,
|
28 |
+
padding=SAME,
|
29 |
+
body_layers=1,
|
30 |
+
encoder=None,
|
31 |
+
loop=1,
|
32 |
+
model=None,
|
33 |
+
act=None,
|
34 |
+
simpler=0,
|
35 |
+
loss_mode=None,
|
36 |
+
loss_weight=0.3,
|
37 |
+
dir_=HORIZONTAL,
|
38 |
+
global_context=False,
|
39 |
+
images_no=8,
|
40 |
+
context_mul=2,
|
41 |
+
res_act="pass",
|
42 |
+
drop_latent=0,
|
43 |
+
drop_inference=0,
|
44 |
+
drop_end=0,
|
45 |
+
ga=False,
|
46 |
+
trans_norm=None,
|
47 |
+
trans_act="relu",
|
48 |
+
arch=HEAD3,
|
49 |
+
encoder_norm=False,
|
50 |
+
encoder_pool=False,
|
51 |
+
encoder_global="GM",
|
52 |
+
encoder_before=False,
|
53 |
+
tail_units=256,
|
54 |
+
tail_flatten=None,
|
55 |
+
# for now by default
|
56 |
+
tail_down="MP",
|
57 |
+
trans_no=1,
|
58 |
+
trans_score_activation=tf.nn.softmax,
|
59 |
+
block_=SoftBlock,
|
60 |
+
**kwargs):
|
61 |
+
if isinstance(latent, int):
|
62 |
+
latent = (latent, 128, 256)
|
63 |
+
scales = lw(scales)
|
64 |
+
|
65 |
+
context_size = np.array(latent) * context_mul
|
66 |
+
# context_size = latent[scales] * context_mul
|
67 |
+
|
68 |
+
# if scales == 2:
|
69 |
+
# arch = HEAD
|
70 |
+
# elif scales == 1:
|
71 |
+
# arch = HEAD2
|
72 |
+
# else:
|
73 |
+
# arch = VERY2
|
74 |
+
|
75 |
+
if encoder_pool:
|
76 |
+
strides = (1, 1)
|
77 |
+
else:
|
78 |
+
strides = (2, 2)
|
79 |
+
if not isinstance(encoder_before, tuple):
|
80 |
+
encoder_before = [encoder_before] * 3
|
81 |
+
|
82 |
+
# if trans == 1:
|
83 |
+
# trans_model = TransModel2
|
84 |
+
# else:
|
85 |
+
# trans_model = TransModel
|
86 |
+
|
87 |
+
# if scales == 3:
|
88 |
+
# head = MultiHeadModel(encoder=encoder)
|
89 |
+
arch = MODEL_ARCH[arch]
|
90 |
+
heads = []
|
91 |
+
for s in list(range(0, max(scales) + 1)):
|
92 |
+
if s in (0, 1):
|
93 |
+
if s == 0:
|
94 |
+
encoder = build_encoder(arch[:3], add_norm=encoder_norm, add_pool=encoder_pool, kerner_size=(4, 4),
|
95 |
+
strides=strides)
|
96 |
+
else:
|
97 |
+
encoder = build_encoder(arch[3:4], add_norm=encoder_norm, add_pool=encoder_pool, kerner_size=(4, 4),
|
98 |
+
strides=strides)
|
99 |
+
head = LatentHeadModel(
|
100 |
+
encoder=encoder,
|
101 |
+
inference_network=(
|
102 |
+
bm([
|
103 |
+
CatBefore(filters=int(context_size[s] / 8)) if encoder_before[s] else Cat(
|
104 |
+
filters=context_size[s]),
|
105 |
+
# todo activation?
|
106 |
+
Res(filters=context_size[s], padding=padding)
|
107 |
+
] + ([drop(drop_inference)] if drop_inference else []),
|
108 |
+
name="inference")
|
109 |
+
) if s in scales else Pass(),
|
110 |
+
stem=Base(
|
111 |
+
bm(
|
112 |
+
# ok we choose by parameters anyway
|
113 |
+
[res(filters=latent[s], padding=padding, act=act)] + (
|
114 |
+
[drop(drop_latent)] if drop_latent else [])
|
115 |
+
),
|
116 |
+
name="stem")
|
117 |
+
)
|
118 |
+
else:
|
119 |
+
encoder = bm([
|
120 |
+
Res(),
|
121 |
+
build_encoder(arch[4:], add_norm=encoder_norm, add_pool=encoder_pool, kerner_size=(4, 4),
|
122 |
+
strides=strides),
|
123 |
+
short(encoder_global) if encoder_global else Flatten(),
|
124 |
+
dense(latent[s])
|
125 |
+
])
|
126 |
+
head = LatentHeadModel(
|
127 |
+
encoder=encoder,
|
128 |
+
inference_network=bm([
|
129 |
+
# todo Echeck Cat
|
130 |
+
CatDenseBefore(filters=int(context_size[s] / 8)) if encoder_before[
|
131 |
+
s] else CatDense(filters=context_size[s]),
|
132 |
+
# todo activation?
|
133 |
+
Res(model="dv2", filters=context_size[s], padding=padding)
|
134 |
+
] + ([dense_drop(drop_inference)] if drop_inference else []),
|
135 |
+
name="inference"),
|
136 |
+
stem=Base(
|
137 |
+
bm(
|
138 |
+
# ok we choose by parameters anyway
|
139 |
+
[res(model="dv2", units=latent[s], padding=padding, act=act)] + (
|
140 |
+
[dense_drop(drop_latent)] if drop_latent else [])
|
141 |
+
),
|
142 |
+
name="stem")
|
143 |
+
)
|
144 |
+
heads.append(head)
|
145 |
+
|
146 |
+
concat_input = [f"{LATENT}_{i}" for i, _ in enumerate(heads)] + [f"{INFERENCE}_{i}" for i, _ in enumerate(heads)]
|
147 |
+
concat_output = ["LATENTS", "INFERENCES"]
|
148 |
+
|
149 |
+
def head_concat(inputs):
|
150 |
+
latents = inputs[:len(heads)]
|
151 |
+
inferences = inputs[len(heads):]
|
152 |
+
return latents, inferences
|
153 |
+
|
154 |
+
head = ListModel([(h, (INPUTS if i == 0 else OUTPUT), [f"{LATENT}_{i}", f"{INFERENCE}_{i}", OUTPUT]) for i, h in
|
155 |
+
enumerate(heads)] + [
|
156 |
+
(head_concat, concat_input, concat_output)], out=concat_output)
|
157 |
+
# from rav_utils.raven import init_image
|
158 |
+
# a = init_image()
|
159 |
+
# head(a)
|
160 |
+
|
161 |
+
if model is None:
|
162 |
+
model = []
|
163 |
+
for i in scales:
|
164 |
+
trans_models = []
|
165 |
+
for t in range(trans_no):
|
166 |
+
trans_models.append(
|
167 |
+
bm(
|
168 |
+
[create_block(latent=latent[i], simpler=simpler, padding=padding, norm=trans_norm, act=res_act,
|
169 |
+
loop=loop, type_="dense" if i == 2 else "conv", block_=block_)] +
|
170 |
+
[Activation(trans_act)] + [
|
171 |
+
res(filters=latent[i],
|
172 |
+
padding=padding,
|
173 |
+
act=act,
|
174 |
+
name="body_out",
|
175 |
+
model="dv2" if i == 2 else "v2") for _ in
|
176 |
+
range(body_layers)] + ([Drop(drop_latent)] if drop_latent else []),
|
177 |
+
base_class=SubClassingModel)
|
178 |
+
)
|
179 |
+
trans_models = lu(trans_models)
|
180 |
+
if trans_no > 1:
|
181 |
+
trans_models = bm([
|
182 |
+
lambda x: [[x[0], x[1]], x[1]],
|
183 |
+
SoftBlock(
|
184 |
+
model=trans_models,
|
185 |
+
score_model=bm([
|
186 |
+
Flat2(filters=latent[i], units=256, res_no=2),
|
187 |
+
Dense(trans_no, trans_score_activation)
|
188 |
+
])
|
189 |
+
)
|
190 |
+
],
|
191 |
+
base_class=SubClassingModel
|
192 |
+
)
|
193 |
+
|
194 |
+
model.append(
|
195 |
+
TransModel(
|
196 |
+
body=trans_models,
|
197 |
+
dir_=dir_,
|
198 |
+
images_no=images_no
|
199 |
+
)
|
200 |
+
)
|
201 |
+
|
202 |
+
tail = []
|
203 |
+
for i, s in enumerate(scales):
|
204 |
+
flatting = lambda: Flat2(filters=latent[s + 1], base_class=tail_flatten, units=tail_units)
|
205 |
+
if s == 0:
|
206 |
+
if tail_flatten is None:
|
207 |
+
branch = bm([res(filters=latent[s], padding=padding),
|
208 |
+
conv(filters=latent[s], padding=padding),
|
209 |
+
BatchNormalization(),
|
210 |
+
conv(filters=latent[s], padding=padding),
|
211 |
+
Flatten()])
|
212 |
+
else:
|
213 |
+
branch = bm([down(base_class=tail_down), flatting()])
|
214 |
+
elif s == 1:
|
215 |
+
if tail_flatten is None:
|
216 |
+
branch = bm([res(filters=latent[s], padding=padding),
|
217 |
+
Flatten()])
|
218 |
+
else:
|
219 |
+
branch = flatting()
|
220 |
+
else:
|
221 |
+
branch = bm([tail_units] * 2, add_flatten=False)
|
222 |
+
tail.append(branch)
|
223 |
+
|
224 |
+
tail.append(
|
225 |
+
bm([dense(tail_units)] + ([dense_drop(drop_end)] if drop_end else []) + [Dense(output_size)], add_flatten=False,
|
226 |
+
name=TAIL))
|
227 |
+
class_input = []
|
228 |
+
|
229 |
+
return bt([
|
230 |
+
DictModel(head, in_=INPUTS, out=[LATENT, INFERENCE], name="Head"),
|
231 |
+
DictModel(FullTrans(model, scales=scales), in_=[LATENT, INFERENCE], out=TRANS, name="Body"),
|
232 |
+
DictModel(RavenClass(Merge(tail), scales=scales, no=8), in_=[LATENT] + class_input, out=CLASSIFICATION,
|
233 |
+
name="Classificator"),
|
234 |
+
DictModel(RavenClass(Merge(tail), scales=list(range(len(scales))), no=3), in_=[TRANS] + class_input,
|
235 |
+
out=OUTPUT, name="Classificator_trans"),
|
236 |
+
],
|
237 |
+
loss=RavenLoss(mode=loss_mode, classification=True, trans=True, lw=(1.0, loss_weight)),
|
238 |
+
loss_wrap=False
|
239 |
+
)
|
raven_utils/models/trans.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from ml_utils import lw
|
3 |
+
from models_utils import ops as K, SubClassingModel
|
4 |
+
from tensorflow.keras import Model
|
5 |
+
|
6 |
+
from models.body import create_dense_block
|
7 |
+
import raven_utils as rv
|
8 |
+
from raven_utils.const import HORIZONTAL, VERTICAL
|
9 |
+
|
10 |
+
|
11 |
+
class TransModel(Model):
|
12 |
+
def __init__(self, body=None, dir_=HORIZONTAL, images_no=8, latent=64):
|
13 |
+
super().__init__()
|
14 |
+
self.model = body or create_dense_block(latent=latent)
|
15 |
+
if dir_ == VERTICAL:
|
16 |
+
self.dir = (0, 3, 1, 4, 3, 5)
|
17 |
+
else:
|
18 |
+
self.dir = (0, 1, 3, 4, 6, 7)
|
19 |
+
self.images_no = images_no
|
20 |
+
self.latent = latent
|
21 |
+
|
22 |
+
def call(self, inputs):
|
23 |
+
# latents = tnp.asarray(inputs[0])
|
24 |
+
latents = inputs[0]
|
25 |
+
inference = inputs[1]
|
26 |
+
shape = tf.shape(latents)
|
27 |
+
new_shape = K.cat([[-1, 3, 2], shape[2:]])
|
28 |
+
horizontal = latents[:, self.dir].reshape(new_shape)
|
29 |
+
res = tf.TensorArray(tf.float32, size=3)
|
30 |
+
for i in range(3):
|
31 |
+
res = res.write(i, self.model([horizontal[:, i], inference]))
|
32 |
+
result = K.tran(res.stack())
|
33 |
+
return result
|
34 |
+
|
35 |
+
|
36 |
+
class TransModel2(Model):
|
37 |
+
def __init__(self, body=None, dir_=HORIZONTAL, images_no=8, latent=64):
|
38 |
+
super().__init__()
|
39 |
+
self.body = body or create_dense_block(latent=latent)
|
40 |
+
if dir_ == VERTICAL:
|
41 |
+
self.dir = (0, 3, 1, 4, 3, 5)
|
42 |
+
else:
|
43 |
+
self.dir = (0, 1, 3, 4, 6, 7)
|
44 |
+
self.images_no = images_no
|
45 |
+
self.latent = latent
|
46 |
+
|
47 |
+
def call(self, inputs):
|
48 |
+
# latents = tnp.asarray(inputs[0])
|
49 |
+
latents = inputs[0]
|
50 |
+
inference = inputs[1]
|
51 |
+
shape = tf.shape(latents)
|
52 |
+
new_shape = K.cat([[-1, 3, 2], shape[2:]])
|
53 |
+
horizontal = latents[:, self.dir].reshape(new_shape)
|
54 |
+
res = tf.TensorArray(tf.float32, size=3)
|
55 |
+
for i in tf.range(3):
|
56 |
+
res = res.write(i, self.body([horizontal[:, i], inference[:,i]]))
|
57 |
+
result = K.tran(res.stack())
|
58 |
+
return result
|
59 |
+
|
60 |
+
|
61 |
+
class FullTrans(SubClassingModel):
|
62 |
+
def __init__(self, model,scales,name=None):
|
63 |
+
super().__init__(model=model,name=name)
|
64 |
+
self.scales = scales
|
65 |
+
|
66 |
+
def call(self, inputs):
|
67 |
+
latent = lw(inputs[0])
|
68 |
+
inference = lw(inputs[1])
|
69 |
+
results = []
|
70 |
+
# todo merging inference?
|
71 |
+
for i,s in enumerate(self.scales):
|
72 |
+
# results.append(model([latent[::-1][i], inference]))
|
73 |
+
results.append(self.model[i]([latent[s], inference[s]]))
|
74 |
+
return results,
|
raven_utils/models/transformer.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from tensorflow.keras.layers import Lambda
|
3 |
+
from tensorflow.python.keras import Sequential
|
4 |
+
|
5 |
+
# from models_utils.models.loss import VTRavenLoss, create_uniform_mask, SingleVTRavenLoss
|
6 |
+
from models_utils import DictModel, bt, INPUTS, BatchInitialWeight
|
7 |
+
import models_utils.ops as K
|
8 |
+
from models_utils.models.transformer.img_seq import init_weights, take_left, mix, empty_last
|
9 |
+
from models_utils.models.transformer.img_seq2 import init_weights, take_left, mix, empty_last, img_sec_trans
|
10 |
+
from models_utils.ops_core import IndexReshape
|
11 |
+
from models_utils.random_ import EpsilonGreedy, EpsilonSoft
|
12 |
+
from models_utils.step import StepDict
|
13 |
+
|
14 |
+
# res = (1 - mask) * inputs[..., :self.last_index] + mask * tf.tile(self.get_last(inputs),
|
15 |
+
# (1, 1, 1, self.last_index))
|
16 |
+
|
17 |
+
|
18 |
+
# from data_utils import ims
|
19 |
+
# for i in range(50):
|
20 |
+
# ims(res[i].numpy().swapaxes(0, 2))
|
21 |
+
# res[12].numpy()
|
22 |
+
# self.get_last(inputs).numpy()
|
23 |
+
# import tensorflow as tf
|
24 |
+
# tf.random.uniform(shape=shape[0:1], maxval=255, dtype=tf.int32)
|
25 |
+
# from ml_utils import print_error
|
26 |
+
# ims(mask[0].numpy())
|
27 |
+
# print_error(lambda :ims(mask[0]))
|
28 |
+
# from models_utils import ops as K
|
29 |
+
|
30 |
+
|
31 |
+
# self.model.layers[0](x)
|
32 |
+
from raven_utils.models.loss import VTRavenLoss, SingleVTRavenLoss, create_uniform_mask
|
33 |
+
|
34 |
+
|
35 |
+
def get_rav_trans(
|
36 |
+
data,
|
37 |
+
type_=9,
|
38 |
+
no=4,
|
39 |
+
extractor="ef",
|
40 |
+
loss_mode=create_uniform_mask,
|
41 |
+
output_size=10,
|
42 |
+
loss_weight=1.0,
|
43 |
+
out_layers=(1000, 1000, 1000),
|
44 |
+
pos_emd="cat",
|
45 |
+
base="seq",
|
46 |
+
inverse_image=True,
|
47 |
+
last="left",
|
48 |
+
epsilon="greedy",
|
49 |
+
epsilon_step=500,
|
50 |
+
mask_fn=None,
|
51 |
+
model=None,
|
52 |
+
loss="multi",
|
53 |
+
**kwargs):
|
54 |
+
if last == "left":
|
55 |
+
last = take_left
|
56 |
+
elif last == "mix":
|
57 |
+
last = mix
|
58 |
+
elif last == "empty":
|
59 |
+
last = empty_last
|
60 |
+
elif last == "start":
|
61 |
+
last = Sequential([Lambda(empty_last), BatchInitialWeight(initializer=init_weights)])
|
62 |
+
|
63 |
+
if epsilon == "greedy":
|
64 |
+
epsilon = EpsilonGreedy(step=epsilon_step)
|
65 |
+
elif epsilon == "soft":
|
66 |
+
epsilon = EpsilonSoft(step=epsilon_step)
|
67 |
+
elif epsilon is False:
|
68 |
+
epsilon = None
|
69 |
+
|
70 |
+
if epsilon:
|
71 |
+
trans_raven = TransRavenwithStep(
|
72 |
+
type_=type_,
|
73 |
+
no=no,
|
74 |
+
extractor=extractor,
|
75 |
+
output_size=output_size,
|
76 |
+
out_layer=out_layers,
|
77 |
+
pos_emd=pos_emd,
|
78 |
+
base=base,
|
79 |
+
last=last,
|
80 |
+
epsilon=epsilon,
|
81 |
+
**kwargs
|
82 |
+
)
|
83 |
+
return StepDict(bt(
|
84 |
+
DictModel(
|
85 |
+
Sequential([Lambda(lambda x: (255 - x[0], x[1])), trans_raven]) if inverse_image else trans_raven,
|
86 |
+
in_=[INPUTS, "step"],
|
87 |
+
name="Body"
|
88 |
+
),
|
89 |
+
loss=VTRavenLoss(mode=loss_mode, classification=True, lw=(loss_weight, 1.0)),
|
90 |
+
loss_wrap=False),
|
91 |
+
add_step=epsilon_step,
|
92 |
+
)
|
93 |
+
|
94 |
+
trans_raven = img_sec_trans(
|
95 |
+
type_=type_,
|
96 |
+
no=no,
|
97 |
+
extractor=extractor,
|
98 |
+
model=model,
|
99 |
+
output_size=output_size,
|
100 |
+
out_layer=out_layers,
|
101 |
+
pos_emd=pos_emd,
|
102 |
+
base=base,
|
103 |
+
last=last,
|
104 |
+
epsilon=epsilon,
|
105 |
+
mask_fn=mask_fn,
|
106 |
+
**kwargs
|
107 |
+
)
|
108 |
+
if loss == "single":
|
109 |
+
loss = SingleVTRavenLoss
|
110 |
+
else:
|
111 |
+
loss = VTRavenLoss
|
112 |
+
|
113 |
+
# return bt(
|
114 |
+
# DictModel(
|
115 |
+
# Sequential([Lambda(lambda x: 255 - x), trans_raven]) if inverse_image else trans_raven,
|
116 |
+
# inputs=INPUTS,
|
117 |
+
# name="Body"
|
118 |
+
# ),
|
119 |
+
# loss=loss(mode=loss_mode, classification=True, lw=(loss_weight, 1.0)),
|
120 |
+
# loss_wrap=False
|
121 |
+
# )
|
122 |
+
|
123 |
+
return bt([
|
124 |
+
DictModel(
|
125 |
+
Sequential([Lambda(lambda x: 255 - x), trans_raven]) if inverse_image else trans_raven,
|
126 |
+
in_=INPUTS,
|
127 |
+
name="Body"
|
128 |
+
),
|
129 |
+
|
130 |
+
],
|
131 |
+
loss=loss(mode=loss_mode, classification=True, lw=(loss_weight, 1.0)),
|
132 |
+
loss_wrap=False
|
133 |
+
)
|
raven_utils/models/transformer_2.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow.keras.layers import Lambda
|
5 |
+
from tensorflow.python.keras import Sequential
|
6 |
+
from models_utils import ops as K, SubClassing
|
7 |
+
from models_utils.models.transformer import aug
|
8 |
+
|
9 |
+
# from models_utils.models.loss import VTRavenLoss, create_uniform_mask, SingleVTRavenLoss
|
10 |
+
from data_utils import DataGenerator, LOSS, TARGET, IMAGES
|
11 |
+
from models_utils import DictModel, bt, INPUTS, BatchInitialWeight, build_functional_model, get_input_layer
|
12 |
+
import models_utils.ops as K
|
13 |
+
from models_utils.models.transformer.img_seq import init_weights, take_left, mix, empty_last
|
14 |
+
from models_utils.models.transformer.img_seq2 import init_weights, take_left, mix, empty_last, img_sec_trans
|
15 |
+
from models_utils.ops_core import IndexReshape
|
16 |
+
from models_utils.random_ import EpsilonGreedy, EpsilonSoft
|
17 |
+
from models_utils.step import StepDict
|
18 |
+
|
19 |
+
from models_utils.models.transformer import aug
|
20 |
+
|
21 |
+
# res = (1 - mask) * inputs[..., :self.last_index] + mask * tf.tile(self.get_last(inputs),
|
22 |
+
# (1, 1, 1, self.last_index))
|
23 |
+
|
24 |
+
|
25 |
+
# from data_utils import ims
|
26 |
+
# for i in range(50):
|
27 |
+
# ims(res[i].numpy().swapaxes(0, 2))
|
28 |
+
# res[12].numpy()
|
29 |
+
# self.get_last(inputs).numpy()
|
30 |
+
# import tensorflow as tf
|
31 |
+
# tf.random.uniform(shape=shape[0:1], maxval=255, dtype=tf.int32)
|
32 |
+
# from ml_utils import print_error
|
33 |
+
# ims(mask[0].numpy())
|
34 |
+
# print_error(lambda :ims(mask[0]))
|
35 |
+
# from models_utils import ops as K
|
36 |
+
|
37 |
+
|
38 |
+
# self.model.layers[0](x)
|
39 |
+
from raven_utils.constant import INDEX, LABELS
|
40 |
+
from raven_utils.models.loss import VTRavenLoss, SingleVTRavenLoss, create_uniform_mask
|
41 |
+
|
42 |
+
|
43 |
+
def get_matrix(inputs, index):
|
44 |
+
return tf.concat([inputs[:, :8], K.gather(inputs, index[:, 0])[:, None]], axis=1)
|
45 |
+
|
46 |
+
|
47 |
+
def get_images(inputs):
|
48 |
+
return get_matrix(inputs[0], inputs[1])
|
49 |
+
|
50 |
+
|
51 |
+
def random_last(inputs, max_=8):
|
52 |
+
index = K.init.label(max=max_, shape=[tf.shape(inputs[0])[0]])[..., None]
|
53 |
+
return get_matrix(inputs[0], index)
|
54 |
+
|
55 |
+
|
56 |
+
def get_images_no_answer(inputs):
|
57 |
+
return inputs[0][:, :9]
|
58 |
+
|
59 |
+
|
60 |
+
def repeat_last(inputs):
|
61 |
+
return inputs[0][:, list(range(8)) + [7]]
|
62 |
+
|
63 |
+
|
64 |
+
def get_rav_trans(
|
65 |
+
data,
|
66 |
+
inverse_image=True,
|
67 |
+
loss_mode=create_uniform_mask,
|
68 |
+
loss_weight=1.0,
|
69 |
+
loss="multi",
|
70 |
+
number_loss=False,
|
71 |
+
plw=None,
|
72 |
+
pre="auto",
|
73 |
+
augmentation=None,
|
74 |
+
**kwargs):
|
75 |
+
if isinstance(data, DataGenerator):
|
76 |
+
data = data[0]['inputs'], data[0]['index']
|
77 |
+
# u = img_sec_trans(**kwargs)(get_images(data) if kwargs['mask'] == "random" else get_images_no_answer(data))
|
78 |
+
# u.shape
|
79 |
+
from keras import Model
|
80 |
+
if pre == "auto":
|
81 |
+
pre = get_images if kwargs['mask'] == "random" else get_images_no_answer
|
82 |
+
elif pre == "no_answer":
|
83 |
+
pre = get_images_no_answer
|
84 |
+
elif pre == "last":
|
85 |
+
pre = repeat_last
|
86 |
+
elif pre == "images":
|
87 |
+
pre = get_images
|
88 |
+
elif pre == "random_last":
|
89 |
+
pre = random_last
|
90 |
+
elif pre == "noise":
|
91 |
+
pre = SubClassing([get_matrix, partial(aug.noise, max_=8)])
|
92 |
+
elif pre == "batch_noise":
|
93 |
+
pre = SubClassing([get_matrix, partial(aug.batch_noise, max_=8)])
|
94 |
+
|
95 |
+
if augmentation == "transpose":
|
96 |
+
augmentation = aug.Transpose(axis=(0, 2, 1))
|
97 |
+
augmentation_label = aug.Transpose(axis=(0, 2, 1))
|
98 |
+
elif augmentation == "shuffle_col":
|
99 |
+
augmentation = aug.shuffle_col
|
100 |
+
augmentation_label = aug.shuffle_col
|
101 |
+
elif augmentation == "shuffle":
|
102 |
+
augmentation = aug.shuffle
|
103 |
+
augmentation_label = aug.shuffle
|
104 |
+
if augmentation:
|
105 |
+
augmentation = [
|
106 |
+
# DictModel(augmentation, IMAGES, IMAGES),
|
107 |
+
# DictModel(aug.reshape_static(pre(data),augmentation), IMAGES, IMAGES),
|
108 |
+
DictModel(aug.ReshapeStatic(augmentation), IMAGES, IMAGES),
|
109 |
+
DictModel(
|
110 |
+
aug.PartialModel(
|
111 |
+
aug.ReshapeStatic(augmentation_label),
|
112 |
+
last_axis=9)
|
113 |
+
, LABELS, LABELS)
|
114 |
+
]
|
115 |
+
else:
|
116 |
+
augmentation = []
|
117 |
+
|
118 |
+
trans_raven = build_functional_model(
|
119 |
+
img_sec_trans(**kwargs),
|
120 |
+
# get_images(data) if kwargs['mask'] == "random" else get_images_no_answer(data)
|
121 |
+
pre(data)
|
122 |
+
# data[0]
|
123 |
+
)
|
124 |
+
if loss == "single":
|
125 |
+
loss = SingleVTRavenLoss
|
126 |
+
else:
|
127 |
+
loss = VTRavenLoss
|
128 |
+
if isinstance(loss_weight, float):
|
129 |
+
loss_weight = (loss_weight, 1.0)
|
130 |
+
|
131 |
+
return bt([
|
132 |
+
# DictModel(get_images if kwargs['mask'] == "random" else get_images_no_answer, [INPUTS, INDEX], IMAGES),
|
133 |
+
DictModel(pre, [INPUTS, INDEX], IMAGES),
|
134 |
+
*augmentation,
|
135 |
+
DictModel(
|
136 |
+
Sequential([Lambda(lambda x: 255 - x), trans_raven]) if inverse_image else trans_raven,
|
137 |
+
in_=IMAGES,
|
138 |
+
# inputs=INPUTS,
|
139 |
+
name="Body"
|
140 |
+
),
|
141 |
+
|
142 |
+
],
|
143 |
+
loss=loss(mode=loss_mode, classification=True, number_loss=number_loss, lw=loss_weight, plw=plw),
|
144 |
+
predict=LOSS,
|
145 |
+
loss_wrap=False
|
146 |
+
)
|
raven_utils/models/transformer_3.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
from loguru import logger
|
4 |
+
from tensorflow.keras.layers import Lambda
|
5 |
+
from tensorflow.keras.layers import Activation
|
6 |
+
|
7 |
+
from grid_transformer import aug_trans
|
8 |
+
from raven_utils.models.loss_3 import VTRavenLoss, SingleVTRavenLoss, create_uniform_mask
|
9 |
+
from data_utils import get_shape, TakeDict
|
10 |
+
|
11 |
+
from data_utils import DataGenerator, LOSS, TARGET, IMAGES
|
12 |
+
from models_utils import DictModel, bt, INPUTS, BatchInitialWeight, build_functional, get_input_layer, Last, bm, \
|
13 |
+
add_end, AUGMENTATION
|
14 |
+
# from report.select_ import SelectModel2, SelectModel, SelectModel9
|
15 |
+
from experiment_utils.keras_model import load_weights as model_load_weights
|
16 |
+
|
17 |
+
|
18 |
+
def get_rav_trans(
|
19 |
+
data,
|
20 |
+
loss_mode=create_uniform_mask,
|
21 |
+
loss_weight=2.0,
|
22 |
+
number_loss=False,
|
23 |
+
dry_run="auto",
|
24 |
+
plw=None,
|
25 |
+
**kwargs):
|
26 |
+
if isinstance(loss_weight, float):
|
27 |
+
loss_weight = (loss_weight, 1.0)
|
28 |
+
|
29 |
+
# seq_trans(**kwargs)(data[0])
|
30 |
+
# trans_raven = build_functional_model2(
|
31 |
+
# seq_trans(**kwargs),
|
32 |
+
# data[0],
|
33 |
+
# batch=None
|
34 |
+
# )
|
35 |
+
trans_raven = build_functional(
|
36 |
+
model=aug_trans,
|
37 |
+
inputs_=data[0] if isinstance(data, DataGenerator) else data,
|
38 |
+
batch_=None,
|
39 |
+
dry_run=dry_run,
|
40 |
+
**kwargs
|
41 |
+
)
|
42 |
+
|
43 |
+
return bt(
|
44 |
+
model=trans_raven,
|
45 |
+
loss=VTRavenLoss(mode=loss_mode, classification=True, number_loss=number_loss, lw=loss_weight, plw=plw),
|
46 |
+
model_wrap=False,
|
47 |
+
predict=LOSS,
|
48 |
+
loss_wrap=False
|
49 |
+
)
|
50 |
+
|
51 |
+
|
52 |
+
def rav_select_model(
|
53 |
+
data,
|
54 |
+
load_weights=None,
|
55 |
+
loss_weight=(0.01, 0.0),
|
56 |
+
plw=5.0,
|
57 |
+
result_metric="sparse_categorical_accuracy",
|
58 |
+
select_type=2,
|
59 |
+
select_out=0,
|
60 |
+
additional_out=0,
|
61 |
+
additional_copy=True,
|
62 |
+
tail_out=(1000, 1000),
|
63 |
+
**kwargs
|
64 |
+
):
|
65 |
+
out_layers = Last()
|
66 |
+
if additional_out > 0:
|
67 |
+
model3 = get_rav_trans(
|
68 |
+
data,
|
69 |
+
plw=plw,
|
70 |
+
loss_weight=loss_weight,
|
71 |
+
**kwargs
|
72 |
+
)
|
73 |
+
|
74 |
+
model_load_weights(
|
75 |
+
model3,
|
76 |
+
load_weights,
|
77 |
+
# sample_data,
|
78 |
+
None,
|
79 |
+
template="weights_{epoch:02d}-{val_loss:.2f}",
|
80 |
+
key=result_metric,
|
81 |
+
)
|
82 |
+
|
83 |
+
if AUGMENTATION in kwargs and kwargs[AUGMENTATION] is not None:
|
84 |
+
index = -1
|
85 |
+
else:
|
86 |
+
index = -2
|
87 |
+
|
88 |
+
out = model3[0, index, :additional_out]
|
89 |
+
logger.info(f"Additional out from: {model3[0, index]}.")
|
90 |
+
|
91 |
+
if additional_out > 2:
|
92 |
+
out += [Activation("gelu")]
|
93 |
+
out_layers = bm([out_layers] + out, add_flatten=False)
|
94 |
+
model = get_rav_trans(
|
95 |
+
TakeDict(data[0])[:, 8:],
|
96 |
+
plw=plw,
|
97 |
+
loss_weight=loss_weight,
|
98 |
+
**{
|
99 |
+
**kwargs,
|
100 |
+
"out_layers": out_layers,
|
101 |
+
}
|
102 |
+
# **{**as_dict(p.mp), "show_shape": True, "save_shape": f"output/shapes/type_{p.mp.type_}.json"},
|
103 |
+
)
|
104 |
+
# from data_utils.ops import Equal
|
105 |
+
# o = []
|
106 |
+
# for i in range(1, 3):
|
107 |
+
# for j in range(2):
|
108 |
+
# o.append(
|
109 |
+
# Equal(
|
110 |
+
# # model[0,:,-2, i].variables[j],
|
111 |
+
# model2[0, :, -2, i].variables[j],
|
112 |
+
# # out_layers[i].variables[j]
|
113 |
+
# second_pooling[i].variables[j]
|
114 |
+
# ).equal
|
115 |
+
# )
|
116 |
+
# assert all(o)
|
117 |
+
# model = get_rav_trans(
|
118 |
+
# # TakeDict(val_generator[0])[:, 8:],
|
119 |
+
# # TakeDict(val_generator[0])[:, 8:],
|
120 |
+
# val_generator[0],
|
121 |
+
# plw=p.plw,
|
122 |
+
# loss_weight=p.loss_weight,
|
123 |
+
# **{**as_dict(p.mp),
|
124 |
+
# # "out_layers": out_layers,
|
125 |
+
# }
|
126 |
+
# # **{**as_dict(p.mp), "show_shape": True, "save_shape": f"output/shapes/type_{p.mp.type_}.json"},
|
127 |
+
# )
|
128 |
+
model_load_weights(model,
|
129 |
+
load_weights,
|
130 |
+
# sample_data,
|
131 |
+
None,
|
132 |
+
template="weights_{epoch:02d}-{val_loss:.2f}",
|
133 |
+
key=result_metric,
|
134 |
+
)
|
135 |
+
# model.compile()
|
136 |
+
# model.evaluate(val_generator.data[:1000])
|
137 |
+
# model(TakeDict(val_generator[0])[:, 8:])
|
138 |
+
trans_raven = model[0]
|
139 |
+
# s = trans_raven(TakeDict(val_generator[0])[:, 8:])
|
140 |
+
if select_type == 2:
|
141 |
+
second_pooling = Lambda(lambda x: x[:, :-1])
|
142 |
+
else:
|
143 |
+
second_pooling = Last()
|
144 |
+
if additional_out > 0:
|
145 |
+
if additional_copy:
|
146 |
+
model4 = get_rav_trans(
|
147 |
+
data,
|
148 |
+
plw=plw,
|
149 |
+
loss_weight=loss_weight,
|
150 |
+
**kwargs
|
151 |
+
)
|
152 |
+
model_load_weights(model4,
|
153 |
+
load_weights,
|
154 |
+
# sample_data,
|
155 |
+
None,
|
156 |
+
template="weights_{epoch:02d}-{val_loss:.2f}",
|
157 |
+
key=result_metric,
|
158 |
+
)
|
159 |
+
|
160 |
+
if AUGMENTATION in kwargs and kwargs[AUGMENTATION] is not None:
|
161 |
+
index = -1
|
162 |
+
else:
|
163 |
+
index = -2
|
164 |
+
out2 = model4[0, index, :additional_out]
|
165 |
+
logger.info(f"Additional out from: {model4[0, index]}.")
|
166 |
+
|
167 |
+
if additional_out > 2:
|
168 |
+
out2 += [Activation("gelu")]
|
169 |
+
else:
|
170 |
+
out2 = out
|
171 |
+
|
172 |
+
second_pooling = bm([second_pooling] + out2, add_flatten=False)
|
173 |
+
|
174 |
+
model2 = get_rav_trans(
|
175 |
+
TakeDict(data[0])[:, 8:],
|
176 |
+
plw=plw,
|
177 |
+
loss_weight=loss_weight,
|
178 |
+
**{
|
179 |
+
**kwargs,
|
180 |
+
"out_layers": second_pooling,
|
181 |
+
}
|
182 |
+
# **{**as_dict(p.mp), "show_shape": True, "save_shape": f"output/shapes/type_{p.mp.type_}.json"},
|
183 |
+
)
|
184 |
+
model_load_weights(
|
185 |
+
model2,
|
186 |
+
load_weights,
|
187 |
+
# sample_data,
|
188 |
+
None,
|
189 |
+
template="weights_{epoch:02d}-{val_loss:.2f}",
|
190 |
+
key=result_metric,
|
191 |
+
)
|
192 |
+
if select_type == 0:
|
193 |
+
# not working
|
194 |
+
trans_raven2 = model2[0]
|
195 |
+
else:
|
196 |
+
trans_raven2 = model2[0]
|
197 |
+
tail = add_end(out_layers=tail_out, output_size=8 if select_out else 1)
|
198 |
+
# trans_raven2.mask_fn = ImageMask(last=take_by_index)
|
199 |
+
if select_type == 2:
|
200 |
+
select_model_class = SelectModel2
|
201 |
+
elif select_type == 1:
|
202 |
+
select_model_class = SelectModel
|
203 |
+
else:
|
204 |
+
select_model_class = SelectModel9
|
205 |
+
select_model = select_model_class(trans_raven, model2=trans_raven2, tail=tail, select_out=select_out)
|
206 |
+
return select_model
|
raven_utils/models/uitls_.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import tensorflow.experimental.numpy as tnp
|
3 |
+
from tensorflow.keras import Model
|
4 |
+
import raven_utils as rv
|
5 |
+
|
6 |
+
|
7 |
+
class RangeMask(Model):
|
8 |
+
def __init__(self):
|
9 |
+
super().__init__()
|
10 |
+
ranges = tf.tile(tf.range(rv.entity.INDEX[-1])[None], [rv.group.NO, 1])
|
11 |
+
start_index = rv.entity.INDEX[:-1][:, None]
|
12 |
+
end_index = rv.entity.INDEX[1:][:, None]
|
13 |
+
self.mask = tnp.array((start_index <= ranges) & (ranges < end_index))
|
14 |
+
|
15 |
+
def call(self, inputs):
|
16 |
+
return self.mask[inputs]
|
raven_utils/output.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import raven_utils.entity as entity
|
3 |
+
import raven_utils.properties as properties
|
4 |
+
import raven_utils.group as group
|
5 |
+
|
6 |
+
SIZE = entity.SUM * properties.SUM + group.NO + entity.SUM
|
7 |
+
|
8 |
+
SLOT_AND_GROUP = group.NO + entity.SUM
|
9 |
+
|
10 |
+
PROPERTIES_SLICE = np.s_[:, :-SLOT_AND_GROUP]
|
11 |
+
SLOT_SLICE = np.s_[:, -SLOT_AND_GROUP:-group.NO]
|
12 |
+
GROUP_SLICE = np.s_[:, -group.NO:]
|
13 |
+
|
14 |
+
GROUP_SLICE_END = np.s_[-group.NO:]
|
15 |
+
SLOT_SLICE_END = np.s_[-SLOT_AND_GROUP:-group.NO]
|
16 |
+
PROPERTIES_SLICE_END = np.s_[:-SLOT_AND_GROUP]
|
raven_utils/params.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Any, Tuple
|
3 |
+
|
4 |
+
from ml_utils import get_str_name
|
5 |
+
from grid_transformer.params import ImgSeqTransformerParameters
|
6 |
+
from raven_utils import output
|
7 |
+
|
8 |
+
from experiment_utils.parameters.nn_default import TP, EP
|
9 |
+
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class SudokuParameters(ImgSeqTransformerParameters):
|
13 |
+
mask: str = "input"
|
14 |
+
col: int = 3
|
15 |
+
row: int = 3
|
16 |
+
pooling: int = 81
|
17 |
+
output_size: int = 9
|
18 |
+
size: int = 384
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class RavenTransParameters(ImgSeqTransformerParameters):
|
23 |
+
mask: str = "last"
|
24 |
+
last_index: int = 8
|
25 |
+
col: int = 1
|
26 |
+
row: int = 1
|
27 |
+
output_size: int = output.SIZE
|
28 |
+
number_loss: bool = 0
|
29 |
+
pre: str = "images"
|
30 |
+
num_heads: int = 8
|
31 |
+
|
32 |
+
|
33 |
+
MP = RavenTransParameters
|
34 |
+
|
35 |
+
|
36 |
+
@dataclass
|
37 |
+
class RavenSelectTransParameters(RavenTransParameters):
|
38 |
+
select_type: int = 2
|
39 |
+
select_out: int = 0
|
40 |
+
additional_out: int = 0
|
41 |
+
additional_copy: bool = True
|
42 |
+
tail_out: Tuple = (1000, 1000)
|
43 |
+
pre: str = "index"
|
44 |
+
|
45 |
+
|
46 |
+
SMP = RavenSelectTransParameters
|
47 |
+
|
48 |
+
from raven_utils.config.models import AVAILABLE_MODELS
|
49 |
+
from experiment_utils.parameters.nn_clean import Parameters as BaseParameters
|
50 |
+
from raven_utils.config.constant import RAVEN, LABELS, INDEX, FEATURES, RAV_METRICS, IMP_RAV_METRICS, ACC_NO_GROUP, \
|
51 |
+
ACC_SAME
|
52 |
+
|
53 |
+
MODEL_NO = -1
|
54 |
+
|
55 |
+
|
56 |
+
@dataclass
|
57 |
+
class RavenParameters(BaseParameters):
|
58 |
+
dataset_name: str = RAVEN
|
59 |
+
data: Any = (
|
60 |
+
f"{dataset_name}/train.npy",
|
61 |
+
f"{dataset_name}/val.npy",
|
62 |
+
f"{dataset_name}/train_labels.npy",
|
63 |
+
f"{dataset_name}/val_labels.npy",
|
64 |
+
f"{dataset_name}/train_target.npy",
|
65 |
+
f"{dataset_name}/val_target.npy",
|
66 |
+
f"arr/train_features_{AVAILABLE_MODELS[MODEL_NO]}.npy",
|
67 |
+
f"arr/val_features_{AVAILABLE_MODELS[MODEL_NO]}.npy",
|
68 |
+
f"{dataset_name}/val_index.npy"
|
69 |
+
# DataParameters2()
|
70 |
+
)
|
71 |
+
|
72 |
+
# core_metrics: tuple = tuple(RAV_METRICS)
|
73 |
+
filter_metrics: tuple = tuple(IMP_RAV_METRICS)
|
74 |
+
# result_metric: str = ACC_NO_GROUP
|
75 |
+
result_metric: str = ACC_SAME
|
76 |
+
|
77 |
+
lw: float = 0.0001 # Autoencoder
|
78 |
+
loss_weight: float = 2.0
|
79 |
+
plw: int = 5.0
|
80 |
+
mp: RavenTransParameters = RavenTransParameters()
|
81 |
+
|
82 |
+
@property
|
83 |
+
def experiment(self):
|
84 |
+
# return "rav/trans"
|
85 |
+
return "rav/best_test3"
|
86 |
+
# return "rav/trans_weight"
|
87 |
+
|
88 |
+
# @property
|
89 |
+
# def name(self):
|
90 |
+
# # return f"i{self.extractor}_{len(self.tail)}{self.tail[0]}_{self.type_}_{self.epsilon}_{self.last}_{self.epsilon_step}"
|
91 |
+
# return f"{get_str_name(self.mp.pre)[0]}_{str(self.plw)[0]}_{str(self.mp.number_loss)[0]}_{self.mp.extractor}_{self.mp.noise if self.mp.noise else ''}_{self.mp.augmentation if self.mp.augmentation else ''}_{self.mp.extractor_shape}_{self.mp.no}_{self.mp.num_heads}_{self.mp.size}_{self.mp.pos_emd}_{self.mp.ff_mul}_{self.tp.batch}"
|
92 |
+
|
93 |
+
|
94 |
+
@dataclass
|
95 |
+
class BaselineRavenParameters(RavenParameters):
|
96 |
+
|
97 |
+
@property
|
98 |
+
def experiment(self):
|
99 |
+
# return "rav/best_test3"
|
100 |
+
return "rav/baseline"
|
101 |
+
# return "rav/trans_weight"
|
102 |
+
|
103 |
+
@property
|
104 |
+
def name(self):
|
105 |
+
# return f"i{self.extractor}_{len(self.tail)}{self.tail[0]}_{self.type_}_{self.epsilon}_{self.last}_{self.epsilon_step}"
|
106 |
+
return f"{get_str_name(self.mp.pre)[0]}_{str(self.plw)[0]}_{str(self.mp.number_loss)[0]}_{self.mp.extractor}_{self.mp.noise if self.mp.noise else ''}_{self.mp.augmentation if self.mp.augmentation else ''}_{self.mp.extractor_shape}_{self.mp.no}_{self.mp.num_heads}_{self.mp.size}_{self.mp.pos_emd}_{self.mp.ff_mul}_{self.tp.batch}"
|
107 |
+
|
108 |
+
|
109 |
+
if __name__ == '__main__':
|
110 |
+
params = PreRavenTransParameters()
|
raven_utils/properties.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import raven_utils as rv
|
2 |
+
from ml_utils import dict_from_list2, CalcDict
|
3 |
+
import raven_utils.entity as entity
|
4 |
+
|
5 |
+
NAMES = [
|
6 |
+
'Color',
|
7 |
+
'Size',
|
8 |
+
'Type',
|
9 |
+
]
|
10 |
+
RAW_SIZE = [10, 6, 5]
|
11 |
+
SIZE = dict_from_list2(NAMES, RAW_SIZE)
|
12 |
+
ANGLE_SIZE = 7
|
13 |
+
NO = len(NAMES)
|
14 |
+
|
15 |
+
INDEX = (CalcDict(SIZE) * entity.SUM).to_dict()
|
16 |
+
SUM = sum(list(SIZE.values()))
|
raven_utils/range_mask.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import tensorflow.experimental.numpy as tnp
|
3 |
+
from tensorflow.keras import Model
|
4 |
+
import raven_utils as rv
|
5 |
+
|
6 |
+
|
7 |
+
class RangeMask(Model):
|
8 |
+
def __init__(self):
|
9 |
+
super().__init__()
|
10 |
+
ranges = tf.tile(tf.range(rv.entity.INDEX[-1])[None], [rv.group.NO, 1])
|
11 |
+
start_index = rv.entity.INDEX[:-1][:, None]
|
12 |
+
end_index = rv.entity.INDEX[1:][:, None]
|
13 |
+
self.mask = tnp.array((start_index <= ranges) & (ranges < end_index))
|
14 |
+
|
15 |
+
def call(self, inputs):
|
16 |
+
return self.mask[inputs]
|
raven_utils/render/__init__.py
ADDED
File without changes
|
raven_utils/render/const.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
|
4 |
+
# Maximum number of components in a RPM
|
5 |
+
MAX_COMPONENTS = 2
|
6 |
+
|
7 |
+
# Canvas parameters
|
8 |
+
IMAGE_SIZE = 160
|
9 |
+
CENTER = (IMAGE_SIZE / 2, IMAGE_SIZE / 2)
|
10 |
+
DEFAULT_RADIUS = IMAGE_SIZE / 4
|
11 |
+
DEFAULT_WIDTH = 2
|
12 |
+
|
13 |
+
# Attribute parameters
|
14 |
+
# Number
|
15 |
+
NUM_VALUES = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
16 |
+
NUM_MIN = 0
|
17 |
+
NUM_MAX = len(NUM_VALUES) - 1
|
18 |
+
|
19 |
+
# Uniformity
|
20 |
+
UNI_VALUES = [False, False, False, True]
|
21 |
+
UNI_MIN = 0
|
22 |
+
UNI_MAX = len(UNI_VALUES) - 1
|
23 |
+
|
24 |
+
# Type
|
25 |
+
TYPE_VALUES = ["none", "triangle", "square", "pentagon", "hexagon", "circle"]
|
26 |
+
TYPE_MIN = 0
|
27 |
+
TYPE_MAX = len(TYPE_VALUES) - 1
|
28 |
+
|
29 |
+
# Size
|
30 |
+
SIZE_VALUES = [0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
|
31 |
+
SIZE_MIN = 0
|
32 |
+
SIZE_MAX = len(SIZE_VALUES) - 1
|
33 |
+
|
34 |
+
# Color
|
35 |
+
COLOR_VALUES = [255, 224, 196, 168, 140, 112, 84, 56, 28, 0]
|
36 |
+
COLOR_MIN = 0
|
37 |
+
COLOR_MAX = len(COLOR_VALUES) - 1
|
38 |
+
|
39 |
+
# Angle: self-rotation
|
40 |
+
ANGLE_VALUES = [-135, -90, -45, 0, 45, 90, 135, 180]
|
41 |
+
ANGLE_MIN = 0
|
42 |
+
ANGLE_MAX = len(ANGLE_VALUES) - 1
|
43 |
+
|
44 |
+
META_TARGET_FORMAT = ["Constant", "Progression", "Arithmetic", "Distribute_Three", "Number", "Position", "Type", "Size", "Color"]
|
45 |
+
META_STRUCTURE_FORMAT = ["Singleton", "Left_Right", "Up_Down", "Out_In", "Left", "Right", "Up", "Down", "Out", "In", "Grid", "Center_Single", "Distribute_Four", "Distribute_Nine", "Left_Center_Single", "Right_Center_Single", "Up_Center_Single", "Down_Center_Single", "Out_Center_Single", "In_Center_Single", "In_Distribute_Four"]
|
46 |
+
|
47 |
+
# Rule, Attr, Param
|
48 |
+
# The design encodes rule priority order: Number/Position always comes first
|
49 |
+
# Number and Position could not both be sampled
|
50 |
+
# Progression on Number: Number on each Panel +1/2 or -1/2
|
51 |
+
# Progression on Position: Entities on each Panel roll over the layout
|
52 |
+
# Arithmetic on Number: Numeber on the third Panel = Number on first +/- Number on second (1 for + and -1 for -)
|
53 |
+
# Arithmetic on Position: 1 for SET_UNION and -1 for SET_DIFF
|
54 |
+
# Distribute_Three on Number: Three numbers through each row
|
55 |
+
# Distribute_Three on Position: Three positions (same number) through each row
|
56 |
+
# Constant on Number/Position: Nothing changes
|
57 |
+
# Progression on Type: Type progression defined as the number of edges on each entity (Triangle, Square, Pentagon, Hexagon, Circle)
|
58 |
+
# Distribute_Three on Type: Three types through each row
|
59 |
+
# Constant on Type: Nothing changes
|
60 |
+
# Progression on Size: Size on each entity +1/2 or -1/2
|
61 |
+
# Arithmetic on Size: Size on the third Panel = Size on the first +/- Size on the second (1 for + and -1 for -)
|
62 |
+
# Distribute_Three on Size: Three sizes through each row
|
63 |
+
# Constant on Size: Nothing changes
|
64 |
+
# Progression on Color: Color +1/2 or -1/2
|
65 |
+
# Arithmetic on Color: Color on the third Panel = Color on the first +/- Color on the second (1 for + and -1 for -)
|
66 |
+
# Distribute_Three on Color: Three colors through each row
|
67 |
+
# Constant on Color: Nothing changes
|
68 |
+
# Note that all rules on Type, Size and Color enforce value consistency in a panel
|
69 |
+
RULE_ATTR = [[["Progression", "Number", [-2, -1, 1, 2]],
|
70 |
+
["Progression", "Position", [-2, -1, 1, 2]],
|
71 |
+
["Arithmetic", "Number", [1, -1]],
|
72 |
+
["Arithmetic", "Position", [1, -1]],
|
73 |
+
["Distribute_Three", "Number", None],
|
74 |
+
["Distribute_Three", "Position", None],
|
75 |
+
["Constant", "Number/Position", None]],
|
76 |
+
[["Progression", "Type", [-2, -1, 1, 2]],
|
77 |
+
["Distribute_Three", "Type", None],
|
78 |
+
["Constant", "Type", None]],
|
79 |
+
[["Progression", "Size", [-2, -1, 1, 2]],
|
80 |
+
["Arithmetic", "Size", [1, -1]],
|
81 |
+
["Distribute_Three", "Size", None],
|
82 |
+
["Constant", "Size", None]],
|
83 |
+
[["Progression", "Color", [-2, -1, 1, 2]],
|
84 |
+
["Arithmetic", "Color", [1, -1]],
|
85 |
+
["Distribute_Three", "Color", None],
|
86 |
+
["Constant", "Color", None]]]
|
raven_utils/render/rendering.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
#
|
8 |
+
# from AoT import Root
|
9 |
+
import raven_utils.decode
|
10 |
+
from raven_utils.render.const import CENTER, DEFAULT_WIDTH, IMAGE_SIZE
|
11 |
+
|
12 |
+
from data_utils import Bag
|
13 |
+
|
14 |
+
from raven_utils.render_ import COLOR_VALUES, SIZE_VALUES, TYPE_VALUES, ANGLE_VALUES, RENDER_POSITIONS
|
15 |
+
|
16 |
+
|
17 |
+
def imshow(array):
|
18 |
+
image = Image.fromarray(array)
|
19 |
+
image.show()
|
20 |
+
|
21 |
+
|
22 |
+
def imsave(array, filepath):
|
23 |
+
image = Image.fromarray(array)
|
24 |
+
image.save(filepath)
|
25 |
+
|
26 |
+
|
27 |
+
def generate_matrix(array_list):
|
28 |
+
# row-major array_list
|
29 |
+
assert len(array_list) <= 9
|
30 |
+
img_grid = np.zeros((IMAGE_SIZE * 3, IMAGE_SIZE * 3), np.uint8)
|
31 |
+
for idx in range(len(array_list)):
|
32 |
+
i, j = divmod(idx, 3)
|
33 |
+
img_grid[i * IMAGE_SIZE:(i + 1) * IMAGE_SIZE, j * IMAGE_SIZE:(j + 1) * IMAGE_SIZE] = array_list[idx]
|
34 |
+
# draw grid
|
35 |
+
for x in [0.33, 0.67]:
|
36 |
+
img_grid[int(x * IMAGE_SIZE * 3) - 1:int(x * IMAGE_SIZE * 3) + 1, :] = 0
|
37 |
+
for y in [0.33, 0.67]:
|
38 |
+
img_grid[:, int(y * IMAGE_SIZE * 3) - 1:int(y * IMAGE_SIZE * 3) + 1] = 0
|
39 |
+
return img_grid
|
40 |
+
|
41 |
+
|
42 |
+
def generate_answers(array_list):
|
43 |
+
assert len(array_list) <= 8
|
44 |
+
img_grid = np.zeros((IMAGE_SIZE * 2, IMAGE_SIZE * 4), np.uint8)
|
45 |
+
for idx in range(len(array_list)):
|
46 |
+
i, j = divmod(idx, 4)
|
47 |
+
img_grid[i * IMAGE_SIZE:(i + 1) * IMAGE_SIZE, j * IMAGE_SIZE:(j + 1) * IMAGE_SIZE] = array_list[idx]
|
48 |
+
# draw grid
|
49 |
+
for x in [0.5]:
|
50 |
+
img_grid[int(x * IMAGE_SIZE * 2) - 1:int(x * IMAGE_SIZE * 2) + 1, :] = 0
|
51 |
+
for y in [0.25, 0.5, 0.75]:
|
52 |
+
img_grid[:, int(y * IMAGE_SIZE * 4) - 1:int(y * IMAGE_SIZE * 4) + 1] = 0
|
53 |
+
return img_grid
|
54 |
+
|
55 |
+
|
56 |
+
def generate_matrix_answer(array_list):
|
57 |
+
# row-major array_list
|
58 |
+
assert len(array_list) <= 18
|
59 |
+
img_grid = np.zeros((IMAGE_SIZE * 6, IMAGE_SIZE * 3), np.uint8)
|
60 |
+
for idx in range(len(array_list)):
|
61 |
+
i, j = divmod(idx, 3)
|
62 |
+
img_grid[i * IMAGE_SIZE:(i + 1) * IMAGE_SIZE, j * IMAGE_SIZE:(j + 1) * IMAGE_SIZE] = array_list[idx]
|
63 |
+
# draw grid
|
64 |
+
for x in [0.33, 0.67, 1.00, 1.33, 1.67]:
|
65 |
+
img_grid[int(x * IMAGE_SIZE * 3), :] = 0
|
66 |
+
for y in [0.33, 0.67]:
|
67 |
+
img_grid[:, int(y * IMAGE_SIZE * 3)] = 0
|
68 |
+
return img_grid
|
69 |
+
|
70 |
+
|
71 |
+
def merge_matrix_answer(matrix, answer):
|
72 |
+
matrix_image = generate_matrix(matrix)
|
73 |
+
answer_image = generate_answers(answer)
|
74 |
+
img_grid = np.ones((IMAGE_SIZE * 5 + 20, IMAGE_SIZE * 4), np.uint8) * 255
|
75 |
+
img_grid[:IMAGE_SIZE * 3, int(0.5 * IMAGE_SIZE):int(3.5 * IMAGE_SIZE)] = matrix_image
|
76 |
+
img_grid[-(IMAGE_SIZE * 2):, :] = answer_image
|
77 |
+
return img_grid
|
78 |
+
|
79 |
+
|
80 |
+
def render_panels(feature, target=True,angle=None):
|
81 |
+
# Decompose the panel into a structure and its entities
|
82 |
+
# root
|
83 |
+
# rv.decode_output(root)
|
84 |
+
# rv.decode_output_reshape(root)
|
85 |
+
# decoded =
|
86 |
+
# panel = decoded[0]
|
87 |
+
panels = []
|
88 |
+
for group, exist, color, size, type_ in Bag(raven_utils.decode.decode_target_flat(feature)):
|
89 |
+
canvas = np.ones((IMAGE_SIZE, IMAGE_SIZE), np.uint8) * 255
|
90 |
+
structure_img = render_structure(group)
|
91 |
+
background = np.zeros((IMAGE_SIZE, IMAGE_SIZE), np.uint8)
|
92 |
+
# note left components entities are in the lower layer
|
93 |
+
for i, entity in enumerate(exist):
|
94 |
+
if entity:
|
95 |
+
entity_img = render_entity(RENDER_POSITIONS[i], color[i], size[i], type_[i] + 1, angle=angle)
|
96 |
+
background = layer_add(background, entity_img)
|
97 |
+
background = layer_add(background, structure_img)
|
98 |
+
panels.append(canvas - background)
|
99 |
+
return np.stack(panels)
|
100 |
+
|
101 |
+
|
102 |
+
def render_structure(structure):
|
103 |
+
if structure == 5:
|
104 |
+
ret = np.zeros((IMAGE_SIZE, IMAGE_SIZE), np.uint8)
|
105 |
+
ret[:, int(0.5 * IMAGE_SIZE)] = 255.0
|
106 |
+
elif structure == 6:
|
107 |
+
ret = np.zeros((IMAGE_SIZE, IMAGE_SIZE), np.uint8)
|
108 |
+
ret[int(0.5 * IMAGE_SIZE), :] = 255.0
|
109 |
+
else:
|
110 |
+
ret = np.zeros((IMAGE_SIZE, IMAGE_SIZE), np.uint8)
|
111 |
+
return ret
|
112 |
+
|
113 |
+
|
114 |
+
def render_entity(bbox, color, size, type_, angle=None):
|
115 |
+
color = COLOR_VALUES[color]
|
116 |
+
size = SIZE_VALUES[size]
|
117 |
+
type_ = TYPE_VALUES[type_]
|
118 |
+
if angle is None:
|
119 |
+
angle = np.random.randint(0, 7, 1)[0]
|
120 |
+
angle = ANGLE_VALUES[angle]
|
121 |
+
img = np.zeros((IMAGE_SIZE, IMAGE_SIZE), np.uint8)
|
122 |
+
|
123 |
+
# planar position: [x, y, w, h]
|
124 |
+
# angular position: [x, y, w, h, x_c, y_c, omega]
|
125 |
+
# center: (columns, rows)
|
126 |
+
center = (int(bbox[1] * IMAGE_SIZE), int(bbox[0] * IMAGE_SIZE))
|
127 |
+
if type_ == "triangle":
|
128 |
+
unit = min(bbox[2], bbox[3]) * IMAGE_SIZE / 2
|
129 |
+
dl = int(unit * size)
|
130 |
+
pts = np.array([[center[0], center[1] - dl],
|
131 |
+
[center[0] + int(dl / 2.0 * np.sqrt(3)), center[1] + int(dl / 2.0)],
|
132 |
+
[center[0] - int(dl / 2.0 * np.sqrt(3)), center[1] + int(dl / 2.0)]],
|
133 |
+
np.int32)
|
134 |
+
pts = pts.reshape((-1, 1, 2))
|
135 |
+
color = 255 - color
|
136 |
+
width = DEFAULT_WIDTH
|
137 |
+
draw_triangle(img, pts, color, width)
|
138 |
+
elif type_ == "square":
|
139 |
+
unit = min(bbox[2], bbox[3]) * IMAGE_SIZE / 2
|
140 |
+
dl = int(unit / 2 * np.sqrt(2) * size)
|
141 |
+
pt1 = (center[0] - dl, center[1] - dl)
|
142 |
+
pt2 = (center[0] + dl, center[1] + dl)
|
143 |
+
color = 255 - color
|
144 |
+
width = DEFAULT_WIDTH
|
145 |
+
draw_square(img, pt1, pt2, color, width)
|
146 |
+
elif type_ == "pentagon":
|
147 |
+
unit = min(bbox[2], bbox[3]) * IMAGE_SIZE / 2
|
148 |
+
dl = int(unit * size)
|
149 |
+
pts = np.array([[center[0], center[1] - dl],
|
150 |
+
[center[0] - int(dl * np.cos(np.pi / 10)), center[1] - int(dl * np.sin(np.pi / 10))],
|
151 |
+
[center[0] - int(dl * np.sin(np.pi / 5)), center[1] + int(dl * np.cos(np.pi / 5))],
|
152 |
+
[center[0] + int(dl * np.sin(np.pi / 5)), center[1] + int(dl * np.cos(np.pi / 5))],
|
153 |
+
[center[0] + int(dl * np.cos(np.pi / 10)), center[1] - int(dl * np.sin(np.pi / 10))]],
|
154 |
+
np.int32)
|
155 |
+
pts = pts.reshape((-1, 1, 2))
|
156 |
+
color = 255 - color
|
157 |
+
width = DEFAULT_WIDTH
|
158 |
+
draw_pentagon(img, pts, color, width)
|
159 |
+
elif type_ == "hexagon":
|
160 |
+
unit = min(bbox[2], bbox[3]) * IMAGE_SIZE / 2
|
161 |
+
dl = int(unit * size)
|
162 |
+
pts = np.array([[center[0], center[1] - dl],
|
163 |
+
[center[0] - int(dl / 2.0 * np.sqrt(3)), center[1] - int(dl / 2.0)],
|
164 |
+
[center[0] - int(dl / 2.0 * np.sqrt(3)), center[1] + int(dl / 2.0)],
|
165 |
+
[center[0], center[1] + dl],
|
166 |
+
[center[0] + int(dl / 2.0 * np.sqrt(3)), center[1] + int(dl / 2.0)],
|
167 |
+
[center[0] + int(dl / 2.0 * np.sqrt(3)), center[1] - int(dl / 2.0)]],
|
168 |
+
np.int32)
|
169 |
+
pts = pts.reshape((-1, 1, 2))
|
170 |
+
color = 255 - color
|
171 |
+
width = DEFAULT_WIDTH
|
172 |
+
draw_hexagon(img, pts, color, width)
|
173 |
+
elif type_ == "circle":
|
174 |
+
# Minus because of the way we show the image. See: render_panel's return
|
175 |
+
color = 255 - color
|
176 |
+
unit = min(bbox[2], bbox[3]) * IMAGE_SIZE / 2
|
177 |
+
radius = int(unit * size)
|
178 |
+
width = DEFAULT_WIDTH
|
179 |
+
draw_circle(img, center, radius, color, width)
|
180 |
+
elif type_ == "none":
|
181 |
+
pass
|
182 |
+
# angular
|
183 |
+
if len(bbox) > 4:
|
184 |
+
# [x, y, w, h, x_c, y_c, omega]
|
185 |
+
angle = bbox[6]
|
186 |
+
center = (int(bbox[5] * IMAGE_SIZE), int(bbox[4] * IMAGE_SIZE))
|
187 |
+
img = rotate(img, angle, center=center)
|
188 |
+
# planar
|
189 |
+
else:
|
190 |
+
img = rotate(img, angle, center=center)
|
191 |
+
# img = shift(img, *entity_position)
|
192 |
+
|
193 |
+
return img
|
194 |
+
|
195 |
+
|
196 |
+
def shift(img, dx, dy):
|
197 |
+
M = np.array([[1, 0, dx], [0, 1, dy]], np.float32)
|
198 |
+
img = cv2.warpAffine(img, M, (IMAGE_SIZE, IMAGE_SIZE), flags=cv2.INTER_LINEAR)
|
199 |
+
return img
|
200 |
+
|
201 |
+
|
202 |
+
def rotate(img, angle, center=CENTER):
|
203 |
+
M = cv2.getRotationMatrix2D(center, angle, 1)
|
204 |
+
img = cv2.warpAffine(img, M, (IMAGE_SIZE, IMAGE_SIZE), flags=cv2.INTER_LINEAR)
|
205 |
+
return img
|
206 |
+
|
207 |
+
|
208 |
+
def scale(img, tx, ty, center=CENTER):
|
209 |
+
M = np.array([[tx, 0, center[0] * (1 - tx)], [0, ty, center[1] * (1 - ty)]], np.float32)
|
210 |
+
img = cv2.warpAffine(img, M, (IMAGE_SIZE, IMAGE_SIZE), flags=cv2.INTER_LINEAR)
|
211 |
+
return img
|
212 |
+
|
213 |
+
|
214 |
+
def layer_add(lower_layer_np, higher_layer_np):
|
215 |
+
# higher_layer_np is superimposed on lower_layer_np
|
216 |
+
# new_np = lower_layer_np.copy()
|
217 |
+
# lower_layer_np is modified
|
218 |
+
lower_layer_np[higher_layer_np > 0] = 0
|
219 |
+
return lower_layer_np + higher_layer_np
|
220 |
+
|
221 |
+
|
222 |
+
# Draw primitives
|
223 |
+
def draw_triangle(img, pts, color, width):
|
224 |
+
# if filled
|
225 |
+
if color != 0:
|
226 |
+
# fill the interior
|
227 |
+
cv2.fillConvexPoly(img, pts, color)
|
228 |
+
# draw the edge
|
229 |
+
cv2.polylines(img, [pts], True, 255, width)
|
230 |
+
# if not filled
|
231 |
+
else:
|
232 |
+
cv2.polylines(img, [pts], True, 255, width)
|
233 |
+
|
234 |
+
|
235 |
+
def draw_square(img, pt1, pt2, color, width):
|
236 |
+
# if filled
|
237 |
+
if color != 0:
|
238 |
+
# fill the interior
|
239 |
+
cv2.rectangle(img,
|
240 |
+
pt1,
|
241 |
+
pt2,
|
242 |
+
color,
|
243 |
+
-1)
|
244 |
+
# draw the edge
|
245 |
+
cv2.rectangle(img,
|
246 |
+
pt1,
|
247 |
+
pt2,
|
248 |
+
255,
|
249 |
+
width)
|
250 |
+
# if not filled
|
251 |
+
else:
|
252 |
+
cv2.rectangle(img,
|
253 |
+
pt1,
|
254 |
+
pt2,
|
255 |
+
255,
|
256 |
+
width)
|
257 |
+
|
258 |
+
|
259 |
+
def draw_pentagon(img, pts, color, width):
|
260 |
+
# if filled
|
261 |
+
if color != 0:
|
262 |
+
# fill the interior
|
263 |
+
cv2.fillConvexPoly(img, pts, color)
|
264 |
+
# draw the edge
|
265 |
+
cv2.polylines(img, [pts], True, 255, width)
|
266 |
+
# if not filled
|
267 |
+
else:
|
268 |
+
cv2.polylines(img, [pts], True, 255, width)
|
269 |
+
|
270 |
+
|
271 |
+
def draw_hexagon(img, pts, color, width):
|
272 |
+
# if filled
|
273 |
+
if color != 0:
|
274 |
+
# fill the interior
|
275 |
+
cv2.fillConvexPoly(img, pts, color)
|
276 |
+
# draw the edge
|
277 |
+
cv2.polylines(img, [pts], True, 255, width)
|
278 |
+
# if not filled
|
279 |
+
else:
|
280 |
+
cv2.polylines(img, [pts], True, 255, width)
|
281 |
+
|
282 |
+
|
283 |
+
def draw_circle(img, center, radius, color, width):
|
284 |
+
# if filled
|
285 |
+
if color != 0:
|
286 |
+
# fill the interior
|
287 |
+
cv2.circle(img,
|
288 |
+
center,
|
289 |
+
radius,
|
290 |
+
color,
|
291 |
+
-1)
|
292 |
+
# draw the edge
|
293 |
+
cv2.circle(img,
|
294 |
+
center,
|
295 |
+
radius,
|
296 |
+
255,
|
297 |
+
width)
|
298 |
+
# if not filled
|
299 |
+
else:
|
300 |
+
cv2.circle(img,
|
301 |
+
center,
|
302 |
+
radius,
|
303 |
+
255,
|
304 |
+
width)
|
raven_utils/render_.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
COLOR_VALUES = [255, 224, 196, 168, 140, 112, 84, 56, 28, 0]
|
2 |
+
SIZE_VALUES = [0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
|
3 |
+
TYPE_VALUES = ["none", "triangle", "square", "pentagon", "hexagon", "circle"]
|
4 |
+
ANGLE_VALUES = [-135, -90, -45, 0, 45, 90, 135, 180]
|
5 |
+
RENDER_POSITIONS_GROUPED = [
|
6 |
+
[(0.5, 0.5, 1, 1)],
|
7 |
+
# ...
|
8 |
+
[(0.25, 0.25, 0.5, 0.5),
|
9 |
+
(0.25, 0.75, 0.5, 0.5),
|
10 |
+
(0.75, 0.25, 0.5, 0.5),
|
11 |
+
(0.75, 0.75, 0.5, 0.5)],
|
12 |
+
# ...
|
13 |
+
[(0.16, 0.16, 0.33, 0.33),
|
14 |
+
(0.16, 0.5, 0.33, 0.33),
|
15 |
+
(0.16, 0.83, 0.33, 0.33),
|
16 |
+
(0.5, 0.16, 0.33, 0.33),
|
17 |
+
(0.5, 0.5, 0.33, 0.33),
|
18 |
+
(0.5, 0.83, 0.33, 0.33),
|
19 |
+
(0.83, 0.16, 0.33, 0.33),
|
20 |
+
(0.83, 0.5, 0.33, 0.33),
|
21 |
+
(0.83, 0.83, 0.33, 0.33)],
|
22 |
+
# ...
|
23 |
+
[(0.5, 0.5, 1, 1)],
|
24 |
+
[(0.5, 0.5, 0.33, 0.33)],
|
25 |
+
# ...
|
26 |
+
[(0.5, 0.5, 1, 1)],
|
27 |
+
[(0.42, 0.42, 0.15, 0.15),
|
28 |
+
(0.42, 0.58, 0.15, 0.15),
|
29 |
+
(0.58, 0.42, 0.15, 0.15),
|
30 |
+
(0.58, 0.58, 0.15, 0.15)],
|
31 |
+
# ....
|
32 |
+
[(0.5, 0.25, 0.5, 0.5)],
|
33 |
+
[(0.5, 0.75, 0.5, 0.5)],
|
34 |
+
# ...
|
35 |
+
[(0.25, 0.5, 0.5, 0.5)],
|
36 |
+
[(0.75, 0.5, 0.5, 0.5)],
|
37 |
+
# ...
|
38 |
+
]
|
39 |
+
RENDER_POSITIONS = [pos_ for group_pos_ in RENDER_POSITIONS_GROUPED for pos_ in group_pos_]
|
40 |
+
MAPPING = {
|
41 |
+
"distribute_nine":
|
42 |
+
{0.16: 0,
|
43 |
+
0.5: 1,
|
44 |
+
0.83: 2},
|
45 |
+
"distribute_four":
|
46 |
+
{0.25: 0,
|
47 |
+
0.75: 1},
|
48 |
+
'in_distribute_four_out_center_single':
|
49 |
+
{0.42: 0,
|
50 |
+
0.58: 1}
|
51 |
+
}
|
52 |
+
MUL = {
|
53 |
+
"distribute_nine": 3,
|
54 |
+
"distribute_four": 2,
|
55 |
+
'in_distribute_four_out_center_single': 2
|
56 |
+
}
|
57 |
+
TYPES = ["triangle", "square", "pentagon", "hexagon", "circle"]
|
58 |
+
TYPES_NONE = ["none", "triangle", "square", "pentagon", "hexagon", "circle"]
|
59 |
+
SIZES = ["vs", "s", "m", "h", "vh", "e"]
|
60 |
+
SIZES_NAME = ["Very Small", "Small", "Medium", "High", "Very High", "Enormous"]
|
61 |
+
COLORS = ["vs", "s", "m", "h", "vh", "e"]
|
62 |
+
|
63 |
+
SAMPLE_TARGET = [[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
64 |
+
0, 0, 0, 0, 9, 5, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
65 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
66 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
67 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 2, 0, 0, 0, 0,
|
68 |
+
0, 1, 3],
|
69 |
+
[1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
70 |
+
0, 0, 0, 0, 0, 0, 0, 2, 0, 3, 2, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
71 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
72 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
73 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 2, 3, 0, 0, 0, 0,
|
74 |
+
0, 3, 3],
|
75 |
+
[2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
|
76 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
77 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0,
|
78 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
79 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 3, 2, 0, 0, 0, 0,
|
80 |
+
0, 0, 3],
|
81 |
+
[3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0,
|
82 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
83 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
84 |
+
0, 0, 0, 5, 2, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
85 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3,
|
86 |
+
3, 2, 1],
|
87 |
+
[4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1,
|
88 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
89 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
90 |
+
0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 3,
|
91 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 3, 2, 1,
|
92 |
+
3, 0, 1],
|
93 |
+
[5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
94 |
+
1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
95 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
96 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
97 |
+
0, 2, 0, 0, 7, 5, 4, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 3, 0, 0, 3, 3,
|
98 |
+
3, 1, 0],
|
99 |
+
[6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
100 |
+
0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
101 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
102 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
103 |
+
0, 0, 0, 0, 0, 0, 0, 6, 5, 0, 8, 5, 1, 0, 0, 0, 3, 2, 0, 0, 1, 0,
|
104 |
+
3, 3, 3]]
|
raven_utils/rules.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ml_utils import dict_from_list
|
2 |
+
COMBINE = "Number/Position"
|
3 |
+
|
4 |
+
ATTRIBUTES = [
|
5 |
+
"Number",
|
6 |
+
"Position",
|
7 |
+
"Color",
|
8 |
+
"Size",
|
9 |
+
"Type"
|
10 |
+
]
|
11 |
+
ATTRIBUTES_LEN = len(ATTRIBUTES)
|
12 |
+
ATTRIBUTES_INDEX = dict_from_list(ATTRIBUTES)
|
13 |
+
|
14 |
+
TYPES = [
|
15 |
+
"Constant",
|
16 |
+
"Arithmetic",
|
17 |
+
"Progression",
|
18 |
+
"Distribute_Three"
|
19 |
+
]
|
20 |
+
TYPES_INDEX = dict_from_list(TYPES)
|
21 |
+
TYPES_LEN = len(TYPES)
|
raven_utils/target.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import raven_utils.group as group
|
3 |
+
import raven_utils.entity as entity
|
4 |
+
import raven_utils.rules as rules
|
5 |
+
import raven_utils.properties as properties
|
6 |
+
|
7 |
+
|
8 |
+
ENTITY_INDEX = entity.INDEX + 1
|
9 |
+
ENTITY_DICT = dict(zip(group.NAMES, ENTITY_INDEX[:-1]))
|
10 |
+
NAMES_ORDER = dict(zip(group.NAMES, np.arange(len(group.NAMES))))
|
11 |
+
PROPERTIES_INDEXES = np.cumsum(np.array(list(entity.NO.values())) * properties.NO)
|
12 |
+
INDEX = np.concatenate([[0], PROPERTIES_INDEXES]) + entity.SUM + 1 # +2 type and uniformity
|
13 |
+
|
14 |
+
SECOND_LAYOUT = [i - 1 for i in [
|
15 |
+
ENTITY_DICT["in_center_single_out_center_single"] + 1,
|
16 |
+
ENTITY_DICT["in_distribute_four_out_center_single"] + 1,
|
17 |
+
ENTITY_DICT["in_distribute_four_out_center_single"] + 2,
|
18 |
+
ENTITY_DICT["in_distribute_four_out_center_single"] + 3,
|
19 |
+
ENTITY_DICT["left_center_single_right_center_single"] + 1,
|
20 |
+
ENTITY_DICT["up_center_single_down_center_single"] + 1
|
21 |
+
]]
|
22 |
+
|
23 |
+
FIRST_LAYOUT = list(set(range(entity.SUM)) - set(SECOND_LAYOUT))
|
24 |
+
LAYOUT_NO = 2
|
25 |
+
|
26 |
+
START_INDEX = dict(zip(group.NAMES, INDEX[:-1]))
|
27 |
+
END_INDEX = INDEX[-1]
|
28 |
+
|
29 |
+
RULES_ATTRIBUTES_ALL_LEN = rules.ATTRIBUTES_LEN * LAYOUT_NO
|
30 |
+
UNIFORMITY_NO = 2
|
31 |
+
UNIFORMITY_INDEX = END_INDEX + RULES_ATTRIBUTES_ALL_LEN
|
32 |
+
|
33 |
+
SIZE = UNIFORMITY_INDEX + UNIFORMITY_NO
|
34 |
+
|
35 |
+
def take(target):
|
36 |
+
return target[1], target[2]
|
37 |
+
|
38 |
+
|
39 |
+
def create(images, index, pattern_index=(2, 5), full_index=False, arrange=np.arange, shape=lambda x: x.shape):
|
40 |
+
return [images[:, pattern_index[0]], images[:, pattern_index[1]],
|
41 |
+
images[arrange(shape(index)[0]), (0 if full_index else 8) + index[:, 0]]]
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
def take_simple(target):
|
46 |
+
return target[1], target[0]
|
47 |
+
|
48 |
+
|
49 |
+
def create_simple(images, target, index=slice(None), pattern_index=(2, 5)):
|
50 |
+
return [images[:, pattern_index[0]], images[:, pattern_index[1]], target][index]
|
raven_utils/uitls.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from itertools import product
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from funcy import identity
|
6 |
+
|
7 |
+
from data_utils import gather, DataGenerator, Data
|
8 |
+
from data_utils.sampling import DataSampler
|
9 |
+
from models_utils import init_image as def_init_image, INPUTS, TARGET
|
10 |
+
|
11 |
+
import raven_utils.group as group
|
12 |
+
|
13 |
+
from data_utils import ops as D
|
14 |
+
|
15 |
+
init_image = partial(def_init_image, shape=(16, 8, 80, 80, 1))
|
16 |
+
|
17 |
+
|
18 |
+
def get_val_index(no=group.NO, base=3,add_end=False):
|
19 |
+
indexes = np.arange(no) * 2000 + base
|
20 |
+
if add_end:
|
21 |
+
indexes = np.concatenate([indexes, no*2000])
|
22 |
+
return indexes
|
23 |
+
|
24 |
+
|
25 |
+
def get_matrix(inputs, index):
|
26 |
+
return np.concatenate([inputs[:, :8], gather(inputs, index[:, 0])[:, None]], axis=1)
|
27 |
+
|
28 |
+
|
29 |
+
def get_matrix_from_data(x):
|
30 |
+
inputs = x["inputs"]
|
31 |
+
index = x["index"]
|
32 |
+
return get_matrix(inputs, index)
|
33 |
+
|
34 |
+
|
35 |
+
def get_data_class(data, batch_size=128):
|
36 |
+
fn = identity
|
37 |
+
shape = data[0].shape
|
38 |
+
train_generator = DataGenerator(
|
39 |
+
{
|
40 |
+
INPUTS: Data(data[0], fn),
|
41 |
+
TARGET: Data(data[2], fn),
|
42 |
+
},
|
43 |
+
sampler=DataSampler(np.array(list(product(np.arange(shape[0]), np.arange(shape[1]))))),
|
44 |
+
batch=batch_size
|
45 |
+
)
|
46 |
+
shape = data[1].shape
|
47 |
+
val_generator = DataGenerator(
|
48 |
+
{
|
49 |
+
INPUTS: Data(data[1], fn),
|
50 |
+
TARGET: Data(data[3], fn),
|
51 |
+
},
|
52 |
+
sampler=DataSampler(np.array(list(product(np.arange(shape[0]), np.arange(shape[1])))), shuffle=False),
|
53 |
+
batch=batch_size
|
54 |
+
)
|
55 |
+
return train_generator, val_generator
|
56 |
+
|
57 |
+
|
58 |
+
def compare_from_result(result, data):
|
59 |
+
data = data.data.data
|
60 |
+
answer = D.gather(data['target'].data, data['index'].data[:, 0])
|
61 |
+
import raven_utils as rv
|
62 |
+
predict = result['predict']
|
63 |
+
predict_mask = result['predict_mask']
|
64 |
+
return np.all(rv.decode.compare(answer[:len(predict)], predict, predict_mask), axis=-1)
|
saved_model/1/keras_metadata.pb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3065f1580247f096711cd61201a17a730a1e5a3d719f2c2778030dea78bb17b4
|
3 |
+
size 730275
|
saved_model/1/saved_model.pb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:48d10d74324a5e993ceacc0a4bffc1fcb232d7e2f708a2ebbeabd864650baeeb
|
3 |
+
size 12159312
|
saved_model/1/variables/variables.data-00000-of-00001
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3e3c44b273228c834166b40b8a062a53dce76cc21d4cce42f65df2edc53533a7
|
3 |
+
size 43002413
|
saved_model/1/variables/variables.index
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e69bbffa5d415625538b629762f1aaeeb355a83d676242110af3d633e31017dd
|
3 |
+
size 24958
|
utils.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from data_utils.image import draw_images
|
3 |
+
from ml_utils import il
|
4 |
+
|
5 |
+
import raven_utils as rv
|
6 |
+
from raven_utils.uitls import get_matrix
|
7 |
+
from tensorflow.keras.models import load_model
|
8 |
+
from raven_utils.draw import render_from_model
|
9 |
+
import models
|
10 |
+
import ast
|
11 |
+
|
12 |
+
|
13 |
+
def load_example(index=0):
|
14 |
+
index = ast.literal_eval(str(index))
|
15 |
+
if il(index):
|
16 |
+
example = rv.draw.render_panels(np.array(index))
|
17 |
+
desc = "Custom matrix"
|
18 |
+
else:
|
19 |
+
if not index:
|
20 |
+
index = 0
|
21 |
+
index = int(index)
|
22 |
+
|
23 |
+
desc = rv.draw.extract_rules(models.properties[index])
|
24 |
+
desc = "<br /><br />".join(["<br />".join(d) for d in desc])
|
25 |
+
|
26 |
+
example = get_matrix(models.data[index:index + 1], models.indexes[index:index + 1, None] + 8)
|
27 |
+
result = np.tile(draw_images(example[:9], row=3), reps=(1, 1, 3))
|
28 |
+
return result, desc
|
29 |
+
|
30 |
+
|
31 |
+
def load_model_(name):
|
32 |
+
if name == "Transformer":
|
33 |
+
path = "/home/jkwiatkowski/all/best/rav/full_trans/6e8e6bad403e4171ad10daa1a518ba09"
|
34 |
+
else:
|
35 |
+
path = name
|
36 |
+
models.model = load_model(path)
|
37 |
+
return f"Success loading: {name}"
|
38 |
+
|
39 |
+
|
40 |
+
def run_nn(index=0):
|
41 |
+
index = ast.literal_eval(str(index))
|
42 |
+
if il(index):
|
43 |
+
data = rv.draw.render_panels(np.array(index))
|
44 |
+
data = np.concatenate([data, data[:7]])[None]
|
45 |
+
else:
|
46 |
+
if not index:
|
47 |
+
index = models.START_IMAGE
|
48 |
+
index = int(index)
|
49 |
+
data = models.data[index:index + 1]
|
50 |
+
|
51 |
+
# model = load_model("/home/jkwiatkowski/all/best/rav/full_trans/6e8e6bad403e4171ad10daa1a518ba09")
|
52 |
+
data = {
|
53 |
+
'inputs': data,
|
54 |
+
'index': np.zeros(shape=(1, 1), dtype="uint8"),
|
55 |
+
'labels': np.zeros(shape=(1, 16, 113), dtype="int8"),
|
56 |
+
'target': np.zeros(shape=(1, 16, 113), dtype="int8"),
|
57 |
+
# 'features': np.zeros(shape=(1, 16, 64), dtype="float32")
|
58 |
+
}
|
59 |
+
res = np.tile(render_from_model(data, models.model)[0, ..., None], reps=(1, 1, 3))
|
60 |
+
|
61 |
+
# res = model({'inputs': data[0:1]})
|
62 |
+
|
63 |
+
return res
|
64 |
+
|
65 |
+
|
66 |
+
def next_(index=0):
|
67 |
+
index = ast.literal_eval(str(index))
|
68 |
+
if not isinstance(index, int):
|
69 |
+
index = models.START_IMAGE
|
70 |
+
index = int(index) + 1
|
71 |
+
return (index,) + load_example(index)
|
72 |
+
|
73 |
+
|
74 |
+
def prev_(index=0):
|
75 |
+
index = ast.literal_eval(str(index))
|
76 |
+
if not isinstance(index, int):
|
77 |
+
index = models.START_IMAGE
|
78 |
+
index = int(index) - 1
|
79 |
+
return (index,) + load_example(index)
|
80 |
+
|
81 |
+
|
82 |
+
if __name__ == '__main__':
|
83 |
+
image, _ = load_example(5)
|
84 |
+
run_nn(image)
|