|
""" |
|
Script to build a FAISS index from UMLS concept metadata. |
|
Produces: |
|
- `umls_embeddings.npy`: normalized vectors for each concept |
|
- `umls_index.faiss`: FAISS index for fast similarity search |
|
- `umls_metadata.json`: mapping from index position to concept metadata |
|
|
|
Usage: |
|
python build_umls_faiss_index.py \ |
|
--input concepts.csv \ |
|
--output_meta backend/umls_metadata.json \ |
|
--output_emb backend/umls_embeddings.npy \ |
|
--output_idx backend/umls_index.faiss |
|
""" |
|
import argparse |
|
import csv |
|
import json |
|
import numpy as np |
|
import faiss |
|
from transformers import AutoTokenizer, AutoModel |
|
import torch |
|
|
|
def encode_concepts(model, tokenizer, texts, batch_size=32, device='cpu'): |
|
embeddings = [] |
|
model.to(device) |
|
for i in range(0, len(texts), batch_size): |
|
batch_texts = texts[i:i+batch_size] |
|
inputs = tokenizer( |
|
batch_texts, |
|
padding=True, |
|
truncation=True, |
|
return_tensors='pt' |
|
).to(device) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
cls_emb = outputs.last_hidden_state[:, 0, :].cpu().numpy() |
|
|
|
norms = np.linalg.norm(cls_emb, axis=1, keepdims=True) |
|
embeddings.append(cls_emb / norms) |
|
return np.vstack(embeddings) |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="Build FAISS index for UMLS concepts.") |
|
parser.add_argument('--input', required=True, |
|
help='CSV file with columns: cui,name,definition,source') |
|
parser.add_argument('--output_meta', required=True, |
|
help='JSON metadata output path') |
|
parser.add_argument('--output_emb', required=True, |
|
help='NumPy embeddings output path') |
|
parser.add_argument('--output_idx', required=True, |
|
help='FAISS index output path') |
|
parser.add_argument('--model', |
|
default='microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL', |
|
help='Hugging Face model name') |
|
args = parser.parse_args() |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model) |
|
model = AutoModel.from_pretrained(args.model) |
|
model.eval() |
|
|
|
|
|
cuis, names, defs, sources = [], [], [], [] |
|
with open(args.input, newline='', encoding='utf-8') as f: |
|
reader = csv.DictReader(f) |
|
for row in reader: |
|
cuis.append(row['cui']) |
|
text = row['name'] |
|
if row.get('definition'): |
|
text += ' - ' + row['definition'] |
|
names.append(text) |
|
defs.append(row.get('definition', '')) |
|
sources.append(row.get('source', 'UMLS')) |
|
|
|
|
|
print(f"Encoding {len(names)} concepts...") |
|
embeddings = encode_concepts(model, tokenizer, names) |
|
|
|
|
|
dim = embeddings.shape[1] |
|
index = faiss.IndexFlatIP(dim) |
|
index.add(embeddings) |
|
|
|
|
|
np.save(args.output_emb, embeddings) |
|
faiss.write_index(index, args.output_idx) |
|
|
|
|
|
metadata = {} |
|
for idx, cui in enumerate(cuis): |
|
metadata[str(idx)] = { |
|
'cui': cui, |
|
'name': names[idx], |
|
'definition': defs[idx], |
|
'source': sources[idx] |
|
} |
|
with open(args.output_meta, 'w', encoding='utf-8') as f: |
|
json.dump(metadata, f, indent=2) |
|
|
|
print("FAISS index, embeddings, and metadata saved.") |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|