darabos commited on
Commit
87f8b85
·
1 Parent(s): ae86f2c

Split repeat boxes into start/end boxes.

Browse files
examples/Model definition CHANGED
@@ -34,6 +34,20 @@
34
  "sourceHandle": "loss",
35
  "target": "Optimizer 2",
36
  "targetHandle": "loss"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  }
38
  ],
39
  "env": "PyTorch model",
@@ -42,6 +56,7 @@
42
  "data": {
43
  "display": null,
44
  "error": null,
 
45
  "meta": {
46
  "inputs": {},
47
  "name": "Input: embedding",
@@ -75,6 +90,7 @@
75
  "data": {
76
  "display": null,
77
  "error": null,
 
78
  "meta": {
79
  "inputs": {
80
  "x": {
@@ -126,6 +142,7 @@
126
  "data": {
127
  "display": null,
128
  "error": null,
 
129
  "meta": {
130
  "inputs": {
131
  "x": {
@@ -174,6 +191,7 @@
174
  "data": {
175
  "display": null,
176
  "error": null,
 
177
  "meta": {
178
  "inputs": {},
179
  "name": "Input: label",
@@ -209,6 +227,7 @@
209
  "collapsed": null,
210
  "display": null,
211
  "error": null,
 
212
  "meta": {
213
  "inputs": {
214
  "x": {
@@ -243,10 +262,6 @@
243
  }
244
  }
245
  },
246
- "position": {
247
- "x": 419.0,
248
- "y": 396.0
249
- },
250
  "type": "basic"
251
  },
252
  "params": {
@@ -271,6 +286,7 @@
271
  "collapsed": null,
272
  "display": null,
273
  "error": null,
 
274
  "meta": {
275
  "inputs": {
276
  "loss": {
@@ -307,10 +323,6 @@
307
  }
308
  }
309
  },
310
- "position": {
311
- "x": 526.0,
312
- "y": 116.0
313
- },
314
  "type": "basic"
315
  },
316
  "params": {
@@ -329,6 +341,72 @@
329
  },
330
  "type": "basic",
331
  "width": 200.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  }
333
  ]
334
  }
 
34
  "sourceHandle": "loss",
35
  "target": "Optimizer 2",
36
  "targetHandle": "loss"
37
+ },
38
+ {
39
+ "id": "Activation 2 Repeat 1",
40
+ "source": "Activation 2",
41
+ "sourceHandle": "x",
42
+ "target": "Repeat 1",
43
+ "targetHandle": "input"
44
+ },
45
+ {
46
+ "id": "Repeat 1 Linear 1",
47
+ "source": "Repeat 1",
48
+ "sourceHandle": "output",
49
+ "target": "Linear 1",
50
+ "targetHandle": "x"
51
  }
52
  ],
53
  "env": "PyTorch model",
 
56
  "data": {
57
  "display": null,
58
  "error": null,
59
+ "input_metadata": null,
60
  "meta": {
61
  "inputs": {},
62
  "name": "Input: embedding",
 
90
  "data": {
91
  "display": null,
92
  "error": null,
93
+ "input_metadata": null,
94
  "meta": {
95
  "inputs": {
96
  "x": {
 
142
  "data": {
143
  "display": null,
144
  "error": null,
145
+ "input_metadata": null,
146
  "meta": {
147
  "inputs": {
148
  "x": {
 
191
  "data": {
192
  "display": null,
193
  "error": null,
194
+ "input_metadata": null,
195
  "meta": {
196
  "inputs": {},
197
  "name": "Input: label",
 
227
  "collapsed": null,
228
  "display": null,
229
  "error": null,
230
+ "input_metadata": null,
231
  "meta": {
232
  "inputs": {
233
  "x": {
 
262
  }
263
  }
264
  },
 
 
 
 
265
  "type": "basic"
266
  },
267
  "params": {
 
286
  "collapsed": null,
287
  "display": null,
288
  "error": null,
289
+ "input_metadata": null,
290
  "meta": {
291
  "inputs": {
292
  "loss": {
 
323
  }
324
  }
325
  },
 
 
 
 
326
  "type": "basic"
327
  },
328
  "params": {
 
341
  },
342
  "type": "basic",
343
  "width": 200.0
344
+ },
345
+ {
346
+ "data": {
347
+ "__execution_delay": 0.0,
348
+ "collapsed": null,
349
+ "display": null,
350
+ "error": null,
351
+ "input_metadata": null,
352
+ "meta": {
353
+ "inputs": {
354
+ "input": {
355
+ "name": "input",
356
+ "position": "top",
357
+ "type": {
358
+ "type": "tensor"
359
+ }
360
+ }
361
+ },
362
+ "name": "Repeat",
363
+ "outputs": {
364
+ "output": {
365
+ "name": "output",
366
+ "position": "bottom",
367
+ "type": {
368
+ "type": "tensor"
369
+ }
370
+ }
371
+ },
372
+ "params": {
373
+ "same_weights": {
374
+ "default": true,
375
+ "name": "same_weights",
376
+ "type": {
377
+ "type": "<class 'bool'>"
378
+ }
379
+ },
380
+ "times": {
381
+ "default": 1.0,
382
+ "name": "times",
383
+ "type": {
384
+ "type": "<class 'int'>"
385
+ }
386
+ }
387
+ },
388
+ "position": {
389
+ "x": 386.0,
390
+ "y": 456.0
391
+ },
392
+ "type": "basic"
393
+ },
394
+ "params": {
395
+ "same_weights": false,
396
+ "times": "2"
397
+ },
398
+ "status": "planned",
399
+ "title": "Repeat"
400
+ },
401
+ "dragHandle": ".bg-primary",
402
+ "height": 200.0,
403
+ "id": "Repeat 1",
404
+ "position": {
405
+ "x": -180.0,
406
+ "y": -90.0
407
+ },
408
+ "type": "basic",
409
+ "width": 200.0
410
  }
411
  ]
412
  }
examples/Model use CHANGED
@@ -579,54 +579,54 @@
579
  ],
580
  "data": [
581
  [
582
- "[0.49691743 0.61873293 0.90698647 0.94486356]",
583
- "[1.49691749 1.61873293 1.90698647 1.94486356]",
584
- "[1.4993021488189697, 1.6404846906661987, 1.923316240310669, 1.9422152042388916]"
585
- ],
586
- [
587
- "[0.56922203 0.98222166 0.76851749 0.28615737]",
588
- "[1.56922197 1.9822216 1.76851749 1.28615737]",
589
- "[1.5835213661193848, 1.9884355068206787, 1.7694181203842163, 1.2917503118515015]"
590
  ],
591
  [
592
- "[0.90817457 0.89270043 0.38583666 0.66566533]",
593
- "[1.90817451 1.89270043 1.3858366 1.66566539]",
594
- "[1.9053494930267334, 1.9083378314971924, 1.3998609781265259, 1.6636812686920166]"
595
  ],
596
  [
597
- "[0.72795159 0.79317838 0.27832931 0.96576637]",
598
- "[1.72795153 1.79317832 1.27832937 1.96576643]",
599
- "[1.734963297843933, 1.8026459217071533, 1.2926064729690552, 1.9596911668777466]"
600
  ],
601
  [
602
- "[0.04508126 0.76880038 0.80721325 0.62542385]",
603
- "[1.04508126 1.76880038 1.80721331 1.62542391]",
604
- "[1.0830243825912476, 1.7584562301635742, 1.8005754947662354, 1.6277496814727783]"
605
  ],
606
  [
607
- "[0.6032477 0.83361369 0.18538666 0.19108021]",
608
- "[1.60324764 1.83361363 1.18538666 1.19108021]",
609
- "[1.6177492141723633, 1.8144152164459229, 1.1718573570251465, 1.1950569152832031]"
610
  ],
611
  [
612
- "[0.15064228 0.03198934 0.25754827 0.51484001]",
613
- "[1.15064228 1.03198934 1.25754833 1.51484001]",
614
- "[1.1556042432785034, 0.9955940246582031, 1.2316606044769287, 1.5150485038757324]"
615
  ],
616
  [
617
- "[0.48959708 0.48549271 0.32688856 0.356677 ]",
618
- "[1.48959708 1.48549271 1.32688856 1.35667706]",
619
- "[1.4930214881896973, 1.467790961265564, 1.3132573366165161, 1.3589863777160645]"
620
  ],
621
  [
622
- "[0.08107251 0.2602725 0.18861133 0.44833237]",
623
- "[1.08107257 1.2602725 1.18861127 1.44833231]",
624
- "[1.102121114730835, 1.2180893421173096, 1.160165548324585, 1.4495322704315186]"
625
  ],
626
  [
627
  "[0.68094063 0.45189077 0.22661722 0.37354094]",
628
  "[1.68094063 1.45189071 1.22661722 1.37354088]",
629
- "[1.6725687980651855, 1.4393560886383057, 1.2169336080551147, 1.3746893405914307]"
 
 
 
 
 
630
  ]
631
  ]
632
  },
@@ -644,10 +644,6 @@
644
  "[0.85706753 0.61447072 0.41741937 0.85147089]",
645
  "[1.85706758 1.61447072 1.41741943 1.85147095]"
646
  ],
647
- [
648
- "[0.11560339 0.57495481 0.76535827 0.0391947 ]",
649
- "[1.11560345 1.57495475 1.76535821 1.0391947 ]"
650
- ],
651
  [
652
  "[0.19409031 0.68692201 0.60667384 0.57829887]",
653
  "[1.19409037 1.68692207 1.60667384 1.57829881]"
@@ -688,10 +684,18 @@
688
  "[0.11693293 0.49860179 0.55020827 0.88832849]",
689
  "[1.11693287 1.49860179 1.55020833 1.88832855]"
690
  ],
 
 
 
 
691
  [
692
  "[0.50272274 0.54912758 0.17663097 0.79070699]",
693
  "[1.50272274 1.54912758 1.17663097 1.79070699]"
694
  ],
 
 
 
 
695
  [
696
  "[0.19908059 0.17570406 0.51475513 0.1893943 ]",
697
  "[1.19908059 1.175704 1.51475513 1.18939424]"
@@ -709,13 +713,17 @@
709
  "[1.24388778 1.07268476 1.68350863 1.73431659]"
710
  ],
711
  [
712
- "[0.62569475 0.9881897 0.83639616 0.9828859 ]",
713
- "[1.62569475 1.9881897 1.83639622 1.98288584]"
714
  ],
715
  [
716
  "[0.88776821 0.51636773 0.30333066 0.32230979]",
717
  "[1.88776827 1.51636767 1.30333066 1.32230973]"
718
  ],
 
 
 
 
719
  [
720
  "[0.48507756 0.80808765 0.77162558 0.47834778]",
721
  "[1.48507762 1.80808759 1.77162552 1.47834778]"
@@ -728,10 +736,6 @@
728
  "[0.31518555 0.49643308 0.11509258 0.95458382]",
729
  "[1.31518555 1.49643302 1.11509252 1.95458388]"
730
  ],
731
- [
732
- "[0.79121011 0.54161114 0.69369799 0.1520769 ]",
733
- "[1.79121017 1.54161119 1.69369793 1.15207696]"
734
- ],
735
  [
736
  "[0.79423058 0.07138705 0.061777 0.18766576]",
737
  "[1.79423058 1.07138705 1.061777 1.1876657 ]"
@@ -764,10 +768,6 @@
764
  "[0.98033333 0.97656083 0.38939917 0.81491041]",
765
  "[1.98033333 1.97656083 1.38939917 1.81491041]"
766
  ],
767
- [
768
- "[0.74064726 0.4155122 0.09800029 0.49930882]",
769
- "[1.74064732 1.4155122 1.09800029 1.49930882]"
770
- ],
771
  [
772
  "[0.78956431 0.87284744 0.06880784 0.03455889]",
773
  "[1.78956437 1.87284744 1.06880784 1.03455889]"
@@ -804,10 +804,6 @@
804
  "[0.73217702 0.65233225 0.44077861 0.33837909]",
805
  "[1.73217702 1.65233231 1.44077861 1.33837914]"
806
  ],
807
- [
808
- "[0.34084332 0.73018837 0.54168713 0.91440833]",
809
- "[1.34084332 1.73018837 1.54168713 1.91440833]"
810
- ],
811
  [
812
  "[0.60110539 0.3618983 0.32342511 0.98672163]",
813
  "[1.60110545 1.3618983 1.32342505 1.98672163]"
@@ -816,6 +812,10 @@
816
  "[0.77427191 0.21829212 0.12769502 0.74303615]",
817
  "[1.77427197 1.21829212 1.12769508 1.74303615]"
818
  ],
 
 
 
 
819
  [
820
  "[0.59812403 0.78395379 0.0291847 0.81814629]",
821
  "[1.59812403 1.78395379 1.0291847 1.81814623]"
@@ -856,10 +856,6 @@
856
  "[0.54914117 0.03810108 0.87531954 0.73044223]",
857
  "[1.54914117 1.03810108 1.87531948 1.73044229]"
858
  ],
859
- [
860
- "[0.67418337 0.79634351 0.23229051 0.71345252]",
861
- "[1.67418337 1.79634356 1.23229051 1.71345258]"
862
- ],
863
  [
864
  "[0.87285906 0.48354989 0.39394957 0.59456545]",
865
  "[1.872859 1.48354983 1.39394951 1.59456539]"
@@ -908,10 +904,6 @@
908
  "[0.60609657 0.96257663 0.19292736 0.95702219]",
909
  "[1.60609651 1.96257663 1.19292736 1.95702219]"
910
  ],
911
- [
912
- "[0.80654246 0.08253473 0.74478531 0.71257162]",
913
- "[1.8065424 1.08253479 1.74478531 1.71257162]"
914
- ],
915
  [
916
  "[0.70167565 0.26930219 0.5660674 0.61194974]",
917
  "[1.70167565 1.26930213 1.56606746 1.61194968]"
@@ -924,6 +916,10 @@
924
  "[0.59492421 0.90274489 0.38069052 0.46101224]",
925
  "[1.59492421 1.90274489 1.38069057 1.46101224]"
926
  ],
 
 
 
 
927
  [
928
  "[0.12024075 0.21342516 0.56858408 0.58644271]",
929
  "[1.12024069 1.21342516 1.56858408 1.58644271]"
@@ -933,8 +929,8 @@
933
  "[1.91730917 1.22574067 1.09591603 1.33056474]"
934
  ],
935
  [
936
- "[0.63235509 0.70352674 0.96188956 0.46240485]",
937
- "[1.63235509 1.70352674 1.96188951 1.46240485]"
938
  ],
939
  [
940
  "[0.37959969 0.42820001 0.10690689 0.96353984]",
@@ -988,6 +984,10 @@
988
  "[0.47870928 0.17129105 0.27300501 0.20634609]",
989
  "[1.47870922 1.17129111 1.27300501 1.20634604]"
990
  ],
 
 
 
 
991
  [
992
  "[0.87608397 0.93200487 0.80169648 0.37758952]",
993
  "[1.87608397 1.93200493 1.80169654 1.37758946]"
@@ -1000,7 +1000,7 @@
1000
  }
1001
  },
1002
  "other": {
1003
- "model": "ModelConfig(model=Sequential(\n (0) - Linear(in_features=4, out_features=4, bias=True): Input__embedding_1_x -> Linear_1_x\n (1) - <function leaky_relu at 0x719e0ce23a60>: Linear_1_x -> Activation_2_x\n (2) - Identity(): Activation_2_x -> Activation_2_x\n), model_inputs=['Input__embedding_1_x'], model_outputs=['Activation_2_x'], loss_inputs=['Input__label_1_y', 'Activation_2_x'], loss=Sequential(\n (0) - <function mse_loss at 0x719e0ce2d580>: Activation_2_x, Input__label_1_y -> MSE_loss_1_loss\n (1) - Identity(): MSE_loss_1_loss -> loss\n), optimizer=SGD (\nParameter Group 0\n dampening: 0\n differentiable: False\n foreach: None\n fused: None\n lr: 0.1\n maximize: False\n momentum: 0\n nesterov: False\n weight_decay: 0\n), source_workspace=None, trained=True)"
1004
  },
1005
  "relations": []
1006
  },
@@ -1036,10 +1036,10 @@
1036
  ],
1037
  "loss_inputs": [
1038
  "Input__label_1_y",
1039
- "Activation_2_x"
1040
  ],
1041
  "outputs": [
1042
- "Activation_2_x"
1043
  ],
1044
  "trained": true
1045
  },
@@ -1211,10 +1211,10 @@
1211
  ],
1212
  "loss_inputs": [
1213
  "Input__label_1_y",
1214
- "Activation_2_x"
1215
  ],
1216
  "outputs": [
1217
- "Activation_2_x"
1218
  ],
1219
  "trained": false
1220
  },
@@ -1270,7 +1270,7 @@
1270
  "type": "basic"
1271
  },
1272
  "params": {
1273
- "epochs": "1001",
1274
  "input_mapping": "{\"map\":{\"Input__embedding_1_x\":{\"column\":\"x\",\"df\":\"df_train\"},\"Input__label_1_y\":{\"column\":\"y\",\"df\":\"df_train\"}}}",
1275
  "model_name": "model"
1276
  },
@@ -1304,7 +1304,6 @@
1304
  },
1305
  "df_test": {
1306
  "columns": [
1307
- "predicted",
1308
  "x",
1309
  "y"
1310
  ]
@@ -1324,10 +1323,10 @@
1324
  ],
1325
  "loss_inputs": [
1326
  "Input__label_1_y",
1327
- "Activation_2_x"
1328
  ],
1329
  "outputs": [
1330
- "Activation_2_x"
1331
  ],
1332
  "trained": true
1333
  },
@@ -1383,9 +1382,9 @@
1383
  "type": "basic"
1384
  },
1385
  "params": {
1386
- "input_mapping": "{\"map\":{\"Input__embedding_1_x\":{\"column\":\"x\",\"df\":\"df_test\"}}}",
1387
  "model_name": "model",
1388
- "output_mapping": "{\"map\":{\"Activation_2_x\":{\"column\":\"predicted\",\"df\":\"df_test\"}}}"
1389
  },
1390
  "status": "done",
1391
  "title": "Model inference"
 
579
  ],
