darabos commited on
Commit
010d5ee
·
2 Parent(s): 48da2eb ab6c18c

Merge pull request #184 from biggraph/darabos-ode-gnn-v2

Browse files
lynxkite-app/web/src/workspace/nodes/LynxKiteNode.tsx CHANGED
@@ -54,6 +54,7 @@ function getHandles(inputs: any[], outputs: any[]) {
54
  }
55
 
56
  const OP_COLORS: { [key: string]: string } = {
 
57
  pink: "oklch(75% 0.2 0)",
58
  orange: "oklch(75% 0.2 55)",
59
  green: "oklch(75% 0.2 150)",
 
54
  }
55
 
56
  const OP_COLORS: { [key: string]: string } = {
57
+ gray: "oklch(95% 0 0)",
58
  pink: "oklch(75% 0.2 0)",
59
  orange: "oklch(75% 0.2 55)",
60
  green: "oklch(75% 0.2 150)",
lynxkite-core/src/lynxkite/core/ops.py CHANGED
@@ -240,7 +240,7 @@ def op(
240
  """Decorator for defining an operation."""
241
 
242
  def decorator(func):
243
- doc = get_doc(func)
244
  sig = inspect.signature(func)
245
  _view = view
246
  if view == "matplotlib":
@@ -436,16 +436,17 @@ def run_user_script(script_path: pathlib.Path):
436
  spec.loader.exec_module(module)
437
 
438
 
439
- def get_doc(func):
 
440
  """Griffe is an optional dependency. When available, we returned the parsed docstring."""
441
  try:
442
  import griffe
443
  except ImportError:
444
- return func.__doc__
445
- if func.__doc__ is None:
446
  return None
447
- if "----" in func.__doc__:
448
- doc = griffe.Docstring(func.__doc__).parse("numpy")
449
  else:
450
- doc = griffe.Docstring(func.__doc__).parse("google")
451
  return json.loads(json.dumps(doc, cls=griffe.JSONEncoder))
 
240
  """Decorator for defining an operation."""
241
 
242
  def decorator(func):
243
+ doc = parse_doc(func.__doc__)
244
  sig = inspect.signature(func)
245
  _view = view
246
  if view == "matplotlib":
 
436
  spec.loader.exec_module(module)
437
 
438
 
439
+ @functools.cache
440
+ def parse_doc(doc):
441
  """Griffe is an optional dependency. When available, we returned the parsed docstring."""
442
  try:
443
  import griffe
444
  except ImportError:
445
+ return doc
446
+ if doc is None:
447
  return None
448
+ if "----" in doc:
449
+ doc = griffe.Docstring(doc).parse("numpy")
450
  else:
451
+ doc = griffe.Docstring(doc).parse("google")
452
  return json.loads(json.dumps(doc, cls=griffe.JSONEncoder))
lynxkite-graph-analytics/src/lynxkite_graph_analytics/pytorch/pytorch_ops.py CHANGED
@@ -6,24 +6,23 @@ from lynxkite.core.ops import Parameter as P
6
  import torch
7
  from .pytorch_core import op, reg, ENV
8
 
9
- reg("Input: tensor", outputs=["output"], params=[P.basic("name")])
10
- reg("Input: graph edges", outputs=["edges"])
11
- reg("Input: sequential", outputs=["y"], params=[P.basic("name")])
12
- reg("Output", inputs=["x"], outputs=["x"], params=[P.basic("name")])
13
 
14
 
15
  @op("LSTM", weights=True)
16
  def lstm(x, *, input_size=1024, hidden_size=1024, dropout=0.0):
17
- return torch.nn.LSTM(input_size, hidden_size, dropout=0.0)
18
 
19
 
20
  reg(
21
- "Neural ODE",
22
  color="blue",
23
- inputs=["x"],
 
24
  params=[
25
- P.basic("relative_tolerance"),
26
- P.basic("absolute_tolerance"),
27
  P.options(
28
  "method",
29
  [
@@ -39,6 +38,11 @@ reg(
39
  "implicit_adams",
40
  ],
41
  ),
 
 
 
 
 
42
  ],
43
  )
44
 
@@ -66,6 +70,13 @@ def linear(x, *, output_dim=1024):
66
  return pyg_nn.Linear(-1, output_dim)
67
 
68
 
 
 
 
 
 
 
 
69
  class ActivationTypes(str, enum.Enum):
70
  ReLU = "ReLU"
71
  Leaky_ReLU = "Leaky ReLU"
@@ -93,11 +104,39 @@ def softmax(x, *, dim=1):
93
  return torch.nn.Softmax(dim=dim)
94
 
95
 
 
 
 
 
 
96
  @op("Concatenate")
97
  def concatenate(a, b):
98
  return lambda a, b: torch.concatenate(*torch.broadcast_tensors(a, b))
99
 
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  reg(
102
  "Graph conv",
103
  color="blue",
@@ -105,6 +144,15 @@ reg(
105
  outputs=["x"],
106
  params=[P.options("type", ["GCNConv", "GATConv", "GATv2Conv", "SAGEConv"])],
107
  )
 
 
 
 
 
 
 
 
 
108
 
109
  reg("Triplet margin loss", inputs=["x", "x_pos", "x_neg"], outputs=["loss"])
110
  reg("Cross-entropy loss", inputs=["x", "y"], outputs=["loss"])
@@ -125,7 +173,7 @@ reg(
125
  "Galore AdamW",
126
  ],
127
  ),
128
- P.basic("lr", 0.001),
129
  ],
130
  color="green",
131
  )
 
6
  import torch
7
  from .pytorch_core import op, reg, ENV
8
 
9
+ reg("Input: tensor", outputs=["output"], params=[P.basic("name")], color="gray")
10
+ reg("Input: graph edges", outputs=["edges"], params=[P.basic("name")], color="gray")
11
+ reg("Input: sequential", outputs=["y"], params=[P.basic("name")], color="gray")
12
+ reg("Output", inputs=["x"], outputs=["x"], params=[P.basic("name")], color="gray")
13
 
14
 
15
  @op("LSTM", weights=True)
16
  def lstm(x, *, input_size=1024, hidden_size=1024, dropout=0.0):
17
+ return torch.nn.LSTM(input_size, hidden_size, dropout=dropout)
18
 
19
 
20
  reg(
21
+ "Neural ODE with MLP",
22
  color="blue",
23
+ inputs=["x", "y0", "t"],
24
+ outputs=["y"],
25
  params=[
 
 
26
  P.options(
27
  "method",
28
  [
 
38
  "implicit_adams",
39
  ],
40
  ),
41
+ P.basic("relative_tolerance"),
42
+ P.basic("absolute_tolerance"),
43
+ P.basic("mlp_layers"),
44
+ P.basic("mlp_hidden_size"),
45
+ P.options("mlp_activation", ["ReLU", "Tanh", "Sigmoid"]),
46
  ],
47
  )
48
 
 
70
  return pyg_nn.Linear(-1, output_dim)
71
 
72
 
73
+ @op("Mean pool")
74
+ def mean_pool(x):
75
+ import torch_geometric.nn as pyg_nn
76
+
77
+ return pyg_nn.global_mean_pool
78
+
79
+
80
  class ActivationTypes(str, enum.Enum):
81
  ReLU = "ReLU"
82
  Leaky_ReLU = "Leaky ReLU"
 
104
  return torch.nn.Softmax(dim=dim)
105
 
106
 
107
+ @op("Embedding", weights=True)
108
+ def embedding(x, *, num_embeddings: int, embedding_dim: int):
109
+ return torch.nn.Embedding(num_embeddings, embedding_dim)
110
+
111
+
112
  @op("Concatenate")
113
  def concatenate(a, b):
114
  return lambda a, b: torch.concatenate(*torch.broadcast_tensors(a, b))
115
 
116
 
117
+ reg(
118
+ "Pick element by index",
119
+ inputs=["x", "index"],
120
+ outputs=["x_i"],
121
+ )
122
+ reg(
123
+ "Pick element by constant",
124
+ inputs=["x"],
125
+ outputs=["x_i"],
126
+ params=[ops.Parameter.basic("index", "0")],
127
+ )
128
+ reg(
129
+ "Take first n",
130
+ inputs=["x"],
131
+ outputs=["x"],
132
+ params=[ops.Parameter.basic("n", 1, int)],
133
+ )
134
+ reg(
135
+ "Drop first n",
136
+ inputs=["x"],
137
+ outputs=["x"],
138
+ params=[ops.Parameter.basic("n", 1, int)],
139
+ )
140
  reg(
141
  "Graph conv",
142
  color="blue",
 
144
  outputs=["x"],
145
  params=[P.options("type", ["GCNConv", "GATConv", "GATv2Conv", "SAGEConv"])],
146
  )
147
+ reg(
148
+ "Heterogeneous graph conv",
149
+ inputs=["node_embeddings", "edge_modules"],
150
+ outputs=["x"],
151
+ params=[
152
+ ops.Parameter.basic("node_embeddings_order"),
153
+ ops.Parameter.basic("edge_modules_order"),
154
+ ],
155
+ )
156
 
157
  reg("Triplet margin loss", inputs=["x", "x_pos", "x_neg"], outputs=["loss"])
158
  reg("Cross-entropy loss", inputs=["x", "y"], outputs=["loss"])
 
173
  "Galore AdamW",
174
  ],
175
  ),
176
+ P.basic("lr", 0.0001),
177
  ],
178
  color="green",
179
  )