Chatbot / Back-end /models /test_db.py
same899's picture
Upload folder using huggingface_hub
5dde370 verified
raw
history blame
3.84 kB
import json
from chromadb import Client, Settings, EmbeddingFunction
from pprint import pprint
import random
import os
from sentence_transformers import SentenceTransformer
import torch
# 配置参数
CHROMA_URI = "./Data/database"
EMBEDDING_MODEL_NAME = "pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb"
VECTOR_DIM = 768 # BioBERT的向量维度
class BioEmbeddingFunction(EmbeddingFunction):
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = SentenceTransformer(EMBEDDING_MODEL_NAME)
self.model.to(self.device)
def __call__(self, input: list[str]) -> list[list[float]]:
embeddings = self.model.encode(
input,
normalize_embeddings=True,
convert_to_numpy=True
)
return embeddings.tolist()
def test_database():
print("="*50)
print("开始测试数据库")
print("="*50)
# 初始化客户端
client = Client(Settings(
persist_directory=CHROMA_URI,
anonymized_telemetry=False,
is_persistent=True
))
embedding_func = BioEmbeddingFunction()
collection = client.get_or_create_collection(
name="healthcare_qa",
embedding_function=embedding_func
)
# 1. 显示基本信息
print("\n1. 数据库基本信息:")
print(f"数据库位置: {os.path.abspath(CHROMA_URI)}")
print(f"数据库大小: {os.path.getsize(CHROMA_URI) / 1024 / 1024:.2f} MB")
print(f"总条目数: {collection.count()} 条")
print(f"使用的嵌入模型: {EMBEDDING_MODEL_NAME}")
# 2. 随机获取样本
print("\n2. 随机样本展示:")
total_items = collection.count()
sample_size = min(2, total_items)
random_indices = random.sample(range(total_items), sample_size)
results = collection.get(
ids=[str(i) for i in random_indices],
include=["documents", "metadatas"]
)
for i, (doc, metadata) in enumerate(zip(results['documents'], results['metadatas']), 1):
print(f"\n样本 {i}:")
print("-" * 40)
print("文档内容:")
print(doc)
print("\n元数据:")
pprint(metadata)
print("-" * 40)
# # 4. 测试 update
# print("\n4. 测试 update 功能:")
# # 随机选择一个文档
# total_items = collection.count()
# random_index = random.randint(0, total_items - 1)
# random_id = str(random_index)
# random_index2 = random.randint(0, total_items - 1)
# random_id2 = str(random_index2)
# # 更新文档内容
# new_content = "更新后的文档内容"
# collection.update(
# ids=[random_id, random_id2],
# metadatas=[{"cluster": "new_cluster"}, {"cluster": "new_cluster2"}]
# )
# # 验证更新
# results = collection.get(
# ids=[random_id],
# include=["documents", "metadatas"]
# )
# print(f"\n更新后的文档内容: {results['documents'][0]}")
# print(f"\n更新后的元数据: {results['metadatas'][0]}")
# 3. 测试简单查询
print("\n3. 测试查询功能:")
query = "diabetes"
results = collection.query(
query_texts=[query],
n_results=1,
include=["documents", "metadatas", "distances"]
)
print(f"\n使用查询词 '{query}' 的结果:")
for i, (doc, metadata, distance) in enumerate(zip(
results['documents'][0],
results['metadatas'][0],
results['distances'][0]
), 1):
print(f"\n结果 {i}:")
print("-" * 40)
print(f"相似度得分: {1 - distance:.4f}")
print("\n文档内容:")
print(doc)
print("\n元数据:")
pprint(metadata)
print("-" * 40)
if __name__ == "__main__":
test_database()