580
  "data": [
581
  [
582
+ "[0.67418337 0.79634351 0.23229051 0.71345252]",
583
+ "[1.67418337 1.79634356 1.23229051 1.71345258]",
584
+ "[1.2304607629776, 1.2535771131515503, 1.4764351844787598, 1.524468183517456]"
 
 
 
 
 
585
  ],
586
  [
587
+ "[0.6032477 0.83361369 0.18538666 0.19108021]",
588
+ "[1.60324764 1.83361363 1.18538666 1.19108021]",
589
+ "[1.0745662450790405, 1.3498694896697998, 1.524811029434204, 1.2336101531982422]"
590
  ],
591
  [
592
+ "[0.62569475 0.9881897 0.83639616 0.9828859 ]",
593
+ "[1.62569475 1.9881897 1.83639622 1.98288584]",
594
+ "[1.3193368911743164, 1.691685438156128, 1.4516379833221436, 1.6486209630966187]"
595
  ],
596
  [
597
+ "[0.63235509 0.70352674 0.96188956 0.46240485]",
598
+ "[1.63235509 1.70352674 1.96188951 1.46240485]",
599
+ "[1.298614740371704, 1.7463937997817993, 1.4211854934692383, 1.2179659605026245]"
600
  ],
601
  [
602
+ "[0.74064726 0.4155122 0.09800029 0.49930882]",
603
+ "[1.74064732 1.4155122 1.09800029 1.49930882]",
604
+ "[1.2740169763565063, 1.0242725610733032, 1.413628339767456, 1.288135290145874]"
605
  ],
606
  [
607
+ "[0.34084332 0.73018837 0.54168713 0.91440833]",
608
+ "[1.34084332 1.73018837 1.54168713 1.91440833]",
609
+ "[1.0640922784805298, 1.3728828430175781, 1.1964272260665894, 1.5791122913360596]"
610
  ],
611
  [
612
+ "[0.80654246 0.08253473 0.74478531 0.71257162]",
613
+ "[1.8065424 1.08253479 1.74478531 1.71257162]",
614
+ "[1.521227240562439, 1.247929334640503, 1.2715831995010376, 1.1901893615722656]"
615
  ],
616
  [
617
+ "[0.11560339 0.57495481 0.76535827 0.0391947 ]",
618
+ "[1.11560345 1.57495475 1.76535821 1.0391947 ]",
619
+ "[0.8016152381896973, 1.6247310638427734, 1.1153241395950317, 0.9691400527954102]"
620
  ],
621
  [
622
  "[0.68094063 0.45189077 0.22661722 0.37354094]",
623
  "[1.68094063 1.45189071 1.22661722 1.37354088]",
624
+ "[1.224153995513916, 1.1527836322784424, 1.4024348258972168, 1.204542636871338]"
625
+ ],
626
+ [
627
+ "[0.79121011 0.54161114 0.69369799 0.1520769 ]",
628
+ "[1.79121017 1.54161119 1.69369793 1.15207696]",
629
+ "[1.3464093208312988, 1.5594894886016846, 1.5191831588745117, 1.0183898210525513]"
630
  ]
631
  ]
