darabos commited on
Commit
17a0053
·
1 Parent(s): 69610f6

Repeat boxes working.

Browse files
examples/Model definition CHANGED
@@ -438,8 +438,8 @@
438
  "height": 200.0,
439
  "id": "MSE loss 2",
440
  "position": {
441
- "x": 690.0,
442
- "y": -480.0
443
  },
444
  "type": "basic",
445
  "width": 200.0
 
438
  "height": 200.0,
439
  "id": "MSE loss 2",
440
  "position": {
441
+ "x": 309.4422414664647,
442
+ "y": -552.1056805642488
443
  },
444
  "type": "basic",
445
  "width": 200.0
examples/Model use CHANGED
@@ -579,54 +579,54 @@
579
  ],
580
  "data": [
581
  [
582
- "[0.33480108 0.59181517 0.76198453 0.98062384]",
583
- "[1.33480108 1.59181523 1.76198459 1.98062384]",
584
- "[1.3419755697250366, 1.5946478843688965, 1.7717586755752563, 1.9897401332855225]"
585
  ],
586
  [
587
- "[0.91730917 0.22574073 0.09591609 0.33056474]",
588
- "[1.91730917 1.22574067 1.09591603 1.33056474]",
589
- "[1.900892972946167, 1.2247941493988037, 1.0862866640090942, 1.323314905166626]"
590
  ],
591
  [
592
- "[0.32565445 0.90939188 0.07488042 0.13730896]",
593
- "[1.32565451 1.90939188 1.07488036 1.13730896]",
594
- "[1.3460955619812012, 1.8960161209106445, 1.0530263185501099, 1.1075329780578613]"
595
  ],
596
  [
597
- "[0.87608397 0.93200487 0.80169648 0.37758952]",
598
- "[1.87608397 1.93200493 1.80169654 1.37758946]",
599
- "[1.87070894241333, 1.9386992454528809, 1.8151044845581055, 1.3952441215515137]"
600
  ],
601
  [
602
- "[0.39147133 0.29854035 0.84663737 0.58175623]",
603
- "[1.39147139 1.29854035 1.84663737 1.58175623]",
604
- "[1.3877646923065186, 1.2995290756225586, 1.847062587738037, 1.583693265914917]"
605
  ],
606
  [
607
- "[0.48507756 0.80808765 0.77162558 0.47834778]",
608
- "[1.48507762 1.80808759 1.77162552 1.47834778]",
609
- "[1.490919828414917, 1.8087174892425537, 1.7757861614227295, 1.4824031591415405]"
610
  ],
611
  [
612
- "[0.75292218 0.81470108 0.49657214 0.56217098]",
613
- "[1.75292218 1.81470108 1.49657214 1.56217098]",
614
- "[1.7527031898498535, 1.8176040649414062, 1.503413438796997, 1.570152759552002]"
615
  ],
616
  [
617
- "[0.11693293 0.49860179 0.55020827 0.88832849]",
618
- "[1.11693287 1.49860179 1.55020833 1.88832855]",
619
- "[1.1314976215362549, 1.4944026470184326, 1.546830177307129, 1.8803892135620117]"
620
  ],
621
  [
622
- "[0.19409031 0.68692201 0.60667384 0.57829887]",
623
- "[1.19409037 1.68692207 1.60667384 1.57829881]",
624
- "[1.2091591358184814, 1.6816589832305908, 1.6011345386505127, 1.5684995651245117]"
625
  ],
626
  [
627
- "[0.62569475 0.9881897 0.83639616 0.9828859 ]",
628
- "[1.62569475 1.9881897 1.83639622 1.98288584]",
629
- "[1.6314740180969238, 1.996805191040039, 1.8592857122421265, 2.0075552463531494]"
630
  ]
631
  ]
632
  },
@@ -648,6 +648,10 @@
648
  "[0.11560339 0.57495481 0.76535827 0.0391947 ]",
649
  "[1.11560345 1.57495475 1.76535821 1.0391947 ]"
650
  ],
 
 
 
 
651
  [
652
  "[0.76807946 0.98855817 0.08259124 0.01730657]",
653
  "[1.76807952 1.98855817 1.0825913 1.01730657]"
@@ -681,12 +685,12 @@
681
  "[1.98324287 1.99464178 1.14008355 1.47651017]"
682
  ],
683
  [
684
- "[0.48959708 0.48549271 0.32688856 0.356677 ]",
685
- "[1.48959708 1.48549271 1.32688856 1.35667706]"
686
  ],
687
  [
688
- "[0.50272274 0.54912758 0.17663097 0.79070699]",
689
- "[1.50272274 1.54912758 1.17663097 1.79070699]"
690
  ],
691
  [
692
  "[0.04508126 0.76880038 0.80721325 0.62542385]",
@@ -708,6 +712,10 @@
708
  "[0.24388778 0.07268471 0.68350857 0.73431659]",
709
  "[1.24388778 1.07268476 1.68350863 1.73431659]"
710
  ],
 
 
 
 
