mgbam commited on
Commit
d5e0cb0
Β·
verified Β·
1 Parent(s): b7556e4

Update mcp/knowledge_graph.py

Browse files
Files changed (1) hide show
  1. mcp/knowledge_graph.py +127 -49
mcp/knowledge_graph.py CHANGED
@@ -1,62 +1,140 @@
1
- from streamlit_agraph import Node, Edge, Config
2
- import re, itertools
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from typing import List, Tuple
4
 
5
- _GREEN = "#00b894"
6
- _ORANGE = "#d35400"
7
- _BLUE = "#0984e3"
 
 
 
 
 
8
 
9
- def _safe(node_like):
10
- """Return empty dict if node_like is an Exception."""
11
- return node_like if isinstance(node_like, dict) else {}
 
 
 
 
 
12
 
 
 
13
  def build_agraph(
14
- papers: List[dict],
15
- umls: List[dict],
16
- drug_safety: List[dict],
17
  ) -> Tuple[List[Node], List[Edge], Config]:
18
- nodes, edges = [], []
19
-
20
- # UMLS concept nodes
21
- for c in filter(bool, map(_safe, umls)):
22
- cui, name = c.get("cui"), c.get("name", "")
23
- if cui and name:
24
- nid = f"cui:{cui}"
25
- nodes.append(Node(id=nid, label=name, color=_GREEN, size=25))
26
-
27
- # Drug nodes
28
- def _drug_name(d: dict) -> str:
29
- return (
30
- d.get("drug_name")
31
- or d.get("patient", {}).get("drug", "")
32
- or d.get("medicinalproduct", "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  )
34
 
35
- for idx, rec in enumerate(itertools.chain.from_iterable(
36
- [r if isinstance(r, list) else [r] for r in drug_safety])):
37
- dn = _drug_name(rec) or f"drug_{idx}"
38
- did = f"drug_{idx}"
39
- nodes.append(Node(id=did, label=dn, color=_ORANGE, size=25))
40
-
41
- # Paper nodes
42
- for i, p in enumerate(papers, 1):
43
- pid = f"paper_{i}"
44
- nodes.append(Node(id=pid, label=f"P{i}", tooltip=p["title"], color=_BLUE, size=15))
45
-
46
- txt = f"{p['title']} {p['summary']}".lower()
47
- # link ↔ concepts
48
- for c in filter(bool, map(_safe, umls)):
49
- if (name := c.get("name", "")).lower() in txt and c.get("cui"):
50
- edges.append(Edge(source=pid, target=f"cui:{c['cui']}", label="mentions"))
51
- # link ↔ drugs
52
- for n in nodes:
53
- if n.id.startswith("drug_") and n.label.lower() in txt:
54
- edges.append(Edge(source=pid, target=n.id, label="mentions"))
55
 
56
  cfg = Config(
57
- width="100%", height="600px", directed=False,
58
- node={"labelProperty": "label"},
59
- nodeHighlightBehavior=True, highlightColor="#f1c40f",
 
 
60
  collapsible=True,
 
61
  )
62
  return nodes, edges, cfg
 
1
+ # mcp/knowledge_graph.py
2
+ """
3
+ Build agraph-compatible nodes + edges for the MedGenesis UI.
4
+
5
+ Robustness notes
6
+ ----------------
7
+ * Accepts *any* iterable for ``papers``, ``umls``, ``drug_safety``.
8
+ * Silently skips items that are **not** dictionaries or have missing keys.
9
+ * Normalises drug-safety payloads that may arrive as dict **or** list.
10
+ * Always casts labels to string – avoids ``None.lower()`` errors.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import re
16
  from typing import List, Tuple
17
 
18
+ from streamlit_agraph import Node, Edge, Config
19
+
20
+
21
+ # ── helpers -----------------------------------------------------------------
22
+ def _safe_str(x) -> str:
23
+ """Return UTF-8 string or empty string."""
24
+ return str(x) if x is not None else ""
25
+
26
 
27
+ def _uniquify(nodes: List[Node]) -> List[Node]:
28
+ """Remove duplicate node-ids (keep first)."""
29
+ seen, out = set(), []
30
+ for n in nodes:
31
+ if n.id not in seen:
32
+ out.append(n)
33
+ seen.add(n.id)
34
+ return out
35
 
36
+
37
+ # ── public builder ----------------------------------------------------------
38
  def build_agraph(
39
+ papers: list,
40
+ umls: list,
41
+ drug_safety: list,
42
  ) -> Tuple[List[Node], List[Edge], Config]:
43
+ """
44
+ Parameters
45
+ ----------
46
+ papers : List[dict]
47
+ Must contain keys ``title``, ``summary``.
48
+ umls : List[dict]
49
+ Dicts with at least ``name`` and ``cui``.
50
+ drug_safety : List[dict | list]
51
+ OpenFDA records – could be one dict or list of dicts.
52
+
53
+ Returns
54
+ -------
55
+ nodes, edges, cfg : tuple
56
+ Ready for ``streamlit_agraph.agraph``.
57
+ """
58
+
59
+ nodes: List[Node] = []
60
+ edges: List[Edge] = []
61
+
62
+ # ── UMLS concepts -------------------------------------------------------
63
+ for c in umls:
64
+ if not isinstance(c, dict):
65
+ continue
66
+ cui = _safe_str(c.get("cui")).strip()
67
+ name = _safe_str(c.get("name")).strip()
68
+ if not (cui and name):
69
+ continue
70
+ nodes.append(
71
+ Node(id=f"concept_{cui}", label=name, size=28, color="#00b894")
72
+ )
73
+
74
+ # ── Drug safety --------------------------------------------------------
75
+ drug_nodes: List[Tuple[str, str]] = []
76
+ for idx, rec in enumerate(drug_safety):
77
+ if not rec:
78
+ continue
79
+ recs = rec if isinstance(rec, list) else [rec]
80
+ for j, r in enumerate(recs):
81
+ if not isinstance(r, dict):
82
+ continue
83
+ dn = (
84
+ r.get("drug_name")
85
+ or r.get("patient", {}).get("drug")
86
+ or r.get("medicinalproduct")
87
+ )
88
+ dn = _safe_str(dn).strip() or f"drug_{idx}_{j}"
89
+ did = f"drug_{idx}_{j}"
90
+ drug_nodes.append((did, dn))
91
+ nodes.append(Node(id=did, label=dn, size=25, color="#d35400"))
92
+
93
+ # ── Papers & edges ------------------------------------------------------
94
+ for p_idx, p in enumerate(papers):
95
+ if not isinstance(p, dict):
96
+ continue
97
+ pid = f"paper_{p_idx}"
98
+ title = _safe_str(p.get("title"))
99
+ summary = _safe_str(p.get("summary"))
100
+ nodes.append(
101
+ Node(
102
+ id=pid,
103
+ label=f"P{p_idx + 1}",
104
+ tooltip=title,
105
+ size=16,
106
+ color="#0984e3",
107
+ )
108
  )
109
 
110
+ text_blob = f"{title} {summary}".lower()
111
+
112
+ # β†’ concept edges
113
+ for c in umls:
114
+ if not isinstance(c, dict):
115
+ continue
116
+ name = _safe_str(c.get("name")).lower()
117
+ cui = _safe_str(c.get("cui"))
118
+ if name and cui and name in text_blob:
119
+ edges.append(
120
+ Edge(source=pid, target=f"concept_{cui}", label="mentions")
121
+ )
122
+
123
+ # β†’ drug edges
124
+ for did, dn in drug_nodes:
125
+ if dn.lower() in text_blob:
126
+ edges.append(Edge(source=pid, target=did, label="mentions"))
127
+
128
+ # ── deduplicate & config ------------------------------------------------
129
+ nodes = _uniquify(nodes)
130
 
131
  cfg = Config(
132
+ width="100%",
133
+ height="600px",
134
+ directed=False,
135
+ nodeHighlightBehavior=True,
136
+ highlightColor="#f1c40f",
137
  collapsible=True,
138
+ node={"labelProperty": "label"},
139
  )
140
  return nodes, edges, cfg