Nolwenn
Initial commit
b599481
raw
history blame
3.18 kB
import json
from copy import copy
with open("../opendialkg/train_data_id.json", "r", encoding="utf-8") as f:
train_data_id = json.load(f)
train_data_id_set = set(train_data_id)
print(len(train_data_id_set))
with open("../opendialkg/valid_data_id.json", "r", encoding="utf-8") as f:
valid_data_id = json.load(f)
valid_data_id_set = set(valid_data_id)
print(len(valid_data_id_set))
with open("../opendialkg/test_data_id.json", "r", encoding="utf-8") as f:
test_data_id = json.load(f)
test_data_id_set = set(test_data_id)
print(len(test_data_id_set))
with open("../opendialkg/data.jsonl", "r", encoding="utf-8") as f, open(
"train_data_processed.jsonl", "w", encoding="utf-8"
) as train_w, open(
"valid_data_processed.jsonl", "w", encoding="utf-8"
) as valid_w, open(
"test_data_processed.jsonl", "w", encoding="utf-8"
) as test_w:
lines = f.readlines()
for line in lines:
dialog = json.loads(line)
context_list = []
entity_list = []
for message in dialog["dialog"]:
role = message["role"]
text = message["text"]
# mask_text = message['text_template']
entity_turn = message["entity"]
item_turn = message["item"]
dialog_turn_id = (
str(dialog["dialog_id"]) + "_" + str(message["turn_id"])
)
if dialog_turn_id in train_data_id_set:
data = {
"dialog_id": dialog["dialog_id"],
"turn_id": message["turn_id"],
"context": copy(context_list),
"entity": copy(entity_list),
"rec": copy(item_turn),
"resp": text,
}
train_w.write(json.dumps(data, ensure_ascii=False) + "\n")
elif dialog_turn_id in valid_data_id_set:
data = {
"dialog_id": dialog["dialog_id"],
"turn_id": message["turn_id"],
"context": copy(context_list),
"entity": copy(entity_list),
"rec": copy(item_turn),
"resp": text,
}
valid_w.write(json.dumps(data, ensure_ascii=False) + "\n")
elif dialog_turn_id in test_data_id_set:
data = {
"dialog_id": dialog["dialog_id"],
"turn_id": message["turn_id"],
"context": copy(context_list),
"entity": copy(entity_list),
"rec": copy(item_turn),
"resp": text,
}
test_w.write(json.dumps(data, ensure_ascii=False) + "\n")
context_list.append(text)
entity_list.extend(entity_turn)
with open("train_data_processed.jsonl", "r", encoding="utf-8") as f:
lines = f.readlines()
print("train:", len(lines))
with open("valid_data_processed.jsonl", "r", encoding="utf-8") as f:
lines = f.readlines()
print("valid:", len(lines))
with open("test_data_processed.jsonl", "r", encoding="utf-8") as f:
lines = f.readlines()
print("test:", len(lines))