YiDuo1999 commited on
Commit
e95bc70
·
verified ·
1 Parent(s): b051745

Create run_model_soups.py

Browse files
Files changed (1) hide show
  1. run_model_soups.py +111 -0
run_model_soups.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import List, Dict, Tuple
6
+ from datasets import load_dataset
7
+ from greedy_search import find_best_combination
8
+ from cases_collect import valid_results_collect
9
+
10
+ def setup_logger() -> logging.Logger:
11
+ """Configure and return logger."""
12
+ logging.basicConfig(
13
+ level=logging.INFO,
14
+ format='%(asctime)s - %(levelname)s - %(message)s'
15
+ )
16
+ return logging.getLogger(__name__)
17
+
18
+ def get_model_paths(model_names: List[str], base_path: str = './') -> List[str]:
19
+ """Generate model paths from names."""
20
+ return [os.path.join(base_path, f"{name}_model") for name in model_names]
21
+
22
+ def load_test_data(dataset_name: str = 'hippocrates/MedNLI_test') -> List[Dict]:
23
+ """Load and prepare test dataset."""
24
+ dataset = load_dataset(dataset_name)
25
+ return [
26
+ {'Input': item['query'], 'Output': item['answer']}
27
+ for item in dataset['test']
28
+ ]
29
+
30
+ def calculate_accuracy(correct: List, failed: List) -> float:
31
+ """Calculate accuracy from correct and failed cases."""
32
+ total = len(correct) + len(failed)
33
+ return len(correct) / total if total > 0 else 0.0
34
+
35
+ def main():
36
+ """Main execution function."""
37
+ logger = setup_logger()
38
+
39
+ try:
40
+ # Configuration
41
+ config = {
42
+ 'search_name': 'randoms_model',
43
+ 'model_names': ['randoms_data_3k_model'],
44
+ 'base_path': './',
45
+ 'valid_data_path': 'nli_demo.pt',
46
+ 'seed': True,
47
+ 'iteration': 5
48
+ }
49
+
50
+ # Generate model paths
51
+ model_paths = get_model_paths(config['model_names'], config['base_path'])
52
+ logger.info(f"Generated model paths: {model_paths}")
53
+
54
+ # Load datasets
55
+ logger.info("Loading test data...")
56
+ test_examples = load_test_data()
57
+ logger.info(f"Loaded {len(test_examples)} test examples")
58
+
59
+ logger.info("Loading validation data...")
60
+ try:
61
+ valid_data = torch.load(config['valid_data_path'])
62
+ logger.info(f"Loaded validation data from {config['valid_data_path']}")
63
+ except Exception as e:
64
+ logger.error(f"Failed to load validation data: {str(e)}")
65
+ raise
66
+
67
+ # Find best combination
68
+ logger.info("Finding best model combination...")
69
+ best_path, update_scores = find_best_combination(
70
+ model_paths,
71
+ valid_data,
72
+ valid_data,
73
+ config['search_name'],
74
+ iteration=config['iteration'],
75
+ seed=config['seed']
76
+ )
77
+ logger.info(f"Best path found with scores: {update_scores}")
78
+
79
+ # Evaluate on test set
80
+ logger.info("Evaluating on test set...")
81
+ failed_test, correct_test = valid_results_collect(
82
+ best_path,
83
+ test_examples,
84
+ 'nli'
85
+ )
86
+
87
+ # Calculate and log accuracy
88
+ accuracy = calculate_accuracy(correct_test, failed_test)
89
+ logger.info(f"Test Accuracy: {accuracy:.4f}")
90
+
91
+ # Save results
92
+ results = {
93
+ 'best_path': best_path,
94
+ 'update_scores': update_scores,
95
+ 'test_accuracy': accuracy,
96
+ 'test_results': {
97
+ 'correct': len(correct_test),
98
+ 'failed': len(failed_test)
99
+ }
100
+ }
101
+
102
+ save_path = Path(f"results_{config['search_name']}.pt")
103
+ torch.save(results, save_path)
104
+ logger.info(f"Results saved to {save_path}")
105
+
106
+ except Exception as e:
107
+ logger.error(f"Error in main execution: {str(e)}", exc_info=True)
108
+ raise
109
+
110
+ if __name__ == "__main__":
111
+ main()