Jakub Kwiatkowski commited on
Commit
9502bdf
·
1 Parent(s): 86045f3

Refactor hf/raven.

Browse files
models.py CHANGED
@@ -11,3 +11,4 @@ indexes = nload("/home/jkwiatkowski/all/dataset/arr/val_target.npy")
11
 
12
  folders = DataSetFromFolder("/home/jkwiatkowski/all/dataset/arr/RAVEN-10000-release/RAVEN-10000", file_type="dir")
13
  properties = DataSetFromFolder(folders[:], file_type="xml", extension="val")
 
 
11
 
12
  folders = DataSetFromFolder("/home/jkwiatkowski/all/dataset/arr/RAVEN-10000-release/RAVEN-10000", file_type="dir")
13
  properties = DataSetFromFolder(folders[:], file_type="xml", extension="val")
14
+
raven_utils/depricated/__init__.py DELETED
File without changes
raven_utils/depricated/old_raven.py DELETED
@@ -1,490 +0,0 @@
1
- from functools import partial
2
-
3
- import numpy as np
4
- from data_utils import take, EXIST, COR
5
- from data_utils.image import draw_images, add_text
6
- from data_utils.op import np_split
7
- from ml_utils import lu, dict_from_list2, filter_keys, none
8
- from data_utils import ops as K
9
-
10
- from config.constant import PROPERTY, TARGET, INPUTS
11
- # from raven_utils.render.rendering import render_panels
12
-
13
- RENDER_POSITIONS = [
14
- [(0.5, 0.5, 1, 1)],
15
- # ...
16
- [(0.25, 0.25, 0.5, 0.5),
17
- (0.25, 0.75, 0.5, 0.5),
18
- (0.75, 0.25, 0.5, 0.5),
19
- (0.75, 0.75, 0.5, 0.5)],
20
- # ...
21
- [(0.16, 0.16, 0.33, 0.33),
22
- (0.16, 0.5, 0.33, 0.33),
23
- (0.16, 0.83, 0.33, 0.33),
24
- (0.5, 0.16, 0.33, 0.33),
25
- (0.5, 0.5, 0.33, 0.33),
26
- (0.5, 0.83, 0.33, 0.33),
27
- (0.83, 0.16, 0.33, 0.33),
28
- (0.83, 0.5, 0.33, 0.33),
29
- (0.83, 0.83, 0.33, 0.33)],
30
- # ...
31
- [(0.5, 0.25, 0.5, 0.5)],
32
- [(0.5, 0.75, 0.5, 0.5)],
33
- # ...
34
- [(0.25, 0.5, 0.5, 0.5)],
35
- [(0.75, 0.5, 0.5, 0.5)],
36
- # ...
37
- [(0.5, 0.5, 1, 1)],
38
- [(0.5, 0.5, 0.33, 0.33)],
39
- # ...
40
- [(0.5, 0.5, 1, 1)],
41
- [(0.42, 0.42, 0.15, 0.15),
42
- (0.42, 0.58, 0.15, 0.15),
43
- (0.58, 0.42, 0.15, 0.15),
44
- (0.58, 0.58, 0.15, 0.15)],
45
- # ...
46
-
47
- ]
48
-
49
- HORIZONTAL = "horizontal"
50
- VERTICAL = "vertical"
51
-
52
- NAMES = ['center_single',
53
- 'distribute_four',
54
- 'distribute_nine',
55
- 'in_center_single_out_center_single',
56
- 'in_distribute_four_out_center_single',
57
- 'left_center_single_right_center_single',
58
- 'up_center_single_down_center_single']
59
-
60
- PROPERTIES_NAMES = [
61
- 'Color',
62
- 'Size',
63
- 'Type',
64
-
65
- ]
66
- PROPERTIES = dict_from_list2(PROPERTIES_NAMES, [10, 6, 5])
67
- ANGLE_MAX = 7
68
-
69
- PROPERTIES_NO = len(PROPERTIES)
70
-
71
- RULES_COMBINE = "Number/Position"
72
-
73
- RULES_ATTRIBUTES = [
74
- "Number",
75
- "Position",
76
- "Color",
77
- "Size",
78
- "Type"
79
- ]
80
- RULES_ATTRIBUTES_LEN = len(RULES_ATTRIBUTES)
81
-
82
- RULES_ATTRIBUTES_INDEX = dict_from_list2(RULES_ATTRIBUTES)
83
-
84
- RULES_TYPES = [
85
- "Constant",
86
- "Arithmetic",
87
- "Progression",
88
- "Distribute_Three"
89
- ]
90
- RULES_TYPES_INDEX = dict_from_list2(RULES_TYPES)
91
- RULES_TYPES_LEN = len(RULES_ATTRIBUTES)
92
-
93
- GROUPS_NO = len(NAMES)
94
- ENTITY_NO = dict(zip(NAMES, [1, 4, 9, 2, 5, 2, 2]))
95
- ENTITY_SUM = sum(list(ENTITY_NO.values()))
96
- ENTITY_INDEX = np.concatenate([[0], np.cumsum(list(ENTITY_NO.values()))])
97
- ENTITY_INDEX_TARGET = ENTITY_INDEX + 1
98
- ENTITY_DICT = dict(zip(NAMES, ENTITY_INDEX_TARGET[:-1]))
99
- NAMES_ORDER = dict(zip(NAMES, np.arange(len(NAMES))))
100
- PROPERTIES_INDEXES = np.cumsum(np.array(list(ENTITY_NO.values())) * len(PROPERTIES))
101
- INDEX = np.concatenate([[0], PROPERTIES_INDEXES]) + ENTITY_SUM + 1 # +2 type and uniformity
102
-
103
- SECOND_LAYOUT = [i - 1 for i in [
104
- ENTITY_DICT["in_center_single_out_center_single"] + 1,
105
- ENTITY_DICT["in_distribute_four_out_center_single"] + 1,
106
- ENTITY_DICT["in_distribute_four_out_center_single"] + 2,
107
- ENTITY_DICT["in_distribute_four_out_center_single"] + 3,
108
- ENTITY_DICT["left_center_single_right_center_single"] + 1,
109
- ENTITY_DICT["up_center_single_down_center_single"] + 1
110
- ]]
111
-
112
- FIRST_LAYOUT = list(set(range(ENTITY_SUM)) - set(SECOND_LAYOUT))
113
- LAYOUT_NO = 2
114
-
115
- START_INDEX = dict(zip(NAMES, INDEX[:-1]))
116
- END_INDEX = INDEX[-1]
117
-
118
- RULES_ATTRIBUTES_ALL_LEN = RULES_ATTRIBUTES_LEN * LAYOUT_NO
119
- UNIFORMITY_NO = 2
120
- UNIFORMITY_INDEX = END_INDEX + RULES_ATTRIBUTES_ALL_LEN
121
-
122
- FEATURE_NO = UNIFORMITY_INDEX + UNIFORMITY_NO
123
- MAPPING = {
124
- "distribute_nine":
125
- {0.16: 0,
126
- 0.5: 1,
127
- 0.83: 2},
128
- "distribute_four":
129
- {0.25: 0,
130
- 0.75: 1},
131
- 'in_distribute_four_out_center_single':
132
- {0.42: 0,
133
- 0.58: 1}
134
- }
135
- MUL = {
136
- "distribute_nine": 3,
137
- "distribute_four": 2,
138
- 'in_distribute_four_out_center_single': 2
139
- }
140
-
141
- # SIZES = np.linspace(0.4, 0.9, 6)
142
- TYPES = ["triangle", "square", "pentagon", "hexagon", "circle"]
143
- # TYPES = ["triangle", "square", "pentagon", "circle", "circle"]
144
- SIZES = ["vs", "s", "m", "h", "vh", "e"]
145
- COLORS = ["vs", "s", "m", "h", "vh", "e"]
146
- # TYPES = ["", "", "circle", "hexagon", "square"]
147
-
148
- ENTITY_PROPERTIES_VALUES = list(PROPERTIES.values())
149
- ENTITY_PROPERTIES_KEYS = list(PROPERTIES.keys())
150
- ENTITY_PROPERTIES_NO = len(PROPERTIES)
151
- INDEX = dict(zip(PROPERTIES, np.array(ENTITY_PROPERTIES_VALUES) * ENTITY_SUM))
152
- ENTITY_PROPERTIES_SUM = sum(list(PROPERTIES.values()))
153
-
154
- OUTPUT_SIZE = ENTITY_SUM * ENTITY_PROPERTIES_SUM + GROUPS_NO + ENTITY_SUM
155
-
156
- SLOT_AND_GROUP = ENTITY_SUM + GROUPS_NO
157
-
158
- OUTPUT_GROUP_SLICE = np.s_[:, -GROUPS_NO:]
159
- OUTPUT_SLOT_SLICE = np.s_[:, -SLOT_AND_GROUP:-GROUPS_NO]
160
- OUTPUT_PROPERTIES_SLICE = np.s_[:, :-SLOT_AND_GROUP]
161
-
162
- OUTPUT_GROUP_SLICE_END = np.s_[-GROUPS_NO:]
163
- OUTPUT_SLOT_SLICE_END = np.s_[-SLOT_AND_GROUP:-GROUPS_NO]
164
- OUTPUT_PROPERTIES_SLICE_END = np.s_[:-SLOT_AND_GROUP]
165
-
166
- # Transformation
167
- # constant
168
- # progression -2, -1,1 ,2
169
- # arithmetic -/+ Position set arithmetic
170
- # distribute three
171
-
172
- # todo
173
- SLOTS_GROUPS = GROUPS_NO
174
-
175
- SLOT_TRANSFORMATION_NO = 4
176
- PROPERTY_TRANSFORMATION_NO = 8
177
- PROPERTIES_TRANSFORMATION_NO = PROPERTY_TRANSFORMATION_NO * PROPERTIES_NO
178
- PROPERTIES_TRANSFORMATION_SIZE = PROPERTIES_TRANSFORMATION_NO * ENTITY_SUM
179
-
180
- SLOT_TRANSFORMATION_SIZE = PROPERTY_TRANSFORMATION_NO * SLOTS_GROUPS
181
- INFERENCE_SIZE = SLOT_TRANSFORMATION_SIZE + PROPERTIES_TRANSFORMATION_SIZE
182
-
183
- INFERENCE_SLOT_SLICE = np.s_[:, :SLOT_TRANSFORMATION_SIZE]
184
- INFERENCE_PROPERTIES_SLICE = np.s_[:, -PROPERTIES_TRANSFORMATION_SIZE:]
185
- from operator import add
186
-
187
-
188
- # todo Refactor
189
- # Maybe properties should be on same level as rest.
190
- def decode_output(output, split_fn=np_split):
191
- group_output = output[..., OUTPUT_GROUP_SLICE_END]
192
- slot_output = output[..., OUTPUT_SLOT_SLICE_END]
193
- properties_output = output[..., OUTPUT_PROPERTIES_SLICE_END]
194
- properties_output_splited = split_fn(properties_output, list(rv.properties.INDEX.values()), axis=-1)
195
- return group_output, slot_output, properties_output_splited
196
-
197
-
198
- def decode_inference(inference, reshape=np.reshape):
199
- return reshape(inference[INFERENCE_SLOT_SLICE],
200
- [-1, SLOTS_GROUPS, PROPERTY_TRANSFORMATION_NO]), reshape(
201
- inference[INFERENCE_PROPERTIES_SLICE],
202
- [-1, PROPERTIES_NO, ENTITY_SUM, PROPERTY_TRANSFORMATION_NO])
203
-
204
-
205
- def decode_output_reshape(output, split_fn=np_split):
206
- result = decode_output(output, split_fn=split_fn)
207
- out_reshaped = []
208
- for i, out in enumerate(result[2]):
209
- shape = (-1, ENTITY_SUM, ENTITY_PROPERTIES_VALUES[i])
210
- out_reshaped.append(out.reshape(shape))
211
- return result[:2] + tuple(out_reshaped)
212
-
213
-
214
- def take_target(target):
215
- return target[1], target[2]
216
-
217
-
218
- def create_target(images, index, pattern_index=(2, 5), full_index=False, arrange=np.arange, shape=lambda x: x.shape):
219
- return [images[:, pattern_index[0]], images[:, pattern_index[1]],
220
- images[arrange(shape(index)[0]), (0 if full_index else 8) + index[:, 0]]]
221
-
222
-
223
- def take_target_simple(target):
224
- return target[1], target[0]
225
-
226
-
227
- def create_target_simple(images, target, index=slice(None), pattern_index=(2, 5)):
228
- return [images[:, pattern_index[0]], images[:, pattern_index[1]], target][index]
229
-
230
-
231
- def decode_output_result(output, split_fn=np_split, arg_max=np.argmax):
232
- result = decode_output_reshape(output, split_fn=split_fn)
233
- res = []
234
- for i, r in enumerate(result):
235
- if i == 1:
236
- res.append(r)
237
- else:
238
- res.append(arg_max(r, axis=-1))
239
- return tuple(res)
240
-
241
-
242
- def decode_target(target):
243
- target_group = target[..., 0]
244
- target_slot = target[..., 1:INDEX[0]]
245
- target_properties = target[..., INDEX[0]:END_INDEX]
246
- target_properties_splited = [
247
- target_properties[..., ::PROPERTIES_NO],
248
- target_properties[..., 1::PROPERTIES_NO],
249
- target_properties[..., 2::PROPERTIES_NO]
250
- ]
251
- return target_group, target_slot, target_properties_splited
252
-
253
-
254
- def decode_target_flat(target):
255
- t = decode_target(target)
256
- return t[0], t[1], t[2][0], t[2][1], t[2][2]
257
-
258
-
259
- def draw_board(images, target=None, predict=None,image=None, desc=None, layout=None, break_=20):
260
- if image != "target" and predict is not None:
261
- image = images[predict:predict + 1]
262
- elif images is None and target is not None:
263
- image = images[target:target + 1]
264
- # image = False to not draw anything
265
- border = [{COR: target - 8, EXIST: (1, 3)}] + [{COR: p, EXIST: (0, 2)} for p in none(predict)]
266
-
267
- boards = []
268
- boards.append(draw_images(np.concatenate([images[:8], image[None] if len(image.shape)==3 else image]) if image is not None else images[:8]))
269
- if layout == 1:
270
- i = draw_images(images[8:], column=4, border=border)
271
- if break_:
272
- i = np.concatenate([np.zeros([ break_, i.shape[1],1]),i ],axis=0)
273
- boards.append(i)
274
-
275
- else:
276
- boards.append(
277
- draw_images(np.concatenate([images[8:], predict]) if predict is not None else images[8:], column=4,
278
- border=target - 8))
279
- full_board = draw_images(boards, grid=False)
280
- if desc:
281
- full_board = add_text(full_board, desc)
282
- return full_board
283
-
284
-
285
- def draw_boards(images, target=None, predict=None, image=None, desc=None, no=1, layout=None):
286
- boards = []
287
- for i, image in enumerate(images):
288
- boards.append(draw_board(image, target[i][0] if target is not None else None,
289
- predict[i] if predict is not None else None,
290
- image[i] if image is not None else None,
291
- desc[i] if desc is not None else None, layout=layout))
292
- return boards
293
-
294
-
295
- def draw_raven(generator, predict=None, no=1, add_target_desc=True, indexes=None, types=TYPES,
296
- layout=1):
297
- if indexes is None:
298
- indexes = val_sample(no)
299
- data = generator.data[indexes]
300
- if is_model(predict):
301
- d = filter_keys(data, PROPERTY,reverse=True)
302
- # tmp change
303
- pro = predict(d)['predict']
304
- print(pro)
305
- predict = render_panels(pro, target=False)
306
- # if target is not None:
307
- target = data[TARGET]
308
- target_index = data["index"]
309
- images = data[INPUTS]
310
-
311
- if hasattr(predict, "shape"):
312
- if len(predict.shape) > 3:
313
- # iamges
314
- image = predict
315
- # todo create index and output based on image
316
- predict = None
317
- predict_index = None
318
- elif len(predict.shape) == 3:
319
- image = render_panels(predict, target=False)
320
- # Create index based on predict.
321
- predict_index = None
322
- else:
323
- image = images[predict]
324
- predict_index = predict
325
- predict = target
326
- else:
327
- image = K.gather(images, target_index[:, 0])
328
- predict_index = None
329
- predict = None
330
-
331
- # elif not(hasattr(target,"shape") and len(target.shape) > 3):
332
- # if hasattr(target,"shape") and target.shape[-1] == OUTPUT_SIZE:
333
- # pro = target
334
- # predict = render_panels(pro)
335
- # elif hasattr(target,"shape") and target.shape[-1] == FEATURE_NO:
336
- # # pro = target
337
- # pro = np.zeros([no, OUTPUT_SIZE], dtype="int")
338
- # else:
339
- # pro = np.zeros([no, OUTPUT_SIZE], dtype="int")
340
- # # predict = [None] * no
341
- # predict = render_panels(data[TARGET])
342
-
343
- all_rules = []
344
- for d in data[PROPERTY]:
345
- rules = []
346
- for j, rule_group in enumerate(d.findAll("Rule_Group")):
347
- # rules_all.append(rule_group['id'])
348
- for j, rule in enumerate(rule_group.findAll("Rule")):
349
- rules.append(f"{rule['attr']} - {rule['name']}")
350
- rules.append("")
351
- all_rules.append(rules)
352
- target_desc = get_desc(target)
353
- if predict is not None:
354
- predict_desc = decode_output_result(predict) if predict.shape[-1] == OUTPUT_SIZE else get_desc(predict)
355
- else:
356
- predict_desc = [None] * len(target_desc)
357
- for a, po, to in zip(all_rules, predict_desc, target_desc):
358
- # fl(predict_desc[-1])
359
- if po is None:
360
- po = [None] * len(to)
361
- for p, t in zip(po, to):
362
- a.extend(
363
- [" ".join([str(i) for i in t])] + (
364
- [" ".join([str(i) for i in p]), ""] if p is not None else []
365
- )
366
- )
367
- # a.extend([""] + [] + [""] + [" ".join(fl(p))])
368
-
369
- # image = draw_boards(data[INPUTS],target=data["index"], predict=predict[:no], desc=all_rules, no=no,layer=layer)
370
- image = draw_boards(images, target=target_index, predict=predict_index, image=image, desc=None, no=no,
371
- layout=layout)
372
- return lu([(i, j) for i, j in zip(image, all_rules)])
373
-
374
-
375
- def val_sample(no=GROUPS_NO, base=3):
376
- indexes = np.arange(no) * 2000 + base
377
- return indexes
378
-
379
-
380
- def get_desc(target, exist=None, types=TYPES, sizes=SIZES):
381
- decoded = decode_target(target)
382
- exist = decoded[1] if exist is None else exist
383
- taken = np.stack(take(decoded[2], np.array(exist, dtype=bool))).T
384
-
385
- figures_no = np.sum(exist, axis=-1)
386
- desc = np.split(taken, np.cumsum(figures_no))[:-1]
387
- # figures_no = np.sum(exist, axis=-1)
388
- # div = np.split(desc, np.cumsum(figures_no))[:-1]
389
- result = []
390
- for pd in desc:
391
- r = []
392
- for p in pd:
393
- r.append([p[0], sizes[p[1]], types[p[2]]])
394
- result.append(r)
395
-
396
- return result
397
-
398
-
399
- # def get
400
-
401
-
402
- def get_description(inputs, predict, pro, no, types=TYPES, sizes=SIZES):
403
- # target = inputs[1][2][:no]
404
- target = inputs[TARGET]
405
- target_group = target[:, 0]
406
- target_exist = np.asarray(target[:, 1:ENTITY_SUM + 1], dtype="bool")
407
- target_rest = target[:, ENTITY_SUM + 1:ENTITY_SUM + 1 + ENTITY_SUM * PROPERTIES_NO]
408
- pro_reshaped = np.reshape(pro, (pro.shape[0], -1, PROPERTIES_NO))
409
- target_reshaped = np.reshape(target_rest, (target_rest.shape[0], -1, PROPERTIES_NO))
410
-
411
- # mask = np.repeat(target_exist, [4] * ENTITY_SUM, axis=-1)
412
- # masked_result = np.repeat(target_exist, [4] * ENTITY_SUM, axis=-1)
413
- pro_res = pro_reshaped[target_exist]
414
- target_res = target_reshaped[target_exist]
415
- figures_no = np.sum(target_exist, axis=-1)
416
-
417
- pro_div = np.split(pro_res, np.cumsum(figures_no))[:-1]
418
- target_div = np.split(target_res, np.cumsum(figures_no))[:-1]
419
- pro_result_full = []
420
- target_result_full = []
421
- for pd, td in zip(pro_div, target_div):
422
- pro_result = []
423
- target_result = []
424
- for p in pd:
425
- pro_result.append([p[0], sizes[p[1]], types[p[2]]])
426
- for t in td:
427
- target_result.append([t[0], sizes[t[1]], types[t[2]]])
428
- pro_result_full.append(pro_result)
429
- target_result_full.append(target_result)
430
-
431
- return pro_result_full, target_result_full
432
-
433
-
434
- def get_properties(target, types=TYPES, sizes=SIZES):
435
- target_exist = np.asarray(target[:, 1:ENTITY_SUM + 1], dtype="bool")
436
- target_rest = target[:, ENTITY_SUM + 1:ENTITY_SUM + 1 + ENTITY_SUM * PROPERTIES_NO]
437
- target_reshaped = np.reshape(target_rest, (target_rest.shape[0], -1, PROPERTIES_NO))
438
- target_res = target_reshaped[target_exist]
439
- figures_no = np.sum(target_exist, axis=-1)
440
- target_div = np.split(target_res, np.cumsum(figures_no))[:-1]
441
- target_result_full = []
442
- for td in target_div:
443
- target_result = []
444
- for t in td:
445
- target_result.append([t[0], sizes[t[1]], types[t[2]]])
446
- target_result_full.append(target_result)
447
- return target_result_full
448
-
449
-
450
- def desc_properties(target, decode_fn=None, types=TYPES, sizes=SIZES):
451
- if decode_fn is None:
452
- if target.shape[1] == OUTPUT_SIZE:
453
- decode_fn = decode_output_result
454
- else:
455
- decode_fn = decode_target
456
-
457
- target_div = decode_fn(target)[2:]
458
- target_result_full = []
459
- for td in target_div:
460
- target_result = []
461
- for t in td:
462
- target_result.append([t[0], sizes[t[1]], types[t[2]]])
463
- target_result_full.append(target_result)
464
- return target_result_full
465
-
466
-
467
- def get_pro(t, types=TYPES, sizes=SIZES):
468
- return [int(t[0]), sizes[t[1]], types[t[2]]]
469
-
470
-
471
- def get_pro2(td, types=TYPES, sizes=SIZES):
472
- target_result = []
473
- for t in td:
474
- target_result.append([int(t[0]), sizes[t[1]], types[t[2]]])
475
- return target_result
476
-
477
-
478
- def get_pro3(target_div, types=TYPES, sizes=SIZES):
479
- target_result_full = []
480
- for td in target_div.to_list():
481
- target_result = []
482
- for t in td:
483
- target_result.append([int(t[0]), sizes[t[1]], types[t[2]]])
484
- target_result_full.append(target_result)
485
- return target_result_full
486
-
487
-
488
- from models_utils import init_image as def_init_image, is_model
489
-
490
- init_image = partial(def_init_image, shape=(16, 8, 80, 80, 1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
raven_utils/models/__init__.py DELETED
File without changes
raven_utils/models/attn.py DELETED
@@ -1,187 +0,0 @@
1
- from __future__ import print_function
2
-
3
- import tensorflow as tf
4
- from tensorflow.keras import backend as K
5
- from tensorflow.keras.layers import LSTMCell
6
- from tensorflow.keras.models import Model
7
- from tensorflow.keras.layers import Conv2D, Dense
8
- from tensorflow.keras.losses import mse
9
- from tensorflow.keras.models import clone_model
10
- from tensorflow.layers.base import InputSpec, Layer
11
-
12
- from models.dense import create_conv_model
13
- from models.utils import broadcast
14
-
15
-
16
- class ReflectionPadding2D(Layer):
17
- def __init__(self, padding=(1, 1), **kwargs):
18
- self.padding = tuple(padding)
19
- self.input_spec = [InputSpec(ndim=4)]
20
- super(ReflectionPadding2D, self).__init__(**kwargs)
21
-
22
- def compute_output_shape(self, s):
23
- """ If you are using "channels_last" configuration"""
24
- return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])
25
-
26
- def call(self, x, mask=None):
27
- w_pad, h_pad = self.padding
28
- return tf.pad(x, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]], 'REFLECT')
29
-
30
-
31
- class Conv2Ref(Layer):
32
- def __init__(self, padding=(1, 1), **kwargs):
33
- self.padding = tuple(padding)
34
- self.input_spec = [InputSpec(ndim=4)]
35
- super(ReflectionPadding2D, self).__init__(**kwargs)
36
-
37
- def compute_output_shape(self, s):
38
- """ If you are using "channels_last" configuration"""
39
- return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])
40
-
41
- def call(self, x, mask=None):
42
- w_pad, h_pad = self.padding
43
- return tf.pad(x, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]], 'REFLECT')
44
-
45
-
46
- class SegmentationNetwork(Model):
47
-
48
- def __init__(self, filters=64, kernels=(3, 3)):
49
- super(RecAE, self).__init__()
50
- self.conv_1 = Conv2D(filters, kernels, padding=SAME)
51
- self.conv_2 = Conv2D(filters, kernels, padding=SAME)
52
-
53
- def call(self, inputs):
54
- x = K.relu(inputs)
55
- x = self.conv_1(x)
56
- x = K.relu(x)
57
- x = self.conv_2(x)
58
- return x + inputs
59
-
60
-
61
- class QueryNetwork(Model):
62
-
63
- def __init__(self, units=64):
64
- super(RecAE, self).__init__()
65
- self.conv_1 = Dense(units)
66
- self.conv_2 = Dense(units)
67
-
68
- def call(self, inputs):
69
- x = K.relu(inputs)
70
- x = self.conv_1(x)
71
- x = K.relu(x)
72
- x = self.conv_2(x)
73
- return x + inputs
74
-
75
-
76
- class RecAE(Model):
77
-
78
- def __init__(self, head, bottle, decoder):
79
- super(RecAE, self).__init__()
80
- self.head = head
81
- self.bottle = bottle
82
- self.base = clone_model(bottle)
83
- self.decoder = decoder
84
- self.segmentation_network = SegmentationNetwork()
85
- self.query_network = QueryNetwork()
86
- self.control = LSTMCell(64)
87
- self.memory = LSTMCell(64)
88
-
89
- def call(self, inputs):
90
- feature = self.head(inputs)
91
- segmentation = self.segmentation_network(feature)
92
- control_base = self.base(feature)
93
- h_c = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
94
- h_m = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
95
- shape = K.shape(feature)[:-1]
96
- full_attention = tf.zeros(shape)[..., tf.newaxis]
97
- full_image = tf.zeros(K.shape(inputs))
98
- masks = []
99
- ff = tf.zeros(K.shape(inputs))
100
- scope = tf.ones(shape)[..., tf.newaxis]
101
- for i in range(4):
102
- r_c, h_c = self.control(tf.concat([control_base, h_m[0]], 1), h_c)
103
- query = self.query_network(h_c[0])
104
- log_attention = image_attention(segmentation, query)
105
- attention = K.sigmoid(log_attention)
106
- mask = attention * scope
107
- scope = scope - mask
108
- im = feature * mask
109
- # im = feature
110
- latent = self.bottle(im)
111
- decoded = self.decoder(latent)
112
- # self.add_loss(K.mean(-mse(full_attention, attention)))
113
- # self.add_loss(K.mean(-mse(tf.ones(attention.shape), attention)))
114
- full_attention += attention
115
- big_mask = tf.image.resize(mask, K.shape(inputs)[1:-1])
116
- ff += K.sigmoid(decoded)
117
- full_image += K.sigmoid(decoded) * big_mask
118
- r_m, h_m = self.memory(latent, h_m)
119
- masks.append(big_mask)
120
- self.add_loss(K.mean(mse(inputs, full_image)))
121
- return full_image, masks
122
-
123
-
124
- # def image_attention(image, query, scale=True):
125
- @tf.function
126
- def image_attention(image, query):
127
- log_attention = K.sum(query[:, tf.newaxis, tf.newaxis, :] * image, axis=-1, keepdims=True)
128
- # if scale is not None:
129
- log_attention /= tf.sqrt(tf.cast(K.shape(image)[-1], dtype=float))
130
- return log_attention
131
-
132
-
133
- class RecAE_2(Model):
134
-
135
- def __init__(self, head, bottle, decoder):
136
- super(RecAE_2, self).__init__()
137
- self.head = head
138
- self.bottle = bottle
139
- # self.base = clone_model(bottle)
140
- self.base = self.bottle
141
- self.decoder = decoder
142
- self.segmentation_network = create_conv_model((64, 64, 1))
143
- self.control = LSTMCell(64)
144
- self.memory = LSTMCell(64)
145
-
146
- def call(self, inputs):
147
- feature = self.head(inputs)
148
- control_base = self.base(feature)
149
- h_c = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
150
- h_m = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
151
- shape = K.shape(feature)[:-1]
152
- full_attention = tf.zeros(shape)[..., tf.newaxis]
153
- full_image = tf.zeros(K.shape(inputs))
154
- big_masks = []
155
- masks = []
156
- ff = tf.zeros(K.shape(inputs))
157
- scope = tf.ones(shape)[..., tf.newaxis]
158
- for i in range(4):
159
- if i ==3:
160
- mask = scope
161
- else:
162
- r_c, h_c = self.control(tf.concat([control_base, h_m[0]], 1), h_c)
163
- query = broadcast(h_c[0], feature.shape[1:])
164
- log_attention = self.segmentation_network(tf.concat([feature, query], axis=-1))
165
- attention = K.sigmoid(log_attention)
166
- mask = attention * scope
167
- scope = scope - mask
168
- masks.append(mask)
169
- im = feature * mask
170
- # im = feature
171
- latent = self.bottle(im)
172
- decoded = self.decoder(latent)
173
- # self.add_loss(K.mean(-mse(scope, mask)))
174
- sum = K.sum(tf.ones(K.shape(mask)))
175
- self.add_loss(K.abs((sum/4)-K.sum(mask))/sum)
176
- # self.add_loss(K.mean(-mse(tf.zeros(K.shape(mask)), mask)))
177
- for m in masks:
178
- self.add_loss(K.mean(-mse(m,mask)))
179
-
180
- full_attention += mask
181
- big_mask = tf.image.resize(mask, K.shape(inputs)[1:-1])
182
- ff += K.sigmoid(decoded)
183
- full_image += K.sigmoid(decoded) * big_mask
184
- r_m, h_m = self.memory(latent, h_m)
185
- big_masks.append(big_mask)
186
- self.add_loss(K.mean(mse(inputs, full_image)))
187
- return full_image, big_masks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
raven_utils/models/attn2.py DELETED
@@ -1,187 +0,0 @@
1
- from __future__ import print_function
2
-
3
- import tensorflow as tf
4
- from tensorflow.keras import backend as K
5
- from tensorflow.keras.layers import LSTMCell
6
- from tensorflow.keras.models import Model
7
- from tensorflow.keras.layers import Conv2D, Dense
8
- from tensorflow.keras.losses import mse
9
- from tensorflow.keras.models import clone_model
10
- from tensorflow.layers.base import InputSpec, Layer
11
-
12
- from models.dense import create_conv_model
13
- from models.utils import broadcast
14
-
15
-
16
- class ReflectionPadding2D(Layer):
17
- def __init__(self, padding=(1, 1), **kwargs):
18
- self.padding = tuple(padding)
19
- self.input_spec = [InputSpec(ndim=4)]
20
- super(ReflectionPadding2D, self).__init__(**kwargs)
21
-
22
- def compute_output_shape(self, s):
23
- """ If you are using "channels_last" configuration"""
24
- return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])
25
-
26
- def call(self, x, mask=None):
27
- w_pad, h_pad = self.padding
28
- return tf.pad(x, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]], 'REFLECT')
29
-
30
-
31
- class Conv2Ref(Layer):
32
- def __init__(self, padding=(1, 1), **kwargs):
33
- self.padding = tuple(padding)
34
- self.input_spec = [InputSpec(ndim=4)]
35
- super(ReflectionPadding2D, self).__init__(**kwargs)
36
-
37
- def compute_output_shape(self, s):
38
- """ If you are using "channels_last" configuration"""
39
- return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])
40
-
41
- def call(self, x, mask=None):
42
- w_pad, h_pad = self.padding
43
- return tf.pad(x, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]], 'REFLECT')
44
-
45
-
46
- class SegmentationNetwork(Model):
47
-
48
- def __init__(self, filters=64, kernels=(3, 3)):
49
- super(RecAE, self).__init__()
50
- self.conv_1 = Conv2D(filters, kernels)
51
- self.conv_2 = Conv2D(filters, kernels)
52
-
53
- def call(self, inputs):
54
- x = K.relu(inputs)
55
- x = self.conv_1(x)
56
- x = K.relu(x)
57
- x = self.conv_2(x)
58
- return x + inputs
59
-
60
-
61
- class QueryNetwork(Model):
62
-
63
- def __init__(self, units=64):
64
- super(RecAE, self).__init__()
65
- self.conv_1 = Dense(units)
66
- self.conv_2 = Dense(units)
67
-
68
- def call(self, inputs):
69
- x = K.relu(inputs)
70
- x = self.conv_1(x)
71
- x = K.relu(x)
72
- x = self.conv_2(x)
73
- return x + inputs
74
-
75
-
76
- class RecAE(Model):
77
-
78
- def __init__(self, head, bottle, decoder):
79
- super(RecAE, self).__init__()
80
- self.head = head
81
- self.bottle = bottle
82
- self.base = clone_model(bottle)
83
- self.decoder = decoder
84
- self.segmentation_network = SegmentationNetwork()
85
- self.query_network = QueryNetwork()
86
- self.control = LSTMCell(64)
87
- self.memory = LSTMCell(64)
88
-
89
- def call(self, inputs):
90
- feature = self.head(inputs)
91
- segmentation = self.segmentation_network(feature)
92
- control_base = self.base(feature)
93
- h_c = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
94
- h_m = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
95
- shape = K.shape(feature)[:-1]
96
- full_attention = tf.zeros(shape)[..., tf.newaxis]
97
- full_image = tf.zeros(K.shape(inputs))
98
- masks = []
99
- ff = tf.zeros(K.shape(inputs))
100
- scope = tf.ones(shape)[..., tf.newaxis]
101
- for i in range(10):
102
- r_c, h_c = self.control(tf.concat([control_base, h_m[0]], 1), h_c)
103
- query = self.query_network(h_c[0])
104
- log_attention = image_attention(segmentation, query)
105
- attention = K.softmax(log_attention)
106
- mask = attention * scope
107
- scope = scope - mask
108
- im = feature * mask
109
- # im = feature
110
- latent = self.bottle(im)
111
- decoded = self.decoder(latent)
112
- # self.add_loss(K.mean(-mse(full_attention, attention)))
113
- # self.add_loss(K.mean(-mse(tf.ones(attention.shape), attention)))
114
- full_attention += attention
115
- big_mask = tf.image.resize(mask, K.shape(inputs)[1:-1])
116
- ff += K.sigmoid(decoded)
117
- full_image += K.sigmoid(decoded) * big_mask
118
- r_m, h_m = self.memory(latent, h_m)
119
- masks.append(big_mask)
120
- self.add_loss(K.mean(mse(inputs, full_image)))
121
- return full_image, masks
122
-
123
-
124
- # def image_attention(image, query, scale=True):
125
- @tf.function
126
- def image_attention(image, query):
127
- log_attention = K.sum(query[:, tf.newaxis, tf.newaxis, :] * image, axis=-1, keepdims=True)
128
- # if scale is not None:
129
- log_attention /= tf.sqrt(tf.cast(K.shape(image)[-1], dtype=float))
130
- return log_attention
131
-
132
-
133
- class RecAE_2(Model):
134
-
135
- def __init__(self, head, bottle, decoder):
136
- super(RecAE_2, self).__init__()
137
- self.head = head
138
- self.bottle = bottle
139
- # self.base = clone_model(bottle)
140
- self.base = self.bottle
141
- self.decoder = decoder
142
- self.segmentation_network = create_conv_model((64, 64, 1))
143
- self.control = LSTMCell(64)
144
- self.memory = LSTMCell(64)
145
-
146
- def call(self, inputs):
147
- feature = self.head(inputs)
148
- control_base = self.base(feature)
149
- h_c = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
150
- h_m = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
151
- shape = K.shape(feature)[:-1]
152
- full_attention = tf.zeros(shape)[..., tf.newaxis]
153
- full_image = tf.zeros(K.shape(inputs))
154
- big_masks = []
155
- masks = []
156
- ff = tf.zeros(K.shape(inputs))
157
- scope = tf.ones(shape)[..., tf.newaxis]
158
- for i in range(4):
159
- if i ==3:
160
- mask = scope
161
- else:
162
- r_c, h_c = self.control(tf.concat([control_base, h_m[0]], 1), h_c)
163
- query = broadcast(h_c[0], feature.shape[1:])
164
- log_attention = self.segmentation_network(tf.concat([feature, query], axis=-1))
165
- attention = K.sigmoid(log_attention)
166
- mask = attention * scope
167
- scope = scope - mask
168
- masks.append(mask)
169
- im = feature * mask
170
- # im = feature
171
- latent = self.bottle(im)
172
- decoded = self.decoder(latent)
173
- # self.add_loss(K.mean(-mse(scope, mask)))
174
- sum = K.sum(tf.ones(K.shape(mask)))
175
- self.add_loss(K.abs((sum/4)-K.sum(mask))/sum)
176
- # self.add_loss(K.mean(-mse(tf.zeros(K.shape(mask)), mask)))
177
- for m in masks:
178
- self.add_loss(K.mean(-mse(m,mask)))
179
-
180
- full_attention += mask
181
- big_mask = tf.image.resize(mask, K.shape(inputs)[1:-1])
182
- ff += K.sigmoid(decoded)
183
- full_image += K.sigmoid(decoded) * big_mask
184
- r_m, h_m = self.memory(latent, h_m)
185
- big_masks.append(big_mask)
186
- self.add_loss(K.mean(mse(inputs, full_image)))
187
- return full_image, big_masks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
raven_utils/models/augment.py DELETED
File without changes
raven_utils/models/body.py DELETED
@@ -1,276 +0,0 @@
1
- import itertools
2
-
3
- import tensorflow as tf
4
- from ml_utils import self_product, lw
5
-
6
- from models_utils import DictModel, ListModel, Flat, bm, Base, Cat, Res, Flat2, conv, KERNEL_SIZE, FILTERS, SAME, \
7
- Get, SM, bs, RELU, ACTIVATION, dense, bd, HardBlock, MaxBlock
8
- import models_utils.ops as K
9
- from models_utils import Merge, SoftBlock
10
- from models_utils.build import build_multi_dense, build_multi_conv, build_conv_model, build_encoder
11
- from tensorflow.keras.layers import Lambda, Dense
12
- from tensorflow.keras.layers import Conv2D
13
-
14
- from config.constant import MEMORY, CONTROL, LATENT, MERGE, CONCAT, INFERENCE, FLATTEN
15
- from models_utils.config import config
16
-
17
-
18
- class RavRes(Res):
19
- def __init__(self, model="v2", latent=256, act=RELU):
20
- super().__init__(model=model)
21
- self.latent = latent
22
-
23
- def call(self, inputs):
24
- return self.model(inputs) + inputs[0][:, ..., self.latent:]
25
-
26
-
27
- # not working
28
- class RavResConv(Res):
29
- def __init__(self, model="v2", latent=256, act=RELU):
30
- super().__init__(model=model)
31
- self.latent = latent
32
- self.conv = conv(latent, (1, 1), activation=act)
33
-
34
- def call(self, inputs):
35
- return self.model(inputs) + self.conv(inputs[0])
36
-
37
-
38
- class RavResDense(Res):
39
- def __init__(self, model="v2", latent=256, act=config.DEF_DENSE.activation):
40
- super().__init__(model=model)
41
- self.latent = latent
42
- self.conv = dense(latent, activation=act)
43
-
44
- def call(self, inputs):
45
- return self.model(inputs) + self.conv(inputs[0])
46
-
47
-
48
- def create_dense_block(latent=256, loop=1):
49
- soft_block = Res(SoftBlock(build_multi_dense(latent), add_identity=None,
50
- score_activation=tf.sigmoid), latent=latent)
51
- cells = [
52
- (lambda x: K.cat([x[:, 0], x[:, 1]]), LATENT, CONCAT),
53
- (None, CONCAT, MEMORY),
54
- (Dense(latent), CONCAT, MERGE),
55
- (Merge(latent), [INFERENCE, MERGE], CONTROL),
56
- (soft_block, [MEMORY, CONTROL], MEMORY)
57
- ]
58
-
59
- return ListModel([DictModel(*cell) for cell in cells] * loop, [LATENT, INFERENCE], MEMORY)
60
-
61
-
62
- def build_multi_conv(filters=32, end_filters=64, padding="same",mul=1, norm=None, **kwargs):
63
- base = [(1, 3), (3, 1), (3, 3)]
64
- block = list(self_product(base))
65
- block2 = [b + b[0:1] for b in block]
66
- block3 = [b + b for b in block]
67
- block4 = ([[(3, 3)]] + [[(3, 3), (3, 3)]] + [[(3, 3), (3, 3), (3, 3)]]) * 2
68
- block5 = [[], []]
69
- all_blocks = [s for b in [block, block2, block3, block4, block5] for s in b]
70
- start = {
71
- FILTERS: filters,
72
- KERNEL_SIZE: (1, 1)
73
- }
74
-
75
- end = {
76
- FILTERS: end_filters,
77
- KERNEL_SIZE: (1, 1),
78
- ACTIVATION: None
79
- }
80
-
81
- all_arch = []
82
- for ab in all_blocks:
83
- arch = [{
84
- FILTERS: filters,
85
- KERNEL_SIZE: a,
86
- **kwargs
87
- } for a in ab]
88
- all_arch.append([start] + arch + [end])
89
-
90
- all_arch = all_arch * mul
91
-
92
- return [
93
- build_encoder(a, add_norm=norm if norm else None, padding=padding, name=f"b{i}", order=(1, 0) if norm else None)
94
- for i, a in enumerate(all_arch)]
95
-
96
-
97
- def create_block(latent=256, simpler=0, loop=1, padding=SAME, norm=None, trans_div=2, act="pass", type_="conv",
98
- block_=SoftBlock,max_k=16,
99
- **kwargs):
100
- trans_size = int(latent / trans_div)
101
- # if block_ == HardBlock:
102
- # mul = 2
103
- # elif block_ == MaxBlock:
104
- # mul = int(38/max_k)
105
- # else:
106
- # mul = 1
107
-
108
- if act == "pass":
109
- res_class = RavRes
110
- else:
111
- if type_ == "dense":
112
- res_class = RavResDense
113
- else:
114
- res_class = RavResConv
115
-
116
- if type_ == "dense":
117
- build_res = lambda: Res(model="dv2")
118
- # build_reduction = lambda: bm([dense(latent), "IN"])
119
- build_reduction = lambda: dense(latent)
120
- build_flatten = lambda: bd([latent] * 2)
121
- else:
122
- build_res = lambda: Res(padding=padding)
123
- build_reduction = lambda: bm([conv(trans_size if simpler else latent, 1, padding=padding), "BN"])
124
- # build_reduction = lambda: bm([conv(latent, 1, padding=padding), "BN"])
125
- # build_reduction = lambda: bm([conv(trans_size, 1, padding=padding), "BN"])
126
- # build_reduction = lambda: conv(trans_size, 1, padding=padding)
127
- # build_flatten = lambda: Flat2(filters=trans_size,res_no=2, padding=padding, units=64)
128
- build_flatten = lambda: Flat2(filters=trans_size,padding=padding, units=64)
129
-
130
- if simpler == 1:
131
- cells = [
132
- (lambda x: K.cat([x[:, 0], x[:, 1]]), LATENT, CONCAT,"concatenation"),
133
- # (None, CONCAT, MEMORY),
134
- (build_reduction(), CONCAT, MERGE,"Start_resnet_block"),
135
- # (Get(), INFERENCE, INFERENCE),
136
- (K.cat, [INFERENCE, MERGE], CONTROL,"concatenation"),
137
- ]
138
- else:
139
- cells = [
140
- (lambda x: K.cat([x[:, 0], x[:, 1]]), LATENT, CONCAT),
141
- (build_reduction(), CONCAT, MEMORY),
142
- (build_reduction(), INFERENCE, CONTROL),
143
- ]
144
- for i, l in enumerate(lw(loop)):
145
- if l:
146
- concat = K.cat
147
- control_reduction = build_reduction()
148
- control_res = build_res()
149
- control_flatten = build_flatten()
150
- if i == 0 and simpler == 1:
151
- rest_params = {
152
- "latent": latent,
153
- "act": act
154
- }
155
- else:
156
- rest_params = {
157
- "latent": 0
158
- }
159
-
160
-
161
- if block_ == SoftBlock:
162
- block_params = {
163
- }
164
- else:
165
- block_params = {
166
- "trans_output_shape": latent
167
- }
168
- if block_ == MaxBlock:
169
- block_params['max_k'] = max_k
170
-
171
-
172
- # todo change name
173
- soft_block = res_class(
174
- block_(
175
- build_multi_dense(latent) if type_ == "dense" else build_multi_conv(trans_size, end_filters=latent,
176
- norm=norm, padding=padding,
177
- **kwargs),
178
- add_identity=None,
179
- score_activation=tf.sigmoid,
180
- **block_params
181
-
182
- ),
183
- **rest_params)
184
-
185
- if i == 0 and simpler == 1:
186
- cells.extend([
187
- (control_reduction, CONTROL, CONTROL,"Reduction"),
188
- (control_res, CONTROL, CONTROL,"Control_resnet_block"),
189
- (control_flatten, CONTROL, FLATTEN,"Weights"),
190
- (soft_block, [CONCAT, FLATTEN], MEMORY,"Transformation"),
191
- # (soft_block, [MEMORY, FLATTEN], MEMORY,"Transformation"),
192
- ])
193
- else:
194
- if l:
195
- memory_res = build_res()
196
-
197
- cells.extend([
198
- (memory_res, MEMORY, MEMORY,"Memory_resnet_block"),
199
- (concat, [CONTROL, MEMORY], CONTROL,"concatenation"),
200
- (control_reduction, CONTROL, CONTROL,"Reduction"),
201
- (control_res, CONTROL, CONTROL,"Control_resnet_block"),
202
- (control_flatten, CONTROL, FLATTEN,"Weights"),
203
- (soft_block, [MEMORY, FLATTEN], MEMORY, "Transformation"),
204
- ])
205
- return ListModel([DictModel(*cell) for cell in cells], [LATENT, INFERENCE], MEMORY, debug_=False)
206
-
207
- #
208
- #
209
- # def test(x):
210
- # np.zeros(4)
211
- # self_product((1, 3))
212
- #
213
- #
214
- # list(itertools.product())
215
- # u.layers[0].layers[-1].model.layers[1]
216
-
217
- # class RecurrentBodyDict(Model):
218
- # # def __init__(self, start=None, cell=None, output_network=None, output_activation="tanh", latent=64, loop_no=5):
219
- # def __init__(self, start=None, cell=None, output_network=None, output_activation=None, latent=64, loop_no=5):
220
- # super().__init__()
221
- # self.start = sm(start, lambda: SubClassingModel([StartLSTMControl(latent), StartLSTMMemory(latent)]),
222
- # latent=latent)
223
- # self.cell = sm(cell, lambda: SubClassingModel([LSTMControl(latent), LSTMMemory(latent)]), latent=latent)
224
- # self.output_network = sm(output_network, lf(take_memory_states))
225
- # self.loop_no = loop_no
226
- # # tmp
227
- # self.activation = Activation(output_activation)
228
- #
229
- # def call(self, inputs):
230
- # outputs = []
231
- # for j in range(3):
232
- # outputs.append(self.start({"latent": inputs[0][j], "inference": inputs[1]}))
233
- # for i in range(self.loop_no):
234
- # for j in range(3):
235
- # outputs[j] = self.cell(outputs[j])
236
- #
237
- # return self.activation(self.output_network(outputs))
238
- #
239
- #
240
- # class RecurrentBodySimpleMix4Dict(RecurrentBodyDict):
241
- # def __init__(self, latent=64, output_network=None, loop_no=5):
242
- # super().__init__(
243
- # start=SubClassingModel(
244
- # [ConcatCell(), DenseCell(latent), InfMergeCell(latent),
245
- # WeigthCell(latent, layer_no=np.repeat([1, 2, 3, 4, 5, 6, 7, 8], 4),
246
- # add_identity=Lambda(lambda x: x[:, latent:]))]),
247
- # cell=False,
248
- # output_network=output_network, loop_no=0)
249
- # class RecurrentBodySimpleMix4Conv(RecurrentBodyDict):
250
- # def __init__(self, latent=64, output_network=None, loop_no=5):
251
- # super().__init__(
252
- # start=SubClassingModel(
253
- # [ConcatCell(), ConvCell(latent), ReduceCell(latent), InfMergeCell(latent),
254
- # ModelCell(latent=latent, layers_no=2, input_name=CONTROL, result_name=CONTROL),
255
- # WeigthCell(latent,
256
- # transformation_network=[build_conv_model2([latent] * i, kernels=(j, j)) for i in range(1, 7) for j in
257
- # range(1, 5) for _ in range(1)],
258
- # add_identity=Lambda(lambda x: x[:, ..., latent:]))
259
- # ]),
260
- # cell=False,
261
- # output_network=output_network, loop_no=0)
262
- #
263
- #
264
- # class RecurrentBodySimpleMix4Conv2(RecurrentBodyDict):
265
- # def __init__(self, latent=64, output_network=None, loop_no=5):
266
- # super().__init__(
267
- # start=SubClassingModel(
268
- # [ConcatCell(), ConvCell(latent), ReduceCell2(latent), InfMergeCell(latent),
269
- # ModelCell(latent=latent, layers_no=2, input_name=CONTROL, result_name=CONTROL),
270
- # WeigthCell(latent,
271
- # transformation_network=[bc([latent] * i, kernels=(j, j)) for i in range(1, 7) for j in
272
- # range(1, 5) for _ in range(1)],
273
- # add_identity=Lambda(lambda x: x[:, ..., latent:]))
274
- # ]),
275
- # cell=False,
276
- # output_network=output_network, loop_no=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
raven_utils/models/class_.py DELETED
@@ -1,31 +0,0 @@
1
- from ml_utils import lw
2
- from models_utils import SubClassingModel, ops as K, Base
3
- import tensorflow as tf
4
-
5
-
6
- class Merge(SubClassingModel):
7
- def call(self, inputs):
8
- results = []
9
- for i, model in enumerate(self.model[:-1]):
10
- results.append(model(inputs[i]))
11
- # todo why K.cat not working
12
- results = self.model[-1](tf.concat(results, axis=-1))
13
- return results
14
-
15
-
16
- class RavenClass(Base):
17
- def __init__(self, model, scales=None, no=3, name=None):
18
- super().__init__(model=model, name=name)
19
- self.scales = scales
20
- self.no = no
21
-
22
- def call(self, inputs):
23
- inputs = lw(inputs)
24
- class_res = []
25
- # for i in range(inputs[0].shape[1]):
26
- for i in range(self.no):
27
- # d = [r[:, i] if r.ndim == 5 else r for r in inputs]
28
- d = [inputs[s][:, i] if inputs[s].ndim > 2 else inputs for s in self.scales]
29
- class_res.append(self.model(d))
30
- # return tf.stack(class_res,axis=1)
31
- return [class_res]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
raven_utils/models/head.py DELETED
@@ -1,159 +0,0 @@
1
- import tensorflow as tf
2
- from ml_utils import set_default
3
- from models_utils import build_dense_model, bm, ActivationModel, sm, large_conv_dense_encoder, Pass
4
- from models_utils import res
5
- from tensorflow.keras import Model
6
- from models_utils import ops as K
7
- from tensorflow.keras.layers import Dense, Conv2D, Flatten
8
- from keras.backend import batch_flatten
9
-
10
-
11
- # todo Refactoring
12
- class HeadModel(Model):
13
- def __init__(self, encoder=None, inference_network=None, output_size=64, inference_output_size=None,
14
- inference_activation="relu", stem=None, images_no=8, inference_image_no=None):
15
- super().__init__()
16
- # self.encoder = sm(encoder, bm([en.large_conv_dense_encoder(), Dense(output_size)], False))
17
- self.encoder = encoder or bm([large_conv_dense_encoder(), Dense(output_size)])
18
- # self.head = head or HeadBatch(encoder=encoder, output_size=output_size)
19
- inference_output_size = inference_output_size or output_size
20
- self.inference_network = inference_network or bm([
21
- K.flat,
22
- build_dense_model([1028, 512, 512, inference_output_size],
23
- last_activation=inference_activation)]
24
- )
25
- self.stem = stem or Pass()
26
- self.images_no = images_no
27
- self.inference_image_no = self.images_no if inference_image_no is None else inference_image_no
28
-
29
-
30
- class LatentHeadModel(HeadModel):
31
- def call(self, inputs):
32
- result = K.map_batch(inputs[:, :self.images_no], self.encoder)
33
- inference = self.inference_network(result[:, :self.inference_image_no])
34
- latents = self.stem(result)
35
- return [latents, inference,result]
36
-
37
-
38
- # # todo use map_batch
39
- # class HeadBatch(Model):
40
- # def __init__(self, encoder=None, output_size=64):
41
- # super().__init__()
42
- # self.encoder = sm(encoder, bm([large_conv_dense_encoder(), Dense(output_size)], False))
43
- #
44
- # def call(self, inputs):
45
- # shape = tf.shape(inputs)
46
- # latents = self.encoder(tf.reshape(inputs, shape=tf.concat([[-1], shape[2:]], axis=-1)))
47
- # latents = K.reshape(latents, tf.concat([[-1, shape[1]], latents.shape[1:]], axis=-1))
48
- # return latents
49
-
50
-
51
- # Not working
52
- class DuoHeadModel(HeadModel):
53
- def __init__(self, encoder=None, inference_network=None, images_no=8, filters=-4):
54
- super().__init__(encoder=encoder, inference_network=inference_network, images_no=images_no)
55
- self.encoder = ActivationModel(self.encoder, filters=filters, include_input=False)
56
-
57
- def call(self, inputs):
58
- shape = inputs.shape
59
- result = reversed(self.encoder(K.reshape(inputs, shape=[-1] + list(shape[2:]))))
60
- latents = K.reshape(result[0], [-1, self.images_no] + [result[0].shape[-1]])
61
- inference = self.inference_network(K.flat(result[1]))
62
- return [latents, inference]
63
-
64
-
65
- class MultiHeadModel(Model):
66
- def __init__(self, encoder=None, images_no=8, filters=(1, 3, 6)):
67
- super().__init__()
68
- self.encoder = ActivationModel(encoder, filters=filters, include_input=False)
69
- self.merge = MergeSacles()
70
- self.images_no = images_no
71
-
72
- def call(self, inputs):
73
- shape = tf.shape(inputs)
74
- results = self.encoder(tf.reshape(inputs, shape=tf.concat([[-1], shape[2:]], axis=-1)))
75
- latents = [tf.reshape(result, shape=tf.concat([[-1, self.images_no], tf.shape(result)[1:]], axis=-1)) for result
76
- in results]
77
-
78
- l1 = tf.transpose(latents[0], (0, 2, 3, 1, 4))
79
- # l1 = tf.reshape(l1, tuple(list(l1.shape[:3]) + [l1.shape[-2] * l1.shape[-1]]))
80
- shape = tf.shape(l1)
81
- l1 = tf.reshape(l1, tf.concat([[-1], shape[1:3], [shape[-2] * shape[-1]]], axis=-1))
82
-
83
- l2 = tf.transpose(latents[1], (0, 2, 3, 1, 4))
84
- # l2 = tf.reshape(l2, [-1] + list(l2.shape[1:3]) + [l2.shape[-2] * l2.shape[-1]])
85
- shape = tf.shape(l2)
86
- l2 = tf.reshape(l2, tf.concat([[-1], shape[1:3], [shape[-2] * shape[-1]]], axis=-1))
87
-
88
- l3 = latents[2]
89
- shape = tf.shape(l3)
90
- # l3 = tf.reshape(l3, [-1] + [l3.shape[-2] * l3.shape[-1]])
91
- l3 = tf.reshape(l3, tf.concat([[-1], [shape[-2] * shape[-1]]], axis=-1))
92
-
93
- inference = self.merge([l1, l2, l3])
94
- return [latents, inference]
95
-
96
-
97
- class MergeSacles(Model):
98
- def __init__(self):
99
- super().__init__()
100
- self.inf_1 = bm([Conv2D(64, 1, activation="relu"), res(64),
101
- Conv2D(64, 3, strides=2, padding=SAME, activation="relu"),
102
- res(64),
103
- Flatten(),
104
- Dense(256, "relu")])
105
- self.inf_2 = bm([Conv2D(128, 1, activation="relu"),
106
- res(128),
107
- Flatten(),
108
- Dense(256, "relu")])
109
- self.inf_3 = Dense(256, "relu")
110
-
111
- def call(self, inputs):
112
- il1 = self.inf_1(inputs[0])
113
- il2 = self.inf_2(inputs[1])
114
- il3 = self.inf_3(inputs[2])
115
- inference = tf.concat([il1, il2, il3], axis=1)
116
- return inference
117
-
118
-
119
- class MultiHeadModel2(Model):
120
- def __init__(self, encoder=None, images_no=8, filters=(3, 6)):
121
- super().__init__()
122
- self.encoder = ActivationModel(encoder, filters=filters, include_input=False)
123
- self.merge = MergeSacles2()
124
- self.images_no = images_no
125
-
126
- def call(self, inputs):
127
- shape = tf.shape(inputs)
128
- results = self.encoder(tf.reshape(inputs, shape=tf.concat([[-1], shape[2:]], axis=-1)))
129
- latents = [tf.reshape(result, shape=tf.concat([[-1, self.images_no], tf.shape(result)[1:]], axis=-1)) for result
130
- in results]
131
-
132
- l2 = tf.transpose(latents[0], (0, 2, 3, 1, 4))
133
- # l2 = tf.reshape(l2, [-1] + list(l2.shape[1:3]) + [l2.shape[-2] * l2.shape[-1]])
134
- shape = tf.shape(l2)
135
- l2 = tf.reshape(l2, tf.concat([[-1], shape[1:3], [shape[-2] * shape[-1]]], axis=-1))
136
-
137
- l3 = latents[1]
138
- shape = tf.shape(l3)
139
- # l3 = tf.reshape(l3, [-1] + [l3.shape[-2] * l3.shape[-1]])
140
- l3 = tf.reshape(l3, tf.concat([[-1], [shape[-2] * shape[-1]]], axis=-1))
141
-
142
- inference = self.merge([l2, l3])
143
- return [latents, inference]
144
-
145
-
146
- class MergeSacles2(Model):
147
- def __init__(self):
148
- super().__init__()
149
- self.inf_1 = bm([Conv2D(128, 1, activation="relu"),
150
- res(128),
151
- Flatten(),
152
- Dense(256, "relu")])
153
- self.inf_2 = Dense(256, "relu")
154
-
155
- def call(self, inputs):
156
- il1 = self.inf_1(inputs[0])
157
- il2 = self.inf_2(inputs[1])
158
- inference = tf.concat([il1, il2], axis=1)
159
- return inference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
raven_utils/models/loss.py DELETED
@@ -1,630 +0,0 @@
1
- from functools import partial
2
-
3
- import tensorflow as tf
4
- import tensorflow.experimental.numpy as tnp
5
- from models_utils import OUTPUT, TARGET, PREDICT, DictModel, add_loss, LOSS, Predict
6
- from models_utils import SubClassingModel
7
- from models_utils.models.utils import interleave
8
- from models_utils.op import reshape
9
- from tensorflow.keras import Model
10
- # from tensorflow.keras import backend as K
11
- from tensorflow.keras.layers import Lambda
12
- from tensorflow.keras.losses import SparseCategoricalCrossentropy, mse
13
- from tensorflow.keras.metrics import SparseCategoricalAccuracy, Accuracy, BinaryAccuracy
14
- import models_utils.ops as K
15
-
16
- import raven_utils.decode
17
- import raven_utils as rv
18
- from raven_utils.config.constant import LABELS, INDEX, ACC_SAME, ACC_CHOOSE_LOWER, ACC_CHOOSE_UPPER, CLASSIFICATION, \
19
- SLOT, \
20
- PROPERTIES, ACC, GROUP, NUMBER, MASK
21
- from raven_utils.models.uitls_ import RangeMask
22
- from raven_utils.const import VERTICAL, HORIZONTAL
23
-
24
-
25
- def get_properties_mask(target):
26
- return target[:, rv.target.END_INDEX:rv.target.UNIFORMITY_INDEX] > 0
27
-
28
-
29
- def create_change_mask(target):
30
- properties_mask = get_properties_mask(target)
31
- return [create_mask(properties_mask, i) for i, _ in enumerate(rv.rules.ATTRIBUTES)]
32
-
33
-
34
- def create_uniform_mask(target):
35
- u_mask = lambda i: tf.tile(target[:, rv.target.UNIFORMITY_INDEX + i, None] == 3, [1, rv.rules.ATTRIBUTES_LEN])
36
- properties_mask = tf.concat([u_mask(0), u_mask(1)], axis=-1) | get_properties_mask(target)
37
- return [create_mask(properties_mask, i) for i, _ in enumerate(rv.rules.ATTRIBUTES)]
38
-
39
-
40
- def create_all_mask(target):
41
- return [
42
- tf.cast(tf.ones(tf.stack([tf.shape(target)[0], rv.entity.SUM])), dtype=tf.bool) for i, _ in
43
- enumerate(rv.rules.ATTRIBUTES)]
44
-
45
-
46
- class BaselineClassificationLossModel(Model):
47
- def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True):
48
- super().__init__()
49
- self.predict_fn = SubClassingModel([lambda x: x[0], PredictModel()])
50
- self.loss_fn = ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
51
- group_loss=group_loss)
52
- self.metric_fn = SimilarityRaven(mode=mode)
53
-
54
- def call(self, inputs):
55
- losses = []
56
- output = inputs[1]
57
- losses.append(self.loss_fn([inputs[0][0], output]))
58
- losses.append(self.metric_fn([inputs[0][2], inputs[3][0], inputs[0][1][:, 8:]]))
59
- return losses
60
-
61
-
62
- class RavenLoss(Model):
63
- def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True, lw=(1.0, 0.3),
64
- classification=False, trans=True, anneal=False):
65
- super().__init__()
66
- if anneal:
67
- self.weight_scheduler
68
- self.classification = classification
69
- self.trans = trans
70
- self.predict_fn = DictModel(SubClassingModel([lambda x: x[-1], PredictModel()]), in_=OUTPUT,
71
- out=[PREDICT, MASK], name="pred")
72
- if self.trans:
73
- self.loss_fn = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
74
- group_loss=group_loss, enable_metrics=False, lw=lw[0]),
75
- name="main_loss")
76
- self.loss_fn_2 = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
77
- group_loss=group_loss), name="add_loss")
78
- self.metric_fn = SimilarityRaven(mode=mode)
79
- if self.classification:
80
- self.loss_fn_3 = add_loss(
81
- ClassRavenModel(mode=create_all_mask, number_loss=number_loss, slot_loss=slot_loss,
82
- group_loss=group_loss, enable_metrics="c" if self.trans else True), lw=lw[1],
83
- name="class_loss")
84
-
85
- def call(self, inputs):
86
- losses = []
87
- output = inputs[OUTPUT]
88
- target = inputs[TARGET]
89
- labels = inputs[LABELS]
90
-
91
- if self.trans:
92
- losses.append(self.loss_fn([labels[:, 2], output[0]]))
93
- losses.append(self.loss_fn([labels[:, 5], output[1]]))
94
- losses.append(self.loss_fn_2([target, output[2]]))
95
- losses.append(self.metric_fn([inputs[INDEX], inputs[PREDICT], labels]))
96
- if self.classification:
97
- for i in range(8):
98
- losses.append(self.loss_fn_3([labels[:, i], inputs[CLASSIFICATION][i]]))
99
- return {**inputs, LOSS: losses}
100
-
101
-
102
- class VTRavenLoss(Model):
103
- def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True, lw=(1.0, 0.1),
104
- classification=False, trans=True, anneal=False, plw=None):
105
- super().__init__()
106
- if anneal:
107
- self.weight_scheduler
108
- self.classification = classification
109
- self.trans = trans
110
- self.predict_fn = DictModel(SubClassingModel([lambda x: x[:, -1], PredictModel()]), in_=OUTPUT,
111
- out=[PREDICT, MASK], name="pred")
112
- self.loss_fn = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
113
- group_loss=group_loss, plw=plw), lw=lw[0] , name="add_loss")
114
- self.metric_fn = SimilarityRaven(mode=mode)
115
- if self.classification:
116
- self.loss_fn_2 = add_loss(
117
- ClassRavenModel(mode=create_all_mask, number_loss=number_loss, slot_loss=slot_loss,
118
- group_loss=group_loss, enable_metrics="c", plw=plw), lw=lw[1], name="class_loss")
119
-
120
- def call(self, inputs):
121
- losses = []
122
- output = inputs[OUTPUT]
123
- target = inputs[TARGET]
124
- labels = inputs[LABELS]
125
-
126
- for i in range(9):
127
- losses.append(self.loss_fn_2([labels[:, i], output[:, i]]))
128
- losses.append(self.loss_fn([target, output[:, 8]]))
129
- losses.append(self.metric_fn([inputs[INDEX], inputs[PREDICT], labels]))
130
- return {**inputs, LOSS: losses}
131
-
132
-
133
- class SingleVTRavenLoss(Model):
134
- def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True, lw=(1.0, 0.1),
135
- classification=False, trans=True, anneal=False):
136
- super().__init__()
137
- if anneal:
138
- self.weight_scheduler
139
- self.classification = classification
140
- self.trans = trans
141
- self.predict_fn = DictModel(PredictModel(), in_=OUTPUT, out=[PREDICT, MASK], name="pred")
142
- self.loss_fn = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
143
- group_loss=group_loss), lw=lw[0], name="add_loss")
144
- self.metric_fn = SimilarityRaven(mode=mode)
145
-
146
- def call(self, inputs):
147
- losses = []
148
- output = inputs[OUTPUT]
149
- target = inputs[TARGET]
150
- labels = inputs[LABELS]
151
-
152
- losses.append(self.loss_fn([target, output]))
153
- losses.append(self.metric_fn([inputs[INDEX], inputs[PREDICT], labels]))
154
- return {**inputs, LOSS: losses}
155
-
156
-
157
- class ClassRavenModel(Model):
158
- def __init__(self, mode=create_all_mask,plw=None, number_loss=False, slot_loss=True, group_loss=True, enable_metrics=True,
159
- lw=1.0):
160
- super().__init__()
161
- self.number_loss = number_loss
162
- self.group_loss = group_loss
163
- self.enable_metrics = enable_metrics
164
- self.slot_loss = slot_loss
165
- self.predict_fn = PredictModel()
166
- self.loss_fn = SparseCategoricalCrossentropy(from_logits=True)
167
- if self.slot_loss:
168
- self.loss_fn_2 = tf.nn.sigmoid_cross_entropy_with_logits
169
- if self.enable_metrics:
170
- self.enable_metrics = f"{self.enable_metrics}_" if isinstance(self.enable_metrics, str) else ""
171
- self.metric_fn = [
172
- SparseCategoricalAccuracy(name=f"{self.enable_metrics}{ACC}_{property_}") for property_ in
173
- rv.properties.NAMES]
174
- if self.group_loss:
175
- self.metric_fn_group = SparseCategoricalAccuracy(name=f"{self.enable_metrics}{ACC}_{GROUP}")
176
- if self.slot_loss:
177
- self.metric_fn_2 = BinaryAccuracy(name=f"{self.enable_metrics}{ACC}_{SLOT}")
178
- self.range_mask = RangeMask()
179
- self.mode = mode
180
- self.lw = lw
181
- if not plw:
182
- plw = [1., 95.37352927, 2.83426987, 0.85212836, 1.096005, 1.21943385]
183
- elif isinstance(plw, int) or isinstance(plw, float):
184
- plw = [1., plw, 2.83426987, 0.85212836, 1.096005, 1.21943385]
185
- # plw = [plw] * 6
186
- self.plw = plw
187
-
188
- # self.predict_fn = partial(tf.argmax, axis=-1)
189
-
190
- def call(self, inputs):
191
- losses = []
192
- metrics = {}
193
- target = inputs[0]
194
- output = inputs[1]
195
-
196
- target_group, target_slot, target_all = raven_utils.decode.decode_target(target)
197
-
198
- group_output, output_slot, outputs = raven_utils.decode.output_divide(output, split_fn=tf.split)
199
-
200
- # group
201
- if self.group_loss:
202
- group_loss = self.lw * self.plw[0] * self.loss_fn(target_group, group_output)
203
- losses.append(group_loss)
204
-
205
- if isinstance(self.enable_metrics, str):
206
- group_metric = self.metric_fn_group(target_group, group_output)
207
- # metrics[GROUP] = group_metric
208
- self.add_metric(group_metric)
209
- self.add_metric(tf.reduce_sum(group_metric), f"{self.enable_metrics}{ACC}")
210
-
211
- # setting uniformity mask
212
- full_properties_musks = self.mode(target)
213
-
214
- range_mask = self.range_mask(target_group)
215
-
216
- if self.slot_loss:
217
- # number
218
- number_mask = range_mask & full_properties_musks[0]
219
- number_mask = tf.cast(number_mask, tf.float32)
220
- target_number = tf.reduce_sum(
221
- tf.cast(target_slot, "float32") * number_mask, axis=-1)
222
- output_number = tf.reduce_sum(
223
- tf.cast(tf.sigmoid(output_slot) >= 0.5, "float32") * number_mask, axis=-1)
224
-
225
- # output_number = tf.reduce_sum(tf.sigmoid(output_slot) * number_mask, axis=-1)
226
- if self.number_loss:
227
- scale = 1 / 9
228
- if self.number_loss == 2:
229
- output_number_2 = tf.reduce_sum(tf.sigmoid(output_slot) * number_mask, axis=-1)
230
- else:
231
- output_number_2 = output_number
232
- number_loss = self.lw * self.plw[1] * mse(tf.stop_gradient(target_number) * scale, output_number_2 * scale)
233
- losses.append(number_loss)
234
-
235
- # metrics[NUMBER] = number_acc
236
-
237
- if isinstance(self.enable_metrics, str):
238
- number_acc = tf.reduce_mean(
239
- tf.cast(tf.cast(target_number, "int8") == tf.cast(output_number, "int8"), "float32"))
240
- self.add_metric(tf.reduce_sum(number_acc), f"{self.enable_metrics}{ACC}_{NUMBER}")
241
- self.add_metric(tf.reduce_sum(number_acc), f"{self.enable_metrics}{ACC}")
242
- self.add_metric(tf.reduce_sum(number_acc), f"{self.enable_metrics}{ACC}_NO_{GROUP}")
243
-
244
- # position/slot
245
- slot_mask = range_mask & full_properties_musks[1]
246
- # tf.boolean_mask(target_slot,slot_mask)
247
-
248
- if tf.reduce_any(slot_mask):
249
- # if tf.reduce_mean(tf.cast(slot_mask, dtype=tf.int32)) > 0:
250
- target_slot_masked = tf.boolean_mask(target_slot, slot_mask)[:, None]
251
- output_slot_masked = tf.boolean_mask(output_slot, slot_mask)[:, None]
252
- loss_slot = self.lw * self.plw[2] * tf.reduce_mean(
253
- self.loss_fn_2(tf.cast(target_slot_masked, "float32"), output_slot_masked))
254
- if isinstance(self.enable_metrics, str):
255
- acc_slot = self.metric_fn_2(target_slot_masked, output_slot_masked)
256
- self.add_metric(acc_slot)
257
- self.add_metric(tf.reduce_sum(acc_slot), f"{self.enable_metrics}{ACC}")
258
- self.add_metric(tf.reduce_sum(acc_slot), f"{self.enable_metrics}{ACC}_NO_{GROUP}")
259
- else:
260
- loss_slot = 0.0
261
- acc_slot = -1.0
262
-
263
- losses.append(loss_slot)
264
- # metrics[SLOT] = acc_slot
265
- # if loss_slot != 0:
266
-
267
- # if tf.reduce_any(slot_mask):
268
-
269
- # self.add_metric(acc_slot, f"{self.enable_metrics}{ACC}_{NUMBER}")
270
- # self.add_metric(acc_slot, f"{self.enable_metrics}{ACC}")
271
- # self.add_metric(acc_slot, f"{self.enable_metrics}{ACC}_NO_{GROUP}")
272
-
273
- # properties
274
- for i, out in enumerate(outputs):
275
- shape = (-1, rv.entity.SUM, rv.properties.RAW_SIZE[i])
276
- out_reshaped = tf.reshape(out, shape)
277
- properties_mask = tf.cast(target_slot, "bool") & full_properties_musks[i + 2]
278
-
279
- if tf.reduce_any(properties_mask):
280
- out_masked = tf.boolean_mask(out_reshaped, properties_mask)
281
- out_target = tf.boolean_mask(target_all[i], properties_mask)
282
- loss = self.lw * self.plw[3+i] * self.loss_fn(out_target, out_masked)
283
- if isinstance(self.enable_metrics, str):
284
- metric = self.metric_fn[i](out_target, out_masked)
285
- self.add_metric(metric)
286
- # self.add_metric(metric, f"{self.enable_metrics}{ACC}")
287
- self.add_metric(tf.reduce_sum(metric), f"{self.enable_metrics}{ACC}")
288
- self.add_metric(tf.reduce_sum(metric), f"{self.enable_metrics}{ACC}_{PROPERTIES}")
289
- self.add_metric(tf.reduce_sum(metric), f"{self.enable_metrics}{ACC}_NO_{GROUP}")
290
- else:
291
- loss = 0.0
292
- metric = -1.0
293
-
294
- losses.append(loss)
295
- return losses
296
-
297
-
298
- class FullMask(Model):
299
- def __init__(self, mode=create_uniform_mask):
300
- super().__init__()
301
- self.range_mask = RangeMask()
302
- self.mode = mode
303
-
304
- def call(self, inputs):
305
- target_group, target_slot, _ = raven_utils.decode.decode_target(inputs)
306
- full_properties_musks = self.mode(inputs)
307
- range_mask = self.range_mask(target_group)
308
-
309
- number_mask = range_mask & full_properties_musks[0]
310
-
311
- slot_mask = range_mask & full_properties_musks[1]
312
- properties_mask = []
313
- for property_mask in full_properties_musks[2:]:
314
- properties_mask.append(tf.cast(target_slot, "bool") & property_mask)
315
- return [slot_mask, properties_mask, number_mask]
316
-
317
-
318
- def create_mask(rules, i):
319
- mask_1 = tf.tile(rules[:, i][None], [len(rv.target.FIRST_LAYOUT), 1])
320
- mask_2 = tf.tile(rules[:, i + 5][None], [len(rv.target.SECOND_LAYOUT), 1])
321
- shape = tf.shape(rules)
322
- full_mask_1 = tf.scatter_nd(tnp.array(rv.target.FIRST_LAYOUT)[:, None], mask_1, shape=(rv.entity.SUM, shape[0]))
323
- full_mask_2 = tf.tensor_scatter_nd_update(full_mask_1, tnp.array(rv.target.SECOND_LAYOUT)[:, None], mask_2)
324
- return tf.transpose(full_mask_2)
325
-
326
-
327
- # class PredictModel(Model):
328
- # def __init__(self):
329
- # super().__init__()
330
- # self.predict_fn = Lambda(partial(tf.argmax, axis=-1))
331
- # self.predict_fn_2 = Lambda(lambda x: tf.sigmoid(x) > 0.5)
332
- # self.range_mask = RangeMask()
333
- #
334
- # # self.predict_fn = partial(tf.argmax, axis=-1)
335
- #
336
- # def call(self, inputs):
337
- # group_output = inputs[rv.OUTPUT_GROUP_SLICE]
338
- # group_loss = self.predict_fn(group_output)[:, None]
339
- #
340
- # output_slot = inputs[rv.OUTPUT_SLOT_SLICE]
341
- # range_mask = self.range_mask(group_loss[:, 0])
342
- # loss_slot = tf.cast(self.predict_fn_2(output_slot), dtype=tf.int64)
343
- #
344
- # properties_output = inputs[rv.OUTPUT_PROPERTIES_SLICE]
345
- # properties = []
346
- # outputs = tf.split(properties_output, list(rv.ENTITY_PROPERTIES_INDEX.values()), axis=-1)
347
- # for i, out in enumerate(outputs):
348
- # shape = (-1, rv.ENTITY_SUM, rv.ENTITY_PROPERTIES_VALUES[i])
349
- # out_reshaped = tf.reshape(out, shape)
350
- # properties.append(self.predict_fn(out_reshaped))
351
- # number_loss = tf.reduce_sum(loss_slot, axis=-1, keepdims=True)
352
- #
353
- # result = tf.concat([group_loss, loss_slot, interleave(properties), number_loss], axis=-1)
354
- #
355
- # return [result, range_mask, range_mask, range_mask, range_mask]
356
-
357
- class PredictModel(Model):
358
- def __init__(self):
359
- super().__init__()
360
- self.predict_fn = Predict()
361
- self.predict_fn_2 = Lambda(lambda x: tf.sigmoid(x) > 0.5)
362
- self.range_mask = RangeMask()
363
-
364
- # self.predict_fn = partial(tf.argmax, axis=-1)
365
-
366
- def call(self, inputs):
367
- group_output, output_slot, *properties = rv.decode.output(inputs, tf.split, self.predict_fn, self.predict_fn_2)
368
- number_loss = K.int64(K.sum(output_slot))
369
- result = tf.concat(
370
- [group_output[:, None], tf.cast(output_slot, dtype=tf.int64), interleave(properties), number_loss[:, None]],
371
- axis=-1)
372
-
373
- range_mask = self.range_mask(group_output)
374
- return [result, range_mask]
375
- # return [result, range_mask, range_mask, range_mask, range_mask]
376
-
377
-
378
- # todo change slices
379
- class PredictModelMasked(Model):
380
- def __init__(self):
381
- super().__init__()
382
- self.predict_fn = Lambda(partial(tf.argmax, axis=-1))
383
- self.loss_fn_2 = Lambda(lambda x: tf.sigmoid(x) > 0.5)
384
- self.range_mask = RangeMask()
385
-
386
- # self.predict_fn = partial(tf.argmax, axis=-1)
387
-
388
- def call(self, inputs):
389
- group_output = inputs[:, -rv.GROUPS_NO:]
390
- group_loss = self.predict_fn(group_output)[:, None]
391
-
392
- output_slot = inputs[:, :rv.ENTITY_SUM]
393
- range_mask = self.range_mask(group_loss[:, 0])
394
- loss_slot = tf.cast(self.predict_fn_2(output_slot * range_mask), dtype=tf.int64)
395
-
396
- properties_output = inputs[:, rv.ENTITY_SUM:-rv.GROUPS_NO]
397
-
398
- properties = []
399
- outputs = tf.split(properties_output, list(rv.ENTITY_PROPERTIES_INDEX.values()), axis=-1)
400
- for i, out in enumerate(outputs):
401
- shape = (-1, rv.ENTITY_SUM, rv.ENTITY_PROPERTIES_VALUES[i])
402
- out_reshaped = tf.reshape(out, shape)
403
- out_masked = out_reshaped * loss_slot[..., None]
404
- properties.append(self.predict_fn(out_masked))
405
- # out_masked[0].numpy()
406
- number_loss = tf.reduce_sum(loss_slot, axis=-1, keepdims=True)
407
-
408
- result = tf.concat([group_loss, loss_slot, interleave(properties), number_loss], axis=-1)
409
-
410
- return result
411
-
412
-
413
- def final_predict_mask(x, mask):
414
- r = reshape(x[0][:, rv.INDEX[0]:-1], [-1, 3])
415
- return tf.ragged.boolean_mask(r, mask)
416
-
417
-
418
- def final_predict(x, mode=False):
419
- m = x[1] if mode else tf.cast(x[0][:, 1:rv.INDEX[0]], tf.bool)
420
- return final_predict_mask(x[0], m)
421
-
422
-
423
- def final_predict_2(x):
424
- ones = tf.cast(tf.ones(tf.shape(x[0])[0]), tf.bool)[:, None]
425
- mask = tf.concat([ones, tf.tile(x[1], [1, 4]), ones], axis=-1)
426
- return tf.ragged.boolean_mask(x[0], mask)
427
-
428
-
429
- class PredictModelOld(Model):
430
-
431
- def call(self, inputs):
432
- output = inputs[-2]
433
-
434
- rest_output = output[:, :-rv.GROUPS_NO]
435
-
436
- result_all = []
437
- outputs = tf.split(rest_output, list(rv.ENTITY_PROPERTIES_INDEX.values()), axis=-3)
438
- for i, out in enumerate(outputs):
439
- shape = (-3, rv.ENTITY_SUM, rv.ENTITY_PROPERTIES_VALUES[i])
440
- out_reshaped = tf.reshape(out, shape)
441
-
442
- result = tf.cast(tf.argmax(out_reshaped, axis=-3), dtype="int8")
443
- result_all.append(result)
444
-
445
- result_all = interleave(result_all)
446
- return result_all
447
-
448
-
449
- def get_matches(diff, target_index):
450
- diff_sum = K.sum(diff)
451
- db_argsort = tf.argsort(diff_sum, axis=-1)
452
- db_sorted = tf.sort(diff_sum)
453
- db_mask = db_sorted[:, 0, None] == db_sorted
454
- db_same = tf.where(db_mask, db_argsort, -1 * tf.ones_like(db_argsort))
455
- matched_index = db_same == target_index
456
- # setting shape needed for TensorFlow graph
457
- matched_index.set_shape(db_same.shape)
458
- matches = K.any(matched_index)
459
- more_matches = K.sum(db_mask) > 1
460
- once_matches = K.sum(matches & tf.math.logical_not(more_matches))
461
- return matches, more_matches, once_matches
462
-
463
-
464
- class SimilarityRaven(Model):
465
- def __init__(self, mode=create_all_mask, number_loss=False):
466
- super().__init__()
467
- self.range_mask = RangeMask()
468
- self.mode = mode
469
-
470
- # self.predict_fn = partial(tf.argmax, axis=-1)
471
-
472
- # INDEX, PREDICT, LABELS
473
- def call(self, inputs):
474
- metrics = []
475
- target_index = inputs[0] - 8
476
- predict = inputs[1]
477
- answers = inputs[2][:, 8:]
478
- shape = tf.shape(predict)
479
-
480
- target = K.gather(answers, target_index[:, 0])
481
-
482
- target_group = target[:, 0]
483
-
484
- # comp_slice = np.
485
- target_comp = target[:, 1:rv.target.END_INDEX]
486
- predict_comp = predict[:, 1:rv.target.END_INDEX]
487
- answers_comp = answers[:, :, 1:rv.target.END_INDEX]
488
-
489
- full_properties_musks = self.mode(target)
490
- fpm = K.cat([full_properties_musks[0], interleave(full_properties_musks[2:])])
491
-
492
- range_mask = self.range_mask(target_group)
493
- full_range_mask = K.cat([range_mask, tf.repeat(range_mask, 3, axis=-1)], axis=-1)
494
-
495
- final_mask = fpm & full_range_mask
496
-
497
- target_masked = target_comp * final_mask
498
- predict_masked = predict_comp * final_mask
499
- answers_masked = answers_comp * tf.tile(final_mask[:, None], [1, 8, 1])
500
-
501
- acc_same = K.mean(K.all(target_masked == predict_masked))
502
- self.add_metric(acc_same, ACC_SAME)
503
- metrics.append(acc_same)
504
-
505
- diff = tf.abs(predict_masked[:, None] - answers_masked)
506
- diff_bool = diff != 0
507
-
508
- matches, more_matches, once_matches = get_matches(tf.cast(diff_bool, dtype=tf.int32), target_index)
509
-
510
- second_phase_mask = (more_matches & matches)
511
- diff_second_phase = tf.boolean_mask(diff, second_phase_mask)
512
- target_index_2 = tf.boolean_mask(target_index, second_phase_mask, axis=0)
513
-
514
- matches_2, more_matches_2, once_matches_2 = get_matches(diff_second_phase, target_index_2)
515
- matches_2_no = K.sum(matches_2)
516
-
517
- acc_choose_upper = (once_matches + matches_2_no) / shape[0]
518
- self.add_metric(acc_choose_upper, ACC_CHOOSE_UPPER)
519
- metrics.append(acc_choose_upper)
520
-
521
- acc_choose_lower = (once_matches + once_matches_2) / shape[0]
522
- self.add_metric(acc_choose_lower, ACC_CHOOSE_LOWER)
523
- metrics.append(acc_choose_lower)
524
-
525
- return metrics
526
-
527
-
528
- class SimilarityRaven2(Model):
529
- def __init__(self, mode=create_all_mask, number_loss=False):
530
- super().__init__()
531
- self.range_mask = RangeMask()
532
- self.mode = mode
533
-
534
- # self.predict_fn = partial(tf.argmax, axis=-1)
535
-
536
- # INDEX, PREDICT, LABELS
537
- def call(self, inputs):
538
- metrics = []
539
- target_index = inputs[0] - 8
540
- predict = inputs[1]
541
- answers = inputs[2][:, 8:]
542
- shape = tf.shape(predict)
543
-
544
- target = K.gather(answers, target_index[:, 0])
545
-
546
- target_group = target[:, 0]
547
-
548
- # comp_slice = np.
549
- target_comp = target[:, 1:rv.target.END_INDEX]
550
- predict_comp = predict[:, 1:rv.target.END_INDEX]
551
- answers_comp = answers[:, :, 1:rv.target.END_INDEX]
552
-
553
- full_properties_musks = self.mode(target)
554
- fpm = K.cat([full_properties_musks[0], interleave(full_properties_musks[2:])])
555
-
556
- range_mask = self.range_mask(target_group)
557
- full_range_mask = K.cat([range_mask, tf.repeat(range_mask, 3, axis=-1)], axis=-1)
558
-
559
- final_mask = fpm & full_range_mask
560
-
561
- target_masked = target_comp * final_mask
562
- predict_masked = predict_comp * final_mask
563
- answers_masked = answers_comp * tf.tile(final_mask[:, None], [1, 8, 1])
564
-
565
- acc_same = K.mean(K.all(target_masked == predict_masked))
566
- self.add_metric(acc_same, ACC_SAME)
567
- metrics.append(acc_same)
568
-
569
- diff = tf.abs(predict_masked[:, None] - answers_masked)
570
- diff_bool = diff != 0
571
-
572
- matches, more_matches, once_matches = get_matches(tf.cast(diff_bool, dtype=tf.int32), target_index)
573
-
574
- second_phase_mask = (more_matches & matches)
575
- diff_second_phase = tf.boolean_mask(diff, second_phase_mask)
576
- target_index_2 = tf.boolean_mask(target_index, second_phase_mask, axis=0)
577
-
578
- matches_2, more_matches_2, once_matches_2 = get_matches(diff_second_phase, target_index_2)
579
- matches_2_no = K.sum(matches_2)
580
-
581
- acc_choose_upper = (once_matches + matches_2_no) / shape[0]
582
- self.add_metric(acc_choose_upper, ACC_CHOOSE_UPPER)
583
- metrics.append(acc_choose_upper)
584
-
585
- acc_choose_lower = (once_matches + once_matches_2) / shape[0]
586
- self.add_metric(acc_choose_lower, ACC_CHOOSE_LOWER)
587
- metrics.append(acc_choose_lower)
588
-
589
- metrics.append(K.sum(target_masked != predict_masked))
590
-
591
- return metrics
592
-
593
-
594
- class LatentLossModel(Model):
595
- def __init__(self, dir_=HORIZONTAL):
596
- super().__init__()
597
- # self.sum_metrics = []
598
- # for i in range(8):
599
- # self.sum_metrics.append(Sum(name=f"no_{i}"))
600
- self.metric_fn = Accuracy(name="acc_latent")
601
- if dir_ == VERTICAL:
602
- self.dir = (6, 7)
603
- else:
604
- self.dir = (2, 5)
605
-
606
- def call(self, inputs):
607
- target_image = tf.reshape(inputs[0][2], [-1])
608
- output = inputs[1]
609
- latents = tnp.asarray(inputs[2])
610
-
611
- target_hor = tf.concat([
612
- latents[:, self.dir],
613
- latents[tf.range(latents.shape[0]), target_image + 8][:, None]
614
- ],
615
- axis=1)
616
-
617
- loss_hor = mse(K.stop_gradient(target_hor), output)
618
- self.add_loss(loss_hor)
619
-
620
- self.add_metric(self.metric_fn(inputs[3], target_image))
621
-
622
- return loss_hor
623
-
624
-
625
- class PredRav(Model):
626
-
627
- def call(self, inputs):
628
- output = inputs[0][:, -1]
629
- answers = inputs[1][:, 8:]
630
- return tf.argmin(tf.reduce_sum(tf.abs(output[:, None] - answers), axis=-1), axis=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
raven_utils/models/loss_3.py DELETED
@@ -1,638 +0,0 @@
1
- from functools import partial
2
-
3
- import tensorflow as tf
4
- import tensorflow.experimental.numpy as tnp
5
- from models_utils import OUTPUT, TARGET, PREDICT, DictModel, add_loss, LOSS, Predict
6
- from models_utils import SubClassingModel
7
- from models_utils.models.utils import interleave
8
- from models_utils.op import reshape
9
- from tensorflow.keras import Model
10
- # from tensorflow.keras import backend as K
11
- from tensorflow.keras.layers import Lambda
12
- from tensorflow.keras.losses import SparseCategoricalCrossentropy, mse
13
- from tensorflow.keras.metrics import SparseCategoricalAccuracy, Accuracy, BinaryAccuracy
14
- import models_utils.ops as K
15
-
16
- import raven_utils.decode
17
- import raven_utils as rv
18
- from raven_utils.config.constant import LABELS, INDEX, ACC_SAME, ACC_CHOOSE_LOWER, ACC_CHOOSE_UPPER, CLASSIFICATION, \
19
- SLOT, \
20
- PROPERTIES, ACC, GROUP, NUMBER, MASK
21
- from raven_utils.models.uitls_ import RangeMask
22
- from raven_utils.const import VERTICAL, HORIZONTAL
23
-
24
-
25
- def get_properties_mask(target):
26
- return target[:, rv.target.END_INDEX:rv.target.UNIFORMITY_INDEX] > 0
27
-
28
-
29
- def create_change_mask(target):
30
- properties_mask = get_properties_mask(target)
31
- return [create_mask(properties_mask, i) for i, _ in enumerate(rv.rules.ATTRIBUTES)]
32
-
33
-
34
- def create_uniform_mask(target):
35
- u_mask = lambda i: tf.tile(target[:, rv.target.UNIFORMITY_INDEX + i, None] == 3, [1, rv.rules.ATTRIBUTES_LEN])
36
- properties_mask = tf.concat([u_mask(0), u_mask(1)], axis=-1) | get_properties_mask(target)
37
- return [create_mask(properties_mask, i) for i, _ in enumerate(rv.rules.ATTRIBUTES)]
38
-
39
-
40
- def create_all_mask(target):
41
- return [
42
- tf.cast(tf.ones(tf.stack([tf.shape(target)[0], rv.entity.SUM])), dtype=tf.bool) for i, _ in
43
- enumerate(rv.rules.ATTRIBUTES)]
44
-
45
-
46
- class BaselineClassificationLossModel(Model):
47
- def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True):
48
- super().__init__()
49
- self.predict_fn = SubClassingModel([lambda x: x[0], PredictModel()])
50
- self.loss_fn = ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
51
- group_loss=group_loss)
52
- self.metric_fn = SimilarityRaven(mode=mode)
53
-
54
- def call(self, inputs):
55
- losses = []
56
- output = inputs[1]
57
- losses.append(self.loss_fn([inputs[0][0], output]))
58
- losses.append(self.metric_fn([inputs[0][2], inputs[3][0], inputs[0][1][:, 8:]]))
59
- return losses
60
-
61
-
62
- class RavenLoss(Model):
63
- def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True, lw=(1.0, 0.3),
64
- classification=False, trans=True, anneal=False):
65
- super().__init__()
66
- if anneal:
67
- self.weight_scheduler
68
- self.classification = classification
69
- self.trans = trans
70
- self.predict_fn = DictModel(SubClassingModel([lambda x: x[-1], PredictModel()]), in_=OUTPUT,
71
- out=[PREDICT, MASK], name="pred")
72
- if self.trans:
73
- self.loss_fn = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
74
- group_loss=group_loss, enable_metrics=False, lw=lw[0]),
75
- name="main_loss")
76
- self.loss_fn_2 = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
77
- group_loss=group_loss), name="add_loss")
78
- self.metric_fn = SimilarityRaven(mode=mode)
79
- if self.classification:
80
- self.loss_fn_3 = add_loss(
81
- ClassRavenModel(mode=create_all_mask, number_loss=number_loss, slot_loss=slot_loss,
82
- group_loss=group_loss, enable_metrics="c" if self.trans else True), lw=lw[1],
83
- name="class_loss")
84
-
85
- def call(self, inputs):
86
- losses = []
87
- output = inputs[OUTPUT]
88
- target = inputs[TARGET]
89
- labels = inputs[LABELS]
90
-
91
- if self.trans:
92
- losses.append(self.loss_fn([labels[:, 2], output[0]]))
93
- losses.append(self.loss_fn([labels[:, 5], output[1]]))
94
- losses.append(self.loss_fn_2([target, output[2]]))
95
- losses.append(self.metric_fn([inputs[INDEX], inputs[PREDICT], labels]))
96
- if self.classification:
97
- for i in range(8):
98
- losses.append(self.loss_fn_3([labels[:, i], inputs[CLASSIFICATION][i]]))
99
- return {**inputs, LOSS: losses}
100
-
101
-
102
- class VTRavenLoss(Model):
103
- def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True, lw=(2.0, 1.0),
104
- classification=False, trans=True, anneal=False, plw=None):
105
- super().__init__()
106
- if anneal:
107
- self.weight_scheduler
108
- self.classification = classification
109
- self.trans = trans
110
- self.predict_fn = DictModel(SubClassingModel([lambda x: x[:, -1], PredictModel()]), in_=OUTPUT,
111
- out=[PREDICT, "predict_mask"], name="pred")
112
- self.loss_fn = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
113
- group_loss=group_loss, plw=plw), lw=lw[0], name="add_loss")
114
- self.metric_fn = SimilarityRaven(mode=mode)
115
- if self.classification:
116
- self.loss_fn_2 = add_loss(
117
- ClassRavenModel(mode=create_all_mask, number_loss=number_loss, slot_loss=slot_loss,
118
- group_loss=group_loss, enable_metrics="c", plw=plw), lw=lw[1], name="class_loss")
119
-
120
- def call(self, inputs):
121
- losses = []
122
- output = inputs[OUTPUT]
123
- target = inputs[TARGET]
124
- labels = inputs[LABELS]
125
- mask = inputs[MASK]
126
-
127
- target_masked = target[mask]
128
- output_masked = output[mask]
129
- losses.append(self.loss_fn([target_masked, output_masked]))
130
-
131
- target_unmasked = target[~mask]
132
- output_unmasked = output[~mask]
133
- losses.append(self.loss_fn_2([target_unmasked, output_unmasked]))
134
-
135
- losses.append(self.metric_fn([inputs[INDEX], inputs[PREDICT], labels]))
136
- return {**inputs, LOSS: losses}
137
-
138
-
139
- class SingleVTRavenLoss(Model):
140
- def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True, lw=(1.0, 0.1),
141
- classification=False, trans=True, anneal=False):
142
- super().__init__()
143
- if anneal:
144
- self.weight_scheduler
145
- self.classification = classification
146
- self.trans = trans
147
- self.predict_fn = DictModel(PredictModel(), in_=OUTPUT, out=[PREDICT, MASK], name="pred")
148
- self.loss_fn = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
149
- group_loss=group_loss), lw=lw[0], name="add_loss")
150
- self.metric_fn = SimilarityRaven(mode=mode)
151
-
152
- def call(self, inputs):
153
- losses = []
154
- output = inputs[OUTPUT]
155
- target = inputs[TARGET]
156
- labels = inputs[LABELS]
157
-
158
- losses.append(self.loss_fn([target, output]))
159
- losses.append(self.metric_fn([inputs[INDEX], inputs[PREDICT], labels]))
160
- return {**inputs, LOSS: losses}
161
-
162
-
163
- class ClassRavenModel(Model):
164
- def __init__(self, mode=create_all_mask, plw=None, number_loss=False, slot_loss=True, group_loss=True,
165
- enable_metrics=True,
166
- lw=1.0):
167
- super().__init__()
168
- self.number_loss = number_loss
169
- self.group_loss = group_loss
170
- self.enable_metrics = enable_metrics
171
- self.slot_loss = slot_loss
172
- self.predict_fn = PredictModel()
173
- self.loss_fn = SparseCategoricalCrossentropy(from_logits=True)
174
- if self.slot_loss:
175
- self.loss_fn_2 = tf.nn.sigmoid_cross_entropy_with_logits
176
- if self.enable_metrics:
177
- self.enable_metrics = f"{self.enable_metrics}_" if isinstance(self.enable_metrics, str) else ""
178
- self.metric_fn = [
179
- SparseCategoricalAccuracy(name=f"{self.enable_metrics}{ACC}_{property_}") for property_ in
180
- rv.properties.NAMES]
181
- if self.group_loss:
182
- self.metric_fn_group = SparseCategoricalAccuracy(name=f"{self.enable_metrics}{ACC}_{GROUP}")
183
- if self.slot_loss:
184
- self.metric_fn_2 = BinaryAccuracy(name=f"{self.enable_metrics}{ACC}_{SLOT}")
185
- self.range_mask = RangeMask()
186
- self.mode = mode
187
- self.lw = lw
188
- if not plw:
189
- plw = [1., 95.37352927, 2.83426987, 0.85212836, 1.096005, 1.21943385]
190
- elif isinstance(plw, int) or isinstance(plw, float):
191
- plw = [1., plw, 2.83426987, 0.85212836, 1.096005, 1.21943385]
192
- # plw = [plw] * 6
193
- self.plw = plw
194
-
195
- # self.predict_fn = partial(tf.argmax, axis=-1)
196
-
197
- def call(self, inputs):
198
- losses = []
199
- metrics = {}
200
- target = inputs[0]
201
- output = inputs[1]
202
-
203
- target_group, target_slot, target_all = raven_utils.decode.decode_target(target)
204
-
205
- group_output, output_slot, outputs = raven_utils.decode.output_divide(output, split_fn=tf.split)
206
-
207
- # group
208
- if self.group_loss:
209
- group_loss = self.lw * self.plw[0] * self.loss_fn(target_group, group_output)
210
- losses.append(group_loss)
211
-
212
- if isinstance(self.enable_metrics, str):
213
- group_metric = self.metric_fn_group(target_group, group_output)
214
- # metrics[GROUP] = group_metric
215
- self.add_metric(group_metric)
216
- self.add_metric(tf.reduce_sum(group_metric), f"{self.enable_metrics}{ACC}")
217
-
218
- # setting uniformity mask
219
- full_properties_musks = self.mode(target)
220
-
221
- range_mask = self.range_mask(target_group)
222
-
223
- if self.slot_loss:
224
- # number
225
- number_mask = range_mask & full_properties_musks[0]
226
- number_mask = tf.cast(number_mask, tf.float32)
227
- target_number = tf.reduce_sum(
228
- tf.cast(target_slot, "float32") * number_mask, axis=-1)
229
- output_number = tf.reduce_sum(
230
- tf.cast(tf.sigmoid(output_slot) >= 0.5, "float32") * number_mask, axis=-1)
231
-
232
- # output_number = tf.reduce_sum(tf.sigmoid(output_slot) * number_mask, axis=-1)
233
- if self.number_loss:
234
- scale = 1 / 9
235
- if self.number_loss == 2:
236
- output_number_2 = tf.reduce_sum(tf.sigmoid(output_slot) * number_mask, axis=-1)
237
- else:
238
- output_number_2 = output_number
239
- number_loss = self.lw * self.plw[1] * mse(tf.stop_gradient(target_number) * scale,
240
- output_number_2 * scale)
241
- losses.append(number_loss)
242
-
243
- # metrics[NUMBER] = number_acc
244
-
245
- if isinstance(self.enable_metrics, str):
246
- number_acc = tf.reduce_mean(
247
- tf.cast(tf.cast(target_number, "int8") == tf.cast(output_number, "int8"), "float32"))
248
- self.add_metric(tf.reduce_sum(number_acc), f"{self.enable_metrics}{ACC}_{NUMBER}")
249
- self.add_metric(tf.reduce_sum(number_acc), f"{self.enable_metrics}{ACC}")
250
- self.add_metric(tf.reduce_sum(number_acc), f"{self.enable_metrics}{ACC}_NO_{GROUP}")
251
-
252
- # position/slot
253
- slot_mask = range_mask & full_properties_musks[1]
254
- # tf.boolean_mask(target_slot,slot_mask)
255
-
256
- if tf.reduce_any(slot_mask):
257
- # if tf.reduce_mean(tf.cast(slot_mask, dtype=tf.int32)) > 0:
258
- target_slot_masked = tf.boolean_mask(target_slot, slot_mask)[:, None]
259
- output_slot_masked = tf.boolean_mask(output_slot, slot_mask)[:, None]
260
- loss_slot = self.lw * self.plw[2] * tf.reduce_mean(
261
- self.loss_fn_2(tf.cast(target_slot_masked, "float32"), output_slot_masked))
262
- if isinstance(self.enable_metrics, str):
263
- acc_slot = self.metric_fn_2(target_slot_masked, output_slot_masked)
264
- self.add_metric(acc_slot)
265
- self.add_metric(tf.reduce_sum(acc_slot), f"{self.enable_metrics}{ACC}")
266
- self.add_metric(tf.reduce_sum(acc_slot), f"{self.enable_metrics}{ACC}_NO_{GROUP}")
267
- else:
268
- loss_slot = 0.0
269
- acc_slot = -1.0
270
-
271
- losses.append(loss_slot)
272
- # metrics[SLOT] = acc_slot
273
- # if loss_slot != 0:
274
-
275
- # if tf.reduce_any(slot_mask):
276
-
277
- # self.add_metric(acc_slot, f"{self.enable_metrics}{ACC}_{NUMBER}")
278
- # self.add_metric(acc_slot, f"{self.enable_metrics}{ACC}")
279
- # self.add_metric(acc_slot, f"{self.enable_metrics}{ACC}_NO_{GROUP}")
280
-
281
- # properties
282
- for i, out in enumerate(outputs):
283
- shape = (-1, rv.entity.SUM, rv.properties.RAW_SIZE[i])
284
- out_reshaped = tf.reshape(out, shape)
285
- properties_mask = tf.cast(target_slot, "bool") & full_properties_musks[i + 2]
286
-
287
- if tf.reduce_any(properties_mask):
288
- out_masked = tf.boolean_mask(out_reshaped, properties_mask)
289
- out_target = tf.boolean_mask(target_all[i], properties_mask)
290
- loss = self.lw * self.plw[3 + i] * self.loss_fn(out_target, out_masked)
291
- if isinstance(self.enable_metrics, str):
292
- metric = self.metric_fn[i](out_target, out_masked)
293
- self.add_metric(metric)
294
- # self.add_metric(metric, f"{self.enable_metrics}{ACC}")
295
- self.add_metric(tf.reduce_sum(metric), f"{self.enable_metrics}{ACC}")
296
- self.add_metric(tf.reduce_sum(metric), f"{self.enable_metrics}{ACC}_{PROPERTIES}")
297
- self.add_metric(tf.reduce_sum(metric), f"{self.enable_metrics}{ACC}_NO_{GROUP}")
298
- else:
299
- loss = 0.0
300
- metric = -1.0
301
-
302
- losses.append(loss)
303
- return losses
304
-
305
-
306
- class FullMask(Model):
307
- def __init__(self, mode=create_uniform_mask):
308
- super().__init__()
309
- self.range_mask = RangeMask()
310
- self.mode = mode
311
-
312
- def call(self, inputs):
313
- target_group, target_slot, _ = raven_utils.decode.decode_target(inputs)
314
- full_properties_musks = self.mode(inputs)
315
- range_mask = self.range_mask(target_group)
316
-
317
- number_mask = range_mask & full_properties_musks[0]
318
-
319
- slot_mask = range_mask & full_properties_musks[1]
320
- properties_mask = []
321
- for property_mask in full_properties_musks[2:]:
322
- properties_mask.append(tf.cast(target_slot, "bool") & property_mask)
323
- return [slot_mask, properties_mask, number_mask]
324
-
325
-
326
- def create_mask(rules, i):
327
- mask_1 = tf.tile(rules[:, i][None], [len(rv.target.FIRST_LAYOUT), 1])
328
- mask_2 = tf.tile(rules[:, i + 5][None], [len(rv.target.SECOND_LAYOUT), 1])
329
- shape = tf.shape(rules)
330
- full_mask_1 = tf.scatter_nd(tnp.array(rv.target.FIRST_LAYOUT)[:, None], mask_1, shape=(rv.entity.SUM, shape[0]))
331
- full_mask_2 = tf.tensor_scatter_nd_update(full_mask_1, tnp.array(rv.target.SECOND_LAYOUT)[:, None], mask_2)
332
- return tf.transpose(full_mask_2)
333
-
334
-
335
- # class PredictModel(Model):
336
- # def __init__(self):
337
- # super().__init__()
338
- # self.predict_fn = Lambda(partial(tf.argmax, axis=-1))
339
- # self.predict_fn_2 = Lambda(lambda x: tf.sigmoid(x) > 0.5)
340
- # self.range_mask = RangeMask()
341
- #
342
- # # self.predict_fn = partial(tf.argmax, axis=-1)
343
- #
344
- # def call(self, inputs):
345
- # group_output = inputs[rv.OUTPUT_GROUP_SLICE]
346
- # group_loss = self.predict_fn(group_output)[:, None]
347
- #
348
- # output_slot = inputs[rv.OUTPUT_SLOT_SLICE]
349
- # range_mask = self.range_mask(group_loss[:, 0])
350
- # loss_slot = tf.cast(self.predict_fn_2(output_slot), dtype=tf.int64)
351
- #
352
- # properties_output = inputs[rv.OUTPUT_PROPERTIES_SLICE]
353
- # properties = []
354
- # outputs = tf.split(properties_output, list(rv.ENTITY_PROPERTIES_INDEX.values()), axis=-1)
355
- # for i, out in enumerate(outputs):
356
- # shape = (-1, rv.ENTITY_SUM, rv.ENTITY_PROPERTIES_VALUES[i])
357
- # out_reshaped = tf.reshape(out, shape)
358
- # properties.append(self.predict_fn(out_reshaped))
359
- # number_loss = tf.reduce_sum(loss_slot, axis=-1, keepdims=True)
360
- #
361
- # result = tf.concat([group_loss, loss_slot, interleave(properties), number_loss], axis=-1)
362
- #
363
- # return [result, range_mask, range_mask, range_mask, range_mask]
364
-
365
- class PredictModel(Model):
366
- def __init__(self):
367
- super().__init__()
368
- self.predict_fn = Predict()
369
- self.predict_fn_2 = Lambda(lambda x: tf.sigmoid(x) > 0.5)
370
- self.range_mask = RangeMask()
371
-
372
- # self.predict_fn = partial(tf.argmax, axis=-1)
373
-
374
- def call(self, inputs):
375
- group_output, output_slot, *properties = rv.decode.output(inputs, tf.split, self.predict_fn, self.predict_fn_2)
376
- number_loss = K.int64(K.sum(output_slot))
377
- result = tf.concat(
378
- [group_output[:, None], tf.cast(output_slot, dtype=tf.int64), interleave(properties), number_loss[:, None]],
379
- axis=-1)
380
-
381
- range_mask = self.range_mask(group_output)
382
- return [result, range_mask]
383
- # return [result, range_mask, range_mask, range_mask, range_mask]
384
-
385
-
386
- # todo change slices
387
- class PredictModelMasked(Model):
388
- def __init__(self):
389
- super().__init__()
390
- self.predict_fn = Lambda(partial(tf.argmax, axis=-1))
391
- self.loss_fn_2 = Lambda(lambda x: tf.sigmoid(x) > 0.5)
392
- self.range_mask = RangeMask()
393
-
394
- # self.predict_fn = partial(tf.argmax, axis=-1)
395
-
396
- def call(self, inputs):
397
- group_output = inputs[:, -rv.GROUPS_NO:]
398
- group_loss = self.predict_fn(group_output)[:, None]
399
-
400
- output_slot = inputs[:, :rv.ENTITY_SUM]
401
- range_mask = self.range_mask(group_loss[:, 0])
402
- loss_slot = tf.cast(self.predict_fn_2(output_slot * range_mask), dtype=tf.int64)
403
-
404
- properties_output = inputs[:, rv.ENTITY_SUM:-rv.GROUPS_NO]
405
-
406
- properties = []
407
- outputs = tf.split(properties_output, list(rv.ENTITY_PROPERTIES_INDEX.values()), axis=-1)
408
- for i, out in enumerate(outputs):
409
- shape = (-1, rv.ENTITY_SUM, rv.ENTITY_PROPERTIES_VALUES[i])
410
- out_reshaped = tf.reshape(out, shape)
411
- out_masked = out_reshaped * loss_slot[..., None]
412
- properties.append(self.predict_fn(out_masked))
413
- # out_masked[0].numpy()
414
- number_loss = tf.reduce_sum(loss_slot, axis=-1, keepdims=True)
415
-
416
- result = tf.concat([group_loss, loss_slot, interleave(properties), number_loss], axis=-1)
417
-
418
- return result
419
-
420
-
421
- def final_predict_mask(x, mask):
422
- r = reshape(x[0][:, rv.INDEX[0]:-1], [-1, 3])
423
- return tf.ragged.boolean_mask(r, mask)
424
-
425
-
426
- def final_predict(x, mode=False):
427
- m = x[1] if mode else tf.cast(x[0][:, 1:rv.INDEX[0]], tf.bool)
428
- return final_predict_mask(x[0], m)
429
-
430
-
431
- def final_predict_2(x):
432
- ones = tf.cast(tf.ones(tf.shape(x[0])[0]), tf.bool)[:, None]
433
- mask = tf.concat([ones, tf.tile(x[1], [1, 4]), ones], axis=-1)
434
- return tf.ragged.boolean_mask(x[0], mask)
435
-
436
-
437
- class PredictModelOld(Model):
438
-
439
- def call(self, inputs):
440
- output = inputs[-2]
441
-
442
- rest_output = output[:, :-rv.GROUPS_NO]
443
-
444
- result_all = []
445
- outputs = tf.split(rest_output, list(rv.ENTITY_PROPERTIES_INDEX.values()), axis=-3)
446
- for i, out in enumerate(outputs):
447
- shape = (-3, rv.ENTITY_SUM, rv.ENTITY_PROPERTIES_VALUES[i])
448
- out_reshaped = tf.reshape(out, shape)
449
-
450
- result = tf.cast(tf.argmax(out_reshaped, axis=-3), dtype="int8")
451
- result_all.append(result)
452
-
453
- result_all = interleave(result_all)
454
- return result_all
455
-
456
-
457
- def get_matches(diff, target_index):
458
- diff_sum = K.sum(diff)
459
- db_argsort = tf.argsort(diff_sum, axis=-1)
460
- db_sorted = tf.sort(diff_sum)
461
- db_mask = db_sorted[:, 0, None] == db_sorted
462
- db_same = tf.where(db_mask, db_argsort, -1 * tf.ones_like(db_argsort))
463
- matched_index = db_same == target_index
464
- # setting shape needed for TensorFlow graph
465
- matched_index.set_shape(db_same.shape)
466
- matches = K.any(matched_index)
467
- more_matches = K.sum(db_mask) > 1
468
- once_matches = K.sum(matches & tf.math.logical_not(more_matches))
469
- return matches, more_matches, once_matches
470
-
471
-
472
- class SimilarityRaven(Model):
473
- def __init__(self, mode=create_all_mask, number_loss=False):
474
- super().__init__()
475
- self.range_mask = RangeMask()
476
- self.mode = mode
477
-
478
- # self.predict_fn = partial(tf.argmax, axis=-1)
479
-
480
- # INDEX, PREDICT, LABELS
481
- def call(self, inputs):
482
- metrics = []
483
- target_index = inputs[0] - 8
484
- predict = inputs[1]
485
- answers = inputs[2][:, 8:]
486
- shape = tf.shape(predict)
487
-
488
- target = K.gather(answers, target_index[:, 0])
489
-
490
- target_group = target[:, 0]
491
-
492
- # comp_slice = np.
493
- target_comp = target[:, 1:rv.target.END_INDEX]
494
- predict_comp = predict[:, 1:rv.target.END_INDEX]
495
- answers_comp = answers[:, :, 1:rv.target.END_INDEX]
496
-
497
- full_properties_musks = self.mode(target)
498
- fpm = K.cat([full_properties_musks[0], interleave(full_properties_musks[2:])])
499
-
500
- range_mask = self.range_mask(target_group)
501
- full_range_mask = K.cat([range_mask, tf.repeat(range_mask, 3, axis=-1)], axis=-1)
502
-
503
- final_mask = fpm & full_range_mask
504
-
505
- target_masked = target_comp * final_mask
506
- predict_masked = predict_comp * final_mask
507
- answers_masked = answers_comp * tf.tile(final_mask[:, None], [1, 8, 1])
508
-
509
- acc_same = K.mean(K.all(target_masked == predict_masked))
510
- self.add_metric(acc_same, ACC_SAME)
511
- metrics.append(acc_same)
512
-
513
- diff = tf.abs(predict_masked[:, None] - answers_masked)
514
- diff_bool = diff != 0
515
-
516
- matches, more_matches, once_matches = get_matches(tf.cast(diff_bool, dtype=tf.int32), target_index)
517
-
518
- second_phase_mask = (more_matches & matches)
519
- diff_second_phase = tf.boolean_mask(diff, second_phase_mask)
520
- target_index_2 = tf.boolean_mask(target_index, second_phase_mask, axis=0)
521
-
522
- matches_2, more_matches_2, once_matches_2 = get_matches(diff_second_phase, target_index_2)
523
- matches_2_no = K.sum(matches_2)
524
-
525
- acc_choose_upper = (once_matches + matches_2_no) / shape[0]
526
- self.add_metric(acc_choose_upper, ACC_CHOOSE_UPPER)
527
- metrics.append(acc_choose_upper)
528
-
529
- acc_choose_lower = (once_matches + once_matches_2) / shape[0]
530
- self.add_metric(acc_choose_lower, ACC_CHOOSE_LOWER)
531
- metrics.append(acc_choose_lower)
532
-
533
- return metrics
534
-
535
-
536
- class SimilarityRaven2(Model):
537
- def __init__(self, mode=create_all_mask, number_loss=False):
538
- super().__init__()
539
- self.range_mask = RangeMask()
540
- self.mode = mode
541
-
542
- # self.predict_fn = partial(tf.argmax, axis=-1)
543
-
544
- # INDEX, PREDICT, LABELS
545
- def call(self, inputs):
546
- metrics = []
547
- target_index = inputs[0] - 8
548
- predict = inputs[1]
549
- answers = inputs[2][:, 8:]
550
- shape = tf.shape(predict)
551
-
552
- target = K.gather(answers, target_index[:, 0])
553
-
554
- target_group = target[:, 0]
555
-
556
- # comp_slice = np.
557
- target_comp = target[:, 1:rv.target.END_INDEX]
558
- predict_comp = predict[:, 1:rv.target.END_INDEX]
559
- answers_comp = answers[:, :, 1:rv.target.END_INDEX]
560
-
561
- full_properties_musks = self.mode(target)
562
- fpm = K.cat([full_properties_musks[0], interleave(full_properties_musks[2:])])
563
-
564
- range_mask = self.range_mask(target_group)
565
- full_range_mask = K.cat([range_mask, tf.repeat(range_mask, 3, axis=-1)], axis=-1)
566
-
567
- final_mask = fpm & full_range_mask
568
-
569
- target_masked = target_comp * final_mask
570
- predict_masked = predict_comp * final_mask
571
- answers_masked = answers_comp * tf.tile(final_mask[:, None], [1, 8, 1])
572
-
573
- acc_same = K.mean(K.all(target_masked == predict_masked))
574
- self.add_metric(acc_same, ACC_SAME)
575
- metrics.append(acc_same)
576
-
577
- diff = tf.abs(predict_masked[:, None] - answers_masked)
578
- diff_bool = diff != 0
579
-
580
- matches, more_matches, once_matches = get_matches(tf.cast(diff_bool, dtype=tf.int32), target_index)
581
-
582
- second_phase_mask = (more_matches & matches)
583
- diff_second_phase = tf.boolean_mask(diff, second_phase_mask)
584
- target_index_2 = tf.boolean_mask(target_index, second_phase_mask, axis=0)
585
-
586
- matches_2, more_matches_2, once_matches_2 = get_matches(diff_second_phase, target_index_2)
587
- matches_2_no = K.sum(matches_2)
588
-
589
- acc_choose_upper = (once_matches + matches_2_no) / shape[0]
590
- self.add_metric(acc_choose_upper, ACC_CHOOSE_UPPER)
591
- metrics.append(acc_choose_upper)
592
-
593
- acc_choose_lower = (once_matches + once_matches_2) / shape[0]
594
- self.add_metric(acc_choose_lower, ACC_CHOOSE_LOWER)
595
- metrics.append(acc_choose_lower)
596
-
597
- metrics.append(K.sum(target_masked != predict_masked))
598
-
599
- return metrics
600
-
601
-
602
- class LatentLossModel(Model):
603
- def __init__(self, dir_=HORIZONTAL):
604
- super().__init__()
605
- # self.sum_metrics = []
606
- # for i in range(8):
607
- # self.sum_metrics.append(Sum(name=f"no_{i}"))
608
- self.metric_fn = Accuracy(name="acc_latent")
609
- if dir_ == VERTICAL:
610
- self.dir = (6, 7)
611
- else:
612
- self.dir = (2, 5)
613
-
614
- def call(self, inputs):
615
- target_image = tf.reshape(inputs[0][2], [-1])
616
- output = inputs[1]
617
- latents = tnp.asarray(inputs[2])
618
-
619
- target_hor = tf.concat([
620
- latents[:, self.dir],
621
- latents[tf.range(latents.shape[0]), target_image + 8][:, None]
622
- ],
623
- axis=1)
624
-
625
- loss_hor = mse(K.stop_gradient(target_hor), output)
626
- self.add_loss(loss_hor)
627
-
628
- self.add_metric(self.metric_fn(inputs[3], target_image))
629
-
630
- return loss_hor
631
-
632
-
633
- class PredRav(Model):
634
-
635
- def call(self, inputs):
636
- output = inputs[0][:, -1]
637
- answers = inputs[1][:, 8:]
638
- return tf.argmin(tf.reduce_sum(tf.abs(output[:, None] - answers), axis=-1), axis=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
raven_utils/models/multi_transformer.py DELETED
@@ -1,274 +0,0 @@
1
- import tensorflow as tf
2
- from functools import partial
3
- from tensorflow.keras.layers import Lambda
4
- from tensorflow.keras.layers import Dense
5
- from tensorflow.keras import Input, Model
6
- from tensorflow.python.keras import Sequential
7
-
8
- from config.constant import TRANS
9
- from ml_utils import filter_init
10
- from models.loss import VTRavenLoss, create_uniform_mask, SingleVTRavenLoss
11
- from models_utils import pmodel, DictModel, bt, INPUTS, bm, OUTPUT, LATENTS, transformer, BatchModel, get_extractor, \
12
- build_seq_model, BUILD, build_train_list, InitialWeight
13
- from models_utils import SumPositionEmbedding, TransformerBlock, CatPositionEmbedding, transformer, BatchInitialWeight
14
- import models_utils.ops as K
15
- from models_utils.image import inverse_fn
16
- from models_utils.ops_core import IndexReshape
17
- from models_utils.random_ import EpsilonGreedy, EpsilonSoft
18
- from models_utils.step import StepDict
19
-
20
-
21
- def init_weights(shape, dtype=None):
22
- return tf.cast(K.var.image(shape=shape, pre=True), dtype=tf.float32)
23
-
24
-
25
- def conversion(x, max_=45):
26
- shape = tf.shape(x)
27
- return tf.reshape(x[:, :max_], tf.stack([shape[0], 9, -1]))
28
-
29
-
30
- def take_left(x):
31
- return x[..., 7:8]
32
-
33
-
34
- def take_by_index(x, i=8):
35
- return x[..., i:i + 1]
36
-
37
-
38
- def mix(x):
39
- return (x[..., 7:8] + x[..., 5:6]) / 2
40
-
41
-
42
- def empty_last(x):
43
- return tf.zeros_like(x[..., 7:8])
44
-
45
-
46
- class Conversion(Model):
47
- def __init__(self):
48
- super().__init__()
49
- self.model = IndexReshape((0, "9", None))
50
-
51
- def call(self, inputs):
52
- return self.model(inputs[:, :45])
53
-
54
-
55
- class RandomImageMask(Model):
56
- def __init__(self, last, last_index=9):
57
- super().__init__()
58
- self.get_last = last
59
- self.last_index = last_index
60
-
61
- def call(self, inputs):
62
- shape = tf.shape(inputs)
63
- indexes = tf.random.uniform(shape=shape[0:1], maxval=self.last_index, dtype=tf.int32)
64
- mask = tf.one_hot(indexes, self.last_index)[:, None, None]
65
-
66
- return (1 - mask) * inputs[..., :self.last_index] + mask * tf.tile(self.get_last(inputs),
67
- (1, 1, 1, self.last_index))
68
-
69
-
70
- # res = (1 - mask) * inputs[..., :self.last_index] + mask * tf.tile(self.get_last(inputs),
71
- # (1, 1, 1, self.last_index))
72
-
73
-
74
- # from data_utils import ims
75
- # for i in range(50):
76
- # ims(res[i].numpy().swapaxes(0, 2))
77
- # res[12].numpy()
78
- # self.get_last(inputs).numpy()
79
- # import tensorflow as tf
80
- # tf.random.uniform(shape=shape[0:1], maxval=255, dtype=tf.int32)
81
- # from ml_utils import print_error
82
- # ims(mask[0].numpy())
83
- # print_error(lambda :ims(mask[0]))
84
- # from models_utils import ops as K
85
-
86
-
87
- class ImageMask(Model):
88
- def __init__(self, last, index=8, last_index=9):
89
- super().__init__()
90
- self.get_last = last
91
- self.index = index
92
- self.last_index = last_index
93
-
94
- def call(self, inputs):
95
- return tf.concat([inputs[..., :8], self.get_last(inputs)], axis=-1)
96
-
97
-
98
- class CreateGrid(Model):
99
- def __init__(self,
100
- no=4,
101
- extractor="ef",
102
- type_=3,
103
- base="seq",
104
- last=take_left,
105
- epsilon=None,
106
- pooling=None,
107
- mask_fn=None,
108
- model=None,
109
- **kwargs
110
- ):
111
- super().__init__()
112
- self.type_ = type_
113
- if type_ == 9:
114
- self.start_shape = 75
115
- data = (224, 224, 3)
116
- conv = lambda: Conversion()
117
- else:
118
- self.start_shape = 84
119
- data = (84, 84, 3)
120
- extractor = BUILD[base]([
121
- BatchModel(get_extractor(data=data, model=extractor)),
122
- lambda x: tf.transpose(x, (1, 0, 2, 3, 4))
123
- # lambda x: tf.tile(x[:, :224, :224], (1, 1, 1, 3))
124
- ])
125
- conv = lambda: conversion
126
-
127
- self.epsilon = epsilon
128
- if mask_fn == "random":
129
- mask_fn = RandomImageMask(last=last)
130
- elif mask_fn is None:
131
- mask_fn = ImageMask(last=last)
132
-
133
- self.mask_fn = mask_fn
134
-
135
-
136
- def call(self, inputs):
137
- transposed = tf.image.resize(tf.transpose(inputs, (0, 2, 3, 1)), (self.start_shape, self.start_shape))
138
- re = self.mask_fn(transposed)
139
-
140
- # re = tf.concat([transposed[..., :8], self.get_last(transposed)], axis=-1)
141
- if self.type_ == 9:
142
- x = tf.transpose(re, [0, 3, 1, 2])[..., None]
143
- x = K.create_image_grid(x, 3, 3)
144
- x = x[:, :224, :224]
145
- x = tf.tile(x, [1, 1, 1, 3])
146
- else:
147
-
148
- x = tf.stack([
149
- re[..., :3],
150
- re[..., 3:6],
151
- re[..., 6:9],
152
- ])
153
- return self.model(x)
154
-
155
-
156
- # self.model.layers[0](x)
157
-
158
-
159
- def grid_transformer(
160
- *args,
161
- type_=9,
162
- no=4,
163
- extractor="ef",
164
- loss_mode=create_uniform_mask,
165
- output_size=10,
166
- loss_weight=1.0,
167
- out_layers=(1000, 1000, 1000),
168
- pos_emd="cat",
169
- base="seq",
170
- inverse_image=True,
171
- last="left",
172
- mask_fn=None,
173
- model=None,
174
- trans=None,
175
- **kwargs):
176
-
177
- if last == "left":
178
- last = take_left
179
- elif last == "mix":
180
- last = mix
181
- elif last == "empty":
182
- last = empty_last
183
- elif last == "start":
184
- last = Sequential([Lambda(empty_last), BatchInitialWeight(initializer=init_weights)])
185
-
186
- create_grid = CreateGrid(
187
- type_=type_,
188
- no=no,
189
- extractor=extractor,
190
- model=model,
191
- output_size=output_size,
192
- out_layer=out_layers,
193
- pos_emd=pos_emd,
194
- base=base,
195
- last=last,
196
- mask_fn=mask_fn,
197
- **kwargs
198
- )
199
-
200
- if model is None:
201
- trans = transformer(
202
- extractor=extractor,
203
- pos_emd=pos_emd,
204
- data=data,
205
- output_size=output_size,
206
- out_layers=out_layer,
207
- pooling=conv,
208
- no=no,
209
- base=base,
210
- **kwargs
211
- # **as_dict(p.trans)
212
- )
213
- else:
214
- trans = trans
215
-
216
-
217
-
218
- def get_rav_trans(
219
- *args,
220
- type_=9,
221
- no=4,
222
- extractor="ef",
223
- loss_mode=create_uniform_mask,
224
- output_size=10,
225
- loss_weight=1.0,
226
- out_layers=(1000, 1000, 1000),
227
- pos_emd="cat",
228
- base="seq",
229
- inverse_image=True,
230
- last="left",
231
- epsilon="greedy",
232
- epsilon_step=500,
233
- mask_fn=None,
234
- model=None,
235
- loss="multi",
236
- **kwargs):
237
- if last == "left":
238
- last = take_left
239
- elif last == "mix":
240
- last = mix
241
- elif last == "empty":
242
- last = empty_last
243
- elif last == "start":
244
- last = Sequential([Lambda(empty_last), BatchInitialWeight(initializer=init_weights)])
245
-
246
- trans_raven = CreateGrid(
247
- type_=type_,
248
- no=no,
249
- extractor=extractor,
250
- model=model,
251
- output_size=output_size,
252
- out_layer=out_layers,
253
- pos_emd=pos_emd,
254
- base=base,
255
- last=last,
256
- epsilon=epsilon,
257
- mask_fn=mask_fn,
258
- **kwargs
259
- )
260
-
261
- if loss == "single":
262
- loss = SingleVTRavenLoss
263
- else:
264
- loss = VTRavenLoss
265
-
266
- return bt(
267
- DictModel(
268
- Sequential([Lambda(lambda x: 255 - x), trans_raven]) if inverse_image else trans_raven,
269
- in_=INPUTS,
270
- name="Body"
271
- ),
272
- loss=loss(mode=loss_mode, classification=True, lw=(loss_weight, 1.0)),
273
- loss_wrap=False
274
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
raven_utils/models/raven.py DELETED
@@ -1,239 +0,0 @@
1
- from ml_utils import lw, lu
2
- from models_utils import bm, Base, res, bt, DictModel, dense_drop, drop, build_encoder, MODEL_ARCH, ListModel, short, \
3
- dense, Flatten, Cat, CatDenseBefore, \
4
- CatDense, CatBefore, Drop, Flat2, down, Pass, conv, Flat, Get, bs, Res, SoftBlock
5
- from models_utils import SubClassingModel
6
- from models_utils.config.constants import *
7
- import numpy as np
8
-
9
- from config.constant import *
10
- from tensorflow.keras.layers import Dense, Activation, BatchNormalization
11
- import tensorflow as tf
12
-
13
- import raven_utils as rv
14
-
15
- from models.body import create_block
16
- from models.class_ import Merge, RavenClass
17
- from models.head import LatentHeadModel
18
-
19
- from models.loss import RavenLoss
20
- from models.trans import TransModel, FullTrans
21
- from raven_utils.const import HORIZONTAL
22
-
23
-
24
- def raven_model(scales,
25
- out_layers,
26
- latent=(64, 128, 256),
27
- output_size=None,
28
- padding=SAME,
29
- body_layers=1,
30
- encoder=None,
31
- loop=1,
32
- model=None,
33
- act=None,
34
- simpler=0,
35
- loss_mode=None,
36
- loss_weight=0.3,
37
- dir_=HORIZONTAL,
38
- global_context=False,
39
- images_no=8,
40
- context_mul=2,
41
- res_act="pass",
42
- drop_latent=0,
43
- drop_inference=0,
44
- drop_end=0,
45
- ga=False,
46
- trans_norm=None,
47
- trans_act="relu",
48
- arch=HEAD3,
49
- encoder_norm=False,
50
- encoder_pool=False,
51
- encoder_global="GM",
52
- encoder_before=False,
53
- tail_units=256,
54
- tail_flatten=None,
55
- # for now by default
56
- tail_down="MP",
57
- trans_no=1,
58
- trans_score_activation=tf.nn.softmax,
59
- block_=SoftBlock,
60
- **kwargs):
61
- if isinstance(latent, int):
62
- latent = (latent, 128, 256)
63
- scales = lw(scales)
64
-
65
- context_size = np.array(latent) * context_mul
66
- # context_size = latent[scales] * context_mul
67
-
68
- # if scales == 2:
69
- # arch = HEAD
70
- # elif scales == 1:
71
- # arch = HEAD2
72
- # else:
73
- # arch = VERY2
74
-
75
- if encoder_pool:
76
- strides = (1, 1)
77
- else:
78
- strides = (2, 2)
79
- if not isinstance(encoder_before, tuple):
80
- encoder_before = [encoder_before] * 3
81
-
82
- # if trans == 1:
83
- # trans_model = TransModel2
84
- # else:
85
- # trans_model = TransModel
86
-
87
- # if scales == 3:
88
- # head = MultiHeadModel(encoder=encoder)
89
- arch = MODEL_ARCH[arch]
90
- heads = []
91
- for s in list(range(0, max(scales) + 1)):
92
- if s in (0, 1):
93
- if s == 0:
94
- encoder = build_encoder(arch[:3], add_norm=encoder_norm, add_pool=encoder_pool, kerner_size=(4, 4),
95
- strides=strides)
96
- else:
97
- encoder = build_encoder(arch[3:4], add_norm=encoder_norm, add_pool=encoder_pool, kerner_size=(4, 4),
98
- strides=strides)
99
- head = LatentHeadModel(
100
- encoder=encoder,
101
- inference_network=(
102
- bm([
103
- CatBefore(filters=int(context_size[s] / 8)) if encoder_before[s] else Cat(
104
- filters=context_size[s]),
105
- # todo activation?
106
- Res(filters=context_size[s], padding=padding)
107
- ] + ([drop(drop_inference)] if drop_inference else []),
108
- name="inference")
109
- ) if s in scales else Pass(),
110
- stem=Base(
111
- bm(
112
- # ok we choose by parameters anyway
113
- [res(filters=latent[s], padding=padding, act=act)] + (
114
- [drop(drop_latent)] if drop_latent else [])
115
- ),
116
- name="stem")
117
- )
118
- else:
119
- encoder = bm([
120
- Res(),
121
- build_encoder(arch[4:], add_norm=encoder_norm, add_pool=encoder_pool, kerner_size=(4, 4),
122
- strides=strides),
123
- short(encoder_global) if encoder_global else Flatten(),
124
- dense(latent[s])
125
- ])
126
- head = LatentHeadModel(
127
- encoder=encoder,
128
- inference_network=bm([
129
- # todo Echeck Cat
130
- CatDenseBefore(filters=int(context_size[s] / 8)) if encoder_before[
131
- s] else CatDense(filters=context_size[s]),
132
- # todo activation?
133
- Res(model="dv2", filters=context_size[s], padding=padding)
134
- ] + ([dense_drop(drop_inference)] if drop_inference else []),
135
- name="inference"),
136
- stem=Base(
137
- bm(
138
- # ok we choose by parameters anyway
139
- [res(model="dv2", units=latent[s], padding=padding, act=act)] + (
140
- [dense_drop(drop_latent)] if drop_latent else [])
141
- ),
142
- name="stem")
143
- )
144
- heads.append(head)
145
-
146
- concat_input = [f"{LATENT}_{i}" for i, _ in enumerate(heads)] + [f"{INFERENCE}_{i}" for i, _ in enumerate(heads)]
147
- concat_output = ["LATENTS", "INFERENCES"]
148
-
149
- def head_concat(inputs):
150
- latents = inputs[:len(heads)]
151
- inferences = inputs[len(heads):]
152
- return latents, inferences
153
-
154
- head = ListModel([(h, (INPUTS if i == 0 else OUTPUT), [f"{LATENT}_{i}", f"{INFERENCE}_{i}", OUTPUT]) for i, h in
155
- enumerate(heads)] + [
156
- (head_concat, concat_input, concat_output)], out=concat_output)
157
- # from rav_utils.raven import init_image
158
- # a = init_image()
159
- # head(a)
160
-
161
- if model is None:
162
- model = []
163
- for i in scales:
164
- trans_models = []
165
- for t in range(trans_no):
166
- trans_models.append(
167
- bm(
168
- [create_block(latent=latent[i], simpler=simpler, padding=padding, norm=trans_norm, act=res_act,
169
- loop=loop, type_="dense" if i == 2 else "conv", block_=block_)] +
170
- [Activation(trans_act)] + [
171
- res(filters=latent[i],
172
- padding=padding,
173
- act=act,
174
- name="body_out",
175
- model="dv2" if i == 2 else "v2") for _ in
176
- range(body_layers)] + ([Drop(drop_latent)] if drop_latent else []),
177
- base_class=SubClassingModel)
178
- )
179
- trans_models = lu(trans_models)
180
- if trans_no > 1:
181
- trans_models = bm([
182
- lambda x: [[x[0], x[1]], x[1]],
183
- SoftBlock(
184
- model=trans_models,
185
- score_model=bm([
186
- Flat2(filters=latent[i], units=256, res_no=2),
187
- Dense(trans_no, trans_score_activation)
188
- ])
189
- )
190
- ],
191
- base_class=SubClassingModel
192
- )
193
-
194
- model.append(
195
- TransModel(
196
- body=trans_models,
197
- dir_=dir_,
198
- images_no=images_no
199
- )
200
- )
201
-
202
- tail = []
203
- for i, s in enumerate(scales):
204
- flatting = lambda: Flat2(filters=latent[s + 1], base_class=tail_flatten, units=tail_units)
205
- if s == 0:
206
- if tail_flatten is None:
207
- branch = bm([res(filters=latent[s], padding=padding),
208
- conv(filters=latent[s], padding=padding),
209
- BatchNormalization(),
210
- conv(filters=latent[s], padding=padding),
211
- Flatten()])
212
- else:
213
- branch = bm([down(base_class=tail_down), flatting()])
214
- elif s == 1:
215
- if tail_flatten is None:
216
- branch = bm([res(filters=latent[s], padding=padding),
217
- Flatten()])
218
- else:
219
- branch = flatting()
220
- else:
221
- branch = bm([tail_units] * 2, add_flatten=False)
222
- tail.append(branch)
223
-
224
- tail.append(
225
- bm([dense(tail_units)] + ([dense_drop(drop_end)] if drop_end else []) + [Dense(output_size)], add_flatten=False,
226
- name=TAIL))
227
- class_input = []
228
-
229
- return bt([
230
- DictModel(head, in_=INPUTS, out=[LATENT, INFERENCE], name="Head"),
231
- DictModel(FullTrans(model, scales=scales), in_=[LATENT, INFERENCE], out=TRANS, name="Body"),
232
- DictModel(RavenClass(Merge(tail), scales=scales, no=8), in_=[LATENT] + class_input, out=CLASSIFICATION,
233
- name="Classificator"),
234
- DictModel(RavenClass(Merge(tail), scales=list(range(len(scales))), no=3), in_=[TRANS] + class_input,
235
- out=OUTPUT, name="Classificator_trans"),
236
- ],
237
- loss=RavenLoss(mode=loss_mode, classification=True, trans=True, lw=(1.0, loss_weight)),
238
- loss_wrap=False
239
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
raven_utils/models/trans.py DELETED
@@ -1,74 +0,0 @@
1
- import tensorflow as tf
2
- from ml_utils import lw
3
- from models_utils import ops as K, SubClassingModel
4
- from tensorflow.keras import Model
5
-
6
- from models.body import create_dense_block
7
- import raven_utils as rv
8
- from raven_utils.const import HORIZONTAL, VERTICAL
9
-
10
-
11
- class TransModel(Model):
12
- def __init__(self, body=None, dir_=HORIZONTAL, images_no=8, latent=64):
13
- super().__init__()
14
- self.model = body or create_dense_block(latent=latent)
15
- if dir_ == VERTICAL:
16
- self.dir = (0, 3, 1, 4, 3, 5)
17
- else:
18
- self.dir = (0, 1, 3, 4, 6, 7)
19
- self.images_no = images_no
20
- self.latent = latent
21
-
22
- def call(self, inputs):
23
- # latents = tnp.asarray(inputs[0])
24
- latents = inputs[0]
25
- inference = inputs[1]
26
- shape = tf.shape(latents)
27
- new_shape = K.cat([[-1, 3, 2], shape[2:]])
28
- horizontal = latents[:, self.dir].reshape(new_shape)
29
- res = tf.TensorArray(tf.float32, size=3)
30
- for i in range(3):
31
- res = res.write(i, self.model([horizontal[:, i], inference]))
32
- result = K.tran(res.stack())
33
- return result
34
-
35
-
36
- class TransModel2(Model):
37
- def __init__(self, body=None, dir_=HORIZONTAL, images_no=8, latent=64):
38
- super().__init__()
39
- self.body = body or create_dense_block(latent=latent)
40
- if dir_ == VERTICAL:
41
- self.dir = (0, 3, 1, 4, 3, 5)
42
- else:
43
- self.dir = (0, 1, 3, 4, 6, 7)
44
- self.images_no = images_no
45
- self.latent = latent
46
-
47
- def call(self, inputs):
48
- # latents = tnp.asarray(inputs[0])
49
- latents = inputs[0]
50
- inference = inputs[1]
51
- shape = tf.shape(latents)
52
- new_shape = K.cat([[-1, 3, 2], shape[2:]])
53
- horizontal = latents[:, self.dir].reshape(new_shape)
54
- res = tf.TensorArray(tf.float32, size=3)
55
- for i in tf.range(3):
56
- res = res.write(i, self.body([horizontal[:, i], inference[:,i]]))
57
- result = K.tran(res.stack())
58
- return result
59
-
60
-
61
- class FullTrans(SubClassingModel):
62
- def __init__(self, model,scales,name=None):
63
- super().__init__(model=model,name=name)
64
- self.scales = scales
65
-
66
- def call(self, inputs):
67
- latent = lw(inputs[0])
68
- inference = lw(inputs[1])
69
- results = []
70
- # todo merging inference?
71
- for i,s in enumerate(self.scales):
72
- # results.append(model([latent[::-1][i], inference]))
73
- results.append(self.model[i]([latent[s], inference[s]]))
74
- return results,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
raven_utils/models/transformer.py DELETED
@@ -1,133 +0,0 @@
1
- import tensorflow as tf
2
- from tensorflow.keras.layers import Lambda
3
- from tensorflow.python.keras import Sequential
4
-
5
- # from models_utils.models.loss import VTRavenLoss, create_uniform_mask, SingleVTRavenLoss
6
- from models_utils import DictModel, bt, INPUTS, BatchInitialWeight
7
- import models_utils.ops as K
8
- from models_utils.models.transformer.img_seq import init_weights, take_left, mix, empty_last
9
- from models_utils.models.transformer.img_seq2 import init_weights, take_left, mix, empty_last, img_sec_trans
10
- from models_utils.ops_core import IndexReshape
11
- from models_utils.random_ import EpsilonGreedy, EpsilonSoft
12
- from models_utils.step import StepDict
13
-
14
- # res = (1 - mask) * inputs[..., :self.last_index] + mask * tf.tile(self.get_last(inputs),
15
- # (1, 1, 1, self.last_index))
16
-
17
-
18
- # from data_utils import ims
19
- # for i in range(50):
20
- # ims(res[i].numpy().swapaxes(0, 2))
21
- # res[12].numpy()
22
- # self.get_last(inputs).numpy()
23
- # import tensorflow as tf
24
- # tf.random.uniform(shape=shape[0:1], maxval=255, dtype=tf.int32)
25
- # from ml_utils import print_error
26
- # ims(mask[0].numpy())
27
- # print_error(lambda :ims(mask[0]))
28
- # from models_utils import ops as K
29
-
30
-
31
- # self.model.layers[0](x)
32
- from raven_utils.models.loss import VTRavenLoss, SingleVTRavenLoss, create_uniform_mask
33
-
34
-
35
- def get_rav_trans(
36
- data,
37
- type_=9,
38
- no=4,
39
- extractor="ef",
40
- loss_mode=create_uniform_mask,
41
- output_size=10,
42
- loss_weight=1.0,
43
- out_layers=(1000, 1000, 1000),
44
- pos_emd="cat",
45
- base="seq",
46
- inverse_image=True,
47
- last="left",
48
- epsilon="greedy",
49
- epsilon_step=500,
50
- mask_fn=None,
51
- model=None,
52
- loss="multi",
53
- **kwargs):
54
- if last == "left":
55
- last = take_left
56
- elif last == "mix":
57
- last = mix
58
- elif last == "empty":
59
- last = empty_last
60
- elif last == "start":
61
- last = Sequential([Lambda(empty_last), BatchInitialWeight(initializer=init_weights)])
62
-
63
- if epsilon == "greedy":
64
- epsilon = EpsilonGreedy(step=epsilon_step)
65
- elif epsilon == "soft":
66
- epsilon = EpsilonSoft(step=epsilon_step)
67
- elif epsilon is False:
68
- epsilon = None
69
-
70
- if epsilon:
71
- trans_raven = TransRavenwithStep(
72
- type_=type_,
73
- no=no,
74
- extractor=extractor,
75
- output_size=output_size,
76
- out_layer=out_layers,
77
- pos_emd=pos_emd,
78
- base=base,
79
- last=last,
80
- epsilon=epsilon,
81
- **kwargs
82
- )
83
- return StepDict(bt(
84
- DictModel(
85
- Sequential([Lambda(lambda x: (255 - x[0], x[1])), trans_raven]) if inverse_image else trans_raven,
86
- in_=[INPUTS, "step"],
87
- name="Body"
88
- ),
89
- loss=VTRavenLoss(mode=loss_mode, classification=True, lw=(loss_weight, 1.0)),
90
- loss_wrap=False),
91
- add_step=epsilon_step,
92
- )
93
-
94
- trans_raven = img_sec_trans(
95
- type_=type_,
96
- no=no,
97
- extractor=extractor,
98
- model=model,
99
- output_size=output_size,
100
- out_layer=out_layers,
101
- pos_emd=pos_emd,
102
- base=base,
103
- last=last,
104
- epsilon=epsilon,
105
- mask_fn=mask_fn,
106
- **kwargs
107
- )
108
- if loss == "single":
109
- loss = SingleVTRavenLoss
110
- else:
111
- loss = VTRavenLoss
112
-
113
- # return bt(
114
- # DictModel(
115
- # Sequential([Lambda(lambda x: 255 - x), trans_raven]) if inverse_image else trans_raven,
116
- # inputs=INPUTS,
117
- # name="Body"
118
- # ),
119
- # loss=loss(mode=loss_mode, classification=True, lw=(loss_weight, 1.0)),
120
- # loss_wrap=False
121
- # )
122
-
123
- return bt([
124
- DictModel(
125
- Sequential([Lambda(lambda x: 255 - x), trans_raven]) if inverse_image else trans_raven,
126
- in_=INPUTS,
127
- name="Body"
128
- ),
129
-
130
- ],
131
- loss=loss(mode=loss_mode, classification=True, lw=(loss_weight, 1.0)),
132
- loss_wrap=False
133
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
raven_utils/models/transformer_2.py DELETED
@@ -1,146 +0,0 @@
1
- from functools import partial
2
-
3
- import tensorflow as tf
4
- from tensorflow.keras.layers import Lambda
5
- from tensorflow.python.keras import Sequential
6
- from models_utils import ops as K, SubClassing
7
- from models_utils.models.transformer import aug
8
-
9
- # from models_utils.models.loss import VTRavenLoss, create_uniform_mask, SingleVTRavenLoss
10
- from data_utils import DataGenerator, LOSS, TARGET, IMAGES
11
- from models_utils import DictModel, bt, INPUTS, BatchInitialWeight, build_functional_model, get_input_layer
12
- import models_utils.ops as K
13
- from models_utils.models.transformer.img_seq import init_weights, take_left, mix, empty_last
14
- from models_utils.models.transformer.img_seq2 import init_weights, take_left, mix, empty_last, img_sec_trans
15
- from models_utils.ops_core import IndexReshape
16
- from models_utils.random_ import EpsilonGreedy, EpsilonSoft
17
- from models_utils.step import StepDict
18
-
19
- from models_utils.models.transformer import aug
20
-
21
- # res = (1 - mask) * inputs[..., :self.last_index] + mask * tf.tile(self.get_last(inputs),
22
- # (1, 1, 1, self.last_index))
23
-
24
-
25
- # from data_utils import ims
26
- # for i in range(50):
27
- # ims(res[i].numpy().swapaxes(0, 2))
28
- # res[12].numpy()
29
- # self.get_last(inputs).numpy()
30
- # import tensorflow as tf
31
- # tf.random.uniform(shape=shape[0:1], maxval=255, dtype=tf.int32)
32
- # from ml_utils import print_error
33
- # ims(mask[0].numpy())
34
- # print_error(lambda :ims(mask[0]))
35
- # from models_utils import ops as K
36
-
37
-
38
- # self.model.layers[0](x)
39
- from raven_utils.constant import INDEX, LABELS
40
- from raven_utils.models.loss import VTRavenLoss, SingleVTRavenLoss, create_uniform_mask
41
-
42
-
43
- def get_matrix(inputs, index):
44
- return tf.concat([inputs[:, :8], K.gather(inputs, index[:, 0])[:, None]], axis=1)
45
-
46
-
47
- def get_images(inputs):
48
- return get_matrix(inputs[0], inputs[1])
49
-
50
-
51
- def random_last(inputs, max_=8):
52
- index = K.init.label(max=max_, shape=[tf.shape(inputs[0])[0]])[..., None]
53
- return get_matrix(inputs[0], index)
54
-
55
-
56
- def get_images_no_answer(inputs):
57
- return inputs[0][:, :9]
58
-
59
-
60
- def repeat_last(inputs):
61
- return inputs[0][:, list(range(8)) + [7]]
62
-
63
-
64
- def get_rav_trans(
65
- data,
66
- inverse_image=True,
67
- loss_mode=create_uniform_mask,
68
- loss_weight=1.0,
69
- loss="multi",
70
- number_loss=False,
71
- plw=None,
72
- pre="auto",
73
- augmentation=None,
74
- **kwargs):
75
- if isinstance(data, DataGenerator):
76
- data = data[0]['inputs'], data[0]['index']
77
- # u = img_sec_trans(**kwargs)(get_images(data) if kwargs['mask'] == "random" else get_images_no_answer(data))
78
- # u.shape
79
- from keras import Model
80
- if pre == "auto":
81
- pre = get_images if kwargs['mask'] == "random" else get_images_no_answer
82
- elif pre == "no_answer":
83
- pre = get_images_no_answer
84
- elif pre == "last":
85
- pre = repeat_last
86
- elif pre == "images":
87
- pre = get_images
88
- elif pre == "random_last":
89
- pre = random_last
90
- elif pre == "noise":
91
- pre = SubClassing([get_matrix, partial(aug.noise, max_=8)])
92
- elif pre == "batch_noise":
93
- pre = SubClassing([get_matrix, partial(aug.batch_noise, max_=8)])
94
-
95
- if augmentation == "transpose":
96
- augmentation = aug.Transpose(axis=(0, 2, 1))
97
- augmentation_label = aug.Transpose(axis=(0, 2, 1))
98
- elif augmentation == "shuffle_col":
99
- augmentation = aug.shuffle_col
100
- augmentation_label = aug.shuffle_col
101
- elif augmentation == "shuffle":
102
- augmentation = aug.shuffle
103
- augmentation_label = aug.shuffle
104
- if augmentation:
105
- augmentation = [
106
- # DictModel(augmentation, IMAGES, IMAGES),
107
- # DictModel(aug.reshape_static(pre(data),augmentation), IMAGES, IMAGES),
108
- DictModel(aug.ReshapeStatic(augmentation), IMAGES, IMAGES),
109
- DictModel(
110
- aug.PartialModel(
111
- aug.ReshapeStatic(augmentation_label),
112
- last_axis=9)
113
- , LABELS, LABELS)
114
- ]
115
- else:
116
- augmentation = []
117
-
118
- trans_raven = build_functional_model(
119
- img_sec_trans(**kwargs),
120
- # get_images(data) if kwargs['mask'] == "random" else get_images_no_answer(data)
121
- pre(data)
122
- # data[0]
123
- )
124
- if loss == "single":
125
- loss = SingleVTRavenLoss
126
- else:
127
- loss = VTRavenLoss
128
- if isinstance(loss_weight, float):
129
- loss_weight = (loss_weight, 1.0)
130
-
131
- return bt([
132
- # DictModel(get_images if kwargs['mask'] == "random" else get_images_no_answer, [INPUTS, INDEX], IMAGES),
133
- DictModel(pre, [INPUTS, INDEX], IMAGES),
134
- *augmentation,
135
- DictModel(
136
- Sequential([Lambda(lambda x: 255 - x), trans_raven]) if inverse_image else trans_raven,
137
- in_=IMAGES,
138
- # inputs=INPUTS,
139
- name="Body"
140
- ),
141
-
142
- ],
143
- loss=loss(mode=loss_mode, classification=True, number_loss=number_loss, lw=loss_weight, plw=plw),
144
- predict=LOSS,
145
- loss_wrap=False
146
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
raven_utils/models/transformer_3.py DELETED
@@ -1,206 +0,0 @@
1
- import logging
2
-
3
- from loguru import logger
4
- from tensorflow.keras.layers import Lambda
5
- from tensorflow.keras.layers import Activation
6
-
7
- from grid_transformer import aug_trans
8
- from raven_utils.models.loss_3 import VTRavenLoss, SingleVTRavenLoss, create_uniform_mask
9
- from data_utils import get_shape, TakeDict
10
-
11
- from data_utils import DataGenerator, LOSS, TARGET, IMAGES
12
- from models_utils import DictModel, bt, INPUTS, BatchInitialWeight, build_functional, get_input_layer, Last, bm, \
13
- add_end, AUGMENTATION
14
- # from report.select_ import SelectModel2, SelectModel, SelectModel9
15
- from experiment_utils.keras_model import load_weights as model_load_weights
16
-
17
-
18
- def get_rav_trans(
19
- data,
20
- loss_mode=create_uniform_mask,
21
- loss_weight=2.0,
22
- number_loss=False,
23
- dry_run="auto",
24
- plw=None,
25
- **kwargs):
26
- if isinstance(loss_weight, float):
27
- loss_weight = (loss_weight, 1.0)
28
-
29
- # seq_trans(**kwargs)(data[0])
30
- # trans_raven = build_functional_model2(
31
- # seq_trans(**kwargs),
32
- # data[0],
33
- # batch=None
34
- # )
35
- trans_raven = build_functional(
36
- model=aug_trans,
37
- inputs_=data[0] if isinstance(data, DataGenerator) else data,
38
- batch_=None,
39
- dry_run=dry_run,
40
- **kwargs
41
- )
42
-
43
- return bt(
44
- model=trans_raven,
45
- loss=VTRavenLoss(mode=loss_mode, classification=True, number_loss=number_loss, lw=loss_weight, plw=plw),
46
- model_wrap=False,
47
- predict=LOSS,
48
- loss_wrap=False
49
- )
50
-
51
-
52
- def rav_select_model(
53
- data,
54
- load_weights=None,
55
- loss_weight=(0.01, 0.0),
56
- plw=5.0,
57
- result_metric="sparse_categorical_accuracy",
58
- select_type=2,
59
- select_out=0,
60
- additional_out=0,
61
- additional_copy=True,
62
- tail_out=(1000, 1000),
63
- **kwargs
64
- ):
65
- out_layers = Last()
66
- if additional_out > 0:
67
- model3 = get_rav_trans(
68
- data,
69
- plw=plw,
70
- loss_weight=loss_weight,
71
- **kwargs
72
- )
73
-
74
- model_load_weights(
75
- model3,
76
- load_weights,
77
- # sample_data,
78
- None,
79
- template="weights_{epoch:02d}-{val_loss:.2f}",
80
- key=result_metric,
81
- )
82
-
83
- if AUGMENTATION in kwargs and kwargs[AUGMENTATION] is not None:
84
- index = -1
85
- else:
86
- index = -2
87
-
88
- out = model3[0, index, :additional_out]
89
- logger.info(f"Additional out from: {model3[0, index]}.")
90
-
91
- if additional_out > 2:
92
- out += [Activation("gelu")]
93
- out_layers = bm([out_layers] + out, add_flatten=False)
94
- model = get_rav_trans(
95
- TakeDict(data[0])[:, 8:],
96
- plw=plw,
97
- loss_weight=loss_weight,
98
- **{
99
- **kwargs,
100
- "out_layers": out_layers,
101
- }
102
- # **{**as_dict(p.mp), "show_shape": True, "save_shape": f"output/shapes/type_{p.mp.type_}.json"},
103
- )
104
- # from data_utils.ops import Equal
105
- # o = []
106
- # for i in range(1, 3):
107
- # for j in range(2):
108
- # o.append(
109
- # Equal(
110
- # # model[0,:,-2, i].variables[j],
111
- # model2[0, :, -2, i].variables[j],
112
- # # out_layers[i].variables[j]
113
- # second_pooling[i].variables[j]
114
- # ).equal
115
- # )
116
- # assert all(o)
117
- # model = get_rav_trans(
118
- # # TakeDict(val_generator[0])[:, 8:],
119
- # # TakeDict(val_generator[0])[:, 8:],
120
- # val_generator[0],
121
- # plw=p.plw,
122
- # loss_weight=p.loss_weight,
123
- # **{**as_dict(p.mp),
124
- # # "out_layers": out_layers,
125
- # }
126
- # # **{**as_dict(p.mp), "show_shape": True, "save_shape": f"output/shapes/type_{p.mp.type_}.json"},
127
- # )
128
- model_load_weights(model,
129
- load_weights,
130
- # sample_data,
131
- None,
132
- template="weights_{epoch:02d}-{val_loss:.2f}",
133
- key=result_metric,
134
- )
135
- # model.compile()
136
- # model.evaluate(val_generator.data[:1000])
137
- # model(TakeDict(val_generator[0])[:, 8:])
138
- trans_raven = model[0]
139
- # s = trans_raven(TakeDict(val_generator[0])[:, 8:])
140
- if select_type == 2:
141
- second_pooling = Lambda(lambda x: x[:, :-1])
142
- else:
143
- second_pooling = Last()
144
- if additional_out > 0:
145
- if additional_copy:
146
- model4 = get_rav_trans(
147
- data,
148
- plw=plw,
149
- loss_weight=loss_weight,
150
- **kwargs
151
- )
152
- model_load_weights(model4,
153
- load_weights,
154
- # sample_data,
155
- None,
156
- template="weights_{epoch:02d}-{val_loss:.2f}",
157
- key=result_metric,
158
- )
159
-
160
- if AUGMENTATION in kwargs and kwargs[AUGMENTATION] is not None:
161
- index = -1
162
- else:
163
- index = -2
164
- out2 = model4[0, index, :additional_out]
165
- logger.info(f"Additional out from: {model4[0, index]}.")
166
-
167
- if additional_out > 2:
168
- out2 += [Activation("gelu")]
169
- else:
170
- out2 = out
171
-
172
- second_pooling = bm([second_pooling] + out2, add_flatten=False)
173
-
174
- model2 = get_rav_trans(
175
- TakeDict(data[0])[:, 8:],
176
- plw=plw,
177
- loss_weight=loss_weight,
178
- **{
179
- **kwargs,
180
- "out_layers": second_pooling,
181
- }
182
- # **{**as_dict(p.mp), "show_shape": True, "save_shape": f"output/shapes/type_{p.mp.type_}.json"},
183
- )
184
- model_load_weights(
185
- model2,
186
- load_weights,
187
- # sample_data,
188
- None,
189
- template="weights_{epoch:02d}-{val_loss:.2f}",
190
- key=result_metric,
191
- )
192
- if select_type == 0:
193
- # not working
194
- trans_raven2 = model2[0]
195
- else:
196
- trans_raven2 = model2[0]
197
- tail = add_end(out_layers=tail_out, output_size=8 if select_out else 1)
198
- # trans_raven2.mask_fn = ImageMask(last=take_by_index)
199
- if select_type == 2:
200
- select_model_class = SelectModel2
201
- elif select_type == 1:
202
- select_model_class = SelectModel
203
- else:
204
- select_model_class = SelectModel9
205
- select_model = select_model_class(trans_raven, model2=trans_raven2, tail=tail, select_out=select_out)
206
- return select_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
raven_utils/models/uitls_.py DELETED
@@ -1,16 +0,0 @@
1
- import tensorflow as tf
2
- import tensorflow.experimental.numpy as tnp
3
- from tensorflow.keras import Model
4
- import raven_utils as rv
5
-
6
-
7
- class RangeMask(Model):
8
- def __init__(self):
9
- super().__init__()
10
- ranges = tf.tile(tf.range(rv.entity.INDEX[-1])[None], [rv.group.NO, 1])
11
- start_index = rv.entity.INDEX[:-1][:, None]
12
- end_index = rv.entity.INDEX[1:][:, None]
13
- self.mask = tnp.array((start_index <= ranges) & (ranges < end_index))
14
-
15
- def call(self, inputs):
16
- return self.mask[inputs]