File size: 4,620 Bytes
8ebda9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import json
from tqdm import tqdm
import os
from sklearn.utils import shuffle
import re
import argparse


def cut_sent(para):
    para = re.sub('([。,,!?\?])([^”’])', r"\1\n\2", para)  # 单字符断句符
    para = re.sub('(\.{6})([^”’])', r"\1\n\2", para)  # 英文省略号
    para = re.sub('(\…{2})([^”’])', r"\1\n\2", para)  # 中文省略号
    para = re.sub('([。!?\?][”’])([^,。!?\?])', r'\1\n\2', para)
    # 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号
    para = para.rstrip()  # 段尾如果有多余的\n就去掉它
    # 很多规则中会考虑分号;,但是这里
    return para.split("\n")


def search(pattern, sequence):
    n = len(pattern)
    res=[]
    for i in range(len(sequence)):
        if sequence[i:i + n] == pattern:
            res.append([i,i + n-1])
    return res

max_length=512
stride=128
def stride_split(question, context, answer, start):
    end = start + len(answer) -1
    results, n = [], 0
    max_c_len = max_length - len(question) - 3
    while True:
        left, right = n * stride, n * stride + max_c_len
        if left <= start < end <= right:
            results.append((question, context[left:right], answer, start - left, end - left))
        elif right < start or end < right:
            results.append((question, context[left:right], '', -1, -1))
        if right >= len(context):
            return results
        n += 1


def load_data(file_path,is_training=False):
    task_type='抽取任务'
    subtask_type='抽取式阅读理解'
    with open(file_path, 'r', encoding='utf8') as f:
        lines = json.loads(''.join(f.readlines()))
        result=[]
        lines = lines['data']
        for line in tqdm(lines): 
            if line['paragraphs']==[]:
                continue
            data = line['paragraphs'][0]
            context=data['context'].strip()
            for qa in data['qas']:
                question=qa['question'].strip()
                rcv=[]
                for a in qa['answers']:
                    if a not in rcv:
                        rcv.append(a)
                        split=stride_split(question, context, a['text'], a['answer_start'])
                        for sp in split:
                            choices = []
                            
                            choice = {}
                            choice['id']=qa['id']
                            choice['entity_type'] = qa['question']
                            choice['label']=0
                            entity_list=[]
                            if sp[3]>=0 and sp[4]>=0:
                                entity_list.append({'entity_name':sp[2],'entity_type':'','entity_idx':[[sp[3],sp[4]]]})
                                
                            choice['entity_list']=entity_list
                            choices.append(choice)
                                
                            if choices==[]:
                                print(data)
                                continue
                            result.append({ 'task_type':task_type,
                                            'subtask_type':subtask_type,
                                            'text':sp[1],
                                            'choices':choices,
                                            'id':0}) 
                            
        return result


def save_data(data,file_path):
    with open(file_path, 'w', encoding='utf8') as f:
        for line in data:
            json_data=json.dumps(line,ensure_ascii=False)
            f.write(json_data+'\n')



if __name__=="__main__":
    parser = argparse.ArgumentParser(description="train")
    parser.add_argument("--data_path", type=str,default="")
    parser.add_argument("--save_path", type=str,default="")

    args = parser.parse_args()
    
    
    data_path = args.data_path
    save_path = args.save_path
    
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    
    file_list=['dev','train','trial','test']
    train_data = []
    dev_data = []
    for file in file_list:
        file_path = os.path.join(data_path,file+'.json')
        data=load_data(file_path=file_path)
        if 'train' in file or 'trial' in file:
            train_data.extend(data)
        else:
            output_path = os.path.join(save_path,file+'.json')
            save_data(data,output_path)
    
    output_path = os.path.join(save_path,'train.json')
    save_data(train_data,output_path)