711
  [
712
  "[0.56922203 0.98222166 0.76851749 0.28615737]",
713
  "[1.56922197 1.9822216 1.76851749 1.28615737]"
@@ -720,6 +728,10 @@
720
  "[0.90817457 0.89270043 0.38583666 0.66566533]",
721
  "[1.90817451 1.89270043 1.3858366 1.66566539]"
722
  ],
 
 
 
 
723
  [
724
  "[0.68062544 0.98093534 0.14778823 0.53244978]",
725
  "[1.68062544 1.98093534 1.14778829 1.53244972]"
@@ -740,10 +752,6 @@
740
  "[0.23942459 0.90487361 0.69337189 0.65089428]",
741
  "[1.23942459 1.90487361 1.69337189 1.65089428]"
742
  ],
743
- [
744
- "[0.94516498 0.08422136 0.5608117 0.07652664]",
745
- "[1.94516492 1.08422136 1.56081176 1.07652664]"
746
- ],
747
  [
748
  "[0.26661873 0.45946234 0.13510543 0.81294441]",
749
  "[1.26661873 1.4594624 1.13510537 1.81294441]"
@@ -772,10 +780,6 @@
772
  "[0.78956431 0.87284744 0.06880784 0.03455889]",
773
  "[1.78956437 1.87284744 1.06880784 1.03455889]"
774
  ],
775
- [
776
- "[0.94221359 0.57740951 0.98649532 0.40934443]",
777
- "[1.94221354 1.57740951 1.98649526 1.40934443]"
778
- ],
779
  [
780
  "[0.00497234 0.39319336 0.57054168 0.75150961]",
781
  "[1.00497234 1.39319336 1.57054162 1.75150967]"
@@ -788,6 +792,14 @@
788
  "[0.72290605 0.96945059 0.68354797 0.15270454]",
789
  "[1.72290611 1.96945059 1.68354797 1.15270448]"
790
  ],
 
 
 
 
 
 
 
 
791
  [
792
  "[0.52784437 0.54268694 0.12358981 0.72116476]",
793
  "[1.52784443 1.54268694 1.12358975 1.7211647 ]"
@@ -796,10 +808,6 @@
796
  "[0.73217702 0.65233225 0.44077861 0.33837909]",
797
  "[1.73217702 1.65233231 1.44077861 1.33837914]"
798
  ],
799
- [
800
- "[0.34084332 0.73018837 0.54168713 0.91440833]",
801
- "[1.34084332 1.73018837 1.54168713 1.91440833]"
802
- ],
803
  [
804
  "[0.60110539 0.3618983 0.32342511 0.98672163]",
805
  "[1.60110545 1.3618983 1.32342505 1.98672163]"
@@ -836,6 +844,10 @@
836
  "[0.95928186 0.84273899 0.71514636 0.38619852]",
837
  "[1.95928192 1.84273899 1.7151463 1.38619852]"
838
  ],
 
 
 
 
839
  [
840
  "[0.9829582 0.59269661 0.40120947 0.95487177]",
841
  "[1.9829582 1.59269667 1.40120947 1.95487177]"
@@ -844,10 +856,6 @@
844
  "[0.79905868 0.89367443 0.75429088 0.3190186 ]",
845
  "[1.79905868 1.89367437 1.75429082 1.3190186 ]"
846
  ],
847
- [
848
- "[0.54914117 0.03810108 0.87531954 0.73044223]",
849
- "[1.54914117 1.03810108 1.87531948 1.73044229]"
850
- ],
851
  [
852
  "[0.67418337 0.79634351 0.23229051 0.71345252]",
853
  "[1.67418337 1.79634356 1.23229051 1.71345258]"
@@ -860,14 +868,14 @@
860
  "[0.81788456 0.58174163 0.29376316 0.7971254 ]",
861
  "[1.81788456 1.58174157 1.29376316 1.79712534]"
862
  ],
863
- [
864
- "[0.94559073 0.65736622 0.25761551 0.48553199]",
865
- "[1.94559073 1.65736628 1.25761557 1.48553205]"
866
- ],
867
  [
868
  "[0.60075855 0.12234765 0.00614399 0.30560958]",
869
  "[1.60075855 1.12234759 1.00614405 1.30560958]"
870
  ],
 
 
 
 
871
  [
872
  "[0.02162331 0.81861657 0.92468154 0.07808572]",
873
  "[1.02162337 1.81861663 1.92468154 1.07808566]"
@@ -896,10 +904,6 @@
896
  "[0.60609657 0.96257663 0.19292736 0.95702219]",
897
  "[1.60609651 1.96257663 1.19292736 1.95702219]"
898
  ],
