darabos commited on
Commit
83caaa1
·
1 Parent(s): bbd029e

14 PyTorch ops.

Browse files
examples/Model definition CHANGED
@@ -8,45 +8,66 @@
8
  "targetHandle": "loss"
9
  },
10
  {
11
- "id": "Activation 1 MSE loss 2",
12
  "source": "Activation 1",
13
  "sourceHandle": "output",
14
- "target": "MSE loss 2",
15
- "targetHandle": "x"
16
  },
17
  {
18
- "id": "Input: tensor 3 MSE loss 2",
19
- "source": "Input: tensor 3",
20
- "sourceHandle": "x",
21
- "target": "MSE loss 2",
22
- "targetHandle": "y"
23
  },
24
  {
25
- "id": "Activation 1 Repeat 1",
26
- "source": "Activation 1",
27
  "sourceHandle": "output",
28
- "target": "Repeat 1",
29
- "targetHandle": "input"
30
  },
31
  {
32
  "id": "Input: tensor 1 Linear 1",
33
  "source": "Input: tensor 1",
34
- "sourceHandle": "x",
35
  "target": "Linear 1",
36
  "targetHandle": "x"
37
  },
38
  {
39
- "id": "Linear 1 Activation 1",
40
- "source": "Linear 1",
41
  "sourceHandle": "output",
42
- "target": "Activation 1",
43
- "targetHandle": "x"
44
  },
45
  {
46
- "id": "Repeat 1 Linear 1",
47
- "source": "Repeat 1",
48
  "sourceHandle": "output",
49
- "target": "Linear 1",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  "targetHandle": "x"
51
  }
52
  ],
@@ -108,8 +129,8 @@
108
  "height": 250.0,
109
  "id": "Optimizer 2",
110
  "position": {
111
- "x": 292.3983313429414,
112
- "y": -853.8015246037802
113
  },
114
  "type": "basic",
115
  "width": 232.0
@@ -155,10 +176,6 @@
155
  }
156
  }
157
  },
158
- "position": {
159
- "x": 344.0,
160
- "y": 384.0
161
- },
162
  "type": "basic"
163
  },
164
  "params": {
@@ -188,58 +205,8 @@
188
  "inputs": {},
189
  "name": "Input: tensor",
190
  "outputs": {
191
- "x": {
192
- "name": "x",
193
- "position": "top",
194
- "type": {
195
- "type": "tensor"
196
- }
197
- }
198
- },
199
- "params": {
200
- "name": {
201
- "default": null,
202
- "name": "name",
203
- "type": {
204
- "type": "None"
205
- }
206
- }
207
- },
208
- "position": {
209
- "x": 258.0,
210
- "y": 397.0
211
- },
212
- "type": "basic"
213
- },
214
- "params": {
215
- "name": "X"
216
- },
217
- "status": "planned",
218
- "title": "Input: tensor"
219
- },
220
- "dragHandle": ".bg-primary",
221
- "height": 200.0,
222
- "id": "Input: tensor 1",
223
- "position": {
224
- "x": 85.83561484252238,
225
- "y": 293.6278596776366
226
- },
227
- "type": "basic",
228
- "width": 200.0
229
- },
230
- {
231
- "data": {
232
- "__execution_delay": 0.0,
233
- "collapsed": null,
234
- "display": null,
235
- "error": null,
236
- "input_metadata": null,
237
- "meta": {
238
- "inputs": {},
239
- "name": "Input: tensor",
240
- "outputs": {
241
- "x": {
242
- "name": "x",
243
  "position": "top",
244
  "type": {
245
  "type": "tensor"
@@ -255,10 +222,6 @@
255
  }
256
  }
257
  },
258
- "position": {
259
- "x": 1169.0,
260
- "y": 340.0
261
- },
262
  "type": "basic"
263
  },
264
  "params": {
@@ -279,6 +242,8 @@
279
  },
280
  {
281
  "data": {
 
 
282
  "display": null,
283
  "error": null,
284
  "input_metadata": null,
@@ -310,10 +275,6 @@
310
  }
311
  },
312
  "params": {},
313
- "position": {
314
- "x": 937.0,
315
- "y": 270.0
316
- },
317
  "type": "basic"
318
  },
319
  "params": {},
@@ -324,8 +285,8 @@
324
  "height": 200.0,
325
  "id": "MSE loss 2",
326
  "position": {
327
- "x": 309.4422414664647,
328
- "y": -552.1056805642488
329
  },
330
  "type": "basic",
331
  "width": 200.0
@@ -373,10 +334,6 @@
373
  }
374
  }
375
  },
376
- "position": {
377
- "x": 487.0,
378
- "y": 443.0
379
- },
380
  "type": "basic"
381
  },
382
  "params": {
@@ -433,8 +390,8 @@
433
  }
434
  },
435
  "position": {
436
- "x": 359.0,
437
- "y": 310.0
438
  },
439
  "type": "basic"
440
  },
@@ -448,8 +405,219 @@
448
  "height": 200.0,
449
  "id": "Linear 1",
450
  "position": {
451
- "x": 88.83370222907377,
452
- "y": 48.642890099180136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
  },
454
  "type": "basic",
455
  "width": 200.0
 
8
  "targetHandle": "loss"
9
  },
10
  {
11
+ "id": "Activation 1 Repeat 1",
12
  "source": "Activation 1",
13
  "sourceHandle": "output",
14
+ "target": "Repeat 1",
15
+ "targetHandle": "input"
16
  },
17
  {
18
+ "id": "Linear 1 Activation 1",
19
+ "source": "Linear 1",
20
+ "sourceHandle": "output",
21
+ "target": "Activation 1",
22
+ "targetHandle": "x"
23
  },
24
  {
25
+ "id": "Repeat 1 Linear 1",
26
+ "source": "Repeat 1",
27
  "sourceHandle": "output",
28
+ "target": "Linear 1",
29
+ "targetHandle": "x"
30
  },
31
  {
32
  "id": "Input: tensor 1 Linear 1",
33
  "source": "Input: tensor 1",
34
+ "sourceHandle": "output",
35
  "target": "Linear 1",
36
  "targetHandle": "x"
37
  },
38
  {
39
+ "id": "Constant vector 1 Add 1",
40
+ "source": "Constant vector 1",
41
  "sourceHandle": "output",
42
+ "target": "Add 1",
43
+ "targetHandle": "b"
44
  },
45
  {
46
+ "id": "Input: tensor 3 Add 1",
47
+ "source": "Input: tensor 3",
48
  "sourceHandle": "output",
49
+ "target": "Add 1",
50
+ "targetHandle": "a"
51
+ },
52
+ {
53
+ "id": "Add 1 MSE loss 2",
54
+ "source": "Add 1",
55
+ "sourceHandle": "output",
56
+ "target": "MSE loss 2",
57
+ "targetHandle": "y"
58
+ },
59
+ {
60
+ "id": "Activation 1 Output 1",
61
+ "source": "Activation 1",
62
+ "sourceHandle": "output",
63
+ "target": "Output 1",
64
+ "targetHandle": "x"
65
+ },
66
+ {
67
+ "id": "Output 1 MSE loss 2",
68
+ "source": "Output 1",
69
+ "sourceHandle": "x",
70
+ "target": "MSE loss 2",
71
  "targetHandle": "x"
72
  }
73
  ],
 
129
  "height": 250.0,
130
  "id": "Optimizer 2",
131
  "position": {
132
+ "x": 359.75221367487865,
133
+ "y": -1560.7604266065723
134
  },
135
  "type": "basic",
136
  "width": 232.0
 
176
  }
177
  }
178
  },
 
 
 
 
179
  "type": "basic"
180
  },
