darabos commited on
Commit
1353683
·
1 Parent(s): 3cc3a0a

Training and inference.

Browse files
examples/Model definition ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "edges": [
3
+ {
4
+ "id": "Input: embedding 1 Linear 1",
5
+ "source": "Input: embedding 1",
6
+ "sourceHandle": "x",
7
+ "target": "Linear 1",
8
+ "targetHandle": "x"
9
+ },
10
+ {
11
+ "id": "Input: label 1 MSE loss 1",
12
+ "source": "Input: label 1",
13
+ "sourceHandle": "y",
14
+ "target": "MSE loss 1",
15
+ "targetHandle": "y"
16
+ },
17
+ {
18
+ "id": "Linear 1 Activation 2",
19
+ "source": "Linear 1",
20
+ "sourceHandle": "x",
21
+ "target": "Activation 2",
22
+ "targetHandle": "x"
23
+ },
24
+ {
25
+ "id": "Activation 2 MSE loss 1",
26
+ "source": "Activation 2",
27
+ "sourceHandle": "x",
28
+ "target": "MSE loss 1",
29
+ "targetHandle": "x"
30
+ },
31
+ {
32
+ "id": "MSE loss 1 Optimizer 2",
33
+ "source": "MSE loss 1",
34
+ "sourceHandle": "loss",
35
+ "target": "Optimizer 2",
36
+ "targetHandle": "loss"
37
+ }
38
+ ],
39
+ "env": "PyTorch model",
40
+ "nodes": [
41
+ {
42
+ "data": {
43
+ "display": null,
44
+ "error": null,
45
+ "meta": {
46
+ "inputs": {},
47
+ "name": "Input: embedding",
48
+ "outputs": {
49
+ "x": {
50
+ "name": "x",
51
+ "position": "top",
52
+ "type": {
53
+ "type": "tensor"
54
+ }
55
+ }
56
+ },
57
+ "params": {},
58
+ "type": "basic"
59
+ },
60
+ "params": {},
61
+ "status": "planned",
62
+ "title": "Input: embedding"
63
+ },
64
+ "dragHandle": ".bg-primary",
65
+ "height": 200.0,
66
+ "id": "Input: embedding 1",
67
+ "position": {
68
+ "x": 91.0,
69
+ "y": 266.0
70
+ },
71
+ "type": "basic",
72
+ "width": 200.0
73
+ },
74
+ {
75
+ "data": {
76
+ "display": null,
77
+ "error": null,
78
+ "meta": {
79
+ "inputs": {
80
+ "x": {
81
+ "name": "x",
82
+ "position": "bottom",
83
+ "type": {
84
+ "type": "tensor"
85
+ }
86
+ }
87
+ },
88
+ "name": "Linear",
89
+ "outputs": {
90
+ "x": {
91
+ "name": "x",
92
+ "position": "top",
93
+ "type": {
94
+ "type": "tensor"
95
+ }
96
+ }
97
+ },
98
+ "params": {
99
+ "output_dim": {
100
+ "default": "same",
101
+ "name": "output_dim",
102
+ "type": {
103
+ "type": "<class 'str'>"
104
+ }
105
+ }
106
+ },
107
+ "type": "basic"
108
+ },
109
+ "params": {
110
+ "output_dim": "same"
111
+ },
112
+ "status": "planned",
113
+ "title": "Linear"
114
+ },
115
+ "dragHandle": ".bg-primary",
116
+ "height": 200.0,
117
+ "id": "Linear 1",
118
+ "position": {
119
+ "x": 86.0,
120
+ "y": 33.0
121
+ },
122
+ "type": "basic",
123
+ "width": 200.0
124
+ },
125
+ {
126
+ "data": {
127
+ "display": null,
128
+ "error": null,
129
+ "meta": {
130
+ "inputs": {
131
+ "x": {
132
+ "name": "x",
133
+ "position": "bottom",
134
+ "type": {
135
+ "type": "tensor"
136
+ }
137
+ },
138
+ "y": {
139
+ "name": "y",
140
+ "position": "bottom",
141
+ "type": {
142
+ "type": "tensor"
143
+ }
144
+ }
145
+ },
146
+ "name": "MSE loss",
147
+ "outputs": {
148
+ "loss": {
149
+ "name": "loss",
150
+ "position": "top",
151
+ "type": {
152
+ "type": "tensor"
153
+ }
154
+ }
155
+ },
156
+ "params": {},
157
+ "type": "basic"
158
+ },
159
+ "params": {},
160
+ "status": "planned",
161
+ "title": "MSE loss"
162
+ },
163
+ "dragHandle": ".bg-primary",
164
+ "height": 200.0,
165
+ "id": "MSE loss 1",
166
+ "position": {
167
+ "x": 315.0,
168
+ "y": -510.0
169
+ },
170
+ "type": "basic",
171
+ "width": 200.0
172
+ },
173
+ {
174
+ "data": {
175
+ "display": null,
176
+ "error": null,
177
+ "meta": {
178
+ "inputs": {},
179
+ "name": "Input: label",
180
+ "outputs": {
181
+ "y": {
182
+ "name": "y",
183
+ "position": "top",
184
+ "type": {
185
+ "type": "tensor"
186
+ }
187
+ }
188
+ },
189
+ "params": {},
190
+ "type": "basic"
191
+ },
192
+ "params": {},
193
+ "status": "planned",
194
+ "title": "Input: label"
195
+ },
196
+ "dragHandle": ".bg-primary",
197
+ "height": 200.0,
198
+ "id": "Input: label 1",
199
+ "position": {
200
+ "x": 615.0,
201
+ "y": -165.0
202
+ },
203
+ "type": "basic",
204
+ "width": 200.0
205
+ },
206
+ {
207
+ "data": {
208
+ "__execution_delay": 0.0,
209
+ "collapsed": null,
210
+ "display": null,
211
+ "error": null,
212
+ "meta": {
213
+ "inputs": {
214
+ "x": {
215
+ "name": "x",
216
+ "position": "bottom",
217
+ "type": {
218
+ "type": "tensor"
219
+ }
220
+ }
221
+ },
222
+ "name": "Activation",
223
+ "outputs": {
224
+ "x": {
225
+ "name": "x",
226
+ "position": "top",
227
+ "type": {
228
+ "type": "tensor"
229
+ }
230
+ }
231
+ },
232
+ "params": {
233
+ "type": {
234
+ "default": "ReLU",
235
+ "name": "type",
236
+ "type": {
237
+ "enum": [
238
+ "ReLU",
239
+ "Leaky ReLU",
240
+ "Tanh",
241
+ "Mish"
242
+ ]
243
+ }
244
+ }
245
+ },
246
+ "position": {
247
+ "x": 419.0,
248
+ "y": 396.0
249
+ },
250
+ "type": "basic"
251
+ },
252
+ "params": {
253
+ "type": "Leaky ReLU"
254
+ },
255
+ "status": "planned",
256
+ "title": "Activation"
257
+ },
258
+ "dragHandle": ".bg-primary",
259
+ "height": 200.0,
260
+ "id": "Activation 2",
261
+ "position": {
262
+ "x": 93.61643829835265,
263
+ "y": -229.04087132886406
264
+ },
265
+ "type": "basic",
266
+ "width": 200.0
267
+ },
268
+ {
269
+ "data": {
270
+ "__execution_delay": 0.0,
271
+ "collapsed": null,
272
+ "display": null,
273
+ "error": null,
274
+ "meta": {
275
+ "inputs": {
276
+ "loss": {
277
+ "name": "loss",
278
+ "position": "bottom",
279
+ "type": {
280
+ "type": "tensor"
281
+ }
282
+ }
283
+ },
284
+ "name": "Optimizer",
285
+ "outputs": {},
286
+ "params": {
287
+ "lr": {
288
+ "default": 0.001,
289
+ "name": "lr",
290
+ "type": {
291
+ "type": "<class 'float'>"
292
+ }
293
+ },
294
+ "type": {
295
+ "default": "AdamW",
296
+ "name": "type",
297
+ "type": {
298
+ "enum": [
299
+ "AdamW",
300
+ "Adafactor",
301
+ "Adagrad",
302
+ "SGD",
303
+ "Lion",
304
+ "Paged AdamW",
305
+ "Galore AdamW"
306
+ ]
307
+ }
308
+ }
309
+ },
310
+ "position": {
311
+ "x": 526.0,
312
+ "y": 116.0
313
+ },
314
+ "type": "basic"
315
+ },
316
+ "params": {
317
+ "lr": "0.1",
318
+ "type": "SGD"
319
+ },
320
+ "status": "planned",
321
+ "title": "Optimizer"
322
+ },
323
+ "dragHandle": ".bg-primary",
324
+ "height": 200.0,
325
+ "id": "Optimizer 2",
326
+ "position": {
327
+ "x": 305.6132943499785,
328
+ "y": -804.0094318451224
329
+ },
330
+ "type": "basic",
331
+ "width": 200.0
332
+ }
333
+ ]
334
+ }
examples/Model use ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "edges": [
3
+ {
4
+ "id": "Import Parquet 1 Train/test split 1",
5
+ "source": "Import Parquet 1",
6
+ "sourceHandle": "output",
7
+ "target": "Train/test split 1",
8
+ "targetHandle": "bundle"
9
+ },
10
+ {
11
+ "id": "Train/test split 1 Train model 3",
12
+ "source": "Train/test split 1",
13
+ "sourceHandle": "output",
14
+ "target": "Train model 3",
15
+ "targetHandle": "bundle"
16
+ },
17
+ {
18
+ "id": "Train model 3 Model inference 1",
19
+ "source": "Train model 3",
20
+ "sourceHandle": "output",
21
+ "target": "Model inference 1",
22
+ "targetHandle": "bundle"
23
+ }
24
+ ],
25
+ "env": "LynxKite Graph Analytics",
26
+ "nodes": [
27
+ {
28
+ "data": {
29
+ "__execution_delay": 0.0,
30
+ "collapsed": null,
31
+ "display": null,
32
+ "error": null,
33
+ "meta": {
34
+ "inputs": {
35
+ "bundle": {
36
+ "name": "bundle",
37
+ "position": "left",
38
+ "type": {
39
+ "type": "<class 'lynxkite_graph_analytics.core.Bundle'>"
40
+ }
41
+ }
42
+ },
43
+ "name": "Train/test split",
44
+ "outputs": {
45
+ "output": {
46
+ "name": "output",
47
+ "position": "right",
48
+ "type": {
49
+ "type": "None"
50
+ }
51
+ }
52
+ },
53
+ "params": {
54
+ "table_name": {
55
+ "default": null,
56
+ "name": "table_name",
57
+ "type": {
58
+ "type": "<class 'str'>"
59
+ }
60
+ },
61
+ "test_ratio": {
62
+ "default": 0.1,
63
+ "name": "test_ratio",
64
+ "type": {
65
+ "type": "<class 'float'>"
66
+ }
67
+ }
68
+ },
69
+ "type": "basic"
70
+ },
71
+ "params": {
72
+ "table_name": "df",
73
+ "test_ratio": 0.1
74
+ },
75
+ "status": "done",
76
+ "title": "Train/test split"
77
+ },
78
+ "dragHandle": ".bg-primary",
79
+ "height": 282.0,
80
+ "id": "Train/test split 1",
81
+ "position": {
82
+ "x": 345.0,
83
+ "y": 139.0
84
+ },
85
+ "type": "basic",
86
+ "width": 259.0
87
+ },
88
+ {
89
+ "data": {
90
+ "__execution_delay": 0.0,
91
+ "collapsed": null,
92
+ "display": null,
93
+ "error": null,
94
+ "meta": {
95
+ "inputs": {},
96
+ "name": "Import Parquet",
97
+ "outputs": {
98
+ "output": {
99
+ "name": "output",
100
+ "position": "right",
101
+ "type": {
102
+ "type": "None"
103
+ }
104
+ }
105
+ },
106
+ "params": {
107
+ "filename": {
108
+ "default": null,
109
+ "name": "filename",
110
+ "type": {
111
+ "type": "<class 'str'>"
112
+ }
113
+ }
114
+ },
115
+ "type": "basic"
116
+ },
117
+ "params": {
118
+ "filename": "uploads/plus-one-dataset.parquet"
119
+ },
120
+ "status": "done",
121
+ "title": "Import Parquet"
122
+ },
123
+ "dragHandle": ".bg-primary",
124
+ "height": 403.0,
125
+ "id": "Import Parquet 1",
126
+ "position": {
127
+ "x": -166.0,
128
+ "y": 112.0
129
+ },
130
+ "type": "basic",
131
+ "width": 371.0
132
+ },
133
+ {
134
+ "data": {
135
+ "__execution_delay": 0.0,
136
+ "collapsed": null,
137
+ "display": null,
138
+ "error": null,
139
+ "meta": {
140
+ "inputs": {
141
+ "bundle": {
142
+ "name": "bundle",
143
+ "position": "left",
144
+ "type": {
145
+ "type": "<class 'lynxkite_graph_analytics.core.Bundle'>"
146
+ }
147
+ }
148
+ },
149
+ "name": "Train model",
150
+ "outputs": {
151
+ "output": {
152
+ "name": "output",
153
+ "position": "right",
154
+ "type": {
155
+ "type": "None"
156
+ }
157
+ }
158
+ },
159
+ "params": {
160
+ "epochs": {
161
+ "default": 1.0,
162
+ "name": "epochs",
163
+ "type": {
164
+ "type": "<class 'int'>"
165
+ }
166
+ },
167
+ "input_mapping": {
168
+ "default": null,
169
+ "name": "input_mapping",
170
+ "type": {
171
+ "type": "<class 'str'>"
172
+ }
173
+ },
174
+ "model_workspace": {
175
+ "default": null,
176
+ "name": "model_workspace",
177
+ "type": {
178
+ "type": "<class 'str'>"
179
+ }
180
+ },
181
+ "save_as": {
182
+ "default": "model",
183
+ "name": "save_as",
184
+ "type": {
185
+ "type": "<class 'str'>"
186
+ }
187
+ }
188
+ },
189
+ "position": {
190
+ "x": 675.0,
191
+ "y": 144.0
192
+ },
193
+ "type": "basic"
194
+ },
195
+ "params": {
196
+ "epochs": "1000",
197
+ "input_mapping": "{\"Input__embedding_1_x\": {\"df\": \"df_train\", \"column\": \"x\"}, \"Input__label_1_y\": {\"df\": \"df_train\", \"column\": \"y\" }}",
198
+ "model_workspace": "Model definition",
199
+ "save_as": "model"
200
+ },
201
+ "status": "done",
202
+ "title": "Train model"
203
+ },
204
+ "dragHandle": ".bg-primary",
205
+ "height": 519.0,
206
+ "id": "Train model 3",
207
+ "position": {
208
+ "x": 687.3818749999999,
209
+ "y": -34.16902777777775
210
+ },
211
+ "type": "basic",
212
+ "width": 640.0
213
+ },
214
+ {
215
+ "data": {
216
+ "__execution_delay": 0.0,
217
+ "collapsed": null,
218
+ "display": null,
219
+ "error": null,
220
+ "meta": {
221
+ "inputs": {
222
+ "bundle": {
223
+ "name": "bundle",
224
+ "position": "left",
225
+ "type": {
226
+ "type": "<class 'lynxkite_graph_analytics.core.Bundle'>"
227
+ }
228
+ }
229
+ },
230
+ "name": "Model inference",
231
+ "outputs": {
232
+ "output": {
233
+ "name": "output",
234
+ "position": "right",
235
+ "type": {
236
+ "type": "None"
237
+ }
238
+ }
239
+ },
240
+ "params": {
241
+ "input_mapping": {
242
+ "default": "",
243
+ "name": "input_mapping",
244
+ "type": {
245
+ "type": "<class 'str'>"
246
+ }
247
+ },
248
+ "model_name": {
249
+ "default": "model",
250
+ "name": "model_name",
251
+ "type": {
252
+ "type": "<class 'str'>"
253
+ }
254
+ },
255
+ "output_mapping": {
256
+ "default": "",
257
+ "name": "output_mapping",
258
+ "type": {
259
+ "type": "<class 'str'>"
260
+ }
261
+ }
262
+ },
263
+ "position": {
264
+ "x": 506.0,
265
+ "y": 115.0
266
+ },
267
+ "type": "basic"
268
+ },
269
+ "params": {
270
+ "input_mapping": "{\"Input__embedding_1_x\": {\"df\": \"df_test\", \"column\": \"x\"}}",
271
+ "model_name": "model",
272
+ "output_mapping": "{\"Activation_2_x\": {\"df\": \"df_test\", \"column\": \"predicted\"}}"
273
+ },
274
+ "status": "done",
275
+ "title": "Model inference"
276
+ },
277
+ "dragHandle": ".bg-primary",
278
+ "height": 429.0,
279
+ "id": "Model inference 1",
280
+ "position": {
281
+ "x": 1445.5664910683593,
282
+ "y": 12.075943590382515
283
+ },
284
+ "type": "basic",
285
+ "width": 410.0
286
+ }
287
+ ]
288
+ }
examples/uploads/plus-one-dataset.parquet ADDED
Binary file (7.54 kB). View file
 
