|
import json |
|
import random |
|
from llm_utils import * |
|
|
|
DATA_PATHS = { |
|
"addsub": "./dataset/AddSub/AddSub.json", |
|
|
|
|
|
|
|
"coin_flip": "./dataset/coin_flip/coin_flip.json", |
|
"commonsensqa": "./dataset/CommonsenseQA/dev_rand_split.jsonl", |
|
"gsm8k": "./dataset/grade-school-math/test.jsonl", |
|
"last_letters": "./dataset/last_letters/last_letters.json", |
|
"multiarith": "./dataset/MultiArith/MultiArith.json", |
|
"strategyqa": "./dataset/StrategyQA/task.json", |
|
"singleeq": "./dataset/SingleEq/questions.json", |
|
"svamp": "./dataset/SVAMP/SVAMP.json", |
|
} |
|
|
|
|
|
|
|
def shuffleDict(d): |
|
keys = list(d.keys()) |
|
random.shuffle(keys) |
|
[(key, d[key]) for key in keys] |
|
random.shuffle(keys) |
|
[(key, d[key]) for key in keys] |
|
random.shuffle(keys) |
|
keys = [(key, d[key]) for key in keys] |
|
|
|
return dict(keys) |
|
|
|
|
|
def sample_type_demo(num_type=1): |
|
decoder = json.JSONDecoder() |
|
all_demo = {} |
|
|
|
for data, datapath in DATA_PATHS.items(): |
|
''' |
|
if data == "aqua": |
|
questions = [] |
|
with open(datapath) as f: |
|
lines = f.readlines() |
|
for line in lines: |
|
json_res = decoder.raw_decode(line)[0] |
|
choice = "(" + "(".join(json_res["options"]) |
|
choice = choice.replace("(", " (").replace(")", ") ") |
|
choice = "Answer Choices:" + choice |
|
questions.append(json_res["question"].strip() + " " + choice) |
|
questions = random.sample(questions, num_type) |
|
if data not in all_demo.keys(): |
|
all_demo[data] = [] |
|
for que in questions: |
|
que_string = "Question: " + que + "\n" |
|
all_demo[data].append(que_string) |
|
#demo = "Question: " + que + "\n" + "Type: " + data + "\n\n" |
|
#all_demo.append(demo) |
|
''' |
|
|
|
if data == "gsm8k": |
|
questions = [] |
|
with open(datapath) as f: |
|
lines = f.readlines() |
|
for line in lines: |
|
json_res = decoder.raw_decode(line)[0] |
|
questions.append(json_res["question"].strip()) |
|
questions = random.sample(questions, num_type) |
|
if data not in all_demo.keys(): |
|
all_demo[data] = [] |
|
for que in questions: |
|
que_string = "Question: " + que + "\n" |
|
all_demo[data].append(que_string) |
|
|
|
|
|
|
|
|
|
elif data == "commonsensqa": |
|
questions = [] |
|
with open(datapath) as f: |
|
lines = f.readlines() |
|
for line in lines: |
|
json_res = decoder.raw_decode(line)[0] |
|
choice = "Answer Choices:" |
|
for c in json_res["question"]["choices"]: |
|
choice += " (" |
|
choice += c["label"] |
|
choice += ") " |
|
choice += c["text"] |
|
questions.append(json_res["question"]["stem"].strip() + " " + choice) |
|
questions = random.sample(questions, num_type) |
|
if data not in all_demo.keys(): |
|
all_demo[data] = [] |
|
for que in questions: |
|
que_string = "Question: " + que + "\n" |
|
all_demo[data].append(que_string) |
|
|
|
|
|
|
|
elif data in ("addsub", "multiarith", "singleeq"): |
|
questions = [] |
|
with open(datapath) as f: |
|
json_data = json.load(f) |
|
for line in json_data: |
|
q = line["sQuestion"].strip() |
|
questions.append(q) |
|
questions = random.sample(questions, num_type) |
|
if data not in all_demo.keys(): |
|
all_demo[data] = [] |
|
for que in questions: |
|
que_string = "Question: " + que + "\n" |
|
all_demo[data].append(que_string) |
|
|
|
|
|
|
|
elif data == "strategyqa": |
|
questions = [] |
|
with open(datapath) as f: |
|
json_data = json.load(f)["examples"] |
|
for line in json_data: |
|
q = line["input"].strip() |
|
questions.append(q) |
|
questions = random.sample(questions, num_type) |
|
if data not in all_demo.keys(): |
|
all_demo[data] = [] |
|
for que in questions: |
|
que_string = "Question: " + que + "\n" |
|
all_demo[data].append(que_string) |
|
|
|
|
|
|
|
elif data == "svamp": |
|
questions = [] |
|
with open(datapath) as f: |
|
json_data = json.load(f) |
|
for line in json_data: |
|
q = line["Body"].strip() + " " + line["Question"].strip() |
|
questions.append(q) |
|
questions = random.sample(questions, num_type) |
|
if data not in all_demo.keys(): |
|
all_demo[data] = [] |
|
for que in questions: |
|
que_string = "Question: " + que + "\n" |
|
all_demo[data].append(que_string) |
|
|
|
|
|
|
|
|
|
elif data in ("coin_flip", "last_letters"): |
|
questions = [] |
|
with open(datapath) as f: |
|
json_data = json.load(f) |
|
json_data = json_data["examples"] |
|
for line in json_data: |
|
q = line["question"] |
|
questions.append(q) |
|
questions = random.sample(questions, num_type) |
|
if data not in all_demo.keys(): |
|
all_demo[data] = [] |
|
for que in questions: |
|
que_string = "Question: " + que + "\n" |
|
all_demo[data].append(que_string) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return all_demo |
|
|
|
|
|
def type_for_dataset(dataset_name): |
|
if dataset_name in ("addsub", "aqua", "gsm8k", "multiarith", "singleeq", "svamp"): |
|
type = "arithmetic" |
|
elif dataset_name == "commonsensqa": |
|
type = "commonsense-mc" |
|
elif dataset_name == "strategyqa": |
|
type = "commonsense-verify" |
|
elif dataset_name == "coin_flip": |
|
type = "symbolic-coin" |
|
elif dataset_name == "last_letters": |
|
type = "symbolic-letter" |
|
|
|
|
|
|
|
|
|
else: |
|
type = None |
|
return type |
|
|
|
def get_type_prompt(all_demo): |
|
total_prompt = [] |
|
for dataset_name, question_string in all_demo.items(): |
|
demo = question_string[0] + "Type: " + type_for_dataset(dataset_name) + "\n\n" |
|
total_prompt += demo |
|
total_prompt = "".join(total_prompt) |
|
return total_prompt |
|
|
|
def identify_type(question, engine): |
|
with open('./demos/type', 'r') as f: |
|
typedemo = f.read() |
|
typedemo += "Question: " + question + "\nOutput the Type, choosing from <'arithmetic','commonsense-mc','commonsense-verify','symbolic-coin', 'symbolic-letter'>: " |
|
response = decoder_for_gpt3(typedemo, 32, temperature=0, engine=engine) |
|
response = response.strip().lower() |
|
|
|
return response |
|
|
|
|
|
if __name__ == "__main__": |
|
all_demo = sample_type_demo(num_type=1) |
|
|
|
total_prompt = get_type_prompt(all_demo) |
|
print(total_prompt) |
|
with open('./demos/type', 'w') as f: |
|
data_json = json.dumps(total_prompt) |
|
f.write(data_json + "\n") |
|
|
|
|
|
|
|
question = "Did the 40th president of the United States forward lolcats to his friends?" |
|
engine = "text-davinci-003" |
|
res = identify_type(question, engine) |
|
print(res) |
|
|