darabos commited on
Commit
e9b29c4
·
unverified ·
2 Parent(s): 0eb0918 58cbe2a

Merge pull request #96 from biggraph/darabos-model-designer

Browse files
examples/Model definition ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "edges": [
3
+ {
4
+ "id": "Input: embedding 1 Linear 1",
5
+ "source": "Input: embedding 1",
6
+ "sourceHandle": "x",
7
+ "target": "Linear 1",
8
+ "targetHandle": "x"
9
+ },
10
+ {
11
+ "id": "Input: label 1 MSE loss 1",
12
+ "source": "Input: label 1",
13
+ "sourceHandle": "y",
14
+ "target": "MSE loss 1",
15
+ "targetHandle": "y"
16
+ },
17
+ {
18
+ "id": "Linear 1 Activation 2",
19
+ "source": "Linear 1",
20
+ "sourceHandle": "x",
21
+ "target": "Activation 2",
22
+ "targetHandle": "x"
23
+ },
24
+ {
25
+ "id": "Activation 2 MSE loss 1",
26
+ "source": "Activation 2",
27
+ "sourceHandle": "x",
28
+ "target": "MSE loss 1",
29
+ "targetHandle": "x"
30
+ },
31
+ {
32
+ "id": "MSE loss 1 Optimizer 2",
33
+ "source": "MSE loss 1",
34
+ "sourceHandle": "loss",
35
+ "target": "Optimizer 2",
36
+ "targetHandle": "loss"
37
+ }
38
+ ],
39
+ "env": "PyTorch model",
40
+ "nodes": [
41
+ {
42
+ "data": {
43
+ "display": null,
44
+ "error": null,
45
+ "meta": {
46
+ "inputs": {},
47
+ "name": "Input: embedding",
48
+ "outputs": {
49
+ "x": {
50
+ "name": "x",
51
+ "position": "top",
52
+ "type": {
53
+ "type": "tensor"
54
+ }
55
+ }
56
+ },
57
+ "params": {},
58
+ "type": "basic"
59
+ },
60
+ "params": {},
61
+ "status": "planned",
62
+ "title": "Input: embedding"
63
+ },
64
+ "dragHandle": ".bg-primary",
65
+ "height": 200.0,
66
+ "id": "Input: embedding 1",
67
+ "position": {
68
+ "x": 91.0,
69
+ "y": 266.0
70
+ },
71
+ "type": "basic",
72
+ "width": 200.0
73
+ },
74
+ {
75
+ "data": {
76
+ "display": null,
77
+ "error": null,
78
+ "meta": {
79
+ "inputs": {
80
+ "x": {
81
+ "name": "x",
82
+ "position": "bottom",
83
+ "type": {
84
+ "type": "tensor"
85
+ }
86
+ }
87
+ },
88
+ "name": "Linear",
89
+ "outputs": {
90
+ "x": {
91
+ "name": "x",
92
+ "position": "top",
93
+ "type": {
94
+ "type": "tensor"
95
+ }
96
+ }
97
+ },
98
+ "params": {
99
+ "output_dim": {
100
+ "default": "same",
101
+ "name": "output_dim",
102
+ "type": {
103
+ "type": "<class 'str'>"
104
+ }
105
+ }
106
+ },
107
+ "type": "basic"
108
+ },
109
+ "params": {
110
+ "output_dim": "same"
111
+ },
112
+ "status": "planned",
113
+ "title": "Linear"
114
+ },
115
+ "dragHandle": ".bg-primary",
116
+ "height": 200.0,
117
+ "id": "Linear 1",
118
+ "position": {
119
+ "x": 86.0,
120
+ "y": 33.0
121
+ },
122
+ "type": "basic",
123
+ "width": 200.0
124
+ },
125
+ {
126
+ "data": {
127
+ "display": null,
128
+ "error": null,
129
+ "meta": {
130
+ "inputs": {
131
+ "x": {
132
+ "name": "x",
133
+ "position": "bottom",
134
+ "type": {
135
+ "type": "tensor"
136
+ }
137
+ },
138
+ "y": {
139
+ "name": "y",
140
+ "position": "bottom",
141
+ "type": {
142
+ "type": "tensor"
143
+ }
144
+ }
145
+ },
146
+ "name": "MSE loss",
147
+ "outputs": {
148
+ "loss": {
149
+ "name": "loss",
150
+ "position": "top",
151
+ "type": {
152
+ "type": "tensor"
153
+ }
154
+ }
155
+ },
156
+ "params": {},
157
+ "type": "basic"
158
+ },
159
+ "params": {},
160
+ "status": "planned",
161
+ "title": "MSE loss"
162
+ },
163
+ "dragHandle": ".bg-primary",
164
+ "height": 200.0,
165
+ "id": "MSE loss 1",
166
+ "position": {
167
+ "x": 315.0,
168
+ "y": -510.0
169
+ },
170
+ "type": "basic",
171
+ "width": 200.0
172
+ },
173
+ {
174
+ "data": {
175
+ "display": null,
176
+ "error": null,
177
+ "meta": {
178
+ "inputs": {},
179
+ "name": "Input: label",
180
+ "outputs": {
181
+ "y": {
182
+ "name": "y",
183
+ "position": "top",
184
+ "type": {
185
+ "type": "tensor"
186
+ }
187
+ }
188
+ },
189
+ "params": {},
190
+ "type": "basic"
191
+ },
192
+ "params": {},
193
+ "status": "planned",
194
+ "title": "Input: label"
195
+ },
196
+ "dragHandle": ".bg-primary",
197
+ "height": 200.0,
198
+ "id": "Input: label 1",
199
+ "position": {
200
+ "x": 615.0,
201
+ "y": -165.0
202
+ },
203
+ "type": "basic",
204
+ "width": 200.0
205
+ },
206
+ {
207
+ "data": {
208
+ "__execution_delay": 0.0,
209
+ "collapsed": null,
210
+ "display": null,
211
+ "error": null,
212
+ "meta": {
213
+ "inputs": {
214
+ "x": {
215
+ "name": "x",
216
+ "position": "bottom",
217
+ "type": {
218
+ "type": "tensor"
219
+ }
220
+ }
221
+ },
222
+ "name": "Activation",
223
+ "outputs": {
224
+ "x": {
225
+ "name": "x",
226
+ "position": "top",
227
+ "type": {
228
+ "type": "tensor"
229
+ }
230
+ }
231
+ },
232
+ "params": {
233
+ "type": {
234
+ "default": "ReLU",
235
+ "name": "type",
236
+ "type": {
237
+ "enum": [
238
+ "ReLU",
239
+ "Leaky ReLU",
240
+ "Tanh",
241
+ "Mish"
242
+ ]
243
+ }
244
+ }
245
+ },
246
+ "position": {
247
+ "x": 419.0,
248
+ "y": 396.0
249
+ },
250
+ "type": "basic"
251
+ },
252
+ "params": {
253
+ "type": "Leaky ReLU"
254
+ },
255
+ "status": "planned",
256
+ "title": "Activation"
257
+ },
258
+ "dragHandle": ".bg-primary",
259
+ "height": 200.0,
260
+ "id": "Activation 2",
261
+ "position": {
262
+ "x": 93.61643829835265,
263
+ "y": -229.04087132886406
264
+ },
265
+ "type": "basic",
266
+ "width": 200.0
267
+ },
268
+ {
269
+ "data": {
270
+ "__execution_delay": 0.0,
271
+ "collapsed": null,
272
+ "display": null,
273
+ "error": null,
274
+ "meta": {
275
+ "inputs": {
276
+ "loss": {
277
+ "name": "loss",
278
+ "position": "bottom",
279
+ "type": {
280
+ "type": "tensor"
281
+ }
282
+ }
283
+ },
284
+ "name": "Optimizer",
285
+ "outputs": {},
286
+ "params": {
287
+ "lr": {
288
+ "default": 0.001,
289
+ "name": "lr",
290
+ "type": {
291
+ "type": "<class 'float'>"
292
+ }
293
+ },
294
+ "type": {
295
+ "default": "AdamW",
296
+ "name": "type",
297
+ "type": {
298
+ "enum": [
299
+ "AdamW",
300
+ "Adafactor",
301
+ "Adagrad",
302
+ "SGD",
303
+ "Lion",
304
+ "Paged AdamW",
305
+ "Galore AdamW"
306
+ ]
307
+ }
308
+ }
309
+ },
310
+ "position": {
311
+ "x": 526.0,
312
+ "y": 116.0
313
+ },
314
+ "type": "basic"
315
+ },
316
+ "params": {
317
+ "lr": "0.1",
318
+ "type": "SGD"
319
+ },
320
+ "status": "planned",
321
+ "title": "Optimizer"
322
+ },
323
+ "dragHandle": ".bg-primary",
324
+ "height": 200.0,
325
+ "id": "Optimizer 2",
326
+ "position": {
327
+ "x": 305.6132943499785,
328
+ "y": -804.0094318451224
329
+ },
330
+ "type": "basic",
331
+ "width": 200.0
332
+ }
333
+ ]
334
+ }
examples/Model use ADDED
@@ -0,0 +1,1404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "edges": [
3
+ {
4
+ "id": "Import Parquet 1 Train/test split 1",
5
+ "source": "Import Parquet 1",
6
+ "sourceHandle": "output",
7
+ "target": "Train/test split 1",
8
+ "targetHandle": "bundle"
9
+ },
10
+ {
11
+ "id": "Train/test split 1 Define model 1",
12
+ "source": "Train/test split 1",
13
+ "sourceHandle": "output",
14
+ "target": "Define model 1",
15
+ "targetHandle": "bundle"
16
+ },
17
+ {
18
+ "id": "Define model 1 Train model 2",
19
+ "source": "Define model 1",
20
+ "sourceHandle": "output",
21
+ "target": "Train model 2",
22
+ "targetHandle": "bundle"
23
+ },
24
+ {
25
+ "id": "Train model 2 Model inference 1",
26
+ "source": "Train model 2",
27
+ "sourceHandle": "output",
28
+ "target": "Model inference 1",
29
+ "targetHandle": "bundle"
30
+ },
31
+ {
32
+ "id": "Model inference 1 View tables 1",
33
+ "source": "Model inference 1",
34
+ "sourceHandle": "output",
35
+ "target": "View tables 1",
36
+ "targetHandle": "bundle"
37
+ }
38
+ ],
39
+ "env": "LynxKite Graph Analytics",
40
+ "nodes": [
41
+ {
42
+ "data": {
43
+ "__execution_delay": 0.0,
44
+ "collapsed": null,
45
+ "display": null,
46
+ "error": null,
47
+ "input_metadata": [
48
+ {
49
+ "dataframes": {
50
+ "df": {
51
+ "columns": [
52
+ "x",
53
+ "y"
54
+ ]
55
+ }
56
+ },
57
+ "other": {},
58
+ "relations": []
59
+ }
60
+ ],
61
+ "meta": {
62
+ "inputs": {
63
+ "bundle": {
64
+ "name": "bundle",
65
+ "position": "left",
66
+ "type": {
67
+ "type": "<class 'lynxkite_graph_analytics.core.Bundle'>"
68
+ }
69
+ }
70
+ },
71
+ "name": "Train/test split",
72
+ "outputs": {
73
+ "output": {
74
+ "name": "output",
75
+ "position": "right",
76
+ "type": {
77
+ "type": "None"
78
+ }
79
+ }
80
+ },
81
+ "params": {
82
+ "table_name": {
83
+ "default": null,
84
+ "name": "table_name",
85
+ "type": {
86
+ "type": "<class 'str'>"
87
+ }
88
+ },
89
+ "test_ratio": {
90
+ "default": 0.1,
91
+ "name": "test_ratio",
92
+ "type": {
93
+ "type": "<class 'float'>"
94
+ }
95
+ }
96
+ },
97
+ "type": "basic"
98
+ },
99
+ "params": {
100
+ "table_name": "df",
101
+ "test_ratio": 0.1
102
+ },
103
+ "status": "done",
104
+ "title": "Train/test split"
105
+ },
106
+ "dragHandle": ".bg-primary",
107
+ "height": 282.0,
108
+ "id": "Train/test split 1",
109
+ "position": {
110
+ "x": 345.0,
111
+ "y": 139.0
112
+ },
113
+ "type": "basic",
114
+ "width": 259.0
115
+ },
116
+ {
117
+ "data": {
118
+ "__execution_delay": 0.0,
119
+ "collapsed": null,
120
+ "display": null,
121
+ "error": null,
122
+ "input_metadata": [],
123
+ "meta": {
124
+ "inputs": {},
125
+ "name": "Import Parquet",
126
+ "outputs": {
127
+ "output": {
128
+ "name": "output",
129
+ "position": "right",
130
+ "type": {
131
+ "type": "None"
132
+ }
133
+ }
134
+ },
135
+ "params": {
136
+ "filename": {
137
+ "default": null,
138
+ "name": "filename",
139
+ "type": {
140
+ "type": "<class 'str'>"
141
+ }
142
+ }
143
+ },
144
+ "type": "basic"
145
+ },
146
+ "params": {
147
+ "filename": "uploads/plus-one-dataset.parquet"
148
+ },
149
+ "status": "done",
150
+ "title": "Import Parquet"
151
+ },
152
+ "dragHandle": ".bg-primary",
153
+ "height": 403.0,
154
+ "id": "Import Parquet 1",
155
+ "position": {
156
+ "x": -166.0,
157
+ "y": 112.0
158
+ },
159
+ "type": "basic",
160
+ "width": 371.0
161
+ },
162
+ {
163
+ "data": {
164
+ "display": {
165
+ "dataframes": {
166
+ "df": {
167
+ "columns": [
168
+ "x",
169
+ "y"
170
+ ],
171
+ "data": [
172
+ [
173
+ "[0.52046251 0.45887971 0.72169858 0.29517919]",
174
+ "[1.52046251 1.45887971 1.72169852 1.29517913]"
175
+ ],
176
+ [
177
+ "[0.85706753 0.61447072 0.41741937 0.85147089]",
178
+ "[1.85706758 1.61447072 1.41741943 1.85147095]"
179
+ ],
180
+ [
181
+ "[0.11560339 0.57495481 0.76535827 0.0391947 ]",
182
+ "[1.11560345 1.57495475 1.76535821 1.0391947 ]"
183
+ ],
184
+ [
185
+ "[0.19409031 0.68692201 0.60667384 0.57829887]",
186
+ "[1.19409037 1.68692207 1.60667384 1.57829881]"
187
+ ],
188
+ [
189
+ "[0.76807946 0.98855817 0.08259124 0.01730657]",
190
+ "[1.76807952 1.98855817 1.0825913 1.01730657]"
191
+ ],
192
+ [
193
+ "[0.67269951 0.10478973 0.5584439 0.83605725]",
194
+ "[1.67269945 1.10478973 1.5584439 1.83605719]"
195
+ ],
196
+ [
197
+ "[0.18686318 0.49356437 0.51323432 0.75392658]",
198
+ "[1.18686318 1.49356437 1.51323438 1.75392652]"
199
+ ],
200
+ [
201
+ "[0.18149549 0.30520517 0.30946714 0.16786289]",
202
+ "[1.18149543 1.30520511 1.30946708 1.16786289]"
203
+ ],
204
+ [
205
+ "[4.27091718e-01 4.89909172e-01 6.92297399e-01 2.57611275e-04]",
206
+ "[1.42709172 1.48990917 1.69229746 1.00025761]"
207
+ ],
208
+ [
209
+ "[0.32225502 0.16999388 0.05823922 0.9628762 ]",
210
+ "[1.32225502 1.16999388 1.05823922 1.9628762 ]"
211
+ ],
212
+ [
213
+ "[0.50783676 0.04156506 0.21984279 0.8454656 ]",
214
+ "[1.50783682 1.04156506 1.21984279 1.84546566]"
215
+ ],
216
+ [
217
+ "[0.98324287 0.99464184 0.14008355 0.47651017]",
218
+ "[1.98324287 1.99464178 1.14008355 1.47651017]"
219
+ ],
220
+ [
221
+ "[0.11693293 0.49860179 0.55020827 0.88832849]",
222
+ "[1.11693287 1.49860179 1.55020833 1.88832855]"
223
+ ],
224
+ [
225
+ "[0.48959708 0.48549271 0.32688856 0.356677 ]",
226
+ "[1.48959708 1.48549271 1.32688856 1.35667706]"
227
+ ],
228
+ [
229
+ "[0.50272274 0.54912758 0.17663097 0.79070699]",
230
+ "[1.50272274 1.54912758 1.17663097 1.79070699]"
231
+ ],
232
+ [
233
+ "[0.04508126 0.76880038 0.80721325 0.62542385]",
234
+ "[1.04508126 1.76880038 1.80721331 1.62542391]"
235
+ ],
236
+ [
237
+ "[0.19908059 0.17570406 0.51475513 0.1893943 ]",
238
+ "[1.19908059 1.175704 1.51475513 1.18939424]"
239
+ ],
240
+ [
241
+ "[0.40167677 0.25953674 0.9407078 0.76308483]",
242
+ "[1.40167677 1.25953674 1.9407078 1.76308489]"
243
+ ],
244
+ [
245
+ "[0.2480728 0.21694398 0.63941365 0.57128876]",
246
+ "[1.24807286 1.21694398 1.6394136 1.57128882]"
247
+ ],
248
+ [
249
+ "[0.24388778 0.07268471 0.68350857 0.73431659]",
250
+ "[1.24388778 1.07268476 1.68350863 1.73431659]"
251
+ ],
252
+ [
253
+ "[0.62569475 0.9881897 0.83639616 0.9828859 ]",
254
+ "[1.62569475 1.9881897 1.83639622 1.98288584]"
255
+ ],
256
+ [
257
+ "[0.56922203 0.98222166 0.76851749 0.28615737]",
258
+ "[1.56922197 1.9822216 1.76851749 1.28615737]"
259
+ ],
260
+ [
261
+ "[0.88776821 0.51636773 0.30333066 0.32230979]",
262
+ "[1.88776827 1.51636767 1.30333066 1.32230973]"
263
+ ],
264
+ [
265
+ "[0.90817457 0.89270043 0.38583666 0.66566533]",
266
+ "[1.90817451 1.89270043 1.3858366 1.66566539]"
267
+ ],
268
+ [
269
+ "[0.48507756 0.80808765 0.77162558 0.47834778]",
270
+ "[1.48507762 1.80808759 1.77162552 1.47834778]"
271
+ ],
272
+ [
273
+ "[0.68062544 0.98093534 0.14778823 0.53244978]",
274
+ "[1.68062544 1.98093534 1.14778829 1.53244972]"
275
+ ],
276
+ [
277
+ "[0.31518555 0.49643308 0.11509258 0.95458382]",
278
+ "[1.31518555 1.49643302 1.11509252 1.95458388]"
279
+ ],
280
+ [
281
+ "[0.79121011 0.54161114 0.69369799 0.1520769 ]",
282
+ "[1.79121017 1.54161119 1.69369793 1.15207696]"
283
+ ],
284
+ [
285
+ "[0.79423058 0.07138705 0.061777 0.18766576]",
286
+ "[1.79423058 1.07138705 1.061777 1.1876657 ]"
287
+ ],
288
+ [
289
+ "[0.23942459 0.90487361 0.69337189 0.65089428]",
290
+ "[1.23942459 1.90487361 1.69337189 1.65089428]"
291
+ ],
292
+ [
293
+ "[0.94516498 0.08422136 0.5608117 0.07652664]",
294
+ "[1.94516492 1.08422136 1.56081176 1.07652664]"
295
+ ],
296
+ [
297
+ "[0.26661873 0.45946234 0.13510543 0.81294441]",
298
+ "[1.26661873 1.4594624 1.13510537 1.81294441]"
299
+ ],
300
+ [
301
+ "[0.30754459 0.77694583 0.09278506 0.38326019]",
302
+ "[1.30754459 1.77694583 1.09278512 1.38326025]"
303
+ ],
304
+ [
305
+ "[0.27845025 0.32472342 0.82203609 0.77107543]",
306
+ "[1.27845025 1.32472348 1.82203603 1.77107549]"
307
+ ],
308
+ [
309
+ "[0.4827103 0.10563457 0.98858833 0.82286644]",
310
+ "[1.48271036 1.10563457 1.98858833 1.82286644]"
311
+ ],
312
+ [
313
+ "[0.98033333 0.97656083 0.38939917 0.81491041]",
314
+ "[1.98033333 1.97656083 1.38939917 1.81491041]"
315
+ ],
316
+ [
317
+ "[0.74064726 0.4155122 0.09800029 0.49930882]",
318
+ "[1.74064732 1.4155122 1.09800029 1.49930882]"
319
+ ],
320
+ [
321
+ "[0.78956431 0.87284744 0.06880784 0.03455889]",
322
+ "[1.78956437 1.87284744 1.06880784 1.03455889]"
323
+ ],
324
+ [
325
+ "[0.94221359 0.57740951 0.98649532 0.40934443]",
326
+ "[1.94221354 1.57740951 1.98649526 1.40934443]"
327
+ ],
328
+ [
329
+ "[0.00497234 0.39319336 0.57054168 0.75150961]",
330
+ "[1.00497234 1.39319336 1.57054162 1.75150967]"
331
+ ],
332
+ [
333
+ "[0.44330525 0.09997386 0.89025736 0.90507984]",
334
+ "[1.44330525 1.09997392 1.89025736 1.90507984]"
335
+ ],
336
+ [
337
+ "[0.72290605 0.96945059 0.68354797 0.15270454]",
338
+ "[1.72290611 1.96945059 1.68354797 1.15270448]"
339
+ ],
340
+ [
341
+ "[0.75292218 0.81470108 0.49657214 0.56217098]",
342
+ "[1.75292218 1.81470108 1.49657214 1.56217098]"
343
+ ],
344
+ [
345
+ "[0.33480108 0.59181517 0.76198453 0.98062384]",
346
+ "[1.33480108 1.59181523 1.76198459 1.98062384]"
347
+ ],
348
+ [
349
+ "[0.52784437 0.54268694 0.12358981 0.72116476]",
350
+ "[1.52784443 1.54268694 1.12358975 1.7211647 ]"
351
+ ],
352
+ [
353
+ "[0.73217702 0.65233225 0.44077861 0.33837909]",
354
+ "[1.73217702 1.65233231 1.44077861 1.33837914]"
355
+ ],
356
+ [
357
+ "[0.34084332 0.73018837 0.54168713 0.91440833]",
358
+ "[1.34084332 1.73018837 1.54168713 1.91440833]"
359
+ ],
360
+ [
361
+ "[0.60110539 0.3618983 0.32342511 0.98672163]",
362
+ "[1.60110545 1.3618983 1.32342505 1.98672163]"
363
+ ],
364
+ [
365
+ "[0.77427191 0.21829212 0.12769502 0.74303615]",
366
+ "[1.77427197 1.21829212 1.12769508 1.74303615]"
367
+ ],
368
+ [
369
+ "[0.08107251 0.2602725 0.18861133 0.44833237]",
370
+ "[1.08107257 1.2602725 1.18861127 1.44833231]"
371
+ ],
372
+ [
373
+ "[0.59812403 0.78395379 0.0291847 0.81814629]",
374
+ "[1.59812403 1.78395379 1.0291847 1.81814623]"
375
+ ],
376
+ [
377
+ "[0.93488538 0.73882395 0.37345302 0.0274905 ]",
378
+ "[1.93488538 1.73882389 1.37345302 1.0274905 ]"
379
+ ],
380
+ [
381
+ "[0.30631393 0.48311198 0.87847513 0.67559886]",
382
+ "[1.30631399 1.48311198 1.87847519 1.67559886]"
383
+ ],
384
+ [
385
+ "[0.18720162 0.74115586 0.98626411 0.30355608]",
386
+ "[1.18720162 1.74115586 1.98626411 1.30355608]"
387
+ ],
388
+ [
389
+ "[0.85566247 0.83362883 0.48424995 0.25265992]",
390
+ "[1.85566247 1.83362889 1.48424995 1.25265992]"
391
+ ],
392
+ [
393
+ "[0.95928186 0.84273899 0.71514636 0.38619852]",
394
+ "[1.95928192 1.84273899 1.7151463 1.38619852]"
395
+ ],
396
+ [
397
+ "[0.32565445 0.90939188 0.07488042 0.13730896]",
398
+ "[1.32565451 1.90939188 1.07488036 1.13730896]"
399
+ ],
400
+ [
401
+ "[0.9829582 0.59269661 0.40120947 0.95487177]",
402
+ "[1.9829582 1.59269667 1.40120947 1.95487177]"
403
+ ],
404
+ [
405
+ "[0.79905868 0.89367443 0.75429088 0.3190186 ]",
406
+ "[1.79905868 1.89367437 1.75429082 1.3190186 ]"
407
+ ],
408
+ [
409
+ "[0.54914117 0.03810108 0.87531954 0.73044223]",
410
+ "[1.54914117 1.03810108 1.87531948 1.73044229]"
411
+ ],
412
+ [
413
+ "[0.67418337 0.79634351 0.23229051 0.71345252]",
414
+ "[1.67418337 1.79634356 1.23229051 1.71345258]"
415
+ ],
416
+ [
417
+ "[0.87285906 0.48354989 0.39394957 0.59456545]",
418
+ "[1.872859 1.48354983 1.39394951 1.59456539]"
419
+ ],
420
+ [
421
+ "[0.81788456 0.58174163 0.29376316 0.7971254 ]",
422
+ "[1.81788456 1.58174157 1.29376316 1.79712534]"
423
+ ],
424
+ [
425
+ "[0.94559073 0.65736622 0.25761551 0.48553199]",
426
+ "[1.94559073 1.65736628 1.25761557 1.48553205]"
427
+ ],
428
+ [
429
+ "[0.60075855 0.12234765 0.00614399 0.30560958]",
430
+ "[1.60075855 1.12234759 1.00614405 1.30560958]"
431
+ ],
432
+ [
433
+ "[0.39147133 0.29854035 0.84663737 0.58175623]",
434
+ "[1.39147139 1.29854035 1.84663737 1.58175623]"
435
+ ],
436
+ [
437
+ "[0.02162331 0.81861657 0.92468154 0.07808572]",
438
+ "[1.02162337 1.81861663 1.92468154 1.07808566]"
439
+ ],
440
+ [
441
+ "[0.02235305 0.52774918 0.7331115 0.84358269]",
442
+ "[1.02235305 1.52774918 1.7331115 1.84358263]"
443
+ ],
444
+ [
445
+ "[0.6080932 0.56563014 0.32107437 0.72599429]",
446
+ "[1.60809326 1.5656302 1.32107437 1.72599435]"
447
+ ],
448
+ [
449
+ "[0.67447788 0.6125319 0.98007888 0.65968603]",
450
+ "[1.67447782 1.6125319 1.98007894 1.65968609]"
451
+ ],
452
+ [
453
+ "[0.47963417 0.81818312 0.48720706 0.49339259]",
454
+ "[1.47963417 1.81818318 1.48720706 1.49339259]"
455
+ ],
456
+ [
457
+ "[0.9630242 0.76359051 0.24853623 0.76881069]",
458
+ "[1.96302414 1.76359057 1.24853623 1.76881075]"
459
+ ],
460
+ [
461
+ "[0.60609657 0.96257663 0.19292736 0.95702219]",
462
+ "[1.60609651 1.96257663 1.19292736 1.95702219]"
463
+ ],
464
+ [
465
+ "[0.80654246 0.08253473 0.74478531 0.71257162]",
466
+ "[1.8065424 1.08253479 1.74478531 1.71257162]"
467
+ ],
468
+ [
469
+ "[0.70167565 0.26930219 0.5660674 0.61194974]",
470
+ "[1.70167565 1.26930213 1.56606746 1.61194968]"
471
+ ],
472
+ [
473
+ "[0.76933283 0.86241865 0.44114518 0.65644735]",
474
+ "[1.76933289 1.86241865 1.44114518 1.65644741]"
475
+ ],
476
+ [
477
+ "[0.59492421 0.90274489 0.38069052 0.46101224]",
478
+ "[1.59492421 1.90274489 1.38069057 1.46101224]"
479
+ ],
480
+ [
481
+ "[0.15064228 0.03198934 0.25754827 0.51484001]",
482
+ "[1.15064228 1.03198934 1.25754833 1.51484001]"
483
+ ],
484
+ [
485
+ "[0.12024075 0.21342516 0.56858408 0.58644271]",
486
+ "[1.12024069 1.21342516 1.56858408 1.58644271]"
487
+ ],
488
+ [
489
+ "[0.91730917 0.22574073 0.09591609 0.33056474]",
490
+ "[1.91730917 1.22574067 1.09591603 1.33056474]"
491
+ ],
492
+ [
493
+ "[0.49691743 0.61873293 0.90698647 0.94486356]",
494
+ "[1.49691749 1.61873293 1.90698647 1.94486356]"
495
+ ],
496
+ [
497
+ "[0.6032477 0.83361369 0.18538666 0.19108021]",
498
+ "[1.60324764 1.83361363 1.18538666 1.19108021]"
499
+ ],
500
+ [
501
+ "[0.63235509 0.70352674 0.96188956 0.46240485]",
502
+ "[1.63235509 1.70352674 1.96188951 1.46240485]"
503
+ ],
504
+ [
505
+ "[0.37959969 0.42820001 0.10690689 0.96353984]",
506
+ "[1.37959969 1.42820001 1.10690689 1.96353984]"
507
+ ],
508
+ [
509
+ "[0.49607176 0.1922397 0.46640229 0.78321403]",
510
+ "[1.49607182 1.19223976 1.46640229 1.78321409]"
511
+ ],
512
+ [
513
+ "[0.40234613 0.54987347 0.49542785 0.54153186]",
514
+ "[1.40234613 1.54987347 1.49542785 1.5415318 ]"
515
+ ],
516
+ [
517
+ "[0.80893755 0.92237449 0.88346356 0.93164903]",
518
+ "[1.80893755 1.92237449 1.88346362 1.93164897]"
519
+ ],
520
+ [
521
+ "[0.12858278 0.09930819 0.83222693 0.72485673]",
522
+ "[1.12858272 1.09930825 1.83222699 1.72485673]"
523
+ ],
524
+ [
525
+ "[0.72470158 0.4940322 0.41027349 0.89364016]",
526
+ "[1.72470164 1.49403214 1.41027355 1.89364016]"
527
+ ],
528
+ [
529
+ "[0.47856545 0.46267092 0.6376707 0.84747767]",
530
+ "[1.47856545 1.46267092 1.63767076 1.84747767]"
531
+ ],
532
+ [
533
+ "[0.49584109 0.80599248 0.07096875 0.75872749]",
534
+ "[1.49584103 1.80599248 1.07096875 1.75872755]"
535
+ ],
536
+ [
537
+ "[0.43500566 0.66041756 0.80293626 0.96224713]",
538
+ "[1.43500566 1.66041756 1.80293632 1.96224713]"
539
+ ],
540
+ [
541
+ "[0.78397602 0.74223626 0.26603186 0.41664881]",
542
+ "[1.78397608 1.74223626 1.26603186 1.41664886]"
543
+ ],
544
+ [
545
+ "[0.28942841 0.05601001 0.33039129 0.27781558]",
546
+ "[1.28942847 1.05601001 1.33039129 1.27781558]"
547
+ ],
548
+ [
549
+ "[0.68094063 0.45189077 0.22661722 0.37354094]",
550
+ "[1.68094063 1.45189071 1.22661722 1.37354088]"
551
+ ],
552
+ [
553
+ "[0.43681622 0.74680805 0.83598751 0.12414402]",
554
+ "[1.43681622 1.74680805 1.83598757 1.12414408]"
555
+ ],
556
+ [
557
+ "[0.47870928 0.17129105 0.27300501 0.20634609]",
558
+ "[1.47870922 1.17129111 1.27300501 1.20634604]"
559
+ ],
560
+ [
561
+ "[0.72795159 0.79317838 0.27832931 0.96576637]",
562
+ "[1.72795153 1.79317832 1.27832937 1.96576643]"
563
+ ],
564
+ [
565
+ "[0.87608397 0.93200487 0.80169648 0.37758952]",
566
+ "[1.87608397 1.93200493 1.80169654 1.37758946]"
567
+ ],
568
+ [
569
+ "[0.68891573 0.25576538 0.96339929 0.503833 ]",
570
+ "[1.68891573 1.25576544 1.96339929 1.50383306]"
571
+ ]
572
+ ]
573
+ },
574
+ "df_test": {
575
+ "columns": [
576
+ "x",
577
+ "y",
578
+ "predicted"
579
+ ],
580
+ "data": [
581
+ [
582
+ "[0.49691743 0.61873293 0.90698647 0.94486356]",
583
+ "[1.49691749 1.61873293 1.90698647 1.94486356]",
584
+ "[1.4993021488189697, 1.6404846906661987, 1.923316240310669, 1.9422152042388916]"
585
+ ],
586
+ [
587
+ "[0.56922203 0.98222166 0.76851749 0.28615737]",
588
+ "[1.56922197 1.9822216 1.76851749 1.28615737]",
589
+ "[1.5835213661193848, 1.9884355068206787, 1.7694181203842163, 1.2917503118515015]"
590
+ ],
591
+ [
592
+ "[0.90817457 0.89270043 0.38583666 0.66566533]",
593
+ "[1.90817451 1.89270043 1.3858366 1.66566539]",
594
+ "[1.9053494930267334, 1.9083378314971924, 1.3998609781265259, 1.6636812686920166]"
595
+ ],
596
+ [
597
+ "[0.72795159 0.79317838 0.27832931 0.96576637]",
598
+ "[1.72795153 1.79317832 1.27832937 1.96576643]",
599
+ "[1.734963297843933, 1.8026459217071533, 1.2926064729690552, 1.9596911668777466]"
600
+ ],
601
+ [
602
+ "[0.04508126 0.76880038 0.80721325 0.62542385]",
603
+ "[1.04508126 1.76880038 1.80721331 1.62542391]",
604
+ "[1.0830243825912476, 1.7584562301635742, 1.8005754947662354, 1.6277496814727783]"
605
+ ],
606
+ [
607
+ "[0.6032477 0.83361369 0.18538666 0.19108021]",
608
+ "[1.60324764 1.83361363 1.18538666 1.19108021]",
609
+ "[1.6177492141723633, 1.8144152164459229, 1.1718573570251465, 1.1950569152832031]"
610
+ ],
611
+ [
612
+ "[0.15064228 0.03198934 0.25754827 0.51484001]",
613
+ "[1.15064228 1.03198934 1.25754833 1.51484001]",
614
+ "[1.1556042432785034, 0.9955940246582031, 1.2316606044769287, 1.5150485038757324]"
615
+ ],
616
+ [
617
+ "[0.48959708 0.48549271 0.32688856 0.356677 ]",
618
+ "[1.48959708 1.48549271 1.32688856 1.35667706]",
619
+ "[1.4930214881896973, 1.467790961265564, 1.3132573366165161, 1.3589863777160645]"
620
+ ],
621
+ [
622
+ "[0.08107251 0.2602725 0.18861133 0.44833237]",
623
+ "[1.08107257 1.2602725 1.18861127 1.44833231]",
624
+ "[1.102121114730835, 1.2180893421173096, 1.160165548324585, 1.4495322704315186]"
625
+ ],
626
+ [
627
+ "[0.68094063 0.45189077 0.22661722 0.37354094]",
628
+ "[1.68094063 1.45189071 1.22661722 1.37354088]",
629
+ "[1.6725687980651855, 1.4393560886383057, 1.2169336080551147, 1.3746893405914307]"
630
+ ]
631
+ ]
632
+ },
633
+ "df_train": {
634
+ "columns": [
635
+ "x",
636
+ "y"
637
+ ],
638
+ "data": [
639
+ [
640
+ "[0.52046251 0.45887971 0.72169858 0.29517919]",
641
+ "[1.52046251 1.45887971 1.72169852 1.29517913]"
642
+ ],
643
+ [
644
+ "[0.85706753 0.61447072 0.41741937 0.85147089]",
645
+ "[1.85706758 1.61447072 1.41741943 1.85147095]"
646
+ ],
647
+ [
648
+ "[0.11560339 0.57495481 0.76535827 0.0391947 ]",
649
+ "[1.11560345 1.57495475 1.76535821 1.0391947 ]"
650
+ ],
651
+ [
652
+ "[0.19409031 0.68692201 0.60667384 0.57829887]",
653
+ "[1.19409037 1.68692207 1.60667384 1.57829881]"
654
+ ],
655
+ [
656
+ "[0.76807946 0.98855817 0.08259124 0.01730657]",
657
+ "[1.76807952 1.98855817 1.0825913 1.01730657]"
658
+ ],
659
+ [
660
+ "[0.67269951 0.10478973 0.5584439 0.83605725]",
661
+ "[1.67269945 1.10478973 1.5584439 1.83605719]"
662
+ ],
663
+ [
664
+ "[0.18686318 0.49356437 0.51323432 0.75392658]",
665
+ "[1.18686318 1.49356437 1.51323438 1.75392652]"
666
+ ],
667
+ [
668
+ "[0.18149549 0.30520517 0.30946714 0.16786289]",
669
+ "[1.18149543 1.30520511 1.30946708 1.16786289]"
670
+ ],
671
+ [
672
+ "[4.27091718e-01 4.89909172e-01 6.92297399e-01 2.57611275e-04]",
673
+ "[1.42709172 1.48990917 1.69229746 1.00025761]"
674
+ ],
675
+ [
676
+ "[0.32225502 0.16999388 0.05823922 0.9628762 ]",
677
+ "[1.32225502 1.16999388 1.05823922 1.9628762 ]"
678
+ ],
679
+ [
680
+ "[0.50783676 0.04156506 0.21984279 0.8454656 ]",
681
+ "[1.50783682 1.04156506 1.21984279 1.84546566]"
682
+ ],
683
+ [
684
+ "[0.98324287 0.99464184 0.14008355 0.47651017]",
685
+ "[1.98324287 1.99464178 1.14008355 1.47651017]"
686
+ ],
687
+ [
688
+ "[0.11693293 0.49860179 0.55020827 0.88832849]",
689
+ "[1.11693287 1.49860179 1.55020833 1.88832855]"
690
+ ],
691
+ [
692
+ "[0.50272274 0.54912758 0.17663097 0.79070699]",
693
+ "[1.50272274 1.54912758 1.17663097 1.79070699]"
694
+ ],
695
+ [
696
+ "[0.19908059 0.17570406 0.51475513 0.1893943 ]",
697
+ "[1.19908059 1.175704 1.51475513 1.18939424]"
698
+ ],
699
+ [
700
+ "[0.40167677 0.25953674 0.9407078 0.76308483]",
701
+ "[1.40167677 1.25953674 1.9407078 1.76308489]"
702
+ ],
703
+ [
704
+ "[0.2480728 0.21694398 0.63941365 0.57128876]",
705
+ "[1.24807286 1.21694398 1.6394136 1.57128882]"
706
+ ],
707
+ [
708
+ "[0.24388778 0.07268471 0.68350857 0.73431659]",
709
+ "[1.24388778 1.07268476 1.68350863 1.73431659]"
710
+ ],
711
+ [
712
+ "[0.62569475 0.9881897 0.83639616 0.9828859 ]",
713
+ "[1.62569475 1.9881897 1.83639622 1.98288584]"
714
+ ],
715
+ [
716
+ "[0.88776821 0.51636773 0.30333066 0.32230979]",
717
+ "[1.88776827 1.51636767 1.30333066 1.32230973]"
718
+ ],
719
+ [
720
+ "[0.48507756 0.80808765 0.77162558 0.47834778]",
721
+ "[1.48507762 1.80808759 1.77162552 1.47834778]"
722
+ ],
723
+ [
724
+ "[0.68062544 0.98093534 0.14778823 0.53244978]",
725
+ "[1.68062544 1.98093534 1.14778829 1.53244972]"
726
+ ],
727
+ [
728
+ "[0.31518555 0.49643308 0.11509258 0.95458382]",
729
+ "[1.31518555 1.49643302 1.11509252 1.95458388]"
730
+ ],
731
+ [
732
+ "[0.79121011 0.54161114 0.69369799 0.1520769 ]",
733
+ "[1.79121017 1.54161119 1.69369793 1.15207696]"
734
+ ],
735
+ [
736
+ "[0.79423058 0.07138705 0.061777 0.18766576]",
737
+ "[1.79423058 1.07138705 1.061777 1.1876657 ]"
738
+ ],
739
+ [
740
+ "[0.23942459 0.90487361 0.69337189 0.65089428]",
741
+ "[1.23942459 1.90487361 1.69337189 1.65089428]"
742
+ ],
743
+ [
744
+ "[0.94516498 0.08422136 0.5608117 0.07652664]",
745
+ "[1.94516492 1.08422136 1.56081176 1.07652664]"
746
+ ],
747
+ [
748
+ "[0.26661873 0.45946234 0.13510543 0.81294441]",
749
+ "[1.26661873 1.4594624 1.13510537 1.81294441]"
750
+ ],
751
+ [
752
+ "[0.30754459 0.77694583 0.09278506 0.38326019]",
753
+ "[1.30754459 1.77694583 1.09278512 1.38326025]"
754
+ ],
755
+ [
756
+ "[0.27845025 0.32472342 0.82203609 0.77107543]",
757
+ "[1.27845025 1.32472348 1.82203603 1.77107549]"
758
+ ],
759
+ [
760
+ "[0.4827103 0.10563457 0.98858833 0.82286644]",
761
+ "[1.48271036 1.10563457 1.98858833 1.82286644]"
762
+ ],
763
+ [
764
+ "[0.98033333 0.97656083 0.38939917 0.81491041]",
765
+ "[1.98033333 1.97656083 1.38939917 1.81491041]"
766
+ ],
767
+ [
768
+ "[0.74064726 0.4155122 0.09800029 0.49930882]",
769
+ "[1.74064732 1.4155122 1.09800029 1.49930882]"
770
+ ],
771
+ [
772
+ "[0.78956431 0.87284744 0.06880784 0.03455889]",
773
+ "[1.78956437 1.87284744 1.06880784 1.03455889]"
774
+ ],
775
+ [
776
+ "[0.94221359 0.57740951 0.98649532 0.40934443]",
777
+ "[1.94221354 1.57740951 1.98649526 1.40934443]"
778
+ ],
779
+ [
780
+ "[0.00497234 0.39319336 0.57054168 0.75150961]",
781
+ "[1.00497234 1.39319336 1.57054162 1.75150967]"
782
+ ],
783
+ [
784
+ "[0.44330525 0.09997386 0.89025736 0.90507984]",
785
+ "[1.44330525 1.09997392 1.89025736 1.90507984]"
786
+ ],
787
+ [
788
+ "[0.72290605 0.96945059 0.68354797 0.15270454]",
789
+ "[1.72290611 1.96945059 1.68354797 1.15270448]"
790
+ ],
791
+ [
792
+ "[0.75292218 0.81470108 0.49657214 0.56217098]",
793
+ "[1.75292218 1.81470108 1.49657214 1.56217098]"
794
+ ],
795
+ [
796
+ "[0.33480108 0.59181517 0.76198453 0.98062384]",
797
+ "[1.33480108 1.59181523 1.76198459 1.98062384]"
798
+ ],
799
+ [
800
+ "[0.52784437 0.54268694 0.12358981 0.72116476]",
801
+ "[1.52784443 1.54268694 1.12358975 1.7211647 ]"
802
+ ],
803
+ [
804
+ "[0.73217702 0.65233225 0.44077861 0.33837909]",
805
+ "[1.73217702 1.65233231 1.44077861 1.33837914]"
806
+ ],
807
+ [
808
+ "[0.34084332 0.73018837 0.54168713 0.91440833]",
809
+ "[1.34084332 1.73018837 1.54168713 1.91440833]"
810
+ ],
811
+ [
812
+ "[0.60110539 0.3618983 0.32342511 0.98672163]",
813
+ "[1.60110545 1.3618983 1.32342505 1.98672163]"
814
+ ],
815
+ [
816
+ "[0.77427191 0.21829212 0.12769502 0.74303615]",
817
+ "[1.77427197 1.21829212 1.12769508 1.74303615]"
818
+ ],
819
+ [
820
+ "[0.59812403 0.78395379 0.0291847 0.81814629]",
821
+ "[1.59812403 1.78395379 1.0291847 1.81814623]"
822
+ ],
823
+ [
824
+ "[0.93488538 0.73882395 0.37345302 0.0274905 ]",
825
+ "[1.93488538 1.73882389 1.37345302 1.0274905 ]"
826
+ ],
827
+ [
828
+ "[0.30631393 0.48311198 0.87847513 0.67559886]",
829
+ "[1.30631399 1.48311198 1.87847519 1.67559886]"
830
+ ],
831
+ [
832
+ "[0.18720162 0.74115586 0.98626411 0.30355608]",
833
+ "[1.18720162 1.74115586 1.98626411 1.30355608]"
834
+ ],
835
+ [
836
+ "[0.85566247 0.83362883 0.48424995 0.25265992]",
837
+ "[1.85566247 1.83362889 1.48424995 1.25265992]"
838
+ ],
839
+ [
840
+ "[0.95928186 0.84273899 0.71514636 0.38619852]",
841
+ "[1.95928192 1.84273899 1.7151463 1.38619852]"
842
+ ],
843
+ [
844
+ "[0.32565445 0.90939188 0.07488042 0.13730896]",
845
+ "[1.32565451 1.90939188 1.07488036 1.13730896]"
846
+ ],
847
+ [
848
+ "[0.9829582 0.59269661 0.40120947 0.95487177]",
849
+ "[1.9829582 1.59269667 1.40120947 1.95487177]"
850
+ ],
851
+ [
852
+ "[0.79905868 0.89367443 0.75429088 0.3190186 ]",
853
+ "[1.79905868 1.89367437 1.75429082 1.3190186 ]"
854
+ ],
855
+ [
856
+ "[0.54914117 0.03810108 0.87531954 0.73044223]",
857
+ "[1.54914117 1.03810108 1.87531948 1.73044229]"
858
+ ],
859
+ [
860
+ "[0.67418337 0.79634351 0.23229051 0.71345252]",
861
+ "[1.67418337 1.79634356 1.23229051 1.71345258]"
862
+ ],
863
+ [
864
+ "[0.87285906 0.48354989 0.39394957 0.59456545]",
865
+ "[1.872859 1.48354983 1.39394951 1.59456539]"
866
+ ],
867
+ [
868
+ "[0.81788456 0.58174163 0.29376316 0.7971254 ]",
869
+ "[1.81788456 1.58174157 1.29376316 1.79712534]"
870
+ ],
871
+ [
872
+ "[0.94559073 0.65736622 0.25761551 0.48553199]",
873
+ "[1.94559073 1.65736628 1.25761557 1.48553205]"
874
+ ],
875
+ [
876
+ "[0.60075855 0.12234765 0.00614399 0.30560958]",
877
+ "[1.60075855 1.12234759 1.00614405 1.30560958]"
878
+ ],
879
+ [
880
+ "[0.39147133 0.29854035 0.84663737 0.58175623]",
881
+ "[1.39147139 1.29854035 1.84663737 1.58175623]"
882
+ ],
883
+ [
884
+ "[0.02162331 0.81861657 0.92468154 0.07808572]",
885
+ "[1.02162337 1.81861663 1.92468154 1.07808566]"
886
+ ],
887
+ [
888
+ "[0.02235305 0.52774918 0.7331115 0.84358269]",
889
+ "[1.02235305 1.52774918 1.7331115 1.84358263]"
890
+ ],
891
+ [
892
+ "[0.6080932 0.56563014 0.32107437 0.72599429]",
893
+ "[1.60809326 1.5656302 1.32107437 1.72599435]"
894
+ ],
895
+ [
896
+ "[0.67447788 0.6125319 0.98007888 0.65968603]",
897
+ "[1.67447782 1.6125319 1.98007894 1.65968609]"
898
+ ],
899
+ [
900
+ "[0.47963417 0.81818312 0.48720706 0.49339259]",
901
+ "[1.47963417 1.81818318 1.48720706 1.49339259]"
902
+ ],
903
+ [
904
+ "[0.9630242 0.76359051 0.24853623 0.76881069]",
905
+ "[1.96302414 1.76359057 1.24853623 1.76881075]"
906
+ ],
907
+ [
908
+ "[0.60609657 0.96257663 0.19292736 0.95702219]",
909
+ "[1.60609651 1.96257663 1.19292736 1.95702219]"
910
+ ],
911
+ [
912
+ "[0.80654246 0.08253473 0.74478531 0.71257162]",
913
+ "[1.8065424 1.08253479 1.74478531 1.71257162]"
914
+ ],
915
+ [
916
+ "[0.70167565 0.26930219 0.5660674 0.61194974]",
917
+ "[1.70167565 1.26930213 1.56606746 1.61194968]"
918
+ ],
919
+ [
920
+ "[0.76933283 0.86241865 0.44114518 0.65644735]",
921
+ "[1.76933289 1.86241865 1.44114518 1.65644741]"
922
+ ],
923
+ [
924
+ "[0.59492421 0.90274489 0.38069052 0.46101224]",
925
+ "[1.59492421 1.90274489 1.38069057 1.46101224]"
926
+ ],
927
+ [
928
+ "[0.12024075 0.21342516 0.56858408 0.58644271]",
929
+ "[1.12024069 1.21342516 1.56858408 1.58644271]"
930
+ ],
931
+ [
932
+ "[0.91730917 0.22574073 0.09591609 0.33056474]",
933
+ "[1.91730917 1.22574067 1.09591603 1.33056474]"
934
+ ],
935
+ [
936
+ "[0.63235509 0.70352674 0.96188956 0.46240485]",
937
+ "[1.63235509 1.70352674 1.96188951 1.46240485]"
938
+ ],
939
+ [
940
+ "[0.37959969 0.42820001 0.10690689 0.96353984]",
941
+ "[1.37959969 1.42820001 1.10690689 1.96353984]"
942
+ ],
943
+ [
944
+ "[0.49607176 0.1922397 0.46640229 0.78321403]",
945
+ "[1.49607182 1.19223976 1.46640229 1.78321409]"
946
+ ],
947
+ [
948
+ "[0.40234613 0.54987347 0.49542785 0.54153186]",
949
+ "[1.40234613 1.54987347 1.49542785 1.5415318 ]"
950
+ ],
951
+ [
952
+ "[0.80893755 0.92237449 0.88346356 0.93164903]",
953
+ "[1.80893755 1.92237449 1.88346362 1.93164897]"
954
+ ],
955
+ [
956
+ "[0.12858278 0.09930819 0.83222693 0.72485673]",
957
+ "[1.12858272 1.09930825 1.83222699 1.72485673]"
958
+ ],
959
+ [
960
+ "[0.72470158 0.4940322 0.41027349 0.89364016]",
961
+ "[1.72470164 1.49403214 1.41027355 1.89364016]"
962
+ ],
963
+ [
964
+ "[0.47856545 0.46267092 0.6376707 0.84747767]",
965
+ "[1.47856545 1.46267092 1.63767076 1.84747767]"
966
+ ],
967
+ [
968
+ "[0.49584109 0.80599248 0.07096875 0.75872749]",
969
+ "[1.49584103 1.80599248 1.07096875 1.75872755]"
970
+ ],
971
+ [
972
+ "[0.43500566 0.66041756 0.80293626 0.96224713]",
973
+ "[1.43500566 1.66041756 1.80293632 1.96224713]"
974
+ ],
975
+ [
976
+ "[0.78397602 0.74223626 0.26603186 0.41664881]",
977
+ "[1.78397608 1.74223626 1.26603186 1.41664886]"
978
+ ],
979
+ [
980
+ "[0.28942841 0.05601001 0.33039129 0.27781558]",
981
+ "[1.28942847 1.05601001 1.33039129 1.27781558]"
982
+ ],
983
+ [
984
+ "[0.43681622 0.74680805 0.83598751 0.12414402]",
985
+ "[1.43681622 1.74680805 1.83598757 1.12414408]"
986
+ ],
987
+ [
988
+ "[0.47870928 0.17129105 0.27300501 0.20634609]",
989
+ "[1.47870922 1.17129111 1.27300501 1.20634604]"
990
+ ],
991
+ [
992
+ "[0.87608397 0.93200487 0.80169648 0.37758952]",
993
+ "[1.87608397 1.93200493 1.80169654 1.37758946]"
994
+ ],
995
+ [
996
+ "[0.68891573 0.25576538 0.96339929 0.503833 ]",
997
+ "[1.68891573 1.25576544 1.96339929 1.50383306]"
998
+ ]
999
+ ]
1000
+ }
1001
+ },
1002
+ "other": {
1003
+ "model": "ModelConfig(model=Sequential(\n (0) - Linear(in_features=4, out_features=4, bias=True): Input__embedding_1_x -> Linear_1_x\n (1) - <function leaky_relu at 0x719e0ce23a60>: Linear_1_x -> Activation_2_x\n (2) - Identity(): Activation_2_x -> Activation_2_x\n), model_inputs=['Input__embedding_1_x'], model_outputs=['Activation_2_x'], loss_inputs=['Input__label_1_y', 'Activation_2_x'], loss=Sequential(\n (0) - <function mse_loss at 0x719e0ce2d580>: Activation_2_x, Input__label_1_y -> MSE_loss_1_loss\n (1) - Identity(): MSE_loss_1_loss -> loss\n), optimizer=SGD (\nParameter Group 0\n dampening: 0\n differentiable: False\n foreach: None\n fused: None\n lr: 0.1\n maximize: False\n momentum: 0\n nesterov: False\n weight_decay: 0\n), source_workspace=None, trained=True)"
1004
+ },
1005
+ "relations": []
1006
+ },
1007
+ "error": null,
1008
+ "input_metadata": [
1009
+ {
1010
+ "dataframes": {
1011
+ "df": {
1012
+ "columns": [
1013
+ "x",
1014
+ "y"
1015
+ ]
1016
+ },
1017
+ "df_test": {
1018
+ "columns": [
1019
+ "predicted",
1020
+ "x",
1021
+ "y"
1022
+ ]
1023
+ },
1024
+ "df_train": {
1025
+ "columns": [
1026
+ "x",
1027
+ "y"
1028
+ ]
1029
+ }
1030
+ },
1031
+ "other": {
1032
+ "model": {
1033
+ "model": {
1034
+ "inputs": [
1035
+ "Input__embedding_1_x"
1036
+ ],
1037
+ "loss_inputs": [
1038
+ "Input__label_1_y",
1039
+ "Activation_2_x"
1040
+ ],
1041
+ "outputs": [
1042
+ "Activation_2_x"
1043
+ ],
1044
+ "trained": true
1045
+ },
1046
+ "type": "model"
1047
+ }
1048
+ },
1049
+ "relations": []
1050
+ }
1051
+ ],
1052
+ "meta": {
1053
+ "inputs": {
1054
+ "bundle": {
1055
+ "name": "bundle",
1056
+ "position": "left",
1057
+ "type": {
1058
+ "type": "<class 'lynxkite_graph_analytics.core.Bundle'>"
1059
+ }
1060
+ }
1061
+ },
1062
+ "name": "View tables",
1063
+ "outputs": {},
1064
+ "params": {
1065
+ "limit": {
1066
+ "default": 100.0,
1067
+ "name": "limit",
1068
+ "type": {
1069
+ "type": "<class 'int'>"
1070
+ }
1071
+ }
1072
+ },
1073
+ "type": "table_view"
1074
+ },
1075
+ "params": {
1076
+ "limit": 100.0
1077
+ },
1078
+ "status": "done",
1079
+ "title": "View tables"
1080
+ },
1081
+ "dragHandle": ".bg-primary",
1082
+ "height": 711.0,
1083
+ "id": "View tables 1",
1084
+ "position": {
1085
+ "x": 2900.3734843758352,
1086
+ "y": -128.41609608842867
1087
+ },
1088
+ "type": "table_view",
1089
+ "width": 836.0
1090
+ },
1091
+ {
1092
+ "data": {
1093
+ "__execution_delay": 0.0,
1094
+ "collapsed": null,
1095
+ "display": null,
1096
+ "error": null,
1097
+ "input_metadata": [
1098
+ {
1099
+ "dataframes": {
1100
+ "df": {
1101
+ "columns": [
1102
+ "x",
1103
+ "y"
1104
+ ]
1105
+ },
1106
+ "df_test": {
1107
+ "columns": [
1108
+ "x",
1109
+ "y"
1110
+ ]
1111
+ },
1112
+ "df_train": {
1113
+ "columns": [
1114
+ "x",
1115
+ "y"
1116
+ ]
1117
+ }
1118
+ },
1119
+ "other": {},
1120
+ "relations": []
1121
+ }
1122
+ ],
1123
+ "meta": {
1124
+ "inputs": {
1125
+ "bundle": {
1126
+ "name": "bundle",
1127
+ "position": "left",
1128
+ "type": {
1129
+ "type": "<class 'lynxkite_graph_analytics.core.Bundle'>"
1130
+ }
1131
+ }
1132
+ },
1133
+ "name": "Define model",
1134
+ "outputs": {
1135
+ "output": {
1136
+ "name": "output",
1137
+ "position": "right",
1138
+ "type": {
1139
+ "type": "None"
1140
+ }
1141
+ }
1142
+ },
1143
+ "params": {
1144
+ "model_workspace": {
1145
+ "default": null,
1146
+ "name": "model_workspace",
1147
+ "type": {
1148
+ "type": "<class 'str'>"
1149
+ }
1150
+ },
1151
+ "save_as": {
1152
+ "default": "model",
1153
+ "name": "save_as",
1154
+ "type": {
1155
+ "type": "<class 'str'>"
1156
+ }
1157
+ }
1158
+ },
1159
+ "type": "basic"
1160
+ },
1161
+ "params": {
1162
+ "model_workspace": "Model definition",
1163
+ "save_as": "model"
1164
+ },
1165
+ "status": "done",
1166
+ "title": "Define model"
1167
+ },
1168
+ "dragHandle": ".bg-primary",
1169
+ "height": 537.0,
1170
+ "id": "Define model 1",
1171
+ "position": {
1172
+ "x": 795.0,
1173
+ "y": -45.0
1174
+ },
1175
+ "type": "basic",
1176
+ "width": 498.0
1177
+ },
1178
+ {
1179
+ "data": {
1180
+ "__execution_delay": 0.0,
1181
+ "collapsed": null,
1182
+ "display": null,
1183
+ "error": null,
1184
+ "input_metadata": [
1185
+ {
1186
+ "dataframes": {
1187
+ "df": {
1188
+ "columns": [
1189
+ "x",
1190
+ "y"
1191
+ ]
1192
+ },
1193
+ "df_test": {
1194
+ "columns": [
1195
+ "x",
1196
+ "y"
1197
+ ]
1198
+ },
1199
+ "df_train": {
1200
+ "columns": [
1201
+ "x",
1202
+ "y"
1203
+ ]
1204
+ }
1205
+ },
1206
+ "other": {
1207
+ "model": {
1208
+ "model": {
1209
+ "inputs": [
1210
+ "Input__embedding_1_x"
1211
+ ],
1212
+ "loss_inputs": [
1213
+ "Input__label_1_y",
1214
+ "Activation_2_x"
1215
+ ],
1216
+ "outputs": [
1217
+ "Activation_2_x"
1218
+ ],
1219
+ "trained": false
1220
+ },
1221
+ "type": "model"
1222
+ }
1223
+ },
1224
+ "relations": []
1225
+ }
1226
+ ],
1227
+ "meta": {
1228
+ "inputs": {
1229
+ "bundle": {
1230
+ "name": "bundle",
1231
+ "position": "left",
1232
+ "type": {
1233
+ "type": "<class 'lynxkite_graph_analytics.core.Bundle'>"
1234
+ }
1235
+ }
1236
+ },
1237
+ "name": "Train model",
1238
+ "outputs": {
1239
+ "output": {
1240
+ "name": "output",
1241
+ "position": "right",
1242
+ "type": {
1243
+ "type": "None"
1244
+ }
1245
+ }
1246
+ },
1247
+ "params": {
1248
+ "epochs": {
1249
+ "default": 1.0,
1250
+ "name": "epochs",
1251
+ "type": {
1252
+ "type": "<class 'int'>"
1253
+ }
1254
+ },
1255
+ "input_mapping": {
1256
+ "default": null,
1257
+ "name": "input_mapping",
1258
+ "type": {
1259
+ "type": "<class 'lynxkite_graph_analytics.lynxkite_ops.ModelTrainingInputMapping'>"
1260
+ }
1261
+ },
1262
+ "model_name": {
1263
+ "default": "model",
1264
+ "name": "model_name",
1265
+ "type": {
1266
+ "type": "<class 'str'>"
1267
+ }
1268
+ }
1269
+ },
1270
+ "type": "basic"
1271
+ },
1272
+ "params": {
1273
+ "epochs": "1001",
1274
+ "input_mapping": "{\"map\":{\"Input__embedding_1_x\":{\"column\":\"x\",\"df\":\"df_train\"},\"Input__label_1_y\":{\"column\":\"y\",\"df\":\"df_train\"}}}",
1275
+ "model_name": "model"
1276
+ },
1277
+ "status": "done",
1278
+ "title": "Train model"
1279
+ },
1280
+ "dragHandle": ".bg-primary",
1281
+ "height": 604.0,
1282
+ "id": "Train model 2",
1283
+ "position": {
1284
+ "x": 1399.5245787239226,
1285
+ "y": -19.196202428593544
1286
+ },
1287
+ "type": "basic",
1288
+ "width": 586.0
1289
+ },
1290
+ {
1291
+ "data": {
1292
+ "__execution_delay": 0.0,
1293
+ "collapsed": null,
1294
+ "display": null,
1295
+ "error": null,
1296
+ "input_metadata": [
1297
+ {
1298
+ "dataframes": {
1299
+ "df": {
1300
+ "columns": [
1301
+ "x",
1302
+ "y"
1303
+ ]
1304
+ },
1305
+ "df_test": {
1306
+ "columns": [
1307
+ "predicted",
1308
+ "x",
1309
+ "y"
1310
+ ]
1311
+ },
1312
+ "df_train": {
1313
+ "columns": [
1314
+ "x",
1315
+ "y"
1316
+ ]
1317
+ }
1318
+ },
1319
+ "other": {
1320
+ "model": {
1321
+ "model": {
1322
+ "inputs": [
1323
+ "Input__embedding_1_x"
1324
+ ],
1325
+ "loss_inputs": [
1326
+ "Input__label_1_y",
1327
+ "Activation_2_x"
1328
+ ],
1329
+ "outputs": [
1330
+ "Activation_2_x"
1331
+ ],
1332
+ "trained": true
1333
+ },
1334
+ "type": "model"
1335
+ }
1336
+ },
1337
+ "relations": []
1338
+ }
1339
+ ],
1340
+ "meta": {
1341
+ "inputs": {
1342
+ "bundle": {
1343
+ "name": "bundle",
1344
+ "position": "left",
1345
+ "type": {
1346
+ "type": "<class 'lynxkite_graph_analytics.core.Bundle'>"
1347
+ }
1348
+ }
1349
+ },
1350
+ "name": "Model inference",
1351
+ "outputs": {
1352
+ "output": {
1353
+ "name": "output",
1354
+ "position": "right",
1355
+ "type": {
1356
+ "type": "None"
1357
+ }
1358
+ }
1359
+ },
1360
+ "params": {
1361
+ "input_mapping": {
1362
+ "default": null,
1363
+ "name": "input_mapping",
1364
+ "type": {
1365
+ "type": "<class 'lynxkite_graph_analytics.lynxkite_ops.ModelInferenceInputMapping'>"
1366
+ }
1367
+ },
1368
+ "model_name": {
1369
+ "default": "model",
1370
+ "name": "model_name",
1371
+ "type": {
1372
+ "type": "<class 'str'>"
1373
+ }
1374
+ },
1375
+ "output_mapping": {
1376
+ "default": null,
1377
+ "name": "output_mapping",
1378
+ "type": {
1379
+ "type": "<class 'lynxkite_graph_analytics.lynxkite_ops.ModelOutputMapping'>"
1380
+ }
1381
+ }
1382
+ },
1383
+ "type": "basic"
1384
+ },
1385
+ "params": {
1386
+ "input_mapping": "{\"map\":{\"Input__embedding_1_x\":{\"column\":\"x\",\"df\":\"df_test\"}}}",
1387
+ "model_name": "model",
1388
+ "output_mapping": "{\"map\":{\"Activation_2_x\":{\"column\":\"predicted\",\"df\":\"df_test\"}}}"
1389
+ },
1390
+ "status": "done",
1391
+ "title": "Model inference"
1392
+ },
1393
+ "dragHandle": ".bg-primary",
1394
+ "height": 893.0,
1395
+ "id": "Model inference 1",
1396
+ "position": {
1397
+ "x": 2181.718373860645,
1398
+ "y": -69.44701793295484
1399
+ },
1400
+ "type": "basic",
1401
+ "width": 529.0
1402
+ }
1403
+ ]
1404
+ }
examples/ODE-GNN ADDED
@@ -0,0 +1,796 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "edges": [
3
+ {
4
+ "id": "Input: embedding 1 Graph conv 1",
5
+ "source": "Input: embedding 1",
6
+ "sourceHandle": "x",
7
+ "target": "Graph conv 1",
8
+ "targetHandle": "x"
9
+ },
10
+ {
11
+ "id": "Input: graph edges 1 Graph conv 1",
12
+ "source": "Input: graph edges 1",
13
+ "sourceHandle": "edges",
14
+ "target": "Graph conv 1",
15
+ "targetHandle": "edges"
16
+ },
17
+ {
18
+ "id": "Graph conv 1 Activation 1",
19
+ "source": "Graph conv 1",
20
+ "sourceHandle": "x",
21
+ "target": "Activation 1",
22
+ "targetHandle": "x"
23
+ },
24
+ {
25
+ "id": "Activation 1 Repeat 1",
26
+ "source": "Activation 1",
27
+ "sourceHandle": "x",
28
+ "target": "Repeat 1",
29
+ "targetHandle": "input"
30
+ },
31
+ {
32
+ "id": "Repeat 1 Graph conv 1",
33
+ "source": "Repeat 1",
34
+ "sourceHandle": "output",
35
+ "target": "Graph conv 1",
36
+ "targetHandle": "x"
37
+ },
38
+ {
39
+ "id": "Input: sequential 1 LSTM 1",
40
+ "source": "Input: sequential 1",
41
+ "sourceHandle": "y",
42
+ "target": "LSTM 1",
43
+ "targetHandle": "x"
44
+ },
45
+ {
46
+ "id": "Input: zeros 1 LSTM 1",
47
+ "source": "Input: zeros 1",
48
+ "sourceHandle": "x",
49
+ "target": "LSTM 1",
50
+ "targetHandle": "h"
51
+ },
52
+ {
53
+ "id": "Recurrent chain 1 LSTM 1",
54
+ "source": "Recurrent chain 1",
55
+ "sourceHandle": "output",
56
+ "target": "LSTM 1",
57
+ "targetHandle": "h"
58
+ },
59
+ {
60
+ "id": "LSTM 1 Recurrent chain 1",
61
+ "source": "LSTM 1",
62
+ "sourceHandle": "h",
63
+ "target": "Recurrent chain 1",
64
+ "targetHandle": "input"
65
+ },
66
+ {
67
+ "id": "Activation 1 Concatenate 1",
68
+ "source": "Activation 1",
69
+ "sourceHandle": "x",
70
+ "target": "Concatenate 1",
71
+ "targetHandle": "a"
72
+ },
73
+ {
74
+ "id": "LSTM 1 Concatenate 1",
75
+ "source": "LSTM 1",
76
+ "sourceHandle": "x",
77
+ "target": "Concatenate 1",
78
+ "targetHandle": "b"
79
+ },
80
+ {
81
+ "id": "Input: label 1 MSE loss 1",
82
+ "source": "Input: label 1",
83
+ "sourceHandle": "y",
84
+ "target": "MSE loss 1",
85
+ "targetHandle": "y"
86
+ },
87
+ {
88
+ "id": "MSE loss 1 Optimizer 1",
89
+ "source": "MSE loss 1",
90
+ "sourceHandle": "loss",
91
+ "target": "Optimizer 1",
92
+ "targetHandle": "loss"
93
+ },
94
+ {
95
+ "id": "Concatenate 1 Neural ODE 2",
96
+ "source": "Concatenate 1",
97
+ "sourceHandle": "x",
98
+ "target": "Neural ODE 2",
99
+ "targetHandle": "x"
100
+ },
101
+ {
102
+ "id": "Neural ODE 2 MSE loss 1",
103
+ "source": "Neural ODE 2",
104
+ "sourceHandle": "x",
105
+ "target": "MSE loss 1",
106
+ "targetHandle": "x"
107
+ }
108
+ ],
109
+ "env": "PyTorch model",
110
+ "nodes": [
111
+ {
112
+ "data": {
113
+ "display": null,
114
+ "error": null,
115
+ "meta": {
116
+ "inputs": {
117
+ "edges": {
118
+ "name": "edges",
119
+ "position": "bottom",
120
+ "type": {
121
+ "type": "tensor"
122
+ }
123
+ },
124
+ "x": {
125
+ "name": "x",
126
+ "position": "bottom",
127
+ "type": {
128
+ "type": "tensor"
129
+ }
130
+ }
131
+ },
132
+ "name": "Graph conv",
133
+ "outputs": {
134
+ "x": {
135
+ "name": "x",
136
+ "position": "top",
137
+ "type": {
138
+ "type": "tensor"
139
+ }
140
+ }
141
+ },
142
+ "params": {
143
+ "type": {
144
+ "default": "1",
145
+ "name": "type",
146
+ "type": {
147
+ "enum": [
148
+ "GCNConv",
149
+ "GATConv",
150
+ "GATv2Conv",
151
+ "SAGEConv"
152
+ ]
153
+ }
154
+ }
155
+ },
156
+ "type": "basic"
157
+ },
158
+ "params": {
159
+ "type": 1.0
160
+ },
161
+ "status": "planned",
162
+ "title": "Graph conv"
163
+ },
164
+ "dragHandle": ".bg-primary",
165
+ "height": 200.0,
166
+ "id": "Graph conv 1",
167
+ "position": {
168
+ "x": 350.98078368755864,
169
+ "y": 195.0
170
+ },
171
+ "type": "basic",
172
+ "width": 200.0
173
+ },
174
+ {
175
+ "data": {
176
+ "__execution_delay": 0.0,
177
+ "collapsed": null,
178
+ "display": null,
179
+ "error": null,
180
+ "meta": {
181
+ "inputs": {
182
+ "input": {
183
+ "name": "input",
184
+ "position": "top",
185
+ "type": {
186
+ "type": "tensor"
187
+ }
188
+ }
189
+ },
190
+ "name": "Repeat",
191
+ "outputs": {
192
+ "output": {
193
+ "name": "output",
194
+ "position": "bottom",
195
+ "type": {
196
+ "type": "tensor"
197
+ }
198
+ }
199
+ },
200
+ "params": {
201
+ "times": {
202
+ "default": 1.0,
203
+ "name": "times",
204
+ "type": {
205
+ "type": "<class 'int'>"
206
+ }
207
+ }
208
+ },
209
+ "type": "basic"
210
+ },
211
+ "params": {
212
+ "times": "5"
213
+ },
214
+ "status": "planned",
215
+ "title": "Repeat"
216
+ },
217
+ "dragHandle": ".bg-primary",
218
+ "height": 200.0,
219
+ "id": "Repeat 1",
220
+ "position": {
221
+ "x": -94.15168677219138,
222
+ "y": 14.525356969883305
223
+ },
224
+ "type": "basic",
225
+ "width": 200.0
226
+ },
227
+ {
228
+ "data": {
229
+ "__execution_delay": null,
230
+ "collapsed": true,
231
+ "display": null,
232
+ "error": null,
233
+ "meta": {
234
+ "inputs": {
235
+ "a": {
236
+ "name": "a",
237
+ "position": "bottom",
238
+ "type": {
239
+ "type": "tensor"
240
+ }
241
+ },
242
+ "b": {
243
+ "name": "b",
244
+ "position": "bottom",
245
+ "type": {
246
+ "type": "tensor"
247
+ }
248
+ }
249
+ },
250
+ "name": "Concatenate",
251
+ "outputs": {
252
+ "x": {
253
+ "name": "x",
254
+ "position": "top",
255
+ "type": {
256
+ "type": "tensor"
257
+ }
258
+ }
259
+ },
260
+ "params": {},
261
+ "type": "basic"
262
+ },
263
+ "params": {},
264
+ "status": "planned",
265
+ "title": "Concatenate"
266
+ },
267
+ "dragHandle": ".bg-primary",
268
+ "height": 200.0,
269
+ "id": "Concatenate 1",
270
+ "position": {
271
+ "x": 477.88148637482334,
272
+ "y": -372.62774030487003
273
+ },
274
+ "type": "basic",
275
+ "width": 200.0
276
+ },
277
+ {
278
+ "data": {
279
+ "__execution_delay": null,
280
+ "collapsed": true,
281
+ "display": null,
282
+ "error": null,
283
+ "meta": {
284
+ "inputs": {},
285
+ "name": "Input: graph edges",
286
+ "outputs": {
287
+ "edges": {
288
+ "name": "edges",
289
+ "position": "top",
290
+ "type": {
291
+ "type": "tensor"
292
+ }
293
+ }
294
+ },
295
+ "params": {},
296
+ "type": "basic"
297
+ },
298
+ "params": {},
299
+ "status": "planned",
300
+ "title": "Input: graph edges"
301
+ },
302
+ "dragHandle": ".bg-primary",
303
+ "height": 200.0,
304
+ "id": "Input: graph edges 1",
305
+ "position": {
306
+ "x": 515.6535517374441,
307
+ "y": 545.4709559884296
308
+ },
309
+ "type": "basic",
310
+ "width": 200.0
311
+ },
312
+ {
313
+ "data": {
314
+ "__execution_delay": null,
315
+ "collapsed": true,
316
+ "display": null,
317
+ "error": null,
318
+ "meta": {
319
+ "inputs": {},
320
+ "name": "Input: embedding",
321
+ "outputs": {
322
+ "x": {
323
+ "name": "x",
324
+ "position": "top",
325
+ "type": {
326
+ "type": "tensor"
327
+ }
328
+ }
329
+ },
330
+ "params": {},
331
+ "type": "basic"
332
+ },
333
+ "params": {},
334
+ "status": "planned",
335
+ "title": "Input: embedding"
336
+ },
337
+ "dragHandle": ".bg-primary",
338
+ "height": 200.0,
339
+ "id": "Input: embedding 1",
340
+ "position": {
341
+ "x": 246.6527948448857,
342
+ "y": 551.6313504198322
343
+ },
344
+ "type": "basic",
345
+ "width": 200.0
346
+ },
347
+ {
348
+ "data": {
349
+ "display": null,
350
+ "error": null,
351
+ "meta": {
352
+ "inputs": {
353
+ "x": {
354
+ "name": "x",
355
+ "position": "bottom",
356
+ "type": {
357
+ "type": "tensor"
358
+ }
359
+ }
360
+ },
361
+ "name": "Activation",
362
+ "outputs": {
363
+ "x": {
364
+ "name": "x",
365
+ "position": "top",
366
+ "type": {
367
+ "type": "tensor"
368
+ }
369
+ }
370
+ },
371
+ "params": {
372
+ "type": {
373
+ "default": "1",
374
+ "name": "type",
375
+ "type": {
376
+ "enum": [
377
+ "ReLU",
378
+ "LeakyReLU",
379
+ "Tanh",
380
+ "Mish"
381
+ ]
382
+ }
383
+ }
384
+ },
385
+ "type": "basic"
386
+ },
387
+ "params": {
388
+ "type": 1.0
389
+ },
390
+ "status": "planned",
391
+ "title": "Activation"
392
+ },
393
+ "dragHandle": ".bg-primary",
394
+ "height": 200.0,
395
+ "id": "Activation 1",
396
+ "position": {
397
+ "x": 354.3731834561054,
398
+ "y": -73.74768512965228
399
+ },
400
+ "type": "basic",
401
+ "width": 200.0
402
+ },
403
+ {
404
+ "data": {
405
+ "__execution_delay": null,
406
+ "collapsed": true,
407
+ "display": null,
408
+ "error": null,
409
+ "meta": {
410
+ "inputs": {
411
+ "h": {
412
+ "name": "h",
413
+ "position": "bottom",
414
+ "type": {
415
+ "type": "tensor"
416
+ }
417
+ },
418
+ "x": {
419
+ "name": "x",
420
+ "position": "bottom",
421
+ "type": {
422
+ "type": "tensor"
423
+ }
424
+ }
425
+ },
426
+ "name": "LSTM",
427
+ "outputs": {
428
+ "h": {
429
+ "name": "h",
430
+ "position": "top",
431
+ "type": {
432
+ "type": "tensor"
433
+ }
434
+ },
435
+ "x": {
436
+ "name": "x",
437
+ "position": "top",
438
+ "type": {
439
+ "type": "tensor"
440
+ }
441
+ }
442
+ },
443
+ "params": {},
444
+ "type": "basic"
445
+ },
446
+ "params": {},
447
+ "status": "planned",
448
+ "title": "LSTM"
449
+ },
450
+ "dragHandle": ".bg-primary",
451
+ "height": 200.0,
452
+ "id": "LSTM 1",
453
+ "position": {
454
+ "x": 960.0,
455
+ "y": 135.0
456
+ },
457
+ "type": "basic",
458
+ "width": 200.0
459
+ },
460
+ {
461
+ "data": {
462
+ "__execution_delay": null,
463
+ "collapsed": true,
464
+ "display": null,
465
+ "error": null,
466
+ "meta": {
467
+ "inputs": {},
468
+ "name": "Input: sequential",
469
+ "outputs": {
470
+ "y": {
471
+ "name": "y",
472
+ "position": "top",
473
+ "type": {
474
+ "type": "tensor"
475
+ }
476
+ }
477
+ },
478
+ "params": {},
479
+ "type": "basic"
480
+ },
481
+ "params": {},
482
+ "status": "planned",
483
+ "title": "Input: sequential"
484
+ },
485
+ "dragHandle": ".bg-primary",
486
+ "height": 200.0,
487
+ "id": "Input: sequential 1",
488
+ "position": {
489
+ "x": 1005.0,
490
+ "y": 510.0
491
+ },
492
+ "type": "basic",
493
+ "width": 200.0
494
+ },
495
+ {
496
+ "data": {
497
+ "__execution_delay": null,
498
+ "collapsed": true,
499
+ "display": null,
500
+ "error": null,
501
+ "meta": {
502
+ "inputs": {},
503
+ "name": "Input: zeros",
504
+ "outputs": {
505
+ "x": {
506
+ "name": "x",
507
+ "position": "top",
508
+ "type": {
509
+ "type": "tensor"
510
+ }
511
+ }
512
+ },
513
+ "params": {},
514
+ "type": "basic"
515
+ },
516
+ "params": {},
517
+ "status": "planned",
518
+ "title": "Input: zeros"
519
+ },
520
+ "dragHandle": ".bg-primary",
521
+ "height": 200.0,
522
+ "id": "Input: zeros 1",
523
+ "position": {
524
+ "x": 1290.0,
525
+ "y": 405.0
526
+ },
527
+ "type": "basic",
528
+ "width": 200.0
529
+ },
530
+ {
531
+ "data": {
532
+ "__execution_delay": null,
533
+ "collapsed": true,
534
+ "display": null,
535
+ "error": null,
536
+ "meta": {
537
+ "inputs": {
538
+ "input": {
539
+ "name": "input",
540
+ "position": "top",
541
+ "type": {
542
+ "type": "tensor"
543
+ }
544
+ }
545
+ },
546
+ "name": "Recurrent chain",
547
+ "outputs": {
548
+ "output": {
549
+ "name": "output",
550
+ "position": "bottom",
551
+ "type": {
552
+ "type": "tensor"
553
+ }
554
+ }
555
+ },
556
+ "params": {},
557
+ "type": "basic"
558
+ },
559
+ "params": {},
560
+ "status": "planned",
561
+ "title": "Recurrent chain"
562
+ },
563
+ "dragHandle": ".bg-primary",
564
+ "height": 200.0,
565
+ "id": "Recurrent chain 1",
566
+ "position": {
567
+ "x": 1224.6603040746108,
568
+ "y": 135.44839862151363
569
+ },
570
+ "type": "basic",
571
+ "width": 200.0
572
+ },
573
+ {
574
+ "data": {
575
+ "__execution_delay": null,
576
+ "collapsed": true,
577
+ "display": null,
578
+ "error": null,
579
+ "meta": {
580
+ "inputs": {
581
+ "x": {
582
+ "name": "x",
583
+ "position": "bottom",
584
+ "type": {
585
+ "type": "tensor"
586
+ }
587
+ },
588
+ "y": {
589
+ "name": "y",
590
+ "position": "bottom",
591
+ "type": {
592
+ "type": "tensor"
593
+ }
594
+ }
595
+ },
596
+ "name": "MSE loss",
597
+ "outputs": {
598
+ "loss": {
599
+ "name": "loss",
600
+ "position": "top",
601
+ "type": {
602
+ "type": "tensor"
603
+ }
604
+ }
605
+ },
606
+ "params": {},
607
+ "type": "basic"
608
+ },
609
+ "params": {},
610
+ "status": "planned",
611
+ "title": "MSE loss"
612
+ },
613
+ "dragHandle": ".bg-primary",
614
+ "height": 200.0,
615
+ "id": "MSE loss 1",
616
+ "position": {
617
+ "x": 915.0,
618
+ "y": -900.0
619
+ },
620
+ "type": "basic",
621
+ "width": 200.0
622
+ },
623
+ {
624
+ "data": {
625
+ "__execution_delay": null,
626
+ "collapsed": true,
627
+ "display": null,
628
+ "error": null,
629
+ "meta": {
630
+ "inputs": {},
631
+ "name": "Input: label",
632
+ "outputs": {
633
+ "y": {
634
+ "name": "y",
635
+ "position": "top",
636
+ "type": {
637
+ "type": "tensor"
638
+ }
639
+ }
640
+ },
641
+ "params": {},
642
+ "type": "basic"
643
+ },
644
+ "params": {},
645
+ "status": "planned",
646
+ "title": "Input: label"
647
+ },
648
+ "dragHandle": ".bg-primary",
649
+ "height": 200.0,
650
+ "id": "Input: label 1",
651
+ "position": {
652
+ "x": 1095.0,
653
+ "y": -450.0
654
+ },
655
+ "type": "basic",
656
+ "width": 200.0
657
+ },
658
+ {
659
+ "data": {
660
+ "display": null,
661
+ "error": null,
662
+ "meta": {
663
+ "inputs": {
664
+ "loss": {
665
+ "name": "loss",
666
+ "position": "bottom",
667
+ "type": {
668
+ "type": "tensor"
669
+ }
670
+ }
671
+ },
672
+ "name": "Optimizer",
673
+ "outputs": {},
674
+ "params": {
675
+ "lr": {
676
+ "default": 0.001,
677
+ "name": "lr",
678
+ "type": {
679
+ "type": "<class 'float'>"
680
+ }
681
+ },
682
+ "type": {
683
+ "default": "1",
684
+ "name": "type",
685
+ "type": {
686
+ "enum": [
687
+ "AdamW",
688
+ "Adafactor",
689
+ "Adagrad",
690
+ "SGD",
691
+ "Lion",
692
+ "Paged AdamW",
693
+ "Galore AdamW"
694
+ ]
695
+ }
696
+ }
697
+ },
698
+ "type": "basic"
699
+ },
700
+ "params": {
701
+ "lr": 0.001,
702
+ "type": 1.0
703
+ },
704
+ "status": "planned",
705
+ "title": "Optimizer"
706
+ },
707
+ "dragHandle": ".bg-primary",
708
+ "height": 247.0,
709
+ "id": "Optimizer 1",
710
+ "position": {
711
+ "x": 915.3430278730226,
712
+ "y": -1268.0577550022126
713
+ },
714
+ "type": "basic",
715
+ "width": 190.0
716
+ },
717
+ {
718
+ "data": {
719
+ "display": null,
720
+ "error": null,
721
+ "meta": {
722
+ "inputs": {
723
+ "x": {
724
+ "name": "x",
725
+ "position": "bottom",
726
+ "type": {
727
+ "type": "tensor"
728
+ }
729
+ }
730
+ },
731
+ "name": "Neural ODE",
732
+ "outputs": {
733
+ "x": {
734
+ "name": "x",
735
+ "position": "top",
736
+ "type": {
737
+ "type": "tensor"
738
+ }
739
+ }
740
+ },
741
+ "params": {
742
+ "absolute_tolerance": {
743
+ "default": null,
744
+ "name": "absolute_tolerance",
745
+ "type": {
746
+ "type": "None"
747
+ }
748
+ },
749
+ "method": {
750
+ "default": "1",
751
+ "name": "method",
752
+ "type": {
753
+ "enum": [
754
+ "dopri8",
755
+ "dopri5",
756
+ "bosh3",
757
+ "fehlberg2",
758
+ "adaptive_heun",
759
+ "euler",
760
+ "midpoint",
761
+ "rk4",
762
+ "explicit_adams",
763
+ "implicit_adams"
764
+ ]
765
+ }
766
+ },
767
+ "relative_tolerance": {
768
+ "default": null,
769
+ "name": "relative_tolerance",
770
+ "type": {
771
+ "type": "None"
772
+ }
773
+ }
774
+ },
775
+ "type": "basic"
776
+ },
777
+ "params": {
778
+ "absolute_tolerance": null,
779
+ "method": 1.0,
780
+ "relative_tolerance": null
781
+ },
782
+ "status": "planned",
783
+ "title": "Neural ODE"
784
+ },
785
+ "dragHandle": ".bg-primary",
786
+ "height": 200.0,
787
+ "id": "Neural ODE 2",
788
+ "position": {
789
+ "x": 342.3226409443945,
790
+ "y": -687.1882072175634
791
+ },
792
+ "type": "basic",
793
+ "width": 200.0
794
+ }
795
+ ]
796
+ }
examples/ODE-GNN experiment ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "edges": [
3
+ {
4
+ "id": "Import CSV 1 Train/test split 1",
5
+ "source": "Import CSV 1",
6
+ "sourceHandle": "output",
7
+ "target": "Train/test split 1",
8
+ "targetHandle": "bundle"
9
+ },
10
+ {
11
+ "id": "Train/test split 1 Create graph 1",
12
+ "source": "Train/test split 1",
13
+ "sourceHandle": "output",
14
+ "target": "Create graph 1",
15
+ "targetHandle": "bundle"
16
+ },
17
+ {
18
+ "id": "Biomedical foundation graph (PLACEHOLDER) 1 Create graph 1",
19
+ "source": "Biomedical foundation graph (PLACEHOLDER) 1",
20
+ "sourceHandle": "output",
21
+ "target": "Create graph 1",
22
+ "targetHandle": "bundle"
23
+ },
24
+ {
25
+ "id": "Define model 1 Create graph 1",
26
+ "source": "Define model 1",
27
+ "sourceHandle": "output",
28
+ "target": "Create graph 1",
29
+ "targetHandle": "bundle"
30
+ },
31
+ {
32
+ "id": "Create graph 1 Train model 1",
33
+ "source": "Create graph 1",
34
+ "sourceHandle": "output",
35
+ "target": "Train model 1",
36
+ "targetHandle": "bundle"
37
+ },
38
+ {
39
+ "id": "Train model 1 Model inference 1",
40
+ "source": "Train model 1",
41
+ "sourceHandle": "output",
42
+ "target": "Model inference 1",
43
+ "targetHandle": "bundle"
44
+ }
45
+ ],
46
+ "env": "LynxKite Graph Analytics",
47
+ "nodes": [
48
+ {
49
+ "data": {
50
+ "__execution_delay": 0.0,
51
+ "collapsed": null,
52
+ "display": null,
53
+ "error": null,
54
+ "meta": {
55
+ "inputs": {},
56
+ "name": "Biomedical foundation graph (PLACEHOLDER)",
57
+ "outputs": {
58
+ "output": {
59
+ "name": "output",
60
+ "position": "right",
61
+ "type": {
62
+ "type": "None"
63
+ }
64
+ }
65
+ },
66
+ "params": {
67
+ "filter_nodes": {
68
+ "default": null,
69
+ "name": "filter_nodes",
70
+ "type": {
71
+ "type": "<class 'str'>"
72
+ }
73
+ }
74
+ },
75
+ "type": "basic"
76
+ },
77
+ "params": {
78
+ "filter_nodes": "drug,gene,disease"
79
+ },
80
+ "status": "done",
81
+ "title": "Biomedical foundation graph (PLACEHOLDER)"
82
+ },
83
+ "dragHandle": ".bg-primary",
84
+ "height": 200.0,
85
+ "id": "Biomedical foundation graph (PLACEHOLDER) 1",
86
+ "position": {
87
+ "x": 230.1082040835347,
88
+ "y": 643.2454063689602
89
+ },
90
+ "type": "basic",
91
+ "width": 200.0
92
+ },
93
+ {
94
+ "data": {
95
+ "__execution_delay": null,
96
+ "collapsed": true,
97
+ "display": null,
98
+ "error": "Missing input: bundle",
99
+ "meta": {
100
+ "inputs": {
101
+ "bundle": {
102
+ "name": "bundle",
103
+ "position": "left",
104
+ "type": {
105
+ "type": "<class 'lynxkite_graph_analytics.core.Bundle'>"
106
+ }
107
+ }
108
+ },
109
+ "name": "Train/test split",
110
+ "outputs": {
111
+ "output": {
112
+ "name": "output",
113
+ "position": "right",
114
+ "type": {
115
+ "type": "None"
116
+ }
117
+ }
118
+ },
119
+ "params": {
120
+ "table_name": {
121
+ "default": null,
122
+ "name": "table_name",
123
+ "type": {
124
+ "type": "<class 'str'>"
125
+ }
126
+ },
127
+ "test_ratio": {
128
+ "default": 0.1,
129
+ "name": "test_ratio",
130
+ "type": {
131
+ "type": "<class 'float'>"
132
+ }
133
+ }
134
+ },
135
+ "type": "basic"
136
+ },
137
+ "params": {
138
+ "table_name": null,
139
+ "test_ratio": 0.1
140
+ },
141
+ "status": "planned",
142
+ "title": "Train/test split"
143
+ },
144
+ "dragHandle": ".bg-primary",
145
+ "height": 200.0,
146
+ "id": "Train/test split 1",
147
+ "position": {
148
+ "x": 313.3745540124723,
149
+ "y": 412.5466021460861
150
+ },
151
+ "type": "basic",
152
+ "width": 200.0
153
+ },
154
+ {
155
+ "data": {
156
+ "display": null,
157
+ "error": "[Errno 2] No such file or directory: ''",
158
+ "meta": {
159
+ "inputs": {},
160
+ "name": "Import CSV",
161
+ "outputs": {
162
+ "output": {
163
+ "name": "output",
164
+ "position": "right",
165
+ "type": {
166
+ "type": "None"
167
+ }
168
+ }
169
+ },
170
+ "params": {
171
+ "columns": {
172
+ "default": "<from file>",
173
+ "name": "columns",
174
+ "type": {
175
+ "type": "<class 'str'>"
176
+ }
177
+ },
178
+ "filename": {
179
+ "default": null,
180
+ "name": "filename",
181
+ "type": {
182
+ "type": "<class 'str'>"
183
+ }
184
+ },
185
+ "separator": {
186
+ "default": "<auto>",
187
+ "name": "separator",
188
+ "type": {
189
+ "type": "<class 'str'>"
190
+ }
191
+ }
192
+ },
193
+ "type": "basic"
194
+ },
195
+ "params": {
196
+ "columns": "<from file>",
197
+ "filename": null,
198
+ "separator": "<auto>"
199
+ },
200
+ "status": "done",
201
+ "title": "Import CSV"
202
+ },
203
+ "dragHandle": ".bg-primary",
204
+ "height": 200.0,
205
+ "id": "Import CSV 1",
206
+ "position": {
207
+ "x": -2.1743215714344757,
208
+ "y": 346.06014722935214
209
+ },
210
+ "type": "basic",
211
+ "width": 200.0
212
+ },
213
+ {
214
+ "data": {
215
+ "__execution_delay": 0.0,
216
+ "collapsed": null,
217
+ "display": null,
218
+ "error": "Missing input: bundle",
219
+ "meta": {
220
+ "inputs": {
221
+ "bundle": {
222
+ "name": "bundle",
223
+ "position": "left",
224
+ "type": {
225
+ "type": "<class 'lynxkite_graph_analytics.core.Bundle'>"
226
+ }
227
+ }
228
+ },
229
+ "name": "Model inference",
230
+ "outputs": {
231
+ "output": {
232
+ "name": "output",
233
+ "position": "right",
234
+ "type": {
235
+ "type": "None"
236
+ }
237
+ }
238
+ },
239
+ "params": {
240
+ "model_mapping": {
241
+ "default": null,
242
+ "name": "model_mapping",
243
+ "type": {
244
+ "type": "<class 'str'>"
245
+ }
246
+ },
247
+ "model_name": {
248
+ "default": null,
249
+ "name": "model_name",
250
+ "type": {
251
+ "type": "<class 'str'>"
252
+ }
253
+ },
254
+ "save_output_as": {
255
+ "default": "prediction",
256
+ "name": "save_output_as",
257
+ "type": {
258
+ "type": "<class 'str'>"
259
+ }
260
+ }
261
+ },
262
+ "type": "basic"
263
+ },
264
+ "params": {
265
+ "model_mapping": "input: data_test",
266
+ "model_name": "model",
267
+ "save_output_as": "prediction"
268
+ },
269
+ "status": "done",
270
+ "title": "Model inference"
271
+ },
272
+ "dragHandle": ".bg-primary",
273
+ "height": 339.0,
274
+ "id": "Model inference 1",
275
+ "position": {
276
+ "x": 1736.5697434242886,
277
+ "y": 357.0743204289906
278
+ },
279
+ "type": "basic",
280
+ "width": 281.0
281
+ },
282
+ {
283
+ "data": {
284
+ "__execution_delay": null,
285
+ "collapsed": true,
286
+ "display": null,
287
+ "error": "Missing input: bundle",
288
+ "meta": {
289
+ "inputs": {
290
+ "bundle": {
291
+ "name": "bundle",
292
+ "position": "left",
293
+ "type": {
294
+ "type": "list[lynxkite_graph_analytics.core.Bundle]"
295
+ }
296
+ }
297
+ },
298
+ "name": "Organize",
299
+ "outputs": {
300
+ "output": {
301
+ "name": "output",
302
+ "position": "right",
303
+ "type": {
304
+ "type": "None"
305
+ }
306
+ }
307
+ },
308
+ "params": {
309
+ "relations": {
310
+ "default": null,
311
+ "name": "relations",
312
+ "type": {
313
+ "type": "<class 'str'>"
314
+ }
315
+ }
316
+ },
317
+ "type": "graph_creation_view"
318
+ },
319
+ "params": {
320
+ "relations": null
321
+ },
322
+ "status": "planned",
323
+ "title": "Organize"
324
+ },
325
+ "dragHandle": ".bg-primary",
326
+ "height": 322.0,
327
+ "id": "Create graph 1",
328
+ "position": {
329
+ "x": 846.6882598271658,
330
+ "y": 480.6258932907771
331
+ },
332
+ "type": "graph_creation_view",
333
+ "width": 313.0
334
+ },
335
+ {
336
+ "data": {
337
+ "__execution_delay": 0.0,
338
+ "collapsed": null,
339
+ "display": null,
340
+ "error": null,
341
+ "meta": {
342
+ "inputs": {},
343
+ "name": "Define model",
344
+ "outputs": {
345
+ "output": {
346
+ "name": "output",
347
+ "position": "right",
348
+ "type": {
349
+ "type": "None"
350
+ }
351
+ }
352
+ },
353
+ "params": {
354
+ "model_workspace": {
355
+ "default": null,
356
+ "name": "model_workspace",
357
+ "type": {
358
+ "type": "<class 'str'>"
359
+ }
360
+ },
361
+ "save_as": {
362
+ "default": "model",
363
+ "name": "save_as",
364
+ "type": {
365
+ "type": "<class 'str'>"
366
+ }
367
+ }
368
+ },
369
+ "position": {
370
+ "x": 286.0,
371
+ "y": 208.0
372
+ },
373
+ "type": "basic"
374
+ },
375
+ "params": {
376
+ "model_workspace": "ODE-GNN",
377
+ "save_as": "model"
378
+ },
379
+ "status": "done",
380
+ "title": "Define model"
381
+ },
382
+ "dragHandle": ".bg-primary",
383
+ "height": 200.0,
384
+ "id": "Define model 1",
385
+ "position": {
386
+ "x": 311.976524267066,
387
+ "y": 146.99006795914332
388
+ },
389
+ "type": "basic",
390
+ "width": 200.0
391
+ },
392
+ {
393
+ "data": {
394
+ "__execution_delay": 0.0,
395
+ "collapsed": null,
396
+ "display": null,
397
+ "error": "Missing input: bundle",
398
+ "meta": {
399
+ "inputs": {
400
+ "bundle": {
401
+ "name": "bundle",
402
+ "position": "left",
403
+ "type": {
404
+ "type": "<class 'lynxkite_graph_analytics.core.Bundle'>"
405
+ }
406
+ }
407
+ },
408
+ "name": "Train model",
409
+ "outputs": {
410
+ "output": {
411
+ "name": "output",
412
+ "position": "right",
413
+ "type": {
414
+ "type": "None"
415
+ }
416
+ }
417
+ },
418
+ "params": {
419
+ "epochs": {
420
+ "default": 1.0,
421
+ "name": "epochs",
422
+ "type": {
423
+ "type": "<class 'int'>"
424
+ }
425
+ },
426
+ "model_mapping": {
427
+ "default": null,
428
+ "name": "model_mapping",
429
+ "type": {
430
+ "type": "<class 'str'>"
431
+ }
432
+ },
433
+ "model_name": {
434
+ "default": null,
435
+ "name": "model_name",
436
+ "type": {
437
+ "type": "<class 'str'>"
438
+ }
439
+ }
440
+ },
441
+ "position": {
442
+ "x": 995.0,
443
+ "y": 350.0
444
+ },
445
+ "type": "basic"
446
+ },
447
+ "params": {
448
+ "epochs": 1.0,
449
+ "model_mapping": "input: data_train",
450
+ "model_name": "model"
451
+ },
452
+ "status": "planned",
453
+ "title": "Train model"
454
+ },
455
+ "dragHandle": ".bg-primary",
456
+ "height": 342.0,
457
+ "id": "Train model 1",
458
+ "position": {
459
+ "x": 1358.7213662492159,
460
+ "y": 352.03096133771896
461
+ },
462
+ "type": "basic",
463
+ "width": 296.0
464
+ }
465
+ ]
466
+ }
examples/uploads/plus-one-dataset.parquet ADDED
Binary file (7.54 kB). View file
 
