File size: 1,869 Bytes
8ebda9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import json
from tqdm import tqdm
import argparse
import numpy as np

def save_data(data,file_path):
    with open(file_path, 'w', encoding='utf8') as f:
        json_data=json.dumps(data,ensure_ascii=False)
        f.write(json_data+'\n')


def load_data(file_path,is_training=False):
    with open(file_path, 'r', encoding='utf8') as f:
        lines = f.readlines()
        result=[]
        for l,line in tqdm(enumerate(lines)): 
            data = json.loads(line)
            result.append(data)
        return result


def recls(line):
    mat=[]
    for l in line:
        s=[v for v in l['score'].values()]
        mat.append(s)
    mat=np.array(mat)
    batch,num_labels=mat.shape
    for i in range(len(line)):
        index = np.unravel_index(np.argmax(mat, axis=None), mat.shape)
        line[index[0]]['label'] = int(index[1])
        mat[index[0],:] = np.zeros((num_labels,))
        mat[:,index[1]] = np.zeros((batch,))
    return line

        
        
     
                                                                                                                                     

def chid_m(data):
    lines={}
    for d in data:
        if d['line_id'] not in lines.keys():
            lines[d['line_id']]=[]
        lines[d['line_id']].append(d)
    result=[]
    for k,v in lines.items():
        result.extend(recls(v))
    return result



def submit(file_path):
    lines = chid_m(load_data(file_path))
    result={}
    for line in tqdm(lines): 
        data = line
        result[data['id']]=data['label']
    return result


if __name__=="__main__":
    parser = argparse.ArgumentParser(description="train")
    parser.add_argument("--data_path", type=str,default="")
    parser.add_argument("--save_path", type=str,default="")

    args = parser.parse_args()
    save_data(submit(args.data_path), args.save_path)