181
  "params": {
 
205
  "inputs": {},
206
  "name": "Input: tensor",
207
  "outputs": {
208
+ "output": {
209
+ "name": "output",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  "position": "top",
211
  "type": {
212
  "type": "tensor"
 
222
  }
223
  }
224
  },
 
 
 
 
225
  "type": "basic"
226
  },
227
  "params": {
 
242
  },
243
  {
244
  "data": {
245
+ "__execution_delay": null,
246
+ "collapsed": true,
247
  "display": null,
248
  "error": null,
249
  "input_metadata": null,
 
275
  }
276
  },
277
  "params": {},
 
 
 
 
278
  "type": "basic"
279
  },
280
  "params": {},
 
285
  "height": 200.0,
286
  "id": "MSE loss 2",
287
  "position": {
288
+ "x": 362.77557479979805,
289
+ "y": -1287.1056805642488
290
  },
291
  "type": "basic",
292
  "width": 200.0
 
334
  }
335
  }
336
  },
 
 
 
 
337
  "type": "basic"
338
  },
339
  "params": {
 
390
  }
391
  },
392
  "position": {
393
+ "x": 667.0,
394
+ "y": 432.0
395
  },
396
  "type": "basic"
397
  },
 
405
  "height": 200.0,
406
  "id": "Linear 1",
407
  "position": {
408
+ "x": 98.54861342271252,
409
+ "y": 14.121603973834155
410
+ },
411
+ "type": "basic",
412
+ "width": 200.0
413
+ },
414
+ {
415
+ "data": {
416
+ "__execution_delay": 0.0,
417
+ "collapsed": null,
418
+ "display": null,
419
+ "error": null,
420
+ "input_metadata": null,
421
+ "meta": {
422
+ "inputs": {},
423
+ "name": "Input: tensor",
424
+ "outputs": {
425
+ "output": {
426
+ "name": "output",
427
+ "position": "top",
428
+ "type": {
429
+ "type": "tensor"
430
+ }
431
+ }
432
+ },
433
+ "params": {
434
+ "name": {
435
+ "default": null,
436
+ "name": "name",
437
+ "type": {
438
+ "type": "None"
439
+ }
440
+ }
441
+ },
442
+ "position": {
443
+ "x": 675.0,
444
+ "y": 499.0
445
+ },
446
+ "type": "basic"
447
+ },
448
+ "params": {
449
+ "name": "X"
450
+ },
451
+ "status": "planned",
452
+ "title": "Input: tensor"
453
+ },
454
+ "dragHandle": ".bg-primary",
455
+ "height": 200.0,
456
+ "id": "Input: tensor 1",
457
+ "position": {
458
+ "x": 108.75735538875443,
459
+ "y": 331.53404347930933
460
+ },
461
+ "type": "basic",
462
+ "width": 200.0
463
+ },
464
+ {
465
+ "data": {
466
+ "__execution_delay": 0.0,
467
+ "collapsed": null,
468
+ "display": null,
469
+ "error": null,
470
+ "input_metadata": null,
471
+ "meta": {
472
+ "inputs": {},
473
+ "name": "Constant vector",
474
+ "outputs": {
475
+ "output": {
476
+ "name": "output",
477
+ "position": "top",
478
+ "type": {
479
+ "type": "None"
480
+ }
481
+ }
482
+ },
483
+ "params": {
484
+ "size": {
485
+ "default": 1.0,
486
+ "name": "size",
487
+ "type": {
488
+ "type": "<class 'int'>"
489
+ }
490
+ },
491
+ "value": {
492
+ "default": 0.0,
493
+ "name": "value",
494
+ "type": {
495
+ "type": "<class 'int'>"
496
+ }
497
+ }
498
+ },
499
+ "position": {
500
+ "x": 1061.0,
501
+ "y": 239.0
502
+ },
503
+ "type": "basic"
504
+ },
505
+ "params": {
506
+ "size": "1",
507
+ "value": "1"
508
+ },
509
+ "status": "planned",
510
+ "title": "Constant vector"
511
+ },
512
+ "dragHandle": ".bg-primary",
513
+ "height": 258.0,
514
+ "id": "Constant vector 1",
515
+ "position": {
516
+ "x": 983.1241140187901,
517
+ "y": -562.803650462906
518
+ },
519
+ "type": "basic",
520
+ "width": 238.0
521
+ },
522
+ {
523
+ "data": {
524
+ "__execution_delay": null,
525
+ "collapsed": true,
526
+ "display": null,
527
+ "error": null,
528
+ "input_metadata": null,
529
+ "meta": {
530
+ "inputs": {
531
+ "a": {
532
+ "name": "a",
533
+ "position": "bottom",
534
+ "type": {
535
+ "type": "<class 'inspect._empty'>"
536
+ }
537
+ },
538
+ "b": {
539
+ "name": "b",
540
+ "position": "bottom",
541
+ "type": {
542
+ "type": "<class 'inspect._empty'>"
543
+ }
544
+ }
545
+ },
546
+ "name": "Add",
547
+ "outputs": {
548
+ "output": {
549
+ "name": "output",
550
+ "position": "top",
551
+ "type": {
552
+ "type": "None"
553
+ }
554
+ }
555
+ },
556
+ "params": {},
557
+ "position": {
558
+ "x": 1077.0,
559
+ "y": 334.0
560
+ },
561
+ "type": "basic"
562
+ },
563
+ "params": {},
564
+ "status": "planned",
565
+ "title": "Add"
566
+ },
567
+ "dragHandle": ".bg-primary",
568
+ "height": 200.0,
569
+ "id": "Add 1",
570
+ "position": {
571
+ "x": 818.5444381090571,
572
+ "y": -955.5157374399466
573
+ },
574
+ "type": "basic",
575
+ "width": 200.0
576
+ },
577
+ {
578
+ "data": {
579
+ "__execution_delay": null,
580
+ "collapsed": true,
581
+ "display": null,
582
+ "error": null,
583
+ "input_metadata": null,
584
+ "meta": {
585
+ "inputs": {
586
+ "x": {
587
+ "name": "x",
588
+ "position": "bottom",
589
+ "type": {
590
+ "type": "tensor"
591
+ }
592
+ }
593
+ },
594
+ "name": "Output",
595
+ "outputs": {
596
+ "x": {
597
+ "name": "x",
598
+ "position": "top",
599
+ "type": {
600
+ "type": "tensor"
601
+ }
602
+ }
603
+ },
604
+ "params": {},
605
+ "position": {
606
+ "x": 544.0,
607
+ "y": 306.0
608
+ },
609
+ "type": "basic"
610
+ },
611
+ "params": {},
612
+ "status": "planned",
613
+ "title": "Output"
614
+ },
615
+ "dragHandle": ".bg-primary",
616
+ "height": 200.0,
617
+ "id": "Output 1",
618
+ "position": {
619
+ "x": 185.15239170944702,
620
+ "y": -733.1526319565451
621
  },
622
  "type": "basic",
623
  "width": 200.0
examples/Model use CHANGED
@@ -579,54 +579,54 @@
579
  ],
