mgbam commited on
Commit
beb1728
·
verified ·
1 Parent(s): ef17fc5

Update utils/build_umls_faiss_index.py

Browse files
Files changed (1) hide show
  1. utils/build_umls_faiss_index.py +103 -0
utils/build_umls_faiss_index.py CHANGED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script to build a FAISS index from UMLS concept metadata.
3
+ Produces:
4
+ - `umls_embeddings.npy`: normalized vectors for each concept
5
+ - `umls_index.faiss`: FAISS index for fast similarity search
6
+ - `umls_metadata.json`: mapping from index position to concept metadata
7
+
8
+ Usage:
9
+ python build_umls_faiss_index.py \
10
+ --input concepts.csv \
11
+ --output_meta backend/umls_metadata.json \
12
+ --output_emb backend/umls_embeddings.npy \
13
+ --output_idx backend/umls_index.faiss
14
+ """
15
+ import argparse
16
+ import csv
17
+ import json
18
+ import numpy as np
19
+ import faiss
20
+ from transformers import AutoTokenizer, AutoModel
21
+ import torch
22
+
23
+ def encode_concepts(model, tokenizer, texts, batch_size=32, device='cpu'):
24
+ embeddings = []
25
+ model.to(device)
26
+ for i in range(0, len(texts), batch_size):
27
+ batch_texts = texts[i:i+batch_size]
28
+ inputs = tokenizer(
29
+ batch_texts,
30
+ padding=True,
31
+ truncation=True,
32
+ return_tensors='pt'
33
+ ).to(device)
34
+ with torch.no_grad():
35
+ outputs = model(**inputs)
36
+ cls_emb = outputs.last_hidden_state[:, 0, :].cpu().numpy()
37
+ # normalize
38
+ norms = np.linalg.norm(cls_emb, axis=1, keepdims=True)
39
+ embeddings.append(cls_emb / norms)
40
+ return np.vstack(embeddings)
41
+
42
+ def main():
43
+ parser = argparse.ArgumentParser(description="Build FAISS index for UMLS concepts.")
44
+ parser.add_argument('--input', required=True,
45
+ help='CSV file with columns: cui,name,definition,source')
46
+ parser.add_argument('--output_meta', required=True,
47
+ help='JSON metadata output path')
48
+ parser.add_argument('--output_emb', required=True,
49
+ help='NumPy embeddings output path')
50
+ parser.add_argument('--output_idx', required=True,
51
+ help='FAISS index output path')
52
+ parser.add_argument('--model',
53
+ default='microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL',
54
+ help='Hugging Face model name')
55
+ args = parser.parse_args()
56
+
57
+ # Load model & tokenizer
58
+ tokenizer = AutoTokenizer.from_pretrained(args.model)
59
+ model = AutoModel.from_pretrained(args.model)
60
+ model.eval()
61
+
62
+ # Read concepts CSV
63
+ cuis, names, defs, sources = [], [], [], []
64
+ with open(args.input, newline='', encoding='utf-8') as f:
65
+ reader = csv.DictReader(f)
66
+ for row in reader:
67
+ cuis.append(row['cui'])
68
+ text = row['name']
69
+ if row.get('definition'):
70
+ text += ' - ' + row['definition']
71
+ names.append(text)
72
+ defs.append(row.get('definition', ''))
73
+ sources.append(row.get('source', 'UMLS'))
74
+
75
+ # Encode all concept texts
76
+ print(f"Encoding {len(names)} concepts...")
77
+ embeddings = encode_concepts(model, tokenizer, names)
78
+
79
+ # Build FAISS index (inner-product search)
80
+ dim = embeddings.shape[1]
81
+ index = faiss.IndexFlatIP(dim)
82
+ index.add(embeddings)
83
+
84
+ # Save outputs
85
+ np.save(args.output_emb, embeddings)
86
+ faiss.write_index(index, args.output_idx)
87
+
88
+ # Build metadata mapping
89
+ metadata = {}
90
+ for idx, cui in enumerate(cuis):
91
+ metadata[str(idx)] = {
92
+ 'cui': cui,
93
+ 'name': names[idx],
94
+ 'definition': defs[idx],
95
+ 'source': sources[idx]
96
+ }
97
+ with open(args.output_meta, 'w', encoding='utf-8') as f:
98
+ json.dump(metadata, f, indent=2)
99
+
100
+ print("FAISS index, embeddings, and metadata saved.")
101
+
102
+ if __name__ == '__main__':
103
+ main()