lynxkite-app/src/lynxkite_app/crdt.py CHANGED
@@ -90,6 +90,7 @@ last_ws_input = None
90
  def clean_input(ws_pyd):
91
  for node in ws_pyd.nodes:
92
  node.data.display = None
 
93
  node.data.error = None
94
  node.data.status = workspace.NodeStatus.done
95
  node.position.x = 0
@@ -168,9 +169,12 @@ def try_to_load_workspace(ws: pycrdt.Map, name: str):
168
  """
169
  if os.path.exists(name):
170
  ws_pyd = workspace.load(name)
171
- # We treat the display field as a black box, since it is a large
172
- # dictionary that is meant to change as a whole.
173
- crdt_update(ws, ws_pyd.model_dump(), non_collaborative_fields={"display"})
 
 
 
174
 
175
 
176
  last_known_versions = {}
 
90
  def clean_input(ws_pyd):
91
  for node in ws_pyd.nodes:
92
  node.data.display = None
93
+ node.data.input_metadata = None
94
  node.data.error = None
95
  node.data.status = workspace.NodeStatus.done
96
  node.position.x = 0
 
169
  """
170
  if os.path.exists(name):
171
  ws_pyd = workspace.load(name)
172
+ crdt_update(
173
+ ws,
174
+ ws_pyd.model_dump(),
175
+ # We treat some fields as black boxes. They are not edited on the frontend.
176
+ non_collaborative_fields={"display", "input_metadata"},
177
+ )
178
 
