darabos commited on
Commit
2c239c4
·
1 Parent(s): 87f8b85

Each op makes its own layer. #107

Browse files
examples/Model definition CHANGED
@@ -1,52 +1,52 @@
1
  {
2
  "edges": [
3
  {
4
- "id": "Input: embedding 1 Linear 1",
5
- "source": "Input: embedding 1",
6
- "sourceHandle": "x",
7
- "target": "Linear 1",
8
- "targetHandle": "x"
9
- },
10
- {
11
- "id": "Input: label 1 MSE loss 1",
12
- "source": "Input: label 1",
13
- "sourceHandle": "y",
14
- "target": "MSE loss 1",
15
- "targetHandle": "y"
16
  },
17
  {
18
- "id": "Linear 1 Activation 2",
19
- "source": "Linear 1",
20
- "sourceHandle": "x",
21
- "target": "Activation 2",
22
  "targetHandle": "x"
23
  },
24
  {
25
- "id": "Activation 2 MSE loss 1",
26
- "source": "Activation 2",
27
- "sourceHandle": "x",
28
  "target": "MSE loss 1",
29
  "targetHandle": "x"
30
  },
31
  {
32
- "id": "MSE loss 1 Optimizer 2",
33
- "source": "MSE loss 1",
34
- "sourceHandle": "loss",
35
- "target": "Optimizer 2",
36
- "targetHandle": "loss"
37
  },
38
  {
39
- "id": "Activation 2 Repeat 1",
40
- "source": "Activation 2",
41
- "sourceHandle": "x",
42
  "target": "Repeat 1",
43
  "targetHandle": "input"
44
  },
45
  {
46
- "id": "Repeat 1 Linear 1",
47
- "source": "Repeat 1",
48
- "sourceHandle": "output",
49
- "target": "Linear 1",
 
 
 
 
 
 
 
50
  "targetHandle": "x"
51
  }
52
  ],
@@ -58,11 +58,26 @@
58
  "error": null,
59
  "input_metadata": null,
60
  "meta": {
61
- "inputs": {},
62
- "name": "Input: embedding",
63
- "outputs": {
64
  "x": {
65
  "name": "x",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  "position": "top",
67
  "type": {
68
  "type": "tensor"
@@ -74,115 +89,138 @@
74
  },
75
  "params": {},
76
  "status": "planned",
77
- "title": "Input: embedding"
78
  },
79
  "dragHandle": ".bg-primary",
80
  "height": 200.0,
81
- "id": "Input: embedding 1",
82
  "position": {
83
- "x": 91.0,
84
- "y": 266.0
85
  },
86
  "type": "basic",
87
  "width": 200.0
88
  },
89
  {
90
  "data": {
 
 
91
  "display": null,
92
  "error": null,
93
  "input_metadata": null,
94
  "meta": {
95
  "inputs": {
96
- "x": {
97
- "name": "x",
98
  "position": "bottom",
99
  "type": {
100
  "type": "tensor"
101
  }
102
  }
103
  },
104
- "name": "Linear",
105
- "outputs": {
106
- "x": {
107
- "name": "x",
108
- "position": "top",
 
109
  "type": {
110
- "type": "tensor"
111
  }
112
- }
113
- },
114
- "params": {
115
- "output_dim": {
116
- "default": "same",
117
- "name": "output_dim",
118
  "type": {
119
- "type": "<class 'str'>"
 
 
 
 
 
 
 
 
120
  }
121
  }
122
  },
123
  "type": "basic"
124
  },
125
  "params": {
126
- "output_dim": "same"
 
127
  },
128
  "status": "planned",
129
- "title": "Linear"
130
  },
131
  "dragHandle": ".bg-primary",
132
- "height": 200.0,
133
- "id": "Linear 1",
134
  "position": {
135
- "x": 86.0,
136
- "y": 33.0
137
  },
138
  "type": "basic",
139
- "width": 200.0
140
  },
141
  {
142
  "data": {
 
 
143
  "display": null,
144
  "error": null,
145
  "input_metadata": null,
146
  "meta": {
147
  "inputs": {
148
- "x": {
149
- "name": "x",
150
- "position": "bottom",
151
  "type": {
152
  "type": "tensor"
153
  }
154
- },
155
- "y": {
156
- "name": "y",
 
 
 
157
  "position": "bottom",
158
  "type": {
159
  "type": "tensor"
160
  }
161
  }
162
  },
163
- "name": "MSE loss",
164
- "outputs": {
165
- "loss": {
166
- "name": "loss",
167
- "position": "top",
168
  "type": {
169
- "type": "tensor"
 
 
 
 
 
 
 
170
  }
171
  }
172
  },
173
- "params": {},
174
  "type": "basic"
175
  },
176
- "params": {},
 
 
 
177
  "status": "planned",
178
- "title": "MSE loss"
179
  },
180
  "dragHandle": ".bg-primary",
181
  "height": 200.0,
182
- "id": "MSE loss 1",
183
  "position": {
184
- "x": 315.0,
185
- "y": -510.0
186
  },
187
  "type": "basic",
188
  "width": 200.0
@@ -193,30 +231,48 @@
193
  "error": null,
194
  "input_metadata": null,
195
  "meta": {
196
- "inputs": {},
197
- "name": "Input: label",
 
 
 
 
 
 
 
 
198
  "outputs": {
199
- "y": {
200
- "name": "y",
201
  "position": "top",
202
  "type": {
203
- "type": "tensor"
 
 
 
 
 
 
 
 
 
204
  }
205
  }
206
  },
207
- "params": {},
208
  "type": "basic"
209
  },
210
- "params": {},
 
 
211
  "status": "planned",
212
- "title": "Input: label"
213
  },
214
  "dragHandle": ".bg-primary",
215
  "height": 200.0,
216
- "id": "Input: label 1",
217
  "position": {
218
- "x": 615.0,
219
- "y": -165.0
220
  },
221
  "type": "basic",
222
  "width": 200.0
@@ -234,17 +290,17 @@
234
  "name": "x",
235
  "position": "bottom",
236
  "type": {
237
- "type": "tensor"
238
  }
239
  }
240
  },
241
  "name": "Activation",
242
  "outputs": {
243
- "x": {
244
- "name": "x",
245
  "position": "top",
246
  "type": {
247
- "type": "tensor"
248
  }
249
  }
250
  },
@@ -255,27 +311,31 @@
255
  "type": {
256
  "enum": [
257
  "ReLU",
258
- "Leaky ReLU",
259
  "Tanh",
260
  "Mish"
261
  ]
262
  }
263
  }
264
  },
 
 
 
 
265
  "type": "basic"
266
  },
267
  "params": {
268
- "type": "Leaky ReLU"
269
  },
270
  "status": "planned",
271
  "title": "Activation"
272
  },
273
  "dragHandle": ".bg-primary",
274
  "height": 200.0,
275
- "id": "Activation 2",
276
  "position": {
277
- "x": 93.61643829835265,
278
- "y": -229.04087132886406
279
  },