632
  },
 
644
  "[0.85706753 0.61447072 0.41741937 0.85147089]",
645
  "[1.85706758 1.61447072 1.41741943 1.85147095]"
646
  ],
 
 
 
 
647
  [
648
  "[0.19409031 0.68692201 0.60667384 0.57829887]",
649
  "[1.19409037 1.68692207 1.60667384 1.57829881]"
 
684
  "[0.11693293 0.49860179 0.55020827 0.88832849]",
685
  "[1.11693287 1.49860179 1.55020833 1.88832855]"
686
  ],
687
+ [
688
+ "[0.48959708 0.48549271 0.32688856 0.356677 ]",
689
+ "[1.48959708 1.48549271 1.32688856 1.35667706]"
690
+ ],
691
  [
692
  "[0.50272274 0.54912758 0.17663097 0.79070699]",
693
  "[1.50272274 1.54912758 1.17663097 1.79070699]"
694
  ],
695
+ [
696
+ "[0.04508126 0.76880038 0.80721325 0.62542385]",
697
+ "[1.04508126 1.76880038 1.80721331 1.62542391]"
698
+ ],
699
  [
700
  "[0.19908059 0.17570406 0.51475513 0.1893943 ]",
701
  "[1.19908059 1.175704 1.51475513 1.18939424]"
 
713
  "[1.24388778 1.07268476 1.68350863 1.73431659]"
714
  ],