179
 
180
  last_known_versions = {}
lynxkite-app/tests/test_main.py CHANGED
@@ -37,6 +37,7 @@ def test_save_and_load():
37
  "type": "basic",
38
  "data": {
39
  "display": None,
 
40
  "error": "Unknown operation.",
41
  "title": "Test node",
42
  "params": {"param1": "value"},
 
37
  "type": "basic",
38
  "data": {
39
  "display": None,
40
+ "input_metadata": None,
41
  "error": "Unknown operation.",
42
  "title": "Test node",
43
  "params": {"param1": "value"},
lynxkite-app/web/playwright.config.ts CHANGED
@@ -24,7 +24,7 @@ export default defineConfig({
24
  ],
25
  webServer: {
26
  command: "cd ../../examples && lynxkite",
27
- url: "http://127.0.0.1:8000",
28
  reuseExistingServer: false,
29
  },
30
  });
 
24
  ],
25
  webServer: {
26
  command: "cd ../../examples && lynxkite",
27
+ port: 8000,
28
  reuseExistingServer: false,
29
  },
30
  });
lynxkite-app/web/src/apiTypes.ts CHANGED
@@ -5,6 +5,8 @@
5
  /* Do not modify it by hand - just update the pydantic models and then re-run the script
6
  */
7
 
 
 
