File size: 3,184 Bytes
b599481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import json
from copy import copy

with open("../opendialkg/train_data_id.json", "r", encoding="utf-8") as f:
    train_data_id = json.load(f)
train_data_id_set = set(train_data_id)
print(len(train_data_id_set))

with open("../opendialkg/valid_data_id.json", "r", encoding="utf-8") as f:
    valid_data_id = json.load(f)
valid_data_id_set = set(valid_data_id)
print(len(valid_data_id_set))

with open("../opendialkg/test_data_id.json", "r", encoding="utf-8") as f:
    test_data_id = json.load(f)
test_data_id_set = set(test_data_id)
print(len(test_data_id_set))

with open("../opendialkg/data.jsonl", "r", encoding="utf-8") as f, open(
    "train_data_processed.jsonl", "w", encoding="utf-8"
) as train_w, open(
    "valid_data_processed.jsonl", "w", encoding="utf-8"
) as valid_w, open(
    "test_data_processed.jsonl", "w", encoding="utf-8"
) as test_w:
    lines = f.readlines()
    for line in lines:
        dialog = json.loads(line)
        context_list = []
        entity_list = []
        for message in dialog["dialog"]:
            role = message["role"]
            text = message["text"]
            # mask_text = message['text_template']
            entity_turn = message["entity"]
            item_turn = message["item"]
            dialog_turn_id = (
                str(dialog["dialog_id"]) + "_" + str(message["turn_id"])
            )

            if dialog_turn_id in train_data_id_set:
                data = {
                    "dialog_id": dialog["dialog_id"],
                    "turn_id": message["turn_id"],
                    "context": copy(context_list),
                    "entity": copy(entity_list),
                    "rec": copy(item_turn),
                    "resp": text,
                }
                train_w.write(json.dumps(data, ensure_ascii=False) + "\n")

            elif dialog_turn_id in valid_data_id_set:
                data = {
                    "dialog_id": dialog["dialog_id"],
                    "turn_id": message["turn_id"],
                    "context": copy(context_list),
                    "entity": copy(entity_list),
                    "rec": copy(item_turn),
                    "resp": text,
                }
                valid_w.write(json.dumps(data, ensure_ascii=False) + "\n")

            elif dialog_turn_id in test_data_id_set:
                data = {
                    "dialog_id": dialog["dialog_id"],
                    "turn_id": message["turn_id"],
                    "context": copy(context_list),
                    "entity": copy(entity_list),
                    "rec": copy(item_turn),
                    "resp": text,
                }
                test_w.write(json.dumps(data, ensure_ascii=False) + "\n")

            context_list.append(text)
            entity_list.extend(entity_turn)

with open("train_data_processed.jsonl", "r", encoding="utf-8") as f:
    lines = f.readlines()
    print("train:", len(lines))

with open("valid_data_processed.jsonl", "r", encoding="utf-8") as f:
    lines = f.readlines()
    print("valid:", len(lines))

with open("test_data_processed.jsonl", "r", encoding="utf-8") as f:
    lines = f.readlines()
    print("test:", len(lines))