280
  "type": "basic",
281
  "width": 200.0
@@ -288,56 +348,44 @@
288
  "error": null,
289
  "input_metadata": null,
290
  "meta": {
291
- "inputs": {
292
- "loss": {
293
- "name": "loss",
294
- "position": "bottom",
 
 
295
  "type": {
296
  "type": "tensor"
297
  }
298
  }
299
  },
300
- "name": "Optimizer",
301
- "outputs": {},
302
  "params": {
303
- "lr": {
304
- "default": 0.001,
305
- "name": "lr",
306
- "type": {
307
- "type": "<class 'float'>"
308
- }
309
- },
310
- "type": {
311
- "default": "AdamW",
312
- "name": "type",
313
  "type": {
314
- "enum": [
315
- "AdamW",
316
- "Adafactor",
317
- "Adagrad",
318
- "SGD",
319
- "Lion",
320
- "Paged AdamW",
321
- "Galore AdamW"
322
- ]
323
  }
324
  }
325
  },
 
 
 
 
326
  "type": "basic"
327
  },
328
  "params": {
329
- "lr": "0.1",
330
- "type": "SGD"
331
  },
332
  "status": "planned",
333
- "title": "Optimizer"
334
  },
335
  "dragHandle": ".bg-primary",
336
  "height": 200.0,
337
- "id": "Optimizer 2",
338
  "position": {
339
- "x": 305.6132943499785,
340
- "y": -804.0094318451224
341
  },
342
  "type": "basic",
343
  "width": 200.0
@@ -350,60 +398,44 @@
350
  "error": null,
351
  "input_metadata": null,
352
  "meta": {
353
- "inputs": {
354
- "input": {
355
- "name": "input",
356
- "position": "top",
357
- "type": {
358
- "type": "tensor"
359
- }
360
- }
361
- },
362
- "name": "Repeat",
363
  "outputs": {
364
- "output": {
365
- "name": "output",
366
- "position": "bottom",
367
  "type": {
368
  "type": "tensor"
369
  }
370
  }
371
  },
372
  "params": {
373
- "same_weights": {
374
- "default": true,
375
- "name": "same_weights",
376
- "type": {
377
- "type": "<class 'bool'>"
378
- }
379
- },
380
- "times": {
381
- "default": 1.0,
382
- "name": "times",
383
  "type": {
384
- "type": "<class 'int'>"
385
  }
386
  }
387
  },
388
  "position": {
389
- "x": 386.0,
390
- "y": 456.0
391
  },
392
  "type": "basic"
393
  },
394
  "params": {
395
- "same_weights": false,
396
- "times": "2"
397
  },
398
  "status": "planned",
399
- "title": "Repeat"
400
  },
401
  "dragHandle": ".bg-primary",
402
  "height": 200.0,
403
- "id": "Repeat 1",
404
  "position": {
405
- "x": -180.0,
406
- "y": -90.0
407
  },
408
  "type": "basic",
409
  "width": 200.0
 
1
  {
2
  "edges": [
3
  {
4
+ "id": "MSE loss 1 Optimizer 2",
5
+ "source": "MSE loss 1",
6
+ "sourceHandle": "loss",
7
+ "target": "Optimizer 2",
8
+ "targetHandle": "loss"
 
 
 
 
 
 
 
9
  },
10
  {
11
+ "id": "Repeat 1 Linear 2",
12
+ "source": "Repeat 1",
13
+ "sourceHandle": "output",
14
+ "target": "Linear 2",
15
  "targetHandle": "x"
16
  },
17
  {
18
+ "id": "Activation 1 MSE loss 1",
19
+ "source": "Activation 1",
20
+ "sourceHandle": "output",
21
  "target": "MSE loss 1",
22
  "targetHandle": "x"
23
  },
24
  {
25
+ "id": "Linear 2 Activation 1",
26
+ "source": "Linear 2",
27
+ "sourceHandle": "output",
28
+ "target": "Activation 1",
29
+ "targetHandle": "x"
30
  },
31
  {
32
+ "id": "Activation 1 Repeat 1",
33
+ "source": "Activation 1",
34
+ "sourceHandle": "output",
35
  "target": "Repeat 1",
36
  "targetHandle": "input"
37
  },
38
  {
39
+ "id": "Input: tensor 3 MSE loss 1",
40
+ "source": "Input: tensor 3",
41
+ "sourceHandle": "x",
42
+ "target": "MSE loss 1",
43
+ "targetHandle": "y"
44
+ },
45
+ {
46
+ "id": "Input: tensor 1 Linear 2",
47
+ "source": "Input: tensor 1",
48
+ "sourceHandle": "x",
49
+ "target": "Linear 2",
50
  "targetHandle": "x"
51
  }
52
  ],
 
58
  "error": null,
59
  "input_metadata": null,
60
  "meta": {
61
+ "inputs": {
 
 
62
  "x": {
63
  "name": "x",
64
+ "position": "bottom",
65
+ "type": {
66
+ "type": "tensor"
67
+ }
68
+ },
69
+ "y": {
70
+ "name": "y",
71
+ "position": "bottom",
72
+ "type": {
73
+ "type": "tensor"
74
+ }
75
+ }
76
+ },
77
+ "name": "MSE loss",
78
+ "outputs": {
79
+ "loss": {
80
+ "name": "loss",
81
  "position": "top",
82
  "type": {
83
  "type": "tensor"
 
89
  },
90
  "params": {},
91
  "status": "planned",
92
+ "title": "MSE loss"
93
  },
94
  "dragHandle": ".bg-primary",
95
  "height": 200.0,
96
+ "id": "MSE loss 1",
97
  "position": {
98
+ "x": 315.0,
99
+ "y": -510.0
100
  },
101
  "type": "basic",
102
  "width": 200.0
103
  },
104
  {
105
  "data": {
106
+ "__execution_delay": 0.0,
107
+ "collapsed": null,
108
  "display": null,
109
  "error": null,
110
  "input_metadata": null,
111
  "meta": {
112
  "inputs": {
113
+ "loss": {
114
+ "name": "loss",
115
  "position": "bottom",
116
  "type": {
117
  "type": "tensor"
118
  }
119
  }
120
  },
121
+ "name": "Optimizer",
122
+ "outputs": {},
123
+ "params": {
124
+ "lr": {
125
+ "default": 0.001,
126
+ "name": "lr",
127
  "type": {
128
+ "type": "<class 'float'>"
129
  }
130
+ },
131
+ "type": {
132
+ "default": "AdamW",
133
+ "name": "type",
 
 
134
  "type": {
135
+ "enum": [
136
+ "AdamW",
137
+ "Adafactor",
138
+ "Adagrad",
139
+ "SGD",
140
+ "Lion",
141
+ "Paged AdamW",
142
+ "Galore AdamW"
143
+ ]
144
  }
145
  }
146
  },
147
  "type": "basic"
148
  },
149
  "params": {
150
+ "lr": "0.1",
151
+ "type": "SGD"
152
  },
153
  "status": "planned",
154
+ "title": "Optimizer"
155
  },
156
  "dragHandle": ".bg-primary",
157
+ "height": 250.0,
158
+ "id": "Optimizer 2",
159
  "position": {
160
+ "x": 292.3983313429414,
161
+ "y": -853.8015246037802
162
  },
163
  "type": "basic",
164
+ "width": 232.0
165
  },
166
  {
167
  "data": {
168
+ "__execution_delay": 0.0,
169
+ "collapsed": null,
170
  "display": null,
171
  "error": null,
172
  "input_metadata": null,
173
  "meta": {
174
  "inputs": {
175
+ "input": {
176
+ "name": "input",
177
+ "position": "top",
178
  "type": {
179
  "type": "tensor"
180
  }
181
+ }
182
+ },
183
+ "name": "Repeat",
184
+ "outputs": {
185
+ "output": {
186
+ "name": "output",
187
  "position": "bottom",
188
  "type": {
189
  "type": "tensor"
190
  }
191
  }
192
  },
193
+ "params": {
194
+ "same_weights": {
195
+ "default": false,
196
+ "name": "same_weights",
 
197
  "type": {
198
+ "type": "<class 'bool'>"
199
+ }
200
+ },
201
+ "times": {
202
+ "default": 1.0,
203
+ "name": "times",
204
+ "type": {
205
+ "type": "<class 'int'>"
206
  }
207
  }
208
  },
 
209
  "type": "basic"
210
  },
211
+ "params": {
212
+ "same_weights": false,
213
+ "times": "3"
214
+ },
215
  "status": "planned",
216
+ "title": "Repeat"
217
  },
218
  "dragHandle": ".bg-primary",
219
  "height": 200.0,
220
+ "id": "Repeat 1",
221
  "position": {
222
+ "x": -180.0,
223
+ "y": -90.0
224
  },
225
  "type": "basic",
226
  "width": 200.0
 
231
  "error": null,
232
  "input_metadata": null,
233
  "meta": {
234
+ "inputs": {
235
+ "x": {
236
+ "name": "x",
237
+ "position": "bottom",
238
+ "type": {
239
+ "type": "<class 'inspect._empty'>"
240
+ }
241
+ }
242
+ },
243
+ "name": "Linear",
244
  "outputs": {
245
+ "output": {
246
+ "name": "output",
247
  "position": "top",
248
  "type": {
249
+ "type": "None"
250
+ }
251
+ }
252
+ },
253
+ "params": {
254
+ "output_dim": {
255
+ "default": "same",
256
+ "name": "output_dim",
257
+ "type": {
258
+ "type": "<class 'str'>"
259
  }
260
  }
261
  },
 
262
  "type": "basic"
263
  },
264
+ "params": {
265
+ "output_dim": "same"
266
+ },
267
  "status": "planned",
268
+ "title": "Linear"
269
  },
270
  "dragHandle": ".bg-primary",
271
  "height": 200.0,
272
+ "id": "Linear 2",
273
  "position": {
274
+ "x": 92.32755761444682,
275
+ "y": 20.626371289630676
276
  },
277
  "type": "basic",
278
  "width": 200.0
 
290
  "name": "x",
291
  "position": "bottom",
292
  "type": {
293
+ "type": "<class 'inspect._empty'>"
294
  }
295
  }
296
  },