899
- [
900
- "[0.80654246 0.08253473 0.74478531 0.71257162]",
901
- "[1.8065424 1.08253479 1.74478531 1.71257162]"
902
- ],
903
  [
904
  "[0.70167565 0.26930219 0.5660674 0.61194974]",
905
  "[1.70167565 1.26930213 1.56606746 1.61194968]"
@@ -908,10 +912,6 @@
908
  "[0.76933283 0.86241865 0.44114518 0.65644735]",
909
  "[1.76933289 1.86241865 1.44114518 1.65644741]"
910
  ],
911
- [
912
- "[0.59492421 0.90274489 0.38069052 0.46101224]",
913
- "[1.59492421 1.90274489 1.38069057 1.46101224]"
914
- ],
915
  [
916
  "[0.15064228 0.03198934 0.25754827 0.51484001]",
917
  "[1.15064228 1.03198934 1.25754833 1.51484001]"
@@ -920,6 +920,10 @@
920
  "[0.12024075 0.21342516 0.56858408 0.58644271]",
921
  "[1.12024069 1.21342516 1.56858408 1.58644271]"
922
  ],
 
 
 
 
923
  [
924
  "[0.49691743 0.61873293 0.90698647 0.94486356]",
925
  "[1.49691749 1.61873293 1.90698647 1.94486356]"
@@ -948,18 +952,10 @@
948
  "[0.80893755 0.92237449 0.88346356 0.93164903]",
949
  "[1.80893755 1.92237449 1.88346362 1.93164897]"
950
  ],
951
- [
952
- "[0.12858278 0.09930819 0.83222693 0.72485673]",
953
- "[1.12858272 1.09930825 1.83222699 1.72485673]"
954
- ],
955
  [
956
  "[0.72470158 0.4940322 0.41027349 0.89364016]",
957
  "[1.72470164 1.49403214 1.41027355 1.89364016]"
958
  ],
959
- [
960
- "[0.47856545 0.46267092 0.6376707 0.84747767]",
961
- "[1.47856545 1.46267092 1.63767076 1.84747767]"
962
- ],
963
  [
964
  "[0.49584109 0.80599248 0.07096875 0.75872749]",
965
  "[1.49584103 1.80599248 1.07096875 1.75872755]"
@@ -992,6 +988,10 @@
992
  "[0.72795159 0.79317838 0.27832931 0.96576637]",
993
  "[1.72795153 1.79317832 1.27832937 1.96576643]"
994
  ],
 
 
 
 
995
  [
996
  "[0.68891573 0.25576538 0.96339929 0.503833 ]",
997
  "[1.68891573 1.25576544 1.96339929 1.50383306]"
@@ -1000,7 +1000,7 @@
1000
  }
1001
  },
1002
  "other": {
1003
- "model": "ModelConfig(model=Sequential(\n (0) - Identity(): Input__tensor_1_x -> START_Repeat_1_output\n (1) - Linear(in_features=4, out_features=4, bias=True): START_Repeat_1_output -> Linear_2_output\n (2) - <function leaky_relu at 0x759ed4f2c360>: Linear_2_output -> Activation_1_output\n (3) - Identity(): Activation_1_output -> END_Repeat_1_output\n (4) - Identity(): END_Repeat_1_output -> END_Repeat_1_output\n), model_inputs=['Input__tensor_1_x'], model_outputs=['END_Repeat_1_output'], loss_inputs=['END_Repeat_1_output', 'Input__tensor_3_x'], loss=Sequential(\n (0) - <function mse_loss at 0x759ed4f2de40>: END_Repeat_1_output, Input__tensor_3_x -> MSE_loss_2_output\n (1) - Identity(): MSE_loss_2_output -> loss\n), optimizer=SGD (\nParameter Group 0\n dampening: 0\n differentiable: False\n foreach: None\n fused: None\n lr: 0.1\n maximize: False\n momentum: 0\n nesterov: False\n weight_decay: 0\n), source_workspace=None, trained=True)"
1004
  },
1005
  "relations": []
1006
  },
@@ -1035,8 +1035,8 @@
1035
  "Input__tensor_1_x"
1036
  ],
1037
  "loss_inputs": [
1038
- "END_Repeat_1_output",
1039
- "Input__tensor_3_x"
1040
  ],
1041
  "outputs": [
1042
  "END_Repeat_1_output"
@@ -1210,8 +1210,8 @@
1210
  "Input__tensor_1_x"
1211
  ],
1212
  "loss_inputs": [
1213
- "END_Repeat_1_output",
1214
- "Input__tensor_3_x"
1215
  ],
1216
  "outputs": [
1217
  "END_Repeat_1_output"
@@ -1270,7 +1270,7 @@
1270
  "type": "basic"
1271
  },
1272
  "params": {
1273
- "epochs": "1500",
1274
  "input_mapping": "{\"map\":{\"Input__tensor_1_x\":{\"df\":\"df_train\",\"column\":\"x\"},\"Input__tensor_3_x\":{\"df\":\"df_train\",\"column\":\"y\"}}}",
1275
  "model_name": "model"
1276
  },
