Update app.py
Browse files
app.py
CHANGED
@@ -1,28 +1,22 @@
|
|
1 |
import os
|
|
|
2 |
import streamlit as st
|
3 |
-
import requests
|
4 |
from transformers import AutoTokenizer, AutoModel
|
5 |
import torch
|
6 |
import numpy as np
|
|
|
7 |
|
8 |
# Page configuration
|
9 |
-
st.set_page_config(page_title=
|
10 |
-
st.title(
|
11 |
|
12 |
-
#
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
|
18 |
-
#
|
19 |
-
TGT_URL = "https://utslogin.nlm.nih.gov/cas/v1/api-key"
|
20 |
-
SERVICE = "http://umlsks.nlm.nih.gov"
|
21 |
-
SEARCH_URL = "https://uts-ws.nlm.nih.gov/rest/search/current"
|
22 |
-
CONTENT_URL = "https://uts-ws.nlm.nih.gov/rest/content/current/"
|
23 |
-
|
24 |
-
# Load KRISSBERT model
|
25 |
-
MODEL_NAME = "microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL"
|
26 |
@st.cache_resource
|
27 |
def load_model():
|
28 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
@@ -32,80 +26,58 @@ def load_model():
|
|
32 |
|
33 |
tokenizer, model = load_model()
|
34 |
|
35 |
-
#
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
st.stop()
|
43 |
|
44 |
-
|
45 |
-
def get_st(tgt):
|
46 |
-
resp = requests.post(tgt, data={"service": SERVICE})
|
47 |
-
if resp.status_code == 200:
|
48 |
-
return resp.text
|
49 |
-
else:
|
50 |
-
st.error("Failed to obtain service ticket from UMLS API.")
|
51 |
-
st.stop()
|
52 |
|
53 |
-
#
|
54 |
@st.cache_resource
|
55 |
def embed_text(text, _tokenizer, _model):
|
56 |
-
inputs = _tokenizer(text, return_tensors=
|
57 |
with torch.no_grad():
|
58 |
outputs = _model(**inputs)
|
59 |
emb = outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy()
|
60 |
return emb / np.linalg.norm(emb)
|
61 |
|
62 |
-
# UI:
|
63 |
-
st.markdown(
|
64 |
examples = [
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
]
|
69 |
-
selected = st.selectbox(
|
70 |
-
sentence = st.text_area(
|
71 |
|
72 |
-
if st.button(
|
73 |
if not sentence.strip():
|
74 |
-
st.warning(
|
75 |
else:
|
76 |
-
with st.spinner(
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
)
|
94 |
-
definition = content_resp.json().get("result", {}).get("definition", "") if content_resp.status_code == 200 else ""
|
95 |
-
candidates.append({"ui": rui, "name": name, "definition": definition})
|
96 |
-
|
97 |
-
# Embed and score
|
98 |
-
sent_emb = embed_text(sentence, tokenizer, model)
|
99 |
-
for cand in candidates:
|
100 |
-
cand_emb = embed_text(cand['name'], tokenizer, model)
|
101 |
-
cand['score'] = float(np.dot(sent_emb, cand_emb))
|
102 |
-
|
103 |
-
ranked = sorted(candidates, key=lambda x: x['score'], reverse=True)[:5]
|
104 |
-
|
105 |
-
# Display
|
106 |
-
st.success("Top UMLS candidates:")
|
107 |
-
for item in ranked:
|
108 |
-
st.markdown(f"**{item['name']}** (CUI: `{item['ui']}`) β score: {item['score']:.3f}")
|
109 |
if item['definition']:
|
110 |
-
st.markdown(
|
111 |
-
st.markdown(
|
|
|
|
|
|
1 |
import os
|
2 |
+
import json
|
3 |
import streamlit as st
|
|
|
4 |
from transformers import AutoTokenizer, AutoModel
|
5 |
import torch
|
6 |
import numpy as np
|
7 |
+
import faiss
|
8 |
|
9 |
# Page configuration
|
10 |
+
st.set_page_config(page_title='KRISSBERT UMLS Linker', layout='wide')
|
11 |
+
st.title('𧬠KRISSBERT + UMLS Entity Linker (Local FAISS)')
|
12 |
|
13 |
+
# File paths
|
14 |
+
METADATA_PATH = 'umls_metadata.json'
|
15 |
+
EMBED_PATH = 'umls_embeddings.npy'
|
16 |
+
INDEX_PATH = 'umls_index.faiss'
|
17 |
+
MODEL_NAME = 'microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL'
|
18 |
|
19 |
+
# Load model & tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
@st.cache_resource
|
21 |
def load_model():
|
22 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
|
26 |
|
27 |
tokenizer, model = load_model()
|
28 |
|
29 |
+
# Load UMLS FAISS index + metadata
|
30 |
+
@st.cache_resource
|
31 |
+
def load_umls_index():
|
32 |
+
meta = json.load(open(METADATA_PATH, 'r'))
|
33 |
+
embeddings = np.load(EMBED_PATH)
|
34 |
+
index = faiss.read_index(INDEX_PATH)
|
35 |
+
return index, meta
|
|
|
36 |
|
37 |
+
faiss_index, umls_meta = load_umls_index()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
+
# Embed text
|
40 |
@st.cache_resource
|
41 |
def embed_text(text, _tokenizer, _model):
|
42 |
+
inputs = _tokenizer(text, return_tensors='pt', truncation=True, padding=True)
|
43 |
with torch.no_grad():
|
44 |
outputs = _model(**inputs)
|
45 |
emb = outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy()
|
46 |
return emb / np.linalg.norm(emb)
|
47 |
|
48 |
+
# UI: examples and input
|
49 |
+
st.markdown('Enter a biomedical sentence to link entities via local UMLS FAISS index and KRISSBERT:')
|
50 |
examples = [
|
51 |
+
'The patient was administered metformin for type 2 diabetes.',
|
52 |
+
'ER crowding has become a widespread issue in hospitals.',
|
53 |
+
'Tamoxifen is used in the treatment of ER-positive breast cancer.'
|
54 |
]
|
55 |
+
selected = st.selectbox('π Example queries', ['Choose...'] + examples)
|
56 |
+
sentence = st.text_area('π Sentence:', value=(selected if selected != 'Choose...' else ''))
|
57 |
|
58 |
+
if st.button('π Link Entities'):
|
59 |
if not sentence.strip():
|
60 |
+
st.warning('Please enter a sentence first.')
|
61 |
else:
|
62 |
+
with st.spinner('Embedding sentence and searching FAISSβ¦'):
|
63 |
+
sent_emb = embed_text(sentence, tokenizer, model).reshape(1, -1)
|
64 |
+
distances, indices = faiss_index.search(sent_emb, 5)
|
65 |
+
results = []
|
66 |
+
for idx in indices[0]:
|
67 |
+
entry = umls_meta.get(str(idx), {})
|
68 |
+
results.append({
|
69 |
+
'cui': entry.get('cui', ''),
|
70 |
+
'name': entry.get('name', ''),
|
71 |
+
'definition': entry.get('definition', ''),
|
72 |
+
'source': entry.get('source', '')
|
73 |
+
})
|
74 |
+
# Display
|
75 |
+
if results:
|
76 |
+
st.success('Top UMLS candidates:')
|
77 |
+
for item in results:
|
78 |
+
st.markdown('**' + item['name'] + '** (CUI: `' + item['cui'] + '`)')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
if item['definition']:
|
80 |
+
st.markdown('> ' + item['definition'] + '\n')
|
81 |
+
st.markdown('_Source: ' + item['source'] + '_\n---')
|
82 |
+
else:
|
83 |
+
st.info('No matches found in UMLS index.')
|