715
  [
716
+ "[0.56922203 0.98222166 0.76851749 0.28615737]",
717
+ "[1.56922197 1.9822216 1.76851749 1.28615737]"
718
  ],
719
  [
720
  "[0.88776821 0.51636773 0.30333066 0.32230979]",
721
  "[1.88776827 1.51636767 1.30333066 1.32230973]"
722
  ],
723
+ [
724
+ "[0.90817457 0.89270043 0.38583666 0.66566533]",
725
+ "[1.90817451 1.89270043 1.3858366 1.66566539]"
726
+ ],
727
  [
728
  "[0.48507756 0.80808765 0.77162558 0.47834778]",
729
  "[1.48507762 1.80808759 1.77162552 1.47834778]"
 
736
  "[0.31518555 0.49643308 0.11509258 0.95458382]",
737
  "[1.31518555 1.49643302 1.11509252 1.95458388]"
738
  ],
 
 
 
 
739
  [
740
  "[0.79423058 0.07138705 0.061777 0.18766576]",
741
  "[1.79423058 1.07138705 1.061777 1.1876657 ]"
 
768
  "[0.98033333 0.97656083 0.38939917 0.81491041]",
769
  "[1.98033333 1.97656083 1.38939917 1.81491041]"
770
  ],
 
 
 
 
