File size: 4,089 Bytes
3cc3a0a
51eb575
3cc3a0a
 
 
 
612802c
3cc3a0a
 
0cdd509
 
3cc3a0a
 
 
 
0cdd509
3cc3a0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51eb575
0cdd509
 
 
51eb575
0cdd509
 
 
 
 
 
3cc3a0a
 
51eb575
3cc3a0a
83caaa1
612802c
0cdd509
83caaa1
0cdd509
3cc3a0a
 
 
 
83caaa1
0cdd509
83caaa1
 
0cdd509
 
3cc3a0a
 
 
 
51eb575
3cc3a0a
83caaa1
3cc3a0a
83caaa1
 
3cc3a0a
 
 
0cdd509
 
 
51eb575
0cdd509
83caaa1
612802c
0cdd509
83caaa1
0cdd509
 
 
 
 
 
83caaa1
0cdd509
83caaa1
 
0cdd509
 
 
 
 
 
 
 
51eb575
83caaa1
 
0cdd509
 
51eb575
83caaa1
 
0cdd509
 
51eb575
83caaa1
 
0cdd509
 
3cc3a0a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from lynxkite.core import workspace
from lynxkite_graph_analytics.pytorch import pytorch_core
import torch
import pytest


def make_ws(env, nodes: dict[str, dict], edges: list[tuple[str, str]]):
    ws = workspace.Workspace(env=env)
    for id, data in nodes.items():
        title = data["title"]
        del data["title"]
        ws.nodes.append(
            workspace.WorkspaceNode(
                id=id,
                type="basic",
                data=workspace.WorkspaceNodeData(title=title, params=data),
                position=workspace.Position(
                    x=data.get("x", 0),
                    y=data.get("y", 0),
                ),
            )
        )
    ws.edges = [
        workspace.WorkspaceEdge(
            id=f"{source}->{target}",
            source=source.split(":")[0],
            target=target.split(":")[0],
            sourceHandle=source.split(":")[1],
            targetHandle=target.split(":")[1],
        )
        for source, target in edges
    ]
    return ws


def summarize_layers(m: pytorch_core.ModelConfig) -> str:
    return "".join(str(e)[0] for e in m.model)


def summarize_connections(m: pytorch_core.ModelConfig) -> str:
    return " ".join(
        "".join(n[0] for n in c.param_names) + "->" + "".join(n[0] for n in c.return_names)
        for c in m.model._children
    )


async def test_build_model():
    ws = make_ws(
        pytorch_core.ENV,
        {
            "input": {"title": "Input: tensor"},
            "lin": {"title": "Linear", "output_dim": 4},
            "act": {"title": "Activation", "type": "Leaky_ReLU"},
            "output": {"title": "Output"},
            "label": {"title": "Input: tensor"},
            "loss": {"title": "MSE loss"},
            "optim": {"title": "Optimizer", "type": "SGD", "lr": 0.1},
        },
        [
            ("input:output", "lin:x"),
            ("lin:output", "act:x"),
            ("act:output", "output:x"),
            ("output:x", "loss:x"),
            ("label:output", "loss:y"),
            ("loss:output", "optim:loss"),
        ],
    )
    x = torch.rand(100, 4)
    y = x + 1
    m = pytorch_core.build_model(ws)
    for i in range(1000):
        loss = m.train({"input_output": x, "label_output": y})
    assert loss < 0.1
    o = m.inference({"input_output": x[:1]})
    error = torch.nn.functional.mse_loss(o["output_x"], x[:1] + 1)
    assert error < 0.1


async def test_build_model_with_repeat():
    def repeated_ws(times):
        return make_ws(
            pytorch_core.ENV,
            {
                "input": {"title": "Input: tensor"},
                "lin": {"title": "Linear", "output_dim": 8},
                "act": {"title": "Activation", "type": "Leaky_ReLU"},
                "output": {"title": "Output"},
                "label": {"title": "Input: tensor"},
                "loss": {"title": "MSE loss"},
                "optim": {"title": "Optimizer", "type": "SGD", "lr": 0.1},
                "repeat": {"title": "Repeat", "times": times, "same_weights": False},
            },
            [
                ("input:output", "lin:x"),
                ("lin:output", "act:x"),
                ("act:output", "output:x"),
                ("output:x", "loss:x"),
                ("label:output", "loss:y"),
                ("loss:output", "optim:loss"),
                ("repeat:output", "lin:x"),
                ("act:output", "repeat:input"),
            ],
        )

    # 1 repetition
    m = pytorch_core.build_model(repeated_ws(1))
    assert summarize_layers(m) == "IL<III"
    assert summarize_connections(m) == "i->S S->l l->a a->E E->o o->o"

    # 2 repetitions
    m = pytorch_core.build_model(repeated_ws(2))
    assert summarize_layers(m) == "IL<IL<III"
    assert summarize_connections(m) == "i->S S->l l->a a->S S->l l->a a->E E->o o->o"

    # 3 repetitions
    m = pytorch_core.build_model(repeated_ws(3))
    assert summarize_layers(m) == "IL<IL<IL<III"
    assert summarize_connections(m) == "i->S S->l l->a a->S S->l l->a a->S S->l l->a a->E E->o o->o"


if __name__ == "__main__":
    pytest.main()