File size: 2,187 Bytes
cad86c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#!/usr/bin/env python3
"""
Build a FAISS index and metadata from a UMLS CSV.
Outputs:
 - `umls_embeddings.npy`
 - `umls_index.faiss`
 - `umls_metadata.json`
"""
import argparse, csv, json
import numpy as np
import faiss
import torch
from transformers import AutoTokenizer, AutoModel

def encode(texts, tokenizer, model, batch_size=32):
    embs = []
    model.eval()
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        inputs = tokenizer(batch, padding=True, truncation=True, return_tensors='pt')
        with torch.no_grad():
            outputs = model(**inputs)
        cls = outputs.last_hidden_state[:,0,:].cpu().numpy()
        normed = cls / np.linalg.norm(cls, axis=1, keepdims=True)
        embs.append(normed)
    return np.vstack(embs)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--input', required=True, help='CSV with cui,name,definition,source')
    parser.add_argument('--out_meta', required=True)
    parser.add_argument('--out_emb', required=True)
    parser.add_argument('--out_idx', required=True)
    parser.add_argument('--model', default='microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL')
    args = parser.parse_args()

    # Load
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    model = AutoModel.from_pretrained(args.model)

    cuis, texts, defs, srcs = [], [], [], []
    with open(args.input) as f:
        for row in csv.DictReader(f):
            cuis.append(row['cui'])
            texts.append(row['name'] + (' - ' + row.get('definition','')))
            defs.append(row.get('definition',''))
            srcs.append(row.get('source','UMLS'))

    print(f'Encoding {len(texts)} concepts...')
    embeddings = encode(texts, tokenizer, model)

    # Build FAISS
    dim = embeddings.shape[1]
    index = faiss.IndexFlatIP(dim)
    index.add(embeddings)

    # Save
    np.save(args.out_emb, embeddings)
    faiss.write_index(index, args.out_idx)

    meta = {str(i): {'cui': cuis[i], 'name': texts[i], 'definition': defs[i], 'source': srcs[i]}
            for i in range(len(cuis))}
    json.dump(meta, open(args.out_meta, 'w'), indent=2)
    print('Done.')