Spaces:
Build error
Build error
Jakub Kwiatkowski
commited on
Commit
·
9502bdf
1
Parent(s):
86045f3
Refactor hf/raven.
Browse files- models.py +1 -0
- raven_utils/depricated/__init__.py +0 -0
- raven_utils/depricated/old_raven.py +0 -490
- raven_utils/models/__init__.py +0 -0
- raven_utils/models/attn.py +0 -187
- raven_utils/models/attn2.py +0 -187
- raven_utils/models/augment.py +0 -0
- raven_utils/models/body.py +0 -276
- raven_utils/models/class_.py +0 -31
- raven_utils/models/head.py +0 -159
- raven_utils/models/loss.py +0 -630
- raven_utils/models/loss_3.py +0 -638
- raven_utils/models/multi_transformer.py +0 -274
- raven_utils/models/raven.py +0 -239
- raven_utils/models/trans.py +0 -74
- raven_utils/models/transformer.py +0 -133
- raven_utils/models/transformer_2.py +0 -146
- raven_utils/models/transformer_3.py +0 -206
- raven_utils/models/uitls_.py +0 -16
models.py
CHANGED
@@ -11,3 +11,4 @@ indexes = nload("/home/jkwiatkowski/all/dataset/arr/val_target.npy")
|
|
11 |
|
12 |
folders = DataSetFromFolder("/home/jkwiatkowski/all/dataset/arr/RAVEN-10000-release/RAVEN-10000", file_type="dir")
|
13 |
properties = DataSetFromFolder(folders[:], file_type="xml", extension="val")
|
|
|
|
11 |
|
12 |
folders = DataSetFromFolder("/home/jkwiatkowski/all/dataset/arr/RAVEN-10000-release/RAVEN-10000", file_type="dir")
|
13 |
properties = DataSetFromFolder(folders[:], file_type="xml", extension="val")
|
14 |
+
|
raven_utils/depricated/__init__.py
DELETED
File without changes
|
raven_utils/depricated/old_raven.py
DELETED
@@ -1,490 +0,0 @@
|
|
1 |
-
from functools import partial
|
2 |
-
|
3 |
-
import numpy as np
|
4 |
-
from data_utils import take, EXIST, COR
|
5 |
-
from data_utils.image import draw_images, add_text
|
6 |
-
from data_utils.op import np_split
|
7 |
-
from ml_utils import lu, dict_from_list2, filter_keys, none
|
8 |
-
from data_utils import ops as K
|
9 |
-
|
10 |
-
from config.constant import PROPERTY, TARGET, INPUTS
|
11 |
-
# from raven_utils.render.rendering import render_panels
|
12 |
-
|
13 |
-
RENDER_POSITIONS = [
|
14 |
-
[(0.5, 0.5, 1, 1)],
|
15 |
-
# ...
|
16 |
-
[(0.25, 0.25, 0.5, 0.5),
|
17 |
-
(0.25, 0.75, 0.5, 0.5),
|
18 |
-
(0.75, 0.25, 0.5, 0.5),
|
19 |
-
(0.75, 0.75, 0.5, 0.5)],
|
20 |
-
# ...
|
21 |
-
[(0.16, 0.16, 0.33, 0.33),
|
22 |
-
(0.16, 0.5, 0.33, 0.33),
|
23 |
-
(0.16, 0.83, 0.33, 0.33),
|
24 |
-
(0.5, 0.16, 0.33, 0.33),
|
25 |
-
(0.5, 0.5, 0.33, 0.33),
|
26 |
-
(0.5, 0.83, 0.33, 0.33),
|
27 |
-
(0.83, 0.16, 0.33, 0.33),
|
28 |
-
(0.83, 0.5, 0.33, 0.33),
|
29 |
-
(0.83, 0.83, 0.33, 0.33)],
|
30 |
-
# ...
|
31 |
-
[(0.5, 0.25, 0.5, 0.5)],
|
32 |
-
[(0.5, 0.75, 0.5, 0.5)],
|
33 |
-
# ...
|
34 |
-
[(0.25, 0.5, 0.5, 0.5)],
|
35 |
-
[(0.75, 0.5, 0.5, 0.5)],
|
36 |
-
# ...
|
37 |
-
[(0.5, 0.5, 1, 1)],
|
38 |
-
[(0.5, 0.5, 0.33, 0.33)],
|
39 |
-
# ...
|
40 |
-
[(0.5, 0.5, 1, 1)],
|
41 |
-
[(0.42, 0.42, 0.15, 0.15),
|
42 |
-
(0.42, 0.58, 0.15, 0.15),
|
43 |
-
(0.58, 0.42, 0.15, 0.15),
|
44 |
-
(0.58, 0.58, 0.15, 0.15)],
|
45 |
-
# ...
|
46 |
-
|
47 |
-
]
|
48 |
-
|
49 |
-
HORIZONTAL = "horizontal"
|
50 |
-
VERTICAL = "vertical"
|
51 |
-
|
52 |
-
NAMES = ['center_single',
|
53 |
-
'distribute_four',
|
54 |
-
'distribute_nine',
|
55 |
-
'in_center_single_out_center_single',
|
56 |
-
'in_distribute_four_out_center_single',
|
57 |
-
'left_center_single_right_center_single',
|
58 |
-
'up_center_single_down_center_single']
|
59 |
-
|
60 |
-
PROPERTIES_NAMES = [
|
61 |
-
'Color',
|
62 |
-
'Size',
|
63 |
-
'Type',
|
64 |
-
|
65 |
-
]
|
66 |
-
PROPERTIES = dict_from_list2(PROPERTIES_NAMES, [10, 6, 5])
|
67 |
-
ANGLE_MAX = 7
|
68 |
-
|
69 |
-
PROPERTIES_NO = len(PROPERTIES)
|
70 |
-
|
71 |
-
RULES_COMBINE = "Number/Position"
|
72 |
-
|
73 |
-
RULES_ATTRIBUTES = [
|
74 |
-
"Number",
|
75 |
-
"Position",
|
76 |
-
"Color",
|
77 |
-
"Size",
|
78 |
-
"Type"
|
79 |
-
]
|
80 |
-
RULES_ATTRIBUTES_LEN = len(RULES_ATTRIBUTES)
|
81 |
-
|
82 |
-
RULES_ATTRIBUTES_INDEX = dict_from_list2(RULES_ATTRIBUTES)
|
83 |
-
|
84 |
-
RULES_TYPES = [
|
85 |
-
"Constant",
|
86 |
-
"Arithmetic",
|
87 |
-
"Progression",
|
88 |
-
"Distribute_Three"
|
89 |
-
]
|
90 |
-
RULES_TYPES_INDEX = dict_from_list2(RULES_TYPES)
|
91 |
-
RULES_TYPES_LEN = len(RULES_ATTRIBUTES)
|
92 |
-
|
93 |
-
GROUPS_NO = len(NAMES)
|
94 |
-
ENTITY_NO = dict(zip(NAMES, [1, 4, 9, 2, 5, 2, 2]))
|
95 |
-
ENTITY_SUM = sum(list(ENTITY_NO.values()))
|
96 |
-
ENTITY_INDEX = np.concatenate([[0], np.cumsum(list(ENTITY_NO.values()))])
|
97 |
-
ENTITY_INDEX_TARGET = ENTITY_INDEX + 1
|
98 |
-
ENTITY_DICT = dict(zip(NAMES, ENTITY_INDEX_TARGET[:-1]))
|
99 |
-
NAMES_ORDER = dict(zip(NAMES, np.arange(len(NAMES))))
|
100 |
-
PROPERTIES_INDEXES = np.cumsum(np.array(list(ENTITY_NO.values())) * len(PROPERTIES))
|
101 |
-
INDEX = np.concatenate([[0], PROPERTIES_INDEXES]) + ENTITY_SUM + 1 # +2 type and uniformity
|
102 |
-
|
103 |
-
SECOND_LAYOUT = [i - 1 for i in [
|
104 |
-
ENTITY_DICT["in_center_single_out_center_single"] + 1,
|
105 |
-
ENTITY_DICT["in_distribute_four_out_center_single"] + 1,
|
106 |
-
ENTITY_DICT["in_distribute_four_out_center_single"] + 2,
|
107 |
-
ENTITY_DICT["in_distribute_four_out_center_single"] + 3,
|
108 |
-
ENTITY_DICT["left_center_single_right_center_single"] + 1,
|
109 |
-
ENTITY_DICT["up_center_single_down_center_single"] + 1
|
110 |
-
]]
|
111 |
-
|
112 |
-
FIRST_LAYOUT = list(set(range(ENTITY_SUM)) - set(SECOND_LAYOUT))
|
113 |
-
LAYOUT_NO = 2
|
114 |
-
|
115 |
-
START_INDEX = dict(zip(NAMES, INDEX[:-1]))
|
116 |
-
END_INDEX = INDEX[-1]
|
117 |
-
|
118 |
-
RULES_ATTRIBUTES_ALL_LEN = RULES_ATTRIBUTES_LEN * LAYOUT_NO
|
119 |
-
UNIFORMITY_NO = 2
|
120 |
-
UNIFORMITY_INDEX = END_INDEX + RULES_ATTRIBUTES_ALL_LEN
|
121 |
-
|
122 |
-
FEATURE_NO = UNIFORMITY_INDEX + UNIFORMITY_NO
|
123 |
-
MAPPING = {
|
124 |
-
"distribute_nine":
|
125 |
-
{0.16: 0,
|
126 |
-
0.5: 1,
|
127 |
-
0.83: 2},
|
128 |
-
"distribute_four":
|
129 |
-
{0.25: 0,
|
130 |
-
0.75: 1},
|
131 |
-
'in_distribute_four_out_center_single':
|
132 |
-
{0.42: 0,
|
133 |
-
0.58: 1}
|
134 |
-
}
|
135 |
-
MUL = {
|
136 |
-
"distribute_nine": 3,
|
137 |
-
"distribute_four": 2,
|
138 |
-
'in_distribute_four_out_center_single': 2
|
139 |
-
}
|
140 |
-
|
141 |
-
# SIZES = np.linspace(0.4, 0.9, 6)
|
142 |
-
TYPES = ["triangle", "square", "pentagon", "hexagon", "circle"]
|
143 |
-
# TYPES = ["triangle", "square", "pentagon", "circle", "circle"]
|
144 |
-
SIZES = ["vs", "s", "m", "h", "vh", "e"]
|
145 |
-
COLORS = ["vs", "s", "m", "h", "vh", "e"]
|
146 |
-
# TYPES = ["", "", "circle", "hexagon", "square"]
|
147 |
-
|
148 |
-
ENTITY_PROPERTIES_VALUES = list(PROPERTIES.values())
|
149 |
-
ENTITY_PROPERTIES_KEYS = list(PROPERTIES.keys())
|
150 |
-
ENTITY_PROPERTIES_NO = len(PROPERTIES)
|
151 |
-
INDEX = dict(zip(PROPERTIES, np.array(ENTITY_PROPERTIES_VALUES) * ENTITY_SUM))
|
152 |
-
ENTITY_PROPERTIES_SUM = sum(list(PROPERTIES.values()))
|
153 |
-
|
154 |
-
OUTPUT_SIZE = ENTITY_SUM * ENTITY_PROPERTIES_SUM + GROUPS_NO + ENTITY_SUM
|
155 |
-
|
156 |
-
SLOT_AND_GROUP = ENTITY_SUM + GROUPS_NO
|
157 |
-
|
158 |
-
OUTPUT_GROUP_SLICE = np.s_[:, -GROUPS_NO:]
|
159 |
-
OUTPUT_SLOT_SLICE = np.s_[:, -SLOT_AND_GROUP:-GROUPS_NO]
|
160 |
-
OUTPUT_PROPERTIES_SLICE = np.s_[:, :-SLOT_AND_GROUP]
|
161 |
-
|
162 |
-
OUTPUT_GROUP_SLICE_END = np.s_[-GROUPS_NO:]
|
163 |
-
OUTPUT_SLOT_SLICE_END = np.s_[-SLOT_AND_GROUP:-GROUPS_NO]
|
164 |
-
OUTPUT_PROPERTIES_SLICE_END = np.s_[:-SLOT_AND_GROUP]
|
165 |
-
|
166 |
-
# Transformation
|
167 |
-
# constant
|
168 |
-
# progression -2, -1,1 ,2
|
169 |
-
# arithmetic -/+ Position set arithmetic
|
170 |
-
# distribute three
|
171 |
-
|
172 |
-
# todo
|
173 |
-
SLOTS_GROUPS = GROUPS_NO
|
174 |
-
|
175 |
-
SLOT_TRANSFORMATION_NO = 4
|
176 |
-
PROPERTY_TRANSFORMATION_NO = 8
|
177 |
-
PROPERTIES_TRANSFORMATION_NO = PROPERTY_TRANSFORMATION_NO * PROPERTIES_NO
|
178 |
-
PROPERTIES_TRANSFORMATION_SIZE = PROPERTIES_TRANSFORMATION_NO * ENTITY_SUM
|
179 |
-
|
180 |
-
SLOT_TRANSFORMATION_SIZE = PROPERTY_TRANSFORMATION_NO * SLOTS_GROUPS
|
181 |
-
INFERENCE_SIZE = SLOT_TRANSFORMATION_SIZE + PROPERTIES_TRANSFORMATION_SIZE
|
182 |
-
|
183 |
-
INFERENCE_SLOT_SLICE = np.s_[:, :SLOT_TRANSFORMATION_SIZE]
|
184 |
-
INFERENCE_PROPERTIES_SLICE = np.s_[:, -PROPERTIES_TRANSFORMATION_SIZE:]
|
185 |
-
from operator import add
|
186 |
-
|
187 |
-
|
188 |
-
# todo Refactor
|
189 |
-
# Maybe properties should be on same level as rest.
|
190 |
-
def decode_output(output, split_fn=np_split):
|
191 |
-
group_output = output[..., OUTPUT_GROUP_SLICE_END]
|
192 |
-
slot_output = output[..., OUTPUT_SLOT_SLICE_END]
|
193 |
-
properties_output = output[..., OUTPUT_PROPERTIES_SLICE_END]
|
194 |
-
properties_output_splited = split_fn(properties_output, list(rv.properties.INDEX.values()), axis=-1)
|
195 |
-
return group_output, slot_output, properties_output_splited
|
196 |
-
|
197 |
-
|
198 |
-
def decode_inference(inference, reshape=np.reshape):
|
199 |
-
return reshape(inference[INFERENCE_SLOT_SLICE],
|
200 |
-
[-1, SLOTS_GROUPS, PROPERTY_TRANSFORMATION_NO]), reshape(
|
201 |
-
inference[INFERENCE_PROPERTIES_SLICE],
|
202 |
-
[-1, PROPERTIES_NO, ENTITY_SUM, PROPERTY_TRANSFORMATION_NO])
|
203 |
-
|
204 |
-
|
205 |
-
def decode_output_reshape(output, split_fn=np_split):
|
206 |
-
result = decode_output(output, split_fn=split_fn)
|
207 |
-
out_reshaped = []
|
208 |
-
for i, out in enumerate(result[2]):
|
209 |
-
shape = (-1, ENTITY_SUM, ENTITY_PROPERTIES_VALUES[i])
|
210 |
-
out_reshaped.append(out.reshape(shape))
|
211 |
-
return result[:2] + tuple(out_reshaped)
|
212 |
-
|
213 |
-
|
214 |
-
def take_target(target):
|
215 |
-
return target[1], target[2]
|
216 |
-
|
217 |
-
|
218 |
-
def create_target(images, index, pattern_index=(2, 5), full_index=False, arrange=np.arange, shape=lambda x: x.shape):
|
219 |
-
return [images[:, pattern_index[0]], images[:, pattern_index[1]],
|
220 |
-
images[arrange(shape(index)[0]), (0 if full_index else 8) + index[:, 0]]]
|
221 |
-
|
222 |
-
|
223 |
-
def take_target_simple(target):
|
224 |
-
return target[1], target[0]
|
225 |
-
|
226 |
-
|
227 |
-
def create_target_simple(images, target, index=slice(None), pattern_index=(2, 5)):
|
228 |
-
return [images[:, pattern_index[0]], images[:, pattern_index[1]], target][index]
|
229 |
-
|
230 |
-
|
231 |
-
def decode_output_result(output, split_fn=np_split, arg_max=np.argmax):
|
232 |
-
result = decode_output_reshape(output, split_fn=split_fn)
|
233 |
-
res = []
|
234 |
-
for i, r in enumerate(result):
|
235 |
-
if i == 1:
|
236 |
-
res.append(r)
|
237 |
-
else:
|
238 |
-
res.append(arg_max(r, axis=-1))
|
239 |
-
return tuple(res)
|
240 |
-
|
241 |
-
|
242 |
-
def decode_target(target):
|
243 |
-
target_group = target[..., 0]
|
244 |
-
target_slot = target[..., 1:INDEX[0]]
|
245 |
-
target_properties = target[..., INDEX[0]:END_INDEX]
|
246 |
-
target_properties_splited = [
|
247 |
-
target_properties[..., ::PROPERTIES_NO],
|
248 |
-
target_properties[..., 1::PROPERTIES_NO],
|
249 |
-
target_properties[..., 2::PROPERTIES_NO]
|
250 |
-
]
|
251 |
-
return target_group, target_slot, target_properties_splited
|
252 |
-
|
253 |
-
|
254 |
-
def decode_target_flat(target):
|
255 |
-
t = decode_target(target)
|
256 |
-
return t[0], t[1], t[2][0], t[2][1], t[2][2]
|
257 |
-
|
258 |
-
|
259 |
-
def draw_board(images, target=None, predict=None,image=None, desc=None, layout=None, break_=20):
|
260 |
-
if image != "target" and predict is not None:
|
261 |
-
image = images[predict:predict + 1]
|
262 |
-
elif images is None and target is not None:
|
263 |
-
image = images[target:target + 1]
|
264 |
-
# image = False to not draw anything
|
265 |
-
border = [{COR: target - 8, EXIST: (1, 3)}] + [{COR: p, EXIST: (0, 2)} for p in none(predict)]
|
266 |
-
|
267 |
-
boards = []
|
268 |
-
boards.append(draw_images(np.concatenate([images[:8], image[None] if len(image.shape)==3 else image]) if image is not None else images[:8]))
|
269 |
-
if layout == 1:
|
270 |
-
i = draw_images(images[8:], column=4, border=border)
|
271 |
-
if break_:
|
272 |
-
i = np.concatenate([np.zeros([ break_, i.shape[1],1]),i ],axis=0)
|
273 |
-
boards.append(i)
|
274 |
-
|
275 |
-
else:
|
276 |
-
boards.append(
|
277 |
-
draw_images(np.concatenate([images[8:], predict]) if predict is not None else images[8:], column=4,
|
278 |
-
border=target - 8))
|
279 |
-
full_board = draw_images(boards, grid=False)
|
280 |
-
if desc:
|
281 |
-
full_board = add_text(full_board, desc)
|
282 |
-
return full_board
|
283 |
-
|
284 |
-
|
285 |
-
def draw_boards(images, target=None, predict=None, image=None, desc=None, no=1, layout=None):
|
286 |
-
boards = []
|
287 |
-
for i, image in enumerate(images):
|
288 |
-
boards.append(draw_board(image, target[i][0] if target is not None else None,
|
289 |
-
predict[i] if predict is not None else None,
|
290 |
-
image[i] if image is not None else None,
|
291 |
-
desc[i] if desc is not None else None, layout=layout))
|
292 |
-
return boards
|
293 |
-
|
294 |
-
|
295 |
-
def draw_raven(generator, predict=None, no=1, add_target_desc=True, indexes=None, types=TYPES,
|
296 |
-
layout=1):
|
297 |
-
if indexes is None:
|
298 |
-
indexes = val_sample(no)
|
299 |
-
data = generator.data[indexes]
|
300 |
-
if is_model(predict):
|
301 |
-
d = filter_keys(data, PROPERTY,reverse=True)
|
302 |
-
# tmp change
|
303 |
-
pro = predict(d)['predict']
|
304 |
-
print(pro)
|
305 |
-
predict = render_panels(pro, target=False)
|
306 |
-
# if target is not None:
|
307 |
-
target = data[TARGET]
|
308 |
-
target_index = data["index"]
|
309 |
-
images = data[INPUTS]
|
310 |
-
|
311 |
-
if hasattr(predict, "shape"):
|
312 |
-
if len(predict.shape) > 3:
|
313 |
-
# iamges
|
314 |
-
image = predict
|
315 |
-
# todo create index and output based on image
|
316 |
-
predict = None
|
317 |
-
predict_index = None
|
318 |
-
elif len(predict.shape) == 3:
|
319 |
-
image = render_panels(predict, target=False)
|
320 |
-
# Create index based on predict.
|
321 |
-
predict_index = None
|
322 |
-
else:
|
323 |
-
image = images[predict]
|
324 |
-
predict_index = predict
|
325 |
-
predict = target
|
326 |
-
else:
|
327 |
-
image = K.gather(images, target_index[:, 0])
|
328 |
-
predict_index = None
|
329 |
-
predict = None
|
330 |
-
|
331 |
-
# elif not(hasattr(target,"shape") and len(target.shape) > 3):
|
332 |
-
# if hasattr(target,"shape") and target.shape[-1] == OUTPUT_SIZE:
|
333 |
-
# pro = target
|
334 |
-
# predict = render_panels(pro)
|
335 |
-
# elif hasattr(target,"shape") and target.shape[-1] == FEATURE_NO:
|
336 |
-
# # pro = target
|
337 |
-
# pro = np.zeros([no, OUTPUT_SIZE], dtype="int")
|
338 |
-
# else:
|
339 |
-
# pro = np.zeros([no, OUTPUT_SIZE], dtype="int")
|
340 |
-
# # predict = [None] * no
|
341 |
-
# predict = render_panels(data[TARGET])
|
342 |
-
|
343 |
-
all_rules = []
|
344 |
-
for d in data[PROPERTY]:
|
345 |
-
rules = []
|
346 |
-
for j, rule_group in enumerate(d.findAll("Rule_Group")):
|
347 |
-
# rules_all.append(rule_group['id'])
|
348 |
-
for j, rule in enumerate(rule_group.findAll("Rule")):
|
349 |
-
rules.append(f"{rule['attr']} - {rule['name']}")
|
350 |
-
rules.append("")
|
351 |
-
all_rules.append(rules)
|
352 |
-
target_desc = get_desc(target)
|
353 |
-
if predict is not None:
|
354 |
-
predict_desc = decode_output_result(predict) if predict.shape[-1] == OUTPUT_SIZE else get_desc(predict)
|
355 |
-
else:
|
356 |
-
predict_desc = [None] * len(target_desc)
|
357 |
-
for a, po, to in zip(all_rules, predict_desc, target_desc):
|
358 |
-
# fl(predict_desc[-1])
|
359 |
-
if po is None:
|
360 |
-
po = [None] * len(to)
|
361 |
-
for p, t in zip(po, to):
|
362 |
-
a.extend(
|
363 |
-
[" ".join([str(i) for i in t])] + (
|
364 |
-
[" ".join([str(i) for i in p]), ""] if p is not None else []
|
365 |
-
)
|
366 |
-
)
|
367 |
-
# a.extend([""] + [] + [""] + [" ".join(fl(p))])
|
368 |
-
|
369 |
-
# image = draw_boards(data[INPUTS],target=data["index"], predict=predict[:no], desc=all_rules, no=no,layer=layer)
|
370 |
-
image = draw_boards(images, target=target_index, predict=predict_index, image=image, desc=None, no=no,
|
371 |
-
layout=layout)
|
372 |
-
return lu([(i, j) for i, j in zip(image, all_rules)])
|
373 |
-
|
374 |
-
|
375 |
-
def val_sample(no=GROUPS_NO, base=3):
|
376 |
-
indexes = np.arange(no) * 2000 + base
|
377 |
-
return indexes
|
378 |
-
|
379 |
-
|
380 |
-
def get_desc(target, exist=None, types=TYPES, sizes=SIZES):
|
381 |
-
decoded = decode_target(target)
|
382 |
-
exist = decoded[1] if exist is None else exist
|
383 |
-
taken = np.stack(take(decoded[2], np.array(exist, dtype=bool))).T
|
384 |
-
|
385 |
-
figures_no = np.sum(exist, axis=-1)
|
386 |
-
desc = np.split(taken, np.cumsum(figures_no))[:-1]
|
387 |
-
# figures_no = np.sum(exist, axis=-1)
|
388 |
-
# div = np.split(desc, np.cumsum(figures_no))[:-1]
|
389 |
-
result = []
|
390 |
-
for pd in desc:
|
391 |
-
r = []
|
392 |
-
for p in pd:
|
393 |
-
r.append([p[0], sizes[p[1]], types[p[2]]])
|
394 |
-
result.append(r)
|
395 |
-
|
396 |
-
return result
|
397 |
-
|
398 |
-
|
399 |
-
# def get
|
400 |
-
|
401 |
-
|
402 |
-
def get_description(inputs, predict, pro, no, types=TYPES, sizes=SIZES):
|
403 |
-
# target = inputs[1][2][:no]
|
404 |
-
target = inputs[TARGET]
|
405 |
-
target_group = target[:, 0]
|
406 |
-
target_exist = np.asarray(target[:, 1:ENTITY_SUM + 1], dtype="bool")
|
407 |
-
target_rest = target[:, ENTITY_SUM + 1:ENTITY_SUM + 1 + ENTITY_SUM * PROPERTIES_NO]
|
408 |
-
pro_reshaped = np.reshape(pro, (pro.shape[0], -1, PROPERTIES_NO))
|
409 |
-
target_reshaped = np.reshape(target_rest, (target_rest.shape[0], -1, PROPERTIES_NO))
|
410 |
-
|
411 |
-
# mask = np.repeat(target_exist, [4] * ENTITY_SUM, axis=-1)
|
412 |
-
# masked_result = np.repeat(target_exist, [4] * ENTITY_SUM, axis=-1)
|
413 |
-
pro_res = pro_reshaped[target_exist]
|
414 |
-
target_res = target_reshaped[target_exist]
|
415 |
-
figures_no = np.sum(target_exist, axis=-1)
|
416 |
-
|
417 |
-
pro_div = np.split(pro_res, np.cumsum(figures_no))[:-1]
|
418 |
-
target_div = np.split(target_res, np.cumsum(figures_no))[:-1]
|
419 |
-
pro_result_full = []
|
420 |
-
target_result_full = []
|
421 |
-
for pd, td in zip(pro_div, target_div):
|
422 |
-
pro_result = []
|
423 |
-
target_result = []
|
424 |
-
for p in pd:
|
425 |
-
pro_result.append([p[0], sizes[p[1]], types[p[2]]])
|
426 |
-
for t in td:
|
427 |
-
target_result.append([t[0], sizes[t[1]], types[t[2]]])
|
428 |
-
pro_result_full.append(pro_result)
|
429 |
-
target_result_full.append(target_result)
|
430 |
-
|
431 |
-
return pro_result_full, target_result_full
|
432 |
-
|
433 |
-
|
434 |
-
def get_properties(target, types=TYPES, sizes=SIZES):
|
435 |
-
target_exist = np.asarray(target[:, 1:ENTITY_SUM + 1], dtype="bool")
|
436 |
-
target_rest = target[:, ENTITY_SUM + 1:ENTITY_SUM + 1 + ENTITY_SUM * PROPERTIES_NO]
|
437 |
-
target_reshaped = np.reshape(target_rest, (target_rest.shape[0], -1, PROPERTIES_NO))
|
438 |
-
target_res = target_reshaped[target_exist]
|
439 |
-
figures_no = np.sum(target_exist, axis=-1)
|
440 |
-
target_div = np.split(target_res, np.cumsum(figures_no))[:-1]
|
441 |
-
target_result_full = []
|
442 |
-
for td in target_div:
|
443 |
-
target_result = []
|
444 |
-
for t in td:
|
445 |
-
target_result.append([t[0], sizes[t[1]], types[t[2]]])
|
446 |
-
target_result_full.append(target_result)
|
447 |
-
return target_result_full
|
448 |
-
|
449 |
-
|
450 |
-
def desc_properties(target, decode_fn=None, types=TYPES, sizes=SIZES):
|
451 |
-
if decode_fn is None:
|
452 |
-
if target.shape[1] == OUTPUT_SIZE:
|
453 |
-
decode_fn = decode_output_result
|
454 |
-
else:
|
455 |
-
decode_fn = decode_target
|
456 |
-
|
457 |
-
target_div = decode_fn(target)[2:]
|
458 |
-
target_result_full = []
|
459 |
-
for td in target_div:
|
460 |
-
target_result = []
|
461 |
-
for t in td:
|
462 |
-
target_result.append([t[0], sizes[t[1]], types[t[2]]])
|
463 |
-
target_result_full.append(target_result)
|
464 |
-
return target_result_full
|
465 |
-
|
466 |
-
|
467 |
-
def get_pro(t, types=TYPES, sizes=SIZES):
|
468 |
-
return [int(t[0]), sizes[t[1]], types[t[2]]]
|
469 |
-
|
470 |
-
|
471 |
-
def get_pro2(td, types=TYPES, sizes=SIZES):
|
472 |
-
target_result = []
|
473 |
-
for t in td:
|
474 |
-
target_result.append([int(t[0]), sizes[t[1]], types[t[2]]])
|
475 |
-
return target_result
|
476 |
-
|
477 |
-
|
478 |
-
def get_pro3(target_div, types=TYPES, sizes=SIZES):
|
479 |
-
target_result_full = []
|
480 |
-
for td in target_div.to_list():
|
481 |
-
target_result = []
|
482 |
-
for t in td:
|
483 |
-
target_result.append([int(t[0]), sizes[t[1]], types[t[2]]])
|
484 |
-
target_result_full.append(target_result)
|
485 |
-
return target_result_full
|
486 |
-
|
487 |
-
|
488 |
-
from models_utils import init_image as def_init_image, is_model
|
489 |
-
|
490 |
-
init_image = partial(def_init_image, shape=(16, 8, 80, 80, 1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raven_utils/models/__init__.py
DELETED
File without changes
|
raven_utils/models/attn.py
DELETED
@@ -1,187 +0,0 @@
|
|
1 |
-
from __future__ import print_function
|
2 |
-
|
3 |
-
import tensorflow as tf
|
4 |
-
from tensorflow.keras import backend as K
|
5 |
-
from tensorflow.keras.layers import LSTMCell
|
6 |
-
from tensorflow.keras.models import Model
|
7 |
-
from tensorflow.keras.layers import Conv2D, Dense
|
8 |
-
from tensorflow.keras.losses import mse
|
9 |
-
from tensorflow.keras.models import clone_model
|
10 |
-
from tensorflow.layers.base import InputSpec, Layer
|
11 |
-
|
12 |
-
from models.dense import create_conv_model
|
13 |
-
from models.utils import broadcast
|
14 |
-
|
15 |
-
|
16 |
-
class ReflectionPadding2D(Layer):
|
17 |
-
def __init__(self, padding=(1, 1), **kwargs):
|
18 |
-
self.padding = tuple(padding)
|
19 |
-
self.input_spec = [InputSpec(ndim=4)]
|
20 |
-
super(ReflectionPadding2D, self).__init__(**kwargs)
|
21 |
-
|
22 |
-
def compute_output_shape(self, s):
|
23 |
-
""" If you are using "channels_last" configuration"""
|
24 |
-
return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])
|
25 |
-
|
26 |
-
def call(self, x, mask=None):
|
27 |
-
w_pad, h_pad = self.padding
|
28 |
-
return tf.pad(x, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]], 'REFLECT')
|
29 |
-
|
30 |
-
|
31 |
-
class Conv2Ref(Layer):
|
32 |
-
def __init__(self, padding=(1, 1), **kwargs):
|
33 |
-
self.padding = tuple(padding)
|
34 |
-
self.input_spec = [InputSpec(ndim=4)]
|
35 |
-
super(ReflectionPadding2D, self).__init__(**kwargs)
|
36 |
-
|
37 |
-
def compute_output_shape(self, s):
|
38 |
-
""" If you are using "channels_last" configuration"""
|
39 |
-
return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])
|
40 |
-
|
41 |
-
def call(self, x, mask=None):
|
42 |
-
w_pad, h_pad = self.padding
|
43 |
-
return tf.pad(x, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]], 'REFLECT')
|
44 |
-
|
45 |
-
|
46 |
-
class SegmentationNetwork(Model):
|
47 |
-
|
48 |
-
def __init__(self, filters=64, kernels=(3, 3)):
|
49 |
-
super(RecAE, self).__init__()
|
50 |
-
self.conv_1 = Conv2D(filters, kernels, padding=SAME)
|
51 |
-
self.conv_2 = Conv2D(filters, kernels, padding=SAME)
|
52 |
-
|
53 |
-
def call(self, inputs):
|
54 |
-
x = K.relu(inputs)
|
55 |
-
x = self.conv_1(x)
|
56 |
-
x = K.relu(x)
|
57 |
-
x = self.conv_2(x)
|
58 |
-
return x + inputs
|
59 |
-
|
60 |
-
|
61 |
-
class QueryNetwork(Model):
|
62 |
-
|
63 |
-
def __init__(self, units=64):
|
64 |
-
super(RecAE, self).__init__()
|
65 |
-
self.conv_1 = Dense(units)
|
66 |
-
self.conv_2 = Dense(units)
|
67 |
-
|
68 |
-
def call(self, inputs):
|
69 |
-
x = K.relu(inputs)
|
70 |
-
x = self.conv_1(x)
|
71 |
-
x = K.relu(x)
|
72 |
-
x = self.conv_2(x)
|
73 |
-
return x + inputs
|
74 |
-
|
75 |
-
|
76 |
-
class RecAE(Model):
|
77 |
-
|
78 |
-
def __init__(self, head, bottle, decoder):
|
79 |
-
super(RecAE, self).__init__()
|
80 |
-
self.head = head
|
81 |
-
self.bottle = bottle
|
82 |
-
self.base = clone_model(bottle)
|
83 |
-
self.decoder = decoder
|
84 |
-
self.segmentation_network = SegmentationNetwork()
|
85 |
-
self.query_network = QueryNetwork()
|
86 |
-
self.control = LSTMCell(64)
|
87 |
-
self.memory = LSTMCell(64)
|
88 |
-
|
89 |
-
def call(self, inputs):
|
90 |
-
feature = self.head(inputs)
|
91 |
-
segmentation = self.segmentation_network(feature)
|
92 |
-
control_base = self.base(feature)
|
93 |
-
h_c = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
|
94 |
-
h_m = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
|
95 |
-
shape = K.shape(feature)[:-1]
|
96 |
-
full_attention = tf.zeros(shape)[..., tf.newaxis]
|
97 |
-
full_image = tf.zeros(K.shape(inputs))
|
98 |
-
masks = []
|
99 |
-
ff = tf.zeros(K.shape(inputs))
|
100 |
-
scope = tf.ones(shape)[..., tf.newaxis]
|
101 |
-
for i in range(4):
|
102 |
-
r_c, h_c = self.control(tf.concat([control_base, h_m[0]], 1), h_c)
|
103 |
-
query = self.query_network(h_c[0])
|
104 |
-
log_attention = image_attention(segmentation, query)
|
105 |
-
attention = K.sigmoid(log_attention)
|
106 |
-
mask = attention * scope
|
107 |
-
scope = scope - mask
|
108 |
-
im = feature * mask
|
109 |
-
# im = feature
|
110 |
-
latent = self.bottle(im)
|
111 |
-
decoded = self.decoder(latent)
|
112 |
-
# self.add_loss(K.mean(-mse(full_attention, attention)))
|
113 |
-
# self.add_loss(K.mean(-mse(tf.ones(attention.shape), attention)))
|
114 |
-
full_attention += attention
|
115 |
-
big_mask = tf.image.resize(mask, K.shape(inputs)[1:-1])
|
116 |
-
ff += K.sigmoid(decoded)
|
117 |
-
full_image += K.sigmoid(decoded) * big_mask
|
118 |
-
r_m, h_m = self.memory(latent, h_m)
|
119 |
-
masks.append(big_mask)
|
120 |
-
self.add_loss(K.mean(mse(inputs, full_image)))
|
121 |
-
return full_image, masks
|
122 |
-
|
123 |
-
|
124 |
-
# def image_attention(image, query, scale=True):
|
125 |
-
@tf.function
|
126 |
-
def image_attention(image, query):
|
127 |
-
log_attention = K.sum(query[:, tf.newaxis, tf.newaxis, :] * image, axis=-1, keepdims=True)
|
128 |
-
# if scale is not None:
|
129 |
-
log_attention /= tf.sqrt(tf.cast(K.shape(image)[-1], dtype=float))
|
130 |
-
return log_attention
|
131 |
-
|
132 |
-
|
133 |
-
class RecAE_2(Model):
|
134 |
-
|
135 |
-
def __init__(self, head, bottle, decoder):
|
136 |
-
super(RecAE_2, self).__init__()
|
137 |
-
self.head = head
|
138 |
-
self.bottle = bottle
|
139 |
-
# self.base = clone_model(bottle)
|
140 |
-
self.base = self.bottle
|
141 |
-
self.decoder = decoder
|
142 |
-
self.segmentation_network = create_conv_model((64, 64, 1))
|
143 |
-
self.control = LSTMCell(64)
|
144 |
-
self.memory = LSTMCell(64)
|
145 |
-
|
146 |
-
def call(self, inputs):
|
147 |
-
feature = self.head(inputs)
|
148 |
-
control_base = self.base(feature)
|
149 |
-
h_c = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
|
150 |
-
h_m = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
|
151 |
-
shape = K.shape(feature)[:-1]
|
152 |
-
full_attention = tf.zeros(shape)[..., tf.newaxis]
|
153 |
-
full_image = tf.zeros(K.shape(inputs))
|
154 |
-
big_masks = []
|
155 |
-
masks = []
|
156 |
-
ff = tf.zeros(K.shape(inputs))
|
157 |
-
scope = tf.ones(shape)[..., tf.newaxis]
|
158 |
-
for i in range(4):
|
159 |
-
if i ==3:
|
160 |
-
mask = scope
|
161 |
-
else:
|
162 |
-
r_c, h_c = self.control(tf.concat([control_base, h_m[0]], 1), h_c)
|
163 |
-
query = broadcast(h_c[0], feature.shape[1:])
|
164 |
-
log_attention = self.segmentation_network(tf.concat([feature, query], axis=-1))
|
165 |
-
attention = K.sigmoid(log_attention)
|
166 |
-
mask = attention * scope
|
167 |
-
scope = scope - mask
|
168 |
-
masks.append(mask)
|
169 |
-
im = feature * mask
|
170 |
-
# im = feature
|
171 |
-
latent = self.bottle(im)
|
172 |
-
decoded = self.decoder(latent)
|
173 |
-
# self.add_loss(K.mean(-mse(scope, mask)))
|
174 |
-
sum = K.sum(tf.ones(K.shape(mask)))
|
175 |
-
self.add_loss(K.abs((sum/4)-K.sum(mask))/sum)
|
176 |
-
# self.add_loss(K.mean(-mse(tf.zeros(K.shape(mask)), mask)))
|
177 |
-
for m in masks:
|
178 |
-
self.add_loss(K.mean(-mse(m,mask)))
|
179 |
-
|
180 |
-
full_attention += mask
|
181 |
-
big_mask = tf.image.resize(mask, K.shape(inputs)[1:-1])
|
182 |
-
ff += K.sigmoid(decoded)
|
183 |
-
full_image += K.sigmoid(decoded) * big_mask
|
184 |
-
r_m, h_m = self.memory(latent, h_m)
|
185 |
-
big_masks.append(big_mask)
|
186 |
-
self.add_loss(K.mean(mse(inputs, full_image)))
|
187 |
-
return full_image, big_masks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raven_utils/models/attn2.py
DELETED
@@ -1,187 +0,0 @@
|
|
1 |
-
from __future__ import print_function
|
2 |
-
|
3 |
-
import tensorflow as tf
|
4 |
-
from tensorflow.keras import backend as K
|
5 |
-
from tensorflow.keras.layers import LSTMCell
|
6 |
-
from tensorflow.keras.models import Model
|
7 |
-
from tensorflow.keras.layers import Conv2D, Dense
|
8 |
-
from tensorflow.keras.losses import mse
|
9 |
-
from tensorflow.keras.models import clone_model
|
10 |
-
from tensorflow.layers.base import InputSpec, Layer
|
11 |
-
|
12 |
-
from models.dense import create_conv_model
|
13 |
-
from models.utils import broadcast
|
14 |
-
|
15 |
-
|
16 |
-
class ReflectionPadding2D(Layer):
|
17 |
-
def __init__(self, padding=(1, 1), **kwargs):
|
18 |
-
self.padding = tuple(padding)
|
19 |
-
self.input_spec = [InputSpec(ndim=4)]
|
20 |
-
super(ReflectionPadding2D, self).__init__(**kwargs)
|
21 |
-
|
22 |
-
def compute_output_shape(self, s):
|
23 |
-
""" If you are using "channels_last" configuration"""
|
24 |
-
return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])
|
25 |
-
|
26 |
-
def call(self, x, mask=None):
|
27 |
-
w_pad, h_pad = self.padding
|
28 |
-
return tf.pad(x, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]], 'REFLECT')
|
29 |
-
|
30 |
-
|
31 |
-
class Conv2Ref(Layer):
|
32 |
-
def __init__(self, padding=(1, 1), **kwargs):
|
33 |
-
self.padding = tuple(padding)
|
34 |
-
self.input_spec = [InputSpec(ndim=4)]
|
35 |
-
super(ReflectionPadding2D, self).__init__(**kwargs)
|
36 |
-
|
37 |
-
def compute_output_shape(self, s):
|
38 |
-
""" If you are using "channels_last" configuration"""
|
39 |
-
return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])
|
40 |
-
|
41 |
-
def call(self, x, mask=None):
|
42 |
-
w_pad, h_pad = self.padding
|
43 |
-
return tf.pad(x, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]], 'REFLECT')
|
44 |
-
|
45 |
-
|
46 |
-
class SegmentationNetwork(Model):
|
47 |
-
|
48 |
-
def __init__(self, filters=64, kernels=(3, 3)):
|
49 |
-
super(RecAE, self).__init__()
|
50 |
-
self.conv_1 = Conv2D(filters, kernels)
|
51 |
-
self.conv_2 = Conv2D(filters, kernels)
|
52 |
-
|
53 |
-
def call(self, inputs):
|
54 |
-
x = K.relu(inputs)
|
55 |
-
x = self.conv_1(x)
|
56 |
-
x = K.relu(x)
|
57 |
-
x = self.conv_2(x)
|
58 |
-
return x + inputs
|
59 |
-
|
60 |
-
|
61 |
-
class QueryNetwork(Model):
|
62 |
-
|
63 |
-
def __init__(self, units=64):
|
64 |
-
super(RecAE, self).__init__()
|
65 |
-
self.conv_1 = Dense(units)
|
66 |
-
self.conv_2 = Dense(units)
|
67 |
-
|
68 |
-
def call(self, inputs):
|
69 |
-
x = K.relu(inputs)
|
70 |
-
x = self.conv_1(x)
|
71 |
-
x = K.relu(x)
|
72 |
-
x = self.conv_2(x)
|
73 |
-
return x + inputs
|
74 |
-
|
75 |
-
|
76 |
-
class RecAE(Model):
|
77 |
-
|
78 |
-
def __init__(self, head, bottle, decoder):
|
79 |
-
super(RecAE, self).__init__()
|
80 |
-
self.head = head
|
81 |
-
self.bottle = bottle
|
82 |
-
self.base = clone_model(bottle)
|
83 |
-
self.decoder = decoder
|
84 |
-
self.segmentation_network = SegmentationNetwork()
|
85 |
-
self.query_network = QueryNetwork()
|
86 |
-
self.control = LSTMCell(64)
|
87 |
-
self.memory = LSTMCell(64)
|
88 |
-
|
89 |
-
def call(self, inputs):
|
90 |
-
feature = self.head(inputs)
|
91 |
-
segmentation = self.segmentation_network(feature)
|
92 |
-
control_base = self.base(feature)
|
93 |
-
h_c = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
|
94 |
-
h_m = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
|
95 |
-
shape = K.shape(feature)[:-1]
|
96 |
-
full_attention = tf.zeros(shape)[..., tf.newaxis]
|
97 |
-
full_image = tf.zeros(K.shape(inputs))
|
98 |
-
masks = []
|
99 |
-
ff = tf.zeros(K.shape(inputs))
|
100 |
-
scope = tf.ones(shape)[..., tf.newaxis]
|
101 |
-
for i in range(10):
|
102 |
-
r_c, h_c = self.control(tf.concat([control_base, h_m[0]], 1), h_c)
|
103 |
-
query = self.query_network(h_c[0])
|
104 |
-
log_attention = image_attention(segmentation, query)
|
105 |
-
attention = K.softmax(log_attention)
|
106 |
-
mask = attention * scope
|
107 |
-
scope = scope - mask
|
108 |
-
im = feature * mask
|
109 |
-
# im = feature
|
110 |
-
latent = self.bottle(im)
|
111 |
-
decoded = self.decoder(latent)
|
112 |
-
# self.add_loss(K.mean(-mse(full_attention, attention)))
|
113 |
-
# self.add_loss(K.mean(-mse(tf.ones(attention.shape), attention)))
|
114 |
-
full_attention += attention
|
115 |
-
big_mask = tf.image.resize(mask, K.shape(inputs)[1:-1])
|
116 |
-
ff += K.sigmoid(decoded)
|
117 |
-
full_image += K.sigmoid(decoded) * big_mask
|
118 |
-
r_m, h_m = self.memory(latent, h_m)
|
119 |
-
masks.append(big_mask)
|
120 |
-
self.add_loss(K.mean(mse(inputs, full_image)))
|
121 |
-
return full_image, masks
|
122 |
-
|
123 |
-
|
124 |
-
# def image_attention(image, query, scale=True):
|
125 |
-
@tf.function
|
126 |
-
def image_attention(image, query):
|
127 |
-
log_attention = K.sum(query[:, tf.newaxis, tf.newaxis, :] * image, axis=-1, keepdims=True)
|
128 |
-
# if scale is not None:
|
129 |
-
log_attention /= tf.sqrt(tf.cast(K.shape(image)[-1], dtype=float))
|
130 |
-
return log_attention
|
131 |
-
|
132 |
-
|
133 |
-
class RecAE_2(Model):
|
134 |
-
|
135 |
-
def __init__(self, head, bottle, decoder):
|
136 |
-
super(RecAE_2, self).__init__()
|
137 |
-
self.head = head
|
138 |
-
self.bottle = bottle
|
139 |
-
# self.base = clone_model(bottle)
|
140 |
-
self.base = self.bottle
|
141 |
-
self.decoder = decoder
|
142 |
-
self.segmentation_network = create_conv_model((64, 64, 1))
|
143 |
-
self.control = LSTMCell(64)
|
144 |
-
self.memory = LSTMCell(64)
|
145 |
-
|
146 |
-
def call(self, inputs):
|
147 |
-
feature = self.head(inputs)
|
148 |
-
control_base = self.base(feature)
|
149 |
-
h_c = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
|
150 |
-
h_m = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
|
151 |
-
shape = K.shape(feature)[:-1]
|
152 |
-
full_attention = tf.zeros(shape)[..., tf.newaxis]
|
153 |
-
full_image = tf.zeros(K.shape(inputs))
|
154 |
-
big_masks = []
|
155 |
-
masks = []
|
156 |
-
ff = tf.zeros(K.shape(inputs))
|
157 |
-
scope = tf.ones(shape)[..., tf.newaxis]
|
158 |
-
for i in range(4):
|
159 |
-
if i ==3:
|
160 |
-
mask = scope
|
161 |
-
else:
|
162 |
-
r_c, h_c = self.control(tf.concat([control_base, h_m[0]], 1), h_c)
|
163 |
-
query = broadcast(h_c[0], feature.shape[1:])
|
164 |
-
log_attention = self.segmentation_network(tf.concat([feature, query], axis=-1))
|
165 |
-
attention = K.sigmoid(log_attention)
|
166 |
-
mask = attention * scope
|
167 |
-
scope = scope - mask
|
168 |
-
masks.append(mask)
|
169 |
-
im = feature * mask
|
170 |
-
# im = feature
|
171 |
-
latent = self.bottle(im)
|
172 |
-
decoded = self.decoder(latent)
|
173 |
-
# self.add_loss(K.mean(-mse(scope, mask)))
|
174 |
-
sum = K.sum(tf.ones(K.shape(mask)))
|
175 |
-
self.add_loss(K.abs((sum/4)-K.sum(mask))/sum)
|
176 |
-
# self.add_loss(K.mean(-mse(tf.zeros(K.shape(mask)), mask)))
|
177 |
-
for m in masks:
|
178 |
-
self.add_loss(K.mean(-mse(m,mask)))
|
179 |
-
|
180 |
-
full_attention += mask
|
181 |
-
big_mask = tf.image.resize(mask, K.shape(inputs)[1:-1])
|
182 |
-
ff += K.sigmoid(decoded)
|
183 |
-
full_image += K.sigmoid(decoded) * big_mask
|
184 |
-
r_m, h_m = self.memory(latent, h_m)
|
185 |
-
big_masks.append(big_mask)
|
186 |
-
self.add_loss(K.mean(mse(inputs, full_image)))
|
187 |
-
return full_image, big_masks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raven_utils/models/augment.py
DELETED
File without changes
|
raven_utils/models/body.py
DELETED
@@ -1,276 +0,0 @@
|
|
1 |
-
import itertools
|
2 |
-
|
3 |
-
import tensorflow as tf
|
4 |
-
from ml_utils import self_product, lw
|
5 |
-
|
6 |
-
from models_utils import DictModel, ListModel, Flat, bm, Base, Cat, Res, Flat2, conv, KERNEL_SIZE, FILTERS, SAME, \
|
7 |
-
Get, SM, bs, RELU, ACTIVATION, dense, bd, HardBlock, MaxBlock
|
8 |
-
import models_utils.ops as K
|
9 |
-
from models_utils import Merge, SoftBlock
|
10 |
-
from models_utils.build import build_multi_dense, build_multi_conv, build_conv_model, build_encoder
|
11 |
-
from tensorflow.keras.layers import Lambda, Dense
|
12 |
-
from tensorflow.keras.layers import Conv2D
|
13 |
-
|
14 |
-
from config.constant import MEMORY, CONTROL, LATENT, MERGE, CONCAT, INFERENCE, FLATTEN
|
15 |
-
from models_utils.config import config
|
16 |
-
|
17 |
-
|
18 |
-
class RavRes(Res):
|
19 |
-
def __init__(self, model="v2", latent=256, act=RELU):
|
20 |
-
super().__init__(model=model)
|
21 |
-
self.latent = latent
|
22 |
-
|
23 |
-
def call(self, inputs):
|
24 |
-
return self.model(inputs) + inputs[0][:, ..., self.latent:]
|
25 |
-
|
26 |
-
|
27 |
-
# not working
|
28 |
-
class RavResConv(Res):
|
29 |
-
def __init__(self, model="v2", latent=256, act=RELU):
|
30 |
-
super().__init__(model=model)
|
31 |
-
self.latent = latent
|
32 |
-
self.conv = conv(latent, (1, 1), activation=act)
|
33 |
-
|
34 |
-
def call(self, inputs):
|
35 |
-
return self.model(inputs) + self.conv(inputs[0])
|
36 |
-
|
37 |
-
|
38 |
-
class RavResDense(Res):
|
39 |
-
def __init__(self, model="v2", latent=256, act=config.DEF_DENSE.activation):
|
40 |
-
super().__init__(model=model)
|
41 |
-
self.latent = latent
|
42 |
-
self.conv = dense(latent, activation=act)
|
43 |
-
|
44 |
-
def call(self, inputs):
|
45 |
-
return self.model(inputs) + self.conv(inputs[0])
|
46 |
-
|
47 |
-
|
48 |
-
def create_dense_block(latent=256, loop=1):
|
49 |
-
soft_block = Res(SoftBlock(build_multi_dense(latent), add_identity=None,
|
50 |
-
score_activation=tf.sigmoid), latent=latent)
|
51 |
-
cells = [
|
52 |
-
(lambda x: K.cat([x[:, 0], x[:, 1]]), LATENT, CONCAT),
|
53 |
-
(None, CONCAT, MEMORY),
|
54 |
-
(Dense(latent), CONCAT, MERGE),
|
55 |
-
(Merge(latent), [INFERENCE, MERGE], CONTROL),
|
56 |
-
(soft_block, [MEMORY, CONTROL], MEMORY)
|
57 |
-
]
|
58 |
-
|
59 |
-
return ListModel([DictModel(*cell) for cell in cells] * loop, [LATENT, INFERENCE], MEMORY)
|
60 |
-
|
61 |
-
|
62 |
-
def build_multi_conv(filters=32, end_filters=64, padding="same",mul=1, norm=None, **kwargs):
|
63 |
-
base = [(1, 3), (3, 1), (3, 3)]
|
64 |
-
block = list(self_product(base))
|
65 |
-
block2 = [b + b[0:1] for b in block]
|
66 |
-
block3 = [b + b for b in block]
|
67 |
-
block4 = ([[(3, 3)]] + [[(3, 3), (3, 3)]] + [[(3, 3), (3, 3), (3, 3)]]) * 2
|
68 |
-
block5 = [[], []]
|
69 |
-
all_blocks = [s for b in [block, block2, block3, block4, block5] for s in b]
|
70 |
-
start = {
|
71 |
-
FILTERS: filters,
|
72 |
-
KERNEL_SIZE: (1, 1)
|
73 |
-
}
|
74 |
-
|
75 |
-
end = {
|
76 |
-
FILTERS: end_filters,
|
77 |
-
KERNEL_SIZE: (1, 1),
|
78 |
-
ACTIVATION: None
|
79 |
-
}
|
80 |
-
|
81 |
-
all_arch = []
|
82 |
-
for ab in all_blocks:
|
83 |
-
arch = [{
|
84 |
-
FILTERS: filters,
|
85 |
-
KERNEL_SIZE: a,
|
86 |
-
**kwargs
|
87 |
-
} for a in ab]
|
88 |
-
all_arch.append([start] + arch + [end])
|
89 |
-
|
90 |
-
all_arch = all_arch * mul
|
91 |
-
|
92 |
-
return [
|
93 |
-
build_encoder(a, add_norm=norm if norm else None, padding=padding, name=f"b{i}", order=(1, 0) if norm else None)
|
94 |
-
for i, a in enumerate(all_arch)]
|
95 |
-
|
96 |
-
|
97 |
-
def create_block(latent=256, simpler=0, loop=1, padding=SAME, norm=None, trans_div=2, act="pass", type_="conv",
|
98 |
-
block_=SoftBlock,max_k=16,
|
99 |
-
**kwargs):
|
100 |
-
trans_size = int(latent / trans_div)
|
101 |
-
# if block_ == HardBlock:
|
102 |
-
# mul = 2
|
103 |
-
# elif block_ == MaxBlock:
|
104 |
-
# mul = int(38/max_k)
|
105 |
-
# else:
|
106 |
-
# mul = 1
|
107 |
-
|
108 |
-
if act == "pass":
|
109 |
-
res_class = RavRes
|
110 |
-
else:
|
111 |
-
if type_ == "dense":
|
112 |
-
res_class = RavResDense
|
113 |
-
else:
|
114 |
-
res_class = RavResConv
|
115 |
-
|
116 |
-
if type_ == "dense":
|
117 |
-
build_res = lambda: Res(model="dv2")
|
118 |
-
# build_reduction = lambda: bm([dense(latent), "IN"])
|
119 |
-
build_reduction = lambda: dense(latent)
|
120 |
-
build_flatten = lambda: bd([latent] * 2)
|
121 |
-
else:
|
122 |
-
build_res = lambda: Res(padding=padding)
|
123 |
-
build_reduction = lambda: bm([conv(trans_size if simpler else latent, 1, padding=padding), "BN"])
|
124 |
-
# build_reduction = lambda: bm([conv(latent, 1, padding=padding), "BN"])
|
125 |
-
# build_reduction = lambda: bm([conv(trans_size, 1, padding=padding), "BN"])
|
126 |
-
# build_reduction = lambda: conv(trans_size, 1, padding=padding)
|
127 |
-
# build_flatten = lambda: Flat2(filters=trans_size,res_no=2, padding=padding, units=64)
|
128 |
-
build_flatten = lambda: Flat2(filters=trans_size,padding=padding, units=64)
|
129 |
-
|
130 |
-
if simpler == 1:
|
131 |
-
cells = [
|
132 |
-
(lambda x: K.cat([x[:, 0], x[:, 1]]), LATENT, CONCAT,"concatenation"),
|
133 |
-
# (None, CONCAT, MEMORY),
|
134 |
-
(build_reduction(), CONCAT, MERGE,"Start_resnet_block"),
|
135 |
-
# (Get(), INFERENCE, INFERENCE),
|
136 |
-
(K.cat, [INFERENCE, MERGE], CONTROL,"concatenation"),
|
137 |
-
]
|
138 |
-
else:
|
139 |
-
cells = [
|
140 |
-
(lambda x: K.cat([x[:, 0], x[:, 1]]), LATENT, CONCAT),
|
141 |
-
(build_reduction(), CONCAT, MEMORY),
|
142 |
-
(build_reduction(), INFERENCE, CONTROL),
|
143 |
-
]
|
144 |
-
for i, l in enumerate(lw(loop)):
|
145 |
-
if l:
|
146 |
-
concat = K.cat
|
147 |
-
control_reduction = build_reduction()
|
148 |
-
control_res = build_res()
|
149 |
-
control_flatten = build_flatten()
|
150 |
-
if i == 0 and simpler == 1:
|
151 |
-
rest_params = {
|
152 |
-
"latent": latent,
|
153 |
-
"act": act
|
154 |
-
}
|
155 |
-
else:
|
156 |
-
rest_params = {
|
157 |
-
"latent": 0
|
158 |
-
}
|
159 |
-
|
160 |
-
|
161 |
-
if block_ == SoftBlock:
|
162 |
-
block_params = {
|
163 |
-
}
|
164 |
-
else:
|
165 |
-
block_params = {
|
166 |
-
"trans_output_shape": latent
|
167 |
-
}
|
168 |
-
if block_ == MaxBlock:
|
169 |
-
block_params['max_k'] = max_k
|
170 |
-
|
171 |
-
|
172 |
-
# todo change name
|
173 |
-
soft_block = res_class(
|
174 |
-
block_(
|
175 |
-
build_multi_dense(latent) if type_ == "dense" else build_multi_conv(trans_size, end_filters=latent,
|
176 |
-
norm=norm, padding=padding,
|
177 |
-
**kwargs),
|
178 |
-
add_identity=None,
|
179 |
-
score_activation=tf.sigmoid,
|
180 |
-
**block_params
|
181 |
-
|
182 |
-
),
|
183 |
-
**rest_params)
|
184 |
-
|
185 |
-
if i == 0 and simpler == 1:
|
186 |
-
cells.extend([
|
187 |
-
(control_reduction, CONTROL, CONTROL,"Reduction"),
|
188 |
-
(control_res, CONTROL, CONTROL,"Control_resnet_block"),
|
189 |
-
(control_flatten, CONTROL, FLATTEN,"Weights"),
|
190 |
-
(soft_block, [CONCAT, FLATTEN], MEMORY,"Transformation"),
|
191 |
-
# (soft_block, [MEMORY, FLATTEN], MEMORY,"Transformation"),
|
192 |
-
])
|
193 |
-
else:
|
194 |
-
if l:
|
195 |
-
memory_res = build_res()
|
196 |
-
|
197 |
-
cells.extend([
|
198 |
-
(memory_res, MEMORY, MEMORY,"Memory_resnet_block"),
|
199 |
-
(concat, [CONTROL, MEMORY], CONTROL,"concatenation"),
|
200 |
-
(control_reduction, CONTROL, CONTROL,"Reduction"),
|
201 |
-
(control_res, CONTROL, CONTROL,"Control_resnet_block"),
|
202 |
-
(control_flatten, CONTROL, FLATTEN,"Weights"),
|
203 |
-
(soft_block, [MEMORY, FLATTEN], MEMORY, "Transformation"),
|
204 |
-
])
|
205 |
-
return ListModel([DictModel(*cell) for cell in cells], [LATENT, INFERENCE], MEMORY, debug_=False)
|
206 |
-
|
207 |
-
#
|
208 |
-
#
|
209 |
-
# def test(x):
|
210 |
-
# np.zeros(4)
|
211 |
-
# self_product((1, 3))
|
212 |
-
#
|
213 |
-
#
|
214 |
-
# list(itertools.product())
|
215 |
-
# u.layers[0].layers[-1].model.layers[1]
|
216 |
-
|
217 |
-
# class RecurrentBodyDict(Model):
|
218 |
-
# # def __init__(self, start=None, cell=None, output_network=None, output_activation="tanh", latent=64, loop_no=5):
|
219 |
-
# def __init__(self, start=None, cell=None, output_network=None, output_activation=None, latent=64, loop_no=5):
|
220 |
-
# super().__init__()
|
221 |
-
# self.start = sm(start, lambda: SubClassingModel([StartLSTMControl(latent), StartLSTMMemory(latent)]),
|
222 |
-
# latent=latent)
|
223 |
-
# self.cell = sm(cell, lambda: SubClassingModel([LSTMControl(latent), LSTMMemory(latent)]), latent=latent)
|
224 |
-
# self.output_network = sm(output_network, lf(take_memory_states))
|
225 |
-
# self.loop_no = loop_no
|
226 |
-
# # tmp
|
227 |
-
# self.activation = Activation(output_activation)
|
228 |
-
#
|
229 |
-
# def call(self, inputs):
|
230 |
-
# outputs = []
|
231 |
-
# for j in range(3):
|
232 |
-
# outputs.append(self.start({"latent": inputs[0][j], "inference": inputs[1]}))
|
233 |
-
# for i in range(self.loop_no):
|
234 |
-
# for j in range(3):
|
235 |
-
# outputs[j] = self.cell(outputs[j])
|
236 |
-
#
|
237 |
-
# return self.activation(self.output_network(outputs))
|
238 |
-
#
|
239 |
-
#
|
240 |
-
# class RecurrentBodySimpleMix4Dict(RecurrentBodyDict):
|
241 |
-
# def __init__(self, latent=64, output_network=None, loop_no=5):
|
242 |
-
# super().__init__(
|
243 |
-
# start=SubClassingModel(
|
244 |
-
# [ConcatCell(), DenseCell(latent), InfMergeCell(latent),
|
245 |
-
# WeigthCell(latent, layer_no=np.repeat([1, 2, 3, 4, 5, 6, 7, 8], 4),
|
246 |
-
# add_identity=Lambda(lambda x: x[:, latent:]))]),
|
247 |
-
# cell=False,
|
248 |
-
# output_network=output_network, loop_no=0)
|
249 |
-
# class RecurrentBodySimpleMix4Conv(RecurrentBodyDict):
|
250 |
-
# def __init__(self, latent=64, output_network=None, loop_no=5):
|
251 |
-
# super().__init__(
|
252 |
-
# start=SubClassingModel(
|
253 |
-
# [ConcatCell(), ConvCell(latent), ReduceCell(latent), InfMergeCell(latent),
|
254 |
-
# ModelCell(latent=latent, layers_no=2, input_name=CONTROL, result_name=CONTROL),
|
255 |
-
# WeigthCell(latent,
|
256 |
-
# transformation_network=[build_conv_model2([latent] * i, kernels=(j, j)) for i in range(1, 7) for j in
|
257 |
-
# range(1, 5) for _ in range(1)],
|
258 |
-
# add_identity=Lambda(lambda x: x[:, ..., latent:]))
|
259 |
-
# ]),
|
260 |
-
# cell=False,
|
261 |
-
# output_network=output_network, loop_no=0)
|
262 |
-
#
|
263 |
-
#
|
264 |
-
# class RecurrentBodySimpleMix4Conv2(RecurrentBodyDict):
|
265 |
-
# def __init__(self, latent=64, output_network=None, loop_no=5):
|
266 |
-
# super().__init__(
|
267 |
-
# start=SubClassingModel(
|
268 |
-
# [ConcatCell(), ConvCell(latent), ReduceCell2(latent), InfMergeCell(latent),
|
269 |
-
# ModelCell(latent=latent, layers_no=2, input_name=CONTROL, result_name=CONTROL),
|
270 |
-
# WeigthCell(latent,
|
271 |
-
# transformation_network=[bc([latent] * i, kernels=(j, j)) for i in range(1, 7) for j in
|
272 |
-
# range(1, 5) for _ in range(1)],
|
273 |
-
# add_identity=Lambda(lambda x: x[:, ..., latent:]))
|
274 |
-
# ]),
|
275 |
-
# cell=False,
|
276 |
-
# output_network=output_network, loop_no=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raven_utils/models/class_.py
DELETED
@@ -1,31 +0,0 @@
|
|
1 |
-
from ml_utils import lw
|
2 |
-
from models_utils import SubClassingModel, ops as K, Base
|
3 |
-
import tensorflow as tf
|
4 |
-
|
5 |
-
|
6 |
-
class Merge(SubClassingModel):
|
7 |
-
def call(self, inputs):
|
8 |
-
results = []
|
9 |
-
for i, model in enumerate(self.model[:-1]):
|
10 |
-
results.append(model(inputs[i]))
|
11 |
-
# todo why K.cat not working
|
12 |
-
results = self.model[-1](tf.concat(results, axis=-1))
|
13 |
-
return results
|
14 |
-
|
15 |
-
|
16 |
-
class RavenClass(Base):
|
17 |
-
def __init__(self, model, scales=None, no=3, name=None):
|
18 |
-
super().__init__(model=model, name=name)
|
19 |
-
self.scales = scales
|
20 |
-
self.no = no
|
21 |
-
|
22 |
-
def call(self, inputs):
|
23 |
-
inputs = lw(inputs)
|
24 |
-
class_res = []
|
25 |
-
# for i in range(inputs[0].shape[1]):
|
26 |
-
for i in range(self.no):
|
27 |
-
# d = [r[:, i] if r.ndim == 5 else r for r in inputs]
|
28 |
-
d = [inputs[s][:, i] if inputs[s].ndim > 2 else inputs for s in self.scales]
|
29 |
-
class_res.append(self.model(d))
|
30 |
-
# return tf.stack(class_res,axis=1)
|
31 |
-
return [class_res]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raven_utils/models/head.py
DELETED
@@ -1,159 +0,0 @@
|
|
1 |
-
import tensorflow as tf
|
2 |
-
from ml_utils import set_default
|
3 |
-
from models_utils import build_dense_model, bm, ActivationModel, sm, large_conv_dense_encoder, Pass
|
4 |
-
from models_utils import res
|
5 |
-
from tensorflow.keras import Model
|
6 |
-
from models_utils import ops as K
|
7 |
-
from tensorflow.keras.layers import Dense, Conv2D, Flatten
|
8 |
-
from keras.backend import batch_flatten
|
9 |
-
|
10 |
-
|
11 |
-
# todo Refactoring
|
12 |
-
class HeadModel(Model):
|
13 |
-
def __init__(self, encoder=None, inference_network=None, output_size=64, inference_output_size=None,
|
14 |
-
inference_activation="relu", stem=None, images_no=8, inference_image_no=None):
|
15 |
-
super().__init__()
|
16 |
-
# self.encoder = sm(encoder, bm([en.large_conv_dense_encoder(), Dense(output_size)], False))
|
17 |
-
self.encoder = encoder or bm([large_conv_dense_encoder(), Dense(output_size)])
|
18 |
-
# self.head = head or HeadBatch(encoder=encoder, output_size=output_size)
|
19 |
-
inference_output_size = inference_output_size or output_size
|
20 |
-
self.inference_network = inference_network or bm([
|
21 |
-
K.flat,
|
22 |
-
build_dense_model([1028, 512, 512, inference_output_size],
|
23 |
-
last_activation=inference_activation)]
|
24 |
-
)
|
25 |
-
self.stem = stem or Pass()
|
26 |
-
self.images_no = images_no
|
27 |
-
self.inference_image_no = self.images_no if inference_image_no is None else inference_image_no
|
28 |
-
|
29 |
-
|
30 |
-
class LatentHeadModel(HeadModel):
|
31 |
-
def call(self, inputs):
|
32 |
-
result = K.map_batch(inputs[:, :self.images_no], self.encoder)
|
33 |
-
inference = self.inference_network(result[:, :self.inference_image_no])
|
34 |
-
latents = self.stem(result)
|
35 |
-
return [latents, inference,result]
|
36 |
-
|
37 |
-
|
38 |
-
# # todo use map_batch
|
39 |
-
# class HeadBatch(Model):
|
40 |
-
# def __init__(self, encoder=None, output_size=64):
|
41 |
-
# super().__init__()
|
42 |
-
# self.encoder = sm(encoder, bm([large_conv_dense_encoder(), Dense(output_size)], False))
|
43 |
-
#
|
44 |
-
# def call(self, inputs):
|
45 |
-
# shape = tf.shape(inputs)
|
46 |
-
# latents = self.encoder(tf.reshape(inputs, shape=tf.concat([[-1], shape[2:]], axis=-1)))
|
47 |
-
# latents = K.reshape(latents, tf.concat([[-1, shape[1]], latents.shape[1:]], axis=-1))
|
48 |
-
# return latents
|
49 |
-
|
50 |
-
|
51 |
-
# Not working
|
52 |
-
class DuoHeadModel(HeadModel):
|
53 |
-
def __init__(self, encoder=None, inference_network=None, images_no=8, filters=-4):
|
54 |
-
super().__init__(encoder=encoder, inference_network=inference_network, images_no=images_no)
|
55 |
-
self.encoder = ActivationModel(self.encoder, filters=filters, include_input=False)
|
56 |
-
|
57 |
-
def call(self, inputs):
|
58 |
-
shape = inputs.shape
|
59 |
-
result = reversed(self.encoder(K.reshape(inputs, shape=[-1] + list(shape[2:]))))
|
60 |
-
latents = K.reshape(result[0], [-1, self.images_no] + [result[0].shape[-1]])
|
61 |
-
inference = self.inference_network(K.flat(result[1]))
|
62 |
-
return [latents, inference]
|
63 |
-
|
64 |
-
|
65 |
-
class MultiHeadModel(Model):
|
66 |
-
def __init__(self, encoder=None, images_no=8, filters=(1, 3, 6)):
|
67 |
-
super().__init__()
|
68 |
-
self.encoder = ActivationModel(encoder, filters=filters, include_input=False)
|
69 |
-
self.merge = MergeSacles()
|
70 |
-
self.images_no = images_no
|
71 |
-
|
72 |
-
def call(self, inputs):
|
73 |
-
shape = tf.shape(inputs)
|
74 |
-
results = self.encoder(tf.reshape(inputs, shape=tf.concat([[-1], shape[2:]], axis=-1)))
|
75 |
-
latents = [tf.reshape(result, shape=tf.concat([[-1, self.images_no], tf.shape(result)[1:]], axis=-1)) for result
|
76 |
-
in results]
|
77 |
-
|
78 |
-
l1 = tf.transpose(latents[0], (0, 2, 3, 1, 4))
|
79 |
-
# l1 = tf.reshape(l1, tuple(list(l1.shape[:3]) + [l1.shape[-2] * l1.shape[-1]]))
|
80 |
-
shape = tf.shape(l1)
|
81 |
-
l1 = tf.reshape(l1, tf.concat([[-1], shape[1:3], [shape[-2] * shape[-1]]], axis=-1))
|
82 |
-
|
83 |
-
l2 = tf.transpose(latents[1], (0, 2, 3, 1, 4))
|
84 |
-
# l2 = tf.reshape(l2, [-1] + list(l2.shape[1:3]) + [l2.shape[-2] * l2.shape[-1]])
|
85 |
-
shape = tf.shape(l2)
|
86 |
-
l2 = tf.reshape(l2, tf.concat([[-1], shape[1:3], [shape[-2] * shape[-1]]], axis=-1))
|
87 |
-
|
88 |
-
l3 = latents[2]
|
89 |
-
shape = tf.shape(l3)
|
90 |
-
# l3 = tf.reshape(l3, [-1] + [l3.shape[-2] * l3.shape[-1]])
|
91 |
-
l3 = tf.reshape(l3, tf.concat([[-1], [shape[-2] * shape[-1]]], axis=-1))
|
92 |
-
|
93 |
-
inference = self.merge([l1, l2, l3])
|
94 |
-
return [latents, inference]
|
95 |
-
|
96 |
-
|
97 |
-
class MergeSacles(Model):
|
98 |
-
def __init__(self):
|
99 |
-
super().__init__()
|
100 |
-
self.inf_1 = bm([Conv2D(64, 1, activation="relu"), res(64),
|
101 |
-
Conv2D(64, 3, strides=2, padding=SAME, activation="relu"),
|
102 |
-
res(64),
|
103 |
-
Flatten(),
|
104 |
-
Dense(256, "relu")])
|
105 |
-
self.inf_2 = bm([Conv2D(128, 1, activation="relu"),
|
106 |
-
res(128),
|
107 |
-
Flatten(),
|
108 |
-
Dense(256, "relu")])
|
109 |
-
self.inf_3 = Dense(256, "relu")
|
110 |
-
|
111 |
-
def call(self, inputs):
|
112 |
-
il1 = self.inf_1(inputs[0])
|
113 |
-
il2 = self.inf_2(inputs[1])
|
114 |
-
il3 = self.inf_3(inputs[2])
|
115 |
-
inference = tf.concat([il1, il2, il3], axis=1)
|
116 |
-
return inference
|
117 |
-
|
118 |
-
|
119 |
-
class MultiHeadModel2(Model):
|
120 |
-
def __init__(self, encoder=None, images_no=8, filters=(3, 6)):
|
121 |
-
super().__init__()
|
122 |
-
self.encoder = ActivationModel(encoder, filters=filters, include_input=False)
|
123 |
-
self.merge = MergeSacles2()
|
124 |
-
self.images_no = images_no
|
125 |
-
|
126 |
-
def call(self, inputs):
|
127 |
-
shape = tf.shape(inputs)
|
128 |
-
results = self.encoder(tf.reshape(inputs, shape=tf.concat([[-1], shape[2:]], axis=-1)))
|
129 |
-
latents = [tf.reshape(result, shape=tf.concat([[-1, self.images_no], tf.shape(result)[1:]], axis=-1)) for result
|
130 |
-
in results]
|
131 |
-
|
132 |
-
l2 = tf.transpose(latents[0], (0, 2, 3, 1, 4))
|
133 |
-
# l2 = tf.reshape(l2, [-1] + list(l2.shape[1:3]) + [l2.shape[-2] * l2.shape[-1]])
|
134 |
-
shape = tf.shape(l2)
|
135 |
-
l2 = tf.reshape(l2, tf.concat([[-1], shape[1:3], [shape[-2] * shape[-1]]], axis=-1))
|
136 |
-
|
137 |
-
l3 = latents[1]
|
138 |
-
shape = tf.shape(l3)
|
139 |
-
# l3 = tf.reshape(l3, [-1] + [l3.shape[-2] * l3.shape[-1]])
|
140 |
-
l3 = tf.reshape(l3, tf.concat([[-1], [shape[-2] * shape[-1]]], axis=-1))
|
141 |
-
|
142 |
-
inference = self.merge([l2, l3])
|
143 |
-
return [latents, inference]
|
144 |
-
|
145 |
-
|
146 |
-
class MergeSacles2(Model):
|
147 |
-
def __init__(self):
|
148 |
-
super().__init__()
|
149 |
-
self.inf_1 = bm([Conv2D(128, 1, activation="relu"),
|
150 |
-
res(128),
|
151 |
-
Flatten(),
|
152 |
-
Dense(256, "relu")])
|
153 |
-
self.inf_2 = Dense(256, "relu")
|
154 |
-
|
155 |
-
def call(self, inputs):
|
156 |
-
il1 = self.inf_1(inputs[0])
|
157 |
-
il2 = self.inf_2(inputs[1])
|
158 |
-
inference = tf.concat([il1, il2], axis=1)
|
159 |
-
return inference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raven_utils/models/loss.py
DELETED
@@ -1,630 +0,0 @@
|
|
1 |
-
from functools import partial
|
2 |
-
|
3 |
-
import tensorflow as tf
|
4 |
-
import tensorflow.experimental.numpy as tnp
|
5 |
-
from models_utils import OUTPUT, TARGET, PREDICT, DictModel, add_loss, LOSS, Predict
|
6 |
-
from models_utils import SubClassingModel
|
7 |
-
from models_utils.models.utils import interleave
|
8 |
-
from models_utils.op import reshape
|
9 |
-
from tensorflow.keras import Model
|
10 |
-
# from tensorflow.keras import backend as K
|
11 |
-
from tensorflow.keras.layers import Lambda
|
12 |
-
from tensorflow.keras.losses import SparseCategoricalCrossentropy, mse
|
13 |
-
from tensorflow.keras.metrics import SparseCategoricalAccuracy, Accuracy, BinaryAccuracy
|
14 |
-
import models_utils.ops as K
|
15 |
-
|
16 |
-
import raven_utils.decode
|
17 |
-
import raven_utils as rv
|
18 |
-
from raven_utils.config.constant import LABELS, INDEX, ACC_SAME, ACC_CHOOSE_LOWER, ACC_CHOOSE_UPPER, CLASSIFICATION, \
|
19 |
-
SLOT, \
|
20 |
-
PROPERTIES, ACC, GROUP, NUMBER, MASK
|
21 |
-
from raven_utils.models.uitls_ import RangeMask
|
22 |
-
from raven_utils.const import VERTICAL, HORIZONTAL
|
23 |
-
|
24 |
-
|
25 |
-
def get_properties_mask(target):
|
26 |
-
return target[:, rv.target.END_INDEX:rv.target.UNIFORMITY_INDEX] > 0
|
27 |
-
|
28 |
-
|
29 |
-
def create_change_mask(target):
|
30 |
-
properties_mask = get_properties_mask(target)
|
31 |
-
return [create_mask(properties_mask, i) for i, _ in enumerate(rv.rules.ATTRIBUTES)]
|
32 |
-
|
33 |
-
|
34 |
-
def create_uniform_mask(target):
|
35 |
-
u_mask = lambda i: tf.tile(target[:, rv.target.UNIFORMITY_INDEX + i, None] == 3, [1, rv.rules.ATTRIBUTES_LEN])
|
36 |
-
properties_mask = tf.concat([u_mask(0), u_mask(1)], axis=-1) | get_properties_mask(target)
|
37 |
-
return [create_mask(properties_mask, i) for i, _ in enumerate(rv.rules.ATTRIBUTES)]
|
38 |
-
|
39 |
-
|
40 |
-
def create_all_mask(target):
|
41 |
-
return [
|
42 |
-
tf.cast(tf.ones(tf.stack([tf.shape(target)[0], rv.entity.SUM])), dtype=tf.bool) for i, _ in
|
43 |
-
enumerate(rv.rules.ATTRIBUTES)]
|
44 |
-
|
45 |
-
|
46 |
-
class BaselineClassificationLossModel(Model):
|
47 |
-
def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True):
|
48 |
-
super().__init__()
|
49 |
-
self.predict_fn = SubClassingModel([lambda x: x[0], PredictModel()])
|
50 |
-
self.loss_fn = ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
|
51 |
-
group_loss=group_loss)
|
52 |
-
self.metric_fn = SimilarityRaven(mode=mode)
|
53 |
-
|
54 |
-
def call(self, inputs):
|
55 |
-
losses = []
|
56 |
-
output = inputs[1]
|
57 |
-
losses.append(self.loss_fn([inputs[0][0], output]))
|
58 |
-
losses.append(self.metric_fn([inputs[0][2], inputs[3][0], inputs[0][1][:, 8:]]))
|
59 |
-
return losses
|
60 |
-
|
61 |
-
|
62 |
-
class RavenLoss(Model):
|
63 |
-
def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True, lw=(1.0, 0.3),
|
64 |
-
classification=False, trans=True, anneal=False):
|
65 |
-
super().__init__()
|
66 |
-
if anneal:
|
67 |
-
self.weight_scheduler
|
68 |
-
self.classification = classification
|
69 |
-
self.trans = trans
|
70 |
-
self.predict_fn = DictModel(SubClassingModel([lambda x: x[-1], PredictModel()]), in_=OUTPUT,
|
71 |
-
out=[PREDICT, MASK], name="pred")
|
72 |
-
if self.trans:
|
73 |
-
self.loss_fn = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
|
74 |
-
group_loss=group_loss, enable_metrics=False, lw=lw[0]),
|
75 |
-
name="main_loss")
|
76 |
-
self.loss_fn_2 = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
|
77 |
-
group_loss=group_loss), name="add_loss")
|
78 |
-
self.metric_fn = SimilarityRaven(mode=mode)
|
79 |
-
if self.classification:
|
80 |
-
self.loss_fn_3 = add_loss(
|
81 |
-
ClassRavenModel(mode=create_all_mask, number_loss=number_loss, slot_loss=slot_loss,
|
82 |
-
group_loss=group_loss, enable_metrics="c" if self.trans else True), lw=lw[1],
|
83 |
-
name="class_loss")
|
84 |
-
|
85 |
-
def call(self, inputs):
|
86 |
-
losses = []
|
87 |
-
output = inputs[OUTPUT]
|
88 |
-
target = inputs[TARGET]
|
89 |
-
labels = inputs[LABELS]
|
90 |
-
|
91 |
-
if self.trans:
|
92 |
-
losses.append(self.loss_fn([labels[:, 2], output[0]]))
|
93 |
-
losses.append(self.loss_fn([labels[:, 5], output[1]]))
|
94 |
-
losses.append(self.loss_fn_2([target, output[2]]))
|
95 |
-
losses.append(self.metric_fn([inputs[INDEX], inputs[PREDICT], labels]))
|
96 |
-
if self.classification:
|
97 |
-
for i in range(8):
|
98 |
-
losses.append(self.loss_fn_3([labels[:, i], inputs[CLASSIFICATION][i]]))
|
99 |
-
return {**inputs, LOSS: losses}
|
100 |
-
|
101 |
-
|
102 |
-
class VTRavenLoss(Model):
|
103 |
-
def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True, lw=(1.0, 0.1),
|
104 |
-
classification=False, trans=True, anneal=False, plw=None):
|
105 |
-
super().__init__()
|
106 |
-
if anneal:
|
107 |
-
self.weight_scheduler
|
108 |
-
self.classification = classification
|
109 |
-
self.trans = trans
|
110 |
-
self.predict_fn = DictModel(SubClassingModel([lambda x: x[:, -1], PredictModel()]), in_=OUTPUT,
|
111 |
-
out=[PREDICT, MASK], name="pred")
|
112 |
-
self.loss_fn = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
|
113 |
-
group_loss=group_loss, plw=plw), lw=lw[0] , name="add_loss")
|
114 |
-
self.metric_fn = SimilarityRaven(mode=mode)
|
115 |
-
if self.classification:
|
116 |
-
self.loss_fn_2 = add_loss(
|
117 |
-
ClassRavenModel(mode=create_all_mask, number_loss=number_loss, slot_loss=slot_loss,
|
118 |
-
group_loss=group_loss, enable_metrics="c", plw=plw), lw=lw[1], name="class_loss")
|
119 |
-
|
120 |
-
def call(self, inputs):
|
121 |
-
losses = []
|
122 |
-
output = inputs[OUTPUT]
|
123 |
-
target = inputs[TARGET]
|
124 |
-
labels = inputs[LABELS]
|
125 |
-
|
126 |
-
for i in range(9):
|
127 |
-
losses.append(self.loss_fn_2([labels[:, i], output[:, i]]))
|
128 |
-
losses.append(self.loss_fn([target, output[:, 8]]))
|
129 |
-
losses.append(self.metric_fn([inputs[INDEX], inputs[PREDICT], labels]))
|
130 |
-
return {**inputs, LOSS: losses}
|
131 |
-
|
132 |
-
|
133 |
-
class SingleVTRavenLoss(Model):
|
134 |
-
def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True, lw=(1.0, 0.1),
|
135 |
-
classification=False, trans=True, anneal=False):
|
136 |
-
super().__init__()
|
137 |
-
if anneal:
|
138 |
-
self.weight_scheduler
|
139 |
-
self.classification = classification
|
140 |
-
self.trans = trans
|
141 |
-
self.predict_fn = DictModel(PredictModel(), in_=OUTPUT, out=[PREDICT, MASK], name="pred")
|
142 |
-
self.loss_fn = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
|
143 |
-
group_loss=group_loss), lw=lw[0], name="add_loss")
|
144 |
-
self.metric_fn = SimilarityRaven(mode=mode)
|
145 |
-
|
146 |
-
def call(self, inputs):
|
147 |
-
losses = []
|
148 |
-
output = inputs[OUTPUT]
|
149 |
-
target = inputs[TARGET]
|
150 |
-
labels = inputs[LABELS]
|
151 |
-
|
152 |
-
losses.append(self.loss_fn([target, output]))
|
153 |
-
losses.append(self.metric_fn([inputs[INDEX], inputs[PREDICT], labels]))
|
154 |
-
return {**inputs, LOSS: losses}
|
155 |
-
|
156 |
-
|
157 |
-
class ClassRavenModel(Model):
|
158 |
-
def __init__(self, mode=create_all_mask,plw=None, number_loss=False, slot_loss=True, group_loss=True, enable_metrics=True,
|
159 |
-
lw=1.0):
|
160 |
-
super().__init__()
|
161 |
-
self.number_loss = number_loss
|
162 |
-
self.group_loss = group_loss
|
163 |
-
self.enable_metrics = enable_metrics
|
164 |
-
self.slot_loss = slot_loss
|
165 |
-
self.predict_fn = PredictModel()
|
166 |
-
self.loss_fn = SparseCategoricalCrossentropy(from_logits=True)
|
167 |
-
if self.slot_loss:
|
168 |
-
self.loss_fn_2 = tf.nn.sigmoid_cross_entropy_with_logits
|
169 |
-
if self.enable_metrics:
|
170 |
-
self.enable_metrics = f"{self.enable_metrics}_" if isinstance(self.enable_metrics, str) else ""
|
171 |
-
self.metric_fn = [
|
172 |
-
SparseCategoricalAccuracy(name=f"{self.enable_metrics}{ACC}_{property_}") for property_ in
|
173 |
-
rv.properties.NAMES]
|
174 |
-
if self.group_loss:
|
175 |
-
self.metric_fn_group = SparseCategoricalAccuracy(name=f"{self.enable_metrics}{ACC}_{GROUP}")
|
176 |
-
if self.slot_loss:
|
177 |
-
self.metric_fn_2 = BinaryAccuracy(name=f"{self.enable_metrics}{ACC}_{SLOT}")
|
178 |
-
self.range_mask = RangeMask()
|
179 |
-
self.mode = mode
|
180 |
-
self.lw = lw
|
181 |
-
if not plw:
|
182 |
-
plw = [1., 95.37352927, 2.83426987, 0.85212836, 1.096005, 1.21943385]
|
183 |
-
elif isinstance(plw, int) or isinstance(plw, float):
|
184 |
-
plw = [1., plw, 2.83426987, 0.85212836, 1.096005, 1.21943385]
|
185 |
-
# plw = [plw] * 6
|
186 |
-
self.plw = plw
|
187 |
-
|
188 |
-
# self.predict_fn = partial(tf.argmax, axis=-1)
|
189 |
-
|
190 |
-
def call(self, inputs):
|
191 |
-
losses = []
|
192 |
-
metrics = {}
|
193 |
-
target = inputs[0]
|
194 |
-
output = inputs[1]
|
195 |
-
|
196 |
-
target_group, target_slot, target_all = raven_utils.decode.decode_target(target)
|
197 |
-
|
198 |
-
group_output, output_slot, outputs = raven_utils.decode.output_divide(output, split_fn=tf.split)
|
199 |
-
|
200 |
-
# group
|
201 |
-
if self.group_loss:
|
202 |
-
group_loss = self.lw * self.plw[0] * self.loss_fn(target_group, group_output)
|
203 |
-
losses.append(group_loss)
|
204 |
-
|
205 |
-
if isinstance(self.enable_metrics, str):
|
206 |
-
group_metric = self.metric_fn_group(target_group, group_output)
|
207 |
-
# metrics[GROUP] = group_metric
|
208 |
-
self.add_metric(group_metric)
|
209 |
-
self.add_metric(tf.reduce_sum(group_metric), f"{self.enable_metrics}{ACC}")
|
210 |
-
|
211 |
-
# setting uniformity mask
|
212 |
-
full_properties_musks = self.mode(target)
|
213 |
-
|
214 |
-
range_mask = self.range_mask(target_group)
|
215 |
-
|
216 |
-
if self.slot_loss:
|
217 |
-
# number
|
218 |
-
number_mask = range_mask & full_properties_musks[0]
|
219 |
-
number_mask = tf.cast(number_mask, tf.float32)
|
220 |
-
target_number = tf.reduce_sum(
|
221 |
-
tf.cast(target_slot, "float32") * number_mask, axis=-1)
|
222 |
-
output_number = tf.reduce_sum(
|
223 |
-
tf.cast(tf.sigmoid(output_slot) >= 0.5, "float32") * number_mask, axis=-1)
|
224 |
-
|
225 |
-
# output_number = tf.reduce_sum(tf.sigmoid(output_slot) * number_mask, axis=-1)
|
226 |
-
if self.number_loss:
|
227 |
-
scale = 1 / 9
|
228 |
-
if self.number_loss == 2:
|
229 |
-
output_number_2 = tf.reduce_sum(tf.sigmoid(output_slot) * number_mask, axis=-1)
|
230 |
-
else:
|
231 |
-
output_number_2 = output_number
|
232 |
-
number_loss = self.lw * self.plw[1] * mse(tf.stop_gradient(target_number) * scale, output_number_2 * scale)
|
233 |
-
losses.append(number_loss)
|
234 |
-
|
235 |
-
# metrics[NUMBER] = number_acc
|
236 |
-
|
237 |
-
if isinstance(self.enable_metrics, str):
|
238 |
-
number_acc = tf.reduce_mean(
|
239 |
-
tf.cast(tf.cast(target_number, "int8") == tf.cast(output_number, "int8"), "float32"))
|
240 |
-
self.add_metric(tf.reduce_sum(number_acc), f"{self.enable_metrics}{ACC}_{NUMBER}")
|
241 |
-
self.add_metric(tf.reduce_sum(number_acc), f"{self.enable_metrics}{ACC}")
|
242 |
-
self.add_metric(tf.reduce_sum(number_acc), f"{self.enable_metrics}{ACC}_NO_{GROUP}")
|
243 |
-
|
244 |
-
# position/slot
|
245 |
-
slot_mask = range_mask & full_properties_musks[1]
|
246 |
-
# tf.boolean_mask(target_slot,slot_mask)
|
247 |
-
|
248 |
-
if tf.reduce_any(slot_mask):
|
249 |
-
# if tf.reduce_mean(tf.cast(slot_mask, dtype=tf.int32)) > 0:
|
250 |
-
target_slot_masked = tf.boolean_mask(target_slot, slot_mask)[:, None]
|
251 |
-
output_slot_masked = tf.boolean_mask(output_slot, slot_mask)[:, None]
|
252 |
-
loss_slot = self.lw * self.plw[2] * tf.reduce_mean(
|
253 |
-
self.loss_fn_2(tf.cast(target_slot_masked, "float32"), output_slot_masked))
|
254 |
-
if isinstance(self.enable_metrics, str):
|
255 |
-
acc_slot = self.metric_fn_2(target_slot_masked, output_slot_masked)
|
256 |
-
self.add_metric(acc_slot)
|
257 |
-
self.add_metric(tf.reduce_sum(acc_slot), f"{self.enable_metrics}{ACC}")
|
258 |
-
self.add_metric(tf.reduce_sum(acc_slot), f"{self.enable_metrics}{ACC}_NO_{GROUP}")
|
259 |
-
else:
|
260 |
-
loss_slot = 0.0
|
261 |
-
acc_slot = -1.0
|
262 |
-
|
263 |
-
losses.append(loss_slot)
|
264 |
-
# metrics[SLOT] = acc_slot
|
265 |
-
# if loss_slot != 0:
|
266 |
-
|
267 |
-
# if tf.reduce_any(slot_mask):
|
268 |
-
|
269 |
-
# self.add_metric(acc_slot, f"{self.enable_metrics}{ACC}_{NUMBER}")
|
270 |
-
# self.add_metric(acc_slot, f"{self.enable_metrics}{ACC}")
|
271 |
-
# self.add_metric(acc_slot, f"{self.enable_metrics}{ACC}_NO_{GROUP}")
|
272 |
-
|
273 |
-
# properties
|
274 |
-
for i, out in enumerate(outputs):
|
275 |
-
shape = (-1, rv.entity.SUM, rv.properties.RAW_SIZE[i])
|
276 |
-
out_reshaped = tf.reshape(out, shape)
|
277 |
-
properties_mask = tf.cast(target_slot, "bool") & full_properties_musks[i + 2]
|
278 |
-
|
279 |
-
if tf.reduce_any(properties_mask):
|
280 |
-
out_masked = tf.boolean_mask(out_reshaped, properties_mask)
|
281 |
-
out_target = tf.boolean_mask(target_all[i], properties_mask)
|
282 |
-
loss = self.lw * self.plw[3+i] * self.loss_fn(out_target, out_masked)
|
283 |
-
if isinstance(self.enable_metrics, str):
|
284 |
-
metric = self.metric_fn[i](out_target, out_masked)
|
285 |
-
self.add_metric(metric)
|
286 |
-
# self.add_metric(metric, f"{self.enable_metrics}{ACC}")
|
287 |
-
self.add_metric(tf.reduce_sum(metric), f"{self.enable_metrics}{ACC}")
|
288 |
-
self.add_metric(tf.reduce_sum(metric), f"{self.enable_metrics}{ACC}_{PROPERTIES}")
|
289 |
-
self.add_metric(tf.reduce_sum(metric), f"{self.enable_metrics}{ACC}_NO_{GROUP}")
|
290 |
-
else:
|
291 |
-
loss = 0.0
|
292 |
-
metric = -1.0
|
293 |
-
|
294 |
-
losses.append(loss)
|
295 |
-
return losses
|
296 |
-
|
297 |
-
|
298 |
-
class FullMask(Model):
|
299 |
-
def __init__(self, mode=create_uniform_mask):
|
300 |
-
super().__init__()
|
301 |
-
self.range_mask = RangeMask()
|
302 |
-
self.mode = mode
|
303 |
-
|
304 |
-
def call(self, inputs):
|
305 |
-
target_group, target_slot, _ = raven_utils.decode.decode_target(inputs)
|
306 |
-
full_properties_musks = self.mode(inputs)
|
307 |
-
range_mask = self.range_mask(target_group)
|
308 |
-
|
309 |
-
number_mask = range_mask & full_properties_musks[0]
|
310 |
-
|
311 |
-
slot_mask = range_mask & full_properties_musks[1]
|
312 |
-
properties_mask = []
|
313 |
-
for property_mask in full_properties_musks[2:]:
|
314 |
-
properties_mask.append(tf.cast(target_slot, "bool") & property_mask)
|
315 |
-
return [slot_mask, properties_mask, number_mask]
|
316 |
-
|
317 |
-
|
318 |
-
def create_mask(rules, i):
|
319 |
-
mask_1 = tf.tile(rules[:, i][None], [len(rv.target.FIRST_LAYOUT), 1])
|
320 |
-
mask_2 = tf.tile(rules[:, i + 5][None], [len(rv.target.SECOND_LAYOUT), 1])
|
321 |
-
shape = tf.shape(rules)
|
322 |
-
full_mask_1 = tf.scatter_nd(tnp.array(rv.target.FIRST_LAYOUT)[:, None], mask_1, shape=(rv.entity.SUM, shape[0]))
|
323 |
-
full_mask_2 = tf.tensor_scatter_nd_update(full_mask_1, tnp.array(rv.target.SECOND_LAYOUT)[:, None], mask_2)
|
324 |
-
return tf.transpose(full_mask_2)
|
325 |
-
|
326 |
-
|
327 |
-
# class PredictModel(Model):
|
328 |
-
# def __init__(self):
|
329 |
-
# super().__init__()
|
330 |
-
# self.predict_fn = Lambda(partial(tf.argmax, axis=-1))
|
331 |
-
# self.predict_fn_2 = Lambda(lambda x: tf.sigmoid(x) > 0.5)
|
332 |
-
# self.range_mask = RangeMask()
|
333 |
-
#
|
334 |
-
# # self.predict_fn = partial(tf.argmax, axis=-1)
|
335 |
-
#
|
336 |
-
# def call(self, inputs):
|
337 |
-
# group_output = inputs[rv.OUTPUT_GROUP_SLICE]
|
338 |
-
# group_loss = self.predict_fn(group_output)[:, None]
|
339 |
-
#
|
340 |
-
# output_slot = inputs[rv.OUTPUT_SLOT_SLICE]
|
341 |
-
# range_mask = self.range_mask(group_loss[:, 0])
|
342 |
-
# loss_slot = tf.cast(self.predict_fn_2(output_slot), dtype=tf.int64)
|
343 |
-
#
|
344 |
-
# properties_output = inputs[rv.OUTPUT_PROPERTIES_SLICE]
|
345 |
-
# properties = []
|
346 |
-
# outputs = tf.split(properties_output, list(rv.ENTITY_PROPERTIES_INDEX.values()), axis=-1)
|
347 |
-
# for i, out in enumerate(outputs):
|
348 |
-
# shape = (-1, rv.ENTITY_SUM, rv.ENTITY_PROPERTIES_VALUES[i])
|
349 |
-
# out_reshaped = tf.reshape(out, shape)
|
350 |
-
# properties.append(self.predict_fn(out_reshaped))
|
351 |
-
# number_loss = tf.reduce_sum(loss_slot, axis=-1, keepdims=True)
|
352 |
-
#
|
353 |
-
# result = tf.concat([group_loss, loss_slot, interleave(properties), number_loss], axis=-1)
|
354 |
-
#
|
355 |
-
# return [result, range_mask, range_mask, range_mask, range_mask]
|
356 |
-
|
357 |
-
class PredictModel(Model):
|
358 |
-
def __init__(self):
|
359 |
-
super().__init__()
|
360 |
-
self.predict_fn = Predict()
|
361 |
-
self.predict_fn_2 = Lambda(lambda x: tf.sigmoid(x) > 0.5)
|
362 |
-
self.range_mask = RangeMask()
|
363 |
-
|
364 |
-
# self.predict_fn = partial(tf.argmax, axis=-1)
|
365 |
-
|
366 |
-
def call(self, inputs):
|
367 |
-
group_output, output_slot, *properties = rv.decode.output(inputs, tf.split, self.predict_fn, self.predict_fn_2)
|
368 |
-
number_loss = K.int64(K.sum(output_slot))
|
369 |
-
result = tf.concat(
|
370 |
-
[group_output[:, None], tf.cast(output_slot, dtype=tf.int64), interleave(properties), number_loss[:, None]],
|
371 |
-
axis=-1)
|
372 |
-
|
373 |
-
range_mask = self.range_mask(group_output)
|
374 |
-
return [result, range_mask]
|
375 |
-
# return [result, range_mask, range_mask, range_mask, range_mask]
|
376 |
-
|
377 |
-
|
378 |
-
# todo change slices
|
379 |
-
class PredictModelMasked(Model):
|
380 |
-
def __init__(self):
|
381 |
-
super().__init__()
|
382 |
-
self.predict_fn = Lambda(partial(tf.argmax, axis=-1))
|
383 |
-
self.loss_fn_2 = Lambda(lambda x: tf.sigmoid(x) > 0.5)
|
384 |
-
self.range_mask = RangeMask()
|
385 |
-
|
386 |
-
# self.predict_fn = partial(tf.argmax, axis=-1)
|
387 |
-
|
388 |
-
def call(self, inputs):
|
389 |
-
group_output = inputs[:, -rv.GROUPS_NO:]
|
390 |
-
group_loss = self.predict_fn(group_output)[:, None]
|
391 |
-
|
392 |
-
output_slot = inputs[:, :rv.ENTITY_SUM]
|
393 |
-
range_mask = self.range_mask(group_loss[:, 0])
|
394 |
-
loss_slot = tf.cast(self.predict_fn_2(output_slot * range_mask), dtype=tf.int64)
|
395 |
-
|
396 |
-
properties_output = inputs[:, rv.ENTITY_SUM:-rv.GROUPS_NO]
|
397 |
-
|
398 |
-
properties = []
|
399 |
-
outputs = tf.split(properties_output, list(rv.ENTITY_PROPERTIES_INDEX.values()), axis=-1)
|
400 |
-
for i, out in enumerate(outputs):
|
401 |
-
shape = (-1, rv.ENTITY_SUM, rv.ENTITY_PROPERTIES_VALUES[i])
|
402 |
-
out_reshaped = tf.reshape(out, shape)
|
403 |
-
out_masked = out_reshaped * loss_slot[..., None]
|
404 |
-
properties.append(self.predict_fn(out_masked))
|
405 |
-
# out_masked[0].numpy()
|
406 |
-
number_loss = tf.reduce_sum(loss_slot, axis=-1, keepdims=True)
|
407 |
-
|
408 |
-
result = tf.concat([group_loss, loss_slot, interleave(properties), number_loss], axis=-1)
|
409 |
-
|
410 |
-
return result
|
411 |
-
|
412 |
-
|
413 |
-
def final_predict_mask(x, mask):
|
414 |
-
r = reshape(x[0][:, rv.INDEX[0]:-1], [-1, 3])
|
415 |
-
return tf.ragged.boolean_mask(r, mask)
|
416 |
-
|
417 |
-
|
418 |
-
def final_predict(x, mode=False):
|
419 |
-
m = x[1] if mode else tf.cast(x[0][:, 1:rv.INDEX[0]], tf.bool)
|
420 |
-
return final_predict_mask(x[0], m)
|
421 |
-
|
422 |
-
|
423 |
-
def final_predict_2(x):
|
424 |
-
ones = tf.cast(tf.ones(tf.shape(x[0])[0]), tf.bool)[:, None]
|
425 |
-
mask = tf.concat([ones, tf.tile(x[1], [1, 4]), ones], axis=-1)
|
426 |
-
return tf.ragged.boolean_mask(x[0], mask)
|
427 |
-
|
428 |
-
|
429 |
-
class PredictModelOld(Model):
|
430 |
-
|
431 |
-
def call(self, inputs):
|
432 |
-
output = inputs[-2]
|
433 |
-
|
434 |
-
rest_output = output[:, :-rv.GROUPS_NO]
|
435 |
-
|
436 |
-
result_all = []
|
437 |
-
outputs = tf.split(rest_output, list(rv.ENTITY_PROPERTIES_INDEX.values()), axis=-3)
|
438 |
-
for i, out in enumerate(outputs):
|
439 |
-
shape = (-3, rv.ENTITY_SUM, rv.ENTITY_PROPERTIES_VALUES[i])
|
440 |
-
out_reshaped = tf.reshape(out, shape)
|
441 |
-
|
442 |
-
result = tf.cast(tf.argmax(out_reshaped, axis=-3), dtype="int8")
|
443 |
-
result_all.append(result)
|
444 |
-
|
445 |
-
result_all = interleave(result_all)
|
446 |
-
return result_all
|
447 |
-
|
448 |
-
|
449 |
-
def get_matches(diff, target_index):
|
450 |
-
diff_sum = K.sum(diff)
|
451 |
-
db_argsort = tf.argsort(diff_sum, axis=-1)
|
452 |
-
db_sorted = tf.sort(diff_sum)
|
453 |
-
db_mask = db_sorted[:, 0, None] == db_sorted
|
454 |
-
db_same = tf.where(db_mask, db_argsort, -1 * tf.ones_like(db_argsort))
|
455 |
-
matched_index = db_same == target_index
|
456 |
-
# setting shape needed for TensorFlow graph
|
457 |
-
matched_index.set_shape(db_same.shape)
|
458 |
-
matches = K.any(matched_index)
|
459 |
-
more_matches = K.sum(db_mask) > 1
|
460 |
-
once_matches = K.sum(matches & tf.math.logical_not(more_matches))
|
461 |
-
return matches, more_matches, once_matches
|
462 |
-
|
463 |
-
|
464 |
-
class SimilarityRaven(Model):
|
465 |
-
def __init__(self, mode=create_all_mask, number_loss=False):
|
466 |
-
super().__init__()
|
467 |
-
self.range_mask = RangeMask()
|
468 |
-
self.mode = mode
|
469 |
-
|
470 |
-
# self.predict_fn = partial(tf.argmax, axis=-1)
|
471 |
-
|
472 |
-
# INDEX, PREDICT, LABELS
|
473 |
-
def call(self, inputs):
|
474 |
-
metrics = []
|
475 |
-
target_index = inputs[0] - 8
|
476 |
-
predict = inputs[1]
|
477 |
-
answers = inputs[2][:, 8:]
|
478 |
-
shape = tf.shape(predict)
|
479 |
-
|
480 |
-
target = K.gather(answers, target_index[:, 0])
|
481 |
-
|
482 |
-
target_group = target[:, 0]
|
483 |
-
|
484 |
-
# comp_slice = np.
|
485 |
-
target_comp = target[:, 1:rv.target.END_INDEX]
|
486 |
-
predict_comp = predict[:, 1:rv.target.END_INDEX]
|
487 |
-
answers_comp = answers[:, :, 1:rv.target.END_INDEX]
|
488 |
-
|
489 |
-
full_properties_musks = self.mode(target)
|
490 |
-
fpm = K.cat([full_properties_musks[0], interleave(full_properties_musks[2:])])
|
491 |
-
|
492 |
-
range_mask = self.range_mask(target_group)
|
493 |
-
full_range_mask = K.cat([range_mask, tf.repeat(range_mask, 3, axis=-1)], axis=-1)
|
494 |
-
|
495 |
-
final_mask = fpm & full_range_mask
|
496 |
-
|
497 |
-
target_masked = target_comp * final_mask
|
498 |
-
predict_masked = predict_comp * final_mask
|
499 |
-
answers_masked = answers_comp * tf.tile(final_mask[:, None], [1, 8, 1])
|
500 |
-
|
501 |
-
acc_same = K.mean(K.all(target_masked == predict_masked))
|
502 |
-
self.add_metric(acc_same, ACC_SAME)
|
503 |
-
metrics.append(acc_same)
|
504 |
-
|
505 |
-
diff = tf.abs(predict_masked[:, None] - answers_masked)
|
506 |
-
diff_bool = diff != 0
|
507 |
-
|
508 |
-
matches, more_matches, once_matches = get_matches(tf.cast(diff_bool, dtype=tf.int32), target_index)
|
509 |
-
|
510 |
-
second_phase_mask = (more_matches & matches)
|
511 |
-
diff_second_phase = tf.boolean_mask(diff, second_phase_mask)
|
512 |
-
target_index_2 = tf.boolean_mask(target_index, second_phase_mask, axis=0)
|
513 |
-
|
514 |
-
matches_2, more_matches_2, once_matches_2 = get_matches(diff_second_phase, target_index_2)
|
515 |
-
matches_2_no = K.sum(matches_2)
|
516 |
-
|
517 |
-
acc_choose_upper = (once_matches + matches_2_no) / shape[0]
|
518 |
-
self.add_metric(acc_choose_upper, ACC_CHOOSE_UPPER)
|
519 |
-
metrics.append(acc_choose_upper)
|
520 |
-
|
521 |
-
acc_choose_lower = (once_matches + once_matches_2) / shape[0]
|
522 |
-
self.add_metric(acc_choose_lower, ACC_CHOOSE_LOWER)
|
523 |
-
metrics.append(acc_choose_lower)
|
524 |
-
|
525 |
-
return metrics
|
526 |
-
|
527 |
-
|
528 |
-
class SimilarityRaven2(Model):
|
529 |
-
def __init__(self, mode=create_all_mask, number_loss=False):
|
530 |
-
super().__init__()
|
531 |
-
self.range_mask = RangeMask()
|
532 |
-
self.mode = mode
|
533 |
-
|
534 |
-
# self.predict_fn = partial(tf.argmax, axis=-1)
|
535 |
-
|
536 |
-
# INDEX, PREDICT, LABELS
|
537 |
-
def call(self, inputs):
|
538 |
-
metrics = []
|
539 |
-
target_index = inputs[0] - 8
|
540 |
-
predict = inputs[1]
|
541 |
-
answers = inputs[2][:, 8:]
|
542 |
-
shape = tf.shape(predict)
|
543 |
-
|
544 |
-
target = K.gather(answers, target_index[:, 0])
|
545 |
-
|
546 |
-
target_group = target[:, 0]
|
547 |
-
|
548 |
-
# comp_slice = np.
|
549 |
-
target_comp = target[:, 1:rv.target.END_INDEX]
|
550 |
-
predict_comp = predict[:, 1:rv.target.END_INDEX]
|
551 |
-
answers_comp = answers[:, :, 1:rv.target.END_INDEX]
|
552 |
-
|
553 |
-
full_properties_musks = self.mode(target)
|
554 |
-
fpm = K.cat([full_properties_musks[0], interleave(full_properties_musks[2:])])
|
555 |
-
|
556 |
-
range_mask = self.range_mask(target_group)
|
557 |
-
full_range_mask = K.cat([range_mask, tf.repeat(range_mask, 3, axis=-1)], axis=-1)
|
558 |
-
|
559 |
-
final_mask = fpm & full_range_mask
|
560 |
-
|
561 |
-
target_masked = target_comp * final_mask
|
562 |
-
predict_masked = predict_comp * final_mask
|
563 |
-
answers_masked = answers_comp * tf.tile(final_mask[:, None], [1, 8, 1])
|
564 |
-
|
565 |
-
acc_same = K.mean(K.all(target_masked == predict_masked))
|
566 |
-
self.add_metric(acc_same, ACC_SAME)
|
567 |
-
metrics.append(acc_same)
|
568 |
-
|
569 |
-
diff = tf.abs(predict_masked[:, None] - answers_masked)
|
570 |
-
diff_bool = diff != 0
|
571 |
-
|
572 |
-
matches, more_matches, once_matches = get_matches(tf.cast(diff_bool, dtype=tf.int32), target_index)
|
573 |
-
|
574 |
-
second_phase_mask = (more_matches & matches)
|
575 |
-
diff_second_phase = tf.boolean_mask(diff, second_phase_mask)
|
576 |
-
target_index_2 = tf.boolean_mask(target_index, second_phase_mask, axis=0)
|
577 |
-
|
578 |
-
matches_2, more_matches_2, once_matches_2 = get_matches(diff_second_phase, target_index_2)
|
579 |
-
matches_2_no = K.sum(matches_2)
|
580 |
-
|
581 |
-
acc_choose_upper = (once_matches + matches_2_no) / shape[0]
|
582 |
-
self.add_metric(acc_choose_upper, ACC_CHOOSE_UPPER)
|
583 |
-
metrics.append(acc_choose_upper)
|
584 |
-
|
585 |
-
acc_choose_lower = (once_matches + once_matches_2) / shape[0]
|
586 |
-
self.add_metric(acc_choose_lower, ACC_CHOOSE_LOWER)
|
587 |
-
metrics.append(acc_choose_lower)
|
588 |
-
|
589 |
-
metrics.append(K.sum(target_masked != predict_masked))
|
590 |
-
|
591 |
-
return metrics
|
592 |
-
|
593 |
-
|
594 |
-
class LatentLossModel(Model):
|
595 |
-
def __init__(self, dir_=HORIZONTAL):
|
596 |
-
super().__init__()
|
597 |
-
# self.sum_metrics = []
|
598 |
-
# for i in range(8):
|
599 |
-
# self.sum_metrics.append(Sum(name=f"no_{i}"))
|
600 |
-
self.metric_fn = Accuracy(name="acc_latent")
|
601 |
-
if dir_ == VERTICAL:
|
602 |
-
self.dir = (6, 7)
|
603 |
-
else:
|
604 |
-
self.dir = (2, 5)
|
605 |
-
|
606 |
-
def call(self, inputs):
|
607 |
-
target_image = tf.reshape(inputs[0][2], [-1])
|
608 |
-
output = inputs[1]
|
609 |
-
latents = tnp.asarray(inputs[2])
|
610 |
-
|
611 |
-
target_hor = tf.concat([
|
612 |
-
latents[:, self.dir],
|
613 |
-
latents[tf.range(latents.shape[0]), target_image + 8][:, None]
|
614 |
-
],
|
615 |
-
axis=1)
|
616 |
-
|
617 |
-
loss_hor = mse(K.stop_gradient(target_hor), output)
|
618 |
-
self.add_loss(loss_hor)
|
619 |
-
|
620 |
-
self.add_metric(self.metric_fn(inputs[3], target_image))
|
621 |
-
|
622 |
-
return loss_hor
|
623 |
-
|
624 |
-
|
625 |
-
class PredRav(Model):
|
626 |
-
|
627 |
-
def call(self, inputs):
|
628 |
-
output = inputs[0][:, -1]
|
629 |
-
answers = inputs[1][:, 8:]
|
630 |
-
return tf.argmin(tf.reduce_sum(tf.abs(output[:, None] - answers), axis=-1), axis=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raven_utils/models/loss_3.py
DELETED
@@ -1,638 +0,0 @@
|
|
1 |
-
from functools import partial
|
2 |
-
|
3 |
-
import tensorflow as tf
|
4 |
-
import tensorflow.experimental.numpy as tnp
|
5 |
-
from models_utils import OUTPUT, TARGET, PREDICT, DictModel, add_loss, LOSS, Predict
|
6 |
-
from models_utils import SubClassingModel
|
7 |
-
from models_utils.models.utils import interleave
|
8 |
-
from models_utils.op import reshape
|
9 |
-
from tensorflow.keras import Model
|
10 |
-
# from tensorflow.keras import backend as K
|
11 |
-
from tensorflow.keras.layers import Lambda
|
12 |
-
from tensorflow.keras.losses import SparseCategoricalCrossentropy, mse
|
13 |
-
from tensorflow.keras.metrics import SparseCategoricalAccuracy, Accuracy, BinaryAccuracy
|
14 |
-
import models_utils.ops as K
|
15 |
-
|
16 |
-
import raven_utils.decode
|
17 |
-
import raven_utils as rv
|
18 |
-
from raven_utils.config.constant import LABELS, INDEX, ACC_SAME, ACC_CHOOSE_LOWER, ACC_CHOOSE_UPPER, CLASSIFICATION, \
|
19 |
-
SLOT, \
|
20 |
-
PROPERTIES, ACC, GROUP, NUMBER, MASK
|
21 |
-
from raven_utils.models.uitls_ import RangeMask
|
22 |
-
from raven_utils.const import VERTICAL, HORIZONTAL
|
23 |
-
|
24 |
-
|
25 |
-
def get_properties_mask(target):
|
26 |
-
return target[:, rv.target.END_INDEX:rv.target.UNIFORMITY_INDEX] > 0
|
27 |
-
|
28 |
-
|
29 |
-
def create_change_mask(target):
|
30 |
-
properties_mask = get_properties_mask(target)
|
31 |
-
return [create_mask(properties_mask, i) for i, _ in enumerate(rv.rules.ATTRIBUTES)]
|
32 |
-
|
33 |
-
|
34 |
-
def create_uniform_mask(target):
|
35 |
-
u_mask = lambda i: tf.tile(target[:, rv.target.UNIFORMITY_INDEX + i, None] == 3, [1, rv.rules.ATTRIBUTES_LEN])
|
36 |
-
properties_mask = tf.concat([u_mask(0), u_mask(1)], axis=-1) | get_properties_mask(target)
|
37 |
-
return [create_mask(properties_mask, i) for i, _ in enumerate(rv.rules.ATTRIBUTES)]
|
38 |
-
|
39 |
-
|
40 |
-
def create_all_mask(target):
|
41 |
-
return [
|
42 |
-
tf.cast(tf.ones(tf.stack([tf.shape(target)[0], rv.entity.SUM])), dtype=tf.bool) for i, _ in
|
43 |
-
enumerate(rv.rules.ATTRIBUTES)]
|
44 |
-
|
45 |
-
|
46 |
-
class BaselineClassificationLossModel(Model):
|
47 |
-
def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True):
|
48 |
-
super().__init__()
|
49 |
-
self.predict_fn = SubClassingModel([lambda x: x[0], PredictModel()])
|
50 |
-
self.loss_fn = ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
|
51 |
-
group_loss=group_loss)
|
52 |
-
self.metric_fn = SimilarityRaven(mode=mode)
|
53 |
-
|
54 |
-
def call(self, inputs):
|
55 |
-
losses = []
|
56 |
-
output = inputs[1]
|
57 |
-
losses.append(self.loss_fn([inputs[0][0], output]))
|
58 |
-
losses.append(self.metric_fn([inputs[0][2], inputs[3][0], inputs[0][1][:, 8:]]))
|
59 |
-
return losses
|
60 |
-
|
61 |
-
|
62 |
-
class RavenLoss(Model):
|
63 |
-
def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True, lw=(1.0, 0.3),
|
64 |
-
classification=False, trans=True, anneal=False):
|
65 |
-
super().__init__()
|
66 |
-
if anneal:
|
67 |
-
self.weight_scheduler
|
68 |
-
self.classification = classification
|
69 |
-
self.trans = trans
|
70 |
-
self.predict_fn = DictModel(SubClassingModel([lambda x: x[-1], PredictModel()]), in_=OUTPUT,
|
71 |
-
out=[PREDICT, MASK], name="pred")
|
72 |
-
if self.trans:
|
73 |
-
self.loss_fn = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
|
74 |
-
group_loss=group_loss, enable_metrics=False, lw=lw[0]),
|
75 |
-
name="main_loss")
|
76 |
-
self.loss_fn_2 = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
|
77 |
-
group_loss=group_loss), name="add_loss")
|
78 |
-
self.metric_fn = SimilarityRaven(mode=mode)
|
79 |
-
if self.classification:
|
80 |
-
self.loss_fn_3 = add_loss(
|
81 |
-
ClassRavenModel(mode=create_all_mask, number_loss=number_loss, slot_loss=slot_loss,
|
82 |
-
group_loss=group_loss, enable_metrics="c" if self.trans else True), lw=lw[1],
|
83 |
-
name="class_loss")
|
84 |
-
|
85 |
-
def call(self, inputs):
|
86 |
-
losses = []
|
87 |
-
output = inputs[OUTPUT]
|
88 |
-
target = inputs[TARGET]
|
89 |
-
labels = inputs[LABELS]
|
90 |
-
|
91 |
-
if self.trans:
|
92 |
-
losses.append(self.loss_fn([labels[:, 2], output[0]]))
|
93 |
-
losses.append(self.loss_fn([labels[:, 5], output[1]]))
|
94 |
-
losses.append(self.loss_fn_2([target, output[2]]))
|
95 |
-
losses.append(self.metric_fn([inputs[INDEX], inputs[PREDICT], labels]))
|
96 |
-
if self.classification:
|
97 |
-
for i in range(8):
|
98 |
-
losses.append(self.loss_fn_3([labels[:, i], inputs[CLASSIFICATION][i]]))
|
99 |
-
return {**inputs, LOSS: losses}
|
100 |
-
|
101 |
-
|
102 |
-
class VTRavenLoss(Model):
|
103 |
-
def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True, lw=(2.0, 1.0),
|
104 |
-
classification=False, trans=True, anneal=False, plw=None):
|
105 |
-
super().__init__()
|
106 |
-
if anneal:
|
107 |
-
self.weight_scheduler
|
108 |
-
self.classification = classification
|
109 |
-
self.trans = trans
|
110 |
-
self.predict_fn = DictModel(SubClassingModel([lambda x: x[:, -1], PredictModel()]), in_=OUTPUT,
|
111 |
-
out=[PREDICT, "predict_mask"], name="pred")
|
112 |
-
self.loss_fn = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
|
113 |
-
group_loss=group_loss, plw=plw), lw=lw[0], name="add_loss")
|
114 |
-
self.metric_fn = SimilarityRaven(mode=mode)
|
115 |
-
if self.classification:
|
116 |
-
self.loss_fn_2 = add_loss(
|
117 |
-
ClassRavenModel(mode=create_all_mask, number_loss=number_loss, slot_loss=slot_loss,
|
118 |
-
group_loss=group_loss, enable_metrics="c", plw=plw), lw=lw[1], name="class_loss")
|
119 |
-
|
120 |
-
def call(self, inputs):
|
121 |
-
losses = []
|
122 |
-
output = inputs[OUTPUT]
|
123 |
-
target = inputs[TARGET]
|
124 |
-
labels = inputs[LABELS]
|
125 |
-
mask = inputs[MASK]
|
126 |
-
|
127 |
-
target_masked = target[mask]
|
128 |
-
output_masked = output[mask]
|
129 |
-
losses.append(self.loss_fn([target_masked, output_masked]))
|
130 |
-
|
131 |
-
target_unmasked = target[~mask]
|
132 |
-
output_unmasked = output[~mask]
|
133 |
-
losses.append(self.loss_fn_2([target_unmasked, output_unmasked]))
|
134 |
-
|
135 |
-
losses.append(self.metric_fn([inputs[INDEX], inputs[PREDICT], labels]))
|
136 |
-
return {**inputs, LOSS: losses}
|
137 |
-
|
138 |
-
|
139 |
-
class SingleVTRavenLoss(Model):
|
140 |
-
def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True, lw=(1.0, 0.1),
|
141 |
-
classification=False, trans=True, anneal=False):
|
142 |
-
super().__init__()
|
143 |
-
if anneal:
|
144 |
-
self.weight_scheduler
|
145 |
-
self.classification = classification
|
146 |
-
self.trans = trans
|
147 |
-
self.predict_fn = DictModel(PredictModel(), in_=OUTPUT, out=[PREDICT, MASK], name="pred")
|
148 |
-
self.loss_fn = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
|
149 |
-
group_loss=group_loss), lw=lw[0], name="add_loss")
|
150 |
-
self.metric_fn = SimilarityRaven(mode=mode)
|
151 |
-
|
152 |
-
def call(self, inputs):
|
153 |
-
losses = []
|
154 |
-
output = inputs[OUTPUT]
|
155 |
-
target = inputs[TARGET]
|
156 |
-
labels = inputs[LABELS]
|
157 |
-
|
158 |
-
losses.append(self.loss_fn([target, output]))
|
159 |
-
losses.append(self.metric_fn([inputs[INDEX], inputs[PREDICT], labels]))
|
160 |
-
return {**inputs, LOSS: losses}
|
161 |
-
|
162 |
-
|
163 |
-
class ClassRavenModel(Model):
|
164 |
-
def __init__(self, mode=create_all_mask, plw=None, number_loss=False, slot_loss=True, group_loss=True,
|
165 |
-
enable_metrics=True,
|
166 |
-
lw=1.0):
|
167 |
-
super().__init__()
|
168 |
-
self.number_loss = number_loss
|
169 |
-
self.group_loss = group_loss
|
170 |
-
self.enable_metrics = enable_metrics
|
171 |
-
self.slot_loss = slot_loss
|
172 |
-
self.predict_fn = PredictModel()
|
173 |
-
self.loss_fn = SparseCategoricalCrossentropy(from_logits=True)
|
174 |
-
if self.slot_loss:
|
175 |
-
self.loss_fn_2 = tf.nn.sigmoid_cross_entropy_with_logits
|
176 |
-
if self.enable_metrics:
|
177 |
-
self.enable_metrics = f"{self.enable_metrics}_" if isinstance(self.enable_metrics, str) else ""
|
178 |
-
self.metric_fn = [
|
179 |
-
SparseCategoricalAccuracy(name=f"{self.enable_metrics}{ACC}_{property_}") for property_ in
|
180 |
-
rv.properties.NAMES]
|
181 |
-
if self.group_loss:
|
182 |
-
self.metric_fn_group = SparseCategoricalAccuracy(name=f"{self.enable_metrics}{ACC}_{GROUP}")
|
183 |
-
if self.slot_loss:
|
184 |
-
self.metric_fn_2 = BinaryAccuracy(name=f"{self.enable_metrics}{ACC}_{SLOT}")
|
185 |
-
self.range_mask = RangeMask()
|
186 |
-
self.mode = mode
|
187 |
-
self.lw = lw
|
188 |
-
if not plw:
|
189 |
-
plw = [1., 95.37352927, 2.83426987, 0.85212836, 1.096005, 1.21943385]
|
190 |
-
elif isinstance(plw, int) or isinstance(plw, float):
|
191 |
-
plw = [1., plw, 2.83426987, 0.85212836, 1.096005, 1.21943385]
|
192 |
-
# plw = [plw] * 6
|
193 |
-
self.plw = plw
|
194 |
-
|
195 |
-
# self.predict_fn = partial(tf.argmax, axis=-1)
|
196 |
-
|
197 |
-
def call(self, inputs):
|
198 |
-
losses = []
|
199 |
-
metrics = {}
|
200 |
-
target = inputs[0]
|
201 |
-
output = inputs[1]
|
202 |
-
|
203 |
-
target_group, target_slot, target_all = raven_utils.decode.decode_target(target)
|
204 |
-
|
205 |
-
group_output, output_slot, outputs = raven_utils.decode.output_divide(output, split_fn=tf.split)
|
206 |
-
|
207 |
-
# group
|
208 |
-
if self.group_loss:
|
209 |
-
group_loss = self.lw * self.plw[0] * self.loss_fn(target_group, group_output)
|
210 |
-
losses.append(group_loss)
|
211 |
-
|
212 |
-
if isinstance(self.enable_metrics, str):
|
213 |
-
group_metric = self.metric_fn_group(target_group, group_output)
|
214 |
-
# metrics[GROUP] = group_metric
|
215 |
-
self.add_metric(group_metric)
|
216 |
-
self.add_metric(tf.reduce_sum(group_metric), f"{self.enable_metrics}{ACC}")
|
217 |
-
|
218 |
-
# setting uniformity mask
|
219 |
-
full_properties_musks = self.mode(target)
|
220 |
-
|
221 |
-
range_mask = self.range_mask(target_group)
|
222 |
-
|
223 |
-
if self.slot_loss:
|
224 |
-
# number
|
225 |
-
number_mask = range_mask & full_properties_musks[0]
|
226 |
-
number_mask = tf.cast(number_mask, tf.float32)
|
227 |
-
target_number = tf.reduce_sum(
|
228 |
-
tf.cast(target_slot, "float32") * number_mask, axis=-1)
|
229 |
-
output_number = tf.reduce_sum(
|
230 |
-
tf.cast(tf.sigmoid(output_slot) >= 0.5, "float32") * number_mask, axis=-1)
|
231 |
-
|
232 |
-
# output_number = tf.reduce_sum(tf.sigmoid(output_slot) * number_mask, axis=-1)
|
233 |
-
if self.number_loss:
|
234 |
-
scale = 1 / 9
|
235 |
-
if self.number_loss == 2:
|
236 |
-
output_number_2 = tf.reduce_sum(tf.sigmoid(output_slot) * number_mask, axis=-1)
|
237 |
-
else:
|
238 |
-
output_number_2 = output_number
|
239 |
-
number_loss = self.lw * self.plw[1] * mse(tf.stop_gradient(target_number) * scale,
|
240 |
-
output_number_2 * scale)
|
241 |
-
losses.append(number_loss)
|
242 |
-
|
243 |
-
# metrics[NUMBER] = number_acc
|
244 |
-
|
245 |
-
if isinstance(self.enable_metrics, str):
|
246 |
-
number_acc = tf.reduce_mean(
|
247 |
-
tf.cast(tf.cast(target_number, "int8") == tf.cast(output_number, "int8"), "float32"))
|
248 |
-
self.add_metric(tf.reduce_sum(number_acc), f"{self.enable_metrics}{ACC}_{NUMBER}")
|
249 |
-
self.add_metric(tf.reduce_sum(number_acc), f"{self.enable_metrics}{ACC}")
|
250 |
-
self.add_metric(tf.reduce_sum(number_acc), f"{self.enable_metrics}{ACC}_NO_{GROUP}")
|
251 |
-
|
252 |
-
# position/slot
|
253 |
-
slot_mask = range_mask & full_properties_musks[1]
|
254 |
-
# tf.boolean_mask(target_slot,slot_mask)
|
255 |
-
|
256 |
-
if tf.reduce_any(slot_mask):
|
257 |
-
# if tf.reduce_mean(tf.cast(slot_mask, dtype=tf.int32)) > 0:
|
258 |
-
target_slot_masked = tf.boolean_mask(target_slot, slot_mask)[:, None]
|
259 |
-
output_slot_masked = tf.boolean_mask(output_slot, slot_mask)[:, None]
|
260 |
-
loss_slot = self.lw * self.plw[2] * tf.reduce_mean(
|
261 |
-
self.loss_fn_2(tf.cast(target_slot_masked, "float32"), output_slot_masked))
|
262 |
-
if isinstance(self.enable_metrics, str):
|
263 |
-
acc_slot = self.metric_fn_2(target_slot_masked, output_slot_masked)
|
264 |
-
self.add_metric(acc_slot)
|
265 |
-
self.add_metric(tf.reduce_sum(acc_slot), f"{self.enable_metrics}{ACC}")
|
266 |
-
self.add_metric(tf.reduce_sum(acc_slot), f"{self.enable_metrics}{ACC}_NO_{GROUP}")
|
267 |
-
else:
|
268 |
-
loss_slot = 0.0
|
269 |
-
acc_slot = -1.0
|
270 |
-
|
271 |
-
losses.append(loss_slot)
|
272 |
-
# metrics[SLOT] = acc_slot
|
273 |
-
# if loss_slot != 0:
|
274 |
-
|
275 |
-
# if tf.reduce_any(slot_mask):
|
276 |
-
|
277 |
-
# self.add_metric(acc_slot, f"{self.enable_metrics}{ACC}_{NUMBER}")
|
278 |
-
# self.add_metric(acc_slot, f"{self.enable_metrics}{ACC}")
|
279 |
-
# self.add_metric(acc_slot, f"{self.enable_metrics}{ACC}_NO_{GROUP}")
|
280 |
-
|
281 |
-
# properties
|
282 |
-
for i, out in enumerate(outputs):
|
283 |
-
shape = (-1, rv.entity.SUM, rv.properties.RAW_SIZE[i])
|
284 |
-
out_reshaped = tf.reshape(out, shape)
|
285 |
-
properties_mask = tf.cast(target_slot, "bool") & full_properties_musks[i + 2]
|
286 |
-
|
287 |
-
if tf.reduce_any(properties_mask):
|
288 |
-
out_masked = tf.boolean_mask(out_reshaped, properties_mask)
|
289 |
-
out_target = tf.boolean_mask(target_all[i], properties_mask)
|
290 |
-
loss = self.lw * self.plw[3 + i] * self.loss_fn(out_target, out_masked)
|
291 |
-
if isinstance(self.enable_metrics, str):
|
292 |
-
metric = self.metric_fn[i](out_target, out_masked)
|
293 |
-
self.add_metric(metric)
|
294 |
-
# self.add_metric(metric, f"{self.enable_metrics}{ACC}")
|
295 |
-
self.add_metric(tf.reduce_sum(metric), f"{self.enable_metrics}{ACC}")
|
296 |
-
self.add_metric(tf.reduce_sum(metric), f"{self.enable_metrics}{ACC}_{PROPERTIES}")
|
297 |
-
self.add_metric(tf.reduce_sum(metric), f"{self.enable_metrics}{ACC}_NO_{GROUP}")
|
298 |
-
else:
|
299 |
-
loss = 0.0
|
300 |
-
metric = -1.0
|
301 |
-
|
302 |
-
losses.append(loss)
|
303 |
-
return losses
|
304 |
-
|
305 |
-
|
306 |
-
class FullMask(Model):
|
307 |
-
def __init__(self, mode=create_uniform_mask):
|
308 |
-
super().__init__()
|
309 |
-
self.range_mask = RangeMask()
|
310 |
-
self.mode = mode
|
311 |
-
|
312 |
-
def call(self, inputs):
|
313 |
-
target_group, target_slot, _ = raven_utils.decode.decode_target(inputs)
|
314 |
-
full_properties_musks = self.mode(inputs)
|
315 |
-
range_mask = self.range_mask(target_group)
|
316 |
-
|
317 |
-
number_mask = range_mask & full_properties_musks[0]
|
318 |
-
|
319 |
-
slot_mask = range_mask & full_properties_musks[1]
|
320 |
-
properties_mask = []
|
321 |
-
for property_mask in full_properties_musks[2:]:
|
322 |
-
properties_mask.append(tf.cast(target_slot, "bool") & property_mask)
|
323 |
-
return [slot_mask, properties_mask, number_mask]
|
324 |
-
|
325 |
-
|
326 |
-
def create_mask(rules, i):
|
327 |
-
mask_1 = tf.tile(rules[:, i][None], [len(rv.target.FIRST_LAYOUT), 1])
|
328 |
-
mask_2 = tf.tile(rules[:, i + 5][None], [len(rv.target.SECOND_LAYOUT), 1])
|
329 |
-
shape = tf.shape(rules)
|
330 |
-
full_mask_1 = tf.scatter_nd(tnp.array(rv.target.FIRST_LAYOUT)[:, None], mask_1, shape=(rv.entity.SUM, shape[0]))
|
331 |
-
full_mask_2 = tf.tensor_scatter_nd_update(full_mask_1, tnp.array(rv.target.SECOND_LAYOUT)[:, None], mask_2)
|
332 |
-
return tf.transpose(full_mask_2)
|
333 |
-
|
334 |
-
|
335 |
-
# class PredictModel(Model):
|
336 |
-
# def __init__(self):
|
337 |
-
# super().__init__()
|
338 |
-
# self.predict_fn = Lambda(partial(tf.argmax, axis=-1))
|
339 |
-
# self.predict_fn_2 = Lambda(lambda x: tf.sigmoid(x) > 0.5)
|
340 |
-
# self.range_mask = RangeMask()
|
341 |
-
#
|
342 |
-
# # self.predict_fn = partial(tf.argmax, axis=-1)
|
343 |
-
#
|
344 |
-
# def call(self, inputs):
|
345 |
-
# group_output = inputs[rv.OUTPUT_GROUP_SLICE]
|
346 |
-
# group_loss = self.predict_fn(group_output)[:, None]
|
347 |
-
#
|
348 |
-
# output_slot = inputs[rv.OUTPUT_SLOT_SLICE]
|
349 |
-
# range_mask = self.range_mask(group_loss[:, 0])
|
350 |
-
# loss_slot = tf.cast(self.predict_fn_2(output_slot), dtype=tf.int64)
|
351 |
-
#
|
352 |
-
# properties_output = inputs[rv.OUTPUT_PROPERTIES_SLICE]
|
353 |
-
# properties = []
|
354 |
-
# outputs = tf.split(properties_output, list(rv.ENTITY_PROPERTIES_INDEX.values()), axis=-1)
|
355 |
-
# for i, out in enumerate(outputs):
|
356 |
-
# shape = (-1, rv.ENTITY_SUM, rv.ENTITY_PROPERTIES_VALUES[i])
|
357 |
-
# out_reshaped = tf.reshape(out, shape)
|
358 |
-
# properties.append(self.predict_fn(out_reshaped))
|
359 |
-
# number_loss = tf.reduce_sum(loss_slot, axis=-1, keepdims=True)
|
360 |
-
#
|
361 |
-
# result = tf.concat([group_loss, loss_slot, interleave(properties), number_loss], axis=-1)
|
362 |
-
#
|
363 |
-
# return [result, range_mask, range_mask, range_mask, range_mask]
|
364 |
-
|
365 |
-
class PredictModel(Model):
|
366 |
-
def __init__(self):
|
367 |
-
super().__init__()
|
368 |
-
self.predict_fn = Predict()
|
369 |
-
self.predict_fn_2 = Lambda(lambda x: tf.sigmoid(x) > 0.5)
|
370 |
-
self.range_mask = RangeMask()
|
371 |
-
|
372 |
-
# self.predict_fn = partial(tf.argmax, axis=-1)
|
373 |
-
|
374 |
-
def call(self, inputs):
|
375 |
-
group_output, output_slot, *properties = rv.decode.output(inputs, tf.split, self.predict_fn, self.predict_fn_2)
|
376 |
-
number_loss = K.int64(K.sum(output_slot))
|
377 |
-
result = tf.concat(
|
378 |
-
[group_output[:, None], tf.cast(output_slot, dtype=tf.int64), interleave(properties), number_loss[:, None]],
|
379 |
-
axis=-1)
|
380 |
-
|
381 |
-
range_mask = self.range_mask(group_output)
|
382 |
-
return [result, range_mask]
|
383 |
-
# return [result, range_mask, range_mask, range_mask, range_mask]
|
384 |
-
|
385 |
-
|
386 |
-
# todo change slices
|
387 |
-
class PredictModelMasked(Model):
|
388 |
-
def __init__(self):
|
389 |
-
super().__init__()
|
390 |
-
self.predict_fn = Lambda(partial(tf.argmax, axis=-1))
|
391 |
-
self.loss_fn_2 = Lambda(lambda x: tf.sigmoid(x) > 0.5)
|
392 |
-
self.range_mask = RangeMask()
|
393 |
-
|
394 |
-
# self.predict_fn = partial(tf.argmax, axis=-1)
|
395 |
-
|
396 |
-
def call(self, inputs):
|
397 |
-
group_output = inputs[:, -rv.GROUPS_NO:]
|
398 |
-
group_loss = self.predict_fn(group_output)[:, None]
|
399 |
-
|
400 |
-
output_slot = inputs[:, :rv.ENTITY_SUM]
|
401 |
-
range_mask = self.range_mask(group_loss[:, 0])
|
402 |
-
loss_slot = tf.cast(self.predict_fn_2(output_slot * range_mask), dtype=tf.int64)
|
403 |
-
|
404 |
-
properties_output = inputs[:, rv.ENTITY_SUM:-rv.GROUPS_NO]
|
405 |
-
|
406 |
-
properties = []
|
407 |
-
outputs = tf.split(properties_output, list(rv.ENTITY_PROPERTIES_INDEX.values()), axis=-1)
|
408 |
-
for i, out in enumerate(outputs):
|
409 |
-
shape = (-1, rv.ENTITY_SUM, rv.ENTITY_PROPERTIES_VALUES[i])
|
410 |
-
out_reshaped = tf.reshape(out, shape)
|
411 |
-
out_masked = out_reshaped * loss_slot[..., None]
|
412 |
-
properties.append(self.predict_fn(out_masked))
|
413 |
-
# out_masked[0].numpy()
|
414 |
-
number_loss = tf.reduce_sum(loss_slot, axis=-1, keepdims=True)
|
415 |
-
|
416 |
-
result = tf.concat([group_loss, loss_slot, interleave(properties), number_loss], axis=-1)
|
417 |
-
|
418 |
-
return result
|
419 |
-
|
420 |
-
|
421 |
-
def final_predict_mask(x, mask):
|
422 |
-
r = reshape(x[0][:, rv.INDEX[0]:-1], [-1, 3])
|
423 |
-
return tf.ragged.boolean_mask(r, mask)
|
424 |
-
|
425 |
-
|
426 |
-
def final_predict(x, mode=False):
|
427 |
-
m = x[1] if mode else tf.cast(x[0][:, 1:rv.INDEX[0]], tf.bool)
|
428 |
-
return final_predict_mask(x[0], m)
|
429 |
-
|
430 |
-
|
431 |
-
def final_predict_2(x):
|
432 |
-
ones = tf.cast(tf.ones(tf.shape(x[0])[0]), tf.bool)[:, None]
|
433 |
-
mask = tf.concat([ones, tf.tile(x[1], [1, 4]), ones], axis=-1)
|
434 |
-
return tf.ragged.boolean_mask(x[0], mask)
|
435 |
-
|
436 |
-
|
437 |
-
class PredictModelOld(Model):
|
438 |
-
|
439 |
-
def call(self, inputs):
|
440 |
-
output = inputs[-2]
|
441 |
-
|
442 |
-
rest_output = output[:, :-rv.GROUPS_NO]
|
443 |
-
|
444 |
-
result_all = []
|
445 |
-
outputs = tf.split(rest_output, list(rv.ENTITY_PROPERTIES_INDEX.values()), axis=-3)
|
446 |
-
for i, out in enumerate(outputs):
|
447 |
-
shape = (-3, rv.ENTITY_SUM, rv.ENTITY_PROPERTIES_VALUES[i])
|
448 |
-
out_reshaped = tf.reshape(out, shape)
|
449 |
-
|
450 |
-
result = tf.cast(tf.argmax(out_reshaped, axis=-3), dtype="int8")
|
451 |
-
result_all.append(result)
|
452 |
-
|
453 |
-
result_all = interleave(result_all)
|
454 |
-
return result_all
|
455 |
-
|
456 |
-
|
457 |
-
def get_matches(diff, target_index):
|
458 |
-
diff_sum = K.sum(diff)
|
459 |
-
db_argsort = tf.argsort(diff_sum, axis=-1)
|
460 |
-
db_sorted = tf.sort(diff_sum)
|
461 |
-
db_mask = db_sorted[:, 0, None] == db_sorted
|
462 |
-
db_same = tf.where(db_mask, db_argsort, -1 * tf.ones_like(db_argsort))
|
463 |
-
matched_index = db_same == target_index
|
464 |
-
# setting shape needed for TensorFlow graph
|
465 |
-
matched_index.set_shape(db_same.shape)
|
466 |
-
matches = K.any(matched_index)
|
467 |
-
more_matches = K.sum(db_mask) > 1
|
468 |
-
once_matches = K.sum(matches & tf.math.logical_not(more_matches))
|
469 |
-
return matches, more_matches, once_matches
|
470 |
-
|
471 |
-
|
472 |
-
class SimilarityRaven(Model):
|
473 |
-
def __init__(self, mode=create_all_mask, number_loss=False):
|
474 |
-
super().__init__()
|
475 |
-
self.range_mask = RangeMask()
|
476 |
-
self.mode = mode
|
477 |
-
|
478 |
-
# self.predict_fn = partial(tf.argmax, axis=-1)
|
479 |
-
|
480 |
-
# INDEX, PREDICT, LABELS
|
481 |
-
def call(self, inputs):
|
482 |
-
metrics = []
|
483 |
-
target_index = inputs[0] - 8
|
484 |
-
predict = inputs[1]
|
485 |
-
answers = inputs[2][:, 8:]
|
486 |
-
shape = tf.shape(predict)
|
487 |
-
|
488 |
-
target = K.gather(answers, target_index[:, 0])
|
489 |
-
|
490 |
-
target_group = target[:, 0]
|
491 |
-
|
492 |
-
# comp_slice = np.
|
493 |
-
target_comp = target[:, 1:rv.target.END_INDEX]
|
494 |
-
predict_comp = predict[:, 1:rv.target.END_INDEX]
|
495 |
-
answers_comp = answers[:, :, 1:rv.target.END_INDEX]
|
496 |
-
|
497 |
-
full_properties_musks = self.mode(target)
|
498 |
-
fpm = K.cat([full_properties_musks[0], interleave(full_properties_musks[2:])])
|
499 |
-
|
500 |
-
range_mask = self.range_mask(target_group)
|
501 |
-
full_range_mask = K.cat([range_mask, tf.repeat(range_mask, 3, axis=-1)], axis=-1)
|
502 |
-
|
503 |
-
final_mask = fpm & full_range_mask
|
504 |
-
|
505 |
-
target_masked = target_comp * final_mask
|
506 |
-
predict_masked = predict_comp * final_mask
|
507 |
-
answers_masked = answers_comp * tf.tile(final_mask[:, None], [1, 8, 1])
|
508 |
-
|
509 |
-
acc_same = K.mean(K.all(target_masked == predict_masked))
|
510 |
-
self.add_metric(acc_same, ACC_SAME)
|
511 |
-
metrics.append(acc_same)
|
512 |
-
|
513 |
-
diff = tf.abs(predict_masked[:, None] - answers_masked)
|
514 |
-
diff_bool = diff != 0
|
515 |
-
|
516 |
-
matches, more_matches, once_matches = get_matches(tf.cast(diff_bool, dtype=tf.int32), target_index)
|
517 |
-
|
518 |
-
second_phase_mask = (more_matches & matches)
|
519 |
-
diff_second_phase = tf.boolean_mask(diff, second_phase_mask)
|
520 |
-
target_index_2 = tf.boolean_mask(target_index, second_phase_mask, axis=0)
|
521 |
-
|
522 |
-
matches_2, more_matches_2, once_matches_2 = get_matches(diff_second_phase, target_index_2)
|
523 |
-
matches_2_no = K.sum(matches_2)
|
524 |
-
|
525 |
-
acc_choose_upper = (once_matches + matches_2_no) / shape[0]
|
526 |
-
self.add_metric(acc_choose_upper, ACC_CHOOSE_UPPER)
|
527 |
-
metrics.append(acc_choose_upper)
|
528 |
-
|
529 |
-
acc_choose_lower = (once_matches + once_matches_2) / shape[0]
|
530 |
-
self.add_metric(acc_choose_lower, ACC_CHOOSE_LOWER)
|
531 |
-
metrics.append(acc_choose_lower)
|
532 |
-
|
533 |
-
return metrics
|
534 |
-
|
535 |
-
|
536 |
-
class SimilarityRaven2(Model):
|
537 |
-
def __init__(self, mode=create_all_mask, number_loss=False):
|
538 |
-
super().__init__()
|
539 |
-
self.range_mask = RangeMask()
|
540 |
-
self.mode = mode
|
541 |
-
|
542 |
-
# self.predict_fn = partial(tf.argmax, axis=-1)
|
543 |
-
|
544 |
-
# INDEX, PREDICT, LABELS
|
545 |
-
def call(self, inputs):
|
546 |
-
metrics = []
|
547 |
-
target_index = inputs[0] - 8
|
548 |
-
predict = inputs[1]
|
549 |
-
answers = inputs[2][:, 8:]
|
550 |
-
shape = tf.shape(predict)
|
551 |
-
|
552 |
-
target = K.gather(answers, target_index[:, 0])
|
553 |
-
|
554 |
-
target_group = target[:, 0]
|
555 |
-
|
556 |
-
# comp_slice = np.
|
557 |
-
target_comp = target[:, 1:rv.target.END_INDEX]
|
558 |
-
predict_comp = predict[:, 1:rv.target.END_INDEX]
|
559 |
-
answers_comp = answers[:, :, 1:rv.target.END_INDEX]
|
560 |
-
|
561 |
-
full_properties_musks = self.mode(target)
|
562 |
-
fpm = K.cat([full_properties_musks[0], interleave(full_properties_musks[2:])])
|
563 |
-
|
564 |
-
range_mask = self.range_mask(target_group)
|
565 |
-
full_range_mask = K.cat([range_mask, tf.repeat(range_mask, 3, axis=-1)], axis=-1)
|
566 |
-
|
567 |
-
final_mask = fpm & full_range_mask
|
568 |
-
|
569 |
-
target_masked = target_comp * final_mask
|
570 |
-
predict_masked = predict_comp * final_mask
|
571 |
-
answers_masked = answers_comp * tf.tile(final_mask[:, None], [1, 8, 1])
|
572 |
-
|
573 |
-
acc_same = K.mean(K.all(target_masked == predict_masked))
|
574 |
-
self.add_metric(acc_same, ACC_SAME)
|
575 |
-
metrics.append(acc_same)
|
576 |
-
|
577 |
-
diff = tf.abs(predict_masked[:, None] - answers_masked)
|
578 |
-
diff_bool = diff != 0
|
579 |
-
|
580 |
-
matches, more_matches, once_matches = get_matches(tf.cast(diff_bool, dtype=tf.int32), target_index)
|
581 |
-
|
582 |
-
second_phase_mask = (more_matches & matches)
|
583 |
-
diff_second_phase = tf.boolean_mask(diff, second_phase_mask)
|
584 |
-
target_index_2 = tf.boolean_mask(target_index, second_phase_mask, axis=0)
|
585 |
-
|
586 |
-
matches_2, more_matches_2, once_matches_2 = get_matches(diff_second_phase, target_index_2)
|
587 |
-
matches_2_no = K.sum(matches_2)
|
588 |
-
|
589 |
-
acc_choose_upper = (once_matches + matches_2_no) / shape[0]
|
590 |
-
self.add_metric(acc_choose_upper, ACC_CHOOSE_UPPER)
|
591 |
-
metrics.append(acc_choose_upper)
|
592 |
-
|
593 |
-
acc_choose_lower = (once_matches + once_matches_2) / shape[0]
|
594 |
-
self.add_metric(acc_choose_lower, ACC_CHOOSE_LOWER)
|
595 |
-
metrics.append(acc_choose_lower)
|
596 |
-
|
597 |
-
metrics.append(K.sum(target_masked != predict_masked))
|
598 |
-
|
599 |
-
return metrics
|
600 |
-
|
601 |
-
|
602 |
-
class LatentLossModel(Model):
|
603 |
-
def __init__(self, dir_=HORIZONTAL):
|
604 |
-
super().__init__()
|
605 |
-
# self.sum_metrics = []
|
606 |
-
# for i in range(8):
|
607 |
-
# self.sum_metrics.append(Sum(name=f"no_{i}"))
|
608 |
-
self.metric_fn = Accuracy(name="acc_latent")
|
609 |
-
if dir_ == VERTICAL:
|
610 |
-
self.dir = (6, 7)
|
611 |
-
else:
|
612 |
-
self.dir = (2, 5)
|
613 |
-
|
614 |
-
def call(self, inputs):
|
615 |
-
target_image = tf.reshape(inputs[0][2], [-1])
|
616 |
-
output = inputs[1]
|
617 |
-
latents = tnp.asarray(inputs[2])
|
618 |
-
|
619 |
-
target_hor = tf.concat([
|
620 |
-
latents[:, self.dir],
|
621 |
-
latents[tf.range(latents.shape[0]), target_image + 8][:, None]
|
622 |
-
],
|
623 |
-
axis=1)
|
624 |
-
|
625 |
-
loss_hor = mse(K.stop_gradient(target_hor), output)
|
626 |
-
self.add_loss(loss_hor)
|
627 |
-
|
628 |
-
self.add_metric(self.metric_fn(inputs[3], target_image))
|
629 |
-
|
630 |
-
return loss_hor
|
631 |
-
|
632 |
-
|
633 |
-
class PredRav(Model):
|
634 |
-
|
635 |
-
def call(self, inputs):
|
636 |
-
output = inputs[0][:, -1]
|
637 |
-
answers = inputs[1][:, 8:]
|
638 |
-
return tf.argmin(tf.reduce_sum(tf.abs(output[:, None] - answers), axis=-1), axis=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raven_utils/models/multi_transformer.py
DELETED
@@ -1,274 +0,0 @@
|
|
1 |
-
import tensorflow as tf
|
2 |
-
from functools import partial
|
3 |
-
from tensorflow.keras.layers import Lambda
|
4 |
-
from tensorflow.keras.layers import Dense
|
5 |
-
from tensorflow.keras import Input, Model
|
6 |
-
from tensorflow.python.keras import Sequential
|
7 |
-
|
8 |
-
from config.constant import TRANS
|
9 |
-
from ml_utils import filter_init
|
10 |
-
from models.loss import VTRavenLoss, create_uniform_mask, SingleVTRavenLoss
|
11 |
-
from models_utils import pmodel, DictModel, bt, INPUTS, bm, OUTPUT, LATENTS, transformer, BatchModel, get_extractor, \
|
12 |
-
build_seq_model, BUILD, build_train_list, InitialWeight
|
13 |
-
from models_utils import SumPositionEmbedding, TransformerBlock, CatPositionEmbedding, transformer, BatchInitialWeight
|
14 |
-
import models_utils.ops as K
|
15 |
-
from models_utils.image import inverse_fn
|
16 |
-
from models_utils.ops_core import IndexReshape
|
17 |
-
from models_utils.random_ import EpsilonGreedy, EpsilonSoft
|
18 |
-
from models_utils.step import StepDict
|
19 |
-
|
20 |
-
|
21 |
-
def init_weights(shape, dtype=None):
|
22 |
-
return tf.cast(K.var.image(shape=shape, pre=True), dtype=tf.float32)
|
23 |
-
|
24 |
-
|
25 |
-
def conversion(x, max_=45):
|
26 |
-
shape = tf.shape(x)
|
27 |
-
return tf.reshape(x[:, :max_], tf.stack([shape[0], 9, -1]))
|
28 |
-
|
29 |
-
|
30 |
-
def take_left(x):
|
31 |
-
return x[..., 7:8]
|
32 |
-
|
33 |
-
|
34 |
-
def take_by_index(x, i=8):
|
35 |
-
return x[..., i:i + 1]
|
36 |
-
|
37 |
-
|
38 |
-
def mix(x):
|
39 |
-
return (x[..., 7:8] + x[..., 5:6]) / 2
|
40 |
-
|
41 |
-
|
42 |
-
def empty_last(x):
|
43 |
-
return tf.zeros_like(x[..., 7:8])
|
44 |
-
|
45 |
-
|
46 |
-
class Conversion(Model):
|
47 |
-
def __init__(self):
|
48 |
-
super().__init__()
|
49 |
-
self.model = IndexReshape((0, "9", None))
|
50 |
-
|
51 |
-
def call(self, inputs):
|
52 |
-
return self.model(inputs[:, :45])
|
53 |
-
|
54 |
-
|
55 |
-
class RandomImageMask(Model):
|
56 |
-
def __init__(self, last, last_index=9):
|
57 |
-
super().__init__()
|
58 |
-
self.get_last = last
|
59 |
-
self.last_index = last_index
|
60 |
-
|
61 |
-
def call(self, inputs):
|
62 |
-
shape = tf.shape(inputs)
|
63 |
-
indexes = tf.random.uniform(shape=shape[0:1], maxval=self.last_index, dtype=tf.int32)
|
64 |
-
mask = tf.one_hot(indexes, self.last_index)[:, None, None]
|
65 |
-
|
66 |
-
return (1 - mask) * inputs[..., :self.last_index] + mask * tf.tile(self.get_last(inputs),
|
67 |
-
(1, 1, 1, self.last_index))
|
68 |
-
|
69 |
-
|
70 |
-
# res = (1 - mask) * inputs[..., :self.last_index] + mask * tf.tile(self.get_last(inputs),
|
71 |
-
# (1, 1, 1, self.last_index))
|
72 |
-
|
73 |
-
|
74 |
-
# from data_utils import ims
|
75 |
-
# for i in range(50):
|
76 |
-
# ims(res[i].numpy().swapaxes(0, 2))
|
77 |
-
# res[12].numpy()
|
78 |
-
# self.get_last(inputs).numpy()
|
79 |
-
# import tensorflow as tf
|
80 |
-
# tf.random.uniform(shape=shape[0:1], maxval=255, dtype=tf.int32)
|
81 |
-
# from ml_utils import print_error
|
82 |
-
# ims(mask[0].numpy())
|
83 |
-
# print_error(lambda :ims(mask[0]))
|
84 |
-
# from models_utils import ops as K
|
85 |
-
|
86 |
-
|
87 |
-
class ImageMask(Model):
|
88 |
-
def __init__(self, last, index=8, last_index=9):
|
89 |
-
super().__init__()
|
90 |
-
self.get_last = last
|
91 |
-
self.index = index
|
92 |
-
self.last_index = last_index
|
93 |
-
|
94 |
-
def call(self, inputs):
|
95 |
-
return tf.concat([inputs[..., :8], self.get_last(inputs)], axis=-1)
|
96 |
-
|
97 |
-
|
98 |
-
class CreateGrid(Model):
|
99 |
-
def __init__(self,
|
100 |
-
no=4,
|
101 |
-
extractor="ef",
|
102 |
-
type_=3,
|
103 |
-
base="seq",
|
104 |
-
last=take_left,
|
105 |
-
epsilon=None,
|
106 |
-
pooling=None,
|
107 |
-
mask_fn=None,
|
108 |
-
model=None,
|
109 |
-
**kwargs
|
110 |
-
):
|
111 |
-
super().__init__()
|
112 |
-
self.type_ = type_
|
113 |
-
if type_ == 9:
|
114 |
-
self.start_shape = 75
|
115 |
-
data = (224, 224, 3)
|
116 |
-
conv = lambda: Conversion()
|
117 |
-
else:
|
118 |
-
self.start_shape = 84
|
119 |
-
data = (84, 84, 3)
|
120 |
-
extractor = BUILD[base]([
|
121 |
-
BatchModel(get_extractor(data=data, model=extractor)),
|
122 |
-
lambda x: tf.transpose(x, (1, 0, 2, 3, 4))
|
123 |
-
# lambda x: tf.tile(x[:, :224, :224], (1, 1, 1, 3))
|
124 |
-
])
|
125 |
-
conv = lambda: conversion
|
126 |
-
|
127 |
-
self.epsilon = epsilon
|
128 |
-
if mask_fn == "random":
|
129 |
-
mask_fn = RandomImageMask(last=last)
|
130 |
-
elif mask_fn is None:
|
131 |
-
mask_fn = ImageMask(last=last)
|
132 |
-
|
133 |
-
self.mask_fn = mask_fn
|
134 |
-
|
135 |
-
|
136 |
-
def call(self, inputs):
|
137 |
-
transposed = tf.image.resize(tf.transpose(inputs, (0, 2, 3, 1)), (self.start_shape, self.start_shape))
|
138 |
-
re = self.mask_fn(transposed)
|
139 |
-
|
140 |
-
# re = tf.concat([transposed[..., :8], self.get_last(transposed)], axis=-1)
|
141 |
-
if self.type_ == 9:
|
142 |
-
x = tf.transpose(re, [0, 3, 1, 2])[..., None]
|
143 |
-
x = K.create_image_grid(x, 3, 3)
|
144 |
-
x = x[:, :224, :224]
|
145 |
-
x = tf.tile(x, [1, 1, 1, 3])
|
146 |
-
else:
|
147 |
-
|
148 |
-
x = tf.stack([
|
149 |
-
re[..., :3],
|
150 |
-
re[..., 3:6],
|
151 |
-
re[..., 6:9],
|
152 |
-
])
|
153 |
-
return self.model(x)
|
154 |
-
|
155 |
-
|
156 |
-
# self.model.layers[0](x)
|
157 |
-
|
158 |
-
|
159 |
-
def grid_transformer(
|
160 |
-
*args,
|
161 |
-
type_=9,
|
162 |
-
no=4,
|
163 |
-
extractor="ef",
|
164 |
-
loss_mode=create_uniform_mask,
|
165 |
-
output_size=10,
|
166 |
-
loss_weight=1.0,
|
167 |
-
out_layers=(1000, 1000, 1000),
|
168 |
-
pos_emd="cat",
|
169 |
-
base="seq",
|
170 |
-
inverse_image=True,
|
171 |
-
last="left",
|
172 |
-
mask_fn=None,
|
173 |
-
model=None,
|
174 |
-
trans=None,
|
175 |
-
**kwargs):
|
176 |
-
|
177 |
-
if last == "left":
|
178 |
-
last = take_left
|
179 |
-
elif last == "mix":
|
180 |
-
last = mix
|
181 |
-
elif last == "empty":
|
182 |
-
last = empty_last
|
183 |
-
elif last == "start":
|
184 |
-
last = Sequential([Lambda(empty_last), BatchInitialWeight(initializer=init_weights)])
|
185 |
-
|
186 |
-
create_grid = CreateGrid(
|
187 |
-
type_=type_,
|
188 |
-
no=no,
|
189 |
-
extractor=extractor,
|
190 |
-
model=model,
|
191 |
-
output_size=output_size,
|
192 |
-
out_layer=out_layers,
|
193 |
-
pos_emd=pos_emd,
|
194 |
-
base=base,
|
195 |
-
last=last,
|
196 |
-
mask_fn=mask_fn,
|
197 |
-
**kwargs
|
198 |
-
)
|
199 |
-
|
200 |
-
if model is None:
|
201 |
-
trans = transformer(
|
202 |
-
extractor=extractor,
|
203 |
-
pos_emd=pos_emd,
|
204 |
-
data=data,
|
205 |
-
output_size=output_size,
|
206 |
-
out_layers=out_layer,
|
207 |
-
pooling=conv,
|
208 |
-
no=no,
|
209 |
-
base=base,
|
210 |
-
**kwargs
|
211 |
-
# **as_dict(p.trans)
|
212 |
-
)
|
213 |
-
else:
|
214 |
-
trans = trans
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
def get_rav_trans(
|
219 |
-
*args,
|
220 |
-
type_=9,
|
221 |
-
no=4,
|
222 |
-
extractor="ef",
|
223 |
-
loss_mode=create_uniform_mask,
|
224 |
-
output_size=10,
|
225 |
-
loss_weight=1.0,
|
226 |
-
out_layers=(1000, 1000, 1000),
|
227 |
-
pos_emd="cat",
|
228 |
-
base="seq",
|
229 |
-
inverse_image=True,
|
230 |
-
last="left",
|
231 |
-
epsilon="greedy",
|
232 |
-
epsilon_step=500,
|
233 |
-
mask_fn=None,
|
234 |
-
model=None,
|
235 |
-
loss="multi",
|
236 |
-
**kwargs):
|
237 |
-
if last == "left":
|
238 |
-
last = take_left
|
239 |
-
elif last == "mix":
|
240 |
-
last = mix
|
241 |
-
elif last == "empty":
|
242 |
-
last = empty_last
|
243 |
-
elif last == "start":
|
244 |
-
last = Sequential([Lambda(empty_last), BatchInitialWeight(initializer=init_weights)])
|
245 |
-
|
246 |
-
trans_raven = CreateGrid(
|
247 |
-
type_=type_,
|
248 |
-
no=no,
|
249 |
-
extractor=extractor,
|
250 |
-
model=model,
|
251 |
-
output_size=output_size,
|
252 |
-
out_layer=out_layers,
|
253 |
-
pos_emd=pos_emd,
|
254 |
-
base=base,
|
255 |
-
last=last,
|
256 |
-
epsilon=epsilon,
|
257 |
-
mask_fn=mask_fn,
|
258 |
-
**kwargs
|
259 |
-
)
|
260 |
-
|
261 |
-
if loss == "single":
|
262 |
-
loss = SingleVTRavenLoss
|
263 |
-
else:
|
264 |
-
loss = VTRavenLoss
|
265 |
-
|
266 |
-
return bt(
|
267 |
-
DictModel(
|
268 |
-
Sequential([Lambda(lambda x: 255 - x), trans_raven]) if inverse_image else trans_raven,
|
269 |
-
in_=INPUTS,
|
270 |
-
name="Body"
|
271 |
-
),
|
272 |
-
loss=loss(mode=loss_mode, classification=True, lw=(loss_weight, 1.0)),
|
273 |
-
loss_wrap=False
|
274 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raven_utils/models/raven.py
DELETED
@@ -1,239 +0,0 @@
|
|
1 |
-
from ml_utils import lw, lu
|
2 |
-
from models_utils import bm, Base, res, bt, DictModel, dense_drop, drop, build_encoder, MODEL_ARCH, ListModel, short, \
|
3 |
-
dense, Flatten, Cat, CatDenseBefore, \
|
4 |
-
CatDense, CatBefore, Drop, Flat2, down, Pass, conv, Flat, Get, bs, Res, SoftBlock
|
5 |
-
from models_utils import SubClassingModel
|
6 |
-
from models_utils.config.constants import *
|
7 |
-
import numpy as np
|
8 |
-
|
9 |
-
from config.constant import *
|
10 |
-
from tensorflow.keras.layers import Dense, Activation, BatchNormalization
|
11 |
-
import tensorflow as tf
|
12 |
-
|
13 |
-
import raven_utils as rv
|
14 |
-
|
15 |
-
from models.body import create_block
|
16 |
-
from models.class_ import Merge, RavenClass
|
17 |
-
from models.head import LatentHeadModel
|
18 |
-
|
19 |
-
from models.loss import RavenLoss
|
20 |
-
from models.trans import TransModel, FullTrans
|
21 |
-
from raven_utils.const import HORIZONTAL
|
22 |
-
|
23 |
-
|
24 |
-
def raven_model(scales,
|
25 |
-
out_layers,
|
26 |
-
latent=(64, 128, 256),
|
27 |
-
output_size=None,
|
28 |
-
padding=SAME,
|
29 |
-
body_layers=1,
|
30 |
-
encoder=None,
|
31 |
-
loop=1,
|
32 |
-
model=None,
|
33 |
-
act=None,
|
34 |
-
simpler=0,
|
35 |
-
loss_mode=None,
|
36 |
-
loss_weight=0.3,
|
37 |
-
dir_=HORIZONTAL,
|
38 |
-
global_context=False,
|
39 |
-
images_no=8,
|
40 |
-
context_mul=2,
|
41 |
-
res_act="pass",
|
42 |
-
drop_latent=0,
|
43 |
-
drop_inference=0,
|
44 |
-
drop_end=0,
|
45 |
-
ga=False,
|
46 |
-
trans_norm=None,
|
47 |
-
trans_act="relu",
|
48 |
-
arch=HEAD3,
|
49 |
-
encoder_norm=False,
|
50 |
-
encoder_pool=False,
|
51 |
-
encoder_global="GM",
|
52 |
-
encoder_before=False,
|
53 |
-
tail_units=256,
|
54 |
-
tail_flatten=None,
|
55 |
-
# for now by default
|
56 |
-
tail_down="MP",
|
57 |
-
trans_no=1,
|
58 |
-
trans_score_activation=tf.nn.softmax,
|
59 |
-
block_=SoftBlock,
|
60 |
-
**kwargs):
|
61 |
-
if isinstance(latent, int):
|
62 |
-
latent = (latent, 128, 256)
|
63 |
-
scales = lw(scales)
|
64 |
-
|
65 |
-
context_size = np.array(latent) * context_mul
|
66 |
-
# context_size = latent[scales] * context_mul
|
67 |
-
|
68 |
-
# if scales == 2:
|
69 |
-
# arch = HEAD
|
70 |
-
# elif scales == 1:
|
71 |
-
# arch = HEAD2
|
72 |
-
# else:
|
73 |
-
# arch = VERY2
|
74 |
-
|
75 |
-
if encoder_pool:
|
76 |
-
strides = (1, 1)
|
77 |
-
else:
|
78 |
-
strides = (2, 2)
|
79 |
-
if not isinstance(encoder_before, tuple):
|
80 |
-
encoder_before = [encoder_before] * 3
|
81 |
-
|
82 |
-
# if trans == 1:
|
83 |
-
# trans_model = TransModel2
|
84 |
-
# else:
|
85 |
-
# trans_model = TransModel
|
86 |
-
|
87 |
-
# if scales == 3:
|
88 |
-
# head = MultiHeadModel(encoder=encoder)
|
89 |
-
arch = MODEL_ARCH[arch]
|
90 |
-
heads = []
|
91 |
-
for s in list(range(0, max(scales) + 1)):
|
92 |
-
if s in (0, 1):
|
93 |
-
if s == 0:
|
94 |
-
encoder = build_encoder(arch[:3], add_norm=encoder_norm, add_pool=encoder_pool, kerner_size=(4, 4),
|
95 |
-
strides=strides)
|
96 |
-
else:
|
97 |
-
encoder = build_encoder(arch[3:4], add_norm=encoder_norm, add_pool=encoder_pool, kerner_size=(4, 4),
|
98 |
-
strides=strides)
|
99 |
-
head = LatentHeadModel(
|
100 |
-
encoder=encoder,
|
101 |
-
inference_network=(
|
102 |
-
bm([
|
103 |
-
CatBefore(filters=int(context_size[s] / 8)) if encoder_before[s] else Cat(
|
104 |
-
filters=context_size[s]),
|
105 |
-
# todo activation?
|
106 |
-
Res(filters=context_size[s], padding=padding)
|
107 |
-
] + ([drop(drop_inference)] if drop_inference else []),
|
108 |
-
name="inference")
|
109 |
-
) if s in scales else Pass(),
|
110 |
-
stem=Base(
|
111 |
-
bm(
|
112 |
-
# ok we choose by parameters anyway
|
113 |
-
[res(filters=latent[s], padding=padding, act=act)] + (
|
114 |
-
[drop(drop_latent)] if drop_latent else [])
|
115 |
-
),
|
116 |
-
name="stem")
|
117 |
-
)
|
118 |
-
else:
|
119 |
-
encoder = bm([
|
120 |
-
Res(),
|
121 |
-
build_encoder(arch[4:], add_norm=encoder_norm, add_pool=encoder_pool, kerner_size=(4, 4),
|
122 |
-
strides=strides),
|
123 |
-
short(encoder_global) if encoder_global else Flatten(),
|
124 |
-
dense(latent[s])
|
125 |
-
])
|
126 |
-
head = LatentHeadModel(
|
127 |
-
encoder=encoder,
|
128 |
-
inference_network=bm([
|
129 |
-
# todo Echeck Cat
|
130 |
-
CatDenseBefore(filters=int(context_size[s] / 8)) if encoder_before[
|
131 |
-
s] else CatDense(filters=context_size[s]),
|
132 |
-
# todo activation?
|
133 |
-
Res(model="dv2", filters=context_size[s], padding=padding)
|
134 |
-
] + ([dense_drop(drop_inference)] if drop_inference else []),
|
135 |
-
name="inference"),
|
136 |
-
stem=Base(
|
137 |
-
bm(
|
138 |
-
# ok we choose by parameters anyway
|
139 |
-
[res(model="dv2", units=latent[s], padding=padding, act=act)] + (
|
140 |
-
[dense_drop(drop_latent)] if drop_latent else [])
|
141 |
-
),
|
142 |
-
name="stem")
|
143 |
-
)
|
144 |
-
heads.append(head)
|
145 |
-
|
146 |
-
concat_input = [f"{LATENT}_{i}" for i, _ in enumerate(heads)] + [f"{INFERENCE}_{i}" for i, _ in enumerate(heads)]
|
147 |
-
concat_output = ["LATENTS", "INFERENCES"]
|
148 |
-
|
149 |
-
def head_concat(inputs):
|
150 |
-
latents = inputs[:len(heads)]
|
151 |
-
inferences = inputs[len(heads):]
|
152 |
-
return latents, inferences
|
153 |
-
|
154 |
-
head = ListModel([(h, (INPUTS if i == 0 else OUTPUT), [f"{LATENT}_{i}", f"{INFERENCE}_{i}", OUTPUT]) for i, h in
|
155 |
-
enumerate(heads)] + [
|
156 |
-
(head_concat, concat_input, concat_output)], out=concat_output)
|
157 |
-
# from rav_utils.raven import init_image
|
158 |
-
# a = init_image()
|
159 |
-
# head(a)
|
160 |
-
|
161 |
-
if model is None:
|
162 |
-
model = []
|
163 |
-
for i in scales:
|
164 |
-
trans_models = []
|
165 |
-
for t in range(trans_no):
|
166 |
-
trans_models.append(
|
167 |
-
bm(
|
168 |
-
[create_block(latent=latent[i], simpler=simpler, padding=padding, norm=trans_norm, act=res_act,
|
169 |
-
loop=loop, type_="dense" if i == 2 else "conv", block_=block_)] +
|
170 |
-
[Activation(trans_act)] + [
|
171 |
-
res(filters=latent[i],
|
172 |
-
padding=padding,
|
173 |
-
act=act,
|
174 |
-
name="body_out",
|
175 |
-
model="dv2" if i == 2 else "v2") for _ in
|
176 |
-
range(body_layers)] + ([Drop(drop_latent)] if drop_latent else []),
|
177 |
-
base_class=SubClassingModel)
|
178 |
-
)
|
179 |
-
trans_models = lu(trans_models)
|
180 |
-
if trans_no > 1:
|
181 |
-
trans_models = bm([
|
182 |
-
lambda x: [[x[0], x[1]], x[1]],
|
183 |
-
SoftBlock(
|
184 |
-
model=trans_models,
|
185 |
-
score_model=bm([
|
186 |
-
Flat2(filters=latent[i], units=256, res_no=2),
|
187 |
-
Dense(trans_no, trans_score_activation)
|
188 |
-
])
|
189 |
-
)
|
190 |
-
],
|
191 |
-
base_class=SubClassingModel
|
192 |
-
)
|
193 |
-
|
194 |
-
model.append(
|
195 |
-
TransModel(
|
196 |
-
body=trans_models,
|
197 |
-
dir_=dir_,
|
198 |
-
images_no=images_no
|
199 |
-
)
|
200 |
-
)
|
201 |
-
|
202 |
-
tail = []
|
203 |
-
for i, s in enumerate(scales):
|
204 |
-
flatting = lambda: Flat2(filters=latent[s + 1], base_class=tail_flatten, units=tail_units)
|
205 |
-
if s == 0:
|
206 |
-
if tail_flatten is None:
|
207 |
-
branch = bm([res(filters=latent[s], padding=padding),
|
208 |
-
conv(filters=latent[s], padding=padding),
|
209 |
-
BatchNormalization(),
|
210 |
-
conv(filters=latent[s], padding=padding),
|
211 |
-
Flatten()])
|
212 |
-
else:
|
213 |
-
branch = bm([down(base_class=tail_down), flatting()])
|
214 |
-
elif s == 1:
|
215 |
-
if tail_flatten is None:
|
216 |
-
branch = bm([res(filters=latent[s], padding=padding),
|
217 |
-
Flatten()])
|
218 |
-
else:
|
219 |
-
branch = flatting()
|
220 |
-
else:
|
221 |
-
branch = bm([tail_units] * 2, add_flatten=False)
|
222 |
-
tail.append(branch)
|
223 |
-
|
224 |
-
tail.append(
|
225 |
-
bm([dense(tail_units)] + ([dense_drop(drop_end)] if drop_end else []) + [Dense(output_size)], add_flatten=False,
|
226 |
-
name=TAIL))
|
227 |
-
class_input = []
|
228 |
-
|
229 |
-
return bt([
|
230 |
-
DictModel(head, in_=INPUTS, out=[LATENT, INFERENCE], name="Head"),
|
231 |
-
DictModel(FullTrans(model, scales=scales), in_=[LATENT, INFERENCE], out=TRANS, name="Body"),
|
232 |
-
DictModel(RavenClass(Merge(tail), scales=scales, no=8), in_=[LATENT] + class_input, out=CLASSIFICATION,
|
233 |
-
name="Classificator"),
|
234 |
-
DictModel(RavenClass(Merge(tail), scales=list(range(len(scales))), no=3), in_=[TRANS] + class_input,
|
235 |
-
out=OUTPUT, name="Classificator_trans"),
|
236 |
-
],
|
237 |
-
loss=RavenLoss(mode=loss_mode, classification=True, trans=True, lw=(1.0, loss_weight)),
|
238 |
-
loss_wrap=False
|
239 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raven_utils/models/trans.py
DELETED
@@ -1,74 +0,0 @@
|
|
1 |
-
import tensorflow as tf
|
2 |
-
from ml_utils import lw
|
3 |
-
from models_utils import ops as K, SubClassingModel
|
4 |
-
from tensorflow.keras import Model
|
5 |
-
|
6 |
-
from models.body import create_dense_block
|
7 |
-
import raven_utils as rv
|
8 |
-
from raven_utils.const import HORIZONTAL, VERTICAL
|
9 |
-
|
10 |
-
|
11 |
-
class TransModel(Model):
|
12 |
-
def __init__(self, body=None, dir_=HORIZONTAL, images_no=8, latent=64):
|
13 |
-
super().__init__()
|
14 |
-
self.model = body or create_dense_block(latent=latent)
|
15 |
-
if dir_ == VERTICAL:
|
16 |
-
self.dir = (0, 3, 1, 4, 3, 5)
|
17 |
-
else:
|
18 |
-
self.dir = (0, 1, 3, 4, 6, 7)
|
19 |
-
self.images_no = images_no
|
20 |
-
self.latent = latent
|
21 |
-
|
22 |
-
def call(self, inputs):
|
23 |
-
# latents = tnp.asarray(inputs[0])
|
24 |
-
latents = inputs[0]
|
25 |
-
inference = inputs[1]
|
26 |
-
shape = tf.shape(latents)
|
27 |
-
new_shape = K.cat([[-1, 3, 2], shape[2:]])
|
28 |
-
horizontal = latents[:, self.dir].reshape(new_shape)
|
29 |
-
res = tf.TensorArray(tf.float32, size=3)
|
30 |
-
for i in range(3):
|
31 |
-
res = res.write(i, self.model([horizontal[:, i], inference]))
|
32 |
-
result = K.tran(res.stack())
|
33 |
-
return result
|
34 |
-
|
35 |
-
|
36 |
-
class TransModel2(Model):
|
37 |
-
def __init__(self, body=None, dir_=HORIZONTAL, images_no=8, latent=64):
|
38 |
-
super().__init__()
|
39 |
-
self.body = body or create_dense_block(latent=latent)
|
40 |
-
if dir_ == VERTICAL:
|
41 |
-
self.dir = (0, 3, 1, 4, 3, 5)
|
42 |
-
else:
|
43 |
-
self.dir = (0, 1, 3, 4, 6, 7)
|
44 |
-
self.images_no = images_no
|
45 |
-
self.latent = latent
|
46 |
-
|
47 |
-
def call(self, inputs):
|
48 |
-
# latents = tnp.asarray(inputs[0])
|
49 |
-
latents = inputs[0]
|
50 |
-
inference = inputs[1]
|
51 |
-
shape = tf.shape(latents)
|
52 |
-
new_shape = K.cat([[-1, 3, 2], shape[2:]])
|
53 |
-
horizontal = latents[:, self.dir].reshape(new_shape)
|
54 |
-
res = tf.TensorArray(tf.float32, size=3)
|
55 |
-
for i in tf.range(3):
|
56 |
-
res = res.write(i, self.body([horizontal[:, i], inference[:,i]]))
|
57 |
-
result = K.tran(res.stack())
|
58 |
-
return result
|
59 |
-
|
60 |
-
|
61 |
-
class FullTrans(SubClassingModel):
|
62 |
-
def __init__(self, model,scales,name=None):
|
63 |
-
super().__init__(model=model,name=name)
|
64 |
-
self.scales = scales
|
65 |
-
|
66 |
-
def call(self, inputs):
|
67 |
-
latent = lw(inputs[0])
|
68 |
-
inference = lw(inputs[1])
|
69 |
-
results = []
|
70 |
-
# todo merging inference?
|
71 |
-
for i,s in enumerate(self.scales):
|
72 |
-
# results.append(model([latent[::-1][i], inference]))
|
73 |
-
results.append(self.model[i]([latent[s], inference[s]]))
|
74 |
-
return results,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raven_utils/models/transformer.py
DELETED
@@ -1,133 +0,0 @@
|
|
1 |
-
import tensorflow as tf
|
2 |
-
from tensorflow.keras.layers import Lambda
|
3 |
-
from tensorflow.python.keras import Sequential
|
4 |
-
|
5 |
-
# from models_utils.models.loss import VTRavenLoss, create_uniform_mask, SingleVTRavenLoss
|
6 |
-
from models_utils import DictModel, bt, INPUTS, BatchInitialWeight
|
7 |
-
import models_utils.ops as K
|
8 |
-
from models_utils.models.transformer.img_seq import init_weights, take_left, mix, empty_last
|
9 |
-
from models_utils.models.transformer.img_seq2 import init_weights, take_left, mix, empty_last, img_sec_trans
|
10 |
-
from models_utils.ops_core import IndexReshape
|
11 |
-
from models_utils.random_ import EpsilonGreedy, EpsilonSoft
|
12 |
-
from models_utils.step import StepDict
|
13 |
-
|
14 |
-
# res = (1 - mask) * inputs[..., :self.last_index] + mask * tf.tile(self.get_last(inputs),
|
15 |
-
# (1, 1, 1, self.last_index))
|
16 |
-
|
17 |
-
|
18 |
-
# from data_utils import ims
|
19 |
-
# for i in range(50):
|
20 |
-
# ims(res[i].numpy().swapaxes(0, 2))
|
21 |
-
# res[12].numpy()
|
22 |
-
# self.get_last(inputs).numpy()
|
23 |
-
# import tensorflow as tf
|
24 |
-
# tf.random.uniform(shape=shape[0:1], maxval=255, dtype=tf.int32)
|
25 |
-
# from ml_utils import print_error
|
26 |
-
# ims(mask[0].numpy())
|
27 |
-
# print_error(lambda :ims(mask[0]))
|
28 |
-
# from models_utils import ops as K
|
29 |
-
|
30 |
-
|
31 |
-
# self.model.layers[0](x)
|
32 |
-
from raven_utils.models.loss import VTRavenLoss, SingleVTRavenLoss, create_uniform_mask
|
33 |
-
|
34 |
-
|
35 |
-
def get_rav_trans(
|
36 |
-
data,
|
37 |
-
type_=9,
|
38 |
-
no=4,
|
39 |
-
extractor="ef",
|
40 |
-
loss_mode=create_uniform_mask,
|
41 |
-
output_size=10,
|
42 |
-
loss_weight=1.0,
|
43 |
-
out_layers=(1000, 1000, 1000),
|
44 |
-
pos_emd="cat",
|
45 |
-
base="seq",
|
46 |
-
inverse_image=True,
|
47 |
-
last="left",
|
48 |
-
epsilon="greedy",
|
49 |
-
epsilon_step=500,
|
50 |
-
mask_fn=None,
|
51 |
-
model=None,
|
52 |
-
loss="multi",
|
53 |
-
**kwargs):
|
54 |
-
if last == "left":
|
55 |
-
last = take_left
|
56 |
-
elif last == "mix":
|
57 |
-
last = mix
|
58 |
-
elif last == "empty":
|
59 |
-
last = empty_last
|
60 |
-
elif last == "start":
|
61 |
-
last = Sequential([Lambda(empty_last), BatchInitialWeight(initializer=init_weights)])
|
62 |
-
|
63 |
-
if epsilon == "greedy":
|
64 |
-
epsilon = EpsilonGreedy(step=epsilon_step)
|
65 |
-
elif epsilon == "soft":
|
66 |
-
epsilon = EpsilonSoft(step=epsilon_step)
|
67 |
-
elif epsilon is False:
|
68 |
-
epsilon = None
|
69 |
-
|
70 |
-
if epsilon:
|
71 |
-
trans_raven = TransRavenwithStep(
|
72 |
-
type_=type_,
|
73 |
-
no=no,
|
74 |
-
extractor=extractor,
|
75 |
-
output_size=output_size,
|
76 |
-
out_layer=out_layers,
|
77 |
-
pos_emd=pos_emd,
|
78 |
-
base=base,
|
79 |
-
last=last,
|
80 |
-
epsilon=epsilon,
|
81 |
-
**kwargs
|
82 |
-
)
|
83 |
-
return StepDict(bt(
|
84 |
-
DictModel(
|
85 |
-
Sequential([Lambda(lambda x: (255 - x[0], x[1])), trans_raven]) if inverse_image else trans_raven,
|
86 |
-
in_=[INPUTS, "step"],
|
87 |
-
name="Body"
|
88 |
-
),
|
89 |
-
loss=VTRavenLoss(mode=loss_mode, classification=True, lw=(loss_weight, 1.0)),
|
90 |
-
loss_wrap=False),
|
91 |
-
add_step=epsilon_step,
|
92 |
-
)
|
93 |
-
|
94 |
-
trans_raven = img_sec_trans(
|
95 |
-
type_=type_,
|
96 |
-
no=no,
|
97 |
-
extractor=extractor,
|
98 |
-
model=model,
|
99 |
-
output_size=output_size,
|
100 |
-
out_layer=out_layers,
|
101 |
-
pos_emd=pos_emd,
|
102 |
-
base=base,
|
103 |
-
last=last,
|
104 |
-
epsilon=epsilon,
|
105 |
-
mask_fn=mask_fn,
|
106 |
-
**kwargs
|
107 |
-
)
|
108 |
-
if loss == "single":
|
109 |
-
loss = SingleVTRavenLoss
|
110 |
-
else:
|
111 |
-
loss = VTRavenLoss
|
112 |
-
|
113 |
-
# return bt(
|
114 |
-
# DictModel(
|
115 |
-
# Sequential([Lambda(lambda x: 255 - x), trans_raven]) if inverse_image else trans_raven,
|
116 |
-
# inputs=INPUTS,
|
117 |
-
# name="Body"
|
118 |
-
# ),
|
119 |
-
# loss=loss(mode=loss_mode, classification=True, lw=(loss_weight, 1.0)),
|
120 |
-
# loss_wrap=False
|
121 |
-
# )
|
122 |
-
|
123 |
-
return bt([
|
124 |
-
DictModel(
|
125 |
-
Sequential([Lambda(lambda x: 255 - x), trans_raven]) if inverse_image else trans_raven,
|
126 |
-
in_=INPUTS,
|
127 |
-
name="Body"
|
128 |
-
),
|
129 |
-
|
130 |
-
],
|
131 |
-
loss=loss(mode=loss_mode, classification=True, lw=(loss_weight, 1.0)),
|
132 |
-
loss_wrap=False
|
133 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raven_utils/models/transformer_2.py
DELETED
@@ -1,146 +0,0 @@
|
|
1 |
-
from functools import partial
|
2 |
-
|
3 |
-
import tensorflow as tf
|
4 |
-
from tensorflow.keras.layers import Lambda
|
5 |
-
from tensorflow.python.keras import Sequential
|
6 |
-
from models_utils import ops as K, SubClassing
|
7 |
-
from models_utils.models.transformer import aug
|
8 |
-
|
9 |
-
# from models_utils.models.loss import VTRavenLoss, create_uniform_mask, SingleVTRavenLoss
|
10 |
-
from data_utils import DataGenerator, LOSS, TARGET, IMAGES
|
11 |
-
from models_utils import DictModel, bt, INPUTS, BatchInitialWeight, build_functional_model, get_input_layer
|
12 |
-
import models_utils.ops as K
|
13 |
-
from models_utils.models.transformer.img_seq import init_weights, take_left, mix, empty_last
|
14 |
-
from models_utils.models.transformer.img_seq2 import init_weights, take_left, mix, empty_last, img_sec_trans
|
15 |
-
from models_utils.ops_core import IndexReshape
|
16 |
-
from models_utils.random_ import EpsilonGreedy, EpsilonSoft
|
17 |
-
from models_utils.step import StepDict
|
18 |
-
|
19 |
-
from models_utils.models.transformer import aug
|
20 |
-
|
21 |
-
# res = (1 - mask) * inputs[..., :self.last_index] + mask * tf.tile(self.get_last(inputs),
|
22 |
-
# (1, 1, 1, self.last_index))
|
23 |
-
|
24 |
-
|
25 |
-
# from data_utils import ims
|
26 |
-
# for i in range(50):
|
27 |
-
# ims(res[i].numpy().swapaxes(0, 2))
|
28 |
-
# res[12].numpy()
|
29 |
-
# self.get_last(inputs).numpy()
|
30 |
-
# import tensorflow as tf
|
31 |
-
# tf.random.uniform(shape=shape[0:1], maxval=255, dtype=tf.int32)
|
32 |
-
# from ml_utils import print_error
|
33 |
-
# ims(mask[0].numpy())
|
34 |
-
# print_error(lambda :ims(mask[0]))
|
35 |
-
# from models_utils import ops as K
|
36 |
-
|
37 |
-
|
38 |
-
# self.model.layers[0](x)
|
39 |
-
from raven_utils.constant import INDEX, LABELS
|
40 |
-
from raven_utils.models.loss import VTRavenLoss, SingleVTRavenLoss, create_uniform_mask
|
41 |
-
|
42 |
-
|
43 |
-
def get_matrix(inputs, index):
|
44 |
-
return tf.concat([inputs[:, :8], K.gather(inputs, index[:, 0])[:, None]], axis=1)
|
45 |
-
|
46 |
-
|
47 |
-
def get_images(inputs):
|
48 |
-
return get_matrix(inputs[0], inputs[1])
|
49 |
-
|
50 |
-
|
51 |
-
def random_last(inputs, max_=8):
|
52 |
-
index = K.init.label(max=max_, shape=[tf.shape(inputs[0])[0]])[..., None]
|
53 |
-
return get_matrix(inputs[0], index)
|
54 |
-
|
55 |
-
|
56 |
-
def get_images_no_answer(inputs):
|
57 |
-
return inputs[0][:, :9]
|
58 |
-
|
59 |
-
|
60 |
-
def repeat_last(inputs):
|
61 |
-
return inputs[0][:, list(range(8)) + [7]]
|
62 |
-
|
63 |
-
|
64 |
-
def get_rav_trans(
|
65 |
-
data,
|
66 |
-
inverse_image=True,
|
67 |
-
loss_mode=create_uniform_mask,
|
68 |
-
loss_weight=1.0,
|
69 |
-
loss="multi",
|
70 |
-
number_loss=False,
|
71 |
-
plw=None,
|
72 |
-
pre="auto",
|
73 |
-
augmentation=None,
|
74 |
-
**kwargs):
|
75 |
-
if isinstance(data, DataGenerator):
|
76 |
-
data = data[0]['inputs'], data[0]['index']
|
77 |
-
# u = img_sec_trans(**kwargs)(get_images(data) if kwargs['mask'] == "random" else get_images_no_answer(data))
|
78 |
-
# u.shape
|
79 |
-
from keras import Model
|
80 |
-
if pre == "auto":
|
81 |
-
pre = get_images if kwargs['mask'] == "random" else get_images_no_answer
|
82 |
-
elif pre == "no_answer":
|
83 |
-
pre = get_images_no_answer
|
84 |
-
elif pre == "last":
|
85 |
-
pre = repeat_last
|
86 |
-
elif pre == "images":
|
87 |
-
pre = get_images
|
88 |
-
elif pre == "random_last":
|
89 |
-
pre = random_last
|
90 |
-
elif pre == "noise":
|
91 |
-
pre = SubClassing([get_matrix, partial(aug.noise, max_=8)])
|
92 |
-
elif pre == "batch_noise":
|
93 |
-
pre = SubClassing([get_matrix, partial(aug.batch_noise, max_=8)])
|
94 |
-
|
95 |
-
if augmentation == "transpose":
|
96 |
-
augmentation = aug.Transpose(axis=(0, 2, 1))
|
97 |
-
augmentation_label = aug.Transpose(axis=(0, 2, 1))
|
98 |
-
elif augmentation == "shuffle_col":
|
99 |
-
augmentation = aug.shuffle_col
|
100 |
-
augmentation_label = aug.shuffle_col
|
101 |
-
elif augmentation == "shuffle":
|
102 |
-
augmentation = aug.shuffle
|
103 |
-
augmentation_label = aug.shuffle
|
104 |
-
if augmentation:
|
105 |
-
augmentation = [
|
106 |
-
# DictModel(augmentation, IMAGES, IMAGES),
|
107 |
-
# DictModel(aug.reshape_static(pre(data),augmentation), IMAGES, IMAGES),
|
108 |
-
DictModel(aug.ReshapeStatic(augmentation), IMAGES, IMAGES),
|
109 |
-
DictModel(
|
110 |
-
aug.PartialModel(
|
111 |
-
aug.ReshapeStatic(augmentation_label),
|
112 |
-
last_axis=9)
|
113 |
-
, LABELS, LABELS)
|
114 |
-
]
|
115 |
-
else:
|
116 |
-
augmentation = []
|
117 |
-
|
118 |
-
trans_raven = build_functional_model(
|
119 |
-
img_sec_trans(**kwargs),
|
120 |
-
# get_images(data) if kwargs['mask'] == "random" else get_images_no_answer(data)
|
121 |
-
pre(data)
|
122 |
-
# data[0]
|
123 |
-
)
|
124 |
-
if loss == "single":
|
125 |
-
loss = SingleVTRavenLoss
|
126 |
-
else:
|
127 |
-
loss = VTRavenLoss
|
128 |
-
if isinstance(loss_weight, float):
|
129 |
-
loss_weight = (loss_weight, 1.0)
|
130 |
-
|
131 |
-
return bt([
|
132 |
-
# DictModel(get_images if kwargs['mask'] == "random" else get_images_no_answer, [INPUTS, INDEX], IMAGES),
|
133 |
-
DictModel(pre, [INPUTS, INDEX], IMAGES),
|
134 |
-
*augmentation,
|
135 |
-
DictModel(
|
136 |
-
Sequential([Lambda(lambda x: 255 - x), trans_raven]) if inverse_image else trans_raven,
|
137 |
-
in_=IMAGES,
|
138 |
-
# inputs=INPUTS,
|
139 |
-
name="Body"
|
140 |
-
),
|
141 |
-
|
142 |
-
],
|
143 |
-
loss=loss(mode=loss_mode, classification=True, number_loss=number_loss, lw=loss_weight, plw=plw),
|
144 |
-
predict=LOSS,
|
145 |
-
loss_wrap=False
|
146 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raven_utils/models/transformer_3.py
DELETED
@@ -1,206 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
|
3 |
-
from loguru import logger
|
4 |
-
from tensorflow.keras.layers import Lambda
|
5 |
-
from tensorflow.keras.layers import Activation
|
6 |
-
|
7 |
-
from grid_transformer import aug_trans
|
8 |
-
from raven_utils.models.loss_3 import VTRavenLoss, SingleVTRavenLoss, create_uniform_mask
|
9 |
-
from data_utils import get_shape, TakeDict
|
10 |
-
|
11 |
-
from data_utils import DataGenerator, LOSS, TARGET, IMAGES
|
12 |
-
from models_utils import DictModel, bt, INPUTS, BatchInitialWeight, build_functional, get_input_layer, Last, bm, \
|
13 |
-
add_end, AUGMENTATION
|
14 |
-
# from report.select_ import SelectModel2, SelectModel, SelectModel9
|
15 |
-
from experiment_utils.keras_model import load_weights as model_load_weights
|
16 |
-
|
17 |
-
|
18 |
-
def get_rav_trans(
|
19 |
-
data,
|
20 |
-
loss_mode=create_uniform_mask,
|
21 |
-
loss_weight=2.0,
|
22 |
-
number_loss=False,
|
23 |
-
dry_run="auto",
|
24 |
-
plw=None,
|
25 |
-
**kwargs):
|
26 |
-
if isinstance(loss_weight, float):
|
27 |
-
loss_weight = (loss_weight, 1.0)
|
28 |
-
|
29 |
-
# seq_trans(**kwargs)(data[0])
|
30 |
-
# trans_raven = build_functional_model2(
|
31 |
-
# seq_trans(**kwargs),
|
32 |
-
# data[0],
|
33 |
-
# batch=None
|
34 |
-
# )
|
35 |
-
trans_raven = build_functional(
|
36 |
-
model=aug_trans,
|
37 |
-
inputs_=data[0] if isinstance(data, DataGenerator) else data,
|
38 |
-
batch_=None,
|
39 |
-
dry_run=dry_run,
|
40 |
-
**kwargs
|
41 |
-
)
|
42 |
-
|
43 |
-
return bt(
|
44 |
-
model=trans_raven,
|
45 |
-
loss=VTRavenLoss(mode=loss_mode, classification=True, number_loss=number_loss, lw=loss_weight, plw=plw),
|
46 |
-
model_wrap=False,
|
47 |
-
predict=LOSS,
|
48 |
-
loss_wrap=False
|
49 |
-
)
|
50 |
-
|
51 |
-
|
52 |
-
def rav_select_model(
|
53 |
-
data,
|
54 |
-
load_weights=None,
|
55 |
-
loss_weight=(0.01, 0.0),
|
56 |
-
plw=5.0,
|
57 |
-
result_metric="sparse_categorical_accuracy",
|
58 |
-
select_type=2,
|
59 |
-
select_out=0,
|
60 |
-
additional_out=0,
|
61 |
-
additional_copy=True,
|
62 |
-
tail_out=(1000, 1000),
|
63 |
-
**kwargs
|
64 |
-
):
|
65 |
-
out_layers = Last()
|
66 |
-
if additional_out > 0:
|
67 |
-
model3 = get_rav_trans(
|
68 |
-
data,
|
69 |
-
plw=plw,
|
70 |
-
loss_weight=loss_weight,
|
71 |
-
**kwargs
|
72 |
-
)
|
73 |
-
|
74 |
-
model_load_weights(
|
75 |
-
model3,
|
76 |
-
load_weights,
|
77 |
-
# sample_data,
|
78 |
-
None,
|
79 |
-
template="weights_{epoch:02d}-{val_loss:.2f}",
|
80 |
-
key=result_metric,
|
81 |
-
)
|
82 |
-
|
83 |
-
if AUGMENTATION in kwargs and kwargs[AUGMENTATION] is not None:
|
84 |
-
index = -1
|
85 |
-
else:
|
86 |
-
index = -2
|
87 |
-
|
88 |
-
out = model3[0, index, :additional_out]
|
89 |
-
logger.info(f"Additional out from: {model3[0, index]}.")
|
90 |
-
|
91 |
-
if additional_out > 2:
|
92 |
-
out += [Activation("gelu")]
|
93 |
-
out_layers = bm([out_layers] + out, add_flatten=False)
|
94 |
-
model = get_rav_trans(
|
95 |
-
TakeDict(data[0])[:, 8:],
|
96 |
-
plw=plw,
|
97 |
-
loss_weight=loss_weight,
|
98 |
-
**{
|
99 |
-
**kwargs,
|
100 |
-
"out_layers": out_layers,
|
101 |
-
}
|
102 |
-
# **{**as_dict(p.mp), "show_shape": True, "save_shape": f"output/shapes/type_{p.mp.type_}.json"},
|
103 |
-
)
|
104 |
-
# from data_utils.ops import Equal
|
105 |
-
# o = []
|
106 |
-
# for i in range(1, 3):
|
107 |
-
# for j in range(2):
|
108 |
-
# o.append(
|
109 |
-
# Equal(
|
110 |
-
# # model[0,:,-2, i].variables[j],
|
111 |
-
# model2[0, :, -2, i].variables[j],
|
112 |
-
# # out_layers[i].variables[j]
|
113 |
-
# second_pooling[i].variables[j]
|
114 |
-
# ).equal
|
115 |
-
# )
|
116 |
-
# assert all(o)
|
117 |
-
# model = get_rav_trans(
|
118 |
-
# # TakeDict(val_generator[0])[:, 8:],
|
119 |
-
# # TakeDict(val_generator[0])[:, 8:],
|
120 |
-
# val_generator[0],
|
121 |
-
# plw=p.plw,
|
122 |
-
# loss_weight=p.loss_weight,
|
123 |
-
# **{**as_dict(p.mp),
|
124 |
-
# # "out_layers": out_layers,
|
125 |
-
# }
|
126 |
-
# # **{**as_dict(p.mp), "show_shape": True, "save_shape": f"output/shapes/type_{p.mp.type_}.json"},
|
127 |
-
# )
|
128 |
-
model_load_weights(model,
|
129 |
-
load_weights,
|
130 |
-
# sample_data,
|
131 |
-
None,
|
132 |
-
template="weights_{epoch:02d}-{val_loss:.2f}",
|
133 |
-
key=result_metric,
|
134 |
-
)
|
135 |
-
# model.compile()
|
136 |
-
# model.evaluate(val_generator.data[:1000])
|
137 |
-
# model(TakeDict(val_generator[0])[:, 8:])
|
138 |
-
trans_raven = model[0]
|
139 |
-
# s = trans_raven(TakeDict(val_generator[0])[:, 8:])
|
140 |
-
if select_type == 2:
|
141 |
-
second_pooling = Lambda(lambda x: x[:, :-1])
|
142 |
-
else:
|
143 |
-
second_pooling = Last()
|
144 |
-
if additional_out > 0:
|
145 |
-
if additional_copy:
|
146 |
-
model4 = get_rav_trans(
|
147 |
-
data,
|
148 |
-
plw=plw,
|
149 |
-
loss_weight=loss_weight,
|
150 |
-
**kwargs
|
151 |
-
)
|
152 |
-
model_load_weights(model4,
|
153 |
-
load_weights,
|
154 |
-
# sample_data,
|
155 |
-
None,
|
156 |
-
template="weights_{epoch:02d}-{val_loss:.2f}",
|
157 |
-
key=result_metric,
|
158 |
-
)
|
159 |
-
|
160 |
-
if AUGMENTATION in kwargs and kwargs[AUGMENTATION] is not None:
|
161 |
-
index = -1
|
162 |
-
else:
|
163 |
-
index = -2
|
164 |
-
out2 = model4[0, index, :additional_out]
|
165 |
-
logger.info(f"Additional out from: {model4[0, index]}.")
|
166 |
-
|
167 |
-
if additional_out > 2:
|
168 |
-
out2 += [Activation("gelu")]
|
169 |
-
else:
|
170 |
-
out2 = out
|
171 |
-
|
172 |
-
second_pooling = bm([second_pooling] + out2, add_flatten=False)
|
173 |
-
|
174 |
-
model2 = get_rav_trans(
|
175 |
-
TakeDict(data[0])[:, 8:],
|
176 |
-
plw=plw,
|
177 |
-
loss_weight=loss_weight,
|
178 |
-
**{
|
179 |
-
**kwargs,
|
180 |
-
"out_layers": second_pooling,
|
181 |
-
}
|
182 |
-
# **{**as_dict(p.mp), "show_shape": True, "save_shape": f"output/shapes/type_{p.mp.type_}.json"},
|
183 |
-
)
|
184 |
-
model_load_weights(
|
185 |
-
model2,
|
186 |
-
load_weights,
|
187 |
-
# sample_data,
|
188 |
-
None,
|
189 |
-
template="weights_{epoch:02d}-{val_loss:.2f}",
|
190 |
-
key=result_metric,
|
191 |
-
)
|
192 |
-
if select_type == 0:
|
193 |
-
# not working
|
194 |
-
trans_raven2 = model2[0]
|
195 |
-
else:
|
196 |
-
trans_raven2 = model2[0]
|
197 |
-
tail = add_end(out_layers=tail_out, output_size=8 if select_out else 1)
|
198 |
-
# trans_raven2.mask_fn = ImageMask(last=take_by_index)
|
199 |
-
if select_type == 2:
|
200 |
-
select_model_class = SelectModel2
|
201 |
-
elif select_type == 1:
|
202 |
-
select_model_class = SelectModel
|
203 |
-
else:
|
204 |
-
select_model_class = SelectModel9
|
205 |
-
select_model = select_model_class(trans_raven, model2=trans_raven2, tail=tail, select_out=select_out)
|
206 |
-
return select_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raven_utils/models/uitls_.py
DELETED
@@ -1,16 +0,0 @@
|
|
1 |
-
import tensorflow as tf
|
2 |
-
import tensorflow.experimental.numpy as tnp
|
3 |
-
from tensorflow.keras import Model
|
4 |
-
import raven_utils as rv
|
5 |
-
|
6 |
-
|
7 |
-
class RangeMask(Model):
|
8 |
-
def __init__(self):
|
9 |
-
super().__init__()
|
10 |
-
ranges = tf.tile(tf.range(rv.entity.INDEX[-1])[None], [rv.group.NO, 1])
|
11 |
-
start_index = rv.entity.INDEX[:-1][:, None]
|
12 |
-
end_index = rv.entity.INDEX[1:][:, None]
|
13 |
-
self.mask = tnp.array((start_index <= ranges) & (ranges < end_index))
|
14 |
-
|
15 |
-
def call(self, inputs):
|
16 |
-
return self.mask[inputs]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|