Anni123 commited on
Commit
81d7292
1 Parent(s): 2ef16f8

Delete self_construction.py

Browse files
Files changed (1) hide show
  1. self_construction.py +0 -171
self_construction.py DELETED
@@ -1,171 +0,0 @@
1
- import random
2
- from sentence_transformers import SentenceTransformer
3
- from sklearn.cluster import KMeans
4
- from sklearn.decomposition import PCA
5
- import numpy as np
6
- import json
7
- import matplotlib.pyplot as plt
8
- import argparse
9
- from utils import fix_seed
10
-
11
-
12
- def parse_arguments():
13
- parser = argparse.ArgumentParser(description="Zero-shot-CoT")
14
- parser.add_argument(
15
- "--task", type=str, default="strategyqa",
16
- choices=["aqua", "gsm8k", "commonsensqa", "addsub", "multiarith", "strategyqa", "svamp", "singleeq", "coin_flip", "last_letters"], help="dataset used for experiment"
17
- )
18
- parser.add_argument(
19
- "--max_ra_len", type=int, default=5, help="maximum number of reasoning chains"
20
- )
21
- parser.add_argument(
22
- "--pred_file", type=str, default="log/multiarith_zero_shot_cot.log",
23
- help="use the reasoning chains generated by zero-shot-cot."
24
- )
25
- parser.add_argument(
26
- "--demo_save_dir", type=str, default="demos/multiarith", help="where to save the contructed demonstrations"
27
- )
28
- parser.add_argument("--random_seed", type=int, default=192, help="random seed")
29
- parser.add_argument(
30
- "--encoder", type=str, default="all-MiniLM-L6-v2", help="which sentence-transformer encoder for clustering"
31
- )
32
- parser.add_argument(
33
- "--sampling", type=str, default="center", help="whether to sample the cluster center first"
34
- )
35
- parser.add_argument(
36
- "--debug", type=bool, default=True, help="debug mode"
37
- )
38
- args = parser.parse_args()
39
- return args
40
-
41
- def main():
42
- args = parse_arguments()
43
- fix_seed(args.random_seed)
44
- encoder = SentenceTransformer(args.encoder)
45
-
46
- task = args.task
47
- pred_file = args.pred_file
48
- save_file = args.demo_save_dir
49
- max_ra_len = args.max_ra_len
50
- if task == "last_letters":
51
- max_ra_len = 7
52
- if task == "aqua" or task == "last_letters":
53
- num_clusters = 4
54
- elif task == "commonsensqa":
55
- num_clusters = 7
56
- elif task == "strategyqa":
57
- num_clusters = 6
58
- else:
59
- num_clusters = 8
60
-
61
- corpus = []
62
- question = []
63
- rationale = []
64
- gold_ans = []
65
- pred_ans = []
66
-
67
- with open(pred_file, "r", encoding="utf-8") as fp:
68
- answer_seg = ""
69
- for line in fp:
70
- if "Q: " in line:
71
- c_question = line.strip()
72
- if "A: " in line:
73
- answer_seg = line
74
- elif "Therefore" in line and "the answer" in line:
75
- c_rationale = answer_seg
76
-
77
- elif answer_seg != "":
78
- answer_seg += line
79
- if "pred_mode" in line:
80
- c_pred_ans = line.split(":")[1].strip()
81
- if "GT :" in line:
82
- c_gold_ans = line.split(":")[1].strip()
83
-
84
- c_rationale = c_rationale.replace("A: Let's think step by step.", "Let's think step by step.")
85
- c_question = c_question + "\nA:"
86
-
87
- corpus.append(c_question)
88
- question.append(c_question)
89
- rationale.append(c_rationale)
90
- pred_ans.append(c_pred_ans)
91
- if args.debug:
92
- gold_ans.append(c_gold_ans)
93
- answer_seg = ""
94
-
95
- corpus_embeddings = encoder.encode(corpus)
96
-
97
- # Perform kmean clustering
98
- clustering_model = KMeans(n_clusters=num_clusters, random_state=args.random_seed)
99
- clustering_model.fit(corpus_embeddings)
100
- cluster_assignment = clustering_model.labels_
101
-
102
- clustered_sentences = [[] for i in range(num_clusters)]
103
-
104
- dist = clustering_model.transform(corpus_embeddings)
105
- clustered_dists = [[] for i in range(num_clusters)]
106
- clustered_idx = [[] for i in range(num_clusters)]
107
- for sentence_id, cluster_id in enumerate(cluster_assignment):
108
- clustered_sentences[cluster_id].append(corpus[sentence_id])
109
- clustered_dists[cluster_id].append(dist[sentence_id][cluster_id])
110
- clustered_idx[cluster_id].append(sentence_id)
111
-
112
- demos = []
113
-
114
- for i in range(len(clustered_dists)):
115
- print("Cluster ", i+1)
116
- tmp = list(map(list, zip(range(len(clustered_dists[i])), clustered_dists[i])))
117
- top_min_dist = sorted(tmp, key=lambda x: x[1], reverse=False)
118
- if not args.sampling == "center":
119
- random.shuffle(top_min_dist)
120
- for element in top_min_dist:
121
- min_idx = element[0]
122
- c_rationale = rationale[clustered_idx[i][min_idx]].strip()
123
- c_pred_ans = pred_ans[clustered_idx[i][min_idx]].strip()
124
-
125
- if len(question[clustered_idx[i][min_idx]].strip().split()) <= 60 \
126
- and len(c_rationale.replace("\n\n", "\n").split("\n")) <= max_ra_len and c_rationale[-1] == "." and c_pred_ans != "":
127
- if args.task in ["gsm8k", "multiarith", "singleeq", "addsub", "svamp"]:
128
- if not (c_pred_ans.strip() in c_rationale.split(".")[-2] or c_pred_ans.strip() in c_rationale.split()[-10:]):
129
- continue
130
- c_question = question[clustered_idx[i][min_idx]]
131
- c_rationale = c_rationale.replace("\n\n", "\n").replace("\n", " ").strip()
132
- c_rationale = " ".join(c_rationale.split())
133
- if args.debug:
134
- c_gold_ans = gold_ans[clustered_idx[i][min_idx]]
135
- else:
136
- c_gold_ans = None
137
- demo_element = {
138
- "question": c_question,
139
- "rationale": c_rationale,
140
- "pred_ans": c_pred_ans,
141
- "gold_ans": c_gold_ans,
142
- }
143
- demos.append(demo_element)
144
- print(c_question)
145
- print(c_rationale)
146
- print(c_pred_ans)
147
- print(c_gold_ans)
148
- print("")
149
- break
150
-
151
- demos = {"demo": demos}
152
-
153
- with open(args.demo_save_dir, 'w', encoding="utf-8") as write_f:
154
- json.dump(demos, write_f, indent=4, ensure_ascii=False)
155
-
156
- y_km = clustering_model.fit_predict(corpus_embeddings)
157
- pca_model = PCA(n_components=2, random_state=args.random_seed)
158
- transformed = pca_model.fit_transform(corpus_embeddings)
159
- centers = pca_model.transform(clustering_model.cluster_centers_)
160
-
161
- plt.scatter(x=transformed[:, 0], y=transformed[:, 1], c=y_km, s=50, cmap=plt.cm.Paired, alpha=0.4)
162
- plt.scatter(centers[:, 0],centers[:, 1],
163
- s=250, marker='*', label='centroids',
164
- edgecolor='black',
165
- c=np.arange(0,num_clusters),cmap=plt.cm.Paired,)
166
- plt.xticks([])
167
- plt.yticks([])
168
- plt.savefig(save_file+".png", dpi=600)
169
-
170
- if __name__ == "__main__":
171
- main()