297
  "name": "Activation",
298
  "outputs": {
299
+ "output": {
300
+ "name": "output",
301
  "position": "top",
302
  "type": {
303
+ "type": "None"
304
  }
305
  }
306
  },
 
311
  "type": {
312
  "enum": [
313
  "ReLU",
314
+ "Leaky_ReLU",
315
  "Tanh",
316
  "Mish"
317
  ]
318
  }
319
  }
320
  },
321
+ "position": {
322
+ "x": 344.0,
323
+ "y": 384.0
324
+ },
325
  "type": "basic"
326
  },
327
  "params": {
328
+ "type": "Leaky_ReLU"
329
  },
330
  "status": "planned",
331
  "title": "Activation"
332
  },
333
  "dragHandle": ".bg-primary",
334
  "height": 200.0,
335
+ "id": "Activation 1",
336
  "position": {
337
+ "x": 99.77615018185415,
338
+ "y": -249.43925929074078
339
  },
340
  "type": "basic",
341
  "width": 200.0
 
348
  "error": null,
349
  "input_metadata": null,
350
  "meta": {
351
+ "inputs": {},
352
+ "name": "Input: tensor",
353
+ "outputs": {
354
+ "x": {
355
+ "name": "x",
356
+ "position": "top",
357
  "type": {
358
  "type": "tensor"
359
  }
360
  }
361
  },
 
 
362
  "params": {
363
+ "name": {
364
+ "default": null,
365
+ "name": "name",
 
 
 
 
 
 
 
366
  "type": {
367
+ "type": "None"
 
 
 
 
 
 
 
 
368
  }
369
  }
370
  },
371
+ "position": {
372
+ "x": 258.0,
373
+ "y": 397.0
374
+ },
375
  "type": "basic"
376
  },
377
  "params": {
378
+ "name": "X"
 
379
  },
380
  "status": "planned",
381
+ "title": "Input: tensor"
382
  },
383
  "dragHandle": ".bg-primary",
384
  "height": 200.0,
385
+ "id": "Input: tensor 1",
386
  "position": {
387
+ "x": 97.60681762952905,
388
+ "y": 293.6278596776366
389
  },
390
  "type": "basic",
391
  "width": 200.0
 
398
  "error": null,
399
  "input_metadata": null,
400
  "meta": {
401
+ "inputs": {},
402
+ "name": "Input: tensor",
 
 
 
 
 
 
 
 
403
  "outputs": {
404
+ "x": {
405
+ "name": "x",
406
+ "position": "top",
407
  "type": {
408
  "type": "tensor"
409
  }
410
  }
411
  },
412
  "params": {
413
+ "name": {
414
+ "default": null,
415
+ "name": "name",
 
 
 
 
 
 
 
416
  "type": {
417
+ "type": "None"
418
  }
419
  }
420
  },
421
  "position": {
422
+ "x": 1169.0,
423
+ "y": 340.0
424
  },
425
  "type": "basic"
426
  },
427
  "params": {
428
+ "name": "Y"
 
429
  },
430
  "status": "planned",
431
+ "title": "Input: tensor"
432
  },
433
  "dragHandle": ".bg-primary",
434
  "height": 200.0,
435
+ "id": "Input: tensor 3",
436
  "position": {
437
+ "x": 862.4359094222825,
438
+ "y": -290.0677203273021
439
  },
440
  "type": "basic",
441
  "width": 200.0
examples/Model use CHANGED
@@ -579,54 +579,54 @@
579
  ],
