|
import os |
|
import streamlit as st |
|
import requests |
|
from transformers import AutoTokenizer, AutoModel |
|
import torch |
|
import numpy as np |
|
|
|
|
|
st.set_page_config(page_title="KRISSBERT UMLS Linker", layout="wide") |
|
st.title("𧬠KRISSBERT + UMLS Entity Linker on Hugging Face Spaces") |
|
|
|
|
|
UMLS_API_KEY = os.getenv("UMLS_API_KEY") |
|
if not UMLS_API_KEY: |
|
st.error("β Please set the UMLS_API_KEY as a secret in your Space.") |
|
st.stop() |
|
|
|
|
|
TGT_URL = "https://utslogin.nlm.nih.gov/cas/v1/api-key" |
|
SERVICE = "http://umlsks.nlm.nih.gov" |
|
SEARCH_URL = "https://uts-ws.nlm.nih.gov/rest/search/current" |
|
CONTENT_URL = "https://uts-ws.nlm.nih.gov/rest/content/current/" |
|
|
|
|
|
MODEL_NAME = "microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL" |
|
@st.cache_resource |
|
def load_model(): |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
model = AutoModel.from_pretrained(MODEL_NAME) |
|
model.eval() |
|
return tokenizer, model |
|
|
|
tokenizer, model = load_model() |
|
|
|
|
|
def get_tgt(api_key): |
|
resp = requests.post(TGT_URL, data={"apikey": api_key}) |
|
if resp.status_code == 201: |
|
return resp.headers.get('location') |
|
else: |
|
st.error("Failed to obtain TGT from UMLS API.") |
|
st.stop() |
|
|
|
@st.cache_data(ttl=3600) |
|
def get_st(tgt): |
|
resp = requests.post(tgt, data={"service": SERVICE}) |
|
if resp.status_code == 200: |
|
return resp.text |
|
else: |
|
st.error("Failed to obtain service ticket from UMLS API.") |
|
st.stop() |
|
|
|
|
|
@st.cache_resource |
|
def embed_text(text, tokenizer, model): |
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
emb = outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy() |
|
return emb / np.linalg.norm(emb) |
|
|
|
|
|
st.markdown("Enter a biomedical sentence to link entities via UMLS API + KRISSBERT:") |
|
examples = [ |
|
"The patient was administered metformin for type 2 diabetes.", |
|
"ER crowding has become a widespread issue in hospitals.", |
|
"Tamoxifen is used in the treatment of ER-positive breast cancer." |
|
] |
|
selected = st.selectbox("π Example queries", ["Choose..."] + examples) |
|
sentence = st.text_area("π Sentence:", value=(selected if selected != "Choose..." else "")) |
|
|
|
if st.button("π Link Entities"): |
|
if not sentence.strip(): |
|
st.warning("Please enter a sentence first.") |
|
else: |
|
with st.spinner("Querying UMLS API and ranking... π§ "): |
|
|
|
tgt = get_tgt(UMLS_API_KEY) |
|
sticket = get_st(tgt) |
|
|
|
|
|
params = {"string": sentence, "ticket": sticket} |
|
search_resp = requests.get(SEARCH_URL, params=params) |
|
search_resp.raise_for_status() |
|
results = search_resp.json().get("result", {}).get("results", []) |
|
|
|
candidates = [] |
|
for res in results[:10]: |
|
rui = res.get("ui") |
|
name = res.get("name") |
|
content_resp = requests.get( |
|
f"{CONTENT_URL}{rui}", params={"ticket": sticket} |
|
) |
|
definition = content_resp.json().get("result", {}).get("definition", "") if content_resp.status_code == 200 else "" |
|
candidates.append({"ui": rui, "name": name, "definition": definition}) |
|
|
|
|
|
sent_emb = embed_text(sentence, tokenizer, model) |
|
for cand in candidates: |
|
cand_emb = embed_text(cand['name'], tokenizer, model) |
|
cand['score'] = float(np.dot(sent_emb, cand_emb)) |
|
|
|
ranked = sorted(candidates, key=lambda x: x['score'], reverse=True)[:5] |
|
|
|
|
|
st.success("Top UMLS candidates:") |
|
for item in ranked: |
|
st.markdown(f"**{item['name']}** (CUI: `{item['ui']}`) β score: {item['score']:.3f}") |
|
if item['definition']: |
|
st.markdown(f"> {item['definition']}\n") |
|
st.markdown("---") |
|
|