File size: 3,531 Bytes
beb1728
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""
Script to build a FAISS index from UMLS concept metadata.
Produces:
 - `umls_embeddings.npy`: normalized vectors for each concept
 - `umls_index.faiss`: FAISS index for fast similarity search
 - `umls_metadata.json`: mapping from index position to concept metadata

Usage:
 python build_umls_faiss_index.py \
   --input concepts.csv \
   --output_meta backend/umls_metadata.json \
   --output_emb backend/umls_embeddings.npy \
   --output_idx backend/umls_index.faiss
"""
import argparse
import csv
import json
import numpy as np
import faiss
from transformers import AutoTokenizer, AutoModel
import torch

def encode_concepts(model, tokenizer, texts, batch_size=32, device='cpu'):
    embeddings = []
    model.to(device)
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i+batch_size]
        inputs = tokenizer(
            batch_texts,
            padding=True,
            truncation=True,
            return_tensors='pt'
        ).to(device)
        with torch.no_grad():
            outputs = model(**inputs)
        cls_emb = outputs.last_hidden_state[:, 0, :].cpu().numpy()
        # normalize
        norms = np.linalg.norm(cls_emb, axis=1, keepdims=True)
        embeddings.append(cls_emb / norms)
    return np.vstack(embeddings)

def main():
    parser = argparse.ArgumentParser(description="Build FAISS index for UMLS concepts.")
    parser.add_argument('--input', required=True,
                        help='CSV file with columns: cui,name,definition,source')
    parser.add_argument('--output_meta', required=True,
                        help='JSON metadata output path')
    parser.add_argument('--output_emb', required=True,
                        help='NumPy embeddings output path')
    parser.add_argument('--output_idx', required=True,
                        help='FAISS index output path')
    parser.add_argument('--model',
                        default='microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL',
                        help='Hugging Face model name')
    args = parser.parse_args()

    # Load model & tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    model = AutoModel.from_pretrained(args.model)
    model.eval()

    # Read concepts CSV
    cuis, names, defs, sources = [], [], [], []
    with open(args.input, newline='', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        for row in reader:
            cuis.append(row['cui'])
            text = row['name']
            if row.get('definition'):
                text += ' - ' + row['definition']
            names.append(text)
            defs.append(row.get('definition', ''))
            sources.append(row.get('source', 'UMLS'))

    # Encode all concept texts
    print(f"Encoding {len(names)} concepts...")
    embeddings = encode_concepts(model, tokenizer, names)

    # Build FAISS index (inner-product search)
    dim = embeddings.shape[1]
    index = faiss.IndexFlatIP(dim)
    index.add(embeddings)

    # Save outputs
    np.save(args.output_emb, embeddings)
    faiss.write_index(index, args.output_idx)

    # Build metadata mapping
    metadata = {}
    for idx, cui in enumerate(cuis):
        metadata[str(idx)] = {
            'cui': cui,
            'name': names[idx],
            'definition': defs[idx],
            'source': sources[idx]
        }
    with open(args.output_meta, 'w', encoding='utf-8') as f:
        json.dump(metadata, f, indent=2)

    print("FAISS index, embeddings, and metadata saved.")

if __name__ == '__main__':
    main()