File size: 3,049 Bytes
ef17fc5
afa884d
ef17fc5
 
 
 
afa884d
ef17fc5
 
afa884d
 
ef17fc5
8a63984
afa884d
 
 
 
ef17fc5
8a63984
ef17fc5
 
 
 
 
 
 
 
 
8a63984
afa884d
 
 
 
 
 
ef17fc5
afa884d
ef17fc5
8a63984
ef17fc5
72145e5
afa884d
ef17fc5
72145e5
ef17fc5
 
 
8a63984
afa884d
ef17fc5
afa884d
 
 
ef17fc5
afa884d
 
ef17fc5
afa884d
ef17fc5
afa884d
ef17fc5
afa884d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a63984
ef17fc5
8a63984
 
afa884d
8a63984
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import json
import streamlit as st
from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np
import faiss

# Page configuration
st.set_page_config(page_title='KRISSBERT UMLS Linker', layout='wide')
st.title('🧬 KRISSBERT + UMLS Entity Linker (Local FAISS)')

# Paths & model name
METADATA_PATH = 'umls_metadata.json'
EMBED_PATH = 'umls_embeddings.npy'
INDEX_PATH = 'umls_index.faiss'
MODEL_NAME = 'microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL'

# 1️⃣ Load model & tokenizer
@st.cache_resource
def load_model():
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModel.from_pretrained(MODEL_NAME)
    model.eval()
    return tokenizer, model

tokenizer, model = load_model()

# 2️⃣ Load UMLS FAISS index & metadata
@st.cache_resource
def load_umls_index():
    meta = json.load(open(METADATA_PATH, 'r'))
    embeddings = np.load(EMBED_PATH)
    index = faiss.read_index(INDEX_PATH)
    return index, meta

faiss_index, umls_meta = load_umls_index()

# 3️⃣ Embed text (prefix underscores to avoid caching errors)
@st.cache_resource
def embed_text(text, _tokenizer, _model):
    inputs = _tokenizer(text, return_tensors='pt', truncation=True, padding=True)
    with torch.no_grad():
        outputs = _model(**inputs)
    emb = outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy()
    return emb / np.linalg.norm(emb)

# UI: examples + input
st.markdown('Enter a biomedical sentence to link entities via local UMLS FAISS index and KRISSBERT:')
examples = [
    'The patient was administered metformin for type 2 diabetes.',
    'ER crowding has become a widespread issue in hospitals.',
    'Tamoxifen is used in the treatment of ER-positive breast cancer.'
]
selected = st.selectbox('🔍 Example queries', ['Choose...'] + examples)
sentence = st.text_area('📝 Sentence:', value=(selected if selected != 'Choose...' else ''))

if st.button('🔗 Link Entities'):
    if not sentence.strip():
        st.warning('Please enter a sentence first.')
    else:
        with st.spinner('Embedding sentence and searching FAISS…'):
            sent_emb = embed_text(sentence, tokenizer, model).reshape(1, -1)
            distances, indices = faiss_index.search(sent_emb, 5)
            results = []
            for idx in indices[0]:
                entry = umls_meta.get(str(idx), {})
                results.append({
                    'cui': entry.get('cui', ''),
                    'name': entry.get('name', ''),
                    'definition': entry.get('definition', ''),
                    'source': entry.get('source', '')
                })
        # Display
        if results:
            st.success('Top UMLS candidates:')
            for item in results:
                st.markdown(f"**{item['name']}** (CUI: `{item['cui']}`)")
                if item['definition']:
                    st.markdown(f"> {item['definition']}\n")
                st.markdown(f"_Source: {item['source']}_\n---")
        else:
            st.info('No matches found in UMLS index.')