File size: 2,187 Bytes
cad86c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
#!/usr/bin/env python3
"""
Build a FAISS index and metadata from a UMLS CSV.
Outputs:
- `umls_embeddings.npy`
- `umls_index.faiss`
- `umls_metadata.json`
"""
import argparse, csv, json
import numpy as np
import faiss
import torch
from transformers import AutoTokenizer, AutoModel
def encode(texts, tokenizer, model, batch_size=32):
embs = []
model.eval()
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
inputs = tokenizer(batch, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
outputs = model(**inputs)
cls = outputs.last_hidden_state[:,0,:].cpu().numpy()
normed = cls / np.linalg.norm(cls, axis=1, keepdims=True)
embs.append(normed)
return np.vstack(embs)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input', required=True, help='CSV with cui,name,definition,source')
parser.add_argument('--out_meta', required=True)
parser.add_argument('--out_emb', required=True)
parser.add_argument('--out_idx', required=True)
parser.add_argument('--model', default='microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL')
args = parser.parse_args()
# Load
tokenizer = AutoTokenizer.from_pretrained(args.model)
model = AutoModel.from_pretrained(args.model)
cuis, texts, defs, srcs = [], [], [], []
with open(args.input) as f:
for row in csv.DictReader(f):
cuis.append(row['cui'])
texts.append(row['name'] + (' - ' + row.get('definition','')))
defs.append(row.get('definition',''))
srcs.append(row.get('source','UMLS'))
print(f'Encoding {len(texts)} concepts...')
embeddings = encode(texts, tokenizer, model)
# Build FAISS
dim = embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(embeddings)
# Save
np.save(args.out_emb, embeddings)
faiss.write_index(index, args.out_idx)
meta = {str(i): {'cui': cuis[i], 'name': texts[i], 'definition': defs[i], 'source': srcs[i]}
for i in range(len(cuis))}
json.dump(meta, open(args.out_meta, 'w'), indent=2)
print('Done.') |