580
  "data": [
581
  [
582
- "[0.67418337 0.79634351 0.23229051 0.71345252]",
583
- "[1.67418337 1.79634356 1.23229051 1.71345258]",
584
- "[1.2304607629776, 1.2535771131515503, 1.4764351844787598, 1.524468183517456]"
585
  ],
586
  [
587
- "[0.6032477 0.83361369 0.18538666 0.19108021]",
588
- "[1.60324764 1.83361363 1.18538666 1.19108021]",
589
- "[1.0745662450790405, 1.3498694896697998, 1.524811029434204, 1.2336101531982422]"
590
  ],
591
  [
592
- "[0.62569475 0.9881897 0.83639616 0.9828859 ]",
593
- "[1.62569475 1.9881897 1.83639622 1.98288584]",
594
- "[1.3193368911743164, 1.691685438156128, 1.4516379833221436, 1.6486209630966187]"
595
  ],
596
  [
597
- "[0.63235509 0.70352674 0.96188956 0.46240485]",
598
- "[1.63235509 1.70352674 1.96188951 1.46240485]",
599
- "[1.298614740371704, 1.7463937997817993, 1.4211854934692383, 1.2179659605026245]"
600
  ],
601
  [
602
- "[0.74064726 0.4155122 0.09800029 0.49930882]",
603
- "[1.74064732 1.4155122 1.09800029 1.49930882]",
604
- "[1.2740169763565063, 1.0242725610733032, 1.413628339767456, 1.288135290145874]"
605
  ],
606
  [
607
- "[0.34084332 0.73018837 0.54168713 0.91440833]",
608
- "[1.34084332 1.73018837 1.54168713 1.91440833]",
609
- "[1.0640922784805298, 1.3728828430175781, 1.1964272260665894, 1.5791122913360596]"
610
  ],
611
  [
612
- "[0.80654246 0.08253473 0.74478531 0.71257162]",
613
- "[1.8065424 1.08253479 1.74478531 1.71257162]",
614
- "[1.521227240562439, 1.247929334640503, 1.2715831995010376, 1.1901893615722656]"
615
  ],
616
  [
617
- "[0.11560339 0.57495481 0.76535827 0.0391947 ]",
618
- "[1.11560345 1.57495475 1.76535821 1.0391947 ]",
619
- "[0.8016152381896973, 1.6247310638427734, 1.1153241395950317, 0.9691400527954102]"
620
  ],
621
  [
622
- "[0.68094063 0.45189077 0.22661722 0.37354094]",
623
- "[1.68094063 1.45189071 1.22661722 1.37354088]",
624
- "[1.224153995513916, 1.1527836322784424, 1.4024348258972168, 1.204542636871338]"
625
  ],
626
  [
627
- "[0.79121011 0.54161114 0.69369799 0.1520769 ]",
628
- "[1.79121017 1.54161119 1.69369793 1.15207696]",
629
- "[1.3464093208312988, 1.5594894886016846, 1.5191831588745117, 1.0183898210525513]"
630
  ]
631
  ]
632
  },
@@ -644,6 +644,10 @@
644
  "[0.85706753 0.61447072 0.41741937 0.85147089]",
645
  "[1.85706758 1.61447072 1.41741943 1.85147095]"
646
  ],
 
 
 
 
647
  [
648
  "[0.19409031 0.68692201 0.60667384 0.57829887]",
649
  "[1.19409037 1.68692207 1.60667384 1.57829881]"
@@ -712,6 +716,10 @@
712
  "[0.24388778 0.07268471 0.68350857 0.73431659]",
713
  "[1.24388778 1.07268476 1.68350863 1.73431659]"
714
  ],
 
 
 
 
715
  [
716
  "[0.56922203 0.98222166 0.76851749 0.28615737]",
717
  "[1.56922197 1.9822216 1.76851749 1.28615737]"
@@ -733,8 +741,8 @@
733
  "[1.68062544 1.98093534 1.14778829 1.53244972]"
734
  ],
735
  [
736
- "[0.31518555 0.49643308 0.11509258 0.95458382]",
737
- "[1.31518555 1.49643302 1.11509252 1.95458388]"
738
  ],
739
  [
740
  "[0.79423058 0.07138705 0.061777 0.18766576]",
@@ -769,16 +777,12 @@
769
  "[1.98033333 1.97656083 1.38939917 1.81491041]"
770
  ],
771
  [
772
- "[0.78956431 0.87284744 0.06880784 0.03455889]",
773
- "[1.78956437 1.87284744 1.06880784 1.03455889]"
774
- ],
775
- [
776
- "[0.94221359 0.57740951 0.98649532 0.40934443]",
777
- "[1.94221354 1.57740951 1.98649526 1.40934443]"
778
  ],
779
  [
780
- "[0.00497234 0.39319336 0.57054168 0.75150961]",
781
- "[1.00497234 1.39319336 1.57054162 1.75150967]"
782
  ],
783
  [
784
  "[0.44330525 0.09997386 0.89025736 0.90507984]",
@@ -832,10 +836,6 @@
832
  "[0.18720162 0.74115586 0.98626411 0.30355608]",
833
  "[1.18720162 1.74115586 1.98626411 1.30355608]"
834
  ],
835
- [
836
- "[0.85566247 0.83362883 0.48424995 0.25265992]",
837
- "[1.85566247 1.83362889 1.48424995 1.25265992]"
838
- ],
839
  [
840
  "[0.95928186 0.84273899 0.71514636 0.38619852]",
841
  "[1.95928192 1.84273899 1.7151463 1.38619852]"
@@ -844,10 +844,6 @@
844
  "[0.32565445 0.90939188 0.07488042 0.13730896]",
845
  "[1.32565451 1.90939188 1.07488036 1.13730896]"
846
  ],
847
- [
848
- "[0.9829582 0.59269661 0.40120947 0.95487177]",
849
- "[1.9829582 1.59269667 1.40120947 1.95487177]"
850
- ],
851
  [
852
  "[0.79905868 0.89367443 0.75429088 0.3190186 ]",
853
  "[1.79905868 1.89367437 1.75429082 1.3190186 ]"
@@ -856,6 +852,10 @@
856
  "[0.54914117 0.03810108 0.87531954 0.73044223]",
857
  "[1.54914117 1.03810108 1.87531948 1.73044229]"
858
  ],
 
 
 
 
859
  [
860
  "[0.87285906 0.48354989 0.39394957 0.59456545]",
861
  "[1.872859 1.48354983 1.39394951 1.59456539]"
@@ -876,14 +876,6 @@
876
  "[0.39147133 0.29854035 0.84663737 0.58175623]",
877
  "[1.39147139 1.29854035 1.84663737 1.58175623]"
878
  ],
879
- [
880
- "[0.02162331 0.81861657 0.92468154 0.07808572]",
881
- "[1.02162337 1.81861663 1.92468154 1.07808566]"
882
- ],
883
- [
884
- "[0.02235305 0.52774918 0.7331115 0.84358269]",
885
- "[1.02235305 1.52774918 1.7331115 1.84358263]"
886
- ],
887
  [
888
  "[0.6080932 0.56563014 0.32107437 0.72599429]",
889
  "[1.60809326 1.5656302 1.32107437 1.72599435]"
@@ -904,6 +896,10 @@
904
  "[0.60609657 0.96257663 0.19292736 0.95702219]",
905
  "[1.60609651 1.96257663 1.19292736 1.95702219]"
906
  ],
 
 
 
 