@@ -1322,8 +1322,8 @@
1322
  "Input__tensor_1_x"
1323
  ],
1324
  "loss_inputs": [
1325
- "END_Repeat_1_output",
1326
- "Input__tensor_3_x"
1327
  ],
1328
  "outputs": [
1329
  "END_Repeat_1_output"
 
579
  ],
580
  "data": [
581
  [
582
+ "[0.94559073 0.65736622 0.25761551 0.48553199]",
583
+ "[1.94559073 1.65736628 1.25761557 1.48553205]",
584
+ "[1.5948047637939453, 1.619612693786621, 1.5269112586975098, -0.008817584253847599]"
585
  ],
586
  [
587
+ "[0.47856545 0.46267092 0.6376707 0.84747767]",
588
+ "[1.47856545 1.46267092 1.63767076 1.84747767]",
589
+ "[1.5928349494934082, 1.6176562309265137, 1.52553391456604, -0.008808750659227371]"
590
  ],
591
  [
592
+ "[0.59492421 0.90274489 0.38069052 0.46101224]",
593
+ "[1.59492421 1.90274489 1.38069057 1.46101224]",
594
+ "[1.5592631101608276, 1.5841729640960693, 1.5020174980163574, -0.008660631254315376]"
595
  ],
596
  [
597
+ "[0.12858278 0.09930819 0.83222693 0.72485673]",
598
+ "[1.12858272 1.09930825 1.83222699 1.72485673]",
599
+ "[1.469609260559082, 1.4953691959381104, 1.4395854473114014, -0.00825517252087593]"
600
  ],
601
  [
602
+ "[0.94516498 0.08422136 0.5608117 0.07652664]",
603
+ "[1.94516492 1.08422136 1.56081176 1.07652664]",
604
+ "[1.5648787021636963, 1.5899150371551514, 1.5060429573059082, -0.008683123625814915]"
605
  ],
606
  [
607
+ "[0.54914117 0.03810108 0.87531954 0.73044223]",
608
+ "[1.54914117 1.03810108 1.87531948 1.73044229]",
609
+ "[1.6262837648391724, 1.650843620300293, 1.548863410949707, -0.00895910244435072]"
610
  ],
611
  [
612
+ "[0.94221359 0.57740951 0.98649532 0.40934443]",
613
+ "[1.94221354 1.57740951 1.98649526 1.40934443]",
614
+ "[1.8493703603744507, 1.8721930980682373, 1.704444169998169, -0.009961890056729317]"
615
  ],
616
  [
617
+ "[0.80654246 0.08253473 0.74478531 0.71257162]",
618
+ "[1.8065424 1.08253479 1.74478531 1.71257162]",
619
+ "[1.672502040863037, 1.6967051029205322, 1.5810983180999756, -0.009166811592876911]"
620
  ],
621
  [
622
+ "[0.50272274 0.54912758 0.17663097 0.79070699]",
623
+ "[1.50272274 1.54912758 1.17663097 1.79070699]",
624
+ "[1.4309396743774414, 1.4570224285125732, 1.4126286506652832, -0.008081027306616306]"
625
  ],
626
  [
627
+ "[0.34084332 0.73018837 0.54168713 0.91440833]",
628
+ "[1.34084332 1.73018837 1.54168713 1.91440833]",
629
+ "[1.5581963062286377, 1.5832865238189697, 1.5013742446899414, -0.008653069846332073]"
630
  ]
631
  ]
632
  },
 
648
  "[0.11560339 0.57495481 0.76535827 0.0391947 ]",
649
  "[1.11560345 1.57495475 1.76535821 1.0391947 ]"
650
  ],
651
+ [
652
+ "[0.19409031 0.68692201 0.60667384 0.57829887]",
653
+ "[1.19409037 1.68692207 1.60667384 1.57829881]"
654
+ ],
655
  [
656
  "[0.76807946 0.98855817 0.08259124 0.01730657]",
657
  "[1.76807952 1.98855817 1.0825913 1.01730657]"
 
685
  "[1.98324287 1.99464178 1.14008355 1.47651017]"
686
  ],
687
  [
688
+ "[0.11693293 0.49860179 0.55020827 0.88832849]",
689
+ "[1.11693287 1.49860179 1.55020833 1.88832855]"
690
  ],
691
  [
692
+ "[0.48959708 0.48549271 0.32688856 0.356677 ]",
693
+ "[1.48959708 1.48549271 1.32688856 1.35667706]"
694
  ],
695
  [
696
  "[0.04508126 0.76880038 0.80721325 0.62542385]",
 
712
  "[0.24388778 0.07268471 0.68350857 0.73431659]",
713
  "[1.24388778 1.07268476 1.68350863 1.73431659]"
714
  ],
