Jakub Kwiatkowski commited on
Commit
e986ee1
·
1 Parent(s): 9396266

Add model.

Browse files
Files changed (48) hide show
  1. main.py +49 -0
  2. models.py +13 -0
  3. raven_utils/__init__.py +10 -0
  4. raven_utils/config/__init__.py +0 -0
  5. raven_utils/config/constant.py +54 -0
  6. raven_utils/config/models.py +9 -0
  7. raven_utils/const.py +2 -0
  8. raven_utils/constant.py +53 -0
  9. raven_utils/data.py +46 -0
  10. raven_utils/decode.py +100 -0
  11. raven_utils/depricated/__init__.py +0 -0
  12. raven_utils/depricated/old_raven.py +490 -0
  13. raven_utils/draw.py +174 -0
  14. raven_utils/entity.py +6 -0
  15. raven_utils/group.py +11 -0
  16. raven_utils/inference.py +15 -0
  17. raven_utils/models/__init__.py +0 -0
  18. raven_utils/models/attn.py +187 -0
  19. raven_utils/models/attn2.py +187 -0
  20. raven_utils/models/augment.py +0 -0
  21. raven_utils/models/body.py +276 -0
  22. raven_utils/models/class_.py +31 -0
  23. raven_utils/models/head.py +159 -0
  24. raven_utils/models/loss.py +630 -0
  25. raven_utils/models/loss_3.py +638 -0
  26. raven_utils/models/multi_transformer.py +274 -0
  27. raven_utils/models/raven.py +239 -0
  28. raven_utils/models/trans.py +74 -0
  29. raven_utils/models/transformer.py +133 -0
  30. raven_utils/models/transformer_2.py +146 -0
  31. raven_utils/models/transformer_3.py +206 -0
  32. raven_utils/models/uitls_.py +16 -0
  33. raven_utils/output.py +16 -0
  34. raven_utils/params.py +110 -0
  35. raven_utils/properties.py +16 -0
  36. raven_utils/range_mask.py +16 -0
  37. raven_utils/render/__init__.py +0 -0
  38. raven_utils/render/const.py +86 -0
  39. raven_utils/render/rendering.py +304 -0
  40. raven_utils/render_.py +104 -0
  41. raven_utils/rules.py +21 -0
  42. raven_utils/target.py +50 -0
  43. raven_utils/uitls.py +64 -0
  44. saved_model/1/keras_metadata.pb +3 -0
  45. saved_model/1/saved_model.pb +3 -0
  46. saved_model/1/variables/variables.data-00000-of-00001 +3 -0
  47. saved_model/1/variables/variables.index +3 -0
  48. utils.py +84 -0