771
  [
772
  "[0.78956431 0.87284744 0.06880784 0.03455889]",
773
  "[1.78956437 1.87284744 1.06880784 1.03455889]"
 
804
  "[0.73217702 0.65233225 0.44077861 0.33837909]",
805
  "[1.73217702 1.65233231 1.44077861 1.33837914]"
806
  ],
 
 
 
 
807
  [
808
  "[0.60110539 0.3618983 0.32342511 0.98672163]",
809
  "[1.60110545 1.3618983 1.32342505 1.98672163]"
 
812
  "[0.77427191 0.21829212 0.12769502 0.74303615]",
813
  "[1.77427197 1.21829212 1.12769508 1.74303615]"
814
  ],
815
+ [
816
+ "[0.08107251 0.2602725 0.18861133 0.44833237]",
817
+ "[1.08107257 1.2602725 1.18861127 1.44833231]"
818
+ ],
819
  [
820
  "[0.59812403 0.78395379 0.0291847 0.81814629]",
821
  "[1.59812403 1.78395379 1.0291847 1.81814623]"
 
856
  "[0.54914117 0.03810108 0.87531954 0.73044223]",
857
  "[1.54914117 1.03810108 1.87531948 1.73044229]"
858
  ],
 
 
 
 
859
  [
860
  "[0.87285906 0.48354989 0.39394957 0.59456545]",
861
  "[1.872859 1.48354983 1.39394951 1.59456539]"
 
904
  "[0.60609657 0.96257663 0.19292736 0.95702219]",
905
  "[1.60609651 1.96257663 1.19292736 1.95702219]"
906
  ],
 
 
 
 
