UMLS / app.py
mgbam's picture
Create app.py
ef17fc5 verified
raw
history blame
4.15 kB
import os
import streamlit as st
import requests
from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np
# Page configuration
st.set_page_config(page_title="KRISSBERT UMLS Linker", layout="wide")
st.title("🧬 KRISSBERT + UMLS Entity Linker on Hugging Face Spaces")
# Environment variables
UMLS_API_KEY = os.getenv("UMLS_API_KEY")
if not UMLS_API_KEY:
st.error("❗ Please set the UMLS_API_KEY as a secret in your Space.")
st.stop()
# UMLS API endpoints
TGT_URL = "https://utslogin.nlm.nih.gov/cas/v1/api-key"
SERVICE = "http://umlsks.nlm.nih.gov"
SEARCH_URL = "https://uts-ws.nlm.nih.gov/rest/search/current"
CONTENT_URL = "https://uts-ws.nlm.nih.gov/rest/content/current/"
# Load KRISSBERT model
MODEL_NAME = "microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL"
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
model.eval()
return tokenizer, model
tokenizer, model = load_model()
# Functions for UMLS API authentication
def get_tgt(api_key):
resp = requests.post(TGT_URL, data={"apikey": api_key})
if resp.status_code == 201:
return resp.headers.get('location')
else:
st.error("Failed to obtain TGT from UMLS API.")
st.stop()
@st.cache_data(ttl=3600)
def get_st(tgt):
resp = requests.post(tgt, data={"service": SERVICE})
if resp.status_code == 200:
return resp.text
else:
st.error("Failed to obtain service ticket from UMLS API.")
st.stop()
# Text embedding
@st.cache_resource
def embed_text(text, tokenizer, model):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
emb = outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy()
return emb / np.linalg.norm(emb)
# UI: Input box and examples
st.markdown("Enter a biomedical sentence to link entities via UMLS API + KRISSBERT:")
examples = [
"The patient was administered metformin for type 2 diabetes.",
"ER crowding has become a widespread issue in hospitals.",
"Tamoxifen is used in the treatment of ER-positive breast cancer."
]
selected = st.selectbox("πŸ” Example queries", ["Choose..."] + examples)
sentence = st.text_area("πŸ“ Sentence:", value=(selected if selected != "Choose..." else ""))
if st.button("πŸ”— Link Entities"):
if not sentence.strip():
st.warning("Please enter a sentence first.")
else:
with st.spinner("Querying UMLS API and ranking... 🧠"):
# Authenticate
tgt = get_tgt(UMLS_API_KEY)
sticket = get_st(tgt)
# UMLS search for mentions
params = {"string": sentence, "ticket": sticket}
search_resp = requests.get(SEARCH_URL, params=params)
search_resp.raise_for_status()
results = search_resp.json().get("result", {}).get("results", [])
candidates = []
for res in results[:10]:
rui = res.get("ui")
name = res.get("name")
content_resp = requests.get(
f"{CONTENT_URL}{rui}", params={"ticket": sticket}
)
definition = content_resp.json().get("result", {}).get("definition", "") if content_resp.status_code == 200 else ""
candidates.append({"ui": rui, "name": name, "definition": definition})
# Embed and score
sent_emb = embed_text(sentence, tokenizer, model)
for cand in candidates:
cand_emb = embed_text(cand['name'], tokenizer, model)
cand['score'] = float(np.dot(sent_emb, cand_emb))
ranked = sorted(candidates, key=lambda x: x['score'], reverse=True)[:5]
# Display
st.success("Top UMLS candidates:")
for item in ranked:
st.markdown(f"**{item['name']}** (CUI: `{item['ui']}`) β€” score: {item['score']:.3f}")
if item['definition']:
st.markdown(f"> {item['definition']}\n")
st.markdown("---")