580
  "data": [
581
  [
582
- "[0.19908059 0.17570406 0.51475513 0.1893943 ]",
583
- "[1.19908059 1.175704 1.51475513 1.18939424]",
584
- "[1.560641884803772, 1.5941988229751587, 1.5775359869003296, 1.4935821294784546]"
585
  ],
586
  [
587
- "[0.43681622 0.74680805 0.83598751 0.12414402]",
588
- "[1.43681622 1.74680805 1.83598757 1.12414408]",
589
- "[1.5766589641571045, 1.7117265462875366, 1.7645087242126465, 1.3384637832641602]"
590
  ],
591
  [
592
- "[0.9829582 0.59269661 0.40120947 0.95487177]",
593
- "[1.9829582 1.59269667 1.40120947 1.95487177]",
594
- "[1.5375217199325562, 1.4159281253814697, 1.2972962856292725, 1.7269455194473267]"
595
  ],
596
  [
597
- "[0.32565445 0.90939188 0.07488042 0.13730896]",
598
- "[1.32565451 1.90939188 1.07488036 1.13730896]",
599
- "[1.562728762626648, 1.6061222553253174, 1.597141146659851, 1.4772177934646606]"
600
  ],
601
  [
602
- "[0.31518555 0.49643308 0.11509258 0.95458382]",
603
- "[1.31518555 1.49643302 1.11509252 1.95458388]",
604
- "[1.528311848640442, 1.3380011320114136, 1.171952247619629, 1.8305948972702026]"
605
  ],
606
  [
607
- "[0.79905868 0.89367443 0.75429088 0.3190186 ]",
608
- "[1.79905868 1.89367437 1.75429082 1.3190186 ]",
609
- "[1.5757312774658203, 1.7105278968811035, 1.7636661529541016, 1.3394038677215576]"
610
  ],
611
  [
612
- "[0.80893755 0.92237449 0.88346356 0.93164903]",
613
- "[1.80893755 1.92237449 1.88346362 1.93164897]",
614
- "[1.562132716178894, 1.6031286716461182, 1.593322992324829, 1.4810831546783447]"
615
  ],
616
  [
617
- "[0.26661873 0.45946234 0.13510543 0.81294441]",
618
- "[1.26661873 1.4594624 1.13510537 1.81294441]",
619
- "[1.533058762550354, 1.3753284215927124, 1.230975866317749, 1.7815138101577759]"
620
  ],
621
  [
622
- "[0.39147133 0.29854035 0.84663737 0.58175623]",
623
- "[1.39147139 1.29854035 1.84663737 1.58175623]",
624
- "[1.5607244968414307, 1.5942375659942627, 1.5779708623886108, 1.4935153722763062]"
625
  ],
626
  [
627
- "[0.34084332 0.73018837 0.54168713 0.91440833]",
628
- "[1.34084332 1.73018837 1.54168713 1.91440833]",
629
- "[1.5488454103469849, 1.4963982105255127, 1.422922968864441, 1.622254490852356]"
630
  ]
631
  ]
632
  },
@@ -688,10 +688,6 @@
688
  "[0.11693293 0.49860179 0.55020827 0.88832849]",
689
  "[1.11693287 1.49860179 1.55020833 1.88832855]"
690
  ],
691
- [
692
- "[0.48959708 0.48549271 0.32688856 0.356677 ]",
693
- "[1.48959708 1.48549271 1.32688856 1.35667706]"
694
- ],
695
  [
696
  "[0.50272274 0.54912758 0.17663097 0.79070699]",
697
  "[1.50272274 1.54912758 1.17663097 1.79070699]"
@@ -700,6 +696,10 @@
700
  "[0.04508126 0.76880038 0.80721325 0.62542385]",
701
  "[1.04508126 1.76880038 1.80721331 1.62542391]"
702
  ],
 
 
 
 
703
  [
704
  "[0.40167677 0.25953674 0.9407078 0.76308483]",
705
  "[1.40167677 1.25953674 1.9407078 1.76308489]"
@@ -716,10 +716,6 @@
716
  "[0.62569475 0.9881897 0.83639616 0.9828859 ]",
717
  "[1.62569475 1.9881897 1.83639622 1.98288584]"
718
  ],
719
- [
720
- "[0.56922203 0.98222166 0.76851749 0.28615737]",
721
- "[1.56922197 1.9822216 1.76851749 1.28615737]"
722
- ],
723
  [
724
  "[0.88776821 0.51636773 0.30333066 0.32230979]",
725
  "[1.88776827 1.51636767 1.30333066 1.32230973]"
@@ -736,6 +732,10 @@
736
  "[0.68062544 0.98093534 0.14778823 0.53244978]",
737
  "[1.68062544 1.98093534 1.14778829 1.53244972]"
738
  ],
 
 
 
 
739
  [
740
  "[0.79121011 0.54161114 0.69369799 0.1520769 ]",
741
  "[1.79121017 1.54161119 1.69369793 1.15207696]"
@@ -752,6 +752,10 @@
752
  "[0.94516498 0.08422136 0.5608117 0.07652664]",
753
  "[1.94516492 1.08422136 1.56081176 1.07652664]"
754
  ],
 
 
 
 
755
  [
756
  "[0.30754459 0.77694583 0.09278506 0.38326019]",
757
  "[1.30754459 1.77694583 1.09278512 1.38326025]"
@@ -764,14 +768,6 @@
764
  "[0.4827103 0.10563457 0.98858833 0.82286644]",
765
  "[1.48271036 1.10563457 1.98858833 1.82286644]"
766
  ],
767
- [
768
- "[0.98033333 0.97656083 0.38939917 0.81491041]",
769
- "[1.98033333 1.97656083 1.38939917 1.81491041]"
770
- ],
771
- [
772
- "[0.74064726 0.4155122 0.09800029 0.49930882]",
773
- "[1.74064732 1.4155122 1.09800029 1.49930882]"
774
- ],
775
  [
776
  "[0.78956431 0.87284744 0.06880784 0.03455889]",
777
  "[1.78956437 1.87284744 1.06880784 1.03455889]"
@@ -808,6 +804,10 @@
808
  "[0.73217702 0.65233225 0.44077861 0.33837909]",
809
  "[1.73217702 1.65233231 1.44077861 1.33837914]"
810
  ],
 
 
 
 
811
  [
812
  "[0.60110539 0.3618983 0.32342511 0.98672163]",
813
  "[1.60110545 1.3618983 1.32342505 1.98672163]"
@@ -836,14 +836,18 @@
836
  "[0.18720162 0.74115586 0.98626411 0.30355608]",
837
  "[1.18720162 1.74115586 1.98626411 1.30355608]"
838
  ],
839
- [
840
- "[0.85566247 0.83362883 0.48424995 0.25265992]",
841
- "[1.85566247 1.83362889 1.48424995 1.25265992]"
842
- ],
843
  [
844
  "[0.95928186 0.84273899 0.71514636 0.38619852]",
845
  "[1.95928192 1.84273899 1.7151463 1.38619852]"
846
  ],
 
 
 
 
 
 
 
 
847
  [
848
  "[0.54914117 0.03810108 0.87531954 0.73044223]",
849
  "[1.54914117 1.03810108 1.87531948 1.73044229]"
@@ -852,10 +856,6 @@
852
  "[0.67418337 0.79634351 0.23229051 0.71345252]",
853
  "[1.67418337 1.79634356 1.23229051 1.71345258]"
854
  ],
855
- [
856
- "[0.87285906 0.48354989 0.39394957 0.59456545]",
857
- "[1.872859 1.48354983 1.39394951 1.59456539]"
858
- ],
859
  [
860
  "[0.81788456 0.58174163 0.29376316 0.7971254 ]",
861
  "[1.81788456 1.58174157 1.29376316 1.79712534]"
@@ -868,6 +868,10 @@
868
  "[0.60075855 0.12234765 0.00614399 0.30560958]",
869
  "[1.60075855 1.12234759 1.00614405 1.30560958]"
870
  ],
 
 
 
 
871
  [
872
  "[0.02162331 0.81861657 0.92468154 0.07808572]",
873
  "[1.02162337 1.81861663 1.92468154 1.07808566]"
@@ -896,10 +900,6 @@
896
  "[0.60609657 0.96257663 0.19292736 0.95702219]",
897
  "[1.60609651 1.96257663 1.19292736 1.95702219]"
898
  ],
