shawntan commited on
Commit
22d3401
·
2 Parent(s): 1fcd3ff 4a89382

Merge branch 'main' of https://huggingface.co/YiDuo1999/random_models_9

Browse files
Files changed (4) hide show
  1. cases_collect.py +2 -72
  2. greedy_search.py +97 -0
  3. model_soups_utils.py +25 -0
  4. run_model_soups.py +111 -0
cases_collect.py CHANGED
@@ -21,12 +21,8 @@ def valid_results_collect(model_path,valid_data,task):
21
  torch.cuda.ipc_collect()
22
  # multiprocessing.set_start_method('spawn')
23
  trained_model=LLM(model=model_path,gpu_memory_utilization=0.95)
24
-
25
  start_t=time.time()
26
- if task=='sql':
27
- failed_cases,correct_cases=sql_evaluation(trained_model,valid_data)
28
- elif task=='nli':
29
- failed_cases,correct_cases=nli_evaluation(trained_model,valid_data)
30
  del trained_model
31
  end_t=time.time()
32
  print('time',start_t-end_t)
@@ -34,9 +30,6 @@ def valid_results_collect(model_path,valid_data,task):
34
  torch.cuda.empty_cache()
35
  torch.cuda.ipc_collect()
36
  torch.cuda.synchronize()
37
- #torch.cuda.synchronize()
38
- #torch.cuda.empty_cache()
39
- #torch.cuda.synchronize()
40
  time.sleep(10)
41
  return failed_cases,correct_cases
42
  def extract_answer_prediction_nli(predicted_output):
@@ -58,7 +51,6 @@ def process_batch(data_batch,trained_model,failed_cases,correct_cases):
58
  batch_prompts = [data['Input'] for data in data_batch]
59
  outputs = trained_model.generate(batch_prompts, sampling_params)
60
 
61
- results = []
62
  labels=['entailment','contradiction','neutral']
63
  for data, output in zip(data_batch, outputs):
64
  # pdb.set_trace()
@@ -70,9 +62,6 @@ def process_batch(data_batch,trained_model,failed_cases,correct_cases):
70
  # pdb.set_trace()
71
 
72
  predicted_res=predicted_output
73
- # print(label,predicted_output) # if 'contradiction #label_transform(data['Output'])
74
- # pdb.set_trace()
75
- # print(predicted_res,label,'\n')
76
  non_labels = [lbl for lbl in labels if lbl != label]
77
  if label not in predicted_res or any(non_label in predicted_res for non_label in non_labels):
78
  failed_cases.append((data['Input'],predicted_res,label,data))
@@ -80,69 +69,10 @@ def process_batch(data_batch,trained_model,failed_cases,correct_cases):
80
  correct_cases.append((data['Input'],predicted_res,label,data))
81
  return failed_cases,correct_cases
82
  def nli_evaluation(trained_model,valid_data):
83
- id=0
84
  failed_cases=[]
85
  correct_cases=[]
86
  batch_size=500
87
  batched_data = [valid_data[i:i+batch_size] for i in range(0, len(valid_data), batch_size)]
88
  for batch in batched_data:
89
  failed_cases,correct_cases=process_batch(batch,trained_model,failed_cases,correct_cases)
