import os import streamlit as st import requests from transformers import AutoTokenizer, AutoModel import torch import numpy as np # Page configuration st.set_page_config(page_title="KRISSBERT UMLS Linker", layout="wide") st.title("🧬 KRISSBERT + UMLS Entity Linker on Hugging Face Spaces") # Environment variables UMLS_API_KEY = os.getenv("UMLS_API_KEY") if not UMLS_API_KEY: st.error("❗ Please set the UMLS_API_KEY as a secret in your Space.") st.stop() # UMLS API endpoints TGT_URL = "https://utslogin.nlm.nih.gov/cas/v1/api-key" SERVICE = "http://umlsks.nlm.nih.gov" SEARCH_URL = "https://uts-ws.nlm.nih.gov/rest/search/current" CONTENT_URL = "https://uts-ws.nlm.nih.gov/rest/content/current/" # Load KRISSBERT model MODEL_NAME = "microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL" @st.cache_resource def load_model(): tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModel.from_pretrained(MODEL_NAME) model.eval() return tokenizer, model tokenizer, model = load_model() # Functions for UMLS API authentication def get_tgt(api_key): resp = requests.post(TGT_URL, data={"apikey": api_key}) if resp.status_code == 201: return resp.headers.get('location') else: st.error("Failed to obtain TGT from UMLS API.") st.stop() @st.cache_data(ttl=3600) def get_st(tgt): resp = requests.post(tgt, data={"service": SERVICE}) if resp.status_code == 200: return resp.text else: st.error("Failed to obtain service ticket from UMLS API.") st.stop() # Text embedding @st.cache_resource def embed_text(text, tokenizer, model): inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): outputs = model(**inputs) emb = outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy() return emb / np.linalg.norm(emb) # UI: Input box and examples st.markdown("Enter a biomedical sentence to link entities via UMLS API + KRISSBERT:") examples = [ "The patient was administered metformin for type 2 diabetes.", "ER crowding has become a widespread issue in hospitals.", "Tamoxifen is used in the treatment of ER-positive breast cancer." ] selected = st.selectbox("🔍 Example queries", ["Choose..."] + examples) sentence = st.text_area("📝 Sentence:", value=(selected if selected != "Choose..." else "")) if st.button("🔗 Link Entities"): if not sentence.strip(): st.warning("Please enter a sentence first.") else: with st.spinner("Querying UMLS API and ranking... 🧠"): # Authenticate tgt = get_tgt(UMLS_API_KEY) sticket = get_st(tgt) # UMLS search for mentions params = {"string": sentence, "ticket": sticket} search_resp = requests.get(SEARCH_URL, params=params) search_resp.raise_for_status() results = search_resp.json().get("result", {}).get("results", []) candidates = [] for res in results[:10]: rui = res.get("ui") name = res.get("name") content_resp = requests.get( f"{CONTENT_URL}{rui}", params={"ticket": sticket} ) definition = content_resp.json().get("result", {}).get("definition", "") if content_resp.status_code == 200 else "" candidates.append({"ui": rui, "name": name, "definition": definition}) # Embed and score sent_emb = embed_text(sentence, tokenizer, model) for cand in candidates: cand_emb = embed_text(cand['name'], tokenizer, model) cand['score'] = float(np.dot(sent_emb, cand_emb)) ranked = sorted(candidates, key=lambda x: x['score'], reverse=True)[:5] # Display st.success("Top UMLS candidates:") for item in ranked: st.markdown(f"**{item['name']}** (CUI: `{item['ui']}`) — score: {item['score']:.3f}") if item['definition']: st.markdown(f"> {item['definition']}\n") st.markdown("---")