899
- [
900
- "[0.80654246 0.08253473 0.74478531 0.71257162]",
901
- "[1.8065424 1.08253479 1.74478531 1.71257162]"
902
- ],
903
  [
904
  "[0.70167565 0.26930219 0.5660674 0.61194974]",
905
  "[1.70167565 1.26930213 1.56606746 1.61194968]"
@@ -940,14 +940,14 @@
940
  "[0.37959969 0.42820001 0.10690689 0.96353984]",
941
  "[1.37959969 1.42820001 1.10690689 1.96353984]"
942
  ],
943
- [
944
- "[0.49607176 0.1922397 0.46640229 0.78321403]",
945
- "[1.49607182 1.19223976 1.46640229 1.78321409]"
946
- ],
947
  [
948
  "[0.40234613 0.54987347 0.49542785 0.54153186]",
949
  "[1.40234613 1.54987347 1.49542785 1.5415318 ]"
950
  ],
 
 
 
 
951
  [
952
  "[0.12858278 0.09930819 0.83222693 0.72485673]",
953
  "[1.12858272 1.09930825 1.83222699 1.72485673]"
@@ -980,6 +980,10 @@
980
  "[0.68094063 0.45189077 0.22661722 0.37354094]",
981
  "[1.68094063 1.45189071 1.22661722 1.37354088]"
982
  ],
 
 
 
 
983
  [
984
  "[0.47870928 0.17129105 0.27300501 0.20634609]",
985
  "[1.47870922 1.17129111 1.27300501 1.20634604]"
@@ -991,16 +995,12 @@
991
  [
992
  "[0.87608397 0.93200487 0.80169648 0.37758952]",
993
  "[1.87608397 1.93200493 1.80169654 1.37758946]"
994
- ],
995
- [
996
- "[0.68891573 0.25576538 0.96339929 0.503833 ]",
997
- "[1.68891573 1.25576544 1.96339929 1.50383306]"
998
  ]
999
  ]
1000
  }
1001
  },
1002
  "other": {
1003
- "model": "ModelConfig(model=Sequential(\n (0) - Identity(): Input__tensor_1_x -> START_Repeat_1_output\n (1) - Linear(in_features=4, out_features=4, bias=True): START_Repeat_1_output -> Linear_1_output\n (2) - <function leaky_relu at 0x762d1f82c680>: Linear_1_output -> Activation_1_output\n (3) - Identity(): Activation_1_output -> START_Repeat_1_output\n (4) - Linear(in_features=4, out_features=4, bias=True): START_Repeat_1_output -> Linear_1_output\n (5) - <function leaky_relu at 0x762d1f82c680>: Linear_1_output -> Activation_1_output\n (6) - Identity(): Activation_1_output -> END_Repeat_1_output\n (7) - Identity(): END_Repeat_1_output -> END_Repeat_1_output\n), model_inputs=['Input__tensor_1_x'], model_outputs=['END_Repeat_1_output'], loss_inputs=['END_Repeat_1_output', 'Input__tensor_3_x'], loss=Sequential(\n (0) - <function mse_loss at 0x762d1f82e160>: END_Repeat_1_output, Input__tensor_3_x -> MSE_loss_2_output\n (1) - Identity(): MSE_loss_2_output -> loss\n), optimizer_parameters={'lr': 0.1, 'type': <OptionsFor_type.SGD: 4>}, optimizer=SGD (\nParameter Group 0\n dampening: 0\n differentiable: False\n foreach: None\n fused: None\n lr: 0.1\n maximize: False\n momentum: 0\n nesterov: False\n weight_decay: 0\n), source_workspace='Model definition', trained=True)"
1004
  },
1005
  "relations": []
1006
  },
@@ -1032,14 +1032,14 @@
1032
  "model": {
1033
  "model": {
1034
  "inputs": [
1035
- "Input__tensor_1_x"
1036
  ],
1037
  "loss_inputs": [
1038
- "END_Repeat_1_output",
1039
- "Input__tensor_3_x"
1040
  ],
1041
  "outputs": [
1042
- "END_Repeat_1_output"
1043
  ],
1044
  "trained": true
1045
  },
@@ -1207,14 +1207,14 @@
1207
  "model": {
1208
  "model": {
1209
  "inputs": [
1210
- "Input__tensor_1_x"
1211
  ],
1212
  "loss_inputs": [
1213
- "END_Repeat_1_output",
1214
- "Input__tensor_3_x"
1215
  ],
1216
  "outputs": [
1217
- "END_Repeat_1_output"
1218
  ],
1219
  "trained": false
1220
  },
@@ -1270,8 +1270,8 @@
1270
  "type": "basic"
1271
  },
1272
  "params": {
1273
- "epochs": "1003",
1274
- "input_mapping": "{\"map\":{\"Input__tensor_1_x\":{\"df\":\"df_train\",\"column\":\"x\"},\"Input__tensor_3_x\":{\"df\":\"df_train\",\"column\":\"y\"}}}",
1275
  "model_name": "model"
1276
  },
1277
  "status": "done",
@@ -1319,14 +1319,14 @@
1319
  "model": {
1320
  "model": {
1321
  "inputs": [
1322
- "Input__tensor_1_x"
1323
  ],
1324
  "loss_inputs": [
1325
- "END_Repeat_1_output",
1326
- "Input__tensor_3_x"
1327
  ],
1328
  "outputs": [
1329
- "END_Repeat_1_output"
1330
  ],
1331
  "trained": true
1332
  },
@@ -1382,9 +1382,9 @@
1382
  "type": "basic"
1383
  },
1384
  "params": {
1385
- "input_mapping": "{\"map\":{\"Input__tensor_1_x\":{\"df\":\"df_test\",\"column\":\"x\"}}}",
1386
  "model_name": "model",
1387
- "output_mapping": "{\"map\":{\"END_Repeat_1_output\":{\"df\":\"df_test\",\"column\":\"pred\"}}}"
1388
  },
1389
  "status": "done",
1390
  "title": "Model inference"
 
579
  ],
580
  "data": [
581
  [
582
+ "[0.56922203 0.98222166 0.76851749 0.28615737]",
583
+ "[1.56922197 1.9822216 1.76851749 1.28615737]",
584
+ "[2.5075035095214844, 3.0133981704711914, 2.698194980621338, 2.3600802421569824]"
585
  ],
586
  [
587
+ "[0.98033333 0.97656083 0.38939917 0.81491041]",
588
+ "[1.98033333 1.97656083 1.38939917 1.81491041]",
589
+ "[3.0080792903900146, 2.9657773971557617, 2.4899187088012695, 2.7300877571105957]"
590
  ],
591
  [
592
+ "[0.74064726 0.4155122 0.09800029 0.49930882]",
593
+ "[1.74064732 1.4155122 1.09800029 1.49930882]",
594
+ "[2.7571821212768555, 2.4066381454467773, 2.099902629852295, 2.491953134536743]"
595
  ],
596
  [
597
+ "[0.68891573 0.25576538 0.96339929 0.503833 ]",
598
+ "[1.68891573 1.25576544 1.96339929 1.50383306]",
599
+ "[2.4299726486206055, 2.4633498191833496, 2.669276237487793, 2.753784656524658]"
600
  ],
601
  [
602
+ "[0.87285906 0.48354989 0.39394957 0.59456545]",
603
+ "[1.872859 1.48354983 1.39394951 1.59456539]",
604
+ "[2.7922420501708984, 2.555819034576416, 2.3168320655822754, 2.6561851501464844]"
605
  ],
606
  [
607
+ "[0.32565445 0.90939188 0.07488042 0.13730896]",
608
+ "[1.32565451 1.90939188 1.07488036 1.13730896]",
609
+ "[2.5029656887054443, 2.744194984436035, 2.2190115451812744, 2.024731159210205]"
610
  ],
611
  [
612
+ "[0.49607176 0.1922397 0.46640229 0.78321403]",
613
+ "[1.49607182 1.19223976 1.46640229 1.78321409]",
614
+ "[2.4844861030578613, 2.2119431495666504, 2.451040267944336, 2.7876930236816406]"
615
  ],
616
  [
617
+ "[0.85566247 0.83362883 0.48424995 0.25265992]",
618
+ "[1.85566247 1.83362889 1.48424995 1.25265992]",
619
+ "[2.757845401763916, 2.9018521308898926, 2.371169328689575, 2.356513500213623]"
620
  ],
621
  [
622
+ "[0.48959708 0.48549271 0.32688856 0.356677 ]",
623
+ "[1.48959708 1.48549271 1.32688856 1.35667706]",
624
+ "[2.4994757175445557, 2.469498634338379, 2.3010873794555664, 2.3796985149383545]"
625
  ],
626
  [
627
+ "[0.80654246 0.08253473 0.74478531 0.71257162]",
628
+ "[1.8065424 1.08253479 1.74478531 1.71257162]",
629
+ "[2.5831751823425293, 2.2751128673553467, 2.5047459602355957, 2.9079394340515137]"
630
  ]
631
  ]
