darabos commited on
Commit
ae86f2c
·
1 Parent(s): 7657f6c

Reliable model mapping frontend.

Browse files
lynxkite-app/web/src/workspace/nodes/NodeParameter.tsx CHANGED
@@ -1,4 +1,5 @@
1
  // @ts-ignore
 
2
  import ArrowsHorizontal from "~icons/tabler/arrows-horizontal.jsx";
3
 
4
  const BOOLEAN = "<class 'bool'>";
@@ -14,13 +15,16 @@ function ParamName({ name }: { name: string }) {
14
  function Input({
15
  value,
16
  onChange,
 
17
  }: {
18
  value: string;
19
  onChange: (value: string, options?: { delay: number }) => void;
 
20
  }) {
21
  return (
22
  <input
23
  className="input input-bordered w-full"
 
24
  value={value || ""}
25
  onChange={(evt) => onChange(evt.currentTarget.value, { delay: 2 })}
26
  onBlur={(evt) => onChange(evt.currentTarget.value, { delay: 0 })}
@@ -29,6 +33,13 @@ function Input({
29
  );
30
  }
31
 
 
 
 
 
 
 
 
32
  function getModelBindings(
33
  data: any,
34
  variant: "training input" | "inference input" | "output",
@@ -71,6 +82,10 @@ function parseJsonOrEmpty(json: string): object {
71
  }
72
 
73
  function ModelMapping({ value, onChange, data, variant }: any) {
 
 
 
 
74
  const v: any = parseJsonOrEmpty(value);
75
  v.map ??= {};
76
  const dfs: { [df: string]: string[] } = {};
@@ -84,6 +99,17 @@ function ModelMapping({ value, onChange, data, variant }: any) {
84
  }
85
  }
86
  const bindings = getModelBindings(data, variant);
 
 
 
 
 
 
 
 
 
 
 
87
  return (
88
  <table className="model-mapping-param">
89
  <tbody>
@@ -98,21 +124,10 @@ function ModelMapping({ value, onChange, data, variant }: any) {
98
  <select
99
  className="select select-ghost"
100
  value={v.map?.[binding]?.df}
101
- onChange={(evt) => {
102
- const df = evt.currentTarget.value;
103
- if (df === "") {
104
- const map = { ...v.map, [binding]: undefined };
105
- onChange(JSON.stringify({ map }));
106
- } else {
107
- const columnSpec = {
108
- column: dfs[df][0],
109
- ...(v.map?.[binding] || {}),
110
- df,
111
- };
112
- const map = { ...v.map, [binding]: columnSpec };
113
- onChange(JSON.stringify({ map }));
114
- }
115
  }}
 
116
  >
117
  <option key="" value="" />
118
  {Object.keys(dfs).map((df: string) => (
@@ -125,13 +140,16 @@ function ModelMapping({ value, onChange, data, variant }: any) {
125
  <td>
126
  {variant === "output" ? (
127
  <Input
 
 
 
128
  value={v.map?.[binding]?.column}
129
  onChange={(column, options) => {
130
- const columnSpec = {
131
- ...(v.map?.[binding] || {}),
132
- column,
133
- };
134
- const map = { ...v.map, [binding]: columnSpec };
135
  onChange(JSON.stringify({ map }), options);
136
  }}
137
  />
@@ -139,16 +157,12 @@ function ModelMapping({ value, onChange, data, variant }: any) {
139
  <select
140
  className="select select-ghost"
141
  value={v.map?.[binding]?.column}
142
- onChange={(evt) => {
143
- const column = evt.currentTarget.value;
144
- const columnSpec = {
145
- ...(v.map?.[binding] || {}),
146
- column,
147
- };
148
- const map = { ...v.map, [binding]: columnSpec };
149
- onChange(JSON.stringify({ map }));
150
  }}
 
151
  >
 
152
  {dfs[v.map?.[binding]?.df]?.map((col: string) => (
153
  <option key={col} value={col}>
154
  {col}
 
1
  // @ts-ignore
2
+ import { useRef } from "react";
3
  import ArrowsHorizontal from "~icons/tabler/arrows-horizontal.jsx";
4
 
5
  const BOOLEAN = "<class 'bool'>";
 
15
  function Input({
16
  value,
17
  onChange,
18
+ inputRef,
19
  }: {
20
  value: string;
21
  onChange: (value: string, options?: { delay: number }) => void;
22
+ inputRef?: React.Ref<HTMLInputElement>;
23
  }) {
24
  return (
25
  <input
26
  className="input input-bordered w-full"
27
+ ref={inputRef}
28
  value={value || ""}
29
  onChange={(evt) => onChange(evt.currentTarget.value, { delay: 2 })}
30
  onBlur={(evt) => onChange(evt.currentTarget.value, { delay: 0 })}
 
33
  );
34
  }
35
 
36
+ type Bindings = {
37
+ [key: string]: {
38
+ df: string;
39
+ column: string;
40
+ };
41
+ };
42
+
43
  function getModelBindings(
44
  data: any,
45
  variant: "training input" | "inference input" | "output",
 
82
  }
83
 
84
  function ModelMapping({ value, onChange, data, variant }: any) {
85
+ const dfsRef = useRef({} as { [binding: string]: HTMLSelectElement | null });
86
+ const columnsRef = useRef(
87
+ {} as { [binding: string]: HTMLSelectElement | HTMLInputElement | null },
88
+ );
89
  const v: any = parseJsonOrEmpty(value);
90
  v.map ??= {};
91
  const dfs: { [df: string]: string[] } = {};
 
99
  }
100
  }
101
  const bindings = getModelBindings(data, variant);
102
+ function getMap() {
103
+ const map: Bindings = {};
104
+ for (const binding of bindings) {
105
+ const df = dfsRef.current[binding]?.value ?? "";
106
+ const column = columnsRef.current[binding]?.value ?? "";
107
+ if (df.length || column.length) {
108
+ map[binding] = { df, column };
109
+ }
110
+ }
111
+ return map;
112
+ }
113
  return (
114
  <table className="model-mapping-param">
115
  <tbody>
 
124
  <select
125
  className="select select-ghost"
126
  value={v.map?.[binding]?.df}
127
+ ref={(el) => {
128
+ dfsRef.current[binding] = el;
 
 
 
 
 
 
 
 
 
 
 
 
129
  }}
130
+ onChange={() => onChange(JSON.stringify({ map: getMap() }))}
131
  >
132
  <option key="" value="" />
133
  {Object.keys(dfs).map((df: string) => (
 
140
  <td>
141
  {variant === "output" ? (
142
  <Input
143
+ inputRef={(el) => {
144
+ columnsRef.current[binding] = el;
145
+ }}
146
  value={v.map?.[binding]?.column}
147
  onChange={(column, options) => {
148
+ const map = getMap();
149
+ // At this point the <input> has not been updated yet. We use the value from the event.
150
+ const df = dfsRef.current[binding]?.value ?? "";
151
+ map[binding] ??= { df, column };
152
+ map[binding].column = column;
153
  onChange(JSON.stringify({ map }), options);
154
  }}
155
  />
 
157
  <select
158
  className="select select-ghost"
159
  value={v.map?.[binding]?.column}
160
+ ref={(el) => {
161
+ columnsRef.current[binding] = el;
 
 
 
 
 
 
162
  }}
163
+ onChange={() => onChange(JSON.stringify({ map: getMap() }))}
164
  >
165
+ <option key="" value="" />
166
  {dfs[v.map?.[binding]?.df]?.map((col: string) => (
167
  <option key={col} value={col}>
168
  {col}
lynxkite-graph-analytics/src/lynxkite_graph_analytics/lynxkite_ops.py CHANGED
@@ -409,7 +409,13 @@ def model_inference(
409
  inputs = pytorch_model_ops.to_tensors(bundle, input_mapping)
410
  outputs = m.inference(inputs)
411
  bundle = bundle.copy()
 
412
  for k, v in output_mapping.map.items():
 
 
 
 
 
413
  bundle.dfs[v.df][v.column] = outputs[k].detach().numpy().tolist()
414
  return bundle
415
 
 
409
  inputs = pytorch_model_ops.to_tensors(bundle, input_mapping)
410
  outputs = m.inference(inputs)
411
  bundle = bundle.copy()
412
+ copied = set()
413
  for k, v in output_mapping.map.items():
414
+ if not v.df or not v.column:
415
+ continue
416
+ if v.df not in copied:
417
+ bundle.dfs[v.df] = bundle.dfs[v.df].copy()
418
+ copied.add(v.df)
419
  bundle.dfs[v.df][v.column] = outputs[k].detach().numpy().tolist()
420
  return bundle
421