File size: 3,184 Bytes
b599481 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import json
from copy import copy
with open("../opendialkg/train_data_id.json", "r", encoding="utf-8") as f:
train_data_id = json.load(f)
train_data_id_set = set(train_data_id)
print(len(train_data_id_set))
with open("../opendialkg/valid_data_id.json", "r", encoding="utf-8") as f:
valid_data_id = json.load(f)
valid_data_id_set = set(valid_data_id)
print(len(valid_data_id_set))
with open("../opendialkg/test_data_id.json", "r", encoding="utf-8") as f:
test_data_id = json.load(f)
test_data_id_set = set(test_data_id)
print(len(test_data_id_set))
with open("../opendialkg/data.jsonl", "r", encoding="utf-8") as f, open(
"train_data_processed.jsonl", "w", encoding="utf-8"
) as train_w, open(
"valid_data_processed.jsonl", "w", encoding="utf-8"
) as valid_w, open(
"test_data_processed.jsonl", "w", encoding="utf-8"
) as test_w:
lines = f.readlines()
for line in lines:
dialog = json.loads(line)
context_list = []
entity_list = []
for message in dialog["dialog"]:
role = message["role"]
text = message["text"]
# mask_text = message['text_template']
entity_turn = message["entity"]
item_turn = message["item"]
dialog_turn_id = (
str(dialog["dialog_id"]) + "_" + str(message["turn_id"])
)
if dialog_turn_id in train_data_id_set:
data = {
"dialog_id": dialog["dialog_id"],
"turn_id": message["turn_id"],
"context": copy(context_list),
"entity": copy(entity_list),
"rec": copy(item_turn),
"resp": text,
}
train_w.write(json.dumps(data, ensure_ascii=False) + "\n")
elif dialog_turn_id in valid_data_id_set:
data = {
"dialog_id": dialog["dialog_id"],
"turn_id": message["turn_id"],
"context": copy(context_list),
"entity": copy(entity_list),
"rec": copy(item_turn),
"resp": text,
}
valid_w.write(json.dumps(data, ensure_ascii=False) + "\n")
elif dialog_turn_id in test_data_id_set:
data = {
"dialog_id": dialog["dialog_id"],
"turn_id": message["turn_id"],
"context": copy(context_list),
"entity": copy(entity_list),
"rec": copy(item_turn),
"resp": text,
}
test_w.write(json.dumps(data, ensure_ascii=False) + "\n")
context_list.append(text)
entity_list.extend(entity_turn)
with open("train_data_processed.jsonl", "r", encoding="utf-8") as f:
lines = f.readlines()
print("train:", len(lines))
with open("valid_data_processed.jsonl", "r", encoding="utf-8") as f:
lines = f.readlines()
print("valid:", len(lines))
with open("test_data_processed.jsonl", "r", encoding="utf-8") as f:
lines = f.readlines()
print("test:", len(lines))
|