632
  },
 
688
  "[0.11693293 0.49860179 0.55020827 0.88832849]",
689
  "[1.11693287 1.49860179 1.55020833 1.88832855]"
690
  ],
 
 
 
 
691
  [
692
  "[0.50272274 0.54912758 0.17663097 0.79070699]",
693
  "[1.50272274 1.54912758 1.17663097 1.79070699]"
 
696
  "[0.04508126 0.76880038 0.80721325 0.62542385]",
697
  "[1.04508126 1.76880038 1.80721331 1.62542391]"
698
  ],
699
+ [
700
+ "[0.19908059 0.17570406 0.51475513 0.1893943 ]",
701
+ "[1.19908059 1.175704 1.51475513 1.18939424]"
702
+ ],
703
  [
704
  "[0.40167677 0.25953674 0.9407078 0.76308483]",
705
  "[1.40167677 1.25953674 1.9407078 1.76308489]"
 
716
  "[0.62569475 0.9881897 0.83639616 0.9828859 ]",
717
  "[1.62569475 1.9881897 1.83639622 1.98288584]"
718
  ],
 
 
 
 
719
  [
720
  "[0.88776821 0.51636773 0.30333066 0.32230979]",
721
  "[1.88776827 1.51636767 1.30333066 1.32230973]"
 
732
  "[0.68062544 0.98093534 0.14778823 0.53244978]",
733
  "[1.68062544 1.98093534 1.14778829 1.53244972]"
734
  ],
735
+ [
736
+ "[0.31518555 0.49643308 0.11509258 0.95458382]",
737
+ "[1.31518555 1.49643302 1.11509252 1.95458388]"
738
+ ],
739
  [
740
  "[0.79121011 0.54161114 0.69369799 0.1520769 ]",
741
  "[1.79121017 1.54161119 1.69369793 1.15207696]"
 
752
  "[0.94516498 0.08422136 0.5608117 0.07652664]",
753
  "[1.94516492 1.08422136 1.56081176 1.07652664]"
754
  ],
755
+ [
756
+ "[0.26661873 0.45946234 0.13510543 0.81294441]",
757
+ "[1.26661873 1.4594624 1.13510537 1.81294441]"
758
+ ],
759
  [
760
  "[0.30754459 0.77694583 0.09278506 0.38326019]",
761
  "[1.30754459 1.77694583 1.09278512 1.38326025]"
 
768
  "[0.4827103 0.10563457 0.98858833 0.82286644]",
769
  "[1.48271036 1.10563457 1.98858833 1.82286644]"
770
  ],
 
 
 
 
 
 
 
 
771
  [
772
  "[0.78956431 0.87284744 0.06880784 0.03455889]",
773
  "[1.78956437 1.87284744 1.06880784 1.03455889]"
 
804
  "[0.73217702 0.65233225 0.44077861 0.33837909]",
805
  "[1.73217702 1.65233231 1.44077861 1.33837914]"
806
  ],
807
+ [
808
+ "[0.34084332 0.73018837 0.54168713 0.91440833]",
809
+ "[1.34084332 1.73018837 1.54168713 1.91440833]"
810
+ ],
811
  [
812
  "[0.60110539 0.3618983 0.32342511 0.98672163]",
813
  "[1.60110545 1.3618983 1.32342505 1.98672163]"
 
836
  "[0.18720162 0.74115586 0.98626411 0.30355608]",
837
  "[1.18720162 1.74115586 1.98626411 1.30355608]"
838
  ],
 
 
 
 
839
  [
840
  "[0.95928186 0.84273899 0.71514636 0.38619852]",
841
  "[1.95928192 1.84273899 1.7151463 1.38619852]"
842
  ],
843
+ [
844
+ "[0.9829582 0.59269661 0.40120947 0.95487177]",
845
+ "[1.9829582 1.59269667 1.40120947 1.95487177]"
846
+ ],
847
+ [
848
+ "[0.79905868 0.89367443 0.75429088 0.3190186 ]",
849
+ "[1.79905868 1.89367437 1.75429082 1.3190186 ]"
850
+ ],
851
  [
852
  "[0.54914117 0.03810108 0.87531954 0.73044223]",
853
  "[1.54914117 1.03810108 1.87531948 1.73044229]"
 
856
  "[0.67418337 0.79634351 0.23229051 0.71345252]",
857
  "[1.67418337 1.79634356 1.23229051 1.71345258]"
858
  ],
 
 
 
 
859
  [
860
  "[0.81788456 0.58174163 0.29376316 0.7971254 ]",
861
  "[1.81788456 1.58174157 1.29376316 1.79712534]"
 
868
  "[0.60075855 0.12234765 0.00614399 0.30560958]",
869
  "[1.60075855 1.12234759 1.00614405 1.30560958]"
870
  ],
871
+ [
872
+ "[0.39147133 0.29854035 0.84663737 0.58175623]",
873
+ "[1.39147139 1.29854035 1.84663737 1.58175623]"
874
+ ],
875
  [
876
  "[0.02162331 0.81861657 0.92468154 0.07808572]",
877
  "[1.02162337 1.81861663 1.92468154 1.07808566]"
 
900
  "[0.60609657 0.96257663 0.19292736 0.95702219]",
901
  "[1.60609651 1.96257663 1.19292736 1.95702219]"
902
  ],
 
 
 
 
903
  [
904
  "[0.70167565 0.26930219 0.5660674 0.61194974]",
905
  "[1.70167565 1.26930213 1.56606746 1.61194968]"
 
940
  "[0.37959969 0.42820001 0.10690689 0.96353984]",
941
  "[1.37959969 1.42820001 1.10690689 1.96353984]"
942
  ],
 
 
 
 
943
  [
944
  "[0.40234613 0.54987347 0.49542785 0.54153186]",
945
  "[1.40234613 1.54987347 1.49542785 1.5415318 ]"
946
  ],
947
+ [
948
+ "[0.80893755 0.92237449 0.88346356 0.93164903]",
949
+ "[1.80893755 1.92237449 1.88346362 1.93164897]"
950
+ ],
951
  [
952
  "[0.12858278 0.09930819 0.83222693 0.72485673]",
953
  "[1.12858272 1.09930825 1.83222699 1.72485673]"
 
980
  "[0.68094063 0.45189077 0.22661722 0.37354094]",
981
  "[1.68094063 1.45189071 1.22661722 1.37354088]"
982
  ],
