|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Preprocessing for Wizard of Wikipedia and Wizard of Internet datasets""" |
|
|
|
import torch |
|
import argparse |
|
from nltk import word_tokenize |
|
from tqdm import tqdm |
|
import numpy as np |
|
import json |
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser(description="Preprocessing") |
|
|
|
parser.add_argument("--func", type=str, default=None, |
|
help="choose to run which function") |
|
parser.add_argument("--raw_file", type=str, default=None, |
|
help="path of the input file") |
|
parser.add_argument("--processed_file", type=str, default=None, |
|
help="path of the output file") |
|
parser.add_argument("--knwl_ref_file", type=str, default=None, |
|
help="path of the knowledge reference file") |
|
parser.add_argument("--resp_ref_file", type=str, default=None, |
|
help="path of the knowledge reference file") |
|
parser.add_argument("--knwl_gen_file", type=str, default=None, |
|
help="path of the generated knowledge file") |
|
parser.add_argument("--test_file", type=str, default=None, |
|
help="path of the test file") |
|
parser.add_argument("--train_file", type=str, default=None, |
|
help="path of the train file") |
|
parser.add_argument("--model_file", type=str, default=None, |
|
help="path of the model file") |
|
parser.add_argument("--data_type", type=str, default=None, |
|
help="data types, choose one out of three types: \ |
|
wow_seen, wow_unseen, and woi") |
|
parser.add_argument("--seed", type=int, default=1234, |
|
help="random seed") |
|
|
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def process_wow_dataset(raw_file, processed_file, knwl_ref_file, resp_ref_file): |
|
""" |
|
This is a function used for processing the wizard of wikipedia (wow) dataset |
|
Expected processed format: |
|
topic \t dialogue context \t golden knowledge \t golden response |
|
""" |
|
|
|
|
|
print("> Loading data from %s" % raw_file) |
|
with open(raw_file, "r") as fr: |
|
dialog_data = json.load(fr) |
|
|
|
print("> Processing data ...") |
|
fproc = open(processed_file, "w") |
|
fknwl = open(knwl_ref_file, "w") if knwl_ref_file else None |
|
fresp = open(resp_ref_file, "w") if resp_ref_file else None |
|
|
|
for i, sample in enumerate(tqdm(dialog_data)): |
|
|
|
dialog = sample["dialog"] |
|
|
|
turn_list = [] |
|
|
|
for j, turn in enumerate(dialog): |
|
|
|
text = turn["text"] |
|
if not (text.endswith("?") or text.endswith(".") or text.endswith("!")): |
|
text = text + "." |
|
|
|
if j == 0: |
|
|
|
turn_list.append(text) |
|
continue |
|
|
|
speaker = turn["speaker"].lower() |
|
if "wizard" in speaker: |
|
checked_sentence = list(turn["checked_sentence"].values()) |
|
checked_passage = list(turn["checked_passage"].values()) |
|
|
|
assert len(checked_sentence) <= 1 |
|
|
|
|
|
if len(checked_sentence) > 0: |
|
checked_sentence = checked_sentence[0] |
|
else: |
|
checked_sentence = "no_passages_used" |
|
|
|
if len(checked_passage) == 1: |
|
checked_passage = checked_passage[0] |
|
else: |
|
checked_passage = "no_passages_used" |
|
|
|
|
|
if checked_passage != "no_passages_used": |
|
topic = checked_passage |
|
else: |
|
topic = sample["chosen_topic"] |
|
|
|
dialog_context = " [SEP] ".join(turn_list) |
|
knowledge = checked_sentence |
|
response = text |
|
|
|
turn_list.append(response) |
|
|
|
|
|
fproc.write(topic + "\t" + dialog_context + "\t" + \ |
|
knowledge + "\t" + response + "\n") |
|
|
|
if fknwl: |
|
fknwl.write(knowledge + "\n") |
|
if fresp: |
|
|
|
response = " ".join(word_tokenize(response)) |
|
fresp.write(response + "\n") |
|
|
|
else: |
|
assert "apprentice" in speaker |
|
turn_list.append(text) |
|
|
|
fproc.close() |
|
if fknwl: |
|
fknwl.close() |
|
if fresp: |
|
fresp.close() |
|
|
|
|
|
def process_woi_dataset(raw_file, processed_file, knwl_ref_file, resp_ref_file): |
|
""" |
|
This is a function used for processing the wizard of internet (woi) dataset |
|
Expected processed format: |
|
topic \t dialogue context \t golden knowledge \t golden response |
|
""" |
|
|
|
print("> Processing %s" % raw_file) |
|
fproc = open(processed_file, "w") |
|
fknwl = open(knwl_ref_file, "w") if knwl_ref_file else None |
|
fresp = open(resp_ref_file, "w") if resp_ref_file else None |
|
|
|
with open(raw_file, "r") as fr: |
|
for i, line in tqdm(enumerate(fr)): |
|
|
|
line = line.strip() |
|
item_dict = json.loads(line) |
|
|
|
|
|
|
|
item_dict = item_dict.values() |
|
item_dict = list(item_dict)[0] |
|
|
|
|
|
dialog_data = item_dict['dialog_history'] |
|
length = len(dialog_data) |
|
|
|
turn_list = [] |
|
search_text = "" |
|
for i in range(length): |
|
item = dialog_data[i] |
|
action = item['action'] |
|
|
|
if action == "Wizard => SearchAgent": |
|
search_text = item['text'] |
|
|
|
elif action == "Wizard => Apprentice": |
|
if len(turn_list) == 0: |
|
|
|
turn = item['text'] |
|
turn_list.append(turn) |
|
continue |
|
|
|
|
|
contents = item["context"]["contents"] |
|
selects = item["context"]["selected_contents"] |
|
flag = selects[0][0] |
|
selects = selects[1:] |
|
assert len(selects) == len(contents) |
|
|
|
|
|
if flag: |
|
|
|
topic = "no_topic" |
|
knwl_sent = "no_passages_used" |
|
else: |
|
|
|
topic = search_text |
|
|
|
knwl_sent = "" |
|
for content, select in zip(contents, selects): |
|
content = content['content'] |
|
assert len(content) == len(select) |
|
for c, s in zip(content, select): |
|
if s: |
|
knwl_sent = c |
|
break |
|
|
|
if knwl_sent == "": |
|
|
|
topic = "no_topic" |
|
knwl_sent = "no_passages_used" |
|
|
|
|
|
dialog_context = " [SEP] ".join(turn_list) |
|
response = item['text'] |
|
|
|
|
|
topic = topic.replace("\n", "").replace("\r", \ |
|
"").replace("\t", "") |
|
dialog_context = dialog_context.replace("\n", "").replace("\r", \ |
|
"").replace("\t", "") |
|
knwl_sent = knwl_sent.replace("\n", "").replace("\r", \ |
|
"").replace("\t", "") |
|
response = response.replace("\n", "").replace("\r", \ |
|
"").replace("\t", "") |
|
|
|
if topic != "no_topic": |
|
|
|
fproc.write(topic + "\t" + dialog_context + "\t" + \ |
|
knwl_sent + "\t" + response + "\n") |
|
if fknwl: |
|
fknwl.write(knwl_sent + "\n") |
|
if fresp: |
|
|
|
response = " ".join(word_tokenize(response)) |
|
fresp.write(response + "\n") |
|
|
|
turn_list.append(response) |
|
|
|
elif action == "Apprentice => Wizard": |
|
turn = item['text'] |
|
turn_list.append(turn) |
|
|
|
else: |
|
assert action == "SearchAgent => Wizard", \ |
|
"Please check whether you have used the correct data!" |
|
|
|
fproc.close() |
|
if fknwl: |
|
fknwl.close() |
|
if fresp: |
|
fresp.close() |
|
|
|
|
|
def get_database(test_datapath, train_datapath, data_type): |
|
"""Get the database by topics""" |
|
|
|
assert data_type in ["wow_seen", "wow_unseen", "woi"], \ |
|
"Please input a correct data type!!" |
|
|
|
|
|
print("> reading test data from %s" % test_datapath) |
|
test_topics = {} |
|
with open(test_datapath, "r") as f: |
|
for i, line in enumerate(f): |
|
line = line.strip() |
|
splits = line.split("\t") |
|
topic = splits[0] |
|
test_topics[topic] = True |
|
|
|
print("> reading data from %s" % train_datapath) |
|
train_data_by_topic = {} |
|
dialog_data_by_topic = {} |
|
dialog_examples = [] |
|
with open(train_datapath, "r") as f: |
|
for i, line in enumerate(f): |
|
line = line.strip() |
|
splits = line.split("\t") |
|
topic = splits[0] |
|
turns = splits[1].split(" [SEP] ")[-3:] |
|
knowledge = splits[2] |
|
response = splits[3] |
|
|
|
if knowledge == "no_passages_used": |
|
|
|
continue |
|
if data_type != "wow_seen" and ("(" in knowledge or ")" in knowledge): |
|
|
|
continue |
|
if data_type != "wow_seen" and topic not in knowledge: |
|
|
|
continue |
|
|
|
|
|
last_turn = turns[-1] |
|
instance = "( " + last_turn + " ) " + topic + " => " + knowledge |
|
|
|
|
|
dialog_example = "" |
|
if data_type != "wow_seen": |
|
dialog_example += "( " + topic + " ) " |
|
for i, turn in enumerate(turns): |
|
if i != 0: |
|
dialog_example += " " |
|
dialog_example += turn |
|
|
|
|
|
if topic in test_topics: |
|
if topic not in train_data_by_topic: |
|
train_data_by_topic[topic] = [instance] |
|
else: |
|
train_data_by_topic[topic].append(instance) |
|
|
|
if topic not in dialog_data_by_topic: |
|
dialog_data_by_topic[topic] = [dialog_example] |
|
else: |
|
dialog_data_by_topic[topic].append(dialog_example) |
|
|
|
else: |
|
|
|
if len(knowledge.split()) > 20: |
|
|
|
continue |
|
if knowledge.startswith("It") or knowledge.startswith("it") or \ |
|
knowledge.startswith("This") or knowledge.startswith("this"): |
|
continue |
|
|
|
|
|
dialog_examples.append((topic, dialog_example, instance)) |
|
|
|
return train_data_by_topic, dialog_data_by_topic, dialog_examples |
|
|
|
|
|
emb_dict = {} |
|
def select_prompts_based_on_similarity( |
|
query, dialog_list, prompt_list, topic, tokenizer, encoder, topk): |
|
"""Select samples based on the similarity""" |
|
|
|
with torch.no_grad(): |
|
|
|
query_ids = tokenizer.encode(query) |
|
query_ids = torch.LongTensor([query_ids]).cuda() |
|
query_emb = encoder(input_ids=query_ids).pooler_output |
|
query_emb = query_emb[0] |
|
|
|
|
|
if topic in emb_dict: |
|
example_embeddings = emb_dict[topic] |
|
example_embeddings = example_embeddings.cuda() |
|
else: |
|
for idx, example in enumerate(dialog_list): |
|
example_ids = tokenizer.encode(example) |
|
example_ids = torch.LongTensor([example_ids]).cuda() |
|
example_emb = encoder(input_ids=example_ids).pooler_output |
|
if idx == 0: |
|
example_embeddings = example_emb |
|
else: |
|
example_embeddings = torch.cat( |
|
(example_embeddings, example_emb), dim=0) |
|
emb_dict[topic] = example_embeddings.cpu() |
|
|
|
|
|
similarity_list = example_embeddings.matmul(query_emb) |
|
_, indices = torch.topk(similarity_list, k=topk) |
|
|
|
indices = indices.tolist() |
|
indices = indices[::-1] |
|
selected_prompts = [] |
|
for index in indices: |
|
|
|
selected_prompts.append(prompt_list[index]) |
|
|
|
return selected_prompts |
|
|
|
|
|
def prompt_selection_for_knowledge_generation( |
|
test_datapath, train_datapath, model_path, output_prompt_path, data_type): |
|
"""Selecting prompts for the knowledge generation""" |
|
|
|
print("> Selecting prompts for the knowledge generation") |
|
|
|
train_data_by_topic, dialog_data_by_topic, dialog_examples = \ |
|
get_database(test_datapath, train_datapath, data_type) |
|
|
|
from transformers import DPRQuestionEncoderTokenizer |
|
print("> loading tokenizer and encoder") |
|
tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( |
|
'facebook/dpr-question_encoder-single-nq-base') |
|
encoder = torch.load(model_path).cuda() |
|
|
|
print("> getting dialog embeddings") |
|
with torch.no_grad(): |
|
for idx, example in tqdm(enumerate(dialog_examples)): |
|
dialog = example[1] |
|
dialog_ids = tokenizer.encode(dialog) |
|
dialog_ids = torch.LongTensor([dialog_ids]).cuda() |
|
dialog_emb = encoder(input_ids=dialog_ids).pooler_output |
|
|
|
if idx == 0: |
|
dialog_embeddings = dialog_emb |
|
else: |
|
dialog_embeddings = torch.cat((dialog_embeddings, dialog_emb), dim=0) |
|
|
|
print("> reading test data from %s" % test_datapath) |
|
prompt_list_for_each_sample = [] |
|
with open(test_datapath, "r") as f: |
|
for i, line in tqdm(enumerate(f)): |
|
line = line.strip() |
|
|
|
splits = line.split("\t") |
|
topic = splits[0] |
|
turns = splits[1].split(" [SEP] ")[-3:] |
|
|
|
|
|
query_sent = "" |
|
if data_type != "seen": |
|
query_sent += "( " + topic + " ) " |
|
for i, turn in enumerate(turns): |
|
if i != 0: |
|
query_sent += " " |
|
query_sent += turn |
|
|
|
if topic not in train_data_by_topic: |
|
|
|
query_ids = tokenizer.encode(query_sent) |
|
query_ids = torch.LongTensor([query_ids]).cuda() |
|
query_emb = encoder(input_ids=query_ids).pooler_output |
|
query_emb = query_emb[0] |
|
|
|
|
|
similarity_list = dialog_embeddings.matmul(query_emb) |
|
_, indices = torch.sort(similarity_list) |
|
indices = indices.tolist() |
|
selected_topics = {} |
|
selected_prompts = [] |
|
num_prompt = 0 |
|
for index in indices: |
|
example = dialog_examples[index] |
|
topic_temp = example[0] |
|
if topic_temp not in selected_topics: |
|
selected_topics[topic_temp] = True |
|
selected_prompts.append(example[2]) |
|
num_prompt += 1 |
|
if num_prompt == 10: |
|
break |
|
|
|
|
|
example_list = selected_prompts[::-1] |
|
key = topic + " " + turns[-1] |
|
prompt_list_for_each_sample.append({key: example_list}) |
|
|
|
else: |
|
num_data_sample = min(len(train_data_by_topic[topic]), 10) |
|
total_example_list = train_data_by_topic[topic] |
|
|
|
dialog_list = dialog_data_by_topic[topic] |
|
assert len(dialog_list) == len(train_data_by_topic[topic]) |
|
|
|
|
|
example_list = select_prompts_based_on_similarity( |
|
query_sent, dialog_list, total_example_list, |
|
topic, tokenizer, encoder, topk=num_data_sample) |
|
|
|
key = topic + " " + turns[-1] |
|
prompt_list_for_each_sample.append({key: example_list}) |
|
|
|
print("writing to %s" % output_prompt_path) |
|
with open(output_prompt_path, "w") as f: |
|
for instance in tqdm(prompt_list_for_each_sample): |
|
json.dump(instance, f) |
|
f.write("\n") |
|
|
|
|
|
def prompt_selection_for_response_generation(input_path, output_path, seed): |
|
"""Selecting prompts for the response generation""" |
|
|
|
print("> Selecting prompts for the response generation") |
|
print("> set random seed") |
|
np.random.seed(seed) |
|
|
|
prompt_example_list = [] |
|
print("> reading data from %s" % input_path) |
|
with open(input_path, "r") as f: |
|
for i, line in tqdm(enumerate(f)): |
|
line = line.strip() |
|
splits = line.split("\t") |
|
|
|
|
|
topic = splits[0] |
|
dialog_context = splits[1] |
|
knowledge = splits[2] |
|
response = splits[3] |
|
turns = dialog_context.split(" [SEP] ")[-3:] |
|
if knowledge == "no_passages_used": |
|
continue |
|
|
|
|
|
from nltk import word_tokenize |
|
knowledge_sent_token_list = word_tokenize(knowledge) |
|
knowledge_sent_token_dict = {token: True for token in knowledge_sent_token_list} |
|
knowledge_len = len(knowledge_sent_token_list) |
|
response_token_list = word_tokenize(response) |
|
response_len = len(response_token_list) |
|
num_overlap_token = 0 |
|
accumulator = 0 |
|
for token in response_token_list: |
|
if token in knowledge_sent_token_dict: |
|
accumulator += 1 |
|
else: |
|
if accumulator >= 10: |
|
num_overlap_token += accumulator |
|
accumulator = 0 |
|
if accumulator >= 10: |
|
num_overlap_token += accumulator |
|
|
|
|
|
if num_overlap_token > response_len * 0.9 or num_overlap_token < response_len * 0.6: |
|
continue |
|
if num_overlap_token < knowledge_len * 0.8: |
|
continue |
|
|
|
last_turn = " ".join(word_tokenize(turns[-1])) |
|
knowledge = " ".join(word_tokenize(knowledge)) |
|
response = " ".join(word_tokenize(response)) |
|
prompt_example = "" |
|
|
|
prompt_example += "Topic: " + topic + ". " |
|
prompt_example += "User says: " + last_turn + " " |
|
prompt_example += "We know that: " + knowledge + " " |
|
prompt_example += "System replies: " + response |
|
|
|
prompt_example_list.append(prompt_example) |
|
|
|
|
|
np.random.shuffle(prompt_example_list) |
|
|
|
print("> writing to %s" % output_path) |
|
with open(output_path, "w") as f: |
|
|
|
for i in tqdm(range(20)): |
|
example = prompt_example_list[i] |
|
f.write(example + "\n") |
|
|
|
|
|
def prepare_input_for_response_generation(test_file, knwl_gen_file, processed_file): |
|
"""Preparing inputs for the response generation""" |
|
|
|
print("> Reading knowledge file from %s" % knwl_gen_file) |
|
|
|
with open(knwl_gen_file, "r") as f: |
|
knowledge_list = f.readlines() |
|
|
|
print("> Processing ...") |
|
with open(test_file, "r") as fr: |
|
with open(processed_file, "w") as fw: |
|
for line_num, line in enumerate(tqdm(fr)): |
|
line = line.strip() |
|
splits = line.split("\t") |
|
|
|
topic = splits[0] |
|
dialog_context = splits[1] |
|
response = splits[3] |
|
knowledge = knowledge_list[line_num] |
|
knowledge = knowledge.strip() |
|
if "<|endoftext|>" in knowledge: |
|
knowledge = knowledge.replace("<|endoftext|>", "") |
|
|
|
|
|
fw.write(topic + "\t" + dialog_context + "\t" \ |
|
+ knowledge + "\t" + response + "\n") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
args = get_args() |
|
if args.func == "process_wow_dataset": |
|
process_wow_dataset(args.raw_file, args.processed_file, args.knwl_ref_file, args.resp_ref_file) |
|
|
|
elif args.func == "process_woi_dataset": |
|
process_woi_dataset(args.raw_file, args.processed_file, args.knwl_ref_file, args.resp_ref_file) |
|
|
|
elif args.func == "get_knwl_gen_prompts": |
|
prompt_selection_for_knowledge_generation( |
|
args.test_file, args.train_file, args.model_file, |
|
args.processed_file, args.data_type) |
|
|
|
elif args.func == "get_resp_gen_prompts": |
|
prompt_selection_for_response_generation( |
|
args.train_file, args.processed_file, args.seed) |
|
|
|
elif args.func == "prepare_input": |
|
prepare_input_for_response_generation( |
|
args.test_file, args.knwl_gen_file, args.processed_file) |
|
|