import json import random from llm_utils import * DATA_PATHS = { "addsub": "./dataset/AddSub/AddSub.json", #"aqua": "./dataset/AQuA/test.json", #"bigbench_date": "./dataset/Bigbench_Date/task.json", #"object_tracking": "./dataset/Bigbench_object_tracking/task.json", "coin_flip": "./dataset/coin_flip/coin_flip.json", "commonsensqa": "./dataset/CommonsenseQA/dev_rand_split.jsonl", "gsm8k": "./dataset/grade-school-math/test.jsonl", "last_letters": "./dataset/last_letters/last_letters.json", "multiarith": "./dataset/MultiArith/MultiArith.json", "strategyqa": "./dataset/StrategyQA/task.json", "singleeq": "./dataset/SingleEq/questions.json", "svamp": "./dataset/SVAMP/SVAMP.json", } # https://review-of-my-life.blogspot.com/2017/11/python-dict-shuffle.html def shuffleDict(d): keys = list(d.keys()) random.shuffle(keys) [(key, d[key]) for key in keys] random.shuffle(keys) [(key, d[key]) for key in keys] random.shuffle(keys) keys = [(key, d[key]) for key in keys] #keys = d(keys) return dict(keys) def sample_type_demo(num_type=1): decoder = json.JSONDecoder() all_demo = {} for data, datapath in DATA_PATHS.items(): ''' if data == "aqua": questions = [] with open(datapath) as f: lines = f.readlines() for line in lines: json_res = decoder.raw_decode(line)[0] choice = "(" + "(".join(json_res["options"]) choice = choice.replace("(", " (").replace(")", ") ") choice = "Answer Choices:" + choice questions.append(json_res["question"].strip() + " " + choice) questions = random.sample(questions, num_type) if data not in all_demo.keys(): all_demo[data] = [] for que in questions: que_string = "Question: " + que + "\n" all_demo[data].append(que_string) #demo = "Question: " + que + "\n" + "Type: " + data + "\n\n" #all_demo.append(demo) ''' if data == "gsm8k": questions = [] with open(datapath) as f: lines = f.readlines() for line in lines: json_res = decoder.raw_decode(line)[0] questions.append(json_res["question"].strip()) questions = random.sample(questions, num_type) if data not in all_demo.keys(): all_demo[data] = [] for que in questions: que_string = "Question: " + que + "\n" all_demo[data].append(que_string) #demo = "Question: " + que + "\n" + "Type: " + data + "\n\n" #all_demo.append(demo) elif data == "commonsensqa": questions = [] with open(datapath) as f: lines = f.readlines() for line in lines: json_res = decoder.raw_decode(line)[0] choice = "Answer Choices:" for c in json_res["question"]["choices"]: choice += " (" choice += c["label"] choice += ") " choice += c["text"] questions.append(json_res["question"]["stem"].strip() + " " + choice) questions = random.sample(questions, num_type) if data not in all_demo.keys(): all_demo[data] = [] for que in questions: que_string = "Question: " + que + "\n" all_demo[data].append(que_string) #demo = "Question: " + que + "\n" + "Type: " + data + "\n\n" #all_demo.append(demo) elif data in ("addsub", "multiarith", "singleeq"): questions = [] with open(datapath) as f: json_data = json.load(f) for line in json_data: q = line["sQuestion"].strip() questions.append(q) questions = random.sample(questions, num_type) if data not in all_demo.keys(): all_demo[data] = [] for que in questions: que_string = "Question: " + que + "\n" all_demo[data].append(que_string) #demo = "Question: " + que + "\n" + "Type: " + data + "\n\n" #all_demo.append(demo) elif data == "strategyqa": questions = [] with open(datapath) as f: json_data = json.load(f)["examples"] for line in json_data: q = line["input"].strip() questions.append(q) questions = random.sample(questions, num_type) if data not in all_demo.keys(): all_demo[data] = [] for que in questions: que_string = "Question: " + que + "\n" all_demo[data].append(que_string) #demo = "Question: " + que + "\n" + "Type: " + data + "\n\n" #all_demo.append(demo) elif data == "svamp": questions = [] with open(datapath) as f: json_data = json.load(f) for line in json_data: q = line["Body"].strip() + " " + line["Question"].strip() questions.append(q) questions = random.sample(questions, num_type) if data not in all_demo.keys(): all_demo[data] = [] for que in questions: que_string = "Question: " + que + "\n" all_demo[data].append(que_string) #demo = "Question: " + que + "\n" + "Type: " + data + "\n\n" #all_demo.append(demo) elif data in ("coin_flip", "last_letters"): questions = [] with open(datapath) as f: json_data = json.load(f) json_data = json_data["examples"] for line in json_data: q = line["question"] questions.append(q) questions = random.sample(questions, num_type) if data not in all_demo.keys(): all_demo[data] = [] for que in questions: que_string = "Question: " + que + "\n" all_demo[data].append(que_string) #demo = "Question: " + que + "\n" + "Type: " + data + "\n\n" #all_demo.append(demo) #random.shuffle(all_demo) #all_demo = "".join(all_demo) return all_demo def type_for_dataset(dataset_name): if dataset_name in ("addsub", "aqua", "gsm8k", "multiarith", "singleeq", "svamp"): type = "arithmetic" elif dataset_name == "commonsensqa": type = "commonsense-mc" elif dataset_name == "strategyqa": type = "commonsense-verify" elif dataset_name == "coin_flip": type = "symbolic-coin" elif dataset_name == "last_letters": type = "symbolic-letter" #elif dataset_name in ("commonsensqa", "strategyqa"): # type = "commonsense" #elif dataset_name in ("coin_flip", "last_letters"): # type = "symbolic" else: type = None return type def get_type_prompt(all_demo): total_prompt = [] for dataset_name, question_string in all_demo.items(): demo = question_string[0] + "Type: " + type_for_dataset(dataset_name) + "\n\n" total_prompt += demo total_prompt = "".join(total_prompt) return total_prompt def identify_type(question, engine): with open('./demos/type', 'r') as f: typedemo = f.read() typedemo += "Question: " + question + "\nOutput the Type, choosing from <'arithmetic','commonsense-mc','commonsense-verify','symbolic-coin', 'symbolic-letter'>: " response = decoder_for_gpt3(typedemo, 32, temperature=0, engine=engine) response = response.strip().lower() return response if __name__ == "__main__": all_demo = sample_type_demo(num_type=1) #print(all_demo) total_prompt = get_type_prompt(all_demo) print(total_prompt) with open('./demos/type', 'w') as f: data_json = json.dumps(total_prompt) f.write(data_json + "\n") #with open('./demos/type', 'r') as f: # data = f.read() # print(type(data)) question = "Did the 40th president of the United States forward lolcats to his friends?" engine = "text-davinci-003" res = identify_type(question, engine) print(res)