File size: 3,689 Bytes
e95bc70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import torch
import logging
from pathlib import Path
from typing import List, Dict, Tuple
from datasets import load_dataset
from greedy_search import find_best_combination
from cases_collect import valid_results_collect

def setup_logger() -> logging.Logger:
    """Configure and return logger."""
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s'
    )
    return logging.getLogger(__name__)

def get_model_paths(model_names: List[str], base_path: str = './') -> List[str]:
    """Generate model paths from names."""
    return [os.path.join(base_path, f"{name}_model") for name in model_names]

def load_test_data(dataset_name: str = 'hippocrates/MedNLI_test') -> List[Dict]:
    """Load and prepare test dataset."""
    dataset = load_dataset(dataset_name)
    return [
        {'Input': item['query'], 'Output': item['answer']} 
        for item in dataset['test']
    ]

def calculate_accuracy(correct: List, failed: List) -> float:
    """Calculate accuracy from correct and failed cases."""
    total = len(correct) + len(failed)
    return len(correct) / total if total > 0 else 0.0

def main():
    """Main execution function."""
    logger = setup_logger()
    
    try:
        # Configuration
        config = {
            'search_name': 'randoms_model',
            'model_names': ['randoms_data_3k_model'],
            'base_path': './',
            'valid_data_path': 'nli_demo.pt',
            'seed': True,
            'iteration': 5
        }
        
        # Generate model paths
        model_paths = get_model_paths(config['model_names'], config['base_path'])
        logger.info(f"Generated model paths: {model_paths}")
        
        # Load datasets
        logger.info("Loading test data...")
        test_examples = load_test_data()
        logger.info(f"Loaded {len(test_examples)} test examples")
        
        logger.info("Loading validation data...")
        try:
            valid_data = torch.load(config['valid_data_path'])
            logger.info(f"Loaded validation data from {config['valid_data_path']}")
        except Exception as e:
            logger.error(f"Failed to load validation data: {str(e)}")
            raise
        
        # Find best combination
        logger.info("Finding best model combination...")
        best_path, update_scores = find_best_combination(
            model_paths,
            valid_data,
            valid_data,
            config['search_name'],
            iteration=config['iteration'],
            seed=config['seed']
        )
        logger.info(f"Best path found with scores: {update_scores}")
        
        # Evaluate on test set
        logger.info("Evaluating on test set...")
        failed_test, correct_test = valid_results_collect(
            best_path,
            test_examples,
            'nli'
        )
        
        # Calculate and log accuracy
        accuracy = calculate_accuracy(correct_test, failed_test)
        logger.info(f"Test Accuracy: {accuracy:.4f}")
        
        # Save results
        results = {
            'best_path': best_path,
            'update_scores': update_scores,
            'test_accuracy': accuracy,
            'test_results': {
                'correct': len(correct_test),
                'failed': len(failed_test)
            }
        }
        
        save_path = Path(f"results_{config['search_name']}.pt")
        torch.save(results, save_path)
        logger.info(f"Results saved to {save_path}")
        
    except Exception as e:
        logger.error(f"Error in main execution: {str(e)}", exc_info=True)
        raise

if __name__ == "__main__":
    main()