90
-
91
- #for data in valid_data:
92
- # prompt=data['Input']
93
- # output=trained_model.generate(prompt, sampling_params)
94
- # predicted_output=output[0].outputs[0].text
95
- # predicted_res=extract_answer_prediction_nli(predicted_output) #$try:
96
- # # predicted_res=extract_answer(predicted_output.split('final')[-1].split('is')[1].split('.')[0])
97
- #except:
98
- # predicted_res=extract_answer(predicted_output.split('is')[-1])
99
- # label=extract_answer(data['Output'].split('is')[-1])
100
- # print(label,predicted_res)
101
- # if not predicted_res:
102
- # pdb.set_trace()
103
- # predicted_res=''
104
- # if 'contradiction #label_transform(data['Output'])
105
- # pdb.set_trace()
106
- # if label not in predicted_res:
107
- # failed_cases.append((id,prompt,predicted_res,label,data))
108
- # else:
109
- # correct_cases.append((id,prompt,predicted_res,label,data))
110
- # id+=1
111
- #id,prompt,prior_pred+predicted_sql,valid_data[id],ground_truth,predicted_res,ground_truth_res
112
- return failed_cases,correct_cases
113
- def sql_evaluation(trained_model,valid_data):
114
- id=0
115
- failed_cases=[]
116
- correct_cases=[]
117
- for triple in valid_data:
118
-
119
- db_id,prompt,ground_truth=triple
120
- prompt=prompt.replace('SELECT','')
121
- db_path='/dccstor/obsidian_llm/yiduo/AgentBench/DAMO-ConvAI/bird/data/train/train_databases/{0}/{0}.sqlite'.format(db_id)
122
- prompt+=' To generate the SQL query to' #print(db_path) #pdb.set_trace()
123
- conn = sqlite3.connect(db_path)
124
- output=trained_model.generate(prompt, sampling_params) #pdb.set_trace()
125
- predicted_sql = output[0].outputs[0].text
126
- #pdb.set_trace()
127
- prior_pred=predicted_sql.split('final SQL')[0]
128
- try:
129
- predicted_sql = predicted_sql.split('final SQL')[1].strip()
130
- except:
131
- predicted_sql = 'SELECT'+predicted_sql.split('SELECT')[1]
132
- predicted_sql=predicted_sql.split(';')[0]
133
- predicted_sql=predicted_sql[predicted_sql.find('SELECT'):] #[1:]
134
- cursor = conn.cursor()
135
- # pdb.set_trace()
136
- try:
137
- cursor.execute(predicted_sql)
138
- predicted_res = cursor.fetchall()
139
- cursor.execute(ground_truth)
140
- ground_truth_res = cursor.fetchall()
141
- #print('results',predicted_res,'truth',ground_truth_res,'\n')
142
- if set(predicted_res) != set(ground_truth_res):
143
- failed_cases.append((id,prompt,prior_pred+predicted_sql,valid_data[id],ground_truth,predicted_res,ground_truth_res))
144
- else:
145
- correct_cases.append((id,prompt,prior_pred+predicted_sql,valid_data[id],ground_truth,predicted_res,ground_truth_res))
146
- except Exception as e:
147
- failed_cases.append((id,prompt,predicted_sql,valid_data[id],ground_truth,str(Exception)+str(e)))
148
- return failed_cases,correct_cases
 
21
  torch.cuda.ipc_collect()
22
  # multiprocessing.set_start_method('spawn')
23
  trained_model=LLM(model=model_path,gpu_memory_utilization=0.95)
 
24
  start_t=time.time()
25
+ failed_cases,correct_cases=nli_evaluation(trained_model,valid_data)
 
 
 
26
  del trained_model
27
  end_t=time.time()
28
  print('time',start_t-end_t)
 
30
  torch.cuda.empty_cache()
31
  torch.cuda.ipc_collect()
32
  torch.cuda.synchronize()
 
 
 
33
  time.sleep(10)
34
  return failed_cases,correct_cases
35
  def extract_answer_prediction_nli(predicted_output):
 
51
  batch_prompts = [data['Input'] for data in data_batch]
52
  outputs = trained_model.generate(batch_prompts, sampling_params)
53
 
 
54
  labels=['entailment','contradiction','neutral']
55
  for data, output in zip(data_batch, outputs):
56
  # pdb.set_trace()
 
62
  # pdb.set_trace()
63
 
64
  predicted_res=predicted_output
 
 
 
65
  non_labels = [lbl for lbl in labels if lbl != label]
66
  if label not in predicted_res or any(non_label in predicted_res for non_label in non_labels):
67
  failed_cases.append((data['Input'],predicted_res,label,data))
 
69
  correct_cases.append((data['Input'],predicted_res,label,data))
70
  return failed_cases,correct_cases
71
  def nli_evaluation(trained_model,valid_data):
 
