darabos commited on
Commit
612802c
·
1 Parent(s): 0cdd509

Do not track tensor shapes. Much simpler!

Browse files
examples/Model definition CHANGED
@@ -1,12 +1,5 @@
1
  {
2
  "edges": [
3
- {
4
- "id": "Repeat 1 Linear 2",
5
- "source": "Repeat 1",
6
- "sourceHandle": "output",
7
- "target": "Linear 2",
8
- "targetHandle": "x"
9
- },
10
  {
11
  "id": "Linear 2 Activation 1",
12
  "source": "Linear 2",
@@ -14,17 +7,10 @@
14
  "target": "Activation 1",
15
  "targetHandle": "x"
16
  },
17
- {
18
- "id": "Activation 1 Repeat 1",
19
- "source": "Activation 1",
20
- "sourceHandle": "output",
21
- "target": "Repeat 1",
22
- "targetHandle": "input"
23
- },
24
  {
25
  "id": "Input: tensor 1 Linear 2",
26
  "source": "Input: tensor 1",
27
- "sourceHandle": "output",
28
  "target": "Linear 2",
29
  "targetHandle": "x"
30
  },
@@ -45,9 +31,23 @@
45
  {
46
  "id": "Input: tensor 3 MSE loss 2",
47
  "source": "Input: tensor 3",
48
- "sourceHandle": "output",
49
  "target": "MSE loss 2",
50
  "targetHandle": "y"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  }
52
  ],
53
  "env": "PyTorch model",
@@ -118,66 +118,6 @@
118
  "data": {
119
  "__execution_delay": 0.0,
120
  "collapsed": null,
121
- "display": null,
122
- "error": null,
123
- "input_metadata": null,
124
- "meta": {
125
- "inputs": {
126
- "input": {
127
- "name": "input",
128
- "position": "top",
129
- "type": {
130
- "type": "tensor"
131
- }
132
- }
133
- },
134
- "name": "Repeat",
135
- "outputs": {
136
- "output": {
137
- "name": "output",
138
- "position": "bottom",
139
- "type": {
140
- "type": "tensor"
141
- }
142
- }
143
- },
144
- "params": {
145
- "same_weights": {
146
- "default": false,
147
- "name": "same_weights",
148
- "type": {
149
- "type": "<class 'bool'>"
150
- }
151
- },
152
- "times": {
153
- "default": 1.0,
154
- "name": "times",
155
- "type": {
156
- "type": "<class 'int'>"
157
- }
158
- }
159
- },
160
- "type": "basic"
161
- },
162
- "params": {
163
- "same_weights": false,
164
- "times": "3"
165
- },
166
- "status": "planned",
167
- "title": "Repeat"
168
- },
169
- "dragHandle": ".bg-primary",
170
- "height": 200.0,
171
- "id": "Repeat 1",
172
- "position": {
173
- "x": -180.0,
174
- "y": -90.0
175
- },
176
- "type": "basic",
177
- "width": 200.0
178
- },
179
- {
180
- "data": {
181
  "display": null,
182
  "error": null,
183
  "input_metadata": null,
@@ -203,17 +143,17 @@
203
  },
204
  "params": {
205
  "output_dim": {
206
- "default": "same",
207
  "name": "output_dim",
208
  "type": {
209
- "type": "<class 'str'>"
210
  }
211
  }
212
  },
213
  "type": "basic"
214
  },
215
  "params": {
216
- "output_dim": "same"
217
  },
218
  "status": "planned",
219
  "title": "Linear"
 