983
+ [
984
+ "[0.43681622 0.74680805 0.83598751 0.12414402]",
985
+ "[1.43681622 1.74680805 1.83598757 1.12414408]"
986
+ ],
987
  [
988
  "[0.47870928 0.17129105 0.27300501 0.20634609]",
989
  "[1.47870922 1.17129111 1.27300501 1.20634604]"
 
995
  [
996
  "[0.87608397 0.93200487 0.80169648 0.37758952]",
997
  "[1.87608397 1.93200493 1.80169654 1.37758946]"
 
 
 
 
998
  ]
999
  ]
1000
  }
1001
  },
1002
  "other": {
1003
+ "model": "ModelConfig(model=Sequential(\n (0) - Identity(): Input__tensor_1_output -> START_Repeat_1_output\n (1) - Linear(4, 4, bias=True): START_Repeat_1_output -> Linear_1_output\n (2) - <function leaky_relu at 0x7b4a644509a0>: Linear_1_output -> Activation_1_output\n (3) - Identity(): Activation_1_output -> START_Repeat_1_output\n (4) - Linear(4, 4, bias=True): START_Repeat_1_output -> Linear_1_output\n (5) - <function leaky_relu at 0x7b4a644509a0>: Linear_1_output -> Activation_1_output\n (6) - Identity(): Activation_1_output -> END_Repeat_1_output\n (7) - Identity(): END_Repeat_1_output -> Output_1_x\n (8) - Identity(): Output_1_x -> Output_1_x\n), model_inputs=['Input__tensor_1_output'], model_outputs=['Output_1_x'], loss_inputs=['Input__tensor_3_output', 'Output_1_x'], loss=Sequential(\n (0) - <function constant_vector.<locals>.<lambda> at 0x7b4975350e00>: nothing -> Constant_vector_1_output\n (1) - <built-in method add of type object at 0x7b4a5859ef00>: Input__tensor_3_output, Constant_vector_1_output -> Add_1_output\n (2) - <function mse_loss at 0x7b4a64452480>: Output_1_x, Add_1_output -> MSE_loss_2_output\n (3) - Identity(): MSE_loss_2_output -> loss\n), optimizer_parameters={'lr': 0.1, 'type': <OptionsFor_type.SGD: 4>}, optimizer=SGD (\nParameter Group 0\n dampening: 0\n differentiable: False\n foreach: None\n fused: None\n lr: 0.1\n maximize: False\n momentum: 0\n nesterov: False\n weight_decay: 0\n), source_workspace='Model definition', trained=True)"
1004
  },
1005
  "relations": []
1006
  },
 