715
+ [
716
+ "[0.62569475 0.9881897 0.83639616 0.9828859 ]",
717
+ "[1.62569475 1.9881897 1.83639622 1.98288584]"
718
+ ],
719
  [
720
  "[0.56922203 0.98222166 0.76851749 0.28615737]",
721
  "[1.56922197 1.9822216 1.76851749 1.28615737]"
 
728
  "[0.90817457 0.89270043 0.38583666 0.66566533]",
729
  "[1.90817451 1.89270043 1.3858366 1.66566539]"
730
  ],
731
+ [
732
+ "[0.48507756 0.80808765 0.77162558 0.47834778]",
733
+ "[1.48507762 1.80808759 1.77162552 1.47834778]"
734
+ ],
735
  [
736
  "[0.68062544 0.98093534 0.14778823 0.53244978]",
737
  "[1.68062544 1.98093534 1.14778829 1.53244972]"
 
752
  "[0.23942459 0.90487361 0.69337189 0.65089428]",
753
  "[1.23942459 1.90487361 1.69337189 1.65089428]"
754
  ],
 
 
 
 
755
  [
756
  "[0.26661873 0.45946234 0.13510543 0.81294441]",
757
  "[1.26661873 1.4594624 1.13510537 1.81294441]"
 
780
  "[0.78956431 0.87284744 0.06880784 0.03455889]",
781
  "[1.78956437 1.87284744 1.06880784 1.03455889]"
782
  ],
 
 
 
 
783
  [
784
  "[0.00497234 0.39319336 0.57054168 0.75150961]",
785
  "[1.00497234 1.39319336 1.57054162 1.75150967]"
 
792
  "[0.72290605 0.96945059 0.68354797 0.15270454]",
793
  "[1.72290611 1.96945059 1.68354797 1.15270448]"
794
  ],
795
+ [
796
+ "[0.75292218 0.81470108 0.49657214 0.56217098]",
797
+ "[1.75292218 1.81470108 1.49657214 1.56217098]"
798
+ ],
799
+ [
800
+ "[0.33480108 0.59181517 0.76198453 0.98062384]",
801
+ "[1.33480108 1.59181523 1.76198459 1.98062384]"
802
+ ],
803
  [
804
  "[0.52784437 0.54268694 0.12358981 0.72116476]",
805
  "[1.52784443 1.54268694 1.12358975 1.7211647 ]"
 
808
  "[0.73217702 0.65233225 0.44077861 0.33837909]",
809
  "[1.73217702 1.65233231 1.44077861 1.33837914]"
810
  ],
 
 
 
 
811
  [
812
  "[0.60110539 0.3618983 0.32342511 0.98672163]",
813
  "[1.60110545 1.3618983 1.32342505 1.98672163]"
 
844
  "[0.95928186 0.84273899 0.71514636 0.38619852]",
845
  "[1.95928192 1.84273899 1.7151463 1.38619852]"
846
  ],
847
+ [
848
+ "[0.32565445 0.90939188 0.07488042 0.13730896]",
849
+ "[1.32565451 1.90939188 1.07488036 1.13730896]"
850
+ ],
851
  [
852
  "[0.9829582 0.59269661 0.40120947 0.95487177]",
853
  "[1.9829582 1.59269667 1.40120947 1.95487177]"
 
856
  "[0.79905868 0.89367443 0.75429088 0.3190186 ]",
857
  "[1.79905868 1.89367437 1.75429082 1.3190186 ]"
858
  ],
 
 
 
 
859
  [
860
  "[0.67418337 0.79634351 0.23229051 0.71345252]",
861
  "[1.67418337 1.79634356 1.23229051 1.71345258]"
 
868
  "[0.81788456 0.58174163 0.29376316 0.7971254 ]",
869
  "[1.81788456 1.58174157 1.29376316 1.79712534]"
870
  ],
 
 
 
 
871
  [
872
  "[0.60075855 0.12234765 0.00614399 0.30560958]",
873
  "[1.60075855 1.12234759 1.00614405 1.30560958]"
874
  ],
875
+ [
876
+ "[0.39147133 0.29854035 0.84663737 0.58175623]",
877
+ "[1.39147139 1.29854035 1.84663737 1.58175623]"
878
+ ],
879
  [
880
  "[0.02162331 0.81861657 0.92468154 0.07808572]",
881
  "[1.02162337 1.81861663 1.92468154 1.07808566]"
 
904
  "[0.60609657 0.96257663 0.19292736 0.95702219]",
905
  "[1.60609651 1.96257663 1.19292736 1.95702219]"
906
  ],
 
 
 
 
907
  [
908
  "[0.70167565 0.26930219 0.5660674 0.61194974]",
909
  "[1.70167565 1.26930213 1.56606746 1.61194968]"
 
912
  "[0.76933283 0.86241865 0.44114518 0.65644735]",
913
  "[1.76933289 1.86241865 1.44114518 1.65644741]"
914
  ],
 
 
 
 