1
  {
2
  "edges": [
 
 
 
 
 
 
 
3
  {
4
  "id": "Linear 2 Activation 1",
5
  "source": "Linear 2",
 
7
  "target": "Activation 1",
8
  "targetHandle": "x"
9
  },
 
 
 
 
 
 
 
10
  {
11
  "id": "Input: tensor 1 Linear 2",
12
  "source": "Input: tensor 1",
13
+ "sourceHandle": "x",
14
  "target": "Linear 2",
15
  "targetHandle": "x"
16
  },
 
31
  {
32
  "id": "Input: tensor 3 MSE loss 2",
33
  "source": "Input: tensor 3",
34
+ "sourceHandle": "x",
35
  "target": "MSE loss 2",
36
  "targetHandle": "y"
37
+ },
38
+ {
39
+ "id": "Activation 1 Repeat 1",
40
+ "source": "Activation 1",
41
+ "sourceHandle": "output",
42
+ "target": "Repeat 1",
43
+ "targetHandle": "input"
44
+ },
45
+ {
46
+ "id": "Repeat 1 Linear 2",
47
+ "source": "Repeat 1",
48
+ "sourceHandle": "output",
49
+ "target": "Linear 2",
50
+ "targetHandle": "x"
51
  }
52
  ],
53
  "env": "PyTorch model",
 
118
  "data": {
119
  "__execution_delay": 0.0,
120
  "collapsed": null,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  "display": null,
122
  "error": null,
123
  "input_metadata": null,
 
143
  },
144
  "params": {
145
  "output_dim": {
146
+ "default": "",
147
  "name": "output_dim",
148
  "type": {
149
+ "type": "<class 'int'>"
150
  }
151
  }
152
  },
153
  "type": "basic"
154
  },
155
  "params": {
156
+ "output_dim": "4"
157
  },
158
  "status": "planned",
159
  "title": "Linear"
