|
import os |
|
import torch |
|
import logging |
|
from pathlib import Path |
|
from typing import List, Dict, Tuple |
|
from datasets import load_dataset |
|
from greedy_search import find_best_combination |
|
from cases_collect import valid_results_collect |
|
|
|
def setup_logger() -> logging.Logger: |
|
"""Configure and return logger.""" |
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s' |
|
) |
|
return logging.getLogger(__name__) |
|
|
|
def get_model_paths(model_names: List[str], base_path: str = './') -> List[str]: |
|
"""Generate model paths from names.""" |
|
return [os.path.join(base_path, f"{name}_model") for name in model_names] |
|
|
|
def load_test_data(dataset_name: str = 'hippocrates/MedNLI_test') -> List[Dict]: |
|
"""Load and prepare test dataset.""" |
|
dataset = load_dataset(dataset_name) |
|
return [ |
|
{'Input': item['query'], 'Output': item['answer']} |
|
for item in dataset['test'] |
|
] |
|
|
|
def calculate_accuracy(correct: List, failed: List) -> float: |
|
"""Calculate accuracy from correct and failed cases.""" |
|
total = len(correct) + len(failed) |
|
return len(correct) / total if total > 0 else 0.0 |
|
|
|
def main(): |
|
"""Main execution function.""" |
|
logger = setup_logger() |
|
|
|
try: |
|
|
|
config = { |
|
'search_name': 'randoms_model', |
|
'model_names': ['randoms_data_3k_model'], |
|
'base_path': './', |
|
'valid_data_path': 'nli_demo.pt', |
|
'seed': True, |
|
'iteration': 5 |
|
} |
|
|
|
|
|
model_paths = get_model_paths(config['model_names'], config['base_path']) |
|
logger.info(f"Generated model paths: {model_paths}") |
|
|
|
|
|
logger.info("Loading test data...") |
|
test_examples = load_test_data() |
|
logger.info(f"Loaded {len(test_examples)} test examples") |
|
|
|
logger.info("Loading validation data...") |
|
try: |
|
valid_data = torch.load(config['valid_data_path']) |
|
logger.info(f"Loaded validation data from {config['valid_data_path']}") |
|
except Exception as e: |
|
logger.error(f"Failed to load validation data: {str(e)}") |
|
raise |
|
|
|
|
|
logger.info("Finding best model combination...") |
|
best_path, update_scores = find_best_combination( |
|
model_paths, |
|
valid_data, |
|
valid_data, |
|
config['search_name'], |
|
iteration=config['iteration'], |
|
seed=config['seed'] |
|
) |
|
logger.info(f"Best path found with scores: {update_scores}") |
|
|
|
|
|
logger.info("Evaluating on test set...") |
|
failed_test, correct_test = valid_results_collect( |
|
best_path, |
|
test_examples, |
|
'nli' |
|
) |
|
|
|
|
|
accuracy = calculate_accuracy(correct_test, failed_test) |
|
logger.info(f"Test Accuracy: {accuracy:.4f}") |
|
|
|
|
|
results = { |
|
'best_path': best_path, |
|
'update_scores': update_scores, |
|
'test_accuracy': accuracy, |
|
'test_results': { |
|
'correct': len(correct_test), |
|
'failed': len(failed_test) |
|
} |
|
} |
|
|
|
save_path = Path(f"results_{config['search_name']}.pt") |
|
torch.save(results, save_path) |
|
logger.info(f"Results saved to {save_path}") |
|
|
|
except Exception as e: |
|
logger.error(f"Error in main execution: {str(e)}", exc_info=True) |
|
raise |
|
|
|
if __name__ == "__main__": |
|
main() |