Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Add CONCURRENCY_LIMIT; Graph config change -> directed
Browse files
interactive/pyvis_graph.py
CHANGED
@@ -3,12 +3,12 @@ import json
|
|
3 |
import torch
|
4 |
import gradio as gr
|
5 |
from pyvis.network import Network
|
6 |
-
|
7 |
sys.path.append(".")
|
|
|
8 |
from src.benchmarks import get_semistructured_data
|
9 |
|
10 |
-
|
11 |
-
TITLE = "STaRK Knowledge Base Explorer"
|
12 |
BRAND_NAME = {
|
13 |
"amazon": "STaRK-Amazon",
|
14 |
"mag": "STaRK-MAG",
|
@@ -22,20 +22,16 @@ NODE_COLORS = [
|
|
22 |
"#00796B", # Teal
|
23 |
"#03A9F4", # Light Blue
|
24 |
"#CDDC39", # Lime
|
25 |
-
"#E91E63", # Pink
|
26 |
"#3F51B5", # Indigo
|
27 |
"#00BCD4", # Cyan
|
28 |
"#FFC107", # Amber
|
29 |
"#8BC34A", # Light Green
|
30 |
-
"#795548", # Brown
|
31 |
"#9E9E9E", # Grey
|
32 |
"#607D8B", # Blue Grey
|
33 |
"#FFEB3B", # Bright Yellow
|
34 |
"#E1F5FE", # Light Blue 50
|
35 |
"#F1F8E9", # Light Green 50
|
36 |
"#FFF3E0", # Orange 50
|
37 |
-
"#FCE4EC", # Pink 50
|
38 |
-
"#F3E5F5", # Purple 50
|
39 |
"#FFFDE7", # Yellow 50
|
40 |
"#E0F7FA", # Cyan 50
|
41 |
"#E8F5E9", # Green 50
|
@@ -90,10 +86,14 @@ def relabel(x, edge_index, batch, pos=None):
|
|
90 |
return x, edge_index, batch, pos
|
91 |
|
92 |
|
93 |
-
def generate_network(kb, node_id, max_nodes=10, num_hops=
|
94 |
max_nodes = int(max_nodes)
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
97 |
|
98 |
def get_one_hop(kb, node_id, max_nodes):
|
99 |
edge_index = kb.edge_index
|
@@ -137,7 +137,6 @@ def generate_network(kb, node_id, max_nodes=10, num_hops="1"):
|
|
137 |
node_ids, relabel_edge_index, _, _ = relabel(
|
138 |
torch.arange(kb.num_nodes()), edge_index, batch=torch.zeros(kb.num_nodes())
|
139 |
)
|
140 |
-
|
141 |
for idx, n_id in enumerate(node_ids):
|
142 |
if node_id == n_id:
|
143 |
net.add_node(
|
@@ -158,31 +157,45 @@ def generate_network(kb, node_id, max_nodes=10, num_hops="1"):
|
|
158 |
font={"align": "middle", "size": 10},
|
159 |
)
|
160 |
for idx in range(relabel_edge_index.size(-1)):
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
return net.get_network_data()
|
173 |
|
174 |
|
175 |
def get_text_html(kb, node_id):
|
176 |
text = kb.get_doc_info(node_id, add_rel=False, compact=False)
|
177 |
-
# need a text box, figure left, text right
|
178 |
-
text = text.replace("\n", "<br>").replace(" ", " ")
|
179 |
# add a title
|
|
|
180 |
text = f"<h3>Textual Info of Entity {node_id}:</h3>{text}"
|
|
|
181 |
# show the text as what it is with empty space and can be scrolled
|
182 |
-
return f"""<
|
|
|
|
|
|
|
183 |
|
184 |
|
185 |
-
def get_subgraph_html(kb, kb_name, node_id, max_nodes=10, num_hops=
|
186 |
network = generate_network(kb, node_id, max_nodes, num_hops)
|
187 |
|
188 |
nodes = network[0]
|
@@ -200,7 +213,7 @@ def get_subgraph_html(kb, kb_name, node_id, max_nodes=10, num_hops="1"):
|
|
200 |
|
201 |
def main():
|
202 |
# kb = get_semistructured_data(DATASET_NAME)
|
203 |
-
kbs = {k: get_semistructured_data(k) for k in BRAND_NAME.keys()}
|
204 |
|
205 |
with gr.Blocks(head=VISJS_HEAD, title=TITLE) as demo:
|
206 |
gr.Markdown(f"# {TITLE}")
|
@@ -208,11 +221,14 @@ def main():
|
|
208 |
with gr.Tab(BRAND_NAME[name]):
|
209 |
with gr.Row():
|
210 |
entity_id = gr.Number(
|
211 |
-
label="Entity ID",
|
|
|
|
|
|
|
|
|
212 |
)
|
213 |
-
max_paths = gr.Slider(1, 200, 10, step=1, label="Max Paths")
|
214 |
num_hops = gr.Dropdown(
|
215 |
-
["1", "2", "inf"], value="
|
216 |
)
|
217 |
query_btn = gr.Button(
|
218 |
value="Show Graph",
|
@@ -232,7 +248,7 @@ def main():
|
|
232 |
),
|
233 |
inputs=[entity_id, max_paths, num_hops],
|
234 |
outputs=[graph_area, text_area],
|
235 |
-
api_name=f"{name}-fetch-graph"
|
236 |
)
|
237 |
|
238 |
# Hidden inputs for fetch just text
|
@@ -248,11 +264,12 @@ def main():
|
|
248 |
lambda e, kb=kb: get_text_html(kb, e),
|
249 |
inputs=[entity_for_text],
|
250 |
outputs=text_area,
|
251 |
-
api_name=f"{name}-fetch-text"
|
252 |
)
|
253 |
-
|
254 |
demo.launch(share=True)
|
255 |
|
256 |
|
257 |
if __name__ == "__main__":
|
258 |
-
|
|
|
|
3 |
import torch
|
4 |
import gradio as gr
|
5 |
from pyvis.network import Network
|
|
|
6 |
sys.path.append(".")
|
7 |
+
import re
|
8 |
from src.benchmarks import get_semistructured_data
|
9 |
|
10 |
+
CONCURRENCY_LIMIT = 1000
|
11 |
+
TITLE = "STaRK Semistructure Knowledge Base Explorer"
|
12 |
BRAND_NAME = {
|
13 |
"amazon": "STaRK-Amazon",
|
14 |
"mag": "STaRK-MAG",
|
|
|
22 |
"#00796B", # Teal
|
23 |
"#03A9F4", # Light Blue
|
24 |
"#CDDC39", # Lime
|
|
|
25 |
"#3F51B5", # Indigo
|
26 |
"#00BCD4", # Cyan
|
27 |
"#FFC107", # Amber
|
28 |
"#8BC34A", # Light Green
|
|
|
29 |
"#9E9E9E", # Grey
|
30 |
"#607D8B", # Blue Grey
|
31 |
"#FFEB3B", # Bright Yellow
|
32 |
"#E1F5FE", # Light Blue 50
|
33 |
"#F1F8E9", # Light Green 50
|
34 |
"#FFF3E0", # Orange 50
|
|
|
|
|
35 |
"#FFFDE7", # Yellow 50
|
36 |
"#E0F7FA", # Cyan 50
|
37 |
"#E8F5E9", # Green 50
|
|
|
86 |
return x, edge_index, batch, pos
|
87 |
|
88 |
|
89 |
+
def generate_network(kb, node_id, max_nodes=10, num_hops='2'):
|
90 |
max_nodes = int(max_nodes)
|
91 |
+
if 'gene/protein' in kb.node_type_dict.values():
|
92 |
+
indirected = True
|
93 |
+
net = Network(directed=False)
|
94 |
+
else:
|
95 |
+
indirected = False
|
96 |
+
net = Network()
|
97 |
|
98 |
def get_one_hop(kb, node_id, max_nodes):
|
99 |
edge_index = kb.edge_index
|
|
|
137 |
node_ids, relabel_edge_index, _, _ = relabel(
|
138 |
torch.arange(kb.num_nodes()), edge_index, batch=torch.zeros(kb.num_nodes())
|
139 |
)
|
|
|
140 |
for idx, n_id in enumerate(node_ids):
|
141 |
if node_id == n_id:
|
142 |
net.add_node(
|
|
|
157 |
font={"align": "middle", "size": 10},
|
158 |
)
|
159 |
for idx in range(relabel_edge_index.size(-1)):
|
160 |
+
if indirected:
|
161 |
+
net.add_edge(
|
162 |
+
relabel_edge_index[0][idx].item(),
|
163 |
+
relabel_edge_index[1][idx].item(),
|
164 |
+
color=EDGE_COLORS[edge_types[idx].item()],
|
165 |
+
label=kb.edge_type_dict[edge_types[idx].item()]
|
166 |
+
.replace('___', " ")
|
167 |
+
.replace('_', " "),
|
168 |
+
width=1,
|
169 |
+
font={"align": "middle", "size": 10})
|
170 |
+
else:
|
171 |
+
net.add_edge(
|
172 |
+
relabel_edge_index[0][idx].item(),
|
173 |
+
relabel_edge_index[1][idx].item(),
|
174 |
+
color=EDGE_COLORS[edge_types[idx].item()],
|
175 |
+
label=kb.edge_type_dict[edge_types[idx].item()]
|
176 |
+
.replace('___', " ")
|
177 |
+
.replace('_', " "),
|
178 |
+
width=1,
|
179 |
+
font={"align": "middle", "size": 10},
|
180 |
+
arrows="to",
|
181 |
+
arrowStrikethrough=False)
|
182 |
return net.get_network_data()
|
183 |
|
184 |
|
185 |
def get_text_html(kb, node_id):
|
186 |
text = kb.get_doc_info(node_id, add_rel=False, compact=False)
|
|
|
|
|
187 |
# add a title
|
188 |
+
text = text.replace("\n", "<br>").replace(" ", " ")
|
189 |
text = f"<h3>Textual Info of Entity {node_id}:</h3>{text}"
|
190 |
+
text = re.sub(r"\$([^$]+)\$", r"\\(\1\\)", text)
|
191 |
# show the text as what it is with empty space and can be scrolled
|
192 |
+
return f"""<script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
193 |
+
<script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
194 |
+
<div style="width: 100%; height: 600px; overflow-x: hidden; overflow-y: scroll; overflow-wrap: break-word; hyphens: auto; padding: 10px; margin: 0 auto; border: 1px solid #ccc; line-height: 1.5;
|
195 |
+
font-family: SF Pro Text, SF Pro Icons, Helvetica Neue, Helvetica, Arial, sans-serif;">{text}</div>"""
|
196 |
|
197 |
|
198 |
+
def get_subgraph_html(kb, kb_name, node_id, max_nodes=10, num_hops=1):
|
199 |
network = generate_network(kb, node_id, max_nodes, num_hops)
|
200 |
|
201 |
nodes = network[0]
|
|
|
213 |
|
214 |
def main():
|
215 |
# kb = get_semistructured_data(DATASET_NAME)
|
216 |
+
kbs = {k: get_semistructured_data(k, indirected=False) for k in BRAND_NAME.keys()}
|
217 |
|
218 |
with gr.Blocks(head=VISJS_HEAD, title=TITLE) as demo:
|
219 |
gr.Markdown(f"# {TITLE}")
|
|
|
221 |
with gr.Tab(BRAND_NAME[name]):
|
222 |
with gr.Row():
|
223 |
entity_id = gr.Number(
|
224 |
+
label="Entity ID",
|
225 |
+
elem_id=f"{name}-entity-id-input"
|
226 |
+
)
|
227 |
+
max_paths = gr.Slider(
|
228 |
+
1, 200, 10, step=1, label="Max Number of Paths"
|
229 |
)
|
|
|
230 |
num_hops = gr.Dropdown(
|
231 |
+
["1", "2", "inf"], value="2", label="Number of Hops"
|
232 |
)
|
233 |
query_btn = gr.Button(
|
234 |
value="Show Graph",
|
|
|
248 |
),
|
249 |
inputs=[entity_id, max_paths, num_hops],
|
250 |
outputs=[graph_area, text_area],
|
251 |
+
api_name=f"{name}-fetch-graph"
|
252 |
)
|
253 |
|
254 |
# Hidden inputs for fetch just text
|
|
|
264 |
lambda e, kb=kb: get_text_html(kb, e),
|
265 |
inputs=[entity_for_text],
|
266 |
outputs=text_area,
|
267 |
+
api_name=f"{name}-fetch-text"
|
268 |
)
|
269 |
+
demo.queue(max_size=2*CONCURRENCY_LIMIT, default_concurrency_limit=CONCURRENCY_LIMIT)
|
270 |
demo.launch(share=True)
|
271 |
|
272 |
|
273 |
if __name__ == "__main__":
|
274 |
+
|
275 |
+
main()
|
src/benchmarks/get_semistruct.py
CHANGED
@@ -2,21 +2,22 @@ import os.path as osp
|
|
2 |
from src.benchmarks.semistruct import AmazonSemiStruct, PrimeKGSemiStruct, MagSemiStruct
|
3 |
|
4 |
|
5 |
-
def get_semistructured_data(name, root='data/', download_processed=True):
|
6 |
data_root = osp.join(root, name)
|
7 |
if name == 'amazon':
|
8 |
categories = ['Sports_and_Outdoors']
|
9 |
kb = AmazonSemiStruct(root=data_root,
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
if name == 'primekg':
|
16 |
kb = PrimeKGSemiStruct(root=data_root,
|
17 |
-
|
|
|
18 |
|
19 |
if name == 'mag':
|
20 |
kb = MagSemiStruct(root=data_root,
|
21 |
-
|
22 |
return kb
|
|
|
2 |
from src.benchmarks.semistruct import AmazonSemiStruct, PrimeKGSemiStruct, MagSemiStruct
|
3 |
|
4 |
|
5 |
+
def get_semistructured_data(name, root='data/', download_processed=True, **kwargs):
|
6 |
data_root = osp.join(root, name)
|
7 |
if name == 'amazon':
|
8 |
categories = ['Sports_and_Outdoors']
|
9 |
kb = AmazonSemiStruct(root=data_root,
|
10 |
+
categories=categories,
|
11 |
+
meta_link_types=['brand'],
|
12 |
+
download_processed=download_processed,
|
13 |
+
**kwargs
|
14 |
+
)
|
15 |
if name == 'primekg':
|
16 |
kb = PrimeKGSemiStruct(root=data_root,
|
17 |
+
download_processed=download_processed,
|
18 |
+
**kwargs)
|
19 |
|
20 |
if name == 'mag':
|
21 |
kb = MagSemiStruct(root=data_root,
|
22 |
+
download_processed=download_processed)
|
23 |
return kb
|
src/benchmarks/semistruct/amazon.py
CHANGED
@@ -63,8 +63,8 @@ class AmazonSemiStruct(SemiStructureKB):
|
|
63 |
categories: list,
|
64 |
meta_link_types=['brand'],
|
65 |
max_entries=25,
|
66 |
-
|
67 |
-
|
68 |
'''
|
69 |
Args:
|
70 |
root (str): root directory to store the data
|
@@ -108,7 +108,7 @@ class AmazonSemiStruct(SemiStructureKB):
|
|
108 |
if meta_link_types:
|
109 |
# customize the graph by adding meta links
|
110 |
processed_data = self.post_process(processed_data, meta_link_types=meta_link_types, cache_path=cache_path)
|
111 |
-
super(AmazonSemiStruct, self).__init__(**processed_data,
|
112 |
|
113 |
def __getitem__(self, idx):
|
114 |
idx = int(idx)
|
|
|
63 |
categories: list,
|
64 |
meta_link_types=['brand'],
|
65 |
max_entries=25,
|
66 |
+
download_processed=True,
|
67 |
+
**kwargs):
|
68 |
'''
|
69 |
Args:
|
70 |
root (str): root directory to store the data
|
|
|
108 |
if meta_link_types:
|
109 |
# customize the graph by adding meta links
|
110 |
processed_data = self.post_process(processed_data, meta_link_types=meta_link_types, cache_path=cache_path)
|
111 |
+
super(AmazonSemiStruct, self).__init__(**processed_data, **kwargs)
|
112 |
|
113 |
def __getitem__(self, idx):
|
114 |
idx = int(idx)
|
src/benchmarks/semistruct/mag.py
CHANGED
@@ -40,7 +40,7 @@ class MagSemiStruct(SemiStructureKB):
|
|
40 |
ogbn_papers100M_url = 'https://snap.stanford.edu/ogb/data/misc/ogbn_papers100M/paperinfo.zip'
|
41 |
mag_mapping_url = 'https://zenodo.org/records/2628216/files'
|
42 |
|
43 |
-
def __init__(self, root, download_processed=True):
|
44 |
'''
|
45 |
Args:
|
46 |
root (str): root directory to store the dataset folder
|
@@ -88,7 +88,7 @@ class MagSemiStruct(SemiStructureKB):
|
|
88 |
processed_data = self._process_raw()
|
89 |
processed_data.update({'node_type_dict': self.node_type_dict,
|
90 |
'edge_type_dict': self.edge_type_dict})
|
91 |
-
super(MagSemiStruct, self).__init__(**processed_data)
|
92 |
|
93 |
def load_edge(self, edge_type):
|
94 |
edge_dir = osp.join(self.graph_data_root, f"raw/relations/{edge_type}/edge.csv.gz")
|
|
|
40 |
ogbn_papers100M_url = 'https://snap.stanford.edu/ogb/data/misc/ogbn_papers100M/paperinfo.zip'
|
41 |
mag_mapping_url = 'https://zenodo.org/records/2628216/files'
|
42 |
|
43 |
+
def __init__(self, root, download_processed=True, **kwargs):
|
44 |
'''
|
45 |
Args:
|
46 |
root (str): root directory to store the dataset folder
|
|
|
88 |
processed_data = self._process_raw()
|
89 |
processed_data.update({'node_type_dict': self.node_type_dict,
|
90 |
'edge_type_dict': self.edge_type_dict})
|
91 |
+
super(MagSemiStruct, self).__init__(**processed_data, **kwargs)
|
92 |
|
93 |
def load_edge(self, edge_type):
|
94 |
edge_dir = osp.join(self.graph_data_root, f"raw/relations/{edge_type}/edge.csv.gz")
|
src/benchmarks/semistruct/primekg.py
CHANGED
@@ -30,7 +30,7 @@ class PrimeKGSemiStruct(SemiStructureKB):
|
|
30 |
candidate_types = NODE_TYPES
|
31 |
raw_data_url = 'https://drive.google.com/uc?id=1d__3yP6YZYjKWR2F9fGg-y1rW7-HJPpr'
|
32 |
|
33 |
-
def __init__(self, root, download_processed=True):
|
34 |
'''
|
35 |
Args:
|
36 |
root (str): root directory to store the dataset folder
|
@@ -61,7 +61,7 @@ class PrimeKGSemiStruct(SemiStructureKB):
|
|
61 |
print(f'Loaded from {self.processed_data_dir}!')
|
62 |
else:
|
63 |
processed_data = self._process_raw()
|
64 |
-
super(PrimeKGSemiStruct, self).__init__(**processed_data)
|
65 |
|
66 |
self.node_info = clean_dict(self.node_info)
|
67 |
self.node_attr_dict = {}
|
|
|
30 |
candidate_types = NODE_TYPES
|
31 |
raw_data_url = 'https://drive.google.com/uc?id=1d__3yP6YZYjKWR2F9fGg-y1rW7-HJPpr'
|
32 |
|
33 |
+
def __init__(self, root, download_processed=True, **kwargs):
|
34 |
'''
|
35 |
Args:
|
36 |
root (str): root directory to store the dataset folder
|
|
|
61 |
print(f'Loaded from {self.processed_data_dir}!')
|
62 |
else:
|
63 |
processed_data = self._process_raw()
|
64 |
+
super(PrimeKGSemiStruct, self).__init__(**processed_data, **kwargs)
|
65 |
|
66 |
self.node_info = clean_dict(self.node_info)
|
67 |
self.node_attr_dict = {}
|