|
import os |
|
import shutil |
|
import torch |
|
import pdb |
|
from model_soups_utils import average_two_model |
|
from cases_collect import valid_results_collect |
|
def remove_folder(path): |
|
if os.path.isdir(path): |
|
shutil.rmtree(path) |
|
print(f"Directory '{path}' has been removed.") |
|
else: |
|
print(f"Directory '{path}' does not exist.") |
|
def score_criteria(x): |
|
return x[1] |
|
def compare_criteria(x,y): |
|
return x<=y |
|
def find_best_combination(model_path,valid_data,test_examples,search_name,iteration=5,seed=True,task='nli'): |
|
if seed: |
|
if isinstance(model_path,list): |
|
paths=[] |
|
for m_p in model_path: |
|
paths.extend([m_p+'_{0}'.format(seed) for seed in [str(i) for i in range(2020,2030)]]) |
|
else: |
|
paths=[model_path+'_{0}'.format(seed) for seed in [str(i) for i in range(2020,2030)]] |
|
else: |
|
paths=model_path |
|
try: |
|
update_scores=torch.load('{0}_score.pt'.format(search_name)) |
|
del_paths=torch.load('{0}_path.pt'.format(search_name)) |
|
for path in del_paths: |
|
del paths[paths.index(path)] |
|
best_path=torch.load('{0}_best_path.pt'.format(search_name)) |
|
best_score=update_scores[-1] |
|
except: |
|
del_paths=[] |
|
update_scores=[] |
|
path_count=[] |
|
for path_id,path in enumerate(paths): |
|
print(0,path_id,len(paths)) |
|
f_test,c_test=valid_results_collect(path, test_examples,task) |
|
path_count.append((path,len(c_test)/(len(f_test)+len(c_test)))) |
|
print(path_count[-1][1]) |
|
path_count.sort(key=lambda x:score_criteria(x),reverse=True) |
|
best_path=path_count[0][0] |
|
best_score=score_criteria(path_count[0]) |
|
update_scores.append(best_score) |
|
f_test,c_test=valid_results_collect(best_path, test_examples,'nli') |
|
print(best_score,len(c_test)/(len(f_test)+len(c_test))) |
|
torch.save(update_scores,'{0}_score.pt'.format(search_name)) |
|
del_paths.append(best_path) |
|
torch.save(del_paths,'{0}_path.pt'.format(search_name)) |
|
del paths[paths.index(best_path)] |
|
torch.save(best_path,'{0}_best_path.pt'.format(search_name)) |
|
while len(paths)>0: |
|
path_count=[] |
|
for path_id,path in enumerate(paths): |
|
print(len(update_scores),path_id,len(paths)) |
|
average_path="{0}_average".format(best_path+path.split('/')[-1]) |
|
if not os.path.isdir(average_path): |
|
average_path=average_two_model(best_path,path,len(update_scores)) |
|
f_test,c_test=valid_results_collect(average_path, test_examples, 'nli') |
|
if not path_count: |
|
path_count.append((path,len(c_test)/(len(f_test)+len(c_test)),average_path)) |
|
else: |
|
score=len(c_test)/(len(f_test)+len(c_test)) |
|
if score>=path_count[-1][1]: |
|
path_count.append((path,score,average_path)) |
|
else: |
|
remove_folder(average_path) |
|
print(path_count[-1][1]) |
|
path_count.sort(key=lambda x:score_criteria(x),reverse=True) |
|
win_path=path_count[0][0] |
|
win_score=score_criteria(path_count[0]) |
|
|
|
if compare_criteria(best_score,win_score): |
|
if len(del_paths)>2: |
|
remove_folder(best_path) |
|
best_path=path_count[0][2] |
|
torch.save(best_path,'{0}_best_path.pt'.format(search_name)) |
|
best_score=win_score |
|
|
|
print(best_score) |
|
del paths[paths.index(win_path)] |
|
del_paths.append(win_path) |
|
torch.save(del_paths,'{0}_path.pt'.format(search_name)) |
|
|
|
update_scores.append(best_score) |
|
torch.save(update_scores,'{0}_score.pt'.format(search_name)) |
|
else: |
|
while paths: |
|
paths.pop() |
|
best_path=best_path |
|
|
|
|
|
return best_path,update_scores |
|
|
|
|