|
|
|
""" |
|
Build a FAISS index and metadata from a UMLS CSV. |
|
Outputs: |
|
- `umls_embeddings.npy` |
|
- `umls_index.faiss` |
|
- `umls_metadata.json` |
|
""" |
|
import argparse, csv, json |
|
import numpy as np |
|
import faiss |
|
import torch |
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
def encode(texts, tokenizer, model, batch_size=32): |
|
embs = [] |
|
model.eval() |
|
for i in range(0, len(texts), batch_size): |
|
batch = texts[i:i+batch_size] |
|
inputs = tokenizer(batch, padding=True, truncation=True, return_tensors='pt') |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
cls = outputs.last_hidden_state[:,0,:].cpu().numpy() |
|
normed = cls / np.linalg.norm(cls, axis=1, keepdims=True) |
|
embs.append(normed) |
|
return np.vstack(embs) |
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--input', required=True, help='CSV with cui,name,definition,source') |
|
parser.add_argument('--out_meta', required=True) |
|
parser.add_argument('--out_emb', required=True) |
|
parser.add_argument('--out_idx', required=True) |
|
parser.add_argument('--model', default='microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL') |
|
args = parser.parse_args() |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model) |
|
model = AutoModel.from_pretrained(args.model) |
|
|
|
cuis, texts, defs, srcs = [], [], [], [] |
|
with open(args.input) as f: |
|
for row in csv.DictReader(f): |
|
cuis.append(row['cui']) |
|
texts.append(row['name'] + (' - ' + row.get('definition',''))) |
|
defs.append(row.get('definition','')) |
|
srcs.append(row.get('source','UMLS')) |
|
|
|
print(f'Encoding {len(texts)} concepts...') |
|
embeddings = encode(texts, tokenizer, model) |
|
|
|
|
|
dim = embeddings.shape[1] |
|
index = faiss.IndexFlatIP(dim) |
|
index.add(embeddings) |
|
|
|
|
|
np.save(args.out_emb, embeddings) |
|
faiss.write_index(index, args.out_idx) |
|
|
|
meta = {str(i): {'cui': cuis[i], 'name': texts[i], 'definition': defs[i], 'source': srcs[i]} |
|
for i in range(len(cuis))} |
|
json.dump(meta, open(args.out_meta, 'w'), indent=2) |
|
print('Done.') |