"""Boxes for defining PyTorch models.""" import graphlib from lynxkite.core import ops, workspace from lynxkite.core.ops import Parameter as P import torch import torch_geometric as pyg from dataclasses import dataclass ENV = "PyTorch model" def reg(name, inputs=[], outputs=None, params=[]): if outputs is None: outputs = inputs return ops.register_passive_op( ENV, name, inputs=[ ops.Input(name=name, position="bottom", type="tensor") for name in inputs ], outputs=[ ops.Output(name=name, position="top", type="tensor") for name in outputs ], params=params, ) reg("Input: embedding", outputs=["x"]) reg("Input: graph edges", outputs=["edges"]) reg("Input: label", outputs=["y"]) reg("Input: positive sample", outputs=["x_pos"]) reg("Input: negative sample", outputs=["x_neg"]) reg("Input: sequential", outputs=["y"]) reg("Input: zeros", outputs=["x"]) reg("LSTM", inputs=["x", "h"], outputs=["x", "h"]) reg( "Neural ODE", inputs=["x"], params=[ P.basic("relative_tolerance"), P.basic("absolute_tolerance"), P.options( "method", [ "dopri8", "dopri5", "bosh3", "fehlberg2", "adaptive_heun", "euler", "midpoint", "rk4", "explicit_adams", "implicit_adams", ], ), ], ) reg("Attention", inputs=["q", "k", "v"], outputs=["x", "weights"]) reg("LayerNorm", inputs=["x"]) reg("Dropout", inputs=["x"], params=[P.basic("p", 0.5)]) reg("Linear", inputs=["x"], params=[P.basic("output_dim", "same")]) reg("Softmax", inputs=["x"]) reg( "Graph conv", inputs=["x", "edges"], outputs=["x"], params=[P.options("type", ["GCNConv", "GATConv", "GATv2Conv", "SAGEConv"])], ) reg( "Activation", inputs=["x"], params=[P.options("type", ["ReLU", "Leaky ReLU", "Tanh", "Mish"])], ) reg("Concatenate", inputs=["a", "b"], outputs=["x"]) reg("Add", inputs=["a", "b"], outputs=["x"]) reg("Subtract", inputs=["a", "b"], outputs=["x"]) reg("Multiply", inputs=["a", "b"], outputs=["x"]) reg("MSE loss", inputs=["x", "y"], outputs=["loss"]) reg("Triplet margin loss", inputs=["x", "x_pos", "x_neg"], outputs=["loss"]) reg("Cross-entropy loss", inputs=["x", "y"], outputs=["loss"]) reg( "Optimizer", inputs=["loss"], outputs=[], params=[ P.options( "type", [ "AdamW", "Adafactor", "Adagrad", "SGD", "Lion", "Paged AdamW", "Galore AdamW", ], ), P.basic("lr", 0.001), ], ) ops.register_passive_op( ENV, "Repeat", inputs=[ops.Input(name="input", position="top", type="tensor")], outputs=[ops.Output(name="output", position="bottom", type="tensor")], params=[ ops.Parameter.basic("times", 1, int), ops.Parameter.basic("same_weights", True, bool), ], ) ops.register_passive_op( ENV, "Recurrent chain", inputs=[ops.Input(name="input", position="top", type="tensor")], outputs=[ops.Output(name="output", position="bottom", type="tensor")], params=[], ) def _to_id(s: str) -> str: """Replaces all non-alphanumeric characters with underscores.""" return "".join(c if c.isalnum() else "_" for c in s) @dataclass class ModelConfig: model: torch.nn.Module model_inputs: list[str] model_outputs: list[str] loss_inputs: list[str] loss: torch.nn.Module optimizer: torch.optim.Optimizer def _forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: model_inputs = [inputs[i] for i in self.model_inputs] output = self.model(*model_inputs) if not isinstance(output, tuple): output = (output,) values = {k: v for k, v in zip(self.model_outputs, output)} return values def inference(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: # TODO: Do multiple batches. self.model.eval() return self._forward(inputs) def train(self, inputs: dict[str, torch.Tensor]) -> float: """Train the model for one epoch. Returns the loss.""" # TODO: Do multiple batches. self.model.train() self.optimizer.zero_grad() values = self._forward(inputs) values.update(inputs) loss_inputs = [values[i] for i in self.loss_inputs] loss = self.loss(*loss_inputs) loss.backward() self.optimizer.step() return loss.item() def build_model( ws: workspace.Workspace, inputs: dict[str, torch.Tensor] ) -> ModelConfig: """Builds the model described in the workspace.""" optimizers = [] nodes = {} for node in ws.nodes: nodes[node.id] = node if node.data.title == "Optimizer": optimizers.append(node.id) assert optimizers, "No optimizer found." assert len(optimizers) == 1, f"More than one optimizer found: {optimizers}" [optimizer] = optimizers dependencies = {n.id: [] for n in ws.nodes} edges = {} # TODO: Dissolve repeat boxes here. for e in ws.edges: dependencies[e.target].append(e.source) edges.setdefault((e.target, e.targetHandle), []).append( (e.source, e.sourceHandle) ) sizes = {} for k, i in inputs.items(): sizes[k] = i.shape[-1] ts = graphlib.TopologicalSorter(dependencies) layers = [] loss_layers = [] in_loss = set() cfg = {} loss_inputs = set() used_inputs = set() for node_id in ts.static_order(): node = nodes[node_id] t = node.data.title p = node.data.params for b in dependencies[node_id]: if b in in_loss: in_loss.add(node_id) ls = loss_layers if node_id in in_loss else layers nid = _to_id(node_id) match t: case "Linear": [(ib, ih)] = edges[node_id, "x"] i = _to_id(ib) + "_" + ih used_inputs.add(i) isize = sizes[i] osize = isize if p["output_dim"] == "same" else int(p["output_dim"]) ls.append((torch.nn.Linear(isize, osize), f"{i} -> {nid}_x")) sizes[f"{nid}_x"] = osize case "Activation": [(ib, ih)] = edges[node_id, "x"] i = _to_id(ib) + "_" + ih used_inputs.add(i) f = getattr(torch.nn.functional, p["type"].lower().replace(" ", "_")) ls.append((f, f"{i} -> {nid}_x")) sizes[f"{nid}_x"] = sizes[i] case "MSE loss": [(xb, xh)] = edges[node_id, "x"] xi = _to_id(xb) + "_" + xh [(yb, yh)] = edges[node_id, "y"] yi = _to_id(yb) + "_" + yh loss_inputs.add(xi) loss_inputs.add(yi) in_loss.add(node_id) loss_layers.append( (torch.nn.functional.mse_loss, f"{xi}, {yi} -> {nid}_loss") ) cfg["model_inputs"] = used_inputs & inputs.keys() cfg["model_outputs"] = loss_inputs - inputs.keys() cfg["loss_inputs"] = loss_inputs # Make sure the trained output is output from the last model layer. outputs = ", ".join(cfg["model_outputs"]) layers.append((torch.nn.Identity(), f"{outputs} -> {outputs}")) # Create model. cfg["model"] = pyg.nn.Sequential(", ".join(used_inputs & inputs.keys()), layers) # Make sure the loss is output from the last loss layer. [(lossb, lossh)] = edges[optimizer, "loss"] lossi = _to_id(lossb) + "_" + lossh loss_layers.append((torch.nn.Identity(), f"{lossi} -> loss")) # Create loss function. cfg["loss"] = pyg.nn.Sequential(", ".join(loss_inputs), loss_layers) assert not list(cfg["loss"].parameters()), ( f"loss should have no parameters: {list(cfg['loss'].parameters())}" ) # Create optimizer. p = nodes[optimizer].data.params o = getattr(torch.optim, p["type"]) cfg["optimizer"] = o(cfg["model"].parameters(), lr=p["lr"]) return ModelConfig(**cfg)