mgbam commited on
Commit
ef17fc5
Β·
verified Β·
1 Parent(s): d3006b4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -0
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import requests
4
+ from transformers import AutoTokenizer, AutoModel
5
+ import torch
6
+ import numpy as np
7
+
8
+ # Page configuration
9
+ st.set_page_config(page_title="KRISSBERT UMLS Linker", layout="wide")
10
+ st.title("🧬 KRISSBERT + UMLS Entity Linker on Hugging Face Spaces")
11
+
12
+ # Environment variables
13
+ UMLS_API_KEY = os.getenv("UMLS_API_KEY")
14
+ if not UMLS_API_KEY:
15
+ st.error("❗ Please set the UMLS_API_KEY as a secret in your Space.")
16
+ st.stop()
17
+
18
+ # UMLS API endpoints
19
+ TGT_URL = "https://utslogin.nlm.nih.gov/cas/v1/api-key"
20
+ SERVICE = "http://umlsks.nlm.nih.gov"
21
+ SEARCH_URL = "https://uts-ws.nlm.nih.gov/rest/search/current"
22
+ CONTENT_URL = "https://uts-ws.nlm.nih.gov/rest/content/current/"
23
+
24
+ # Load KRISSBERT model
25
+ MODEL_NAME = "microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL"
26
+ @st.cache_resource
27
+ def load_model():
28
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
29
+ model = AutoModel.from_pretrained(MODEL_NAME)
30
+ model.eval()
31
+ return tokenizer, model
32
+
33
+ tokenizer, model = load_model()
34
+
35
+ # Functions for UMLS API authentication
36
+ def get_tgt(api_key):
37
+ resp = requests.post(TGT_URL, data={"apikey": api_key})
38
+ if resp.status_code == 201:
39
+ return resp.headers.get('location')
40
+ else:
41
+ st.error("Failed to obtain TGT from UMLS API.")
42
+ st.stop()
43
+
44
+ @st.cache_data(ttl=3600)
45
+ def get_st(tgt):
46
+ resp = requests.post(tgt, data={"service": SERVICE})
47
+ if resp.status_code == 200:
48
+ return resp.text
49
+ else:
50
+ st.error("Failed to obtain service ticket from UMLS API.")
51
+ st.stop()
52
+
53
+ # Text embedding
54
+ @st.cache_resource
55
+ def embed_text(text, tokenizer, model):
56
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
57
+ with torch.no_grad():
58
+ outputs = model(**inputs)
59
+ emb = outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy()
60
+ return emb / np.linalg.norm(emb)
61
+
62
+ # UI: Input box and examples
63
+ st.markdown("Enter a biomedical sentence to link entities via UMLS API + KRISSBERT:")
64
+ examples = [
65
+ "The patient was administered metformin for type 2 diabetes.",
66
+ "ER crowding has become a widespread issue in hospitals.",
67
+ "Tamoxifen is used in the treatment of ER-positive breast cancer."
68
+ ]
69
+ selected = st.selectbox("πŸ” Example queries", ["Choose..."] + examples)
70
+ sentence = st.text_area("πŸ“ Sentence:", value=(selected if selected != "Choose..." else ""))
71
+
72
+ if st.button("πŸ”— Link Entities"):
73
+ if not sentence.strip():
74
+ st.warning("Please enter a sentence first.")
75
+ else:
76
+ with st.spinner("Querying UMLS API and ranking... 🧠"):
77
+ # Authenticate
78
+ tgt = get_tgt(UMLS_API_KEY)
79
+ sticket = get_st(tgt)
80
+
81
+ # UMLS search for mentions
82
+ params = {"string": sentence, "ticket": sticket}
83
+ search_resp = requests.get(SEARCH_URL, params=params)
84
+ search_resp.raise_for_status()
85
+ results = search_resp.json().get("result", {}).get("results", [])
86
+
87
+ candidates = []
88
+ for res in results[:10]:
89
+ rui = res.get("ui")
90
+ name = res.get("name")
91
+ content_resp = requests.get(
92
+ f"{CONTENT_URL}{rui}", params={"ticket": sticket}
93
+ )
94
+ definition = content_resp.json().get("result", {}).get("definition", "") if content_resp.status_code == 200 else ""
95
+ candidates.append({"ui": rui, "name": name, "definition": definition})
96
+
97
+ # Embed and score
98
+ sent_emb = embed_text(sentence, tokenizer, model)
99
+ for cand in candidates:
100
+ cand_emb = embed_text(cand['name'], tokenizer, model)
101
+ cand['score'] = float(np.dot(sent_emb, cand_emb))
102
+
103
+ ranked = sorted(candidates, key=lambda x: x['score'], reverse=True)[:5]
104
+
105
+ # Display
106
+ st.success("Top UMLS candidates:")
107
+ for item in ranked:
108
+ st.markdown(f"**{item['name']}** (CUI: `{item['ui']}`) β€” score: {item['score']:.3f}")
109
+ if item['definition']:
110
+ st.markdown(f"> {item['definition']}\n")
111
+ st.markdown("---")