|
import numpy as np |
|
from sklearn.model_selection import StratifiedKFold |
|
from sklearn.metrics import classification_report, accuracy_score, f1_score |
|
from sklearn.ensemble import RandomForestClassifier |
|
import joblib |
|
from chromadb import Client, Settings |
|
import os |
|
import json |
|
from datetime import datetime |
|
import warnings |
|
warnings.filterwarnings('ignore') |
|
|
|
class TopicClassifier: |
|
def __init__(self, chroma_uri: str = "./Data/database"): |
|
"""初始化分类器 |
|
|
|
Args: |
|
chroma_uri: ChromaDB数据库路径 |
|
""" |
|
self.chroma_uri = chroma_uri |
|
self.client = Client(Settings( |
|
persist_directory=chroma_uri, |
|
anonymized_telemetry=False, |
|
is_persistent=True |
|
)) |
|
self.collection = self.client.get_collection("healthcare_qa") |
|
self.model = None |
|
self.X = None |
|
self.y = None |
|
|
|
def load_data(self): |
|
"""从数据库加载数据和标签""" |
|
print("正在加载数据...") |
|
|
|
|
|
result = self.collection.get(include=["embeddings", "metadatas"]) |
|
self.X = np.array(result["embeddings"]) |
|
|
|
|
|
self.y = [] |
|
for metadata in result["metadatas"]: |
|
cluster = metadata.get("cluster", "noise") |
|
|
|
if cluster == "noise": |
|
self.y.append(-1) |
|
else: |
|
self.y.append(int(cluster.split("_")[1])) |
|
self.y = np.array(self.y) |
|
|
|
|
|
mask = self.y != -1 |
|
self.X = self.X[mask] |
|
self.y = self.y[mask] |
|
|
|
print(f"数据加载完成,特征形状: {self.X.shape}") |
|
print(f"类别数量: {len(np.unique(self.y))}") |
|
|
|
def train_and_evaluate(self, n_splits=5): |
|
"""使用5折交叉验证训练和评估模型""" |
|
if self.X is None or self.y is None: |
|
self.load_data() |
|
|
|
print(f"\n开始{n_splits}折交叉验证...") |
|
|
|
|
|
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42) |
|
|
|
|
|
fold_scores = { |
|
'accuracy': [], |
|
'macro_f1': [], |
|
'weighted_f1': [] |
|
} |
|
|
|
for fold, (train_idx, val_idx) in enumerate(skf.split(self.X, self.y), 1): |
|
print(f"\n第 {fold} 折验证:") |
|
|
|
|
|
X_train, X_val = self.X[train_idx], self.X[val_idx] |
|
y_train, y_val = self.y[train_idx], self.y[val_idx] |
|
|
|
|
|
print("训练模型...") |
|
self.model = RandomForestClassifier( |
|
n_estimators=100, |
|
max_depth=None, |
|
n_jobs=-1, |
|
random_state=42 |
|
) |
|
self.model.fit(X_train, y_train) |
|
|
|
|
|
y_pred = self.model.predict(X_val) |
|
|
|
|
|
accuracy = accuracy_score(y_val, y_pred) |
|
macro_f1 = f1_score(y_val, y_pred, average='macro') |
|
weighted_f1 = f1_score(y_val, y_pred, average='weighted') |
|
|
|
fold_scores['accuracy'].append(accuracy) |
|
fold_scores['macro_f1'].append(macro_f1) |
|
fold_scores['weighted_f1'].append(weighted_f1) |
|
|
|
print("\n分类报告:") |
|
print(classification_report(y_val, y_pred)) |
|
|
|
|
|
print("\n总体性能:") |
|
print(f"平均准确率: {np.mean(fold_scores['accuracy']):.4f} ± {np.std(fold_scores['accuracy']):.4f}") |
|
print(f"平均宏F1分数: {np.mean(fold_scores['macro_f1']):.4f} ± {np.std(fold_scores['macro_f1']):.4f}") |
|
print(f"平均加权F1分数: {np.mean(fold_scores['weighted_f1']):.4f} ± {np.std(fold_scores['weighted_f1']):.4f}") |
|
|
|
def save_model(self, model_dir: str = "./models"): |
|
"""保存最终模型""" |
|
if self.model is None: |
|
raise ValueError("模型尚未训练") |
|
|
|
os.makedirs(model_dir, exist_ok=True) |
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
model_path = os.path.join(model_dir, f"topic_classifier_{timestamp}.joblib") |
|
|
|
joblib.dump(self.model, model_path) |
|
print(f"\n模型已保存到: {model_path}") |
|
|
|
def main(): |
|
|
|
classifier = TopicClassifier() |
|
classifier.train_and_evaluate() |
|
classifier.save_model() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|