import os import json import streamlit as st from transformers import AutoTokenizer, AutoModel import torch import numpy as np import faiss # Page configuration st.set_page_config(page_title='KRISSBERT UMLS Linker', layout='wide') st.title('🧬 KRISSBERT + UMLS Entity Linker (Local FAISS)') # Paths & model name METADATA_PATH = 'umls_metadata.json' EMBED_PATH = 'umls_embeddings.npy' INDEX_PATH = 'umls_index.faiss' MODEL_NAME = 'microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL' # 1️⃣ Load model & tokenizer @st.cache_resource def load_model(): tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModel.from_pretrained(MODEL_NAME) model.eval() return tokenizer, model tokenizer, model = load_model() # 2️⃣ Load UMLS FAISS index & metadata @st.cache_resource def load_umls_index(): meta = json.load(open(METADATA_PATH, 'r')) embeddings = np.load(EMBED_PATH) index = faiss.read_index(INDEX_PATH) return index, meta faiss_index, umls_meta = load_umls_index() # 3️⃣ Embed text (prefix underscores to avoid caching errors) @st.cache_resource def embed_text(text, _tokenizer, _model): inputs = _tokenizer(text, return_tensors='pt', truncation=True, padding=True) with torch.no_grad(): outputs = _model(**inputs) emb = outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy() return emb / np.linalg.norm(emb) # UI: examples + input st.markdown('Enter a biomedical sentence to link entities via local UMLS FAISS index and KRISSBERT:') examples = [ 'The patient was administered metformin for type 2 diabetes.', 'ER crowding has become a widespread issue in hospitals.', 'Tamoxifen is used in the treatment of ER-positive breast cancer.' ] selected = st.selectbox('🔍 Example queries', ['Choose...'] + examples) sentence = st.text_area('📝 Sentence:', value=(selected if selected != 'Choose...' else '')) if st.button('🔗 Link Entities'): if not sentence.strip(): st.warning('Please enter a sentence first.') else: with st.spinner('Embedding sentence and searching FAISS…'): sent_emb = embed_text(sentence, tokenizer, model).reshape(1, -1) distances, indices = faiss_index.search(sent_emb, 5) results = [] for idx in indices[0]: entry = umls_meta.get(str(idx), {}) results.append({ 'cui': entry.get('cui', ''), 'name': entry.get('name', ''), 'definition': entry.get('definition', ''), 'source': entry.get('source', '') }) # Display if results: st.success('Top UMLS candidates:') for item in results: st.markdown(f"**{item['name']}** (CUI: `{item['cui']}`)") if item['definition']: st.markdown(f"> {item['definition']}\n") st.markdown(f"_Source: {item['source']}_\n---") else: st.info('No matches found in UMLS index.')