8
  export interface DirectoryEntry {
9
  name: string;
10
  type: string;
@@ -40,8 +42,9 @@ export interface WorkspaceNodeData {
40
  [k: string]: unknown;
41
  };
42
  display?: unknown;
 
43
  error?: string | null;
44
- in_progress?: boolean;
45
  [k: string]: unknown;
46
  }
47
  export interface Position {
 
5
  /* Do not modify it by hand - just update the pydantic models and then re-run the script
6
  */
7
 
8
+ export type NodeStatus = "planned" | "active" | "done";
9
+
10
  export interface DirectoryEntry {
11
  name: string;
12
  type: string;
 
42
  [k: string]: unknown;
43
  };
44
  display?: unknown;
45
+ input_metadata?: unknown;
46
  error?: string | null;
47
+ status?: NodeStatus;
48
  [k: string]: unknown;
49
  }
50
  export interface Position {
lynxkite-app/web/src/index.css CHANGED
@@ -256,6 +256,14 @@ body {
256
  cursor: pointer;
257
  }
258
  }
 
 
 
 
 
 
 
 
259
  }
260
 
261
  .params-expander {
 
256
  cursor: pointer;
257
  }
258
  }
259
+
260
+ .model-mapping-param {
261
+ border: 1px solid var(--fallback-bc, oklch(var(--bc) / 0.2));
262
+ border-collapse: separate;
263
+ border-radius: 5px;
264
+ padding: 5px 10px;
265
+ width: 100%;
266
+ }
267
  }
