""" Script to build a FAISS index from UMLS concept metadata. Produces: - `umls_embeddings.npy`: normalized vectors for each concept - `umls_index.faiss`: FAISS index for fast similarity search - `umls_metadata.json`: mapping from index position to concept metadata Usage: python build_umls_faiss_index.py \ --input concepts.csv \ --output_meta backend/umls_metadata.json \ --output_emb backend/umls_embeddings.npy \ --output_idx backend/umls_index.faiss """ import argparse import csv import json import numpy as np import faiss from transformers import AutoTokenizer, AutoModel import torch def encode_concepts(model, tokenizer, texts, batch_size=32, device='cpu'): embeddings = [] model.to(device) for i in range(0, len(texts), batch_size): batch_texts = texts[i:i+batch_size] inputs = tokenizer( batch_texts, padding=True, truncation=True, return_tensors='pt' ).to(device) with torch.no_grad(): outputs = model(**inputs) cls_emb = outputs.last_hidden_state[:, 0, :].cpu().numpy() # normalize norms = np.linalg.norm(cls_emb, axis=1, keepdims=True) embeddings.append(cls_emb / norms) return np.vstack(embeddings) def main(): parser = argparse.ArgumentParser(description="Build FAISS index for UMLS concepts.") parser.add_argument('--input', required=True, help='CSV file with columns: cui,name,definition,source') parser.add_argument('--output_meta', required=True, help='JSON metadata output path') parser.add_argument('--output_emb', required=True, help='NumPy embeddings output path') parser.add_argument('--output_idx', required=True, help='FAISS index output path') parser.add_argument('--model', default='microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL', help='Hugging Face model name') args = parser.parse_args() # Load model & tokenizer tokenizer = AutoTokenizer.from_pretrained(args.model) model = AutoModel.from_pretrained(args.model) model.eval() # Read concepts CSV cuis, names, defs, sources = [], [], [], [] with open(args.input, newline='', encoding='utf-8') as f: reader = csv.DictReader(f) for row in reader: cuis.append(row['cui']) text = row['name'] if row.get('definition'): text += ' - ' + row['definition'] names.append(text) defs.append(row.get('definition', '')) sources.append(row.get('source', 'UMLS')) # Encode all concept texts print(f"Encoding {len(names)} concepts...") embeddings = encode_concepts(model, tokenizer, names) # Build FAISS index (inner-product search) dim = embeddings.shape[1] index = faiss.IndexFlatIP(dim) index.add(embeddings) # Save outputs np.save(args.output_emb, embeddings) faiss.write_index(index, args.output_idx) # Build metadata mapping metadata = {} for idx, cui in enumerate(cuis): metadata[str(idx)] = { 'cui': cui, 'name': names[idx], 'definition': defs[idx], 'source': sources[idx] } with open(args.output_meta, 'w', encoding='utf-8') as f: json.dump(metadata, f, indent=2) print("FAISS index, embeddings, and metadata saved.") if __name__ == '__main__': main()