File size: 4,720 Bytes
5dde370
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import numpy as np
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import classification_report, accuracy_score, f1_score
from sklearn.ensemble import RandomForestClassifier
import joblib
from chromadb import Client, Settings
import os
import json
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

class TopicClassifier:
    def __init__(self, chroma_uri: str = "./Data/database"):
        """初始化分类器
        
        Args:
            chroma_uri: ChromaDB数据库路径
        """
        self.chroma_uri = chroma_uri
        self.client = Client(Settings(
            persist_directory=chroma_uri,
            anonymized_telemetry=False,
            is_persistent=True
        ))
        self.collection = self.client.get_collection("healthcare_qa")
        self.model = None
        self.X = None
        self.y = None
        
    def load_data(self):
        """从数据库加载数据和标签"""
        print("正在加载数据...")
        
        # 加载嵌入向量
        result = self.collection.get(include=["embeddings", "metadatas"])
        self.X = np.array(result["embeddings"])
        
        # 从元数据中提取cluster标签
        self.y = []
        for metadata in result["metadatas"]:
            cluster = metadata.get("cluster", "noise")
            # 将cluster_X格式转换为数字标签
            if cluster == "noise":
                self.y.append(-1)
            else:
                self.y.append(int(cluster.split("_")[1]))
        self.y = np.array(self.y)
        
        # 移除噪声点
        mask = self.y != -1
        self.X = self.X[mask]
        self.y = self.y[mask]
        
        print(f"数据加载完成,特征形状: {self.X.shape}")
        print(f"类别数量: {len(np.unique(self.y))}")
        
    def train_and_evaluate(self, n_splits=5):
        """使用5折交叉验证训练和评估模型"""
        if self.X is None or self.y is None:
            self.load_data()
            
        print(f"\n开始{n_splits}折交叉验证...")
        
        # 初始化分层K折交叉验证
        skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
        
        # 存储每折的性能指标
        fold_scores = {
            'accuracy': [],
            'macro_f1': [],
            'weighted_f1': []
        }
        
        for fold, (train_idx, val_idx) in enumerate(skf.split(self.X, self.y), 1):
            print(f"\n第 {fold} 折验证:")
            
            # 划分训练集和验证集
            X_train, X_val = self.X[train_idx], self.X[val_idx]
            y_train, y_val = self.y[train_idx], self.y[val_idx]
            
            # 训练模型
            print("训练模型...")
            self.model = RandomForestClassifier(
                n_estimators=100,
                max_depth=None,
                n_jobs=-1,
                random_state=42
            )
            self.model.fit(X_train, y_train)
            
            # 预测和评估
            y_pred = self.model.predict(X_val)
            
            # 计算性能指标
            accuracy = accuracy_score(y_val, y_pred)
            macro_f1 = f1_score(y_val, y_pred, average='macro')
            weighted_f1 = f1_score(y_val, y_pred, average='weighted')
            
            fold_scores['accuracy'].append(accuracy)
            fold_scores['macro_f1'].append(macro_f1)
            fold_scores['weighted_f1'].append(weighted_f1)
            
            print("\n分类报告:")
            print(classification_report(y_val, y_pred))
            
        # 输出平均性能
        print("\n总体性能:")
        print(f"平均准确率: {np.mean(fold_scores['accuracy']):.4f} ± {np.std(fold_scores['accuracy']):.4f}")
        print(f"平均宏F1分数: {np.mean(fold_scores['macro_f1']):.4f} ± {np.std(fold_scores['macro_f1']):.4f}")
        print(f"平均加权F1分数: {np.mean(fold_scores['weighted_f1']):.4f} ± {np.std(fold_scores['weighted_f1']):.4f}")
        
    def save_model(self, model_dir: str = "./models"):
        """保存最终模型"""
        if self.model is None:
            raise ValueError("模型尚未训练")
            
        os.makedirs(model_dir, exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        model_path = os.path.join(model_dir, f"topic_classifier_{timestamp}.joblib")
        
        joblib.dump(self.model, model_path)
        print(f"\n模型已保存到: {model_path}")

def main():
    # 示例用法
    classifier = TopicClassifier()
    classifier.train_and_evaluate()
    classifier.save_model()

if __name__ == "__main__":
    main()