UMLS / app.py
mgbam's picture
Update app.py
8a63984 verified
raw
history blame contribute delete
3.05 kB
import os
import json
import streamlit as st
from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np
import faiss
# Page configuration
st.set_page_config(page_title='KRISSBERT UMLS Linker', layout='wide')
st.title('🧬 KRISSBERT + UMLS Entity Linker (Local FAISS)')
# Paths & model name
METADATA_PATH = 'umls_metadata.json'
EMBED_PATH = 'umls_embeddings.npy'
INDEX_PATH = 'umls_index.faiss'
MODEL_NAME = 'microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL'
# 1️⃣ Load model & tokenizer
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
model.eval()
return tokenizer, model
tokenizer, model = load_model()
# 2️⃣ Load UMLS FAISS index & metadata
@st.cache_resource
def load_umls_index():
meta = json.load(open(METADATA_PATH, 'r'))
embeddings = np.load(EMBED_PATH)
index = faiss.read_index(INDEX_PATH)
return index, meta
faiss_index, umls_meta = load_umls_index()
# 3️⃣ Embed text (prefix underscores to avoid caching errors)
@st.cache_resource
def embed_text(text, _tokenizer, _model):
inputs = _tokenizer(text, return_tensors='pt', truncation=True, padding=True)
with torch.no_grad():
outputs = _model(**inputs)
emb = outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy()
return emb / np.linalg.norm(emb)
# UI: examples + input
st.markdown('Enter a biomedical sentence to link entities via local UMLS FAISS index and KRISSBERT:')
examples = [
'The patient was administered metformin for type 2 diabetes.',
'ER crowding has become a widespread issue in hospitals.',
'Tamoxifen is used in the treatment of ER-positive breast cancer.'
]
selected = st.selectbox('🔍 Example queries', ['Choose...'] + examples)
sentence = st.text_area('📝 Sentence:', value=(selected if selected != 'Choose...' else ''))
if st.button('🔗 Link Entities'):
if not sentence.strip():
st.warning('Please enter a sentence first.')
else:
with st.spinner('Embedding sentence and searching FAISS…'):
sent_emb = embed_text(sentence, tokenizer, model).reshape(1, -1)
distances, indices = faiss_index.search(sent_emb, 5)
results = []
for idx in indices[0]:
entry = umls_meta.get(str(idx), {})
results.append({
'cui': entry.get('cui', ''),
'name': entry.get('name', ''),
'definition': entry.get('definition', ''),
'source': entry.get('source', '')
})
# Display
if results:
st.success('Top UMLS candidates:')
for item in results:
st.markdown(f"**{item['name']}** (CUI: `{item['cui']}`)")
if item['definition']:
st.markdown(f"> {item['definition']}\n")
st.markdown(f"_Source: {item['source']}_\n---")
else:
st.info('No matches found in UMLS index.')