lynxkite-core/src/lynxkite/core/ops.py CHANGED
@@ -61,7 +61,7 @@ class Parameter(BaseConfig):
61
  @staticmethod
62
  def options(name, options, default=None):
63
  e = enum.Enum(f"OptionsFor_{name}", options)
64
- return Parameter.basic(name, e[default or options[0]], e)
65
 
66
  @staticmethod
67
  def collapsed(name, default, type=None):
@@ -154,9 +154,7 @@ class Op(BaseConfig):
154
 
155
  def __call__(self, *inputs, **params):
156
  # Convert parameters.
157
- for p in params:
158
- if p in self.params:
159
- params[p] = _param_to_type(p, params[p], self.params[p].type)
160
  res = self.func(*inputs, **params)
161
  if not isinstance(res, Result):
162
  # Automatically wrap the result in a Result object, if it isn't already.
@@ -172,6 +170,15 @@ class Op(BaseConfig):
172
  res.display = res.output
173
  return res
174
 
 
 
 
 
 
 
 
 
 
175
 
176
  def op(env: str, name: str, *, view="basic", outputs=None, params=None):
177
  """Decorator for defining an operation."""
 
61
  @staticmethod
62
  def options(name, options, default=None):
63
  e = enum.Enum(f"OptionsFor_{name}", options)
