Spaces:
Running
Running
Repeat boxes working.
Browse files
examples/Model definition
CHANGED
@@ -438,8 +438,8 @@
|
|
438 |
"height": 200.0,
|
439 |
"id": "MSE loss 2",
|
440 |
"position": {
|
441 |
-
"x":
|
442 |
-
"y": -
|
443 |
},
|
444 |
"type": "basic",
|
445 |
"width": 200.0
|
|
|
438 |
"height": 200.0,
|
439 |
"id": "MSE loss 2",
|
440 |
"position": {
|
441 |
+
"x": 309.4422414664647,
|
442 |
+
"y": -552.1056805642488
|
443 |
},
|
444 |
"type": "basic",
|
445 |
"width": 200.0
|
examples/Model use
CHANGED
@@ -579,54 +579,54 @@
|
|
579 |
],
|
580 |
"data": [
|
581 |
[
|
582 |
-
"[0.
|
583 |
-
"[1.
|
584 |
-
"[1.
|
585 |
],
|
586 |
[
|
587 |
-
"[0.
|
588 |
-
"[1.
|
589 |
-
"[1.
|
590 |
],
|
591 |
[
|
592 |
-
"[0.
|
593 |
-
"[1.
|
594 |
-
"[1.
|
595 |
],
|
596 |
[
|
597 |
-
"[0.
|
598 |
-
"[1.
|
599 |
-
"[1.
|
600 |
],
|
601 |
[
|
602 |
-
"[0.
|
603 |
-
"[1.
|
604 |
-
"[1.
|
605 |
],
|
606 |
[
|
607 |
-
"[0.
|
608 |
-
"[1.
|
609 |
-
"[1.
|
610 |
],
|
611 |
[
|
612 |
-
"[0.
|
613 |
-
"[1.
|
614 |
-
"[1.
|
615 |
],
|
616 |
[
|
617 |
-
"[0.
|
618 |
-
"[1.
|
619 |
-
"[1.
|
620 |
],
|
621 |
[
|
622 |
-
"[0.
|
623 |
-
"[1.
|
624 |
-
"[1.
|
625 |
],
|
626 |
[
|
627 |
-
"[0.
|
628 |
-
"[1.
|
629 |
-
"[1.
|
630 |
]
|
631 |
]
|
632 |
},
|
@@ -648,6 +648,10 @@
|
|
648 |
"[0.11560339 0.57495481 0.76535827 0.0391947 ]",
|
649 |
"[1.11560345 1.57495475 1.76535821 1.0391947 ]"
|
650 |
],
|
|
|
|
|
|
|
|
|
651 |
[
|
652 |
"[0.76807946 0.98855817 0.08259124 0.01730657]",
|
653 |
"[1.76807952 1.98855817 1.0825913 1.01730657]"
|
@@ -681,12 +685,12 @@
|
|
681 |
"[1.98324287 1.99464178 1.14008355 1.47651017]"
|
682 |
],
|
683 |
[
|
684 |
-
"[0.
|
685 |
-
"[1.
|
686 |
],
|
687 |
[
|
688 |
-
"[0.
|
689 |
-
"[1.
|
690 |
],
|
691 |
[
|
692 |
"[0.04508126 0.76880038 0.80721325 0.62542385]",
|
@@ -708,6 +712,10 @@
|
|
708 |
"[0.24388778 0.07268471 0.68350857 0.73431659]",
|
709 |
"[1.24388778 1.07268476 1.68350863 1.73431659]"
|
710 |
],
|
|
|
|
|
|
|
|
|
711 |
[
|
712 |
"[0.56922203 0.98222166 0.76851749 0.28615737]",
|
713 |
"[1.56922197 1.9822216 1.76851749 1.28615737]"
|
@@ -720,6 +728,10 @@
|
|
720 |
"[0.90817457 0.89270043 0.38583666 0.66566533]",
|
721 |
"[1.90817451 1.89270043 1.3858366 1.66566539]"
|
722 |
],
|
|
|
|
|
|
|
|
|
723 |
[
|
724 |
"[0.68062544 0.98093534 0.14778823 0.53244978]",
|
725 |
"[1.68062544 1.98093534 1.14778829 1.53244972]"
|
@@ -740,10 +752,6 @@
|
|
740 |
"[0.23942459 0.90487361 0.69337189 0.65089428]",
|
741 |
"[1.23942459 1.90487361 1.69337189 1.65089428]"
|
742 |
],
|
743 |
-
[
|
744 |
-
"[0.94516498 0.08422136 0.5608117 0.07652664]",
|
745 |
-
"[1.94516492 1.08422136 1.56081176 1.07652664]"
|
746 |
-
],
|
747 |
[
|
748 |
"[0.26661873 0.45946234 0.13510543 0.81294441]",
|
749 |
"[1.26661873 1.4594624 1.13510537 1.81294441]"
|
@@ -772,10 +780,6 @@
|
|
772 |
"[0.78956431 0.87284744 0.06880784 0.03455889]",
|
773 |
"[1.78956437 1.87284744 1.06880784 1.03455889]"
|
774 |
],
|
775 |
-
[
|
776 |
-
"[0.94221359 0.57740951 0.98649532 0.40934443]",
|
777 |
-
"[1.94221354 1.57740951 1.98649526 1.40934443]"
|
778 |
-
],
|
779 |
[
|
780 |
"[0.00497234 0.39319336 0.57054168 0.75150961]",
|
781 |
"[1.00497234 1.39319336 1.57054162 1.75150967]"
|
@@ -788,6 +792,14 @@
|
|
788 |
"[0.72290605 0.96945059 0.68354797 0.15270454]",
|
789 |
"[1.72290611 1.96945059 1.68354797 1.15270448]"
|
790 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
791 |
[
|
792 |
"[0.52784437 0.54268694 0.12358981 0.72116476]",
|
793 |
"[1.52784443 1.54268694 1.12358975 1.7211647 ]"
|
@@ -796,10 +808,6 @@
|
|
796 |
"[0.73217702 0.65233225 0.44077861 0.33837909]",
|
797 |
"[1.73217702 1.65233231 1.44077861 1.33837914]"
|
798 |
],
|
799 |
-
[
|
800 |
-
"[0.34084332 0.73018837 0.54168713 0.91440833]",
|
801 |
-
"[1.34084332 1.73018837 1.54168713 1.91440833]"
|
802 |
-
],
|
803 |
[
|
804 |
"[0.60110539 0.3618983 0.32342511 0.98672163]",
|
805 |
"[1.60110545 1.3618983 1.32342505 1.98672163]"
|
@@ -836,6 +844,10 @@
|
|
836 |
"[0.95928186 0.84273899 0.71514636 0.38619852]",
|
837 |
"[1.95928192 1.84273899 1.7151463 1.38619852]"
|
838 |
],
|
|
|
|
|
|
|
|
|
839 |
[
|
840 |
"[0.9829582 0.59269661 0.40120947 0.95487177]",
|
841 |
"[1.9829582 1.59269667 1.40120947 1.95487177]"
|
@@ -844,10 +856,6 @@
|
|
844 |
"[0.79905868 0.89367443 0.75429088 0.3190186 ]",
|
845 |
"[1.79905868 1.89367437 1.75429082 1.3190186 ]"
|
846 |
],
|
847 |
-
[
|
848 |
-
"[0.54914117 0.03810108 0.87531954 0.73044223]",
|
849 |
-
"[1.54914117 1.03810108 1.87531948 1.73044229]"
|
850 |
-
],
|
851 |
[
|
852 |
"[0.67418337 0.79634351 0.23229051 0.71345252]",
|
853 |
"[1.67418337 1.79634356 1.23229051 1.71345258]"
|
@@ -860,14 +868,14 @@
|
|
860 |
"[0.81788456 0.58174163 0.29376316 0.7971254 ]",
|
861 |
"[1.81788456 1.58174157 1.29376316 1.79712534]"
|
862 |
],
|
863 |
-
[
|
864 |
-
"[0.94559073 0.65736622 0.25761551 0.48553199]",
|
865 |
-
"[1.94559073 1.65736628 1.25761557 1.48553205]"
|
866 |
-
],
|
867 |
[
|
868 |
"[0.60075855 0.12234765 0.00614399 0.30560958]",
|
869 |
"[1.60075855 1.12234759 1.00614405 1.30560958]"
|
870 |
],
|
|
|
|
|
|
|
|
|
871 |
[
|
872 |
"[0.02162331 0.81861657 0.92468154 0.07808572]",
|
873 |
"[1.02162337 1.81861663 1.92468154 1.07808566]"
|
@@ -896,10 +904,6 @@
|
|
896 |
"[0.60609657 0.96257663 0.19292736 0.95702219]",
|
897 |
"[1.60609651 1.96257663 1.19292736 1.95702219]"
|
898 |
],
|
899 |
-
[
|
900 |
-
"[0.80654246 0.08253473 0.74478531 0.71257162]",
|
901 |
-
"[1.8065424 1.08253479 1.74478531 1.71257162]"
|
902 |
-
],
|
903 |
[
|
904 |
"[0.70167565 0.26930219 0.5660674 0.61194974]",
|
905 |
"[1.70167565 1.26930213 1.56606746 1.61194968]"
|
@@ -908,10 +912,6 @@
|
|
908 |
"[0.76933283 0.86241865 0.44114518 0.65644735]",
|
909 |
"[1.76933289 1.86241865 1.44114518 1.65644741]"
|
910 |
],
|
911 |
-
[
|
912 |
-
"[0.59492421 0.90274489 0.38069052 0.46101224]",
|
913 |
-
"[1.59492421 1.90274489 1.38069057 1.46101224]"
|
914 |
-
],
|
915 |
[
|
916 |
"[0.15064228 0.03198934 0.25754827 0.51484001]",
|
917 |
"[1.15064228 1.03198934 1.25754833 1.51484001]"
|
@@ -920,6 +920,10 @@
|
|
920 |
"[0.12024075 0.21342516 0.56858408 0.58644271]",
|
921 |
"[1.12024069 1.21342516 1.56858408 1.58644271]"
|
922 |
],
|
|
|
|
|
|
|
|
|
923 |
[
|
924 |
"[0.49691743 0.61873293 0.90698647 0.94486356]",
|
925 |
"[1.49691749 1.61873293 1.90698647 1.94486356]"
|
@@ -948,18 +952,10 @@
|
|
948 |
"[0.80893755 0.92237449 0.88346356 0.93164903]",
|
949 |
"[1.80893755 1.92237449 1.88346362 1.93164897]"
|
950 |
],
|
951 |
-
[
|
952 |
-
"[0.12858278 0.09930819 0.83222693 0.72485673]",
|
953 |
-
"[1.12858272 1.09930825 1.83222699 1.72485673]"
|
954 |
-
],
|
955 |
[
|
956 |
"[0.72470158 0.4940322 0.41027349 0.89364016]",
|
957 |
"[1.72470164 1.49403214 1.41027355 1.89364016]"
|
958 |
],
|
959 |
-
[
|
960 |
-
"[0.47856545 0.46267092 0.6376707 0.84747767]",
|
961 |
-
"[1.47856545 1.46267092 1.63767076 1.84747767]"
|
962 |
-
],
|
963 |
[
|
964 |
"[0.49584109 0.80599248 0.07096875 0.75872749]",
|
965 |
"[1.49584103 1.80599248 1.07096875 1.75872755]"
|
@@ -992,6 +988,10 @@
|
|
992 |
"[0.72795159 0.79317838 0.27832931 0.96576637]",
|
993 |
"[1.72795153 1.79317832 1.27832937 1.96576643]"
|
994 |
],
|
|
|
|
|
|
|
|
|
995 |
[
|
996 |
"[0.68891573 0.25576538 0.96339929 0.503833 ]",
|
997 |
"[1.68891573 1.25576544 1.96339929 1.50383306]"
|
@@ -1000,7 +1000,7 @@
|
|
1000 |
}
|
1001 |
},
|
1002 |
"other": {
|
1003 |
-
"model": "ModelConfig(model=Sequential(\n (0) - Identity(): Input__tensor_1_x -> START_Repeat_1_output\n (1) - Linear(in_features=4, out_features=4, bias=True): START_Repeat_1_output -> Linear_2_output\n (2) - <function leaky_relu at
|
1004 |
},
|
1005 |
"relations": []
|
1006 |
},
|
@@ -1035,8 +1035,8 @@
|
|
1035 |
"Input__tensor_1_x"
|
1036 |
],
|
1037 |
"loss_inputs": [
|
1038 |
-
"
|
1039 |
-
"
|
1040 |
],
|
1041 |
"outputs": [
|
1042 |
"END_Repeat_1_output"
|
@@ -1210,8 +1210,8 @@
|
|
1210 |
"Input__tensor_1_x"
|
1211 |
],
|
1212 |
"loss_inputs": [
|
1213 |
-
"
|
1214 |
-
"
|
1215 |
],
|
1216 |
"outputs": [
|
1217 |
"END_Repeat_1_output"
|
@@ -1270,7 +1270,7 @@
|
|
1270 |
"type": "basic"
|
1271 |
},
|
1272 |
"params": {
|
1273 |
-
"epochs": "
|
1274 |
"input_mapping": "{\"map\":{\"Input__tensor_1_x\":{\"df\":\"df_train\",\"column\":\"x\"},\"Input__tensor_3_x\":{\"df\":\"df_train\",\"column\":\"y\"}}}",
|
1275 |
"model_name": "model"
|
1276 |
},
|
@@ -1322,8 +1322,8 @@
|
|
1322 |
"Input__tensor_1_x"
|
1323 |
],
|
1324 |
"loss_inputs": [
|
1325 |
-
"
|
1326 |
-
"
|
1327 |
],
|
1328 |
"outputs": [
|
1329 |
"END_Repeat_1_output"
|
|
|
579 |
],
|
580 |
"data": [
|
581 |
[
|
582 |
+
"[0.94559073 0.65736622 0.25761551 0.48553199]",
|
583 |
+
"[1.94559073 1.65736628 1.25761557 1.48553205]",
|
584 |
+
"[1.5948047637939453, 1.619612693786621, 1.5269112586975098, -0.008817584253847599]"
|
585 |
],
|
586 |
[
|
587 |
+
"[0.47856545 0.46267092 0.6376707 0.84747767]",
|
588 |
+
"[1.47856545 1.46267092 1.63767076 1.84747767]",
|
589 |
+
"[1.5928349494934082, 1.6176562309265137, 1.52553391456604, -0.008808750659227371]"
|
590 |
],
|
591 |
[
|
592 |
+
"[0.59492421 0.90274489 0.38069052 0.46101224]",
|
593 |
+
"[1.59492421 1.90274489 1.38069057 1.46101224]",
|
594 |
+
"[1.5592631101608276, 1.5841729640960693, 1.5020174980163574, -0.008660631254315376]"
|
595 |
],
|
596 |
[
|
597 |
+
"[0.12858278 0.09930819 0.83222693 0.72485673]",
|
598 |
+
"[1.12858272 1.09930825 1.83222699 1.72485673]",
|
599 |
+
"[1.469609260559082, 1.4953691959381104, 1.4395854473114014, -0.00825517252087593]"
|
600 |
],
|
601 |
[
|
602 |
+
"[0.94516498 0.08422136 0.5608117 0.07652664]",
|
603 |
+
"[1.94516492 1.08422136 1.56081176 1.07652664]",
|
604 |
+
"[1.5648787021636963, 1.5899150371551514, 1.5060429573059082, -0.008683123625814915]"
|
605 |
],
|
606 |
[
|
607 |
+
"[0.54914117 0.03810108 0.87531954 0.73044223]",
|
608 |
+
"[1.54914117 1.03810108 1.87531948 1.73044229]",
|
609 |
+
"[1.6262837648391724, 1.650843620300293, 1.548863410949707, -0.00895910244435072]"
|
610 |
],
|
611 |
[
|
612 |
+
"[0.94221359 0.57740951 0.98649532 0.40934443]",
|
613 |
+
"[1.94221354 1.57740951 1.98649526 1.40934443]",
|
614 |
+
"[1.8493703603744507, 1.8721930980682373, 1.704444169998169, -0.009961890056729317]"
|
615 |
],
|
616 |
[
|
617 |
+
"[0.80654246 0.08253473 0.74478531 0.71257162]",
|
618 |
+
"[1.8065424 1.08253479 1.74478531 1.71257162]",
|
619 |
+
"[1.672502040863037, 1.6967051029205322, 1.5810983180999756, -0.009166811592876911]"
|
620 |
],
|
621 |
[
|
622 |
+
"[0.50272274 0.54912758 0.17663097 0.79070699]",
|
623 |
+
"[1.50272274 1.54912758 1.17663097 1.79070699]",
|
624 |
+
"[1.4309396743774414, 1.4570224285125732, 1.4126286506652832, -0.008081027306616306]"
|
625 |
],
|
626 |
[
|
627 |
+
"[0.34084332 0.73018837 0.54168713 0.91440833]",
|
628 |
+
"[1.34084332 1.73018837 1.54168713 1.91440833]",
|
629 |
+
"[1.5581963062286377, 1.5832865238189697, 1.5013742446899414, -0.008653069846332073]"
|
630 |
]
|
631 |
]
|
632 |
},
|
|
|
648 |
"[0.11560339 0.57495481 0.76535827 0.0391947 ]",
|
649 |
"[1.11560345 1.57495475 1.76535821 1.0391947 ]"
|
650 |
],
|
651 |
+
[
|
652 |
+
"[0.19409031 0.68692201 0.60667384 0.57829887]",
|
653 |
+
"[1.19409037 1.68692207 1.60667384 1.57829881]"
|
654 |
+
],
|
655 |
[
|
656 |
"[0.76807946 0.98855817 0.08259124 0.01730657]",
|
657 |
"[1.76807952 1.98855817 1.0825913 1.01730657]"
|
|
|
685 |
"[1.98324287 1.99464178 1.14008355 1.47651017]"
|
686 |
],
|
687 |
[
|
688 |
+
"[0.11693293 0.49860179 0.55020827 0.88832849]",
|
689 |
+
"[1.11693287 1.49860179 1.55020833 1.88832855]"
|
690 |
],
|
691 |
[
|
692 |
+
"[0.48959708 0.48549271 0.32688856 0.356677 ]",
|
693 |
+
"[1.48959708 1.48549271 1.32688856 1.35667706]"
|
694 |
],
|
695 |
[
|
696 |
"[0.04508126 0.76880038 0.80721325 0.62542385]",
|
|
|
712 |
"[0.24388778 0.07268471 0.68350857 0.73431659]",
|
713 |
"[1.24388778 1.07268476 1.68350863 1.73431659]"
|
714 |
],
|
715 |
+
[
|
716 |
+
"[0.62569475 0.9881897 0.83639616 0.9828859 ]",
|
717 |
+
"[1.62569475 1.9881897 1.83639622 1.98288584]"
|
718 |
+
],
|
719 |
[
|
720 |
"[0.56922203 0.98222166 0.76851749 0.28615737]",
|
721 |
"[1.56922197 1.9822216 1.76851749 1.28615737]"
|
|
|
728 |
"[0.90817457 0.89270043 0.38583666 0.66566533]",
|
729 |
"[1.90817451 1.89270043 1.3858366 1.66566539]"
|
730 |
],
|
731 |
+
[
|
732 |
+
"[0.48507756 0.80808765 0.77162558 0.47834778]",
|
733 |
+
"[1.48507762 1.80808759 1.77162552 1.47834778]"
|
734 |
+
],
|
735 |
[
|
736 |
"[0.68062544 0.98093534 0.14778823 0.53244978]",
|
737 |
"[1.68062544 1.98093534 1.14778829 1.53244972]"
|
|
|
752 |
"[0.23942459 0.90487361 0.69337189 0.65089428]",
|
753 |
"[1.23942459 1.90487361 1.69337189 1.65089428]"
|
754 |
],
|
|
|
|
|
|
|
|
|
755 |
[
|
756 |
"[0.26661873 0.45946234 0.13510543 0.81294441]",
|
757 |
"[1.26661873 1.4594624 1.13510537 1.81294441]"
|
|
|
780 |
"[0.78956431 0.87284744 0.06880784 0.03455889]",
|
781 |
"[1.78956437 1.87284744 1.06880784 1.03455889]"
|
782 |
],
|
|
|
|
|
|
|
|
|
783 |
[
|
784 |
"[0.00497234 0.39319336 0.57054168 0.75150961]",
|
785 |
"[1.00497234 1.39319336 1.57054162 1.75150967]"
|
|
|
792 |
"[0.72290605 0.96945059 0.68354797 0.15270454]",
|
793 |
"[1.72290611 1.96945059 1.68354797 1.15270448]"
|
794 |
],
|
795 |
+
[
|
796 |
+
"[0.75292218 0.81470108 0.49657214 0.56217098]",
|
797 |
+
"[1.75292218 1.81470108 1.49657214 1.56217098]"
|
798 |
+
],
|
799 |
+
[
|
800 |
+
"[0.33480108 0.59181517 0.76198453 0.98062384]",
|
801 |
+
"[1.33480108 1.59181523 1.76198459 1.98062384]"
|
802 |
+
],
|
803 |
[
|
804 |
"[0.52784437 0.54268694 0.12358981 0.72116476]",
|
805 |
"[1.52784443 1.54268694 1.12358975 1.7211647 ]"
|
|
|
808 |
"[0.73217702 0.65233225 0.44077861 0.33837909]",
|
809 |
"[1.73217702 1.65233231 1.44077861 1.33837914]"
|
810 |
],
|
|
|
|
|
|
|
|
|
811 |
[
|
812 |
"[0.60110539 0.3618983 0.32342511 0.98672163]",
|
813 |
"[1.60110545 1.3618983 1.32342505 1.98672163]"
|
|
|
844 |
"[0.95928186 0.84273899 0.71514636 0.38619852]",
|
845 |
"[1.95928192 1.84273899 1.7151463 1.38619852]"
|
846 |
],
|
847 |
+
[
|
848 |
+
"[0.32565445 0.90939188 0.07488042 0.13730896]",
|
849 |
+
"[1.32565451 1.90939188 1.07488036 1.13730896]"
|
850 |
+
],
|
851 |
[
|
852 |
"[0.9829582 0.59269661 0.40120947 0.95487177]",
|
853 |
"[1.9829582 1.59269667 1.40120947 1.95487177]"
|
|
|
856 |
"[0.79905868 0.89367443 0.75429088 0.3190186 ]",
|
857 |
"[1.79905868 1.89367437 1.75429082 1.3190186 ]"
|
858 |
],
|
|
|
|
|
|
|
|
|
859 |
[
|
860 |
"[0.67418337 0.79634351 0.23229051 0.71345252]",
|
861 |
"[1.67418337 1.79634356 1.23229051 1.71345258]"
|
|
|
868 |
"[0.81788456 0.58174163 0.29376316 0.7971254 ]",
|
869 |
"[1.81788456 1.58174157 1.29376316 1.79712534]"
|
870 |
],
|
|
|
|
|
|
|
|
|
871 |
[
|
872 |
"[0.60075855 0.12234765 0.00614399 0.30560958]",
|
873 |
"[1.60075855 1.12234759 1.00614405 1.30560958]"
|
874 |
],
|
875 |
+
[
|
876 |
+
"[0.39147133 0.29854035 0.84663737 0.58175623]",
|
877 |
+
"[1.39147139 1.29854035 1.84663737 1.58175623]"
|
878 |
+
],
|
879 |
[
|
880 |
"[0.02162331 0.81861657 0.92468154 0.07808572]",
|
881 |
"[1.02162337 1.81861663 1.92468154 1.07808566]"
|
|
|
904 |
"[0.60609657 0.96257663 0.19292736 0.95702219]",
|
905 |
"[1.60609651 1.96257663 1.19292736 1.95702219]"
|
906 |
],
|
|
|
|
|
|
|
|
|
907 |
[
|
908 |
"[0.70167565 0.26930219 0.5660674 0.61194974]",
|
909 |
"[1.70167565 1.26930213 1.56606746 1.61194968]"
|
|
|
912 |
"[0.76933283 0.86241865 0.44114518 0.65644735]",
|
913 |
"[1.76933289 1.86241865 1.44114518 1.65644741]"
|
914 |
],
|
|
|
|
|
|
|
|
|
915 |
[
|
916 |
"[0.15064228 0.03198934 0.25754827 0.51484001]",
|
917 |
"[1.15064228 1.03198934 1.25754833 1.51484001]"
|
|
|
920 |
"[0.12024075 0.21342516 0.56858408 0.58644271]",
|
921 |
"[1.12024069 1.21342516 1.56858408 1.58644271]"
|
922 |
],
|
923 |
+
[
|
924 |
+
"[0.91730917 0.22574073 0.09591609 0.33056474]",
|
925 |
+
"[1.91730917 1.22574067 1.09591603 1.33056474]"
|
926 |
+
],
|
927 |
[
|
928 |
"[0.49691743 0.61873293 0.90698647 0.94486356]",
|
929 |
"[1.49691749 1.61873293 1.90698647 1.94486356]"
|
|
|
952 |
"[0.80893755 0.92237449 0.88346356 0.93164903]",
|
953 |
"[1.80893755 1.92237449 1.88346362 1.93164897]"
|
954 |
],
|
|
|
|
|
|
|
|
|
955 |
[
|
956 |
"[0.72470158 0.4940322 0.41027349 0.89364016]",
|
957 |
"[1.72470164 1.49403214 1.41027355 1.89364016]"
|
958 |
],
|
|
|
|
|
|
|
|
|
959 |
[
|
960 |
"[0.49584109 0.80599248 0.07096875 0.75872749]",
|
961 |
"[1.49584103 1.80599248 1.07096875 1.75872755]"
|
|
|
988 |
"[0.72795159 0.79317838 0.27832931 0.96576637]",
|
989 |
"[1.72795153 1.79317832 1.27832937 1.96576643]"
|
990 |
],
|
991 |
+
[
|
992 |
+
"[0.87608397 0.93200487 0.80169648 0.37758952]",
|
993 |
+
"[1.87608397 1.93200493 1.80169654 1.37758946]"
|
994 |
+
],
|
995 |
[
|
996 |
"[0.68891573 0.25576538 0.96339929 0.503833 ]",
|
997 |
"[1.68891573 1.25576544 1.96339929 1.50383306]"
|
|
|
1000 |
}
|
1001 |
},
|
1002 |
"other": {
|
1003 |
+
"model": "ModelConfig(model=Sequential(\n (0) - Identity(): Input__tensor_1_x -> START_Repeat_1_output\n (1) - Linear(in_features=4, out_features=4, bias=True): START_Repeat_1_output -> Linear_2_output\n (2) - <function leaky_relu at 0x75e660938220>: Linear_2_output -> Activation_1_output\n (3) - Identity(): Activation_1_output -> START_Repeat_1_output\n (4) - Linear(in_features=4, out_features=4, bias=True): START_Repeat_1_output -> Linear_2_output\n (5) - <function leaky_relu at 0x75e660938220>: Linear_2_output -> Activation_1_output\n (6) - Identity(): Activation_1_output -> START_Repeat_1_output\n (7) - Linear(in_features=4, out_features=4, bias=True): START_Repeat_1_output -> Linear_2_output\n (8) - <function leaky_relu at 0x75e660938220>: Linear_2_output -> Activation_1_output\n (9) - Identity(): Activation_1_output -> END_Repeat_1_output\n (10) - Identity(): END_Repeat_1_output -> END_Repeat_1_output\n), model_inputs=['Input__tensor_1_x'], model_outputs=['END_Repeat_1_output'], loss_inputs=['Input__tensor_3_x', 'END_Repeat_1_output'], loss=Sequential(\n (0) - <function mse_loss at 0x75e660939d00>: END_Repeat_1_output, Input__tensor_3_x -> MSE_loss_2_output\n (1) - Identity(): MSE_loss_2_output -> loss\n), optimizer=SGD (\nParameter Group 0\n dampening: 0\n differentiable: False\n foreach: None\n fused: None\n lr: 0.1\n maximize: False\n momentum: 0\n nesterov: False\n weight_decay: 0\n), source_workspace=None, trained=True)"
|
1004 |
},
|
1005 |
"relations": []
|
1006 |
},
|
|
|
1035 |
"Input__tensor_1_x"
|
1036 |
],
|
1037 |
"loss_inputs": [
|
1038 |
+
"Input__tensor_3_x",
|
1039 |
+
"END_Repeat_1_output"
|
1040 |
],
|
1041 |
"outputs": [
|
1042 |
"END_Repeat_1_output"
|
|
|
1210 |
"Input__tensor_1_x"
|
1211 |
],
|
1212 |
"loss_inputs": [
|
1213 |
+
"Input__tensor_3_x",
|
1214 |
+
"END_Repeat_1_output"
|
1215 |
],
|
1216 |
"outputs": [
|
1217 |
"END_Repeat_1_output"
|
|
|
1270 |
"type": "basic"
|
1271 |
},
|
1272 |
"params": {
|
1273 |
+
"epochs": "110",
|
1274 |
"input_mapping": "{\"map\":{\"Input__tensor_1_x\":{\"df\":\"df_train\",\"column\":\"x\"},\"Input__tensor_3_x\":{\"df\":\"df_train\",\"column\":\"y\"}}}",
|
1275 |
"model_name": "model"
|
1276 |
},
|
|
|
1322 |
"Input__tensor_1_x"
|
1323 |
],
|
1324 |
"loss_inputs": [
|
1325 |
+
"Input__tensor_3_x",
|
1326 |
+
"END_Repeat_1_output"
|
1327 |
],
|
1328 |
"outputs": [
|
1329 |
"END_Repeat_1_output"
|
lynxkite-core/src/lynxkite/core/ops.py
CHANGED
@@ -13,7 +13,7 @@ from typing_extensions import Annotated
|
|
13 |
if typing.TYPE_CHECKING:
|
14 |
from . import workspace
|
15 |
|
16 |
-
CATALOGS = {}
|
17 |
EXECUTORS = {}
|
18 |
|
19 |
typeof = type # We have some arguments called "type".
|
|
|
13 |
if typing.TYPE_CHECKING:
|
14 |
from . import workspace
|
15 |
|
16 |
+
CATALOGS: dict[str, dict[str, "Op"]] = {}
|
17 |
EXECUTORS = {}
|
18 |
|
19 |
typeof = type # We have some arguments called "type".
|
lynxkite-graph-analytics/src/lynxkite_graph_analytics/pytorch_model_ops.py
CHANGED
@@ -214,6 +214,9 @@ class ModelConfig:
|
|
214 |
source_workspace: str | None = None
|
215 |
trained: bool = False
|
216 |
|
|
|
|
|
|
|
217 |
def _forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
218 |
model_inputs = [inputs[i] for i in self.model_inputs]
|
219 |
output = self.model(*model_inputs)
|
@@ -270,7 +273,7 @@ class ModelBuilder:
|
|
270 |
def __init__(self, ws: workspace.Workspace, inputs: dict[str, torch.Tensor]):
|
271 |
self.catalog = ops.CATALOGS[ENV]
|
272 |
optimizers = []
|
273 |
-
self.nodes = {}
|
274 |
for node in ws.nodes:
|
275 |
self.nodes[node.id] = node
|
276 |
if node.data.title == "Optimizer":
|
@@ -279,8 +282,8 @@ class ModelBuilder:
|
|
279 |
assert len(optimizers) == 1, f"More than one optimizer found: {optimizers}"
|
280 |
[self.optimizer] = optimizers
|
281 |
self.dependencies = {n.id: [] for n in ws.nodes}
|
282 |
-
self.in_edges = {}
|
283 |
-
self.out_edges = {}
|
284 |
repeats = []
|
285 |
for e in ws.edges:
|
286 |
if self.nodes[e.target].data.title == "Repeat":
|
@@ -367,21 +370,12 @@ class ModelBuilder:
|
|
367 |
t = node.data.title
|
368 |
op = self.catalog[t]
|
369 |
p = op.convert_params(node.data.params)
|
370 |
-
inputs = {}
|
371 |
-
for n in self.in_edges.get(node_id, []):
|
372 |
-
for b, h in self.in_edges[node_id][n]:
|
373 |
-
i = _to_id(b, h)
|
374 |
-
inputs[n] = i
|
375 |
-
outputs = {}
|
376 |
-
for out in self.out_edges.get(node_id, []):
|
377 |
-
i = _to_id(node_id, out)
|
378 |
-
outputs[out] = i
|
379 |
match t:
|
380 |
case "Repeat":
|
381 |
if node_id.startswith("END "):
|
382 |
repeat_id = node_id.removeprefix("END ")
|
383 |
start_id = f"START {repeat_id}"
|
384 |
-
|
385 |
after_start = self.all_downstream(start_id)
|
386 |
after_end = self.all_downstream(node_id)
|
387 |
before_end = self.all_upstream(node_id)
|
@@ -390,28 +384,64 @@ class ModelBuilder:
|
|
390 |
assert affected_nodes == repeated_nodes, (
|
391 |
f"edges leave repeated section '{repeat_id}':\n{affected_nodes - repeated_nodes}"
|
392 |
)
|
393 |
-
for
|
394 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
395 |
case "Optimizer" | "Input: tensor" | "Input: graph edges" | "Input: sequential":
|
396 |
return
|
397 |
-
|
398 |
-
|
399 |
-
self.layers.append(layer)
|
400 |
|
401 |
-
def run_op(self,
|
402 |
"""Returns the layer produced by this op."""
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
if op.func != ops.no_op:
|
407 |
-
|
408 |
-
|
409 |
-
layer =
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
415 |
return layer
|
416 |
|
417 |
def build_model(self) -> ModelConfig:
|
@@ -462,7 +492,6 @@ class ModelBuilder:
|
|
462 |
p = op.convert_params(self.nodes[self.optimizer].data.params)
|
463 |
o = getattr(torch.optim, p["type"].name)
|
464 |
cfg["optimizer"] = o(cfg["model"].parameters(), lr=p["lr"])
|
465 |
-
print(cfg)
|
466 |
return ModelConfig(**cfg)
|
467 |
|
468 |
|
|
|
214 |
source_workspace: str | None = None
|
215 |
trained: bool = False
|
216 |
|
217 |
+
def num_parameters(self) -> int:
|
218 |
+
return sum(p.numel() for p in self.model.parameters())
|
219 |
+
|
220 |
def _forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
221 |
model_inputs = [inputs[i] for i in self.model_inputs]
|
222 |
output = self.model(*model_inputs)
|
|
|
273 |
def __init__(self, ws: workspace.Workspace, inputs: dict[str, torch.Tensor]):
|
274 |
self.catalog = ops.CATALOGS[ENV]
|
275 |
optimizers = []
|
276 |
+
self.nodes: dict[str, workspace.WorkspaceNode] = {}
|
277 |
for node in ws.nodes:
|
278 |
self.nodes[node.id] = node
|
279 |
if node.data.title == "Optimizer":
|
|
|
282 |
assert len(optimizers) == 1, f"More than one optimizer found: {optimizers}"
|
283 |
[self.optimizer] = optimizers
|
284 |
self.dependencies = {n.id: [] for n in ws.nodes}
|
285 |
+
self.in_edges: dict[str, dict[str, list[(str, str)]]] = {}
|
286 |
+
self.out_edges: dict[str, dict[str, list[(str, str)]]] = {}
|
287 |
repeats = []
|
288 |
for e in ws.edges:
|
289 |
if self.nodes[e.target].data.title == "Repeat":
|
|
|
370 |
t = node.data.title
|
371 |
op = self.catalog[t]
|
372 |
p = op.convert_params(node.data.params)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
match t:
|
374 |
case "Repeat":
|
375 |
if node_id.startswith("END "):
|
376 |
repeat_id = node_id.removeprefix("END ")
|
377 |
start_id = f"START {repeat_id}"
|
378 |
+
[last_output] = self.in_edges[node_id]["input"]
|
379 |
after_start = self.all_downstream(start_id)
|
380 |
after_end = self.all_downstream(node_id)
|
381 |
before_end = self.all_upstream(node_id)
|
|
|
384 |
assert affected_nodes == repeated_nodes, (
|
385 |
f"edges leave repeated section '{repeat_id}':\n{affected_nodes - repeated_nodes}"
|
386 |
)
|
387 |
+
repeated_layers = [e for e in self.layers if e._origin_id in repeated_nodes]
|
388 |
+
for i in range(p["times"] - 1):
|
389 |
+
# Copy repeat section's output to repeat section's input.
|
390 |
+
self.layers.append(
|
391 |
+
self.empty_layer(
|
392 |
+
node_id,
|
393 |
+
inputs=[_to_id(*last_output)],
|
394 |
+
outputs=[_to_id(start_id, "output")],
|
395 |
+
)
|
396 |
+
)
|
397 |
+
# Repeat the layers in the section.
|
398 |
+
for layer in repeated_layers:
|
399 |
+
if p["same_weights"]:
|
400 |
+
self.layers.append(
|
401 |
+
Layer(
|
402 |
+
layer.module,
|
403 |
+
shapes=layer.shapes,
|
404 |
+
_origin_id=layer._origin_id,
|
405 |
+
_inputs=layer._inputs,
|
406 |
+
_outputs=layer._outputs,
|
407 |
+
)
|
408 |
+
)
|
409 |
+
else:
|
410 |
+
self.run_node(layer._origin_id)
|
411 |
+
self.layers.append(self.run_op(node_id, op, p))
|
412 |
case "Optimizer" | "Input: tensor" | "Input: graph edges" | "Input: sequential":
|
413 |
return
|
414 |
+
case _:
|
415 |
+
self.layers.append(self.run_op(node_id, op, p))
|
|
|
416 |
|
417 |
+
def run_op(self, node_id: str, op: ops.Op, params) -> Layer:
|
418 |
"""Returns the layer produced by this op."""
|
419 |
+
inputs = [_to_id(*i) for n in op.inputs for i in self.in_edges[node_id][n]]
|
420 |
+
outputs = [_to_id(node_id, n) for n in op.outputs]
|
421 |
+
layer = self.empty_layer(node_id, inputs, outputs)
|
422 |
if op.func != ops.no_op:
|
423 |
+
op_layer = op.func(*layer._inputs, **params)
|
424 |
+
layer.module = op_layer.module
|
425 |
+
layer.shapes = op_layer.shapes
|
426 |
+
for o in layer._outputs:
|
427 |
+
self.sizes[o._id] = o.shape
|
428 |
+
return layer
|
429 |
+
|
430 |
+
def empty_layer(self, id: str, inputs: list[str], outputs: list[str]) -> Layer:
|
431 |
+
"""Creates an identity layer. Assumes that outputs have the same shapes as inputs."""
|
432 |
+
layer_inputs = [TensorRef(i, shape=self.sizes.get(i, 1)) for i in inputs]
|
433 |
+
layer_outputs = []
|
434 |
+
for i, o in zip(inputs, outputs):
|
435 |
+
shape = self.sizes.get(i, 1)
|
436 |
+
layer_outputs.append(TensorRef(o, shape=shape))
|
437 |
+
self.sizes[o] = shape
|
438 |
+
layer = Layer(
|
439 |
+
torch.nn.Identity(),
|
440 |
+
shapes=[self.sizes[o._id] for o in layer_outputs],
|
441 |
+
_inputs=layer_inputs,
|
442 |
+
_outputs=layer_outputs,
|
443 |
+
_origin_id=id,
|
444 |
+
)
|
445 |
return layer
|
446 |
|
447 |
def build_model(self) -> ModelConfig:
|
|
|
492 |
p = op.convert_params(self.nodes[self.optimizer].data.params)
|
493 |
o = getattr(torch.optim, p["type"].name)
|
494 |
cfg["optimizer"] = o(cfg["model"].parameters(), lr=p["lr"])
|
|
|
495 |
return ModelConfig(**cfg)
|
496 |
|
497 |
|