915
  [
916
  "[0.15064228 0.03198934 0.25754827 0.51484001]",
917
  "[1.15064228 1.03198934 1.25754833 1.51484001]"
 
920
  "[0.12024075 0.21342516 0.56858408 0.58644271]",
921
  "[1.12024069 1.21342516 1.56858408 1.58644271]"
922
  ],
923
+ [
924
+ "[0.91730917 0.22574073 0.09591609 0.33056474]",
925
+ "[1.91730917 1.22574067 1.09591603 1.33056474]"
926
+ ],
927
  [
928
  "[0.49691743 0.61873293 0.90698647 0.94486356]",
929
  "[1.49691749 1.61873293 1.90698647 1.94486356]"
 
952
  "[0.80893755 0.92237449 0.88346356 0.93164903]",
953
  "[1.80893755 1.92237449 1.88346362 1.93164897]"
954
  ],
 
 
 
 
955
  [
956
  "[0.72470158 0.4940322 0.41027349 0.89364016]",
957
  "[1.72470164 1.49403214 1.41027355 1.89364016]"
958
  ],
 
 
 
 
959
  [
960
  "[0.49584109 0.80599248 0.07096875 0.75872749]",
961
  "[1.49584103 1.80599248 1.07096875 1.75872755]"
 
988
  "[0.72795159 0.79317838 0.27832931 0.96576637]",
989
  "[1.72795153 1.79317832 1.27832937 1.96576643]"
990
  ],
991
+ [
992
+ "[0.87608397 0.93200487 0.80169648 0.37758952]",
993
+ "[1.87608397 1.93200493 1.80169654 1.37758946]"
994
+ ],
995
  [
996
  "[0.68891573 0.25576538 0.96339929 0.503833 ]",
997
  "[1.68891573 1.25576544 1.96339929 1.50383306]"
 
1000
  }
1001
  },
1002
  "other": {
1003
+ "model": "ModelConfig(model=Sequential(\n (0) - Identity(): Input__tensor_1_x -> START_Repeat_1_output\n (1) - Linear(in_features=4, out_features=4, bias=True): START_Repeat_1_output -> Linear_2_output\n (2) - <function leaky_relu at 0x75e660938220>: Linear_2_output -> Activation_1_output\n (3) - Identity(): Activation_1_output -> START_Repeat_1_output\n (4) - Linear(in_features=4, out_features=4, bias=True): START_Repeat_1_output -> Linear_2_output\n (5) - <function leaky_relu at 0x75e660938220>: Linear_2_output -> Activation_1_output\n (6) - Identity(): Activation_1_output -> START_Repeat_1_output\n (7) - Linear(in_features=4, out_features=4, bias=True): START_Repeat_1_output -> Linear_2_output\n (8) - <function leaky_relu at 0x75e660938220>: Linear_2_output -> Activation_1_output\n (9) - Identity(): Activation_1_output -> END_Repeat_1_output\n (10) - Identity(): END_Repeat_1_output -> END_Repeat_1_output\n), model_inputs=['Input__tensor_1_x'], model_outputs=['END_Repeat_1_output'], loss_inputs=['Input__tensor_3_x', 'END_Repeat_1_output'], loss=Sequential(\n (0) - <function mse_loss at 0x75e660939d00>: END_Repeat_1_output, Input__tensor_3_x -> MSE_loss_2_output\n (1) - Identity(): MSE_loss_2_output -> loss\n), optimizer=SGD (\nParameter Group 0\n dampening: 0\n differentiable: False\n foreach: None\n fused: None\n lr: 0.1\n maximize: False\n momentum: 0\n nesterov: False\n weight_decay: 0\n), source_workspace=None, trained=True)"
1004
  },
1005
  "relations": []
1006
  },
 
1035
  "Input__tensor_1_x"
1036
  ],
1037
  "loss_inputs": [
1038
+ "Input__tensor_3_x",
1039
+ "END_Repeat_1_output"
1040
  ],
1041
  "outputs": [
1042
  "END_Repeat_1_output"
 
1210
  "Input__tensor_1_x"
1211
  ],
1212
  "loss_inputs": [
1213
+ "Input__tensor_3_x",
1214
+ "END_Repeat_1_output"
1215
  ],
1216
  "outputs": [
1217
  "END_Repeat_1_output"
 
1270
  "type": "basic"
1271
  },
1272
  "params": {
1273
+ "epochs": "110",
1274
  "input_mapping": "{\"map\":{\"Input__tensor_1_x\":{\"df\":\"df_train\",\"column\":\"x\"},\"Input__tensor_3_x\":{\"df\":\"df_train\",\"column\":\"y\"}}}",
1275
  "model_name": "model"
1276
  },
 
1322
  "Input__tensor_1_x"
1323
  ],
1324
  "loss_inputs": [
1325
+ "Input__tensor_3_x",
1326
+ "END_Repeat_1_output"
1327
  ],
