Anni123 commited on
Commit
bb74d0b
·
1 Parent(s): 611e1d3

Delete datatype_sampling.py

Browse files
Files changed (1) hide show
  1. datatype_sampling.py +0 -223
datatype_sampling.py DELETED
@@ -1,223 +0,0 @@
1
- import json
2
- import random
3
- from llm_utils import *
4
-
5
- DATA_PATHS = {
6
- "addsub": "./dataset/AddSub/AddSub.json",
7
- #"aqua": "./dataset/AQuA/test.json",
8
- #"bigbench_date": "./dataset/Bigbench_Date/task.json",
9
- #"object_tracking": "./dataset/Bigbench_object_tracking/task.json",
10
- "coin_flip": "./dataset/coin_flip/coin_flip.json",
11
- "commonsensqa": "./dataset/CommonsenseQA/dev_rand_split.jsonl",
12
- "gsm8k": "./dataset/grade-school-math/test.jsonl",
13
- "last_letters": "./dataset/last_letters/last_letters.json",
14
- "multiarith": "./dataset/MultiArith/MultiArith.json",
15
- "strategyqa": "./dataset/StrategyQA/task.json",
16
- "singleeq": "./dataset/SingleEq/questions.json",
17
- "svamp": "./dataset/SVAMP/SVAMP.json",
18
- }
19
-
20
-
21
- # https://review-of-my-life.blogspot.com/2017/11/python-dict-shuffle.html
22
- def shuffleDict(d):
23
- keys = list(d.keys())
24
- random.shuffle(keys)
25
- [(key, d[key]) for key in keys]
26
- random.shuffle(keys)
27
- [(key, d[key]) for key in keys]
28
- random.shuffle(keys)
29
- keys = [(key, d[key]) for key in keys]
30
- #keys = d(keys)
31
- return dict(keys)
32
-
33
-
34
- def sample_type_demo(num_type=1):
35
- decoder = json.JSONDecoder()
36
- all_demo = {}
37
-
38
- for data, datapath in DATA_PATHS.items():
39
- '''
40
- if data == "aqua":
41
- questions = []
42
- with open(datapath) as f:
43
- lines = f.readlines()
44
- for line in lines:
45
- json_res = decoder.raw_decode(line)[0]
46
- choice = "(" + "(".join(json_res["options"])
47
- choice = choice.replace("(", " (").replace(")", ") ")
48
- choice = "Answer Choices:" + choice
49
- questions.append(json_res["question"].strip() + " " + choice)
50
- questions = random.sample(questions, num_type)
51
- if data not in all_demo.keys():
52
- all_demo[data] = []
53
- for que in questions:
54
- que_string = "Question: " + que + "\n"
55
- all_demo[data].append(que_string)
56
- #demo = "Question: " + que + "\n" + "Type: " + data + "\n\n"
57
- #all_demo.append(demo)
58
- '''
59
-
60
- if data == "gsm8k":
61
- questions = []
62
- with open(datapath) as f:
63
- lines = f.readlines()
64
- for line in lines:
65
- json_res = decoder.raw_decode(line)[0]
66
- questions.append(json_res["question"].strip())
67
- questions = random.sample(questions, num_type)
68
- if data not in all_demo.keys():
69
- all_demo[data] = []
70
- for que in questions:
71
- que_string = "Question: " + que + "\n"
72
- all_demo[data].append(que_string)
73
- #demo = "Question: " + que + "\n" + "Type: " + data + "\n\n"
74
- #all_demo.append(demo)
75
-
76
-
77
- elif data == "commonsensqa":
78
- questions = []
79
- with open(datapath) as f:
80
- lines = f.readlines()
81
- for line in lines:
82
- json_res = decoder.raw_decode(line)[0]
83
- choice = "Answer Choices:"
84
- for c in json_res["question"]["choices"]:
85
- choice += " ("
86
- choice += c["label"]
87
- choice += ") "
88
- choice += c["text"]
89
- questions.append(json_res["question"]["stem"].strip() + " " + choice)
90
- questions = random.sample(questions, num_type)
91
- if data not in all_demo.keys():
92
- all_demo[data] = []
93
- for que in questions:
94
- que_string = "Question: " + que + "\n"
95
- all_demo[data].append(que_string)
96
- #demo = "Question: " + que + "\n" + "Type: " + data + "\n\n"
97
- #all_demo.append(demo)
98
-
99
- elif data in ("addsub", "multiarith", "singleeq"):
100
- questions = []
101
- with open(datapath) as f:
102
- json_data = json.load(f)
103
- for line in json_data:
104
- q = line["sQuestion"].strip()
105
- questions.append(q)
106
- questions = random.sample(questions, num_type)
107
- if data not in all_demo.keys():
108
- all_demo[data] = []
109
- for que in questions:
110
- que_string = "Question: " + que + "\n"
111
- all_demo[data].append(que_string)
112
- #demo = "Question: " + que + "\n" + "Type: " + data + "\n\n"
113
- #all_demo.append(demo)
114
-
115
- elif data == "strategyqa":
116
- questions = []
117
- with open(datapath) as f:
118
- json_data = json.load(f)["examples"]
119
- for line in json_data:
120
- q = line["input"].strip()
121
- questions.append(q)
122
- questions = random.sample(questions, num_type)
123
- if data not in all_demo.keys():
124
- all_demo[data] = []
125
- for que in questions:
126
- que_string = "Question: " + que + "\n"
127
- all_demo[data].append(que_string)
128
- #demo = "Question: " + que + "\n" + "Type: " + data + "\n\n"
129
- #all_demo.append(demo)
130
-
131
- elif data == "svamp":
132
- questions = []
133
- with open(datapath) as f:
134
- json_data = json.load(f)
135
- for line in json_data:
136
- q = line["Body"].strip() + " " + line["Question"].strip()
137
- questions.append(q)
138
- questions = random.sample(questions, num_type)
139
- if data not in all_demo.keys():
140
- all_demo[data] = []
141
- for que in questions:
142
- que_string = "Question: " + que + "\n"
143
- all_demo[data].append(que_string)
144
- #demo = "Question: " + que + "\n" + "Type: " + data + "\n\n"
145
- #all_demo.append(demo)
146
-
147
-
148
- elif data in ("coin_flip", "last_letters"):
149
- questions = []
150
- with open(datapath) as f:
151
- json_data = json.load(f)
152
- json_data = json_data["examples"]
153
- for line in json_data:
154
- q = line["question"]
155
- questions.append(q)
156
- questions = random.sample(questions, num_type)
157
- if data not in all_demo.keys():
158
- all_demo[data] = []
159
- for que in questions:
160
- que_string = "Question: " + que + "\n"
161
- all_demo[data].append(que_string)
162
- #demo = "Question: " + que + "\n" + "Type: " + data + "\n\n"
163
- #all_demo.append(demo)
164
-
165
-
166
- #random.shuffle(all_demo)
167
- #all_demo = "".join(all_demo)
168
-
169
- return all_demo
170
-
171
-
172
- def type_for_dataset(dataset_name):
173
- if dataset_name in ("addsub", "aqua", "gsm8k", "multiarith", "singleeq", "svamp"):
174
- type = "arithmetic"
175
- elif dataset_name == "commonsensqa":
176
- type = "commonsense-mc"
177
- elif dataset_name == "strategyqa":
178
- type = "commonsense-verify"
179
- elif dataset_name == "coin_flip":
180
- type = "symbolic-coin"
181
- elif dataset_name == "last_letters":
182
- type = "symbolic-letter"
183
- #elif dataset_name in ("commonsensqa", "strategyqa"):
184
- # type = "commonsense"
185
- #elif dataset_name in ("coin_flip", "last_letters"):
186
- # type = "symbolic"
187
- else:
188
- type = None
189
- return type
190
-
191
- def get_type_prompt(all_demo):
192
- total_prompt = []
193
- for dataset_name, question_string in all_demo.items():
194
- demo = question_string[0] + "Type: " + type_for_dataset(dataset_name) + "\n\n"
195
- total_prompt += demo
196
- total_prompt = "".join(total_prompt)
197
- return total_prompt
198
-
199
- def identify_type(question, engine):
200
- with open('./demos/type', 'r') as f:
201
- typedemo = f.read()
202
- typedemo += "Question: " + question + "\nOutput the Type, choosing from <'arithmetic','commonsense-mc','commonsense-verify','symbolic-coin', 'symbolic-letter'>: "
203
- response = decoder_for_gpt3(typedemo, 32, temperature=0, engine=engine)
204
- response = response.strip().lower()
205
-
206
- return response
207
-
208
-
209
- if __name__ == "__main__":
210
- all_demo = sample_type_demo(num_type=1)
211
- #print(all_demo)
212
- total_prompt = get_type_prompt(all_demo)
213
- print(total_prompt)
214
- with open('./demos/type', 'w') as f:
215
- data_json = json.dumps(total_prompt)
216
- f.write(data_json + "\n")
217
- #with open('./demos/type', 'r') as f:
218
- # data = f.read()
219
- # print(type(data))
220
- question = "Did the 40th president of the United States forward lolcats to his friends?"
221
- engine = "text-davinci-003"
222
- res = identify_type(question, engine)
223
- print(res)