|
import numpy as np |
|
from chromadb import Client, Settings |
|
from sklearn.decomposition import PCA |
|
import joblib |
|
import os |
|
from datetime import datetime |
|
import warnings |
|
import cupy as cp |
|
from cuml.cluster import KMeans as cuKMeans |
|
from tqdm import tqdm |
|
|
|
warnings.filterwarnings('ignore', category=FutureWarning) |
|
warnings.filterwarnings('ignore', category=UserWarning) |
|
|
|
class TopicClusterer: |
|
def __init__(self, chroma_uri: str = "./Data/database"): |
|
"""初始化聚类器 |
|
|
|
Args: |
|
chroma_uri: ChromaDB数据库路径 |
|
""" |
|
self.chroma_uri = chroma_uri |
|
self.client = Client(Settings( |
|
persist_directory=chroma_uri, |
|
anonymized_telemetry=False, |
|
is_persistent=True |
|
)) |
|
|
|
self.vector_dim = 768 |
|
|
|
|
|
try: |
|
self.collection = self.client.get_collection("healthcare_qa") |
|
except Exception as e: |
|
print(f"集合不存在") |
|
|
|
self.embeddings = None |
|
self.reduced_embeddings = None |
|
self.labels = None |
|
self.document_ids = None |
|
|
|
def load_embeddings(self) -> np.ndarray: |
|
"""从数据库加载embeddings""" |
|
|
|
embeddings_cache_file = '/home/dyvm6xra/dyvm6xrauser11/workspace/projects/HKU/Chatbot/Data/Embeddings/embeddings_703df19c43bd6565563071b97e7172ce.npy' |
|
|
|
|
|
if os.path.exists(embeddings_cache_file) and 0: |
|
print("发现缓存的embeddings,正在加载...") |
|
try: |
|
self.embeddings = np.load(embeddings_cache_file) |
|
self.document_ids = [str(i) for i in range(len(self.embeddings))] |
|
print(f"从缓存加载完成,数据形状: {self.embeddings.shape}") |
|
return self.embeddings |
|
except Exception as e: |
|
print(f"加载缓存失败: {e},将从数据库重新加载") |
|
else: |
|
print("正在加载embeddings...") |
|
print(self.collection.count()) |
|
result = self.collection.get(include=["embeddings"]) |
|
self.embeddings = np.array(result["embeddings"]) |
|
self.document_ids = result["ids"] |
|
|
|
print(f"加载完成,数据形状: {self.embeddings.shape}") |
|
return self.embeddings |
|
|
|
def reduce_dimensions(self, n_components: int = 2) -> np.ndarray: |
|
"""使用PCA进行降维 |
|
|
|
Args: |
|
n_components: 降维后的维度 |
|
""" |
|
if self.embeddings is None: |
|
self.load_embeddings() |
|
|
|
print("使用PCA进行降维...") |
|
|
|
|
|
reducer = PCA( |
|
n_components=n_components, |
|
random_state=42, |
|
svd_solver='randomized' |
|
) |
|
self.reduced_embeddings = reducer.fit_transform(self.embeddings) |
|
cumulative_variance = np.cumsum(reducer.explained_variance_ratio_) |
|
print(f"PCA累积解释方差比: {cumulative_variance[-1]:.4f}") |
|
|
|
print(f"降维完成,降维后形状: {self.reduced_embeddings.shape}") |
|
|
|
|
|
cache_dir = os.path.dirname(os.path.dirname(self.chroma_uri)) + '/Embeddings' |
|
os.makedirs(cache_dir, exist_ok=True) |
|
cache_file = os.path.join(cache_dir, f'pca_reduced_{n_components}d.npy') |
|
np.save(cache_file, self.reduced_embeddings) |
|
print(f"降维结果已缓存到: {cache_file}") |
|
|
|
return self.reduced_embeddings |
|
|
|
def cluster_kmeans(self, n_clusters: int = 4) -> np.ndarray: |
|
"""使用KMeans进行聚类 |
|
|
|
Args: |
|
n_clusters: 聚类数 |
|
""" |
|
print("使用GPU加速的KMeans进行聚类...") |
|
|
|
|
|
if self.reduced_embeddings is None: |
|
self.reduce_dimensions() |
|
|
|
|
|
data_gpu = cp.array(self.reduced_embeddings) |
|
|
|
|
|
kmeans = cuKMeans( |
|
n_clusters=n_clusters, |
|
random_state=42, |
|
n_init=10, |
|
max_iter=300, |
|
verbose=1 |
|
) |
|
kmeans.fit(data_gpu) |
|
self.labels = cp.asnumpy(kmeans.labels_) |
|
|
|
|
|
unique_labels = np.unique(self.labels) |
|
n_clusters = len(unique_labels) |
|
|
|
print(f"发现 {n_clusters} 个聚类") |
|
for label in unique_labels: |
|
count = np.sum(self.labels == label) |
|
percentage = count / len(self.labels) * 100 |
|
print(f"簇 {label}: {count} 样本 ({percentage:.2f}%)") |
|
|
|
return self.labels |
|
|
|
def update_database(self) -> None: |
|
"""将聚类结果写回数据库""" |
|
if self.labels is None or self.document_ids is None: |
|
raise ValueError("请先进行聚类") |
|
|
|
print("正在更新数据库...") |
|
|
|
|
|
label_strings = [f"cluster_{label}" for label in self.labels] |
|
|
|
|
|
batch_size = 500 |
|
total_docs = len(self.document_ids) |
|
|
|
for i in tqdm(range(0, total_docs, batch_size), desc="批量更新数据库"): |
|
batch_end = min(i + batch_size, total_docs) |
|
batch_ids = self.document_ids[i:batch_end] |
|
batch_labels = label_strings[i:batch_end] |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
print("数据库更新完成") |
|
|
|
def main(): |
|
|
|
clusterer = TopicClusterer() |
|
|
|
|
|
clusterer.load_embeddings() |
|
|
|
|
|
clusterer.reduce_dimensions(n_components=2) |
|
|
|
|
|
clusterer.cluster_kmeans(n_clusters=4) |
|
|
|
|
|
clusterer.update_database() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|