shirwu commited on
Commit
c77efb7
·
1 Parent(s): 0c3992e

Add CONCURRENCY_LIMIT; Graph config change -> directed

Browse files
interactive/pyvis_graph.py CHANGED
@@ -3,12 +3,12 @@ import json
3
  import torch
4
  import gradio as gr
5
  from pyvis.network import Network
6
-
7
  sys.path.append(".")
 
8
  from src.benchmarks import get_semistructured_data
9
 
10
-
11
- TITLE = "STaRK Knowledge Base Explorer"
12
  BRAND_NAME = {
13
  "amazon": "STaRK-Amazon",
14
  "mag": "STaRK-MAG",
@@ -22,20 +22,16 @@ NODE_COLORS = [
22
  "#00796B", # Teal
23
  "#03A9F4", # Light Blue
24
  "#CDDC39", # Lime
25
- "#E91E63", # Pink
26
  "#3F51B5", # Indigo
27
  "#00BCD4", # Cyan
28
  "#FFC107", # Amber
29
  "#8BC34A", # Light Green
30
- "#795548", # Brown
31
  "#9E9E9E", # Grey
32
  "#607D8B", # Blue Grey
33
  "#FFEB3B", # Bright Yellow
34
  "#E1F5FE", # Light Blue 50
35
  "#F1F8E9", # Light Green 50
36
  "#FFF3E0", # Orange 50
37
- "#FCE4EC", # Pink 50
38
- "#F3E5F5", # Purple 50
39
  "#FFFDE7", # Yellow 50
40
  "#E0F7FA", # Cyan 50
41
  "#E8F5E9", # Green 50
@@ -90,10 +86,14 @@ def relabel(x, edge_index, batch, pos=None):
90
  return x, edge_index, batch, pos
91
 
92
 
93
- def generate_network(kb, node_id, max_nodes=10, num_hops="1"):
94
  max_nodes = int(max_nodes)
95
-
96
- net = Network()
 
 
 
 
97
 
98
  def get_one_hop(kb, node_id, max_nodes):
99
  edge_index = kb.edge_index
@@ -137,7 +137,6 @@ def generate_network(kb, node_id, max_nodes=10, num_hops="1"):
137
  node_ids, relabel_edge_index, _, _ = relabel(
138
  torch.arange(kb.num_nodes()), edge_index, batch=torch.zeros(kb.num_nodes())
139
  )
140
-
141
  for idx, n_id in enumerate(node_ids):
142
  if node_id == n_id:
143
  net.add_node(
@@ -158,31 +157,45 @@ def generate_network(kb, node_id, max_nodes=10, num_hops="1"):
158
  font={"align": "middle", "size": 10},
159
  )
160
  for idx in range(relabel_edge_index.size(-1)):
161
- net.add_edge(
162
- relabel_edge_index[0][idx].item(),
163
- relabel_edge_index[1][idx].item(),
164
- color=EDGE_COLORS[edge_types[idx].item()],
165
- label=kb.edge_type_dict[edge_types[idx].item()]
166
- .replace("___", " ")
167
- .replace("_", " "),
168
- width=1,
169
- font={"align": "middle", "size": 10},
170
- )
171
-
 
 
 
 
 
 
 
 
 
 
 
172
  return net.get_network_data()
173
 
174
 
175
  def get_text_html(kb, node_id):
176
  text = kb.get_doc_info(node_id, add_rel=False, compact=False)
177
- # need a text box, figure left, text right
178
- text = text.replace("\n", "<br>").replace(" ", "&nbsp;")
179
  # add a title
 
180
  text = f"<h3>Textual Info of Entity {node_id}:</h3>{text}"
 
181
  # show the text as what it is with empty space and can be scrolled
182
- return f"""<div style="width: 100%; height: 600px; overflow-x: hidden; overflow-y: scroll; overflow-wrap: break-word; padding: 10px; margin: 0 auto; border: 1px solid #ccc;">{text}</div>"""
 
 
 
183
 
184
 
185
- def get_subgraph_html(kb, kb_name, node_id, max_nodes=10, num_hops="1"):
186
  network = generate_network(kb, node_id, max_nodes, num_hops)
187
 
188
  nodes = network[0]
@@ -200,7 +213,7 @@ def get_subgraph_html(kb, kb_name, node_id, max_nodes=10, num_hops="1"):
200
 
201
  def main():
202
  # kb = get_semistructured_data(DATASET_NAME)
203
- kbs = {k: get_semistructured_data(k) for k in BRAND_NAME.keys()}
204
 
205
  with gr.Blocks(head=VISJS_HEAD, title=TITLE) as demo:
206
  gr.Markdown(f"# {TITLE}")
@@ -208,11 +221,14 @@ def main():
208
  with gr.Tab(BRAND_NAME[name]):
209
  with gr.Row():
210
  entity_id = gr.Number(
211
- label="Entity ID", elem_id=f"{name}-entity-id-input"
 
 
 
 
212
  )
213
- max_paths = gr.Slider(1, 200, 10, step=1, label="Max Paths")
214
  num_hops = gr.Dropdown(
215
- ["1", "2", "inf"], value="1", label="Number of Hops"
216
  )
217
  query_btn = gr.Button(
218
  value="Show Graph",
@@ -232,7 +248,7 @@ def main():
232
  ),
233
  inputs=[entity_id, max_paths, num_hops],
234
  outputs=[graph_area, text_area],
235
- api_name=f"{name}-fetch-graph",
236
  )
237
 
238
  # Hidden inputs for fetch just text
@@ -248,11 +264,12 @@ def main():
248
  lambda e, kb=kb: get_text_html(kb, e),
249
  inputs=[entity_for_text],
250
  outputs=text_area,
251
- api_name=f"{name}-fetch-text",
252
  )