lynxkite-graph-analytics/src/lynxkite_graph_analytics/lynxkite_ops.py CHANGED
@@ -347,7 +347,7 @@ def define_model(
347
  assert model_workspace, "Model workspace is unset."
348
  ws = load_ws(model_workspace)
349
  # Build the model without inputs, to get its interface.
350
- m = pytorch_model_ops.build_model(ws, {})
351
  m.source_workspace = model_workspace
352
  bundle = bundle.copy()
353
  bundle.other[save_as] = m
@@ -379,10 +379,6 @@ def train_model(
379
  """Trains the selected model on the selected dataset. Most training parameters are set in the model definition."""
380
  m = bundle.other[model_name].copy()
381
  inputs = pytorch_model_ops.to_tensors(bundle, input_mapping)
382
- if not m.trained and m.source_workspace:
383
- # Rebuild the model for the correct inputs.
384
- ws = load_ws(m.source_workspace)
385
- m = pytorch_model_ops.build_model(ws, inputs)
386
  t = tqdm(range(epochs), desc="Training model")
387
  for _ in t:
388
  loss = m.train(inputs)
 
347
  assert model_workspace, "Model workspace is unset."
348
  ws = load_ws(model_workspace)
349
  # Build the model without inputs, to get its interface.
350
+ m = pytorch_model_ops.build_model(ws)
351
  m.source_workspace = model_workspace
352
  bundle = bundle.copy()
353
  bundle.other[save_as] = m
 
379
  """Trains the selected model on the selected dataset. Most training parameters are set in the model definition."""
380
  m = bundle.other[model_name].copy()
381
  inputs = pytorch_model_ops.to_tensors(bundle, input_mapping)
 
 
 
 
382
  t = tqdm(range(epochs), desc="Training model")
383
  for _ in t:
384
  loss = m.train(inputs)
lynxkite-graph-analytics/src/lynxkite_graph_analytics/pytorch_model_ops.py CHANGED
@@ -8,7 +8,7 @@ import pydantic
8
  from lynxkite.core import ops, workspace
9
  from lynxkite.core.ops import Parameter as P
10
  import torch
11
- import torch_geometric as pyg
12
  import dataclasses
13
  from . import core
14
 
@@ -78,12 +78,8 @@ reg("Dropout", inputs=["x"], params=[P.basic("p", 0.5)])
78
 
79
 
80
  @op("Linear")
81
- def linear(x, *, output_dim="same"):
82
- if output_dim == "same":
83
- oshape = x.shape
84
- else:
85
- oshape = tuple(*x.shape[:-1], int(output_dim))
86
- return Layer(torch.nn.Linear(x.shape, oshape), shape=oshape)
87
 
88
 
89
  class ActivationTypes(enum.Enum):
@@ -95,13 +91,12 @@ class ActivationTypes(enum.Enum):
95
 
96
  @op("Activation")
97
  def activation(x, *, type: ActivationTypes = ActivationTypes.ReLU):
98
- f = getattr(torch.nn.functional, type.name.lower().replace(" ", "_"))
99
- return Layer(f, shape=x.shape)
100
 
101
 
102
  @op("MSE loss")
103
  def mse_loss(x, y):
104
- return Layer(torch.nn.functional.mse_loss, shape=[1])
105
 
106
 
107
  reg("Softmax", inputs=["x"])
@@ -163,34 +158,18 @@ def _to_id(*strings: str) -> str:
163
  return "_".join("".join(c if c.isalnum() else "_" for c in s) for s in strings)
164
 
165
 
166
- @dataclasses.dataclass
167
- class TensorRef:
168
- """Ops get their inputs like this. They have to return a Layer made for this input."""
169
-
170
- _id: str
171
- shape: tuple[int, ...]
172
-
173
-
174
  @dataclasses.dataclass
175
  class Layer:
176
- """Return this from an op. Must include a module and the shapes of the outputs."""
177
 
178
  module: torch.nn.Module
179
- shapes: list[tuple[int, ...]] | None = None # One for each output.
180
- shape: dataclasses.InitVar[tuple[int, ...] | None] = None # Convenience for single output.
181
- # Set by ModelBuilder.
182
- _origin_id: str | None = None
183
- _inputs: list[TensorRef] | None = None
184
- _outputs: list[TensorRef] | None = None
185
-
186
- def __post_init__(self, shape):
187
- assert not self.shapes or not shape, "Cannot set both shapes and shape."
188
- if shape:
189
- self.shapes = [shape]
190
-
191
- def _for_sequential(self):
192
- inputs = ", ".join(i._id for i in self._inputs)
193
- outputs = ", ".join(o._id for o in self._outputs)
194
  return self.module, f"{inputs} -> {outputs}"
195
 
196
 
@@ -261,16 +240,16 @@ class ModelConfig:
261
  }
262
 
263
 
264
- def build_model(ws: workspace.Workspace, inputs: dict[str, torch.Tensor]) -> ModelConfig:
265
  """Builds the model described in the workspace."""
266
- builder = ModelBuilder(ws, inputs)
267
  return builder.build_model()
268
 
269
 
270
  class ModelBuilder:
271
  """The state shared between methods that are used to build the model."""
272
 
273
- def __init__(self, ws: workspace.Workspace, inputs: dict[str, torch.Tensor]):
274
  self.catalog = ops.CATALOGS[ENV]
275
  optimizers = []
276
  self.nodes: dict[str, workspace.WorkspaceNode] = {}
@@ -287,8 +266,8 @@ class ModelBuilder:
287
  assert len(optimizers) == 1, f"More than one optimizer found: {optimizers}"
288
  [self.optimizer] = optimizers
289
  self.dependencies = {n: [] for n in self.nodes}
290
- self.in_edges: dict[str, dict[str, list[(str, str)]]] = {n: {} for n in self.nodes}
291
- self.out_edges: dict[str, dict[str, list[(str, str)]]] = {n: {} for n in self.nodes}
292
  for e in ws.edges:
293
  self.dependencies[e.target].append(e.source)
294
  self.in_edges.setdefault(e.target, {}).setdefault(e.targetHandle, []).append(
@@ -342,10 +321,7 @@ class ModelBuilder:
342
  for k, v in self.dependencies.items():
343
  for i in v:
344
  self.inv_dependencies[i].append(k)
345
- self.sizes = {}
346
- for k, i in inputs.items():
347
- self.sizes[k] = i.shape[-1]
348
- self.layers = []
349
  # Clean up disconnected nodes.
350
  disconnected = set()
351
  for node_id in self.nodes:
@@ -396,13 +372,14 @@ class ModelBuilder:
396
  assert affected_nodes == repeated_nodes, (
397
  f"edges leave repeated section '{repeat_id}':\n{affected_nodes - repeated_nodes}"
398
  )
399
- repeated_layers = [e for e in self.layers if e._origin_id in repeated_nodes]
400
  assert p["times"] >= 1, f"Cannot repeat {repeat_id} {p['times']} times."
401
  for i in range(p["times"] - 1):
402
  # Copy repeat section's output to repeat section's input.
403
  self.layers.append(
404
- self.empty_layer(
405
- node_id,
 
406
  inputs=[_to_id(*last_output)],
407
  outputs=[_to_id(start_id, "output")],
408
  )
@@ -410,17 +387,9 @@ class ModelBuilder:
410
  # Repeat the layers in the section.
411
  for layer in repeated_layers:
412
  if p["same_weights"]:
413
- self.layers.append(
414
- Layer(
415
- layer.module,
416
- shapes=layer.shapes,
417
- _origin_id=layer._origin_id,
418
- _inputs=layer._inputs,
419
- _outputs=layer._outputs,
420
- )
421
- )
422
  else:
423
- self.run_node(layer._origin_id)
424
  self.layers.append(self.run_op(node_id, op, p))
425
  case "Optimizer" | "Input: tensor" | "Input: graph edges" | "Input: sequential":
426
  return
@@ -431,31 +400,11 @@ class ModelBuilder:
431
  """Returns the layer produced by this op."""
432
  inputs = [_to_id(*i) for n in op.inputs for i in self.in_edges[node_id][n]]
433
  outputs = [_to_id(node_id, n) for n in op.outputs]
434
- layer = self.empty_layer(node_id, inputs, outputs)
435
- if op.func != ops.no_op:
436
- op_layer = op.func(*layer._inputs, **params)
437
- layer.module = op_layer.module
438
- layer.shapes = op_layer.shapes
439
- for o in layer._outputs:
440
- self.sizes[o._id] = o.shape
441
- return layer
442
-
443
- def empty_layer(self, id: str, inputs: list[str], outputs: list[str]) -> Layer:
444
- """Creates an identity layer. Assumes that outputs have the same shapes as inputs."""
445
- layer_inputs = [TensorRef(i, shape=self.sizes.get(i, 1)) for i in inputs]
446
- layer_outputs = []
447
- for i, o in zip(inputs, outputs):
448
- shape = self.sizes.get(i, 1)
449
- layer_outputs.append(TensorRef(o, shape=shape))
450
- self.sizes[o] = shape
451
- layer = Layer(
452
- torch.nn.Identity(),
453
- shapes=[self.sizes[o._id] for o in layer_outputs],
454
- _inputs=layer_inputs,
455
- _outputs=layer_outputs,
456
- _origin_id=id,
457
- )
458
- return layer
459
 
460
  def build_model(self) -> ModelConfig:
461
  # Walk the graph in topological order.
@@ -474,16 +423,16 @@ class ModelBuilder:
474
  layers = []
475
  loss_layers = []
476
  for layer in self.layers:
477
- if layer._origin_id in loss_nodes:
478
  loss_layers.append(layer)
479
  else:
480
  layers.append(layer)
481
- used_in_model = set(input._id for layer in layers for input in layer._inputs)
482
- used_in_loss = set(input._id for layer in loss_layers for input in layer._inputs)
483
- made_in_model = set(output._id for layer in layers for output in layer._outputs)
484
- made_in_loss = set(output._id for layer in loss_layers for output in layer._outputs)
485
- layers = [layer._for_sequential() for layer in layers]
486
- loss_layers = [layer._for_sequential() for layer in loss_layers]
487
  cfg = {}
488
  cfg["model_inputs"] = list(used_in_model - made_in_model)
489
  cfg["model_outputs"] = list(made_in_model & used_in_loss)
@@ -492,13 +441,13 @@ class ModelBuilder:
492
  outputs = ", ".join(cfg["model_outputs"])
493
  layers.append((torch.nn.Identity(), f"{outputs} -> {outputs}"))
494
  # Create model.
495
- cfg["model"] = pyg.nn.Sequential(", ".join(cfg["model_inputs"]), layers)
496
  # Make sure the loss is output from the last loss layer.
497
  [(lossb, lossh)] = self.in_edges[self.optimizer]["loss"]
498
  lossi = _to_id(lossb, lossh)
499
  loss_layers.append((torch.nn.Identity(), f"{lossi} -> loss"))
500
  # Create loss function.
501
- cfg["loss"] = pyg.nn.Sequential(", ".join(cfg["loss_inputs"]), loss_layers)
502
  assert not list(cfg["loss"].parameters()), f"loss should have no parameters: {loss_layers}"
503
  # Create optimizer.
504
  op = self.catalog["Optimizer"]
 
8
  from lynxkite.core import ops, workspace
9
  from lynxkite.core.ops import Parameter as P
10
  import torch
11
+ import torch_geometric.nn as pyg_nn
12
  import dataclasses
13
  from . import core
14
 
 
78
 
79
 
80
  @op("Linear")
81
+ def linear(x, *, output_dim=1024):
82
+ return pyg_nn.Linear(-1, output_dim)
 
 
 
 
83
 
84
 
85
  class ActivationTypes(enum.Enum):
 
91
 
92
  @op("Activation")
93
  def activation(x, *, type: ActivationTypes = ActivationTypes.ReLU):
94
+ return getattr(torch.nn.functional, type.name.lower().replace(" ", "_"))
 
95
 
96
 
97
  @op("MSE loss")
98
  def mse_loss(x, y):
99
+ return torch.nn.functional.mse_loss
100
 
101
 
102
  reg("Softmax", inputs=["x"])
 
158
  return "_".join("".join(c if c.isalnum() else "_" for c in s) for s in strings)
159
 
160
 
 
 
 
 
 
 
 
 
161
  @dataclasses.dataclass
162
  class Layer:
163
+ """Temporary data structure used by ModelBuilder."""
164
 
165
  module: torch.nn.Module
166
+ origin_id: str
167
+ inputs: list[str]
168
+ outputs: list[str]
169
+
170
+ def for_sequential(self):
171
+ inputs = ", ".join(self.inputs)
172
+ outputs = ", ".join(self.outputs)
 
 
 
 
 
 
 
 
173
  return self.module, f"{inputs} -> {outputs}"
174
 
175
 
 
240
  }
241
 
242
 
243
+ def build_model(ws: workspace.Workspace) -> ModelConfig:
244
  """Builds the model described in the workspace."""
245
+ builder = ModelBuilder(ws)
246
  return builder.build_model()
247
 
248
 
249
  class ModelBuilder:
250
  """The state shared between methods that are used to build the model."""
251
 
252
+ def __init__(self, ws: workspace.Workspace):
253
  self.catalog = ops.CATALOGS[ENV]
254
  optimizers = []
255
  self.nodes: dict[str, workspace.WorkspaceNode] = {}
 
266
  assert len(optimizers) == 1, f"More than one optimizer found: {optimizers}"
267
  [self.optimizer] = optimizers
268
  self.dependencies = {n: [] for n in self.nodes}
269
+ self.in_edges: dict[str, dict[str, list[tuple[str, str]]]] = {n: {} for n in self.nodes}
270
+ self.out_edges: dict[str, dict[str, list[tuple[str, str]]]] = {n: {} for n in self.nodes}
271
  for e in ws.edges:
272
  self.dependencies[e.target].append(e.source)
273
  self.in_edges.setdefault(e.target, {}).setdefault(e.targetHandle, []).append(
 
321
  for k, v in self.dependencies.items():
322
  for i in v:
323
  self.inv_dependencies[i].append(k)
324
+ self.layers: list[Layer] = []
 
 
 
325
  # Clean up disconnected nodes.
326
  disconnected = set()
327
  for node_id in self.nodes:
 
372
  assert affected_nodes == repeated_nodes, (
373
  f"edges leave repeated section '{repeat_id}':\n{affected_nodes - repeated_nodes}"
374
  )
375
+ repeated_layers = [e for e in self.layers if e.origin_id in repeated_nodes]
376
  assert p["times"] >= 1, f"Cannot repeat {repeat_id} {p['times']} times."
377
  for i in range(p["times"] - 1):
378
  # Copy repeat section's output to repeat section's input.
379
  self.layers.append(
380
+ Layer(
381
+ torch.nn.Identity(),
382
+ origin_id=node_id,
383
  inputs=[_to_id(*last_output)],
384
  outputs=[_to_id(start_id, "output")],
385
  )
 
387
  # Repeat the layers in the section.
388
  for layer in repeated_layers:
389
  if p["same_weights"]:
390
+ self.layers.append(layer)
 
 
 
 
 
 
 
 
391
  else:
392
+ self.run_node(layer.origin_id)
393
  self.layers.append(self.run_op(node_id, op, p))
394
  case "Optimizer" | "Input: tensor" | "Input: graph edges" | "Input: sequential":
395
  return
 
400
  """Returns the layer produced by this op."""
401
  inputs = [_to_id(*i) for n in op.inputs for i in self.in_edges[node_id][n]]
402
  outputs = [_to_id(node_id, n) for n in op.outputs]
403
+ if op.func == ops.no_op:
404
+ module = torch.nn.Identity()
405
+ else:
406
+ module = op.func(*inputs, **params)
407
+ return Layer(module, node_id, inputs, outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
 
409
  def build_model(self) -> ModelConfig:
410
  # Walk the graph in topological order.
 
423
  layers = []
424
  loss_layers = []
425
  for layer in self.layers:
426
+ if layer.origin_id in loss_nodes:
427
  loss_layers.append(layer)
428
  else:
429
  layers.append(layer)
430
+ used_in_model = set(input for layer in layers for input in layer.inputs)
431
+ used_in_loss = set(input for layer in loss_layers for input in layer.inputs)
432
+ made_in_model = set(output for layer in layers for output in layer.outputs)
433
+ made_in_loss = set(output for layer in loss_layers for output in layer.outputs)
434
+ layers = [layer.for_sequential() for layer in layers]
435
+ loss_layers = [layer.for_sequential() for layer in loss_layers]
436
  cfg = {}
437
  cfg["model_inputs"] = list(used_in_model - made_in_model)
438
  cfg["model_outputs"] = list(made_in_model & used_in_loss)
 
441
  outputs = ", ".join(cfg["model_outputs"])
442
  layers.append((torch.nn.Identity(), f"{outputs} -> {outputs}"))
443
  # Create model.
444
+ cfg["model"] = pyg_nn.Sequential(", ".join(cfg["model_inputs"]), layers)
445
  # Make sure the loss is output from the last loss layer.
446
  [(lossb, lossh)] = self.in_edges[self.optimizer]["loss"]
447
  lossi = _to_id(lossb, lossh)
448
  loss_layers.append((torch.nn.Identity(), f"{lossi} -> loss"))
449
  # Create loss function.
450
+ cfg["loss"] = pyg_nn.Sequential(", ".join(cfg["loss_inputs"]), loss_layers)
451
  assert not list(cfg["loss"].parameters()), f"loss should have no parameters: {loss_layers}"
452
  # Create optimizer.
453
  op = self.catalog["Optimizer"]
lynxkite-graph-analytics/tests/test_pytorch_model_ops.py CHANGED
@@ -4,7 +4,7 @@ import torch
4
  import pytest
5
 
6
 
7
- def make_ws(env, nodes: dict[str, dict], edges: list[tuple[str, str, str, str]]):
8
  ws = workspace.Workspace(env=env)
9
  for id, data in nodes.items():
10
  title = data["title"]
@@ -49,7 +49,7 @@ async def test_build_model():
49
  pytorch_model_ops.ENV,
50
  {
51
  "emb": {"title": "Input: tensor"},
52
- "lin": {"title": "Linear", "output_dim": "same"},
53
  "act": {"title": "Activation", "type": "Leaky_ReLU"},
54
  "label": {"title": "Input: tensor"},
55
  "loss": {"title": "MSE loss"},
@@ -65,7 +65,7 @@ async def test_build_model():
65
  )
66
  x = torch.rand(100, 4)
67
  y = x + 1
68
- m = pytorch_model_ops.build_model(ws, {"emb_output": x, "label_output": y})
69
  for i in range(1000):
70
  loss = m.train({"emb_output": x, "label_output": y})
71
  assert loss < 0.1
@@ -80,7 +80,7 @@ async def test_build_model_with_repeat():
80
  pytorch_model_ops.ENV,
81
  {
82
  "emb": {"title": "Input: tensor"},
83
- "lin": {"title": "Linear", "output_dim": "same"},
84
  "act": {"title": "Activation", "type": "Leaky_ReLU"},
85
  "label": {"title": "Input: tensor"},
86
  "loss": {"title": "MSE loss"},
@@ -99,17 +99,17 @@ async def test_build_model_with_repeat():
99
  )
100
 
101
  # 1 repetition
102
- m = pytorch_model_ops.build_model(repeated_ws(1), {})
103
  assert summarize_layers(m) == "IL<II"
104
  assert summarize_connections(m) == "e->S S->l l->a a->E E->E"
105
 
106
  # 2 repetitions
107
- m = pytorch_model_ops.build_model(repeated_ws(2), {})
108
  assert summarize_layers(m) == "IL<IL<II"
109
  assert summarize_connections(m) == "e->S S->l l->a a->S S->l l->a a->E E->E"
110
 
111
  # 3 repetitions
112
- m = pytorch_model_ops.build_model(repeated_ws(3), {})
113
  assert summarize_layers(m) == "IL<IL<IL<II"
114
  assert summarize_connections(m) == "e->S S->l l->a a->S S->l l->a a->S S->l l->a a->E E->E"
115
 
 
4
  import pytest
5
 
6
 
7
+ def make_ws(env, nodes: dict[str, dict], edges: list[tuple[str, str]]):
8
  ws = workspace.Workspace(env=env)
9
  for id, data in nodes.items():
10
  title = data["title"]
 
49
  pytorch_model_ops.ENV,
50
  {
51
  "emb": {"title": "Input: tensor"},
52
+ "lin": {"title": "Linear", "output_dim": 4},
53
  "act": {"title": "Activation", "type": "Leaky_ReLU"},
54
  "label": {"title": "Input: tensor"},
55
  "loss": {"title": "MSE loss"},
 
65
  )
66
  x = torch.rand(100, 4)
67
  y = x + 1
68
+ m = pytorch_model_ops.build_model(ws)
69
  for i in range(1000):
70
  loss = m.train({"emb_output": x, "label_output": y})
71
  assert loss < 0.1
 
80
  pytorch_model_ops.ENV,
81
  {
82
  "emb": {"title": "Input: tensor"},
83
+ "lin": {"title": "Linear", "output_dim": 8},
84
  "act": {"title": "Activation", "type": "Leaky_ReLU"},
85
  "label": {"title": "Input: tensor"},
86
  "loss": {"title": "MSE loss"},
 
99
  )
100
 
101
  # 1 repetition
102
+ m = pytorch_model_ops.build_model(repeated_ws(1))
103
  assert summarize_layers(m) == "IL<II"
104
  assert summarize_connections(m) == "e->S S->l l->a a->E E->E"
105
 
106
  # 2 repetitions
107
+ m = pytorch_model_ops.build_model(repeated_ws(2))
108
  assert summarize_layers(m) == "IL<IL<II"
109
  assert summarize_connections(m) == "e->S S->l l->a a->S S->l l->a a->E E->E"
110
 
111
  # 3 repetitions
112
+ m = pytorch_model_ops.build_model(repeated_ws(3))
113
  assert summarize_layers(m) == "IL<IL<IL<II"
114
  assert summarize_connections(m) == "e->S S->l l->a a->S S->l l->a a->S S->l l->a a->E E->E"
115