random_models_9 / greedy_search.py
YiDuo1999's picture
Rename greedy_search to greedy_search.py
7f55732 verified
import os
import shutil
import torch
import pdb
from model_soups_utils import average_two_model
from cases_collect import valid_results_collect
def remove_folder(path):
if os.path.isdir(path): # Check if the directory exists
shutil.rmtree(path)
print(f"Directory '{path}' has been removed.")
else:
print(f"Directory '{path}' does not exist.")
def score_criteria(x):
return x[1]
def compare_criteria(x,y):
return x<=y
def find_best_combination(model_path,valid_data,test_examples,search_name,iteration=5,seed=True,task='nli'):
if seed:
if isinstance(model_path,list):
paths=[]
for m_p in model_path:
paths.extend([m_p+'_{0}'.format(seed) for seed in [str(i) for i in range(2020,2030)]])
else:
paths=[model_path+'_{0}'.format(seed) for seed in [str(i) for i in range(2020,2030)]]
else:
paths=model_path
try:
update_scores=torch.load('{0}_score.pt'.format(search_name))
del_paths=torch.load('{0}_path.pt'.format(search_name))
for path in del_paths:
del paths[paths.index(path)]
best_path=torch.load('{0}_best_path.pt'.format(search_name))
best_score=update_scores[-1]
except:
del_paths=[]
update_scores=[]
path_count=[]
for path_id,path in enumerate(paths):
print(0,path_id,len(paths))
f_test,c_test=valid_results_collect(path, test_examples,task) #test_examples, args.task)
path_count.append((path,len(c_test)/(len(f_test)+len(c_test)))) #ooa_failed_cases, im_failed_cases, correct_cases=process_nli_validation_batch(path, valid_data,seed=False, iteration=5)
print(path_count[-1][1]) #path_count.append((path,len(ooa_failed_cases),len(im_failed_cases),len(correct_cases)))
path_count.sort(key=lambda x:score_criteria(x),reverse=True)
best_path=path_count[0][0]
best_score=score_criteria(path_count[0])
update_scores.append(best_score)
f_test,c_test=valid_results_collect(best_path, test_examples,'nli')
print(best_score,len(c_test)/(len(f_test)+len(c_test)))
torch.save(update_scores,'{0}_score.pt'.format(search_name)) #torch.save(update_scores,'update_scores_backup.pt')
del_paths.append(best_path)
torch.save(del_paths,'{0}_path.pt'.format(search_name)) #torch.save(update_scores,'{0}_score.pt'.format(search_name)) #del_paths=torch.load('{0}_path.pt'.format(search_name)) #pdb.set_trace()
del paths[paths.index(best_path)]
torch.save(best_path,'{0}_best_path.pt'.format(search_name))
while len(paths)>0:
path_count=[]
for path_id,path in enumerate(paths):
print(len(update_scores),path_id,len(paths))
average_path="{0}_average".format(best_path+path.split('/')[-1])
if not os.path.isdir(average_path):
average_path=average_two_model(best_path,path,len(update_scores))
f_test,c_test=valid_results_collect(average_path, test_examples, 'nli') #valid_results_collect(path, valid_data,args.task) #f_test,c_test=valid_results_collect(average_path, test_examples$
if not path_count: #ooa_failed_cases, im_failed_cases, correct_cases=process_nli_validation_batch(average_path, valid_data,seed=False, iteration=5)
path_count.append((path,len(c_test)/(len(f_test)+len(c_test)),average_path))
else:
score=len(c_test)/(len(f_test)+len(c_test))
if score>=path_count[-1][1]:
path_count.append((path,score,average_path))
else:
remove_folder(average_path)
print(path_count[-1][1]) #len(ooa_failed_cases),len(im_failed_cases),len(correct_cases),average_path))
path_count.sort(key=lambda x:score_criteria(x),reverse=True)
win_path=path_count[0][0]
win_score=score_criteria(path_count[0])
#del paths[paths.index(win_path)]
if compare_criteria(best_score,win_score):
if len(del_paths)>2:
remove_folder(best_path)
best_path=path_count[0][2]
torch.save(best_path,'{0}_best_path.pt'.format(search_name))
best_score=win_score
#f_test,c_test=valid_results_collect(best_path, test_examples,args.task)
print(best_score) #,len(c_test)/(len(f_test)+len(c_test)))
del paths[paths.index(win_path)] #print(best_score)
del_paths.append(win_path)
torch.save(del_paths,'{0}_path.pt'.format(search_name))
# pdb.set_trace()
update_scores.append(best_score)
torch.save(update_scores,'{0}_score.pt'.format(search_name)) #torch.save(update_scores,'update_scores_backup.pt')
else:
while paths:
paths.pop()
best_path=best_path
#break
#update_scores.append(best_score)
return best_path,update_scores
#ooa_failed_cases, im_failed_cases, correct_cases=process_nli_validation_batch(path, valid_data,seed=False, iteration=100)