907
  [
908
  "[0.70167565 0.26930219 0.5660674 0.61194974]",
909
  "[1.70167565 1.26930213 1.56606746 1.61194968]"
@@ -912,10 +908,6 @@
912
  "[0.76933283 0.86241865 0.44114518 0.65644735]",
913
  "[1.76933289 1.86241865 1.44114518 1.65644741]"
914
  ],
915
- [
916
- "[0.59492421 0.90274489 0.38069052 0.46101224]",
917
- "[1.59492421 1.90274489 1.38069057 1.46101224]"
918
- ],
919
  [
920
  "[0.15064228 0.03198934 0.25754827 0.51484001]",
921
  "[1.15064228 1.03198934 1.25754833 1.51484001]"
@@ -932,6 +924,14 @@
932
  "[0.49691743 0.61873293 0.90698647 0.94486356]",
933
  "[1.49691749 1.61873293 1.90698647 1.94486356]"
934
  ],
 
 
 
 
 
 
 
 
935
  [
936
  "[0.37959969 0.42820001 0.10690689 0.96353984]",
937
  "[1.37959969 1.42820001 1.10690689 1.96353984]"
@@ -960,10 +960,6 @@
960
  "[0.47856545 0.46267092 0.6376707 0.84747767]",
961
  "[1.47856545 1.46267092 1.63767076 1.84747767]"
962
  ],
963
- [
964
- "[0.49584109 0.80599248 0.07096875 0.75872749]",
965
- "[1.49584103 1.80599248 1.07096875 1.75872755]"
966
- ],
967
  [
968
  "[0.43500566 0.66041756 0.80293626 0.96224713]",
969
  "[1.43500566 1.66041756 1.80293632 1.96224713]"
@@ -976,6 +972,10 @@
976
  "[0.28942841 0.05601001 0.33039129 0.27781558]",
977
  "[1.28942847 1.05601001 1.33039129 1.27781558]"
978
  ],
 
 
 
 
979
  [
980
  "[0.43681622 0.74680805 0.83598751 0.12414402]",
981
  "[1.43681622 1.74680805 1.83598757 1.12414408]"
@@ -1000,7 +1000,7 @@
1000
  }
1001
  },
1002
  "other": {
1003
- "model": "ModelConfig(model=Sequential(\n (0) - Identity(): Input__embedding_1_x -> START_Repeat_1_output\n (1) - Linear(in_features=4, out_features=4, bias=True): START_Repeat_1_output -> Linear_1_x\n (2) - <function leaky_relu at 0x7e6e1cf2c220>: Linear_1_x -> Activation_2_x\n (3) - Identity(): Activation_2_x -> END_Repeat_1_output\n (4) - Identity(): END_Repeat_1_output -> END_Repeat_1_output\n), model_inputs=['Input__embedding_1_x'], model_outputs=['END_Repeat_1_output'], loss_inputs=['Input__label_1_y', 'END_Repeat_1_output'], loss=Sequential(\n (0) - <function mse_loss at 0x7e6e1cf2dd00>: END_Repeat_1_output, Input__label_1_y -> MSE_loss_1_loss\n (1) - Identity(): MSE_loss_1_loss -> loss\n), optimizer=SGD (\nParameter Group 0\n dampening: 0\n differentiable: False\n foreach: None\n fused: None\n lr: 0.1\n maximize: False\n momentum: 0\n nesterov: False\n weight_decay: 0\n), source_workspace=None, trained=True)"
1004
  },
1005
  "relations": []
1006
  },
