File size: 4,213 Bytes
ef17fc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72145e5
ef17fc5
72145e5
 
ef17fc5
72145e5
ef17fc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import streamlit as st
import requests
from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np

# Page configuration
st.set_page_config(page_title="KRISSBERT UMLS Linker", layout="wide")
st.title("🧬 KRISSBERT + UMLS Entity Linker on Hugging Face Spaces")

# Environment variables
UMLS_API_KEY = os.getenv("UMLS_API_KEY")
if not UMLS_API_KEY:
    st.error("❗ Please set the UMLS_API_KEY as a secret in your Space.")
    st.stop()

# UMLS API endpoints
TGT_URL = "https://utslogin.nlm.nih.gov/cas/v1/api-key"
SERVICE = "http://umlsks.nlm.nih.gov"
SEARCH_URL = "https://uts-ws.nlm.nih.gov/rest/search/current"
CONTENT_URL = "https://uts-ws.nlm.nih.gov/rest/content/current/"

# Load KRISSBERT model
MODEL_NAME = "microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL"
@st.cache_resource
def load_model():
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModel.from_pretrained(MODEL_NAME)
    model.eval()
    return tokenizer, model

tokenizer, model = load_model()

# Functions for UMLS API authentication
def get_tgt(api_key):
    resp = requests.post(TGT_URL, data={"apikey": api_key})
    if resp.status_code == 201:
        return resp.headers.get('location')
    else:
        st.error("Failed to obtain TGT from UMLS API.")
        st.stop()

@st.cache_data(ttl=3600)
def get_st(tgt):
    resp = requests.post(tgt, data={"service": SERVICE})
    if resp.status_code == 200:
        return resp.text
    else:
        st.error("Failed to obtain service ticket from UMLS API.")
        st.stop()

# Text embedding (tokenizer and model are unhashable, prefix with underscore)
@st.cache_resource
def embed_text(text, _tokenizer, _model):
    inputs = _tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = _model(**inputs)
    emb = outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy()
    return emb / np.linalg.norm(emb)

# UI: Input box and examples
st.markdown("Enter a biomedical sentence to link entities via UMLS API + KRISSBERT:")
examples = [
    "The patient was administered metformin for type 2 diabetes.",
    "ER crowding has become a widespread issue in hospitals.",
    "Tamoxifen is used in the treatment of ER-positive breast cancer."
]
selected = st.selectbox("πŸ” Example queries", ["Choose..."] + examples)
sentence = st.text_area("πŸ“ Sentence:", value=(selected if selected != "Choose..." else ""))

if st.button("πŸ”— Link Entities"):
    if not sentence.strip():
        st.warning("Please enter a sentence first.")
    else:
        with st.spinner("Querying UMLS API and ranking... 🧠"):
            # Authenticate
            tgt = get_tgt(UMLS_API_KEY)
            sticket = get_st(tgt)

            # UMLS search for mentions
            params = {"string": sentence, "ticket": sticket}
            search_resp = requests.get(SEARCH_URL, params=params)
            search_resp.raise_for_status()
            results = search_resp.json().get("result", {}).get("results", [])

            candidates = []
            for res in results[:10]:
                rui = res.get("ui")
                name = res.get("name")
                content_resp = requests.get(
                    f"{CONTENT_URL}{rui}", params={"ticket": sticket}
                )
                definition = content_resp.json().get("result", {}).get("definition", "") if content_resp.status_code == 200 else ""
                candidates.append({"ui": rui, "name": name, "definition": definition})

            # Embed and score
            sent_emb = embed_text(sentence, tokenizer, model)
            for cand in candidates:
                cand_emb = embed_text(cand['name'], tokenizer, model)
                cand['score'] = float(np.dot(sent_emb, cand_emb))

            ranked = sorted(candidates, key=lambda x: x['score'], reverse=True)[:5]

            # Display
            st.success("Top UMLS candidates:")
            for item in ranked:
                st.markdown(f"**{item['name']}** (CUI: `{item['ui']}`) β€” score: {item['score']:.3f}")
                if item['definition']:
                    st.markdown(f"> {item['definition']}\n")
                st.markdown("---")