64
+ return Parameter.basic(name, default or options[0], e)
65
 
66
  @staticmethod
67
  def collapsed(name, default, type=None):
 
154
 
155
  def __call__(self, *inputs, **params):
156
  # Convert parameters.
157
+ params = self.convert_params(params)
 
 
158
  res = self.func(*inputs, **params)
159
  if not isinstance(res, Result):
160
  # Automatically wrap the result in a Result object, if it isn't already.
 
170
  res.display = res.output
171
  return res
172
 
173
+ def convert_params(self, params):
174
+ """Returns the parameters converted to the expected type."""
175
+ res = {}
176
+ for p in params:
177
+ res[p] = params[p]
178
+ if p in self.params:
179
+ res[p] = _param_to_type(p, params[p], self.params[p].type)
180
+ return res
181
+
182
 
183
  def op(env: str, name: str, *, view="basic", outputs=None, params=None):
184
  """Decorator for defining an operation."""
lynxkite-graph-analytics/src/lynxkite_graph_analytics/core.py CHANGED
@@ -42,7 +42,7 @@ class Bundle:
42
 
43
  dfs: dict[str, pd.DataFrame] = dataclasses.field(default_factory=dict)
44
  relations: list[RelationDefinition] = dataclasses.field(default_factory=list)
45
- other: dict[str, typing.Any] = None
46
 