1328
  "outputs": [
1329
  "END_Repeat_1_output"
lynxkite-core/src/lynxkite/core/ops.py CHANGED
@@ -13,7 +13,7 @@ from typing_extensions import Annotated
13
  if typing.TYPE_CHECKING:
14
  from . import workspace
15
 
16
- CATALOGS = {}
17
  EXECUTORS = {}
18
 
19
  typeof = type # We have some arguments called "type".
 
13
  if typing.TYPE_CHECKING:
14
  from . import workspace
15
 
16
+ CATALOGS: dict[str, dict[str, "Op"]] = {}
17
  EXECUTORS = {}
18
 
19
  typeof = type # We have some arguments called "type".
lynxkite-graph-analytics/src/lynxkite_graph_analytics/pytorch_model_ops.py CHANGED
@@ -214,6 +214,9 @@ class ModelConfig:
214
  source_workspace: str | None = None
215
  trained: bool = False
216
 
 
 
 
217
  def _forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
218
  model_inputs = [inputs[i] for i in self.model_inputs]
219
  output = self.model(*model_inputs)
@@ -270,7 +273,7 @@ class ModelBuilder:
270
  def __init__(self, ws: workspace.Workspace, inputs: dict[str, torch.Tensor]):
271
  self.catalog = ops.CATALOGS[ENV]
272
  optimizers = []
273
- self.nodes = {}
274
  for node in ws.nodes:
275
  self.nodes[node.id] = node
276
  if node.data.title == "Optimizer":
@@ -279,8 +282,8 @@ class ModelBuilder:
279
  assert len(optimizers) == 1, f"More than one optimizer found: {optimizers}"
280
  [self.optimizer] = optimizers
281
  self.dependencies = {n.id: [] for n in ws.nodes}
282
- self.in_edges = {}
283
- self.out_edges = {}
284
  repeats = []
285
  for e in ws.edges:
286
  if self.nodes[e.target].data.title == "Repeat":
@@ -367,21 +370,12 @@ class ModelBuilder:
367
  t = node.data.title
368
  op = self.catalog[t]
369
  p = op.convert_params(node.data.params)
370
- inputs = {}
371
- for n in self.in_edges.get(node_id, []):
372
- for b, h in self.in_edges[node_id][n]:
373
- i = _to_id(b, h)
374
- inputs[n] = i
375
- outputs = {}
376
- for out in self.out_edges.get(node_id, []):
377
- i = _to_id(node_id, out)
378
- outputs[out] = i
379
  match t:
380
  case "Repeat":
381
  if node_id.startswith("END "):
382
  repeat_id = node_id.removeprefix("END ")
383
  start_id = f"START {repeat_id}"
384
- print(f"repeat {repeat_id} ending")
385
  after_start = self.all_downstream(start_id)
386
  after_end = self.all_downstream(node_id)
387
  before_end = self.all_upstream(node_id)
@@ -390,28 +384,64 @@ class ModelBuilder:
390
  assert affected_nodes == repeated_nodes, (
391
  f"edges leave repeated section '{repeat_id}':\n{affected_nodes - repeated_nodes}"
392
  )
393
- for n in repeated_nodes:
394
- print(f"repeating {n}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  case "Optimizer" | "Input: tensor" | "Input: graph edges" | "Input: sequential":
396
  return
397
- layer = self.run_op(op, p, inputs, outputs)
398
- layer._origin_id = node_id
399
- self.layers.append(layer)
400
 
401
- def run_op(self, op, params, inputs: dict[str, str], outputs: dict[str, str]) -> Layer:
402
  """Returns the layer produced by this op."""
403
- op_inputs = [
404
- TensorRef(inputs[i], shape=self.sizes.get(inputs[i], 1)) for i in op.inputs.keys()
405
- ]
406
  if op.func != ops.no_op:
407
- layer = op.func(*op_inputs, **params)
408
- else:
409
- layer = Layer(torch.nn.Identity(), shapes=[i.shape for i in op_inputs])
410
- layer._inputs = op_inputs
411
- layer._outputs = []
412
- for o, shape in zip(op.outputs.keys(), layer.shapes):
413
- layer._outputs.append(TensorRef(outputs[o], shape=shape))
414
- self.sizes[outputs[o]] = shape
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  return layer
416
 
417
  def build_model(self) -> ModelConfig:
@@ -462,7 +492,6 @@ class ModelBuilder:
462
  p = op.convert_params(self.nodes[self.optimizer].data.params)
463
  o = getattr(torch.optim, p["type"].name)
464
  cfg["optimizer"] = o(cfg["model"].parameters(), lr=p["lr"])
465
- print(cfg)
466
  return ModelConfig(**cfg)
467
 
468
 
 
214
  source_workspace: str | None = None
215
  trained: bool = False
216
 
217
+ def num_parameters(self) -> int:
218
+ return sum(p.numel() for p in self.model.parameters())
219
+
220
  def _forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
221
  model_inputs = [inputs[i] for i in self.model_inputs]
222
  output = self.model(*model_inputs)
 
273
  def __init__(self, ws: workspace.Workspace, inputs: dict[str, torch.Tensor]):
274
  self.catalog = ops.CATALOGS[ENV]
275
  optimizers = []
276
+ self.nodes: dict[str, workspace.WorkspaceNode] = {}
277
  for node in ws.nodes:
278
  self.nodes[node.id] = node
279
  if node.data.title == "Optimizer":
 
282
  assert len(optimizers) == 1, f"More than one optimizer found: {optimizers}"
283
  [self.optimizer] = optimizers
284
  self.dependencies = {n.id: [] for n in ws.nodes}
285
+ self.in_edges: dict[str, dict[str, list[(str, str)]]] = {}
286
+ self.out_edges: dict[str, dict[str, list[(str, str)]]] = {}
287
  repeats = []
288
  for e in ws.edges:
289
  if self.nodes[e.target].data.title == "Repeat":
 
370
  t = node.data.title
371
  op = self.catalog[t]
372
  p = op.convert_params(node.data.params)
 
 
 
 
 
 
 
 
 
373
  match t:
374
  case "Repeat":
375
  if node_id.startswith("END "):
376
  repeat_id = node_id.removeprefix("END ")
377
  start_id = f"START {repeat_id}"
378
+ [last_output] = self.in_edges[node_id]["input"]
379
  after_start = self.all_downstream(start_id)
380
  after_end = self.all_downstream(node_id)
381
  before_end = self.all_upstream(node_id)
 
384
  assert affected_nodes == repeated_nodes, (
385
  f"edges leave repeated section '{repeat_id}':\n{affected_nodes - repeated_nodes}"
386
  )
387
+ repeated_layers = [e for e in self.layers if e._origin_id in repeated_nodes]
388
+ for i in range(p["times"] - 1):
389
+ # Copy repeat section's output to repeat section's input.
390
+ self.layers.append(
391
+ self.empty_layer(
392
+ node_id,
393
+ inputs=[_to_id(*last_output)],
394
+ outputs=[_to_id(start_id, "output")],
395
+ )
396
+ )
397
+ # Repeat the layers in the section.
398
+ for layer in repeated_layers:
399
+ if p["same_weights"]:
400
+ self.layers.append(
401
+ Layer(
402
+ layer.module,
403
+ shapes=layer.shapes,
404
+ _origin_id=layer._origin_id,
405
+ _inputs=layer._inputs,
406
+ _outputs=layer._outputs,
407
+ )
408
+ )
409
+ else:
410
+ self.run_node(layer._origin_id)
411
+ self.layers.append(self.run_op(node_id, op, p))
412
  case "Optimizer" | "Input: tensor" | "Input: graph edges" | "Input: sequential":
413
  return
414
+ case _:
415
+ self.layers.append(self.run_op(node_id, op, p))
 
416
 
417
+ def run_op(self, node_id: str, op: ops.Op, params) -> Layer:
418
  """Returns the layer produced by this op."""
419
+ inputs = [_to_id(*i) for n in op.inputs for i in self.in_edges[node_id][n]]
420
+ outputs = [_to_id(node_id, n) for n in op.outputs]
421
+ layer = self.empty_layer(node_id, inputs, outputs)
422
  if op.func != ops.no_op:
423
+ op_layer = op.func(*layer._inputs, **params)
424
+ layer.module = op_layer.module
425
+ layer.shapes = op_layer.shapes
426
+ for o in layer._outputs:
427
+ self.sizes[o._id] = o.shape
428
+ return layer
429
+
430
+ def empty_layer(self, id: str, inputs: list[str], outputs: list[str]) -> Layer:
431
+ """Creates an identity layer. Assumes that outputs have the same shapes as inputs."""
432
+ layer_inputs = [TensorRef(i, shape=self.sizes.get(i, 1)) for i in inputs]
433
+ layer_outputs = []
434
+ for i, o in zip(inputs, outputs):
435
+ shape = self.sizes.get(i, 1)
436
+ layer_outputs.append(TensorRef(o, shape=shape))
437
+ self.sizes[o] = shape
438
+ layer = Layer(
439
+ torch.nn.Identity(),
440
+ shapes=[self.sizes[o._id] for o in layer_outputs],
441
+ _inputs=layer_inputs,
442
+ _outputs=layer_outputs,
443
+ _origin_id=id,
444
+ )
445
  return layer
446
 
447
  def build_model(self) -> ModelConfig:
 
492
  p = op.convert_params(self.nodes[self.optimizer].data.params)
493
  o = getattr(torch.optim, p["type"].name)
494
  cfg["optimizer"] = o(cfg["model"].parameters(), lr=p["lr"])
 
495
  return ModelConfig(**cfg)
496
 
497