253
-
254
  demo.launch(share=True)
255
 
256
 
257
  if __name__ == "__main__":
258
- main()
 
 
3
  import torch
4
  import gradio as gr
5
  from pyvis.network import Network
 
6
  sys.path.append(".")
7
+ import re
8
  from src.benchmarks import get_semistructured_data
9
 
10
+ CONCURRENCY_LIMIT = 1000
11
+ TITLE = "STaRK Semistructure Knowledge Base Explorer"
12
  BRAND_NAME = {
13
  "amazon": "STaRK-Amazon",
14
  "mag": "STaRK-MAG",
 
22
  "#00796B", # Teal
23
  "#03A9F4", # Light Blue
24
  "#CDDC39", # Lime
 
25
  "#3F51B5", # Indigo
26
  "#00BCD4", # Cyan
27
  "#FFC107", # Amber
28
  "#8BC34A", # Light Green
 
29
  "#9E9E9E", # Grey
30
  "#607D8B", # Blue Grey
31
  "#FFEB3B", # Bright Yellow
32
  "#E1F5FE", # Light Blue 50
33
  "#F1F8E9", # Light Green 50
34
  "#FFF3E0", # Orange 50
 
 
35
  "#FFFDE7", # Yellow 50
36
  "#E0F7FA", # Cyan 50
37
  "#E8F5E9", # Green 50
 
86
  return x, edge_index, batch, pos
87
 
88
 
89
+ def generate_network(kb, node_id, max_nodes=10, num_hops='2'):
90
  max_nodes = int(max_nodes)
91
+ if 'gene/protein' in kb.node_type_dict.values():
92
+ indirected = True
93
+ net = Network(directed=False)
94
+ else:
95
+ indirected = False
96
+ net = Network()
97
 
98
  def get_one_hop(kb, node_id, max_nodes):
99
  edge_index = kb.edge_index
 
137
  node_ids, relabel_edge_index, _, _ = relabel(
138
  torch.arange(kb.num_nodes()), edge_index, batch=torch.zeros(kb.num_nodes())
139
  )
 
140
  for idx, n_id in enumerate(node_ids):
141
  if node_id == n_id:
142
  net.add_node(
 
157
  font={"align": "middle", "size": 10},
158
  )
159
  for idx in range(relabel_edge_index.size(-1)):
160
+ if indirected:
161
+ net.add_edge(
162
+ relabel_edge_index[0][idx].item(),
163
+ relabel_edge_index[1][idx].item(),
164
+ color=EDGE_COLORS[edge_types[idx].item()],
165
+ label=kb.edge_type_dict[edge_types[idx].item()]
166
+ .replace('___', " ")
167
+ .replace('_', " "),
168
+ width=1,
169
+ font={"align": "middle", "size": 10})
170
+ else:
171
+ net.add_edge(
172
+ relabel_edge_index[0][idx].item(),
173
+ relabel_edge_index[1][idx].item(),
174
+ color=EDGE_COLORS[edge_types[idx].item()],
175
+ label=kb.edge_type_dict[edge_types[idx].item()]
176
+ .replace('___', " ")
177
+ .replace('_', " "),
178
+ width=1,
179
+ font={"align": "middle", "size": 10},
180
+ arrows="to",
181
+ arrowStrikethrough=False)
182
  return net.get_network_data()
183
 
184
 
185
  def get_text_html(kb, node_id):
186
  text = kb.get_doc_info(node_id, add_rel=False, compact=False)
 
 
187
  # add a title
188
+ text = text.replace("\n", "<br>").replace(" ", "&nbsp;")
189
  text = f"<h3>Textual Info of Entity {node_id}:</h3>{text}"
190
+ text = re.sub(r"\$([^$]+)\$", r"\\(\1\\)", text)
191
  # show the text as what it is with empty space and can be scrolled
192
+ return f"""<script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
193
+ <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
194
+ <div style="width: 100%; height: 600px; overflow-x: hidden; overflow-y: scroll; overflow-wrap: break-word; hyphens: auto; padding: 10px; margin: 0 auto; border: 1px solid #ccc; line-height: 1.5;
195
+ font-family: SF Pro Text, SF Pro Icons, Helvetica Neue, Helvetica, Arial, sans-serif;">{text}</div>"""
196
 
197
 
198
+ def get_subgraph_html(kb, kb_name, node_id, max_nodes=10, num_hops=1):
199
  network = generate_network(kb, node_id, max_nodes, num_hops)
200
 
201
  nodes = network[0]
 
213
 
214
  def main():
215
  # kb = get_semistructured_data(DATASET_NAME)
216
+ kbs = {k: get_semistructured_data(k, indirected=False) for k in BRAND_NAME.keys()}
217
 
218
  with gr.Blocks(head=VISJS_HEAD, title=TITLE) as demo:
219
  gr.Markdown(f"# {TITLE}")
 
221
  with gr.Tab(BRAND_NAME[name]):
222
  with gr.Row():
223
  entity_id = gr.Number(
224
+ label="Entity ID",
225
+ elem_id=f"{name}-entity-id-input"
226
+ )
227
+ max_paths = gr.Slider(
228
+ 1, 200, 10, step=1, label="Max Number of Paths"
229
  )
 
230
  num_hops = gr.Dropdown(
231
+ ["1", "2", "inf"], value="2", label="Number of Hops"
232
  )
233
  query_btn = gr.Button(
234
  value="Show Graph",
 
248
  ),
249
  inputs=[entity_id, max_paths, num_hops],
250
  outputs=[graph_area, text_area],
251
+ api_name=f"{name}-fetch-graph"
252
  )