47
  @classmethod
48
  def from_nx(cls, graph: nx.Graph):
@@ -102,7 +102,7 @@ class Bundle:
102
  return Bundle(
103
  dfs=dict(self.dfs),
104
  relations=list(self.relations),
105
- other=dict(self.other) if self.other else None,
106
  )
107
 
108
  def to_dict(self, limit: int = 100):
 
42
 
43
  dfs: dict[str, pd.DataFrame] = dataclasses.field(default_factory=dict)
44
  relations: list[RelationDefinition] = dataclasses.field(default_factory=list)
45
+ other: dict[str, typing.Any] = dataclasses.field(default_factory=dict)
46
 
47
  @classmethod
48
  def from_nx(cls, graph: nx.Graph):
 
102
  return Bundle(
103
  dfs=dict(self.dfs),
104
  relations=list(self.relations),
105
+ other=dict(self.other),
106
  )
107
 
108
  def to_dict(self, limit: int = 100):
lynxkite-graph-analytics/src/lynxkite_graph_analytics/lynxkite_ops.py CHANGED
@@ -2,10 +2,14 @@
2
 
3
  import enum
4
  import os
 
5
  import fsspec
6
  from lynxkite.core import ops
7
  from collections import deque
8
- from . import core
 
 
 
9
  import grandcypher
