Spaces:
Sleeping
Sleeping
YanBoChen
commited on
Commit
Β·
f3ac7d9
1
Parent(s):
8942859
Add comprehensive tests for chunk quality analysis and embedding validation
Browse files- Introduced a new test suite for chunk quality analysis, covering chunk length distribution, chunking method comparison, token vs character analysis, and recommendations generation.
- Enhanced embedding validation tests with detailed logging and checks for embedding dimensions, self-retrieval accuracy, and cross-dataset search functionality.
- tests/test_chunk_quality_analysis.py +333 -0
- tests/test_embedding_and_index.py +96 -24
- tests/test_embedding_validation.py +99 -15
tests/test_chunk_quality_analysis.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chunk Quality Analysis Tests
|
| 3 |
+
|
| 4 |
+
This module analyzes chunk quality and identifies issues with chunk length differences
|
| 5 |
+
between emergency and treatment data processing methods.
|
| 6 |
+
|
| 7 |
+
Author: OnCall.ai Team
|
| 8 |
+
Date: 2025-07-28
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import sys
|
| 12 |
+
import json
|
| 13 |
+
import numpy as np
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import List, Dict, Tuple
|
| 16 |
+
import logging
|
| 17 |
+
|
| 18 |
+
# Setup logging
|
| 19 |
+
logging.basicConfig(
|
| 20 |
+
level=logging.INFO,
|
| 21 |
+
format='%(levelname)s:%(name)s:%(message)s'
|
| 22 |
+
)
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
# Add src to python path
|
| 26 |
+
current_dir = Path(__file__).parent.resolve()
|
| 27 |
+
project_root = current_dir.parent
|
| 28 |
+
sys.path.append(str(project_root / "src"))
|
| 29 |
+
|
| 30 |
+
from data_processing import DataProcessor
|
| 31 |
+
|
| 32 |
+
class TestChunkQualityAnalysis:
|
| 33 |
+
|
| 34 |
+
def setup_class(self):
|
| 35 |
+
"""Initialize test environment"""
|
| 36 |
+
print("\n=== Phase 1: Setting up Chunk Quality Analysis ===")
|
| 37 |
+
self.base_dir = Path(__file__).parent.parent.resolve()
|
| 38 |
+
self.models_dir = self.base_dir / "models"
|
| 39 |
+
self.embeddings_dir = self.models_dir / "embeddings"
|
| 40 |
+
|
| 41 |
+
print(f"β’ Base directory: {self.base_dir}")
|
| 42 |
+
print(f"β’ Models directory: {self.models_dir}")
|
| 43 |
+
|
| 44 |
+
# Initialize processor
|
| 45 |
+
self.processor = DataProcessor(base_dir=str(self.base_dir))
|
| 46 |
+
print("β’ DataProcessor initialized")
|
| 47 |
+
|
| 48 |
+
def test_chunk_length_analysis(self):
|
| 49 |
+
"""Detailed analysis of chunk length distribution"""
|
| 50 |
+
print("\n=== Phase 2: Chunk Length Distribution Analysis ===")
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
# Load chunk data
|
| 54 |
+
print("β’ Loading chunk data...")
|
| 55 |
+
with open(self.embeddings_dir / "emergency_chunks.json", 'r') as f:
|
| 56 |
+
emergency_chunks = json.load(f)
|
| 57 |
+
with open(self.embeddings_dir / "treatment_chunks.json", 'r') as f:
|
| 58 |
+
treatment_chunks = json.load(f)
|
| 59 |
+
|
| 60 |
+
# Analyze emergency chunks
|
| 61 |
+
em_lengths = [len(chunk['text']) for chunk in emergency_chunks]
|
| 62 |
+
em_token_counts = [chunk.get('token_count', 0) for chunk in emergency_chunks]
|
| 63 |
+
|
| 64 |
+
print(f"\nπ Emergency Chunks Analysis:")
|
| 65 |
+
print(f"β’ Total chunks: {len(em_lengths):,}")
|
| 66 |
+
print(f"β’ Min length: {min(em_lengths)} chars")
|
| 67 |
+
print(f"β’ Max length: {max(em_lengths)} chars")
|
| 68 |
+
print(f"β’ Average length: {sum(em_lengths)/len(em_lengths):.2f} chars")
|
| 69 |
+
print(f"β’ Median length: {sorted(em_lengths)[len(em_lengths)//2]} chars")
|
| 70 |
+
|
| 71 |
+
if any(em_token_counts):
|
| 72 |
+
avg_tokens = sum(em_token_counts)/len(em_token_counts)
|
| 73 |
+
print(f"β’ Average tokens: {avg_tokens:.2f}")
|
| 74 |
+
print(f"β’ Chars per token ratio: {(sum(em_lengths)/len(em_lengths)) / avg_tokens:.2f}")
|
| 75 |
+
|
| 76 |
+
# Analyze treatment chunks
|
| 77 |
+
tr_lengths = [len(chunk['text']) for chunk in treatment_chunks]
|
| 78 |
+
|
| 79 |
+
print(f"\nπ Treatment Chunks Analysis:")
|
| 80 |
+
print(f"β’ Total chunks: {len(tr_lengths):,}")
|
| 81 |
+
print(f"β’ Min length: {min(tr_lengths)} chars")
|
| 82 |
+
print(f"β’ Max length: {max(tr_lengths)} chars")
|
| 83 |
+
print(f"β’ Average length: {sum(tr_lengths)/len(tr_lengths):.2f} chars")
|
| 84 |
+
print(f"β’ Median length: {sorted(tr_lengths)[len(tr_lengths)//2]} chars")
|
| 85 |
+
|
| 86 |
+
# Length distribution comparison
|
| 87 |
+
em_avg = sum(em_lengths)/len(em_lengths)
|
| 88 |
+
tr_avg = sum(tr_lengths)/len(tr_lengths)
|
| 89 |
+
ratio = em_avg / tr_avg
|
| 90 |
+
|
| 91 |
+
print(f"\nπ Length Distribution Comparison:")
|
| 92 |
+
print(f"β’ Emergency average: {em_avg:.0f} chars")
|
| 93 |
+
print(f"β’ Treatment average: {tr_avg:.0f} chars")
|
| 94 |
+
print(f"β’ Ratio (Emergency/Treatment): {ratio:.1f}x")
|
| 95 |
+
|
| 96 |
+
# Length distribution buckets
|
| 97 |
+
print(f"\nπ Length Distribution Buckets:")
|
| 98 |
+
buckets = [0, 100, 250, 500, 1000, 2000, 5000]
|
| 99 |
+
|
| 100 |
+
for i in range(len(buckets)-1):
|
| 101 |
+
em_count = sum(1 for l in em_lengths if buckets[i] <= l < buckets[i+1])
|
| 102 |
+
tr_count = sum(1 for l in tr_lengths if buckets[i] <= l < buckets[i+1])
|
| 103 |
+
print(f"β’ {buckets[i]}-{buckets[i+1]} chars: Emergency={em_count}, Treatment={tr_count}")
|
| 104 |
+
|
| 105 |
+
# Flag potential issues
|
| 106 |
+
if ratio > 5.0:
|
| 107 |
+
print(f"\nβ οΈ WARNING: Emergency chunks are {ratio:.1f}x longer than treatment chunks!")
|
| 108 |
+
print(" This suggests different chunking strategies are being used.")
|
| 109 |
+
|
| 110 |
+
print("β
Chunk length analysis completed")
|
| 111 |
+
|
| 112 |
+
except Exception as e:
|
| 113 |
+
print(f"β Error in chunk length analysis: {str(e)}")
|
| 114 |
+
raise
|
| 115 |
+
|
| 116 |
+
def test_chunking_method_comparison(self):
|
| 117 |
+
"""Compare the two chunking methods on the same data"""
|
| 118 |
+
print("\n=== Phase 3: Chunking Method Comparison ===")
|
| 119 |
+
|
| 120 |
+
try:
|
| 121 |
+
# Load data
|
| 122 |
+
print("β’ Loading dataset for comparison...")
|
| 123 |
+
self.processor.load_filtered_data()
|
| 124 |
+
|
| 125 |
+
# Test on multiple samples for better analysis
|
| 126 |
+
sample_size = 5
|
| 127 |
+
samples = self.processor.treatment_data.head(sample_size)
|
| 128 |
+
|
| 129 |
+
method1_results = [] # keyword_centered_chunks
|
| 130 |
+
method2_results = [] # dual_keyword_chunks
|
| 131 |
+
|
| 132 |
+
print(f"β’ Testing {sample_size} samples with both methods...")
|
| 133 |
+
|
| 134 |
+
for idx, row in samples.iterrows():
|
| 135 |
+
if not row.get('clean_text') or not row.get('treatment_matched'):
|
| 136 |
+
continue
|
| 137 |
+
|
| 138 |
+
text_length = len(row['clean_text'])
|
| 139 |
+
emergency_kw = row.get('matched', '')
|
| 140 |
+
treatment_kw = row['treatment_matched']
|
| 141 |
+
|
| 142 |
+
# Method 1: keyword_centered_chunks (Emergency method)
|
| 143 |
+
chunks1 = self.processor.create_keyword_centered_chunks(
|
| 144 |
+
text=row['clean_text'],
|
| 145 |
+
matched_keywords=emergency_kw,
|
| 146 |
+
chunk_size=256,
|
| 147 |
+
doc_id=f"test_{idx}"
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# Method 2: dual_keyword_chunks (Treatment method)
|
| 151 |
+
chunks2 = self.processor.create_dual_keyword_chunks(
|
| 152 |
+
text=row['clean_text'],
|
| 153 |
+
emergency_keywords=emergency_kw,
|
| 154 |
+
treatment_keywords=treatment_kw,
|
| 155 |
+
chunk_size=256,
|
| 156 |
+
doc_id=f"test_{idx}"
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# Collect results
|
| 160 |
+
if chunks1:
|
| 161 |
+
avg_len1 = sum(len(c['text']) for c in chunks1) / len(chunks1)
|
| 162 |
+
method1_results.append({
|
| 163 |
+
'doc_id': idx,
|
| 164 |
+
'chunks_count': len(chunks1),
|
| 165 |
+
'avg_length': avg_len1,
|
| 166 |
+
'text_length': text_length
|
| 167 |
+
})
|
| 168 |
+
|
| 169 |
+
if chunks2:
|
| 170 |
+
avg_len2 = sum(len(c['text']) for c in chunks2) / len(chunks2)
|
| 171 |
+
method2_results.append({
|
| 172 |
+
'doc_id': idx,
|
| 173 |
+
'chunks_count': len(chunks2),
|
| 174 |
+
'avg_length': avg_len2,
|
| 175 |
+
'text_length': text_length
|
| 176 |
+
})
|
| 177 |
+
|
| 178 |
+
# Analysis results
|
| 179 |
+
print(f"\nπ Method Comparison Results:")
|
| 180 |
+
|
| 181 |
+
if method1_results:
|
| 182 |
+
avg_chunks1 = sum(r['chunks_count'] for r in method1_results) / len(method1_results)
|
| 183 |
+
avg_len1 = sum(r['avg_length'] for r in method1_results) / len(method1_results)
|
| 184 |
+
print(f"\nπΉ Keyword-Centered Method (Emergency):")
|
| 185 |
+
print(f"β’ Average chunks per document: {avg_chunks1:.1f}")
|
| 186 |
+
print(f"β’ Average chunk length: {avg_len1:.0f} chars")
|
| 187 |
+
|
| 188 |
+
if method2_results:
|
| 189 |
+
avg_chunks2 = sum(r['chunks_count'] for r in method2_results) / len(method2_results)
|
| 190 |
+
avg_len2 = sum(r['avg_length'] for r in method2_results) / len(method2_results)
|
| 191 |
+
print(f"\nπΉ Dual-Keyword Method (Treatment):")
|
| 192 |
+
print(f"β’ Average chunks per document: {avg_chunks2:.1f}")
|
| 193 |
+
print(f"β’ Average chunk length: {avg_len2:.0f} chars")
|
| 194 |
+
|
| 195 |
+
if method1_results:
|
| 196 |
+
ratio = avg_len1 / avg_len2
|
| 197 |
+
print(f"\nπ Length Ratio: {ratio:.1f}x (Method1 / Method2)")
|
| 198 |
+
|
| 199 |
+
print("β
Chunking method comparison completed")
|
| 200 |
+
|
| 201 |
+
except Exception as e:
|
| 202 |
+
print(f"β Error in method comparison: {str(e)}")
|
| 203 |
+
raise
|
| 204 |
+
|
| 205 |
+
def test_token_vs_character_analysis(self):
|
| 206 |
+
"""Analyze token vs character differences in chunking"""
|
| 207 |
+
print("\n=== Phase 4: Token vs Character Analysis ===")
|
| 208 |
+
|
| 209 |
+
try:
|
| 210 |
+
# Load model for tokenization
|
| 211 |
+
print("β’ Loading embedding model for tokenization...")
|
| 212 |
+
self.processor.load_embedding_model()
|
| 213 |
+
|
| 214 |
+
# Test sample texts
|
| 215 |
+
test_texts = [
|
| 216 |
+
"Patient presents with acute chest pain and shortness of breath.",
|
| 217 |
+
"Emergency treatment for myocardial infarction includes immediate medication.",
|
| 218 |
+
"The patient's vital signs show tachycardia and hypotension requiring intervention."
|
| 219 |
+
]
|
| 220 |
+
|
| 221 |
+
print(f"\nπ Token vs Character Analysis:")
|
| 222 |
+
|
| 223 |
+
total_chars = 0
|
| 224 |
+
total_tokens = 0
|
| 225 |
+
|
| 226 |
+
for i, text in enumerate(test_texts, 1):
|
| 227 |
+
char_count = len(text)
|
| 228 |
+
token_count = len(self.processor.tokenizer.tokenize(text))
|
| 229 |
+
ratio = char_count / token_count if token_count > 0 else 0
|
| 230 |
+
|
| 231 |
+
print(f"\nSample {i}:")
|
| 232 |
+
print(f"β’ Text: {text[:50]}...")
|
| 233 |
+
print(f"β’ Characters: {char_count}")
|
| 234 |
+
print(f"β’ Tokens: {token_count}")
|
| 235 |
+
print(f"β’ Chars/Token ratio: {ratio:.2f}")
|
| 236 |
+
|
| 237 |
+
total_chars += char_count
|
| 238 |
+
total_tokens += token_count
|
| 239 |
+
|
| 240 |
+
overall_ratio = total_chars / total_tokens
|
| 241 |
+
print(f"\nπ Overall Character/Token Ratio: {overall_ratio:.2f}")
|
| 242 |
+
|
| 243 |
+
# Estimate chunk sizes
|
| 244 |
+
target_tokens = 256
|
| 245 |
+
estimated_chars = target_tokens * overall_ratio
|
| 246 |
+
|
| 247 |
+
print(f"\nπ Chunk Size Estimates:")
|
| 248 |
+
print(f"β’ Target tokens: {target_tokens}")
|
| 249 |
+
print(f"β’ Estimated characters: {estimated_chars:.0f}")
|
| 250 |
+
print(f"β’ Current emergency avg: 1842 chars ({1842/overall_ratio:.0f} estimated tokens)")
|
| 251 |
+
print(f"β’ Current treatment avg: 250 chars ({250/overall_ratio:.0f} estimated tokens)")
|
| 252 |
+
|
| 253 |
+
# Recommendations
|
| 254 |
+
print(f"\nπ‘ Recommendations:")
|
| 255 |
+
if 1842/overall_ratio > 512:
|
| 256 |
+
print("β οΈ Emergency chunks may exceed model's 512 token limit!")
|
| 257 |
+
if 250/overall_ratio < 64:
|
| 258 |
+
print("β οΈ Treatment chunks may be too short for meaningful context!")
|
| 259 |
+
|
| 260 |
+
print("β
Token vs character analysis completed")
|
| 261 |
+
|
| 262 |
+
except Exception as e:
|
| 263 |
+
print(f"β Error in token analysis: {str(e)}")
|
| 264 |
+
raise
|
| 265 |
+
|
| 266 |
+
def test_generate_recommendations(self):
|
| 267 |
+
"""Generate recommendations based on analysis"""
|
| 268 |
+
print("\n=== Phase 5: Generating Recommendations ===")
|
| 269 |
+
|
| 270 |
+
recommendations = []
|
| 271 |
+
|
| 272 |
+
# Based on the known chunk length difference
|
| 273 |
+
recommendations.append({
|
| 274 |
+
'issue': 'Inconsistent chunk lengths',
|
| 275 |
+
'description': 'Emergency chunks (1842 chars) are 7x longer than treatment chunks (250 chars)',
|
| 276 |
+
'recommendation': 'Standardize both methods to use token-based chunking with consistent parameters',
|
| 277 |
+
'priority': 'HIGH'
|
| 278 |
+
})
|
| 279 |
+
|
| 280 |
+
recommendations.append({
|
| 281 |
+
'issue': 'Different chunking strategies',
|
| 282 |
+
'description': 'Emergency uses keyword-centered (token-based), Treatment uses dual-keyword (character-based)',
|
| 283 |
+
'recommendation': 'Update dual_keyword_chunks to use tokenizer for consistent token-based chunking',
|
| 284 |
+
'priority': 'HIGH'
|
| 285 |
+
})
|
| 286 |
+
|
| 287 |
+
recommendations.append({
|
| 288 |
+
'issue': 'Potential token limit overflow',
|
| 289 |
+
'description': 'Large chunks may exceed PubMedBERT 512 token limit',
|
| 290 |
+
'recommendation': 'Implement strict token-based chunking with overlap to prevent overflow',
|
| 291 |
+
'priority': 'MEDIUM'
|
| 292 |
+
})
|
| 293 |
+
|
| 294 |
+
print(f"\nπ Analysis Recommendations:")
|
| 295 |
+
for i, rec in enumerate(recommendations, 1):
|
| 296 |
+
print(f"\n{i}. {rec['issue']} [{rec['priority']}]")
|
| 297 |
+
print(f" Problem: {rec['description']}")
|
| 298 |
+
print(f" Solution: {rec['recommendation']}")
|
| 299 |
+
|
| 300 |
+
print("\nβ
Recommendations generated")
|
| 301 |
+
return recommendations
|
| 302 |
+
|
| 303 |
+
def main():
|
| 304 |
+
"""Run all chunk quality analysis tests"""
|
| 305 |
+
print("\n" + "="*60)
|
| 306 |
+
print("CHUNK QUALITY ANALYSIS TEST SUITE")
|
| 307 |
+
print("="*60)
|
| 308 |
+
|
| 309 |
+
test = TestChunkQualityAnalysis()
|
| 310 |
+
test.setup_class()
|
| 311 |
+
|
| 312 |
+
try:
|
| 313 |
+
test.test_chunk_length_analysis()
|
| 314 |
+
test.test_chunking_method_comparison()
|
| 315 |
+
test.test_token_vs_character_analysis()
|
| 316 |
+
recommendations = test.test_generate_recommendations()
|
| 317 |
+
|
| 318 |
+
print("\n" + "="*60)
|
| 319 |
+
print("π ALL CHUNK QUALITY TESTS COMPLETED SUCCESSFULLY!")
|
| 320 |
+
print("="*60)
|
| 321 |
+
print(f"\nKey Finding: Chunk length inconsistency detected!")
|
| 322 |
+
print(f"Emergency: ~1842 chars, Treatment: ~250 chars (7x difference)")
|
| 323 |
+
print(f"Recommendation: Standardize to token-based chunking")
|
| 324 |
+
print("="*60)
|
| 325 |
+
|
| 326 |
+
except Exception as e:
|
| 327 |
+
print("\n" + "="*60)
|
| 328 |
+
print("β CHUNK QUALITY TESTS FAILED!")
|
| 329 |
+
print(f"Error: {str(e)}")
|
| 330 |
+
print("="*60)
|
| 331 |
+
|
| 332 |
+
if __name__ == "__main__":
|
| 333 |
+
main()
|
tests/test_embedding_and_index.py
CHANGED
|
@@ -1,29 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
from annoy import AnnoyIndex
|
| 3 |
import pytest
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from data_processing import DataProcessor
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Basic embedding and index validation tests
|
| 3 |
+
"""
|
| 4 |
+
# 2025-07-28
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
#
|
| 8 |
+
|
| 9 |
import numpy as np
|
| 10 |
from annoy import AnnoyIndex
|
| 11 |
import pytest
|
| 12 |
+
|
| 13 |
+
print("\n=== Phase 1: Initializing Test Environment ===")
|
| 14 |
+
# add src to python path
|
| 15 |
+
current_dir = Path(__file__).parent.resolve()
|
| 16 |
+
project_root = current_dir.parent
|
| 17 |
+
sys.path.append(str(project_root / "src"))
|
| 18 |
+
|
| 19 |
+
print(f"β’ Current directory: {current_dir}")
|
| 20 |
+
print(f"β’ Project root: {project_root}")
|
| 21 |
+
print(f"β’ Python path: {sys.path}")
|
| 22 |
+
|
| 23 |
from data_processing import DataProcessor
|
| 24 |
|
| 25 |
+
|
| 26 |
+
class TestEmbeddingAndIndex:
|
| 27 |
+
def setup_class(self):
|
| 28 |
+
"""εε§ε測試ι‘"""
|
| 29 |
+
print("\n=== Phase 2: Setting up TestEmbeddingAndIndex ===")
|
| 30 |
+
self.base_dir = Path(__file__).parent.parent.resolve()
|
| 31 |
+
print(f"β’ Base directory: {self.base_dir}")
|
| 32 |
+
self.processor = DataProcessor(base_dir=str(self.base_dir))
|
| 33 |
+
print("β’ DataProcessor initialized")
|
| 34 |
+
|
| 35 |
+
def test_embedding_dimensions(self):
|
| 36 |
+
print("\n=== Phase 3: Testing Embedding Dimensions ===")
|
| 37 |
+
print("β’ Loading emergency embeddings...")
|
| 38 |
+
# load emergency embeddings
|
| 39 |
+
emb = np.load(self.processor.models_dir / "embeddings" / "emergency_embeddings.npy")
|
| 40 |
+
expected_dim = self.processor.embedding_dim
|
| 41 |
+
|
| 42 |
+
print(f"β’ Loaded embedding shape: {emb.shape}")
|
| 43 |
+
print(f"β’ Expected dimension: {expected_dim}")
|
| 44 |
+
|
| 45 |
+
assert emb.ndim == 2, f"Expected 2D array, got {emb.ndim}D"
|
| 46 |
+
assert emb.shape[1] == expected_dim, (
|
| 47 |
+
f"Expected embedding dimension {expected_dim}, got {emb.shape[1]}"
|
| 48 |
+
)
|
| 49 |
+
print("β
Embedding dimensions test passed")
|
| 50 |
+
|
| 51 |
+
def test_annoy_search(self):
|
| 52 |
+
print("\n=== Phase 4: Testing Annoy Search ===")
|
| 53 |
+
print("β’ Loading embeddings...")
|
| 54 |
+
# load embeddings
|
| 55 |
+
emb = np.load(self.processor.models_dir / "embeddings" / "emergency_embeddings.npy")
|
| 56 |
+
print(f"β’ Loaded embeddings shape: {emb.shape}")
|
| 57 |
+
|
| 58 |
+
print("β’ Loading Annoy index...")
|
| 59 |
+
# load Annoy index
|
| 60 |
+
idx = AnnoyIndex(self.processor.embedding_dim, 'angular')
|
| 61 |
+
index_path = self.processor.models_dir / "indices" / "annoy" / "emergency_index.ann"
|
| 62 |
+
print(f"β’ Index path: {index_path}")
|
| 63 |
+
idx.load(str(index_path))
|
| 64 |
+
|
| 65 |
+
print("β’ Performing sample query...")
|
| 66 |
+
# perform a sample query
|
| 67 |
+
query_vec = emb[0]
|
| 68 |
+
ids, distances = idx.get_nns_by_vector(query_vec, 5, include_distances=True)
|
| 69 |
+
|
| 70 |
+
print(f"β’ Search results:")
|
| 71 |
+
print(f" - Found IDs: {ids}")
|
| 72 |
+
print(f" - Distances: {[f'{d:.4f}' for d in distances]}")
|
| 73 |
+
|
| 74 |
+
assert len(ids) == 5, f"Expected 5 results, got {len(ids)}"
|
| 75 |
+
assert all(0 <= d <= 2 for d in distances), "Invalid distance values"
|
| 76 |
+
print("β
Annoy search test passed")
|
| 77 |
+
|
| 78 |
+
def main():
|
| 79 |
+
"""Run tests manually"""
|
| 80 |
+
print("\n" + "="*50)
|
| 81 |
+
print("Starting Embedding and Index Tests")
|
| 82 |
+
print("="*50)
|
| 83 |
+
|
| 84 |
+
test = TestEmbeddingAndIndex()
|
| 85 |
+
test.setup_class() # ζεεε§ε
|
| 86 |
+
|
| 87 |
+
try:
|
| 88 |
+
test.test_embedding_dimensions()
|
| 89 |
+
test.test_annoy_search()
|
| 90 |
+
print("\n" + "="*50)
|
| 91 |
+
print("π All tests completed successfully!")
|
| 92 |
+
print("="*50)
|
| 93 |
+
|
| 94 |
+
except Exception as e:
|
| 95 |
+
print("\n" + "="*50)
|
| 96 |
+
print("β Tests failed!")
|
| 97 |
+
print(f"Error: {str(e)}")
|
| 98 |
+
print("="*50)
|
| 99 |
+
|
| 100 |
+
if __name__ == "__main__":
|
| 101 |
+
main()
|
tests/test_embedding_validation.py
CHANGED
|
@@ -7,14 +7,27 @@ import numpy as np
|
|
| 7 |
import json
|
| 8 |
import logging
|
| 9 |
import os
|
|
|
|
| 10 |
from pathlib import Path
|
| 11 |
from typing import Tuple, List, Optional
|
| 12 |
from annoy import AnnoyIndex
|
| 13 |
from sentence_transformers import SentenceTransformer
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
class TestEmbeddingValidation:
|
| 16 |
def setup_class(self):
|
| 17 |
"""Initialize test environment with necessary data and models."""
|
|
|
|
|
|
|
| 18 |
# Setup logging
|
| 19 |
logging.basicConfig(
|
| 20 |
level=logging.DEBUG,
|
|
@@ -24,43 +37,57 @@ class TestEmbeddingValidation:
|
|
| 24 |
self.logger = logging.getLogger(__name__)
|
| 25 |
|
| 26 |
# Define base paths
|
| 27 |
-
self.project_root = Path(
|
| 28 |
self.models_dir = self.project_root / "models"
|
| 29 |
self.embeddings_dir = self.models_dir / "embeddings"
|
| 30 |
self.indices_dir = self.models_dir / "indices" / "annoy"
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
self.logger.info(f"Project root: {self.project_root}")
|
| 33 |
self.logger.info(f"Models directory: {self.models_dir}")
|
| 34 |
self.logger.info(f"Embeddings directory: {self.embeddings_dir}")
|
| 35 |
|
| 36 |
try:
|
| 37 |
# Check directory existence
|
|
|
|
| 38 |
if not self.embeddings_dir.exists():
|
| 39 |
raise FileNotFoundError(f"Embeddings directory not found at: {self.embeddings_dir}")
|
| 40 |
if not self.indices_dir.exists():
|
| 41 |
raise FileNotFoundError(f"Indices directory not found at: {self.indices_dir}")
|
| 42 |
|
| 43 |
# Load embeddings
|
|
|
|
| 44 |
self.emergency_emb = np.load(self.embeddings_dir / "emergency_embeddings.npy")
|
| 45 |
self.treatment_emb = np.load(self.embeddings_dir / "treatment_embeddings.npy")
|
| 46 |
|
| 47 |
# Load chunks
|
|
|
|
| 48 |
with open(self.embeddings_dir / "emergency_chunks.json", 'r') as f:
|
| 49 |
self.emergency_chunks = json.load(f)
|
| 50 |
with open(self.embeddings_dir / "treatment_chunks.json", 'r') as f:
|
| 51 |
self.treatment_chunks = json.load(f)
|
| 52 |
|
| 53 |
# Initialize model
|
|
|
|
| 54 |
self.model = SentenceTransformer("NeuML/pubmedbert-base-embeddings")
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
self.logger.info("Test environment initialized successfully")
|
| 57 |
self.logger.info(f"Emergency embeddings shape: {self.emergency_emb.shape}")
|
| 58 |
self.logger.info(f"Treatment embeddings shape: {self.treatment_emb.shape}")
|
| 59 |
|
| 60 |
except FileNotFoundError as e:
|
|
|
|
| 61 |
self.logger.error(f"File not found: {e}")
|
| 62 |
raise
|
| 63 |
except Exception as e:
|
|
|
|
| 64 |
self.logger.error(f"Error during initialization: {e}")
|
| 65 |
raise
|
| 66 |
|
|
@@ -84,20 +111,28 @@ class TestEmbeddingValidation:
|
|
| 84 |
|
| 85 |
def test_embedding_dimensions(self):
|
| 86 |
"""Test embedding dimensions and data quality."""
|
|
|
|
| 87 |
self.logger.info("\n=== Embedding Validation Report ===")
|
| 88 |
|
| 89 |
try:
|
| 90 |
# Basic dimension checks
|
|
|
|
| 91 |
assert self.emergency_emb.shape[1] == 768, "Emergency embedding dimension should be 768"
|
| 92 |
assert self.treatment_emb.shape[1] == 768, "Treatment embedding dimension should be 768"
|
|
|
|
|
|
|
| 93 |
|
| 94 |
# Count verification
|
|
|
|
| 95 |
assert len(self.emergency_chunks) == self.emergency_emb.shape[0], \
|
| 96 |
"Emergency chunks count mismatch"
|
| 97 |
assert len(self.treatment_chunks) == self.treatment_emb.shape[0], \
|
| 98 |
"Treatment chunks count mismatch"
|
|
|
|
|
|
|
| 99 |
|
| 100 |
# Data quality checks
|
|
|
|
| 101 |
for name, emb in [("Emergency", self.emergency_emb),
|
| 102 |
("Treatment", self.treatment_emb)]:
|
| 103 |
# Check for NaN and Inf
|
|
@@ -105,25 +140,35 @@ class TestEmbeddingValidation:
|
|
| 105 |
assert not np.isinf(emb).any(), f"{name} contains Inf values"
|
| 106 |
|
| 107 |
# Value distribution analysis
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
self.logger.info(f"\n{name} Embeddings Statistics:")
|
| 109 |
self.logger.info(f"- Range: {np.min(emb):.3f} to {np.max(emb):.3f}")
|
| 110 |
self.logger.info(f"- Mean: {np.mean(emb):.3f}")
|
| 111 |
self.logger.info(f"- Std: {np.std(emb):.3f}")
|
| 112 |
|
|
|
|
| 113 |
self.logger.info("\nβ
All embedding validations passed")
|
| 114 |
|
| 115 |
except AssertionError as e:
|
|
|
|
| 116 |
self.logger.error(f"Validation failed: {str(e)}")
|
| 117 |
raise
|
| 118 |
|
| 119 |
def test_multiple_known_item_search(self):
|
| 120 |
"""Test ANNOY search with multiple random samples."""
|
|
|
|
| 121 |
self.logger.info("\n=== Multiple Known-Item Search Test ===")
|
| 122 |
|
|
|
|
| 123 |
emergency_index = AnnoyIndex(768, 'angular')
|
| 124 |
emergency_index.load(str(self.indices_dir / "emergency_index.ann"))
|
| 125 |
|
| 126 |
# Test 20 random samples
|
|
|
|
| 127 |
test_indices = np.random.choice(
|
| 128 |
self.emergency_emb.shape[0],
|
| 129 |
size=20,
|
|
@@ -131,36 +176,45 @@ class TestEmbeddingValidation:
|
|
| 131 |
)
|
| 132 |
|
| 133 |
success_count = 0
|
| 134 |
-
for
|
|
|
|
| 135 |
try:
|
| 136 |
test_emb = self.emergency_emb[test_idx]
|
| 137 |
indices, distances = self._safe_search(emergency_index, test_emb)
|
| 138 |
|
| 139 |
if indices is None:
|
|
|
|
| 140 |
continue
|
| 141 |
|
| 142 |
# Verify self-retrieval
|
| 143 |
assert indices[0] == test_idx, f"Self-retrieval failed for index {test_idx}"
|
| 144 |
assert distances[0] < 0.0001, f"Self-distance too large for index {test_idx}"
|
| 145 |
success_count += 1
|
|
|
|
| 146 |
|
| 147 |
except AssertionError as e:
|
|
|
|
| 148 |
self.logger.warning(f"Test failed for index {test_idx}: {str(e)}")
|
| 149 |
|
|
|
|
| 150 |
self.logger.info(f"\nβ
{success_count}/20 self-retrieval tests passed")
|
| 151 |
assert success_count >= 18, "Less than 90% of self-retrieval tests passed"
|
|
|
|
| 152 |
|
| 153 |
def test_balanced_cross_dataset_search(self):
|
| 154 |
"""Test search across both emergency and treatment datasets."""
|
|
|
|
| 155 |
self.logger.info("\n=== Balanced Cross-Dataset Search Test ===")
|
| 156 |
|
| 157 |
# Initialize indices
|
|
|
|
| 158 |
emergency_index = AnnoyIndex(768, 'angular')
|
| 159 |
treatment_index = AnnoyIndex(768, 'angular')
|
| 160 |
|
| 161 |
try:
|
| 162 |
emergency_index.load(str(self.indices_dir / "emergency_index.ann"))
|
| 163 |
treatment_index.load(str(self.indices_dir / "treatment_index.ann"))
|
|
|
|
| 164 |
|
| 165 |
# Test queries
|
| 166 |
test_queries = [
|
|
@@ -169,45 +223,75 @@ class TestEmbeddingValidation:
|
|
| 169 |
"What are the emergency procedures for anaphylactic shock?"
|
| 170 |
]
|
| 171 |
|
| 172 |
-
|
| 173 |
-
|
|
|
|
|
|
|
| 174 |
|
| 175 |
# Generate query vector
|
|
|
|
| 176 |
query_emb = self.model.encode([query])[0]
|
| 177 |
|
| 178 |
# Get top-5 results from each dataset
|
|
|
|
| 179 |
e_indices, e_distances = self._safe_search(emergency_index, query_emb, k=5)
|
| 180 |
t_indices, t_distances = self._safe_search(treatment_index, query_emb, k=5)
|
| 181 |
|
| 182 |
if None in [e_indices, e_distances, t_indices, t_distances]:
|
|
|
|
| 183 |
self.logger.error("Search failed for one or both datasets")
|
| 184 |
continue
|
| 185 |
|
| 186 |
# Print first sentence of each result
|
| 187 |
-
print("\
|
| 188 |
for i, (idx, dist) in enumerate(zip(e_indices, e_distances), 1):
|
| 189 |
text = self.emergency_chunks[idx]['text']
|
| 190 |
first_sentence = text.split('.')[0] + '.'
|
| 191 |
-
print(f"
|
| 192 |
-
print(first_sentence)
|
| 193 |
|
| 194 |
-
print("\
|
| 195 |
for i, (idx, dist) in enumerate(zip(t_indices, t_distances), 1):
|
| 196 |
text = self.treatment_chunks[idx]['text']
|
| 197 |
first_sentence = text.split('.')[0] + '.'
|
| 198 |
-
print(f"
|
| 199 |
-
|
|
|
|
| 200 |
|
| 201 |
except Exception as e:
|
|
|
|
| 202 |
self.logger.error(f"Test failed: {str(e)}")
|
| 203 |
raise
|
| 204 |
else:
|
|
|
|
| 205 |
self.logger.info("\nβ
Cross-dataset search test completed")
|
| 206 |
|
| 207 |
-
|
| 208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
test = TestEmbeddingValidation()
|
| 210 |
test.setup_class()
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
import json
|
| 8 |
import logging
|
| 9 |
import os
|
| 10 |
+
import sys
|
| 11 |
from pathlib import Path
|
| 12 |
from typing import Tuple, List, Optional
|
| 13 |
from annoy import AnnoyIndex
|
| 14 |
from sentence_transformers import SentenceTransformer
|
| 15 |
|
| 16 |
+
print("\n=== Phase 1: Initializing Test Environment ===")
|
| 17 |
+
# Add src to python path
|
| 18 |
+
current_dir = Path(__file__).parent.resolve()
|
| 19 |
+
project_root = current_dir.parent
|
| 20 |
+
sys.path.append(str(project_root / "src"))
|
| 21 |
+
|
| 22 |
+
print(f"β’ Current directory: {current_dir}")
|
| 23 |
+
print(f"β’ Project root: {project_root}")
|
| 24 |
+
print(f"β’ Python path added: {project_root / 'src'}")
|
| 25 |
+
|
| 26 |
class TestEmbeddingValidation:
|
| 27 |
def setup_class(self):
|
| 28 |
"""Initialize test environment with necessary data and models."""
|
| 29 |
+
print("\n=== Phase 2: Setting up Test Environment ===")
|
| 30 |
+
|
| 31 |
# Setup logging
|
| 32 |
logging.basicConfig(
|
| 33 |
level=logging.DEBUG,
|
|
|
|
| 37 |
self.logger = logging.getLogger(__name__)
|
| 38 |
|
| 39 |
# Define base paths
|
| 40 |
+
self.project_root = Path(__file__).parent.parent.resolve()
|
| 41 |
self.models_dir = self.project_root / "models"
|
| 42 |
self.embeddings_dir = self.models_dir / "embeddings"
|
| 43 |
self.indices_dir = self.models_dir / "indices" / "annoy"
|
| 44 |
|
| 45 |
+
print(f"β’ Project root: {self.project_root}")
|
| 46 |
+
print(f"β’ Models directory: {self.models_dir}")
|
| 47 |
+
print(f"β’ Embeddings directory: {self.embeddings_dir}")
|
| 48 |
+
|
| 49 |
self.logger.info(f"Project root: {self.project_root}")
|
| 50 |
self.logger.info(f"Models directory: {self.models_dir}")
|
| 51 |
self.logger.info(f"Embeddings directory: {self.embeddings_dir}")
|
| 52 |
|
| 53 |
try:
|
| 54 |
# Check directory existence
|
| 55 |
+
print("β’ Checking directory existence...")
|
| 56 |
if not self.embeddings_dir.exists():
|
| 57 |
raise FileNotFoundError(f"Embeddings directory not found at: {self.embeddings_dir}")
|
| 58 |
if not self.indices_dir.exists():
|
| 59 |
raise FileNotFoundError(f"Indices directory not found at: {self.indices_dir}")
|
| 60 |
|
| 61 |
# Load embeddings
|
| 62 |
+
print("β’ Loading embeddings...")
|
| 63 |
self.emergency_emb = np.load(self.embeddings_dir / "emergency_embeddings.npy")
|
| 64 |
self.treatment_emb = np.load(self.embeddings_dir / "treatment_embeddings.npy")
|
| 65 |
|
| 66 |
# Load chunks
|
| 67 |
+
print("β’ Loading chunk metadata...")
|
| 68 |
with open(self.embeddings_dir / "emergency_chunks.json", 'r') as f:
|
| 69 |
self.emergency_chunks = json.load(f)
|
| 70 |
with open(self.embeddings_dir / "treatment_chunks.json", 'r') as f:
|
| 71 |
self.treatment_chunks = json.load(f)
|
| 72 |
|
| 73 |
# Initialize model
|
| 74 |
+
print("β’ Loading PubMedBERT model...")
|
| 75 |
self.model = SentenceTransformer("NeuML/pubmedbert-base-embeddings")
|
| 76 |
|
| 77 |
+
print(f"β’ Emergency embeddings shape: {self.emergency_emb.shape}")
|
| 78 |
+
print(f"β’ Treatment embeddings shape: {self.treatment_emb.shape}")
|
| 79 |
+
print("β
Test environment initialized successfully")
|
| 80 |
+
|
| 81 |
self.logger.info("Test environment initialized successfully")
|
| 82 |
self.logger.info(f"Emergency embeddings shape: {self.emergency_emb.shape}")
|
| 83 |
self.logger.info(f"Treatment embeddings shape: {self.treatment_emb.shape}")
|
| 84 |
|
| 85 |
except FileNotFoundError as e:
|
| 86 |
+
print(f"β File not found: {e}")
|
| 87 |
self.logger.error(f"File not found: {e}")
|
| 88 |
raise
|
| 89 |
except Exception as e:
|
| 90 |
+
print(f"β Error during initialization: {e}")
|
| 91 |
self.logger.error(f"Error during initialization: {e}")
|
| 92 |
raise
|
| 93 |
|
|
|
|
| 111 |
|
| 112 |
def test_embedding_dimensions(self):
|
| 113 |
"""Test embedding dimensions and data quality."""
|
| 114 |
+
print("\n=== Phase 3: Embedding Validation ===")
|
| 115 |
self.logger.info("\n=== Embedding Validation Report ===")
|
| 116 |
|
| 117 |
try:
|
| 118 |
# Basic dimension checks
|
| 119 |
+
print("β’ Checking embedding dimensions...")
|
| 120 |
assert self.emergency_emb.shape[1] == 768, "Emergency embedding dimension should be 768"
|
| 121 |
assert self.treatment_emb.shape[1] == 768, "Treatment embedding dimension should be 768"
|
| 122 |
+
print(f"β Emergency dimensions: {self.emergency_emb.shape}")
|
| 123 |
+
print(f"β Treatment dimensions: {self.treatment_emb.shape}")
|
| 124 |
|
| 125 |
# Count verification
|
| 126 |
+
print("β’ Verifying chunk count consistency...")
|
| 127 |
assert len(self.emergency_chunks) == self.emergency_emb.shape[0], \
|
| 128 |
"Emergency chunks count mismatch"
|
| 129 |
assert len(self.treatment_chunks) == self.treatment_emb.shape[0], \
|
| 130 |
"Treatment chunks count mismatch"
|
| 131 |
+
print(f"β Emergency: {len(self.emergency_chunks)} chunks = {self.emergency_emb.shape[0]} embeddings")
|
| 132 |
+
print(f"β Treatment: {len(self.treatment_chunks)} chunks = {self.treatment_emb.shape[0]} embeddings")
|
| 133 |
|
| 134 |
# Data quality checks
|
| 135 |
+
print("β’ Performing data quality checks...")
|
| 136 |
for name, emb in [("Emergency", self.emergency_emb),
|
| 137 |
("Treatment", self.treatment_emb)]:
|
| 138 |
# Check for NaN and Inf
|
|
|
|
| 140 |
assert not np.isinf(emb).any(), f"{name} contains Inf values"
|
| 141 |
|
| 142 |
# Value distribution analysis
|
| 143 |
+
print(f"\nπ {name} Embeddings Statistics:")
|
| 144 |
+
print(f"β’ Range: {np.min(emb):.3f} to {np.max(emb):.3f}")
|
| 145 |
+
print(f"β’ Mean: {np.mean(emb):.3f}")
|
| 146 |
+
print(f"β’ Std: {np.std(emb):.3f}")
|
| 147 |
+
|
| 148 |
self.logger.info(f"\n{name} Embeddings Statistics:")
|
| 149 |
self.logger.info(f"- Range: {np.min(emb):.3f} to {np.max(emb):.3f}")
|
| 150 |
self.logger.info(f"- Mean: {np.mean(emb):.3f}")
|
| 151 |
self.logger.info(f"- Std: {np.std(emb):.3f}")
|
| 152 |
|
| 153 |
+
print("\nβ
All embedding validations passed")
|
| 154 |
self.logger.info("\nβ
All embedding validations passed")
|
| 155 |
|
| 156 |
except AssertionError as e:
|
| 157 |
+
print(f"β Validation failed: {str(e)}")
|
| 158 |
self.logger.error(f"Validation failed: {str(e)}")
|
| 159 |
raise
|
| 160 |
|
| 161 |
def test_multiple_known_item_search(self):
|
| 162 |
"""Test ANNOY search with multiple random samples."""
|
| 163 |
+
print("\n=== Phase 4: Multiple Known-Item Search Test ===")
|
| 164 |
self.logger.info("\n=== Multiple Known-Item Search Test ===")
|
| 165 |
|
| 166 |
+
print("β’ Loading emergency index...")
|
| 167 |
emergency_index = AnnoyIndex(768, 'angular')
|
| 168 |
emergency_index.load(str(self.indices_dir / "emergency_index.ann"))
|
| 169 |
|
| 170 |
# Test 20 random samples
|
| 171 |
+
print("β’ Selecting 20 random samples for self-retrieval test...")
|
| 172 |
test_indices = np.random.choice(
|
| 173 |
self.emergency_emb.shape[0],
|
| 174 |
size=20,
|
|
|
|
| 176 |
)
|
| 177 |
|
| 178 |
success_count = 0
|
| 179 |
+
print("β’ Testing self-retrieval for each sample...")
|
| 180 |
+
for i, test_idx in enumerate(test_indices, 1):
|
| 181 |
try:
|
| 182 |
test_emb = self.emergency_emb[test_idx]
|
| 183 |
indices, distances = self._safe_search(emergency_index, test_emb)
|
| 184 |
|
| 185 |
if indices is None:
|
| 186 |
+
print(f" {i}/20: β Search failed for index {test_idx}")
|
| 187 |
continue
|
| 188 |
|
| 189 |
# Verify self-retrieval
|
| 190 |
assert indices[0] == test_idx, f"Self-retrieval failed for index {test_idx}"
|
| 191 |
assert distances[0] < 0.0001, f"Self-distance too large for index {test_idx}"
|
| 192 |
success_count += 1
|
| 193 |
+
print(f" {i}/20: β Index {test_idx} (distance: {distances[0]:.6f})")
|
| 194 |
|
| 195 |
except AssertionError as e:
|
| 196 |
+
print(f" {i}/20: β Index {test_idx} failed: {str(e)}")
|
| 197 |
self.logger.warning(f"Test failed for index {test_idx}: {str(e)}")
|
| 198 |
|
| 199 |
+
print(f"\nπ Self-Retrieval Results: {success_count}/20 tests passed ({success_count/20*100:.1f}%)")
|
| 200 |
self.logger.info(f"\nβ
{success_count}/20 self-retrieval tests passed")
|
| 201 |
assert success_count >= 18, "Less than 90% of self-retrieval tests passed"
|
| 202 |
+
print("β
Multiple known-item search test passed")
|
| 203 |
|
| 204 |
def test_balanced_cross_dataset_search(self):
|
| 205 |
"""Test search across both emergency and treatment datasets."""
|
| 206 |
+
print("\n=== Phase 5: Cross-Dataset Search Test ===")
|
| 207 |
self.logger.info("\n=== Balanced Cross-Dataset Search Test ===")
|
| 208 |
|
| 209 |
# Initialize indices
|
| 210 |
+
print("β’ Loading ANNOY indices...")
|
| 211 |
emergency_index = AnnoyIndex(768, 'angular')
|
| 212 |
treatment_index = AnnoyIndex(768, 'angular')
|
| 213 |
|
| 214 |
try:
|
| 215 |
emergency_index.load(str(self.indices_dir / "emergency_index.ann"))
|
| 216 |
treatment_index.load(str(self.indices_dir / "treatment_index.ann"))
|
| 217 |
+
print("β Emergency and treatment indices loaded")
|
| 218 |
|
| 219 |
# Test queries
|
| 220 |
test_queries = [
|
|
|
|
| 223 |
"What are the emergency procedures for anaphylactic shock?"
|
| 224 |
]
|
| 225 |
|
| 226 |
+
print(f"β’ Testing {len(test_queries)} medical queries...")
|
| 227 |
+
|
| 228 |
+
for query_num, query in enumerate(test_queries, 1):
|
| 229 |
+
print(f"\nπ Query {query_num}/3: {query}")
|
| 230 |
|
| 231 |
# Generate query vector
|
| 232 |
+
print("β’ Generating query embedding...")
|
| 233 |
query_emb = self.model.encode([query])[0]
|
| 234 |
|
| 235 |
# Get top-5 results from each dataset
|
| 236 |
+
print("β’ Searching both datasets...")
|
| 237 |
e_indices, e_distances = self._safe_search(emergency_index, query_emb, k=5)
|
| 238 |
t_indices, t_distances = self._safe_search(treatment_index, query_emb, k=5)
|
| 239 |
|
| 240 |
if None in [e_indices, e_distances, t_indices, t_distances]:
|
| 241 |
+
print("β Search failed for one or both datasets")
|
| 242 |
self.logger.error("Search failed for one or both datasets")
|
| 243 |
continue
|
| 244 |
|
| 245 |
# Print first sentence of each result
|
| 246 |
+
print(f"\nπ Emergency Dataset Results:")
|
| 247 |
for i, (idx, dist) in enumerate(zip(e_indices, e_distances), 1):
|
| 248 |
text = self.emergency_chunks[idx]['text']
|
| 249 |
first_sentence = text.split('.')[0] + '.'
|
| 250 |
+
print(f" E-{i} (distance: {dist:.3f}): {first_sentence[:80]}...")
|
|
|
|
| 251 |
|
| 252 |
+
print(f"\nπ Treatment Dataset Results:")
|
| 253 |
for i, (idx, dist) in enumerate(zip(t_indices, t_distances), 1):
|
| 254 |
text = self.treatment_chunks[idx]['text']
|
| 255 |
first_sentence = text.split('.')[0] + '.'
|
| 256 |
+
print(f" T-{i} (distance: {dist:.3f}): {first_sentence[:80]}...")
|
| 257 |
+
|
| 258 |
+
print("β Query completed")
|
| 259 |
|
| 260 |
except Exception as e:
|
| 261 |
+
print(f"β Test failed: {str(e)}")
|
| 262 |
self.logger.error(f"Test failed: {str(e)}")
|
| 263 |
raise
|
| 264 |
else:
|
| 265 |
+
print("\nβ
Cross-dataset search test completed")
|
| 266 |
self.logger.info("\nβ
Cross-dataset search test completed")
|
| 267 |
|
| 268 |
+
def main():
|
| 269 |
+
"""Run all embedding validation tests"""
|
| 270 |
+
print("\n" + "="*60)
|
| 271 |
+
print("COMPREHENSIVE EMBEDDING VALIDATION TEST SUITE")
|
| 272 |
+
print("="*60)
|
| 273 |
+
|
| 274 |
test = TestEmbeddingValidation()
|
| 275 |
test.setup_class()
|
| 276 |
+
|
| 277 |
+
try:
|
| 278 |
+
test.test_embedding_dimensions()
|
| 279 |
+
test.test_multiple_known_item_search()
|
| 280 |
+
test.test_balanced_cross_dataset_search()
|
| 281 |
+
|
| 282 |
+
print("\n" + "="*60)
|
| 283 |
+
print("π ALL EMBEDDING VALIDATION TESTS COMPLETED SUCCESSFULLY!")
|
| 284 |
+
print("="*60)
|
| 285 |
+
print("β
Embedding dimensions validated")
|
| 286 |
+
print("β
Self-retrieval accuracy confirmed")
|
| 287 |
+
print("β
Cross-dataset search functionality verified")
|
| 288 |
+
print("="*60)
|
| 289 |
+
|
| 290 |
+
except Exception as e:
|
| 291 |
+
print("\n" + "="*60)
|
| 292 |
+
print("β EMBEDDING VALIDATION TESTS FAILED!")
|
| 293 |
+
print(f"Error: {str(e)}")
|
| 294 |
+
print("="*60)
|
| 295 |
+
|
| 296 |
+
if __name__ == "__main__":
|
| 297 |
+
main()
|