File size: 5,463 Bytes
b599481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import json
import random
from copy import copy
from tqdm import tqdm


def process_data(data_file_path):
    global dialog_id

    dialog_list = []
    data_list = []
    with open(data_file_path, encoding='utf-8') as f:
        for line in tqdm(f):
            line = json.loads(line)
            context_text_list = []
            context_text_template_list = []
            context_entity_list = []
            context_item_list = []
            turn_id = 0

            dialog = {'dialog_id': dialog_id, 'dialog': []}

            for turn in line:
                turn['turn_id'] = turn_id
                dialog['dialog'].append(turn)

                text = turn['text']
                text_template = turn['text_template']
                role = turn['role']
                # item = [title2id[title] for title in turn['item']]
                item = turn['item']
                entity = turn['entity']

                if role == 'assistant':
                    flag = True
                    if len(context_text_list) == 0:
                        context_text_list.append('')
                        turn_id += 1
                        flag = False

                    if len(item) > 0 and flag is True:
                        data = {
                            'dialog_id': dialog_id,
                            'turn_id': turn_id,
                            'context': copy(context_text_list),
                            'context_template': copy(context_text_template_list),
                            'context_entity': copy(context_entity_list),
                            'context_item': copy(context_item_list),
                            'resp': text,
                            'item': item
                        }
                        data_list.append(data)
                        # out_file.write(json.dumps(data, ensure_ascii=False) + '\n')

                context_text_list.append(text)
                context_text_template_list.append(text_template)
                context_entity_list.extend(entity)
                context_item_list.extend(item)
                turn_id += 1

            dialog_id += 1
            dialog_list.append(dialog)

    return data_list, dialog_list


if __name__ == '__main__':
    dialog_id = 0

    random.seed(42)

    with open('id2info.json', encoding='utf-8') as f:
        id2info = json.load(f)
    with open('title2id.json', encoding='utf-8') as f:
        title2id = json.load(f)

    data_file_path = 'dialog_movie.jsonl'
    movie_data_list, movie_dialog_list = process_data(data_file_path)

    data_file_path = 'dialog_Books.jsonl'
    book_data_list, book_dialog_list = process_data(data_file_path)

    all_dialog_list = movie_dialog_list + book_dialog_list
    with open('data.jsonl', 'w', encoding='utf-8') as f:
        for dialog in all_dialog_list:
            f.write(json.dumps(dialog, ensure_ascii=False) + '\n')

    all_data_list = movie_data_list + book_data_list
    test_data_list = random.sample(all_data_list, int(len(all_data_list) * 0.15))
    test_data_list = sorted(test_data_list, key=lambda x: x['dialog_id'])

    print(len(test_data_list))

    test_data_dialog_id_set = {f"{data['dialog_id']}_{data['turn_id']}" for data in test_data_list}
    rest_data_list = []
    for data in all_data_list:
        if f"{data['dialog_id']}_{data['turn_id']}" not in test_data_dialog_id_set:
            rest_data_list.append(data)
    assert len(rest_data_list) + len(test_data_list) == len(all_data_list)

    random.shuffle(rest_data_list)
    train_data_list, valid_data_list = rest_data_list[int(0.15 * len(all_data_list)):], rest_data_list[:int(0.15 * len(all_data_list))]
    assert len(valid_data_list) == len(test_data_list)

    train_data_id_list = [f"{data['dialog_id']}_{data['turn_id']}" for data in train_data_list]
    with open('train_data_id.json', 'w', encoding='utf-8') as f:
        json.dump(train_data_id_list, f, ensure_ascii=False)

    valid_data_id_list = [f"{data['dialog_id']}_{data['turn_id']}" for data in valid_data_list]
    with open('valid_data_id.json', 'w', encoding='utf-8') as f:
        json.dump(valid_data_id_list, f, ensure_ascii=False)

    test_data_id_list = [f"{data['dialog_id']}_{data['turn_id']}" for data in test_data_list]
    with open('test_data_id.json', 'w', encoding='utf-8') as f:
        json.dump(test_data_id_list, f, ensure_ascii=False)

    # with open('train_data.jsonl', 'w', encoding='utf-8') as f:
    #     for data in train_data_list:
    #         f.write(json.dumps(data, ensure_ascii=False) + '\n')
    #
    # with open('valid_data.jsonl', 'w', encoding='utf-8') as f:
    #     for data in valid_data_list:
    #         f.write(json.dumps(data, ensure_ascii=False) + '\n')
    #
    # with open('test_data.jsonl', 'w', encoding='utf-8') as f:
    #     for data in test_data_list:
    #         f.write(json.dumps(data, ensure_ascii=False) + '\n')

    # cnt = 0
    # with open('../../model/chat/data/opendialkg/test_data_processed.jsonl', encoding='utf-8') as f:
    #     for line in f:
    #         data = json.loads(line)
    #         data_id = f"{data['dialog_id']}_{data['turn_id']}"
    #         if data_id not in test_data_dialog_id_set:
    #             print(data_id)
    #         cnt += 1
    # assert cnt == len(test_data_list)