10
  import joblib
11
  import matplotlib
@@ -344,10 +348,13 @@ def create_graph(bundle: core.Bundle, *, relations: str = None) -> core.Bundle:
344
  return ops.Result(output=bundle, display=bundle.to_dict(limit=100))
345
 
346
 
347
- @op("Define model")
348
- def define_model(*, model_workspace: str, save_as: str = "model"):
349
- """Reads a PyTorch model workspace and returns it as a model in a bundle."""
350
- return None
 
 
 
351
 
352
 
353
  @op("Biomedical foundation graph (PLACEHOLDER)")
@@ -358,25 +365,54 @@ def biomedical_foundation_graph(*, filter_nodes: str):
358
 
359
  @op("Train model")
360
  def train_model(
361
- bundle: core.Bundle, *, model_name: str, model_mapping: str, epochs: int = 1
 
 
 
 
 
362
  ):
363
  """Trains the selected model on the selected dataset. Most training parameters are set in the model definition."""
364
- return None
 
 
 
 
 
 
 
 
 
 
365
 
366
 
367
  @op("Model inference")
368
  def model_inference(
369
  bundle: core.Bundle,
370
  *,
371
- model_name: str,
372
- model_mapping: str,
373
- save_output_as: str = "prediction",
374
  ):
375
  """Executes a trained model."""