main.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from utils import load_example, run_nn, load_model_, next_, prev_
4
+
5
+ demo = gr.Blocks()
6
+ import models
7
+
8
+ with demo:
9
+ headline = gr.Markdown("## Raven resolver ")
10
+ markdown = gr.Markdown("Below we show all 9 images from raven matrix. "
11
+ "Model gets 8 images and predicts the properties of last one. "
12
+ "Based on this properties the answer image is render in the right panel. <br />"
13
+ "Note that angle rotation is only used as a noise. "
14
+ "There are not rules applied to angle property, so angle rotation of final output do not need to be the same as in example. "
15
+ "Additionally there are cases that other properties could be used as noise.")
16
+ with gr.Row():
17
+ with gr.Column():
18
+ with gr.Row():
19
+ text = gr.Textbox(models.START_IMAGE,
20
+ label="Write the example number from validation dataset (0, 14,000). You can also paste here matrix representation from generator.")
21
+ with gr.Row():
22
+ prev = gr.Button("Prev")
23
+ show = gr.Button("Show")
24
+ next = gr.Button("Next")
25
+ # button = gr.Button("Run")
26
+ with gr.Row():
27
+ image = gr.Image(value=load_example(models.START_IMAGE)[0], label="Raven matrix")
28
+ desc = gr.Markdown(value=load_example(models.START_IMAGE)[1])
29
+
30
+ with gr.Column():
31
+ with gr.Row():
32
+ output = gr.Image(label="Generated image", shape=(200, 200))
33
+ with gr.Row():
34
+ button = gr.Button("Run")
35
+
36
+ # text.change(load_example, inputs=text, outputs=[image, desc])
37
+ show.click(load_example, inputs=text, outputs=[image, desc])
38
+ # button.click(run_nn, inputs=image, outputs=output)
39
+ button.click(run_nn, inputs=text, outputs=output)
40
+
41
+ # next.click(next_, inputs=text, outputs=text)
42
+ # next.click(load_example, inputs=text, outputs=[image, desc])
43
+ next.click(next_, inputs=text, outputs=[text, image, desc])
44
+
45
+ # prev.click(prev_, inputs=text, outputs=text)
46
+ # prev.click(load_example, inputs=text, outputs=[image, desc])
47
+ prev.click(prev_, inputs=text, outputs=[text, image, desc])
48
+
49
+ demo.launch(debug=True)
models.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ START_IMAGE = 12000
4
+
5
+ from tensorflow.keras.models import load_model
6
+ model = load_model("saved_model/1")
7
+
8
+ from data_utils import nload, ims, DataSetFromFolder
9
+ data = nload("/home/jkwiatkowski/all/dataset/arr/val.npy")
10
+ indexes = nload("/home/jkwiatkowski/all/dataset/arr/val_target.npy")
11
+
12
+ folders = DataSetFromFolder("/home/jkwiatkowski/all/dataset/arr/RAVEN-10000-release/RAVEN-10000", file_type="dir")
13
+ properties = DataSetFromFolder(folders[:], file_type="xml", extension="val")
raven_utils/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import raven_utils.group as group
2
+ import raven_utils.entity as entity
3
+ import raven_utils.properties as properties
4
+ import raven_utils.target as target
5
+ import raven_utils.rules as rules
6
+ import raven_utils.output as output
7
+ import raven_utils.inference as inference
8
+ import raven_utils.decode as decode
9
+ import raven_utils.render_ as render_
10
+ import raven_utils.draw as draw
raven_utils/config/__init__.py ADDED
File without changes
raven_utils/config/constant.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ RAVEN = "arr"
2
+ RAVEN_BIG = "arrb"
3
+ INDEX = "index"
4
+ LABELS = "labels"
5
+ TARGET_LABELS = "target_labels"
6
+ FEATURES = "features"
7
+ ACC_SAME = "acc_same"
8
+ ACC_CHOOSE_LOWER = "acc_choose_lower"
9
+ ACC_CHOOSE_UPPER = "acc_choose_upper"
10
+ ACC_NO_GROUP = "acc_NO_group"
11
+ CLASSIFICATION = "classification"
12
+ INFERENCE = "inference"
13
+ # PROPERTIES = "properties"
14
+ PROPERTY = "property"
15
+ MEMORY = "memory"
16
+ CONTROL = "control"
17
+ LATENT = "latent"
18
+ TARGET = "target"
19
+ INPUTS = "inputs"
20
+ RES = "res"
21
+ RESULT = "result"
22
+ MERGE = "merge"
23
+ MEMORY_STATE = "memory_state"
24
+ CONTROL_STATE = "control_state"
25
+ CONCAT = "concat"
26
+ FLATTEN = "flatten"
27
+ CROSS_ENTROPY = "cross_entropy"
28
+ SLOT = "slot"
29
+ PROPERTIES = "properties"
30
+ ACC = "acc"
31
+ GROUP = 'group'
32
+ NUMBER = 'number'
33
+ TRANS = 'trans'
34
+ TAIL = "tail"
35
+ MASK = "mask"
36
+
37
+ RAV_METRICS = [
38
+ ACC_NO_GROUP,
39
+ ACC_SAME,
40
+ ACC_CHOOSE_UPPER,
41
+ ACC_CHOOSE_LOWER,
42
+ "acc",
43
+ "c_acc_NO_group",
44
+ "c_acc",
45
+ "loss",
46
+ ]
47
+
48
+ IMP_RAV_METRICS = [
49
+ ACC_NO_GROUP,
50
+ ACC_SAME,
51
+ ACC_CHOOSE_UPPER,
52
+ ACC_CHOOSE_LOWER,
53
+ ACC,
54
+ ]
raven_utils/config/models.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+
2
+ AVAILABLE_MODELS = [
3
+ "197-0.31",
4
+ "53-0.48",
5
+ "74-0.50",
6
+ "21-0.48",
7
+ "10-0.52",
8
+ "179-0.50"
9
+ ]
raven_utils/const.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ HORIZONTAL = "horizontal"
2
+ VERTICAL = "vertical"
raven_utils/constant.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ RAVEN = "arr"
2
+ RAVEN_BIG = "arrb"
3
+ INDEX = "index"
4
+ LABELS = "labels"
5
+ TARGET_LABELS = "target_labels"
6
+ FEATURES = "features"
7
+ ACC_SAME = "acc_same"
8
+ ACC_CHOOSE_LOWER = "acc_choose_lower"
9
+ ACC_CHOOSE_UPPER = "acc_choose_upper"
10
+ ACC_NO_GROUP = "acc_NO_group"
11
+ CLASSIFICATION = "classification"
12
+ INFERENCE = "inference"
13
+ # PROPERTIES = "properties"
14
+ PROPERTY = "property"
15
+ MEMORY = "memory"
16
+ CONTROL = "control"
17
+ LATENT = "latent"
18
+ TARGET = "target"
19
+ INPUTS = "inputs"
20
+ RES = "res"
21
+ RESULT = "result"
22
+ MERGE = "merge"
23
+ MEMORY_STATE = "memory_state"
24
+ CONTROL_STATE = "control_state"
25
+ CONCAT = "concat"
26
+ FLATTEN = "flatten"
27
+ CROSS_ENTROPY = "cross_entropy"
28
+ SLOT = "slot"
29
+ PROPERTIES = "properties"
30
+ ACC = "acc"
31
+ GROUP = 'group'
32
+ NUMBER = 'number'
33
+ TRANS = 'trans'
34
+ TAIL = "tail"
35
+ MASK = "mask"
36
+
37
+ RAV_METRICS = [
38
+ ACC_NO_GROUP,
39
+ ACC_SAME,
40
+ ACC_CHOOSE_UPPER,
41
+ ACC_CHOOSE_LOWER,
42
+ "acc",
43
+ "c_acc_NO_group",
44
+ "c_acc",
45
+ "loss",
46
+ ]
47
+
48
+ IMP_RAV_METRICS = [
49
+ ACC_NO_GROUP,
50
+ ACC_SAME,
51
+ ACC_CHOOSE_UPPER,
52
+ ACC_CHOOSE_LOWER,
53
+ ]
raven_utils/data.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import tensorflow as tf
4
+
5
+ from models_utils import INPUTS, TARGET
6
+
7
+ from raven_utils.config.constant import RAVEN, LABELS, INDEX, FEATURES, RAV_METRICS, IMP_RAV_METRICS, ACC_NO_GROUP
8
+
9
+
10
+ from typing import Any
11
+
12
+ from data_utils import pre, Data, gather, vec, resize
13
+ from data_utils.data_generator import DataGenerator
14
+ from funcy import identity
15
+
16
+
17
+ def get_data(data, batch_size, steps=None, val_steps=None):
18
+ if val_steps is None:
19
+ val_steps = steps
20
+ fn = identity
21
+ train_target_index = data[4] + 8
22
+ train_generator = DataGenerator({
23
+ INPUTS: Data(data[0], fn),
24
+ TARGET: Data(data[2], identity),
25
+ LABELS: Data(data[2], identity),
26
+ INDEX: train_target_index[:, None],
27
+ # FEATURES: data[6]
28
+ },
29
+ batch=batch_size,
30
+ steps=steps
31
+ )
32
+ val_target_index = data[5] + 8
33
+ val_data = {
34
+ INPUTS: Data(data[1], fn),
35
+ TARGET: Data(data[3], identity),
36
+ LABELS: Data(data[3], identity),
37
+ INDEX: val_target_index[:, None],
38
+ # FEATURES: data[7]
39
+ }
40
+ val_generator = DataGenerator(
41
+ val_data,
42
+ batch=batch_size,
43
+ sampler="val",
44
+ steps=val_steps
45
+ )
46
+ return train_generator, val_generator
raven_utils/decode.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from data_utils import np_split
3
+ from ml_utils import lw
4
+ from models_utils.ops import ibin
5
+
6
+ import raven_utils as rv
7
+
8
+
9
+ def output(x, split_fn=np_split, predict_fn_1=np.argmax, predict_fn_2=ibin):
10
+ res = output_divide(x, split_fn=split_fn)
11
+ res = output_predict(res, predict_fn_1=predict_fn_1, predict_fn_2=predict_fn_2)
12
+ return (res[0], res[1]) + tuple(output_properties(res[2], predict_fn=predict_fn_1))
13
+
14
+
15
+ def output_divide(output, split_fn=np_split):
16
+ group_output = output[..., rv.output.GROUP_SLICE_END]
17
+ slot_output = output[..., rv.output.SLOT_SLICE_END]
18
+ properties_output = output[..., rv.output.PROPERTIES_SLICE_END]
19
+ properties_output_splited = split_fn(properties_output, list(rv.properties.INDEX.values()), axis=-1)
20
+ return group_output, slot_output, properties_output_splited
21
+
22
+
23
+ def output_predict(output, predict_fn_1=np.argmax, predict_fn_2=ibin):
24
+ return predict_fn_1(output[0]), predict_fn_2(output[1]), output[2]
25
+
26
+
27
+ def output_properties(x, predict_fn=np.argmax):
28
+ out_reshaped = []
29
+ for i, out in enumerate(x):
30
+ shape = (-1, rv.entity.SUM, rv.properties.RAW_SIZE[i])
31
+ out_reshaped.append(predict_fn(out.reshape(shape)))
32
+ return out_reshaped
33
+
34
+
35
+ def output_result(output, split_fn=np_split, arg_max=np.argmax):
36
+ result = output_properties(output, predict_fn=split_fn)
37
+ res = []
38
+ for i, r in enumerate(result):
39
+ if i == 1:
40
+ res.append(r)
41
+ else:
42
+ res.append(arg_max(r, axis=-1))
43
+ return tuple(res)
44
+
45
+
46
+ def decode_inference(inference, reshape=np.reshape):
47
+ return reshape(inference[rv.inference.SLOT_SLICE],
48
+ [-1, rv.group.NO, rv.inference.PROPERTY_TRANSFORMATION_NO]), reshape(
49
+ inference[rv.inference.PROPERTIES_SLICE],
50
+ [-1, rv.properties.NO, rv.entity.SUM, rv.inference.PROPERTY_TRANSFORMATION_NO])
51
+
52
+
53
+ def decode_target(target):
54
+ target_group = target[..., 0]
55
+ target_slot = target[..., 1:rv.target.INDEX[0]]
56
+ target_properties = target[..., rv.target.INDEX[0]:rv.target.END_INDEX]
57
+ target_properties_splited = [
58
+ target_properties[..., ::rv.properties.NO],
59
+ target_properties[..., 1::rv.properties.NO],
60
+ target_properties[..., 2::rv.properties.NO]
61
+ ]
62
+ return target_group, target_slot, target_properties_splited
63
+
64
+
65
+
66
+
67
+ def decode_target_flat(target):
68
+ t = decode_target(target)
69
+ return t[0], t[1], t[2][0], t[2][1], t[2][2]
70
+
71
+
72
+ def demask(target, mask=None, group=None, zeroes=None):
73
+ if mask is None:
74
+ if group is None:
75
+ group = target[0]
76
+ # todo Use numpy range Mask
77
+ from models.uitls_ import RangeMask
78
+ mask = RangeMask()(group).numpy()
79
+ if zeroes is None:
80
+ return np.concatenate([t[mask] for t in lw(target[1:])])
81
+ return np.concatenate([target[0][None]] + [t * mask for t in lw(target[1:])],axis=-1)
82
+
83
+
84
+ def target_mask(mask,right=1):
85
+ shape = mask.shape
86
+ return np.concatenate([np.ones([shape[0], 1]) ,mask, np.repeat(mask,3,axis=1), np.ones([shape[0], right])],axis=1)
87
+
88
+
89
+ def get_full_range_mask(mask):
90
+ return np.concatenate([mask, np.repeat(mask, 3, axis=-1)], axis=-1)
91
+
92
+ def compare(target, predict, mask):
93
+ target_comp = target[:, 1:rv.target.END_INDEX]
94
+ predict_comp = predict[:, 1:rv.target.END_INDEX]
95
+
96
+ mask = get_full_range_mask(mask)
97
+
98
+ target_masked = target_comp * mask
99
+ predict_masked = predict_comp * mask
100
+ return target_masked == predict_masked
raven_utils/depricated/__init__.py ADDED
File without changes
raven_utils/depricated/old_raven.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import numpy as np
4
+ from data_utils import take, EXIST, COR
5
+ from data_utils.image import draw_images, add_text
6
+ from data_utils.op import np_split
7
+ from ml_utils import lu, dict_from_list2, filter_keys, none
8
+ from data_utils import ops as K
9
+
10
+ from config.constant import PROPERTY, TARGET, INPUTS
11
+ # from raven_utils.render.rendering import render_panels
12
+
13
+ RENDER_POSITIONS = [
14
+ [(0.5, 0.5, 1, 1)],
15
+ # ...
16
+ [(0.25, 0.25, 0.5, 0.5),
17
+ (0.25, 0.75, 0.5, 0.5),
18
+ (0.75, 0.25, 0.5, 0.5),
19
+ (0.75, 0.75, 0.5, 0.5)],
20
+ # ...
21
+ [(0.16, 0.16, 0.33, 0.33),
22
+ (0.16, 0.5, 0.33, 0.33),
23
+ (0.16, 0.83, 0.33, 0.33),
24
+ (0.5, 0.16, 0.33, 0.33),
25
+ (0.5, 0.5, 0.33, 0.33),
26
+ (0.5, 0.83, 0.33, 0.33),
27
+ (0.83, 0.16, 0.33, 0.33),
28
+ (0.83, 0.5, 0.33, 0.33),
29
+ (0.83, 0.83, 0.33, 0.33)],
30
+ # ...
31
+ [(0.5, 0.25, 0.5, 0.5)],
32
+ [(0.5, 0.75, 0.5, 0.5)],
33
+ # ...
34
+ [(0.25, 0.5, 0.5, 0.5)],
35
+ [(0.75, 0.5, 0.5, 0.5)],
36
+ # ...
37
+ [(0.5, 0.5, 1, 1)],
38
+ [(0.5, 0.5, 0.33, 0.33)],
39
+ # ...
40
+ [(0.5, 0.5, 1, 1)],
41
+ [(0.42, 0.42, 0.15, 0.15),
42
+ (0.42, 0.58, 0.15, 0.15),
43
+ (0.58, 0.42, 0.15, 0.15),
44
+ (0.58, 0.58, 0.15, 0.15)],
45
+ # ...
46
+
47
+ ]
48
+
49
+ HORIZONTAL = "horizontal"
50
+ VERTICAL = "vertical"
51
+
52
+ NAMES = ['center_single',
53
+ 'distribute_four',
54
+ 'distribute_nine',
55
+ 'in_center_single_out_center_single',
56
+ 'in_distribute_four_out_center_single',
57
+ 'left_center_single_right_center_single',
58
+ 'up_center_single_down_center_single']
59
+
60
+ PROPERTIES_NAMES = [
61
+ 'Color',
62
+ 'Size',
63
+ 'Type',
64
+
65
+ ]
66
+ PROPERTIES = dict_from_list2(PROPERTIES_NAMES, [10, 6, 5])
67
+ ANGLE_MAX = 7
68
+
69
+ PROPERTIES_NO = len(PROPERTIES)
70
+
71
+ RULES_COMBINE = "Number/Position"
72
+
73
+ RULES_ATTRIBUTES = [
74
+ "Number",
75
+ "Position",
76
+ "Color",
77
+ "Size",
78
+ "Type"
79
+ ]
80
+ RULES_ATTRIBUTES_LEN = len(RULES_ATTRIBUTES)
81
+
82
+ RULES_ATTRIBUTES_INDEX = dict_from_list2(RULES_ATTRIBUTES)
83
+
84
+ RULES_TYPES = [
85
+ "Constant",
86
+ "Arithmetic",
87
+ "Progression",
88
+ "Distribute_Three"
89
+ ]
90
+ RULES_TYPES_INDEX = dict_from_list2(RULES_TYPES)
91
+ RULES_TYPES_LEN = len(RULES_ATTRIBUTES)
92
+
93
+ GROUPS_NO = len(NAMES)
94
+ ENTITY_NO = dict(zip(NAMES, [1, 4, 9, 2, 5, 2, 2]))
95
+ ENTITY_SUM = sum(list(ENTITY_NO.values()))
96
+ ENTITY_INDEX = np.concatenate([[0], np.cumsum(list(ENTITY_NO.values()))])
97
+ ENTITY_INDEX_TARGET = ENTITY_INDEX + 1
98
+ ENTITY_DICT = dict(zip(NAMES, ENTITY_INDEX_TARGET[:-1]))
99
+ NAMES_ORDER = dict(zip(NAMES, np.arange(len(NAMES))))
100
+ PROPERTIES_INDEXES = np.cumsum(np.array(list(ENTITY_NO.values())) * len(PROPERTIES))
101
+ INDEX = np.concatenate([[0], PROPERTIES_INDEXES]) + ENTITY_SUM + 1 # +2 type and uniformity
102
+
103
+ SECOND_LAYOUT = [i - 1 for i in [
104
+ ENTITY_DICT["in_center_single_out_center_single"] + 1,
105
+ ENTITY_DICT["in_distribute_four_out_center_single"] + 1,
106
+ ENTITY_DICT["in_distribute_four_out_center_single"] + 2,
107
+ ENTITY_DICT["in_distribute_four_out_center_single"] + 3,
108
+ ENTITY_DICT["left_center_single_right_center_single"] + 1,
109
+ ENTITY_DICT["up_center_single_down_center_single"] + 1
110
+ ]]
111
+
112
+ FIRST_LAYOUT = list(set(range(ENTITY_SUM)) - set(SECOND_LAYOUT))
113
+ LAYOUT_NO = 2
114
+
115
+ START_INDEX = dict(zip(NAMES, INDEX[:-1]))
116
+ END_INDEX = INDEX[-1]
117
+
118
+ RULES_ATTRIBUTES_ALL_LEN = RULES_ATTRIBUTES_LEN * LAYOUT_NO
119
+ UNIFORMITY_NO = 2
120
+ UNIFORMITY_INDEX = END_INDEX + RULES_ATTRIBUTES_ALL_LEN
121
+
122
+ FEATURE_NO = UNIFORMITY_INDEX + UNIFORMITY_NO
123
+ MAPPING = {
124
+ "distribute_nine":
125
+ {0.16: 0,
126
+ 0.5: 1,
127
+ 0.83: 2},
128
+ "distribute_four":
129
+ {0.25: 0,
130
+ 0.75: 1},
131
+ 'in_distribute_four_out_center_single':
132
+ {0.42: 0,
133
+ 0.58: 1}
134
+ }
135
+ MUL = {
136
+ "distribute_nine": 3,
137
+ "distribute_four": 2,
138
+ 'in_distribute_four_out_center_single': 2
139
+ }
140
+
141
+ # SIZES = np.linspace(0.4, 0.9, 6)
142
+ TYPES = ["triangle", "square", "pentagon", "hexagon", "circle"]
143
+ # TYPES = ["triangle", "square", "pentagon", "circle", "circle"]
144
+ SIZES = ["vs", "s", "m", "h", "vh", "e"]
145
+ COLORS = ["vs", "s", "m", "h", "vh", "e"]
146
+ # TYPES = ["", "", "circle", "hexagon", "square"]
147
+
148
+ ENTITY_PROPERTIES_VALUES = list(PROPERTIES.values())
149
+ ENTITY_PROPERTIES_KEYS = list(PROPERTIES.keys())
150
+ ENTITY_PROPERTIES_NO = len(PROPERTIES)
151
+ INDEX = dict(zip(PROPERTIES, np.array(ENTITY_PROPERTIES_VALUES) * ENTITY_SUM))
152
+ ENTITY_PROPERTIES_SUM = sum(list(PROPERTIES.values()))
153
+
154
+ OUTPUT_SIZE = ENTITY_SUM * ENTITY_PROPERTIES_SUM + GROUPS_NO + ENTITY_SUM
155
+
156
+ SLOT_AND_GROUP = ENTITY_SUM + GROUPS_NO
157
+
158
+ OUTPUT_GROUP_SLICE = np.s_[:, -GROUPS_NO:]
159
+ OUTPUT_SLOT_SLICE = np.s_[:, -SLOT_AND_GROUP:-GROUPS_NO]
160
+ OUTPUT_PROPERTIES_SLICE = np.s_[:, :-SLOT_AND_GROUP]
161
+
162
+ OUTPUT_GROUP_SLICE_END = np.s_[-GROUPS_NO:]
163
+ OUTPUT_SLOT_SLICE_END = np.s_[-SLOT_AND_GROUP:-GROUPS_NO]
164
+ OUTPUT_PROPERTIES_SLICE_END = np.s_[:-SLOT_AND_GROUP]
165
+
166
+ # Transformation
167
+ # constant
168
+ # progression -2, -1,1 ,2
169
+ # arithmetic -/+ Position set arithmetic
170
+ # distribute three
171
+
172
+ # todo
173
+ SLOTS_GROUPS = GROUPS_NO
174
+
175
+ SLOT_TRANSFORMATION_NO = 4
176
+ PROPERTY_TRANSFORMATION_NO = 8
177
+ PROPERTIES_TRANSFORMATION_NO = PROPERTY_TRANSFORMATION_NO * PROPERTIES_NO
178
+ PROPERTIES_TRANSFORMATION_SIZE = PROPERTIES_TRANSFORMATION_NO * ENTITY_SUM
179
+
180
+ SLOT_TRANSFORMATION_SIZE = PROPERTY_TRANSFORMATION_NO * SLOTS_GROUPS
181
+ INFERENCE_SIZE = SLOT_TRANSFORMATION_SIZE + PROPERTIES_TRANSFORMATION_SIZE
182
+
183
+ INFERENCE_SLOT_SLICE = np.s_[:, :SLOT_TRANSFORMATION_SIZE]
184
+ INFERENCE_PROPERTIES_SLICE = np.s_[:, -PROPERTIES_TRANSFORMATION_SIZE:]
185
+ from operator import add
186
+
187
+
188
+ # todo Refactor
189
+ # Maybe properties should be on same level as rest.
190
+ def decode_output(output, split_fn=np_split):
191
+ group_output = output[..., OUTPUT_GROUP_SLICE_END]
192
+ slot_output = output[..., OUTPUT_SLOT_SLICE_END]
193
+ properties_output = output[..., OUTPUT_PROPERTIES_SLICE_END]
194
+ properties_output_splited = split_fn(properties_output, list(rv.properties.INDEX.values()), axis=-1)
195
+ return group_output, slot_output, properties_output_splited
196
+
197
+
198
+ def decode_inference(inference, reshape=np.reshape):
199
+ return reshape(inference[INFERENCE_SLOT_SLICE],
200
+ [-1, SLOTS_GROUPS, PROPERTY_TRANSFORMATION_NO]), reshape(
201
+ inference[INFERENCE_PROPERTIES_SLICE],
202
+ [-1, PROPERTIES_NO, ENTITY_SUM, PROPERTY_TRANSFORMATION_NO])
203
+
204
+
205
+ def decode_output_reshape(output, split_fn=np_split):
206
+ result = decode_output(output, split_fn=split_fn)
207
+ out_reshaped = []
208
+ for i, out in enumerate(result[2]):
209
+ shape = (-1, ENTITY_SUM, ENTITY_PROPERTIES_VALUES[i])
210
+ out_reshaped.append(out.reshape(shape))
211
+ return result[:2] + tuple(out_reshaped)
212
+
213
+
214
+ def take_target(target):
215
+ return target[1], target[2]
216
+
217
+
218
+ def create_target(images, index, pattern_index=(2, 5), full_index=False, arrange=np.arange, shape=lambda x: x.shape):
219
+ return [images[:, pattern_index[0]], images[:, pattern_index[1]],
220
+ images[arrange(shape(index)[0]), (0 if full_index else 8) + index[:, 0]]]
221
+
222
+
223
+ def take_target_simple(target):
224
+ return target[1], target[0]
225
+
226
+
227
+ def create_target_simple(images, target, index=slice(None), pattern_index=(2, 5)):
228
+ return [images[:, pattern_index[0]], images[:, pattern_index[1]], target][index]
229
+
230
+
231
+ def decode_output_result(output, split_fn=np_split, arg_max=np.argmax):
232
+ result = decode_output_reshape(output, split_fn=split_fn)
233
+ res = []
234
+ for i, r in enumerate(result):
235
+ if i == 1:
236
+ res.append(r)
237
+ else:
238
+ res.append(arg_max(r, axis=-1))
239
+ return tuple(res)
240
+
241
+
242
+ def decode_target(target):
243
+ target_group = target[..., 0]
244
+ target_slot = target[..., 1:INDEX[0]]
245
+ target_properties = target[..., INDEX[0]:END_INDEX]
246
+ target_properties_splited = [
247
+ target_properties[..., ::PROPERTIES_NO],
248
+ target_properties[..., 1::PROPERTIES_NO],
249
+ target_properties[..., 2::PROPERTIES_NO]
250
+ ]
251
+ return target_group, target_slot, target_properties_splited
252
+
253
+
254
+ def decode_target_flat(target):
255
+ t = decode_target(target)
256
+ return t[0], t[1], t[2][0], t[2][1], t[2][2]
257
+
258
+
259
+ def draw_board(images, target=None, predict=None,image=None, desc=None, layout=None, break_=20):
260
+ if image != "target" and predict is not None:
261
+ image = images[predict:predict + 1]
262
+ elif images is None and target is not None:
263
+ image = images[target:target + 1]
264
+ # image = False to not draw anything
265
+ border = [{COR: target - 8, EXIST: (1, 3)}] + [{COR: p, EXIST: (0, 2)} for p in none(predict)]
266
+
267
+ boards = []
268
+ boards.append(draw_images(np.concatenate([images[:8], image[None] if len(image.shape)==3 else image]) if image is not None else images[:8]))
269
+ if layout == 1:
270
+ i = draw_images(images[8:], column=4, border=border)
271
+ if break_:
272
+ i = np.concatenate([np.zeros([ break_, i.shape[1],1]),i ],axis=0)
273
+ boards.append(i)
274
+
275
+ else:
276
+ boards.append(
277
+ draw_images(np.concatenate([images[8:], predict]) if predict is not None else images[8:], column=4,
278
+ border=target - 8))
279
+ full_board = draw_images(boards, grid=False)
280
+ if desc:
281
+ full_board = add_text(full_board, desc)
282
+ return full_board
283
+
284
+
285
+ def draw_boards(images, target=None, predict=None, image=None, desc=None, no=1, layout=None):
286
+ boards = []
287
+ for i, image in enumerate(images):
288
+ boards.append(draw_board(image, target[i][0] if target is not None else None,
289
+ predict[i] if predict is not None else None,
290
+ image[i] if image is not None else None,
291
+ desc[i] if desc is not None else None, layout=layout))
292
+ return boards
293
+
294
+
295
+ def draw_raven(generator, predict=None, no=1, add_target_desc=True, indexes=None, types=TYPES,
296
+ layout=1):
297
+ if indexes is None:
298
+ indexes = val_sample(no)
299
+ data = generator.data[indexes]
300
+ if is_model(predict):
301
+ d = filter_keys(data, PROPERTY,reverse=True)
302
+ # tmp change
303
+ pro = predict(d)['predict']
304
+ print(pro)
305
+ predict = render_panels(pro, target=False)
306
+ # if target is not None:
307
+ target = data[TARGET]
308
+ target_index = data["index"]
309
+ images = data[INPUTS]
310
+
311
+ if hasattr(predict, "shape"):
312
+ if len(predict.shape) > 3:
313
+ # iamges
314
+ image = predict
315
+ # todo create index and output based on image
316
+ predict = None
317
+ predict_index = None
318
+ elif len(predict.shape) == 3:
319
+ image = render_panels(predict, target=False)
320
+ # Create index based on predict.
321
+ predict_index = None
322
+ else:
323
+ image = images[predict]
324
+ predict_index = predict
325
+ predict = target
326
+ else:
327
+ image = K.gather(images, target_index[:, 0])
328
+ predict_index = None
329
+ predict = None
330
+
331
+ # elif not(hasattr(target,"shape") and len(target.shape) > 3):
332
+ # if hasattr(target,"shape") and target.shape[-1] == OUTPUT_SIZE:
333
+ # pro = target
334
+ # predict = render_panels(pro)
335
+ # elif hasattr(target,"shape") and target.shape[-1] == FEATURE_NO:
336
+ # # pro = target
337
+ # pro = np.zeros([no, OUTPUT_SIZE], dtype="int")
338
+ # else:
339
+ # pro = np.zeros([no, OUTPUT_SIZE], dtype="int")
340
+ # # predict = [None] * no
341
+ # predict = render_panels(data[TARGET])
342
+
343
+ all_rules = []
344
+ for d in data[PROPERTY]:
345
+ rules = []
346
+ for j, rule_group in enumerate(d.findAll("Rule_Group")):
347
+ # rules_all.append(rule_group['id'])
348
+ for j, rule in enumerate(rule_group.findAll("Rule")):
349
+ rules.append(f"{rule['attr']} - {rule['name']}")
350
+ rules.append("")
351
+ all_rules.append(rules)
352
+ target_desc = get_desc(target)
353
+ if predict is not None:
354
+ predict_desc = decode_output_result(predict) if predict.shape[-1] == OUTPUT_SIZE else get_desc(predict)
355
+ else:
356
+ predict_desc = [None] * len(target_desc)
357
+ for a, po, to in zip(all_rules, predict_desc, target_desc):
358
+ # fl(predict_desc[-1])
359
+ if po is None:
360
+ po = [None] * len(to)
361
+ for p, t in zip(po, to):
362
+ a.extend(
363
+ [" ".join([str(i) for i in t])] + (
364
+ [" ".join([str(i) for i in p]), ""] if p is not None else []
365
+ )
366
+ )
367
+ # a.extend([""] + [] + [""] + [" ".join(fl(p))])
368
+
369
+ # image = draw_boards(data[INPUTS],target=data["index"], predict=predict[:no], desc=all_rules, no=no,layer=layer)
370
+ image = draw_boards(images, target=target_index, predict=predict_index, image=image, desc=None, no=no,
371
+ layout=layout)
372
+ return lu([(i, j) for i, j in zip(image, all_rules)])
373
+
374
+
375
+ def val_sample(no=GROUPS_NO, base=3):
376
+ indexes = np.arange(no) * 2000 + base
377
+ return indexes
378
+
379
+
380
+ def get_desc(target, exist=None, types=TYPES, sizes=SIZES):
381
+ decoded = decode_target(target)
382
+ exist = decoded[1] if exist is None else exist
383
+ taken = np.stack(take(decoded[2], np.array(exist, dtype=bool))).T
384
+
385
+ figures_no = np.sum(exist, axis=-1)
386
+ desc = np.split(taken, np.cumsum(figures_no))[:-1]
387
+ # figures_no = np.sum(exist, axis=-1)
388
+ # div = np.split(desc, np.cumsum(figures_no))[:-1]
389
+ result = []
390
+ for pd in desc:
391
+ r = []
392
+ for p in pd:
393
+ r.append([p[0], sizes[p[1]], types[p[2]]])
394
+ result.append(r)
395
+
396
+ return result
397
+
398
+
399
+ # def get
400
+
401
+
402
+ def get_description(inputs, predict, pro, no, types=TYPES, sizes=SIZES):
403
+ # target = inputs[1][2][:no]
404
+ target = inputs[TARGET]
405
+ target_group = target[:, 0]
406
+ target_exist = np.asarray(target[:, 1:ENTITY_SUM + 1], dtype="bool")
407
+ target_rest = target[:, ENTITY_SUM + 1:ENTITY_SUM + 1 + ENTITY_SUM * PROPERTIES_NO]
408
+ pro_reshaped = np.reshape(pro, (pro.shape[0], -1, PROPERTIES_NO))
409
+ target_reshaped = np.reshape(target_rest, (target_rest.shape[0], -1, PROPERTIES_NO))
410
+
411
+ # mask = np.repeat(target_exist, [4] * ENTITY_SUM, axis=-1)
412
+ # masked_result = np.repeat(target_exist, [4] * ENTITY_SUM, axis=-1)
413
+ pro_res = pro_reshaped[target_exist]
414
+ target_res = target_reshaped[target_exist]
415
+ figures_no = np.sum(target_exist, axis=-1)
416
+
417
+ pro_div = np.split(pro_res, np.cumsum(figures_no))[:-1]
418
+ target_div = np.split(target_res, np.cumsum(figures_no))[:-1]
419
+ pro_result_full = []
420
+ target_result_full = []
421
+ for pd, td in zip(pro_div, target_div):
422
+ pro_result = []
423
+ target_result = []
424
+ for p in pd:
425
+ pro_result.append([p[0], sizes[p[1]], types[p[2]]])
426
+ for t in td:
427
+ target_result.append([t[0], sizes[t[1]], types[t[2]]])
428
+ pro_result_full.append(pro_result)
429
+ target_result_full.append(target_result)
430
+
431
+ return pro_result_full, target_result_full
432
+
433
+
434
+ def get_properties(target, types=TYPES, sizes=SIZES):
435
+ target_exist = np.asarray(target[:, 1:ENTITY_SUM + 1], dtype="bool")
436
+ target_rest = target[:, ENTITY_SUM + 1:ENTITY_SUM + 1 + ENTITY_SUM * PROPERTIES_NO]
437
+ target_reshaped = np.reshape(target_rest, (target_rest.shape[0], -1, PROPERTIES_NO))
438
+ target_res = target_reshaped[target_exist]
439
+ figures_no = np.sum(target_exist, axis=-1)
440
+ target_div = np.split(target_res, np.cumsum(figures_no))[:-1]
441
+ target_result_full = []
442
+ for td in target_div:
443
+ target_result = []
444
+ for t in td:
445
+ target_result.append([t[0], sizes[t[1]], types[t[2]]])
446
+ target_result_full.append(target_result)
447
+ return target_result_full
448
+
449
+
450
+ def desc_properties(target, decode_fn=None, types=TYPES, sizes=SIZES):
451
+ if decode_fn is None:
452
+ if target.shape[1] == OUTPUT_SIZE:
453
+ decode_fn = decode_output_result
454
+ else:
455
+ decode_fn = decode_target
456
+
457
+ target_div = decode_fn(target)[2:]
458
+ target_result_full = []
459
+ for td in target_div:
460
+ target_result = []
461
+ for t in td:
462
+ target_result.append([t[0], sizes[t[1]], types[t[2]]])
463
+ target_result_full.append(target_result)
464
+ return target_result_full
465
+
466
+
467
+ def get_pro(t, types=TYPES, sizes=SIZES):
468
+ return [int(t[0]), sizes[t[1]], types[t[2]]]
469
+
470
+
471
+ def get_pro2(td, types=TYPES, sizes=SIZES):
472
+ target_result = []
473
+ for t in td:
474
+ target_result.append([int(t[0]), sizes[t[1]], types[t[2]]])
475
+ return target_result
476
+
477
+
478
+ def get_pro3(target_div, types=TYPES, sizes=SIZES):
479
+ target_result_full = []
480
+ for td in target_div.to_list():
481
+ target_result = []
482
+ for t in td:
483
+ target_result.append([int(t[0]), sizes[t[1]], types[t[2]]])
484
+ target_result_full.append(target_result)
485
+ return target_result_full
486
+
487
+
488
+ from models_utils import init_image as def_init_image, is_model
489
+
490
+ init_image = partial(def_init_image, shape=(16, 8, 80, 80, 1))
raven_utils/draw.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from data_utils import take, EXIST, COR
3
+ from data_utils.image import draw_images, add_text
4
+ from funcy import identity
5
+ from ml_utils import none, filter_keys, lu
6
+ from models_utils import is_model
7
+ from models_utils import ops as K
8
+
9
+ from raven_utils.constant import PROPERTY, TARGET, INPUTS
10
+ from raven_utils.decode import decode_target, target_mask
11
+ from raven_utils.render.rendering import render_panels
12
+ from raven_utils.render_ import TYPES, SIZES
13
+ from raven_utils.uitls import get_val_index
14
+
15
+
16
+ def draw_board(images, target=None, predict=None, image=None, desc=None, layout=None, break_=20):
17
+ if image != "target" and predict is not None:
18
+ image = images[predict:predict + 1]
19
+ elif images is None and target is not None:
20
+ image = images[target:target + 1]
21
+ # image = False to not draw anything
22
+ border = [{COR: target - 8, EXIST: list(range(4)) if predict is None else (1, 3)}] + [{COR: p, EXIST: (0, 2)} for p
23
+ in none(predict)]
24
+
25
+ boards = []
26
+ boards.append(draw_images(
27
+ np.concatenate([images[:8], image[None] if len(image.shape) == 3 else image]) if image is not None else images[
28
+ :8]))
29
+ if layout == 1:
30
+ i = draw_images(images[8:], column=4, border=border)
31
+ if break_:
32
+ i = np.concatenate([np.zeros([break_, i.shape[1], 1]), i], axis=0)
33
+ boards.append(i)
34
+
35
+ else:
36
+ boards.append(
37
+ draw_images(np.concatenate([images[8:], predict]) if predict is not None else images[8:], column=4,
38
+ border=target - 8))
39
+ full_board = draw_images(boards, grid=False)
40
+ if desc:
41
+ full_board = add_text(full_board, desc)
42
+ return full_board
43
+
44
+
45
+ def draw_boards(images, target=None, predict=None,image=None, desc=None, layout=None):
46
+ boards = []
47
+ for i, im in enumerate(images):
48
+ boards.append(draw_board(im, target[i][0] if target is not None else None,
49
+ predict[i] if predict is not None else None,
50
+ image[i] if image is not None else None,
51
+ desc[i] if desc is not None else None, layout=layout))
52
+ return boards
53
+
54
+
55
+ def draw_from_generator(generator, predict=None, no=1, indexes=None, layout=1):
56
+ data,_ = val_sample(generator, no, indexes)
57
+ return draw_raven(data, predict=predict, pre_fn=generator.data.data["inputs"].fn, layout=layout)
58
+
59
+
60
+ def val_sample(generator, no=1, indexes=None):
61
+ if indexes is None:
62
+ indexes = get_val_index(base=no)
63
+ data = generator.data[indexes]
64
+ return data, indexes
65
+
66
+ def render_from_model(data,predict,pre_fn=identity):
67
+ data = filter_keys(data, PROPERTY, reverse=True)
68
+ if is_model(predict):
69
+ predict = predict(data)
70
+ pro = np.array(target_mask(predict['predict_mask'].numpy()) * predict["predict"].numpy(), dtype=np.int8)
71
+ return pre_fn(render_panels(pro, target=False)[None])[0]
72
+
73
+ def draw_raven(data, predict=None, pre_fn=identity, layout=1):
74
+ if is_model(predict):
75
+ d = filter_keys(data, PROPERTY, reverse=True)
76
+ # tmp change
77
+ res = predict(d)
78
+ pro = np.array(target_mask(res['mask'].numpy()) * res["predict"].numpy(),dtype=np.int8)
79
+ predict = pre_fn(render_panels(pro, target=False)[None])[0]
80
+ # from data_utils import ims
81
+ # ims(1 - predict[0])
82
+ # if target is not None:
83
+ target = data[TARGET]
84
+ target_index = data["index"]
85
+ images = data[INPUTS]
86
+ # np.equal(res['predict'], pro[:,:102]).sum()
87
+
88
+ if hasattr(predict, "shape"):
89
+ if len(predict.shape) > 3:
90
+ # iamges
91
+ image = predict
92
+ # todo create index and output based on image
93
+ predict = None
94
+ predict_index = None
95
+ elif len(predict.shape) == 3:
96
+ image = render_panels(predict, target=False)
97
+ # Create index based on predict.
98
+ predict_index = None
99
+ else:
100
+ image = images[predict]
101
+ predict_index = predict
102
+ predict = target
103
+ else:
104
+ image = K.gather(images, target_index[:, 0])
105
+ predict_index = None
106
+ predict = None
107
+
108
+ # elif not(hasattr(target,"shape") and len(target.shape) > 3):
109
+ # if hasattr(target,"shape") and target.shape[-1] == OUTPUT_SIZE:
110
+ # pro = target
111
+ # predict = render_panels(pro)
112
+ # elif hasattr(target,"shape") and target.shape[-1] == FEATURE_NO:
113
+ # # pro = target
114
+ # pro = np.zeros([no, OUTPUT_SIZE], dtype="int")
115
+ # else:
116
+ # pro = np.zeros([no, OUTPUT_SIZE], dtype="int")
117
+ # # predict = [None] * no
118
+ # predict = render_panels(data[TARGET])
119
+
120
+ image = draw_boards(images, target=target_index, predict=predict_index,image=image, desc=None,
121
+ layout=layout)
122
+
123
+ all_rules = extract_rules(data[PROPERTY])
124
+ target_desc = get_desc(target)
125
+ if predict is not None:
126
+ predict_desc = get_desc(predict)
127
+ else:
128
+ predict_desc = [None] * len(target_desc)
129
+ for a, po, to in zip(all_rules, predict_desc, target_desc):
130
+ # fl(predict_desc[-1])
131
+ if po is None:
132
+ po = [None] * len(to)
133
+ for p, t in zip(po, to):
134
+ a.extend(
135
+ [" ".join([str(i) for i in t])] + (
136
+ [" ".join([str(i) for i in p]), ""] if p is not None else []
137
+ )
138
+ )
139
+ # a.extend([""] + [] + [""] + [" ".join(fl(p))])
140
+
141
+ # image = draw_boards(data[INPUTS],target=data["index"], predict=predict[:no], desc=all_rules, no=no,layer=layer)
142
+ return lu([(i, j) for i, j in zip(image, all_rules)])
143
+
144
+
145
+ def extract_rules(data):
146
+ all_rules = []
147
+ for d in data:
148
+ rules = []
149
+ for j, rule_group in enumerate(d.findAll("Rule_Group")):
150
+ # rules_all.append(rule_group['id'])
151
+ for j, rule in enumerate(rule_group.findAll("Rule")):
152
+ rules.append(f"{rule['attr']} - {rule['name']}")
153
+ rules.append("")
154
+ all_rules.append(rules)
155
+ return all_rules
156
+
157
+
158
+ def get_desc(target, exist=None, types=TYPES, sizes=SIZES):
159
+ decoded = decode_target(target)
160
+ exist = decoded[1] if exist is None else exist
161
+ taken = np.stack(take(decoded[2], np.array(exist, dtype=bool))).T
162
+
163
+ figures_no = np.sum(exist, axis=-1)
164
+ desc = np.split(taken, np.cumsum(figures_no))[:-1]
165
+ # figures_no = np.sum(exist, axis=-1)
166
+ # div = np.split(desc, np.cumsum(figures_no))[:-1]
167
+ result = []
168
+ for pd in desc:
169
+ r = []
170
+ for p in pd:
171
+ r.append([p[0], sizes[p[1]], types[p[2]]])
172
+ result.append(r)
173
+
174
+ return result
raven_utils/entity.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import raven_utils.group as group
2
+ import numpy as np
3
+ NO = dict(zip(group.NAMES, [1, 4, 9, 2, 5, 2, 2]))
4
+ SUM = sum(list(NO.values()))
5
+
6
+ INDEX = np.concatenate([[0], np.cumsum(list(NO.values()))])
raven_utils/group.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ NAMES = ['center_single',
4
+ 'distribute_four',
5
+ 'distribute_nine',
6
+ 'in_center_single_out_center_single',
7
+ 'in_distribute_four_out_center_single',
8
+ 'left_center_single_right_center_single',
9
+ 'up_center_single_down_center_single']
10
+
11
+ NO = len(NAMES)
raven_utils/inference.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import raven_utils.properties as properties
3
+ import raven_utils.group as group
4
+
5
+
6
+ SLOT_TRANSFORMATION_NO = 4
7
+ PROPERTY_TRANSFORMATION_NO = 8
8
+ PROPERTIES_TRANSFORMATION_NO = PROPERTY_TRANSFORMATION_NO * properties.NO
9
+ PROPERTIES_TRANSFORMATION_SIZE = PROPERTIES_TRANSFORMATION_NO * group.NO
10
+
11
+ SLOT_TRANSFORMATION_SIZE = PROPERTY_TRANSFORMATION_NO * group.NO
12
+ SIZE = SLOT_TRANSFORMATION_SIZE + PROPERTIES_TRANSFORMATION_SIZE
13
+
14
+ SLOT_SLICE = np.s_[:, :SLOT_TRANSFORMATION_SIZE]
15
+ PROPERTIES_SLICE = np.s_[:, -PROPERTIES_TRANSFORMATION_SIZE:]
raven_utils/models/__init__.py ADDED
File without changes
raven_utils/models/attn.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+
3
+ import tensorflow as tf
4
+ from tensorflow.keras import backend as K
5
+ from tensorflow.keras.layers import LSTMCell
6
+ from tensorflow.keras.models import Model
7
+ from tensorflow.keras.layers import Conv2D, Dense
8
+ from tensorflow.keras.losses import mse
9
+ from tensorflow.keras.models import clone_model
10
+ from tensorflow.layers.base import InputSpec, Layer
11
+
12
+ from models.dense import create_conv_model
13
+ from models.utils import broadcast
14
+
15
+
16
+ class ReflectionPadding2D(Layer):
17
+ def __init__(self, padding=(1, 1), **kwargs):
18
+ self.padding = tuple(padding)
19
+ self.input_spec = [InputSpec(ndim=4)]
20
+ super(ReflectionPadding2D, self).__init__(**kwargs)
21
+
22
+ def compute_output_shape(self, s):
23
+ """ If you are using "channels_last" configuration"""
24
+ return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])
25
+
26
+ def call(self, x, mask=None):
27
+ w_pad, h_pad = self.padding
28
+ return tf.pad(x, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]], 'REFLECT')
29
+
30
+
31
+ class Conv2Ref(Layer):
32
+ def __init__(self, padding=(1, 1), **kwargs):
33
+ self.padding = tuple(padding)
34
+ self.input_spec = [InputSpec(ndim=4)]
35
+ super(ReflectionPadding2D, self).__init__(**kwargs)
36
+
37
+ def compute_output_shape(self, s):
38
+ """ If you are using "channels_last" configuration"""
39
+ return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])
40
+
41
+ def call(self, x, mask=None):
42
+ w_pad, h_pad = self.padding
43
+ return tf.pad(x, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]], 'REFLECT')
44
+
45
+
46
+ class SegmentationNetwork(Model):
47
+
48
+ def __init__(self, filters=64, kernels=(3, 3)):
49
+ super(RecAE, self).__init__()
50
+ self.conv_1 = Conv2D(filters, kernels, padding=SAME)
51
+ self.conv_2 = Conv2D(filters, kernels, padding=SAME)
52
+
53
+ def call(self, inputs):
54
+ x = K.relu(inputs)
55
+ x = self.conv_1(x)
56
+ x = K.relu(x)
57
+ x = self.conv_2(x)
58
+ return x + inputs
59
+
60
+
61
+ class QueryNetwork(Model):
62
+
63
+ def __init__(self, units=64):
64
+ super(RecAE, self).__init__()
65
+ self.conv_1 = Dense(units)
66
+ self.conv_2 = Dense(units)
67
+
68
+ def call(self, inputs):
69
+ x = K.relu(inputs)
70
+ x = self.conv_1(x)
71
+ x = K.relu(x)
72
+ x = self.conv_2(x)
73
+ return x + inputs
74
+
75
+
76
+ class RecAE(Model):
77
+
78
+ def __init__(self, head, bottle, decoder):
79
+ super(RecAE, self).__init__()
80
+ self.head = head
81
+ self.bottle = bottle
82
+ self.base = clone_model(bottle)
83
+ self.decoder = decoder
84
+ self.segmentation_network = SegmentationNetwork()
85
+ self.query_network = QueryNetwork()
86
+ self.control = LSTMCell(64)
87
+ self.memory = LSTMCell(64)
88
+
89
+ def call(self, inputs):
90
+ feature = self.head(inputs)
91
+ segmentation = self.segmentation_network(feature)
92
+ control_base = self.base(feature)
93
+ h_c = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
94
+ h_m = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
95
+ shape = K.shape(feature)[:-1]
96
+ full_attention = tf.zeros(shape)[..., tf.newaxis]
97
+ full_image = tf.zeros(K.shape(inputs))
98
+ masks = []
99
+ ff = tf.zeros(K.shape(inputs))
100
+ scope = tf.ones(shape)[..., tf.newaxis]
101
+ for i in range(4):
102
+ r_c, h_c = self.control(tf.concat([control_base, h_m[0]], 1), h_c)
103
+ query = self.query_network(h_c[0])
104
+ log_attention = image_attention(segmentation, query)
105
+ attention = K.sigmoid(log_attention)
106
+ mask = attention * scope
107
+ scope = scope - mask
108
+ im = feature * mask
109
+ # im = feature
110
+ latent = self.bottle(im)
111
+ decoded = self.decoder(latent)
112
+ # self.add_loss(K.mean(-mse(full_attention, attention)))
113
+ # self.add_loss(K.mean(-mse(tf.ones(attention.shape), attention)))
114
+ full_attention += attention
115
+ big_mask = tf.image.resize(mask, K.shape(inputs)[1:-1])
116
+ ff += K.sigmoid(decoded)
117
+ full_image += K.sigmoid(decoded) * big_mask
118
+ r_m, h_m = self.memory(latent, h_m)
119
+ masks.append(big_mask)
120
+ self.add_loss(K.mean(mse(inputs, full_image)))
121
+ return full_image, masks
122
+
123
+
124
+ # def image_attention(image, query, scale=True):
125
+ @tf.function
126
+ def image_attention(image, query):
127
+ log_attention = K.sum(query[:, tf.newaxis, tf.newaxis, :] * image, axis=-1, keepdims=True)
128
+ # if scale is not None:
129
+ log_attention /= tf.sqrt(tf.cast(K.shape(image)[-1], dtype=float))
130
+ return log_attention
131
+
132
+
133
+ class RecAE_2(Model):
134
+
135
+ def __init__(self, head, bottle, decoder):
136
+ super(RecAE_2, self).__init__()
137
+ self.head = head
138
+ self.bottle = bottle
139
+ # self.base = clone_model(bottle)
140
+ self.base = self.bottle
141
+ self.decoder = decoder
142
+ self.segmentation_network = create_conv_model((64, 64, 1))
143
+ self.control = LSTMCell(64)
144
+ self.memory = LSTMCell(64)
145
+
146
+ def call(self, inputs):
147
+ feature = self.head(inputs)
148
+ control_base = self.base(feature)
149
+ h_c = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
150
+ h_m = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
151
+ shape = K.shape(feature)[:-1]
152
+ full_attention = tf.zeros(shape)[..., tf.newaxis]
153
+ full_image = tf.zeros(K.shape(inputs))
154
+ big_masks = []
155
+ masks = []
156
+ ff = tf.zeros(K.shape(inputs))
157
+ scope = tf.ones(shape)[..., tf.newaxis]
158
+ for i in range(4):
159
+ if i ==3:
160
+ mask = scope
161
+ else:
162
+ r_c, h_c = self.control(tf.concat([control_base, h_m[0]], 1), h_c)
163
+ query = broadcast(h_c[0], feature.shape[1:])
164
+ log_attention = self.segmentation_network(tf.concat([feature, query], axis=-1))
165
+ attention = K.sigmoid(log_attention)
166
+ mask = attention * scope
167
+ scope = scope - mask
168
+ masks.append(mask)
169
+ im = feature * mask
170
+ # im = feature
171
+ latent = self.bottle(im)
172
+ decoded = self.decoder(latent)
173
+ # self.add_loss(K.mean(-mse(scope, mask)))
174
+ sum = K.sum(tf.ones(K.shape(mask)))
175
+ self.add_loss(K.abs((sum/4)-K.sum(mask))/sum)
176
+ # self.add_loss(K.mean(-mse(tf.zeros(K.shape(mask)), mask)))
177
+ for m in masks:
178
+ self.add_loss(K.mean(-mse(m,mask)))
179
+
180
+ full_attention += mask
181
+ big_mask = tf.image.resize(mask, K.shape(inputs)[1:-1])
182
+ ff += K.sigmoid(decoded)
183
+ full_image += K.sigmoid(decoded) * big_mask
184
+ r_m, h_m = self.memory(latent, h_m)
185
+ big_masks.append(big_mask)
186
+ self.add_loss(K.mean(mse(inputs, full_image)))
187
+ return full_image, big_masks
raven_utils/models/attn2.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+
3
+ import tensorflow as tf
4
+ from tensorflow.keras import backend as K
5
+ from tensorflow.keras.layers import LSTMCell
6
+ from tensorflow.keras.models import Model
7
+ from tensorflow.keras.layers import Conv2D, Dense
8
+ from tensorflow.keras.losses import mse
9
+ from tensorflow.keras.models import clone_model
10
+ from tensorflow.layers.base import InputSpec, Layer
11
+
12
+ from models.dense import create_conv_model
13
+ from models.utils import broadcast
14
+
15
+
16
+ class ReflectionPadding2D(Layer):
17
+ def __init__(self, padding=(1, 1), **kwargs):
18
+ self.padding = tuple(padding)
19
+ self.input_spec = [InputSpec(ndim=4)]
20
+ super(ReflectionPadding2D, self).__init__(**kwargs)
21
+
22
+ def compute_output_shape(self, s):
23
+ """ If you are using "channels_last" configuration"""
24
+ return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])
25
+
26
+ def call(self, x, mask=None):
27
+ w_pad, h_pad = self.padding
28
+ return tf.pad(x, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]], 'REFLECT')
29
+
30
+
31
+ class Conv2Ref(Layer):
32
+ def __init__(self, padding=(1, 1), **kwargs):
33
+ self.padding = tuple(padding)
34
+ self.input_spec = [InputSpec(ndim=4)]
35
+ super(ReflectionPadding2D, self).__init__(**kwargs)
36
+
37
+ def compute_output_shape(self, s):
38
+ """ If you are using "channels_last" configuration"""
39
+ return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])
40
+
41
+ def call(self, x, mask=None):
42
+ w_pad, h_pad = self.padding
43
+ return tf.pad(x, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]], 'REFLECT')
44
+
45
+
46
+ class SegmentationNetwork(Model):
47
+
48
+ def __init__(self, filters=64, kernels=(3, 3)):
49
+ super(RecAE, self).__init__()
50
+ self.conv_1 = Conv2D(filters, kernels)
51
+ self.conv_2 = Conv2D(filters, kernels)
52
+
53
+ def call(self, inputs):
54
+ x = K.relu(inputs)
55
+ x = self.conv_1(x)
56
+ x = K.relu(x)
57
+ x = self.conv_2(x)
58
+ return x + inputs
59
+
60
+
61
+ class QueryNetwork(Model):
62
+
63
+ def __init__(self, units=64):
64
+ super(RecAE, self).__init__()
65
+ self.conv_1 = Dense(units)
66
+ self.conv_2 = Dense(units)
67
+
68
+ def call(self, inputs):
69
+ x = K.relu(inputs)
70
+ x = self.conv_1(x)
71
+ x = K.relu(x)
72
+ x = self.conv_2(x)
73
+ return x + inputs
74
+
75
+
76
+ class RecAE(Model):
77
+
78
+ def __init__(self, head, bottle, decoder):
79
+ super(RecAE, self).__init__()
80
+ self.head = head
81
+ self.bottle = bottle
82
+ self.base = clone_model(bottle)
83
+ self.decoder = decoder
84
+ self.segmentation_network = SegmentationNetwork()
85
+ self.query_network = QueryNetwork()
86
+ self.control = LSTMCell(64)
87
+ self.memory = LSTMCell(64)
88
+
89
+ def call(self, inputs):
90
+ feature = self.head(inputs)
91
+ segmentation = self.segmentation_network(feature)
92
+ control_base = self.base(feature)
93
+ h_c = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
94
+ h_m = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
95
+ shape = K.shape(feature)[:-1]
96
+ full_attention = tf.zeros(shape)[..., tf.newaxis]
97
+ full_image = tf.zeros(K.shape(inputs))
98
+ masks = []
99
+ ff = tf.zeros(K.shape(inputs))
100
+ scope = tf.ones(shape)[..., tf.newaxis]
101
+ for i in range(10):
102
+ r_c, h_c = self.control(tf.concat([control_base, h_m[0]], 1), h_c)
103
+ query = self.query_network(h_c[0])
104
+ log_attention = image_attention(segmentation, query)
105
+ attention = K.softmax(log_attention)
106
+ mask = attention * scope
107
+ scope = scope - mask
108
+ im = feature * mask
109
+ # im = feature
110
+ latent = self.bottle(im)
111
+ decoded = self.decoder(latent)
112
+ # self.add_loss(K.mean(-mse(full_attention, attention)))
113
+ # self.add_loss(K.mean(-mse(tf.ones(attention.shape), attention)))
114
+ full_attention += attention
115
+ big_mask = tf.image.resize(mask, K.shape(inputs)[1:-1])
116
+ ff += K.sigmoid(decoded)
117
+ full_image += K.sigmoid(decoded) * big_mask
118
+ r_m, h_m = self.memory(latent, h_m)
119
+ masks.append(big_mask)
120
+ self.add_loss(K.mean(mse(inputs, full_image)))
121
+ return full_image, masks
122
+
123
+
124
+ # def image_attention(image, query, scale=True):
125
+ @tf.function
126
+ def image_attention(image, query):
127
+ log_attention = K.sum(query[:, tf.newaxis, tf.newaxis, :] * image, axis=-1, keepdims=True)
128
+ # if scale is not None:
129
+ log_attention /= tf.sqrt(tf.cast(K.shape(image)[-1], dtype=float))
130
+ return log_attention
131
+
132
+
133
+ class RecAE_2(Model):
134
+
135
+ def __init__(self, head, bottle, decoder):
136
+ super(RecAE_2, self).__init__()
137
+ self.head = head
138
+ self.bottle = bottle
139
+ # self.base = clone_model(bottle)
140
+ self.base = self.bottle
141
+ self.decoder = decoder
142
+ self.segmentation_network = create_conv_model((64, 64, 1))
143
+ self.control = LSTMCell(64)
144
+ self.memory = LSTMCell(64)
145
+
146
+ def call(self, inputs):
147
+ feature = self.head(inputs)
148
+ control_base = self.base(feature)
149
+ h_c = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
150
+ h_m = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
151
+ shape = K.shape(feature)[:-1]
152
+ full_attention = tf.zeros(shape)[..., tf.newaxis]
153
+ full_image = tf.zeros(K.shape(inputs))
154
+ big_masks = []
155
+ masks = []
156
+ ff = tf.zeros(K.shape(inputs))
157
+ scope = tf.ones(shape)[..., tf.newaxis]
158
+ for i in range(4):
159
+ if i ==3:
160
+ mask = scope
161
+ else:
162
+ r_c, h_c = self.control(tf.concat([control_base, h_m[0]], 1), h_c)
163
+ query = broadcast(h_c[0], feature.shape[1:])
164
+ log_attention = self.segmentation_network(tf.concat([feature, query], axis=-1))
165
+ attention = K.sigmoid(log_attention)
166
+ mask = attention * scope
167
+ scope = scope - mask
168
+ masks.append(mask)
169
+ im = feature * mask
170
+ # im = feature
171
+ latent = self.bottle(im)
172
+ decoded = self.decoder(latent)
173
+ # self.add_loss(K.mean(-mse(scope, mask)))
174
+ sum = K.sum(tf.ones(K.shape(mask)))
175
+ self.add_loss(K.abs((sum/4)-K.sum(mask))/sum)
176
+ # self.add_loss(K.mean(-mse(tf.zeros(K.shape(mask)), mask)))
177
+ for m in masks:
178
+ self.add_loss(K.mean(-mse(m,mask)))
179
+
180
+ full_attention += mask
181
+ big_mask = tf.image.resize(mask, K.shape(inputs)[1:-1])
182
+ ff += K.sigmoid(decoded)
183
+ full_image += K.sigmoid(decoded) * big_mask
184
+ r_m, h_m = self.memory(latent, h_m)
185
+ big_masks.append(big_mask)
186
+ self.add_loss(K.mean(mse(inputs, full_image)))
187
+ return full_image, big_masks
raven_utils/models/augment.py ADDED
File without changes
raven_utils/models/body.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+
3
+ import tensorflow as tf
4
+ from ml_utils import self_product, lw
5
+
6
+ from models_utils import DictModel, ListModel, Flat, bm, Base, Cat, Res, Flat2, conv, KERNEL_SIZE, FILTERS, SAME, \
7
+ Get, SM, bs, RELU, ACTIVATION, dense, bd, HardBlock, MaxBlock
8
+ import models_utils.ops as K
9
+ from models_utils import Merge, SoftBlock
10
+ from models_utils.build import build_multi_dense, build_multi_conv, build_conv_model, build_encoder
11
+ from tensorflow.keras.layers import Lambda, Dense
12
+ from tensorflow.keras.layers import Conv2D
13
+
14
+ from config.constant import MEMORY, CONTROL, LATENT, MERGE, CONCAT, INFERENCE, FLATTEN
15
+ from models_utils.config import config
16
+
17
+
18
+ class RavRes(Res):
19
+ def __init__(self, model="v2", latent=256, act=RELU):
20
+ super().__init__(model=model)
21
+ self.latent = latent
22
+
23
+ def call(self, inputs):
24
+ return self.model(inputs) + inputs[0][:, ..., self.latent:]
25
+
26
+
27
+ # not working
28
+ class RavResConv(Res):
29
+ def __init__(self, model="v2", latent=256, act=RELU):
30
+ super().__init__(model=model)
31
+ self.latent = latent
32
+ self.conv = conv(latent, (1, 1), activation=act)
33
+
34
+ def call(self, inputs):
35
+ return self.model(inputs) + self.conv(inputs[0])
36
+
37
+
38
+ class RavResDense(Res):
39
+ def __init__(self, model="v2", latent=256, act=config.DEF_DENSE.activation):
40
+ super().__init__(model=model)
41
+ self.latent = latent
42
+ self.conv = dense(latent, activation=act)
43
+
44
+ def call(self, inputs):
45
+ return self.model(inputs) + self.conv(inputs[0])
46
+
47
+
48
+ def create_dense_block(latent=256, loop=1):
49
+ soft_block = Res(SoftBlock(build_multi_dense(latent), add_identity=None,
50
+ score_activation=tf.sigmoid), latent=latent)
51
+ cells = [
52
+ (lambda x: K.cat([x[:, 0], x[:, 1]]), LATENT, CONCAT),
53
+ (None, CONCAT, MEMORY),
54
+ (Dense(latent), CONCAT, MERGE),
55
+ (Merge(latent), [INFERENCE, MERGE], CONTROL),
56
+ (soft_block, [MEMORY, CONTROL], MEMORY)
57
+ ]
58
+
59
+ return ListModel([DictModel(*cell) for cell in cells] * loop, [LATENT, INFERENCE], MEMORY)
60
+
61
+
62
+ def build_multi_conv(filters=32, end_filters=64, padding="same",mul=1, norm=None, **kwargs):
63
+ base = [(1, 3), (3, 1), (3, 3)]
64
+ block = list(self_product(base))
65
+ block2 = [b + b[0:1] for b in block]
66
+ block3 = [b + b for b in block]
67
+ block4 = ([[(3, 3)]] + [[(3, 3), (3, 3)]] + [[(3, 3), (3, 3), (3, 3)]]) * 2
68
+ block5 = [[], []]
69
+ all_blocks = [s for b in [block, block2, block3, block4, block5] for s in b]
70
+ start = {
71
+ FILTERS: filters,
72
+ KERNEL_SIZE: (1, 1)
73
+ }
74
+
75
+ end = {
76
+ FILTERS: end_filters,
77
+ KERNEL_SIZE: (1, 1),
78
+ ACTIVATION: None
79
+ }
80
+
81
+ all_arch = []
82
+ for ab in all_blocks:
83
+ arch = [{
84
+ FILTERS: filters,
85
+ KERNEL_SIZE: a,
86
+ **kwargs
87
+ } for a in ab]
88
+ all_arch.append([start] + arch + [end])
89
+
90
+ all_arch = all_arch * mul
91
+
92
+ return [
93
+ build_encoder(a, add_norm=norm if norm else None, padding=padding, name=f"b{i}", order=(1, 0) if norm else None)
94
+ for i, a in enumerate(all_arch)]
95
+
96
+
97
+ def create_block(latent=256, simpler=0, loop=1, padding=SAME, norm=None, trans_div=2, act="pass", type_="conv",
98
+ block_=SoftBlock,max_k=16,
99
+ **kwargs):
100
+ trans_size = int(latent / trans_div)
101
+ # if block_ == HardBlock:
102
+ # mul = 2
103
+ # elif block_ == MaxBlock:
104
+ # mul = int(38/max_k)
105
+ # else:
106
+ # mul = 1
107
+
108
+ if act == "pass":
109
+ res_class = RavRes
110
+ else:
111
+ if type_ == "dense":
112
+ res_class = RavResDense
113
+ else:
114
+ res_class = RavResConv
115
+
116
+ if type_ == "dense":
117
+ build_res = lambda: Res(model="dv2")
118
+ # build_reduction = lambda: bm([dense(latent), "IN"])
119
+ build_reduction = lambda: dense(latent)
120
+ build_flatten = lambda: bd([latent] * 2)
121
+ else:
122
+ build_res = lambda: Res(padding=padding)
123
+ build_reduction = lambda: bm([conv(trans_size if simpler else latent, 1, padding=padding), "BN"])
124
+ # build_reduction = lambda: bm([conv(latent, 1, padding=padding), "BN"])
125
+ # build_reduction = lambda: bm([conv(trans_size, 1, padding=padding), "BN"])
126
+ # build_reduction = lambda: conv(trans_size, 1, padding=padding)
127
+ # build_flatten = lambda: Flat2(filters=trans_size,res_no=2, padding=padding, units=64)
128
+ build_flatten = lambda: Flat2(filters=trans_size,padding=padding, units=64)
129
+
130
+ if simpler == 1:
131
+ cells = [
132
+ (lambda x: K.cat([x[:, 0], x[:, 1]]), LATENT, CONCAT,"concatenation"),
133
+ # (None, CONCAT, MEMORY),
134
+ (build_reduction(), CONCAT, MERGE,"Start_resnet_block"),
135
+ # (Get(), INFERENCE, INFERENCE),
136
+ (K.cat, [INFERENCE, MERGE], CONTROL,"concatenation"),
137
+ ]
138
+ else:
139
+ cells = [
140
+ (lambda x: K.cat([x[:, 0], x[:, 1]]), LATENT, CONCAT),
141
+ (build_reduction(), CONCAT, MEMORY),
142
+ (build_reduction(), INFERENCE, CONTROL),
143
+ ]
144
+ for i, l in enumerate(lw(loop)):
145
+ if l:
146
+ concat = K.cat
147
+ control_reduction = build_reduction()
148
+ control_res = build_res()
149
+ control_flatten = build_flatten()
150
+ if i == 0 and simpler == 1:
151
+ rest_params = {
152
+ "latent": latent,
153
+ "act": act
154
+ }
155
+ else:
156
+ rest_params = {
157
+ "latent": 0
158
+ }
159
+
160
+
161
+ if block_ == SoftBlock:
162
+ block_params = {
163
+ }
164
+ else:
165
+ block_params = {
166
+ "trans_output_shape": latent
167
+ }
168
+ if block_ == MaxBlock:
169
+ block_params['max_k'] = max_k
170
+
171
+
172
+ # todo change name
173
+ soft_block = res_class(
174
+ block_(
175
+ build_multi_dense(latent) if type_ == "dense" else build_multi_conv(trans_size, end_filters=latent,
176
+ norm=norm, padding=padding,
177
+ **kwargs),
178
+ add_identity=None,
179
+ score_activation=tf.sigmoid,
180
+ **block_params
181
+
182
+ ),
183
+ **rest_params)
184
+
185
+ if i == 0 and simpler == 1:
186
+ cells.extend([
187
+ (control_reduction, CONTROL, CONTROL,"Reduction"),
188
+ (control_res, CONTROL, CONTROL,"Control_resnet_block"),
189
+ (control_flatten, CONTROL, FLATTEN,"Weights"),
190
+ (soft_block, [CONCAT, FLATTEN], MEMORY,"Transformation"),
191
+ # (soft_block, [MEMORY, FLATTEN], MEMORY,"Transformation"),
192
+ ])
193
+ else:
194
+ if l:
195
+ memory_res = build_res()
196
+
197
+ cells.extend([
198
+ (memory_res, MEMORY, MEMORY,"Memory_resnet_block"),
199
+ (concat, [CONTROL, MEMORY], CONTROL,"concatenation"),
200
+ (control_reduction, CONTROL, CONTROL,"Reduction"),
201
+ (control_res, CONTROL, CONTROL,"Control_resnet_block"),
202
+ (control_flatten, CONTROL, FLATTEN,"Weights"),
203
+ (soft_block, [MEMORY, FLATTEN], MEMORY, "Transformation"),
204
+ ])
205
+ return ListModel([DictModel(*cell) for cell in cells], [LATENT, INFERENCE], MEMORY, debug_=False)
206
+
207
+ #
208
+ #
209
+ # def test(x):
210
+ # np.zeros(4)
211
+ # self_product((1, 3))
212
+ #
213
+ #
214
+ # list(itertools.product())
215
+ # u.layers[0].layers[-1].model.layers[1]
216
+
217
+ # class RecurrentBodyDict(Model):
218
+ # # def __init__(self, start=None, cell=None, output_network=None, output_activation="tanh", latent=64, loop_no=5):
219
+ # def __init__(self, start=None, cell=None, output_network=None, output_activation=None, latent=64, loop_no=5):
220
+ # super().__init__()
221
+ # self.start = sm(start, lambda: SubClassingModel([StartLSTMControl(latent), StartLSTMMemory(latent)]),
222
+ # latent=latent)
223
+ # self.cell = sm(cell, lambda: SubClassingModel([LSTMControl(latent), LSTMMemory(latent)]), latent=latent)
224
+ # self.output_network = sm(output_network, lf(take_memory_states))
225
+ # self.loop_no = loop_no
226
+ # # tmp
227
+ # self.activation = Activation(output_activation)
228
+ #
229
+ # def call(self, inputs):
230
+ # outputs = []
231
+ # for j in range(3):
232
+ # outputs.append(self.start({"latent": inputs[0][j], "inference": inputs[1]}))
233
+ # for i in range(self.loop_no):
234
+ # for j in range(3):
235
+ # outputs[j] = self.cell(outputs[j])
236
+ #
237
+ # return self.activation(self.output_network(outputs))
238
+ #
239
+ #
240
+ # class RecurrentBodySimpleMix4Dict(RecurrentBodyDict):
241
+ # def __init__(self, latent=64, output_network=None, loop_no=5):
242
+ # super().__init__(
243
+ # start=SubClassingModel(
244
+ # [ConcatCell(), DenseCell(latent), InfMergeCell(latent),
245
+ # WeigthCell(latent, layer_no=np.repeat([1, 2, 3, 4, 5, 6, 7, 8], 4),
246
+ # add_identity=Lambda(lambda x: x[:, latent:]))]),
247
+ # cell=False,
248
+ # output_network=output_network, loop_no=0)
249
+ # class RecurrentBodySimpleMix4Conv(RecurrentBodyDict):
250
+ # def __init__(self, latent=64, output_network=None, loop_no=5):
251
+ # super().__init__(
252
+ # start=SubClassingModel(
253
+ # [ConcatCell(), ConvCell(latent), ReduceCell(latent), InfMergeCell(latent),
254
+ # ModelCell(latent=latent, layers_no=2, input_name=CONTROL, result_name=CONTROL),
255
+ # WeigthCell(latent,
256
+ # transformation_network=[build_conv_model2([latent] * i, kernels=(j, j)) for i in range(1, 7) for j in
257
+ # range(1, 5) for _ in range(1)],
258
+ # add_identity=Lambda(lambda x: x[:, ..., latent:]))
259
+ # ]),
260
+ # cell=False,
261
+ # output_network=output_network, loop_no=0)
262
+ #
263
+ #
264
+ # class RecurrentBodySimpleMix4Conv2(RecurrentBodyDict):
265
+ # def __init__(self, latent=64, output_network=None, loop_no=5):
266
+ # super().__init__(
267
+ # start=SubClassingModel(
268
+ # [ConcatCell(), ConvCell(latent), ReduceCell2(latent), InfMergeCell(latent),
269
+ # ModelCell(latent=latent, layers_no=2, input_name=CONTROL, result_name=CONTROL),
270
+ # WeigthCell(latent,
271
+ # transformation_network=[bc([latent] * i, kernels=(j, j)) for i in range(1, 7) for j in
272
+ # range(1, 5) for _ in range(1)],
273
+ # add_identity=Lambda(lambda x: x[:, ..., latent:]))
274
+ # ]),
275
+ # cell=False,
276
+ # output_network=output_network, loop_no=0)
raven_utils/models/class_.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ml_utils import lw
2
+ from models_utils import SubClassingModel, ops as K, Base
3
+ import tensorflow as tf
4
+
5
+
6
+ class Merge(SubClassingModel):
7
+ def call(self, inputs):
8
+ results = []
9
+ for i, model in enumerate(self.model[:-1]):
10
+ results.append(model(inputs[i]))
11
+ # todo why K.cat not working
12
+ results = self.model[-1](tf.concat(results, axis=-1))
13
+ return results
14
+
15
+
16
+ class RavenClass(Base):
17
+ def __init__(self, model, scales=None, no=3, name=None):
18
+ super().__init__(model=model, name=name)
19
+ self.scales = scales
20
+ self.no = no
21
+
22
+ def call(self, inputs):
23
+ inputs = lw(inputs)
24
+ class_res = []
25
+ # for i in range(inputs[0].shape[1]):
26
+ for i in range(self.no):
27
+ # d = [r[:, i] if r.ndim == 5 else r for r in inputs]
28
+ d = [inputs[s][:, i] if inputs[s].ndim > 2 else inputs for s in self.scales]
29
+ class_res.append(self.model(d))
30
+ # return tf.stack(class_res,axis=1)
31
+ return [class_res]
raven_utils/models/head.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from ml_utils import set_default
3
+ from models_utils import build_dense_model, bm, ActivationModel, sm, large_conv_dense_encoder, Pass
4
+ from models_utils import res
5
+ from tensorflow.keras import Model
6
+ from models_utils import ops as K
7
+ from tensorflow.keras.layers import Dense, Conv2D, Flatten
8
+ from keras.backend import batch_flatten
9
+
10
+
11
+ # todo Refactoring
12
+ class HeadModel(Model):
13
+ def __init__(self, encoder=None, inference_network=None, output_size=64, inference_output_size=None,
14
+ inference_activation="relu", stem=None, images_no=8, inference_image_no=None):
15
+ super().__init__()
16
+ # self.encoder = sm(encoder, bm([en.large_conv_dense_encoder(), Dense(output_size)], False))
17
+ self.encoder = encoder or bm([large_conv_dense_encoder(), Dense(output_size)])
18
+ # self.head = head or HeadBatch(encoder=encoder, output_size=output_size)
19
+ inference_output_size = inference_output_size or output_size
20
+ self.inference_network = inference_network or bm([
21
+ K.flat,
22
+ build_dense_model([1028, 512, 512, inference_output_size],
23
+ last_activation=inference_activation)]
24
+ )
25
+ self.stem = stem or Pass()
26
+ self.images_no = images_no
27
+ self.inference_image_no = self.images_no if inference_image_no is None else inference_image_no
28
+
29
+
30
+ class LatentHeadModel(HeadModel):
31
+ def call(self, inputs):
32
+ result = K.map_batch(inputs[:, :self.images_no], self.encoder)
33
+ inference = self.inference_network(result[:, :self.inference_image_no])
34
+ latents = self.stem(result)
35
+ return [latents, inference,result]
36
+
37
+
38
+ # # todo use map_batch
39
+ # class HeadBatch(Model):
40
+ # def __init__(self, encoder=None, output_size=64):
41
+ # super().__init__()
42
+ # self.encoder = sm(encoder, bm([large_conv_dense_encoder(), Dense(output_size)], False))
43
+ #
44
+ # def call(self, inputs):
45
+ # shape = tf.shape(inputs)
46
+ # latents = self.encoder(tf.reshape(inputs, shape=tf.concat([[-1], shape[2:]], axis=-1)))
47
+ # latents = K.reshape(latents, tf.concat([[-1, shape[1]], latents.shape[1:]], axis=-1))
48
+ # return latents
49
+
50
+
51
+ # Not working
52
+ class DuoHeadModel(HeadModel):
53
+ def __init__(self, encoder=None, inference_network=None, images_no=8, filters=-4):
54
+ super().__init__(encoder=encoder, inference_network=inference_network, images_no=images_no)
55
+ self.encoder = ActivationModel(self.encoder, filters=filters, include_input=False)
56
+
57
+ def call(self, inputs):
58
+ shape = inputs.shape
59
+ result = reversed(self.encoder(K.reshape(inputs, shape=[-1] + list(shape[2:]))))
60
+ latents = K.reshape(result[0], [-1, self.images_no] + [result[0].shape[-1]])
61
+ inference = self.inference_network(K.flat(result[1]))
62
+ return [latents, inference]
63
+
64
+
65
+ class MultiHeadModel(Model):
66
+ def __init__(self, encoder=None, images_no=8, filters=(1, 3, 6)):
67
+ super().__init__()
68
+ self.encoder = ActivationModel(encoder, filters=filters, include_input=False)
69
+ self.merge = MergeSacles()
70
+ self.images_no = images_no
71
+
72
+ def call(self, inputs):
73
+ shape = tf.shape(inputs)
74
+ results = self.encoder(tf.reshape(inputs, shape=tf.concat([[-1], shape[2:]], axis=-1)))
75
+ latents = [tf.reshape(result, shape=tf.concat([[-1, self.images_no], tf.shape(result)[1:]], axis=-1)) for result
76
+ in results]
77
+
78
+ l1 = tf.transpose(latents[0], (0, 2, 3, 1, 4))
79
+ # l1 = tf.reshape(l1, tuple(list(l1.shape[:3]) + [l1.shape[-2] * l1.shape[-1]]))
80
+ shape = tf.shape(l1)
81
+ l1 = tf.reshape(l1, tf.concat([[-1], shape[1:3], [shape[-2] * shape[-1]]], axis=-1))
82
+
83
+ l2 = tf.transpose(latents[1], (0, 2, 3, 1, 4))
84
+ # l2 = tf.reshape(l2, [-1] + list(l2.shape[1:3]) + [l2.shape[-2] * l2.shape[-1]])
85
+ shape = tf.shape(l2)
86
+ l2 = tf.reshape(l2, tf.concat([[-1], shape[1:3], [shape[-2] * shape[-1]]], axis=-1))
87
+
88
+ l3 = latents[2]
89
+ shape = tf.shape(l3)
90
+ # l3 = tf.reshape(l3, [-1] + [l3.shape[-2] * l3.shape[-1]])
91
+ l3 = tf.reshape(l3, tf.concat([[-1], [shape[-2] * shape[-1]]], axis=-1))
92
+
93
+ inference = self.merge([l1, l2, l3])
94
+ return [latents, inference]
95
+
96
+
97
+ class MergeSacles(Model):
98
+ def __init__(self):
99
+ super().__init__()
100
+ self.inf_1 = bm([Conv2D(64, 1, activation="relu"), res(64),
101
+ Conv2D(64, 3, strides=2, padding=SAME, activation="relu"),
102
+ res(64),
103
+ Flatten(),
104
+ Dense(256, "relu")])
105
+ self.inf_2 = bm([Conv2D(128, 1, activation="relu"),
106
+ res(128),
107
+ Flatten(),
108
+ Dense(256, "relu")])
109
+ self.inf_3 = Dense(256, "relu")
110
+
111
+ def call(self, inputs):
112
+ il1 = self.inf_1(inputs[0])
113
+ il2 = self.inf_2(inputs[1])
114
+ il3 = self.inf_3(inputs[2])
115
+ inference = tf.concat([il1, il2, il3], axis=1)
116
+ return inference
117
+
118
+
119
+ class MultiHeadModel2(Model):
120
+ def __init__(self, encoder=None, images_no=8, filters=(3, 6)):
121
+ super().__init__()
122
+ self.encoder = ActivationModel(encoder, filters=filters, include_input=False)
123
+ self.merge = MergeSacles2()
124
+ self.images_no = images_no
125
+
126
+ def call(self, inputs):
127
+ shape = tf.shape(inputs)
128
+ results = self.encoder(tf.reshape(inputs, shape=tf.concat([[-1], shape[2:]], axis=-1)))
129
+ latents = [tf.reshape(result, shape=tf.concat([[-1, self.images_no], tf.shape(result)[1:]], axis=-1)) for result
130
+ in results]
131
+
132
+ l2 = tf.transpose(latents[0], (0, 2, 3, 1, 4))
133
+ # l2 = tf.reshape(l2, [-1] + list(l2.shape[1:3]) + [l2.shape[-2] * l2.shape[-1]])
134
+ shape = tf.shape(l2)
135
+ l2 = tf.reshape(l2, tf.concat([[-1], shape[1:3], [shape[-2] * shape[-1]]], axis=-1))
136
+
137
+ l3 = latents[1]
138
+ shape = tf.shape(l3)
139
+ # l3 = tf.reshape(l3, [-1] + [l3.shape[-2] * l3.shape[-1]])
140
+ l3 = tf.reshape(l3, tf.concat([[-1], [shape[-2] * shape[-1]]], axis=-1))
141
+
142
+ inference = self.merge([l2, l3])
143
+ return [latents, inference]
144
+
145
+
146
+ class MergeSacles2(Model):
147
+ def __init__(self):
148
+ super().__init__()
149
+ self.inf_1 = bm([Conv2D(128, 1, activation="relu"),
150
+ res(128),
151
+ Flatten(),
152
+ Dense(256, "relu")])
153
+ self.inf_2 = Dense(256, "relu")
154
+
155
+ def call(self, inputs):
156
+ il1 = self.inf_1(inputs[0])
157
+ il2 = self.inf_2(inputs[1])
158
+ inference = tf.concat([il1, il2], axis=1)
159
+ return inference
raven_utils/models/loss.py ADDED
@@ -0,0 +1,630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import tensorflow as tf
4
+ import tensorflow.experimental.numpy as tnp
5
+ from models_utils import OUTPUT, TARGET, PREDICT, DictModel, add_loss, LOSS, Predict
6
+ from models_utils import SubClassingModel
7
+ from models_utils.models.utils import interleave
8
+ from models_utils.op import reshape
9
+ from tensorflow.keras import Model
10
+ # from tensorflow.keras import backend as K
11
+ from tensorflow.keras.layers import Lambda
12
+ from tensorflow.keras.losses import SparseCategoricalCrossentropy, mse
13
+ from tensorflow.keras.metrics import SparseCategoricalAccuracy, Accuracy, BinaryAccuracy
14
+ import models_utils.ops as K
15
+
16
+ import raven_utils.decode
17
+ import raven_utils as rv
18
+ from raven_utils.config.constant import LABELS, INDEX, ACC_SAME, ACC_CHOOSE_LOWER, ACC_CHOOSE_UPPER, CLASSIFICATION, \
19
+ SLOT, \
20
+ PROPERTIES, ACC, GROUP, NUMBER, MASK
21
+ from raven_utils.models.uitls_ import RangeMask
22
+ from raven_utils.const import VERTICAL, HORIZONTAL
23
+
24
+
25
+ def get_properties_mask(target):
26
+ return target[:, rv.target.END_INDEX:rv.target.UNIFORMITY_INDEX] > 0
27
+
28
+
29
+ def create_change_mask(target):
30
+ properties_mask = get_properties_mask(target)
31
+ return [create_mask(properties_mask, i) for i, _ in enumerate(rv.rules.ATTRIBUTES)]
32
+
33
+
34
+ def create_uniform_mask(target):
35
+ u_mask = lambda i: tf.tile(target[:, rv.target.UNIFORMITY_INDEX + i, None] == 3, [1, rv.rules.ATTRIBUTES_LEN])
36
+ properties_mask = tf.concat([u_mask(0), u_mask(1)], axis=-1) | get_properties_mask(target)
37
+ return [create_mask(properties_mask, i) for i, _ in enumerate(rv.rules.ATTRIBUTES)]
38
+
39
+
40
+ def create_all_mask(target):
41
+ return [
42
+ tf.cast(tf.ones(tf.stack([tf.shape(target)[0], rv.entity.SUM])), dtype=tf.bool) for i, _ in
43
+ enumerate(rv.rules.ATTRIBUTES)]
44
+
45
+
46
+ class BaselineClassificationLossModel(Model):
47
+ def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True):
48
+ super().__init__()
49
+ self.predict_fn = SubClassingModel([lambda x: x[0], PredictModel()])
50
+ self.loss_fn = ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
51
+ group_loss=group_loss)
52
+ self.metric_fn = SimilarityRaven(mode=mode)
53
+
54
+ def call(self, inputs):
55
+ losses = []
56
+ output = inputs[1]
57
+ losses.append(self.loss_fn([inputs[0][0], output]))
58
+ losses.append(self.metric_fn([inputs[0][2], inputs[3][0], inputs[0][1][:, 8:]]))
59
+ return losses
60
+
61
+
62
+ class RavenLoss(Model):
63
+ def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True, lw=(1.0, 0.3),
64
+ classification=False, trans=True, anneal=False):
65
+ super().__init__()
66
+ if anneal:
67
+ self.weight_scheduler
68
+ self.classification = classification
69
+ self.trans = trans
70
+ self.predict_fn = DictModel(SubClassingModel([lambda x: x[-1], PredictModel()]), in_=OUTPUT,
71
+ out=[PREDICT, MASK], name="pred")
72
+ if self.trans:
73
+ self.loss_fn = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
74
+ group_loss=group_loss, enable_metrics=False, lw=lw[0]),
75
+ name="main_loss")
76
+ self.loss_fn_2 = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
77
+ group_loss=group_loss), name="add_loss")
78
+ self.metric_fn = SimilarityRaven(mode=mode)
79
+ if self.classification:
80
+ self.loss_fn_3 = add_loss(
81
+ ClassRavenModel(mode=create_all_mask, number_loss=number_loss, slot_loss=slot_loss,
82
+ group_loss=group_loss, enable_metrics="c" if self.trans else True), lw=lw[1],
83
+ name="class_loss")
84
+
85
+ def call(self, inputs):
86
+ losses = []
87
+ output = inputs[OUTPUT]
88
+ target = inputs[TARGET]
89
+ labels = inputs[LABELS]
90
+
91
+ if self.trans:
92
+ losses.append(self.loss_fn([labels[:, 2], output[0]]))
93
+ losses.append(self.loss_fn([labels[:, 5], output[1]]))
94
+ losses.append(self.loss_fn_2([target, output[2]]))
95
+ losses.append(self.metric_fn([inputs[INDEX], inputs[PREDICT], labels]))
96
+ if self.classification:
97
+ for i in range(8):
98
+ losses.append(self.loss_fn_3([labels[:, i], inputs[CLASSIFICATION][i]]))
99
+ return {**inputs, LOSS: losses}
100
+
101
+
102
+ class VTRavenLoss(Model):
103
+ def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True, lw=(1.0, 0.1),
104
+ classification=False, trans=True, anneal=False, plw=None):
105
+ super().__init__()
106
+ if anneal:
107
+ self.weight_scheduler
108
+ self.classification = classification
109
+ self.trans = trans
110
+ self.predict_fn = DictModel(SubClassingModel([lambda x: x[:, -1], PredictModel()]), in_=OUTPUT,
111
+ out=[PREDICT, MASK], name="pred")
112
+ self.loss_fn = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
113
+ group_loss=group_loss, plw=plw), lw=lw[0] , name="add_loss")
114
+ self.metric_fn = SimilarityRaven(mode=mode)
115
+ if self.classification:
116
+ self.loss_fn_2 = add_loss(
117
+ ClassRavenModel(mode=create_all_mask, number_loss=number_loss, slot_loss=slot_loss,
118
+ group_loss=group_loss, enable_metrics="c", plw=plw), lw=lw[1], name="class_loss")
119
+
120
+ def call(self, inputs):
121
+ losses = []
122
+ output = inputs[OUTPUT]
123
+ target = inputs[TARGET]
124
+ labels = inputs[LABELS]
125
+
126
+ for i in range(9):
127
+ losses.append(self.loss_fn_2([labels[:, i], output[:, i]]))
128
+ losses.append(self.loss_fn([target, output[:, 8]]))
129
+ losses.append(self.metric_fn([inputs[INDEX], inputs[PREDICT], labels]))
130
+ return {**inputs, LOSS: losses}
131
+
132
+
133
+ class SingleVTRavenLoss(Model):
134
+ def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True, lw=(1.0, 0.1),
135
+ classification=False, trans=True, anneal=False):
136
+ super().__init__()
137
+ if anneal:
138
+ self.weight_scheduler
139
+ self.classification = classification
140
+ self.trans = trans
141
+ self.predict_fn = DictModel(PredictModel(), in_=OUTPUT, out=[PREDICT, MASK], name="pred")
142
+ self.loss_fn = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
143
+ group_loss=group_loss), lw=lw[0], name="add_loss")
144
+ self.metric_fn = SimilarityRaven(mode=mode)
145
+
146
+ def call(self, inputs):
147
+ losses = []
148
+ output = inputs[OUTPUT]
149
+ target = inputs[TARGET]
150
+ labels = inputs[LABELS]
151
+
152
+ losses.append(self.loss_fn([target, output]))
153
+ losses.append(self.metric_fn([inputs[INDEX], inputs[PREDICT], labels]))
154
+ return {**inputs, LOSS: losses}
155
+
156
+
157
+ class ClassRavenModel(Model):
158
+ def __init__(self, mode=create_all_mask,plw=None, number_loss=False, slot_loss=True, group_loss=True, enable_metrics=True,
159
+ lw=1.0):
160
+ super().__init__()
161
+ self.number_loss = number_loss
162
+ self.group_loss = group_loss
163
+ self.enable_metrics = enable_metrics
164
+ self.slot_loss = slot_loss
165
+ self.predict_fn = PredictModel()
166
+ self.loss_fn = SparseCategoricalCrossentropy(from_logits=True)
167
+ if self.slot_loss:
168
+ self.loss_fn_2 = tf.nn.sigmoid_cross_entropy_with_logits
169
+ if self.enable_metrics:
170
+ self.enable_metrics = f"{self.enable_metrics}_" if isinstance(self.enable_metrics, str) else ""
171
+ self.metric_fn = [
172
+ SparseCategoricalAccuracy(name=f"{self.enable_metrics}{ACC}_{property_}") for property_ in
173
+ rv.properties.NAMES]
174
+ if self.group_loss:
175
+ self.metric_fn_group = SparseCategoricalAccuracy(name=f"{self.enable_metrics}{ACC}_{GROUP}")
176
+ if self.slot_loss:
177
+ self.metric_fn_2 = BinaryAccuracy(name=f"{self.enable_metrics}{ACC}_{SLOT}")
178
+ self.range_mask = RangeMask()
179
+ self.mode = mode
180
+ self.lw = lw
181
+ if not plw:
182
+ plw = [1., 95.37352927, 2.83426987, 0.85212836, 1.096005, 1.21943385]
183
+ elif isinstance(plw, int) or isinstance(plw, float):
184
+ plw = [1., plw, 2.83426987, 0.85212836, 1.096005, 1.21943385]
185
+ # plw = [plw] * 6
186
+ self.plw = plw
187
+
188
+ # self.predict_fn = partial(tf.argmax, axis=-1)
189
+
190
+ def call(self, inputs):
191
+ losses = []
192
+ metrics = {}
193
+ target = inputs[0]
194
+ output = inputs[1]
195
+
196
+ target_group, target_slot, target_all = raven_utils.decode.decode_target(target)
197
+
198
+ group_output, output_slot, outputs = raven_utils.decode.output_divide(output, split_fn=tf.split)
199
+
200
+ # group
201
+ if self.group_loss:
202
+ group_loss = self.lw * self.plw[0] * self.loss_fn(target_group, group_output)
203
+ losses.append(group_loss)
204
+
205
+ if isinstance(self.enable_metrics, str):
206
+ group_metric = self.metric_fn_group(target_group, group_output)
207
+ # metrics[GROUP] = group_metric
208
+ self.add_metric(group_metric)
209
+ self.add_metric(tf.reduce_sum(group_metric), f"{self.enable_metrics}{ACC}")
210
+
211
+ # setting uniformity mask
212
+ full_properties_musks = self.mode(target)
213
+
214
+ range_mask = self.range_mask(target_group)
215
+
216
+ if self.slot_loss:
217
+ # number
218
+ number_mask = range_mask & full_properties_musks[0]
219
+ number_mask = tf.cast(number_mask, tf.float32)
220
+ target_number = tf.reduce_sum(
221
+ tf.cast(target_slot, "float32") * number_mask, axis=-1)
222
+ output_number = tf.reduce_sum(
223
+ tf.cast(tf.sigmoid(output_slot) >= 0.5, "float32") * number_mask, axis=-1)
224
+
225
+ # output_number = tf.reduce_sum(tf.sigmoid(output_slot) * number_mask, axis=-1)
226
+ if self.number_loss:
227
+ scale = 1 / 9
228
+ if self.number_loss == 2:
229
+ output_number_2 = tf.reduce_sum(tf.sigmoid(output_slot) * number_mask, axis=-1)
230
+ else:
231
+ output_number_2 = output_number
232
+ number_loss = self.lw * self.plw[1] * mse(tf.stop_gradient(target_number) * scale, output_number_2 * scale)
233
+ losses.append(number_loss)
234
+
235
+ # metrics[NUMBER] = number_acc
236
+
237
+ if isinstance(self.enable_metrics, str):
238
+ number_acc = tf.reduce_mean(
239
+ tf.cast(tf.cast(target_number, "int8") == tf.cast(output_number, "int8"), "float32"))
240
+ self.add_metric(tf.reduce_sum(number_acc), f"{self.enable_metrics}{ACC}_{NUMBER}")
241
+ self.add_metric(tf.reduce_sum(number_acc), f"{self.enable_metrics}{ACC}")
242
+ self.add_metric(tf.reduce_sum(number_acc), f"{self.enable_metrics}{ACC}_NO_{GROUP}")
243
+
244
+ # position/slot
245
+ slot_mask = range_mask & full_properties_musks[1]
246
+ # tf.boolean_mask(target_slot,slot_mask)
247
+
248
+ if tf.reduce_any(slot_mask):
249
+ # if tf.reduce_mean(tf.cast(slot_mask, dtype=tf.int32)) > 0:
250
+ target_slot_masked = tf.boolean_mask(target_slot, slot_mask)[:, None]
251
+ output_slot_masked = tf.boolean_mask(output_slot, slot_mask)[:, None]
252
+ loss_slot = self.lw * self.plw[2] * tf.reduce_mean(
253
+ self.loss_fn_2(tf.cast(target_slot_masked, "float32"), output_slot_masked))
254
+ if isinstance(self.enable_metrics, str):
255
+ acc_slot = self.metric_fn_2(target_slot_masked, output_slot_masked)
256
+ self.add_metric(acc_slot)
257
+ self.add_metric(tf.reduce_sum(acc_slot), f"{self.enable_metrics}{ACC}")
258
+ self.add_metric(tf.reduce_sum(acc_slot), f"{self.enable_metrics}{ACC}_NO_{GROUP}")
259
+ else:
260
+ loss_slot = 0.0
261
+ acc_slot = -1.0
262
+
263
+ losses.append(loss_slot)
264
+ # metrics[SLOT] = acc_slot
265
+ # if loss_slot != 0:
266
+
267
+ # if tf.reduce_any(slot_mask):
268
+
269
+ # self.add_metric(acc_slot, f"{self.enable_metrics}{ACC}_{NUMBER}")
270
+ # self.add_metric(acc_slot, f"{self.enable_metrics}{ACC}")
271
+ # self.add_metric(acc_slot, f"{self.enable_metrics}{ACC}_NO_{GROUP}")
272
+
273
+ # properties
274
+ for i, out in enumerate(outputs):
275
+ shape = (-1, rv.entity.SUM, rv.properties.RAW_SIZE[i])
276
+ out_reshaped = tf.reshape(out, shape)
277
+ properties_mask = tf.cast(target_slot, "bool") & full_properties_musks[i + 2]
278
+
279
+ if tf.reduce_any(properties_mask):
280
+ out_masked = tf.boolean_mask(out_reshaped, properties_mask)
281
+ out_target = tf.boolean_mask(target_all[i], properties_mask)
282
+ loss = self.lw * self.plw[3+i] * self.loss_fn(out_target, out_masked)
283
+ if isinstance(self.enable_metrics, str):
284
+ metric = self.metric_fn[i](out_target, out_masked)
285
+ self.add_metric(metric)
286
+ # self.add_metric(metric, f"{self.enable_metrics}{ACC}")
287
+ self.add_metric(tf.reduce_sum(metric), f"{self.enable_metrics}{ACC}")
288
+ self.add_metric(tf.reduce_sum(metric), f"{self.enable_metrics}{ACC}_{PROPERTIES}")
289
+ self.add_metric(tf.reduce_sum(metric), f"{self.enable_metrics}{ACC}_NO_{GROUP}")
290
+ else:
291
+ loss = 0.0
292
+ metric = -1.0
293
+
294
+ losses.append(loss)
295
+ return losses
296
+
297
+
298
+ class FullMask(Model):
299
+ def __init__(self, mode=create_uniform_mask):
300
+ super().__init__()
301
+ self.range_mask = RangeMask()
302
+ self.mode = mode
303
+
304
+ def call(self, inputs):
305
+ target_group, target_slot, _ = raven_utils.decode.decode_target(inputs)
306
+ full_properties_musks = self.mode(inputs)
307
+ range_mask = self.range_mask(target_group)
308
+
309
+ number_mask = range_mask & full_properties_musks[0]
310
+
311
+ slot_mask = range_mask & full_properties_musks[1]
312
+ properties_mask = []
313
+ for property_mask in full_properties_musks[2:]:
314
+ properties_mask.append(tf.cast(target_slot, "bool") & property_mask)
315
+ return [slot_mask, properties_mask, number_mask]
316
+
317
+
318
+ def create_mask(rules, i):
319
+ mask_1 = tf.tile(rules[:, i][None], [len(rv.target.FIRST_LAYOUT), 1])
320
+ mask_2 = tf.tile(rules[:, i + 5][None], [len(rv.target.SECOND_LAYOUT), 1])
321
+ shape = tf.shape(rules)
322
+ full_mask_1 = tf.scatter_nd(tnp.array(rv.target.FIRST_LAYOUT)[:, None], mask_1, shape=(rv.entity.SUM, shape[0]))
323
+ full_mask_2 = tf.tensor_scatter_nd_update(full_mask_1, tnp.array(rv.target.SECOND_LAYOUT)[:, None], mask_2)
324
+ return tf.transpose(full_mask_2)
325
+
326
+
327
+ # class PredictModel(Model):
328
+ # def __init__(self):
329
+ # super().__init__()
330
+ # self.predict_fn = Lambda(partial(tf.argmax, axis=-1))
331
+ # self.predict_fn_2 = Lambda(lambda x: tf.sigmoid(x) > 0.5)
332
+ # self.range_mask = RangeMask()
333
+ #
334
+ # # self.predict_fn = partial(tf.argmax, axis=-1)
335
+ #
336
+ # def call(self, inputs):
337
+ # group_output = inputs[rv.OUTPUT_GROUP_SLICE]
338
+ # group_loss = self.predict_fn(group_output)[:, None]
339
+ #
340
+ # output_slot = inputs[rv.OUTPUT_SLOT_SLICE]
341
+ # range_mask = self.range_mask(group_loss[:, 0])
342
+ # loss_slot = tf.cast(self.predict_fn_2(output_slot), dtype=tf.int64)
343
+ #
344
+ # properties_output = inputs[rv.OUTPUT_PROPERTIES_SLICE]
345
+ # properties = []
346
+ # outputs = tf.split(properties_output, list(rv.ENTITY_PROPERTIES_INDEX.values()), axis=-1)
347
+ # for i, out in enumerate(outputs):
348
+ # shape = (-1, rv.ENTITY_SUM, rv.ENTITY_PROPERTIES_VALUES[i])
349
+ # out_reshaped = tf.reshape(out, shape)
350
+ # properties.append(self.predict_fn(out_reshaped))
351
+ # number_loss = tf.reduce_sum(loss_slot, axis=-1, keepdims=True)
352
+ #
353
+ # result = tf.concat([group_loss, loss_slot, interleave(properties), number_loss], axis=-1)
354
+ #
355
+ # return [result, range_mask, range_mask, range_mask, range_mask]
356
+
357
+ class PredictModel(Model):
358
+ def __init__(self):
359
+ super().__init__()
360
+ self.predict_fn = Predict()
361
+ self.predict_fn_2 = Lambda(lambda x: tf.sigmoid(x) > 0.5)
362
+ self.range_mask = RangeMask()
363
+
364
+ # self.predict_fn = partial(tf.argmax, axis=-1)
365
+
366
+ def call(self, inputs):
367
+ group_output, output_slot, *properties = rv.decode.output(inputs, tf.split, self.predict_fn, self.predict_fn_2)
368
+ number_loss = K.int64(K.sum(output_slot))
369
+ result = tf.concat(
370
+ [group_output[:, None], tf.cast(output_slot, dtype=tf.int64), interleave(properties), number_loss[:, None]],
371
+ axis=-1)
372
+
373
+ range_mask = self.range_mask(group_output)
374
+ return [result, range_mask]
375
+ # return [result, range_mask, range_mask, range_mask, range_mask]
376
+
377
+
378
+ # todo change slices
379
+ class PredictModelMasked(Model):
380
+ def __init__(self):
381
+ super().__init__()
382
+ self.predict_fn = Lambda(partial(tf.argmax, axis=-1))
383
+ self.loss_fn_2 = Lambda(lambda x: tf.sigmoid(x) > 0.5)
384
+ self.range_mask = RangeMask()
385
+
386
+ # self.predict_fn = partial(tf.argmax, axis=-1)
387
+
388
+ def call(self, inputs):
389
+ group_output = inputs[:, -rv.GROUPS_NO:]
390
+ group_loss = self.predict_fn(group_output)[:, None]
391
+
392
+ output_slot = inputs[:, :rv.ENTITY_SUM]
393
+ range_mask = self.range_mask(group_loss[:, 0])
394
+ loss_slot = tf.cast(self.predict_fn_2(output_slot * range_mask), dtype=tf.int64)
395
+
396
+ properties_output = inputs[:, rv.ENTITY_SUM:-rv.GROUPS_NO]
397
+
398
+ properties = []
399
+ outputs = tf.split(properties_output, list(rv.ENTITY_PROPERTIES_INDEX.values()), axis=-1)
400
+ for i, out in enumerate(outputs):
401
+ shape = (-1, rv.ENTITY_SUM, rv.ENTITY_PROPERTIES_VALUES[i])
402
+ out_reshaped = tf.reshape(out, shape)
403
+ out_masked = out_reshaped * loss_slot[..., None]
404
+ properties.append(self.predict_fn(out_masked))
405
+ # out_masked[0].numpy()
406
+ number_loss = tf.reduce_sum(loss_slot, axis=-1, keepdims=True)
407
+
408
+ result = tf.concat([group_loss, loss_slot, interleave(properties), number_loss], axis=-1)
409
+
410
+ return result
411
+
412
+
413
+ def final_predict_mask(x, mask):
414
+ r = reshape(x[0][:, rv.INDEX[0]:-1], [-1, 3])
415
+ return tf.ragged.boolean_mask(r, mask)
416
+
417
+
418
+ def final_predict(x, mode=False):
419
+ m = x[1] if mode else tf.cast(x[0][:, 1:rv.INDEX[0]], tf.bool)
420
+ return final_predict_mask(x[0], m)
421
+
422
+
423
+ def final_predict_2(x):
424
+ ones = tf.cast(tf.ones(tf.shape(x[0])[0]), tf.bool)[:, None]
425
+ mask = tf.concat([ones, tf.tile(x[1], [1, 4]), ones], axis=-1)
426
+ return tf.ragged.boolean_mask(x[0], mask)
427
+
428
+
429
+ class PredictModelOld(Model):
430
+
431
+ def call(self, inputs):
432
+ output = inputs[-2]
433
+
434
+ rest_output = output[:, :-rv.GROUPS_NO]
435
+
436
+ result_all = []
437
+ outputs = tf.split(rest_output, list(rv.ENTITY_PROPERTIES_INDEX.values()), axis=-3)
438
+ for i, out in enumerate(outputs):
439
+ shape = (-3, rv.ENTITY_SUM, rv.ENTITY_PROPERTIES_VALUES[i])
440
+ out_reshaped = tf.reshape(out, shape)
441
+
442
+ result = tf.cast(tf.argmax(out_reshaped, axis=-3), dtype="int8")
443
+ result_all.append(result)
444
+
445
+ result_all = interleave(result_all)
446
+ return result_all
447
+
448
+
449
+ def get_matches(diff, target_index):
450
+ diff_sum = K.sum(diff)
451
+ db_argsort = tf.argsort(diff_sum, axis=-1)
452
+ db_sorted = tf.sort(diff_sum)
453
+ db_mask = db_sorted[:, 0, None] == db_sorted
454
+ db_same = tf.where(db_mask, db_argsort, -1 * tf.ones_like(db_argsort))
455
+ matched_index = db_same == target_index
456
+ # setting shape needed for TensorFlow graph
457
+ matched_index.set_shape(db_same.shape)
458
+ matches = K.any(matched_index)
459
+ more_matches = K.sum(db_mask) > 1
460
+ once_matches = K.sum(matches & tf.math.logical_not(more_matches))
461
+ return matches, more_matches, once_matches
462
+
463
+
464
+ class SimilarityRaven(Model):
465
+ def __init__(self, mode=create_all_mask, number_loss=False):
466
+ super().__init__()
467
+ self.range_mask = RangeMask()
468
+ self.mode = mode
469
+
470
+ # self.predict_fn = partial(tf.argmax, axis=-1)
471
+
472
+ # INDEX, PREDICT, LABELS
473
+ def call(self, inputs):
474
+ metrics = []
475
+ target_index = inputs[0] - 8
476
+ predict = inputs[1]
477
+ answers = inputs[2][:, 8:]
478
+ shape = tf.shape(predict)
479
+
480
+ target = K.gather(answers, target_index[:, 0])
481
+
482
+ target_group = target[:, 0]
483
+
484
+ # comp_slice = np.
485
+ target_comp = target[:, 1:rv.target.END_INDEX]
486
+ predict_comp = predict[:, 1:rv.target.END_INDEX]
487
+ answers_comp = answers[:, :, 1:rv.target.END_INDEX]
488
+
489
+ full_properties_musks = self.mode(target)
490
+ fpm = K.cat([full_properties_musks[0], interleave(full_properties_musks[2:])])
491
+
492
+ range_mask = self.range_mask(target_group)
493
+ full_range_mask = K.cat([range_mask, tf.repeat(range_mask, 3, axis=-1)], axis=-1)
494
+
495
+ final_mask = fpm & full_range_mask
496
+
497
+ target_masked = target_comp * final_mask
498
+ predict_masked = predict_comp * final_mask
499
+ answers_masked = answers_comp * tf.tile(final_mask[:, None], [1, 8, 1])
500
+
501
+ acc_same = K.mean(K.all(target_masked == predict_masked))
502
+ self.add_metric(acc_same, ACC_SAME)
503
+ metrics.append(acc_same)
504
+
505
+ diff = tf.abs(predict_masked[:, None] - answers_masked)
506
+ diff_bool = diff != 0
507
+
508
+ matches, more_matches, once_matches = get_matches(tf.cast(diff_bool, dtype=tf.int32), target_index)
509
+
510
+ second_phase_mask = (more_matches & matches)
511
+ diff_second_phase = tf.boolean_mask(diff, second_phase_mask)
512
+ target_index_2 = tf.boolean_mask(target_index, second_phase_mask, axis=0)
513
+
514
+ matches_2, more_matches_2, once_matches_2 = get_matches(diff_second_phase, target_index_2)
515
+ matches_2_no = K.sum(matches_2)
516
+
517
+ acc_choose_upper = (once_matches + matches_2_no) / shape[0]
518
+ self.add_metric(acc_choose_upper, ACC_CHOOSE_UPPER)
519
+ metrics.append(acc_choose_upper)
520
+
521
+ acc_choose_lower = (once_matches + once_matches_2) / shape[0]
522
+ self.add_metric(acc_choose_lower, ACC_CHOOSE_LOWER)
523
+ metrics.append(acc_choose_lower)
524
+
525
+ return metrics
526
+
527
+
528
+ class SimilarityRaven2(Model):
529
+ def __init__(self, mode=create_all_mask, number_loss=False):
530
+ super().__init__()
531
+ self.range_mask = RangeMask()
532
+ self.mode = mode
533
+
534
+ # self.predict_fn = partial(tf.argmax, axis=-1)
535
+
536
+ # INDEX, PREDICT, LABELS
537
+ def call(self, inputs):
538
+ metrics = []
539
+ target_index = inputs[0] - 8
540
+ predict = inputs[1]
541
+ answers = inputs[2][:, 8:]
542
+ shape = tf.shape(predict)
543
+
544
+ target = K.gather(answers, target_index[:, 0])
545
+
546
+ target_group = target[:, 0]
547
+
548
+ # comp_slice = np.
549
+ target_comp = target[:, 1:rv.target.END_INDEX]
550
+ predict_comp = predict[:, 1:rv.target.END_INDEX]
551
+ answers_comp = answers[:, :, 1:rv.target.END_INDEX]
552
+
553
+ full_properties_musks = self.mode(target)
554
+ fpm = K.cat([full_properties_musks[0], interleave(full_properties_musks[2:])])
555
+
556
+ range_mask = self.range_mask(target_group)
557
+ full_range_mask = K.cat([range_mask, tf.repeat(range_mask, 3, axis=-1)], axis=-1)
558
+
559
+ final_mask = fpm & full_range_mask
560
+
561
+ target_masked = target_comp * final_mask
562
+ predict_masked = predict_comp * final_mask
563
+ answers_masked = answers_comp * tf.tile(final_mask[:, None], [1, 8, 1])
564
+
565
+ acc_same = K.mean(K.all(target_masked == predict_masked))
566
+ self.add_metric(acc_same, ACC_SAME)
567
+ metrics.append(acc_same)
568
+
569
+ diff = tf.abs(predict_masked[:, None] - answers_masked)
570
+ diff_bool = diff != 0
571
+
572
+ matches, more_matches, once_matches = get_matches(tf.cast(diff_bool, dtype=tf.int32), target_index)
573
+
574
+ second_phase_mask = (more_matches & matches)
575
+ diff_second_phase = tf.boolean_mask(diff, second_phase_mask)
576
+ target_index_2 = tf.boolean_mask(target_index, second_phase_mask, axis=0)
577
+
578
+ matches_2, more_matches_2, once_matches_2 = get_matches(diff_second_phase, target_index_2)
579
+ matches_2_no = K.sum(matches_2)
580
+
581
+ acc_choose_upper = (once_matches + matches_2_no) / shape[0]
582
+ self.add_metric(acc_choose_upper, ACC_CHOOSE_UPPER)
583
+ metrics.append(acc_choose_upper)
584
+
585
+ acc_choose_lower = (once_matches + once_matches_2) / shape[0]
586
+ self.add_metric(acc_choose_lower, ACC_CHOOSE_LOWER)
587
+ metrics.append(acc_choose_lower)
588
+
589
+ metrics.append(K.sum(target_masked != predict_masked))
590
+
591
+ return metrics
592
+
593
+
594
+ class LatentLossModel(Model):
595
+ def __init__(self, dir_=HORIZONTAL):
596
+ super().__init__()
597
+ # self.sum_metrics = []
598
+ # for i in range(8):
599
+ # self.sum_metrics.append(Sum(name=f"no_{i}"))
600
+ self.metric_fn = Accuracy(name="acc_latent")
601
+ if dir_ == VERTICAL:
602
+ self.dir = (6, 7)
603
+ else:
604
+ self.dir = (2, 5)
605
+
606
+ def call(self, inputs):
607
+ target_image = tf.reshape(inputs[0][2], [-1])
608
+ output = inputs[1]
609
+ latents = tnp.asarray(inputs[2])
610
+
611
+ target_hor = tf.concat([
612
+ latents[:, self.dir],
613
+ latents[tf.range(latents.shape[0]), target_image + 8][:, None]
614
+ ],
615
+ axis=1)
616
+
617
+ loss_hor = mse(K.stop_gradient(target_hor), output)
618
+ self.add_loss(loss_hor)
619
+
620
+ self.add_metric(self.metric_fn(inputs[3], target_image))
621
+
622
+ return loss_hor
623
+
624
+
625
+ class PredRav(Model):
626
+
627
+ def call(self, inputs):
628
+ output = inputs[0][:, -1]
629
+ answers = inputs[1][:, 8:]
630
+ return tf.argmin(tf.reduce_sum(tf.abs(output[:, None] - answers), axis=-1), axis=-1)
raven_utils/models/loss_3.py ADDED
@@ -0,0 +1,638 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import tensorflow as tf
4
+ import tensorflow.experimental.numpy as tnp
5
+ from models_utils import OUTPUT, TARGET, PREDICT, DictModel, add_loss, LOSS, Predict
6
+ from models_utils import SubClassingModel
7
+ from models_utils.models.utils import interleave
8
+ from models_utils.op import reshape
9
+ from tensorflow.keras import Model
10
+ # from tensorflow.keras import backend as K
11
+ from tensorflow.keras.layers import Lambda
12
+ from tensorflow.keras.losses import SparseCategoricalCrossentropy, mse
13
+ from tensorflow.keras.metrics import SparseCategoricalAccuracy, Accuracy, BinaryAccuracy
14
+ import models_utils.ops as K
15
+
16
+ import raven_utils.decode
17
+ import raven_utils as rv
18
+ from raven_utils.config.constant import LABELS, INDEX, ACC_SAME, ACC_CHOOSE_LOWER, ACC_CHOOSE_UPPER, CLASSIFICATION, \
19
+ SLOT, \
20
+ PROPERTIES, ACC, GROUP, NUMBER, MASK
21
+ from raven_utils.models.uitls_ import RangeMask
22
+ from raven_utils.const import VERTICAL, HORIZONTAL
23
+
24
+
25
+ def get_properties_mask(target):
26
+ return target[:, rv.target.END_INDEX:rv.target.UNIFORMITY_INDEX] > 0
27
+
28
+
29
+ def create_change_mask(target):
30
+ properties_mask = get_properties_mask(target)
31
+ return [create_mask(properties_mask, i) for i, _ in enumerate(rv.rules.ATTRIBUTES)]
32
+
33
+
34
+ def create_uniform_mask(target):
35
+ u_mask = lambda i: tf.tile(target[:, rv.target.UNIFORMITY_INDEX + i, None] == 3, [1, rv.rules.ATTRIBUTES_LEN])
36
+ properties_mask = tf.concat([u_mask(0), u_mask(1)], axis=-1) | get_properties_mask(target)
37
+ return [create_mask(properties_mask, i) for i, _ in enumerate(rv.rules.ATTRIBUTES)]
38
+
39
+
40
+ def create_all_mask(target):
41
+ return [
42
+ tf.cast(tf.ones(tf.stack([tf.shape(target)[0], rv.entity.SUM])), dtype=tf.bool) for i, _ in
43
+ enumerate(rv.rules.ATTRIBUTES)]
44
+
45
+
46
+ class BaselineClassificationLossModel(Model):
47
+ def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True):
48
+ super().__init__()
49
+ self.predict_fn = SubClassingModel([lambda x: x[0], PredictModel()])
50
+ self.loss_fn = ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
51
+ group_loss=group_loss)
52
+ self.metric_fn = SimilarityRaven(mode=mode)
53
+
54
+ def call(self, inputs):
55
+ losses = []
56
+ output = inputs[1]
57
+ losses.append(self.loss_fn([inputs[0][0], output]))
58
+ losses.append(self.metric_fn([inputs[0][2], inputs[3][0], inputs[0][1][:, 8:]]))
59
+ return losses
60
+
61
+
62
+ class RavenLoss(Model):
63
+ def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True, lw=(1.0, 0.3),
64
+ classification=False, trans=True, anneal=False):
65
+ super().__init__()
66
+ if anneal:
67
+ self.weight_scheduler
68
+ self.classification = classification
69
+ self.trans = trans
70
+ self.predict_fn = DictModel(SubClassingModel([lambda x: x[-1], PredictModel()]), in_=OUTPUT,
71
+ out=[PREDICT, MASK], name="pred")
72
+ if self.trans:
73
+ self.loss_fn = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
74
+ group_loss=group_loss, enable_metrics=False, lw=lw[0]),
75
+ name="main_loss")
76
+ self.loss_fn_2 = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
77
+ group_loss=group_loss), name="add_loss")
78
+ self.metric_fn = SimilarityRaven(mode=mode)
79
+ if self.classification:
80
+ self.loss_fn_3 = add_loss(
81
+ ClassRavenModel(mode=create_all_mask, number_loss=number_loss, slot_loss=slot_loss,
82
+ group_loss=group_loss, enable_metrics="c" if self.trans else True), lw=lw[1],
83
+ name="class_loss")
84
+
85
+ def call(self, inputs):
86
+ losses = []
87
+ output = inputs[OUTPUT]
88
+ target = inputs[TARGET]
89
+ labels = inputs[LABELS]
90
+
91
+ if self.trans:
92
+ losses.append(self.loss_fn([labels[:, 2], output[0]]))
93
+ losses.append(self.loss_fn([labels[:, 5], output[1]]))
94
+ losses.append(self.loss_fn_2([target, output[2]]))
95
+ losses.append(self.metric_fn([inputs[INDEX], inputs[PREDICT], labels]))
96
+ if self.classification:
97
+ for i in range(8):
98
+ losses.append(self.loss_fn_3([labels[:, i], inputs[CLASSIFICATION][i]]))
99
+ return {**inputs, LOSS: losses}
100
+
101
+
102
+ class VTRavenLoss(Model):
103
+ def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True, lw=(2.0, 1.0),
104
+ classification=False, trans=True, anneal=False, plw=None):
105
+ super().__init__()
106
+ if anneal:
107
+ self.weight_scheduler
108
+ self.classification = classification
109
+ self.trans = trans
110
+ self.predict_fn = DictModel(SubClassingModel([lambda x: x[:, -1], PredictModel()]), in_=OUTPUT,
111
+ out=[PREDICT, "predict_mask"], name="pred")
112
+ self.loss_fn = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
113
+ group_loss=group_loss, plw=plw), lw=lw[0], name="add_loss")
114
+ self.metric_fn = SimilarityRaven(mode=mode)
115
+ if self.classification:
116
+ self.loss_fn_2 = add_loss(
117
+ ClassRavenModel(mode=create_all_mask, number_loss=number_loss, slot_loss=slot_loss,
118
+ group_loss=group_loss, enable_metrics="c", plw=plw), lw=lw[1], name="class_loss")
119
+
120
+ def call(self, inputs):
121
+ losses = []
122
+ output = inputs[OUTPUT]
123
+ target = inputs[TARGET]
124
+ labels = inputs[LABELS]
125
+ mask = inputs[MASK]
126
+
127
+ target_masked = target[mask]
128
+ output_masked = output[mask]
129
+ losses.append(self.loss_fn([target_masked, output_masked]))
130
+
131
+ target_unmasked = target[~mask]
132
+ output_unmasked = output[~mask]
133
+ losses.append(self.loss_fn_2([target_unmasked, output_unmasked]))
134
+
135
+ losses.append(self.metric_fn([inputs[INDEX], inputs[PREDICT], labels]))
136
+ return {**inputs, LOSS: losses}
137
+
138
+
139
+ class SingleVTRavenLoss(Model):
140
+ def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True, lw=(1.0, 0.1),
141
+ classification=False, trans=True, anneal=False):
142
+ super().__init__()
143
+ if anneal:
144
+ self.weight_scheduler
145
+ self.classification = classification
146
+ self.trans = trans
147
+ self.predict_fn = DictModel(PredictModel(), in_=OUTPUT, out=[PREDICT, MASK], name="pred")
148
+ self.loss_fn = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
149
+ group_loss=group_loss), lw=lw[0], name="add_loss")
150
+ self.metric_fn = SimilarityRaven(mode=mode)
151
+
152
+ def call(self, inputs):
153
+ losses = []
154
+ output = inputs[OUTPUT]
155
+ target = inputs[TARGET]
156
+ labels = inputs[LABELS]
157
+
158
+ losses.append(self.loss_fn([target, output]))
159
+ losses.append(self.metric_fn([inputs[INDEX], inputs[PREDICT], labels]))
160
+ return {**inputs, LOSS: losses}
161
+
162
+
163
+ class ClassRavenModel(Model):
164
+ def __init__(self, mode=create_all_mask, plw=None, number_loss=False, slot_loss=True, group_loss=True,
165
+ enable_metrics=True,
166
+ lw=1.0):
167
+ super().__init__()
168
+ self.number_loss = number_loss
169
+ self.group_loss = group_loss
170
+ self.enable_metrics = enable_metrics
171
+ self.slot_loss = slot_loss
172
+ self.predict_fn = PredictModel()
173
+ self.loss_fn = SparseCategoricalCrossentropy(from_logits=True)
174
+ if self.slot_loss:
175
+ self.loss_fn_2 = tf.nn.sigmoid_cross_entropy_with_logits
176
+ if self.enable_metrics:
177
+ self.enable_metrics = f"{self.enable_metrics}_" if isinstance(self.enable_metrics, str) else ""
178
+ self.metric_fn = [
179
+ SparseCategoricalAccuracy(name=f"{self.enable_metrics}{ACC}_{property_}") for property_ in
180
+ rv.properties.NAMES]
181
+ if self.group_loss:
182
+ self.metric_fn_group = SparseCategoricalAccuracy(name=f"{self.enable_metrics}{ACC}_{GROUP}")
183
+ if self.slot_loss:
184
+ self.metric_fn_2 = BinaryAccuracy(name=f"{self.enable_metrics}{ACC}_{SLOT}")
185
+ self.range_mask = RangeMask()
186
+ self.mode = mode
187
+ self.lw = lw
188
+ if not plw:
189
+ plw = [1., 95.37352927, 2.83426987, 0.85212836, 1.096005, 1.21943385]
190
+ elif isinstance(plw, int) or isinstance(plw, float):
191
+ plw = [1., plw, 2.83426987, 0.85212836, 1.096005, 1.21943385]
192
+ # plw = [plw] * 6
193
+ self.plw = plw
194
+
195
+ # self.predict_fn = partial(tf.argmax, axis=-1)
196
+
197
+ def call(self, inputs):
198
+ losses = []
199
+ metrics = {}
200
+ target = inputs[0]
201
+ output = inputs[1]
202
+
203
+ target_group, target_slot, target_all = raven_utils.decode.decode_target(target)
204
+
205
+ group_output, output_slot, outputs = raven_utils.decode.output_divide(output, split_fn=tf.split)
206
+
207
+ # group
208
+ if self.group_loss:
209
+ group_loss = self.lw * self.plw[0] * self.loss_fn(target_group, group_output)
210
+ losses.append(group_loss)
211
+
212
+ if isinstance(self.enable_metrics, str):
213
+ group_metric = self.metric_fn_group(target_group, group_output)
214
+ # metrics[GROUP] = group_metric
215
+ self.add_metric(group_metric)
216
+ self.add_metric(tf.reduce_sum(group_metric), f"{self.enable_metrics}{ACC}")
217
+
218
+ # setting uniformity mask
219
+ full_properties_musks = self.mode(target)
220
+
221
+ range_mask = self.range_mask(target_group)
222
+
223
+ if self.slot_loss:
224
+ # number
225
+ number_mask = range_mask & full_properties_musks[0]
226
+ number_mask = tf.cast(number_mask, tf.float32)
227
+ target_number = tf.reduce_sum(
228
+ tf.cast(target_slot, "float32") * number_mask, axis=-1)
229
+ output_number = tf.reduce_sum(
230
+ tf.cast(tf.sigmoid(output_slot) >= 0.5, "float32") * number_mask, axis=-1)
231
+
232
+ # output_number = tf.reduce_sum(tf.sigmoid(output_slot) * number_mask, axis=-1)
233
+ if self.number_loss:
234
+ scale = 1 / 9
235
+ if self.number_loss == 2:
236
+ output_number_2 = tf.reduce_sum(tf.sigmoid(output_slot) * number_mask, axis=-1)
237
+ else:
238
+ output_number_2 = output_number
239
+ number_loss = self.lw * self.plw[1] * mse(tf.stop_gradient(target_number) * scale,
240
+ output_number_2 * scale)
241
+ losses.append(number_loss)
242
+
243
+ # metrics[NUMBER] = number_acc
244
+
245
+ if isinstance(self.enable_metrics, str):
246
+ number_acc = tf.reduce_mean(
247
+ tf.cast(tf.cast(target_number, "int8") == tf.cast(output_number, "int8"), "float32"))
248
+ self.add_metric(tf.reduce_sum(number_acc), f"{self.enable_metrics}{ACC}_{NUMBER}")
249
+ self.add_metric(tf.reduce_sum(number_acc), f"{self.enable_metrics}{ACC}")
250
+ self.add_metric(tf.reduce_sum(number_acc), f"{self.enable_metrics}{ACC}_NO_{GROUP}")
251
+
252
+ # position/slot
253
+ slot_mask = range_mask & full_properties_musks[1]
254
+ # tf.boolean_mask(target_slot,slot_mask)
255
+
256
+ if tf.reduce_any(slot_mask):
257
+ # if tf.reduce_mean(tf.cast(slot_mask, dtype=tf.int32)) > 0:
258
+ target_slot_masked = tf.boolean_mask(target_slot, slot_mask)[:, None]
259
+ output_slot_masked = tf.boolean_mask(output_slot, slot_mask)[:, None]
260
+ loss_slot = self.lw * self.plw[2] * tf.reduce_mean(
261
+ self.loss_fn_2(tf.cast(target_slot_masked, "float32"), output_slot_masked))
262
+ if isinstance(self.enable_metrics, str):
263
+ acc_slot = self.metric_fn_2(target_slot_masked, output_slot_masked)
264
+ self.add_metric(acc_slot)
265
+ self.add_metric(tf.reduce_sum(acc_slot), f"{self.enable_metrics}{ACC}")
266
+ self.add_metric(tf.reduce_sum(acc_slot), f"{self.enable_metrics}{ACC}_NO_{GROUP}")
267
+ else:
268
+ loss_slot = 0.0
269
+ acc_slot = -1.0
270
+
271
+ losses.append(loss_slot)
272
+ # metrics[SLOT] = acc_slot
273
+ # if loss_slot != 0:
274
+
275
+ # if tf.reduce_any(slot_mask):
276
+
277
+ # self.add_metric(acc_slot, f"{self.enable_metrics}{ACC}_{NUMBER}")
278
+ # self.add_metric(acc_slot, f"{self.enable_metrics}{ACC}")
279
+ # self.add_metric(acc_slot, f"{self.enable_metrics}{ACC}_NO_{GROUP}")
280
+
281
+ # properties
282
+ for i, out in enumerate(outputs):
283
+ shape = (-1, rv.entity.SUM, rv.properties.RAW_SIZE[i])
284
+ out_reshaped = tf.reshape(out, shape)
285
+ properties_mask = tf.cast(target_slot, "bool") & full_properties_musks[i + 2]
286
+
287
+ if tf.reduce_any(properties_mask):
288
+ out_masked = tf.boolean_mask(out_reshaped, properties_mask)
289
+ out_target = tf.boolean_mask(target_all[i], properties_mask)
290
+ loss = self.lw * self.plw[3 + i] * self.loss_fn(out_target, out_masked)
291
+ if isinstance(self.enable_metrics, str):
292
+ metric = self.metric_fn[i](out_target, out_masked)
293
+ self.add_metric(metric)
294
+ # self.add_metric(metric, f"{self.enable_metrics}{ACC}")
295
+ self.add_metric(tf.reduce_sum(metric), f"{self.enable_metrics}{ACC}")
296
+ self.add_metric(tf.reduce_sum(metric), f"{self.enable_metrics}{ACC}_{PROPERTIES}")
297
+ self.add_metric(tf.reduce_sum(metric), f"{self.enable_metrics}{ACC}_NO_{GROUP}")
298
+ else:
299
+ loss = 0.0
300
+ metric = -1.0
301
+
302
+ losses.append(loss)
303
+ return losses
304
+
305
+
306
+ class FullMask(Model):
307
+ def __init__(self, mode=create_uniform_mask):
308
+ super().__init__()
309
+ self.range_mask = RangeMask()
310
+ self.mode = mode
311
+
312
+ def call(self, inputs):
313
+ target_group, target_slot, _ = raven_utils.decode.decode_target(inputs)
314
+ full_properties_musks = self.mode(inputs)
315
+ range_mask = self.range_mask(target_group)
316
+
317
+ number_mask = range_mask & full_properties_musks[0]
318
+
319
+ slot_mask = range_mask & full_properties_musks[1]
320
+ properties_mask = []
321
+ for property_mask in full_properties_musks[2:]:
322
+ properties_mask.append(tf.cast(target_slot, "bool") & property_mask)
323
+ return [slot_mask, properties_mask, number_mask]
324
+
325
+
326
+ def create_mask(rules, i):
327
+ mask_1 = tf.tile(rules[:, i][None], [len(rv.target.FIRST_LAYOUT), 1])
328
+ mask_2 = tf.tile(rules[:, i + 5][None], [len(rv.target.SECOND_LAYOUT), 1])
329
+ shape = tf.shape(rules)
330
+ full_mask_1 = tf.scatter_nd(tnp.array(rv.target.FIRST_LAYOUT)[:, None], mask_1, shape=(rv.entity.SUM, shape[0]))
331
+ full_mask_2 = tf.tensor_scatter_nd_update(full_mask_1, tnp.array(rv.target.SECOND_LAYOUT)[:, None], mask_2)
332
+ return tf.transpose(full_mask_2)
333
+
334
+
335
+ # class PredictModel(Model):
336
+ # def __init__(self):
337
+ # super().__init__()
338
+ # self.predict_fn = Lambda(partial(tf.argmax, axis=-1))
339
+ # self.predict_fn_2 = Lambda(lambda x: tf.sigmoid(x) > 0.5)
340
+ # self.range_mask = RangeMask()
341
+ #
342
+ # # self.predict_fn = partial(tf.argmax, axis=-1)
343
+ #
344
+ # def call(self, inputs):
345
+ # group_output = inputs[rv.OUTPUT_GROUP_SLICE]
346
+ # group_loss = self.predict_fn(group_output)[:, None]
347
+ #
348
+ # output_slot = inputs[rv.OUTPUT_SLOT_SLICE]
349
+ # range_mask = self.range_mask(group_loss[:, 0])
350
+ # loss_slot = tf.cast(self.predict_fn_2(output_slot), dtype=tf.int64)
351
+ #
352
+ # properties_output = inputs[rv.OUTPUT_PROPERTIES_SLICE]
353
+ # properties = []
354
+ # outputs = tf.split(properties_output, list(rv.ENTITY_PROPERTIES_INDEX.values()), axis=-1)
355
+ # for i, out in enumerate(outputs):
356
+ # shape = (-1, rv.ENTITY_SUM, rv.ENTITY_PROPERTIES_VALUES[i])
357
+ # out_reshaped = tf.reshape(out, shape)
358
+ # properties.append(self.predict_fn(out_reshaped))
359
+ # number_loss = tf.reduce_sum(loss_slot, axis=-1, keepdims=True)
360
+ #
361
+ # result = tf.concat([group_loss, loss_slot, interleave(properties), number_loss], axis=-1)
362
+ #
363
+ # return [result, range_mask, range_mask, range_mask, range_mask]
364
+
365
+ class PredictModel(Model):
366
+ def __init__(self):
367
+ super().__init__()
368
+ self.predict_fn = Predict()
369
+ self.predict_fn_2 = Lambda(lambda x: tf.sigmoid(x) > 0.5)
370
+ self.range_mask = RangeMask()
371
+
372
+ # self.predict_fn = partial(tf.argmax, axis=-1)
373
+
374
+ def call(self, inputs):
375
+ group_output, output_slot, *properties = rv.decode.output(inputs, tf.split, self.predict_fn, self.predict_fn_2)
376
+ number_loss = K.int64(K.sum(output_slot))
377
+ result = tf.concat(
378
+ [group_output[:, None], tf.cast(output_slot, dtype=tf.int64), interleave(properties), number_loss[:, None]],
379
+ axis=-1)
380
+
381
+ range_mask = self.range_mask(group_output)
382
+ return [result, range_mask]
383
+ # return [result, range_mask, range_mask, range_mask, range_mask]
384
+
385
+
386
+ # todo change slices
387
+ class PredictModelMasked(Model):
388
+ def __init__(self):
389
+ super().__init__()
390
+ self.predict_fn = Lambda(partial(tf.argmax, axis=-1))
391
+ self.loss_fn_2 = Lambda(lambda x: tf.sigmoid(x) > 0.5)
392
+ self.range_mask = RangeMask()
393
+
394
+ # self.predict_fn = partial(tf.argmax, axis=-1)
395
+
396
+ def call(self, inputs):
397
+ group_output = inputs[:, -rv.GROUPS_NO:]
398
+ group_loss = self.predict_fn(group_output)[:, None]
399
+
400
+ output_slot = inputs[:, :rv.ENTITY_SUM]
401
+ range_mask = self.range_mask(group_loss[:, 0])
402
+ loss_slot = tf.cast(self.predict_fn_2(output_slot * range_mask), dtype=tf.int64)
403
+
404
+ properties_output = inputs[:, rv.ENTITY_SUM:-rv.GROUPS_NO]
405
+
406
+ properties = []
407
+ outputs = tf.split(properties_output, list(rv.ENTITY_PROPERTIES_INDEX.values()), axis=-1)
408
+ for i, out in enumerate(outputs):
409
+ shape = (-1, rv.ENTITY_SUM, rv.ENTITY_PROPERTIES_VALUES[i])
410
+ out_reshaped = tf.reshape(out, shape)
411
+ out_masked = out_reshaped * loss_slot[..., None]
412
+ properties.append(self.predict_fn(out_masked))
413
+ # out_masked[0].numpy()
414
+ number_loss = tf.reduce_sum(loss_slot, axis=-1, keepdims=True)
415
+
416
+ result = tf.concat([group_loss, loss_slot, interleave(properties), number_loss], axis=-1)
417
+
418
+ return result
419
+
420
+
421
+ def final_predict_mask(x, mask):
422
+ r = reshape(x[0][:, rv.INDEX[0]:-1], [-1, 3])
423
+ return tf.ragged.boolean_mask(r, mask)
424
+
425
+
426
+ def final_predict(x, mode=False):
427
+ m = x[1] if mode else tf.cast(x[0][:, 1:rv.INDEX[0]], tf.bool)
428
+ return final_predict_mask(x[0], m)
429
+
430
+
431
+ def final_predict_2(x):
432
+ ones = tf.cast(tf.ones(tf.shape(x[0])[0]), tf.bool)[:, None]
433
+ mask = tf.concat([ones, tf.tile(x[1], [1, 4]), ones], axis=-1)
434
+ return tf.ragged.boolean_mask(x[0], mask)
435
+
436
+
437
+ class PredictModelOld(Model):
438
+
439
+ def call(self, inputs):
440
+ output = inputs[-2]
441
+
442
+ rest_output = output[:, :-rv.GROUPS_NO]
443
+
444
+ result_all = []
445
+ outputs = tf.split(rest_output, list(rv.ENTITY_PROPERTIES_INDEX.values()), axis=-3)
446
+ for i, out in enumerate(outputs):
447
+ shape = (-3, rv.ENTITY_SUM, rv.ENTITY_PROPERTIES_VALUES[i])
448
+ out_reshaped = tf.reshape(out, shape)
449
+
450
+ result = tf.cast(tf.argmax(out_reshaped, axis=-3), dtype="int8")
451
+ result_all.append(result)
452
+
453
+ result_all = interleave(result_all)
454
+ return result_all
455
+
456
+
457
+ def get_matches(diff, target_index):
458
+ diff_sum = K.sum(diff)
459
+ db_argsort = tf.argsort(diff_sum, axis=-1)
460
+ db_sorted = tf.sort(diff_sum)
461
+ db_mask = db_sorted[:, 0, None] == db_sorted
462
+ db_same = tf.where(db_mask, db_argsort, -1 * tf.ones_like(db_argsort))
463
+ matched_index = db_same == target_index
464
+ # setting shape needed for TensorFlow graph
465
+ matched_index.set_shape(db_same.shape)
466
+ matches = K.any(matched_index)
467
+ more_matches = K.sum(db_mask) > 1
468
+ once_matches = K.sum(matches & tf.math.logical_not(more_matches))
469
+ return matches, more_matches, once_matches
470
+
471
+
472
+ class SimilarityRaven(Model):
473
+ def __init__(self, mode=create_all_mask, number_loss=False):
474
+ super().__init__()
475
+ self.range_mask = RangeMask()
476
+ self.mode = mode
477
+
478
+ # self.predict_fn = partial(tf.argmax, axis=-1)
479
+
480
+ # INDEX, PREDICT, LABELS
481
+ def call(self, inputs):
482
+ metrics = []
483
+ target_index = inputs[0] - 8
484
+ predict = inputs[1]
485
+ answers = inputs[2][:, 8:]
486
+ shape = tf.shape(predict)
487
+
488
+ target = K.gather(answers, target_index[:, 0])
489
+
490
+ target_group = target[:, 0]
491
+
492
+ # comp_slice = np.
493
+ target_comp = target[:, 1:rv.target.END_INDEX]
494
+ predict_comp = predict[:, 1:rv.target.END_INDEX]
495
+ answers_comp = answers[:, :, 1:rv.target.END_INDEX]
496
+
497
+ full_properties_musks = self.mode(target)
498
+ fpm = K.cat([full_properties_musks[0], interleave(full_properties_musks[2:])])
499
+
500
+ range_mask = self.range_mask(target_group)
501
+ full_range_mask = K.cat([range_mask, tf.repeat(range_mask, 3, axis=-1)], axis=-1)
502
+
503
+ final_mask = fpm & full_range_mask
504
+
505
+ target_masked = target_comp * final_mask
506
+ predict_masked = predict_comp * final_mask
507
+ answers_masked = answers_comp * tf.tile(final_mask[:, None], [1, 8, 1])
508
+
509
+ acc_same = K.mean(K.all(target_masked == predict_masked))
510
+ self.add_metric(acc_same, ACC_SAME)
511
+ metrics.append(acc_same)
512
+
513
+ diff = tf.abs(predict_masked[:, None] - answers_masked)
514
+ diff_bool = diff != 0
515
+
516
+ matches, more_matches, once_matches = get_matches(tf.cast(diff_bool, dtype=tf.int32), target_index)
517
+
518
+ second_phase_mask = (more_matches & matches)
519
+ diff_second_phase = tf.boolean_mask(diff, second_phase_mask)
520
+ target_index_2 = tf.boolean_mask(target_index, second_phase_mask, axis=0)
521
+
522
+ matches_2, more_matches_2, once_matches_2 = get_matches(diff_second_phase, target_index_2)
523
+ matches_2_no = K.sum(matches_2)
524
+
525
+ acc_choose_upper = (once_matches + matches_2_no) / shape[0]
526
+ self.add_metric(acc_choose_upper, ACC_CHOOSE_UPPER)
527
+ metrics.append(acc_choose_upper)
528
+
529
+ acc_choose_lower = (once_matches + once_matches_2) / shape[0]
530
+ self.add_metric(acc_choose_lower, ACC_CHOOSE_LOWER)
531
+ metrics.append(acc_choose_lower)
532
+
533
+ return metrics
534
+
535
+
536
+ class SimilarityRaven2(Model):
537
+ def __init__(self, mode=create_all_mask, number_loss=False):
538
+ super().__init__()
539
+ self.range_mask = RangeMask()
540
+ self.mode = mode
541
+
542
+ # self.predict_fn = partial(tf.argmax, axis=-1)
543
+
544
+ # INDEX, PREDICT, LABELS
545
+ def call(self, inputs):
546
+ metrics = []
547
+ target_index = inputs[0] - 8
548
+ predict = inputs[1]
549
+ answers = inputs[2][:, 8:]
550
+ shape = tf.shape(predict)
551
+
552
+ target = K.gather(answers, target_index[:, 0])
553
+
554
+ target_group = target[:, 0]
555
+
556
+ # comp_slice = np.
557
+ target_comp = target[:, 1:rv.target.END_INDEX]
558
+ predict_comp = predict[:, 1:rv.target.END_INDEX]
559
+ answers_comp = answers[:, :, 1:rv.target.END_INDEX]
560
+
561
+ full_properties_musks = self.mode(target)
562
+ fpm = K.cat([full_properties_musks[0], interleave(full_properties_musks[2:])])
563
+
564
+ range_mask = self.range_mask(target_group)
565
+ full_range_mask = K.cat([range_mask, tf.repeat(range_mask, 3, axis=-1)], axis=-1)
566
+
567
+ final_mask = fpm & full_range_mask
568
+
569
+ target_masked = target_comp * final_mask
570
+ predict_masked = predict_comp * final_mask
571
+ answers_masked = answers_comp * tf.tile(final_mask[:, None], [1, 8, 1])
572
+
573
+ acc_same = K.mean(K.all(target_masked == predict_masked))
574
+ self.add_metric(acc_same, ACC_SAME)
575
+ metrics.append(acc_same)
576
+
577
+ diff = tf.abs(predict_masked[:, None] - answers_masked)
578
+ diff_bool = diff != 0
579
+
580
+ matches, more_matches, once_matches = get_matches(tf.cast(diff_bool, dtype=tf.int32), target_index)
581
+
582
+ second_phase_mask = (more_matches & matches)
583
+ diff_second_phase = tf.boolean_mask(diff, second_phase_mask)
584
+ target_index_2 = tf.boolean_mask(target_index, second_phase_mask, axis=0)
585
+
586
+ matches_2, more_matches_2, once_matches_2 = get_matches(diff_second_phase, target_index_2)
587
+ matches_2_no = K.sum(matches_2)
588
+
589
+ acc_choose_upper = (once_matches + matches_2_no) / shape[0]
590
+ self.add_metric(acc_choose_upper, ACC_CHOOSE_UPPER)
591
+ metrics.append(acc_choose_upper)
592
+
593
+ acc_choose_lower = (once_matches + once_matches_2) / shape[0]
594
+ self.add_metric(acc_choose_lower, ACC_CHOOSE_LOWER)
595
+ metrics.append(acc_choose_lower)
596
+
597
+ metrics.append(K.sum(target_masked != predict_masked))
598
+
599
+ return metrics
600
+
601
+
602
+ class LatentLossModel(Model):
603
+ def __init__(self, dir_=HORIZONTAL):
604
+ super().__init__()
605
+ # self.sum_metrics = []
606
+ # for i in range(8):
607
+ # self.sum_metrics.append(Sum(name=f"no_{i}"))
608
+ self.metric_fn = Accuracy(name="acc_latent")
609
+ if dir_ == VERTICAL:
610
+ self.dir = (6, 7)
611
+ else:
612
+ self.dir = (2, 5)
613
+
614
+ def call(self, inputs):
615
+ target_image = tf.reshape(inputs[0][2], [-1])
616
+ output = inputs[1]
617
+ latents = tnp.asarray(inputs[2])
618
+
619
+ target_hor = tf.concat([
620
+ latents[:, self.dir],
621
+ latents[tf.range(latents.shape[0]), target_image + 8][:, None]
622
+ ],
623
+ axis=1)
624
+
625
+ loss_hor = mse(K.stop_gradient(target_hor), output)
626
+ self.add_loss(loss_hor)
627
+
628
+ self.add_metric(self.metric_fn(inputs[3], target_image))
629
+
630
+ return loss_hor
631
+
632
+
633
+ class PredRav(Model):
634
+
635
+ def call(self, inputs):
636
+ output = inputs[0][:, -1]
637
+ answers = inputs[1][:, 8:]
638
+ return tf.argmin(tf.reduce_sum(tf.abs(output[:, None] - answers), axis=-1), axis=-1)
raven_utils/models/multi_transformer.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from functools import partial
3
+ from tensorflow.keras.layers import Lambda
4
+ from tensorflow.keras.layers import Dense
5
+ from tensorflow.keras import Input, Model
6
+ from tensorflow.python.keras import Sequential
7
+
8
+ from config.constant import TRANS
9
+ from ml_utils import filter_init
10
+ from models.loss import VTRavenLoss, create_uniform_mask, SingleVTRavenLoss
11
+ from models_utils import pmodel, DictModel, bt, INPUTS, bm, OUTPUT, LATENTS, transformer, BatchModel, get_extractor, \
12
+ build_seq_model, BUILD, build_train_list, InitialWeight
13
+ from models_utils import SumPositionEmbedding, TransformerBlock, CatPositionEmbedding, transformer, BatchInitialWeight
14
+ import models_utils.ops as K
15
+ from models_utils.image import inverse_fn
16
+ from models_utils.ops_core import IndexReshape
17
+ from models_utils.random_ import EpsilonGreedy, EpsilonSoft
18
+ from models_utils.step import StepDict
19
+
20
+
21
+ def init_weights(shape, dtype=None):
22
+ return tf.cast(K.var.image(shape=shape, pre=True), dtype=tf.float32)
23
+
24
+
25
+ def conversion(x, max_=45):
26
+ shape = tf.shape(x)
27
+ return tf.reshape(x[:, :max_], tf.stack([shape[0], 9, -1]))
28
+
29
+
30
+ def take_left(x):
31
+ return x[..., 7:8]
32
+
33
+
34
+ def take_by_index(x, i=8):
35
+ return x[..., i:i + 1]
36
+
37
+
38
+ def mix(x):
39
+ return (x[..., 7:8] + x[..., 5:6]) / 2
40
+
41
+
42
+ def empty_last(x):
43
+ return tf.zeros_like(x[..., 7:8])
44
+
45
+
46
+ class Conversion(Model):
47
+ def __init__(self):
48
+ super().__init__()
49
+ self.model = IndexReshape((0, "9", None))
50
+
51
+ def call(self, inputs):
52
+ return self.model(inputs[:, :45])
53
+
54
+
55
+ class RandomImageMask(Model):
56
+ def __init__(self, last, last_index=9):
57
+ super().__init__()
58
+ self.get_last = last
59
+ self.last_index = last_index
60
+
61
+ def call(self, inputs):
62
+ shape = tf.shape(inputs)
63
+ indexes = tf.random.uniform(shape=shape[0:1], maxval=self.last_index, dtype=tf.int32)
64
+ mask = tf.one_hot(indexes, self.last_index)[:, None, None]
65
+
66
+ return (1 - mask) * inputs[..., :self.last_index] + mask * tf.tile(self.get_last(inputs),
67
+ (1, 1, 1, self.last_index))
68
+
69
+
70
+ # res = (1 - mask) * inputs[..., :self.last_index] + mask * tf.tile(self.get_last(inputs),
71
+ # (1, 1, 1, self.last_index))
72
+
73
+
74
+ # from data_utils import ims
75
+ # for i in range(50):
76
+ # ims(res[i].numpy().swapaxes(0, 2))
77
+ # res[12].numpy()
78
+ # self.get_last(inputs).numpy()
79
+ # import tensorflow as tf
80
+ # tf.random.uniform(shape=shape[0:1], maxval=255, dtype=tf.int32)
81
+ # from ml_utils import print_error
82
+ # ims(mask[0].numpy())
83
+ # print_error(lambda :ims(mask[0]))
84
+ # from models_utils import ops as K
85
+
86
+
87
+ class ImageMask(Model):
88
+ def __init__(self, last, index=8, last_index=9):
89
+ super().__init__()
90
+ self.get_last = last
91
+ self.index = index
92
+ self.last_index = last_index
93
+
94
+ def call(self, inputs):
95
+ return tf.concat([inputs[..., :8], self.get_last(inputs)], axis=-1)
96
+
97
+
98
+ class CreateGrid(Model):
99
+ def __init__(self,
100
+ no=4,
101
+ extractor="ef",
102
+ type_=3,
103
+ base="seq",
104
+ last=take_left,
105
+ epsilon=None,
106
+ pooling=None,
107
+ mask_fn=None,
108
+ model=None,
109
+ **kwargs
110
+ ):
111
+ super().__init__()
112
+ self.type_ = type_
113
+ if type_ == 9:
114
+ self.start_shape = 75
115
+ data = (224, 224, 3)
116
+ conv = lambda: Conversion()
117
+ else:
118
+ self.start_shape = 84
119
+ data = (84, 84, 3)
120
+ extractor = BUILD[base]([
121
+ BatchModel(get_extractor(data=data, model=extractor)),
122
+ lambda x: tf.transpose(x, (1, 0, 2, 3, 4))
123
+ # lambda x: tf.tile(x[:, :224, :224], (1, 1, 1, 3))
124
+ ])
125
+ conv = lambda: conversion
126
+
127
+ self.epsilon = epsilon
128
+ if mask_fn == "random":
129
+ mask_fn = RandomImageMask(last=last)
130
+ elif mask_fn is None:
131
+ mask_fn = ImageMask(last=last)
132
+
133
+ self.mask_fn = mask_fn
134
+
135
+
136
+ def call(self, inputs):
137
+ transposed = tf.image.resize(tf.transpose(inputs, (0, 2, 3, 1)), (self.start_shape, self.start_shape))
138
+ re = self.mask_fn(transposed)
139
+
140
+ # re = tf.concat([transposed[..., :8], self.get_last(transposed)], axis=-1)
141
+ if self.type_ == 9:
142
+ x = tf.transpose(re, [0, 3, 1, 2])[..., None]
143
+ x = K.create_image_grid(x, 3, 3)
144
+ x = x[:, :224, :224]
145
+ x = tf.tile(x, [1, 1, 1, 3])
146
+ else:
147
+
148
+ x = tf.stack([
149
+ re[..., :3],
150
+ re[..., 3:6],
151
+ re[..., 6:9],
152
+ ])
153
+ return self.model(x)
154
+
155
+
156
+ # self.model.layers[0](x)
157
+
158
+
159
+ def grid_transformer(
160
+ *args,
161
+ type_=9,
162
+ no=4,
163
+ extractor="ef",
164
+ loss_mode=create_uniform_mask,
165
+ output_size=10,
166
+ loss_weight=1.0,
167
+ out_layers=(1000, 1000, 1000),
168
+ pos_emd="cat",
169
+ base="seq",
170
+ inverse_image=True,
171
+ last="left",
172
+ mask_fn=None,
173
+ model=None,
174
+ trans=None,
175
+ **kwargs):
176
+
177
+ if last == "left":
178
+ last = take_left
179
+ elif last == "mix":
180
+ last = mix
181
+ elif last == "empty":
182
+ last = empty_last
183
+ elif last == "start":
184
+ last = Sequential([Lambda(empty_last), BatchInitialWeight(initializer=init_weights)])
185
+
186
+ create_grid = CreateGrid(
187
+ type_=type_,
188
+ no=no,
189
+ extractor=extractor,
190
+ model=model,
191
+ output_size=output_size,
192
+ out_layer=out_layers,
193
+ pos_emd=pos_emd,
194
+ base=base,
195
+ last=last,
196
+ mask_fn=mask_fn,
197
+ **kwargs
198
+ )
199
+
200
+ if model is None:
201
+ trans = transformer(
202
+ extractor=extractor,
203
+ pos_emd=pos_emd,
204
+ data=data,
205
+ output_size=output_size,
206
+ out_layers=out_layer,
207
+ pooling=conv,
208
+ no=no,
209
+ base=base,
210
+ **kwargs
211
+ # **as_dict(p.trans)
212
+ )
213
+ else:
214
+ trans = trans
215
+
216
+
217
+
218
+ def get_rav_trans(
219
+ *args,
220
+ type_=9,
221
+ no=4,
222
+ extractor="ef",
223
+ loss_mode=create_uniform_mask,
224
+ output_size=10,
225
+ loss_weight=1.0,
226
+ out_layers=(1000, 1000, 1000),
227
+ pos_emd="cat",
228
+ base="seq",
229
+ inverse_image=True,
230
+ last="left",
231
+ epsilon="greedy",
232
+ epsilon_step=500,
233
+ mask_fn=None,
234
+ model=None,
235
+ loss="multi",
236
+ **kwargs):
237
+ if last == "left":
238
+ last = take_left
239
+ elif last == "mix":
240
+ last = mix
241
+ elif last == "empty":
242
+ last = empty_last
243
+ elif last == "start":
244
+ last = Sequential([Lambda(empty_last), BatchInitialWeight(initializer=init_weights)])
245
+
246
+ trans_raven = CreateGrid(
247
+ type_=type_,
248
+ no=no,
249
+ extractor=extractor,
250
+ model=model,
251
+ output_size=output_size,
252
+ out_layer=out_layers,
253
+ pos_emd=pos_emd,
254
+ base=base,
255
+ last=last,
256
+ epsilon=epsilon,
257
+ mask_fn=mask_fn,
258
+ **kwargs
259
+ )
260
+
261
+ if loss == "single":
262
+ loss = SingleVTRavenLoss
263
+ else:
264
+ loss = VTRavenLoss
265
+
266
+ return bt(
267
+ DictModel(
268
+ Sequential([Lambda(lambda x: 255 - x), trans_raven]) if inverse_image else trans_raven,
269
+ in_=INPUTS,
270
+ name="Body"
271
+ ),
272
+ loss=loss(mode=loss_mode, classification=True, lw=(loss_weight, 1.0)),
273
+ loss_wrap=False
274
+ )
raven_utils/models/raven.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ml_utils import lw, lu
2
+ from models_utils import bm, Base, res, bt, DictModel, dense_drop, drop, build_encoder, MODEL_ARCH, ListModel, short, \
3
+ dense, Flatten, Cat, CatDenseBefore, \
4
+ CatDense, CatBefore, Drop, Flat2, down, Pass, conv, Flat, Get, bs, Res, SoftBlock
5
+ from models_utils import SubClassingModel
6
+ from models_utils.config.constants import *
7
+ import numpy as np
8
+
9
+ from config.constant import *
10
+ from tensorflow.keras.layers import Dense, Activation, BatchNormalization
11
+ import tensorflow as tf
12
+
13
+ import raven_utils as rv
14
+
15
+ from models.body import create_block
16
+ from models.class_ import Merge, RavenClass
17
+ from models.head import LatentHeadModel
18
+
19
+ from models.loss import RavenLoss
20
+ from models.trans import TransModel, FullTrans
21
+ from raven_utils.const import HORIZONTAL
22
+
23
+
24
+ def raven_model(scales,
25
+ out_layers,
26
+ latent=(64, 128, 256),
27
+ output_size=None,
28
+ padding=SAME,
29
+ body_layers=1,
30
+ encoder=None,
31
+ loop=1,
32
+ model=None,
33
+ act=None,
34
+ simpler=0,
35
+ loss_mode=None,
36
+ loss_weight=0.3,
37
+ dir_=HORIZONTAL,
38
+ global_context=False,
39
+ images_no=8,
40
+ context_mul=2,
41
+ res_act="pass",
42
+ drop_latent=0,
43
+ drop_inference=0,
44
+ drop_end=0,
45
+ ga=False,
46
+ trans_norm=None,
47
+ trans_act="relu",
48
+ arch=HEAD3,
49
+ encoder_norm=False,
50
+ encoder_pool=False,
51
+ encoder_global="GM",
52
+ encoder_before=False,
53
+ tail_units=256,
54
+ tail_flatten=None,
55
+ # for now by default
56
+ tail_down="MP",
57
+ trans_no=1,
58
+ trans_score_activation=tf.nn.softmax,
59
+ block_=SoftBlock,
60
+ **kwargs):
61
+ if isinstance(latent, int):
62
+ latent = (latent, 128, 256)
63
+ scales = lw(scales)
64
+
65
+ context_size = np.array(latent) * context_mul
66
+ # context_size = latent[scales] * context_mul
67
+
68
+ # if scales == 2:
69
+ # arch = HEAD
70
+ # elif scales == 1:
71
+ # arch = HEAD2
72
+ # else:
73
+ # arch = VERY2
74
+
75
+ if encoder_pool:
76
+ strides = (1, 1)
77
+ else:
78
+ strides = (2, 2)
79
+ if not isinstance(encoder_before, tuple):
80
+ encoder_before = [encoder_before] * 3
81
+
82
+ # if trans == 1:
83
+ # trans_model = TransModel2
84
+ # else:
85
+ # trans_model = TransModel
86
+
87
+ # if scales == 3:
88
+ # head = MultiHeadModel(encoder=encoder)
89
+ arch = MODEL_ARCH[arch]
90
+ heads = []
91
+ for s in list(range(0, max(scales) + 1)):
92
+ if s in (0, 1):
93
+ if s == 0:
94
+ encoder = build_encoder(arch[:3], add_norm=encoder_norm, add_pool=encoder_pool, kerner_size=(4, 4),
95
+ strides=strides)
96
+ else:
97
+ encoder = build_encoder(arch[3:4], add_norm=encoder_norm, add_pool=encoder_pool, kerner_size=(4, 4),
98
+ strides=strides)
99
+ head = LatentHeadModel(
100
+ encoder=encoder,
101
+ inference_network=(
102
+ bm([
103
+ CatBefore(filters=int(context_size[s] / 8)) if encoder_before[s] else Cat(
104
+ filters=context_size[s]),
105
+ # todo activation?
106
+ Res(filters=context_size[s], padding=padding)
107
+ ] + ([drop(drop_inference)] if drop_inference else []),
108
+ name="inference")
109
+ ) if s in scales else Pass(),
110
+ stem=Base(
111
+ bm(
112
+ # ok we choose by parameters anyway
113
+ [res(filters=latent[s], padding=padding, act=act)] + (
114
+ [drop(drop_latent)] if drop_latent else [])
115
+ ),
116
+ name="stem")
117
+ )
118
+ else:
119
+ encoder = bm([
120
+ Res(),
121
+ build_encoder(arch[4:], add_norm=encoder_norm, add_pool=encoder_pool, kerner_size=(4, 4),
122
+ strides=strides),
123
+ short(encoder_global) if encoder_global else Flatten(),
124
+ dense(latent[s])
125
+ ])
126
+ head = LatentHeadModel(
127
+ encoder=encoder,
128
+ inference_network=bm([
129
+ # todo Echeck Cat
130
+ CatDenseBefore(filters=int(context_size[s] / 8)) if encoder_before[
131
+ s] else CatDense(filters=context_size[s]),
132
+ # todo activation?
133
+ Res(model="dv2", filters=context_size[s], padding=padding)
134
+ ] + ([dense_drop(drop_inference)] if drop_inference else []),
135
+ name="inference"),
136
+ stem=Base(
137
+ bm(
138
+ # ok we choose by parameters anyway
139
+ [res(model="dv2", units=latent[s], padding=padding, act=act)] + (
140
+ [dense_drop(drop_latent)] if drop_latent else [])
141
+ ),
142
+ name="stem")
143
+ )
144
+ heads.append(head)
145
+
146
+ concat_input = [f"{LATENT}_{i}" for i, _ in enumerate(heads)] + [f"{INFERENCE}_{i}" for i, _ in enumerate(heads)]
147
+ concat_output = ["LATENTS", "INFERENCES"]
148
+
149
+ def head_concat(inputs):
150
+ latents = inputs[:len(heads)]
151
+ inferences = inputs[len(heads):]
152
+ return latents, inferences
153
+
154
+ head = ListModel([(h, (INPUTS if i == 0 else OUTPUT), [f"{LATENT}_{i}", f"{INFERENCE}_{i}", OUTPUT]) for i, h in
155
+ enumerate(heads)] + [
156
+ (head_concat, concat_input, concat_output)], out=concat_output)
157
+ # from rav_utils.raven import init_image
158
+ # a = init_image()
159
+ # head(a)
160
+
161
+ if model is None:
162
+ model = []
163
+ for i in scales:
164
+ trans_models = []
165
+ for t in range(trans_no):
166
+ trans_models.append(
167
+ bm(
168
+ [create_block(latent=latent[i], simpler=simpler, padding=padding, norm=trans_norm, act=res_act,
169
+ loop=loop, type_="dense" if i == 2 else "conv", block_=block_)] +
170
+ [Activation(trans_act)] + [
171
+ res(filters=latent[i],
172
+ padding=padding,
173
+ act=act,
174
+ name="body_out",
175
+ model="dv2" if i == 2 else "v2") for _ in
176
+ range(body_layers)] + ([Drop(drop_latent)] if drop_latent else []),
177
+ base_class=SubClassingModel)
178
+ )
179
+ trans_models = lu(trans_models)
180
+ if trans_no > 1:
181
+ trans_models = bm([
182
+ lambda x: [[x[0], x[1]], x[1]],
183
+ SoftBlock(
184
+ model=trans_models,
185
+ score_model=bm([
186
+ Flat2(filters=latent[i], units=256, res_no=2),
187
+ Dense(trans_no, trans_score_activation)
188
+ ])
189
+ )
190
+ ],
191
+ base_class=SubClassingModel
192
+ )
193
+
194
+ model.append(
195
+ TransModel(
196
+ body=trans_models,
197
+ dir_=dir_,
198
+ images_no=images_no
199
+ )
200
+ )
201
+
202
+ tail = []
203
+ for i, s in enumerate(scales):
204
+ flatting = lambda: Flat2(filters=latent[s + 1], base_class=tail_flatten, units=tail_units)
205
+ if s == 0:
206
+ if tail_flatten is None:
207
+ branch = bm([res(filters=latent[s], padding=padding),
208
+ conv(filters=latent[s], padding=padding),
209
+ BatchNormalization(),
210
+ conv(filters=latent[s], padding=padding),
211
+ Flatten()])
212
+ else:
213
+ branch = bm([down(base_class=tail_down), flatting()])
214
+ elif s == 1:
215
+ if tail_flatten is None:
216
+ branch = bm([res(filters=latent[s], padding=padding),
217
+ Flatten()])
218
+ else:
219
+ branch = flatting()
220
+ else:
221
+ branch = bm([tail_units] * 2, add_flatten=False)
222
+ tail.append(branch)
223
+
224
+ tail.append(
225
+ bm([dense(tail_units)] + ([dense_drop(drop_end)] if drop_end else []) + [Dense(output_size)], add_flatten=False,
226
+ name=TAIL))
227
+ class_input = []
228
+
229
+ return bt([
230
+ DictModel(head, in_=INPUTS, out=[LATENT, INFERENCE], name="Head"),
231
+ DictModel(FullTrans(model, scales=scales), in_=[LATENT, INFERENCE], out=TRANS, name="Body"),
232
+ DictModel(RavenClass(Merge(tail), scales=scales, no=8), in_=[LATENT] + class_input, out=CLASSIFICATION,
233
+ name="Classificator"),
234
+ DictModel(RavenClass(Merge(tail), scales=list(range(len(scales))), no=3), in_=[TRANS] + class_input,
235
+ out=OUTPUT, name="Classificator_trans"),
236
+ ],
237
+ loss=RavenLoss(mode=loss_mode, classification=True, trans=True, lw=(1.0, loss_weight)),
238
+ loss_wrap=False
239
+ )
raven_utils/models/trans.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from ml_utils import lw
3
+ from models_utils import ops as K, SubClassingModel
4
+ from tensorflow.keras import Model
5
+
6
+ from models.body import create_dense_block
7
+ import raven_utils as rv
8
+ from raven_utils.const import HORIZONTAL, VERTICAL
9
+
10
+
11
+ class TransModel(Model):
12
+ def __init__(self, body=None, dir_=HORIZONTAL, images_no=8, latent=64):
13
+ super().__init__()
14
+ self.model = body or create_dense_block(latent=latent)
15
+ if dir_ == VERTICAL:
16
+ self.dir = (0, 3, 1, 4, 3, 5)
17
+ else:
18
+ self.dir = (0, 1, 3, 4, 6, 7)
19
+ self.images_no = images_no
20
+ self.latent = latent
21
+
22
+ def call(self, inputs):
23
+ # latents = tnp.asarray(inputs[0])
24
+ latents = inputs[0]
25
+ inference = inputs[1]
26
+ shape = tf.shape(latents)
27
+ new_shape = K.cat([[-1, 3, 2], shape[2:]])
28
+ horizontal = latents[:, self.dir].reshape(new_shape)
29
+ res = tf.TensorArray(tf.float32, size=3)
30
+ for i in range(3):
31
+ res = res.write(i, self.model([horizontal[:, i], inference]))
32
+ result = K.tran(res.stack())
33
+ return result
34
+
35
+
36
+ class TransModel2(Model):
37
+ def __init__(self, body=None, dir_=HORIZONTAL, images_no=8, latent=64):
38
+ super().__init__()
39
+ self.body = body or create_dense_block(latent=latent)
40
+ if dir_ == VERTICAL:
41
+ self.dir = (0, 3, 1, 4, 3, 5)
42
+ else:
43
+ self.dir = (0, 1, 3, 4, 6, 7)
44
+ self.images_no = images_no
45
+ self.latent = latent
46
+
47
+ def call(self, inputs):
48
+ # latents = tnp.asarray(inputs[0])
49
+ latents = inputs[0]
50
+ inference = inputs[1]
51
+ shape = tf.shape(latents)
52
+ new_shape = K.cat([[-1, 3, 2], shape[2:]])
53
+ horizontal = latents[:, self.dir].reshape(new_shape)
54
+ res = tf.TensorArray(tf.float32, size=3)
55
+ for i in tf.range(3):
56
+ res = res.write(i, self.body([horizontal[:, i], inference[:,i]]))
57
+ result = K.tran(res.stack())
58
+ return result
59
+
60
+
61
+ class FullTrans(SubClassingModel):
62
+ def __init__(self, model,scales,name=None):
63
+ super().__init__(model=model,name=name)
64
+ self.scales = scales
65
+
66
+ def call(self, inputs):
67
+ latent = lw(inputs[0])
68
+ inference = lw(inputs[1])
69
+ results = []
70
+ # todo merging inference?
71
+ for i,s in enumerate(self.scales):
72
+ # results.append(model([latent[::-1][i], inference]))
73
+ results.append(self.model[i]([latent[s], inference[s]]))
74
+ return results,
raven_utils/models/transformer.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras.layers import Lambda
3
+ from tensorflow.python.keras import Sequential
4
+
5
+ # from models_utils.models.loss import VTRavenLoss, create_uniform_mask, SingleVTRavenLoss
6
+ from models_utils import DictModel, bt, INPUTS, BatchInitialWeight
7
+ import models_utils.ops as K
8
+ from models_utils.models.transformer.img_seq import init_weights, take_left, mix, empty_last
9
+ from models_utils.models.transformer.img_seq2 import init_weights, take_left, mix, empty_last, img_sec_trans
10
+ from models_utils.ops_core import IndexReshape
11
+ from models_utils.random_ import EpsilonGreedy, EpsilonSoft
12
+ from models_utils.step import StepDict
13
+
14
+ # res = (1 - mask) * inputs[..., :self.last_index] + mask * tf.tile(self.get_last(inputs),
15
+ # (1, 1, 1, self.last_index))
16
+
17
+
18
+ # from data_utils import ims
19
+ # for i in range(50):
20
+ # ims(res[i].numpy().swapaxes(0, 2))
21
+ # res[12].numpy()
22
+ # self.get_last(inputs).numpy()
23
+ # import tensorflow as tf
24
+ # tf.random.uniform(shape=shape[0:1], maxval=255, dtype=tf.int32)
25
+ # from ml_utils import print_error
26
+ # ims(mask[0].numpy())
27
+ # print_error(lambda :ims(mask[0]))
28
+ # from models_utils import ops as K
29
+
30
+
31
+ # self.model.layers[0](x)
32
+ from raven_utils.models.loss import VTRavenLoss, SingleVTRavenLoss, create_uniform_mask
33
+
34
+
35
+ def get_rav_trans(
36
+ data,
37
+ type_=9,
38
+ no=4,
39
+ extractor="ef",
40
+ loss_mode=create_uniform_mask,
41
+ output_size=10,
42
+ loss_weight=1.0,
43
+ out_layers=(1000, 1000, 1000),
44
+ pos_emd="cat",
45
+ base="seq",
46
+ inverse_image=True,
47
+ last="left",
48
+ epsilon="greedy",
49
+ epsilon_step=500,
50
+ mask_fn=None,
51
+ model=None,
52
+ loss="multi",
53
+ **kwargs):
54
+ if last == "left":
55
+ last = take_left
56
+ elif last == "mix":
57
+ last = mix
58
+ elif last == "empty":
59
+ last = empty_last
60
+ elif last == "start":
61
+ last = Sequential([Lambda(empty_last), BatchInitialWeight(initializer=init_weights)])
62
+
63
+ if epsilon == "greedy":
64
+ epsilon = EpsilonGreedy(step=epsilon_step)
65
+ elif epsilon == "soft":
66
+ epsilon = EpsilonSoft(step=epsilon_step)
67
+ elif epsilon is False:
68
+ epsilon = None
69
+
70
+ if epsilon:
71
+ trans_raven = TransRavenwithStep(
72
+ type_=type_,
73
+ no=no,
74
+ extractor=extractor,
75
+ output_size=output_size,
76
+ out_layer=out_layers,
77
+ pos_emd=pos_emd,
78
+ base=base,
79
+ last=last,
80
+ epsilon=epsilon,
81
+ **kwargs
82
+ )
83
+ return StepDict(bt(
84
+ DictModel(
85
+ Sequential([Lambda(lambda x: (255 - x[0], x[1])), trans_raven]) if inverse_image else trans_raven,
86
+ in_=[INPUTS, "step"],
87
+ name="Body"
88
+ ),
89
+ loss=VTRavenLoss(mode=loss_mode, classification=True, lw=(loss_weight, 1.0)),
90
+ loss_wrap=False),
91
+ add_step=epsilon_step,
92
+ )
93
+
94
+ trans_raven = img_sec_trans(
95
+ type_=type_,
96
+ no=no,
97
+ extractor=extractor,
98
+ model=model,
99
+ output_size=output_size,
100
+ out_layer=out_layers,
101
+ pos_emd=pos_emd,
102
+ base=base,
103
+ last=last,
104
+ epsilon=epsilon,
105
+ mask_fn=mask_fn,
106
+ **kwargs
107
+ )
108
+ if loss == "single":
109
+ loss = SingleVTRavenLoss
110
+ else:
111
+ loss = VTRavenLoss
112
+
113
+ # return bt(
114
+ # DictModel(
115
+ # Sequential([Lambda(lambda x: 255 - x), trans_raven]) if inverse_image else trans_raven,
116
+ # inputs=INPUTS,
117
+ # name="Body"
118
+ # ),
119
+ # loss=loss(mode=loss_mode, classification=True, lw=(loss_weight, 1.0)),
120
+ # loss_wrap=False
121
+ # )
122
+
123
+ return bt([
124
+ DictModel(
125
+ Sequential([Lambda(lambda x: 255 - x), trans_raven]) if inverse_image else trans_raven,
126
+ in_=INPUTS,
127
+ name="Body"
128
+ ),
129
+
130
+ ],
131
+ loss=loss(mode=loss_mode, classification=True, lw=(loss_weight, 1.0)),
132
+ loss_wrap=False
133
+ )
raven_utils/models/transformer_2.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import tensorflow as tf
4
+ from tensorflow.keras.layers import Lambda
5
+ from tensorflow.python.keras import Sequential
6
+ from models_utils import ops as K, SubClassing
7
+ from models_utils.models.transformer import aug
8
+
9
+ # from models_utils.models.loss import VTRavenLoss, create_uniform_mask, SingleVTRavenLoss
10
+ from data_utils import DataGenerator, LOSS, TARGET, IMAGES
11
+ from models_utils import DictModel, bt, INPUTS, BatchInitialWeight, build_functional_model, get_input_layer
12
+ import models_utils.ops as K
13
+ from models_utils.models.transformer.img_seq import init_weights, take_left, mix, empty_last
14
+ from models_utils.models.transformer.img_seq2 import init_weights, take_left, mix, empty_last, img_sec_trans
15
+ from models_utils.ops_core import IndexReshape
16
+ from models_utils.random_ import EpsilonGreedy, EpsilonSoft
17
+ from models_utils.step import StepDict
18
+
19
+ from models_utils.models.transformer import aug
20
+
21
+ # res = (1 - mask) * inputs[..., :self.last_index] + mask * tf.tile(self.get_last(inputs),
22
+ # (1, 1, 1, self.last_index))
23
+
24
+
25
+ # from data_utils import ims
26
+ # for i in range(50):
27
+ # ims(res[i].numpy().swapaxes(0, 2))
28
+ # res[12].numpy()
29
+ # self.get_last(inputs).numpy()
30
+ # import tensorflow as tf
31
+ # tf.random.uniform(shape=shape[0:1], maxval=255, dtype=tf.int32)
32
+ # from ml_utils import print_error
33
+ # ims(mask[0].numpy())
34
+ # print_error(lambda :ims(mask[0]))
35
+ # from models_utils import ops as K
36
+
37
+
38
+ # self.model.layers[0](x)
39
+ from raven_utils.constant import INDEX, LABELS
40
+ from raven_utils.models.loss import VTRavenLoss, SingleVTRavenLoss, create_uniform_mask
41
+
42
+
43
+ def get_matrix(inputs, index):
44
+ return tf.concat([inputs[:, :8], K.gather(inputs, index[:, 0])[:, None]], axis=1)
45
+
46
+
47
+ def get_images(inputs):
48
+ return get_matrix(inputs[0], inputs[1])
49
+
50
+
51
+ def random_last(inputs, max_=8):
52
+ index = K.init.label(max=max_, shape=[tf.shape(inputs[0])[0]])[..., None]
53
+ return get_matrix(inputs[0], index)
54
+
55
+
56
+ def get_images_no_answer(inputs):
57
+ return inputs[0][:, :9]
58
+
59
+
60
+ def repeat_last(inputs):
61
+ return inputs[0][:, list(range(8)) + [7]]
62
+
63
+
64
+ def get_rav_trans(
65
+ data,
66
+ inverse_image=True,
67
+ loss_mode=create_uniform_mask,
68
+ loss_weight=1.0,
69
+ loss="multi",
70
+ number_loss=False,
71
+ plw=None,
72
+ pre="auto",
73
+ augmentation=None,
74
+ **kwargs):
75
+ if isinstance(data, DataGenerator):
76
+ data = data[0]['inputs'], data[0]['index']
77
+ # u = img_sec_trans(**kwargs)(get_images(data) if kwargs['mask'] == "random" else get_images_no_answer(data))
78
+ # u.shape
79
+ from keras import Model
80
+ if pre == "auto":
81
+ pre = get_images if kwargs['mask'] == "random" else get_images_no_answer
82
+ elif pre == "no_answer":
83
+ pre = get_images_no_answer
84
+ elif pre == "last":
85
+ pre = repeat_last
86
+ elif pre == "images":
87
+ pre = get_images
88
+ elif pre == "random_last":
89
+ pre = random_last
90
+ elif pre == "noise":
91
+ pre = SubClassing([get_matrix, partial(aug.noise, max_=8)])
92
+ elif pre == "batch_noise":
93
+ pre = SubClassing([get_matrix, partial(aug.batch_noise, max_=8)])
94
+
95
+ if augmentation == "transpose":
96
+ augmentation = aug.Transpose(axis=(0, 2, 1))
97
+ augmentation_label = aug.Transpose(axis=(0, 2, 1))
98
+ elif augmentation == "shuffle_col":
99
+ augmentation = aug.shuffle_col
100
+ augmentation_label = aug.shuffle_col
101
+ elif augmentation == "shuffle":
102
+ augmentation = aug.shuffle
103
+ augmentation_label = aug.shuffle
104
+ if augmentation:
105
+ augmentation = [
106
+ # DictModel(augmentation, IMAGES, IMAGES),
107
+ # DictModel(aug.reshape_static(pre(data),augmentation), IMAGES, IMAGES),
108
+ DictModel(aug.ReshapeStatic(augmentation), IMAGES, IMAGES),
109
+ DictModel(
110
+ aug.PartialModel(
111
+ aug.ReshapeStatic(augmentation_label),
112
+ last_axis=9)
113
+ , LABELS, LABELS)
114
+ ]
115
+ else:
116
+ augmentation = []
117
+
118
+ trans_raven = build_functional_model(
119
+ img_sec_trans(**kwargs),
120
+ # get_images(data) if kwargs['mask'] == "random" else get_images_no_answer(data)
121
+ pre(data)
122
+ # data[0]
123
+ )
124
+ if loss == "single":
125
+ loss = SingleVTRavenLoss
126
+ else:
127
+ loss = VTRavenLoss
128
+ if isinstance(loss_weight, float):
129
+ loss_weight = (loss_weight, 1.0)
130
+
131
+ return bt([
132
+ # DictModel(get_images if kwargs['mask'] == "random" else get_images_no_answer, [INPUTS, INDEX], IMAGES),
133
+ DictModel(pre, [INPUTS, INDEX], IMAGES),
134
+ *augmentation,
135
+ DictModel(
136
+ Sequential([Lambda(lambda x: 255 - x), trans_raven]) if inverse_image else trans_raven,
137
+ in_=IMAGES,
138
+ # inputs=INPUTS,
139
+ name="Body"
140
+ ),
141
+
142
+ ],
143
+ loss=loss(mode=loss_mode, classification=True, number_loss=number_loss, lw=loss_weight, plw=plw),
144
+ predict=LOSS,
145
+ loss_wrap=False
146
+ )
raven_utils/models/transformer_3.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from loguru import logger
4
+ from tensorflow.keras.layers import Lambda
5
+ from tensorflow.keras.layers import Activation
6
+
7
+ from grid_transformer import aug_trans
8
+ from raven_utils.models.loss_3 import VTRavenLoss, SingleVTRavenLoss, create_uniform_mask
9
+ from data_utils import get_shape, TakeDict
10
+
11
+ from data_utils import DataGenerator, LOSS, TARGET, IMAGES
12
+ from models_utils import DictModel, bt, INPUTS, BatchInitialWeight, build_functional, get_input_layer, Last, bm, \
13
+ add_end, AUGMENTATION
14
+ # from report.select_ import SelectModel2, SelectModel, SelectModel9
15
+ from experiment_utils.keras_model import load_weights as model_load_weights
16
+
17
+
18
+ def get_rav_trans(
19
+ data,
20
+ loss_mode=create_uniform_mask,
21
+ loss_weight=2.0,
22
+ number_loss=False,
23
+ dry_run="auto",
24
+ plw=None,
25
+ **kwargs):
26
+ if isinstance(loss_weight, float):
27
+ loss_weight = (loss_weight, 1.0)
28
+
29
+ # seq_trans(**kwargs)(data[0])
30
+ # trans_raven = build_functional_model2(
31
+ # seq_trans(**kwargs),
32
+ # data[0],
33
+ # batch=None
34
+ # )
35
+ trans_raven = build_functional(
36
+ model=aug_trans,
37
+ inputs_=data[0] if isinstance(data, DataGenerator) else data,
38
+ batch_=None,
39
+ dry_run=dry_run,
40
+ **kwargs
41
+ )
42
+
43
+ return bt(
44
+ model=trans_raven,
45
+ loss=VTRavenLoss(mode=loss_mode, classification=True, number_loss=number_loss, lw=loss_weight, plw=plw),
46
+ model_wrap=False,
47
+ predict=LOSS,
48
+ loss_wrap=False
49
+ )
50
+
51
+
52
+ def rav_select_model(
53
+ data,
54
+ load_weights=None,
55
+ loss_weight=(0.01, 0.0),
56
+ plw=5.0,
57
+ result_metric="sparse_categorical_accuracy",
58
+ select_type=2,
59
+ select_out=0,
60
+ additional_out=0,
61
+ additional_copy=True,
62
+ tail_out=(1000, 1000),
63
+ **kwargs
64
+ ):
65
+ out_layers = Last()
66
+ if additional_out > 0:
67
+ model3 = get_rav_trans(
68
+ data,
69
+ plw=plw,
70
+ loss_weight=loss_weight,
71
+ **kwargs
72
+ )
73
+
74
+ model_load_weights(
75
+ model3,
76
+ load_weights,
77
+ # sample_data,
78
+ None,
79
+ template="weights_{epoch:02d}-{val_loss:.2f}",
80
+ key=result_metric,
81
+ )
82
+
83
+ if AUGMENTATION in kwargs and kwargs[AUGMENTATION] is not None:
84
+ index = -1
85
+ else:
86
+ index = -2
87
+
88
+ out = model3[0, index, :additional_out]
89
+ logger.info(f"Additional out from: {model3[0, index]}.")
90
+
91
+ if additional_out > 2:
92
+ out += [Activation("gelu")]
93
+ out_layers = bm([out_layers] + out, add_flatten=False)
94
+ model = get_rav_trans(
95
+ TakeDict(data[0])[:, 8:],
96
+ plw=plw,
97
+ loss_weight=loss_weight,
98
+ **{
99
+ **kwargs,
100
+ "out_layers": out_layers,
101
+ }
102
+ # **{**as_dict(p.mp), "show_shape": True, "save_shape": f"output/shapes/type_{p.mp.type_}.json"},
103
+ )
104
+ # from data_utils.ops import Equal
105
+ # o = []
106
+ # for i in range(1, 3):
107
+ # for j in range(2):
108
+ # o.append(
109
+ # Equal(
110
+ # # model[0,:,-2, i].variables[j],
111
+ # model2[0, :, -2, i].variables[j],
112
+ # # out_layers[i].variables[j]
113
+ # second_pooling[i].variables[j]
114
+ # ).equal
115
+ # )
116
+ # assert all(o)
117
+ # model = get_rav_trans(
118
+ # # TakeDict(val_generator[0])[:, 8:],
119
+ # # TakeDict(val_generator[0])[:, 8:],
120
+ # val_generator[0],
121
+ # plw=p.plw,
122
+ # loss_weight=p.loss_weight,
123
+ # **{**as_dict(p.mp),
124
+ # # "out_layers": out_layers,
125
+ # }
126
+ # # **{**as_dict(p.mp), "show_shape": True, "save_shape": f"output/shapes/type_{p.mp.type_}.json"},
127
+ # )
128
+ model_load_weights(model,
129
+ load_weights,
130
+ # sample_data,
131
+ None,
132
+ template="weights_{epoch:02d}-{val_loss:.2f}",
133
+ key=result_metric,
134
+ )
135
+ # model.compile()
136
+ # model.evaluate(val_generator.data[:1000])
137
+ # model(TakeDict(val_generator[0])[:, 8:])
138
+ trans_raven = model[0]
139
+ # s = trans_raven(TakeDict(val_generator[0])[:, 8:])
140
+ if select_type == 2:
141
+ second_pooling = Lambda(lambda x: x[:, :-1])
142
+ else:
143
+ second_pooling = Last()
144
+ if additional_out > 0:
145
+ if additional_copy:
146
+ model4 = get_rav_trans(
147
+ data,
148
+ plw=plw,
149
+ loss_weight=loss_weight,
150
+ **kwargs
151
+ )
152
+ model_load_weights(model4,
153
+ load_weights,
154
+ # sample_data,
155
+ None,
156
+ template="weights_{epoch:02d}-{val_loss:.2f}",
157
+ key=result_metric,
158
+ )
159
+
160
+ if AUGMENTATION in kwargs and kwargs[AUGMENTATION] is not None:
161
+ index = -1
162
+ else:
163
+ index = -2
164
+ out2 = model4[0, index, :additional_out]
165
+ logger.info(f"Additional out from: {model4[0, index]}.")
166
+
167
+ if additional_out > 2:
168
+ out2 += [Activation("gelu")]
169
+ else:
170
+ out2 = out
171
+
172
+ second_pooling = bm([second_pooling] + out2, add_flatten=False)
173
+
174
+ model2 = get_rav_trans(
175
+ TakeDict(data[0])[:, 8:],
176
+ plw=plw,
177
+ loss_weight=loss_weight,
178
+ **{
179
+ **kwargs,
180
+ "out_layers": second_pooling,
181
+ }
182
+ # **{**as_dict(p.mp), "show_shape": True, "save_shape": f"output/shapes/type_{p.mp.type_}.json"},
183
+ )
184
+ model_load_weights(
185
+ model2,
186
+ load_weights,
187
+ # sample_data,
188
+ None,
189
+ template="weights_{epoch:02d}-{val_loss:.2f}",
190
+ key=result_metric,
191
+ )
192
+ if select_type == 0:
193
+ # not working
194
+ trans_raven2 = model2[0]
195
+ else:
196
+ trans_raven2 = model2[0]
197
+ tail = add_end(out_layers=tail_out, output_size=8 if select_out else 1)
198
+ # trans_raven2.mask_fn = ImageMask(last=take_by_index)
199
+ if select_type == 2:
200
+ select_model_class = SelectModel2
201
+ elif select_type == 1:
202
+ select_model_class = SelectModel
203
+ else:
204
+ select_model_class = SelectModel9
205
+ select_model = select_model_class(trans_raven, model2=trans_raven2, tail=tail, select_out=select_out)
206
+ return select_model
raven_utils/models/uitls_.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import tensorflow.experimental.numpy as tnp
3
+ from tensorflow.keras import Model
4
+ import raven_utils as rv
5
+
6
+
7
+ class RangeMask(Model):
8
+ def __init__(self):
9
+ super().__init__()
10
+ ranges = tf.tile(tf.range(rv.entity.INDEX[-1])[None], [rv.group.NO, 1])
11
+ start_index = rv.entity.INDEX[:-1][:, None]
12
+ end_index = rv.entity.INDEX[1:][:, None]
13
+ self.mask = tnp.array((start_index <= ranges) & (ranges < end_index))
14
+
15
+ def call(self, inputs):
16
+ return self.mask[inputs]
raven_utils/output.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import raven_utils.entity as entity
3
+ import raven_utils.properties as properties
4
+ import raven_utils.group as group
5
+
6
+ SIZE = entity.SUM * properties.SUM + group.NO + entity.SUM
7
+
8
+ SLOT_AND_GROUP = group.NO + entity.SUM
9
+
10
+ PROPERTIES_SLICE = np.s_[:, :-SLOT_AND_GROUP]
11
+ SLOT_SLICE = np.s_[:, -SLOT_AND_GROUP:-group.NO]
12
+ GROUP_SLICE = np.s_[:, -group.NO:]
13
+
14
+ GROUP_SLICE_END = np.s_[-group.NO:]
15
+ SLOT_SLICE_END = np.s_[-SLOT_AND_GROUP:-group.NO]
16
+ PROPERTIES_SLICE_END = np.s_[:-SLOT_AND_GROUP]
raven_utils/params.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Tuple
3
+
4
+ from ml_utils import get_str_name
5
+ from grid_transformer.params import ImgSeqTransformerParameters
6
+ from raven_utils import output
7
+
8
+ from experiment_utils.parameters.nn_default import TP, EP
9
+
10
+
11
+ @dataclass
12
+ class SudokuParameters(ImgSeqTransformerParameters):
13
+ mask: str = "input"
14
+ col: int = 3
15
+ row: int = 3
16
+ pooling: int = 81
17
+ output_size: int = 9
18
+ size: int = 384
19
+
20
+
21
+ @dataclass
22
+ class RavenTransParameters(ImgSeqTransformerParameters):
23
+ mask: str = "last"
24
+ last_index: int = 8
25
+ col: int = 1
26
+ row: int = 1
27
+ output_size: int = output.SIZE
28
+ number_loss: bool = 0
29
+ pre: str = "images"
30
+ num_heads: int = 8
31
+
32
+
33
+ MP = RavenTransParameters
34
+
35
+
36
+ @dataclass
37
+ class RavenSelectTransParameters(RavenTransParameters):
38
+ select_type: int = 2
39
+ select_out: int = 0
40
+ additional_out: int = 0
41
+ additional_copy: bool = True
42
+ tail_out: Tuple = (1000, 1000)
43
+ pre: str = "index"
44
+
45
+
46
+ SMP = RavenSelectTransParameters
47
+
48
+ from raven_utils.config.models import AVAILABLE_MODELS
49
+ from experiment_utils.parameters.nn_clean import Parameters as BaseParameters
50
+ from raven_utils.config.constant import RAVEN, LABELS, INDEX, FEATURES, RAV_METRICS, IMP_RAV_METRICS, ACC_NO_GROUP, \
51
+ ACC_SAME
52
+
53
+ MODEL_NO = -1
54
+
55
+
56
+ @dataclass
57
+ class RavenParameters(BaseParameters):
58
+ dataset_name: str = RAVEN
59
+ data: Any = (
60
+ f"{dataset_name}/train.npy",
61
+ f"{dataset_name}/val.npy",
62
+ f"{dataset_name}/train_labels.npy",
63
+ f"{dataset_name}/val_labels.npy",
64
+ f"{dataset_name}/train_target.npy",
65
+ f"{dataset_name}/val_target.npy",
66
+ f"arr/train_features_{AVAILABLE_MODELS[MODEL_NO]}.npy",
67
+ f"arr/val_features_{AVAILABLE_MODELS[MODEL_NO]}.npy",
68
+ f"{dataset_name}/val_index.npy"
69
+ # DataParameters2()
70
+ )
71
+
72
+ # core_metrics: tuple = tuple(RAV_METRICS)
73
+ filter_metrics: tuple = tuple(IMP_RAV_METRICS)
74
+ # result_metric: str = ACC_NO_GROUP
75
+ result_metric: str = ACC_SAME
76
+
77
+ lw: float = 0.0001 # Autoencoder
78
+ loss_weight: float = 2.0
79
+ plw: int = 5.0
80
+ mp: RavenTransParameters = RavenTransParameters()
81
+
82
+ @property
83
+ def experiment(self):
84
+ # return "rav/trans"
85
+ return "rav/best_test3"
86
+ # return "rav/trans_weight"
87
+
88
+ # @property
89
+ # def name(self):
90
+ # # return f"i{self.extractor}_{len(self.tail)}{self.tail[0]}_{self.type_}_{self.epsilon}_{self.last}_{self.epsilon_step}"
91
+ # return f"{get_str_name(self.mp.pre)[0]}_{str(self.plw)[0]}_{str(self.mp.number_loss)[0]}_{self.mp.extractor}_{self.mp.noise if self.mp.noise else ''}_{self.mp.augmentation if self.mp.augmentation else ''}_{self.mp.extractor_shape}_{self.mp.no}_{self.mp.num_heads}_{self.mp.size}_{self.mp.pos_emd}_{self.mp.ff_mul}_{self.tp.batch}"
92
+
93
+
94
+ @dataclass
95
+ class BaselineRavenParameters(RavenParameters):
96
+
97
+ @property
98
+ def experiment(self):
99
+ # return "rav/best_test3"
100
+ return "rav/baseline"
101
+ # return "rav/trans_weight"
102
+
103
+ @property
104
+ def name(self):
105
+ # return f"i{self.extractor}_{len(self.tail)}{self.tail[0]}_{self.type_}_{self.epsilon}_{self.last}_{self.epsilon_step}"
106
+ return f"{get_str_name(self.mp.pre)[0]}_{str(self.plw)[0]}_{str(self.mp.number_loss)[0]}_{self.mp.extractor}_{self.mp.noise if self.mp.noise else ''}_{self.mp.augmentation if self.mp.augmentation else ''}_{self.mp.extractor_shape}_{self.mp.no}_{self.mp.num_heads}_{self.mp.size}_{self.mp.pos_emd}_{self.mp.ff_mul}_{self.tp.batch}"
107
+
108
+
109
+ if __name__ == '__main__':
110
+ params = PreRavenTransParameters()
raven_utils/properties.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import raven_utils as rv
2
+ from ml_utils import dict_from_list2, CalcDict
3
+ import raven_utils.entity as entity
4
+
5
+ NAMES = [
6
+ 'Color',
7
+ 'Size',
8
+ 'Type',
9
+ ]
10
+ RAW_SIZE = [10, 6, 5]
11
+ SIZE = dict_from_list2(NAMES, RAW_SIZE)
12
+ ANGLE_SIZE = 7
13
+ NO = len(NAMES)
14
+
15
+ INDEX = (CalcDict(SIZE) * entity.SUM).to_dict()
16
+ SUM = sum(list(SIZE.values()))
raven_utils/range_mask.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import tensorflow.experimental.numpy as tnp
3
+ from tensorflow.keras import Model
4
+ import raven_utils as rv
5
+
6
+
7
+ class RangeMask(Model):
8
+ def __init__(self):
9
+ super().__init__()
10
+ ranges = tf.tile(tf.range(rv.entity.INDEX[-1])[None], [rv.group.NO, 1])
11
+ start_index = rv.entity.INDEX[:-1][:, None]
12
+ end_index = rv.entity.INDEX[1:][:, None]
13
+ self.mask = tnp.array((start_index <= ranges) & (ranges < end_index))
14
+
15
+ def call(self, inputs):
16
+ return self.mask[inputs]
raven_utils/render/__init__.py ADDED
File without changes
raven_utils/render/const.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+
4
+ # Maximum number of components in a RPM
5
+ MAX_COMPONENTS = 2
6
+
7
+ # Canvas parameters
8
+ IMAGE_SIZE = 160
9
+ CENTER = (IMAGE_SIZE / 2, IMAGE_SIZE / 2)
10
+ DEFAULT_RADIUS = IMAGE_SIZE / 4
11
+ DEFAULT_WIDTH = 2
12
+
13
+ # Attribute parameters
14
+ # Number
15
+ NUM_VALUES = [1, 2, 3, 4, 5, 6, 7, 8, 9]
16
+ NUM_MIN = 0
17
+ NUM_MAX = len(NUM_VALUES) - 1
18
+
19
+ # Uniformity
20
+ UNI_VALUES = [False, False, False, True]
21
+ UNI_MIN = 0
22
+ UNI_MAX = len(UNI_VALUES) - 1
23
+
24
+ # Type
25
+ TYPE_VALUES = ["none", "triangle", "square", "pentagon", "hexagon", "circle"]
26
+ TYPE_MIN = 0
27
+ TYPE_MAX = len(TYPE_VALUES) - 1
28
+
29
+ # Size
30
+ SIZE_VALUES = [0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
31
+ SIZE_MIN = 0
32
+ SIZE_MAX = len(SIZE_VALUES) - 1
33
+
34
+ # Color
35
+ COLOR_VALUES = [255, 224, 196, 168, 140, 112, 84, 56, 28, 0]
36
+ COLOR_MIN = 0
37
+ COLOR_MAX = len(COLOR_VALUES) - 1
38
+
39
+ # Angle: self-rotation
40
+ ANGLE_VALUES = [-135, -90, -45, 0, 45, 90, 135, 180]
41
+ ANGLE_MIN = 0
42
+ ANGLE_MAX = len(ANGLE_VALUES) - 1
43
+
44
+ META_TARGET_FORMAT = ["Constant", "Progression", "Arithmetic", "Distribute_Three", "Number", "Position", "Type", "Size", "Color"]
45
+ META_STRUCTURE_FORMAT = ["Singleton", "Left_Right", "Up_Down", "Out_In", "Left", "Right", "Up", "Down", "Out", "In", "Grid", "Center_Single", "Distribute_Four", "Distribute_Nine", "Left_Center_Single", "Right_Center_Single", "Up_Center_Single", "Down_Center_Single", "Out_Center_Single", "In_Center_Single", "In_Distribute_Four"]
46
+
47
+ # Rule, Attr, Param
48
+ # The design encodes rule priority order: Number/Position always comes first
49
+ # Number and Position could not both be sampled
50
+ # Progression on Number: Number on each Panel +1/2 or -1/2
51
+ # Progression on Position: Entities on each Panel roll over the layout
52
+ # Arithmetic on Number: Numeber on the third Panel = Number on first +/- Number on second (1 for + and -1 for -)
53
+ # Arithmetic on Position: 1 for SET_UNION and -1 for SET_DIFF
54
+ # Distribute_Three on Number: Three numbers through each row
55
+ # Distribute_Three on Position: Three positions (same number) through each row
56
+ # Constant on Number/Position: Nothing changes
57
+ # Progression on Type: Type progression defined as the number of edges on each entity (Triangle, Square, Pentagon, Hexagon, Circle)
58
+ # Distribute_Three on Type: Three types through each row
59
+ # Constant on Type: Nothing changes
60
+ # Progression on Size: Size on each entity +1/2 or -1/2
61
+ # Arithmetic on Size: Size on the third Panel = Size on the first +/- Size on the second (1 for + and -1 for -)
62
+ # Distribute_Three on Size: Three sizes through each row
63
+ # Constant on Size: Nothing changes
64
+ # Progression on Color: Color +1/2 or -1/2
65
+ # Arithmetic on Color: Color on the third Panel = Color on the first +/- Color on the second (1 for + and -1 for -)
66
+ # Distribute_Three on Color: Three colors through each row
67
+ # Constant on Color: Nothing changes
68
+ # Note that all rules on Type, Size and Color enforce value consistency in a panel
69
+ RULE_ATTR = [[["Progression", "Number", [-2, -1, 1, 2]],
70
+ ["Progression", "Position", [-2, -1, 1, 2]],
71
+ ["Arithmetic", "Number", [1, -1]],
72
+ ["Arithmetic", "Position", [1, -1]],
73
+ ["Distribute_Three", "Number", None],
74
+ ["Distribute_Three", "Position", None],
75
+ ["Constant", "Number/Position", None]],
76
+ [["Progression", "Type", [-2, -1, 1, 2]],
77
+ ["Distribute_Three", "Type", None],
78
+ ["Constant", "Type", None]],
79
+ [["Progression", "Size", [-2, -1, 1, 2]],
80
+ ["Arithmetic", "Size", [1, -1]],
81
+ ["Distribute_Three", "Size", None],
82
+ ["Constant", "Size", None]],
83
+ [["Progression", "Color", [-2, -1, 1, 2]],
84
+ ["Arithmetic", "Color", [1, -1]],
85
+ ["Distribute_Three", "Color", None],
86
+ ["Constant", "Color", None]]]
raven_utils/render/rendering.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+
4
+ import cv2
5
+ import numpy as np
6
+ from PIL import Image
7
+ #
8
+ # from AoT import Root
9
+ import raven_utils.decode
10
+ from raven_utils.render.const import CENTER, DEFAULT_WIDTH, IMAGE_SIZE
11
+
12
+ from data_utils import Bag
13
+
14
+ from raven_utils.render_ import COLOR_VALUES, SIZE_VALUES, TYPE_VALUES, ANGLE_VALUES, RENDER_POSITIONS
15
+
16
+
17
+ def imshow(array):
18
+ image = Image.fromarray(array)
19
+ image.show()
20
+
21
+
22
+ def imsave(array, filepath):
23
+ image = Image.fromarray(array)
24
+ image.save(filepath)
25
+
26
+
27
+ def generate_matrix(array_list):
28
+ # row-major array_list
29
+ assert len(array_list) <= 9
30
+ img_grid = np.zeros((IMAGE_SIZE * 3, IMAGE_SIZE * 3), np.uint8)
31
+ for idx in range(len(array_list)):
32
+ i, j = divmod(idx, 3)
33
+ img_grid[i * IMAGE_SIZE:(i + 1) * IMAGE_SIZE, j * IMAGE_SIZE:(j + 1) * IMAGE_SIZE] = array_list[idx]
34
+ # draw grid
35
+ for x in [0.33, 0.67]:
36
+ img_grid[int(x * IMAGE_SIZE * 3) - 1:int(x * IMAGE_SIZE * 3) + 1, :] = 0
37
+ for y in [0.33, 0.67]:
38
+ img_grid[:, int(y * IMAGE_SIZE * 3) - 1:int(y * IMAGE_SIZE * 3) + 1] = 0
39
+ return img_grid
40
+
41
+
42
+ def generate_answers(array_list):
43
+ assert len(array_list) <= 8
44
+ img_grid = np.zeros((IMAGE_SIZE * 2, IMAGE_SIZE * 4), np.uint8)
45
+ for idx in range(len(array_list)):
46
+ i, j = divmod(idx, 4)
47
+ img_grid[i * IMAGE_SIZE:(i + 1) * IMAGE_SIZE, j * IMAGE_SIZE:(j + 1) * IMAGE_SIZE] = array_list[idx]
48
+ # draw grid
49
+ for x in [0.5]:
50
+ img_grid[int(x * IMAGE_SIZE * 2) - 1:int(x * IMAGE_SIZE * 2) + 1, :] = 0
51
+ for y in [0.25, 0.5, 0.75]:
52
+ img_grid[:, int(y * IMAGE_SIZE * 4) - 1:int(y * IMAGE_SIZE * 4) + 1] = 0
53
+ return img_grid
54
+
55
+
56
+ def generate_matrix_answer(array_list):
57
+ # row-major array_list
58
+ assert len(array_list) <= 18
59
+ img_grid = np.zeros((IMAGE_SIZE * 6, IMAGE_SIZE * 3), np.uint8)
60
+ for idx in range(len(array_list)):
61
+ i, j = divmod(idx, 3)
62
+ img_grid[i * IMAGE_SIZE:(i + 1) * IMAGE_SIZE, j * IMAGE_SIZE:(j + 1) * IMAGE_SIZE] = array_list[idx]
63
+ # draw grid
64
+ for x in [0.33, 0.67, 1.00, 1.33, 1.67]:
65
+ img_grid[int(x * IMAGE_SIZE * 3), :] = 0
66
+ for y in [0.33, 0.67]:
67
+ img_grid[:, int(y * IMAGE_SIZE * 3)] = 0
68
+ return img_grid
69
+
70
+
71
+ def merge_matrix_answer(matrix, answer):
72
+ matrix_image = generate_matrix(matrix)
73
+ answer_image = generate_answers(answer)
74
+ img_grid = np.ones((IMAGE_SIZE * 5 + 20, IMAGE_SIZE * 4), np.uint8) * 255
75
+ img_grid[:IMAGE_SIZE * 3, int(0.5 * IMAGE_SIZE):int(3.5 * IMAGE_SIZE)] = matrix_image
76
+ img_grid[-(IMAGE_SIZE * 2):, :] = answer_image
77
+ return img_grid
78
+
79
+
80
+ def render_panels(feature, target=True,angle=None):
81
+ # Decompose the panel into a structure and its entities
82
+ # root
83
+ # rv.decode_output(root)
84
+ # rv.decode_output_reshape(root)
85
+ # decoded =
86
+ # panel = decoded[0]
87
+ panels = []
88
+ for group, exist, color, size, type_ in Bag(raven_utils.decode.decode_target_flat(feature)):
89
+ canvas = np.ones((IMAGE_SIZE, IMAGE_SIZE), np.uint8) * 255
90
+ structure_img = render_structure(group)
91
+ background = np.zeros((IMAGE_SIZE, IMAGE_SIZE), np.uint8)
92
+ # note left components entities are in the lower layer
93
+ for i, entity in enumerate(exist):
94
+ if entity:
95
+ entity_img = render_entity(RENDER_POSITIONS[i], color[i], size[i], type_[i] + 1, angle=angle)
96
+ background = layer_add(background, entity_img)
97
+ background = layer_add(background, structure_img)
98
+ panels.append(canvas - background)
99
+ return np.stack(panels)
100
+
101
+
102
+ def render_structure(structure):
103
+ if structure == 5:
104
+ ret = np.zeros((IMAGE_SIZE, IMAGE_SIZE), np.uint8)
105
+ ret[:, int(0.5 * IMAGE_SIZE)] = 255.0
106
+ elif structure == 6:
107
+ ret = np.zeros((IMAGE_SIZE, IMAGE_SIZE), np.uint8)
108
+ ret[int(0.5 * IMAGE_SIZE), :] = 255.0
109
+ else:
110
+ ret = np.zeros((IMAGE_SIZE, IMAGE_SIZE), np.uint8)
111
+ return ret
112
+
113
+
114
+ def render_entity(bbox, color, size, type_, angle=None):
115
+ color = COLOR_VALUES[color]
116
+ size = SIZE_VALUES[size]
117
+ type_ = TYPE_VALUES[type_]
118
+ if angle is None:
119
+ angle = np.random.randint(0, 7, 1)[0]
120
+ angle = ANGLE_VALUES[angle]
121
+ img = np.zeros((IMAGE_SIZE, IMAGE_SIZE), np.uint8)
122
+
123
+ # planar position: [x, y, w, h]
124
+ # angular position: [x, y, w, h, x_c, y_c, omega]
125
+ # center: (columns, rows)
126
+ center = (int(bbox[1] * IMAGE_SIZE), int(bbox[0] * IMAGE_SIZE))
127
+ if type_ == "triangle":
128
+ unit = min(bbox[2], bbox[3]) * IMAGE_SIZE / 2
129
+ dl = int(unit * size)
130
+ pts = np.array([[center[0], center[1] - dl],
131
+ [center[0] + int(dl / 2.0 * np.sqrt(3)), center[1] + int(dl / 2.0)],
132
+ [center[0] - int(dl / 2.0 * np.sqrt(3)), center[1] + int(dl / 2.0)]],
133
+ np.int32)
134
+ pts = pts.reshape((-1, 1, 2))
135
+ color = 255 - color
136
+ width = DEFAULT_WIDTH
137
+ draw_triangle(img, pts, color, width)
138
+ elif type_ == "square":
139
+ unit = min(bbox[2], bbox[3]) * IMAGE_SIZE / 2
140
+ dl = int(unit / 2 * np.sqrt(2) * size)
141
+ pt1 = (center[0] - dl, center[1] - dl)
142
+ pt2 = (center[0] + dl, center[1] + dl)
143
+ color = 255 - color
144
+ width = DEFAULT_WIDTH
145
+ draw_square(img, pt1, pt2, color, width)
146
+ elif type_ == "pentagon":
147
+ unit = min(bbox[2], bbox[3]) * IMAGE_SIZE / 2
148
+ dl = int(unit * size)
149
+ pts = np.array([[center[0], center[1] - dl],
150
+ [center[0] - int(dl * np.cos(np.pi / 10)), center[1] - int(dl * np.sin(np.pi / 10))],
151
+ [center[0] - int(dl * np.sin(np.pi / 5)), center[1] + int(dl * np.cos(np.pi / 5))],
152
+ [center[0] + int(dl * np.sin(np.pi / 5)), center[1] + int(dl * np.cos(np.pi / 5))],
153
+ [center[0] + int(dl * np.cos(np.pi / 10)), center[1] - int(dl * np.sin(np.pi / 10))]],
154
+ np.int32)
155
+ pts = pts.reshape((-1, 1, 2))
156
+ color = 255 - color
157
+ width = DEFAULT_WIDTH
158
+ draw_pentagon(img, pts, color, width)
159
+ elif type_ == "hexagon":
160
+ unit = min(bbox[2], bbox[3]) * IMAGE_SIZE / 2
161
+ dl = int(unit * size)
162
+ pts = np.array([[center[0], center[1] - dl],
163
+ [center[0] - int(dl / 2.0 * np.sqrt(3)), center[1] - int(dl / 2.0)],
164
+ [center[0] - int(dl / 2.0 * np.sqrt(3)), center[1] + int(dl / 2.0)],
165
+ [center[0], center[1] + dl],
166
+ [center[0] + int(dl / 2.0 * np.sqrt(3)), center[1] + int(dl / 2.0)],
167
+ [center[0] + int(dl / 2.0 * np.sqrt(3)), center[1] - int(dl / 2.0)]],
168
+ np.int32)
169
+ pts = pts.reshape((-1, 1, 2))
170
+ color = 255 - color
171
+ width = DEFAULT_WIDTH
172
+ draw_hexagon(img, pts, color, width)
173
+ elif type_ == "circle":
174
+ # Minus because of the way we show the image. See: render_panel's return
175
+ color = 255 - color
176
+ unit = min(bbox[2], bbox[3]) * IMAGE_SIZE / 2
177
+ radius = int(unit * size)
178
+ width = DEFAULT_WIDTH
179
+ draw_circle(img, center, radius, color, width)
180
+ elif type_ == "none":
181
+ pass
182
+ # angular
183
+ if len(bbox) > 4:
184
+ # [x, y, w, h, x_c, y_c, omega]
185
+ angle = bbox[6]
186
+ center = (int(bbox[5] * IMAGE_SIZE), int(bbox[4] * IMAGE_SIZE))
187
+ img = rotate(img, angle, center=center)
188
+ # planar
189
+ else:
190
+ img = rotate(img, angle, center=center)
191
+ # img = shift(img, *entity_position)
192
+
193
+ return img
194
+
195
+
196
+ def shift(img, dx, dy):
197
+ M = np.array([[1, 0, dx], [0, 1, dy]], np.float32)
198
+ img = cv2.warpAffine(img, M, (IMAGE_SIZE, IMAGE_SIZE), flags=cv2.INTER_LINEAR)
199
+ return img
200
+
201
+
202
+ def rotate(img, angle, center=CENTER):
203
+ M = cv2.getRotationMatrix2D(center, angle, 1)
204
+ img = cv2.warpAffine(img, M, (IMAGE_SIZE, IMAGE_SIZE), flags=cv2.INTER_LINEAR)
205
+ return img
206
+
207
+
208
+ def scale(img, tx, ty, center=CENTER):
209
+ M = np.array([[tx, 0, center[0] * (1 - tx)], [0, ty, center[1] * (1 - ty)]], np.float32)
210
+ img = cv2.warpAffine(img, M, (IMAGE_SIZE, IMAGE_SIZE), flags=cv2.INTER_LINEAR)
211
+ return img
212
+
213
+
214
+ def layer_add(lower_layer_np, higher_layer_np):
215
+ # higher_layer_np is superimposed on lower_layer_np
216
+ # new_np = lower_layer_np.copy()
217
+ # lower_layer_np is modified
218
+ lower_layer_np[higher_layer_np > 0] = 0
219
+ return lower_layer_np + higher_layer_np
220
+
221
+
222
+ # Draw primitives
223
+ def draw_triangle(img, pts, color, width):
224
+ # if filled
225
+ if color != 0:
226
+ # fill the interior
227
+ cv2.fillConvexPoly(img, pts, color)
228
+ # draw the edge
229
+ cv2.polylines(img, [pts], True, 255, width)
230
+ # if not filled
231
+ else:
232
+ cv2.polylines(img, [pts], True, 255, width)
233
+
234
+
235
+ def draw_square(img, pt1, pt2, color, width):
236
+ # if filled
237
+ if color != 0:
238
+ # fill the interior
239
+ cv2.rectangle(img,
240
+ pt1,
241
+ pt2,
242
+ color,
243
+ -1)
244
+ # draw the edge
245
+ cv2.rectangle(img,
246
+ pt1,
247
+ pt2,
248
+ 255,
249
+ width)
250
+ # if not filled
251
+ else:
252
+ cv2.rectangle(img,
253
+ pt1,
254
+ pt2,
255
+ 255,
256
+ width)
257
+
258
+
259
+ def draw_pentagon(img, pts, color, width):
260
+ # if filled
261
+ if color != 0:
262
+ # fill the interior
263
+ cv2.fillConvexPoly(img, pts, color)
264
+ # draw the edge
265
+ cv2.polylines(img, [pts], True, 255, width)
266
+ # if not filled
267
+ else:
268
+ cv2.polylines(img, [pts], True, 255, width)
269
+
270
+
271
+ def draw_hexagon(img, pts, color, width):
272
+ # if filled
273
+ if color != 0:
274
+ # fill the interior
275
+ cv2.fillConvexPoly(img, pts, color)
276
+ # draw the edge
277
+ cv2.polylines(img, [pts], True, 255, width)
278
+ # if not filled
279
+ else:
280
+ cv2.polylines(img, [pts], True, 255, width)
281
+
282
+
283
+ def draw_circle(img, center, radius, color, width):
284
+ # if filled
285
+ if color != 0:
286
+ # fill the interior
287
+ cv2.circle(img,
288
+ center,
289
+ radius,
290
+ color,
291
+ -1)
292
+ # draw the edge
293
+ cv2.circle(img,
294
+ center,
295
+ radius,
296
+ 255,
297
+ width)
298
+ # if not filled
299
+ else:
300
+ cv2.circle(img,
301
+ center,
302
+ radius,
303
+ 255,
304
+ width)
raven_utils/render_.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ COLOR_VALUES = [255, 224, 196, 168, 140, 112, 84, 56, 28, 0]
2
+ SIZE_VALUES = [0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
3
+ TYPE_VALUES = ["none", "triangle", "square", "pentagon", "hexagon", "circle"]
4
+ ANGLE_VALUES = [-135, -90, -45, 0, 45, 90, 135, 180]
5
+ RENDER_POSITIONS_GROUPED = [
6
+ [(0.5, 0.5, 1, 1)],
7
+ # ...
8
+ [(0.25, 0.25, 0.5, 0.5),
9
+ (0.25, 0.75, 0.5, 0.5),
10
+ (0.75, 0.25, 0.5, 0.5),
11
+ (0.75, 0.75, 0.5, 0.5)],
12
+ # ...
13
+ [(0.16, 0.16, 0.33, 0.33),
14
+ (0.16, 0.5, 0.33, 0.33),
15
+ (0.16, 0.83, 0.33, 0.33),
16
+ (0.5, 0.16, 0.33, 0.33),
17
+ (0.5, 0.5, 0.33, 0.33),
18
+ (0.5, 0.83, 0.33, 0.33),
19
+ (0.83, 0.16, 0.33, 0.33),
20
+ (0.83, 0.5, 0.33, 0.33),
21
+ (0.83, 0.83, 0.33, 0.33)],
22
+ # ...
23
+ [(0.5, 0.5, 1, 1)],
24
+ [(0.5, 0.5, 0.33, 0.33)],
25
+ # ...
26
+ [(0.5, 0.5, 1, 1)],
27
+ [(0.42, 0.42, 0.15, 0.15),
28
+ (0.42, 0.58, 0.15, 0.15),
29
+ (0.58, 0.42, 0.15, 0.15),
30
+ (0.58, 0.58, 0.15, 0.15)],
31
+ # ....
32
+ [(0.5, 0.25, 0.5, 0.5)],
33
+ [(0.5, 0.75, 0.5, 0.5)],
34
+ # ...
35
+ [(0.25, 0.5, 0.5, 0.5)],
36
+ [(0.75, 0.5, 0.5, 0.5)],
37
+ # ...
38
+ ]
39
+ RENDER_POSITIONS = [pos_ for group_pos_ in RENDER_POSITIONS_GROUPED for pos_ in group_pos_]
40
+ MAPPING = {
41
+ "distribute_nine":
42
+ {0.16: 0,
43
+ 0.5: 1,
44
+ 0.83: 2},
45
+ "distribute_four":
46
+ {0.25: 0,
47
+ 0.75: 1},
48
+ 'in_distribute_four_out_center_single':
49
+ {0.42: 0,
50
+ 0.58: 1}
51
+ }
52
+ MUL = {
53
+ "distribute_nine": 3,
54
+ "distribute_four": 2,
55
+ 'in_distribute_four_out_center_single': 2
56
+ }
57
+ TYPES = ["triangle", "square", "pentagon", "hexagon", "circle"]
58
+ TYPES_NONE = ["none", "triangle", "square", "pentagon", "hexagon", "circle"]
59
+ SIZES = ["vs", "s", "m", "h", "vh", "e"]
60
+ SIZES_NAME = ["Very Small", "Small", "Medium", "High", "Very High", "Enormous"]
61
+ COLORS = ["vs", "s", "m", "h", "vh", "e"]
62
+
63
+ SAMPLE_TARGET = [[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
64
+ 0, 0, 0, 0, 9, 5, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
65
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
66
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
67
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 2, 0, 0, 0, 0,
68
+ 0, 1, 3],
69
+ [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
70
+ 0, 0, 0, 0, 0, 0, 0, 2, 0, 3, 2, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0,
71
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
72
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
73
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 2, 3, 0, 0, 0, 0,
74
+ 0, 3, 3],
75
+ [2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
76
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
77
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0,
78
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
79
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 3, 2, 0, 0, 0, 0,
80
+ 0, 0, 3],
81
+ [3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0,
82
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
83
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
84
+ 0, 0, 0, 5, 2, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
85
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3,
86
+ 3, 2, 1],
87
+ [4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1,
88
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
89
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
90
+ 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 3,
91
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 3, 2, 1,
92
+ 3, 0, 1],
93
+ [5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
94
+ 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
95
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
96
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
97
+ 0, 2, 0, 0, 7, 5, 4, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 3, 0, 0, 3, 3,
98
+ 3, 1, 0],
99
+ [6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
100
+ 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
101
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
102
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
103
+ 0, 0, 0, 0, 0, 0, 0, 6, 5, 0, 8, 5, 1, 0, 0, 0, 3, 2, 0, 0, 1, 0,
104
+ 3, 3, 3]]
raven_utils/rules.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ml_utils import dict_from_list
2
+ COMBINE = "Number/Position"
3
+
4
+ ATTRIBUTES = [
5
+ "Number",
6
+ "Position",
7
+ "Color",
8
+ "Size",
9
+ "Type"
10
+ ]
11
+ ATTRIBUTES_LEN = len(ATTRIBUTES)
12
+ ATTRIBUTES_INDEX = dict_from_list(ATTRIBUTES)
13
+
14
+ TYPES = [
15
+ "Constant",
16
+ "Arithmetic",
17
+ "Progression",
18
+ "Distribute_Three"
19
+ ]
20
+ TYPES_INDEX = dict_from_list(TYPES)
21
+ TYPES_LEN = len(TYPES)
raven_utils/target.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import raven_utils.group as group
3
+ import raven_utils.entity as entity
4
+ import raven_utils.rules as rules
5
+ import raven_utils.properties as properties
6
+
7
+
8
+ ENTITY_INDEX = entity.INDEX + 1
9
+ ENTITY_DICT = dict(zip(group.NAMES, ENTITY_INDEX[:-1]))
10
+ NAMES_ORDER = dict(zip(group.NAMES, np.arange(len(group.NAMES))))
11
+ PROPERTIES_INDEXES = np.cumsum(np.array(list(entity.NO.values())) * properties.NO)
12
+ INDEX = np.concatenate([[0], PROPERTIES_INDEXES]) + entity.SUM + 1 # +2 type and uniformity
13
+
14
+ SECOND_LAYOUT = [i - 1 for i in [
15
+ ENTITY_DICT["in_center_single_out_center_single"] + 1,
16
+ ENTITY_DICT["in_distribute_four_out_center_single"] + 1,
17
+ ENTITY_DICT["in_distribute_four_out_center_single"] + 2,
18
+ ENTITY_DICT["in_distribute_four_out_center_single"] + 3,
19
+ ENTITY_DICT["left_center_single_right_center_single"] + 1,
20
+ ENTITY_DICT["up_center_single_down_center_single"] + 1
21
+ ]]
22
+
23
+ FIRST_LAYOUT = list(set(range(entity.SUM)) - set(SECOND_LAYOUT))
24
+ LAYOUT_NO = 2
25
+
26
+ START_INDEX = dict(zip(group.NAMES, INDEX[:-1]))
27
+ END_INDEX = INDEX[-1]
28
+
29
+ RULES_ATTRIBUTES_ALL_LEN = rules.ATTRIBUTES_LEN * LAYOUT_NO
30
+ UNIFORMITY_NO = 2
31
+ UNIFORMITY_INDEX = END_INDEX + RULES_ATTRIBUTES_ALL_LEN
32
+
33
+ SIZE = UNIFORMITY_INDEX + UNIFORMITY_NO
34
+
35
+ def take(target):
36
+ return target[1], target[2]
37
+
38
+
39
+ def create(images, index, pattern_index=(2, 5), full_index=False, arrange=np.arange, shape=lambda x: x.shape):
40
+ return [images[:, pattern_index[0]], images[:, pattern_index[1]],
41
+ images[arrange(shape(index)[0]), (0 if full_index else 8) + index[:, 0]]]
42
+
43
+
44
+
45
+ def take_simple(target):
46
+ return target[1], target[0]
47
+
48
+
49
+ def create_simple(images, target, index=slice(None), pattern_index=(2, 5)):
50
+ return [images[:, pattern_index[0]], images[:, pattern_index[1]], target][index]
raven_utils/uitls.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from itertools import product
3
+
4
+ import numpy as np
5
+ from funcy import identity
6
+
7
+ from data_utils import gather, DataGenerator, Data
8
+ from data_utils.sampling import DataSampler
9
+ from models_utils import init_image as def_init_image, INPUTS, TARGET
10
+
11
+ import raven_utils.group as group
12
+
13
+ from data_utils import ops as D
14
+
15
+ init_image = partial(def_init_image, shape=(16, 8, 80, 80, 1))
16
+
17
+
18
+ def get_val_index(no=group.NO, base=3,add_end=False):
19
+ indexes = np.arange(no) * 2000 + base
20
+ if add_end:
21
+ indexes = np.concatenate([indexes, no*2000])
22
+ return indexes
23
+
24
+
25
+ def get_matrix(inputs, index):
26
+ return np.concatenate([inputs[:, :8], gather(inputs, index[:, 0])[:, None]], axis=1)
27
+
28
+
29
+ def get_matrix_from_data(x):
30
+ inputs = x["inputs"]
31
+ index = x["index"]
32
+ return get_matrix(inputs, index)
33
+
34
+
35
+ def get_data_class(data, batch_size=128):
36
+ fn = identity
37
+ shape = data[0].shape
38
+ train_generator = DataGenerator(
39
+ {
40
+ INPUTS: Data(data[0], fn),
41
+ TARGET: Data(data[2], fn),
42
+ },
43
+ sampler=DataSampler(np.array(list(product(np.arange(shape[0]), np.arange(shape[1]))))),
44
+ batch=batch_size
45
+ )
46
+ shape = data[1].shape
47
+ val_generator = DataGenerator(
48
+ {
49
+ INPUTS: Data(data[1], fn),
50
+ TARGET: Data(data[3], fn),
51
+ },
52
+ sampler=DataSampler(np.array(list(product(np.arange(shape[0]), np.arange(shape[1])))), shuffle=False),
53
+ batch=batch_size
54
+ )
55
+ return train_generator, val_generator
56
+
57
+
58
+ def compare_from_result(result, data):
59
+ data = data.data.data
60
+ answer = D.gather(data['target'].data, data['index'].data[:, 0])
61
+ import raven_utils as rv
62
+ predict = result['predict']
63
+ predict_mask = result['predict_mask']
64
+ return np.all(rv.decode.compare(answer[:len(predict)], predict, predict_mask), axis=-1)
saved_model/1/keras_metadata.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3065f1580247f096711cd61201a17a730a1e5a3d719f2c2778030dea78bb17b4
3
+ size 730275
saved_model/1/saved_model.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:48d10d74324a5e993ceacc0a4bffc1fcb232d7e2f708a2ebbeabd864650baeeb
3
+ size 12159312
saved_model/1/variables/variables.data-00000-of-00001 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e3c44b273228c834166b40b8a062a53dce76cc21d4cce42f65df2edc53533a7
3
+ size 43002413
saved_model/1/variables/variables.index ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e69bbffa5d415625538b629762f1aaeeb355a83d676242110af3d633e31017dd
3
+ size 24958
utils.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from data_utils.image import draw_images
3
+ from ml_utils import il
4
+
5
+ import raven_utils as rv
6
+ from raven_utils.uitls import get_matrix
7
+ from tensorflow.keras.models import load_model
8
+ from raven_utils.draw import render_from_model
9
+ import models
10
+ import ast
11
+
12
+
13
+ def load_example(index=0):
14
+ index = ast.literal_eval(str(index))
15
+ if il(index):
16
+ example = rv.draw.render_panels(np.array(index))
17
+ desc = "Custom matrix"
18
+ else:
19
+ if not index:
20
+ index = 0
21
+ index = int(index)
22
+
23
+ desc = rv.draw.extract_rules(models.properties[index])
24
+ desc = "<br /><br />".join(["<br />".join(d) for d in desc])
25
+
26
+ example = get_matrix(models.data[index:index + 1], models.indexes[index:index + 1, None] + 8)
27
+ result = np.tile(draw_images(example[:9], row=3), reps=(1, 1, 3))
28
+ return result, desc
29
+
30
+
31
+ def load_model_(name):
32
+ if name == "Transformer":
33
+ path = "/home/jkwiatkowski/all/best/rav/full_trans/6e8e6bad403e4171ad10daa1a518ba09"
34
+ else:
35
+ path = name
36
+ models.model = load_model(path)
37
+ return f"Success loading: {name}"
38
+
39
+
40
+ def run_nn(index=0):
41
+ index = ast.literal_eval(str(index))
42
+ if il(index):
43
+ data = rv.draw.render_panels(np.array(index))
44
+ data = np.concatenate([data, data[:7]])[None]
45
+ else:
46
+ if not index:
47
+ index = models.START_IMAGE
48
+ index = int(index)
49
+ data = models.data[index:index + 1]
50
+
51
+ # model = load_model("/home/jkwiatkowski/all/best/rav/full_trans/6e8e6bad403e4171ad10daa1a518ba09")
52
+ data = {
53
+ 'inputs': data,
54
+ 'index': np.zeros(shape=(1, 1), dtype="uint8"),
55
+ 'labels': np.zeros(shape=(1, 16, 113), dtype="int8"),
56
+ 'target': np.zeros(shape=(1, 16, 113), dtype="int8"),
57
+ # 'features': np.zeros(shape=(1, 16, 64), dtype="float32")
58
+ }
59
+ res = np.tile(render_from_model(data, models.model)[0, ..., None], reps=(1, 1, 3))
60
+
61
+ # res = model({'inputs': data[0:1]})
62
+
63
+ return res
64
+
65
+
66
+ def next_(index=0):
67
+ index = ast.literal_eval(str(index))
68
+ if not isinstance(index, int):
69
+ index = models.START_IMAGE
70
+ index = int(index) + 1
71
+ return (index,) + load_example(index)
72
+
73
+
74
+ def prev_(index=0):
75
+ index = ast.literal_eval(str(index))
76
+ if not isinstance(index, int):
77
+ index = models.START_IMAGE
78
+ index = int(index) - 1
79
+ return (index,) + load_example(index)
80
+
81
+
82
+ if __name__ == '__main__':
83
+ image, _ = load_example(5)
84
+ run_nn(image)