mgbam commited on
Commit
cad86c8
·
verified ·
1 Parent(s): cb52b2e

Create utils/build_umls_faiss_index.py

Browse files
Files changed (1) hide show
  1. utils/build_umls_faiss_index.py +64 -0
utils/build_umls_faiss_index.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Build a FAISS index and metadata from a UMLS CSV.
4
+ Outputs:
5
+ - `umls_embeddings.npy`
6
+ - `umls_index.faiss`
7
+ - `umls_metadata.json`
8
+ """
9
+ import argparse, csv, json
10
+ import numpy as np
11
+ import faiss
12
+ import torch
13
+ from transformers import AutoTokenizer, AutoModel
14
+
15
+ def encode(texts, tokenizer, model, batch_size=32):
16
+ embs = []
17
+ model.eval()
18
+ for i in range(0, len(texts), batch_size):
19
+ batch = texts[i:i+batch_size]
20
+ inputs = tokenizer(batch, padding=True, truncation=True, return_tensors='pt')
21
+ with torch.no_grad():
22
+ outputs = model(**inputs)
23
+ cls = outputs.last_hidden_state[:,0,:].cpu().numpy()
24
+ normed = cls / np.linalg.norm(cls, axis=1, keepdims=True)
25
+ embs.append(normed)
26
+ return np.vstack(embs)
27
+
28
+ if __name__ == '__main__':
29
+ parser = argparse.ArgumentParser()
30
+ parser.add_argument('--input', required=True, help='CSV with cui,name,definition,source')
31
+ parser.add_argument('--out_meta', required=True)
32
+ parser.add_argument('--out_emb', required=True)
33
+ parser.add_argument('--out_idx', required=True)
34
+ parser.add_argument('--model', default='microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL')
35
+ args = parser.parse_args()
36
+
37
+ # Load
38
+ tokenizer = AutoTokenizer.from_pretrained(args.model)
39
+ model = AutoModel.from_pretrained(args.model)
40
+
41
+ cuis, texts, defs, srcs = [], [], [], []
42
+ with open(args.input) as f:
43
+ for row in csv.DictReader(f):
44
+ cuis.append(row['cui'])
45
+ texts.append(row['name'] + (' - ' + row.get('definition','')))
46
+ defs.append(row.get('definition',''))
47
+ srcs.append(row.get('source','UMLS'))
48
+
49
+ print(f'Encoding {len(texts)} concepts...')
50
+ embeddings = encode(texts, tokenizer, model)
51
+
52
+ # Build FAISS
53
+ dim = embeddings.shape[1]
54
+ index = faiss.IndexFlatIP(dim)
55
+ index.add(embeddings)
56
+
57
+ # Save
58
+ np.save(args.out_emb, embeddings)
59
+ faiss.write_index(index, args.out_idx)
60
+
61
+ meta = {str(i): {'cui': cuis[i], 'name': texts[i], 'definition': defs[i], 'source': srcs[i]}
62
+ for i in range(len(cuis))}
63
+ json.dump(meta, open(args.out_meta, 'w'), indent=2)
64
+ print('Done.')