376
- return None
 
 
 
 
 
 
 
 
377
 
378
 
379
  @op("Train/test split")
380
  def train_test_split(bundle: core.Bundle, *, table_name: str, test_ratio: float = 0.1):
381
  """Splits a dataframe in the bundle into separate "_train" and "_test" dataframes."""
382
- return None
 
 
 
 
 
 
 
2
 
3
  import enum
4
  import os
5
+ import pathlib
6
  import fsspec
7
  from lynxkite.core import ops
8
  from collections import deque
9
+
10
+ from tqdm import tqdm
11
+ from . import core, pytorch_model_ops
12
+ from lynxkite.core import workspace
13
  import grandcypher
14
  import joblib
15
  import matplotlib
 
348
  return ops.Result(output=bundle, display=bundle.to_dict(limit=100))
349
 
350
 
351
+ def load_ws(model_workspace: str):
352
+ cwd = pathlib.Path()
353
+ path = cwd / model_workspace
354
+ assert path.is_relative_to(cwd)
355
+ assert path.exists(), f"Workspace {path} does not exist"
356
+ ws = workspace.load(path)
357
+ return ws
358
 
359
 
360
  @op("Biomedical foundation graph (PLACEHOLDER)")
 
365
 
366
  @op("Train model")
367
  def train_model(
368
+ bundle: core.Bundle,
369
+ *,
370
+ model_workspace: str,
371
+ input_mapping: str,
372
+ epochs: int = 1,
373
+ save_as: str = "model",
374
  ):
375
  """Trains the selected model on the selected dataset. Most training parameters are set in the model definition."""
376
+ ws = load_ws(model_workspace)
377
+ input_mapping = json.loads(input_mapping)
378
+ inputs = pytorch_model_ops.to_tensors(bundle, input_mapping)
379
+ m = pytorch_model_ops.build_model(ws, inputs)
380
+ t = tqdm(range(epochs), desc="Training model")
381
+ for _ in t:
382
+ loss = m.train(inputs)
383
+ t.set_postfix({"loss": loss})
384
+ bundle = bundle.copy()
385
+ bundle.other[save_as] = m
386
+ return bundle
387
 
388
 
389
  @op("Model inference")
390
  def model_inference(
391
  bundle: core.Bundle,
392
  *,
393
+ model_name: str = "model",
394
+ input_mapping: str = "",
395
+ output_mapping: str = "",
396
  ):
397
  """Executes a trained model."""
398
+ m = bundle.other[model_name]
399
+ input_mapping = json.loads(input_mapping)
400
+ output_mapping = json.loads(output_mapping)
401
+ inputs = pytorch_model_ops.to_tensors(bundle, input_mapping)
402
+ outputs = m.inference(inputs)
403
+ bundle = bundle.copy()
404
+ for k, v in output_mapping.items():
405
+ bundle.dfs[v["df"]][v["column"]] = outputs[k].detach().numpy().tolist()
406
+ return bundle
407
 
408
 
409
  @op("Train/test split")
410
  def train_test_split(bundle: core.Bundle, *, table_name: str, test_ratio: float = 0.1):
411
  """Splits a dataframe in the bundle into separate "_train" and "_test" dataframes."""
412
+ df = bundle.dfs[table_name]
413
+ test = df.sample(frac=test_ratio)
414
+ train = df.drop(test.index)
415
+ bundle = bundle.copy()
416
+ bundle.dfs[f"{table_name}_train"] = train
417
+ bundle.dfs[f"{table_name}_test"] = test
418
+ return bundle
lynxkite-graph-analytics/src/lynxkite_graph_analytics/pytorch_model_ops.py CHANGED
@@ -6,6 +6,7 @@ from lynxkite.core.ops import Parameter as P
6
  import torch
