File size: 4,435 Bytes
5dde370
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import json
import torch
from sentence_transformers import SentenceTransformer
from chromadb import Client, Settings,EmbeddingFunction
from tqdm import tqdm
import numpy as np
import os
import psutil
import time
import hashlib
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
from typing import List, Dict, Any

CHROMA_URI = "./Data/database"
EMBEDDING_MODEL_NAME = "pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb"
VECTOR_DIM = 768
EMBEDDINGS_DIR = "./Data/Embeddings"



class BioEmbeddingFunction(EmbeddingFunction):
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = SentenceTransformer(EMBEDDING_MODEL_NAME)
        self.model.to(self.device)
        # 添加dimensionality属性,获取模型输出的嵌入维度
        self.dimensionality = self.model.get_sentence_embedding_dimension()
        # 确保dimensionality在类实例中正确设置
        if not hasattr(self, "dimensionality") or self.dimensionality is None:
            self.dimensionality = VECTOR_DIM
    
    def __call__(self, input: list[str]) -> list[list[float]]:
        embeddings = self.model.encode(
            input,
            normalize_embeddings=True,
            convert_to_numpy=True
        )
        return embeddings.tolist()

if __name__ == "__main__":
    embedding_func = BioEmbeddingFunction()

   # 创建输出目录
    os.makedirs(CHROMA_URI, exist_ok=True)
    os.makedirs(EMBEDDINGS_DIR, exist_ok=True)
    
    # 加载数据
    print("\n[1/5] 加载数据文件...")
    loading_start = time.time()
    with open("./Data/Processed/keywords/keyword_index.json") as f:
        keyword_index = json.load(f)
    with open("./Data/Processed/cleaned_qa/qa_database.json") as f:
        qa_database = json.load(f)

    # 建立QA索引映射
    print("\n[2/5] 处理文档数据...")
    documents = []
    metadatas = []

    print("建立QA索引映射...")
    qa_map = {qa["id"]: qa for qa in qa_database}
    
    # 使用tqdm显示文档处理进度
    total_items = sum(len(item_ids) for item_ids in keyword_index.values())
    with tqdm(total=total_items, desc="处理文档") as pbar:
        for source, item_ids in keyword_index.items():
            for item_id in item_ids:
                qa = qa_map.get(item_id)
                if not qa:
                    pbar.update(1)
                    continue

                combined_text = f"Question: {qa['question']}\nAnswer: {qa['answer']}\nKeywords: {', '.join(qa.get('keywords', []))}"

                metadata = {
                    "source": source,
                    "item_id": item_id,
                    "keywords": ", ".join(qa.get("keywords", [])),
                    "type": "qa"
                }

                documents.append(combined_text)
                metadatas.append(metadata)
                pbar.update(1)

    client = Client(
        Settings(
            persist_directory=CHROMA_URI,
            anonymized_telemetry=False,
            is_persistent=True
        )
    )

    collection = client.get_or_create_collection(
        name="healthcare_qa",
        embedding_function=embedding_func,
        metadata={
            "hnsw:space": "cosine",
            "hnsw:construction_ef": 200,
            "hnsw:search_ef": 128,
            "hnsw:M": 64,
        }
    )

    # 分批持久化数据
    PERSIST_BATCH_SIZE = 40000  # 设置小于最大限制 41666
    total_records = len(documents)
    
    print("\n[4/5] 开始持久化数据到向量数据库...")
    
    with tqdm(total=total_records, desc="持久化进度") as pbar:
        for i in range(0, total_records, PERSIST_BATCH_SIZE):
            end_idx = min(i + PERSIST_BATCH_SIZE, total_records)
            
            # 获取当前批次的数据
            batch_ids = [str(j) for j in range(i, end_idx)]
            batch_documents = documents[i:end_idx]
            batch_metadatas = metadatas[i:end_idx]
 
            # 添加到持久化集合
            collection.upsert(
                ids=batch_ids,
                documents=batch_documents,
                metadatas=batch_metadatas
            )
            
            pbar.update(end_idx - i)
    
    print("\n[5/5] 完成数据处理和持久化!")
    print(f"总共处理了 {total_records} 条记录")
    print(f"向量维度: {embedding_func.dimensionality}")