907
  [
908
  "[0.70167565 0.26930219 0.5660674 0.61194974]",
909
  "[1.70167565 1.26930213 1.56606746 1.61194968]"
 
916
  "[0.59492421 0.90274489 0.38069052 0.46101224]",
917
  "[1.59492421 1.90274489 1.38069057 1.46101224]"
918
  ],
919
+ [
920
+ "[0.15064228 0.03198934 0.25754827 0.51484001]",
921
+ "[1.15064228 1.03198934 1.25754833 1.51484001]"
922
+ ],
923
  [
924
  "[0.12024075 0.21342516 0.56858408 0.58644271]",
925
  "[1.12024069 1.21342516 1.56858408 1.58644271]"
 
929
  "[1.91730917 1.22574067 1.09591603 1.33056474]"
930
  ],
931
  [
932
+ "[0.49691743 0.61873293 0.90698647 0.94486356]",
933
+ "[1.49691749 1.61873293 1.90698647 1.94486356]"
934
  ],
935
  [
936
  "[0.37959969 0.42820001 0.10690689 0.96353984]",
 
984
  "[0.47870928 0.17129105 0.27300501 0.20634609]",
985
  "[1.47870922 1.17129111 1.27300501 1.20634604]"
986
  ],
987
+ [
988
+ "[0.72795159 0.79317838 0.27832931 0.96576637]",
989
+ "[1.72795153 1.79317832 1.27832937 1.96576643]"
990
+ ],
991
  [
992
  "[0.87608397 0.93200487 0.80169648 0.37758952]",
993
  "[1.87608397 1.93200493 1.80169654 1.37758946]"
 
1000
  }
1001
  },
1002
  "other": {
1003
+ "model": "ModelConfig(model=Sequential(\n (0) - Identity(): Input__embedding_1_x -> START_Repeat_1_output\n (1) - Linear(in_features=4, out_features=4, bias=True): START_Repeat_1_output -> Linear_1_x\n (2) - <function leaky_relu at 0x7e6e1cf2c220>: Linear_1_x -> Activation_2_x\n (3) - Identity(): Activation_2_x -> END_Repeat_1_output\n (4) - Identity(): END_Repeat_1_output -> END_Repeat_1_output\n), model_inputs=['Input__embedding_1_x'], model_outputs=['END_Repeat_1_output'], loss_inputs=['Input__label_1_y', 'END_Repeat_1_output'], loss=Sequential(\n (0) - <function mse_loss at 0x7e6e1cf2dd00>: END_Repeat_1_output, Input__label_1_y -> MSE_loss_1_loss\n (1) - Identity(): MSE_loss_1_loss -> loss\n), optimizer=SGD (\nParameter Group 0\n dampening: 0\n differentiable: False\n foreach: None\n fused: None\n lr: 0.1\n maximize: False\n momentum: 0\n nesterov: False\n weight_decay: 0\n), source_workspace=None, trained=True)"
1004
  },
1005
  "relations": []
1006
  },
 
1036
  ],
1037
  "loss_inputs": [
1038
  "Input__label_1_y",
1039
+ "END_Repeat_1_output"
1040
  ],
1041
  "outputs": [
1042
+ "END_Repeat_1_output"
1043
  ],
1044
  "trained": true
1045
  },
 
1211
  ],
1212
  "loss_inputs": [
1213
  "Input__label_1_y",
1214
+ "END_Repeat_1_output"
1215
  ],
1216
  "outputs": [
1217
+ "END_Repeat_1_output"
1218
  ],
1219
  "trained": false
1220
  },
 
1270
  "type": "basic"
1271
  },
1272
  "params": {
1273
+ "epochs": "15",
1274
  "input_mapping": "{\"map\":{\"Input__embedding_1_x\":{\"column\":\"x\",\"df\":\"df_train\"},\"Input__label_1_y\":{\"column\":\"y\",\"df\":\"df_train\"}}}",
1275
  "model_name": "model"
1276
  },
 
1304
  },
1305
  "df_test": {
1306
  "columns": [
 
1307
  "x",
1308
  "y"
1309
  ]
 
1323
  ],
1324
  "loss_inputs": [
1325
  "Input__label_1_y",
1326
+ "END_Repeat_1_output"
1327
  ],