7
  import torch_geometric as pyg
8
  from dataclasses import dataclass
 
9
 
10
  ENV = "PyTorch model"
11
 
@@ -162,11 +163,18 @@ class ModelConfig:
162
  self.optimizer.step()
163
  return loss.item()
164
 
 
 
 
 
 
 
165
 
166
  def build_model(
167
  ws: workspace.Workspace, inputs: dict[str, torch.Tensor]
168
  ) -> ModelConfig:
169
  """Builds the model described in the workspace."""
 
170
  optimizers = []
171
  nodes = {}
172
  for node in ws.nodes:
@@ -197,7 +205,8 @@ def build_model(
197
  for node_id in ts.static_order():
198
  node = nodes[node_id]
199
  t = node.data.title
200
- p = node.data.params
 
201
  for b in dependencies[node_id]:
202
  if b in in_loss:
203
  in_loss.add(node_id)
@@ -216,7 +225,9 @@ def build_model(
216
  [(ib, ih)] = edges[node_id, "x"]
217
  i = _to_id(ib) + "_" + ih
218
  used_inputs.add(i)
219
- f = getattr(torch.nn.functional, p["type"].lower().replace(" ", "_"))
 
 
220
  ls.append((f, f"{i} -> {nid}_x"))
221
  sizes[f"{nid}_x"] = sizes[i]
222
  case "MSE loss":
@@ -248,7 +259,18 @@ def build_model(
248
  f"loss should have no parameters: {list(cfg['loss'].parameters())}"
249
  )
250
  # Create optimizer.
251
- p = nodes[optimizer].data.params
252
- o = getattr(torch.optim, p["type"])
 
253
  cfg["optimizer"] = o(cfg["model"].parameters(), lr=p["lr"])
254
  return ModelConfig(**cfg)
 
 
 
 
 
 
 
 
 
 
 
6
  import torch
7
  import torch_geometric as pyg
8
  from dataclasses import dataclass
9
+ from . import core
10
 
11
  ENV = "PyTorch model"
12
 
 
163
  self.optimizer.step()
164
  return loss.item()
165
 
166
+ def copy(self):
167
+ """Returns a copy of the model."""
168
+ c = super().copy()
169
+ c.model = self.model.copy()
170
+ return c
171
+
172
 
173
  def build_model(
174
  ws: workspace.Workspace, inputs: dict[str, torch.Tensor]
175
  ) -> ModelConfig:
176
  """Builds the model described in the workspace."""
177
+ catalog = ops.CATALOGS[ENV]
178
  optimizers = []
179
  nodes = {}
180
  for node in ws.nodes:
 
205
  for node_id in ts.static_order():
206
  node = nodes[node_id]
207
  t = node.data.title
208
+ op = catalog[t]
209
+ p = op.convert_params(node.data.params)
210
  for b in dependencies[node_id]:
211
  if b in in_loss:
212
  in_loss.add(node_id)
 
225
  [(ib, ih)] = edges[node_id, "x"]
226
  i = _to_id(ib) + "_" + ih
227
  used_inputs.add(i)
228
+ f = getattr(
229
+ torch.nn.functional, p["type"].name.lower().replace(" ", "_")
230
+ )
231
  ls.append((f, f"{i} -> {nid}_x"))
232
  sizes[f"{nid}_x"] = sizes[i]
233
  case "MSE loss":
 
259
  f"loss should have no parameters: {list(cfg['loss'].parameters())}"
260
  )
261
  # Create optimizer.
262
+ op = catalog["Optimizer"]
263
+ p = op.convert_params(nodes[optimizer].data.params)
264
+ o = getattr(torch.optim, p["type"].name)
265
  cfg["optimizer"] = o(cfg["model"].parameters(), lr=p["lr"])
266
  return ModelConfig(**cfg)
267
+
268
+
269
+ def to_tensors(b: core.Bundle, m: dict[str, dict]) -> dict[str, torch.Tensor]:
270
+ """Converts a tensor to the correct type for PyTorch."""
271
+ tensors = {}
272
+ for k, v in m.items():
273
+ tensors[k] = torch.tensor(
274
+ b.dfs[v["df"]][v["column"]].to_list(), dtype=torch.float32
275
+ )
276
+ return tensors