mgbam commited on
Commit
afa884d
Β·
verified Β·
1 Parent(s): 09f0258

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -78
app.py CHANGED
@@ -1,28 +1,22 @@
1
  import os
 
2
  import streamlit as st
3
- import requests
4
  from transformers import AutoTokenizer, AutoModel
5
  import torch
6
  import numpy as np
 
7
 
8
  # Page configuration
9
- st.set_page_config(page_title="KRISSBERT UMLS Linker", layout="wide")
10
- st.title("🧬 KRISSBERT + UMLS Entity Linker on Hugging Face Spaces")
11
 
12
- # Environment variables
13
- UMLS_API_KEY = os.getenv("UMLS_API_KEY")
14
- if not UMLS_API_KEY:
15
- st.error("❗ Please set the UMLS_API_KEY as a secret in your Space.")
16
- st.stop()
17
 
18
- # UMLS API endpoints
19
- TGT_URL = "https://utslogin.nlm.nih.gov/cas/v1/api-key"
20
- SERVICE = "http://umlsks.nlm.nih.gov"
21
- SEARCH_URL = "https://uts-ws.nlm.nih.gov/rest/search/current"
22
- CONTENT_URL = "https://uts-ws.nlm.nih.gov/rest/content/current/"
23
-
24
- # Load KRISSBERT model
25
- MODEL_NAME = "microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL"
26
  @st.cache_resource
27
  def load_model():
28
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
@@ -32,80 +26,58 @@ def load_model():
32
 
33
  tokenizer, model = load_model()
34
 
35
- # Functions for UMLS API authentication
36
- def get_tgt(api_key):
37
- resp = requests.post(TGT_URL, data={"apikey": api_key})
38
- if resp.status_code == 201:
39
- return resp.headers.get('location')
40
- else:
41
- st.error("Failed to obtain TGT from UMLS API.")
42
- st.stop()
43
 
44
- @st.cache_data(ttl=3600)
45
- def get_st(tgt):
46
- resp = requests.post(tgt, data={"service": SERVICE})
47
- if resp.status_code == 200:
48
- return resp.text
49
- else:
50
- st.error("Failed to obtain service ticket from UMLS API.")
51
- st.stop()
52
 
53
- # Text embedding (tokenizer and model are unhashable, prefix with underscore)
54
  @st.cache_resource
55
  def embed_text(text, _tokenizer, _model):
56
- inputs = _tokenizer(text, return_tensors="pt", truncation=True, padding=True)
57
  with torch.no_grad():
58
  outputs = _model(**inputs)
59
  emb = outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy()
60
  return emb / np.linalg.norm(emb)
61
 
62
- # UI: Input box and examples
63
- st.markdown("Enter a biomedical sentence to link entities via UMLS API + KRISSBERT:")
64
  examples = [
65
- "The patient was administered metformin for type 2 diabetes.",
66
- "ER crowding has become a widespread issue in hospitals.",
67
- "Tamoxifen is used in the treatment of ER-positive breast cancer."
68
  ]
69
- selected = st.selectbox("πŸ” Example queries", ["Choose..."] + examples)
70
- sentence = st.text_area("πŸ“ Sentence:", value=(selected if selected != "Choose..." else ""))
71
 
72
- if st.button("πŸ”— Link Entities"):
73
  if not sentence.strip():
74
- st.warning("Please enter a sentence first.")
75
  else:
76
- with st.spinner("Querying UMLS API and ranking... 🧠"):
77
- # Authenticate
78
- tgt = get_tgt(UMLS_API_KEY)
79
- sticket = get_st(tgt)
80
-
81
- # UMLS search for mentions
82
- params = {"string": sentence, "ticket": sticket}
83
- search_resp = requests.get(SEARCH_URL, params=params)
84
- search_resp.raise_for_status()
85
- results = search_resp.json().get("result", {}).get("results", [])
86
-
87
- candidates = []
88
- for res in results[:10]:
89
- rui = res.get("ui")
90
- name = res.get("name")
91
- content_resp = requests.get(
92
- f"{CONTENT_URL}{rui}", params={"ticket": sticket}
93
- )
94
- definition = content_resp.json().get("result", {}).get("definition", "") if content_resp.status_code == 200 else ""
95
- candidates.append({"ui": rui, "name": name, "definition": definition})
96
-
97
- # Embed and score
98
- sent_emb = embed_text(sentence, tokenizer, model)
99
- for cand in candidates:
100
- cand_emb = embed_text(cand['name'], tokenizer, model)
101
- cand['score'] = float(np.dot(sent_emb, cand_emb))
102
-
103
- ranked = sorted(candidates, key=lambda x: x['score'], reverse=True)[:5]
104
-
105
- # Display
106
- st.success("Top UMLS candidates:")
107
- for item in ranked:
108
- st.markdown(f"**{item['name']}** (CUI: `{item['ui']}`) β€” score: {item['score']:.3f}")
109
  if item['definition']:
110
- st.markdown(f"> {item['definition']}\n")
111
- st.markdown("---")
 
 
 
1
  import os
2
+ import json
3
  import streamlit as st
 
4
  from transformers import AutoTokenizer, AutoModel
5
  import torch
6
  import numpy as np
7
+ import faiss
8
 
9
  # Page configuration
10
+ st.set_page_config(page_title='KRISSBERT UMLS Linker', layout='wide')
11
+ st.title('🧬 KRISSBERT + UMLS Entity Linker (Local FAISS)')
12
 
13
+ # File paths
14
+ METADATA_PATH = 'umls_metadata.json'
15
+ EMBED_PATH = 'umls_embeddings.npy'
16
+ INDEX_PATH = 'umls_index.faiss'
17
+ MODEL_NAME = 'microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL'
18
 
19
+ # Load model & tokenizer
 
 
 
 
 
 
 
20
  @st.cache_resource
21
  def load_model():
22
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
 
26
 
27
  tokenizer, model = load_model()
28
 
29
+ # Load UMLS FAISS index + metadata
30
+ @st.cache_resource
31
+ def load_umls_index():
32
+ meta = json.load(open(METADATA_PATH, 'r'))
33
+ embeddings = np.load(EMBED_PATH)
34
+ index = faiss.read_index(INDEX_PATH)
35
+ return index, meta
 
36
 
37
+ faiss_index, umls_meta = load_umls_index()
 
 
 
 
 
 
 
38
 
39
+ # Embed text
40
  @st.cache_resource
41
  def embed_text(text, _tokenizer, _model):
42
+ inputs = _tokenizer(text, return_tensors='pt', truncation=True, padding=True)
43
  with torch.no_grad():
44
  outputs = _model(**inputs)
45
  emb = outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy()
46
  return emb / np.linalg.norm(emb)
47
 
48
+ # UI: examples and input
49
+ st.markdown('Enter a biomedical sentence to link entities via local UMLS FAISS index and KRISSBERT:')
50
  examples = [
51
+ 'The patient was administered metformin for type 2 diabetes.',
52
+ 'ER crowding has become a widespread issue in hospitals.',
53
+ 'Tamoxifen is used in the treatment of ER-positive breast cancer.'
54
  ]
55
+ selected = st.selectbox('πŸ” Example queries', ['Choose...'] + examples)
56
+ sentence = st.text_area('πŸ“ Sentence:', value=(selected if selected != 'Choose...' else ''))
57
 
58
+ if st.button('πŸ”— Link Entities'):
59
  if not sentence.strip():
60
+ st.warning('Please enter a sentence first.')
61
  else:
62
+ with st.spinner('Embedding sentence and searching FAISS…'):
63
+ sent_emb = embed_text(sentence, tokenizer, model).reshape(1, -1)
64
+ distances, indices = faiss_index.search(sent_emb, 5)
65
+ results = []
66
+ for idx in indices[0]:
67
+ entry = umls_meta.get(str(idx), {})
68
+ results.append({
69
+ 'cui': entry.get('cui', ''),
70
+ 'name': entry.get('name', ''),
71
+ 'definition': entry.get('definition', ''),
72
+ 'source': entry.get('source', '')
73
+ })
74
+ # Display
75
+ if results:
76
+ st.success('Top UMLS candidates:')
77
+ for item in results:
78
+ st.markdown('**' + item['name'] + '** (CUI: `' + item['cui'] + '`)')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  if item['definition']:
80
+ st.markdown('> ' + item['definition'] + '\n')
81
+ st.markdown('_Source: ' + item['source'] + '_\n---')
82
+ else:
83
+ st.info('No matches found in UMLS index.')