|
import json |
|
import torch |
|
from sentence_transformers import SentenceTransformer |
|
from chromadb import Client, Settings,EmbeddingFunction |
|
from tqdm import tqdm |
|
import numpy as np |
|
import os |
|
import psutil |
|
import time |
|
import hashlib |
|
from datetime import datetime |
|
from concurrent.futures import ThreadPoolExecutor |
|
from typing import List, Dict, Any |
|
|
|
CHROMA_URI = "./Data/database" |
|
EMBEDDING_MODEL_NAME = "pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb" |
|
VECTOR_DIM = 768 |
|
EMBEDDINGS_DIR = "./Data/Embeddings" |
|
|
|
|
|
|
|
class BioEmbeddingFunction(EmbeddingFunction): |
|
def __init__(self): |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.model = SentenceTransformer(EMBEDDING_MODEL_NAME) |
|
self.model.to(self.device) |
|
|
|
self.dimensionality = self.model.get_sentence_embedding_dimension() |
|
|
|
if not hasattr(self, "dimensionality") or self.dimensionality is None: |
|
self.dimensionality = VECTOR_DIM |
|
|
|
def __call__(self, input: list[str]) -> list[list[float]]: |
|
embeddings = self.model.encode( |
|
input, |
|
normalize_embeddings=True, |
|
convert_to_numpy=True |
|
) |
|
return embeddings.tolist() |
|
|
|
if __name__ == "__main__": |
|
embedding_func = BioEmbeddingFunction() |
|
|
|
|
|
os.makedirs(CHROMA_URI, exist_ok=True) |
|
os.makedirs(EMBEDDINGS_DIR, exist_ok=True) |
|
|
|
|
|
print("\n[1/5] 加载数据文件...") |
|
loading_start = time.time() |
|
with open("./Data/Processed/keywords/keyword_index.json") as f: |
|
keyword_index = json.load(f) |
|
with open("./Data/Processed/cleaned_qa/qa_database.json") as f: |
|
qa_database = json.load(f) |
|
|
|
|
|
print("\n[2/5] 处理文档数据...") |
|
documents = [] |
|
metadatas = [] |
|
|
|
print("建立QA索引映射...") |
|
qa_map = {qa["id"]: qa for qa in qa_database} |
|
|
|
|
|
total_items = sum(len(item_ids) for item_ids in keyword_index.values()) |
|
with tqdm(total=total_items, desc="处理文档") as pbar: |
|
for source, item_ids in keyword_index.items(): |
|
for item_id in item_ids: |
|
qa = qa_map.get(item_id) |
|
if not qa: |
|
pbar.update(1) |
|
continue |
|
|
|
combined_text = f"Question: {qa['question']}\nAnswer: {qa['answer']}\nKeywords: {', '.join(qa.get('keywords', []))}" |
|
|
|
metadata = { |
|
"source": source, |
|
"item_id": item_id, |
|
"keywords": ", ".join(qa.get("keywords", [])), |
|
"type": "qa" |
|
} |
|
|
|
documents.append(combined_text) |
|
metadatas.append(metadata) |
|
pbar.update(1) |
|
|
|
client = Client( |
|
Settings( |
|
persist_directory=CHROMA_URI, |
|
anonymized_telemetry=False, |
|
is_persistent=True |
|
) |
|
) |
|
|
|
collection = client.get_or_create_collection( |
|
name="healthcare_qa", |
|
embedding_function=embedding_func, |
|
metadata={ |
|
"hnsw:space": "cosine", |
|
"hnsw:construction_ef": 200, |
|
"hnsw:search_ef": 128, |
|
"hnsw:M": 64, |
|
} |
|
) |
|
|
|
|
|
PERSIST_BATCH_SIZE = 40000 |
|
total_records = len(documents) |
|
|
|
print("\n[4/5] 开始持久化数据到向量数据库...") |
|
|
|
with tqdm(total=total_records, desc="持久化进度") as pbar: |
|
for i in range(0, total_records, PERSIST_BATCH_SIZE): |
|
end_idx = min(i + PERSIST_BATCH_SIZE, total_records) |
|
|
|
|
|
batch_ids = [str(j) for j in range(i, end_idx)] |
|
batch_documents = documents[i:end_idx] |
|
batch_metadatas = metadatas[i:end_idx] |
|
|
|
|
|
collection.upsert( |
|
ids=batch_ids, |
|
documents=batch_documents, |
|
metadatas=batch_metadatas |
|
) |
|
|
|
pbar.update(end_idx - i) |
|
|
|
print("\n[5/5] 完成数据处理和持久化!") |
|
print(f"总共处理了 {total_records} 条记录") |
|
print(f"向量维度: {embedding_func.dimensionality}") |
|
|
|
|