268
 
269
  .params-expander {
lynxkite-app/web/src/workspace/Workspace.tsx CHANGED
@@ -383,6 +383,8 @@ function LynxKiteFlow() {
383
  proOptions={{ hideAttribution: true }}
384
  maxZoom={1}
385
  minZoom={0.3}
 
 
386
  defaultEdgeOptions={{
387
  markerEnd: {
388
  type: MarkerType.ArrowClosed,
 
383
  proOptions={{ hideAttribution: true }}
384
  maxZoom={1}
385
  minZoom={0.3}
386
+ zoomOnScroll={false}
387
+ preventScrolling={false}
388
  defaultEdgeOptions={{
389
  markerEnd: {
390
  type: MarkerType.ArrowClosed,
lynxkite-app/web/src/workspace/nodes/NodeGroupParameter.tsx CHANGED
@@ -24,6 +24,7 @@ interface GroupsType {
24
  interface NodeGroupParameterProps {
25
  meta: { selector: SelectorType; groups: GroupsType };
26
  value: any;
 
27
  setParam: (name: string, value: any, options?: { delay: number }) => void;
28
  deleteParam: (name: string, options?: { delay: number }) => void;
29
  }
@@ -31,6 +32,7 @@ interface NodeGroupParameterProps {
31
  export default function NodeGroupParameter({
32
  meta,
33
  value,
 
34
  setParam,
35
  deleteParam,
36
  }: NodeGroupParameterProps) {
@@ -60,6 +62,7 @@ export default function NodeGroupParameter({
60
  name={selector.name}
61
  key={selector.name}
62
  value={selectedValue}
 
63
  meta={selector}
64
  onChange={handleSelectorChange}
65
  />
 
24
  interface NodeGroupParameterProps {
25
  meta: { selector: SelectorType; groups: GroupsType };
26
  value: any;
27
+ data: any;
28
  setParam: (name: string, value: any, options?: { delay: number }) => void;
29
  deleteParam: (name: string, options?: { delay: number }) => void;
30
  }
 
32
  export default function NodeGroupParameter({
33
  meta,
34
  value,
35
+ data,
36
  setParam,
37
  deleteParam,
38
  }: NodeGroupParameterProps) {
 
62
  name={selector.name}
63
  key={selector.name}
64
  value={selectedValue}
65
+ data={data}
66
  meta={selector}
67
  onChange={handleSelectorChange}
68
  />
lynxkite-app/web/src/workspace/nodes/NodeParameter.tsx CHANGED
@@ -1,15 +1,187 @@
1
- const BOOLEAN = "<class 'bool'>";
 
2
 
 
 
 
 
 
 
 
3
  function ParamName({ name }: { name: string }) {
4
  return (
5
  <span className="param-name bg-base-200">{name.replace(/_/g, " ")}</span>
6
  );
7
  }
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  interface NodeParameterProps {
10
  name: string;
11
  value: any;
12
  meta: any;
 
13
  onChange: (value: any, options?: { delay: number }) => void;
14
  }
15
 
@@ -17,6 +189,7 @@ export default function NodeParameter({
17
  name,
18
  value,
19
  meta,
 
20
  onChange,
21
  }: NodeParameterProps) {
22
  return (
@@ -56,29 +229,50 @@ export default function NodeParameter({
56
  ) : meta?.type?.type === BOOLEAN ? (
57
  <div className="form-control">
58
  <label className="label cursor-pointer">
 
59
  <input
60
  className="checkbox"
61
  type="checkbox"
62
  checked={value}
63
  onChange={(evt) => onChange(evt.currentTarget.checked)}
64
  />
65
- {name.replace(/_/g, " ")}
66
  </label>
67
  </div>
68
- ) : (
69
  <>
70
  <ParamName name={name} />
71
- <input
72
- className="input input-bordered w-full"
73
- value={value || ""}
74
- onChange={(evt) => onChange(evt.currentTarget.value, { delay: 2 })}
75
- onBlur={(evt) => onChange(evt.currentTarget.value, { delay: 0 })}
76
- onKeyDown={(evt) =>
77
- evt.code === "Enter" &&
78
- onChange(evt.currentTarget.value, { delay: 0 })
79
- }
 
 
 
 
 
 
80
  />
81
  </>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  )}
83
  </label>
84
  );
 
1
+ // @ts-ignore
2
+ import ArrowsHorizontal from "~icons/tabler/arrows-horizontal.jsx";
3
 
4
+ const BOOLEAN = "<class 'bool'>";
5
+ const MODEL_TRAINING_INPUT_MAPPING =
6
+ "<class 'lynxkite_graph_analytics.lynxkite_ops.ModelTrainingInputMapping'>";
7
+ const MODEL_INFERENCE_INPUT_MAPPING =
8
+ "<class 'lynxkite_graph_analytics.lynxkite_ops.ModelInferenceInputMapping'>";
9
+ const MODEL_OUTPUT_MAPPING =
10
+ "<class 'lynxkite_graph_analytics.lynxkite_ops.ModelOutputMapping'>";
11
  function ParamName({ name }: { name: string }) {
12
  return (
13
  <span className="param-name bg-base-200">{name.replace(/_/g, " ")}</span>
14
  );
15
  }
16
 
17
+ function Input({
18
+ value,
19
+ onChange,
20
+ }: {
21
+ value: string;
22
+ onChange: (value: string, options?: { delay: number }) => void;
23
+ }) {
24
+ return (
25
+ <input
26
+ className="input input-bordered w-full"
27
+ value={value || ""}
28
+ onChange={(evt) => onChange(evt.currentTarget.value, { delay: 2 })}
29
+ onBlur={(evt) => onChange(evt.currentTarget.value, { delay: 0 })}
30
+ onKeyDown={(evt) =>
31
+ evt.code === "Enter" && onChange(evt.currentTarget.value, { delay: 0 })
32
+ }
33
+ />
34
+ );
35
+ }
36
+
37
+ function getModelBindings(
38
+ data: any,
39
+ variant: "training input" | "inference input" | "output",
40
+ ): string[] {
41
+ function bindingsOfModel(m: any): string[] {
42
+ switch (variant) {
43
+ case "training input":
44
+ return [
45
+ ...m.inputs,
46
+ ...m.loss_inputs.filter((i: string) => !m.outputs.includes(i)),
47
+ ];
48
+ case "inference input":
49
+ return m.inputs;
50
+ case "output":
51
+ return m.outputs;
52
+ }
53
+ }
54
+ const bindings = new Set<string>();
55
+ const inputs = data?.input_metadata?.value ?? data?.input_metadata ?? [];
56
+ for (const input of inputs) {
57
+ const other = input.other ?? {};
58
+ for (const e of Object.values(other) as any[]) {
59
+ if (e.type === "model") {
60
+ for (const b of bindingsOfModel(e.model)) {
61
+ bindings.add(b);
62
+ }
63
+ }
64
+ }
65
+ }
66
+ const list = [...bindings];
67
+ list.sort();
68
+ return list;
69
+ }
70
+
71
+ function parseJsonOrEmpty(json: string): object {
72
+ try {
73
+ const j = JSON.parse(json);
74
+ if (j !== null && typeof j === "object") {
75
+ return j;
76
+ }
77
+ } catch (e) {}
78
+ return {};
79
+ }
80
+
81
+ function ModelMapping({ value, onChange, data, variant }: any) {
82
+ const v: any = parseJsonOrEmpty(value);
83
+ v.map ??= {};
84
+ const dfs: { [df: string]: string[] } = {};
85
+ const inputs = data?.input_metadata?.value ?? data?.input_metadata ?? [];
86
+ for (const input of inputs) {
87
+ const dataframes = input.dataframes as {
88
+ [df: string]: { columns: string[] };
89
+ };
90
+ for (const [df, { columns }] of Object.entries(dataframes)) {
91
+ dfs[df] = columns;
92
+ }
93
+ }
94
+ const bindings = getModelBindings(data, variant);
95
+ return (
96
+ <table className="model-mapping-param">
97
+ <tbody>
98
+ {bindings.length > 0 ? (
99
+ bindings.map((binding: string) => (
100
+ <tr key={binding}>
101
+ <td>{binding}</td>
102
+ <td>
103
+ <ArrowsHorizontal />
104
+ </td>
105
+ <td>
106
+ <select
107
+ className="select select-ghost"
108
+ value={v.map?.[binding]?.df}
109
+ onChange={(evt) => {
110
+ const df = evt.currentTarget.value;
111
+ if (df === "") {
112
+ const map = { ...v.map, [binding]: undefined };
113
+ onChange(JSON.stringify({ map }));
114
+ } else {
115
+ const columnSpec = {
116
+ column: dfs[df][0],
117
+ ...(v.map?.[binding] || {}),
118
+ df,
119
+ };
120
+ const map = { ...v.map, [binding]: columnSpec };
121
+ onChange(JSON.stringify({ map }));
122
+ }
123
+ }}
124
+ >
125
+ <option key="" value="" />
126
+ {Object.keys(dfs).map((df: string) => (
127
+ <option key={df} value={df}>
128
+ {df}
129
+ </option>
130
+ ))}
131
+ </select>
132
+ </td>
133
+ <td>
134
+ {variant === "output" ? (
135
+ <Input
136
+ value={v.map?.[binding]?.column}
137
+ onChange={(column, options) => {
138
+ const columnSpec = {
139
+ ...(v.map?.[binding] || {}),
140
+ column,
141
+ };
142
+ const map = { ...v.map, [binding]: columnSpec };
143
+ onChange(JSON.stringify({ map }), options);
144
+ }}
145
+ />
146
+ ) : (
147
+ <select
148
+ className="select select-ghost"
149
+ value={v.map?.[binding]?.column}
150
+ onChange={(evt) => {
151
+ const column = evt.currentTarget.value;
152
+ const columnSpec = {
153
+ ...(v.map?.[binding] || {}),
154
+ column,
155
+ };
156
+ const map = { ...v.map, [binding]: columnSpec };
157
+ onChange(JSON.stringify({ map }));
158
+ }}
159
+ >
160
+ {dfs[v.map?.[binding]?.df]?.map((col: string) => (
161
+ <option key={col} value={col}>
162
+ {col}
163
+ </option>
164
+ ))}
165
+ </select>
166
+ )}
167
+ </td>
168
+ </tr>
169
+ ))
170
+ ) : (
171
+ <tr>
172
+ <td>no bindings</td>
173
+ </tr>
174
+ )}
175
+ </tbody>
176
+ </table>
177
+ );
178
+ }
179
+
180
  interface NodeParameterProps {
181
  name: string;
182
  value: any;
183
  meta: any;
184
+ data: any;
185
  onChange: (value: any, options?: { delay: number }) => void;
186
  }
187
 
 
189
  name,
190
  value,
191
  meta,
192
+ data,
193
  onChange,
194
  }: NodeParameterProps) {
195
  return (
 
229
  ) : meta?.type?.type === BOOLEAN ? (
230
  <div className="form-control">
231
  <label className="label cursor-pointer">
232
+ {name.replace(/_/g, " ")}
233
  <input
234
  className="checkbox"
235
  type="checkbox"
236
  checked={value}
237
  onChange={(evt) => onChange(evt.currentTarget.checked)}
238
  />
 
239
  </label>
240
  </div>
241
+ ) : meta?.type?.type === MODEL_TRAINING_INPUT_MAPPING ? (
242
  <>
243
  <ParamName name={name} />
244
+ <ModelMapping
245
+ value={value}
246
+ data={data}
247
+ variant="training input"
248
+ onChange={onChange}
249
+ />
250
+ </>
251
+ ) : meta?.type?.type === MODEL_INFERENCE_INPUT_MAPPING ? (
252
+ <>
253
+ <ParamName name={name} />
254
+ <ModelMapping
255
+ value={value}
256
+ data={data}
257
+ variant="inference input"
258
+ onChange={onChange}
259
  />
260
  </>
261
+ ) : meta?.type?.type === MODEL_OUTPUT_MAPPING ? (
262
+ <>
263
+ <ParamName name={name} />
264
+ <ModelMapping
265
+ value={value}
266
+ data={data}
267
+ variant="output"
268
+ onChange={onChange}
269
+ />
270
+ </>
271
+ ) : (
272
+ <>
273
+ <ParamName name={name} />
274
+ <Input value={value} onChange={onChange} />
275
+ </>
276
  )}
277
  </label>
278
  );
lynxkite-app/web/src/workspace/nodes/NodeWithParams.tsx CHANGED
@@ -49,6 +49,7 @@ function NodeWithParams(props: any) {
49
  <NodeGroupParameter
50
  key={name}
51
  value={value}
 
52
  meta={metaParams?.[name]}
53
  setParam={(name: string, value: any, opts?: UpdateOptions) =>
54
  setParam(name, value, opts || {})
@@ -62,6 +63,7 @@ function NodeWithParams(props: any) {
62
  name={name}
63
  key={name}
64
  value={value}
 
65
  meta={metaParams?.[name]}
66
  onChange={(value: any, opts?: UpdateOptions) =>
67
  setParam(name, value, opts || {})
 
49
  <NodeGroupParameter
50
  key={name}
51
  value={value}
52
+ data={props.data}
53
  meta={metaParams?.[name]}
54
  setParam={(name: string, value: any, opts?: UpdateOptions) =>
55
  setParam(name, value, opts || {})
 
63
  name={name}
64
  key={name}
65
  value={value}
66
+ data={props.data}
67
  meta={metaParams?.[name]}
68
  onChange={(value: any, opts?: UpdateOptions) =>
69
  setParam(name, value, opts || {})
lynxkite-core/src/lynxkite/core/ops.py CHANGED
@@ -61,7 +61,7 @@ class Parameter(BaseConfig):
61
  @staticmethod
62
  def options(name, options, default=None):
63
  e = enum.Enum(f"OptionsFor_{name}", options)
64
- return Parameter.basic(name, e[default or options[0]], e)
65
 
66
  @staticmethod
67
  def collapsed(name, default, type=None):
@@ -106,11 +106,13 @@ class Result:
106
  The `output` attribute is what will be used as input for other operations.
107
  The `display` attribute is used to send data to display on the UI. The value has to be
108
  JSON-serializable.
 
109
  """
110
 
111
  output: typing.Any = None
112
  display: ReadOnlyJSON | None = None
113
  error: str | None = None
 
114
 
115
 
116
  MULTI_INPUT = Input(name="multi", type="*")
@@ -140,6 +142,11 @@ def _param_to_type(name, value, type):
140
  return None if value == "" else _param_to_type(name, value, type)
141
  case (type, types.NoneType):
142
  return None if value == "" else _param_to_type(name, value, type)
 
 
 
 
 
143
  return value
144
 
145
 
@@ -154,9 +161,7 @@ class Op(BaseConfig):
154
 
155
  def __call__(self, *inputs, **params):
156
  # Convert parameters.
157
- for p in params:
158
- if p in self.params:
159
- params[p] = _param_to_type(p, params[p], self.params[p].type)
160
  res = self.func(*inputs, **params)
161
  if not isinstance(res, Result):
162
  # Automatically wrap the result in a Result object, if it isn't already.
@@ -172,6 +177,16 @@ class Op(BaseConfig):
172
  res.display = res.output
173
  return res
174
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  def op(env: str, name: str, *, view="basic", outputs=None, params=None):
177
  """Decorator for defining an operation."""
 
61
  @staticmethod
62
  def options(name, options, default=None):
63
  e = enum.Enum(f"OptionsFor_{name}", options)
64
+ return Parameter.basic(name, default or options[0], e)
65
 
66
  @staticmethod
67
  def collapsed(name, default, type=None):
 
106
  The `output` attribute is what will be used as input for other operations.
107
  The `display` attribute is used to send data to display on the UI. The value has to be
108
  JSON-serializable.
109
+ `input_metadata` is a list of JSON objects describing each input.
110
  """
111
 
112
  output: typing.Any = None
113
  display: ReadOnlyJSON | None = None
114
  error: str | None = None
115
+ input_metadata: ReadOnlyJSON | None = None
116
 
117
 
118
  MULTI_INPUT = Input(name="multi", type="*")
 
142
  return None if value == "" else _param_to_type(name, value, type)
143
  case (type, types.NoneType):
144
  return None if value == "" else _param_to_type(name, value, type)
145
+ if isinstance(type, typeof) and issubclass(type, pydantic.BaseModel):
146
+ try:
147
+ return type.model_validate_json(value)
148
+ except pydantic.ValidationError:
149
+ return None
150
  return value
151
 
152
 
 
161
 
162
  def __call__(self, *inputs, **params):
163
  # Convert parameters.
164
+ params = self.convert_params(params)
 
 
165
  res = self.func(*inputs, **params)
166
  if not isinstance(res, Result):
167
  # Automatically wrap the result in a Result object, if it isn't already.
 
177
  res.display = res.output
178
  return res
179
 
180
+ def convert_params(self, params):
181
+ """Returns the parameters converted to the expected type."""
182
+ res = {}
183
+ for p in params:
184
+ if p in self.params:
185
+ res[p] = _param_to_type(p, params[p], self.params[p].type)
186
+ else:
187
+ res[p] = params[p]
188
+ return res
189
+
190
 
191
  def op(env: str, name: str, *, view="basic", outputs=None, params=None):
192
  """Decorator for defining an operation."""
lynxkite-core/src/lynxkite/core/workspace.py CHANGED
@@ -32,6 +32,7 @@ class WorkspaceNodeData(BaseConfig):
32
  title: str
33
  params: dict
34
  display: Optional[object] = None
 
35
  error: Optional[str] = None
36
  status: NodeStatus = NodeStatus.done
37
  # Also contains a "meta" field when going out.
@@ -59,12 +60,14 @@ class WorkspaceNode(BaseConfig):
59
  def publish_result(self, result: ops.Result):
60
  """Sends the result to the frontend. Call this in an executor when the result is available."""
61
  self.data.display = result.display
 
62
  self.data.error = result.error
63
  self.data.status = NodeStatus.done
64
  if hasattr(self, "_crdt"):
65
  with self._crdt.doc.transaction():
66
- self._crdt["data"]["display"] = result.display
67
- self._crdt["data"]["error"] = result.error
 
68
  self._crdt["data"]["status"] = NodeStatus.done
69
 
70
  def publish_error(self, error: Exception | str | None):
 
32
  title: str
33
  params: dict
34
  display: Optional[object] = None
35
+ input_metadata: Optional[object] = None
36
  error: Optional[str] = None
37
  status: NodeStatus = NodeStatus.done
38
  # Also contains a "meta" field when going out.
 
60
  def publish_result(self, result: ops.Result):
61
  """Sends the result to the frontend. Call this in an executor when the result is available."""
62
  self.data.display = result.display
63
+ self.data.input_metadata = result.input_metadata
64
  self.data.error = result.error
65
  self.data.status = NodeStatus.done
66
  if hasattr(self, "_crdt"):
67
  with self._crdt.doc.transaction():
68
+ self._crdt["data"]["display"] = self.data.display
69
+ self._crdt["data"]["input_metadata"] = self.data.input_metadata
70
+ self._crdt["data"]["error"] = self.data.error
71
  self._crdt["data"]["status"] = NodeStatus.done
72
 
73
  def publish_error(self, error: Exception | str | None):
lynxkite-graph-analytics/pyproject.toml CHANGED
@@ -14,6 +14,8 @@ dependencies = [
14
  "osmnx>=2.0.1",
15
  "pandas>=2.2.3",
16
  "polars[gpu]>=1.14.0",
 
 
17
  ]
18
 
19
  [project.optional-dependencies]
 
14
  "osmnx>=2.0.1",
15
  "pandas>=2.2.3",
16
  "polars[gpu]>=1.14.0",
17
+ "torch>=2.6.0",
18
+ "torch-geometric>=2.6.1",
19
  ]
20
 
21
  [project.optional-dependencies]
lynxkite-graph-analytics/src/lynxkite_graph_analytics/core.py CHANGED
@@ -42,7 +42,7 @@ class Bundle:
42
 
43
  dfs: dict[str, pd.DataFrame] = dataclasses.field(default_factory=dict)
44
  relations: list[RelationDefinition] = dataclasses.field(default_factory=list)
45
- other: dict[str, typing.Any] = None
46
 
47
  @classmethod
48
  def from_nx(cls, graph: nx.Graph):
@@ -102,10 +102,11 @@ class Bundle:
102
  return Bundle(
103
  dfs=dict(self.dfs),
104
  relations=list(self.relations),
105
- other=dict(self.other) if self.other else None,
106
  )
107
 
108
  def to_dict(self, limit: int = 100):
 
109
  return {
110
  "dataframes": {
111
  name: {
@@ -115,7 +116,22 @@ class Bundle:
115
  for name, df in self.dfs.items()
116
  },
117
  "relations": [dataclasses.asdict(relation) for relation in self.relations],
118
- "other": self.other,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  }
120
 
121
 
@@ -137,8 +153,15 @@ def nx_node_attribute_func(name):
137
 
138
  def disambiguate_edges(ws: workspace.Workspace):
139
  """If an input plug is connected to multiple edges, keep only the last edge."""
 
 
140
  seen = set()
141
  for edge in reversed(ws.edges):
 
 
 
 
 
142
  if (edge.target, edge.targetHandle) in seen:
143
  i = ws.edges.index(edge)
144
  del ws.edges[i]
@@ -174,13 +197,14 @@ def _execute_node(node, ws, catalog, outputs):
174
  node.publish_error("Operation not found in catalog")
175
  return
176
  node.publish_started()
 
177
  input_map = {
178
  edge.targetHandle: outputs[edge.source]
179
  for edge in ws.edges
180
  if edge.target == node.id
181
  }
 
182
  try:
183
- # Convert inputs types to match operation signature.
184
  inputs = []
185
  for p in op.inputs.values():
186
  if p.name not in input_map:
@@ -194,16 +218,30 @@ def _execute_node(node, ws, catalog, outputs):
194
  elif p.type == Bundle and isinstance(x, pd.DataFrame):
195
  x = Bundle.from_df(x)
196
  inputs.append(x)
197
- result = op(*inputs, **params)
198
  except Exception as e:
199
  if os.environ.get("LYNXKITE_LOG_OP_ERRORS"):
200
  traceback.print_exc()
201
  node.publish_error(e)
202
  return
203
- outputs[node.id] = result.output
 
 
 
 
 
 
 
 
 
204
  node.publish_result(result)
205
 
206
 
 
 
 
 
 
 
207
  def df_for_frontend(df: pd.DataFrame, limit: int) -> pd.DataFrame:
208
  """Returns a DataFrame with values that are safe to send to the frontend."""
209
  df = df[:limit]
 
42
 
43
  dfs: dict[str, pd.DataFrame] = dataclasses.field(default_factory=dict)
44
  relations: list[RelationDefinition] = dataclasses.field(default_factory=list)
45
+ other: dict[str, typing.Any] = dataclasses.field(default_factory=dict)
46
 
47
  @classmethod
48
  def from_nx(cls, graph: nx.Graph):
 
102
  return Bundle(
103
  dfs=dict(self.dfs),
104
  relations=list(self.relations),
105
+ other=dict(self.other),
106
  )
107
 
108
  def to_dict(self, limit: int = 100):
109
+ """JSON-serializable representation of the bundle, including some data."""
110
  return {
111
  "dataframes": {
112
  name: {
 
116
  for name, df in self.dfs.items()
117
  },
118
  "relations": [dataclasses.asdict(relation) for relation in self.relations],
119
+ "other": {k: str(v) for k, v in self.other.items()},
120
+ }
121
+
122
+ def metadata(self):
123
+ """JSON-serializable information about the bundle, metadata only."""
124
+ return {
125
+ "dataframes": {
126
+ name: {
127
+ "columns": sorted(str(c) for c in df.columns),
128
+ }
129
+ for name, df in self.dfs.items()
130
+ },
131
+ "relations": [dataclasses.asdict(relation) for relation in self.relations],
132
+ "other": {
133
+ k: getattr(v, "metadata", lambda: {})() for k, v in self.other.items()
134
+ },
135
  }
136
 
137
 
 
153
 
154
  def disambiguate_edges(ws: workspace.Workspace):
155
  """If an input plug is connected to multiple edges, keep only the last edge."""
156
+ catalog = ops.CATALOGS[ws.env]
157
+ nodes = {node.id: node for node in ws.nodes}
158
  seen = set()
159
  for edge in reversed(ws.edges):
160
+ dst_node = nodes[edge.target]
161
+ op = catalog.get(dst_node.data.title)
162
+ if op.inputs[edge.targetHandle].type == list[Bundle]:
163
+ # Takes multiple bundles as an input. No need to disambiguate.
164
+ continue
165
  if (edge.target, edge.targetHandle) in seen:
166
  i = ws.edges.index(edge)
167
  del ws.edges[i]
 
197
  node.publish_error("Operation not found in catalog")
198
  return
199
  node.publish_started()
200
+ # TODO: Handle multi-inputs.
201
  input_map = {
202
  edge.targetHandle: outputs[edge.source]
203
  for edge in ws.edges
204
  if edge.target == node.id
205
  }
206
+ # Convert inputs types to match operation signature.
207
  try:
 
208
  inputs = []
209
  for p in op.inputs.values():
210
  if p.name not in input_map:
 
218
  elif p.type == Bundle and isinstance(x, pd.DataFrame):
219
  x = Bundle.from_df(x)
220
  inputs.append(x)
 
221
  except Exception as e:
222
  if os.environ.get("LYNXKITE_LOG_OP_ERRORS"):
223
  traceback.print_exc()
224
  node.publish_error(e)
225
  return
226
+ # Execute op.
227
+ try:
228
+ result = op(*inputs, **params)
229
+ except Exception as e:
230
+ if os.environ.get("LYNXKITE_LOG_OP_ERRORS"):
231
+ traceback.print_exc()
232
+ result = ops.Result(error=str(e))
233
+ result.input_metadata = [_get_metadata(i) for i in inputs]
234
+ if result.output is not None:
235
+ outputs[node.id] = result.output
236
  node.publish_result(result)
237
 
238
 
239
+ def _get_metadata(x):
240
+ if hasattr(x, "metadata"):
241
+ return x.metadata()
242
+ return {}
243
+
244
+
245
  def df_for_frontend(df: pd.DataFrame, limit: int) -> pd.DataFrame:
246
  """Returns a DataFrame with values that are safe to send to the frontend."""
247
  df = df[:limit]
lynxkite-graph-analytics/src/lynxkite_graph_analytics/lynxkite_ops.py CHANGED
@@ -2,10 +2,14 @@
2
 
3
  import enum
4
  import os
 
5
  import fsspec
6
  from lynxkite.core import ops
7
  from collections import deque
8
- from . import core
 
 
 
9
  import grandcypher
10
  import joblib
11
  import matplotlib
@@ -165,11 +169,11 @@ def cypher(bundle: core.Bundle, *, query: ops.LongStr, save_as: str = "result"):
165
  return bundle
166
 
167
 
168
- @op("Organize bundle")
169
- def organize_bundle(bundle: core.Bundle, *, code: ops.LongStr):
170
  """Lets you rename/copy/delete DataFrames, and modify relations.
171
 
172
- TODO: Use a declarative solution instead of Python code. Add UI.
173
  """
174
  bundle = bundle.copy()
175
  exec(code, globals(), {"bundle": bundle})
@@ -342,3 +346,108 @@ def create_graph(bundle: core.Bundle, *, relations: str = None) -> core.Bundle:
342
  core.RelationDefinition(**r) for r in json.loads(relations).values()
343
  ]
344
  return ops.Result(output=bundle, display=bundle.to_dict(limit=100))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import enum
4
  import os
5
+ import pathlib
6
  import fsspec
7
  from lynxkite.core import ops
8
  from collections import deque
9
+
10
+ from tqdm import tqdm
11
+ from . import core, pytorch_model_ops
12
+ from lynxkite.core import workspace
13
  import grandcypher
14
  import joblib
15
  import matplotlib
 
169
  return bundle
170
 
171
 
172
+ @op("Organize")
173
+ def organize(bundle: list[core.Bundle], *, code: ops.LongStr) -> core.Bundle:
174
  """Lets you rename/copy/delete DataFrames, and modify relations.
175
 
176
+ TODO: Merge this with "Create graph".
177
  """
178
  bundle = bundle.copy()
179
  exec(code, globals(), {"bundle": bundle})
 
346
  core.RelationDefinition(**r) for r in json.loads(relations).values()
347
  ]
348
  return ops.Result(output=bundle, display=bundle.to_dict(limit=100))
349
+
350
+
351
+ def load_ws(model_workspace: str):
352
+ cwd = pathlib.Path()
353
+ path = cwd / model_workspace
354
+ assert path.is_relative_to(cwd)
355
+ assert path.exists(), f"Workspace {path} does not exist"
356
+ ws = workspace.load(path)
357
+ return ws
358
+
359
+
360
+ @op("Biomedical foundation graph (PLACEHOLDER)")
361
+ def biomedical_foundation_graph(*, filter_nodes: str):
362
+ """Loads the gigantic Lynx-maintained knowledge graph. Includes drugs, diseases, genes, proteins, etc."""
363
+ return None
364
+
365
+
366
+ @op("Define model")
367
+ def define_model(
368
+ bundle: core.Bundle,
369
+ *,
370
+ model_workspace: str,
371
+ save_as: str = "model",
372
+ ):
373
+ """Trains the selected model on the selected dataset. Most training parameters are set in the model definition."""
374
+ assert model_workspace, "Model workspace is unset."
375
+ ws = load_ws(model_workspace)
376
+ # Build the model without inputs, to get its interface.
377
+ m = pytorch_model_ops.build_model(ws, {})
378
+ m.source_workspace = model_workspace
379
+ bundle = bundle.copy()
380
+ bundle.other[save_as] = m
381
+ return bundle
382
+
383
+
384
+ # These contain the same mapping, but they get different UIs.
385
+ # For inputs, you select existing columns. For outputs, you can create new columns.
386
+ class ModelInferenceInputMapping(pytorch_model_ops.ModelMapping):
387
+ pass
388
+
389
+
390
+ class ModelTrainingInputMapping(pytorch_model_ops.ModelMapping):
391
+ pass
392
+
393
+
394
+ class ModelOutputMapping(pytorch_model_ops.ModelMapping):
395
+ pass
396
+
397
+
398
+ @op("Train model")
399
+ def train_model(
400
+ bundle: core.Bundle,
401
+ *,
402
+ model_name: str = "model",
403
+ input_mapping: ModelTrainingInputMapping,
404
+ epochs: int = 1,
405
+ ):
406
+ """Trains the selected model on the selected dataset. Most training parameters are set in the model definition."""
407
+ m = bundle.other[model_name].copy()
408
+ inputs = pytorch_model_ops.to_tensors(bundle, input_mapping)
409
+ if not m.trained and m.source_workspace:
410
+ # Rebuild the model for the correct inputs.
411
+ ws = load_ws(m.source_workspace)
412
+ m = pytorch_model_ops.build_model(ws, inputs)
413
+ t = tqdm(range(epochs), desc="Training model")
414
+ for _ in t:
415
+ loss = m.train(inputs)
416
+ t.set_postfix({"loss": loss})
417
+ m.trained = True
418
+ bundle = bundle.copy()
419
+ bundle.other[model_name] = m
420
+ return bundle
421
+
422
+
423
+ @op("Model inference")
424
+ def model_inference(
425
+ bundle: core.Bundle,
426
+ *,
427
+ model_name: str = "model",
428
+ input_mapping: ModelInferenceInputMapping,
429
+ output_mapping: ModelOutputMapping,
430
+ ):
431
+ """Executes a trained model."""
432
+ if input_mapping is None or output_mapping is None:
433
+ return ops.Result(bundle, error="Mapping is unset.")
434
+ m = bundle.other[model_name]
435
+ assert m.trained, "The model is not trained."
436
+ inputs = pytorch_model_ops.to_tensors(bundle, input_mapping)
437
+ outputs = m.inference(inputs)
438
+ bundle = bundle.copy()
439
+ for k, v in output_mapping.map.items():
440
+ bundle.dfs[v.df][v.column] = outputs[k].detach().numpy().tolist()
441
+ return bundle
442
+
443
+
444
+ @op("Train/test split")
445
+ def train_test_split(bundle: core.Bundle, *, table_name: str, test_ratio: float = 0.1):
446
+ """Splits a dataframe in the bundle into separate "_train" and "_test" dataframes."""
447
+ df = bundle.dfs[table_name]
448
+ test = df.sample(frac=test_ratio)
449
+ train = df.drop(test.index)
450
+ bundle = bundle.copy()
451
+ bundle.dfs[f"{table_name}_train"] = train
452
+ bundle.dfs[f"{table_name}_test"] = test
453
+ return bundle
lynxkite-graph-analytics/src/lynxkite_graph_analytics/pytorch_model_ops.py CHANGED
@@ -1,7 +1,16 @@
1
  """Boxes for defining PyTorch models."""
2
 
3
- from lynxkite.core import ops
 
 
 
 
 
4
  from lynxkite.core.ops import Parameter as P
 
 
 
 
5
 
6
  ENV = "PyTorch model"
7
 
@@ -22,16 +31,43 @@ def reg(name, inputs=[], outputs=None, params=[]):
22
  )
23
 
24
 
25
- reg("Input: features", outputs=["x"])
26
  reg("Input: graph edges", outputs=["edges"])
27
  reg("Input: label", outputs=["y"])
28
  reg("Input: positive sample", outputs=["x_pos"])
29
  reg("Input: negative sample", outputs=["x_neg"])
 
 
30
 
31
- reg("Attention", inputs=["q", "k", "v"], outputs=["x"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  reg("LayerNorm", inputs=["x"])
33
  reg("Dropout", inputs=["x"], params=[P.basic("p", 0.5)])
34
  reg("Linear", inputs=["x"], params=[P.basic("output_dim", "same")])
 
35
  reg(
36
  "Graph conv",
37
  inputs=["x", "edges"],
@@ -41,10 +77,15 @@ reg(
41
  reg(
42
  "Activation",
43
  inputs=["x"],
44
- params=[P.options("type", ["ReLU", "LeakyReLU", "Tanh", "Mish"])],
45
  )
46
- reg("Supervised loss", inputs=["x", "y"], outputs=["loss"])
47
- reg("Triplet loss", inputs=["x", "x_pos", "x_neg"], outputs=["loss"])
 
 
 
 
 
48
  reg(
49
  "Optimizer",
50
  inputs=["loss"],
@@ -71,5 +112,211 @@ ops.register_passive_op(
71
  "Repeat",
72
  inputs=[ops.Input(name="input", position="top", type="tensor")],
73
  outputs=[ops.Output(name="output", position="bottom", type="tensor")],
74
- params=[ops.Parameter.basic("times", 1, int)],
 
 
 
 
 
 
 
 
 
 
 
75
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Boxes for defining PyTorch models."""
2
 
3
+ import copy
4
+ import graphlib
5
+ import types
6
+
7
+ import pydantic
8
+ from lynxkite.core import ops, workspace
9
  from lynxkite.core.ops import Parameter as P
10
+ import torch
11
+ import torch_geometric as pyg
12
+ import dataclasses
13
+ from . import core
14
 
15
  ENV = "PyTorch model"
16
 
 
31
  )
32
 
33
 
34
+ reg("Input: embedding", outputs=["x"])
35
  reg("Input: graph edges", outputs=["edges"])
36
  reg("Input: label", outputs=["y"])
37
  reg("Input: positive sample", outputs=["x_pos"])
38
  reg("Input: negative sample", outputs=["x_neg"])
39
+ reg("Input: sequential", outputs=["y"])
40
+ reg("Input: zeros", outputs=["x"])
41
 
42
+ reg("LSTM", inputs=["x", "h"], outputs=["x", "h"])
43
+ reg(
44
+ "Neural ODE",
45
+ inputs=["x"],
46
+ params=[
47
+ P.basic("relative_tolerance"),
48
+ P.basic("absolute_tolerance"),
49
+ P.options(
50
+ "method",
51
+ [
52
+ "dopri8",
53
+ "dopri5",
54
+ "bosh3",
55
+ "fehlberg2",
56
+ "adaptive_heun",
57
+ "euler",
58
+ "midpoint",
59
+ "rk4",
60
+ "explicit_adams",
61
+ "implicit_adams",
62
+ ],
63
+ ),
64
+ ],
65
+ )
66
+ reg("Attention", inputs=["q", "k", "v"], outputs=["x", "weights"])
67
  reg("LayerNorm", inputs=["x"])
68
  reg("Dropout", inputs=["x"], params=[P.basic("p", 0.5)])
69
  reg("Linear", inputs=["x"], params=[P.basic("output_dim", "same")])
70
+ reg("Softmax", inputs=["x"])
71
  reg(
72
  "Graph conv",
73
  inputs=["x", "edges"],
 
77
  reg(
78
  "Activation",
79
  inputs=["x"],
80
+ params=[P.options("type", ["ReLU", "Leaky ReLU", "Tanh", "Mish"])],
81
  )
82
+ reg("Concatenate", inputs=["a", "b"], outputs=["x"])
83
+ reg("Add", inputs=["a", "b"], outputs=["x"])
84
+ reg("Subtract", inputs=["a", "b"], outputs=["x"])
85
+ reg("Multiply", inputs=["a", "b"], outputs=["x"])
86
+ reg("MSE loss", inputs=["x", "y"], outputs=["loss"])
87
+ reg("Triplet margin loss", inputs=["x", "x_pos", "x_neg"], outputs=["loss"])
88
+ reg("Cross-entropy loss", inputs=["x", "y"], outputs=["loss"])
89
  reg(
90
  "Optimizer",
91
  inputs=["loss"],
 
112
  "Repeat",
113
  inputs=[ops.Input(name="input", position="top", type="tensor")],
114
  outputs=[ops.Output(name="output", position="bottom", type="tensor")],
115
+ params=[
116
+ ops.Parameter.basic("times", 1, int),
117
+ ops.Parameter.basic("same_weights", True, bool),
118
+ ],
119
+ )
120
+
121
+ ops.register_passive_op(
122
+ ENV,
123
+ "Recurrent chain",
124
+ inputs=[ops.Input(name="input", position="top", type="tensor")],
125
+ outputs=[ops.Output(name="output", position="bottom", type="tensor")],
126
+ params=[],
127
  )
128
+
129
+
130
+ def _to_id(*strings: str) -> str:
131
+ """Replaces all non-alphanumeric characters with underscores."""
132
+ return "_".join("".join(c if c.isalnum() else "_" for c in s) for s in strings)
133
+
134
+
135
+ class ColumnSpec(pydantic.BaseModel):
136
+ df: str
137
+ column: str
138
+
139
+
140
+ class ModelMapping(pydantic.BaseModel):
141
+ map: dict[str, ColumnSpec]
142
+
143
+
144
+ @dataclasses.dataclass
145
+ class ModelConfig:
146
+ model: torch.nn.Module
147
+ model_inputs: list[str]
148
+ model_outputs: list[str]
149
+ loss_inputs: list[str]
150
+ loss: torch.nn.Module
151
+ optimizer: torch.optim.Optimizer
152
+ source_workspace: str | None = None
153
+ trained: bool = False
154
+
155
+ def _forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
156
+ model_inputs = [inputs[i] for i in self.model_inputs]
157
+ output = self.model(*model_inputs)
158
+ if not isinstance(output, tuple):
159
+ output = (output,)
160
+ values = {k: v for k, v in zip(self.model_outputs, output)}
161
+ return values
162
+
163
+ def inference(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
164
+ # TODO: Do multiple batches.
165
+ self.model.eval()
166
+ return self._forward(inputs)
167
+
168
+ def train(self, inputs: dict[str, torch.Tensor]) -> float:
169
+ """Train the model for one epoch. Returns the loss."""
170
+ # TODO: Do multiple batches.
171
+ self.model.train()
172
+ self.optimizer.zero_grad()
173
+ values = self._forward(inputs)
174
+ values.update(inputs)
175
+ loss_inputs = [values[i] for i in self.loss_inputs]
176
+ loss = self.loss(*loss_inputs)
177
+ loss.backward()
178
+ self.optimizer.step()
179
+ return loss.item()
180
+
181
+ def copy(self):
182
+ """Returns a copy of the model."""
183
+ c = dataclasses.replace(self)
184
+ c.model = copy.deepcopy(self.model)
185
+ return c
186
+
187
+ def metadata(self):
188
+ return {
189
+ "type": "model",
190
+ "model": {
191
+ "inputs": self.model_inputs,
192
+ "outputs": self.model_outputs,
193
+ "loss_inputs": self.loss_inputs,
194
+ "trained": self.trained,
195
+ },
196
+ }
197
+
198
+
199
+ def build_model(
200
+ ws: workspace.Workspace, inputs: dict[str, torch.Tensor]
201
+ ) -> ModelConfig:
202
+ """Builds the model described in the workspace."""
203
+ catalog = ops.CATALOGS[ENV]
204
+ optimizers = []
205
+ nodes = {}
206
+ for node in ws.nodes:
207
+ nodes[node.id] = node
208
+ if node.data.title == "Optimizer":
209
+ optimizers.append(node.id)
210
+ assert optimizers, "No optimizer found."
211
+ assert len(optimizers) == 1, f"More than one optimizer found: {optimizers}"
212
+ [optimizer] = optimizers
213
+ dependencies = {n.id: [] for n in ws.nodes}
214
+ in_edges = {}
215
+ out_edges = {}
216
+ # TODO: Dissolve repeat boxes here.
217
+ for e in ws.edges:
218
+ dependencies[e.target].append(e.source)
219
+ in_edges.setdefault(e.target, {}).setdefault(e.targetHandle, []).append(
220
+ (e.source, e.sourceHandle)
221
+ )
222
+ out_edges.setdefault(e.source, {}).setdefault(e.sourceHandle, []).append(
223
+ (e.target, e.targetHandle)
224
+ )
225
+ sizes = {}
226
+ for k, i in inputs.items():
227
+ sizes[k] = i.shape[-1]
228
+ ts = graphlib.TopologicalSorter(dependencies)
229
+ layers = []
230
+ loss_layers = []
231
+ in_loss = set()
232
+ cfg = {}
233
+ used_in_model = set()
234
+ made_in_model = set()
235
+ used_in_loss = set()
236
+ made_in_loss = set()
237
+ for node_id in ts.static_order():
238
+ node = nodes[node_id]
239
+ t = node.data.title
240
+ op = catalog[t]
241
+ p = op.convert_params(node.data.params)
242
+ for b in dependencies[node_id]:
243
+ if b in in_loss:
244
+ in_loss.add(node_id)
245
+ if "loss" in t:
246
+ in_loss.add(node_id)
247
+ inputs = {}
248
+ for n in in_edges.get(node_id, []):
249
+ for b, h in in_edges[node_id][n]:
250
+ i = _to_id(b, h)
251
+ inputs[n] = i
252
+ if node_id in in_loss:
253
+ used_in_loss.add(i)
254
+ else:
255
+ used_in_model.add(i)
256
+ outputs = {}
257
+ for out in out_edges.get(node_id, []):
258
+ i = _to_id(node_id, out)
259
+ outputs[out] = i
260
+ if inputs: # Nodes with no inputs are input nodes. Their outputs are not "made" by us.
261
+ if node_id in in_loss:
262
+ made_in_loss.add(i)
263
+ else:
264
+ made_in_model.add(i)
265
+ inputs = types.SimpleNamespace(**inputs)
266
+ outputs = types.SimpleNamespace(**outputs)
267
+ ls = loss_layers if node_id in in_loss else layers
268
+ match t:
269
+ case "Linear":
270
+ isize = sizes.get(inputs.x, 1)
271
+ osize = isize if p["output_dim"] == "same" else int(p["output_dim"])
272
+ ls.append((torch.nn.Linear(isize, osize), f"{inputs.x} -> {outputs.x}"))
273
+ sizes[outputs.x] = osize
274
+ case "Activation":
275
+ f = getattr(
276
+ torch.nn.functional, p["type"].name.lower().replace(" ", "_")
277
+ )
278
+ ls.append((f, f"{inputs.x} -> {outputs.x}"))
279
+ sizes[outputs.x] = sizes.get(inputs.x, 1)
280
+ case "MSE loss":
281
+ ls.append(
282
+ (
283
+ torch.nn.functional.mse_loss,
284
+ f"{inputs.x}, {inputs.y} -> {outputs.loss}",
285
+ )
286
+ )
287
+ cfg["model_inputs"] = list(used_in_model - made_in_model)
288
+ cfg["model_outputs"] = list(made_in_model & used_in_loss)
289
+ cfg["loss_inputs"] = list(used_in_loss - made_in_loss)
290
+ # Make sure the trained output is output from the last model layer.
291
+ outputs = ", ".join(cfg["model_outputs"])
292
+ layers.append((torch.nn.Identity(), f"{outputs} -> {outputs}"))
293
+ # Create model.
294
+ cfg["model"] = pyg.nn.Sequential(", ".join(cfg["model_inputs"]), layers)
295
+ # Make sure the loss is output from the last loss layer.
296
+ [(lossb, lossh)] = in_edges[optimizer]["loss"]
297
+ lossi = _to_id(lossb, lossh)
298
+ loss_layers.append((torch.nn.Identity(), f"{lossi} -> loss"))
299
+ # Create loss function.
300
+ cfg["loss"] = pyg.nn.Sequential(", ".join(cfg["loss_inputs"]), loss_layers)
301
+ assert not list(cfg["loss"].parameters()), (
302
+ f"loss should have no parameters: {list(cfg['loss'].parameters())}"
303
+ )
304
+ # Create optimizer.
305
+ op = catalog["Optimizer"]
306
+ p = op.convert_params(nodes[optimizer].data.params)
307
+ o = getattr(torch.optim, p["type"].name)
308
+ cfg["optimizer"] = o(cfg["model"].parameters(), lr=p["lr"])
309
+ return ModelConfig(**cfg)
310
+
311
+
312
+ def to_tensors(b: core.Bundle, m: ModelMapping | None) -> dict[str, torch.Tensor]:
313
+ """Converts a tensor to the correct type for PyTorch. Ignores missing mappings."""
314
+ if m is None:
315
+ return {}
316
+ tensors = {}
317
+ for k, v in m.map.items():
318
+ if v.df in b.dfs and v.column in b.dfs[v.df]:
319
+ tensors[k] = torch.tensor(
320
+ b.dfs[v.df][v.column].to_list(), dtype=torch.float32
321
+ )
322
+ return tensors
lynxkite-graph-analytics/tests/test_pytorch_model_ops.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lynxkite.core import workspace
2
+ from lynxkite_graph_analytics import pytorch_model_ops
3
+ import torch
4
+ import pytest
5
+
6
+
7
+ def make_ws(env, nodes: dict[str, dict], edges: list[tuple[str, str, str, str]]):
8
+ ws = workspace.Workspace(env=env)
9
+ for id, data in nodes.items():
10
+ ws.nodes.append(
11
+ workspace.WorkspaceNode(
12
+ id=id,
13
+ type="basic",
14
+ data=workspace.WorkspaceNodeData(title=data["title"], params=data),
15
+ position=workspace.Position(
16
+ x=data.get("x", 0),
17
+ y=data.get("y", 0),
18
+ ),
19
+ )
20
+ )
21
+ ws.edges = [
22
+ workspace.WorkspaceEdge(
23
+ id=f"{source}->{target}",
24
+ source=source.split(":")[0],
25
+ target=target.split(":")[0],
26
+ sourceHandle=source.split(":")[1],
27
+ targetHandle=target.split(":")[1],
28
+ )
29
+ for source, target in edges
30
+ ]
31
+ return ws
32
+
33
+
34
+ async def test_build_model():
35
+ ws = make_ws(
36
+ pytorch_model_ops.ENV,
37
+ {
38
+ "emb": {"title": "Input: embedding"},
39
+ "lin": {"title": "Linear", "output_dim": "same"},
40
+ "act": {"title": "Activation", "type": "Leaky ReLU"},
41
+ "label": {"title": "Input: label"},
42
+ "loss": {"title": "MSE loss"},
43
+ "optim": {"title": "Optimizer", "type": "SGD", "lr": 0.1},
44
+ },
45
+ [
46
+ ("emb:x", "lin:x"),
47
+ ("lin:x", "act:x"),
48
+ ("act:x", "loss:x"),
49
+ ("label:y", "loss:y"),
50
+ ("loss:loss", "optim:loss"),
51
+ ],
52
+ )
53
+ x = torch.rand(100, 4)
54
+ y = x + 1
55
+ m = pytorch_model_ops.build_model(ws, {"emb_x": x, "label_y": y})
56
+ for i in range(1000):
57
+ loss = m.train({"emb_x": x, "label_y": y})
58
+ assert loss < 0.1
59
+ o = m.inference({"emb_x": x[:1]})
60
+ error = torch.nn.functional.mse_loss(o["act_x"], x[:1] + 1)
61
+ assert error < 0.1
62
+
63
+
64
+ if __name__ == "__main__":
65
+ pytest.main()
lynxkite-graph-analytics/uv.lock CHANGED
The diff for this file is too large to render. See raw diff