File size: 5,133 Bytes
c63c010
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import shutil
import torch
import pdb
from model_soups_utils import average_two_model
from cases_collect import valid_results_collect
def remove_folder(path):
    if os.path.isdir(path):  # Check if the directory exists
        shutil.rmtree(path)
        print(f"Directory '{path}' has been removed.")
    else:
        print(f"Directory '{path}' does not exist.")
def score_criteria(x):
    return x[1]
def compare_criteria(x,y):
    return x<=y
def find_best_combination(model_path,valid_data,test_examples,search_name,iteration=5,seed=True,task='nli'):
    if seed:
        if isinstance(model_path,list):
            paths=[]
            for m_p in model_path:
                paths.extend([m_p+'_{0}'.format(seed) for seed in [str(i) for i in range(2020,2030)]])
        else:
            paths=[model_path+'_{0}'.format(seed) for seed in [str(i) for i in range(2020,2030)]]
    else:
        paths=model_path
    try:
        update_scores=torch.load('{0}_score.pt'.format(search_name))
        del_paths=torch.load('{0}_path.pt'.format(search_name))
        for path in del_paths:
            del paths[paths.index(path)]
        best_path=torch.load('{0}_best_path.pt'.format(search_name))
        best_score=update_scores[-1]
    except:
        del_paths=[]
        update_scores=[]
        path_count=[]
        for path_id,path in enumerate(paths):
            print(0,path_id,len(paths))
            f_test,c_test=valid_results_collect(path, test_examples,task) #test_examples, args.task)
            path_count.append((path,len(c_test)/(len(f_test)+len(c_test)))) #ooa_failed_cases, im_failed_cases, correct_cases=process_nli_validation_batch(path, valid_data,seed=False, iteration=5)
        print(path_count[-1][1]) #path_count.append((path,len(ooa_failed_cases),len(im_failed_cases),len(correct_cases)))
        path_count.sort(key=lambda x:score_criteria(x),reverse=True)
        best_path=path_count[0][0]
        best_score=score_criteria(path_count[0])
        update_scores.append(best_score)
        f_test,c_test=valid_results_collect(best_path, test_examples,'nli') 
        print(best_score,len(c_test)/(len(f_test)+len(c_test)))
        torch.save(update_scores,'{0}_score.pt'.format(search_name)) #torch.save(update_scores,'update_scores_backup.pt') 
        del_paths.append(best_path) 
        torch.save(del_paths,'{0}_path.pt'.format(search_name)) #torch.save(update_scores,'{0}_score.pt'.format(search_name)) #del_paths=torch.load('{0}_path.pt'.format(search_name)) #pdb.set_trace()
        del paths[paths.index(best_path)]
        torch.save(best_path,'{0}_best_path.pt'.format(search_name))
    while len(paths)>0:
        path_count=[]
        for path_id,path in enumerate(paths):
            print(len(update_scores),path_id,len(paths))
            average_path="{0}_average".format(best_path+path.split('/')[-1])
            if not os.path.isdir(average_path):
                average_path=average_two_model(best_path,path,len(update_scores))
            f_test,c_test=valid_results_collect(average_path, test_examples, 'nli') #valid_results_collect(path, valid_data,args.task) #f_test,c_test=valid_results_collect(average_path, test_examples$
            if not path_count: #ooa_failed_cases, im_failed_cases, correct_cases=process_nli_validation_batch(average_path, valid_data,seed=False, iteration=5)
                path_count.append((path,len(c_test)/(len(f_test)+len(c_test)),average_path))
            else:
                score=len(c_test)/(len(f_test)+len(c_test))
                if score>=path_count[-1][1]:
                    path_count.append((path,score,average_path))
                else:
                    remove_folder(average_path)
            print(path_count[-1][1]) #len(ooa_failed_cases),len(im_failed_cases),len(correct_cases),average_path))
        path_count.sort(key=lambda x:score_criteria(x),reverse=True)
        win_path=path_count[0][0]
        win_score=score_criteria(path_count[0])
        #del paths[paths.index(win_path)]
        if compare_criteria(best_score,win_score):
            if len(del_paths)>2: 
                remove_folder(best_path)
            best_path=path_count[0][2]
            torch.save(best_path,'{0}_best_path.pt'.format(search_name))
            best_score=win_score
            #f_test,c_test=valid_results_collect(best_path, test_examples,args.task) 
            print(best_score) #,len(c_test)/(len(f_test)+len(c_test)))
            del paths[paths.index(win_path)] #print(best_score)
            del_paths.append(win_path) 
            torch.save(del_paths,'{0}_path.pt'.format(search_name))
   #         pdb.set_trace()
            update_scores.append(best_score)
            torch.save(update_scores,'{0}_score.pt'.format(search_name)) #torch.save(update_scores,'update_scores_backup.pt') 
        else:
            while paths:
                paths.pop()
            best_path=best_path
            #break
#update_scores.append(best_score)
    return best_path,update_scores

        #ooa_failed_cases, im_failed_cases, correct_cases=process_nli_validation_batch(path, valid_data,seed=False, iteration=100)