253
 
254
  # Hidden inputs for fetch just text
 
264
  lambda e, kb=kb: get_text_html(kb, e),
265
  inputs=[entity_for_text],
266
  outputs=text_area,
267
+ api_name=f"{name}-fetch-text"
268
  )
269
+ demo.queue(max_size=2*CONCURRENCY_LIMIT, default_concurrency_limit=CONCURRENCY_LIMIT)
270
  demo.launch(share=True)
271
 
272
 
273
  if __name__ == "__main__":
274
+
275
+ main()
src/benchmarks/get_semistruct.py CHANGED
@@ -2,21 +2,22 @@ import os.path as osp
2
  from src.benchmarks.semistruct import AmazonSemiStruct, PrimeKGSemiStruct, MagSemiStruct
3
 
4
 
5
- def get_semistructured_data(name, root='data/', download_processed=True):
6
  data_root = osp.join(root, name)
7
  if name == 'amazon':
8
  categories = ['Sports_and_Outdoors']
9
  kb = AmazonSemiStruct(root=data_root,
10
- categories=categories,
11
- meta_link_types=['brand'],
12
- indirected=True,
13
- download_processed=download_processed
14
- )
15
  if name == 'primekg':
16
  kb = PrimeKGSemiStruct(root=data_root,
17
- download_processed=download_processed)
 
18
 
19
  if name == 'mag':
20
  kb = MagSemiStruct(root=data_root,
21
- download_processed=download_processed)
22
  return kb
 
2
  from src.benchmarks.semistruct import AmazonSemiStruct, PrimeKGSemiStruct, MagSemiStruct
3
 
4
 
5
+ def get_semistructured_data(name, root='data/', download_processed=True, **kwargs):
6
  data_root = osp.join(root, name)
7
  if name == 'amazon':
8
  categories = ['Sports_and_Outdoors']
9
  kb = AmazonSemiStruct(root=data_root,
10
+ categories=categories,
11
+ meta_link_types=['brand'],
12
+ download_processed=download_processed,
13
+ **kwargs
14
+ )
15
  if name == 'primekg':
16
  kb = PrimeKGSemiStruct(root=data_root,
17
+ download_processed=download_processed,
18
+ **kwargs)
19
 
20
  if name == 'mag':
21
  kb = MagSemiStruct(root=data_root,
22
+ download_processed=download_processed)
23
  return kb
src/benchmarks/semistruct/amazon.py CHANGED
@@ -63,8 +63,8 @@ class AmazonSemiStruct(SemiStructureKB):
63
  categories: list,