@@ -1032,11 +1032,11 @@
1032
  "model": {
1033
  "model": {
1034
  "inputs": [
1035
- "Input__embedding_1_x"
1036
  ],
1037
  "loss_inputs": [
1038
- "Input__label_1_y",
1039
- "END_Repeat_1_output"
1040
  ],
1041
  "outputs": [
1042
  "END_Repeat_1_output"
@@ -1207,11 +1207,11 @@
1207
  "model": {
1208
  "model": {
1209
  "inputs": [
1210
- "Input__embedding_1_x"
1211
  ],
1212
  "loss_inputs": [
1213
- "Input__label_1_y",
1214
- "END_Repeat_1_output"
1215
  ],
1216
  "outputs": [
1217
  "END_Repeat_1_output"
@@ -1270,8 +1270,8 @@
1270
  "type": "basic"
1271
  },
1272
  "params": {
1273
- "epochs": "15",
1274
- "input_mapping": "{\"map\":{\"Input__embedding_1_x\":{\"column\":\"x\",\"df\":\"df_train\"},\"Input__label_1_y\":{\"column\":\"y\",\"df\":\"df_train\"}}}",
1275
  "model_name": "model"
1276
  },
1277
  "status": "done",
@@ -1319,11 +1319,11 @@
1319
  "model": {
1320
  "model": {
1321
  "inputs": [
1322
- "Input__embedding_1_x"
1323
  ],
1324
  "loss_inputs": [
1325
- "Input__label_1_y",
1326
- "END_Repeat_1_output"
1327
  ],
1328
  "outputs": [
1329
  "END_Repeat_1_output"
@@ -1382,7 +1382,7 @@
1382
  "type": "basic"
1383
  },
1384
  "params": {
1385
- "input_mapping": "{\"map\":{\"Input__embedding_1_x\":{\"df\":\"df_test\",\"column\":\"x\"}}}",
1386
  "model_name": "model",
1387
  "output_mapping": "{\"map\":{\"END_Repeat_1_output\":{\"df\":\"df_test\",\"column\":\"predicted\"}}}"
1388
  },
 
579
  ],
580
  "data": [
581
  [
582
+ "[0.31518555 0.49643308 0.11509258 0.95458382]",
583
+ "[1.31518555 1.49643302 1.11509252 1.95458388]",
584
+ "[1.3819222450256348, -0.005686390213668346, 1.3793643712997437, 1.581865906715393]"
585
  ],
586
  [
587
+ "[0.02162331 0.81861657 0.92468154 0.07808572]",
588
+ "[1.02162337 1.81861663 1.92468154 1.07808566]",
589
+ "[1.312654972076416, -0.00689137727022171, 1.4941580295562744, 1.243792176246643]"
590
  ],
591
  [
592
+ "[0.94221359 0.57740951 0.98649532 0.40934443]",
593
+ "[1.94221354 1.57740951 1.98649526 1.40934443]",
594
+ "[1.9255921840667725, -0.008701151236891747, 1.751355767250061, 1.79597806930542]"
595
  ],
596
  [
597
+ "[0.34084332 0.73018837 0.54168713 0.91440833]",
598
+ "[1.34084332 1.73018837 1.54168713 1.91440833]",
599
+ "[1.6509568691253662, -0.007272087037563324, 1.5942981243133545, 1.81572687625885]"
600
  ],
601
  [
602
+ "[0.85566247 0.83362883 0.48424995 0.25265992]",
603
+ "[1.85566247 1.83362889 1.48424995 1.25265992]",
604
+ "[1.7482354640960693, -0.0063837491907179356, 1.4504402875900269, 1.5329445600509644]"
605
  ],
606
  [
607
+ "[0.02235305 0.52774918 0.7331115 0.84358269]",
608
+ "[1.02235305 1.52774918 1.7331115 1.84358263]",
609
+ "[1.3979142904281616, -0.007555779069662094, 1.6136289834976196, 1.6417407989501953]"
610
  ],
611
  [
612
+ "[0.9829582 0.59269661 0.40120947 0.95487177]",
613
+ "[1.9829582 1.59269667 1.40120947 1.95487177]",
614
+ "[1.9523842334747314, -0.00748100271448493, 1.6264307498931885, 1.9942888021469116]"
615
  ],
616
  [
617
+ "[0.49584109 0.80599248 0.07096875 0.75872749]",
618
+ "[1.49584103 1.80599248 1.07096875 1.75872755]",
619
+ "[1.5513110160827637, -0.005337317008525133, 1.3384482860565186, 1.5973539352416992]"
620
  ],
621
  [
622
+ "[0.00497234 0.39319336 0.57054168 0.75150961]",
623
+ "[1.00497234 1.39319336 1.57054162 1.75150967]",
624
+ "[1.2277441024780273, -0.0067505668848752975, 1.4969637393951416, 1.4524610042572021]"
625
  ],
626
  [
627
+ "[0.59492421 0.90274489 0.38069052 0.46101224]",
628
+ "[1.59492421 1.90274489 1.38069057 1.46101224]",
629
+ "[1.6593225002288818, -0.006088308058679104, 1.4240546226501465, 1.570335865020752]"
630
  ]
631
  ]
632
  },
 
644
  "[0.85706753 0.61447072 0.41741937 0.85147089]",
645
  "[1.85706758 1.61447072 1.41741943 1.85147095]"
646
  ],
647
+ [
648
+ "[0.11560339 0.57495481 0.76535827 0.0391947 ]",
649
+ "[1.11560345 1.57495475 1.76535821 1.0391947 ]"
650
+ ],
651
  [
652
  "[0.19409031 0.68692201 0.60667384 0.57829887]",
653
  "[1.19409037 1.68692207 1.60667384 1.57829881]"
 
716
  "[0.24388778 0.07268471 0.68350857 0.73431659]",
717
  "[1.24388778 1.07268476 1.68350863 1.73431659]"
718
  ],
719
+ [
720
+ "[0.62569475 0.9881897 0.83639616 0.9828859 ]",
721
+ "[1.62569475 1.9881897 1.83639622 1.98288584]"
722
+ ],
723
  [
724
  "[0.56922203 0.98222166 0.76851749 0.28615737]",
725
  "[1.56922197 1.9822216 1.76851749 1.28615737]"
 
741
  "[1.68062544 1.98093534 1.14778829 1.53244972]"
742
  ],
743
  [
744
+ "[0.79121011 0.54161114 0.69369799 0.1520769 ]",
745
+ "[1.79121017 1.54161119 1.69369793 1.15207696]"
746
  ],
747
  [
748
  "[0.79423058 0.07138705 0.061777 0.18766576]",
 
777
  "[1.98033333 1.97656083 1.38939917 1.81491041]"
778
  ],
779
  [
780
+ "[0.74064726 0.4155122 0.09800029 0.49930882]",
781
+ "[1.74064732 1.4155122 1.09800029 1.49930882]"
 
 
 
 
782
  ],
783
  [
784
+ "[0.78956431 0.87284744 0.06880784 0.03455889]",
785
+ "[1.78956437 1.87284744 1.06880784 1.03455889]"
786
  ],
787
  [
788
  "[0.44330525 0.09997386 0.89025736 0.90507984]",
 
836
  "[0.18720162 0.74115586 0.98626411 0.30355608]",
837
  "[1.18720162 1.74115586 1.98626411 1.30355608]"
838
  ],
 
 
 
 
839
  [
840
  "[0.95928186 0.84273899 0.71514636 0.38619852]",
841
  "[1.95928192 1.84273899 1.7151463 1.38619852]"
 
844
  "[0.32565445 0.90939188 0.07488042 0.13730896]",
845
  "[1.32565451 1.90939188 1.07488036 1.13730896]"
846
  ],
 
 
 
 
847
  [
848
  "[0.79905868 0.89367443 0.75429088 0.3190186 ]",
849
  "[1.79905868 1.89367437 1.75429082 1.3190186 ]"
 
852
  "[0.54914117 0.03810108 0.87531954 0.73044223]",
853
  "[1.54914117 1.03810108 1.87531948 1.73044229]"
854
  ],
855
+ [
856
+ "[0.67418337 0.79634351 0.23229051 0.71345252]",
857
+ "[1.67418337 1.79634356 1.23229051 1.71345258]"
858
+ ],
859
  [
860
  "[0.87285906 0.48354989 0.39394957 0.59456545]",
861
  "[1.872859 1.48354983 1.39394951 1.59456539]"
 
876
  "[0.39147133 0.29854035 0.84663737 0.58175623]",
877
  "[1.39147139 1.29854035 1.84663737 1.58175623]"
878
  ],
 
 
 
 
 
 
 
 
879
  [
880
  "[0.6080932 0.56563014 0.32107437 0.72599429]",
881
  "[1.60809326 1.5656302 1.32107437 1.72599435]"
 
896
  "[0.60609657 0.96257663 0.19292736 0.95702219]",
897
  "[1.60609651 1.96257663 1.19292736 1.95702219]"
898
  ],
899
+ [
900
+ "[0.80654246 0.08253473 0.74478531 0.71257162]",
901
+ "[1.8065424 1.08253479 1.74478531 1.71257162]"
902
+ ],
903
  [
904
  "[0.70167565 0.26930219 0.5660674 0.61194974]",
905
  "[1.70167565 1.26930213 1.56606746 1.61194968]"
 
908
  "[0.76933283 0.86241865 0.44114518 0.65644735]",
909
  "[1.76933289 1.86241865 1.44114518 1.65644741]"
910
  ],
 
 
 
 
911
  [
912
  "[0.15064228 0.03198934 0.25754827 0.51484001]",
913
  "[1.15064228 1.03198934 1.25754833 1.51484001]"
 
924
  "[0.49691743 0.61873293 0.90698647 0.94486356]",
925
  "[1.49691749 1.61873293 1.90698647 1.94486356]"
926
  ],
927
+ [
928
+ "[0.6032477 0.83361369 0.18538666 0.19108021]",
929
+ "[1.60324764 1.83361363 1.18538666 1.19108021]"
930
+ ],
931
+ [
932
+ "[0.63235509 0.70352674 0.96188956 0.46240485]",
933
+ "[1.63235509 1.70352674 1.96188951 1.46240485]"
934
+ ],
935
  [
936
  "[0.37959969 0.42820001 0.10690689 0.96353984]",
937
  "[1.37959969 1.42820001 1.10690689 1.96353984]"
 
960
  "[0.47856545 0.46267092 0.6376707 0.84747767]",
961
  "[1.47856545 1.46267092 1.63767076 1.84747767]"
962
  ],
 
 
 
 
963
  [
964
  "[0.43500566 0.66041756 0.80293626 0.96224713]",
965
  "[1.43500566 1.66041756 1.80293632 1.96224713]"
 
972
  "[0.28942841 0.05601001 0.33039129 0.27781558]",
973
  "[1.28942847 1.05601001 1.33039129 1.27781558]"
974
  ],
975
+ [
976
+ "[0.68094063 0.45189077 0.22661722 0.37354094]",
977
+ "[1.68094063 1.45189071 1.22661722 1.37354088]"
978
+ ],
979
  [
980
  "[0.43681622 0.74680805 0.83598751 0.12414402]",
981
  "[1.43681622 1.74680805 1.83598757 1.12414408]"
 
1000
  }
1001
  },
1002
  "other": {
1003
+ "model": "ModelConfig(model=Sequential(\n (0) - Identity(): Input__tensor_1_x -> START_Repeat_1_output\n (1) - Linear(in_features=4, out_features=4, bias=True): START_Repeat_1_output -> Linear_2_output\n (2) - <function leaky_relu at 0x759513340220>: Linear_2_output -> Activation_1_output\n (3) - Identity(): Activation_1_output -> END_Repeat_1_output\n (4) - Identity(): END_Repeat_1_output -> END_Repeat_1_output\n), model_inputs=['Input__tensor_1_x'], model_outputs=['END_Repeat_1_output'], loss_inputs=['END_Repeat_1_output', 'Input__tensor_3_x'], loss=Sequential(\n (0) - <function mse_loss at 0x759513341d00>: END_Repeat_1_output, Input__tensor_3_x -> MSE_loss_1_loss\n (1) - Identity(): MSE_loss_1_loss -> loss\n), optimizer=SGD (\nParameter Group 0\n dampening: 0\n differentiable: False\n foreach: None\n fused: None\n lr: 0.1\n maximize: False\n momentum: 0\n nesterov: False\n weight_decay: 0\n), source_workspace=None, trained=True)"
1004
  },
1005
  "relations": []
1006
  },
 
1032
  "model": {
1033
  "model": {
1034
  "inputs": [
1035
+ "Input__tensor_1_x"
1036
  ],
1037
  "loss_inputs": [
1038
+ "END_Repeat_1_output",
1039
+ "Input__tensor_3_x"
1040
  ],
1041
  "outputs": [
1042
  "END_Repeat_1_output"
 
1207
  "model": {
1208
  "model": {
1209
  "inputs": [
1210
+ "Input__tensor_1_x"
1211
  ],
1212
  "loss_inputs": [
1213
+ "END_Repeat_1_output",
1214
+ "Input__tensor_3_x"
1215
  ],
1216
  "outputs": [
1217
  "END_Repeat_1_output"
 
1270
  "type": "basic"
1271
  },
1272
  "params": {
1273
+ "epochs": "150",
1274
+ "input_mapping": "{\"map\":{\"Input__tensor_1_x\":{\"df\":\"df_train\",\"column\":\"x\"},\"Input__tensor_3_x\":{\"df\":\"df_train\",\"column\":\"y\"}}}",
1275
  "model_name": "model"
1276
  },
1277
  "status": "done",
 
1319
  "model": {
1320
  "model": {
1321
  "inputs": [
1322
+ "Input__tensor_1_x"
1323
  ],
1324
  "loss_inputs": [
1325
+ "END_Repeat_1_output",
1326
+ "Input__tensor_3_x"
1327
  ],
1328
  "outputs": [
1329
  "END_Repeat_1_output"
 
1382
  "type": "basic"
1383
  },
1384
  "params": {
1385
+ "input_mapping": "{\"map\":{\"Input__tensor_1_x\":{\"df\":\"df_test\",\"column\":\"x\"}}}",
1386
  "model_name": "model",
1387
  "output_mapping": "{\"map\":{\"END_Repeat_1_output\":{\"df\":\"df_test\",\"column\":\"predicted\"}}}"
1388
  },
lynxkite-graph-analytics/src/lynxkite_graph_analytics/pytorch_model_ops.py CHANGED
@@ -1,6 +1,7 @@
1
  """Boxes for defining PyTorch models."""
2
 
3
  import copy
 
4
  import graphlib
5
  import types
6
 
@@ -15,6 +16,21 @@ from . import core
15
  ENV = "PyTorch model"
16
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def reg(name, inputs=[], outputs=None, params=[]):
19
  if outputs is None:
20
  outputs = inputs
@@ -27,13 +43,9 @@ def reg(name, inputs=[], outputs=None, params=[]):
27
  )
28
 
29
 
30
- reg("Input: embedding", outputs=["x"])
31
  reg("Input: graph edges", outputs=["edges"])
32
- reg("Input: label", outputs=["y"])
33
- reg("Input: positive sample", outputs=["x_pos"])
34
- reg("Input: negative sample", outputs=["x_neg"])
35
  reg("Input: sequential", outputs=["y"])
36
- reg("Input: zeros", outputs=["x"])
37
 
38
  reg("LSTM", inputs=["x", "h"], outputs=["x", "h"])
39
  reg(
@@ -59,10 +71,35 @@ reg(
59
  ),
60
  ],
61
  )
 
 
62
  reg("Attention", inputs=["q", "k", "v"], outputs=["x", "weights"])
63
  reg("LayerNorm", inputs=["x"])
64
  reg("Dropout", inputs=["x"], params=[P.basic("p", 0.5)])
65
- reg("Linear", inputs=["x"], params=[P.basic("output_dim", "same")])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  reg("Softmax", inputs=["x"])
67
  reg(
68
  "Graph conv",
@@ -70,11 +107,6 @@ reg(
70
  outputs=["x"],
71
  params=[P.options("type", ["GCNConv", "GATConv", "GATv2Conv", "SAGEConv"])],
72
  )
73
- reg(
74
- "Activation",
75
- inputs=["x"],
76
- params=[P.options("type", ["ReLU", "Leaky ReLU", "Tanh", "Mish"])],
77
- )
78
  reg("Concatenate", inputs=["a", "b"], outputs=["x"])
79
  reg("Add", inputs=["a", "b"], outputs=["x"])
80
  reg("Subtract", inputs=["a", "b"], outputs=["x"])
@@ -128,6 +160,28 @@ def _to_id(*strings: str) -> str:
128
  return "_".join("".join(c if c.isalnum() else "_" for c in s) for s in strings)
129
 
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  class ColumnSpec(pydantic.BaseModel):
132
  df: str
133
  column: str
@@ -306,15 +360,6 @@ def build_model(ws: workspace.Workspace, inputs: dict[str, torch.Tensor]) -> Mod
306
  outputs = types.SimpleNamespace(**outputs)
307
  ls = loss_layers if "loss" in regions[node_id] else layers
308
  match t:
309
- case "Linear":
310
- isize = sizes.get(inputs.x, 1)
311
- osize = isize if p["output_dim"] == "same" else int(p["output_dim"])
312
- ls.append((torch.nn.Linear(isize, osize), f"{inputs.x} -> {outputs.x}"))
313
- sizes[outputs.x] = osize
314
- case "Activation":
315
- f = getattr(torch.nn.functional, p["type"].name.lower().replace(" ", "_"))
316
- ls.append((f, f"{inputs.x} -> {outputs.x}"))
317
- sizes[outputs.x] = sizes.get(inputs.x, 1)
318
  case "MSE loss":
319
  ls.append(
320
  (
@@ -335,6 +380,25 @@ def build_model(ws: workspace.Workspace, inputs: dict[str, torch.Tensor]) -> Mod
335
  r = regions.get(n, set())
336
  if ("repeat", repeat_id) in r:
337
  print(f"repeating {n}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  cfg["model_inputs"] = list(used_in_model - made_in_model)
339
  cfg["model_outputs"] = list(made_in_model & used_in_loss)
340
  cfg["loss_inputs"] = list(used_in_loss - made_in_loss)
 
1
  """Boxes for defining PyTorch models."""
2
 
3
  import copy
4
+ import enum
5
  import graphlib
6
  import types
7
 
 
16
  ENV = "PyTorch model"
17
 
18
 
19
+ def op(name, **kwargs):
20
+ _op = ops.op(ENV, name, **kwargs)
21
+
22
+ def decorator(func):
23
+ _op(func)
24
+ op = func.__op__
25
+ for p in op.inputs.values():
26
+ p.position = "bottom"
27
+ for p in op.outputs.values():
28
+ p.position = "top"
29
+ return func
30
+
31
+ return decorator
32
+
33
+
34
  def reg(name, inputs=[], outputs=None, params=[]):
35
  if outputs is None:
36
  outputs = inputs
 
43
  )
44
 
45
 
46
+ reg("Input: tensor", outputs=["x"], params=[P.basic("name")])
47
  reg("Input: graph edges", outputs=["edges"])
 
 
 
48
  reg("Input: sequential", outputs=["y"])
 
49
 
50
  reg("LSTM", inputs=["x", "h"], outputs=["x", "h"])
51
  reg(
 
71
  ),
72
  ],
73
  )
74
+
75
+
76
  reg("Attention", inputs=["q", "k", "v"], outputs=["x", "weights"])
77
  reg("LayerNorm", inputs=["x"])
78
  reg("Dropout", inputs=["x"], params=[P.basic("p", 0.5)])
79
+
80
+
81
+ @op("Linear")
82
+ def linear(x, *, output_dim="same"):
83
+ if output_dim == "same":
84
+ oshape = x.shape
85
+ else:
86
+ oshape = tuple(*x.shape[:-1], int(output_dim))
87
+ return Layer(torch.nn.Linear(x.shape, oshape), shape=oshape)
88
+
89
+
90
+ class ActivationTypes(enum.Enum):
91
+ ReLU = "ReLU"
92
+ Leaky_ReLU = "Leaky ReLU"
93
+ Tanh = "Tanh"
94
+ Mish = "Mish"
95
+
96
+
97
+ @op("Activation")
98
+ def activation(x, *, type: ActivationTypes = ActivationTypes.ReLU):
99
+ f = getattr(torch.nn.functional, type.name.lower().replace(" ", "_"))
100
+ return Layer(f, shape=x.shape)
101
+
102
+
103
  reg("Softmax", inputs=["x"])
104
  reg(
105
  "Graph conv",
 
107
  outputs=["x"],
108
  params=[P.options("type", ["GCNConv", "GATConv", "GATv2Conv", "SAGEConv"])],
109
  )
 
 
 
 
 
110
  reg("Concatenate", inputs=["a", "b"], outputs=["x"])
111
  reg("Add", inputs=["a", "b"], outputs=["x"])
112
  reg("Subtract", inputs=["a", "b"], outputs=["x"])
 
160
  return "_".join("".join(c if c.isalnum() else "_" for c in s) for s in strings)
161
 
162
 
163
+ @dataclasses.dataclass
164
+ class OpInput:
165
+ """Ops get their inputs like this. They have to return a Layer made for this input."""
166
+
167
+ id: str
168
+ shape: tuple[int, ...]
169
+
170
+
171
+ @dataclasses.dataclass
172
+ class Layer:
173
+ """Return this from an op. Must include a module and the shapes of the outputs."""
174
+
175
+ module: torch.nn.Module
176
+ shapes: list[tuple[int, ...]] | None = None # One for each output.
177
+ shape: dataclasses.InitVar[tuple[int, ...] | None] = None # Convenience for single output.
178
+
179
+ def __post_init__(self, shape):
180
+ assert not self.shapes or not shape, "Cannot set both shapes and shape."
181
+ if shape:
182
+ self.shapes = [shape]
183
+
184
+
185
  class ColumnSpec(pydantic.BaseModel):
186
  df: str
187
  column: str
 
360
  outputs = types.SimpleNamespace(**outputs)
361
  ls = loss_layers if "loss" in regions[node_id] else layers
362
  match t:
 
 
 
 
 
 
 
 
 
363
  case "MSE loss":
364
  ls.append(
365
  (
 
380
  r = regions.get(n, set())
381
  if ("repeat", repeat_id) in r:
382
  print(f"repeating {n}")
383
+ case "Optimizer" | "Input: tensor" | "Input: graph edges" | "Input: sequential":
384
+ pass
385
+ case _:
386
+ op_inputs = []
387
+ for i in op.inputs.keys():
388
+ id = getattr(inputs, i)
389
+ op_inputs.append(OpInput(id, shape=sizes.get(id, 1)))
390
+ if op.func != ops.no_op:
391
+ layer = op.func(*op_inputs, **p)
392
+ else:
393
+ layer = Layer(torch.nn.Identity(), shapes=[i.shape for i in op_inputs])
394
+ input_ids = ", ".join(i.id for i in op_inputs)
395
+ output_ids = []
396
+ for o, shape in zip(op.outputs.keys(), layer.shapes):
397
+ id = getattr(outputs, o)
398
+ output_ids.append(id)
399
+ sizes[id] = shape
400
+ output_ids = ", ".join(output_ids)
401
+ ls.append((layer.module, f"{input_ids} -> {output_ids}"))
402
  cfg["model_inputs"] = list(used_in_model - made_in_model)
403
  cfg["model_outputs"] = list(made_in_model & used_in_loss)
404
  cfg["loss_inputs"] = list(used_in_loss - made_in_loss)