Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import streamlit as st
|
3 |
+
import requests
|
4 |
+
from transformers import AutoTokenizer, AutoModel
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
# Page configuration
|
9 |
+
st.set_page_config(page_title="KRISSBERT UMLS Linker", layout="wide")
|
10 |
+
st.title("𧬠KRISSBERT + UMLS Entity Linker on Hugging Face Spaces")
|
11 |
+
|
12 |
+
# Environment variables
|
13 |
+
UMLS_API_KEY = os.getenv("UMLS_API_KEY")
|
14 |
+
if not UMLS_API_KEY:
|
15 |
+
st.error("β Please set the UMLS_API_KEY as a secret in your Space.")
|
16 |
+
st.stop()
|
17 |
+
|
18 |
+
# UMLS API endpoints
|
19 |
+
TGT_URL = "https://utslogin.nlm.nih.gov/cas/v1/api-key"
|
20 |
+
SERVICE = "http://umlsks.nlm.nih.gov"
|
21 |
+
SEARCH_URL = "https://uts-ws.nlm.nih.gov/rest/search/current"
|
22 |
+
CONTENT_URL = "https://uts-ws.nlm.nih.gov/rest/content/current/"
|
23 |
+
|
24 |
+
# Load KRISSBERT model
|
25 |
+
MODEL_NAME = "microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL"
|
26 |
+
@st.cache_resource
|
27 |
+
def load_model():
|
28 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
29 |
+
model = AutoModel.from_pretrained(MODEL_NAME)
|
30 |
+
model.eval()
|
31 |
+
return tokenizer, model
|
32 |
+
|
33 |
+
tokenizer, model = load_model()
|
34 |
+
|
35 |
+
# Functions for UMLS API authentication
|
36 |
+
def get_tgt(api_key):
|
37 |
+
resp = requests.post(TGT_URL, data={"apikey": api_key})
|
38 |
+
if resp.status_code == 201:
|
39 |
+
return resp.headers.get('location')
|
40 |
+
else:
|
41 |
+
st.error("Failed to obtain TGT from UMLS API.")
|
42 |
+
st.stop()
|
43 |
+
|
44 |
+
@st.cache_data(ttl=3600)
|
45 |
+
def get_st(tgt):
|
46 |
+
resp = requests.post(tgt, data={"service": SERVICE})
|
47 |
+
if resp.status_code == 200:
|
48 |
+
return resp.text
|
49 |
+
else:
|
50 |
+
st.error("Failed to obtain service ticket from UMLS API.")
|
51 |
+
st.stop()
|
52 |
+
|
53 |
+
# Text embedding
|
54 |
+
@st.cache_resource
|
55 |
+
def embed_text(text, tokenizer, model):
|
56 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
57 |
+
with torch.no_grad():
|
58 |
+
outputs = model(**inputs)
|
59 |
+
emb = outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy()
|
60 |
+
return emb / np.linalg.norm(emb)
|
61 |
+
|
62 |
+
# UI: Input box and examples
|
63 |
+
st.markdown("Enter a biomedical sentence to link entities via UMLS API + KRISSBERT:")
|
64 |
+
examples = [
|
65 |
+
"The patient was administered metformin for type 2 diabetes.",
|
66 |
+
"ER crowding has become a widespread issue in hospitals.",
|
67 |
+
"Tamoxifen is used in the treatment of ER-positive breast cancer."
|
68 |
+
]
|
69 |
+
selected = st.selectbox("π Example queries", ["Choose..."] + examples)
|
70 |
+
sentence = st.text_area("π Sentence:", value=(selected if selected != "Choose..." else ""))
|
71 |
+
|
72 |
+
if st.button("π Link Entities"):
|
73 |
+
if not sentence.strip():
|
74 |
+
st.warning("Please enter a sentence first.")
|
75 |
+
else:
|
76 |
+
with st.spinner("Querying UMLS API and ranking... π§ "):
|
77 |
+
# Authenticate
|
78 |
+
tgt = get_tgt(UMLS_API_KEY)
|
79 |
+
sticket = get_st(tgt)
|
80 |
+
|
81 |
+
# UMLS search for mentions
|
82 |
+
params = {"string": sentence, "ticket": sticket}
|
83 |
+
search_resp = requests.get(SEARCH_URL, params=params)
|
84 |
+
search_resp.raise_for_status()
|
85 |
+
results = search_resp.json().get("result", {}).get("results", [])
|
86 |
+
|
87 |
+
candidates = []
|
88 |
+
for res in results[:10]:
|
89 |
+
rui = res.get("ui")
|
90 |
+
name = res.get("name")
|
91 |
+
content_resp = requests.get(
|
92 |
+
f"{CONTENT_URL}{rui}", params={"ticket": sticket}
|
93 |
+
)
|
94 |
+
definition = content_resp.json().get("result", {}).get("definition", "") if content_resp.status_code == 200 else ""
|
95 |
+
candidates.append({"ui": rui, "name": name, "definition": definition})
|
96 |
+
|
97 |
+
# Embed and score
|
98 |
+
sent_emb = embed_text(sentence, tokenizer, model)
|
99 |
+
for cand in candidates:
|
100 |
+
cand_emb = embed_text(cand['name'], tokenizer, model)
|
101 |
+
cand['score'] = float(np.dot(sent_emb, cand_emb))
|
102 |
+
|
103 |
+
ranked = sorted(candidates, key=lambda x: x['score'], reverse=True)[:5]
|
104 |
+
|
105 |
+
# Display
|
106 |
+
st.success("Top UMLS candidates:")
|
107 |
+
for item in ranked:
|
108 |
+
st.markdown(f"**{item['name']}** (CUI: `{item['ui']}`) β score: {item['score']:.3f}")
|
109 |
+
if item['definition']:
|
110 |
+
st.markdown(f"> {item['definition']}\n")
|
111 |
+
st.markdown("---")
|