64
  meta_link_types=['brand'],
65
  max_entries=25,
66
- indirected=True,
67
- download_processed=True):
68
  '''
69
  Args:
70
  root (str): root directory to store the data
@@ -108,7 +108,7 @@ class AmazonSemiStruct(SemiStructureKB):
108
  if meta_link_types:
109
  # customize the graph by adding meta links
110
  processed_data = self.post_process(processed_data, meta_link_types=meta_link_types, cache_path=cache_path)
111
- super(AmazonSemiStruct, self).__init__(**processed_data, indirected=indirected)
112
 
113
  def __getitem__(self, idx):
114
  idx = int(idx)
 
63
  categories: list,
64
  meta_link_types=['brand'],
65
  max_entries=25,
66
+ download_processed=True,
67
+ **kwargs):
68
  '''
69
  Args:
70
  root (str): root directory to store the data
 
108
  if meta_link_types:
109
  # customize the graph by adding meta links
110
  processed_data = self.post_process(processed_data, meta_link_types=meta_link_types, cache_path=cache_path)
111
+ super(AmazonSemiStruct, self).__init__(**processed_data, **kwargs)
112
 
113
  def __getitem__(self, idx):
114
  idx = int(idx)
src/benchmarks/semistruct/mag.py CHANGED
@@ -40,7 +40,7 @@ class MagSemiStruct(SemiStructureKB):
40
  ogbn_papers100M_url = 'https://snap.stanford.edu/ogb/data/misc/ogbn_papers100M/paperinfo.zip'
41
  mag_mapping_url = 'https://zenodo.org/records/2628216/files'
42
 
43
- def __init__(self, root, download_processed=True):
44
  '''
45
  Args:
46
  root (str): root directory to store the dataset folder
@@ -88,7 +88,7 @@ class MagSemiStruct(SemiStructureKB):
88
  processed_data = self._process_raw()
89
  processed_data.update({'node_type_dict': self.node_type_dict,
90
  'edge_type_dict': self.edge_type_dict})
91
- super(MagSemiStruct, self).__init__(**processed_data)
92
 
93
  def load_edge(self, edge_type):
94
  edge_dir = osp.join(self.graph_data_root, f"raw/relations/{edge_type}/edge.csv.gz")
 
40
  ogbn_papers100M_url = 'https://snap.stanford.edu/ogb/data/misc/ogbn_papers100M/paperinfo.zip'
41
  mag_mapping_url = 'https://zenodo.org/records/2628216/files'
42
 
43
+ def __init__(self, root, download_processed=True, **kwargs):
44
  '''
45
  Args:
46
  root (str): root directory to store the dataset folder
 
88
  processed_data = self._process_raw()
89
  processed_data.update({'node_type_dict': self.node_type_dict,
90
  'edge_type_dict': self.edge_type_dict})
91
+ super(MagSemiStruct, self).__init__(**processed_data, **kwargs)
92
 
93
  def load_edge(self, edge_type):
94
  edge_dir = osp.join(self.graph_data_root, f"raw/relations/{edge_type}/edge.csv.gz")
src/benchmarks/semistruct/primekg.py CHANGED
@@ -30,7 +30,7 @@ class PrimeKGSemiStruct(SemiStructureKB):
30
  candidate_types = NODE_TYPES
31
  raw_data_url = 'https://drive.google.com/uc?id=1d__3yP6YZYjKWR2F9fGg-y1rW7-HJPpr'
32
 
33
- def __init__(self, root, download_processed=True):
34
  '''
35
  Args:
36
  root (str): root directory to store the dataset folder
@@ -61,7 +61,7 @@ class PrimeKGSemiStruct(SemiStructureKB):
61
  print(f'Loaded from {self.processed_data_dir}!')
62
  else:
63
  processed_data = self._process_raw()
64
- super(PrimeKGSemiStruct, self).__init__(**processed_data)
65
 
66
  self.node_info = clean_dict(self.node_info)
67
  self.node_attr_dict = {}
 
30
  candidate_types = NODE_TYPES
31
  raw_data_url = 'https://drive.google.com/uc?id=1d__3yP6YZYjKWR2F9fGg-y1rW7-HJPpr'
32
 
33
+ def __init__(self, root, download_processed=True, **kwargs):
34
  '''
35
  Args:
36
  root (str): root directory to store the dataset folder
 
61
  print(f'Loaded from {self.processed_data_dir}!')
62
  else:
63
  processed_data = self._process_raw()
64
+ super(PrimeKGSemiStruct, self).__init__(**processed_data, **kwargs)
65
 
66
  self.node_info = clean_dict(self.node_info)
67
  self.node_attr_dict = {}