Spaces:
Running
Running
Split repeat boxes into start/end boxes.
Browse files
examples/Model definition
CHANGED
@@ -34,6 +34,20 @@
|
|
34 |
"sourceHandle": "loss",
|
35 |
"target": "Optimizer 2",
|
36 |
"targetHandle": "loss"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
}
|
38 |
],
|
39 |
"env": "PyTorch model",
|
@@ -42,6 +56,7 @@
|
|
42 |
"data": {
|
43 |
"display": null,
|
44 |
"error": null,
|
|
|
45 |
"meta": {
|
46 |
"inputs": {},
|
47 |
"name": "Input: embedding",
|
@@ -75,6 +90,7 @@
|
|
75 |
"data": {
|
76 |
"display": null,
|
77 |
"error": null,
|
|
|
78 |
"meta": {
|
79 |
"inputs": {
|
80 |
"x": {
|
@@ -126,6 +142,7 @@
|
|
126 |
"data": {
|
127 |
"display": null,
|
128 |
"error": null,
|
|
|
129 |
"meta": {
|
130 |
"inputs": {
|
131 |
"x": {
|
@@ -174,6 +191,7 @@
|
|
174 |
"data": {
|
175 |
"display": null,
|
176 |
"error": null,
|
|
|
177 |
"meta": {
|
178 |
"inputs": {},
|
179 |
"name": "Input: label",
|
@@ -209,6 +227,7 @@
|
|
209 |
"collapsed": null,
|
210 |
"display": null,
|
211 |
"error": null,
|
|
|
212 |
"meta": {
|
213 |
"inputs": {
|
214 |
"x": {
|
@@ -243,10 +262,6 @@
|
|
243 |
}
|
244 |
}
|
245 |
},
|
246 |
-
"position": {
|
247 |
-
"x": 419.0,
|
248 |
-
"y": 396.0
|
249 |
-
},
|
250 |
"type": "basic"
|
251 |
},
|
252 |
"params": {
|
@@ -271,6 +286,7 @@
|
|
271 |
"collapsed": null,
|
272 |
"display": null,
|
273 |
"error": null,
|
|
|
274 |
"meta": {
|
275 |
"inputs": {
|
276 |
"loss": {
|
@@ -307,10 +323,6 @@
|
|
307 |
}
|
308 |
}
|
309 |
},
|
310 |
-
"position": {
|
311 |
-
"x": 526.0,
|
312 |
-
"y": 116.0
|
313 |
-
},
|
314 |
"type": "basic"
|
315 |
},
|
316 |
"params": {
|
@@ -329,6 +341,72 @@
|
|
329 |
},
|
330 |
"type": "basic",
|
331 |
"width": 200.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
}
|
333 |
]
|
334 |
}
|
|
|
34 |
"sourceHandle": "loss",
|
35 |
"target": "Optimizer 2",
|
36 |
"targetHandle": "loss"
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"id": "Activation 2 Repeat 1",
|
40 |
+
"source": "Activation 2",
|
41 |
+
"sourceHandle": "x",
|
42 |
+
"target": "Repeat 1",
|
43 |
+
"targetHandle": "input"
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"id": "Repeat 1 Linear 1",
|
47 |
+
"source": "Repeat 1",
|
48 |
+
"sourceHandle": "output",
|
49 |
+
"target": "Linear 1",
|
50 |
+
"targetHandle": "x"
|
51 |
}
|
52 |
],
|
53 |
"env": "PyTorch model",
|
|
|
56 |
"data": {
|
57 |
"display": null,
|
58 |
"error": null,
|
59 |
+
"input_metadata": null,
|
60 |
"meta": {
|
61 |
"inputs": {},
|
62 |
"name": "Input: embedding",
|
|
|
90 |
"data": {
|
91 |
"display": null,
|
92 |
"error": null,
|
93 |
+
"input_metadata": null,
|
94 |
"meta": {
|
95 |
"inputs": {
|
96 |
"x": {
|
|
|
142 |
"data": {
|
143 |
"display": null,
|
144 |
"error": null,
|
145 |
+
"input_metadata": null,
|
146 |
"meta": {
|
147 |
"inputs": {
|
148 |
"x": {
|
|
|
191 |
"data": {
|
192 |
"display": null,
|
193 |
"error": null,
|
194 |
+
"input_metadata": null,
|
195 |
"meta": {
|
196 |
"inputs": {},
|
197 |
"name": "Input: label",
|
|
|
227 |
"collapsed": null,
|
228 |
"display": null,
|
229 |
"error": null,
|
230 |
+
"input_metadata": null,
|
231 |
"meta": {
|
232 |
"inputs": {
|
233 |
"x": {
|
|
|
262 |
}
|
263 |
}
|
264 |
},
|
|
|
|
|
|
|
|
|
265 |
"type": "basic"
|
266 |
},
|
267 |
"params": {
|
|
|
286 |
"collapsed": null,
|
287 |
"display": null,
|
288 |
"error": null,
|
289 |
+
"input_metadata": null,
|
290 |
"meta": {
|
291 |
"inputs": {
|
292 |
"loss": {
|
|
|
323 |
}
|
324 |
}
|
325 |
},
|
|
|
|
|
|
|
|
|
326 |
"type": "basic"
|
327 |
},
|
328 |
"params": {
|
|
|
341 |
},
|
342 |
"type": "basic",
|
343 |
"width": 200.0
|
344 |
+
},
|
345 |
+
{
|
346 |
+
"data": {
|
347 |
+
"__execution_delay": 0.0,
|
348 |
+
"collapsed": null,
|
349 |
+
"display": null,
|
350 |
+
"error": null,
|
351 |
+
"input_metadata": null,
|
352 |
+
"meta": {
|
353 |
+
"inputs": {
|
354 |
+
"input": {
|
355 |
+
"name": "input",
|
356 |
+
"position": "top",
|
357 |
+
"type": {
|
358 |
+
"type": "tensor"
|
359 |
+
}
|
360 |
+
}
|
361 |
+
},
|
362 |
+
"name": "Repeat",
|
363 |
+
"outputs": {
|
364 |
+
"output": {
|
365 |
+
"name": "output",
|
366 |
+
"position": "bottom",
|
367 |
+
"type": {
|
368 |
+
"type": "tensor"
|
369 |
+
}
|
370 |
+
}
|
371 |
+
},
|
372 |
+
"params": {
|
373 |
+
"same_weights": {
|
374 |
+
"default": true,
|
375 |
+
"name": "same_weights",
|
376 |
+
"type": {
|
377 |
+
"type": "<class 'bool'>"
|
378 |
+
}
|
379 |
+
},
|
380 |
+
"times": {
|
381 |
+
"default": 1.0,
|
382 |
+
"name": "times",
|
383 |
+
"type": {
|
384 |
+
"type": "<class 'int'>"
|
385 |
+
}
|
386 |
+
}
|
387 |
+
},
|
388 |
+
"position": {
|
389 |
+
"x": 386.0,
|
390 |
+
"y": 456.0
|
391 |
+
},
|
392 |
+
"type": "basic"
|
393 |
+
},
|
394 |
+
"params": {
|
395 |
+
"same_weights": false,
|
396 |
+
"times": "2"
|
397 |
+
},
|
398 |
+
"status": "planned",
|
399 |
+
"title": "Repeat"
|
400 |
+
},
|
401 |
+
"dragHandle": ".bg-primary",
|
402 |
+
"height": 200.0,
|
403 |
+
"id": "Repeat 1",
|
404 |
+
"position": {
|
405 |
+
"x": -180.0,
|
406 |
+
"y": -90.0
|
407 |
+
},
|
408 |
+
"type": "basic",
|
409 |
+
"width": 200.0
|
410 |
}
|
411 |
]
|
412 |
}
|
examples/Model use
CHANGED
@@ -579,54 +579,54 @@
|
|
579 |
],
|
580 |
"data": [
|
581 |
[
|
582 |
-
"[0.
|
583 |
-
"[1.
|
584 |
-
"[1.
|
585 |
-
],
|
586 |
-
[
|
587 |
-
"[0.56922203 0.98222166 0.76851749 0.28615737]",
|
588 |
-
"[1.56922197 1.9822216 1.76851749 1.28615737]",
|
589 |
-
"[1.5835213661193848, 1.9884355068206787, 1.7694181203842163, 1.2917503118515015]"
|
590 |
],
|
591 |
[
|
592 |
-
"[0.
|
593 |
-
"[1.
|
594 |
-
"[1.
|
595 |
],
|
596 |
[
|
597 |
-
"[0.
|
598 |
-
"[1.
|
599 |
-
"[1.
|
600 |
],
|
601 |
[
|
602 |
-
"[0.
|
603 |
-
"[1.
|
604 |
-
"[1.
|
605 |
],
|
606 |
[
|
607 |
-
"[0.
|
608 |
-
"[1.
|
609 |
-
"[1.
|
610 |
],
|
611 |
[
|
612 |
-
"[0.
|
613 |
-
"[1.
|
614 |
-
"[1.
|
615 |
],
|
616 |
[
|
617 |
-
"[0.
|
618 |
-
"[1.
|
619 |
-
"[1.
|
620 |
],
|
621 |
[
|
622 |
-
"[0.
|
623 |
-
"[1.
|
624 |
-
"[
|
625 |
],
|
626 |
[
|
627 |
"[0.68094063 0.45189077 0.22661722 0.37354094]",
|
628 |
"[1.68094063 1.45189071 1.22661722 1.37354088]",
|
629 |
-
"[1.
|
|
|
|
|
|
|
|
|
|
|
630 |
]
|
631 |
]
|
632 |
},
|
@@ -644,10 +644,6 @@
|
|
644 |
"[0.85706753 0.61447072 0.41741937 0.85147089]",
|
645 |
"[1.85706758 1.61447072 1.41741943 1.85147095]"
|
646 |
],
|
647 |
-
[
|
648 |
-
"[0.11560339 0.57495481 0.76535827 0.0391947 ]",
|
649 |
-
"[1.11560345 1.57495475 1.76535821 1.0391947 ]"
|
650 |
-
],
|
651 |
[
|
652 |
"[0.19409031 0.68692201 0.60667384 0.57829887]",
|
653 |
"[1.19409037 1.68692207 1.60667384 1.57829881]"
|
@@ -688,10 +684,18 @@
|
|
688 |
"[0.11693293 0.49860179 0.55020827 0.88832849]",
|
689 |
"[1.11693287 1.49860179 1.55020833 1.88832855]"
|
690 |
],
|
|
|
|
|
|
|
|
|
691 |
[
|
692 |
"[0.50272274 0.54912758 0.17663097 0.79070699]",
|
693 |
"[1.50272274 1.54912758 1.17663097 1.79070699]"
|
694 |
],
|
|
|
|
|
|
|
|
|
695 |
[
|
696 |
"[0.19908059 0.17570406 0.51475513 0.1893943 ]",
|
697 |
"[1.19908059 1.175704 1.51475513 1.18939424]"
|
@@ -709,13 +713,17 @@
|
|
709 |
"[1.24388778 1.07268476 1.68350863 1.73431659]"
|
710 |
],
|
711 |
[
|
712 |
-
"[0.
|
713 |
-
"[1.
|
714 |
],
|
715 |
[
|
716 |
"[0.88776821 0.51636773 0.30333066 0.32230979]",
|
717 |
"[1.88776827 1.51636767 1.30333066 1.32230973]"
|
718 |
],
|
|
|
|
|
|
|
|
|
719 |
[
|
720 |
"[0.48507756 0.80808765 0.77162558 0.47834778]",
|
721 |
"[1.48507762 1.80808759 1.77162552 1.47834778]"
|
@@ -728,10 +736,6 @@
|
|
728 |
"[0.31518555 0.49643308 0.11509258 0.95458382]",
|
729 |
"[1.31518555 1.49643302 1.11509252 1.95458388]"
|
730 |
],
|
731 |
-
[
|
732 |
-
"[0.79121011 0.54161114 0.69369799 0.1520769 ]",
|
733 |
-
"[1.79121017 1.54161119 1.69369793 1.15207696]"
|
734 |
-
],
|
735 |
[
|
736 |
"[0.79423058 0.07138705 0.061777 0.18766576]",
|
737 |
"[1.79423058 1.07138705 1.061777 1.1876657 ]"
|
@@ -764,10 +768,6 @@
|
|
764 |
"[0.98033333 0.97656083 0.38939917 0.81491041]",
|
765 |
"[1.98033333 1.97656083 1.38939917 1.81491041]"
|
766 |
],
|
767 |
-
[
|
768 |
-
"[0.74064726 0.4155122 0.09800029 0.49930882]",
|
769 |
-
"[1.74064732 1.4155122 1.09800029 1.49930882]"
|
770 |
-
],
|
771 |
[
|
772 |
"[0.78956431 0.87284744 0.06880784 0.03455889]",
|
773 |
"[1.78956437 1.87284744 1.06880784 1.03455889]"
|
@@ -804,10 +804,6 @@
|
|
804 |
"[0.73217702 0.65233225 0.44077861 0.33837909]",
|
805 |
"[1.73217702 1.65233231 1.44077861 1.33837914]"
|
806 |
],
|
807 |
-
[
|
808 |
-
"[0.34084332 0.73018837 0.54168713 0.91440833]",
|
809 |
-
"[1.34084332 1.73018837 1.54168713 1.91440833]"
|
810 |
-
],
|
811 |
[
|
812 |
"[0.60110539 0.3618983 0.32342511 0.98672163]",
|
813 |
"[1.60110545 1.3618983 1.32342505 1.98672163]"
|
@@ -816,6 +812,10 @@
|
|
816 |
"[0.77427191 0.21829212 0.12769502 0.74303615]",
|
817 |
"[1.77427197 1.21829212 1.12769508 1.74303615]"
|
818 |
],
|
|
|
|
|
|
|
|
|
819 |
[
|
820 |
"[0.59812403 0.78395379 0.0291847 0.81814629]",
|
821 |
"[1.59812403 1.78395379 1.0291847 1.81814623]"
|
@@ -856,10 +856,6 @@
|
|
856 |
"[0.54914117 0.03810108 0.87531954 0.73044223]",
|
857 |
"[1.54914117 1.03810108 1.87531948 1.73044229]"
|
858 |
],
|
859 |
-
[
|
860 |
-
"[0.67418337 0.79634351 0.23229051 0.71345252]",
|
861 |
-
"[1.67418337 1.79634356 1.23229051 1.71345258]"
|
862 |
-
],
|
863 |
[
|
864 |
"[0.87285906 0.48354989 0.39394957 0.59456545]",
|
865 |
"[1.872859 1.48354983 1.39394951 1.59456539]"
|
@@ -908,10 +904,6 @@
|
|
908 |
"[0.60609657 0.96257663 0.19292736 0.95702219]",
|
909 |
"[1.60609651 1.96257663 1.19292736 1.95702219]"
|
910 |
],
|
911 |
-
[
|
912 |
-
"[0.80654246 0.08253473 0.74478531 0.71257162]",
|
913 |
-
"[1.8065424 1.08253479 1.74478531 1.71257162]"
|
914 |
-
],
|
915 |
[
|
916 |
"[0.70167565 0.26930219 0.5660674 0.61194974]",
|
917 |
"[1.70167565 1.26930213 1.56606746 1.61194968]"
|
@@ -924,6 +916,10 @@
|
|
924 |
"[0.59492421 0.90274489 0.38069052 0.46101224]",
|
925 |
"[1.59492421 1.90274489 1.38069057 1.46101224]"
|
926 |
],
|
|
|
|
|
|
|
|
|
927 |
[
|
928 |
"[0.12024075 0.21342516 0.56858408 0.58644271]",
|
929 |
"[1.12024069 1.21342516 1.56858408 1.58644271]"
|
@@ -933,8 +929,8 @@
|
|
933 |
"[1.91730917 1.22574067 1.09591603 1.33056474]"
|
934 |
],
|
935 |
[
|
936 |
-
"[0.
|
937 |
-
"[1.
|
938 |
],
|
939 |
[
|
940 |
"[0.37959969 0.42820001 0.10690689 0.96353984]",
|
@@ -988,6 +984,10 @@
|
|
988 |
"[0.47870928 0.17129105 0.27300501 0.20634609]",
|
989 |
"[1.47870922 1.17129111 1.27300501 1.20634604]"
|
990 |
],
|
|
|
|
|
|
|
|
|
991 |
[
|
992 |
"[0.87608397 0.93200487 0.80169648 0.37758952]",
|
993 |
"[1.87608397 1.93200493 1.80169654 1.37758946]"
|
@@ -1000,7 +1000,7 @@
|
|
1000 |
}
|
1001 |
},
|
1002 |
"other": {
|
1003 |
-
"model": "ModelConfig(model=Sequential(\n (0) - Linear(in_features=4, out_features=4, bias=True):
|
1004 |
},
|
1005 |
"relations": []
|
1006 |
},
|
@@ -1036,10 +1036,10 @@
|
|
1036 |
],
|
1037 |
"loss_inputs": [
|
1038 |
"Input__label_1_y",
|
1039 |
-
"
|
1040 |
],
|
1041 |
"outputs": [
|
1042 |
-
"
|
1043 |
],
|
1044 |
"trained": true
|
1045 |
},
|
@@ -1211,10 +1211,10 @@
|
|
1211 |
],
|
1212 |
"loss_inputs": [
|
1213 |
"Input__label_1_y",
|
1214 |
-
"
|
1215 |
],
|
1216 |
"outputs": [
|
1217 |
-
"
|
1218 |
],
|
1219 |
"trained": false
|
1220 |
},
|
@@ -1270,7 +1270,7 @@
|
|
1270 |
"type": "basic"
|
1271 |
},
|
1272 |
"params": {
|
1273 |
-
"epochs": "
|
1274 |
"input_mapping": "{\"map\":{\"Input__embedding_1_x\":{\"column\":\"x\",\"df\":\"df_train\"},\"Input__label_1_y\":{\"column\":\"y\",\"df\":\"df_train\"}}}",
|
1275 |
"model_name": "model"
|
1276 |
},
|
@@ -1304,7 +1304,6 @@
|
|
1304 |
},
|
1305 |
"df_test": {
|
1306 |
"columns": [
|
1307 |
-
"predicted",
|
1308 |
"x",
|
1309 |
"y"
|
1310 |
]
|
@@ -1324,10 +1323,10 @@
|
|
1324 |
],
|
1325 |
"loss_inputs": [
|
1326 |
"Input__label_1_y",
|
1327 |
-
"
|
1328 |
],
|
1329 |
"outputs": [
|
1330 |
-
"
|
1331 |
],
|
1332 |
"trained": true
|
1333 |
},
|
@@ -1383,9 +1382,9 @@
|
|
1383 |
"type": "basic"
|
1384 |
},
|
1385 |
"params": {
|
1386 |
-
"input_mapping": "{\"map\":{\"Input__embedding_1_x\":{\"
|
1387 |
"model_name": "model",
|
1388 |
-
"output_mapping": "{\"map\":{\"
|
1389 |
},
|
1390 |
"status": "done",
|
1391 |
"title": "Model inference"
|
|
|
579 |
],
|
580 |
"data": [
|
581 |
[
|
582 |
+
"[0.67418337 0.79634351 0.23229051 0.71345252]",
|
583 |
+
"[1.67418337 1.79634356 1.23229051 1.71345258]",
|
584 |
+
"[1.2304607629776, 1.2535771131515503, 1.4764351844787598, 1.524468183517456]"
|
|
|
|
|
|
|
|
|
|
|
585 |
],
|
586 |
[
|
587 |
+
"[0.6032477 0.83361369 0.18538666 0.19108021]",
|
588 |
+
"[1.60324764 1.83361363 1.18538666 1.19108021]",
|
589 |
+
"[1.0745662450790405, 1.3498694896697998, 1.524811029434204, 1.2336101531982422]"
|
590 |
],
|
591 |
[
|
592 |
+
"[0.62569475 0.9881897 0.83639616 0.9828859 ]",
|
593 |
+
"[1.62569475 1.9881897 1.83639622 1.98288584]",
|
594 |
+
"[1.3193368911743164, 1.691685438156128, 1.4516379833221436, 1.6486209630966187]"
|
595 |
],
|
596 |
[
|
597 |
+
"[0.63235509 0.70352674 0.96188956 0.46240485]",
|
598 |
+
"[1.63235509 1.70352674 1.96188951 1.46240485]",
|
599 |
+
"[1.298614740371704, 1.7463937997817993, 1.4211854934692383, 1.2179659605026245]"
|
600 |
],
|
601 |
[
|
602 |
+
"[0.74064726 0.4155122 0.09800029 0.49930882]",
|
603 |
+
"[1.74064732 1.4155122 1.09800029 1.49930882]",
|
604 |
+
"[1.2740169763565063, 1.0242725610733032, 1.413628339767456, 1.288135290145874]"
|
605 |
],
|
606 |
[
|
607 |
+
"[0.34084332 0.73018837 0.54168713 0.91440833]",
|
608 |
+
"[1.34084332 1.73018837 1.54168713 1.91440833]",
|
609 |
+
"[1.0640922784805298, 1.3728828430175781, 1.1964272260665894, 1.5791122913360596]"
|
610 |
],
|
611 |
[
|
612 |
+
"[0.80654246 0.08253473 0.74478531 0.71257162]",
|
613 |
+
"[1.8065424 1.08253479 1.74478531 1.71257162]",
|
614 |
+
"[1.521227240562439, 1.247929334640503, 1.2715831995010376, 1.1901893615722656]"
|
615 |
],
|
616 |
[
|
617 |
+
"[0.11560339 0.57495481 0.76535827 0.0391947 ]",
|
618 |
+
"[1.11560345 1.57495475 1.76535821 1.0391947 ]",
|
619 |
+
"[0.8016152381896973, 1.6247310638427734, 1.1153241395950317, 0.9691400527954102]"
|
620 |
],
|
621 |
[
|
622 |
"[0.68094063 0.45189077 0.22661722 0.37354094]",
|
623 |
"[1.68094063 1.45189071 1.22661722 1.37354088]",
|
624 |
+
"[1.224153995513916, 1.1527836322784424, 1.4024348258972168, 1.204542636871338]"
|
625 |
+
],
|
626 |
+
[
|
627 |
+
"[0.79121011 0.54161114 0.69369799 0.1520769 ]",
|
628 |
+
"[1.79121017 1.54161119 1.69369793 1.15207696]",
|
629 |
+
"[1.3464093208312988, 1.5594894886016846, 1.5191831588745117, 1.0183898210525513]"
|
630 |
]
|
631 |
]
|
632 |
},
|
|
|
644 |
"[0.85706753 0.61447072 0.41741937 0.85147089]",
|
645 |
"[1.85706758 1.61447072 1.41741943 1.85147095]"
|
646 |
],
|
|
|
|
|
|
|
|
|
647 |
[
|
648 |
"[0.19409031 0.68692201 0.60667384 0.57829887]",
|
649 |
"[1.19409037 1.68692207 1.60667384 1.57829881]"
|
|
|
684 |
"[0.11693293 0.49860179 0.55020827 0.88832849]",
|
685 |
"[1.11693287 1.49860179 1.55020833 1.88832855]"
|
686 |
],
|
687 |
+
[
|
688 |
+
"[0.48959708 0.48549271 0.32688856 0.356677 ]",
|
689 |
+
"[1.48959708 1.48549271 1.32688856 1.35667706]"
|
690 |
+
],
|
691 |
[
|
692 |
"[0.50272274 0.54912758 0.17663097 0.79070699]",
|
693 |
"[1.50272274 1.54912758 1.17663097 1.79070699]"
|
694 |
],
|
695 |
+
[
|
696 |
+
"[0.04508126 0.76880038 0.80721325 0.62542385]",
|
697 |
+
"[1.04508126 1.76880038 1.80721331 1.62542391]"
|
698 |
+
],
|
699 |
[
|
700 |
"[0.19908059 0.17570406 0.51475513 0.1893943 ]",
|
701 |
"[1.19908059 1.175704 1.51475513 1.18939424]"
|
|
|
713 |
"[1.24388778 1.07268476 1.68350863 1.73431659]"
|
714 |
],
|
715 |
[
|
716 |
+
"[0.56922203 0.98222166 0.76851749 0.28615737]",
|
717 |
+
"[1.56922197 1.9822216 1.76851749 1.28615737]"
|
718 |
],
|
719 |
[
|
720 |
"[0.88776821 0.51636773 0.30333066 0.32230979]",
|
721 |
"[1.88776827 1.51636767 1.30333066 1.32230973]"
|
722 |
],
|
723 |
+
[
|
724 |
+
"[0.90817457 0.89270043 0.38583666 0.66566533]",
|
725 |
+
"[1.90817451 1.89270043 1.3858366 1.66566539]"
|
726 |
+
],
|
727 |
[
|
728 |
"[0.48507756 0.80808765 0.77162558 0.47834778]",
|
729 |
"[1.48507762 1.80808759 1.77162552 1.47834778]"
|
|
|
736 |
"[0.31518555 0.49643308 0.11509258 0.95458382]",
|
737 |
"[1.31518555 1.49643302 1.11509252 1.95458388]"
|
738 |
],
|
|
|
|
|
|
|
|
|
739 |
[
|
740 |
"[0.79423058 0.07138705 0.061777 0.18766576]",
|
741 |
"[1.79423058 1.07138705 1.061777 1.1876657 ]"
|
|
|
768 |
"[0.98033333 0.97656083 0.38939917 0.81491041]",
|
769 |
"[1.98033333 1.97656083 1.38939917 1.81491041]"
|
770 |
],
|
|
|
|
|
|
|
|
|
771 |
[
|
772 |
"[0.78956431 0.87284744 0.06880784 0.03455889]",
|
773 |
"[1.78956437 1.87284744 1.06880784 1.03455889]"
|
|
|
804 |
"[0.73217702 0.65233225 0.44077861 0.33837909]",
|
805 |
"[1.73217702 1.65233231 1.44077861 1.33837914]"
|
806 |
],
|
|
|
|
|
|
|
|
|
807 |
[
|
808 |
"[0.60110539 0.3618983 0.32342511 0.98672163]",
|
809 |
"[1.60110545 1.3618983 1.32342505 1.98672163]"
|
|
|
812 |
"[0.77427191 0.21829212 0.12769502 0.74303615]",
|
813 |
"[1.77427197 1.21829212 1.12769508 1.74303615]"
|
814 |
],
|
815 |
+
[
|
816 |
+
"[0.08107251 0.2602725 0.18861133 0.44833237]",
|
817 |
+
"[1.08107257 1.2602725 1.18861127 1.44833231]"
|
818 |
+
],
|
819 |
[
|
820 |
"[0.59812403 0.78395379 0.0291847 0.81814629]",
|
821 |
"[1.59812403 1.78395379 1.0291847 1.81814623]"
|
|
|
856 |
"[0.54914117 0.03810108 0.87531954 0.73044223]",
|
857 |
"[1.54914117 1.03810108 1.87531948 1.73044229]"
|
858 |
],
|
|
|
|
|
|
|
|
|
859 |
[
|
860 |
"[0.87285906 0.48354989 0.39394957 0.59456545]",
|
861 |
"[1.872859 1.48354983 1.39394951 1.59456539]"
|
|
|
904 |
"[0.60609657 0.96257663 0.19292736 0.95702219]",
|
905 |
"[1.60609651 1.96257663 1.19292736 1.95702219]"
|
906 |
],
|
|
|
|
|
|
|
|
|
907 |
[
|
908 |
"[0.70167565 0.26930219 0.5660674 0.61194974]",
|
909 |
"[1.70167565 1.26930213 1.56606746 1.61194968]"
|
|
|
916 |
"[0.59492421 0.90274489 0.38069052 0.46101224]",
|
917 |
"[1.59492421 1.90274489 1.38069057 1.46101224]"
|
918 |
],
|
919 |
+
[
|
920 |
+
"[0.15064228 0.03198934 0.25754827 0.51484001]",
|
921 |
+
"[1.15064228 1.03198934 1.25754833 1.51484001]"
|
922 |
+
],
|
923 |
[
|
924 |
"[0.12024075 0.21342516 0.56858408 0.58644271]",
|
925 |
"[1.12024069 1.21342516 1.56858408 1.58644271]"
|
|
|
929 |
"[1.91730917 1.22574067 1.09591603 1.33056474]"
|
930 |
],
|
931 |
[
|
932 |
+
"[0.49691743 0.61873293 0.90698647 0.94486356]",
|
933 |
+
"[1.49691749 1.61873293 1.90698647 1.94486356]"
|
934 |
],
|
935 |
[
|
936 |
"[0.37959969 0.42820001 0.10690689 0.96353984]",
|
|
|
984 |
"[0.47870928 0.17129105 0.27300501 0.20634609]",
|
985 |
"[1.47870922 1.17129111 1.27300501 1.20634604]"
|
986 |
],
|
987 |
+
[
|
988 |
+
"[0.72795159 0.79317838 0.27832931 0.96576637]",
|
989 |
+
"[1.72795153 1.79317832 1.27832937 1.96576643]"
|
990 |
+
],
|
991 |
[
|
992 |
"[0.87608397 0.93200487 0.80169648 0.37758952]",
|
993 |
"[1.87608397 1.93200493 1.80169654 1.37758946]"
|
|
|
1000 |
}
|
1001 |
},
|
1002 |
"other": {
|
1003 |
+
"model": "ModelConfig(model=Sequential(\n (0) - Identity(): Input__embedding_1_x -> START_Repeat_1_output\n (1) - Linear(in_features=4, out_features=4, bias=True): START_Repeat_1_output -> Linear_1_x\n (2) - <function leaky_relu at 0x7e6e1cf2c220>: Linear_1_x -> Activation_2_x\n (3) - Identity(): Activation_2_x -> END_Repeat_1_output\n (4) - Identity(): END_Repeat_1_output -> END_Repeat_1_output\n), model_inputs=['Input__embedding_1_x'], model_outputs=['END_Repeat_1_output'], loss_inputs=['Input__label_1_y', 'END_Repeat_1_output'], loss=Sequential(\n (0) - <function mse_loss at 0x7e6e1cf2dd00>: END_Repeat_1_output, Input__label_1_y -> MSE_loss_1_loss\n (1) - Identity(): MSE_loss_1_loss -> loss\n), optimizer=SGD (\nParameter Group 0\n dampening: 0\n differentiable: False\n foreach: None\n fused: None\n lr: 0.1\n maximize: False\n momentum: 0\n nesterov: False\n weight_decay: 0\n), source_workspace=None, trained=True)"
|
1004 |
},
|
1005 |
"relations": []
|
1006 |
},
|
|
|
1036 |
],
|
1037 |
"loss_inputs": [
|
1038 |
"Input__label_1_y",
|
1039 |
+
"END_Repeat_1_output"
|
1040 |
],
|
1041 |
"outputs": [
|
1042 |
+
"END_Repeat_1_output"
|
1043 |
],
|
1044 |
"trained": true
|
1045 |
},
|
|
|
1211 |
],
|
1212 |
"loss_inputs": [
|
1213 |
"Input__label_1_y",
|
1214 |
+
"END_Repeat_1_output"
|
1215 |
],
|
1216 |
"outputs": [
|
1217 |
+
"END_Repeat_1_output"
|
1218 |
],
|
1219 |
"trained": false
|
1220 |
},
|
|
|
1270 |
"type": "basic"
|
1271 |
},
|
1272 |
"params": {
|
1273 |
+
"epochs": "15",
|
1274 |
"input_mapping": "{\"map\":{\"Input__embedding_1_x\":{\"column\":\"x\",\"df\":\"df_train\"},\"Input__label_1_y\":{\"column\":\"y\",\"df\":\"df_train\"}}}",
|
1275 |
"model_name": "model"
|
1276 |
},
|
|
|
1304 |
},
|
1305 |
"df_test": {
|
1306 |
"columns": [
|
|
|
1307 |
"x",
|
1308 |
"y"
|
1309 |
]
|
|
|
1323 |
],
|
1324 |
"loss_inputs": [
|
1325 |
"Input__label_1_y",
|
1326 |
+
"END_Repeat_1_output"
|
1327 |
],
|
1328 |
"outputs": [
|
1329 |
+
"END_Repeat_1_output"
|
1330 |
],
|
1331 |
"trained": true
|
1332 |
},
|
|
|
1382 |
"type": "basic"
|
1383 |
},
|
1384 |
"params": {
|
1385 |
+
"input_mapping": "{\"map\":{\"Input__embedding_1_x\":{\"df\":\"df_test\",\"column\":\"x\"}}}",
|
1386 |
"model_name": "model",
|
1387 |
+
"output_mapping": "{\"map\":{\"END_Repeat_1_output\":{\"df\":\"df_test\",\"column\":\"predicted\"}}}"
|
1388 |
},
|
1389 |
"status": "done",
|
1390 |
"title": "Model inference"
|
lynxkite-graph-analytics/src/lynxkite_graph_analytics/pytorch_model_ops.py
CHANGED
@@ -110,7 +110,7 @@ ops.register_passive_op(
|
|
110 |
outputs=[ops.Output(name="output", position="bottom", type="tensor")],
|
111 |
params=[
|
112 |
ops.Parameter.basic("times", 1, int),
|
113 |
-
ops.Parameter.basic("same_weights",
|
114 |
],
|
115 |
)
|
116 |
|
@@ -207,8 +207,10 @@ def build_model(ws: workspace.Workspace, inputs: dict[str, torch.Tensor]) -> Mod
|
|
207 |
dependencies = {n.id: [] for n in ws.nodes}
|
208 |
in_edges = {}
|
209 |
out_edges = {}
|
210 |
-
|
211 |
for e in ws.edges:
|
|
|
|
|
212 |
dependencies[e.target].append(e.source)
|
213 |
in_edges.setdefault(e.target, {}).setdefault(e.targetHandle, []).append(
|
214 |
(e.source, e.sourceHandle)
|
@@ -216,34 +218,78 @@ def build_model(ws: workspace.Workspace, inputs: dict[str, torch.Tensor]) -> Mod
|
|
216 |
out_edges.setdefault(e.source, {}).setdefault(e.sourceHandle, []).append(
|
217 |
(e.target, e.targetHandle)
|
218 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
sizes = {}
|
220 |
for k, i in inputs.items():
|
221 |
sizes[k] = i.shape[-1]
|
222 |
ts = graphlib.TopologicalSorter(dependencies)
|
223 |
layers = []
|
224 |
loss_layers = []
|
225 |
-
|
226 |
cfg = {}
|
227 |
used_in_model = set()
|
228 |
made_in_model = set()
|
229 |
used_in_loss = set()
|
230 |
made_in_loss = set()
|
231 |
for node_id in ts.static_order():
|
232 |
-
|
|
|
|
|
|
|
|
|
|
|
233 |
t = node.data.title
|
234 |
op = catalog[t]
|
235 |
p = op.convert_params(node.data.params)
|
236 |
for b in dependencies[node_id]:
|
237 |
-
|
238 |
-
in_loss.add(node_id)
|
239 |
if "loss" in t:
|
240 |
-
|
241 |
inputs = {}
|
242 |
for n in in_edges.get(node_id, []):
|
243 |
for b, h in in_edges[node_id][n]:
|
244 |
i = _to_id(b, h)
|
245 |
inputs[n] = i
|
246 |
-
if
|
247 |
used_in_loss.add(i)
|
248 |
else:
|
249 |
used_in_model.add(i)
|
@@ -252,13 +298,13 @@ def build_model(ws: workspace.Workspace, inputs: dict[str, torch.Tensor]) -> Mod
|
|
252 |
i = _to_id(node_id, out)
|
253 |
outputs[out] = i
|
254 |
if inputs: # Nodes with no inputs are input nodes. Their outputs are not "made" by us.
|
255 |
-
if
|
256 |
made_in_loss.add(i)
|
257 |
else:
|
258 |
made_in_model.add(i)
|
259 |
inputs = types.SimpleNamespace(**inputs)
|
260 |
outputs = types.SimpleNamespace(**outputs)
|
261 |
-
ls = loss_layers if
|
262 |
match t:
|
263 |
case "Linear":
|
264 |
isize = sizes.get(inputs.x, 1)
|
@@ -276,6 +322,19 @@ def build_model(ws: workspace.Workspace, inputs: dict[str, torch.Tensor]) -> Mod
|
|
276 |
f"{inputs.x}, {inputs.y} -> {outputs.loss}",
|
277 |
)
|
278 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
cfg["model_inputs"] = list(used_in_model - made_in_model)
|
280 |
cfg["model_outputs"] = list(made_in_model & used_in_loss)
|
281 |
cfg["loss_inputs"] = list(used_in_loss - made_in_loss)
|
|
|
110 |
outputs=[ops.Output(name="output", position="bottom", type="tensor")],
|
111 |
params=[
|
112 |
ops.Parameter.basic("times", 1, int),
|
113 |
+
ops.Parameter.basic("same_weights", False, bool),
|
114 |
],
|
115 |
)
|
116 |
|
|
|
207 |
dependencies = {n.id: [] for n in ws.nodes}
|
208 |
in_edges = {}
|
209 |
out_edges = {}
|
210 |
+
repeats = []
|
211 |
for e in ws.edges:
|
212 |
+
if nodes[e.target].data.title == "Repeat":
|
213 |
+
repeats.append(e.target)
|
214 |
dependencies[e.target].append(e.source)
|
215 |
in_edges.setdefault(e.target, {}).setdefault(e.targetHandle, []).append(
|
216 |
(e.source, e.sourceHandle)
|
|
|
218 |
out_edges.setdefault(e.source, {}).setdefault(e.sourceHandle, []).append(
|
219 |
(e.target, e.targetHandle)
|
220 |
)
|
221 |
+
# Split repeat boxes into start and end, and insert them into the flow.
|
222 |
+
# TODO: Think about recursive repeats.
|
223 |
+
for repeat in repeats:
|
224 |
+
start_id = f"START {repeat}"
|
225 |
+
end_id = f"END {repeat}"
|
226 |
+
# repeat -> first <- real_input
|
227 |
+
# ...becomes...
|
228 |
+
# real_input -> start -> first
|
229 |
+
first, firsth = out_edges[repeat]["output"][0]
|
230 |
+
[(real_input, real_inputh)] = [
|
231 |
+
k for k in in_edges[first][firsth] if k != (repeat, "output")
|
232 |
+
]
|
233 |
+
dependencies[first].remove(repeat)
|
234 |
+
dependencies[first].append(start_id)
|
235 |
+
dependencies[start_id] = [real_input]
|
236 |
+
out_edges[real_input][real_inputh] = [
|
237 |
+
k if k != (first, firsth) else (start_id, "input")
|
238 |
+
for k in out_edges[real_input][real_inputh]
|
239 |
+
]
|
240 |
+
in_edges[start_id] = {"input": [(real_input, real_inputh)]}
|
241 |
+
out_edges[start_id] = {"output": [(first, firsth)]}
|
242 |
+
in_edges[first][firsth] = [(start_id, "output")]
|
243 |
+
# repeat <- last -> real_output
|
244 |
+
# ...becomes...
|
245 |
+
# last -> end -> real_output
|
246 |
+
last, lasth = in_edges[repeat]["input"][0]
|
247 |
+
[(real_output, real_outputh)] = [
|
248 |
+
k for k in out_edges[last][lasth] if k != (repeat, "input")
|
249 |
+
]
|
250 |
+
del dependencies[repeat]
|
251 |
+
dependencies[end_id] = [last]
|
252 |
+
dependencies[real_output].append(end_id)
|
253 |
+
out_edges[last][lasth] = [(end_id, "input")]
|
254 |
+
in_edges[end_id] = {"input": [(last, lasth)]}
|
255 |
+
out_edges[end_id] = {"output": [(real_output, real_outputh)]}
|
256 |
+
in_edges[real_output][real_outputh] = [
|
257 |
+
k if k != (last, lasth) else (end_id, "output")
|
258 |
+
for k in in_edges[real_output][real_outputh]
|
259 |
+
]
|
260 |
+
# Walk the graph in topological order.
|
261 |
sizes = {}
|
262 |
for k, i in inputs.items():
|
263 |
sizes[k] = i.shape[-1]
|
264 |
ts = graphlib.TopologicalSorter(dependencies)
|
265 |
layers = []
|
266 |
loss_layers = []
|
267 |
+
regions: dict[str, set[str]] = {node_id: set() for node_id in dependencies}
|
268 |
cfg = {}
|
269 |
used_in_model = set()
|
270 |
made_in_model = set()
|
271 |
used_in_loss = set()
|
272 |
made_in_loss = set()
|
273 |
for node_id in ts.static_order():
|
274 |
+
if node_id.startswith("START "):
|
275 |
+
node = nodes[node_id.removeprefix("START ")]
|
276 |
+
elif node_id.startswith("END "):
|
277 |
+
node = nodes[node_id.removeprefix("END ")]
|
278 |
+
else:
|
279 |
+
node = nodes[node_id]
|
280 |
t = node.data.title
|
281 |
op = catalog[t]
|
282 |
p = op.convert_params(node.data.params)
|
283 |
for b in dependencies[node_id]:
|
284 |
+
regions[node_id] |= regions[b]
|
|
|
285 |
if "loss" in t:
|
286 |
+
regions[node_id].add("loss")
|
287 |
inputs = {}
|
288 |
for n in in_edges.get(node_id, []):
|
289 |
for b, h in in_edges[node_id][n]:
|
290 |
i = _to_id(b, h)
|
291 |
inputs[n] = i
|
292 |
+
if "loss" in regions[node_id]:
|
293 |
used_in_loss.add(i)
|
294 |
else:
|
295 |
used_in_model.add(i)
|
|
|
298 |
i = _to_id(node_id, out)
|
299 |
outputs[out] = i
|
300 |
if inputs: # Nodes with no inputs are input nodes. Their outputs are not "made" by us.
|
301 |
+
if "loss" in regions[node_id]:
|
302 |
made_in_loss.add(i)
|
303 |
else:
|
304 |
made_in_model.add(i)
|
305 |
inputs = types.SimpleNamespace(**inputs)
|
306 |
outputs = types.SimpleNamespace(**outputs)
|
307 |
+
ls = loss_layers if "loss" in regions[node_id] else layers
|
308 |
match t:
|
309 |
case "Linear":
|
310 |
isize = sizes.get(inputs.x, 1)
|
|
|
322 |
f"{inputs.x}, {inputs.y} -> {outputs.loss}",
|
323 |
)
|
324 |
)
|
325 |
+
case "Repeat":
|
326 |
+
ls.append((torch.nn.Identity(), f"{inputs.input} -> {outputs.output}"))
|
327 |
+
sizes[outputs.output] = sizes.get(inputs.input, 1)
|
328 |
+
if node_id.startswith("START "):
|
329 |
+
regions[node_id].add(("repeat", node_id.removeprefix("START ")))
|
330 |
+
else:
|
331 |
+
repeat_id = node_id.removeprefix("END ")
|
332 |
+
print(f"repeat {repeat_id} ending")
|
333 |
+
regions[node_id].remove(("repeat", repeat_id))
|
334 |
+
for n in nodes:
|
335 |
+
r = regions.get(n, set())
|
336 |
+
if ("repeat", repeat_id) in r:
|
337 |
+
print(f"repeating {n}")
|
338 |
cfg["model_inputs"] = list(used_in_model - made_in_model)
|
339 |
cfg["model_outputs"] = list(made_in_model & used_in_loss)
|
340 |
cfg["loss_inputs"] = list(used_in_loss - made_in_loss)
|