1328
  "outputs": [
1329
+ "END_Repeat_1_output"
1330
  ],
1331
  "trained": true
1332
  },
 
1382
  "type": "basic"
1383
  },
1384
  "params": {
1385
+ "input_mapping": "{\"map\":{\"Input__embedding_1_x\":{\"df\":\"df_test\",\"column\":\"x\"}}}",
1386
  "model_name": "model",
1387
+ "output_mapping": "{\"map\":{\"END_Repeat_1_output\":{\"df\":\"df_test\",\"column\":\"predicted\"}}}"
1388
  },
1389
  "status": "done",
1390
  "title": "Model inference"
lynxkite-graph-analytics/src/lynxkite_graph_analytics/pytorch_model_ops.py CHANGED
@@ -110,7 +110,7 @@ ops.register_passive_op(
110
  outputs=[ops.Output(name="output", position="bottom", type="tensor")],
111
  params=[
112
  ops.Parameter.basic("times", 1, int),
113
- ops.Parameter.basic("same_weights", True, bool),
114
  ],
115
  )
116
 
@@ -207,8 +207,10 @@ def build_model(ws: workspace.Workspace, inputs: dict[str, torch.Tensor]) -> Mod
207
  dependencies = {n.id: [] for n in ws.nodes}
208
  in_edges = {}
209
  out_edges = {}
210
- # TODO: Dissolve repeat boxes here.
211
  for e in ws.edges:
 
 
212
  dependencies[e.target].append(e.source)
213
  in_edges.setdefault(e.target, {}).setdefault(e.targetHandle, []).append(
214
  (e.source, e.sourceHandle)
@@ -216,34 +218,78 @@ def build_model(ws: workspace.Workspace, inputs: dict[str, torch.Tensor]) -> Mod
216
  out_edges.setdefault(e.source, {}).setdefault(e.sourceHandle, []).append(
217
  (e.target, e.targetHandle)
218
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  sizes = {}
220
  for k, i in inputs.items():
221
  sizes[k] = i.shape[-1]
222
  ts = graphlib.TopologicalSorter(dependencies)
223
  layers = []
224
  loss_layers = []
225
- in_loss = set()
226
  cfg = {}
227
  used_in_model = set()
228
  made_in_model = set()
229
  used_in_loss = set()
230
  made_in_loss = set()
231
  for node_id in ts.static_order():
232
- node = nodes[node_id]
 
 
 
 
 
233
  t = node.data.title
234
  op = catalog[t]
235
  p = op.convert_params(node.data.params)
236
  for b in dependencies[node_id]:
237
- if b in in_loss:
238
- in_loss.add(node_id)
239
  if "loss" in t:
240
- in_loss.add(node_id)
241
  inputs = {}
242
  for n in in_edges.get(node_id, []):
243
  for b, h in in_edges[node_id][n]:
244
  i = _to_id(b, h)
245
  inputs[n] = i
246
- if node_id in in_loss:
247
  used_in_loss.add(i)
248
  else:
249
  used_in_model.add(i)
@@ -252,13 +298,13 @@ def build_model(ws: workspace.Workspace, inputs: dict[str, torch.Tensor]) -> Mod
252
  i = _to_id(node_id, out)
253
  outputs[out] = i
254
  if inputs: # Nodes with no inputs are input nodes. Their outputs are not "made" by us.
255
- if node_id in in_loss:
256
  made_in_loss.add(i)
257
  else:
258
  made_in_model.add(i)
259
  inputs = types.SimpleNamespace(**inputs)
260
  outputs = types.SimpleNamespace(**outputs)
261
- ls = loss_layers if node_id in in_loss else layers
262
  match t:
263
  case "Linear":
264
  isize = sizes.get(inputs.x, 1)
@@ -276,6 +322,19 @@ def build_model(ws: workspace.Workspace, inputs: dict[str, torch.Tensor]) -> Mod
276
  f"{inputs.x}, {inputs.y} -> {outputs.loss}",
277
  )
278
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  cfg["model_inputs"] = list(used_in_model - made_in_model)
280
  cfg["model_outputs"] = list(made_in_model & used_in_loss)
281
  cfg["loss_inputs"] = list(used_in_loss - made_in_loss)
 
110
  outputs=[ops.Output(name="output", position="bottom", type="tensor")],
111
  params=[
112
  ops.Parameter.basic("times", 1, int),
113
+ ops.Parameter.basic("same_weights", False, bool),
114
  ],
115
  )
116
 
 
207
  dependencies = {n.id: [] for n in ws.nodes}
208
  in_edges = {}
209
  out_edges = {}
210
+ repeats = []
211
  for e in ws.edges:
212
+ if nodes[e.target].data.title == "Repeat":
213
+ repeats.append(e.target)
214
  dependencies[e.target].append(e.source)
