darabos commited on
Commit
b2852ba
·
1 Parent(s): 1353683

A UI for mapping model bindings.

Browse files
examples/Model use CHANGED
@@ -15,8 +15,22 @@
15
  "targetHandle": "bundle"
16
  },
17
  {
18
- "id": "Train model 3 Model inference 1",
19
- "source": "Train model 3",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  "sourceHandle": "output",
21
  "target": "Model inference 1",
22
  "targetHandle": "bundle"
@@ -28,7 +42,30 @@
28
  "data": {
29
  "__execution_delay": 0.0,
30
  "collapsed": null,
31
- "display": null,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  "error": null,
33
  "meta": {
34
  "inputs": {
@@ -134,7 +171,46 @@
134
  "data": {
135
  "__execution_delay": 0.0,
136
  "collapsed": null,
137
- "display": null,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  "error": null,
139
  "meta": {
140
  "inputs": {
@@ -194,7 +270,7 @@
194
  },
195
  "params": {
196
  "epochs": "1000",
197
- "input_mapping": "{\"Input__embedding_1_x\": {\"df\": \"df_train\", \"column\": \"x\"}, \"Input__label_1_y\": {\"df\": \"df_train\", \"column\": \"y\" }}",
198
  "model_workspace": "Model definition",
199
  "save_as": "model"
200
  },
@@ -205,8 +281,8 @@
205
  "height": 519.0,
206
  "id": "Train model 3",
207
  "position": {
208
- "x": 687.3818749999999,
209
- "y": -34.16902777777775
210
  },
211
  "type": "basic",
212
  "width": 640.0
@@ -216,7 +292,7 @@
216
  "__execution_delay": 0.0,
217
  "collapsed": null,
218
  "display": null,
219
- "error": null,
220
  "meta": {
221
  "inputs": {
222
  "bundle": {
@@ -267,9 +343,9 @@
267
  "type": "basic"
268
  },
269
  "params": {
270
- "input_mapping": "{\"Input__embedding_1_x\": {\"df\": \"df_test\", \"column\": \"x\"}}",
271
  "model_name": "model",
272
- "output_mapping": "{\"Activation_2_x\": {\"df\": \"df_test\", \"column\": \"predicted\"}}"
273
  },
274
  "status": "done",
275
  "title": "Model inference"
@@ -283,6 +359,1014 @@
283
  },
284
  "type": "basic",
285
  "width": 410.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  }
287
  ]
288
  }
 
15
  "targetHandle": "bundle"
16
  },
17
  {
18
+ "id": "Model inference 1 View tables 1",
19
+ "source": "Model inference 1",
20
+ "sourceHandle": "output",
21
+ "target": "View tables 1",
22
+ "targetHandle": "bundle"
23
+ },
24
+ {
25
+ "id": "Train/test split 1 Train model 1",
26
+ "source": "Train/test split 1",
27
+ "sourceHandle": "output",
28
+ "target": "Train model 1",
29
+ "targetHandle": "bundle"
30
+ },
31
+ {
32
+ "id": "Train model 1 Model inference 1",
33
+ "source": "Train model 1",
34
  "sourceHandle": "output",
35
  "target": "Model inference 1",
36
  "targetHandle": "bundle"
 
42
  "data": {
43
  "__execution_delay": 0.0,
44
  "collapsed": null,
45
+ "display": {
46
+ "dataframes": {
47
+ "df": {
48
+ "columns": [
49
+ "x",
50
+ "y"
51
+ ]
52
+ },
53
+ "df_test": {
54
+ "columns": [
55
+ "x",
56
+ "y"
57
+ ]
58
+ },
59
+ "df_train": {
60
+ "columns": [
61
+ "x",
62
+ "y"
63
+ ]
64
+ }
65
+ },
66
+ "other": {},
67
+ "relations": []
68
+ },
69
  "error": null,
70
  "meta": {
71
  "inputs": {
 
171
  "data": {
172
  "__execution_delay": 0.0,
173
  "collapsed": null,
174
+ "display": {
175
+ "dataframes": {
176
+ "df": {
177
+ "columns": [
178
+ "x",
179
+ "y"
180
+ ]
181
+ },
182
+ "df_test": {
183
+ "columns": [
184
+ "x",
185
+ "y"
186
+ ]
187
+ },
188
+ "df_train": {
189
+ "columns": [
190
+ "x",
191
+ "y"
192
+ ]
193
+ }
194
+ },
195
+ "other": {
196
+ "model": {
197
+ "model": {
198
+ "inputs": [
199
+ "Input__embedding_1_x"
200
+ ],
201
+ "loss_inputs": [
202
+ "Activation_2_x",
203
+ "Input__label_1_y"
204
+ ],
205
+ "outputs": [
206
+ "Activation_2_x"
207
+ ]
208
+ },
209
+ "type": "model"
210
+ }
211
+ },
212
+ "relations": []
213
+ },
214
  "error": null,
215
  "meta": {
216
  "inputs": {
 
270
  },
271
  "params": {
272
  "epochs": "1000",
273
+ "input_mapping": "{\"map\": {\"Input__embedding_1_x\": {\"df\": \"df_train\", \"column\": \"x\"}, \"Input__label_1_y\": {\"df\": \"df_train\", \"column\": \"y\" }}}",
274
  "model_workspace": "Model definition",
275
  "save_as": "model"
276
  },
 
281
  "height": 519.0,
282
  "id": "Train model 3",
283
  "position": {
284
+ "x": 722.5912720951791,
285
+ "y": -784.0614755260641
286
  },
287
  "type": "basic",
288
  "width": 640.0
 
292
  "__execution_delay": 0.0,
293
  "collapsed": null,
294
  "display": null,
295
+ "error": "'Input__embedding_1_x'",
296
  "meta": {
297
  "inputs": {
298
  "bundle": {
 
343
  "type": "basic"
344
  },
345
  "params": {
346
+ "input_mapping": "{\"map\": {\"Input__embedding_1_x\": {\"df\": \"df_test\", \"column\": \"x\"}}}",
347
  "model_name": "model",
348
+ "output_mapping": "{\"map\": {\"Activation_2_x\": {\"df\": \"df_test\", \"column\": \"predicted\"}}}"
349
  },
350
  "status": "done",
351
  "title": "Model inference"
 
359
  },
360
  "type": "basic",
361
  "width": 410.0
362
+ },
363
+ {
364
+ "data": {
365
+ "display": {
366
+ "dataframes": {
367
+ "df": {
368
+ "columns": [
369
+ "x",
370
+ "y"
371
+ ],
372
+ "data": [
373
+ [
374
+ "[0.52046251 0.45887971 0.72169858 0.29517919]",
375
+ "[1.52046251 1.45887971 1.72169852 1.29517913]"
376
+ ],
377
+ [
378
+ "[0.85706753 0.61447072 0.41741937 0.85147089]",
379
+ "[1.85706758 1.61447072 1.41741943 1.85147095]"
380
+ ],
381
+ [
382
+ "[0.11560339 0.57495481 0.76535827 0.0391947 ]",
383
+ "[1.11560345 1.57495475 1.76535821 1.0391947 ]"
384
+ ],
385
+ [
386
+ "[0.19409031 0.68692201 0.60667384 0.57829887]",
387
+ "[1.19409037 1.68692207 1.60667384 1.57829881]"
388
+ ],
389
+ [
390
+ "[0.76807946 0.98855817 0.08259124 0.01730657]",
391
+ "[1.76807952 1.98855817 1.0825913 1.01730657]"
392
+ ],
393
+ [
394
+ "[0.67269951 0.10478973 0.5584439 0.83605725]",
395
+ "[1.67269945 1.10478973 1.5584439 1.83605719]"
396
+ ],
397
+ [
398
+ "[0.18686318 0.49356437 0.51323432 0.75392658]",
399
+ "[1.18686318 1.49356437 1.51323438 1.75392652]"
400
+ ],
401
+ [
402
+ "[0.18149549 0.30520517 0.30946714 0.16786289]",
403
+ "[1.18149543 1.30520511 1.30946708 1.16786289]"
404
+ ],
405
+ [
406
+ "[4.27091718e-01 4.89909172e-01 6.92297399e-01 2.57611275e-04]",
407
+ "[1.42709172 1.48990917 1.69229746 1.00025761]"
408
+ ],
409
+ [
410
+ "[0.32225502 0.16999388 0.05823922 0.9628762 ]",
411
+ "[1.32225502 1.16999388 1.05823922 1.9628762 ]"
412
+ ],
413
+ [
414
+ "[0.50783676 0.04156506 0.21984279 0.8454656 ]",
415
+ "[1.50783682 1.04156506 1.21984279 1.84546566]"
416
+ ],
417
+ [
418
+ "[0.98324287 0.99464184 0.14008355 0.47651017]",
419
+ "[1.98324287 1.99464178 1.14008355 1.47651017]"
420
+ ],
421
+ [
422
+ "[0.11693293 0.49860179 0.55020827 0.88832849]",
423
+ "[1.11693287 1.49860179 1.55020833 1.88832855]"
424
+ ],
425
+ [
426
+ "[0.48959708 0.48549271 0.32688856 0.356677 ]",
427
+ "[1.48959708 1.48549271 1.32688856 1.35667706]"
428
+ ],
429
+ [
430
+ "[0.50272274 0.54912758 0.17663097 0.79070699]",
431
+ "[1.50272274 1.54912758 1.17663097 1.79070699]"
432
+ ],
433
+ [
434
+ "[0.04508126 0.76880038 0.80721325 0.62542385]",
435
+ "[1.04508126 1.76880038 1.80721331 1.62542391]"
436
+ ],
437
+ [
438
+ "[0.19908059 0.17570406 0.51475513 0.1893943 ]",
439
+ "[1.19908059 1.175704 1.51475513 1.18939424]"
440
+ ],
441
+ [
442
+ "[0.40167677 0.25953674 0.9407078 0.76308483]",
443
+ "[1.40167677 1.25953674 1.9407078 1.76308489]"
444
+ ],
445
+ [
446
+ "[0.2480728 0.21694398 0.63941365 0.57128876]",
447
+ "[1.24807286 1.21694398 1.6394136 1.57128882]"
448
+ ],
449
+ [
450
+ "[0.24388778 0.07268471 0.68350857 0.73431659]",
451
+ "[1.24388778 1.07268476 1.68350863 1.73431659]"
452
+ ],
453
+ [
454
+ "[0.62569475 0.9881897 0.83639616 0.9828859 ]",
455
+ "[1.62569475 1.9881897 1.83639622 1.98288584]"
456
+ ],
457
+ [
458
+ "[0.56922203 0.98222166 0.76851749 0.28615737]",
459
+ "[1.56922197 1.9822216 1.76851749 1.28615737]"
460
+ ],
461
+ [
462
+ "[0.88776821 0.51636773 0.30333066 0.32230979]",
463
+ "[1.88776827 1.51636767 1.30333066 1.32230973]"
464
+ ],
465
+ [
466
+ "[0.90817457 0.89270043 0.38583666 0.66566533]",
467
+ "[1.90817451 1.89270043 1.3858366 1.66566539]"
468
+ ],
469
+ [
470
+ "[0.48507756 0.80808765 0.77162558 0.47834778]",
471
+ "[1.48507762 1.80808759 1.77162552 1.47834778]"
472
+ ],
473
+ [
474
+ "[0.68062544 0.98093534 0.14778823 0.53244978]",
475
+ "[1.68062544 1.98093534 1.14778829 1.53244972]"
476
+ ],
477
+ [
478
+ "[0.31518555 0.49643308 0.11509258 0.95458382]",
479
+ "[1.31518555 1.49643302 1.11509252 1.95458388]"
480
+ ],
481
+ [
482
+ "[0.79121011 0.54161114 0.69369799 0.1520769 ]",
483
+ "[1.79121017 1.54161119 1.69369793 1.15207696]"
484
+ ],
485
+ [
486
+ "[0.79423058 0.07138705 0.061777 0.18766576]",
487
+ "[1.79423058 1.07138705 1.061777 1.1876657 ]"
488
+ ],
489
+ [
490
+ "[0.23942459 0.90487361 0.69337189 0.65089428]",
491
+ "[1.23942459 1.90487361 1.69337189 1.65089428]"
492
+ ],
493
+ [
494
+ "[0.94516498 0.08422136 0.5608117 0.07652664]",
495
+ "[1.94516492 1.08422136 1.56081176 1.07652664]"
496
+ ],
497
+ [
498
+ "[0.26661873 0.45946234 0.13510543 0.81294441]",
499
+ "[1.26661873 1.4594624 1.13510537 1.81294441]"
500
+ ],
501
+ [
502
+ "[0.30754459 0.77694583 0.09278506 0.38326019]",
503
+ "[1.30754459 1.77694583 1.09278512 1.38326025]"
504
+ ],
505
+ [
506
+ "[0.27845025 0.32472342 0.82203609 0.77107543]",
507
+ "[1.27845025 1.32472348 1.82203603 1.77107549]"
508
+ ],
509
+ [
510
+ "[0.4827103 0.10563457 0.98858833 0.82286644]",
511
+ "[1.48271036 1.10563457 1.98858833 1.82286644]"
512
+ ],
513
+ [
514
+ "[0.98033333 0.97656083 0.38939917 0.81491041]",
515
+ "[1.98033333 1.97656083 1.38939917 1.81491041]"
516
+ ],
517
+ [
518
+ "[0.74064726 0.4155122 0.09800029 0.49930882]",
519
+ "[1.74064732 1.4155122 1.09800029 1.49930882]"
520
+ ],
521
+ [
522
+ "[0.78956431 0.87284744 0.06880784 0.03455889]",
523
+ "[1.78956437 1.87284744 1.06880784 1.03455889]"
524
+ ],
525
+ [
526
+ "[0.94221359 0.57740951 0.98649532 0.40934443]",
527
+ "[1.94221354 1.57740951 1.98649526 1.40934443]"
528
+ ],
529
+ [
530
+ "[0.00497234 0.39319336 0.57054168 0.75150961]",
531
+ "[1.00497234 1.39319336 1.57054162 1.75150967]"
532
+ ],
533
+ [
534
+ "[0.44330525 0.09997386 0.89025736 0.90507984]",
535
+ "[1.44330525 1.09997392 1.89025736 1.90507984]"
536
+ ],
537
+ [
538
+ "[0.72290605 0.96945059 0.68354797 0.15270454]",
539
+ "[1.72290611 1.96945059 1.68354797 1.15270448]"
540
+ ],
541
+ [
542
+ "[0.75292218 0.81470108 0.49657214 0.56217098]",
543
+ "[1.75292218 1.81470108 1.49657214 1.56217098]"
544
+ ],
545
+ [
546
+ "[0.33480108 0.59181517 0.76198453 0.98062384]",
547
+ "[1.33480108 1.59181523 1.76198459 1.98062384]"
548
+ ],
549
+ [
550
+ "[0.52784437 0.54268694 0.12358981 0.72116476]",
551
+ "[1.52784443 1.54268694 1.12358975 1.7211647 ]"
552
+ ],
553
+ [
554
+ "[0.73217702 0.65233225 0.44077861 0.33837909]",
555
+ "[1.73217702 1.65233231 1.44077861 1.33837914]"
556
+ ],
557
+ [
558
+ "[0.34084332 0.73018837 0.54168713 0.91440833]",
559
+ "[1.34084332 1.73018837 1.54168713 1.91440833]"
560
+ ],
561
+ [
562
+ "[0.60110539 0.3618983 0.32342511 0.98672163]",
563
+ "[1.60110545 1.3618983 1.32342505 1.98672163]"
564
+ ],
565
+ [
566
+ "[0.77427191 0.21829212 0.12769502 0.74303615]",
567
+ "[1.77427197 1.21829212 1.12769508 1.74303615]"
568
+ ],
569
+ [
570
+ "[0.08107251 0.2602725 0.18861133 0.44833237]",
571
+ "[1.08107257 1.2602725 1.18861127 1.44833231]"
572
+ ],
573
+ [
574
+ "[0.59812403 0.78395379 0.0291847 0.81814629]",
575
+ "[1.59812403 1.78395379 1.0291847 1.81814623]"
576
+ ],
577
+ [
578
+ "[0.93488538 0.73882395 0.37345302 0.0274905 ]",
579
+ "[1.93488538 1.73882389 1.37345302 1.0274905 ]"
580
+ ],
581
+ [
582
+ "[0.30631393 0.48311198 0.87847513 0.67559886]",
583
+ "[1.30631399 1.48311198 1.87847519 1.67559886]"
584
+ ],
585
+ [
586
+ "[0.18720162 0.74115586 0.98626411 0.30355608]",
587
+ "[1.18720162 1.74115586 1.98626411 1.30355608]"
588
+ ],
589
+ [
590
+ "[0.85566247 0.83362883 0.48424995 0.25265992]",
591
+ "[1.85566247 1.83362889 1.48424995 1.25265992]"
592
+ ],
593
+ [
594
+ "[0.95928186 0.84273899 0.71514636 0.38619852]",
595
+ "[1.95928192 1.84273899 1.7151463 1.38619852]"
596
+ ],
597
+ [
598
+ "[0.32565445 0.90939188 0.07488042 0.13730896]",
599
+ "[1.32565451 1.90939188 1.07488036 1.13730896]"
600
+ ],
601
+ [
602
+ "[0.9829582 0.59269661 0.40120947 0.95487177]",
603
+ "[1.9829582 1.59269667 1.40120947 1.95487177]"
604
+ ],
605
+ [
606
+ "[0.79905868 0.89367443 0.75429088 0.3190186 ]",
607
+ "[1.79905868 1.89367437 1.75429082 1.3190186 ]"
608
+ ],
609
+ [
610
+ "[0.54914117 0.03810108 0.87531954 0.73044223]",
611
+ "[1.54914117 1.03810108 1.87531948 1.73044229]"
612
+ ],
613
+ [
614
+ "[0.67418337 0.79634351 0.23229051 0.71345252]",
615
+ "[1.67418337 1.79634356 1.23229051 1.71345258]"
616
+ ],
617
+ [
618
+ "[0.87285906 0.48354989 0.39394957 0.59456545]",
619
+ "[1.872859 1.48354983 1.39394951 1.59456539]"
620
+ ],
621
+ [
622
+ "[0.81788456 0.58174163 0.29376316 0.7971254 ]",
623
+ "[1.81788456 1.58174157 1.29376316 1.79712534]"
624
+ ],
625
+ [
626
+ "[0.94559073 0.65736622 0.25761551 0.48553199]",
627
+ "[1.94559073 1.65736628 1.25761557 1.48553205]"
628
+ ],
629
+ [
630
+ "[0.60075855 0.12234765 0.00614399 0.30560958]",
631
+ "[1.60075855 1.12234759 1.00614405 1.30560958]"
632
+ ],
633
+ [
634
+ "[0.39147133 0.29854035 0.84663737 0.58175623]",
635
+ "[1.39147139 1.29854035 1.84663737 1.58175623]"
636
+ ],
637
+ [
638
+ "[0.02162331 0.81861657 0.92468154 0.07808572]",
639
+ "[1.02162337 1.81861663 1.92468154 1.07808566]"
640
+ ],
641
+ [
642
+ "[0.02235305 0.52774918 0.7331115 0.84358269]",
643
+ "[1.02235305 1.52774918 1.7331115 1.84358263]"
644
+ ],
645
+ [
646
+ "[0.6080932 0.56563014 0.32107437 0.72599429]",
647
+ "[1.60809326 1.5656302 1.32107437 1.72599435]"
648
+ ],
649
+ [
650
+ "[0.67447788 0.6125319 0.98007888 0.65968603]",
651
+ "[1.67447782 1.6125319 1.98007894 1.65968609]"
652
+ ],
653
+ [
654
+ "[0.47963417 0.81818312 0.48720706 0.49339259]",
655
+ "[1.47963417 1.81818318 1.48720706 1.49339259]"
656
+ ],
657
+ [
658
+ "[0.9630242 0.76359051 0.24853623 0.76881069]",
659
+ "[1.96302414 1.76359057 1.24853623 1.76881075]"
660
+ ],
661
+ [
662
+ "[0.60609657 0.96257663 0.19292736 0.95702219]",
663
+ "[1.60609651 1.96257663 1.19292736 1.95702219]"
664
+ ],
665
+ [
666
+ "[0.80654246 0.08253473 0.74478531 0.71257162]",
667
+ "[1.8065424 1.08253479 1.74478531 1.71257162]"
668
+ ],
669
+ [
670
+ "[0.70167565 0.26930219 0.5660674 0.61194974]",
671
+ "[1.70167565 1.26930213 1.56606746 1.61194968]"
672
+ ],
673
+ [
674
+ "[0.76933283 0.86241865 0.44114518 0.65644735]",
675
+ "[1.76933289 1.86241865 1.44114518 1.65644741]"
676
+ ],
677
+ [
678
+ "[0.59492421 0.90274489 0.38069052 0.46101224]",
679
+ "[1.59492421 1.90274489 1.38069057 1.46101224]"
680
+ ],
681
+ [
682
+ "[0.15064228 0.03198934 0.25754827 0.51484001]",
683
+ "[1.15064228 1.03198934 1.25754833 1.51484001]"
684
+ ],
685
+ [
686
+ "[0.12024075 0.21342516 0.56858408 0.58644271]",
687
+ "[1.12024069 1.21342516 1.56858408 1.58644271]"
688
+ ],
689
+ [
690
+ "[0.91730917 0.22574073 0.09591609 0.33056474]",
691
+ "[1.91730917 1.22574067 1.09591603 1.33056474]"
692
+ ],
693
+ [
694
+ "[0.49691743 0.61873293 0.90698647 0.94486356]",
695
+ "[1.49691749 1.61873293 1.90698647 1.94486356]"
696
+ ],
697
+ [
698
+ "[0.6032477 0.83361369 0.18538666 0.19108021]",
699
+ "[1.60324764 1.83361363 1.18538666 1.19108021]"
700
+ ],
701
+ [
702
+ "[0.63235509 0.70352674 0.96188956 0.46240485]",
703
+ "[1.63235509 1.70352674 1.96188951 1.46240485]"
704
+ ],
705
+ [
706
+ "[0.37959969 0.42820001 0.10690689 0.96353984]",
707
+ "[1.37959969 1.42820001 1.10690689 1.96353984]"
708
+ ],
709
+ [
710
+ "[0.49607176 0.1922397 0.46640229 0.78321403]",
711
+ "[1.49607182 1.19223976 1.46640229 1.78321409]"
712
+ ],
713
+ [
714
+ "[0.40234613 0.54987347 0.49542785 0.54153186]",
715
+ "[1.40234613 1.54987347 1.49542785 1.5415318 ]"
716
+ ],
717
+ [
718
+ "[0.80893755 0.92237449 0.88346356 0.93164903]",
719
+ "[1.80893755 1.92237449 1.88346362 1.93164897]"
720
+ ],
721
+ [
722
+ "[0.12858278 0.09930819 0.83222693 0.72485673]",
723
+ "[1.12858272 1.09930825 1.83222699 1.72485673]"
724
+ ],
725
+ [
726
+ "[0.72470158 0.4940322 0.41027349 0.89364016]",
727
+ "[1.72470164 1.49403214 1.41027355 1.89364016]"
728
+ ],
729
+ [
730
+ "[0.47856545 0.46267092 0.6376707 0.84747767]",
731
+ "[1.47856545 1.46267092 1.63767076 1.84747767]"
732
+ ],
733
+ [
734
+ "[0.49584109 0.80599248 0.07096875 0.75872749]",
735
+ "[1.49584103 1.80599248 1.07096875 1.75872755]"
736
+ ],
737
+ [
738
+ "[0.43500566 0.66041756 0.80293626 0.96224713]",
739
+ "[1.43500566 1.66041756 1.80293632 1.96224713]"
740
+ ],
741
+ [
742
+ "[0.78397602 0.74223626 0.26603186 0.41664881]",
743
+ "[1.78397608 1.74223626 1.26603186 1.41664886]"
744
+ ],
745
+ [
746
+ "[0.28942841 0.05601001 0.33039129 0.27781558]",
747
+ "[1.28942847 1.05601001 1.33039129 1.27781558]"
748
+ ],
749
+ [
750
+ "[0.68094063 0.45189077 0.22661722 0.37354094]",
751
+ "[1.68094063 1.45189071 1.22661722 1.37354088]"
752
+ ],
753
+ [
754
+ "[0.43681622 0.74680805 0.83598751 0.12414402]",
755
+ "[1.43681622 1.74680805 1.83598757 1.12414408]"
756
+ ],
757
+ [
758
+ "[0.47870928 0.17129105 0.27300501 0.20634609]",
759
+ "[1.47870922 1.17129111 1.27300501 1.20634604]"
760
+ ],
761
+ [
762
+ "[0.72795159 0.79317838 0.27832931 0.96576637]",
763
+ "[1.72795153 1.79317832 1.27832937 1.96576643]"
764
+ ],
765
+ [
766
+ "[0.87608397 0.93200487 0.80169648 0.37758952]",
767
+ "[1.87608397 1.93200493 1.80169654 1.37758946]"
768
+ ],
769
+ [
770
+ "[0.68891573 0.25576538 0.96339929 0.503833 ]",
771
+ "[1.68891573 1.25576544 1.96339929 1.50383306]"
772
+ ]
773
+ ]
774
+ },
775
+ "df_test": {
776
+ "columns": [
777
+ "x",
778
+ "y",
779
+ "predicted"
780
+ ],
781
+ "data": [
782
+ [
783
+ "[0.52046251 0.45887971 0.72169858 0.29517919]",
784
+ "[1.52046251 1.45887971 1.72169852 1.29517913]",
785
+ "[1.5168578624725342, 1.450861930847168, 1.7133464813232422, 1.3041404485702515]"
786
+ ],
787
+ [
788
+ "[0.78956431 0.87284744 0.06880784 0.03455889]",
789
+ "[1.78956437 1.87284744 1.06880784 1.03455889]",
790
+ "[1.7899272441864014, 1.829580307006836, 1.0702992677688599, 1.0265709161758423]"
791
+ ],
792
+ [
793
+ "[0.49607176 0.1922397 0.46640229 0.78321403]",
794
+ "[1.49607182 1.19223976 1.46640229 1.78321409]",
795
+ "[1.4901000261306763, 1.193819284439087, 1.4632138013839722, 1.7822779417037964]"
796
+ ],
797
+ [
798
+ "[0.49691743 0.61873293 0.90698647 0.94486356]",
799
+ "[1.49691749 1.61873293 1.90698647 1.94486356]",
800
+ "[1.4999868869781494, 1.6656270027160645, 1.9074199199676514, 1.9556759595870972]"
801
+ ],
802
+ [
803
+ "[0.59812403 0.78395379 0.0291847 0.81814629]",
804
+ "[1.59812403 1.78395379 1.0291847 1.81814623]",
805
+ "[1.6044235229492188, 1.7707669734954834, 1.0426081418991089, 1.7988944053649902]"
806
+ ],
807
+ [
808
+ "[0.67447788 0.6125319 0.98007888 0.65968603]",
809
+ "[1.67447782 1.6125319 1.98007894 1.65968609]",
810
+ "[1.6721093654632568, 1.6624714136123657, 1.9726766347885132, 1.6813924312591553]"
811
+ ],
812
+ [
813
+ "[0.18720162 0.74115586 0.98626411 0.30355608]",
814
+ "[1.18720162 1.74115586 1.98626411 1.30355608]",
815
+ "[1.1961991786956787, 1.723442792892456, 1.9852817058563232, 1.3066248893737793]"
816
+ ],
817
+ [
818
+ "[0.74064726 0.4155122 0.09800029 0.49930882]",
819
+ "[1.74064732 1.4155122 1.09800029 1.49930882]",
820
+ "[1.7340764999389648, 1.3968157768249512, 1.0968588590621948, 1.493086814880371]"
821
+ ],
822
+ [
823
+ "[0.70167565 0.26930219 0.5660674 0.61194974]",
824
+ "[1.70167565 1.26930213 1.56606746 1.61194968]",
825
+ "[1.691997766494751, 1.2865687608718872, 1.5571787357330322, 1.622729778289795]"
826
+ ],
827
+ [
828
+ "[0.90817457 0.89270043 0.38583666 0.66566533]",
829
+ "[1.90817451 1.89270043 1.3858366 1.66566539]",
830
+ "[1.9086859226226807, 1.924757719039917, 1.3887461423873901, 1.6714670658111572]"
831
+ ]
832
+ ]
833
+ },
834
+ "df_train": {
835
+ "columns": [
836
+ "x",
837
+ "y"
838
+ ],
839
+ "data": [
840
+ [
841
+ "[0.85706753 0.61447072 0.41741937 0.85147089]",
842
+ "[1.85706758 1.61447072 1.41741943 1.85147095]"
843
+ ],
844
+ [
845
+ "[0.11560339 0.57495481 0.76535827 0.0391947 ]",
846
+ "[1.11560345 1.57495475 1.76535821 1.0391947 ]"
847
+ ],
848
+ [
849
+ "[0.19409031 0.68692201 0.60667384 0.57829887]",
850
+ "[1.19409037 1.68692207 1.60667384 1.57829881]"
851
+ ],
852
+ [
853
+ "[0.76807946 0.98855817 0.08259124 0.01730657]",
854
+ "[1.76807952 1.98855817 1.0825913 1.01730657]"
855
+ ],
856
+ [
857
+ "[0.67269951 0.10478973 0.5584439 0.83605725]",
858
+ "[1.67269945 1.10478973 1.5584439 1.83605719]"
859
+ ],
860
+ [
861
+ "[0.18686318 0.49356437 0.51323432 0.75392658]",
862
+ "[1.18686318 1.49356437 1.51323438 1.75392652]"
863
+ ],
864
+ [
865
+ "[0.18149549 0.30520517 0.30946714 0.16786289]",
866
+ "[1.18149543 1.30520511 1.30946708 1.16786289]"
867
+ ],
868
+ [
869
+ "[4.27091718e-01 4.89909172e-01 6.92297399e-01 2.57611275e-04]",
870
+ "[1.42709172 1.48990917 1.69229746 1.00025761]"
871
+ ],
872
+ [
873
+ "[0.32225502 0.16999388 0.05823922 0.9628762 ]",
874
+ "[1.32225502 1.16999388 1.05823922 1.9628762 ]"
875
+ ],
876
+ [
877
+ "[0.50783676 0.04156506 0.21984279 0.8454656 ]",
878
+ "[1.50783682 1.04156506 1.21984279 1.84546566]"
879
+ ],
880
+ [
881
+ "[0.98324287 0.99464184 0.14008355 0.47651017]",
882
+ "[1.98324287 1.99464178 1.14008355 1.47651017]"
883
+ ],
884
+ [
885
+ "[0.11693293 0.49860179 0.55020827 0.88832849]",
886
+ "[1.11693287 1.49860179 1.55020833 1.88832855]"
887
+ ],
888
+ [
889
+ "[0.48959708 0.48549271 0.32688856 0.356677 ]",
890
+ "[1.48959708 1.48549271 1.32688856 1.35667706]"
891
+ ],
892
+ [
893
+ "[0.50272274 0.54912758 0.17663097 0.79070699]",
894
+ "[1.50272274 1.54912758 1.17663097 1.79070699]"
895
+ ],
896
+ [
897
+ "[0.04508126 0.76880038 0.80721325 0.62542385]",
898
+ "[1.04508126 1.76880038 1.80721331 1.62542391]"
899
+ ],
900
+ [
901
+ "[0.19908059 0.17570406 0.51475513 0.1893943 ]",
902
+ "[1.19908059 1.175704 1.51475513 1.18939424]"
903
+ ],
904
+ [
905
+ "[0.40167677 0.25953674 0.9407078 0.76308483]",
906
+ "[1.40167677 1.25953674 1.9407078 1.76308489]"
907
+ ],
908
+ [
909
+ "[0.2480728 0.21694398 0.63941365 0.57128876]",
910
+ "[1.24807286 1.21694398 1.6394136 1.57128882]"
911
+ ],
912
+ [
913
+ "[0.24388778 0.07268471 0.68350857 0.73431659]",
914
+ "[1.24388778 1.07268476 1.68350863 1.73431659]"
915
+ ],
916
+ [
917
+ "[0.62569475 0.9881897 0.83639616 0.9828859 ]",
918
+ "[1.62569475 1.9881897 1.83639622 1.98288584]"
919
+ ],
920
+ [
921
+ "[0.56922203 0.98222166 0.76851749 0.28615737]",
922
+ "[1.56922197 1.9822216 1.76851749 1.28615737]"
923
+ ],
924
+ [
925
+ "[0.88776821 0.51636773 0.30333066 0.32230979]",
926
+ "[1.88776827 1.51636767 1.30333066 1.32230973]"
927
+ ],
928
+ [
929
+ "[0.48507756 0.80808765 0.77162558 0.47834778]",
930
+ "[1.48507762 1.80808759 1.77162552 1.47834778]"
931
+ ],
932
+ [
933
+ "[0.68062544 0.98093534 0.14778823 0.53244978]",
934
+ "[1.68062544 1.98093534 1.14778829 1.53244972]"
935
+ ],
936
+ [
937
+ "[0.31518555 0.49643308 0.11509258 0.95458382]",
938
+ "[1.31518555 1.49643302 1.11509252 1.95458388]"
939
+ ],
940
+ [
941
+ "[0.79121011 0.54161114 0.69369799 0.1520769 ]",
942
+ "[1.79121017 1.54161119 1.69369793 1.15207696]"
943
+ ],
944
+ [
945
+ "[0.79423058 0.07138705 0.061777 0.18766576]",
946
+ "[1.79423058 1.07138705 1.061777 1.1876657 ]"
947
+ ],
948
+ [
949
+ "[0.23942459 0.90487361 0.69337189 0.65089428]",
950
+ "[1.23942459 1.90487361 1.69337189 1.65089428]"
951
+ ],
952
+ [
953
+ "[0.94516498 0.08422136 0.5608117 0.07652664]",
954
+ "[1.94516492 1.08422136 1.56081176 1.07652664]"
955
+ ],
956
+ [
957
+ "[0.26661873 0.45946234 0.13510543 0.81294441]",
958
+ "[1.26661873 1.4594624 1.13510537 1.81294441]"
959
+ ],
960
+ [
961
+ "[0.30754459 0.77694583 0.09278506 0.38326019]",
962
+ "[1.30754459 1.77694583 1.09278512 1.38326025]"
963
+ ],
964
+ [
965
+ "[0.27845025 0.32472342 0.82203609 0.77107543]",
966
+ "[1.27845025 1.32472348 1.82203603 1.77107549]"
967
+ ],
968
+ [
969
+ "[0.4827103 0.10563457 0.98858833 0.82286644]",
970
+ "[1.48271036 1.10563457 1.98858833 1.82286644]"
971
+ ],
972
+ [
973
+ "[0.98033333 0.97656083 0.38939917 0.81491041]",
974
+ "[1.98033333 1.97656083 1.38939917 1.81491041]"
975
+ ],
976
+ [
977
+ "[0.94221359 0.57740951 0.98649532 0.40934443]",
978
+ "[1.94221354 1.57740951 1.98649526 1.40934443]"
979
+ ],
980
+ [
981
+ "[0.00497234 0.39319336 0.57054168 0.75150961]",
982
+ "[1.00497234 1.39319336 1.57054162 1.75150967]"
983
+ ],
984
+ [
985
+ "[0.44330525 0.09997386 0.89025736 0.90507984]",
986
+ "[1.44330525 1.09997392 1.89025736 1.90507984]"
987
+ ],
988
+ [
989
+ "[0.72290605 0.96945059 0.68354797 0.15270454]",
990
+ "[1.72290611 1.96945059 1.68354797 1.15270448]"
991
+ ],
992
+ [
993
+ "[0.75292218 0.81470108 0.49657214 0.56217098]",
994
+ "[1.75292218 1.81470108 1.49657214 1.56217098]"
995
+ ],
996
+ [
997
+ "[0.33480108 0.59181517 0.76198453 0.98062384]",
998
+ "[1.33480108 1.59181523 1.76198459 1.98062384]"
999
+ ],
1000
+ [
1001
+ "[0.52784437 0.54268694 0.12358981 0.72116476]",
1002
+ "[1.52784443 1.54268694 1.12358975 1.7211647 ]"
1003
+ ],
1004
+ [
1005
+ "[0.73217702 0.65233225 0.44077861 0.33837909]",
1006
+ "[1.73217702 1.65233231 1.44077861 1.33837914]"
1007
+ ],
1008
+ [
1009
+ "[0.34084332 0.73018837 0.54168713 0.91440833]",
1010
+ "[1.34084332 1.73018837 1.54168713 1.91440833]"
1011
+ ],
1012
+ [
1013
+ "[0.60110539 0.3618983 0.32342511 0.98672163]",
1014
+ "[1.60110545 1.3618983 1.32342505 1.98672163]"
1015
+ ],
1016
+ [
1017
+ "[0.77427191 0.21829212 0.12769502 0.74303615]",
1018
+ "[1.77427197 1.21829212 1.12769508 1.74303615]"
1019
+ ],
1020
+ [
1021
+ "[0.08107251 0.2602725 0.18861133 0.44833237]",
1022
+ "[1.08107257 1.2602725 1.18861127 1.44833231]"
1023
+ ],
1024
+ [
1025
+ "[0.93488538 0.73882395 0.37345302 0.0274905 ]",
1026
+ "[1.93488538 1.73882389 1.37345302 1.0274905 ]"
1027
+ ],
1028
+ [
1029
+ "[0.30631393 0.48311198 0.87847513 0.67559886]",
1030
+ "[1.30631399 1.48311198 1.87847519 1.67559886]"
1031
+ ],
1032
+ [
1033
+ "[0.85566247 0.83362883 0.48424995 0.25265992]",
1034
+ "[1.85566247 1.83362889 1.48424995 1.25265992]"
1035
+ ],
1036
+ [
1037
+ "[0.95928186 0.84273899 0.71514636 0.38619852]",
1038
+ "[1.95928192 1.84273899 1.7151463 1.38619852]"
1039
+ ],
1040
+ [
1041
+ "[0.32565445 0.90939188 0.07488042 0.13730896]",
1042
+ "[1.32565451 1.90939188 1.07488036 1.13730896]"
1043
+ ],
1044
+ [
1045
+ "[0.9829582 0.59269661 0.40120947 0.95487177]",
1046
+ "[1.9829582 1.59269667 1.40120947 1.95487177]"
1047
+ ],
1048
+ [
1049
+ "[0.79905868 0.89367443 0.75429088 0.3190186 ]",
1050
+ "[1.79905868 1.89367437 1.75429082 1.3190186 ]"
1051
+ ],
1052
+ [
1053
+ "[0.54914117 0.03810108 0.87531954 0.73044223]",
1054
+ "[1.54914117 1.03810108 1.87531948 1.73044229]"
1055
+ ],
1056
+ [
1057
+ "[0.67418337 0.79634351 0.23229051 0.71345252]",
1058
+ "[1.67418337 1.79634356 1.23229051 1.71345258]"
1059
+ ],
1060
+ [
1061
+ "[0.87285906 0.48354989 0.39394957 0.59456545]",
1062
+ "[1.872859 1.48354983 1.39394951 1.59456539]"
1063
+ ],
1064
+ [
1065
+ "[0.81788456 0.58174163 0.29376316 0.7971254 ]",
1066
+ "[1.81788456 1.58174157 1.29376316 1.79712534]"
1067
+ ],
1068
+ [
1069
+ "[0.94559073 0.65736622 0.25761551 0.48553199]",
1070
+ "[1.94559073 1.65736628 1.25761557 1.48553205]"
1071
+ ],
1072
+ [
1073
+ "[0.60075855 0.12234765 0.00614399 0.30560958]",
1074
+ "[1.60075855 1.12234759 1.00614405 1.30560958]"
1075
+ ],
1076
+ [
1077
+ "[0.39147133 0.29854035 0.84663737 0.58175623]",
1078
+ "[1.39147139 1.29854035 1.84663737 1.58175623]"
1079
+ ],
1080
+ [
1081
+ "[0.02162331 0.81861657 0.92468154 0.07808572]",
1082
+ "[1.02162337 1.81861663 1.92468154 1.07808566]"
1083
+ ],
1084
+ [
1085
+ "[0.02235305 0.52774918 0.7331115 0.84358269]",
1086
+ "[1.02235305 1.52774918 1.7331115 1.84358263]"
1087
+ ],
1088
+ [
1089
+ "[0.6080932 0.56563014 0.32107437 0.72599429]",
1090
+ "[1.60809326 1.5656302 1.32107437 1.72599435]"
1091
+ ],
1092
+ [
1093
+ "[0.47963417 0.81818312 0.48720706 0.49339259]",
1094
+ "[1.47963417 1.81818318 1.48720706 1.49339259]"
1095
+ ],
1096
+ [
1097
+ "[0.9630242 0.76359051 0.24853623 0.76881069]",
1098
+ "[1.96302414 1.76359057 1.24853623 1.76881075]"
1099
+ ],
1100
+ [
1101
+ "[0.60609657 0.96257663 0.19292736 0.95702219]",
1102
+ "[1.60609651 1.96257663 1.19292736 1.95702219]"
1103
+ ],
1104
+ [
1105
+ "[0.80654246 0.08253473 0.74478531 0.71257162]",
1106
+ "[1.8065424 1.08253479 1.74478531 1.71257162]"
1107
+ ],
1108
+ [
1109
+ "[0.76933283 0.86241865 0.44114518 0.65644735]",
1110
+ "[1.76933289 1.86241865 1.44114518 1.65644741]"
1111
+ ],
1112
+ [
1113
+ "[0.59492421 0.90274489 0.38069052 0.46101224]",
1114
+ "[1.59492421 1.90274489 1.38069057 1.46101224]"
1115
+ ],
1116
+ [
1117
+ "[0.15064228 0.03198934 0.25754827 0.51484001]",
1118
+ "[1.15064228 1.03198934 1.25754833 1.51484001]"
1119
+ ],
1120
+ [
1121
+ "[0.12024075 0.21342516 0.56858408 0.58644271]",
1122
+ "[1.12024069 1.21342516 1.56858408 1.58644271]"
1123
+ ],
1124
+ [
1125
+ "[0.91730917 0.22574073 0.09591609 0.33056474]",
1126
+ "[1.91730917 1.22574067 1.09591603 1.33056474]"
1127
+ ],
1128
+ [
1129
+ "[0.6032477 0.83361369 0.18538666 0.19108021]",
1130
+ "[1.60324764 1.83361363 1.18538666 1.19108021]"
1131
+ ],
1132
+ [
1133
+ "[0.63235509 0.70352674 0.96188956 0.46240485]",
1134
+ "[1.63235509 1.70352674 1.96188951 1.46240485]"
1135
+ ],
1136
+ [
1137
+ "[0.37959969 0.42820001 0.10690689 0.96353984]",
1138
+ "[1.37959969 1.42820001 1.10690689 1.96353984]"
1139
+ ],
1140
+ [
1141
+ "[0.40234613 0.54987347 0.49542785 0.54153186]",
1142
+ "[1.40234613 1.54987347 1.49542785 1.5415318 ]"
1143
+ ],
1144
+ [
1145
+ "[0.80893755 0.92237449 0.88346356 0.93164903]",
1146
+ "[1.80893755 1.92237449 1.88346362 1.93164897]"
1147
+ ],
1148
+ [
1149
+ "[0.12858278 0.09930819 0.83222693 0.72485673]",
1150
+ "[1.12858272 1.09930825 1.83222699 1.72485673]"
1151
+ ],
1152
+ [
1153
+ "[0.72470158 0.4940322 0.41027349 0.89364016]",
1154
+ "[1.72470164 1.49403214 1.41027355 1.89364016]"
1155
+ ],
1156
+ [
1157
+ "[0.47856545 0.46267092 0.6376707 0.84747767]",
1158
+ "[1.47856545 1.46267092 1.63767076 1.84747767]"
1159
+ ],
1160
+ [
1161
+ "[0.49584109 0.80599248 0.07096875 0.75872749]",
1162
+ "[1.49584103 1.80599248 1.07096875 1.75872755]"
1163
+ ],
1164
+ [
1165
+ "[0.43500566 0.66041756 0.80293626 0.96224713]",
1166
+ "[1.43500566 1.66041756 1.80293632 1.96224713]"
1167
+ ],
1168
+ [
1169
+ "[0.78397602 0.74223626 0.26603186 0.41664881]",
1170
+ "[1.78397608 1.74223626 1.26603186 1.41664886]"
1171
+ ],
1172
+ [
1173
+ "[0.28942841 0.05601001 0.33039129 0.27781558]",
1174
+ "[1.28942847 1.05601001 1.33039129 1.27781558]"
1175
+ ],
1176
+ [
1177
+ "[0.68094063 0.45189077 0.22661722 0.37354094]",
1178
+ "[1.68094063 1.45189071 1.22661722 1.37354088]"
1179
+ ],
1180
+ [
1181
+ "[0.43681622 0.74680805 0.83598751 0.12414402]",
1182
+ "[1.43681622 1.74680805 1.83598757 1.12414408]"
1183
+ ],
1184
+ [
1185
+ "[0.47870928 0.17129105 0.27300501 0.20634609]",
1186
+ "[1.47870922 1.17129111 1.27300501 1.20634604]"
1187
+ ],
1188
+ [
1189
+ "[0.72795159 0.79317838 0.27832931 0.96576637]",
1190
+ "[1.72795153 1.79317832 1.27832937 1.96576643]"
1191
+ ],
1192
+ [
1193
+ "[0.87608397 0.93200487 0.80169648 0.37758952]",
1194
+ "[1.87608397 1.93200493 1.80169654 1.37758946]"
1195
+ ],
1196
+ [
1197
+ "[0.68891573 0.25576538 0.96339929 0.503833 ]",
1198
+ "[1.68891573 1.25576544 1.96339929 1.50383306]"
1199
+ ]
1200
+ ]
1201
+ }
1202
+ },
1203
+ "other": {
1204
+ "model": "ModelConfig(model=Sequential(\n (0) - Linear(in_features=4, out_features=4, bias=True): Input__embedding_1_x -> Linear_1_x\n (1) - <function leaky_relu at 0x710fc8f0fba0>: Linear_1_x -> Activation_2_x\n (2) - Identity(): Activation_2_x -> Activation_2_x\n), model_inputs=['Input__embedding_1_x'], model_outputs=['Activation_2_x'], loss_inputs=['Activation_2_x', 'Input__label_1_y'], loss=Sequential(\n (0) - <function mse_loss at 0x710fc8f316c0>: Activation_2_x, Input__label_1_y -> MSE_loss_1_loss\n (1) - Identity(): MSE_loss_1_loss -> loss\n), optimizer=SGD (\nParameter Group 0\n dampening: 0\n differentiable: False\n foreach: None\n fused: None\n lr: 0.1\n maximize: False\n momentum: 0\n nesterov: False\n weight_decay: 0\n))"
1205
+ },
1206
+ "relations": []
1207
+ },
1208
+ "error": null,
1209
+ "meta": {
1210
+ "inputs": {
1211
+ "bundle": {
1212
+ "name": "bundle",
1213
+ "position": "left",
1214
+ "type": {
1215
+ "type": "<class 'lynxkite_graph_analytics.core.Bundle'>"
1216
+ }
1217
+ }
1218
+ },
1219
+ "name": "View tables",
1220
+ "outputs": {},
1221
+ "params": {
1222
+ "limit": {
1223
+ "default": 100.0,
1224
+ "name": "limit",
1225
+ "type": {
1226
+ "type": "<class 'int'>"
1227
+ }
1228
+ }
1229
+ },
1230
+ "position": {
1231
+ "x": 471.0,
1232
+ "y": 424.0
1233
+ },
1234
+ "type": "table_view"
1235
+ },
1236
+ "params": {
1237
+ "limit": 100.0
1238
+ },
1239
+ "status": "planned",
1240
+ "title": "View tables"
1241
+ },
1242
+ "dragHandle": ".bg-primary",
1243
+ "height": 600.0,
1244
+ "id": "View tables 1",
1245
+ "position": {
1246
+ "x": 2017.4630208327735,
1247
+ "y": -223.54449620081252
1248
+ },
1249
+ "type": "table_view",
1250
+ "width": 582.0
1251
+ },
1252
+ {
1253
+ "data": {
1254
+ "__execution_delay": 0.0,
1255
+ "collapsed": null,
1256
+ "display": {
1257
+ "dataframes": {
1258
+ "df": {
1259
+ "columns": [
1260
+ "x",
1261
+ "y"
1262
+ ]
1263
+ },
1264
+ "df_test": {
1265
+ "columns": [
1266
+ "x",
1267
+ "y"
1268
+ ]
1269
+ },
1270
+ "df_train": {
1271
+ "columns": [
1272
+ "x",
1273
+ "y"
1274
+ ]
1275
+ }
1276
+ },
1277
+ "other": {
1278
+ "model": {
1279
+ "model": {
1280
+ "inputs": [],
1281
+ "loss_inputs": [
1282
+ "Activation_2_x",
1283
+ "Input__label_1_y"
1284
+ ],
1285
+ "outputs": [
1286
+ "Activation_2_x",
1287
+ "Input__label_1_y"
1288
+ ]
1289
+ },
1290
+ "type": "model"
1291
+ }
1292
+ },
1293
+ "relations": []
1294
+ },
1295
+ "error": "Mapping is unset.",
1296
+ "meta": {
1297
+ "inputs": {
1298
+ "bundle": {
1299
+ "name": "bundle",
1300
+ "position": "left",
1301
+ "type": {
1302
+ "type": "<class 'lynxkite_graph_analytics.core.Bundle'>"
1303
+ }
1304
+ }
1305
+ },
1306
+ "name": "Train model",
1307
+ "outputs": {
1308
+ "output": {
1309
+ "name": "output",
1310
+ "position": "right",
1311
+ "type": {
1312
+ "type": "None"
1313
+ }
1314
+ }
1315
+ },
1316
+ "params": {
1317
+ "epochs": {
1318
+ "default": 1.0,
1319
+ "name": "epochs",
1320
+ "type": {
1321
+ "type": "<class 'int'>"
1322
+ }
1323
+ },
1324
+ "input_mapping": {
1325
+ "default": null,
1326
+ "name": "input_mapping",
1327
+ "type": {
1328
+ "type": "<class 'lynxkite_graph_analytics.pytorch_model_ops.ModelMapping'>"
1329
+ }
1330
+ },
1331
+ "model_workspace": {
1332
+ "default": null,
1333
+ "name": "model_workspace",
1334
+ "type": {
1335
+ "type": "<class 'str'>"
1336
+ }
1337
+ },
1338
+ "save_as": {
1339
+ "default": "model",
1340
+ "name": "save_as",
1341
+ "type": {
1342
+ "type": "<class 'str'>"
1343
+ }
1344
+ }
1345
+ },
1346
+ "position": {
1347
+ "x": 723.0,
1348
+ "y": 370.0
1349
+ },
1350
+ "type": "basic"
1351
+ },
1352
+ "params": {
1353
+ "epochs": "2",
1354
+ "input_mapping": "{\"map\":{\"Activation_2_x\":{\"df\":\"df_train\"},\"Input__label_1_y\":{\"df\":\"df_train\",\"column\":\"y\"}}}",
1355
+ "model_workspace": "Model definition",
1356
+ "save_as": "model"
1357
+ },
1358
+ "status": "done",
1359
+ "title": "Train model"
1360
+ },
1361
+ "dragHandle": ".bg-primary",
1362
+ "height": 473.0,
1363
+ "id": "Train model 1",
1364
+ "position": {
1365
+ "x": 712.1212754578014,
1366
+ "y": 42.33722689912529
1367
+ },
1368
+ "type": "basic",
1369
+ "width": 577.0
1370
  }
1371
  ]
1372
  }
lynxkite-app/web/src/index.css CHANGED
@@ -256,6 +256,14 @@ body {
256
  cursor: pointer;
257
  }
258
  }
 
 
 
 
 
 
 
 
259
  }
260
 
261
  .params-expander {
 
256
  cursor: pointer;
257
  }
258
  }
259
+
260
+ .model-mapping-param {
261
+ border: 1px solid var(--fallback-bc, oklch(var(--bc) / 0.2));
262
+ border-collapse: separate;
263
+ border-radius: 5px;
264
+ padding: 5px 10px;
265
+ width: 100%;
266
+ }
267
  }
268
 
269
  .params-expander {
lynxkite-app/web/src/workspace/nodes/NodeParameter.tsx CHANGED
@@ -1,15 +1,127 @@
1
- const BOOLEAN = "<class 'bool'>";
 
2
 
 
 
 
3
  function ParamName({ name }: { name: string }) {
4
  return (
5
  <span className="param-name bg-base-200">{name.replace(/_/g, " ")}</span>
6
  );
7
  }
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  interface NodeParameterProps {
10
  name: string;
11
  value: any;
12
  meta: any;
 
13
  onChange: (value: any, options?: { delay: number }) => void;
14
  }
15
 
@@ -17,6 +129,7 @@ export default function NodeParameter({
17
  name,
18
  value,
19
  meta,
 
20
  onChange,
21
  }: NodeParameterProps) {
22
  return (
@@ -65,6 +178,11 @@ export default function NodeParameter({
65
  {name.replace(/_/g, " ")}
66
  </label>
67
  </div>
 
 
 
 
 
68
  ) : (
69
  <>
70
  <ParamName name={name} />
 
1
+ // @ts-ignore
2
+ import ArrowsHorizontal from "~icons/tabler/arrows-horizontal.jsx";
3
 
4
+ const BOOLEAN = "<class 'bool'>";
5
+ const MODEL_MAPPING =
6
+ "<class 'lynxkite_graph_analytics.pytorch_model_ops.ModelMapping'>";
7
  function ParamName({ name }: { name: string }) {
8
  return (
9
  <span className="param-name bg-base-200">{name.replace(/_/g, " ")}</span>
10
  );
11
  }
12
 
13
+ function getModelBindings(data: any): string[] {
14
+ function bindingsOfModel(m: any): string[] {
15
+ return [...m.inputs, ...m.outputs, ...m.loss_inputs];
16
+ }
17
+ const bindings = new Set<string>();
18
+ const other = data?.display?.other ?? data?.display?.value?.other ?? {};
19
+ for (const e of Object.values(other) as any[]) {
20
+ if (e.type === "model") {
21
+ for (const b of bindingsOfModel(e.model)) {
22
+ bindings.add(b);
23
+ }
24
+ }
25
+ }
26
+ const list = [...bindings];
27
+ list.sort();
28
+ return list;
29
+ }
30
+
31
+ function parseJsonOrEmpty(json: string): object {
32
+ try {
33
+ const j = JSON.parse(json);
34
+ if (typeof j === "object") {
35
+ return j;
36
+ }
37
+ } catch (e) {}
38
+ return {};
39
+ }
40
+
41
+ function ModelMapping({ value, onChange, data }: any) {
42
+ const v: any = parseJsonOrEmpty(value);
43
+ v.map ??= {};
44
+ const dfs =
45
+ data?.display?.dataframes ?? data?.display?.value?.dataframes ?? {};
46
+ const bindings = getModelBindings(data);
47
+ return (
48
+ <table className="model-mapping-param">
49
+ <tbody>
50
+ <tr>
51
+ <td>mm</td>
52
+ </tr>
53
+ {bindings.length > 0 ? (
54
+ bindings.map((binding: string) => (
55
+ <tr key={binding}>
56
+ <td>{binding}</td>
57
+ <td>
58
+ <ArrowsHorizontal />
59
+ </td>
60
+ <td>
61
+ <select
62
+ className="select select-ghost"
63
+ value={v.map?.[binding]?.df}
64
+ onChange={(evt) => {
65
+ const df = evt.currentTarget.value;
66
+ if (df === "unbound") {
67
+ const map = { ...v.map, [binding]: undefined };
68
+ onChange(JSON.stringify({ map }));
69
+ } else {
70
+ const columnSpec = {
71
+ column: dfs[df][0],
72
+ ...(v.map?.[binding] || {}),
73
+ df,
74
+ };
75
+ const map = { ...v.map, [binding]: columnSpec };
76
+ onChange(JSON.stringify({ map }));
77
+ }
78
+ }}
79
+ >
80
+ <option key="unbound" value="unbound">
81
+ unbound
82
+ </option>
83
+ {Object.keys(dfs).map((df: string) => (
84
+ <option key={df} value={df}>
85
+ {df}
86
+ </option>
87
+ ))}
88
+ </select>
89
+ </td>
90
+ <td>
91
+ <select
92
+ className="select select-ghost"
93
+ value={v.map?.[binding]?.column}
94
+ onChange={(evt) => {
95
+ const column = evt.currentTarget.value;
96
+ const columnSpec = { ...(v.map?.[binding] || {}), column };
97
+ const map = { ...v.map, [binding]: columnSpec };
98
+ onChange(JSON.stringify({ map }));
99
+ }}
100
+ >
101
+ {dfs[v.map?.[binding]?.df]?.columns.map((col: string) => (
102
+ <option key={col} value={col}>
103
+ {col}
104
+ </option>
105
+ ))}
106
+ </select>
107
+ </td>
108
+ </tr>
109
+ ))
110
+ ) : (
111
+ <tr>
112
+ <td>no bindings</td>
113
+ </tr>
114
+ )}
115
+ </tbody>
116
+ </table>
117
+ );
118
+ }
119
+
120
  interface NodeParameterProps {
121
  name: string;
122
  value: any;
123
  meta: any;
124
+ data: any;
125
  onChange: (value: any, options?: { delay: number }) => void;
126
  }
127
 
 
129
  name,
130
  value,
131
  meta,
132
+ data,
133
  onChange,
134
  }: NodeParameterProps) {
135
  return (
 
178
  {name.replace(/_/g, " ")}
179
  </label>
180
  </div>
181
+ ) : meta?.type?.type === MODEL_MAPPING ? (
182
+ <>
183
+ <ParamName name={name} />
184
+ <ModelMapping value={value} data={data} onChange={onChange} />
185
+ </>
186
  ) : (
187
  <>
188
  <ParamName name={name} />
lynxkite-app/web/src/workspace/nodes/NodeWithParams.tsx CHANGED
@@ -62,6 +62,7 @@ function NodeWithParams(props: any) {
62
  name={name}
63
  key={name}
64
  value={value}
 
65
  meta={metaParams?.[name]}
66
  onChange={(value: any, opts?: UpdateOptions) =>
67
  setParam(name, value, opts || {})
 
62
  name={name}
63
  key={name}
64
  value={value}
65
+ data={props.data}
66
  meta={metaParams?.[name]}
67
  onChange={(value: any, opts?: UpdateOptions) =>
68
  setParam(name, value, opts || {})
lynxkite-core/src/lynxkite/core/ops.py CHANGED
@@ -112,6 +112,13 @@ class Result:
112
  display: ReadOnlyJSON | None = None
113
  error: str | None = None
114
 
 
 
 
 
 
 
 
115
 
116
  MULTI_INPUT = Input(name="multi", type="*")
117
 
@@ -140,6 +147,11 @@ def _param_to_type(name, value, type):
140
  return None if value == "" else _param_to_type(name, value, type)
141
  case (type, types.NoneType):
142
  return None if value == "" else _param_to_type(name, value, type)
 
 
 
 
 
143
  return value
144
 
145
 
@@ -174,9 +186,10 @@ class Op(BaseConfig):
174
  """Returns the parameters converted to the expected type."""
175
  res = {}
176
  for p in params:
177
- res[p] = params[p]
178
  if p in self.params:
179
  res[p] = _param_to_type(p, params[p], self.params[p].type)
 
 
180
  return res
181
 
182
 
 
112
  display: ReadOnlyJSON | None = None
113
  error: str | None = None
114
 
115
+ def default_display(self) -> ReadOnlyJSON | None:
116
+ """Automatically extracts basic data from the output."""
117
+ if hasattr(self.output, "default_display"):
118
+ return self.output.default_display()
119
+ else:
120
+ return None
121
+
122
 
123
  MULTI_INPUT = Input(name="multi", type="*")
124
 
 
147
  return None if value == "" else _param_to_type(name, value, type)
148
  case (type, types.NoneType):
149
  return None if value == "" else _param_to_type(name, value, type)
150
+ if issubclass(type, pydantic.BaseModel):
151
+ try:
152
+ return type.model_validate_json(value)
153
+ except pydantic.ValidationError:
154
+ return None
155
  return value
156
 
157
 
 
186
  """Returns the parameters converted to the expected type."""
187
  res = {}
188
  for p in params:
 
189
  if p in self.params:
190
  res[p] = _param_to_type(p, params[p], self.params[p].type)
191
+ else:
192
+ res[p] = params[p]
193
  return res
194
 
195
 
lynxkite-core/src/lynxkite/core/workspace.py CHANGED
@@ -58,13 +58,13 @@ class WorkspaceNode(BaseConfig):
58
 
59
  def publish_result(self, result: ops.Result):
60
  """Sends the result to the frontend. Call this in an executor when the result is available."""
61
- self.data.display = result.display
62
  self.data.error = result.error
63
  self.data.status = NodeStatus.done
64
  if hasattr(self, "_crdt"):
65
  with self._crdt.doc.transaction():
66
- self._crdt["data"]["display"] = result.display
67
- self._crdt["data"]["error"] = result.error
68
  self._crdt["data"]["status"] = NodeStatus.done
69
 
70
  def publish_error(self, error: Exception | str | None):
 
58
 
59
  def publish_result(self, result: ops.Result):
60
  """Sends the result to the frontend. Call this in an executor when the result is available."""
61
+ self.data.display = result.display or result.default_display()
62
  self.data.error = result.error
63
  self.data.status = NodeStatus.done
64
  if hasattr(self, "_crdt"):
65
  with self._crdt.doc.transaction():
66
+ self._crdt["data"]["display"] = self.data.display
67
+ self._crdt["data"]["error"] = self.data.error
68
  self._crdt["data"]["status"] = NodeStatus.done
69
 
70
  def publish_error(self, error: Exception | str | None):
lynxkite-graph-analytics/src/lynxkite_graph_analytics/core.py CHANGED
@@ -106,6 +106,7 @@ class Bundle:
106
  )
107
 
108
  def to_dict(self, limit: int = 100):
 
109
  return {
110
  "dataframes": {
111
  name: {
@@ -115,7 +116,23 @@ class Bundle:
115
  for name, df in self.dfs.items()
116
  },
117
  "relations": [dataclasses.asdict(relation) for relation in self.relations],
118
- "other": self.other,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  }
120
 
121
 
 
106
  )
107
 
108
  def to_dict(self, limit: int = 100):
109
+ """JSON-serializable representation of the bundle, including some data."""
110
  return {
111
  "dataframes": {
112
  name: {
 
116
  for name, df in self.dfs.items()
117
  },
118
  "relations": [dataclasses.asdict(relation) for relation in self.relations],
119
+ "other": {k: str(v) for k, v in self.other.items()},
120
+ }
121
+
122
+ def default_display(self):
123
+ """JSON-serializable information about the bundle, metadata only."""
124
+ return {
125
+ "dataframes": {
126
+ name: {
127
+ "columns": sorted(str(c) for c in df.columns),
128
+ }
129
+ for name, df in self.dfs.items()
130
+ },
131
+ "relations": [dataclasses.asdict(relation) for relation in self.relations],
132
+ "other": {
133
+ k: getattr(v, "default_display", lambda: {})()
134
+ for k, v in self.other.items()
135
+ },
136
  }
137
 
138
 
lynxkite-graph-analytics/src/lynxkite_graph_analytics/lynxkite_ops.py CHANGED
@@ -368,21 +368,26 @@ def train_model(
368
  bundle: core.Bundle,
369
  *,
370
  model_workspace: str,
371
- input_mapping: str,
372
  epochs: int = 1,
373
  save_as: str = "model",
374
  ):
375
  """Trains the selected model on the selected dataset. Most training parameters are set in the model definition."""
 
 
376
  ws = load_ws(model_workspace)
377
- input_mapping = json.loads(input_mapping)
378
- inputs = pytorch_model_ops.to_tensors(bundle, input_mapping)
 
379
  m = pytorch_model_ops.build_model(ws, inputs)
 
 
 
 
380
  t = tqdm(range(epochs), desc="Training model")
381
  for _ in t:
382
  loss = m.train(inputs)
383
  t.set_postfix({"loss": loss})
384
- bundle = bundle.copy()
385
- bundle.other[save_as] = m
386
  return bundle
387
 
388
 
@@ -391,18 +396,18 @@ def model_inference(
391
  bundle: core.Bundle,
392
  *,
393
  model_name: str = "model",
394
- input_mapping: str = "",
395
- output_mapping: str = "",
396
  ):
397
  """Executes a trained model."""
 
 
398
  m = bundle.other[model_name]
399
- input_mapping = json.loads(input_mapping)
400
- output_mapping = json.loads(output_mapping)
401
  inputs = pytorch_model_ops.to_tensors(bundle, input_mapping)
402
  outputs = m.inference(inputs)
403
  bundle = bundle.copy()
404
- for k, v in output_mapping.items():
405
- bundle.dfs[v["df"]][v["column"]] = outputs[k].detach().numpy().tolist()
406
  return bundle
407
 
408
 
 
368
  bundle: core.Bundle,
369
  *,
370
  model_workspace: str,
371
+ input_mapping: pytorch_model_ops.ModelMapping,
372
  epochs: int = 1,
373
  save_as: str = "model",
374
  ):
375
  """Trains the selected model on the selected dataset. Most training parameters are set in the model definition."""
376
+ assert model_workspace, "Model workspace is unset."
377
+ print(f"input_mapping: {input_mapping}")
378
  ws = load_ws(model_workspace)
379
+ inputs = (
380
+ pytorch_model_ops.to_tensors(bundle, input_mapping) if input_mapping else {}
381
+ )
382
  m = pytorch_model_ops.build_model(ws, inputs)
383
+ bundle = bundle.copy()
384
+ bundle.other[save_as] = m
385
+ if input_mapping is None:
386
+ return ops.Result(bundle, error="Mapping is unset.")
387
  t = tqdm(range(epochs), desc="Training model")
388
  for _ in t:
389
  loss = m.train(inputs)
390
  t.set_postfix({"loss": loss})
 
 
391
  return bundle
392
 
393
 
 
396
  bundle: core.Bundle,
397
  *,
398
  model_name: str = "model",
399
+ input_mapping: pytorch_model_ops.ModelMapping,
400
+ output_mapping: pytorch_model_ops.ModelMapping,
401
  ):
402
  """Executes a trained model."""
403
+ if input_mapping is None or output_mapping is None:
404
+ return ops.Result(bundle, error="Mapping is unset.")
405
  m = bundle.other[model_name]
 
 
406
  inputs = pytorch_model_ops.to_tensors(bundle, input_mapping)
407
  outputs = m.inference(inputs)
408
  bundle = bundle.copy()
409
+ for k, v in output_mapping.map.items():
410
+ bundle.dfs[v.df][v.column] = outputs[k].detach().numpy().tolist()
411
  return bundle
412
 
413
 
lynxkite-graph-analytics/src/lynxkite_graph_analytics/pytorch_model_ops.py CHANGED
@@ -1,6 +1,8 @@
1
  """Boxes for defining PyTorch models."""
2
 
3
  import graphlib
 
 
4
  from lynxkite.core import ops, workspace
5
  from lynxkite.core.ops import Parameter as P
6
  import torch
@@ -128,6 +130,15 @@ def _to_id(s: str) -> str:
128
  return "".join(c if c.isalnum() else "_" for c in s)
129
 
130
 
 
 
 
 
 
 
 
 
 
131
  @dataclass
132
  class ModelConfig:
133
  model: torch.nn.Module
@@ -169,6 +180,16 @@ class ModelConfig:
169
  c.model = self.model.copy()
170
  return c
171
 
 
 
 
 
 
 
 
 
 
 
172
 
173
  def build_model(
174
  ws: workspace.Workspace, inputs: dict[str, torch.Tensor]
@@ -241,9 +262,9 @@ def build_model(
241
  loss_layers.append(
242
  (torch.nn.functional.mse_loss, f"{xi}, {yi} -> {nid}_loss")
243
  )
244
- cfg["model_inputs"] = used_inputs & inputs.keys()
245
- cfg["model_outputs"] = loss_inputs - inputs.keys()
246
- cfg["loss_inputs"] = loss_inputs
247
  # Make sure the trained output is output from the last model layer.
248
  outputs = ", ".join(cfg["model_outputs"])
249
  layers.append((torch.nn.Identity(), f"{outputs} -> {outputs}"))
@@ -266,11 +287,9 @@ def build_model(
266
  return ModelConfig(**cfg)
267
 
268
 
269
- def to_tensors(b: core.Bundle, m: dict[str, dict]) -> dict[str, torch.Tensor]:
270
  """Converts a tensor to the correct type for PyTorch."""
271
  tensors = {}
272
- for k, v in m.items():
273
- tensors[k] = torch.tensor(
274
- b.dfs[v["df"]][v["column"]].to_list(), dtype=torch.float32
275
- )
276
  return tensors
 
1
  """Boxes for defining PyTorch models."""
2
 
3
  import graphlib
4
+
5
+ import pydantic
6
  from lynxkite.core import ops, workspace
7
  from lynxkite.core.ops import Parameter as P
8
  import torch
 
130
  return "".join(c if c.isalnum() else "_" for c in s)
131
 
132
 
133
+ class ColumnSpec(pydantic.BaseModel):
134
+ df: str
135
+ column: str
136
+
137
+
138
+ class ModelMapping(pydantic.BaseModel):
139
+ map: dict[str, ColumnSpec]
140
+
141
+
142
  @dataclass
143
  class ModelConfig:
144
  model: torch.nn.Module
 
180
  c.model = self.model.copy()
181
  return c
182
 
183
+ def default_display(self):
184
+ return {
185
+ "type": "model",
186
+ "model": {
187
+ "inputs": self.model_inputs,
188
+ "outputs": self.model_outputs,
189
+ "loss_inputs": self.loss_inputs,
190
+ },
191
+ }
192
+
193
 
194
  def build_model(
195
  ws: workspace.Workspace, inputs: dict[str, torch.Tensor]
 
262
  loss_layers.append(
263
  (torch.nn.functional.mse_loss, f"{xi}, {yi} -> {nid}_loss")
264
  )
265
+ cfg["model_inputs"] = list(used_inputs & inputs.keys())
266
+ cfg["model_outputs"] = list(loss_inputs - inputs.keys())
267
+ cfg["loss_inputs"] = list(loss_inputs)
268
  # Make sure the trained output is output from the last model layer.
269
  outputs = ", ".join(cfg["model_outputs"])
270
  layers.append((torch.nn.Identity(), f"{outputs} -> {outputs}"))
 
287
  return ModelConfig(**cfg)
288
 
289
 
290
+ def to_tensors(b: core.Bundle, m: ModelMapping) -> dict[str, torch.Tensor]:
291
  """Converts a tensor to the correct type for PyTorch."""
292
  tensors = {}
293
+ for k, v in m.map.items():
294
+ tensors[k] = torch.tensor(b.dfs[v.df][v.column].to_list(), dtype=torch.float32)
 
 
295
  return tensors