AuRoRA / datatype_sampling.py
Anni123's picture
Upload folder using huggingface_hub
b6a1d8d
raw
history blame
8.61 kB
import json
import random
from llm_utils import *
DATA_PATHS = {
"addsub": "./dataset/AddSub/AddSub.json",
#"aqua": "./dataset/AQuA/test.json",
#"bigbench_date": "./dataset/Bigbench_Date/task.json",
#"object_tracking": "./dataset/Bigbench_object_tracking/task.json",
"coin_flip": "./dataset/coin_flip/coin_flip.json",
"commonsensqa": "./dataset/CommonsenseQA/dev_rand_split.jsonl",
"gsm8k": "./dataset/grade-school-math/test.jsonl",
"last_letters": "./dataset/last_letters/last_letters.json",
"multiarith": "./dataset/MultiArith/MultiArith.json",
"strategyqa": "./dataset/StrategyQA/task.json",
"singleeq": "./dataset/SingleEq/questions.json",
"svamp": "./dataset/SVAMP/SVAMP.json",
}
# https://review-of-my-life.blogspot.com/2017/11/python-dict-shuffle.html
def shuffleDict(d):
keys = list(d.keys())
random.shuffle(keys)
[(key, d[key]) for key in keys]
random.shuffle(keys)
[(key, d[key]) for key in keys]
random.shuffle(keys)
keys = [(key, d[key]) for key in keys]
#keys = d(keys)
return dict(keys)
def sample_type_demo(num_type=1):
decoder = json.JSONDecoder()
all_demo = {}
for data, datapath in DATA_PATHS.items():
'''
if data == "aqua":
questions = []
with open(datapath) as f:
lines = f.readlines()
for line in lines:
json_res = decoder.raw_decode(line)[0]
choice = "(" + "(".join(json_res["options"])
choice = choice.replace("(", " (").replace(")", ") ")
choice = "Answer Choices:" + choice
questions.append(json_res["question"].strip() + " " + choice)
questions = random.sample(questions, num_type)
if data not in all_demo.keys():
all_demo[data] = []
for que in questions:
que_string = "Question: " + que + "\n"
all_demo[data].append(que_string)
#demo = "Question: " + que + "\n" + "Type: " + data + "\n\n"
#all_demo.append(demo)
'''
if data == "gsm8k":
questions = []
with open(datapath) as f:
lines = f.readlines()
for line in lines:
json_res = decoder.raw_decode(line)[0]
questions.append(json_res["question"].strip())
questions = random.sample(questions, num_type)
if data not in all_demo.keys():
all_demo[data] = []
for que in questions:
que_string = "Question: " + que + "\n"
all_demo[data].append(que_string)
#demo = "Question: " + que + "\n" + "Type: " + data + "\n\n"
#all_demo.append(demo)
elif data == "commonsensqa":
questions = []
with open(datapath) as f:
lines = f.readlines()
for line in lines:
json_res = decoder.raw_decode(line)[0]
choice = "Answer Choices:"
for c in json_res["question"]["choices"]:
choice += " ("
choice += c["label"]
choice += ") "
choice += c["text"]
questions.append(json_res["question"]["stem"].strip() + " " + choice)
questions = random.sample(questions, num_type)
if data not in all_demo.keys():
all_demo[data] = []
for que in questions:
que_string = "Question: " + que + "\n"
all_demo[data].append(que_string)
#demo = "Question: " + que + "\n" + "Type: " + data + "\n\n"
#all_demo.append(demo)
elif data in ("addsub", "multiarith", "singleeq"):
questions = []
with open(datapath) as f:
json_data = json.load(f)
for line in json_data:
q = line["sQuestion"].strip()
questions.append(q)
questions = random.sample(questions, num_type)
if data not in all_demo.keys():
all_demo[data] = []
for que in questions:
que_string = "Question: " + que + "\n"
all_demo[data].append(que_string)
#demo = "Question: " + que + "\n" + "Type: " + data + "\n\n"
#all_demo.append(demo)
elif data == "strategyqa":
questions = []
with open(datapath) as f:
json_data = json.load(f)["examples"]
for line in json_data:
q = line["input"].strip()
questions.append(q)
questions = random.sample(questions, num_type)
if data not in all_demo.keys():
all_demo[data] = []
for que in questions:
que_string = "Question: " + que + "\n"
all_demo[data].append(que_string)
#demo = "Question: " + que + "\n" + "Type: " + data + "\n\n"
#all_demo.append(demo)
elif data == "svamp":
questions = []
with open(datapath) as f:
json_data = json.load(f)
for line in json_data:
q = line["Body"].strip() + " " + line["Question"].strip()
questions.append(q)
questions = random.sample(questions, num_type)
if data not in all_demo.keys():
all_demo[data] = []
for que in questions:
que_string = "Question: " + que + "\n"
all_demo[data].append(que_string)
#demo = "Question: " + que + "\n" + "Type: " + data + "\n\n"
#all_demo.append(demo)
elif data in ("coin_flip", "last_letters"):
questions = []
with open(datapath) as f:
json_data = json.load(f)
json_data = json_data["examples"]
for line in json_data:
q = line["question"]
questions.append(q)
questions = random.sample(questions, num_type)
if data not in all_demo.keys():
all_demo[data] = []
for que in questions:
que_string = "Question: " + que + "\n"
all_demo[data].append(que_string)
#demo = "Question: " + que + "\n" + "Type: " + data + "\n\n"
#all_demo.append(demo)
#random.shuffle(all_demo)
#all_demo = "".join(all_demo)
return all_demo
def type_for_dataset(dataset_name):
if dataset_name in ("addsub", "aqua", "gsm8k", "multiarith", "singleeq", "svamp"):
type = "arithmetic"
elif dataset_name == "commonsensqa":
type = "commonsense-mc"
elif dataset_name == "strategyqa":
type = "commonsense-verify"
elif dataset_name == "coin_flip":
type = "symbolic-coin"
elif dataset_name == "last_letters":
type = "symbolic-letter"
#elif dataset_name in ("commonsensqa", "strategyqa"):
# type = "commonsense"
#elif dataset_name in ("coin_flip", "last_letters"):
# type = "symbolic"
else:
type = None
return type
def get_type_prompt(all_demo):
total_prompt = []
for dataset_name, question_string in all_demo.items():
demo = question_string[0] + "Type: " + type_for_dataset(dataset_name) + "\n\n"
total_prompt += demo
total_prompt = "".join(total_prompt)
return total_prompt
def identify_type(question, engine):
with open('./demos/type', 'r') as f:
typedemo = f.read()
typedemo += "Question: " + question + "\nOutput the Type, choosing from <'arithmetic','commonsense-mc','commonsense-verify','symbolic-coin', 'symbolic-letter'>: "
response = decoder_for_gpt3(typedemo, 32, temperature=0, engine=engine)
response = response.strip().lower()
return response
if __name__ == "__main__":
all_demo = sample_type_demo(num_type=1)
#print(all_demo)
total_prompt = get_type_prompt(all_demo)
print(total_prompt)
with open('./demos/type', 'w') as f:
data_json = json.dumps(total_prompt)
f.write(data_json + "\n")
#with open('./demos/type', 'r') as f:
# data = f.read()
# print(type(data))
question = "Did the 40th president of the United States forward lolcats to his friends?"
engine = "text-davinci-003"
res = identify_type(question, engine)
print(res)