72
  failed_cases=[]
73
  correct_cases=[]
74
  batch_size=500
75
  batched_data = [valid_data[i:i+batch_size] for i in range(0, len(valid_data), batch_size)]
76
  for batch in batched_data:
77
  failed_cases,correct_cases=process_batch(batch,trained_model,failed_cases,correct_cases)
78
+ return failed_cases,correct_cases
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
greedy_search.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import torch
4
+ import pdb
5
+ from model_soups_utils import average_two_model
6
+ from cases_collect import valid_results_collect
7
+ def remove_folder(path):
8
+ if os.path.isdir(path): # Check if the directory exists
9
+ shutil.rmtree(path)
10
+ print(f"Directory '{path}' has been removed.")
11
+ else:
12
+ print(f"Directory '{path}' does not exist.")
13
+ def score_criteria(x):
14
+ return x[1]
15
+ def compare_criteria(x,y):
16
+ return x<=y
17
+ def find_best_combination(model_path,valid_data,test_examples,search_name,iteration=5,seed=True,task='nli'):
18
+ if seed:
19
+ if isinstance(model_path,list):
20
+ paths=[]
21
+ for m_p in model_path:
22
+ paths.extend([m_p+'_{0}'.format(seed) for seed in [str(i) for i in range(2020,2030)]])
23
+ else:
24
+ paths=[model_path+'_{0}'.format(seed) for seed in [str(i) for i in range(2020,2030)]]
25
+ else:
26
+ paths=model_path
27
+ try:
28
+ update_scores=torch.load('{0}_score.pt'.format(search_name))
29
+ del_paths=torch.load('{0}_path.pt'.format(search_name))
30
+ for path in del_paths:
31
+ del paths[paths.index(path)]
32
+ best_path=torch.load('{0}_best_path.pt'.format(search_name))
33
+ best_score=update_scores[-1]
34
+ except:
35
+ del_paths=[]
36
+ update_scores=[]
37
+ path_count=[]
38
+ for path_id,path in enumerate(paths):
39
+ print(0,path_id,len(paths))
40
+ f_test,c_test=valid_results_collect(path, test_examples,task) #test_examples, args.task)
41
+ path_count.append((path,len(c_test)/(len(f_test)+len(c_test)))) #ooa_failed_cases, im_failed_cases, correct_cases=process_nli_validation_batch(path, valid_data,seed=False, iteration=5)
42
+ print(path_count[-1][1]) #path_count.append((path,len(ooa_failed_cases),len(im_failed_cases),len(correct_cases)))
43
+ path_count.sort(key=lambda x:score_criteria(x),reverse=True)
44
+ best_path=path_count[0][0]
45
+ best_score=score_criteria(path_count[0])
46
+ update_scores.append(best_score)
47
+ f_test,c_test=valid_results_collect(best_path, test_examples,'nli')
48
+ print(best_score,len(c_test)/(len(f_test)+len(c_test)))
49
+ torch.save(update_scores,'{0}_score.pt'.format(search_name)) #torch.save(update_scores,'update_scores_backup.pt')
50
+ del_paths.append(best_path)
51
+ torch.save(del_paths,'{0}_path.pt'.format(search_name)) #torch.save(update_scores,'{0}_score.pt'.format(search_name)) #del_paths=torch.load('{0}_path.pt'.format(search_name)) #pdb.set_trace()
52
+ del paths[paths.index(best_path)]
53
+ torch.save(best_path,'{0}_best_path.pt'.format(search_name))
54
+ while len(paths)>0:
55
+ path_count=[]
56
+ for path_id,path in enumerate(paths):
57
+ print(len(update_scores),path_id,len(paths))
58
+ average_path="{0}_average".format(best_path+path.split('/')[-1])
59
+ if not os.path.isdir(average_path):
60
+ average_path=average_two_model(best_path,path,len(update_scores))
61
+ f_test,c_test=valid_results_collect(average_path, test_examples, 'nli') #valid_results_collect(path, valid_data,args.task) #f_test,c_test=valid_results_collect(average_path, test_examples$
62
+ if not path_count: #ooa_failed_cases, im_failed_cases, correct_cases=process_nli_validation_batch(average_path, valid_data,seed=False, iteration=5)
63
+ path_count.append((path,len(c_test)/(len(f_test)+len(c_test)),average_path))
64
+ else:
65
+ score=len(c_test)/(len(f_test)+len(c_test))
66
+ if score>=path_count[-1][1]:
67
+ path_count.append((path,score,average_path))
68
+ else:
69
+ remove_folder(average_path)
70
+ print(path_count[-1][1]) #len(ooa_failed_cases),len(im_failed_cases),len(correct_cases),average_path))
71
+ path_count.sort(key=lambda x:score_criteria(x),reverse=True)
72
+ win_path=path_count[0][0]
73
+ win_score=score_criteria(path_count[0])
74
+ #del paths[paths.index(win_path)]
75
+ if compare_criteria(best_score,win_score):
76
+ if len(del_paths)>2:
77
+ remove_folder(best_path)
78
+ best_path=path_count[0][2]
79
+ torch.save(best_path,'{0}_best_path.pt'.format(search_name))
80
+ best_score=win_score
81
+ #f_test,c_test=valid_results_collect(best_path, test_examples,args.task)
82
+ print(best_score) #,len(c_test)/(len(f_test)+len(c_test)))
83
+ del paths[paths.index(win_path)] #print(best_score)
84
+ del_paths.append(win_path)
85
+ torch.save(del_paths,'{0}_path.pt'.format(search_name))
86
+ # pdb.set_trace()
87
+ update_scores.append(best_score)
88
+ torch.save(update_scores,'{0}_score.pt'.format(search_name)) #torch.save(update_scores,'update_scores_backup.pt')
89
+ else:
90
+ while paths:
91
+ paths.pop()
92
+ best_path=best_path
93
+ #break
94
+ #update_scores.append(best_score)
95
+ return best_path,update_scores
96
+
97
+ #ooa_failed_cases, im_failed_cases, correct_cases=process_nli_validation_batch(path, valid_data,seed=False, iteration=100)
model_soups_utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM,AutoTokenizer
3
+ from transformers import LlamaTokenizer
4
+ from vllm import LLM, SamplingParams
5
+
6
+ def average_two_model(model_path_1,model_path_2,update_num,base_path='/dccstor/obsidian_llm/yiduo/h100_data/llama-3-8b'):
7
+
8
+ # Path to save the averaged model and tokenizer
9
+ averaged_model_path = "{0}".format(model_path_1+model_path_2.split('/')[-1]).replace('00','').replace('random','').replace('naive_3k','').replace('shuffle','').replace('average','')
10
+ # Load and average the state dicts for each model
11
+ models=[]
12
+ model_paths=[model_path_1,model_path_2]
13
+ for model_path in model_paths:
14
+ models.append(AutoModelForCausalLM.from_pretrained(model_path))
15
+ avg_state_dict = {}
16
+ for key in models[0].state_dict().keys():
17
+ avg_state_dict[key] = (update_num/(update_num+1))*models[0].state_dict()[key]+(1.0/(update_num+1))*models[1].state_dict()[key] #sum([model.state_dict()[key] for model in models]) / len(models)
18
+ base_model = AutoModelForCausalLM.from_pretrained(base_path) # Load the base model configuration
19
+ base_model.load_state_dict(avg_state_dict)
20
+ base_model.save_pretrained(averaged_model_path) # Save the averaged model
21
+ # Load the tokenizer (assuming all models used the same tokenizer)
22
+ # If needed, adjust the tokenizer path to match the base LLaMA tokenizer used
23
+ tokenizer = AutoTokenizer.from_pretrained(model_path_1) #tokenizer = LlamaTokenizer.from_pretrained(model_path+'_{0}'.format(seeds[0]))
24
+ tokenizer.save_pretrained(averaged_model_path)
25
+ return averaged_model_path
run_model_soups.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import List, Dict, Tuple
6
+ from datasets import load_dataset
7
+ from greedy_search import find_best_combination
8
+ from cases_collect import valid_results_collect
9
+
10
+ def setup_logger() -> logging.Logger:
11
+ """Configure and return logger."""
12
+ logging.basicConfig(
13
+ level=logging.INFO,
14
+ format='%(asctime)s - %(levelname)s - %(message)s'
15
+ )
16
+ return logging.getLogger(__name__)
17
+
18
+ def get_model_paths(model_names: List[str], base_path: str = './') -> List[str]:
19
+ """Generate model paths from names."""
20
+ return [os.path.join(base_path, f"{name}_model") for name in model_names]
21
+
22
+ def load_test_data(dataset_name: str = 'hippocrates/MedNLI_test') -> List[Dict]:
23
+ """Load and prepare test dataset."""
24
+ dataset = load_dataset(dataset_name)
25
+ return [
26
+ {'Input': item['query'], 'Output': item['answer']}
27
+ for item in dataset['test']
28
+ ]
29
+
30
+ def calculate_accuracy(correct: List, failed: List) -> float:
31
+ """Calculate accuracy from correct and failed cases."""
32
+ total = len(correct) + len(failed)
33
+ return len(correct) / total if total > 0 else 0.0
34
+
35
+ def main():
36
+ """Main execution function."""
37
+ logger = setup_logger()
38
+
39
+ try:
40
+ # Configuration
41
+ config = {
42
+ 'search_name': 'randoms_model',
43
+ 'model_names': ['randoms_data_3k_model'],
44
+ 'base_path': './',
45
+ 'valid_data_path': 'nli_demo.pt',
46
+ 'seed': True,
47
+ 'iteration': 5
48
+ }
49
+
50
+ # Generate model paths
51
+ model_paths = get_model_paths(config['model_names'], config['base_path'])
52
+ logger.info(f"Generated model paths: {model_paths}")
53
+
54
+ # Load datasets
55
+ logger.info("Loading test data...")
56
+ test_examples = load_test_data()
57
+ logger.info(f"Loaded {len(test_examples)} test examples")
58
+
59
+ logger.info("Loading validation data...")
60
+ try:
61
+ valid_data = torch.load(config['valid_data_path'])
62
+ logger.info(f"Loaded validation data from {config['valid_data_path']}")
63
+ except Exception as e:
64
+ logger.error(f"Failed to load validation data: {str(e)}")
65
+ raise
66
+
67
+ # Find best combination
68
+ logger.info("Finding best model combination...")
69
+ best_path, update_scores = find_best_combination(
70
+ model_paths,
71
+ valid_data,
72
+ valid_data,
73
+ config['search_name'],
74
+ iteration=config['iteration'],
75
+ seed=config['seed']
76
+ )
77
+ logger.info(f"Best path found with scores: {update_scores}")
78
+
79
+ # Evaluate on test set
80
+ logger.info("Evaluating on test set...")
81
+ failed_test, correct_test = valid_results_collect(
82
+ best_path,
83
+ test_examples,
84
+ 'nli'
85
+ )
86
+
87
+ # Calculate and log accuracy
88
+ accuracy = calculate_accuracy(correct_test, failed_test)
89
+ logger.info(f"Test Accuracy: {accuracy:.4f}")
90
+
91
+ # Save results
92
+ results = {
93
+ 'best_path': best_path,
94
+ 'update_scores': update_scores,
95
+ 'test_accuracy': accuracy,
96
+ 'test_results': {
97
+ 'correct': len(correct_test),
98
+ 'failed': len(failed_test)
99
+ }
100
+ }
101
+
102
+ save_path = Path(f"results_{config['search_name']}.pt")
103
+ torch.save(results, save_path)
104
+ logger.info(f"Results saved to {save_path}")
105
+
106
+ except Exception as e:
107
+ logger.error(f"Error in main execution: {str(e)}", exc_info=True)
108
+ raise
109
+
110
+ if __name__ == "__main__":
111
+ main()