215
  in_edges.setdefault(e.target, {}).setdefault(e.targetHandle, []).append(
216
  (e.source, e.sourceHandle)
 
218
  out_edges.setdefault(e.source, {}).setdefault(e.sourceHandle, []).append(
219
  (e.target, e.targetHandle)
220
  )
221
+ # Split repeat boxes into start and end, and insert them into the flow.
222
+ # TODO: Think about recursive repeats.
223
+ for repeat in repeats:
224
+ start_id = f"START {repeat}"
225
+ end_id = f"END {repeat}"
226
+ # repeat -> first <- real_input
227
+ # ...becomes...
228
+ # real_input -> start -> first
229
+ first, firsth = out_edges[repeat]["output"][0]
230
+ [(real_input, real_inputh)] = [
231
+ k for k in in_edges[first][firsth] if k != (repeat, "output")
232
+ ]
233
+ dependencies[first].remove(repeat)
234
+ dependencies[first].append(start_id)
235
+ dependencies[start_id] = [real_input]
236
+ out_edges[real_input][real_inputh] = [
237
+ k if k != (first, firsth) else (start_id, "input")
238
+ for k in out_edges[real_input][real_inputh]
239
+ ]
240
+ in_edges[start_id] = {"input": [(real_input, real_inputh)]}
241
+ out_edges[start_id] = {"output": [(first, firsth)]}
242
+ in_edges[first][firsth] = [(start_id, "output")]
243
+ # repeat <- last -> real_output
244
+ # ...becomes...
245
+ # last -> end -> real_output
246
+ last, lasth = in_edges[repeat]["input"][0]
247
+ [(real_output, real_outputh)] = [
248
+ k for k in out_edges[last][lasth] if k != (repeat, "input")
249
+ ]
250
+ del dependencies[repeat]
251
+ dependencies[end_id] = [last]
252
+ dependencies[real_output].append(end_id)
253
+ out_edges[last][lasth] = [(end_id, "input")]
254
+ in_edges[end_id] = {"input": [(last, lasth)]}
255
+ out_edges[end_id] = {"output": [(real_output, real_outputh)]}
256
+ in_edges[real_output][real_outputh] = [
257
+ k if k != (last, lasth) else (end_id, "output")
258
+ for k in in_edges[real_output][real_outputh]
259
+ ]
260
+ # Walk the graph in topological order.
261
  sizes = {}
262
  for k, i in inputs.items():
263
  sizes[k] = i.shape[-1]
264
  ts = graphlib.TopologicalSorter(dependencies)
265
  layers = []
266
  loss_layers = []
267
+ regions: dict[str, set[str]] = {node_id: set() for node_id in dependencies}
268
  cfg = {}
269
  used_in_model = set()
270
  made_in_model = set()
271
  used_in_loss = set()
272
  made_in_loss = set()
273
  for node_id in ts.static_order():
274
+ if node_id.startswith("START "):
275
+ node = nodes[node_id.removeprefix("START ")]
276
+ elif node_id.startswith("END "):
277
+ node = nodes[node_id.removeprefix("END ")]
278
+ else:
279
+ node = nodes[node_id]
280
  t = node.data.title
281
  op = catalog[t]
282
  p = op.convert_params(node.data.params)
283
  for b in dependencies[node_id]:
284
+ regions[node_id] |= regions[b]
 
285
  if "loss" in t:
286
+ regions[node_id].add("loss")
287
  inputs = {}
288
  for n in in_edges.get(node_id, []):
289
  for b, h in in_edges[node_id][n]:
290
  i = _to_id(b, h)
291
  inputs[n] = i
292
+ if "loss" in regions[node_id]:
293
  used_in_loss.add(i)
294
  else:
295
  used_in_model.add(i)
 
298
  i = _to_id(node_id, out)
299
  outputs[out] = i
300
  if inputs: # Nodes with no inputs are input nodes. Their outputs are not "made" by us.
301
+ if "loss" in regions[node_id]:
302
  made_in_loss.add(i)
303
  else:
304
  made_in_model.add(i)
305
  inputs = types.SimpleNamespace(**inputs)
306
  outputs = types.SimpleNamespace(**outputs)
307
+ ls = loss_layers if "loss" in regions[node_id] else layers
308
  match t:
309
  case "Linear":
310
  isize = sizes.get(inputs.x, 1)
 
322
  f"{inputs.x}, {inputs.y} -> {outputs.loss}",
323
  )
324
  )
325
+ case "Repeat":
326
+ ls.append((torch.nn.Identity(), f"{inputs.input} -> {outputs.output}"))
327
+ sizes[outputs.output] = sizes.get(inputs.input, 1)
328
+ if node_id.startswith("START "):
329
+ regions[node_id].add(("repeat", node_id.removeprefix("START ")))
330
+ else:
331
+ repeat_id = node_id.removeprefix("END ")
332
+ print(f"repeat {repeat_id} ending")
333
+ regions[node_id].remove(("repeat", repeat_id))
334
+ for n in nodes:
335
+ r = regions.get(n, set())
336
+ if ("repeat", repeat_id) in r:
337
+ print(f"repeating {n}")
338
  cfg["model_inputs"] = list(used_in_model - made_in_model)
339
  cfg["model_outputs"] = list(made_in_model & used_in_loss)
340
  cfg["loss_inputs"] = list(used_in_loss - made_in_loss)