File size: 3,840 Bytes
5dde370 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import json
from chromadb import Client, Settings, EmbeddingFunction
from pprint import pprint
import random
import os
from sentence_transformers import SentenceTransformer
import torch
# 配置参数
CHROMA_URI = "./Data/database"
EMBEDDING_MODEL_NAME = "pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb"
VECTOR_DIM = 768 # BioBERT的向量维度
class BioEmbeddingFunction(EmbeddingFunction):
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = SentenceTransformer(EMBEDDING_MODEL_NAME)
self.model.to(self.device)
def __call__(self, input: list[str]) -> list[list[float]]:
embeddings = self.model.encode(
input,
normalize_embeddings=True,
convert_to_numpy=True
)
return embeddings.tolist()
def test_database():
print("="*50)
print("开始测试数据库")
print("="*50)
# 初始化客户端
client = Client(Settings(
persist_directory=CHROMA_URI,
anonymized_telemetry=False,
is_persistent=True
))
embedding_func = BioEmbeddingFunction()
collection = client.get_or_create_collection(
name="healthcare_qa",
embedding_function=embedding_func
)
# 1. 显示基本信息
print("\n1. 数据库基本信息:")
print(f"数据库位置: {os.path.abspath(CHROMA_URI)}")
print(f"数据库大小: {os.path.getsize(CHROMA_URI) / 1024 / 1024:.2f} MB")
print(f"总条目数: {collection.count()} 条")
print(f"使用的嵌入模型: {EMBEDDING_MODEL_NAME}")
# 2. 随机获取样本
print("\n2. 随机样本展示:")
total_items = collection.count()
sample_size = min(2, total_items)
random_indices = random.sample(range(total_items), sample_size)
results = collection.get(
ids=[str(i) for i in random_indices],
include=["documents", "metadatas"]
)
for i, (doc, metadata) in enumerate(zip(results['documents'], results['metadatas']), 1):
print(f"\n样本 {i}:")
print("-" * 40)
print("文档内容:")
print(doc)
print("\n元数据:")
pprint(metadata)
print("-" * 40)
# # 4. 测试 update
# print("\n4. 测试 update 功能:")
# # 随机选择一个文档
# total_items = collection.count()
# random_index = random.randint(0, total_items - 1)
# random_id = str(random_index)
# random_index2 = random.randint(0, total_items - 1)
# random_id2 = str(random_index2)
# # 更新文档内容
# new_content = "更新后的文档内容"
# collection.update(
# ids=[random_id, random_id2],
# metadatas=[{"cluster": "new_cluster"}, {"cluster": "new_cluster2"}]
# )
# # 验证更新
# results = collection.get(
# ids=[random_id],
# include=["documents", "metadatas"]
# )
# print(f"\n更新后的文档内容: {results['documents'][0]}")
# print(f"\n更新后的元数据: {results['metadatas'][0]}")
# 3. 测试简单查询
print("\n3. 测试查询功能:")
query = "diabetes"
results = collection.query(
query_texts=[query],
n_results=1,
include=["documents", "metadatas", "distances"]
)
print(f"\n使用查询词 '{query}' 的结果:")
for i, (doc, metadata, distance) in enumerate(zip(
results['documents'][0],
results['metadatas'][0],
results['distances'][0]
), 1):
print(f"\n结果 {i}:")
print("-" * 40)
print(f"相似度得分: {1 - distance:.4f}")
print("\n文档内容:")
print(doc)
print("\n元数据:")
pprint(metadata)
print("-" * 40)
if __name__ == "__main__":
test_database()
|