|
import os |
|
import json |
|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModel |
|
import torch |
|
import numpy as np |
|
import faiss |
|
|
|
|
|
st.set_page_config(page_title='KRISSBERT UMLS Linker', layout='wide') |
|
st.title('🧬 KRISSBERT + UMLS Entity Linker (Local FAISS)') |
|
|
|
|
|
METADATA_PATH = 'umls_metadata.json' |
|
EMBED_PATH = 'umls_embeddings.npy' |
|
INDEX_PATH = 'umls_index.faiss' |
|
MODEL_NAME = 'microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL' |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
model = AutoModel.from_pretrained(MODEL_NAME) |
|
model.eval() |
|
return tokenizer, model |
|
|
|
tokenizer, model = load_model() |
|
|
|
|
|
@st.cache_resource |
|
def load_umls_index(): |
|
meta = json.load(open(METADATA_PATH, 'r')) |
|
embeddings = np.load(EMBED_PATH) |
|
index = faiss.read_index(INDEX_PATH) |
|
return index, meta |
|
|
|
faiss_index, umls_meta = load_umls_index() |
|
|
|
|
|
@st.cache_resource |
|
def embed_text(text, _tokenizer, _model): |
|
inputs = _tokenizer(text, return_tensors='pt', truncation=True, padding=True) |
|
with torch.no_grad(): |
|
outputs = _model(**inputs) |
|
emb = outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy() |
|
return emb / np.linalg.norm(emb) |
|
|
|
|
|
st.markdown('Enter a biomedical sentence to link entities via local UMLS FAISS index and KRISSBERT:') |
|
examples = [ |
|
'The patient was administered metformin for type 2 diabetes.', |
|
'ER crowding has become a widespread issue in hospitals.', |
|
'Tamoxifen is used in the treatment of ER-positive breast cancer.' |
|
] |
|
selected = st.selectbox('🔍 Example queries', ['Choose...'] + examples) |
|
sentence = st.text_area('📝 Sentence:', value=(selected if selected != 'Choose...' else '')) |
|
|
|
if st.button('🔗 Link Entities'): |
|
if not sentence.strip(): |
|
st.warning('Please enter a sentence first.') |
|
else: |
|
with st.spinner('Embedding sentence and searching FAISS…'): |
|
sent_emb = embed_text(sentence, tokenizer, model).reshape(1, -1) |
|
distances, indices = faiss_index.search(sent_emb, 5) |
|
results = [] |
|
for idx in indices[0]: |
|
entry = umls_meta.get(str(idx), {}) |
|
results.append({ |
|
'cui': entry.get('cui', ''), |
|
'name': entry.get('name', ''), |
|
'definition': entry.get('definition', ''), |
|
'source': entry.get('source', '') |
|
}) |
|
|
|
if results: |
|
st.success('Top UMLS candidates:') |
|
for item in results: |
|
st.markdown(f"**{item['name']}** (CUI: `{item['cui']}`)") |
|
if item['definition']: |
|
st.markdown(f"> {item['definition']}\n") |
|
st.markdown(f"_Source: {item['source']}_\n---") |
|
else: |
|
st.info('No matches found in UMLS index.') |