1032
  "model": {
1033
  "model": {
1034
  "inputs": [
1035
+ "Input__tensor_1_output"
1036
  ],
1037
  "loss_inputs": [
1038
+ "Input__tensor_3_output",
1039
+ "Output_1_x"
1040
  ],
1041
  "outputs": [
1042
+ "Output_1_x"
1043
  ],
1044
  "trained": true
1045
  },
 
1207
  "model": {
1208
  "model": {
1209
  "inputs": [
1210
+ "Input__tensor_1_output"
1211
  ],
1212
  "loss_inputs": [
1213
+ "Input__tensor_3_output",
1214
+ "Output_1_x"
1215
  ],
1216
  "outputs": [
1217
+ "Output_1_x"
1218
  ],
1219
  "trained": false
1220
  },
 
1270
  "type": "basic"
1271
  },
1272
  "params": {
1273
+ "epochs": "1001",
1274
+ "input_mapping": "{\"map\":{\"Input__tensor_1_output\":{\"df\":\"df_train\",\"column\":\"x\"},\"Input__tensor_3_output\":{\"df\":\"df_train\",\"column\":\"y\"}}}",
1275
  "model_name": "model"
1276
  },
1277
  "status": "done",
 
1319
  "model": {
1320
  "model": {
1321
  "inputs": [
1322
+ "Input__tensor_1_output"
1323
  ],
1324
  "loss_inputs": [
1325
+ "Input__tensor_3_output",
1326
+ "Output_1_x"
1327
  ],
1328
  "outputs": [
1329
+ "Output_1_x"
1330
  ],
1331
  "trained": true
1332
  },
 
1382
  "type": "basic"
1383
  },
1384
  "params": {
1385
+ "input_mapping": "{\"map\":{\"Input__tensor_1_output\":{\"df\":\"df_test\",\"column\":\"x\"}}}",
1386
  "model_name": "model",
1387
+ "output_mapping": "{\"map\":{\"Output_1_x\":{\"df\":\"df_test\",\"column\":\"pred\"}}}"
1388
  },
1389
  "status": "done",
1390
  "title": "Model inference"
lynxkite-graph-analytics/src/lynxkite_graph_analytics/pytorch/core.py CHANGED
@@ -1,4 +1,4 @@
1
- """Boxes for defining PyTorch models."""
2
 
3
  import copy
4
  import graphlib
@@ -55,7 +55,10 @@ class Layer:
55
  outputs: list[str]
56
 
57
  def for_sequential(self):
58
- inputs = ", ".join(self.inputs)
 
 
 
59
  outputs = ", ".join(self.outputs)
60
  return self.module, f"{inputs} -> {outputs}"
61
 
@@ -88,8 +91,7 @@ class ModelConfig:
88
  return sum(p.numel() for p in self.model.parameters())
89
 
90
  def _forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
91
- model_inputs = [inputs[i] for i in self.model_inputs]
92
- output = self.model(*model_inputs)
93
  if not isinstance(output, tuple):
94
  output = (output,)
95
  values = {k: v for k, v in zip(self.model_outputs, output)}
@@ -107,8 +109,7 @@ class ModelConfig:
107
  self.optimizer.zero_grad()
108
  values = self._forward(inputs)
109
  values.update(inputs)
110
- loss_inputs = [values[i] for i in self.loss_inputs]
111
- loss = self.loss(*loss_inputs)
112
  loss.backward()
113
  self.optimizer.step()
114
  return loss.item()
@@ -204,20 +205,20 @@ class ModelBuilder:
204
  # repeat <- last -> real_output
205
  # ...becomes...
206
  # last -> end -> real_output
207
- last, lasth = self.in_edges[repeat]["input"][0]
208
- [(real_output, real_outputh)] = [
209
- k for k in self.out_edges[last][lasth] if k != (repeat, "input")
210
- ]
211
  del self.dependencies[repeat]
212
  self.dependencies[end_id] = [last]
213
- self.dependencies[real_output].append(end_id)
214
  self.out_edges[last][lasth] = [(end_id, "input")]
215
  self.in_edges[end_id] = {"input": [(last, lasth)]}
216
- self.out_edges[end_id] = {"output": [(real_output, real_outputh)]}
217
- self.in_edges[real_output][real_outputh] = [
218
- k if k != (last, lasth) else (end_id, "output")
219
- for k in self.in_edges[real_output][real_outputh]
220
- ]
 
 
 
221
  self.inv_dependencies = {n: [] for n in self.nodes}
222
  for k, v in self.dependencies.items():
223
  for i in v:
@@ -316,18 +317,19 @@ class ModelBuilder:
316
 
317
  def get_config(self) -> ModelConfig:
318
  # Split the design into model and loss.
319
- loss_nodes = set()
320
  for node_id in self.nodes:
321
- if "loss" in self.nodes[node_id].data.title:
322
- loss_nodes.add(node_id)
323
- loss_nodes |= self.all_downstream(node_id)
 
324
  layers = []
325
  loss_layers = []
326
  for layer in self.layers:
327
- if layer.origin_id in loss_nodes:
328
- loss_layers.append(layer)
329
- else:
330
  layers.append(layer)
 
 
331
  used_in_model = set(input for layer in layers for input in layer.inputs)
332
  used_in_loss = set(input for layer in loss_layers for input in layer.inputs)
333
  made_in_model = set(output for layer in layers for output in layer.outputs)
 
1
+ """Infrastructure for defining PyTorch models."""
2
 
3
  import copy
4
  import graphlib
 
55
  outputs: list[str]
56
 
57
  def for_sequential(self):
58
+ """The layer signature for pyg.nn.Sequential."""
59
+ # "nothing" is used as a bogus input if an operation has no inputs.
60
+ # The module in turn needs to take one argument, but it will always be None.
61
+ inputs = ", ".join(self.inputs) or "nothing"
62
  outputs = ", ".join(self.outputs)
63
  return self.module, f"{inputs} -> {outputs}"
64
 
 
91
  return sum(p.numel() for p in self.model.parameters())
92
 
93
  def _forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
94
+ output = self.model(nothing=None, **inputs)
 
95
  if not isinstance(output, tuple):
96
  output = (output,)
97
  values = {k: v for k, v in zip(self.model_outputs, output)}
 
109
  self.optimizer.zero_grad()
110
  values = self._forward(inputs)
111
  values.update(inputs)
112
+ loss = self.loss(nothing=None, **values)
 
113
  loss.backward()
114
  self.optimizer.step()
115
  return loss.item()
 
205
  # repeat <- last -> real_output
206
  # ...becomes...
207
  # last -> end -> real_output
208
+ [(last, lasth)] = self.in_edges[repeat]["input"]
 
 
 
209
  del self.dependencies[repeat]
210
  self.dependencies[end_id] = [last]
211
+ real_edges = [e for e in self.out_edges[last][lasth] if e != (repeat, "input")]
212
  self.out_edges[last][lasth] = [(end_id, "input")]
213
  self.in_edges[end_id] = {"input": [(last, lasth)]}
214
+ self.out_edges[end_id] = {"output": []} # Populated below.
215
+ for real_output, real_outputh in real_edges:
216
+ self.dependencies[real_output].append(end_id)
217
+ self.in_edges[real_output][real_outputh] = [
218
+ k if k != (last, lasth) else (end_id, "output")
219
+ for k in self.in_edges[real_output][real_outputh]
220
+ ]
221
+ self.out_edges[end_id]["output"].append((real_output, real_outputh))
222
  self.inv_dependencies = {n: [] for n in self.nodes}
223
  for k, v in self.dependencies.items():
224
  for i in v:
 
317
 
318
  def get_config(self) -> ModelConfig:
319
  # Split the design into model and loss.
320
+ model_nodes = set()
321
  for node_id in self.nodes:
322
+ if self.nodes[node_id].data.title == "Output":
323
+ model_nodes.add(node_id)
324
+ model_nodes |= self.all_upstream(node_id)
325
+ assert model_nodes, "The model definition must have at least one Output node."
326
  layers = []
327
  loss_layers = []
328
  for layer in self.layers:
329
+ if layer.origin_id in model_nodes:
 
 
330
  layers.append(layer)
331
+ else:
332
+ loss_layers.append(layer)
333
  used_in_model = set(input for layer in layers for input in layer.inputs)
334
  used_in_loss = set(input for layer in loss_layers for input in layer.inputs)
335
  made_in_model = set(output for layer in layers for output in layer.outputs)
lynxkite-graph-analytics/src/lynxkite_graph_analytics/pytorch/ops.py CHANGED
@@ -10,8 +10,14 @@ from .core import op, reg, ENV
10
  reg("Input: tensor", outputs=["output"], params=[P.basic("name")])
11
  reg("Input: graph edges", outputs=["edges"])
12
  reg("Input: sequential", outputs=["y"])
 
 
 
 
 
 
 
13
 
14
- reg("LSTM", inputs=["x", "h"], outputs=["x", "h"])
15
  reg(
16
  "Neural ODE",
17
  inputs=["x"],
@@ -37,9 +43,20 @@ reg(
37
  )
38
 
39
 
40
- reg("Attention", inputs=["q", "k", "v"], outputs=["x", "weights"])
41
- reg("LayerNorm", inputs=["x"])
42
- reg("Dropout", inputs=["x"], params=[P.basic("p", 0.5)])
 
 
 
 
 
 
 
 
 
 
 
43
 
44
 
45
  @op("Linear")
@@ -64,17 +81,28 @@ def mse_loss(x, y):
64
  return torch.nn.functional.mse_loss
65
 
66
 
67
- reg("Softmax", inputs=["x"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  reg(
69
  "Graph conv",
70
  inputs=["x", "edges"],
71
  outputs=["x"],
72
  params=[P.options("type", ["GCNConv", "GATConv", "GATv2Conv", "SAGEConv"])],
73
  )
74
- reg("Concatenate", inputs=["a", "b"], outputs=["x"])
75
- reg("Add", inputs=["a", "b"], outputs=["x"])
76
- reg("Subtract", inputs=["a", "b"], outputs=["x"])
77
- reg("Multiply", inputs=["a", "b"], outputs=["x"])
78
  reg("Triplet margin loss", inputs=["x", "x_pos", "x_neg"], outputs=["loss"])
79
  reg("Cross-entropy loss", inputs=["x", "y"], outputs=["loss"])
80
  reg(
@@ -116,3 +144,40 @@ ops.register_passive_op(
116
  outputs=[ops.Output(name="output", position="bottom", type="tensor")],
117
  params=[],
118
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  reg("Input: tensor", outputs=["output"], params=[P.basic("name")])
11
  reg("Input: graph edges", outputs=["edges"])
12
  reg("Input: sequential", outputs=["y"])
13
+ reg("Output", inputs=["x"], outputs=["x"], params=[P.basic("name")])
14
+
15
+
16
+ @op("LSTM")
17
+ def lstm(x, *, input_size=1024, hidden_size=1024, dropout=0.0):
18
+ return torch.nn.LSTM(input_size, hidden_size, dropout=0.0)
19
+
20
 
 
21
  reg(
22
  "Neural ODE",
23
  inputs=["x"],
 
43
  )
44
 
45
 
46
+ @op("Attention", outputs=["outputs", "weights"])
47
+ def attention(query, key, value, *, embed_dim=1024, num_heads=1, dropout=0.0):
48
+ return torch.nn.MultiHeadAttention(embed_dim, num_heads, dropout=dropout, need_weights=True)
49
+
50
+
51
+ @op("LayerNorm", outputs=["outputs", "weights"])
52
+ def layernorm(x, *, normalized_shape=""):
53
+ normalized_shape = [int(s.strip()) for s in normalized_shape.split(",")]
54
+ return torch.nn.LayerNorm(normalized_shape)
55
+
56
+
57
+ @op("Dropout", outputs=["outputs", "weights"])
58
+ def dropout(x, *, p=0.0):
59
+ return torch.nn.Dropout(p)
60
 
61
 
62
  @op("Linear")
 
81
  return torch.nn.functional.mse_loss
82
 
83
 
84
+ @op("Constant vector")
85
+ def constant_vector(*, value=0, size=1):
86
+ return lambda _: torch.full((size,), value)
87
+
88
+
89
+ @op("Softmax")
90
+ def softmax(x, *, dim=1):
91
+ return torch.nn.Softmax(dim=dim)
92
+
93
+
94
+ @op("Concatenate")
95
+ def concatenate(a, b):
96
+ return lambda a, b: torch.concatenate(*torch.broadcast_tensors(a, b))
97
+
98
+
99
  reg(
100
  "Graph conv",
101
  inputs=["x", "edges"],
102
  outputs=["x"],
103
  params=[P.options("type", ["GCNConv", "GATConv", "GATv2Conv", "SAGEConv"])],
104
  )
105
+
 
 
 
106
  reg("Triplet margin loss", inputs=["x", "x_pos", "x_neg"], outputs=["loss"])
107
  reg("Cross-entropy loss", inputs=["x", "y"], outputs=["loss"])
108
  reg(
 
144
  outputs=[ops.Output(name="output", position="bottom", type="tensor")],
145
  params=[],
146
  )
147
+
148
+
149
+ def _set_handle_positions(op):
150
+ op: ops.Op = op.__op__
151
+ for v in op.outputs.values():
152
+ v.position = "top"
153
+ for v in op.inputs.values():
154
+ v.position = "bottom"
155
+
156
+
157
+ def _register_simple_pytorch_layer(func):
158
+ op = ops.op(ENV, func.__name__.title())(lambda input: func)
159
+ _set_handle_positions(op)
160
+
161
+
162
+ def _register_two_tensor_function(func):
163
+ op = ops.op(ENV, func.__name__.title())(lambda a, b: func)
164
+ _set_handle_positions(op)
165
+
166
+
167
+ SIMPLE_FUNCTIONS = [
168
+ torch.sin,
169
+ torch.cos,
170
+ torch.log,
171
+ torch.exp,
172
+ ]
173
+ TWO_TENSOR_FUNCTIONS = [
174
+ torch.multiply,
175
+ torch.add,
176
+ torch.subtract,
177
+ ]
178
+
179
+
180
+ for f in SIMPLE_FUNCTIONS:
181
+ _register_simple_pytorch_layer(f)
182
+ for f in TWO_TENSOR_FUNCTIONS:
183
+ _register_two_tensor_function(f)
lynxkite-graph-analytics/tests/test_pytorch_model_ops.py CHANGED
@@ -48,17 +48,19 @@ async def test_build_model():
48
  ws = make_ws(
49
  pytorch.core.ENV,
50
  {
51
- "emb": {"title": "Input: tensor"},
52
  "lin": {"title": "Linear", "output_dim": 4},
53
  "act": {"title": "Activation", "type": "Leaky_ReLU"},
 
54
  "label": {"title": "Input: tensor"},
55
  "loss": {"title": "MSE loss"},
56
  "optim": {"title": "Optimizer", "type": "SGD", "lr": 0.1},
57
  },
58
  [
59
- ("emb:output", "lin:x"),
60
  ("lin:output", "act:x"),
61
- ("act:output", "loss:x"),
 
62
  ("label:output", "loss:y"),
63
  ("loss:output", "optim:loss"),
64
  ],
@@ -67,10 +69,10 @@ async def test_build_model():
67
  y = x + 1
68
  m = pytorch.core.build_model(ws)
69
  for i in range(1000):
70
- loss = m.train({"emb_output": x, "label_output": y})
71
  assert loss < 0.1
72
- o = m.inference({"emb_output": x[:1]})
73
- error = torch.nn.functional.mse_loss(o["act_output"], x[:1] + 1)
74
  assert error < 0.1
75
 
76
 
@@ -79,18 +81,20 @@ async def test_build_model_with_repeat():
79
  return make_ws(
80
  pytorch.core.ENV,
81
  {
82
- "emb": {"title": "Input: tensor"},
83
  "lin": {"title": "Linear", "output_dim": 8},
84
  "act": {"title": "Activation", "type": "Leaky_ReLU"},
 
85
  "label": {"title": "Input: tensor"},
86
  "loss": {"title": "MSE loss"},
87
  "optim": {"title": "Optimizer", "type": "SGD", "lr": 0.1},
88
  "repeat": {"title": "Repeat", "times": times, "same_weights": False},
89
  },
90
  [
91
- ("emb:output", "lin:x"),
92
  ("lin:output", "act:x"),
93
- ("act:output", "loss:x"),
 
94
  ("label:output", "loss:y"),
95
  ("loss:output", "optim:loss"),
96
  ("repeat:output", "lin:x"),
@@ -100,18 +104,18 @@ async def test_build_model_with_repeat():
100
 
101
  # 1 repetition
102
  m = pytorch.core.build_model(repeated_ws(1))
103
- assert summarize_layers(m) == "IL<II"
104
- assert summarize_connections(m) == "e->S S->l l->a a->E E->E"
105
 
106
  # 2 repetitions
107
  m = pytorch.core.build_model(repeated_ws(2))
108
- assert summarize_layers(m) == "IL<IL<II"
109
- assert summarize_connections(m) == "e->S S->l l->a a->S S->l l->a a->E E->E"
110
 
111
  # 3 repetitions
112
  m = pytorch.core.build_model(repeated_ws(3))
113
- assert summarize_layers(m) == "IL<IL<IL<II"
114
- assert summarize_connections(m) == "e->S S->l l->a a->S S->l l->a a->S S->l l->a a->E E->E"
115
 
116
 
117
  if __name__ == "__main__":
 
48
  ws = make_ws(
49
  pytorch.core.ENV,
50
  {
51
+ "input": {"title": "Input: tensor"},
52
  "lin": {"title": "Linear", "output_dim": 4},
53
  "act": {"title": "Activation", "type": "Leaky_ReLU"},
54
+ "output": {"title": "Output"},
55
  "label": {"title": "Input: tensor"},
56
  "loss": {"title": "MSE loss"},
57
  "optim": {"title": "Optimizer", "type": "SGD", "lr": 0.1},
58
  },
59
  [
60
+ ("input:output", "lin:x"),
61
  ("lin:output", "act:x"),
62
+ ("act:output", "output:x"),
63
+ ("output:x", "loss:x"),
64
  ("label:output", "loss:y"),
65
  ("loss:output", "optim:loss"),
66
  ],
 
69
  y = x + 1
70
  m = pytorch.core.build_model(ws)
71
  for i in range(1000):
72
+ loss = m.train({"input_output": x, "label_output": y})
73
  assert loss < 0.1
74
+ o = m.inference({"input_output": x[:1]})
75
+ error = torch.nn.functional.mse_loss(o["output_x"], x[:1] + 1)
76
  assert error < 0.1
77
 
78
 
 
81
  return make_ws(
82
  pytorch.core.ENV,
83
  {
84
+ "input": {"title": "Input: tensor"},
85
  "lin": {"title": "Linear", "output_dim": 8},
86
  "act": {"title": "Activation", "type": "Leaky_ReLU"},
87
+ "output": {"title": "Output"},
88
  "label": {"title": "Input: tensor"},
89
  "loss": {"title": "MSE loss"},
90
  "optim": {"title": "Optimizer", "type": "SGD", "lr": 0.1},
91
  "repeat": {"title": "Repeat", "times": times, "same_weights": False},
92
  },
93
  [
94
+ ("input:output", "lin:x"),
95
  ("lin:output", "act:x"),
96
+ ("act:output", "output:x"),
97
+ ("output:x", "loss:x"),
98
  ("label:output", "loss:y"),
99
  ("loss:output", "optim:loss"),
100
  ("repeat:output", "lin:x"),
 
104
 
105
  # 1 repetition
106
  m = pytorch.core.build_model(repeated_ws(1))
107
+ assert summarize_layers(m) == "IL<III"
108
+ assert summarize_connections(m) == "i->S S->l l->a a->E E->o o->o"
109
 
110
  # 2 repetitions
111
  m = pytorch.core.build_model(repeated_ws(2))
112
+ assert summarize_layers(m) == "IL<IL<III"
113
+ assert summarize_connections(m) == "i->S S->l l->a a->S S->l l->a a->E E->o o->o"
114
 
115
  # 3 repetitions
116
  m = pytorch.core.build_model(repeated_ws(3))
117
+ assert summarize_layers(m) == "IL<IL<IL<III"
118
+ assert summarize_connections(m) == "i->S S->l l->a a->S S->l l->a a->S S->l l->a a->E E->o o->o"
119
 
120
 
121
  if __name__ == "__main__":